Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[zstd][#60] Add decompression size sanity Check #115

Merged
merged 5 commits into from
Apr 14, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 43 additions & 31 deletions zstd.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ var (
ErrEmptySlice = errors.New("Bytes slice is empty")
)

const (
zstdFrameHeaderSizeMax = 18 // From zstd.h. Since it's experimental API, hardcoding it
)

// CompressBound returns the worst case size needed for a destination buffer,
// which can be used to preallocate a destination buffer or select a previously
// allocated buffer from a pool.
Expand All @@ -46,6 +50,30 @@ func cCompressBound(srcSize int) int {
return int(C.ZSTD_compressBound(C.size_t(srcSize)))
}

// decompressSizeHint tries to give a hint on how much of the output buffer size we should have
// based on zstd frame descriptors. To prevent DOS from maliciously-created payloads, limit the size
func decompressSizeHint(src []byte) int {
// 1 MB or 10x input size
upperBound := 10 * len(src)
if upperBound < 1000*1000 {
upperBound = 1000 * 1000
}

hint := upperBound
if len(src) >= zstdFrameHeaderSizeMax {
hint = int(C.ZSTD_getFrameContentSize(unsafe.Pointer(&src[0]), C.size_t(len(src))))
if hint < 0 { // On error, just use upperBound
hint = upperBound
}
}

// Take the minimum of both
if hint > upperBound {
return upperBound
}
return hint
}

// Compress src into dst. If you have a buffer to use, you can pass it to
// prevent allocation. If it is too small, or if nil is passed, a new buffer
// will be allocated and returned.
Expand Down Expand Up @@ -97,41 +125,25 @@ func Decompress(dst, src []byte) ([]byte, error) {
if len(src) == 0 {
return []byte{}, ErrEmptySlice
}
decompress := func(dst, src []byte) ([]byte, error) {

cWritten := C.ZSTD_decompress(
unsafe.Pointer(&dst[0]),
C.size_t(len(dst)),
unsafe.Pointer(&src[0]),
C.size_t(len(src)))

written := int(cWritten)
// Check error
if err := getError(written); err != nil {
return nil, err
}
return dst[:written], nil
bound := decompressSizeHint(src)
if cap(dst) >= bound {
dst = dst[0:cap(dst)]
} else {
dst = make([]byte, bound)
}

if len(dst) == 0 {
// Attempt to use zStd to determine decompressed size (may result in error or 0)
size := int(C.ZSTD_getDecompressedSize(unsafe.Pointer(&src[0]), C.size_t(len(src))))
if err := getError(size); err != nil {
return nil, err
}

if size > 0 {
dst = make([]byte, size)
} else {
dst = make([]byte, len(src)*3) // starting guess
}
written := int(C.ZSTD_decompress(
unsafe.Pointer(&dst[0]),
C.size_t(len(dst)),
unsafe.Pointer(&src[0]),
C.size_t(len(src))))
err := getError(written)
if err == nil {
return dst[:written], nil
}
for i := 0; i < 3; i++ { // 3 tries to allocate a bigger buffer
result, err := decompress(dst, src)
if !IsDstSizeTooSmallError(err) {
return result, err
}
dst = make([]byte, len(dst)*2) // Grow buffer by 2
if !IsDstSizeTooSmallError(err) {
return nil, err
}

// We failed getting a dst buffer of correct size, use stream API
Expand Down
9 changes: 2 additions & 7 deletions zstd_bulk.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ var (
ErrEmptyDictionary = errors.New("Dictionary is empty")
// ErrBadDictionary is returned when cannot load the given dictionary
ErrBadDictionary = errors.New("Cannot load dictionary")
// ErrContentSize is returned when cannot determine the content size
ErrContentSize = errors.New("Cannot determine the content size")
)

// BulkProcessor implements Bulk processing dictionary API.
Expand Down Expand Up @@ -111,12 +109,9 @@ func (p *BulkProcessor) Decompress(dst, src []byte) ([]byte, error) {
if len(src) == 0 {
return nil, ErrEmptySlice
}
contentSize := uint64(C.ZSTD_getFrameContentSize(unsafe.Pointer(&src[0]), C.size_t(len(src))))
if contentSize == C.ZSTD_CONTENTSIZE_ERROR || contentSize == C.ZSTD_CONTENTSIZE_UNKNOWN {
return nil, ErrContentSize
}

if cap(dst) >= int(contentSize) {
contentSize := decompressSizeHint(src)
if cap(dst) >= contentSize {
dst = dst[0:contentSize]
} else {
dst = make([]byte, contentSize)
Expand Down
48 changes: 16 additions & 32 deletions zstd_ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,43 +96,27 @@ func (c *ctx) Decompress(dst, src []byte) ([]byte, error) {
if len(src) == 0 {
return []byte{}, ErrEmptySlice
}
decompress := func(dst, src []byte) ([]byte, error) {

cWritten := C.ZSTD_decompressDCtx(
c.dctx,
unsafe.Pointer(&dst[0]),
C.size_t(len(dst)),
unsafe.Pointer(&src[0]),
C.size_t(len(src)))

written := int(cWritten)
// Check error
if err := getError(written); err != nil {
return nil, err
}
return dst[:written], nil
bound := decompressSizeHint(src)
if cap(dst) >= bound {
dst = dst[0:cap(dst)]
} else {
dst = make([]byte, bound)
}

if len(dst) == 0 {
// Attempt to use zStd to determine decompressed size (may result in error or 0)
size := int(C.size_t(C.ZSTD_getDecompressedSize(unsafe.Pointer(&src[0]), C.size_t(len(src)))))

if err := getError(size); err != nil {
return nil, err
}
written := int(C.ZSTD_decompressDCtx(
c.dctx,
unsafe.Pointer(&dst[0]),
C.size_t(len(dst)),
unsafe.Pointer(&src[0]),
C.size_t(len(src))))

if size > 0 {
dst = make([]byte, size)
} else {
dst = make([]byte, len(src)*3) // starting guess
}
err := getError(written)
if err == nil {
return dst[:written], nil
}
for i := 0; i < 3; i++ { // 3 tries to allocate a bigger buffer
result, err := decompress(dst, src)
if !IsDstSizeTooSmallError(err) {
return result, err
}
dst = make([]byte, len(dst)*2) // Grow buffer by 2
if !IsDstSizeTooSmallError(err) {
return nil, err
}

// We failed getting a dst buffer of correct size, use stream API
Expand Down
9 changes: 9 additions & 0 deletions zstd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package zstd

import (
"bytes"
b64 "encoding/base64"
"errors"
"fmt"
"io/ioutil"
Expand Down Expand Up @@ -284,6 +285,14 @@ func TestLegacy(t *testing.T) {
}
}

func TestBadPayloadZipBomb(t *testing.T) {
payload, _ := b64.StdEncoding.DecodeString("KLUv/dcwMDAwMDAwMDAwMAAA")
_, err := Decompress(nil, payload)
if err.Error() != "Src size is incorrect" {
t.Fatal("zstd should detect that the size is incorrect")
}
}

func BenchmarkCompression(b *testing.B) {
if raw == nil {
b.Fatal(ErrNoPayloadEnv)
Expand Down