@@ -3,8 +3,35 @@ package zstd
3
3
/*
4
4
#define ZSTD_STATIC_LINKING_ONLY
5
5
#define ZBUFF_DISABLE_DEPRECATE_WARNINGS
6
+ #include "stdint.h" // for uintptr_t
6
7
#include "zstd.h"
7
8
#include "zbuff.h"
9
+
10
+ typedef struct compressStream2_result_s {
11
+ size_t return_code;
12
+ size_t bytes_consumed;
13
+ size_t bytes_written;
14
+ } compressStream2_result;
15
+
16
+ static void ZSTD_compressStream2_wrapper(compressStream2_result* result, ZSTD_CCtx* ctx, uintptr_t dst, size_t maxDstSize, const uintptr_t src, size_t srcSize) {
17
+ ZSTD_outBuffer outBuffer = { (void*)dst, maxDstSize, 0 };
18
+ ZSTD_inBuffer inBuffer = { (void*)src, srcSize, 0 };
19
+ size_t retCode = ZSTD_compressStream2(ctx, &outBuffer, &inBuffer, ZSTD_e_continue);
20
+
21
+ result->return_code = retCode;
22
+ result->bytes_consumed = inBuffer.pos;
23
+ result->bytes_written = outBuffer.pos;
24
+ }
25
+
26
+ static void ZSTD_compressStream2_finish(compressStream2_result* result, ZSTD_CCtx* ctx, uintptr_t dst, size_t maxDstSize, const uintptr_t src, size_t srcSize) {
27
+ ZSTD_outBuffer outBuffer = { (void*)dst, maxDstSize, 0 };
28
+ ZSTD_inBuffer inBuffer = { (void*)src, srcSize, 0 };
29
+ size_t retCode = ZSTD_compressStream2(ctx, &outBuffer, &inBuffer, ZSTD_e_end);
30
+
31
+ result->return_code = retCode;
32
+ result->bytes_consumed = inBuffer.pos;
33
+ result->bytes_written = outBuffer.pos;
34
+ }
8
35
*/
9
36
import "C"
10
37
import (
@@ -24,9 +51,11 @@ type Writer struct {
24
51
25
52
ctx * C.ZSTD_CCtx
26
53
dict []byte
54
+ srcBuffer []byte
27
55
dstBuffer []byte
28
56
firstError error
29
57
underlyingWriter io.Writer
58
+ resultBuffer * C.compressStream2_result
30
59
}
31
60
32
61
func resize (in []byte , newSize int ) []byte {
@@ -61,26 +90,30 @@ func NewWriterLevel(w io.Writer, level int) *Writer {
61
90
// should not be modified until the writer is closed.
62
91
func NewWriterLevelDict (w io.Writer , level int , dict []byte ) * Writer {
63
92
var err error
64
- ctx := C .ZSTD_createCCtx ()
93
+ ctx := C .ZSTD_createCStream ()
65
94
66
- if dict == nil {
67
- err = getError (int (C .ZSTD_compressBegin (ctx ,
68
- C .int (level ))))
69
- } else {
70
- err = getError (int (C .ZSTD_compressBegin_usingDict (
71
- ctx ,
95
+ // Load dictionnary if any
96
+ if dict != nil {
97
+ err = getError (int (C .ZSTD_CCtx_loadDictionary (ctx ,
72
98
unsafe .Pointer (& dict [0 ]),
73
99
C .size_t (len (dict )),
74
- C .int (level ))))
100
+ )))
101
+ }
102
+
103
+ if err == nil {
104
+ // Only set level if the ctx is not in error already
105
+ err = getError (int (C .ZSTD_CCtx_setParameter (ctx , C .ZSTD_c_compressionLevel , C .int (level ))))
75
106
}
76
107
77
108
return & Writer {
78
109
CompressionLevel : level ,
79
110
ctx : ctx ,
80
111
dict : dict ,
112
+ srcBuffer : make ([]byte , 0 ),
81
113
dstBuffer : make ([]byte , CompressBound (1024 )),
82
114
firstError : err ,
83
115
underlyingWriter : w ,
116
+ resultBuffer : new (C.compressStream2_result ),
84
117
}
85
118
}
86
119
@@ -93,22 +126,56 @@ func (w *Writer) Write(p []byte) (int, error) {
93
126
return 0 , nil
94
127
}
95
128
// Check if dstBuffer is enough
129
+ w .dstBuffer = w .dstBuffer [0 :cap (w .dstBuffer )]
96
130
if len (w .dstBuffer ) < CompressBound (len (p )) {
97
131
w .dstBuffer = make ([]byte , CompressBound (len (p )))
98
132
}
99
133
100
- retCode := C .ZSTD_compressContinue (
134
+ // Do not do an extra memcopy if zstd ingest all input data
135
+ srcData := p
136
+ fastPath := len (w .srcBuffer ) == 0
137
+ if ! fastPath {
138
+ w .srcBuffer = append (w .srcBuffer , p ... )
139
+ srcData = w .srcBuffer
140
+ }
141
+
142
+ srcPtr := C .uintptr_t (uintptr (0 )) // Do not point anywhere, if src is empty
143
+ if len (srcData ) > 0 {
144
+ srcPtr = C .uintptr_t (uintptr (unsafe .Pointer (& srcData [0 ])))
145
+ }
146
+
147
+ C .ZSTD_compressStream2_wrapper (
148
+ w .resultBuffer ,
101
149
w .ctx ,
102
- unsafe .Pointer (& w .dstBuffer [0 ]),
150
+ C . uintptr_t ( uintptr ( unsafe .Pointer (& w .dstBuffer [0 ])) ),
103
151
C .size_t (len (w .dstBuffer )),
104
- unsafe .Pointer (& p [0 ]),
105
- C .size_t (len (p )))
106
-
107
- if err := getError (int (retCode )); err != nil {
152
+ srcPtr ,
153
+ C .size_t (len (srcData )),
154
+ )
155
+ runtime .KeepAlive (p ) // Ensure p is kept until here so pointer doesn't disappear during C call
156
+ ret := int (w .resultBuffer .return_code )
157
+ if err := getError (ret ); err != nil {
108
158
return 0 , err
109
159
}
110
- written := int (retCode )
111
160
161
+ consumed := int (w .resultBuffer .bytes_consumed )
162
+ if ! fastPath {
163
+ w .srcBuffer = w .srcBuffer [consumed :]
164
+ } else {
165
+ remaining := len (p ) - consumed
166
+ if remaining > 0 {
167
+ // We still have some non-consumed data, copy remaining data to srcBuffer
168
+ // Try to not reallocate w.srcBuffer if we already have enough space
169
+ if cap (w .srcBuffer ) >= remaining {
170
+ w .srcBuffer = w .srcBuffer [0 :remaining ]
171
+ } else {
172
+ w .srcBuffer = make ([]byte , remaining )
173
+ }
174
+ copy (w .srcBuffer , p [consumed :])
175
+ }
176
+ }
177
+
178
+ written := int (w .resultBuffer .bytes_written )
112
179
// Write to underlying buffer
113
180
_ , err := w .underlyingWriter .Write (w .dstBuffer [:written ])
114
181
@@ -123,28 +190,42 @@ func (w *Writer) Write(p []byte) (int, error) {
123
190
// Close closes the Writer, flushing any unwritten data to the underlying
124
191
// io.Writer and freeing objects, but does not close the underlying io.Writer.
125
192
func (w * Writer ) Close () error {
126
- retCode := C .ZSTD_compressEnd (
127
- w .ctx ,
128
- unsafe .Pointer (& w .dstBuffer [0 ]),
129
- C .size_t (len (w .dstBuffer )),
130
- unsafe .Pointer (nil ),
131
- C .size_t (0 ))
132
-
133
- if err := getError (int (retCode )); err != nil {
134
- return err
193
+ if w .firstError != nil {
194
+ return w .firstError
135
195
}
136
- written := int (retCode )
137
- retCode = C .ZSTD_freeCCtx (w .ctx ) // Safely close buffer before writing the end
138
196
139
- if err := getError (int (retCode )); err != nil {
140
- return err
141
- }
197
+ ret := 1 // So we loop at least once
198
+ for ret > 0 {
199
+ srcPtr := C .uintptr_t (uintptr (0 )) // Do not point anywhere, if src is empty
200
+ if len (w .srcBuffer ) > 0 {
201
+ srcPtr = C .uintptr_t (uintptr (unsafe .Pointer (& w .srcBuffer [0 ])))
202
+ }
142
203
143
- _ , err := w .underlyingWriter .Write (w .dstBuffer [:written ])
144
- if err != nil {
145
- return err
204
+ C .ZSTD_compressStream2_finish (
205
+ w .resultBuffer ,
206
+ w .ctx ,
207
+ C .uintptr_t (uintptr (unsafe .Pointer (& w .dstBuffer [0 ]))),
208
+ C .size_t (len (w .dstBuffer )),
209
+ srcPtr ,
210
+ C .size_t (len (w .srcBuffer )),
211
+ )
212
+ ret = int (w .resultBuffer .return_code )
213
+ if err := getError (ret ); err != nil {
214
+ return err
215
+ }
216
+ w .srcBuffer = w .srcBuffer [w .resultBuffer .bytes_consumed :]
217
+ written := int (w .resultBuffer .bytes_written )
218
+ w .underlyingWriter .Write (w .dstBuffer [:written ])
219
+
220
+ if ret > 0 { // We have a hint if we need to resize the dstBuffer
221
+ w .dstBuffer = w .dstBuffer [:cap (w .dstBuffer )]
222
+ if len (w .dstBuffer ) < ret {
223
+ w .dstBuffer = make ([]byte , ret )
224
+ }
225
+ }
146
226
}
147
- return nil
227
+
228
+ return getError (int (C .ZSTD_freeCStream (w .ctx )))
148
229
}
149
230
150
231
// cSize is the recommended size of reader.compressionBuffer. This func and
0 commit comments