xref: /netbsd-src/common/dist/zlib/infback.c (revision c42dbd0ed2e61fe6eda8590caa852ccf34719964)
1 /*	$NetBSD: infback.c,v 1.4 2022/10/15 19:49:32 christos Exp $	*/
2 
3 /* infback.c -- inflate using a call-back interface
4  * Copyright (C) 1995-2022 Mark Adler
5  * For conditions of distribution and use, see copyright notice in zlib.h
6  */
7 
8 /*
9    This code is largely copied from inflate.c.  Normally either infback.o or
10    inflate.o would be linked into an application--not both.  The interface
11    with inffast.c is retained so that optimized assembler-coded versions of
12    inflate_fast() can be used with either inflate.c or infback.c.
13  */
14 
15 #include "zutil.h"
16 #include "inftrees.h"
17 #include "inflate.h"
18 #include "inffast.h"
19 
20 /* function prototypes */
21 local void fixedtables OF((struct inflate_state FAR *state));
22 
23 /*
24    strm provides memory allocation functions in zalloc and zfree, or
25    Z_NULL to use the library memory allocation functions.
26 
27    windowBits is in the range 8..15, and window is a user-supplied
28    window and output buffer that is 2**windowBits bytes.
29  */
30 int ZEXPORT inflateBackInit_(strm, windowBits, window, version, stream_size)
31 z_streamp strm;
32 int windowBits;
33 unsigned char FAR *window;
34 const char *version;
35 int stream_size;
36 {
37     struct inflate_state FAR *state;
38 
39     if (version == Z_NULL || version[0] != ZLIB_VERSION[0] ||
40         stream_size != (int)(sizeof(z_stream)))
41         return Z_VERSION_ERROR;
42     if (strm == Z_NULL || window == Z_NULL ||
43         windowBits < 8 || windowBits > 15)
44         return Z_STREAM_ERROR;
45     strm->msg = Z_NULL;                 /* in case we return an error */
46     if (strm->zalloc == (alloc_func)0) {
47 #ifdef Z_SOLO
48         return Z_STREAM_ERROR;
49 #else
50         strm->zalloc = zcalloc;
51         strm->opaque = (voidpf)0;
52 #endif
53     }
54     if (strm->zfree == (free_func)0)
55 #ifdef Z_SOLO
56         return Z_STREAM_ERROR;
57 #else
58     strm->zfree = zcfree;
59 #endif
60     state = (struct inflate_state FAR *)ZALLOC(strm, 1,
61                                                sizeof(struct inflate_state));
62     if (state == Z_NULL) return Z_MEM_ERROR;
63     Tracev((stderr, "inflate: allocated\n"));
64     strm->state = (struct internal_state FAR *)state;
65     state->dmax = 32768U;
66     state->wbits = (uInt)windowBits;
67     state->wsize = 1U << windowBits;
68     state->window = window;
69     state->wnext = 0;
70     state->whave = 0;
71     state->sane = 1;
72     return Z_OK;
73 }
74 
75 /*
76    Return state with length and distance decoding tables and index sizes set to
77    fixed code decoding.  Normally this returns fixed tables from inffixed.h.
78    If BUILDFIXED is defined, then instead this routine builds the tables the
79    first time it's called, and returns those tables the first time and
80    thereafter.  This reduces the size of the code by about 2K bytes, in
81    exchange for a little execution time.  However, BUILDFIXED should not be
82    used for threaded applications, since the rewriting of the tables and virgin
83    may not be thread-safe.
84  */
85 local void fixedtables(state)
86 struct inflate_state FAR *state;
87 {
88 #ifdef BUILDFIXED
89     static int virgin = 1;
90     static code *lenfix, *distfix;
91     static code fixed[544];
92 
93     /* build fixed huffman tables if first call (may not be thread safe) */
94     if (virgin) {
95         unsigned sym, bits;
96         static code *next;
97 
98         /* literal/length table */
99         sym = 0;
100         while (sym < 144) state->lens[sym++] = 8;
101         while (sym < 256) state->lens[sym++] = 9;
102         while (sym < 280) state->lens[sym++] = 7;
103         while (sym < 288) state->lens[sym++] = 8;
104         next = fixed;
105         lenfix = next;
106         bits = 9;
107         inflate_table(LENS, state->lens, 288, &(next), &(bits), state->work);
108 
109         /* distance table */
110         sym = 0;
111         while (sym < 32) state->lens[sym++] = 5;
112         distfix = next;
113         bits = 5;
114         inflate_table(DISTS, state->lens, 32, &(next), &(bits), state->work);
115 
116         /* do this just once */
117         virgin = 0;
118     }
119 #else /* !BUILDFIXED */
120 #   include "inffixed.h"
121 #endif /* BUILDFIXED */
122     state->lencode = lenfix;
123     state->lenbits = 9;
124     state->distcode = distfix;
125     state->distbits = 5;
126 }
127 
128 /* Macros for inflateBack(): */
129 
130 /* Load returned state from inflate_fast() */
131 #define LOAD() \
132     do { \
133         put = strm->next_out; \
134         left = strm->avail_out; \
135         next = strm->next_in; \
136         have = strm->avail_in; \
137         hold = state->hold; \
138         bits = state->bits; \
139     } while (0)
140 
141 /* Set state from registers for inflate_fast() */
142 #define RESTORE() \
143     do { \
144         strm->next_out = put; \
145         strm->avail_out = left; \
146         strm->next_in = next; \
147         strm->avail_in = have; \
148         state->hold = hold; \
149         state->bits = bits; \
150     } while (0)
151 
152 /* Clear the input bit accumulator */
153 #define INITBITS() \
154     do { \
155         hold = 0; \
156         bits = 0; \
157     } while (0)
158 
159 /* Assure that some input is available.  If input is requested, but denied,
160    then return a Z_BUF_ERROR from inflateBack(). */
161 #define PULL() \
162     do { \
163         if (have == 0) { \
164             have = in(in_desc, &next); \
165             if (have == 0) { \
166                 next = Z_NULL; \
167                 ret = Z_BUF_ERROR; \
168                 goto inf_leave; \
169             } \
170         } \
171     } while (0)
172 
173 /* Get a byte of input into the bit accumulator, or return from inflateBack()
174    with an error if there is no input available. */
175 #define PULLBYTE() \
176     do { \
177         PULL(); \
178         have--; \
179         hold += (unsigned long)(*next++) << bits; \
180         bits += 8; \
181     } while (0)
182 
183 /* Assure that there are at least n bits in the bit accumulator.  If there is
184    not enough available input to do that, then return from inflateBack() with
185    an error. */
186 #define NEEDBITS(n) \
187     do { \
188         while (bits < (unsigned)(n)) \
189             PULLBYTE(); \
190     } while (0)
191 
192 /* Return the low n bits of the bit accumulator (n < 16) */
193 #define BITS(n) \
194     ((unsigned)hold & ((1U << (n)) - 1))
195 
196 /* Remove n bits from the bit accumulator */
197 #define DROPBITS(n) \
198     do { \
199         hold >>= (n); \
200         bits -= (unsigned)(n); \
201     } while (0)
202 
203 /* Remove zero to seven bits as needed to go to a byte boundary */
204 #define BYTEBITS() \
205     do { \
206         hold >>= bits & 7; \
207         bits -= bits & 7; \
208     } while (0)
209 
210 /* Assure that some output space is available, by writing out the window
211    if it's full.  If the write fails, return from inflateBack() with a
212    Z_BUF_ERROR. */
213 #define ROOM() \
214     do { \
215         if (left == 0) { \
216             put = state->window; \
217             left = state->wsize; \
218             state->whave = left; \
219             if (out(out_desc, put, left)) { \
220                 ret = Z_BUF_ERROR; \
221                 goto inf_leave; \
222             } \
223         } \
224     } while (0)
225 
226 /*
227    strm provides the memory allocation functions and window buffer on input,
228    and provides information on the unused input on return.  For Z_DATA_ERROR
229    returns, strm will also provide an error message.
230 
231    in() and out() are the call-back input and output functions.  When
232    inflateBack() needs more input, it calls in().  When inflateBack() has
233    filled the window with output, or when it completes with data in the
234    window, it calls out() to write out the data.  The application must not
235    change the provided input until in() is called again or inflateBack()
236    returns.  The application must not change the window/output buffer until
237    inflateBack() returns.
238 
239    in() and out() are called with a descriptor parameter provided in the
240    inflateBack() call.  This parameter can be a structure that provides the
241    information required to do the read or write, as well as accumulated
242    information on the input and output such as totals and check values.
243 
244    in() should return zero on failure.  out() should return non-zero on
245    failure.  If either in() or out() fails, than inflateBack() returns a
246    Z_BUF_ERROR.  strm->next_in can be checked for Z_NULL to see whether it
247    was in() or out() that caused in the error.  Otherwise,  inflateBack()
248    returns Z_STREAM_END on success, Z_DATA_ERROR for an deflate format
249    error, or Z_MEM_ERROR if it could not allocate memory for the state.
250    inflateBack() can also return Z_STREAM_ERROR if the input parameters
251    are not correct, i.e. strm is Z_NULL or the state was not initialized.
252  */
253 int ZEXPORT inflateBack(strm, in, in_desc, out, out_desc)
254 z_streamp strm;
255 in_func in;
256 void FAR *in_desc;
257 out_func out;
258 void FAR *out_desc;
259 {
260     struct inflate_state FAR *state;
261     z_const unsigned char FAR *next;    /* next input */
262     unsigned char FAR *put;     /* next output */
263     unsigned have, left;        /* available input and output */
264     unsigned long hold;         /* bit buffer */
265     unsigned bits;              /* bits in bit buffer */
266     unsigned copy;              /* number of stored or match bytes to copy */
267     unsigned char FAR *from;    /* where to copy match bytes from */
268     code here;                  /* current decoding table entry */
269     code last;                  /* parent table entry */
270     unsigned len;               /* length to copy for repeats, bits to drop */
271     int ret;                    /* return code */
272     static const unsigned short order[19] = /* permutation of code lengths */
273         {16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15};
274 
275     /* Check that the strm exists and that the state was initialized */
276     if (strm == Z_NULL || strm->state == Z_NULL)
277         return Z_STREAM_ERROR;
278     state = (struct inflate_state FAR *)strm->state;
279 
280     /* Reset the state */
281     strm->msg = Z_NULL;
282     state->mode = TYPE;
283     state->last = 0;
284     state->whave = 0;
285     next = strm->next_in;
286     have = next != Z_NULL ? strm->avail_in : 0;
287     hold = 0;
288     bits = 0;
289     put = state->window;
290     left = state->wsize;
291 
292     /* Inflate until end of block marked as last */
293     for (;;)
294         switch (state->mode) {
295         case TYPE:
296             /* determine and dispatch block type */
297             if (state->last) {
298                 BYTEBITS();
299                 state->mode = DONE;
300                 break;
301             }
302             NEEDBITS(3);
303             state->last = BITS(1);
304             DROPBITS(1);
305             switch (BITS(2)) {
306             case 0:                             /* stored block */
307                 Tracev((stderr, "inflate:     stored block%s\n",
308                         state->last ? " (last)" : ""));
309                 state->mode = STORED;
310                 break;
311             case 1:                             /* fixed block */
312                 fixedtables(state);
313                 Tracev((stderr, "inflate:     fixed codes block%s\n",
314                         state->last ? " (last)" : ""));
315                 state->mode = LEN;              /* decode codes */
316                 break;
317             case 2:                             /* dynamic block */
318                 Tracev((stderr, "inflate:     dynamic codes block%s\n",
319                         state->last ? " (last)" : ""));
320                 state->mode = TABLE;
321                 break;
322             case 3:
323                 strm->msg = __UNCONST("invalid block type");
324                 state->mode = BAD;
325             }
326             DROPBITS(2);
327             break;
328 
329         case STORED:
330             /* get and verify stored block length */
331             BYTEBITS();                         /* go to byte boundary */
332             NEEDBITS(32);
333             if ((hold & 0xffff) != ((hold >> 16) ^ 0xffff)) {
334                 strm->msg = __UNCONST("invalid stored block lengths");
335                 state->mode = BAD;
336                 break;
337             }
338             state->length = (unsigned)hold & 0xffff;
339             Tracev((stderr, "inflate:       stored length %u\n",
340                     state->length));
341             INITBITS();
342 
343             /* copy stored block from input to output */
344             while (state->length != 0) {
345                 copy = state->length;
346                 PULL();
347                 ROOM();
348                 if (copy > have) copy = have;
349                 if (copy > left) copy = left;
350                 zmemcpy(put, next, copy);
351                 have -= copy;
352                 next += copy;
353                 left -= copy;
354                 put += copy;
355                 state->length -= copy;
356             }
357             Tracev((stderr, "inflate:       stored end\n"));
358             state->mode = TYPE;
359             break;
360 
361         case TABLE:
362             /* get dynamic table entries descriptor */
363             NEEDBITS(14);
364             state->nlen = BITS(5) + 257;
365             DROPBITS(5);
366             state->ndist = BITS(5) + 1;
367             DROPBITS(5);
368             state->ncode = BITS(4) + 4;
369             DROPBITS(4);
370 #ifndef PKZIP_BUG_WORKAROUND
371             if (state->nlen > 286 || state->ndist > 30) {
372                 strm->msg = __UNCONST("too many length or distance symbols");
373                 state->mode = BAD;
374                 break;
375             }
376 #endif
377             Tracev((stderr, "inflate:       table sizes ok\n"));
378 
379             /* get code length code lengths (not a typo) */
380             state->have = 0;
381             while (state->have < state->ncode) {
382                 NEEDBITS(3);
383                 state->lens[order[state->have++]] = (unsigned short)BITS(3);
384                 DROPBITS(3);
385             }
386             while (state->have < 19)
387                 state->lens[order[state->have++]] = 0;
388             state->next = state->codes;
389             state->lencode = (code const FAR *)(state->next);
390             state->lenbits = 7;
391             ret = inflate_table(CODES, state->lens, 19, &(state->next),
392                                 &(state->lenbits), state->work);
393             if (ret) {
394                 strm->msg = __UNCONST("invalid code lengths set");
395                 state->mode = BAD;
396                 break;
397             }
398             Tracev((stderr, "inflate:       code lengths ok\n"));
399 
400             /* get length and distance code code lengths */
401             state->have = 0;
402             while (state->have < state->nlen + state->ndist) {
403                 for (;;) {
404                     here = state->lencode[BITS(state->lenbits)];
405                     if ((unsigned)(here.bits) <= bits) break;
406                     PULLBYTE();
407                 }
408                 if (here.val < 16) {
409                     DROPBITS(here.bits);
410                     state->lens[state->have++] = here.val;
411                 }
412                 else {
413                     if (here.val == 16) {
414                         NEEDBITS(here.bits + 2);
415                         DROPBITS(here.bits);
416                         if (state->have == 0) {
417                             strm->msg = __UNCONST("invalid bit length repeat");
418                             state->mode = BAD;
419                             break;
420                         }
421                         len = (unsigned)(state->lens[state->have - 1]);
422                         copy = 3 + BITS(2);
423                         DROPBITS(2);
424                     }
425                     else if (here.val == 17) {
426                         NEEDBITS(here.bits + 3);
427                         DROPBITS(here.bits);
428                         len = 0;
429                         copy = 3 + BITS(3);
430                         DROPBITS(3);
431                     }
432                     else {
433                         NEEDBITS(here.bits + 7);
434                         DROPBITS(here.bits);
435                         len = 0;
436                         copy = 11 + BITS(7);
437                         DROPBITS(7);
438                     }
439                     if (state->have + copy > state->nlen + state->ndist) {
440                         strm->msg = __UNCONST("invalid bit length repeat");
441                         state->mode = BAD;
442                         break;
443                     }
444                     while (copy--)
445                         state->lens[state->have++] = (unsigned short)len;
446                 }
447             }
448 
449             /* handle error breaks in while */
450             if (state->mode == BAD) break;
451 
452             /* check for end-of-block code (better have one) */
453             if (state->lens[256] == 0) {
454                 strm->msg = __UNCONST("invalid code -- missing end-of-block");
455                 state->mode = BAD;
456                 break;
457             }
458 
459             /* build code tables -- note: do not change the lenbits or distbits
460                values here (9 and 6) without reading the comments in inftrees.h
461                concerning the ENOUGH constants, which depend on those values */
462             state->next = state->codes;
463             state->lencode = (code const FAR *)(state->next);
464             state->lenbits = 9;
465             ret = inflate_table(LENS, state->lens, state->nlen, &(state->next),
466                                 &(state->lenbits), state->work);
467             if (ret) {
468                 strm->msg = __UNCONST("invalid literal/lengths set");
469                 state->mode = BAD;
470                 break;
471             }
472             state->distcode = (code const FAR *)(state->next);
473             state->distbits = 6;
474             ret = inflate_table(DISTS, state->lens + state->nlen, state->ndist,
475                             &(state->next), &(state->distbits), state->work);
476             if (ret) {
477                 strm->msg = __UNCONST("invalid distances set");
478                 state->mode = BAD;
479                 break;
480             }
481             Tracev((stderr, "inflate:       codes ok\n"));
482             state->mode = LEN;
483                 /* fallthrough */
484 
485         case LEN:
486             /* use inflate_fast() if we have enough input and output */
487             if (have >= 6 && left >= 258) {
488                 RESTORE();
489                 if (state->whave < state->wsize)
490                     state->whave = state->wsize - left;
491                 inflate_fast(strm, state->wsize);
492                 LOAD();
493                 break;
494             }
495 
496             /* get a literal, length, or end-of-block code */
497             for (;;) {
498                 here = state->lencode[BITS(state->lenbits)];
499                 if ((unsigned)(here.bits) <= bits) break;
500                 PULLBYTE();
501             }
502             if (here.op && (here.op & 0xf0) == 0) {
503                 last = here;
504                 for (;;) {
505                     here = state->lencode[last.val +
506                             (BITS(last.bits + last.op) >> last.bits)];
507                     if ((unsigned)(last.bits + here.bits) <= bits) break;
508                     PULLBYTE();
509                 }
510                 DROPBITS(last.bits);
511             }
512             DROPBITS(here.bits);
513             state->length = (unsigned)here.val;
514 
515             /* process literal */
516             if (here.op == 0) {
517                 Tracevv((stderr, here.val >= 0x20 && here.val < 0x7f ?
518                         "inflate:         literal '%c'\n" :
519                         "inflate:         literal 0x%02x\n", here.val));
520                 ROOM();
521                 *put++ = (unsigned char)(state->length);
522                 left--;
523                 state->mode = LEN;
524                 break;
525             }
526 
527             /* process end of block */
528             if (here.op & 32) {
529                 Tracevv((stderr, "inflate:         end of block\n"));
530                 state->mode = TYPE;
531                 break;
532             }
533 
534             /* invalid code */
535             if (here.op & 64) {
536                 strm->msg = __UNCONST("invalid literal/length code");
537                 state->mode = BAD;
538                 break;
539             }
540 
541             /* length code -- get extra bits, if any */
542             state->extra = (unsigned)(here.op) & 15;
543             if (state->extra != 0) {
544                 NEEDBITS(state->extra);
545                 state->length += BITS(state->extra);
546                 DROPBITS(state->extra);
547             }
548             Tracevv((stderr, "inflate:         length %u\n", state->length));
549 
550             /* get distance code */
551             for (;;) {
552                 here = state->distcode[BITS(state->distbits)];
553                 if ((unsigned)(here.bits) <= bits) break;
554                 PULLBYTE();
555             }
556             if ((here.op & 0xf0) == 0) {
557                 last = here;
558                 for (;;) {
559                     here = state->distcode[last.val +
560                             (BITS(last.bits + last.op) >> last.bits)];
561                     if ((unsigned)(last.bits + here.bits) <= bits) break;
562                     PULLBYTE();
563                 }
564                 DROPBITS(last.bits);
565             }
566             DROPBITS(here.bits);
567             if (here.op & 64) {
568                 strm->msg = __UNCONST("invalid distance code");
569                 state->mode = BAD;
570                 break;
571             }
572             state->offset = (unsigned)here.val;
573 
574             /* get distance extra bits, if any */
575             state->extra = (unsigned)(here.op) & 15;
576             if (state->extra != 0) {
577                 NEEDBITS(state->extra);
578                 state->offset += BITS(state->extra);
579                 DROPBITS(state->extra);
580             }
581             if (state->offset > state->wsize - (state->whave < state->wsize ?
582                                                 left : 0)) {
583                 strm->msg = __UNCONST("invalid distance too far back");
584                 state->mode = BAD;
585                 break;
586             }
587             Tracevv((stderr, "inflate:         distance %u\n", state->offset));
588 
589             /* copy match from window to output */
590             do {
591                 ROOM();
592                 copy = state->wsize - state->offset;
593                 if (copy < left) {
594                     from = put + copy;
595                     copy = left - copy;
596                 }
597                 else {
598                     from = put - state->offset;
599                     copy = left;
600                 }
601                 if (copy > state->length) copy = state->length;
602                 state->length -= copy;
603                 left -= copy;
604                 do {
605                     *put++ = *from++;
606                 } while (--copy);
607             } while (state->length != 0);
608             break;
609 
610         case DONE:
611             /* inflate stream terminated properly */
612             ret = Z_STREAM_END;
613             goto inf_leave;
614 
615         case BAD:
616             ret = Z_DATA_ERROR;
617             goto inf_leave;
618 
619         default:
620             /* can't happen, but makes compilers happy */
621             ret = Z_STREAM_ERROR;
622             goto inf_leave;
623         }
624 
625     /* Write leftover output and return unused input */
626   inf_leave:
627     if (left < state->wsize) {
628         if (out(out_desc, state->window, state->wsize - left) &&
629             ret == Z_STREAM_END)
630             ret = Z_BUF_ERROR;
631     }
632     strm->next_in = next;
633     strm->avail_in = have;
634     return ret;
635 }
636 
637 int ZEXPORT inflateBackEnd(strm)
638 z_streamp strm;
639 {
640     if (strm == Z_NULL || strm->state == Z_NULL || strm->zfree == (free_func)0)
641         return Z_STREAM_ERROR;
642     ZFREE(strm, strm->state);
643     strm->state = Z_NULL;
644     Tracev((stderr, "inflate: end\n"));
645     return Z_OK;
646 }
647