xref: /plan9/sys/src/9/port/devssl.c (revision a6a9e07217f318acf170f99684a55fba5200524f)
1 /*
2  *  devssl - secure sockets layer
3  */
4 #include	"u.h"
5 #include	"../port/lib.h"
6 #include	"mem.h"
7 #include	"dat.h"
8 #include	"fns.h"
9 #include	"../port/error.h"
10 
11 #include	<libsec.h>
12 
13 #define NOSPOOKS 1
14 
15 typedef struct OneWay OneWay;
16 struct OneWay
17 {
18 	QLock	q;
19 	QLock	ctlq;
20 
21 	void	*state;		/* encryption state */
22 	int	slen;		/* hash data length */
23 	uchar	*secret;	/* secret */
24 	ulong	mid;		/* message id */
25 };
26 
27 enum
28 {
29 	/* connection states */
30 	Sincomplete=	0,
31 	Sclear=		1,
32 	Sencrypting=	2,
33 	Sdigesting=	4,
34 	Sdigenc=	Sencrypting|Sdigesting,
35 
36 	/* encryption algorithms */
37 	Noencryption=	0,
38 	DESCBC=		1,
39 	DESECB=		2,
40 	RC4=		3
41 };
42 
43 typedef struct Dstate Dstate;
44 struct Dstate
45 {
46 	Chan	*c;		/* io channel */
47 	uchar	state;		/* state of connection */
48 	int	ref;		/* serialized by dslock for atomic destroy */
49 
50 	uchar	encryptalg;	/* encryption algorithm */
51 	ushort	blocklen;	/* blocking length */
52 
53 	ushort	diglen;		/* length of digest */
54 	DigestState *(*hf)(uchar*, ulong, uchar*, DigestState*);	/* hash func */
55 
56 	/* for SSL format */
57 	int	max;		/* maximum unpadded data per msg */
58 	int	maxpad;		/* maximum padded data per msg */
59 
60 	/* input side */
61 	OneWay	in;
62 	Block	*processed;
63 	Block	*unprocessed;
64 
65 	/* output side */
66 	OneWay	out;
67 
68 	/* protections */
69 	char	*user;
70 	int	perm;
71 };
72 
73 enum
74 {
75 	Maxdmsg=	1<<16,
76 	Maxdstate=	128,	/* must be a power of 2 */
77 };
78 
79 Lock	dslock;
80 int	dshiwat;
81 char	*dsname[Maxdstate];
82 Dstate	*dstate[Maxdstate];
83 char	*encalgs;
84 char	*hashalgs;
85 
86 enum{
87 	Qtopdir		= 1,	/* top level directory */
88 	Qprotodir,
89 	Qclonus,
90 	Qconvdir,		/* directory for a conversation */
91 	Qdata,
92 	Qctl,
93 	Qsecretin,
94 	Qsecretout,
95 	Qencalgs,
96 	Qhashalgs,
97 };
98 
99 #define TYPE(x) 	((x).path & 0xf)
100 #define CONV(x) 	(((x).path >> 5)&(Maxdstate-1))
101 #define QID(c, y) 	(((c)<<5) | (y))
102 
103 static void	ensure(Dstate*, Block**, int);
104 static void	consume(Block**, uchar*, int);
105 static void	setsecret(OneWay*, uchar*, int);
106 static Block*	encryptb(Dstate*, Block*, int);
107 static Block*	decryptb(Dstate*, Block*);
108 static Block*	digestb(Dstate*, Block*, int);
109 static void	checkdigestb(Dstate*, Block*);
110 static Chan*	buftochan(char*);
111 static void	sslhangup(Dstate*);
112 static Dstate*	dsclone(Chan *c);
113 static void	dsnew(Chan *c, Dstate **);
114 static long	sslput(Dstate *s, Block * volatile b);
115 
116 char *sslnames[] = {
117 [Qclonus]	"clone",
118 [Qdata]		"data",
119 [Qctl]		"ctl",
120 [Qsecretin]	"secretin",
121 [Qsecretout]	"secretout",
122 [Qencalgs]	"encalgs",
123 [Qhashalgs]	"hashalgs",
124 };
125 
126 static int
127 sslgen(Chan *c, char*, Dirtab *d, int nd, int s, Dir *dp)
128 {
129 	Qid q;
130 	Dstate *ds;
131 	char name[16], *p, *nm;
132 	int ft;
133 
134 	USED(nd);
135 	USED(d);
136 
137 	q.type = QTFILE;
138 	q.vers = 0;
139 
140 	ft = TYPE(c->qid);
141 	switch(ft) {
142 	case Qtopdir:
143 		if(s == DEVDOTDOT){
144 			q.path = QID(0, Qtopdir);
145 			q.type = QTDIR;
146 			devdir(c, q, "#D", 0, eve, 0555, dp);
147 			return 1;
148 		}
149 		if(s > 0)
150 			return -1;
151 		q.path = QID(0, Qprotodir);
152 		q.type = QTDIR;
153 		devdir(c, q, "ssl", 0, eve, 0555, dp);
154 		return 1;
155 	case Qprotodir:
156 		if(s == DEVDOTDOT){
157 			q.path = QID(0, Qtopdir);
158 			q.type = QTDIR;
159 			devdir(c, q, ".", 0, eve, 0555, dp);
160 			return 1;
161 		}
162 		if(s < dshiwat) {
163 			q.path = QID(s, Qconvdir);
164 			q.type = QTDIR;
165 			ds = dstate[s];
166 			if(ds != 0)
167 				nm = ds->user;
168 			else
169 				nm = eve;
170 			if(dsname[s] == nil){
171 				sprint(name, "%d", s);
172 				kstrdup(&dsname[s], name);
173 			}
174 			devdir(c, q, dsname[s], 0, nm, 0555, dp);
175 			return 1;
176 		}
177 		if(s > dshiwat)
178 			return -1;
179 		q.path = QID(0, Qclonus);
180 		devdir(c, q, "clone", 0, eve, 0555, dp);
181 		return 1;
182 	case Qconvdir:
183 		if(s == DEVDOTDOT){
184 			q.path = QID(0, Qprotodir);
185 			q.type = QTDIR;
186 			devdir(c, q, "ssl", 0, eve, 0555, dp);
187 			return 1;
188 		}
189 		ds = dstate[CONV(c->qid)];
190 		if(ds != 0)
191 			nm = ds->user;
192 		else
193 			nm = eve;
194 		switch(s) {
195 		default:
196 			return -1;
197 		case 0:
198 			q.path = QID(CONV(c->qid), Qctl);
199 			p = "ctl";
200 			break;
201 		case 1:
202 			q.path = QID(CONV(c->qid), Qdata);
203 			p = "data";
204 			break;
205 		case 2:
206 			q.path = QID(CONV(c->qid), Qsecretin);
207 			p = "secretin";
208 			break;
209 		case 3:
210 			q.path = QID(CONV(c->qid), Qsecretout);
211 			p = "secretout";
212 			break;
213 		case 4:
214 			q.path = QID(CONV(c->qid), Qencalgs);
215 			p = "encalgs";
216 			break;
217 		case 5:
218 			q.path = QID(CONV(c->qid), Qhashalgs);
219 			p = "hashalgs";
220 			break;
221 		}
222 		devdir(c, q, p, 0, nm, 0660, dp);
223 		return 1;
224 	case Qclonus:
225 		devdir(c, c->qid, sslnames[TYPE(c->qid)], 0, eve, 0555, dp);
226 		return 1;
227 	default:
228 		ds = dstate[CONV(c->qid)];
229 		if(ds != 0)
230 			nm = ds->user;
231 		else
232 			nm = eve;
233 		devdir(c, c->qid, sslnames[TYPE(c->qid)], 0, nm, 0660, dp);
234 		return 1;
235 	}
236 	return -1;
237 }
238 
239 static Chan*
240 sslattach(char *spec)
241 {
242 	Chan *c;
243 
244 	c = devattach('D', spec);
245 	c->qid.path = QID(0, Qtopdir);
246 	c->qid.vers = 0;
247 	c->qid.type = QTDIR;
248 	return c;
249 }
250 
251 static Walkqid*
252 sslwalk(Chan *c, Chan *nc, char **name, int nname)
253 {
254 	return devwalk(c, nc, name, nname, nil, 0, sslgen);
255 }
256 
257 static int
258 sslstat(Chan *c, uchar *db, int n)
259 {
260 	return devstat(c, db, n, nil, 0, sslgen);
261 }
262 
263 static Chan*
264 sslopen(Chan *c, int omode)
265 {
266 	Dstate *s, **pp;
267 	int perm;
268 	int ft;
269 
270 	perm = 0;
271 	omode &= 3;
272 	switch(omode) {
273 	case OREAD:
274 		perm = 4;
275 		break;
276 	case OWRITE:
277 		perm = 2;
278 		break;
279 	case ORDWR:
280 		perm = 6;
281 		break;
282 	}
283 
284 	ft = TYPE(c->qid);
285 	switch(ft) {
286 	default:
287 		panic("sslopen");
288 	case Qtopdir:
289 	case Qprotodir:
290 	case Qconvdir:
291 		if(omode != OREAD)
292 			error(Eperm);
293 		break;
294 	case Qclonus:
295 		s = dsclone(c);
296 		if(s == 0)
297 			error(Enodev);
298 		break;
299 	case Qctl:
300 	case Qdata:
301 	case Qsecretin:
302 	case Qsecretout:
303 		if(waserror()) {
304 			unlock(&dslock);
305 			nexterror();
306 		}
307 		lock(&dslock);
308 		pp = &dstate[CONV(c->qid)];
309 		s = *pp;
310 		if(s == 0)
311 			dsnew(c, pp);
312 		else {
313 			if((perm & (s->perm>>6)) != perm
314 			   && (strcmp(up->user, s->user) != 0
315 			     || (perm & s->perm) != perm))
316 				error(Eperm);
317 
318 			s->ref++;
319 		}
320 		unlock(&dslock);
321 		poperror();
322 		break;
323 	case Qencalgs:
324 	case Qhashalgs:
325 		if(omode != OREAD)
326 			error(Eperm);
327 		break;
328 	}
329 	c->mode = openmode(omode);
330 	c->flag |= COPEN;
331 	c->offset = 0;
332 	return c;
333 }
334 
335 static int
336 sslwstat(Chan *c, uchar *db, int n)
337 {
338 	Dir *dir;
339 	Dstate *s;
340 	int m;
341 
342 	s = dstate[CONV(c->qid)];
343 	if(s == 0)
344 		error(Ebadusefd);
345 	if(strcmp(s->user, up->user) != 0)
346 		error(Eperm);
347 
348 	dir = smalloc(sizeof(Dir)+n);
349 	m = convM2D(db, n, &dir[0], (char*)&dir[1]);
350 	if(m == 0){
351 		free(dir);
352 		error(Eshortstat);
353 	}
354 
355 	if(!emptystr(dir->uid))
356 		kstrdup(&s->user, dir->uid);
357 	if(dir->mode != ~0UL)
358 		s->perm = dir->mode;
359 
360 	free(dir);
361 	return m;
362 }
363 
364 static void
365 sslclose(Chan *c)
366 {
367 	Dstate *s;
368 	int ft;
369 
370 	ft = TYPE(c->qid);
371 	switch(ft) {
372 	case Qctl:
373 	case Qdata:
374 	case Qsecretin:
375 	case Qsecretout:
376 		if((c->flag & COPEN) == 0)
377 			break;
378 
379 		s = dstate[CONV(c->qid)];
380 		if(s == 0)
381 			break;
382 
383 		lock(&dslock);
384 		if(--s->ref > 0) {
385 			unlock(&dslock);
386 			break;
387 		}
388 		dstate[CONV(c->qid)] = 0;
389 		unlock(&dslock);
390 
391 		if(s->user != nil)
392 			free(s->user);
393 		sslhangup(s);
394 		if(s->c)
395 			cclose(s->c);
396 		if(s->in.secret)
397 			free(s->in.secret);
398 		if(s->out.secret)
399 			free(s->out.secret);
400 		if(s->in.state)
401 			free(s->in.state);
402 		if(s->out.state)
403 			free(s->out.state);
404 		free(s);
405 
406 	}
407 }
408 
409 /*
410  *  make sure we have at least 'n' bytes in list 'l'
411  */
412 static void
413 ensure(Dstate *s, Block **l, int n)
414 {
415 	int sofar, i;
416 	Block *b, *bl;
417 
418 	sofar = 0;
419 	for(b = *l; b; b = b->next){
420 		sofar += BLEN(b);
421 		if(sofar >= n)
422 			return;
423 		l = &b->next;
424 	}
425 
426 	while(sofar < n){
427 		bl = devtab[s->c->type]->bread(s->c, Maxdmsg, 0);
428 		if(bl == 0)
429 			nexterror();
430 		*l = bl;
431 		i = 0;
432 		for(b = bl; b; b = b->next){
433 			i += BLEN(b);
434 			l = &b->next;
435 		}
436 		if(i == 0)
437 			error(Ehungup);
438 		sofar += i;
439 	}
440 }
441 
442 /*
443  *  copy 'n' bytes from 'l' into 'p' and free
444  *  the bytes in 'l'
445  */
446 static void
447 consume(Block **l, uchar *p, int n)
448 {
449 	Block *b;
450 	int i;
451 
452 	for(; *l && n > 0; n -= i){
453 		b = *l;
454 		i = BLEN(b);
455 		if(i > n)
456 			i = n;
457 		memmove(p, b->rp, i);
458 		b->rp += i;
459 		p += i;
460 		if(BLEN(b) < 0)
461 			panic("consume");
462 		if(BLEN(b))
463 			break;
464 		*l = b->next;
465 		freeb(b);
466 	}
467 }
468 
469 /*
470  *  give back n bytes
471 static void
472 regurgitate(Dstate *s, uchar *p, int n)
473 {
474 	Block *b;
475 
476 	if(n <= 0)
477 		return;
478 	b = s->unprocessed;
479 	if(s->unprocessed == nil || b->rp - b->base < n) {
480 		b = allocb(n);
481 		memmove(b->wp, p, n);
482 		b->wp += n;
483 		b->next = s->unprocessed;
484 		s->unprocessed = b;
485 	} else {
486 		b->rp -= n;
487 		memmove(b->rp, p, n);
488 	}
489 }
490  */
491 
492 /*
493  *  remove at most n bytes from the queue, if discard is set
494  *  dump the remainder
495  */
496 static Block*
497 qtake(Block **l, int n, int discard)
498 {
499 	Block *nb, *b, *first;
500 	int i;
501 
502 	first = *l;
503 	for(b = first; b; b = b->next){
504 		i = BLEN(b);
505 		if(i == n){
506 			if(discard){
507 				freeblist(b->next);
508 				*l = 0;
509 			} else
510 				*l = b->next;
511 			b->next = 0;
512 			return first;
513 		} else if(i > n){
514 			i -= n;
515 			if(discard){
516 				freeblist(b->next);
517 				b->wp -= i;
518 				*l = 0;
519 			} else {
520 				nb = allocb(i);
521 				memmove(nb->wp, b->rp+n, i);
522 				nb->wp += i;
523 				b->wp -= i;
524 				nb->next = b->next;
525 				*l = nb;
526 			}
527 			b->next = 0;
528 			if(BLEN(b) < 0)
529 				panic("qtake");
530 			return first;
531 		} else
532 			n -= i;
533 		if(BLEN(b) < 0)
534 			panic("qtake");
535 	}
536 	*l = 0;
537 	return first;
538 }
539 
540 /*
541  *  We can't let Eintr's lose data since the program
542  *  doing the read may be able to handle it.  The only
543  *  places Eintr is possible is during the read's in consume.
544  *  Therefore, we make sure we can always put back the bytes
545  *  consumed before the last ensure.
546  */
547 static Block*
548 sslbread(Chan *c, long n, ulong)
549 {
550 	Dstate * volatile s;
551 	Block *b;
552 	uchar consumed[3], *p;
553 	int toconsume;
554 	int len, pad;
555 
556 	s = dstate[CONV(c->qid)];
557 	if(s == 0)
558 		panic("sslbread");
559 	if(s->state == Sincomplete)
560 		error(Ebadusefd);
561 
562 	qlock(&s->in.q);
563 	if(waserror()){
564 		qunlock(&s->in.q);
565 		nexterror();
566 	}
567 
568 	if(s->processed == 0){
569 		/*
570 		 * Read in the whole message.  Until we've got it all,
571 		 * it stays on s->unprocessed, so that if we get Eintr,
572 		 * we'll pick up where we left off.
573 		 */
574 		ensure(s, &s->unprocessed, 3);
575 		s->unprocessed = pullupblock(s->unprocessed, 2);
576 		p = s->unprocessed->rp;
577 		if(p[0] & 0x80){
578 			len = ((p[0] & 0x7f)<<8) | p[1];
579 			ensure(s, &s->unprocessed, len);
580 			pad = 0;
581 			toconsume = 2;
582 		} else {
583 			s->unprocessed = pullupblock(s->unprocessed, 3);
584 			len = ((p[0] & 0x3f)<<8) | p[1];
585 			pad = p[2];
586 			if(pad > len){
587 				print("pad %d buf len %d\n", pad, len);
588 				error("bad pad in ssl message");
589 			}
590 			toconsume = 3;
591 		}
592 		ensure(s, &s->unprocessed, toconsume+len);
593 
594 		/* skip header */
595 		consume(&s->unprocessed, consumed, toconsume);
596 
597 		/* grab the next message and decode/decrypt it */
598 		b = qtake(&s->unprocessed, len, 0);
599 
600 		if(blocklen(b) != len)
601 			print("devssl: sslbread got wrong count %d != %d", blocklen(b), len);
602 
603 		if(waserror()){
604 			qunlock(&s->in.ctlq);
605 			if(b != nil)
606 				freeb(b);
607 			nexterror();
608 		}
609 		qlock(&s->in.ctlq);
610 		switch(s->state){
611 		case Sencrypting:
612 			if(b == nil)
613 				error("ssl message too short (encrypting)");
614 			b = decryptb(s, b);
615 			break;
616 		case Sdigesting:
617 			b = pullupblock(b, s->diglen);
618 			if(b == nil)
619 				error("ssl message too short (digesting)");
620 			checkdigestb(s, b);
621 			b->rp += s->diglen;
622 			break;
623 		case Sdigenc:
624 			b = decryptb(s, b);
625 			b = pullupblock(b, s->diglen);
626 			if(b == nil)
627 				error("ssl message too short (dig+enc)");
628 			checkdigestb(s, b);
629 			b->rp += s->diglen;
630 			len -= s->diglen;
631 			break;
632 		}
633 
634 		/* remove pad */
635 		if(pad)
636 			s->processed = qtake(&b, len - pad, 1);
637 		else
638 			s->processed = b;
639 		b = nil;
640 		s->in.mid++;
641 		qunlock(&s->in.ctlq);
642 		poperror();
643 	}
644 
645 	/* return at most what was asked for */
646 	b = qtake(&s->processed, n, 0);
647 
648 	qunlock(&s->in.q);
649 	poperror();
650 
651 	return b;
652 }
653 
654 static long
655 sslread(Chan *c, void *a, long n, vlong off)
656 {
657 	Block * volatile b;
658 	Block *nb;
659 	uchar *va;
660 	int i;
661 	char buf[128];
662 	ulong offset = off;
663 	int ft;
664 
665 	if(c->qid.type & QTDIR)
666 		return devdirread(c, a, n, 0, 0, sslgen);
667 
668 	ft = TYPE(c->qid);
669 	switch(ft) {
670 	default:
671 		error(Ebadusefd);
672 	case Qctl:
673 		ft = CONV(c->qid);
674 		sprint(buf, "%d", ft);
675 		return readstr(offset, a, n, buf);
676 	case Qdata:
677 		b = sslbread(c, n, offset);
678 		break;
679 	case Qencalgs:
680 		return readstr(offset, a, n, encalgs);
681 		break;
682 	case Qhashalgs:
683 		return readstr(offset, a, n, hashalgs);
684 		break;
685 	}
686 
687 	if(waserror()){
688 		freeblist(b);
689 		nexterror();
690 	}
691 
692 	n = 0;
693 	va = a;
694 	for(nb = b; nb; nb = nb->next){
695 		i = BLEN(nb);
696 		memmove(va+n, nb->rp, i);
697 		n += i;
698 	}
699 
700 	freeblist(b);
701 	poperror();
702 
703 	return n;
704 }
705 
706 /*
707  *  this algorithm doesn't have to be great since we're just
708  *  trying to obscure the block fill
709  */
710 static void
711 randfill(uchar *buf, int len)
712 {
713 	while(len-- > 0)
714 		*buf++ = nrand(256);
715 }
716 
717 static long
718 sslbwrite(Chan *c, Block *b, ulong)
719 {
720 	Dstate * volatile s;
721 	long rv;
722 
723 	s = dstate[CONV(c->qid)];
724 	if(s == nil)
725 		panic("sslbwrite");
726 
727 	if(s->state == Sincomplete){
728 		freeb(b);
729 		error(Ebadusefd);
730 	}
731 
732 	/* lock so split writes won't interleave */
733 	if(waserror()){
734 		qunlock(&s->out.q);
735 		nexterror();
736 	}
737 	qlock(&s->out.q);
738 
739 	rv = sslput(s, b);
740 
741 	poperror();
742 	qunlock(&s->out.q);
743 
744 	return rv;
745 }
746 
747 /*
748  *  use SSL record format, add in count, digest and/or encrypt.
749  *  the write is interruptable.  if it is interrupted, we'll
750  *  get out of sync with the far side.  not much we can do about
751  *  it since we don't know if any bytes have been written.
752  */
753 static long
754 sslput(Dstate *s, Block * volatile b)
755 {
756 	Block *nb;
757 	int h, n, m, pad, rv;
758 	uchar *p;
759 	int offset;
760 
761 	if(waserror()){
762 		if(b != nil)
763 			free(b);
764 		nexterror();
765 	}
766 
767 	rv = 0;
768 	while(b != nil){
769 		m = n = BLEN(b);
770 		h = s->diglen + 2;
771 
772 		/* trim to maximum block size */
773 		pad = 0;
774 		if(m > s->max){
775 			m = s->max;
776 		} else if(s->blocklen != 1){
777 			pad = (m + s->diglen)%s->blocklen;
778 			if(pad){
779 				if(m > s->maxpad){
780 					pad = 0;
781 					m = s->maxpad;
782 				} else {
783 					pad = s->blocklen - pad;
784 					h++;
785 				}
786 			}
787 		}
788 
789 		rv += m;
790 		if(m != n){
791 			nb = allocb(m + h + pad);
792 			memmove(nb->wp + h, b->rp, m);
793 			nb->wp += m + h;
794 			b->rp += m;
795 		} else {
796 			/* add header space */
797 			nb = padblock(b, h);
798 			b = 0;
799 		}
800 		m += s->diglen;
801 
802 		/* SSL style count */
803 		if(pad){
804 			nb = padblock(nb, -pad);
805 			randfill(nb->wp, pad);
806 			nb->wp += pad;
807 			m += pad;
808 
809 			p = nb->rp;
810 			p[0] = (m>>8);
811 			p[1] = m;
812 			p[2] = pad;
813 			offset = 3;
814 		} else {
815 			p = nb->rp;
816 			p[0] = (m>>8) | 0x80;
817 			p[1] = m;
818 			offset = 2;
819 		}
820 
821 		switch(s->state){
822 		case Sencrypting:
823 			nb = encryptb(s, nb, offset);
824 			break;
825 		case Sdigesting:
826 			nb = digestb(s, nb, offset);
827 			break;
828 		case Sdigenc:
829 			nb = digestb(s, nb, offset);
830 			nb = encryptb(s, nb, offset);
831 			break;
832 		}
833 
834 		s->out.mid++;
835 
836 		m = BLEN(nb);
837 		devtab[s->c->type]->bwrite(s->c, nb, s->c->offset);
838 		s->c->offset += m;
839 	}
840 
841 	poperror();
842 	return rv;
843 }
844 
845 static void
846 setsecret(OneWay *w, uchar *secret, int n)
847 {
848 	if(w->secret)
849 		free(w->secret);
850 
851 	w->secret = smalloc(n);
852 	memmove(w->secret, secret, n);
853 	w->slen = n;
854 }
855 
856 static void
857 initDESkey(OneWay *w)
858 {
859 	if(w->state){
860 		free(w->state);
861 		w->state = 0;
862 	}
863 
864 	w->state = smalloc(sizeof(DESstate));
865 	if(w->slen >= 16)
866 		setupDESstate(w->state, w->secret, w->secret+8);
867 	else if(w->slen >= 8)
868 		setupDESstate(w->state, w->secret, 0);
869 	else
870 		error("secret too short");
871 }
872 
873 /*
874  *  40 bit DES is the same as 56 bit DES.  However,
875  *  16 bits of the key are masked to zero.
876  */
877 static void
878 initDESkey_40(OneWay *w)
879 {
880 	uchar key[8];
881 
882 	if(w->state){
883 		free(w->state);
884 		w->state = 0;
885 	}
886 
887 	if(w->slen >= 8){
888 		memmove(key, w->secret, 8);
889 		key[0] &= 0x0f;
890 		key[2] &= 0x0f;
891 		key[4] &= 0x0f;
892 		key[6] &= 0x0f;
893 	}
894 
895 	w->state = malloc(sizeof(DESstate));
896 	if(w->slen >= 16)
897 		setupDESstate(w->state, key, w->secret+8);
898 	else if(w->slen >= 8)
899 		setupDESstate(w->state, key, 0);
900 	else
901 		error("secret too short");
902 }
903 
904 static void
905 initRC4key(OneWay *w)
906 {
907 	if(w->state){
908 		free(w->state);
909 		w->state = 0;
910 	}
911 
912 	w->state = smalloc(sizeof(RC4state));
913 	setupRC4state(w->state, w->secret, w->slen);
914 }
915 
916 /*
917  *  40 bit RC4 is the same as n-bit RC4.  However,
918  *  we ignore all but the first 40 bits of the key.
919  */
920 static void
921 initRC4key_40(OneWay *w)
922 {
923 	if(w->state){
924 		free(w->state);
925 		w->state = 0;
926 	}
927 
928 	if(w->slen > 5)
929 		w->slen = 5;
930 
931 	w->state = malloc(sizeof(RC4state));
932 	setupRC4state(w->state, w->secret, w->slen);
933 }
934 
935 /*
936  *  128 bit RC4 is the same as n-bit RC4.  However,
937  *  we ignore all but the first 128 bits of the key.
938  */
939 static void
940 initRC4key_128(OneWay *w)
941 {
942 	if(w->state){
943 		free(w->state);
944 		w->state = 0;
945 	}
946 
947 	if(w->slen > 16)
948 		w->slen = 16;
949 
950 	w->state = malloc(sizeof(RC4state));
951 	setupRC4state(w->state, w->secret, w->slen);
952 }
953 
954 
955 typedef struct Hashalg Hashalg;
956 struct Hashalg
957 {
958 	char	*name;
959 	int	diglen;
960 	DigestState *(*hf)(uchar*, ulong, uchar*, DigestState*);
961 };
962 
963 Hashalg hashtab[] =
964 {
965 	{ "md4", MD4dlen, md4, },
966 	{ "md5", MD5dlen, md5, },
967 	{ "sha1", SHA1dlen, sha1, },
968 	{ "sha", SHA1dlen, sha1, },
969 	{ 0 }
970 };
971 
972 static int
973 parsehashalg(char *p, Dstate *s)
974 {
975 	Hashalg *ha;
976 
977 	for(ha = hashtab; ha->name; ha++){
978 		if(strcmp(p, ha->name) == 0){
979 			s->hf = ha->hf;
980 			s->diglen = ha->diglen;
981 			s->state &= ~Sclear;
982 			s->state |= Sdigesting;
983 			return 0;
984 		}
985 	}
986 	return -1;
987 }
988 
989 typedef struct Encalg Encalg;
990 struct Encalg
991 {
992 	char	*name;
993 	int	blocklen;
994 	int	alg;
995 	void	(*keyinit)(OneWay*);
996 };
997 
998 #ifdef NOSPOOKS
999 Encalg encrypttab[] =
1000 {
1001 	{ "descbc", 8, DESCBC, initDESkey, },           /* DEPRECATED -- use des_56_cbc */
1002 	{ "desecb", 8, DESECB, initDESkey, },           /* DEPRECATED -- use des_56_ecb */
1003 	{ "des_56_cbc", 8, DESCBC, initDESkey, },
1004 	{ "des_56_ecb", 8, DESECB, initDESkey, },
1005 	{ "des_40_cbc", 8, DESCBC, initDESkey_40, },
1006 	{ "des_40_ecb", 8, DESECB, initDESkey_40, },
1007 	{ "rc4", 1, RC4, initRC4key_40, },              /* DEPRECATED -- use rc4_X      */
1008 	{ "rc4_256", 1, RC4, initRC4key, },
1009 	{ "rc4_128", 1, RC4, initRC4key_128, },
1010 	{ "rc4_40", 1, RC4, initRC4key_40, },
1011 	{ 0 }
1012 };
1013 #else
1014 Encalg encrypttab[] =
1015 {
1016 	{ "des_40_cbc", 8, DESCBC, initDESkey_40, },
1017 	{ "des_40_ecb", 8, DESECB, initDESkey_40, },
1018 	{ "rc4", 1, RC4, initRC4key_40, },              /* DEPRECATED -- use rc4_X      */
1019 	{ "rc4_40", 1, RC4, initRC4key_40, },
1020 	{ 0 }
1021 };
1022 #endif NOSPOOKS
1023 
1024 static int
1025 parseencryptalg(char *p, Dstate *s)
1026 {
1027 	Encalg *ea;
1028 
1029 	for(ea = encrypttab; ea->name; ea++){
1030 		if(strcmp(p, ea->name) == 0){
1031 			s->encryptalg = ea->alg;
1032 			s->blocklen = ea->blocklen;
1033 			(*ea->keyinit)(&s->in);
1034 			(*ea->keyinit)(&s->out);
1035 			s->state &= ~Sclear;
1036 			s->state |= Sencrypting;
1037 			return 0;
1038 		}
1039 	}
1040 	return -1;
1041 }
1042 
1043 static long
1044 sslwrite(Chan *c, void *a, long n, vlong)
1045 {
1046 	Dstate * volatile s;
1047 	Block * volatile b;
1048 	int m, t;
1049 	char *p, *np, *e, buf[128];
1050 	uchar *x;
1051 
1052 	s = dstate[CONV(c->qid)];
1053 	if(s == 0)
1054 		panic("sslwrite");
1055 
1056 	t = TYPE(c->qid);
1057 	if(t == Qdata){
1058 		if(s->state == Sincomplete)
1059 			error(Ebadusefd);
1060 
1061 		/* lock should a write gets split over multiple records */
1062 		if(waserror()){
1063 			qunlock(&s->out.q);
1064 			nexterror();
1065 		}
1066 		qlock(&s->out.q);
1067 
1068 		p = a;
1069 		e = p + n;
1070 		do {
1071 			m = e - p;
1072 			if(m > s->max)
1073 				m = s->max;
1074 
1075 			b = allocb(m);
1076 			if(waserror()){
1077 				freeb(b);
1078 				nexterror();
1079 			}
1080 			memmove(b->wp, p, m);
1081 			poperror();
1082 			b->wp += m;
1083 
1084 			sslput(s, b);
1085 
1086 			p += m;
1087 		} while(p < e);
1088 
1089 		poperror();
1090 		qunlock(&s->out.q);
1091 		return n;
1092 	}
1093 
1094 	/* mutex with operations using what we're about to change */
1095 	if(waserror()){
1096 		qunlock(&s->in.ctlq);
1097 		qunlock(&s->out.q);
1098 		nexterror();
1099 	}
1100 	qlock(&s->in.ctlq);
1101 	qlock(&s->out.q);
1102 
1103 	switch(t){
1104 	default:
1105 		panic("sslwrite");
1106 	case Qsecretin:
1107 		setsecret(&s->in, a, n);
1108 		goto out;
1109 	case Qsecretout:
1110 		setsecret(&s->out, a, n);
1111 		goto out;
1112 	case Qctl:
1113 		break;
1114 	}
1115 
1116 	if(n >= sizeof(buf))
1117 		error("arg too long");
1118 	strncpy(buf, a, n);
1119 	buf[n] = 0;
1120 	p = strchr(buf, '\n');
1121 	if(p)
1122 		*p = 0;
1123 	p = strchr(buf, ' ');
1124 	if(p)
1125 		*p++ = 0;
1126 
1127 	if(strcmp(buf, "fd") == 0){
1128 		s->c = buftochan(p);
1129 
1130 		/* default is clear (msg delimiters only) */
1131 		s->state = Sclear;
1132 		s->blocklen = 1;
1133 		s->diglen = 0;
1134 		s->maxpad = s->max = (1<<15) - s->diglen - 1;
1135 		s->in.mid = 0;
1136 		s->out.mid = 0;
1137 	} else if(strcmp(buf, "alg") == 0 && p != 0){
1138 		s->blocklen = 1;
1139 		s->diglen = 0;
1140 
1141 		if(s->c == 0)
1142 			error("must set fd before algorithm");
1143 
1144 		s->state = Sclear;
1145 		s->maxpad = s->max = (1<<15) - s->diglen - 1;
1146 		if(strcmp(p, "clear") == 0){
1147 			goto out;
1148 		}
1149 
1150 		if(s->in.secret && s->out.secret == 0)
1151 			setsecret(&s->out, s->in.secret, s->in.slen);
1152 		if(s->out.secret && s->in.secret == 0)
1153 			setsecret(&s->in, s->out.secret, s->out.slen);
1154 		if(s->in.secret == 0 || s->out.secret == 0)
1155 			error("algorithm but no secret");
1156 
1157 		s->hf = 0;
1158 		s->encryptalg = Noencryption;
1159 		s->blocklen = 1;
1160 
1161 		for(;;){
1162 			np = strchr(p, ' ');
1163 			if(np)
1164 				*np++ = 0;
1165 
1166 			if(parsehashalg(p, s) < 0)
1167 			if(parseencryptalg(p, s) < 0)
1168 				error("bad algorithm");
1169 
1170 			if(np == 0)
1171 				break;
1172 			p = np;
1173 		}
1174 
1175 		if(s->hf == 0 && s->encryptalg == Noencryption)
1176 			error("bad algorithm");
1177 
1178 		if(s->blocklen != 1){
1179 			s->max = (1<<15) - s->diglen - 1;
1180 			s->max -= s->max % s->blocklen;
1181 			s->maxpad = (1<<14) - s->diglen - 1;
1182 			s->maxpad -= s->maxpad % s->blocklen;
1183 		} else
1184 			s->maxpad = s->max = (1<<15) - s->diglen - 1;
1185 	} else if(strcmp(buf, "secretin") == 0 && p != 0) {
1186 		m = (strlen(p)*3)/2;
1187 		x = smalloc(m);
1188 		t = dec64(x, m, p, strlen(p));
1189 		setsecret(&s->in, x, t);
1190 		free(x);
1191 	} else if(strcmp(buf, "secretout") == 0 && p != 0) {
1192 		m = (strlen(p)*3)/2 + 1;
1193 		x = smalloc(m);
1194 		t = dec64(x, m, p, strlen(p));
1195 		setsecret(&s->out, x, t);
1196 		free(x);
1197 	} else
1198 		error(Ebadarg);
1199 
1200 out:
1201 	qunlock(&s->in.ctlq);
1202 	qunlock(&s->out.q);
1203 	poperror();
1204 	return n;
1205 }
1206 
1207 static void
1208 sslinit(void)
1209 {
1210 	struct Encalg *e;
1211 	struct Hashalg *h;
1212 	int n;
1213 	char *cp;
1214 
1215 	n = 1;
1216 	for(e = encrypttab; e->name != nil; e++)
1217 		n += strlen(e->name) + 1;
1218 	cp = encalgs = smalloc(n);
1219 	for(e = encrypttab;;){
1220 		strcpy(cp, e->name);
1221 		cp += strlen(e->name);
1222 		e++;
1223 		if(e->name == nil)
1224 			break;
1225 		*cp++ = ' ';
1226 	}
1227 	*cp = 0;
1228 
1229 	n = 1;
1230 	for(h = hashtab; h->name != nil; h++)
1231 		n += strlen(h->name) + 1;
1232 	cp = hashalgs = smalloc(n);
1233 	for(h = hashtab;;){
1234 		strcpy(cp, h->name);
1235 		cp += strlen(h->name);
1236 		h++;
1237 		if(h->name == nil)
1238 			break;
1239 		*cp++ = ' ';
1240 	}
1241 	*cp = 0;
1242 }
1243 
1244 Dev ssldevtab = {
1245 	'D',
1246 	"ssl",
1247 
1248 	devreset,
1249 	sslinit,
1250 	devshutdown,
1251 	sslattach,
1252 	sslwalk,
1253 	sslstat,
1254 	sslopen,
1255 	devcreate,
1256 	sslclose,
1257 	sslread,
1258 	sslbread,
1259 	sslwrite,
1260 	sslbwrite,
1261 	devremove,
1262 	sslwstat,
1263 };
1264 
1265 static Block*
1266 encryptb(Dstate *s, Block *b, int offset)
1267 {
1268 	uchar *p, *ep, *p2, *ip, *eip;
1269 	DESstate *ds;
1270 
1271 	switch(s->encryptalg){
1272 	case DESECB:
1273 		ds = s->out.state;
1274 		ep = b->rp + BLEN(b);
1275 		for(p = b->rp + offset; p < ep; p += 8)
1276 			block_cipher(ds->expanded, p, 0);
1277 		break;
1278 	case DESCBC:
1279 		ds = s->out.state;
1280 		ep = b->rp + BLEN(b);
1281 		for(p = b->rp + offset; p < ep; p += 8){
1282 			p2 = p;
1283 			ip = ds->ivec;
1284 			for(eip = ip+8; ip < eip; )
1285 				*p2++ ^= *ip++;
1286 			block_cipher(ds->expanded, p, 0);
1287 			memmove(ds->ivec, p, 8);
1288 		}
1289 		break;
1290 	case RC4:
1291 		rc4(s->out.state, b->rp + offset, BLEN(b) - offset);
1292 		break;
1293 	}
1294 	return b;
1295 }
1296 
1297 static Block*
1298 decryptb(Dstate *s, Block *bin)
1299 {
1300 	Block *b, **l;
1301 	uchar *p, *ep, *tp, *ip, *eip;
1302 	DESstate *ds;
1303 	uchar tmp[8];
1304 	int i;
1305 
1306 	l = &bin;
1307 	for(b = bin; b; b = b->next){
1308 		/* make sure we have a multiple of s->blocklen */
1309 		if(s->blocklen > 1){
1310 			i = BLEN(b);
1311 			if(i % s->blocklen){
1312 				*l = b = pullupblock(b, i + s->blocklen - (i%s->blocklen));
1313 				if(b == 0)
1314 					error("ssl encrypted message too short");
1315 			}
1316 		}
1317 		l = &b->next;
1318 
1319 		/* decrypt */
1320 		switch(s->encryptalg){
1321 		case DESECB:
1322 			ds = s->in.state;
1323 			ep = b->rp + BLEN(b);
1324 			for(p = b->rp; p < ep; p += 8)
1325 				block_cipher(ds->expanded, p, 1);
1326 			break;
1327 		case DESCBC:
1328 			ds = s->in.state;
1329 			ep = b->rp + BLEN(b);
1330 			for(p = b->rp; p < ep;){
1331 				memmove(tmp, p, 8);
1332 				block_cipher(ds->expanded, p, 1);
1333 				tp = tmp;
1334 				ip = ds->ivec;
1335 				for(eip = ip+8; ip < eip; ){
1336 					*p++ ^= *ip;
1337 					*ip++ = *tp++;
1338 				}
1339 			}
1340 			break;
1341 		case RC4:
1342 			rc4(s->in.state, b->rp, BLEN(b));
1343 			break;
1344 		}
1345 	}
1346 	return bin;
1347 }
1348 
1349 static Block*
1350 digestb(Dstate *s, Block *b, int offset)
1351 {
1352 	uchar *p;
1353 	DigestState ss;
1354 	uchar msgid[4];
1355 	ulong n, h;
1356 	OneWay *w;
1357 
1358 	w = &s->out;
1359 
1360 	memset(&ss, 0, sizeof(ss));
1361 	h = s->diglen + offset;
1362 	n = BLEN(b) - h;
1363 
1364 	/* hash secret + message */
1365 	(*s->hf)(w->secret, w->slen, 0, &ss);
1366 	(*s->hf)(b->rp + h, n, 0, &ss);
1367 
1368 	/* hash message id */
1369 	p = msgid;
1370 	n = w->mid;
1371 	*p++ = n>>24;
1372 	*p++ = n>>16;
1373 	*p++ = n>>8;
1374 	*p = n;
1375 	(*s->hf)(msgid, 4, b->rp + offset, &ss);
1376 
1377 	return b;
1378 }
1379 
1380 static void
1381 checkdigestb(Dstate *s, Block *bin)
1382 {
1383 	uchar *p;
1384 	DigestState ss;
1385 	uchar msgid[4];
1386 	int n, h;
1387 	OneWay *w;
1388 	uchar digest[128];
1389 	Block *b;
1390 
1391 	w = &s->in;
1392 
1393 	memset(&ss, 0, sizeof(ss));
1394 
1395 	/* hash secret */
1396 	(*s->hf)(w->secret, w->slen, 0, &ss);
1397 
1398 	/* hash message */
1399 	h = s->diglen;
1400 	for(b = bin; b; b = b->next){
1401 		n = BLEN(b) - h;
1402 		if(n < 0)
1403 			panic("checkdigestb");
1404 		(*s->hf)(b->rp + h, n, 0, &ss);
1405 		h = 0;
1406 	}
1407 
1408 	/* hash message id */
1409 	p = msgid;
1410 	n = w->mid;
1411 	*p++ = n>>24;
1412 	*p++ = n>>16;
1413 	*p++ = n>>8;
1414 	*p = n;
1415 	(*s->hf)(msgid, 4, digest, &ss);
1416 
1417 	if(memcmp(digest, bin->rp, s->diglen) != 0)
1418 		error("bad digest");
1419 }
1420 
1421 /* get channel associated with an fd */
1422 static Chan*
1423 buftochan(char *p)
1424 {
1425 	Chan *c;
1426 	int fd;
1427 
1428 	if(p == 0)
1429 		error(Ebadarg);
1430 	fd = strtoul(p, 0, 0);
1431 	if(fd < 0)
1432 		error(Ebadarg);
1433 	c = fdtochan(fd, -1, 0, 1);	/* error check and inc ref */
1434 	if(devtab[c->type] == &ssldevtab){
1435 		cclose(c);
1436 		error("cannot ssl encrypt devssl files");
1437 	}
1438 	return c;
1439 }
1440 
1441 /* hand up a digest connection */
1442 static void
1443 sslhangup(Dstate *s)
1444 {
1445 	Block *b;
1446 
1447 	qlock(&s->in.q);
1448 	for(b = s->processed; b; b = s->processed){
1449 		s->processed = b->next;
1450 		freeb(b);
1451 	}
1452 	if(s->unprocessed){
1453 		freeb(s->unprocessed);
1454 		s->unprocessed = 0;
1455 	}
1456 	s->state = Sincomplete;
1457 	qunlock(&s->in.q);
1458 }
1459 
1460 static Dstate*
1461 dsclone(Chan *ch)
1462 {
1463 	int i;
1464 	Dstate *ret;
1465 
1466 	if(waserror()) {
1467 		unlock(&dslock);
1468 		nexterror();
1469 	}
1470 	lock(&dslock);
1471 	ret = nil;
1472 	for(i=0; i<Maxdstate; i++){
1473 		if(dstate[i] == nil){
1474 			dsnew(ch, &dstate[i]);
1475 			ret = dstate[i];
1476 			break;
1477 		}
1478 	}
1479 	unlock(&dslock);
1480 	poperror();
1481 	return ret;
1482 }
1483 
1484 static void
1485 dsnew(Chan *ch, Dstate **pp)
1486 {
1487 	Dstate *s;
1488 	int t;
1489 
1490 	*pp = s = malloc(sizeof(*s));
1491 	if(!s)
1492 		error(Enomem);
1493 	if(pp - dstate >= dshiwat)
1494 		dshiwat++;
1495 	memset(s, 0, sizeof(*s));
1496 	s->state = Sincomplete;
1497 	s->ref = 1;
1498 	kstrdup(&s->user, up->user);
1499 	s->perm = 0660;
1500 	t = TYPE(ch->qid);
1501 	if(t == Qclonus)
1502 		t = Qctl;
1503 	ch->qid.path = QID(pp - dstate, t);
1504 	ch->qid.vers = 0;
1505 }
1506