Skip to content

Commit fd035e5

Browse files
authoredJun 6, 2022
Merge pull request #117 from bsergean/patch-1
Add SetNbWorkers api to the writer code (see #108)
2 parents d64f463 + c798238 commit fd035e5

File tree

3 files changed

+46
-6
lines changed

3 files changed

+46
-6
lines changed
 

‎zstd.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ package zstd
44
// support decoding of "legacy" zstd payloads from versions [0.4, 0.8], matching the
55
// default configuration of the zstd command line tool:
66
// https://github.com/facebook/zstd/blob/dev/programs/README.md
7-
#cgo CFLAGS: -DZSTD_LEGACY_SUPPORT=4
7+
#cgo CFLAGS: -DZSTD_LEGACY_SUPPORT=4 -DZSTD_MULTITHREAD=1
88
99
#include "zstd.h"
1010
*/

‎zstd_stream.go

+23
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ import (
7272

7373
var errShortRead = errors.New("short read")
7474
var errReaderClosed = errors.New("Reader is closed")
75+
var ErrNoParallelSupport = errors.New("No parallel support")
7576

7677
// Writer is an io.WriteCloser that zstd-compresses its input.
7778
type Writer struct {
@@ -302,6 +303,28 @@ func (w *Writer) Close() error {
302303
return getError(int(C.ZSTD_freeCStream(w.ctx)))
303304
}
304305

306+
// Set the number of workers to run the compression in parallel using multiple threads
307+
// If > 1, the Write() call will become asynchronous. This means data will be buffered until processed.
308+
// If you call Write() too fast, you might incur a memory buffer up to as large as your input.
309+
// Consider calling Flush() periodically if you need to compress a very large file that would not fit all in memory.
310+
// By default only one worker is used.
311+
func (w *Writer) SetNbWorkers(n int) error {
312+
if w.firstError != nil {
313+
return w.firstError
314+
}
315+
if err := getError(int(C.ZSTD_CCtx_setParameter(w.ctx, C.ZSTD_c_nbWorkers, C.int(n)))); err != nil {
316+
w.firstError = err
317+
// First error case, a shared libary is used, and the library was compiled without parallel support
318+
if err.Error() == "Unsupported parameter" {
319+
return ErrNoParallelSupport
320+
} else {
321+
// This could happen if a very large number is passed in, and possibly zstd refuse to create as many threads, or the OS fails to do so
322+
return err
323+
}
324+
}
325+
return nil
326+
}
327+
305328
// cSize is the recommended size of reader.compressionBuffer. This func and
306329
// invocation allow for a one-time check for validity.
307330
var cSize = func() int {

‎zstd_stream_test.go

+22-5
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"log"
1010
"os"
1111
"runtime/debug"
12+
"strings"
1213
"testing"
1314
)
1415

@@ -19,9 +20,16 @@ func failOnError(t *testing.T, msg string, err error) {
1920
}
2021
}
2122

22-
func testCompressionDecompression(t *testing.T, dict []byte, payload []byte) {
23+
func testCompressionDecompression(t *testing.T, dict []byte, payload []byte, nbWorkers int) {
2324
var w bytes.Buffer
2425
writer := NewWriterLevelDict(&w, DefaultCompression, dict)
26+
27+
if nbWorkers > 1 {
28+
if err := writer.SetNbWorkers(nbWorkers); err == ErrNoParallelSupport {
29+
t.Skip()
30+
}
31+
}
32+
2533
_, err := writer.Write(payload)
2634
failOnError(t, "Failed writing to compress object", err)
2735
failOnError(t, "Failed to close compress object", writer.Close())
@@ -79,19 +87,19 @@ func TestResize(t *testing.T) {
7987
}
8088

8189
func TestStreamSimpleCompressionDecompression(t *testing.T) {
82-
testCompressionDecompression(t, nil, []byte("Hello world!"))
90+
testCompressionDecompression(t, nil, []byte("Hello world!"), 1)
8391
}
8492

8593
func TestStreamEmptySlice(t *testing.T) {
86-
testCompressionDecompression(t, nil, []byte{})
94+
testCompressionDecompression(t, nil, []byte{}, 1)
8795
}
8896

8997
func TestZstdReaderLong(t *testing.T) {
9098
var long bytes.Buffer
9199
for i := 0; i < 10000; i++ {
92100
long.Write([]byte("Hellow World!"))
93101
}
94-
testCompressionDecompression(t, nil, long.Bytes())
102+
testCompressionDecompression(t, nil, long.Bytes(), 1)
95103
}
96104

97105
func doStreamCompressionDecompression() error {
@@ -186,7 +194,7 @@ func TestStreamRealPayload(t *testing.T) {
186194
if raw == nil {
187195
t.Skip(ErrNoPayloadEnv)
188196
}
189-
testCompressionDecompression(t, nil, raw)
197+
testCompressionDecompression(t, nil, raw, 1)
190198
}
191199

192200
func TestStreamEmptyPayload(t *testing.T) {
@@ -398,12 +406,21 @@ func TestStreamWriteNoGoPointers(t *testing.T) {
398406
})
399407
}
400408

409+
func TestStreamSetNbWorkers(t *testing.T) {
410+
// Build a big string first
411+
s := strings.Repeat("foobaa", 1000*1000)
412+
413+
nbWorkers := 4
414+
testCompressionDecompression(t, nil, []byte(s), nbWorkers)
415+
}
416+
401417
func BenchmarkStreamCompression(b *testing.B) {
402418
if raw == nil {
403419
b.Fatal(ErrNoPayloadEnv)
404420
}
405421
var intermediate bytes.Buffer
406422
w := NewWriter(&intermediate)
423+
// w.SetNbWorkers(8)
407424
defer w.Close()
408425
b.SetBytes(int64(len(raw)))
409426
b.ResetTimer()

0 commit comments

Comments
 (0)
Please sign in to comment.