xref: /plan9/sys/src/libsec/port/tlshand.c (revision 3468a4915d661daa200976acc4f80f51aae144b2)
1 #include <u.h>
2 #include <libc.h>
3 #include <bio.h>
4 #include <auth.h>
5 #include <mp.h>
6 #include <libsec.h>
7 
8 // The main groups of functions are:
9 //		client/server - main handshake protocol definition
10 //		message functions - formating handshake messages
11 //		cipher choices - catalog of digest and encrypt algorithms
12 //		security functions - PKCS#1, sslHMAC, session keygen
13 //		general utility functions - malloc, serialization
14 // The handshake protocol builds on the TLS/SSL3 record layer protocol,
15 // which is implemented in kernel device #a.  See also /lib/rfc/rfc2246.
16 
17 enum {
18 	TLSFinishedLen = 12,
19 	SSL3FinishedLen = MD5dlen+SHA1dlen,
20 	MaxKeyData = 104,	// amount of secret we may need
21 	MaxChunk = 1<<14,
22 	RandomSize = 32,
23 	SidSize = 32,
24 	MasterSecretSize = 48,
25 	AQueue = 0,
26 	AFlush = 1,
27 };
28 
29 typedef struct TlsSec TlsSec;
30 
31 typedef struct Bytes{
32 	int len;
33 	uchar data[1];  // [len]
34 } Bytes;
35 
36 typedef struct Ints{
37 	int len;
38 	int data[1];  // [len]
39 } Ints;
40 
41 typedef struct Algs{
42 	char *enc;
43 	char *digest;
44 	int nsecret;
45 	int tlsid;
46 	int ok;
47 } Algs;
48 
49 typedef struct Finished{
50 	uchar verify[SSL3FinishedLen];
51 	int n;
52 } Finished;
53 
54 typedef struct TlsConnection{
55 	TlsSec *sec;	// security management goo
56 	int hand, ctl;	// record layer file descriptors
57 	int erred;		// set when tlsError called
58 	int (*trace)(char*fmt, ...); // for debugging
59 	int version;	// protocol we are speaking
60 	int verset;		// version has been set
61 	int ver2hi;		// server got a version 2 hello
62 	int isClient;	// is this the client or server?
63 	Bytes *sid;		// SessionID
64 	Bytes *cert;	// only last - no chain
65 
66 	Lock statelk;
67 	int state;		// must be set using setstate
68 
69 	// input buffer for handshake messages
70 	uchar buf[MaxChunk+2048];
71 	uchar *rp, *ep;
72 
73 	uchar crandom[RandomSize];	// client random
74 	uchar srandom[RandomSize];	// server random
75 	int clientVersion;	// version in ClientHello
76 	char *digest;	// name of digest algorithm to use
77 	char *enc;		// name of encryption algorithm to use
78 	int nsecret;	// amount of secret data to init keys
79 
80 	// for finished messages
81 	MD5state	hsmd5;	// handshake hash
82 	SHAstate	hssha1;	// handshake hash
83 	Finished	finished;
84 } TlsConnection;
85 
86 typedef struct Msg{
87 	int tag;
88 	union {
89 		struct {
90 			int version;
91 			uchar 	random[RandomSize];
92 			Bytes*	sid;
93 			Ints*	ciphers;
94 			Bytes*	compressors;
95 		} clientHello;
96 		struct {
97 			int version;
98 			uchar 	random[RandomSize];
99 			Bytes*	sid;
100 			int cipher;
101 			int compressor;
102 		} serverHello;
103 		struct {
104 			int ncert;
105 			Bytes **certs;
106 		} certificate;
107 		struct {
108 			Bytes *types;
109 			int nca;
110 			Bytes **cas;
111 		} certificateRequest;
112 		struct {
113 			Bytes *key;
114 		} clientKeyExchange;
115 		Finished finished;
116 	} u;
117 } Msg;
118 
119 typedef struct TlsSec{
120 	char *server;	// name of remote; nil for server
121 	int ok;	// <0 killed; == 0 in progress; >0 reusable
122 	RSApub *rsapub;
123 	AuthRpc *rpc;	// factotum for rsa private key
124 	uchar sec[MasterSecretSize];	// master secret
125 	uchar crandom[RandomSize];	// client random
126 	uchar srandom[RandomSize];	// server random
127 	int clientVers;		// version in ClientHello
128 	int vers;			// final version
129 	// byte generation and handshake checksum
130 	void (*prf)(uchar*, int, uchar*, int, char*, uchar*, int, uchar*, int);
131 	void (*setFinished)(TlsSec*, MD5state, SHAstate, uchar*, int);
132 	int nfin;
133 } TlsSec;
134 
135 
136 enum {
137 	TLSVersion = 0x0301,
138 	SSL3Version = 0x0300,
139 	ProtocolVersion = 0x0301,	// maximum version we speak
140 	MinProtoVersion = 0x0300,	// limits on version we accept
141 	MaxProtoVersion	= 0x03ff,
142 };
143 
144 // handshake type
145 enum {
146 	HHelloRequest,
147 	HClientHello,
148 	HServerHello,
149 	HSSL2ClientHello = 9,  /* local convention;  see devtls.c */
150 	HCertificate = 11,
151 	HServerKeyExchange,
152 	HCertificateRequest,
153 	HServerHelloDone,
154 	HCertificateVerify,
155 	HClientKeyExchange,
156 	HFinished = 20,
157 	HMax
158 };
159 
160 // alerts
161 enum {
162 	ECloseNotify = 0,
163 	EUnexpectedMessage = 10,
164 	EBadRecordMac = 20,
165 	EDecryptionFailed = 21,
166 	ERecordOverflow = 22,
167 	EDecompressionFailure = 30,
168 	EHandshakeFailure = 40,
169 	ENoCertificate = 41,
170 	EBadCertificate = 42,
171 	EUnsupportedCertificate = 43,
172 	ECertificateRevoked = 44,
173 	ECertificateExpired = 45,
174 	ECertificateUnknown = 46,
175 	EIllegalParameter = 47,
176 	EUnknownCa = 48,
177 	EAccessDenied = 49,
178 	EDecodeError = 50,
179 	EDecryptError = 51,
180 	EExportRestriction = 60,
181 	EProtocolVersion = 70,
182 	EInsufficientSecurity = 71,
183 	EInternalError = 80,
184 	EUserCanceled = 90,
185 	ENoRenegotiation = 100,
186 	EMax = 256
187 };
188 
189 // cipher suites
190 enum {
191 	TLS_NULL_WITH_NULL_NULL	 		= 0x0000,
192 	TLS_RSA_WITH_NULL_MD5 			= 0x0001,
193 	TLS_RSA_WITH_NULL_SHA 			= 0x0002,
194 	TLS_RSA_EXPORT_WITH_RC4_40_MD5 		= 0x0003,
195 	TLS_RSA_WITH_RC4_128_MD5 		= 0x0004,
196 	TLS_RSA_WITH_RC4_128_SHA 		= 0x0005,
197 	TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5	= 0X0006,
198 	TLS_RSA_WITH_IDEA_CBC_SHA 		= 0X0007,
199 	TLS_RSA_EXPORT_WITH_DES40_CBC_SHA	= 0X0008,
200 	TLS_RSA_WITH_DES_CBC_SHA		= 0X0009,
201 	TLS_RSA_WITH_3DES_EDE_CBC_SHA		= 0X000A,
202 	TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA	= 0X000B,
203 	TLS_DH_DSS_WITH_DES_CBC_SHA		= 0X000C,
204 	TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA	= 0X000D,
205 	TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA	= 0X000E,
206 	TLS_DH_RSA_WITH_DES_CBC_SHA		= 0X000F,
207 	TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA	= 0X0010,
208 	TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA	= 0X0011,
209 	TLS_DHE_DSS_WITH_DES_CBC_SHA		= 0X0012,
210 	TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA	= 0X0013,	// ZZZ must be implemented for tls1.0 compliance
211 	TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA	= 0X0014,
212 	TLS_DHE_RSA_WITH_DES_CBC_SHA		= 0X0015,
213 	TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA	= 0X0016,
214 	TLS_DH_anon_EXPORT_WITH_RC4_40_MD5	= 0x0017,
215 	TLS_DH_anon_WITH_RC4_128_MD5 		= 0x0018,
216 	TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA	= 0X0019,
217 	TLS_DH_anon_WITH_DES_CBC_SHA		= 0X001A,
218 	TLS_DH_anon_WITH_3DES_EDE_CBC_SHA	= 0X001B,
219 
220 	TLS_RSA_WITH_AES_128_CBC_SHA		= 0X002f,	// aes, aka rijndael with 128 bit blocks
221 	TLS_DH_DSS_WITH_AES_128_CBC_SHA		= 0X0030,
222 	TLS_DH_RSA_WITH_AES_128_CBC_SHA		= 0X0031,
223 	TLS_DHE_DSS_WITH_AES_128_CBC_SHA	= 0X0032,
224 	TLS_DHE_RSA_WITH_AES_128_CBC_SHA	= 0X0033,
225 	TLS_DH_anon_WITH_AES_128_CBC_SHA	= 0X0034,
226 	TLS_RSA_WITH_AES_256_CBC_SHA		= 0X0035,
227 	TLS_DH_DSS_WITH_AES_256_CBC_SHA		= 0X0036,
228 	TLS_DH_RSA_WITH_AES_256_CBC_SHA		= 0X0037,
229 	TLS_DHE_DSS_WITH_AES_256_CBC_SHA	= 0X0038,
230 	TLS_DHE_RSA_WITH_AES_256_CBC_SHA	= 0X0039,
231 	TLS_DH_anon_WITH_AES_256_CBC_SHA	= 0X003A,
232 	CipherMax
233 };
234 
235 // compression methods
236 enum {
237 	CompressionNull = 0,
238 	CompressionMax
239 };
240 
241 static Algs cipherAlgs[] = {
242 	{"rc4_128", "md5",	2 * (16 + MD5dlen), TLS_RSA_WITH_RC4_128_MD5},
243 	{"rc4_128", "sha1",	2 * (16 + SHA1dlen), TLS_RSA_WITH_RC4_128_SHA},
244 	{"3des_ede_cbc","sha1",2*(4*8+SHA1dlen), TLS_RSA_WITH_3DES_EDE_CBC_SHA},
245 };
246 
247 static uchar compressors[] = {
248 	CompressionNull,
249 };
250 
251 static TlsConnection *tlsServer2(int ctl, int hand, uchar *cert, int ncert, int (*trace)(char*fmt, ...), PEMChain *chain);
252 static TlsConnection *tlsClient2(int ctl, int hand, uchar *csid, int ncsid, int (*trace)(char*fmt, ...));
253 
254 static void	msgClear(Msg *m);
255 static char* msgPrint(char *buf, int n, Msg *m);
256 static int	msgRecv(TlsConnection *c, Msg *m);
257 static int	msgSend(TlsConnection *c, Msg *m, int act);
258 static void	tlsError(TlsConnection *c, int err, char *msg, ...);
259 #pragma	varargck argpos	tlsError 3
260 static int setVersion(TlsConnection *c, int version);
261 static int finishedMatch(TlsConnection *c, Finished *f);
262 static void tlsConnectionFree(TlsConnection *c);
263 
264 static int setAlgs(TlsConnection *c, int a);
265 static int okCipher(Ints *cv);
266 static int okCompression(Bytes *cv);
267 static int initCiphers(void);
268 static Ints* makeciphers(void);
269 
270 static TlsSec* tlsSecInits(int cvers, uchar *csid, int ncsid, uchar *crandom, uchar *ssid, int *nssid, uchar *srandom);
271 static int	tlsSecSecrets(TlsSec *sec, int vers, uchar *epm, int nepm, uchar *kd, int nkd);
272 static TlsSec*	tlsSecInitc(int cvers, uchar *crandom);
273 static int	tlsSecSecretc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers, uchar **epm, int *nepm, uchar *kd, int nkd);
274 static int	tlsSecFinished(TlsSec *sec, MD5state md5, SHAstate sha1, uchar *fin, int nfin, int isclient);
275 static void	tlsSecOk(TlsSec *sec);
276 static void	tlsSecKill(TlsSec *sec);
277 static void	tlsSecClose(TlsSec *sec);
278 static void	setMasterSecret(TlsSec *sec, Bytes *pm);
279 static void	serverMasterSecret(TlsSec *sec, uchar *epm, int nepm);
280 static void	setSecrets(TlsSec *sec, uchar *kd, int nkd);
281 static int	clientMasterSecret(TlsSec *sec, RSApub *pub, uchar **epm, int *nepm);
282 static Bytes *pkcs1_encrypt(Bytes* data, RSApub* key, int blocktype);
283 static Bytes *pkcs1_decrypt(TlsSec *sec, uchar *epm, int nepm);
284 static void	tlsSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient);
285 static void	sslSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient);
286 static void	sslPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label,
287 			uchar *seed0, int nseed0, uchar *seed1, int nseed1);
288 static int setVers(TlsSec *sec, int version);
289 
290 static AuthRpc* factotum_rsa_open(uchar *cert, int certlen);
291 static mpint* factotum_rsa_decrypt(AuthRpc *rpc, mpint *cipher);
292 static void factotum_rsa_close(AuthRpc*rpc);
293 
294 static void* emalloc(int);
295 static void* erealloc(void*, int);
296 static void put32(uchar *p, u32int);
297 static void put24(uchar *p, int);
298 static void put16(uchar *p, int);
299 static u32int get32(uchar *p);
300 static int get24(uchar *p);
301 static int get16(uchar *p);
302 static Bytes* newbytes(int len);
303 static Bytes* makebytes(uchar* buf, int len);
304 static void freebytes(Bytes* b);
305 static Ints* newints(int len);
306 static Ints* makeints(int* buf, int len);
307 static void freeints(Ints* b);
308 
309 //================= client/server ========================
310 
311 //	push TLS onto fd, returning new (application) file descriptor
312 //		or -1 if error.
313 int
314 tlsServer(int fd, TLSconn *conn)
315 {
316 	char buf[8];
317 	char dname[64];
318 	int n, data, ctl, hand;
319 	TlsConnection *tls;
320 
321 	if(conn == nil)
322 		return -1;
323 	ctl = open("#a/tls/clone", ORDWR);
324 	if(ctl < 0)
325 		return -1;
326 	n = read(ctl, buf, sizeof(buf)-1);
327 	if(n < 0){
328 		close(ctl);
329 		return -1;
330 	}
331 	buf[n] = 0;
332 	sprint(conn->dir, "#a/tls/%s", buf);
333 	sprint(dname, "#a/tls/%s/hand", buf);
334 	hand = open(dname, ORDWR);
335 	if(hand < 0){
336 		close(ctl);
337 		return -1;
338 	}
339 	fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion);
340 	tls = tlsServer2(ctl, hand, conn->cert, conn->certlen, conn->trace, conn->chain);
341 	sprint(dname, "#a/tls/%s/data", buf);
342 	data = open(dname, ORDWR);
343 	close(fd);
344 	close(hand);
345 	close(ctl);
346 	if(data < 0){
347 		return -1;
348 	}
349 	if(tls == nil){
350 		close(data);
351 		return -1;
352 	}
353 	if(conn->cert)
354 		free(conn->cert);
355 	conn->cert = 0;  // client certificates are not yet implemented
356 	conn->certlen = 0;
357 	conn->sessionIDlen = tls->sid->len;
358 	conn->sessionID = emalloc(conn->sessionIDlen);
359 	memcpy(conn->sessionID, tls->sid->data, conn->sessionIDlen);
360 	if(conn->sessionKey != nil && conn->sessionType != nil && strcmp(conn->sessionType, "ttls") == 0)
361 		tls->sec->prf(conn->sessionKey, conn->sessionKeylen, tls->sec->sec, MasterSecretSize, conn->sessionConst,  tls->sec->crandom, RandomSize, tls->sec->srandom, RandomSize);
362 	tlsConnectionFree(tls);
363 	return data;
364 }
365 
366 //	push TLS onto fd, returning new (application) file descriptor
367 //		or -1 if error.
368 int
369 tlsClient(int fd, TLSconn *conn)
370 {
371 	char buf[8];
372 	char dname[64];
373 	int n, data, ctl, hand;
374 	TlsConnection *tls;
375 
376 	if(!conn)
377 		return -1;
378 	ctl = open("#a/tls/clone", ORDWR);
379 	if(ctl < 0)
380 		return -1;
381 	n = read(ctl, buf, sizeof(buf)-1);
382 	if(n < 0){
383 		close(ctl);
384 		return -1;
385 	}
386 	buf[n] = 0;
387 	sprint(conn->dir, "#a/tls/%s", buf);
388 	sprint(dname, "#a/tls/%s/hand", buf);
389 	hand = open(dname, ORDWR);
390 	if(hand < 0){
391 		close(ctl);
392 		return -1;
393 	}
394 	sprint(dname, "#a/tls/%s/data", buf);
395 	data = open(dname, ORDWR);
396 	if(data < 0)
397 		return -1;
398 	fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion);
399 	tls = tlsClient2(ctl, hand, conn->sessionID, conn->sessionIDlen, conn->trace);
400 	close(fd);
401 	close(hand);
402 	close(ctl);
403 	if(tls == nil){
404 		close(data);
405 		return -1;
406 	}
407 	conn->certlen = tls->cert->len;
408 	conn->cert = emalloc(conn->certlen);
409 	memcpy(conn->cert, tls->cert->data, conn->certlen);
410 	conn->sessionIDlen = tls->sid->len;
411 	conn->sessionID = emalloc(conn->sessionIDlen);
412 	memcpy(conn->sessionID, tls->sid->data, conn->sessionIDlen);
413 	if(conn->sessionKey != nil && conn->sessionType != nil && strcmp(conn->sessionType, "ttls") == 0)
414 		tls->sec->prf(conn->sessionKey, conn->sessionKeylen, tls->sec->sec, MasterSecretSize, conn->sessionConst,  tls->sec->crandom, RandomSize, tls->sec->srandom, RandomSize);
415 	tlsConnectionFree(tls);
416 	return data;
417 }
418 
419 static int
420 countchain(PEMChain *p)
421 {
422 	int i = 0;
423 
424 	while (p) {
425 		i++;
426 		p = p->next;
427 	}
428 	return i;
429 }
430 
431 static TlsConnection *
432 tlsServer2(int ctl, int hand, uchar *cert, int ncert, int (*trace)(char*fmt, ...), PEMChain *chp)
433 {
434 	TlsConnection *c;
435 	Msg m;
436 	Bytes *csid;
437 	uchar sid[SidSize], kd[MaxKeyData];
438 	char *secrets;
439 	int cipher, compressor, nsid, rv, numcerts, i;
440 
441 	if(trace)
442 		trace("tlsServer2\n");
443 	if(!initCiphers())
444 		return nil;
445 	c = emalloc(sizeof(TlsConnection));
446 	c->ctl = ctl;
447 	c->hand = hand;
448 	c->trace = trace;
449 	c->version = ProtocolVersion;
450 
451 	memset(&m, 0, sizeof(m));
452 	if(!msgRecv(c, &m)){
453 		if(trace)
454 			trace("initial msgRecv failed\n");
455 		goto Err;
456 	}
457 	if(m.tag != HClientHello) {
458 		tlsError(c, EUnexpectedMessage, "expected a client hello");
459 		goto Err;
460 	}
461 	c->clientVersion = m.u.clientHello.version;
462 	if(trace)
463 		trace("ClientHello version %x\n", c->clientVersion);
464 	if(setVersion(c, m.u.clientHello.version) < 0) {
465 		tlsError(c, EIllegalParameter, "incompatible version");
466 		goto Err;
467 	}
468 
469 	memmove(c->crandom, m.u.clientHello.random, RandomSize);
470 	cipher = okCipher(m.u.clientHello.ciphers);
471 	if(cipher < 0) {
472 		// reply with EInsufficientSecurity if we know that's the case
473 		if(cipher == -2)
474 			tlsError(c, EInsufficientSecurity, "cipher suites too weak");
475 		else
476 			tlsError(c, EHandshakeFailure, "no matching cipher suite");
477 		goto Err;
478 	}
479 	if(!setAlgs(c, cipher)){
480 		tlsError(c, EHandshakeFailure, "no matching cipher suite");
481 		goto Err;
482 	}
483 	compressor = okCompression(m.u.clientHello.compressors);
484 	if(compressor < 0) {
485 		tlsError(c, EHandshakeFailure, "no matching compressor");
486 		goto Err;
487 	}
488 
489 	csid = m.u.clientHello.sid;
490 	if(trace)
491 		trace("  cipher %d, compressor %d, csidlen %d\n", cipher, compressor, csid->len);
492 	c->sec = tlsSecInits(c->clientVersion, csid->data, csid->len, c->crandom, sid, &nsid, c->srandom);
493 	if(c->sec == nil){
494 		tlsError(c, EHandshakeFailure, "can't initialize security: %r");
495 		goto Err;
496 	}
497 	c->sec->rpc = factotum_rsa_open(cert, ncert);
498 	if(c->sec->rpc == nil){
499 		tlsError(c, EHandshakeFailure, "factotum_rsa_open: %r");
500 		goto Err;
501 	}
502 	c->sec->rsapub = X509toRSApub(cert, ncert, nil, 0);
503 	msgClear(&m);
504 
505 	m.tag = HServerHello;
506 	m.u.serverHello.version = c->version;
507 	memmove(m.u.serverHello.random, c->srandom, RandomSize);
508 	m.u.serverHello.cipher = cipher;
509 	m.u.serverHello.compressor = compressor;
510 	c->sid = makebytes(sid, nsid);
511 	m.u.serverHello.sid = makebytes(c->sid->data, c->sid->len);
512 	if(!msgSend(c, &m, AQueue))
513 		goto Err;
514 	msgClear(&m);
515 
516 	m.tag = HCertificate;
517 	numcerts = countchain(chp);
518 	m.u.certificate.ncert = 1 + numcerts;
519 	m.u.certificate.certs = emalloc(m.u.certificate.ncert * sizeof(Bytes));
520 	m.u.certificate.certs[0] = makebytes(cert, ncert);
521 	for (i = 0; i < numcerts && chp; i++, chp = chp->next)
522 		m.u.certificate.certs[i+1] = makebytes(chp->pem, chp->pemlen);
523 	if(!msgSend(c, &m, AQueue))
524 		goto Err;
525 	msgClear(&m);
526 
527 	m.tag = HServerHelloDone;
528 	if(!msgSend(c, &m, AFlush))
529 		goto Err;
530 	msgClear(&m);
531 
532 	if(!msgRecv(c, &m))
533 		goto Err;
534 	if(m.tag != HClientKeyExchange) {
535 		tlsError(c, EUnexpectedMessage, "expected a client key exchange");
536 		goto Err;
537 	}
538 	if(tlsSecSecrets(c->sec, c->version, m.u.clientKeyExchange.key->data, m.u.clientKeyExchange.key->len, kd, c->nsecret) < 0){
539 		tlsError(c, EHandshakeFailure, "couldn't set secrets: %r");
540 		goto Err;
541 	}
542 	if(trace)
543 		trace("tls secrets\n");
544 	secrets = (char*)emalloc(2*c->nsecret);
545 	enc64(secrets, 2*c->nsecret, kd, c->nsecret);
546 	rv = fprint(c->ctl, "secret %s %s 0 %s", c->digest, c->enc, secrets);
547 	memset(secrets, 0, 2*c->nsecret);
548 	free(secrets);
549 	memset(kd, 0, c->nsecret);
550 	if(rv < 0){
551 		tlsError(c, EHandshakeFailure, "can't set keys: %r");
552 		goto Err;
553 	}
554 	msgClear(&m);
555 
556 	/* no CertificateVerify; skip to Finished */
557 	if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 1) < 0){
558 		tlsError(c, EInternalError, "can't set finished: %r");
559 		goto Err;
560 	}
561 	if(!msgRecv(c, &m))
562 		goto Err;
563 	if(m.tag != HFinished) {
564 		tlsError(c, EUnexpectedMessage, "expected a finished");
565 		goto Err;
566 	}
567 	if(!finishedMatch(c, &m.u.finished)) {
568 		tlsError(c, EHandshakeFailure, "finished verification failed");
569 		goto Err;
570 	}
571 	msgClear(&m);
572 
573 	/* change cipher spec */
574 	if(fprint(c->ctl, "changecipher") < 0){
575 		tlsError(c, EInternalError, "can't enable cipher: %r");
576 		goto Err;
577 	}
578 
579 	if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 0) < 0){
580 		tlsError(c, EInternalError, "can't set finished: %r");
581 		goto Err;
582 	}
583 	m.tag = HFinished;
584 	m.u.finished = c->finished;
585 	if(!msgSend(c, &m, AFlush))
586 		goto Err;
587 	msgClear(&m);
588 	if(trace)
589 		trace("tls finished\n");
590 
591 	if(fprint(c->ctl, "opened") < 0)
592 		goto Err;
593 	tlsSecOk(c->sec);
594 	return c;
595 
596 Err:
597 	msgClear(&m);
598 	tlsConnectionFree(c);
599 	return 0;
600 }
601 
602 static TlsConnection *
603 tlsClient2(int ctl, int hand, uchar *csid, int ncsid, int (*trace)(char*fmt, ...))
604 {
605 	TlsConnection *c;
606 	Msg m;
607 	uchar kd[MaxKeyData], *epm;
608 	char *secrets;
609 	int creq, nepm, rv;
610 
611 	if(!initCiphers())
612 		return nil;
613 	epm = nil;
614 	c = emalloc(sizeof(TlsConnection));
615 	c->version = ProtocolVersion;
616 	c->ctl = ctl;
617 	c->hand = hand;
618 	c->trace = trace;
619 	c->isClient = 1;
620 	c->clientVersion = c->version;
621 
622 	c->sec = tlsSecInitc(c->clientVersion, c->crandom);
623 	if(c->sec == nil)
624 		goto Err;
625 
626 	/* client hello */
627 	memset(&m, 0, sizeof(m));
628 	m.tag = HClientHello;
629 	m.u.clientHello.version = c->clientVersion;
630 	memmove(m.u.clientHello.random, c->crandom, RandomSize);
631 	m.u.clientHello.sid = makebytes(csid, ncsid);
632 	m.u.clientHello.ciphers = makeciphers();
633 	m.u.clientHello.compressors = makebytes(compressors,sizeof(compressors));
634 	if(!msgSend(c, &m, AFlush))
635 		goto Err;
636 	msgClear(&m);
637 
638 	/* server hello */
639 	if(!msgRecv(c, &m))
640 		goto Err;
641 	if(m.tag != HServerHello) {
642 		tlsError(c, EUnexpectedMessage, "expected a server hello");
643 		goto Err;
644 	}
645 	if(setVersion(c, m.u.serverHello.version) < 0) {
646 		tlsError(c, EIllegalParameter, "incompatible version %r");
647 		goto Err;
648 	}
649 	memmove(c->srandom, m.u.serverHello.random, RandomSize);
650 	c->sid = makebytes(m.u.serverHello.sid->data, m.u.serverHello.sid->len);
651 	if(c->sid->len != 0 && c->sid->len != SidSize) {
652 		tlsError(c, EIllegalParameter, "invalid server session identifier");
653 		goto Err;
654 	}
655 	if(!setAlgs(c, m.u.serverHello.cipher)) {
656 		tlsError(c, EIllegalParameter, "invalid cipher suite");
657 		goto Err;
658 	}
659 	if(m.u.serverHello.compressor != CompressionNull) {
660 		tlsError(c, EIllegalParameter, "invalid compression");
661 		goto Err;
662 	}
663 	msgClear(&m);
664 
665 	/* certificate */
666 	if(!msgRecv(c, &m) || m.tag != HCertificate) {
667 		tlsError(c, EUnexpectedMessage, "expected a certificate");
668 		goto Err;
669 	}
670 	if(m.u.certificate.ncert < 1) {
671 		tlsError(c, EIllegalParameter, "runt certificate");
672 		goto Err;
673 	}
674 	c->cert = makebytes(m.u.certificate.certs[0]->data, m.u.certificate.certs[0]->len);
675 	msgClear(&m);
676 
677 	/* server key exchange (optional) */
678 	if(!msgRecv(c, &m))
679 		goto Err;
680 	if(m.tag == HServerKeyExchange) {
681 		tlsError(c, EUnexpectedMessage, "got an server key exchange");
682 		goto Err;
683 		// If implementing this later, watch out for rollback attack
684 		// described in Wagner Schneier 1996, section 4.4.
685 	}
686 
687 	/* certificate request (optional) */
688 	creq = 0;
689 	if(m.tag == HCertificateRequest) {
690 		creq = 1;
691 		msgClear(&m);
692 		if(!msgRecv(c, &m))
693 			goto Err;
694 	}
695 
696 	if(m.tag != HServerHelloDone) {
697 		tlsError(c, EUnexpectedMessage, "expected a server hello done");
698 		goto Err;
699 	}
700 	msgClear(&m);
701 
702 	if(tlsSecSecretc(c->sec, c->sid->data, c->sid->len, c->srandom,
703 			c->cert->data, c->cert->len, c->version, &epm, &nepm,
704 			kd, c->nsecret) < 0){
705 		tlsError(c, EBadCertificate, "invalid x509/rsa certificate");
706 		goto Err;
707 	}
708 	secrets = (char*)emalloc(2*c->nsecret);
709 	enc64(secrets, 2*c->nsecret, kd, c->nsecret);
710 	rv = fprint(c->ctl, "secret %s %s 1 %s", c->digest, c->enc, secrets);
711 	memset(secrets, 0, 2*c->nsecret);
712 	free(secrets);
713 	memset(kd, 0, c->nsecret);
714 	if(rv < 0){
715 		tlsError(c, EHandshakeFailure, "can't set keys: %r");
716 		goto Err;
717 	}
718 
719 	if(creq) {
720 		/* send a zero length certificate */
721 		m.tag = HCertificate;
722 		if(!msgSend(c, &m, AFlush))
723 			goto Err;
724 		msgClear(&m);
725 	}
726 
727 	/* client key exchange */
728 	m.tag = HClientKeyExchange;
729 	m.u.clientKeyExchange.key = makebytes(epm, nepm);
730 	free(epm);
731 	epm = nil;
732 	if(m.u.clientKeyExchange.key == nil) {
733 		tlsError(c, EHandshakeFailure, "can't set secret: %r");
734 		goto Err;
735 	}
736 	if(!msgSend(c, &m, AFlush))
737 		goto Err;
738 	msgClear(&m);
739 
740 	/* change cipher spec */
741 	if(fprint(c->ctl, "changecipher") < 0){
742 		tlsError(c, EInternalError, "can't enable cipher: %r");
743 		goto Err;
744 	}
745 
746 	// Cipherchange must occur immediately before Finished to avoid
747 	// potential hole;  see section 4.3 of Wagner Schneier 1996.
748 	if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 1) < 0){
749 		tlsError(c, EInternalError, "can't set finished 1: %r");
750 		goto Err;
751 	}
752 	m.tag = HFinished;
753 	m.u.finished = c->finished;
754 
755 	if(!msgSend(c, &m, AFlush)) {
756 		fprint(2, "tlsClient nepm=%d\n", nepm);
757 		tlsError(c, EInternalError, "can't flush after client Finished: %r");
758 		goto Err;
759 	}
760 	msgClear(&m);
761 
762 	if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 0) < 0){
763 		fprint(2, "tlsClient nepm=%d\n", nepm);
764 		tlsError(c, EInternalError, "can't set finished 0: %r");
765 		goto Err;
766 	}
767 	if(!msgRecv(c, &m)) {
768 		fprint(2, "tlsClient nepm=%d\n", nepm);
769 		tlsError(c, EInternalError, "can't read server Finished: %r");
770 		goto Err;
771 	}
772 	if(m.tag != HFinished) {
773 		fprint(2, "tlsClient nepm=%d\n", nepm);
774 		tlsError(c, EUnexpectedMessage, "expected a Finished msg from server");
775 		goto Err;
776 	}
777 
778 	if(!finishedMatch(c, &m.u.finished)) {
779 		tlsError(c, EHandshakeFailure, "finished verification failed");
780 		goto Err;
781 	}
782 	msgClear(&m);
783 
784 	if(fprint(c->ctl, "opened") < 0){
785 		if(trace)
786 			trace("unable to do final open: %r\n");
787 		goto Err;
788 	}
789 	tlsSecOk(c->sec);
790 	return c;
791 
792 Err:
793 	free(epm);
794 	msgClear(&m);
795 	tlsConnectionFree(c);
796 	return 0;
797 }
798 
799 
800 //================= message functions ========================
801 
802 static uchar sendbuf[9000], *sendp;
803 
804 static int
805 msgSend(TlsConnection *c, Msg *m, int act)
806 {
807 	uchar *p; // sendp = start of new message;  p = write pointer
808 	int nn, n, i;
809 
810 	if(sendp == nil)
811 		sendp = sendbuf;
812 	p = sendp;
813 	if(c->trace)
814 		c->trace("send %s", msgPrint((char*)p, (sizeof sendbuf) - (p-sendbuf), m));
815 
816 	p[0] = m->tag;	// header - fill in size later
817 	p += 4;
818 
819 	switch(m->tag) {
820 	default:
821 		tlsError(c, EInternalError, "can't encode a %d", m->tag);
822 		goto Err;
823 	case HClientHello:
824 		// version
825 		put16(p, m->u.clientHello.version);
826 		p += 2;
827 
828 		// random
829 		memmove(p, m->u.clientHello.random, RandomSize);
830 		p += RandomSize;
831 
832 		// sid
833 		n = m->u.clientHello.sid->len;
834 		assert(n < 256);
835 		p[0] = n;
836 		memmove(p+1, m->u.clientHello.sid->data, n);
837 		p += n+1;
838 
839 		n = m->u.clientHello.ciphers->len;
840 		assert(n > 0 && n < 200);
841 		put16(p, n*2);
842 		p += 2;
843 		for(i=0; i<n; i++) {
844 			put16(p, m->u.clientHello.ciphers->data[i]);
845 			p += 2;
846 		}
847 
848 		n = m->u.clientHello.compressors->len;
849 		assert(n > 0);
850 		p[0] = n;
851 		memmove(p+1, m->u.clientHello.compressors->data, n);
852 		p += n+1;
853 		break;
854 	case HServerHello:
855 		put16(p, m->u.serverHello.version);
856 		p += 2;
857 
858 		// random
859 		memmove(p, m->u.serverHello.random, RandomSize);
860 		p += RandomSize;
861 
862 		// sid
863 		n = m->u.serverHello.sid->len;
864 		assert(n < 256);
865 		p[0] = n;
866 		memmove(p+1, m->u.serverHello.sid->data, n);
867 		p += n+1;
868 
869 		put16(p, m->u.serverHello.cipher);
870 		p += 2;
871 		p[0] = m->u.serverHello.compressor;
872 		p += 1;
873 		break;
874 	case HServerHelloDone:
875 		break;
876 	case HCertificate:
877 		nn = 0;
878 		for(i = 0; i < m->u.certificate.ncert; i++)
879 			nn += 3 + m->u.certificate.certs[i]->len;
880 		if(p + 3 + nn - sendbuf > sizeof(sendbuf)) {
881 			tlsError(c, EInternalError, "output buffer too small for certificate");
882 			goto Err;
883 		}
884 		put24(p, nn);
885 		p += 3;
886 		for(i = 0; i < m->u.certificate.ncert; i++){
887 			put24(p, m->u.certificate.certs[i]->len);
888 			p += 3;
889 			memmove(p, m->u.certificate.certs[i]->data, m->u.certificate.certs[i]->len);
890 			p += m->u.certificate.certs[i]->len;
891 		}
892 		break;
893 	case HClientKeyExchange:
894 		n = m->u.clientKeyExchange.key->len;
895 		if(c->version != SSL3Version){
896 			put16(p, n);
897 			p += 2;
898 		}
899 		memmove(p, m->u.clientKeyExchange.key->data, n);
900 		p += n;
901 		break;
902 	case HFinished:
903 		memmove(p, m->u.finished.verify, m->u.finished.n);
904 		p += m->u.finished.n;
905 		break;
906 	}
907 
908 	// go back and fill in size
909 	n = p-sendp;
910 	assert(p <= sendbuf+sizeof(sendbuf));
911 	put24(sendp+1, n-4);
912 
913 	// remember hash of Handshake messages
914 	if(m->tag != HHelloRequest) {
915 		md5(sendp, n, 0, &c->hsmd5);
916 		sha1(sendp, n, 0, &c->hssha1);
917 	}
918 
919 	sendp = p;
920 	if(act == AFlush){
921 		sendp = sendbuf;
922 		if(write(c->hand, sendbuf, p-sendbuf) < 0){
923 			fprint(2, "write error: %r\n");
924 			goto Err;
925 		}
926 	}
927 	msgClear(m);
928 	return 1;
929 Err:
930 	msgClear(m);
931 	return 0;
932 }
933 
934 static uchar*
935 tlsReadN(TlsConnection *c, int n)
936 {
937 	uchar *p;
938 	int nn, nr;
939 
940 	nn = c->ep - c->rp;
941 	if(nn < n){
942 		if(c->rp != c->buf){
943 			memmove(c->buf, c->rp, nn);
944 			c->rp = c->buf;
945 			c->ep = &c->buf[nn];
946 		}
947 		for(; nn < n; nn += nr) {
948 			nr = read(c->hand, &c->rp[nn], n - nn);
949 			if(nr <= 0)
950 				return nil;
951 			c->ep += nr;
952 		}
953 	}
954 	p = c->rp;
955 	c->rp += n;
956 	return p;
957 }
958 
959 static int
960 msgRecv(TlsConnection *c, Msg *m)
961 {
962 	uchar *p;
963 	int type, n, nn, i, nsid, nrandom, nciph;
964 
965 	for(;;) {
966 		p = tlsReadN(c, 4);
967 		if(p == nil)
968 			return 0;
969 		type = p[0];
970 		n = get24(p+1);
971 
972 		if(type != HHelloRequest)
973 			break;
974 		if(n != 0) {
975 			tlsError(c, EDecodeError, "invalid hello request during handshake");
976 			return 0;
977 		}
978 	}
979 
980 	if(n > sizeof(c->buf)) {
981 		tlsError(c, EDecodeError, "handshake message too long %d %d", n, sizeof(c->buf));
982 		return 0;
983 	}
984 
985 	if(type == HSSL2ClientHello){
986 		/* Cope with an SSL3 ClientHello expressed in SSL2 record format.
987 			This is sent by some clients that we must interoperate
988 			with, such as Java's JSSE and Microsoft's Internet Explorer. */
989 		p = tlsReadN(c, n);
990 		if(p == nil)
991 			return 0;
992 		md5(p, n, 0, &c->hsmd5);
993 		sha1(p, n, 0, &c->hssha1);
994 		m->tag = HClientHello;
995 		if(n < 22)
996 			goto Short;
997 		m->u.clientHello.version = get16(p+1);
998 		p += 3;
999 		n -= 3;
1000 		nn = get16(p); /* cipher_spec_len */
1001 		nsid = get16(p + 2);
1002 		nrandom = get16(p + 4);
1003 		p += 6;
1004 		n -= 6;
1005 		if(nsid != 0 	/* no sid's, since shouldn't restart using ssl2 header */
1006 				|| nrandom < 16 || nn % 3)
1007 			goto Err;
1008 		if(c->trace && (n - nrandom != nn))
1009 			c->trace("n-nrandom!=nn: n=%d nrandom=%d nn=%d\n", n, nrandom, nn);
1010 		/* ignore ssl2 ciphers and look for {0x00, ssl3 cipher} */
1011 		nciph = 0;
1012 		for(i = 0; i < nn; i += 3)
1013 			if(p[i] == 0)
1014 				nciph++;
1015 		m->u.clientHello.ciphers = newints(nciph);
1016 		nciph = 0;
1017 		for(i = 0; i < nn; i += 3)
1018 			if(p[i] == 0)
1019 				m->u.clientHello.ciphers->data[nciph++] = get16(&p[i + 1]);
1020 		p += nn;
1021 		m->u.clientHello.sid = makebytes(nil, 0);
1022 		if(nrandom > RandomSize)
1023 			nrandom = RandomSize;
1024 		memset(m->u.clientHello.random, 0, RandomSize - nrandom);
1025 		memmove(&m->u.clientHello.random[RandomSize - nrandom], p, nrandom);
1026 		m->u.clientHello.compressors = newbytes(1);
1027 		m->u.clientHello.compressors->data[0] = CompressionNull;
1028 		goto Ok;
1029 	}
1030 
1031 	md5(p, 4, 0, &c->hsmd5);
1032 	sha1(p, 4, 0, &c->hssha1);
1033 
1034 	p = tlsReadN(c, n);
1035 	if(p == nil)
1036 		return 0;
1037 
1038 	md5(p, n, 0, &c->hsmd5);
1039 	sha1(p, n, 0, &c->hssha1);
1040 
1041 	m->tag = type;
1042 
1043 	switch(type) {
1044 	default:
1045 		tlsError(c, EUnexpectedMessage, "can't decode a %d", type);
1046 		goto Err;
1047 	case HClientHello:
1048 		if(n < 2)
1049 			goto Short;
1050 		m->u.clientHello.version = get16(p);
1051 		p += 2;
1052 		n -= 2;
1053 
1054 		if(n < RandomSize)
1055 			goto Short;
1056 		memmove(m->u.clientHello.random, p, RandomSize);
1057 		p += RandomSize;
1058 		n -= RandomSize;
1059 		if(n < 1 || n < p[0]+1)
1060 			goto Short;
1061 		m->u.clientHello.sid = makebytes(p+1, p[0]);
1062 		p += m->u.clientHello.sid->len+1;
1063 		n -= m->u.clientHello.sid->len+1;
1064 
1065 		if(n < 2)
1066 			goto Short;
1067 		nn = get16(p);
1068 		p += 2;
1069 		n -= 2;
1070 
1071 		if((nn & 1) || n < nn || nn < 2)
1072 			goto Short;
1073 		m->u.clientHello.ciphers = newints(nn >> 1);
1074 		for(i = 0; i < nn; i += 2)
1075 			m->u.clientHello.ciphers->data[i >> 1] = get16(&p[i]);
1076 		p += nn;
1077 		n -= nn;
1078 
1079 		if(n < 1 || n < p[0]+1 || p[0] == 0)
1080 			goto Short;
1081 		nn = p[0];
1082 		m->u.clientHello.compressors = newbytes(nn);
1083 		memmove(m->u.clientHello.compressors->data, p+1, nn);
1084 		n -= nn + 1;
1085 		break;
1086 	case HServerHello:
1087 		if(n < 2)
1088 			goto Short;
1089 		m->u.serverHello.version = get16(p);
1090 		p += 2;
1091 		n -= 2;
1092 
1093 		if(n < RandomSize)
1094 			goto Short;
1095 		memmove(m->u.serverHello.random, p, RandomSize);
1096 		p += RandomSize;
1097 		n -= RandomSize;
1098 
1099 		if(n < 1 || n < p[0]+1)
1100 			goto Short;
1101 		m->u.serverHello.sid = makebytes(p+1, p[0]);
1102 		p += m->u.serverHello.sid->len+1;
1103 		n -= m->u.serverHello.sid->len+1;
1104 
1105 		if(n < 3)
1106 			goto Short;
1107 		m->u.serverHello.cipher = get16(p);
1108 		m->u.serverHello.compressor = p[2];
1109 		n -= 3;
1110 		break;
1111 	case HCertificate:
1112 		if(n < 3)
1113 			goto Short;
1114 		nn = get24(p);
1115 		p += 3;
1116 		n -= 3;
1117 		if(n != nn)
1118 			goto Short;
1119 		/* certs */
1120 		i = 0;
1121 		while(n > 0) {
1122 			if(n < 3)
1123 				goto Short;
1124 			nn = get24(p);
1125 			p += 3;
1126 			n -= 3;
1127 			if(nn > n)
1128 				goto Short;
1129 			m->u.certificate.ncert = i+1;
1130 			m->u.certificate.certs = erealloc(m->u.certificate.certs, (i+1)*sizeof(Bytes));
1131 			m->u.certificate.certs[i] = makebytes(p, nn);
1132 			p += nn;
1133 			n -= nn;
1134 			i++;
1135 		}
1136 		break;
1137 	case HCertificateRequest:
1138 		if(n < 1)
1139 			goto Short;
1140 		nn = p[0];
1141 		p += 1;
1142 		n -= 1;
1143 		if(nn < 1 || nn > n)
1144 			goto Short;
1145 		m->u.certificateRequest.types = makebytes(p, nn);
1146 		p += nn;
1147 		n -= nn;
1148 		if(n < 2)
1149 			goto Short;
1150 		nn = get16(p);
1151 		p += 2;
1152 		n -= 2;
1153 		/* nn == 0 can happen; yahoo's servers do it */
1154 		if(nn != n)
1155 			goto Short;
1156 		/* cas */
1157 		i = 0;
1158 		while(n > 0) {
1159 			if(n < 2)
1160 				goto Short;
1161 			nn = get16(p);
1162 			p += 2;
1163 			n -= 2;
1164 			if(nn < 1 || nn > n)
1165 				goto Short;
1166 			m->u.certificateRequest.nca = i+1;
1167 			m->u.certificateRequest.cas = erealloc(
1168 				m->u.certificateRequest.cas, (i+1)*sizeof(Bytes));
1169 			m->u.certificateRequest.cas[i] = makebytes(p, nn);
1170 			p += nn;
1171 			n -= nn;
1172 			i++;
1173 		}
1174 		break;
1175 	case HServerHelloDone:
1176 		break;
1177 	case HClientKeyExchange:
1178 		/*
1179 		 * this message depends upon the encryption selected
1180 		 * assume rsa.
1181 		 */
1182 		if(c->version == SSL3Version)
1183 			nn = n;
1184 		else{
1185 			if(n < 2)
1186 				goto Short;
1187 			nn = get16(p);
1188 			p += 2;
1189 			n -= 2;
1190 		}
1191 		if(n < nn)
1192 			goto Short;
1193 		m->u.clientKeyExchange.key = makebytes(p, nn);
1194 		n -= nn;
1195 		break;
1196 	case HFinished:
1197 		m->u.finished.n = c->finished.n;
1198 		if(n < m->u.finished.n)
1199 			goto Short;
1200 		memmove(m->u.finished.verify, p, m->u.finished.n);
1201 		n -= m->u.finished.n;
1202 		break;
1203 	}
1204 
1205 	if(type != HClientHello && n != 0)
1206 		goto Short;
1207 Ok:
1208 	if(c->trace){
1209 		char *buf;
1210 		buf = emalloc(8000);
1211 		c->trace("recv %s", msgPrint(buf, 8000, m));
1212 		free(buf);
1213 	}
1214 	return 1;
1215 Short:
1216 	tlsError(c, EDecodeError, "handshake message has invalid length");
1217 Err:
1218 	msgClear(m);
1219 	return 0;
1220 }
1221 
1222 static void
1223 msgClear(Msg *m)
1224 {
1225 	int i;
1226 
1227 	switch(m->tag) {
1228 	default:
1229 		sysfatal("msgClear: unknown message type: %d", m->tag);
1230 	case HHelloRequest:
1231 		break;
1232 	case HClientHello:
1233 		freebytes(m->u.clientHello.sid);
1234 		freeints(m->u.clientHello.ciphers);
1235 		freebytes(m->u.clientHello.compressors);
1236 		break;
1237 	case HServerHello:
1238 		freebytes(m->u.clientHello.sid);
1239 		break;
1240 	case HCertificate:
1241 		for(i=0; i<m->u.certificate.ncert; i++)
1242 			freebytes(m->u.certificate.certs[i]);
1243 		free(m->u.certificate.certs);
1244 		break;
1245 	case HCertificateRequest:
1246 		freebytes(m->u.certificateRequest.types);
1247 		for(i=0; i<m->u.certificateRequest.nca; i++)
1248 			freebytes(m->u.certificateRequest.cas[i]);
1249 		free(m->u.certificateRequest.cas);
1250 		break;
1251 	case HServerHelloDone:
1252 		break;
1253 	case HClientKeyExchange:
1254 		freebytes(m->u.clientKeyExchange.key);
1255 		break;
1256 	case HFinished:
1257 		break;
1258 	}
1259 	memset(m, 0, sizeof(Msg));
1260 }
1261 
1262 static char *
1263 bytesPrint(char *bs, char *be, char *s0, Bytes *b, char *s1)
1264 {
1265 	int i;
1266 
1267 	if(s0)
1268 		bs = seprint(bs, be, "%s", s0);
1269 	bs = seprint(bs, be, "[");
1270 	if(b == nil)
1271 		bs = seprint(bs, be, "nil");
1272 	else
1273 		for(i=0; i<b->len; i++)
1274 			bs = seprint(bs, be, "%.2x ", b->data[i]);
1275 	bs = seprint(bs, be, "]");
1276 	if(s1)
1277 		bs = seprint(bs, be, "%s", s1);
1278 	return bs;
1279 }
1280 
1281 static char *
1282 intsPrint(char *bs, char *be, char *s0, Ints *b, char *s1)
1283 {
1284 	int i;
1285 
1286 	if(s0)
1287 		bs = seprint(bs, be, "%s", s0);
1288 	bs = seprint(bs, be, "[");
1289 	if(b == nil)
1290 		bs = seprint(bs, be, "nil");
1291 	else
1292 		for(i=0; i<b->len; i++)
1293 			bs = seprint(bs, be, "%x ", b->data[i]);
1294 	bs = seprint(bs, be, "]");
1295 	if(s1)
1296 		bs = seprint(bs, be, "%s", s1);
1297 	return bs;
1298 }
1299 
1300 static char*
1301 msgPrint(char *buf, int n, Msg *m)
1302 {
1303 	int i;
1304 	char *bs = buf, *be = buf+n;
1305 
1306 	switch(m->tag) {
1307 	default:
1308 		bs = seprint(bs, be, "unknown %d\n", m->tag);
1309 		break;
1310 	case HClientHello:
1311 		bs = seprint(bs, be, "ClientHello\n");
1312 		bs = seprint(bs, be, "\tversion: %.4x\n", m->u.clientHello.version);
1313 		bs = seprint(bs, be, "\trandom: ");
1314 		for(i=0; i<RandomSize; i++)
1315 			bs = seprint(bs, be, "%.2x", m->u.clientHello.random[i]);
1316 		bs = seprint(bs, be, "\n");
1317 		bs = bytesPrint(bs, be, "\tsid: ", m->u.clientHello.sid, "\n");
1318 		bs = intsPrint(bs, be, "\tciphers: ", m->u.clientHello.ciphers, "\n");
1319 		bs = bytesPrint(bs, be, "\tcompressors: ", m->u.clientHello.compressors, "\n");
1320 		break;
1321 	case HServerHello:
1322 		bs = seprint(bs, be, "ServerHello\n");
1323 		bs = seprint(bs, be, "\tversion: %.4x\n", m->u.serverHello.version);
1324 		bs = seprint(bs, be, "\trandom: ");
1325 		for(i=0; i<RandomSize; i++)
1326 			bs = seprint(bs, be, "%.2x", m->u.serverHello.random[i]);
1327 		bs = seprint(bs, be, "\n");
1328 		bs = bytesPrint(bs, be, "\tsid: ", m->u.serverHello.sid, "\n");
1329 		bs = seprint(bs, be, "\tcipher: %.4x\n", m->u.serverHello.cipher);
1330 		bs = seprint(bs, be, "\tcompressor: %.2x\n", m->u.serverHello.compressor);
1331 		break;
1332 	case HCertificate:
1333 		bs = seprint(bs, be, "Certificate\n");
1334 		for(i=0; i<m->u.certificate.ncert; i++)
1335 			bs = bytesPrint(bs, be, "\t", m->u.certificate.certs[i], "\n");
1336 		break;
1337 	case HCertificateRequest:
1338 		bs = seprint(bs, be, "CertificateRequest\n");
1339 		bs = bytesPrint(bs, be, "\ttypes: ", m->u.certificateRequest.types, "\n");
1340 		bs = seprint(bs, be, "\tcertificateauthorities\n");
1341 		for(i=0; i<m->u.certificateRequest.nca; i++)
1342 			bs = bytesPrint(bs, be, "\t\t", m->u.certificateRequest.cas[i], "\n");
1343 		break;
1344 	case HServerHelloDone:
1345 		bs = seprint(bs, be, "ServerHelloDone\n");
1346 		break;
1347 	case HClientKeyExchange:
1348 		bs = seprint(bs, be, "HClientKeyExchange\n");
1349 		bs = bytesPrint(bs, be, "\tkey: ", m->u.clientKeyExchange.key, "\n");
1350 		break;
1351 	case HFinished:
1352 		bs = seprint(bs, be, "HFinished\n");
1353 		for(i=0; i<m->u.finished.n; i++)
1354 			bs = seprint(bs, be, "%.2x", m->u.finished.verify[i]);
1355 		bs = seprint(bs, be, "\n");
1356 		break;
1357 	}
1358 	USED(bs);
1359 	return buf;
1360 }
1361 
1362 static void
1363 tlsError(TlsConnection *c, int err, char *fmt, ...)
1364 {
1365 	char msg[512];
1366 	va_list arg;
1367 
1368 	va_start(arg, fmt);
1369 	vseprint(msg, msg+sizeof(msg), fmt, arg);
1370 	va_end(arg);
1371 	if(c->trace)
1372 		c->trace("tlsError: %s\n", msg);
1373 	else if(c->erred)
1374 		fprint(2, "double error: %r, %s", msg);
1375 	else
1376 		werrstr("tls: local %s", msg);
1377 	c->erred = 1;
1378 	fprint(c->ctl, "alert %d", err);
1379 }
1380 
1381 // commit to specific version number
1382 static int
1383 setVersion(TlsConnection *c, int version)
1384 {
1385 	if(c->verset || version > MaxProtoVersion || version < MinProtoVersion)
1386 		return -1;
1387 	if(version > c->version)
1388 		version = c->version;
1389 	if(version == SSL3Version) {
1390 		c->version = version;
1391 		c->finished.n = SSL3FinishedLen;
1392 	}else if(version == TLSVersion){
1393 		c->version = version;
1394 		c->finished.n = TLSFinishedLen;
1395 	}else
1396 		return -1;
1397 	c->verset = 1;
1398 	return fprint(c->ctl, "version 0x%x", version);
1399 }
1400 
1401 // confirm that received Finished message matches the expected value
1402 static int
1403 finishedMatch(TlsConnection *c, Finished *f)
1404 {
1405 	return memcmp(f->verify, c->finished.verify, f->n) == 0;
1406 }
1407 
1408 // free memory associated with TlsConnection struct
1409 //		(but don't close the TLS channel itself)
1410 static void
1411 tlsConnectionFree(TlsConnection *c)
1412 {
1413 	tlsSecClose(c->sec);
1414 	freebytes(c->sid);
1415 	freebytes(c->cert);
1416 	memset(c, 0, sizeof(c));
1417 	free(c);
1418 }
1419 
1420 
1421 //================= cipher choices ========================
1422 
1423 static int weakCipher[CipherMax] =
1424 {
1425 	1,	/* TLS_NULL_WITH_NULL_NULL */
1426 	1,	/* TLS_RSA_WITH_NULL_MD5 */
1427 	1,	/* TLS_RSA_WITH_NULL_SHA */
1428 	1,	/* TLS_RSA_EXPORT_WITH_RC4_40_MD5 */
1429 	0,	/* TLS_RSA_WITH_RC4_128_MD5 */
1430 	0,	/* TLS_RSA_WITH_RC4_128_SHA */
1431 	1,	/* TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5 */
1432 	0,	/* TLS_RSA_WITH_IDEA_CBC_SHA */
1433 	1,	/* TLS_RSA_EXPORT_WITH_DES40_CBC_SHA */
1434 	0,	/* TLS_RSA_WITH_DES_CBC_SHA */
1435 	0,	/* TLS_RSA_WITH_3DES_EDE_CBC_SHA */
1436 	1,	/* TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA */
1437 	0,	/* TLS_DH_DSS_WITH_DES_CBC_SHA */
1438 	0,	/* TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA */
1439 	1,	/* TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA */
1440 	0,	/* TLS_DH_RSA_WITH_DES_CBC_SHA */
1441 	0,	/* TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA */
1442 	1,	/* TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA */
1443 	0,	/* TLS_DHE_DSS_WITH_DES_CBC_SHA */
1444 	0,	/* TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA */
1445 	1,	/* TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA */
1446 	0,	/* TLS_DHE_RSA_WITH_DES_CBC_SHA */
1447 	0,	/* TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA */
1448 	1,	/* TLS_DH_anon_EXPORT_WITH_RC4_40_MD5 */
1449 	1,	/* TLS_DH_anon_WITH_RC4_128_MD5 */
1450 	1,	/* TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA */
1451 	1,	/* TLS_DH_anon_WITH_DES_CBC_SHA */
1452 	1,	/* TLS_DH_anon_WITH_3DES_EDE_CBC_SHA */
1453 };
1454 
1455 static int
1456 setAlgs(TlsConnection *c, int a)
1457 {
1458 	int i;
1459 
1460 	for(i = 0; i < nelem(cipherAlgs); i++){
1461 		if(cipherAlgs[i].tlsid == a){
1462 			c->enc = cipherAlgs[i].enc;
1463 			c->digest = cipherAlgs[i].digest;
1464 			c->nsecret = cipherAlgs[i].nsecret;
1465 			if(c->nsecret > MaxKeyData)
1466 				return 0;
1467 			return 1;
1468 		}
1469 	}
1470 	return 0;
1471 }
1472 
1473 static int
1474 okCipher(Ints *cv)
1475 {
1476 	int weak, i, j, c;
1477 
1478 	weak = 1;
1479 	for(i = 0; i < cv->len; i++) {
1480 		c = cv->data[i];
1481 		if(c >= CipherMax)
1482 			weak = 0;
1483 		else
1484 			weak &= weakCipher[c];
1485 		for(j = 0; j < nelem(cipherAlgs); j++)
1486 			if(cipherAlgs[j].ok && cipherAlgs[j].tlsid == c)
1487 				return c;
1488 	}
1489 	if(weak)
1490 		return -2;
1491 	return -1;
1492 }
1493 
1494 static int
1495 okCompression(Bytes *cv)
1496 {
1497 	int i, j, c;
1498 
1499 	for(i = 0; i < cv->len; i++) {
1500 		c = cv->data[i];
1501 		for(j = 0; j < nelem(compressors); j++) {
1502 			if(compressors[j] == c)
1503 				return c;
1504 		}
1505 	}
1506 	return -1;
1507 }
1508 
1509 static Lock	ciphLock;
1510 static int	nciphers;
1511 
1512 static int
1513 initCiphers(void)
1514 {
1515 	enum {MaxAlgF = 1024, MaxAlgs = 10};
1516 	char s[MaxAlgF], *flds[MaxAlgs];
1517 	int i, j, n, ok;
1518 
1519 	lock(&ciphLock);
1520 	if(nciphers){
1521 		unlock(&ciphLock);
1522 		return nciphers;
1523 	}
1524 	j = open("#a/tls/encalgs", OREAD);
1525 	if(j < 0){
1526 		werrstr("can't open #a/tls/encalgs: %r");
1527 		return 0;
1528 	}
1529 	n = read(j, s, MaxAlgF-1);
1530 	close(j);
1531 	if(n <= 0){
1532 		werrstr("nothing in #a/tls/encalgs: %r");
1533 		return 0;
1534 	}
1535 	s[n] = 0;
1536 	n = getfields(s, flds, MaxAlgs, 1, " \t\r\n");
1537 	for(i = 0; i < nelem(cipherAlgs); i++){
1538 		ok = 0;
1539 		for(j = 0; j < n; j++){
1540 			if(strcmp(cipherAlgs[i].enc, flds[j]) == 0){
1541 				ok = 1;
1542 				break;
1543 			}
1544 		}
1545 		cipherAlgs[i].ok = ok;
1546 	}
1547 
1548 	j = open("#a/tls/hashalgs", OREAD);
1549 	if(j < 0){
1550 		werrstr("can't open #a/tls/hashalgs: %r");
1551 		return 0;
1552 	}
1553 	n = read(j, s, MaxAlgF-1);
1554 	close(j);
1555 	if(n <= 0){
1556 		werrstr("nothing in #a/tls/hashalgs: %r");
1557 		return 0;
1558 	}
1559 	s[n] = 0;
1560 	n = getfields(s, flds, MaxAlgs, 1, " \t\r\n");
1561 	for(i = 0; i < nelem(cipherAlgs); i++){
1562 		ok = 0;
1563 		for(j = 0; j < n; j++){
1564 			if(strcmp(cipherAlgs[i].digest, flds[j]) == 0){
1565 				ok = 1;
1566 				break;
1567 			}
1568 		}
1569 		cipherAlgs[i].ok &= ok;
1570 		if(cipherAlgs[i].ok)
1571 			nciphers++;
1572 	}
1573 	unlock(&ciphLock);
1574 	return nciphers;
1575 }
1576 
1577 static Ints*
1578 makeciphers(void)
1579 {
1580 	Ints *is;
1581 	int i, j;
1582 
1583 	is = newints(nciphers);
1584 	j = 0;
1585 	for(i = 0; i < nelem(cipherAlgs); i++){
1586 		if(cipherAlgs[i].ok)
1587 			is->data[j++] = cipherAlgs[i].tlsid;
1588 	}
1589 	return is;
1590 }
1591 
1592 
1593 
1594 //================= security functions ========================
1595 
1596 // given X.509 certificate, set up connection to factotum
1597 //	for using corresponding private key
1598 static AuthRpc*
1599 factotum_rsa_open(uchar *cert, int certlen)
1600 {
1601 	int afd;
1602 	char *s;
1603 	mpint *pub = nil;
1604 	RSApub *rsapub;
1605 	AuthRpc *rpc;
1606 
1607 	// start talking to factotum
1608 	if((afd = open("/mnt/factotum/rpc", ORDWR)) < 0)
1609 		return nil;
1610 	if((rpc = auth_allocrpc(afd)) == nil){
1611 		close(afd);
1612 		return nil;
1613 	}
1614 	s = "proto=rsa service=tls role=client";
1615 	if(auth_rpc(rpc, "start", s, strlen(s)) != ARok){
1616 		factotum_rsa_close(rpc);
1617 		return nil;
1618 	}
1619 
1620 	// roll factotum keyring around to match certificate
1621 	rsapub = X509toRSApub(cert, certlen, nil, 0);
1622 	while(1){
1623 		if(auth_rpc(rpc, "read", nil, 0) != ARok){
1624 			factotum_rsa_close(rpc);
1625 			rpc = nil;
1626 			goto done;
1627 		}
1628 		pub = strtomp(rpc->arg, nil, 16, nil);
1629 		assert(pub != nil);
1630 		if(mpcmp(pub,rsapub->n) == 0)
1631 			break;
1632 	}
1633 done:
1634 	mpfree(pub);
1635 	rsapubfree(rsapub);
1636 	return rpc;
1637 }
1638 
1639 static mpint*
1640 factotum_rsa_decrypt(AuthRpc *rpc, mpint *cipher)
1641 {
1642 	char *p;
1643 	int rv;
1644 
1645 	if((p = mptoa(cipher, 16, nil, 0)) == nil)
1646 		return nil;
1647 	rv = auth_rpc(rpc, "write", p, strlen(p));
1648 	free(p);
1649 	if(rv != ARok || auth_rpc(rpc, "read", nil, 0) != ARok)
1650 		return nil;
1651 	mpfree(cipher);
1652 	return strtomp(rpc->arg, nil, 16, nil);
1653 }
1654 
1655 static void
1656 factotum_rsa_close(AuthRpc*rpc)
1657 {
1658 	if(!rpc)
1659 		return;
1660 	close(rpc->afd);
1661 	auth_freerpc(rpc);
1662 }
1663 
1664 static void
1665 tlsPmd5(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
1666 {
1667 	uchar ai[MD5dlen], tmp[MD5dlen];
1668 	int i, n;
1669 	MD5state *s;
1670 
1671 	// generate a1
1672 	s = hmac_md5(label, nlabel, key, nkey, nil, nil);
1673 	s = hmac_md5(seed0, nseed0, key, nkey, nil, s);
1674 	hmac_md5(seed1, nseed1, key, nkey, ai, s);
1675 
1676 	while(nbuf > 0) {
1677 		s = hmac_md5(ai, MD5dlen, key, nkey, nil, nil);
1678 		s = hmac_md5(label, nlabel, key, nkey, nil, s);
1679 		s = hmac_md5(seed0, nseed0, key, nkey, nil, s);
1680 		hmac_md5(seed1, nseed1, key, nkey, tmp, s);
1681 		n = MD5dlen;
1682 		if(n > nbuf)
1683 			n = nbuf;
1684 		for(i = 0; i < n; i++)
1685 			buf[i] ^= tmp[i];
1686 		buf += n;
1687 		nbuf -= n;
1688 		hmac_md5(ai, MD5dlen, key, nkey, tmp, nil);
1689 		memmove(ai, tmp, MD5dlen);
1690 	}
1691 }
1692 
1693 static void
1694 tlsPsha1(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
1695 {
1696 	uchar ai[SHA1dlen], tmp[SHA1dlen];
1697 	int i, n;
1698 	SHAstate *s;
1699 
1700 	// generate a1
1701 	s = hmac_sha1(label, nlabel, key, nkey, nil, nil);
1702 	s = hmac_sha1(seed0, nseed0, key, nkey, nil, s);
1703 	hmac_sha1(seed1, nseed1, key, nkey, ai, s);
1704 
1705 	while(nbuf > 0) {
1706 		s = hmac_sha1(ai, SHA1dlen, key, nkey, nil, nil);
1707 		s = hmac_sha1(label, nlabel, key, nkey, nil, s);
1708 		s = hmac_sha1(seed0, nseed0, key, nkey, nil, s);
1709 		hmac_sha1(seed1, nseed1, key, nkey, tmp, s);
1710 		n = SHA1dlen;
1711 		if(n > nbuf)
1712 			n = nbuf;
1713 		for(i = 0; i < n; i++)
1714 			buf[i] ^= tmp[i];
1715 		buf += n;
1716 		nbuf -= n;
1717 		hmac_sha1(ai, SHA1dlen, key, nkey, tmp, nil);
1718 		memmove(ai, tmp, SHA1dlen);
1719 	}
1720 }
1721 
1722 // fill buf with md5(args)^sha1(args)
1723 static void
1724 tlsPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
1725 {
1726 	int i;
1727 	int nlabel = strlen(label);
1728 	int n = (nkey + 1) >> 1;
1729 
1730 	for(i = 0; i < nbuf; i++)
1731 		buf[i] = 0;
1732 	tlsPmd5(buf, nbuf, key, n, (uchar*)label, nlabel, seed0, nseed0, seed1, nseed1);
1733 	tlsPsha1(buf, nbuf, key+nkey-n, n, (uchar*)label, nlabel, seed0, nseed0, seed1, nseed1);
1734 }
1735 
1736 /*
1737  * for setting server session id's
1738  */
1739 static Lock	sidLock;
1740 static long	maxSid = 1;
1741 
1742 /* the keys are verified to have the same public components
1743  * and to function correctly with pkcs 1 encryption and decryption. */
1744 static TlsSec*
1745 tlsSecInits(int cvers, uchar *csid, int ncsid, uchar *crandom, uchar *ssid, int *nssid, uchar *srandom)
1746 {
1747 	TlsSec *sec = emalloc(sizeof(*sec));
1748 
1749 	USED(csid); USED(ncsid);  // ignore csid for now
1750 
1751 	memmove(sec->crandom, crandom, RandomSize);
1752 	sec->clientVers = cvers;
1753 
1754 	put32(sec->srandom, time(0));
1755 	genrandom(sec->srandom+4, RandomSize-4);
1756 	memmove(srandom, sec->srandom, RandomSize);
1757 
1758 	/*
1759 	 * make up a unique sid: use our pid, and and incrementing id
1760 	 * can signal no sid by setting nssid to 0.
1761 	 */
1762 	memset(ssid, 0, SidSize);
1763 	put32(ssid, getpid());
1764 	lock(&sidLock);
1765 	put32(ssid+4, maxSid++);
1766 	unlock(&sidLock);
1767 	*nssid = SidSize;
1768 	return sec;
1769 }
1770 
1771 static int
1772 tlsSecSecrets(TlsSec *sec, int vers, uchar *epm, int nepm, uchar *kd, int nkd)
1773 {
1774 	if(epm != nil){
1775 		if(setVers(sec, vers) < 0)
1776 			goto Err;
1777 		serverMasterSecret(sec, epm, nepm);
1778 	}else if(sec->vers != vers){
1779 		werrstr("mismatched session versions");
1780 		goto Err;
1781 	}
1782 	setSecrets(sec, kd, nkd);
1783 	return 0;
1784 Err:
1785 	sec->ok = -1;
1786 	return -1;
1787 }
1788 
1789 static TlsSec*
1790 tlsSecInitc(int cvers, uchar *crandom)
1791 {
1792 	TlsSec *sec = emalloc(sizeof(*sec));
1793 	sec->clientVers = cvers;
1794 	put32(sec->crandom, time(0));
1795 	genrandom(sec->crandom+4, RandomSize-4);
1796 	memmove(crandom, sec->crandom, RandomSize);
1797 	return sec;
1798 }
1799 
1800 static int
1801 tlsSecSecretc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers, uchar **epm, int *nepm, uchar *kd, int nkd)
1802 {
1803 	RSApub *pub;
1804 
1805 	pub = nil;
1806 
1807 	USED(sid);
1808 	USED(nsid);
1809 
1810 	memmove(sec->srandom, srandom, RandomSize);
1811 
1812 	if(setVers(sec, vers) < 0)
1813 		goto Err;
1814 
1815 	pub = X509toRSApub(cert, ncert, nil, 0);
1816 	if(pub == nil){
1817 		werrstr("invalid x509/rsa certificate");
1818 		goto Err;
1819 	}
1820 	if(clientMasterSecret(sec, pub, epm, nepm) < 0)
1821 		goto Err;
1822 	rsapubfree(pub);
1823 	setSecrets(sec, kd, nkd);
1824 	return 0;
1825 
1826 Err:
1827 	if(pub != nil)
1828 		rsapubfree(pub);
1829 	sec->ok = -1;
1830 	return -1;
1831 }
1832 
1833 static int
1834 tlsSecFinished(TlsSec *sec, MD5state md5, SHAstate sha1, uchar *fin, int nfin, int isclient)
1835 {
1836 	if(sec->nfin != nfin){
1837 		sec->ok = -1;
1838 		werrstr("invalid finished exchange");
1839 		return -1;
1840 	}
1841 	md5.malloced = 0;
1842 	sha1.malloced = 0;
1843 	(*sec->setFinished)(sec, md5, sha1, fin, isclient);
1844 	return 1;
1845 }
1846 
1847 static void
1848 tlsSecOk(TlsSec *sec)
1849 {
1850 	if(sec->ok == 0)
1851 		sec->ok = 1;
1852 }
1853 
1854 static void
1855 tlsSecKill(TlsSec *sec)
1856 {
1857 	if(!sec)
1858 		return;
1859 	factotum_rsa_close(sec->rpc);
1860 	sec->ok = -1;
1861 }
1862 
1863 static void
1864 tlsSecClose(TlsSec *sec)
1865 {
1866 	if(!sec)
1867 		return;
1868 	factotum_rsa_close(sec->rpc);
1869 	free(sec->server);
1870 	free(sec);
1871 }
1872 
1873 static int
1874 setVers(TlsSec *sec, int v)
1875 {
1876 	if(v == SSL3Version){
1877 		sec->setFinished = sslSetFinished;
1878 		sec->nfin = SSL3FinishedLen;
1879 		sec->prf = sslPRF;
1880 	}else if(v == TLSVersion){
1881 		sec->setFinished = tlsSetFinished;
1882 		sec->nfin = TLSFinishedLen;
1883 		sec->prf = tlsPRF;
1884 	}else{
1885 		werrstr("invalid version");
1886 		return -1;
1887 	}
1888 	sec->vers = v;
1889 	return 0;
1890 }
1891 
1892 /*
1893  * generate secret keys from the master secret.
1894  *
1895  * different crypto selections will require different amounts
1896  * of key expansion and use of key expansion data,
1897  * but it's all generated using the same function.
1898  */
1899 static void
1900 setSecrets(TlsSec *sec, uchar *kd, int nkd)
1901 {
1902 	(*sec->prf)(kd, nkd, sec->sec, MasterSecretSize, "key expansion",
1903 			sec->srandom, RandomSize, sec->crandom, RandomSize);
1904 }
1905 
1906 /*
1907  * set the master secret from the pre-master secret.
1908  */
1909 static void
1910 setMasterSecret(TlsSec *sec, Bytes *pm)
1911 {
1912 	(*sec->prf)(sec->sec, MasterSecretSize, pm->data, MasterSecretSize, "master secret",
1913 			sec->crandom, RandomSize, sec->srandom, RandomSize);
1914 }
1915 
1916 static void
1917 serverMasterSecret(TlsSec *sec, uchar *epm, int nepm)
1918 {
1919 	Bytes *pm;
1920 
1921 	pm = pkcs1_decrypt(sec, epm, nepm);
1922 
1923 	// if the client messed up, just continue as if everything is ok,
1924 	// to prevent attacks to check for correctly formatted messages.
1925 	// Hence the fprint(2,) can't be replaced by tlsError(), which sends an Alert msg to the client.
1926 	if(sec->ok < 0 || pm == nil || get16(pm->data) != sec->clientVers){
1927 		fprint(2, "serverMasterSecret failed ok=%d pm=%p pmvers=%x cvers=%x nepm=%d\n",
1928 			sec->ok, pm, pm ? get16(pm->data) : -1, sec->clientVers, nepm);
1929 		sec->ok = -1;
1930 		if(pm != nil)
1931 			freebytes(pm);
1932 		pm = newbytes(MasterSecretSize);
1933 		genrandom(pm->data, MasterSecretSize);
1934 	}
1935 	setMasterSecret(sec, pm);
1936 	memset(pm->data, 0, pm->len);
1937 	freebytes(pm);
1938 }
1939 
1940 static int
1941 clientMasterSecret(TlsSec *sec, RSApub *pub, uchar **epm, int *nepm)
1942 {
1943 	Bytes *pm, *key;
1944 
1945 	pm = newbytes(MasterSecretSize);
1946 	put16(pm->data, sec->clientVers);
1947 	genrandom(pm->data+2, MasterSecretSize - 2);
1948 
1949 	setMasterSecret(sec, pm);
1950 
1951 	key = pkcs1_encrypt(pm, pub, 2);
1952 	memset(pm->data, 0, pm->len);
1953 	freebytes(pm);
1954 	if(key == nil){
1955 		werrstr("tls pkcs1_encrypt failed");
1956 		return -1;
1957 	}
1958 
1959 	*nepm = key->len;
1960 	*epm = malloc(*nepm);
1961 	if(*epm == nil){
1962 		freebytes(key);
1963 		werrstr("out of memory");
1964 		return -1;
1965 	}
1966 	memmove(*epm, key->data, *nepm);
1967 
1968 	freebytes(key);
1969 
1970 	return 1;
1971 }
1972 
1973 static void
1974 sslSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient)
1975 {
1976 	DigestState *s;
1977 	uchar h0[MD5dlen], h1[SHA1dlen], pad[48];
1978 	char *label;
1979 
1980 	if(isClient)
1981 		label = "CLNT";
1982 	else
1983 		label = "SRVR";
1984 
1985 	md5((uchar*)label, 4, nil, &hsmd5);
1986 	md5(sec->sec, MasterSecretSize, nil, &hsmd5);
1987 	memset(pad, 0x36, 48);
1988 	md5(pad, 48, nil, &hsmd5);
1989 	md5(nil, 0, h0, &hsmd5);
1990 	memset(pad, 0x5C, 48);
1991 	s = md5(sec->sec, MasterSecretSize, nil, nil);
1992 	s = md5(pad, 48, nil, s);
1993 	md5(h0, MD5dlen, finished, s);
1994 
1995 	sha1((uchar*)label, 4, nil, &hssha1);
1996 	sha1(sec->sec, MasterSecretSize, nil, &hssha1);
1997 	memset(pad, 0x36, 40);
1998 	sha1(pad, 40, nil, &hssha1);
1999 	sha1(nil, 0, h1, &hssha1);
2000 	memset(pad, 0x5C, 40);
2001 	s = sha1(sec->sec, MasterSecretSize, nil, nil);
2002 	s = sha1(pad, 40, nil, s);
2003 	sha1(h1, SHA1dlen, finished + MD5dlen, s);
2004 }
2005 
2006 // fill "finished" arg with md5(args)^sha1(args)
2007 static void
2008 tlsSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient)
2009 {
2010 	uchar h0[MD5dlen], h1[SHA1dlen];
2011 	char *label;
2012 
2013 	// get current hash value, but allow further messages to be hashed in
2014 	md5(nil, 0, h0, &hsmd5);
2015 	sha1(nil, 0, h1, &hssha1);
2016 
2017 	if(isClient)
2018 		label = "client finished";
2019 	else
2020 		label = "server finished";
2021 	tlsPRF(finished, TLSFinishedLen, sec->sec, MasterSecretSize, label, h0, MD5dlen, h1, SHA1dlen);
2022 }
2023 
2024 static void
2025 sslPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
2026 {
2027 	DigestState *s;
2028 	uchar sha1dig[SHA1dlen], md5dig[MD5dlen], tmp[26];
2029 	int i, n, len;
2030 
2031 	USED(label);
2032 	len = 1;
2033 	while(nbuf > 0){
2034 		if(len > 26)
2035 			return;
2036 		for(i = 0; i < len; i++)
2037 			tmp[i] = 'A' - 1 + len;
2038 		s = sha1(tmp, len, nil, nil);
2039 		s = sha1(key, nkey, nil, s);
2040 		s = sha1(seed0, nseed0, nil, s);
2041 		sha1(seed1, nseed1, sha1dig, s);
2042 		s = md5(key, nkey, nil, nil);
2043 		md5(sha1dig, SHA1dlen, md5dig, s);
2044 		n = MD5dlen;
2045 		if(n > nbuf)
2046 			n = nbuf;
2047 		memmove(buf, md5dig, n);
2048 		buf += n;
2049 		nbuf -= n;
2050 		len++;
2051 	}
2052 }
2053 
2054 static mpint*
2055 bytestomp(Bytes* bytes)
2056 {
2057 	mpint* ans;
2058 
2059 	ans = betomp(bytes->data, bytes->len, nil);
2060 	return ans;
2061 }
2062 
2063 /*
2064  * Convert mpint* to Bytes, putting high order byte first.
2065  */
2066 static Bytes*
2067 mptobytes(mpint* big)
2068 {
2069 	int n, m;
2070 	uchar *a;
2071 	Bytes* ans;
2072 
2073 	a = nil;
2074 	n = (mpsignif(big)+7)/8;
2075 	m = mptobe(big, nil, n, &a);
2076 	ans = makebytes(a, m);
2077 	if(a != nil)
2078 		free(a);
2079 	return ans;
2080 }
2081 
2082 // Do RSA computation on block according to key, and pad
2083 // result on left with zeros to make it modlen long.
2084 static Bytes*
2085 rsacomp(Bytes* block, RSApub* key, int modlen)
2086 {
2087 	mpint *x, *y;
2088 	Bytes *a, *ybytes;
2089 	int ylen;
2090 
2091 	x = bytestomp(block);
2092 	y = rsaencrypt(key, x, nil);
2093 	mpfree(x);
2094 	ybytes = mptobytes(y);
2095 	ylen = ybytes->len;
2096 
2097 	if(ylen < modlen) {
2098 		a = newbytes(modlen);
2099 		memset(a->data, 0, modlen-ylen);
2100 		memmove(a->data+modlen-ylen, ybytes->data, ylen);
2101 		freebytes(ybytes);
2102 		ybytes = a;
2103 	}
2104 	else if(ylen > modlen) {
2105 		// assume it has leading zeros (mod should make it so)
2106 		a = newbytes(modlen);
2107 		memmove(a->data, ybytes->data, modlen);
2108 		freebytes(ybytes);
2109 		ybytes = a;
2110 	}
2111 	mpfree(y);
2112 	return ybytes;
2113 }
2114 
2115 // encrypt data according to PKCS#1, /lib/rfc/rfc2437 9.1.2.1
2116 static Bytes*
2117 pkcs1_encrypt(Bytes* data, RSApub* key, int blocktype)
2118 {
2119 	Bytes *pad, *eb, *ans;
2120 	int i, dlen, padlen, modlen;
2121 
2122 	modlen = (mpsignif(key->n)+7)/8;
2123 	dlen = data->len;
2124 	if(modlen < 12 || dlen > modlen - 11)
2125 		return nil;
2126 	padlen = modlen - 3 - dlen;
2127 	pad = newbytes(padlen);
2128 	genrandom(pad->data, padlen);
2129 	for(i = 0; i < padlen; i++) {
2130 		if(blocktype == 0)
2131 			pad->data[i] = 0;
2132 		else if(blocktype == 1)
2133 			pad->data[i] = 255;
2134 		else if(pad->data[i] == 0)
2135 			pad->data[i] = 1;
2136 	}
2137 	eb = newbytes(modlen);
2138 	eb->data[0] = 0;
2139 	eb->data[1] = blocktype;
2140 	memmove(eb->data+2, pad->data, padlen);
2141 	eb->data[padlen+2] = 0;
2142 	memmove(eb->data+padlen+3, data->data, dlen);
2143 	ans = rsacomp(eb, key, modlen);
2144 	freebytes(eb);
2145 	freebytes(pad);
2146 	return ans;
2147 }
2148 
2149 // decrypt data according to PKCS#1, with given key.
2150 // expect a block type of 2.
2151 static Bytes*
2152 pkcs1_decrypt(TlsSec *sec, uchar *epm, int nepm)
2153 {
2154 	Bytes *eb, *ans = nil;
2155 	int i, modlen;
2156 	mpint *x, *y;
2157 
2158 	modlen = (mpsignif(sec->rsapub->n)+7)/8;
2159 	if(nepm != modlen)
2160 		return nil;
2161 	x = betomp(epm, nepm, nil);
2162 	y = factotum_rsa_decrypt(sec->rpc, x);
2163 	if(y == nil)
2164 		return nil;
2165 	eb = mptobytes(y);
2166 	if(eb->len < modlen){ // pad on left with zeros
2167 		ans = newbytes(modlen);
2168 		memset(ans->data, 0, modlen-eb->len);
2169 		memmove(ans->data+modlen-eb->len, eb->data, eb->len);
2170 		freebytes(eb);
2171 		eb = ans;
2172 	}
2173 	if(eb->data[0] == 0 && eb->data[1] == 2) {
2174 		for(i = 2; i < modlen; i++)
2175 			if(eb->data[i] == 0)
2176 				break;
2177 		if(i < modlen - 1)
2178 			ans = makebytes(eb->data+i+1, modlen-(i+1));
2179 	}
2180 	freebytes(eb);
2181 	return ans;
2182 }
2183 
2184 
2185 //================= general utility functions ========================
2186 
2187 static void *
2188 emalloc(int n)
2189 {
2190 	void *p;
2191 	if(n==0)
2192 		n=1;
2193 	p = malloc(n);
2194 	if(p == nil){
2195 		exits("out of memory");
2196 	}
2197 	memset(p, 0, n);
2198 	return p;
2199 }
2200 
2201 static void *
2202 erealloc(void *ReallocP, int ReallocN)
2203 {
2204 	if(ReallocN == 0)
2205 		ReallocN = 1;
2206 	if(!ReallocP)
2207 		ReallocP = emalloc(ReallocN);
2208 	else if(!(ReallocP = realloc(ReallocP, ReallocN))){
2209 		exits("out of memory");
2210 	}
2211 	return(ReallocP);
2212 }
2213 
2214 static void
2215 put32(uchar *p, u32int x)
2216 {
2217 	p[0] = x>>24;
2218 	p[1] = x>>16;
2219 	p[2] = x>>8;
2220 	p[3] = x;
2221 }
2222 
2223 static void
2224 put24(uchar *p, int x)
2225 {
2226 	p[0] = x>>16;
2227 	p[1] = x>>8;
2228 	p[2] = x;
2229 }
2230 
2231 static void
2232 put16(uchar *p, int x)
2233 {
2234 	p[0] = x>>8;
2235 	p[1] = x;
2236 }
2237 
2238 static u32int
2239 get32(uchar *p)
2240 {
2241 	return (p[0]<<24)|(p[1]<<16)|(p[2]<<8)|p[3];
2242 }
2243 
2244 static int
2245 get24(uchar *p)
2246 {
2247 	return (p[0]<<16)|(p[1]<<8)|p[2];
2248 }
2249 
2250 static int
2251 get16(uchar *p)
2252 {
2253 	return (p[0]<<8)|p[1];
2254 }
2255 
2256 #define OFFSET(x, s) offsetof(s, x)
2257 
2258 /*
2259  * malloc and return a new Bytes structure capable of
2260  * holding len bytes. (len >= 0)
2261  * Used to use crypt_malloc, which aborts if malloc fails.
2262  */
2263 static Bytes*
2264 newbytes(int len)
2265 {
2266 	Bytes* ans;
2267 
2268 	ans = (Bytes*)malloc(OFFSET(data[0], Bytes) + len);
2269 	ans->len = len;
2270 	return ans;
2271 }
2272 
2273 /*
2274  * newbytes(len), with data initialized from buf
2275  */
2276 static Bytes*
2277 makebytes(uchar* buf, int len)
2278 {
2279 	Bytes* ans;
2280 
2281 	ans = newbytes(len);
2282 	memmove(ans->data, buf, len);
2283 	return ans;
2284 }
2285 
2286 static void
2287 freebytes(Bytes* b)
2288 {
2289 	if(b != nil)
2290 		free(b);
2291 }
2292 
2293 /* len is number of ints */
2294 static Ints*
2295 newints(int len)
2296 {
2297 	Ints* ans;
2298 
2299 	ans = (Ints*)malloc(OFFSET(data[0], Ints) + len*sizeof(int));
2300 	ans->len = len;
2301 	return ans;
2302 }
2303 
2304 static Ints*
2305 makeints(int* buf, int len)
2306 {
2307 	Ints* ans;
2308 
2309 	ans = newints(len);
2310 	if(len > 0)
2311 		memmove(ans->data, buf, len*sizeof(int));
2312 	return ans;
2313 }
2314 
2315 static void
2316 freeints(Ints* b)
2317 {
2318 	if(b != nil)
2319 		free(b);
2320 }
2321