Skip to content

Commit 4b8fdba

Browse files
authored
Merge pull request #86 from delthas/1.x
[zstd_stream] Add Writer.Flush()
2 parents 558004b + ddae45c commit 4b8fdba

File tree

3 files changed

+138
-1
lines changed

3 files changed

+138
-1
lines changed

README.md

+3
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ NewWriterLevelDict(w io.Writer, level int, dict []byte) *Writer
6060
// Write compresses the input data and write it to the underlying writer
6161
(w *Writer) Write(p []byte) (int, error)
6262

63+
// Flush writes any unwritten data to the underlying writer
64+
(w *Writer) Flush() error
65+
6366
// Close flushes the buffer and frees C zstd objects
6467
(w *Writer) Close() error
6568
```

zstd_stream.go

+58-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,16 @@ static void ZSTD_compressStream2_wrapper(compressStream2_result* result, ZSTD_CC
2121
result->bytes_written = outBuffer.pos;
2222
}
2323
24+
static void ZSTD_compressStream2_flush(compressStream2_result* result, ZSTD_CCtx* ctx, uintptr_t dst, size_t maxDstSize, const uintptr_t src, size_t srcSize) {
25+
ZSTD_outBuffer outBuffer = { (void*)dst, maxDstSize, 0 };
26+
ZSTD_inBuffer inBuffer = { (void*)src, srcSize, 0 };
27+
size_t retCode = ZSTD_compressStream2(ctx, &outBuffer, &inBuffer, ZSTD_e_flush);
28+
29+
result->return_code = retCode;
30+
result->bytes_consumed = inBuffer.pos;
31+
result->bytes_written = outBuffer.pos;
32+
}
33+
2434
static void ZSTD_compressStream2_finish(compressStream2_result* result, ZSTD_CCtx* ctx, uintptr_t dst, size_t maxDstSize, const uintptr_t src, size_t srcSize) {
2535
ZSTD_outBuffer outBuffer = { (void*)dst, maxDstSize, 0 };
2636
ZSTD_inBuffer inBuffer = { (void*)src, srcSize, 0 };
@@ -203,6 +213,49 @@ func (w *Writer) Write(p []byte) (int, error) {
203213
return len(p), err
204214
}
205215

216+
// Flush writes any unwritten data to the underlying io.Writer.
217+
func (w *Writer) Flush() error {
218+
if w.firstError != nil {
219+
return w.firstError
220+
}
221+
222+
ret := 1 // So we loop at least once
223+
for ret > 0 {
224+
srcPtr := C.uintptr_t(uintptr(0)) // Do not point anywhere, if src is empty
225+
if len(w.srcBuffer) > 0 {
226+
srcPtr = C.uintptr_t(uintptr(unsafe.Pointer(&w.srcBuffer[0])))
227+
}
228+
229+
C.ZSTD_compressStream2_flush(
230+
w.resultBuffer,
231+
w.ctx,
232+
C.uintptr_t(uintptr(unsafe.Pointer(&w.dstBuffer[0]))),
233+
C.size_t(len(w.dstBuffer)),
234+
srcPtr,
235+
C.size_t(len(w.srcBuffer)),
236+
)
237+
ret = int(w.resultBuffer.return_code)
238+
if err := getError(ret); err != nil {
239+
return err
240+
}
241+
w.srcBuffer = w.srcBuffer[w.resultBuffer.bytes_consumed:]
242+
written := int(w.resultBuffer.bytes_written)
243+
_, err := w.underlyingWriter.Write(w.dstBuffer[:written])
244+
if err != nil {
245+
return err
246+
}
247+
248+
if ret > 0 { // We have a hint if we need to resize the dstBuffer
249+
w.dstBuffer = w.dstBuffer[:cap(w.dstBuffer)]
250+
if len(w.dstBuffer) < ret {
251+
w.dstBuffer = make([]byte, ret)
252+
}
253+
}
254+
}
255+
256+
return nil
257+
}
258+
206259
// Close closes the Writer, flushing any unwritten data to the underlying
207260
// io.Writer and freeing objects, but does not close the underlying io.Writer.
208261
func (w *Writer) Close() error {
@@ -231,7 +284,11 @@ func (w *Writer) Close() error {
231284
}
232285
w.srcBuffer = w.srcBuffer[w.resultBuffer.bytes_consumed:]
233286
written := int(w.resultBuffer.bytes_written)
234-
w.underlyingWriter.Write(w.dstBuffer[:written])
287+
_, err := w.underlyingWriter.Write(w.dstBuffer[:written])
288+
if err != nil {
289+
C.ZSTD_freeCStream(w.ctx)
290+
return err
291+
}
235292

236293
if ret > 0 { // We have a hint if we need to resize the dstBuffer
237294
w.dstBuffer = w.dstBuffer[:cap(w.dstBuffer)]

zstd_stream_test.go

+77
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@ package zstd
22

33
import (
44
"bytes"
5+
"errors"
56
"io"
67
"io/ioutil"
8+
"log"
79
"runtime/debug"
810
"testing"
911
)
@@ -149,6 +151,81 @@ func TestStreamEmptyPayload(t *testing.T) {
149151
}
150152
}
151153

154+
func TestStreamFlush(t *testing.T) {
155+
var w bytes.Buffer
156+
writer := NewWriter(&w)
157+
reader := NewReader(&w)
158+
159+
payload := "cc" // keep the payload short to make sure it will not be automatically flushed by zstd
160+
buf := make([]byte, len(payload))
161+
162+
for i := 0; i < 5; i++ {
163+
_, err := writer.Write([]byte(payload))
164+
failOnError(t, "Failed writing to compress object", err)
165+
166+
err = writer.Flush()
167+
failOnError(t, "Failed flushing compress object", err)
168+
169+
_, err = io.ReadFull(reader, buf)
170+
failOnError(t, "Failed reading uncompress object", err)
171+
172+
if string(buf) != payload {
173+
debug.PrintStack()
174+
log.Fatal("Uncompressed object mismatch")
175+
}
176+
}
177+
178+
failOnError(t, "Failed to close compress object", writer.Close())
179+
failOnError(t, "Failed to close uncompress object", reader.Close())
180+
}
181+
182+
type closeableWriter struct{
183+
w io.Writer
184+
closed bool
185+
}
186+
187+
func (c *closeableWriter) Write(p []byte) (n int, err error) {
188+
if c.closed {
189+
return 0, errors.New("io: Write on a closed closeableWriter")
190+
}
191+
return c.w.Write(p)
192+
}
193+
194+
func (c *closeableWriter) Close() error {
195+
c.closed = true
196+
return nil
197+
}
198+
199+
func TestStreamFlushError(t *testing.T) {
200+
var bw bytes.Buffer
201+
w := closeableWriter{w: &bw}
202+
writer := NewWriter(&w)
203+
204+
_, err := writer.Write([]byte("cc"))
205+
failOnError(t, "Failed writing to compress object", err)
206+
207+
w.Close()
208+
if err = writer.Flush(); err == nil {
209+
debug.PrintStack()
210+
t.Fatal("Writer.Flush returned no error when writing to underlying io.Writer failed")
211+
}
212+
}
213+
214+
func TestStreamCloseError(t *testing.T) {
215+
var bw bytes.Buffer
216+
w := closeableWriter{w: &bw}
217+
writer := NewWriter(&w)
218+
219+
_, err := writer.Write([]byte("cc"))
220+
failOnError(t, "Failed writing to compress object", err)
221+
222+
w.Close()
223+
if err = writer.Close(); err == nil {
224+
debug.PrintStack()
225+
t.Fatal("Writer.Close returned no error when writing to underlying io.Writer failed")
226+
}
227+
}
228+
152229
type breakingReader struct{}
153230

154231
func (r *breakingReader) Read(p []byte) (int, error) {

0 commit comments

Comments
 (0)