1 #include <u.h>
2 #include <libc.h>
3 #include <bio.h>
4 #include <auth.h>
5 #include <mp.h>
6 #include <libsec.h>
7
8 // The main groups of functions are:
9 // client/server - main handshake protocol definition
10 // message functions - formating handshake messages
11 // cipher choices - catalog of digest and encrypt algorithms
12 // security functions - PKCS#1, sslHMAC, session keygen
13 // general utility functions - malloc, serialization
14 // The handshake protocol builds on the TLS/SSL3 record layer protocol,
15 // which is implemented in kernel device #a. See also /lib/rfc/rfc2246.
16
17 enum {
18 TLSFinishedLen = 12,
19 SSL3FinishedLen = MD5dlen+SHA1dlen,
20 MaxKeyData = 104, // amount of secret we may need
21 MaxChunk = 1<<14,
22 RandomSize = 32,
23 SidSize = 32,
24 MasterSecretSize = 48,
25 AQueue = 0,
26 AFlush = 1,
27 };
28
29 typedef struct TlsSec TlsSec;
30
31 typedef struct Bytes{
32 int len;
33 uchar data[1]; // [len]
34 } Bytes;
35
36 typedef struct Ints{
37 int len;
38 int data[1]; // [len]
39 } Ints;
40
41 typedef struct Algs{
42 char *enc;
43 char *digest;
44 int nsecret;
45 int tlsid;
46 int ok;
47 } Algs;
48
49 typedef struct Finished{
50 uchar verify[SSL3FinishedLen];
51 int n;
52 } Finished;
53
54 typedef struct TlsConnection{
55 TlsSec *sec; // security management goo
56 int hand, ctl; // record layer file descriptors
57 int erred; // set when tlsError called
58 int (*trace)(char*fmt, ...); // for debugging
59 int version; // protocol we are speaking
60 int verset; // version has been set
61 int ver2hi; // server got a version 2 hello
62 int isClient; // is this the client or server?
63 Bytes *sid; // SessionID
64 Bytes *cert; // only last - no chain
65
66 Lock statelk;
67 int state; // must be set using setstate
68
69 // input buffer for handshake messages
70 uchar buf[MaxChunk+2048];
71 uchar *rp, *ep;
72
73 uchar crandom[RandomSize]; // client random
74 uchar srandom[RandomSize]; // server random
75 int clientVersion; // version in ClientHello
76 char *digest; // name of digest algorithm to use
77 char *enc; // name of encryption algorithm to use
78 int nsecret; // amount of secret data to init keys
79
80 // for finished messages
81 MD5state hsmd5; // handshake hash
82 SHAstate hssha1; // handshake hash
83 Finished finished;
84 } TlsConnection;
85
86 typedef struct Msg{
87 int tag;
88 union {
89 struct {
90 int version;
91 uchar random[RandomSize];
92 Bytes* sid;
93 Ints* ciphers;
94 Bytes* compressors;
95 } clientHello;
96 struct {
97 int version;
98 uchar random[RandomSize];
99 Bytes* sid;
100 int cipher;
101 int compressor;
102 } serverHello;
103 struct {
104 int ncert;
105 Bytes **certs;
106 } certificate;
107 struct {
108 Bytes *types;
109 int nca;
110 Bytes **cas;
111 } certificateRequest;
112 struct {
113 Bytes *key;
114 } clientKeyExchange;
115 Finished finished;
116 } u;
117 } Msg;
118
119 typedef struct TlsSec{
120 char *server; // name of remote; nil for server
121 int ok; // <0 killed; ==0 in progress; >0 reusable
122 RSApub *rsapub;
123 AuthRpc *rpc; // factotum for rsa private key
124 uchar sec[MasterSecretSize]; // master secret
125 uchar crandom[RandomSize]; // client random
126 uchar srandom[RandomSize]; // server random
127 int clientVers; // version in ClientHello
128 int vers; // final version
129 // byte generation and handshake checksum
130 void (*prf)(uchar*, int, uchar*, int, char*, uchar*, int, uchar*, int);
131 void (*setFinished)(TlsSec*, MD5state, SHAstate, uchar*, int);
132 int nfin;
133 } TlsSec;
134
135
136 enum {
137 TLSVersion = 0x0301,
138 SSL3Version = 0x0300,
139 ProtocolVersion = 0x0301, // maximum version we speak
140 MinProtoVersion = 0x0300, // limits on version we accept
141 MaxProtoVersion = 0x03ff,
142 };
143
144 // handshake type
145 enum {
146 HHelloRequest,
147 HClientHello,
148 HServerHello,
149 HSSL2ClientHello = 9, /* local convention; see devtls.c */
150 HCertificate = 11,
151 HServerKeyExchange,
152 HCertificateRequest,
153 HServerHelloDone,
154 HCertificateVerify,
155 HClientKeyExchange,
156 HFinished = 20,
157 HMax
158 };
159
160 // alerts
161 enum {
162 ECloseNotify = 0,
163 EUnexpectedMessage = 10,
164 EBadRecordMac = 20,
165 EDecryptionFailed = 21,
166 ERecordOverflow = 22,
167 EDecompressionFailure = 30,
168 EHandshakeFailure = 40,
169 ENoCertificate = 41,
170 EBadCertificate = 42,
171 EUnsupportedCertificate = 43,
172 ECertificateRevoked = 44,
173 ECertificateExpired = 45,
174 ECertificateUnknown = 46,
175 EIllegalParameter = 47,
176 EUnknownCa = 48,
177 EAccessDenied = 49,
178 EDecodeError = 50,
179 EDecryptError = 51,
180 EExportRestriction = 60,
181 EProtocolVersion = 70,
182 EInsufficientSecurity = 71,
183 EInternalError = 80,
184 EUserCanceled = 90,
185 ENoRenegotiation = 100,
186 EMax = 256
187 };
188
189 // cipher suites
190 enum {
191 TLS_NULL_WITH_NULL_NULL = 0x0000,
192 TLS_RSA_WITH_NULL_MD5 = 0x0001,
193 TLS_RSA_WITH_NULL_SHA = 0x0002,
194 TLS_RSA_EXPORT_WITH_RC4_40_MD5 = 0x0003,
195 TLS_RSA_WITH_RC4_128_MD5 = 0x0004,
196 TLS_RSA_WITH_RC4_128_SHA = 0x0005,
197 TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5 = 0X0006,
198 TLS_RSA_WITH_IDEA_CBC_SHA = 0X0007,
199 TLS_RSA_EXPORT_WITH_DES40_CBC_SHA = 0X0008,
200 TLS_RSA_WITH_DES_CBC_SHA = 0X0009,
201 TLS_RSA_WITH_3DES_EDE_CBC_SHA = 0X000A,
202 TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA = 0X000B,
203 TLS_DH_DSS_WITH_DES_CBC_SHA = 0X000C,
204 TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA = 0X000D,
205 TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA = 0X000E,
206 TLS_DH_RSA_WITH_DES_CBC_SHA = 0X000F,
207 TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA = 0X0010,
208 TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA = 0X0011,
209 TLS_DHE_DSS_WITH_DES_CBC_SHA = 0X0012,
210 TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA = 0X0013, // ZZZ must be implemented for tls1.0 compliance
211 TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA = 0X0014,
212 TLS_DHE_RSA_WITH_DES_CBC_SHA = 0X0015,
213 TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA = 0X0016,
214 TLS_DH_anon_EXPORT_WITH_RC4_40_MD5 = 0x0017,
215 TLS_DH_anon_WITH_RC4_128_MD5 = 0x0018,
216 TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA = 0X0019,
217 TLS_DH_anon_WITH_DES_CBC_SHA = 0X001A,
218 TLS_DH_anon_WITH_3DES_EDE_CBC_SHA = 0X001B,
219
220 TLS_RSA_WITH_AES_128_CBC_SHA = 0X002f, // aes, aka rijndael with 128 bit blocks
221 TLS_DH_DSS_WITH_AES_128_CBC_SHA = 0X0030,
222 TLS_DH_RSA_WITH_AES_128_CBC_SHA = 0X0031,
223 TLS_DHE_DSS_WITH_AES_128_CBC_SHA = 0X0032,
224 TLS_DHE_RSA_WITH_AES_128_CBC_SHA = 0X0033,
225 TLS_DH_anon_WITH_AES_128_CBC_SHA = 0X0034,
226 TLS_RSA_WITH_AES_256_CBC_SHA = 0X0035,
227 TLS_DH_DSS_WITH_AES_256_CBC_SHA = 0X0036,
228 TLS_DH_RSA_WITH_AES_256_CBC_SHA = 0X0037,
229 TLS_DHE_DSS_WITH_AES_256_CBC_SHA = 0X0038,
230 TLS_DHE_RSA_WITH_AES_256_CBC_SHA = 0X0039,
231 TLS_DH_anon_WITH_AES_256_CBC_SHA = 0X003A,
232 CipherMax
233 };
234
235 // compression methods
236 enum {
237 CompressionNull = 0,
238 CompressionMax
239 };
240
241 static Algs cipherAlgs[] = {
242 {"rc4_128", "md5", 2 * (16 + MD5dlen), TLS_RSA_WITH_RC4_128_MD5},
243 {"rc4_128", "sha1", 2 * (16 + SHA1dlen), TLS_RSA_WITH_RC4_128_SHA},
244 {"3des_ede_cbc","sha1",2*(4*8+SHA1dlen), TLS_RSA_WITH_3DES_EDE_CBC_SHA},
245 };
246
247 static uchar compressors[] = {
248 CompressionNull,
249 };
250
251 static TlsConnection *tlsServer2(int ctl, int hand, uchar *cert, int ncert, int (*trace)(char*fmt, ...));
252 static TlsConnection *tlsClient2(int ctl, int hand, uchar *csid, int ncsid, int (*trace)(char*fmt, ...));
253
254 static void msgClear(Msg *m);
255 static char* msgPrint(char *buf, int n, Msg *m);
256 static int msgRecv(TlsConnection *c, Msg *m);
257 static int msgSend(TlsConnection *c, Msg *m, int act);
258 static void tlsError(TlsConnection *c, int err, char *msg, ...);
259 #pragma varargck argpos tlsError 3
260 static int setVersion(TlsConnection *c, int version);
261 static int finishedMatch(TlsConnection *c, Finished *f);
262 static void tlsConnectionFree(TlsConnection *c);
263
264 static int setAlgs(TlsConnection *c, int a);
265 static int okCipher(Ints *cv);
266 static int okCompression(Bytes *cv);
267 static int initCiphers(void);
268 static Ints* makeciphers(void);
269
270 static TlsSec* tlsSecInits(int cvers, uchar *csid, int ncsid, uchar *crandom, uchar *ssid, int *nssid, uchar *srandom);
271 static int tlsSecSecrets(TlsSec *sec, int vers, uchar *epm, int nepm, uchar *kd, int nkd);
272 static TlsSec* tlsSecInitc(int cvers, uchar *crandom);
273 static int tlsSecSecretc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers, uchar **epm, int *nepm, uchar *kd, int nkd);
274 static int tlsSecFinished(TlsSec *sec, MD5state md5, SHAstate sha1, uchar *fin, int nfin, int isclient);
275 static void tlsSecOk(TlsSec *sec);
276 static void tlsSecKill(TlsSec *sec);
277 static void tlsSecClose(TlsSec *sec);
278 static void setMasterSecret(TlsSec *sec, Bytes *pm);
279 static void serverMasterSecret(TlsSec *sec, uchar *epm, int nepm);
280 static void setSecrets(TlsSec *sec, uchar *kd, int nkd);
281 static int clientMasterSecret(TlsSec *sec, RSApub *pub, uchar **epm, int *nepm);
282 static Bytes *pkcs1_encrypt(Bytes* data, RSApub* key, int blocktype);
283 static Bytes *pkcs1_decrypt(TlsSec *sec, uchar *epm, int nepm);
284 static void tlsSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient);
285 static void sslSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient);
286 static void sslPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label,
287 uchar *seed0, int nseed0, uchar *seed1, int nseed1);
288 static int setVers(TlsSec *sec, int version);
289
290 static AuthRpc* factotum_rsa_open(uchar *cert, int certlen);
291 static mpint* factotum_rsa_decrypt(AuthRpc *rpc, mpint *cipher);
292 static void factotum_rsa_close(AuthRpc*rpc);
293
294 static void* emalloc(int);
295 static void* erealloc(void*, int);
296 static void put32(uchar *p, u32int);
297 static void put24(uchar *p, int);
298 static void put16(uchar *p, int);
299 static u32int get32(uchar *p);
300 static int get24(uchar *p);
301 static int get16(uchar *p);
302 static Bytes* newbytes(int len);
303 static Bytes* makebytes(uchar* buf, int len);
304 static void freebytes(Bytes* b);
305 static Ints* newints(int len);
306 static Ints* makeints(int* buf, int len);
307 static void freeints(Ints* b);
308
309 //================= client/server ========================
310
311 // push TLS onto fd, returning new (application) file descriptor
312 // or -1 if error.
313 int
tlsServer(int fd,TLSconn * conn)314 tlsServer(int fd, TLSconn *conn)
315 {
316 char buf[8];
317 char dname[64];
318 int n, data, ctl, hand;
319 TlsConnection *tls;
320
321 if(conn == nil)
322 return -1;
323 ctl = open("#a/tls/clone", ORDWR);
324 if(ctl < 0)
325 return -1;
326 n = read(ctl, buf, sizeof(buf)-1);
327 if(n < 0){
328 close(ctl);
329 return -1;
330 }
331 buf[n] = 0;
332 sprint(conn->dir, "#a/tls/%s", buf);
333 sprint(dname, "#a/tls/%s/hand", buf);
334 hand = open(dname, ORDWR);
335 if(hand < 0){
336 close(ctl);
337 return -1;
338 }
339 fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion);
340 tls = tlsServer2(ctl, hand, conn->cert, conn->certlen, conn->trace);
341 sprint(dname, "#a/tls/%s/data", buf);
342 data = open(dname, ORDWR);
343 close(fd);
344 close(hand);
345 close(ctl);
346 if(data < 0){
347 return -1;
348 }
349 if(tls == nil){
350 close(data);
351 return -1;
352 }
353 if(conn->cert)
354 free(conn->cert);
355 conn->cert = 0; // client certificates are not yet implemented
356 conn->certlen = 0;
357 conn->sessionIDlen = tls->sid->len;
358 conn->sessionID = emalloc(conn->sessionIDlen);
359 memcpy(conn->sessionID, tls->sid->data, conn->sessionIDlen);
360 tlsConnectionFree(tls);
361 return data;
362 }
363
364 // push TLS onto fd, returning new (application) file descriptor
365 // or -1 if error.
366 int
tlsClient(int fd,TLSconn * conn)367 tlsClient(int fd, TLSconn *conn)
368 {
369 char buf[8];
370 char dname[64];
371 int n, data, ctl, hand;
372 TlsConnection *tls;
373
374 if(!conn)
375 return -1;
376 ctl = open("#a/tls/clone", ORDWR);
377 if(ctl < 0)
378 return -1;
379 n = read(ctl, buf, sizeof(buf)-1);
380 if(n < 0){
381 close(ctl);
382 return -1;
383 }
384 buf[n] = 0;
385 sprint(conn->dir, "#a/tls/%s", buf);
386 sprint(dname, "#a/tls/%s/hand", buf);
387 hand = open(dname, ORDWR);
388 if(hand < 0){
389 close(ctl);
390 return -1;
391 }
392 sprint(dname, "#a/tls/%s/data", buf);
393 data = open(dname, ORDWR);
394 if(data < 0)
395 return -1;
396 fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion);
397 tls = tlsClient2(ctl, hand, conn->sessionID, conn->sessionIDlen, conn->trace);
398 close(fd);
399 close(hand);
400 close(ctl);
401 if(tls == nil){
402 close(data);
403 return -1;
404 }
405 conn->certlen = tls->cert->len;
406 conn->cert = emalloc(conn->certlen);
407 memcpy(conn->cert, tls->cert->data, conn->certlen);
408 conn->sessionIDlen = tls->sid->len;
409 conn->sessionID = emalloc(conn->sessionIDlen);
410 memcpy(conn->sessionID, tls->sid->data, conn->sessionIDlen);
411 tlsConnectionFree(tls);
412 return data;
413 }
414
415 static TlsConnection *
tlsServer2(int ctl,int hand,uchar * cert,int ncert,int (* trace)(char * fmt,...))416 tlsServer2(int ctl, int hand, uchar *cert, int ncert, int (*trace)(char*fmt, ...))
417 {
418 TlsConnection *c;
419 Msg m;
420 Bytes *csid;
421 uchar sid[SidSize], kd[MaxKeyData];
422 char *secrets;
423 int cipher, compressor, nsid, rv;
424
425 if(trace)
426 trace("tlsServer2\n");
427 if(!initCiphers())
428 return nil;
429 c = emalloc(sizeof(TlsConnection));
430 c->ctl = ctl;
431 c->hand = hand;
432 c->trace = trace;
433 c->version = ProtocolVersion;
434
435 memset(&m, 0, sizeof(m));
436 if(!msgRecv(c, &m)){
437 if(trace)
438 trace("initial msgRecv failed\n");
439 goto Err;
440 }
441 if(m.tag != HClientHello) {
442 tlsError(c, EUnexpectedMessage, "expected a client hello");
443 goto Err;
444 }
445 c->clientVersion = m.u.clientHello.version;
446 if(trace)
447 trace("ClientHello version %x\n", c->clientVersion);
448 if(setVersion(c, m.u.clientHello.version) < 0) {
449 tlsError(c, EIllegalParameter, "incompatible version");
450 goto Err;
451 }
452
453 memmove(c->crandom, m.u.clientHello.random, RandomSize);
454 cipher = okCipher(m.u.clientHello.ciphers);
455 if(cipher < 0) {
456 // reply with EInsufficientSecurity if we know that's the case
457 if(cipher == -2)
458 tlsError(c, EInsufficientSecurity, "cipher suites too weak");
459 else
460 tlsError(c, EHandshakeFailure, "no matching cipher suite");
461 goto Err;
462 }
463 if(!setAlgs(c, cipher)){
464 tlsError(c, EHandshakeFailure, "no matching cipher suite");
465 goto Err;
466 }
467 compressor = okCompression(m.u.clientHello.compressors);
468 if(compressor < 0) {
469 tlsError(c, EHandshakeFailure, "no matching compressor");
470 goto Err;
471 }
472
473 csid = m.u.clientHello.sid;
474 if(trace)
475 trace(" cipher %d, compressor %d, csidlen %d\n", cipher, compressor, csid->len);
476 c->sec = tlsSecInits(c->clientVersion, csid->data, csid->len, c->crandom, sid, &nsid, c->srandom);
477 if(c->sec == nil){
478 tlsError(c, EHandshakeFailure, "can't initialize security: %r");
479 goto Err;
480 }
481 c->sec->rpc = factotum_rsa_open(cert, ncert);
482 if(c->sec->rpc == nil){
483 tlsError(c, EHandshakeFailure, "factotum_rsa_open: %r");
484 goto Err;
485 }
486 c->sec->rsapub = X509toRSApub(cert, ncert, nil, 0);
487 msgClear(&m);
488
489 m.tag = HServerHello;
490 m.u.serverHello.version = c->version;
491 memmove(m.u.serverHello.random, c->srandom, RandomSize);
492 m.u.serverHello.cipher = cipher;
493 m.u.serverHello.compressor = compressor;
494 c->sid = makebytes(sid, nsid);
495 m.u.serverHello.sid = makebytes(c->sid->data, c->sid->len);
496 if(!msgSend(c, &m, AQueue))
497 goto Err;
498 msgClear(&m);
499
500 m.tag = HCertificate;
501 m.u.certificate.ncert = 1;
502 m.u.certificate.certs = emalloc(m.u.certificate.ncert * sizeof(Bytes));
503 m.u.certificate.certs[0] = makebytes(cert, ncert);
504 if(!msgSend(c, &m, AQueue))
505 goto Err;
506 msgClear(&m);
507
508 m.tag = HServerHelloDone;
509 if(!msgSend(c, &m, AFlush))
510 goto Err;
511 msgClear(&m);
512
513 if(!msgRecv(c, &m))
514 goto Err;
515 if(m.tag != HClientKeyExchange) {
516 tlsError(c, EUnexpectedMessage, "expected a client key exchange");
517 goto Err;
518 }
519 if(tlsSecSecrets(c->sec, c->version, m.u.clientKeyExchange.key->data, m.u.clientKeyExchange.key->len, kd, c->nsecret) < 0){
520 tlsError(c, EHandshakeFailure, "couldn't set secrets: %r");
521 goto Err;
522 }
523 if(trace)
524 trace("tls secrets\n");
525 secrets = (char*)emalloc(2*c->nsecret);
526 enc64(secrets, 2*c->nsecret, kd, c->nsecret);
527 rv = fprint(c->ctl, "secret %s %s 0 %s", c->digest, c->enc, secrets);
528 memset(secrets, 0, 2*c->nsecret);
529 free(secrets);
530 memset(kd, 0, c->nsecret);
531 if(rv < 0){
532 tlsError(c, EHandshakeFailure, "can't set keys: %r");
533 goto Err;
534 }
535 msgClear(&m);
536
537 /* no CertificateVerify; skip to Finished */
538 if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 1) < 0){
539 tlsError(c, EInternalError, "can't set finished: %r");
540 goto Err;
541 }
542 if(!msgRecv(c, &m))
543 goto Err;
544 if(m.tag != HFinished) {
545 tlsError(c, EUnexpectedMessage, "expected a finished");
546 goto Err;
547 }
548 if(!finishedMatch(c, &m.u.finished)) {
549 tlsError(c, EHandshakeFailure, "finished verification failed");
550 goto Err;
551 }
552 msgClear(&m);
553
554 /* change cipher spec */
555 if(fprint(c->ctl, "changecipher") < 0){
556 tlsError(c, EInternalError, "can't enable cipher: %r");
557 goto Err;
558 }
559
560 if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 0) < 0){
561 tlsError(c, EInternalError, "can't set finished: %r");
562 goto Err;
563 }
564 m.tag = HFinished;
565 m.u.finished = c->finished;
566 if(!msgSend(c, &m, AFlush))
567 goto Err;
568 msgClear(&m);
569 if(trace)
570 trace("tls finished\n");
571
572 if(fprint(c->ctl, "opened") < 0)
573 goto Err;
574 tlsSecOk(c->sec);
575 return c;
576
577 Err:
578 msgClear(&m);
579 tlsConnectionFree(c);
580 return 0;
581 }
582
583 static TlsConnection *
tlsClient2(int ctl,int hand,uchar * csid,int ncsid,int (* trace)(char * fmt,...))584 tlsClient2(int ctl, int hand, uchar *csid, int ncsid, int (*trace)(char*fmt, ...))
585 {
586 TlsConnection *c;
587 Msg m;
588 uchar kd[MaxKeyData], *epm;
589 char *secrets;
590 int creq, nepm, rv;
591
592 if(!initCiphers())
593 return nil;
594 epm = nil;
595 c = emalloc(sizeof(TlsConnection));
596 c->version = ProtocolVersion;
597 c->ctl = ctl;
598 c->hand = hand;
599 c->trace = trace;
600 c->isClient = 1;
601 c->clientVersion = c->version;
602
603 c->sec = tlsSecInitc(c->clientVersion, c->crandom);
604 if(c->sec == nil)
605 goto Err;
606
607 /* client hello */
608 memset(&m, 0, sizeof(m));
609 m.tag = HClientHello;
610 m.u.clientHello.version = c->clientVersion;
611 memmove(m.u.clientHello.random, c->crandom, RandomSize);
612 m.u.clientHello.sid = makebytes(csid, ncsid);
613 m.u.clientHello.ciphers = makeciphers();
614 m.u.clientHello.compressors = makebytes(compressors,sizeof(compressors));
615 if(!msgSend(c, &m, AFlush))
616 goto Err;
617 msgClear(&m);
618
619 /* server hello */
620 if(!msgRecv(c, &m))
621 goto Err;
622 if(m.tag != HServerHello) {
623 tlsError(c, EUnexpectedMessage, "expected a server hello");
624 goto Err;
625 }
626 if(setVersion(c, m.u.serverHello.version) < 0) {
627 tlsError(c, EIllegalParameter, "incompatible version %r");
628 goto Err;
629 }
630 memmove(c->srandom, m.u.serverHello.random, RandomSize);
631 c->sid = makebytes(m.u.serverHello.sid->data, m.u.serverHello.sid->len);
632 if(c->sid->len != 0 && c->sid->len != SidSize) {
633 tlsError(c, EIllegalParameter, "invalid server session identifier");
634 goto Err;
635 }
636 if(!setAlgs(c, m.u.serverHello.cipher)) {
637 tlsError(c, EIllegalParameter, "invalid cipher suite");
638 goto Err;
639 }
640 if(m.u.serverHello.compressor != CompressionNull) {
641 tlsError(c, EIllegalParameter, "invalid compression");
642 goto Err;
643 }
644 msgClear(&m);
645
646 /* certificate */
647 if(!msgRecv(c, &m) || m.tag != HCertificate) {
648 tlsError(c, EUnexpectedMessage, "expected a certificate");
649 goto Err;
650 }
651 if(m.u.certificate.ncert < 1) {
652 tlsError(c, EIllegalParameter, "runt certificate");
653 goto Err;
654 }
655 c->cert = makebytes(m.u.certificate.certs[0]->data, m.u.certificate.certs[0]->len);
656 msgClear(&m);
657
658 /* server key exchange (optional) */
659 if(!msgRecv(c, &m))
660 goto Err;
661 if(m.tag == HServerKeyExchange) {
662 tlsError(c, EUnexpectedMessage, "got an server key exchange");
663 goto Err;
664 // If implementing this later, watch out for rollback attack
665 // described in Wagner Schneier 1996, section 4.4.
666 }
667
668 /* certificate request (optional) */
669 creq = 0;
670 if(m.tag == HCertificateRequest) {
671 creq = 1;
672 msgClear(&m);
673 if(!msgRecv(c, &m))
674 goto Err;
675 }
676
677 if(m.tag != HServerHelloDone) {
678 tlsError(c, EUnexpectedMessage, "expected a server hello done");
679 goto Err;
680 }
681 msgClear(&m);
682
683 if(tlsSecSecretc(c->sec, c->sid->data, c->sid->len, c->srandom,
684 c->cert->data, c->cert->len, c->version, &epm, &nepm,
685 kd, c->nsecret) < 0){
686 tlsError(c, EBadCertificate, "invalid x509/rsa certificate");
687 goto Err;
688 }
689 secrets = (char*)emalloc(2*c->nsecret);
690 enc64(secrets, 2*c->nsecret, kd, c->nsecret);
691 rv = fprint(c->ctl, "secret %s %s 1 %s", c->digest, c->enc, secrets);
692 memset(secrets, 0, 2*c->nsecret);
693 free(secrets);
694 memset(kd, 0, c->nsecret);
695 if(rv < 0){
696 tlsError(c, EHandshakeFailure, "can't set keys: %r");
697 goto Err;
698 }
699
700 if(creq) {
701 /* send a zero length certificate */
702 m.tag = HCertificate;
703 if(!msgSend(c, &m, AFlush))
704 goto Err;
705 msgClear(&m);
706 }
707
708 /* client key exchange */
709 m.tag = HClientKeyExchange;
710 m.u.clientKeyExchange.key = makebytes(epm, nepm);
711 free(epm);
712 epm = nil;
713 if(m.u.clientKeyExchange.key == nil) {
714 tlsError(c, EHandshakeFailure, "can't set secret: %r");
715 goto Err;
716 }
717 if(!msgSend(c, &m, AFlush))
718 goto Err;
719 msgClear(&m);
720
721 /* change cipher spec */
722 if(fprint(c->ctl, "changecipher") < 0){
723 tlsError(c, EInternalError, "can't enable cipher: %r");
724 goto Err;
725 }
726
727 // Cipherchange must occur immediately before Finished to avoid
728 // potential hole; see section 4.3 of Wagner Schneier 1996.
729 if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 1) < 0){
730 tlsError(c, EInternalError, "can't set finished 1: %r");
731 goto Err;
732 }
733 m.tag = HFinished;
734 m.u.finished = c->finished;
735
736 if(!msgSend(c, &m, AFlush)) {
737 fprint(2, "tlsClient nepm=%d\n", nepm);
738 tlsError(c, EInternalError, "can't flush after client Finished: %r");
739 goto Err;
740 }
741 msgClear(&m);
742
743 if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 0) < 0){
744 fprint(2, "tlsClient nepm=%d\n", nepm);
745 tlsError(c, EInternalError, "can't set finished 0: %r");
746 goto Err;
747 }
748 if(!msgRecv(c, &m)) {
749 fprint(2, "tlsClient nepm=%d\n", nepm);
750 tlsError(c, EInternalError, "can't read server Finished: %r");
751 goto Err;
752 }
753 if(m.tag != HFinished) {
754 fprint(2, "tlsClient nepm=%d\n", nepm);
755 tlsError(c, EUnexpectedMessage, "expected a Finished msg from server");
756 goto Err;
757 }
758
759 if(!finishedMatch(c, &m.u.finished)) {
760 tlsError(c, EHandshakeFailure, "finished verification failed");
761 goto Err;
762 }
763 msgClear(&m);
764
765 if(fprint(c->ctl, "opened") < 0){
766 if(trace)
767 trace("unable to do final open: %r\n");
768 goto Err;
769 }
770 tlsSecOk(c->sec);
771 return c;
772
773 Err:
774 free(epm);
775 msgClear(&m);
776 tlsConnectionFree(c);
777 return 0;
778 }
779
780
781 //================= message functions ========================
782
783 static uchar sendbuf[9000], *sendp;
784
785 static int
msgSend(TlsConnection * c,Msg * m,int act)786 msgSend(TlsConnection *c, Msg *m, int act)
787 {
788 uchar *p; // sendp = start of new message; p = write pointer
789 int nn, n, i;
790
791 if(sendp == nil)
792 sendp = sendbuf;
793 p = sendp;
794 if(c->trace)
795 c->trace("send %s", msgPrint((char*)p, (sizeof sendbuf) - (p-sendbuf), m));
796
797 p[0] = m->tag; // header - fill in size later
798 p += 4;
799
800 switch(m->tag) {
801 default:
802 tlsError(c, EInternalError, "can't encode a %d", m->tag);
803 goto Err;
804 case HClientHello:
805 // version
806 put16(p, m->u.clientHello.version);
807 p += 2;
808
809 // random
810 memmove(p, m->u.clientHello.random, RandomSize);
811 p += RandomSize;
812
813 // sid
814 n = m->u.clientHello.sid->len;
815 assert(n < 256);
816 p[0] = n;
817 memmove(p+1, m->u.clientHello.sid->data, n);
818 p += n+1;
819
820 n = m->u.clientHello.ciphers->len;
821 assert(n > 0 && n < 200);
822 put16(p, n*2);
823 p += 2;
824 for(i=0; i<n; i++) {
825 put16(p, m->u.clientHello.ciphers->data[i]);
826 p += 2;
827 }
828
829 n = m->u.clientHello.compressors->len;
830 assert(n > 0);
831 p[0] = n;
832 memmove(p+1, m->u.clientHello.compressors->data, n);
833 p += n+1;
834 break;
835 case HServerHello:
836 put16(p, m->u.serverHello.version);
837 p += 2;
838
839 // random
840 memmove(p, m->u.serverHello.random, RandomSize);
841 p += RandomSize;
842
843 // sid
844 n = m->u.serverHello.sid->len;
845 assert(n < 256);
846 p[0] = n;
847 memmove(p+1, m->u.serverHello.sid->data, n);
848 p += n+1;
849
850 put16(p, m->u.serverHello.cipher);
851 p += 2;
852 p[0] = m->u.serverHello.compressor;
853 p += 1;
854 break;
855 case HServerHelloDone:
856 break;
857 case HCertificate:
858 nn = 0;
859 for(i = 0; i < m->u.certificate.ncert; i++)
860 nn += 3 + m->u.certificate.certs[i]->len;
861 if(p + 3 + nn - sendbuf > sizeof(sendbuf)) {
862 tlsError(c, EInternalError, "output buffer too small for certificate");
863 goto Err;
864 }
865 put24(p, nn);
866 p += 3;
867 for(i = 0; i < m->u.certificate.ncert; i++){
868 put24(p, m->u.certificate.certs[i]->len);
869 p += 3;
870 memmove(p, m->u.certificate.certs[i]->data, m->u.certificate.certs[i]->len);
871 p += m->u.certificate.certs[i]->len;
872 }
873 break;
874 case HClientKeyExchange:
875 n = m->u.clientKeyExchange.key->len;
876 if(c->version != SSL3Version){
877 put16(p, n);
878 p += 2;
879 }
880 memmove(p, m->u.clientKeyExchange.key->data, n);
881 p += n;
882 break;
883 case HFinished:
884 memmove(p, m->u.finished.verify, m->u.finished.n);
885 p += m->u.finished.n;
886 break;
887 }
888
889 // go back and fill in size
890 n = p-sendp;
891 assert(p <= sendbuf+sizeof(sendbuf));
892 put24(sendp+1, n-4);
893
894 // remember hash of Handshake messages
895 if(m->tag != HHelloRequest) {
896 md5(sendp, n, 0, &c->hsmd5);
897 sha1(sendp, n, 0, &c->hssha1);
898 }
899
900 sendp = p;
901 if(act == AFlush){
902 sendp = sendbuf;
903 if(write(c->hand, sendbuf, p-sendbuf) < 0){
904 fprint(2, "write error: %r\n");
905 goto Err;
906 }
907 }
908 msgClear(m);
909 return 1;
910 Err:
911 msgClear(m);
912 return 0;
913 }
914
915 static uchar*
tlsReadN(TlsConnection * c,int n)916 tlsReadN(TlsConnection *c, int n)
917 {
918 uchar *p;
919 int nn, nr;
920
921 nn = c->ep - c->rp;
922 if(nn < n){
923 if(c->rp != c->buf){
924 memmove(c->buf, c->rp, nn);
925 c->rp = c->buf;
926 c->ep = &c->buf[nn];
927 }
928 for(; nn < n; nn += nr) {
929 nr = read(c->hand, &c->rp[nn], n - nn);
930 if(nr <= 0)
931 return nil;
932 c->ep += nr;
933 }
934 }
935 p = c->rp;
936 c->rp += n;
937 return p;
938 }
939
940 static int
msgRecv(TlsConnection * c,Msg * m)941 msgRecv(TlsConnection *c, Msg *m)
942 {
943 uchar *p;
944 int type, n, nn, i, nsid, nrandom, nciph;
945
946 for(;;) {
947 p = tlsReadN(c, 4);
948 if(p == nil)
949 return 0;
950 type = p[0];
951 n = get24(p+1);
952
953 if(type != HHelloRequest)
954 break;
955 if(n != 0) {
956 tlsError(c, EDecodeError, "invalid hello request during handshake");
957 return 0;
958 }
959 }
960
961 if(n > sizeof(c->buf)) {
962 tlsError(c, EDecodeError, "handshake message too long %d %d", n, sizeof(c->buf));
963 return 0;
964 }
965
966 if(type == HSSL2ClientHello){
967 /* Cope with an SSL3 ClientHello expressed in SSL2 record format.
968 This is sent by some clients that we must interoperate
969 with, such as Java's JSSE and Microsoft's Internet Explorer. */
970 p = tlsReadN(c, n);
971 if(p == nil)
972 return 0;
973 md5(p, n, 0, &c->hsmd5);
974 sha1(p, n, 0, &c->hssha1);
975 m->tag = HClientHello;
976 if(n < 22)
977 goto Short;
978 m->u.clientHello.version = get16(p+1);
979 p += 3;
980 n -= 3;
981 nn = get16(p); /* cipher_spec_len */
982 nsid = get16(p + 2);
983 nrandom = get16(p + 4);
984 p += 6;
985 n -= 6;
986 if(nsid != 0 /* no sid's, since shouldn't restart using ssl2 header */
987 || nrandom < 16 || nn % 3)
988 goto Err;
989 if(c->trace && (n - nrandom != nn))
990 c->trace("n-nrandom!=nn: n=%d nrandom=%d nn=%d\n", n, nrandom, nn);
991 /* ignore ssl2 ciphers and look for {0x00, ssl3 cipher} */
992 nciph = 0;
993 for(i = 0; i < nn; i += 3)
994 if(p[i] == 0)
995 nciph++;
996 m->u.clientHello.ciphers = newints(nciph);
997 nciph = 0;
998 for(i = 0; i < nn; i += 3)
999 if(p[i] == 0)
1000 m->u.clientHello.ciphers->data[nciph++] = get16(&p[i + 1]);
1001 p += nn;
1002 m->u.clientHello.sid = makebytes(nil, 0);
1003 if(nrandom > RandomSize)
1004 nrandom = RandomSize;
1005 memset(m->u.clientHello.random, 0, RandomSize - nrandom);
1006 memmove(&m->u.clientHello.random[RandomSize - nrandom], p, nrandom);
1007 m->u.clientHello.compressors = newbytes(1);
1008 m->u.clientHello.compressors->data[0] = CompressionNull;
1009 goto Ok;
1010 }
1011
1012 md5(p, 4, 0, &c->hsmd5);
1013 sha1(p, 4, 0, &c->hssha1);
1014
1015 p = tlsReadN(c, n);
1016 if(p == nil)
1017 return 0;
1018
1019 md5(p, n, 0, &c->hsmd5);
1020 sha1(p, n, 0, &c->hssha1);
1021
1022 m->tag = type;
1023
1024 switch(type) {
1025 default:
1026 tlsError(c, EUnexpectedMessage, "can't decode a %d", type);
1027 goto Err;
1028 case HClientHello:
1029 if(n < 2)
1030 goto Short;
1031 m->u.clientHello.version = get16(p);
1032 p += 2;
1033 n -= 2;
1034
1035 if(n < RandomSize)
1036 goto Short;
1037 memmove(m->u.clientHello.random, p, RandomSize);
1038 p += RandomSize;
1039 n -= RandomSize;
1040 if(n < 1 || n < p[0]+1)
1041 goto Short;
1042 m->u.clientHello.sid = makebytes(p+1, p[0]);
1043 p += m->u.clientHello.sid->len+1;
1044 n -= m->u.clientHello.sid->len+1;
1045
1046 if(n < 2)
1047 goto Short;
1048 nn = get16(p);
1049 p += 2;
1050 n -= 2;
1051
1052 if((nn & 1) || n < nn || nn < 2)
1053 goto Short;
1054 m->u.clientHello.ciphers = newints(nn >> 1);
1055 for(i = 0; i < nn; i += 2)
1056 m->u.clientHello.ciphers->data[i >> 1] = get16(&p[i]);
1057 p += nn;
1058 n -= nn;
1059
1060 if(n < 1 || n < p[0]+1 || p[0] == 0)
1061 goto Short;
1062 nn = p[0];
1063 m->u.clientHello.compressors = newbytes(nn);
1064 memmove(m->u.clientHello.compressors->data, p+1, nn);
1065 n -= nn + 1;
1066 break;
1067 case HServerHello:
1068 if(n < 2)
1069 goto Short;
1070 m->u.serverHello.version = get16(p);
1071 p += 2;
1072 n -= 2;
1073
1074 if(n < RandomSize)
1075 goto Short;
1076 memmove(m->u.serverHello.random, p, RandomSize);
1077 p += RandomSize;
1078 n -= RandomSize;
1079
1080 if(n < 1 || n < p[0]+1)
1081 goto Short;
1082 m->u.serverHello.sid = makebytes(p+1, p[0]);
1083 p += m->u.serverHello.sid->len+1;
1084 n -= m->u.serverHello.sid->len+1;
1085
1086 if(n < 3)
1087 goto Short;
1088 m->u.serverHello.cipher = get16(p);
1089 m->u.serverHello.compressor = p[2];
1090 n -= 3;
1091 break;
1092 case HCertificate:
1093 if(n < 3)
1094 goto Short;
1095 nn = get24(p);
1096 p += 3;
1097 n -= 3;
1098 if(n != nn)
1099 goto Short;
1100 /* certs */
1101 i = 0;
1102 while(n > 0) {
1103 if(n < 3)
1104 goto Short;
1105 nn = get24(p);
1106 p += 3;
1107 n -= 3;
1108 if(nn > n)
1109 goto Short;
1110 m->u.certificate.ncert = i+1;
1111 m->u.certificate.certs = erealloc(m->u.certificate.certs, (i+1)*sizeof(Bytes));
1112 m->u.certificate.certs[i] = makebytes(p, nn);
1113 p += nn;
1114 n -= nn;
1115 i++;
1116 }
1117 break;
1118 case HCertificateRequest:
1119 if(n < 2)
1120 goto Short;
1121 nn = get16(p);
1122 p += 2;
1123 n -= 2;
1124 if(nn < 1 || nn > n)
1125 goto Short;
1126 m->u.certificateRequest.types = makebytes(p, nn);
1127 nn = get24(p);
1128 p += 3;
1129 n -= 3;
1130 if(nn == 0 || n != nn)
1131 goto Short;
1132 /* cas */
1133 i = 0;
1134 while(n > 0) {
1135 if(n < 2)
1136 goto Short;
1137 nn = get16(p);
1138 p += 2;
1139 n -= 2;
1140 if(nn < 1 || nn > n)
1141 goto Short;
1142 m->u.certificateRequest.nca = i+1;
1143 m->u.certificateRequest.cas = erealloc(m->u.certificateRequest.cas, (i+1)*sizeof(Bytes));
1144 m->u.certificateRequest.cas[i] = makebytes(p, nn);
1145 p += nn;
1146 n -= nn;
1147 i++;
1148 }
1149 break;
1150 case HServerHelloDone:
1151 break;
1152 case HClientKeyExchange:
1153 /*
1154 * this message depends upon the encryption selected
1155 * assume rsa.
1156 */
1157 if(c->version == SSL3Version)
1158 nn = n;
1159 else{
1160 if(n < 2)
1161 goto Short;
1162 nn = get16(p);
1163 p += 2;
1164 n -= 2;
1165 }
1166 if(n < nn)
1167 goto Short;
1168 m->u.clientKeyExchange.key = makebytes(p, nn);
1169 n -= nn;
1170 break;
1171 case HFinished:
1172 m->u.finished.n = c->finished.n;
1173 if(n < m->u.finished.n)
1174 goto Short;
1175 memmove(m->u.finished.verify, p, m->u.finished.n);
1176 n -= m->u.finished.n;
1177 break;
1178 }
1179
1180 if(type != HClientHello && n != 0)
1181 goto Short;
1182 Ok:
1183 if(c->trace){
1184 char buf[8000];
1185 c->trace("recv %s", msgPrint(buf, sizeof buf, m));
1186 }
1187 return 1;
1188 Short:
1189 tlsError(c, EDecodeError, "handshake message has invalid length");
1190 Err:
1191 msgClear(m);
1192 return 0;
1193 }
1194
1195 static void
msgClear(Msg * m)1196 msgClear(Msg *m)
1197 {
1198 int i;
1199
1200 switch(m->tag) {
1201 default:
1202 sysfatal("msgClear: unknown message type: %d", m->tag);
1203 case HHelloRequest:
1204 break;
1205 case HClientHello:
1206 freebytes(m->u.clientHello.sid);
1207 freeints(m->u.clientHello.ciphers);
1208 freebytes(m->u.clientHello.compressors);
1209 break;
1210 case HServerHello:
1211 freebytes(m->u.clientHello.sid);
1212 break;
1213 case HCertificate:
1214 for(i=0; i<m->u.certificate.ncert; i++)
1215 freebytes(m->u.certificate.certs[i]);
1216 free(m->u.certificate.certs);
1217 break;
1218 case HCertificateRequest:
1219 freebytes(m->u.certificateRequest.types);
1220 for(i=0; i<m->u.certificateRequest.nca; i++)
1221 freebytes(m->u.certificateRequest.cas[i]);
1222 free(m->u.certificateRequest.cas);
1223 break;
1224 case HServerHelloDone:
1225 break;
1226 case HClientKeyExchange:
1227 freebytes(m->u.clientKeyExchange.key);
1228 break;
1229 case HFinished:
1230 break;
1231 }
1232 memset(m, 0, sizeof(Msg));
1233 }
1234
1235 static char *
bytesPrint(char * bs,char * be,char * s0,Bytes * b,char * s1)1236 bytesPrint(char *bs, char *be, char *s0, Bytes *b, char *s1)
1237 {
1238 int i;
1239
1240 if(s0)
1241 bs = seprint(bs, be, "%s", s0);
1242 bs = seprint(bs, be, "[");
1243 if(b == nil)
1244 bs = seprint(bs, be, "nil");
1245 else
1246 for(i=0; i<b->len; i++)
1247 bs = seprint(bs, be, "%.2x ", b->data[i]);
1248 bs = seprint(bs, be, "]");
1249 if(s1)
1250 bs = seprint(bs, be, "%s", s1);
1251 return bs;
1252 }
1253
1254 static char *
intsPrint(char * bs,char * be,char * s0,Ints * b,char * s1)1255 intsPrint(char *bs, char *be, char *s0, Ints *b, char *s1)
1256 {
1257 int i;
1258
1259 if(s0)
1260 bs = seprint(bs, be, "%s", s0);
1261 bs = seprint(bs, be, "[");
1262 if(b == nil)
1263 bs = seprint(bs, be, "nil");
1264 else
1265 for(i=0; i<b->len; i++)
1266 bs = seprint(bs, be, "%x ", b->data[i]);
1267 bs = seprint(bs, be, "]");
1268 if(s1)
1269 bs = seprint(bs, be, "%s", s1);
1270 return bs;
1271 }
1272
1273 static char*
msgPrint(char * buf,int n,Msg * m)1274 msgPrint(char *buf, int n, Msg *m)
1275 {
1276 int i;
1277 char *bs = buf, *be = buf+n;
1278
1279 switch(m->tag) {
1280 default:
1281 bs = seprint(bs, be, "unknown %d\n", m->tag);
1282 break;
1283 case HClientHello:
1284 bs = seprint(bs, be, "ClientHello\n");
1285 bs = seprint(bs, be, "\tversion: %.4x\n", m->u.clientHello.version);
1286 bs = seprint(bs, be, "\trandom: ");
1287 for(i=0; i<RandomSize; i++)
1288 bs = seprint(bs, be, "%.2x", m->u.clientHello.random[i]);
1289 bs = seprint(bs, be, "\n");
1290 bs = bytesPrint(bs, be, "\tsid: ", m->u.clientHello.sid, "\n");
1291 bs = intsPrint(bs, be, "\tciphers: ", m->u.clientHello.ciphers, "\n");
1292 bs = bytesPrint(bs, be, "\tcompressors: ", m->u.clientHello.compressors, "\n");
1293 break;
1294 case HServerHello:
1295 bs = seprint(bs, be, "ServerHello\n");
1296 bs = seprint(bs, be, "\tversion: %.4x\n", m->u.serverHello.version);
1297 bs = seprint(bs, be, "\trandom: ");
1298 for(i=0; i<RandomSize; i++)
1299 bs = seprint(bs, be, "%.2x", m->u.serverHello.random[i]);
1300 bs = seprint(bs, be, "\n");
1301 bs = bytesPrint(bs, be, "\tsid: ", m->u.serverHello.sid, "\n");
1302 bs = seprint(bs, be, "\tcipher: %.4x\n", m->u.serverHello.cipher);
1303 bs = seprint(bs, be, "\tcompressor: %.2x\n", m->u.serverHello.compressor);
1304 break;
1305 case HCertificate:
1306 bs = seprint(bs, be, "Certificate\n");
1307 for(i=0; i<m->u.certificate.ncert; i++)
1308 bs = bytesPrint(bs, be, "\t", m->u.certificate.certs[i], "\n");
1309 break;
1310 case HCertificateRequest:
1311 bs = seprint(bs, be, "CertificateRequest\n");
1312 bs = bytesPrint(bs, be, "\ttypes: ", m->u.certificateRequest.types, "\n");
1313 bs = seprint(bs, be, "\tcertificateauthorities\n");
1314 for(i=0; i<m->u.certificateRequest.nca; i++)
1315 bs = bytesPrint(bs, be, "\t\t", m->u.certificateRequest.cas[i], "\n");
1316 break;
1317 case HServerHelloDone:
1318 bs = seprint(bs, be, "ServerHelloDone\n");
1319 break;
1320 case HClientKeyExchange:
1321 bs = seprint(bs, be, "HClientKeyExchange\n");
1322 bs = bytesPrint(bs, be, "\tkey: ", m->u.clientKeyExchange.key, "\n");
1323 break;
1324 case HFinished:
1325 bs = seprint(bs, be, "HFinished\n");
1326 for(i=0; i<m->u.finished.n; i++)
1327 bs = seprint(bs, be, "%.2x", m->u.finished.verify[i]);
1328 bs = seprint(bs, be, "\n");
1329 break;
1330 }
1331 USED(bs);
1332 return buf;
1333 }
1334
1335 static void
tlsError(TlsConnection * c,int err,char * fmt,...)1336 tlsError(TlsConnection *c, int err, char *fmt, ...)
1337 {
1338 char msg[512];
1339 va_list arg;
1340
1341 va_start(arg, fmt);
1342 vseprint(msg, msg+sizeof(msg), fmt, arg);
1343 va_end(arg);
1344 if(c->trace)
1345 c->trace("tlsError: %s\n", msg);
1346 else if(c->erred)
1347 fprint(2, "double error: %r, %s", msg);
1348 else
1349 werrstr("tls: local %s", msg);
1350 c->erred = 1;
1351 fprint(c->ctl, "alert %d", err);
1352 }
1353
1354 // commit to specific version number
1355 static int
setVersion(TlsConnection * c,int version)1356 setVersion(TlsConnection *c, int version)
1357 {
1358 if(c->verset || version > MaxProtoVersion || version < MinProtoVersion)
1359 return -1;
1360 if(version > c->version)
1361 version = c->version;
1362 if(version == SSL3Version) {
1363 c->version = version;
1364 c->finished.n = SSL3FinishedLen;
1365 }else if(version == TLSVersion){
1366 c->version = version;
1367 c->finished.n = TLSFinishedLen;
1368 }else
1369 return -1;
1370 c->verset = 1;
1371 return fprint(c->ctl, "version 0x%x", version);
1372 }
1373
1374 // confirm that received Finished message matches the expected value
1375 static int
finishedMatch(TlsConnection * c,Finished * f)1376 finishedMatch(TlsConnection *c, Finished *f)
1377 {
1378 return memcmp(f->verify, c->finished.verify, f->n) == 0;
1379 }
1380
1381 // free memory associated with TlsConnection struct
1382 // (but don't close the TLS channel itself)
1383 static void
tlsConnectionFree(TlsConnection * c)1384 tlsConnectionFree(TlsConnection *c)
1385 {
1386 tlsSecClose(c->sec);
1387 freebytes(c->sid);
1388 freebytes(c->cert);
1389 memset(c, 0, sizeof(c));
1390 free(c);
1391 }
1392
1393
1394 //================= cipher choices ========================
1395
1396 static int weakCipher[CipherMax] =
1397 {
1398 1, /* TLS_NULL_WITH_NULL_NULL */
1399 1, /* TLS_RSA_WITH_NULL_MD5 */
1400 1, /* TLS_RSA_WITH_NULL_SHA */
1401 1, /* TLS_RSA_EXPORT_WITH_RC4_40_MD5 */
1402 0, /* TLS_RSA_WITH_RC4_128_MD5 */
1403 0, /* TLS_RSA_WITH_RC4_128_SHA */
1404 1, /* TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5 */
1405 0, /* TLS_RSA_WITH_IDEA_CBC_SHA */
1406 1, /* TLS_RSA_EXPORT_WITH_DES40_CBC_SHA */
1407 0, /* TLS_RSA_WITH_DES_CBC_SHA */
1408 0, /* TLS_RSA_WITH_3DES_EDE_CBC_SHA */
1409 1, /* TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA */
1410 0, /* TLS_DH_DSS_WITH_DES_CBC_SHA */
1411 0, /* TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA */
1412 1, /* TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA */
1413 0, /* TLS_DH_RSA_WITH_DES_CBC_SHA */
1414 0, /* TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA */
1415 1, /* TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA */
1416 0, /* TLS_DHE_DSS_WITH_DES_CBC_SHA */
1417 0, /* TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA */
1418 1, /* TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA */
1419 0, /* TLS_DHE_RSA_WITH_DES_CBC_SHA */
1420 0, /* TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA */
1421 1, /* TLS_DH_anon_EXPORT_WITH_RC4_40_MD5 */
1422 1, /* TLS_DH_anon_WITH_RC4_128_MD5 */
1423 1, /* TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA */
1424 1, /* TLS_DH_anon_WITH_DES_CBC_SHA */
1425 1, /* TLS_DH_anon_WITH_3DES_EDE_CBC_SHA */
1426 };
1427
1428 static int
setAlgs(TlsConnection * c,int a)1429 setAlgs(TlsConnection *c, int a)
1430 {
1431 int i;
1432
1433 for(i = 0; i < nelem(cipherAlgs); i++){
1434 if(cipherAlgs[i].tlsid == a){
1435 c->enc = cipherAlgs[i].enc;
1436 c->digest = cipherAlgs[i].digest;
1437 c->nsecret = cipherAlgs[i].nsecret;
1438 if(c->nsecret > MaxKeyData)
1439 return 0;
1440 return 1;
1441 }
1442 }
1443 return 0;
1444 }
1445
1446 static int
okCipher(Ints * cv)1447 okCipher(Ints *cv)
1448 {
1449 int weak, i, j, c;
1450
1451 weak = 1;
1452 for(i = 0; i < cv->len; i++) {
1453 c = cv->data[i];
1454 if(c >= CipherMax)
1455 weak = 0;
1456 else
1457 weak &= weakCipher[c];
1458 for(j = 0; j < nelem(cipherAlgs); j++)
1459 if(cipherAlgs[j].ok && cipherAlgs[j].tlsid == c)
1460 return c;
1461 }
1462 if(weak)
1463 return -2;
1464 return -1;
1465 }
1466
1467 static int
okCompression(Bytes * cv)1468 okCompression(Bytes *cv)
1469 {
1470 int i, j, c;
1471
1472 for(i = 0; i < cv->len; i++) {
1473 c = cv->data[i];
1474 for(j = 0; j < nelem(compressors); j++) {
1475 if(compressors[j] == c)
1476 return c;
1477 }
1478 }
1479 return -1;
1480 }
1481
1482 static Lock ciphLock;
1483 static int nciphers;
1484
1485 static int
initCiphers(void)1486 initCiphers(void)
1487 {
1488 enum {MaxAlgF = 1024, MaxAlgs = 10};
1489 char s[MaxAlgF], *flds[MaxAlgs];
1490 int i, j, n, ok;
1491
1492 lock(&ciphLock);
1493 if(nciphers){
1494 unlock(&ciphLock);
1495 return nciphers;
1496 }
1497 j = open("#a/tls/encalgs", OREAD);
1498 if(j < 0){
1499 werrstr("can't open #a/tls/encalgs: %r");
1500 return 0;
1501 }
1502 n = read(j, s, MaxAlgF-1);
1503 close(j);
1504 if(n <= 0){
1505 werrstr("nothing in #a/tls/encalgs: %r");
1506 return 0;
1507 }
1508 s[n] = 0;
1509 n = getfields(s, flds, MaxAlgs, 1, " \t\r\n");
1510 for(i = 0; i < nelem(cipherAlgs); i++){
1511 ok = 0;
1512 for(j = 0; j < n; j++){
1513 if(strcmp(cipherAlgs[i].enc, flds[j]) == 0){
1514 ok = 1;
1515 break;
1516 }
1517 }
1518 cipherAlgs[i].ok = ok;
1519 }
1520
1521 j = open("#a/tls/hashalgs", OREAD);
1522 if(j < 0){
1523 werrstr("can't open #a/tls/hashalgs: %r");
1524 return 0;
1525 }
1526 n = read(j, s, MaxAlgF-1);
1527 close(j);
1528 if(n <= 0){
1529 werrstr("nothing in #a/tls/hashalgs: %r");
1530 return 0;
1531 }
1532 s[n] = 0;
1533 n = getfields(s, flds, MaxAlgs, 1, " \t\r\n");
1534 for(i = 0; i < nelem(cipherAlgs); i++){
1535 ok = 0;
1536 for(j = 0; j < n; j++){
1537 if(strcmp(cipherAlgs[i].digest, flds[j]) == 0){
1538 ok = 1;
1539 break;
1540 }
1541 }
1542 cipherAlgs[i].ok &= ok;
1543 if(cipherAlgs[i].ok)
1544 nciphers++;
1545 }
1546 unlock(&ciphLock);
1547 return nciphers;
1548 }
1549
1550 static Ints*
makeciphers(void)1551 makeciphers(void)
1552 {
1553 Ints *is;
1554 int i, j;
1555
1556 is = newints(nciphers);
1557 j = 0;
1558 for(i = 0; i < nelem(cipherAlgs); i++){
1559 if(cipherAlgs[i].ok)
1560 is->data[j++] = cipherAlgs[i].tlsid;
1561 }
1562 return is;
1563 }
1564
1565
1566
1567 //================= security functions ========================
1568
1569 // given X.509 certificate, set up connection to factotum
1570 // for using corresponding private key
1571 static AuthRpc*
factotum_rsa_open(uchar * cert,int certlen)1572 factotum_rsa_open(uchar *cert, int certlen)
1573 {
1574 int afd;
1575 char *s;
1576 mpint *pub = nil;
1577 RSApub *rsapub;
1578 AuthRpc *rpc;
1579
1580 // start talking to factotum
1581 if((afd = open("/mnt/factotum/rpc", ORDWR)) < 0)
1582 return nil;
1583 if((rpc = auth_allocrpc(afd)) == nil){
1584 close(afd);
1585 return nil;
1586 }
1587 s = "proto=rsa service=tls role=client";
1588 if(auth_rpc(rpc, "start", s, strlen(s)) != ARok){
1589 factotum_rsa_close(rpc);
1590 return nil;
1591 }
1592
1593 // roll factotum keyring around to match certificate
1594 rsapub = X509toRSApub(cert, certlen, nil, 0);
1595 while(1){
1596 if(auth_rpc(rpc, "read", nil, 0) != ARok){
1597 factotum_rsa_close(rpc);
1598 rpc = nil;
1599 goto done;
1600 }
1601 pub = strtomp(rpc->arg, nil, 16, nil);
1602 assert(pub != nil);
1603 if(mpcmp(pub,rsapub->n) == 0)
1604 break;
1605 }
1606 done:
1607 mpfree(pub);
1608 rsapubfree(rsapub);
1609 return rpc;
1610 }
1611
1612 static mpint*
factotum_rsa_decrypt(AuthRpc * rpc,mpint * cipher)1613 factotum_rsa_decrypt(AuthRpc *rpc, mpint *cipher)
1614 {
1615 char *p;
1616 int rv;
1617
1618 if((p = mptoa(cipher, 16, nil, 0)) == nil)
1619 return nil;
1620 rv = auth_rpc(rpc, "write", p, strlen(p));
1621 free(p);
1622 if(rv != ARok || auth_rpc(rpc, "read", nil, 0) != ARok)
1623 return nil;
1624 mpfree(cipher);
1625 return strtomp(rpc->arg, nil, 16, nil);
1626 }
1627
1628 static void
factotum_rsa_close(AuthRpc * rpc)1629 factotum_rsa_close(AuthRpc*rpc)
1630 {
1631 if(!rpc)
1632 return;
1633 close(rpc->afd);
1634 auth_freerpc(rpc);
1635 }
1636
1637 static void
tlsPmd5(uchar * buf,int nbuf,uchar * key,int nkey,uchar * label,int nlabel,uchar * seed0,int nseed0,uchar * seed1,int nseed1)1638 tlsPmd5(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
1639 {
1640 uchar ai[MD5dlen], tmp[MD5dlen];
1641 int i, n;
1642 MD5state *s;
1643
1644 // generate a1
1645 s = hmac_md5(label, nlabel, key, nkey, nil, nil);
1646 s = hmac_md5(seed0, nseed0, key, nkey, nil, s);
1647 hmac_md5(seed1, nseed1, key, nkey, ai, s);
1648
1649 while(nbuf > 0) {
1650 s = hmac_md5(ai, MD5dlen, key, nkey, nil, nil);
1651 s = hmac_md5(label, nlabel, key, nkey, nil, s);
1652 s = hmac_md5(seed0, nseed0, key, nkey, nil, s);
1653 hmac_md5(seed1, nseed1, key, nkey, tmp, s);
1654 n = MD5dlen;
1655 if(n > nbuf)
1656 n = nbuf;
1657 for(i = 0; i < n; i++)
1658 buf[i] ^= tmp[i];
1659 buf += n;
1660 nbuf -= n;
1661 hmac_md5(ai, MD5dlen, key, nkey, tmp, nil);
1662 memmove(ai, tmp, MD5dlen);
1663 }
1664 }
1665
1666 static void
tlsPsha1(uchar * buf,int nbuf,uchar * key,int nkey,uchar * label,int nlabel,uchar * seed0,int nseed0,uchar * seed1,int nseed1)1667 tlsPsha1(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
1668 {
1669 uchar ai[SHA1dlen], tmp[SHA1dlen];
1670 int i, n;
1671 SHAstate *s;
1672
1673 // generate a1
1674 s = hmac_sha1(label, nlabel, key, nkey, nil, nil);
1675 s = hmac_sha1(seed0, nseed0, key, nkey, nil, s);
1676 hmac_sha1(seed1, nseed1, key, nkey, ai, s);
1677
1678 while(nbuf > 0) {
1679 s = hmac_sha1(ai, SHA1dlen, key, nkey, nil, nil);
1680 s = hmac_sha1(label, nlabel, key, nkey, nil, s);
1681 s = hmac_sha1(seed0, nseed0, key, nkey, nil, s);
1682 hmac_sha1(seed1, nseed1, key, nkey, tmp, s);
1683 n = SHA1dlen;
1684 if(n > nbuf)
1685 n = nbuf;
1686 for(i = 0; i < n; i++)
1687 buf[i] ^= tmp[i];
1688 buf += n;
1689 nbuf -= n;
1690 hmac_sha1(ai, SHA1dlen, key, nkey, tmp, nil);
1691 memmove(ai, tmp, SHA1dlen);
1692 }
1693 }
1694
1695 // fill buf with md5(args)^sha1(args)
1696 static void
tlsPRF(uchar * buf,int nbuf,uchar * key,int nkey,char * label,uchar * seed0,int nseed0,uchar * seed1,int nseed1)1697 tlsPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
1698 {
1699 int i;
1700 int nlabel = strlen(label);
1701 int n = (nkey + 1) >> 1;
1702
1703 for(i = 0; i < nbuf; i++)
1704 buf[i] = 0;
1705 tlsPmd5(buf, nbuf, key, n, (uchar*)label, nlabel, seed0, nseed0, seed1, nseed1);
1706 tlsPsha1(buf, nbuf, key+nkey-n, n, (uchar*)label, nlabel, seed0, nseed0, seed1, nseed1);
1707 }
1708
1709 /*
1710 * for setting server session id's
1711 */
1712 static Lock sidLock;
1713 static long maxSid = 1;
1714
1715 /* the keys are verified to have the same public components
1716 * and to function correctly with pkcs 1 encryption and decryption. */
1717 static TlsSec*
tlsSecInits(int cvers,uchar * csid,int ncsid,uchar * crandom,uchar * ssid,int * nssid,uchar * srandom)1718 tlsSecInits(int cvers, uchar *csid, int ncsid, uchar *crandom, uchar *ssid, int *nssid, uchar *srandom)
1719 {
1720 TlsSec *sec = emalloc(sizeof(*sec));
1721
1722 USED(csid); USED(ncsid); // ignore csid for now
1723
1724 memmove(sec->crandom, crandom, RandomSize);
1725 sec->clientVers = cvers;
1726
1727 put32(sec->srandom, time(0));
1728 genrandom(sec->srandom+4, RandomSize-4);
1729 memmove(srandom, sec->srandom, RandomSize);
1730
1731 /*
1732 * make up a unique sid: use our pid, and and incrementing id
1733 * can signal no sid by setting nssid to 0.
1734 */
1735 memset(ssid, 0, SidSize);
1736 put32(ssid, getpid());
1737 lock(&sidLock);
1738 put32(ssid+4, maxSid++);
1739 unlock(&sidLock);
1740 *nssid = SidSize;
1741 return sec;
1742 }
1743
1744 static int
tlsSecSecrets(TlsSec * sec,int vers,uchar * epm,int nepm,uchar * kd,int nkd)1745 tlsSecSecrets(TlsSec *sec, int vers, uchar *epm, int nepm, uchar *kd, int nkd)
1746 {
1747 if(epm != nil){
1748 if(setVers(sec, vers) < 0)
1749 goto Err;
1750 serverMasterSecret(sec, epm, nepm);
1751 }else if(sec->vers != vers){
1752 werrstr("mismatched session versions");
1753 goto Err;
1754 }
1755 setSecrets(sec, kd, nkd);
1756 return 0;
1757 Err:
1758 sec->ok = -1;
1759 return -1;
1760 }
1761
1762 static TlsSec*
tlsSecInitc(int cvers,uchar * crandom)1763 tlsSecInitc(int cvers, uchar *crandom)
1764 {
1765 TlsSec *sec = emalloc(sizeof(*sec));
1766 sec->clientVers = cvers;
1767 put32(sec->crandom, time(0));
1768 genrandom(sec->crandom+4, RandomSize-4);
1769 memmove(crandom, sec->crandom, RandomSize);
1770 return sec;
1771 }
1772
1773 static int
tlsSecSecretc(TlsSec * sec,uchar * sid,int nsid,uchar * srandom,uchar * cert,int ncert,int vers,uchar ** epm,int * nepm,uchar * kd,int nkd)1774 tlsSecSecretc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers, uchar **epm, int *nepm, uchar *kd, int nkd)
1775 {
1776 RSApub *pub;
1777
1778 pub = nil;
1779
1780 USED(sid);
1781 USED(nsid);
1782
1783 memmove(sec->srandom, srandom, RandomSize);
1784
1785 if(setVers(sec, vers) < 0)
1786 goto Err;
1787
1788 pub = X509toRSApub(cert, ncert, nil, 0);
1789 if(pub == nil){
1790 werrstr("invalid x509/rsa certificate");
1791 goto Err;
1792 }
1793 if(clientMasterSecret(sec, pub, epm, nepm) < 0)
1794 goto Err;
1795 rsapubfree(pub);
1796 setSecrets(sec, kd, nkd);
1797 return 0;
1798
1799 Err:
1800 if(pub != nil)
1801 rsapubfree(pub);
1802 sec->ok = -1;
1803 return -1;
1804 }
1805
1806 static int
tlsSecFinished(TlsSec * sec,MD5state md5,SHAstate sha1,uchar * fin,int nfin,int isclient)1807 tlsSecFinished(TlsSec *sec, MD5state md5, SHAstate sha1, uchar *fin, int nfin, int isclient)
1808 {
1809 if(sec->nfin != nfin){
1810 sec->ok = -1;
1811 werrstr("invalid finished exchange");
1812 return -1;
1813 }
1814 md5.malloced = 0;
1815 sha1.malloced = 0;
1816 (*sec->setFinished)(sec, md5, sha1, fin, isclient);
1817 return 1;
1818 }
1819
1820 static void
tlsSecOk(TlsSec * sec)1821 tlsSecOk(TlsSec *sec)
1822 {
1823 if(sec->ok == 0)
1824 sec->ok = 1;
1825 }
1826
1827 static void
tlsSecKill(TlsSec * sec)1828 tlsSecKill(TlsSec *sec)
1829 {
1830 if(!sec)
1831 return;
1832 factotum_rsa_close(sec->rpc);
1833 sec->ok = -1;
1834 }
1835
1836 static void
tlsSecClose(TlsSec * sec)1837 tlsSecClose(TlsSec *sec)
1838 {
1839 if(!sec)
1840 return;
1841 factotum_rsa_close(sec->rpc);
1842 free(sec->server);
1843 free(sec);
1844 }
1845
1846 static int
setVers(TlsSec * sec,int v)1847 setVers(TlsSec *sec, int v)
1848 {
1849 if(v == SSL3Version){
1850 sec->setFinished = sslSetFinished;
1851 sec->nfin = SSL3FinishedLen;
1852 sec->prf = sslPRF;
1853 }else if(v == TLSVersion){
1854 sec->setFinished = tlsSetFinished;
1855 sec->nfin = TLSFinishedLen;
1856 sec->prf = tlsPRF;
1857 }else{
1858 werrstr("invalid version");
1859 return -1;
1860 }
1861 sec->vers = v;
1862 return 0;
1863 }
1864
1865 /*
1866 * generate secret keys from the master secret.
1867 *
1868 * different crypto selections will require different amounts
1869 * of key expansion and use of key expansion data,
1870 * but it's all generated using the same function.
1871 */
1872 static void
setSecrets(TlsSec * sec,uchar * kd,int nkd)1873 setSecrets(TlsSec *sec, uchar *kd, int nkd)
1874 {
1875 (*sec->prf)(kd, nkd, sec->sec, MasterSecretSize, "key expansion",
1876 sec->srandom, RandomSize, sec->crandom, RandomSize);
1877 }
1878
1879 /*
1880 * set the master secret from the pre-master secret.
1881 */
1882 static void
setMasterSecret(TlsSec * sec,Bytes * pm)1883 setMasterSecret(TlsSec *sec, Bytes *pm)
1884 {
1885 (*sec->prf)(sec->sec, MasterSecretSize, pm->data, MasterSecretSize, "master secret",
1886 sec->crandom, RandomSize, sec->srandom, RandomSize);
1887 }
1888
1889 static void
serverMasterSecret(TlsSec * sec,uchar * epm,int nepm)1890 serverMasterSecret(TlsSec *sec, uchar *epm, int nepm)
1891 {
1892 Bytes *pm;
1893
1894 pm = pkcs1_decrypt(sec, epm, nepm);
1895
1896 // if the client messed up, just continue as if everything is ok,
1897 // to prevent attacks to check for correctly formatted messages.
1898 // Hence the fprint(2,) can't be replaced by tlsError(), which sends an Alert msg to the client.
1899 if(sec->ok < 0 || pm == nil || get16(pm->data) != sec->clientVers){
1900 fprint(2, "serverMasterSecret failed ok=%d pm=%p pmvers=%x cvers=%x nepm=%d\n",
1901 sec->ok, pm, pm ? get16(pm->data) : -1, sec->clientVers, nepm);
1902 sec->ok = -1;
1903 if(pm != nil)
1904 freebytes(pm);
1905 pm = newbytes(MasterSecretSize);
1906 genrandom(pm->data, MasterSecretSize);
1907 }
1908 setMasterSecret(sec, pm);
1909 memset(pm->data, 0, pm->len);
1910 freebytes(pm);
1911 }
1912
1913 static int
clientMasterSecret(TlsSec * sec,RSApub * pub,uchar ** epm,int * nepm)1914 clientMasterSecret(TlsSec *sec, RSApub *pub, uchar **epm, int *nepm)
1915 {
1916 Bytes *pm, *key;
1917
1918 pm = newbytes(MasterSecretSize);
1919 put16(pm->data, sec->clientVers);
1920 genrandom(pm->data+2, MasterSecretSize - 2);
1921
1922 setMasterSecret(sec, pm);
1923
1924 key = pkcs1_encrypt(pm, pub, 2);
1925 memset(pm->data, 0, pm->len);
1926 freebytes(pm);
1927 if(key == nil){
1928 werrstr("tls pkcs1_encrypt failed");
1929 return -1;
1930 }
1931
1932 *nepm = key->len;
1933 *epm = malloc(*nepm);
1934 if(*epm == nil){
1935 freebytes(key);
1936 werrstr("out of memory");
1937 return -1;
1938 }
1939 memmove(*epm, key->data, *nepm);
1940
1941 freebytes(key);
1942
1943 return 1;
1944 }
1945
1946 static void
sslSetFinished(TlsSec * sec,MD5state hsmd5,SHAstate hssha1,uchar * finished,int isClient)1947 sslSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient)
1948 {
1949 DigestState *s;
1950 uchar h0[MD5dlen], h1[SHA1dlen], pad[48];
1951 char *label;
1952
1953 if(isClient)
1954 label = "CLNT";
1955 else
1956 label = "SRVR";
1957
1958 md5((uchar*)label, 4, nil, &hsmd5);
1959 md5(sec->sec, MasterSecretSize, nil, &hsmd5);
1960 memset(pad, 0x36, 48);
1961 md5(pad, 48, nil, &hsmd5);
1962 md5(nil, 0, h0, &hsmd5);
1963 memset(pad, 0x5C, 48);
1964 s = md5(sec->sec, MasterSecretSize, nil, nil);
1965 s = md5(pad, 48, nil, s);
1966 md5(h0, MD5dlen, finished, s);
1967
1968 sha1((uchar*)label, 4, nil, &hssha1);
1969 sha1(sec->sec, MasterSecretSize, nil, &hssha1);
1970 memset(pad, 0x36, 40);
1971 sha1(pad, 40, nil, &hssha1);
1972 sha1(nil, 0, h1, &hssha1);
1973 memset(pad, 0x5C, 40);
1974 s = sha1(sec->sec, MasterSecretSize, nil, nil);
1975 s = sha1(pad, 40, nil, s);
1976 sha1(h1, SHA1dlen, finished + MD5dlen, s);
1977 }
1978
1979 // fill "finished" arg with md5(args)^sha1(args)
1980 static void
tlsSetFinished(TlsSec * sec,MD5state hsmd5,SHAstate hssha1,uchar * finished,int isClient)1981 tlsSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient)
1982 {
1983 uchar h0[MD5dlen], h1[SHA1dlen];
1984 char *label;
1985
1986 // get current hash value, but allow further messages to be hashed in
1987 md5(nil, 0, h0, &hsmd5);
1988 sha1(nil, 0, h1, &hssha1);
1989
1990 if(isClient)
1991 label = "client finished";
1992 else
1993 label = "server finished";
1994 tlsPRF(finished, TLSFinishedLen, sec->sec, MasterSecretSize, label, h0, MD5dlen, h1, SHA1dlen);
1995 }
1996
1997 static void
sslPRF(uchar * buf,int nbuf,uchar * key,int nkey,char * label,uchar * seed0,int nseed0,uchar * seed1,int nseed1)1998 sslPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
1999 {
2000 DigestState *s;
2001 uchar sha1dig[SHA1dlen], md5dig[MD5dlen], tmp[26];
2002 int i, n, len;
2003
2004 USED(label);
2005 len = 1;
2006 while(nbuf > 0){
2007 if(len > 26)
2008 return;
2009 for(i = 0; i < len; i++)
2010 tmp[i] = 'A' - 1 + len;
2011 s = sha1(tmp, len, nil, nil);
2012 s = sha1(key, nkey, nil, s);
2013 s = sha1(seed0, nseed0, nil, s);
2014 sha1(seed1, nseed1, sha1dig, s);
2015 s = md5(key, nkey, nil, nil);
2016 md5(sha1dig, SHA1dlen, md5dig, s);
2017 n = MD5dlen;
2018 if(n > nbuf)
2019 n = nbuf;
2020 memmove(buf, md5dig, n);
2021 buf += n;
2022 nbuf -= n;
2023 len++;
2024 }
2025 }
2026
2027 static mpint*
bytestomp(Bytes * bytes)2028 bytestomp(Bytes* bytes)
2029 {
2030 mpint* ans;
2031
2032 ans = betomp(bytes->data, bytes->len, nil);
2033 return ans;
2034 }
2035
2036 /*
2037 * Convert mpint* to Bytes, putting high order byte first.
2038 */
2039 static Bytes*
mptobytes(mpint * big)2040 mptobytes(mpint* big)
2041 {
2042 int n, m;
2043 uchar *a;
2044 Bytes* ans;
2045
2046 n = (mpsignif(big)+7)/8;
2047 m = mptobe(big, nil, n, &a);
2048 ans = makebytes(a, m);
2049 return ans;
2050 }
2051
2052 // Do RSA computation on block according to key, and pad
2053 // result on left with zeros to make it modlen long.
2054 static Bytes*
rsacomp(Bytes * block,RSApub * key,int modlen)2055 rsacomp(Bytes* block, RSApub* key, int modlen)
2056 {
2057 mpint *x, *y;
2058 Bytes *a, *ybytes;
2059 int ylen;
2060
2061 x = bytestomp(block);
2062 y = rsaencrypt(key, x, nil);
2063 mpfree(x);
2064 ybytes = mptobytes(y);
2065 ylen = ybytes->len;
2066
2067 if(ylen < modlen) {
2068 a = newbytes(modlen);
2069 memset(a->data, 0, modlen-ylen);
2070 memmove(a->data+modlen-ylen, ybytes->data, ylen);
2071 freebytes(ybytes);
2072 ybytes = a;
2073 }
2074 else if(ylen > modlen) {
2075 // assume it has leading zeros (mod should make it so)
2076 a = newbytes(modlen);
2077 memmove(a->data, ybytes->data, modlen);
2078 freebytes(ybytes);
2079 ybytes = a;
2080 }
2081 mpfree(y);
2082 return ybytes;
2083 }
2084
2085 // encrypt data according to PKCS#1, /lib/rfc/rfc2437 9.1.2.1
2086 static Bytes*
pkcs1_encrypt(Bytes * data,RSApub * key,int blocktype)2087 pkcs1_encrypt(Bytes* data, RSApub* key, int blocktype)
2088 {
2089 Bytes *pad, *eb, *ans;
2090 int i, dlen, padlen, modlen;
2091
2092 modlen = (mpsignif(key->n)+7)/8;
2093 dlen = data->len;
2094 if(modlen < 12 || dlen > modlen - 11)
2095 return nil;
2096 padlen = modlen - 3 - dlen;
2097 pad = newbytes(padlen);
2098 genrandom(pad->data, padlen);
2099 for(i = 0; i < padlen; i++) {
2100 if(blocktype == 0)
2101 pad->data[i] = 0;
2102 else if(blocktype == 1)
2103 pad->data[i] = 255;
2104 else if(pad->data[i] == 0)
2105 pad->data[i] = 1;
2106 }
2107 eb = newbytes(modlen);
2108 eb->data[0] = 0;
2109 eb->data[1] = blocktype;
2110 memmove(eb->data+2, pad->data, padlen);
2111 eb->data[padlen+2] = 0;
2112 memmove(eb->data+padlen+3, data->data, dlen);
2113 ans = rsacomp(eb, key, modlen);
2114 freebytes(eb);
2115 freebytes(pad);
2116 return ans;
2117 }
2118
2119 // decrypt data according to PKCS#1, with given key.
2120 // expect a block type of 2.
2121 static Bytes*
pkcs1_decrypt(TlsSec * sec,uchar * epm,int nepm)2122 pkcs1_decrypt(TlsSec *sec, uchar *epm, int nepm)
2123 {
2124 Bytes *eb, *ans = nil;
2125 int i, modlen;
2126 mpint *x, *y;
2127
2128 modlen = (mpsignif(sec->rsapub->n)+7)/8;
2129 if(nepm != modlen)
2130 return nil;
2131 x = betomp(epm, nepm, nil);
2132 y = factotum_rsa_decrypt(sec->rpc, x);
2133 if(y == nil)
2134 return nil;
2135 eb = mptobytes(y);
2136 if(eb->len < modlen){ // pad on left with zeros
2137 ans = newbytes(modlen);
2138 memset(ans->data, 0, modlen-eb->len);
2139 memmove(ans->data+modlen-eb->len, eb->data, eb->len);
2140 freebytes(eb);
2141 eb = ans;
2142 }
2143 if(eb->data[0] == 0 && eb->data[1] == 2) {
2144 for(i = 2; i < modlen; i++)
2145 if(eb->data[i] == 0)
2146 break;
2147 if(i < modlen - 1)
2148 ans = makebytes(eb->data+i+1, modlen-(i+1));
2149 }
2150 freebytes(eb);
2151 return ans;
2152 }
2153
2154
2155 //================= general utility functions ========================
2156
2157 static void *
emalloc(int n)2158 emalloc(int n)
2159 {
2160 void *p;
2161 if(n==0)
2162 n=1;
2163 p = malloc(n);
2164 if(p == nil){
2165 exits("out of memory");
2166 }
2167 memset(p, 0, n);
2168 return p;
2169 }
2170
2171 static void *
erealloc(void * ReallocP,int ReallocN)2172 erealloc(void *ReallocP, int ReallocN)
2173 {
2174 if(ReallocN == 0)
2175 ReallocN = 1;
2176 if(!ReallocP)
2177 ReallocP = emalloc(ReallocN);
2178 else if(!(ReallocP = realloc(ReallocP, ReallocN))){
2179 exits("out of memory");
2180 }
2181 return(ReallocP);
2182 }
2183
2184 static void
put32(uchar * p,u32int x)2185 put32(uchar *p, u32int x)
2186 {
2187 p[0] = x>>24;
2188 p[1] = x>>16;
2189 p[2] = x>>8;
2190 p[3] = x;
2191 }
2192
2193 static void
put24(uchar * p,int x)2194 put24(uchar *p, int x)
2195 {
2196 p[0] = x>>16;
2197 p[1] = x>>8;
2198 p[2] = x;
2199 }
2200
2201 static void
put16(uchar * p,int x)2202 put16(uchar *p, int x)
2203 {
2204 p[0] = x>>8;
2205 p[1] = x;
2206 }
2207
2208 static u32int
get32(uchar * p)2209 get32(uchar *p)
2210 {
2211 return (p[0]<<24)|(p[1]<<16)|(p[2]<<8)|p[3];
2212 }
2213
2214 static int
get24(uchar * p)2215 get24(uchar *p)
2216 {
2217 return (p[0]<<16)|(p[1]<<8)|p[2];
2218 }
2219
2220 static int
get16(uchar * p)2221 get16(uchar *p)
2222 {
2223 return (p[0]<<8)|p[1];
2224 }
2225
2226 /* ANSI offsetof() */
2227 #define OFFSET(x, s) ((int)(&(((s*)0)->x)))
2228
2229 /*
2230 * malloc and return a new Bytes structure capable of
2231 * holding len bytes. (len >= 0)
2232 * Used to use crypt_malloc, which aborts if malloc fails.
2233 */
2234 static Bytes*
newbytes(int len)2235 newbytes(int len)
2236 {
2237 Bytes* ans;
2238
2239 ans = (Bytes*)malloc(OFFSET(data[0], Bytes) + len);
2240 ans->len = len;
2241 return ans;
2242 }
2243
2244 /*
2245 * newbytes(len), with data initialized from buf
2246 */
2247 static Bytes*
makebytes(uchar * buf,int len)2248 makebytes(uchar* buf, int len)
2249 {
2250 Bytes* ans;
2251
2252 ans = newbytes(len);
2253 memmove(ans->data, buf, len);
2254 return ans;
2255 }
2256
2257 static void
freebytes(Bytes * b)2258 freebytes(Bytes* b)
2259 {
2260 if(b != nil)
2261 free(b);
2262 }
2263
2264 /* len is number of ints */
2265 static Ints*
newints(int len)2266 newints(int len)
2267 {
2268 Ints* ans;
2269
2270 ans = (Ints*)malloc(OFFSET(data[0], Ints) + len*sizeof(int));
2271 ans->len = len;
2272 return ans;
2273 }
2274
2275 static Ints*
makeints(int * buf,int len)2276 makeints(int* buf, int len)
2277 {
2278 Ints* ans;
2279
2280 ans = newints(len);
2281 if(len > 0)
2282 memmove(ans->data, buf, len*sizeof(int));
2283 return ans;
2284 }
2285
2286 static void
freeints(Ints * b)2287 freeints(Ints* b)
2288 {
2289 if(b != nil)
2290 free(b);
2291 }
2292