xref: /plan9/sys/src/cmd/unix/drawterm/kern/devssl.c (revision ec59a3ddbfceee0efe34584c2c9981a5e5ff1ec4)
1 /*
2  *  devssl - secure sockets layer
3  */
4 #include	"u.h"
5 #include	"lib.h"
6 #include	"dat.h"
7 #include	"fns.h"
8 #include	"error.h"
9 
10 #include	"libsec.h"
11 
12 #define NOSPOOKS 1
13 
14 typedef struct OneWay OneWay;
15 struct OneWay
16 {
17 	QLock	q;
18 	QLock	ctlq;
19 
20 	void	*state;		/* encryption state */
21 	int	slen;		/* hash data length */
22 	uchar	*secret;	/* secret */
23 	ulong	mid;		/* message id */
24 };
25 
26 enum
27 {
28 	/* connection states */
29 	Sincomplete=	0,
30 	Sclear=		1,
31 	Sencrypting=	2,
32 	Sdigesting=	4,
33 	Sdigenc=	Sencrypting|Sdigesting,
34 
35 	/* encryption algorithms */
36 	Noencryption=	0,
37 	DESCBC=		1,
38 	DESECB=		2,
39 	RC4=		3
40 };
41 
42 typedef struct Dstate Dstate;
43 struct Dstate
44 {
45 	Chan	*c;		/* io channel */
46 	uchar	state;		/* state of connection */
47 	int	ref;		/* serialized by dslock for atomic destroy */
48 
49 	uchar	encryptalg;	/* encryption algorithm */
50 	ushort	blocklen;	/* blocking length */
51 
52 	ushort	diglen;		/* length of digest */
53 	DigestState *(*hf)(uchar*, ulong, uchar*, DigestState*);	/* hash func */
54 
55 	/* for SSL format */
56 	int	max;		/* maximum unpadded data per msg */
57 	int	maxpad;		/* maximum padded data per msg */
58 
59 	/* input side */
60 	OneWay	in;
61 	Block	*processed;
62 	Block	*unprocessed;
63 
64 	/* output side */
65 	OneWay	out;
66 
67 	/* protections */
68 	char	*user;
69 	int	perm;
70 };
71 
72 enum
73 {
74 	Maxdmsg=	1<<16,
75 	Maxdstate=	128,	/* must be a power of 2 */
76 };
77 
78 Lock	dslock;
79 int	dshiwat;
80 char	*dsname[Maxdstate];
81 Dstate	*dstate[Maxdstate];
82 char	*encalgs;
83 char	*hashalgs;
84 
85 enum{
86 	Qtopdir		= 1,	/* top level directory */
87 	Qprotodir,
88 	Qclonus,
89 	Qconvdir,		/* directory for a conversation */
90 	Qdata,
91 	Qctl,
92 	Qsecretin,
93 	Qsecretout,
94 	Qencalgs,
95 	Qhashalgs,
96 };
97 
98 #define TYPE(x) 	((x).path & 0xf)
99 #define CONV(x) 	(((x).path >> 5)&(Maxdstate-1))
100 #define QID(c, y) 	(((c)<<5) | (y))
101 
102 static void	ensure(Dstate*, Block**, int);
103 static void	consume(Block**, uchar*, int);
104 static void	setsecret(OneWay*, uchar*, int);
105 static Block*	encryptb(Dstate*, Block*, int);
106 static Block*	decryptb(Dstate*, Block*);
107 static Block*	digestb(Dstate*, Block*, int);
108 static void	checkdigestb(Dstate*, Block*);
109 static Chan*	buftochan(char*);
110 static void	sslhangup(Dstate*);
111 static Dstate*	dsclone(Chan *c);
112 static void	dsnew(Chan *c, Dstate **);
113 static long	sslput(Dstate *s, Block * volatile b);
114 
115 char *sslnames[] = {
116 	/* unused */ 0,
117 	/* topdir */ 0,
118 	/* protodir */ 0,
119 	"clone",
120 	/* convdir */ 0,
121 	"data",
122 	"ctl",
123 	"secretin",
124 	"secretout",
125 	"encalgs",
126 	"hashalgs",
127 };
128 
129 static int
sslgen(Chan * c,char * n,Dirtab * d,int nd,int s,Dir * dp)130 sslgen(Chan *c, char *n, Dirtab *d, int nd, int s, Dir *dp)
131 {
132 	Qid q;
133 	Dstate *ds;
134 	char name[16], *p, *nm;
135 	int ft;
136 
137 	USED(n);
138 	USED(nd);
139 	USED(d);
140 
141 	q.type = QTFILE;
142 	q.vers = 0;
143 
144 	ft = TYPE(c->qid);
145 	switch(ft) {
146 	case Qtopdir:
147 		if(s == DEVDOTDOT){
148 			q.path = QID(0, Qtopdir);
149 			q.type = QTDIR;
150 			devdir(c, q, "#D", 0, eve, 0555, dp);
151 			return 1;
152 		}
153 		if(s > 0)
154 			return -1;
155 		q.path = QID(0, Qprotodir);
156 		q.type = QTDIR;
157 		devdir(c, q, "ssl", 0, eve, 0555, dp);
158 		return 1;
159 	case Qprotodir:
160 		if(s == DEVDOTDOT){
161 			q.path = QID(0, Qtopdir);
162 			q.type = QTDIR;
163 			devdir(c, q, ".", 0, eve, 0555, dp);
164 			return 1;
165 		}
166 		if(s < dshiwat) {
167 			q.path = QID(s, Qconvdir);
168 			q.type = QTDIR;
169 			ds = dstate[s];
170 			if(ds != 0)
171 				nm = ds->user;
172 			else
173 				nm = eve;
174 			if(dsname[s] == nil){
175 				sprint(name, "%d", s);
176 				kstrdup(&dsname[s], name);
177 			}
178 			devdir(c, q, dsname[s], 0, nm, 0555, dp);
179 			return 1;
180 		}
181 		if(s > dshiwat)
182 			return -1;
183 		q.path = QID(0, Qclonus);
184 		devdir(c, q, "clone", 0, eve, 0555, dp);
185 		return 1;
186 	case Qconvdir:
187 		if(s == DEVDOTDOT){
188 			q.path = QID(0, Qprotodir);
189 			q.type = QTDIR;
190 			devdir(c, q, "ssl", 0, eve, 0555, dp);
191 			return 1;
192 		}
193 		ds = dstate[CONV(c->qid)];
194 		if(ds != 0)
195 			nm = ds->user;
196 		else
197 			nm = eve;
198 		switch(s) {
199 		default:
200 			return -1;
201 		case 0:
202 			q.path = QID(CONV(c->qid), Qctl);
203 			p = "ctl";
204 			break;
205 		case 1:
206 			q.path = QID(CONV(c->qid), Qdata);
207 			p = "data";
208 			break;
209 		case 2:
210 			q.path = QID(CONV(c->qid), Qsecretin);
211 			p = "secretin";
212 			break;
213 		case 3:
214 			q.path = QID(CONV(c->qid), Qsecretout);
215 			p = "secretout";
216 			break;
217 		case 4:
218 			q.path = QID(CONV(c->qid), Qencalgs);
219 			p = "encalgs";
220 			break;
221 		case 5:
222 			q.path = QID(CONV(c->qid), Qhashalgs);
223 			p = "hashalgs";
224 			break;
225 		}
226 		devdir(c, q, p, 0, nm, 0660, dp);
227 		return 1;
228 	case Qclonus:
229 		devdir(c, c->qid, sslnames[TYPE(c->qid)], 0, eve, 0555, dp);
230 		return 1;
231 	default:
232 		ds = dstate[CONV(c->qid)];
233 		if(ds != 0)
234 			nm = ds->user;
235 		else
236 			nm = eve;
237 		devdir(c, c->qid, sslnames[TYPE(c->qid)], 0, nm, 0660, dp);
238 		return 1;
239 	}
240 	return -1;
241 }
242 
243 static Chan*
sslattach(char * spec)244 sslattach(char *spec)
245 {
246 	Chan *c;
247 
248 	c = devattach('D', spec);
249 	c->qid.path = QID(0, Qtopdir);
250 	c->qid.vers = 0;
251 	c->qid.type = QTDIR;
252 	return c;
253 }
254 
255 static Walkqid*
sslwalk(Chan * c,Chan * nc,char ** name,int nname)256 sslwalk(Chan *c, Chan *nc, char **name, int nname)
257 {
258 	return devwalk(c, nc, name, nname, nil, 0, sslgen);
259 }
260 
261 static int
sslstat(Chan * c,uchar * db,int n)262 sslstat(Chan *c, uchar *db, int n)
263 {
264 	return devstat(c, db, n, nil, 0, sslgen);
265 }
266 
267 static Chan*
sslopen(Chan * c,int omode)268 sslopen(Chan *c, int omode)
269 {
270 	Dstate *s, **pp;
271 	int perm;
272 	int ft;
273 
274 	perm = 0;
275 	omode &= 3;
276 	switch(omode) {
277 	case OREAD:
278 		perm = 4;
279 		break;
280 	case OWRITE:
281 		perm = 2;
282 		break;
283 	case ORDWR:
284 		perm = 6;
285 		break;
286 	}
287 
288 	ft = TYPE(c->qid);
289 	switch(ft) {
290 	default:
291 		panic("sslopen");
292 	case Qtopdir:
293 	case Qprotodir:
294 	case Qconvdir:
295 		if(omode != OREAD)
296 			error(Eperm);
297 		break;
298 	case Qclonus:
299 		s = dsclone(c);
300 		if(s == 0)
301 			error(Enodev);
302 		break;
303 	case Qctl:
304 	case Qdata:
305 	case Qsecretin:
306 	case Qsecretout:
307 		if(waserror()) {
308 			unlock(&dslock);
309 			nexterror();
310 		}
311 		lock(&dslock);
312 		pp = &dstate[CONV(c->qid)];
313 		s = *pp;
314 		if(s == 0)
315 			dsnew(c, pp);
316 		else {
317 			if((perm & (s->perm>>6)) != perm
318 			   && (strcmp(up->user, s->user) != 0
319 			     || (perm & s->perm) != perm))
320 				error(Eperm);
321 
322 			s->ref++;
323 		}
324 		unlock(&dslock);
325 		poperror();
326 		break;
327 	case Qencalgs:
328 	case Qhashalgs:
329 		if(omode != OREAD)
330 			error(Eperm);
331 		break;
332 	}
333 	c->mode = openmode(omode);
334 	c->flag |= COPEN;
335 	c->offset = 0;
336 	return c;
337 }
338 
339 static int
sslwstat(Chan * c,uchar * db,int n)340 sslwstat(Chan *c, uchar *db, int n)
341 {
342 	Dir *dir;
343 	Dstate *s;
344 	int m;
345 
346 	s = dstate[CONV(c->qid)];
347 	if(s == 0)
348 		error(Ebadusefd);
349 	if(strcmp(s->user, up->user) != 0)
350 		error(Eperm);
351 
352 	dir = smalloc(sizeof(Dir)+n);
353 	m = convM2D(db, n, &dir[0], (char*)&dir[1]);
354 	if(m == 0){
355 		free(dir);
356 		error(Eshortstat);
357 	}
358 
359 	if(!emptystr(dir->uid))
360 		kstrdup(&s->user, dir->uid);
361 	if(dir->mode != ~0)
362 		s->perm = dir->mode;
363 
364 	free(dir);
365 	return m;
366 }
367 
368 static void
sslclose(Chan * c)369 sslclose(Chan *c)
370 {
371 	Dstate *s;
372 	int ft;
373 
374 	ft = TYPE(c->qid);
375 	switch(ft) {
376 	case Qctl:
377 	case Qdata:
378 	case Qsecretin:
379 	case Qsecretout:
380 		if((c->flag & COPEN) == 0)
381 			break;
382 
383 		s = dstate[CONV(c->qid)];
384 		if(s == 0)
385 			break;
386 
387 		lock(&dslock);
388 		if(--s->ref > 0) {
389 			unlock(&dslock);
390 			break;
391 		}
392 		dstate[CONV(c->qid)] = 0;
393 		unlock(&dslock);
394 
395 		if(s->user != nil)
396 			free(s->user);
397 		sslhangup(s);
398 		if(s->c)
399 			cclose(s->c);
400 		if(s->in.secret)
401 			free(s->in.secret);
402 		if(s->out.secret)
403 			free(s->out.secret);
404 		if(s->in.state)
405 			free(s->in.state);
406 		if(s->out.state)
407 			free(s->out.state);
408 		free(s);
409 
410 	}
411 }
412 
413 /*
414  *  make sure we have at least 'n' bytes in list 'l'
415  */
416 static void
ensure(Dstate * s,Block ** l,int n)417 ensure(Dstate *s, Block **l, int n)
418 {
419 	int sofar, i;
420 	Block *b, *bl;
421 
422 	sofar = 0;
423 	for(b = *l; b; b = b->next){
424 		sofar += BLEN(b);
425 		if(sofar >= n)
426 			return;
427 		l = &b->next;
428 	}
429 
430 	while(sofar < n){
431 		bl = devtab[s->c->type]->bread(s->c, Maxdmsg, 0);
432 		if(bl == 0)
433 			nexterror();
434 		*l = bl;
435 		i = 0;
436 		for(b = bl; b; b = b->next){
437 			i += BLEN(b);
438 			l = &b->next;
439 		}
440 		if(i == 0)
441 			error(Ehungup);
442 		sofar += i;
443 	}
444 }
445 
446 /*
447  *  copy 'n' bytes from 'l' into 'p' and free
448  *  the bytes in 'l'
449  */
450 static void
consume(Block ** l,uchar * p,int n)451 consume(Block **l, uchar *p, int n)
452 {
453 	Block *b;
454 	int i;
455 
456 	for(; *l && n > 0; n -= i){
457 		b = *l;
458 		i = BLEN(b);
459 		if(i > n)
460 			i = n;
461 		memmove(p, b->rp, i);
462 		b->rp += i;
463 		p += i;
464 		if(BLEN(b) < 0)
465 			panic("consume");
466 		if(BLEN(b))
467 			break;
468 		*l = b->next;
469 		freeb(b);
470 	}
471 }
472 
473 /*
474  *  give back n bytes
475 static void
476 regurgitate(Dstate *s, uchar *p, int n)
477 {
478 	Block *b;
479 
480 	if(n <= 0)
481 		return;
482 	b = s->unprocessed;
483 	if(s->unprocessed == nil || b->rp - b->base < n) {
484 		b = allocb(n);
485 		memmove(b->wp, p, n);
486 		b->wp += n;
487 		b->next = s->unprocessed;
488 		s->unprocessed = b;
489 	} else {
490 		b->rp -= n;
491 		memmove(b->rp, p, n);
492 	}
493 }
494  */
495 
496 /*
497  *  remove at most n bytes from the queue, if discard is set
498  *  dump the remainder
499  */
500 static Block*
qtake(Block ** l,int n,int discard)501 qtake(Block **l, int n, int discard)
502 {
503 	Block *nb, *b, *first;
504 	int i;
505 
506 	first = *l;
507 	for(b = first; b; b = b->next){
508 		i = BLEN(b);
509 		if(i == n){
510 			if(discard){
511 				freeblist(b->next);
512 				*l = 0;
513 			} else
514 				*l = b->next;
515 			b->next = 0;
516 			return first;
517 		} else if(i > n){
518 			i -= n;
519 			if(discard){
520 				freeblist(b->next);
521 				b->wp -= i;
522 				*l = 0;
523 			} else {
524 				nb = allocb(i);
525 				memmove(nb->wp, b->rp+n, i);
526 				nb->wp += i;
527 				b->wp -= i;
528 				nb->next = b->next;
529 				*l = nb;
530 			}
531 			b->next = 0;
532 			if(BLEN(b) < 0)
533 				panic("qtake");
534 			return first;
535 		} else
536 			n -= i;
537 		if(BLEN(b) < 0)
538 			panic("qtake");
539 	}
540 	*l = 0;
541 	return first;
542 }
543 
544 /*
545  *  We can't let Eintr's lose data since the program
546  *  doing the read may be able to handle it.  The only
547  *  places Eintr is possible is during the read's in consume.
548  *  Therefore, we make sure we can always put back the bytes
549  *  consumed before the last ensure.
550  */
551 static Block*
sslbread(Chan * c,long n,ulong o)552 sslbread(Chan *c, long n, ulong o)
553 {
554 	Dstate * volatile s;
555 	Block *b;
556 	uchar consumed[3], *p;
557 	int toconsume;
558 	int len, pad;
559 
560 	USED(o);
561 	s = dstate[CONV(c->qid)];
562 	if(s == 0)
563 		panic("sslbread");
564 	if(s->state == Sincomplete)
565 		error(Ebadusefd);
566 
567 	qlock(&s->in.q);
568 	if(waserror()){
569 		qunlock(&s->in.q);
570 		nexterror();
571 	}
572 
573 	if(s->processed == 0){
574 		/*
575 		 * Read in the whole message.  Until we've got it all,
576 		 * it stays on s->unprocessed, so that if we get Eintr,
577 		 * we'll pick up where we left off.
578 		 */
579 		ensure(s, &s->unprocessed, 3);
580 		s->unprocessed = pullupblock(s->unprocessed, 2);
581 		p = s->unprocessed->rp;
582 		if(p[0] & 0x80){
583 			len = ((p[0] & 0x7f)<<8) | p[1];
584 			ensure(s, &s->unprocessed, len);
585 			pad = 0;
586 			toconsume = 2;
587 		} else {
588 			s->unprocessed = pullupblock(s->unprocessed, 3);
589 			len = ((p[0] & 0x3f)<<8) | p[1];
590 			pad = p[2];
591 			if(pad > len){
592 				print("pad %d buf len %d\n", pad, len);
593 				error("bad pad in ssl message");
594 			}
595 			toconsume = 3;
596 		}
597 		ensure(s, &s->unprocessed, toconsume+len);
598 
599 		/* skip header */
600 		consume(&s->unprocessed, consumed, toconsume);
601 
602 		/* grab the next message and decode/decrypt it */
603 		b = qtake(&s->unprocessed, len, 0);
604 
605 		if(blocklen(b) != len)
606 			print("devssl: sslbread got wrong count %d != %d", blocklen(b), len);
607 
608 		if(waserror()){
609 			qunlock(&s->in.ctlq);
610 			if(b != nil)
611 				freeb(b);
612 			nexterror();
613 		}
614 		qlock(&s->in.ctlq);
615 		switch(s->state){
616 		case Sencrypting:
617 			if(b == nil)
618 				error("ssl message too short (encrypting)");
619 			b = decryptb(s, b);
620 			break;
621 		case Sdigesting:
622 			b = pullupblock(b, s->diglen);
623 			if(b == nil)
624 				error("ssl message too short (digesting)");
625 			checkdigestb(s, b);
626 			pullblock(&b, s->diglen);
627 			len -= s->diglen;
628 			break;
629 		case Sdigenc:
630 			b = decryptb(s, b);
631 			b = pullupblock(b, s->diglen);
632 			if(b == nil)
633 				error("ssl message too short (dig+enc)");
634 			checkdigestb(s, b);
635 			pullblock(&b, s->diglen);
636 			len -= s->diglen;
637 			break;
638 		}
639 
640 		/* remove pad */
641 		if(pad)
642 			s->processed = qtake(&b, len - pad, 1);
643 		else
644 			s->processed = b;
645 		b = nil;
646 		s->in.mid++;
647 		qunlock(&s->in.ctlq);
648 		poperror();
649 	}
650 
651 	/* return at most what was asked for */
652 	b = qtake(&s->processed, n, 0);
653 
654 	qunlock(&s->in.q);
655 	poperror();
656 
657 	return b;
658 }
659 
660 static long
sslread(Chan * c,void * a,long n,vlong off)661 sslread(Chan *c, void *a, long n, vlong off)
662 {
663 	Block * volatile b;
664 	Block *nb;
665 	uchar *va;
666 	int i;
667 	char buf[128];
668 	ulong offset = off;
669 	int ft;
670 
671 	if(c->qid.type & QTDIR)
672 		return devdirread(c, a, n, 0, 0, sslgen);
673 
674 	ft = TYPE(c->qid);
675 	switch(ft) {
676 	default:
677 		error(Ebadusefd);
678 	case Qctl:
679 		ft = CONV(c->qid);
680 		sprint(buf, "%d", ft);
681 		return readstr(offset, a, n, buf);
682 	case Qdata:
683 		b = sslbread(c, n, offset);
684 		break;
685 	case Qencalgs:
686 		return readstr(offset, a, n, encalgs);
687 		break;
688 	case Qhashalgs:
689 		return readstr(offset, a, n, hashalgs);
690 		break;
691 	}
692 
693 	if(waserror()){
694 		freeblist(b);
695 		nexterror();
696 	}
697 
698 	n = 0;
699 	va = a;
700 	for(nb = b; nb; nb = nb->next){
701 		i = BLEN(nb);
702 		memmove(va+n, nb->rp, i);
703 		n += i;
704 	}
705 
706 	freeblist(b);
707 	poperror();
708 
709 	return n;
710 }
711 
712 /*
713  *  this algorithm doesn't have to be great since we're just
714  *  trying to obscure the block fill
715  */
716 static void
randfill(uchar * buf,int len)717 randfill(uchar *buf, int len)
718 {
719 	while(len-- > 0)
720 		*buf++ = fastrand();
721 }
722 
723 static long
sslbwrite(Chan * c,Block * b,ulong o)724 sslbwrite(Chan *c, Block *b, ulong o)
725 {
726 	Dstate * volatile s;
727 	long rv;
728 
729 	USED(o);
730 	s = dstate[CONV(c->qid)];
731 	if(s == nil)
732 		panic("sslbwrite");
733 
734 	if(s->state == Sincomplete){
735 		freeb(b);
736 		error(Ebadusefd);
737 	}
738 
739 	/* lock so split writes won't interleave */
740 	if(waserror()){
741 		qunlock(&s->out.q);
742 		nexterror();
743 	}
744 	qlock(&s->out.q);
745 
746 	rv = sslput(s, b);
747 
748 	poperror();
749 	qunlock(&s->out.q);
750 
751 	return rv;
752 }
753 
754 /*
755  *  use SSL record format, add in count, digest and/or encrypt.
756  *  the write is interruptable.  if it is interrupted, we'll
757  *  get out of sync with the far side.  not much we can do about
758  *  it since we don't know if any bytes have been written.
759  */
760 static long
sslput(Dstate * s,Block * volatile b)761 sslput(Dstate *s, Block * volatile b)
762 {
763 	Block *nb;
764 	int h, n, m, pad, rv;
765 	uchar *p;
766 	int offset;
767 
768 	if(waserror()){
769 iprint("error: %s\n", up->errstr);
770 		if(b != nil)
771 			free(b);
772 		nexterror();
773 	}
774 
775 	rv = 0;
776 	while(b != nil){
777 		m = n = BLEN(b);
778 		h = s->diglen + 2;
779 
780 		/* trim to maximum block size */
781 		pad = 0;
782 		if(m > s->max){
783 			m = s->max;
784 		} else if(s->blocklen != 1){
785 			pad = (m + s->diglen)%s->blocklen;
786 			if(pad){
787 				if(m > s->maxpad){
788 					pad = 0;
789 					m = s->maxpad;
790 				} else {
791 					pad = s->blocklen - pad;
792 					h++;
793 				}
794 			}
795 		}
796 
797 		rv += m;
798 		if(m != n){
799 			nb = allocb(m + h + pad);
800 			memmove(nb->wp + h, b->rp, m);
801 			nb->wp += m + h;
802 			b->rp += m;
803 		} else {
804 			/* add header space */
805 			nb = padblock(b, h);
806 			b = 0;
807 		}
808 		m += s->diglen;
809 
810 		/* SSL style count */
811 		if(pad){
812 			nb = padblock(nb, -pad);
813 			randfill(nb->wp, pad);
814 			nb->wp += pad;
815 			m += pad;
816 
817 			p = nb->rp;
818 			p[0] = (m>>8);
819 			p[1] = m;
820 			p[2] = pad;
821 			offset = 3;
822 		} else {
823 			p = nb->rp;
824 			p[0] = (m>>8) | 0x80;
825 			p[1] = m;
826 			offset = 2;
827 		}
828 
829 		switch(s->state){
830 		case Sencrypting:
831 			nb = encryptb(s, nb, offset);
832 			break;
833 		case Sdigesting:
834 			nb = digestb(s, nb, offset);
835 			break;
836 		case Sdigenc:
837 			nb = digestb(s, nb, offset);
838 			nb = encryptb(s, nb, offset);
839 			break;
840 		}
841 
842 		s->out.mid++;
843 
844 		m = BLEN(nb);
845 		devtab[s->c->type]->bwrite(s->c, nb, s->c->offset);
846 		s->c->offset += m;
847 	}
848 
849 	poperror();
850 	return rv;
851 }
852 
853 static void
setsecret(OneWay * w,uchar * secret,int n)854 setsecret(OneWay *w, uchar *secret, int n)
855 {
856 	if(w->secret)
857 		free(w->secret);
858 
859 	w->secret = smalloc(n);
860 	memmove(w->secret, secret, n);
861 	w->slen = n;
862 }
863 
864 static void
initDESkey(OneWay * w)865 initDESkey(OneWay *w)
866 {
867 	if(w->state){
868 		free(w->state);
869 		w->state = 0;
870 	}
871 
872 	w->state = smalloc(sizeof(DESstate));
873 	if(w->slen >= 16)
874 		setupDESstate(w->state, w->secret, w->secret+8);
875 	else if(w->slen >= 8)
876 		setupDESstate(w->state, w->secret, 0);
877 	else
878 		error("secret too short");
879 }
880 
881 /*
882  *  40 bit DES is the same as 56 bit DES.  However,
883  *  16 bits of the key are masked to zero.
884  */
885 static void
initDESkey_40(OneWay * w)886 initDESkey_40(OneWay *w)
887 {
888 	uchar key[8];
889 
890 	if(w->state){
891 		free(w->state);
892 		w->state = 0;
893 	}
894 
895 	if(w->slen >= 8){
896 		memmove(key, w->secret, 8);
897 		key[0] &= 0x0f;
898 		key[2] &= 0x0f;
899 		key[4] &= 0x0f;
900 		key[6] &= 0x0f;
901 	}
902 
903 	w->state = malloc(sizeof(DESstate));
904 	if(w->slen >= 16)
905 		setupDESstate(w->state, key, w->secret+8);
906 	else if(w->slen >= 8)
907 		setupDESstate(w->state, key, 0);
908 	else
909 		error("secret too short");
910 }
911 
912 static void
initRC4key(OneWay * w)913 initRC4key(OneWay *w)
914 {
915 	if(w->state){
916 		free(w->state);
917 		w->state = 0;
918 	}
919 
920 	w->state = smalloc(sizeof(RC4state));
921 	setupRC4state(w->state, w->secret, w->slen);
922 }
923 
924 /*
925  *  40 bit RC4 is the same as n-bit RC4.  However,
926  *  we ignore all but the first 40 bits of the key.
927  */
928 static void
initRC4key_40(OneWay * w)929 initRC4key_40(OneWay *w)
930 {
931 	if(w->state){
932 		free(w->state);
933 		w->state = 0;
934 	}
935 
936 	if(w->slen > 5)
937 		w->slen = 5;
938 
939 	w->state = malloc(sizeof(RC4state));
940 	setupRC4state(w->state, w->secret, w->slen);
941 }
942 
943 /*
944  *  128 bit RC4 is the same as n-bit RC4.  However,
945  *  we ignore all but the first 128 bits of the key.
946  */
947 static void
initRC4key_128(OneWay * w)948 initRC4key_128(OneWay *w)
949 {
950 	if(w->state){
951 		free(w->state);
952 		w->state = 0;
953 	}
954 
955 	if(w->slen > 16)
956 		w->slen = 16;
957 
958 	w->state = malloc(sizeof(RC4state));
959 	setupRC4state(w->state, w->secret, w->slen);
960 }
961 
962 
963 typedef struct Hashalg Hashalg;
964 struct Hashalg
965 {
966 	char	*name;
967 	int	diglen;
968 	DigestState *(*hf)(uchar*, ulong, uchar*, DigestState*);
969 };
970 
971 Hashalg hashtab[] =
972 {
973 	{ "md4", MD4dlen, md4, },
974 	{ "md5", MD5dlen, md5, },
975 	{ "sha1", SHA1dlen, sha1, },
976 	{ "sha", SHA1dlen, sha1, },
977 	{ 0 }
978 };
979 
980 static int
parsehashalg(char * p,Dstate * s)981 parsehashalg(char *p, Dstate *s)
982 {
983 	Hashalg *ha;
984 
985 	for(ha = hashtab; ha->name; ha++){
986 		if(strcmp(p, ha->name) == 0){
987 			s->hf = ha->hf;
988 			s->diglen = ha->diglen;
989 			s->state &= ~Sclear;
990 			s->state |= Sdigesting;
991 			return 0;
992 		}
993 	}
994 	return -1;
995 }
996 
997 typedef struct Encalg Encalg;
998 struct Encalg
999 {
1000 	char	*name;
1001 	int	blocklen;
1002 	int	alg;
1003 	void	(*keyinit)(OneWay*);
1004 };
1005 
1006 #ifdef NOSPOOKS
1007 Encalg encrypttab[] =
1008 {
1009 	{ "descbc", 8, DESCBC, initDESkey, },           /* DEPRECATED -- use des_56_cbc */
1010 	{ "desecb", 8, DESECB, initDESkey, },           /* DEPRECATED -- use des_56_ecb */
1011 	{ "des_56_cbc", 8, DESCBC, initDESkey, },
1012 	{ "des_56_ecb", 8, DESECB, initDESkey, },
1013 	{ "des_40_cbc", 8, DESCBC, initDESkey_40, },
1014 	{ "des_40_ecb", 8, DESECB, initDESkey_40, },
1015 	{ "rc4", 1, RC4, initRC4key_40, },              /* DEPRECATED -- use rc4_X      */
1016 	{ "rc4_256", 1, RC4, initRC4key, },
1017 	{ "rc4_128", 1, RC4, initRC4key_128, },
1018 	{ "rc4_40", 1, RC4, initRC4key_40, },
1019 	{ 0 }
1020 };
1021 #else
1022 Encalg encrypttab[] =
1023 {
1024 	{ "des_40_cbc", 8, DESCBC, initDESkey_40, },
1025 	{ "des_40_ecb", 8, DESECB, initDESkey_40, },
1026 	{ "rc4", 1, RC4, initRC4key_40, },              /* DEPRECATED -- use rc4_X      */
1027 	{ "rc4_40", 1, RC4, initRC4key_40, },
1028 	{ 0 }
1029 };
1030 #endif /* NOSPOOKS */
1031 
1032 static int
parseencryptalg(char * p,Dstate * s)1033 parseencryptalg(char *p, Dstate *s)
1034 {
1035 	Encalg *ea;
1036 
1037 	for(ea = encrypttab; ea->name; ea++){
1038 		if(strcmp(p, ea->name) == 0){
1039 			s->encryptalg = ea->alg;
1040 			s->blocklen = ea->blocklen;
1041 			(*ea->keyinit)(&s->in);
1042 			(*ea->keyinit)(&s->out);
1043 			s->state &= ~Sclear;
1044 			s->state |= Sencrypting;
1045 			return 0;
1046 		}
1047 	}
1048 	return -1;
1049 }
1050 
1051 static long
sslwrite(Chan * c,void * a,long n,vlong o)1052 sslwrite(Chan *c, void *a, long n, vlong o)
1053 {
1054 	Dstate * volatile s;
1055 	Block * volatile b;
1056 	int m, t;
1057 	char *p, *np, *e, buf[128];
1058 	uchar *x;
1059 
1060 	USED(o);
1061 	s = dstate[CONV(c->qid)];
1062 	if(s == 0)
1063 		panic("sslwrite");
1064 
1065 	t = TYPE(c->qid);
1066 	if(t == Qdata){
1067 		if(s->state == Sincomplete)
1068 			error(Ebadusefd);
1069 
1070 		/* lock should a write gets split over multiple records */
1071 		if(waserror()){
1072 			qunlock(&s->out.q);
1073 			nexterror();
1074 		}
1075 		qlock(&s->out.q);
1076 		p = a;
1077 if(0) iprint("write %d %.2ux %.2ux %.2ux %.2ux %.2ux %.2ux %.2ux %.2ux\n",
1078 	n, p[0], p[1], p[2], p[3], p[4], p[5], p[6], p[7]);
1079 		e = p + n;
1080 		do {
1081 			m = e - p;
1082 			if(m > s->max)
1083 				m = s->max;
1084 
1085 			b = allocb(m);
1086 			if(waserror()){
1087 				freeb(b);
1088 				nexterror();
1089 			}
1090 			memmove(b->wp, p, m);
1091 			poperror();
1092 			b->wp += m;
1093 
1094 			sslput(s, b);
1095 
1096 			p += m;
1097 		} while(p < e);
1098 		p = a;
1099 if(0) iprint("wrote %d %.2ux %.2ux %.2ux %.2ux %.2ux %.2ux %.2ux %.2ux\n",
1100 	n, p[0], p[1], p[2], p[3], p[4], p[5], p[6], p[7]);
1101 		poperror();
1102 		qunlock(&s->out.q);
1103 		return n;
1104 	}
1105 
1106 	/* mutex with operations using what we're about to change */
1107 	if(waserror()){
1108 		qunlock(&s->in.ctlq);
1109 		qunlock(&s->out.q);
1110 		nexterror();
1111 	}
1112 	qlock(&s->in.ctlq);
1113 	qlock(&s->out.q);
1114 
1115 	switch(t){
1116 	default:
1117 		panic("sslwrite");
1118 	case Qsecretin:
1119 		setsecret(&s->in, a, n);
1120 		goto out;
1121 	case Qsecretout:
1122 		setsecret(&s->out, a, n);
1123 		goto out;
1124 	case Qctl:
1125 		break;
1126 	}
1127 
1128 	if(n >= sizeof(buf))
1129 		error("arg too long");
1130 	strncpy(buf, a, n);
1131 	buf[n] = 0;
1132 	p = strchr(buf, '\n');
1133 	if(p)
1134 		*p = 0;
1135 	p = strchr(buf, ' ');
1136 	if(p)
1137 		*p++ = 0;
1138 
1139 	if(strcmp(buf, "fd") == 0){
1140 		s->c = buftochan(p);
1141 
1142 		/* default is clear (msg delimiters only) */
1143 		s->state = Sclear;
1144 		s->blocklen = 1;
1145 		s->diglen = 0;
1146 		s->maxpad = s->max = (1<<15) - s->diglen - 1;
1147 		s->in.mid = 0;
1148 		s->out.mid = 0;
1149 	} else if(strcmp(buf, "alg") == 0 && p != 0){
1150 		s->blocklen = 1;
1151 		s->diglen = 0;
1152 
1153 		if(s->c == 0)
1154 			error("must set fd before algorithm");
1155 
1156 		s->state = Sclear;
1157 		s->maxpad = s->max = (1<<15) - s->diglen - 1;
1158 		if(strcmp(p, "clear") == 0){
1159 			goto out;
1160 		}
1161 
1162 		if(s->in.secret && s->out.secret == 0)
1163 			setsecret(&s->out, s->in.secret, s->in.slen);
1164 		if(s->out.secret && s->in.secret == 0)
1165 			setsecret(&s->in, s->out.secret, s->out.slen);
1166 		if(s->in.secret == 0 || s->out.secret == 0)
1167 			error("algorithm but no secret");
1168 
1169 		s->hf = 0;
1170 		s->encryptalg = Noencryption;
1171 		s->blocklen = 1;
1172 
1173 		for(;;){
1174 			np = strchr(p, ' ');
1175 			if(np)
1176 				*np++ = 0;
1177 
1178 			if(parsehashalg(p, s) < 0)
1179 			if(parseencryptalg(p, s) < 0)
1180 				error("bad algorithm");
1181 
1182 			if(np == 0)
1183 				break;
1184 			p = np;
1185 		}
1186 
1187 		if(s->hf == 0 && s->encryptalg == Noencryption)
1188 			error("bad algorithm");
1189 
1190 		if(s->blocklen != 1){
1191 			s->max = (1<<15) - s->diglen - 1;
1192 			s->max -= s->max % s->blocklen;
1193 			s->maxpad = (1<<14) - s->diglen - 1;
1194 			s->maxpad -= s->maxpad % s->blocklen;
1195 		} else
1196 			s->maxpad = s->max = (1<<15) - s->diglen - 1;
1197 	} else if(strcmp(buf, "secretin") == 0 && p != 0) {
1198 		m = (strlen(p)*3)/2;
1199 		x = smalloc(m);
1200 		t = dec64(x, m, p, strlen(p));
1201 		setsecret(&s->in, x, t);
1202 		free(x);
1203 	} else if(strcmp(buf, "secretout") == 0 && p != 0) {
1204 		m = (strlen(p)*3)/2 + 1;
1205 		x = smalloc(m);
1206 		t = dec64(x, m, p, strlen(p));
1207 		setsecret(&s->out, x, t);
1208 		free(x);
1209 	} else
1210 		error(Ebadarg);
1211 
1212 out:
1213 	qunlock(&s->in.ctlq);
1214 	qunlock(&s->out.q);
1215 	poperror();
1216 	return n;
1217 }
1218 
1219 static void
sslinit(void)1220 sslinit(void)
1221 {
1222 	struct Encalg *e;
1223 	struct Hashalg *h;
1224 	int n;
1225 	char *cp;
1226 
1227 	n = 1;
1228 	for(e = encrypttab; e->name != nil; e++)
1229 		n += strlen(e->name) + 1;
1230 	cp = encalgs = smalloc(n);
1231 	for(e = encrypttab;;){
1232 		strcpy(cp, e->name);
1233 		cp += strlen(e->name);
1234 		e++;
1235 		if(e->name == nil)
1236 			break;
1237 		*cp++ = ' ';
1238 	}
1239 	*cp = 0;
1240 
1241 	n = 1;
1242 	for(h = hashtab; h->name != nil; h++)
1243 		n += strlen(h->name) + 1;
1244 	cp = hashalgs = smalloc(n);
1245 	for(h = hashtab;;){
1246 		strcpy(cp, h->name);
1247 		cp += strlen(h->name);
1248 		h++;
1249 		if(h->name == nil)
1250 			break;
1251 		*cp++ = ' ';
1252 	}
1253 	*cp = 0;
1254 }
1255 
1256 Dev ssldevtab = {
1257 	'D',
1258 	"ssl",
1259 
1260 	devreset,
1261 	sslinit,
1262 	devshutdown,
1263 	sslattach,
1264 	sslwalk,
1265 	sslstat,
1266 	sslopen,
1267 	devcreate,
1268 	sslclose,
1269 	sslread,
1270 	sslbread,
1271 	sslwrite,
1272 	sslbwrite,
1273 	devremove,
1274 	sslwstat,
1275 };
1276 
1277 static Block*
encryptb(Dstate * s,Block * b,int offset)1278 encryptb(Dstate *s, Block *b, int offset)
1279 {
1280 	uchar *p, *ep, *p2, *ip, *eip;
1281 	DESstate *ds;
1282 
1283 	switch(s->encryptalg){
1284 	case DESECB:
1285 		ds = s->out.state;
1286 		ep = b->rp + BLEN(b);
1287 		for(p = b->rp + offset; p < ep; p += 8)
1288 			block_cipher(ds->expanded, p, 0);
1289 		break;
1290 	case DESCBC:
1291 		ds = s->out.state;
1292 		ep = b->rp + BLEN(b);
1293 		for(p = b->rp + offset; p < ep; p += 8){
1294 			p2 = p;
1295 			ip = ds->ivec;
1296 			for(eip = ip+8; ip < eip; )
1297 				*p2++ ^= *ip++;
1298 			block_cipher(ds->expanded, p, 0);
1299 			memmove(ds->ivec, p, 8);
1300 		}
1301 		break;
1302 	case RC4:
1303 		rc4(s->out.state, b->rp + offset, BLEN(b) - offset);
1304 		break;
1305 	}
1306 	return b;
1307 }
1308 
1309 static Block*
decryptb(Dstate * s,Block * bin)1310 decryptb(Dstate *s, Block *bin)
1311 {
1312 	Block *b, **l;
1313 	uchar *p, *ep, *tp, *ip, *eip;
1314 	DESstate *ds;
1315 	uchar tmp[8];
1316 	int i;
1317 
1318 	l = &bin;
1319 	for(b = bin; b; b = b->next){
1320 		/* make sure we have a multiple of s->blocklen */
1321 		if(s->blocklen > 1){
1322 			i = BLEN(b);
1323 			if(i % s->blocklen){
1324 				*l = b = pullupblock(b, i + s->blocklen - (i%s->blocklen));
1325 				if(b == 0)
1326 					error("ssl encrypted message too short");
1327 			}
1328 		}
1329 		l = &b->next;
1330 
1331 		/* decrypt */
1332 		switch(s->encryptalg){
1333 		case DESECB:
1334 			ds = s->in.state;
1335 			ep = b->rp + BLEN(b);
1336 			for(p = b->rp; p < ep; p += 8)
1337 				block_cipher(ds->expanded, p, 1);
1338 			break;
1339 		case DESCBC:
1340 			ds = s->in.state;
1341 			ep = b->rp + BLEN(b);
1342 			for(p = b->rp; p < ep;){
1343 				memmove(tmp, p, 8);
1344 				block_cipher(ds->expanded, p, 1);
1345 				tp = tmp;
1346 				ip = ds->ivec;
1347 				for(eip = ip+8; ip < eip; ){
1348 					*p++ ^= *ip;
1349 					*ip++ = *tp++;
1350 				}
1351 			}
1352 			break;
1353 		case RC4:
1354 			rc4(s->in.state, b->rp, BLEN(b));
1355 			break;
1356 		}
1357 	}
1358 	return bin;
1359 }
1360 
1361 static Block*
digestb(Dstate * s,Block * b,int offset)1362 digestb(Dstate *s, Block *b, int offset)
1363 {
1364 	uchar *p;
1365 	DigestState ss;
1366 	uchar msgid[4];
1367 	ulong n, h;
1368 	OneWay *w;
1369 
1370 	w = &s->out;
1371 
1372 	memset(&ss, 0, sizeof(ss));
1373 	h = s->diglen + offset;
1374 	n = BLEN(b) - h;
1375 
1376 	/* hash secret + message */
1377 	(*s->hf)(w->secret, w->slen, 0, &ss);
1378 	(*s->hf)(b->rp + h, n, 0, &ss);
1379 
1380 	/* hash message id */
1381 	p = msgid;
1382 	n = w->mid;
1383 	*p++ = n>>24;
1384 	*p++ = n>>16;
1385 	*p++ = n>>8;
1386 	*p = n;
1387 	(*s->hf)(msgid, 4, b->rp + offset, &ss);
1388 
1389 	return b;
1390 }
1391 
1392 static void
checkdigestb(Dstate * s,Block * bin)1393 checkdigestb(Dstate *s, Block *bin)
1394 {
1395 	uchar *p;
1396 	DigestState ss;
1397 	uchar msgid[4];
1398 	int n, h;
1399 	OneWay *w;
1400 	uchar digest[128];
1401 	Block *b;
1402 
1403 	w = &s->in;
1404 
1405 	memset(&ss, 0, sizeof(ss));
1406 
1407 	/* hash secret */
1408 	(*s->hf)(w->secret, w->slen, 0, &ss);
1409 
1410 	/* hash message */
1411 	h = s->diglen;
1412 	for(b = bin; b; b = b->next){
1413 		n = BLEN(b) - h;
1414 		if(n < 0)
1415 			panic("checkdigestb");
1416 		(*s->hf)(b->rp + h, n, 0, &ss);
1417 		h = 0;
1418 	}
1419 
1420 	/* hash message id */
1421 	p = msgid;
1422 	n = w->mid;
1423 	*p++ = n>>24;
1424 	*p++ = n>>16;
1425 	*p++ = n>>8;
1426 	*p = n;
1427 	(*s->hf)(msgid, 4, digest, &ss);
1428 
1429 	if(memcmp(digest, bin->rp, s->diglen) != 0)
1430 		error("bad digest");
1431 }
1432 
1433 /* get channel associated with an fd */
1434 static Chan*
buftochan(char * p)1435 buftochan(char *p)
1436 {
1437 	Chan *c;
1438 	int fd;
1439 
1440 	if(p == 0)
1441 		error(Ebadarg);
1442 	fd = strtoul(p, 0, 0);
1443 	if(fd < 0)
1444 		error(Ebadarg);
1445 	c = fdtochan(fd, -1, 0, 1);	/* error check and inc ref */
1446 	if(devtab[c->type] == &ssldevtab){
1447 		cclose(c);
1448 		error("cannot ssl encrypt devssl files");
1449 	}
1450 	return c;
1451 }
1452 
1453 /* hand up a digest connection */
1454 static void
sslhangup(Dstate * s)1455 sslhangup(Dstate *s)
1456 {
1457 	Block *b;
1458 
1459 	qlock(&s->in.q);
1460 	for(b = s->processed; b; b = s->processed){
1461 		s->processed = b->next;
1462 		freeb(b);
1463 	}
1464 	if(s->unprocessed){
1465 		freeb(s->unprocessed);
1466 		s->unprocessed = 0;
1467 	}
1468 	s->state = Sincomplete;
1469 	qunlock(&s->in.q);
1470 }
1471 
1472 static Dstate*
dsclone(Chan * ch)1473 dsclone(Chan *ch)
1474 {
1475 	int i;
1476 	Dstate *ret;
1477 
1478 	if(waserror()) {
1479 		unlock(&dslock);
1480 		nexterror();
1481 	}
1482 	lock(&dslock);
1483 	ret = nil;
1484 	for(i=0; i<Maxdstate; i++){
1485 		if(dstate[i] == nil){
1486 			dsnew(ch, &dstate[i]);
1487 			ret = dstate[i];
1488 			break;
1489 		}
1490 	}
1491 	unlock(&dslock);
1492 	poperror();
1493 	return ret;
1494 }
1495 
1496 static void
dsnew(Chan * ch,Dstate ** pp)1497 dsnew(Chan *ch, Dstate **pp)
1498 {
1499 	Dstate *s;
1500 	int t;
1501 
1502 	*pp = s = malloc(sizeof(*s));
1503 	if(!s)
1504 		error(Enomem);
1505 	if(pp - dstate >= dshiwat)
1506 		dshiwat++;
1507 	memset(s, 0, sizeof(*s));
1508 	s->state = Sincomplete;
1509 	s->ref = 1;
1510 	kstrdup(&s->user, up->user);
1511 	s->perm = 0660;
1512 	t = TYPE(ch->qid);
1513 	if(t == Qclonus)
1514 		t = Qctl;
1515 	ch->qid.path = QID(pp - dstate, t);
1516 	ch->qid.vers = 0;
1517 }
1518