1 #include "lib9.h"
2 #include <flate.h>
3
4 enum {
5 HistorySize= 32*1024,
6 BufSize= 4*1024,
7 MaxHuffBits= 17, /* maximum bits in a encoded code */
8 Nlitlen= 288, /* number of litlen codes */
9 Noff= 32, /* number of offset codes */
10 Nclen= 19, /* number of codelen codes */
11 LenShift= 10, /* code = len<<LenShift|code */
12 LitlenBits= 7, /* number of bits in litlen decode table */
13 OffBits= 6, /* number of bits in offset decode table */
14 ClenBits= 6, /* number of bits in code len decode table */
15 MaxFlatBits= LitlenBits,
16 MaxLeaf= Nlitlen
17 };
18
19 typedef struct Input Input;
20 typedef struct History History;
21 typedef struct Huff Huff;
22
23 struct Input
24 {
25 int error; /* first error encountered, or FlateOk */
26 void *wr;
27 int (*w)(void*, void*, int);
28 void *getr;
29 int (*get)(void*);
30 ulong sreg;
31 int nbits;
32 };
33
34 struct History
35 {
36 uchar his[HistorySize];
37 uchar *cp; /* current pointer in history */
38 int full; /* his has been filled up at least once */
39 };
40
41 struct Huff
42 {
43 int maxbits; /* max bits for any code */
44 int minbits; /* min bits to get before looking in flat */
45 int flatmask; /* bits used in "flat" fast decoding table */
46 ulong flat[1<<MaxFlatBits];
47 ulong maxcode[MaxHuffBits];
48 ulong last[MaxHuffBits];
49 ulong decode[MaxLeaf];
50 };
51
52 /* litlen code words 257-285 extra bits */
53 static int litlenextra[Nlitlen-257] =
54 {
55 /* 257 */ 0, 0, 0,
56 /* 260 */ 0, 0, 0, 0, 0, 1, 1, 1, 1, 2,
57 /* 270 */ 2, 2, 2, 3, 3, 3, 3, 4, 4, 4,
58 /* 280 */ 4, 5, 5, 5, 5, 0, 0, 0
59 };
60
61 static int litlenbase[Nlitlen-257];
62
63 /* offset code word extra bits */
64 static int offextra[Noff] =
65 {
66 0, 0, 0, 0, 1, 1, 2, 2, 3, 3,
67 4, 4, 5, 5, 6, 6, 7, 7, 8, 8,
68 9, 9, 10, 10, 11, 11, 12, 12, 13, 13,
69 0, 0,
70 };
71 static int offbase[Noff];
72
73 /* order code lengths */
74 static int clenorder[Nclen] =
75 {
76 16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15
77 };
78
79 /* for static huffman tables */
80 static Huff litlentab;
81 static Huff offtab;
82 static uchar revtab[256];
83
84 static int uncblock(Input *in, History*);
85 static int fixedblock(Input *in, History*);
86 static int dynamicblock(Input *in, History*);
87 static int sregfill(Input *in, int n);
88 static int sregunget(Input *in);
89 static int decode(Input*, History*, Huff*, Huff*);
90 static int hufftab(Huff*, char*, int, int);
91 static int hdecsym(Input *in, Huff *h, int b);
92
93 int
inflateinit(void)94 inflateinit(void)
95 {
96 char *len;
97 int i, j, base;
98
99 /* byte reverse table */
100 for(i=0; i<256; i++)
101 for(j=0; j<8; j++)
102 if(i & (1<<j))
103 revtab[i] |= 0x80 >> j;
104
105 for(i=257,base=3; i<Nlitlen; i++) {
106 litlenbase[i-257] = base;
107 base += 1<<litlenextra[i-257];
108 }
109 /* strange table entry in spec... */
110 litlenbase[285-257]--;
111
112 for(i=0,base=1; i<Noff; i++) {
113 offbase[i] = base;
114 base += 1<<offextra[i];
115 }
116
117 len = malloc(MaxLeaf);
118 if(len == nil)
119 return FlateNoMem;
120
121 /* static Litlen bit lengths */
122 for(i=0; i<144; i++)
123 len[i] = 8;
124 for(i=144; i<256; i++)
125 len[i] = 9;
126 for(i=256; i<280; i++)
127 len[i] = 7;
128 for(i=280; i<Nlitlen; i++)
129 len[i] = 8;
130
131 if(!hufftab(&litlentab, len, Nlitlen, MaxFlatBits))
132 return FlateInternal;
133
134 /* static Offset bit lengths */
135 for(i=0; i<Noff; i++)
136 len[i] = 5;
137
138 if(!hufftab(&offtab, len, Noff, MaxFlatBits))
139 return FlateInternal;
140 free(len);
141
142 return FlateOk;
143 }
144
145 int
inflate(void * wr,int (* w)(void *,void *,int),void * getr,int (* get)(void *))146 inflate(void *wr, int (*w)(void*, void*, int), void *getr, int (*get)(void*))
147 {
148 History *his;
149 Input in;
150 int final, type;
151
152 his = malloc(sizeof(History));
153 if(his == nil)
154 return FlateNoMem;
155 his->cp = his->his;
156 his->full = 0;
157 in.getr = getr;
158 in.get = get;
159 in.wr = wr;
160 in.w = w;
161 in.nbits = 0;
162 in.sreg = 0;
163 in.error = FlateOk;
164
165 do {
166 if(!sregfill(&in, 3))
167 goto bad;
168 final = in.sreg & 0x1;
169 type = (in.sreg>>1) & 0x3;
170 in.sreg >>= 3;
171 in.nbits -= 3;
172 switch(type) {
173 default:
174 in.error = FlateCorrupted;
175 goto bad;
176 case 0:
177 /* uncompressed */
178 if(!uncblock(&in, his))
179 goto bad;
180 break;
181 case 1:
182 /* fixed huffman */
183 if(!fixedblock(&in, his))
184 goto bad;
185 break;
186 case 2:
187 /* dynamic huffman */
188 if(!dynamicblock(&in, his))
189 goto bad;
190 break;
191 }
192 } while(!final);
193
194 if(his->cp != his->his && (*w)(wr, his->his, his->cp - his->his) != his->cp - his->his) {
195 in.error = FlateOutputFail;
196 goto bad;
197 }
198
199 if(!sregunget(&in))
200 goto bad;
201
202 free(his);
203 if(in.error != FlateOk)
204 return FlateInternal;
205 return FlateOk;
206
207 bad:
208 free(his);
209 if(in.error == FlateOk)
210 return FlateInternal;
211 return in.error;
212 }
213
214 static int
uncblock(Input * in,History * his)215 uncblock(Input *in, History *his)
216 {
217 int len, nlen, c;
218 uchar *hs, *hp, *he;
219
220 if(!sregunget(in))
221 return 0;
222 len = (*in->get)(in->getr);
223 len |= (*in->get)(in->getr)<<8;
224 nlen = (*in->get)(in->getr);
225 nlen |= (*in->get)(in->getr)<<8;
226 if(len != (~nlen&0xffff)) {
227 in->error = FlateCorrupted;
228 return 0;
229 }
230
231 hp = his->cp;
232 hs = his->his;
233 he = hs + HistorySize;
234
235 while(len > 0) {
236 c = (*in->get)(in->getr);
237 if(c < 0)
238 return 0;
239 *hp++ = c;
240 if(hp == he) {
241 his->full = 1;
242 if((*in->w)(in->wr, hs, HistorySize) != HistorySize) {
243 in->error = FlateOutputFail;
244 return 0;
245 }
246 hp = hs;
247 }
248 len--;
249 }
250
251 his->cp = hp;
252
253 return 1;
254 }
255
256 static int
fixedblock(Input * in,History * his)257 fixedblock(Input *in, History *his)
258 {
259 return decode(in, his, &litlentab, &offtab);
260 }
261
262 static int
dynamicblock(Input * in,History * his)263 dynamicblock(Input *in, History *his)
264 {
265 Huff *lentab, *offtab;
266 char *len;
267 int i, j, n, c, nlit, ndist, nclen, res, nb;
268
269 if(!sregfill(in, 14))
270 return 0;
271 nlit = (in->sreg&0x1f) + 257;
272 ndist = ((in->sreg>>5) & 0x1f) + 1;
273 nclen = ((in->sreg>>10) & 0xf) + 4;
274 in->sreg >>= 14;
275 in->nbits -= 14;
276
277 if(nlit > Nlitlen || ndist > Noff || nlit < 257) {
278 in->error = FlateCorrupted;
279 return 0;
280 }
281
282 /* huff table header */
283 len = malloc(Nlitlen+Noff);
284 lentab = malloc(sizeof(Huff));
285 offtab = malloc(sizeof(Huff));
286 if(len == nil || lentab == nil || offtab == nil){
287 in->error = FlateNoMem;
288 goto bad;
289 }
290 for(i=0; i < Nclen; i++)
291 len[i] = 0;
292 for(i=0; i<nclen; i++) {
293 if(!sregfill(in, 3))
294 goto bad;
295 len[clenorder[i]] = in->sreg & 0x7;
296 in->sreg >>= 3;
297 in->nbits -= 3;
298 }
299
300 if(!hufftab(lentab, len, Nclen, ClenBits)){
301 in->error = FlateCorrupted;
302 goto bad;
303 }
304
305 n = nlit+ndist;
306 for(i=0; i<n;) {
307 nb = lentab->minbits;
308 for(;;){
309 if(in->nbits<nb && !sregfill(in, nb))
310 goto bad;
311 c = lentab->flat[in->sreg & lentab->flatmask];
312 nb = c & 0xff;
313 if(nb > in->nbits){
314 if(nb != 0xff)
315 continue;
316 c = hdecsym(in, lentab, c);
317 if(c < 0)
318 goto bad;
319 }else{
320 c >>= 8;
321 in->sreg >>= nb;
322 in->nbits -= nb;
323 }
324 break;
325 }
326
327 if(c < 16) {
328 j = 1;
329 } else if(c == 16) {
330 if(in->nbits<2 && !sregfill(in, 2))
331 goto bad;
332 j = (in->sreg&0x3)+3;
333 in->sreg >>= 2;
334 in->nbits -= 2;
335 if(i == 0) {
336 in->error = FlateCorrupted;
337 goto bad;
338 }
339 c = len[i-1];
340 } else if(c == 17) {
341 if(in->nbits<3 && !sregfill(in, 3))
342 goto bad;
343 j = (in->sreg&0x7)+3;
344 in->sreg >>= 3;
345 in->nbits -= 3;
346 c = 0;
347 } else if(c == 18) {
348 if(in->nbits<7 && !sregfill(in, 7))
349 goto bad;
350 j = (in->sreg&0x7f)+11;
351 in->sreg >>= 7;
352 in->nbits -= 7;
353 c = 0;
354 } else {
355 in->error = FlateCorrupted;
356 goto bad;
357 }
358
359 if(i+j > n) {
360 in->error = FlateCorrupted;
361 goto bad;
362 }
363
364 while(j) {
365 len[i] = c;
366 i++;
367 j--;
368 }
369 }
370
371 if(!hufftab(lentab, len, nlit, LitlenBits)
372 || !hufftab(offtab, &len[nlit], ndist, OffBits)){
373 in->error = FlateCorrupted;
374 goto bad;
375 }
376
377 res = decode(in, his, lentab, offtab);
378
379 free(len);
380 free(lentab);
381 free(offtab);
382
383 return res;
384
385 bad:
386 free(len);
387 free(lentab);
388 free(offtab);
389 return 0;
390 }
391
392 static int
decode(Input * in,History * his,Huff * litlentab,Huff * offtab)393 decode(Input *in, History *his, Huff *litlentab, Huff *offtab)
394 {
395 int len, off;
396 uchar *hs, *hp, *hq, *he;
397 int c;
398 int nb;
399
400 hs = his->his;
401 he = hs + HistorySize;
402 hp = his->cp;
403
404 for(;;) {
405 nb = litlentab->minbits;
406 for(;;){
407 if(in->nbits<nb && !sregfill(in, nb))
408 return 0;
409 c = litlentab->flat[in->sreg & litlentab->flatmask];
410 nb = c & 0xff;
411 if(nb > in->nbits){
412 if(nb != 0xff)
413 continue;
414 c = hdecsym(in, litlentab, c);
415 if(c < 0)
416 return 0;
417 }else{
418 c >>= 8;
419 in->sreg >>= nb;
420 in->nbits -= nb;
421 }
422 break;
423 }
424
425 if(c < 256) {
426 /* literal */
427 *hp++ = c;
428 if(hp == he) {
429 his->full = 1;
430 if((*in->w)(in->wr, hs, HistorySize) != HistorySize) {
431 in->error = FlateOutputFail;
432 return 0;
433 }
434 hp = hs;
435 }
436 continue;
437 }
438
439 if(c == 256)
440 break;
441
442 if(c > 285) {
443 in->error = FlateCorrupted;
444 return 0;
445 }
446
447 c -= 257;
448 nb = litlenextra[c];
449 if(in->nbits < nb && !sregfill(in, nb))
450 return 0;
451 len = litlenbase[c] + (in->sreg & ((1<<nb)-1));
452 in->sreg >>= nb;
453 in->nbits -= nb;
454
455 /* get offset */
456 nb = offtab->minbits;
457 for(;;){
458 if(in->nbits<nb && !sregfill(in, nb))
459 return 0;
460 c = offtab->flat[in->sreg & offtab->flatmask];
461 nb = c & 0xff;
462 if(nb > in->nbits){
463 if(nb != 0xff)
464 continue;
465 c = hdecsym(in, offtab, c);
466 if(c < 0)
467 return 0;
468 }else{
469 c >>= 8;
470 in->sreg >>= nb;
471 in->nbits -= nb;
472 }
473 break;
474 }
475
476 if(c > 29) {
477 in->error = FlateCorrupted;
478 return 0;
479 }
480
481 nb = offextra[c];
482 if(in->nbits < nb && !sregfill(in, nb))
483 return 0;
484
485 off = offbase[c] + (in->sreg & ((1<<nb)-1));
486 in->sreg >>= nb;
487 in->nbits -= nb;
488
489 hq = hp - off;
490 if(hq < hs) {
491 if(!his->full) {
492 in->error = FlateCorrupted;
493 return 0;
494 }
495 hq += HistorySize;
496 }
497
498 /* slow but correct */
499 while(len) {
500 *hp = *hq;
501 hq++;
502 hp++;
503 if(hq >= he)
504 hq = hs;
505 if(hp == he) {
506 his->full = 1;
507 if((*in->w)(in->wr, hs, HistorySize) != HistorySize) {
508 in->error = FlateOutputFail;
509 return 0;
510 }
511 hp = hs;
512 }
513 len--;
514 }
515
516 }
517
518 his->cp = hp;
519
520 return 1;
521 }
522
523 static int
revcode(int c,int b)524 revcode(int c, int b)
525 {
526 /* shift encode up so it starts on bit 15 then reverse */
527 c <<= (16-b);
528 c = revtab[c>>8] | (revtab[c&0xff]<<8);
529 return c;
530 }
531
532 /*
533 * construct the huffman decoding arrays and a fast lookup table.
534 * the fast lookup is a table indexed by the next flatbits bits,
535 * which returns the symbol matched and the number of bits consumed,
536 * or the minimum number of bits needed and 0xff if more than flatbits
537 * bits are needed.
538 *
539 * flatbits can be longer than the smallest huffman code,
540 * because shorter codes are assigned smaller lexical prefixes.
541 * this means assuming zeros for the next few bits will give a
542 * conservative answer, in the sense that it will either give the
543 * correct answer, or return the minimum number of bits which
544 * are needed for an answer.
545 */
546 static int
hufftab(Huff * h,char * hb,int maxleaf,int flatbits)547 hufftab(Huff *h, char *hb, int maxleaf, int flatbits)
548 {
549 ulong bitcount[MaxHuffBits];
550 ulong c, fc, ec, mincode, code, nc[MaxHuffBits];
551 int i, b, minbits, maxbits;
552
553 for(i = 0; i < MaxHuffBits; i++)
554 bitcount[i] = 0;
555 maxbits = -1;
556 minbits = MaxHuffBits + 1;
557 for(i=0; i < maxleaf; i++){
558 b = hb[i];
559 if(b){
560 bitcount[b]++;
561 if(b < minbits)
562 minbits = b;
563 if(b > maxbits)
564 maxbits = b;
565 }
566 }
567
568 h->maxbits = maxbits;
569 if(maxbits <= 0){
570 h->maxbits = 0;
571 h->minbits = 0;
572 h->flatmask = 0;
573 return 1;
574 }
575 code = 0;
576 c = 0;
577 for(b = 0; b <= maxbits; b++){
578 h->last[b] = c;
579 c += bitcount[b];
580 mincode = code << 1;
581 nc[b] = mincode;
582 code = mincode + bitcount[b];
583 if(code > (1 << b))
584 return 0;
585 h->maxcode[b] = code - 1;
586 h->last[b] += code - 1;
587 }
588
589 if(flatbits > maxbits)
590 flatbits = maxbits;
591 h->flatmask = (1 << flatbits) - 1;
592 if(minbits > flatbits)
593 minbits = flatbits;
594 h->minbits = minbits;
595
596 b = 1 << flatbits;
597 for(i = 0; i < b; i++)
598 h->flat[i] = ~0;
599
600 /*
601 * initialize the flat table to include the minimum possible
602 * bit length for each code prefix
603 */
604 for(b = maxbits; b > flatbits; b--){
605 code = h->maxcode[b];
606 if(code == -1)
607 break;
608 mincode = code + 1 - bitcount[b];
609 mincode >>= b - flatbits;
610 code >>= b - flatbits;
611 for(; mincode <= code; mincode++)
612 h->flat[revcode(mincode, flatbits)] = (b << 8) | 0xff;
613 }
614
615 for(i = 0; i < maxleaf; i++){
616 b = hb[i];
617 if(b <= 0)
618 continue;
619 c = nc[b]++;
620 if(b <= flatbits){
621 code = (i << 8) | b;
622 ec = (c + 1) << (flatbits - b);
623 if(ec > (1<<flatbits))
624 return 0; /* this is actually an internal error */
625 for(fc = c << (flatbits - b); fc < ec; fc++)
626 h->flat[revcode(fc, flatbits)] = code;
627 }
628 if(b > minbits){
629 c = h->last[b] - c;
630 if(c >= maxleaf)
631 return 0;
632 h->decode[c] = i;
633 }
634 }
635 return 1;
636 }
637
638 static int
hdecsym(Input * in,Huff * h,int nb)639 hdecsym(Input *in, Huff *h, int nb)
640 {
641 long c;
642
643 if((nb & 0xff) == 0xff)
644 nb = nb >> 8;
645 else
646 nb = nb & 0xff;
647 for(; nb <= h->maxbits; nb++){
648 if(in->nbits<nb && !sregfill(in, nb))
649 return -1;
650 c = revtab[in->sreg&0xff]<<8;
651 c |= revtab[(in->sreg>>8)&0xff];
652 c >>= (16-nb);
653 if(c <= h->maxcode[nb]){
654 in->sreg >>= nb;
655 in->nbits -= nb;
656 return h->decode[h->last[nb] - c];
657 }
658 }
659 in->error = FlateCorrupted;
660 return -1;
661 }
662
663 static int
sregfill(Input * in,int n)664 sregfill(Input *in, int n)
665 {
666 int c;
667
668 while(n > in->nbits) {
669 c = (*in->get)(in->getr);
670 if(c < 0){
671 in->error = FlateInputFail;
672 return 0;
673 }
674 in->sreg |= c<<in->nbits;
675 in->nbits += 8;
676 }
677 return 1;
678 }
679
680 static int
sregunget(Input * in)681 sregunget(Input *in)
682 {
683 if(in->nbits >= 8) {
684 in->error = FlateInternal;
685 return 0;
686 }
687
688 /* throw other bits on the floor */
689 in->nbits = 0;
690 in->sreg = 0;
691 return 1;
692 }
693