xref: /plan9/sys/src/9/port/devssl.c (revision 4e3613ab15c331a9ada113286cc0f2a35bc0373d)
1 /*
2  *  devssl - secure sockets layer
3  */
4 #include	"u.h"
5 #include	"../port/lib.h"
6 #include	"mem.h"
7 #include	"dat.h"
8 #include	"fns.h"
9 #include	"../port/error.h"
10 
11 #include	<libsec.h>
12 
13 #define NOSPOOKS 1
14 
15 typedef struct OneWay OneWay;
16 struct OneWay
17 {
18 	QLock	q;
19 	QLock	ctlq;
20 
21 	void	*state;		/* encryption state */
22 	int	slen;		/* hash data length */
23 	uchar	*secret;	/* secret */
24 	ulong	mid;		/* message id */
25 };
26 
27 enum
28 {
29 	/* connection states */
30 	Sincomplete=	0,
31 	Sclear=		1,
32 	Sencrypting=	2,
33 	Sdigesting=	4,
34 	Sdigenc=	Sencrypting|Sdigesting,
35 
36 	/* encryption algorithms */
37 	Noencryption=	0,
38 	DESCBC=		1,
39 	DESECB=		2,
40 	RC4=		3
41 };
42 
43 typedef struct Dstate Dstate;
44 struct Dstate
45 {
46 	Chan	*c;		/* io channel */
47 	uchar	state;		/* state of connection */
48 	int	ref;		/* serialized by dslock for atomic destroy */
49 
50 	uchar	encryptalg;	/* encryption algorithm */
51 	ushort	blocklen;	/* blocking length */
52 
53 	ushort	diglen;		/* length of digest */
54 	DigestState *(*hf)(uchar*, ulong, uchar*, DigestState*);	/* hash func */
55 
56 	/* for SSL format */
57 	int	max;		/* maximum unpadded data per msg */
58 	int	maxpad;		/* maximum padded data per msg */
59 
60 	/* input side */
61 	OneWay	in;
62 	Block	*processed;
63 	Block	*unprocessed;
64 
65 	/* output side */
66 	OneWay	out;
67 
68 	/* protections */
69 	char	*user;
70 	int	perm;
71 };
72 
73 enum
74 {
75 	Maxdmsg=	1<<16,
76 	Maxdstate=	512,	/* max. open ssl conn's; must be a power of 2 */
77 };
78 
79 static	Lock	dslock;
80 static	int	dshiwat;
81 static	char	*dsname[Maxdstate];
82 static	Dstate	*dstate[Maxdstate];
83 static	char	*encalgs;
84 static	char	*hashalgs;
85 
86 enum{
87 	Qtopdir		= 1,	/* top level directory */
88 	Qprotodir,
89 	Qclonus,
90 	Qconvdir,		/* directory for a conversation */
91 	Qdata,
92 	Qctl,
93 	Qsecretin,
94 	Qsecretout,
95 	Qencalgs,
96 	Qhashalgs,
97 };
98 
99 #define TYPE(x) 	((x).path & 0xf)
100 #define CONV(x) 	(((x).path >> 5)&(Maxdstate-1))
101 #define QID(c, y) 	(((c)<<5) | (y))
102 
103 static void	ensure(Dstate*, Block**, int);
104 static void	consume(Block**, uchar*, int);
105 static void	setsecret(OneWay*, uchar*, int);
106 static Block*	encryptb(Dstate*, Block*, int);
107 static Block*	decryptb(Dstate*, Block*);
108 static Block*	digestb(Dstate*, Block*, int);
109 static void	checkdigestb(Dstate*, Block*);
110 static Chan*	buftochan(char*);
111 static void	sslhangup(Dstate*);
112 static Dstate*	dsclone(Chan *c);
113 static void	dsnew(Chan *c, Dstate **);
114 static long	sslput(Dstate *s, Block * volatile b);
115 
116 char *sslnames[] = {
117 [Qclonus]	"clone",
118 [Qdata]		"data",
119 [Qctl]		"ctl",
120 [Qsecretin]	"secretin",
121 [Qsecretout]	"secretout",
122 [Qencalgs]	"encalgs",
123 [Qhashalgs]	"hashalgs",
124 };
125 
126 static int
sslgen(Chan * c,char *,Dirtab * d,int nd,int s,Dir * dp)127 sslgen(Chan *c, char*, Dirtab *d, int nd, int s, Dir *dp)
128 {
129 	Qid q;
130 	Dstate *ds;
131 	char name[16], *p, *nm;
132 	int ft;
133 
134 	USED(nd);
135 	USED(d);
136 
137 	q.type = QTFILE;
138 	q.vers = 0;
139 
140 	ft = TYPE(c->qid);
141 	switch(ft) {
142 	case Qtopdir:
143 		if(s == DEVDOTDOT){
144 			q.path = QID(0, Qtopdir);
145 			q.type = QTDIR;
146 			devdir(c, q, "#D", 0, eve, 0555, dp);
147 			return 1;
148 		}
149 		if(s > 0)
150 			return -1;
151 		q.path = QID(0, Qprotodir);
152 		q.type = QTDIR;
153 		devdir(c, q, "ssl", 0, eve, 0555, dp);
154 		return 1;
155 	case Qprotodir:
156 		if(s == DEVDOTDOT){
157 			q.path = QID(0, Qtopdir);
158 			q.type = QTDIR;
159 			devdir(c, q, ".", 0, eve, 0555, dp);
160 			return 1;
161 		}
162 		if(s < dshiwat) {
163 			q.path = QID(s, Qconvdir);
164 			q.type = QTDIR;
165 			ds = dstate[s];
166 			if(ds != 0)
167 				nm = ds->user;
168 			else
169 				nm = eve;
170 			if(dsname[s] == nil){
171 				snprint(name, sizeof name, "%d", s);
172 				kstrdup(&dsname[s], name);
173 			}
174 			devdir(c, q, dsname[s], 0, nm, 0555, dp);
175 			return 1;
176 		}
177 		if(s > dshiwat)
178 			return -1;
179 		q.path = QID(0, Qclonus);
180 		devdir(c, q, "clone", 0, eve, 0555, dp);
181 		return 1;
182 	case Qconvdir:
183 		if(s == DEVDOTDOT){
184 			q.path = QID(0, Qprotodir);
185 			q.type = QTDIR;
186 			devdir(c, q, "ssl", 0, eve, 0555, dp);
187 			return 1;
188 		}
189 		ds = dstate[CONV(c->qid)];
190 		if(ds != 0)
191 			nm = ds->user;
192 		else
193 			nm = eve;
194 		switch(s) {
195 		default:
196 			return -1;
197 		case 0:
198 			q.path = QID(CONV(c->qid), Qctl);
199 			p = "ctl";
200 			break;
201 		case 1:
202 			q.path = QID(CONV(c->qid), Qdata);
203 			p = "data";
204 			break;
205 		case 2:
206 			q.path = QID(CONV(c->qid), Qsecretin);
207 			p = "secretin";
208 			break;
209 		case 3:
210 			q.path = QID(CONV(c->qid), Qsecretout);
211 			p = "secretout";
212 			break;
213 		case 4:
214 			q.path = QID(CONV(c->qid), Qencalgs);
215 			p = "encalgs";
216 			break;
217 		case 5:
218 			q.path = QID(CONV(c->qid), Qhashalgs);
219 			p = "hashalgs";
220 			break;
221 		}
222 		devdir(c, q, p, 0, nm, 0660, dp);
223 		return 1;
224 	case Qclonus:
225 		devdir(c, c->qid, sslnames[TYPE(c->qid)], 0, eve, 0555, dp);
226 		return 1;
227 	default:
228 		ds = dstate[CONV(c->qid)];
229 		if(ds != 0)
230 			nm = ds->user;
231 		else
232 			nm = eve;
233 		devdir(c, c->qid, sslnames[TYPE(c->qid)], 0, nm, 0660, dp);
234 		return 1;
235 	}
236 }
237 
238 static Chan*
sslattach(char * spec)239 sslattach(char *spec)
240 {
241 	Chan *c;
242 
243 	c = devattach('D', spec);
244 	c->qid.path = QID(0, Qtopdir);
245 	c->qid.vers = 0;
246 	c->qid.type = QTDIR;
247 	return c;
248 }
249 
250 static Walkqid*
sslwalk(Chan * c,Chan * nc,char ** name,int nname)251 sslwalk(Chan *c, Chan *nc, char **name, int nname)
252 {
253 	return devwalk(c, nc, name, nname, nil, 0, sslgen);
254 }
255 
256 static int
sslstat(Chan * c,uchar * db,int n)257 sslstat(Chan *c, uchar *db, int n)
258 {
259 	return devstat(c, db, n, nil, 0, sslgen);
260 }
261 
262 static Chan*
sslopen(Chan * c,int omode)263 sslopen(Chan *c, int omode)
264 {
265 	Dstate *s, **pp;
266 	int perm;
267 	int ft;
268 
269 	perm = 0;
270 	omode &= 3;
271 	switch(omode) {
272 	case OREAD:
273 		perm = 4;
274 		break;
275 	case OWRITE:
276 		perm = 2;
277 		break;
278 	case ORDWR:
279 		perm = 6;
280 		break;
281 	}
282 
283 	ft = TYPE(c->qid);
284 	switch(ft) {
285 	default:
286 		panic("sslopen");
287 	case Qtopdir:
288 	case Qprotodir:
289 	case Qconvdir:
290 		if(omode != OREAD)
291 			error(Eperm);
292 		break;
293 	case Qclonus:
294 		s = dsclone(c);
295 		if(s == 0)
296 			error(Enodev);
297 		break;
298 	case Qctl:
299 	case Qdata:
300 	case Qsecretin:
301 	case Qsecretout:
302 		if(waserror()) {
303 			unlock(&dslock);
304 			nexterror();
305 		}
306 		lock(&dslock);
307 		pp = &dstate[CONV(c->qid)];
308 		s = *pp;
309 		if(s == 0)
310 			dsnew(c, pp);
311 		else {
312 			if((perm & (s->perm>>6)) != perm
313 			   && (strcmp(up->user, s->user) != 0
314 			     || (perm & s->perm) != perm))
315 				error(Eperm);
316 
317 			s->ref++;
318 		}
319 		unlock(&dslock);
320 		poperror();
321 		break;
322 	case Qencalgs:
323 	case Qhashalgs:
324 		if(omode != OREAD)
325 			error(Eperm);
326 		break;
327 	}
328 	c->mode = openmode(omode);
329 	c->flag |= COPEN;
330 	c->offset = 0;
331 	return c;
332 }
333 
334 static int
sslwstat(Chan * c,uchar * db,int n)335 sslwstat(Chan *c, uchar *db, int n)
336 {
337 	Dir *dir;
338 	Dstate *s;
339 	int m;
340 
341 	s = dstate[CONV(c->qid)];
342 	if(s == 0)
343 		error(Ebadusefd);
344 	if(strcmp(s->user, up->user) != 0)
345 		error(Eperm);
346 
347 	dir = smalloc(sizeof(Dir)+n);
348 	m = convM2D(db, n, &dir[0], (char*)&dir[1]);
349 	if(m == 0){
350 		free(dir);
351 		error(Eshortstat);
352 	}
353 
354 	if(!emptystr(dir->uid))
355 		kstrdup(&s->user, dir->uid);
356 	if(dir->mode != ~0UL)
357 		s->perm = dir->mode;
358 
359 	free(dir);
360 	return m;
361 }
362 
363 static void
sslclose(Chan * c)364 sslclose(Chan *c)
365 {
366 	Dstate *s;
367 	int ft;
368 
369 	ft = TYPE(c->qid);
370 	switch(ft) {
371 	case Qctl:
372 	case Qdata:
373 	case Qsecretin:
374 	case Qsecretout:
375 		if((c->flag & COPEN) == 0)
376 			break;
377 
378 		s = dstate[CONV(c->qid)];
379 		if(s == 0)
380 			break;
381 
382 		lock(&dslock);
383 		if(--s->ref > 0) {
384 			unlock(&dslock);
385 			break;
386 		}
387 		dstate[CONV(c->qid)] = 0;
388 		unlock(&dslock);
389 
390 		if(s->user != nil)
391 			free(s->user);
392 		sslhangup(s);
393 		if(s->c)
394 			cclose(s->c);
395 		if(s->in.secret)
396 			free(s->in.secret);
397 		if(s->out.secret)
398 			free(s->out.secret);
399 		if(s->in.state)
400 			free(s->in.state);
401 		if(s->out.state)
402 			free(s->out.state);
403 		free(s);
404 
405 	}
406 }
407 
408 /*
409  *  make sure we have at least 'n' bytes in list 'l'
410  */
411 static void
ensure(Dstate * s,Block ** l,int n)412 ensure(Dstate *s, Block **l, int n)
413 {
414 	int sofar, i;
415 	Block *b, *bl;
416 
417 	sofar = 0;
418 	for(b = *l; b; b = b->next){
419 		sofar += BLEN(b);
420 		if(sofar >= n)
421 			return;
422 		l = &b->next;
423 	}
424 
425 	while(sofar < n){
426 		bl = devtab[s->c->type]->bread(s->c, Maxdmsg, 0);
427 		if(bl == 0)
428 			nexterror();
429 		*l = bl;
430 		i = 0;
431 		for(b = bl; b; b = b->next){
432 			i += BLEN(b);
433 			l = &b->next;
434 		}
435 		if(i == 0)
436 			error(Ehungup);
437 		sofar += i;
438 	}
439 }
440 
441 /*
442  *  copy 'n' bytes from 'l' into 'p' and free
443  *  the bytes in 'l'
444  */
445 static void
consume(Block ** l,uchar * p,int n)446 consume(Block **l, uchar *p, int n)
447 {
448 	Block *b;
449 	int i;
450 
451 	for(; *l && n > 0; n -= i){
452 		b = *l;
453 		i = BLEN(b);
454 		if(i > n)
455 			i = n;
456 		memmove(p, b->rp, i);
457 		b->rp += i;
458 		p += i;
459 		if(BLEN(b) < 0)
460 			panic("consume");
461 		if(BLEN(b))
462 			break;
463 		*l = b->next;
464 		freeb(b);
465 	}
466 }
467 
468 /*
469  *  give back n bytes
470 static void
471 regurgitate(Dstate *s, uchar *p, int n)
472 {
473 	Block *b;
474 
475 	if(n <= 0)
476 		return;
477 	b = s->unprocessed;
478 	if(s->unprocessed == nil || b->rp - b->base < n) {
479 		b = allocb(n);
480 		memmove(b->wp, p, n);
481 		b->wp += n;
482 		b->next = s->unprocessed;
483 		s->unprocessed = b;
484 	} else {
485 		b->rp -= n;
486 		memmove(b->rp, p, n);
487 	}
488 }
489  */
490 
491 /*
492  *  remove at most n bytes from the queue, if discard is set
493  *  dump the remainder
494  */
495 static Block*
qtake(Block ** l,int n,int discard)496 qtake(Block **l, int n, int discard)
497 {
498 	Block *nb, *b, *first;
499 	int i;
500 
501 	first = *l;
502 	for(b = first; b; b = b->next){
503 		i = BLEN(b);
504 		if(i == n){
505 			if(discard){
506 				freeblist(b->next);
507 				*l = 0;
508 			} else
509 				*l = b->next;
510 			b->next = 0;
511 			return first;
512 		} else if(i > n){
513 			i -= n;
514 			if(discard){
515 				freeblist(b->next);
516 				b->wp -= i;
517 				*l = 0;
518 			} else {
519 				nb = allocb(i);
520 				memmove(nb->wp, b->rp+n, i);
521 				nb->wp += i;
522 				b->wp -= i;
523 				nb->next = b->next;
524 				*l = nb;
525 			}
526 			b->next = 0;
527 			if(BLEN(b) < 0)
528 				panic("qtake");
529 			return first;
530 		} else
531 			n -= i;
532 		if(BLEN(b) < 0)
533 			panic("qtake");
534 	}
535 	*l = 0;
536 	return first;
537 }
538 
539 /*
540  *  We can't let Eintr's lose data since the program
541  *  doing the read may be able to handle it.  The only
542  *  places Eintr is possible is during the read's in consume.
543  *  Therefore, we make sure we can always put back the bytes
544  *  consumed before the last ensure.
545  */
546 static Block*
sslbread(Chan * c,long n,ulong)547 sslbread(Chan *c, long n, ulong)
548 {
549 	Dstate * volatile s;
550 	Block *b;
551 	uchar consumed[3], *p;
552 	int toconsume;
553 	int len, pad;
554 
555 	s = dstate[CONV(c->qid)];
556 	if(s == 0)
557 		panic("sslbread");
558 	if(s->state == Sincomplete)
559 		error(Ebadusefd);
560 
561 	qlock(&s->in.q);
562 	if(waserror()){
563 		qunlock(&s->in.q);
564 		nexterror();
565 	}
566 
567 	if(s->processed == 0){
568 		/*
569 		 * Read in the whole message.  Until we've got it all,
570 		 * it stays on s->unprocessed, so that if we get Eintr,
571 		 * we'll pick up where we left off.
572 		 */
573 		ensure(s, &s->unprocessed, 3);
574 		s->unprocessed = pullupblock(s->unprocessed, 2);
575 		p = s->unprocessed->rp;
576 		if(p[0] & 0x80){
577 			len = ((p[0] & 0x7f)<<8) | p[1];
578 			ensure(s, &s->unprocessed, len);
579 			pad = 0;
580 			toconsume = 2;
581 		} else {
582 			s->unprocessed = pullupblock(s->unprocessed, 3);
583 			len = ((p[0] & 0x3f)<<8) | p[1];
584 			pad = p[2];
585 			if(pad > len){
586 				print("pad %d buf len %d\n", pad, len);
587 				error("bad pad in ssl message");
588 			}
589 			toconsume = 3;
590 		}
591 		ensure(s, &s->unprocessed, toconsume+len);
592 
593 		/* skip header */
594 		consume(&s->unprocessed, consumed, toconsume);
595 
596 		/* grab the next message and decode/decrypt it */
597 		b = qtake(&s->unprocessed, len, 0);
598 
599 		if(blocklen(b) != len)
600 			print("devssl: sslbread got wrong count %d != %d", blocklen(b), len);
601 
602 		if(waserror()){
603 			qunlock(&s->in.ctlq);
604 			if(b != nil)
605 				freeb(b);
606 			nexterror();
607 		}
608 		qlock(&s->in.ctlq);
609 		switch(s->state){
610 		case Sencrypting:
611 			if(b == nil)
612 				error("ssl message too short (encrypting)");
613 			b = decryptb(s, b);
614 			break;
615 		case Sdigesting:
616 			b = pullupblock(b, s->diglen);
617 			if(b == nil)
618 				error("ssl message too short (digesting)");
619 			checkdigestb(s, b);
620 			pullblock(&b, s->diglen);
621 			len -= s->diglen;
622 			break;
623 		case Sdigenc:
624 			b = decryptb(s, b);
625 			b = pullupblock(b, s->diglen);
626 			if(b == nil)
627 				error("ssl message too short (dig+enc)");
628 			checkdigestb(s, b);
629 			pullblock(&b, s->diglen);
630 			len -= s->diglen;
631 			break;
632 		}
633 
634 		/* remove pad */
635 		if(pad)
636 			s->processed = qtake(&b, len - pad, 1);
637 		else
638 			s->processed = b;
639 		b = nil;
640 		s->in.mid++;
641 		qunlock(&s->in.ctlq);
642 		poperror();
643 	}
644 
645 	/* return at most what was asked for */
646 	b = qtake(&s->processed, n, 0);
647 
648 	qunlock(&s->in.q);
649 	poperror();
650 
651 	return b;
652 }
653 
654 static long
sslread(Chan * c,void * a,long n,vlong off)655 sslread(Chan *c, void *a, long n, vlong off)
656 {
657 	Block * volatile b;
658 	Block *nb;
659 	uchar *va;
660 	int i;
661 	char buf[128];
662 	ulong offset = off;
663 	int ft;
664 
665 	if(c->qid.type & QTDIR)
666 		return devdirread(c, a, n, 0, 0, sslgen);
667 
668 	ft = TYPE(c->qid);
669 	switch(ft) {
670 	default:
671 		error(Ebadusefd);
672 	case Qctl:
673 		ft = CONV(c->qid);
674 		snprint(buf, sizeof buf, "%d", ft);
675 		return readstr(offset, a, n, buf);
676 	case Qdata:
677 		b = sslbread(c, n, offset);
678 		break;
679 	case Qencalgs:
680 		return readstr(offset, a, n, encalgs);
681 		break;
682 	case Qhashalgs:
683 		return readstr(offset, a, n, hashalgs);
684 		break;
685 	}
686 
687 	if(waserror()){
688 		freeblist(b);
689 		nexterror();
690 	}
691 
692 	n = 0;
693 	va = a;
694 	for(nb = b; nb; nb = nb->next){
695 		i = BLEN(nb);
696 		memmove(va+n, nb->rp, i);
697 		n += i;
698 	}
699 
700 	freeblist(b);
701 	poperror();
702 
703 	return n;
704 }
705 
706 /*
707  *  this algorithm doesn't have to be great since we're just
708  *  trying to obscure the block fill
709  */
710 static void
randfill(uchar * buf,int len)711 randfill(uchar *buf, int len)
712 {
713 	while(len-- > 0)
714 		*buf++ = nrand(256);
715 }
716 
717 static long
sslbwrite(Chan * c,Block * b,ulong)718 sslbwrite(Chan *c, Block *b, ulong)
719 {
720 	Dstate * volatile s;
721 	long rv;
722 
723 	s = dstate[CONV(c->qid)];
724 	if(s == nil)
725 		panic("sslbwrite");
726 
727 	if(s->state == Sincomplete){
728 		freeb(b);
729 		error(Ebadusefd);
730 	}
731 
732 	/* lock so split writes won't interleave */
733 	if(waserror()){
734 		qunlock(&s->out.q);
735 		nexterror();
736 	}
737 	qlock(&s->out.q);
738 
739 	rv = sslput(s, b);
740 
741 	poperror();
742 	qunlock(&s->out.q);
743 
744 	return rv;
745 }
746 
747 /*
748  *  use SSL record format, add in count, digest and/or encrypt.
749  *  the write is interruptable.  if it is interrupted, we'll
750  *  get out of sync with the far side.  not much we can do about
751  *  it since we don't know if any bytes have been written.
752  */
753 static long
sslput(Dstate * s,Block * volatile b)754 sslput(Dstate *s, Block * volatile b)
755 {
756 	Block *nb;
757 	int h, n, m, pad, rv;
758 	uchar *p;
759 	int offset;
760 
761 	if(waserror()){
762 		if(b != nil)
763 			freeb(b);
764 		nexterror();
765 	}
766 
767 	rv = 0;
768 	while(b != nil){
769 		m = n = BLEN(b);
770 		h = s->diglen + 2;
771 
772 		/* trim to maximum block size */
773 		pad = 0;
774 		if(m > s->max){
775 			m = s->max;
776 		} else if(s->blocklen != 1){
777 			pad = (m + s->diglen)%s->blocklen;
778 			if(pad){
779 				if(m > s->maxpad){
780 					pad = 0;
781 					m = s->maxpad;
782 				} else {
783 					pad = s->blocklen - pad;
784 					h++;
785 				}
786 			}
787 		}
788 
789 		rv += m;
790 		if(m != n){
791 			nb = allocb(m + h + pad);
792 			memmove(nb->wp + h, b->rp, m);
793 			nb->wp += m + h;
794 			b->rp += m;
795 		} else {
796 			/* add header space */
797 			nb = padblock(b, h);
798 			b = 0;
799 		}
800 		m += s->diglen;
801 
802 		/* SSL style count */
803 		if(pad){
804 			nb = padblock(nb, -pad);
805 			randfill(nb->wp, pad);
806 			nb->wp += pad;
807 			m += pad;
808 
809 			p = nb->rp;
810 			p[0] = (m>>8);
811 			p[1] = m;
812 			p[2] = pad;
813 			offset = 3;
814 		} else {
815 			p = nb->rp;
816 			p[0] = (m>>8) | 0x80;
817 			p[1] = m;
818 			offset = 2;
819 		}
820 
821 		switch(s->state){
822 		case Sencrypting:
823 			nb = encryptb(s, nb, offset);
824 			break;
825 		case Sdigesting:
826 			nb = digestb(s, nb, offset);
827 			break;
828 		case Sdigenc:
829 			nb = digestb(s, nb, offset);
830 			nb = encryptb(s, nb, offset);
831 			break;
832 		}
833 
834 		s->out.mid++;
835 
836 		m = BLEN(nb);
837 		devtab[s->c->type]->bwrite(s->c, nb, s->c->offset);
838 		s->c->offset += m;
839 	}
840 
841 	poperror();
842 	return rv;
843 }
844 
845 static void
setsecret(OneWay * w,uchar * secret,int n)846 setsecret(OneWay *w, uchar *secret, int n)
847 {
848 	if(w->secret)
849 		free(w->secret);
850 
851 	w->secret = smalloc(n);
852 	memmove(w->secret, secret, n);
853 	w->slen = n;
854 }
855 
856 static void
initDESkey(OneWay * w)857 initDESkey(OneWay *w)
858 {
859 	if(w->state){
860 		free(w->state);
861 		w->state = 0;
862 	}
863 
864 	w->state = smalloc(sizeof(DESstate));
865 	if(w->slen >= 16)
866 		setupDESstate(w->state, w->secret, w->secret+8);
867 	else if(w->slen >= 8)
868 		setupDESstate(w->state, w->secret, 0);
869 	else
870 		error("secret too short");
871 }
872 
873 /*
874  *  40 bit DES is the same as 56 bit DES.  However,
875  *  16 bits of the key are masked to zero.
876  */
877 static void
initDESkey_40(OneWay * w)878 initDESkey_40(OneWay *w)
879 {
880 	uchar key[8];
881 
882 	if(w->state){
883 		free(w->state);
884 		w->state = 0;
885 	}
886 
887 	if(w->slen >= 8){
888 		memmove(key, w->secret, 8);
889 		key[0] &= 0x0f;
890 		key[2] &= 0x0f;
891 		key[4] &= 0x0f;
892 		key[6] &= 0x0f;
893 	}
894 
895 	w->state = malloc(sizeof(DESstate));
896 	if(w->state == nil)
897 		error(Enomem);
898 	if(w->slen >= 16)
899 		setupDESstate(w->state, key, w->secret+8);
900 	else if(w->slen >= 8)
901 		setupDESstate(w->state, key, 0);
902 	else
903 		error("secret too short");
904 }
905 
906 static void
initRC4key(OneWay * w)907 initRC4key(OneWay *w)
908 {
909 	if(w->state){
910 		free(w->state);
911 		w->state = 0;
912 	}
913 
914 	w->state = smalloc(sizeof(RC4state));
915 	setupRC4state(w->state, w->secret, w->slen);
916 }
917 
918 /*
919  *  40 bit RC4 is the same as n-bit RC4.  However,
920  *  we ignore all but the first 40 bits of the key.
921  */
922 static void
initRC4key_40(OneWay * w)923 initRC4key_40(OneWay *w)
924 {
925 	if(w->state){
926 		free(w->state);
927 		w->state = 0;
928 	}
929 
930 	if(w->slen > 5)
931 		w->slen = 5;
932 
933 	w->state = malloc(sizeof(RC4state));
934 	if(w->state == nil)
935 		error(Enomem);
936 	setupRC4state(w->state, w->secret, w->slen);
937 }
938 
939 /*
940  *  128 bit RC4 is the same as n-bit RC4.  However,
941  *  we ignore all but the first 128 bits of the key.
942  */
943 static void
initRC4key_128(OneWay * w)944 initRC4key_128(OneWay *w)
945 {
946 	if(w->state){
947 		free(w->state);
948 		w->state = 0;
949 	}
950 
951 	if(w->slen > 16)
952 		w->slen = 16;
953 
954 	w->state = malloc(sizeof(RC4state));
955 	if(w->state == nil)
956 		error(Enomem);
957 	setupRC4state(w->state, w->secret, w->slen);
958 }
959 
960 
961 typedef struct Hashalg Hashalg;
962 struct Hashalg
963 {
964 	char	*name;
965 	int	diglen;
966 	DigestState *(*hf)(uchar*, ulong, uchar*, DigestState*);
967 };
968 
969 Hashalg hashtab[] =
970 {
971 	{ "md4", MD4dlen, md4, },
972 	{ "md5", MD5dlen, md5, },
973 	{ "sha1", SHA1dlen, sha1, },
974 	{ "sha", SHA1dlen, sha1, },
975 	{ 0 }
976 };
977 
978 static int
parsehashalg(char * p,Dstate * s)979 parsehashalg(char *p, Dstate *s)
980 {
981 	Hashalg *ha;
982 
983 	for(ha = hashtab; ha->name; ha++){
984 		if(strcmp(p, ha->name) == 0){
985 			s->hf = ha->hf;
986 			s->diglen = ha->diglen;
987 			s->state &= ~Sclear;
988 			s->state |= Sdigesting;
989 			return 0;
990 		}
991 	}
992 	return -1;
993 }
994 
995 typedef struct Encalg Encalg;
996 struct Encalg
997 {
998 	char	*name;
999 	int	blocklen;
1000 	int	alg;
1001 	void	(*keyinit)(OneWay*);
1002 };
1003 
1004 #ifdef NOSPOOKS
1005 Encalg encrypttab[] =
1006 {
1007 	{ "descbc", 8, DESCBC, initDESkey, },           /* DEPRECATED -- use des_56_cbc */
1008 	{ "desecb", 8, DESECB, initDESkey, },           /* DEPRECATED -- use des_56_ecb */
1009 	{ "des_56_cbc", 8, DESCBC, initDESkey, },
1010 	{ "des_56_ecb", 8, DESECB, initDESkey, },
1011 	{ "des_40_cbc", 8, DESCBC, initDESkey_40, },
1012 	{ "des_40_ecb", 8, DESECB, initDESkey_40, },
1013 	{ "rc4", 1, RC4, initRC4key_40, },              /* DEPRECATED -- use rc4_X      */
1014 	{ "rc4_256", 1, RC4, initRC4key, },
1015 	{ "rc4_128", 1, RC4, initRC4key_128, },
1016 	{ "rc4_40", 1, RC4, initRC4key_40, },
1017 	{ 0 }
1018 };
1019 #else
1020 Encalg encrypttab[] =
1021 {
1022 	{ "des_40_cbc", 8, DESCBC, initDESkey_40, },
1023 	{ "des_40_ecb", 8, DESECB, initDESkey_40, },
1024 	{ "rc4", 1, RC4, initRC4key_40, },              /* DEPRECATED -- use rc4_X      */
1025 	{ "rc4_40", 1, RC4, initRC4key_40, },
1026 	{ 0 }
1027 };
1028 #endif NOSPOOKS
1029 
1030 static int
parseencryptalg(char * p,Dstate * s)1031 parseencryptalg(char *p, Dstate *s)
1032 {
1033 	Encalg *ea;
1034 
1035 	for(ea = encrypttab; ea->name; ea++){
1036 		if(strcmp(p, ea->name) == 0){
1037 			s->encryptalg = ea->alg;
1038 			s->blocklen = ea->blocklen;
1039 			(*ea->keyinit)(&s->in);
1040 			(*ea->keyinit)(&s->out);
1041 			s->state &= ~Sclear;
1042 			s->state |= Sencrypting;
1043 			return 0;
1044 		}
1045 	}
1046 	return -1;
1047 }
1048 
1049 static long
sslwrite(Chan * c,void * a,long n,vlong)1050 sslwrite(Chan *c, void *a, long n, vlong)
1051 {
1052 	Dstate * volatile s;
1053 	Block * volatile b;
1054 	int m, t;
1055 	char *p, *np, *e, buf[128];
1056 	uchar *x;
1057 
1058 	x = nil;
1059 	s = dstate[CONV(c->qid)];
1060 	if(s == 0)
1061 		panic("sslwrite");
1062 
1063 	t = TYPE(c->qid);
1064 	if(t == Qdata){
1065 		if(s->state == Sincomplete)
1066 			error(Ebadusefd);
1067 
1068 		/* lock should a write gets split over multiple records */
1069 		if(waserror()){
1070 			qunlock(&s->out.q);
1071 			nexterror();
1072 		}
1073 		qlock(&s->out.q);
1074 
1075 		p = a;
1076 		e = p + n;
1077 		do {
1078 			m = e - p;
1079 			if(m > s->max)
1080 				m = s->max;
1081 
1082 			b = allocb(m);
1083 			if(waserror()){
1084 				freeb(b);
1085 				nexterror();
1086 			}
1087 			memmove(b->wp, p, m);
1088 			poperror();
1089 			b->wp += m;
1090 
1091 			sslput(s, b);
1092 
1093 			p += m;
1094 		} while(p < e);
1095 
1096 		poperror();
1097 		qunlock(&s->out.q);
1098 		return n;
1099 	}
1100 
1101 	/* mutex with operations using what we're about to change */
1102 	if(waserror()){
1103 		qunlock(&s->in.ctlq);
1104 		qunlock(&s->out.q);
1105 		nexterror();
1106 	}
1107 	qlock(&s->in.ctlq);
1108 	qlock(&s->out.q);
1109 
1110 	switch(t){
1111 	default:
1112 		panic("sslwrite");
1113 	case Qsecretin:
1114 		setsecret(&s->in, a, n);
1115 		goto out;
1116 	case Qsecretout:
1117 		setsecret(&s->out, a, n);
1118 		goto out;
1119 	case Qctl:
1120 		break;
1121 	}
1122 
1123 	if(n >= sizeof(buf))
1124 		error("arg too long");
1125 	strncpy(buf, a, n);
1126 	buf[n] = 0;
1127 	p = strchr(buf, '\n');
1128 	if(p)
1129 		*p = 0;
1130 	p = strchr(buf, ' ');
1131 	if(p)
1132 		*p++ = 0;
1133 
1134 	if(waserror()){
1135 		free(x);
1136 		nexterror();
1137 	}
1138 	if(strcmp(buf, "fd") == 0){
1139 		s->c = buftochan(p);
1140 
1141 		/* default is clear (msg delimiters only) */
1142 		s->state = Sclear;
1143 		s->blocklen = 1;
1144 		s->diglen = 0;
1145 		s->maxpad = s->max = (1<<15) - s->diglen - 1;
1146 		s->in.mid = 0;
1147 		s->out.mid = 0;
1148 	} else if(strcmp(buf, "alg") == 0 && p != 0){
1149 		s->blocklen = 1;
1150 		s->diglen = 0;
1151 
1152 		if(s->c == 0)
1153 			error("must set fd before algorithm");
1154 
1155 		s->state = Sclear;
1156 		s->maxpad = s->max = (1<<15) - s->diglen - 1;
1157 		if(strcmp(p, "clear") == 0)
1158 			goto outx;
1159 
1160 		if(s->in.secret && s->out.secret == 0)
1161 			setsecret(&s->out, s->in.secret, s->in.slen);
1162 		if(s->out.secret && s->in.secret == 0)
1163 			setsecret(&s->in, s->out.secret, s->out.slen);
1164 		if(s->in.secret == 0 || s->out.secret == 0)
1165 			error("algorithm but no secret");
1166 
1167 		s->hf = 0;
1168 		s->encryptalg = Noencryption;
1169 		s->blocklen = 1;
1170 
1171 		for(;;){
1172 			np = strchr(p, ' ');
1173 			if(np)
1174 				*np++ = 0;
1175 
1176 			if(parsehashalg(p, s) < 0)
1177 			if(parseencryptalg(p, s) < 0)
1178 				error("bad algorithm");
1179 
1180 			if(np == 0)
1181 				break;
1182 			p = np;
1183 		}
1184 
1185 		if(s->hf == 0 && s->encryptalg == Noencryption)
1186 			error("bad algorithm");
1187 
1188 		if(s->blocklen != 1){
1189 			s->max = (1<<15) - s->diglen - 1;
1190 			s->max -= s->max % s->blocklen;
1191 			s->maxpad = (1<<14) - s->diglen - 1;
1192 			s->maxpad -= s->maxpad % s->blocklen;
1193 		} else
1194 			s->maxpad = s->max = (1<<15) - s->diglen - 1;
1195 	} else if(strcmp(buf, "secretin") == 0 && p != 0) {
1196 		m = (strlen(p)*3)/2;
1197 		x = smalloc(m);
1198 		t = dec64(x, m, p, strlen(p));
1199 		if(t <= 0)
1200 			error(Ebadarg);
1201 		setsecret(&s->in, x, t);
1202 	} else if(strcmp(buf, "secretout") == 0 && p != 0) {
1203 		m = (strlen(p)*3)/2 + 1;
1204 		x = smalloc(m);
1205 		t = dec64(x, m, p, strlen(p));
1206 		if(t <= 0)
1207 			error(Ebadarg);
1208 		setsecret(&s->out, x, t);
1209 	} else
1210 		error(Ebadarg);
1211 outx:
1212 	free(x);
1213 	poperror();
1214 out:
1215 	qunlock(&s->in.ctlq);
1216 	qunlock(&s->out.q);
1217 	poperror();
1218 	return n;
1219 }
1220 
1221 static void
sslinit(void)1222 sslinit(void)
1223 {
1224 	struct Encalg *e;
1225 	struct Hashalg *h;
1226 	int n;
1227 	char *cp;
1228 
1229 	n = 1;
1230 	for(e = encrypttab; e->name != nil; e++)
1231 		n += strlen(e->name) + 1;
1232 	cp = encalgs = smalloc(n);
1233 	for(e = encrypttab;;){
1234 		strcpy(cp, e->name);
1235 		cp += strlen(e->name);
1236 		e++;
1237 		if(e->name == nil)
1238 			break;
1239 		*cp++ = ' ';
1240 	}
1241 	*cp = 0;
1242 
1243 	n = 1;
1244 	for(h = hashtab; h->name != nil; h++)
1245 		n += strlen(h->name) + 1;
1246 	cp = hashalgs = smalloc(n);
1247 	for(h = hashtab;;){
1248 		strcpy(cp, h->name);
1249 		cp += strlen(h->name);
1250 		h++;
1251 		if(h->name == nil)
1252 			break;
1253 		*cp++ = ' ';
1254 	}
1255 	*cp = 0;
1256 }
1257 
1258 Dev ssldevtab = {
1259 	'D',
1260 	"ssl",
1261 
1262 	devreset,
1263 	sslinit,
1264 	devshutdown,
1265 	sslattach,
1266 	sslwalk,
1267 	sslstat,
1268 	sslopen,
1269 	devcreate,
1270 	sslclose,
1271 	sslread,
1272 	sslbread,
1273 	sslwrite,
1274 	sslbwrite,
1275 	devremove,
1276 	sslwstat,
1277 };
1278 
1279 static Block*
encryptb(Dstate * s,Block * b,int offset)1280 encryptb(Dstate *s, Block *b, int offset)
1281 {
1282 	uchar *p, *ep, *p2, *ip, *eip;
1283 	DESstate *ds;
1284 
1285 	switch(s->encryptalg){
1286 	case DESECB:
1287 		ds = s->out.state;
1288 		ep = b->rp + BLEN(b);
1289 		for(p = b->rp + offset; p < ep; p += 8)
1290 			block_cipher(ds->expanded, p, 0);
1291 		break;
1292 	case DESCBC:
1293 		ds = s->out.state;
1294 		ep = b->rp + BLEN(b);
1295 		for(p = b->rp + offset; p < ep; p += 8){
1296 			p2 = p;
1297 			ip = ds->ivec;
1298 			for(eip = ip+8; ip < eip; )
1299 				*p2++ ^= *ip++;
1300 			block_cipher(ds->expanded, p, 0);
1301 			memmove(ds->ivec, p, 8);
1302 		}
1303 		break;
1304 	case RC4:
1305 		rc4(s->out.state, b->rp + offset, BLEN(b) - offset);
1306 		break;
1307 	}
1308 	return b;
1309 }
1310 
1311 static Block*
decryptb(Dstate * s,Block * bin)1312 decryptb(Dstate *s, Block *bin)
1313 {
1314 	Block *b, **l;
1315 	uchar *p, *ep, *tp, *ip, *eip;
1316 	DESstate *ds;
1317 	uchar tmp[8];
1318 	int i;
1319 
1320 	l = &bin;
1321 	for(b = bin; b; b = b->next){
1322 		/* make sure we have a multiple of s->blocklen */
1323 		if(s->blocklen > 1){
1324 			i = BLEN(b);
1325 			if(i % s->blocklen){
1326 				*l = b = pullupblock(b, i + s->blocklen - (i%s->blocklen));
1327 				if(b == 0)
1328 					error("ssl encrypted message too short");
1329 			}
1330 		}
1331 		l = &b->next;
1332 
1333 		/* decrypt */
1334 		switch(s->encryptalg){
1335 		case DESECB:
1336 			ds = s->in.state;
1337 			ep = b->rp + BLEN(b);
1338 			for(p = b->rp; p < ep; p += 8)
1339 				block_cipher(ds->expanded, p, 1);
1340 			break;
1341 		case DESCBC:
1342 			ds = s->in.state;
1343 			ep = b->rp + BLEN(b);
1344 			for(p = b->rp; p < ep;){
1345 				memmove(tmp, p, 8);
1346 				block_cipher(ds->expanded, p, 1);
1347 				tp = tmp;
1348 				ip = ds->ivec;
1349 				for(eip = ip+8; ip < eip; ){
1350 					*p++ ^= *ip;
1351 					*ip++ = *tp++;
1352 				}
1353 			}
1354 			break;
1355 		case RC4:
1356 			rc4(s->in.state, b->rp, BLEN(b));
1357 			break;
1358 		}
1359 	}
1360 	return bin;
1361 }
1362 
1363 static Block*
digestb(Dstate * s,Block * b,int offset)1364 digestb(Dstate *s, Block *b, int offset)
1365 {
1366 	uchar *p;
1367 	DigestState ss;
1368 	uchar msgid[4];
1369 	ulong n, h;
1370 	OneWay *w;
1371 
1372 	w = &s->out;
1373 
1374 	memset(&ss, 0, sizeof(ss));
1375 	h = s->diglen + offset;
1376 	n = BLEN(b) - h;
1377 
1378 	/* hash secret + message */
1379 	(*s->hf)(w->secret, w->slen, 0, &ss);
1380 	(*s->hf)(b->rp + h, n, 0, &ss);
1381 
1382 	/* hash message id */
1383 	p = msgid;
1384 	n = w->mid;
1385 	*p++ = n>>24;
1386 	*p++ = n>>16;
1387 	*p++ = n>>8;
1388 	*p = n;
1389 	(*s->hf)(msgid, 4, b->rp + offset, &ss);
1390 
1391 	return b;
1392 }
1393 
1394 static void
checkdigestb(Dstate * s,Block * bin)1395 checkdigestb(Dstate *s, Block *bin)
1396 {
1397 	uchar *p;
1398 	DigestState ss;
1399 	uchar msgid[4];
1400 	int n, h;
1401 	OneWay *w;
1402 	uchar digest[128];
1403 	Block *b;
1404 
1405 	w = &s->in;
1406 
1407 	memset(&ss, 0, sizeof(ss));
1408 
1409 	/* hash secret */
1410 	(*s->hf)(w->secret, w->slen, 0, &ss);
1411 
1412 	/* hash message */
1413 	h = s->diglen;
1414 	for(b = bin; b; b = b->next){
1415 		n = BLEN(b) - h;
1416 		if(n < 0)
1417 			panic("checkdigestb");
1418 		(*s->hf)(b->rp + h, n, 0, &ss);
1419 		h = 0;
1420 	}
1421 
1422 	/* hash message id */
1423 	p = msgid;
1424 	n = w->mid;
1425 	*p++ = n>>24;
1426 	*p++ = n>>16;
1427 	*p++ = n>>8;
1428 	*p = n;
1429 	(*s->hf)(msgid, 4, digest, &ss);
1430 
1431 	if(memcmp(digest, bin->rp, s->diglen) != 0)
1432 		error("bad digest");
1433 }
1434 
1435 /* get channel associated with an fd */
1436 static Chan*
buftochan(char * p)1437 buftochan(char *p)
1438 {
1439 	Chan *c;
1440 	int fd;
1441 
1442 	if(p == 0)
1443 		error(Ebadarg);
1444 	fd = strtoul(p, 0, 0);
1445 	if(fd < 0)
1446 		error(Ebadarg);
1447 	c = fdtochan(fd, -1, 0, 1);	/* error check and inc ref */
1448 	if(devtab[c->type] == &ssldevtab){
1449 		cclose(c);
1450 		error("cannot ssl encrypt devssl files");
1451 	}
1452 	return c;
1453 }
1454 
1455 /* hand up a digest connection */
1456 static void
sslhangup(Dstate * s)1457 sslhangup(Dstate *s)
1458 {
1459 	Block *b;
1460 
1461 	qlock(&s->in.q);
1462 	for(b = s->processed; b; b = s->processed){
1463 		s->processed = b->next;
1464 		freeb(b);
1465 	}
1466 	if(s->unprocessed){
1467 		freeb(s->unprocessed);
1468 		s->unprocessed = 0;
1469 	}
1470 	s->state = Sincomplete;
1471 	qunlock(&s->in.q);
1472 }
1473 
1474 static Dstate*
dsclone(Chan * ch)1475 dsclone(Chan *ch)
1476 {
1477 	int i;
1478 	Dstate *ret;
1479 
1480 	if(waserror()) {
1481 		unlock(&dslock);
1482 		nexterror();
1483 	}
1484 	lock(&dslock);
1485 	ret = nil;
1486 	for(i=0; i<Maxdstate; i++){
1487 		if(dstate[i] == nil){
1488 			dsnew(ch, &dstate[i]);
1489 			ret = dstate[i];
1490 			break;
1491 		}
1492 	}
1493 	unlock(&dslock);
1494 	poperror();
1495 	return ret;
1496 }
1497 
1498 static void
dsnew(Chan * ch,Dstate ** pp)1499 dsnew(Chan *ch, Dstate **pp)
1500 {
1501 	Dstate *s;
1502 	int t;
1503 
1504 	*pp = s = malloc(sizeof(*s));
1505 	if(!s)
1506 		error(Enomem);
1507 	if(pp - dstate >= dshiwat)
1508 		dshiwat++;
1509 	memset(s, 0, sizeof(*s));
1510 	s->state = Sincomplete;
1511 	s->ref = 1;
1512 	kstrdup(&s->user, up->user);
1513 	s->perm = 0660;
1514 	t = TYPE(ch->qid);
1515 	if(t == Qclonus)
1516 		t = Qctl;
1517 	ch->qid.path = QID(pp - dstate, t);
1518 	ch->qid.vers = 0;
1519 }
1520