1 #include <u.h>
2 #include <libc.h>
3 #include <bio.h>
4 #include <auth.h>
5 #include <mp.h>
6 #include <libsec.h>
7
8 // The main groups of functions are:
9 // client/server - main handshake protocol definition
10 // message functions - formating handshake messages
11 // cipher choices - catalog of digest and encrypt algorithms
12 // security functions - PKCS#1, sslHMAC, session keygen
13 // general utility functions - malloc, serialization
14 // The handshake protocol builds on the TLS/SSL3 record layer protocol,
15 // which is implemented in kernel device #a. See also /lib/rfc/rfc2246.
16
17 enum {
18 TLSFinishedLen = 12,
19 SSL3FinishedLen = MD5dlen+SHA1dlen,
20 MaxKeyData = 136, // amount of secret we may need
21 MaxChunk = 1<<14,
22 RandomSize = 32,
23 SidSize = 32,
24 MasterSecretSize = 48,
25 AQueue = 0,
26 AFlush = 1,
27 };
28
29 typedef struct TlsSec TlsSec;
30
31 typedef struct Bytes{
32 int len;
33 uchar data[1]; // [len]
34 } Bytes;
35
36 typedef struct Ints{
37 int len;
38 int data[1]; // [len]
39 } Ints;
40
41 typedef struct Algs{
42 char *enc;
43 char *digest;
44 int nsecret;
45 int tlsid;
46 int ok;
47 } Algs;
48
49 typedef struct Finished{
50 uchar verify[SSL3FinishedLen];
51 int n;
52 } Finished;
53
54 typedef struct TlsConnection{
55 TlsSec *sec; // security management goo
56 int hand, ctl; // record layer file descriptors
57 int erred; // set when tlsError called
58 int (*trace)(char*fmt, ...); // for debugging
59 int version; // protocol we are speaking
60 int verset; // version has been set
61 int ver2hi; // server got a version 2 hello
62 int isClient; // is this the client or server?
63 Bytes *sid; // SessionID
64 Bytes *cert; // only last - no chain
65
66 Lock statelk;
67 int state; // must be set using setstate
68
69 // input buffer for handshake messages
70 uchar buf[MaxChunk+2048];
71 uchar *rp, *ep;
72
73 uchar crandom[RandomSize]; // client random
74 uchar srandom[RandomSize]; // server random
75 int clientVersion; // version in ClientHello
76 char *digest; // name of digest algorithm to use
77 char *enc; // name of encryption algorithm to use
78 int nsecret; // amount of secret data to init keys
79
80 // for finished messages
81 MD5state hsmd5; // handshake hash
82 SHAstate hssha1; // handshake hash
83 Finished finished;
84 } TlsConnection;
85
86 typedef struct Msg{
87 int tag;
88 union {
89 struct {
90 int version;
91 uchar random[RandomSize];
92 Bytes* sid;
93 Ints* ciphers;
94 Bytes* compressors;
95 } clientHello;
96 struct {
97 int version;
98 uchar random[RandomSize];
99 Bytes* sid;
100 int cipher;
101 int compressor;
102 } serverHello;
103 struct {
104 int ncert;
105 Bytes **certs;
106 } certificate;
107 struct {
108 Bytes *types;
109 int nca;
110 Bytes **cas;
111 } certificateRequest;
112 struct {
113 Bytes *key;
114 } clientKeyExchange;
115 Finished finished;
116 } u;
117 } Msg;
118
119 typedef struct TlsSec{
120 char *server; // name of remote; nil for server
121 int ok; // <0 killed; == 0 in progress; >0 reusable
122 RSApub *rsapub;
123 AuthRpc *rpc; // factotum for rsa private key
124 uchar sec[MasterSecretSize]; // master secret
125 uchar crandom[RandomSize]; // client random
126 uchar srandom[RandomSize]; // server random
127 int clientVers; // version in ClientHello
128 int vers; // final version
129 // byte generation and handshake checksum
130 void (*prf)(uchar*, int, uchar*, int, char*, uchar*, int, uchar*, int);
131 void (*setFinished)(TlsSec*, MD5state, SHAstate, uchar*, int);
132 int nfin;
133 } TlsSec;
134
135
136 enum {
137 TLSVersion = 0x0301,
138 SSL3Version = 0x0300,
139 ProtocolVersion = 0x0301, // maximum version we speak
140 MinProtoVersion = 0x0300, // limits on version we accept
141 MaxProtoVersion = 0x03ff,
142 };
143
144 // handshake type
145 enum {
146 HHelloRequest,
147 HClientHello,
148 HServerHello,
149 HSSL2ClientHello = 9, /* local convention; see devtls.c */
150 HCertificate = 11,
151 HServerKeyExchange,
152 HCertificateRequest,
153 HServerHelloDone,
154 HCertificateVerify,
155 HClientKeyExchange,
156 HFinished = 20,
157 HMax
158 };
159
160 // alerts
161 enum {
162 ECloseNotify = 0,
163 EUnexpectedMessage = 10,
164 EBadRecordMac = 20,
165 EDecryptionFailed = 21,
166 ERecordOverflow = 22,
167 EDecompressionFailure = 30,
168 EHandshakeFailure = 40,
169 ENoCertificate = 41,
170 EBadCertificate = 42,
171 EUnsupportedCertificate = 43,
172 ECertificateRevoked = 44,
173 ECertificateExpired = 45,
174 ECertificateUnknown = 46,
175 EIllegalParameter = 47,
176 EUnknownCa = 48,
177 EAccessDenied = 49,
178 EDecodeError = 50,
179 EDecryptError = 51,
180 EExportRestriction = 60,
181 EProtocolVersion = 70,
182 EInsufficientSecurity = 71,
183 EInternalError = 80,
184 EUserCanceled = 90,
185 ENoRenegotiation = 100,
186 EMax = 256
187 };
188
189 // cipher suites
190 enum {
191 TLS_NULL_WITH_NULL_NULL = 0x0000,
192 TLS_RSA_WITH_NULL_MD5 = 0x0001,
193 TLS_RSA_WITH_NULL_SHA = 0x0002,
194 TLS_RSA_EXPORT_WITH_RC4_40_MD5 = 0x0003,
195 TLS_RSA_WITH_RC4_128_MD5 = 0x0004,
196 TLS_RSA_WITH_RC4_128_SHA = 0x0005,
197 TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5 = 0X0006,
198 TLS_RSA_WITH_IDEA_CBC_SHA = 0X0007,
199 TLS_RSA_EXPORT_WITH_DES40_CBC_SHA = 0X0008,
200 TLS_RSA_WITH_DES_CBC_SHA = 0X0009,
201 TLS_RSA_WITH_3DES_EDE_CBC_SHA = 0X000A,
202 TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA = 0X000B,
203 TLS_DH_DSS_WITH_DES_CBC_SHA = 0X000C,
204 TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA = 0X000D,
205 TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA = 0X000E,
206 TLS_DH_RSA_WITH_DES_CBC_SHA = 0X000F,
207 TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA = 0X0010,
208 TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA = 0X0011,
209 TLS_DHE_DSS_WITH_DES_CBC_SHA = 0X0012,
210 TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA = 0X0013, // ZZZ must be implemented for tls1.0 compliance
211 TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA = 0X0014,
212 TLS_DHE_RSA_WITH_DES_CBC_SHA = 0X0015,
213 TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA = 0X0016,
214 TLS_DH_anon_EXPORT_WITH_RC4_40_MD5 = 0x0017,
215 TLS_DH_anon_WITH_RC4_128_MD5 = 0x0018,
216 TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA = 0X0019,
217 TLS_DH_anon_WITH_DES_CBC_SHA = 0X001A,
218 TLS_DH_anon_WITH_3DES_EDE_CBC_SHA = 0X001B,
219
220 TLS_RSA_WITH_AES_128_CBC_SHA = 0X002f, // aes, aka rijndael with 128 bit blocks
221 TLS_DH_DSS_WITH_AES_128_CBC_SHA = 0X0030,
222 TLS_DH_RSA_WITH_AES_128_CBC_SHA = 0X0031,
223 TLS_DHE_DSS_WITH_AES_128_CBC_SHA = 0X0032,
224 TLS_DHE_RSA_WITH_AES_128_CBC_SHA = 0X0033,
225 TLS_DH_anon_WITH_AES_128_CBC_SHA = 0X0034,
226 TLS_RSA_WITH_AES_256_CBC_SHA = 0X0035,
227 TLS_DH_DSS_WITH_AES_256_CBC_SHA = 0X0036,
228 TLS_DH_RSA_WITH_AES_256_CBC_SHA = 0X0037,
229 TLS_DHE_DSS_WITH_AES_256_CBC_SHA = 0X0038,
230 TLS_DHE_RSA_WITH_AES_256_CBC_SHA = 0X0039,
231 TLS_DH_anon_WITH_AES_256_CBC_SHA = 0X003A,
232 CipherMax
233 };
234
235 // compression methods
236 enum {
237 CompressionNull = 0,
238 CompressionMax
239 };
240
241 static Algs cipherAlgs[] = {
242 {"rc4_128", "md5", 2*(16+MD5dlen), TLS_RSA_WITH_RC4_128_MD5},
243 {"rc4_128", "sha1", 2*(16+SHA1dlen), TLS_RSA_WITH_RC4_128_SHA},
244 {"3des_ede_cbc", "sha1", 2*(4*8+SHA1dlen), TLS_RSA_WITH_3DES_EDE_CBC_SHA},
245 {"aes_128_cbc", "sha1", 2*(16+16+SHA1dlen), TLS_RSA_WITH_AES_128_CBC_SHA},
246 {"aes_256_cbc", "sha1", 2*(32+16+SHA1dlen), TLS_RSA_WITH_AES_256_CBC_SHA}
247 };
248
249 static uchar compressors[] = {
250 CompressionNull,
251 };
252
253 static TlsConnection *tlsServer2(int ctl, int hand, uchar *cert, int ncert, int (*trace)(char*fmt, ...), PEMChain *chain);
254 static TlsConnection *tlsClient2(int ctl, int hand, uchar *csid, int ncsid, int (*trace)(char*fmt, ...));
255
256 static void msgClear(Msg *m);
257 static char* msgPrint(char *buf, int n, Msg *m);
258 static int msgRecv(TlsConnection *c, Msg *m);
259 static int msgSend(TlsConnection *c, Msg *m, int act);
260 static void tlsError(TlsConnection *c, int err, char *msg, ...);
261 #pragma varargck argpos tlsError 3
262 static int setVersion(TlsConnection *c, int version);
263 static int finishedMatch(TlsConnection *c, Finished *f);
264 static void tlsConnectionFree(TlsConnection *c);
265
266 static int setAlgs(TlsConnection *c, int a);
267 static int okCipher(Ints *cv);
268 static int okCompression(Bytes *cv);
269 static int initCiphers(void);
270 static Ints* makeciphers(void);
271
272 static TlsSec* tlsSecInits(int cvers, uchar *csid, int ncsid, uchar *crandom, uchar *ssid, int *nssid, uchar *srandom);
273 static int tlsSecSecrets(TlsSec *sec, int vers, uchar *epm, int nepm, uchar *kd, int nkd);
274 static TlsSec* tlsSecInitc(int cvers, uchar *crandom);
275 static int tlsSecSecretc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers, uchar **epm, int *nepm, uchar *kd, int nkd);
276 static int tlsSecFinished(TlsSec *sec, MD5state md5, SHAstate sha1, uchar *fin, int nfin, int isclient);
277 static void tlsSecOk(TlsSec *sec);
278 static void tlsSecKill(TlsSec *sec);
279 static void tlsSecClose(TlsSec *sec);
280 static void setMasterSecret(TlsSec *sec, Bytes *pm);
281 static void serverMasterSecret(TlsSec *sec, uchar *epm, int nepm);
282 static void setSecrets(TlsSec *sec, uchar *kd, int nkd);
283 static int clientMasterSecret(TlsSec *sec, RSApub *pub, uchar **epm, int *nepm);
284 static Bytes *pkcs1_encrypt(Bytes* data, RSApub* key, int blocktype);
285 static Bytes *pkcs1_decrypt(TlsSec *sec, uchar *epm, int nepm);
286 static void tlsSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient);
287 static void sslSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient);
288 static void sslPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label,
289 uchar *seed0, int nseed0, uchar *seed1, int nseed1);
290 static int setVers(TlsSec *sec, int version);
291
292 static AuthRpc* factotum_rsa_open(uchar *cert, int certlen);
293 static mpint* factotum_rsa_decrypt(AuthRpc *rpc, mpint *cipher);
294 static void factotum_rsa_close(AuthRpc*rpc);
295
296 static void* emalloc(int);
297 static void* erealloc(void*, int);
298 static void put32(uchar *p, u32int);
299 static void put24(uchar *p, int);
300 static void put16(uchar *p, int);
301 static u32int get32(uchar *p);
302 static int get24(uchar *p);
303 static int get16(uchar *p);
304 static Bytes* newbytes(int len);
305 static Bytes* makebytes(uchar* buf, int len);
306 static void freebytes(Bytes* b);
307 static Ints* newints(int len);
308 static Ints* makeints(int* buf, int len);
309 static void freeints(Ints* b);
310
311 //================= client/server ========================
312
313 // push TLS onto fd, returning new (application) file descriptor
314 // or -1 if error.
315 int
tlsServer(int fd,TLSconn * conn)316 tlsServer(int fd, TLSconn *conn)
317 {
318 char buf[8];
319 char dname[64];
320 int n, data, ctl, hand;
321 TlsConnection *tls;
322
323 if(conn == nil)
324 return -1;
325 ctl = open("#a/tls/clone", ORDWR);
326 if(ctl < 0)
327 return -1;
328 n = read(ctl, buf, sizeof(buf)-1);
329 if(n < 0){
330 close(ctl);
331 return -1;
332 }
333 buf[n] = 0;
334 sprint(conn->dir, "#a/tls/%s", buf);
335 sprint(dname, "#a/tls/%s/hand", buf);
336 hand = open(dname, ORDWR);
337 if(hand < 0){
338 close(ctl);
339 return -1;
340 }
341 fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion);
342 tls = tlsServer2(ctl, hand, conn->cert, conn->certlen, conn->trace, conn->chain);
343 sprint(dname, "#a/tls/%s/data", buf);
344 data = open(dname, ORDWR);
345 close(fd);
346 close(hand);
347 close(ctl);
348 if(data < 0){
349 return -1;
350 }
351 if(tls == nil){
352 close(data);
353 return -1;
354 }
355 if(conn->cert)
356 free(conn->cert);
357 conn->cert = 0; // client certificates are not yet implemented
358 conn->certlen = 0;
359 conn->sessionIDlen = tls->sid->len;
360 conn->sessionID = emalloc(conn->sessionIDlen);
361 memcpy(conn->sessionID, tls->sid->data, conn->sessionIDlen);
362 if(conn->sessionKey != nil && conn->sessionType != nil && strcmp(conn->sessionType, "ttls") == 0)
363 tls->sec->prf(conn->sessionKey, conn->sessionKeylen, tls->sec->sec, MasterSecretSize, conn->sessionConst, tls->sec->crandom, RandomSize, tls->sec->srandom, RandomSize);
364 tlsConnectionFree(tls);
365 return data;
366 }
367
368 // push TLS onto fd, returning new (application) file descriptor
369 // or -1 if error.
370 int
tlsClient(int fd,TLSconn * conn)371 tlsClient(int fd, TLSconn *conn)
372 {
373 char buf[8];
374 char dname[64];
375 int n, data, ctl, hand;
376 TlsConnection *tls;
377
378 if(!conn)
379 return -1;
380 ctl = open("#a/tls/clone", ORDWR);
381 if(ctl < 0)
382 return -1;
383 n = read(ctl, buf, sizeof(buf)-1);
384 if(n < 0){
385 close(ctl);
386 return -1;
387 }
388 buf[n] = 0;
389 sprint(conn->dir, "#a/tls/%s", buf);
390 sprint(dname, "#a/tls/%s/hand", buf);
391 hand = open(dname, ORDWR);
392 if(hand < 0){
393 close(ctl);
394 return -1;
395 }
396 sprint(dname, "#a/tls/%s/data", buf);
397 data = open(dname, ORDWR);
398 if(data < 0)
399 return -1;
400 fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion);
401 tls = tlsClient2(ctl, hand, conn->sessionID, conn->sessionIDlen, conn->trace);
402 close(fd);
403 close(hand);
404 close(ctl);
405 if(tls == nil){
406 close(data);
407 return -1;
408 }
409 conn->certlen = tls->cert->len;
410 conn->cert = emalloc(conn->certlen);
411 memcpy(conn->cert, tls->cert->data, conn->certlen);
412 conn->sessionIDlen = tls->sid->len;
413 conn->sessionID = emalloc(conn->sessionIDlen);
414 memcpy(conn->sessionID, tls->sid->data, conn->sessionIDlen);
415 if(conn->sessionKey != nil && conn->sessionType != nil && strcmp(conn->sessionType, "ttls") == 0)
416 tls->sec->prf(conn->sessionKey, conn->sessionKeylen, tls->sec->sec, MasterSecretSize, conn->sessionConst, tls->sec->crandom, RandomSize, tls->sec->srandom, RandomSize);
417 tlsConnectionFree(tls);
418 return data;
419 }
420
421 static int
countchain(PEMChain * p)422 countchain(PEMChain *p)
423 {
424 int i = 0;
425
426 while (p) {
427 i++;
428 p = p->next;
429 }
430 return i;
431 }
432
433 static TlsConnection *
tlsServer2(int ctl,int hand,uchar * cert,int ncert,int (* trace)(char * fmt,...),PEMChain * chp)434 tlsServer2(int ctl, int hand, uchar *cert, int ncert, int (*trace)(char*fmt, ...), PEMChain *chp)
435 {
436 TlsConnection *c;
437 Msg m;
438 Bytes *csid;
439 uchar sid[SidSize], kd[MaxKeyData];
440 char *secrets;
441 int cipher, compressor, nsid, rv, numcerts, i;
442
443 if(trace)
444 trace("tlsServer2\n");
445 if(!initCiphers())
446 return nil;
447 c = emalloc(sizeof(TlsConnection));
448 c->ctl = ctl;
449 c->hand = hand;
450 c->trace = trace;
451 c->version = ProtocolVersion;
452
453 memset(&m, 0, sizeof(m));
454 if(!msgRecv(c, &m)){
455 if(trace)
456 trace("initial msgRecv failed\n");
457 goto Err;
458 }
459 if(m.tag != HClientHello) {
460 tlsError(c, EUnexpectedMessage, "expected a client hello");
461 goto Err;
462 }
463 c->clientVersion = m.u.clientHello.version;
464 if(trace)
465 trace("ClientHello version %x\n", c->clientVersion);
466 if(setVersion(c, m.u.clientHello.version) < 0) {
467 tlsError(c, EIllegalParameter, "incompatible version");
468 goto Err;
469 }
470
471 memmove(c->crandom, m.u.clientHello.random, RandomSize);
472 cipher = okCipher(m.u.clientHello.ciphers);
473 if(cipher < 0) {
474 // reply with EInsufficientSecurity if we know that's the case
475 if(cipher == -2)
476 tlsError(c, EInsufficientSecurity, "cipher suites too weak");
477 else
478 tlsError(c, EHandshakeFailure, "no matching cipher suite");
479 goto Err;
480 }
481 if(!setAlgs(c, cipher)){
482 tlsError(c, EHandshakeFailure, "no matching cipher suite");
483 goto Err;
484 }
485 compressor = okCompression(m.u.clientHello.compressors);
486 if(compressor < 0) {
487 tlsError(c, EHandshakeFailure, "no matching compressor");
488 goto Err;
489 }
490
491 csid = m.u.clientHello.sid;
492 if(trace)
493 trace(" cipher %d, compressor %d, csidlen %d\n", cipher, compressor, csid->len);
494 c->sec = tlsSecInits(c->clientVersion, csid->data, csid->len, c->crandom, sid, &nsid, c->srandom);
495 if(c->sec == nil){
496 tlsError(c, EHandshakeFailure, "can't initialize security: %r");
497 goto Err;
498 }
499 c->sec->rpc = factotum_rsa_open(cert, ncert);
500 if(c->sec->rpc == nil){
501 tlsError(c, EHandshakeFailure, "factotum_rsa_open: %r");
502 goto Err;
503 }
504 c->sec->rsapub = X509toRSApub(cert, ncert, nil, 0);
505 msgClear(&m);
506
507 m.tag = HServerHello;
508 m.u.serverHello.version = c->version;
509 memmove(m.u.serverHello.random, c->srandom, RandomSize);
510 m.u.serverHello.cipher = cipher;
511 m.u.serverHello.compressor = compressor;
512 c->sid = makebytes(sid, nsid);
513 m.u.serverHello.sid = makebytes(c->sid->data, c->sid->len);
514 if(!msgSend(c, &m, AQueue))
515 goto Err;
516 msgClear(&m);
517
518 m.tag = HCertificate;
519 numcerts = countchain(chp);
520 m.u.certificate.ncert = 1 + numcerts;
521 m.u.certificate.certs = emalloc(m.u.certificate.ncert * sizeof(Bytes));
522 m.u.certificate.certs[0] = makebytes(cert, ncert);
523 for (i = 0; i < numcerts && chp; i++, chp = chp->next)
524 m.u.certificate.certs[i+1] = makebytes(chp->pem, chp->pemlen);
525 if(!msgSend(c, &m, AQueue))
526 goto Err;
527 msgClear(&m);
528
529 m.tag = HServerHelloDone;
530 if(!msgSend(c, &m, AFlush))
531 goto Err;
532 msgClear(&m);
533
534 if(!msgRecv(c, &m))
535 goto Err;
536 if(m.tag != HClientKeyExchange) {
537 tlsError(c, EUnexpectedMessage, "expected a client key exchange");
538 goto Err;
539 }
540 if(tlsSecSecrets(c->sec, c->version, m.u.clientKeyExchange.key->data, m.u.clientKeyExchange.key->len, kd, c->nsecret) < 0){
541 tlsError(c, EHandshakeFailure, "couldn't set secrets: %r");
542 goto Err;
543 }
544 if(trace)
545 trace("tls secrets\n");
546 secrets = (char*)emalloc(2*c->nsecret);
547 enc64(secrets, 2*c->nsecret, kd, c->nsecret);
548 rv = fprint(c->ctl, "secret %s %s 0 %s", c->digest, c->enc, secrets);
549 memset(secrets, 0, 2*c->nsecret);
550 free(secrets);
551 memset(kd, 0, c->nsecret);
552 if(rv < 0){
553 tlsError(c, EHandshakeFailure, "can't set keys: %r");
554 goto Err;
555 }
556 msgClear(&m);
557
558 /* no CertificateVerify; skip to Finished */
559 if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 1) < 0){
560 tlsError(c, EInternalError, "can't set finished: %r");
561 goto Err;
562 }
563 if(!msgRecv(c, &m))
564 goto Err;
565 if(m.tag != HFinished) {
566 tlsError(c, EUnexpectedMessage, "expected a finished");
567 goto Err;
568 }
569 if(!finishedMatch(c, &m.u.finished)) {
570 tlsError(c, EHandshakeFailure, "finished verification failed");
571 goto Err;
572 }
573 msgClear(&m);
574
575 /* change cipher spec */
576 if(fprint(c->ctl, "changecipher") < 0){
577 tlsError(c, EInternalError, "can't enable cipher: %r");
578 goto Err;
579 }
580
581 if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 0) < 0){
582 tlsError(c, EInternalError, "can't set finished: %r");
583 goto Err;
584 }
585 m.tag = HFinished;
586 m.u.finished = c->finished;
587 if(!msgSend(c, &m, AFlush))
588 goto Err;
589 msgClear(&m);
590 if(trace)
591 trace("tls finished\n");
592
593 if(fprint(c->ctl, "opened") < 0)
594 goto Err;
595 tlsSecOk(c->sec);
596 return c;
597
598 Err:
599 msgClear(&m);
600 tlsConnectionFree(c);
601 return 0;
602 }
603
604 static TlsConnection *
tlsClient2(int ctl,int hand,uchar * csid,int ncsid,int (* trace)(char * fmt,...))605 tlsClient2(int ctl, int hand, uchar *csid, int ncsid, int (*trace)(char*fmt, ...))
606 {
607 TlsConnection *c;
608 Msg m;
609 uchar kd[MaxKeyData], *epm;
610 char *secrets;
611 int creq, nepm, rv;
612
613 if(!initCiphers())
614 return nil;
615 epm = nil;
616 c = emalloc(sizeof(TlsConnection));
617 c->version = ProtocolVersion;
618 c->ctl = ctl;
619 c->hand = hand;
620 c->trace = trace;
621 c->isClient = 1;
622 c->clientVersion = c->version;
623
624 c->sec = tlsSecInitc(c->clientVersion, c->crandom);
625 if(c->sec == nil)
626 goto Err;
627
628 /* client hello */
629 memset(&m, 0, sizeof(m));
630 m.tag = HClientHello;
631 m.u.clientHello.version = c->clientVersion;
632 memmove(m.u.clientHello.random, c->crandom, RandomSize);
633 m.u.clientHello.sid = makebytes(csid, ncsid);
634 m.u.clientHello.ciphers = makeciphers();
635 m.u.clientHello.compressors = makebytes(compressors,sizeof(compressors));
636 if(!msgSend(c, &m, AFlush))
637 goto Err;
638 msgClear(&m);
639
640 /* server hello */
641 if(!msgRecv(c, &m))
642 goto Err;
643 if(m.tag != HServerHello) {
644 tlsError(c, EUnexpectedMessage, "expected a server hello");
645 goto Err;
646 }
647 if(setVersion(c, m.u.serverHello.version) < 0) {
648 tlsError(c, EIllegalParameter, "incompatible version %r");
649 goto Err;
650 }
651 memmove(c->srandom, m.u.serverHello.random, RandomSize);
652 c->sid = makebytes(m.u.serverHello.sid->data, m.u.serverHello.sid->len);
653 if(c->sid->len != 0 && c->sid->len != SidSize) {
654 tlsError(c, EIllegalParameter, "invalid server session identifier");
655 goto Err;
656 }
657 if(!setAlgs(c, m.u.serverHello.cipher)) {
658 tlsError(c, EIllegalParameter, "invalid cipher suite");
659 goto Err;
660 }
661 if(m.u.serverHello.compressor != CompressionNull) {
662 tlsError(c, EIllegalParameter, "invalid compression");
663 goto Err;
664 }
665 msgClear(&m);
666
667 /* certificate */
668 if(!msgRecv(c, &m) || m.tag != HCertificate) {
669 tlsError(c, EUnexpectedMessage, "expected a certificate");
670 goto Err;
671 }
672 if(m.u.certificate.ncert < 1) {
673 tlsError(c, EIllegalParameter, "runt certificate");
674 goto Err;
675 }
676 c->cert = makebytes(m.u.certificate.certs[0]->data, m.u.certificate.certs[0]->len);
677 msgClear(&m);
678
679 /* server key exchange (optional) */
680 if(!msgRecv(c, &m))
681 goto Err;
682 if(m.tag == HServerKeyExchange) {
683 tlsError(c, EUnexpectedMessage, "got an server key exchange");
684 goto Err;
685 // If implementing this later, watch out for rollback attack
686 // described in Wagner Schneier 1996, section 4.4.
687 }
688
689 /* certificate request (optional) */
690 creq = 0;
691 if(m.tag == HCertificateRequest) {
692 creq = 1;
693 msgClear(&m);
694 if(!msgRecv(c, &m))
695 goto Err;
696 }
697
698 if(m.tag != HServerHelloDone) {
699 tlsError(c, EUnexpectedMessage, "expected a server hello done");
700 goto Err;
701 }
702 msgClear(&m);
703
704 if(tlsSecSecretc(c->sec, c->sid->data, c->sid->len, c->srandom,
705 c->cert->data, c->cert->len, c->version, &epm, &nepm,
706 kd, c->nsecret) < 0){
707 tlsError(c, EBadCertificate, "invalid x509/rsa certificate");
708 goto Err;
709 }
710 secrets = (char*)emalloc(2*c->nsecret);
711 enc64(secrets, 2*c->nsecret, kd, c->nsecret);
712 rv = fprint(c->ctl, "secret %s %s 1 %s", c->digest, c->enc, secrets);
713 memset(secrets, 0, 2*c->nsecret);
714 free(secrets);
715 memset(kd, 0, c->nsecret);
716 if(rv < 0){
717 tlsError(c, EHandshakeFailure, "can't set keys: %r");
718 goto Err;
719 }
720
721 if(creq) {
722 /* send a zero length certificate */
723 m.tag = HCertificate;
724 if(!msgSend(c, &m, AFlush))
725 goto Err;
726 msgClear(&m);
727 }
728
729 /* client key exchange */
730 m.tag = HClientKeyExchange;
731 m.u.clientKeyExchange.key = makebytes(epm, nepm);
732 free(epm);
733 epm = nil;
734 if(m.u.clientKeyExchange.key == nil) {
735 tlsError(c, EHandshakeFailure, "can't set secret: %r");
736 goto Err;
737 }
738 if(!msgSend(c, &m, AFlush))
739 goto Err;
740 msgClear(&m);
741
742 /* change cipher spec */
743 if(fprint(c->ctl, "changecipher") < 0){
744 tlsError(c, EInternalError, "can't enable cipher: %r");
745 goto Err;
746 }
747
748 // Cipherchange must occur immediately before Finished to avoid
749 // potential hole; see section 4.3 of Wagner Schneier 1996.
750 if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 1) < 0){
751 tlsError(c, EInternalError, "can't set finished 1: %r");
752 goto Err;
753 }
754 m.tag = HFinished;
755 m.u.finished = c->finished;
756
757 if(!msgSend(c, &m, AFlush)) {
758 fprint(2, "tlsClient nepm=%d\n", nepm);
759 tlsError(c, EInternalError, "can't flush after client Finished: %r");
760 goto Err;
761 }
762 msgClear(&m);
763
764 if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 0) < 0){
765 fprint(2, "tlsClient nepm=%d\n", nepm);
766 tlsError(c, EInternalError, "can't set finished 0: %r");
767 goto Err;
768 }
769 if(!msgRecv(c, &m)) {
770 fprint(2, "tlsClient nepm=%d\n", nepm);
771 tlsError(c, EInternalError, "can't read server Finished: %r");
772 goto Err;
773 }
774 if(m.tag != HFinished) {
775 fprint(2, "tlsClient nepm=%d\n", nepm);
776 tlsError(c, EUnexpectedMessage, "expected a Finished msg from server");
777 goto Err;
778 }
779
780 if(!finishedMatch(c, &m.u.finished)) {
781 tlsError(c, EHandshakeFailure, "finished verification failed");
782 goto Err;
783 }
784 msgClear(&m);
785
786 if(fprint(c->ctl, "opened") < 0){
787 if(trace)
788 trace("unable to do final open: %r\n");
789 goto Err;
790 }
791 tlsSecOk(c->sec);
792 return c;
793
794 Err:
795 free(epm);
796 msgClear(&m);
797 tlsConnectionFree(c);
798 return 0;
799 }
800
801
802 //================= message functions ========================
803
804 static uchar sendbuf[9000], *sendp;
805
806 static int
msgSend(TlsConnection * c,Msg * m,int act)807 msgSend(TlsConnection *c, Msg *m, int act)
808 {
809 uchar *p; // sendp = start of new message; p = write pointer
810 int nn, n, i;
811
812 if(sendp == nil)
813 sendp = sendbuf;
814 p = sendp;
815 if(c->trace)
816 c->trace("send %s", msgPrint((char*)p, (sizeof sendbuf) - (p-sendbuf), m));
817
818 p[0] = m->tag; // header - fill in size later
819 p += 4;
820
821 switch(m->tag) {
822 default:
823 tlsError(c, EInternalError, "can't encode a %d", m->tag);
824 goto Err;
825 case HClientHello:
826 // version
827 put16(p, m->u.clientHello.version);
828 p += 2;
829
830 // random
831 memmove(p, m->u.clientHello.random, RandomSize);
832 p += RandomSize;
833
834 // sid
835 n = m->u.clientHello.sid->len;
836 assert(n < 256);
837 p[0] = n;
838 memmove(p+1, m->u.clientHello.sid->data, n);
839 p += n+1;
840
841 n = m->u.clientHello.ciphers->len;
842 assert(n > 0 && n < 200);
843 put16(p, n*2);
844 p += 2;
845 for(i=0; i<n; i++) {
846 put16(p, m->u.clientHello.ciphers->data[i]);
847 p += 2;
848 }
849
850 n = m->u.clientHello.compressors->len;
851 assert(n > 0);
852 p[0] = n;
853 memmove(p+1, m->u.clientHello.compressors->data, n);
854 p += n+1;
855 break;
856 case HServerHello:
857 put16(p, m->u.serverHello.version);
858 p += 2;
859
860 // random
861 memmove(p, m->u.serverHello.random, RandomSize);
862 p += RandomSize;
863
864 // sid
865 n = m->u.serverHello.sid->len;
866 assert(n < 256);
867 p[0] = n;
868 memmove(p+1, m->u.serverHello.sid->data, n);
869 p += n+1;
870
871 put16(p, m->u.serverHello.cipher);
872 p += 2;
873 p[0] = m->u.serverHello.compressor;
874 p += 1;
875 break;
876 case HServerHelloDone:
877 break;
878 case HCertificate:
879 nn = 0;
880 for(i = 0; i < m->u.certificate.ncert; i++)
881 nn += 3 + m->u.certificate.certs[i]->len;
882 if(p + 3 + nn - sendbuf > sizeof(sendbuf)) {
883 tlsError(c, EInternalError, "output buffer too small for certificate");
884 goto Err;
885 }
886 put24(p, nn);
887 p += 3;
888 for(i = 0; i < m->u.certificate.ncert; i++){
889 put24(p, m->u.certificate.certs[i]->len);
890 p += 3;
891 memmove(p, m->u.certificate.certs[i]->data, m->u.certificate.certs[i]->len);
892 p += m->u.certificate.certs[i]->len;
893 }
894 break;
895 case HClientKeyExchange:
896 n = m->u.clientKeyExchange.key->len;
897 if(c->version != SSL3Version){
898 put16(p, n);
899 p += 2;
900 }
901 memmove(p, m->u.clientKeyExchange.key->data, n);
902 p += n;
903 break;
904 case HFinished:
905 memmove(p, m->u.finished.verify, m->u.finished.n);
906 p += m->u.finished.n;
907 break;
908 }
909
910 // go back and fill in size
911 n = p-sendp;
912 assert(p <= sendbuf+sizeof(sendbuf));
913 put24(sendp+1, n-4);
914
915 // remember hash of Handshake messages
916 if(m->tag != HHelloRequest) {
917 md5(sendp, n, 0, &c->hsmd5);
918 sha1(sendp, n, 0, &c->hssha1);
919 }
920
921 sendp = p;
922 if(act == AFlush){
923 sendp = sendbuf;
924 if(write(c->hand, sendbuf, p-sendbuf) < 0){
925 fprint(2, "write error: %r\n");
926 goto Err;
927 }
928 }
929 msgClear(m);
930 return 1;
931 Err:
932 msgClear(m);
933 return 0;
934 }
935
936 static uchar*
tlsReadN(TlsConnection * c,int n)937 tlsReadN(TlsConnection *c, int n)
938 {
939 uchar *p;
940 int nn, nr;
941
942 nn = c->ep - c->rp;
943 if(nn < n){
944 if(c->rp != c->buf){
945 memmove(c->buf, c->rp, nn);
946 c->rp = c->buf;
947 c->ep = &c->buf[nn];
948 }
949 for(; nn < n; nn += nr) {
950 nr = read(c->hand, &c->rp[nn], n - nn);
951 if(nr <= 0)
952 return nil;
953 c->ep += nr;
954 }
955 }
956 p = c->rp;
957 c->rp += n;
958 return p;
959 }
960
961 static int
msgRecv(TlsConnection * c,Msg * m)962 msgRecv(TlsConnection *c, Msg *m)
963 {
964 uchar *p;
965 int type, n, nn, i, nsid, nrandom, nciph;
966
967 for(;;) {
968 p = tlsReadN(c, 4);
969 if(p == nil)
970 return 0;
971 type = p[0];
972 n = get24(p+1);
973
974 if(type != HHelloRequest)
975 break;
976 if(n != 0) {
977 tlsError(c, EDecodeError, "invalid hello request during handshake");
978 return 0;
979 }
980 }
981
982 if(n > sizeof(c->buf)) {
983 tlsError(c, EDecodeError, "handshake message too long %d %d", n, sizeof(c->buf));
984 return 0;
985 }
986
987 if(type == HSSL2ClientHello){
988 /* Cope with an SSL3 ClientHello expressed in SSL2 record format.
989 This is sent by some clients that we must interoperate
990 with, such as Java's JSSE and Microsoft's Internet Explorer. */
991 p = tlsReadN(c, n);
992 if(p == nil)
993 return 0;
994 md5(p, n, 0, &c->hsmd5);
995 sha1(p, n, 0, &c->hssha1);
996 m->tag = HClientHello;
997 if(n < 22)
998 goto Short;
999 m->u.clientHello.version = get16(p+1);
1000 p += 3;
1001 n -= 3;
1002 nn = get16(p); /* cipher_spec_len */
1003 nsid = get16(p + 2);
1004 nrandom = get16(p + 4);
1005 p += 6;
1006 n -= 6;
1007 if(nsid != 0 /* no sid's, since shouldn't restart using ssl2 header */
1008 || nrandom < 16 || nn % 3)
1009 goto Err;
1010 if(c->trace && (n - nrandom != nn))
1011 c->trace("n-nrandom!=nn: n=%d nrandom=%d nn=%d\n", n, nrandom, nn);
1012 /* ignore ssl2 ciphers and look for {0x00, ssl3 cipher} */
1013 nciph = 0;
1014 for(i = 0; i < nn; i += 3)
1015 if(p[i] == 0)
1016 nciph++;
1017 m->u.clientHello.ciphers = newints(nciph);
1018 nciph = 0;
1019 for(i = 0; i < nn; i += 3)
1020 if(p[i] == 0)
1021 m->u.clientHello.ciphers->data[nciph++] = get16(&p[i + 1]);
1022 p += nn;
1023 m->u.clientHello.sid = makebytes(nil, 0);
1024 if(nrandom > RandomSize)
1025 nrandom = RandomSize;
1026 memset(m->u.clientHello.random, 0, RandomSize - nrandom);
1027 memmove(&m->u.clientHello.random[RandomSize - nrandom], p, nrandom);
1028 m->u.clientHello.compressors = newbytes(1);
1029 m->u.clientHello.compressors->data[0] = CompressionNull;
1030 goto Ok;
1031 }
1032
1033 md5(p, 4, 0, &c->hsmd5);
1034 sha1(p, 4, 0, &c->hssha1);
1035
1036 p = tlsReadN(c, n);
1037 if(p == nil)
1038 return 0;
1039
1040 md5(p, n, 0, &c->hsmd5);
1041 sha1(p, n, 0, &c->hssha1);
1042
1043 m->tag = type;
1044
1045 switch(type) {
1046 default:
1047 tlsError(c, EUnexpectedMessage, "can't decode a %d", type);
1048 goto Err;
1049 case HClientHello:
1050 if(n < 2)
1051 goto Short;
1052 m->u.clientHello.version = get16(p);
1053 p += 2;
1054 n -= 2;
1055
1056 if(n < RandomSize)
1057 goto Short;
1058 memmove(m->u.clientHello.random, p, RandomSize);
1059 p += RandomSize;
1060 n -= RandomSize;
1061 if(n < 1 || n < p[0]+1)
1062 goto Short;
1063 m->u.clientHello.sid = makebytes(p+1, p[0]);
1064 p += m->u.clientHello.sid->len+1;
1065 n -= m->u.clientHello.sid->len+1;
1066
1067 if(n < 2)
1068 goto Short;
1069 nn = get16(p);
1070 p += 2;
1071 n -= 2;
1072
1073 if((nn & 1) || n < nn || nn < 2)
1074 goto Short;
1075 m->u.clientHello.ciphers = newints(nn >> 1);
1076 for(i = 0; i < nn; i += 2)
1077 m->u.clientHello.ciphers->data[i >> 1] = get16(&p[i]);
1078 p += nn;
1079 n -= nn;
1080
1081 if(n < 1 || n < p[0]+1 || p[0] == 0)
1082 goto Short;
1083 nn = p[0];
1084 m->u.clientHello.compressors = newbytes(nn);
1085 memmove(m->u.clientHello.compressors->data, p+1, nn);
1086 n -= nn + 1;
1087 break;
1088 case HServerHello:
1089 if(n < 2)
1090 goto Short;
1091 m->u.serverHello.version = get16(p);
1092 p += 2;
1093 n -= 2;
1094
1095 if(n < RandomSize)
1096 goto Short;
1097 memmove(m->u.serverHello.random, p, RandomSize);
1098 p += RandomSize;
1099 n -= RandomSize;
1100
1101 if(n < 1 || n < p[0]+1)
1102 goto Short;
1103 m->u.serverHello.sid = makebytes(p+1, p[0]);
1104 p += m->u.serverHello.sid->len+1;
1105 n -= m->u.serverHello.sid->len+1;
1106
1107 if(n < 3)
1108 goto Short;
1109 m->u.serverHello.cipher = get16(p);
1110 m->u.serverHello.compressor = p[2];
1111 n -= 3;
1112 break;
1113 case HCertificate:
1114 if(n < 3)
1115 goto Short;
1116 nn = get24(p);
1117 p += 3;
1118 n -= 3;
1119 if(n != nn)
1120 goto Short;
1121 /* certs */
1122 i = 0;
1123 while(n > 0) {
1124 if(n < 3)
1125 goto Short;
1126 nn = get24(p);
1127 p += 3;
1128 n -= 3;
1129 if(nn > n)
1130 goto Short;
1131 m->u.certificate.ncert = i+1;
1132 m->u.certificate.certs = erealloc(m->u.certificate.certs, (i+1)*sizeof(Bytes));
1133 m->u.certificate.certs[i] = makebytes(p, nn);
1134 p += nn;
1135 n -= nn;
1136 i++;
1137 }
1138 break;
1139 case HCertificateRequest:
1140 if(n < 1)
1141 goto Short;
1142 nn = p[0];
1143 p += 1;
1144 n -= 1;
1145 if(nn < 1 || nn > n)
1146 goto Short;
1147 m->u.certificateRequest.types = makebytes(p, nn);
1148 p += nn;
1149 n -= nn;
1150 if(n < 2)
1151 goto Short;
1152 nn = get16(p);
1153 p += 2;
1154 n -= 2;
1155 /* nn == 0 can happen; yahoo's servers do it */
1156 if(nn != n)
1157 goto Short;
1158 /* cas */
1159 i = 0;
1160 while(n > 0) {
1161 if(n < 2)
1162 goto Short;
1163 nn = get16(p);
1164 p += 2;
1165 n -= 2;
1166 if(nn < 1 || nn > n)
1167 goto Short;
1168 m->u.certificateRequest.nca = i+1;
1169 m->u.certificateRequest.cas = erealloc(
1170 m->u.certificateRequest.cas, (i+1)*sizeof(Bytes));
1171 m->u.certificateRequest.cas[i] = makebytes(p, nn);
1172 p += nn;
1173 n -= nn;
1174 i++;
1175 }
1176 break;
1177 case HServerHelloDone:
1178 break;
1179 case HClientKeyExchange:
1180 /*
1181 * this message depends upon the encryption selected
1182 * assume rsa.
1183 */
1184 if(c->version == SSL3Version)
1185 nn = n;
1186 else{
1187 if(n < 2)
1188 goto Short;
1189 nn = get16(p);
1190 p += 2;
1191 n -= 2;
1192 }
1193 if(n < nn)
1194 goto Short;
1195 m->u.clientKeyExchange.key = makebytes(p, nn);
1196 n -= nn;
1197 break;
1198 case HFinished:
1199 m->u.finished.n = c->finished.n;
1200 if(n < m->u.finished.n)
1201 goto Short;
1202 memmove(m->u.finished.verify, p, m->u.finished.n);
1203 n -= m->u.finished.n;
1204 break;
1205 }
1206
1207 if(type != HClientHello && n != 0)
1208 goto Short;
1209 Ok:
1210 if(c->trace){
1211 char *buf;
1212 buf = emalloc(8000);
1213 c->trace("recv %s", msgPrint(buf, 8000, m));
1214 free(buf);
1215 }
1216 return 1;
1217 Short:
1218 tlsError(c, EDecodeError, "handshake message has invalid length");
1219 Err:
1220 msgClear(m);
1221 return 0;
1222 }
1223
1224 static void
msgClear(Msg * m)1225 msgClear(Msg *m)
1226 {
1227 int i;
1228
1229 switch(m->tag) {
1230 default:
1231 sysfatal("msgClear: unknown message type: %d", m->tag);
1232 case HHelloRequest:
1233 break;
1234 case HClientHello:
1235 freebytes(m->u.clientHello.sid);
1236 freeints(m->u.clientHello.ciphers);
1237 freebytes(m->u.clientHello.compressors);
1238 break;
1239 case HServerHello:
1240 freebytes(m->u.clientHello.sid);
1241 break;
1242 case HCertificate:
1243 for(i=0; i<m->u.certificate.ncert; i++)
1244 freebytes(m->u.certificate.certs[i]);
1245 free(m->u.certificate.certs);
1246 break;
1247 case HCertificateRequest:
1248 freebytes(m->u.certificateRequest.types);
1249 for(i=0; i<m->u.certificateRequest.nca; i++)
1250 freebytes(m->u.certificateRequest.cas[i]);
1251 free(m->u.certificateRequest.cas);
1252 break;
1253 case HServerHelloDone:
1254 break;
1255 case HClientKeyExchange:
1256 freebytes(m->u.clientKeyExchange.key);
1257 break;
1258 case HFinished:
1259 break;
1260 }
1261 memset(m, 0, sizeof(Msg));
1262 }
1263
1264 static char *
bytesPrint(char * bs,char * be,char * s0,Bytes * b,char * s1)1265 bytesPrint(char *bs, char *be, char *s0, Bytes *b, char *s1)
1266 {
1267 int i;
1268
1269 if(s0)
1270 bs = seprint(bs, be, "%s", s0);
1271 bs = seprint(bs, be, "[");
1272 if(b == nil)
1273 bs = seprint(bs, be, "nil");
1274 else
1275 for(i=0; i<b->len; i++)
1276 bs = seprint(bs, be, "%.2x ", b->data[i]);
1277 bs = seprint(bs, be, "]");
1278 if(s1)
1279 bs = seprint(bs, be, "%s", s1);
1280 return bs;
1281 }
1282
1283 static char *
intsPrint(char * bs,char * be,char * s0,Ints * b,char * s1)1284 intsPrint(char *bs, char *be, char *s0, Ints *b, char *s1)
1285 {
1286 int i;
1287
1288 if(s0)
1289 bs = seprint(bs, be, "%s", s0);
1290 bs = seprint(bs, be, "[");
1291 if(b == nil)
1292 bs = seprint(bs, be, "nil");
1293 else
1294 for(i=0; i<b->len; i++)
1295 bs = seprint(bs, be, "%x ", b->data[i]);
1296 bs = seprint(bs, be, "]");
1297 if(s1)
1298 bs = seprint(bs, be, "%s", s1);
1299 return bs;
1300 }
1301
1302 static char*
msgPrint(char * buf,int n,Msg * m)1303 msgPrint(char *buf, int n, Msg *m)
1304 {
1305 int i;
1306 char *bs = buf, *be = buf+n;
1307
1308 switch(m->tag) {
1309 default:
1310 bs = seprint(bs, be, "unknown %d\n", m->tag);
1311 break;
1312 case HClientHello:
1313 bs = seprint(bs, be, "ClientHello\n");
1314 bs = seprint(bs, be, "\tversion: %.4x\n", m->u.clientHello.version);
1315 bs = seprint(bs, be, "\trandom: ");
1316 for(i=0; i<RandomSize; i++)
1317 bs = seprint(bs, be, "%.2x", m->u.clientHello.random[i]);
1318 bs = seprint(bs, be, "\n");
1319 bs = bytesPrint(bs, be, "\tsid: ", m->u.clientHello.sid, "\n");
1320 bs = intsPrint(bs, be, "\tciphers: ", m->u.clientHello.ciphers, "\n");
1321 bs = bytesPrint(bs, be, "\tcompressors: ", m->u.clientHello.compressors, "\n");
1322 break;
1323 case HServerHello:
1324 bs = seprint(bs, be, "ServerHello\n");
1325 bs = seprint(bs, be, "\tversion: %.4x\n", m->u.serverHello.version);
1326 bs = seprint(bs, be, "\trandom: ");
1327 for(i=0; i<RandomSize; i++)
1328 bs = seprint(bs, be, "%.2x", m->u.serverHello.random[i]);
1329 bs = seprint(bs, be, "\n");
1330 bs = bytesPrint(bs, be, "\tsid: ", m->u.serverHello.sid, "\n");
1331 bs = seprint(bs, be, "\tcipher: %.4x\n", m->u.serverHello.cipher);
1332 bs = seprint(bs, be, "\tcompressor: %.2x\n", m->u.serverHello.compressor);
1333 break;
1334 case HCertificate:
1335 bs = seprint(bs, be, "Certificate\n");
1336 for(i=0; i<m->u.certificate.ncert; i++)
1337 bs = bytesPrint(bs, be, "\t", m->u.certificate.certs[i], "\n");
1338 break;
1339 case HCertificateRequest:
1340 bs = seprint(bs, be, "CertificateRequest\n");
1341 bs = bytesPrint(bs, be, "\ttypes: ", m->u.certificateRequest.types, "\n");
1342 bs = seprint(bs, be, "\tcertificateauthorities\n");
1343 for(i=0; i<m->u.certificateRequest.nca; i++)
1344 bs = bytesPrint(bs, be, "\t\t", m->u.certificateRequest.cas[i], "\n");
1345 break;
1346 case HServerHelloDone:
1347 bs = seprint(bs, be, "ServerHelloDone\n");
1348 break;
1349 case HClientKeyExchange:
1350 bs = seprint(bs, be, "HClientKeyExchange\n");
1351 bs = bytesPrint(bs, be, "\tkey: ", m->u.clientKeyExchange.key, "\n");
1352 break;
1353 case HFinished:
1354 bs = seprint(bs, be, "HFinished\n");
1355 for(i=0; i<m->u.finished.n; i++)
1356 bs = seprint(bs, be, "%.2x", m->u.finished.verify[i]);
1357 bs = seprint(bs, be, "\n");
1358 break;
1359 }
1360 USED(bs);
1361 return buf;
1362 }
1363
1364 static void
tlsError(TlsConnection * c,int err,char * fmt,...)1365 tlsError(TlsConnection *c, int err, char *fmt, ...)
1366 {
1367 char msg[512];
1368 va_list arg;
1369
1370 va_start(arg, fmt);
1371 vseprint(msg, msg+sizeof(msg), fmt, arg);
1372 va_end(arg);
1373 if(c->trace)
1374 c->trace("tlsError: %s\n", msg);
1375 else if(c->erred)
1376 fprint(2, "double error: %r, %s", msg);
1377 else
1378 werrstr("tls: local %s", msg);
1379 c->erred = 1;
1380 fprint(c->ctl, "alert %d", err);
1381 }
1382
1383 // commit to specific version number
1384 static int
setVersion(TlsConnection * c,int version)1385 setVersion(TlsConnection *c, int version)
1386 {
1387 if(c->verset || version > MaxProtoVersion || version < MinProtoVersion)
1388 return -1;
1389 if(version > c->version)
1390 version = c->version;
1391 if(version == SSL3Version) {
1392 c->version = version;
1393 c->finished.n = SSL3FinishedLen;
1394 }else if(version == TLSVersion){
1395 c->version = version;
1396 c->finished.n = TLSFinishedLen;
1397 }else
1398 return -1;
1399 c->verset = 1;
1400 return fprint(c->ctl, "version 0x%x", version);
1401 }
1402
1403 // confirm that received Finished message matches the expected value
1404 static int
finishedMatch(TlsConnection * c,Finished * f)1405 finishedMatch(TlsConnection *c, Finished *f)
1406 {
1407 return memcmp(f->verify, c->finished.verify, f->n) == 0;
1408 }
1409
1410 // free memory associated with TlsConnection struct
1411 // (but don't close the TLS channel itself)
1412 static void
tlsConnectionFree(TlsConnection * c)1413 tlsConnectionFree(TlsConnection *c)
1414 {
1415 tlsSecClose(c->sec);
1416 freebytes(c->sid);
1417 freebytes(c->cert);
1418 memset(c, 0, sizeof(c));
1419 free(c);
1420 }
1421
1422
1423 //================= cipher choices ========================
1424
1425 static int weakCipher[CipherMax] =
1426 {
1427 1, /* TLS_NULL_WITH_NULL_NULL */
1428 1, /* TLS_RSA_WITH_NULL_MD5 */
1429 1, /* TLS_RSA_WITH_NULL_SHA */
1430 1, /* TLS_RSA_EXPORT_WITH_RC4_40_MD5 */
1431 0, /* TLS_RSA_WITH_RC4_128_MD5 */
1432 0, /* TLS_RSA_WITH_RC4_128_SHA */
1433 1, /* TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5 */
1434 0, /* TLS_RSA_WITH_IDEA_CBC_SHA */
1435 1, /* TLS_RSA_EXPORT_WITH_DES40_CBC_SHA */
1436 0, /* TLS_RSA_WITH_DES_CBC_SHA */
1437 0, /* TLS_RSA_WITH_3DES_EDE_CBC_SHA */
1438 1, /* TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA */
1439 0, /* TLS_DH_DSS_WITH_DES_CBC_SHA */
1440 0, /* TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA */
1441 1, /* TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA */
1442 0, /* TLS_DH_RSA_WITH_DES_CBC_SHA */
1443 0, /* TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA */
1444 1, /* TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA */
1445 0, /* TLS_DHE_DSS_WITH_DES_CBC_SHA */
1446 0, /* TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA */
1447 1, /* TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA */
1448 0, /* TLS_DHE_RSA_WITH_DES_CBC_SHA */
1449 0, /* TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA */
1450 1, /* TLS_DH_anon_EXPORT_WITH_RC4_40_MD5 */
1451 1, /* TLS_DH_anon_WITH_RC4_128_MD5 */
1452 1, /* TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA */
1453 1, /* TLS_DH_anon_WITH_DES_CBC_SHA */
1454 1, /* TLS_DH_anon_WITH_3DES_EDE_CBC_SHA */
1455 };
1456
1457 static int
setAlgs(TlsConnection * c,int a)1458 setAlgs(TlsConnection *c, int a)
1459 {
1460 int i;
1461
1462 for(i = 0; i < nelem(cipherAlgs); i++){
1463 if(cipherAlgs[i].tlsid == a){
1464 c->enc = cipherAlgs[i].enc;
1465 c->digest = cipherAlgs[i].digest;
1466 c->nsecret = cipherAlgs[i].nsecret;
1467 if(c->nsecret > MaxKeyData)
1468 return 0;
1469 return 1;
1470 }
1471 }
1472 return 0;
1473 }
1474
1475 static int
okCipher(Ints * cv)1476 okCipher(Ints *cv)
1477 {
1478 int weak, i, j, c;
1479
1480 weak = 1;
1481 for(i = 0; i < cv->len; i++) {
1482 c = cv->data[i];
1483 if(c >= CipherMax)
1484 weak = 0;
1485 else
1486 weak &= weakCipher[c];
1487 for(j = 0; j < nelem(cipherAlgs); j++)
1488 if(cipherAlgs[j].ok && cipherAlgs[j].tlsid == c)
1489 return c;
1490 }
1491 if(weak)
1492 return -2;
1493 return -1;
1494 }
1495
1496 static int
okCompression(Bytes * cv)1497 okCompression(Bytes *cv)
1498 {
1499 int i, j, c;
1500
1501 for(i = 0; i < cv->len; i++) {
1502 c = cv->data[i];
1503 for(j = 0; j < nelem(compressors); j++) {
1504 if(compressors[j] == c)
1505 return c;
1506 }
1507 }
1508 return -1;
1509 }
1510
1511 static Lock ciphLock;
1512 static int nciphers;
1513
1514 static int
initCiphers(void)1515 initCiphers(void)
1516 {
1517 enum {MaxAlgF = 1024, MaxAlgs = 10};
1518 char s[MaxAlgF], *flds[MaxAlgs];
1519 int i, j, n, ok;
1520
1521 lock(&ciphLock);
1522 if(nciphers){
1523 unlock(&ciphLock);
1524 return nciphers;
1525 }
1526 j = open("#a/tls/encalgs", OREAD);
1527 if(j < 0){
1528 werrstr("can't open #a/tls/encalgs: %r");
1529 return 0;
1530 }
1531 n = read(j, s, MaxAlgF-1);
1532 close(j);
1533 if(n <= 0){
1534 werrstr("nothing in #a/tls/encalgs: %r");
1535 return 0;
1536 }
1537 s[n] = 0;
1538 n = getfields(s, flds, MaxAlgs, 1, " \t\r\n");
1539 for(i = 0; i < nelem(cipherAlgs); i++){
1540 ok = 0;
1541 for(j = 0; j < n; j++){
1542 if(strcmp(cipherAlgs[i].enc, flds[j]) == 0){
1543 ok = 1;
1544 break;
1545 }
1546 }
1547 cipherAlgs[i].ok = ok;
1548 }
1549
1550 j = open("#a/tls/hashalgs", OREAD);
1551 if(j < 0){
1552 werrstr("can't open #a/tls/hashalgs: %r");
1553 return 0;
1554 }
1555 n = read(j, s, MaxAlgF-1);
1556 close(j);
1557 if(n <= 0){
1558 werrstr("nothing in #a/tls/hashalgs: %r");
1559 return 0;
1560 }
1561 s[n] = 0;
1562 n = getfields(s, flds, MaxAlgs, 1, " \t\r\n");
1563 for(i = 0; i < nelem(cipherAlgs); i++){
1564 ok = 0;
1565 for(j = 0; j < n; j++){
1566 if(strcmp(cipherAlgs[i].digest, flds[j]) == 0){
1567 ok = 1;
1568 break;
1569 }
1570 }
1571 cipherAlgs[i].ok &= ok;
1572 if(cipherAlgs[i].ok)
1573 nciphers++;
1574 }
1575 unlock(&ciphLock);
1576 return nciphers;
1577 }
1578
1579 static Ints*
makeciphers(void)1580 makeciphers(void)
1581 {
1582 Ints *is;
1583 int i, j;
1584
1585 is = newints(nciphers);
1586 j = 0;
1587 for(i = 0; i < nelem(cipherAlgs); i++){
1588 if(cipherAlgs[i].ok)
1589 is->data[j++] = cipherAlgs[i].tlsid;
1590 }
1591 return is;
1592 }
1593
1594
1595
1596 //================= security functions ========================
1597
1598 // given X.509 certificate, set up connection to factotum
1599 // for using corresponding private key
1600 static AuthRpc*
factotum_rsa_open(uchar * cert,int certlen)1601 factotum_rsa_open(uchar *cert, int certlen)
1602 {
1603 int afd;
1604 char *s;
1605 mpint *pub = nil;
1606 RSApub *rsapub;
1607 AuthRpc *rpc;
1608
1609 // start talking to factotum
1610 if((afd = open("/mnt/factotum/rpc", ORDWR)) < 0)
1611 return nil;
1612 if((rpc = auth_allocrpc(afd)) == nil){
1613 close(afd);
1614 return nil;
1615 }
1616 s = "proto=rsa service=tls role=client";
1617 if(auth_rpc(rpc, "start", s, strlen(s)) != ARok){
1618 factotum_rsa_close(rpc);
1619 return nil;
1620 }
1621
1622 // roll factotum keyring around to match certificate
1623 rsapub = X509toRSApub(cert, certlen, nil, 0);
1624 while(1){
1625 if(auth_rpc(rpc, "read", nil, 0) != ARok){
1626 factotum_rsa_close(rpc);
1627 rpc = nil;
1628 goto done;
1629 }
1630 pub = strtomp(rpc->arg, nil, 16, nil);
1631 assert(pub != nil);
1632 if(mpcmp(pub,rsapub->n) == 0)
1633 break;
1634 }
1635 done:
1636 mpfree(pub);
1637 rsapubfree(rsapub);
1638 return rpc;
1639 }
1640
1641 static mpint*
factotum_rsa_decrypt(AuthRpc * rpc,mpint * cipher)1642 factotum_rsa_decrypt(AuthRpc *rpc, mpint *cipher)
1643 {
1644 char *p;
1645 int rv;
1646
1647 if((p = mptoa(cipher, 16, nil, 0)) == nil)
1648 return nil;
1649 rv = auth_rpc(rpc, "write", p, strlen(p));
1650 free(p);
1651 if(rv != ARok || auth_rpc(rpc, "read", nil, 0) != ARok)
1652 return nil;
1653 mpfree(cipher);
1654 return strtomp(rpc->arg, nil, 16, nil);
1655 }
1656
1657 static void
factotum_rsa_close(AuthRpc * rpc)1658 factotum_rsa_close(AuthRpc*rpc)
1659 {
1660 if(!rpc)
1661 return;
1662 close(rpc->afd);
1663 auth_freerpc(rpc);
1664 }
1665
1666 static void
tlsPmd5(uchar * buf,int nbuf,uchar * key,int nkey,uchar * label,int nlabel,uchar * seed0,int nseed0,uchar * seed1,int nseed1)1667 tlsPmd5(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
1668 {
1669 uchar ai[MD5dlen], tmp[MD5dlen];
1670 int i, n;
1671 MD5state *s;
1672
1673 // generate a1
1674 s = hmac_md5(label, nlabel, key, nkey, nil, nil);
1675 s = hmac_md5(seed0, nseed0, key, nkey, nil, s);
1676 hmac_md5(seed1, nseed1, key, nkey, ai, s);
1677
1678 while(nbuf > 0) {
1679 s = hmac_md5(ai, MD5dlen, key, nkey, nil, nil);
1680 s = hmac_md5(label, nlabel, key, nkey, nil, s);
1681 s = hmac_md5(seed0, nseed0, key, nkey, nil, s);
1682 hmac_md5(seed1, nseed1, key, nkey, tmp, s);
1683 n = MD5dlen;
1684 if(n > nbuf)
1685 n = nbuf;
1686 for(i = 0; i < n; i++)
1687 buf[i] ^= tmp[i];
1688 buf += n;
1689 nbuf -= n;
1690 hmac_md5(ai, MD5dlen, key, nkey, tmp, nil);
1691 memmove(ai, tmp, MD5dlen);
1692 }
1693 }
1694
1695 static void
tlsPsha1(uchar * buf,int nbuf,uchar * key,int nkey,uchar * label,int nlabel,uchar * seed0,int nseed0,uchar * seed1,int nseed1)1696 tlsPsha1(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
1697 {
1698 uchar ai[SHA1dlen], tmp[SHA1dlen];
1699 int i, n;
1700 SHAstate *s;
1701
1702 // generate a1
1703 s = hmac_sha1(label, nlabel, key, nkey, nil, nil);
1704 s = hmac_sha1(seed0, nseed0, key, nkey, nil, s);
1705 hmac_sha1(seed1, nseed1, key, nkey, ai, s);
1706
1707 while(nbuf > 0) {
1708 s = hmac_sha1(ai, SHA1dlen, key, nkey, nil, nil);
1709 s = hmac_sha1(label, nlabel, key, nkey, nil, s);
1710 s = hmac_sha1(seed0, nseed0, key, nkey, nil, s);
1711 hmac_sha1(seed1, nseed1, key, nkey, tmp, s);
1712 n = SHA1dlen;
1713 if(n > nbuf)
1714 n = nbuf;
1715 for(i = 0; i < n; i++)
1716 buf[i] ^= tmp[i];
1717 buf += n;
1718 nbuf -= n;
1719 hmac_sha1(ai, SHA1dlen, key, nkey, tmp, nil);
1720 memmove(ai, tmp, SHA1dlen);
1721 }
1722 }
1723
1724 // fill buf with md5(args)^sha1(args)
1725 static void
tlsPRF(uchar * buf,int nbuf,uchar * key,int nkey,char * label,uchar * seed0,int nseed0,uchar * seed1,int nseed1)1726 tlsPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
1727 {
1728 int i;
1729 int nlabel = strlen(label);
1730 int n = (nkey + 1) >> 1;
1731
1732 for(i = 0; i < nbuf; i++)
1733 buf[i] = 0;
1734 tlsPmd5(buf, nbuf, key, n, (uchar*)label, nlabel, seed0, nseed0, seed1, nseed1);
1735 tlsPsha1(buf, nbuf, key+nkey-n, n, (uchar*)label, nlabel, seed0, nseed0, seed1, nseed1);
1736 }
1737
1738 /*
1739 * for setting server session id's
1740 */
1741 static Lock sidLock;
1742 static long maxSid = 1;
1743
1744 /* the keys are verified to have the same public components
1745 * and to function correctly with pkcs 1 encryption and decryption. */
1746 static TlsSec*
tlsSecInits(int cvers,uchar * csid,int ncsid,uchar * crandom,uchar * ssid,int * nssid,uchar * srandom)1747 tlsSecInits(int cvers, uchar *csid, int ncsid, uchar *crandom, uchar *ssid, int *nssid, uchar *srandom)
1748 {
1749 TlsSec *sec = emalloc(sizeof(*sec));
1750
1751 USED(csid); USED(ncsid); // ignore csid for now
1752
1753 memmove(sec->crandom, crandom, RandomSize);
1754 sec->clientVers = cvers;
1755
1756 put32(sec->srandom, time(0));
1757 genrandom(sec->srandom+4, RandomSize-4);
1758 memmove(srandom, sec->srandom, RandomSize);
1759
1760 /*
1761 * make up a unique sid: use our pid, and and incrementing id
1762 * can signal no sid by setting nssid to 0.
1763 */
1764 memset(ssid, 0, SidSize);
1765 put32(ssid, getpid());
1766 lock(&sidLock);
1767 put32(ssid+4, maxSid++);
1768 unlock(&sidLock);
1769 *nssid = SidSize;
1770 return sec;
1771 }
1772
1773 static int
tlsSecSecrets(TlsSec * sec,int vers,uchar * epm,int nepm,uchar * kd,int nkd)1774 tlsSecSecrets(TlsSec *sec, int vers, uchar *epm, int nepm, uchar *kd, int nkd)
1775 {
1776 if(epm != nil){
1777 if(setVers(sec, vers) < 0)
1778 goto Err;
1779 serverMasterSecret(sec, epm, nepm);
1780 }else if(sec->vers != vers){
1781 werrstr("mismatched session versions");
1782 goto Err;
1783 }
1784 setSecrets(sec, kd, nkd);
1785 return 0;
1786 Err:
1787 sec->ok = -1;
1788 return -1;
1789 }
1790
1791 static TlsSec*
tlsSecInitc(int cvers,uchar * crandom)1792 tlsSecInitc(int cvers, uchar *crandom)
1793 {
1794 TlsSec *sec = emalloc(sizeof(*sec));
1795 sec->clientVers = cvers;
1796 put32(sec->crandom, time(0));
1797 genrandom(sec->crandom+4, RandomSize-4);
1798 memmove(crandom, sec->crandom, RandomSize);
1799 return sec;
1800 }
1801
1802 static int
tlsSecSecretc(TlsSec * sec,uchar * sid,int nsid,uchar * srandom,uchar * cert,int ncert,int vers,uchar ** epm,int * nepm,uchar * kd,int nkd)1803 tlsSecSecretc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers, uchar **epm, int *nepm, uchar *kd, int nkd)
1804 {
1805 RSApub *pub;
1806
1807 pub = nil;
1808
1809 USED(sid);
1810 USED(nsid);
1811
1812 memmove(sec->srandom, srandom, RandomSize);
1813
1814 if(setVers(sec, vers) < 0)
1815 goto Err;
1816
1817 pub = X509toRSApub(cert, ncert, nil, 0);
1818 if(pub == nil){
1819 werrstr("invalid x509/rsa certificate");
1820 goto Err;
1821 }
1822 if(clientMasterSecret(sec, pub, epm, nepm) < 0)
1823 goto Err;
1824 rsapubfree(pub);
1825 setSecrets(sec, kd, nkd);
1826 return 0;
1827
1828 Err:
1829 if(pub != nil)
1830 rsapubfree(pub);
1831 sec->ok = -1;
1832 return -1;
1833 }
1834
1835 static int
tlsSecFinished(TlsSec * sec,MD5state md5,SHAstate sha1,uchar * fin,int nfin,int isclient)1836 tlsSecFinished(TlsSec *sec, MD5state md5, SHAstate sha1, uchar *fin, int nfin, int isclient)
1837 {
1838 if(sec->nfin != nfin){
1839 sec->ok = -1;
1840 werrstr("invalid finished exchange");
1841 return -1;
1842 }
1843 md5.malloced = 0;
1844 sha1.malloced = 0;
1845 (*sec->setFinished)(sec, md5, sha1, fin, isclient);
1846 return 1;
1847 }
1848
1849 static void
tlsSecOk(TlsSec * sec)1850 tlsSecOk(TlsSec *sec)
1851 {
1852 if(sec->ok == 0)
1853 sec->ok = 1;
1854 }
1855
1856 static void
tlsSecKill(TlsSec * sec)1857 tlsSecKill(TlsSec *sec)
1858 {
1859 if(!sec)
1860 return;
1861 factotum_rsa_close(sec->rpc);
1862 sec->ok = -1;
1863 }
1864
1865 static void
tlsSecClose(TlsSec * sec)1866 tlsSecClose(TlsSec *sec)
1867 {
1868 if(!sec)
1869 return;
1870 factotum_rsa_close(sec->rpc);
1871 free(sec->server);
1872 free(sec);
1873 }
1874
1875 static int
setVers(TlsSec * sec,int v)1876 setVers(TlsSec *sec, int v)
1877 {
1878 if(v == SSL3Version){
1879 sec->setFinished = sslSetFinished;
1880 sec->nfin = SSL3FinishedLen;
1881 sec->prf = sslPRF;
1882 }else if(v == TLSVersion){
1883 sec->setFinished = tlsSetFinished;
1884 sec->nfin = TLSFinishedLen;
1885 sec->prf = tlsPRF;
1886 }else{
1887 werrstr("invalid version");
1888 return -1;
1889 }
1890 sec->vers = v;
1891 return 0;
1892 }
1893
1894 /*
1895 * generate secret keys from the master secret.
1896 *
1897 * different crypto selections will require different amounts
1898 * of key expansion and use of key expansion data,
1899 * but it's all generated using the same function.
1900 */
1901 static void
setSecrets(TlsSec * sec,uchar * kd,int nkd)1902 setSecrets(TlsSec *sec, uchar *kd, int nkd)
1903 {
1904 (*sec->prf)(kd, nkd, sec->sec, MasterSecretSize, "key expansion",
1905 sec->srandom, RandomSize, sec->crandom, RandomSize);
1906 }
1907
1908 /*
1909 * set the master secret from the pre-master secret.
1910 */
1911 static void
setMasterSecret(TlsSec * sec,Bytes * pm)1912 setMasterSecret(TlsSec *sec, Bytes *pm)
1913 {
1914 (*sec->prf)(sec->sec, MasterSecretSize, pm->data, MasterSecretSize, "master secret",
1915 sec->crandom, RandomSize, sec->srandom, RandomSize);
1916 }
1917
1918 static void
serverMasterSecret(TlsSec * sec,uchar * epm,int nepm)1919 serverMasterSecret(TlsSec *sec, uchar *epm, int nepm)
1920 {
1921 Bytes *pm;
1922
1923 pm = pkcs1_decrypt(sec, epm, nepm);
1924
1925 // if the client messed up, just continue as if everything is ok,
1926 // to prevent attacks to check for correctly formatted messages.
1927 // Hence the fprint(2,) can't be replaced by tlsError(), which sends an Alert msg to the client.
1928 if(sec->ok < 0 || pm == nil || get16(pm->data) != sec->clientVers){
1929 fprint(2, "serverMasterSecret failed ok=%d pm=%p pmvers=%x cvers=%x nepm=%d\n",
1930 sec->ok, pm, pm ? get16(pm->data) : -1, sec->clientVers, nepm);
1931 sec->ok = -1;
1932 if(pm != nil)
1933 freebytes(pm);
1934 pm = newbytes(MasterSecretSize);
1935 genrandom(pm->data, MasterSecretSize);
1936 }
1937 setMasterSecret(sec, pm);
1938 memset(pm->data, 0, pm->len);
1939 freebytes(pm);
1940 }
1941
1942 static int
clientMasterSecret(TlsSec * sec,RSApub * pub,uchar ** epm,int * nepm)1943 clientMasterSecret(TlsSec *sec, RSApub *pub, uchar **epm, int *nepm)
1944 {
1945 Bytes *pm, *key;
1946
1947 pm = newbytes(MasterSecretSize);
1948 put16(pm->data, sec->clientVers);
1949 genrandom(pm->data+2, MasterSecretSize - 2);
1950
1951 setMasterSecret(sec, pm);
1952
1953 key = pkcs1_encrypt(pm, pub, 2);
1954 memset(pm->data, 0, pm->len);
1955 freebytes(pm);
1956 if(key == nil){
1957 werrstr("tls pkcs1_encrypt failed");
1958 return -1;
1959 }
1960
1961 *nepm = key->len;
1962 *epm = malloc(*nepm);
1963 if(*epm == nil){
1964 freebytes(key);
1965 werrstr("out of memory");
1966 return -1;
1967 }
1968 memmove(*epm, key->data, *nepm);
1969
1970 freebytes(key);
1971
1972 return 1;
1973 }
1974
1975 static void
sslSetFinished(TlsSec * sec,MD5state hsmd5,SHAstate hssha1,uchar * finished,int isClient)1976 sslSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient)
1977 {
1978 DigestState *s;
1979 uchar h0[MD5dlen], h1[SHA1dlen], pad[48];
1980 char *label;
1981
1982 if(isClient)
1983 label = "CLNT";
1984 else
1985 label = "SRVR";
1986
1987 md5((uchar*)label, 4, nil, &hsmd5);
1988 md5(sec->sec, MasterSecretSize, nil, &hsmd5);
1989 memset(pad, 0x36, 48);
1990 md5(pad, 48, nil, &hsmd5);
1991 md5(nil, 0, h0, &hsmd5);
1992 memset(pad, 0x5C, 48);
1993 s = md5(sec->sec, MasterSecretSize, nil, nil);
1994 s = md5(pad, 48, nil, s);
1995 md5(h0, MD5dlen, finished, s);
1996
1997 sha1((uchar*)label, 4, nil, &hssha1);
1998 sha1(sec->sec, MasterSecretSize, nil, &hssha1);
1999 memset(pad, 0x36, 40);
2000 sha1(pad, 40, nil, &hssha1);
2001 sha1(nil, 0, h1, &hssha1);
2002 memset(pad, 0x5C, 40);
2003 s = sha1(sec->sec, MasterSecretSize, nil, nil);
2004 s = sha1(pad, 40, nil, s);
2005 sha1(h1, SHA1dlen, finished + MD5dlen, s);
2006 }
2007
2008 // fill "finished" arg with md5(args)^sha1(args)
2009 static void
tlsSetFinished(TlsSec * sec,MD5state hsmd5,SHAstate hssha1,uchar * finished,int isClient)2010 tlsSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient)
2011 {
2012 uchar h0[MD5dlen], h1[SHA1dlen];
2013 char *label;
2014
2015 // get current hash value, but allow further messages to be hashed in
2016 md5(nil, 0, h0, &hsmd5);
2017 sha1(nil, 0, h1, &hssha1);
2018
2019 if(isClient)
2020 label = "client finished";
2021 else
2022 label = "server finished";
2023 tlsPRF(finished, TLSFinishedLen, sec->sec, MasterSecretSize, label, h0, MD5dlen, h1, SHA1dlen);
2024 }
2025
2026 static void
sslPRF(uchar * buf,int nbuf,uchar * key,int nkey,char * label,uchar * seed0,int nseed0,uchar * seed1,int nseed1)2027 sslPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
2028 {
2029 DigestState *s;
2030 uchar sha1dig[SHA1dlen], md5dig[MD5dlen], tmp[26];
2031 int i, n, len;
2032
2033 USED(label);
2034 len = 1;
2035 while(nbuf > 0){
2036 if(len > 26)
2037 return;
2038 for(i = 0; i < len; i++)
2039 tmp[i] = 'A' - 1 + len;
2040 s = sha1(tmp, len, nil, nil);
2041 s = sha1(key, nkey, nil, s);
2042 s = sha1(seed0, nseed0, nil, s);
2043 sha1(seed1, nseed1, sha1dig, s);
2044 s = md5(key, nkey, nil, nil);
2045 md5(sha1dig, SHA1dlen, md5dig, s);
2046 n = MD5dlen;
2047 if(n > nbuf)
2048 n = nbuf;
2049 memmove(buf, md5dig, n);
2050 buf += n;
2051 nbuf -= n;
2052 len++;
2053 }
2054 }
2055
2056 static mpint*
bytestomp(Bytes * bytes)2057 bytestomp(Bytes* bytes)
2058 {
2059 mpint* ans;
2060
2061 ans = betomp(bytes->data, bytes->len, nil);
2062 return ans;
2063 }
2064
2065 /*
2066 * Convert mpint* to Bytes, putting high order byte first.
2067 */
2068 static Bytes*
mptobytes(mpint * big)2069 mptobytes(mpint* big)
2070 {
2071 int n, m;
2072 uchar *a;
2073 Bytes* ans;
2074
2075 a = nil;
2076 n = (mpsignif(big)+7)/8;
2077 m = mptobe(big, nil, n, &a);
2078 ans = makebytes(a, m);
2079 if(a != nil)
2080 free(a);
2081 return ans;
2082 }
2083
2084 // Do RSA computation on block according to key, and pad
2085 // result on left with zeros to make it modlen long.
2086 static Bytes*
rsacomp(Bytes * block,RSApub * key,int modlen)2087 rsacomp(Bytes* block, RSApub* key, int modlen)
2088 {
2089 mpint *x, *y;
2090 Bytes *a, *ybytes;
2091 int ylen;
2092
2093 x = bytestomp(block);
2094 y = rsaencrypt(key, x, nil);
2095 mpfree(x);
2096 ybytes = mptobytes(y);
2097 ylen = ybytes->len;
2098
2099 if(ylen < modlen) {
2100 a = newbytes(modlen);
2101 memset(a->data, 0, modlen-ylen);
2102 memmove(a->data+modlen-ylen, ybytes->data, ylen);
2103 freebytes(ybytes);
2104 ybytes = a;
2105 }
2106 else if(ylen > modlen) {
2107 // assume it has leading zeros (mod should make it so)
2108 a = newbytes(modlen);
2109 memmove(a->data, ybytes->data, modlen);
2110 freebytes(ybytes);
2111 ybytes = a;
2112 }
2113 mpfree(y);
2114 return ybytes;
2115 }
2116
2117 // encrypt data according to PKCS#1, /lib/rfc/rfc2437 9.1.2.1
2118 static Bytes*
pkcs1_encrypt(Bytes * data,RSApub * key,int blocktype)2119 pkcs1_encrypt(Bytes* data, RSApub* key, int blocktype)
2120 {
2121 Bytes *pad, *eb, *ans;
2122 int i, dlen, padlen, modlen;
2123
2124 modlen = (mpsignif(key->n)+7)/8;
2125 dlen = data->len;
2126 if(modlen < 12 || dlen > modlen - 11)
2127 return nil;
2128 padlen = modlen - 3 - dlen;
2129 pad = newbytes(padlen);
2130 genrandom(pad->data, padlen);
2131 for(i = 0; i < padlen; i++) {
2132 if(blocktype == 0)
2133 pad->data[i] = 0;
2134 else if(blocktype == 1)
2135 pad->data[i] = 255;
2136 else if(pad->data[i] == 0)
2137 pad->data[i] = 1;
2138 }
2139 eb = newbytes(modlen);
2140 eb->data[0] = 0;
2141 eb->data[1] = blocktype;
2142 memmove(eb->data+2, pad->data, padlen);
2143 eb->data[padlen+2] = 0;
2144 memmove(eb->data+padlen+3, data->data, dlen);
2145 ans = rsacomp(eb, key, modlen);
2146 freebytes(eb);
2147 freebytes(pad);
2148 return ans;
2149 }
2150
2151 // decrypt data according to PKCS#1, with given key.
2152 // expect a block type of 2.
2153 static Bytes*
pkcs1_decrypt(TlsSec * sec,uchar * epm,int nepm)2154 pkcs1_decrypt(TlsSec *sec, uchar *epm, int nepm)
2155 {
2156 Bytes *eb, *ans = nil;
2157 int i, modlen;
2158 mpint *x, *y;
2159
2160 modlen = (mpsignif(sec->rsapub->n)+7)/8;
2161 if(nepm != modlen)
2162 return nil;
2163 x = betomp(epm, nepm, nil);
2164 y = factotum_rsa_decrypt(sec->rpc, x);
2165 if(y == nil)
2166 return nil;
2167 eb = mptobytes(y);
2168 if(eb->len < modlen){ // pad on left with zeros
2169 ans = newbytes(modlen);
2170 memset(ans->data, 0, modlen-eb->len);
2171 memmove(ans->data+modlen-eb->len, eb->data, eb->len);
2172 freebytes(eb);
2173 eb = ans;
2174 }
2175 if(eb->data[0] == 0 && eb->data[1] == 2) {
2176 for(i = 2; i < modlen; i++)
2177 if(eb->data[i] == 0)
2178 break;
2179 if(i < modlen - 1)
2180 ans = makebytes(eb->data+i+1, modlen-(i+1));
2181 }
2182 freebytes(eb);
2183 return ans;
2184 }
2185
2186
2187 //================= general utility functions ========================
2188
2189 static void *
emalloc(int n)2190 emalloc(int n)
2191 {
2192 void *p;
2193 if(n==0)
2194 n=1;
2195 p = malloc(n);
2196 if(p == nil){
2197 exits("out of memory");
2198 }
2199 memset(p, 0, n);
2200 return p;
2201 }
2202
2203 static void *
erealloc(void * ReallocP,int ReallocN)2204 erealloc(void *ReallocP, int ReallocN)
2205 {
2206 if(ReallocN == 0)
2207 ReallocN = 1;
2208 if(!ReallocP)
2209 ReallocP = emalloc(ReallocN);
2210 else if(!(ReallocP = realloc(ReallocP, ReallocN))){
2211 exits("out of memory");
2212 }
2213 return(ReallocP);
2214 }
2215
2216 static void
put32(uchar * p,u32int x)2217 put32(uchar *p, u32int x)
2218 {
2219 p[0] = x>>24;
2220 p[1] = x>>16;
2221 p[2] = x>>8;
2222 p[3] = x;
2223 }
2224
2225 static void
put24(uchar * p,int x)2226 put24(uchar *p, int x)
2227 {
2228 p[0] = x>>16;
2229 p[1] = x>>8;
2230 p[2] = x;
2231 }
2232
2233 static void
put16(uchar * p,int x)2234 put16(uchar *p, int x)
2235 {
2236 p[0] = x>>8;
2237 p[1] = x;
2238 }
2239
2240 static u32int
get32(uchar * p)2241 get32(uchar *p)
2242 {
2243 return (p[0]<<24)|(p[1]<<16)|(p[2]<<8)|p[3];
2244 }
2245
2246 static int
get24(uchar * p)2247 get24(uchar *p)
2248 {
2249 return (p[0]<<16)|(p[1]<<8)|p[2];
2250 }
2251
2252 static int
get16(uchar * p)2253 get16(uchar *p)
2254 {
2255 return (p[0]<<8)|p[1];
2256 }
2257
2258 #define OFFSET(x, s) offsetof(s, x)
2259
2260 /*
2261 * malloc and return a new Bytes structure capable of
2262 * holding len bytes. (len >= 0)
2263 * Used to use crypt_malloc, which aborts if malloc fails.
2264 */
2265 static Bytes*
newbytes(int len)2266 newbytes(int len)
2267 {
2268 Bytes* ans;
2269
2270 ans = (Bytes*)malloc(OFFSET(data[0], Bytes) + len);
2271 ans->len = len;
2272 return ans;
2273 }
2274
2275 /*
2276 * newbytes(len), with data initialized from buf
2277 */
2278 static Bytes*
makebytes(uchar * buf,int len)2279 makebytes(uchar* buf, int len)
2280 {
2281 Bytes* ans;
2282
2283 ans = newbytes(len);
2284 memmove(ans->data, buf, len);
2285 return ans;
2286 }
2287
2288 static void
freebytes(Bytes * b)2289 freebytes(Bytes* b)
2290 {
2291 if(b != nil)
2292 free(b);
2293 }
2294
2295 /* len is number of ints */
2296 static Ints*
newints(int len)2297 newints(int len)
2298 {
2299 Ints* ans;
2300
2301 ans = (Ints*)malloc(OFFSET(data[0], Ints) + len*sizeof(int));
2302 ans->len = len;
2303 return ans;
2304 }
2305
2306 static Ints*
makeints(int * buf,int len)2307 makeints(int* buf, int len)
2308 {
2309 Ints* ans;
2310
2311 ans = newints(len);
2312 if(len > 0)
2313 memmove(ans->data, buf, len*sizeof(int));
2314 return ans;
2315 }
2316
2317 static void
freeints(Ints * b)2318 freeints(Ints* b)
2319 {
2320 if(b != nil)
2321 free(b);
2322 }
2323