Skip to content

Commit 30c4b29

Browse files
committed
[zstd] Add a sanity limit to decompress buffer size allocation
Fix #60 Before we were blindly trusting the data returned by ZSTD_getDecompressedSize. This mean with a specifically crafter payload, we would try to allocate a lot of memory resulting in potential DOS. Now set a sane limit and fall back to streaming
1 parent eaf4b06 commit 30c4b29

File tree

1 file changed

+33
-12
lines changed

1 file changed

+33
-12
lines changed

zstd.go

+33-12
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ var (
2828
ErrEmptySlice = errors.New("Bytes slice is empty")
2929
)
3030

31+
const (
32+
zstdFrameHeaderSizeMax = 18 // From zstd.h. Since it's experimental API, hardcoding it
33+
)
34+
3135
// CompressBound returns the worst case size needed for a destination buffer,
3236
// which can be used to preallocate a destination buffer or select a previously
3337
// allocated buffer from a pool.
@@ -46,6 +50,30 @@ func cCompressBound(srcSize int) int {
4650
return int(C.ZSTD_compressBound(C.size_t(srcSize)))
4751
}
4852

53+
// decompressSizeHint tries to give a hint on how much of the output buffer size we should have
54+
// based on zstd frame descriptors. To prevent DOS from maliciously-created payloads, limit the size
55+
func decompressSizeHint(src []byte) int {
56+
// 1 MB or 10x input size
57+
upperBound := 10 * len(src)
58+
if upperBound < 1000*1000 {
59+
upperBound = 1000 * 1000
60+
}
61+
62+
hint := upperBound
63+
if len(src) >= zstdFrameHeaderSizeMax {
64+
hint = int(C.ZSTD_getFrameContentSize(unsafe.Pointer(&src[0]), C.size_t(len(src))))
65+
if hint < 0 { // On error, just use upperBound
66+
hint = upperBound
67+
}
68+
}
69+
70+
// Take the minimum of both
71+
if hint > upperBound {
72+
return upperBound
73+
}
74+
return hint
75+
}
76+
4977
// Compress src into dst. If you have a buffer to use, you can pass it to
5078
// prevent allocation. If it is too small, or if nil is passed, a new buffer
5179
// will be allocated and returned.
@@ -113,18 +141,11 @@ func Decompress(dst, src []byte) ([]byte, error) {
113141
return dst[:written], nil
114142
}
115143

116-
if len(dst) == 0 {
117-
// Attempt to use zStd to determine decompressed size (may result in error or 0)
118-
size := int(C.ZSTD_getDecompressedSize(unsafe.Pointer(&src[0]), C.size_t(len(src))))
119-
if err := getError(size); err != nil {
120-
return nil, err
121-
}
122-
123-
if size > 0 {
124-
dst = make([]byte, size)
125-
} else {
126-
dst = make([]byte, len(src)*3) // starting guess
127-
}
144+
bound := decompressSizeHint(src)
145+
if cap(dst) >= bound {
146+
dst = dst[0:cap(dst)]
147+
} else {
148+
dst = make([]byte, bound)
128149
}
129150
for i := 0; i < 3; i++ { // 3 tries to allocate a bigger buffer
130151
result, err := decompress(dst, src)

0 commit comments

Comments
 (0)