xref: /freebsd-src/sys/contrib/zstd/doc/educational_decoder/zstd_decompress.c (revision 5ff13fbc199bdf5f0572845351c68ee5ca828e71)
10c16b537SWarner Losh /*
2*5ff13fbcSAllan Jude  * Copyright (c) Facebook, Inc.
30c16b537SWarner Losh  * All rights reserved.
40c16b537SWarner Losh  *
50c16b537SWarner Losh  * This source code is licensed under both the BSD-style license (found in the
60c16b537SWarner Losh  * LICENSE file in the root directory of this source tree) and the GPLv2 (found
70c16b537SWarner Losh  * in the COPYING file in the root directory of this source tree).
837f1f268SConrad Meyer  * You may select, at your option, one of the above-listed licenses.
90c16b537SWarner Losh  */
100c16b537SWarner Losh 
110c16b537SWarner Losh /// Zstandard educational decoder implementation
120c16b537SWarner Losh /// See https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md
130c16b537SWarner Losh 
1437f1f268SConrad Meyer #include <stdint.h>   // uint8_t, etc.
1537f1f268SConrad Meyer #include <stdlib.h>   // malloc, free, exit
1637f1f268SConrad Meyer #include <stdio.h>    // fprintf
1737f1f268SConrad Meyer #include <string.h>   // memset, memcpy
180c16b537SWarner Losh #include "zstd_decompress.h"
190c16b537SWarner Losh 
200c16b537SWarner Losh 
2137f1f268SConrad Meyer /******* IMPORTANT CONSTANTS *********************************************/
2237f1f268SConrad Meyer 
2337f1f268SConrad Meyer // Zstandard frame
2437f1f268SConrad Meyer // "Magic_Number
2537f1f268SConrad Meyer // 4 Bytes, little-endian format. Value : 0xFD2FB528"
2637f1f268SConrad Meyer #define ZSTD_MAGIC_NUMBER 0xFD2FB528U
2737f1f268SConrad Meyer 
2837f1f268SConrad Meyer // The size of `Block_Content` is limited by `Block_Maximum_Size`,
2937f1f268SConrad Meyer #define ZSTD_BLOCK_SIZE_MAX ((size_t)128 * 1024)
3037f1f268SConrad Meyer 
3137f1f268SConrad Meyer // literal blocks can't be larger than their block
3237f1f268SConrad Meyer #define MAX_LITERALS_SIZE ZSTD_BLOCK_SIZE_MAX
3337f1f268SConrad Meyer 
3437f1f268SConrad Meyer 
3537f1f268SConrad Meyer /******* UTILITY MACROS AND TYPES *********************************************/
360c16b537SWarner Losh #define MAX(a, b) ((a) > (b) ? (a) : (b))
370c16b537SWarner Losh #define MIN(a, b) ((a) < (b) ? (a) : (b))
380c16b537SWarner Losh 
3937f1f268SConrad Meyer #if defined(ZDEC_NO_MESSAGE)
4037f1f268SConrad Meyer #define MESSAGE(...)
4137f1f268SConrad Meyer #else
4237f1f268SConrad Meyer #define MESSAGE(...)  fprintf(stderr, "" __VA_ARGS__)
4337f1f268SConrad Meyer #endif
4437f1f268SConrad Meyer 
450c16b537SWarner Losh /// This decoder calls exit(1) when it encounters an error, however a production
460c16b537SWarner Losh /// library should propagate error codes
470c16b537SWarner Losh #define ERROR(s)                                                               \
480c16b537SWarner Losh     do {                                                                       \
4937f1f268SConrad Meyer         MESSAGE("Error: %s\n", s);                                     \
500c16b537SWarner Losh         exit(1);                                                               \
510c16b537SWarner Losh     } while (0)
520c16b537SWarner Losh #define INP_SIZE()                                                             \
530c16b537SWarner Losh     ERROR("Input buffer smaller than it should be or input is "                \
540c16b537SWarner Losh           "corrupted")
550c16b537SWarner Losh #define OUT_SIZE() ERROR("Output buffer too small for output")
560c16b537SWarner Losh #define CORRUPTION() ERROR("Corruption detected while decompressing")
570c16b537SWarner Losh #define BAD_ALLOC() ERROR("Memory allocation error")
580c16b537SWarner Losh #define IMPOSSIBLE() ERROR("An impossibility has occurred")
590c16b537SWarner Losh 
600c16b537SWarner Losh typedef uint8_t  u8;
610c16b537SWarner Losh typedef uint16_t u16;
620c16b537SWarner Losh typedef uint32_t u32;
630c16b537SWarner Losh typedef uint64_t u64;
640c16b537SWarner Losh 
650c16b537SWarner Losh typedef int8_t  i8;
660c16b537SWarner Losh typedef int16_t i16;
670c16b537SWarner Losh typedef int32_t i32;
680c16b537SWarner Losh typedef int64_t i64;
690c16b537SWarner Losh /******* END UTILITY MACROS AND TYPES *****************************************/
700c16b537SWarner Losh 
710c16b537SWarner Losh /******* IMPLEMENTATION PRIMITIVE PROTOTYPES **********************************/
720c16b537SWarner Losh /// The implementations for these functions can be found at the bottom of this
730c16b537SWarner Losh /// file.  They implement low-level functionality needed for the higher level
740c16b537SWarner Losh /// decompression functions.
750c16b537SWarner Losh 
760c16b537SWarner Losh /*** IO STREAM OPERATIONS *************/
770c16b537SWarner Losh 
780c16b537SWarner Losh /// ostream_t/istream_t are used to wrap the pointers/length data passed into
790c16b537SWarner Losh /// ZSTD_decompress, so that all IO operations are safely bounds checked
800c16b537SWarner Losh /// They are written/read forward, and reads are treated as little-endian
810c16b537SWarner Losh /// They should be used opaquely to ensure safety
820c16b537SWarner Losh typedef struct {
830c16b537SWarner Losh     u8 *ptr;
840c16b537SWarner Losh     size_t len;
850c16b537SWarner Losh } ostream_t;
860c16b537SWarner Losh 
870c16b537SWarner Losh typedef struct {
880c16b537SWarner Losh     const u8 *ptr;
890c16b537SWarner Losh     size_t len;
900c16b537SWarner Losh 
910c16b537SWarner Losh     // Input often reads a few bits at a time, so maintain an internal offset
920c16b537SWarner Losh     int bit_offset;
930c16b537SWarner Losh } istream_t;
940c16b537SWarner Losh 
950c16b537SWarner Losh /// The following two functions are the only ones that allow the istream to be
960c16b537SWarner Losh /// non-byte aligned
970c16b537SWarner Losh 
980c16b537SWarner Losh /// Reads `num` bits from a bitstream, and updates the internal offset
990c16b537SWarner Losh static inline u64 IO_read_bits(istream_t *const in, const int num_bits);
1000c16b537SWarner Losh /// Backs-up the stream by `num` bits so they can be read again
1010c16b537SWarner Losh static inline void IO_rewind_bits(istream_t *const in, const int num_bits);
1020c16b537SWarner Losh /// If the remaining bits in a byte will be unused, advance to the end of the
1030c16b537SWarner Losh /// byte
1040c16b537SWarner Losh static inline void IO_align_stream(istream_t *const in);
1050c16b537SWarner Losh 
1060c16b537SWarner Losh /// Write the given byte into the output stream
1070c16b537SWarner Losh static inline void IO_write_byte(ostream_t *const out, u8 symb);
1080c16b537SWarner Losh 
1090c16b537SWarner Losh /// Returns the number of bytes left to be read in this stream.  The stream must
1100c16b537SWarner Losh /// be byte aligned.
1110c16b537SWarner Losh static inline size_t IO_istream_len(const istream_t *const in);
1120c16b537SWarner Losh 
1130c16b537SWarner Losh /// Advances the stream by `len` bytes, and returns a pointer to the chunk that
1140c16b537SWarner Losh /// was skipped.  The stream must be byte aligned.
1150c16b537SWarner Losh static inline const u8 *IO_get_read_ptr(istream_t *const in, size_t len);
1160c16b537SWarner Losh /// Advances the stream by `len` bytes, and returns a pointer to the chunk that
1170c16b537SWarner Losh /// was skipped so it can be written to.
1180c16b537SWarner Losh static inline u8 *IO_get_write_ptr(ostream_t *const out, size_t len);
1190c16b537SWarner Losh 
1200c16b537SWarner Losh /// Advance the inner state by `len` bytes.  The stream must be byte aligned.
1210c16b537SWarner Losh static inline void IO_advance_input(istream_t *const in, size_t len);
1220c16b537SWarner Losh 
1230c16b537SWarner Losh /// Returns an `ostream_t` constructed from the given pointer and length.
1240c16b537SWarner Losh static inline ostream_t IO_make_ostream(u8 *out, size_t len);
1250c16b537SWarner Losh /// Returns an `istream_t` constructed from the given pointer and length.
1260c16b537SWarner Losh static inline istream_t IO_make_istream(const u8 *in, size_t len);
1270c16b537SWarner Losh 
1280c16b537SWarner Losh /// Returns an `istream_t` with the same base as `in`, and length `len`.
1290c16b537SWarner Losh /// Then, advance `in` to account for the consumed bytes.
1300c16b537SWarner Losh /// `in` must be byte aligned.
1310c16b537SWarner Losh static inline istream_t IO_make_sub_istream(istream_t *const in, size_t len);
1320c16b537SWarner Losh /*** END IO STREAM OPERATIONS *********/
1330c16b537SWarner Losh 
1340c16b537SWarner Losh /*** BITSTREAM OPERATIONS *************/
1350c16b537SWarner Losh /// Read `num` bits (up to 64) from `src + offset`, where `offset` is in bits,
1360c16b537SWarner Losh /// and return them interpreted as a little-endian unsigned integer.
1370c16b537SWarner Losh static inline u64 read_bits_LE(const u8 *src, const int num_bits,
1380c16b537SWarner Losh                                const size_t offset);
1390c16b537SWarner Losh 
1400c16b537SWarner Losh /// Read bits from the end of a HUF or FSE bitstream.  `offset` is in bits, so
1410c16b537SWarner Losh /// it updates `offset` to `offset - bits`, and then reads `bits` bits from
1420c16b537SWarner Losh /// `src + offset`.  If the offset becomes negative, the extra bits at the
1430c16b537SWarner Losh /// bottom are filled in with `0` bits instead of reading from before `src`.
1440c16b537SWarner Losh static inline u64 STREAM_read_bits(const u8 *src, const int bits,
1450c16b537SWarner Losh                                    i64 *const offset);
1460c16b537SWarner Losh /*** END BITSTREAM OPERATIONS *********/
1470c16b537SWarner Losh 
1480c16b537SWarner Losh /*** BIT COUNTING OPERATIONS **********/
1490c16b537SWarner Losh /// Returns the index of the highest set bit in `num`, or `-1` if `num == 0`
1500c16b537SWarner Losh static inline int highest_set_bit(const u64 num);
1510c16b537SWarner Losh /*** END BIT COUNTING OPERATIONS ******/
1520c16b537SWarner Losh 
1530c16b537SWarner Losh /*** HUFFMAN PRIMITIVES ***************/
1540c16b537SWarner Losh // Table decode method uses exponential memory, so we need to limit depth
1550c16b537SWarner Losh #define HUF_MAX_BITS (16)
1560c16b537SWarner Losh 
1570c16b537SWarner Losh // Limit the maximum number of symbols to 256 so we can store a symbol in a byte
1580c16b537SWarner Losh #define HUF_MAX_SYMBS (256)
1590c16b537SWarner Losh 
1600c16b537SWarner Losh /// Structure containing all tables necessary for efficient Huffman decoding
1610c16b537SWarner Losh typedef struct {
1620c16b537SWarner Losh     u8 *symbols;
1630c16b537SWarner Losh     u8 *num_bits;
1640c16b537SWarner Losh     int max_bits;
1650c16b537SWarner Losh } HUF_dtable;
1660c16b537SWarner Losh 
1670c16b537SWarner Losh /// Decode a single symbol and read in enough bits to refresh the state
1680c16b537SWarner Losh static inline u8 HUF_decode_symbol(const HUF_dtable *const dtable,
1690c16b537SWarner Losh                                    u16 *const state, const u8 *const src,
1700c16b537SWarner Losh                                    i64 *const offset);
1710c16b537SWarner Losh /// Read in a full state's worth of bits to initialize it
1720c16b537SWarner Losh static inline void HUF_init_state(const HUF_dtable *const dtable,
1730c16b537SWarner Losh                                   u16 *const state, const u8 *const src,
1740c16b537SWarner Losh                                   i64 *const offset);
1750c16b537SWarner Losh 
1760c16b537SWarner Losh /// Decompresses a single Huffman stream, returns the number of bytes decoded.
1770c16b537SWarner Losh /// `src_len` must be the exact length of the Huffman-coded block.
1780c16b537SWarner Losh static size_t HUF_decompress_1stream(const HUF_dtable *const dtable,
1790c16b537SWarner Losh                                      ostream_t *const out, istream_t *const in);
1800c16b537SWarner Losh /// Same as previous but decodes 4 streams, formatted as in the Zstandard
1810c16b537SWarner Losh /// specification.
1820c16b537SWarner Losh /// `src_len` must be the exact length of the Huffman-coded block.
1830c16b537SWarner Losh static size_t HUF_decompress_4stream(const HUF_dtable *const dtable,
1840c16b537SWarner Losh                                      ostream_t *const out, istream_t *const in);
1850c16b537SWarner Losh 
1860c16b537SWarner Losh /// Initialize a Huffman decoding table using the table of bit counts provided
1870c16b537SWarner Losh static void HUF_init_dtable(HUF_dtable *const table, const u8 *const bits,
1880c16b537SWarner Losh                             const int num_symbs);
1890c16b537SWarner Losh /// Initialize a Huffman decoding table using the table of weights provided
1900c16b537SWarner Losh /// Weights follow the definition provided in the Zstandard specification
1910c16b537SWarner Losh static void HUF_init_dtable_usingweights(HUF_dtable *const table,
1920c16b537SWarner Losh                                          const u8 *const weights,
1930c16b537SWarner Losh                                          const int num_symbs);
1940c16b537SWarner Losh 
1950c16b537SWarner Losh /// Free the malloc'ed parts of a decoding table
1960c16b537SWarner Losh static void HUF_free_dtable(HUF_dtable *const dtable);
1970c16b537SWarner Losh /*** END HUFFMAN PRIMITIVES ***********/
1980c16b537SWarner Losh 
1990c16b537SWarner Losh /*** FSE PRIMITIVES *******************/
2000c16b537SWarner Losh /// For more description of FSE see
2010c16b537SWarner Losh /// https://github.com/Cyan4973/FiniteStateEntropy/
2020c16b537SWarner Losh 
2030c16b537SWarner Losh // FSE table decoding uses exponential memory, so limit the maximum accuracy
2040c16b537SWarner Losh #define FSE_MAX_ACCURACY_LOG (15)
2050c16b537SWarner Losh // Limit the maximum number of symbols so they can be stored in a single byte
2060c16b537SWarner Losh #define FSE_MAX_SYMBS (256)
2070c16b537SWarner Losh 
2080c16b537SWarner Losh /// The tables needed to decode FSE encoded streams
2090c16b537SWarner Losh typedef struct {
2100c16b537SWarner Losh     u8 *symbols;
2110c16b537SWarner Losh     u8 *num_bits;
2120c16b537SWarner Losh     u16 *new_state_base;
2130c16b537SWarner Losh     int accuracy_log;
2140c16b537SWarner Losh } FSE_dtable;
2150c16b537SWarner Losh 
2160c16b537SWarner Losh /// Return the symbol for the current state
2170c16b537SWarner Losh static inline u8 FSE_peek_symbol(const FSE_dtable *const dtable,
2180c16b537SWarner Losh                                  const u16 state);
2190c16b537SWarner Losh /// Read the number of bits necessary to update state, update, and shift offset
2200c16b537SWarner Losh /// back to reflect the bits read
2210c16b537SWarner Losh static inline void FSE_update_state(const FSE_dtable *const dtable,
2220c16b537SWarner Losh                                     u16 *const state, const u8 *const src,
2230c16b537SWarner Losh                                     i64 *const offset);
2240c16b537SWarner Losh 
2250c16b537SWarner Losh /// Combine peek and update: decode a symbol and update the state
2260c16b537SWarner Losh static inline u8 FSE_decode_symbol(const FSE_dtable *const dtable,
2270c16b537SWarner Losh                                    u16 *const state, const u8 *const src,
2280c16b537SWarner Losh                                    i64 *const offset);
2290c16b537SWarner Losh 
2300c16b537SWarner Losh /// Read bits from the stream to initialize the state and shift offset back
2310c16b537SWarner Losh static inline void FSE_init_state(const FSE_dtable *const dtable,
2320c16b537SWarner Losh                                   u16 *const state, const u8 *const src,
2330c16b537SWarner Losh                                   i64 *const offset);
2340c16b537SWarner Losh 
2350c16b537SWarner Losh /// Decompress two interleaved bitstreams (e.g. compressed Huffman weights)
2360c16b537SWarner Losh /// using an FSE decoding table.  `src_len` must be the exact length of the
2370c16b537SWarner Losh /// block.
2380c16b537SWarner Losh static size_t FSE_decompress_interleaved2(const FSE_dtable *const dtable,
2390c16b537SWarner Losh                                           ostream_t *const out,
2400c16b537SWarner Losh                                           istream_t *const in);
2410c16b537SWarner Losh 
2420c16b537SWarner Losh /// Initialize a decoding table using normalized frequencies.
2430c16b537SWarner Losh static void FSE_init_dtable(FSE_dtable *const dtable,
2440c16b537SWarner Losh                             const i16 *const norm_freqs, const int num_symbs,
2450c16b537SWarner Losh                             const int accuracy_log);
2460c16b537SWarner Losh 
2470c16b537SWarner Losh /// Decode an FSE header as defined in the Zstandard format specification and
2480c16b537SWarner Losh /// use the decoded frequencies to initialize a decoding table.
2490c16b537SWarner Losh static void FSE_decode_header(FSE_dtable *const dtable, istream_t *const in,
2500c16b537SWarner Losh                                 const int max_accuracy_log);
2510c16b537SWarner Losh 
2520c16b537SWarner Losh /// Initialize an FSE table that will always return the same symbol and consume
2530c16b537SWarner Losh /// 0 bits per symbol, to be used for RLE mode in sequence commands
2540c16b537SWarner Losh static void FSE_init_dtable_rle(FSE_dtable *const dtable, const u8 symb);
2550c16b537SWarner Losh 
2560c16b537SWarner Losh /// Free the malloc'ed parts of a decoding table
2570c16b537SWarner Losh static void FSE_free_dtable(FSE_dtable *const dtable);
2580c16b537SWarner Losh /*** END FSE PRIMITIVES ***************/
2590c16b537SWarner Losh 
2600c16b537SWarner Losh /******* END IMPLEMENTATION PRIMITIVE PROTOTYPES ******************************/
2610c16b537SWarner Losh 
2620c16b537SWarner Losh /******* ZSTD HELPER STRUCTS AND PROTOTYPES ***********************************/
2630c16b537SWarner Losh 
2640c16b537SWarner Losh /// A small structure that can be reused in various places that need to access
2650c16b537SWarner Losh /// frame header information
2660c16b537SWarner Losh typedef struct {
2670c16b537SWarner Losh     // The size of window that we need to be able to contiguously store for
2680c16b537SWarner Losh     // references
2690c16b537SWarner Losh     size_t window_size;
2700c16b537SWarner Losh     // The total output size of this compressed frame
2710c16b537SWarner Losh     size_t frame_content_size;
2720c16b537SWarner Losh 
2730c16b537SWarner Losh     // The dictionary id if this frame uses one
2740c16b537SWarner Losh     u32 dictionary_id;
2750c16b537SWarner Losh 
2760c16b537SWarner Losh     // Whether or not the content of this frame has a checksum
2770c16b537SWarner Losh     int content_checksum_flag;
2780c16b537SWarner Losh     // Whether or not the output for this frame is in a single segment
2790c16b537SWarner Losh     int single_segment_flag;
2800c16b537SWarner Losh } frame_header_t;
2810c16b537SWarner Losh 
2820c16b537SWarner Losh /// The context needed to decode blocks in a frame
2830c16b537SWarner Losh typedef struct {
2840c16b537SWarner Losh     frame_header_t header;
2850c16b537SWarner Losh 
2860c16b537SWarner Losh     // The total amount of data available for backreferences, to determine if an
2870c16b537SWarner Losh     // offset too large to be correct
2880c16b537SWarner Losh     size_t current_total_output;
2890c16b537SWarner Losh 
2900c16b537SWarner Losh     const u8 *dict_content;
2910c16b537SWarner Losh     size_t dict_content_len;
2920c16b537SWarner Losh 
2930c16b537SWarner Losh     // Entropy encoding tables so they can be repeated by future blocks instead
2940c16b537SWarner Losh     // of retransmitting
2950c16b537SWarner Losh     HUF_dtable literals_dtable;
2960c16b537SWarner Losh     FSE_dtable ll_dtable;
2970c16b537SWarner Losh     FSE_dtable ml_dtable;
2980c16b537SWarner Losh     FSE_dtable of_dtable;
2990c16b537SWarner Losh 
3000c16b537SWarner Losh     // The last 3 offsets for the special "repeat offsets".
3010c16b537SWarner Losh     u64 previous_offsets[3];
3020c16b537SWarner Losh } frame_context_t;
3030c16b537SWarner Losh 
3040c16b537SWarner Losh /// The decoded contents of a dictionary so that it doesn't have to be repeated
3050c16b537SWarner Losh /// for each frame that uses it
3060c16b537SWarner Losh struct dictionary_s {
3070c16b537SWarner Losh     // Entropy tables
3080c16b537SWarner Losh     HUF_dtable literals_dtable;
3090c16b537SWarner Losh     FSE_dtable ll_dtable;
3100c16b537SWarner Losh     FSE_dtable ml_dtable;
3110c16b537SWarner Losh     FSE_dtable of_dtable;
3120c16b537SWarner Losh 
3130c16b537SWarner Losh     // Raw content for backreferences
3140c16b537SWarner Losh     u8 *content;
3150c16b537SWarner Losh     size_t content_size;
3160c16b537SWarner Losh 
3170c16b537SWarner Losh     // Offset history to prepopulate the frame's history
3180c16b537SWarner Losh     u64 previous_offsets[3];
3190c16b537SWarner Losh 
3200c16b537SWarner Losh     u32 dictionary_id;
3210c16b537SWarner Losh };
3220c16b537SWarner Losh 
3230c16b537SWarner Losh /// A tuple containing the parts necessary to decode and execute a ZSTD sequence
3240c16b537SWarner Losh /// command
3250c16b537SWarner Losh typedef struct {
3260c16b537SWarner Losh     u32 literal_length;
3270c16b537SWarner Losh     u32 match_length;
3280c16b537SWarner Losh     u32 offset;
3290c16b537SWarner Losh } sequence_command_t;
3300c16b537SWarner Losh 
3310c16b537SWarner Losh /// The decoder works top-down, starting at the high level like Zstd frames, and
3320c16b537SWarner Losh /// working down to lower more technical levels such as blocks, literals, and
3330c16b537SWarner Losh /// sequences.  The high-level functions roughly follow the outline of the
3340c16b537SWarner Losh /// format specification:
3350c16b537SWarner Losh /// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md
3360c16b537SWarner Losh 
3370c16b537SWarner Losh /// Before the implementation of each high-level function declared here, the
3380c16b537SWarner Losh /// prototypes for their helper functions are defined and explained
3390c16b537SWarner Losh 
3400c16b537SWarner Losh /// Decode a single Zstd frame, or error if the input is not a valid frame.
3410c16b537SWarner Losh /// Accepts a dict argument, which may be NULL indicating no dictionary.
3420c16b537SWarner Losh /// See
3430c16b537SWarner Losh /// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#frame-concatenation
3440c16b537SWarner Losh static void decode_frame(ostream_t *const out, istream_t *const in,
3450c16b537SWarner Losh                          const dictionary_t *const dict);
3460c16b537SWarner Losh 
3470c16b537SWarner Losh // Decode data in a compressed block
3480c16b537SWarner Losh static void decompress_block(frame_context_t *const ctx, ostream_t *const out,
3490c16b537SWarner Losh                              istream_t *const in);
3500c16b537SWarner Losh 
3510c16b537SWarner Losh // Decode the literals section of a block
3520c16b537SWarner Losh static size_t decode_literals(frame_context_t *const ctx, istream_t *const in,
3530c16b537SWarner Losh                               u8 **const literals);
3540c16b537SWarner Losh 
3550c16b537SWarner Losh // Decode the sequences part of a block
3560c16b537SWarner Losh static size_t decode_sequences(frame_context_t *const ctx, istream_t *const in,
3570c16b537SWarner Losh                                sequence_command_t **const sequences);
3580c16b537SWarner Losh 
3590c16b537SWarner Losh // Execute the decoded sequences on the literals block
3600c16b537SWarner Losh static void execute_sequences(frame_context_t *const ctx, ostream_t *const out,
3610c16b537SWarner Losh                               const u8 *const literals,
3620c16b537SWarner Losh                               const size_t literals_len,
3630c16b537SWarner Losh                               const sequence_command_t *const sequences,
3640c16b537SWarner Losh                               const size_t num_sequences);
3650c16b537SWarner Losh 
3660c16b537SWarner Losh // Copies literals and returns the total literal length that was copied
3670c16b537SWarner Losh static u32 copy_literals(const size_t seq, istream_t *litstream,
3680c16b537SWarner Losh                          ostream_t *const out);
3690c16b537SWarner Losh 
3700c16b537SWarner Losh // Given an offset code from a sequence command (either an actual offset value
3712b9c00cbSConrad Meyer // or an index for previous offset), computes the correct offset and updates
3720c16b537SWarner Losh // the offset history
3730c16b537SWarner Losh static size_t compute_offset(sequence_command_t seq, u64 *const offset_hist);
3740c16b537SWarner Losh 
3750c16b537SWarner Losh // Given an offset, match length, and total output, as well as the frame
3760c16b537SWarner Losh // context for the dictionary, determines if the dictionary is used and
3770c16b537SWarner Losh // executes the copy operation
3780c16b537SWarner Losh static void execute_match_copy(frame_context_t *const ctx, size_t offset,
3790c16b537SWarner Losh                               size_t match_length, size_t total_output,
3800c16b537SWarner Losh                               ostream_t *const out);
3810c16b537SWarner Losh 
3820c16b537SWarner Losh /******* END ZSTD HELPER STRUCTS AND PROTOTYPES *******************************/
3830c16b537SWarner Losh 
ZSTD_decompress(void * const dst,const size_t dst_len,const void * const src,const size_t src_len)3840c16b537SWarner Losh size_t ZSTD_decompress(void *const dst, const size_t dst_len,
3850c16b537SWarner Losh                        const void *const src, const size_t src_len) {
38637f1f268SConrad Meyer     dictionary_t* const uninit_dict = create_dictionary();
3870c16b537SWarner Losh     size_t const decomp_size = ZSTD_decompress_with_dict(dst, dst_len, src,
3880c16b537SWarner Losh                                                          src_len, uninit_dict);
3890c16b537SWarner Losh     free_dictionary(uninit_dict);
3900c16b537SWarner Losh     return decomp_size;
3910c16b537SWarner Losh }
3920c16b537SWarner Losh 
ZSTD_decompress_with_dict(void * const dst,const size_t dst_len,const void * const src,const size_t src_len,dictionary_t * parsed_dict)3930c16b537SWarner Losh size_t ZSTD_decompress_with_dict(void *const dst, const size_t dst_len,
3940c16b537SWarner Losh                                  const void *const src, const size_t src_len,
3950c16b537SWarner Losh                                  dictionary_t* parsed_dict) {
3960c16b537SWarner Losh 
3970c16b537SWarner Losh     istream_t in = IO_make_istream(src, src_len);
3980c16b537SWarner Losh     ostream_t out = IO_make_ostream(dst, dst_len);
3990c16b537SWarner Losh 
4000c16b537SWarner Losh     // "A content compressed by Zstandard is transformed into a Zstandard frame.
4010c16b537SWarner Losh     // Multiple frames can be appended into a single file or stream. A frame is
4020c16b537SWarner Losh     // totally independent, has a defined beginning and end, and a set of
4030c16b537SWarner Losh     // parameters which tells the decoder how to decompress it."
4040c16b537SWarner Losh 
4050c16b537SWarner Losh     /* this decoder assumes decompression of a single frame */
4060c16b537SWarner Losh     decode_frame(&out, &in, parsed_dict);
4070c16b537SWarner Losh 
4089cbefe25SConrad Meyer     return (size_t)(out.ptr - (u8 *)dst);
4090c16b537SWarner Losh }
4100c16b537SWarner Losh 
4110c16b537SWarner Losh /******* FRAME DECODING ******************************************************/
4120c16b537SWarner Losh 
4130c16b537SWarner Losh static void decode_data_frame(ostream_t *const out, istream_t *const in,
4140c16b537SWarner Losh                               const dictionary_t *const dict);
4150c16b537SWarner Losh static void init_frame_context(frame_context_t *const context,
4160c16b537SWarner Losh                                istream_t *const in,
4170c16b537SWarner Losh                                const dictionary_t *const dict);
4180c16b537SWarner Losh static void free_frame_context(frame_context_t *const context);
4190c16b537SWarner Losh static void parse_frame_header(frame_header_t *const header,
4200c16b537SWarner Losh                                istream_t *const in);
4210c16b537SWarner Losh static void frame_context_apply_dict(frame_context_t *const ctx,
4220c16b537SWarner Losh                                      const dictionary_t *const dict);
4230c16b537SWarner Losh 
4240c16b537SWarner Losh static void decompress_data(frame_context_t *const ctx, ostream_t *const out,
4250c16b537SWarner Losh                             istream_t *const in);
4260c16b537SWarner Losh 
decode_frame(ostream_t * const out,istream_t * const in,const dictionary_t * const dict)4270c16b537SWarner Losh static void decode_frame(ostream_t *const out, istream_t *const in,
4280c16b537SWarner Losh                          const dictionary_t *const dict) {
4299cbefe25SConrad Meyer     const u32 magic_number = (u32)IO_read_bits(in, 32);
43037f1f268SConrad Meyer     if (magic_number == ZSTD_MAGIC_NUMBER) {
4310c16b537SWarner Losh         // ZSTD frame
4320c16b537SWarner Losh         decode_data_frame(out, in, dict);
4330c16b537SWarner Losh 
4340c16b537SWarner Losh         return;
4350c16b537SWarner Losh     }
4360c16b537SWarner Losh 
4370c16b537SWarner Losh     // not a real frame or a skippable frame
4380c16b537SWarner Losh     ERROR("Tried to decode non-ZSTD frame");
4390c16b537SWarner Losh }
4400c16b537SWarner Losh 
4410c16b537SWarner Losh /// Decode a frame that contains compressed data.  Not all frames do as there
4420c16b537SWarner Losh /// are skippable frames.
4430c16b537SWarner Losh /// See
4440c16b537SWarner Losh /// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#general-structure-of-zstandard-frame-format
decode_data_frame(ostream_t * const out,istream_t * const in,const dictionary_t * const dict)4450c16b537SWarner Losh static void decode_data_frame(ostream_t *const out, istream_t *const in,
4460c16b537SWarner Losh                               const dictionary_t *const dict) {
4470c16b537SWarner Losh     frame_context_t ctx;
4480c16b537SWarner Losh 
4490c16b537SWarner Losh     // Initialize the context that needs to be carried from block to block
4500c16b537SWarner Losh     init_frame_context(&ctx, in, dict);
4510c16b537SWarner Losh 
4520c16b537SWarner Losh     if (ctx.header.frame_content_size != 0 &&
4530c16b537SWarner Losh         ctx.header.frame_content_size > out->len) {
4540c16b537SWarner Losh         OUT_SIZE();
4550c16b537SWarner Losh     }
4560c16b537SWarner Losh 
4570c16b537SWarner Losh     decompress_data(&ctx, out, in);
4580c16b537SWarner Losh 
4590c16b537SWarner Losh     free_frame_context(&ctx);
4600c16b537SWarner Losh }
4610c16b537SWarner Losh 
4620c16b537SWarner Losh /// Takes the information provided in the header and dictionary, and initializes
4630c16b537SWarner Losh /// the context for this frame
init_frame_context(frame_context_t * const context,istream_t * const in,const dictionary_t * const dict)4640c16b537SWarner Losh static void init_frame_context(frame_context_t *const context,
4650c16b537SWarner Losh                                istream_t *const in,
4660c16b537SWarner Losh                                const dictionary_t *const dict) {
4670c16b537SWarner Losh     // Most fields in context are correct when initialized to 0
4680c16b537SWarner Losh     memset(context, 0, sizeof(frame_context_t));
4690c16b537SWarner Losh 
4700c16b537SWarner Losh     // Parse data from the frame header
4710c16b537SWarner Losh     parse_frame_header(&context->header, in);
4720c16b537SWarner Losh 
4730c16b537SWarner Losh     // Set up the offset history for the repeat offset commands
4740c16b537SWarner Losh     context->previous_offsets[0] = 1;
4750c16b537SWarner Losh     context->previous_offsets[1] = 4;
4760c16b537SWarner Losh     context->previous_offsets[2] = 8;
4770c16b537SWarner Losh 
4780c16b537SWarner Losh     // Apply details from the dict if it exists
4790c16b537SWarner Losh     frame_context_apply_dict(context, dict);
4800c16b537SWarner Losh }
4810c16b537SWarner Losh 
free_frame_context(frame_context_t * const context)4820c16b537SWarner Losh static void free_frame_context(frame_context_t *const context) {
4830c16b537SWarner Losh     HUF_free_dtable(&context->literals_dtable);
4840c16b537SWarner Losh 
4850c16b537SWarner Losh     FSE_free_dtable(&context->ll_dtable);
4860c16b537SWarner Losh     FSE_free_dtable(&context->ml_dtable);
4870c16b537SWarner Losh     FSE_free_dtable(&context->of_dtable);
4880c16b537SWarner Losh 
4890c16b537SWarner Losh     memset(context, 0, sizeof(frame_context_t));
4900c16b537SWarner Losh }
4910c16b537SWarner Losh 
parse_frame_header(frame_header_t * const header,istream_t * const in)4920c16b537SWarner Losh static void parse_frame_header(frame_header_t *const header,
4930c16b537SWarner Losh                                istream_t *const in) {
4940c16b537SWarner Losh     // "The first header's byte is called the Frame_Header_Descriptor. It tells
4950c16b537SWarner Losh     // which other fields are present. Decoding this byte is enough to tell the
4960c16b537SWarner Losh     // size of Frame_Header.
4970c16b537SWarner Losh     //
4980c16b537SWarner Losh     // Bit number   Field name
4990c16b537SWarner Losh     // 7-6  Frame_Content_Size_flag
5000c16b537SWarner Losh     // 5    Single_Segment_flag
5010c16b537SWarner Losh     // 4    Unused_bit
5020c16b537SWarner Losh     // 3    Reserved_bit
5030c16b537SWarner Losh     // 2    Content_Checksum_flag
5040c16b537SWarner Losh     // 1-0  Dictionary_ID_flag"
5059cbefe25SConrad Meyer     const u8 descriptor = (u8)IO_read_bits(in, 8);
5060c16b537SWarner Losh 
5070c16b537SWarner Losh     // decode frame header descriptor into flags
5080c16b537SWarner Losh     const u8 frame_content_size_flag = descriptor >> 6;
5090c16b537SWarner Losh     const u8 single_segment_flag = (descriptor >> 5) & 1;
5100c16b537SWarner Losh     const u8 reserved_bit = (descriptor >> 3) & 1;
5110c16b537SWarner Losh     const u8 content_checksum_flag = (descriptor >> 2) & 1;
5120c16b537SWarner Losh     const u8 dictionary_id_flag = descriptor & 3;
5130c16b537SWarner Losh 
5140c16b537SWarner Losh     if (reserved_bit != 0) {
5150c16b537SWarner Losh         CORRUPTION();
5160c16b537SWarner Losh     }
5170c16b537SWarner Losh 
5180c16b537SWarner Losh     header->single_segment_flag = single_segment_flag;
5190c16b537SWarner Losh     header->content_checksum_flag = content_checksum_flag;
5200c16b537SWarner Losh 
5210c16b537SWarner Losh     // decode window size
5220c16b537SWarner Losh     if (!single_segment_flag) {
5230c16b537SWarner Losh         // "Provides guarantees on maximum back-reference distance that will be
5240c16b537SWarner Losh         // used within compressed data. This information is important for
5250c16b537SWarner Losh         // decoders to allocate enough memory.
5260c16b537SWarner Losh         //
5270c16b537SWarner Losh         // Bit numbers  7-3         2-0
5280c16b537SWarner Losh         // Field name   Exponent    Mantissa"
5299cbefe25SConrad Meyer         u8 window_descriptor = (u8)IO_read_bits(in, 8);
5300c16b537SWarner Losh         u8 exponent = window_descriptor >> 3;
5310c16b537SWarner Losh         u8 mantissa = window_descriptor & 7;
5320c16b537SWarner Losh 
5330c16b537SWarner Losh         // Use the algorithm from the specification to compute window size
5340c16b537SWarner Losh         // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#window_descriptor
5350c16b537SWarner Losh         size_t window_base = (size_t)1 << (10 + exponent);
5360c16b537SWarner Losh         size_t window_add = (window_base / 8) * mantissa;
5370c16b537SWarner Losh         header->window_size = window_base + window_add;
5380c16b537SWarner Losh     }
5390c16b537SWarner Losh 
5400c16b537SWarner Losh     // decode dictionary id if it exists
5410c16b537SWarner Losh     if (dictionary_id_flag) {
5420c16b537SWarner Losh         // "This is a variable size field, which contains the ID of the
5430c16b537SWarner Losh         // dictionary required to properly decode the frame. Note that this
5440c16b537SWarner Losh         // field is optional. When it's not present, it's up to the caller to
5450c16b537SWarner Losh         // make sure it uses the correct dictionary. Format is little-endian."
5460c16b537SWarner Losh         const int bytes_array[] = {0, 1, 2, 4};
5470c16b537SWarner Losh         const int bytes = bytes_array[dictionary_id_flag];
5480c16b537SWarner Losh 
5499cbefe25SConrad Meyer         header->dictionary_id = (u32)IO_read_bits(in, bytes * 8);
5500c16b537SWarner Losh     } else {
5510c16b537SWarner Losh         header->dictionary_id = 0;
5520c16b537SWarner Losh     }
5530c16b537SWarner Losh 
5540c16b537SWarner Losh     // decode frame content size if it exists
5550c16b537SWarner Losh     if (single_segment_flag || frame_content_size_flag) {
5560c16b537SWarner Losh         // "This is the original (uncompressed) size. This information is
5570c16b537SWarner Losh         // optional. The Field_Size is provided according to value of
5580c16b537SWarner Losh         // Frame_Content_Size_flag. The Field_Size can be equal to 0 (not
5590c16b537SWarner Losh         // present), 1, 2, 4 or 8 bytes. Format is little-endian."
5600c16b537SWarner Losh         //
5610c16b537SWarner Losh         // if frame_content_size_flag == 0 but single_segment_flag is set, we
5620c16b537SWarner Losh         // still have a 1 byte field
5630c16b537SWarner Losh         const int bytes_array[] = {1, 2, 4, 8};
5640c16b537SWarner Losh         const int bytes = bytes_array[frame_content_size_flag];
5650c16b537SWarner Losh 
5660c16b537SWarner Losh         header->frame_content_size = IO_read_bits(in, bytes * 8);
5670c16b537SWarner Losh         if (bytes == 2) {
5680c16b537SWarner Losh             // "When Field_Size is 2, the offset of 256 is added."
5690c16b537SWarner Losh             header->frame_content_size += 256;
5700c16b537SWarner Losh         }
5710c16b537SWarner Losh     } else {
5720c16b537SWarner Losh         header->frame_content_size = 0;
5730c16b537SWarner Losh     }
5740c16b537SWarner Losh 
5750c16b537SWarner Losh     if (single_segment_flag) {
5760c16b537SWarner Losh         // "The Window_Descriptor byte is optional. It is absent when
5770c16b537SWarner Losh         // Single_Segment_flag is set. In this case, the maximum back-reference
5780c16b537SWarner Losh         // distance is the content size itself, which can be any value from 1 to
5790c16b537SWarner Losh         // 2^64-1 bytes (16 EB)."
5800c16b537SWarner Losh         header->window_size = header->frame_content_size;
5810c16b537SWarner Losh     }
5820c16b537SWarner Losh }
5830c16b537SWarner Losh 
5840c16b537SWarner Losh /// Decompress the data from a frame block by block
decompress_data(frame_context_t * const ctx,ostream_t * const out,istream_t * const in)5850c16b537SWarner Losh static void decompress_data(frame_context_t *const ctx, ostream_t *const out,
5860c16b537SWarner Losh                             istream_t *const in) {
5870c16b537SWarner Losh     // "A frame encapsulates one or multiple blocks. Each block can be
5880c16b537SWarner Losh     // compressed or not, and has a guaranteed maximum content size, which
5890c16b537SWarner Losh     // depends on frame parameters. Unlike frames, each block depends on
5900c16b537SWarner Losh     // previous blocks for proper decoding. However, each block can be
5910c16b537SWarner Losh     // decompressed without waiting for its successor, allowing streaming
5920c16b537SWarner Losh     // operations."
5930c16b537SWarner Losh     int last_block = 0;
5940c16b537SWarner Losh     do {
5950c16b537SWarner Losh         // "Last_Block
5960c16b537SWarner Losh         //
5970c16b537SWarner Losh         // The lowest bit signals if this block is the last one. Frame ends
5980c16b537SWarner Losh         // right after this block.
5990c16b537SWarner Losh         //
6000c16b537SWarner Losh         // Block_Type and Block_Size
6010c16b537SWarner Losh         //
6020c16b537SWarner Losh         // The next 2 bits represent the Block_Type, while the remaining 21 bits
6030c16b537SWarner Losh         // represent the Block_Size. Format is little-endian."
6049cbefe25SConrad Meyer         last_block = (int)IO_read_bits(in, 1);
6059cbefe25SConrad Meyer         const int block_type = (int)IO_read_bits(in, 2);
6060c16b537SWarner Losh         const size_t block_len = IO_read_bits(in, 21);
6070c16b537SWarner Losh 
6080c16b537SWarner Losh         switch (block_type) {
6090c16b537SWarner Losh         case 0: {
6100c16b537SWarner Losh             // "Raw_Block - this is an uncompressed block. Block_Size is the
6110c16b537SWarner Losh             // number of bytes to read and copy."
6120c16b537SWarner Losh             const u8 *const read_ptr = IO_get_read_ptr(in, block_len);
6130c16b537SWarner Losh             u8 *const write_ptr = IO_get_write_ptr(out, block_len);
6140c16b537SWarner Losh 
6150c16b537SWarner Losh             // Copy the raw data into the output
6160c16b537SWarner Losh             memcpy(write_ptr, read_ptr, block_len);
6170c16b537SWarner Losh 
6180c16b537SWarner Losh             ctx->current_total_output += block_len;
6190c16b537SWarner Losh             break;
6200c16b537SWarner Losh         }
6210c16b537SWarner Losh         case 1: {
6220c16b537SWarner Losh             // "RLE_Block - this is a single byte, repeated N times. In which
6230c16b537SWarner Losh             // case, Block_Size is the size to regenerate, while the
6240c16b537SWarner Losh             // "compressed" block is just 1 byte (the byte to repeat)."
6250c16b537SWarner Losh             const u8 *const read_ptr = IO_get_read_ptr(in, 1);
6260c16b537SWarner Losh             u8 *const write_ptr = IO_get_write_ptr(out, block_len);
6270c16b537SWarner Losh 
6280c16b537SWarner Losh             // Copy `block_len` copies of `read_ptr[0]` to the output
6290c16b537SWarner Losh             memset(write_ptr, read_ptr[0], block_len);
6300c16b537SWarner Losh 
6310c16b537SWarner Losh             ctx->current_total_output += block_len;
6320c16b537SWarner Losh             break;
6330c16b537SWarner Losh         }
6340c16b537SWarner Losh         case 2: {
6350c16b537SWarner Losh             // "Compressed_Block - this is a Zstandard compressed block,
6360c16b537SWarner Losh             // detailed in another section of this specification. Block_Size is
6370c16b537SWarner Losh             // the compressed size.
6380c16b537SWarner Losh 
6390c16b537SWarner Losh             // Create a sub-stream for the block
6400c16b537SWarner Losh             istream_t block_stream = IO_make_sub_istream(in, block_len);
6410c16b537SWarner Losh             decompress_block(ctx, out, &block_stream);
6420c16b537SWarner Losh             break;
6430c16b537SWarner Losh         }
6440c16b537SWarner Losh         case 3:
6450c16b537SWarner Losh             // "Reserved - this is not a block. This value cannot be used with
6460c16b537SWarner Losh             // current version of this specification."
6470c16b537SWarner Losh             CORRUPTION();
6480c16b537SWarner Losh             break;
6490c16b537SWarner Losh         default:
6500c16b537SWarner Losh             IMPOSSIBLE();
6510c16b537SWarner Losh         }
6520c16b537SWarner Losh     } while (!last_block);
6530c16b537SWarner Losh 
6540c16b537SWarner Losh     if (ctx->header.content_checksum_flag) {
6550c16b537SWarner Losh         // This program does not support checking the checksum, so skip over it
6560c16b537SWarner Losh         // if it's present
6570c16b537SWarner Losh         IO_advance_input(in, 4);
6580c16b537SWarner Losh     }
6590c16b537SWarner Losh }
6600c16b537SWarner Losh /******* END FRAME DECODING ***************************************************/
6610c16b537SWarner Losh 
6620c16b537SWarner Losh /******* BLOCK DECOMPRESSION **************************************************/
decompress_block(frame_context_t * const ctx,ostream_t * const out,istream_t * const in)6630c16b537SWarner Losh static void decompress_block(frame_context_t *const ctx, ostream_t *const out,
6640c16b537SWarner Losh                              istream_t *const in) {
6650c16b537SWarner Losh     // "A compressed block consists of 2 sections :
6660c16b537SWarner Losh     //
6670c16b537SWarner Losh     // Literals_Section
6680c16b537SWarner Losh     // Sequences_Section"
6690c16b537SWarner Losh 
6700c16b537SWarner Losh 
6710c16b537SWarner Losh     // Part 1: decode the literals block
6720c16b537SWarner Losh     u8 *literals = NULL;
6730c16b537SWarner Losh     const size_t literals_size = decode_literals(ctx, in, &literals);
6740c16b537SWarner Losh 
6750c16b537SWarner Losh     // Part 2: decode the sequences block
6760c16b537SWarner Losh     sequence_command_t *sequences = NULL;
6770c16b537SWarner Losh     const size_t num_sequences =
6780c16b537SWarner Losh         decode_sequences(ctx, in, &sequences);
6790c16b537SWarner Losh 
6800c16b537SWarner Losh     // Part 3: combine literals and sequence commands to generate output
6810c16b537SWarner Losh     execute_sequences(ctx, out, literals, literals_size, sequences,
6820c16b537SWarner Losh                       num_sequences);
6830c16b537SWarner Losh     free(literals);
6840c16b537SWarner Losh     free(sequences);
6850c16b537SWarner Losh }
6860c16b537SWarner Losh /******* END BLOCK DECOMPRESSION **********************************************/
6870c16b537SWarner Losh 
6880c16b537SWarner Losh /******* LITERALS DECODING ****************************************************/
6890c16b537SWarner Losh static size_t decode_literals_simple(istream_t *const in, u8 **const literals,
6900c16b537SWarner Losh                                      const int block_type,
6910c16b537SWarner Losh                                      const int size_format);
6920c16b537SWarner Losh static size_t decode_literals_compressed(frame_context_t *const ctx,
6930c16b537SWarner Losh                                          istream_t *const in,
6940c16b537SWarner Losh                                          u8 **const literals,
6950c16b537SWarner Losh                                          const int block_type,
6960c16b537SWarner Losh                                          const int size_format);
6970c16b537SWarner Losh static void decode_huf_table(HUF_dtable *const dtable, istream_t *const in);
6980c16b537SWarner Losh static void fse_decode_hufweights(ostream_t *weights, istream_t *const in,
6990c16b537SWarner Losh                                     int *const num_symbs);
7000c16b537SWarner Losh 
decode_literals(frame_context_t * const ctx,istream_t * const in,u8 ** const literals)7010c16b537SWarner Losh static size_t decode_literals(frame_context_t *const ctx, istream_t *const in,
7020c16b537SWarner Losh                               u8 **const literals) {
7030c16b537SWarner Losh     // "Literals can be stored uncompressed or compressed using Huffman prefix
7040c16b537SWarner Losh     // codes. When compressed, an optional tree description can be present,
7050c16b537SWarner Losh     // followed by 1 or 4 streams."
7060c16b537SWarner Losh     //
7070c16b537SWarner Losh     // "Literals_Section_Header
7080c16b537SWarner Losh     //
7090c16b537SWarner Losh     // Header is in charge of describing how literals are packed. It's a
7100c16b537SWarner Losh     // byte-aligned variable-size bitfield, ranging from 1 to 5 bytes, using
7110c16b537SWarner Losh     // little-endian convention."
7120c16b537SWarner Losh     //
7130c16b537SWarner Losh     // "Literals_Block_Type
7140c16b537SWarner Losh     //
7150c16b537SWarner Losh     // This field uses 2 lowest bits of first byte, describing 4 different block
7160c16b537SWarner Losh     // types"
7170c16b537SWarner Losh     //
7180c16b537SWarner Losh     // size_format takes between 1 and 2 bits
7199cbefe25SConrad Meyer     int block_type = (int)IO_read_bits(in, 2);
7209cbefe25SConrad Meyer     int size_format = (int)IO_read_bits(in, 2);
7210c16b537SWarner Losh 
7220c16b537SWarner Losh     if (block_type <= 1) {
7230c16b537SWarner Losh         // Raw or RLE literals block
7240c16b537SWarner Losh         return decode_literals_simple(in, literals, block_type,
7250c16b537SWarner Losh                                       size_format);
7260c16b537SWarner Losh     } else {
7270c16b537SWarner Losh         // Huffman compressed literals
7280c16b537SWarner Losh         return decode_literals_compressed(ctx, in, literals, block_type,
7290c16b537SWarner Losh                                           size_format);
7300c16b537SWarner Losh     }
7310c16b537SWarner Losh }
7320c16b537SWarner Losh 
7330c16b537SWarner Losh /// Decodes literals blocks in raw or RLE form
decode_literals_simple(istream_t * const in,u8 ** const literals,const int block_type,const int size_format)7340c16b537SWarner Losh static size_t decode_literals_simple(istream_t *const in, u8 **const literals,
7350c16b537SWarner Losh                                      const int block_type,
7360c16b537SWarner Losh                                      const int size_format) {
7370c16b537SWarner Losh     size_t size;
7380c16b537SWarner Losh     switch (size_format) {
7390c16b537SWarner Losh     // These cases are in the form ?0
7400c16b537SWarner Losh     // In this case, the ? bit is actually part of the size field
7410c16b537SWarner Losh     case 0:
7420c16b537SWarner Losh     case 2:
7430c16b537SWarner Losh         // "Size_Format uses 1 bit. Regenerated_Size uses 5 bits (0-31)."
7440c16b537SWarner Losh         IO_rewind_bits(in, 1);
7450c16b537SWarner Losh         size = IO_read_bits(in, 5);
7460c16b537SWarner Losh         break;
7470c16b537SWarner Losh     case 1:
7480c16b537SWarner Losh         // "Size_Format uses 2 bits. Regenerated_Size uses 12 bits (0-4095)."
7490c16b537SWarner Losh         size = IO_read_bits(in, 12);
7500c16b537SWarner Losh         break;
7510c16b537SWarner Losh     case 3:
7520c16b537SWarner Losh         // "Size_Format uses 2 bits. Regenerated_Size uses 20 bits (0-1048575)."
7530c16b537SWarner Losh         size = IO_read_bits(in, 20);
7540c16b537SWarner Losh         break;
7550c16b537SWarner Losh     default:
7560c16b537SWarner Losh         // Size format is in range 0-3
7570c16b537SWarner Losh         IMPOSSIBLE();
7580c16b537SWarner Losh     }
7590c16b537SWarner Losh 
7600c16b537SWarner Losh     if (size > MAX_LITERALS_SIZE) {
7610c16b537SWarner Losh         CORRUPTION();
7620c16b537SWarner Losh     }
7630c16b537SWarner Losh 
7640c16b537SWarner Losh     *literals = malloc(size);
7650c16b537SWarner Losh     if (!*literals) {
7660c16b537SWarner Losh         BAD_ALLOC();
7670c16b537SWarner Losh     }
7680c16b537SWarner Losh 
7690c16b537SWarner Losh     switch (block_type) {
7700c16b537SWarner Losh     case 0: {
7710c16b537SWarner Losh         // "Raw_Literals_Block - Literals are stored uncompressed."
7720c16b537SWarner Losh         const u8 *const read_ptr = IO_get_read_ptr(in, size);
7730c16b537SWarner Losh         memcpy(*literals, read_ptr, size);
7740c16b537SWarner Losh         break;
7750c16b537SWarner Losh     }
7760c16b537SWarner Losh     case 1: {
7770c16b537SWarner Losh         // "RLE_Literals_Block - Literals consist of a single byte value repeated N times."
7780c16b537SWarner Losh         const u8 *const read_ptr = IO_get_read_ptr(in, 1);
7790c16b537SWarner Losh         memset(*literals, read_ptr[0], size);
7800c16b537SWarner Losh         break;
7810c16b537SWarner Losh     }
7820c16b537SWarner Losh     default:
7830c16b537SWarner Losh         IMPOSSIBLE();
7840c16b537SWarner Losh     }
7850c16b537SWarner Losh 
7860c16b537SWarner Losh     return size;
7870c16b537SWarner Losh }
7880c16b537SWarner Losh 
7890c16b537SWarner Losh /// Decodes Huffman compressed literals
decode_literals_compressed(frame_context_t * const ctx,istream_t * const in,u8 ** const literals,const int block_type,const int size_format)7900c16b537SWarner Losh static size_t decode_literals_compressed(frame_context_t *const ctx,
7910c16b537SWarner Losh                                          istream_t *const in,
7920c16b537SWarner Losh                                          u8 **const literals,
7930c16b537SWarner Losh                                          const int block_type,
7940c16b537SWarner Losh                                          const int size_format) {
7950c16b537SWarner Losh     size_t regenerated_size, compressed_size;
7960c16b537SWarner Losh     // Only size_format=0 has 1 stream, so default to 4
7970c16b537SWarner Losh     int num_streams = 4;
7980c16b537SWarner Losh     switch (size_format) {
7990c16b537SWarner Losh     case 0:
8000c16b537SWarner Losh         // "A single stream. Both Compressed_Size and Regenerated_Size use 10
8010c16b537SWarner Losh         // bits (0-1023)."
8020c16b537SWarner Losh         num_streams = 1;
8030c16b537SWarner Losh     // Fall through as it has the same size format
8049cbefe25SConrad Meyer         /* fallthrough */
8050c16b537SWarner Losh     case 1:
8060c16b537SWarner Losh         // "4 streams. Both Compressed_Size and Regenerated_Size use 10 bits
8070c16b537SWarner Losh         // (0-1023)."
8080c16b537SWarner Losh         regenerated_size = IO_read_bits(in, 10);
8090c16b537SWarner Losh         compressed_size = IO_read_bits(in, 10);
8100c16b537SWarner Losh         break;
8110c16b537SWarner Losh     case 2:
8120c16b537SWarner Losh         // "4 streams. Both Compressed_Size and Regenerated_Size use 14 bits
8130c16b537SWarner Losh         // (0-16383)."
8140c16b537SWarner Losh         regenerated_size = IO_read_bits(in, 14);
8150c16b537SWarner Losh         compressed_size = IO_read_bits(in, 14);
8160c16b537SWarner Losh         break;
8170c16b537SWarner Losh     case 3:
8180c16b537SWarner Losh         // "4 streams. Both Compressed_Size and Regenerated_Size use 18 bits
8190c16b537SWarner Losh         // (0-262143)."
8200c16b537SWarner Losh         regenerated_size = IO_read_bits(in, 18);
8210c16b537SWarner Losh         compressed_size = IO_read_bits(in, 18);
8220c16b537SWarner Losh         break;
8230c16b537SWarner Losh     default:
8240c16b537SWarner Losh         // Impossible
8250c16b537SWarner Losh         IMPOSSIBLE();
8260c16b537SWarner Losh     }
8279cbefe25SConrad Meyer     if (regenerated_size > MAX_LITERALS_SIZE) {
8280c16b537SWarner Losh         CORRUPTION();
8290c16b537SWarner Losh     }
8300c16b537SWarner Losh 
8310c16b537SWarner Losh     *literals = malloc(regenerated_size);
8320c16b537SWarner Losh     if (!*literals) {
8330c16b537SWarner Losh         BAD_ALLOC();
8340c16b537SWarner Losh     }
8350c16b537SWarner Losh 
8360c16b537SWarner Losh     ostream_t lit_stream = IO_make_ostream(*literals, regenerated_size);
8370c16b537SWarner Losh     istream_t huf_stream = IO_make_sub_istream(in, compressed_size);
8380c16b537SWarner Losh 
8390c16b537SWarner Losh     if (block_type == 2) {
8400c16b537SWarner Losh         // Decode the provided Huffman table
8410c16b537SWarner Losh         // "This section is only present when Literals_Block_Type type is
8420c16b537SWarner Losh         // Compressed_Literals_Block (2)."
8430c16b537SWarner Losh 
8440c16b537SWarner Losh         HUF_free_dtable(&ctx->literals_dtable);
8450c16b537SWarner Losh         decode_huf_table(&ctx->literals_dtable, &huf_stream);
8460c16b537SWarner Losh     } else {
8470c16b537SWarner Losh         // If the previous Huffman table is being repeated, ensure it exists
8480c16b537SWarner Losh         if (!ctx->literals_dtable.symbols) {
8490c16b537SWarner Losh             CORRUPTION();
8500c16b537SWarner Losh         }
8510c16b537SWarner Losh     }
8520c16b537SWarner Losh 
8530c16b537SWarner Losh     size_t symbols_decoded;
8540c16b537SWarner Losh     if (num_streams == 1) {
8550c16b537SWarner Losh         symbols_decoded = HUF_decompress_1stream(&ctx->literals_dtable, &lit_stream, &huf_stream);
8560c16b537SWarner Losh     } else {
8570c16b537SWarner Losh         symbols_decoded = HUF_decompress_4stream(&ctx->literals_dtable, &lit_stream, &huf_stream);
8580c16b537SWarner Losh     }
8590c16b537SWarner Losh 
8600c16b537SWarner Losh     if (symbols_decoded != regenerated_size) {
8610c16b537SWarner Losh         CORRUPTION();
8620c16b537SWarner Losh     }
8630c16b537SWarner Losh 
8640c16b537SWarner Losh     return regenerated_size;
8650c16b537SWarner Losh }
8660c16b537SWarner Losh 
8670c16b537SWarner Losh // Decode the Huffman table description
decode_huf_table(HUF_dtable * const dtable,istream_t * const in)8680c16b537SWarner Losh static void decode_huf_table(HUF_dtable *const dtable, istream_t *const in) {
8690c16b537SWarner Losh     // "All literal values from zero (included) to last present one (excluded)
8700c16b537SWarner Losh     // are represented by Weight with values from 0 to Max_Number_of_Bits."
8710c16b537SWarner Losh 
8720c16b537SWarner Losh     // "This is a single byte value (0-255), which describes how to decode the list of weights."
8730c16b537SWarner Losh     const u8 header = IO_read_bits(in, 8);
8740c16b537SWarner Losh 
8750c16b537SWarner Losh     u8 weights[HUF_MAX_SYMBS];
8760c16b537SWarner Losh     memset(weights, 0, sizeof(weights));
8770c16b537SWarner Losh 
8780c16b537SWarner Losh     int num_symbs;
8790c16b537SWarner Losh 
8800c16b537SWarner Losh     if (header >= 128) {
8810c16b537SWarner Losh         // "This is a direct representation, where each Weight is written
8820c16b537SWarner Losh         // directly as a 4 bits field (0-15). The full representation occupies
8830c16b537SWarner Losh         // ((Number_of_Symbols+1)/2) bytes, meaning it uses a last full byte
8840c16b537SWarner Losh         // even if Number_of_Symbols is odd. Number_of_Symbols = headerByte -
8850c16b537SWarner Losh         // 127"
8860c16b537SWarner Losh         num_symbs = header - 127;
8870c16b537SWarner Losh         const size_t bytes = (num_symbs + 1) / 2;
8880c16b537SWarner Losh 
8890c16b537SWarner Losh         const u8 *const weight_src = IO_get_read_ptr(in, bytes);
8900c16b537SWarner Losh 
8910c16b537SWarner Losh         for (int i = 0; i < num_symbs; i++) {
8920c16b537SWarner Losh             // "They are encoded forward, 2
8930c16b537SWarner Losh             // weights to a byte with the first weight taking the top four bits
8940c16b537SWarner Losh             // and the second taking the bottom four (e.g. the following
8950c16b537SWarner Losh             // operations could be used to read the weights: Weight[0] =
8960c16b537SWarner Losh             // (Byte[0] >> 4), Weight[1] = (Byte[0] & 0xf), etc.)."
8970c16b537SWarner Losh             if (i % 2 == 0) {
8980c16b537SWarner Losh                 weights[i] = weight_src[i / 2] >> 4;
8990c16b537SWarner Losh             } else {
9000c16b537SWarner Losh                 weights[i] = weight_src[i / 2] & 0xf;
9010c16b537SWarner Losh             }
9020c16b537SWarner Losh         }
9030c16b537SWarner Losh     } else {
9040c16b537SWarner Losh         // The weights are FSE encoded, decode them before we can construct the
9050c16b537SWarner Losh         // table
9060c16b537SWarner Losh         istream_t fse_stream = IO_make_sub_istream(in, header);
9070c16b537SWarner Losh         ostream_t weight_stream = IO_make_ostream(weights, HUF_MAX_SYMBS);
9080c16b537SWarner Losh         fse_decode_hufweights(&weight_stream, &fse_stream, &num_symbs);
9090c16b537SWarner Losh     }
9100c16b537SWarner Losh 
9110c16b537SWarner Losh     // Construct the table using the decoded weights
9120c16b537SWarner Losh     HUF_init_dtable_usingweights(dtable, weights, num_symbs);
9130c16b537SWarner Losh }
9140c16b537SWarner Losh 
fse_decode_hufweights(ostream_t * weights,istream_t * const in,int * const num_symbs)9150c16b537SWarner Losh static void fse_decode_hufweights(ostream_t *weights, istream_t *const in,
9160c16b537SWarner Losh                                     int *const num_symbs) {
9170c16b537SWarner Losh     const int MAX_ACCURACY_LOG = 7;
9180c16b537SWarner Losh 
9190c16b537SWarner Losh     FSE_dtable dtable;
9200c16b537SWarner Losh 
9210c16b537SWarner Losh     // "An FSE bitstream starts by a header, describing probabilities
9220c16b537SWarner Losh     // distribution. It will create a Decoding Table. For a list of Huffman
9230c16b537SWarner Losh     // weights, maximum accuracy is 7 bits."
9240c16b537SWarner Losh     FSE_decode_header(&dtable, in, MAX_ACCURACY_LOG);
9250c16b537SWarner Losh 
9260c16b537SWarner Losh     // Decode the weights
9270c16b537SWarner Losh     *num_symbs = FSE_decompress_interleaved2(&dtable, weights, in);
9280c16b537SWarner Losh 
9290c16b537SWarner Losh     FSE_free_dtable(&dtable);
9300c16b537SWarner Losh }
9310c16b537SWarner Losh /******* END LITERALS DECODING ************************************************/
9320c16b537SWarner Losh 
9330c16b537SWarner Losh /******* SEQUENCE DECODING ****************************************************/
9340c16b537SWarner Losh /// The combination of FSE states needed to decode sequences
9350c16b537SWarner Losh typedef struct {
9360c16b537SWarner Losh     FSE_dtable ll_table;
9370c16b537SWarner Losh     FSE_dtable of_table;
9380c16b537SWarner Losh     FSE_dtable ml_table;
9390c16b537SWarner Losh 
9400c16b537SWarner Losh     u16 ll_state;
9410c16b537SWarner Losh     u16 of_state;
9420c16b537SWarner Losh     u16 ml_state;
9430c16b537SWarner Losh } sequence_states_t;
9440c16b537SWarner Losh 
9450c16b537SWarner Losh /// Different modes to signal to decode_seq_tables what to do
9460c16b537SWarner Losh typedef enum {
9470c16b537SWarner Losh     seq_literal_length = 0,
9480c16b537SWarner Losh     seq_offset = 1,
9490c16b537SWarner Losh     seq_match_length = 2,
9500c16b537SWarner Losh } seq_part_t;
9510c16b537SWarner Losh 
9520c16b537SWarner Losh typedef enum {
9530c16b537SWarner Losh     seq_predefined = 0,
9540c16b537SWarner Losh     seq_rle = 1,
9550c16b537SWarner Losh     seq_fse = 2,
9560c16b537SWarner Losh     seq_repeat = 3,
9570c16b537SWarner Losh } seq_mode_t;
9580c16b537SWarner Losh 
9590c16b537SWarner Losh /// The predefined FSE distribution tables for `seq_predefined` mode
9600c16b537SWarner Losh static const i16 SEQ_LITERAL_LENGTH_DEFAULT_DIST[36] = {
9610c16b537SWarner Losh     4, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1,  1,  2,  2,
9620c16b537SWarner Losh     2, 2, 2, 2, 2, 2, 2, 3, 2, 1, 1, 1, 1, 1, -1, -1, -1, -1};
9630c16b537SWarner Losh static const i16 SEQ_OFFSET_DEFAULT_DIST[29] = {
9640c16b537SWarner Losh     1, 1, 1, 1, 1, 1, 2, 2, 2, 1,  1,  1,  1,  1, 1,
9650c16b537SWarner Losh     1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1};
9660c16b537SWarner Losh static const i16 SEQ_MATCH_LENGTH_DEFAULT_DIST[53] = {
9670c16b537SWarner Losh     1, 4, 3, 2, 2, 2, 2, 2, 2, 1, 1,  1,  1,  1,  1,  1,  1, 1,
9680c16b537SWarner Losh     1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,  1,  1,  1,  1,  1,  1, 1,
9690c16b537SWarner Losh     1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1};
9700c16b537SWarner Losh 
9710c16b537SWarner Losh /// The sequence decoding baseline and number of additional bits to read/add
9720c16b537SWarner Losh /// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#the-codes-for-literals-lengths-match-lengths-and-offsets
9730c16b537SWarner Losh static const u32 SEQ_LITERAL_LENGTH_BASELINES[36] = {
9740c16b537SWarner Losh     0,  1,  2,   3,   4,   5,    6,    7,    8,    9,     10,    11,
9750c16b537SWarner Losh     12, 13, 14,  15,  16,  18,   20,   22,   24,   28,    32,    40,
9769cbefe25SConrad Meyer     48, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536};
9770c16b537SWarner Losh static const u8 SEQ_LITERAL_LENGTH_EXTRA_BITS[36] = {
9780c16b537SWarner Losh     0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,  0,  0,  0,  0,  1,  1,
9790c16b537SWarner Losh     1, 1, 2, 2, 3, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
9800c16b537SWarner Losh 
9810c16b537SWarner Losh static const u32 SEQ_MATCH_LENGTH_BASELINES[53] = {
9820c16b537SWarner Losh     3,  4,   5,   6,   7,    8,    9,    10,   11,    12,    13,   14, 15, 16,
9830c16b537SWarner Losh     17, 18,  19,  20,  21,   22,   23,   24,   25,    26,    27,   28, 29, 30,
9840c16b537SWarner Losh     31, 32,  33,  34,  35,   37,   39,   41,   43,    47,    51,   59, 67, 83,
9850c16b537SWarner Losh     99, 131, 259, 515, 1027, 2051, 4099, 8195, 16387, 32771, 65539};
9860c16b537SWarner Losh static const u8 SEQ_MATCH_LENGTH_EXTRA_BITS[53] = {
9870c16b537SWarner Losh     0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,  0,  0,  0,  0,  0,  0, 0,
9880c16b537SWarner Losh     0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,  0,  0,  0,  1,  1,  1, 1,
9890c16b537SWarner Losh     2, 2, 3, 3, 4, 4, 5, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
9900c16b537SWarner Losh 
9910c16b537SWarner Losh /// Offset decoding is simpler so we just need a maximum code value
9929cbefe25SConrad Meyer static const u8 SEQ_MAX_CODES[3] = {35, (u8)-1, 52};
9930c16b537SWarner Losh 
9940c16b537SWarner Losh static void decompress_sequences(frame_context_t *const ctx,
9950c16b537SWarner Losh                                  istream_t *const in,
9960c16b537SWarner Losh                                  sequence_command_t *const sequences,
9970c16b537SWarner Losh                                  const size_t num_sequences);
9980c16b537SWarner Losh static sequence_command_t decode_sequence(sequence_states_t *const state,
9990c16b537SWarner Losh                                           const u8 *const src,
10000c16b537SWarner Losh                                           i64 *const offset);
10010c16b537SWarner Losh static void decode_seq_table(FSE_dtable *const table, istream_t *const in,
10020c16b537SWarner Losh                                const seq_part_t type, const seq_mode_t mode);
10030c16b537SWarner Losh 
decode_sequences(frame_context_t * const ctx,istream_t * in,sequence_command_t ** const sequences)10040c16b537SWarner Losh static size_t decode_sequences(frame_context_t *const ctx, istream_t *in,
10050c16b537SWarner Losh                                sequence_command_t **const sequences) {
10060c16b537SWarner Losh     // "A compressed block is a succession of sequences . A sequence is a
10070c16b537SWarner Losh     // literal copy command, followed by a match copy command. A literal copy
10080c16b537SWarner Losh     // command specifies a length. It is the number of bytes to be copied (or
10090c16b537SWarner Losh     // extracted) from the literal section. A match copy command specifies an
10100c16b537SWarner Losh     // offset and a length. The offset gives the position to copy from, which
10110c16b537SWarner Losh     // can be within a previous block."
10120c16b537SWarner Losh 
10130c16b537SWarner Losh     size_t num_sequences;
10140c16b537SWarner Losh 
10150c16b537SWarner Losh     // "Number_of_Sequences
10160c16b537SWarner Losh     //
10170c16b537SWarner Losh     // This is a variable size field using between 1 and 3 bytes. Let's call its
10180c16b537SWarner Losh     // first byte byte0."
10190c16b537SWarner Losh     u8 header = IO_read_bits(in, 8);
10200c16b537SWarner Losh     if (header == 0) {
10210c16b537SWarner Losh         // "There are no sequences. The sequence section stops there.
10220c16b537SWarner Losh         // Regenerated content is defined entirely by literals section."
10230c16b537SWarner Losh         *sequences = NULL;
10240c16b537SWarner Losh         return 0;
10250c16b537SWarner Losh     } else if (header < 128) {
10260c16b537SWarner Losh         // "Number_of_Sequences = byte0 . Uses 1 byte."
10270c16b537SWarner Losh         num_sequences = header;
10280c16b537SWarner Losh     } else if (header < 255) {
10290c16b537SWarner Losh         // "Number_of_Sequences = ((byte0-128) << 8) + byte1 . Uses 2 bytes."
10300c16b537SWarner Losh         num_sequences = ((header - 128) << 8) + IO_read_bits(in, 8);
10310c16b537SWarner Losh     } else {
10320c16b537SWarner Losh         // "Number_of_Sequences = byte1 + (byte2<<8) + 0x7F00 . Uses 3 bytes."
10330c16b537SWarner Losh         num_sequences = IO_read_bits(in, 16) + 0x7F00;
10340c16b537SWarner Losh     }
10350c16b537SWarner Losh 
10360c16b537SWarner Losh     *sequences = malloc(num_sequences * sizeof(sequence_command_t));
10370c16b537SWarner Losh     if (!*sequences) {
10380c16b537SWarner Losh         BAD_ALLOC();
10390c16b537SWarner Losh     }
10400c16b537SWarner Losh 
10410c16b537SWarner Losh     decompress_sequences(ctx, in, *sequences, num_sequences);
10420c16b537SWarner Losh     return num_sequences;
10430c16b537SWarner Losh }
10440c16b537SWarner Losh 
10450c16b537SWarner Losh /// Decompress the FSE encoded sequence commands
decompress_sequences(frame_context_t * const ctx,istream_t * in,sequence_command_t * const sequences,const size_t num_sequences)10460c16b537SWarner Losh static void decompress_sequences(frame_context_t *const ctx, istream_t *in,
10470c16b537SWarner Losh                                  sequence_command_t *const sequences,
10480c16b537SWarner Losh                                  const size_t num_sequences) {
10490c16b537SWarner Losh     // "The Sequences_Section regroup all symbols required to decode commands.
10500c16b537SWarner Losh     // There are 3 symbol types : literals lengths, offsets and match lengths.
10510c16b537SWarner Losh     // They are encoded together, interleaved, in a single bitstream."
10520c16b537SWarner Losh 
10530c16b537SWarner Losh     // "Symbol compression modes
10540c16b537SWarner Losh     //
10550c16b537SWarner Losh     // This is a single byte, defining the compression mode of each symbol
10560c16b537SWarner Losh     // type."
10570c16b537SWarner Losh     //
10580c16b537SWarner Losh     // Bit number : Field name
10590c16b537SWarner Losh     // 7-6        : Literals_Lengths_Mode
10600c16b537SWarner Losh     // 5-4        : Offsets_Mode
10610c16b537SWarner Losh     // 3-2        : Match_Lengths_Mode
10620c16b537SWarner Losh     // 1-0        : Reserved
10630c16b537SWarner Losh     u8 compression_modes = IO_read_bits(in, 8);
10640c16b537SWarner Losh 
10650c16b537SWarner Losh     if ((compression_modes & 3) != 0) {
10660c16b537SWarner Losh         // Reserved bits set
10670c16b537SWarner Losh         CORRUPTION();
10680c16b537SWarner Losh     }
10690c16b537SWarner Losh 
10700c16b537SWarner Losh     // "Following the header, up to 3 distribution tables can be described. When
10710c16b537SWarner Losh     // present, they are in this order :
10720c16b537SWarner Losh     //
10730c16b537SWarner Losh     // Literals lengths
10740c16b537SWarner Losh     // Offsets
10750c16b537SWarner Losh     // Match Lengths"
10760c16b537SWarner Losh     // Update the tables we have stored in the context
10770c16b537SWarner Losh     decode_seq_table(&ctx->ll_dtable, in, seq_literal_length,
10780c16b537SWarner Losh                      (compression_modes >> 6) & 3);
10790c16b537SWarner Losh 
10800c16b537SWarner Losh     decode_seq_table(&ctx->of_dtable, in, seq_offset,
10810c16b537SWarner Losh                      (compression_modes >> 4) & 3);
10820c16b537SWarner Losh 
10830c16b537SWarner Losh     decode_seq_table(&ctx->ml_dtable, in, seq_match_length,
10840c16b537SWarner Losh                      (compression_modes >> 2) & 3);
10850c16b537SWarner Losh 
10860c16b537SWarner Losh 
10870c16b537SWarner Losh     sequence_states_t states;
10880c16b537SWarner Losh 
10890c16b537SWarner Losh     // Initialize the decoding tables
10900c16b537SWarner Losh     {
10910c16b537SWarner Losh         states.ll_table = ctx->ll_dtable;
10920c16b537SWarner Losh         states.of_table = ctx->of_dtable;
10930c16b537SWarner Losh         states.ml_table = ctx->ml_dtable;
10940c16b537SWarner Losh     }
10950c16b537SWarner Losh 
10960c16b537SWarner Losh     const size_t len = IO_istream_len(in);
10970c16b537SWarner Losh     const u8 *const src = IO_get_read_ptr(in, len);
10980c16b537SWarner Losh 
10990c16b537SWarner Losh     // "After writing the last bit containing information, the compressor writes
11000c16b537SWarner Losh     // a single 1-bit and then fills the byte with 0-7 0 bits of padding."
11010c16b537SWarner Losh     const int padding = 8 - highest_set_bit(src[len - 1]);
11020c16b537SWarner Losh     // The offset starts at the end because FSE streams are read backwards
11039cbefe25SConrad Meyer     i64 bit_offset = (i64)(len * 8 - (size_t)padding);
11040c16b537SWarner Losh 
11050c16b537SWarner Losh     // "The bitstream starts with initial state values, each using the required
11060c16b537SWarner Losh     // number of bits in their respective accuracy, decoded previously from
11070c16b537SWarner Losh     // their normalized distribution.
11080c16b537SWarner Losh     //
11090c16b537SWarner Losh     // It starts by Literals_Length_State, followed by Offset_State, and finally
11100c16b537SWarner Losh     // Match_Length_State."
11110c16b537SWarner Losh     FSE_init_state(&states.ll_table, &states.ll_state, src, &bit_offset);
11120c16b537SWarner Losh     FSE_init_state(&states.of_table, &states.of_state, src, &bit_offset);
11130c16b537SWarner Losh     FSE_init_state(&states.ml_table, &states.ml_state, src, &bit_offset);
11140c16b537SWarner Losh 
11150c16b537SWarner Losh     for (size_t i = 0; i < num_sequences; i++) {
11160c16b537SWarner Losh         // Decode sequences one by one
11170c16b537SWarner Losh         sequences[i] = decode_sequence(&states, src, &bit_offset);
11180c16b537SWarner Losh     }
11190c16b537SWarner Losh 
11200c16b537SWarner Losh     if (bit_offset != 0) {
11210c16b537SWarner Losh         CORRUPTION();
11220c16b537SWarner Losh     }
11230c16b537SWarner Losh }
11240c16b537SWarner Losh 
11250c16b537SWarner Losh // Decode a single sequence and update the state
decode_sequence(sequence_states_t * const states,const u8 * const src,i64 * const offset)11260c16b537SWarner Losh static sequence_command_t decode_sequence(sequence_states_t *const states,
11270c16b537SWarner Losh                                           const u8 *const src,
11280c16b537SWarner Losh                                           i64 *const offset) {
11290c16b537SWarner Losh     // "Each symbol is a code in its own context, which specifies Baseline and
11300c16b537SWarner Losh     // Number_of_Bits to add. Codes are FSE compressed, and interleaved with raw
11310c16b537SWarner Losh     // additional bits in the same bitstream."
11320c16b537SWarner Losh 
11330c16b537SWarner Losh     // Decode symbols, but don't update states
11340c16b537SWarner Losh     const u8 of_code = FSE_peek_symbol(&states->of_table, states->of_state);
11350c16b537SWarner Losh     const u8 ll_code = FSE_peek_symbol(&states->ll_table, states->ll_state);
11360c16b537SWarner Losh     const u8 ml_code = FSE_peek_symbol(&states->ml_table, states->ml_state);
11370c16b537SWarner Losh 
11380c16b537SWarner Losh     // Offset doesn't need a max value as it's not decoded using a table
11390c16b537SWarner Losh     if (ll_code > SEQ_MAX_CODES[seq_literal_length] ||
11400c16b537SWarner Losh         ml_code > SEQ_MAX_CODES[seq_match_length]) {
11410c16b537SWarner Losh         CORRUPTION();
11420c16b537SWarner Losh     }
11430c16b537SWarner Losh 
11440c16b537SWarner Losh     // Read the interleaved bits
11450c16b537SWarner Losh     sequence_command_t seq;
11460c16b537SWarner Losh     // "Decoding starts by reading the Number_of_Bits required to decode Offset.
11470c16b537SWarner Losh     // It then does the same for Match_Length, and then for Literals_Length."
11480c16b537SWarner Losh     seq.offset = ((u32)1 << of_code) + STREAM_read_bits(src, of_code, offset);
11490c16b537SWarner Losh 
11500c16b537SWarner Losh     seq.match_length =
11510c16b537SWarner Losh         SEQ_MATCH_LENGTH_BASELINES[ml_code] +
11520c16b537SWarner Losh         STREAM_read_bits(src, SEQ_MATCH_LENGTH_EXTRA_BITS[ml_code], offset);
11530c16b537SWarner Losh 
11540c16b537SWarner Losh     seq.literal_length =
11550c16b537SWarner Losh         SEQ_LITERAL_LENGTH_BASELINES[ll_code] +
11560c16b537SWarner Losh         STREAM_read_bits(src, SEQ_LITERAL_LENGTH_EXTRA_BITS[ll_code], offset);
11570c16b537SWarner Losh 
11580c16b537SWarner Losh     // "If it is not the last sequence in the block, the next operation is to
11590c16b537SWarner Losh     // update states. Using the rules pre-calculated in the decoding tables,
11600c16b537SWarner Losh     // Literals_Length_State is updated, followed by Match_Length_State, and
11610c16b537SWarner Losh     // then Offset_State."
11620c16b537SWarner Losh     // If the stream is complete don't read bits to update state
11630c16b537SWarner Losh     if (*offset != 0) {
11640c16b537SWarner Losh         FSE_update_state(&states->ll_table, &states->ll_state, src, offset);
11650c16b537SWarner Losh         FSE_update_state(&states->ml_table, &states->ml_state, src, offset);
11660c16b537SWarner Losh         FSE_update_state(&states->of_table, &states->of_state, src, offset);
11670c16b537SWarner Losh     }
11680c16b537SWarner Losh 
11690c16b537SWarner Losh     return seq;
11700c16b537SWarner Losh }
11710c16b537SWarner Losh 
11720c16b537SWarner Losh /// Given a sequence part and table mode, decode the FSE distribution
11730c16b537SWarner Losh /// Errors if the mode is `seq_repeat` without a pre-existing table in `table`
decode_seq_table(FSE_dtable * const table,istream_t * const in,const seq_part_t type,const seq_mode_t mode)11740c16b537SWarner Losh static void decode_seq_table(FSE_dtable *const table, istream_t *const in,
11750c16b537SWarner Losh                              const seq_part_t type, const seq_mode_t mode) {
11760c16b537SWarner Losh     // Constant arrays indexed by seq_part_t
11770c16b537SWarner Losh     const i16 *const default_distributions[] = {SEQ_LITERAL_LENGTH_DEFAULT_DIST,
11780c16b537SWarner Losh                                                 SEQ_OFFSET_DEFAULT_DIST,
11790c16b537SWarner Losh                                                 SEQ_MATCH_LENGTH_DEFAULT_DIST};
11800c16b537SWarner Losh     const size_t default_distribution_lengths[] = {36, 29, 53};
11810c16b537SWarner Losh     const size_t default_distribution_accuracies[] = {6, 5, 6};
11820c16b537SWarner Losh 
11830c16b537SWarner Losh     const size_t max_accuracies[] = {9, 8, 9};
11840c16b537SWarner Losh 
11850c16b537SWarner Losh     if (mode != seq_repeat) {
11860c16b537SWarner Losh         // Free old one before overwriting
11870c16b537SWarner Losh         FSE_free_dtable(table);
11880c16b537SWarner Losh     }
11890c16b537SWarner Losh 
11900c16b537SWarner Losh     switch (mode) {
11910c16b537SWarner Losh     case seq_predefined: {
11920c16b537SWarner Losh         // "Predefined_Mode : uses a predefined distribution table."
11930c16b537SWarner Losh         const i16 *distribution = default_distributions[type];
11940c16b537SWarner Losh         const size_t symbs = default_distribution_lengths[type];
11950c16b537SWarner Losh         const size_t accuracy_log = default_distribution_accuracies[type];
11960c16b537SWarner Losh 
11970c16b537SWarner Losh         FSE_init_dtable(table, distribution, symbs, accuracy_log);
11980c16b537SWarner Losh         break;
11990c16b537SWarner Losh     }
12000c16b537SWarner Losh     case seq_rle: {
12010c16b537SWarner Losh         // "RLE_Mode : it's a single code, repeated Number_of_Sequences times."
12020c16b537SWarner Losh         const u8 symb = IO_get_read_ptr(in, 1)[0];
12030c16b537SWarner Losh         FSE_init_dtable_rle(table, symb);
12040c16b537SWarner Losh         break;
12050c16b537SWarner Losh     }
12060c16b537SWarner Losh     case seq_fse: {
12070c16b537SWarner Losh         // "FSE_Compressed_Mode : standard FSE compression. A distribution table
12080c16b537SWarner Losh         // will be present "
12090c16b537SWarner Losh         FSE_decode_header(table, in, max_accuracies[type]);
12100c16b537SWarner Losh         break;
12110c16b537SWarner Losh     }
12120c16b537SWarner Losh     case seq_repeat:
12130c16b537SWarner Losh         // "Repeat_Mode : re-use distribution table from previous compressed
12140c16b537SWarner Losh         // block."
12150c16b537SWarner Losh         // Nothing to do here, table will be unchanged
12160c16b537SWarner Losh         if (!table->symbols) {
12170c16b537SWarner Losh             // This mode is invalid if we don't already have a table
12180c16b537SWarner Losh             CORRUPTION();
12190c16b537SWarner Losh         }
12200c16b537SWarner Losh         break;
12210c16b537SWarner Losh     default:
12220c16b537SWarner Losh         // Impossible, as mode is from 0-3
12230c16b537SWarner Losh         IMPOSSIBLE();
12240c16b537SWarner Losh         break;
12250c16b537SWarner Losh     }
12260c16b537SWarner Losh 
12270c16b537SWarner Losh }
12280c16b537SWarner Losh /******* END SEQUENCE DECODING ************************************************/
12290c16b537SWarner Losh 
12300c16b537SWarner Losh /******* SEQUENCE EXECUTION ***************************************************/
execute_sequences(frame_context_t * const ctx,ostream_t * const out,const u8 * const literals,const size_t literals_len,const sequence_command_t * const sequences,const size_t num_sequences)12310c16b537SWarner Losh static void execute_sequences(frame_context_t *const ctx, ostream_t *const out,
12320c16b537SWarner Losh                               const u8 *const literals,
12330c16b537SWarner Losh                               const size_t literals_len,
12340c16b537SWarner Losh                               const sequence_command_t *const sequences,
12350c16b537SWarner Losh                               const size_t num_sequences) {
12360c16b537SWarner Losh     istream_t litstream = IO_make_istream(literals, literals_len);
12370c16b537SWarner Losh 
12380c16b537SWarner Losh     u64 *const offset_hist = ctx->previous_offsets;
12390c16b537SWarner Losh     size_t total_output = ctx->current_total_output;
12400c16b537SWarner Losh 
12410c16b537SWarner Losh     for (size_t i = 0; i < num_sequences; i++) {
12420c16b537SWarner Losh         const sequence_command_t seq = sequences[i];
12430c16b537SWarner Losh         {
12440c16b537SWarner Losh             const u32 literals_size = copy_literals(seq.literal_length, &litstream, out);
12450c16b537SWarner Losh             total_output += literals_size;
12460c16b537SWarner Losh         }
12470c16b537SWarner Losh 
12480c16b537SWarner Losh         size_t const offset = compute_offset(seq, offset_hist);
12490c16b537SWarner Losh 
12500c16b537SWarner Losh         size_t const match_length = seq.match_length;
12510c16b537SWarner Losh 
12520c16b537SWarner Losh         execute_match_copy(ctx, offset, match_length, total_output, out);
12530c16b537SWarner Losh 
12540c16b537SWarner Losh         total_output += match_length;
12550c16b537SWarner Losh     }
12560c16b537SWarner Losh 
12570c16b537SWarner Losh     // Copy any leftover literals
12580c16b537SWarner Losh     {
12590c16b537SWarner Losh         size_t len = IO_istream_len(&litstream);
12600c16b537SWarner Losh         copy_literals(len, &litstream, out);
12610c16b537SWarner Losh         total_output += len;
12620c16b537SWarner Losh     }
12630c16b537SWarner Losh 
12640c16b537SWarner Losh     ctx->current_total_output = total_output;
12650c16b537SWarner Losh }
12660c16b537SWarner Losh 
copy_literals(const size_t literal_length,istream_t * litstream,ostream_t * const out)12670c16b537SWarner Losh static u32 copy_literals(const size_t literal_length, istream_t *litstream,
12680c16b537SWarner Losh                          ostream_t *const out) {
12690c16b537SWarner Losh     // If the sequence asks for more literals than are left, the
12700c16b537SWarner Losh     // sequence must be corrupted
12710c16b537SWarner Losh     if (literal_length > IO_istream_len(litstream)) {
12720c16b537SWarner Losh         CORRUPTION();
12730c16b537SWarner Losh     }
12740c16b537SWarner Losh 
12750c16b537SWarner Losh     u8 *const write_ptr = IO_get_write_ptr(out, literal_length);
12760c16b537SWarner Losh     const u8 *const read_ptr =
12770c16b537SWarner Losh          IO_get_read_ptr(litstream, literal_length);
12780c16b537SWarner Losh     // Copy literals to output
12790c16b537SWarner Losh     memcpy(write_ptr, read_ptr, literal_length);
12800c16b537SWarner Losh 
12810c16b537SWarner Losh     return literal_length;
12820c16b537SWarner Losh }
12830c16b537SWarner Losh 
compute_offset(sequence_command_t seq,u64 * const offset_hist)12840c16b537SWarner Losh static size_t compute_offset(sequence_command_t seq, u64 *const offset_hist) {
12850c16b537SWarner Losh     size_t offset;
12860c16b537SWarner Losh     // Offsets are special, we need to handle the repeat offsets
12870c16b537SWarner Losh     if (seq.offset <= 3) {
12880c16b537SWarner Losh         // "The first 3 values define a repeated offset and we will call
12890c16b537SWarner Losh         // them Repeated_Offset1, Repeated_Offset2, and Repeated_Offset3.
12900c16b537SWarner Losh         // They are sorted in recency order, with Repeated_Offset1 meaning
12910c16b537SWarner Losh         // 'most recent one'".
12920c16b537SWarner Losh 
12930c16b537SWarner Losh         // Use 0 indexing for the array
12940c16b537SWarner Losh         u32 idx = seq.offset - 1;
12950c16b537SWarner Losh         if (seq.literal_length == 0) {
12960c16b537SWarner Losh             // "There is an exception though, when current sequence's
12970c16b537SWarner Losh             // literals length is 0. In this case, repeated offsets are
12980c16b537SWarner Losh             // shifted by one, so Repeated_Offset1 becomes Repeated_Offset2,
12990c16b537SWarner Losh             // Repeated_Offset2 becomes Repeated_Offset3, and
13000c16b537SWarner Losh             // Repeated_Offset3 becomes Repeated_Offset1 - 1_byte."
13010c16b537SWarner Losh             idx++;
13020c16b537SWarner Losh         }
13030c16b537SWarner Losh 
13040c16b537SWarner Losh         if (idx == 0) {
13050c16b537SWarner Losh             offset = offset_hist[0];
13060c16b537SWarner Losh         } else {
13070c16b537SWarner Losh             // If idx == 3 then literal length was 0 and the offset was 3,
13080c16b537SWarner Losh             // as per the exception listed above
13090c16b537SWarner Losh             offset = idx < 3 ? offset_hist[idx] : offset_hist[0] - 1;
13100c16b537SWarner Losh 
13110c16b537SWarner Losh             // If idx == 1 we don't need to modify offset_hist[2], since
13120c16b537SWarner Losh             // we're using the second-most recent code
13130c16b537SWarner Losh             if (idx > 1) {
13140c16b537SWarner Losh                 offset_hist[2] = offset_hist[1];
13150c16b537SWarner Losh             }
13160c16b537SWarner Losh             offset_hist[1] = offset_hist[0];
13170c16b537SWarner Losh             offset_hist[0] = offset;
13180c16b537SWarner Losh         }
13190c16b537SWarner Losh     } else {
13200c16b537SWarner Losh         // When it's not a repeat offset:
13210c16b537SWarner Losh         // "if (Offset_Value > 3) offset = Offset_Value - 3;"
13220c16b537SWarner Losh         offset = seq.offset - 3;
13230c16b537SWarner Losh 
13240c16b537SWarner Losh         // Shift back history
13250c16b537SWarner Losh         offset_hist[2] = offset_hist[1];
13260c16b537SWarner Losh         offset_hist[1] = offset_hist[0];
13270c16b537SWarner Losh         offset_hist[0] = offset;
13280c16b537SWarner Losh     }
13290c16b537SWarner Losh     return offset;
13300c16b537SWarner Losh }
13310c16b537SWarner Losh 
execute_match_copy(frame_context_t * const ctx,size_t offset,size_t match_length,size_t total_output,ostream_t * const out)13320c16b537SWarner Losh static void execute_match_copy(frame_context_t *const ctx, size_t offset,
13330c16b537SWarner Losh                               size_t match_length, size_t total_output,
13340c16b537SWarner Losh                               ostream_t *const out) {
13350c16b537SWarner Losh     u8 *write_ptr = IO_get_write_ptr(out, match_length);
13360c16b537SWarner Losh     if (total_output <= ctx->header.window_size) {
13370c16b537SWarner Losh         // In this case offset might go back into the dictionary
13380c16b537SWarner Losh         if (offset > total_output + ctx->dict_content_len) {
13390c16b537SWarner Losh             // The offset goes beyond even the dictionary
13400c16b537SWarner Losh             CORRUPTION();
13410c16b537SWarner Losh         }
13420c16b537SWarner Losh 
13430c16b537SWarner Losh         if (offset > total_output) {
13440c16b537SWarner Losh             // "The rest of the dictionary is its content. The content act
13450c16b537SWarner Losh             // as a "past" in front of data to compress or decompress, so it
13460c16b537SWarner Losh             // can be referenced in sequence commands."
13470c16b537SWarner Losh             const size_t dict_copy =
13480c16b537SWarner Losh                 MIN(offset - total_output, match_length);
13490c16b537SWarner Losh             const size_t dict_offset =
13500c16b537SWarner Losh                 ctx->dict_content_len - (offset - total_output);
13510c16b537SWarner Losh 
13520c16b537SWarner Losh             memcpy(write_ptr, ctx->dict_content + dict_offset, dict_copy);
13530c16b537SWarner Losh             write_ptr += dict_copy;
13540c16b537SWarner Losh             match_length -= dict_copy;
13550c16b537SWarner Losh         }
13560c16b537SWarner Losh     } else if (offset > ctx->header.window_size) {
13570c16b537SWarner Losh         CORRUPTION();
13580c16b537SWarner Losh     }
13590c16b537SWarner Losh 
13600c16b537SWarner Losh     // We must copy byte by byte because the match length might be larger
13610c16b537SWarner Losh     // than the offset
13620c16b537SWarner Losh     // ex: if the output so far was "abc", a command with offset=3 and
13630c16b537SWarner Losh     // match_length=6 would produce "abcabcabc" as the new output
13640c16b537SWarner Losh     for (size_t j = 0; j < match_length; j++) {
13650c16b537SWarner Losh         *write_ptr = *(write_ptr - offset);
13660c16b537SWarner Losh         write_ptr++;
13670c16b537SWarner Losh     }
13680c16b537SWarner Losh }
13690c16b537SWarner Losh /******* END SEQUENCE EXECUTION ***********************************************/
13700c16b537SWarner Losh 
13710c16b537SWarner Losh /******* OUTPUT SIZE COUNTING *************************************************/
13720c16b537SWarner Losh /// Get the decompressed size of an input stream so memory can be allocated in
13730c16b537SWarner Losh /// advance.
13740c16b537SWarner Losh /// This implementation assumes `src` points to a single ZSTD-compressed frame
ZSTD_get_decompressed_size(const void * src,const size_t src_len)13750c16b537SWarner Losh size_t ZSTD_get_decompressed_size(const void *src, const size_t src_len) {
13760c16b537SWarner Losh     istream_t in = IO_make_istream(src, src_len);
13770c16b537SWarner Losh 
13780c16b537SWarner Losh     // get decompressed size from ZSTD frame header
13790c16b537SWarner Losh     {
13809cbefe25SConrad Meyer         const u32 magic_number = (u32)IO_read_bits(&in, 32);
13810c16b537SWarner Losh 
138237f1f268SConrad Meyer         if (magic_number == ZSTD_MAGIC_NUMBER) {
13830c16b537SWarner Losh             // ZSTD frame
13840c16b537SWarner Losh             frame_header_t header;
13850c16b537SWarner Losh             parse_frame_header(&header, &in);
13860c16b537SWarner Losh 
13870c16b537SWarner Losh             if (header.frame_content_size == 0 && !header.single_segment_flag) {
13880c16b537SWarner Losh                 // Content size not provided, we can't tell
13899cbefe25SConrad Meyer                 return (size_t)-1;
13900c16b537SWarner Losh             }
13910c16b537SWarner Losh 
13920c16b537SWarner Losh             return header.frame_content_size;
13930c16b537SWarner Losh         } else {
13940c16b537SWarner Losh             // not a real frame or skippable frame
13950c16b537SWarner Losh             ERROR("ZSTD frame magic number did not match");
13960c16b537SWarner Losh         }
13970c16b537SWarner Losh     }
13980c16b537SWarner Losh }
13990c16b537SWarner Losh /******* END OUTPUT SIZE COUNTING *********************************************/
14000c16b537SWarner Losh 
14010c16b537SWarner Losh /******* DICTIONARY PARSING ***************************************************/
create_dictionary()14020c16b537SWarner Losh dictionary_t* create_dictionary() {
140337f1f268SConrad Meyer     dictionary_t* const dict = calloc(1, sizeof(dictionary_t));
14040c16b537SWarner Losh     if (!dict) {
14050c16b537SWarner Losh         BAD_ALLOC();
14060c16b537SWarner Losh     }
14070c16b537SWarner Losh     return dict;
14080c16b537SWarner Losh }
14090c16b537SWarner Losh 
141037f1f268SConrad Meyer /// Free an allocated dictionary
free_dictionary(dictionary_t * const dict)141137f1f268SConrad Meyer void free_dictionary(dictionary_t *const dict) {
141237f1f268SConrad Meyer     HUF_free_dtable(&dict->literals_dtable);
141337f1f268SConrad Meyer     FSE_free_dtable(&dict->ll_dtable);
141437f1f268SConrad Meyer     FSE_free_dtable(&dict->of_dtable);
141537f1f268SConrad Meyer     FSE_free_dtable(&dict->ml_dtable);
141637f1f268SConrad Meyer 
141737f1f268SConrad Meyer     free(dict->content);
141837f1f268SConrad Meyer 
141937f1f268SConrad Meyer     memset(dict, 0, sizeof(dictionary_t));
142037f1f268SConrad Meyer 
142137f1f268SConrad Meyer     free(dict);
142237f1f268SConrad Meyer }
142337f1f268SConrad Meyer 
142437f1f268SConrad Meyer 
142537f1f268SConrad Meyer #if !defined(ZDEC_NO_DICTIONARY)
142637f1f268SConrad Meyer #define DICT_SIZE_ERROR() ERROR("Dictionary size cannot be less than 8 bytes")
142737f1f268SConrad Meyer #define NULL_SRC() ERROR("Tried to create dictionary with pointer to null src");
142837f1f268SConrad Meyer 
14290c16b537SWarner Losh static void init_dictionary_content(dictionary_t *const dict,
14300c16b537SWarner Losh                                     istream_t *const in);
14310c16b537SWarner Losh 
parse_dictionary(dictionary_t * const dict,const void * src,size_t src_len)14320c16b537SWarner Losh void parse_dictionary(dictionary_t *const dict, const void *src,
14330c16b537SWarner Losh                              size_t src_len) {
14340c16b537SWarner Losh     const u8 *byte_src = (const u8 *)src;
14350c16b537SWarner Losh     memset(dict, 0, sizeof(dictionary_t));
14360c16b537SWarner Losh     if (src == NULL) { /* cannot initialize dictionary with null src */
14370c16b537SWarner Losh         NULL_SRC();
14380c16b537SWarner Losh     }
14390c16b537SWarner Losh     if (src_len < 8) {
14400c16b537SWarner Losh         DICT_SIZE_ERROR();
14410c16b537SWarner Losh     }
14420c16b537SWarner Losh 
14430c16b537SWarner Losh     istream_t in = IO_make_istream(byte_src, src_len);
14440c16b537SWarner Losh 
14450c16b537SWarner Losh     const u32 magic_number = IO_read_bits(&in, 32);
14460c16b537SWarner Losh     if (magic_number != 0xEC30A437) {
14470c16b537SWarner Losh         // raw content dict
14480c16b537SWarner Losh         IO_rewind_bits(&in, 32);
14490c16b537SWarner Losh         init_dictionary_content(dict, &in);
14500c16b537SWarner Losh         return;
14510c16b537SWarner Losh     }
14520c16b537SWarner Losh 
14530c16b537SWarner Losh     dict->dictionary_id = IO_read_bits(&in, 32);
14540c16b537SWarner Losh 
14550c16b537SWarner Losh     // "Entropy_Tables : following the same format as the tables in compressed
14560c16b537SWarner Losh     // blocks. They are stored in following order : Huffman tables for literals,
14570c16b537SWarner Losh     // FSE table for offsets, FSE table for match lengths, and FSE table for
14580c16b537SWarner Losh     // literals lengths. It's finally followed by 3 offset values, populating
14590c16b537SWarner Losh     // recent offsets (instead of using {1,4,8}), stored in order, 4-bytes
14600c16b537SWarner Losh     // little-endian each, for a total of 12 bytes. Each recent offset must have
14610c16b537SWarner Losh     // a value < dictionary size."
14620c16b537SWarner Losh     decode_huf_table(&dict->literals_dtable, &in);
14630c16b537SWarner Losh     decode_seq_table(&dict->of_dtable, &in, seq_offset, seq_fse);
14640c16b537SWarner Losh     decode_seq_table(&dict->ml_dtable, &in, seq_match_length, seq_fse);
14650c16b537SWarner Losh     decode_seq_table(&dict->ll_dtable, &in, seq_literal_length, seq_fse);
14660c16b537SWarner Losh 
14670c16b537SWarner Losh     // Read in the previous offset history
14680c16b537SWarner Losh     dict->previous_offsets[0] = IO_read_bits(&in, 32);
14690c16b537SWarner Losh     dict->previous_offsets[1] = IO_read_bits(&in, 32);
14700c16b537SWarner Losh     dict->previous_offsets[2] = IO_read_bits(&in, 32);
14710c16b537SWarner Losh 
14720c16b537SWarner Losh     // Ensure the provided offsets aren't too large
14730c16b537SWarner Losh     // "Each recent offset must have a value < dictionary size."
14740c16b537SWarner Losh     for (int i = 0; i < 3; i++) {
14750c16b537SWarner Losh         if (dict->previous_offsets[i] > src_len) {
14760c16b537SWarner Losh             ERROR("Dictionary corrupted");
14770c16b537SWarner Losh         }
14780c16b537SWarner Losh     }
14790c16b537SWarner Losh 
14800c16b537SWarner Losh     // "Content : The rest of the dictionary is its content. The content act as
14810c16b537SWarner Losh     // a "past" in front of data to compress or decompress, so it can be
14820c16b537SWarner Losh     // referenced in sequence commands."
14830c16b537SWarner Losh     init_dictionary_content(dict, &in);
14840c16b537SWarner Losh }
14850c16b537SWarner Losh 
init_dictionary_content(dictionary_t * const dict,istream_t * const in)14860c16b537SWarner Losh static void init_dictionary_content(dictionary_t *const dict,
14870c16b537SWarner Losh                                     istream_t *const in) {
14880c16b537SWarner Losh     // Copy in the content
14890c16b537SWarner Losh     dict->content_size = IO_istream_len(in);
14900c16b537SWarner Losh     dict->content = malloc(dict->content_size);
14910c16b537SWarner Losh     if (!dict->content) {
14920c16b537SWarner Losh         BAD_ALLOC();
14930c16b537SWarner Losh     }
14940c16b537SWarner Losh 
14950c16b537SWarner Losh     const u8 *const content = IO_get_read_ptr(in, dict->content_size);
14960c16b537SWarner Losh 
14970c16b537SWarner Losh     memcpy(dict->content, content, dict->content_size);
14980c16b537SWarner Losh }
14990c16b537SWarner Losh 
HUF_copy_dtable(HUF_dtable * const dst,const HUF_dtable * const src)150037f1f268SConrad Meyer static void HUF_copy_dtable(HUF_dtable *const dst,
150137f1f268SConrad Meyer                             const HUF_dtable *const src) {
150237f1f268SConrad Meyer     if (src->max_bits == 0) {
150337f1f268SConrad Meyer         memset(dst, 0, sizeof(HUF_dtable));
150437f1f268SConrad Meyer         return;
15050c16b537SWarner Losh     }
150637f1f268SConrad Meyer 
150737f1f268SConrad Meyer     const size_t size = (size_t)1 << src->max_bits;
150837f1f268SConrad Meyer     dst->max_bits = src->max_bits;
150937f1f268SConrad Meyer 
151037f1f268SConrad Meyer     dst->symbols = malloc(size);
151137f1f268SConrad Meyer     dst->num_bits = malloc(size);
151237f1f268SConrad Meyer     if (!dst->symbols || !dst->num_bits) {
151337f1f268SConrad Meyer         BAD_ALLOC();
151437f1f268SConrad Meyer     }
151537f1f268SConrad Meyer 
151637f1f268SConrad Meyer     memcpy(dst->symbols, src->symbols, size);
151737f1f268SConrad Meyer     memcpy(dst->num_bits, src->num_bits, size);
151837f1f268SConrad Meyer }
151937f1f268SConrad Meyer 
FSE_copy_dtable(FSE_dtable * const dst,const FSE_dtable * const src)152037f1f268SConrad Meyer static void FSE_copy_dtable(FSE_dtable *const dst, const FSE_dtable *const src) {
152137f1f268SConrad Meyer     if (src->accuracy_log == 0) {
152237f1f268SConrad Meyer         memset(dst, 0, sizeof(FSE_dtable));
152337f1f268SConrad Meyer         return;
152437f1f268SConrad Meyer     }
152537f1f268SConrad Meyer 
152637f1f268SConrad Meyer     size_t size = (size_t)1 << src->accuracy_log;
152737f1f268SConrad Meyer     dst->accuracy_log = src->accuracy_log;
152837f1f268SConrad Meyer 
152937f1f268SConrad Meyer     dst->symbols = malloc(size);
153037f1f268SConrad Meyer     dst->num_bits = malloc(size);
153137f1f268SConrad Meyer     dst->new_state_base = malloc(size * sizeof(u16));
153237f1f268SConrad Meyer     if (!dst->symbols || !dst->num_bits || !dst->new_state_base) {
153337f1f268SConrad Meyer         BAD_ALLOC();
153437f1f268SConrad Meyer     }
153537f1f268SConrad Meyer 
153637f1f268SConrad Meyer     memcpy(dst->symbols, src->symbols, size);
153737f1f268SConrad Meyer     memcpy(dst->num_bits, src->num_bits, size);
153837f1f268SConrad Meyer     memcpy(dst->new_state_base, src->new_state_base, size * sizeof(u16));
153937f1f268SConrad Meyer }
154037f1f268SConrad Meyer 
154137f1f268SConrad Meyer /// A dictionary acts as initializing values for the frame context before
154237f1f268SConrad Meyer /// decompression, so we implement it by applying it's predetermined
154337f1f268SConrad Meyer /// tables and content to the context before beginning decompression
frame_context_apply_dict(frame_context_t * const ctx,const dictionary_t * const dict)154437f1f268SConrad Meyer static void frame_context_apply_dict(frame_context_t *const ctx,
154537f1f268SConrad Meyer                                      const dictionary_t *const dict) {
154637f1f268SConrad Meyer     // If the content pointer is NULL then it must be an empty dict
154737f1f268SConrad Meyer     if (!dict || !dict->content)
154837f1f268SConrad Meyer         return;
154937f1f268SConrad Meyer 
155037f1f268SConrad Meyer     // If the requested dictionary_id is non-zero, the correct dictionary must
155137f1f268SConrad Meyer     // be present
155237f1f268SConrad Meyer     if (ctx->header.dictionary_id != 0 &&
155337f1f268SConrad Meyer         ctx->header.dictionary_id != dict->dictionary_id) {
155437f1f268SConrad Meyer         ERROR("Wrong dictionary provided");
155537f1f268SConrad Meyer     }
155637f1f268SConrad Meyer 
155737f1f268SConrad Meyer     // Copy the dict content to the context for references during sequence
155837f1f268SConrad Meyer     // execution
155937f1f268SConrad Meyer     ctx->dict_content = dict->content;
156037f1f268SConrad Meyer     ctx->dict_content_len = dict->content_size;
156137f1f268SConrad Meyer 
156237f1f268SConrad Meyer     // If it's a formatted dict copy the precomputed tables in so they can
156337f1f268SConrad Meyer     // be used in the table repeat modes
156437f1f268SConrad Meyer     if (dict->dictionary_id != 0) {
156537f1f268SConrad Meyer         // Deep copy the entropy tables so they can be freed independently of
156637f1f268SConrad Meyer         // the dictionary struct
156737f1f268SConrad Meyer         HUF_copy_dtable(&ctx->literals_dtable, &dict->literals_dtable);
156837f1f268SConrad Meyer         FSE_copy_dtable(&ctx->ll_dtable, &dict->ll_dtable);
156937f1f268SConrad Meyer         FSE_copy_dtable(&ctx->of_dtable, &dict->of_dtable);
157037f1f268SConrad Meyer         FSE_copy_dtable(&ctx->ml_dtable, &dict->ml_dtable);
157137f1f268SConrad Meyer 
157237f1f268SConrad Meyer         // Copy the repeated offsets
157337f1f268SConrad Meyer         memcpy(ctx->previous_offsets, dict->previous_offsets,
157437f1f268SConrad Meyer                sizeof(ctx->previous_offsets));
157537f1f268SConrad Meyer     }
157637f1f268SConrad Meyer }
157737f1f268SConrad Meyer 
157837f1f268SConrad Meyer #else  // ZDEC_NO_DICTIONARY is defined
157937f1f268SConrad Meyer 
frame_context_apply_dict(frame_context_t * const ctx,const dictionary_t * const dict)158037f1f268SConrad Meyer static void frame_context_apply_dict(frame_context_t *const ctx,
158137f1f268SConrad Meyer                                      const dictionary_t *const dict) {
158237f1f268SConrad Meyer     (void)ctx;
158337f1f268SConrad Meyer     if (dict && dict->content) ERROR("dictionary not supported");
158437f1f268SConrad Meyer }
158537f1f268SConrad Meyer 
158637f1f268SConrad Meyer #endif
15870c16b537SWarner Losh /******* END DICTIONARY PARSING ***********************************************/
15880c16b537SWarner Losh 
15890c16b537SWarner Losh /******* IO STREAM OPERATIONS *************************************************/
15909cbefe25SConrad Meyer 
15910c16b537SWarner Losh /// Reads `num` bits from a bitstream, and updates the internal offset
IO_read_bits(istream_t * const in,const int num_bits)15920c16b537SWarner Losh static inline u64 IO_read_bits(istream_t *const in, const int num_bits) {
15930c16b537SWarner Losh     if (num_bits > 64 || num_bits <= 0) {
15940c16b537SWarner Losh         ERROR("Attempt to read an invalid number of bits");
15950c16b537SWarner Losh     }
15960c16b537SWarner Losh 
15970c16b537SWarner Losh     const size_t bytes = (num_bits + in->bit_offset + 7) / 8;
15980c16b537SWarner Losh     const size_t full_bytes = (num_bits + in->bit_offset) / 8;
15990c16b537SWarner Losh     if (bytes > in->len) {
16000c16b537SWarner Losh         INP_SIZE();
16010c16b537SWarner Losh     }
16020c16b537SWarner Losh 
16030c16b537SWarner Losh     const u64 result = read_bits_LE(in->ptr, num_bits, in->bit_offset);
16040c16b537SWarner Losh 
16050c16b537SWarner Losh     in->bit_offset = (num_bits + in->bit_offset) % 8;
16060c16b537SWarner Losh     in->ptr += full_bytes;
16070c16b537SWarner Losh     in->len -= full_bytes;
16080c16b537SWarner Losh 
16090c16b537SWarner Losh     return result;
16100c16b537SWarner Losh }
16110c16b537SWarner Losh 
16120c16b537SWarner Losh /// If a non-zero number of bits have been read from the current byte, advance
16130c16b537SWarner Losh /// the offset to the next byte
IO_rewind_bits(istream_t * const in,int num_bits)16140c16b537SWarner Losh static inline void IO_rewind_bits(istream_t *const in, int num_bits) {
16150c16b537SWarner Losh     if (num_bits < 0) {
16160c16b537SWarner Losh         ERROR("Attempting to rewind stream by a negative number of bits");
16170c16b537SWarner Losh     }
16180c16b537SWarner Losh 
16190c16b537SWarner Losh     // move the offset back by `num_bits` bits
16200c16b537SWarner Losh     const int new_offset = in->bit_offset - num_bits;
16210c16b537SWarner Losh     // determine the number of whole bytes we have to rewind, rounding up to an
16220c16b537SWarner Losh     // integer number (e.g. if `new_offset == -5`, `bytes == 1`)
16230c16b537SWarner Losh     const i64 bytes = -(new_offset - 7) / 8;
16240c16b537SWarner Losh 
16250c16b537SWarner Losh     in->ptr -= bytes;
16260c16b537SWarner Losh     in->len += bytes;
16270c16b537SWarner Losh     // make sure the resulting `bit_offset` is positive, as mod in C does not
16280c16b537SWarner Losh     // convert numbers from negative to positive (e.g. -22 % 8 == -6)
16290c16b537SWarner Losh     in->bit_offset = ((new_offset % 8) + 8) % 8;
16300c16b537SWarner Losh }
16310c16b537SWarner Losh 
16320c16b537SWarner Losh /// If the remaining bits in a byte will be unused, advance to the end of the
16330c16b537SWarner Losh /// byte
IO_align_stream(istream_t * const in)16340c16b537SWarner Losh static inline void IO_align_stream(istream_t *const in) {
16350c16b537SWarner Losh     if (in->bit_offset != 0) {
16360c16b537SWarner Losh         if (in->len == 0) {
16370c16b537SWarner Losh             INP_SIZE();
16380c16b537SWarner Losh         }
16390c16b537SWarner Losh         in->ptr++;
16400c16b537SWarner Losh         in->len--;
16410c16b537SWarner Losh         in->bit_offset = 0;
16420c16b537SWarner Losh     }
16430c16b537SWarner Losh }
16440c16b537SWarner Losh 
16450c16b537SWarner Losh /// Write the given byte into the output stream
IO_write_byte(ostream_t * const out,u8 symb)16460c16b537SWarner Losh static inline void IO_write_byte(ostream_t *const out, u8 symb) {
16470c16b537SWarner Losh     if (out->len == 0) {
16480c16b537SWarner Losh         OUT_SIZE();
16490c16b537SWarner Losh     }
16500c16b537SWarner Losh 
16510c16b537SWarner Losh     out->ptr[0] = symb;
16520c16b537SWarner Losh     out->ptr++;
16530c16b537SWarner Losh     out->len--;
16540c16b537SWarner Losh }
16550c16b537SWarner Losh 
16560c16b537SWarner Losh /// Returns the number of bytes left to be read in this stream.  The stream must
16570c16b537SWarner Losh /// be byte aligned.
IO_istream_len(const istream_t * const in)16580c16b537SWarner Losh static inline size_t IO_istream_len(const istream_t *const in) {
16590c16b537SWarner Losh     return in->len;
16600c16b537SWarner Losh }
16610c16b537SWarner Losh 
16620c16b537SWarner Losh /// Returns a pointer where `len` bytes can be read, and advances the internal
16630c16b537SWarner Losh /// state.  The stream must be byte aligned.
IO_get_read_ptr(istream_t * const in,size_t len)16640c16b537SWarner Losh static inline const u8 *IO_get_read_ptr(istream_t *const in, size_t len) {
16650c16b537SWarner Losh     if (len > in->len) {
16660c16b537SWarner Losh         INP_SIZE();
16670c16b537SWarner Losh     }
16680c16b537SWarner Losh     if (in->bit_offset != 0) {
16699cbefe25SConrad Meyer         ERROR("Attempting to operate on a non-byte aligned stream");
16700c16b537SWarner Losh     }
16710c16b537SWarner Losh     const u8 *const ptr = in->ptr;
16720c16b537SWarner Losh     in->ptr += len;
16730c16b537SWarner Losh     in->len -= len;
16740c16b537SWarner Losh 
16750c16b537SWarner Losh     return ptr;
16760c16b537SWarner Losh }
16770c16b537SWarner Losh /// Returns a pointer to write `len` bytes to, and advances the internal state
IO_get_write_ptr(ostream_t * const out,size_t len)16780c16b537SWarner Losh static inline u8 *IO_get_write_ptr(ostream_t *const out, size_t len) {
16790c16b537SWarner Losh     if (len > out->len) {
16800c16b537SWarner Losh         OUT_SIZE();
16810c16b537SWarner Losh     }
16820c16b537SWarner Losh     u8 *const ptr = out->ptr;
16830c16b537SWarner Losh     out->ptr += len;
16840c16b537SWarner Losh     out->len -= len;
16850c16b537SWarner Losh 
16860c16b537SWarner Losh     return ptr;
16870c16b537SWarner Losh }
16880c16b537SWarner Losh 
16890c16b537SWarner Losh /// Advance the inner state by `len` bytes
IO_advance_input(istream_t * const in,size_t len)16900c16b537SWarner Losh static inline void IO_advance_input(istream_t *const in, size_t len) {
16910c16b537SWarner Losh     if (len > in->len) {
16920c16b537SWarner Losh          INP_SIZE();
16930c16b537SWarner Losh     }
16940c16b537SWarner Losh     if (in->bit_offset != 0) {
16959cbefe25SConrad Meyer         ERROR("Attempting to operate on a non-byte aligned stream");
16960c16b537SWarner Losh     }
16970c16b537SWarner Losh 
16980c16b537SWarner Losh     in->ptr += len;
16990c16b537SWarner Losh     in->len -= len;
17000c16b537SWarner Losh }
17010c16b537SWarner Losh 
17020c16b537SWarner Losh /// Returns an `ostream_t` constructed from the given pointer and length
IO_make_ostream(u8 * out,size_t len)17030c16b537SWarner Losh static inline ostream_t IO_make_ostream(u8 *out, size_t len) {
17040c16b537SWarner Losh     return (ostream_t) { out, len };
17050c16b537SWarner Losh }
17060c16b537SWarner Losh 
17070c16b537SWarner Losh /// Returns an `istream_t` constructed from the given pointer and length
IO_make_istream(const u8 * in,size_t len)17080c16b537SWarner Losh static inline istream_t IO_make_istream(const u8 *in, size_t len) {
17090c16b537SWarner Losh     return (istream_t) { in, len, 0 };
17100c16b537SWarner Losh }
17110c16b537SWarner Losh 
17120c16b537SWarner Losh /// Returns an `istream_t` with the same base as `in`, and length `len`
17130c16b537SWarner Losh /// Then, advance `in` to account for the consumed bytes
17140c16b537SWarner Losh /// `in` must be byte aligned
IO_make_sub_istream(istream_t * const in,size_t len)17150c16b537SWarner Losh static inline istream_t IO_make_sub_istream(istream_t *const in, size_t len) {
17160c16b537SWarner Losh     // Consume `len` bytes of the parent stream
17170c16b537SWarner Losh     const u8 *const ptr = IO_get_read_ptr(in, len);
17180c16b537SWarner Losh 
17190c16b537SWarner Losh     // Make a substream using the pointer to those `len` bytes
17200c16b537SWarner Losh     return IO_make_istream(ptr, len);
17210c16b537SWarner Losh }
17220c16b537SWarner Losh /******* END IO STREAM OPERATIONS *********************************************/
17230c16b537SWarner Losh 
17240c16b537SWarner Losh /******* BITSTREAM OPERATIONS *************************************************/
17250c16b537SWarner Losh /// Read `num` bits (up to 64) from `src + offset`, where `offset` is in bits
read_bits_LE(const u8 * src,const int num_bits,const size_t offset)17260c16b537SWarner Losh static inline u64 read_bits_LE(const u8 *src, const int num_bits,
17270c16b537SWarner Losh                                const size_t offset) {
17280c16b537SWarner Losh     if (num_bits > 64) {
17290c16b537SWarner Losh         ERROR("Attempt to read an invalid number of bits");
17300c16b537SWarner Losh     }
17310c16b537SWarner Losh 
17320c16b537SWarner Losh     // Skip over bytes that aren't in range
17330c16b537SWarner Losh     src += offset / 8;
17340c16b537SWarner Losh     size_t bit_offset = offset % 8;
17350c16b537SWarner Losh     u64 res = 0;
17360c16b537SWarner Losh 
17370c16b537SWarner Losh     int shift = 0;
17380c16b537SWarner Losh     int left = num_bits;
17390c16b537SWarner Losh     while (left > 0) {
17400c16b537SWarner Losh         u64 mask = left >= 8 ? 0xff : (((u64)1 << left) - 1);
17410c16b537SWarner Losh         // Read the next byte, shift it to account for the offset, and then mask
17420c16b537SWarner Losh         // out the top part if we don't need all the bits
17430c16b537SWarner Losh         res += (((u64)*src++ >> bit_offset) & mask) << shift;
17440c16b537SWarner Losh         shift += 8 - bit_offset;
17450c16b537SWarner Losh         left -= 8 - bit_offset;
17460c16b537SWarner Losh         bit_offset = 0;
17470c16b537SWarner Losh     }
17480c16b537SWarner Losh 
17490c16b537SWarner Losh     return res;
17500c16b537SWarner Losh }
17510c16b537SWarner Losh 
17520c16b537SWarner Losh /// Read bits from the end of a HUF or FSE bitstream.  `offset` is in bits, so
17530c16b537SWarner Losh /// it updates `offset` to `offset - bits`, and then reads `bits` bits from
17540c16b537SWarner Losh /// `src + offset`.  If the offset becomes negative, the extra bits at the
17550c16b537SWarner Losh /// bottom are filled in with `0` bits instead of reading from before `src`.
STREAM_read_bits(const u8 * const src,const int bits,i64 * const offset)17560c16b537SWarner Losh static inline u64 STREAM_read_bits(const u8 *const src, const int bits,
17570c16b537SWarner Losh                                    i64 *const offset) {
17580c16b537SWarner Losh     *offset = *offset - bits;
17590c16b537SWarner Losh     size_t actual_off = *offset;
17600c16b537SWarner Losh     size_t actual_bits = bits;
17610c16b537SWarner Losh     // Don't actually read bits from before the start of src, so if `*offset <
17620c16b537SWarner Losh     // 0` fix actual_off and actual_bits to reflect the quantity to read
17630c16b537SWarner Losh     if (*offset < 0) {
17640c16b537SWarner Losh         actual_bits += *offset;
17650c16b537SWarner Losh         actual_off = 0;
17660c16b537SWarner Losh     }
17670c16b537SWarner Losh     u64 res = read_bits_LE(src, actual_bits, actual_off);
17680c16b537SWarner Losh 
17690c16b537SWarner Losh     if (*offset < 0) {
17700c16b537SWarner Losh         // Fill in the bottom "overflowed" bits with 0's
17710c16b537SWarner Losh         res = -*offset >= 64 ? 0 : (res << -*offset);
17720c16b537SWarner Losh     }
17730c16b537SWarner Losh     return res;
17740c16b537SWarner Losh }
17750c16b537SWarner Losh /******* END BITSTREAM OPERATIONS *********************************************/
17760c16b537SWarner Losh 
17770c16b537SWarner Losh /******* BIT COUNTING OPERATIONS **********************************************/
17780c16b537SWarner Losh /// Returns `x`, where `2^x` is the largest power of 2 less than or equal to
17790c16b537SWarner Losh /// `num`, or `-1` if `num == 0`.
highest_set_bit(const u64 num)17800c16b537SWarner Losh static inline int highest_set_bit(const u64 num) {
17810c16b537SWarner Losh     for (int i = 63; i >= 0; i--) {
17820c16b537SWarner Losh         if (((u64)1 << i) <= num) {
17830c16b537SWarner Losh             return i;
17840c16b537SWarner Losh         }
17850c16b537SWarner Losh     }
17860c16b537SWarner Losh     return -1;
17870c16b537SWarner Losh }
17880c16b537SWarner Losh /******* END BIT COUNTING OPERATIONS ******************************************/
17890c16b537SWarner Losh 
17900c16b537SWarner Losh /******* HUFFMAN PRIMITIVES ***************************************************/
HUF_decode_symbol(const HUF_dtable * const dtable,u16 * const state,const u8 * const src,i64 * const offset)17910c16b537SWarner Losh static inline u8 HUF_decode_symbol(const HUF_dtable *const dtable,
17920c16b537SWarner Losh                                    u16 *const state, const u8 *const src,
17930c16b537SWarner Losh                                    i64 *const offset) {
17940c16b537SWarner Losh     // Look up the symbol and number of bits to read
17950c16b537SWarner Losh     const u8 symb = dtable->symbols[*state];
17960c16b537SWarner Losh     const u8 bits = dtable->num_bits[*state];
17970c16b537SWarner Losh     const u16 rest = STREAM_read_bits(src, bits, offset);
17980c16b537SWarner Losh     // Shift `bits` bits out of the state, keeping the low order bits that
17990c16b537SWarner Losh     // weren't necessary to determine this symbol.  Then add in the new bits
18000c16b537SWarner Losh     // read from the stream.
18010c16b537SWarner Losh     *state = ((*state << bits) + rest) & (((u16)1 << dtable->max_bits) - 1);
18020c16b537SWarner Losh 
18030c16b537SWarner Losh     return symb;
18040c16b537SWarner Losh }
18050c16b537SWarner Losh 
HUF_init_state(const HUF_dtable * const dtable,u16 * const state,const u8 * const src,i64 * const offset)18060c16b537SWarner Losh static inline void HUF_init_state(const HUF_dtable *const dtable,
18070c16b537SWarner Losh                                   u16 *const state, const u8 *const src,
18080c16b537SWarner Losh                                   i64 *const offset) {
18090c16b537SWarner Losh     // Read in a full `dtable->max_bits` bits to initialize the state
18100c16b537SWarner Losh     const u8 bits = dtable->max_bits;
18110c16b537SWarner Losh     *state = STREAM_read_bits(src, bits, offset);
18120c16b537SWarner Losh }
18130c16b537SWarner Losh 
HUF_decompress_1stream(const HUF_dtable * const dtable,ostream_t * const out,istream_t * const in)18140c16b537SWarner Losh static size_t HUF_decompress_1stream(const HUF_dtable *const dtable,
18150c16b537SWarner Losh                                      ostream_t *const out,
18160c16b537SWarner Losh                                      istream_t *const in) {
18170c16b537SWarner Losh     const size_t len = IO_istream_len(in);
18180c16b537SWarner Losh     if (len == 0) {
18190c16b537SWarner Losh         INP_SIZE();
18200c16b537SWarner Losh     }
18210c16b537SWarner Losh     const u8 *const src = IO_get_read_ptr(in, len);
18220c16b537SWarner Losh 
18230c16b537SWarner Losh     // "Each bitstream must be read backward, that is starting from the end down
18240c16b537SWarner Losh     // to the beginning. Therefore it's necessary to know the size of each
18250c16b537SWarner Losh     // bitstream.
18260c16b537SWarner Losh     //
18270c16b537SWarner Losh     // It's also necessary to know exactly which bit is the latest. This is
18280c16b537SWarner Losh     // detected by a final bit flag : the highest bit of latest byte is a
18290c16b537SWarner Losh     // final-bit-flag. Consequently, a last byte of 0 is not possible. And the
18300c16b537SWarner Losh     // final-bit-flag itself is not part of the useful bitstream. Hence, the
18310c16b537SWarner Losh     // last byte contains between 0 and 7 useful bits."
18320c16b537SWarner Losh     const int padding = 8 - highest_set_bit(src[len - 1]);
18330c16b537SWarner Losh 
18340c16b537SWarner Losh     // Offset starts at the end because HUF streams are read backwards
18350c16b537SWarner Losh     i64 bit_offset = len * 8 - padding;
18360c16b537SWarner Losh     u16 state;
18370c16b537SWarner Losh 
18380c16b537SWarner Losh     HUF_init_state(dtable, &state, src, &bit_offset);
18390c16b537SWarner Losh 
18400c16b537SWarner Losh     size_t symbols_written = 0;
18410c16b537SWarner Losh     while (bit_offset > -dtable->max_bits) {
18420c16b537SWarner Losh         // Iterate over the stream, decoding one symbol at a time
18430c16b537SWarner Losh         IO_write_byte(out, HUF_decode_symbol(dtable, &state, src, &bit_offset));
18440c16b537SWarner Losh         symbols_written++;
18450c16b537SWarner Losh     }
18460c16b537SWarner Losh     // "The process continues up to reading the required number of symbols per
18470c16b537SWarner Losh     // stream. If a bitstream is not entirely and exactly consumed, hence
18480c16b537SWarner Losh     // reaching exactly its beginning position with all bits consumed, the
18490c16b537SWarner Losh     // decoding process is considered faulty."
18500c16b537SWarner Losh 
18510c16b537SWarner Losh     // When all symbols have been decoded, the final state value shouldn't have
18520c16b537SWarner Losh     // any data from the stream, so it should have "read" dtable->max_bits from
18530c16b537SWarner Losh     // before the start of `src`
18540c16b537SWarner Losh     // Therefore `offset`, the edge to start reading new bits at, should be
18550c16b537SWarner Losh     // dtable->max_bits before the start of the stream
18560c16b537SWarner Losh     if (bit_offset != -dtable->max_bits) {
18570c16b537SWarner Losh         CORRUPTION();
18580c16b537SWarner Losh     }
18590c16b537SWarner Losh 
18600c16b537SWarner Losh     return symbols_written;
18610c16b537SWarner Losh }
18620c16b537SWarner Losh 
HUF_decompress_4stream(const HUF_dtable * const dtable,ostream_t * const out,istream_t * const in)18630c16b537SWarner Losh static size_t HUF_decompress_4stream(const HUF_dtable *const dtable,
18640c16b537SWarner Losh                                      ostream_t *const out, istream_t *const in) {
18650c16b537SWarner Losh     // "Compressed size is provided explicitly : in the 4-streams variant,
18660c16b537SWarner Losh     // bitstreams are preceded by 3 unsigned little-endian 16-bits values. Each
18670c16b537SWarner Losh     // value represents the compressed size of one stream, in order. The last
18680c16b537SWarner Losh     // stream size is deducted from total compressed size and from previously
18690c16b537SWarner Losh     // decoded stream sizes"
18700c16b537SWarner Losh     const size_t csize1 = IO_read_bits(in, 16);
18710c16b537SWarner Losh     const size_t csize2 = IO_read_bits(in, 16);
18720c16b537SWarner Losh     const size_t csize3 = IO_read_bits(in, 16);
18730c16b537SWarner Losh 
18740c16b537SWarner Losh     istream_t in1 = IO_make_sub_istream(in, csize1);
18750c16b537SWarner Losh     istream_t in2 = IO_make_sub_istream(in, csize2);
18760c16b537SWarner Losh     istream_t in3 = IO_make_sub_istream(in, csize3);
18770c16b537SWarner Losh     istream_t in4 = IO_make_sub_istream(in, IO_istream_len(in));
18780c16b537SWarner Losh 
18790c16b537SWarner Losh     size_t total_output = 0;
18800c16b537SWarner Losh     // Decode each stream independently for simplicity
18810c16b537SWarner Losh     // If we wanted to we could decode all 4 at the same time for speed,
18820c16b537SWarner Losh     // utilizing more execution units
18830c16b537SWarner Losh     total_output += HUF_decompress_1stream(dtable, out, &in1);
18840c16b537SWarner Losh     total_output += HUF_decompress_1stream(dtable, out, &in2);
18850c16b537SWarner Losh     total_output += HUF_decompress_1stream(dtable, out, &in3);
18860c16b537SWarner Losh     total_output += HUF_decompress_1stream(dtable, out, &in4);
18870c16b537SWarner Losh 
18880c16b537SWarner Losh     return total_output;
18890c16b537SWarner Losh }
18900c16b537SWarner Losh 
18910c16b537SWarner Losh /// Initializes a Huffman table using canonical Huffman codes
18920c16b537SWarner Losh /// For more explanation on canonical Huffman codes see
18930c16b537SWarner Losh /// http://www.cs.uofs.edu/~mccloske/courses/cmps340/huff_canonical_dec2015.html
18940c16b537SWarner Losh /// Codes within a level are allocated in symbol order (i.e. smaller symbols get
18950c16b537SWarner Losh /// earlier codes)
HUF_init_dtable(HUF_dtable * const table,const u8 * const bits,const int num_symbs)18960c16b537SWarner Losh static void HUF_init_dtable(HUF_dtable *const table, const u8 *const bits,
18970c16b537SWarner Losh                             const int num_symbs) {
18980c16b537SWarner Losh     memset(table, 0, sizeof(HUF_dtable));
18990c16b537SWarner Losh     if (num_symbs > HUF_MAX_SYMBS) {
19000c16b537SWarner Losh         ERROR("Too many symbols for Huffman");
19010c16b537SWarner Losh     }
19020c16b537SWarner Losh 
19030c16b537SWarner Losh     u8 max_bits = 0;
19040c16b537SWarner Losh     u16 rank_count[HUF_MAX_BITS + 1];
19050c16b537SWarner Losh     memset(rank_count, 0, sizeof(rank_count));
19060c16b537SWarner Losh 
19070c16b537SWarner Losh     // Count the number of symbols for each number of bits, and determine the
19080c16b537SWarner Losh     // depth of the tree
19090c16b537SWarner Losh     for (int i = 0; i < num_symbs; i++) {
19100c16b537SWarner Losh         if (bits[i] > HUF_MAX_BITS) {
19110c16b537SWarner Losh             ERROR("Huffman table depth too large");
19120c16b537SWarner Losh         }
19130c16b537SWarner Losh         max_bits = MAX(max_bits, bits[i]);
19140c16b537SWarner Losh         rank_count[bits[i]]++;
19150c16b537SWarner Losh     }
19160c16b537SWarner Losh 
19170c16b537SWarner Losh     const size_t table_size = 1 << max_bits;
19180c16b537SWarner Losh     table->max_bits = max_bits;
19190c16b537SWarner Losh     table->symbols = malloc(table_size);
19200c16b537SWarner Losh     table->num_bits = malloc(table_size);
19210c16b537SWarner Losh 
19220c16b537SWarner Losh     if (!table->symbols || !table->num_bits) {
19230c16b537SWarner Losh         free(table->symbols);
19240c16b537SWarner Losh         free(table->num_bits);
19250c16b537SWarner Losh         BAD_ALLOC();
19260c16b537SWarner Losh     }
19270c16b537SWarner Losh 
19280c16b537SWarner Losh     // "Symbols are sorted by Weight. Within same Weight, symbols keep natural
19290c16b537SWarner Losh     // order. Symbols with a Weight of zero are removed. Then, starting from
19300c16b537SWarner Losh     // lowest weight, prefix codes are distributed in order."
19310c16b537SWarner Losh 
19320c16b537SWarner Losh     u32 rank_idx[HUF_MAX_BITS + 1];
19330c16b537SWarner Losh     // Initialize the starting codes for each rank (number of bits)
19340c16b537SWarner Losh     rank_idx[max_bits] = 0;
19350c16b537SWarner Losh     for (int i = max_bits; i >= 1; i--) {
19360c16b537SWarner Losh         rank_idx[i - 1] = rank_idx[i] + rank_count[i] * (1 << (max_bits - i));
19370c16b537SWarner Losh         // The entire range takes the same number of bits so we can memset it
19380c16b537SWarner Losh         memset(&table->num_bits[rank_idx[i]], i, rank_idx[i - 1] - rank_idx[i]);
19390c16b537SWarner Losh     }
19400c16b537SWarner Losh 
19410c16b537SWarner Losh     if (rank_idx[0] != table_size) {
19420c16b537SWarner Losh         CORRUPTION();
19430c16b537SWarner Losh     }
19440c16b537SWarner Losh 
19450c16b537SWarner Losh     // Allocate codes and fill in the table
19460c16b537SWarner Losh     for (int i = 0; i < num_symbs; i++) {
19470c16b537SWarner Losh         if (bits[i] != 0) {
19480c16b537SWarner Losh             // Allocate a code for this symbol and set its range in the table
19490c16b537SWarner Losh             const u16 code = rank_idx[bits[i]];
19500c16b537SWarner Losh             // Since the code doesn't care about the bottom `max_bits - bits[i]`
19510c16b537SWarner Losh             // bits of state, it gets a range that spans all possible values of
19520c16b537SWarner Losh             // the lower bits
19530c16b537SWarner Losh             const u16 len = 1 << (max_bits - bits[i]);
19540c16b537SWarner Losh             memset(&table->symbols[code], i, len);
19550c16b537SWarner Losh             rank_idx[bits[i]] += len;
19560c16b537SWarner Losh         }
19570c16b537SWarner Losh     }
19580c16b537SWarner Losh }
19590c16b537SWarner Losh 
HUF_init_dtable_usingweights(HUF_dtable * const table,const u8 * const weights,const int num_symbs)19600c16b537SWarner Losh static void HUF_init_dtable_usingweights(HUF_dtable *const table,
19610c16b537SWarner Losh                                          const u8 *const weights,
19620c16b537SWarner Losh                                          const int num_symbs) {
19630c16b537SWarner Losh     // +1 because the last weight is not transmitted in the header
19640c16b537SWarner Losh     if (num_symbs + 1 > HUF_MAX_SYMBS) {
19650c16b537SWarner Losh         ERROR("Too many symbols for Huffman");
19660c16b537SWarner Losh     }
19670c16b537SWarner Losh 
19680c16b537SWarner Losh     u8 bits[HUF_MAX_SYMBS];
19690c16b537SWarner Losh 
19700c16b537SWarner Losh     u64 weight_sum = 0;
19710c16b537SWarner Losh     for (int i = 0; i < num_symbs; i++) {
19720c16b537SWarner Losh         // Weights are in the same range as bit count
19730c16b537SWarner Losh         if (weights[i] > HUF_MAX_BITS) {
19740c16b537SWarner Losh             CORRUPTION();
19750c16b537SWarner Losh         }
19760c16b537SWarner Losh         weight_sum += weights[i] > 0 ? (u64)1 << (weights[i] - 1) : 0;
19770c16b537SWarner Losh     }
19780c16b537SWarner Losh 
19790c16b537SWarner Losh     // Find the first power of 2 larger than the sum
19800c16b537SWarner Losh     const int max_bits = highest_set_bit(weight_sum) + 1;
19810c16b537SWarner Losh     const u64 left_over = ((u64)1 << max_bits) - weight_sum;
19820c16b537SWarner Losh     // If the left over isn't a power of 2, the weights are invalid
19830c16b537SWarner Losh     if (left_over & (left_over - 1)) {
19840c16b537SWarner Losh         CORRUPTION();
19850c16b537SWarner Losh     }
19860c16b537SWarner Losh 
19870c16b537SWarner Losh     // left_over is used to find the last weight as it's not transmitted
19880c16b537SWarner Losh     // by inverting 2^(weight - 1) we can determine the value of last_weight
19890c16b537SWarner Losh     const int last_weight = highest_set_bit(left_over) + 1;
19900c16b537SWarner Losh 
19910c16b537SWarner Losh     for (int i = 0; i < num_symbs; i++) {
19920c16b537SWarner Losh         // "Number_of_Bits = Number_of_Bits ? Max_Number_of_Bits + 1 - Weight : 0"
19930c16b537SWarner Losh         bits[i] = weights[i] > 0 ? (max_bits + 1 - weights[i]) : 0;
19940c16b537SWarner Losh     }
19950c16b537SWarner Losh     bits[num_symbs] =
19960c16b537SWarner Losh         max_bits + 1 - last_weight; // Last weight is always non-zero
19970c16b537SWarner Losh 
19980c16b537SWarner Losh     HUF_init_dtable(table, bits, num_symbs + 1);
19990c16b537SWarner Losh }
20000c16b537SWarner Losh 
HUF_free_dtable(HUF_dtable * const dtable)20010c16b537SWarner Losh static void HUF_free_dtable(HUF_dtable *const dtable) {
20020c16b537SWarner Losh     free(dtable->symbols);
20030c16b537SWarner Losh     free(dtable->num_bits);
20040c16b537SWarner Losh     memset(dtable, 0, sizeof(HUF_dtable));
20050c16b537SWarner Losh }
20060c16b537SWarner Losh /******* END HUFFMAN PRIMITIVES ***********************************************/
20070c16b537SWarner Losh 
20080c16b537SWarner Losh /******* FSE PRIMITIVES *******************************************************/
20090c16b537SWarner Losh /// For more description of FSE see
20100c16b537SWarner Losh /// https://github.com/Cyan4973/FiniteStateEntropy/
20110c16b537SWarner Losh 
20120c16b537SWarner Losh /// Allow a symbol to be decoded without updating state
FSE_peek_symbol(const FSE_dtable * const dtable,const u16 state)20130c16b537SWarner Losh static inline u8 FSE_peek_symbol(const FSE_dtable *const dtable,
20140c16b537SWarner Losh                                  const u16 state) {
20150c16b537SWarner Losh     return dtable->symbols[state];
20160c16b537SWarner Losh }
20170c16b537SWarner Losh 
20180c16b537SWarner Losh /// Consumes bits from the input and uses the current state to determine the
20190c16b537SWarner Losh /// next state
FSE_update_state(const FSE_dtable * const dtable,u16 * const state,const u8 * const src,i64 * const offset)20200c16b537SWarner Losh static inline void FSE_update_state(const FSE_dtable *const dtable,
20210c16b537SWarner Losh                                     u16 *const state, const u8 *const src,
20220c16b537SWarner Losh                                     i64 *const offset) {
20230c16b537SWarner Losh     const u8 bits = dtable->num_bits[*state];
20240c16b537SWarner Losh     const u16 rest = STREAM_read_bits(src, bits, offset);
20250c16b537SWarner Losh     *state = dtable->new_state_base[*state] + rest;
20260c16b537SWarner Losh }
20270c16b537SWarner Losh 
20280c16b537SWarner Losh /// Decodes a single FSE symbol and updates the offset
FSE_decode_symbol(const FSE_dtable * const dtable,u16 * const state,const u8 * const src,i64 * const offset)20290c16b537SWarner Losh static inline u8 FSE_decode_symbol(const FSE_dtable *const dtable,
20300c16b537SWarner Losh                                    u16 *const state, const u8 *const src,
20310c16b537SWarner Losh                                    i64 *const offset) {
20320c16b537SWarner Losh     const u8 symb = FSE_peek_symbol(dtable, *state);
20330c16b537SWarner Losh     FSE_update_state(dtable, state, src, offset);
20340c16b537SWarner Losh     return symb;
20350c16b537SWarner Losh }
20360c16b537SWarner Losh 
FSE_init_state(const FSE_dtable * const dtable,u16 * const state,const u8 * const src,i64 * const offset)20370c16b537SWarner Losh static inline void FSE_init_state(const FSE_dtable *const dtable,
20380c16b537SWarner Losh                                   u16 *const state, const u8 *const src,
20390c16b537SWarner Losh                                   i64 *const offset) {
20400c16b537SWarner Losh     // Read in a full `accuracy_log` bits to initialize the state
20410c16b537SWarner Losh     const u8 bits = dtable->accuracy_log;
20420c16b537SWarner Losh     *state = STREAM_read_bits(src, bits, offset);
20430c16b537SWarner Losh }
20440c16b537SWarner Losh 
FSE_decompress_interleaved2(const FSE_dtable * const dtable,ostream_t * const out,istream_t * const in)20450c16b537SWarner Losh static size_t FSE_decompress_interleaved2(const FSE_dtable *const dtable,
20460c16b537SWarner Losh                                           ostream_t *const out,
20470c16b537SWarner Losh                                           istream_t *const in) {
20480c16b537SWarner Losh     const size_t len = IO_istream_len(in);
20490c16b537SWarner Losh     if (len == 0) {
20500c16b537SWarner Losh         INP_SIZE();
20510c16b537SWarner Losh     }
20520c16b537SWarner Losh     const u8 *const src = IO_get_read_ptr(in, len);
20530c16b537SWarner Losh 
20540c16b537SWarner Losh     // "Each bitstream must be read backward, that is starting from the end down
20550c16b537SWarner Losh     // to the beginning. Therefore it's necessary to know the size of each
20560c16b537SWarner Losh     // bitstream.
20570c16b537SWarner Losh     //
20580c16b537SWarner Losh     // It's also necessary to know exactly which bit is the latest. This is
20590c16b537SWarner Losh     // detected by a final bit flag : the highest bit of latest byte is a
20600c16b537SWarner Losh     // final-bit-flag. Consequently, a last byte of 0 is not possible. And the
20610c16b537SWarner Losh     // final-bit-flag itself is not part of the useful bitstream. Hence, the
20620c16b537SWarner Losh     // last byte contains between 0 and 7 useful bits."
20630c16b537SWarner Losh     const int padding = 8 - highest_set_bit(src[len - 1]);
20640c16b537SWarner Losh     i64 offset = len * 8 - padding;
20650c16b537SWarner Losh 
20660c16b537SWarner Losh     u16 state1, state2;
20670c16b537SWarner Losh     // "The first state (State1) encodes the even indexed symbols, and the
20680c16b537SWarner Losh     // second (State2) encodes the odd indexes. State1 is initialized first, and
20690c16b537SWarner Losh     // then State2, and they take turns decoding a single symbol and updating
20700c16b537SWarner Losh     // their state."
20710c16b537SWarner Losh     FSE_init_state(dtable, &state1, src, &offset);
20720c16b537SWarner Losh     FSE_init_state(dtable, &state2, src, &offset);
20730c16b537SWarner Losh 
20740c16b537SWarner Losh     // Decode until we overflow the stream
20750c16b537SWarner Losh     // Since we decode in reverse order, overflowing the stream is offset going
20760c16b537SWarner Losh     // negative
20770c16b537SWarner Losh     size_t symbols_written = 0;
20780c16b537SWarner Losh     while (1) {
20790c16b537SWarner Losh         // "The number of symbols to decode is determined by tracking bitStream
20800c16b537SWarner Losh         // overflow condition: If updating state after decoding a symbol would
20810c16b537SWarner Losh         // require more bits than remain in the stream, it is assumed the extra
20820c16b537SWarner Losh         // bits are 0. Then, the symbols for each of the final states are
20830c16b537SWarner Losh         // decoded and the process is complete."
20840c16b537SWarner Losh         IO_write_byte(out, FSE_decode_symbol(dtable, &state1, src, &offset));
20850c16b537SWarner Losh         symbols_written++;
20860c16b537SWarner Losh         if (offset < 0) {
20870c16b537SWarner Losh             // There's still a symbol to decode in state2
20880c16b537SWarner Losh             IO_write_byte(out, FSE_peek_symbol(dtable, state2));
20890c16b537SWarner Losh             symbols_written++;
20900c16b537SWarner Losh             break;
20910c16b537SWarner Losh         }
20920c16b537SWarner Losh 
20930c16b537SWarner Losh         IO_write_byte(out, FSE_decode_symbol(dtable, &state2, src, &offset));
20940c16b537SWarner Losh         symbols_written++;
20950c16b537SWarner Losh         if (offset < 0) {
20960c16b537SWarner Losh             // There's still a symbol to decode in state1
20970c16b537SWarner Losh             IO_write_byte(out, FSE_peek_symbol(dtable, state1));
20980c16b537SWarner Losh             symbols_written++;
20990c16b537SWarner Losh             break;
21000c16b537SWarner Losh         }
21010c16b537SWarner Losh     }
21020c16b537SWarner Losh 
21030c16b537SWarner Losh     return symbols_written;
21040c16b537SWarner Losh }
21050c16b537SWarner Losh 
FSE_init_dtable(FSE_dtable * const dtable,const i16 * const norm_freqs,const int num_symbs,const int accuracy_log)21060c16b537SWarner Losh static void FSE_init_dtable(FSE_dtable *const dtable,
21070c16b537SWarner Losh                             const i16 *const norm_freqs, const int num_symbs,
21080c16b537SWarner Losh                             const int accuracy_log) {
21090c16b537SWarner Losh     if (accuracy_log > FSE_MAX_ACCURACY_LOG) {
21100c16b537SWarner Losh         ERROR("FSE accuracy too large");
21110c16b537SWarner Losh     }
21120c16b537SWarner Losh     if (num_symbs > FSE_MAX_SYMBS) {
21130c16b537SWarner Losh         ERROR("Too many symbols for FSE");
21140c16b537SWarner Losh     }
21150c16b537SWarner Losh 
21160c16b537SWarner Losh     dtable->accuracy_log = accuracy_log;
21170c16b537SWarner Losh 
21180c16b537SWarner Losh     const size_t size = (size_t)1 << accuracy_log;
21190c16b537SWarner Losh     dtable->symbols = malloc(size * sizeof(u8));
21200c16b537SWarner Losh     dtable->num_bits = malloc(size * sizeof(u8));
21210c16b537SWarner Losh     dtable->new_state_base = malloc(size * sizeof(u16));
21220c16b537SWarner Losh 
21230c16b537SWarner Losh     if (!dtable->symbols || !dtable->num_bits || !dtable->new_state_base) {
21240c16b537SWarner Losh         BAD_ALLOC();
21250c16b537SWarner Losh     }
21260c16b537SWarner Losh 
21270c16b537SWarner Losh     // Used to determine how many bits need to be read for each state,
21280c16b537SWarner Losh     // and where the destination range should start
21290c16b537SWarner Losh     // Needs to be u16 because max value is 2 * max number of symbols,
21300c16b537SWarner Losh     // which can be larger than a byte can store
21310c16b537SWarner Losh     u16 state_desc[FSE_MAX_SYMBS];
21320c16b537SWarner Losh 
21330c16b537SWarner Losh     // "Symbols are scanned in their natural order for "less than 1"
21340c16b537SWarner Losh     // probabilities. Symbols with this probability are being attributed a
21350c16b537SWarner Losh     // single cell, starting from the end of the table. These symbols define a
21360c16b537SWarner Losh     // full state reset, reading Accuracy_Log bits."
21370c16b537SWarner Losh     int high_threshold = size;
21380c16b537SWarner Losh     for (int s = 0; s < num_symbs; s++) {
21390c16b537SWarner Losh         // Scan for low probability symbols to put at the top
21400c16b537SWarner Losh         if (norm_freqs[s] == -1) {
21410c16b537SWarner Losh             dtable->symbols[--high_threshold] = s;
21420c16b537SWarner Losh             state_desc[s] = 1;
21430c16b537SWarner Losh         }
21440c16b537SWarner Losh     }
21450c16b537SWarner Losh 
21460c16b537SWarner Losh     // "All remaining symbols are sorted in their natural order. Starting from
21470c16b537SWarner Losh     // symbol 0 and table position 0, each symbol gets attributed as many cells
2148*5ff13fbcSAllan Jude     // as its probability. Cell allocation is spread, not linear."
21490c16b537SWarner Losh     // Place the rest in the table
21500c16b537SWarner Losh     const u16 step = (size >> 1) + (size >> 3) + 3;
21510c16b537SWarner Losh     const u16 mask = size - 1;
21520c16b537SWarner Losh     u16 pos = 0;
21530c16b537SWarner Losh     for (int s = 0; s < num_symbs; s++) {
21540c16b537SWarner Losh         if (norm_freqs[s] <= 0) {
21550c16b537SWarner Losh             continue;
21560c16b537SWarner Losh         }
21570c16b537SWarner Losh 
21580c16b537SWarner Losh         state_desc[s] = norm_freqs[s];
21590c16b537SWarner Losh 
21600c16b537SWarner Losh         for (int i = 0; i < norm_freqs[s]; i++) {
21610c16b537SWarner Losh             // Give `norm_freqs[s]` states to symbol s
21620c16b537SWarner Losh             dtable->symbols[pos] = s;
21630c16b537SWarner Losh             // "A position is skipped if already occupied, typically by a "less
21640c16b537SWarner Losh             // than 1" probability symbol."
21650c16b537SWarner Losh             do {
21660c16b537SWarner Losh                 pos = (pos + step) & mask;
21670c16b537SWarner Losh             } while (pos >=
21680c16b537SWarner Losh                      high_threshold);
21690c16b537SWarner Losh             // Note: no other collision checking is necessary as `step` is
21700c16b537SWarner Losh             // coprime to `size`, so the cycle will visit each position exactly
21710c16b537SWarner Losh             // once
21720c16b537SWarner Losh         }
21730c16b537SWarner Losh     }
21740c16b537SWarner Losh     if (pos != 0) {
21750c16b537SWarner Losh         CORRUPTION();
21760c16b537SWarner Losh     }
21770c16b537SWarner Losh 
21780c16b537SWarner Losh     // Now we can fill baseline and num bits
21790c16b537SWarner Losh     for (size_t i = 0; i < size; i++) {
21800c16b537SWarner Losh         u8 symbol = dtable->symbols[i];
21810c16b537SWarner Losh         u16 next_state_desc = state_desc[symbol]++;
21820c16b537SWarner Losh         // Fills in the table appropriately, next_state_desc increases by symbol
21830c16b537SWarner Losh         // over time, decreasing number of bits
21840c16b537SWarner Losh         dtable->num_bits[i] = (u8)(accuracy_log - highest_set_bit(next_state_desc));
21850c16b537SWarner Losh         // Baseline increases until the bit threshold is passed, at which point
21860c16b537SWarner Losh         // it resets to 0
21870c16b537SWarner Losh         dtable->new_state_base[i] =
21880c16b537SWarner Losh             ((u16)next_state_desc << dtable->num_bits[i]) - size;
21890c16b537SWarner Losh     }
21900c16b537SWarner Losh }
21910c16b537SWarner Losh 
21920c16b537SWarner Losh /// Decode an FSE header as defined in the Zstandard format specification and
21930c16b537SWarner Losh /// use the decoded frequencies to initialize a decoding table.
FSE_decode_header(FSE_dtable * const dtable,istream_t * const in,const int max_accuracy_log)21940c16b537SWarner Losh static void FSE_decode_header(FSE_dtable *const dtable, istream_t *const in,
21950c16b537SWarner Losh                                 const int max_accuracy_log) {
21960c16b537SWarner Losh     // "An FSE distribution table describes the probabilities of all symbols
21970c16b537SWarner Losh     // from 0 to the last present one (included) on a normalized scale of 1 <<
21980c16b537SWarner Losh     // Accuracy_Log .
21990c16b537SWarner Losh     //
22000c16b537SWarner Losh     // It's a bitstream which is read forward, in little-endian fashion. It's
22010c16b537SWarner Losh     // not necessary to know its exact size, since it will be discovered and
22020c16b537SWarner Losh     // reported by the decoding process.
22030c16b537SWarner Losh     if (max_accuracy_log > FSE_MAX_ACCURACY_LOG) {
22040c16b537SWarner Losh         ERROR("FSE accuracy too large");
22050c16b537SWarner Losh     }
22060c16b537SWarner Losh 
22070c16b537SWarner Losh     // The bitstream starts by reporting on which scale it operates.
22080c16b537SWarner Losh     // Accuracy_Log = low4bits + 5. Note that maximum Accuracy_Log for literal
22090c16b537SWarner Losh     // and match lengths is 9, and for offsets is 8. Higher values are
22100c16b537SWarner Losh     // considered errors."
22110c16b537SWarner Losh     const int accuracy_log = 5 + IO_read_bits(in, 4);
22120c16b537SWarner Losh     if (accuracy_log > max_accuracy_log) {
22130c16b537SWarner Losh         ERROR("FSE accuracy too large");
22140c16b537SWarner Losh     }
22150c16b537SWarner Losh 
22160c16b537SWarner Losh     // "Then follows each symbol value, from 0 to last present one. The number
22170c16b537SWarner Losh     // of bits used by each field is variable. It depends on :
22180c16b537SWarner Losh     //
22190c16b537SWarner Losh     // Remaining probabilities + 1 : example : Presuming an Accuracy_Log of 8,
22200c16b537SWarner Losh     // and presuming 100 probabilities points have already been distributed, the
22210c16b537SWarner Losh     // decoder may read any value from 0 to 255 - 100 + 1 == 156 (inclusive).
22220c16b537SWarner Losh     // Therefore, it must read log2sup(156) == 8 bits.
22230c16b537SWarner Losh     //
22240c16b537SWarner Losh     // Value decoded : small values use 1 less bit : example : Presuming values
22250c16b537SWarner Losh     // from 0 to 156 (inclusive) are possible, 255-156 = 99 values are remaining
22260c16b537SWarner Losh     // in an 8-bits field. They are used this way : first 99 values (hence from
22270c16b537SWarner Losh     // 0 to 98) use only 7 bits, values from 99 to 156 use 8 bits. "
22280c16b537SWarner Losh 
22290c16b537SWarner Losh     i32 remaining = 1 << accuracy_log;
22300c16b537SWarner Losh     i16 frequencies[FSE_MAX_SYMBS];
22310c16b537SWarner Losh 
22320c16b537SWarner Losh     int symb = 0;
22330c16b537SWarner Losh     while (remaining > 0 && symb < FSE_MAX_SYMBS) {
22340c16b537SWarner Losh         // Log of the number of possible values we could read
22350c16b537SWarner Losh         int bits = highest_set_bit(remaining + 1) + 1;
22360c16b537SWarner Losh 
22370c16b537SWarner Losh         u16 val = IO_read_bits(in, bits);
22380c16b537SWarner Losh 
22390c16b537SWarner Losh         // Try to mask out the lower bits to see if it qualifies for the "small
22400c16b537SWarner Losh         // value" threshold
22410c16b537SWarner Losh         const u16 lower_mask = ((u16)1 << (bits - 1)) - 1;
22420c16b537SWarner Losh         const u16 threshold = ((u16)1 << bits) - 1 - (remaining + 1);
22430c16b537SWarner Losh 
22440c16b537SWarner Losh         if ((val & lower_mask) < threshold) {
22450c16b537SWarner Losh             IO_rewind_bits(in, 1);
22460c16b537SWarner Losh             val = val & lower_mask;
22470c16b537SWarner Losh         } else if (val > lower_mask) {
22480c16b537SWarner Losh             val = val - threshold;
22490c16b537SWarner Losh         }
22500c16b537SWarner Losh 
22510c16b537SWarner Losh         // "Probability is obtained from Value decoded by following formula :
22520c16b537SWarner Losh         // Proba = value - 1"
22530c16b537SWarner Losh         const i16 proba = (i16)val - 1;
22540c16b537SWarner Losh 
22550c16b537SWarner Losh         // "It means value 0 becomes negative probability -1. -1 is a special
22560c16b537SWarner Losh         // probability, which means "less than 1". Its effect on distribution
22570c16b537SWarner Losh         // table is described in next paragraph. For the purpose of calculating
22580c16b537SWarner Losh         // cumulated distribution, it counts as one."
22590c16b537SWarner Losh         remaining -= proba < 0 ? -proba : proba;
22600c16b537SWarner Losh 
22610c16b537SWarner Losh         frequencies[symb] = proba;
22620c16b537SWarner Losh         symb++;
22630c16b537SWarner Losh 
22640c16b537SWarner Losh         // "When a symbol has a probability of zero, it is followed by a 2-bits
22650c16b537SWarner Losh         // repeat flag. This repeat flag tells how many probabilities of zeroes
22660c16b537SWarner Losh         // follow the current one. It provides a number ranging from 0 to 3. If
22670c16b537SWarner Losh         // it is a 3, another 2-bits repeat flag follows, and so on."
22680c16b537SWarner Losh         if (proba == 0) {
22690c16b537SWarner Losh             // Read the next two bits to see how many more 0s
22700c16b537SWarner Losh             int repeat = IO_read_bits(in, 2);
22710c16b537SWarner Losh 
22720c16b537SWarner Losh             while (1) {
22730c16b537SWarner Losh                 for (int i = 0; i < repeat && symb < FSE_MAX_SYMBS; i++) {
22740c16b537SWarner Losh                     frequencies[symb++] = 0;
22750c16b537SWarner Losh                 }
22760c16b537SWarner Losh                 if (repeat == 3) {
22770c16b537SWarner Losh                     repeat = IO_read_bits(in, 2);
22780c16b537SWarner Losh                 } else {
22790c16b537SWarner Losh                     break;
22800c16b537SWarner Losh                 }
22810c16b537SWarner Losh             }
22820c16b537SWarner Losh         }
22830c16b537SWarner Losh     }
22840c16b537SWarner Losh     IO_align_stream(in);
22850c16b537SWarner Losh 
22860c16b537SWarner Losh     // "When last symbol reaches cumulated total of 1 << Accuracy_Log, decoding
22870c16b537SWarner Losh     // is complete. If the last symbol makes cumulated total go above 1 <<
22880c16b537SWarner Losh     // Accuracy_Log, distribution is considered corrupted."
22890c16b537SWarner Losh     if (remaining != 0 || symb >= FSE_MAX_SYMBS) {
22900c16b537SWarner Losh         CORRUPTION();
22910c16b537SWarner Losh     }
22920c16b537SWarner Losh 
22930c16b537SWarner Losh     // Initialize the decoding table using the determined weights
22940c16b537SWarner Losh     FSE_init_dtable(dtable, frequencies, symb, accuracy_log);
22950c16b537SWarner Losh }
22960c16b537SWarner Losh 
FSE_init_dtable_rle(FSE_dtable * const dtable,const u8 symb)22970c16b537SWarner Losh static void FSE_init_dtable_rle(FSE_dtable *const dtable, const u8 symb) {
22980c16b537SWarner Losh     dtable->symbols = malloc(sizeof(u8));
22990c16b537SWarner Losh     dtable->num_bits = malloc(sizeof(u8));
23000c16b537SWarner Losh     dtable->new_state_base = malloc(sizeof(u16));
23010c16b537SWarner Losh 
23020c16b537SWarner Losh     if (!dtable->symbols || !dtable->num_bits || !dtable->new_state_base) {
23030c16b537SWarner Losh         BAD_ALLOC();
23040c16b537SWarner Losh     }
23050c16b537SWarner Losh 
23060c16b537SWarner Losh     // This setup will always have a state of 0, always return symbol `symb`,
23070c16b537SWarner Losh     // and never consume any bits
23080c16b537SWarner Losh     dtable->symbols[0] = symb;
23090c16b537SWarner Losh     dtable->num_bits[0] = 0;
23100c16b537SWarner Losh     dtable->new_state_base[0] = 0;
23110c16b537SWarner Losh     dtable->accuracy_log = 0;
23120c16b537SWarner Losh }
23130c16b537SWarner Losh 
FSE_free_dtable(FSE_dtable * const dtable)23140c16b537SWarner Losh static void FSE_free_dtable(FSE_dtable *const dtable) {
23150c16b537SWarner Losh     free(dtable->symbols);
23160c16b537SWarner Losh     free(dtable->num_bits);
23170c16b537SWarner Losh     free(dtable->new_state_base);
23180c16b537SWarner Losh     memset(dtable, 0, sizeof(FSE_dtable));
23190c16b537SWarner Losh }
23200c16b537SWarner Losh /******* END FSE PRIMITIVES ***************************************************/
2321