xref: /netbsd-src/external/bsd/zstd/dist/lib/common/entropy_common.c (revision 3117ece4fc4a4ca4489ba793710b60b0d26bab6c)
1*3117ece4Schristos /* ******************************************************************
2*3117ece4Schristos  * Common functions of New Generation Entropy library
3*3117ece4Schristos  * Copyright (c) Meta Platforms, Inc. and affiliates.
4*3117ece4Schristos  *
5*3117ece4Schristos  *  You can contact the author at :
6*3117ece4Schristos  *  - FSE+HUF source repository : https://github.com/Cyan4973/FiniteStateEntropy
7*3117ece4Schristos  *  - Public forum : https://groups.google.com/forum/#!forum/lz4c
8*3117ece4Schristos  *
9*3117ece4Schristos  * This source code is licensed under both the BSD-style license (found in the
10*3117ece4Schristos  * LICENSE file in the root directory of this source tree) and the GPLv2 (found
11*3117ece4Schristos  * in the COPYING file in the root directory of this source tree).
12*3117ece4Schristos  * You may select, at your option, one of the above-listed licenses.
13*3117ece4Schristos ****************************************************************** */
14*3117ece4Schristos 
15*3117ece4Schristos /* *************************************
16*3117ece4Schristos *  Dependencies
17*3117ece4Schristos ***************************************/
18*3117ece4Schristos #include "mem.h"
19*3117ece4Schristos #include "error_private.h"       /* ERR_*, ERROR */
20*3117ece4Schristos #define FSE_STATIC_LINKING_ONLY  /* FSE_MIN_TABLELOG */
21*3117ece4Schristos #include "fse.h"
22*3117ece4Schristos #include "huf.h"
23*3117ece4Schristos #include "bits.h"                /* ZSDT_highbit32, ZSTD_countTrailingZeros32 */
24*3117ece4Schristos 
25*3117ece4Schristos 
26*3117ece4Schristos /*===   Version   ===*/
27*3117ece4Schristos unsigned FSE_versionNumber(void) { return FSE_VERSION_NUMBER; }
28*3117ece4Schristos 
29*3117ece4Schristos 
30*3117ece4Schristos /*===   Error Management   ===*/
31*3117ece4Schristos unsigned FSE_isError(size_t code) { return ERR_isError(code); }
32*3117ece4Schristos const char* FSE_getErrorName(size_t code) { return ERR_getErrorName(code); }
33*3117ece4Schristos 
34*3117ece4Schristos unsigned HUF_isError(size_t code) { return ERR_isError(code); }
35*3117ece4Schristos const char* HUF_getErrorName(size_t code) { return ERR_getErrorName(code); }
36*3117ece4Schristos 
37*3117ece4Schristos 
38*3117ece4Schristos /*-**************************************************************
39*3117ece4Schristos *  FSE NCount encoding-decoding
40*3117ece4Schristos ****************************************************************/
41*3117ece4Schristos FORCE_INLINE_TEMPLATE
42*3117ece4Schristos size_t FSE_readNCount_body(short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr,
43*3117ece4Schristos                            const void* headerBuffer, size_t hbSize)
44*3117ece4Schristos {
45*3117ece4Schristos     const BYTE* const istart = (const BYTE*) headerBuffer;
46*3117ece4Schristos     const BYTE* const iend = istart + hbSize;
47*3117ece4Schristos     const BYTE* ip = istart;
48*3117ece4Schristos     int nbBits;
49*3117ece4Schristos     int remaining;
50*3117ece4Schristos     int threshold;
51*3117ece4Schristos     U32 bitStream;
52*3117ece4Schristos     int bitCount;
53*3117ece4Schristos     unsigned charnum = 0;
54*3117ece4Schristos     unsigned const maxSV1 = *maxSVPtr + 1;
55*3117ece4Schristos     int previous0 = 0;
56*3117ece4Schristos 
57*3117ece4Schristos     if (hbSize < 8) {
58*3117ece4Schristos         /* This function only works when hbSize >= 8 */
59*3117ece4Schristos         char buffer[8] = {0};
60*3117ece4Schristos         ZSTD_memcpy(buffer, headerBuffer, hbSize);
61*3117ece4Schristos         {   size_t const countSize = FSE_readNCount(normalizedCounter, maxSVPtr, tableLogPtr,
62*3117ece4Schristos                                                     buffer, sizeof(buffer));
63*3117ece4Schristos             if (FSE_isError(countSize)) return countSize;
64*3117ece4Schristos             if (countSize > hbSize) return ERROR(corruption_detected);
65*3117ece4Schristos             return countSize;
66*3117ece4Schristos     }   }
67*3117ece4Schristos     assert(hbSize >= 8);
68*3117ece4Schristos 
69*3117ece4Schristos     /* init */
70*3117ece4Schristos     ZSTD_memset(normalizedCounter, 0, (*maxSVPtr+1) * sizeof(normalizedCounter[0]));   /* all symbols not present in NCount have a frequency of 0 */
71*3117ece4Schristos     bitStream = MEM_readLE32(ip);
72*3117ece4Schristos     nbBits = (bitStream & 0xF) + FSE_MIN_TABLELOG;   /* extract tableLog */
73*3117ece4Schristos     if (nbBits > FSE_TABLELOG_ABSOLUTE_MAX) return ERROR(tableLog_tooLarge);
74*3117ece4Schristos     bitStream >>= 4;
75*3117ece4Schristos     bitCount = 4;
76*3117ece4Schristos     *tableLogPtr = nbBits;
77*3117ece4Schristos     remaining = (1<<nbBits)+1;
78*3117ece4Schristos     threshold = 1<<nbBits;
79*3117ece4Schristos     nbBits++;
80*3117ece4Schristos 
81*3117ece4Schristos     for (;;) {
82*3117ece4Schristos         if (previous0) {
83*3117ece4Schristos             /* Count the number of repeats. Each time the
84*3117ece4Schristos              * 2-bit repeat code is 0b11 there is another
85*3117ece4Schristos              * repeat.
86*3117ece4Schristos              * Avoid UB by setting the high bit to 1.
87*3117ece4Schristos              */
88*3117ece4Schristos             int repeats = ZSTD_countTrailingZeros32(~bitStream | 0x80000000) >> 1;
89*3117ece4Schristos             while (repeats >= 12) {
90*3117ece4Schristos                 charnum += 3 * 12;
91*3117ece4Schristos                 if (LIKELY(ip <= iend-7)) {
92*3117ece4Schristos                     ip += 3;
93*3117ece4Schristos                 } else {
94*3117ece4Schristos                     bitCount -= (int)(8 * (iend - 7 - ip));
95*3117ece4Schristos                     bitCount &= 31;
96*3117ece4Schristos                     ip = iend - 4;
97*3117ece4Schristos                 }
98*3117ece4Schristos                 bitStream = MEM_readLE32(ip) >> bitCount;
99*3117ece4Schristos                 repeats = ZSTD_countTrailingZeros32(~bitStream | 0x80000000) >> 1;
100*3117ece4Schristos             }
101*3117ece4Schristos             charnum += 3 * repeats;
102*3117ece4Schristos             bitStream >>= 2 * repeats;
103*3117ece4Schristos             bitCount += 2 * repeats;
104*3117ece4Schristos 
105*3117ece4Schristos             /* Add the final repeat which isn't 0b11. */
106*3117ece4Schristos             assert((bitStream & 3) < 3);
107*3117ece4Schristos             charnum += bitStream & 3;
108*3117ece4Schristos             bitCount += 2;
109*3117ece4Schristos 
110*3117ece4Schristos             /* This is an error, but break and return an error
111*3117ece4Schristos              * at the end, because returning out of a loop makes
112*3117ece4Schristos              * it harder for the compiler to optimize.
113*3117ece4Schristos              */
114*3117ece4Schristos             if (charnum >= maxSV1) break;
115*3117ece4Schristos 
116*3117ece4Schristos             /* We don't need to set the normalized count to 0
117*3117ece4Schristos              * because we already memset the whole buffer to 0.
118*3117ece4Schristos              */
119*3117ece4Schristos 
120*3117ece4Schristos             if (LIKELY(ip <= iend-7) || (ip + (bitCount>>3) <= iend-4)) {
121*3117ece4Schristos                 assert((bitCount >> 3) <= 3); /* For first condition to work */
122*3117ece4Schristos                 ip += bitCount>>3;
123*3117ece4Schristos                 bitCount &= 7;
124*3117ece4Schristos             } else {
125*3117ece4Schristos                 bitCount -= (int)(8 * (iend - 4 - ip));
126*3117ece4Schristos                 bitCount &= 31;
127*3117ece4Schristos                 ip = iend - 4;
128*3117ece4Schristos             }
129*3117ece4Schristos             bitStream = MEM_readLE32(ip) >> bitCount;
130*3117ece4Schristos         }
131*3117ece4Schristos         {
132*3117ece4Schristos             int const max = (2*threshold-1) - remaining;
133*3117ece4Schristos             int count;
134*3117ece4Schristos 
135*3117ece4Schristos             if ((bitStream & (threshold-1)) < (U32)max) {
136*3117ece4Schristos                 count = bitStream & (threshold-1);
137*3117ece4Schristos                 bitCount += nbBits-1;
138*3117ece4Schristos             } else {
139*3117ece4Schristos                 count = bitStream & (2*threshold-1);
140*3117ece4Schristos                 if (count >= threshold) count -= max;
141*3117ece4Schristos                 bitCount += nbBits;
142*3117ece4Schristos             }
143*3117ece4Schristos 
144*3117ece4Schristos             count--;   /* extra accuracy */
145*3117ece4Schristos             /* When it matters (small blocks), this is a
146*3117ece4Schristos              * predictable branch, because we don't use -1.
147*3117ece4Schristos              */
148*3117ece4Schristos             if (count >= 0) {
149*3117ece4Schristos                 remaining -= count;
150*3117ece4Schristos             } else {
151*3117ece4Schristos                 assert(count == -1);
152*3117ece4Schristos                 remaining += count;
153*3117ece4Schristos             }
154*3117ece4Schristos             normalizedCounter[charnum++] = (short)count;
155*3117ece4Schristos             previous0 = !count;
156*3117ece4Schristos 
157*3117ece4Schristos             assert(threshold > 1);
158*3117ece4Schristos             if (remaining < threshold) {
159*3117ece4Schristos                 /* This branch can be folded into the
160*3117ece4Schristos                  * threshold update condition because we
161*3117ece4Schristos                  * know that threshold > 1.
162*3117ece4Schristos                  */
163*3117ece4Schristos                 if (remaining <= 1) break;
164*3117ece4Schristos                 nbBits = ZSTD_highbit32(remaining) + 1;
165*3117ece4Schristos                 threshold = 1 << (nbBits - 1);
166*3117ece4Schristos             }
167*3117ece4Schristos             if (charnum >= maxSV1) break;
168*3117ece4Schristos 
169*3117ece4Schristos             if (LIKELY(ip <= iend-7) || (ip + (bitCount>>3) <= iend-4)) {
170*3117ece4Schristos                 ip += bitCount>>3;
171*3117ece4Schristos                 bitCount &= 7;
172*3117ece4Schristos             } else {
173*3117ece4Schristos                 bitCount -= (int)(8 * (iend - 4 - ip));
174*3117ece4Schristos                 bitCount &= 31;
175*3117ece4Schristos                 ip = iend - 4;
176*3117ece4Schristos             }
177*3117ece4Schristos             bitStream = MEM_readLE32(ip) >> bitCount;
178*3117ece4Schristos     }   }
179*3117ece4Schristos     if (remaining != 1) return ERROR(corruption_detected);
180*3117ece4Schristos     /* Only possible when there are too many zeros. */
181*3117ece4Schristos     if (charnum > maxSV1) return ERROR(maxSymbolValue_tooSmall);
182*3117ece4Schristos     if (bitCount > 32) return ERROR(corruption_detected);
183*3117ece4Schristos     *maxSVPtr = charnum-1;
184*3117ece4Schristos 
185*3117ece4Schristos     ip += (bitCount+7)>>3;
186*3117ece4Schristos     return ip-istart;
187*3117ece4Schristos }
188*3117ece4Schristos 
189*3117ece4Schristos /* Avoids the FORCE_INLINE of the _body() function. */
190*3117ece4Schristos static size_t FSE_readNCount_body_default(
191*3117ece4Schristos         short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr,
192*3117ece4Schristos         const void* headerBuffer, size_t hbSize)
193*3117ece4Schristos {
194*3117ece4Schristos     return FSE_readNCount_body(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize);
195*3117ece4Schristos }
196*3117ece4Schristos 
197*3117ece4Schristos #if DYNAMIC_BMI2
198*3117ece4Schristos BMI2_TARGET_ATTRIBUTE static size_t FSE_readNCount_body_bmi2(
199*3117ece4Schristos         short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr,
200*3117ece4Schristos         const void* headerBuffer, size_t hbSize)
201*3117ece4Schristos {
202*3117ece4Schristos     return FSE_readNCount_body(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize);
203*3117ece4Schristos }
204*3117ece4Schristos #endif
205*3117ece4Schristos 
206*3117ece4Schristos size_t FSE_readNCount_bmi2(
207*3117ece4Schristos         short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr,
208*3117ece4Schristos         const void* headerBuffer, size_t hbSize, int bmi2)
209*3117ece4Schristos {
210*3117ece4Schristos #if DYNAMIC_BMI2
211*3117ece4Schristos     if (bmi2) {
212*3117ece4Schristos         return FSE_readNCount_body_bmi2(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize);
213*3117ece4Schristos     }
214*3117ece4Schristos #endif
215*3117ece4Schristos     (void)bmi2;
216*3117ece4Schristos     return FSE_readNCount_body_default(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize);
217*3117ece4Schristos }
218*3117ece4Schristos 
219*3117ece4Schristos size_t FSE_readNCount(
220*3117ece4Schristos         short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr,
221*3117ece4Schristos         const void* headerBuffer, size_t hbSize)
222*3117ece4Schristos {
223*3117ece4Schristos     return FSE_readNCount_bmi2(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize, /* bmi2 */ 0);
224*3117ece4Schristos }
225*3117ece4Schristos 
226*3117ece4Schristos 
227*3117ece4Schristos /*! HUF_readStats() :
228*3117ece4Schristos     Read compact Huffman tree, saved by HUF_writeCTable().
229*3117ece4Schristos     `huffWeight` is destination buffer.
230*3117ece4Schristos     `rankStats` is assumed to be a table of at least HUF_TABLELOG_MAX U32.
231*3117ece4Schristos     @return : size read from `src` , or an error Code .
232*3117ece4Schristos     Note : Needed by HUF_readCTable() and HUF_readDTableX?() .
233*3117ece4Schristos */
234*3117ece4Schristos size_t HUF_readStats(BYTE* huffWeight, size_t hwSize, U32* rankStats,
235*3117ece4Schristos                      U32* nbSymbolsPtr, U32* tableLogPtr,
236*3117ece4Schristos                      const void* src, size_t srcSize)
237*3117ece4Schristos {
238*3117ece4Schristos     U32 wksp[HUF_READ_STATS_WORKSPACE_SIZE_U32];
239*3117ece4Schristos     return HUF_readStats_wksp(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, wksp, sizeof(wksp), /* flags */ 0);
240*3117ece4Schristos }
241*3117ece4Schristos 
242*3117ece4Schristos FORCE_INLINE_TEMPLATE size_t
243*3117ece4Schristos HUF_readStats_body(BYTE* huffWeight, size_t hwSize, U32* rankStats,
244*3117ece4Schristos                    U32* nbSymbolsPtr, U32* tableLogPtr,
245*3117ece4Schristos                    const void* src, size_t srcSize,
246*3117ece4Schristos                    void* workSpace, size_t wkspSize,
247*3117ece4Schristos                    int bmi2)
248*3117ece4Schristos {
249*3117ece4Schristos     U32 weightTotal;
250*3117ece4Schristos     const BYTE* ip = (const BYTE*) src;
251*3117ece4Schristos     size_t iSize;
252*3117ece4Schristos     size_t oSize;
253*3117ece4Schristos 
254*3117ece4Schristos     if (!srcSize) return ERROR(srcSize_wrong);
255*3117ece4Schristos     iSize = ip[0];
256*3117ece4Schristos     /* ZSTD_memset(huffWeight, 0, hwSize);   *//* is not necessary, even though some analyzer complain ... */
257*3117ece4Schristos 
258*3117ece4Schristos     if (iSize >= 128) {  /* special header */
259*3117ece4Schristos         oSize = iSize - 127;
260*3117ece4Schristos         iSize = ((oSize+1)/2);
261*3117ece4Schristos         if (iSize+1 > srcSize) return ERROR(srcSize_wrong);
262*3117ece4Schristos         if (oSize >= hwSize) return ERROR(corruption_detected);
263*3117ece4Schristos         ip += 1;
264*3117ece4Schristos         {   U32 n;
265*3117ece4Schristos             for (n=0; n<oSize; n+=2) {
266*3117ece4Schristos                 huffWeight[n]   = ip[n/2] >> 4;
267*3117ece4Schristos                 huffWeight[n+1] = ip[n/2] & 15;
268*3117ece4Schristos     }   }   }
269*3117ece4Schristos     else  {   /* header compressed with FSE (normal case) */
270*3117ece4Schristos         if (iSize+1 > srcSize) return ERROR(srcSize_wrong);
271*3117ece4Schristos         /* max (hwSize-1) values decoded, as last one is implied */
272*3117ece4Schristos         oSize = FSE_decompress_wksp_bmi2(huffWeight, hwSize-1, ip+1, iSize, 6, workSpace, wkspSize, bmi2);
273*3117ece4Schristos         if (FSE_isError(oSize)) return oSize;
274*3117ece4Schristos     }
275*3117ece4Schristos 
276*3117ece4Schristos     /* collect weight stats */
277*3117ece4Schristos     ZSTD_memset(rankStats, 0, (HUF_TABLELOG_MAX + 1) * sizeof(U32));
278*3117ece4Schristos     weightTotal = 0;
279*3117ece4Schristos     {   U32 n; for (n=0; n<oSize; n++) {
280*3117ece4Schristos             if (huffWeight[n] > HUF_TABLELOG_MAX) return ERROR(corruption_detected);
281*3117ece4Schristos             rankStats[huffWeight[n]]++;
282*3117ece4Schristos             weightTotal += (1 << huffWeight[n]) >> 1;
283*3117ece4Schristos     }   }
284*3117ece4Schristos     if (weightTotal == 0) return ERROR(corruption_detected);
285*3117ece4Schristos 
286*3117ece4Schristos     /* get last non-null symbol weight (implied, total must be 2^n) */
287*3117ece4Schristos     {   U32 const tableLog = ZSTD_highbit32(weightTotal) + 1;
288*3117ece4Schristos         if (tableLog > HUF_TABLELOG_MAX) return ERROR(corruption_detected);
289*3117ece4Schristos         *tableLogPtr = tableLog;
290*3117ece4Schristos         /* determine last weight */
291*3117ece4Schristos         {   U32 const total = 1 << tableLog;
292*3117ece4Schristos             U32 const rest = total - weightTotal;
293*3117ece4Schristos             U32 const verif = 1 << ZSTD_highbit32(rest);
294*3117ece4Schristos             U32 const lastWeight = ZSTD_highbit32(rest) + 1;
295*3117ece4Schristos             if (verif != rest) return ERROR(corruption_detected);    /* last value must be a clean power of 2 */
296*3117ece4Schristos             huffWeight[oSize] = (BYTE)lastWeight;
297*3117ece4Schristos             rankStats[lastWeight]++;
298*3117ece4Schristos     }   }
299*3117ece4Schristos 
300*3117ece4Schristos     /* check tree construction validity */
301*3117ece4Schristos     if ((rankStats[1] < 2) || (rankStats[1] & 1)) return ERROR(corruption_detected);   /* by construction : at least 2 elts of rank 1, must be even */
302*3117ece4Schristos 
303*3117ece4Schristos     /* results */
304*3117ece4Schristos     *nbSymbolsPtr = (U32)(oSize+1);
305*3117ece4Schristos     return iSize+1;
306*3117ece4Schristos }
307*3117ece4Schristos 
308*3117ece4Schristos /* Avoids the FORCE_INLINE of the _body() function. */
309*3117ece4Schristos static size_t HUF_readStats_body_default(BYTE* huffWeight, size_t hwSize, U32* rankStats,
310*3117ece4Schristos                      U32* nbSymbolsPtr, U32* tableLogPtr,
311*3117ece4Schristos                      const void* src, size_t srcSize,
312*3117ece4Schristos                      void* workSpace, size_t wkspSize)
313*3117ece4Schristos {
314*3117ece4Schristos     return HUF_readStats_body(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, workSpace, wkspSize, 0);
315*3117ece4Schristos }
316*3117ece4Schristos 
317*3117ece4Schristos #if DYNAMIC_BMI2
318*3117ece4Schristos static BMI2_TARGET_ATTRIBUTE size_t HUF_readStats_body_bmi2(BYTE* huffWeight, size_t hwSize, U32* rankStats,
319*3117ece4Schristos                      U32* nbSymbolsPtr, U32* tableLogPtr,
320*3117ece4Schristos                      const void* src, size_t srcSize,
321*3117ece4Schristos                      void* workSpace, size_t wkspSize)
322*3117ece4Schristos {
323*3117ece4Schristos     return HUF_readStats_body(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, workSpace, wkspSize, 1);
324*3117ece4Schristos }
325*3117ece4Schristos #endif
326*3117ece4Schristos 
327*3117ece4Schristos size_t HUF_readStats_wksp(BYTE* huffWeight, size_t hwSize, U32* rankStats,
328*3117ece4Schristos                      U32* nbSymbolsPtr, U32* tableLogPtr,
329*3117ece4Schristos                      const void* src, size_t srcSize,
330*3117ece4Schristos                      void* workSpace, size_t wkspSize,
331*3117ece4Schristos                      int flags)
332*3117ece4Schristos {
333*3117ece4Schristos #if DYNAMIC_BMI2
334*3117ece4Schristos     if (flags & HUF_flags_bmi2) {
335*3117ece4Schristos         return HUF_readStats_body_bmi2(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, workSpace, wkspSize);
336*3117ece4Schristos     }
337*3117ece4Schristos #endif
338*3117ece4Schristos     (void)flags;
339*3117ece4Schristos     return HUF_readStats_body_default(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, workSpace, wkspSize);
340*3117ece4Schristos }
341