1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under both the BSD-style license (found in the 6 * LICENSE file in the root directory of this source tree) and the GPLv2 (found 7 * in the COPYING file in the root directory of this source tree). 8 * You may select, at your option, one of the above-listed licenses. 9 */ 10 11 /* 12 This program takes a file in input, 13 performs a zstd round-trip test (compression - decompress) 14 compares the result with original 15 and generates a crash (double free) on corruption detection. 16 */ 17 18 /*=========================================== 19 * Dependencies 20 *==========================================*/ 21 #include <stddef.h> /* size_t */ 22 #include <stdlib.h> /* malloc, free, exit */ 23 #include <stdio.h> /* fprintf */ 24 #include <string.h> /* strcmp */ 25 #include <sys/types.h> /* stat */ 26 #include <sys/stat.h> /* stat */ 27 #include "xxhash.h" 28 29 #define ZSTD_STATIC_LINKING_ONLY 30 #include "zstd.h" 31 32 /*=========================================== 33 * Macros 34 *==========================================*/ 35 #define MIN(a,b) ( (a) < (b) ? (a) : (b) ) 36 37 static void crash(int errorCode){ 38 /* abort if AFL/libfuzzer, exit otherwise */ 39 #ifdef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION /* could also use __AFL_COMPILER */ 40 abort(); 41 #else 42 exit(errorCode); 43 #endif 44 } 45 46 #define CHECK_Z(f) { \ 47 size_t const err = f; \ 48 if (ZSTD_isError(err)) { \ 49 fprintf(stderr, \ 50 "Error=> %s: %s", \ 51 #f, ZSTD_getErrorName(err)); \ 52 crash(1); \ 53 } } 54 55 /** roundTripTest() : 56 * Compresses `srcBuff` into `compressedBuff`, 57 * then decompresses `compressedBuff` into `resultBuff`. 58 * Compression level used is derived from first content byte. 59 * @return : result of decompression, which should be == `srcSize` 60 * or an error code if either compression or decompression fails. 61 * Note : `compressedBuffCapacity` should be `>= ZSTD_compressBound(srcSize)` 62 * for compression to be guaranteed to work */ 63 static size_t roundTripTest(void* resultBuff, size_t resultBuffCapacity, 64 void* compressedBuff, size_t compressedBuffCapacity, 65 const void* srcBuff, size_t srcBuffSize) 66 { 67 static const int maxClevel = 19; 68 size_t const hashLength = MIN(128, srcBuffSize); 69 unsigned const h32 = XXH32(srcBuff, hashLength, 0); 70 int const cLevel = h32 % maxClevel; 71 size_t const cSize = ZSTD_compress(compressedBuff, compressedBuffCapacity, srcBuff, srcBuffSize, cLevel); 72 if (ZSTD_isError(cSize)) { 73 fprintf(stderr, "Compression error : %s \n", ZSTD_getErrorName(cSize)); 74 return cSize; 75 } 76 return ZSTD_decompress(resultBuff, resultBuffCapacity, compressedBuff, cSize); 77 } 78 79 /** cctxParamRoundTripTest() : 80 * Same as roundTripTest() except allows experimenting with ZSTD_CCtx_params. */ 81 static size_t cctxParamRoundTripTest(void* resultBuff, size_t resultBuffCapacity, 82 void* compressedBuff, size_t compressedBuffCapacity, 83 const void* srcBuff, size_t srcBuffSize) 84 { 85 ZSTD_CCtx* const cctx = ZSTD_createCCtx(); 86 ZSTD_CCtx_params* const cctxParams = ZSTD_createCCtxParams(); 87 ZSTD_inBuffer inBuffer = { srcBuff, srcBuffSize, 0 }; 88 ZSTD_outBuffer outBuffer = { compressedBuff, compressedBuffCapacity, 0 }; 89 90 static const int maxClevel = 19; 91 size_t const hashLength = MIN(128, srcBuffSize); 92 unsigned const h32 = XXH32(srcBuff, hashLength, 0); 93 int const cLevel = h32 % maxClevel; 94 95 /* Set parameters */ 96 CHECK_Z( ZSTD_CCtxParams_setParameter(cctxParams, ZSTD_c_compressionLevel, cLevel) ); 97 CHECK_Z( ZSTD_CCtxParams_setParameter(cctxParams, ZSTD_c_nbWorkers, 2) ); 98 CHECK_Z( ZSTD_CCtxParams_setParameter(cctxParams, ZSTD_c_overlapLog, 5) ); 99 100 101 /* Apply parameters */ 102 CHECK_Z( ZSTD_CCtx_setParametersUsingCCtxParams(cctx, cctxParams) ); 103 104 CHECK_Z (ZSTD_compressStream2(cctx, &outBuffer, &inBuffer, ZSTD_e_end) ); 105 106 ZSTD_freeCCtxParams(cctxParams); 107 ZSTD_freeCCtx(cctx); 108 109 return ZSTD_decompress(resultBuff, resultBuffCapacity, compressedBuff, outBuffer.pos); 110 } 111 112 static size_t checkBuffers(const void* buff1, const void* buff2, size_t buffSize) 113 { 114 const char* ip1 = (const char*)buff1; 115 const char* ip2 = (const char*)buff2; 116 size_t pos; 117 118 for (pos=0; pos<buffSize; pos++) 119 if (ip1[pos]!=ip2[pos]) 120 break; 121 122 return pos; 123 } 124 125 static void roundTripCheck(const void* srcBuff, size_t srcBuffSize, int testCCtxParams) 126 { 127 size_t const cBuffSize = ZSTD_compressBound(srcBuffSize); 128 void* cBuff = malloc(cBuffSize); 129 void* rBuff = malloc(cBuffSize); 130 131 if (!cBuff || !rBuff) { 132 fprintf(stderr, "not enough memory ! \n"); 133 exit (1); 134 } 135 136 { size_t const result = testCCtxParams ? 137 cctxParamRoundTripTest(rBuff, cBuffSize, cBuff, cBuffSize, srcBuff, srcBuffSize) 138 : roundTripTest(rBuff, cBuffSize, cBuff, cBuffSize, srcBuff, srcBuffSize); 139 if (ZSTD_isError(result)) { 140 fprintf(stderr, "roundTripTest error : %s \n", ZSTD_getErrorName(result)); 141 crash(1); 142 } 143 if (result != srcBuffSize) { 144 fprintf(stderr, "Incorrect regenerated size : %u != %u\n", (unsigned)result, (unsigned)srcBuffSize); 145 crash(1); 146 } 147 if (checkBuffers(srcBuff, rBuff, srcBuffSize) != srcBuffSize) { 148 fprintf(stderr, "Silent decoding corruption !!!"); 149 crash(1); 150 } 151 } 152 153 free(cBuff); 154 free(rBuff); 155 } 156 157 158 static size_t getFileSize(const char* infilename) 159 { 160 int r; 161 #if defined(_MSC_VER) 162 struct _stat64 statbuf; 163 r = _stat64(infilename, &statbuf); 164 if (r || !(statbuf.st_mode & S_IFREG)) return 0; /* No good... */ 165 #else 166 struct stat statbuf; 167 r = stat(infilename, &statbuf); 168 if (r || !S_ISREG(statbuf.st_mode)) return 0; /* No good... */ 169 #endif 170 return (size_t)statbuf.st_size; 171 } 172 173 174 static int isDirectory(const char* infilename) 175 { 176 int r; 177 #if defined(_MSC_VER) 178 struct _stat64 statbuf; 179 r = _stat64(infilename, &statbuf); 180 if (!r && (statbuf.st_mode & _S_IFDIR)) return 1; 181 #else 182 struct stat statbuf; 183 r = stat(infilename, &statbuf); 184 if (!r && S_ISDIR(statbuf.st_mode)) return 1; 185 #endif 186 return 0; 187 } 188 189 190 /** loadFile() : 191 * requirement : `buffer` size >= `fileSize` */ 192 static void loadFile(void* buffer, const char* fileName, size_t fileSize) 193 { 194 FILE* const f = fopen(fileName, "rb"); 195 if (isDirectory(fileName)) { 196 fprintf(stderr, "Ignoring %s directory \n", fileName); 197 exit(2); 198 } 199 if (f==NULL) { 200 fprintf(stderr, "Impossible to open %s \n", fileName); 201 exit(3); 202 } 203 { size_t const readSize = fread(buffer, 1, fileSize, f); 204 if (readSize != fileSize) { 205 fprintf(stderr, "Error reading %s \n", fileName); 206 exit(5); 207 } } 208 fclose(f); 209 } 210 211 212 static void fileCheck(const char* fileName, int testCCtxParams) 213 { 214 size_t const fileSize = getFileSize(fileName); 215 void* const buffer = malloc(fileSize + !fileSize /* avoid 0 */); 216 if (!buffer) { 217 fprintf(stderr, "not enough memory \n"); 218 exit(4); 219 } 220 loadFile(buffer, fileName, fileSize); 221 roundTripCheck(buffer, fileSize, testCCtxParams); 222 free (buffer); 223 } 224 225 int main(int argCount, const char** argv) { 226 int argNb = 1; 227 int testCCtxParams = 0; 228 if (argCount < 2) { 229 fprintf(stderr, "Error : no argument : need input file \n"); 230 exit(9); 231 } 232 233 if (!strcmp(argv[argNb], "--cctxParams")) { 234 testCCtxParams = 1; 235 argNb++; 236 } 237 238 fileCheck(argv[argNb], testCCtxParams); 239 fprintf(stderr, "no pb detected\n"); 240 return 0; 241 } 242