1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under both the BSD-style license (found in the 6 * LICENSE file in the root directory of this source tree) and the GPLv2 (found 7 * in the COPYING file in the root directory of this source tree). 8 */ 9 #include "platform.h" /* Large Files support, SET_BINARY_MODE */ 10 #include "Pzstd.h" 11 #include "SkippableFrame.h" 12 #include "utils/FileSystem.h" 13 #include "utils/Portability.h" 14 #include "utils/Range.h" 15 #include "utils/ScopeGuard.h" 16 #include "utils/ThreadPool.h" 17 #include "utils/WorkQueue.h" 18 19 #include <algorithm> 20 #include <chrono> 21 #include <cinttypes> 22 #include <cstddef> 23 #include <cstdio> 24 #include <memory> 25 #include <string> 26 27 28 namespace pzstd { 29 30 namespace { 31 #ifdef _WIN32 32 const std::string nullOutput = "nul"; 33 #else 34 const std::string nullOutput = "/dev/null"; 35 #endif 36 } 37 38 using std::size_t; 39 40 static std::uintmax_t fileSizeOrZero(const std::string &file) { 41 if (file == "-") { 42 return 0; 43 } 44 std::error_code ec; 45 auto size = file_size(file, ec); 46 if (ec) { 47 size = 0; 48 } 49 return size; 50 } 51 52 static std::uint64_t handleOneInput(const Options &options, 53 const std::string &inputFile, 54 FILE* inputFd, 55 const std::string &outputFile, 56 FILE* outputFd, 57 SharedState& state) { 58 auto inputSize = fileSizeOrZero(inputFile); 59 // WorkQueue outlives ThreadPool so in the case of error we are certain 60 // we don't accidentally try to call push() on it after it is destroyed 61 WorkQueue<std::shared_ptr<BufferWorkQueue>> outs{options.numThreads + 1}; 62 std::uint64_t bytesRead; 63 std::uint64_t bytesWritten; 64 { 65 // Initialize the (de)compression thread pool with numThreads 66 ThreadPool executor(options.numThreads); 67 // Run the reader thread on an extra thread 68 ThreadPool readExecutor(1); 69 if (!options.decompress) { 70 // Add a job that reads the input and starts all the compression jobs 71 readExecutor.add( 72 [&state, &outs, &executor, inputFd, inputSize, &options, &bytesRead] { 73 bytesRead = asyncCompressChunks( 74 state, 75 outs, 76 executor, 77 inputFd, 78 inputSize, 79 options.numThreads, 80 options.determineParameters()); 81 }); 82 // Start writing 83 bytesWritten = writeFile(state, outs, outputFd, options.decompress); 84 } else { 85 // Add a job that reads the input and starts all the decompression jobs 86 readExecutor.add([&state, &outs, &executor, inputFd, &bytesRead] { 87 bytesRead = asyncDecompressFrames(state, outs, executor, inputFd); 88 }); 89 // Start writing 90 bytesWritten = writeFile(state, outs, outputFd, options.decompress); 91 } 92 } 93 if (!state.errorHolder.hasError()) { 94 std::string inputFileName = inputFile == "-" ? "stdin" : inputFile; 95 std::string outputFileName = outputFile == "-" ? "stdout" : outputFile; 96 if (!options.decompress) { 97 double ratio = static_cast<double>(bytesWritten) / 98 static_cast<double>(bytesRead + !bytesRead); 99 state.log(kLogInfo, "%-20s :%6.2f%% (%6" PRIu64 " => %6" PRIu64 100 " bytes, %s)\n", 101 inputFileName.c_str(), ratio * 100, bytesRead, bytesWritten, 102 outputFileName.c_str()); 103 } else { 104 state.log(kLogInfo, "%-20s: %" PRIu64 " bytes \n", 105 inputFileName.c_str(),bytesWritten); 106 } 107 } 108 return bytesWritten; 109 } 110 111 static FILE *openInputFile(const std::string &inputFile, 112 ErrorHolder &errorHolder) { 113 if (inputFile == "-") { 114 SET_BINARY_MODE(stdin); 115 return stdin; 116 } 117 // Check if input file is a directory 118 { 119 std::error_code ec; 120 if (is_directory(inputFile, ec)) { 121 errorHolder.setError("Output file is a directory -- ignored"); 122 return nullptr; 123 } 124 } 125 auto inputFd = std::fopen(inputFile.c_str(), "rb"); 126 if (!errorHolder.check(inputFd != nullptr, "Failed to open input file")) { 127 return nullptr; 128 } 129 return inputFd; 130 } 131 132 static FILE *openOutputFile(const Options &options, 133 const std::string &outputFile, 134 SharedState& state) { 135 if (outputFile == "-") { 136 SET_BINARY_MODE(stdout); 137 return stdout; 138 } 139 // Check if the output file exists and then open it 140 if (!options.overwrite && outputFile != nullOutput) { 141 auto outputFd = std::fopen(outputFile.c_str(), "rb"); 142 if (outputFd != nullptr) { 143 std::fclose(outputFd); 144 if (!state.log.logsAt(kLogInfo)) { 145 state.errorHolder.setError("Output file exists"); 146 return nullptr; 147 } 148 state.log( 149 kLogInfo, 150 "pzstd: %s already exists; do you wish to overwrite (y/n) ? ", 151 outputFile.c_str()); 152 int c = getchar(); 153 if (c != 'y' && c != 'Y') { 154 state.errorHolder.setError("Not overwritten"); 155 return nullptr; 156 } 157 } 158 } 159 auto outputFd = std::fopen(outputFile.c_str(), "wb"); 160 if (!state.errorHolder.check( 161 outputFd != nullptr, "Failed to open output file")) { 162 return nullptr; 163 } 164 return outputFd; 165 } 166 167 int pzstdMain(const Options &options) { 168 int returnCode = 0; 169 SharedState state(options); 170 for (const auto& input : options.inputFiles) { 171 // Setup the shared state 172 auto printErrorGuard = makeScopeGuard([&] { 173 if (state.errorHolder.hasError()) { 174 returnCode = 1; 175 state.log(kLogError, "pzstd: %s: %s.\n", input.c_str(), 176 state.errorHolder.getError().c_str()); 177 } 178 }); 179 // Open the input file 180 auto inputFd = openInputFile(input, state.errorHolder); 181 if (inputFd == nullptr) { 182 continue; 183 } 184 auto closeInputGuard = makeScopeGuard([&] { std::fclose(inputFd); }); 185 // Open the output file 186 auto outputFile = options.getOutputFile(input); 187 if (!state.errorHolder.check(outputFile != "", 188 "Input file does not have extension .zst")) { 189 continue; 190 } 191 auto outputFd = openOutputFile(options, outputFile, state); 192 if (outputFd == nullptr) { 193 continue; 194 } 195 auto closeOutputGuard = makeScopeGuard([&] { std::fclose(outputFd); }); 196 // (de)compress the file 197 handleOneInput(options, input, inputFd, outputFile, outputFd, state); 198 if (state.errorHolder.hasError()) { 199 continue; 200 } 201 // Delete the input file if necessary 202 if (!options.keepSource) { 203 // Be sure that we are done and have written everything before we delete 204 if (!state.errorHolder.check(std::fclose(inputFd) == 0, 205 "Failed to close input file")) { 206 continue; 207 } 208 closeInputGuard.dismiss(); 209 if (!state.errorHolder.check(std::fclose(outputFd) == 0, 210 "Failed to close output file")) { 211 continue; 212 } 213 closeOutputGuard.dismiss(); 214 if (std::remove(input.c_str()) != 0) { 215 state.errorHolder.setError("Failed to remove input file"); 216 continue; 217 } 218 } 219 } 220 // Returns 1 if any of the files failed to (de)compress. 221 return returnCode; 222 } 223 224 /// Construct a `ZSTD_inBuffer` that points to the data in `buffer`. 225 static ZSTD_inBuffer makeZstdInBuffer(const Buffer& buffer) { 226 return ZSTD_inBuffer{buffer.data(), buffer.size(), 0}; 227 } 228 229 /** 230 * Advance `buffer` and `inBuffer` by the amount of data read, as indicated by 231 * `inBuffer.pos`. 232 */ 233 void advance(Buffer& buffer, ZSTD_inBuffer& inBuffer) { 234 auto pos = inBuffer.pos; 235 inBuffer.src = static_cast<const unsigned char*>(inBuffer.src) + pos; 236 inBuffer.size -= pos; 237 inBuffer.pos = 0; 238 return buffer.advance(pos); 239 } 240 241 /// Construct a `ZSTD_outBuffer` that points to the data in `buffer`. 242 static ZSTD_outBuffer makeZstdOutBuffer(Buffer& buffer) { 243 return ZSTD_outBuffer{buffer.data(), buffer.size(), 0}; 244 } 245 246 /** 247 * Split `buffer` and advance `outBuffer` by the amount of data written, as 248 * indicated by `outBuffer.pos`. 249 */ 250 Buffer split(Buffer& buffer, ZSTD_outBuffer& outBuffer) { 251 auto pos = outBuffer.pos; 252 outBuffer.dst = static_cast<unsigned char*>(outBuffer.dst) + pos; 253 outBuffer.size -= pos; 254 outBuffer.pos = 0; 255 return buffer.splitAt(pos); 256 } 257 258 /** 259 * Stream chunks of input from `in`, compress it, and stream it out to `out`. 260 * 261 * @param state The shared state 262 * @param in Queue that we `pop()` input buffers from 263 * @param out Queue that we `push()` compressed output buffers to 264 * @param maxInputSize An upper bound on the size of the input 265 */ 266 static void compress( 267 SharedState& state, 268 std::shared_ptr<BufferWorkQueue> in, 269 std::shared_ptr<BufferWorkQueue> out, 270 size_t maxInputSize) { 271 auto& errorHolder = state.errorHolder; 272 auto guard = makeScopeGuard([&] { out->finish(); }); 273 // Initialize the CCtx 274 auto ctx = state.cStreamPool->get(); 275 if (!errorHolder.check(ctx != nullptr, "Failed to allocate ZSTD_CStream")) { 276 return; 277 } 278 { 279 auto err = ZSTD_CCtx_reset(ctx.get(), ZSTD_reset_session_only); 280 if (!errorHolder.check(!ZSTD_isError(err), ZSTD_getErrorName(err))) { 281 return; 282 } 283 } 284 285 // Allocate space for the result 286 auto outBuffer = Buffer(ZSTD_compressBound(maxInputSize)); 287 auto zstdOutBuffer = makeZstdOutBuffer(outBuffer); 288 { 289 Buffer inBuffer; 290 // Read a buffer in from the input queue 291 while (in->pop(inBuffer) && !errorHolder.hasError()) { 292 auto zstdInBuffer = makeZstdInBuffer(inBuffer); 293 // Compress the whole buffer and send it to the output queue 294 while (!inBuffer.empty() && !errorHolder.hasError()) { 295 if (!errorHolder.check( 296 !outBuffer.empty(), "ZSTD_compressBound() was too small")) { 297 return; 298 } 299 // Compress 300 auto err = 301 ZSTD_compressStream(ctx.get(), &zstdOutBuffer, &zstdInBuffer); 302 if (!errorHolder.check(!ZSTD_isError(err), ZSTD_getErrorName(err))) { 303 return; 304 } 305 // Split the compressed data off outBuffer and pass to the output queue 306 out->push(split(outBuffer, zstdOutBuffer)); 307 // Forget about the data we already compressed 308 advance(inBuffer, zstdInBuffer); 309 } 310 } 311 } 312 // Write the epilog 313 size_t bytesLeft; 314 do { 315 if (!errorHolder.check( 316 !outBuffer.empty(), "ZSTD_compressBound() was too small")) { 317 return; 318 } 319 bytesLeft = ZSTD_endStream(ctx.get(), &zstdOutBuffer); 320 if (!errorHolder.check( 321 !ZSTD_isError(bytesLeft), ZSTD_getErrorName(bytesLeft))) { 322 return; 323 } 324 out->push(split(outBuffer, zstdOutBuffer)); 325 } while (bytesLeft != 0 && !errorHolder.hasError()); 326 } 327 328 /** 329 * Calculates how large each independently compressed frame should be. 330 * 331 * @param size The size of the source if known, 0 otherwise 332 * @param numThreads The number of threads available to run compression jobs on 333 * @param params The zstd parameters to be used for compression 334 */ 335 static size_t calculateStep( 336 std::uintmax_t size, 337 size_t numThreads, 338 const ZSTD_parameters ¶ms) { 339 (void)size; 340 (void)numThreads; 341 // Not validated to work correctly for window logs > 23. 342 // It will definitely fail if windowLog + 2 is >= 4GB because 343 // the skippable frame can only store sizes up to 4GB. 344 assert(params.cParams.windowLog <= 23); 345 return size_t{1} << (params.cParams.windowLog + 2); 346 } 347 348 namespace { 349 enum class FileStatus { Continue, Done, Error }; 350 /// Determines the status of the file descriptor `fd`. 351 FileStatus fileStatus(FILE* fd) { 352 if (std::feof(fd)) { 353 return FileStatus::Done; 354 } else if (std::ferror(fd)) { 355 return FileStatus::Error; 356 } 357 return FileStatus::Continue; 358 } 359 } // anonymous namespace 360 361 /** 362 * Reads `size` data in chunks of `chunkSize` and puts it into `queue`. 363 * Will read less if an error or EOF occurs. 364 * Returns the status of the file after all of the reads have occurred. 365 */ 366 static FileStatus 367 readData(BufferWorkQueue& queue, size_t chunkSize, size_t size, FILE* fd, 368 std::uint64_t *totalBytesRead) { 369 Buffer buffer(size); 370 while (!buffer.empty()) { 371 auto bytesRead = 372 std::fread(buffer.data(), 1, std::min(chunkSize, buffer.size()), fd); 373 *totalBytesRead += bytesRead; 374 queue.push(buffer.splitAt(bytesRead)); 375 auto status = fileStatus(fd); 376 if (status != FileStatus::Continue) { 377 return status; 378 } 379 } 380 return FileStatus::Continue; 381 } 382 383 std::uint64_t asyncCompressChunks( 384 SharedState& state, 385 WorkQueue<std::shared_ptr<BufferWorkQueue>>& chunks, 386 ThreadPool& executor, 387 FILE* fd, 388 std::uintmax_t size, 389 size_t numThreads, 390 ZSTD_parameters params) { 391 auto chunksGuard = makeScopeGuard([&] { chunks.finish(); }); 392 std::uint64_t bytesRead = 0; 393 394 // Break the input up into chunks of size `step` and compress each chunk 395 // independently. 396 size_t step = calculateStep(size, numThreads, params); 397 state.log(kLogDebug, "Chosen frame size: %zu\n", step); 398 auto status = FileStatus::Continue; 399 while (status == FileStatus::Continue && !state.errorHolder.hasError()) { 400 // Make a new input queue that we will put the chunk's input data into. 401 auto in = std::make_shared<BufferWorkQueue>(); 402 auto inGuard = makeScopeGuard([&] { in->finish(); }); 403 // Make a new output queue that compress will put the compressed data into. 404 auto out = std::make_shared<BufferWorkQueue>(); 405 // Start compression in the thread pool 406 executor.add([&state, in, out, step] { 407 return compress( 408 state, std::move(in), std::move(out), step); 409 }); 410 // Pass the output queue to the writer thread. 411 chunks.push(std::move(out)); 412 state.log(kLogVerbose, "%s\n", "Starting a new frame"); 413 // Fill the input queue for the compression job we just started 414 status = readData(*in, ZSTD_CStreamInSize(), step, fd, &bytesRead); 415 } 416 state.errorHolder.check(status != FileStatus::Error, "Error reading input"); 417 return bytesRead; 418 } 419 420 /** 421 * Decompress a frame, whose data is streamed into `in`, and stream the output 422 * to `out`. 423 * 424 * @param state The shared state 425 * @param in Queue that we `pop()` input buffers from. It contains 426 * exactly one compressed frame. 427 * @param out Queue that we `push()` decompressed output buffers to 428 */ 429 static void decompress( 430 SharedState& state, 431 std::shared_ptr<BufferWorkQueue> in, 432 std::shared_ptr<BufferWorkQueue> out) { 433 auto& errorHolder = state.errorHolder; 434 auto guard = makeScopeGuard([&] { out->finish(); }); 435 // Initialize the DCtx 436 auto ctx = state.dStreamPool->get(); 437 if (!errorHolder.check(ctx != nullptr, "Failed to allocate ZSTD_DStream")) { 438 return; 439 } 440 { 441 auto err = ZSTD_DCtx_reset(ctx.get(), ZSTD_reset_session_only); 442 if (!errorHolder.check(!ZSTD_isError(err), ZSTD_getErrorName(err))) { 443 return; 444 } 445 } 446 447 const size_t outSize = ZSTD_DStreamOutSize(); 448 Buffer inBuffer; 449 size_t returnCode = 0; 450 // Read a buffer in from the input queue 451 while (in->pop(inBuffer) && !errorHolder.hasError()) { 452 auto zstdInBuffer = makeZstdInBuffer(inBuffer); 453 // Decompress the whole buffer and send it to the output queue 454 while (!inBuffer.empty() && !errorHolder.hasError()) { 455 // Allocate a buffer with at least outSize bytes. 456 Buffer outBuffer(outSize); 457 auto zstdOutBuffer = makeZstdOutBuffer(outBuffer); 458 // Decompress 459 returnCode = 460 ZSTD_decompressStream(ctx.get(), &zstdOutBuffer, &zstdInBuffer); 461 if (!errorHolder.check( 462 !ZSTD_isError(returnCode), ZSTD_getErrorName(returnCode))) { 463 return; 464 } 465 // Pass the buffer with the decompressed data to the output queue 466 out->push(split(outBuffer, zstdOutBuffer)); 467 // Advance past the input we already read 468 advance(inBuffer, zstdInBuffer); 469 if (returnCode == 0) { 470 // The frame is over, prepare to (maybe) start a new frame 471 ZSTD_initDStream(ctx.get()); 472 } 473 } 474 } 475 if (!errorHolder.check(returnCode <= 1, "Incomplete block")) { 476 return; 477 } 478 // We've given ZSTD_decompressStream all of our data, but there may still 479 // be data to read. 480 while (returnCode == 1) { 481 // Allocate a buffer with at least outSize bytes. 482 Buffer outBuffer(outSize); 483 auto zstdOutBuffer = makeZstdOutBuffer(outBuffer); 484 // Pass in no input. 485 ZSTD_inBuffer zstdInBuffer{nullptr, 0, 0}; 486 // Decompress 487 returnCode = 488 ZSTD_decompressStream(ctx.get(), &zstdOutBuffer, &zstdInBuffer); 489 if (!errorHolder.check( 490 !ZSTD_isError(returnCode), ZSTD_getErrorName(returnCode))) { 491 return; 492 } 493 // Pass the buffer with the decompressed data to the output queue 494 out->push(split(outBuffer, zstdOutBuffer)); 495 } 496 } 497 498 std::uint64_t asyncDecompressFrames( 499 SharedState& state, 500 WorkQueue<std::shared_ptr<BufferWorkQueue>>& frames, 501 ThreadPool& executor, 502 FILE* fd) { 503 auto framesGuard = makeScopeGuard([&] { frames.finish(); }); 504 std::uint64_t totalBytesRead = 0; 505 506 // Split the source up into its component frames. 507 // If we find our recognized skippable frame we know the next frames size 508 // which means that we can decompress each standard frame in independently. 509 // Otherwise, we will decompress using only one decompression task. 510 const size_t chunkSize = ZSTD_DStreamInSize(); 511 auto status = FileStatus::Continue; 512 while (status == FileStatus::Continue && !state.errorHolder.hasError()) { 513 // Make a new input queue that we will put the frames's bytes into. 514 auto in = std::make_shared<BufferWorkQueue>(); 515 auto inGuard = makeScopeGuard([&] { in->finish(); }); 516 // Make a output queue that decompress will put the decompressed data into 517 auto out = std::make_shared<BufferWorkQueue>(); 518 519 size_t frameSize; 520 { 521 // Calculate the size of the next frame. 522 // frameSize is 0 if the frame info can't be decoded. 523 Buffer buffer(SkippableFrame::kSize); 524 auto bytesRead = std::fread(buffer.data(), 1, buffer.size(), fd); 525 totalBytesRead += bytesRead; 526 status = fileStatus(fd); 527 if (bytesRead == 0 && status != FileStatus::Continue) { 528 break; 529 } 530 buffer.subtract(buffer.size() - bytesRead); 531 frameSize = SkippableFrame::tryRead(buffer.range()); 532 in->push(std::move(buffer)); 533 } 534 if (frameSize == 0) { 535 // We hit a non SkippableFrame, so this will be the last job. 536 // Make sure that we don't use too much memory 537 in->setMaxSize(64); 538 out->setMaxSize(64); 539 } 540 // Start decompression in the thread pool 541 executor.add([&state, in, out] { 542 return decompress(state, std::move(in), std::move(out)); 543 }); 544 // Pass the output queue to the writer thread 545 frames.push(std::move(out)); 546 if (frameSize == 0) { 547 // We hit a non SkippableFrame ==> not compressed by pzstd or corrupted 548 // Pass the rest of the source to this decompression task 549 state.log(kLogVerbose, "%s\n", 550 "Input not in pzstd format, falling back to serial decompression"); 551 while (status == FileStatus::Continue && !state.errorHolder.hasError()) { 552 status = readData(*in, chunkSize, chunkSize, fd, &totalBytesRead); 553 } 554 break; 555 } 556 state.log(kLogVerbose, "Decompressing a frame of size %zu", frameSize); 557 // Fill the input queue for the decompression job we just started 558 status = readData(*in, chunkSize, frameSize, fd, &totalBytesRead); 559 } 560 state.errorHolder.check(status != FileStatus::Error, "Error reading input"); 561 return totalBytesRead; 562 } 563 564 /// Write `data` to `fd`, returns true iff success. 565 static bool writeData(ByteRange data, FILE* fd) { 566 while (!data.empty()) { 567 data.advance(std::fwrite(data.begin(), 1, data.size(), fd)); 568 if (std::ferror(fd)) { 569 return false; 570 } 571 } 572 return true; 573 } 574 575 std::uint64_t writeFile( 576 SharedState& state, 577 WorkQueue<std::shared_ptr<BufferWorkQueue>>& outs, 578 FILE* outputFd, 579 bool decompress) { 580 auto& errorHolder = state.errorHolder; 581 auto lineClearGuard = makeScopeGuard([&state] { 582 state.log.clear(kLogInfo); 583 }); 584 std::uint64_t bytesWritten = 0; 585 std::shared_ptr<BufferWorkQueue> out; 586 // Grab the output queue for each decompression job (in order). 587 while (outs.pop(out)) { 588 if (errorHolder.hasError()) { 589 continue; 590 } 591 if (!decompress) { 592 // If we are compressing and want to write skippable frames we can't 593 // start writing before compression is done because we need to know the 594 // compressed size. 595 // Wait for the compressed size to be available and write skippable frame 596 assert(uint64_t(out->size()) < uint64_t(1) << 32); 597 SkippableFrame frame(uint32_t(out->size())); 598 if (!writeData(frame.data(), outputFd)) { 599 errorHolder.setError("Failed to write output"); 600 return bytesWritten; 601 } 602 bytesWritten += frame.kSize; 603 } 604 // For each chunk of the frame: Pop it from the queue and write it 605 Buffer buffer; 606 while (out->pop(buffer) && !errorHolder.hasError()) { 607 if (!writeData(buffer.range(), outputFd)) { 608 errorHolder.setError("Failed to write output"); 609 return bytesWritten; 610 } 611 bytesWritten += buffer.size(); 612 state.log.update(kLogInfo, "Written: %u MB ", 613 static_cast<std::uint32_t>(bytesWritten >> 20)); 614 } 615 } 616 return bytesWritten; 617 } 618 } 619