1*3117ece4Schristos /* 2*3117ece4Schristos * Copyright (c) Meta Platforms, Inc. and affiliates. 3*3117ece4Schristos * All rights reserved. 4*3117ece4Schristos * 5*3117ece4Schristos * This source code is licensed under both the BSD-style license (found in the 6*3117ece4Schristos * LICENSE file in the root directory of this source tree) and the GPLv2 (found 7*3117ece4Schristos * in the COPYING file in the root directory of this source tree). 8*3117ece4Schristos * You may select, at your option, one of the above-listed licenses. 9*3117ece4Schristos */ 10*3117ece4Schristos 11*3117ece4Schristos /* 12*3117ece4Schristos This program takes a file in input, 13*3117ece4Schristos performs a zstd round-trip test (compression - decompress) 14*3117ece4Schristos compares the result with original 15*3117ece4Schristos and generates a crash (double free) on corruption detection. 16*3117ece4Schristos */ 17*3117ece4Schristos 18*3117ece4Schristos /*=========================================== 19*3117ece4Schristos * Dependencies 20*3117ece4Schristos *==========================================*/ 21*3117ece4Schristos #include <stddef.h> /* size_t */ 22*3117ece4Schristos #include <stdlib.h> /* malloc, free, exit */ 23*3117ece4Schristos #include <stdio.h> /* fprintf */ 24*3117ece4Schristos #include <string.h> /* strcmp */ 25*3117ece4Schristos #include <sys/types.h> /* stat */ 26*3117ece4Schristos #include <sys/stat.h> /* stat */ 27*3117ece4Schristos #include "xxhash.h" 28*3117ece4Schristos 29*3117ece4Schristos #define ZSTD_STATIC_LINKING_ONLY 30*3117ece4Schristos #include "zstd.h" 31*3117ece4Schristos 32*3117ece4Schristos /*=========================================== 33*3117ece4Schristos * Macros 34*3117ece4Schristos *==========================================*/ 35*3117ece4Schristos #define MIN(a,b) ( (a) < (b) ? (a) : (b) ) 36*3117ece4Schristos 37*3117ece4Schristos static void crash(int errorCode){ 38*3117ece4Schristos /* abort if AFL/libfuzzer, exit otherwise */ 39*3117ece4Schristos #ifdef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION /* could also use __AFL_COMPILER */ 40*3117ece4Schristos abort(); 41*3117ece4Schristos #else 42*3117ece4Schristos exit(errorCode); 43*3117ece4Schristos #endif 44*3117ece4Schristos } 45*3117ece4Schristos 46*3117ece4Schristos #define CHECK_Z(f) { \ 47*3117ece4Schristos size_t const err = f; \ 48*3117ece4Schristos if (ZSTD_isError(err)) { \ 49*3117ece4Schristos fprintf(stderr, \ 50*3117ece4Schristos "Error=> %s: %s", \ 51*3117ece4Schristos #f, ZSTD_getErrorName(err)); \ 52*3117ece4Schristos crash(1); \ 53*3117ece4Schristos } } 54*3117ece4Schristos 55*3117ece4Schristos /** roundTripTest() : 56*3117ece4Schristos * Compresses `srcBuff` into `compressedBuff`, 57*3117ece4Schristos * then decompresses `compressedBuff` into `resultBuff`. 58*3117ece4Schristos * Compression level used is derived from first content byte. 59*3117ece4Schristos * @return : result of decompression, which should be == `srcSize` 60*3117ece4Schristos * or an error code if either compression or decompression fails. 61*3117ece4Schristos * Note : `compressedBuffCapacity` should be `>= ZSTD_compressBound(srcSize)` 62*3117ece4Schristos * for compression to be guaranteed to work */ 63*3117ece4Schristos static size_t roundTripTest(void* resultBuff, size_t resultBuffCapacity, 64*3117ece4Schristos void* compressedBuff, size_t compressedBuffCapacity, 65*3117ece4Schristos const void* srcBuff, size_t srcBuffSize) 66*3117ece4Schristos { 67*3117ece4Schristos static const int maxClevel = 19; 68*3117ece4Schristos size_t const hashLength = MIN(128, srcBuffSize); 69*3117ece4Schristos unsigned const h32 = XXH32(srcBuff, hashLength, 0); 70*3117ece4Schristos int const cLevel = h32 % maxClevel; 71*3117ece4Schristos size_t const cSize = ZSTD_compress(compressedBuff, compressedBuffCapacity, srcBuff, srcBuffSize, cLevel); 72*3117ece4Schristos if (ZSTD_isError(cSize)) { 73*3117ece4Schristos fprintf(stderr, "Compression error : %s \n", ZSTD_getErrorName(cSize)); 74*3117ece4Schristos return cSize; 75*3117ece4Schristos } 76*3117ece4Schristos return ZSTD_decompress(resultBuff, resultBuffCapacity, compressedBuff, cSize); 77*3117ece4Schristos } 78*3117ece4Schristos 79*3117ece4Schristos /** cctxParamRoundTripTest() : 80*3117ece4Schristos * Same as roundTripTest() except allows experimenting with ZSTD_CCtx_params. */ 81*3117ece4Schristos static size_t cctxParamRoundTripTest(void* resultBuff, size_t resultBuffCapacity, 82*3117ece4Schristos void* compressedBuff, size_t compressedBuffCapacity, 83*3117ece4Schristos const void* srcBuff, size_t srcBuffSize) 84*3117ece4Schristos { 85*3117ece4Schristos ZSTD_CCtx* const cctx = ZSTD_createCCtx(); 86*3117ece4Schristos ZSTD_CCtx_params* const cctxParams = ZSTD_createCCtxParams(); 87*3117ece4Schristos ZSTD_inBuffer inBuffer = { srcBuff, srcBuffSize, 0 }; 88*3117ece4Schristos ZSTD_outBuffer outBuffer = { compressedBuff, compressedBuffCapacity, 0 }; 89*3117ece4Schristos 90*3117ece4Schristos static const int maxClevel = 19; 91*3117ece4Schristos size_t const hashLength = MIN(128, srcBuffSize); 92*3117ece4Schristos unsigned const h32 = XXH32(srcBuff, hashLength, 0); 93*3117ece4Schristos int const cLevel = h32 % maxClevel; 94*3117ece4Schristos 95*3117ece4Schristos /* Set parameters */ 96*3117ece4Schristos CHECK_Z( ZSTD_CCtxParams_setParameter(cctxParams, ZSTD_c_compressionLevel, cLevel) ); 97*3117ece4Schristos CHECK_Z( ZSTD_CCtxParams_setParameter(cctxParams, ZSTD_c_nbWorkers, 2) ); 98*3117ece4Schristos CHECK_Z( ZSTD_CCtxParams_setParameter(cctxParams, ZSTD_c_overlapLog, 5) ); 99*3117ece4Schristos 100*3117ece4Schristos 101*3117ece4Schristos /* Apply parameters */ 102*3117ece4Schristos CHECK_Z( ZSTD_CCtx_setParametersUsingCCtxParams(cctx, cctxParams) ); 103*3117ece4Schristos 104*3117ece4Schristos CHECK_Z (ZSTD_compressStream2(cctx, &outBuffer, &inBuffer, ZSTD_e_end) ); 105*3117ece4Schristos 106*3117ece4Schristos ZSTD_freeCCtxParams(cctxParams); 107*3117ece4Schristos ZSTD_freeCCtx(cctx); 108*3117ece4Schristos 109*3117ece4Schristos return ZSTD_decompress(resultBuff, resultBuffCapacity, compressedBuff, outBuffer.pos); 110*3117ece4Schristos } 111*3117ece4Schristos 112*3117ece4Schristos static size_t checkBuffers(const void* buff1, const void* buff2, size_t buffSize) 113*3117ece4Schristos { 114*3117ece4Schristos const char* ip1 = (const char*)buff1; 115*3117ece4Schristos const char* ip2 = (const char*)buff2; 116*3117ece4Schristos size_t pos; 117*3117ece4Schristos 118*3117ece4Schristos for (pos=0; pos<buffSize; pos++) 119*3117ece4Schristos if (ip1[pos]!=ip2[pos]) 120*3117ece4Schristos break; 121*3117ece4Schristos 122*3117ece4Schristos return pos; 123*3117ece4Schristos } 124*3117ece4Schristos 125*3117ece4Schristos static void roundTripCheck(const void* srcBuff, size_t srcBuffSize, int testCCtxParams) 126*3117ece4Schristos { 127*3117ece4Schristos size_t const cBuffSize = ZSTD_compressBound(srcBuffSize); 128*3117ece4Schristos void* cBuff = malloc(cBuffSize); 129*3117ece4Schristos void* rBuff = malloc(cBuffSize); 130*3117ece4Schristos 131*3117ece4Schristos if (!cBuff || !rBuff) { 132*3117ece4Schristos fprintf(stderr, "not enough memory ! \n"); 133*3117ece4Schristos exit (1); 134*3117ece4Schristos } 135*3117ece4Schristos 136*3117ece4Schristos { size_t const result = testCCtxParams ? 137*3117ece4Schristos cctxParamRoundTripTest(rBuff, cBuffSize, cBuff, cBuffSize, srcBuff, srcBuffSize) 138*3117ece4Schristos : roundTripTest(rBuff, cBuffSize, cBuff, cBuffSize, srcBuff, srcBuffSize); 139*3117ece4Schristos if (ZSTD_isError(result)) { 140*3117ece4Schristos fprintf(stderr, "roundTripTest error : %s \n", ZSTD_getErrorName(result)); 141*3117ece4Schristos crash(1); 142*3117ece4Schristos } 143*3117ece4Schristos if (result != srcBuffSize) { 144*3117ece4Schristos fprintf(stderr, "Incorrect regenerated size : %u != %u\n", (unsigned)result, (unsigned)srcBuffSize); 145*3117ece4Schristos crash(1); 146*3117ece4Schristos } 147*3117ece4Schristos if (checkBuffers(srcBuff, rBuff, srcBuffSize) != srcBuffSize) { 148*3117ece4Schristos fprintf(stderr, "Silent decoding corruption !!!"); 149*3117ece4Schristos crash(1); 150*3117ece4Schristos } 151*3117ece4Schristos } 152*3117ece4Schristos 153*3117ece4Schristos free(cBuff); 154*3117ece4Schristos free(rBuff); 155*3117ece4Schristos } 156*3117ece4Schristos 157*3117ece4Schristos 158*3117ece4Schristos static size_t getFileSize(const char* infilename) 159*3117ece4Schristos { 160*3117ece4Schristos int r; 161*3117ece4Schristos #if defined(_MSC_VER) 162*3117ece4Schristos struct _stat64 statbuf; 163*3117ece4Schristos r = _stat64(infilename, &statbuf); 164*3117ece4Schristos if (r || !(statbuf.st_mode & S_IFREG)) return 0; /* No good... */ 165*3117ece4Schristos #else 166*3117ece4Schristos struct stat statbuf; 167*3117ece4Schristos r = stat(infilename, &statbuf); 168*3117ece4Schristos if (r || !S_ISREG(statbuf.st_mode)) return 0; /* No good... */ 169*3117ece4Schristos #endif 170*3117ece4Schristos return (size_t)statbuf.st_size; 171*3117ece4Schristos } 172*3117ece4Schristos 173*3117ece4Schristos 174*3117ece4Schristos static int isDirectory(const char* infilename) 175*3117ece4Schristos { 176*3117ece4Schristos int r; 177*3117ece4Schristos #if defined(_MSC_VER) 178*3117ece4Schristos struct _stat64 statbuf; 179*3117ece4Schristos r = _stat64(infilename, &statbuf); 180*3117ece4Schristos if (!r && (statbuf.st_mode & _S_IFDIR)) return 1; 181*3117ece4Schristos #else 182*3117ece4Schristos struct stat statbuf; 183*3117ece4Schristos r = stat(infilename, &statbuf); 184*3117ece4Schristos if (!r && S_ISDIR(statbuf.st_mode)) return 1; 185*3117ece4Schristos #endif 186*3117ece4Schristos return 0; 187*3117ece4Schristos } 188*3117ece4Schristos 189*3117ece4Schristos 190*3117ece4Schristos /** loadFile() : 191*3117ece4Schristos * requirement : `buffer` size >= `fileSize` */ 192*3117ece4Schristos static void loadFile(void* buffer, const char* fileName, size_t fileSize) 193*3117ece4Schristos { 194*3117ece4Schristos FILE* const f = fopen(fileName, "rb"); 195*3117ece4Schristos if (isDirectory(fileName)) { 196*3117ece4Schristos fprintf(stderr, "Ignoring %s directory \n", fileName); 197*3117ece4Schristos exit(2); 198*3117ece4Schristos } 199*3117ece4Schristos if (f==NULL) { 200*3117ece4Schristos fprintf(stderr, "Impossible to open %s \n", fileName); 201*3117ece4Schristos exit(3); 202*3117ece4Schristos } 203*3117ece4Schristos { size_t const readSize = fread(buffer, 1, fileSize, f); 204*3117ece4Schristos if (readSize != fileSize) { 205*3117ece4Schristos fprintf(stderr, "Error reading %s \n", fileName); 206*3117ece4Schristos exit(5); 207*3117ece4Schristos } } 208*3117ece4Schristos fclose(f); 209*3117ece4Schristos } 210*3117ece4Schristos 211*3117ece4Schristos 212*3117ece4Schristos static void fileCheck(const char* fileName, int testCCtxParams) 213*3117ece4Schristos { 214*3117ece4Schristos size_t const fileSize = getFileSize(fileName); 215*3117ece4Schristos void* const buffer = malloc(fileSize + !fileSize /* avoid 0 */); 216*3117ece4Schristos if (!buffer) { 217*3117ece4Schristos fprintf(stderr, "not enough memory \n"); 218*3117ece4Schristos exit(4); 219*3117ece4Schristos } 220*3117ece4Schristos loadFile(buffer, fileName, fileSize); 221*3117ece4Schristos roundTripCheck(buffer, fileSize, testCCtxParams); 222*3117ece4Schristos free (buffer); 223*3117ece4Schristos } 224*3117ece4Schristos 225*3117ece4Schristos int main(int argCount, const char** argv) { 226*3117ece4Schristos int argNb = 1; 227*3117ece4Schristos int testCCtxParams = 0; 228*3117ece4Schristos if (argCount < 2) { 229*3117ece4Schristos fprintf(stderr, "Error : no argument : need input file \n"); 230*3117ece4Schristos exit(9); 231*3117ece4Schristos } 232*3117ece4Schristos 233*3117ece4Schristos if (!strcmp(argv[argNb], "--cctxParams")) { 234*3117ece4Schristos testCCtxParams = 1; 235*3117ece4Schristos argNb++; 236*3117ece4Schristos } 237*3117ece4Schristos 238*3117ece4Schristos fileCheck(argv[argNb], testCCtxParams); 239*3117ece4Schristos fprintf(stderr, "no pb detected\n"); 240*3117ece4Schristos return 0; 241*3117ece4Schristos } 242