xref: /plan9-contrib/sys/src/libsec/port/tlshand.c (revision e7c1227a38335dd6dd7ddd89d0dacb1a4d6356b8)
1 #include <u.h>
2 #include <libc.h>
3 #include <bio.h>
4 #include <auth.h>
5 #include <mp.h>
6 #include <libsec.h>
7 
8 // The main groups of functions are:
9 //		client/server - main handshake protocol definition
10 //		message functions - formating handshake messages
11 //		cipher choices - catalog of digest and encrypt algorithms
12 //		security functions - PKCS#1, sslHMAC, session keygen
13 //		general utility functions - malloc, serialization
14 // The handshake protocol builds on the TLS/SSL3 record layer protocol,
15 // which is implemented in kernel device #a.  See also /lib/rfc/rfc2246.
16 
17 enum {
18 	TLSFinishedLen = 12,
19 	SSL3FinishedLen = MD5dlen+SHA1dlen,
20 	MaxKeyData = 136,	// amount of secret we may need
21 	MaxChunk = 1<<14,
22 	RandomSize = 32,
23 	SidSize = 32,
24 	MasterSecretSize = 48,
25 	AQueue = 0,
26 	AFlush = 1,
27 };
28 
29 typedef struct TlsSec TlsSec;
30 
31 typedef struct Bytes{
32 	int len;
33 	uchar data[1];  // [len]
34 } Bytes;
35 
36 typedef struct Ints{
37 	int len;
38 	int data[1];  // [len]
39 } Ints;
40 
41 typedef struct Algs{
42 	char *enc;
43 	char *digest;
44 	int nsecret;
45 	int tlsid;
46 	int ok;
47 } Algs;
48 
49 typedef struct Finished{
50 	uchar verify[SSL3FinishedLen];
51 	int n;
52 } Finished;
53 
54 typedef struct HandHash{
55 	MD5state	md5;
56 	SHAstate	sha1;
57 	SHA2_256state	sha2_256;
58 } HandHash;
59 
60 typedef struct TlsConnection{
61 	TlsSec *sec;	// security management goo
62 	int hand, ctl;	// record layer file descriptors
63 	int erred;		// set when tlsError called
64 	int (*trace)(char*fmt, ...); // for debugging
65 	int version;	// protocol we are speaking
66 	int verset;		// version has been set
67 	int ver2hi;		// server got a version 2 hello
68 	int isClient;	// is this the client or server?
69 	Bytes *sid;		// SessionID
70 	Bytes *cert;	// only last - no chain
71 
72 	Lock statelk;
73 	int state;		// must be set using setstate
74 
75 	// input buffer for handshake messages
76 	uchar buf[MaxChunk+8*1024];
77 	uchar *rp, *ep;
78 
79 	uchar crandom[RandomSize];	// client random
80 	uchar srandom[RandomSize];	// server random
81 	int clientVersion;	// version in ClientHello
82 	char *digest;	// name of digest algorithm to use
83 	char *enc;		// name of encryption algorithm to use
84 	int nsecret;	// amount of secret data to init keys
85 
86 	// for finished messages
87 	HandHash	hs;	// handshake hash
88 	Finished	finished;
89 } TlsConnection;
90 
91 typedef struct Msg{
92 	int tag;
93 	union {
94 		struct {
95 			int version;
96 			uchar 	random[RandomSize];
97 			Bytes*	sid;
98 			Ints*	ciphers;
99 			Bytes*	compressors;
100 			Ints*	sigAlgs;
101 		} clientHello;
102 		struct {
103 			int version;
104 			uchar 	random[RandomSize];
105 			Bytes*	sid;
106 			int cipher;
107 			int compressor;
108 		} serverHello;
109 		struct {
110 			int ncert;
111 			Bytes **certs;
112 		} certificate;
113 		struct {
114 			Bytes *types;
115 			int nca;
116 			Bytes **cas;
117 		} certificateRequest;
118 		struct {
119 			Bytes *key;
120 		} clientKeyExchange;
121 		Finished finished;
122 	} u;
123 } Msg;
124 
125 typedef struct TlsSec{
126 	char *server;	// name of remote; nil for server
127 	int ok;	// <0 killed; == 0 in progress; >0 reusable
128 	RSApub *rsapub;
129 	AuthRpc *rpc;	// factotum for rsa private key
130 	uchar sec[MasterSecretSize];	// master secret
131 	uchar crandom[RandomSize];	// client random
132 	uchar srandom[RandomSize];	// server random
133 	int clientVers;		// version in ClientHello
134 	int vers;			// final version
135 	// byte generation and handshake checksum
136 	void (*prf)(uchar*, int, uchar*, int, char*, uchar*, int, uchar*, int);
137 	void (*setFinished)(TlsSec*, HandHash, uchar*, int);
138 	int nfin;
139 } TlsSec;
140 
141 
142 enum {
143 	SSL3Version  = 0x0300,
144 	TLS10Version = 0x0301,
145 	TLS11Version = 0x0302,
146 	TLS12Version = 0x0303,
147 	ProtocolVersion = TLS12Version,	// maximum version we speak
148 	MinProtoVersion = 0x0300,	// limits on version we accept
149 	MaxProtoVersion	= 0x03ff,
150 };
151 
152 // handshake type
153 enum {
154 	HHelloRequest,
155 	HClientHello,
156 	HServerHello,
157 	HSSL2ClientHello = 9,  /* local convention;  see devtls.c */
158 	HCertificate = 11,
159 	HServerKeyExchange,
160 	HCertificateRequest,
161 	HServerHelloDone,
162 	HCertificateVerify,
163 	HClientKeyExchange,
164 	HFinished = 20,
165 	HMax
166 };
167 
168 // alerts
169 enum {
170 	ECloseNotify = 0,
171 	EUnexpectedMessage = 10,
172 	EBadRecordMac = 20,
173 	EDecryptionFailed = 21,
174 	ERecordOverflow = 22,
175 	EDecompressionFailure = 30,
176 	EHandshakeFailure = 40,
177 	ENoCertificate = 41,
178 	EBadCertificate = 42,
179 	EUnsupportedCertificate = 43,
180 	ECertificateRevoked = 44,
181 	ECertificateExpired = 45,
182 	ECertificateUnknown = 46,
183 	EIllegalParameter = 47,
184 	EUnknownCa = 48,
185 	EAccessDenied = 49,
186 	EDecodeError = 50,
187 	EDecryptError = 51,
188 	EExportRestriction = 60,
189 	EProtocolVersion = 70,
190 	EInsufficientSecurity = 71,
191 	EInternalError = 80,
192 	EUserCanceled = 90,
193 	ENoRenegotiation = 100,
194 	EMax = 256
195 };
196 
197 // cipher suites
198 enum {
199 	TLS_NULL_WITH_NULL_NULL	 		= 0x0000,
200 	TLS_RSA_WITH_NULL_MD5 			= 0x0001,
201 	TLS_RSA_WITH_NULL_SHA 			= 0x0002,
202 	TLS_RSA_EXPORT_WITH_RC4_40_MD5 		= 0x0003,
203 	TLS_RSA_WITH_RC4_128_MD5 		= 0x0004,
204 	TLS_RSA_WITH_RC4_128_SHA 		= 0x0005,
205 	TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5	= 0X0006,
206 	TLS_RSA_WITH_IDEA_CBC_SHA 		= 0X0007,
207 	TLS_RSA_EXPORT_WITH_DES40_CBC_SHA	= 0X0008,
208 	TLS_RSA_WITH_DES_CBC_SHA		= 0X0009,
209 	TLS_RSA_WITH_3DES_EDE_CBC_SHA		= 0X000A,
210 	TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA	= 0X000B,
211 	TLS_DH_DSS_WITH_DES_CBC_SHA		= 0X000C,
212 	TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA	= 0X000D,
213 	TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA	= 0X000E,
214 	TLS_DH_RSA_WITH_DES_CBC_SHA		= 0X000F,
215 	TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA	= 0X0010,
216 	TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA	= 0X0011,
217 	TLS_DHE_DSS_WITH_DES_CBC_SHA		= 0X0012,
218 	TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA	= 0X0013,	// ZZZ must be implemented for tls1.0 compliance
219 	TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA	= 0X0014,
220 	TLS_DHE_RSA_WITH_DES_CBC_SHA		= 0X0015,
221 	TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA	= 0X0016,
222 	TLS_DH_anon_EXPORT_WITH_RC4_40_MD5	= 0x0017,
223 	TLS_DH_anon_WITH_RC4_128_MD5 		= 0x0018,
224 	TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA	= 0X0019,
225 	TLS_DH_anon_WITH_DES_CBC_SHA		= 0X001A,
226 	TLS_DH_anon_WITH_3DES_EDE_CBC_SHA	= 0X001B,
227 
228 	TLS_RSA_WITH_AES_128_CBC_SHA		= 0X002f,	// aes, aka rijndael with 128 bit blocks
229 	TLS_DH_DSS_WITH_AES_128_CBC_SHA		= 0X0030,
230 	TLS_DH_RSA_WITH_AES_128_CBC_SHA		= 0X0031,
231 	TLS_DHE_DSS_WITH_AES_128_CBC_SHA	= 0X0032,
232 	TLS_DHE_RSA_WITH_AES_128_CBC_SHA	= 0X0033,
233 	TLS_DH_anon_WITH_AES_128_CBC_SHA	= 0X0034,
234 	TLS_RSA_WITH_AES_256_CBC_SHA		= 0X0035,
235 	TLS_DH_DSS_WITH_AES_256_CBC_SHA		= 0X0036,
236 	TLS_DH_RSA_WITH_AES_256_CBC_SHA		= 0X0037,
237 	TLS_DHE_DSS_WITH_AES_256_CBC_SHA	= 0X0038,
238 	TLS_DHE_RSA_WITH_AES_256_CBC_SHA	= 0X0039,
239 	TLS_DH_anon_WITH_AES_256_CBC_SHA	= 0X003A,
240 	CipherMax
241 };
242 
243 // compression methods
244 enum {
245 	CompressionNull = 0,
246 	CompressionMax
247 };
248 
249 // extensions
250 enum {
251 	ExtSigalgs = 0xd,
252 };
253 
254 // signature algorithms
255 enum {
256 	RSA_PKCS1_SHA1   = 0x0201,
257 	RSA_PKCS1_SHA256 = 0x0401,
258 	RSA_PKCS1_SHA384 = 0x0501,
259 	RSA_PKCS1_SHA512 = 0x0601
260 };
261 
262 static Algs cipherAlgs[] = {
263 	{"rc4_128", "md5", 2*(16+MD5dlen), TLS_RSA_WITH_RC4_128_MD5},
264 	{"rc4_128", "sha1", 2*(16+SHA1dlen), TLS_RSA_WITH_RC4_128_SHA},
265 	{"3des_ede_cbc", "sha1", 2*(4*8+SHA1dlen), TLS_RSA_WITH_3DES_EDE_CBC_SHA},
266 	{"aes_128_cbc", "sha1", 2*(16+16+SHA1dlen), TLS_RSA_WITH_AES_128_CBC_SHA},
267 	{"aes_256_cbc", "sha1", 2*(32+16+SHA1dlen), TLS_RSA_WITH_AES_256_CBC_SHA}
268 };
269 
270 static uchar compressors[] = {
271 	CompressionNull,
272 };
273 
274 static int sigAlgs[] = {
275 	RSA_PKCS1_SHA256,
276 	RSA_PKCS1_SHA1,
277 };
278 
279 static TlsConnection *tlsServer2(int ctl, int hand, uchar *cert, int ncert, int (*trace)(char*fmt, ...), PEMChain *chain);
280 static TlsConnection *tlsClient2(int ctl, int hand, uchar *csid, int ncsid, int (*trace)(char*fmt, ...));
281 
282 static void	msgClear(Msg *m);
283 static char* msgPrint(char *buf, int n, Msg *m);
284 static int	msgRecv(TlsConnection *c, Msg *m);
285 static int	msgSend(TlsConnection *c, Msg *m, int act);
286 static void	tlsError(TlsConnection *c, int err, char *msg, ...);
287 #pragma	varargck argpos	tlsError 3
288 static int setVersion(TlsConnection *c, int version);
289 static int finishedMatch(TlsConnection *c, Finished *f);
290 static void tlsConnectionFree(TlsConnection *c);
291 
292 static int setAlgs(TlsConnection *c, int a);
293 static int okCipher(Ints *cv);
294 static int okCompression(Bytes *cv);
295 static int initCiphers(void);
296 static Ints* makeciphers(void);
297 
298 static TlsSec* tlsSecInits(int cvers, uchar *csid, int ncsid, uchar *crandom, uchar *ssid, int *nssid, uchar *srandom);
299 static int	tlsSecSecrets(TlsSec *sec, int vers, uchar *epm, int nepm, uchar *kd, int nkd);
300 static TlsSec*	tlsSecInitc(int cvers, uchar *crandom);
301 static int	tlsSecSecretc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers, uchar **epm, int *nepm, uchar *kd, int nkd);
302 static int	tlsSecFinished(TlsSec *sec, HandHash hs, uchar *fin, int nfin, int isclient);
303 static void	tlsSecOk(TlsSec *sec);
304 static void	tlsSecKill(TlsSec *sec);
305 static void	tlsSecClose(TlsSec *sec);
306 static void	setMasterSecret(TlsSec *sec, Bytes *pm);
307 static void	serverMasterSecret(TlsSec *sec, uchar *epm, int nepm);
308 static void	setSecrets(TlsSec *sec, uchar *kd, int nkd);
309 static int	clientMasterSecret(TlsSec *sec, RSApub *pub, uchar **epm, int *nepm);
310 static Bytes *pkcs1_encrypt(Bytes* data, RSApub* key, int blocktype);
311 static Bytes *pkcs1_decrypt(TlsSec *sec, uchar *epm, int nepm);
312 static void	tlsSetFinished(TlsSec *sec, HandHash hs, uchar *finished, int isClient);
313 static void	tls12SetFinished(TlsSec *sec, HandHash hs, uchar *finished, int isClient);
314 static void	sslSetFinished(TlsSec *sec, HandHash hs, uchar *finished, int isClient);
315 static void	sslPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label,
316 			uchar *seed0, int nseed0, uchar *seed1, int nseed1);
317 static int setVers(TlsSec *sec, int version);
318 
319 static AuthRpc* factotum_rsa_open(uchar *cert, int certlen);
320 static mpint* factotum_rsa_decrypt(AuthRpc *rpc, mpint *cipher);
321 static void factotum_rsa_close(AuthRpc*rpc);
322 
323 static void* emalloc(int);
324 static void* erealloc(void*, int);
325 static void put32(uchar *p, u32int);
326 static void put24(uchar *p, int);
327 static void put16(uchar *p, int);
328 static u32int get32(uchar *p);
329 static int get24(uchar *p);
330 static int get16(uchar *p);
331 static Bytes* newbytes(int len);
332 static Bytes* makebytes(uchar* buf, int len);
333 static void freebytes(Bytes* b);
334 static Ints* newints(int len);
335 static Ints* makeints(int* buf, int len);
336 static void freeints(Ints* b);
337 
338 //================= client/server ========================
339 
340 //	push TLS onto fd, returning new (application) file descriptor
341 //		or -1 if error.
342 int
tlsServer(int fd,TLSconn * conn)343 tlsServer(int fd, TLSconn *conn)
344 {
345 	char buf[8];
346 	char dname[64];
347 	int n, data, ctl, hand;
348 	TlsConnection *tls;
349 
350 	if(conn == nil)
351 		return -1;
352 	ctl = open("#a/tls/clone", ORDWR);
353 	if(ctl < 0)
354 		return -1;
355 	n = read(ctl, buf, sizeof(buf)-1);
356 	if(n < 0){
357 		close(ctl);
358 		return -1;
359 	}
360 	buf[n] = 0;
361 	snprint(conn->dir, sizeof conn->dir, "#a/tls/%s", buf);
362 	snprint(dname, sizeof dname, "#a/tls/%s/hand", buf);
363 	hand = open(dname, ORDWR);
364 	if(hand < 0){
365 		close(ctl);
366 		return -1;
367 	}
368 	fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion);
369 	tls = tlsServer2(ctl, hand, conn->cert, conn->certlen, conn->trace, conn->chain);
370 	snprint(dname, sizeof dname, "#a/tls/%s/data", buf);
371 	data = open(dname, ORDWR);
372 	close(fd);
373 	close(hand);
374 	close(ctl);
375 	if(data < 0){
376 		return -1;
377 	}
378 	if(tls == nil){
379 		close(data);
380 		return -1;
381 	}
382 	if(conn->cert)
383 		free(conn->cert);
384 	conn->cert = 0;  // client certificates are not yet implemented
385 	conn->certlen = 0;
386 	conn->sessionIDlen = tls->sid->len;
387 	conn->sessionID = emalloc(conn->sessionIDlen);
388 	memcpy(conn->sessionID, tls->sid->data, conn->sessionIDlen);
389 	if(conn->sessionKey != nil && conn->sessionType != nil && strcmp(conn->sessionType, "ttls") == 0)
390 		tls->sec->prf(conn->sessionKey, conn->sessionKeylen, tls->sec->sec, MasterSecretSize, conn->sessionConst,  tls->sec->crandom, RandomSize, tls->sec->srandom, RandomSize);
391 	tlsConnectionFree(tls);
392 	return data;
393 }
394 
395 //	push TLS onto fd, returning new (application) file descriptor
396 //		or -1 if error.
397 int
tlsClient(int fd,TLSconn * conn)398 tlsClient(int fd, TLSconn *conn)
399 {
400 	char buf[8];
401 	char dname[64];
402 	int n, data, ctl, hand;
403 	TlsConnection *tls;
404 
405 	if(!conn)
406 		return -1;
407 	ctl = open("#a/tls/clone", ORDWR);
408 	if(ctl < 0)
409 		return -1;
410 	n = read(ctl, buf, sizeof(buf)-1);
411 	if(n < 0){
412 		close(ctl);
413 		return -1;
414 	}
415 	buf[n] = 0;
416 	snprint(conn->dir, sizeof conn->dir, "#a/tls/%s", buf);
417 	snprint(dname, sizeof dname, "#a/tls/%s/hand", buf);
418 	hand = open(dname, ORDWR);
419 	if(hand < 0){
420 		close(ctl);
421 		return -1;
422 	}
423 	snprint(dname, sizeof dname, "#a/tls/%s/data", buf);
424 	data = open(dname, ORDWR);
425 	if(data < 0) {
426 		close(hand);
427 		close(ctl);
428 		return -1;
429 	}
430 	fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion);
431 	tls = tlsClient2(ctl, hand, conn->sessionID, conn->sessionIDlen, conn->trace);
432 	close(fd);
433 	close(hand);
434 	close(ctl);
435 	if(tls == nil){
436 		close(data);
437 		return -1;
438 	}
439 	conn->certlen = tls->cert->len;
440 	conn->cert = emalloc(conn->certlen);
441 	memcpy(conn->cert, tls->cert->data, conn->certlen);
442 	conn->sessionIDlen = tls->sid->len;
443 	conn->sessionID = emalloc(conn->sessionIDlen);
444 	memcpy(conn->sessionID, tls->sid->data, conn->sessionIDlen);
445 	if(conn->sessionKey != nil && conn->sessionType != nil && strcmp(conn->sessionType, "ttls") == 0)
446 		tls->sec->prf(conn->sessionKey, conn->sessionKeylen, tls->sec->sec, MasterSecretSize, conn->sessionConst,  tls->sec->crandom, RandomSize, tls->sec->srandom, RandomSize);
447 	tlsConnectionFree(tls);
448 	return data;
449 }
450 
451 static int
countchain(PEMChain * p)452 countchain(PEMChain *p)
453 {
454 	int i = 0;
455 
456 	while (p) {
457 		i++;
458 		p = p->next;
459 	}
460 	return i;
461 }
462 
463 static TlsConnection *
tlsServer2(int ctl,int hand,uchar * cert,int ncert,int (* trace)(char * fmt,...),PEMChain * chp)464 tlsServer2(int ctl, int hand, uchar *cert, int ncert, int (*trace)(char*fmt, ...), PEMChain *chp)
465 {
466 	TlsConnection *c;
467 	Msg m;
468 	Bytes *csid;
469 	uchar sid[SidSize], kd[MaxKeyData];
470 	char *secrets;
471 	int cipher, compressor, nsid, rv, numcerts, i;
472 
473 	if(trace)
474 		trace("tlsServer2\n");
475 	if(!initCiphers())
476 		return nil;
477 	c = emalloc(sizeof(TlsConnection));
478 	c->ctl = ctl;
479 	c->hand = hand;
480 	c->trace = trace;
481 	c->version = ProtocolVersion;
482 
483 	memset(&m, 0, sizeof(m));
484 	if(!msgRecv(c, &m)){
485 		if(trace)
486 			trace("initial msgRecv failed\n");
487 		goto Err;
488 	}
489 	if(m.tag != HClientHello) {
490 		tlsError(c, EUnexpectedMessage, "expected a client hello");
491 		goto Err;
492 	}
493 	c->clientVersion = m.u.clientHello.version;
494 	if(trace)
495 		trace("ClientHello version %x\n", c->clientVersion);
496 	if(setVersion(c, m.u.clientHello.version) < 0) {
497 		tlsError(c, EIllegalParameter, "incompatible version");
498 		goto Err;
499 	}
500 
501 	memmove(c->crandom, m.u.clientHello.random, RandomSize);
502 	cipher = okCipher(m.u.clientHello.ciphers);
503 	if(cipher < 0) {
504 		// reply with EInsufficientSecurity if we know that's the case
505 		if(cipher == -2)
506 			tlsError(c, EInsufficientSecurity, "cipher suites too weak");
507 		else
508 			tlsError(c, EHandshakeFailure, "no matching cipher suite");
509 		goto Err;
510 	}
511 	if(!setAlgs(c, cipher)){
512 		tlsError(c, EHandshakeFailure, "no matching cipher suite");
513 		goto Err;
514 	}
515 	compressor = okCompression(m.u.clientHello.compressors);
516 	if(compressor < 0) {
517 		tlsError(c, EHandshakeFailure, "no matching compressor");
518 		goto Err;
519 	}
520 
521 	csid = m.u.clientHello.sid;
522 	if(trace)
523 		trace("  cipher %d, compressor %d, csidlen %d\n", cipher, compressor, csid->len);
524 	c->sec = tlsSecInits(c->clientVersion, csid->data, csid->len, c->crandom, sid, &nsid, c->srandom);
525 	if(c->sec == nil){
526 		tlsError(c, EHandshakeFailure, "can't initialize security: %r");
527 		goto Err;
528 	}
529 	c->sec->rpc = factotum_rsa_open(cert, ncert);
530 	if(c->sec->rpc == nil){
531 		tlsError(c, EHandshakeFailure, "factotum_rsa_open: %r");
532 		goto Err;
533 	}
534 	c->sec->rsapub = X509toRSApub(cert, ncert, nil, 0);
535 	msgClear(&m);
536 
537 	m.tag = HServerHello;
538 	m.u.serverHello.version = c->version;
539 	memmove(m.u.serverHello.random, c->srandom, RandomSize);
540 	m.u.serverHello.cipher = cipher;
541 	m.u.serverHello.compressor = compressor;
542 	c->sid = makebytes(sid, nsid);
543 	m.u.serverHello.sid = makebytes(c->sid->data, c->sid->len);
544 	if(!msgSend(c, &m, AQueue))
545 		goto Err;
546 	msgClear(&m);
547 
548 	m.tag = HCertificate;
549 	numcerts = countchain(chp);
550 	m.u.certificate.ncert = 1 + numcerts;
551 	m.u.certificate.certs = emalloc(m.u.certificate.ncert * sizeof(Bytes));
552 	m.u.certificate.certs[0] = makebytes(cert, ncert);
553 	for (i = 0; i < numcerts && chp; i++, chp = chp->next)
554 		m.u.certificate.certs[i+1] = makebytes(chp->pem, chp->pemlen);
555 	if(!msgSend(c, &m, AQueue))
556 		goto Err;
557 	msgClear(&m);
558 
559 	m.tag = HServerHelloDone;
560 	if(!msgSend(c, &m, AFlush))
561 		goto Err;
562 	msgClear(&m);
563 
564 	if(!msgRecv(c, &m))
565 		goto Err;
566 	if(m.tag != HClientKeyExchange) {
567 		tlsError(c, EUnexpectedMessage, "expected a client key exchange");
568 		goto Err;
569 	}
570 	if(tlsSecSecrets(c->sec, c->version, m.u.clientKeyExchange.key->data, m.u.clientKeyExchange.key->len, kd, c->nsecret) < 0){
571 		tlsError(c, EHandshakeFailure, "couldn't set secrets: %r");
572 		goto Err;
573 	}
574 	if(trace)
575 		trace("tls secrets\n");
576 	secrets = (char*)emalloc(2*c->nsecret);
577 	enc64(secrets, 2*c->nsecret, kd, c->nsecret);
578 	rv = fprint(c->ctl, "secret %s %s 0 %s", c->digest, c->enc, secrets);
579 	memset(secrets, 0, 2*c->nsecret);
580 	free(secrets);
581 	memset(kd, 0, c->nsecret);
582 	if(rv < 0){
583 		tlsError(c, EHandshakeFailure, "can't set keys: %r");
584 		goto Err;
585 	}
586 	msgClear(&m);
587 
588 	/* no CertificateVerify; skip to Finished */
589 	if(tlsSecFinished(c->sec, c->hs, c->finished.verify, c->finished.n, 1) < 0){
590 		tlsError(c, EInternalError, "can't set finished: %r");
591 		goto Err;
592 	}
593 	if(!msgRecv(c, &m))
594 		goto Err;
595 	if(m.tag != HFinished) {
596 		tlsError(c, EUnexpectedMessage, "expected a finished");
597 		goto Err;
598 	}
599 	if(!finishedMatch(c, &m.u.finished)) {
600 		tlsError(c, EHandshakeFailure, "finished verification failed");
601 		goto Err;
602 	}
603 	msgClear(&m);
604 
605 	/* change cipher spec */
606 	if(fprint(c->ctl, "changecipher") < 0){
607 		tlsError(c, EInternalError, "can't enable cipher: %r");
608 		goto Err;
609 	}
610 
611 	if(tlsSecFinished(c->sec, c->hs, c->finished.verify, c->finished.n, 0) < 0){
612 		tlsError(c, EInternalError, "can't set finished: %r");
613 		goto Err;
614 	}
615 	m.tag = HFinished;
616 	m.u.finished = c->finished;
617 	if(!msgSend(c, &m, AFlush))
618 		goto Err;
619 	msgClear(&m);
620 	if(trace)
621 		trace("tls finished\n");
622 
623 	if(fprint(c->ctl, "opened") < 0)
624 		goto Err;
625 	tlsSecOk(c->sec);
626 	return c;
627 
628 Err:
629 	msgClear(&m);
630 	tlsConnectionFree(c);
631 	return 0;
632 }
633 
634 static TlsConnection *
tlsClient2(int ctl,int hand,uchar * csid,int ncsid,int (* trace)(char * fmt,...))635 tlsClient2(int ctl, int hand, uchar *csid, int ncsid, int (*trace)(char*fmt, ...))
636 {
637 	TlsConnection *c;
638 	Msg m;
639 	uchar kd[MaxKeyData], *epm;
640 	char *secrets;
641 	int creq, nepm, rv;
642 
643 	if(!initCiphers())
644 		return nil;
645 	epm = nil;
646 	c = emalloc(sizeof(TlsConnection));
647 	c->version = ProtocolVersion;
648 	c->ctl = ctl;
649 	c->hand = hand;
650 	c->trace = trace;
651 	c->isClient = 1;
652 	c->clientVersion = c->version;
653 
654 	c->sec = tlsSecInitc(c->clientVersion, c->crandom);
655 	if(c->sec == nil)
656 		goto Err;
657 
658 	/* client hello */
659 	memset(&m, 0, sizeof(m));
660 	m.tag = HClientHello;
661 	m.u.clientHello.version = c->clientVersion;
662 	memmove(m.u.clientHello.random, c->crandom, RandomSize);
663 	m.u.clientHello.sid = makebytes(csid, ncsid);
664 	m.u.clientHello.ciphers = makeciphers();
665 	m.u.clientHello.compressors = makebytes(compressors,sizeof(compressors));
666 	if(c->clientVersion >= TLS12Version)
667 		m.u.clientHello.sigAlgs = makeints(sigAlgs, nelem(sigAlgs));
668 	if(!msgSend(c, &m, AFlush))
669 		goto Err;
670 	msgClear(&m);
671 
672 	/* server hello */
673 	if(!msgRecv(c, &m))
674 		goto Err;
675 	if(m.tag != HServerHello) {
676 		tlsError(c, EUnexpectedMessage, "expected a server hello");
677 		goto Err;
678 	}
679 	if(setVersion(c, m.u.serverHello.version) < 0) {
680 		tlsError(c, EIllegalParameter, "incompatible version %r");
681 		goto Err;
682 	}
683 	memmove(c->srandom, m.u.serverHello.random, RandomSize);
684 	c->sid = makebytes(m.u.serverHello.sid->data, m.u.serverHello.sid->len);
685 	if(c->sid->len != 0 && c->sid->len != SidSize) {
686 		tlsError(c, EIllegalParameter, "invalid server session identifier");
687 		goto Err;
688 	}
689 	if(!setAlgs(c, m.u.serverHello.cipher)) {
690 		tlsError(c, EIllegalParameter, "invalid cipher suite");
691 		goto Err;
692 	}
693 	if(m.u.serverHello.compressor != CompressionNull) {
694 		tlsError(c, EIllegalParameter, "invalid compression");
695 		goto Err;
696 	}
697 	msgClear(&m);
698 
699 	/* certificate */
700 	if(!msgRecv(c, &m) || m.tag != HCertificate) {
701 		tlsError(c, EUnexpectedMessage, "expected a certificate");
702 		goto Err;
703 	}
704 	if(m.u.certificate.ncert < 1) {
705 		tlsError(c, EIllegalParameter, "runt certificate");
706 		goto Err;
707 	}
708 	c->cert = makebytes(m.u.certificate.certs[0]->data, m.u.certificate.certs[0]->len);
709 	msgClear(&m);
710 
711 	/* server key exchange (optional) */
712 	if(!msgRecv(c, &m))
713 		goto Err;
714 	if(m.tag == HServerKeyExchange) {
715 		tlsError(c, EUnexpectedMessage, "got an server key exchange");
716 		goto Err;
717 		// If implementing this later, watch out for rollback attack
718 		// described in Wagner Schneier 1996, section 4.4.
719 	}
720 
721 	/* certificate request (optional) */
722 	creq = 0;
723 	if(m.tag == HCertificateRequest) {
724 		creq = 1;
725 		msgClear(&m);
726 		if(!msgRecv(c, &m))
727 			goto Err;
728 	}
729 
730 	if(m.tag != HServerHelloDone) {
731 		tlsError(c, EUnexpectedMessage, "expected a server hello done");
732 		goto Err;
733 	}
734 	msgClear(&m);
735 
736 	if(tlsSecSecretc(c->sec, c->sid->data, c->sid->len, c->srandom,
737 			c->cert->data, c->cert->len, c->version, &epm, &nepm,
738 			kd, c->nsecret) < 0){
739 		tlsError(c, EBadCertificate, "invalid x509/rsa certificate");
740 		goto Err;
741 	}
742 	secrets = (char*)emalloc(2*c->nsecret);
743 	enc64(secrets, 2*c->nsecret, kd, c->nsecret);
744 	rv = fprint(c->ctl, "secret %s %s 1 %s", c->digest, c->enc, secrets);
745 	memset(secrets, 0, 2*c->nsecret);
746 	free(secrets);
747 	memset(kd, 0, c->nsecret);
748 	if(rv < 0){
749 		tlsError(c, EHandshakeFailure, "can't set keys: %r");
750 		goto Err;
751 	}
752 
753 	if(creq) {
754 		/* send a zero length certificate */
755 		m.tag = HCertificate;
756 		if(!msgSend(c, &m, AFlush))
757 			goto Err;
758 		msgClear(&m);
759 	}
760 
761 	/* client key exchange */
762 	m.tag = HClientKeyExchange;
763 	m.u.clientKeyExchange.key = makebytes(epm, nepm);
764 	free(epm);
765 	epm = nil;
766 	if(m.u.clientKeyExchange.key == nil) {
767 		tlsError(c, EHandshakeFailure, "can't set secret: %r");
768 		goto Err;
769 	}
770 	if(!msgSend(c, &m, AFlush))
771 		goto Err;
772 	msgClear(&m);
773 
774 	/* change cipher spec */
775 	if(fprint(c->ctl, "changecipher") < 0){
776 		tlsError(c, EInternalError, "can't enable cipher: %r");
777 		goto Err;
778 	}
779 
780 	// Cipherchange must occur immediately before Finished to avoid
781 	// potential hole;  see section 4.3 of Wagner Schneier 1996.
782 	if(tlsSecFinished(c->sec, c->hs, c->finished.verify, c->finished.n, 1) < 0){
783 		tlsError(c, EInternalError, "can't set finished 1: %r");
784 		goto Err;
785 	}
786 	m.tag = HFinished;
787 	m.u.finished = c->finished;
788 
789 	if(!msgSend(c, &m, AFlush)) {
790 		fprint(2, "tlsClient nepm=%d\n", nepm);
791 		tlsError(c, EInternalError, "can't flush after client Finished: %r");
792 		goto Err;
793 	}
794 	msgClear(&m);
795 
796 	if(tlsSecFinished(c->sec, c->hs, c->finished.verify, c->finished.n, 0) < 0){
797 		fprint(2, "tlsClient nepm=%d\n", nepm);
798 		tlsError(c, EInternalError, "can't set finished 0: %r");
799 		goto Err;
800 	}
801 	if(!msgRecv(c, &m)) {
802 		fprint(2, "tlsClient nepm=%d\n", nepm);
803 		tlsError(c, EInternalError, "can't read server Finished: %r");
804 		goto Err;
805 	}
806 	if(m.tag != HFinished) {
807 		fprint(2, "tlsClient nepm=%d\n", nepm);
808 		tlsError(c, EUnexpectedMessage, "expected a Finished msg from server");
809 		goto Err;
810 	}
811 
812 	if(!finishedMatch(c, &m.u.finished)) {
813 		tlsError(c, EHandshakeFailure, "finished verification failed");
814 		goto Err;
815 	}
816 	msgClear(&m);
817 
818 	if(fprint(c->ctl, "opened") < 0){
819 		if(trace)
820 			trace("unable to do final open: %r\n");
821 		goto Err;
822 	}
823 	tlsSecOk(c->sec);
824 	return c;
825 
826 Err:
827 	free(epm);
828 	msgClear(&m);
829 	tlsConnectionFree(c);
830 	return 0;
831 }
832 
833 
834 //================= message functions ========================
835 
836 static uchar sendbuf[9000], *sendp;
837 
838 static void
msgHash(TlsConnection * c,uchar * p,int n)839 msgHash(TlsConnection *c, uchar *p, int n)
840 {
841 	md5(p, n, 0, &c->hs.md5);
842 	sha1(p, n, 0, &c->hs.sha1);
843 	if(c->version >= TLS12Version)
844 		sha2_256(p, n, 0, &c->hs.sha2_256);
845 	else
846 		memset(&c->hs.sha2_256, 0, sizeof c->hs.sha2_256);
847 }
848 
849 static int
msgSend(TlsConnection * c,Msg * m,int act)850 msgSend(TlsConnection *c, Msg *m, int act)
851 {
852 	uchar *p; // sendp = start of new message;  p = write pointer
853 	int nn, n, i;
854 
855 	if(sendp == nil)
856 		sendp = sendbuf;
857 	p = sendp;
858 	if(c->trace)
859 		c->trace("send %s", msgPrint((char*)p, (sizeof sendbuf) - (p-sendbuf), m));
860 
861 	p[0] = m->tag;	// header - fill in size later
862 	p += 4;
863 
864 	switch(m->tag) {
865 	default:
866 		tlsError(c, EInternalError, "can't encode a %d", m->tag);
867 		goto Err;
868 	case HClientHello:
869 		// version
870 		put16(p, m->u.clientHello.version);
871 		p += 2;
872 
873 		// random
874 		memmove(p, m->u.clientHello.random, RandomSize);
875 		p += RandomSize;
876 
877 		// sid
878 		n = m->u.clientHello.sid->len;
879 		assert(n < 256);
880 		p[0] = n;
881 		memmove(p+1, m->u.clientHello.sid->data, n);
882 		p += n+1;
883 
884 		n = m->u.clientHello.ciphers->len;
885 		assert(n > 0 && n < 200);
886 		put16(p, n*2);
887 		p += 2;
888 		for(i=0; i<n; i++) {
889 			put16(p, m->u.clientHello.ciphers->data[i]);
890 			p += 2;
891 		}
892 
893 		n = m->u.clientHello.compressors->len;
894 		assert(n > 0);
895 		p[0] = n;
896 		memmove(p+1, m->u.clientHello.compressors->data, n);
897 		p += n+1;
898 
899 		if(m->u.clientHello.sigAlgs != nil) {
900 			n = m->u.clientHello.sigAlgs->len;
901 			put16(p, 6 + 2*n);   /* length of extensions */
902 			put16(p+2, ExtSigalgs);
903 			put16(p+4, 2 + 2*n); /* length of extension content */
904 			put16(p+6, 2*n);     /* length of algorithm list */
905 			p += 8;
906 			for(i = 0; i < n; i++) {
907 				put16(p, m->u.clientHello.sigAlgs->data[i]);
908 				p += 2;
909 			}
910 		}
911 		break;
912 	case HServerHello:
913 		put16(p, m->u.serverHello.version);
914 		p += 2;
915 
916 		// random
917 		memmove(p, m->u.serverHello.random, RandomSize);
918 		p += RandomSize;
919 
920 		// sid
921 		n = m->u.serverHello.sid->len;
922 		assert(n < 256);
923 		p[0] = n;
924 		memmove(p+1, m->u.serverHello.sid->data, n);
925 		p += n+1;
926 
927 		put16(p, m->u.serverHello.cipher);
928 		p += 2;
929 		p[0] = m->u.serverHello.compressor;
930 		p += 1;
931 		break;
932 	case HServerHelloDone:
933 		break;
934 	case HCertificate:
935 		nn = 0;
936 		for(i = 0; i < m->u.certificate.ncert; i++)
937 			nn += 3 + m->u.certificate.certs[i]->len;
938 		if(p + 3 + nn - sendbuf > sizeof(sendbuf)) {
939 			tlsError(c, EInternalError, "output buffer too small for certificate");
940 			goto Err;
941 		}
942 		put24(p, nn);
943 		p += 3;
944 		for(i = 0; i < m->u.certificate.ncert; i++){
945 			put24(p, m->u.certificate.certs[i]->len);
946 			p += 3;
947 			memmove(p, m->u.certificate.certs[i]->data, m->u.certificate.certs[i]->len);
948 			p += m->u.certificate.certs[i]->len;
949 		}
950 		break;
951 	case HClientKeyExchange:
952 		n = m->u.clientKeyExchange.key->len;
953 		if(c->version != SSL3Version){
954 			put16(p, n);
955 			p += 2;
956 		}
957 		memmove(p, m->u.clientKeyExchange.key->data, n);
958 		p += n;
959 		break;
960 	case HFinished:
961 		memmove(p, m->u.finished.verify, m->u.finished.n);
962 		p += m->u.finished.n;
963 		break;
964 	}
965 
966 	// go back and fill in size
967 	n = p-sendp;
968 	assert(p <= sendbuf+sizeof(sendbuf));
969 	put24(sendp+1, n-4);
970 
971 	// remember hash of Handshake messages
972 	if(m->tag != HHelloRequest) {
973 		msgHash(c, sendp, n);
974 	}
975 
976 	sendp = p;
977 	if(act == AFlush){
978 		sendp = sendbuf;
979 		if(write(c->hand, sendbuf, p-sendbuf) < 0){
980 			fprint(2, "write error: %r\n");
981 			goto Err;
982 		}
983 	}
984 	msgClear(m);
985 	return 1;
986 Err:
987 	msgClear(m);
988 	return 0;
989 }
990 
991 static uchar*
tlsReadN(TlsConnection * c,int n)992 tlsReadN(TlsConnection *c, int n)
993 {
994 	uchar *p;
995 	int nn, nr;
996 
997 	nn = c->ep - c->rp;
998 	if(nn < n){
999 		if(c->rp != c->buf){
1000 			memmove(c->buf, c->rp, nn);
1001 			c->rp = c->buf;
1002 			c->ep = &c->buf[nn];
1003 		}
1004 		for(; nn < n; nn += nr) {
1005 			nr = read(c->hand, &c->rp[nn], n - nn);
1006 			if(nr <= 0)
1007 				return nil;
1008 			c->ep += nr;
1009 		}
1010 	}
1011 	p = c->rp;
1012 	c->rp += n;
1013 	return p;
1014 }
1015 
1016 static int
msgRecv(TlsConnection * c,Msg * m)1017 msgRecv(TlsConnection *c, Msg *m)
1018 {
1019 	uchar *p;
1020 	int type, n, nn, nx, i, nsid, nrandom, nciph;
1021 
1022 	for(;;) {
1023 		p = tlsReadN(c, 4);
1024 		if(p == nil)
1025 			return 0;
1026 		type = p[0];
1027 		n = get24(p+1);
1028 
1029 		if(type != HHelloRequest)
1030 			break;
1031 		if(n != 0) {
1032 			tlsError(c, EDecodeError, "invalid hello request during handshake");
1033 			return 0;
1034 		}
1035 	}
1036 
1037 	if(n > sizeof(c->buf)) {
1038 		tlsError(c, EDecodeError, "handshake message too long %d %d", n, sizeof(c->buf));
1039 		return 0;
1040 	}
1041 
1042 	if(type == HSSL2ClientHello){
1043 		/* Cope with an SSL3 ClientHello expressed in SSL2 record format.
1044 			This is sent by some clients that we must interoperate
1045 			with, such as Java's JSSE and Microsoft's Internet Explorer. */
1046 		p = tlsReadN(c, n);
1047 		if(p == nil)
1048 			return 0;
1049 		msgHash(c, p, n);
1050 		m->tag = HClientHello;
1051 		if(n < 22)
1052 			goto Short;
1053 		m->u.clientHello.version = get16(p+1);
1054 		p += 3;
1055 		n -= 3;
1056 		nn = get16(p); /* cipher_spec_len */
1057 		nsid = get16(p + 2);
1058 		nrandom = get16(p + 4);
1059 		p += 6;
1060 		n -= 6;
1061 		if(nsid != 0 	/* no sid's, since shouldn't restart using ssl2 header */
1062 				|| nrandom < 16 || nn % 3)
1063 			goto Err;
1064 		if(c->trace && (n - nrandom != nn))
1065 			c->trace("n-nrandom!=nn: n=%d nrandom=%d nn=%d\n", n, nrandom, nn);
1066 		/* ignore ssl2 ciphers and look for {0x00, ssl3 cipher} */
1067 		nciph = 0;
1068 		for(i = 0; i < nn; i += 3)
1069 			if(p[i] == 0)
1070 				nciph++;
1071 		m->u.clientHello.ciphers = newints(nciph);
1072 		nciph = 0;
1073 		for(i = 0; i < nn; i += 3)
1074 			if(p[i] == 0)
1075 				m->u.clientHello.ciphers->data[nciph++] = get16(&p[i + 1]);
1076 		p += nn;
1077 		m->u.clientHello.sid = makebytes(nil, 0);
1078 		if(nrandom > RandomSize)
1079 			nrandom = RandomSize;
1080 		memset(m->u.clientHello.random, 0, RandomSize - nrandom);
1081 		memmove(&m->u.clientHello.random[RandomSize - nrandom], p, nrandom);
1082 		m->u.clientHello.compressors = newbytes(1);
1083 		m->u.clientHello.compressors->data[0] = CompressionNull;
1084 		goto Ok;
1085 	}
1086 
1087 	msgHash(c, p, 4);
1088 
1089 	p = tlsReadN(c, n);
1090 	if(p == nil)
1091 		return 0;
1092 
1093 	msgHash(c, p, n);
1094 
1095 	m->tag = type;
1096 
1097 	switch(type) {
1098 	default:
1099 		tlsError(c, EUnexpectedMessage, "can't decode a %d", type);
1100 		goto Err;
1101 	case HClientHello:
1102 		if(n < 2)
1103 			goto Short;
1104 		m->u.clientHello.version = get16(p);
1105 		p += 2;
1106 		n -= 2;
1107 
1108 		if(n < RandomSize)
1109 			goto Short;
1110 		memmove(m->u.clientHello.random, p, RandomSize);
1111 		p += RandomSize;
1112 		n -= RandomSize;
1113 		if(n < 1 || n < p[0]+1)
1114 			goto Short;
1115 		m->u.clientHello.sid = makebytes(p+1, p[0]);
1116 		p += m->u.clientHello.sid->len+1;
1117 		n -= m->u.clientHello.sid->len+1;
1118 
1119 		if(n < 2)
1120 			goto Short;
1121 		nn = get16(p);
1122 		p += 2;
1123 		n -= 2;
1124 
1125 		if((nn & 1) || n < nn || nn < 2)
1126 			goto Short;
1127 		m->u.clientHello.ciphers = newints(nn >> 1);
1128 		for(i = 0; i < nn; i += 2)
1129 			m->u.clientHello.ciphers->data[i >> 1] = get16(&p[i]);
1130 		p += nn;
1131 		n -= nn;
1132 
1133 		if(n < 1 || n < p[0]+1 || p[0] == 0)
1134 			goto Short;
1135 		nn = p[0];
1136 		m->u.clientHello.compressors = newbytes(nn);
1137 		memmove(m->u.clientHello.compressors->data, p+1, nn);
1138 		p += nn + 1;
1139 		n -= nn + 1;
1140 
1141 		/* extensions */
1142 		if(n == 0)
1143 			break;
1144 		if(n < 2)
1145 			goto Short;
1146 		nx = get16(p);
1147 		p += 2;
1148 		n -= 2;
1149 		while(nx > 0){
1150 			if(n < nx || nx < 4)
1151 				goto Short;
1152 			i = get16(p);
1153 			nn = get16(p+2);
1154 			if(nx < nn+4)
1155 				goto Short;
1156 			nx -= nn+4;
1157 			p += 4;
1158 			n -= 4;
1159 			if(i == ExtSigalgs){
1160 				if(get16(p) != nn-2)
1161 					goto Short;
1162 				p += 2;
1163 				n -= 2;
1164 				nn -= 2;
1165 				m->u.clientHello.sigAlgs = newints(nn/2);
1166 				for(i = 0; i < nn; i += 2)
1167 					m->u.clientHello.sigAlgs->data[i >> 1] = get16(&p[i]);
1168 			}
1169 			p += nn;
1170 			n -= nn;
1171 		}
1172 		break;
1173 	case HServerHello:
1174 		if(n < 2)
1175 			goto Short;
1176 		m->u.serverHello.version = get16(p);
1177 		p += 2;
1178 		n -= 2;
1179 
1180 		if(n < RandomSize)
1181 			goto Short;
1182 		memmove(m->u.serverHello.random, p, RandomSize);
1183 		p += RandomSize;
1184 		n -= RandomSize;
1185 
1186 		if(n < 1 || n < p[0]+1)
1187 			goto Short;
1188 		m->u.serverHello.sid = makebytes(p+1, p[0]);
1189 		p += m->u.serverHello.sid->len+1;
1190 		n -= m->u.serverHello.sid->len+1;
1191 
1192 		if(n < 3)
1193 			goto Short;
1194 		m->u.serverHello.cipher = get16(p);
1195 		m->u.serverHello.compressor = p[2];
1196 		n -= 3;
1197 		break;
1198 	case HCertificate:
1199 		if(n < 3)
1200 			goto Short;
1201 		nn = get24(p);
1202 		p += 3;
1203 		n -= 3;
1204 		if(n != nn)
1205 			goto Short;
1206 		/* certs */
1207 		i = 0;
1208 		while(n > 0) {
1209 			if(n < 3)
1210 				goto Short;
1211 			nn = get24(p);
1212 			p += 3;
1213 			n -= 3;
1214 			if(nn > n)
1215 				goto Short;
1216 			m->u.certificate.ncert = i+1;
1217 			m->u.certificate.certs = erealloc(m->u.certificate.certs, (i+1)*sizeof(Bytes));
1218 			m->u.certificate.certs[i] = makebytes(p, nn);
1219 			p += nn;
1220 			n -= nn;
1221 			i++;
1222 		}
1223 		break;
1224 	case HCertificateRequest:
1225 		if(n < 1)
1226 			goto Short;
1227 		nn = p[0];
1228 		p += 1;
1229 		n -= 1;
1230 		if(nn < 1 || nn > n)
1231 			goto Short;
1232 		m->u.certificateRequest.types = makebytes(p, nn);
1233 		p += nn;
1234 		n -= nn;
1235 		if(n < 2)
1236 			goto Short;
1237 		nn = get16(p);
1238 		p += 2;
1239 		n -= 2;
1240 		/* nn == 0 can happen; yahoo's servers do it */
1241 		if(nn != n)
1242 			goto Short;
1243 		/* cas */
1244 		i = 0;
1245 		while(n > 0) {
1246 			if(n < 2)
1247 				goto Short;
1248 			nn = get16(p);
1249 			p += 2;
1250 			n -= 2;
1251 			if(nn < 1 || nn > n)
1252 				goto Short;
1253 			m->u.certificateRequest.nca = i+1;
1254 			m->u.certificateRequest.cas = erealloc(
1255 				m->u.certificateRequest.cas, (i+1)*sizeof(Bytes));
1256 			m->u.certificateRequest.cas[i] = makebytes(p, nn);
1257 			p += nn;
1258 			n -= nn;
1259 			i++;
1260 		}
1261 		break;
1262 	case HServerHelloDone:
1263 		break;
1264 	case HClientKeyExchange:
1265 		/*
1266 		 * this message depends upon the encryption selected
1267 		 * assume rsa.
1268 		 */
1269 		if(c->version == SSL3Version)
1270 			nn = n;
1271 		else{
1272 			if(n < 2)
1273 				goto Short;
1274 			nn = get16(p);
1275 			p += 2;
1276 			n -= 2;
1277 		}
1278 		if(n < nn)
1279 			goto Short;
1280 		m->u.clientKeyExchange.key = makebytes(p, nn);
1281 		n -= nn;
1282 		break;
1283 	case HFinished:
1284 		m->u.finished.n = c->finished.n;
1285 		if(n < m->u.finished.n)
1286 			goto Short;
1287 		memmove(m->u.finished.verify, p, m->u.finished.n);
1288 		n -= m->u.finished.n;
1289 		break;
1290 	}
1291 
1292 	if(type != HClientHello && n != 0)
1293 		goto Short;
1294 Ok:
1295 	if(c->trace){
1296 		char *buf;
1297 		buf = emalloc(8000);
1298 		c->trace("recv %s", msgPrint(buf, 8000, m));
1299 		free(buf);
1300 	}
1301 	return 1;
1302 Short:
1303 	tlsError(c, EDecodeError, "handshake message has invalid length");
1304 Err:
1305 	msgClear(m);
1306 	return 0;
1307 }
1308 
1309 static void
msgClear(Msg * m)1310 msgClear(Msg *m)
1311 {
1312 	int i;
1313 
1314 	switch(m->tag) {
1315 	default:
1316 		sysfatal("msgClear: unknown message type: %d", m->tag);
1317 	case HHelloRequest:
1318 		break;
1319 	case HClientHello:
1320 		freebytes(m->u.clientHello.sid);
1321 		freeints(m->u.clientHello.ciphers);
1322 		m->u.clientHello.ciphers = nil;
1323 		freebytes(m->u.clientHello.compressors);
1324 		freeints(m->u.clientHello.sigAlgs);
1325 		break;
1326 	case HServerHello:
1327 		freebytes(m->u.clientHello.sid);
1328 		break;
1329 	case HCertificate:
1330 		for(i=0; i<m->u.certificate.ncert; i++)
1331 			freebytes(m->u.certificate.certs[i]);
1332 		free(m->u.certificate.certs);
1333 		break;
1334 	case HCertificateRequest:
1335 		freebytes(m->u.certificateRequest.types);
1336 		for(i=0; i<m->u.certificateRequest.nca; i++)
1337 			freebytes(m->u.certificateRequest.cas[i]);
1338 		free(m->u.certificateRequest.cas);
1339 		break;
1340 	case HServerHelloDone:
1341 		break;
1342 	case HClientKeyExchange:
1343 		freebytes(m->u.clientKeyExchange.key);
1344 		break;
1345 	case HFinished:
1346 		break;
1347 	}
1348 	memset(m, 0, sizeof(Msg));
1349 }
1350 
1351 static char *
bytesPrint(char * bs,char * be,char * s0,Bytes * b,char * s1)1352 bytesPrint(char *bs, char *be, char *s0, Bytes *b, char *s1)
1353 {
1354 	int i;
1355 
1356 	if(s0)
1357 		bs = seprint(bs, be, "%s", s0);
1358 	bs = seprint(bs, be, "[");
1359 	if(b == nil)
1360 		bs = seprint(bs, be, "nil");
1361 	else
1362 		for(i=0; i<b->len; i++)
1363 			bs = seprint(bs, be, "%.2x ", b->data[i]);
1364 	bs = seprint(bs, be, "]");
1365 	if(s1)
1366 		bs = seprint(bs, be, "%s", s1);
1367 	return bs;
1368 }
1369 
1370 static char *
intsPrint(char * bs,char * be,char * s0,Ints * b,char * s1)1371 intsPrint(char *bs, char *be, char *s0, Ints *b, char *s1)
1372 {
1373 	int i;
1374 
1375 	if(s0)
1376 		bs = seprint(bs, be, "%s", s0);
1377 	bs = seprint(bs, be, "[");
1378 	if(b == nil)
1379 		bs = seprint(bs, be, "nil");
1380 	else
1381 		for(i=0; i<b->len; i++)
1382 			bs = seprint(bs, be, "%x ", b->data[i]);
1383 	bs = seprint(bs, be, "]");
1384 	if(s1)
1385 		bs = seprint(bs, be, "%s", s1);
1386 	return bs;
1387 }
1388 
1389 static char*
msgPrint(char * buf,int n,Msg * m)1390 msgPrint(char *buf, int n, Msg *m)
1391 {
1392 	int i;
1393 	char *bs = buf, *be = buf+n;
1394 
1395 	switch(m->tag) {
1396 	default:
1397 		bs = seprint(bs, be, "unknown %d\n", m->tag);
1398 		break;
1399 	case HClientHello:
1400 		bs = seprint(bs, be, "ClientHello\n");
1401 		bs = seprint(bs, be, "\tversion: %.4x\n", m->u.clientHello.version);
1402 		bs = seprint(bs, be, "\trandom: ");
1403 		for(i=0; i<RandomSize; i++)
1404 			bs = seprint(bs, be, "%.2x", m->u.clientHello.random[i]);
1405 		bs = seprint(bs, be, "\n");
1406 		bs = bytesPrint(bs, be, "\tsid: ", m->u.clientHello.sid, "\n");
1407 		bs = intsPrint(bs, be, "\tciphers: ", m->u.clientHello.ciphers, "\n");
1408 		bs = bytesPrint(bs, be, "\tcompressors: ", m->u.clientHello.compressors, "\n");
1409 		if(m->u.clientHello.sigAlgs != nil)
1410 			bs = intsPrint(bs, be, "\tsigAlgs: ", m->u.clientHello.sigAlgs, "\n");
1411 		break;
1412 	case HServerHello:
1413 		bs = seprint(bs, be, "ServerHello\n");
1414 		bs = seprint(bs, be, "\tversion: %.4x\n", m->u.serverHello.version);
1415 		bs = seprint(bs, be, "\trandom: ");
1416 		for(i=0; i<RandomSize; i++)
1417 			bs = seprint(bs, be, "%.2x", m->u.serverHello.random[i]);
1418 		bs = seprint(bs, be, "\n");
1419 		bs = bytesPrint(bs, be, "\tsid: ", m->u.serverHello.sid, "\n");
1420 		bs = seprint(bs, be, "\tcipher: %.4x\n", m->u.serverHello.cipher);
1421 		bs = seprint(bs, be, "\tcompressor: %.2x\n", m->u.serverHello.compressor);
1422 		break;
1423 	case HCertificate:
1424 		bs = seprint(bs, be, "Certificate\n");
1425 		for(i=0; i<m->u.certificate.ncert; i++)
1426 			bs = bytesPrint(bs, be, "\t", m->u.certificate.certs[i], "\n");
1427 		break;
1428 	case HCertificateRequest:
1429 		bs = seprint(bs, be, "CertificateRequest\n");
1430 		bs = bytesPrint(bs, be, "\ttypes: ", m->u.certificateRequest.types, "\n");
1431 		bs = seprint(bs, be, "\tcertificateauthorities\n");
1432 		for(i=0; i<m->u.certificateRequest.nca; i++)
1433 			bs = bytesPrint(bs, be, "\t\t", m->u.certificateRequest.cas[i], "\n");
1434 		break;
1435 	case HServerHelloDone:
1436 		bs = seprint(bs, be, "ServerHelloDone\n");
1437 		break;
1438 	case HClientKeyExchange:
1439 		bs = seprint(bs, be, "HClientKeyExchange\n");
1440 		bs = bytesPrint(bs, be, "\tkey: ", m->u.clientKeyExchange.key, "\n");
1441 		break;
1442 	case HFinished:
1443 		bs = seprint(bs, be, "HFinished\n");
1444 		for(i=0; i<m->u.finished.n; i++)
1445 			bs = seprint(bs, be, "%.2x", m->u.finished.verify[i]);
1446 		bs = seprint(bs, be, "\n");
1447 		break;
1448 	}
1449 	USED(bs);
1450 	return buf;
1451 }
1452 
1453 static void
tlsError(TlsConnection * c,int err,char * fmt,...)1454 tlsError(TlsConnection *c, int err, char *fmt, ...)
1455 {
1456 	char msg[512];
1457 	va_list arg;
1458 
1459 	va_start(arg, fmt);
1460 	vseprint(msg, msg+sizeof(msg), fmt, arg);
1461 	va_end(arg);
1462 	if(c->trace)
1463 		c->trace("tlsError: %s\n", msg);
1464 	else if(c->erred)
1465 		fprint(2, "double error: %r, %s", msg);
1466 	else
1467 		werrstr("tls: local %s", msg);
1468 	c->erred = 1;
1469 	fprint(c->ctl, "alert %d", err);
1470 }
1471 
1472 // commit to specific version number
1473 static int
setVersion(TlsConnection * c,int version)1474 setVersion(TlsConnection *c, int version)
1475 {
1476 	if(c->verset || version > MaxProtoVersion || version < MinProtoVersion)
1477 		return -1;
1478 	if(version > c->version)
1479 		version = c->version;
1480 	switch(version) {
1481 	case SSL3Version:
1482 		c->finished.n = SSL3FinishedLen;
1483 		return -1;
1484 	case TLS10Version:
1485 	case TLS11Version:
1486 	case TLS12Version:
1487 		c->finished.n = TLSFinishedLen;
1488 		break;
1489 	default:
1490 		return -1;
1491 	}
1492 	c->version = version;
1493 	c->verset = 1;
1494 	return fprint(c->ctl, "version 0x%x", version);
1495 }
1496 
1497 // confirm that received Finished message matches the expected value
1498 static int
finishedMatch(TlsConnection * c,Finished * f)1499 finishedMatch(TlsConnection *c, Finished *f)
1500 {
1501 	return memcmp(f->verify, c->finished.verify, f->n) == 0;
1502 }
1503 
1504 // free memory associated with TlsConnection struct
1505 //		(but don't close the TLS channel itself)
1506 static void
tlsConnectionFree(TlsConnection * c)1507 tlsConnectionFree(TlsConnection *c)
1508 {
1509 	tlsSecClose(c->sec);
1510 	freebytes(c->sid);
1511 	freebytes(c->cert);
1512 	memset(c, 0, sizeof *c);
1513 	free(c);
1514 }
1515 
1516 
1517 //================= cipher choices ========================
1518 
1519 static int weakCipher[CipherMax] =
1520 {
1521 	1,	/* TLS_NULL_WITH_NULL_NULL */
1522 	1,	/* TLS_RSA_WITH_NULL_MD5 */
1523 	1,	/* TLS_RSA_WITH_NULL_SHA */
1524 	1,	/* TLS_RSA_EXPORT_WITH_RC4_40_MD5 */
1525 	1,	/* TLS_RSA_WITH_RC4_128_MD5 */
1526 	1,	/* TLS_RSA_WITH_RC4_128_SHA */
1527 	1,	/* TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5 */
1528 	0,	/* TLS_RSA_WITH_IDEA_CBC_SHA */
1529 	1,	/* TLS_RSA_EXPORT_WITH_DES40_CBC_SHA */
1530 	0,	/* TLS_RSA_WITH_DES_CBC_SHA */
1531 	0,	/* TLS_RSA_WITH_3DES_EDE_CBC_SHA */
1532 	1,	/* TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA */
1533 	0,	/* TLS_DH_DSS_WITH_DES_CBC_SHA */
1534 	0,	/* TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA */
1535 	1,	/* TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA */
1536 	0,	/* TLS_DH_RSA_WITH_DES_CBC_SHA */
1537 	0,	/* TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA */
1538 	1,	/* TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA */
1539 	0,	/* TLS_DHE_DSS_WITH_DES_CBC_SHA */
1540 	0,	/* TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA */
1541 	1,	/* TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA */
1542 	0,	/* TLS_DHE_RSA_WITH_DES_CBC_SHA */
1543 	0,	/* TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA */
1544 	1,	/* TLS_DH_anon_EXPORT_WITH_RC4_40_MD5 */
1545 	1,	/* TLS_DH_anon_WITH_RC4_128_MD5 */
1546 	1,	/* TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA */
1547 	1,	/* TLS_DH_anon_WITH_DES_CBC_SHA */
1548 	1,	/* TLS_DH_anon_WITH_3DES_EDE_CBC_SHA */
1549 };
1550 
1551 static int
setAlgs(TlsConnection * c,int a)1552 setAlgs(TlsConnection *c, int a)
1553 {
1554 	int i;
1555 
1556 	for(i = 0; i < nelem(cipherAlgs); i++){
1557 		if(cipherAlgs[i].tlsid == a){
1558 			c->enc = cipherAlgs[i].enc;
1559 			c->digest = cipherAlgs[i].digest;
1560 			c->nsecret = cipherAlgs[i].nsecret;
1561 			if(c->nsecret > MaxKeyData)
1562 				return 0;
1563 			return 1;
1564 		}
1565 	}
1566 	return 0;
1567 }
1568 
1569 static int
okCipher(Ints * cv)1570 okCipher(Ints *cv)
1571 {
1572 	int weak, i, j, c;
1573 
1574 	weak = 1;
1575 	for(i = 0; i < cv->len; i++) {
1576 		c = cv->data[i];
1577 		if(c >= CipherMax)
1578 			weak = 0;
1579 		else
1580 			weak &= weakCipher[c];
1581 		for(j = 0; j < nelem(cipherAlgs); j++)
1582 			if(cipherAlgs[j].ok && cipherAlgs[j].tlsid == c)
1583 				return c;
1584 	}
1585 	if(weak)
1586 		return -2;
1587 	return -1;
1588 }
1589 
1590 static int
okCompression(Bytes * cv)1591 okCompression(Bytes *cv)
1592 {
1593 	int i, j, c;
1594 
1595 	for(i = 0; i < cv->len; i++) {
1596 		c = cv->data[i];
1597 		for(j = 0; j < nelem(compressors); j++) {
1598 			if(compressors[j] == c)
1599 				return c;
1600 		}
1601 	}
1602 	return -1;
1603 }
1604 
1605 static Lock	ciphLock;
1606 static int	nciphers;
1607 
1608 static int
initCiphers(void)1609 initCiphers(void)
1610 {
1611 	enum {MaxAlgF = 1024, MaxAlgs = 10};
1612 	char s[MaxAlgF], *flds[MaxAlgs];
1613 	int i, j, n, ok;
1614 
1615 	lock(&ciphLock);
1616 	if(nciphers){
1617 		unlock(&ciphLock);
1618 		return nciphers;
1619 	}
1620 	j = open("#a/tls/encalgs", OREAD);
1621 	if(j < 0){
1622 		werrstr("can't open #a/tls/encalgs: %r");
1623 		unlock(&ciphLock);
1624 		return 0;
1625 	}
1626 	n = read(j, s, MaxAlgF-1);
1627 	close(j);
1628 	if(n <= 0){
1629 		werrstr("nothing in #a/tls/encalgs: %r");
1630 		unlock(&ciphLock);
1631 		return 0;
1632 	}
1633 	s[n] = 0;
1634 	n = getfields(s, flds, MaxAlgs, 1, " \t\r\n");
1635 	for(i = 0; i < nelem(cipherAlgs); i++){
1636 		ok = 0;
1637 		for(j = 0; j < n; j++){
1638 			if(strcmp(cipherAlgs[i].enc, flds[j]) == 0){
1639 				ok = 1;
1640 				break;
1641 			}
1642 		}
1643 		cipherAlgs[i].ok = ok;
1644 	}
1645 
1646 	j = open("#a/tls/hashalgs", OREAD);
1647 	if(j < 0){
1648 		werrstr("can't open #a/tls/hashalgs: %r");
1649 		unlock(&ciphLock);
1650 		return 0;
1651 	}
1652 	n = read(j, s, MaxAlgF-1);
1653 	close(j);
1654 	if(n <= 0){
1655 		werrstr("nothing in #a/tls/hashalgs: %r");
1656 		unlock(&ciphLock);
1657 		return 0;
1658 	}
1659 	s[n] = 0;
1660 	n = getfields(s, flds, MaxAlgs, 1, " \t\r\n");
1661 	for(i = 0; i < nelem(cipherAlgs); i++){
1662 		ok = 0;
1663 		for(j = 0; j < n; j++){
1664 			if(strcmp(cipherAlgs[i].digest, flds[j]) == 0){
1665 				ok = 1;
1666 				break;
1667 			}
1668 		}
1669 		cipherAlgs[i].ok &= ok;
1670 		if(cipherAlgs[i].ok)
1671 			nciphers++;
1672 	}
1673 	unlock(&ciphLock);
1674 	return nciphers;
1675 }
1676 
1677 static Ints*
makeciphers(void)1678 makeciphers(void)
1679 {
1680 	Ints *is;
1681 	int i, j;
1682 
1683 	is = newints(nciphers);
1684 	j = 0;
1685 	for(i = 0; i < nelem(cipherAlgs); i++){
1686 		if(cipherAlgs[i].ok)
1687 			is->data[j++] = cipherAlgs[i].tlsid;
1688 	}
1689 	return is;
1690 }
1691 
1692 
1693 
1694 //================= security functions ========================
1695 
1696 // given X.509 certificate, set up connection to factotum
1697 //	for using corresponding private key
1698 static AuthRpc*
factotum_rsa_open(uchar * cert,int certlen)1699 factotum_rsa_open(uchar *cert, int certlen)
1700 {
1701 	int afd;
1702 	char *s;
1703 	mpint *pub = nil;
1704 	RSApub *rsapub;
1705 	AuthRpc *rpc;
1706 
1707 	// start talking to factotum
1708 	if((afd = open("/mnt/factotum/rpc", ORDWR)) < 0)
1709 		return nil;
1710 	if((rpc = auth_allocrpc(afd)) == nil){
1711 		close(afd);
1712 		return nil;
1713 	}
1714 	s = "proto=rsa service=tls role=client";
1715 	if(auth_rpc(rpc, "start", s, strlen(s)) != ARok){
1716 		factotum_rsa_close(rpc);
1717 		return nil;
1718 	}
1719 
1720 	// roll factotum keyring around to match certificate
1721 	rsapub = X509toRSApub(cert, certlen, nil, 0);
1722 	while(1){
1723 		if(auth_rpc(rpc, "read", nil, 0) != ARok){
1724 			factotum_rsa_close(rpc);
1725 			rpc = nil;
1726 			goto done;
1727 		}
1728 		pub = strtomp(rpc->arg, nil, 16, nil);
1729 		assert(pub != nil);
1730 		if(mpcmp(pub,rsapub->n) == 0)
1731 			break;
1732 	}
1733 done:
1734 	mpfree(pub);
1735 	rsapubfree(rsapub);
1736 	return rpc;
1737 }
1738 
1739 static mpint*
factotum_rsa_decrypt(AuthRpc * rpc,mpint * cipher)1740 factotum_rsa_decrypt(AuthRpc *rpc, mpint *cipher)
1741 {
1742 	char *p;
1743 	int rv;
1744 
1745 	if((p = mptoa(cipher, 16, nil, 0)) == nil)
1746 		return nil;
1747 	rv = auth_rpc(rpc, "write", p, strlen(p));
1748 	free(p);
1749 	if(rv != ARok || auth_rpc(rpc, "read", nil, 0) != ARok)
1750 		return nil;
1751 	mpfree(cipher);
1752 	return strtomp(rpc->arg, nil, 16, nil);
1753 }
1754 
1755 static void
factotum_rsa_close(AuthRpc * rpc)1756 factotum_rsa_close(AuthRpc*rpc)
1757 {
1758 	if(!rpc)
1759 		return;
1760 	close(rpc->afd);
1761 	auth_freerpc(rpc);
1762 }
1763 
1764 static void
tlsPmd5(uchar * buf,int nbuf,uchar * key,int nkey,uchar * label,int nlabel,uchar * seed0,int nseed0,uchar * seed1,int nseed1)1765 tlsPmd5(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
1766 {
1767 	uchar ai[MD5dlen], tmp[MD5dlen];
1768 	int i, n;
1769 	MD5state *s;
1770 
1771 	// generate a1
1772 	s = hmac_md5(label, nlabel, key, nkey, nil, nil);
1773 	s = hmac_md5(seed0, nseed0, key, nkey, nil, s);
1774 	hmac_md5(seed1, nseed1, key, nkey, ai, s);
1775 
1776 	while(nbuf > 0) {
1777 		s = hmac_md5(ai, MD5dlen, key, nkey, nil, nil);
1778 		s = hmac_md5(label, nlabel, key, nkey, nil, s);
1779 		s = hmac_md5(seed0, nseed0, key, nkey, nil, s);
1780 		hmac_md5(seed1, nseed1, key, nkey, tmp, s);
1781 		n = MD5dlen;
1782 		if(n > nbuf)
1783 			n = nbuf;
1784 		for(i = 0; i < n; i++)
1785 			buf[i] ^= tmp[i];
1786 		buf += n;
1787 		nbuf -= n;
1788 		hmac_md5(ai, MD5dlen, key, nkey, tmp, nil);
1789 		memmove(ai, tmp, MD5dlen);
1790 	}
1791 }
1792 
1793 static void
tlsPsha1(uchar * buf,int nbuf,uchar * key,int nkey,uchar * label,int nlabel,uchar * seed0,int nseed0,uchar * seed1,int nseed1)1794 tlsPsha1(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
1795 {
1796 	uchar ai[SHA1dlen], tmp[SHA1dlen];
1797 	int i, n;
1798 	SHAstate *s;
1799 
1800 	// generate a1
1801 	s = hmac_sha1(label, nlabel, key, nkey, nil, nil);
1802 	s = hmac_sha1(seed0, nseed0, key, nkey, nil, s);
1803 	hmac_sha1(seed1, nseed1, key, nkey, ai, s);
1804 
1805 	while(nbuf > 0) {
1806 		s = hmac_sha1(ai, SHA1dlen, key, nkey, nil, nil);
1807 		s = hmac_sha1(label, nlabel, key, nkey, nil, s);
1808 		s = hmac_sha1(seed0, nseed0, key, nkey, nil, s);
1809 		hmac_sha1(seed1, nseed1, key, nkey, tmp, s);
1810 		n = SHA1dlen;
1811 		if(n > nbuf)
1812 			n = nbuf;
1813 		for(i = 0; i < n; i++)
1814 			buf[i] ^= tmp[i];
1815 		buf += n;
1816 		nbuf -= n;
1817 		hmac_sha1(ai, SHA1dlen, key, nkey, tmp, nil);
1818 		memmove(ai, tmp, SHA1dlen);
1819 	}
1820 }
1821 
1822 static void
tlsPsha2_256(uchar * buf,int nbuf,uchar * key,int nkey,uchar * label,int nlabel,uchar * seed,int nseed)1823 tlsPsha2_256(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, uchar *seed, int nseed)
1824 {
1825 	uchar ai[SHA2_256dlen], tmp[SHA2_256dlen];
1826 	int n;
1827 	SHAstate *s;
1828 
1829 	// generate a1
1830 	s = hmac_sha2_256(label, nlabel, key, nkey, nil, nil);
1831 	hmac_sha2_256(seed, nseed, key, nkey, ai, s);
1832 
1833 	while(nbuf > 0) {
1834 		s = hmac_sha2_256(ai, SHA2_256dlen, key, nkey, nil, nil);
1835 		s = hmac_sha2_256(label, nlabel, key, nkey, nil, s);
1836 		hmac_sha2_256(seed, nseed, key, nkey, tmp, s);
1837 		n = SHA2_256dlen;
1838 		if(n > nbuf)
1839 			n = nbuf;
1840 		memmove(buf, tmp, n);
1841 		buf += n;
1842 		nbuf -= n;
1843 		hmac_sha2_256(ai, SHA2_256dlen, key, nkey, tmp, nil);
1844 		memmove(ai, tmp, SHA2_256dlen);
1845 	}
1846 }
1847 
1848 // fill buf with md5(args)^sha1(args)
1849 static void
tlsPRF(uchar * buf,int nbuf,uchar * key,int nkey,char * label,uchar * seed0,int nseed0,uchar * seed1,int nseed1)1850 tlsPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
1851 {
1852 	int i;
1853 	int nlabel = strlen(label);
1854 	int n = (nkey + 1) >> 1;
1855 
1856 	for(i = 0; i < nbuf; i++)
1857 		buf[i] = 0;
1858 	tlsPmd5(buf, nbuf, key, n, (uchar*)label, nlabel, seed0, nseed0, seed1, nseed1);
1859 	tlsPsha1(buf, nbuf, key+nkey-n, n, (uchar*)label, nlabel, seed0, nseed0, seed1, nseed1);
1860 }
1861 
1862 void
tls12PRF(uchar * buf,int nbuf,uchar * key,int nkey,char * label,uchar * seed0,int nseed0,uchar * seed1,int nseed1)1863 tls12PRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
1864 {
1865 	uchar seed[2*RandomSize];
1866 	int nlabel = strlen(label);
1867 
1868 	memmove(seed, seed0, nseed0);
1869 	memmove(seed+nseed0, seed1, nseed1);
1870 	tlsPsha2_256(buf, nbuf, key, nkey, (uchar*)label, nlabel, seed, nseed0+nseed1);
1871 }
1872 
1873 /*
1874  * for setting server session id's
1875  */
1876 static Lock	sidLock;
1877 static long	maxSid = 1;
1878 
1879 /* the keys are verified to have the same public components
1880  * and to function correctly with pkcs 1 encryption and decryption. */
1881 static TlsSec*
tlsSecInits(int cvers,uchar * csid,int ncsid,uchar * crandom,uchar * ssid,int * nssid,uchar * srandom)1882 tlsSecInits(int cvers, uchar *csid, int ncsid, uchar *crandom, uchar *ssid, int *nssid, uchar *srandom)
1883 {
1884 	TlsSec *sec = emalloc(sizeof(*sec));
1885 
1886 	USED(csid); USED(ncsid);  // ignore csid for now
1887 
1888 	memmove(sec->crandom, crandom, RandomSize);
1889 	sec->clientVers = cvers;
1890 
1891 	put32(sec->srandom, time(0));
1892 	genrandom(sec->srandom+4, RandomSize-4);
1893 	memmove(srandom, sec->srandom, RandomSize);
1894 
1895 	/*
1896 	 * make up a unique sid: use our pid, and and incrementing id
1897 	 * can signal no sid by setting nssid to 0.
1898 	 */
1899 	memset(ssid, 0, SidSize);
1900 	put32(ssid, getpid());
1901 	lock(&sidLock);
1902 	put32(ssid+4, maxSid++);
1903 	unlock(&sidLock);
1904 	*nssid = SidSize;
1905 	return sec;
1906 }
1907 
1908 static int
tlsSecSecrets(TlsSec * sec,int vers,uchar * epm,int nepm,uchar * kd,int nkd)1909 tlsSecSecrets(TlsSec *sec, int vers, uchar *epm, int nepm, uchar *kd, int nkd)
1910 {
1911 	if(epm != nil){
1912 		if(setVers(sec, vers) < 0)
1913 			goto Err;
1914 		serverMasterSecret(sec, epm, nepm);
1915 	}else if(sec->vers != vers){
1916 		werrstr("mismatched session versions");
1917 		goto Err;
1918 	}
1919 	setSecrets(sec, kd, nkd);
1920 	return 0;
1921 Err:
1922 	sec->ok = -1;
1923 	return -1;
1924 }
1925 
1926 static TlsSec*
tlsSecInitc(int cvers,uchar * crandom)1927 tlsSecInitc(int cvers, uchar *crandom)
1928 {
1929 	TlsSec *sec = emalloc(sizeof(*sec));
1930 	sec->clientVers = cvers;
1931 	put32(sec->crandom, time(0));
1932 	genrandom(sec->crandom+4, RandomSize-4);
1933 	memmove(crandom, sec->crandom, RandomSize);
1934 	return sec;
1935 }
1936 
1937 static int
tlsSecSecretc(TlsSec * sec,uchar * sid,int nsid,uchar * srandom,uchar * cert,int ncert,int vers,uchar ** epm,int * nepm,uchar * kd,int nkd)1938 tlsSecSecretc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers, uchar **epm, int *nepm, uchar *kd, int nkd)
1939 {
1940 	RSApub *pub;
1941 
1942 	pub = nil;
1943 
1944 	USED(sid);
1945 	USED(nsid);
1946 
1947 	memmove(sec->srandom, srandom, RandomSize);
1948 
1949 	if(setVers(sec, vers) < 0)
1950 		goto Err;
1951 
1952 	pub = X509toRSApub(cert, ncert, nil, 0);
1953 	if(pub == nil){
1954 		werrstr("invalid x509/rsa certificate");
1955 		goto Err;
1956 	}
1957 	if(clientMasterSecret(sec, pub, epm, nepm) < 0)
1958 		goto Err;
1959 	rsapubfree(pub);
1960 	setSecrets(sec, kd, nkd);
1961 	return 0;
1962 
1963 Err:
1964 	if(pub != nil)
1965 		rsapubfree(pub);
1966 	sec->ok = -1;
1967 	return -1;
1968 }
1969 
1970 static int
tlsSecFinished(TlsSec * sec,HandHash hs,uchar * fin,int nfin,int isclient)1971 tlsSecFinished(TlsSec *sec, HandHash hs, uchar *fin, int nfin, int isclient)
1972 {
1973 	if(sec->nfin != nfin){
1974 		sec->ok = -1;
1975 		werrstr("invalid finished exchange");
1976 		return -1;
1977 	}
1978 	hs.md5.malloced = 0;
1979 	hs.sha1.malloced = 0;
1980 	if(sec->setFinished == nil ){
1981 		sec->ok = -1;
1982 		werrstr("nil sec->setFinished in tlsSecFinished");
1983 		return -1;
1984 	}
1985 	(*sec->setFinished)(sec, hs, fin, isclient);
1986 	return 1;
1987 }
1988 
1989 static void
tlsSecOk(TlsSec * sec)1990 tlsSecOk(TlsSec *sec)
1991 {
1992 	if(sec->ok == 0)
1993 		sec->ok = 1;
1994 }
1995 
1996 static void
tlsSecKill(TlsSec * sec)1997 tlsSecKill(TlsSec *sec)
1998 {
1999 	if(!sec)
2000 		return;
2001 	factotum_rsa_close(sec->rpc);
2002 	sec->ok = -1;
2003 }
2004 
2005 static void
tlsSecClose(TlsSec * sec)2006 tlsSecClose(TlsSec *sec)
2007 {
2008 	if(!sec)
2009 		return;
2010 	factotum_rsa_close(sec->rpc);
2011 	free(sec->server);
2012 	free(sec);
2013 }
2014 
2015 static int
setVers(TlsSec * sec,int v)2016 setVers(TlsSec *sec, int v)
2017 {
2018 	switch(v){
2019 	case SSL3Version:
2020 		sec->setFinished = sslSetFinished;
2021 		sec->nfin = SSL3FinishedLen;
2022 		sec->prf = sslPRF;
2023 		break;
2024 	case TLS10Version:
2025 	case TLS11Version:
2026 		sec->setFinished = tlsSetFinished;
2027 		sec->nfin = TLSFinishedLen;
2028 		sec->prf = tlsPRF;
2029 		break;
2030 	case TLS12Version:
2031 		sec->setFinished = tls12SetFinished;
2032 		sec->nfin = TLSFinishedLen;
2033 		sec->prf = tls12PRF;
2034 		break;
2035 	default:
2036 		werrstr("invalid version");
2037 		sec->setFinished = nil;
2038 		sec->prf = nil;
2039 		return -1;
2040 	}
2041 	sec->vers = v;
2042 	return 0;
2043 }
2044 
2045 /*
2046  * generate secret keys from the master secret.
2047  *
2048  * different crypto selections will require different amounts
2049  * of key expansion and use of key expansion data,
2050  * but it's all generated using the same function.
2051  */
2052 static void
setSecrets(TlsSec * sec,uchar * kd,int nkd)2053 setSecrets(TlsSec *sec, uchar *kd, int nkd)
2054 {
2055 	if (sec->prf == nil) {
2056 		werrstr("nil sec->prf in setSecrets");
2057 		return;
2058 	}
2059 	(*sec->prf)(kd, nkd, sec->sec, MasterSecretSize, "key expansion",
2060 			sec->srandom, RandomSize, sec->crandom, RandomSize);
2061 }
2062 
2063 /*
2064  * set the master secret from the pre-master secret.
2065  */
2066 static void
setMasterSecret(TlsSec * sec,Bytes * pm)2067 setMasterSecret(TlsSec *sec, Bytes *pm)
2068 {
2069 	(*sec->prf)(sec->sec, MasterSecretSize, pm->data, MasterSecretSize, "master secret",
2070 			sec->crandom, RandomSize, sec->srandom, RandomSize);
2071 }
2072 
2073 static void
serverMasterSecret(TlsSec * sec,uchar * epm,int nepm)2074 serverMasterSecret(TlsSec *sec, uchar *epm, int nepm)
2075 {
2076 	Bytes *pm;
2077 
2078 	pm = pkcs1_decrypt(sec, epm, nepm);
2079 
2080 	// if the client messed up, just continue as if everything is ok,
2081 	// to prevent attacks to check for correctly formatted messages.
2082 	// Hence the fprint(2,) can't be replaced by tlsError(), which sends an Alert msg to the client.
2083 	if(sec->ok < 0 || pm == nil || get16(pm->data) != sec->clientVers){
2084 		fprint(2, "serverMasterSecret failed ok=%d pm=%p pmvers=%x cvers=%x nepm=%d\n",
2085 			sec->ok, pm, pm ? get16(pm->data) : -1, sec->clientVers, nepm);
2086 		sec->ok = -1;
2087 		if(pm != nil)
2088 			freebytes(pm);
2089 		pm = newbytes(MasterSecretSize);
2090 		genrandom(pm->data, MasterSecretSize);
2091 	}
2092 	setMasterSecret(sec, pm);
2093 	memset(pm->data, 0, pm->len);
2094 	freebytes(pm);
2095 }
2096 
2097 static int
clientMasterSecret(TlsSec * sec,RSApub * pub,uchar ** epm,int * nepm)2098 clientMasterSecret(TlsSec *sec, RSApub *pub, uchar **epm, int *nepm)
2099 {
2100 	Bytes *pm, *key;
2101 
2102 	pm = newbytes(MasterSecretSize);
2103 	put16(pm->data, sec->clientVers);
2104 	genrandom(pm->data+2, MasterSecretSize - 2);
2105 
2106 	setMasterSecret(sec, pm);
2107 
2108 	key = pkcs1_encrypt(pm, pub, 2);
2109 	memset(pm->data, 0, pm->len);
2110 	freebytes(pm);
2111 	if(key == nil){
2112 		werrstr("tls pkcs1_encrypt failed");
2113 		return -1;
2114 	}
2115 
2116 	*nepm = key->len;
2117 	*epm = malloc(*nepm);
2118 	if(*epm == nil){
2119 		freebytes(key);
2120 		werrstr("out of memory");
2121 		return -1;
2122 	}
2123 	memmove(*epm, key->data, *nepm);
2124 
2125 	freebytes(key);
2126 
2127 	return 1;
2128 }
2129 
2130 static void
sslSetFinished(TlsSec * sec,HandHash hs,uchar * finished,int isClient)2131 sslSetFinished(TlsSec *sec, HandHash hs, uchar *finished, int isClient)
2132 {
2133 	DigestState *s;
2134 	uchar h0[MD5dlen], h1[SHA1dlen], pad[48];
2135 	char *label;
2136 
2137 	if(isClient)
2138 		label = "CLNT";
2139 	else
2140 		label = "SRVR";
2141 
2142 	md5((uchar*)label, 4, nil, &hs.md5);
2143 	md5(sec->sec, MasterSecretSize, nil, &hs.md5);
2144 	memset(pad, 0x36, 48);
2145 	md5(pad, 48, nil, &hs.md5);
2146 	md5(nil, 0, h0, &hs.md5);
2147 	memset(pad, 0x5C, 48);
2148 	s = md5(sec->sec, MasterSecretSize, nil, nil);
2149 	s = md5(pad, 48, nil, s);
2150 	md5(h0, MD5dlen, finished, s);
2151 
2152 	sha1((uchar*)label, 4, nil, &hs.sha1);
2153 	sha1(sec->sec, MasterSecretSize, nil, &hs.sha1);
2154 	memset(pad, 0x36, 40);
2155 	sha1(pad, 40, nil, &hs.sha1);
2156 	sha1(nil, 0, h1, &hs.sha1);
2157 	memset(pad, 0x5C, 40);
2158 	s = sha1(sec->sec, MasterSecretSize, nil, nil);
2159 	s = sha1(pad, 40, nil, s);
2160 	sha1(h1, SHA1dlen, finished + MD5dlen, s);
2161 }
2162 
2163 // fill "finished" arg with md5(args)^sha1(args)
2164 static void
tlsSetFinished(TlsSec * sec,HandHash hs,uchar * finished,int isClient)2165 tlsSetFinished(TlsSec *sec, HandHash hs, uchar *finished, int isClient)
2166 {
2167 	uchar h0[MD5dlen], h1[SHA1dlen];
2168 	char *label;
2169 
2170 	// get current hash value, but allow further messages to be hashed in
2171 	md5(nil, 0, h0, &hs.md5);
2172 	sha1(nil, 0, h1, &hs.sha1);
2173 
2174 	if(isClient)
2175 		label = "client finished";
2176 	else
2177 		label = "server finished";
2178 	if (sec->prf == nil) {
2179 		werrstr("nil sec->prf in tlsSetFinished");
2180 		return;
2181 	}
2182 	(*sec->prf)(finished, TLSFinishedLen, sec->sec, MasterSecretSize, label, h0, MD5dlen, h1, SHA1dlen);
2183 }
2184 
2185 // fill "finished" arg with sha256(args)
2186 static void
tls12SetFinished(TlsSec * sec,HandHash hs,uchar * finished,int isClient)2187 tls12SetFinished(TlsSec *sec, HandHash hs, uchar *finished, int isClient)
2188 {
2189 	uchar h[SHA2_256dlen];
2190 	char *label;
2191 
2192 	// get current hash value, but allow further messages to be hashed in
2193 	sha2_256(nil, 0, h, &hs.sha2_256);
2194 
2195 	if(isClient)
2196 		label = "client finished";
2197 	else
2198 		label = "server finished";
2199 	tlsPsha2_256(finished, TLSFinishedLen, sec->sec, MasterSecretSize, (uchar*)label, strlen(label), h, SHA2_256dlen);
2200 }
2201 
2202 static void
sslPRF(uchar * buf,int nbuf,uchar * key,int nkey,char * label,uchar * seed0,int nseed0,uchar * seed1,int nseed1)2203 sslPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
2204 {
2205 	DigestState *s;
2206 	uchar sha1dig[SHA1dlen], md5dig[MD5dlen], tmp[26];
2207 	int i, n, len;
2208 
2209 	USED(label);
2210 	len = 1;
2211 	while(nbuf > 0){
2212 		if(len > 26)
2213 			return;
2214 		for(i = 0; i < len; i++)
2215 			tmp[i] = 'A' - 1 + len;
2216 		s = sha1(tmp, len, nil, nil);
2217 		s = sha1(key, nkey, nil, s);
2218 		s = sha1(seed0, nseed0, nil, s);
2219 		sha1(seed1, nseed1, sha1dig, s);
2220 		s = md5(key, nkey, nil, nil);
2221 		md5(sha1dig, SHA1dlen, md5dig, s);
2222 		n = MD5dlen;
2223 		if(n > nbuf)
2224 			n = nbuf;
2225 		memmove(buf, md5dig, n);
2226 		buf += n;
2227 		nbuf -= n;
2228 		len++;
2229 	}
2230 }
2231 
2232 static mpint*
bytestomp(Bytes * bytes)2233 bytestomp(Bytes* bytes)
2234 {
2235 	mpint* ans;
2236 
2237 	ans = betomp(bytes->data, bytes->len, nil);
2238 	return ans;
2239 }
2240 
2241 /*
2242  * Convert mpint* to Bytes, putting high order byte first.
2243  */
2244 static Bytes*
mptobytes(mpint * big)2245 mptobytes(mpint* big)
2246 {
2247 	int n, m;
2248 	uchar *a;
2249 	Bytes* ans;
2250 
2251 	a = nil;
2252 	n = (mpsignif(big)+7)/8;
2253 	m = mptobe(big, nil, n, &a);
2254 	ans = makebytes(a, m);
2255 	if(a != nil)
2256 		free(a);
2257 	return ans;
2258 }
2259 
2260 // Do RSA computation on block according to key, and pad
2261 // result on left with zeros to make it modlen long.
2262 static Bytes*
rsacomp(Bytes * block,RSApub * key,int modlen)2263 rsacomp(Bytes* block, RSApub* key, int modlen)
2264 {
2265 	mpint *x, *y;
2266 	Bytes *a, *ybytes;
2267 	int ylen;
2268 
2269 	x = bytestomp(block);
2270 	y = rsaencrypt(key, x, nil);
2271 	mpfree(x);
2272 	ybytes = mptobytes(y);
2273 	ylen = ybytes->len;
2274 
2275 	if(ylen < modlen) {
2276 		a = newbytes(modlen);
2277 		memset(a->data, 0, modlen-ylen);
2278 		memmove(a->data+modlen-ylen, ybytes->data, ylen);
2279 		freebytes(ybytes);
2280 		ybytes = a;
2281 	}
2282 	else if(ylen > modlen) {
2283 		// assume it has leading zeros (mod should make it so)
2284 		a = newbytes(modlen);
2285 		memmove(a->data, ybytes->data, modlen);
2286 		freebytes(ybytes);
2287 		ybytes = a;
2288 	}
2289 	mpfree(y);
2290 	return ybytes;
2291 }
2292 
2293 // encrypt data according to PKCS#1, /lib/rfc/rfc2437 9.1.2.1
2294 static Bytes*
pkcs1_encrypt(Bytes * data,RSApub * key,int blocktype)2295 pkcs1_encrypt(Bytes* data, RSApub* key, int blocktype)
2296 {
2297 	Bytes *pad, *eb, *ans;
2298 	int i, dlen, padlen, modlen;
2299 
2300 	modlen = (mpsignif(key->n)+7)/8;
2301 	dlen = data->len;
2302 	if(modlen < 12 || dlen > modlen - 11)
2303 		return nil;
2304 	padlen = modlen - 3 - dlen;
2305 	pad = newbytes(padlen);
2306 	genrandom(pad->data, padlen);
2307 	for(i = 0; i < padlen; i++) {
2308 		if(blocktype == 0)
2309 			pad->data[i] = 0;
2310 		else if(blocktype == 1)
2311 			pad->data[i] = 255;
2312 		else if(pad->data[i] == 0)
2313 			pad->data[i] = 1;
2314 	}
2315 	eb = newbytes(modlen);
2316 	eb->data[0] = 0;
2317 	eb->data[1] = blocktype;
2318 	memmove(eb->data+2, pad->data, padlen);
2319 	eb->data[padlen+2] = 0;
2320 	memmove(eb->data+padlen+3, data->data, dlen);
2321 	ans = rsacomp(eb, key, modlen);
2322 	freebytes(eb);
2323 	freebytes(pad);
2324 	return ans;
2325 }
2326 
2327 // decrypt data according to PKCS#1, with given key.
2328 // expect a block type of 2.
2329 static Bytes*
pkcs1_decrypt(TlsSec * sec,uchar * epm,int nepm)2330 pkcs1_decrypt(TlsSec *sec, uchar *epm, int nepm)
2331 {
2332 	Bytes *eb, *ans = nil;
2333 	int i, modlen;
2334 	mpint *x, *y;
2335 
2336 	modlen = (mpsignif(sec->rsapub->n)+7)/8;
2337 	if(nepm != modlen)
2338 		return nil;
2339 	x = betomp(epm, nepm, nil);
2340 	y = factotum_rsa_decrypt(sec->rpc, x);
2341 	if(y == nil)
2342 		return nil;
2343 	eb = mptobytes(y);
2344 	if(eb->len < modlen){ // pad on left with zeros
2345 		ans = newbytes(modlen);
2346 		memset(ans->data, 0, modlen-eb->len);
2347 		memmove(ans->data+modlen-eb->len, eb->data, eb->len);
2348 		freebytes(eb);
2349 		eb = ans;
2350 	}
2351 	if(eb->data[0] == 0 && eb->data[1] == 2) {
2352 		for(i = 2; i < modlen; i++)
2353 			if(eb->data[i] == 0)
2354 				break;
2355 		if(i < modlen - 1)
2356 			ans = makebytes(eb->data+i+1, modlen-(i+1));
2357 	}
2358 	if (eb != ans)			/* not freed above? */
2359 		freebytes(eb);
2360 	return ans;
2361 }
2362 
2363 
2364 //================= general utility functions ========================
2365 
2366 static void *
emalloc(int n)2367 emalloc(int n)
2368 {
2369 	void *p;
2370 	if(n==0)
2371 		n=1;
2372 	p = malloc(n);
2373 	if(p == nil){
2374 		exits("out of memory");
2375 	}
2376 	memset(p, 0, n);
2377 	return p;
2378 }
2379 
2380 static void *
erealloc(void * ReallocP,int ReallocN)2381 erealloc(void *ReallocP, int ReallocN)
2382 {
2383 	if(ReallocN == 0)
2384 		ReallocN = 1;
2385 	if(!ReallocP)
2386 		ReallocP = emalloc(ReallocN);
2387 	else if(!(ReallocP = realloc(ReallocP, ReallocN))){
2388 		exits("out of memory");
2389 	}
2390 	return(ReallocP);
2391 }
2392 
2393 static void
put32(uchar * p,u32int x)2394 put32(uchar *p, u32int x)
2395 {
2396 	p[0] = x>>24;
2397 	p[1] = x>>16;
2398 	p[2] = x>>8;
2399 	p[3] = x;
2400 }
2401 
2402 static void
put24(uchar * p,int x)2403 put24(uchar *p, int x)
2404 {
2405 	p[0] = x>>16;
2406 	p[1] = x>>8;
2407 	p[2] = x;
2408 }
2409 
2410 static void
put16(uchar * p,int x)2411 put16(uchar *p, int x)
2412 {
2413 	p[0] = x>>8;
2414 	p[1] = x;
2415 }
2416 
2417 static u32int
get32(uchar * p)2418 get32(uchar *p)
2419 {
2420 	return (p[0]<<24)|(p[1]<<16)|(p[2]<<8)|p[3];
2421 }
2422 
2423 static int
get24(uchar * p)2424 get24(uchar *p)
2425 {
2426 	return (p[0]<<16)|(p[1]<<8)|p[2];
2427 }
2428 
2429 static int
get16(uchar * p)2430 get16(uchar *p)
2431 {
2432 	return (p[0]<<8)|p[1];
2433 }
2434 
2435 #define OFFSET(x, s) offsetof(s, x)
2436 
2437 /*
2438  * malloc and return a new Bytes structure capable of
2439  * holding len bytes. (len >= 0)
2440  * Used to use crypt_malloc, which aborts if malloc fails.
2441  */
2442 static Bytes*
newbytes(int len)2443 newbytes(int len)
2444 {
2445 	Bytes* ans;
2446 
2447 	ans = (Bytes*)malloc(OFFSET(data[0], Bytes) + len);
2448 	ans->len = len;
2449 	return ans;
2450 }
2451 
2452 /*
2453  * newbytes(len), with data initialized from buf
2454  */
2455 static Bytes*
makebytes(uchar * buf,int len)2456 makebytes(uchar* buf, int len)
2457 {
2458 	Bytes* ans;
2459 
2460 	ans = newbytes(len);
2461 	memmove(ans->data, buf, len);
2462 	return ans;
2463 }
2464 
2465 static void
freebytes(Bytes * b)2466 freebytes(Bytes* b)
2467 {
2468 	if(b != nil)
2469 		free(b);
2470 }
2471 
2472 /* len is number of ints */
2473 static Ints*
newints(int len)2474 newints(int len)
2475 {
2476 	Ints* ans;
2477 
2478 	ans = (Ints*)malloc(OFFSET(data[0], Ints) + len*sizeof(int));
2479 	ans->len = len;
2480 	return ans;
2481 }
2482 
2483 static Ints*
makeints(int * buf,int len)2484 makeints(int* buf, int len)
2485 {
2486 	Ints* ans;
2487 
2488 	ans = newints(len);
2489 	if(len > 0)
2490 		memmove(ans->data, buf, len*sizeof(int));
2491 	return ans;
2492 }
2493 
2494 static void
freeints(Ints * b)2495 freeints(Ints* b)
2496 {
2497 	if(b != nil)
2498 		free(b);
2499 }
2500