1*3117ece4Schristos /* ****************************************************************** 2*3117ece4Schristos * Common functions of New Generation Entropy library 3*3117ece4Schristos * Copyright (c) Meta Platforms, Inc. and affiliates. 4*3117ece4Schristos * 5*3117ece4Schristos * You can contact the author at : 6*3117ece4Schristos * - FSE+HUF source repository : https://github.com/Cyan4973/FiniteStateEntropy 7*3117ece4Schristos * - Public forum : https://groups.google.com/forum/#!forum/lz4c 8*3117ece4Schristos * 9*3117ece4Schristos * This source code is licensed under both the BSD-style license (found in the 10*3117ece4Schristos * LICENSE file in the root directory of this source tree) and the GPLv2 (found 11*3117ece4Schristos * in the COPYING file in the root directory of this source tree). 12*3117ece4Schristos * You may select, at your option, one of the above-listed licenses. 13*3117ece4Schristos ****************************************************************** */ 14*3117ece4Schristos 15*3117ece4Schristos /* ************************************* 16*3117ece4Schristos * Dependencies 17*3117ece4Schristos ***************************************/ 18*3117ece4Schristos #include "mem.h" 19*3117ece4Schristos #include "error_private.h" /* ERR_*, ERROR */ 20*3117ece4Schristos #define FSE_STATIC_LINKING_ONLY /* FSE_MIN_TABLELOG */ 21*3117ece4Schristos #include "fse.h" 22*3117ece4Schristos #include "huf.h" 23*3117ece4Schristos #include "bits.h" /* ZSDT_highbit32, ZSTD_countTrailingZeros32 */ 24*3117ece4Schristos 25*3117ece4Schristos 26*3117ece4Schristos /*=== Version ===*/ 27*3117ece4Schristos unsigned FSE_versionNumber(void) { return FSE_VERSION_NUMBER; } 28*3117ece4Schristos 29*3117ece4Schristos 30*3117ece4Schristos /*=== Error Management ===*/ 31*3117ece4Schristos unsigned FSE_isError(size_t code) { return ERR_isError(code); } 32*3117ece4Schristos const char* FSE_getErrorName(size_t code) { return ERR_getErrorName(code); } 33*3117ece4Schristos 34*3117ece4Schristos unsigned HUF_isError(size_t code) { return ERR_isError(code); } 35*3117ece4Schristos const char* HUF_getErrorName(size_t code) { return ERR_getErrorName(code); } 36*3117ece4Schristos 37*3117ece4Schristos 38*3117ece4Schristos /*-************************************************************** 39*3117ece4Schristos * FSE NCount encoding-decoding 40*3117ece4Schristos ****************************************************************/ 41*3117ece4Schristos FORCE_INLINE_TEMPLATE 42*3117ece4Schristos size_t FSE_readNCount_body(short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr, 43*3117ece4Schristos const void* headerBuffer, size_t hbSize) 44*3117ece4Schristos { 45*3117ece4Schristos const BYTE* const istart = (const BYTE*) headerBuffer; 46*3117ece4Schristos const BYTE* const iend = istart + hbSize; 47*3117ece4Schristos const BYTE* ip = istart; 48*3117ece4Schristos int nbBits; 49*3117ece4Schristos int remaining; 50*3117ece4Schristos int threshold; 51*3117ece4Schristos U32 bitStream; 52*3117ece4Schristos int bitCount; 53*3117ece4Schristos unsigned charnum = 0; 54*3117ece4Schristos unsigned const maxSV1 = *maxSVPtr + 1; 55*3117ece4Schristos int previous0 = 0; 56*3117ece4Schristos 57*3117ece4Schristos if (hbSize < 8) { 58*3117ece4Schristos /* This function only works when hbSize >= 8 */ 59*3117ece4Schristos char buffer[8] = {0}; 60*3117ece4Schristos ZSTD_memcpy(buffer, headerBuffer, hbSize); 61*3117ece4Schristos { size_t const countSize = FSE_readNCount(normalizedCounter, maxSVPtr, tableLogPtr, 62*3117ece4Schristos buffer, sizeof(buffer)); 63*3117ece4Schristos if (FSE_isError(countSize)) return countSize; 64*3117ece4Schristos if (countSize > hbSize) return ERROR(corruption_detected); 65*3117ece4Schristos return countSize; 66*3117ece4Schristos } } 67*3117ece4Schristos assert(hbSize >= 8); 68*3117ece4Schristos 69*3117ece4Schristos /* init */ 70*3117ece4Schristos ZSTD_memset(normalizedCounter, 0, (*maxSVPtr+1) * sizeof(normalizedCounter[0])); /* all symbols not present in NCount have a frequency of 0 */ 71*3117ece4Schristos bitStream = MEM_readLE32(ip); 72*3117ece4Schristos nbBits = (bitStream & 0xF) + FSE_MIN_TABLELOG; /* extract tableLog */ 73*3117ece4Schristos if (nbBits > FSE_TABLELOG_ABSOLUTE_MAX) return ERROR(tableLog_tooLarge); 74*3117ece4Schristos bitStream >>= 4; 75*3117ece4Schristos bitCount = 4; 76*3117ece4Schristos *tableLogPtr = nbBits; 77*3117ece4Schristos remaining = (1<<nbBits)+1; 78*3117ece4Schristos threshold = 1<<nbBits; 79*3117ece4Schristos nbBits++; 80*3117ece4Schristos 81*3117ece4Schristos for (;;) { 82*3117ece4Schristos if (previous0) { 83*3117ece4Schristos /* Count the number of repeats. Each time the 84*3117ece4Schristos * 2-bit repeat code is 0b11 there is another 85*3117ece4Schristos * repeat. 86*3117ece4Schristos * Avoid UB by setting the high bit to 1. 87*3117ece4Schristos */ 88*3117ece4Schristos int repeats = ZSTD_countTrailingZeros32(~bitStream | 0x80000000) >> 1; 89*3117ece4Schristos while (repeats >= 12) { 90*3117ece4Schristos charnum += 3 * 12; 91*3117ece4Schristos if (LIKELY(ip <= iend-7)) { 92*3117ece4Schristos ip += 3; 93*3117ece4Schristos } else { 94*3117ece4Schristos bitCount -= (int)(8 * (iend - 7 - ip)); 95*3117ece4Schristos bitCount &= 31; 96*3117ece4Schristos ip = iend - 4; 97*3117ece4Schristos } 98*3117ece4Schristos bitStream = MEM_readLE32(ip) >> bitCount; 99*3117ece4Schristos repeats = ZSTD_countTrailingZeros32(~bitStream | 0x80000000) >> 1; 100*3117ece4Schristos } 101*3117ece4Schristos charnum += 3 * repeats; 102*3117ece4Schristos bitStream >>= 2 * repeats; 103*3117ece4Schristos bitCount += 2 * repeats; 104*3117ece4Schristos 105*3117ece4Schristos /* Add the final repeat which isn't 0b11. */ 106*3117ece4Schristos assert((bitStream & 3) < 3); 107*3117ece4Schristos charnum += bitStream & 3; 108*3117ece4Schristos bitCount += 2; 109*3117ece4Schristos 110*3117ece4Schristos /* This is an error, but break and return an error 111*3117ece4Schristos * at the end, because returning out of a loop makes 112*3117ece4Schristos * it harder for the compiler to optimize. 113*3117ece4Schristos */ 114*3117ece4Schristos if (charnum >= maxSV1) break; 115*3117ece4Schristos 116*3117ece4Schristos /* We don't need to set the normalized count to 0 117*3117ece4Schristos * because we already memset the whole buffer to 0. 118*3117ece4Schristos */ 119*3117ece4Schristos 120*3117ece4Schristos if (LIKELY(ip <= iend-7) || (ip + (bitCount>>3) <= iend-4)) { 121*3117ece4Schristos assert((bitCount >> 3) <= 3); /* For first condition to work */ 122*3117ece4Schristos ip += bitCount>>3; 123*3117ece4Schristos bitCount &= 7; 124*3117ece4Schristos } else { 125*3117ece4Schristos bitCount -= (int)(8 * (iend - 4 - ip)); 126*3117ece4Schristos bitCount &= 31; 127*3117ece4Schristos ip = iend - 4; 128*3117ece4Schristos } 129*3117ece4Schristos bitStream = MEM_readLE32(ip) >> bitCount; 130*3117ece4Schristos } 131*3117ece4Schristos { 132*3117ece4Schristos int const max = (2*threshold-1) - remaining; 133*3117ece4Schristos int count; 134*3117ece4Schristos 135*3117ece4Schristos if ((bitStream & (threshold-1)) < (U32)max) { 136*3117ece4Schristos count = bitStream & (threshold-1); 137*3117ece4Schristos bitCount += nbBits-1; 138*3117ece4Schristos } else { 139*3117ece4Schristos count = bitStream & (2*threshold-1); 140*3117ece4Schristos if (count >= threshold) count -= max; 141*3117ece4Schristos bitCount += nbBits; 142*3117ece4Schristos } 143*3117ece4Schristos 144*3117ece4Schristos count--; /* extra accuracy */ 145*3117ece4Schristos /* When it matters (small blocks), this is a 146*3117ece4Schristos * predictable branch, because we don't use -1. 147*3117ece4Schristos */ 148*3117ece4Schristos if (count >= 0) { 149*3117ece4Schristos remaining -= count; 150*3117ece4Schristos } else { 151*3117ece4Schristos assert(count == -1); 152*3117ece4Schristos remaining += count; 153*3117ece4Schristos } 154*3117ece4Schristos normalizedCounter[charnum++] = (short)count; 155*3117ece4Schristos previous0 = !count; 156*3117ece4Schristos 157*3117ece4Schristos assert(threshold > 1); 158*3117ece4Schristos if (remaining < threshold) { 159*3117ece4Schristos /* This branch can be folded into the 160*3117ece4Schristos * threshold update condition because we 161*3117ece4Schristos * know that threshold > 1. 162*3117ece4Schristos */ 163*3117ece4Schristos if (remaining <= 1) break; 164*3117ece4Schristos nbBits = ZSTD_highbit32(remaining) + 1; 165*3117ece4Schristos threshold = 1 << (nbBits - 1); 166*3117ece4Schristos } 167*3117ece4Schristos if (charnum >= maxSV1) break; 168*3117ece4Schristos 169*3117ece4Schristos if (LIKELY(ip <= iend-7) || (ip + (bitCount>>3) <= iend-4)) { 170*3117ece4Schristos ip += bitCount>>3; 171*3117ece4Schristos bitCount &= 7; 172*3117ece4Schristos } else { 173*3117ece4Schristos bitCount -= (int)(8 * (iend - 4 - ip)); 174*3117ece4Schristos bitCount &= 31; 175*3117ece4Schristos ip = iend - 4; 176*3117ece4Schristos } 177*3117ece4Schristos bitStream = MEM_readLE32(ip) >> bitCount; 178*3117ece4Schristos } } 179*3117ece4Schristos if (remaining != 1) return ERROR(corruption_detected); 180*3117ece4Schristos /* Only possible when there are too many zeros. */ 181*3117ece4Schristos if (charnum > maxSV1) return ERROR(maxSymbolValue_tooSmall); 182*3117ece4Schristos if (bitCount > 32) return ERROR(corruption_detected); 183*3117ece4Schristos *maxSVPtr = charnum-1; 184*3117ece4Schristos 185*3117ece4Schristos ip += (bitCount+7)>>3; 186*3117ece4Schristos return ip-istart; 187*3117ece4Schristos } 188*3117ece4Schristos 189*3117ece4Schristos /* Avoids the FORCE_INLINE of the _body() function. */ 190*3117ece4Schristos static size_t FSE_readNCount_body_default( 191*3117ece4Schristos short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr, 192*3117ece4Schristos const void* headerBuffer, size_t hbSize) 193*3117ece4Schristos { 194*3117ece4Schristos return FSE_readNCount_body(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize); 195*3117ece4Schristos } 196*3117ece4Schristos 197*3117ece4Schristos #if DYNAMIC_BMI2 198*3117ece4Schristos BMI2_TARGET_ATTRIBUTE static size_t FSE_readNCount_body_bmi2( 199*3117ece4Schristos short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr, 200*3117ece4Schristos const void* headerBuffer, size_t hbSize) 201*3117ece4Schristos { 202*3117ece4Schristos return FSE_readNCount_body(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize); 203*3117ece4Schristos } 204*3117ece4Schristos #endif 205*3117ece4Schristos 206*3117ece4Schristos size_t FSE_readNCount_bmi2( 207*3117ece4Schristos short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr, 208*3117ece4Schristos const void* headerBuffer, size_t hbSize, int bmi2) 209*3117ece4Schristos { 210*3117ece4Schristos #if DYNAMIC_BMI2 211*3117ece4Schristos if (bmi2) { 212*3117ece4Schristos return FSE_readNCount_body_bmi2(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize); 213*3117ece4Schristos } 214*3117ece4Schristos #endif 215*3117ece4Schristos (void)bmi2; 216*3117ece4Schristos return FSE_readNCount_body_default(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize); 217*3117ece4Schristos } 218*3117ece4Schristos 219*3117ece4Schristos size_t FSE_readNCount( 220*3117ece4Schristos short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr, 221*3117ece4Schristos const void* headerBuffer, size_t hbSize) 222*3117ece4Schristos { 223*3117ece4Schristos return FSE_readNCount_bmi2(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize, /* bmi2 */ 0); 224*3117ece4Schristos } 225*3117ece4Schristos 226*3117ece4Schristos 227*3117ece4Schristos /*! HUF_readStats() : 228*3117ece4Schristos Read compact Huffman tree, saved by HUF_writeCTable(). 229*3117ece4Schristos `huffWeight` is destination buffer. 230*3117ece4Schristos `rankStats` is assumed to be a table of at least HUF_TABLELOG_MAX U32. 231*3117ece4Schristos @return : size read from `src` , or an error Code . 232*3117ece4Schristos Note : Needed by HUF_readCTable() and HUF_readDTableX?() . 233*3117ece4Schristos */ 234*3117ece4Schristos size_t HUF_readStats(BYTE* huffWeight, size_t hwSize, U32* rankStats, 235*3117ece4Schristos U32* nbSymbolsPtr, U32* tableLogPtr, 236*3117ece4Schristos const void* src, size_t srcSize) 237*3117ece4Schristos { 238*3117ece4Schristos U32 wksp[HUF_READ_STATS_WORKSPACE_SIZE_U32]; 239*3117ece4Schristos return HUF_readStats_wksp(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, wksp, sizeof(wksp), /* flags */ 0); 240*3117ece4Schristos } 241*3117ece4Schristos 242*3117ece4Schristos FORCE_INLINE_TEMPLATE size_t 243*3117ece4Schristos HUF_readStats_body(BYTE* huffWeight, size_t hwSize, U32* rankStats, 244*3117ece4Schristos U32* nbSymbolsPtr, U32* tableLogPtr, 245*3117ece4Schristos const void* src, size_t srcSize, 246*3117ece4Schristos void* workSpace, size_t wkspSize, 247*3117ece4Schristos int bmi2) 248*3117ece4Schristos { 249*3117ece4Schristos U32 weightTotal; 250*3117ece4Schristos const BYTE* ip = (const BYTE*) src; 251*3117ece4Schristos size_t iSize; 252*3117ece4Schristos size_t oSize; 253*3117ece4Schristos 254*3117ece4Schristos if (!srcSize) return ERROR(srcSize_wrong); 255*3117ece4Schristos iSize = ip[0]; 256*3117ece4Schristos /* ZSTD_memset(huffWeight, 0, hwSize); *//* is not necessary, even though some analyzer complain ... */ 257*3117ece4Schristos 258*3117ece4Schristos if (iSize >= 128) { /* special header */ 259*3117ece4Schristos oSize = iSize - 127; 260*3117ece4Schristos iSize = ((oSize+1)/2); 261*3117ece4Schristos if (iSize+1 > srcSize) return ERROR(srcSize_wrong); 262*3117ece4Schristos if (oSize >= hwSize) return ERROR(corruption_detected); 263*3117ece4Schristos ip += 1; 264*3117ece4Schristos { U32 n; 265*3117ece4Schristos for (n=0; n<oSize; n+=2) { 266*3117ece4Schristos huffWeight[n] = ip[n/2] >> 4; 267*3117ece4Schristos huffWeight[n+1] = ip[n/2] & 15; 268*3117ece4Schristos } } } 269*3117ece4Schristos else { /* header compressed with FSE (normal case) */ 270*3117ece4Schristos if (iSize+1 > srcSize) return ERROR(srcSize_wrong); 271*3117ece4Schristos /* max (hwSize-1) values decoded, as last one is implied */ 272*3117ece4Schristos oSize = FSE_decompress_wksp_bmi2(huffWeight, hwSize-1, ip+1, iSize, 6, workSpace, wkspSize, bmi2); 273*3117ece4Schristos if (FSE_isError(oSize)) return oSize; 274*3117ece4Schristos } 275*3117ece4Schristos 276*3117ece4Schristos /* collect weight stats */ 277*3117ece4Schristos ZSTD_memset(rankStats, 0, (HUF_TABLELOG_MAX + 1) * sizeof(U32)); 278*3117ece4Schristos weightTotal = 0; 279*3117ece4Schristos { U32 n; for (n=0; n<oSize; n++) { 280*3117ece4Schristos if (huffWeight[n] > HUF_TABLELOG_MAX) return ERROR(corruption_detected); 281*3117ece4Schristos rankStats[huffWeight[n]]++; 282*3117ece4Schristos weightTotal += (1 << huffWeight[n]) >> 1; 283*3117ece4Schristos } } 284*3117ece4Schristos if (weightTotal == 0) return ERROR(corruption_detected); 285*3117ece4Schristos 286*3117ece4Schristos /* get last non-null symbol weight (implied, total must be 2^n) */ 287*3117ece4Schristos { U32 const tableLog = ZSTD_highbit32(weightTotal) + 1; 288*3117ece4Schristos if (tableLog > HUF_TABLELOG_MAX) return ERROR(corruption_detected); 289*3117ece4Schristos *tableLogPtr = tableLog; 290*3117ece4Schristos /* determine last weight */ 291*3117ece4Schristos { U32 const total = 1 << tableLog; 292*3117ece4Schristos U32 const rest = total - weightTotal; 293*3117ece4Schristos U32 const verif = 1 << ZSTD_highbit32(rest); 294*3117ece4Schristos U32 const lastWeight = ZSTD_highbit32(rest) + 1; 295*3117ece4Schristos if (verif != rest) return ERROR(corruption_detected); /* last value must be a clean power of 2 */ 296*3117ece4Schristos huffWeight[oSize] = (BYTE)lastWeight; 297*3117ece4Schristos rankStats[lastWeight]++; 298*3117ece4Schristos } } 299*3117ece4Schristos 300*3117ece4Schristos /* check tree construction validity */ 301*3117ece4Schristos if ((rankStats[1] < 2) || (rankStats[1] & 1)) return ERROR(corruption_detected); /* by construction : at least 2 elts of rank 1, must be even */ 302*3117ece4Schristos 303*3117ece4Schristos /* results */ 304*3117ece4Schristos *nbSymbolsPtr = (U32)(oSize+1); 305*3117ece4Schristos return iSize+1; 306*3117ece4Schristos } 307*3117ece4Schristos 308*3117ece4Schristos /* Avoids the FORCE_INLINE of the _body() function. */ 309*3117ece4Schristos static size_t HUF_readStats_body_default(BYTE* huffWeight, size_t hwSize, U32* rankStats, 310*3117ece4Schristos U32* nbSymbolsPtr, U32* tableLogPtr, 311*3117ece4Schristos const void* src, size_t srcSize, 312*3117ece4Schristos void* workSpace, size_t wkspSize) 313*3117ece4Schristos { 314*3117ece4Schristos return HUF_readStats_body(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, workSpace, wkspSize, 0); 315*3117ece4Schristos } 316*3117ece4Schristos 317*3117ece4Schristos #if DYNAMIC_BMI2 318*3117ece4Schristos static BMI2_TARGET_ATTRIBUTE size_t HUF_readStats_body_bmi2(BYTE* huffWeight, size_t hwSize, U32* rankStats, 319*3117ece4Schristos U32* nbSymbolsPtr, U32* tableLogPtr, 320*3117ece4Schristos const void* src, size_t srcSize, 321*3117ece4Schristos void* workSpace, size_t wkspSize) 322*3117ece4Schristos { 323*3117ece4Schristos return HUF_readStats_body(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, workSpace, wkspSize, 1); 324*3117ece4Schristos } 325*3117ece4Schristos #endif 326*3117ece4Schristos 327*3117ece4Schristos size_t HUF_readStats_wksp(BYTE* huffWeight, size_t hwSize, U32* rankStats, 328*3117ece4Schristos U32* nbSymbolsPtr, U32* tableLogPtr, 329*3117ece4Schristos const void* src, size_t srcSize, 330*3117ece4Schristos void* workSpace, size_t wkspSize, 331*3117ece4Schristos int flags) 332*3117ece4Schristos { 333*3117ece4Schristos #if DYNAMIC_BMI2 334*3117ece4Schristos if (flags & HUF_flags_bmi2) { 335*3117ece4Schristos return HUF_readStats_body_bmi2(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, workSpace, wkspSize); 336*3117ece4Schristos } 337*3117ece4Schristos #endif 338*3117ece4Schristos (void)flags; 339*3117ece4Schristos return HUF_readStats_body_default(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, workSpace, wkspSize); 340*3117ece4Schristos } 341