1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under both the BSD-style license (found in the 6 * LICENSE file in the root directory of this source tree) and the GPLv2 (found 7 * in the COPYING file in the root directory of this source tree). 8 * You may select, at your option, one of the above-listed licenses. 9 */ 10 11 /// Zstandard educational decoder implementation 12 /// See https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md 13 14 #include <stdint.h> // uint8_t, etc. 15 #include <stdlib.h> // malloc, free, exit 16 #include <stdio.h> // fprintf 17 #include <string.h> // memset, memcpy 18 #include "zstd_decompress.h" 19 20 21 /******* IMPORTANT CONSTANTS *********************************************/ 22 23 // Zstandard frame 24 // "Magic_Number 25 // 4 Bytes, little-endian format. Value : 0xFD2FB528" 26 #define ZSTD_MAGIC_NUMBER 0xFD2FB528U 27 28 // The size of `Block_Content` is limited by `Block_Maximum_Size`, 29 #define ZSTD_BLOCK_SIZE_MAX ((size_t)128 * 1024) 30 31 // literal blocks can't be larger than their block 32 #define MAX_LITERALS_SIZE ZSTD_BLOCK_SIZE_MAX 33 34 35 /******* UTILITY MACROS AND TYPES *********************************************/ 36 #define MAX(a, b) ((a) > (b) ? (a) : (b)) 37 #define MIN(a, b) ((a) < (b) ? (a) : (b)) 38 39 #if defined(ZDEC_NO_MESSAGE) 40 #define MESSAGE(...) 41 #else 42 #define MESSAGE(...) fprintf(stderr, "" __VA_ARGS__) 43 #endif 44 45 /// This decoder calls exit(1) when it encounters an error, however a production 46 /// library should propagate error codes 47 #define ERROR(s) \ 48 do { \ 49 MESSAGE("Error: %s\n", s); \ 50 exit(1); \ 51 } while (0) 52 #define INP_SIZE() \ 53 ERROR("Input buffer smaller than it should be or input is " \ 54 "corrupted") 55 #define OUT_SIZE() ERROR("Output buffer too small for output") 56 #define CORRUPTION() ERROR("Corruption detected while decompressing") 57 #define BAD_ALLOC() ERROR("Memory allocation error") 58 #define IMPOSSIBLE() ERROR("An impossibility has occurred") 59 60 typedef uint8_t u8; 61 typedef uint16_t u16; 62 typedef uint32_t u32; 63 typedef uint64_t u64; 64 65 typedef int8_t i8; 66 typedef int16_t i16; 67 typedef int32_t i32; 68 typedef int64_t i64; 69 /******* END UTILITY MACROS AND TYPES *****************************************/ 70 71 /******* IMPLEMENTATION PRIMITIVE PROTOTYPES **********************************/ 72 /// The implementations for these functions can be found at the bottom of this 73 /// file. They implement low-level functionality needed for the higher level 74 /// decompression functions. 75 76 /*** IO STREAM OPERATIONS *************/ 77 78 /// ostream_t/istream_t are used to wrap the pointers/length data passed into 79 /// ZSTD_decompress, so that all IO operations are safely bounds checked 80 /// They are written/read forward, and reads are treated as little-endian 81 /// They should be used opaquely to ensure safety 82 typedef struct { 83 u8 *ptr; 84 size_t len; 85 } ostream_t; 86 87 typedef struct { 88 const u8 *ptr; 89 size_t len; 90 91 // Input often reads a few bits at a time, so maintain an internal offset 92 int bit_offset; 93 } istream_t; 94 95 /// The following two functions are the only ones that allow the istream to be 96 /// non-byte aligned 97 98 /// Reads `num` bits from a bitstream, and updates the internal offset 99 static inline u64 IO_read_bits(istream_t *const in, const int num_bits); 100 /// Backs-up the stream by `num` bits so they can be read again 101 static inline void IO_rewind_bits(istream_t *const in, const int num_bits); 102 /// If the remaining bits in a byte will be unused, advance to the end of the 103 /// byte 104 static inline void IO_align_stream(istream_t *const in); 105 106 /// Write the given byte into the output stream 107 static inline void IO_write_byte(ostream_t *const out, u8 symb); 108 109 /// Returns the number of bytes left to be read in this stream. The stream must 110 /// be byte aligned. 111 static inline size_t IO_istream_len(const istream_t *const in); 112 113 /// Advances the stream by `len` bytes, and returns a pointer to the chunk that 114 /// was skipped. The stream must be byte aligned. 115 static inline const u8 *IO_get_read_ptr(istream_t *const in, size_t len); 116 /// Advances the stream by `len` bytes, and returns a pointer to the chunk that 117 /// was skipped so it can be written to. 118 static inline u8 *IO_get_write_ptr(ostream_t *const out, size_t len); 119 120 /// Advance the inner state by `len` bytes. The stream must be byte aligned. 121 static inline void IO_advance_input(istream_t *const in, size_t len); 122 123 /// Returns an `ostream_t` constructed from the given pointer and length. 124 static inline ostream_t IO_make_ostream(u8 *out, size_t len); 125 /// Returns an `istream_t` constructed from the given pointer and length. 126 static inline istream_t IO_make_istream(const u8 *in, size_t len); 127 128 /// Returns an `istream_t` with the same base as `in`, and length `len`. 129 /// Then, advance `in` to account for the consumed bytes. 130 /// `in` must be byte aligned. 131 static inline istream_t IO_make_sub_istream(istream_t *const in, size_t len); 132 /*** END IO STREAM OPERATIONS *********/ 133 134 /*** BITSTREAM OPERATIONS *************/ 135 /// Read `num` bits (up to 64) from `src + offset`, where `offset` is in bits, 136 /// and return them interpreted as a little-endian unsigned integer. 137 static inline u64 read_bits_LE(const u8 *src, const int num_bits, 138 const size_t offset); 139 140 /// Read bits from the end of a HUF or FSE bitstream. `offset` is in bits, so 141 /// it updates `offset` to `offset - bits`, and then reads `bits` bits from 142 /// `src + offset`. If the offset becomes negative, the extra bits at the 143 /// bottom are filled in with `0` bits instead of reading from before `src`. 144 static inline u64 STREAM_read_bits(const u8 *src, const int bits, 145 i64 *const offset); 146 /*** END BITSTREAM OPERATIONS *********/ 147 148 /*** BIT COUNTING OPERATIONS **********/ 149 /// Returns the index of the highest set bit in `num`, or `-1` if `num == 0` 150 static inline int highest_set_bit(const u64 num); 151 /*** END BIT COUNTING OPERATIONS ******/ 152 153 /*** HUFFMAN PRIMITIVES ***************/ 154 // Table decode method uses exponential memory, so we need to limit depth 155 #define HUF_MAX_BITS (16) 156 157 // Limit the maximum number of symbols to 256 so we can store a symbol in a byte 158 #define HUF_MAX_SYMBS (256) 159 160 /// Structure containing all tables necessary for efficient Huffman decoding 161 typedef struct { 162 u8 *symbols; 163 u8 *num_bits; 164 int max_bits; 165 } HUF_dtable; 166 167 /// Decode a single symbol and read in enough bits to refresh the state 168 static inline u8 HUF_decode_symbol(const HUF_dtable *const dtable, 169 u16 *const state, const u8 *const src, 170 i64 *const offset); 171 /// Read in a full state's worth of bits to initialize it 172 static inline void HUF_init_state(const HUF_dtable *const dtable, 173 u16 *const state, const u8 *const src, 174 i64 *const offset); 175 176 /// Decompresses a single Huffman stream, returns the number of bytes decoded. 177 /// `src_len` must be the exact length of the Huffman-coded block. 178 static size_t HUF_decompress_1stream(const HUF_dtable *const dtable, 179 ostream_t *const out, istream_t *const in); 180 /// Same as previous but decodes 4 streams, formatted as in the Zstandard 181 /// specification. 182 /// `src_len` must be the exact length of the Huffman-coded block. 183 static size_t HUF_decompress_4stream(const HUF_dtable *const dtable, 184 ostream_t *const out, istream_t *const in); 185 186 /// Initialize a Huffman decoding table using the table of bit counts provided 187 static void HUF_init_dtable(HUF_dtable *const table, const u8 *const bits, 188 const int num_symbs); 189 /// Initialize a Huffman decoding table using the table of weights provided 190 /// Weights follow the definition provided in the Zstandard specification 191 static void HUF_init_dtable_usingweights(HUF_dtable *const table, 192 const u8 *const weights, 193 const int num_symbs); 194 195 /// Free the malloc'ed parts of a decoding table 196 static void HUF_free_dtable(HUF_dtable *const dtable); 197 /*** END HUFFMAN PRIMITIVES ***********/ 198 199 /*** FSE PRIMITIVES *******************/ 200 /// For more description of FSE see 201 /// https://github.com/Cyan4973/FiniteStateEntropy/ 202 203 // FSE table decoding uses exponential memory, so limit the maximum accuracy 204 #define FSE_MAX_ACCURACY_LOG (15) 205 // Limit the maximum number of symbols so they can be stored in a single byte 206 #define FSE_MAX_SYMBS (256) 207 208 /// The tables needed to decode FSE encoded streams 209 typedef struct { 210 u8 *symbols; 211 u8 *num_bits; 212 u16 *new_state_base; 213 int accuracy_log; 214 } FSE_dtable; 215 216 /// Return the symbol for the current state 217 static inline u8 FSE_peek_symbol(const FSE_dtable *const dtable, 218 const u16 state); 219 /// Read the number of bits necessary to update state, update, and shift offset 220 /// back to reflect the bits read 221 static inline void FSE_update_state(const FSE_dtable *const dtable, 222 u16 *const state, const u8 *const src, 223 i64 *const offset); 224 225 /// Combine peek and update: decode a symbol and update the state 226 static inline u8 FSE_decode_symbol(const FSE_dtable *const dtable, 227 u16 *const state, const u8 *const src, 228 i64 *const offset); 229 230 /// Read bits from the stream to initialize the state and shift offset back 231 static inline void FSE_init_state(const FSE_dtable *const dtable, 232 u16 *const state, const u8 *const src, 233 i64 *const offset); 234 235 /// Decompress two interleaved bitstreams (e.g. compressed Huffman weights) 236 /// using an FSE decoding table. `src_len` must be the exact length of the 237 /// block. 238 static size_t FSE_decompress_interleaved2(const FSE_dtable *const dtable, 239 ostream_t *const out, 240 istream_t *const in); 241 242 /// Initialize a decoding table using normalized frequencies. 243 static void FSE_init_dtable(FSE_dtable *const dtable, 244 const i16 *const norm_freqs, const int num_symbs, 245 const int accuracy_log); 246 247 /// Decode an FSE header as defined in the Zstandard format specification and 248 /// use the decoded frequencies to initialize a decoding table. 249 static void FSE_decode_header(FSE_dtable *const dtable, istream_t *const in, 250 const int max_accuracy_log); 251 252 /// Initialize an FSE table that will always return the same symbol and consume 253 /// 0 bits per symbol, to be used for RLE mode in sequence commands 254 static void FSE_init_dtable_rle(FSE_dtable *const dtable, const u8 symb); 255 256 /// Free the malloc'ed parts of a decoding table 257 static void FSE_free_dtable(FSE_dtable *const dtable); 258 /*** END FSE PRIMITIVES ***************/ 259 260 /******* END IMPLEMENTATION PRIMITIVE PROTOTYPES ******************************/ 261 262 /******* ZSTD HELPER STRUCTS AND PROTOTYPES ***********************************/ 263 264 /// A small structure that can be reused in various places that need to access 265 /// frame header information 266 typedef struct { 267 // The size of window that we need to be able to contiguously store for 268 // references 269 size_t window_size; 270 // The total output size of this compressed frame 271 size_t frame_content_size; 272 273 // The dictionary id if this frame uses one 274 u32 dictionary_id; 275 276 // Whether or not the content of this frame has a checksum 277 int content_checksum_flag; 278 // Whether or not the output for this frame is in a single segment 279 int single_segment_flag; 280 } frame_header_t; 281 282 /// The context needed to decode blocks in a frame 283 typedef struct { 284 frame_header_t header; 285 286 // The total amount of data available for backreferences, to determine if an 287 // offset too large to be correct 288 size_t current_total_output; 289 290 const u8 *dict_content; 291 size_t dict_content_len; 292 293 // Entropy encoding tables so they can be repeated by future blocks instead 294 // of retransmitting 295 HUF_dtable literals_dtable; 296 FSE_dtable ll_dtable; 297 FSE_dtable ml_dtable; 298 FSE_dtable of_dtable; 299 300 // The last 3 offsets for the special "repeat offsets". 301 u64 previous_offsets[3]; 302 } frame_context_t; 303 304 /// The decoded contents of a dictionary so that it doesn't have to be repeated 305 /// for each frame that uses it 306 struct dictionary_s { 307 // Entropy tables 308 HUF_dtable literals_dtable; 309 FSE_dtable ll_dtable; 310 FSE_dtable ml_dtable; 311 FSE_dtable of_dtable; 312 313 // Raw content for backreferences 314 u8 *content; 315 size_t content_size; 316 317 // Offset history to prepopulate the frame's history 318 u64 previous_offsets[3]; 319 320 u32 dictionary_id; 321 }; 322 323 /// A tuple containing the parts necessary to decode and execute a ZSTD sequence 324 /// command 325 typedef struct { 326 u32 literal_length; 327 u32 match_length; 328 u32 offset; 329 } sequence_command_t; 330 331 /// The decoder works top-down, starting at the high level like Zstd frames, and 332 /// working down to lower more technical levels such as blocks, literals, and 333 /// sequences. The high-level functions roughly follow the outline of the 334 /// format specification: 335 /// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md 336 337 /// Before the implementation of each high-level function declared here, the 338 /// prototypes for their helper functions are defined and explained 339 340 /// Decode a single Zstd frame, or error if the input is not a valid frame. 341 /// Accepts a dict argument, which may be NULL indicating no dictionary. 342 /// See 343 /// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#frame-concatenation 344 static void decode_frame(ostream_t *const out, istream_t *const in, 345 const dictionary_t *const dict); 346 347 // Decode data in a compressed block 348 static void decompress_block(frame_context_t *const ctx, ostream_t *const out, 349 istream_t *const in); 350 351 // Decode the literals section of a block 352 static size_t decode_literals(frame_context_t *const ctx, istream_t *const in, 353 u8 **const literals); 354 355 // Decode the sequences part of a block 356 static size_t decode_sequences(frame_context_t *const ctx, istream_t *const in, 357 sequence_command_t **const sequences); 358 359 // Execute the decoded sequences on the literals block 360 static void execute_sequences(frame_context_t *const ctx, ostream_t *const out, 361 const u8 *const literals, 362 const size_t literals_len, 363 const sequence_command_t *const sequences, 364 const size_t num_sequences); 365 366 // Copies literals and returns the total literal length that was copied 367 static u32 copy_literals(const size_t seq, istream_t *litstream, 368 ostream_t *const out); 369 370 // Given an offset code from a sequence command (either an actual offset value 371 // or an index for previous offset), computes the correct offset and updates 372 // the offset history 373 static size_t compute_offset(sequence_command_t seq, u64 *const offset_hist); 374 375 // Given an offset, match length, and total output, as well as the frame 376 // context for the dictionary, determines if the dictionary is used and 377 // executes the copy operation 378 static void execute_match_copy(frame_context_t *const ctx, size_t offset, 379 size_t match_length, size_t total_output, 380 ostream_t *const out); 381 382 /******* END ZSTD HELPER STRUCTS AND PROTOTYPES *******************************/ 383 384 size_t ZSTD_decompress(void *const dst, const size_t dst_len, 385 const void *const src, const size_t src_len) { 386 dictionary_t* const uninit_dict = create_dictionary(); 387 size_t const decomp_size = ZSTD_decompress_with_dict(dst, dst_len, src, 388 src_len, uninit_dict); 389 free_dictionary(uninit_dict); 390 return decomp_size; 391 } 392 393 size_t ZSTD_decompress_with_dict(void *const dst, const size_t dst_len, 394 const void *const src, const size_t src_len, 395 dictionary_t* parsed_dict) { 396 397 istream_t in = IO_make_istream(src, src_len); 398 ostream_t out = IO_make_ostream(dst, dst_len); 399 400 // "A content compressed by Zstandard is transformed into a Zstandard frame. 401 // Multiple frames can be appended into a single file or stream. A frame is 402 // totally independent, has a defined beginning and end, and a set of 403 // parameters which tells the decoder how to decompress it." 404 405 /* this decoder assumes decompression of a single frame */ 406 decode_frame(&out, &in, parsed_dict); 407 408 return (size_t)(out.ptr - (u8 *)dst); 409 } 410 411 /******* FRAME DECODING ******************************************************/ 412 413 static void decode_data_frame(ostream_t *const out, istream_t *const in, 414 const dictionary_t *const dict); 415 static void init_frame_context(frame_context_t *const context, 416 istream_t *const in, 417 const dictionary_t *const dict); 418 static void free_frame_context(frame_context_t *const context); 419 static void parse_frame_header(frame_header_t *const header, 420 istream_t *const in); 421 static void frame_context_apply_dict(frame_context_t *const ctx, 422 const dictionary_t *const dict); 423 424 static void decompress_data(frame_context_t *const ctx, ostream_t *const out, 425 istream_t *const in); 426 427 static void decode_frame(ostream_t *const out, istream_t *const in, 428 const dictionary_t *const dict) { 429 const u32 magic_number = (u32)IO_read_bits(in, 32); 430 if (magic_number == ZSTD_MAGIC_NUMBER) { 431 // ZSTD frame 432 decode_data_frame(out, in, dict); 433 434 return; 435 } 436 437 // not a real frame or a skippable frame 438 ERROR("Tried to decode non-ZSTD frame"); 439 } 440 441 /// Decode a frame that contains compressed data. Not all frames do as there 442 /// are skippable frames. 443 /// See 444 /// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#general-structure-of-zstandard-frame-format 445 static void decode_data_frame(ostream_t *const out, istream_t *const in, 446 const dictionary_t *const dict) { 447 frame_context_t ctx; 448 449 // Initialize the context that needs to be carried from block to block 450 init_frame_context(&ctx, in, dict); 451 452 if (ctx.header.frame_content_size != 0 && 453 ctx.header.frame_content_size > out->len) { 454 OUT_SIZE(); 455 } 456 457 decompress_data(&ctx, out, in); 458 459 free_frame_context(&ctx); 460 } 461 462 /// Takes the information provided in the header and dictionary, and initializes 463 /// the context for this frame 464 static void init_frame_context(frame_context_t *const context, 465 istream_t *const in, 466 const dictionary_t *const dict) { 467 // Most fields in context are correct when initialized to 0 468 memset(context, 0, sizeof(frame_context_t)); 469 470 // Parse data from the frame header 471 parse_frame_header(&context->header, in); 472 473 // Set up the offset history for the repeat offset commands 474 context->previous_offsets[0] = 1; 475 context->previous_offsets[1] = 4; 476 context->previous_offsets[2] = 8; 477 478 // Apply details from the dict if it exists 479 frame_context_apply_dict(context, dict); 480 } 481 482 static void free_frame_context(frame_context_t *const context) { 483 HUF_free_dtable(&context->literals_dtable); 484 485 FSE_free_dtable(&context->ll_dtable); 486 FSE_free_dtable(&context->ml_dtable); 487 FSE_free_dtable(&context->of_dtable); 488 489 memset(context, 0, sizeof(frame_context_t)); 490 } 491 492 static void parse_frame_header(frame_header_t *const header, 493 istream_t *const in) { 494 // "The first header's byte is called the Frame_Header_Descriptor. It tells 495 // which other fields are present. Decoding this byte is enough to tell the 496 // size of Frame_Header. 497 // 498 // Bit number Field name 499 // 7-6 Frame_Content_Size_flag 500 // 5 Single_Segment_flag 501 // 4 Unused_bit 502 // 3 Reserved_bit 503 // 2 Content_Checksum_flag 504 // 1-0 Dictionary_ID_flag" 505 const u8 descriptor = (u8)IO_read_bits(in, 8); 506 507 // decode frame header descriptor into flags 508 const u8 frame_content_size_flag = descriptor >> 6; 509 const u8 single_segment_flag = (descriptor >> 5) & 1; 510 const u8 reserved_bit = (descriptor >> 3) & 1; 511 const u8 content_checksum_flag = (descriptor >> 2) & 1; 512 const u8 dictionary_id_flag = descriptor & 3; 513 514 if (reserved_bit != 0) { 515 CORRUPTION(); 516 } 517 518 header->single_segment_flag = single_segment_flag; 519 header->content_checksum_flag = content_checksum_flag; 520 521 // decode window size 522 if (!single_segment_flag) { 523 // "Provides guarantees on maximum back-reference distance that will be 524 // used within compressed data. This information is important for 525 // decoders to allocate enough memory. 526 // 527 // Bit numbers 7-3 2-0 528 // Field name Exponent Mantissa" 529 u8 window_descriptor = (u8)IO_read_bits(in, 8); 530 u8 exponent = window_descriptor >> 3; 531 u8 mantissa = window_descriptor & 7; 532 533 // Use the algorithm from the specification to compute window size 534 // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#window_descriptor 535 size_t window_base = (size_t)1 << (10 + exponent); 536 size_t window_add = (window_base / 8) * mantissa; 537 header->window_size = window_base + window_add; 538 } 539 540 // decode dictionary id if it exists 541 if (dictionary_id_flag) { 542 // "This is a variable size field, which contains the ID of the 543 // dictionary required to properly decode the frame. Note that this 544 // field is optional. When it's not present, it's up to the caller to 545 // make sure it uses the correct dictionary. Format is little-endian." 546 const int bytes_array[] = {0, 1, 2, 4}; 547 const int bytes = bytes_array[dictionary_id_flag]; 548 549 header->dictionary_id = (u32)IO_read_bits(in, bytes * 8); 550 } else { 551 header->dictionary_id = 0; 552 } 553 554 // decode frame content size if it exists 555 if (single_segment_flag || frame_content_size_flag) { 556 // "This is the original (uncompressed) size. This information is 557 // optional. The Field_Size is provided according to value of 558 // Frame_Content_Size_flag. The Field_Size can be equal to 0 (not 559 // present), 1, 2, 4 or 8 bytes. Format is little-endian." 560 // 561 // if frame_content_size_flag == 0 but single_segment_flag is set, we 562 // still have a 1 byte field 563 const int bytes_array[] = {1, 2, 4, 8}; 564 const int bytes = bytes_array[frame_content_size_flag]; 565 566 header->frame_content_size = IO_read_bits(in, bytes * 8); 567 if (bytes == 2) { 568 // "When Field_Size is 2, the offset of 256 is added." 569 header->frame_content_size += 256; 570 } 571 } else { 572 header->frame_content_size = 0; 573 } 574 575 if (single_segment_flag) { 576 // "The Window_Descriptor byte is optional. It is absent when 577 // Single_Segment_flag is set. In this case, the maximum back-reference 578 // distance is the content size itself, which can be any value from 1 to 579 // 2^64-1 bytes (16 EB)." 580 header->window_size = header->frame_content_size; 581 } 582 } 583 584 /// Decompress the data from a frame block by block 585 static void decompress_data(frame_context_t *const ctx, ostream_t *const out, 586 istream_t *const in) { 587 // "A frame encapsulates one or multiple blocks. Each block can be 588 // compressed or not, and has a guaranteed maximum content size, which 589 // depends on frame parameters. Unlike frames, each block depends on 590 // previous blocks for proper decoding. However, each block can be 591 // decompressed without waiting for its successor, allowing streaming 592 // operations." 593 int last_block = 0; 594 do { 595 // "Last_Block 596 // 597 // The lowest bit signals if this block is the last one. Frame ends 598 // right after this block. 599 // 600 // Block_Type and Block_Size 601 // 602 // The next 2 bits represent the Block_Type, while the remaining 21 bits 603 // represent the Block_Size. Format is little-endian." 604 last_block = (int)IO_read_bits(in, 1); 605 const int block_type = (int)IO_read_bits(in, 2); 606 const size_t block_len = IO_read_bits(in, 21); 607 608 switch (block_type) { 609 case 0: { 610 // "Raw_Block - this is an uncompressed block. Block_Size is the 611 // number of bytes to read and copy." 612 const u8 *const read_ptr = IO_get_read_ptr(in, block_len); 613 u8 *const write_ptr = IO_get_write_ptr(out, block_len); 614 615 // Copy the raw data into the output 616 memcpy(write_ptr, read_ptr, block_len); 617 618 ctx->current_total_output += block_len; 619 break; 620 } 621 case 1: { 622 // "RLE_Block - this is a single byte, repeated N times. In which 623 // case, Block_Size is the size to regenerate, while the 624 // "compressed" block is just 1 byte (the byte to repeat)." 625 const u8 *const read_ptr = IO_get_read_ptr(in, 1); 626 u8 *const write_ptr = IO_get_write_ptr(out, block_len); 627 628 // Copy `block_len` copies of `read_ptr[0]` to the output 629 memset(write_ptr, read_ptr[0], block_len); 630 631 ctx->current_total_output += block_len; 632 break; 633 } 634 case 2: { 635 // "Compressed_Block - this is a Zstandard compressed block, 636 // detailed in another section of this specification. Block_Size is 637 // the compressed size. 638 639 // Create a sub-stream for the block 640 istream_t block_stream = IO_make_sub_istream(in, block_len); 641 decompress_block(ctx, out, &block_stream); 642 break; 643 } 644 case 3: 645 // "Reserved - this is not a block. This value cannot be used with 646 // current version of this specification." 647 CORRUPTION(); 648 break; 649 default: 650 IMPOSSIBLE(); 651 } 652 } while (!last_block); 653 654 if (ctx->header.content_checksum_flag) { 655 // This program does not support checking the checksum, so skip over it 656 // if it's present 657 IO_advance_input(in, 4); 658 } 659 } 660 /******* END FRAME DECODING ***************************************************/ 661 662 /******* BLOCK DECOMPRESSION **************************************************/ 663 static void decompress_block(frame_context_t *const ctx, ostream_t *const out, 664 istream_t *const in) { 665 // "A compressed block consists of 2 sections : 666 // 667 // Literals_Section 668 // Sequences_Section" 669 670 671 // Part 1: decode the literals block 672 u8 *literals = NULL; 673 const size_t literals_size = decode_literals(ctx, in, &literals); 674 675 // Part 2: decode the sequences block 676 sequence_command_t *sequences = NULL; 677 const size_t num_sequences = 678 decode_sequences(ctx, in, &sequences); 679 680 // Part 3: combine literals and sequence commands to generate output 681 execute_sequences(ctx, out, literals, literals_size, sequences, 682 num_sequences); 683 free(literals); 684 free(sequences); 685 } 686 /******* END BLOCK DECOMPRESSION **********************************************/ 687 688 /******* LITERALS DECODING ****************************************************/ 689 static size_t decode_literals_simple(istream_t *const in, u8 **const literals, 690 const int block_type, 691 const int size_format); 692 static size_t decode_literals_compressed(frame_context_t *const ctx, 693 istream_t *const in, 694 u8 **const literals, 695 const int block_type, 696 const int size_format); 697 static void decode_huf_table(HUF_dtable *const dtable, istream_t *const in); 698 static void fse_decode_hufweights(ostream_t *weights, istream_t *const in, 699 int *const num_symbs); 700 701 static size_t decode_literals(frame_context_t *const ctx, istream_t *const in, 702 u8 **const literals) { 703 // "Literals can be stored uncompressed or compressed using Huffman prefix 704 // codes. When compressed, an optional tree description can be present, 705 // followed by 1 or 4 streams." 706 // 707 // "Literals_Section_Header 708 // 709 // Header is in charge of describing how literals are packed. It's a 710 // byte-aligned variable-size bitfield, ranging from 1 to 5 bytes, using 711 // little-endian convention." 712 // 713 // "Literals_Block_Type 714 // 715 // This field uses 2 lowest bits of first byte, describing 4 different block 716 // types" 717 // 718 // size_format takes between 1 and 2 bits 719 int block_type = (int)IO_read_bits(in, 2); 720 int size_format = (int)IO_read_bits(in, 2); 721 722 if (block_type <= 1) { 723 // Raw or RLE literals block 724 return decode_literals_simple(in, literals, block_type, 725 size_format); 726 } else { 727 // Huffman compressed literals 728 return decode_literals_compressed(ctx, in, literals, block_type, 729 size_format); 730 } 731 } 732 733 /// Decodes literals blocks in raw or RLE form 734 static size_t decode_literals_simple(istream_t *const in, u8 **const literals, 735 const int block_type, 736 const int size_format) { 737 size_t size; 738 switch (size_format) { 739 // These cases are in the form ?0 740 // In this case, the ? bit is actually part of the size field 741 case 0: 742 case 2: 743 // "Size_Format uses 1 bit. Regenerated_Size uses 5 bits (0-31)." 744 IO_rewind_bits(in, 1); 745 size = IO_read_bits(in, 5); 746 break; 747 case 1: 748 // "Size_Format uses 2 bits. Regenerated_Size uses 12 bits (0-4095)." 749 size = IO_read_bits(in, 12); 750 break; 751 case 3: 752 // "Size_Format uses 2 bits. Regenerated_Size uses 20 bits (0-1048575)." 753 size = IO_read_bits(in, 20); 754 break; 755 default: 756 // Size format is in range 0-3 757 IMPOSSIBLE(); 758 } 759 760 if (size > MAX_LITERALS_SIZE) { 761 CORRUPTION(); 762 } 763 764 *literals = malloc(size); 765 if (!*literals) { 766 BAD_ALLOC(); 767 } 768 769 switch (block_type) { 770 case 0: { 771 // "Raw_Literals_Block - Literals are stored uncompressed." 772 const u8 *const read_ptr = IO_get_read_ptr(in, size); 773 memcpy(*literals, read_ptr, size); 774 break; 775 } 776 case 1: { 777 // "RLE_Literals_Block - Literals consist of a single byte value repeated N times." 778 const u8 *const read_ptr = IO_get_read_ptr(in, 1); 779 memset(*literals, read_ptr[0], size); 780 break; 781 } 782 default: 783 IMPOSSIBLE(); 784 } 785 786 return size; 787 } 788 789 /// Decodes Huffman compressed literals 790 static size_t decode_literals_compressed(frame_context_t *const ctx, 791 istream_t *const in, 792 u8 **const literals, 793 const int block_type, 794 const int size_format) { 795 size_t regenerated_size, compressed_size; 796 // Only size_format=0 has 1 stream, so default to 4 797 int num_streams = 4; 798 switch (size_format) { 799 case 0: 800 // "A single stream. Both Compressed_Size and Regenerated_Size use 10 801 // bits (0-1023)." 802 num_streams = 1; 803 // Fall through as it has the same size format 804 /* fallthrough */ 805 case 1: 806 // "4 streams. Both Compressed_Size and Regenerated_Size use 10 bits 807 // (0-1023)." 808 regenerated_size = IO_read_bits(in, 10); 809 compressed_size = IO_read_bits(in, 10); 810 break; 811 case 2: 812 // "4 streams. Both Compressed_Size and Regenerated_Size use 14 bits 813 // (0-16383)." 814 regenerated_size = IO_read_bits(in, 14); 815 compressed_size = IO_read_bits(in, 14); 816 break; 817 case 3: 818 // "4 streams. Both Compressed_Size and Regenerated_Size use 18 bits 819 // (0-262143)." 820 regenerated_size = IO_read_bits(in, 18); 821 compressed_size = IO_read_bits(in, 18); 822 break; 823 default: 824 // Impossible 825 IMPOSSIBLE(); 826 } 827 if (regenerated_size > MAX_LITERALS_SIZE) { 828 CORRUPTION(); 829 } 830 831 *literals = malloc(regenerated_size); 832 if (!*literals) { 833 BAD_ALLOC(); 834 } 835 836 ostream_t lit_stream = IO_make_ostream(*literals, regenerated_size); 837 istream_t huf_stream = IO_make_sub_istream(in, compressed_size); 838 839 if (block_type == 2) { 840 // Decode the provided Huffman table 841 // "This section is only present when Literals_Block_Type type is 842 // Compressed_Literals_Block (2)." 843 844 HUF_free_dtable(&ctx->literals_dtable); 845 decode_huf_table(&ctx->literals_dtable, &huf_stream); 846 } else { 847 // If the previous Huffman table is being repeated, ensure it exists 848 if (!ctx->literals_dtable.symbols) { 849 CORRUPTION(); 850 } 851 } 852 853 size_t symbols_decoded; 854 if (num_streams == 1) { 855 symbols_decoded = HUF_decompress_1stream(&ctx->literals_dtable, &lit_stream, &huf_stream); 856 } else { 857 symbols_decoded = HUF_decompress_4stream(&ctx->literals_dtable, &lit_stream, &huf_stream); 858 } 859 860 if (symbols_decoded != regenerated_size) { 861 CORRUPTION(); 862 } 863 864 return regenerated_size; 865 } 866 867 // Decode the Huffman table description 868 static void decode_huf_table(HUF_dtable *const dtable, istream_t *const in) { 869 // "All literal values from zero (included) to last present one (excluded) 870 // are represented by Weight with values from 0 to Max_Number_of_Bits." 871 872 // "This is a single byte value (0-255), which describes how to decode the list of weights." 873 const u8 header = IO_read_bits(in, 8); 874 875 u8 weights[HUF_MAX_SYMBS]; 876 memset(weights, 0, sizeof(weights)); 877 878 int num_symbs; 879 880 if (header >= 128) { 881 // "This is a direct representation, where each Weight is written 882 // directly as a 4 bits field (0-15). The full representation occupies 883 // ((Number_of_Symbols+1)/2) bytes, meaning it uses a last full byte 884 // even if Number_of_Symbols is odd. Number_of_Symbols = headerByte - 885 // 127" 886 num_symbs = header - 127; 887 const size_t bytes = (num_symbs + 1) / 2; 888 889 const u8 *const weight_src = IO_get_read_ptr(in, bytes); 890 891 for (int i = 0; i < num_symbs; i++) { 892 // "They are encoded forward, 2 893 // weights to a byte with the first weight taking the top four bits 894 // and the second taking the bottom four (e.g. the following 895 // operations could be used to read the weights: Weight[0] = 896 // (Byte[0] >> 4), Weight[1] = (Byte[0] & 0xf), etc.)." 897 if (i % 2 == 0) { 898 weights[i] = weight_src[i / 2] >> 4; 899 } else { 900 weights[i] = weight_src[i / 2] & 0xf; 901 } 902 } 903 } else { 904 // The weights are FSE encoded, decode them before we can construct the 905 // table 906 istream_t fse_stream = IO_make_sub_istream(in, header); 907 ostream_t weight_stream = IO_make_ostream(weights, HUF_MAX_SYMBS); 908 fse_decode_hufweights(&weight_stream, &fse_stream, &num_symbs); 909 } 910 911 // Construct the table using the decoded weights 912 HUF_init_dtable_usingweights(dtable, weights, num_symbs); 913 } 914 915 static void fse_decode_hufweights(ostream_t *weights, istream_t *const in, 916 int *const num_symbs) { 917 const int MAX_ACCURACY_LOG = 7; 918 919 FSE_dtable dtable; 920 921 // "An FSE bitstream starts by a header, describing probabilities 922 // distribution. It will create a Decoding Table. For a list of Huffman 923 // weights, maximum accuracy is 7 bits." 924 FSE_decode_header(&dtable, in, MAX_ACCURACY_LOG); 925 926 // Decode the weights 927 *num_symbs = FSE_decompress_interleaved2(&dtable, weights, in); 928 929 FSE_free_dtable(&dtable); 930 } 931 /******* END LITERALS DECODING ************************************************/ 932 933 /******* SEQUENCE DECODING ****************************************************/ 934 /// The combination of FSE states needed to decode sequences 935 typedef struct { 936 FSE_dtable ll_table; 937 FSE_dtable of_table; 938 FSE_dtable ml_table; 939 940 u16 ll_state; 941 u16 of_state; 942 u16 ml_state; 943 } sequence_states_t; 944 945 /// Different modes to signal to decode_seq_tables what to do 946 typedef enum { 947 seq_literal_length = 0, 948 seq_offset = 1, 949 seq_match_length = 2, 950 } seq_part_t; 951 952 typedef enum { 953 seq_predefined = 0, 954 seq_rle = 1, 955 seq_fse = 2, 956 seq_repeat = 3, 957 } seq_mode_t; 958 959 /// The predefined FSE distribution tables for `seq_predefined` mode 960 static const i16 SEQ_LITERAL_LENGTH_DEFAULT_DIST[36] = { 961 4, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 2, 2, 962 2, 2, 2, 2, 2, 2, 2, 3, 2, 1, 1, 1, 1, 1, -1, -1, -1, -1}; 963 static const i16 SEQ_OFFSET_DEFAULT_DIST[29] = { 964 1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 965 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1}; 966 static const i16 SEQ_MATCH_LENGTH_DEFAULT_DIST[53] = { 967 1, 4, 3, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 968 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 969 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1}; 970 971 /// The sequence decoding baseline and number of additional bits to read/add 972 /// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#the-codes-for-literals-lengths-match-lengths-and-offsets 973 static const u32 SEQ_LITERAL_LENGTH_BASELINES[36] = { 974 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 975 12, 13, 14, 15, 16, 18, 20, 22, 24, 28, 32, 40, 976 48, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536}; 977 static const u8 SEQ_LITERAL_LENGTH_EXTRA_BITS[36] = { 978 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 979 1, 1, 2, 2, 3, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; 980 981 static const u32 SEQ_MATCH_LENGTH_BASELINES[53] = { 982 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 983 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 984 31, 32, 33, 34, 35, 37, 39, 41, 43, 47, 51, 59, 67, 83, 985 99, 131, 259, 515, 1027, 2051, 4099, 8195, 16387, 32771, 65539}; 986 static const u8 SEQ_MATCH_LENGTH_EXTRA_BITS[53] = { 987 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 988 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 989 2, 2, 3, 3, 4, 4, 5, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; 990 991 /// Offset decoding is simpler so we just need a maximum code value 992 static const u8 SEQ_MAX_CODES[3] = {35, (u8)-1, 52}; 993 994 static void decompress_sequences(frame_context_t *const ctx, 995 istream_t *const in, 996 sequence_command_t *const sequences, 997 const size_t num_sequences); 998 static sequence_command_t decode_sequence(sequence_states_t *const state, 999 const u8 *const src, 1000 i64 *const offset, 1001 int lastSequence); 1002 static void decode_seq_table(FSE_dtable *const table, istream_t *const in, 1003 const seq_part_t type, const seq_mode_t mode); 1004 1005 static size_t decode_sequences(frame_context_t *const ctx, istream_t *in, 1006 sequence_command_t **const sequences) { 1007 // "A compressed block is a succession of sequences . A sequence is a 1008 // literal copy command, followed by a match copy command. A literal copy 1009 // command specifies a length. It is the number of bytes to be copied (or 1010 // extracted) from the literal section. A match copy command specifies an 1011 // offset and a length. The offset gives the position to copy from, which 1012 // can be within a previous block." 1013 1014 size_t num_sequences; 1015 1016 // "Number_of_Sequences 1017 // 1018 // This is a variable size field using between 1 and 3 bytes. Let's call its 1019 // first byte byte0." 1020 u8 header = IO_read_bits(in, 8); 1021 if (header < 128) { 1022 // "Number_of_Sequences = byte0 . Uses 1 byte." 1023 num_sequences = header; 1024 } else if (header < 255) { 1025 // "Number_of_Sequences = ((byte0-128) << 8) + byte1 . Uses 2 bytes." 1026 num_sequences = ((header - 128) << 8) + IO_read_bits(in, 8); 1027 } else { 1028 // "Number_of_Sequences = byte1 + (byte2<<8) + 0x7F00 . Uses 3 bytes." 1029 num_sequences = IO_read_bits(in, 16) + 0x7F00; 1030 } 1031 1032 if (num_sequences == 0) { 1033 // "There are no sequences. The sequence section stops there." 1034 *sequences = NULL; 1035 return 0; 1036 } 1037 1038 *sequences = malloc(num_sequences * sizeof(sequence_command_t)); 1039 if (!*sequences) { 1040 BAD_ALLOC(); 1041 } 1042 1043 decompress_sequences(ctx, in, *sequences, num_sequences); 1044 return num_sequences; 1045 } 1046 1047 /// Decompress the FSE encoded sequence commands 1048 static void decompress_sequences(frame_context_t *const ctx, istream_t *in, 1049 sequence_command_t *const sequences, 1050 const size_t num_sequences) { 1051 // "The Sequences_Section regroup all symbols required to decode commands. 1052 // There are 3 symbol types : literals lengths, offsets and match lengths. 1053 // They are encoded together, interleaved, in a single bitstream." 1054 1055 // "Symbol compression modes 1056 // 1057 // This is a single byte, defining the compression mode of each symbol 1058 // type." 1059 // 1060 // Bit number : Field name 1061 // 7-6 : Literals_Lengths_Mode 1062 // 5-4 : Offsets_Mode 1063 // 3-2 : Match_Lengths_Mode 1064 // 1-0 : Reserved 1065 u8 compression_modes = IO_read_bits(in, 8); 1066 1067 if ((compression_modes & 3) != 0) { 1068 // Reserved bits set 1069 CORRUPTION(); 1070 } 1071 1072 // "Following the header, up to 3 distribution tables can be described. When 1073 // present, they are in this order : 1074 // 1075 // Literals lengths 1076 // Offsets 1077 // Match Lengths" 1078 // Update the tables we have stored in the context 1079 decode_seq_table(&ctx->ll_dtable, in, seq_literal_length, 1080 (compression_modes >> 6) & 3); 1081 1082 decode_seq_table(&ctx->of_dtable, in, seq_offset, 1083 (compression_modes >> 4) & 3); 1084 1085 decode_seq_table(&ctx->ml_dtable, in, seq_match_length, 1086 (compression_modes >> 2) & 3); 1087 1088 1089 sequence_states_t states; 1090 1091 // Initialize the decoding tables 1092 { 1093 states.ll_table = ctx->ll_dtable; 1094 states.of_table = ctx->of_dtable; 1095 states.ml_table = ctx->ml_dtable; 1096 } 1097 1098 const size_t len = IO_istream_len(in); 1099 const u8 *const src = IO_get_read_ptr(in, len); 1100 1101 // "After writing the last bit containing information, the compressor writes 1102 // a single 1-bit and then fills the byte with 0-7 0 bits of padding." 1103 const int padding = 8 - highest_set_bit(src[len - 1]); 1104 // The offset starts at the end because FSE streams are read backwards 1105 i64 bit_offset = (i64)(len * 8 - (size_t)padding); 1106 1107 // "The bitstream starts with initial state values, each using the required 1108 // number of bits in their respective accuracy, decoded previously from 1109 // their normalized distribution. 1110 // 1111 // It starts by Literals_Length_State, followed by Offset_State, and finally 1112 // Match_Length_State." 1113 FSE_init_state(&states.ll_table, &states.ll_state, src, &bit_offset); 1114 FSE_init_state(&states.of_table, &states.of_state, src, &bit_offset); 1115 FSE_init_state(&states.ml_table, &states.ml_state, src, &bit_offset); 1116 1117 for (size_t i = 0; i < num_sequences; i++) { 1118 // Decode sequences one by one 1119 sequences[i] = decode_sequence(&states, src, &bit_offset, i==num_sequences-1); 1120 } 1121 1122 if (bit_offset != 0) { 1123 CORRUPTION(); 1124 } 1125 } 1126 1127 // Decode a single sequence and update the state 1128 static sequence_command_t decode_sequence(sequence_states_t *const states, 1129 const u8 *const src, 1130 i64 *const offset, 1131 int lastSequence) { 1132 // "Each symbol is a code in its own context, which specifies Baseline and 1133 // Number_of_Bits to add. Codes are FSE compressed, and interleaved with raw 1134 // additional bits in the same bitstream." 1135 1136 // Decode symbols, but don't update states 1137 const u8 of_code = FSE_peek_symbol(&states->of_table, states->of_state); 1138 const u8 ll_code = FSE_peek_symbol(&states->ll_table, states->ll_state); 1139 const u8 ml_code = FSE_peek_symbol(&states->ml_table, states->ml_state); 1140 1141 // Offset doesn't need a max value as it's not decoded using a table 1142 if (ll_code > SEQ_MAX_CODES[seq_literal_length] || 1143 ml_code > SEQ_MAX_CODES[seq_match_length]) { 1144 CORRUPTION(); 1145 } 1146 1147 // Read the interleaved bits 1148 sequence_command_t seq; 1149 // "Decoding starts by reading the Number_of_Bits required to decode Offset. 1150 // It then does the same for Match_Length, and then for Literals_Length." 1151 seq.offset = ((u32)1 << of_code) + STREAM_read_bits(src, of_code, offset); 1152 1153 seq.match_length = 1154 SEQ_MATCH_LENGTH_BASELINES[ml_code] + 1155 STREAM_read_bits(src, SEQ_MATCH_LENGTH_EXTRA_BITS[ml_code], offset); 1156 1157 seq.literal_length = 1158 SEQ_LITERAL_LENGTH_BASELINES[ll_code] + 1159 STREAM_read_bits(src, SEQ_LITERAL_LENGTH_EXTRA_BITS[ll_code], offset); 1160 1161 // "If it is not the last sequence in the block, the next operation is to 1162 // update states. Using the rules pre-calculated in the decoding tables, 1163 // Literals_Length_State is updated, followed by Match_Length_State, and 1164 // then Offset_State." 1165 // If the stream is complete don't read bits to update state 1166 if (!lastSequence) { 1167 FSE_update_state(&states->ll_table, &states->ll_state, src, offset); 1168 FSE_update_state(&states->ml_table, &states->ml_state, src, offset); 1169 FSE_update_state(&states->of_table, &states->of_state, src, offset); 1170 } 1171 1172 return seq; 1173 } 1174 1175 /// Given a sequence part and table mode, decode the FSE distribution 1176 /// Errors if the mode is `seq_repeat` without a pre-existing table in `table` 1177 static void decode_seq_table(FSE_dtable *const table, istream_t *const in, 1178 const seq_part_t type, const seq_mode_t mode) { 1179 // Constant arrays indexed by seq_part_t 1180 const i16 *const default_distributions[] = {SEQ_LITERAL_LENGTH_DEFAULT_DIST, 1181 SEQ_OFFSET_DEFAULT_DIST, 1182 SEQ_MATCH_LENGTH_DEFAULT_DIST}; 1183 const size_t default_distribution_lengths[] = {36, 29, 53}; 1184 const size_t default_distribution_accuracies[] = {6, 5, 6}; 1185 1186 const size_t max_accuracies[] = {9, 8, 9}; 1187 1188 if (mode != seq_repeat) { 1189 // Free old one before overwriting 1190 FSE_free_dtable(table); 1191 } 1192 1193 switch (mode) { 1194 case seq_predefined: { 1195 // "Predefined_Mode : uses a predefined distribution table." 1196 const i16 *distribution = default_distributions[type]; 1197 const size_t symbs = default_distribution_lengths[type]; 1198 const size_t accuracy_log = default_distribution_accuracies[type]; 1199 1200 FSE_init_dtable(table, distribution, symbs, accuracy_log); 1201 break; 1202 } 1203 case seq_rle: { 1204 // "RLE_Mode : it's a single code, repeated Number_of_Sequences times." 1205 const u8 symb = IO_get_read_ptr(in, 1)[0]; 1206 FSE_init_dtable_rle(table, symb); 1207 break; 1208 } 1209 case seq_fse: { 1210 // "FSE_Compressed_Mode : standard FSE compression. A distribution table 1211 // will be present " 1212 FSE_decode_header(table, in, max_accuracies[type]); 1213 break; 1214 } 1215 case seq_repeat: 1216 // "Repeat_Mode : reuse distribution table from previous compressed 1217 // block." 1218 // Nothing to do here, table will be unchanged 1219 if (!table->symbols) { 1220 // This mode is invalid if we don't already have a table 1221 CORRUPTION(); 1222 } 1223 break; 1224 default: 1225 // Impossible, as mode is from 0-3 1226 IMPOSSIBLE(); 1227 break; 1228 } 1229 1230 } 1231 /******* END SEQUENCE DECODING ************************************************/ 1232 1233 /******* SEQUENCE EXECUTION ***************************************************/ 1234 static void execute_sequences(frame_context_t *const ctx, ostream_t *const out, 1235 const u8 *const literals, 1236 const size_t literals_len, 1237 const sequence_command_t *const sequences, 1238 const size_t num_sequences) { 1239 istream_t litstream = IO_make_istream(literals, literals_len); 1240 1241 u64 *const offset_hist = ctx->previous_offsets; 1242 size_t total_output = ctx->current_total_output; 1243 1244 for (size_t i = 0; i < num_sequences; i++) { 1245 const sequence_command_t seq = sequences[i]; 1246 { 1247 const u32 literals_size = copy_literals(seq.literal_length, &litstream, out); 1248 total_output += literals_size; 1249 } 1250 1251 size_t const offset = compute_offset(seq, offset_hist); 1252 1253 size_t const match_length = seq.match_length; 1254 1255 execute_match_copy(ctx, offset, match_length, total_output, out); 1256 1257 total_output += match_length; 1258 } 1259 1260 // Copy any leftover literals 1261 { 1262 size_t len = IO_istream_len(&litstream); 1263 copy_literals(len, &litstream, out); 1264 total_output += len; 1265 } 1266 1267 ctx->current_total_output = total_output; 1268 } 1269 1270 static u32 copy_literals(const size_t literal_length, istream_t *litstream, 1271 ostream_t *const out) { 1272 // If the sequence asks for more literals than are left, the 1273 // sequence must be corrupted 1274 if (literal_length > IO_istream_len(litstream)) { 1275 CORRUPTION(); 1276 } 1277 1278 u8 *const write_ptr = IO_get_write_ptr(out, literal_length); 1279 const u8 *const read_ptr = 1280 IO_get_read_ptr(litstream, literal_length); 1281 // Copy literals to output 1282 memcpy(write_ptr, read_ptr, literal_length); 1283 1284 return literal_length; 1285 } 1286 1287 static size_t compute_offset(sequence_command_t seq, u64 *const offset_hist) { 1288 size_t offset; 1289 // Offsets are special, we need to handle the repeat offsets 1290 if (seq.offset <= 3) { 1291 // "The first 3 values define a repeated offset and we will call 1292 // them Repeated_Offset1, Repeated_Offset2, and Repeated_Offset3. 1293 // They are sorted in recency order, with Repeated_Offset1 meaning 1294 // 'most recent one'". 1295 1296 // Use 0 indexing for the array 1297 u32 idx = seq.offset - 1; 1298 if (seq.literal_length == 0) { 1299 // "There is an exception though, when current sequence's 1300 // literals length is 0. In this case, repeated offsets are 1301 // shifted by one, so Repeated_Offset1 becomes Repeated_Offset2, 1302 // Repeated_Offset2 becomes Repeated_Offset3, and 1303 // Repeated_Offset3 becomes Repeated_Offset1 - 1_byte." 1304 idx++; 1305 } 1306 1307 if (idx == 0) { 1308 offset = offset_hist[0]; 1309 } else { 1310 // If idx == 3 then literal length was 0 and the offset was 3, 1311 // as per the exception listed above 1312 offset = idx < 3 ? offset_hist[idx] : offset_hist[0] - 1; 1313 1314 // If idx == 1 we don't need to modify offset_hist[2], since 1315 // we're using the second-most recent code 1316 if (idx > 1) { 1317 offset_hist[2] = offset_hist[1]; 1318 } 1319 offset_hist[1] = offset_hist[0]; 1320 offset_hist[0] = offset; 1321 } 1322 } else { 1323 // When it's not a repeat offset: 1324 // "if (Offset_Value > 3) offset = Offset_Value - 3;" 1325 offset = seq.offset - 3; 1326 1327 // Shift back history 1328 offset_hist[2] = offset_hist[1]; 1329 offset_hist[1] = offset_hist[0]; 1330 offset_hist[0] = offset; 1331 } 1332 return offset; 1333 } 1334 1335 static void execute_match_copy(frame_context_t *const ctx, size_t offset, 1336 size_t match_length, size_t total_output, 1337 ostream_t *const out) { 1338 u8 *write_ptr = IO_get_write_ptr(out, match_length); 1339 if (total_output <= ctx->header.window_size) { 1340 // In this case offset might go back into the dictionary 1341 if (offset > total_output + ctx->dict_content_len) { 1342 // The offset goes beyond even the dictionary 1343 CORRUPTION(); 1344 } 1345 1346 if (offset > total_output) { 1347 // "The rest of the dictionary is its content. The content act 1348 // as a "past" in front of data to compress or decompress, so it 1349 // can be referenced in sequence commands." 1350 const size_t dict_copy = 1351 MIN(offset - total_output, match_length); 1352 const size_t dict_offset = 1353 ctx->dict_content_len - (offset - total_output); 1354 1355 memcpy(write_ptr, ctx->dict_content + dict_offset, dict_copy); 1356 write_ptr += dict_copy; 1357 match_length -= dict_copy; 1358 } 1359 } else if (offset > ctx->header.window_size) { 1360 CORRUPTION(); 1361 } 1362 1363 // We must copy byte by byte because the match length might be larger 1364 // than the offset 1365 // ex: if the output so far was "abc", a command with offset=3 and 1366 // match_length=6 would produce "abcabcabc" as the new output 1367 for (size_t j = 0; j < match_length; j++) { 1368 *write_ptr = *(write_ptr - offset); 1369 write_ptr++; 1370 } 1371 } 1372 /******* END SEQUENCE EXECUTION ***********************************************/ 1373 1374 /******* OUTPUT SIZE COUNTING *************************************************/ 1375 /// Get the decompressed size of an input stream so memory can be allocated in 1376 /// advance. 1377 /// This implementation assumes `src` points to a single ZSTD-compressed frame 1378 size_t ZSTD_get_decompressed_size(const void *src, const size_t src_len) { 1379 istream_t in = IO_make_istream(src, src_len); 1380 1381 // get decompressed size from ZSTD frame header 1382 { 1383 const u32 magic_number = (u32)IO_read_bits(&in, 32); 1384 1385 if (magic_number == ZSTD_MAGIC_NUMBER) { 1386 // ZSTD frame 1387 frame_header_t header; 1388 parse_frame_header(&header, &in); 1389 1390 if (header.frame_content_size == 0 && !header.single_segment_flag) { 1391 // Content size not provided, we can't tell 1392 return (size_t)-1; 1393 } 1394 1395 return header.frame_content_size; 1396 } else { 1397 // not a real frame or skippable frame 1398 ERROR("ZSTD frame magic number did not match"); 1399 } 1400 } 1401 } 1402 /******* END OUTPUT SIZE COUNTING *********************************************/ 1403 1404 /******* DICTIONARY PARSING ***************************************************/ 1405 dictionary_t* create_dictionary(void) { 1406 dictionary_t* const dict = calloc(1, sizeof(dictionary_t)); 1407 if (!dict) { 1408 BAD_ALLOC(); 1409 } 1410 return dict; 1411 } 1412 1413 /// Free an allocated dictionary 1414 void free_dictionary(dictionary_t *const dict) { 1415 HUF_free_dtable(&dict->literals_dtable); 1416 FSE_free_dtable(&dict->ll_dtable); 1417 FSE_free_dtable(&dict->of_dtable); 1418 FSE_free_dtable(&dict->ml_dtable); 1419 1420 free(dict->content); 1421 1422 memset(dict, 0, sizeof(dictionary_t)); 1423 1424 free(dict); 1425 } 1426 1427 1428 #if !defined(ZDEC_NO_DICTIONARY) 1429 #define DICT_SIZE_ERROR() ERROR("Dictionary size cannot be less than 8 bytes") 1430 #define NULL_SRC() ERROR("Tried to create dictionary with pointer to null src"); 1431 1432 static void init_dictionary_content(dictionary_t *const dict, 1433 istream_t *const in); 1434 1435 void parse_dictionary(dictionary_t *const dict, const void *src, 1436 size_t src_len) { 1437 const u8 *byte_src = (const u8 *)src; 1438 memset(dict, 0, sizeof(dictionary_t)); 1439 if (src == NULL) { /* cannot initialize dictionary with null src */ 1440 NULL_SRC(); 1441 } 1442 if (src_len < 8) { 1443 DICT_SIZE_ERROR(); 1444 } 1445 1446 istream_t in = IO_make_istream(byte_src, src_len); 1447 1448 const u32 magic_number = IO_read_bits(&in, 32); 1449 if (magic_number != 0xEC30A437) { 1450 // raw content dict 1451 IO_rewind_bits(&in, 32); 1452 init_dictionary_content(dict, &in); 1453 return; 1454 } 1455 1456 dict->dictionary_id = IO_read_bits(&in, 32); 1457 1458 // "Entropy_Tables : following the same format as the tables in compressed 1459 // blocks. They are stored in following order : Huffman tables for literals, 1460 // FSE table for offsets, FSE table for match lengths, and FSE table for 1461 // literals lengths. It's finally followed by 3 offset values, populating 1462 // recent offsets (instead of using {1,4,8}), stored in order, 4-bytes 1463 // little-endian each, for a total of 12 bytes. Each recent offset must have 1464 // a value < dictionary size." 1465 decode_huf_table(&dict->literals_dtable, &in); 1466 decode_seq_table(&dict->of_dtable, &in, seq_offset, seq_fse); 1467 decode_seq_table(&dict->ml_dtable, &in, seq_match_length, seq_fse); 1468 decode_seq_table(&dict->ll_dtable, &in, seq_literal_length, seq_fse); 1469 1470 // Read in the previous offset history 1471 dict->previous_offsets[0] = IO_read_bits(&in, 32); 1472 dict->previous_offsets[1] = IO_read_bits(&in, 32); 1473 dict->previous_offsets[2] = IO_read_bits(&in, 32); 1474 1475 // Ensure the provided offsets aren't too large 1476 // "Each recent offset must have a value < dictionary size." 1477 for (int i = 0; i < 3; i++) { 1478 if (dict->previous_offsets[i] > src_len) { 1479 ERROR("Dictionary corrupted"); 1480 } 1481 } 1482 1483 // "Content : The rest of the dictionary is its content. The content act as 1484 // a "past" in front of data to compress or decompress, so it can be 1485 // referenced in sequence commands." 1486 init_dictionary_content(dict, &in); 1487 } 1488 1489 static void init_dictionary_content(dictionary_t *const dict, 1490 istream_t *const in) { 1491 // Copy in the content 1492 dict->content_size = IO_istream_len(in); 1493 dict->content = malloc(dict->content_size); 1494 if (!dict->content) { 1495 BAD_ALLOC(); 1496 } 1497 1498 const u8 *const content = IO_get_read_ptr(in, dict->content_size); 1499 1500 memcpy(dict->content, content, dict->content_size); 1501 } 1502 1503 static void HUF_copy_dtable(HUF_dtable *const dst, 1504 const HUF_dtable *const src) { 1505 if (src->max_bits == 0) { 1506 memset(dst, 0, sizeof(HUF_dtable)); 1507 return; 1508 } 1509 1510 const size_t size = (size_t)1 << src->max_bits; 1511 dst->max_bits = src->max_bits; 1512 1513 dst->symbols = malloc(size); 1514 dst->num_bits = malloc(size); 1515 if (!dst->symbols || !dst->num_bits) { 1516 BAD_ALLOC(); 1517 } 1518 1519 memcpy(dst->symbols, src->symbols, size); 1520 memcpy(dst->num_bits, src->num_bits, size); 1521 } 1522 1523 static void FSE_copy_dtable(FSE_dtable *const dst, const FSE_dtable *const src) { 1524 if (src->accuracy_log == 0) { 1525 memset(dst, 0, sizeof(FSE_dtable)); 1526 return; 1527 } 1528 1529 size_t size = (size_t)1 << src->accuracy_log; 1530 dst->accuracy_log = src->accuracy_log; 1531 1532 dst->symbols = malloc(size); 1533 dst->num_bits = malloc(size); 1534 dst->new_state_base = malloc(size * sizeof(u16)); 1535 if (!dst->symbols || !dst->num_bits || !dst->new_state_base) { 1536 BAD_ALLOC(); 1537 } 1538 1539 memcpy(dst->symbols, src->symbols, size); 1540 memcpy(dst->num_bits, src->num_bits, size); 1541 memcpy(dst->new_state_base, src->new_state_base, size * sizeof(u16)); 1542 } 1543 1544 /// A dictionary acts as initializing values for the frame context before 1545 /// decompression, so we implement it by applying it's predetermined 1546 /// tables and content to the context before beginning decompression 1547 static void frame_context_apply_dict(frame_context_t *const ctx, 1548 const dictionary_t *const dict) { 1549 // If the content pointer is NULL then it must be an empty dict 1550 if (!dict || !dict->content) 1551 return; 1552 1553 // If the requested dictionary_id is non-zero, the correct dictionary must 1554 // be present 1555 if (ctx->header.dictionary_id != 0 && 1556 ctx->header.dictionary_id != dict->dictionary_id) { 1557 ERROR("Wrong dictionary provided"); 1558 } 1559 1560 // Copy the dict content to the context for references during sequence 1561 // execution 1562 ctx->dict_content = dict->content; 1563 ctx->dict_content_len = dict->content_size; 1564 1565 // If it's a formatted dict copy the precomputed tables in so they can 1566 // be used in the table repeat modes 1567 if (dict->dictionary_id != 0) { 1568 // Deep copy the entropy tables so they can be freed independently of 1569 // the dictionary struct 1570 HUF_copy_dtable(&ctx->literals_dtable, &dict->literals_dtable); 1571 FSE_copy_dtable(&ctx->ll_dtable, &dict->ll_dtable); 1572 FSE_copy_dtable(&ctx->of_dtable, &dict->of_dtable); 1573 FSE_copy_dtable(&ctx->ml_dtable, &dict->ml_dtable); 1574 1575 // Copy the repeated offsets 1576 memcpy(ctx->previous_offsets, dict->previous_offsets, 1577 sizeof(ctx->previous_offsets)); 1578 } 1579 } 1580 1581 #else // ZDEC_NO_DICTIONARY is defined 1582 1583 static void frame_context_apply_dict(frame_context_t *const ctx, 1584 const dictionary_t *const dict) { 1585 (void)ctx; 1586 if (dict && dict->content) ERROR("dictionary not supported"); 1587 } 1588 1589 #endif 1590 /******* END DICTIONARY PARSING ***********************************************/ 1591 1592 /******* IO STREAM OPERATIONS *************************************************/ 1593 1594 /// Reads `num` bits from a bitstream, and updates the internal offset 1595 static inline u64 IO_read_bits(istream_t *const in, const int num_bits) { 1596 if (num_bits > 64 || num_bits <= 0) { 1597 ERROR("Attempt to read an invalid number of bits"); 1598 } 1599 1600 const size_t bytes = (num_bits + in->bit_offset + 7) / 8; 1601 const size_t full_bytes = (num_bits + in->bit_offset) / 8; 1602 if (bytes > in->len) { 1603 INP_SIZE(); 1604 } 1605 1606 const u64 result = read_bits_LE(in->ptr, num_bits, in->bit_offset); 1607 1608 in->bit_offset = (num_bits + in->bit_offset) % 8; 1609 in->ptr += full_bytes; 1610 in->len -= full_bytes; 1611 1612 return result; 1613 } 1614 1615 /// If a non-zero number of bits have been read from the current byte, advance 1616 /// the offset to the next byte 1617 static inline void IO_rewind_bits(istream_t *const in, int num_bits) { 1618 if (num_bits < 0) { 1619 ERROR("Attempting to rewind stream by a negative number of bits"); 1620 } 1621 1622 // move the offset back by `num_bits` bits 1623 const int new_offset = in->bit_offset - num_bits; 1624 // determine the number of whole bytes we have to rewind, rounding up to an 1625 // integer number (e.g. if `new_offset == -5`, `bytes == 1`) 1626 const i64 bytes = -(new_offset - 7) / 8; 1627 1628 in->ptr -= bytes; 1629 in->len += bytes; 1630 // make sure the resulting `bit_offset` is positive, as mod in C does not 1631 // convert numbers from negative to positive (e.g. -22 % 8 == -6) 1632 in->bit_offset = ((new_offset % 8) + 8) % 8; 1633 } 1634 1635 /// If the remaining bits in a byte will be unused, advance to the end of the 1636 /// byte 1637 static inline void IO_align_stream(istream_t *const in) { 1638 if (in->bit_offset != 0) { 1639 if (in->len == 0) { 1640 INP_SIZE(); 1641 } 1642 in->ptr++; 1643 in->len--; 1644 in->bit_offset = 0; 1645 } 1646 } 1647 1648 /// Write the given byte into the output stream 1649 static inline void IO_write_byte(ostream_t *const out, u8 symb) { 1650 if (out->len == 0) { 1651 OUT_SIZE(); 1652 } 1653 1654 out->ptr[0] = symb; 1655 out->ptr++; 1656 out->len--; 1657 } 1658 1659 /// Returns the number of bytes left to be read in this stream. The stream must 1660 /// be byte aligned. 1661 static inline size_t IO_istream_len(const istream_t *const in) { 1662 return in->len; 1663 } 1664 1665 /// Returns a pointer where `len` bytes can be read, and advances the internal 1666 /// state. The stream must be byte aligned. 1667 static inline const u8 *IO_get_read_ptr(istream_t *const in, size_t len) { 1668 if (len > in->len) { 1669 INP_SIZE(); 1670 } 1671 if (in->bit_offset != 0) { 1672 ERROR("Attempting to operate on a non-byte aligned stream"); 1673 } 1674 const u8 *const ptr = in->ptr; 1675 in->ptr += len; 1676 in->len -= len; 1677 1678 return ptr; 1679 } 1680 /// Returns a pointer to write `len` bytes to, and advances the internal state 1681 static inline u8 *IO_get_write_ptr(ostream_t *const out, size_t len) { 1682 if (len > out->len) { 1683 OUT_SIZE(); 1684 } 1685 u8 *const ptr = out->ptr; 1686 out->ptr += len; 1687 out->len -= len; 1688 1689 return ptr; 1690 } 1691 1692 /// Advance the inner state by `len` bytes 1693 static inline void IO_advance_input(istream_t *const in, size_t len) { 1694 if (len > in->len) { 1695 INP_SIZE(); 1696 } 1697 if (in->bit_offset != 0) { 1698 ERROR("Attempting to operate on a non-byte aligned stream"); 1699 } 1700 1701 in->ptr += len; 1702 in->len -= len; 1703 } 1704 1705 /// Returns an `ostream_t` constructed from the given pointer and length 1706 static inline ostream_t IO_make_ostream(u8 *out, size_t len) { 1707 return (ostream_t) { out, len }; 1708 } 1709 1710 /// Returns an `istream_t` constructed from the given pointer and length 1711 static inline istream_t IO_make_istream(const u8 *in, size_t len) { 1712 return (istream_t) { in, len, 0 }; 1713 } 1714 1715 /// Returns an `istream_t` with the same base as `in`, and length `len` 1716 /// Then, advance `in` to account for the consumed bytes 1717 /// `in` must be byte aligned 1718 static inline istream_t IO_make_sub_istream(istream_t *const in, size_t len) { 1719 // Consume `len` bytes of the parent stream 1720 const u8 *const ptr = IO_get_read_ptr(in, len); 1721 1722 // Make a substream using the pointer to those `len` bytes 1723 return IO_make_istream(ptr, len); 1724 } 1725 /******* END IO STREAM OPERATIONS *********************************************/ 1726 1727 /******* BITSTREAM OPERATIONS *************************************************/ 1728 /// Read `num` bits (up to 64) from `src + offset`, where `offset` is in bits 1729 static inline u64 read_bits_LE(const u8 *src, const int num_bits, 1730 const size_t offset) { 1731 if (num_bits > 64) { 1732 ERROR("Attempt to read an invalid number of bits"); 1733 } 1734 1735 // Skip over bytes that aren't in range 1736 src += offset / 8; 1737 size_t bit_offset = offset % 8; 1738 u64 res = 0; 1739 1740 int shift = 0; 1741 int left = num_bits; 1742 while (left > 0) { 1743 u64 mask = left >= 8 ? 0xff : (((u64)1 << left) - 1); 1744 // Read the next byte, shift it to account for the offset, and then mask 1745 // out the top part if we don't need all the bits 1746 res += (((u64)*src++ >> bit_offset) & mask) << shift; 1747 shift += 8 - bit_offset; 1748 left -= 8 - bit_offset; 1749 bit_offset = 0; 1750 } 1751 1752 return res; 1753 } 1754 1755 /// Read bits from the end of a HUF or FSE bitstream. `offset` is in bits, so 1756 /// it updates `offset` to `offset - bits`, and then reads `bits` bits from 1757 /// `src + offset`. If the offset becomes negative, the extra bits at the 1758 /// bottom are filled in with `0` bits instead of reading from before `src`. 1759 static inline u64 STREAM_read_bits(const u8 *const src, const int bits, 1760 i64 *const offset) { 1761 *offset = *offset - bits; 1762 size_t actual_off = *offset; 1763 size_t actual_bits = bits; 1764 // Don't actually read bits from before the start of src, so if `*offset < 1765 // 0` fix actual_off and actual_bits to reflect the quantity to read 1766 if (*offset < 0) { 1767 actual_bits += *offset; 1768 actual_off = 0; 1769 } 1770 u64 res = read_bits_LE(src, actual_bits, actual_off); 1771 1772 if (*offset < 0) { 1773 // Fill in the bottom "overflowed" bits with 0's 1774 res = -*offset >= 64 ? 0 : (res << -*offset); 1775 } 1776 return res; 1777 } 1778 /******* END BITSTREAM OPERATIONS *********************************************/ 1779 1780 /******* BIT COUNTING OPERATIONS **********************************************/ 1781 /// Returns `x`, where `2^x` is the largest power of 2 less than or equal to 1782 /// `num`, or `-1` if `num == 0`. 1783 static inline int highest_set_bit(const u64 num) { 1784 for (int i = 63; i >= 0; i--) { 1785 if (((u64)1 << i) <= num) { 1786 return i; 1787 } 1788 } 1789 return -1; 1790 } 1791 /******* END BIT COUNTING OPERATIONS ******************************************/ 1792 1793 /******* HUFFMAN PRIMITIVES ***************************************************/ 1794 static inline u8 HUF_decode_symbol(const HUF_dtable *const dtable, 1795 u16 *const state, const u8 *const src, 1796 i64 *const offset) { 1797 // Look up the symbol and number of bits to read 1798 const u8 symb = dtable->symbols[*state]; 1799 const u8 bits = dtable->num_bits[*state]; 1800 const u16 rest = STREAM_read_bits(src, bits, offset); 1801 // Shift `bits` bits out of the state, keeping the low order bits that 1802 // weren't necessary to determine this symbol. Then add in the new bits 1803 // read from the stream. 1804 *state = ((*state << bits) + rest) & (((u16)1 << dtable->max_bits) - 1); 1805 1806 return symb; 1807 } 1808 1809 static inline void HUF_init_state(const HUF_dtable *const dtable, 1810 u16 *const state, const u8 *const src, 1811 i64 *const offset) { 1812 // Read in a full `dtable->max_bits` bits to initialize the state 1813 const u8 bits = dtable->max_bits; 1814 *state = STREAM_read_bits(src, bits, offset); 1815 } 1816 1817 static size_t HUF_decompress_1stream(const HUF_dtable *const dtable, 1818 ostream_t *const out, 1819 istream_t *const in) { 1820 const size_t len = IO_istream_len(in); 1821 if (len == 0) { 1822 INP_SIZE(); 1823 } 1824 const u8 *const src = IO_get_read_ptr(in, len); 1825 1826 // "Each bitstream must be read backward, that is starting from the end down 1827 // to the beginning. Therefore it's necessary to know the size of each 1828 // bitstream. 1829 // 1830 // It's also necessary to know exactly which bit is the latest. This is 1831 // detected by a final bit flag : the highest bit of latest byte is a 1832 // final-bit-flag. Consequently, a last byte of 0 is not possible. And the 1833 // final-bit-flag itself is not part of the useful bitstream. Hence, the 1834 // last byte contains between 0 and 7 useful bits." 1835 const int padding = 8 - highest_set_bit(src[len - 1]); 1836 1837 // Offset starts at the end because HUF streams are read backwards 1838 i64 bit_offset = len * 8 - padding; 1839 u16 state; 1840 1841 HUF_init_state(dtable, &state, src, &bit_offset); 1842 1843 size_t symbols_written = 0; 1844 while (bit_offset > -dtable->max_bits) { 1845 // Iterate over the stream, decoding one symbol at a time 1846 IO_write_byte(out, HUF_decode_symbol(dtable, &state, src, &bit_offset)); 1847 symbols_written++; 1848 } 1849 // "The process continues up to reading the required number of symbols per 1850 // stream. If a bitstream is not entirely and exactly consumed, hence 1851 // reaching exactly its beginning position with all bits consumed, the 1852 // decoding process is considered faulty." 1853 1854 // When all symbols have been decoded, the final state value shouldn't have 1855 // any data from the stream, so it should have "read" dtable->max_bits from 1856 // before the start of `src` 1857 // Therefore `offset`, the edge to start reading new bits at, should be 1858 // dtable->max_bits before the start of the stream 1859 if (bit_offset != -dtable->max_bits) { 1860 CORRUPTION(); 1861 } 1862 1863 return symbols_written; 1864 } 1865 1866 static size_t HUF_decompress_4stream(const HUF_dtable *const dtable, 1867 ostream_t *const out, istream_t *const in) { 1868 // "Compressed size is provided explicitly : in the 4-streams variant, 1869 // bitstreams are preceded by 3 unsigned little-endian 16-bits values. Each 1870 // value represents the compressed size of one stream, in order. The last 1871 // stream size is deducted from total compressed size and from previously 1872 // decoded stream sizes" 1873 const size_t csize1 = IO_read_bits(in, 16); 1874 const size_t csize2 = IO_read_bits(in, 16); 1875 const size_t csize3 = IO_read_bits(in, 16); 1876 1877 istream_t in1 = IO_make_sub_istream(in, csize1); 1878 istream_t in2 = IO_make_sub_istream(in, csize2); 1879 istream_t in3 = IO_make_sub_istream(in, csize3); 1880 istream_t in4 = IO_make_sub_istream(in, IO_istream_len(in)); 1881 1882 size_t total_output = 0; 1883 // Decode each stream independently for simplicity 1884 // If we wanted to we could decode all 4 at the same time for speed, 1885 // utilizing more execution units 1886 total_output += HUF_decompress_1stream(dtable, out, &in1); 1887 total_output += HUF_decompress_1stream(dtable, out, &in2); 1888 total_output += HUF_decompress_1stream(dtable, out, &in3); 1889 total_output += HUF_decompress_1stream(dtable, out, &in4); 1890 1891 return total_output; 1892 } 1893 1894 /// Initializes a Huffman table using canonical Huffman codes 1895 /// For more explanation on canonical Huffman codes see 1896 /// https://www.cs.scranton.edu/~mccloske/courses/cmps340/huff_canonical_dec2015.html 1897 /// Codes within a level are allocated in symbol order (i.e. smaller symbols get 1898 /// earlier codes) 1899 static void HUF_init_dtable(HUF_dtable *const table, const u8 *const bits, 1900 const int num_symbs) { 1901 memset(table, 0, sizeof(HUF_dtable)); 1902 if (num_symbs > HUF_MAX_SYMBS) { 1903 ERROR("Too many symbols for Huffman"); 1904 } 1905 1906 u8 max_bits = 0; 1907 u16 rank_count[HUF_MAX_BITS + 1]; 1908 memset(rank_count, 0, sizeof(rank_count)); 1909 1910 // Count the number of symbols for each number of bits, and determine the 1911 // depth of the tree 1912 for (int i = 0; i < num_symbs; i++) { 1913 if (bits[i] > HUF_MAX_BITS) { 1914 ERROR("Huffman table depth too large"); 1915 } 1916 max_bits = MAX(max_bits, bits[i]); 1917 rank_count[bits[i]]++; 1918 } 1919 1920 const size_t table_size = 1 << max_bits; 1921 table->max_bits = max_bits; 1922 table->symbols = malloc(table_size); 1923 table->num_bits = malloc(table_size); 1924 1925 if (!table->symbols || !table->num_bits) { 1926 free(table->symbols); 1927 free(table->num_bits); 1928 BAD_ALLOC(); 1929 } 1930 1931 // "Symbols are sorted by Weight. Within same Weight, symbols keep natural 1932 // order. Symbols with a Weight of zero are removed. Then, starting from 1933 // lowest weight, prefix codes are distributed in order." 1934 1935 u32 rank_idx[HUF_MAX_BITS + 1]; 1936 // Initialize the starting codes for each rank (number of bits) 1937 rank_idx[max_bits] = 0; 1938 for (int i = max_bits; i >= 1; i--) { 1939 rank_idx[i - 1] = rank_idx[i] + rank_count[i] * (1 << (max_bits - i)); 1940 // The entire range takes the same number of bits so we can memset it 1941 memset(&table->num_bits[rank_idx[i]], i, rank_idx[i - 1] - rank_idx[i]); 1942 } 1943 1944 if (rank_idx[0] != table_size) { 1945 CORRUPTION(); 1946 } 1947 1948 // Allocate codes and fill in the table 1949 for (int i = 0; i < num_symbs; i++) { 1950 if (bits[i] != 0) { 1951 // Allocate a code for this symbol and set its range in the table 1952 const u16 code = rank_idx[bits[i]]; 1953 // Since the code doesn't care about the bottom `max_bits - bits[i]` 1954 // bits of state, it gets a range that spans all possible values of 1955 // the lower bits 1956 const u16 len = 1 << (max_bits - bits[i]); 1957 memset(&table->symbols[code], i, len); 1958 rank_idx[bits[i]] += len; 1959 } 1960 } 1961 } 1962 1963 static void HUF_init_dtable_usingweights(HUF_dtable *const table, 1964 const u8 *const weights, 1965 const int num_symbs) { 1966 // +1 because the last weight is not transmitted in the header 1967 if (num_symbs + 1 > HUF_MAX_SYMBS) { 1968 ERROR("Too many symbols for Huffman"); 1969 } 1970 1971 u8 bits[HUF_MAX_SYMBS]; 1972 1973 u64 weight_sum = 0; 1974 for (int i = 0; i < num_symbs; i++) { 1975 // Weights are in the same range as bit count 1976 if (weights[i] > HUF_MAX_BITS) { 1977 CORRUPTION(); 1978 } 1979 weight_sum += weights[i] > 0 ? (u64)1 << (weights[i] - 1) : 0; 1980 } 1981 1982 // Find the first power of 2 larger than the sum 1983 const int max_bits = highest_set_bit(weight_sum) + 1; 1984 const u64 left_over = ((u64)1 << max_bits) - weight_sum; 1985 // If the left over isn't a power of 2, the weights are invalid 1986 if (left_over & (left_over - 1)) { 1987 CORRUPTION(); 1988 } 1989 1990 // left_over is used to find the last weight as it's not transmitted 1991 // by inverting 2^(weight - 1) we can determine the value of last_weight 1992 const int last_weight = highest_set_bit(left_over) + 1; 1993 1994 for (int i = 0; i < num_symbs; i++) { 1995 // "Number_of_Bits = Number_of_Bits ? Max_Number_of_Bits + 1 - Weight : 0" 1996 bits[i] = weights[i] > 0 ? (max_bits + 1 - weights[i]) : 0; 1997 } 1998 bits[num_symbs] = 1999 max_bits + 1 - last_weight; // Last weight is always non-zero 2000 2001 HUF_init_dtable(table, bits, num_symbs + 1); 2002 } 2003 2004 static void HUF_free_dtable(HUF_dtable *const dtable) { 2005 free(dtable->symbols); 2006 free(dtable->num_bits); 2007 memset(dtable, 0, sizeof(HUF_dtable)); 2008 } 2009 /******* END HUFFMAN PRIMITIVES ***********************************************/ 2010 2011 /******* FSE PRIMITIVES *******************************************************/ 2012 /// For more description of FSE see 2013 /// https://github.com/Cyan4973/FiniteStateEntropy/ 2014 2015 /// Allow a symbol to be decoded without updating state 2016 static inline u8 FSE_peek_symbol(const FSE_dtable *const dtable, 2017 const u16 state) { 2018 return dtable->symbols[state]; 2019 } 2020 2021 /// Consumes bits from the input and uses the current state to determine the 2022 /// next state 2023 static inline void FSE_update_state(const FSE_dtable *const dtable, 2024 u16 *const state, const u8 *const src, 2025 i64 *const offset) { 2026 const u8 bits = dtable->num_bits[*state]; 2027 const u16 rest = STREAM_read_bits(src, bits, offset); 2028 *state = dtable->new_state_base[*state] + rest; 2029 } 2030 2031 /// Decodes a single FSE symbol and updates the offset 2032 static inline u8 FSE_decode_symbol(const FSE_dtable *const dtable, 2033 u16 *const state, const u8 *const src, 2034 i64 *const offset) { 2035 const u8 symb = FSE_peek_symbol(dtable, *state); 2036 FSE_update_state(dtable, state, src, offset); 2037 return symb; 2038 } 2039 2040 static inline void FSE_init_state(const FSE_dtable *const dtable, 2041 u16 *const state, const u8 *const src, 2042 i64 *const offset) { 2043 // Read in a full `accuracy_log` bits to initialize the state 2044 const u8 bits = dtable->accuracy_log; 2045 *state = STREAM_read_bits(src, bits, offset); 2046 } 2047 2048 static size_t FSE_decompress_interleaved2(const FSE_dtable *const dtable, 2049 ostream_t *const out, 2050 istream_t *const in) { 2051 const size_t len = IO_istream_len(in); 2052 if (len == 0) { 2053 INP_SIZE(); 2054 } 2055 const u8 *const src = IO_get_read_ptr(in, len); 2056 2057 // "Each bitstream must be read backward, that is starting from the end down 2058 // to the beginning. Therefore it's necessary to know the size of each 2059 // bitstream. 2060 // 2061 // It's also necessary to know exactly which bit is the latest. This is 2062 // detected by a final bit flag : the highest bit of latest byte is a 2063 // final-bit-flag. Consequently, a last byte of 0 is not possible. And the 2064 // final-bit-flag itself is not part of the useful bitstream. Hence, the 2065 // last byte contains between 0 and 7 useful bits." 2066 const int padding = 8 - highest_set_bit(src[len - 1]); 2067 i64 offset = len * 8 - padding; 2068 2069 u16 state1, state2; 2070 // "The first state (State1) encodes the even indexed symbols, and the 2071 // second (State2) encodes the odd indexes. State1 is initialized first, and 2072 // then State2, and they take turns decoding a single symbol and updating 2073 // their state." 2074 FSE_init_state(dtable, &state1, src, &offset); 2075 FSE_init_state(dtable, &state2, src, &offset); 2076 2077 // Decode until we overflow the stream 2078 // Since we decode in reverse order, overflowing the stream is offset going 2079 // negative 2080 size_t symbols_written = 0; 2081 while (1) { 2082 // "The number of symbols to decode is determined by tracking bitStream 2083 // overflow condition: If updating state after decoding a symbol would 2084 // require more bits than remain in the stream, it is assumed the extra 2085 // bits are 0. Then, the symbols for each of the final states are 2086 // decoded and the process is complete." 2087 IO_write_byte(out, FSE_decode_symbol(dtable, &state1, src, &offset)); 2088 symbols_written++; 2089 if (offset < 0) { 2090 // There's still a symbol to decode in state2 2091 IO_write_byte(out, FSE_peek_symbol(dtable, state2)); 2092 symbols_written++; 2093 break; 2094 } 2095 2096 IO_write_byte(out, FSE_decode_symbol(dtable, &state2, src, &offset)); 2097 symbols_written++; 2098 if (offset < 0) { 2099 // There's still a symbol to decode in state1 2100 IO_write_byte(out, FSE_peek_symbol(dtable, state1)); 2101 symbols_written++; 2102 break; 2103 } 2104 } 2105 2106 return symbols_written; 2107 } 2108 2109 static void FSE_init_dtable(FSE_dtable *const dtable, 2110 const i16 *const norm_freqs, const int num_symbs, 2111 const int accuracy_log) { 2112 if (accuracy_log > FSE_MAX_ACCURACY_LOG) { 2113 ERROR("FSE accuracy too large"); 2114 } 2115 if (num_symbs > FSE_MAX_SYMBS) { 2116 ERROR("Too many symbols for FSE"); 2117 } 2118 2119 dtable->accuracy_log = accuracy_log; 2120 2121 const size_t size = (size_t)1 << accuracy_log; 2122 dtable->symbols = malloc(size * sizeof(u8)); 2123 dtable->num_bits = malloc(size * sizeof(u8)); 2124 dtable->new_state_base = malloc(size * sizeof(u16)); 2125 2126 if (!dtable->symbols || !dtable->num_bits || !dtable->new_state_base) { 2127 BAD_ALLOC(); 2128 } 2129 2130 // Used to determine how many bits need to be read for each state, 2131 // and where the destination range should start 2132 // Needs to be u16 because max value is 2 * max number of symbols, 2133 // which can be larger than a byte can store 2134 u16 state_desc[FSE_MAX_SYMBS]; 2135 2136 // "Symbols are scanned in their natural order for "less than 1" 2137 // probabilities. Symbols with this probability are being attributed a 2138 // single cell, starting from the end of the table. These symbols define a 2139 // full state reset, reading Accuracy_Log bits." 2140 int high_threshold = size; 2141 for (int s = 0; s < num_symbs; s++) { 2142 // Scan for low probability symbols to put at the top 2143 if (norm_freqs[s] == -1) { 2144 dtable->symbols[--high_threshold] = s; 2145 state_desc[s] = 1; 2146 } 2147 } 2148 2149 // "All remaining symbols are sorted in their natural order. Starting from 2150 // symbol 0 and table position 0, each symbol gets attributed as many cells 2151 // as its probability. Cell allocation is spread, not linear." 2152 // Place the rest in the table 2153 const u16 step = (size >> 1) + (size >> 3) + 3; 2154 const u16 mask = size - 1; 2155 u16 pos = 0; 2156 for (int s = 0; s < num_symbs; s++) { 2157 if (norm_freqs[s] <= 0) { 2158 continue; 2159 } 2160 2161 state_desc[s] = norm_freqs[s]; 2162 2163 for (int i = 0; i < norm_freqs[s]; i++) { 2164 // Give `norm_freqs[s]` states to symbol s 2165 dtable->symbols[pos] = s; 2166 // "A position is skipped if already occupied, typically by a "less 2167 // than 1" probability symbol." 2168 do { 2169 pos = (pos + step) & mask; 2170 } while (pos >= 2171 high_threshold); 2172 // Note: no other collision checking is necessary as `step` is 2173 // coprime to `size`, so the cycle will visit each position exactly 2174 // once 2175 } 2176 } 2177 if (pos != 0) { 2178 CORRUPTION(); 2179 } 2180 2181 // Now we can fill baseline and num bits 2182 for (size_t i = 0; i < size; i++) { 2183 u8 symbol = dtable->symbols[i]; 2184 u16 next_state_desc = state_desc[symbol]++; 2185 // Fills in the table appropriately, next_state_desc increases by symbol 2186 // over time, decreasing number of bits 2187 dtable->num_bits[i] = (u8)(accuracy_log - highest_set_bit(next_state_desc)); 2188 // Baseline increases until the bit threshold is passed, at which point 2189 // it resets to 0 2190 dtable->new_state_base[i] = 2191 ((u16)next_state_desc << dtable->num_bits[i]) - size; 2192 } 2193 } 2194 2195 /// Decode an FSE header as defined in the Zstandard format specification and 2196 /// use the decoded frequencies to initialize a decoding table. 2197 static void FSE_decode_header(FSE_dtable *const dtable, istream_t *const in, 2198 const int max_accuracy_log) { 2199 // "An FSE distribution table describes the probabilities of all symbols 2200 // from 0 to the last present one (included) on a normalized scale of 1 << 2201 // Accuracy_Log . 2202 // 2203 // It's a bitstream which is read forward, in little-endian fashion. It's 2204 // not necessary to know its exact size, since it will be discovered and 2205 // reported by the decoding process. 2206 if (max_accuracy_log > FSE_MAX_ACCURACY_LOG) { 2207 ERROR("FSE accuracy too large"); 2208 } 2209 2210 // The bitstream starts by reporting on which scale it operates. 2211 // Accuracy_Log = low4bits + 5. Note that maximum Accuracy_Log for literal 2212 // and match lengths is 9, and for offsets is 8. Higher values are 2213 // considered errors." 2214 const int accuracy_log = 5 + IO_read_bits(in, 4); 2215 if (accuracy_log > max_accuracy_log) { 2216 ERROR("FSE accuracy too large"); 2217 } 2218 2219 // "Then follows each symbol value, from 0 to last present one. The number 2220 // of bits used by each field is variable. It depends on : 2221 // 2222 // Remaining probabilities + 1 : example : Presuming an Accuracy_Log of 8, 2223 // and presuming 100 probabilities points have already been distributed, the 2224 // decoder may read any value from 0 to 255 - 100 + 1 == 156 (inclusive). 2225 // Therefore, it must read log2sup(156) == 8 bits. 2226 // 2227 // Value decoded : small values use 1 less bit : example : Presuming values 2228 // from 0 to 156 (inclusive) are possible, 255-156 = 99 values are remaining 2229 // in an 8-bits field. They are used this way : first 99 values (hence from 2230 // 0 to 98) use only 7 bits, values from 99 to 156 use 8 bits. " 2231 2232 i32 remaining = 1 << accuracy_log; 2233 i16 frequencies[FSE_MAX_SYMBS]; 2234 2235 int symb = 0; 2236 while (remaining > 0 && symb < FSE_MAX_SYMBS) { 2237 // Log of the number of possible values we could read 2238 int bits = highest_set_bit(remaining + 1) + 1; 2239 2240 u16 val = IO_read_bits(in, bits); 2241 2242 // Try to mask out the lower bits to see if it qualifies for the "small 2243 // value" threshold 2244 const u16 lower_mask = ((u16)1 << (bits - 1)) - 1; 2245 const u16 threshold = ((u16)1 << bits) - 1 - (remaining + 1); 2246 2247 if ((val & lower_mask) < threshold) { 2248 IO_rewind_bits(in, 1); 2249 val = val & lower_mask; 2250 } else if (val > lower_mask) { 2251 val = val - threshold; 2252 } 2253 2254 // "Probability is obtained from Value decoded by following formula : 2255 // Proba = value - 1" 2256 const i16 proba = (i16)val - 1; 2257 2258 // "It means value 0 becomes negative probability -1. -1 is a special 2259 // probability, which means "less than 1". Its effect on distribution 2260 // table is described in next paragraph. For the purpose of calculating 2261 // cumulated distribution, it counts as one." 2262 remaining -= proba < 0 ? -proba : proba; 2263 2264 frequencies[symb] = proba; 2265 symb++; 2266 2267 // "When a symbol has a probability of zero, it is followed by a 2-bits 2268 // repeat flag. This repeat flag tells how many probabilities of zeroes 2269 // follow the current one. It provides a number ranging from 0 to 3. If 2270 // it is a 3, another 2-bits repeat flag follows, and so on." 2271 if (proba == 0) { 2272 // Read the next two bits to see how many more 0s 2273 int repeat = IO_read_bits(in, 2); 2274 2275 while (1) { 2276 for (int i = 0; i < repeat && symb < FSE_MAX_SYMBS; i++) { 2277 frequencies[symb++] = 0; 2278 } 2279 if (repeat == 3) { 2280 repeat = IO_read_bits(in, 2); 2281 } else { 2282 break; 2283 } 2284 } 2285 } 2286 } 2287 IO_align_stream(in); 2288 2289 // "When last symbol reaches cumulated total of 1 << Accuracy_Log, decoding 2290 // is complete. If the last symbol makes cumulated total go above 1 << 2291 // Accuracy_Log, distribution is considered corrupted." 2292 if (remaining != 0 || symb >= FSE_MAX_SYMBS) { 2293 CORRUPTION(); 2294 } 2295 2296 // Initialize the decoding table using the determined weights 2297 FSE_init_dtable(dtable, frequencies, symb, accuracy_log); 2298 } 2299 2300 static void FSE_init_dtable_rle(FSE_dtable *const dtable, const u8 symb) { 2301 dtable->symbols = malloc(sizeof(u8)); 2302 dtable->num_bits = malloc(sizeof(u8)); 2303 dtable->new_state_base = malloc(sizeof(u16)); 2304 2305 if (!dtable->symbols || !dtable->num_bits || !dtable->new_state_base) { 2306 BAD_ALLOC(); 2307 } 2308 2309 // This setup will always have a state of 0, always return symbol `symb`, 2310 // and never consume any bits 2311 dtable->symbols[0] = symb; 2312 dtable->num_bits[0] = 0; 2313 dtable->new_state_base[0] = 0; 2314 dtable->accuracy_log = 0; 2315 } 2316 2317 static void FSE_free_dtable(FSE_dtable *const dtable) { 2318 free(dtable->symbols); 2319 free(dtable->num_bits); 2320 free(dtable->new_state_base); 2321 memset(dtable, 0, sizeof(FSE_dtable)); 2322 } 2323 /******* END FSE PRIMITIVES ***************************************************/ 2324