1 #include <u.h>
2 #include <libc.h>
3 #include <bio.h>
4 #include <auth.h>
5 #include <mp.h>
6 #include <libsec.h>
7
8 // The main groups of functions are:
9 // client/server - main handshake protocol definition
10 // message functions - formating handshake messages
11 // cipher choices - catalog of digest and encrypt algorithms
12 // security functions - PKCS#1, sslHMAC, session keygen
13 // general utility functions - malloc, serialization
14 // The handshake protocol builds on the TLS/SSL3 record layer protocol,
15 // which is implemented in kernel device #a. See also /lib/rfc/rfc2246.
16
17 enum {
18 TLSFinishedLen = 12,
19 SSL3FinishedLen = MD5dlen+SHA1dlen,
20 MaxKeyData = 136, // amount of secret we may need
21 MaxChunk = 1<<14,
22 RandomSize = 32,
23 SidSize = 32,
24 MasterSecretSize = 48,
25 AQueue = 0,
26 AFlush = 1,
27 };
28
29 typedef struct TlsSec TlsSec;
30
31 typedef struct Bytes{
32 int len;
33 uchar data[1]; // [len]
34 } Bytes;
35
36 typedef struct Ints{
37 int len;
38 int data[1]; // [len]
39 } Ints;
40
41 typedef struct Algs{
42 char *enc;
43 char *digest;
44 int nsecret;
45 int tlsid;
46 int ok;
47 } Algs;
48
49 typedef struct Finished{
50 uchar verify[SSL3FinishedLen];
51 int n;
52 } Finished;
53
54 typedef struct HandHash{
55 MD5state md5;
56 SHAstate sha1;
57 SHA2_256state sha2_256;
58 } HandHash;
59
60 typedef struct TlsConnection{
61 TlsSec *sec; // security management goo
62 int hand, ctl; // record layer file descriptors
63 int erred; // set when tlsError called
64 int (*trace)(char*fmt, ...); // for debugging
65 int version; // protocol we are speaking
66 int verset; // version has been set
67 int ver2hi; // server got a version 2 hello
68 int isClient; // is this the client or server?
69 Bytes *sid; // SessionID
70 Bytes *cert; // only last - no chain
71
72 Lock statelk;
73 int state; // must be set using setstate
74
75 // input buffer for handshake messages
76 uchar buf[MaxChunk+8*1024];
77 uchar *rp, *ep;
78
79 uchar crandom[RandomSize]; // client random
80 uchar srandom[RandomSize]; // server random
81 int clientVersion; // version in ClientHello
82 char *digest; // name of digest algorithm to use
83 char *enc; // name of encryption algorithm to use
84 int nsecret; // amount of secret data to init keys
85
86 // for finished messages
87 HandHash hs; // handshake hash
88 Finished finished;
89 } TlsConnection;
90
91 typedef struct Msg{
92 int tag;
93 union {
94 struct {
95 int version;
96 uchar random[RandomSize];
97 Bytes* sid;
98 Ints* ciphers;
99 Bytes* compressors;
100 Ints* sigAlgs;
101 } clientHello;
102 struct {
103 int version;
104 uchar random[RandomSize];
105 Bytes* sid;
106 int cipher;
107 int compressor;
108 } serverHello;
109 struct {
110 int ncert;
111 Bytes **certs;
112 } certificate;
113 struct {
114 Bytes *types;
115 int nca;
116 Bytes **cas;
117 } certificateRequest;
118 struct {
119 Bytes *key;
120 } clientKeyExchange;
121 Finished finished;
122 } u;
123 } Msg;
124
125 typedef struct TlsSec{
126 char *server; // name of remote; nil for server
127 int ok; // <0 killed; == 0 in progress; >0 reusable
128 RSApub *rsapub;
129 AuthRpc *rpc; // factotum for rsa private key
130 uchar sec[MasterSecretSize]; // master secret
131 uchar crandom[RandomSize]; // client random
132 uchar srandom[RandomSize]; // server random
133 int clientVers; // version in ClientHello
134 int vers; // final version
135 // byte generation and handshake checksum
136 void (*prf)(uchar*, int, uchar*, int, char*, uchar*, int, uchar*, int);
137 void (*setFinished)(TlsSec*, HandHash, uchar*, int);
138 int nfin;
139 } TlsSec;
140
141
142 enum {
143 SSL3Version = 0x0300,
144 TLS10Version = 0x0301,
145 TLS11Version = 0x0302,
146 TLS12Version = 0x0303,
147 ProtocolVersion = TLS12Version, // maximum version we speak
148 MinProtoVersion = 0x0300, // limits on version we accept
149 MaxProtoVersion = 0x03ff,
150 };
151
152 // handshake type
153 enum {
154 HHelloRequest,
155 HClientHello,
156 HServerHello,
157 HSSL2ClientHello = 9, /* local convention; see devtls.c */
158 HCertificate = 11,
159 HServerKeyExchange,
160 HCertificateRequest,
161 HServerHelloDone,
162 HCertificateVerify,
163 HClientKeyExchange,
164 HFinished = 20,
165 HMax
166 };
167
168 // alerts
169 enum {
170 ECloseNotify = 0,
171 EUnexpectedMessage = 10,
172 EBadRecordMac = 20,
173 EDecryptionFailed = 21,
174 ERecordOverflow = 22,
175 EDecompressionFailure = 30,
176 EHandshakeFailure = 40,
177 ENoCertificate = 41,
178 EBadCertificate = 42,
179 EUnsupportedCertificate = 43,
180 ECertificateRevoked = 44,
181 ECertificateExpired = 45,
182 ECertificateUnknown = 46,
183 EIllegalParameter = 47,
184 EUnknownCa = 48,
185 EAccessDenied = 49,
186 EDecodeError = 50,
187 EDecryptError = 51,
188 EExportRestriction = 60,
189 EProtocolVersion = 70,
190 EInsufficientSecurity = 71,
191 EInternalError = 80,
192 EUserCanceled = 90,
193 ENoRenegotiation = 100,
194 EMax = 256
195 };
196
197 // cipher suites
198 enum {
199 TLS_NULL_WITH_NULL_NULL = 0x0000,
200 TLS_RSA_WITH_NULL_MD5 = 0x0001,
201 TLS_RSA_WITH_NULL_SHA = 0x0002,
202 TLS_RSA_EXPORT_WITH_RC4_40_MD5 = 0x0003,
203 TLS_RSA_WITH_RC4_128_MD5 = 0x0004,
204 TLS_RSA_WITH_RC4_128_SHA = 0x0005,
205 TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5 = 0X0006,
206 TLS_RSA_WITH_IDEA_CBC_SHA = 0X0007,
207 TLS_RSA_EXPORT_WITH_DES40_CBC_SHA = 0X0008,
208 TLS_RSA_WITH_DES_CBC_SHA = 0X0009,
209 TLS_RSA_WITH_3DES_EDE_CBC_SHA = 0X000A,
210 TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA = 0X000B,
211 TLS_DH_DSS_WITH_DES_CBC_SHA = 0X000C,
212 TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA = 0X000D,
213 TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA = 0X000E,
214 TLS_DH_RSA_WITH_DES_CBC_SHA = 0X000F,
215 TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA = 0X0010,
216 TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA = 0X0011,
217 TLS_DHE_DSS_WITH_DES_CBC_SHA = 0X0012,
218 TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA = 0X0013, // ZZZ must be implemented for tls1.0 compliance
219 TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA = 0X0014,
220 TLS_DHE_RSA_WITH_DES_CBC_SHA = 0X0015,
221 TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA = 0X0016,
222 TLS_DH_anon_EXPORT_WITH_RC4_40_MD5 = 0x0017,
223 TLS_DH_anon_WITH_RC4_128_MD5 = 0x0018,
224 TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA = 0X0019,
225 TLS_DH_anon_WITH_DES_CBC_SHA = 0X001A,
226 TLS_DH_anon_WITH_3DES_EDE_CBC_SHA = 0X001B,
227
228 TLS_RSA_WITH_AES_128_CBC_SHA = 0X002f, // aes, aka rijndael with 128 bit blocks
229 TLS_DH_DSS_WITH_AES_128_CBC_SHA = 0X0030,
230 TLS_DH_RSA_WITH_AES_128_CBC_SHA = 0X0031,
231 TLS_DHE_DSS_WITH_AES_128_CBC_SHA = 0X0032,
232 TLS_DHE_RSA_WITH_AES_128_CBC_SHA = 0X0033,
233 TLS_DH_anon_WITH_AES_128_CBC_SHA = 0X0034,
234 TLS_RSA_WITH_AES_256_CBC_SHA = 0X0035,
235 TLS_DH_DSS_WITH_AES_256_CBC_SHA = 0X0036,
236 TLS_DH_RSA_WITH_AES_256_CBC_SHA = 0X0037,
237 TLS_DHE_DSS_WITH_AES_256_CBC_SHA = 0X0038,
238 TLS_DHE_RSA_WITH_AES_256_CBC_SHA = 0X0039,
239 TLS_DH_anon_WITH_AES_256_CBC_SHA = 0X003A,
240 CipherMax
241 };
242
243 // compression methods
244 enum {
245 CompressionNull = 0,
246 CompressionMax
247 };
248
249 // extensions
250 enum {
251 ExtSigalgs = 0xd,
252 };
253
254 // signature algorithms
255 enum {
256 RSA_PKCS1_SHA1 = 0x0201,
257 RSA_PKCS1_SHA256 = 0x0401,
258 RSA_PKCS1_SHA384 = 0x0501,
259 RSA_PKCS1_SHA512 = 0x0601
260 };
261
262 static Algs cipherAlgs[] = {
263 {"rc4_128", "md5", 2*(16+MD5dlen), TLS_RSA_WITH_RC4_128_MD5},
264 {"rc4_128", "sha1", 2*(16+SHA1dlen), TLS_RSA_WITH_RC4_128_SHA},
265 {"3des_ede_cbc", "sha1", 2*(4*8+SHA1dlen), TLS_RSA_WITH_3DES_EDE_CBC_SHA},
266 {"aes_128_cbc", "sha1", 2*(16+16+SHA1dlen), TLS_RSA_WITH_AES_128_CBC_SHA},
267 {"aes_256_cbc", "sha1", 2*(32+16+SHA1dlen), TLS_RSA_WITH_AES_256_CBC_SHA}
268 };
269
270 static uchar compressors[] = {
271 CompressionNull,
272 };
273
274 static int sigAlgs[] = {
275 RSA_PKCS1_SHA256,
276 RSA_PKCS1_SHA1,
277 };
278
279 static TlsConnection *tlsServer2(int ctl, int hand, uchar *cert, int ncert, int (*trace)(char*fmt, ...), PEMChain *chain);
280 static TlsConnection *tlsClient2(int ctl, int hand, uchar *csid, int ncsid, int (*trace)(char*fmt, ...));
281
282 static void msgClear(Msg *m);
283 static char* msgPrint(char *buf, int n, Msg *m);
284 static int msgRecv(TlsConnection *c, Msg *m);
285 static int msgSend(TlsConnection *c, Msg *m, int act);
286 static void tlsError(TlsConnection *c, int err, char *msg, ...);
287 #pragma varargck argpos tlsError 3
288 static int setVersion(TlsConnection *c, int version);
289 static int finishedMatch(TlsConnection *c, Finished *f);
290 static void tlsConnectionFree(TlsConnection *c);
291
292 static int setAlgs(TlsConnection *c, int a);
293 static int okCipher(Ints *cv);
294 static int okCompression(Bytes *cv);
295 static int initCiphers(void);
296 static Ints* makeciphers(void);
297
298 static TlsSec* tlsSecInits(int cvers, uchar *csid, int ncsid, uchar *crandom, uchar *ssid, int *nssid, uchar *srandom);
299 static int tlsSecSecrets(TlsSec *sec, int vers, uchar *epm, int nepm, uchar *kd, int nkd);
300 static TlsSec* tlsSecInitc(int cvers, uchar *crandom);
301 static int tlsSecSecretc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers, uchar **epm, int *nepm, uchar *kd, int nkd);
302 static int tlsSecFinished(TlsSec *sec, HandHash hs, uchar *fin, int nfin, int isclient);
303 static void tlsSecOk(TlsSec *sec);
304 static void tlsSecKill(TlsSec *sec);
305 static void tlsSecClose(TlsSec *sec);
306 static void setMasterSecret(TlsSec *sec, Bytes *pm);
307 static void serverMasterSecret(TlsSec *sec, uchar *epm, int nepm);
308 static void setSecrets(TlsSec *sec, uchar *kd, int nkd);
309 static int clientMasterSecret(TlsSec *sec, RSApub *pub, uchar **epm, int *nepm);
310 static Bytes *pkcs1_encrypt(Bytes* data, RSApub* key, int blocktype);
311 static Bytes *pkcs1_decrypt(TlsSec *sec, uchar *epm, int nepm);
312 static void tlsSetFinished(TlsSec *sec, HandHash hs, uchar *finished, int isClient);
313 static void tls12SetFinished(TlsSec *sec, HandHash hs, uchar *finished, int isClient);
314 static void sslSetFinished(TlsSec *sec, HandHash hs, uchar *finished, int isClient);
315 static void sslPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label,
316 uchar *seed0, int nseed0, uchar *seed1, int nseed1);
317 static int setVers(TlsSec *sec, int version);
318
319 static AuthRpc* factotum_rsa_open(uchar *cert, int certlen);
320 static mpint* factotum_rsa_decrypt(AuthRpc *rpc, mpint *cipher);
321 static void factotum_rsa_close(AuthRpc*rpc);
322
323 static void* emalloc(int);
324 static void* erealloc(void*, int);
325 static void put32(uchar *p, u32int);
326 static void put24(uchar *p, int);
327 static void put16(uchar *p, int);
328 static u32int get32(uchar *p);
329 static int get24(uchar *p);
330 static int get16(uchar *p);
331 static Bytes* newbytes(int len);
332 static Bytes* makebytes(uchar* buf, int len);
333 static void freebytes(Bytes* b);
334 static Ints* newints(int len);
335 static Ints* makeints(int* buf, int len);
336 static void freeints(Ints* b);
337
338 //================= client/server ========================
339
340 // push TLS onto fd, returning new (application) file descriptor
341 // or -1 if error.
342 int
tlsServer(int fd,TLSconn * conn)343 tlsServer(int fd, TLSconn *conn)
344 {
345 char buf[8];
346 char dname[64];
347 int n, data, ctl, hand;
348 TlsConnection *tls;
349
350 if(conn == nil)
351 return -1;
352 ctl = open("#a/tls/clone", ORDWR);
353 if(ctl < 0)
354 return -1;
355 n = read(ctl, buf, sizeof(buf)-1);
356 if(n < 0){
357 close(ctl);
358 return -1;
359 }
360 buf[n] = 0;
361 snprint(conn->dir, sizeof conn->dir, "#a/tls/%s", buf);
362 snprint(dname, sizeof dname, "#a/tls/%s/hand", buf);
363 hand = open(dname, ORDWR);
364 if(hand < 0){
365 close(ctl);
366 return -1;
367 }
368 fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion);
369 tls = tlsServer2(ctl, hand, conn->cert, conn->certlen, conn->trace, conn->chain);
370 snprint(dname, sizeof dname, "#a/tls/%s/data", buf);
371 data = open(dname, ORDWR);
372 close(fd);
373 close(hand);
374 close(ctl);
375 if(data < 0){
376 return -1;
377 }
378 if(tls == nil){
379 close(data);
380 return -1;
381 }
382 if(conn->cert)
383 free(conn->cert);
384 conn->cert = 0; // client certificates are not yet implemented
385 conn->certlen = 0;
386 conn->sessionIDlen = tls->sid->len;
387 conn->sessionID = emalloc(conn->sessionIDlen);
388 memcpy(conn->sessionID, tls->sid->data, conn->sessionIDlen);
389 if(conn->sessionKey != nil && conn->sessionType != nil && strcmp(conn->sessionType, "ttls") == 0)
390 tls->sec->prf(conn->sessionKey, conn->sessionKeylen, tls->sec->sec, MasterSecretSize, conn->sessionConst, tls->sec->crandom, RandomSize, tls->sec->srandom, RandomSize);
391 tlsConnectionFree(tls);
392 return data;
393 }
394
395 // push TLS onto fd, returning new (application) file descriptor
396 // or -1 if error.
397 int
tlsClient(int fd,TLSconn * conn)398 tlsClient(int fd, TLSconn *conn)
399 {
400 char buf[8];
401 char dname[64];
402 int n, data, ctl, hand;
403 TlsConnection *tls;
404
405 if(!conn)
406 return -1;
407 ctl = open("#a/tls/clone", ORDWR);
408 if(ctl < 0)
409 return -1;
410 n = read(ctl, buf, sizeof(buf)-1);
411 if(n < 0){
412 close(ctl);
413 return -1;
414 }
415 buf[n] = 0;
416 snprint(conn->dir, sizeof conn->dir, "#a/tls/%s", buf);
417 snprint(dname, sizeof dname, "#a/tls/%s/hand", buf);
418 hand = open(dname, ORDWR);
419 if(hand < 0){
420 close(ctl);
421 return -1;
422 }
423 snprint(dname, sizeof dname, "#a/tls/%s/data", buf);
424 data = open(dname, ORDWR);
425 if(data < 0) {
426 close(hand);
427 close(ctl);
428 return -1;
429 }
430 fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion);
431 tls = tlsClient2(ctl, hand, conn->sessionID, conn->sessionIDlen, conn->trace);
432 close(fd);
433 close(hand);
434 close(ctl);
435 if(tls == nil){
436 close(data);
437 return -1;
438 }
439 conn->certlen = tls->cert->len;
440 conn->cert = emalloc(conn->certlen);
441 memcpy(conn->cert, tls->cert->data, conn->certlen);
442 conn->sessionIDlen = tls->sid->len;
443 conn->sessionID = emalloc(conn->sessionIDlen);
444 memcpy(conn->sessionID, tls->sid->data, conn->sessionIDlen);
445 if(conn->sessionKey != nil && conn->sessionType != nil && strcmp(conn->sessionType, "ttls") == 0)
446 tls->sec->prf(conn->sessionKey, conn->sessionKeylen, tls->sec->sec, MasterSecretSize, conn->sessionConst, tls->sec->crandom, RandomSize, tls->sec->srandom, RandomSize);
447 tlsConnectionFree(tls);
448 return data;
449 }
450
451 static int
countchain(PEMChain * p)452 countchain(PEMChain *p)
453 {
454 int i = 0;
455
456 while (p) {
457 i++;
458 p = p->next;
459 }
460 return i;
461 }
462
463 static TlsConnection *
tlsServer2(int ctl,int hand,uchar * cert,int ncert,int (* trace)(char * fmt,...),PEMChain * chp)464 tlsServer2(int ctl, int hand, uchar *cert, int ncert, int (*trace)(char*fmt, ...), PEMChain *chp)
465 {
466 TlsConnection *c;
467 Msg m;
468 Bytes *csid;
469 uchar sid[SidSize], kd[MaxKeyData];
470 char *secrets;
471 int cipher, compressor, nsid, rv, numcerts, i;
472
473 if(trace)
474 trace("tlsServer2\n");
475 if(!initCiphers())
476 return nil;
477 c = emalloc(sizeof(TlsConnection));
478 c->ctl = ctl;
479 c->hand = hand;
480 c->trace = trace;
481 c->version = ProtocolVersion;
482
483 memset(&m, 0, sizeof(m));
484 if(!msgRecv(c, &m)){
485 if(trace)
486 trace("initial msgRecv failed\n");
487 goto Err;
488 }
489 if(m.tag != HClientHello) {
490 tlsError(c, EUnexpectedMessage, "expected a client hello");
491 goto Err;
492 }
493 c->clientVersion = m.u.clientHello.version;
494 if(trace)
495 trace("ClientHello version %x\n", c->clientVersion);
496 if(setVersion(c, m.u.clientHello.version) < 0) {
497 tlsError(c, EIllegalParameter, "incompatible version");
498 goto Err;
499 }
500
501 memmove(c->crandom, m.u.clientHello.random, RandomSize);
502 cipher = okCipher(m.u.clientHello.ciphers);
503 if(cipher < 0) {
504 // reply with EInsufficientSecurity if we know that's the case
505 if(cipher == -2)
506 tlsError(c, EInsufficientSecurity, "cipher suites too weak");
507 else
508 tlsError(c, EHandshakeFailure, "no matching cipher suite");
509 goto Err;
510 }
511 if(!setAlgs(c, cipher)){
512 tlsError(c, EHandshakeFailure, "no matching cipher suite");
513 goto Err;
514 }
515 compressor = okCompression(m.u.clientHello.compressors);
516 if(compressor < 0) {
517 tlsError(c, EHandshakeFailure, "no matching compressor");
518 goto Err;
519 }
520
521 csid = m.u.clientHello.sid;
522 if(trace)
523 trace(" cipher %d, compressor %d, csidlen %d\n", cipher, compressor, csid->len);
524 c->sec = tlsSecInits(c->clientVersion, csid->data, csid->len, c->crandom, sid, &nsid, c->srandom);
525 if(c->sec == nil){
526 tlsError(c, EHandshakeFailure, "can't initialize security: %r");
527 goto Err;
528 }
529 c->sec->rpc = factotum_rsa_open(cert, ncert);
530 if(c->sec->rpc == nil){
531 tlsError(c, EHandshakeFailure, "factotum_rsa_open: %r");
532 goto Err;
533 }
534 c->sec->rsapub = X509toRSApub(cert, ncert, nil, 0);
535 msgClear(&m);
536
537 m.tag = HServerHello;
538 m.u.serverHello.version = c->version;
539 memmove(m.u.serverHello.random, c->srandom, RandomSize);
540 m.u.serverHello.cipher = cipher;
541 m.u.serverHello.compressor = compressor;
542 c->sid = makebytes(sid, nsid);
543 m.u.serverHello.sid = makebytes(c->sid->data, c->sid->len);
544 if(!msgSend(c, &m, AQueue))
545 goto Err;
546 msgClear(&m);
547
548 m.tag = HCertificate;
549 numcerts = countchain(chp);
550 m.u.certificate.ncert = 1 + numcerts;
551 m.u.certificate.certs = emalloc(m.u.certificate.ncert * sizeof(Bytes));
552 m.u.certificate.certs[0] = makebytes(cert, ncert);
553 for (i = 0; i < numcerts && chp; i++, chp = chp->next)
554 m.u.certificate.certs[i+1] = makebytes(chp->pem, chp->pemlen);
555 if(!msgSend(c, &m, AQueue))
556 goto Err;
557 msgClear(&m);
558
559 m.tag = HServerHelloDone;
560 if(!msgSend(c, &m, AFlush))
561 goto Err;
562 msgClear(&m);
563
564 if(!msgRecv(c, &m))
565 goto Err;
566 if(m.tag != HClientKeyExchange) {
567 tlsError(c, EUnexpectedMessage, "expected a client key exchange");
568 goto Err;
569 }
570 if(tlsSecSecrets(c->sec, c->version, m.u.clientKeyExchange.key->data, m.u.clientKeyExchange.key->len, kd, c->nsecret) < 0){
571 tlsError(c, EHandshakeFailure, "couldn't set secrets: %r");
572 goto Err;
573 }
574 if(trace)
575 trace("tls secrets\n");
576 secrets = (char*)emalloc(2*c->nsecret);
577 enc64(secrets, 2*c->nsecret, kd, c->nsecret);
578 rv = fprint(c->ctl, "secret %s %s 0 %s", c->digest, c->enc, secrets);
579 memset(secrets, 0, 2*c->nsecret);
580 free(secrets);
581 memset(kd, 0, c->nsecret);
582 if(rv < 0){
583 tlsError(c, EHandshakeFailure, "can't set keys: %r");
584 goto Err;
585 }
586 msgClear(&m);
587
588 /* no CertificateVerify; skip to Finished */
589 if(tlsSecFinished(c->sec, c->hs, c->finished.verify, c->finished.n, 1) < 0){
590 tlsError(c, EInternalError, "can't set finished: %r");
591 goto Err;
592 }
593 if(!msgRecv(c, &m))
594 goto Err;
595 if(m.tag != HFinished) {
596 tlsError(c, EUnexpectedMessage, "expected a finished");
597 goto Err;
598 }
599 if(!finishedMatch(c, &m.u.finished)) {
600 tlsError(c, EHandshakeFailure, "finished verification failed");
601 goto Err;
602 }
603 msgClear(&m);
604
605 /* change cipher spec */
606 if(fprint(c->ctl, "changecipher") < 0){
607 tlsError(c, EInternalError, "can't enable cipher: %r");
608 goto Err;
609 }
610
611 if(tlsSecFinished(c->sec, c->hs, c->finished.verify, c->finished.n, 0) < 0){
612 tlsError(c, EInternalError, "can't set finished: %r");
613 goto Err;
614 }
615 m.tag = HFinished;
616 m.u.finished = c->finished;
617 if(!msgSend(c, &m, AFlush))
618 goto Err;
619 msgClear(&m);
620 if(trace)
621 trace("tls finished\n");
622
623 if(fprint(c->ctl, "opened") < 0)
624 goto Err;
625 tlsSecOk(c->sec);
626 return c;
627
628 Err:
629 msgClear(&m);
630 tlsConnectionFree(c);
631 return 0;
632 }
633
634 static TlsConnection *
tlsClient2(int ctl,int hand,uchar * csid,int ncsid,int (* trace)(char * fmt,...))635 tlsClient2(int ctl, int hand, uchar *csid, int ncsid, int (*trace)(char*fmt, ...))
636 {
637 TlsConnection *c;
638 Msg m;
639 uchar kd[MaxKeyData], *epm;
640 char *secrets;
641 int creq, nepm, rv;
642
643 if(!initCiphers())
644 return nil;
645 epm = nil;
646 c = emalloc(sizeof(TlsConnection));
647 c->version = ProtocolVersion;
648 c->ctl = ctl;
649 c->hand = hand;
650 c->trace = trace;
651 c->isClient = 1;
652 c->clientVersion = c->version;
653
654 c->sec = tlsSecInitc(c->clientVersion, c->crandom);
655 if(c->sec == nil)
656 goto Err;
657
658 /* client hello */
659 memset(&m, 0, sizeof(m));
660 m.tag = HClientHello;
661 m.u.clientHello.version = c->clientVersion;
662 memmove(m.u.clientHello.random, c->crandom, RandomSize);
663 m.u.clientHello.sid = makebytes(csid, ncsid);
664 m.u.clientHello.ciphers = makeciphers();
665 m.u.clientHello.compressors = makebytes(compressors,sizeof(compressors));
666 if(c->clientVersion >= TLS12Version)
667 m.u.clientHello.sigAlgs = makeints(sigAlgs, nelem(sigAlgs));
668 if(!msgSend(c, &m, AFlush))
669 goto Err;
670 msgClear(&m);
671
672 /* server hello */
673 if(!msgRecv(c, &m))
674 goto Err;
675 if(m.tag != HServerHello) {
676 tlsError(c, EUnexpectedMessage, "expected a server hello");
677 goto Err;
678 }
679 if(setVersion(c, m.u.serverHello.version) < 0) {
680 tlsError(c, EIllegalParameter, "incompatible version %r");
681 goto Err;
682 }
683 memmove(c->srandom, m.u.serverHello.random, RandomSize);
684 c->sid = makebytes(m.u.serverHello.sid->data, m.u.serverHello.sid->len);
685 if(c->sid->len != 0 && c->sid->len != SidSize) {
686 tlsError(c, EIllegalParameter, "invalid server session identifier");
687 goto Err;
688 }
689 if(!setAlgs(c, m.u.serverHello.cipher)) {
690 tlsError(c, EIllegalParameter, "invalid cipher suite");
691 goto Err;
692 }
693 if(m.u.serverHello.compressor != CompressionNull) {
694 tlsError(c, EIllegalParameter, "invalid compression");
695 goto Err;
696 }
697 msgClear(&m);
698
699 /* certificate */
700 if(!msgRecv(c, &m) || m.tag != HCertificate) {
701 tlsError(c, EUnexpectedMessage, "expected a certificate");
702 goto Err;
703 }
704 if(m.u.certificate.ncert < 1) {
705 tlsError(c, EIllegalParameter, "runt certificate");
706 goto Err;
707 }
708 c->cert = makebytes(m.u.certificate.certs[0]->data, m.u.certificate.certs[0]->len);
709 msgClear(&m);
710
711 /* server key exchange (optional) */
712 if(!msgRecv(c, &m))
713 goto Err;
714 if(m.tag == HServerKeyExchange) {
715 tlsError(c, EUnexpectedMessage, "got an server key exchange");
716 goto Err;
717 // If implementing this later, watch out for rollback attack
718 // described in Wagner Schneier 1996, section 4.4.
719 }
720
721 /* certificate request (optional) */
722 creq = 0;
723 if(m.tag == HCertificateRequest) {
724 creq = 1;
725 msgClear(&m);
726 if(!msgRecv(c, &m))
727 goto Err;
728 }
729
730 if(m.tag != HServerHelloDone) {
731 tlsError(c, EUnexpectedMessage, "expected a server hello done");
732 goto Err;
733 }
734 msgClear(&m);
735
736 if(tlsSecSecretc(c->sec, c->sid->data, c->sid->len, c->srandom,
737 c->cert->data, c->cert->len, c->version, &epm, &nepm,
738 kd, c->nsecret) < 0){
739 tlsError(c, EBadCertificate, "invalid x509/rsa certificate");
740 goto Err;
741 }
742 secrets = (char*)emalloc(2*c->nsecret);
743 enc64(secrets, 2*c->nsecret, kd, c->nsecret);
744 rv = fprint(c->ctl, "secret %s %s 1 %s", c->digest, c->enc, secrets);
745 memset(secrets, 0, 2*c->nsecret);
746 free(secrets);
747 memset(kd, 0, c->nsecret);
748 if(rv < 0){
749 tlsError(c, EHandshakeFailure, "can't set keys: %r");
750 goto Err;
751 }
752
753 if(creq) {
754 /* send a zero length certificate */
755 m.tag = HCertificate;
756 if(!msgSend(c, &m, AFlush))
757 goto Err;
758 msgClear(&m);
759 }
760
761 /* client key exchange */
762 m.tag = HClientKeyExchange;
763 m.u.clientKeyExchange.key = makebytes(epm, nepm);
764 free(epm);
765 epm = nil;
766 if(m.u.clientKeyExchange.key == nil) {
767 tlsError(c, EHandshakeFailure, "can't set secret: %r");
768 goto Err;
769 }
770 if(!msgSend(c, &m, AFlush))
771 goto Err;
772 msgClear(&m);
773
774 /* change cipher spec */
775 if(fprint(c->ctl, "changecipher") < 0){
776 tlsError(c, EInternalError, "can't enable cipher: %r");
777 goto Err;
778 }
779
780 // Cipherchange must occur immediately before Finished to avoid
781 // potential hole; see section 4.3 of Wagner Schneier 1996.
782 if(tlsSecFinished(c->sec, c->hs, c->finished.verify, c->finished.n, 1) < 0){
783 tlsError(c, EInternalError, "can't set finished 1: %r");
784 goto Err;
785 }
786 m.tag = HFinished;
787 m.u.finished = c->finished;
788
789 if(!msgSend(c, &m, AFlush)) {
790 fprint(2, "tlsClient nepm=%d\n", nepm);
791 tlsError(c, EInternalError, "can't flush after client Finished: %r");
792 goto Err;
793 }
794 msgClear(&m);
795
796 if(tlsSecFinished(c->sec, c->hs, c->finished.verify, c->finished.n, 0) < 0){
797 fprint(2, "tlsClient nepm=%d\n", nepm);
798 tlsError(c, EInternalError, "can't set finished 0: %r");
799 goto Err;
800 }
801 if(!msgRecv(c, &m)) {
802 fprint(2, "tlsClient nepm=%d\n", nepm);
803 tlsError(c, EInternalError, "can't read server Finished: %r");
804 goto Err;
805 }
806 if(m.tag != HFinished) {
807 fprint(2, "tlsClient nepm=%d\n", nepm);
808 tlsError(c, EUnexpectedMessage, "expected a Finished msg from server");
809 goto Err;
810 }
811
812 if(!finishedMatch(c, &m.u.finished)) {
813 tlsError(c, EHandshakeFailure, "finished verification failed");
814 goto Err;
815 }
816 msgClear(&m);
817
818 if(fprint(c->ctl, "opened") < 0){
819 if(trace)
820 trace("unable to do final open: %r\n");
821 goto Err;
822 }
823 tlsSecOk(c->sec);
824 return c;
825
826 Err:
827 free(epm);
828 msgClear(&m);
829 tlsConnectionFree(c);
830 return 0;
831 }
832
833
834 //================= message functions ========================
835
836 static uchar sendbuf[9000], *sendp;
837
838 static void
msgHash(TlsConnection * c,uchar * p,int n)839 msgHash(TlsConnection *c, uchar *p, int n)
840 {
841 md5(p, n, 0, &c->hs.md5);
842 sha1(p, n, 0, &c->hs.sha1);
843 if(c->version >= TLS12Version)
844 sha2_256(p, n, 0, &c->hs.sha2_256);
845 else
846 memset(&c->hs.sha2_256, 0, sizeof c->hs.sha2_256);
847 }
848
849 static int
msgSend(TlsConnection * c,Msg * m,int act)850 msgSend(TlsConnection *c, Msg *m, int act)
851 {
852 uchar *p; // sendp = start of new message; p = write pointer
853 int nn, n, i;
854
855 if(sendp == nil)
856 sendp = sendbuf;
857 p = sendp;
858 if(c->trace)
859 c->trace("send %s", msgPrint((char*)p, (sizeof sendbuf) - (p-sendbuf), m));
860
861 p[0] = m->tag; // header - fill in size later
862 p += 4;
863
864 switch(m->tag) {
865 default:
866 tlsError(c, EInternalError, "can't encode a %d", m->tag);
867 goto Err;
868 case HClientHello:
869 // version
870 put16(p, m->u.clientHello.version);
871 p += 2;
872
873 // random
874 memmove(p, m->u.clientHello.random, RandomSize);
875 p += RandomSize;
876
877 // sid
878 n = m->u.clientHello.sid->len;
879 assert(n < 256);
880 p[0] = n;
881 memmove(p+1, m->u.clientHello.sid->data, n);
882 p += n+1;
883
884 n = m->u.clientHello.ciphers->len;
885 assert(n > 0 && n < 200);
886 put16(p, n*2);
887 p += 2;
888 for(i=0; i<n; i++) {
889 put16(p, m->u.clientHello.ciphers->data[i]);
890 p += 2;
891 }
892
893 n = m->u.clientHello.compressors->len;
894 assert(n > 0);
895 p[0] = n;
896 memmove(p+1, m->u.clientHello.compressors->data, n);
897 p += n+1;
898
899 if(m->u.clientHello.sigAlgs != nil) {
900 n = m->u.clientHello.sigAlgs->len;
901 put16(p, 6 + 2*n); /* length of extensions */
902 put16(p+2, ExtSigalgs);
903 put16(p+4, 2 + 2*n); /* length of extension content */
904 put16(p+6, 2*n); /* length of algorithm list */
905 p += 8;
906 for(i = 0; i < n; i++) {
907 put16(p, m->u.clientHello.sigAlgs->data[i]);
908 p += 2;
909 }
910 }
911 break;
912 case HServerHello:
913 put16(p, m->u.serverHello.version);
914 p += 2;
915
916 // random
917 memmove(p, m->u.serverHello.random, RandomSize);
918 p += RandomSize;
919
920 // sid
921 n = m->u.serverHello.sid->len;
922 assert(n < 256);
923 p[0] = n;
924 memmove(p+1, m->u.serverHello.sid->data, n);
925 p += n+1;
926
927 put16(p, m->u.serverHello.cipher);
928 p += 2;
929 p[0] = m->u.serverHello.compressor;
930 p += 1;
931 break;
932 case HServerHelloDone:
933 break;
934 case HCertificate:
935 nn = 0;
936 for(i = 0; i < m->u.certificate.ncert; i++)
937 nn += 3 + m->u.certificate.certs[i]->len;
938 if(p + 3 + nn - sendbuf > sizeof(sendbuf)) {
939 tlsError(c, EInternalError, "output buffer too small for certificate");
940 goto Err;
941 }
942 put24(p, nn);
943 p += 3;
944 for(i = 0; i < m->u.certificate.ncert; i++){
945 put24(p, m->u.certificate.certs[i]->len);
946 p += 3;
947 memmove(p, m->u.certificate.certs[i]->data, m->u.certificate.certs[i]->len);
948 p += m->u.certificate.certs[i]->len;
949 }
950 break;
951 case HClientKeyExchange:
952 n = m->u.clientKeyExchange.key->len;
953 if(c->version != SSL3Version){
954 put16(p, n);
955 p += 2;
956 }
957 memmove(p, m->u.clientKeyExchange.key->data, n);
958 p += n;
959 break;
960 case HFinished:
961 memmove(p, m->u.finished.verify, m->u.finished.n);
962 p += m->u.finished.n;
963 break;
964 }
965
966 // go back and fill in size
967 n = p-sendp;
968 assert(p <= sendbuf+sizeof(sendbuf));
969 put24(sendp+1, n-4);
970
971 // remember hash of Handshake messages
972 if(m->tag != HHelloRequest) {
973 msgHash(c, sendp, n);
974 }
975
976 sendp = p;
977 if(act == AFlush){
978 sendp = sendbuf;
979 if(write(c->hand, sendbuf, p-sendbuf) < 0){
980 fprint(2, "write error: %r\n");
981 goto Err;
982 }
983 }
984 msgClear(m);
985 return 1;
986 Err:
987 msgClear(m);
988 return 0;
989 }
990
991 static uchar*
tlsReadN(TlsConnection * c,int n)992 tlsReadN(TlsConnection *c, int n)
993 {
994 uchar *p;
995 int nn, nr;
996
997 nn = c->ep - c->rp;
998 if(nn < n){
999 if(c->rp != c->buf){
1000 memmove(c->buf, c->rp, nn);
1001 c->rp = c->buf;
1002 c->ep = &c->buf[nn];
1003 }
1004 for(; nn < n; nn += nr) {
1005 nr = read(c->hand, &c->rp[nn], n - nn);
1006 if(nr <= 0)
1007 return nil;
1008 c->ep += nr;
1009 }
1010 }
1011 p = c->rp;
1012 c->rp += n;
1013 return p;
1014 }
1015
1016 static int
msgRecv(TlsConnection * c,Msg * m)1017 msgRecv(TlsConnection *c, Msg *m)
1018 {
1019 uchar *p;
1020 int type, n, nn, nx, i, nsid, nrandom, nciph;
1021
1022 for(;;) {
1023 p = tlsReadN(c, 4);
1024 if(p == nil)
1025 return 0;
1026 type = p[0];
1027 n = get24(p+1);
1028
1029 if(type != HHelloRequest)
1030 break;
1031 if(n != 0) {
1032 tlsError(c, EDecodeError, "invalid hello request during handshake");
1033 return 0;
1034 }
1035 }
1036
1037 if(n > sizeof(c->buf)) {
1038 tlsError(c, EDecodeError, "handshake message too long %d %d", n, sizeof(c->buf));
1039 return 0;
1040 }
1041
1042 if(type == HSSL2ClientHello){
1043 /* Cope with an SSL3 ClientHello expressed in SSL2 record format.
1044 This is sent by some clients that we must interoperate
1045 with, such as Java's JSSE and Microsoft's Internet Explorer. */
1046 p = tlsReadN(c, n);
1047 if(p == nil)
1048 return 0;
1049 msgHash(c, p, n);
1050 m->tag = HClientHello;
1051 if(n < 22)
1052 goto Short;
1053 m->u.clientHello.version = get16(p+1);
1054 p += 3;
1055 n -= 3;
1056 nn = get16(p); /* cipher_spec_len */
1057 nsid = get16(p + 2);
1058 nrandom = get16(p + 4);
1059 p += 6;
1060 n -= 6;
1061 if(nsid != 0 /* no sid's, since shouldn't restart using ssl2 header */
1062 || nrandom < 16 || nn % 3)
1063 goto Err;
1064 if(c->trace && (n - nrandom != nn))
1065 c->trace("n-nrandom!=nn: n=%d nrandom=%d nn=%d\n", n, nrandom, nn);
1066 /* ignore ssl2 ciphers and look for {0x00, ssl3 cipher} */
1067 nciph = 0;
1068 for(i = 0; i < nn; i += 3)
1069 if(p[i] == 0)
1070 nciph++;
1071 m->u.clientHello.ciphers = newints(nciph);
1072 nciph = 0;
1073 for(i = 0; i < nn; i += 3)
1074 if(p[i] == 0)
1075 m->u.clientHello.ciphers->data[nciph++] = get16(&p[i + 1]);
1076 p += nn;
1077 m->u.clientHello.sid = makebytes(nil, 0);
1078 if(nrandom > RandomSize)
1079 nrandom = RandomSize;
1080 memset(m->u.clientHello.random, 0, RandomSize - nrandom);
1081 memmove(&m->u.clientHello.random[RandomSize - nrandom], p, nrandom);
1082 m->u.clientHello.compressors = newbytes(1);
1083 m->u.clientHello.compressors->data[0] = CompressionNull;
1084 goto Ok;
1085 }
1086
1087 msgHash(c, p, 4);
1088
1089 p = tlsReadN(c, n);
1090 if(p == nil)
1091 return 0;
1092
1093 msgHash(c, p, n);
1094
1095 m->tag = type;
1096
1097 switch(type) {
1098 default:
1099 tlsError(c, EUnexpectedMessage, "can't decode a %d", type);
1100 goto Err;
1101 case HClientHello:
1102 if(n < 2)
1103 goto Short;
1104 m->u.clientHello.version = get16(p);
1105 p += 2;
1106 n -= 2;
1107
1108 if(n < RandomSize)
1109 goto Short;
1110 memmove(m->u.clientHello.random, p, RandomSize);
1111 p += RandomSize;
1112 n -= RandomSize;
1113 if(n < 1 || n < p[0]+1)
1114 goto Short;
1115 m->u.clientHello.sid = makebytes(p+1, p[0]);
1116 p += m->u.clientHello.sid->len+1;
1117 n -= m->u.clientHello.sid->len+1;
1118
1119 if(n < 2)
1120 goto Short;
1121 nn = get16(p);
1122 p += 2;
1123 n -= 2;
1124
1125 if((nn & 1) || n < nn || nn < 2)
1126 goto Short;
1127 m->u.clientHello.ciphers = newints(nn >> 1);
1128 for(i = 0; i < nn; i += 2)
1129 m->u.clientHello.ciphers->data[i >> 1] = get16(&p[i]);
1130 p += nn;
1131 n -= nn;
1132
1133 if(n < 1 || n < p[0]+1 || p[0] == 0)
1134 goto Short;
1135 nn = p[0];
1136 m->u.clientHello.compressors = newbytes(nn);
1137 memmove(m->u.clientHello.compressors->data, p+1, nn);
1138 p += nn + 1;
1139 n -= nn + 1;
1140
1141 /* extensions */
1142 if(n == 0)
1143 break;
1144 if(n < 2)
1145 goto Short;
1146 nx = get16(p);
1147 p += 2;
1148 n -= 2;
1149 while(nx > 0){
1150 if(n < nx || nx < 4)
1151 goto Short;
1152 i = get16(p);
1153 nn = get16(p+2);
1154 if(nx < nn+4)
1155 goto Short;
1156 nx -= nn+4;
1157 p += 4;
1158 n -= 4;
1159 if(i == ExtSigalgs){
1160 if(get16(p) != nn-2)
1161 goto Short;
1162 p += 2;
1163 n -= 2;
1164 nn -= 2;
1165 m->u.clientHello.sigAlgs = newints(nn/2);
1166 for(i = 0; i < nn; i += 2)
1167 m->u.clientHello.sigAlgs->data[i >> 1] = get16(&p[i]);
1168 }
1169 p += nn;
1170 n -= nn;
1171 }
1172 break;
1173 case HServerHello:
1174 if(n < 2)
1175 goto Short;
1176 m->u.serverHello.version = get16(p);
1177 p += 2;
1178 n -= 2;
1179
1180 if(n < RandomSize)
1181 goto Short;
1182 memmove(m->u.serverHello.random, p, RandomSize);
1183 p += RandomSize;
1184 n -= RandomSize;
1185
1186 if(n < 1 || n < p[0]+1)
1187 goto Short;
1188 m->u.serverHello.sid = makebytes(p+1, p[0]);
1189 p += m->u.serverHello.sid->len+1;
1190 n -= m->u.serverHello.sid->len+1;
1191
1192 if(n < 3)
1193 goto Short;
1194 m->u.serverHello.cipher = get16(p);
1195 m->u.serverHello.compressor = p[2];
1196 n -= 3;
1197 break;
1198 case HCertificate:
1199 if(n < 3)
1200 goto Short;
1201 nn = get24(p);
1202 p += 3;
1203 n -= 3;
1204 if(n != nn)
1205 goto Short;
1206 /* certs */
1207 i = 0;
1208 while(n > 0) {
1209 if(n < 3)
1210 goto Short;
1211 nn = get24(p);
1212 p += 3;
1213 n -= 3;
1214 if(nn > n)
1215 goto Short;
1216 m->u.certificate.ncert = i+1;
1217 m->u.certificate.certs = erealloc(m->u.certificate.certs, (i+1)*sizeof(Bytes));
1218 m->u.certificate.certs[i] = makebytes(p, nn);
1219 p += nn;
1220 n -= nn;
1221 i++;
1222 }
1223 break;
1224 case HCertificateRequest:
1225 if(n < 1)
1226 goto Short;
1227 nn = p[0];
1228 p += 1;
1229 n -= 1;
1230 if(nn < 1 || nn > n)
1231 goto Short;
1232 m->u.certificateRequest.types = makebytes(p, nn);
1233 p += nn;
1234 n -= nn;
1235 if(n < 2)
1236 goto Short;
1237 nn = get16(p);
1238 p += 2;
1239 n -= 2;
1240 /* nn == 0 can happen; yahoo's servers do it */
1241 if(nn != n)
1242 goto Short;
1243 /* cas */
1244 i = 0;
1245 while(n > 0) {
1246 if(n < 2)
1247 goto Short;
1248 nn = get16(p);
1249 p += 2;
1250 n -= 2;
1251 if(nn < 1 || nn > n)
1252 goto Short;
1253 m->u.certificateRequest.nca = i+1;
1254 m->u.certificateRequest.cas = erealloc(
1255 m->u.certificateRequest.cas, (i+1)*sizeof(Bytes));
1256 m->u.certificateRequest.cas[i] = makebytes(p, nn);
1257 p += nn;
1258 n -= nn;
1259 i++;
1260 }
1261 break;
1262 case HServerHelloDone:
1263 break;
1264 case HClientKeyExchange:
1265 /*
1266 * this message depends upon the encryption selected
1267 * assume rsa.
1268 */
1269 if(c->version == SSL3Version)
1270 nn = n;
1271 else{
1272 if(n < 2)
1273 goto Short;
1274 nn = get16(p);
1275 p += 2;
1276 n -= 2;
1277 }
1278 if(n < nn)
1279 goto Short;
1280 m->u.clientKeyExchange.key = makebytes(p, nn);
1281 n -= nn;
1282 break;
1283 case HFinished:
1284 m->u.finished.n = c->finished.n;
1285 if(n < m->u.finished.n)
1286 goto Short;
1287 memmove(m->u.finished.verify, p, m->u.finished.n);
1288 n -= m->u.finished.n;
1289 break;
1290 }
1291
1292 if(type != HClientHello && n != 0)
1293 goto Short;
1294 Ok:
1295 if(c->trace){
1296 char *buf;
1297 buf = emalloc(8000);
1298 c->trace("recv %s", msgPrint(buf, 8000, m));
1299 free(buf);
1300 }
1301 return 1;
1302 Short:
1303 tlsError(c, EDecodeError, "handshake message has invalid length");
1304 Err:
1305 msgClear(m);
1306 return 0;
1307 }
1308
1309 static void
msgClear(Msg * m)1310 msgClear(Msg *m)
1311 {
1312 int i;
1313
1314 switch(m->tag) {
1315 default:
1316 sysfatal("msgClear: unknown message type: %d", m->tag);
1317 case HHelloRequest:
1318 break;
1319 case HClientHello:
1320 freebytes(m->u.clientHello.sid);
1321 freeints(m->u.clientHello.ciphers);
1322 m->u.clientHello.ciphers = nil;
1323 freebytes(m->u.clientHello.compressors);
1324 freeints(m->u.clientHello.sigAlgs);
1325 break;
1326 case HServerHello:
1327 freebytes(m->u.clientHello.sid);
1328 break;
1329 case HCertificate:
1330 for(i=0; i<m->u.certificate.ncert; i++)
1331 freebytes(m->u.certificate.certs[i]);
1332 free(m->u.certificate.certs);
1333 break;
1334 case HCertificateRequest:
1335 freebytes(m->u.certificateRequest.types);
1336 for(i=0; i<m->u.certificateRequest.nca; i++)
1337 freebytes(m->u.certificateRequest.cas[i]);
1338 free(m->u.certificateRequest.cas);
1339 break;
1340 case HServerHelloDone:
1341 break;
1342 case HClientKeyExchange:
1343 freebytes(m->u.clientKeyExchange.key);
1344 break;
1345 case HFinished:
1346 break;
1347 }
1348 memset(m, 0, sizeof(Msg));
1349 }
1350
1351 static char *
bytesPrint(char * bs,char * be,char * s0,Bytes * b,char * s1)1352 bytesPrint(char *bs, char *be, char *s0, Bytes *b, char *s1)
1353 {
1354 int i;
1355
1356 if(s0)
1357 bs = seprint(bs, be, "%s", s0);
1358 bs = seprint(bs, be, "[");
1359 if(b == nil)
1360 bs = seprint(bs, be, "nil");
1361 else
1362 for(i=0; i<b->len; i++)
1363 bs = seprint(bs, be, "%.2x ", b->data[i]);
1364 bs = seprint(bs, be, "]");
1365 if(s1)
1366 bs = seprint(bs, be, "%s", s1);
1367 return bs;
1368 }
1369
1370 static char *
intsPrint(char * bs,char * be,char * s0,Ints * b,char * s1)1371 intsPrint(char *bs, char *be, char *s0, Ints *b, char *s1)
1372 {
1373 int i;
1374
1375 if(s0)
1376 bs = seprint(bs, be, "%s", s0);
1377 bs = seprint(bs, be, "[");
1378 if(b == nil)
1379 bs = seprint(bs, be, "nil");
1380 else
1381 for(i=0; i<b->len; i++)
1382 bs = seprint(bs, be, "%x ", b->data[i]);
1383 bs = seprint(bs, be, "]");
1384 if(s1)
1385 bs = seprint(bs, be, "%s", s1);
1386 return bs;
1387 }
1388
1389 static char*
msgPrint(char * buf,int n,Msg * m)1390 msgPrint(char *buf, int n, Msg *m)
1391 {
1392 int i;
1393 char *bs = buf, *be = buf+n;
1394
1395 switch(m->tag) {
1396 default:
1397 bs = seprint(bs, be, "unknown %d\n", m->tag);
1398 break;
1399 case HClientHello:
1400 bs = seprint(bs, be, "ClientHello\n");
1401 bs = seprint(bs, be, "\tversion: %.4x\n", m->u.clientHello.version);
1402 bs = seprint(bs, be, "\trandom: ");
1403 for(i=0; i<RandomSize; i++)
1404 bs = seprint(bs, be, "%.2x", m->u.clientHello.random[i]);
1405 bs = seprint(bs, be, "\n");
1406 bs = bytesPrint(bs, be, "\tsid: ", m->u.clientHello.sid, "\n");
1407 bs = intsPrint(bs, be, "\tciphers: ", m->u.clientHello.ciphers, "\n");
1408 bs = bytesPrint(bs, be, "\tcompressors: ", m->u.clientHello.compressors, "\n");
1409 if(m->u.clientHello.sigAlgs != nil)
1410 bs = intsPrint(bs, be, "\tsigAlgs: ", m->u.clientHello.sigAlgs, "\n");
1411 break;
1412 case HServerHello:
1413 bs = seprint(bs, be, "ServerHello\n");
1414 bs = seprint(bs, be, "\tversion: %.4x\n", m->u.serverHello.version);
1415 bs = seprint(bs, be, "\trandom: ");
1416 for(i=0; i<RandomSize; i++)
1417 bs = seprint(bs, be, "%.2x", m->u.serverHello.random[i]);
1418 bs = seprint(bs, be, "\n");
1419 bs = bytesPrint(bs, be, "\tsid: ", m->u.serverHello.sid, "\n");
1420 bs = seprint(bs, be, "\tcipher: %.4x\n", m->u.serverHello.cipher);
1421 bs = seprint(bs, be, "\tcompressor: %.2x\n", m->u.serverHello.compressor);
1422 break;
1423 case HCertificate:
1424 bs = seprint(bs, be, "Certificate\n");
1425 for(i=0; i<m->u.certificate.ncert; i++)
1426 bs = bytesPrint(bs, be, "\t", m->u.certificate.certs[i], "\n");
1427 break;
1428 case HCertificateRequest:
1429 bs = seprint(bs, be, "CertificateRequest\n");
1430 bs = bytesPrint(bs, be, "\ttypes: ", m->u.certificateRequest.types, "\n");
1431 bs = seprint(bs, be, "\tcertificateauthorities\n");
1432 for(i=0; i<m->u.certificateRequest.nca; i++)
1433 bs = bytesPrint(bs, be, "\t\t", m->u.certificateRequest.cas[i], "\n");
1434 break;
1435 case HServerHelloDone:
1436 bs = seprint(bs, be, "ServerHelloDone\n");
1437 break;
1438 case HClientKeyExchange:
1439 bs = seprint(bs, be, "HClientKeyExchange\n");
1440 bs = bytesPrint(bs, be, "\tkey: ", m->u.clientKeyExchange.key, "\n");
1441 break;
1442 case HFinished:
1443 bs = seprint(bs, be, "HFinished\n");
1444 for(i=0; i<m->u.finished.n; i++)
1445 bs = seprint(bs, be, "%.2x", m->u.finished.verify[i]);
1446 bs = seprint(bs, be, "\n");
1447 break;
1448 }
1449 USED(bs);
1450 return buf;
1451 }
1452
1453 static void
tlsError(TlsConnection * c,int err,char * fmt,...)1454 tlsError(TlsConnection *c, int err, char *fmt, ...)
1455 {
1456 char msg[512];
1457 va_list arg;
1458
1459 va_start(arg, fmt);
1460 vseprint(msg, msg+sizeof(msg), fmt, arg);
1461 va_end(arg);
1462 if(c->trace)
1463 c->trace("tlsError: %s\n", msg);
1464 else if(c->erred)
1465 fprint(2, "double error: %r, %s", msg);
1466 else
1467 werrstr("tls: local %s", msg);
1468 c->erred = 1;
1469 fprint(c->ctl, "alert %d", err);
1470 }
1471
1472 // commit to specific version number
1473 static int
setVersion(TlsConnection * c,int version)1474 setVersion(TlsConnection *c, int version)
1475 {
1476 if(c->verset || version > MaxProtoVersion || version < MinProtoVersion)
1477 return -1;
1478 if(version > c->version)
1479 version = c->version;
1480 switch(version) {
1481 case SSL3Version:
1482 c->finished.n = SSL3FinishedLen;
1483 return -1;
1484 case TLS10Version:
1485 case TLS11Version:
1486 case TLS12Version:
1487 c->finished.n = TLSFinishedLen;
1488 break;
1489 default:
1490 return -1;
1491 }
1492 c->version = version;
1493 c->verset = 1;
1494 return fprint(c->ctl, "version 0x%x", version);
1495 }
1496
1497 // confirm that received Finished message matches the expected value
1498 static int
finishedMatch(TlsConnection * c,Finished * f)1499 finishedMatch(TlsConnection *c, Finished *f)
1500 {
1501 return memcmp(f->verify, c->finished.verify, f->n) == 0;
1502 }
1503
1504 // free memory associated with TlsConnection struct
1505 // (but don't close the TLS channel itself)
1506 static void
tlsConnectionFree(TlsConnection * c)1507 tlsConnectionFree(TlsConnection *c)
1508 {
1509 tlsSecClose(c->sec);
1510 freebytes(c->sid);
1511 freebytes(c->cert);
1512 memset(c, 0, sizeof *c);
1513 free(c);
1514 }
1515
1516
1517 //================= cipher choices ========================
1518
1519 static int weakCipher[CipherMax] =
1520 {
1521 1, /* TLS_NULL_WITH_NULL_NULL */
1522 1, /* TLS_RSA_WITH_NULL_MD5 */
1523 1, /* TLS_RSA_WITH_NULL_SHA */
1524 1, /* TLS_RSA_EXPORT_WITH_RC4_40_MD5 */
1525 1, /* TLS_RSA_WITH_RC4_128_MD5 */
1526 1, /* TLS_RSA_WITH_RC4_128_SHA */
1527 1, /* TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5 */
1528 0, /* TLS_RSA_WITH_IDEA_CBC_SHA */
1529 1, /* TLS_RSA_EXPORT_WITH_DES40_CBC_SHA */
1530 0, /* TLS_RSA_WITH_DES_CBC_SHA */
1531 0, /* TLS_RSA_WITH_3DES_EDE_CBC_SHA */
1532 1, /* TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA */
1533 0, /* TLS_DH_DSS_WITH_DES_CBC_SHA */
1534 0, /* TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA */
1535 1, /* TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA */
1536 0, /* TLS_DH_RSA_WITH_DES_CBC_SHA */
1537 0, /* TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA */
1538 1, /* TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA */
1539 0, /* TLS_DHE_DSS_WITH_DES_CBC_SHA */
1540 0, /* TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA */
1541 1, /* TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA */
1542 0, /* TLS_DHE_RSA_WITH_DES_CBC_SHA */
1543 0, /* TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA */
1544 1, /* TLS_DH_anon_EXPORT_WITH_RC4_40_MD5 */
1545 1, /* TLS_DH_anon_WITH_RC4_128_MD5 */
1546 1, /* TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA */
1547 1, /* TLS_DH_anon_WITH_DES_CBC_SHA */
1548 1, /* TLS_DH_anon_WITH_3DES_EDE_CBC_SHA */
1549 };
1550
1551 static int
setAlgs(TlsConnection * c,int a)1552 setAlgs(TlsConnection *c, int a)
1553 {
1554 int i;
1555
1556 for(i = 0; i < nelem(cipherAlgs); i++){
1557 if(cipherAlgs[i].tlsid == a){
1558 c->enc = cipherAlgs[i].enc;
1559 c->digest = cipherAlgs[i].digest;
1560 c->nsecret = cipherAlgs[i].nsecret;
1561 if(c->nsecret > MaxKeyData)
1562 return 0;
1563 return 1;
1564 }
1565 }
1566 return 0;
1567 }
1568
1569 static int
okCipher(Ints * cv)1570 okCipher(Ints *cv)
1571 {
1572 int weak, i, j, c;
1573
1574 weak = 1;
1575 for(i = 0; i < cv->len; i++) {
1576 c = cv->data[i];
1577 if(c >= CipherMax)
1578 weak = 0;
1579 else
1580 weak &= weakCipher[c];
1581 for(j = 0; j < nelem(cipherAlgs); j++)
1582 if(cipherAlgs[j].ok && cipherAlgs[j].tlsid == c)
1583 return c;
1584 }
1585 if(weak)
1586 return -2;
1587 return -1;
1588 }
1589
1590 static int
okCompression(Bytes * cv)1591 okCompression(Bytes *cv)
1592 {
1593 int i, j, c;
1594
1595 for(i = 0; i < cv->len; i++) {
1596 c = cv->data[i];
1597 for(j = 0; j < nelem(compressors); j++) {
1598 if(compressors[j] == c)
1599 return c;
1600 }
1601 }
1602 return -1;
1603 }
1604
1605 static Lock ciphLock;
1606 static int nciphers;
1607
1608 static int
initCiphers(void)1609 initCiphers(void)
1610 {
1611 enum {MaxAlgF = 1024, MaxAlgs = 10};
1612 char s[MaxAlgF], *flds[MaxAlgs];
1613 int i, j, n, ok;
1614
1615 lock(&ciphLock);
1616 if(nciphers){
1617 unlock(&ciphLock);
1618 return nciphers;
1619 }
1620 j = open("#a/tls/encalgs", OREAD);
1621 if(j < 0){
1622 werrstr("can't open #a/tls/encalgs: %r");
1623 unlock(&ciphLock);
1624 return 0;
1625 }
1626 n = read(j, s, MaxAlgF-1);
1627 close(j);
1628 if(n <= 0){
1629 werrstr("nothing in #a/tls/encalgs: %r");
1630 unlock(&ciphLock);
1631 return 0;
1632 }
1633 s[n] = 0;
1634 n = getfields(s, flds, MaxAlgs, 1, " \t\r\n");
1635 for(i = 0; i < nelem(cipherAlgs); i++){
1636 ok = 0;
1637 for(j = 0; j < n; j++){
1638 if(strcmp(cipherAlgs[i].enc, flds[j]) == 0){
1639 ok = 1;
1640 break;
1641 }
1642 }
1643 cipherAlgs[i].ok = ok;
1644 }
1645
1646 j = open("#a/tls/hashalgs", OREAD);
1647 if(j < 0){
1648 werrstr("can't open #a/tls/hashalgs: %r");
1649 unlock(&ciphLock);
1650 return 0;
1651 }
1652 n = read(j, s, MaxAlgF-1);
1653 close(j);
1654 if(n <= 0){
1655 werrstr("nothing in #a/tls/hashalgs: %r");
1656 unlock(&ciphLock);
1657 return 0;
1658 }
1659 s[n] = 0;
1660 n = getfields(s, flds, MaxAlgs, 1, " \t\r\n");
1661 for(i = 0; i < nelem(cipherAlgs); i++){
1662 ok = 0;
1663 for(j = 0; j < n; j++){
1664 if(strcmp(cipherAlgs[i].digest, flds[j]) == 0){
1665 ok = 1;
1666 break;
1667 }
1668 }
1669 cipherAlgs[i].ok &= ok;
1670 if(cipherAlgs[i].ok)
1671 nciphers++;
1672 }
1673 unlock(&ciphLock);
1674 return nciphers;
1675 }
1676
1677 static Ints*
makeciphers(void)1678 makeciphers(void)
1679 {
1680 Ints *is;
1681 int i, j;
1682
1683 is = newints(nciphers);
1684 j = 0;
1685 for(i = 0; i < nelem(cipherAlgs); i++){
1686 if(cipherAlgs[i].ok)
1687 is->data[j++] = cipherAlgs[i].tlsid;
1688 }
1689 return is;
1690 }
1691
1692
1693
1694 //================= security functions ========================
1695
1696 // given X.509 certificate, set up connection to factotum
1697 // for using corresponding private key
1698 static AuthRpc*
factotum_rsa_open(uchar * cert,int certlen)1699 factotum_rsa_open(uchar *cert, int certlen)
1700 {
1701 int afd;
1702 char *s;
1703 mpint *pub = nil;
1704 RSApub *rsapub;
1705 AuthRpc *rpc;
1706
1707 // start talking to factotum
1708 if((afd = open("/mnt/factotum/rpc", ORDWR)) < 0)
1709 return nil;
1710 if((rpc = auth_allocrpc(afd)) == nil){
1711 close(afd);
1712 return nil;
1713 }
1714 s = "proto=rsa service=tls role=client";
1715 if(auth_rpc(rpc, "start", s, strlen(s)) != ARok){
1716 factotum_rsa_close(rpc);
1717 return nil;
1718 }
1719
1720 // roll factotum keyring around to match certificate
1721 rsapub = X509toRSApub(cert, certlen, nil, 0);
1722 while(1){
1723 if(auth_rpc(rpc, "read", nil, 0) != ARok){
1724 factotum_rsa_close(rpc);
1725 rpc = nil;
1726 goto done;
1727 }
1728 pub = strtomp(rpc->arg, nil, 16, nil);
1729 assert(pub != nil);
1730 if(mpcmp(pub,rsapub->n) == 0)
1731 break;
1732 }
1733 done:
1734 mpfree(pub);
1735 rsapubfree(rsapub);
1736 return rpc;
1737 }
1738
1739 static mpint*
factotum_rsa_decrypt(AuthRpc * rpc,mpint * cipher)1740 factotum_rsa_decrypt(AuthRpc *rpc, mpint *cipher)
1741 {
1742 char *p;
1743 int rv;
1744
1745 if((p = mptoa(cipher, 16, nil, 0)) == nil)
1746 return nil;
1747 rv = auth_rpc(rpc, "write", p, strlen(p));
1748 free(p);
1749 if(rv != ARok || auth_rpc(rpc, "read", nil, 0) != ARok)
1750 return nil;
1751 mpfree(cipher);
1752 return strtomp(rpc->arg, nil, 16, nil);
1753 }
1754
1755 static void
factotum_rsa_close(AuthRpc * rpc)1756 factotum_rsa_close(AuthRpc*rpc)
1757 {
1758 if(!rpc)
1759 return;
1760 close(rpc->afd);
1761 auth_freerpc(rpc);
1762 }
1763
1764 static void
tlsPmd5(uchar * buf,int nbuf,uchar * key,int nkey,uchar * label,int nlabel,uchar * seed0,int nseed0,uchar * seed1,int nseed1)1765 tlsPmd5(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
1766 {
1767 uchar ai[MD5dlen], tmp[MD5dlen];
1768 int i, n;
1769 MD5state *s;
1770
1771 // generate a1
1772 s = hmac_md5(label, nlabel, key, nkey, nil, nil);
1773 s = hmac_md5(seed0, nseed0, key, nkey, nil, s);
1774 hmac_md5(seed1, nseed1, key, nkey, ai, s);
1775
1776 while(nbuf > 0) {
1777 s = hmac_md5(ai, MD5dlen, key, nkey, nil, nil);
1778 s = hmac_md5(label, nlabel, key, nkey, nil, s);
1779 s = hmac_md5(seed0, nseed0, key, nkey, nil, s);
1780 hmac_md5(seed1, nseed1, key, nkey, tmp, s);
1781 n = MD5dlen;
1782 if(n > nbuf)
1783 n = nbuf;
1784 for(i = 0; i < n; i++)
1785 buf[i] ^= tmp[i];
1786 buf += n;
1787 nbuf -= n;
1788 hmac_md5(ai, MD5dlen, key, nkey, tmp, nil);
1789 memmove(ai, tmp, MD5dlen);
1790 }
1791 }
1792
1793 static void
tlsPsha1(uchar * buf,int nbuf,uchar * key,int nkey,uchar * label,int nlabel,uchar * seed0,int nseed0,uchar * seed1,int nseed1)1794 tlsPsha1(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
1795 {
1796 uchar ai[SHA1dlen], tmp[SHA1dlen];
1797 int i, n;
1798 SHAstate *s;
1799
1800 // generate a1
1801 s = hmac_sha1(label, nlabel, key, nkey, nil, nil);
1802 s = hmac_sha1(seed0, nseed0, key, nkey, nil, s);
1803 hmac_sha1(seed1, nseed1, key, nkey, ai, s);
1804
1805 while(nbuf > 0) {
1806 s = hmac_sha1(ai, SHA1dlen, key, nkey, nil, nil);
1807 s = hmac_sha1(label, nlabel, key, nkey, nil, s);
1808 s = hmac_sha1(seed0, nseed0, key, nkey, nil, s);
1809 hmac_sha1(seed1, nseed1, key, nkey, tmp, s);
1810 n = SHA1dlen;
1811 if(n > nbuf)
1812 n = nbuf;
1813 for(i = 0; i < n; i++)
1814 buf[i] ^= tmp[i];
1815 buf += n;
1816 nbuf -= n;
1817 hmac_sha1(ai, SHA1dlen, key, nkey, tmp, nil);
1818 memmove(ai, tmp, SHA1dlen);
1819 }
1820 }
1821
1822 static void
tlsPsha2_256(uchar * buf,int nbuf,uchar * key,int nkey,uchar * label,int nlabel,uchar * seed,int nseed)1823 tlsPsha2_256(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, uchar *seed, int nseed)
1824 {
1825 uchar ai[SHA2_256dlen], tmp[SHA2_256dlen];
1826 int n;
1827 SHAstate *s;
1828
1829 // generate a1
1830 s = hmac_sha2_256(label, nlabel, key, nkey, nil, nil);
1831 hmac_sha2_256(seed, nseed, key, nkey, ai, s);
1832
1833 while(nbuf > 0) {
1834 s = hmac_sha2_256(ai, SHA2_256dlen, key, nkey, nil, nil);
1835 s = hmac_sha2_256(label, nlabel, key, nkey, nil, s);
1836 hmac_sha2_256(seed, nseed, key, nkey, tmp, s);
1837 n = SHA2_256dlen;
1838 if(n > nbuf)
1839 n = nbuf;
1840 memmove(buf, tmp, n);
1841 buf += n;
1842 nbuf -= n;
1843 hmac_sha2_256(ai, SHA2_256dlen, key, nkey, tmp, nil);
1844 memmove(ai, tmp, SHA2_256dlen);
1845 }
1846 }
1847
1848 // fill buf with md5(args)^sha1(args)
1849 static void
tlsPRF(uchar * buf,int nbuf,uchar * key,int nkey,char * label,uchar * seed0,int nseed0,uchar * seed1,int nseed1)1850 tlsPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
1851 {
1852 int i;
1853 int nlabel = strlen(label);
1854 int n = (nkey + 1) >> 1;
1855
1856 for(i = 0; i < nbuf; i++)
1857 buf[i] = 0;
1858 tlsPmd5(buf, nbuf, key, n, (uchar*)label, nlabel, seed0, nseed0, seed1, nseed1);
1859 tlsPsha1(buf, nbuf, key+nkey-n, n, (uchar*)label, nlabel, seed0, nseed0, seed1, nseed1);
1860 }
1861
1862 void
tls12PRF(uchar * buf,int nbuf,uchar * key,int nkey,char * label,uchar * seed0,int nseed0,uchar * seed1,int nseed1)1863 tls12PRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
1864 {
1865 uchar seed[2*RandomSize];
1866 int nlabel = strlen(label);
1867
1868 memmove(seed, seed0, nseed0);
1869 memmove(seed+nseed0, seed1, nseed1);
1870 tlsPsha2_256(buf, nbuf, key, nkey, (uchar*)label, nlabel, seed, nseed0+nseed1);
1871 }
1872
1873 /*
1874 * for setting server session id's
1875 */
1876 static Lock sidLock;
1877 static long maxSid = 1;
1878
1879 /* the keys are verified to have the same public components
1880 * and to function correctly with pkcs 1 encryption and decryption. */
1881 static TlsSec*
tlsSecInits(int cvers,uchar * csid,int ncsid,uchar * crandom,uchar * ssid,int * nssid,uchar * srandom)1882 tlsSecInits(int cvers, uchar *csid, int ncsid, uchar *crandom, uchar *ssid, int *nssid, uchar *srandom)
1883 {
1884 TlsSec *sec = emalloc(sizeof(*sec));
1885
1886 USED(csid); USED(ncsid); // ignore csid for now
1887
1888 memmove(sec->crandom, crandom, RandomSize);
1889 sec->clientVers = cvers;
1890
1891 put32(sec->srandom, time(0));
1892 genrandom(sec->srandom+4, RandomSize-4);
1893 memmove(srandom, sec->srandom, RandomSize);
1894
1895 /*
1896 * make up a unique sid: use our pid, and and incrementing id
1897 * can signal no sid by setting nssid to 0.
1898 */
1899 memset(ssid, 0, SidSize);
1900 put32(ssid, getpid());
1901 lock(&sidLock);
1902 put32(ssid+4, maxSid++);
1903 unlock(&sidLock);
1904 *nssid = SidSize;
1905 return sec;
1906 }
1907
1908 static int
tlsSecSecrets(TlsSec * sec,int vers,uchar * epm,int nepm,uchar * kd,int nkd)1909 tlsSecSecrets(TlsSec *sec, int vers, uchar *epm, int nepm, uchar *kd, int nkd)
1910 {
1911 if(epm != nil){
1912 if(setVers(sec, vers) < 0)
1913 goto Err;
1914 serverMasterSecret(sec, epm, nepm);
1915 }else if(sec->vers != vers){
1916 werrstr("mismatched session versions");
1917 goto Err;
1918 }
1919 setSecrets(sec, kd, nkd);
1920 return 0;
1921 Err:
1922 sec->ok = -1;
1923 return -1;
1924 }
1925
1926 static TlsSec*
tlsSecInitc(int cvers,uchar * crandom)1927 tlsSecInitc(int cvers, uchar *crandom)
1928 {
1929 TlsSec *sec = emalloc(sizeof(*sec));
1930 sec->clientVers = cvers;
1931 put32(sec->crandom, time(0));
1932 genrandom(sec->crandom+4, RandomSize-4);
1933 memmove(crandom, sec->crandom, RandomSize);
1934 return sec;
1935 }
1936
1937 static int
tlsSecSecretc(TlsSec * sec,uchar * sid,int nsid,uchar * srandom,uchar * cert,int ncert,int vers,uchar ** epm,int * nepm,uchar * kd,int nkd)1938 tlsSecSecretc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers, uchar **epm, int *nepm, uchar *kd, int nkd)
1939 {
1940 RSApub *pub;
1941
1942 pub = nil;
1943
1944 USED(sid);
1945 USED(nsid);
1946
1947 memmove(sec->srandom, srandom, RandomSize);
1948
1949 if(setVers(sec, vers) < 0)
1950 goto Err;
1951
1952 pub = X509toRSApub(cert, ncert, nil, 0);
1953 if(pub == nil){
1954 werrstr("invalid x509/rsa certificate");
1955 goto Err;
1956 }
1957 if(clientMasterSecret(sec, pub, epm, nepm) < 0)
1958 goto Err;
1959 rsapubfree(pub);
1960 setSecrets(sec, kd, nkd);
1961 return 0;
1962
1963 Err:
1964 if(pub != nil)
1965 rsapubfree(pub);
1966 sec->ok = -1;
1967 return -1;
1968 }
1969
1970 static int
tlsSecFinished(TlsSec * sec,HandHash hs,uchar * fin,int nfin,int isclient)1971 tlsSecFinished(TlsSec *sec, HandHash hs, uchar *fin, int nfin, int isclient)
1972 {
1973 if(sec->nfin != nfin){
1974 sec->ok = -1;
1975 werrstr("invalid finished exchange");
1976 return -1;
1977 }
1978 hs.md5.malloced = 0;
1979 hs.sha1.malloced = 0;
1980 if(sec->setFinished == nil ){
1981 sec->ok = -1;
1982 werrstr("nil sec->setFinished in tlsSecFinished");
1983 return -1;
1984 }
1985 (*sec->setFinished)(sec, hs, fin, isclient);
1986 return 1;
1987 }
1988
1989 static void
tlsSecOk(TlsSec * sec)1990 tlsSecOk(TlsSec *sec)
1991 {
1992 if(sec->ok == 0)
1993 sec->ok = 1;
1994 }
1995
1996 static void
tlsSecKill(TlsSec * sec)1997 tlsSecKill(TlsSec *sec)
1998 {
1999 if(!sec)
2000 return;
2001 factotum_rsa_close(sec->rpc);
2002 sec->ok = -1;
2003 }
2004
2005 static void
tlsSecClose(TlsSec * sec)2006 tlsSecClose(TlsSec *sec)
2007 {
2008 if(!sec)
2009 return;
2010 factotum_rsa_close(sec->rpc);
2011 free(sec->server);
2012 free(sec);
2013 }
2014
2015 static int
setVers(TlsSec * sec,int v)2016 setVers(TlsSec *sec, int v)
2017 {
2018 switch(v){
2019 case SSL3Version:
2020 sec->setFinished = sslSetFinished;
2021 sec->nfin = SSL3FinishedLen;
2022 sec->prf = sslPRF;
2023 break;
2024 case TLS10Version:
2025 case TLS11Version:
2026 sec->setFinished = tlsSetFinished;
2027 sec->nfin = TLSFinishedLen;
2028 sec->prf = tlsPRF;
2029 break;
2030 case TLS12Version:
2031 sec->setFinished = tls12SetFinished;
2032 sec->nfin = TLSFinishedLen;
2033 sec->prf = tls12PRF;
2034 break;
2035 default:
2036 werrstr("invalid version");
2037 sec->setFinished = nil;
2038 sec->prf = nil;
2039 return -1;
2040 }
2041 sec->vers = v;
2042 return 0;
2043 }
2044
2045 /*
2046 * generate secret keys from the master secret.
2047 *
2048 * different crypto selections will require different amounts
2049 * of key expansion and use of key expansion data,
2050 * but it's all generated using the same function.
2051 */
2052 static void
setSecrets(TlsSec * sec,uchar * kd,int nkd)2053 setSecrets(TlsSec *sec, uchar *kd, int nkd)
2054 {
2055 if (sec->prf == nil) {
2056 werrstr("nil sec->prf in setSecrets");
2057 return;
2058 }
2059 (*sec->prf)(kd, nkd, sec->sec, MasterSecretSize, "key expansion",
2060 sec->srandom, RandomSize, sec->crandom, RandomSize);
2061 }
2062
2063 /*
2064 * set the master secret from the pre-master secret.
2065 */
2066 static void
setMasterSecret(TlsSec * sec,Bytes * pm)2067 setMasterSecret(TlsSec *sec, Bytes *pm)
2068 {
2069 (*sec->prf)(sec->sec, MasterSecretSize, pm->data, MasterSecretSize, "master secret",
2070 sec->crandom, RandomSize, sec->srandom, RandomSize);
2071 }
2072
2073 static void
serverMasterSecret(TlsSec * sec,uchar * epm,int nepm)2074 serverMasterSecret(TlsSec *sec, uchar *epm, int nepm)
2075 {
2076 Bytes *pm;
2077
2078 pm = pkcs1_decrypt(sec, epm, nepm);
2079
2080 // if the client messed up, just continue as if everything is ok,
2081 // to prevent attacks to check for correctly formatted messages.
2082 // Hence the fprint(2,) can't be replaced by tlsError(), which sends an Alert msg to the client.
2083 if(sec->ok < 0 || pm == nil || get16(pm->data) != sec->clientVers){
2084 fprint(2, "serverMasterSecret failed ok=%d pm=%p pmvers=%x cvers=%x nepm=%d\n",
2085 sec->ok, pm, pm ? get16(pm->data) : -1, sec->clientVers, nepm);
2086 sec->ok = -1;
2087 if(pm != nil)
2088 freebytes(pm);
2089 pm = newbytes(MasterSecretSize);
2090 genrandom(pm->data, MasterSecretSize);
2091 }
2092 setMasterSecret(sec, pm);
2093 memset(pm->data, 0, pm->len);
2094 freebytes(pm);
2095 }
2096
2097 static int
clientMasterSecret(TlsSec * sec,RSApub * pub,uchar ** epm,int * nepm)2098 clientMasterSecret(TlsSec *sec, RSApub *pub, uchar **epm, int *nepm)
2099 {
2100 Bytes *pm, *key;
2101
2102 pm = newbytes(MasterSecretSize);
2103 put16(pm->data, sec->clientVers);
2104 genrandom(pm->data+2, MasterSecretSize - 2);
2105
2106 setMasterSecret(sec, pm);
2107
2108 key = pkcs1_encrypt(pm, pub, 2);
2109 memset(pm->data, 0, pm->len);
2110 freebytes(pm);
2111 if(key == nil){
2112 werrstr("tls pkcs1_encrypt failed");
2113 return -1;
2114 }
2115
2116 *nepm = key->len;
2117 *epm = malloc(*nepm);
2118 if(*epm == nil){
2119 freebytes(key);
2120 werrstr("out of memory");
2121 return -1;
2122 }
2123 memmove(*epm, key->data, *nepm);
2124
2125 freebytes(key);
2126
2127 return 1;
2128 }
2129
2130 static void
sslSetFinished(TlsSec * sec,HandHash hs,uchar * finished,int isClient)2131 sslSetFinished(TlsSec *sec, HandHash hs, uchar *finished, int isClient)
2132 {
2133 DigestState *s;
2134 uchar h0[MD5dlen], h1[SHA1dlen], pad[48];
2135 char *label;
2136
2137 if(isClient)
2138 label = "CLNT";
2139 else
2140 label = "SRVR";
2141
2142 md5((uchar*)label, 4, nil, &hs.md5);
2143 md5(sec->sec, MasterSecretSize, nil, &hs.md5);
2144 memset(pad, 0x36, 48);
2145 md5(pad, 48, nil, &hs.md5);
2146 md5(nil, 0, h0, &hs.md5);
2147 memset(pad, 0x5C, 48);
2148 s = md5(sec->sec, MasterSecretSize, nil, nil);
2149 s = md5(pad, 48, nil, s);
2150 md5(h0, MD5dlen, finished, s);
2151
2152 sha1((uchar*)label, 4, nil, &hs.sha1);
2153 sha1(sec->sec, MasterSecretSize, nil, &hs.sha1);
2154 memset(pad, 0x36, 40);
2155 sha1(pad, 40, nil, &hs.sha1);
2156 sha1(nil, 0, h1, &hs.sha1);
2157 memset(pad, 0x5C, 40);
2158 s = sha1(sec->sec, MasterSecretSize, nil, nil);
2159 s = sha1(pad, 40, nil, s);
2160 sha1(h1, SHA1dlen, finished + MD5dlen, s);
2161 }
2162
2163 // fill "finished" arg with md5(args)^sha1(args)
2164 static void
tlsSetFinished(TlsSec * sec,HandHash hs,uchar * finished,int isClient)2165 tlsSetFinished(TlsSec *sec, HandHash hs, uchar *finished, int isClient)
2166 {
2167 uchar h0[MD5dlen], h1[SHA1dlen];
2168 char *label;
2169
2170 // get current hash value, but allow further messages to be hashed in
2171 md5(nil, 0, h0, &hs.md5);
2172 sha1(nil, 0, h1, &hs.sha1);
2173
2174 if(isClient)
2175 label = "client finished";
2176 else
2177 label = "server finished";
2178 if (sec->prf == nil) {
2179 werrstr("nil sec->prf in tlsSetFinished");
2180 return;
2181 }
2182 (*sec->prf)(finished, TLSFinishedLen, sec->sec, MasterSecretSize, label, h0, MD5dlen, h1, SHA1dlen);
2183 }
2184
2185 // fill "finished" arg with sha256(args)
2186 static void
tls12SetFinished(TlsSec * sec,HandHash hs,uchar * finished,int isClient)2187 tls12SetFinished(TlsSec *sec, HandHash hs, uchar *finished, int isClient)
2188 {
2189 uchar h[SHA2_256dlen];
2190 char *label;
2191
2192 // get current hash value, but allow further messages to be hashed in
2193 sha2_256(nil, 0, h, &hs.sha2_256);
2194
2195 if(isClient)
2196 label = "client finished";
2197 else
2198 label = "server finished";
2199 tlsPsha2_256(finished, TLSFinishedLen, sec->sec, MasterSecretSize, (uchar*)label, strlen(label), h, SHA2_256dlen);
2200 }
2201
2202 static void
sslPRF(uchar * buf,int nbuf,uchar * key,int nkey,char * label,uchar * seed0,int nseed0,uchar * seed1,int nseed1)2203 sslPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
2204 {
2205 DigestState *s;
2206 uchar sha1dig[SHA1dlen], md5dig[MD5dlen], tmp[26];
2207 int i, n, len;
2208
2209 USED(label);
2210 len = 1;
2211 while(nbuf > 0){
2212 if(len > 26)
2213 return;
2214 for(i = 0; i < len; i++)
2215 tmp[i] = 'A' - 1 + len;
2216 s = sha1(tmp, len, nil, nil);
2217 s = sha1(key, nkey, nil, s);
2218 s = sha1(seed0, nseed0, nil, s);
2219 sha1(seed1, nseed1, sha1dig, s);
2220 s = md5(key, nkey, nil, nil);
2221 md5(sha1dig, SHA1dlen, md5dig, s);
2222 n = MD5dlen;
2223 if(n > nbuf)
2224 n = nbuf;
2225 memmove(buf, md5dig, n);
2226 buf += n;
2227 nbuf -= n;
2228 len++;
2229 }
2230 }
2231
2232 static mpint*
bytestomp(Bytes * bytes)2233 bytestomp(Bytes* bytes)
2234 {
2235 mpint* ans;
2236
2237 ans = betomp(bytes->data, bytes->len, nil);
2238 return ans;
2239 }
2240
2241 /*
2242 * Convert mpint* to Bytes, putting high order byte first.
2243 */
2244 static Bytes*
mptobytes(mpint * big)2245 mptobytes(mpint* big)
2246 {
2247 int n, m;
2248 uchar *a;
2249 Bytes* ans;
2250
2251 a = nil;
2252 n = (mpsignif(big)+7)/8;
2253 m = mptobe(big, nil, n, &a);
2254 ans = makebytes(a, m);
2255 if(a != nil)
2256 free(a);
2257 return ans;
2258 }
2259
2260 // Do RSA computation on block according to key, and pad
2261 // result on left with zeros to make it modlen long.
2262 static Bytes*
rsacomp(Bytes * block,RSApub * key,int modlen)2263 rsacomp(Bytes* block, RSApub* key, int modlen)
2264 {
2265 mpint *x, *y;
2266 Bytes *a, *ybytes;
2267 int ylen;
2268
2269 x = bytestomp(block);
2270 y = rsaencrypt(key, x, nil);
2271 mpfree(x);
2272 ybytes = mptobytes(y);
2273 ylen = ybytes->len;
2274
2275 if(ylen < modlen) {
2276 a = newbytes(modlen);
2277 memset(a->data, 0, modlen-ylen);
2278 memmove(a->data+modlen-ylen, ybytes->data, ylen);
2279 freebytes(ybytes);
2280 ybytes = a;
2281 }
2282 else if(ylen > modlen) {
2283 // assume it has leading zeros (mod should make it so)
2284 a = newbytes(modlen);
2285 memmove(a->data, ybytes->data, modlen);
2286 freebytes(ybytes);
2287 ybytes = a;
2288 }
2289 mpfree(y);
2290 return ybytes;
2291 }
2292
2293 // encrypt data according to PKCS#1, /lib/rfc/rfc2437 9.1.2.1
2294 static Bytes*
pkcs1_encrypt(Bytes * data,RSApub * key,int blocktype)2295 pkcs1_encrypt(Bytes* data, RSApub* key, int blocktype)
2296 {
2297 Bytes *pad, *eb, *ans;
2298 int i, dlen, padlen, modlen;
2299
2300 modlen = (mpsignif(key->n)+7)/8;
2301 dlen = data->len;
2302 if(modlen < 12 || dlen > modlen - 11)
2303 return nil;
2304 padlen = modlen - 3 - dlen;
2305 pad = newbytes(padlen);
2306 genrandom(pad->data, padlen);
2307 for(i = 0; i < padlen; i++) {
2308 if(blocktype == 0)
2309 pad->data[i] = 0;
2310 else if(blocktype == 1)
2311 pad->data[i] = 255;
2312 else if(pad->data[i] == 0)
2313 pad->data[i] = 1;
2314 }
2315 eb = newbytes(modlen);
2316 eb->data[0] = 0;
2317 eb->data[1] = blocktype;
2318 memmove(eb->data+2, pad->data, padlen);
2319 eb->data[padlen+2] = 0;
2320 memmove(eb->data+padlen+3, data->data, dlen);
2321 ans = rsacomp(eb, key, modlen);
2322 freebytes(eb);
2323 freebytes(pad);
2324 return ans;
2325 }
2326
2327 // decrypt data according to PKCS#1, with given key.
2328 // expect a block type of 2.
2329 static Bytes*
pkcs1_decrypt(TlsSec * sec,uchar * epm,int nepm)2330 pkcs1_decrypt(TlsSec *sec, uchar *epm, int nepm)
2331 {
2332 Bytes *eb, *ans = nil;
2333 int i, modlen;
2334 mpint *x, *y;
2335
2336 modlen = (mpsignif(sec->rsapub->n)+7)/8;
2337 if(nepm != modlen)
2338 return nil;
2339 x = betomp(epm, nepm, nil);
2340 y = factotum_rsa_decrypt(sec->rpc, x);
2341 if(y == nil)
2342 return nil;
2343 eb = mptobytes(y);
2344 if(eb->len < modlen){ // pad on left with zeros
2345 ans = newbytes(modlen);
2346 memset(ans->data, 0, modlen-eb->len);
2347 memmove(ans->data+modlen-eb->len, eb->data, eb->len);
2348 freebytes(eb);
2349 eb = ans;
2350 }
2351 if(eb->data[0] == 0 && eb->data[1] == 2) {
2352 for(i = 2; i < modlen; i++)
2353 if(eb->data[i] == 0)
2354 break;
2355 if(i < modlen - 1)
2356 ans = makebytes(eb->data+i+1, modlen-(i+1));
2357 }
2358 if (eb != ans) /* not freed above? */
2359 freebytes(eb);
2360 return ans;
2361 }
2362
2363
2364 //================= general utility functions ========================
2365
2366 static void *
emalloc(int n)2367 emalloc(int n)
2368 {
2369 void *p;
2370 if(n==0)
2371 n=1;
2372 p = malloc(n);
2373 if(p == nil){
2374 exits("out of memory");
2375 }
2376 memset(p, 0, n);
2377 return p;
2378 }
2379
2380 static void *
erealloc(void * ReallocP,int ReallocN)2381 erealloc(void *ReallocP, int ReallocN)
2382 {
2383 if(ReallocN == 0)
2384 ReallocN = 1;
2385 if(!ReallocP)
2386 ReallocP = emalloc(ReallocN);
2387 else if(!(ReallocP = realloc(ReallocP, ReallocN))){
2388 exits("out of memory");
2389 }
2390 return(ReallocP);
2391 }
2392
2393 static void
put32(uchar * p,u32int x)2394 put32(uchar *p, u32int x)
2395 {
2396 p[0] = x>>24;
2397 p[1] = x>>16;
2398 p[2] = x>>8;
2399 p[3] = x;
2400 }
2401
2402 static void
put24(uchar * p,int x)2403 put24(uchar *p, int x)
2404 {
2405 p[0] = x>>16;
2406 p[1] = x>>8;
2407 p[2] = x;
2408 }
2409
2410 static void
put16(uchar * p,int x)2411 put16(uchar *p, int x)
2412 {
2413 p[0] = x>>8;
2414 p[1] = x;
2415 }
2416
2417 static u32int
get32(uchar * p)2418 get32(uchar *p)
2419 {
2420 return (p[0]<<24)|(p[1]<<16)|(p[2]<<8)|p[3];
2421 }
2422
2423 static int
get24(uchar * p)2424 get24(uchar *p)
2425 {
2426 return (p[0]<<16)|(p[1]<<8)|p[2];
2427 }
2428
2429 static int
get16(uchar * p)2430 get16(uchar *p)
2431 {
2432 return (p[0]<<8)|p[1];
2433 }
2434
2435 #define OFFSET(x, s) offsetof(s, x)
2436
2437 /*
2438 * malloc and return a new Bytes structure capable of
2439 * holding len bytes. (len >= 0)
2440 * Used to use crypt_malloc, which aborts if malloc fails.
2441 */
2442 static Bytes*
newbytes(int len)2443 newbytes(int len)
2444 {
2445 Bytes* ans;
2446
2447 ans = (Bytes*)malloc(OFFSET(data[0], Bytes) + len);
2448 ans->len = len;
2449 return ans;
2450 }
2451
2452 /*
2453 * newbytes(len), with data initialized from buf
2454 */
2455 static Bytes*
makebytes(uchar * buf,int len)2456 makebytes(uchar* buf, int len)
2457 {
2458 Bytes* ans;
2459
2460 ans = newbytes(len);
2461 memmove(ans->data, buf, len);
2462 return ans;
2463 }
2464
2465 static void
freebytes(Bytes * b)2466 freebytes(Bytes* b)
2467 {
2468 if(b != nil)
2469 free(b);
2470 }
2471
2472 /* len is number of ints */
2473 static Ints*
newints(int len)2474 newints(int len)
2475 {
2476 Ints* ans;
2477
2478 ans = (Ints*)malloc(OFFSET(data[0], Ints) + len*sizeof(int));
2479 ans->len = len;
2480 return ans;
2481 }
2482
2483 static Ints*
makeints(int * buf,int len)2484 makeints(int* buf, int len)
2485 {
2486 Ints* ans;
2487
2488 ans = newints(len);
2489 if(len > 0)
2490 memmove(ans->data, buf, len*sizeof(int));
2491 return ans;
2492 }
2493
2494 static void
freeints(Ints * b)2495 freeints(Ints* b)
2496 {
2497 if(b != nil)
2498 free(b);
2499 }
2500