xref: /plan9/sys/src/libsec/port/tlshand.c (revision ec59a3ddbfceee0efe34584c2c9981a5e5ff1ec4)
1 #include <u.h>
2 #include <libc.h>
3 #include <bio.h>
4 #include <auth.h>
5 #include <mp.h>
6 #include <libsec.h>
7 
8 // The main groups of functions are:
9 //		client/server - main handshake protocol definition
10 //		message functions - formating handshake messages
11 //		cipher choices - catalog of digest and encrypt algorithms
12 //		security functions - PKCS#1, sslHMAC, session keygen
13 //		general utility functions - malloc, serialization
14 // The handshake protocol builds on the TLS/SSL3 record layer protocol,
15 // which is implemented in kernel device #a.  See also /lib/rfc/rfc2246.
16 
17 enum {
18 	TLSFinishedLen = 12,
19 	SSL3FinishedLen = MD5dlen+SHA1dlen,
20 	MaxKeyData = 104,	// amount of secret we may need
21 	MaxChunk = 1<<14,
22 	RandomSize = 32,
23 	SidSize = 32,
24 	MasterSecretSize = 48,
25 	AQueue = 0,
26 	AFlush = 1,
27 };
28 
29 typedef struct TlsSec TlsSec;
30 
31 typedef struct Bytes{
32 	int len;
33 	uchar data[1];  // [len]
34 } Bytes;
35 
36 typedef struct Ints{
37 	int len;
38 	int data[1];  // [len]
39 } Ints;
40 
41 typedef struct Algs{
42 	char *enc;
43 	char *digest;
44 	int nsecret;
45 	int tlsid;
46 	int ok;
47 } Algs;
48 
49 typedef struct Finished{
50 	uchar verify[SSL3FinishedLen];
51 	int n;
52 } Finished;
53 
54 typedef struct TlsConnection{
55 	TlsSec *sec;	// security management goo
56 	int hand, ctl;	// record layer file descriptors
57 	int erred;		// set when tlsError called
58 	int (*trace)(char*fmt, ...); // for debugging
59 	int version;	// protocol we are speaking
60 	int verset;		// version has been set
61 	int ver2hi;		// server got a version 2 hello
62 	int isClient;	// is this the client or server?
63 	Bytes *sid;		// SessionID
64 	Bytes *cert;	// only last - no chain
65 
66 	Lock statelk;
67 	int state;		// must be set using setstate
68 
69 	// input buffer for handshake messages
70 	uchar buf[MaxChunk+2048];
71 	uchar *rp, *ep;
72 
73 	uchar crandom[RandomSize];	// client random
74 	uchar srandom[RandomSize];	// server random
75 	int clientVersion;	// version in ClientHello
76 	char *digest;	// name of digest algorithm to use
77 	char *enc;		// name of encryption algorithm to use
78 	int nsecret;	// amount of secret data to init keys
79 
80 	// for finished messages
81 	MD5state	hsmd5;	// handshake hash
82 	SHAstate	hssha1;	// handshake hash
83 	Finished	finished;
84 } TlsConnection;
85 
86 typedef struct Msg{
87 	int tag;
88 	union {
89 		struct {
90 			int version;
91 			uchar 	random[RandomSize];
92 			Bytes*	sid;
93 			Ints*	ciphers;
94 			Bytes*	compressors;
95 		} clientHello;
96 		struct {
97 			int version;
98 			uchar 	random[RandomSize];
99 			Bytes*	sid;
100 			int cipher;
101 			int compressor;
102 		} serverHello;
103 		struct {
104 			int ncert;
105 			Bytes **certs;
106 		} certificate;
107 		struct {
108 			Bytes *types;
109 			int nca;
110 			Bytes **cas;
111 		} certificateRequest;
112 		struct {
113 			Bytes *key;
114 		} clientKeyExchange;
115 		Finished finished;
116 	} u;
117 } Msg;
118 
119 typedef struct TlsSec{
120 	char *server;	// name of remote; nil for server
121 	int ok;	// <0 killed; == 0 in progress; >0 reusable
122 	RSApub *rsapub;
123 	AuthRpc *rpc;	// factotum for rsa private key
124 	uchar sec[MasterSecretSize];	// master secret
125 	uchar crandom[RandomSize];	// client random
126 	uchar srandom[RandomSize];	// server random
127 	int clientVers;		// version in ClientHello
128 	int vers;			// final version
129 	// byte generation and handshake checksum
130 	void (*prf)(uchar*, int, uchar*, int, char*, uchar*, int, uchar*, int);
131 	void (*setFinished)(TlsSec*, MD5state, SHAstate, uchar*, int);
132 	int nfin;
133 } TlsSec;
134 
135 
136 enum {
137 	TLSVersion = 0x0301,
138 	SSL3Version = 0x0300,
139 	ProtocolVersion = 0x0301,	// maximum version we speak
140 	MinProtoVersion = 0x0300,	// limits on version we accept
141 	MaxProtoVersion	= 0x03ff,
142 };
143 
144 // handshake type
145 enum {
146 	HHelloRequest,
147 	HClientHello,
148 	HServerHello,
149 	HSSL2ClientHello = 9,  /* local convention;  see devtls.c */
150 	HCertificate = 11,
151 	HServerKeyExchange,
152 	HCertificateRequest,
153 	HServerHelloDone,
154 	HCertificateVerify,
155 	HClientKeyExchange,
156 	HFinished = 20,
157 	HMax
158 };
159 
160 // alerts
161 enum {
162 	ECloseNotify = 0,
163 	EUnexpectedMessage = 10,
164 	EBadRecordMac = 20,
165 	EDecryptionFailed = 21,
166 	ERecordOverflow = 22,
167 	EDecompressionFailure = 30,
168 	EHandshakeFailure = 40,
169 	ENoCertificate = 41,
170 	EBadCertificate = 42,
171 	EUnsupportedCertificate = 43,
172 	ECertificateRevoked = 44,
173 	ECertificateExpired = 45,
174 	ECertificateUnknown = 46,
175 	EIllegalParameter = 47,
176 	EUnknownCa = 48,
177 	EAccessDenied = 49,
178 	EDecodeError = 50,
179 	EDecryptError = 51,
180 	EExportRestriction = 60,
181 	EProtocolVersion = 70,
182 	EInsufficientSecurity = 71,
183 	EInternalError = 80,
184 	EUserCanceled = 90,
185 	ENoRenegotiation = 100,
186 	EMax = 256
187 };
188 
189 // cipher suites
190 enum {
191 	TLS_NULL_WITH_NULL_NULL	 		= 0x0000,
192 	TLS_RSA_WITH_NULL_MD5 			= 0x0001,
193 	TLS_RSA_WITH_NULL_SHA 			= 0x0002,
194 	TLS_RSA_EXPORT_WITH_RC4_40_MD5 		= 0x0003,
195 	TLS_RSA_WITH_RC4_128_MD5 		= 0x0004,
196 	TLS_RSA_WITH_RC4_128_SHA 		= 0x0005,
197 	TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5	= 0X0006,
198 	TLS_RSA_WITH_IDEA_CBC_SHA 		= 0X0007,
199 	TLS_RSA_EXPORT_WITH_DES40_CBC_SHA	= 0X0008,
200 	TLS_RSA_WITH_DES_CBC_SHA		= 0X0009,
201 	TLS_RSA_WITH_3DES_EDE_CBC_SHA		= 0X000A,
202 	TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA	= 0X000B,
203 	TLS_DH_DSS_WITH_DES_CBC_SHA		= 0X000C,
204 	TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA	= 0X000D,
205 	TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA	= 0X000E,
206 	TLS_DH_RSA_WITH_DES_CBC_SHA		= 0X000F,
207 	TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA	= 0X0010,
208 	TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA	= 0X0011,
209 	TLS_DHE_DSS_WITH_DES_CBC_SHA		= 0X0012,
210 	TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA	= 0X0013,	// ZZZ must be implemented for tls1.0 compliance
211 	TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA	= 0X0014,
212 	TLS_DHE_RSA_WITH_DES_CBC_SHA		= 0X0015,
213 	TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA	= 0X0016,
214 	TLS_DH_anon_EXPORT_WITH_RC4_40_MD5	= 0x0017,
215 	TLS_DH_anon_WITH_RC4_128_MD5 		= 0x0018,
216 	TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA	= 0X0019,
217 	TLS_DH_anon_WITH_DES_CBC_SHA		= 0X001A,
218 	TLS_DH_anon_WITH_3DES_EDE_CBC_SHA	= 0X001B,
219 
220 	TLS_RSA_WITH_AES_128_CBC_SHA		= 0X002f,	// aes, aka rijndael with 128 bit blocks
221 	TLS_DH_DSS_WITH_AES_128_CBC_SHA		= 0X0030,
222 	TLS_DH_RSA_WITH_AES_128_CBC_SHA		= 0X0031,
223 	TLS_DHE_DSS_WITH_AES_128_CBC_SHA	= 0X0032,
224 	TLS_DHE_RSA_WITH_AES_128_CBC_SHA	= 0X0033,
225 	TLS_DH_anon_WITH_AES_128_CBC_SHA	= 0X0034,
226 	TLS_RSA_WITH_AES_256_CBC_SHA		= 0X0035,
227 	TLS_DH_DSS_WITH_AES_256_CBC_SHA		= 0X0036,
228 	TLS_DH_RSA_WITH_AES_256_CBC_SHA		= 0X0037,
229 	TLS_DHE_DSS_WITH_AES_256_CBC_SHA	= 0X0038,
230 	TLS_DHE_RSA_WITH_AES_256_CBC_SHA	= 0X0039,
231 	TLS_DH_anon_WITH_AES_256_CBC_SHA	= 0X003A,
232 	CipherMax
233 };
234 
235 // compression methods
236 enum {
237 	CompressionNull = 0,
238 	CompressionMax
239 };
240 
241 static Algs cipherAlgs[] = {
242 	{"rc4_128", "md5",	2 * (16 + MD5dlen), TLS_RSA_WITH_RC4_128_MD5},
243 	{"rc4_128", "sha1",	2 * (16 + SHA1dlen), TLS_RSA_WITH_RC4_128_SHA},
244 	{"3des_ede_cbc","sha1",2*(4*8+SHA1dlen), TLS_RSA_WITH_3DES_EDE_CBC_SHA},
245 };
246 
247 static uchar compressors[] = {
248 	CompressionNull,
249 };
250 
251 static TlsConnection *tlsServer2(int ctl, int hand, uchar *cert, int ncert, int (*trace)(char*fmt, ...), PEMChain *chain);
252 static TlsConnection *tlsClient2(int ctl, int hand, uchar *csid, int ncsid, int (*trace)(char*fmt, ...));
253 
254 static void	msgClear(Msg *m);
255 static char* msgPrint(char *buf, int n, Msg *m);
256 static int	msgRecv(TlsConnection *c, Msg *m);
257 static int	msgSend(TlsConnection *c, Msg *m, int act);
258 static void	tlsError(TlsConnection *c, int err, char *msg, ...);
259 #pragma	varargck argpos	tlsError 3
260 static int setVersion(TlsConnection *c, int version);
261 static int finishedMatch(TlsConnection *c, Finished *f);
262 static void tlsConnectionFree(TlsConnection *c);
263 
264 static int setAlgs(TlsConnection *c, int a);
265 static int okCipher(Ints *cv);
266 static int okCompression(Bytes *cv);
267 static int initCiphers(void);
268 static Ints* makeciphers(void);
269 
270 static TlsSec* tlsSecInits(int cvers, uchar *csid, int ncsid, uchar *crandom, uchar *ssid, int *nssid, uchar *srandom);
271 static int	tlsSecSecrets(TlsSec *sec, int vers, uchar *epm, int nepm, uchar *kd, int nkd);
272 static TlsSec*	tlsSecInitc(int cvers, uchar *crandom);
273 static int	tlsSecSecretc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers, uchar **epm, int *nepm, uchar *kd, int nkd);
274 static int	tlsSecFinished(TlsSec *sec, MD5state md5, SHAstate sha1, uchar *fin, int nfin, int isclient);
275 static void	tlsSecOk(TlsSec *sec);
276 static void	tlsSecKill(TlsSec *sec);
277 static void	tlsSecClose(TlsSec *sec);
278 static void	setMasterSecret(TlsSec *sec, Bytes *pm);
279 static void	serverMasterSecret(TlsSec *sec, uchar *epm, int nepm);
280 static void	setSecrets(TlsSec *sec, uchar *kd, int nkd);
281 static int	clientMasterSecret(TlsSec *sec, RSApub *pub, uchar **epm, int *nepm);
282 static Bytes *pkcs1_encrypt(Bytes* data, RSApub* key, int blocktype);
283 static Bytes *pkcs1_decrypt(TlsSec *sec, uchar *epm, int nepm);
284 static void	tlsSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient);
285 static void	sslSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient);
286 static void	sslPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label,
287 			uchar *seed0, int nseed0, uchar *seed1, int nseed1);
288 static int setVers(TlsSec *sec, int version);
289 
290 static AuthRpc* factotum_rsa_open(uchar *cert, int certlen);
291 static mpint* factotum_rsa_decrypt(AuthRpc *rpc, mpint *cipher);
292 static void factotum_rsa_close(AuthRpc*rpc);
293 
294 static void* emalloc(int);
295 static void* erealloc(void*, int);
296 static void put32(uchar *p, u32int);
297 static void put24(uchar *p, int);
298 static void put16(uchar *p, int);
299 static u32int get32(uchar *p);
300 static int get24(uchar *p);
301 static int get16(uchar *p);
302 static Bytes* newbytes(int len);
303 static Bytes* makebytes(uchar* buf, int len);
304 static void freebytes(Bytes* b);
305 static Ints* newints(int len);
306 static Ints* makeints(int* buf, int len);
307 static void freeints(Ints* b);
308 
309 //================= client/server ========================
310 
311 //	push TLS onto fd, returning new (application) file descriptor
312 //		or -1 if error.
313 int
314 tlsServer(int fd, TLSconn *conn)
315 {
316 	char buf[8];
317 	char dname[64];
318 	int n, data, ctl, hand;
319 	TlsConnection *tls;
320 
321 	if(conn == nil)
322 		return -1;
323 	ctl = open("#a/tls/clone", ORDWR);
324 	if(ctl < 0)
325 		return -1;
326 	n = read(ctl, buf, sizeof(buf)-1);
327 	if(n < 0){
328 		close(ctl);
329 		return -1;
330 	}
331 	buf[n] = 0;
332 	sprint(conn->dir, "#a/tls/%s", buf);
333 	sprint(dname, "#a/tls/%s/hand", buf);
334 	hand = open(dname, ORDWR);
335 	if(hand < 0){
336 		close(ctl);
337 		return -1;
338 	}
339 	fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion);
340 	tls = tlsServer2(ctl, hand, conn->cert, conn->certlen, conn->trace, conn->chain);
341 	sprint(dname, "#a/tls/%s/data", buf);
342 	data = open(dname, ORDWR);
343 	close(fd);
344 	close(hand);
345 	close(ctl);
346 	if(data < 0){
347 		return -1;
348 	}
349 	if(tls == nil){
350 		close(data);
351 		return -1;
352 	}
353 	if(conn->cert)
354 		free(conn->cert);
355 	conn->cert = 0;  // client certificates are not yet implemented
356 	conn->certlen = 0;
357 	conn->sessionIDlen = tls->sid->len;
358 	conn->sessionID = emalloc(conn->sessionIDlen);
359 	memcpy(conn->sessionID, tls->sid->data, conn->sessionIDlen);
360 	if(conn->sessionKey != nil && conn->sessionType != nil && strcmp(conn->sessionType, "ttls") == 0)
361 		tls->sec->prf(conn->sessionKey, conn->sessionKeylen, tls->sec->sec, MasterSecretSize, conn->sessionConst,  tls->sec->crandom, RandomSize, tls->sec->srandom, RandomSize);
362 	tlsConnectionFree(tls);
363 	return data;
364 }
365 
366 //	push TLS onto fd, returning new (application) file descriptor
367 //		or -1 if error.
368 int
369 tlsClient(int fd, TLSconn *conn)
370 {
371 	char buf[8];
372 	char dname[64];
373 	int n, data, ctl, hand;
374 	TlsConnection *tls;
375 
376 	if(!conn)
377 		return -1;
378 	ctl = open("#a/tls/clone", ORDWR);
379 	if(ctl < 0)
380 		return -1;
381 	n = read(ctl, buf, sizeof(buf)-1);
382 	if(n < 0){
383 		close(ctl);
384 		return -1;
385 	}
386 	buf[n] = 0;
387 	sprint(conn->dir, "#a/tls/%s", buf);
388 	sprint(dname, "#a/tls/%s/hand", buf);
389 	hand = open(dname, ORDWR);
390 	if(hand < 0){
391 		close(ctl);
392 		return -1;
393 	}
394 	sprint(dname, "#a/tls/%s/data", buf);
395 	data = open(dname, ORDWR);
396 	if(data < 0)
397 		return -1;
398 	fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion);
399 	tls = tlsClient2(ctl, hand, conn->sessionID, conn->sessionIDlen, conn->trace);
400 	close(fd);
401 	close(hand);
402 	close(ctl);
403 	if(tls == nil){
404 		close(data);
405 		return -1;
406 	}
407 	conn->certlen = tls->cert->len;
408 	conn->cert = emalloc(conn->certlen);
409 	memcpy(conn->cert, tls->cert->data, conn->certlen);
410 	conn->sessionIDlen = tls->sid->len;
411 	conn->sessionID = emalloc(conn->sessionIDlen);
412 	memcpy(conn->sessionID, tls->sid->data, conn->sessionIDlen);
413 	if(conn->sessionKey != nil && conn->sessionType != nil && strcmp(conn->sessionType, "ttls") == 0)
414 		tls->sec->prf(conn->sessionKey, conn->sessionKeylen, tls->sec->sec, MasterSecretSize, conn->sessionConst,  tls->sec->crandom, RandomSize, tls->sec->srandom, RandomSize);
415 	tlsConnectionFree(tls);
416 	return data;
417 }
418 
419 static int
420 countchain(PEMChain *p)
421 {
422 	int i = 0;
423 
424 	while (p) {
425 		i++;
426 		p = p->next;
427 	}
428 	return i;
429 }
430 
431 static TlsConnection *
432 tlsServer2(int ctl, int hand, uchar *cert, int ncert, int (*trace)(char*fmt, ...), PEMChain *chp)
433 {
434 	TlsConnection *c;
435 	Msg m;
436 	Bytes *csid;
437 	uchar sid[SidSize], kd[MaxKeyData];
438 	char *secrets;
439 	int cipher, compressor, nsid, rv, numcerts, i;
440 
441 	if(trace)
442 		trace("tlsServer2\n");
443 	if(!initCiphers())
444 		return nil;
445 	c = emalloc(sizeof(TlsConnection));
446 	c->ctl = ctl;
447 	c->hand = hand;
448 	c->trace = trace;
449 	c->version = ProtocolVersion;
450 
451 	memset(&m, 0, sizeof(m));
452 	if(!msgRecv(c, &m)){
453 		if(trace)
454 			trace("initial msgRecv failed\n");
455 		goto Err;
456 	}
457 	if(m.tag != HClientHello) {
458 		tlsError(c, EUnexpectedMessage, "expected a client hello");
459 		goto Err;
460 	}
461 	c->clientVersion = m.u.clientHello.version;
462 	if(trace)
463 		trace("ClientHello version %x\n", c->clientVersion);
464 	if(setVersion(c, m.u.clientHello.version) < 0) {
465 		tlsError(c, EIllegalParameter, "incompatible version");
466 		goto Err;
467 	}
468 
469 	memmove(c->crandom, m.u.clientHello.random, RandomSize);
470 	cipher = okCipher(m.u.clientHello.ciphers);
471 	if(cipher < 0) {
472 		// reply with EInsufficientSecurity if we know that's the case
473 		if(cipher == -2)
474 			tlsError(c, EInsufficientSecurity, "cipher suites too weak");
475 		else
476 			tlsError(c, EHandshakeFailure, "no matching cipher suite");
477 		goto Err;
478 	}
479 	if(!setAlgs(c, cipher)){
480 		tlsError(c, EHandshakeFailure, "no matching cipher suite");
481 		goto Err;
482 	}
483 	compressor = okCompression(m.u.clientHello.compressors);
484 	if(compressor < 0) {
485 		tlsError(c, EHandshakeFailure, "no matching compressor");
486 		goto Err;
487 	}
488 
489 	csid = m.u.clientHello.sid;
490 	if(trace)
491 		trace("  cipher %d, compressor %d, csidlen %d\n", cipher, compressor, csid->len);
492 	c->sec = tlsSecInits(c->clientVersion, csid->data, csid->len, c->crandom, sid, &nsid, c->srandom);
493 	if(c->sec == nil){
494 		tlsError(c, EHandshakeFailure, "can't initialize security: %r");
495 		goto Err;
496 	}
497 	c->sec->rpc = factotum_rsa_open(cert, ncert);
498 	if(c->sec->rpc == nil){
499 		tlsError(c, EHandshakeFailure, "factotum_rsa_open: %r");
500 		goto Err;
501 	}
502 	c->sec->rsapub = X509toRSApub(cert, ncert, nil, 0);
503 	msgClear(&m);
504 
505 	m.tag = HServerHello;
506 	m.u.serverHello.version = c->version;
507 	memmove(m.u.serverHello.random, c->srandom, RandomSize);
508 	m.u.serverHello.cipher = cipher;
509 	m.u.serverHello.compressor = compressor;
510 	c->sid = makebytes(sid, nsid);
511 	m.u.serverHello.sid = makebytes(c->sid->data, c->sid->len);
512 	if(!msgSend(c, &m, AQueue))
513 		goto Err;
514 	msgClear(&m);
515 
516 	m.tag = HCertificate;
517 	numcerts = countchain(chp);
518 	m.u.certificate.ncert = 1 + numcerts;
519 	m.u.certificate.certs = emalloc(m.u.certificate.ncert * sizeof(Bytes));
520 	m.u.certificate.certs[0] = makebytes(cert, ncert);
521 	for (i = 0; i < numcerts && chp; i++, chp = chp->next)
522 		m.u.certificate.certs[i+1] = makebytes(chp->pem, chp->pemlen);
523 	if(!msgSend(c, &m, AQueue))
524 		goto Err;
525 	msgClear(&m);
526 
527 	m.tag = HServerHelloDone;
528 	if(!msgSend(c, &m, AFlush))
529 		goto Err;
530 	msgClear(&m);
531 
532 	if(!msgRecv(c, &m))
533 		goto Err;
534 	if(m.tag != HClientKeyExchange) {
535 		tlsError(c, EUnexpectedMessage, "expected a client key exchange");
536 		goto Err;
537 	}
538 	if(tlsSecSecrets(c->sec, c->version, m.u.clientKeyExchange.key->data, m.u.clientKeyExchange.key->len, kd, c->nsecret) < 0){
539 		tlsError(c, EHandshakeFailure, "couldn't set secrets: %r");
540 		goto Err;
541 	}
542 	if(trace)
543 		trace("tls secrets\n");
544 	secrets = (char*)emalloc(2*c->nsecret);
545 	enc64(secrets, 2*c->nsecret, kd, c->nsecret);
546 	rv = fprint(c->ctl, "secret %s %s 0 %s", c->digest, c->enc, secrets);
547 	memset(secrets, 0, 2*c->nsecret);
548 	free(secrets);
549 	memset(kd, 0, c->nsecret);
550 	if(rv < 0){
551 		tlsError(c, EHandshakeFailure, "can't set keys: %r");
552 		goto Err;
553 	}
554 	msgClear(&m);
555 
556 	/* no CertificateVerify; skip to Finished */
557 	if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 1) < 0){
558 		tlsError(c, EInternalError, "can't set finished: %r");
559 		goto Err;
560 	}
561 	if(!msgRecv(c, &m))
562 		goto Err;
563 	if(m.tag != HFinished) {
564 		tlsError(c, EUnexpectedMessage, "expected a finished");
565 		goto Err;
566 	}
567 	if(!finishedMatch(c, &m.u.finished)) {
568 		tlsError(c, EHandshakeFailure, "finished verification failed");
569 		goto Err;
570 	}
571 	msgClear(&m);
572 
573 	/* change cipher spec */
574 	if(fprint(c->ctl, "changecipher") < 0){
575 		tlsError(c, EInternalError, "can't enable cipher: %r");
576 		goto Err;
577 	}
578 
579 	if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 0) < 0){
580 		tlsError(c, EInternalError, "can't set finished: %r");
581 		goto Err;
582 	}
583 	m.tag = HFinished;
584 	m.u.finished = c->finished;
585 	if(!msgSend(c, &m, AFlush))
586 		goto Err;
587 	msgClear(&m);
588 	if(trace)
589 		trace("tls finished\n");
590 
591 	if(fprint(c->ctl, "opened") < 0)
592 		goto Err;
593 	tlsSecOk(c->sec);
594 	return c;
595 
596 Err:
597 	msgClear(&m);
598 	tlsConnectionFree(c);
599 	return 0;
600 }
601 
602 static TlsConnection *
603 tlsClient2(int ctl, int hand, uchar *csid, int ncsid, int (*trace)(char*fmt, ...))
604 {
605 	TlsConnection *c;
606 	Msg m;
607 	uchar kd[MaxKeyData], *epm;
608 	char *secrets;
609 	int creq, nepm, rv;
610 
611 	if(!initCiphers())
612 		return nil;
613 	epm = nil;
614 	c = emalloc(sizeof(TlsConnection));
615 	c->version = ProtocolVersion;
616 	c->ctl = ctl;
617 	c->hand = hand;
618 	c->trace = trace;
619 	c->isClient = 1;
620 	c->clientVersion = c->version;
621 
622 	c->sec = tlsSecInitc(c->clientVersion, c->crandom);
623 	if(c->sec == nil)
624 		goto Err;
625 
626 	/* client hello */
627 	memset(&m, 0, sizeof(m));
628 	m.tag = HClientHello;
629 	m.u.clientHello.version = c->clientVersion;
630 	memmove(m.u.clientHello.random, c->crandom, RandomSize);
631 	m.u.clientHello.sid = makebytes(csid, ncsid);
632 	m.u.clientHello.ciphers = makeciphers();
633 	m.u.clientHello.compressors = makebytes(compressors,sizeof(compressors));
634 	if(!msgSend(c, &m, AFlush))
635 		goto Err;
636 	msgClear(&m);
637 
638 	/* server hello */
639 	if(!msgRecv(c, &m))
640 		goto Err;
641 	if(m.tag != HServerHello) {
642 		tlsError(c, EUnexpectedMessage, "expected a server hello");
643 		goto Err;
644 	}
645 	if(setVersion(c, m.u.serverHello.version) < 0) {
646 		tlsError(c, EIllegalParameter, "incompatible version %r");
647 		goto Err;
648 	}
649 	memmove(c->srandom, m.u.serverHello.random, RandomSize);
650 	c->sid = makebytes(m.u.serverHello.sid->data, m.u.serverHello.sid->len);
651 	if(c->sid->len != 0 && c->sid->len != SidSize) {
652 		tlsError(c, EIllegalParameter, "invalid server session identifier");
653 		goto Err;
654 	}
655 	if(!setAlgs(c, m.u.serverHello.cipher)) {
656 		tlsError(c, EIllegalParameter, "invalid cipher suite");
657 		goto Err;
658 	}
659 	if(m.u.serverHello.compressor != CompressionNull) {
660 		tlsError(c, EIllegalParameter, "invalid compression");
661 		goto Err;
662 	}
663 	msgClear(&m);
664 
665 	/* certificate */
666 	if(!msgRecv(c, &m) || m.tag != HCertificate) {
667 		tlsError(c, EUnexpectedMessage, "expected a certificate");
668 		goto Err;
669 	}
670 	if(m.u.certificate.ncert < 1) {
671 		tlsError(c, EIllegalParameter, "runt certificate");
672 		goto Err;
673 	}
674 	c->cert = makebytes(m.u.certificate.certs[0]->data, m.u.certificate.certs[0]->len);
675 	msgClear(&m);
676 
677 	/* server key exchange (optional) */
678 	if(!msgRecv(c, &m))
679 		goto Err;
680 	if(m.tag == HServerKeyExchange) {
681 		tlsError(c, EUnexpectedMessage, "got an server key exchange");
682 		goto Err;
683 		// If implementing this later, watch out for rollback attack
684 		// described in Wagner Schneier 1996, section 4.4.
685 	}
686 
687 	/* certificate request (optional) */
688 	creq = 0;
689 	if(m.tag == HCertificateRequest) {
690 		creq = 1;
691 		msgClear(&m);
692 		if(!msgRecv(c, &m))
693 			goto Err;
694 	}
695 
696 	if(m.tag != HServerHelloDone) {
697 		tlsError(c, EUnexpectedMessage, "expected a server hello done");
698 		goto Err;
699 	}
700 	msgClear(&m);
701 
702 	if(tlsSecSecretc(c->sec, c->sid->data, c->sid->len, c->srandom,
703 			c->cert->data, c->cert->len, c->version, &epm, &nepm,
704 			kd, c->nsecret) < 0){
705 		tlsError(c, EBadCertificate, "invalid x509/rsa certificate");
706 		goto Err;
707 	}
708 	secrets = (char*)emalloc(2*c->nsecret);
709 	enc64(secrets, 2*c->nsecret, kd, c->nsecret);
710 	rv = fprint(c->ctl, "secret %s %s 1 %s", c->digest, c->enc, secrets);
711 	memset(secrets, 0, 2*c->nsecret);
712 	free(secrets);
713 	memset(kd, 0, c->nsecret);
714 	if(rv < 0){
715 		tlsError(c, EHandshakeFailure, "can't set keys: %r");
716 		goto Err;
717 	}
718 
719 	if(creq) {
720 		/* send a zero length certificate */
721 		m.tag = HCertificate;
722 		if(!msgSend(c, &m, AFlush))
723 			goto Err;
724 		msgClear(&m);
725 	}
726 
727 	/* client key exchange */
728 	m.tag = HClientKeyExchange;
729 	m.u.clientKeyExchange.key = makebytes(epm, nepm);
730 	free(epm);
731 	epm = nil;
732 	if(m.u.clientKeyExchange.key == nil) {
733 		tlsError(c, EHandshakeFailure, "can't set secret: %r");
734 		goto Err;
735 	}
736 	if(!msgSend(c, &m, AFlush))
737 		goto Err;
738 	msgClear(&m);
739 
740 	/* change cipher spec */
741 	if(fprint(c->ctl, "changecipher") < 0){
742 		tlsError(c, EInternalError, "can't enable cipher: %r");
743 		goto Err;
744 	}
745 
746 	// Cipherchange must occur immediately before Finished to avoid
747 	// potential hole;  see section 4.3 of Wagner Schneier 1996.
748 	if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 1) < 0){
749 		tlsError(c, EInternalError, "can't set finished 1: %r");
750 		goto Err;
751 	}
752 	m.tag = HFinished;
753 	m.u.finished = c->finished;
754 
755 	if(!msgSend(c, &m, AFlush)) {
756 		fprint(2, "tlsClient nepm=%d\n", nepm);
757 		tlsError(c, EInternalError, "can't flush after client Finished: %r");
758 		goto Err;
759 	}
760 	msgClear(&m);
761 
762 	if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 0) < 0){
763 		fprint(2, "tlsClient nepm=%d\n", nepm);
764 		tlsError(c, EInternalError, "can't set finished 0: %r");
765 		goto Err;
766 	}
767 	if(!msgRecv(c, &m)) {
768 		fprint(2, "tlsClient nepm=%d\n", nepm);
769 		tlsError(c, EInternalError, "can't read server Finished: %r");
770 		goto Err;
771 	}
772 	if(m.tag != HFinished) {
773 		fprint(2, "tlsClient nepm=%d\n", nepm);
774 		tlsError(c, EUnexpectedMessage, "expected a Finished msg from server");
775 		goto Err;
776 	}
777 
778 	if(!finishedMatch(c, &m.u.finished)) {
779 		tlsError(c, EHandshakeFailure, "finished verification failed");
780 		goto Err;
781 	}
782 	msgClear(&m);
783 
784 	if(fprint(c->ctl, "opened") < 0){
785 		if(trace)
786 			trace("unable to do final open: %r\n");
787 		goto Err;
788 	}
789 	tlsSecOk(c->sec);
790 	return c;
791 
792 Err:
793 	free(epm);
794 	msgClear(&m);
795 	tlsConnectionFree(c);
796 	return 0;
797 }
798 
799 
800 //================= message functions ========================
801 
802 static uchar sendbuf[9000], *sendp;
803 
804 static int
805 msgSend(TlsConnection *c, Msg *m, int act)
806 {
807 	uchar *p; // sendp = start of new message;  p = write pointer
808 	int nn, n, i;
809 
810 	if(sendp == nil)
811 		sendp = sendbuf;
812 	p = sendp;
813 	if(c->trace)
814 		c->trace("send %s", msgPrint((char*)p, (sizeof sendbuf) - (p-sendbuf), m));
815 
816 	p[0] = m->tag;	// header - fill in size later
817 	p += 4;
818 
819 	switch(m->tag) {
820 	default:
821 		tlsError(c, EInternalError, "can't encode a %d", m->tag);
822 		goto Err;
823 	case HClientHello:
824 		// version
825 		put16(p, m->u.clientHello.version);
826 		p += 2;
827 
828 		// random
829 		memmove(p, m->u.clientHello.random, RandomSize);
830 		p += RandomSize;
831 
832 		// sid
833 		n = m->u.clientHello.sid->len;
834 		assert(n < 256);
835 		p[0] = n;
836 		memmove(p+1, m->u.clientHello.sid->data, n);
837 		p += n+1;
838 
839 		n = m->u.clientHello.ciphers->len;
840 		assert(n > 0 && n < 200);
841 		put16(p, n*2);
842 		p += 2;
843 		for(i=0; i<n; i++) {
844 			put16(p, m->u.clientHello.ciphers->data[i]);
845 			p += 2;
846 		}
847 
848 		n = m->u.clientHello.compressors->len;
849 		assert(n > 0);
850 		p[0] = n;
851 		memmove(p+1, m->u.clientHello.compressors->data, n);
852 		p += n+1;
853 		break;
854 	case HServerHello:
855 		put16(p, m->u.serverHello.version);
856 		p += 2;
857 
858 		// random
859 		memmove(p, m->u.serverHello.random, RandomSize);
860 		p += RandomSize;
861 
862 		// sid
863 		n = m->u.serverHello.sid->len;
864 		assert(n < 256);
865 		p[0] = n;
866 		memmove(p+1, m->u.serverHello.sid->data, n);
867 		p += n+1;
868 
869 		put16(p, m->u.serverHello.cipher);
870 		p += 2;
871 		p[0] = m->u.serverHello.compressor;
872 		p += 1;
873 		break;
874 	case HServerHelloDone:
875 		break;
876 	case HCertificate:
877 		nn = 0;
878 		for(i = 0; i < m->u.certificate.ncert; i++)
879 			nn += 3 + m->u.certificate.certs[i]->len;
880 		if(p + 3 + nn - sendbuf > sizeof(sendbuf)) {
881 			tlsError(c, EInternalError, "output buffer too small for certificate");
882 			goto Err;
883 		}
884 		put24(p, nn);
885 		p += 3;
886 		for(i = 0; i < m->u.certificate.ncert; i++){
887 			put24(p, m->u.certificate.certs[i]->len);
888 			p += 3;
889 			memmove(p, m->u.certificate.certs[i]->data, m->u.certificate.certs[i]->len);
890 			p += m->u.certificate.certs[i]->len;
891 		}
892 		break;
893 	case HClientKeyExchange:
894 		n = m->u.clientKeyExchange.key->len;
895 		if(c->version != SSL3Version){
896 			put16(p, n);
897 			p += 2;
898 		}
899 		memmove(p, m->u.clientKeyExchange.key->data, n);
900 		p += n;
901 		break;
902 	case HFinished:
903 		memmove(p, m->u.finished.verify, m->u.finished.n);
904 		p += m->u.finished.n;
905 		break;
906 	}
907 
908 	// go back and fill in size
909 	n = p-sendp;
910 	assert(p <= sendbuf+sizeof(sendbuf));
911 	put24(sendp+1, n-4);
912 
913 	// remember hash of Handshake messages
914 	if(m->tag != HHelloRequest) {
915 		md5(sendp, n, 0, &c->hsmd5);
916 		sha1(sendp, n, 0, &c->hssha1);
917 	}
918 
919 	sendp = p;
920 	if(act == AFlush){
921 		sendp = sendbuf;
922 		if(write(c->hand, sendbuf, p-sendbuf) < 0){
923 			fprint(2, "write error: %r\n");
924 			goto Err;
925 		}
926 	}
927 	msgClear(m);
928 	return 1;
929 Err:
930 	msgClear(m);
931 	return 0;
932 }
933 
934 static uchar*
935 tlsReadN(TlsConnection *c, int n)
936 {
937 	uchar *p;
938 	int nn, nr;
939 
940 	nn = c->ep - c->rp;
941 	if(nn < n){
942 		if(c->rp != c->buf){
943 			memmove(c->buf, c->rp, nn);
944 			c->rp = c->buf;
945 			c->ep = &c->buf[nn];
946 		}
947 		for(; nn < n; nn += nr) {
948 			nr = read(c->hand, &c->rp[nn], n - nn);
949 			if(nr <= 0)
950 				return nil;
951 			c->ep += nr;
952 		}
953 	}
954 	p = c->rp;
955 	c->rp += n;
956 	return p;
957 }
958 
959 static int
960 msgRecv(TlsConnection *c, Msg *m)
961 {
962 	uchar *p;
963 	int type, n, nn, i, nsid, nrandom, nciph;
964 
965 	for(;;) {
966 		p = tlsReadN(c, 4);
967 		if(p == nil)
968 			return 0;
969 		type = p[0];
970 		n = get24(p+1);
971 
972 		if(type != HHelloRequest)
973 			break;
974 		if(n != 0) {
975 			tlsError(c, EDecodeError, "invalid hello request during handshake");
976 			return 0;
977 		}
978 	}
979 
980 	if(n > sizeof(c->buf)) {
981 		tlsError(c, EDecodeError, "handshake message too long %d %d", n, sizeof(c->buf));
982 		return 0;
983 	}
984 
985 	if(type == HSSL2ClientHello){
986 		/* Cope with an SSL3 ClientHello expressed in SSL2 record format.
987 			This is sent by some clients that we must interoperate
988 			with, such as Java's JSSE and Microsoft's Internet Explorer. */
989 		p = tlsReadN(c, n);
990 		if(p == nil)
991 			return 0;
992 		md5(p, n, 0, &c->hsmd5);
993 		sha1(p, n, 0, &c->hssha1);
994 		m->tag = HClientHello;
995 		if(n < 22)
996 			goto Short;
997 		m->u.clientHello.version = get16(p+1);
998 		p += 3;
999 		n -= 3;
1000 		nn = get16(p); /* cipher_spec_len */
1001 		nsid = get16(p + 2);
1002 		nrandom = get16(p + 4);
1003 		p += 6;
1004 		n -= 6;
1005 		if(nsid != 0 	/* no sid's, since shouldn't restart using ssl2 header */
1006 				|| nrandom < 16 || nn % 3)
1007 			goto Err;
1008 		if(c->trace && (n - nrandom != nn))
1009 			c->trace("n-nrandom!=nn: n=%d nrandom=%d nn=%d\n", n, nrandom, nn);
1010 		/* ignore ssl2 ciphers and look for {0x00, ssl3 cipher} */
1011 		nciph = 0;
1012 		for(i = 0; i < nn; i += 3)
1013 			if(p[i] == 0)
1014 				nciph++;
1015 		m->u.clientHello.ciphers = newints(nciph);
1016 		nciph = 0;
1017 		for(i = 0; i < nn; i += 3)
1018 			if(p[i] == 0)
1019 				m->u.clientHello.ciphers->data[nciph++] = get16(&p[i + 1]);
1020 		p += nn;
1021 		m->u.clientHello.sid = makebytes(nil, 0);
1022 		if(nrandom > RandomSize)
1023 			nrandom = RandomSize;
1024 		memset(m->u.clientHello.random, 0, RandomSize - nrandom);
1025 		memmove(&m->u.clientHello.random[RandomSize - nrandom], p, nrandom);
1026 		m->u.clientHello.compressors = newbytes(1);
1027 		m->u.clientHello.compressors->data[0] = CompressionNull;
1028 		goto Ok;
1029 	}
1030 
1031 	md5(p, 4, 0, &c->hsmd5);
1032 	sha1(p, 4, 0, &c->hssha1);
1033 
1034 	p = tlsReadN(c, n);
1035 	if(p == nil)
1036 		return 0;
1037 
1038 	md5(p, n, 0, &c->hsmd5);
1039 	sha1(p, n, 0, &c->hssha1);
1040 
1041 	m->tag = type;
1042 
1043 	switch(type) {
1044 	default:
1045 		tlsError(c, EUnexpectedMessage, "can't decode a %d", type);
1046 		goto Err;
1047 	case HClientHello:
1048 		if(n < 2)
1049 			goto Short;
1050 		m->u.clientHello.version = get16(p);
1051 		p += 2;
1052 		n -= 2;
1053 
1054 		if(n < RandomSize)
1055 			goto Short;
1056 		memmove(m->u.clientHello.random, p, RandomSize);
1057 		p += RandomSize;
1058 		n -= RandomSize;
1059 		if(n < 1 || n < p[0]+1)
1060 			goto Short;
1061 		m->u.clientHello.sid = makebytes(p+1, p[0]);
1062 		p += m->u.clientHello.sid->len+1;
1063 		n -= m->u.clientHello.sid->len+1;
1064 
1065 		if(n < 2)
1066 			goto Short;
1067 		nn = get16(p);
1068 		p += 2;
1069 		n -= 2;
1070 
1071 		if((nn & 1) || n < nn || nn < 2)
1072 			goto Short;
1073 		m->u.clientHello.ciphers = newints(nn >> 1);
1074 		for(i = 0; i < nn; i += 2)
1075 			m->u.clientHello.ciphers->data[i >> 1] = get16(&p[i]);
1076 		p += nn;
1077 		n -= nn;
1078 
1079 		if(n < 1 || n < p[0]+1 || p[0] == 0)
1080 			goto Short;
1081 		nn = p[0];
1082 		m->u.clientHello.compressors = newbytes(nn);
1083 		memmove(m->u.clientHello.compressors->data, p+1, nn);
1084 		n -= nn + 1;
1085 		break;
1086 	case HServerHello:
1087 		if(n < 2)
1088 			goto Short;
1089 		m->u.serverHello.version = get16(p);
1090 		p += 2;
1091 		n -= 2;
1092 
1093 		if(n < RandomSize)
1094 			goto Short;
1095 		memmove(m->u.serverHello.random, p, RandomSize);
1096 		p += RandomSize;
1097 		n -= RandomSize;
1098 
1099 		if(n < 1 || n < p[0]+1)
1100 			goto Short;
1101 		m->u.serverHello.sid = makebytes(p+1, p[0]);
1102 		p += m->u.serverHello.sid->len+1;
1103 		n -= m->u.serverHello.sid->len+1;
1104 
1105 		if(n < 3)
1106 			goto Short;
1107 		m->u.serverHello.cipher = get16(p);
1108 		m->u.serverHello.compressor = p[2];
1109 		n -= 3;
1110 		break;
1111 	case HCertificate:
1112 		if(n < 3)
1113 			goto Short;
1114 		nn = get24(p);
1115 		p += 3;
1116 		n -= 3;
1117 		if(n != nn)
1118 			goto Short;
1119 		/* certs */
1120 		i = 0;
1121 		while(n > 0) {
1122 			if(n < 3)
1123 				goto Short;
1124 			nn = get24(p);
1125 			p += 3;
1126 			n -= 3;
1127 			if(nn > n)
1128 				goto Short;
1129 			m->u.certificate.ncert = i+1;
1130 			m->u.certificate.certs = erealloc(m->u.certificate.certs, (i+1)*sizeof(Bytes));
1131 			m->u.certificate.certs[i] = makebytes(p, nn);
1132 			p += nn;
1133 			n -= nn;
1134 			i++;
1135 		}
1136 		break;
1137 	case HCertificateRequest:
1138 		if(n < 1)
1139 			goto Short;
1140 		nn = p[0];
1141 		p += 1;
1142 		n -= 1;
1143 		if(nn < 1 || nn > n)
1144 			goto Short;
1145 		m->u.certificateRequest.types = makebytes(p, nn);
1146 		p += nn;
1147 		n -= nn;
1148 		if(n < 2)
1149 			goto Short;
1150 		nn = get16(p);
1151 		p += 2;
1152 		n -= 2;
1153 		if(nn == 0 || n != nn)
1154 			goto Short;
1155 		/* cas */
1156 		i = 0;
1157 		while(n > 0) {
1158 			if(n < 2)
1159 				goto Short;
1160 			nn = get16(p);
1161 			p += 2;
1162 			n -= 2;
1163 			if(nn < 1 || nn > n)
1164 				goto Short;
1165 			m->u.certificateRequest.nca = i+1;
1166 			m->u.certificateRequest.cas = erealloc(m->u.certificateRequest.cas, (i+1)*sizeof(Bytes));
1167 			m->u.certificateRequest.cas[i] = makebytes(p, nn);
1168 			p += nn;
1169 			n -= nn;
1170 			i++;
1171 		}
1172 		break;
1173 	case HServerHelloDone:
1174 		break;
1175 	case HClientKeyExchange:
1176 		/*
1177 		 * this message depends upon the encryption selected
1178 		 * assume rsa.
1179 		 */
1180 		if(c->version == SSL3Version)
1181 			nn = n;
1182 		else{
1183 			if(n < 2)
1184 				goto Short;
1185 			nn = get16(p);
1186 			p += 2;
1187 			n -= 2;
1188 		}
1189 		if(n < nn)
1190 			goto Short;
1191 		m->u.clientKeyExchange.key = makebytes(p, nn);
1192 		n -= nn;
1193 		break;
1194 	case HFinished:
1195 		m->u.finished.n = c->finished.n;
1196 		if(n < m->u.finished.n)
1197 			goto Short;
1198 		memmove(m->u.finished.verify, p, m->u.finished.n);
1199 		n -= m->u.finished.n;
1200 		break;
1201 	}
1202 
1203 	if(type != HClientHello && n != 0)
1204 		goto Short;
1205 Ok:
1206 	if(c->trace){
1207 		char *buf;
1208 		buf = emalloc(8000);
1209 		c->trace("recv %s", msgPrint(buf, 8000, m));
1210 		free(buf);
1211 	}
1212 	return 1;
1213 Short:
1214 	tlsError(c, EDecodeError, "handshake message has invalid length");
1215 Err:
1216 	msgClear(m);
1217 	return 0;
1218 }
1219 
1220 static void
1221 msgClear(Msg *m)
1222 {
1223 	int i;
1224 
1225 	switch(m->tag) {
1226 	default:
1227 		sysfatal("msgClear: unknown message type: %d\n", m->tag);
1228 	case HHelloRequest:
1229 		break;
1230 	case HClientHello:
1231 		freebytes(m->u.clientHello.sid);
1232 		freeints(m->u.clientHello.ciphers);
1233 		freebytes(m->u.clientHello.compressors);
1234 		break;
1235 	case HServerHello:
1236 		freebytes(m->u.clientHello.sid);
1237 		break;
1238 	case HCertificate:
1239 		for(i=0; i<m->u.certificate.ncert; i++)
1240 			freebytes(m->u.certificate.certs[i]);
1241 		free(m->u.certificate.certs);
1242 		break;
1243 	case HCertificateRequest:
1244 		freebytes(m->u.certificateRequest.types);
1245 		for(i=0; i<m->u.certificateRequest.nca; i++)
1246 			freebytes(m->u.certificateRequest.cas[i]);
1247 		free(m->u.certificateRequest.cas);
1248 		break;
1249 	case HServerHelloDone:
1250 		break;
1251 	case HClientKeyExchange:
1252 		freebytes(m->u.clientKeyExchange.key);
1253 		break;
1254 	case HFinished:
1255 		break;
1256 	}
1257 	memset(m, 0, sizeof(Msg));
1258 }
1259 
1260 static char *
1261 bytesPrint(char *bs, char *be, char *s0, Bytes *b, char *s1)
1262 {
1263 	int i;
1264 
1265 	if(s0)
1266 		bs = seprint(bs, be, "%s", s0);
1267 	bs = seprint(bs, be, "[");
1268 	if(b == nil)
1269 		bs = seprint(bs, be, "nil");
1270 	else
1271 		for(i=0; i<b->len; i++)
1272 			bs = seprint(bs, be, "%.2x ", b->data[i]);
1273 	bs = seprint(bs, be, "]");
1274 	if(s1)
1275 		bs = seprint(bs, be, "%s", s1);
1276 	return bs;
1277 }
1278 
1279 static char *
1280 intsPrint(char *bs, char *be, char *s0, Ints *b, char *s1)
1281 {
1282 	int i;
1283 
1284 	if(s0)
1285 		bs = seprint(bs, be, "%s", s0);
1286 	bs = seprint(bs, be, "[");
1287 	if(b == nil)
1288 		bs = seprint(bs, be, "nil");
1289 	else
1290 		for(i=0; i<b->len; i++)
1291 			bs = seprint(bs, be, "%x ", b->data[i]);
1292 	bs = seprint(bs, be, "]");
1293 	if(s1)
1294 		bs = seprint(bs, be, "%s", s1);
1295 	return bs;
1296 }
1297 
1298 static char*
1299 msgPrint(char *buf, int n, Msg *m)
1300 {
1301 	int i;
1302 	char *bs = buf, *be = buf+n;
1303 
1304 	switch(m->tag) {
1305 	default:
1306 		bs = seprint(bs, be, "unknown %d\n", m->tag);
1307 		break;
1308 	case HClientHello:
1309 		bs = seprint(bs, be, "ClientHello\n");
1310 		bs = seprint(bs, be, "\tversion: %.4x\n", m->u.clientHello.version);
1311 		bs = seprint(bs, be, "\trandom: ");
1312 		for(i=0; i<RandomSize; i++)
1313 			bs = seprint(bs, be, "%.2x", m->u.clientHello.random[i]);
1314 		bs = seprint(bs, be, "\n");
1315 		bs = bytesPrint(bs, be, "\tsid: ", m->u.clientHello.sid, "\n");
1316 		bs = intsPrint(bs, be, "\tciphers: ", m->u.clientHello.ciphers, "\n");
1317 		bs = bytesPrint(bs, be, "\tcompressors: ", m->u.clientHello.compressors, "\n");
1318 		break;
1319 	case HServerHello:
1320 		bs = seprint(bs, be, "ServerHello\n");
1321 		bs = seprint(bs, be, "\tversion: %.4x\n", m->u.serverHello.version);
1322 		bs = seprint(bs, be, "\trandom: ");
1323 		for(i=0; i<RandomSize; i++)
1324 			bs = seprint(bs, be, "%.2x", m->u.serverHello.random[i]);
1325 		bs = seprint(bs, be, "\n");
1326 		bs = bytesPrint(bs, be, "\tsid: ", m->u.serverHello.sid, "\n");
1327 		bs = seprint(bs, be, "\tcipher: %.4x\n", m->u.serverHello.cipher);
1328 		bs = seprint(bs, be, "\tcompressor: %.2x\n", m->u.serverHello.compressor);
1329 		break;
1330 	case HCertificate:
1331 		bs = seprint(bs, be, "Certificate\n");
1332 		for(i=0; i<m->u.certificate.ncert; i++)
1333 			bs = bytesPrint(bs, be, "\t", m->u.certificate.certs[i], "\n");
1334 		break;
1335 	case HCertificateRequest:
1336 		bs = seprint(bs, be, "CertificateRequest\n");
1337 		bs = bytesPrint(bs, be, "\ttypes: ", m->u.certificateRequest.types, "\n");
1338 		bs = seprint(bs, be, "\tcertificateauthorities\n");
1339 		for(i=0; i<m->u.certificateRequest.nca; i++)
1340 			bs = bytesPrint(bs, be, "\t\t", m->u.certificateRequest.cas[i], "\n");
1341 		break;
1342 	case HServerHelloDone:
1343 		bs = seprint(bs, be, "ServerHelloDone\n");
1344 		break;
1345 	case HClientKeyExchange:
1346 		bs = seprint(bs, be, "HClientKeyExchange\n");
1347 		bs = bytesPrint(bs, be, "\tkey: ", m->u.clientKeyExchange.key, "\n");
1348 		break;
1349 	case HFinished:
1350 		bs = seprint(bs, be, "HFinished\n");
1351 		for(i=0; i<m->u.finished.n; i++)
1352 			bs = seprint(bs, be, "%.2x", m->u.finished.verify[i]);
1353 		bs = seprint(bs, be, "\n");
1354 		break;
1355 	}
1356 	USED(bs);
1357 	return buf;
1358 }
1359 
1360 static void
1361 tlsError(TlsConnection *c, int err, char *fmt, ...)
1362 {
1363 	char msg[512];
1364 	va_list arg;
1365 
1366 	va_start(arg, fmt);
1367 	vseprint(msg, msg+sizeof(msg), fmt, arg);
1368 	va_end(arg);
1369 	if(c->trace)
1370 		c->trace("tlsError: %s\n", msg);
1371 	else if(c->erred)
1372 		fprint(2, "double error: %r, %s", msg);
1373 	else
1374 		werrstr("tls: local %s", msg);
1375 	c->erred = 1;
1376 	fprint(c->ctl, "alert %d", err);
1377 }
1378 
1379 // commit to specific version number
1380 static int
1381 setVersion(TlsConnection *c, int version)
1382 {
1383 	if(c->verset || version > MaxProtoVersion || version < MinProtoVersion)
1384 		return -1;
1385 	if(version > c->version)
1386 		version = c->version;
1387 	if(version == SSL3Version) {
1388 		c->version = version;
1389 		c->finished.n = SSL3FinishedLen;
1390 	}else if(version == TLSVersion){
1391 		c->version = version;
1392 		c->finished.n = TLSFinishedLen;
1393 	}else
1394 		return -1;
1395 	c->verset = 1;
1396 	return fprint(c->ctl, "version 0x%x", version);
1397 }
1398 
1399 // confirm that received Finished message matches the expected value
1400 static int
1401 finishedMatch(TlsConnection *c, Finished *f)
1402 {
1403 	return memcmp(f->verify, c->finished.verify, f->n) == 0;
1404 }
1405 
1406 // free memory associated with TlsConnection struct
1407 //		(but don't close the TLS channel itself)
1408 static void
1409 tlsConnectionFree(TlsConnection *c)
1410 {
1411 	tlsSecClose(c->sec);
1412 	freebytes(c->sid);
1413 	freebytes(c->cert);
1414 	memset(c, 0, sizeof(c));
1415 	free(c);
1416 }
1417 
1418 
1419 //================= cipher choices ========================
1420 
1421 static int weakCipher[CipherMax] =
1422 {
1423 	1,	/* TLS_NULL_WITH_NULL_NULL */
1424 	1,	/* TLS_RSA_WITH_NULL_MD5 */
1425 	1,	/* TLS_RSA_WITH_NULL_SHA */
1426 	1,	/* TLS_RSA_EXPORT_WITH_RC4_40_MD5 */
1427 	0,	/* TLS_RSA_WITH_RC4_128_MD5 */
1428 	0,	/* TLS_RSA_WITH_RC4_128_SHA */
1429 	1,	/* TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5 */
1430 	0,	/* TLS_RSA_WITH_IDEA_CBC_SHA */
1431 	1,	/* TLS_RSA_EXPORT_WITH_DES40_CBC_SHA */
1432 	0,	/* TLS_RSA_WITH_DES_CBC_SHA */
1433 	0,	/* TLS_RSA_WITH_3DES_EDE_CBC_SHA */
1434 	1,	/* TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA */
1435 	0,	/* TLS_DH_DSS_WITH_DES_CBC_SHA */
1436 	0,	/* TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA */
1437 	1,	/* TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA */
1438 	0,	/* TLS_DH_RSA_WITH_DES_CBC_SHA */
1439 	0,	/* TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA */
1440 	1,	/* TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA */
1441 	0,	/* TLS_DHE_DSS_WITH_DES_CBC_SHA */
1442 	0,	/* TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA */
1443 	1,	/* TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA */
1444 	0,	/* TLS_DHE_RSA_WITH_DES_CBC_SHA */
1445 	0,	/* TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA */
1446 	1,	/* TLS_DH_anon_EXPORT_WITH_RC4_40_MD5 */
1447 	1,	/* TLS_DH_anon_WITH_RC4_128_MD5 */
1448 	1,	/* TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA */
1449 	1,	/* TLS_DH_anon_WITH_DES_CBC_SHA */
1450 	1,	/* TLS_DH_anon_WITH_3DES_EDE_CBC_SHA */
1451 };
1452 
1453 static int
1454 setAlgs(TlsConnection *c, int a)
1455 {
1456 	int i;
1457 
1458 	for(i = 0; i < nelem(cipherAlgs); i++){
1459 		if(cipherAlgs[i].tlsid == a){
1460 			c->enc = cipherAlgs[i].enc;
1461 			c->digest = cipherAlgs[i].digest;
1462 			c->nsecret = cipherAlgs[i].nsecret;
1463 			if(c->nsecret > MaxKeyData)
1464 				return 0;
1465 			return 1;
1466 		}
1467 	}
1468 	return 0;
1469 }
1470 
1471 static int
1472 okCipher(Ints *cv)
1473 {
1474 	int weak, i, j, c;
1475 
1476 	weak = 1;
1477 	for(i = 0; i < cv->len; i++) {
1478 		c = cv->data[i];
1479 		if(c >= CipherMax)
1480 			weak = 0;
1481 		else
1482 			weak &= weakCipher[c];
1483 		for(j = 0; j < nelem(cipherAlgs); j++)
1484 			if(cipherAlgs[j].ok && cipherAlgs[j].tlsid == c)
1485 				return c;
1486 	}
1487 	if(weak)
1488 		return -2;
1489 	return -1;
1490 }
1491 
1492 static int
1493 okCompression(Bytes *cv)
1494 {
1495 	int i, j, c;
1496 
1497 	for(i = 0; i < cv->len; i++) {
1498 		c = cv->data[i];
1499 		for(j = 0; j < nelem(compressors); j++) {
1500 			if(compressors[j] == c)
1501 				return c;
1502 		}
1503 	}
1504 	return -1;
1505 }
1506 
1507 static Lock	ciphLock;
1508 static int	nciphers;
1509 
1510 static int
1511 initCiphers(void)
1512 {
1513 	enum {MaxAlgF = 1024, MaxAlgs = 10};
1514 	char s[MaxAlgF], *flds[MaxAlgs];
1515 	int i, j, n, ok;
1516 
1517 	lock(&ciphLock);
1518 	if(nciphers){
1519 		unlock(&ciphLock);
1520 		return nciphers;
1521 	}
1522 	j = open("#a/tls/encalgs", OREAD);
1523 	if(j < 0){
1524 		werrstr("can't open #a/tls/encalgs: %r");
1525 		return 0;
1526 	}
1527 	n = read(j, s, MaxAlgF-1);
1528 	close(j);
1529 	if(n <= 0){
1530 		werrstr("nothing in #a/tls/encalgs: %r");
1531 		return 0;
1532 	}
1533 	s[n] = 0;
1534 	n = getfields(s, flds, MaxAlgs, 1, " \t\r\n");
1535 	for(i = 0; i < nelem(cipherAlgs); i++){
1536 		ok = 0;
1537 		for(j = 0; j < n; j++){
1538 			if(strcmp(cipherAlgs[i].enc, flds[j]) == 0){
1539 				ok = 1;
1540 				break;
1541 			}
1542 		}
1543 		cipherAlgs[i].ok = ok;
1544 	}
1545 
1546 	j = open("#a/tls/hashalgs", OREAD);
1547 	if(j < 0){
1548 		werrstr("can't open #a/tls/hashalgs: %r");
1549 		return 0;
1550 	}
1551 	n = read(j, s, MaxAlgF-1);
1552 	close(j);
1553 	if(n <= 0){
1554 		werrstr("nothing in #a/tls/hashalgs: %r");
1555 		return 0;
1556 	}
1557 	s[n] = 0;
1558 	n = getfields(s, flds, MaxAlgs, 1, " \t\r\n");
1559 	for(i = 0; i < nelem(cipherAlgs); i++){
1560 		ok = 0;
1561 		for(j = 0; j < n; j++){
1562 			if(strcmp(cipherAlgs[i].digest, flds[j]) == 0){
1563 				ok = 1;
1564 				break;
1565 			}
1566 		}
1567 		cipherAlgs[i].ok &= ok;
1568 		if(cipherAlgs[i].ok)
1569 			nciphers++;
1570 	}
1571 	unlock(&ciphLock);
1572 	return nciphers;
1573 }
1574 
1575 static Ints*
1576 makeciphers(void)
1577 {
1578 	Ints *is;
1579 	int i, j;
1580 
1581 	is = newints(nciphers);
1582 	j = 0;
1583 	for(i = 0; i < nelem(cipherAlgs); i++){
1584 		if(cipherAlgs[i].ok)
1585 			is->data[j++] = cipherAlgs[i].tlsid;
1586 	}
1587 	return is;
1588 }
1589 
1590 
1591 
1592 //================= security functions ========================
1593 
1594 // given X.509 certificate, set up connection to factotum
1595 //	for using corresponding private key
1596 static AuthRpc*
1597 factotum_rsa_open(uchar *cert, int certlen)
1598 {
1599 	int afd;
1600 	char *s;
1601 	mpint *pub = nil;
1602 	RSApub *rsapub;
1603 	AuthRpc *rpc;
1604 
1605 	// start talking to factotum
1606 	if((afd = open("/mnt/factotum/rpc", ORDWR)) < 0)
1607 		return nil;
1608 	if((rpc = auth_allocrpc(afd)) == nil){
1609 		close(afd);
1610 		return nil;
1611 	}
1612 	s = "proto=rsa service=tls role=client";
1613 	if(auth_rpc(rpc, "start", s, strlen(s)) != ARok){
1614 		factotum_rsa_close(rpc);
1615 		return nil;
1616 	}
1617 
1618 	// roll factotum keyring around to match certificate
1619 	rsapub = X509toRSApub(cert, certlen, nil, 0);
1620 	while(1){
1621 		if(auth_rpc(rpc, "read", nil, 0) != ARok){
1622 			factotum_rsa_close(rpc);
1623 			rpc = nil;
1624 			goto done;
1625 		}
1626 		pub = strtomp(rpc->arg, nil, 16, nil);
1627 		assert(pub != nil);
1628 		if(mpcmp(pub,rsapub->n) == 0)
1629 			break;
1630 	}
1631 done:
1632 	mpfree(pub);
1633 	rsapubfree(rsapub);
1634 	return rpc;
1635 }
1636 
1637 static mpint*
1638 factotum_rsa_decrypt(AuthRpc *rpc, mpint *cipher)
1639 {
1640 	char *p;
1641 	int rv;
1642 
1643 	if((p = mptoa(cipher, 16, nil, 0)) == nil)
1644 		return nil;
1645 	rv = auth_rpc(rpc, "write", p, strlen(p));
1646 	free(p);
1647 	if(rv != ARok || auth_rpc(rpc, "read", nil, 0) != ARok)
1648 		return nil;
1649 	mpfree(cipher);
1650 	return strtomp(rpc->arg, nil, 16, nil);
1651 }
1652 
1653 static void
1654 factotum_rsa_close(AuthRpc*rpc)
1655 {
1656 	if(!rpc)
1657 		return;
1658 	close(rpc->afd);
1659 	auth_freerpc(rpc);
1660 }
1661 
1662 static void
1663 tlsPmd5(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
1664 {
1665 	uchar ai[MD5dlen], tmp[MD5dlen];
1666 	int i, n;
1667 	MD5state *s;
1668 
1669 	// generate a1
1670 	s = hmac_md5(label, nlabel, key, nkey, nil, nil);
1671 	s = hmac_md5(seed0, nseed0, key, nkey, nil, s);
1672 	hmac_md5(seed1, nseed1, key, nkey, ai, s);
1673 
1674 	while(nbuf > 0) {
1675 		s = hmac_md5(ai, MD5dlen, key, nkey, nil, nil);
1676 		s = hmac_md5(label, nlabel, key, nkey, nil, s);
1677 		s = hmac_md5(seed0, nseed0, key, nkey, nil, s);
1678 		hmac_md5(seed1, nseed1, key, nkey, tmp, s);
1679 		n = MD5dlen;
1680 		if(n > nbuf)
1681 			n = nbuf;
1682 		for(i = 0; i < n; i++)
1683 			buf[i] ^= tmp[i];
1684 		buf += n;
1685 		nbuf -= n;
1686 		hmac_md5(ai, MD5dlen, key, nkey, tmp, nil);
1687 		memmove(ai, tmp, MD5dlen);
1688 	}
1689 }
1690 
1691 static void
1692 tlsPsha1(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
1693 {
1694 	uchar ai[SHA1dlen], tmp[SHA1dlen];
1695 	int i, n;
1696 	SHAstate *s;
1697 
1698 	// generate a1
1699 	s = hmac_sha1(label, nlabel, key, nkey, nil, nil);
1700 	s = hmac_sha1(seed0, nseed0, key, nkey, nil, s);
1701 	hmac_sha1(seed1, nseed1, key, nkey, ai, s);
1702 
1703 	while(nbuf > 0) {
1704 		s = hmac_sha1(ai, SHA1dlen, key, nkey, nil, nil);
1705 		s = hmac_sha1(label, nlabel, key, nkey, nil, s);
1706 		s = hmac_sha1(seed0, nseed0, key, nkey, nil, s);
1707 		hmac_sha1(seed1, nseed1, key, nkey, tmp, s);
1708 		n = SHA1dlen;
1709 		if(n > nbuf)
1710 			n = nbuf;
1711 		for(i = 0; i < n; i++)
1712 			buf[i] ^= tmp[i];
1713 		buf += n;
1714 		nbuf -= n;
1715 		hmac_sha1(ai, SHA1dlen, key, nkey, tmp, nil);
1716 		memmove(ai, tmp, SHA1dlen);
1717 	}
1718 }
1719 
1720 // fill buf with md5(args)^sha1(args)
1721 static void
1722 tlsPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
1723 {
1724 	int i;
1725 	int nlabel = strlen(label);
1726 	int n = (nkey + 1) >> 1;
1727 
1728 	for(i = 0; i < nbuf; i++)
1729 		buf[i] = 0;
1730 	tlsPmd5(buf, nbuf, key, n, (uchar*)label, nlabel, seed0, nseed0, seed1, nseed1);
1731 	tlsPsha1(buf, nbuf, key+nkey-n, n, (uchar*)label, nlabel, seed0, nseed0, seed1, nseed1);
1732 }
1733 
1734 /*
1735  * for setting server session id's
1736  */
1737 static Lock	sidLock;
1738 static long	maxSid = 1;
1739 
1740 /* the keys are verified to have the same public components
1741  * and to function correctly with pkcs 1 encryption and decryption. */
1742 static TlsSec*
1743 tlsSecInits(int cvers, uchar *csid, int ncsid, uchar *crandom, uchar *ssid, int *nssid, uchar *srandom)
1744 {
1745 	TlsSec *sec = emalloc(sizeof(*sec));
1746 
1747 	USED(csid); USED(ncsid);  // ignore csid for now
1748 
1749 	memmove(sec->crandom, crandom, RandomSize);
1750 	sec->clientVers = cvers;
1751 
1752 	put32(sec->srandom, time(0));
1753 	genrandom(sec->srandom+4, RandomSize-4);
1754 	memmove(srandom, sec->srandom, RandomSize);
1755 
1756 	/*
1757 	 * make up a unique sid: use our pid, and and incrementing id
1758 	 * can signal no sid by setting nssid to 0.
1759 	 */
1760 	memset(ssid, 0, SidSize);
1761 	put32(ssid, getpid());
1762 	lock(&sidLock);
1763 	put32(ssid+4, maxSid++);
1764 	unlock(&sidLock);
1765 	*nssid = SidSize;
1766 	return sec;
1767 }
1768 
1769 static int
1770 tlsSecSecrets(TlsSec *sec, int vers, uchar *epm, int nepm, uchar *kd, int nkd)
1771 {
1772 	if(epm != nil){
1773 		if(setVers(sec, vers) < 0)
1774 			goto Err;
1775 		serverMasterSecret(sec, epm, nepm);
1776 	}else if(sec->vers != vers){
1777 		werrstr("mismatched session versions");
1778 		goto Err;
1779 	}
1780 	setSecrets(sec, kd, nkd);
1781 	return 0;
1782 Err:
1783 	sec->ok = -1;
1784 	return -1;
1785 }
1786 
1787 static TlsSec*
1788 tlsSecInitc(int cvers, uchar *crandom)
1789 {
1790 	TlsSec *sec = emalloc(sizeof(*sec));
1791 	sec->clientVers = cvers;
1792 	put32(sec->crandom, time(0));
1793 	genrandom(sec->crandom+4, RandomSize-4);
1794 	memmove(crandom, sec->crandom, RandomSize);
1795 	return sec;
1796 }
1797 
1798 static int
1799 tlsSecSecretc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers, uchar **epm, int *nepm, uchar *kd, int nkd)
1800 {
1801 	RSApub *pub;
1802 
1803 	pub = nil;
1804 
1805 	USED(sid);
1806 	USED(nsid);
1807 
1808 	memmove(sec->srandom, srandom, RandomSize);
1809 
1810 	if(setVers(sec, vers) < 0)
1811 		goto Err;
1812 
1813 	pub = X509toRSApub(cert, ncert, nil, 0);
1814 	if(pub == nil){
1815 		werrstr("invalid x509/rsa certificate");
1816 		goto Err;
1817 	}
1818 	if(clientMasterSecret(sec, pub, epm, nepm) < 0)
1819 		goto Err;
1820 	rsapubfree(pub);
1821 	setSecrets(sec, kd, nkd);
1822 	return 0;
1823 
1824 Err:
1825 	if(pub != nil)
1826 		rsapubfree(pub);
1827 	sec->ok = -1;
1828 	return -1;
1829 }
1830 
1831 static int
1832 tlsSecFinished(TlsSec *sec, MD5state md5, SHAstate sha1, uchar *fin, int nfin, int isclient)
1833 {
1834 	if(sec->nfin != nfin){
1835 		sec->ok = -1;
1836 		werrstr("invalid finished exchange");
1837 		return -1;
1838 	}
1839 	md5.malloced = 0;
1840 	sha1.malloced = 0;
1841 	(*sec->setFinished)(sec, md5, sha1, fin, isclient);
1842 	return 1;
1843 }
1844 
1845 static void
1846 tlsSecOk(TlsSec *sec)
1847 {
1848 	if(sec->ok == 0)
1849 		sec->ok = 1;
1850 }
1851 
1852 static void
1853 tlsSecKill(TlsSec *sec)
1854 {
1855 	if(!sec)
1856 		return;
1857 	factotum_rsa_close(sec->rpc);
1858 	sec->ok = -1;
1859 }
1860 
1861 static void
1862 tlsSecClose(TlsSec *sec)
1863 {
1864 	if(!sec)
1865 		return;
1866 	factotum_rsa_close(sec->rpc);
1867 	free(sec->server);
1868 	free(sec);
1869 }
1870 
1871 static int
1872 setVers(TlsSec *sec, int v)
1873 {
1874 	if(v == SSL3Version){
1875 		sec->setFinished = sslSetFinished;
1876 		sec->nfin = SSL3FinishedLen;
1877 		sec->prf = sslPRF;
1878 	}else if(v == TLSVersion){
1879 		sec->setFinished = tlsSetFinished;
1880 		sec->nfin = TLSFinishedLen;
1881 		sec->prf = tlsPRF;
1882 	}else{
1883 		werrstr("invalid version");
1884 		return -1;
1885 	}
1886 	sec->vers = v;
1887 	return 0;
1888 }
1889 
1890 /*
1891  * generate secret keys from the master secret.
1892  *
1893  * different crypto selections will require different amounts
1894  * of key expansion and use of key expansion data,
1895  * but it's all generated using the same function.
1896  */
1897 static void
1898 setSecrets(TlsSec *sec, uchar *kd, int nkd)
1899 {
1900 	(*sec->prf)(kd, nkd, sec->sec, MasterSecretSize, "key expansion",
1901 			sec->srandom, RandomSize, sec->crandom, RandomSize);
1902 }
1903 
1904 /*
1905  * set the master secret from the pre-master secret.
1906  */
1907 static void
1908 setMasterSecret(TlsSec *sec, Bytes *pm)
1909 {
1910 	(*sec->prf)(sec->sec, MasterSecretSize, pm->data, MasterSecretSize, "master secret",
1911 			sec->crandom, RandomSize, sec->srandom, RandomSize);
1912 }
1913 
1914 static void
1915 serverMasterSecret(TlsSec *sec, uchar *epm, int nepm)
1916 {
1917 	Bytes *pm;
1918 
1919 	pm = pkcs1_decrypt(sec, epm, nepm);
1920 
1921 	// if the client messed up, just continue as if everything is ok,
1922 	// to prevent attacks to check for correctly formatted messages.
1923 	// Hence the fprint(2,) can't be replaced by tlsError(), which sends an Alert msg to the client.
1924 	if(sec->ok < 0 || pm == nil || get16(pm->data) != sec->clientVers){
1925 		fprint(2, "serverMasterSecret failed ok=%d pm=%p pmvers=%x cvers=%x nepm=%d\n",
1926 			sec->ok, pm, pm ? get16(pm->data) : -1, sec->clientVers, nepm);
1927 		sec->ok = -1;
1928 		if(pm != nil)
1929 			freebytes(pm);
1930 		pm = newbytes(MasterSecretSize);
1931 		genrandom(pm->data, MasterSecretSize);
1932 	}
1933 	setMasterSecret(sec, pm);
1934 	memset(pm->data, 0, pm->len);
1935 	freebytes(pm);
1936 }
1937 
1938 static int
1939 clientMasterSecret(TlsSec *sec, RSApub *pub, uchar **epm, int *nepm)
1940 {
1941 	Bytes *pm, *key;
1942 
1943 	pm = newbytes(MasterSecretSize);
1944 	put16(pm->data, sec->clientVers);
1945 	genrandom(pm->data+2, MasterSecretSize - 2);
1946 
1947 	setMasterSecret(sec, pm);
1948 
1949 	key = pkcs1_encrypt(pm, pub, 2);
1950 	memset(pm->data, 0, pm->len);
1951 	freebytes(pm);
1952 	if(key == nil){
1953 		werrstr("tls pkcs1_encrypt failed");
1954 		return -1;
1955 	}
1956 
1957 	*nepm = key->len;
1958 	*epm = malloc(*nepm);
1959 	if(*epm == nil){
1960 		freebytes(key);
1961 		werrstr("out of memory");
1962 		return -1;
1963 	}
1964 	memmove(*epm, key->data, *nepm);
1965 
1966 	freebytes(key);
1967 
1968 	return 1;
1969 }
1970 
1971 static void
1972 sslSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient)
1973 {
1974 	DigestState *s;
1975 	uchar h0[MD5dlen], h1[SHA1dlen], pad[48];
1976 	char *label;
1977 
1978 	if(isClient)
1979 		label = "CLNT";
1980 	else
1981 		label = "SRVR";
1982 
1983 	md5((uchar*)label, 4, nil, &hsmd5);
1984 	md5(sec->sec, MasterSecretSize, nil, &hsmd5);
1985 	memset(pad, 0x36, 48);
1986 	md5(pad, 48, nil, &hsmd5);
1987 	md5(nil, 0, h0, &hsmd5);
1988 	memset(pad, 0x5C, 48);
1989 	s = md5(sec->sec, MasterSecretSize, nil, nil);
1990 	s = md5(pad, 48, nil, s);
1991 	md5(h0, MD5dlen, finished, s);
1992 
1993 	sha1((uchar*)label, 4, nil, &hssha1);
1994 	sha1(sec->sec, MasterSecretSize, nil, &hssha1);
1995 	memset(pad, 0x36, 40);
1996 	sha1(pad, 40, nil, &hssha1);
1997 	sha1(nil, 0, h1, &hssha1);
1998 	memset(pad, 0x5C, 40);
1999 	s = sha1(sec->sec, MasterSecretSize, nil, nil);
2000 	s = sha1(pad, 40, nil, s);
2001 	sha1(h1, SHA1dlen, finished + MD5dlen, s);
2002 }
2003 
2004 // fill "finished" arg with md5(args)^sha1(args)
2005 static void
2006 tlsSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient)
2007 {
2008 	uchar h0[MD5dlen], h1[SHA1dlen];
2009 	char *label;
2010 
2011 	// get current hash value, but allow further messages to be hashed in
2012 	md5(nil, 0, h0, &hsmd5);
2013 	sha1(nil, 0, h1, &hssha1);
2014 
2015 	if(isClient)
2016 		label = "client finished";
2017 	else
2018 		label = "server finished";
2019 	tlsPRF(finished, TLSFinishedLen, sec->sec, MasterSecretSize, label, h0, MD5dlen, h1, SHA1dlen);
2020 }
2021 
2022 static void
2023 sslPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
2024 {
2025 	DigestState *s;
2026 	uchar sha1dig[SHA1dlen], md5dig[MD5dlen], tmp[26];
2027 	int i, n, len;
2028 
2029 	USED(label);
2030 	len = 1;
2031 	while(nbuf > 0){
2032 		if(len > 26)
2033 			return;
2034 		for(i = 0; i < len; i++)
2035 			tmp[i] = 'A' - 1 + len;
2036 		s = sha1(tmp, len, nil, nil);
2037 		s = sha1(key, nkey, nil, s);
2038 		s = sha1(seed0, nseed0, nil, s);
2039 		sha1(seed1, nseed1, sha1dig, s);
2040 		s = md5(key, nkey, nil, nil);
2041 		md5(sha1dig, SHA1dlen, md5dig, s);
2042 		n = MD5dlen;
2043 		if(n > nbuf)
2044 			n = nbuf;
2045 		memmove(buf, md5dig, n);
2046 		buf += n;
2047 		nbuf -= n;
2048 		len++;
2049 	}
2050 }
2051 
2052 static mpint*
2053 bytestomp(Bytes* bytes)
2054 {
2055 	mpint* ans;
2056 
2057 	ans = betomp(bytes->data, bytes->len, nil);
2058 	return ans;
2059 }
2060 
2061 /*
2062  * Convert mpint* to Bytes, putting high order byte first.
2063  */
2064 static Bytes*
2065 mptobytes(mpint* big)
2066 {
2067 	int n, m;
2068 	uchar *a;
2069 	Bytes* ans;
2070 
2071 	a = nil;
2072 	n = (mpsignif(big)+7)/8;
2073 	m = mptobe(big, nil, n, &a);
2074 	ans = makebytes(a, m);
2075 	if(a != nil)
2076 		free(a);
2077 	return ans;
2078 }
2079 
2080 // Do RSA computation on block according to key, and pad
2081 // result on left with zeros to make it modlen long.
2082 static Bytes*
2083 rsacomp(Bytes* block, RSApub* key, int modlen)
2084 {
2085 	mpint *x, *y;
2086 	Bytes *a, *ybytes;
2087 	int ylen;
2088 
2089 	x = bytestomp(block);
2090 	y = rsaencrypt(key, x, nil);
2091 	mpfree(x);
2092 	ybytes = mptobytes(y);
2093 	ylen = ybytes->len;
2094 
2095 	if(ylen < modlen) {
2096 		a = newbytes(modlen);
2097 		memset(a->data, 0, modlen-ylen);
2098 		memmove(a->data+modlen-ylen, ybytes->data, ylen);
2099 		freebytes(ybytes);
2100 		ybytes = a;
2101 	}
2102 	else if(ylen > modlen) {
2103 		// assume it has leading zeros (mod should make it so)
2104 		a = newbytes(modlen);
2105 		memmove(a->data, ybytes->data, modlen);
2106 		freebytes(ybytes);
2107 		ybytes = a;
2108 	}
2109 	mpfree(y);
2110 	return ybytes;
2111 }
2112 
2113 // encrypt data according to PKCS#1, /lib/rfc/rfc2437 9.1.2.1
2114 static Bytes*
2115 pkcs1_encrypt(Bytes* data, RSApub* key, int blocktype)
2116 {
2117 	Bytes *pad, *eb, *ans;
2118 	int i, dlen, padlen, modlen;
2119 
2120 	modlen = (mpsignif(key->n)+7)/8;
2121 	dlen = data->len;
2122 	if(modlen < 12 || dlen > modlen - 11)
2123 		return nil;
2124 	padlen = modlen - 3 - dlen;
2125 	pad = newbytes(padlen);
2126 	genrandom(pad->data, padlen);
2127 	for(i = 0; i < padlen; i++) {
2128 		if(blocktype == 0)
2129 			pad->data[i] = 0;
2130 		else if(blocktype == 1)
2131 			pad->data[i] = 255;
2132 		else if(pad->data[i] == 0)
2133 			pad->data[i] = 1;
2134 	}
2135 	eb = newbytes(modlen);
2136 	eb->data[0] = 0;
2137 	eb->data[1] = blocktype;
2138 	memmove(eb->data+2, pad->data, padlen);
2139 	eb->data[padlen+2] = 0;
2140 	memmove(eb->data+padlen+3, data->data, dlen);
2141 	ans = rsacomp(eb, key, modlen);
2142 	freebytes(eb);
2143 	freebytes(pad);
2144 	return ans;
2145 }
2146 
2147 // decrypt data according to PKCS#1, with given key.
2148 // expect a block type of 2.
2149 static Bytes*
2150 pkcs1_decrypt(TlsSec *sec, uchar *epm, int nepm)
2151 {
2152 	Bytes *eb, *ans = nil;
2153 	int i, modlen;
2154 	mpint *x, *y;
2155 
2156 	modlen = (mpsignif(sec->rsapub->n)+7)/8;
2157 	if(nepm != modlen)
2158 		return nil;
2159 	x = betomp(epm, nepm, nil);
2160 	y = factotum_rsa_decrypt(sec->rpc, x);
2161 	if(y == nil)
2162 		return nil;
2163 	eb = mptobytes(y);
2164 	if(eb->len < modlen){ // pad on left with zeros
2165 		ans = newbytes(modlen);
2166 		memset(ans->data, 0, modlen-eb->len);
2167 		memmove(ans->data+modlen-eb->len, eb->data, eb->len);
2168 		freebytes(eb);
2169 		eb = ans;
2170 	}
2171 	if(eb->data[0] == 0 && eb->data[1] == 2) {
2172 		for(i = 2; i < modlen; i++)
2173 			if(eb->data[i] == 0)
2174 				break;
2175 		if(i < modlen - 1)
2176 			ans = makebytes(eb->data+i+1, modlen-(i+1));
2177 	}
2178 	freebytes(eb);
2179 	return ans;
2180 }
2181 
2182 
2183 //================= general utility functions ========================
2184 
2185 static void *
2186 emalloc(int n)
2187 {
2188 	void *p;
2189 	if(n==0)
2190 		n=1;
2191 	p = malloc(n);
2192 	if(p == nil){
2193 		exits("out of memory");
2194 	}
2195 	memset(p, 0, n);
2196 	return p;
2197 }
2198 
2199 static void *
2200 erealloc(void *ReallocP, int ReallocN)
2201 {
2202 	if(ReallocN == 0)
2203 		ReallocN = 1;
2204 	if(!ReallocP)
2205 		ReallocP = emalloc(ReallocN);
2206 	else if(!(ReallocP = realloc(ReallocP, ReallocN))){
2207 		exits("out of memory");
2208 	}
2209 	return(ReallocP);
2210 }
2211 
2212 static void
2213 put32(uchar *p, u32int x)
2214 {
2215 	p[0] = x>>24;
2216 	p[1] = x>>16;
2217 	p[2] = x>>8;
2218 	p[3] = x;
2219 }
2220 
2221 static void
2222 put24(uchar *p, int x)
2223 {
2224 	p[0] = x>>16;
2225 	p[1] = x>>8;
2226 	p[2] = x;
2227 }
2228 
2229 static void
2230 put16(uchar *p, int x)
2231 {
2232 	p[0] = x>>8;
2233 	p[1] = x;
2234 }
2235 
2236 static u32int
2237 get32(uchar *p)
2238 {
2239 	return (p[0]<<24)|(p[1]<<16)|(p[2]<<8)|p[3];
2240 }
2241 
2242 static int
2243 get24(uchar *p)
2244 {
2245 	return (p[0]<<16)|(p[1]<<8)|p[2];
2246 }
2247 
2248 static int
2249 get16(uchar *p)
2250 {
2251 	return (p[0]<<8)|p[1];
2252 }
2253 
2254 /* ANSI offsetof() */
2255 #define OFFSET(x, s) ((int)(&(((s*)0)->x)))
2256 
2257 /*
2258  * malloc and return a new Bytes structure capable of
2259  * holding len bytes. (len >= 0)
2260  * Used to use crypt_malloc, which aborts if malloc fails.
2261  */
2262 static Bytes*
2263 newbytes(int len)
2264 {
2265 	Bytes* ans;
2266 
2267 	ans = (Bytes*)malloc(OFFSET(data[0], Bytes) + len);
2268 	ans->len = len;
2269 	return ans;
2270 }
2271 
2272 /*
2273  * newbytes(len), with data initialized from buf
2274  */
2275 static Bytes*
2276 makebytes(uchar* buf, int len)
2277 {
2278 	Bytes* ans;
2279 
2280 	ans = newbytes(len);
2281 	memmove(ans->data, buf, len);
2282 	return ans;
2283 }
2284 
2285 static void
2286 freebytes(Bytes* b)
2287 {
2288 	if(b != nil)
2289 		free(b);
2290 }
2291 
2292 /* len is number of ints */
2293 static Ints*
2294 newints(int len)
2295 {
2296 	Ints* ans;
2297 
2298 	ans = (Ints*)malloc(OFFSET(data[0], Ints) + len*sizeof(int));
2299 	ans->len = len;
2300 	return ans;
2301 }
2302 
2303 static Ints*
2304 makeints(int* buf, int len)
2305 {
2306 	Ints* ans;
2307 
2308 	ans = newints(len);
2309 	if(len > 0)
2310 		memmove(ans->data, buf, len*sizeof(int));
2311 	return ans;
2312 }
2313 
2314 static void
2315 freeints(Ints* b)
2316 {
2317 	if(b != nil)
2318 		free(b);
2319 }
2320