xref: /plan9-contrib/sys/src/cmd/ip/httpd/websocket.c (revision 5367413097a46c6f6fc69aad7c8df91da69cc758)
1 /* Copyright © 2013-2014 David Hoskin <root@davidrhoskin.com> */
2 
3 #include <u.h>
4 #include <libc.h>
5 #include <thread.h>
6 #include <bio.h>
7 #include <mp.h>
8 #include <libsec.h>
9 #include <auth.h>
10 #include "httpd.h"
11 #include "httpsrv.h"
12 
13 enum
14 {
15 	/* misc parameters */
16 	MAXHDRS = 64,
17 	STACKSZ = 32768,
18 	BUFSZ = 16384,
19 	CHANBUF = 8,
20 
21 	/* packet types */
22 	/* standard non-control frames */
23 	Cont = 0x0,
24 	Text = 0x1,
25 	Binary = 0x2,
26 	/* reserved non-control frames */
27 	/* standard control frames */
28 	Close = 0x8,
29 	Ping = 0x9,
30 	Pong = 0xA,
31 	/* reserved control frames */
32 };
33 
34 typedef struct Procio Procio;
35 struct Procio
36 {
37 	Channel *c;
38 	Biobuf *b;
39 	int fd;
40 	char **argv;
41 };
42 
43 typedef struct Buf Buf;
44 struct Buf
45 {
46 	uchar *buf;
47 	long n;
48 };
49 
50 typedef struct Wspkt Wspkt;
51 struct Wspkt
52 {
53 	Buf;
54 	int type;
55 };
56 
57 /* XXX The default was not enough, but this is just a guess. at least 2*sizeof Biobuf */
58 int mainstacksize = 128*1024;
59 
60 const char wsnoncekey[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
61 const char wsversion[] = "13";
62 
63 HSPairs *
parseheaders(char * headers)64 parseheaders(char *headers)
65 {
66 	char *hdrlines[MAXHDRS], *kv[2];
67 	HSPairs *h, *t, *tmp;
68 	int nhdr;
69 	int i;
70 
71 	h = t = nil;
72 
73 	nhdr = getfields(headers, hdrlines, MAXHDRS, 1, "\r\n");
74 
75 	/*
76 	* XXX I think leading whitespaces signifies a continuation line.
77 	* Skip the first line, or else getfields(..., " ") picks up the GET.
78 	*/
79 	for(i = 1; i < nhdr; ++i){
80 
81 		if(hdrlines[i] == nil)
82 			continue;
83 
84 		getfields(hdrlines[i], kv, 2, 1, ": \t");
85 
86 		tmp = malloc(sizeof(HSPairs));
87 		if(tmp == nil)
88 			goto cleanup;
89 
90 		tmp->s = kv[0];
91 		tmp->t = kv[1];
92 
93 		if(h == nil){
94 			h = t = tmp;
95 		}else{
96 			t->next = tmp;
97 			t = tmp;
98 		}
99 		tmp->next = nil;
100 	}
101 
102 	return h;
103 
104 cleanup:
105 	for(t = h->next; h != nil; h = t, t = h->next)
106 		free(h);
107 	return nil;
108 }
109 
110 char *
getheader(HSPairs * h,const char * k)111 getheader(HSPairs *h, const char *k)
112 {
113 	for(; h != nil; h = h->next)
114 		if(cistrcmp(h->s, k) == 0)
115 			return h->t;
116 	return nil;
117 }
118 
119 int
failhdr(HConnect * c,int code,const char * status,const char * message)120 failhdr(HConnect *c, int code, const char *status, const char *message)
121 {
122 	Hio *o;
123 
124 	o = &c->hout;
125 	hprint(o, "%s %d %s\r\n", hversion, code, status);
126 	hprint(o, "Server: Plan9\r\n");
127 	hprint(o, "Date: %D\r\n", time(nil));
128 	hprint(o, "Content-type: text/html\r\n");
129 	hprint(o, "\r\n");
130 	hprint(o, "<html><head><title>%d %s</title></head>\n", code, status);
131 	hprint(o, "<body><h1>%d %s</h1>\n", code, status);
132 	hprint(o, "<p>Failed to establish websocket connection: %s\n", message);
133 	hprint(o, "</body></html>\n");
134 	hflush(o);
135 	return 0;
136 }
137 
138 void
okhdr(HConnect * c,const char * wshashedkey,const char * proto)139 okhdr(HConnect *c, const char *wshashedkey, const char *proto)
140 {
141 	Hio *o;
142 
143 	o = &c->hout;
144 	hprint(o, "%s 101 Switching Protocols\r\n", hversion);
145 	hprint(o, "Upgrade: websocket\r\n");
146 	hprint(o, "Connection: upgrade\r\n");
147 	hprint(o, "Sec-WebSocket-Accept: %s\r\n", wshashedkey);
148 	if(proto != nil)
149 		hprint(o, "Sec-WebSocket-Protocol: %s\r\n", proto);
150 	/* we don't handle extensions */
151 	hprint(o, "\r\n");
152 	hflush(o);
153 }
154 
155 int
testwsversion(const char * vs)156 testwsversion(const char *vs)
157 {
158 	int i, n;
159 	char *v[16];
160 
161 	n = getfields(vs, v, 16, 1, "\t ,");
162 	for(i = 0; i < n; ++i)
163 		if(strcmp(v[i], wsversion) == 0)
164 			return 1;
165 	return 0;
166 }
167 
168 uvlong
getbe(uchar * t,int w)169 getbe(uchar *t, int w)
170 {
171 	uint i;
172 	uvlong r;
173 
174 	r = 0;
175 	for(i = 0; i < w; i++)
176 		r = r<<8 | t[i];
177 	return r;
178 }
179 
180 int
Bgetbe(Biobuf * b,uvlong * u,int sz)181 Bgetbe(Biobuf *b, uvlong *u, int sz)
182 {
183 	uchar buf[8];
184 
185 	if(Bread(b, buf, sz) != sz)
186 		return -1;
187 
188 	*u = getbe(buf, sz);
189 	return 1;
190 }
191 
192 int
sendpkt(Biobuf * b,Wspkt * pkt)193 sendpkt(Biobuf *b, Wspkt *pkt)
194 {
195 	uchar hdr[2+8];
196 	long hdrsz, len;
197 
198 	hdr[0] = 0x80 | pkt->type;
199 	len = pkt->n;
200 
201 	/* XXX should use putbe(). */
202 	if(len >= (1 << 16)){
203 		hdrsz = 2 + 8;
204 		hdr[1] = 127;
205 		hdr[2] = hdr[3] = hdr[4] = hdr[5] = 0;
206 		hdr[6] = len >> 24;
207 		hdr[7] = len >> 16;
208 		hdr[8] = len >> 8;
209 		hdr[9] = len >> 0;
210 	}else if(len >= 126){
211 		hdrsz = 2 + 2;
212 		hdr[1] = 126;
213 		hdr[2] = len >> 8;
214 		hdr[3]= len >> 0;
215 	}else{
216 		hdrsz = 2;
217 		hdr[1] = len;
218 	}
219 
220 	if(Bwrite(b, hdr, hdrsz) != hdrsz)
221 		return -1;
222 	if(Bwrite(b, pkt->buf, len) != len)
223 		return -1;
224 	if(Bflush(b) < 0)
225 		return -1;
226 
227 	return 0;
228 }
229 
230 int
recvpkt(Wspkt * pkt,Biobuf * b)231 recvpkt(Wspkt *pkt, Biobuf *b)
232 {
233 	long x;
234 	int masked;
235 	uchar mask[4];
236 
237 	pkt->type = Bgetc(b);
238 	if(pkt->type < 0){
239 		return -1;
240 	}
241 	/* Strip FIN/continuation bit. */
242 	pkt->type &= 0x0F;
243 
244 	pkt->n = Bgetc(b);
245 	if(pkt->n < 0){
246 		return -1;
247 	}
248 	masked = pkt->n & 0x80;
249 	pkt->n &= 0x7F;
250 
251 	if(pkt->n >= 127){
252 		if(Bgetbe(b, (uvlong *)&pkt->n, 8) != 1)
253 			return -1;
254 	}else if(pkt->n == 126){
255 		if(Bgetbe(b, (uvlong *)&pkt->n, 2) != 1)
256 			return -1;
257 	}
258 
259 	if(masked){
260 		if(Bread(b, mask, 4) != 4)
261 			return -1;
262 	}
263 	/* allocate appropriate buffer */
264 	if(pkt->n > BUFSZ){
265 		/*
266 		* buffer is unacceptably large!
267 		* XXX this should close the connection with a specific error code.
268 		* See websocket spec.
269 		*/
270 		return -1;
271 	}else if(pkt->n == 0){
272 		pkt->buf = nil;
273 		return 1;
274 	}else{
275 		pkt->buf = malloc(pkt->n);
276 		if(pkt->buf == nil)
277 			return -1;
278 
279 		if(Bread(b, pkt->buf, pkt->n) != pkt->n){
280 			free(pkt->buf);
281 			return -1;
282 		}
283 
284 		if(masked)
285 			for(x = 0; x < pkt->n; ++x)
286 				pkt->buf[x] ^= mask[x % 4];
287 
288 		return 1;
289 	}
290 }
291 
292 void
wsreadproc(void * arg)293 wsreadproc(void *arg)
294 {
295 	Procio *pio;
296 	Channel *c;
297 	Biobuf *b;
298 	Wspkt pkt;
299 
300 	pio = (Procio *)arg;
301 	c = pio->c;
302 	b = pio->b;
303 
304 	for(;;){
305 		if(recvpkt(&pkt, b) < 0)
306 			break;
307 		if(send(c, &pkt) < 0){
308 			free(pkt.buf);
309 			break;
310 		}
311 	}
312 
313 	chanclose(c);
314 	threadexits(nil);
315 }
316 
317 void
wswriteproc(void * arg)318 wswriteproc(void *arg)
319 {
320 	Procio *pio;
321 	Channel *c;
322 	Biobuf *b;
323 	Wspkt pkt;
324 
325 	pio = (Procio *)arg;
326 	c = pio->c;
327 	b = pio->b;
328 
329 	for(;;){
330 		if(recv(c, &pkt) < 0)
331 			break;
332 		if(sendpkt(b, &pkt) < 0){
333 			free(pkt.buf);
334 			break;
335 		}
336 		free(pkt.buf);
337 	}
338 
339 	chanclose(c);
340 	threadexits(nil);
341 }
342 
343 void
pipereadproc(void * arg)344 pipereadproc(void *arg)
345 {
346 	Procio *pio;
347 	Channel *c;
348 	int fd;
349 	Buf b;
350 
351 	pio = (Procio *)arg;
352 	c = pio->c;
353 	fd = pio->fd;
354 
355 	for(;;){
356 		b.buf = malloc(BUFSZ);
357 		if(b.buf == nil)
358 			break;
359 		b.n = read(fd, b.buf, BUFSZ);
360 		if(b.n < 1)
361 			break;
362 		if(send(c, &b) < 0)
363 			break;
364 	}
365 
366 	free(b.buf);
367 	chanclose(c);
368 	threadexits(nil);
369 }
370 
371 void
pipewriteproc(void * arg)372 pipewriteproc(void *arg)
373 {
374 	Procio *pio;
375 	Channel *c;
376 	int fd;
377 	Buf b;
378 
379 	pio = (Procio *)arg;
380 	c = pio->c;
381 	fd = pio->fd;
382 
383 	for(;;){
384 		if(recv(c, &b) != 1)
385 			break;
386 		if(write(fd, b.buf, b.n) != b.n){
387 			free(b.buf);
388 			break;
389 		}
390 		free(b.buf);
391 	}
392 
393 	chanclose(c);
394 	threadexits(nil);
395 }
396 
397 void
mountproc(void * arg)398 mountproc(void *arg)
399 {
400 	Procio *pio;
401 	int fd, i;
402 	char **argv;
403 
404 	pio = (Procio *)arg;
405 	fd = pio->fd;
406 	argv = pio->argv;
407 
408 	for(i = 0; i < 20; ++i){
409 		if(i != fd)
410 			close(i);
411 	}
412 
413 	newns("none", nil);
414 
415 	if(mount(fd, -1, "/dev/", MBEFORE, "") == -1)
416 		sysfatal("mount failed: %r");
417 
418 	procexec(nil, argv[0], argv);
419 }
420 
421 void
echoproc(void * arg)422 echoproc(void *arg)
423 {
424 	Procio *pio;
425 	int fd;
426 	char buf[1024];
427 	int n;
428 
429 	pio = (Procio *)arg;
430 	fd = pio->fd;
431 
432 	for(;;){
433 		n = read(fd, buf, 1024);
434 		if(n > 0)
435 			write(fd, buf, n);
436 	}
437 }
438 
439 int
wscheckhdr(HConnect * c)440 wscheckhdr(HConnect *c)
441 {
442 	HSPairs *hdrs;
443 	char *s, *wsclientkey;
444 	char *rawproto;
445 	char *proto;
446 	char wscatkey[64];
447 	uchar wshashedkey[SHA1dlen];
448 	char wsencoded[32];
449 
450 	if(strcmp(c->req.meth, "GET") != 0)
451 		return hunallowed(c, "GET");
452 
453 	//return failhdr(c, 403, "Forbidden", "my hair is on fire");
454 
455 	hdrs = parseheaders((char *)c->header);
456 
457 	s = getheader(hdrs, "upgrade");
458 	if(s == nil || !cistrstr(s, "websocket"))
459 		return failhdr(c, 400, "Bad Request", "no <code>upgrade: websocket</code> header.");
460 	s = getheader(hdrs, "connection");
461 	if(s == nil || !cistrstr(s, "upgrade"))
462 		return failhdr(c, 400, "Bad Request", "no <code>connection: upgrade</code> header.");
463 	wsclientkey = getheader(hdrs, "sec-websocket-key");
464 	if(wsclientkey == nil || strlen(wsclientkey) != 24)
465 		return failhdr(c, 400, "Bad Request", "invalid websocket nonce key.");
466 	s = getheader(hdrs, "sec-websocket-version");
467 	if(s == nil || !testwsversion(s))
468 		return failhdr(c, 426, "Upgrade Required", "could not match websocket version.");
469 	/* XXX should get resource name */
470 	rawproto = getheader(hdrs, "sec-websocket-protocol");
471 	proto = rawproto;
472 	/* XXX should test if proto is acceptable" */
473 	/* should get sec-websocket-extensions */
474 
475 	/* OK, we seem to have a valid Websocket request. */
476 
477 	/* Hash websocket key. */
478 	strcpy(wscatkey, wsclientkey);
479 	strcat(wscatkey, wsnoncekey);
480 	sha1((uchar *)wscatkey, strlen(wscatkey), wshashedkey, nil);
481 	enc64(wsencoded, 32, wshashedkey, SHA1dlen);
482 
483 	okhdr(c, wsencoded, proto);
484 	hflush(&c->hout);
485 
486 	/* We should now have an open Websocket connection. */
487 
488 	return 1;
489 }
490 
491 int
dowebsock(void)492 dowebsock(void)
493 {
494 	Biobuf bin, bout;
495 	Wspkt pkt;
496 	Buf buf;
497 	int p[2];
498 	Alt a[] = {
499 	/*	c	v	op */
500 		{nil, &pkt, CHANRCV},
501 		{nil, &buf, CHANRCV},
502 		{nil, nil, CHANEND},
503 	};
504 	Procio fromws, tows, frompipe, topipe;
505 	Procio mountp, echop;
506 	char *argv[] = {"/bin/rc", "-c", "ramfs && exec acme", nil};
507 
508 	fromws.c = chancreate(sizeof(Wspkt), CHANBUF);
509 	tows.c = chancreate(sizeof(Wspkt), CHANBUF);
510 	frompipe.c = chancreate(sizeof(Buf), CHANBUF);
511 	topipe.c = chancreate(sizeof(Buf), CHANBUF);
512 
513 	a[0].c = fromws.c;
514 	a[1].c = frompipe.c;
515 
516 	Binit(&bin, 0, OREAD);
517 	Binit(&bout, 1, OWRITE);
518 	fromws.b = &bin;
519 	tows.b = &bout;
520 
521 	pipe(p);
522 	//fd = create("/srv/weebtest", OWRITE, 0666);
523 	//fprint(fd, "%d", p[0]);
524 	//close(fd);
525 	//close(p[0]);
526 
527 	frompipe.fd = p[1];
528 	topipe.fd = p[1];
529 
530 	mountp.fd = echop.fd = p[0];
531 	mountp.argv = argv;
532 
533 	proccreate(wsreadproc, &fromws, STACKSZ);
534 	proccreate(wswriteproc, &tows, STACKSZ);
535 	proccreate(pipereadproc, &frompipe, STACKSZ);
536 	proccreate(pipewriteproc, &topipe, STACKSZ);
537 
538 	//proccreate(echoproc, &echop, STACKSZ);
539 	procrfork(mountproc, &mountp, STACKSZ, RFNAMEG|RFFDG);
540 
541 	for(;;){
542 		int i;
543 
544 		i = alt(a);
545 		if(chanclosing(a[i].c) >= 0){
546 			a[i].op = CHANNOP;
547 			pkt.type = Close;
548 			pkt.buf = nil;
549 			pkt.n = 0;
550 			send(tows.c, &pkt);
551 			goto done;
552 		}
553 
554 		switch(i){
555 		case 0: /* from socket */
556 			if(pkt.type == Ping){
557 				pkt.type = Pong;
558 				send(tows.c, &pkt);
559 			}else if(pkt.type == Close){
560 				send(tows.c, &pkt);
561 				goto done;
562 			}else{
563 				send(topipe.c, &pkt.Buf);
564 			}
565 			break;
566 		case 1: /* from pipe */
567 			pkt.type = Binary;
568 			pkt.Buf = buf;
569 			send(tows.c, &pkt);
570 			break;
571 		default:
572 			sysfatal("can't happen");
573 		}
574 	}
575 done:
576 	return 1;
577 }
578 
579 void
threadmain(int argc,char ** argv)580 threadmain(int argc, char **argv)
581 {
582 	HConnect *c;
583 
584 	c = init(argc, argv);
585 	if(hparseheaders(c, HSTIMEOUT) >= 0)
586 		if(wscheckhdr(c) >= 0)
587 			dowebsock();
588 
589 	threadexitsall(nil);
590 }
591