1 /* ****************************************************************** 2 * Common functions of New Generation Entropy library 3 * Copyright (c) Meta Platforms, Inc. and affiliates. 4 * 5 * You can contact the author at : 6 * - FSE+HUF source repository : https://github.com/Cyan4973/FiniteStateEntropy 7 * - Public forum : https://groups.google.com/forum/#!forum/lz4c 8 * 9 * This source code is licensed under both the BSD-style license (found in the 10 * LICENSE file in the root directory of this source tree) and the GPLv2 (found 11 * in the COPYING file in the root directory of this source tree). 12 * You may select, at your option, one of the above-listed licenses. 13 ****************************************************************** */ 14 15 /* ************************************* 16 * Dependencies 17 ***************************************/ 18 #include "mem.h" 19 #include "error_private.h" /* ERR_*, ERROR */ 20 #define FSE_STATIC_LINKING_ONLY /* FSE_MIN_TABLELOG */ 21 #include "fse.h" 22 #include "huf.h" 23 #include "bits.h" /* ZSDT_highbit32, ZSTD_countTrailingZeros32 */ 24 25 26 /*=== Version ===*/ 27 unsigned FSE_versionNumber(void) { return FSE_VERSION_NUMBER; } 28 29 30 /*=== Error Management ===*/ 31 unsigned FSE_isError(size_t code) { return ERR_isError(code); } 32 const char* FSE_getErrorName(size_t code) { return ERR_getErrorName(code); } 33 34 unsigned HUF_isError(size_t code) { return ERR_isError(code); } 35 const char* HUF_getErrorName(size_t code) { return ERR_getErrorName(code); } 36 37 38 /*-************************************************************** 39 * FSE NCount encoding-decoding 40 ****************************************************************/ 41 FORCE_INLINE_TEMPLATE 42 size_t FSE_readNCount_body(short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr, 43 const void* headerBuffer, size_t hbSize) 44 { 45 const BYTE* const istart = (const BYTE*) headerBuffer; 46 const BYTE* const iend = istart + hbSize; 47 const BYTE* ip = istart; 48 int nbBits; 49 int remaining; 50 int threshold; 51 U32 bitStream; 52 int bitCount; 53 unsigned charnum = 0; 54 unsigned const maxSV1 = *maxSVPtr + 1; 55 int previous0 = 0; 56 57 if (hbSize < 8) { 58 /* This function only works when hbSize >= 8 */ 59 char buffer[8] = {0}; 60 ZSTD_memcpy(buffer, headerBuffer, hbSize); 61 { size_t const countSize = FSE_readNCount(normalizedCounter, maxSVPtr, tableLogPtr, 62 buffer, sizeof(buffer)); 63 if (FSE_isError(countSize)) return countSize; 64 if (countSize > hbSize) return ERROR(corruption_detected); 65 return countSize; 66 } } 67 assert(hbSize >= 8); 68 69 /* init */ 70 ZSTD_memset(normalizedCounter, 0, (*maxSVPtr+1) * sizeof(normalizedCounter[0])); /* all symbols not present in NCount have a frequency of 0 */ 71 bitStream = MEM_readLE32(ip); 72 nbBits = (bitStream & 0xF) + FSE_MIN_TABLELOG; /* extract tableLog */ 73 if (nbBits > FSE_TABLELOG_ABSOLUTE_MAX) return ERROR(tableLog_tooLarge); 74 bitStream >>= 4; 75 bitCount = 4; 76 *tableLogPtr = nbBits; 77 remaining = (1<<nbBits)+1; 78 threshold = 1<<nbBits; 79 nbBits++; 80 81 for (;;) { 82 if (previous0) { 83 /* Count the number of repeats. Each time the 84 * 2-bit repeat code is 0b11 there is another 85 * repeat. 86 * Avoid UB by setting the high bit to 1. 87 */ 88 int repeats = ZSTD_countTrailingZeros32(~bitStream | 0x80000000) >> 1; 89 while (repeats >= 12) { 90 charnum += 3 * 12; 91 if (LIKELY(ip <= iend-7)) { 92 ip += 3; 93 } else { 94 bitCount -= (int)(8 * (iend - 7 - ip)); 95 bitCount &= 31; 96 ip = iend - 4; 97 } 98 bitStream = MEM_readLE32(ip) >> bitCount; 99 repeats = ZSTD_countTrailingZeros32(~bitStream | 0x80000000) >> 1; 100 } 101 charnum += 3 * repeats; 102 bitStream >>= 2 * repeats; 103 bitCount += 2 * repeats; 104 105 /* Add the final repeat which isn't 0b11. */ 106 assert((bitStream & 3) < 3); 107 charnum += bitStream & 3; 108 bitCount += 2; 109 110 /* This is an error, but break and return an error 111 * at the end, because returning out of a loop makes 112 * it harder for the compiler to optimize. 113 */ 114 if (charnum >= maxSV1) break; 115 116 /* We don't need to set the normalized count to 0 117 * because we already memset the whole buffer to 0. 118 */ 119 120 if (LIKELY(ip <= iend-7) || (ip + (bitCount>>3) <= iend-4)) { 121 assert((bitCount >> 3) <= 3); /* For first condition to work */ 122 ip += bitCount>>3; 123 bitCount &= 7; 124 } else { 125 bitCount -= (int)(8 * (iend - 4 - ip)); 126 bitCount &= 31; 127 ip = iend - 4; 128 } 129 bitStream = MEM_readLE32(ip) >> bitCount; 130 } 131 { 132 int const max = (2*threshold-1) - remaining; 133 int count; 134 135 if ((bitStream & (threshold-1)) < (U32)max) { 136 count = bitStream & (threshold-1); 137 bitCount += nbBits-1; 138 } else { 139 count = bitStream & (2*threshold-1); 140 if (count >= threshold) count -= max; 141 bitCount += nbBits; 142 } 143 144 count--; /* extra accuracy */ 145 /* When it matters (small blocks), this is a 146 * predictable branch, because we don't use -1. 147 */ 148 if (count >= 0) { 149 remaining -= count; 150 } else { 151 assert(count == -1); 152 remaining += count; 153 } 154 normalizedCounter[charnum++] = (short)count; 155 previous0 = !count; 156 157 assert(threshold > 1); 158 if (remaining < threshold) { 159 /* This branch can be folded into the 160 * threshold update condition because we 161 * know that threshold > 1. 162 */ 163 if (remaining <= 1) break; 164 nbBits = ZSTD_highbit32(remaining) + 1; 165 threshold = 1 << (nbBits - 1); 166 } 167 if (charnum >= maxSV1) break; 168 169 if (LIKELY(ip <= iend-7) || (ip + (bitCount>>3) <= iend-4)) { 170 ip += bitCount>>3; 171 bitCount &= 7; 172 } else { 173 bitCount -= (int)(8 * (iend - 4 - ip)); 174 bitCount &= 31; 175 ip = iend - 4; 176 } 177 bitStream = MEM_readLE32(ip) >> bitCount; 178 } } 179 if (remaining != 1) return ERROR(corruption_detected); 180 /* Only possible when there are too many zeros. */ 181 if (charnum > maxSV1) return ERROR(maxSymbolValue_tooSmall); 182 if (bitCount > 32) return ERROR(corruption_detected); 183 *maxSVPtr = charnum-1; 184 185 ip += (bitCount+7)>>3; 186 return ip-istart; 187 } 188 189 /* Avoids the FORCE_INLINE of the _body() function. */ 190 static size_t FSE_readNCount_body_default( 191 short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr, 192 const void* headerBuffer, size_t hbSize) 193 { 194 return FSE_readNCount_body(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize); 195 } 196 197 #if DYNAMIC_BMI2 198 BMI2_TARGET_ATTRIBUTE static size_t FSE_readNCount_body_bmi2( 199 short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr, 200 const void* headerBuffer, size_t hbSize) 201 { 202 return FSE_readNCount_body(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize); 203 } 204 #endif 205 206 size_t FSE_readNCount_bmi2( 207 short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr, 208 const void* headerBuffer, size_t hbSize, int bmi2) 209 { 210 #if DYNAMIC_BMI2 211 if (bmi2) { 212 return FSE_readNCount_body_bmi2(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize); 213 } 214 #endif 215 (void)bmi2; 216 return FSE_readNCount_body_default(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize); 217 } 218 219 size_t FSE_readNCount( 220 short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr, 221 const void* headerBuffer, size_t hbSize) 222 { 223 return FSE_readNCount_bmi2(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize, /* bmi2 */ 0); 224 } 225 226 227 /*! HUF_readStats() : 228 Read compact Huffman tree, saved by HUF_writeCTable(). 229 `huffWeight` is destination buffer. 230 `rankStats` is assumed to be a table of at least HUF_TABLELOG_MAX U32. 231 @return : size read from `src` , or an error Code . 232 Note : Needed by HUF_readCTable() and HUF_readDTableX?() . 233 */ 234 size_t HUF_readStats(BYTE* huffWeight, size_t hwSize, U32* rankStats, 235 U32* nbSymbolsPtr, U32* tableLogPtr, 236 const void* src, size_t srcSize) 237 { 238 U32 wksp[HUF_READ_STATS_WORKSPACE_SIZE_U32]; 239 return HUF_readStats_wksp(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, wksp, sizeof(wksp), /* flags */ 0); 240 } 241 242 FORCE_INLINE_TEMPLATE size_t 243 HUF_readStats_body(BYTE* huffWeight, size_t hwSize, U32* rankStats, 244 U32* nbSymbolsPtr, U32* tableLogPtr, 245 const void* src, size_t srcSize, 246 void* workSpace, size_t wkspSize, 247 int bmi2) 248 { 249 U32 weightTotal; 250 const BYTE* ip = (const BYTE*) src; 251 size_t iSize; 252 size_t oSize; 253 254 if (!srcSize) return ERROR(srcSize_wrong); 255 iSize = ip[0]; 256 /* ZSTD_memset(huffWeight, 0, hwSize); *//* is not necessary, even though some analyzer complain ... */ 257 258 if (iSize >= 128) { /* special header */ 259 oSize = iSize - 127; 260 iSize = ((oSize+1)/2); 261 if (iSize+1 > srcSize) return ERROR(srcSize_wrong); 262 if (oSize >= hwSize) return ERROR(corruption_detected); 263 ip += 1; 264 { U32 n; 265 for (n=0; n<oSize; n+=2) { 266 huffWeight[n] = ip[n/2] >> 4; 267 huffWeight[n+1] = ip[n/2] & 15; 268 } } } 269 else { /* header compressed with FSE (normal case) */ 270 if (iSize+1 > srcSize) return ERROR(srcSize_wrong); 271 /* max (hwSize-1) values decoded, as last one is implied */ 272 oSize = FSE_decompress_wksp_bmi2(huffWeight, hwSize-1, ip+1, iSize, 6, workSpace, wkspSize, bmi2); 273 if (FSE_isError(oSize)) return oSize; 274 } 275 276 /* collect weight stats */ 277 ZSTD_memset(rankStats, 0, (HUF_TABLELOG_MAX + 1) * sizeof(U32)); 278 weightTotal = 0; 279 { U32 n; for (n=0; n<oSize; n++) { 280 if (huffWeight[n] > HUF_TABLELOG_MAX) return ERROR(corruption_detected); 281 rankStats[huffWeight[n]]++; 282 weightTotal += (1 << huffWeight[n]) >> 1; 283 } } 284 if (weightTotal == 0) return ERROR(corruption_detected); 285 286 /* get last non-null symbol weight (implied, total must be 2^n) */ 287 { U32 const tableLog = ZSTD_highbit32(weightTotal) + 1; 288 if (tableLog > HUF_TABLELOG_MAX) return ERROR(corruption_detected); 289 *tableLogPtr = tableLog; 290 /* determine last weight */ 291 { U32 const total = 1 << tableLog; 292 U32 const rest = total - weightTotal; 293 U32 const verif = 1 << ZSTD_highbit32(rest); 294 U32 const lastWeight = ZSTD_highbit32(rest) + 1; 295 if (verif != rest) return ERROR(corruption_detected); /* last value must be a clean power of 2 */ 296 huffWeight[oSize] = (BYTE)lastWeight; 297 rankStats[lastWeight]++; 298 } } 299 300 /* check tree construction validity */ 301 if ((rankStats[1] < 2) || (rankStats[1] & 1)) return ERROR(corruption_detected); /* by construction : at least 2 elts of rank 1, must be even */ 302 303 /* results */ 304 *nbSymbolsPtr = (U32)(oSize+1); 305 return iSize+1; 306 } 307 308 /* Avoids the FORCE_INLINE of the _body() function. */ 309 static size_t HUF_readStats_body_default(BYTE* huffWeight, size_t hwSize, U32* rankStats, 310 U32* nbSymbolsPtr, U32* tableLogPtr, 311 const void* src, size_t srcSize, 312 void* workSpace, size_t wkspSize) 313 { 314 return HUF_readStats_body(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, workSpace, wkspSize, 0); 315 } 316 317 #if DYNAMIC_BMI2 318 static BMI2_TARGET_ATTRIBUTE size_t HUF_readStats_body_bmi2(BYTE* huffWeight, size_t hwSize, U32* rankStats, 319 U32* nbSymbolsPtr, U32* tableLogPtr, 320 const void* src, size_t srcSize, 321 void* workSpace, size_t wkspSize) 322 { 323 return HUF_readStats_body(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, workSpace, wkspSize, 1); 324 } 325 #endif 326 327 size_t HUF_readStats_wksp(BYTE* huffWeight, size_t hwSize, U32* rankStats, 328 U32* nbSymbolsPtr, U32* tableLogPtr, 329 const void* src, size_t srcSize, 330 void* workSpace, size_t wkspSize, 331 int flags) 332 { 333 #if DYNAMIC_BMI2 334 if (flags & HUF_flags_bmi2) { 335 return HUF_readStats_body_bmi2(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, workSpace, wkspSize); 336 } 337 #endif 338 (void)flags; 339 return HUF_readStats_body_default(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, workSpace, wkspSize); 340 } 341