xref: /inferno-os/os/boot/libflate/inflate.c (revision 74a4d8c26dd3c1e9febcb717cfd6cb6512991a7a)
1 #include "lib9.h"
2 #include <flate.h>
3 
4 enum {
5 	HistorySize=	32*1024,
6 	BufSize=	4*1024,
7 	MaxHuffBits=	17,	/* maximum bits in a encoded code */
8 	Nlitlen=	288,	/* number of litlen codes */
9 	Noff=		32,	/* number of offset codes */
10 	Nclen=		19,	/* number of codelen codes */
11 	LenShift=	10,	/* code = len<<LenShift|code */
12 	LitlenBits=	7,	/* number of bits in litlen decode table */
13 	OffBits=	6,	/* number of bits in offset decode table */
14 	ClenBits=	6,	/* number of bits in code len decode table */
15 	MaxFlatBits=	LitlenBits,
16 	MaxLeaf=	Nlitlen
17 };
18 
19 typedef struct Input	Input;
20 typedef struct History	History;
21 typedef struct Huff	Huff;
22 
23 struct Input
24 {
25 	int	error;		/* first error encountered, or FlateOk */
26 	void	*wr;
27 	int	(*w)(void*, void*, int);
28 	void	*getr;
29 	int	(*get)(void*);
30 	ulong	sreg;
31 	int	nbits;
32 };
33 
34 struct History
35 {
36 	uchar	his[HistorySize];
37 	uchar	*cp;		/* current pointer in history */
38 	int	full;		/* his has been filled up at least once */
39 };
40 
41 struct Huff
42 {
43 	int	maxbits;	/* max bits for any code */
44 	int	minbits;	/* min bits to get before looking in flat */
45 	int	flatmask;	/* bits used in "flat" fast decoding table */
46 	ulong	flat[1<<MaxFlatBits];
47 	ulong	maxcode[MaxHuffBits];
48 	ulong	last[MaxHuffBits];
49 	ulong	decode[MaxLeaf];
50 };
51 
52 /* litlen code words 257-285 extra bits */
53 static int litlenextra[Nlitlen-257] =
54 {
55 /* 257 */	0, 0, 0,
56 /* 260 */	0, 0, 0, 0, 0, 1, 1, 1, 1, 2,
57 /* 270 */	2, 2, 2, 3, 3, 3, 3, 4, 4, 4,
58 /* 280 */	4, 5, 5, 5, 5, 0, 0, 0
59 };
60 
61 static int litlenbase[Nlitlen-257];
62 
63 /* offset code word extra bits */
64 static int offextra[Noff] =
65 {
66 	0,  0,  0,  0,  1,  1,  2,  2,  3,  3,
67 	4,  4,  5,  5,  6,  6,  7,  7,  8,  8,
68 	9,  9,  10, 10, 11, 11, 12, 12, 13, 13,
69 	0,  0,
70 };
71 static int offbase[Noff];
72 
73 /* order code lengths */
74 static int clenorder[Nclen] =
75 {
76         16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15
77 };
78 
79 /* for static huffman tables */
80 static	Huff	litlentab;
81 static	Huff	offtab;
82 static	uchar	revtab[256];
83 
84 static int	uncblock(Input *in, History*);
85 static int	fixedblock(Input *in, History*);
86 static int	dynamicblock(Input *in, History*);
87 static int	sregfill(Input *in, int n);
88 static int	sregunget(Input *in);
89 static int	decode(Input*, History*, Huff*, Huff*);
90 static int	hufftab(Huff*, char*, int, int);
91 static int	hdecsym(Input *in, Huff *h, int b);
92 
93 int
inflateinit(void)94 inflateinit(void)
95 {
96 	char *len;
97 	int i, j, base;
98 
99 	/* byte reverse table */
100 	for(i=0; i<256; i++)
101 		for(j=0; j<8; j++)
102 			if(i & (1<<j))
103 				revtab[i] |= 0x80 >> j;
104 
105 	for(i=257,base=3; i<Nlitlen; i++) {
106 		litlenbase[i-257] = base;
107 		base += 1<<litlenextra[i-257];
108 	}
109 	/* strange table entry in spec... */
110 	litlenbase[285-257]--;
111 
112 	for(i=0,base=1; i<Noff; i++) {
113 		offbase[i] = base;
114 		base += 1<<offextra[i];
115 	}
116 
117 	len = malloc(MaxLeaf);
118 	if(len == nil)
119 		return FlateNoMem;
120 
121 	/* static Litlen bit lengths */
122 	for(i=0; i<144; i++)
123 		len[i] = 8;
124 	for(i=144; i<256; i++)
125 		len[i] = 9;
126 	for(i=256; i<280; i++)
127 		len[i] = 7;
128 	for(i=280; i<Nlitlen; i++)
129 		len[i] = 8;
130 
131 	if(!hufftab(&litlentab, len, Nlitlen, MaxFlatBits))
132 		return FlateInternal;
133 
134 	/* static Offset bit lengths */
135 	for(i=0; i<Noff; i++)
136 		len[i] = 5;
137 
138 	if(!hufftab(&offtab, len, Noff, MaxFlatBits))
139 		return FlateInternal;
140 	free(len);
141 
142 	return FlateOk;
143 }
144 
145 int
inflate(void * wr,int (* w)(void *,void *,int),void * getr,int (* get)(void *))146 inflate(void *wr, int (*w)(void*, void*, int), void *getr, int (*get)(void*))
147 {
148 	History *his;
149 	Input in;
150 	int final, type;
151 
152 	his = malloc(sizeof(History));
153 	if(his == nil)
154 		return FlateNoMem;
155 	his->cp = his->his;
156 	his->full = 0;
157 	in.getr = getr;
158 	in.get = get;
159 	in.wr = wr;
160 	in.w = w;
161 	in.nbits = 0;
162 	in.sreg = 0;
163 	in.error = FlateOk;
164 
165 	do {
166 		if(!sregfill(&in, 3))
167 			goto bad;
168 		final = in.sreg & 0x1;
169 		type = (in.sreg>>1) & 0x3;
170 		in.sreg >>= 3;
171 		in.nbits -= 3;
172 		switch(type) {
173 		default:
174 			in.error = FlateCorrupted;
175 			goto bad;
176 		case 0:
177 			/* uncompressed */
178 			if(!uncblock(&in, his))
179 				goto bad;
180 			break;
181 		case 1:
182 			/* fixed huffman */
183 			if(!fixedblock(&in, his))
184 				goto bad;
185 			break;
186 		case 2:
187 			/* dynamic huffman */
188 			if(!dynamicblock(&in, his))
189 				goto bad;
190 			break;
191 		}
192 	} while(!final);
193 
194 	if(his->cp != his->his && (*w)(wr, his->his, his->cp - his->his) != his->cp - his->his) {
195 		in.error = FlateOutputFail;
196 		goto bad;
197 	}
198 
199 	if(!sregunget(&in))
200 		goto bad;
201 
202 	free(his);
203 	if(in.error != FlateOk)
204 		return FlateInternal;
205 	return FlateOk;
206 
207 bad:
208 	free(his);
209 	if(in.error == FlateOk)
210 		return FlateInternal;
211 	return in.error;
212 }
213 
214 static int
uncblock(Input * in,History * his)215 uncblock(Input *in, History *his)
216 {
217 	int len, nlen, c;
218 	uchar *hs, *hp, *he;
219 
220 	if(!sregunget(in))
221 		return 0;
222 	len = (*in->get)(in->getr);
223 	len |= (*in->get)(in->getr)<<8;
224 	nlen = (*in->get)(in->getr);
225 	nlen |= (*in->get)(in->getr)<<8;
226 	if(len != (~nlen&0xffff)) {
227 		in->error = FlateCorrupted;
228 		return 0;
229 	}
230 
231 	hp = his->cp;
232 	hs = his->his;
233 	he = hs + HistorySize;
234 
235 	while(len > 0) {
236 		c = (*in->get)(in->getr);
237 		if(c < 0)
238 			return 0;
239 		*hp++ = c;
240 		if(hp == he) {
241 			his->full = 1;
242 			if((*in->w)(in->wr, hs, HistorySize) != HistorySize) {
243 				in->error = FlateOutputFail;
244 				return 0;
245 			}
246 			hp = hs;
247 		}
248 		len--;
249 	}
250 
251 	his->cp = hp;
252 
253 	return 1;
254 }
255 
256 static int
fixedblock(Input * in,History * his)257 fixedblock(Input *in, History *his)
258 {
259 	return decode(in, his, &litlentab, &offtab);
260 }
261 
262 static int
dynamicblock(Input * in,History * his)263 dynamicblock(Input *in, History *his)
264 {
265 	Huff *lentab, *offtab;
266 	char *len;
267 	int i, j, n, c, nlit, ndist, nclen, res, nb;
268 
269 	if(!sregfill(in, 14))
270 		return 0;
271 	nlit = (in->sreg&0x1f) + 257;
272 	ndist = ((in->sreg>>5) & 0x1f) + 1;
273 	nclen = ((in->sreg>>10) & 0xf) + 4;
274 	in->sreg >>= 14;
275 	in->nbits -= 14;
276 
277 	if(nlit > Nlitlen || ndist > Noff || nlit < 257) {
278 		in->error = FlateCorrupted;
279 		return 0;
280 	}
281 
282 	/* huff table header */
283 	len = malloc(Nlitlen+Noff);
284 	lentab = malloc(sizeof(Huff));
285 	offtab = malloc(sizeof(Huff));
286 	if(len == nil || lentab == nil || offtab == nil){
287 		in->error = FlateNoMem;
288 		goto bad;
289 	}
290 	for(i=0; i < Nclen; i++)
291 		len[i] = 0;
292 	for(i=0; i<nclen; i++) {
293 		if(!sregfill(in, 3))
294 			goto bad;
295 		len[clenorder[i]] = in->sreg & 0x7;
296 		in->sreg >>= 3;
297 		in->nbits -= 3;
298 	}
299 
300 	if(!hufftab(lentab, len, Nclen, ClenBits)){
301 		in->error = FlateCorrupted;
302 		goto bad;
303 	}
304 
305 	n = nlit+ndist;
306 	for(i=0; i<n;) {
307 		nb = lentab->minbits;
308 		for(;;){
309 			if(in->nbits<nb && !sregfill(in, nb))
310 				goto bad;
311 			c = lentab->flat[in->sreg & lentab->flatmask];
312 			nb = c & 0xff;
313 			if(nb > in->nbits){
314 				if(nb != 0xff)
315 					continue;
316 				c = hdecsym(in, lentab, c);
317 				if(c < 0)
318 					goto bad;
319 			}else{
320 				c >>= 8;
321 				in->sreg >>= nb;
322 				in->nbits -= nb;
323 			}
324 			break;
325 		}
326 
327 		if(c < 16) {
328 			j = 1;
329 		} else if(c == 16) {
330 			if(in->nbits<2 && !sregfill(in, 2))
331 				goto bad;
332 			j = (in->sreg&0x3)+3;
333 			in->sreg >>= 2;
334 			in->nbits -= 2;
335 			if(i == 0) {
336 				in->error = FlateCorrupted;
337 				goto bad;
338 			}
339 			c = len[i-1];
340 		} else if(c == 17) {
341 			if(in->nbits<3 && !sregfill(in, 3))
342 				goto bad;
343 			j = (in->sreg&0x7)+3;
344 			in->sreg >>= 3;
345 			in->nbits -= 3;
346 			c = 0;
347 		} else if(c == 18) {
348 			if(in->nbits<7 && !sregfill(in, 7))
349 				goto bad;
350 			j = (in->sreg&0x7f)+11;
351 			in->sreg >>= 7;
352 			in->nbits -= 7;
353 			c = 0;
354 		} else {
355 			in->error = FlateCorrupted;
356 			goto bad;
357 		}
358 
359 		if(i+j > n) {
360 			in->error = FlateCorrupted;
361 			goto bad;
362 		}
363 
364 		while(j) {
365 			len[i] = c;
366 			i++;
367 			j--;
368 		}
369 	}
370 
371 	if(!hufftab(lentab, len, nlit, LitlenBits)
372 	|| !hufftab(offtab, &len[nlit], ndist, OffBits)){
373 		in->error = FlateCorrupted;
374 		goto bad;
375 	}
376 
377 	res = decode(in, his, lentab, offtab);
378 
379 	free(len);
380 	free(lentab);
381 	free(offtab);
382 
383 	return res;
384 
385 bad:
386 	free(len);
387 	free(lentab);
388 	free(offtab);
389 	return 0;
390 }
391 
392 static int
decode(Input * in,History * his,Huff * litlentab,Huff * offtab)393 decode(Input *in, History *his, Huff *litlentab, Huff *offtab)
394 {
395 	int len, off;
396 	uchar *hs, *hp, *hq, *he;
397 	int c;
398 	int nb;
399 
400 	hs = his->his;
401 	he = hs + HistorySize;
402 	hp = his->cp;
403 
404 	for(;;) {
405 		nb = litlentab->minbits;
406 		for(;;){
407 			if(in->nbits<nb && !sregfill(in, nb))
408 				return 0;
409 			c = litlentab->flat[in->sreg & litlentab->flatmask];
410 			nb = c & 0xff;
411 			if(nb > in->nbits){
412 				if(nb != 0xff)
413 					continue;
414 				c = hdecsym(in, litlentab, c);
415 				if(c < 0)
416 					return 0;
417 			}else{
418 				c >>= 8;
419 				in->sreg >>= nb;
420 				in->nbits -= nb;
421 			}
422 			break;
423 		}
424 
425 		if(c < 256) {
426 			/* literal */
427 			*hp++ = c;
428 			if(hp == he) {
429 				his->full = 1;
430 				if((*in->w)(in->wr, hs, HistorySize) != HistorySize) {
431 					in->error = FlateOutputFail;
432 					return 0;
433 				}
434 				hp = hs;
435 			}
436 			continue;
437 		}
438 
439 		if(c == 256)
440 			break;
441 
442 		if(c > 285) {
443 			in->error = FlateCorrupted;
444 			return 0;
445 		}
446 
447 		c -= 257;
448 		nb = litlenextra[c];
449 		if(in->nbits < nb && !sregfill(in, nb))
450 			return 0;
451 		len = litlenbase[c] + (in->sreg & ((1<<nb)-1));
452 		in->sreg >>= nb;
453 		in->nbits -= nb;
454 
455 		/* get offset */
456 		nb = offtab->minbits;
457 		for(;;){
458 			if(in->nbits<nb && !sregfill(in, nb))
459 				return 0;
460 			c = offtab->flat[in->sreg & offtab->flatmask];
461 			nb = c & 0xff;
462 			if(nb > in->nbits){
463 				if(nb != 0xff)
464 					continue;
465 				c = hdecsym(in, offtab, c);
466 				if(c < 0)
467 					return 0;
468 			}else{
469 				c >>= 8;
470 				in->sreg >>= nb;
471 				in->nbits -= nb;
472 			}
473 			break;
474 		}
475 
476 		if(c > 29) {
477 			in->error = FlateCorrupted;
478 			return 0;
479 		}
480 
481 		nb = offextra[c];
482 		if(in->nbits < nb && !sregfill(in, nb))
483 			return 0;
484 
485 		off = offbase[c] + (in->sreg & ((1<<nb)-1));
486 		in->sreg >>= nb;
487 		in->nbits -= nb;
488 
489 		hq = hp - off;
490 		if(hq < hs) {
491 			if(!his->full) {
492 				in->error = FlateCorrupted;
493 				return 0;
494 			}
495 			hq += HistorySize;
496 		}
497 
498 		/* slow but correct */
499 		while(len) {
500 			*hp = *hq;
501 			hq++;
502 			hp++;
503 			if(hq >= he)
504 				hq = hs;
505 			if(hp == he) {
506 				his->full = 1;
507 				if((*in->w)(in->wr, hs, HistorySize) != HistorySize) {
508 					in->error = FlateOutputFail;
509 					return 0;
510 				}
511 				hp = hs;
512 			}
513 			len--;
514 		}
515 
516 	}
517 
518 	his->cp = hp;
519 
520 	return 1;
521 }
522 
523 static int
revcode(int c,int b)524 revcode(int c, int b)
525 {
526 	/* shift encode up so it starts on bit 15 then reverse */
527 	c <<= (16-b);
528 	c = revtab[c>>8] | (revtab[c&0xff]<<8);
529 	return c;
530 }
531 
532 /*
533  * construct the huffman decoding arrays and a fast lookup table.
534  * the fast lookup is a table indexed by the next flatbits bits,
535  * which returns the symbol matched and the number of bits consumed,
536  * or the minimum number of bits needed and 0xff if more than flatbits
537  * bits are needed.
538  *
539  * flatbits can be longer than the smallest huffman code,
540  * because shorter codes are assigned smaller lexical prefixes.
541  * this means assuming zeros for the next few bits will give a
542  * conservative answer, in the sense that it will either give the
543  * correct answer, or return the minimum number of bits which
544  * are needed for an answer.
545  */
546 static int
hufftab(Huff * h,char * hb,int maxleaf,int flatbits)547 hufftab(Huff *h, char *hb, int maxleaf, int flatbits)
548 {
549 	ulong bitcount[MaxHuffBits];
550 	ulong c, fc, ec, mincode, code, nc[MaxHuffBits];
551 	int i, b, minbits, maxbits;
552 
553 	for(i = 0; i < MaxHuffBits; i++)
554 		bitcount[i] = 0;
555 	maxbits = -1;
556 	minbits = MaxHuffBits + 1;
557 	for(i=0; i < maxleaf; i++){
558 		b = hb[i];
559 		if(b){
560 			bitcount[b]++;
561 			if(b < minbits)
562 				minbits = b;
563 			if(b > maxbits)
564 				maxbits = b;
565 		}
566 	}
567 
568 	h->maxbits = maxbits;
569 	if(maxbits <= 0){
570 		h->maxbits = 0;
571 		h->minbits = 0;
572 		h->flatmask = 0;
573 		return 1;
574 	}
575 	code = 0;
576 	c = 0;
577 	for(b = 0; b <= maxbits; b++){
578 		h->last[b] = c;
579 		c += bitcount[b];
580 		mincode = code << 1;
581 		nc[b] = mincode;
582 		code = mincode + bitcount[b];
583 		if(code > (1 << b))
584 			return 0;
585 		h->maxcode[b] = code - 1;
586 		h->last[b] += code - 1;
587 	}
588 
589 	if(flatbits > maxbits)
590 		flatbits = maxbits;
591 	h->flatmask = (1 << flatbits) - 1;
592 	if(minbits > flatbits)
593 		minbits = flatbits;
594 	h->minbits = minbits;
595 
596 	b = 1 << flatbits;
597 	for(i = 0; i < b; i++)
598 		h->flat[i] = ~0;
599 
600 	/*
601 	 * initialize the flat table to include the minimum possible
602 	 * bit length for each code prefix
603 	 */
604 	for(b = maxbits; b > flatbits; b--){
605 		code = h->maxcode[b];
606 		if(code == -1)
607 			break;
608 		mincode = code + 1 - bitcount[b];
609 		mincode >>= b - flatbits;
610 		code >>= b - flatbits;
611 		for(; mincode <= code; mincode++)
612 			h->flat[revcode(mincode, flatbits)] = (b << 8) | 0xff;
613 	}
614 
615 	for(i = 0; i < maxleaf; i++){
616 		b = hb[i];
617 		if(b <= 0)
618 			continue;
619 		c = nc[b]++;
620 		if(b <= flatbits){
621 			code = (i << 8) | b;
622 			ec = (c + 1) << (flatbits - b);
623 			if(ec > (1<<flatbits))
624 				return 0;	/* this is actually an internal error */
625 			for(fc = c << (flatbits - b); fc < ec; fc++)
626 				h->flat[revcode(fc, flatbits)] = code;
627 		}
628 		if(b > minbits){
629 			c = h->last[b] - c;
630 			if(c >= maxleaf)
631 				return 0;
632 			h->decode[c] = i;
633 		}
634 	}
635 	return 1;
636 }
637 
638 static int
hdecsym(Input * in,Huff * h,int nb)639 hdecsym(Input *in, Huff *h, int nb)
640 {
641 	long c;
642 
643 	if((nb & 0xff) == 0xff)
644 		nb = nb >> 8;
645 	else
646 		nb = nb & 0xff;
647 	for(; nb <= h->maxbits; nb++){
648 		if(in->nbits<nb && !sregfill(in, nb))
649 			return -1;
650 		c = revtab[in->sreg&0xff]<<8;
651 		c |= revtab[(in->sreg>>8)&0xff];
652 		c >>= (16-nb);
653 		if(c <= h->maxcode[nb]){
654 			in->sreg >>= nb;
655 			in->nbits -= nb;
656 			return h->decode[h->last[nb] - c];
657 		}
658 	}
659 	in->error = FlateCorrupted;
660 	return -1;
661 }
662 
663 static int
sregfill(Input * in,int n)664 sregfill(Input *in, int n)
665 {
666 	int c;
667 
668 	while(n > in->nbits) {
669 		c = (*in->get)(in->getr);
670 		if(c < 0){
671 			in->error = FlateInputFail;
672 			return 0;
673 		}
674 		in->sreg |= c<<in->nbits;
675 		in->nbits += 8;
676 	}
677 	return 1;
678 }
679 
680 static int
sregunget(Input * in)681 sregunget(Input *in)
682 {
683 	if(in->nbits >= 8) {
684 		in->error = FlateInternal;
685 		return 0;
686 	}
687 
688 	/* throw other bits on the floor */
689 	in->nbits = 0;
690 	in->sreg = 0;
691 	return 1;
692 }
693