Skip to content

Commit 98d11f8

Browse files
authored
Merge pull request #79 from DataDog/viq111/stream2
[zstd_stream] Now use ZSTD_compressStream2 C API
2 parents 4720f73 + 154f2b8 commit 98d11f8

File tree

3 files changed

+217
-57
lines changed

3 files changed

+217
-57
lines changed

helpers_test.go

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
package zstd
2+
3+
import (
4+
"io"
5+
"math/rand"
6+
"time"
7+
)
8+
9+
// randbytes creates a stream of non-crypto quality random bytes
10+
type randbytes struct {
11+
rand.Source
12+
}
13+
14+
// NewRandBytes creates a new random reader with a time source.
15+
func NewRandBytes() io.Reader {
16+
return NewRandBytesFrom(rand.NewSource(time.Now().UnixNano()))
17+
}
18+
19+
// NewRandBytesFrom creates a new reader from your own rand.Source
20+
func NewRandBytesFrom(src rand.Source) io.Reader {
21+
return &randbytes{src}
22+
}
23+
24+
// Read satisfies io.Reader
25+
func (r *randbytes) Read(p []byte) (n int, err error) {
26+
todo := len(p)
27+
offset := 0
28+
for {
29+
val := int64(r.Int63())
30+
for i := 0; i < 8; i++ {
31+
p[offset] = byte(val)
32+
todo--
33+
if todo == 0 {
34+
return len(p), nil
35+
}
36+
offset++
37+
val >>= 8
38+
}
39+
}
40+
}

zstd_stream.go

