xref: /netbsd-src/common/dist/zlib/infback.c (revision 3117ece4fc4a4ca4489ba793710b60b0d26bab6c)
1 /*	$NetBSD: infback.c,v 1.5 2024/09/22 19:12:27 christos Exp $	*/
2 
3 /* infback.c -- inflate using a call-back interface
4  * Copyright (C) 1995-2022 Mark Adler
5  * For conditions of distribution and use, see copyright notice in zlib.h
6  */
7 
8 /*
9    This code is largely copied from inflate.c.  Normally either infback.o or
10    inflate.o would be linked into an application--not both.  The interface
11    with inffast.c is retained so that optimized assembler-coded versions of
12    inflate_fast() can be used with either inflate.c or infback.c.
13  */
14 
15 #include "zutil.h"
16 #include "inftrees.h"
17 #include "inflate.h"
18 #include "inffast.h"
19 
20 /*
21    strm provides memory allocation functions in zalloc and zfree, or
22    Z_NULL to use the library memory allocation functions.
23 
24    windowBits is in the range 8..15, and window is a user-supplied
25    window and output buffer that is 2**windowBits bytes.
26  */
27 int ZEXPORT inflateBackInit_(z_streamp strm, int windowBits,
28                              unsigned char FAR *window, const char *version,
29                              int stream_size) {
30     struct inflate_state FAR *state;
31 
32     if (version == Z_NULL || version[0] != ZLIB_VERSION[0] ||
33         stream_size != (int)(sizeof(z_stream)))
34         return Z_VERSION_ERROR;
35     if (strm == Z_NULL || window == Z_NULL ||
36         windowBits < 8 || windowBits > 15)
37         return Z_STREAM_ERROR;
38     strm->msg = Z_NULL;                 /* in case we return an error */
39     if (strm->zalloc == (alloc_func)0) {
40 #ifdef Z_SOLO
41         return Z_STREAM_ERROR;
42 #else
43         strm->zalloc = zcalloc;
44         strm->opaque = (voidpf)0;
45 #endif
46     }
47     if (strm->zfree == (free_func)0)
48 #ifdef Z_SOLO
49         return Z_STREAM_ERROR;
50 #else
51     strm->zfree = zcfree;
52 #endif
53     state = (struct inflate_state FAR *)ZALLOC(strm, 1,
54                                                sizeof(struct inflate_state));
55     if (state == Z_NULL) return Z_MEM_ERROR;
56     Tracev((stderr, "inflate: allocated\n"));
57     strm->state = (struct internal_state FAR *)state;
58     state->dmax = 32768U;
59     state->wbits = (uInt)windowBits;
60     state->wsize = 1U << windowBits;
61     state->window = window;
62     state->wnext = 0;
63     state->whave = 0;
64     state->sane = 1;
65     return Z_OK;
66 }
67 
68 /*
69    Return state with length and distance decoding tables and index sizes set to
70    fixed code decoding.  Normally this returns fixed tables from inffixed.h.
71    If BUILDFIXED is defined, then instead this routine builds the tables the
72    first time it's called, and returns those tables the first time and
73    thereafter.  This reduces the size of the code by about 2K bytes, in
74    exchange for a little execution time.  However, BUILDFIXED should not be
75    used for threaded applications, since the rewriting of the tables and virgin
76    may not be thread-safe.
77  */
78 local void fixedtables(struct inflate_state FAR *state) {
79 #ifdef BUILDFIXED
80     static int virgin = 1;
81     static code *lenfix, *distfix;
82     static code fixed[544];
83 
84     /* build fixed huffman tables if first call (may not be thread safe) */
85     if (virgin) {
86         unsigned sym, bits;
87         static code *next;
88 
89         /* literal/length table */
90         sym = 0;
91         while (sym < 144) state->lens[sym++] = 8;
92         while (sym < 256) state->lens[sym++] = 9;
93         while (sym < 280) state->lens[sym++] = 7;
94         while (sym < 288) state->lens[sym++] = 8;
95         next = fixed;
96         lenfix = next;
97         bits = 9;
98         inflate_table(LENS, state->lens, 288, &(next), &(bits), state->work);
99 
100         /* distance table */
101         sym = 0;
102         while (sym < 32) state->lens[sym++] = 5;
103         distfix = next;
104         bits = 5;
105         inflate_table(DISTS, state->lens, 32, &(next), &(bits), state->work);
106 
107         /* do this just once */
108         virgin = 0;
109     }
110 #else /* !BUILDFIXED */
111 #   include "inffixed.h"
112 #endif /* BUILDFIXED */
113     state->lencode = lenfix;
114     state->lenbits = 9;
115     state->distcode = distfix;
116     state->distbits = 5;
117 }
118 
119 /* Macros for inflateBack(): */
120 
121 /* Load returned state from inflate_fast() */
122 #define LOAD() \
123     do { \
124         put = strm->next_out; \
125         left = strm->avail_out; \
126         next = strm->next_in; \
127         have = strm->avail_in; \
128         hold = state->hold; \
129         bits = state->bits; \
130     } while (0)
131 
132 /* Set state from registers for inflate_fast() */
133 #define RESTORE() \
134     do { \
135         strm->next_out = put; \
136         strm->avail_out = left; \
137         strm->next_in = next; \
138         strm->avail_in = have; \
139         state->hold = hold; \
140         state->bits = bits; \
141     } while (0)
142 
143 /* Clear the input bit accumulator */
144 #define INITBITS() \
145     do { \
146         hold = 0; \
147         bits = 0; \
148     } while (0)
149 
150 /* Assure that some input is available.  If input is requested, but denied,
151    then return a Z_BUF_ERROR from inflateBack(). */
152 #define PULL() \
153     do { \
154         if (have == 0) { \
155             have = in(in_desc, &next); \
156             if (have == 0) { \
157                 next = Z_NULL; \
158                 ret = Z_BUF_ERROR; \
159                 goto inf_leave; \
160             } \
161         } \
162     } while (0)
163 
164 /* Get a byte of input into the bit accumulator, or return from inflateBack()
165    with an error if there is no input available. */
166 #define PULLBYTE() \
167     do { \
168         PULL(); \
169         have--; \
170         hold += (unsigned long)(*next++) << bits; \
171         bits += 8; \
172     } while (0)
173 
174 /* Assure that there are at least n bits in the bit accumulator.  If there is
175    not enough available input to do that, then return from inflateBack() with
176    an error. */
177 #define NEEDBITS(n) \
178     do { \
179         while (bits < (unsigned)(n)) \
180             PULLBYTE(); \
181     } while (0)
182 
183 /* Return the low n bits of the bit accumulator (n < 16) */
184 #define BITS(n) \
185     ((unsigned)hold & ((1U << (n)) - 1))
186 
187 /* Remove n bits from the bit accumulator */
188 #define DROPBITS(n) \
189     do { \
190         hold >>= (n); \
191         bits -= (unsigned)(n); \
192     } while (0)
193 
194 /* Remove zero to seven bits as needed to go to a byte boundary */
195 #define BYTEBITS() \
196     do { \
197         hold >>= bits & 7; \
198         bits -= bits & 7; \
199     } while (0)
200 
201 /* Assure that some output space is available, by writing out the window
202    if it's full.  If the write fails, return from inflateBack() with a
203    Z_BUF_ERROR. */
204 #define ROOM() \
205     do { \
206         if (left == 0) { \
207             put = state->window; \
208             left = state->wsize; \
209             state->whave = left; \
210             if (out(out_desc, put, left)) { \
211                 ret = Z_BUF_ERROR; \
212                 goto inf_leave; \
213             } \
214         } \
215     } while (0)
216 
217 /*
218    strm provides the memory allocation functions and window buffer on input,
219    and provides information on the unused input on return.  For Z_DATA_ERROR
220    returns, strm will also provide an error message.
221 
222    in() and out() are the call-back input and output functions.  When
223    inflateBack() needs more input, it calls in().  When inflateBack() has
224    filled the window with output, or when it completes with data in the
225    window, it calls out() to write out the data.  The application must not
226    change the provided input until in() is called again or inflateBack()
227    returns.  The application must not change the window/output buffer until
228    inflateBack() returns.
229 
230    in() and out() are called with a descriptor parameter provided in the
231    inflateBack() call.  This parameter can be a structure that provides the
232    information required to do the read or write, as well as accumulated
233    information on the input and output such as totals and check values.
234 
235    in() should return zero on failure.  out() should return non-zero on
236    failure.  If either in() or out() fails, than inflateBack() returns a
237    Z_BUF_ERROR.  strm->next_in can be checked for Z_NULL to see whether it
238    was in() or out() that caused in the error.  Otherwise,  inflateBack()
239    returns Z_STREAM_END on success, Z_DATA_ERROR for an deflate format
240    error, or Z_MEM_ERROR if it could not allocate memory for the state.
241    inflateBack() can also return Z_STREAM_ERROR if the input parameters
242    are not correct, i.e. strm is Z_NULL or the state was not initialized.
243  */
244 int ZEXPORT inflateBack(z_streamp strm, in_func in, void FAR *in_desc,
245                         out_func out, void FAR *out_desc) {
246     struct inflate_state FAR *state;
247     z_const unsigned char FAR *next;    /* next input */
248     unsigned char FAR *put;     /* next output */
249     unsigned have, left;        /* available input and output */
250     unsigned long hold;         /* bit buffer */
251     unsigned bits;              /* bits in bit buffer */
252     unsigned copy;              /* number of stored or match bytes to copy */
253     unsigned char FAR *from;    /* where to copy match bytes from */
254     code here;                  /* current decoding table entry */
255     code last;                  /* parent table entry */
256     unsigned len;               /* length to copy for repeats, bits to drop */
257     int ret;                    /* return code */
258     static const unsigned short order[19] = /* permutation of code lengths */
259         {16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15};
260 
261     /* Check that the strm exists and that the state was initialized */
262     if (strm == Z_NULL || strm->state == Z_NULL)
263         return Z_STREAM_ERROR;
264     state = (struct inflate_state FAR *)strm->state;
265 
266     /* Reset the state */
267     strm->msg = Z_NULL;
268     state->mode = TYPE;
269     state->last = 0;
270     state->whave = 0;
271     next = strm->next_in;
272     have = next != Z_NULL ? strm->avail_in : 0;
273     hold = 0;
274     bits = 0;
275     put = state->window;
276     left = state->wsize;
277 
278     /* Inflate until end of block marked as last */
279     for (;;)
280         switch (state->mode) {
281         case TYPE:
282             /* determine and dispatch block type */
283             if (state->last) {
284                 BYTEBITS();
285                 state->mode = DONE;
286                 break;
287             }
288             NEEDBITS(3);
289             state->last = BITS(1);
290             DROPBITS(1);
291             switch (BITS(2)) {
292             case 0:                             /* stored block */
293                 Tracev((stderr, "inflate:     stored block%s\n",
294                         state->last ? " (last)" : ""));
295                 state->mode = STORED;
296                 break;
297             case 1:                             /* fixed block */
298                 fixedtables(state);
299                 Tracev((stderr, "inflate:     fixed codes block%s\n",
300                         state->last ? " (last)" : ""));
301                 state->mode = LEN;              /* decode codes */
302                 break;
303             case 2:                             /* dynamic block */
304                 Tracev((stderr, "inflate:     dynamic codes block%s\n",
305                         state->last ? " (last)" : ""));
306                 state->mode = TABLE;
307                 break;
308             case 3:
309                 strm->msg = __UNCONST("invalid block type");
310                 state->mode = BAD;
311             }
312             DROPBITS(2);
313             break;
314 
315         case STORED:
316             /* get and verify stored block length */
317             BYTEBITS();                         /* go to byte boundary */
318             NEEDBITS(32);
319             if ((hold & 0xffff) != ((hold >> 16) ^ 0xffff)) {
320                 strm->msg = __UNCONST("invalid stored block lengths");
321                 state->mode = BAD;
322                 break;
323             }
324             state->length = (unsigned)hold & 0xffff;
325             Tracev((stderr, "inflate:       stored length %u\n",
326                     state->length));
327             INITBITS();
328 
329             /* copy stored block from input to output */
330             while (state->length != 0) {
331                 copy = state->length;
332                 PULL();
333                 ROOM();
334                 if (copy > have) copy = have;
335                 if (copy > left) copy = left;
336                 zmemcpy(put, next, copy);
337                 have -= copy;
338                 next += copy;
339                 left -= copy;
340                 put += copy;
341                 state->length -= copy;
342             }
343             Tracev((stderr, "inflate:       stored end\n"));
344             state->mode = TYPE;
345             break;
346 
347         case TABLE:
348             /* get dynamic table entries descriptor */
349             NEEDBITS(14);
350             state->nlen = BITS(5) + 257;
351             DROPBITS(5);
352             state->ndist = BITS(5) + 1;
353             DROPBITS(5);
354             state->ncode = BITS(4) + 4;
355             DROPBITS(4);
356 #ifndef PKZIP_BUG_WORKAROUND
357             if (state->nlen > 286 || state->ndist > 30) {
358                 strm->msg = __UNCONST("too many length or distance symbols");
359                 state->mode = BAD;
360                 break;
361             }
362 #endif
363             Tracev((stderr, "inflate:       table sizes ok\n"));
364 
365             /* get code length code lengths (not a typo) */
366             state->have = 0;
367             while (state->have < state->ncode) {
368                 NEEDBITS(3);
369                 state->lens[order[state->have++]] = (unsigned short)BITS(3);
370                 DROPBITS(3);
371             }
372             while (state->have < 19)
373                 state->lens[order[state->have++]] = 0;
374             state->next = state->codes;
375             state->lencode = (code const FAR *)(state->next);
376             state->lenbits = 7;
377             ret = inflate_table(CODES, state->lens, 19, &(state->next),
378                                 &(state->lenbits), state->work);
379             if (ret) {
380                 strm->msg = __UNCONST("invalid code lengths set");
381                 state->mode = BAD;
382                 break;
383             }
384             Tracev((stderr, "inflate:       code lengths ok\n"));
385 
386             /* get length and distance code code lengths */
387             state->have = 0;
388             while (state->have < state->nlen + state->ndist) {
389                 for (;;) {
390                     here = state->lencode[BITS(state->lenbits)];
391                     if ((unsigned)(here.bits) <= bits) break;
392                     PULLBYTE();
393                 }
394                 if (here.val < 16) {
395                     DROPBITS(here.bits);
396                     state->lens[state->have++] = here.val;
397                 }
398                 else {
399                     if (here.val == 16) {
400                         NEEDBITS(here.bits + 2);
401                         DROPBITS(here.bits);
402                         if (state->have == 0) {
403                             strm->msg = __UNCONST("invalid bit length repeat");
404                             state->mode = BAD;
405                             break;
406                         }
407                         len = (unsigned)(state->lens[state->have - 1]);
408                         copy = 3 + BITS(2);
409                         DROPBITS(2);
410                     }
411                     else if (here.val == 17) {
412                         NEEDBITS(here.bits + 3);
413                         DROPBITS(here.bits);
414                         len = 0;
415                         copy = 3 + BITS(3);
416                         DROPBITS(3);
417                     }
418                     else {
419                         NEEDBITS(here.bits + 7);
420                         DROPBITS(here.bits);
421                         len = 0;
422                         copy = 11 + BITS(7);
423                         DROPBITS(7);
424                     }
425                     if (state->have + copy > state->nlen + state->ndist) {
426                         strm->msg = __UNCONST("invalid bit length repeat");
427                         state->mode = BAD;
428                         break;
429                     }
430                     while (copy--)
431                         state->lens[state->have++] = (unsigned short)len;
432                 }
433             }
434 
435             /* handle error breaks in while */
436             if (state->mode == BAD) break;
437 
438             /* check for end-of-block code (better have one) */
439             if (state->lens[256] == 0) {
440                 strm->msg = __UNCONST("invalid code -- missing end-of-block");
441                 state->mode = BAD;
442                 break;
443             }
444 
445             /* build code tables -- note: do not change the lenbits or distbits
446                values here (9 and 6) without reading the comments in inftrees.h
447                concerning the ENOUGH constants, which depend on those values */
448             state->next = state->codes;
449             state->lencode = (code const FAR *)(state->next);
450             state->lenbits = 9;
451             ret = inflate_table(LENS, state->lens, state->nlen, &(state->next),
452                                 &(state->lenbits), state->work);
453             if (ret) {
454                 strm->msg = __UNCONST("invalid literal/lengths set");
455                 state->mode = BAD;
456                 break;
457             }
458             state->distcode = (code const FAR *)(state->next);
459             state->distbits = 6;
460             ret = inflate_table(DISTS, state->lens + state->nlen, state->ndist,
461                             &(state->next), &(state->distbits), state->work);
462             if (ret) {
463                 strm->msg = __UNCONST("invalid distances set");
464                 state->mode = BAD;
465                 break;
466             }
467             Tracev((stderr, "inflate:       codes ok\n"));
468             state->mode = LEN;
469                 /* fallthrough */
470 
471         case LEN:
472             /* use inflate_fast() if we have enough input and output */
473             if (have >= 6 && left >= 258) {
474                 RESTORE();
475                 if (state->whave < state->wsize)
476                     state->whave = state->wsize - left;
477                 inflate_fast(strm, state->wsize);
478                 LOAD();
479                 break;
480             }
481 
482             /* get a literal, length, or end-of-block code */
483             for (;;) {
484                 here = state->lencode[BITS(state->lenbits)];
485                 if ((unsigned)(here.bits) <= bits) break;
486                 PULLBYTE();
487             }
488             if (here.op && (here.op & 0xf0) == 0) {
489                 last = here;
490                 for (;;) {
491                     here = state->lencode[last.val +
492                             (BITS(last.bits + last.op) >> last.bits)];
493                     if ((unsigned)(last.bits + here.bits) <= bits) break;
494                     PULLBYTE();
495                 }
496                 DROPBITS(last.bits);
497             }
498             DROPBITS(here.bits);
499             state->length = (unsigned)here.val;
500 
501             /* process literal */
502             if (here.op == 0) {
503                 Tracevv((stderr, here.val >= 0x20 && here.val < 0x7f ?
504                         "inflate:         literal '%c'\n" :
505                         "inflate:         literal 0x%02x\n", here.val));
506                 ROOM();
507                 *put++ = (unsigned char)(state->length);
508                 left--;
509                 state->mode = LEN;
510                 break;
511             }
512 
513             /* process end of block */
514             if (here.op & 32) {
515                 Tracevv((stderr, "inflate:         end of block\n"));
516                 state->mode = TYPE;
517                 break;
518             }
519 
520             /* invalid code */
521             if (here.op & 64) {
522                 strm->msg = __UNCONST("invalid literal/length code");
523                 state->mode = BAD;
524                 break;
525             }
526 
527             /* length code -- get extra bits, if any */
528             state->extra = (unsigned)(here.op) & 15;
529             if (state->extra != 0) {
530                 NEEDBITS(state->extra);
531                 state->length += BITS(state->extra);
532                 DROPBITS(state->extra);
533             }
534             Tracevv((stderr, "inflate:         length %u\n", state->length));
535 
536             /* get distance code */
537             for (;;) {
538                 here = state->distcode[BITS(state->distbits)];
539                 if ((unsigned)(here.bits) <= bits) break;
540                 PULLBYTE();
541             }
542             if ((here.op & 0xf0) == 0) {
543                 last = here;
544                 for (;;) {
545                     here = state->distcode[last.val +
546                             (BITS(last.bits + last.op) >> last.bits)];
547                     if ((unsigned)(last.bits + here.bits) <= bits) break;
548                     PULLBYTE();
549                 }
550                 DROPBITS(last.bits);
551             }
552             DROPBITS(here.bits);
553             if (here.op & 64) {
554                 strm->msg = __UNCONST("invalid distance code");
555                 state->mode = BAD;
556                 break;
557             }
558             state->offset = (unsigned)here.val;
559 
560             /* get distance extra bits, if any */
561             state->extra = (unsigned)(here.op) & 15;
562             if (state->extra != 0) {
563                 NEEDBITS(state->extra);
564                 state->offset += BITS(state->extra);
565                 DROPBITS(state->extra);
566             }
567             if (state->offset > state->wsize - (state->whave < state->wsize ?
568                                                 left : 0)) {
569                 strm->msg = __UNCONST("invalid distance too far back");
570                 state->mode = BAD;
571                 break;
572             }
573             Tracevv((stderr, "inflate:         distance %u\n", state->offset));
574 
575             /* copy match from window to output */
576             do {
577                 ROOM();
578                 copy = state->wsize - state->offset;
579                 if (copy < left) {
580                     from = put + copy;
581                     copy = left - copy;
582                 }
583                 else {
584                     from = put - state->offset;
585                     copy = left;
586                 }
587                 if (copy > state->length) copy = state->length;
588                 state->length -= copy;
589                 left -= copy;
590                 do {
591                     *put++ = *from++;
592                 } while (--copy);
593             } while (state->length != 0);
594             break;
595 
596         case DONE:
597             /* inflate stream terminated properly */
598             ret = Z_STREAM_END;
599             goto inf_leave;
600 
601         case BAD:
602             ret = Z_DATA_ERROR;
603             goto inf_leave;
604 
605         default:
606             /* can't happen, but makes compilers happy */
607             ret = Z_STREAM_ERROR;
608             goto inf_leave;
609         }
610 
611     /* Write leftover output and return unused input */
612   inf_leave:
613     if (left < state->wsize) {
614         if (out(out_desc, state->window, state->wsize - left) &&
615             ret == Z_STREAM_END)
616             ret = Z_BUF_ERROR;
617     }
618     strm->next_in = next;
619     strm->avail_in = have;
620     return ret;
621 }
622 
623 int ZEXPORT inflateBackEnd(z_streamp strm) {
624     if (strm == Z_NULL || strm->state == Z_NULL || strm->zfree == (free_func)0)
625         return Z_STREAM_ERROR;
626     ZFREE(strm, strm->state);
627     strm->state = Z_NULL;
628     Tracev((stderr, "inflate: end\n"));
629     return Z_OK;
630 }
631