xref: /netbsd-src/external/bsd/zstd/dist/contrib/pzstd/Pzstd.cpp (revision 3117ece4fc4a4ca4489ba793710b60b0d26bab6c)
1*3117ece4Schristos /*
2*3117ece4Schristos  * Copyright (c) Meta Platforms, Inc. and affiliates.
3*3117ece4Schristos  * All rights reserved.
4*3117ece4Schristos  *
5*3117ece4Schristos  * This source code is licensed under both the BSD-style license (found in the
6*3117ece4Schristos  * LICENSE file in the root directory of this source tree) and the GPLv2 (found
7*3117ece4Schristos  * in the COPYING file in the root directory of this source tree).
8*3117ece4Schristos  */
9*3117ece4Schristos #include "platform.h"   /* Large Files support, SET_BINARY_MODE */
10*3117ece4Schristos #include "Pzstd.h"
11*3117ece4Schristos #include "SkippableFrame.h"
12*3117ece4Schristos #include "utils/FileSystem.h"
13*3117ece4Schristos #include "utils/Portability.h"
14*3117ece4Schristos #include "utils/Range.h"
15*3117ece4Schristos #include "utils/ScopeGuard.h"
16*3117ece4Schristos #include "utils/ThreadPool.h"
17*3117ece4Schristos #include "utils/WorkQueue.h"
18*3117ece4Schristos 
19*3117ece4Schristos #include <algorithm>
20*3117ece4Schristos #include <chrono>
21*3117ece4Schristos #include <cinttypes>
22*3117ece4Schristos #include <cstddef>
23*3117ece4Schristos #include <cstdio>
24*3117ece4Schristos #include <memory>
25*3117ece4Schristos #include <string>
26*3117ece4Schristos 
27*3117ece4Schristos 
28*3117ece4Schristos namespace pzstd {
29*3117ece4Schristos 
30*3117ece4Schristos namespace {
31*3117ece4Schristos #ifdef _WIN32
32*3117ece4Schristos const std::string nullOutput = "nul";
33*3117ece4Schristos #else
34*3117ece4Schristos const std::string nullOutput = "/dev/null";
35*3117ece4Schristos #endif
36*3117ece4Schristos }
37*3117ece4Schristos 
38*3117ece4Schristos using std::size_t;
39*3117ece4Schristos 
40*3117ece4Schristos static std::uintmax_t fileSizeOrZero(const std::string &file) {
41*3117ece4Schristos   if (file == "-") {
42*3117ece4Schristos     return 0;
43*3117ece4Schristos   }
44*3117ece4Schristos   std::error_code ec;
45*3117ece4Schristos   auto size = file_size(file, ec);
46*3117ece4Schristos   if (ec) {
47*3117ece4Schristos     size = 0;
48*3117ece4Schristos   }
49*3117ece4Schristos   return size;
50*3117ece4Schristos }
51*3117ece4Schristos 
52*3117ece4Schristos static std::uint64_t handleOneInput(const Options &options,
53*3117ece4Schristos                              const std::string &inputFile,
54*3117ece4Schristos                              FILE* inputFd,
55*3117ece4Schristos                              const std::string &outputFile,
56*3117ece4Schristos                              FILE* outputFd,
57*3117ece4Schristos                              SharedState& state) {
58*3117ece4Schristos   auto inputSize = fileSizeOrZero(inputFile);
59*3117ece4Schristos   // WorkQueue outlives ThreadPool so in the case of error we are certain
60*3117ece4Schristos   // we don't accidentally try to call push() on it after it is destroyed
61*3117ece4Schristos   WorkQueue<std::shared_ptr<BufferWorkQueue>> outs{options.numThreads + 1};
62*3117ece4Schristos   std::uint64_t bytesRead;
63*3117ece4Schristos   std::uint64_t bytesWritten;
64*3117ece4Schristos   {
65*3117ece4Schristos     // Initialize the (de)compression thread pool with numThreads
66*3117ece4Schristos     ThreadPool executor(options.numThreads);
67*3117ece4Schristos     // Run the reader thread on an extra thread
68*3117ece4Schristos     ThreadPool readExecutor(1);
69*3117ece4Schristos     if (!options.decompress) {
70*3117ece4Schristos       // Add a job that reads the input and starts all the compression jobs
71*3117ece4Schristos       readExecutor.add(
72*3117ece4Schristos           [&state, &outs, &executor, inputFd, inputSize, &options, &bytesRead] {
73*3117ece4Schristos             bytesRead = asyncCompressChunks(
74*3117ece4Schristos                 state,
75*3117ece4Schristos                 outs,
76*3117ece4Schristos                 executor,
77*3117ece4Schristos                 inputFd,
78*3117ece4Schristos                 inputSize,
79*3117ece4Schristos                 options.numThreads,
80*3117ece4Schristos                 options.determineParameters());
81*3117ece4Schristos           });
82*3117ece4Schristos       // Start writing
83*3117ece4Schristos       bytesWritten = writeFile(state, outs, outputFd, options.decompress);
84*3117ece4Schristos     } else {
85*3117ece4Schristos       // Add a job that reads the input and starts all the decompression jobs
86*3117ece4Schristos       readExecutor.add([&state, &outs, &executor, inputFd, &bytesRead] {
87*3117ece4Schristos         bytesRead = asyncDecompressFrames(state, outs, executor, inputFd);
88*3117ece4Schristos       });
89*3117ece4Schristos       // Start writing
90*3117ece4Schristos       bytesWritten = writeFile(state, outs, outputFd, options.decompress);
91*3117ece4Schristos     }
92*3117ece4Schristos   }
93*3117ece4Schristos   if (!state.errorHolder.hasError()) {
94*3117ece4Schristos     std::string inputFileName = inputFile == "-" ? "stdin" : inputFile;
95*3117ece4Schristos     std::string outputFileName = outputFile == "-" ? "stdout" : outputFile;
96*3117ece4Schristos     if (!options.decompress) {
97*3117ece4Schristos       double ratio = static_cast<double>(bytesWritten) /
98*3117ece4Schristos                      static_cast<double>(bytesRead + !bytesRead);
99*3117ece4Schristos       state.log(kLogInfo, "%-20s :%6.2f%%   (%6" PRIu64 " => %6" PRIu64
100*3117ece4Schristos                    " bytes, %s)\n",
101*3117ece4Schristos                    inputFileName.c_str(), ratio * 100, bytesRead, bytesWritten,
102*3117ece4Schristos                    outputFileName.c_str());
103*3117ece4Schristos     } else {
104*3117ece4Schristos       state.log(kLogInfo, "%-20s: %" PRIu64 " bytes \n",
105*3117ece4Schristos                    inputFileName.c_str(),bytesWritten);
106*3117ece4Schristos     }
107*3117ece4Schristos   }
108*3117ece4Schristos   return bytesWritten;
109*3117ece4Schristos }
110*3117ece4Schristos 
111*3117ece4Schristos static FILE *openInputFile(const std::string &inputFile,
112*3117ece4Schristos                            ErrorHolder &errorHolder) {
113*3117ece4Schristos   if (inputFile == "-") {
114*3117ece4Schristos     SET_BINARY_MODE(stdin);
115*3117ece4Schristos     return stdin;
116*3117ece4Schristos   }
117*3117ece4Schristos   // Check if input file is a directory
118*3117ece4Schristos   {
119*3117ece4Schristos     std::error_code ec;
120*3117ece4Schristos     if (is_directory(inputFile, ec)) {
121*3117ece4Schristos       errorHolder.setError("Output file is a directory -- ignored");
122*3117ece4Schristos       return nullptr;
123*3117ece4Schristos     }
124*3117ece4Schristos   }
125*3117ece4Schristos   auto inputFd = std::fopen(inputFile.c_str(), "rb");
126*3117ece4Schristos   if (!errorHolder.check(inputFd != nullptr, "Failed to open input file")) {
127*3117ece4Schristos     return nullptr;
128*3117ece4Schristos   }
129*3117ece4Schristos   return inputFd;
130*3117ece4Schristos }
131*3117ece4Schristos 
132*3117ece4Schristos static FILE *openOutputFile(const Options &options,
133*3117ece4Schristos                             const std::string &outputFile,
134*3117ece4Schristos                             SharedState& state) {
135*3117ece4Schristos   if (outputFile == "-") {
136*3117ece4Schristos     SET_BINARY_MODE(stdout);
137*3117ece4Schristos     return stdout;
138*3117ece4Schristos   }
139*3117ece4Schristos   // Check if the output file exists and then open it
140*3117ece4Schristos   if (!options.overwrite && outputFile != nullOutput) {
141*3117ece4Schristos     auto outputFd = std::fopen(outputFile.c_str(), "rb");
142*3117ece4Schristos     if (outputFd != nullptr) {
143*3117ece4Schristos       std::fclose(outputFd);
144*3117ece4Schristos       if (!state.log.logsAt(kLogInfo)) {
145*3117ece4Schristos         state.errorHolder.setError("Output file exists");
146*3117ece4Schristos         return nullptr;
147*3117ece4Schristos       }
148*3117ece4Schristos       state.log(
149*3117ece4Schristos           kLogInfo,
150*3117ece4Schristos           "pzstd: %s already exists; do you wish to overwrite (y/n) ? ",
151*3117ece4Schristos           outputFile.c_str());
152*3117ece4Schristos       int c = getchar();
153*3117ece4Schristos       if (c != 'y' && c != 'Y') {
154*3117ece4Schristos         state.errorHolder.setError("Not overwritten");
155*3117ece4Schristos         return nullptr;
156*3117ece4Schristos       }
157*3117ece4Schristos     }
158*3117ece4Schristos   }
159*3117ece4Schristos   auto outputFd = std::fopen(outputFile.c_str(), "wb");
160*3117ece4Schristos   if (!state.errorHolder.check(
161*3117ece4Schristos           outputFd != nullptr, "Failed to open output file")) {
162*3117ece4Schristos     return nullptr;
163*3117ece4Schristos   }
164*3117ece4Schristos   return outputFd;
165*3117ece4Schristos }
166*3117ece4Schristos 
167*3117ece4Schristos int pzstdMain(const Options &options) {
168*3117ece4Schristos   int returnCode = 0;
169*3117ece4Schristos   SharedState state(options);
170*3117ece4Schristos   for (const auto& input : options.inputFiles) {
171*3117ece4Schristos     // Setup the shared state
172*3117ece4Schristos     auto printErrorGuard = makeScopeGuard([&] {
173*3117ece4Schristos       if (state.errorHolder.hasError()) {
174*3117ece4Schristos         returnCode = 1;
175*3117ece4Schristos         state.log(kLogError, "pzstd: %s: %s.\n", input.c_str(),
176*3117ece4Schristos                   state.errorHolder.getError().c_str());
177*3117ece4Schristos       }
178*3117ece4Schristos     });
179*3117ece4Schristos     // Open the input file
180*3117ece4Schristos     auto inputFd = openInputFile(input, state.errorHolder);
181*3117ece4Schristos     if (inputFd == nullptr) {
182*3117ece4Schristos       continue;
183*3117ece4Schristos     }
184*3117ece4Schristos     auto closeInputGuard = makeScopeGuard([&] { std::fclose(inputFd); });
185*3117ece4Schristos     // Open the output file
186*3117ece4Schristos     auto outputFile = options.getOutputFile(input);
187*3117ece4Schristos     if (!state.errorHolder.check(outputFile != "",
188*3117ece4Schristos                            "Input file does not have extension .zst")) {
189*3117ece4Schristos       continue;
190*3117ece4Schristos     }
191*3117ece4Schristos     auto outputFd = openOutputFile(options, outputFile, state);
192*3117ece4Schristos     if (outputFd == nullptr) {
193*3117ece4Schristos       continue;
194*3117ece4Schristos     }
195*3117ece4Schristos     auto closeOutputGuard = makeScopeGuard([&] { std::fclose(outputFd); });
196*3117ece4Schristos     // (de)compress the file
197*3117ece4Schristos     handleOneInput(options, input, inputFd, outputFile, outputFd, state);
198*3117ece4Schristos     if (state.errorHolder.hasError()) {
199*3117ece4Schristos       continue;
200*3117ece4Schristos     }
201*3117ece4Schristos     // Delete the input file if necessary
202*3117ece4Schristos     if (!options.keepSource) {
203*3117ece4Schristos       // Be sure that we are done and have written everything before we delete
204*3117ece4Schristos       if (!state.errorHolder.check(std::fclose(inputFd) == 0,
205*3117ece4Schristos                              "Failed to close input file")) {
206*3117ece4Schristos         continue;
207*3117ece4Schristos       }
208*3117ece4Schristos       closeInputGuard.dismiss();
209*3117ece4Schristos       if (!state.errorHolder.check(std::fclose(outputFd) == 0,
210*3117ece4Schristos                              "Failed to close output file")) {
211*3117ece4Schristos         continue;
212*3117ece4Schristos       }
213*3117ece4Schristos       closeOutputGuard.dismiss();
214*3117ece4Schristos       if (std::remove(input.c_str()) != 0) {
215*3117ece4Schristos         state.errorHolder.setError("Failed to remove input file");
216*3117ece4Schristos         continue;
217*3117ece4Schristos       }
218*3117ece4Schristos     }
219*3117ece4Schristos   }
220*3117ece4Schristos   // Returns 1 if any of the files failed to (de)compress.
221*3117ece4Schristos   return returnCode;
222*3117ece4Schristos }
223*3117ece4Schristos 
224*3117ece4Schristos /// Construct a `ZSTD_inBuffer` that points to the data in `buffer`.
225*3117ece4Schristos static ZSTD_inBuffer makeZstdInBuffer(const Buffer& buffer) {
226*3117ece4Schristos   return ZSTD_inBuffer{buffer.data(), buffer.size(), 0};
227*3117ece4Schristos }
228*3117ece4Schristos 
229*3117ece4Schristos /**
230*3117ece4Schristos  * Advance `buffer` and `inBuffer` by the amount of data read, as indicated by
231*3117ece4Schristos  * `inBuffer.pos`.
232*3117ece4Schristos  */
233*3117ece4Schristos void advance(Buffer& buffer, ZSTD_inBuffer& inBuffer) {
234*3117ece4Schristos   auto pos = inBuffer.pos;
235*3117ece4Schristos   inBuffer.src = static_cast<const unsigned char*>(inBuffer.src) + pos;
236*3117ece4Schristos   inBuffer.size -= pos;
237*3117ece4Schristos   inBuffer.pos = 0;
238*3117ece4Schristos   return buffer.advance(pos);
239*3117ece4Schristos }
240*3117ece4Schristos 
241*3117ece4Schristos /// Construct a `ZSTD_outBuffer` that points to the data in `buffer`.
242*3117ece4Schristos static ZSTD_outBuffer makeZstdOutBuffer(Buffer& buffer) {
243*3117ece4Schristos   return ZSTD_outBuffer{buffer.data(), buffer.size(), 0};
244*3117ece4Schristos }
245*3117ece4Schristos 
246*3117ece4Schristos /**
247*3117ece4Schristos  * Split `buffer` and advance `outBuffer` by the amount of data written, as
248*3117ece4Schristos  * indicated by `outBuffer.pos`.
249*3117ece4Schristos  */
250*3117ece4Schristos Buffer split(Buffer& buffer, ZSTD_outBuffer& outBuffer) {
251*3117ece4Schristos   auto pos = outBuffer.pos;
252*3117ece4Schristos   outBuffer.dst = static_cast<unsigned char*>(outBuffer.dst) + pos;
253*3117ece4Schristos   outBuffer.size -= pos;
254*3117ece4Schristos   outBuffer.pos = 0;
255*3117ece4Schristos   return buffer.splitAt(pos);
256*3117ece4Schristos }
257*3117ece4Schristos 
258*3117ece4Schristos /**
259*3117ece4Schristos  * Stream chunks of input from `in`, compress it, and stream it out to `out`.
260*3117ece4Schristos  *
261*3117ece4Schristos  * @param state        The shared state
262*3117ece4Schristos  * @param in           Queue that we `pop()` input buffers from
263*3117ece4Schristos  * @param out          Queue that we `push()` compressed output buffers to
264*3117ece4Schristos  * @param maxInputSize An upper bound on the size of the input
265*3117ece4Schristos  */
266*3117ece4Schristos static void compress(
267*3117ece4Schristos     SharedState& state,
268*3117ece4Schristos     std::shared_ptr<BufferWorkQueue> in,
269*3117ece4Schristos     std::shared_ptr<BufferWorkQueue> out,
270*3117ece4Schristos     size_t maxInputSize) {
271*3117ece4Schristos   auto& errorHolder = state.errorHolder;
272*3117ece4Schristos   auto guard = makeScopeGuard([&] { out->finish(); });
273*3117ece4Schristos   // Initialize the CCtx
274*3117ece4Schristos   auto ctx = state.cStreamPool->get();
275*3117ece4Schristos   if (!errorHolder.check(ctx != nullptr, "Failed to allocate ZSTD_CStream")) {
276*3117ece4Schristos     return;
277*3117ece4Schristos   }
278*3117ece4Schristos   {
279*3117ece4Schristos     auto err = ZSTD_CCtx_reset(ctx.get(), ZSTD_reset_session_only);
280*3117ece4Schristos     if (!errorHolder.check(!ZSTD_isError(err), ZSTD_getErrorName(err))) {
281*3117ece4Schristos       return;
282*3117ece4Schristos     }
283*3117ece4Schristos   }
284*3117ece4Schristos 
285*3117ece4Schristos   // Allocate space for the result
286*3117ece4Schristos   auto outBuffer = Buffer(ZSTD_compressBound(maxInputSize));
287*3117ece4Schristos   auto zstdOutBuffer = makeZstdOutBuffer(outBuffer);
288*3117ece4Schristos   {
289*3117ece4Schristos     Buffer inBuffer;
290*3117ece4Schristos     // Read a buffer in from the input queue
291*3117ece4Schristos     while (in->pop(inBuffer) && !errorHolder.hasError()) {
292*3117ece4Schristos       auto zstdInBuffer = makeZstdInBuffer(inBuffer);
293*3117ece4Schristos       // Compress the whole buffer and send it to the output queue
294*3117ece4Schristos       while (!inBuffer.empty() && !errorHolder.hasError()) {
295*3117ece4Schristos         if (!errorHolder.check(
296*3117ece4Schristos                 !outBuffer.empty(), "ZSTD_compressBound() was too small")) {
297*3117ece4Schristos           return;
298*3117ece4Schristos         }
299*3117ece4Schristos         // Compress
300*3117ece4Schristos         auto err =
301*3117ece4Schristos             ZSTD_compressStream(ctx.get(), &zstdOutBuffer, &zstdInBuffer);
302*3117ece4Schristos         if (!errorHolder.check(!ZSTD_isError(err), ZSTD_getErrorName(err))) {
303*3117ece4Schristos           return;
304*3117ece4Schristos         }
305*3117ece4Schristos         // Split the compressed data off outBuffer and pass to the output queue
306*3117ece4Schristos         out->push(split(outBuffer, zstdOutBuffer));
307*3117ece4Schristos         // Forget about the data we already compressed
308*3117ece4Schristos         advance(inBuffer, zstdInBuffer);
309*3117ece4Schristos       }
310*3117ece4Schristos     }
311*3117ece4Schristos   }
312*3117ece4Schristos   // Write the epilog
313*3117ece4Schristos   size_t bytesLeft;
314*3117ece4Schristos   do {
315*3117ece4Schristos     if (!errorHolder.check(
316*3117ece4Schristos             !outBuffer.empty(), "ZSTD_compressBound() was too small")) {
317*3117ece4Schristos       return;
318*3117ece4Schristos     }
319*3117ece4Schristos     bytesLeft = ZSTD_endStream(ctx.get(), &zstdOutBuffer);
320*3117ece4Schristos     if (!errorHolder.check(
321*3117ece4Schristos             !ZSTD_isError(bytesLeft), ZSTD_getErrorName(bytesLeft))) {
322*3117ece4Schristos       return;
323*3117ece4Schristos     }
324*3117ece4Schristos     out->push(split(outBuffer, zstdOutBuffer));
325*3117ece4Schristos   } while (bytesLeft != 0 && !errorHolder.hasError());
326*3117ece4Schristos }
327*3117ece4Schristos 
328*3117ece4Schristos /**
329*3117ece4Schristos  * Calculates how large each independently compressed frame should be.
330*3117ece4Schristos  *
331*3117ece4Schristos  * @param size       The size of the source if known, 0 otherwise
332*3117ece4Schristos  * @param numThreads The number of threads available to run compression jobs on
333*3117ece4Schristos  * @param params     The zstd parameters to be used for compression
334*3117ece4Schristos  */
335*3117ece4Schristos static size_t calculateStep(
336*3117ece4Schristos     std::uintmax_t size,
337*3117ece4Schristos     size_t numThreads,
338*3117ece4Schristos     const ZSTD_parameters &params) {
339*3117ece4Schristos   (void)size;
340*3117ece4Schristos   (void)numThreads;
341*3117ece4Schristos   // Not validated to work correctly for window logs > 23.
342*3117ece4Schristos   // It will definitely fail if windowLog + 2 is >= 4GB because
343*3117ece4Schristos   // the skippable frame can only store sizes up to 4GB.
344*3117ece4Schristos   assert(params.cParams.windowLog <= 23);
345*3117ece4Schristos   return size_t{1} << (params.cParams.windowLog + 2);
346*3117ece4Schristos }
347*3117ece4Schristos 
348*3117ece4Schristos namespace {
349*3117ece4Schristos enum class FileStatus { Continue, Done, Error };
350*3117ece4Schristos /// Determines the status of the file descriptor `fd`.
351*3117ece4Schristos FileStatus fileStatus(FILE* fd) {
352*3117ece4Schristos   if (std::feof(fd)) {
353*3117ece4Schristos     return FileStatus::Done;
354*3117ece4Schristos   } else if (std::ferror(fd)) {
355*3117ece4Schristos     return FileStatus::Error;
356*3117ece4Schristos   }
357*3117ece4Schristos   return FileStatus::Continue;
358*3117ece4Schristos }
359*3117ece4Schristos } // anonymous namespace
360*3117ece4Schristos 
361*3117ece4Schristos /**
362*3117ece4Schristos  * Reads `size` data in chunks of `chunkSize` and puts it into `queue`.
363*3117ece4Schristos  * Will read less if an error or EOF occurs.
364*3117ece4Schristos  * Returns the status of the file after all of the reads have occurred.
365*3117ece4Schristos  */
366*3117ece4Schristos static FileStatus
367*3117ece4Schristos readData(BufferWorkQueue& queue, size_t chunkSize, size_t size, FILE* fd,
368*3117ece4Schristos          std::uint64_t *totalBytesRead) {
369*3117ece4Schristos   Buffer buffer(size);
370*3117ece4Schristos   while (!buffer.empty()) {
371*3117ece4Schristos     auto bytesRead =
372*3117ece4Schristos         std::fread(buffer.data(), 1, std::min(chunkSize, buffer.size()), fd);
373*3117ece4Schristos     *totalBytesRead += bytesRead;
374*3117ece4Schristos     queue.push(buffer.splitAt(bytesRead));
375*3117ece4Schristos     auto status = fileStatus(fd);
376*3117ece4Schristos     if (status != FileStatus::Continue) {
377*3117ece4Schristos       return status;
378*3117ece4Schristos     }
379*3117ece4Schristos   }
380*3117ece4Schristos   return FileStatus::Continue;
381*3117ece4Schristos }
382*3117ece4Schristos 
383*3117ece4Schristos std::uint64_t asyncCompressChunks(
384*3117ece4Schristos     SharedState& state,
385*3117ece4Schristos     WorkQueue<std::shared_ptr<BufferWorkQueue>>& chunks,
386*3117ece4Schristos     ThreadPool& executor,
387*3117ece4Schristos     FILE* fd,
388*3117ece4Schristos     std::uintmax_t size,
389*3117ece4Schristos     size_t numThreads,
390*3117ece4Schristos     ZSTD_parameters params) {
391*3117ece4Schristos   auto chunksGuard = makeScopeGuard([&] { chunks.finish(); });
392*3117ece4Schristos   std::uint64_t bytesRead = 0;
393*3117ece4Schristos 
394*3117ece4Schristos   // Break the input up into chunks of size `step` and compress each chunk
395*3117ece4Schristos   // independently.
396*3117ece4Schristos   size_t step = calculateStep(size, numThreads, params);
397*3117ece4Schristos   state.log(kLogDebug, "Chosen frame size: %zu\n", step);
398*3117ece4Schristos   auto status = FileStatus::Continue;
399*3117ece4Schristos   while (status == FileStatus::Continue && !state.errorHolder.hasError()) {
400*3117ece4Schristos     // Make a new input queue that we will put the chunk's input data into.
401*3117ece4Schristos     auto in = std::make_shared<BufferWorkQueue>();
402*3117ece4Schristos     auto inGuard = makeScopeGuard([&] { in->finish(); });
403*3117ece4Schristos     // Make a new output queue that compress will put the compressed data into.
404*3117ece4Schristos     auto out = std::make_shared<BufferWorkQueue>();
405*3117ece4Schristos     // Start compression in the thread pool
406*3117ece4Schristos     executor.add([&state, in, out, step] {
407*3117ece4Schristos       return compress(
408*3117ece4Schristos           state, std::move(in), std::move(out), step);
409*3117ece4Schristos     });
410*3117ece4Schristos     // Pass the output queue to the writer thread.
411*3117ece4Schristos     chunks.push(std::move(out));
412*3117ece4Schristos     state.log(kLogVerbose, "%s\n", "Starting a new frame");
413*3117ece4Schristos     // Fill the input queue for the compression job we just started
414*3117ece4Schristos     status = readData(*in, ZSTD_CStreamInSize(), step, fd, &bytesRead);
415*3117ece4Schristos   }
416*3117ece4Schristos   state.errorHolder.check(status != FileStatus::Error, "Error reading input");
417*3117ece4Schristos   return bytesRead;
418*3117ece4Schristos }
419*3117ece4Schristos 
420*3117ece4Schristos /**
421*3117ece4Schristos  * Decompress a frame, whose data is streamed into `in`, and stream the output
422*3117ece4Schristos  * to `out`.
423*3117ece4Schristos  *
424*3117ece4Schristos  * @param state        The shared state
425*3117ece4Schristos  * @param in           Queue that we `pop()` input buffers from. It contains
426*3117ece4Schristos  *                      exactly one compressed frame.
427*3117ece4Schristos  * @param out          Queue that we `push()` decompressed output buffers to
428*3117ece4Schristos  */
429*3117ece4Schristos static void decompress(
430*3117ece4Schristos     SharedState& state,
431*3117ece4Schristos     std::shared_ptr<BufferWorkQueue> in,
432*3117ece4Schristos     std::shared_ptr<BufferWorkQueue> out) {
433*3117ece4Schristos   auto& errorHolder = state.errorHolder;
434*3117ece4Schristos   auto guard = makeScopeGuard([&] { out->finish(); });
435*3117ece4Schristos   // Initialize the DCtx
436*3117ece4Schristos   auto ctx = state.dStreamPool->get();
437*3117ece4Schristos   if (!errorHolder.check(ctx != nullptr, "Failed to allocate ZSTD_DStream")) {
438*3117ece4Schristos     return;
439*3117ece4Schristos   }
440*3117ece4Schristos   {
441*3117ece4Schristos     auto err = ZSTD_DCtx_reset(ctx.get(), ZSTD_reset_session_only);
442*3117ece4Schristos     if (!errorHolder.check(!ZSTD_isError(err), ZSTD_getErrorName(err))) {
443*3117ece4Schristos       return;
444*3117ece4Schristos     }
445*3117ece4Schristos   }
446*3117ece4Schristos 
447*3117ece4Schristos   const size_t outSize = ZSTD_DStreamOutSize();
448*3117ece4Schristos   Buffer inBuffer;
449*3117ece4Schristos   size_t returnCode = 0;
450*3117ece4Schristos   // Read a buffer in from the input queue
451*3117ece4Schristos   while (in->pop(inBuffer) && !errorHolder.hasError()) {
452*3117ece4Schristos     auto zstdInBuffer = makeZstdInBuffer(inBuffer);
453*3117ece4Schristos     // Decompress the whole buffer and send it to the output queue
454*3117ece4Schristos     while (!inBuffer.empty() && !errorHolder.hasError()) {
455*3117ece4Schristos       // Allocate a buffer with at least outSize bytes.
456*3117ece4Schristos       Buffer outBuffer(outSize);
457*3117ece4Schristos       auto zstdOutBuffer = makeZstdOutBuffer(outBuffer);
458*3117ece4Schristos       // Decompress
459*3117ece4Schristos       returnCode =
460*3117ece4Schristos           ZSTD_decompressStream(ctx.get(), &zstdOutBuffer, &zstdInBuffer);
461*3117ece4Schristos       if (!errorHolder.check(
462*3117ece4Schristos               !ZSTD_isError(returnCode), ZSTD_getErrorName(returnCode))) {
463*3117ece4Schristos         return;
464*3117ece4Schristos       }
465*3117ece4Schristos       // Pass the buffer with the decompressed data to the output queue
466*3117ece4Schristos       out->push(split(outBuffer, zstdOutBuffer));
467*3117ece4Schristos       // Advance past the input we already read
468*3117ece4Schristos       advance(inBuffer, zstdInBuffer);
469*3117ece4Schristos       if (returnCode == 0) {
470*3117ece4Schristos         // The frame is over, prepare to (maybe) start a new frame
471*3117ece4Schristos         ZSTD_initDStream(ctx.get());
472*3117ece4Schristos       }
473*3117ece4Schristos     }
474*3117ece4Schristos   }
475*3117ece4Schristos   if (!errorHolder.check(returnCode <= 1, "Incomplete block")) {
476*3117ece4Schristos     return;
477*3117ece4Schristos   }
478*3117ece4Schristos   // We've given ZSTD_decompressStream all of our data, but there may still
479*3117ece4Schristos   // be data to read.
480*3117ece4Schristos   while (returnCode == 1) {
481*3117ece4Schristos     // Allocate a buffer with at least outSize bytes.
482*3117ece4Schristos     Buffer outBuffer(outSize);
483*3117ece4Schristos     auto zstdOutBuffer = makeZstdOutBuffer(outBuffer);
484*3117ece4Schristos     // Pass in no input.
485*3117ece4Schristos     ZSTD_inBuffer zstdInBuffer{nullptr, 0, 0};
486*3117ece4Schristos     // Decompress
487*3117ece4Schristos     returnCode =
488*3117ece4Schristos         ZSTD_decompressStream(ctx.get(), &zstdOutBuffer, &zstdInBuffer);
489*3117ece4Schristos     if (!errorHolder.check(
490*3117ece4Schristos             !ZSTD_isError(returnCode), ZSTD_getErrorName(returnCode))) {
491*3117ece4Schristos       return;
492*3117ece4Schristos     }
493*3117ece4Schristos     // Pass the buffer with the decompressed data to the output queue
494*3117ece4Schristos     out->push(split(outBuffer, zstdOutBuffer));
495*3117ece4Schristos   }
496*3117ece4Schristos }
497*3117ece4Schristos 
498*3117ece4Schristos std::uint64_t asyncDecompressFrames(
499*3117ece4Schristos     SharedState& state,
500*3117ece4Schristos     WorkQueue<std::shared_ptr<BufferWorkQueue>>& frames,
501*3117ece4Schristos     ThreadPool& executor,
502*3117ece4Schristos     FILE* fd) {
503*3117ece4Schristos   auto framesGuard = makeScopeGuard([&] { frames.finish(); });
504*3117ece4Schristos   std::uint64_t totalBytesRead = 0;
505*3117ece4Schristos 
506*3117ece4Schristos   // Split the source up into its component frames.
507*3117ece4Schristos   // If we find our recognized skippable frame we know the next frames size
508*3117ece4Schristos   // which means that we can decompress each standard frame in independently.
509*3117ece4Schristos   // Otherwise, we will decompress using only one decompression task.
510*3117ece4Schristos   const size_t chunkSize = ZSTD_DStreamInSize();
511*3117ece4Schristos   auto status = FileStatus::Continue;
512*3117ece4Schristos   while (status == FileStatus::Continue && !state.errorHolder.hasError()) {
513*3117ece4Schristos     // Make a new input queue that we will put the frames's bytes into.
514*3117ece4Schristos     auto in = std::make_shared<BufferWorkQueue>();
515*3117ece4Schristos     auto inGuard = makeScopeGuard([&] { in->finish(); });
516*3117ece4Schristos     // Make a output queue that decompress will put the decompressed data into
517*3117ece4Schristos     auto out = std::make_shared<BufferWorkQueue>();
518*3117ece4Schristos 
519*3117ece4Schristos     size_t frameSize;
520*3117ece4Schristos     {
521*3117ece4Schristos       // Calculate the size of the next frame.
522*3117ece4Schristos       // frameSize is 0 if the frame info can't be decoded.
523*3117ece4Schristos       Buffer buffer(SkippableFrame::kSize);
524*3117ece4Schristos       auto bytesRead = std::fread(buffer.data(), 1, buffer.size(), fd);
525*3117ece4Schristos       totalBytesRead += bytesRead;
526*3117ece4Schristos       status = fileStatus(fd);
527*3117ece4Schristos       if (bytesRead == 0 && status != FileStatus::Continue) {
528*3117ece4Schristos         break;
529*3117ece4Schristos       }
530*3117ece4Schristos       buffer.subtract(buffer.size() - bytesRead);
531*3117ece4Schristos       frameSize = SkippableFrame::tryRead(buffer.range());
532*3117ece4Schristos       in->push(std::move(buffer));
533*3117ece4Schristos     }
534*3117ece4Schristos     if (frameSize == 0) {
535*3117ece4Schristos       // We hit a non SkippableFrame, so this will be the last job.
536*3117ece4Schristos       // Make sure that we don't use too much memory
537*3117ece4Schristos       in->setMaxSize(64);
538*3117ece4Schristos       out->setMaxSize(64);
539*3117ece4Schristos     }
540*3117ece4Schristos     // Start decompression in the thread pool
541*3117ece4Schristos     executor.add([&state, in, out] {
542*3117ece4Schristos       return decompress(state, std::move(in), std::move(out));
543*3117ece4Schristos     });
544*3117ece4Schristos     // Pass the output queue to the writer thread
545*3117ece4Schristos     frames.push(std::move(out));
546*3117ece4Schristos     if (frameSize == 0) {
547*3117ece4Schristos       // We hit a non SkippableFrame ==> not compressed by pzstd or corrupted
548*3117ece4Schristos       // Pass the rest of the source to this decompression task
549*3117ece4Schristos       state.log(kLogVerbose, "%s\n",
550*3117ece4Schristos           "Input not in pzstd format, falling back to serial decompression");
551*3117ece4Schristos       while (status == FileStatus::Continue && !state.errorHolder.hasError()) {
552*3117ece4Schristos         status = readData(*in, chunkSize, chunkSize, fd, &totalBytesRead);
553*3117ece4Schristos       }
554*3117ece4Schristos       break;
555*3117ece4Schristos     }
556*3117ece4Schristos     state.log(kLogVerbose, "Decompressing a frame of size %zu", frameSize);
557*3117ece4Schristos     // Fill the input queue for the decompression job we just started
558*3117ece4Schristos     status = readData(*in, chunkSize, frameSize, fd, &totalBytesRead);
559*3117ece4Schristos   }
560*3117ece4Schristos   state.errorHolder.check(status != FileStatus::Error, "Error reading input");
561*3117ece4Schristos   return totalBytesRead;
562*3117ece4Schristos }
563*3117ece4Schristos 
564*3117ece4Schristos /// Write `data` to `fd`, returns true iff success.
565*3117ece4Schristos static bool writeData(ByteRange data, FILE* fd) {
566*3117ece4Schristos   while (!data.empty()) {
567*3117ece4Schristos     data.advance(std::fwrite(data.begin(), 1, data.size(), fd));
568*3117ece4Schristos     if (std::ferror(fd)) {
569*3117ece4Schristos       return false;
570*3117ece4Schristos     }
571*3117ece4Schristos   }
572*3117ece4Schristos   return true;
573*3117ece4Schristos }
574*3117ece4Schristos 
575*3117ece4Schristos std::uint64_t writeFile(
576*3117ece4Schristos     SharedState& state,
577*3117ece4Schristos     WorkQueue<std::shared_ptr<BufferWorkQueue>>& outs,
578*3117ece4Schristos     FILE* outputFd,
579*3117ece4Schristos     bool decompress) {
580*3117ece4Schristos   auto& errorHolder = state.errorHolder;
581*3117ece4Schristos   auto lineClearGuard = makeScopeGuard([&state] {
582*3117ece4Schristos     state.log.clear(kLogInfo);
583*3117ece4Schristos   });
584*3117ece4Schristos   std::uint64_t bytesWritten = 0;
585*3117ece4Schristos   std::shared_ptr<BufferWorkQueue> out;
586*3117ece4Schristos   // Grab the output queue for each decompression job (in order).
587*3117ece4Schristos   while (outs.pop(out)) {
588*3117ece4Schristos     if (errorHolder.hasError()) {
589*3117ece4Schristos       continue;
590*3117ece4Schristos     }
591*3117ece4Schristos     if (!decompress) {
592*3117ece4Schristos       // If we are compressing and want to write skippable frames we can't
593*3117ece4Schristos       // start writing before compression is done because we need to know the
594*3117ece4Schristos       // compressed size.
595*3117ece4Schristos       // Wait for the compressed size to be available and write skippable frame
596*3117ece4Schristos       assert(uint64_t(out->size()) < uint64_t(1) << 32);
597*3117ece4Schristos       SkippableFrame frame(uint32_t(out->size()));
598*3117ece4Schristos       if (!writeData(frame.data(), outputFd)) {
599*3117ece4Schristos         errorHolder.setError("Failed to write output");
600*3117ece4Schristos         return bytesWritten;
601*3117ece4Schristos       }
602*3117ece4Schristos       bytesWritten += frame.kSize;
603*3117ece4Schristos     }
604*3117ece4Schristos     // For each chunk of the frame: Pop it from the queue and write it
605*3117ece4Schristos     Buffer buffer;
606*3117ece4Schristos     while (out->pop(buffer) && !errorHolder.hasError()) {
607*3117ece4Schristos       if (!writeData(buffer.range(), outputFd)) {
608*3117ece4Schristos         errorHolder.setError("Failed to write output");
609*3117ece4Schristos         return bytesWritten;
610*3117ece4Schristos       }
611*3117ece4Schristos       bytesWritten += buffer.size();
612*3117ece4Schristos       state.log.update(kLogInfo, "Written: %u MB   ",
613*3117ece4Schristos                 static_cast<std::uint32_t>(bytesWritten >> 20));
614*3117ece4Schristos     }
615*3117ece4Schristos   }
616*3117ece4Schristos   return bytesWritten;
617*3117ece4Schristos }
618*3117ece4Schristos }
619