1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under both the BSD-style license (found in the 6 * LICENSE file in the root directory of this source tree) and the GPLv2 (found 7 * in the COPYING file in the root directory of this source tree). 8 * You may select, at your option, one of the above-listed licenses. 9 */ 10 11 #include "method.h" 12 13 #include <stdio.h> 14 #include <stdlib.h> 15 16 #define ZSTD_STATIC_LINKING_ONLY 17 #include <zstd.h> 18 19 #define MIN(x, y) ((x) < (y) ? (x) : (y)) 20 21 static char const* g_zstdcli = NULL; 22 23 void method_set_zstdcli(char const* zstdcli) { 24 g_zstdcli = zstdcli; 25 } 26 27 /** 28 * Macro to get a pointer of type, given ptr, which is a member variable with 29 * the given name, member. 30 * 31 * method_state_t* base = ...; 32 * buffer_state_t* state = container_of(base, buffer_state_t, base); 33 */ 34 #define container_of(ptr, type, member) \ 35 ((type*)(ptr == NULL ? NULL : (char*)(ptr)-offsetof(type, member))) 36 37 /** State to reuse the same buffers between compression calls. */ 38 typedef struct { 39 method_state_t base; 40 data_buffers_t inputs; /**< The input buffer for each file. */ 41 data_buffer_t dictionary; /**< The dictionary. */ 42 data_buffer_t compressed; /**< The compressed data buffer. */ 43 data_buffer_t decompressed; /**< The decompressed data buffer. */ 44 } buffer_state_t; 45 46 static size_t buffers_max_size(data_buffers_t buffers) { 47 size_t max = 0; 48 for (size_t i = 0; i < buffers.size; ++i) { 49 if (buffers.buffers[i].size > max) 50 max = buffers.buffers[i].size; 51 } 52 return max; 53 } 54 55 static method_state_t* buffer_state_create(data_t const* data) { 56 buffer_state_t* state = (buffer_state_t*)calloc(1, sizeof(buffer_state_t)); 57 if (state == NULL) 58 return NULL; 59 state->base.data = data; 60 state->inputs = data_buffers_get(data); 61 state->dictionary = data_buffer_get_dict(data); 62 size_t const max_size = buffers_max_size(state->inputs); 63 state->compressed = data_buffer_create(ZSTD_compressBound(max_size)); 64 state->decompressed = data_buffer_create(max_size); 65 return &state->base; 66 } 67 68 static void buffer_state_destroy(method_state_t* base) { 69 if (base == NULL) 70 return; 71 buffer_state_t* state = container_of(base, buffer_state_t, base); 72 free(state); 73 } 74 75 static int buffer_state_bad( 76 buffer_state_t const* state, 77 config_t const* config) { 78 if (state == NULL) { 79 fprintf(stderr, "buffer_state_t is NULL\n"); 80 return 1; 81 } 82 if (state->inputs.size == 0 || state->compressed.data == NULL || 83 state->decompressed.data == NULL) { 84 fprintf(stderr, "buffer state allocation failure\n"); 85 return 1; 86 } 87 if (config->use_dictionary && state->dictionary.data == NULL) { 88 fprintf(stderr, "dictionary loading failed\n"); 89 return 1; 90 } 91 return 0; 92 } 93 94 static result_t simple_compress(method_state_t* base, config_t const* config) { 95 buffer_state_t* state = container_of(base, buffer_state_t, base); 96 97 if (buffer_state_bad(state, config)) 98 return result_error(result_error_system_error); 99 100 /* Keep the tests short by skipping directories, since behavior shouldn't 101 * change. 102 */ 103 if (base->data->type != data_type_file) 104 return result_error(result_error_skip); 105 106 if (config->advanced_api_only) 107 return result_error(result_error_skip); 108 109 if (config->use_dictionary || config->no_pledged_src_size) 110 return result_error(result_error_skip); 111 112 /* If the config doesn't specify a level, skip. */ 113 int const level = config_get_level(config); 114 if (level == CONFIG_NO_LEVEL) 115 return result_error(result_error_skip); 116 117 data_buffer_t const input = state->inputs.buffers[0]; 118 119 /* Compress, decompress, and check the result. */ 120 state->compressed.size = ZSTD_compress( 121 state->compressed.data, 122 state->compressed.capacity, 123 input.data, 124 input.size, 125 level); 126 if (ZSTD_isError(state->compressed.size)) 127 return result_error(result_error_compression_error); 128 129 state->decompressed.size = ZSTD_decompress( 130 state->decompressed.data, 131 state->decompressed.capacity, 132 state->compressed.data, 133 state->compressed.size); 134 if (ZSTD_isError(state->decompressed.size)) 135 return result_error(result_error_decompression_error); 136 if (data_buffer_compare(input, state->decompressed)) 137 return result_error(result_error_round_trip_error); 138 139 result_data_t data; 140 data.total_size = state->compressed.size; 141 return result_data(data); 142 } 143 144 static result_t compress_cctx_compress( 145 method_state_t* base, 146 config_t const* config) { 147 buffer_state_t* state = container_of(base, buffer_state_t, base); 148 149 if (buffer_state_bad(state, config)) 150 return result_error(result_error_system_error); 151 152 if (config->no_pledged_src_size) 153 return result_error(result_error_skip); 154 155 if (base->data->type != data_type_dir) 156 return result_error(result_error_skip); 157 158 if (config->advanced_api_only) 159 return result_error(result_error_skip); 160 161 int const level = config_get_level(config); 162 163 ZSTD_CCtx* cctx = ZSTD_createCCtx(); 164 ZSTD_DCtx* dctx = ZSTD_createDCtx(); 165 if (cctx == NULL || dctx == NULL) { 166 fprintf(stderr, "context creation failed\n"); 167 return result_error(result_error_system_error); 168 } 169 170 result_t result; 171 result_data_t data = {.total_size = 0}; 172 for (size_t i = 0; i < state->inputs.size; ++i) { 173 data_buffer_t const input = state->inputs.buffers[i]; 174 ZSTD_parameters const params = 175 config_get_zstd_params(config, input.size, state->dictionary.size); 176 177 if (level == CONFIG_NO_LEVEL) 178 state->compressed.size = ZSTD_compress_advanced( 179 cctx, 180 state->compressed.data, 181 state->compressed.capacity, 182 input.data, 183 input.size, 184 config->use_dictionary ? state->dictionary.data : NULL, 185 config->use_dictionary ? state->dictionary.size : 0, 186 params); 187 else if (config->use_dictionary) 188 state->compressed.size = ZSTD_compress_usingDict( 189 cctx, 190 state->compressed.data, 191 state->compressed.capacity, 192 input.data, 193 input.size, 194 state->dictionary.data, 195 state->dictionary.size, 196 level); 197 else 198 state->compressed.size = ZSTD_compressCCtx( 199 cctx, 200 state->compressed.data, 201 state->compressed.capacity, 202 input.data, 203 input.size, 204 level); 205 206 if (ZSTD_isError(state->compressed.size)) { 207 result = result_error(result_error_compression_error); 208 goto out; 209 } 210 211 if (config->use_dictionary) 212 state->decompressed.size = ZSTD_decompress_usingDict( 213 dctx, 214 state->decompressed.data, 215 state->decompressed.capacity, 216 state->compressed.data, 217 state->compressed.size, 218 state->dictionary.data, 219 state->dictionary.size); 220 else 221 state->decompressed.size = ZSTD_decompressDCtx( 222 dctx, 223 state->decompressed.data, 224 state->decompressed.capacity, 225 state->compressed.data, 226 state->compressed.size); 227 if (ZSTD_isError(state->decompressed.size)) { 228 result = result_error(result_error_decompression_error); 229 goto out; 230 } 231 if (data_buffer_compare(input, state->decompressed)) { 232 result = result_error(result_error_round_trip_error); 233 goto out; 234 } 235 236 data.total_size += state->compressed.size; 237 } 238 239 result = result_data(data); 240 out: 241 ZSTD_freeCCtx(cctx); 242 ZSTD_freeDCtx(dctx); 243 return result; 244 } 245 246 /** Generic state creation function. */ 247 static method_state_t* method_state_create(data_t const* data) { 248 method_state_t* state = (method_state_t*)malloc(sizeof(method_state_t)); 249 if (state == NULL) 250 return NULL; 251 state->data = data; 252 return state; 253 } 254 255 static void method_state_destroy(method_state_t* state) { 256 free(state); 257 } 258 259 static result_t cli_compress(method_state_t* state, config_t const* config) { 260 if (config->cli_args == NULL) 261 return result_error(result_error_skip); 262 263 if (config->advanced_api_only) 264 return result_error(result_error_skip); 265 266 /* We don't support no pledged source size with directories. Too slow. */ 267 if (state->data->type == data_type_dir && config->no_pledged_src_size) 268 return result_error(result_error_skip); 269 270 if (g_zstdcli == NULL) 271 return result_error(result_error_system_error); 272 273 /* '<zstd>' -cqr <args> [-D '<dict>'] '<file/dir>' */ 274 char cmd[1024]; 275 size_t const cmd_size = snprintf( 276 cmd, 277 sizeof(cmd), 278 "'%s' -cqr %s %s%s%s %s '%s'", 279 g_zstdcli, 280 config->cli_args, 281 config->use_dictionary ? "-D '" : "", 282 config->use_dictionary ? state->data->dict.path : "", 283 config->use_dictionary ? "'" : "", 284 config->no_pledged_src_size ? "<" : "", 285 state->data->data.path); 286 if (cmd_size >= sizeof(cmd)) { 287 fprintf(stderr, "command too large: %s\n", cmd); 288 return result_error(result_error_system_error); 289 } 290 FILE* zstd = popen(cmd, "r"); 291 if (zstd == NULL) { 292 fprintf(stderr, "failed to popen command: %s\n", cmd); 293 return result_error(result_error_system_error); 294 } 295 296 char out[4096]; 297 size_t total_size = 0; 298 while (1) { 299 size_t const size = fread(out, 1, sizeof(out), zstd); 300 total_size += size; 301 if (size != sizeof(out)) 302 break; 303 } 304 if (ferror(zstd) || pclose(zstd) != 0) { 305 fprintf(stderr, "zstd failed with command: %s\n", cmd); 306 return result_error(result_error_compression_error); 307 } 308 309 result_data_t const data = {.total_size = total_size}; 310 return result_data(data); 311 } 312 313 static int advanced_config( 314 ZSTD_CCtx* cctx, 315 buffer_state_t* state, 316 config_t const* config) { 317 ZSTD_CCtx_reset(cctx, ZSTD_reset_session_and_parameters); 318 for (size_t p = 0; p < config->param_values.size; ++p) { 319 param_value_t const pv = config->param_values.data[p]; 320 if (ZSTD_isError(ZSTD_CCtx_setParameter(cctx, pv.param, pv.value))) { 321 return 1; 322 } 323 } 324 if (config->use_dictionary) { 325 if (ZSTD_isError(ZSTD_CCtx_loadDictionary( 326 cctx, state->dictionary.data, state->dictionary.size))) { 327 return 1; 328 } 329 } 330 return 0; 331 } 332 333 static result_t advanced_one_pass_compress_output_adjustment( 334 method_state_t* base, 335 config_t const* config, 336 size_t const subtract) { 337 buffer_state_t* state = container_of(base, buffer_state_t, base); 338 339 if (buffer_state_bad(state, config)) 340 return result_error(result_error_system_error); 341 342 ZSTD_CCtx* cctx = ZSTD_createCCtx(); 343 result_t result; 344 345 if (!cctx || advanced_config(cctx, state, config)) { 346 result = result_error(result_error_compression_error); 347 goto out; 348 } 349 350 result_data_t data = {.total_size = 0}; 351 for (size_t i = 0; i < state->inputs.size; ++i) { 352 data_buffer_t const input = state->inputs.buffers[i]; 353 354 if (!config->no_pledged_src_size) { 355 if (ZSTD_isError(ZSTD_CCtx_setPledgedSrcSize(cctx, input.size))) { 356 result = result_error(result_error_compression_error); 357 goto out; 358 } 359 } 360 size_t const size = ZSTD_compress2( 361 cctx, 362 state->compressed.data, 363 ZSTD_compressBound(input.size) - subtract, 364 input.data, 365 input.size); 366 if (ZSTD_isError(size)) { 367 result = result_error(result_error_compression_error); 368 goto out; 369 } 370 data.total_size += size; 371 } 372 373 result = result_data(data); 374 out: 375 ZSTD_freeCCtx(cctx); 376 return result; 377 } 378 379 static result_t advanced_one_pass_compress( 380 method_state_t* base, 381 config_t const* config) { 382 return advanced_one_pass_compress_output_adjustment(base, config, 0); 383 } 384 385 static result_t advanced_one_pass_compress_small_output( 386 method_state_t* base, 387 config_t const* config) { 388 return advanced_one_pass_compress_output_adjustment(base, config, 1); 389 } 390 391 static result_t advanced_streaming_compress( 392 method_state_t* base, 393 config_t const* config) { 394 buffer_state_t* state = container_of(base, buffer_state_t, base); 395 396 if (buffer_state_bad(state, config)) 397 return result_error(result_error_system_error); 398 399 ZSTD_CCtx* cctx = ZSTD_createCCtx(); 400 result_t result; 401 402 if (!cctx || advanced_config(cctx, state, config)) { 403 result = result_error(result_error_compression_error); 404 goto out; 405 } 406 407 result_data_t data = {.total_size = 0}; 408 for (size_t i = 0; i < state->inputs.size; ++i) { 409 data_buffer_t input = state->inputs.buffers[i]; 410 411 if (!config->no_pledged_src_size) { 412 if (ZSTD_isError(ZSTD_CCtx_setPledgedSrcSize(cctx, input.size))) { 413 result = result_error(result_error_compression_error); 414 goto out; 415 } 416 } 417 418 while (input.size > 0) { 419 ZSTD_inBuffer in = {input.data, MIN(input.size, 4096)}; 420 input.data += in.size; 421 input.size -= in.size; 422 ZSTD_EndDirective const op = 423 input.size > 0 ? ZSTD_e_continue : ZSTD_e_end; 424 size_t ret = 0; 425 while (in.pos < in.size || (op == ZSTD_e_end && ret != 0)) { 426 ZSTD_outBuffer out = {state->compressed.data, 427 MIN(state->compressed.capacity, 1024)}; 428 ret = ZSTD_compressStream2(cctx, &out, &in, op); 429 if (ZSTD_isError(ret)) { 430 result = result_error(result_error_compression_error); 431 goto out; 432 } 433 data.total_size += out.pos; 434 } 435 } 436 } 437 438 result = result_data(data); 439 out: 440 ZSTD_freeCCtx(cctx); 441 return result; 442 } 443 444 static int init_cstream( 445 buffer_state_t* state, 446 ZSTD_CStream* zcs, 447 config_t const* config, 448 int const advanced, 449 ZSTD_CDict** cdict) 450 { 451 size_t zret; 452 if (advanced) { 453 ZSTD_parameters const params = config_get_zstd_params(config, 0, 0); 454 ZSTD_CDict* dict = NULL; 455 if (cdict) { 456 if (!config->use_dictionary) 457 return 1; 458 *cdict = ZSTD_createCDict_advanced( 459 state->dictionary.data, 460 state->dictionary.size, 461 ZSTD_dlm_byRef, 462 ZSTD_dct_auto, 463 params.cParams, 464 ZSTD_defaultCMem); 465 if (!*cdict) { 466 return 1; 467 } 468 zret = ZSTD_initCStream_usingCDict_advanced( 469 zcs, *cdict, params.fParams, ZSTD_CONTENTSIZE_UNKNOWN); 470 } else { 471 zret = ZSTD_initCStream_advanced( 472 zcs, 473 config->use_dictionary ? state->dictionary.data : NULL, 474 config->use_dictionary ? state->dictionary.size : 0, 475 params, 476 ZSTD_CONTENTSIZE_UNKNOWN); 477 } 478 } else { 479 int const level = config_get_level(config); 480 if (level == CONFIG_NO_LEVEL) 481 return 1; 482 if (cdict) { 483 if (!config->use_dictionary) 484 return 1; 485 *cdict = ZSTD_createCDict( 486 state->dictionary.data, 487 state->dictionary.size, 488 level); 489 if (!*cdict) { 490 return 1; 491 } 492 zret = ZSTD_initCStream_usingCDict(zcs, *cdict); 493 } else if (config->use_dictionary) { 494 zret = ZSTD_initCStream_usingDict( 495 zcs, 496 state->dictionary.data, 497 state->dictionary.size, 498 level); 499 } else { 500 zret = ZSTD_initCStream(zcs, level); 501 } 502 } 503 if (ZSTD_isError(zret)) { 504 return 1; 505 } 506 return 0; 507 } 508 509 static result_t old_streaming_compress_internal( 510 method_state_t* base, 511 config_t const* config, 512 int const advanced, 513 int const cdict) { 514 buffer_state_t* state = container_of(base, buffer_state_t, base); 515 516 if (buffer_state_bad(state, config)) 517 return result_error(result_error_system_error); 518 519 520 ZSTD_CStream* zcs = ZSTD_createCStream(); 521 ZSTD_CDict* cd = NULL; 522 result_t result; 523 if (zcs == NULL) { 524 result = result_error(result_error_compression_error); 525 goto out; 526 } 527 if (!advanced && config_get_level(config) == CONFIG_NO_LEVEL) { 528 result = result_error(result_error_skip); 529 goto out; 530 } 531 if (cdict && !config->use_dictionary) { 532 result = result_error(result_error_skip); 533 goto out; 534 } 535 if (config->advanced_api_only) { 536 result = result_error(result_error_skip); 537 goto out; 538 } 539 if (init_cstream(state, zcs, config, advanced, cdict ? &cd : NULL)) { 540 result = result_error(result_error_compression_error); 541 goto out; 542 } 543 544 result_data_t data = {.total_size = 0}; 545 for (size_t i = 0; i < state->inputs.size; ++i) { 546 data_buffer_t input = state->inputs.buffers[i]; 547 size_t zret = ZSTD_resetCStream( 548 zcs, 549 config->no_pledged_src_size ? ZSTD_CONTENTSIZE_UNKNOWN : input.size); 550 if (ZSTD_isError(zret)) { 551 result = result_error(result_error_compression_error); 552 goto out; 553 } 554 555 while (input.size > 0) { 556 ZSTD_inBuffer in = {input.data, MIN(input.size, 4096)}; 557 input.data += in.size; 558 input.size -= in.size; 559 ZSTD_EndDirective const op = 560 input.size > 0 ? ZSTD_e_continue : ZSTD_e_end; 561 zret = 0; 562 while (in.pos < in.size || (op == ZSTD_e_end && zret != 0)) { 563 ZSTD_outBuffer out = {state->compressed.data, 564 MIN(state->compressed.capacity, 1024)}; 565 if (op == ZSTD_e_continue || in.pos < in.size) 566 zret = ZSTD_compressStream(zcs, &out, &in); 567 else 568 zret = ZSTD_endStream(zcs, &out); 569 if (ZSTD_isError(zret)) { 570 result = result_error(result_error_compression_error); 571 goto out; 572 } 573 data.total_size += out.pos; 574 } 575 } 576 } 577 578 result = result_data(data); 579 out: 580 ZSTD_freeCStream(zcs); 581 ZSTD_freeCDict(cd); 582 return result; 583 } 584 585 static result_t old_streaming_compress( 586 method_state_t* base, 587 config_t const* config) 588 { 589 return old_streaming_compress_internal( 590 base, config, /* advanced */ 0, /* cdict */ 0); 591 } 592 593 static result_t old_streaming_compress_advanced( 594 method_state_t* base, 595 config_t const* config) 596 { 597 return old_streaming_compress_internal( 598 base, config, /* advanced */ 1, /* cdict */ 0); 599 } 600 601 static result_t old_streaming_compress_cdict( 602 method_state_t* base, 603 config_t const* config) 604 { 605 return old_streaming_compress_internal( 606 base, config, /* advanced */ 0, /* cdict */ 1); 607 } 608 609 static result_t old_streaming_compress_cdict_advanced( 610 method_state_t* base, 611 config_t const* config) 612 { 613 return old_streaming_compress_internal( 614 base, config, /* advanced */ 1, /* cdict */ 1); 615 } 616 617 method_t const simple = { 618 .name = "compress simple", 619 .create = buffer_state_create, 620 .compress = simple_compress, 621 .destroy = buffer_state_destroy, 622 }; 623 624 method_t const compress_cctx = { 625 .name = "compress cctx", 626 .create = buffer_state_create, 627 .compress = compress_cctx_compress, 628 .destroy = buffer_state_destroy, 629 }; 630 631 method_t const advanced_one_pass = { 632 .name = "advanced one pass", 633 .create = buffer_state_create, 634 .compress = advanced_one_pass_compress, 635 .destroy = buffer_state_destroy, 636 }; 637 638 method_t const advanced_one_pass_small_out = { 639 .name = "advanced one pass small out", 640 .create = buffer_state_create, 641 .compress = advanced_one_pass_compress, 642 .destroy = buffer_state_destroy, 643 }; 644 645 method_t const advanced_streaming = { 646 .name = "advanced streaming", 647 .create = buffer_state_create, 648 .compress = advanced_streaming_compress, 649 .destroy = buffer_state_destroy, 650 }; 651 652 method_t const old_streaming = { 653 .name = "old streaming", 654 .create = buffer_state_create, 655 .compress = old_streaming_compress, 656 .destroy = buffer_state_destroy, 657 }; 658 659 method_t const old_streaming_advanced = { 660 .name = "old streaming advanced", 661 .create = buffer_state_create, 662 .compress = old_streaming_compress_advanced, 663 .destroy = buffer_state_destroy, 664 }; 665 666 method_t const old_streaming_cdict = { 667 .name = "old streaming cdict", 668 .create = buffer_state_create, 669 .compress = old_streaming_compress_cdict, 670 .destroy = buffer_state_destroy, 671 }; 672 673 method_t const old_streaming_advanced_cdict = { 674 .name = "old streaming advanced cdict", 675 .create = buffer_state_create, 676 .compress = old_streaming_compress_cdict_advanced, 677 .destroy = buffer_state_destroy, 678 }; 679 680 method_t const cli = { 681 .name = "zstdcli", 682 .create = method_state_create, 683 .compress = cli_compress, 684 .destroy = method_state_destroy, 685 }; 686 687 static method_t const* g_methods[] = { 688 &simple, 689 &compress_cctx, 690 &cli, 691 &advanced_one_pass, 692 &advanced_one_pass_small_out, 693 &advanced_streaming, 694 &old_streaming, 695 &old_streaming_advanced, 696 &old_streaming_cdict, 697 &old_streaming_advanced_cdict, 698 NULL, 699 }; 700 701 method_t const* const* methods = g_methods; 702