xref: /plan9/sys/src/cmd/unix/drawterm/libsec/tlshand.c (revision 14cc0f535177405a84c5b73603a98e5db6674719)
1 #include <u.h>
2 #include <libc.h>
3 #include <bio.h>
4 #include <auth.h>
5 #include <mp.h>
6 #include <libsec.h>
7 
8 // The main groups of functions are:
9 //		client/server - main handshake protocol definition
10 //		message functions - formating handshake messages
11 //		cipher choices - catalog of digest and encrypt algorithms
12 //		security functions - PKCS#1, sslHMAC, session keygen
13 //		general utility functions - malloc, serialization
14 // The handshake protocol builds on the TLS/SSL3 record layer protocol,
15 // which is implemented in kernel device #a.  See also /lib/rfc/rfc2246.
16 
17 enum {
18 	TLSFinishedLen = 12,
19 	SSL3FinishedLen = MD5dlen+SHA1dlen,
20 	MaxKeyData = 104,	// amount of secret we may need
21 	MaxChunk = 1<<14,
22 	RandomSize = 32,
23 	SidSize = 32,
24 	MasterSecretSize = 48,
25 	AQueue = 0,
26 	AFlush = 1,
27 };
28 
29 typedef struct TlsSec TlsSec;
30 
31 typedef struct Bytes{
32 	int len;
33 	uchar data[1];  // [len]
34 } Bytes;
35 
36 typedef struct Ints{
37 	int len;
38 	int data[1];  // [len]
39 } Ints;
40 
41 typedef struct Algs{
42 	char *enc;
43 	char *digest;
44 	int nsecret;
45 	int tlsid;
46 	int ok;
47 } Algs;
48 
49 typedef struct Finished{
50 	uchar verify[SSL3FinishedLen];
51 	int n;
52 } Finished;
53 
54 typedef struct TlsConnection{
55 	TlsSec *sec;	// security management goo
56 	int hand, ctl;	// record layer file descriptors
57 	int erred;		// set when tlsError called
58 	int (*trace)(char*fmt, ...); // for debugging
59 	int version;	// protocol we are speaking
60 	int verset;		// version has been set
61 	int ver2hi;		// server got a version 2 hello
62 	int isClient;	// is this the client or server?
63 	Bytes *sid;		// SessionID
64 	Bytes *cert;	// only last - no chain
65 
66 	Lock statelk;
67 	int state;		// must be set using setstate
68 
69 	// input buffer for handshake messages
70 	uchar buf[MaxChunk+2048];
71 	uchar *rp, *ep;
72 
73 	uchar crandom[RandomSize];	// client random
74 	uchar srandom[RandomSize];	// server random
75 	int clientVersion;	// version in ClientHello
76 	char *digest;	// name of digest algorithm to use
77 	char *enc;		// name of encryption algorithm to use
78 	int nsecret;	// amount of secret data to init keys
79 
80 	// for finished messages
81 	MD5state	hsmd5;	// handshake hash
82 	SHAstate	hssha1;	// handshake hash
83 	Finished	finished;
84 } TlsConnection;
85 
86 typedef struct Msg{
87 	int tag;
88 	union {
89 		struct {
90 			int version;
91 			uchar 	random[RandomSize];
92 			Bytes*	sid;
93 			Ints*	ciphers;
94 			Bytes*	compressors;
95 		} clientHello;
96 		struct {
97 			int version;
98 			uchar 	random[RandomSize];
99 			Bytes*	sid;
100 			int cipher;
101 			int compressor;
102 		} serverHello;
103 		struct {
104 			int ncert;
105 			Bytes **certs;
106 		} certificate;
107 		struct {
108 			Bytes *types;
109 			int nca;
110 			Bytes **cas;
111 		} certificateRequest;
112 		struct {
113 			Bytes *key;
114 		} clientKeyExchange;
115 		Finished finished;
116 	} u;
117 } Msg;
118 
119 typedef struct TlsSec{
120 	char *server;	// name of remote; nil for server
121 	int ok;	// <0 killed; ==0 in progress; >0 reusable
122 	RSApub *rsapub;
123 	AuthRpc *rpc;	// factotum for rsa private key
124 	uchar sec[MasterSecretSize];	// master secret
125 	uchar crandom[RandomSize];	// client random
126 	uchar srandom[RandomSize];	// server random
127 	int clientVers;		// version in ClientHello
128 	int vers;			// final version
129 	// byte generation and handshake checksum
130 	void (*prf)(uchar*, int, uchar*, int, char*, uchar*, int, uchar*, int);
131 	void (*setFinished)(TlsSec*, MD5state, SHAstate, uchar*, int);
132 	int nfin;
133 } TlsSec;
134 
135 
136 enum {
137 	TLSVersion = 0x0301,
138 	SSL3Version = 0x0300,
139 	ProtocolVersion = 0x0301,	// maximum version we speak
140 	MinProtoVersion = 0x0300,	// limits on version we accept
141 	MaxProtoVersion	= 0x03ff,
142 };
143 
144 // handshake type
145 enum {
146 	HHelloRequest,
147 	HClientHello,
148 	HServerHello,
149 	HSSL2ClientHello = 9,  /* local convention;  see devtls.c */
150 	HCertificate = 11,
151 	HServerKeyExchange,
152 	HCertificateRequest,
153 	HServerHelloDone,
154 	HCertificateVerify,
155 	HClientKeyExchange,
156 	HFinished = 20,
157 	HMax
158 };
159 
160 // alerts
161 enum {
162 	ECloseNotify = 0,
163 	EUnexpectedMessage = 10,
164 	EBadRecordMac = 20,
165 	EDecryptionFailed = 21,
166 	ERecordOverflow = 22,
167 	EDecompressionFailure = 30,
168 	EHandshakeFailure = 40,
169 	ENoCertificate = 41,
170 	EBadCertificate = 42,
171 	EUnsupportedCertificate = 43,
172 	ECertificateRevoked = 44,
173 	ECertificateExpired = 45,
174 	ECertificateUnknown = 46,
175 	EIllegalParameter = 47,
176 	EUnknownCa = 48,
177 	EAccessDenied = 49,
178 	EDecodeError = 50,
179 	EDecryptError = 51,
180 	EExportRestriction = 60,
181 	EProtocolVersion = 70,
182 	EInsufficientSecurity = 71,
183 	EInternalError = 80,
184 	EUserCanceled = 90,
185 	ENoRenegotiation = 100,
186 	EMax = 256
187 };
188 
189 // cipher suites
190 enum {
191 	TLS_NULL_WITH_NULL_NULL	 		= 0x0000,
192 	TLS_RSA_WITH_NULL_MD5 			= 0x0001,
193 	TLS_RSA_WITH_NULL_SHA 			= 0x0002,
194 	TLS_RSA_EXPORT_WITH_RC4_40_MD5 		= 0x0003,
195 	TLS_RSA_WITH_RC4_128_MD5 		= 0x0004,
196 	TLS_RSA_WITH_RC4_128_SHA 		= 0x0005,
197 	TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5	= 0X0006,
198 	TLS_RSA_WITH_IDEA_CBC_SHA 		= 0X0007,
199 	TLS_RSA_EXPORT_WITH_DES40_CBC_SHA	= 0X0008,
200 	TLS_RSA_WITH_DES_CBC_SHA		= 0X0009,
201 	TLS_RSA_WITH_3DES_EDE_CBC_SHA		= 0X000A,
202 	TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA	= 0X000B,
203 	TLS_DH_DSS_WITH_DES_CBC_SHA		= 0X000C,
204 	TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA	= 0X000D,
205 	TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA	= 0X000E,
206 	TLS_DH_RSA_WITH_DES_CBC_SHA		= 0X000F,
207 	TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA	= 0X0010,
208 	TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA	= 0X0011,
209 	TLS_DHE_DSS_WITH_DES_CBC_SHA		= 0X0012,
210 	TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA	= 0X0013,	// ZZZ must be implemented for tls1.0 compliance
211 	TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA	= 0X0014,
212 	TLS_DHE_RSA_WITH_DES_CBC_SHA		= 0X0015,
213 	TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA	= 0X0016,
214 	TLS_DH_anon_EXPORT_WITH_RC4_40_MD5	= 0x0017,
215 	TLS_DH_anon_WITH_RC4_128_MD5 		= 0x0018,
216 	TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA	= 0X0019,
217 	TLS_DH_anon_WITH_DES_CBC_SHA		= 0X001A,
218 	TLS_DH_anon_WITH_3DES_EDE_CBC_SHA	= 0X001B,
219 
220 	TLS_RSA_WITH_AES_128_CBC_SHA		= 0X002f,	// aes, aka rijndael with 128 bit blocks
221 	TLS_DH_DSS_WITH_AES_128_CBC_SHA		= 0X0030,
222 	TLS_DH_RSA_WITH_AES_128_CBC_SHA		= 0X0031,
223 	TLS_DHE_DSS_WITH_AES_128_CBC_SHA	= 0X0032,
224 	TLS_DHE_RSA_WITH_AES_128_CBC_SHA	= 0X0033,
225 	TLS_DH_anon_WITH_AES_128_CBC_SHA	= 0X0034,
226 	TLS_RSA_WITH_AES_256_CBC_SHA		= 0X0035,
227 	TLS_DH_DSS_WITH_AES_256_CBC_SHA		= 0X0036,
228 	TLS_DH_RSA_WITH_AES_256_CBC_SHA		= 0X0037,
229 	TLS_DHE_DSS_WITH_AES_256_CBC_SHA	= 0X0038,
230 	TLS_DHE_RSA_WITH_AES_256_CBC_SHA	= 0X0039,
231 	TLS_DH_anon_WITH_AES_256_CBC_SHA	= 0X003A,
232 	CipherMax
233 };
234 
235 // compression methods
236 enum {
237 	CompressionNull = 0,
238 	CompressionMax
239 };
240 
241 static Algs cipherAlgs[] = {
242 	{"rc4_128", "md5",	2 * (16 + MD5dlen), TLS_RSA_WITH_RC4_128_MD5},
243 	{"rc4_128", "sha1",	2 * (16 + SHA1dlen), TLS_RSA_WITH_RC4_128_SHA},
244 	{"3des_ede_cbc","sha1",2*(4*8+SHA1dlen), TLS_RSA_WITH_3DES_EDE_CBC_SHA},
245 };
246 
247 static uchar compressors[] = {
248 	CompressionNull,
249 };
250 
251 static TlsConnection *tlsServer2(int ctl, int hand, uchar *cert, int ncert, int (*trace)(char*fmt, ...));
252 static TlsConnection *tlsClient2(int ctl, int hand, uchar *csid, int ncsid, int (*trace)(char*fmt, ...));
253 
254 static void	msgClear(Msg *m);
255 static char* msgPrint(char *buf, int n, Msg *m);
256 static int	msgRecv(TlsConnection *c, Msg *m);
257 static int	msgSend(TlsConnection *c, Msg *m, int act);
258 static void	tlsError(TlsConnection *c, int err, char *msg, ...);
259 #pragma	varargck argpos	tlsError 3
260 static int setVersion(TlsConnection *c, int version);
261 static int finishedMatch(TlsConnection *c, Finished *f);
262 static void tlsConnectionFree(TlsConnection *c);
263 
264 static int setAlgs(TlsConnection *c, int a);
265 static int okCipher(Ints *cv);
266 static int okCompression(Bytes *cv);
267 static int initCiphers(void);
268 static Ints* makeciphers(void);
269 
270 static TlsSec* tlsSecInits(int cvers, uchar *csid, int ncsid, uchar *crandom, uchar *ssid, int *nssid, uchar *srandom);
271 static int	tlsSecSecrets(TlsSec *sec, int vers, uchar *epm, int nepm, uchar *kd, int nkd);
272 static TlsSec*	tlsSecInitc(int cvers, uchar *crandom);
273 static int	tlsSecSecretc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers, uchar **epm, int *nepm, uchar *kd, int nkd);
274 static int	tlsSecFinished(TlsSec *sec, MD5state md5, SHAstate sha1, uchar *fin, int nfin, int isclient);
275 static void	tlsSecOk(TlsSec *sec);
276 static void	tlsSecKill(TlsSec *sec);
277 static void	tlsSecClose(TlsSec *sec);
278 static void	setMasterSecret(TlsSec *sec, Bytes *pm);
279 static void	serverMasterSecret(TlsSec *sec, uchar *epm, int nepm);
280 static void	setSecrets(TlsSec *sec, uchar *kd, int nkd);
281 static int	clientMasterSecret(TlsSec *sec, RSApub *pub, uchar **epm, int *nepm);
282 static Bytes *pkcs1_encrypt(Bytes* data, RSApub* key, int blocktype);
283 static Bytes *pkcs1_decrypt(TlsSec *sec, uchar *epm, int nepm);
284 static void	tlsSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient);
285 static void	sslSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient);
286 static void	sslPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label,
287 			uchar *seed0, int nseed0, uchar *seed1, int nseed1);
288 static int setVers(TlsSec *sec, int version);
289 
290 static AuthRpc* factotum_rsa_open(uchar *cert, int certlen);
291 static mpint* factotum_rsa_decrypt(AuthRpc *rpc, mpint *cipher);
292 static void factotum_rsa_close(AuthRpc*rpc);
293 
294 static void* emalloc(int);
295 static void* erealloc(void*, int);
296 static void put32(uchar *p, u32int);
297 static void put24(uchar *p, int);
298 static void put16(uchar *p, int);
299 static u32int get32(uchar *p);
300 static int get24(uchar *p);
301 static int get16(uchar *p);
302 static Bytes* newbytes(int len);
303 static Bytes* makebytes(uchar* buf, int len);
304 static void freebytes(Bytes* b);
305 static Ints* newints(int len);
306 static Ints* makeints(int* buf, int len);
307 static void freeints(Ints* b);
308 
309 //================= client/server ========================
310 
311 //	push TLS onto fd, returning new (application) file descriptor
312 //		or -1 if error.
313 int
tlsServer(int fd,TLSconn * conn)314 tlsServer(int fd, TLSconn *conn)
315 {
316 	char buf[8];
317 	char dname[64];
318 	int n, data, ctl, hand;
319 	TlsConnection *tls;
320 
321 	if(conn == nil)
322 		return -1;
323 	ctl = open("#a/tls/clone", ORDWR);
324 	if(ctl < 0)
325 		return -1;
326 	n = read(ctl, buf, sizeof(buf)-1);
327 	if(n < 0){
328 		close(ctl);
329 		return -1;
330 	}
331 	buf[n] = 0;
332 	sprint(conn->dir, "#a/tls/%s", buf);
333 	sprint(dname, "#a/tls/%s/hand", buf);
334 	hand = open(dname, ORDWR);
335 	if(hand < 0){
336 		close(ctl);
337 		return -1;
338 	}
339 	fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion);
340 	tls = tlsServer2(ctl, hand, conn->cert, conn->certlen, conn->trace);
341 	sprint(dname, "#a/tls/%s/data", buf);
342 	data = open(dname, ORDWR);
343 	close(fd);
344 	close(hand);
345 	close(ctl);
346 	if(data < 0){
347 		return -1;
348 	}
349 	if(tls == nil){
350 		close(data);
351 		return -1;
352 	}
353 	if(conn->cert)
354 		free(conn->cert);
355 	conn->cert = 0;  // client certificates are not yet implemented
356 	conn->certlen = 0;
357 	conn->sessionIDlen = tls->sid->len;
358 	conn->sessionID = emalloc(conn->sessionIDlen);
359 	memcpy(conn->sessionID, tls->sid->data, conn->sessionIDlen);
360 	tlsConnectionFree(tls);
361 	return data;
362 }
363 
364 //	push TLS onto fd, returning new (application) file descriptor
365 //		or -1 if error.
366 int
tlsClient(int fd,TLSconn * conn)367 tlsClient(int fd, TLSconn *conn)
368 {
369 	char buf[8];
370 	char dname[64];
371 	int n, data, ctl, hand;
372 	TlsConnection *tls;
373 
374 	if(!conn)
375 		return -1;
376 	ctl = open("#a/tls/clone", ORDWR);
377 	if(ctl < 0)
378 		return -1;
379 	n = read(ctl, buf, sizeof(buf)-1);
380 	if(n < 0){
381 		close(ctl);
382 		return -1;
383 	}
384 	buf[n] = 0;
385 	sprint(conn->dir, "#a/tls/%s", buf);
386 	sprint(dname, "#a/tls/%s/hand", buf);
387 	hand = open(dname, ORDWR);
388 	if(hand < 0){
389 		close(ctl);
390 		return -1;
391 	}
392 	sprint(dname, "#a/tls/%s/data", buf);
393 	data = open(dname, ORDWR);
394 	if(data < 0)
395 		return -1;
396 	fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion);
397 	tls = tlsClient2(ctl, hand, conn->sessionID, conn->sessionIDlen, conn->trace);
398 	close(fd);
399 	close(hand);
400 	close(ctl);
401 	if(tls == nil){
402 		close(data);
403 		return -1;
404 	}
405 	conn->certlen = tls->cert->len;
406 	conn->cert = emalloc(conn->certlen);
407 	memcpy(conn->cert, tls->cert->data, conn->certlen);
408 	conn->sessionIDlen = tls->sid->len;
409 	conn->sessionID = emalloc(conn->sessionIDlen);
410 	memcpy(conn->sessionID, tls->sid->data, conn->sessionIDlen);
411 	tlsConnectionFree(tls);
412 	return data;
413 }
414 
415 static TlsConnection *
tlsServer2(int ctl,int hand,uchar * cert,int ncert,int (* trace)(char * fmt,...))416 tlsServer2(int ctl, int hand, uchar *cert, int ncert, int (*trace)(char*fmt, ...))
417 {
418 	TlsConnection *c;
419 	Msg m;
420 	Bytes *csid;
421 	uchar sid[SidSize], kd[MaxKeyData];
422 	char *secrets;
423 	int cipher, compressor, nsid, rv;
424 
425 	if(trace)
426 		trace("tlsServer2\n");
427 	if(!initCiphers())
428 		return nil;
429 	c = emalloc(sizeof(TlsConnection));
430 	c->ctl = ctl;
431 	c->hand = hand;
432 	c->trace = trace;
433 	c->version = ProtocolVersion;
434 
435 	memset(&m, 0, sizeof(m));
436 	if(!msgRecv(c, &m)){
437 		if(trace)
438 			trace("initial msgRecv failed\n");
439 		goto Err;
440 	}
441 	if(m.tag != HClientHello) {
442 		tlsError(c, EUnexpectedMessage, "expected a client hello");
443 		goto Err;
444 	}
445 	c->clientVersion = m.u.clientHello.version;
446 	if(trace)
447 		trace("ClientHello version %x\n", c->clientVersion);
448 	if(setVersion(c, m.u.clientHello.version) < 0) {
449 		tlsError(c, EIllegalParameter, "incompatible version");
450 		goto Err;
451 	}
452 
453 	memmove(c->crandom, m.u.clientHello.random, RandomSize);
454 	cipher = okCipher(m.u.clientHello.ciphers);
455 	if(cipher < 0) {
456 		// reply with EInsufficientSecurity if we know that's the case
457 		if(cipher == -2)
458 			tlsError(c, EInsufficientSecurity, "cipher suites too weak");
459 		else
460 			tlsError(c, EHandshakeFailure, "no matching cipher suite");
461 		goto Err;
462 	}
463 	if(!setAlgs(c, cipher)){
464 		tlsError(c, EHandshakeFailure, "no matching cipher suite");
465 		goto Err;
466 	}
467 	compressor = okCompression(m.u.clientHello.compressors);
468 	if(compressor < 0) {
469 		tlsError(c, EHandshakeFailure, "no matching compressor");
470 		goto Err;
471 	}
472 
473 	csid = m.u.clientHello.sid;
474 	if(trace)
475 		trace("  cipher %d, compressor %d, csidlen %d\n", cipher, compressor, csid->len);
476 	c->sec = tlsSecInits(c->clientVersion, csid->data, csid->len, c->crandom, sid, &nsid, c->srandom);
477 	if(c->sec == nil){
478 		tlsError(c, EHandshakeFailure, "can't initialize security: %r");
479 		goto Err;
480 	}
481 	c->sec->rpc = factotum_rsa_open(cert, ncert);
482 	if(c->sec->rpc == nil){
483 		tlsError(c, EHandshakeFailure, "factotum_rsa_open: %r");
484 		goto Err;
485 	}
486 	c->sec->rsapub = X509toRSApub(cert, ncert, nil, 0);
487 	msgClear(&m);
488 
489 	m.tag = HServerHello;
490 	m.u.serverHello.version = c->version;
491 	memmove(m.u.serverHello.random, c->srandom, RandomSize);
492 	m.u.serverHello.cipher = cipher;
493 	m.u.serverHello.compressor = compressor;
494 	c->sid = makebytes(sid, nsid);
495 	m.u.serverHello.sid = makebytes(c->sid->data, c->sid->len);
496 	if(!msgSend(c, &m, AQueue))
497 		goto Err;
498 	msgClear(&m);
499 
500 	m.tag = HCertificate;
501 	m.u.certificate.ncert = 1;
502 	m.u.certificate.certs = emalloc(m.u.certificate.ncert * sizeof(Bytes));
503 	m.u.certificate.certs[0] = makebytes(cert, ncert);
504 	if(!msgSend(c, &m, AQueue))
505 		goto Err;
506 	msgClear(&m);
507 
508 	m.tag = HServerHelloDone;
509 	if(!msgSend(c, &m, AFlush))
510 		goto Err;
511 	msgClear(&m);
512 
513 	if(!msgRecv(c, &m))
514 		goto Err;
515 	if(m.tag != HClientKeyExchange) {
516 		tlsError(c, EUnexpectedMessage, "expected a client key exchange");
517 		goto Err;
518 	}
519 	if(tlsSecSecrets(c->sec, c->version, m.u.clientKeyExchange.key->data, m.u.clientKeyExchange.key->len, kd, c->nsecret) < 0){
520 		tlsError(c, EHandshakeFailure, "couldn't set secrets: %r");
521 		goto Err;
522 	}
523 	if(trace)
524 		trace("tls secrets\n");
525 	secrets = (char*)emalloc(2*c->nsecret);
526 	enc64(secrets, 2*c->nsecret, kd, c->nsecret);
527 	rv = fprint(c->ctl, "secret %s %s 0 %s", c->digest, c->enc, secrets);
528 	memset(secrets, 0, 2*c->nsecret);
529 	free(secrets);
530 	memset(kd, 0, c->nsecret);
531 	if(rv < 0){
532 		tlsError(c, EHandshakeFailure, "can't set keys: %r");
533 		goto Err;
534 	}
535 	msgClear(&m);
536 
537 	/* no CertificateVerify; skip to Finished */
538 	if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 1) < 0){
539 		tlsError(c, EInternalError, "can't set finished: %r");
540 		goto Err;
541 	}
542 	if(!msgRecv(c, &m))
543 		goto Err;
544 	if(m.tag != HFinished) {
545 		tlsError(c, EUnexpectedMessage, "expected a finished");
546 		goto Err;
547 	}
548 	if(!finishedMatch(c, &m.u.finished)) {
549 		tlsError(c, EHandshakeFailure, "finished verification failed");
550 		goto Err;
551 	}
552 	msgClear(&m);
553 
554 	/* change cipher spec */
555 	if(fprint(c->ctl, "changecipher") < 0){
556 		tlsError(c, EInternalError, "can't enable cipher: %r");
557 		goto Err;
558 	}
559 
560 	if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 0) < 0){
561 		tlsError(c, EInternalError, "can't set finished: %r");
562 		goto Err;
563 	}
564 	m.tag = HFinished;
565 	m.u.finished = c->finished;
566 	if(!msgSend(c, &m, AFlush))
567 		goto Err;
568 	msgClear(&m);
569 	if(trace)
570 		trace("tls finished\n");
571 
572 	if(fprint(c->ctl, "opened") < 0)
573 		goto Err;
574 	tlsSecOk(c->sec);
575 	return c;
576 
577 Err:
578 	msgClear(&m);
579 	tlsConnectionFree(c);
580 	return 0;
581 }
582 
583 static TlsConnection *
tlsClient2(int ctl,int hand,uchar * csid,int ncsid,int (* trace)(char * fmt,...))584 tlsClient2(int ctl, int hand, uchar *csid, int ncsid, int (*trace)(char*fmt, ...))
585 {
586 	TlsConnection *c;
587 	Msg m;
588 	uchar kd[MaxKeyData], *epm;
589 	char *secrets;
590 	int creq, nepm, rv;
591 
592 	if(!initCiphers())
593 		return nil;
594 	epm = nil;
595 	c = emalloc(sizeof(TlsConnection));
596 	c->version = ProtocolVersion;
597 	c->ctl = ctl;
598 	c->hand = hand;
599 	c->trace = trace;
600 	c->isClient = 1;
601 	c->clientVersion = c->version;
602 
603 	c->sec = tlsSecInitc(c->clientVersion, c->crandom);
604 	if(c->sec == nil)
605 		goto Err;
606 
607 	/* client hello */
608 	memset(&m, 0, sizeof(m));
609 	m.tag = HClientHello;
610 	m.u.clientHello.version = c->clientVersion;
611 	memmove(m.u.clientHello.random, c->crandom, RandomSize);
612 	m.u.clientHello.sid = makebytes(csid, ncsid);
613 	m.u.clientHello.ciphers = makeciphers();
614 	m.u.clientHello.compressors = makebytes(compressors,sizeof(compressors));
615 	if(!msgSend(c, &m, AFlush))
616 		goto Err;
617 	msgClear(&m);
618 
619 	/* server hello */
620 	if(!msgRecv(c, &m))
621 		goto Err;
622 	if(m.tag != HServerHello) {
623 		tlsError(c, EUnexpectedMessage, "expected a server hello");
624 		goto Err;
625 	}
626 	if(setVersion(c, m.u.serverHello.version) < 0) {
627 		tlsError(c, EIllegalParameter, "incompatible version %r");
628 		goto Err;
629 	}
630 	memmove(c->srandom, m.u.serverHello.random, RandomSize);
631 	c->sid = makebytes(m.u.serverHello.sid->data, m.u.serverHello.sid->len);
632 	if(c->sid->len != 0 && c->sid->len != SidSize) {
633 		tlsError(c, EIllegalParameter, "invalid server session identifier");
634 		goto Err;
635 	}
636 	if(!setAlgs(c, m.u.serverHello.cipher)) {
637 		tlsError(c, EIllegalParameter, "invalid cipher suite");
638 		goto Err;
639 	}
640 	if(m.u.serverHello.compressor != CompressionNull) {
641 		tlsError(c, EIllegalParameter, "invalid compression");
642 		goto Err;
643 	}
644 	msgClear(&m);
645 
646 	/* certificate */
647 	if(!msgRecv(c, &m) || m.tag != HCertificate) {
648 		tlsError(c, EUnexpectedMessage, "expected a certificate");
649 		goto Err;
650 	}
651 	if(m.u.certificate.ncert < 1) {
652 		tlsError(c, EIllegalParameter, "runt certificate");
653 		goto Err;
654 	}
655 	c->cert = makebytes(m.u.certificate.certs[0]->data, m.u.certificate.certs[0]->len);
656 	msgClear(&m);
657 
658 	/* server key exchange (optional) */
659 	if(!msgRecv(c, &m))
660 		goto Err;
661 	if(m.tag == HServerKeyExchange) {
662 		tlsError(c, EUnexpectedMessage, "got an server key exchange");
663 		goto Err;
664 		// If implementing this later, watch out for rollback attack
665 		// described in Wagner Schneier 1996, section 4.4.
666 	}
667 
668 	/* certificate request (optional) */
669 	creq = 0;
670 	if(m.tag == HCertificateRequest) {
671 		creq = 1;
672 		msgClear(&m);
673 		if(!msgRecv(c, &m))
674 			goto Err;
675 	}
676 
677 	if(m.tag != HServerHelloDone) {
678 		tlsError(c, EUnexpectedMessage, "expected a server hello done");
679 		goto Err;
680 	}
681 	msgClear(&m);
682 
683 	if(tlsSecSecretc(c->sec, c->sid->data, c->sid->len, c->srandom,
684 			c->cert->data, c->cert->len, c->version, &epm, &nepm,
685 			kd, c->nsecret) < 0){
686 		tlsError(c, EBadCertificate, "invalid x509/rsa certificate");
687 		goto Err;
688 	}
689 	secrets = (char*)emalloc(2*c->nsecret);
690 	enc64(secrets, 2*c->nsecret, kd, c->nsecret);
691 	rv = fprint(c->ctl, "secret %s %s 1 %s", c->digest, c->enc, secrets);
692 	memset(secrets, 0, 2*c->nsecret);
693 	free(secrets);
694 	memset(kd, 0, c->nsecret);
695 	if(rv < 0){
696 		tlsError(c, EHandshakeFailure, "can't set keys: %r");
697 		goto Err;
698 	}
699 
700 	if(creq) {
701 		/* send a zero length certificate */
702 		m.tag = HCertificate;
703 		if(!msgSend(c, &m, AFlush))
704 			goto Err;
705 		msgClear(&m);
706 	}
707 
708 	/* client key exchange */
709 	m.tag = HClientKeyExchange;
710 	m.u.clientKeyExchange.key = makebytes(epm, nepm);
711 	free(epm);
712 	epm = nil;
713 	if(m.u.clientKeyExchange.key == nil) {
714 		tlsError(c, EHandshakeFailure, "can't set secret: %r");
715 		goto Err;
716 	}
717 	if(!msgSend(c, &m, AFlush))
718 		goto Err;
719 	msgClear(&m);
720 
721 	/* change cipher spec */
722 	if(fprint(c->ctl, "changecipher") < 0){
723 		tlsError(c, EInternalError, "can't enable cipher: %r");
724 		goto Err;
725 	}
726 
727 	// Cipherchange must occur immediately before Finished to avoid
728 	// potential hole;  see section 4.3 of Wagner Schneier 1996.
729 	if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 1) < 0){
730 		tlsError(c, EInternalError, "can't set finished 1: %r");
731 		goto Err;
732 	}
733 	m.tag = HFinished;
734 	m.u.finished = c->finished;
735 
736 	if(!msgSend(c, &m, AFlush)) {
737 		fprint(2, "tlsClient nepm=%d\n", nepm);
738 		tlsError(c, EInternalError, "can't flush after client Finished: %r");
739 		goto Err;
740 	}
741 	msgClear(&m);
742 
743 	if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 0) < 0){
744 		fprint(2, "tlsClient nepm=%d\n", nepm);
745 		tlsError(c, EInternalError, "can't set finished 0: %r");
746 		goto Err;
747 	}
748 	if(!msgRecv(c, &m)) {
749 		fprint(2, "tlsClient nepm=%d\n", nepm);
750 		tlsError(c, EInternalError, "can't read server Finished: %r");
751 		goto Err;
752 	}
753 	if(m.tag != HFinished) {
754 		fprint(2, "tlsClient nepm=%d\n", nepm);
755 		tlsError(c, EUnexpectedMessage, "expected a Finished msg from server");
756 		goto Err;
757 	}
758 
759 	if(!finishedMatch(c, &m.u.finished)) {
760 		tlsError(c, EHandshakeFailure, "finished verification failed");
761 		goto Err;
762 	}
763 	msgClear(&m);
764 
765 	if(fprint(c->ctl, "opened") < 0){
766 		if(trace)
767 			trace("unable to do final open: %r\n");
768 		goto Err;
769 	}
770 	tlsSecOk(c->sec);
771 	return c;
772 
773 Err:
774 	free(epm);
775 	msgClear(&m);
776 	tlsConnectionFree(c);
777 	return 0;
778 }
779 
780 
781 //================= message functions ========================
782 
783 static uchar sendbuf[9000], *sendp;
784 
785 static int
msgSend(TlsConnection * c,Msg * m,int act)786 msgSend(TlsConnection *c, Msg *m, int act)
787 {
788 	uchar *p; // sendp = start of new message;  p = write pointer
789 	int nn, n, i;
790 
791 	if(sendp == nil)
792 		sendp = sendbuf;
793 	p = sendp;
794 	if(c->trace)
795 		c->trace("send %s", msgPrint((char*)p, (sizeof sendbuf) - (p-sendbuf), m));
796 
797 	p[0] = m->tag;	// header - fill in size later
798 	p += 4;
799 
800 	switch(m->tag) {
801 	default:
802 		tlsError(c, EInternalError, "can't encode a %d", m->tag);
803 		goto Err;
804 	case HClientHello:
805 		// version
806 		put16(p, m->u.clientHello.version);
807 		p += 2;
808 
809 		// random
810 		memmove(p, m->u.clientHello.random, RandomSize);
811 		p += RandomSize;
812 
813 		// sid
814 		n = m->u.clientHello.sid->len;
815 		assert(n < 256);
816 		p[0] = n;
817 		memmove(p+1, m->u.clientHello.sid->data, n);
818 		p += n+1;
819 
820 		n = m->u.clientHello.ciphers->len;
821 		assert(n > 0 && n < 200);
822 		put16(p, n*2);
823 		p += 2;
824 		for(i=0; i<n; i++) {
825 			put16(p, m->u.clientHello.ciphers->data[i]);
826 			p += 2;
827 		}
828 
829 		n = m->u.clientHello.compressors->len;
830 		assert(n > 0);
831 		p[0] = n;
832 		memmove(p+1, m->u.clientHello.compressors->data, n);
833 		p += n+1;
834 		break;
835 	case HServerHello:
836 		put16(p, m->u.serverHello.version);
837 		p += 2;
838 
839 		// random
840 		memmove(p, m->u.serverHello.random, RandomSize);
841 		p += RandomSize;
842 
843 		// sid
844 		n = m->u.serverHello.sid->len;
845 		assert(n < 256);
846 		p[0] = n;
847 		memmove(p+1, m->u.serverHello.sid->data, n);
848 		p += n+1;
849 
850 		put16(p, m->u.serverHello.cipher);
851 		p += 2;
852 		p[0] = m->u.serverHello.compressor;
853 		p += 1;
854 		break;
855 	case HServerHelloDone:
856 		break;
857 	case HCertificate:
858 		nn = 0;
859 		for(i = 0; i < m->u.certificate.ncert; i++)
860 			nn += 3 + m->u.certificate.certs[i]->len;
861 		if(p + 3 + nn - sendbuf > sizeof(sendbuf)) {
862 			tlsError(c, EInternalError, "output buffer too small for certificate");
863 			goto Err;
864 		}
865 		put24(p, nn);
866 		p += 3;
867 		for(i = 0; i < m->u.certificate.ncert; i++){
868 			put24(p, m->u.certificate.certs[i]->len);
869 			p += 3;
870 			memmove(p, m->u.certificate.certs[i]->data, m->u.certificate.certs[i]->len);
871 			p += m->u.certificate.certs[i]->len;
872 		}
873 		break;
874 	case HClientKeyExchange:
875 		n = m->u.clientKeyExchange.key->len;
876 		if(c->version != SSL3Version){
877 			put16(p, n);
878 			p += 2;
879 		}
880 		memmove(p, m->u.clientKeyExchange.key->data, n);
881 		p += n;
882 		break;
883 	case HFinished:
884 		memmove(p, m->u.finished.verify, m->u.finished.n);
885 		p += m->u.finished.n;
886 		break;
887 	}
888 
889 	// go back and fill in size
890 	n = p-sendp;
891 	assert(p <= sendbuf+sizeof(sendbuf));
892 	put24(sendp+1, n-4);
893 
894 	// remember hash of Handshake messages
895 	if(m->tag != HHelloRequest) {
896 		md5(sendp, n, 0, &c->hsmd5);
897 		sha1(sendp, n, 0, &c->hssha1);
898 	}
899 
900 	sendp = p;
901 	if(act == AFlush){
902 		sendp = sendbuf;
903 		if(write(c->hand, sendbuf, p-sendbuf) < 0){
904 			fprint(2, "write error: %r\n");
905 			goto Err;
906 		}
907 	}
908 	msgClear(m);
909 	return 1;
910 Err:
911 	msgClear(m);
912 	return 0;
913 }
914 
915 static uchar*
tlsReadN(TlsConnection * c,int n)916 tlsReadN(TlsConnection *c, int n)
917 {
918 	uchar *p;
919 	int nn, nr;
920 
921 	nn = c->ep - c->rp;
922 	if(nn < n){
923 		if(c->rp != c->buf){
924 			memmove(c->buf, c->rp, nn);
925 			c->rp = c->buf;
926 			c->ep = &c->buf[nn];
927 		}
928 		for(; nn < n; nn += nr) {
929 			nr = read(c->hand, &c->rp[nn], n - nn);
930 			if(nr <= 0)
931 				return nil;
932 			c->ep += nr;
933 		}
934 	}
935 	p = c->rp;
936 	c->rp += n;
937 	return p;
938 }
939 
940 static int
msgRecv(TlsConnection * c,Msg * m)941 msgRecv(TlsConnection *c, Msg *m)
942 {
943 	uchar *p;
944 	int type, n, nn, i, nsid, nrandom, nciph;
945 
946 	for(;;) {
947 		p = tlsReadN(c, 4);
948 		if(p == nil)
949 			return 0;
950 		type = p[0];
951 		n = get24(p+1);
952 
953 		if(type != HHelloRequest)
954 			break;
955 		if(n != 0) {
956 			tlsError(c, EDecodeError, "invalid hello request during handshake");
957 			return 0;
958 		}
959 	}
960 
961 	if(n > sizeof(c->buf)) {
962 		tlsError(c, EDecodeError, "handshake message too long %d %d", n, sizeof(c->buf));
963 		return 0;
964 	}
965 
966 	if(type == HSSL2ClientHello){
967 		/* Cope with an SSL3 ClientHello expressed in SSL2 record format.
968 			This is sent by some clients that we must interoperate
969 			with, such as Java's JSSE and Microsoft's Internet Explorer. */
970 		p = tlsReadN(c, n);
971 		if(p == nil)
972 			return 0;
973 		md5(p, n, 0, &c->hsmd5);
974 		sha1(p, n, 0, &c->hssha1);
975 		m->tag = HClientHello;
976 		if(n < 22)
977 			goto Short;
978 		m->u.clientHello.version = get16(p+1);
979 		p += 3;
980 		n -= 3;
981 		nn = get16(p); /* cipher_spec_len */
982 		nsid = get16(p + 2);
983 		nrandom = get16(p + 4);
984 		p += 6;
985 		n -= 6;
986 		if(nsid != 0 	/* no sid's, since shouldn't restart using ssl2 header */
987 				|| nrandom < 16 || nn % 3)
988 			goto Err;
989 		if(c->trace && (n - nrandom != nn))
990 			c->trace("n-nrandom!=nn: n=%d nrandom=%d nn=%d\n", n, nrandom, nn);
991 		/* ignore ssl2 ciphers and look for {0x00, ssl3 cipher} */
992 		nciph = 0;
993 		for(i = 0; i < nn; i += 3)
994 			if(p[i] == 0)
995 				nciph++;
996 		m->u.clientHello.ciphers = newints(nciph);
997 		nciph = 0;
998 		for(i = 0; i < nn; i += 3)
999 			if(p[i] == 0)
1000 				m->u.clientHello.ciphers->data[nciph++] = get16(&p[i + 1]);
1001 		p += nn;
1002 		m->u.clientHello.sid = makebytes(nil, 0);
1003 		if(nrandom > RandomSize)
1004 			nrandom = RandomSize;
1005 		memset(m->u.clientHello.random, 0, RandomSize - nrandom);
1006 		memmove(&m->u.clientHello.random[RandomSize - nrandom], p, nrandom);
1007 		m->u.clientHello.compressors = newbytes(1);
1008 		m->u.clientHello.compressors->data[0] = CompressionNull;
1009 		goto Ok;
1010 	}
1011 
1012 	md5(p, 4, 0, &c->hsmd5);
1013 	sha1(p, 4, 0, &c->hssha1);
1014 
1015 	p = tlsReadN(c, n);
1016 	if(p == nil)
1017 		return 0;
1018 
1019 	md5(p, n, 0, &c->hsmd5);
1020 	sha1(p, n, 0, &c->hssha1);
1021 
1022 	m->tag = type;
1023 
1024 	switch(type) {
1025 	default:
1026 		tlsError(c, EUnexpectedMessage, "can't decode a %d", type);
1027 		goto Err;
1028 	case HClientHello:
1029 		if(n < 2)
1030 			goto Short;
1031 		m->u.clientHello.version = get16(p);
1032 		p += 2;
1033 		n -= 2;
1034 
1035 		if(n < RandomSize)
1036 			goto Short;
1037 		memmove(m->u.clientHello.random, p, RandomSize);
1038 		p += RandomSize;
1039 		n -= RandomSize;
1040 		if(n < 1 || n < p[0]+1)
1041 			goto Short;
1042 		m->u.clientHello.sid = makebytes(p+1, p[0]);
1043 		p += m->u.clientHello.sid->len+1;
1044 		n -= m->u.clientHello.sid->len+1;
1045 
1046 		if(n < 2)
1047 			goto Short;
1048 		nn = get16(p);
1049 		p += 2;
1050 		n -= 2;
1051 
1052 		if((nn & 1) || n < nn || nn < 2)
1053 			goto Short;
1054 		m->u.clientHello.ciphers = newints(nn >> 1);
1055 		for(i = 0; i < nn; i += 2)
1056 			m->u.clientHello.ciphers->data[i >> 1] = get16(&p[i]);
1057 		p += nn;
1058 		n -= nn;
1059 
1060 		if(n < 1 || n < p[0]+1 || p[0] == 0)
1061 			goto Short;
1062 		nn = p[0];
1063 		m->u.clientHello.compressors = newbytes(nn);
1064 		memmove(m->u.clientHello.compressors->data, p+1, nn);
1065 		n -= nn + 1;
1066 		break;
1067 	case HServerHello:
1068 		if(n < 2)
1069 			goto Short;
1070 		m->u.serverHello.version = get16(p);
1071 		p += 2;
1072 		n -= 2;
1073 
1074 		if(n < RandomSize)
1075 			goto Short;
1076 		memmove(m->u.serverHello.random, p, RandomSize);
1077 		p += RandomSize;
1078 		n -= RandomSize;
1079 
1080 		if(n < 1 || n < p[0]+1)
1081 			goto Short;
1082 		m->u.serverHello.sid = makebytes(p+1, p[0]);
1083 		p += m->u.serverHello.sid->len+1;
1084 		n -= m->u.serverHello.sid->len+1;
1085 
1086 		if(n < 3)
1087 			goto Short;
1088 		m->u.serverHello.cipher = get16(p);
1089 		m->u.serverHello.compressor = p[2];
1090 		n -= 3;
1091 		break;
1092 	case HCertificate:
1093 		if(n < 3)
1094 			goto Short;
1095 		nn = get24(p);
1096 		p += 3;
1097 		n -= 3;
1098 		if(n != nn)
1099 			goto Short;
1100 		/* certs */
1101 		i = 0;
1102 		while(n > 0) {
1103 			if(n < 3)
1104 				goto Short;
1105 			nn = get24(p);
1106 			p += 3;
1107 			n -= 3;
1108 			if(nn > n)
1109 				goto Short;
1110 			m->u.certificate.ncert = i+1;
1111 			m->u.certificate.certs = erealloc(m->u.certificate.certs, (i+1)*sizeof(Bytes));
1112 			m->u.certificate.certs[i] = makebytes(p, nn);
1113 			p += nn;
1114 			n -= nn;
1115 			i++;
1116 		}
1117 		break;
1118 	case HCertificateRequest:
1119 		if(n < 2)
1120 			goto Short;
1121 		nn = get16(p);
1122 		p += 2;
1123 		n -= 2;
1124 		if(nn < 1 || nn > n)
1125 			goto Short;
1126 		m->u.certificateRequest.types = makebytes(p, nn);
1127 		nn = get24(p);
1128 		p += 3;
1129 		n -= 3;
1130 		if(nn == 0 || n != nn)
1131 			goto Short;
1132 		/* cas */
1133 		i = 0;
1134 		while(n > 0) {
1135 			if(n < 2)
1136 				goto Short;
1137 			nn = get16(p);
1138 			p += 2;
1139 			n -= 2;
1140 			if(nn < 1 || nn > n)
1141 				goto Short;
1142 			m->u.certificateRequest.nca = i+1;
1143 			m->u.certificateRequest.cas = erealloc(m->u.certificateRequest.cas, (i+1)*sizeof(Bytes));
1144 			m->u.certificateRequest.cas[i] = makebytes(p, nn);
1145 			p += nn;
1146 			n -= nn;
1147 			i++;
1148 		}
1149 		break;
1150 	case HServerHelloDone:
1151 		break;
1152 	case HClientKeyExchange:
1153 		/*
1154 		 * this message depends upon the encryption selected
1155 		 * assume rsa.
1156 		 */
1157 		if(c->version == SSL3Version)
1158 			nn = n;
1159 		else{
1160 			if(n < 2)
1161 				goto Short;
1162 			nn = get16(p);
1163 			p += 2;
1164 			n -= 2;
1165 		}
1166 		if(n < nn)
1167 			goto Short;
1168 		m->u.clientKeyExchange.key = makebytes(p, nn);
1169 		n -= nn;
1170 		break;
1171 	case HFinished:
1172 		m->u.finished.n = c->finished.n;
1173 		if(n < m->u.finished.n)
1174 			goto Short;
1175 		memmove(m->u.finished.verify, p, m->u.finished.n);
1176 		n -= m->u.finished.n;
1177 		break;
1178 	}
1179 
1180 	if(type != HClientHello && n != 0)
1181 		goto Short;
1182 Ok:
1183 	if(c->trace){
1184 		char buf[8000];
1185 		c->trace("recv %s", msgPrint(buf, sizeof buf, m));
1186 	}
1187 	return 1;
1188 Short:
1189 	tlsError(c, EDecodeError, "handshake message has invalid length");
1190 Err:
1191 	msgClear(m);
1192 	return 0;
1193 }
1194 
1195 static void
msgClear(Msg * m)1196 msgClear(Msg *m)
1197 {
1198 	int i;
1199 
1200 	switch(m->tag) {
1201 	default:
1202 		sysfatal("msgClear: unknown message type: %d", m->tag);
1203 	case HHelloRequest:
1204 		break;
1205 	case HClientHello:
1206 		freebytes(m->u.clientHello.sid);
1207 		freeints(m->u.clientHello.ciphers);
1208 		freebytes(m->u.clientHello.compressors);
1209 		break;
1210 	case HServerHello:
1211 		freebytes(m->u.clientHello.sid);
1212 		break;
1213 	case HCertificate:
1214 		for(i=0; i<m->u.certificate.ncert; i++)
1215 			freebytes(m->u.certificate.certs[i]);
1216 		free(m->u.certificate.certs);
1217 		break;
1218 	case HCertificateRequest:
1219 		freebytes(m->u.certificateRequest.types);
1220 		for(i=0; i<m->u.certificateRequest.nca; i++)
1221 			freebytes(m->u.certificateRequest.cas[i]);
1222 		free(m->u.certificateRequest.cas);
1223 		break;
1224 	case HServerHelloDone:
1225 		break;
1226 	case HClientKeyExchange:
1227 		freebytes(m->u.clientKeyExchange.key);
1228 		break;
1229 	case HFinished:
1230 		break;
1231 	}
1232 	memset(m, 0, sizeof(Msg));
1233 }
1234 
1235 static char *
bytesPrint(char * bs,char * be,char * s0,Bytes * b,char * s1)1236 bytesPrint(char *bs, char *be, char *s0, Bytes *b, char *s1)
1237 {
1238 	int i;
1239 
1240 	if(s0)
1241 		bs = seprint(bs, be, "%s", s0);
1242 	bs = seprint(bs, be, "[");
1243 	if(b == nil)
1244 		bs = seprint(bs, be, "nil");
1245 	else
1246 		for(i=0; i<b->len; i++)
1247 			bs = seprint(bs, be, "%.2x ", b->data[i]);
1248 	bs = seprint(bs, be, "]");
1249 	if(s1)
1250 		bs = seprint(bs, be, "%s", s1);
1251 	return bs;
1252 }
1253 
1254 static char *
intsPrint(char * bs,char * be,char * s0,Ints * b,char * s1)1255 intsPrint(char *bs, char *be, char *s0, Ints *b, char *s1)
1256 {
1257 	int i;
1258 
1259 	if(s0)
1260 		bs = seprint(bs, be, "%s", s0);
1261 	bs = seprint(bs, be, "[");
1262 	if(b == nil)
1263 		bs = seprint(bs, be, "nil");
1264 	else
1265 		for(i=0; i<b->len; i++)
1266 			bs = seprint(bs, be, "%x ", b->data[i]);
1267 	bs = seprint(bs, be, "]");
1268 	if(s1)
1269 		bs = seprint(bs, be, "%s", s1);
1270 	return bs;
1271 }
1272 
1273 static char*
msgPrint(char * buf,int n,Msg * m)1274 msgPrint(char *buf, int n, Msg *m)
1275 {
1276 	int i;
1277 	char *bs = buf, *be = buf+n;
1278 
1279 	switch(m->tag) {
1280 	default:
1281 		bs = seprint(bs, be, "unknown %d\n", m->tag);
1282 		break;
1283 	case HClientHello:
1284 		bs = seprint(bs, be, "ClientHello\n");
1285 		bs = seprint(bs, be, "\tversion: %.4x\n", m->u.clientHello.version);
1286 		bs = seprint(bs, be, "\trandom: ");
1287 		for(i=0; i<RandomSize; i++)
1288 			bs = seprint(bs, be, "%.2x", m->u.clientHello.random[i]);
1289 		bs = seprint(bs, be, "\n");
1290 		bs = bytesPrint(bs, be, "\tsid: ", m->u.clientHello.sid, "\n");
1291 		bs = intsPrint(bs, be, "\tciphers: ", m->u.clientHello.ciphers, "\n");
1292 		bs = bytesPrint(bs, be, "\tcompressors: ", m->u.clientHello.compressors, "\n");
1293 		break;
1294 	case HServerHello:
1295 		bs = seprint(bs, be, "ServerHello\n");
1296 		bs = seprint(bs, be, "\tversion: %.4x\n", m->u.serverHello.version);
1297 		bs = seprint(bs, be, "\trandom: ");
1298 		for(i=0; i<RandomSize; i++)
1299 			bs = seprint(bs, be, "%.2x", m->u.serverHello.random[i]);
1300 		bs = seprint(bs, be, "\n");
1301 		bs = bytesPrint(bs, be, "\tsid: ", m->u.serverHello.sid, "\n");
1302 		bs = seprint(bs, be, "\tcipher: %.4x\n", m->u.serverHello.cipher);
1303 		bs = seprint(bs, be, "\tcompressor: %.2x\n", m->u.serverHello.compressor);
1304 		break;
1305 	case HCertificate:
1306 		bs = seprint(bs, be, "Certificate\n");
1307 		for(i=0; i<m->u.certificate.ncert; i++)
1308 			bs = bytesPrint(bs, be, "\t", m->u.certificate.certs[i], "\n");
1309 		break;
1310 	case HCertificateRequest:
1311 		bs = seprint(bs, be, "CertificateRequest\n");
1312 		bs = bytesPrint(bs, be, "\ttypes: ", m->u.certificateRequest.types, "\n");
1313 		bs = seprint(bs, be, "\tcertificateauthorities\n");
1314 		for(i=0; i<m->u.certificateRequest.nca; i++)
1315 			bs = bytesPrint(bs, be, "\t\t", m->u.certificateRequest.cas[i], "\n");
1316 		break;
1317 	case HServerHelloDone:
1318 		bs = seprint(bs, be, "ServerHelloDone\n");
1319 		break;
1320 	case HClientKeyExchange:
1321 		bs = seprint(bs, be, "HClientKeyExchange\n");
1322 		bs = bytesPrint(bs, be, "\tkey: ", m->u.clientKeyExchange.key, "\n");
1323 		break;
1324 	case HFinished:
1325 		bs = seprint(bs, be, "HFinished\n");
1326 		for(i=0; i<m->u.finished.n; i++)
1327 			bs = seprint(bs, be, "%.2x", m->u.finished.verify[i]);
1328 		bs = seprint(bs, be, "\n");
1329 		break;
1330 	}
1331 	USED(bs);
1332 	return buf;
1333 }
1334 
1335 static void
tlsError(TlsConnection * c,int err,char * fmt,...)1336 tlsError(TlsConnection *c, int err, char *fmt, ...)
1337 {
1338 	char msg[512];
1339 	va_list arg;
1340 
1341 	va_start(arg, fmt);
1342 	vseprint(msg, msg+sizeof(msg), fmt, arg);
1343 	va_end(arg);
1344 	if(c->trace)
1345 		c->trace("tlsError: %s\n", msg);
1346 	else if(c->erred)
1347 		fprint(2, "double error: %r, %s", msg);
1348 	else
1349 		werrstr("tls: local %s", msg);
1350 	c->erred = 1;
1351 	fprint(c->ctl, "alert %d", err);
1352 }
1353 
1354 // commit to specific version number
1355 static int
setVersion(TlsConnection * c,int version)1356 setVersion(TlsConnection *c, int version)
1357 {
1358 	if(c->verset || version > MaxProtoVersion || version < MinProtoVersion)
1359 		return -1;
1360 	if(version > c->version)
1361 		version = c->version;
1362 	if(version == SSL3Version) {
1363 		c->version = version;
1364 		c->finished.n = SSL3FinishedLen;
1365 	}else if(version == TLSVersion){
1366 		c->version = version;
1367 		c->finished.n = TLSFinishedLen;
1368 	}else
1369 		return -1;
1370 	c->verset = 1;
1371 	return fprint(c->ctl, "version 0x%x", version);
1372 }
1373 
1374 // confirm that received Finished message matches the expected value
1375 static int
finishedMatch(TlsConnection * c,Finished * f)1376 finishedMatch(TlsConnection *c, Finished *f)
1377 {
1378 	return memcmp(f->verify, c->finished.verify, f->n) == 0;
1379 }
1380 
1381 // free memory associated with TlsConnection struct
1382 //		(but don't close the TLS channel itself)
1383 static void
tlsConnectionFree(TlsConnection * c)1384 tlsConnectionFree(TlsConnection *c)
1385 {
1386 	tlsSecClose(c->sec);
1387 	freebytes(c->sid);
1388 	freebytes(c->cert);
1389 	memset(c, 0, sizeof(c));
1390 	free(c);
1391 }
1392 
1393 
1394 //================= cipher choices ========================
1395 
1396 static int weakCipher[CipherMax] =
1397 {
1398 	1,	/* TLS_NULL_WITH_NULL_NULL */
1399 	1,	/* TLS_RSA_WITH_NULL_MD5 */
1400 	1,	/* TLS_RSA_WITH_NULL_SHA */
1401 	1,	/* TLS_RSA_EXPORT_WITH_RC4_40_MD5 */
1402 	0,	/* TLS_RSA_WITH_RC4_128_MD5 */
1403 	0,	/* TLS_RSA_WITH_RC4_128_SHA */
1404 	1,	/* TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5 */
1405 	0,	/* TLS_RSA_WITH_IDEA_CBC_SHA */
1406 	1,	/* TLS_RSA_EXPORT_WITH_DES40_CBC_SHA */
1407 	0,	/* TLS_RSA_WITH_DES_CBC_SHA */
1408 	0,	/* TLS_RSA_WITH_3DES_EDE_CBC_SHA */
1409 	1,	/* TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA */
1410 	0,	/* TLS_DH_DSS_WITH_DES_CBC_SHA */
1411 	0,	/* TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA */
1412 	1,	/* TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA */
1413 	0,	/* TLS_DH_RSA_WITH_DES_CBC_SHA */
1414 	0,	/* TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA */
1415 	1,	/* TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA */
1416 	0,	/* TLS_DHE_DSS_WITH_DES_CBC_SHA */
1417 	0,	/* TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA */
1418 	1,	/* TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA */
1419 	0,	/* TLS_DHE_RSA_WITH_DES_CBC_SHA */
1420 	0,	/* TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA */
1421 	1,	/* TLS_DH_anon_EXPORT_WITH_RC4_40_MD5 */
1422 	1,	/* TLS_DH_anon_WITH_RC4_128_MD5 */
1423 	1,	/* TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA */
1424 	1,	/* TLS_DH_anon_WITH_DES_CBC_SHA */
1425 	1,	/* TLS_DH_anon_WITH_3DES_EDE_CBC_SHA */
1426 };
1427 
1428 static int
setAlgs(TlsConnection * c,int a)1429 setAlgs(TlsConnection *c, int a)
1430 {
1431 	int i;
1432 
1433 	for(i = 0; i < nelem(cipherAlgs); i++){
1434 		if(cipherAlgs[i].tlsid == a){
1435 			c->enc = cipherAlgs[i].enc;
1436 			c->digest = cipherAlgs[i].digest;
1437 			c->nsecret = cipherAlgs[i].nsecret;
1438 			if(c->nsecret > MaxKeyData)
1439 				return 0;
1440 			return 1;
1441 		}
1442 	}
1443 	return 0;
1444 }
1445 
1446 static int
okCipher(Ints * cv)1447 okCipher(Ints *cv)
1448 {
1449 	int weak, i, j, c;
1450 
1451 	weak = 1;
1452 	for(i = 0; i < cv->len; i++) {
1453 		c = cv->data[i];
1454 		if(c >= CipherMax)
1455 			weak = 0;
1456 		else
1457 			weak &= weakCipher[c];
1458 		for(j = 0; j < nelem(cipherAlgs); j++)
1459 			if(cipherAlgs[j].ok && cipherAlgs[j].tlsid == c)
1460 				return c;
1461 	}
1462 	if(weak)
1463 		return -2;
1464 	return -1;
1465 }
1466 
1467 static int
okCompression(Bytes * cv)1468 okCompression(Bytes *cv)
1469 {
1470 	int i, j, c;
1471 
1472 	for(i = 0; i < cv->len; i++) {
1473 		c = cv->data[i];
1474 		for(j = 0; j < nelem(compressors); j++) {
1475 			if(compressors[j] == c)
1476 				return c;
1477 		}
1478 	}
1479 	return -1;
1480 }
1481 
1482 static Lock	ciphLock;
1483 static int	nciphers;
1484 
1485 static int
initCiphers(void)1486 initCiphers(void)
1487 {
1488 	enum {MaxAlgF = 1024, MaxAlgs = 10};
1489 	char s[MaxAlgF], *flds[MaxAlgs];
1490 	int i, j, n, ok;
1491 
1492 	lock(&ciphLock);
1493 	if(nciphers){
1494 		unlock(&ciphLock);
1495 		return nciphers;
1496 	}
1497 	j = open("#a/tls/encalgs", OREAD);
1498 	if(j < 0){
1499 		werrstr("can't open #a/tls/encalgs: %r");
1500 		return 0;
1501 	}
1502 	n = read(j, s, MaxAlgF-1);
1503 	close(j);
1504 	if(n <= 0){
1505 		werrstr("nothing in #a/tls/encalgs: %r");
1506 		return 0;
1507 	}
1508 	s[n] = 0;
1509 	n = getfields(s, flds, MaxAlgs, 1, " \t\r\n");
1510 	for(i = 0; i < nelem(cipherAlgs); i++){
1511 		ok = 0;
1512 		for(j = 0; j < n; j++){
1513 			if(strcmp(cipherAlgs[i].enc, flds[j]) == 0){
1514 				ok = 1;
1515 				break;
1516 			}
1517 		}
1518 		cipherAlgs[i].ok = ok;
1519 	}
1520 
1521 	j = open("#a/tls/hashalgs", OREAD);
1522 	if(j < 0){
1523 		werrstr("can't open #a/tls/hashalgs: %r");
1524 		return 0;
1525 	}
1526 	n = read(j, s, MaxAlgF-1);
1527 	close(j);
1528 	if(n <= 0){
1529 		werrstr("nothing in #a/tls/hashalgs: %r");
1530 		return 0;
1531 	}
1532 	s[n] = 0;
1533 	n = getfields(s, flds, MaxAlgs, 1, " \t\r\n");
1534 	for(i = 0; i < nelem(cipherAlgs); i++){
1535 		ok = 0;
1536 		for(j = 0; j < n; j++){
1537 			if(strcmp(cipherAlgs[i].digest, flds[j]) == 0){
1538 				ok = 1;
1539 				break;
1540 			}
1541 		}
1542 		cipherAlgs[i].ok &= ok;
1543 		if(cipherAlgs[i].ok)
1544 			nciphers++;
1545 	}
1546 	unlock(&ciphLock);
1547 	return nciphers;
1548 }
1549 
1550 static Ints*
makeciphers(void)1551 makeciphers(void)
1552 {
1553 	Ints *is;
1554 	int i, j;
1555 
1556 	is = newints(nciphers);
1557 	j = 0;
1558 	for(i = 0; i < nelem(cipherAlgs); i++){
1559 		if(cipherAlgs[i].ok)
1560 			is->data[j++] = cipherAlgs[i].tlsid;
1561 	}
1562 	return is;
1563 }
1564 
1565 
1566 
1567 //================= security functions ========================
1568 
1569 // given X.509 certificate, set up connection to factotum
1570 //	for using corresponding private key
1571 static AuthRpc*
factotum_rsa_open(uchar * cert,int certlen)1572 factotum_rsa_open(uchar *cert, int certlen)
1573 {
1574 	int afd;
1575 	char *s;
1576 	mpint *pub = nil;
1577 	RSApub *rsapub;
1578 	AuthRpc *rpc;
1579 
1580 	// start talking to factotum
1581 	if((afd = open("/mnt/factotum/rpc", ORDWR)) < 0)
1582 		return nil;
1583 	if((rpc = auth_allocrpc(afd)) == nil){
1584 		close(afd);
1585 		return nil;
1586 	}
1587 	s = "proto=rsa service=tls role=client";
1588 	if(auth_rpc(rpc, "start", s, strlen(s)) != ARok){
1589 		factotum_rsa_close(rpc);
1590 		return nil;
1591 	}
1592 
1593 	// roll factotum keyring around to match certificate
1594 	rsapub = X509toRSApub(cert, certlen, nil, 0);
1595 	while(1){
1596 		if(auth_rpc(rpc, "read", nil, 0) != ARok){
1597 			factotum_rsa_close(rpc);
1598 			rpc = nil;
1599 			goto done;
1600 		}
1601 		pub = strtomp(rpc->arg, nil, 16, nil);
1602 		assert(pub != nil);
1603 		if(mpcmp(pub,rsapub->n) == 0)
1604 			break;
1605 	}
1606 done:
1607 	mpfree(pub);
1608 	rsapubfree(rsapub);
1609 	return rpc;
1610 }
1611 
1612 static mpint*
factotum_rsa_decrypt(AuthRpc * rpc,mpint * cipher)1613 factotum_rsa_decrypt(AuthRpc *rpc, mpint *cipher)
1614 {
1615 	char *p;
1616 	int rv;
1617 
1618 	if((p = mptoa(cipher, 16, nil, 0)) == nil)
1619 		return nil;
1620 	rv = auth_rpc(rpc, "write", p, strlen(p));
1621 	free(p);
1622 	if(rv != ARok || auth_rpc(rpc, "read", nil, 0) != ARok)
1623 		return nil;
1624 	mpfree(cipher);
1625 	return strtomp(rpc->arg, nil, 16, nil);
1626 }
1627 
1628 static void
factotum_rsa_close(AuthRpc * rpc)1629 factotum_rsa_close(AuthRpc*rpc)
1630 {
1631 	if(!rpc)
1632 		return;
1633 	close(rpc->afd);
1634 	auth_freerpc(rpc);
1635 }
1636 
1637 static void
tlsPmd5(uchar * buf,int nbuf,uchar * key,int nkey,uchar * label,int nlabel,uchar * seed0,int nseed0,uchar * seed1,int nseed1)1638 tlsPmd5(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
1639 {
1640 	uchar ai[MD5dlen], tmp[MD5dlen];
1641 	int i, n;
1642 	MD5state *s;
1643 
1644 	// generate a1
1645 	s = hmac_md5(label, nlabel, key, nkey, nil, nil);
1646 	s = hmac_md5(seed0, nseed0, key, nkey, nil, s);
1647 	hmac_md5(seed1, nseed1, key, nkey, ai, s);
1648 
1649 	while(nbuf > 0) {
1650 		s = hmac_md5(ai, MD5dlen, key, nkey, nil, nil);
1651 		s = hmac_md5(label, nlabel, key, nkey, nil, s);
1652 		s = hmac_md5(seed0, nseed0, key, nkey, nil, s);
1653 		hmac_md5(seed1, nseed1, key, nkey, tmp, s);
1654 		n = MD5dlen;
1655 		if(n > nbuf)
1656 			n = nbuf;
1657 		for(i = 0; i < n; i++)
1658 			buf[i] ^= tmp[i];
1659 		buf += n;
1660 		nbuf -= n;
1661 		hmac_md5(ai, MD5dlen, key, nkey, tmp, nil);
1662 		memmove(ai, tmp, MD5dlen);
1663 	}
1664 }
1665 
1666 static void
tlsPsha1(uchar * buf,int nbuf,uchar * key,int nkey,uchar * label,int nlabel,uchar * seed0,int nseed0,uchar * seed1,int nseed1)1667 tlsPsha1(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
1668 {
1669 	uchar ai[SHA1dlen], tmp[SHA1dlen];
1670 	int i, n;
1671 	SHAstate *s;
1672 
1673 	// generate a1
1674 	s = hmac_sha1(label, nlabel, key, nkey, nil, nil);
1675 	s = hmac_sha1(seed0, nseed0, key, nkey, nil, s);
1676 	hmac_sha1(seed1, nseed1, key, nkey, ai, s);
1677 
1678 	while(nbuf > 0) {
1679 		s = hmac_sha1(ai, SHA1dlen, key, nkey, nil, nil);
1680 		s = hmac_sha1(label, nlabel, key, nkey, nil, s);
1681 		s = hmac_sha1(seed0, nseed0, key, nkey, nil, s);
1682 		hmac_sha1(seed1, nseed1, key, nkey, tmp, s);
1683 		n = SHA1dlen;
1684 		if(n > nbuf)
1685 			n = nbuf;
1686 		for(i = 0; i < n; i++)
1687 			buf[i] ^= tmp[i];
1688 		buf += n;
1689 		nbuf -= n;
1690 		hmac_sha1(ai, SHA1dlen, key, nkey, tmp, nil);
1691 		memmove(ai, tmp, SHA1dlen);
1692 	}
1693 }
1694 
1695 // fill buf with md5(args)^sha1(args)
1696 static void
tlsPRF(uchar * buf,int nbuf,uchar * key,int nkey,char * label,uchar * seed0,int nseed0,uchar * seed1,int nseed1)1697 tlsPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
1698 {
1699 	int i;
1700 	int nlabel = strlen(label);
1701 	int n = (nkey + 1) >> 1;
1702 
1703 	for(i = 0; i < nbuf; i++)
1704 		buf[i] = 0;
1705 	tlsPmd5(buf, nbuf, key, n, (uchar*)label, nlabel, seed0, nseed0, seed1, nseed1);
1706 	tlsPsha1(buf, nbuf, key+nkey-n, n, (uchar*)label, nlabel, seed0, nseed0, seed1, nseed1);
1707 }
1708 
1709 /*
1710  * for setting server session id's
1711  */
1712 static Lock	sidLock;
1713 static long	maxSid = 1;
1714 
1715 /* the keys are verified to have the same public components
1716  * and to function correctly with pkcs 1 encryption and decryption. */
1717 static TlsSec*
tlsSecInits(int cvers,uchar * csid,int ncsid,uchar * crandom,uchar * ssid,int * nssid,uchar * srandom)1718 tlsSecInits(int cvers, uchar *csid, int ncsid, uchar *crandom, uchar *ssid, int *nssid, uchar *srandom)
1719 {
1720 	TlsSec *sec = emalloc(sizeof(*sec));
1721 
1722 	USED(csid); USED(ncsid);  // ignore csid for now
1723 
1724 	memmove(sec->crandom, crandom, RandomSize);
1725 	sec->clientVers = cvers;
1726 
1727 	put32(sec->srandom, time(0));
1728 	genrandom(sec->srandom+4, RandomSize-4);
1729 	memmove(srandom, sec->srandom, RandomSize);
1730 
1731 	/*
1732 	 * make up a unique sid: use our pid, and and incrementing id
1733 	 * can signal no sid by setting nssid to 0.
1734 	 */
1735 	memset(ssid, 0, SidSize);
1736 	put32(ssid, getpid());
1737 	lock(&sidLock);
1738 	put32(ssid+4, maxSid++);
1739 	unlock(&sidLock);
1740 	*nssid = SidSize;
1741 	return sec;
1742 }
1743 
1744 static int
tlsSecSecrets(TlsSec * sec,int vers,uchar * epm,int nepm,uchar * kd,int nkd)1745 tlsSecSecrets(TlsSec *sec, int vers, uchar *epm, int nepm, uchar *kd, int nkd)
1746 {
1747 	if(epm != nil){
1748 		if(setVers(sec, vers) < 0)
1749 			goto Err;
1750 		serverMasterSecret(sec, epm, nepm);
1751 	}else if(sec->vers != vers){
1752 		werrstr("mismatched session versions");
1753 		goto Err;
1754 	}
1755 	setSecrets(sec, kd, nkd);
1756 	return 0;
1757 Err:
1758 	sec->ok = -1;
1759 	return -1;
1760 }
1761 
1762 static TlsSec*
tlsSecInitc(int cvers,uchar * crandom)1763 tlsSecInitc(int cvers, uchar *crandom)
1764 {
1765 	TlsSec *sec = emalloc(sizeof(*sec));
1766 	sec->clientVers = cvers;
1767 	put32(sec->crandom, time(0));
1768 	genrandom(sec->crandom+4, RandomSize-4);
1769 	memmove(crandom, sec->crandom, RandomSize);
1770 	return sec;
1771 }
1772 
1773 static int
tlsSecSecretc(TlsSec * sec,uchar * sid,int nsid,uchar * srandom,uchar * cert,int ncert,int vers,uchar ** epm,int * nepm,uchar * kd,int nkd)1774 tlsSecSecretc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers, uchar **epm, int *nepm, uchar *kd, int nkd)
1775 {
1776 	RSApub *pub;
1777 
1778 	pub = nil;
1779 
1780 	USED(sid);
1781 	USED(nsid);
1782 
1783 	memmove(sec->srandom, srandom, RandomSize);
1784 
1785 	if(setVers(sec, vers) < 0)
1786 		goto Err;
1787 
1788 	pub = X509toRSApub(cert, ncert, nil, 0);
1789 	if(pub == nil){
1790 		werrstr("invalid x509/rsa certificate");
1791 		goto Err;
1792 	}
1793 	if(clientMasterSecret(sec, pub, epm, nepm) < 0)
1794 		goto Err;
1795 	rsapubfree(pub);
1796 	setSecrets(sec, kd, nkd);
1797 	return 0;
1798 
1799 Err:
1800 	if(pub != nil)
1801 		rsapubfree(pub);
1802 	sec->ok = -1;
1803 	return -1;
1804 }
1805 
1806 static int
tlsSecFinished(TlsSec * sec,MD5state md5,SHAstate sha1,uchar * fin,int nfin,int isclient)1807 tlsSecFinished(TlsSec *sec, MD5state md5, SHAstate sha1, uchar *fin, int nfin, int isclient)
1808 {
1809 	if(sec->nfin != nfin){
1810 		sec->ok = -1;
1811 		werrstr("invalid finished exchange");
1812 		return -1;
1813 	}
1814 	md5.malloced = 0;
1815 	sha1.malloced = 0;
1816 	(*sec->setFinished)(sec, md5, sha1, fin, isclient);
1817 	return 1;
1818 }
1819 
1820 static void
tlsSecOk(TlsSec * sec)1821 tlsSecOk(TlsSec *sec)
1822 {
1823 	if(sec->ok == 0)
1824 		sec->ok = 1;
1825 }
1826 
1827 static void
tlsSecKill(TlsSec * sec)1828 tlsSecKill(TlsSec *sec)
1829 {
1830 	if(!sec)
1831 		return;
1832 	factotum_rsa_close(sec->rpc);
1833 	sec->ok = -1;
1834 }
1835 
1836 static void
tlsSecClose(TlsSec * sec)1837 tlsSecClose(TlsSec *sec)
1838 {
1839 	if(!sec)
1840 		return;
1841 	factotum_rsa_close(sec->rpc);
1842 	free(sec->server);
1843 	free(sec);
1844 }
1845 
1846 static int
setVers(TlsSec * sec,int v)1847 setVers(TlsSec *sec, int v)
1848 {
1849 	if(v == SSL3Version){
1850 		sec->setFinished = sslSetFinished;
1851 		sec->nfin = SSL3FinishedLen;
1852 		sec->prf = sslPRF;
1853 	}else if(v == TLSVersion){
1854 		sec->setFinished = tlsSetFinished;
1855 		sec->nfin = TLSFinishedLen;
1856 		sec->prf = tlsPRF;
1857 	}else{
1858 		werrstr("invalid version");
1859 		return -1;
1860 	}
1861 	sec->vers = v;
1862 	return 0;
1863 }
1864 
1865 /*
1866  * generate secret keys from the master secret.
1867  *
1868  * different crypto selections will require different amounts
1869  * of key expansion and use of key expansion data,
1870  * but it's all generated using the same function.
1871  */
1872 static void
setSecrets(TlsSec * sec,uchar * kd,int nkd)1873 setSecrets(TlsSec *sec, uchar *kd, int nkd)
1874 {
1875 	(*sec->prf)(kd, nkd, sec->sec, MasterSecretSize, "key expansion",
1876 			sec->srandom, RandomSize, sec->crandom, RandomSize);
1877 }
1878 
1879 /*
1880  * set the master secret from the pre-master secret.
1881  */
1882 static void
setMasterSecret(TlsSec * sec,Bytes * pm)1883 setMasterSecret(TlsSec *sec, Bytes *pm)
1884 {
1885 	(*sec->prf)(sec->sec, MasterSecretSize, pm->data, MasterSecretSize, "master secret",
1886 			sec->crandom, RandomSize, sec->srandom, RandomSize);
1887 }
1888 
1889 static void
serverMasterSecret(TlsSec * sec,uchar * epm,int nepm)1890 serverMasterSecret(TlsSec *sec, uchar *epm, int nepm)
1891 {
1892 	Bytes *pm;
1893 
1894 	pm = pkcs1_decrypt(sec, epm, nepm);
1895 
1896 	// if the client messed up, just continue as if everything is ok,
1897 	// to prevent attacks to check for correctly formatted messages.
1898 	// Hence the fprint(2,) can't be replaced by tlsError(), which sends an Alert msg to the client.
1899 	if(sec->ok < 0 || pm == nil || get16(pm->data) != sec->clientVers){
1900 		fprint(2, "serverMasterSecret failed ok=%d pm=%p pmvers=%x cvers=%x nepm=%d\n",
1901 			sec->ok, pm, pm ? get16(pm->data) : -1, sec->clientVers, nepm);
1902 		sec->ok = -1;
1903 		if(pm != nil)
1904 			freebytes(pm);
1905 		pm = newbytes(MasterSecretSize);
1906 		genrandom(pm->data, MasterSecretSize);
1907 	}
1908 	setMasterSecret(sec, pm);
1909 	memset(pm->data, 0, pm->len);
1910 	freebytes(pm);
1911 }
1912 
1913 static int
clientMasterSecret(TlsSec * sec,RSApub * pub,uchar ** epm,int * nepm)1914 clientMasterSecret(TlsSec *sec, RSApub *pub, uchar **epm, int *nepm)
1915 {
1916 	Bytes *pm, *key;
1917 
1918 	pm = newbytes(MasterSecretSize);
1919 	put16(pm->data, sec->clientVers);
1920 	genrandom(pm->data+2, MasterSecretSize - 2);
1921 
1922 	setMasterSecret(sec, pm);
1923 
1924 	key = pkcs1_encrypt(pm, pub, 2);
1925 	memset(pm->data, 0, pm->len);
1926 	freebytes(pm);
1927 	if(key == nil){
1928 		werrstr("tls pkcs1_encrypt failed");
1929 		return -1;
1930 	}
1931 
1932 	*nepm = key->len;
1933 	*epm = malloc(*nepm);
1934 	if(*epm == nil){
1935 		freebytes(key);
1936 		werrstr("out of memory");
1937 		return -1;
1938 	}
1939 	memmove(*epm, key->data, *nepm);
1940 
1941 	freebytes(key);
1942 
1943 	return 1;
1944 }
1945 
1946 static void
sslSetFinished(TlsSec * sec,MD5state hsmd5,SHAstate hssha1,uchar * finished,int isClient)1947 sslSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient)
1948 {
1949 	DigestState *s;
1950 	uchar h0[MD5dlen], h1[SHA1dlen], pad[48];
1951 	char *label;
1952 
1953 	if(isClient)
1954 		label = "CLNT";
1955 	else
1956 		label = "SRVR";
1957 
1958 	md5((uchar*)label, 4, nil, &hsmd5);
1959 	md5(sec->sec, MasterSecretSize, nil, &hsmd5);
1960 	memset(pad, 0x36, 48);
1961 	md5(pad, 48, nil, &hsmd5);
1962 	md5(nil, 0, h0, &hsmd5);
1963 	memset(pad, 0x5C, 48);
1964 	s = md5(sec->sec, MasterSecretSize, nil, nil);
1965 	s = md5(pad, 48, nil, s);
1966 	md5(h0, MD5dlen, finished, s);
1967 
1968 	sha1((uchar*)label, 4, nil, &hssha1);
1969 	sha1(sec->sec, MasterSecretSize, nil, &hssha1);
1970 	memset(pad, 0x36, 40);
1971 	sha1(pad, 40, nil, &hssha1);
1972 	sha1(nil, 0, h1, &hssha1);
1973 	memset(pad, 0x5C, 40);
1974 	s = sha1(sec->sec, MasterSecretSize, nil, nil);
1975 	s = sha1(pad, 40, nil, s);
1976 	sha1(h1, SHA1dlen, finished + MD5dlen, s);
1977 }
1978 
1979 // fill "finished" arg with md5(args)^sha1(args)
1980 static void
tlsSetFinished(TlsSec * sec,MD5state hsmd5,SHAstate hssha1,uchar * finished,int isClient)1981 tlsSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient)
1982 {
1983 	uchar h0[MD5dlen], h1[SHA1dlen];
1984 	char *label;
1985 
1986 	// get current hash value, but allow further messages to be hashed in
1987 	md5(nil, 0, h0, &hsmd5);
1988 	sha1(nil, 0, h1, &hssha1);
1989 
1990 	if(isClient)
1991 		label = "client finished";
1992 	else
1993 		label = "server finished";
1994 	tlsPRF(finished, TLSFinishedLen, sec->sec, MasterSecretSize, label, h0, MD5dlen, h1, SHA1dlen);
1995 }
1996 
1997 static void
sslPRF(uchar * buf,int nbuf,uchar * key,int nkey,char * label,uchar * seed0,int nseed0,uchar * seed1,int nseed1)1998 sslPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
1999 {
2000 	DigestState *s;
2001 	uchar sha1dig[SHA1dlen], md5dig[MD5dlen], tmp[26];
2002 	int i, n, len;
2003 
2004 	USED(label);
2005 	len = 1;
2006 	while(nbuf > 0){
2007 		if(len > 26)
2008 			return;
2009 		for(i = 0; i < len; i++)
2010 			tmp[i] = 'A' - 1 + len;
2011 		s = sha1(tmp, len, nil, nil);
2012 		s = sha1(key, nkey, nil, s);
2013 		s = sha1(seed0, nseed0, nil, s);
2014 		sha1(seed1, nseed1, sha1dig, s);
2015 		s = md5(key, nkey, nil, nil);
2016 		md5(sha1dig, SHA1dlen, md5dig, s);
2017 		n = MD5dlen;
2018 		if(n > nbuf)
2019 			n = nbuf;
2020 		memmove(buf, md5dig, n);
2021 		buf += n;
2022 		nbuf -= n;
2023 		len++;
2024 	}
2025 }
2026 
2027 static mpint*
bytestomp(Bytes * bytes)2028 bytestomp(Bytes* bytes)
2029 {
2030 	mpint* ans;
2031 
2032 	ans = betomp(bytes->data, bytes->len, nil);
2033 	return ans;
2034 }
2035 
2036 /*
2037  * Convert mpint* to Bytes, putting high order byte first.
2038  */
2039 static Bytes*
mptobytes(mpint * big)2040 mptobytes(mpint* big)
2041 {
2042 	int n, m;
2043 	uchar *a;
2044 	Bytes* ans;
2045 
2046 	n = (mpsignif(big)+7)/8;
2047 	m = mptobe(big, nil, n, &a);
2048 	ans = makebytes(a, m);
2049 	return ans;
2050 }
2051 
2052 // Do RSA computation on block according to key, and pad
2053 // result on left with zeros to make it modlen long.
2054 static Bytes*
rsacomp(Bytes * block,RSApub * key,int modlen)2055 rsacomp(Bytes* block, RSApub* key, int modlen)
2056 {
2057 	mpint *x, *y;
2058 	Bytes *a, *ybytes;
2059 	int ylen;
2060 
2061 	x = bytestomp(block);
2062 	y = rsaencrypt(key, x, nil);
2063 	mpfree(x);
2064 	ybytes = mptobytes(y);
2065 	ylen = ybytes->len;
2066 
2067 	if(ylen < modlen) {
2068 		a = newbytes(modlen);
2069 		memset(a->data, 0, modlen-ylen);
2070 		memmove(a->data+modlen-ylen, ybytes->data, ylen);
2071 		freebytes(ybytes);
2072 		ybytes = a;
2073 	}
2074 	else if(ylen > modlen) {
2075 		// assume it has leading zeros (mod should make it so)
2076 		a = newbytes(modlen);
2077 		memmove(a->data, ybytes->data, modlen);
2078 		freebytes(ybytes);
2079 		ybytes = a;
2080 	}
2081 	mpfree(y);
2082 	return ybytes;
2083 }
2084 
2085 // encrypt data according to PKCS#1, /lib/rfc/rfc2437 9.1.2.1
2086 static Bytes*
pkcs1_encrypt(Bytes * data,RSApub * key,int blocktype)2087 pkcs1_encrypt(Bytes* data, RSApub* key, int blocktype)
2088 {
2089 	Bytes *pad, *eb, *ans;
2090 	int i, dlen, padlen, modlen;
2091 
2092 	modlen = (mpsignif(key->n)+7)/8;
2093 	dlen = data->len;
2094 	if(modlen < 12 || dlen > modlen - 11)
2095 		return nil;
2096 	padlen = modlen - 3 - dlen;
2097 	pad = newbytes(padlen);
2098 	genrandom(pad->data, padlen);
2099 	for(i = 0; i < padlen; i++) {
2100 		if(blocktype == 0)
2101 			pad->data[i] = 0;
2102 		else if(blocktype == 1)
2103 			pad->data[i] = 255;
2104 		else if(pad->data[i] == 0)
2105 			pad->data[i] = 1;
2106 	}
2107 	eb = newbytes(modlen);
2108 	eb->data[0] = 0;
2109 	eb->data[1] = blocktype;
2110 	memmove(eb->data+2, pad->data, padlen);
2111 	eb->data[padlen+2] = 0;
2112 	memmove(eb->data+padlen+3, data->data, dlen);
2113 	ans = rsacomp(eb, key, modlen);
2114 	freebytes(eb);
2115 	freebytes(pad);
2116 	return ans;
2117 }
2118 
2119 // decrypt data according to PKCS#1, with given key.
2120 // expect a block type of 2.
2121 static Bytes*
pkcs1_decrypt(TlsSec * sec,uchar * epm,int nepm)2122 pkcs1_decrypt(TlsSec *sec, uchar *epm, int nepm)
2123 {
2124 	Bytes *eb, *ans = nil;
2125 	int i, modlen;
2126 	mpint *x, *y;
2127 
2128 	modlen = (mpsignif(sec->rsapub->n)+7)/8;
2129 	if(nepm != modlen)
2130 		return nil;
2131 	x = betomp(epm, nepm, nil);
2132 	y = factotum_rsa_decrypt(sec->rpc, x);
2133 	if(y == nil)
2134 		return nil;
2135 	eb = mptobytes(y);
2136 	if(eb->len < modlen){ // pad on left with zeros
2137 		ans = newbytes(modlen);
2138 		memset(ans->data, 0, modlen-eb->len);
2139 		memmove(ans->data+modlen-eb->len, eb->data, eb->len);
2140 		freebytes(eb);
2141 		eb = ans;
2142 	}
2143 	if(eb->data[0] == 0 && eb->data[1] == 2) {
2144 		for(i = 2; i < modlen; i++)
2145 			if(eb->data[i] == 0)
2146 				break;
2147 		if(i < modlen - 1)
2148 			ans = makebytes(eb->data+i+1, modlen-(i+1));
2149 	}
2150 	freebytes(eb);
2151 	return ans;
2152 }
2153 
2154 
2155 //================= general utility functions ========================
2156 
2157 static void *
emalloc(int n)2158 emalloc(int n)
2159 {
2160 	void *p;
2161 	if(n==0)
2162 		n=1;
2163 	p = malloc(n);
2164 	if(p == nil){
2165 		exits("out of memory");
2166 	}
2167 	memset(p, 0, n);
2168 	return p;
2169 }
2170 
2171 static void *
erealloc(void * ReallocP,int ReallocN)2172 erealloc(void *ReallocP, int ReallocN)
2173 {
2174 	if(ReallocN == 0)
2175 		ReallocN = 1;
2176 	if(!ReallocP)
2177 		ReallocP = emalloc(ReallocN);
2178 	else if(!(ReallocP = realloc(ReallocP, ReallocN))){
2179 		exits("out of memory");
2180 	}
2181 	return(ReallocP);
2182 }
2183 
2184 static void
put32(uchar * p,u32int x)2185 put32(uchar *p, u32int x)
2186 {
2187 	p[0] = x>>24;
2188 	p[1] = x>>16;
2189 	p[2] = x>>8;
2190 	p[3] = x;
2191 }
2192 
2193 static void
put24(uchar * p,int x)2194 put24(uchar *p, int x)
2195 {
2196 	p[0] = x>>16;
2197 	p[1] = x>>8;
2198 	p[2] = x;
2199 }
2200 
2201 static void
put16(uchar * p,int x)2202 put16(uchar *p, int x)
2203 {
2204 	p[0] = x>>8;
2205 	p[1] = x;
2206 }
2207 
2208 static u32int
get32(uchar * p)2209 get32(uchar *p)
2210 {
2211 	return (p[0]<<24)|(p[1]<<16)|(p[2]<<8)|p[3];
2212 }
2213 
2214 static int
get24(uchar * p)2215 get24(uchar *p)
2216 {
2217 	return (p[0]<<16)|(p[1]<<8)|p[2];
2218 }
2219 
2220 static int
get16(uchar * p)2221 get16(uchar *p)
2222 {
2223 	return (p[0]<<8)|p[1];
2224 }
2225 
2226 /* ANSI offsetof() */
2227 #define OFFSET(x, s) ((int)(&(((s*)0)->x)))
2228 
2229 /*
2230  * malloc and return a new Bytes structure capable of
2231  * holding len bytes. (len >= 0)
2232  * Used to use crypt_malloc, which aborts if malloc fails.
2233  */
2234 static Bytes*
newbytes(int len)2235 newbytes(int len)
2236 {
2237 	Bytes* ans;
2238 
2239 	ans = (Bytes*)malloc(OFFSET(data[0], Bytes) + len);
2240 	ans->len = len;
2241 	return ans;
2242 }
2243 
2244 /*
2245  * newbytes(len), with data initialized from buf
2246  */
2247 static Bytes*
makebytes(uchar * buf,int len)2248 makebytes(uchar* buf, int len)
2249 {
2250 	Bytes* ans;
2251 
2252 	ans = newbytes(len);
2253 	memmove(ans->data, buf, len);
2254 	return ans;
2255 }
2256 
2257 static void
freebytes(Bytes * b)2258 freebytes(Bytes* b)
2259 {
2260 	if(b != nil)
2261 		free(b);
2262 }
2263 
2264 /* len is number of ints */
2265 static Ints*
newints(int len)2266 newints(int len)
2267 {
2268 	Ints* ans;
2269 
2270 	ans = (Ints*)malloc(OFFSET(data[0], Ints) + len*sizeof(int));
2271 	ans->len = len;
2272 	return ans;
2273 }
2274 
2275 static Ints*
makeints(int * buf,int len)2276 makeints(int* buf, int len)
2277 {
2278 	Ints* ans;
2279 
2280 	ans = newints(len);
2281 	if(len > 0)
2282 		memmove(ans->data, buf, len*sizeof(int));
2283 	return ans;
2284 }
2285 
2286 static void
freeints(Ints * b)2287 freeints(Ints* b)
2288 {
2289 	if(b != nil)
2290 		free(b);
2291 }
2292