xref: /plan9/sys/src/libflate/inflate.c (revision 6b6b9ac8b0b103b1e30e4d019522a78c950fce74)
1 #include <u.h>
2 #include <libc.h>
3 #include <flate.h>
4 
5 enum {
6 	HistorySize=	32*1024,
7 	BufSize=	4*1024,
8 	MaxHuffBits=	17,	/* maximum bits in a encoded code */
9 	Nlitlen=	288,	/* number of litlen codes */
10 	Noff=		32,	/* number of offset codes */
11 	Nclen=		19,	/* number of codelen codes */
12 	LenShift=	10,	/* code = len<<LenShift|code */
13 	LitlenBits=	7,	/* number of bits in litlen decode table */
14 	OffBits=	6,	/* number of bits in offset decode table */
15 	ClenBits=	6,	/* number of bits in code len decode table */
16 	MaxFlatBits=	LitlenBits,
17 	MaxLeaf=	Nlitlen
18 };
19 
20 typedef struct Input	Input;
21 typedef struct History	History;
22 typedef struct Huff	Huff;
23 
24 struct Input
25 {
26 	int	error;		/* first error encountered, or FlateOk */
27 	void	*wr;
28 	int	(*w)(void*, void*, int);
29 	void	*getr;
30 	int	(*get)(void*);
31 	ulong	sreg;
32 	int	nbits;
33 };
34 
35 struct History
36 {
37 	uchar	his[HistorySize];
38 	uchar	*cp;		/* current pointer in history */
39 	int	full;		/* his has been filled up at least once */
40 };
41 
42 struct Huff
43 {
44 	int	maxbits;	/* max bits for any code */
45 	int	minbits;	/* min bits to get before looking in flat */
46 	int	flatmask;	/* bits used in "flat" fast decoding table */
47 	ulong	flat[1<<MaxFlatBits];
48 	ulong	maxcode[MaxHuffBits];
49 	ulong	last[MaxHuffBits];
50 	ulong	decode[MaxLeaf];
51 };
52 
53 /* litlen code words 257-285 extra bits */
54 static int litlenextra[Nlitlen-257] =
55 {
56 /* 257 */	0, 0, 0,
57 /* 260 */	0, 0, 0, 0, 0, 1, 1, 1, 1, 2,
58 /* 270 */	2, 2, 2, 3, 3, 3, 3, 4, 4, 4,
59 /* 280 */	4, 5, 5, 5, 5, 0, 0, 0
60 };
61 
62 static int litlenbase[Nlitlen-257];
63 
64 /* offset code word extra bits */
65 static int offextra[Noff] =
66 {
67 	0,  0,  0,  0,  1,  1,  2,  2,  3,  3,
68 	4,  4,  5,  5,  6,  6,  7,  7,  8,  8,
69 	9,  9,  10, 10, 11, 11, 12, 12, 13, 13,
70 	0,  0,
71 };
72 static int offbase[Noff];
73 
74 /* order code lengths */
75 static int clenorder[Nclen] =
76 {
77         16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15
78 };
79 
80 /* for static huffman tables */
81 static	Huff	litlentab;
82 static	Huff	offtab;
83 static	uchar	revtab[256];
84 
85 static int	uncblock(Input *in, History*);
86 static int	fixedblock(Input *in, History*);
87 static int	dynamicblock(Input *in, History*);
88 static int	sregfill(Input *in, int n);
89 static int	sregunget(Input *in);
90 static int	decode(Input*, History*, Huff*, Huff*);
91 static int	hufftab(Huff*, char*, int, int);
92 static int	hdecsym(Input *in, Huff *h, int b);
93 
94 int
inflateinit(void)95 inflateinit(void)
96 {
97 	char *len;
98 	int i, j, base;
99 
100 	/* byte reverse table */
101 	for(i=0; i<256; i++)
102 		for(j=0; j<8; j++)
103 			if(i & (1<<j))
104 				revtab[i] |= 0x80 >> j;
105 
106 	for(i=257,base=3; i<Nlitlen; i++) {
107 		litlenbase[i-257] = base;
108 		base += 1<<litlenextra[i-257];
109 	}
110 	/* strange table entry in spec... */
111 	litlenbase[285-257]--;
112 
113 	for(i=0,base=1; i<Noff; i++) {
114 		offbase[i] = base;
115 		base += 1<<offextra[i];
116 	}
117 
118 	len = malloc(MaxLeaf);
119 	if(len == nil)
120 		return FlateNoMem;
121 
122 	/* static Litlen bit lengths */
123 	for(i=0; i<144; i++)
124 		len[i] = 8;
125 	for(i=144; i<256; i++)
126 		len[i] = 9;
127 	for(i=256; i<280; i++)
128 		len[i] = 7;
129 	for(i=280; i<Nlitlen; i++)
130 		len[i] = 8;
131 
132 	if(!hufftab(&litlentab, len, Nlitlen, MaxFlatBits))
133 		return FlateInternal;
134 
135 	/* static Offset bit lengths */
136 	for(i=0; i<Noff; i++)
137 		len[i] = 5;
138 
139 	if(!hufftab(&offtab, len, Noff, MaxFlatBits))
140 		return FlateInternal;
141 	free(len);
142 
143 	return FlateOk;
144 }
145 
146 int
inflate(void * wr,int (* w)(void *,void *,int),void * getr,int (* get)(void *))147 inflate(void *wr, int (*w)(void*, void*, int), void *getr, int (*get)(void*))
148 {
149 	History *his;
150 	Input in;
151 	int final, type;
152 
153 	his = malloc(sizeof(History));
154 	if(his == nil)
155 		return FlateNoMem;
156 	his->cp = his->his;
157 	his->full = 0;
158 	in.getr = getr;
159 	in.get = get;
160 	in.wr = wr;
161 	in.w = w;
162 	in.nbits = 0;
163 	in.sreg = 0;
164 	in.error = FlateOk;
165 
166 	do {
167 		if(!sregfill(&in, 3))
168 			goto bad;
169 		final = in.sreg & 0x1;
170 		type = (in.sreg>>1) & 0x3;
171 		in.sreg >>= 3;
172 		in.nbits -= 3;
173 		switch(type) {
174 		default:
175 			in.error = FlateCorrupted;
176 			goto bad;
177 		case 0:
178 			/* uncompressed */
179 			if(!uncblock(&in, his))
180 				goto bad;
181 			break;
182 		case 1:
183 			/* fixed huffman */
184 			if(!fixedblock(&in, his))
185 				goto bad;
186 			break;
187 		case 2:
188 			/* dynamic huffman */
189 			if(!dynamicblock(&in, his))
190 				goto bad;
191 			break;
192 		}
193 	} while(!final);
194 
195 	if(his->cp != his->his && (*w)(wr, his->his, his->cp - his->his) != his->cp - his->his) {
196 		in.error = FlateOutputFail;
197 		goto bad;
198 	}
199 
200 	if(!sregunget(&in))
201 		goto bad;
202 
203 	free(his);
204 	if(in.error != FlateOk)
205 		return FlateInternal;
206 	return FlateOk;
207 
208 bad:
209 	free(his);
210 	if(in.error == FlateOk)
211 		return FlateInternal;
212 	return in.error;
213 }
214 
215 static int
uncblock(Input * in,History * his)216 uncblock(Input *in, History *his)
217 {
218 	int len, nlen, c;
219 	uchar *hs, *hp, *he;
220 
221 	if(!sregunget(in))
222 		return 0;
223 	len = (*in->get)(in->getr);
224 	len |= (*in->get)(in->getr)<<8;
225 	nlen = (*in->get)(in->getr);
226 	nlen |= (*in->get)(in->getr)<<8;
227 	if(len != (~nlen&0xffff)) {
228 		in->error = FlateCorrupted;
229 		return 0;
230 	}
231 
232 	hp = his->cp;
233 	hs = his->his;
234 	he = hs + HistorySize;
235 
236 	while(len > 0) {
237 		c = (*in->get)(in->getr);
238 		if(c < 0)
239 			return 0;
240 		*hp++ = c;
241 		if(hp == he) {
242 			his->full = 1;
243 			if((*in->w)(in->wr, hs, HistorySize) != HistorySize) {
244 				in->error = FlateOutputFail;
245 				return 0;
246 			}
247 			hp = hs;
248 		}
249 		len--;
250 	}
251 
252 	his->cp = hp;
253 
254 	return 1;
255 }
256 
257 static int
fixedblock(Input * in,History * his)258 fixedblock(Input *in, History *his)
259 {
260 	return decode(in, his, &litlentab, &offtab);
261 }
262 
263 static int
dynamicblock(Input * in,History * his)264 dynamicblock(Input *in, History *his)
265 {
266 	Huff *lentab, *offtab;
267 	char *len;
268 	int i, j, n, c, nlit, ndist, nclen, res, nb;
269 
270 	if(!sregfill(in, 14))
271 		return 0;
272 	nlit = (in->sreg&0x1f) + 257;
273 	ndist = ((in->sreg>>5) & 0x1f) + 1;
274 	nclen = ((in->sreg>>10) & 0xf) + 4;
275 	in->sreg >>= 14;
276 	in->nbits -= 14;
277 
278 	if(nlit > Nlitlen || ndist > Noff || nlit < 257) {
279 		in->error = FlateCorrupted;
280 		return 0;
281 	}
282 
283 	/* huff table header */
284 	len = malloc(Nlitlen+Noff);
285 	lentab = malloc(sizeof(Huff));
286 	offtab = malloc(sizeof(Huff));
287 	if(len == nil || lentab == nil || offtab == nil){
288 		in->error = FlateNoMem;
289 		goto bad;
290 	}
291 	for(i=0; i < Nclen; i++)
292 		len[i] = 0;
293 	for(i=0; i<nclen; i++) {
294 		if(!sregfill(in, 3))
295 			goto bad;
296 		len[clenorder[i]] = in->sreg & 0x7;
297 		in->sreg >>= 3;
298 		in->nbits -= 3;
299 	}
300 
301 	if(!hufftab(lentab, len, Nclen, ClenBits)){
302 		in->error = FlateCorrupted;
303 		goto bad;
304 	}
305 
306 	n = nlit+ndist;
307 	for(i=0; i<n;) {
308 		nb = lentab->minbits;
309 		for(;;){
310 			if(in->nbits<nb && !sregfill(in, nb))
311 				goto bad;
312 			c = lentab->flat[in->sreg & lentab->flatmask];
313 			nb = c & 0xff;
314 			if(nb > in->nbits){
315 				if(nb != 0xff)
316 					continue;
317 				c = hdecsym(in, lentab, c);
318 				if(c < 0)
319 					goto bad;
320 			}else{
321 				c >>= 8;
322 				in->sreg >>= nb;
323 				in->nbits -= nb;
324 			}
325 			break;
326 		}
327 
328 		if(c < 16) {
329 			j = 1;
330 		} else if(c == 16) {
331 			if(in->nbits<2 && !sregfill(in, 2))
332 				goto bad;
333 			j = (in->sreg&0x3)+3;
334 			in->sreg >>= 2;
335 			in->nbits -= 2;
336 			if(i == 0) {
337 				in->error = FlateCorrupted;
338 				goto bad;
339 			}
340 			c = len[i-1];
341 		} else if(c == 17) {
342 			if(in->nbits<3 && !sregfill(in, 3))
343 				goto bad;
344 			j = (in->sreg&0x7)+3;
345 			in->sreg >>= 3;
346 			in->nbits -= 3;
347 			c = 0;
348 		} else if(c == 18) {
349 			if(in->nbits<7 && !sregfill(in, 7))
350 				goto bad;
351 			j = (in->sreg&0x7f)+11;
352 			in->sreg >>= 7;
353 			in->nbits -= 7;
354 			c = 0;
355 		} else {
356 			in->error = FlateCorrupted;
357 			goto bad;
358 		}
359 
360 		if(i+j > n) {
361 			in->error = FlateCorrupted;
362 			goto bad;
363 		}
364 
365 		while(j) {
366 			len[i] = c;
367 			i++;
368 			j--;
369 		}
370 	}
371 
372 	if(!hufftab(lentab, len, nlit, LitlenBits)
373 	|| !hufftab(offtab, &len[nlit], ndist, OffBits)){
374 		in->error = FlateCorrupted;
375 		goto bad;
376 	}
377 
378 	res = decode(in, his, lentab, offtab);
379 
380 	free(len);
381 	free(lentab);
382 	free(offtab);
383 
384 	return res;
385 
386 bad:
387 	free(len);
388 	free(lentab);
389 	free(offtab);
390 	return 0;
391 }
392 
393 static int
decode(Input * in,History * his,Huff * litlentab,Huff * offtab)394 decode(Input *in, History *his, Huff *litlentab, Huff *offtab)
395 {
396 	int len, off;
397 	uchar *hs, *hp, *hq, *he;
398 	int c;
399 	int nb;
400 
401 	hs = his->his;
402 	he = hs + HistorySize;
403 	hp = his->cp;
404 
405 	for(;;) {
406 		nb = litlentab->minbits;
407 		for(;;){
408 			if(in->nbits<nb && !sregfill(in, nb))
409 				return 0;
410 			c = litlentab->flat[in->sreg & litlentab->flatmask];
411 			nb = c & 0xff;
412 			if(nb > in->nbits){
413 				if(nb != 0xff)
414 					continue;
415 				c = hdecsym(in, litlentab, c);
416 				if(c < 0)
417 					return 0;
418 			}else{
419 				c >>= 8;
420 				in->sreg >>= nb;
421 				in->nbits -= nb;
422 			}
423 			break;
424 		}
425 
426 		if(c < 256) {
427 			/* literal */
428 			*hp++ = c;
429 			if(hp == he) {
430 				his->full = 1;
431 				if((*in->w)(in->wr, hs, HistorySize) != HistorySize) {
432 					in->error = FlateOutputFail;
433 					return 0;
434 				}
435 				hp = hs;
436 			}
437 			continue;
438 		}
439 
440 		if(c == 256)
441 			break;
442 
443 		if(c > 285) {
444 			in->error = FlateCorrupted;
445 			return 0;
446 		}
447 
448 		c -= 257;
449 		nb = litlenextra[c];
450 		if(in->nbits < nb && !sregfill(in, nb))
451 			return 0;
452 		len = litlenbase[c] + (in->sreg & ((1<<nb)-1));
453 		in->sreg >>= nb;
454 		in->nbits -= nb;
455 
456 		/* get offset */
457 		nb = offtab->minbits;
458 		for(;;){
459 			if(in->nbits<nb && !sregfill(in, nb))
460 				return 0;
461 			c = offtab->flat[in->sreg & offtab->flatmask];
462 			nb = c & 0xff;
463 			if(nb > in->nbits){
464 				if(nb != 0xff)
465 					continue;
466 				c = hdecsym(in, offtab, c);
467 				if(c < 0)
468 					return 0;
469 			}else{
470 				c >>= 8;
471 				in->sreg >>= nb;
472 				in->nbits -= nb;
473 			}
474 			break;
475 		}
476 
477 		if(c > 29) {
478 			in->error = FlateCorrupted;
479 			return 0;
480 		}
481 
482 		nb = offextra[c];
483 		if(in->nbits < nb && !sregfill(in, nb))
484 			return 0;
485 
486 		off = offbase[c] + (in->sreg & ((1<<nb)-1));
487 		in->sreg >>= nb;
488 		in->nbits -= nb;
489 
490 		hq = hp - off;
491 		if(hq < hs) {
492 			if(!his->full) {
493 				in->error = FlateCorrupted;
494 				return 0;
495 			}
496 			hq += HistorySize;
497 		}
498 
499 		/* slow but correct */
500 		while(len) {
501 			*hp = *hq;
502 			hq++;
503 			hp++;
504 			if(hq >= he)
505 				hq = hs;
506 			if(hp == he) {
507 				his->full = 1;
508 				if((*in->w)(in->wr, hs, HistorySize) != HistorySize) {
509 					in->error = FlateOutputFail;
510 					return 0;
511 				}
512 				hp = hs;
513 			}
514 			len--;
515 		}
516 
517 	}
518 
519 	his->cp = hp;
520 
521 	return 1;
522 }
523 
524 static int
revcode(int c,int b)525 revcode(int c, int b)
526 {
527 	/* shift encode up so it starts on bit 15 then reverse */
528 	c <<= (16-b);
529 	c = revtab[c>>8] | (revtab[c&0xff]<<8);
530 	return c;
531 }
532 
533 /*
534  * construct the huffman decoding arrays and a fast lookup table.
535  * the fast lookup is a table indexed by the next flatbits bits,
536  * which returns the symbol matched and the number of bits consumed,
537  * or the minimum number of bits needed and 0xff if more than flatbits
538  * bits are needed.
539  *
540  * flatbits can be longer than the smallest huffman code,
541  * because shorter codes are assigned smaller lexical prefixes.
542  * this means assuming zeros for the next few bits will give a
543  * conservative answer, in the sense that it will either give the
544  * correct answer, or return the minimum number of bits which
545  * are needed for an answer.
546  */
547 static int
hufftab(Huff * h,char * hb,int maxleaf,int flatbits)548 hufftab(Huff *h, char *hb, int maxleaf, int flatbits)
549 {
550 	ulong bitcount[MaxHuffBits];
551 	ulong c, fc, ec, mincode, code, nc[MaxHuffBits];
552 	int i, b, minbits, maxbits;
553 
554 	for(i = 0; i < MaxHuffBits; i++)
555 		bitcount[i] = 0;
556 	maxbits = -1;
557 	minbits = MaxHuffBits + 1;
558 	for(i=0; i < maxleaf; i++){
559 		b = hb[i];
560 		if(b){
561 			bitcount[b]++;
562 			if(b < minbits)
563 				minbits = b;
564 			if(b > maxbits)
565 				maxbits = b;
566 		}
567 	}
568 
569 	h->maxbits = maxbits;
570 	if(maxbits <= 0){
571 		h->maxbits = 0;
572 		h->minbits = 0;
573 		h->flatmask = 0;
574 		return 1;
575 	}
576 	code = 0;
577 	c = 0;
578 	for(b = 0; b <= maxbits; b++){
579 		h->last[b] = c;
580 		c += bitcount[b];
581 		mincode = code << 1;
582 		nc[b] = mincode;
583 		code = mincode + bitcount[b];
584 		if(code > (1 << b))
585 			return 0;
586 		h->maxcode[b] = code - 1;
587 		h->last[b] += code - 1;
588 	}
589 
590 	if(flatbits > maxbits)
591 		flatbits = maxbits;
592 	h->flatmask = (1 << flatbits) - 1;
593 	if(minbits > flatbits)
594 		minbits = flatbits;
595 	h->minbits = minbits;
596 
597 	b = 1 << flatbits;
598 	for(i = 0; i < b; i++)
599 		h->flat[i] = ~0;
600 
601 	/*
602 	 * initialize the flat table to include the minimum possible
603 	 * bit length for each code prefix
604 	 */
605 	for(b = maxbits; b > flatbits; b--){
606 		code = h->maxcode[b];
607 		if(code == -1)
608 			break;
609 		mincode = code + 1 - bitcount[b];
610 		mincode >>= b - flatbits;
611 		code >>= b - flatbits;
612 		for(; mincode <= code; mincode++)
613 			h->flat[revcode(mincode, flatbits)] = (b << 8) | 0xff;
614 	}
615 
616 	for(i = 0; i < maxleaf; i++){
617 		b = hb[i];
618 		if(b <= 0)
619 			continue;
620 		c = nc[b]++;
621 		if(b <= flatbits){
622 			code = (i << 8) | b;
623 			ec = (c + 1) << (flatbits - b);
624 			if(ec > (1<<flatbits))
625 				return 0;	/* this is actually an internal error */
626 			for(fc = c << (flatbits - b); fc < ec; fc++)
627 				h->flat[revcode(fc, flatbits)] = code;
628 		}
629 		if(b > minbits){
630 			c = h->last[b] - c;
631 			if(c >= maxleaf)
632 				return 0;
633 			h->decode[c] = i;
634 		}
635 	}
636 	return 1;
637 }
638 
639 static int
hdecsym(Input * in,Huff * h,int nb)640 hdecsym(Input *in, Huff *h, int nb)
641 {
642 	long c;
643 
644 	if((nb & 0xff) == 0xff)
645 		nb = nb >> 8;
646 	else
647 		nb = nb & 0xff;
648 	for(; nb <= h->maxbits; nb++){
649 		if(in->nbits<nb && !sregfill(in, nb))
650 			return -1;
651 		c = revtab[in->sreg&0xff]<<8;
652 		c |= revtab[(in->sreg>>8)&0xff];
653 		c >>= (16-nb);
654 		if(c <= h->maxcode[nb]){
655 			in->sreg >>= nb;
656 			in->nbits -= nb;
657 			return h->decode[h->last[nb] - c];
658 		}
659 	}
660 	in->error = FlateCorrupted;
661 	return -1;
662 }
663 
664 static int
sregfill(Input * in,int n)665 sregfill(Input *in, int n)
666 {
667 	int c;
668 
669 	while(n > in->nbits) {
670 		c = (*in->get)(in->getr);
671 		if(c < 0){
672 			in->error = FlateInputFail;
673 			return 0;
674 		}
675 		in->sreg |= c<<in->nbits;
676 		in->nbits += 8;
677 	}
678 	return 1;
679 }
680 
681 static int
sregunget(Input * in)682 sregunget(Input *in)
683 {
684 	if(in->nbits >= 8) {
685 		in->error = FlateInternal;
686 		return 0;
687 	}
688 
689 	/* throw other bits on the floor */
690 	in->nbits = 0;
691 	in->sreg = 0;
692 	return 1;
693 }
694