xref: /plan9/sys/src/libsec/port/tlshand.c (revision ad6ca847b1a6a504acb0003cd6c5c6d92687369b)
1 #include <u.h>
2 #include <libc.h>
3 #include <bio.h>
4 #include <auth.h>
5 #include <mp.h>
6 #include <libsec.h>
7 
8 // The main groups of functions are:
9 //		client/server - main handshake protocol definition
10 //		message functions - formating handshake messages
11 //		cipher choices - catalog of digest and encrypt algorithms
12 //		security functions - PKCS#1, sslHMAC, session keygen
13 //		general utility functions - malloc, serialization
14 // The handshake protocol builds on the TLS/SSL3 record layer protocol,
15 // which is implemented in kernel device #a.  See also /lib/rfc/rfc2246.
16 
17 enum {
18 	TLSFinishedLen = 12,
19 	SSL3FinishedLen = MD5dlen+SHA1dlen,
20 	MaxKeyData = 136,	// amount of secret we may need
21 	MaxChunk = 1<<14,
22 	RandomSize = 32,
23 	SidSize = 32,
24 	MasterSecretSize = 48,
25 	AQueue = 0,
26 	AFlush = 1,
27 };
28 
29 typedef struct TlsSec TlsSec;
30 
31 typedef struct Bytes{
32 	int len;
33 	uchar data[1];  // [len]
34 } Bytes;
35 
36 typedef struct Ints{
37 	int len;
38 	int data[1];  // [len]
39 } Ints;
40 
41 typedef struct Algs{
42 	char *enc;
43 	char *digest;
44 	int nsecret;
45 	int tlsid;
46 	int ok;
47 } Algs;
48 
49 typedef struct Finished{
50 	uchar verify[SSL3FinishedLen];
51 	int n;
52 } Finished;
53 
54 typedef struct TlsConnection{
55 	TlsSec *sec;	// security management goo
56 	int hand, ctl;	// record layer file descriptors
57 	int erred;		// set when tlsError called
58 	int (*trace)(char*fmt, ...); // for debugging
59 	int version;	// protocol we are speaking
60 	int verset;		// version has been set
61 	int ver2hi;		// server got a version 2 hello
62 	int isClient;	// is this the client or server?
63 	Bytes *sid;		// SessionID
64 	Bytes *cert;	// only last - no chain
65 
66 	Lock statelk;
67 	int state;		// must be set using setstate
68 
69 	// input buffer for handshake messages
70 	uchar buf[MaxChunk+2048];
71 	uchar *rp, *ep;
72 
73 	uchar crandom[RandomSize];	// client random
74 	uchar srandom[RandomSize];	// server random
75 	int clientVersion;	// version in ClientHello
76 	char *digest;	// name of digest algorithm to use
77 	char *enc;		// name of encryption algorithm to use
78 	int nsecret;	// amount of secret data to init keys
79 
80 	// for finished messages
81 	MD5state	hsmd5;	// handshake hash
82 	SHAstate	hssha1;	// handshake hash
83 	Finished	finished;
84 } TlsConnection;
85 
86 typedef struct Msg{
87 	int tag;
88 	union {
89 		struct {
90 			int version;
91 			uchar 	random[RandomSize];
92 			Bytes*	sid;
93 			Ints*	ciphers;
94 			Bytes*	compressors;
95 		} clientHello;
96 		struct {
97 			int version;
98 			uchar 	random[RandomSize];
99 			Bytes*	sid;
100 			int cipher;
101 			int compressor;
102 		} serverHello;
103 		struct {
104 			int ncert;
105 			Bytes **certs;
106 		} certificate;
107 		struct {
108 			Bytes *types;
109 			int nca;
110 			Bytes **cas;
111 		} certificateRequest;
112 		struct {
113 			Bytes *key;
114 		} clientKeyExchange;
115 		Finished finished;
116 	} u;
117 } Msg;
118 
119 typedef struct TlsSec{
120 	char *server;	// name of remote; nil for server
121 	int ok;	// <0 killed; == 0 in progress; >0 reusable
122 	RSApub *rsapub;
123 	AuthRpc *rpc;	// factotum for rsa private key
124 	uchar sec[MasterSecretSize];	// master secret
125 	uchar crandom[RandomSize];	// client random
126 	uchar srandom[RandomSize];	// server random
127 	int clientVers;		// version in ClientHello
128 	int vers;			// final version
129 	// byte generation and handshake checksum
130 	void (*prf)(uchar*, int, uchar*, int, char*, uchar*, int, uchar*, int);
131 	void (*setFinished)(TlsSec*, MD5state, SHAstate, uchar*, int);
132 	int nfin;
133 } TlsSec;
134 
135 
136 enum {
137 	TLSVersion = 0x0301,
138 	SSL3Version = 0x0300,
139 	ProtocolVersion = 0x0301,	// maximum version we speak
140 	MinProtoVersion = 0x0300,	// limits on version we accept
141 	MaxProtoVersion	= 0x03ff,
142 };
143 
144 // handshake type
145 enum {
146 	HHelloRequest,
147 	HClientHello,
148 	HServerHello,
149 	HSSL2ClientHello = 9,  /* local convention;  see devtls.c */
150 	HCertificate = 11,
151 	HServerKeyExchange,
152 	HCertificateRequest,
153 	HServerHelloDone,
154 	HCertificateVerify,
155 	HClientKeyExchange,
156 	HFinished = 20,
157 	HMax
158 };
159 
160 // alerts
161 enum {
162 	ECloseNotify = 0,
163 	EUnexpectedMessage = 10,
164 	EBadRecordMac = 20,
165 	EDecryptionFailed = 21,
166 	ERecordOverflow = 22,
167 	EDecompressionFailure = 30,
168 	EHandshakeFailure = 40,
169 	ENoCertificate = 41,
170 	EBadCertificate = 42,
171 	EUnsupportedCertificate = 43,
172 	ECertificateRevoked = 44,
173 	ECertificateExpired = 45,
174 	ECertificateUnknown = 46,
175 	EIllegalParameter = 47,
176 	EUnknownCa = 48,
177 	EAccessDenied = 49,
178 	EDecodeError = 50,
179 	EDecryptError = 51,
180 	EExportRestriction = 60,
181 	EProtocolVersion = 70,
182 	EInsufficientSecurity = 71,
183 	EInternalError = 80,
184 	EUserCanceled = 90,
185 	ENoRenegotiation = 100,
186 	EMax = 256
187 };
188 
189 // cipher suites
190 enum {
191 	TLS_NULL_WITH_NULL_NULL	 		= 0x0000,
192 	TLS_RSA_WITH_NULL_MD5 			= 0x0001,
193 	TLS_RSA_WITH_NULL_SHA 			= 0x0002,
194 	TLS_RSA_EXPORT_WITH_RC4_40_MD5 		= 0x0003,
195 	TLS_RSA_WITH_RC4_128_MD5 		= 0x0004,
196 	TLS_RSA_WITH_RC4_128_SHA 		= 0x0005,
197 	TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5	= 0X0006,
198 	TLS_RSA_WITH_IDEA_CBC_SHA 		= 0X0007,
199 	TLS_RSA_EXPORT_WITH_DES40_CBC_SHA	= 0X0008,
200 	TLS_RSA_WITH_DES_CBC_SHA		= 0X0009,
201 	TLS_RSA_WITH_3DES_EDE_CBC_SHA		= 0X000A,
202 	TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA	= 0X000B,
203 	TLS_DH_DSS_WITH_DES_CBC_SHA		= 0X000C,
204 	TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA	= 0X000D,
205 	TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA	= 0X000E,
206 	TLS_DH_RSA_WITH_DES_CBC_SHA		= 0X000F,
207 	TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA	= 0X0010,
208 	TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA	= 0X0011,
209 	TLS_DHE_DSS_WITH_DES_CBC_SHA		= 0X0012,
210 	TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA	= 0X0013,	// ZZZ must be implemented for tls1.0 compliance
211 	TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA	= 0X0014,
212 	TLS_DHE_RSA_WITH_DES_CBC_SHA		= 0X0015,
213 	TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA	= 0X0016,
214 	TLS_DH_anon_EXPORT_WITH_RC4_40_MD5	= 0x0017,
215 	TLS_DH_anon_WITH_RC4_128_MD5 		= 0x0018,
216 	TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA	= 0X0019,
217 	TLS_DH_anon_WITH_DES_CBC_SHA		= 0X001A,
218 	TLS_DH_anon_WITH_3DES_EDE_CBC_SHA	= 0X001B,
219 
220 	TLS_RSA_WITH_AES_128_CBC_SHA		= 0X002f,	// aes, aka rijndael with 128 bit blocks
221 	TLS_DH_DSS_WITH_AES_128_CBC_SHA		= 0X0030,
222 	TLS_DH_RSA_WITH_AES_128_CBC_SHA		= 0X0031,
223 	TLS_DHE_DSS_WITH_AES_128_CBC_SHA	= 0X0032,
224 	TLS_DHE_RSA_WITH_AES_128_CBC_SHA	= 0X0033,
225 	TLS_DH_anon_WITH_AES_128_CBC_SHA	= 0X0034,
226 	TLS_RSA_WITH_AES_256_CBC_SHA		= 0X0035,
227 	TLS_DH_DSS_WITH_AES_256_CBC_SHA		= 0X0036,
228 	TLS_DH_RSA_WITH_AES_256_CBC_SHA		= 0X0037,
229 	TLS_DHE_DSS_WITH_AES_256_CBC_SHA	= 0X0038,
230 	TLS_DHE_RSA_WITH_AES_256_CBC_SHA	= 0X0039,
231 	TLS_DH_anon_WITH_AES_256_CBC_SHA	= 0X003A,
232 	CipherMax
233 };
234 
235 // compression methods
236 enum {
237 	CompressionNull = 0,
238 	CompressionMax
239 };
240 
241 static Algs cipherAlgs[] = {
242 	{"rc4_128", "md5", 2*(16+MD5dlen), TLS_RSA_WITH_RC4_128_MD5},
243 	{"rc4_128", "sha1", 2*(16+SHA1dlen), TLS_RSA_WITH_RC4_128_SHA},
244 	{"3des_ede_cbc", "sha1", 2*(4*8+SHA1dlen), TLS_RSA_WITH_3DES_EDE_CBC_SHA},
245 	{"aes_128_cbc", "sha1", 2*(16+16+SHA1dlen), TLS_RSA_WITH_AES_128_CBC_SHA},
246 	{"aes_256_cbc", "sha1", 2*(32+16+SHA1dlen), TLS_RSA_WITH_AES_256_CBC_SHA}
247 };
248 
249 static uchar compressors[] = {
250 	CompressionNull,
251 };
252 
253 static TlsConnection *tlsServer2(int ctl, int hand, uchar *cert, int ncert, int (*trace)(char*fmt, ...), PEMChain *chain);
254 static TlsConnection *tlsClient2(int ctl, int hand, uchar *csid, int ncsid, int (*trace)(char*fmt, ...));
255 
256 static void	msgClear(Msg *m);
257 static char* msgPrint(char *buf, int n, Msg *m);
258 static int	msgRecv(TlsConnection *c, Msg *m);
259 static int	msgSend(TlsConnection *c, Msg *m, int act);
260 static void	tlsError(TlsConnection *c, int err, char *msg, ...);
261 #pragma	varargck argpos	tlsError 3
262 static int setVersion(TlsConnection *c, int version);
263 static int finishedMatch(TlsConnection *c, Finished *f);
264 static void tlsConnectionFree(TlsConnection *c);
265 
266 static int setAlgs(TlsConnection *c, int a);
267 static int okCipher(Ints *cv);
268 static int okCompression(Bytes *cv);
269 static int initCiphers(void);
270 static Ints* makeciphers(void);
271 
272 static TlsSec* tlsSecInits(int cvers, uchar *csid, int ncsid, uchar *crandom, uchar *ssid, int *nssid, uchar *srandom);
273 static int	tlsSecSecrets(TlsSec *sec, int vers, uchar *epm, int nepm, uchar *kd, int nkd);
274 static TlsSec*	tlsSecInitc(int cvers, uchar *crandom);
275 static int	tlsSecSecretc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers, uchar **epm, int *nepm, uchar *kd, int nkd);
276 static int	tlsSecFinished(TlsSec *sec, MD5state md5, SHAstate sha1, uchar *fin, int nfin, int isclient);
277 static void	tlsSecOk(TlsSec *sec);
278 static void	tlsSecKill(TlsSec *sec);
279 static void	tlsSecClose(TlsSec *sec);
280 static void	setMasterSecret(TlsSec *sec, Bytes *pm);
281 static void	serverMasterSecret(TlsSec *sec, uchar *epm, int nepm);
282 static void	setSecrets(TlsSec *sec, uchar *kd, int nkd);
283 static int	clientMasterSecret(TlsSec *sec, RSApub *pub, uchar **epm, int *nepm);
284 static Bytes *pkcs1_encrypt(Bytes* data, RSApub* key, int blocktype);
285 static Bytes *pkcs1_decrypt(TlsSec *sec, uchar *epm, int nepm);
286 static void	tlsSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient);
287 static void	sslSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient);
288 static void	sslPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label,
289 			uchar *seed0, int nseed0, uchar *seed1, int nseed1);
290 static int setVers(TlsSec *sec, int version);
291 
292 static AuthRpc* factotum_rsa_open(uchar *cert, int certlen);
293 static mpint* factotum_rsa_decrypt(AuthRpc *rpc, mpint *cipher);
294 static void factotum_rsa_close(AuthRpc*rpc);
295 
296 static void* emalloc(int);
297 static void* erealloc(void*, int);
298 static void put32(uchar *p, u32int);
299 static void put24(uchar *p, int);
300 static void put16(uchar *p, int);
301 static u32int get32(uchar *p);
302 static int get24(uchar *p);
303 static int get16(uchar *p);
304 static Bytes* newbytes(int len);
305 static Bytes* makebytes(uchar* buf, int len);
306 static void freebytes(Bytes* b);
307 static Ints* newints(int len);
308 static Ints* makeints(int* buf, int len);
309 static void freeints(Ints* b);
310 
311 //================= client/server ========================
312 
313 //	push TLS onto fd, returning new (application) file descriptor
314 //		or -1 if error.
315 int
tlsServer(int fd,TLSconn * conn)316 tlsServer(int fd, TLSconn *conn)
317 {
318 	char buf[8];
319 	char dname[64];
320 	int n, data, ctl, hand;
321 	TlsConnection *tls;
322 
323 	if(conn == nil)
324 		return -1;
325 	ctl = open("#a/tls/clone", ORDWR);
326 	if(ctl < 0)
327 		return -1;
328 	n = read(ctl, buf, sizeof(buf)-1);
329 	if(n < 0){
330 		close(ctl);
331 		return -1;
332 	}
333 	buf[n] = 0;
334 	sprint(conn->dir, "#a/tls/%s", buf);
335 	sprint(dname, "#a/tls/%s/hand", buf);
336 	hand = open(dname, ORDWR);
337 	if(hand < 0){
338 		close(ctl);
339 		return -1;
340 	}
341 	fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion);
342 	tls = tlsServer2(ctl, hand, conn->cert, conn->certlen, conn->trace, conn->chain);
343 	sprint(dname, "#a/tls/%s/data", buf);
344 	data = open(dname, ORDWR);
345 	close(fd);
346 	close(hand);
347 	close(ctl);
348 	if(data < 0){
349 		return -1;
350 	}
351 	if(tls == nil){
352 		close(data);
353 		return -1;
354 	}
355 	if(conn->cert)
356 		free(conn->cert);
357 	conn->cert = 0;  // client certificates are not yet implemented
358 	conn->certlen = 0;
359 	conn->sessionIDlen = tls->sid->len;
360 	conn->sessionID = emalloc(conn->sessionIDlen);
361 	memcpy(conn->sessionID, tls->sid->data, conn->sessionIDlen);
362 	if(conn->sessionKey != nil && conn->sessionType != nil && strcmp(conn->sessionType, "ttls") == 0)
363 		tls->sec->prf(conn->sessionKey, conn->sessionKeylen, tls->sec->sec, MasterSecretSize, conn->sessionConst,  tls->sec->crandom, RandomSize, tls->sec->srandom, RandomSize);
364 	tlsConnectionFree(tls);
365 	return data;
366 }
367 
368 //	push TLS onto fd, returning new (application) file descriptor
369 //		or -1 if error.
370 int
tlsClient(int fd,TLSconn * conn)371 tlsClient(int fd, TLSconn *conn)
372 {
373 	char buf[8];
374 	char dname[64];
375 	int n, data, ctl, hand;
376 	TlsConnection *tls;
377 
378 	if(!conn)
379 		return -1;
380 	ctl = open("#a/tls/clone", ORDWR);
381 	if(ctl < 0)
382 		return -1;
383 	n = read(ctl, buf, sizeof(buf)-1);
384 	if(n < 0){
385 		close(ctl);
386 		return -1;
387 	}
388 	buf[n] = 0;
389 	sprint(conn->dir, "#a/tls/%s", buf);
390 	sprint(dname, "#a/tls/%s/hand", buf);
391 	hand = open(dname, ORDWR);
392 	if(hand < 0){
393 		close(ctl);
394 		return -1;
395 	}
396 	sprint(dname, "#a/tls/%s/data", buf);
397 	data = open(dname, ORDWR);
398 	if(data < 0)
399 		return -1;
400 	fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion);
401 	tls = tlsClient2(ctl, hand, conn->sessionID, conn->sessionIDlen, conn->trace);
402 	close(fd);
403 	close(hand);
404 	close(ctl);
405 	if(tls == nil){
406 		close(data);
407 		return -1;
408 	}
409 	conn->certlen = tls->cert->len;
410 	conn->cert = emalloc(conn->certlen);
411 	memcpy(conn->cert, tls->cert->data, conn->certlen);
412 	conn->sessionIDlen = tls->sid->len;
413 	conn->sessionID = emalloc(conn->sessionIDlen);
414 	memcpy(conn->sessionID, tls->sid->data, conn->sessionIDlen);
415 	if(conn->sessionKey != nil && conn->sessionType != nil && strcmp(conn->sessionType, "ttls") == 0)
416 		tls->sec->prf(conn->sessionKey, conn->sessionKeylen, tls->sec->sec, MasterSecretSize, conn->sessionConst,  tls->sec->crandom, RandomSize, tls->sec->srandom, RandomSize);
417 	tlsConnectionFree(tls);
418 	return data;
419 }
420 
421 static int
countchain(PEMChain * p)422 countchain(PEMChain *p)
423 {
424 	int i = 0;
425 
426 	while (p) {
427 		i++;
428 		p = p->next;
429 	}
430 	return i;
431 }
432 
433 static TlsConnection *
tlsServer2(int ctl,int hand,uchar * cert,int ncert,int (* trace)(char * fmt,...),PEMChain * chp)434 tlsServer2(int ctl, int hand, uchar *cert, int ncert, int (*trace)(char*fmt, ...), PEMChain *chp)
435 {
436 	TlsConnection *c;
437 	Msg m;
438 	Bytes *csid;
439 	uchar sid[SidSize], kd[MaxKeyData];
440 	char *secrets;
441 	int cipher, compressor, nsid, rv, numcerts, i;
442 
443 	if(trace)
444 		trace("tlsServer2\n");
445 	if(!initCiphers())
446 		return nil;
447 	c = emalloc(sizeof(TlsConnection));
448 	c->ctl = ctl;
449 	c->hand = hand;
450 	c->trace = trace;
451 	c->version = ProtocolVersion;
452 
453 	memset(&m, 0, sizeof(m));
454 	if(!msgRecv(c, &m)){
455 		if(trace)
456 			trace("initial msgRecv failed\n");
457 		goto Err;
458 	}
459 	if(m.tag != HClientHello) {
460 		tlsError(c, EUnexpectedMessage, "expected a client hello");
461 		goto Err;
462 	}
463 	c->clientVersion = m.u.clientHello.version;
464 	if(trace)
465 		trace("ClientHello version %x\n", c->clientVersion);
466 	if(setVersion(c, m.u.clientHello.version) < 0) {
467 		tlsError(c, EIllegalParameter, "incompatible version");
468 		goto Err;
469 	}
470 
471 	memmove(c->crandom, m.u.clientHello.random, RandomSize);
472 	cipher = okCipher(m.u.clientHello.ciphers);
473 	if(cipher < 0) {
474 		// reply with EInsufficientSecurity if we know that's the case
475 		if(cipher == -2)
476 			tlsError(c, EInsufficientSecurity, "cipher suites too weak");
477 		else
478 			tlsError(c, EHandshakeFailure, "no matching cipher suite");
479 		goto Err;
480 	}
481 	if(!setAlgs(c, cipher)){
482 		tlsError(c, EHandshakeFailure, "no matching cipher suite");
483 		goto Err;
484 	}
485 	compressor = okCompression(m.u.clientHello.compressors);
486 	if(compressor < 0) {
487 		tlsError(c, EHandshakeFailure, "no matching compressor");
488 		goto Err;
489 	}
490 
491 	csid = m.u.clientHello.sid;
492 	if(trace)
493 		trace("  cipher %d, compressor %d, csidlen %d\n", cipher, compressor, csid->len);
494 	c->sec = tlsSecInits(c->clientVersion, csid->data, csid->len, c->crandom, sid, &nsid, c->srandom);
495 	if(c->sec == nil){
496 		tlsError(c, EHandshakeFailure, "can't initialize security: %r");
497 		goto Err;
498 	}
499 	c->sec->rpc = factotum_rsa_open(cert, ncert);
500 	if(c->sec->rpc == nil){
501 		tlsError(c, EHandshakeFailure, "factotum_rsa_open: %r");
502 		goto Err;
503 	}
504 	c->sec->rsapub = X509toRSApub(cert, ncert, nil, 0);
505 	msgClear(&m);
506 
507 	m.tag = HServerHello;
508 	m.u.serverHello.version = c->version;
509 	memmove(m.u.serverHello.random, c->srandom, RandomSize);
510 	m.u.serverHello.cipher = cipher;
511 	m.u.serverHello.compressor = compressor;
512 	c->sid = makebytes(sid, nsid);
513 	m.u.serverHello.sid = makebytes(c->sid->data, c->sid->len);
514 	if(!msgSend(c, &m, AQueue))
515 		goto Err;
516 	msgClear(&m);
517 
518 	m.tag = HCertificate;
519 	numcerts = countchain(chp);
520 	m.u.certificate.ncert = 1 + numcerts;
521 	m.u.certificate.certs = emalloc(m.u.certificate.ncert * sizeof(Bytes));
522 	m.u.certificate.certs[0] = makebytes(cert, ncert);
523 	for (i = 0; i < numcerts && chp; i++, chp = chp->next)
524 		m.u.certificate.certs[i+1] = makebytes(chp->pem, chp->pemlen);
525 	if(!msgSend(c, &m, AQueue))
526 		goto Err;
527 	msgClear(&m);
528 
529 	m.tag = HServerHelloDone;
530 	if(!msgSend(c, &m, AFlush))
531 		goto Err;
532 	msgClear(&m);
533 
534 	if(!msgRecv(c, &m))
535 		goto Err;
536 	if(m.tag != HClientKeyExchange) {
537 		tlsError(c, EUnexpectedMessage, "expected a client key exchange");
538 		goto Err;
539 	}
540 	if(tlsSecSecrets(c->sec, c->version, m.u.clientKeyExchange.key->data, m.u.clientKeyExchange.key->len, kd, c->nsecret) < 0){
541 		tlsError(c, EHandshakeFailure, "couldn't set secrets: %r");
542 		goto Err;
543 	}
544 	if(trace)
545 		trace("tls secrets\n");
546 	secrets = (char*)emalloc(2*c->nsecret);
547 	enc64(secrets, 2*c->nsecret, kd, c->nsecret);
548 	rv = fprint(c->ctl, "secret %s %s 0 %s", c->digest, c->enc, secrets);
549 	memset(secrets, 0, 2*c->nsecret);
550 	free(secrets);
551 	memset(kd, 0, c->nsecret);
552 	if(rv < 0){
553 		tlsError(c, EHandshakeFailure, "can't set keys: %r");
554 		goto Err;
555 	}
556 	msgClear(&m);
557 
558 	/* no CertificateVerify; skip to Finished */
559 	if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 1) < 0){
560 		tlsError(c, EInternalError, "can't set finished: %r");
561 		goto Err;
562 	}
563 	if(!msgRecv(c, &m))
564 		goto Err;
565 	if(m.tag != HFinished) {
566 		tlsError(c, EUnexpectedMessage, "expected a finished");
567 		goto Err;
568 	}
569 	if(!finishedMatch(c, &m.u.finished)) {
570 		tlsError(c, EHandshakeFailure, "finished verification failed");
571 		goto Err;
572 	}
573 	msgClear(&m);
574 
575 	/* change cipher spec */
576 	if(fprint(c->ctl, "changecipher") < 0){
577 		tlsError(c, EInternalError, "can't enable cipher: %r");
578 		goto Err;
579 	}
580 
581 	if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 0) < 0){
582 		tlsError(c, EInternalError, "can't set finished: %r");
583 		goto Err;
584 	}
585 	m.tag = HFinished;
586 	m.u.finished = c->finished;
587 	if(!msgSend(c, &m, AFlush))
588 		goto Err;
589 	msgClear(&m);
590 	if(trace)
591 		trace("tls finished\n");
592 
593 	if(fprint(c->ctl, "opened") < 0)
594 		goto Err;
595 	tlsSecOk(c->sec);
596 	return c;
597 
598 Err:
599 	msgClear(&m);
600 	tlsConnectionFree(c);
601 	return 0;
602 }
603 
604 static TlsConnection *
tlsClient2(int ctl,int hand,uchar * csid,int ncsid,int (* trace)(char * fmt,...))605 tlsClient2(int ctl, int hand, uchar *csid, int ncsid, int (*trace)(char*fmt, ...))
606 {
607 	TlsConnection *c;
608 	Msg m;
609 	uchar kd[MaxKeyData], *epm;
610 	char *secrets;
611 	int creq, nepm, rv;
612 
613 	if(!initCiphers())
614 		return nil;
615 	epm = nil;
616 	c = emalloc(sizeof(TlsConnection));
617 	c->version = ProtocolVersion;
618 	c->ctl = ctl;
619 	c->hand = hand;
620 	c->trace = trace;
621 	c->isClient = 1;
622 	c->clientVersion = c->version;
623 
624 	c->sec = tlsSecInitc(c->clientVersion, c->crandom);
625 	if(c->sec == nil)
626 		goto Err;
627 
628 	/* client hello */
629 	memset(&m, 0, sizeof(m));
630 	m.tag = HClientHello;
631 	m.u.clientHello.version = c->clientVersion;
632 	memmove(m.u.clientHello.random, c->crandom, RandomSize);
633 	m.u.clientHello.sid = makebytes(csid, ncsid);
634 	m.u.clientHello.ciphers = makeciphers();
635 	m.u.clientHello.compressors = makebytes(compressors,sizeof(compressors));
636 	if(!msgSend(c, &m, AFlush))
637 		goto Err;
638 	msgClear(&m);
639 
640 	/* server hello */
641 	if(!msgRecv(c, &m))
642 		goto Err;
643 	if(m.tag != HServerHello) {
644 		tlsError(c, EUnexpectedMessage, "expected a server hello");
645 		goto Err;
646 	}
647 	if(setVersion(c, m.u.serverHello.version) < 0) {
648 		tlsError(c, EIllegalParameter, "incompatible version %r");
649 		goto Err;
650 	}
651 	memmove(c->srandom, m.u.serverHello.random, RandomSize);
652 	c->sid = makebytes(m.u.serverHello.sid->data, m.u.serverHello.sid->len);
653 	if(c->sid->len != 0 && c->sid->len != SidSize) {
654 		tlsError(c, EIllegalParameter, "invalid server session identifier");
655 		goto Err;
656 	}
657 	if(!setAlgs(c, m.u.serverHello.cipher)) {
658 		tlsError(c, EIllegalParameter, "invalid cipher suite");
659 		goto Err;
660 	}
661 	if(m.u.serverHello.compressor != CompressionNull) {
662 		tlsError(c, EIllegalParameter, "invalid compression");
663 		goto Err;
664 	}
665 	msgClear(&m);
666 
667 	/* certificate */
668 	if(!msgRecv(c, &m) || m.tag != HCertificate) {
669 		tlsError(c, EUnexpectedMessage, "expected a certificate");
670 		goto Err;
671 	}
672 	if(m.u.certificate.ncert < 1) {
673 		tlsError(c, EIllegalParameter, "runt certificate");
674 		goto Err;
675 	}
676 	c->cert = makebytes(m.u.certificate.certs[0]->data, m.u.certificate.certs[0]->len);
677 	msgClear(&m);
678 
679 	/* server key exchange (optional) */
680 	if(!msgRecv(c, &m))
681 		goto Err;
682 	if(m.tag == HServerKeyExchange) {
683 		tlsError(c, EUnexpectedMessage, "got an server key exchange");
684 		goto Err;
685 		// If implementing this later, watch out for rollback attack
686 		// described in Wagner Schneier 1996, section 4.4.
687 	}
688 
689 	/* certificate request (optional) */
690 	creq = 0;
691 	if(m.tag == HCertificateRequest) {
692 		creq = 1;
693 		msgClear(&m);
694 		if(!msgRecv(c, &m))
695 			goto Err;
696 	}
697 
698 	if(m.tag != HServerHelloDone) {
699 		tlsError(c, EUnexpectedMessage, "expected a server hello done");
700 		goto Err;
701 	}
702 	msgClear(&m);
703 
704 	if(tlsSecSecretc(c->sec, c->sid->data, c->sid->len, c->srandom,
705 			c->cert->data, c->cert->len, c->version, &epm, &nepm,
706 			kd, c->nsecret) < 0){
707 		tlsError(c, EBadCertificate, "invalid x509/rsa certificate");
708 		goto Err;
709 	}
710 	secrets = (char*)emalloc(2*c->nsecret);
711 	enc64(secrets, 2*c->nsecret, kd, c->nsecret);
712 	rv = fprint(c->ctl, "secret %s %s 1 %s", c->digest, c->enc, secrets);
713 	memset(secrets, 0, 2*c->nsecret);
714 	free(secrets);
715 	memset(kd, 0, c->nsecret);
716 	if(rv < 0){
717 		tlsError(c, EHandshakeFailure, "can't set keys: %r");
718 		goto Err;
719 	}
720 
721 	if(creq) {
722 		/* send a zero length certificate */
723 		m.tag = HCertificate;
724 		if(!msgSend(c, &m, AFlush))
725 			goto Err;
726 		msgClear(&m);
727 	}
728 
729 	/* client key exchange */
730 	m.tag = HClientKeyExchange;
731 	m.u.clientKeyExchange.key = makebytes(epm, nepm);
732 	free(epm);
733 	epm = nil;
734 	if(m.u.clientKeyExchange.key == nil) {
735 		tlsError(c, EHandshakeFailure, "can't set secret: %r");
736 		goto Err;
737 	}
738 	if(!msgSend(c, &m, AFlush))
739 		goto Err;
740 	msgClear(&m);
741 
742 	/* change cipher spec */
743 	if(fprint(c->ctl, "changecipher") < 0){
744 		tlsError(c, EInternalError, "can't enable cipher: %r");
745 		goto Err;
746 	}
747 
748 	// Cipherchange must occur immediately before Finished to avoid
749 	// potential hole;  see section 4.3 of Wagner Schneier 1996.
750 	if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 1) < 0){
751 		tlsError(c, EInternalError, "can't set finished 1: %r");
752 		goto Err;
753 	}
754 	m.tag = HFinished;
755 	m.u.finished = c->finished;
756 
757 	if(!msgSend(c, &m, AFlush)) {
758 		fprint(2, "tlsClient nepm=%d\n", nepm);
759 		tlsError(c, EInternalError, "can't flush after client Finished: %r");
760 		goto Err;
761 	}
762 	msgClear(&m);
763 
764 	if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 0) < 0){
765 		fprint(2, "tlsClient nepm=%d\n", nepm);
766 		tlsError(c, EInternalError, "can't set finished 0: %r");
767 		goto Err;
768 	}
769 	if(!msgRecv(c, &m)) {
770 		fprint(2, "tlsClient nepm=%d\n", nepm);
771 		tlsError(c, EInternalError, "can't read server Finished: %r");
772 		goto Err;
773 	}
774 	if(m.tag != HFinished) {
775 		fprint(2, "tlsClient nepm=%d\n", nepm);
776 		tlsError(c, EUnexpectedMessage, "expected a Finished msg from server");
777 		goto Err;
778 	}
779 
780 	if(!finishedMatch(c, &m.u.finished)) {
781 		tlsError(c, EHandshakeFailure, "finished verification failed");
782 		goto Err;
783 	}
784 	msgClear(&m);
785 
786 	if(fprint(c->ctl, "opened") < 0){
787 		if(trace)
788 			trace("unable to do final open: %r\n");
789 		goto Err;
790 	}
791 	tlsSecOk(c->sec);
792 	return c;
793 
794 Err:
795 	free(epm);
796 	msgClear(&m);
797 	tlsConnectionFree(c);
798 	return 0;
799 }
800 
801 
802 //================= message functions ========================
803 
804 static uchar sendbuf[9000], *sendp;
805 
806 static int
msgSend(TlsConnection * c,Msg * m,int act)807 msgSend(TlsConnection *c, Msg *m, int act)
808 {
809 	uchar *p; // sendp = start of new message;  p = write pointer
810 	int nn, n, i;
811 
812 	if(sendp == nil)
813 		sendp = sendbuf;
814 	p = sendp;
815 	if(c->trace)
816 		c->trace("send %s", msgPrint((char*)p, (sizeof sendbuf) - (p-sendbuf), m));
817 
818 	p[0] = m->tag;	// header - fill in size later
819 	p += 4;
820 
821 	switch(m->tag) {
822 	default:
823 		tlsError(c, EInternalError, "can't encode a %d", m->tag);
824 		goto Err;
825 	case HClientHello:
826 		// version
827 		put16(p, m->u.clientHello.version);
828 		p += 2;
829 
830 		// random
831 		memmove(p, m->u.clientHello.random, RandomSize);
832 		p += RandomSize;
833 
834 		// sid
835 		n = m->u.clientHello.sid->len;
836 		assert(n < 256);
837 		p[0] = n;
838 		memmove(p+1, m->u.clientHello.sid->data, n);
839 		p += n+1;
840 
841 		n = m->u.clientHello.ciphers->len;
842 		assert(n > 0 && n < 200);
843 		put16(p, n*2);
844 		p += 2;
845 		for(i=0; i<n; i++) {
846 			put16(p, m->u.clientHello.ciphers->data[i]);
847 			p += 2;
848 		}
849 
850 		n = m->u.clientHello.compressors->len;
851 		assert(n > 0);
852 		p[0] = n;
853 		memmove(p+1, m->u.clientHello.compressors->data, n);
854 		p += n+1;
855 		break;
856 	case HServerHello:
857 		put16(p, m->u.serverHello.version);
858 		p += 2;
859 
860 		// random
861 		memmove(p, m->u.serverHello.random, RandomSize);
862 		p += RandomSize;
863 
864 		// sid
865 		n = m->u.serverHello.sid->len;
866 		assert(n < 256);
867 		p[0] = n;
868 		memmove(p+1, m->u.serverHello.sid->data, n);
869 		p += n+1;
870 
871 		put16(p, m->u.serverHello.cipher);
872 		p += 2;
873 		p[0] = m->u.serverHello.compressor;
874 		p += 1;
875 		break;
876 	case HServerHelloDone:
877 		break;
878 	case HCertificate:
879 		nn = 0;
880 		for(i = 0; i < m->u.certificate.ncert; i++)
881 			nn += 3 + m->u.certificate.certs[i]->len;
882 		if(p + 3 + nn - sendbuf > sizeof(sendbuf)) {
883 			tlsError(c, EInternalError, "output buffer too small for certificate");
884 			goto Err;
885 		}
886 		put24(p, nn);
887 		p += 3;
888 		for(i = 0; i < m->u.certificate.ncert; i++){
889 			put24(p, m->u.certificate.certs[i]->len);
890 			p += 3;
891 			memmove(p, m->u.certificate.certs[i]->data, m->u.certificate.certs[i]->len);
892 			p += m->u.certificate.certs[i]->len;
893 		}
894 		break;
895 	case HClientKeyExchange:
896 		n = m->u.clientKeyExchange.key->len;
897 		if(c->version != SSL3Version){
898 			put16(p, n);
899 			p += 2;
900 		}
901 		memmove(p, m->u.clientKeyExchange.key->data, n);
902 		p += n;
903 		break;
904 	case HFinished:
905 		memmove(p, m->u.finished.verify, m->u.finished.n);
906 		p += m->u.finished.n;
907 		break;
908 	}
909 
910 	// go back and fill in size
911 	n = p-sendp;
912 	assert(p <= sendbuf+sizeof(sendbuf));
913 	put24(sendp+1, n-4);
914 
915 	// remember hash of Handshake messages
916 	if(m->tag != HHelloRequest) {
917 		md5(sendp, n, 0, &c->hsmd5);
918 		sha1(sendp, n, 0, &c->hssha1);
919 	}
920 
921 	sendp = p;
922 	if(act == AFlush){
923 		sendp = sendbuf;
924 		if(write(c->hand, sendbuf, p-sendbuf) < 0){
925 			fprint(2, "write error: %r\n");
926 			goto Err;
927 		}
928 	}
929 	msgClear(m);
930 	return 1;
931 Err:
932 	msgClear(m);
933 	return 0;
934 }
935 
936 static uchar*
tlsReadN(TlsConnection * c,int n)937 tlsReadN(TlsConnection *c, int n)
938 {
939 	uchar *p;
940 	int nn, nr;
941 
942 	nn = c->ep - c->rp;
943 	if(nn < n){
944 		if(c->rp != c->buf){
945 			memmove(c->buf, c->rp, nn);
946 			c->rp = c->buf;
947 			c->ep = &c->buf[nn];
948 		}
949 		for(; nn < n; nn += nr) {
950 			nr = read(c->hand, &c->rp[nn], n - nn);
951 			if(nr <= 0)
952 				return nil;
953 			c->ep += nr;
954 		}
955 	}
956 	p = c->rp;
957 	c->rp += n;
958 	return p;
959 }
960 
961 static int
msgRecv(TlsConnection * c,Msg * m)962 msgRecv(TlsConnection *c, Msg *m)
963 {
964 	uchar *p;
965 	int type, n, nn, i, nsid, nrandom, nciph;
966 
967 	for(;;) {
968 		p = tlsReadN(c, 4);
969 		if(p == nil)
970 			return 0;
971 		type = p[0];
972 		n = get24(p+1);
973 
974 		if(type != HHelloRequest)
975 			break;
976 		if(n != 0) {
977 			tlsError(c, EDecodeError, "invalid hello request during handshake");
978 			return 0;
979 		}
980 	}
981 
982 	if(n > sizeof(c->buf)) {
983 		tlsError(c, EDecodeError, "handshake message too long %d %d", n, sizeof(c->buf));
984 		return 0;
985 	}
986 
987 	if(type == HSSL2ClientHello){
988 		/* Cope with an SSL3 ClientHello expressed in SSL2 record format.
989 			This is sent by some clients that we must interoperate
990 			with, such as Java's JSSE and Microsoft's Internet Explorer. */
991 		p = tlsReadN(c, n);
992 		if(p == nil)
993 			return 0;
994 		md5(p, n, 0, &c->hsmd5);
995 		sha1(p, n, 0, &c->hssha1);
996 		m->tag = HClientHello;
997 		if(n < 22)
998 			goto Short;
999 		m->u.clientHello.version = get16(p+1);
1000 		p += 3;
1001 		n -= 3;
1002 		nn = get16(p); /* cipher_spec_len */
1003 		nsid = get16(p + 2);
1004 		nrandom = get16(p + 4);
1005 		p += 6;
1006 		n -= 6;
1007 		if(nsid != 0 	/* no sid's, since shouldn't restart using ssl2 header */
1008 				|| nrandom < 16 || nn % 3)
1009 			goto Err;
1010 		if(c->trace && (n - nrandom != nn))
1011 			c->trace("n-nrandom!=nn: n=%d nrandom=%d nn=%d\n", n, nrandom, nn);
1012 		/* ignore ssl2 ciphers and look for {0x00, ssl3 cipher} */
1013 		nciph = 0;
1014 		for(i = 0; i < nn; i += 3)
1015 			if(p[i] == 0)
1016 				nciph++;
1017 		m->u.clientHello.ciphers = newints(nciph);
1018 		nciph = 0;
1019 		for(i = 0; i < nn; i += 3)
1020 			if(p[i] == 0)
1021 				m->u.clientHello.ciphers->data[nciph++] = get16(&p[i + 1]);
1022 		p += nn;
1023 		m->u.clientHello.sid = makebytes(nil, 0);
1024 		if(nrandom > RandomSize)
1025 			nrandom = RandomSize;
1026 		memset(m->u.clientHello.random, 0, RandomSize - nrandom);
1027 		memmove(&m->u.clientHello.random[RandomSize - nrandom], p, nrandom);
1028 		m->u.clientHello.compressors = newbytes(1);
1029 		m->u.clientHello.compressors->data[0] = CompressionNull;
1030 		goto Ok;
1031 	}
1032 
1033 	md5(p, 4, 0, &c->hsmd5);
1034 	sha1(p, 4, 0, &c->hssha1);
1035 
1036 	p = tlsReadN(c, n);
1037 	if(p == nil)
1038 		return 0;
1039 
1040 	md5(p, n, 0, &c->hsmd5);
1041 	sha1(p, n, 0, &c->hssha1);
1042 
1043 	m->tag = type;
1044 
1045 	switch(type) {
1046 	default:
1047 		tlsError(c, EUnexpectedMessage, "can't decode a %d", type);
1048 		goto Err;
1049 	case HClientHello:
1050 		if(n < 2)
1051 			goto Short;
1052 		m->u.clientHello.version = get16(p);
1053 		p += 2;
1054 		n -= 2;
1055 
1056 		if(n < RandomSize)
1057 			goto Short;
1058 		memmove(m->u.clientHello.random, p, RandomSize);
1059 		p += RandomSize;
1060 		n -= RandomSize;
1061 		if(n < 1 || n < p[0]+1)
1062 			goto Short;
1063 		m->u.clientHello.sid = makebytes(p+1, p[0]);
1064 		p += m->u.clientHello.sid->len+1;
1065 		n -= m->u.clientHello.sid->len+1;
1066 
1067 		if(n < 2)
1068 			goto Short;
1069 		nn = get16(p);
1070 		p += 2;
1071 		n -= 2;
1072 
1073 		if((nn & 1) || n < nn || nn < 2)
1074 			goto Short;
1075 		m->u.clientHello.ciphers = newints(nn >> 1);
1076 		for(i = 0; i < nn; i += 2)
1077 			m->u.clientHello.ciphers->data[i >> 1] = get16(&p[i]);
1078 		p += nn;
1079 		n -= nn;
1080 
1081 		if(n < 1 || n < p[0]+1 || p[0] == 0)
1082 			goto Short;
1083 		nn = p[0];
1084 		m->u.clientHello.compressors = newbytes(nn);
1085 		memmove(m->u.clientHello.compressors->data, p+1, nn);
1086 		n -= nn + 1;
1087 		break;
1088 	case HServerHello:
1089 		if(n < 2)
1090 			goto Short;
1091 		m->u.serverHello.version = get16(p);
1092 		p += 2;
1093 		n -= 2;
1094 
1095 		if(n < RandomSize)
1096 			goto Short;
1097 		memmove(m->u.serverHello.random, p, RandomSize);
1098 		p += RandomSize;
1099 		n -= RandomSize;
1100 
1101 		if(n < 1 || n < p[0]+1)
1102 			goto Short;
1103 		m->u.serverHello.sid = makebytes(p+1, p[0]);
1104 		p += m->u.serverHello.sid->len+1;
1105 		n -= m->u.serverHello.sid->len+1;
1106 
1107 		if(n < 3)
1108 			goto Short;
1109 		m->u.serverHello.cipher = get16(p);
1110 		m->u.serverHello.compressor = p[2];
1111 		n -= 3;
1112 		break;
1113 	case HCertificate:
1114 		if(n < 3)
1115 			goto Short;
1116 		nn = get24(p);
1117 		p += 3;
1118 		n -= 3;
1119 		if(n != nn)
1120 			goto Short;
1121 		/* certs */
1122 		i = 0;
1123 		while(n > 0) {
1124 			if(n < 3)
1125 				goto Short;
1126 			nn = get24(p);
1127 			p += 3;
1128 			n -= 3;
1129 			if(nn > n)
1130 				goto Short;
1131 			m->u.certificate.ncert = i+1;
1132 			m->u.certificate.certs = erealloc(m->u.certificate.certs, (i+1)*sizeof(Bytes));
1133 			m->u.certificate.certs[i] = makebytes(p, nn);
1134 			p += nn;
1135 			n -= nn;
1136 			i++;
1137 		}
1138 		break;
1139 	case HCertificateRequest:
1140 		if(n < 1)
1141 			goto Short;
1142 		nn = p[0];
1143 		p += 1;
1144 		n -= 1;
1145 		if(nn < 1 || nn > n)
1146 			goto Short;
1147 		m->u.certificateRequest.types = makebytes(p, nn);
1148 		p += nn;
1149 		n -= nn;
1150 		if(n < 2)
1151 			goto Short;
1152 		nn = get16(p);
1153 		p += 2;
1154 		n -= 2;
1155 		/* nn == 0 can happen; yahoo's servers do it */
1156 		if(nn != n)
1157 			goto Short;
1158 		/* cas */
1159 		i = 0;
1160 		while(n > 0) {
1161 			if(n < 2)
1162 				goto Short;
1163 			nn = get16(p);
1164 			p += 2;
1165 			n -= 2;
1166 			if(nn < 1 || nn > n)
1167 				goto Short;
1168 			m->u.certificateRequest.nca = i+1;
1169 			m->u.certificateRequest.cas = erealloc(
1170 				m->u.certificateRequest.cas, (i+1)*sizeof(Bytes));
1171 			m->u.certificateRequest.cas[i] = makebytes(p, nn);
1172 			p += nn;
1173 			n -= nn;
1174 			i++;
1175 		}
1176 		break;
1177 	case HServerHelloDone:
1178 		break;
1179 	case HClientKeyExchange:
1180 		/*
1181 		 * this message depends upon the encryption selected
1182 		 * assume rsa.
1183 		 */
1184 		if(c->version == SSL3Version)
1185 			nn = n;
1186 		else{
1187 			if(n < 2)
1188 				goto Short;
1189 			nn = get16(p);
1190 			p += 2;
1191 			n -= 2;
1192 		}
1193 		if(n < nn)
1194 			goto Short;
1195 		m->u.clientKeyExchange.key = makebytes(p, nn);
1196 		n -= nn;
1197 		break;
1198 	case HFinished:
1199 		m->u.finished.n = c->finished.n;
1200 		if(n < m->u.finished.n)
1201 			goto Short;
1202 		memmove(m->u.finished.verify, p, m->u.finished.n);
1203 		n -= m->u.finished.n;
1204 		break;
1205 	}
1206 
1207 	if(type != HClientHello && n != 0)
1208 		goto Short;
1209 Ok:
1210 	if(c->trace){
1211 		char *buf;
1212 		buf = emalloc(8000);
1213 		c->trace("recv %s", msgPrint(buf, 8000, m));
1214 		free(buf);
1215 	}
1216 	return 1;
1217 Short:
1218 	tlsError(c, EDecodeError, "handshake message has invalid length");
1219 Err:
1220 	msgClear(m);
1221 	return 0;
1222 }
1223 
1224 static void
msgClear(Msg * m)1225 msgClear(Msg *m)
1226 {
1227 	int i;
1228 
1229 	switch(m->tag) {
1230 	default:
1231 		sysfatal("msgClear: unknown message type: %d", m->tag);
1232 	case HHelloRequest:
1233 		break;
1234 	case HClientHello:
1235 		freebytes(m->u.clientHello.sid);
1236 		freeints(m->u.clientHello.ciphers);
1237 		freebytes(m->u.clientHello.compressors);
1238 		break;
1239 	case HServerHello:
1240 		freebytes(m->u.clientHello.sid);
1241 		break;
1242 	case HCertificate:
1243 		for(i=0; i<m->u.certificate.ncert; i++)
1244 			freebytes(m->u.certificate.certs[i]);
1245 		free(m->u.certificate.certs);
1246 		break;
1247 	case HCertificateRequest:
1248 		freebytes(m->u.certificateRequest.types);
1249 		for(i=0; i<m->u.certificateRequest.nca; i++)
1250 			freebytes(m->u.certificateRequest.cas[i]);
1251 		free(m->u.certificateRequest.cas);
1252 		break;
1253 	case HServerHelloDone:
1254 		break;
1255 	case HClientKeyExchange:
1256 		freebytes(m->u.clientKeyExchange.key);
1257 		break;
1258 	case HFinished:
1259 		break;
1260 	}
1261 	memset(m, 0, sizeof(Msg));
1262 }
1263 
1264 static char *
bytesPrint(char * bs,char * be,char * s0,Bytes * b,char * s1)1265 bytesPrint(char *bs, char *be, char *s0, Bytes *b, char *s1)
1266 {
1267 	int i;
1268 
1269 	if(s0)
1270 		bs = seprint(bs, be, "%s", s0);
1271 	bs = seprint(bs, be, "[");
1272 	if(b == nil)
1273 		bs = seprint(bs, be, "nil");
1274 	else
1275 		for(i=0; i<b->len; i++)
1276 			bs = seprint(bs, be, "%.2x ", b->data[i]);
1277 	bs = seprint(bs, be, "]");
1278 	if(s1)
1279 		bs = seprint(bs, be, "%s", s1);
1280 	return bs;
1281 }
1282 
1283 static char *
intsPrint(char * bs,char * be,char * s0,Ints * b,char * s1)1284 intsPrint(char *bs, char *be, char *s0, Ints *b, char *s1)
1285 {
1286 	int i;
1287 
1288 	if(s0)
1289 		bs = seprint(bs, be, "%s", s0);
1290 	bs = seprint(bs, be, "[");
1291 	if(b == nil)
1292 		bs = seprint(bs, be, "nil");
1293 	else
1294 		for(i=0; i<b->len; i++)
1295 			bs = seprint(bs, be, "%x ", b->data[i]);
1296 	bs = seprint(bs, be, "]");
1297 	if(s1)
1298 		bs = seprint(bs, be, "%s", s1);
1299 	return bs;
1300 }
1301 
1302 static char*
msgPrint(char * buf,int n,Msg * m)1303 msgPrint(char *buf, int n, Msg *m)
1304 {
1305 	int i;
1306 	char *bs = buf, *be = buf+n;
1307 
1308 	switch(m->tag) {
1309 	default:
1310 		bs = seprint(bs, be, "unknown %d\n", m->tag);
1311 		break;
1312 	case HClientHello:
1313 		bs = seprint(bs, be, "ClientHello\n");
1314 		bs = seprint(bs, be, "\tversion: %.4x\n", m->u.clientHello.version);
1315 		bs = seprint(bs, be, "\trandom: ");
1316 		for(i=0; i<RandomSize; i++)
1317 			bs = seprint(bs, be, "%.2x", m->u.clientHello.random[i]);
1318 		bs = seprint(bs, be, "\n");
1319 		bs = bytesPrint(bs, be, "\tsid: ", m->u.clientHello.sid, "\n");
1320 		bs = intsPrint(bs, be, "\tciphers: ", m->u.clientHello.ciphers, "\n");
1321 		bs = bytesPrint(bs, be, "\tcompressors: ", m->u.clientHello.compressors, "\n");
1322 		break;
1323 	case HServerHello:
1324 		bs = seprint(bs, be, "ServerHello\n");
1325 		bs = seprint(bs, be, "\tversion: %.4x\n", m->u.serverHello.version);
1326 		bs = seprint(bs, be, "\trandom: ");
1327 		for(i=0; i<RandomSize; i++)
1328 			bs = seprint(bs, be, "%.2x", m->u.serverHello.random[i]);
1329 		bs = seprint(bs, be, "\n");
1330 		bs = bytesPrint(bs, be, "\tsid: ", m->u.serverHello.sid, "\n");
1331 		bs = seprint(bs, be, "\tcipher: %.4x\n", m->u.serverHello.cipher);
1332 		bs = seprint(bs, be, "\tcompressor: %.2x\n", m->u.serverHello.compressor);
1333 		break;
1334 	case HCertificate:
1335 		bs = seprint(bs, be, "Certificate\n");
1336 		for(i=0; i<m->u.certificate.ncert; i++)
1337 			bs = bytesPrint(bs, be, "\t", m->u.certificate.certs[i], "\n");
1338 		break;
1339 	case HCertificateRequest:
1340 		bs = seprint(bs, be, "CertificateRequest\n");
1341 		bs = bytesPrint(bs, be, "\ttypes: ", m->u.certificateRequest.types, "\n");
1342 		bs = seprint(bs, be, "\tcertificateauthorities\n");
1343 		for(i=0; i<m->u.certificateRequest.nca; i++)
1344 			bs = bytesPrint(bs, be, "\t\t", m->u.certificateRequest.cas[i], "\n");
1345 		break;
1346 	case HServerHelloDone:
1347 		bs = seprint(bs, be, "ServerHelloDone\n");
1348 		break;
1349 	case HClientKeyExchange:
1350 		bs = seprint(bs, be, "HClientKeyExchange\n");
1351 		bs = bytesPrint(bs, be, "\tkey: ", m->u.clientKeyExchange.key, "\n");
1352 		break;
1353 	case HFinished:
1354 		bs = seprint(bs, be, "HFinished\n");
1355 		for(i=0; i<m->u.finished.n; i++)
1356 			bs = seprint(bs, be, "%.2x", m->u.finished.verify[i]);
1357 		bs = seprint(bs, be, "\n");
1358 		break;
1359 	}
1360 	USED(bs);
1361 	return buf;
1362 }
1363 
1364 static void
tlsError(TlsConnection * c,int err,char * fmt,...)1365 tlsError(TlsConnection *c, int err, char *fmt, ...)
1366 {
1367 	char msg[512];
1368 	va_list arg;
1369 
1370 	va_start(arg, fmt);
1371 	vseprint(msg, msg+sizeof(msg), fmt, arg);
1372 	va_end(arg);
1373 	if(c->trace)
1374 		c->trace("tlsError: %s\n", msg);
1375 	else if(c->erred)
1376 		fprint(2, "double error: %r, %s", msg);
1377 	else
1378 		werrstr("tls: local %s", msg);
1379 	c->erred = 1;
1380 	fprint(c->ctl, "alert %d", err);
1381 }
1382 
1383 // commit to specific version number
1384 static int
setVersion(TlsConnection * c,int version)1385 setVersion(TlsConnection *c, int version)
1386 {
1387 	if(c->verset || version > MaxProtoVersion || version < MinProtoVersion)
1388 		return -1;
1389 	if(version > c->version)
1390 		version = c->version;
1391 	if(version == SSL3Version) {
1392 		c->version = version;
1393 		c->finished.n = SSL3FinishedLen;
1394 	}else if(version == TLSVersion){
1395 		c->version = version;
1396 		c->finished.n = TLSFinishedLen;
1397 	}else
1398 		return -1;
1399 	c->verset = 1;
1400 	return fprint(c->ctl, "version 0x%x", version);
1401 }
1402 
1403 // confirm that received Finished message matches the expected value
1404 static int
finishedMatch(TlsConnection * c,Finished * f)1405 finishedMatch(TlsConnection *c, Finished *f)
1406 {
1407 	return memcmp(f->verify, c->finished.verify, f->n) == 0;
1408 }
1409 
1410 // free memory associated with TlsConnection struct
1411 //		(but don't close the TLS channel itself)
1412 static void
tlsConnectionFree(TlsConnection * c)1413 tlsConnectionFree(TlsConnection *c)
1414 {
1415 	tlsSecClose(c->sec);
1416 	freebytes(c->sid);
1417 	freebytes(c->cert);
1418 	memset(c, 0, sizeof(c));
1419 	free(c);
1420 }
1421 
1422 
1423 //================= cipher choices ========================
1424 
1425 static int weakCipher[CipherMax] =
1426 {
1427 	1,	/* TLS_NULL_WITH_NULL_NULL */
1428 	1,	/* TLS_RSA_WITH_NULL_MD5 */
1429 	1,	/* TLS_RSA_WITH_NULL_SHA */
1430 	1,	/* TLS_RSA_EXPORT_WITH_RC4_40_MD5 */
1431 	0,	/* TLS_RSA_WITH_RC4_128_MD5 */
1432 	0,	/* TLS_RSA_WITH_RC4_128_SHA */
1433 	1,	/* TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5 */
1434 	0,	/* TLS_RSA_WITH_IDEA_CBC_SHA */
1435 	1,	/* TLS_RSA_EXPORT_WITH_DES40_CBC_SHA */
1436 	0,	/* TLS_RSA_WITH_DES_CBC_SHA */
1437 	0,	/* TLS_RSA_WITH_3DES_EDE_CBC_SHA */
1438 	1,	/* TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA */
1439 	0,	/* TLS_DH_DSS_WITH_DES_CBC_SHA */
1440 	0,	/* TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA */
1441 	1,	/* TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA */
1442 	0,	/* TLS_DH_RSA_WITH_DES_CBC_SHA */
1443 	0,	/* TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA */
1444 	1,	/* TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA */
1445 	0,	/* TLS_DHE_DSS_WITH_DES_CBC_SHA */
1446 	0,	/* TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA */
1447 	1,	/* TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA */
1448 	0,	/* TLS_DHE_RSA_WITH_DES_CBC_SHA */
1449 	0,	/* TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA */
1450 	1,	/* TLS_DH_anon_EXPORT_WITH_RC4_40_MD5 */
1451 	1,	/* TLS_DH_anon_WITH_RC4_128_MD5 */
1452 	1,	/* TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA */
1453 	1,	/* TLS_DH_anon_WITH_DES_CBC_SHA */
1454 	1,	/* TLS_DH_anon_WITH_3DES_EDE_CBC_SHA */
1455 };
1456 
1457 static int
setAlgs(TlsConnection * c,int a)1458 setAlgs(TlsConnection *c, int a)
1459 {
1460 	int i;
1461 
1462 	for(i = 0; i < nelem(cipherAlgs); i++){
1463 		if(cipherAlgs[i].tlsid == a){
1464 			c->enc = cipherAlgs[i].enc;
1465 			c->digest = cipherAlgs[i].digest;
1466 			c->nsecret = cipherAlgs[i].nsecret;
1467 			if(c->nsecret > MaxKeyData)
1468 				return 0;
1469 			return 1;
1470 		}
1471 	}
1472 	return 0;
1473 }
1474 
1475 static int
okCipher(Ints * cv)1476 okCipher(Ints *cv)
1477 {
1478 	int weak, i, j, c;
1479 
1480 	weak = 1;
1481 	for(i = 0; i < cv->len; i++) {
1482 		c = cv->data[i];
1483 		if(c >= CipherMax)
1484 			weak = 0;
1485 		else
1486 			weak &= weakCipher[c];
1487 		for(j = 0; j < nelem(cipherAlgs); j++)
1488 			if(cipherAlgs[j].ok && cipherAlgs[j].tlsid == c)
1489 				return c;
1490 	}
1491 	if(weak)
1492 		return -2;
1493 	return -1;
1494 }
1495 
1496 static int
okCompression(Bytes * cv)1497 okCompression(Bytes *cv)
1498 {
1499 	int i, j, c;
1500 
1501 	for(i = 0; i < cv->len; i++) {
1502 		c = cv->data[i];
1503 		for(j = 0; j < nelem(compressors); j++) {
1504 			if(compressors[j] == c)
1505 				return c;
1506 		}
1507 	}
1508 	return -1;
1509 }
1510 
1511 static Lock	ciphLock;
1512 static int	nciphers;
1513 
1514 static int
initCiphers(void)1515 initCiphers(void)
1516 {
1517 	enum {MaxAlgF = 1024, MaxAlgs = 10};
1518 	char s[MaxAlgF], *flds[MaxAlgs];
1519 	int i, j, n, ok;
1520 
1521 	lock(&ciphLock);
1522 	if(nciphers){
1523 		unlock(&ciphLock);
1524 		return nciphers;
1525 	}
1526 	j = open("#a/tls/encalgs", OREAD);
1527 	if(j < 0){
1528 		werrstr("can't open #a/tls/encalgs: %r");
1529 		return 0;
1530 	}
1531 	n = read(j, s, MaxAlgF-1);
1532 	close(j);
1533 	if(n <= 0){
1534 		werrstr("nothing in #a/tls/encalgs: %r");
1535 		return 0;
1536 	}
1537 	s[n] = 0;
1538 	n = getfields(s, flds, MaxAlgs, 1, " \t\r\n");
1539 	for(i = 0; i < nelem(cipherAlgs); i++){
1540 		ok = 0;
1541 		for(j = 0; j < n; j++){
1542 			if(strcmp(cipherAlgs[i].enc, flds[j]) == 0){
1543 				ok = 1;
1544 				break;
1545 			}
1546 		}
1547 		cipherAlgs[i].ok = ok;
1548 	}
1549 
1550 	j = open("#a/tls/hashalgs", OREAD);
1551 	if(j < 0){
1552 		werrstr("can't open #a/tls/hashalgs: %r");
1553 		return 0;
1554 	}
1555 	n = read(j, s, MaxAlgF-1);
1556 	close(j);
1557 	if(n <= 0){
1558 		werrstr("nothing in #a/tls/hashalgs: %r");
1559 		return 0;
1560 	}
1561 	s[n] = 0;
1562 	n = getfields(s, flds, MaxAlgs, 1, " \t\r\n");
1563 	for(i = 0; i < nelem(cipherAlgs); i++){
1564 		ok = 0;
1565 		for(j = 0; j < n; j++){
1566 			if(strcmp(cipherAlgs[i].digest, flds[j]) == 0){
1567 				ok = 1;
1568 				break;
1569 			}
1570 		}
1571 		cipherAlgs[i].ok &= ok;
1572 		if(cipherAlgs[i].ok)
1573 			nciphers++;
1574 	}
1575 	unlock(&ciphLock);
1576 	return nciphers;
1577 }
1578 
1579 static Ints*
makeciphers(void)1580 makeciphers(void)
1581 {
1582 	Ints *is;
1583 	int i, j;
1584 
1585 	is = newints(nciphers);
1586 	j = 0;
1587 	for(i = 0; i < nelem(cipherAlgs); i++){
1588 		if(cipherAlgs[i].ok)
1589 			is->data[j++] = cipherAlgs[i].tlsid;
1590 	}
1591 	return is;
1592 }
1593 
1594 
1595 
1596 //================= security functions ========================
1597 
1598 // given X.509 certificate, set up connection to factotum
1599 //	for using corresponding private key
1600 static AuthRpc*
factotum_rsa_open(uchar * cert,int certlen)1601 factotum_rsa_open(uchar *cert, int certlen)
1602 {
1603 	int afd;
1604 	char *s;
1605 	mpint *pub = nil;
1606 	RSApub *rsapub;
1607 	AuthRpc *rpc;
1608 
1609 	// start talking to factotum
1610 	if((afd = open("/mnt/factotum/rpc", ORDWR)) < 0)
1611 		return nil;
1612 	if((rpc = auth_allocrpc(afd)) == nil){
1613 		close(afd);
1614 		return nil;
1615 	}
1616 	s = "proto=rsa service=tls role=client";
1617 	if(auth_rpc(rpc, "start", s, strlen(s)) != ARok){
1618 		factotum_rsa_close(rpc);
1619 		return nil;
1620 	}
1621 
1622 	// roll factotum keyring around to match certificate
1623 	rsapub = X509toRSApub(cert, certlen, nil, 0);
1624 	while(1){
1625 		if(auth_rpc(rpc, "read", nil, 0) != ARok){
1626 			factotum_rsa_close(rpc);
1627 			rpc = nil;
1628 			goto done;
1629 		}
1630 		pub = strtomp(rpc->arg, nil, 16, nil);
1631 		assert(pub != nil);
1632 		if(mpcmp(pub,rsapub->n) == 0)
1633 			break;
1634 	}
1635 done:
1636 	mpfree(pub);
1637 	rsapubfree(rsapub);
1638 	return rpc;
1639 }
1640 
1641 static mpint*
factotum_rsa_decrypt(AuthRpc * rpc,mpint * cipher)1642 factotum_rsa_decrypt(AuthRpc *rpc, mpint *cipher)
1643 {
1644 	char *p;
1645 	int rv;
1646 
1647 	if((p = mptoa(cipher, 16, nil, 0)) == nil)
1648 		return nil;
1649 	rv = auth_rpc(rpc, "write", p, strlen(p));
1650 	free(p);
1651 	if(rv != ARok || auth_rpc(rpc, "read", nil, 0) != ARok)
1652 		return nil;
1653 	mpfree(cipher);
1654 	return strtomp(rpc->arg, nil, 16, nil);
1655 }
1656 
1657 static void
factotum_rsa_close(AuthRpc * rpc)1658 factotum_rsa_close(AuthRpc*rpc)
1659 {
1660 	if(!rpc)
1661 		return;
1662 	close(rpc->afd);
1663 	auth_freerpc(rpc);
1664 }
1665 
1666 static void
tlsPmd5(uchar * buf,int nbuf,uchar * key,int nkey,uchar * label,int nlabel,uchar * seed0,int nseed0,uchar * seed1,int nseed1)1667 tlsPmd5(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
1668 {
1669 	uchar ai[MD5dlen], tmp[MD5dlen];
1670 	int i, n;
1671 	MD5state *s;
1672 
1673 	// generate a1
1674 	s = hmac_md5(label, nlabel, key, nkey, nil, nil);
1675 	s = hmac_md5(seed0, nseed0, key, nkey, nil, s);
1676 	hmac_md5(seed1, nseed1, key, nkey, ai, s);
1677 
1678 	while(nbuf > 0) {
1679 		s = hmac_md5(ai, MD5dlen, key, nkey, nil, nil);
1680 		s = hmac_md5(label, nlabel, key, nkey, nil, s);
1681 		s = hmac_md5(seed0, nseed0, key, nkey, nil, s);
1682 		hmac_md5(seed1, nseed1, key, nkey, tmp, s);
1683 		n = MD5dlen;
1684 		if(n > nbuf)
1685 			n = nbuf;
1686 		for(i = 0; i < n; i++)
1687 			buf[i] ^= tmp[i];
1688 		buf += n;
1689 		nbuf -= n;
1690 		hmac_md5(ai, MD5dlen, key, nkey, tmp, nil);
1691 		memmove(ai, tmp, MD5dlen);
1692 	}
1693 }
1694 
1695 static void
tlsPsha1(uchar * buf,int nbuf,uchar * key,int nkey,uchar * label,int nlabel,uchar * seed0,int nseed0,uchar * seed1,int nseed1)1696 tlsPsha1(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
1697 {
1698 	uchar ai[SHA1dlen], tmp[SHA1dlen];
1699 	int i, n;
1700 	SHAstate *s;
1701 
1702 	// generate a1
1703 	s = hmac_sha1(label, nlabel, key, nkey, nil, nil);
1704 	s = hmac_sha1(seed0, nseed0, key, nkey, nil, s);
1705 	hmac_sha1(seed1, nseed1, key, nkey, ai, s);
1706 
1707 	while(nbuf > 0) {
1708 		s = hmac_sha1(ai, SHA1dlen, key, nkey, nil, nil);
1709 		s = hmac_sha1(label, nlabel, key, nkey, nil, s);
1710 		s = hmac_sha1(seed0, nseed0, key, nkey, nil, s);
1711 		hmac_sha1(seed1, nseed1, key, nkey, tmp, s);
1712 		n = SHA1dlen;
1713 		if(n > nbuf)
1714 			n = nbuf;
1715 		for(i = 0; i < n; i++)
1716 			buf[i] ^= tmp[i];
1717 		buf += n;
1718 		nbuf -= n;
1719 		hmac_sha1(ai, SHA1dlen, key, nkey, tmp, nil);
1720 		memmove(ai, tmp, SHA1dlen);
1721 	}
1722 }
1723 
1724 // fill buf with md5(args)^sha1(args)
1725 static void
tlsPRF(uchar * buf,int nbuf,uchar * key,int nkey,char * label,uchar * seed0,int nseed0,uchar * seed1,int nseed1)1726 tlsPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
1727 {
1728 	int i;
1729 	int nlabel = strlen(label);
1730 	int n = (nkey + 1) >> 1;
1731 
1732 	for(i = 0; i < nbuf; i++)
1733 		buf[i] = 0;
1734 	tlsPmd5(buf, nbuf, key, n, (uchar*)label, nlabel, seed0, nseed0, seed1, nseed1);
1735 	tlsPsha1(buf, nbuf, key+nkey-n, n, (uchar*)label, nlabel, seed0, nseed0, seed1, nseed1);
1736 }
1737 
1738 /*
1739  * for setting server session id's
1740  */
1741 static Lock	sidLock;
1742 static long	maxSid = 1;
1743 
1744 /* the keys are verified to have the same public components
1745  * and to function correctly with pkcs 1 encryption and decryption. */
1746 static TlsSec*
tlsSecInits(int cvers,uchar * csid,int ncsid,uchar * crandom,uchar * ssid,int * nssid,uchar * srandom)1747 tlsSecInits(int cvers, uchar *csid, int ncsid, uchar *crandom, uchar *ssid, int *nssid, uchar *srandom)
1748 {
1749 	TlsSec *sec = emalloc(sizeof(*sec));
1750 
1751 	USED(csid); USED(ncsid);  // ignore csid for now
1752 
1753 	memmove(sec->crandom, crandom, RandomSize);
1754 	sec->clientVers = cvers;
1755 
1756 	put32(sec->srandom, time(0));
1757 	genrandom(sec->srandom+4, RandomSize-4);
1758 	memmove(srandom, sec->srandom, RandomSize);
1759 
1760 	/*
1761 	 * make up a unique sid: use our pid, and and incrementing id
1762 	 * can signal no sid by setting nssid to 0.
1763 	 */
1764 	memset(ssid, 0, SidSize);
1765 	put32(ssid, getpid());
1766 	lock(&sidLock);
1767 	put32(ssid+4, maxSid++);
1768 	unlock(&sidLock);
1769 	*nssid = SidSize;
1770 	return sec;
1771 }
1772 
1773 static int
tlsSecSecrets(TlsSec * sec,int vers,uchar * epm,int nepm,uchar * kd,int nkd)1774 tlsSecSecrets(TlsSec *sec, int vers, uchar *epm, int nepm, uchar *kd, int nkd)
1775 {
1776 	if(epm != nil){
1777 		if(setVers(sec, vers) < 0)
1778 			goto Err;
1779 		serverMasterSecret(sec, epm, nepm);
1780 	}else if(sec->vers != vers){
1781 		werrstr("mismatched session versions");
1782 		goto Err;
1783 	}
1784 	setSecrets(sec, kd, nkd);
1785 	return 0;
1786 Err:
1787 	sec->ok = -1;
1788 	return -1;
1789 }
1790 
1791 static TlsSec*
tlsSecInitc(int cvers,uchar * crandom)1792 tlsSecInitc(int cvers, uchar *crandom)
1793 {
1794 	TlsSec *sec = emalloc(sizeof(*sec));
1795 	sec->clientVers = cvers;
1796 	put32(sec->crandom, time(0));
1797 	genrandom(sec->crandom+4, RandomSize-4);
1798 	memmove(crandom, sec->crandom, RandomSize);
1799 	return sec;
1800 }
1801 
1802 static int
tlsSecSecretc(TlsSec * sec,uchar * sid,int nsid,uchar * srandom,uchar * cert,int ncert,int vers,uchar ** epm,int * nepm,uchar * kd,int nkd)1803 tlsSecSecretc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers, uchar **epm, int *nepm, uchar *kd, int nkd)
1804 {
1805 	RSApub *pub;
1806 
1807 	pub = nil;
1808 
1809 	USED(sid);
1810 	USED(nsid);
1811 
1812 	memmove(sec->srandom, srandom, RandomSize);
1813 
1814 	if(setVers(sec, vers) < 0)
1815 		goto Err;
1816 
1817 	pub = X509toRSApub(cert, ncert, nil, 0);
1818 	if(pub == nil){
1819 		werrstr("invalid x509/rsa certificate");
1820 		goto Err;
1821 	}
1822 	if(clientMasterSecret(sec, pub, epm, nepm) < 0)
1823 		goto Err;
1824 	rsapubfree(pub);
1825 	setSecrets(sec, kd, nkd);
1826 	return 0;
1827 
1828 Err:
1829 	if(pub != nil)
1830 		rsapubfree(pub);
1831 	sec->ok = -1;
1832 	return -1;
1833 }
1834 
1835 static int
tlsSecFinished(TlsSec * sec,MD5state md5,SHAstate sha1,uchar * fin,int nfin,int isclient)1836 tlsSecFinished(TlsSec *sec, MD5state md5, SHAstate sha1, uchar *fin, int nfin, int isclient)
1837 {
1838 	if(sec->nfin != nfin){
1839 		sec->ok = -1;
1840 		werrstr("invalid finished exchange");
1841 		return -1;
1842 	}
1843 	md5.malloced = 0;
1844 	sha1.malloced = 0;
1845 	(*sec->setFinished)(sec, md5, sha1, fin, isclient);
1846 	return 1;
1847 }
1848 
1849 static void
tlsSecOk(TlsSec * sec)1850 tlsSecOk(TlsSec *sec)
1851 {
1852 	if(sec->ok == 0)
1853 		sec->ok = 1;
1854 }
1855 
1856 static void
tlsSecKill(TlsSec * sec)1857 tlsSecKill(TlsSec *sec)
1858 {
1859 	if(!sec)
1860 		return;
1861 	factotum_rsa_close(sec->rpc);
1862 	sec->ok = -1;
1863 }
1864 
1865 static void
tlsSecClose(TlsSec * sec)1866 tlsSecClose(TlsSec *sec)
1867 {
1868 	if(!sec)
1869 		return;
1870 	factotum_rsa_close(sec->rpc);
1871 	free(sec->server);
1872 	free(sec);
1873 }
1874 
1875 static int
setVers(TlsSec * sec,int v)1876 setVers(TlsSec *sec, int v)
1877 {
1878 	if(v == SSL3Version){
1879 		sec->setFinished = sslSetFinished;
1880 		sec->nfin = SSL3FinishedLen;
1881 		sec->prf = sslPRF;
1882 	}else if(v == TLSVersion){
1883 		sec->setFinished = tlsSetFinished;
1884 		sec->nfin = TLSFinishedLen;
1885 		sec->prf = tlsPRF;
1886 	}else{
1887 		werrstr("invalid version");
1888 		return -1;
1889 	}
1890 	sec->vers = v;
1891 	return 0;
1892 }
1893 
1894 /*
1895  * generate secret keys from the master secret.
1896  *
1897  * different crypto selections will require different amounts
1898  * of key expansion and use of key expansion data,
1899  * but it's all generated using the same function.
1900  */
1901 static void
setSecrets(TlsSec * sec,uchar * kd,int nkd)1902 setSecrets(TlsSec *sec, uchar *kd, int nkd)
1903 {
1904 	(*sec->prf)(kd, nkd, sec->sec, MasterSecretSize, "key expansion",
1905 			sec->srandom, RandomSize, sec->crandom, RandomSize);
1906 }
1907 
1908 /*
1909  * set the master secret from the pre-master secret.
1910  */
1911 static void
setMasterSecret(TlsSec * sec,Bytes * pm)1912 setMasterSecret(TlsSec *sec, Bytes *pm)
1913 {
1914 	(*sec->prf)(sec->sec, MasterSecretSize, pm->data, MasterSecretSize, "master secret",
1915 			sec->crandom, RandomSize, sec->srandom, RandomSize);
1916 }
1917 
1918 static void
serverMasterSecret(TlsSec * sec,uchar * epm,int nepm)1919 serverMasterSecret(TlsSec *sec, uchar *epm, int nepm)
1920 {
1921 	Bytes *pm;
1922 
1923 	pm = pkcs1_decrypt(sec, epm, nepm);
1924 
1925 	// if the client messed up, just continue as if everything is ok,
1926 	// to prevent attacks to check for correctly formatted messages.
1927 	// Hence the fprint(2,) can't be replaced by tlsError(), which sends an Alert msg to the client.
1928 	if(sec->ok < 0 || pm == nil || get16(pm->data) != sec->clientVers){
1929 		fprint(2, "serverMasterSecret failed ok=%d pm=%p pmvers=%x cvers=%x nepm=%d\n",
1930 			sec->ok, pm, pm ? get16(pm->data) : -1, sec->clientVers, nepm);
1931 		sec->ok = -1;
1932 		if(pm != nil)
1933 			freebytes(pm);
1934 		pm = newbytes(MasterSecretSize);
1935 		genrandom(pm->data, MasterSecretSize);
1936 	}
1937 	setMasterSecret(sec, pm);
1938 	memset(pm->data, 0, pm->len);
1939 	freebytes(pm);
1940 }
1941 
1942 static int
clientMasterSecret(TlsSec * sec,RSApub * pub,uchar ** epm,int * nepm)1943 clientMasterSecret(TlsSec *sec, RSApub *pub, uchar **epm, int *nepm)
1944 {
1945 	Bytes *pm, *key;
1946 
1947 	pm = newbytes(MasterSecretSize);
1948 	put16(pm->data, sec->clientVers);
1949 	genrandom(pm->data+2, MasterSecretSize - 2);
1950 
1951 	setMasterSecret(sec, pm);
1952 
1953 	key = pkcs1_encrypt(pm, pub, 2);
1954 	memset(pm->data, 0, pm->len);
1955 	freebytes(pm);
1956 	if(key == nil){
1957 		werrstr("tls pkcs1_encrypt failed");
1958 		return -1;
1959 	}
1960 
1961 	*nepm = key->len;
1962 	*epm = malloc(*nepm);
1963 	if(*epm == nil){
1964 		freebytes(key);
1965 		werrstr("out of memory");
1966 		return -1;
1967 	}
1968 	memmove(*epm, key->data, *nepm);
1969 
1970 	freebytes(key);
1971 
1972 	return 1;
1973 }
1974 
1975 static void
sslSetFinished(TlsSec * sec,MD5state hsmd5,SHAstate hssha1,uchar * finished,int isClient)1976 sslSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient)
1977 {
1978 	DigestState *s;
1979 	uchar h0[MD5dlen], h1[SHA1dlen], pad[48];
1980 	char *label;
1981 
1982 	if(isClient)
1983 		label = "CLNT";
1984 	else
1985 		label = "SRVR";
1986 
1987 	md5((uchar*)label, 4, nil, &hsmd5);
1988 	md5(sec->sec, MasterSecretSize, nil, &hsmd5);
1989 	memset(pad, 0x36, 48);
1990 	md5(pad, 48, nil, &hsmd5);
1991 	md5(nil, 0, h0, &hsmd5);
1992 	memset(pad, 0x5C, 48);
1993 	s = md5(sec->sec, MasterSecretSize, nil, nil);
1994 	s = md5(pad, 48, nil, s);
1995 	md5(h0, MD5dlen, finished, s);
1996 
1997 	sha1((uchar*)label, 4, nil, &hssha1);
1998 	sha1(sec->sec, MasterSecretSize, nil, &hssha1);
1999 	memset(pad, 0x36, 40);
2000 	sha1(pad, 40, nil, &hssha1);
2001 	sha1(nil, 0, h1, &hssha1);
2002 	memset(pad, 0x5C, 40);
2003 	s = sha1(sec->sec, MasterSecretSize, nil, nil);
2004 	s = sha1(pad, 40, nil, s);
2005 	sha1(h1, SHA1dlen, finished + MD5dlen, s);
2006 }
2007 
2008 // fill "finished" arg with md5(args)^sha1(args)
2009 static void
tlsSetFinished(TlsSec * sec,MD5state hsmd5,SHAstate hssha1,uchar * finished,int isClient)2010 tlsSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient)
2011 {
2012 	uchar h0[MD5dlen], h1[SHA1dlen];
2013 	char *label;
2014 
2015 	// get current hash value, but allow further messages to be hashed in
2016 	md5(nil, 0, h0, &hsmd5);
2017 	sha1(nil, 0, h1, &hssha1);
2018 
2019 	if(isClient)
2020 		label = "client finished";
2021 	else
2022 		label = "server finished";
2023 	tlsPRF(finished, TLSFinishedLen, sec->sec, MasterSecretSize, label, h0, MD5dlen, h1, SHA1dlen);
2024 }
2025 
2026 static void
sslPRF(uchar * buf,int nbuf,uchar * key,int nkey,char * label,uchar * seed0,int nseed0,uchar * seed1,int nseed1)2027 sslPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
2028 {
2029 	DigestState *s;
2030 	uchar sha1dig[SHA1dlen], md5dig[MD5dlen], tmp[26];
2031 	int i, n, len;
2032 
2033 	USED(label);
2034 	len = 1;
2035 	while(nbuf > 0){
2036 		if(len > 26)
2037 			return;
2038 		for(i = 0; i < len; i++)
2039 			tmp[i] = 'A' - 1 + len;
2040 		s = sha1(tmp, len, nil, nil);
2041 		s = sha1(key, nkey, nil, s);
2042 		s = sha1(seed0, nseed0, nil, s);
2043 		sha1(seed1, nseed1, sha1dig, s);
2044 		s = md5(key, nkey, nil, nil);
2045 		md5(sha1dig, SHA1dlen, md5dig, s);
2046 		n = MD5dlen;
2047 		if(n > nbuf)
2048 			n = nbuf;
2049 		memmove(buf, md5dig, n);
2050 		buf += n;
2051 		nbuf -= n;
2052 		len++;
2053 	}
2054 }
2055 
2056 static mpint*
bytestomp(Bytes * bytes)2057 bytestomp(Bytes* bytes)
2058 {
2059 	mpint* ans;
2060 
2061 	ans = betomp(bytes->data, bytes->len, nil);
2062 	return ans;
2063 }
2064 
2065 /*
2066  * Convert mpint* to Bytes, putting high order byte first.
2067  */
2068 static Bytes*
mptobytes(mpint * big)2069 mptobytes(mpint* big)
2070 {
2071 	int n, m;
2072 	uchar *a;
2073 	Bytes* ans;
2074 
2075 	a = nil;
2076 	n = (mpsignif(big)+7)/8;
2077 	m = mptobe(big, nil, n, &a);
2078 	ans = makebytes(a, m);
2079 	if(a != nil)
2080 		free(a);
2081 	return ans;
2082 }
2083 
2084 // Do RSA computation on block according to key, and pad
2085 // result on left with zeros to make it modlen long.
2086 static Bytes*
rsacomp(Bytes * block,RSApub * key,int modlen)2087 rsacomp(Bytes* block, RSApub* key, int modlen)
2088 {
2089 	mpint *x, *y;
2090 	Bytes *a, *ybytes;
2091 	int ylen;
2092 
2093 	x = bytestomp(block);
2094 	y = rsaencrypt(key, x, nil);
2095 	mpfree(x);
2096 	ybytes = mptobytes(y);
2097 	ylen = ybytes->len;
2098 
2099 	if(ylen < modlen) {
2100 		a = newbytes(modlen);
2101 		memset(a->data, 0, modlen-ylen);
2102 		memmove(a->data+modlen-ylen, ybytes->data, ylen);
2103 		freebytes(ybytes);
2104 		ybytes = a;
2105 	}
2106 	else if(ylen > modlen) {
2107 		// assume it has leading zeros (mod should make it so)
2108 		a = newbytes(modlen);
2109 		memmove(a->data, ybytes->data, modlen);
2110 		freebytes(ybytes);
2111 		ybytes = a;
2112 	}
2113 	mpfree(y);
2114 	return ybytes;
2115 }
2116 
2117 // encrypt data according to PKCS#1, /lib/rfc/rfc2437 9.1.2.1
2118 static Bytes*
pkcs1_encrypt(Bytes * data,RSApub * key,int blocktype)2119 pkcs1_encrypt(Bytes* data, RSApub* key, int blocktype)
2120 {
2121 	Bytes *pad, *eb, *ans;
2122 	int i, dlen, padlen, modlen;
2123 
2124 	modlen = (mpsignif(key->n)+7)/8;
2125 	dlen = data->len;
2126 	if(modlen < 12 || dlen > modlen - 11)
2127 		return nil;
2128 	padlen = modlen - 3 - dlen;
2129 	pad = newbytes(padlen);
2130 	genrandom(pad->data, padlen);
2131 	for(i = 0; i < padlen; i++) {
2132 		if(blocktype == 0)
2133 			pad->data[i] = 0;
2134 		else if(blocktype == 1)
2135 			pad->data[i] = 255;
2136 		else if(pad->data[i] == 0)
2137 			pad->data[i] = 1;
2138 	}
2139 	eb = newbytes(modlen);
2140 	eb->data[0] = 0;
2141 	eb->data[1] = blocktype;
2142 	memmove(eb->data+2, pad->data, padlen);
2143 	eb->data[padlen+2] = 0;
2144 	memmove(eb->data+padlen+3, data->data, dlen);
2145 	ans = rsacomp(eb, key, modlen);
2146 	freebytes(eb);
2147 	freebytes(pad);
2148 	return ans;
2149 }
2150 
2151 // decrypt data according to PKCS#1, with given key.
2152 // expect a block type of 2.
2153 static Bytes*
pkcs1_decrypt(TlsSec * sec,uchar * epm,int nepm)2154 pkcs1_decrypt(TlsSec *sec, uchar *epm, int nepm)
2155 {
2156 	Bytes *eb, *ans = nil;
2157 	int i, modlen;
2158 	mpint *x, *y;
2159 
2160 	modlen = (mpsignif(sec->rsapub->n)+7)/8;
2161 	if(nepm != modlen)
2162 		return nil;
2163 	x = betomp(epm, nepm, nil);
2164 	y = factotum_rsa_decrypt(sec->rpc, x);
2165 	if(y == nil)
2166 		return nil;
2167 	eb = mptobytes(y);
2168 	if(eb->len < modlen){ // pad on left with zeros
2169 		ans = newbytes(modlen);
2170 		memset(ans->data, 0, modlen-eb->len);
2171 		memmove(ans->data+modlen-eb->len, eb->data, eb->len);
2172 		freebytes(eb);
2173 		eb = ans;
2174 	}
2175 	if(eb->data[0] == 0 && eb->data[1] == 2) {
2176 		for(i = 2; i < modlen; i++)
2177 			if(eb->data[i] == 0)
2178 				break;
2179 		if(i < modlen - 1)
2180 			ans = makebytes(eb->data+i+1, modlen-(i+1));
2181 	}
2182 	freebytes(eb);
2183 	return ans;
2184 }
2185 
2186 
2187 //================= general utility functions ========================
2188 
2189 static void *
emalloc(int n)2190 emalloc(int n)
2191 {
2192 	void *p;
2193 	if(n==0)
2194 		n=1;
2195 	p = malloc(n);
2196 	if(p == nil){
2197 		exits("out of memory");
2198 	}
2199 	memset(p, 0, n);
2200 	return p;
2201 }
2202 
2203 static void *
erealloc(void * ReallocP,int ReallocN)2204 erealloc(void *ReallocP, int ReallocN)
2205 {
2206 	if(ReallocN == 0)
2207 		ReallocN = 1;
2208 	if(!ReallocP)
2209 		ReallocP = emalloc(ReallocN);
2210 	else if(!(ReallocP = realloc(ReallocP, ReallocN))){
2211 		exits("out of memory");
2212 	}
2213 	return(ReallocP);
2214 }
2215 
2216 static void
put32(uchar * p,u32int x)2217 put32(uchar *p, u32int x)
2218 {
2219 	p[0] = x>>24;
2220 	p[1] = x>>16;
2221 	p[2] = x>>8;
2222 	p[3] = x;
2223 }
2224 
2225 static void
put24(uchar * p,int x)2226 put24(uchar *p, int x)
2227 {
2228 	p[0] = x>>16;
2229 	p[1] = x>>8;
2230 	p[2] = x;
2231 }
2232 
2233 static void
put16(uchar * p,int x)2234 put16(uchar *p, int x)
2235 {
2236 	p[0] = x>>8;
2237 	p[1] = x;
2238 }
2239 
2240 static u32int
get32(uchar * p)2241 get32(uchar *p)
2242 {
2243 	return (p[0]<<24)|(p[1]<<16)|(p[2]<<8)|p[3];
2244 }
2245 
2246 static int
get24(uchar * p)2247 get24(uchar *p)
2248 {
2249 	return (p[0]<<16)|(p[1]<<8)|p[2];
2250 }
2251 
2252 static int
get16(uchar * p)2253 get16(uchar *p)
2254 {
2255 	return (p[0]<<8)|p[1];
2256 }
2257 
2258 #define OFFSET(x, s) offsetof(s, x)
2259 
2260 /*
2261  * malloc and return a new Bytes structure capable of
2262  * holding len bytes. (len >= 0)
2263  * Used to use crypt_malloc, which aborts if malloc fails.
2264  */
2265 static Bytes*
newbytes(int len)2266 newbytes(int len)
2267 {
2268 	Bytes* ans;
2269 
2270 	ans = (Bytes*)malloc(OFFSET(data[0], Bytes) + len);
2271 	ans->len = len;
2272 	return ans;
2273 }
2274 
2275 /*
2276  * newbytes(len), with data initialized from buf
2277  */
2278 static Bytes*
makebytes(uchar * buf,int len)2279 makebytes(uchar* buf, int len)
2280 {
2281 	Bytes* ans;
2282 
2283 	ans = newbytes(len);
2284 	memmove(ans->data, buf, len);
2285 	return ans;
2286 }
2287 
2288 static void
freebytes(Bytes * b)2289 freebytes(Bytes* b)
2290 {
2291 	if(b != nil)
2292 		free(b);
2293 }
2294 
2295 /* len is number of ints */
2296 static Ints*
newints(int len)2297 newints(int len)
2298 {
2299 	Ints* ans;
2300 
2301 	ans = (Ints*)malloc(OFFSET(data[0], Ints) + len*sizeof(int));
2302 	ans->len = len;
2303 	return ans;
2304 }
2305 
2306 static Ints*
makeints(int * buf,int len)2307 makeints(int* buf, int len)
2308 {
2309 	Ints* ans;
2310 
2311 	ans = newints(len);
2312 	if(len > 0)
2313 		memmove(ans->data, buf, len*sizeof(int));
2314 	return ans;
2315 }
2316 
2317 static void
freeints(Ints * b)2318 freeints(Ints* b)
2319 {
2320 	if(b != nil)
2321 		free(b);
2322 }
2323