xref: /openbsd-src/lib/libz/infback.c (revision 3374c67d44f9b75b98444cbf63020f777792342e)
1 /* infback.c -- inflate using a call-back interface
2  * Copyright (C) 1995-2022 Mark Adler
3  * For conditions of distribution and use, see copyright notice in zlib.h
4  */
5 
6 /*
7    This code is largely copied from inflate.c.  Normally either infback.o or
8    inflate.o would be linked into an application--not both.  The interface
9    with inffast.c is retained so that optimized assembler-coded versions of
10    inflate_fast() can be used with either inflate.c or infback.c.
11  */
12 
13 #include "zutil.h"
14 #include "inftrees.h"
15 #include "inflate.h"
16 #include "inffast.h"
17 
18 /* function prototypes */
19 local void fixedtables OF((struct inflate_state FAR *state));
20 
21 /*
22    strm provides memory allocation functions in zalloc and zfree, or
23    Z_NULL to use the library memory allocation functions.
24 
25    windowBits is in the range 8..15, and window is a user-supplied
26    window and output buffer that is 2**windowBits bytes.
27  */
28 int ZEXPORT inflateBackInit_(strm, windowBits, window, version, stream_size)
29 z_streamp strm;
30 int windowBits;
31 unsigned char FAR *window;
32 const char *version;
33 int stream_size;
34 {
35     struct inflate_state FAR *state;
36 
37     if (version == Z_NULL || version[0] != ZLIB_VERSION[0] ||
38         stream_size != (int)(sizeof(z_stream)))
39         return Z_VERSION_ERROR;
40     if (strm == Z_NULL || window == Z_NULL ||
41         windowBits < 8 || windowBits > 15)
42         return Z_STREAM_ERROR;
43     strm->msg = Z_NULL;                 /* in case we return an error */
44     if (strm->zalloc == (alloc_func)0) {
45 #ifdef Z_SOLO
46         return Z_STREAM_ERROR;
47 #else
48         strm->zalloc = zcalloc;
49         strm->opaque = (voidpf)0;
50 #endif
51     }
52     if (strm->zfree == (free_func)0)
53 #ifdef Z_SOLO
54         return Z_STREAM_ERROR;
55 #else
56     strm->zfree = zcfree;
57 #endif
58     state = (struct inflate_state FAR *)ZALLOC(strm, 1,
59                                                sizeof(struct inflate_state));
60     if (state == Z_NULL) return Z_MEM_ERROR;
61     Tracev((stderr, "inflate: allocated\n"));
62     strm->state = (struct internal_state FAR *)state;
63     state->dmax = 32768U;
64     state->wbits = (uInt)windowBits;
65     state->wsize = 1U << windowBits;
66     state->window = window;
67     state->wnext = 0;
68     state->whave = 0;
69     state->sane = 1;
70     return Z_OK;
71 }
72 
73 /*
74    Return state with length and distance decoding tables and index sizes set to
75    fixed code decoding.  Normally this returns fixed tables from inffixed.h.
76    If BUILDFIXED is defined, then instead this routine builds the tables the
77    first time it's called, and returns those tables the first time and
78    thereafter.  This reduces the size of the code by about 2K bytes, in
79    exchange for a little execution time.  However, BUILDFIXED should not be
80    used for threaded applications, since the rewriting of the tables and virgin
81    may not be thread-safe.
82  */
83 local void fixedtables(state)
84 struct inflate_state FAR *state;
85 {
86 #ifdef BUILDFIXED
87     static int virgin = 1;
88     static code *lenfix, *distfix;
89     static code fixed[544];
90 
91     /* build fixed huffman tables if first call (may not be thread safe) */
92     if (virgin) {
93         unsigned sym, bits;
94         static code *next;
95 
96         /* literal/length table */
97         sym = 0;
98         while (sym < 144) state->lens[sym++] = 8;
99         while (sym < 256) state->lens[sym++] = 9;
100         while (sym < 280) state->lens[sym++] = 7;
101         while (sym < 288) state->lens[sym++] = 8;
102         next = fixed;
103         lenfix = next;
104         bits = 9;
105         inflate_table(LENS, state->lens, 288, &(next), &(bits), state->work);
106 
107         /* distance table */
108         sym = 0;
109         while (sym < 32) state->lens[sym++] = 5;
110         distfix = next;
111         bits = 5;
112         inflate_table(DISTS, state->lens, 32, &(next), &(bits), state->work);
113 
114         /* do this just once */
115         virgin = 0;
116     }
117 #else /* !BUILDFIXED */
118 #   include "inffixed.h"
119 #endif /* BUILDFIXED */
120     state->lencode = lenfix;
121     state->lenbits = 9;
122     state->distcode = distfix;
123     state->distbits = 5;
124 }
125 
126 /* Macros for inflateBack(): */
127 
128 /* Load returned state from inflate_fast() */
129 #define LOAD() \
130     do { \
131         put = strm->next_out; \
132         left = strm->avail_out; \
133         next = strm->next_in; \
134         have = strm->avail_in; \
135         hold = state->hold; \
136         bits = state->bits; \
137     } while (0)
138 
139 /* Set state from registers for inflate_fast() */
140 #define RESTORE() \
141     do { \
142         strm->next_out = put; \
143         strm->avail_out = left; \
144         strm->next_in = next; \
145         strm->avail_in = have; \
146         state->hold = hold; \
147         state->bits = bits; \
148     } while (0)
149 
150 /* Clear the input bit accumulator */
151 #define INITBITS() \
152     do { \
153         hold = 0; \
154         bits = 0; \
155     } while (0)
156 
157 /* Assure that some input is available.  If input is requested, but denied,
158    then return a Z_BUF_ERROR from inflateBack(). */
159 #define PULL() \
160     do { \
161         if (have == 0) { \
162             have = in(in_desc, &next); \
163             if (have == 0) { \
164                 next = Z_NULL; \
165                 ret = Z_BUF_ERROR; \
166                 goto inf_leave; \
167             } \
168         } \
169     } while (0)
170 
171 /* Get a byte of input into the bit accumulator, or return from inflateBack()
172    with an error if there is no input available. */
173 #define PULLBYTE() \
174     do { \
175         PULL(); \
176         have--; \
177         hold += (unsigned long)(*next++) << bits; \
178         bits += 8; \
179     } while (0)
180 
181 /* Assure that there are at least n bits in the bit accumulator.  If there is
182    not enough available input to do that, then return from inflateBack() with
183    an error. */
184 #define NEEDBITS(n) \
185     do { \
186         while (bits < (unsigned)(n)) \
187             PULLBYTE(); \
188     } while (0)
189 
190 /* Return the low n bits of the bit accumulator (n < 16) */
191 #define BITS(n) \
192     ((unsigned)hold & ((1U << (n)) - 1))
193 
194 /* Remove n bits from the bit accumulator */
195 #define DROPBITS(n) \
196     do { \
197         hold >>= (n); \
198         bits -= (unsigned)(n); \
199     } while (0)
200 
201 /* Remove zero to seven bits as needed to go to a byte boundary */
202 #define BYTEBITS() \
203     do { \
204         hold >>= bits & 7; \
205         bits -= bits & 7; \
206     } while (0)
207 
208 /* Assure that some output space is available, by writing out the window
209    if it's full.  If the write fails, return from inflateBack() with a
210    Z_BUF_ERROR. */
211 #define ROOM() \
212     do { \
213         if (left == 0) { \
214             put = state->window; \
215             left = state->wsize; \
216             state->whave = left; \
217             if (out(out_desc, put, left)) { \
218                 ret = Z_BUF_ERROR; \
219                 goto inf_leave; \
220             } \
221         } \
222     } while (0)
223 
224 /*
225    strm provides the memory allocation functions and window buffer on input,
226    and provides information on the unused input on return.  For Z_DATA_ERROR
227    returns, strm will also provide an error message.
228 
229    in() and out() are the call-back input and output functions.  When
230    inflateBack() needs more input, it calls in().  When inflateBack() has
231    filled the window with output, or when it completes with data in the
232    window, it calls out() to write out the data.  The application must not
233    change the provided input until in() is called again or inflateBack()
234    returns.  The application must not change the window/output buffer until
235    inflateBack() returns.
236 
237    in() and out() are called with a descriptor parameter provided in the
238    inflateBack() call.  This parameter can be a structure that provides the
239    information required to do the read or write, as well as accumulated
240    information on the input and output such as totals and check values.
241 
242    in() should return zero on failure.  out() should return non-zero on
243    failure.  If either in() or out() fails, than inflateBack() returns a
244    Z_BUF_ERROR.  strm->next_in can be checked for Z_NULL to see whether it
245    was in() or out() that caused in the error.  Otherwise,  inflateBack()
246    returns Z_STREAM_END on success, Z_DATA_ERROR for an deflate format
247    error, or Z_MEM_ERROR if it could not allocate memory for the state.
248    inflateBack() can also return Z_STREAM_ERROR if the input parameters
249    are not correct, i.e. strm is Z_NULL or the state was not initialized.
250  */
251 int ZEXPORT inflateBack(strm, in, in_desc, out, out_desc)
252 z_streamp strm;
253 in_func in;
254 void FAR *in_desc;
255 out_func out;
256 void FAR *out_desc;
257 {
258     struct inflate_state FAR *state;
259     z_const unsigned char FAR *next;    /* next input */
260     unsigned char FAR *put;     /* next output */
261     unsigned have, left;        /* available input and output */
262     unsigned long hold;         /* bit buffer */
263     unsigned bits;              /* bits in bit buffer */
264     unsigned copy;              /* number of stored or match bytes to copy */
265     unsigned char FAR *from;    /* where to copy match bytes from */
266     code here;                  /* current decoding table entry */
267     code last;                  /* parent table entry */
268     unsigned len;               /* length to copy for repeats, bits to drop */
269     int ret;                    /* return code */
270     static const unsigned short order[19] = /* permutation of code lengths */
271         {16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15};
272 
273     /* Check that the strm exists and that the state was initialized */
274     if (strm == Z_NULL || strm->state == Z_NULL)
275         return Z_STREAM_ERROR;
276     state = (struct inflate_state FAR *)strm->state;
277 
278     /* Reset the state */
279     strm->msg = Z_NULL;
280     state->mode = TYPE;
281     state->last = 0;
282     state->whave = 0;
283     next = strm->next_in;
284     have = next != Z_NULL ? strm->avail_in : 0;
285     hold = 0;
286     bits = 0;
287     put = state->window;
288     left = state->wsize;
289 
290     /* Inflate until end of block marked as last */
291     for (;;)
292         switch (state->mode) {
293         case TYPE:
294             /* determine and dispatch block type */
295             if (state->last) {
296                 BYTEBITS();
297                 state->mode = DONE;
298                 break;
299             }
300             NEEDBITS(3);
301             state->last = BITS(1);
302             DROPBITS(1);
303             switch (BITS(2)) {
304             case 0:                             /* stored block */
305                 Tracev((stderr, "inflate:     stored block%s\n",
306                         state->last ? " (last)" : ""));
307                 state->mode = STORED;
308                 break;
309             case 1:                             /* fixed block */
310                 fixedtables(state);
311                 Tracev((stderr, "inflate:     fixed codes block%s\n",
312                         state->last ? " (last)" : ""));
313                 state->mode = LEN;              /* decode codes */
314                 break;
315             case 2:                             /* dynamic block */
316                 Tracev((stderr, "inflate:     dynamic codes block%s\n",
317                         state->last ? " (last)" : ""));
318                 state->mode = TABLE;
319                 break;
320             case 3:
321 #ifdef SMALL
322                 strm->msg = "error";
323 #else
324                 strm->msg = (char *)"invalid block type";
325 #endif
326                 state->mode = BAD;
327             }
328             DROPBITS(2);
329             break;
330 
331         case STORED:
332             /* get and verify stored block length */
333             BYTEBITS();                         /* go to byte boundary */
334             NEEDBITS(32);
335             if ((hold & 0xffff) != ((hold >> 16) ^ 0xffff)) {
336 #ifdef SMALL
337                 strm->msg = "error";
338 #else
339                 strm->msg = (char *)"invalid stored block lengths";
340 #endif
341                 state->mode = BAD;
342                 break;
343             }
344             state->length = (unsigned)hold & 0xffff;
345             Tracev((stderr, "inflate:       stored length %u\n",
346                     state->length));
347             INITBITS();
348 
349             /* copy stored block from input to output */
350             while (state->length != 0) {
351                 copy = state->length;
352                 PULL();
353                 ROOM();
354                 if (copy > have) copy = have;
355                 if (copy > left) copy = left;
356                 zmemcpy(put, next, copy);
357                 have -= copy;
358                 next += copy;
359                 left -= copy;
360                 put += copy;
361                 state->length -= copy;
362             }
363             Tracev((stderr, "inflate:       stored end\n"));
364             state->mode = TYPE;
365             break;
366 
367         case TABLE:
368             /* get dynamic table entries descriptor */
369             NEEDBITS(14);
370             state->nlen = BITS(5) + 257;
371             DROPBITS(5);
372             state->ndist = BITS(5) + 1;
373             DROPBITS(5);
374             state->ncode = BITS(4) + 4;
375             DROPBITS(4);
376 #ifndef PKZIP_BUG_WORKAROUND
377             if (state->nlen > 286 || state->ndist > 30) {
378 #ifdef SMALL
379                 strm->msg = "error";
380 #else
381                 strm->msg = (char *)"too many length or distance symbols";
382 #endif
383                 state->mode = BAD;
384                 break;
385             }
386 #endif
387             Tracev((stderr, "inflate:       table sizes ok\n"));
388 
389             /* get code length code lengths (not a typo) */
390             state->have = 0;
391             while (state->have < state->ncode) {
392                 NEEDBITS(3);
393                 state->lens[order[state->have++]] = (unsigned short)BITS(3);
394                 DROPBITS(3);
395             }
396             while (state->have < 19)
397                 state->lens[order[state->have++]] = 0;
398             state->next = state->codes;
399             state->lencode = (code const FAR *)(state->next);
400             state->lenbits = 7;
401             ret = inflate_table(CODES, state->lens, 19, &(state->next),
402                                 &(state->lenbits), state->work);
403             if (ret) {
404 #ifdef SMALL
405                 strm->msg = "error";
406 #else
407                 strm->msg = (char *)"invalid code lengths set";
408 #endif
409                 state->mode = BAD;
410                 break;
411             }
412             Tracev((stderr, "inflate:       code lengths ok\n"));
413 
414             /* get length and distance code code lengths */
415             state->have = 0;
416             while (state->have < state->nlen + state->ndist) {
417                 for (;;) {
418                     here = state->lencode[BITS(state->lenbits)];
419                     if ((unsigned)(here.bits) <= bits) break;
420                     PULLBYTE();
421                 }
422                 if (here.val < 16) {
423                     DROPBITS(here.bits);
424                     state->lens[state->have++] = here.val;
425                 }
426                 else {
427                     if (here.val == 16) {
428                         NEEDBITS(here.bits + 2);
429                         DROPBITS(here.bits);
430                         if (state->have == 0) {
431 #ifdef SMALL
432                             strm->msg = "error";
433 #else
434                             strm->msg = (char *)"invalid bit length repeat";
435 #endif
436                             state->mode = BAD;
437                             break;
438                         }
439                         len = (unsigned)(state->lens[state->have - 1]);
440                         copy = 3 + BITS(2);
441                         DROPBITS(2);
442                     }
443                     else if (here.val == 17) {
444                         NEEDBITS(here.bits + 3);
445                         DROPBITS(here.bits);
446                         len = 0;
447                         copy = 3 + BITS(3);
448                         DROPBITS(3);
449                     }
450                     else {
451                         NEEDBITS(here.bits + 7);
452                         DROPBITS(here.bits);
453                         len = 0;
454                         copy = 11 + BITS(7);
455                         DROPBITS(7);
456                     }
457                     if (state->have + copy > state->nlen + state->ndist) {
458 #ifdef SMALL
459                         strm->msg = "error";
460 #else
461                         strm->msg = (char *)"invalid bit length repeat";
462 #endif
463                         state->mode = BAD;
464                         break;
465                     }
466                     while (copy--)
467                         state->lens[state->have++] = (unsigned short)len;
468                 }
469             }
470 
471             /* handle error breaks in while */
472             if (state->mode == BAD) break;
473 
474             /* check for end-of-block code (better have one) */
475             if (state->lens[256] == 0) {
476 #ifdef SMALL
477                 strm->msg = "error";
478 #else
479                 strm->msg = (char *)"invalid code -- missing end-of-block";
480 #endif
481                 state->mode = BAD;
482                 break;
483             }
484 
485             /* build code tables -- note: do not change the lenbits or distbits
486                values here (9 and 6) without reading the comments in inftrees.h
487                concerning the ENOUGH constants, which depend on those values */
488             state->next = state->codes;
489             state->lencode = (code const FAR *)(state->next);
490             state->lenbits = 9;
491             ret = inflate_table(LENS, state->lens, state->nlen, &(state->next),
492                                 &(state->lenbits), state->work);
493             if (ret) {
494 #ifdef SMALL
495                 strm->msg = "error";
496 #else
497                 strm->msg = (char *)"invalid literal/lengths set";
498 #endif
499                 state->mode = BAD;
500                 break;
501             }
502             state->distcode = (code const FAR *)(state->next);
503             state->distbits = 6;
504             ret = inflate_table(DISTS, state->lens + state->nlen, state->ndist,
505                             &(state->next), &(state->distbits), state->work);
506             if (ret) {
507 #ifdef SMALL
508                 strm->msg = "error";
509 #else
510                 strm->msg = (char *)"invalid distances set";
511 #endif
512                 state->mode = BAD;
513                 break;
514             }
515             Tracev((stderr, "inflate:       codes ok\n"));
516             state->mode = LEN;
517                 /* fallthrough */
518 
519         case LEN:
520 #ifndef SLOW
521             /* use inflate_fast() if we have enough input and output */
522             if (have >= 6 && left >= 258) {
523                 RESTORE();
524                 if (state->whave < state->wsize)
525                     state->whave = state->wsize - left;
526                 inflate_fast(strm, state->wsize);
527                 LOAD();
528                 break;
529             }
530 #endif
531 
532             /* get a literal, length, or end-of-block code */
533             for (;;) {
534                 here = state->lencode[BITS(state->lenbits)];
535                 if ((unsigned)(here.bits) <= bits) break;
536                 PULLBYTE();
537             }
538             if (here.op && (here.op & 0xf0) == 0) {
539                 last = here;
540                 for (;;) {
541                     here = state->lencode[last.val +
542                             (BITS(last.bits + last.op) >> last.bits)];
543                     if ((unsigned)(last.bits + here.bits) <= bits) break;
544                     PULLBYTE();
545                 }
546                 DROPBITS(last.bits);
547             }
548             DROPBITS(here.bits);
549             state->length = (unsigned)here.val;
550 
551             /* process literal */
552             if (here.op == 0) {
553                 Tracevv((stderr, here.val >= 0x20 && here.val < 0x7f ?
554                         "inflate:         literal '%c'\n" :
555                         "inflate:         literal 0x%02x\n", here.val));
556                 ROOM();
557                 *put++ = (unsigned char)(state->length);
558                 left--;
559                 state->mode = LEN;
560                 break;
561             }
562 
563             /* process end of block */
564             if (here.op & 32) {
565                 Tracevv((stderr, "inflate:         end of block\n"));
566                 state->mode = TYPE;
567                 break;
568             }
569 
570             /* invalid code */
571             if (here.op & 64) {
572 #ifdef SMALL
573                 strm->msg = "error";
574 #else
575                 strm->msg = (char *)"invalid literal/length code";
576 #endif
577                 state->mode = BAD;
578                 break;
579             }
580 
581             /* length code -- get extra bits, if any */
582             state->extra = (unsigned)(here.op) & 15;
583             if (state->extra != 0) {
584                 NEEDBITS(state->extra);
585                 state->length += BITS(state->extra);
586                 DROPBITS(state->extra);
587             }
588             Tracevv((stderr, "inflate:         length %u\n", state->length));
589 
590             /* get distance code */
591             for (;;) {
592                 here = state->distcode[BITS(state->distbits)];
593                 if ((unsigned)(here.bits) <= bits) break;
594                 PULLBYTE();
595             }
596             if ((here.op & 0xf0) == 0) {
597                 last = here;
598                 for (;;) {
599                     here = state->distcode[last.val +
600                             (BITS(last.bits + last.op) >> last.bits)];
601                     if ((unsigned)(last.bits + here.bits) <= bits) break;
602                     PULLBYTE();
603                 }
604                 DROPBITS(last.bits);
605             }
606             DROPBITS(here.bits);
607             if (here.op & 64) {
608 #ifdef SMALL
609                 strm->msg = "error";
610 #else
611                 strm->msg = (char *)"invalid distance code";
612 #endif
613                 state->mode = BAD;
614                 break;
615             }
616             state->offset = (unsigned)here.val;
617 
618             /* get distance extra bits, if any */
619             state->extra = (unsigned)(here.op) & 15;
620             if (state->extra != 0) {
621                 NEEDBITS(state->extra);
622                 state->offset += BITS(state->extra);
623                 DROPBITS(state->extra);
624             }
625             if (state->offset > state->wsize - (state->whave < state->wsize ?
626                                                 left : 0)) {
627 #ifdef SMALL
628                 strm->msg = "error";
629 #else
630                 strm->msg = (char *)"invalid distance too far back";
631 #endif
632                 state->mode = BAD;
633                 break;
634             }
635             Tracevv((stderr, "inflate:         distance %u\n", state->offset));
636 
637             /* copy match from window to output */
638             do {
639                 ROOM();
640                 copy = state->wsize - state->offset;
641                 if (copy < left) {
642                     from = put + copy;
643                     copy = left - copy;
644                 }
645                 else {
646                     from = put - state->offset;
647                     copy = left;
648                 }
649                 if (copy > state->length) copy = state->length;
650                 state->length -= copy;
651                 left -= copy;
652                 do {
653                     *put++ = *from++;
654                 } while (--copy);
655             } while (state->length != 0);
656             break;
657 
658         case DONE:
659             /* inflate stream terminated properly */
660             ret = Z_STREAM_END;
661             goto inf_leave;
662 
663         case BAD:
664             ret = Z_DATA_ERROR;
665             goto inf_leave;
666 
667         default:
668             /* can't happen, but makes compilers happy */
669             ret = Z_STREAM_ERROR;
670             goto inf_leave;
671         }
672 
673     /* Write leftover output and return unused input */
674   inf_leave:
675     if (left < state->wsize) {
676         if (out(out_desc, state->window, state->wsize - left) &&
677             ret == Z_STREAM_END)
678             ret = Z_BUF_ERROR;
679     }
680     strm->next_in = next;
681     strm->avail_in = have;
682     return ret;
683 }
684 
685 int ZEXPORT inflateBackEnd(strm)
686 z_streamp strm;
687 {
688     if (strm == Z_NULL || strm->state == Z_NULL || strm->zfree == (free_func)0)
689         return Z_STREAM_ERROR;
690     ZFREE(strm, strm->state);
691     strm->state = Z_NULL;
692     Tracev((stderr, "inflate: end\n"));
693     return Z_OK;
694 }
695