xref: /plan9-contrib/sys/src/9k/port/devssl.c (revision 406c76facc4b13aa2a55454bf4091aab9f03da22)
1 /*
2  *  devssl - secure sockets layer
3  */
4 #include	"u.h"
5 #include	"../port/lib.h"
6 #include	"mem.h"
7 #include	"dat.h"
8 #include	"fns.h"
9 #include	"../port/error.h"
10 
11 #include	<libsec.h>
12 
13 #define NOSPOOKS 1
14 
15 typedef struct OneWay OneWay;
16 struct OneWay
17 {
18 	QLock	q;
19 	QLock	ctlq;
20 
21 	void	*state;		/* encryption state */
22 	int	slen;		/* hash data length */
23 	uchar	*secret;	/* secret */
24 	ulong	mid;		/* message id */
25 };
26 
27 enum
28 {
29 	/* connection states */
30 	Sincomplete=	0,
31 	Sclear=		1,
32 	Sencrypting=	2,
33 	Sdigesting=	4,
34 	Sdigenc=	Sencrypting|Sdigesting,
35 
36 	/* encryption algorithms */
37 	Noencryption=	0,
38 	DESCBC=		1,
39 	DESECB=		2,
40 	RC4=		3
41 };
42 
43 typedef struct Dstate Dstate;
44 struct Dstate
45 {
46 	Chan	*c;		/* io channel */
47 	uchar	state;		/* state of connection */
48 	int	ref;		/* serialized by dslock for atomic destroy */
49 
50 	uchar	encryptalg;	/* encryption algorithm */
51 	ushort	blocklen;	/* blocking length */
52 
53 	ushort	diglen;		/* length of digest */
54 	DigestState *(*hf)(uchar*, ulong, uchar*, DigestState*);	/* hash func */
55 
56 	/* for SSL format */
57 	int	max;		/* maximum unpadded data per msg */
58 	int	maxpad;		/* maximum padded data per msg */
59 
60 	/* input side */
61 	OneWay	in;
62 	Block	*processed;
63 	Block	*unprocessed;
64 
65 	/* output side */
66 	OneWay	out;
67 
68 	/* protections */
69 	char	*user;
70 	int	perm;
71 };
72 
73 enum
74 {
75 	Maxdmsg=	1<<16,
76 	Maxdstate=	512,	/* max. open ssl conn's; must be a power of 2 */
77 };
78 
79 static	Lock	dslock;
80 static	int	dshiwat;
81 static	char	*dsname[Maxdstate];
82 static	Dstate	*dstate[Maxdstate];
83 static	char	*encalgs;
84 static	char	*hashalgs;
85 
86 enum{
87 	Qtopdir		= 1,	/* top level directory */
88 	Qprotodir,
89 	Qclonus,
90 	Qconvdir,		/* directory for a conversation */
91 	Qdata,
92 	Qctl,
93 	Qsecretin,
94 	Qsecretout,
95 	Qencalgs,
96 	Qhashalgs,
97 };
98 
99 #define TYPE(x) 	((x).path & 0xf)
100 #define CONV(x) 	(((x).path >> 5)&(Maxdstate-1))
101 #define QID(c, y) 	(((c)<<5) | (y))
102 
103 static void	ensure(Dstate*, Block**, int);
104 static void	consume(Block**, uchar*, int);
105 static void	setsecret(OneWay*, uchar*, int);
106 static Block*	encryptb(Dstate*, Block*, int);
107 static Block*	decryptb(Dstate*, Block*);
108 static Block*	digestb(Dstate*, Block*, int);
109 static void	checkdigestb(Dstate*, Block*);
110 static Chan*	buftochan(char*);
111 static void	sslhangup(Dstate*);
112 static Dstate*	dsclone(Chan *c);
113 static void	dsnew(Chan *c, Dstate **);
114 static long	sslput(Dstate *s, Block * volatile b);
115 
116 char *sslnames[] = {
117 [Qclonus]	"clone",
118 [Qdata]		"data",
119 [Qctl]		"ctl",
120 [Qsecretin]	"secretin",
121 [Qsecretout]	"secretout",
122 [Qencalgs]	"encalgs",
123 [Qhashalgs]	"hashalgs",
124 };
125 
126 static int
sslgen(Chan * c,char *,Dirtab * d,int nd,int s,Dir * dp)127 sslgen(Chan *c, char*, Dirtab *d, int nd, int s, Dir *dp)
128 {
129 	Qid q;
130 	Dstate *ds;
131 	char name[16], *p, *nm;
132 	int ft;
133 
134 	USED(nd);
135 	USED(d);
136 
137 	q.type = QTFILE;
138 	q.vers = 0;
139 
140 	ft = TYPE(c->qid);
141 	switch(ft) {
142 	case Qtopdir:
143 		if(s == DEVDOTDOT){
144 			q.path = QID(0, Qtopdir);
145 			q.type = QTDIR;
146 			devdir(c, q, "#D", 0, eve, 0555, dp);
147 			return 1;
148 		}
149 		if(s > 0)
150 			return -1;
151 		q.path = QID(0, Qprotodir);
152 		q.type = QTDIR;
153 		devdir(c, q, "ssl", 0, eve, 0555, dp);
154 		return 1;
155 	case Qprotodir:
156 		if(s == DEVDOTDOT){
157 			q.path = QID(0, Qtopdir);
158 			q.type = QTDIR;
159 			devdir(c, q, ".", 0, eve, 0555, dp);
160 			return 1;
161 		}
162 		if(s < dshiwat) {
163 			q.path = QID(s, Qconvdir);
164 			q.type = QTDIR;
165 			ds = dstate[s];
166 			if(ds != 0)
167 				nm = ds->user;
168 			else
169 				nm = eve;
170 			if(dsname[s] == nil){
171 				snprint(name, sizeof name, "%d", s);
172 				kstrdup(&dsname[s], name);
173 			}
174 			devdir(c, q, dsname[s], 0, nm, 0555, dp);
175 			return 1;
176 		}
177 		if(s > dshiwat)
178 			return -1;
179 		q.path = QID(0, Qclonus);
180 		devdir(c, q, "clone", 0, eve, 0555, dp);
181 		return 1;
182 	case Qconvdir:
183 		if(s == DEVDOTDOT){
184 			q.path = QID(0, Qprotodir);
185 			q.type = QTDIR;
186 			devdir(c, q, "ssl", 0, eve, 0555, dp);
187 			return 1;
188 		}
189 		ds = dstate[CONV(c->qid)];
190 		if(ds != 0)
191 			nm = ds->user;
192 		else
193 			nm = eve;
194 		switch(s) {
195 		default:
196 			return -1;
197 		case 0:
198 			q.path = QID(CONV(c->qid), Qctl);
199 			p = "ctl";
200 			break;
201 		case 1:
202 			q.path = QID(CONV(c->qid), Qdata);
203 			p = "data";
204 			break;
205 		case 2:
206 			q.path = QID(CONV(c->qid), Qsecretin);
207 			p = "secretin";
208 			break;
209 		case 3:
210 			q.path = QID(CONV(c->qid), Qsecretout);
211 			p = "secretout";
212 			break;
213 		case 4:
214 			q.path = QID(CONV(c->qid), Qencalgs);
215 			p = "encalgs";
216 			break;
217 		case 5:
218 			q.path = QID(CONV(c->qid), Qhashalgs);
219 			p = "hashalgs";
220 			break;
221 		}
222 		devdir(c, q, p, 0, nm, 0660, dp);
223 		return 1;
224 	case Qclonus:
225 		devdir(c, c->qid, sslnames[TYPE(c->qid)], 0, eve, 0555, dp);
226 		return 1;
227 	default:
228 		ds = dstate[CONV(c->qid)];
229 		if(ds != 0)
230 			nm = ds->user;
231 		else
232 			nm = eve;
233 		devdir(c, c->qid, sslnames[TYPE(c->qid)], 0, nm, 0660, dp);
234 		return 1;
235 	}
236 }
237 
238 static Chan*
sslattach(char * spec)239 sslattach(char *spec)
240 {
241 	Chan *c;
242 
243 	c = devattach('D', spec);
244 	c->qid.path = QID(0, Qtopdir);
245 	c->qid.vers = 0;
246 	c->qid.type = QTDIR;
247 	return c;
248 }
249 
250 static Walkqid*
sslwalk(Chan * c,Chan * nc,char ** name,int nname)251 sslwalk(Chan *c, Chan *nc, char **name, int nname)
252 {
253 	return devwalk(c, nc, name, nname, nil, 0, sslgen);
254 }
255 
256 static long
sslstat(Chan * c,uchar * db,long n)257 sslstat(Chan *c, uchar *db, long n)
258 {
259 	return devstat(c, db, n, nil, 0, sslgen);
260 }
261 
262 static Chan*
sslopen(Chan * c,int omode)263 sslopen(Chan *c, int omode)
264 {
265 	Dstate *s, **pp;
266 	int perm;
267 	int ft;
268 
269 	perm = 0;
270 	omode &= 3;
271 	switch(omode) {
272 	case OREAD:
273 		perm = 4;
274 		break;
275 	case OWRITE:
276 		perm = 2;
277 		break;
278 	case ORDWR:
279 		perm = 6;
280 		break;
281 	}
282 
283 	ft = TYPE(c->qid);
284 	switch(ft) {
285 	default:
286 		panic("sslopen");
287 	case Qtopdir:
288 	case Qprotodir:
289 	case Qconvdir:
290 		if(omode != OREAD)
291 			error(Eperm);
292 		break;
293 	case Qclonus:
294 		s = dsclone(c);
295 		if(s == 0)
296 			error(Enodev);
297 		break;
298 	case Qctl:
299 	case Qdata:
300 	case Qsecretin:
301 	case Qsecretout:
302 		if(waserror()) {
303 			unlock(&dslock);
304 			nexterror();
305 		}
306 		lock(&dslock);
307 		pp = &dstate[CONV(c->qid)];
308 		s = *pp;
309 		if(s == 0)
310 			dsnew(c, pp);
311 		else {
312 			if((perm & (s->perm>>6)) != perm
313 			   && (strcmp(up->user, s->user) != 0
314 			     || (perm & s->perm) != perm))
315 				error(Eperm);
316 
317 			s->ref++;
318 		}
319 		unlock(&dslock);
320 		poperror();
321 		break;
322 	case Qencalgs:
323 	case Qhashalgs:
324 		if(omode != OREAD)
325 			error(Eperm);
326 		break;
327 	}
328 	c->mode = openmode(omode);
329 	c->flag |= COPEN;
330 	c->offset = 0;
331 	return c;
332 }
333 
334 static long
sslwstat(Chan * c,uchar * db,long n)335 sslwstat(Chan *c, uchar *db, long n)
336 {
337 	Dir *dir;
338 	Dstate *s;
339 	int l;
340 
341 	s = dstate[CONV(c->qid)];
342 	if(s == 0)
343 		error(Ebadusefd);
344 	if(strcmp(s->user, up->user) != 0)
345 		error(Eperm);
346 
347 	dir = smalloc(sizeof(Dir)+n);
348 	l = convM2D(db, n, &dir[0], (char*)&dir[1]);
349 	if(l == 0){
350 		free(dir);
351 		error(Eshortstat);
352 	}
353 
354 	if(!emptystr(dir->uid))
355 		kstrdup(&s->user, dir->uid);
356 	if(dir->mode != ~0UL)
357 		s->perm = dir->mode;
358 
359 	free(dir);
360 	return l;
361 }
362 
363 static void
sslclose(Chan * c)364 sslclose(Chan *c)
365 {
366 	Dstate *s;
367 	int ft;
368 
369 	ft = TYPE(c->qid);
370 	switch(ft) {
371 	case Qctl:
372 	case Qdata:
373 	case Qsecretin:
374 	case Qsecretout:
375 		if((c->flag & COPEN) == 0)
376 			break;
377 
378 		s = dstate[CONV(c->qid)];
379 		if(s == 0)
380 			break;
381 
382 		lock(&dslock);
383 		if(--s->ref > 0) {
384 			unlock(&dslock);
385 			break;
386 		}
387 		dstate[CONV(c->qid)] = 0;
388 		unlock(&dslock);
389 
390 		if(s->user != nil)
391 			free(s->user);
392 		sslhangup(s);
393 		if(s->c)
394 			cclose(s->c);
395 		if(s->in.secret)
396 			free(s->in.secret);
397 		if(s->out.secret)
398 			free(s->out.secret);
399 		if(s->in.state)
400 			free(s->in.state);
401 		if(s->out.state)
402 			free(s->out.state);
403 		free(s);
404 
405 	}
406 }
407 
408 /*
409  *  make sure we have at least 'n' bytes in list 'l'
410  */
411 static void
ensure(Dstate * s,Block ** l,int n)412 ensure(Dstate *s, Block **l, int n)
413 {
414 	int sofar, i;
415 	Block *b, *bl;
416 
417 	sofar = 0;
418 	for(b = *l; b; b = b->next){
419 		sofar += BLEN(b);
420 		if(sofar >= n)
421 			return;
422 		l = &b->next;
423 	}
424 
425 	while(sofar < n){
426 		bl = s->c->dev->bread(s->c, Maxdmsg, 0);
427 		if(bl == 0)
428 			nexterror();
429 		*l = bl;
430 		i = 0;
431 		for(b = bl; b; b = b->next){
432 			i += BLEN(b);
433 			l = &b->next;
434 		}
435 		if(i == 0)
436 			error(Ehungup);
437 		sofar += i;
438 	}
439 }
440 
441 /*
442  *  copy 'n' bytes from 'l' into 'p' and free
443  *  the bytes in 'l'
444  */
445 static void
consume(Block ** l,uchar * p,int n)446 consume(Block **l, uchar *p, int n)
447 {
448 	Block *b;
449 	int i;
450 
451 	for(; *l && n > 0; n -= i){
452 		b = *l;
453 		i = BLEN(b);
454 		if(i > n)
455 			i = n;
456 		memmove(p, b->rp, i);
457 		b->rp += i;
458 		p += i;
459 		if(BLEN(b) < 0)
460 			panic("consume");
461 		if(BLEN(b))
462 			break;
463 		*l = b->next;
464 		freeb(b);
465 	}
466 }
467 
468 /*
469  *  give back n bytes
470 static void
471 regurgitate(Dstate *s, uchar *p, int n)
472 {
473 	Block *b;
474 
475 	if(n <= 0)
476 		return;
477 	b = s->unprocessed;
478 	if(s->unprocessed == nil || b->rp - b->base < n) {
479 		b = allocb(n);
480 		memmove(b->wp, p, n);
481 		b->wp += n;
482 		b->next = s->unprocessed;
483 		s->unprocessed = b;
484 	} else {
485 		b->rp -= n;
486 		memmove(b->rp, p, n);
487 	}
488 }
489  */
490 
491 /*
492  *  remove at most n bytes from the queue, if discard is set
493  *  dump the remainder
494  */
495 static Block*
qtake(Block ** l,int n,int discard)496 qtake(Block **l, int n, int discard)
497 {
498 	Block *nb, *b, *first;
499 	int i;
500 
501 	first = *l;
502 	for(b = first; b; b = b->next){
503 		i = BLEN(b);
504 		if(i == n){
505 			if(discard){
506 				freeblist(b->next);
507 				*l = 0;
508 			} else
509 				*l = b->next;
510 			b->next = 0;
511 			return first;
512 		} else if(i > n){
513 			i -= n;
514 			if(discard){
515 				freeblist(b->next);
516 				b->wp -= i;
517 				*l = 0;
518 			} else {
519 				nb = allocb(i);
520 				memmove(nb->wp, b->rp+n, i);
521 				nb->wp += i;
522 				b->wp -= i;
523 				nb->next = b->next;
524 				*l = nb;
525 			}
526 			b->next = 0;
527 			if(BLEN(b) < 0)
528 				panic("qtake");
529 			return first;
530 		} else
531 			n -= i;
532 		if(BLEN(b) < 0)
533 			panic("qtake");
534 	}
535 	*l = 0;
536 	return first;
537 }
538 
539 /*
540  *  We can't let Eintr's lose data since the program
541  *  doing the read may be able to handle it.  The only
542  *  places Eintr is possible is during the read's in consume.
543  *  Therefore, we make sure we can always put back the bytes
544  *  consumed before the last ensure.
545  */
546 static Block*
sslbread(Chan * c,long n,vlong)547 sslbread(Chan *c, long n, vlong)
548 {
549 	Dstate * volatile s;
550 	Block *b;
551 	uchar consumed[3], *p;
552 	int toconsume;
553 	int len, pad;
554 
555 	s = dstate[CONV(c->qid)];
556 	if(s == 0)
557 		panic("sslbread");
558 	if(s->state == Sincomplete)
559 		error(Ebadusefd);
560 
561 	qlock(&s->in.q);
562 	if(waserror()){
563 		qunlock(&s->in.q);
564 		nexterror();
565 	}
566 
567 	if(s->processed == 0){
568 		/*
569 		 * Read in the whole message.  Until we've got it all,
570 		 * it stays on s->unprocessed, so that if we get Eintr,
571 		 * we'll pick up where we left off.
572 		 */
573 		ensure(s, &s->unprocessed, 3);
574 		s->unprocessed = pullupblock(s->unprocessed, 2);
575 		p = s->unprocessed->rp;
576 		if(p[0] & 0x80){
577 			len = ((p[0] & 0x7f)<<8) | p[1];
578 			ensure(s, &s->unprocessed, len);
579 			pad = 0;
580 			toconsume = 2;
581 		} else {
582 			s->unprocessed = pullupblock(s->unprocessed, 3);
583 			len = ((p[0] & 0x3f)<<8) | p[1];
584 			pad = p[2];
585 			if(pad > len){
586 				print("pad %d buf len %d\n", pad, len);
587 				error("bad pad in ssl message");
588 			}
589 			toconsume = 3;
590 		}
591 		ensure(s, &s->unprocessed, toconsume+len);
592 
593 		/* skip header */
594 		consume(&s->unprocessed, consumed, toconsume);
595 
596 		/* grab the next message and decode/decrypt it */
597 		b = qtake(&s->unprocessed, len, 0);
598 
599 		if(blocklen(b) != len)
600 			print("devssl: sslbread got wrong count %d != %d", blocklen(b), len);
601 
602 		if(waserror()){
603 			qunlock(&s->in.ctlq);
604 			if(b != nil)
605 				freeb(b);
606 			nexterror();
607 		}
608 		qlock(&s->in.ctlq);
609 		switch(s->state){
610 		case Sencrypting:
611 			if(b == nil)
612 				error("ssl message too short (encrypting)");
613 			b = decryptb(s, b);
614 			break;
615 		case Sdigesting:
616 			b = pullupblock(b, s->diglen);
617 			if(b == nil)
618 				error("ssl message too short (digesting)");
619 			checkdigestb(s, b);
620 			pullblock(&b, s->diglen);
621 			len -= s->diglen;
622 			break;
623 		case Sdigenc:
624 			b = decryptb(s, b);
625 			b = pullupblock(b, s->diglen);
626 			if(b == nil)
627 				error("ssl message too short (dig+enc)");
628 			checkdigestb(s, b);
629 			pullblock(&b, s->diglen);
630 			len -= s->diglen;
631 			break;
632 		}
633 
634 		/* remove pad */
635 		if(pad)
636 			s->processed = qtake(&b, len - pad, 1);
637 		else
638 			s->processed = b;
639 		b = nil;
640 		s->in.mid++;
641 		qunlock(&s->in.ctlq);
642 		poperror();
643 	}
644 
645 	/* return at most what was asked for */
646 	b = qtake(&s->processed, n, 0);
647 
648 	qunlock(&s->in.q);
649 	poperror();
650 
651 	return b;
652 }
653 
654 static long
sslread(Chan * c,void * a,long n,vlong off)655 sslread(Chan *c, void *a, long n, vlong off)
656 {
657 	Block * volatile b;
658 	Block *nb;
659 	uchar *va;
660 	int i;
661 	char buf[128];
662 	long offset;
663 	int ft;
664 
665 	if(c->qid.type & QTDIR)
666 		return devdirread(c, a, n, 0, 0, sslgen);
667 
668 	ft = TYPE(c->qid);
669 	offset = off;
670 	switch(ft) {
671 	default:
672 		error(Ebadusefd);
673 	case Qctl:
674 		ft = CONV(c->qid);
675 		snprint(buf, sizeof buf, "%d", ft);
676 		return readstr(offset, a, n, buf);
677 	case Qdata:
678 		b = sslbread(c, n, offset);
679 		break;
680 	case Qencalgs:
681 		return readstr(offset, a, n, encalgs);
682 		break;
683 	case Qhashalgs:
684 		return readstr(offset, a, n, hashalgs);
685 		break;
686 	}
687 
688 	if(waserror()){
689 		freeblist(b);
690 		nexterror();
691 	}
692 
693 	n = 0;
694 	va = a;
695 	for(nb = b; nb; nb = nb->next){
696 		i = BLEN(nb);
697 		memmove(va+n, nb->rp, i);
698 		n += i;
699 	}
700 
701 	freeblist(b);
702 	poperror();
703 
704 	return n;
705 }
706 
707 /*
708  *  this algorithm doesn't have to be great since we're just
709  *  trying to obscure the block fill
710  */
711 static void
randfill(uchar * buf,int len)712 randfill(uchar *buf, int len)
713 {
714 	while(len-- > 0)
715 		*buf++ = nrand(256);
716 }
717 
718 static long
sslbwrite(Chan * c,Block * b,vlong)719 sslbwrite(Chan *c, Block *b, vlong)
720 {
721 	Dstate * volatile s;
722 	long rv;
723 
724 	s = dstate[CONV(c->qid)];
725 	if(s == nil)
726 		panic("sslbwrite");
727 
728 	if(s->state == Sincomplete){
729 		freeb(b);
730 		error(Ebadusefd);
731 	}
732 
733 	/* lock so split writes won't interleave */
734 	if(waserror()){
735 		qunlock(&s->out.q);
736 		nexterror();
737 	}
738 	qlock(&s->out.q);
739 
740 	rv = sslput(s, b);
741 
742 	poperror();
743 	qunlock(&s->out.q);
744 
745 	return rv;
746 }
747 
748 /*
749  *  use SSL record format, add in count, digest and/or encrypt.
750  *  the write is interruptable.  if it is interrupted, we'll
751  *  get out of sync with the far side.  not much we can do about
752  *  it since we don't know if any bytes have been written.
753  */
754 static long
sslput(Dstate * s,Block * volatile b)755 sslput(Dstate *s, Block * volatile b)
756 {
757 	Block *nb;
758 	int h, n, l, pad, rv;
759 	uchar *p;
760 	int offset;
761 
762 	if(waserror()){
763 		if(b != nil)
764 			freeb(b);
765 		nexterror();
766 	}
767 
768 	rv = 0;
769 	while(b != nil){
770 		l = n = BLEN(b);
771 		h = s->diglen + 2;
772 
773 		/* trim to maximum block size */
774 		pad = 0;
775 		if(l > s->max){
776 			l = s->max;
777 		} else if(s->blocklen != 1){
778 			pad = (l + s->diglen)%s->blocklen;
779 			if(pad){
780 				if(l > s->maxpad){
781 					pad = 0;
782 					l = s->maxpad;
783 				} else {
784 					pad = s->blocklen - pad;
785 					h++;
786 				}
787 			}
788 		}
789 
790 		rv += l;
791 		if(l != n){
792 			nb = allocb(l + h + pad);
793 			memmove(nb->wp + h, b->rp, l);
794 			nb->wp += l + h;
795 			b->rp += l;
796 		} else {
797 			/* add header space */
798 			nb = padblock(b, h);
799 			b = 0;
800 		}
801 		l += s->diglen;
802 
803 		/* SSL style count */
804 		if(pad){
805 			nb = padblock(nb, -pad);
806 			randfill(nb->wp, pad);
807 			nb->wp += pad;
808 			l += pad;
809 
810 			p = nb->rp;
811 			p[0] = (l>>8);
812 			p[1] = l;
813 			p[2] = pad;
814 			offset = 3;
815 		} else {
816 			p = nb->rp;
817 			p[0] = (l>>8) | 0x80;
818 			p[1] = l;
819 			offset = 2;
820 		}
821 
822 		switch(s->state){
823 		case Sencrypting:
824 			nb = encryptb(s, nb, offset);
825 			break;
826 		case Sdigesting:
827 			nb = digestb(s, nb, offset);
828 			break;
829 		case Sdigenc:
830 			nb = digestb(s, nb, offset);
831 			nb = encryptb(s, nb, offset);
832 			break;
833 		}
834 
835 		s->out.mid++;
836 
837 		l = BLEN(nb);
838 		s->c->dev->bwrite(s->c, nb, s->c->offset);
839 		s->c->offset += l;
840 	}
841 
842 	poperror();
843 	return rv;
844 }
845 
846 static void
setsecret(OneWay * w,uchar * secret,int n)847 setsecret(OneWay *w, uchar *secret, int n)
848 {
849 	if(w->secret)
850 		free(w->secret);
851 
852 	w->secret = smalloc(n);
853 	memmove(w->secret, secret, n);
854 	w->slen = n;
855 }
856 
857 static void
initDESkey(OneWay * w)858 initDESkey(OneWay *w)
859 {
860 	if(w->state){
861 		free(w->state);
862 		w->state = 0;
863 	}
864 
865 	w->state = smalloc(sizeof(DESstate));
866 	if(w->slen >= 16)
867 		setupDESstate(w->state, w->secret, w->secret+8);
868 	else if(w->slen >= 8)
869 		setupDESstate(w->state, w->secret, 0);
870 	else
871 		error("secret too short");
872 }
873 
874 /*
875  *  40 bit DES is the same as 56 bit DES.  However,
876  *  16 bits of the key are masked to zero.
877  */
878 static void
initDESkey_40(OneWay * w)879 initDESkey_40(OneWay *w)
880 {
881 	uchar key[8];
882 
883 	if(w->state){
884 		free(w->state);
885 		w->state = 0;
886 	}
887 
888 	if(w->slen >= 8){
889 		memmove(key, w->secret, 8);
890 		key[0] &= 0x0f;
891 		key[2] &= 0x0f;
892 		key[4] &= 0x0f;
893 		key[6] &= 0x0f;
894 	}
895 
896 	w->state = malloc(sizeof(DESstate));
897 	if(w->state == nil)
898 		error(Enomem);
899 	if(w->slen >= 16)
900 		setupDESstate(w->state, key, w->secret+8);
901 	else if(w->slen >= 8)
902 		setupDESstate(w->state, key, 0);
903 	else
904 		error("secret too short");
905 }
906 
907 static void
initRC4key(OneWay * w)908 initRC4key(OneWay *w)
909 {
910 	if(w->state){
911 		free(w->state);
912 		w->state = 0;
913 	}
914 
915 	w->state = smalloc(sizeof(RC4state));
916 	setupRC4state(w->state, w->secret, w->slen);
917 }
918 
919 /*
920  *  40 bit RC4 is the same as n-bit RC4.  However,
921  *  we ignore all but the first 40 bits of the key.
922  */
923 static void
initRC4key_40(OneWay * w)924 initRC4key_40(OneWay *w)
925 {
926 	if(w->state){
927 		free(w->state);
928 		w->state = 0;
929 	}
930 
931 	if(w->slen > 5)
932 		w->slen = 5;
933 
934 	w->state = malloc(sizeof(RC4state));
935 	if(w->state == nil)
936 		error(Enomem);
937 	setupRC4state(w->state, w->secret, w->slen);
938 }
939 
940 /*
941  *  128 bit RC4 is the same as n-bit RC4.  However,
942  *  we ignore all but the first 128 bits of the key.
943  */
944 static void
initRC4key_128(OneWay * w)945 initRC4key_128(OneWay *w)
946 {
947 	if(w->state){
948 		free(w->state);
949 		w->state = 0;
950 	}
951 
952 	if(w->slen > 16)
953 		w->slen = 16;
954 
955 	w->state = malloc(sizeof(RC4state));
956 	if(w->state == nil)
957 		error(Enomem);
958 	setupRC4state(w->state, w->secret, w->slen);
959 }
960 
961 
962 typedef struct Hashalg Hashalg;
963 struct Hashalg
964 {
965 	char	*name;
966 	int	diglen;
967 	DigestState *(*hf)(uchar*, ulong, uchar*, DigestState*);
968 };
969 
970 Hashalg hashtab[] =
971 {
972 	{ "md4", MD4dlen, md4, },
973 	{ "md5", MD5dlen, md5, },
974 	{ "sha1", SHA1dlen, sha1, },
975 	{ "sha", SHA1dlen, sha1, },
976 	{ 0 }
977 };
978 
979 static int
parsehashalg(char * p,Dstate * s)980 parsehashalg(char *p, Dstate *s)
981 {
982 	Hashalg *ha;
983 
984 	for(ha = hashtab; ha->name; ha++){
985 		if(strcmp(p, ha->name) == 0){
986 			s->hf = ha->hf;
987 			s->diglen = ha->diglen;
988 			s->state &= ~Sclear;
989 			s->state |= Sdigesting;
990 			return 0;
991 		}
992 	}
993 	return -1;
994 }
995 
996 typedef struct Encalg Encalg;
997 struct Encalg
998 {
999 	char	*name;
1000 	int	blocklen;
1001 	int	alg;
1002 	void	(*keyinit)(OneWay*);
1003 };
1004 
1005 #ifdef NOSPOOKS
1006 Encalg encrypttab[] =
1007 {
1008 	{ "descbc", 8, DESCBC, initDESkey, },           /* DEPRECATED -- use des_56_cbc */
1009 	{ "desecb", 8, DESECB, initDESkey, },           /* DEPRECATED -- use des_56_ecb */
1010 	{ "des_56_cbc", 8, DESCBC, initDESkey, },
1011 	{ "des_56_ecb", 8, DESECB, initDESkey, },
1012 	{ "des_40_cbc", 8, DESCBC, initDESkey_40, },
1013 	{ "des_40_ecb", 8, DESECB, initDESkey_40, },
1014 	{ "rc4", 1, RC4, initRC4key_40, },              /* DEPRECATED -- use rc4_X      */
1015 	{ "rc4_256", 1, RC4, initRC4key, },
1016 	{ "rc4_128", 1, RC4, initRC4key_128, },
1017 	{ "rc4_40", 1, RC4, initRC4key_40, },
1018 	{ 0 }
1019 };
1020 #else
1021 Encalg encrypttab[] =
1022 {
1023 	{ "des_40_cbc", 8, DESCBC, initDESkey_40, },
1024 	{ "des_40_ecb", 8, DESECB, initDESkey_40, },
1025 	{ "rc4", 1, RC4, initRC4key_40, },              /* DEPRECATED -- use rc4_X      */
1026 	{ "rc4_40", 1, RC4, initRC4key_40, },
1027 	{ 0 }
1028 };
1029 #endif NOSPOOKS
1030 
1031 static int
parseencryptalg(char * p,Dstate * s)1032 parseencryptalg(char *p, Dstate *s)
1033 {
1034 	Encalg *ea;
1035 
1036 	for(ea = encrypttab; ea->name; ea++){
1037 		if(strcmp(p, ea->name) == 0){
1038 			s->encryptalg = ea->alg;
1039 			s->blocklen = ea->blocklen;
1040 			(*ea->keyinit)(&s->in);
1041 			(*ea->keyinit)(&s->out);
1042 			s->state &= ~Sclear;
1043 			s->state |= Sencrypting;
1044 			return 0;
1045 		}
1046 	}
1047 	return -1;
1048 }
1049 
1050 static long
sslwrite(Chan * c,void * a,long n,vlong)1051 sslwrite(Chan *c, void *a, long n, vlong)
1052 {
1053 	Dstate * volatile s;
1054 	Block * volatile b;
1055 	int l, t;
1056 	char *p, *np, *e, buf[128];
1057 	uchar *x;
1058 
1059 	x = nil;
1060 	s = dstate[CONV(c->qid)];
1061 	if(s == 0)
1062 		panic("sslwrite");
1063 
1064 	t = TYPE(c->qid);
1065 	if(t == Qdata){
1066 		if(s->state == Sincomplete)
1067 			error(Ebadusefd);
1068 
1069 		/* lock should a write gets split over multiple records */
1070 		if(waserror()){
1071 			qunlock(&s->out.q);
1072 			nexterror();
1073 		}
1074 		qlock(&s->out.q);
1075 
1076 		p = a;
1077 		e = p + n;
1078 		do {
1079 			l = e - p;
1080 			if(l > s->max)
1081 				l = s->max;
1082 
1083 			b = allocb(l);
1084 			if(waserror()){
1085 				freeb(b);
1086 				nexterror();
1087 			}
1088 			memmove(b->wp, p, l);
1089 			poperror();
1090 			b->wp += l;
1091 
1092 			sslput(s, b);
1093 
1094 			p += l;
1095 		} while(p < e);
1096 
1097 		poperror();
1098 		qunlock(&s->out.q);
1099 		return n;
1100 	}
1101 
1102 	/* mutex with operations using what we're about to change */
1103 	if(waserror()){
1104 		qunlock(&s->in.ctlq);
1105 		qunlock(&s->out.q);
1106 		nexterror();
1107 	}
1108 	qlock(&s->in.ctlq);
1109 	qlock(&s->out.q);
1110 
1111 	switch(t){
1112 	default:
1113 		panic("sslwrite");
1114 	case Qsecretin:
1115 		setsecret(&s->in, a, n);
1116 		goto out;
1117 	case Qsecretout:
1118 		setsecret(&s->out, a, n);
1119 		goto out;
1120 	case Qctl:
1121 		break;
1122 	}
1123 
1124 	if(n >= sizeof(buf))
1125 		error("arg too long");
1126 	strncpy(buf, a, n);
1127 	buf[n] = 0;
1128 	p = strchr(buf, '\n');
1129 	if(p)
1130 		*p = 0;
1131 	p = strchr(buf, ' ');
1132 	if(p)
1133 		*p++ = 0;
1134 
1135 	if(waserror()){
1136 		free(x);
1137 		nexterror();
1138 	}
1139 	if(strcmp(buf, "fd") == 0){
1140 		s->c = buftochan(p);
1141 
1142 		/* default is clear (msg delimiters only) */
1143 		s->state = Sclear;
1144 		s->blocklen = 1;
1145 		s->diglen = 0;
1146 		s->maxpad = s->max = (1<<15) - s->diglen - 1;
1147 		s->in.mid = 0;
1148 		s->out.mid = 0;
1149 	} else if(strcmp(buf, "alg") == 0 && p != 0){
1150 		s->blocklen = 1;
1151 		s->diglen = 0;
1152 
1153 		if(s->c == 0)
1154 			error("must set fd before algorithm");
1155 
1156 		s->state = Sclear;
1157 		s->maxpad = s->max = (1<<15) - s->diglen - 1;
1158 		if(strcmp(p, "clear") == 0)
1159 			goto outx;
1160 
1161 		if(s->in.secret && s->out.secret == 0)
1162 			setsecret(&s->out, s->in.secret, s->in.slen);
1163 		if(s->out.secret && s->in.secret == 0)
1164 			setsecret(&s->in, s->out.secret, s->out.slen);
1165 		if(s->in.secret == 0 || s->out.secret == 0)
1166 			error("algorithm but no secret");
1167 
1168 		s->hf = 0;
1169 		s->encryptalg = Noencryption;
1170 		s->blocklen = 1;
1171 
1172 		for(;;){
1173 			np = strchr(p, ' ');
1174 			if(np)
1175 				*np++ = 0;
1176 
1177 			if(parsehashalg(p, s) < 0)
1178 			if(parseencryptalg(p, s) < 0)
1179 				error("bad algorithm");
1180 
1181 			if(np == 0)
1182 				break;
1183 			p = np;
1184 		}
1185 
1186 		if(s->hf == 0 && s->encryptalg == Noencryption)
1187 			error("bad algorithm");
1188 
1189 		if(s->blocklen != 1){
1190 			s->max = (1<<15) - s->diglen - 1;
1191 			s->max -= s->max % s->blocklen;
1192 			s->maxpad = (1<<14) - s->diglen - 1;
1193 			s->maxpad -= s->maxpad % s->blocklen;
1194 		} else
1195 			s->maxpad = s->max = (1<<15) - s->diglen - 1;
1196 	} else if(strcmp(buf, "secretin") == 0 && p != 0) {
1197 		l = (strlen(p)*3)/2;
1198 		x = smalloc(l);
1199 		t = dec64(x, l, p, strlen(p));
1200 		if(t <= 0)
1201 			error(Ebadarg);
1202 		setsecret(&s->in, x, t);
1203 	} else if(strcmp(buf, "secretout") == 0 && p != 0) {
1204 		l = (strlen(p)*3)/2 + 1;
1205 		x = smalloc(l);
1206 		t = dec64(x, l, p, strlen(p));
1207 		if(t <= 0)
1208 			error(Ebadarg);
1209 		setsecret(&s->out, x, t);
1210 	} else
1211 		error(Ebadarg);
1212 outx:
1213 	free(x);
1214 	poperror();
1215 out:
1216 	qunlock(&s->in.ctlq);
1217 	qunlock(&s->out.q);
1218 	poperror();
1219 	return n;
1220 }
1221 
1222 static void
sslinit(void)1223 sslinit(void)
1224 {
1225 	struct Encalg *e;
1226 	struct Hashalg *h;
1227 	int n;
1228 	char *cp;
1229 
1230 	n = 1;
1231 	for(e = encrypttab; e->name != nil; e++)
1232 		n += strlen(e->name) + 1;
1233 	cp = encalgs = smalloc(n);
1234 	for(e = encrypttab;;){
1235 		strcpy(cp, e->name);
1236 		cp += strlen(e->name);
1237 		e++;
1238 		if(e->name == nil)
1239 			break;
1240 		*cp++ = ' ';
1241 	}
1242 	*cp = 0;
1243 
1244 	n = 1;
1245 	for(h = hashtab; h->name != nil; h++)
1246 		n += strlen(h->name) + 1;
1247 	cp = hashalgs = smalloc(n);
1248 	for(h = hashtab;;){
1249 		strcpy(cp, h->name);
1250 		cp += strlen(h->name);
1251 		h++;
1252 		if(h->name == nil)
1253 			break;
1254 		*cp++ = ' ';
1255 	}
1256 	*cp = 0;
1257 }
1258 
1259 Dev ssldevtab = {
1260 	'D',
1261 	"ssl",
1262 
1263 	devreset,
1264 	sslinit,
1265 	devshutdown,
1266 	sslattach,
1267 	sslwalk,
1268 	sslstat,
1269 	sslopen,
1270 	devcreate,
1271 	sslclose,
1272 	sslread,
1273 	sslbread,
1274 	sslwrite,
1275 	sslbwrite,
1276 	devremove,
1277 	sslwstat,
1278 };
1279 
1280 static Block*
encryptb(Dstate * s,Block * b,int offset)1281 encryptb(Dstate *s, Block *b, int offset)
1282 {
1283 	uchar *p, *ep, *p2, *ip, *eip;
1284 	DESstate *ds;
1285 
1286 	switch(s->encryptalg){
1287 	case DESECB:
1288 		ds = s->out.state;
1289 		ep = b->rp + BLEN(b);
1290 		for(p = b->rp + offset; p < ep; p += 8)
1291 			block_cipher(ds->expanded, p, 0);
1292 		break;
1293 	case DESCBC:
1294 		ds = s->out.state;
1295 		ep = b->rp + BLEN(b);
1296 		for(p = b->rp + offset; p < ep; p += 8){
1297 			p2 = p;
1298 			ip = ds->ivec;
1299 			for(eip = ip+8; ip < eip; )
1300 				*p2++ ^= *ip++;
1301 			block_cipher(ds->expanded, p, 0);
1302 			memmove(ds->ivec, p, 8);
1303 		}
1304 		break;
1305 	case RC4:
1306 		rc4(s->out.state, b->rp + offset, BLEN(b) - offset);
1307 		break;
1308 	}
1309 	return b;
1310 }
1311 
1312 static Block*
decryptb(Dstate * s,Block * bin)1313 decryptb(Dstate *s, Block *bin)
1314 {
1315 	Block *b, **l;
1316 	uchar *p, *ep, *tp, *ip, *eip;
1317 	DESstate *ds;
1318 	uchar tmp[8];
1319 	int i;
1320 
1321 	l = &bin;
1322 	for(b = bin; b; b = b->next){
1323 		/* make sure we have a multiple of s->blocklen */
1324 		if(s->blocklen > 1){
1325 			i = BLEN(b);
1326 			if(i % s->blocklen){
1327 				*l = b = pullupblock(b, i + s->blocklen - (i%s->blocklen));
1328 				if(b == 0)
1329 					error("ssl encrypted message too short");
1330 			}
1331 		}
1332 		l = &b->next;
1333 
1334 		/* decrypt */
1335 		switch(s->encryptalg){
1336 		case DESECB:
1337 			ds = s->in.state;
1338 			ep = b->rp + BLEN(b);
1339 			for(p = b->rp; p < ep; p += 8)
1340 				block_cipher(ds->expanded, p, 1);
1341 			break;
1342 		case DESCBC:
1343 			ds = s->in.state;
1344 			ep = b->rp + BLEN(b);
1345 			for(p = b->rp; p < ep;){
1346 				memmove(tmp, p, 8);
1347 				block_cipher(ds->expanded, p, 1);
1348 				tp = tmp;
1349 				ip = ds->ivec;
1350 				for(eip = ip+8; ip < eip; ){
1351 					*p++ ^= *ip;
1352 					*ip++ = *tp++;
1353 				}
1354 			}
1355 			break;
1356 		case RC4:
1357 			rc4(s->in.state, b->rp, BLEN(b));
1358 			break;
1359 		}
1360 	}
1361 	return bin;
1362 }
1363 
1364 static Block*
digestb(Dstate * s,Block * b,int offset)1365 digestb(Dstate *s, Block *b, int offset)
1366 {
1367 	uchar *p;
1368 	DigestState ss;
1369 	uchar msgid[4];
1370 	ulong n, h;
1371 	OneWay *w;
1372 
1373 	w = &s->out;
1374 
1375 	memset(&ss, 0, sizeof(ss));
1376 	h = s->diglen + offset;
1377 	n = BLEN(b) - h;
1378 
1379 	/* hash secret + message */
1380 	(*s->hf)(w->secret, w->slen, 0, &ss);
1381 	(*s->hf)(b->rp + h, n, 0, &ss);
1382 
1383 	/* hash message id */
1384 	p = msgid;
1385 	n = w->mid;
1386 	*p++ = n>>24;
1387 	*p++ = n>>16;
1388 	*p++ = n>>8;
1389 	*p = n;
1390 	(*s->hf)(msgid, 4, b->rp + offset, &ss);
1391 
1392 	return b;
1393 }
1394 
1395 static void
checkdigestb(Dstate * s,Block * bin)1396 checkdigestb(Dstate *s, Block *bin)
1397 {
1398 	uchar *p;
1399 	DigestState ss;
1400 	uchar msgid[4];
1401 	int n, h;
1402 	OneWay *w;
1403 	uchar digest[128];
1404 	Block *b;
1405 
1406 	w = &s->in;
1407 
1408 	memset(&ss, 0, sizeof(ss));
1409 
1410 	/* hash secret */
1411 	(*s->hf)(w->secret, w->slen, 0, &ss);
1412 
1413 	/* hash message */
1414 	h = s->diglen;
1415 	for(b = bin; b; b = b->next){
1416 		n = BLEN(b) - h;
1417 		if(n < 0)
1418 			panic("checkdigestb");
1419 		(*s->hf)(b->rp + h, n, 0, &ss);
1420 		h = 0;
1421 	}
1422 
1423 	/* hash message id */
1424 	p = msgid;
1425 	n = w->mid;
1426 	*p++ = n>>24;
1427 	*p++ = n>>16;
1428 	*p++ = n>>8;
1429 	*p = n;
1430 	(*s->hf)(msgid, 4, digest, &ss);
1431 
1432 	if(memcmp(digest, bin->rp, s->diglen) != 0)
1433 		error("bad digest");
1434 }
1435 
1436 /* get channel associated with an fd */
1437 static Chan*
buftochan(char * p)1438 buftochan(char *p)
1439 {
1440 	Chan *c;
1441 	int fd;
1442 
1443 	if(p == 0)
1444 		error(Ebadarg);
1445 	fd = strtoul(p, 0, 0);
1446 	if(fd < 0)
1447 		error(Ebadarg);
1448 	c = fdtochan(fd, -1, 0, 1);	/* error check and inc ref */
1449 	if(c->dev == &ssldevtab){
1450 		cclose(c);
1451 		error("cannot ssl encrypt devssl files");
1452 	}
1453 	return c;
1454 }
1455 
1456 /* hand up a digest connection */
1457 static void
sslhangup(Dstate * s)1458 sslhangup(Dstate *s)
1459 {
1460 	Block *b;
1461 
1462 	qlock(&s->in.q);
1463 	for(b = s->processed; b; b = s->processed){
1464 		s->processed = b->next;
1465 		freeb(b);
1466 	}
1467 	if(s->unprocessed){
1468 		freeb(s->unprocessed);
1469 		s->unprocessed = 0;
1470 	}
1471 	s->state = Sincomplete;
1472 	qunlock(&s->in.q);
1473 }
1474 
1475 static Dstate*
dsclone(Chan * ch)1476 dsclone(Chan *ch)
1477 {
1478 	int i;
1479 	Dstate *ret;
1480 
1481 	if(waserror()) {
1482 		unlock(&dslock);
1483 		nexterror();
1484 	}
1485 	lock(&dslock);
1486 	ret = nil;
1487 	for(i=0; i<Maxdstate; i++){
1488 		if(dstate[i] == nil){
1489 			dsnew(ch, &dstate[i]);
1490 			ret = dstate[i];
1491 			break;
1492 		}
1493 	}
1494 	unlock(&dslock);
1495 	poperror();
1496 	return ret;
1497 }
1498 
1499 static void
dsnew(Chan * ch,Dstate ** pp)1500 dsnew(Chan *ch, Dstate **pp)
1501 {
1502 	Dstate *s;
1503 	int t;
1504 
1505 	*pp = s = malloc(sizeof(*s));
1506 	if(!s)
1507 		error(Enomem);
1508 	if(pp - dstate >= dshiwat)
1509 		dshiwat++;
1510 	memset(s, 0, sizeof(*s));
1511 	s->state = Sincomplete;
1512 	s->ref = 1;
1513 	kstrdup(&s->user, up->user);
1514 	s->perm = 0660;
1515 	t = TYPE(ch->qid);
1516 	if(t == Qclonus)
1517 		t = Qctl;
1518 	ch->qid.path = QID(pp - dstate, t);
1519 	ch->qid.vers = 0;
1520 }
1521