Skip to content

Commit 396ef5b

Browse files
Nick Terrellterrelln
Nick Terrell
authored andcommitted
Fix & refactor Huffman repeat tables for dictionaries
The Huffman repeat mode checker assumed that the CTable was zeroed in the region `[maxSymbolValue + 1, 256)`. This assumption didn't hold for tables built in the dictionaries, because it didn't go through the same codepath. Since this code was originally written, we added a header to the CTable that specifies the `tableLog`. Add `maxSymbolValue` to that header, and check that the table's `maxSymbolValue` is at least the block's `maxSymbolValue`. This solution is cleaner because we write this header for every CTable we build, so it can't be missed in any code path. Credit to OSS-Fuzz
1 parent c27fa39 commit 396ef5b

File tree

2 files changed

+58
-18
lines changed

2 files changed

+58
-18
lines changed

lib/common/huf.h

+14-1
Original file line numberDiff line numberDiff line change
@@ -197,9 +197,22 @@ size_t HUF_readCTable (HUF_CElt* CTable, unsigned* maxSymbolValuePtr, const void
197197

198198
/** HUF_getNbBitsFromCTable() :
199199
* Read nbBits from CTable symbolTable, for symbol `symbolValue` presumed <= HUF_SYMBOLVALUE_MAX
200-
* Note 1 : is not inlined, as HUF_CElt definition is private */
200+
* Note 1 : If symbolValue > HUF_readCTableHeader(symbolTable).maxSymbolValue, returns 0
201+
* Note 2 : is not inlined, as HUF_CElt definition is private
202+
*/
201203
U32 HUF_getNbBitsFromCTable(const HUF_CElt* symbolTable, U32 symbolValue);
202204

205+
typedef struct {
206+
BYTE tableLog;
207+
BYTE maxSymbolValue;
208+
BYTE unused[sizeof(size_t) - 2];
209+
} HUF_CTableHeader;
210+
211+
/** HUF_readCTableHeader() :
212+
* @returns The header from the CTable specifying the tableLog and the maxSymbolValue.
213+
*/
214+
HUF_CTableHeader HUF_readCTableHeader(HUF_CElt const* ctable);
215+
203216
/*
204217
* HUF_decompress() does the following:
205218
* 1. select the decompression algorithm (X1, X2) based on pre-computed heuristics

lib/compress/huf_compress.c

+44-17
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,25 @@ static void HUF_setValue(HUF_CElt* elt, size_t value)
220220
}
221221
}
222222

223+
HUF_CTableHeader HUF_readCTableHeader(HUF_CElt const* ctable)
224+
{
225+
HUF_CTableHeader header;
226+
ZSTD_memcpy(&header, ctable, sizeof(header));
227+
return header;
228+
}
229+
230+
static void HUF_writeCTableHeader(HUF_CElt* ctable, U32 tableLog, U32 maxSymbolValue)
231+
{
232+
HUF_CTableHeader header;
233+
HUF_STATIC_ASSERT(sizeof(ctable[0]) == sizeof(header));
234+
ZSTD_memset(&header, 0, sizeof(header));
235+
assert(tableLog < 256);
236+
header.tableLog = (BYTE)tableLog;
237+
assert(maxSymbolValue < 256);
238+
header.maxSymbolValue = (BYTE)maxSymbolValue;
239+
ZSTD_memcpy(ctable, &header, sizeof(header));
240+
}
241+
223242
typedef struct {
224243
HUF_CompressWeightsWksp wksp;
225244
BYTE bitsToWeight[HUF_TABLELOG_MAX + 1]; /* precomputed conversion table */
@@ -237,6 +256,9 @@ size_t HUF_writeCTable_wksp(void* dst, size_t maxDstSize,
237256

238257
HUF_STATIC_ASSERT(HUF_CTABLE_WORKSPACE_SIZE >= sizeof(HUF_WriteCTableWksp));
239258

259+
assert(HUF_readCTableHeader(CTable).maxSymbolValue == maxSymbolValue);
260+
assert(HUF_readCTableHeader(CTable).tableLog == huffLog);
261+
240262
/* check conditions */
241263
if (workspaceSize < sizeof(HUF_WriteCTableWksp)) return ERROR(GENERIC);
242264
if (maxSymbolValue > HUF_SYMBOLVALUE_MAX) return ERROR(maxSymbolValue_tooLarge);
@@ -283,7 +305,9 @@ size_t HUF_readCTable (HUF_CElt* CTable, unsigned* maxSymbolValuePtr, const void
283305
if (tableLog > HUF_TABLELOG_MAX) return ERROR(tableLog_tooLarge);
284306
if (nbSymbols > *maxSymbolValuePtr+1) return ERROR(maxSymbolValue_tooSmall);
285307

286-
CTable[0] = tableLog;
308+
*maxSymbolValuePtr = nbSymbols - 1;
309+
310+
HUF_writeCTableHeader(CTable, tableLog, *maxSymbolValuePtr);
287311

288312
/* Prepare base value per rank */
289313
{ U32 n, nextRankStart = 0;
@@ -315,14 +339,15 @@ size_t HUF_readCTable (HUF_CElt* CTable, unsigned* maxSymbolValuePtr, const void
315339
{ U32 n; for (n=0; n<nbSymbols; n++) HUF_setValue(ct + n, valPerRank[HUF_getNbBits(ct[n])]++); }
316340
}
317341

318-
*maxSymbolValuePtr = nbSymbols - 1;
319342
return readSize;
320343
}
321344

322345
U32 HUF_getNbBitsFromCTable(HUF_CElt const* CTable, U32 symbolValue)
323346
{
324347
const HUF_CElt* const ct = CTable + 1;
325348
assert(symbolValue <= HUF_SYMBOLVALUE_MAX);
349+
if (symbolValue > HUF_readCTableHeader(CTable).maxSymbolValue)
350+
return 0;
326351
return (U32)HUF_getNbBits(ct[symbolValue]);
327352
}
328353

@@ -723,7 +748,8 @@ static void HUF_buildCTableFromTree(HUF_CElt* CTable, nodeElt const* huffNode, i
723748
HUF_setNbBits(ct + huffNode[n].byte, huffNode[n].nbBits); /* push nbBits per symbol, symbol order */
724749
for (n=0; n<alphabetSize; n++)
725750
HUF_setValue(ct + n, valPerRank[HUF_getNbBits(ct[n])]++); /* assign value within rank, symbol order */
726-
CTable[0] = maxNbBits;
751+
752+
HUF_writeCTableHeader(CTable, maxNbBits, maxSymbolValue);
727753
}
728754

729755
size_t
@@ -776,13 +802,20 @@ size_t HUF_estimateCompressedSize(const HUF_CElt* CTable, const unsigned* count,
776802
}
777803

778804
int HUF_validateCTable(const HUF_CElt* CTable, const unsigned* count, unsigned maxSymbolValue) {
779-
HUF_CElt const* ct = CTable + 1;
780-
int bad = 0;
781-
int s;
782-
for (s = 0; s <= (int)maxSymbolValue; ++s) {
783-
bad |= (count[s] != 0) & (HUF_getNbBits(ct[s]) == 0);
784-
}
785-
return !bad;
805+
HUF_CTableHeader header = HUF_readCTableHeader(CTable);
806+
HUF_CElt const* ct = CTable + 1;
807+
int bad = 0;
808+
int s;
809+
810+
assert(header.tableLog <= HUF_TABLELOG_ABSOLUTEMAX);
811+
812+
if (header.maxSymbolValue < maxSymbolValue)
813+
return 0;
814+
815+
for (s = 0; s <= (int)maxSymbolValue; ++s) {
816+
bad |= (count[s] != 0) & (HUF_getNbBits(ct[s]) == 0);
817+
}
818+
return !bad;
786819
}
787820

788821
size_t HUF_compressBound(size_t size) { return HUF_COMPRESSBOUND(size); }
@@ -1024,7 +1057,7 @@ HUF_compress1X_usingCTable_internal_body(void* dst, size_t dstSize,
10241057
const void* src, size_t srcSize,
10251058
const HUF_CElt* CTable)
10261059
{
1027-
U32 const tableLog = (U32)CTable[0];
1060+
U32 const tableLog = HUF_readCTableHeader(CTable).tableLog;
10281061
HUF_CElt const* ct = CTable + 1;
10291062
const BYTE* ip = (const BYTE*) src;
10301063
BYTE* const ostart = (BYTE*)dst;
@@ -1372,12 +1405,6 @@ HUF_compress_internal (void* dst, size_t dstSize,
13721405
huffLog = (U32)maxBits;
13731406
DEBUGLOG(6, "bit distribution completed (%zu symbols)", showCTableBits(table->CTable + 1, maxSymbolValue+1));
13741407
}
1375-
/* Zero unused symbols in CTable, so we can check it for validity */
1376-
{
1377-
size_t const ctableSize = HUF_CTABLE_SIZE_ST(maxSymbolValue);
1378-
size_t const unusedSize = sizeof(table->CTable) - ctableSize * sizeof(HUF_CElt);
1379-
ZSTD_memset(table->CTable + ctableSize, 0, unusedSize);
1380-
}
13811408

13821409
/* Write table description header */
13831410
{ CHECK_V_F(hSize, HUF_writeCTable_wksp(op, dstSize, table->CTable, maxSymbolValue, huffLog,

0 commit comments

Comments
 (0)