+114-33
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,35 @@ package zstd
33
/*
44
#define ZSTD_STATIC_LINKING_ONLY
55
#define ZBUFF_DISABLE_DEPRECATE_WARNINGS
6+
#include "stdint.h" // for uintptr_t
67
#include "zstd.h"
78
#include "zbuff.h"
9+
10+
typedef struct compressStream2_result_s {
11+
size_t return_code;
12+
size_t bytes_consumed;
13+
size_t bytes_written;
14+
} compressStream2_result;
15+
16+
static void ZSTD_compressStream2_wrapper(compressStream2_result* result, ZSTD_CCtx* ctx, uintptr_t dst, size_t maxDstSize, const uintptr_t src, size_t srcSize) {
17+
ZSTD_outBuffer outBuffer = { (void*)dst, maxDstSize, 0 };
18+
ZSTD_inBuffer inBuffer = { (void*)src, srcSize, 0 };
19+
size_t retCode = ZSTD_compressStream2(ctx, &outBuffer, &inBuffer, ZSTD_e_continue);
20+
21+
result->return_code = retCode;
22+
result->bytes_consumed = inBuffer.pos;
23+
result->bytes_written = outBuffer.pos;
24+
}
25+
26+
static void ZSTD_compressStream2_finish(compressStream2_result* result, ZSTD_CCtx* ctx, uintptr_t dst, size_t maxDstSize, const uintptr_t src, size_t srcSize) {
27+
ZSTD_outBuffer outBuffer = { (void*)dst, maxDstSize, 0 };
28+
ZSTD_inBuffer inBuffer = { (void*)src, srcSize, 0 };
29+
size_t retCode = ZSTD_compressStream2(ctx, &outBuffer, &inBuffer, ZSTD_e_end);
30+
31+
result->return_code = retCode;
32+
result->bytes_consumed = inBuffer.pos;
33+
result->bytes_written = outBuffer.pos;
34+
}
835
*/
936
import "C"
1037
import (
@@ -24,9 +51,11 @@ type Writer struct {
2451

2552
ctx *C.ZSTD_CCtx
2653
dict []byte
54+
srcBuffer []byte
2755
dstBuffer []byte
2856
firstError error
2957
underlyingWriter io.Writer
58+
resultBuffer *C.compressStream2_result
3059
}
3160

3261
func resize(in []byte, newSize int) []byte {
@@ -61,26 +90,30 @@ func NewWriterLevel(w io.Writer, level int) *Writer {
6190
// should not be modified until the writer is closed.
6291
func NewWriterLevelDict(w io.Writer, level int, dict []byte) *Writer {
6392
var err error
64-
ctx := C.ZSTD_createCCtx()
93+
ctx := C.ZSTD_createCStream()
6594

66-
if dict == nil {
67-
err = getError(int(C.ZSTD_compressBegin(ctx,
68-
C.int(level))))
69-
} else {
70-
err = getError(int(C.ZSTD_compressBegin_usingDict(
71-
ctx,
95+
// Load dictionnary if any
96+
if dict != nil {
97+
err = getError(int(C.ZSTD_CCtx_loadDictionary(ctx,
7298
unsafe.Pointer(&dict[0]),
7399
C.size_t(len(dict)),
74-
C.int(level))))
100+
)))
101+
}
102+
103+
if err == nil {
104+
// Only set level if the ctx is not in error already
105+
err = getError(int(C.ZSTD_CCtx_setParameter(ctx, C.ZSTD_c_compressionLevel, C.int(level))))
75106
}
76107

77108
return &Writer{
78109
CompressionLevel: level,
79110
ctx: ctx,
80111
dict: dict,
112+
srcBuffer: make([]byte, 0),
81113
dstBuffer: make([]byte, CompressBound(1024)),
82114
firstError: err,
83115
underlyingWriter: w,
116+
resultBuffer: new(C.compressStream2_result),
84117
}
85118
}
86119

@@ -93,22 +126,56 @@ func (w *Writer) Write(p []byte) (int, error) {
93126
return 0, nil
94127
}
95128
// Check if dstBuffer is enough
129+
w.dstBuffer = w.dstBuffer[0:cap(w.dstBuffer)]
96130
if len(w.dstBuffer) < CompressBound(len(p)) {
97131
w.dstBuffer = make([]byte, CompressBound(len(p)))
98132
}
99133

100-
retCode := C.ZSTD_compressContinue(
134+
// Do not do an extra memcopy if zstd ingest all input data
135+
srcData := p
136+
fastPath := len(w.srcBuffer) == 0
137+
if !fastPath {
138+
w.srcBuffer = append(w.srcBuffer, p...)
139+
srcData = w.srcBuffer
140+
}
141+
142+
srcPtr := C.uintptr_t(uintptr(0)) // Do not point anywhere, if src is empty
143+
if len(srcData) > 0 {
144+
srcPtr = C.uintptr_t(uintptr(unsafe.Pointer(&srcData[0])))
145+
}
146+
147+
C.ZSTD_compressStream2_wrapper(
148+
w.resultBuffer,
101149
w.ctx,
102-
unsafe.Pointer(&w.dstBuffer[0]),
150+
C.uintptr_t(uintptr(unsafe.Pointer(&w.dstBuffer[0]))),
103151
C.size_t(len(w.dstBuffer)),
104-
unsafe.Pointer(&p[0]),
105-
C.size_t(len(p)))
106-
107-
if err := getError(int(retCode)); err != nil {
152+
srcPtr,
153+
C.size_t(len(srcData)),
154+
)
155+
runtime.KeepAlive(p) // Ensure p is kept until here so pointer doesn't disappear during C call
156+
ret := int(w.resultBuffer.return_code)
157+
if err := getError(ret); err != nil {
108158
return 0, err
109159
}
110-
written := int(retCode)
111160

161+
consumed := int(w.resultBuffer.bytes_consumed)
162+
if !fastPath {
163+
w.srcBuffer = w.srcBuffer[consumed:]
164+
} else {
165+
remaining := len(p) - consumed
166+
if remaining > 0 {
167+
// We still have some non-consumed data, copy remaining data to srcBuffer
168+
// Try to not reallocate w.srcBuffer if we already have enough space
169+
if cap(w.srcBuffer) >= remaining {
170+
w.srcBuffer = w.srcBuffer[0:remaining]
171+
} else {
172+
w.srcBuffer = make([]byte, remaining)
173+
}
174+
copy(w.srcBuffer, p[consumed:])
175+
}
176+
}
177+
178+
written := int(w.resultBuffer.bytes_written)
112179
// Write to underlying buffer
113180
_, err := w.underlyingWriter.Write(w.dstBuffer[:written])
114181

@@ -123,28 +190,42 @@ func (w *Writer) Write(p []byte) (int, error) {
123190
// Close closes the Writer, flushing any unwritten data to the underlying
124191
// io.Writer and freeing objects, but does not close the underlying io.Writer.
125192
func (w *Writer) Close() error {
126-
retCode := C.ZSTD_compressEnd(
127-
w.ctx,
128-
unsafe.Pointer(&w.dstBuffer[0]),
129-
C.size_t(len(w.dstBuffer)),
130-
unsafe.Pointer(nil),
131-
C.size_t(0))
132-
133-
if err := getError(int(retCode)); err != nil {
134-
return err
193+
if w.firstError != nil {
194+
return w.firstError
135195
}
136-
written := int(retCode)
137-
retCode = C.ZSTD_freeCCtx(w.ctx) // Safely close buffer before writing the end
138196

139-
if err := getError(int(retCode)); err != nil {
140-
return err
141-
}
197+
ret := 1 // So we loop at least once
198+
for ret > 0 {
199+
srcPtr := C.uintptr_t(uintptr(0)) // Do not point anywhere, if src is empty
200+
if len(w.srcBuffer) > 0 {
201+
srcPtr = C.uintptr_t(uintptr(unsafe.Pointer(&w.srcBuffer[0])))
202+
}
142203

143-
_, err := w.underlyingWriter.Write(w.dstBuffer[:written])
144-
if err != nil {
145-
return err
204+
C.ZSTD_compressStream2_finish(
205+
w.resultBuffer,
206+
w.ctx,
207+
C.uintptr_t(uintptr(unsafe.Pointer(&w.dstBuffer[0]))),
208+
C.size_t(len(w.dstBuffer)),
209+
srcPtr,
210+
C.size_t(len(w.srcBuffer)),
211+
)
212+
ret = int(w.resultBuffer.return_code)
213+
if err := getError(ret); err != nil {
214+
return err
215+
}
216+
w.srcBuffer = w.srcBuffer[w.resultBuffer.bytes_consumed:]
217+
written := int(w.resultBuffer.bytes_written)
218+
w.underlyingWriter.Write(w.dstBuffer[:written])
219+
220+
if ret > 0 { // We have a hint if we need to resize the dstBuffer
221+
w.dstBuffer = w.dstBuffer[:cap(w.dstBuffer)]
222+
if len(w.dstBuffer) < ret {
223+
w.dstBuffer = make([]byte, ret)
224+
}
225+
}
146226
}
147-
return nil
227+
228+
return getError(int(C.ZSTD_freeCStream(w.ctx)))
148229
}
149230

150231
// cSize is the recommended size of reader.compressionBuffer. This func and

zstd_stream_test.go

+63-24
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,69 @@ func TestStreamEmptyPayload(t *testing.T) {
149149
}
150150
}
151151

152+
type breakingReader struct{}
153+
154+
func (r *breakingReader) Read(p []byte) (int, error) {
155+
return len(p) - 1, io.ErrUnexpectedEOF
156+
}
157+
158+
func TestStreamDecompressionUnexpectedEOFHandling(t *testing.T) {
159+
r := NewReader(&breakingReader{})
160+
_, err := r.Read(make([]byte, 1024))
161+
if err == nil {
162+
t.Error("Underlying error was handled silently")
163+
}
164+
}
165+
166+
func TestStreamCompressionDecompressionParallel(t *testing.T) {
167+
for i := 0; i < 200; i++ {
168+
t.Run("", func(t2 *testing.T) {
169+
t2.Parallel()
170+
TestStreamCompressionDecompression(t2)
171+
})
172+
}
173+
}
174+
175+
func TestStreamCompressionChunks(t *testing.T) {
176+
MB := 1024 * 1024
177+
totalSize := 100 * MB
178+
chunk := 1 * MB
179+
180+
rawData := make([]byte, totalSize)
181+
r := NewRandBytes()
182+
r.Read(rawData)
183+
184+
compressed, _ := Compress(nil, rawData)
185+
var streamCompressed bytes.Buffer
186+
w := NewWriter(&streamCompressed)
187+
for i := 0; i < totalSize; i += chunk {
188+
end := i + chunk
189+
if end >= len(rawData) {
190+
end = len(rawData)
191+
}
192+
n, err := w.Write(rawData[i:end])
193+
if err != nil {
194+
t.Fatalf("Error while writing: %s", err)
195+
}
196+
if n != (end - i) {
197+
t.Fatalf("Did not write the full ammount of data: %v != %v", n, end-i)
198+
}
199+
}
200+
err := w.Close()
201+
if err != nil {
202+
t.Fatalf("Failed to close writer: %s", err)
203+
}
204+
streamCompressedBytes := streamCompressed.Bytes()
205+
t.Logf("Compressed with single call=%v bytes, stream compressed=%v bytes", len(compressed), len(streamCompressedBytes))
206+
decompressed, err := Decompress(nil, streamCompressedBytes)
207+
if err != nil {
208+
t.Fatalf("Failed to decompress: %s", err)
209+
}
210+
if !bytes.Equal(rawData, decompressed) {
211+
t.Fatalf("Compression/Decompression data is not equal to original data")
212+
}
213+
}
214+
152215
func BenchmarkStreamCompression(b *testing.B) {
153216
if raw == nil {
154217
b.Fatal(ErrNoPayloadEnv)
@@ -194,27 +257,3 @@ func BenchmarkStreamDecompression(b *testing.B) {
194257
r.Close()
195258
}
196259
}
197-
198-
type breakingReader struct {
199-
}
200-
201-
func (r *breakingReader) Read(p []byte) (int, error) {
202-
return len(p) - 1, io.ErrUnexpectedEOF
203-
}
204-
205-
func TestUnexpectedEOFHandling(t *testing.T) {
206-
r := NewReader(&breakingReader{})
207-
_, err := r.Read(make([]byte, 1024))
208-
if err == nil {
209-
t.Error("Underlying error was handled silently")
210-
}
211-
}
212-
213-
func TestStreamCompressionDecompressionParallel(t *testing.T) {
214-
for i := 0; i < 200; i++ {
215-
t.Run("", func(t2 *testing.T) {
216-
t2.Parallel()
217-
TestStreamCompressionDecompression(t2)
218-
})
219-
}
220-
}

0 commit comments

Comments
 (0)