xref: /openbsd-src/usr.bin/ssh/kex.c (revision 9b9d2a55a62c8e82206c25f94fcc7f4e2765250e)
1 /* $OpenBSD: kex.c,v 1.110 2015/08/21 23:57:48 djm Exp $ */
2 /*
3  * Copyright (c) 2000, 2001 Markus Friedl.  All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  * 1. Redistributions of source code must retain the above copyright
9  *    notice, this list of conditions and the following disclaimer.
10  * 2. Redistributions in binary form must reproduce the above copyright
11  *    notice, this list of conditions and the following disclaimer in the
12  *    documentation and/or other materials provided with the distribution.
13  *
14  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
15  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
16  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
17  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
18  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
19  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
20  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
21  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
23  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24  */
25 
26 #include <sys/param.h>	/* MAX roundup */
27 
28 #include <signal.h>
29 #include <stdio.h>
30 #include <stdlib.h>
31 #include <string.h>
32 
33 #ifdef WITH_OPENSSL
34 #include <openssl/crypto.h>
35 #endif
36 
37 #include "ssh2.h"
38 #include "packet.h"
39 #include "compat.h"
40 #include "cipher.h"
41 #include "sshkey.h"
42 #include "kex.h"
43 #include "log.h"
44 #include "mac.h"
45 #include "match.h"
46 #include "misc.h"
47 #include "dispatch.h"
48 #include "monitor.h"
49 #include "roaming.h"
50 
51 #include "ssherr.h"
52 #include "sshbuf.h"
53 #include "digest.h"
54 
55 /* prototype */
56 static int kex_choose_conf(struct ssh *);
57 static int kex_input_newkeys(int, u_int32_t, void *);
58 
59 static const char *proposal_names[PROPOSAL_MAX] = {
60 	"KEX algorithms",
61 	"host key algorithms",
62 	"ciphers ctos",
63 	"ciphers stoc",
64 	"MACs ctos",
65 	"MACs stoc",
66 	"compression ctos",
67 	"compression stoc",
68 	"languages ctos",
69 	"languages stoc",
70 };
71 
72 struct kexalg {
73 	char *name;
74 	u_int type;
75 	int ec_nid;
76 	int hash_alg;
77 };
78 static const struct kexalg kexalgs[] = {
79 #ifdef WITH_OPENSSL
80 	{ KEX_DH1, KEX_DH_GRP1_SHA1, 0, SSH_DIGEST_SHA1 },
81 	{ KEX_DH14, KEX_DH_GRP14_SHA1, 0, SSH_DIGEST_SHA1 },
82 	{ KEX_DHGEX_SHA1, KEX_DH_GEX_SHA1, 0, SSH_DIGEST_SHA1 },
83 	{ KEX_DHGEX_SHA256, KEX_DH_GEX_SHA256, 0, SSH_DIGEST_SHA256 },
84 	{ KEX_ECDH_SHA2_NISTP256, KEX_ECDH_SHA2,
85 	    NID_X9_62_prime256v1, SSH_DIGEST_SHA256 },
86 	{ KEX_ECDH_SHA2_NISTP384, KEX_ECDH_SHA2, NID_secp384r1,
87 	    SSH_DIGEST_SHA384 },
88 	{ KEX_ECDH_SHA2_NISTP521, KEX_ECDH_SHA2, NID_secp521r1,
89 	    SSH_DIGEST_SHA512 },
90 #endif
91 	{ KEX_CURVE25519_SHA256, KEX_C25519_SHA256, 0, SSH_DIGEST_SHA256 },
92 	{ NULL, -1, -1, -1},
93 };
94 
95 char *
96 kex_alg_list(char sep)
97 {
98 	char *ret = NULL, *tmp;
99 	size_t nlen, rlen = 0;
100 	const struct kexalg *k;
101 
102 	for (k = kexalgs; k->name != NULL; k++) {
103 		if (ret != NULL)
104 			ret[rlen++] = sep;
105 		nlen = strlen(k->name);
106 		if ((tmp = realloc(ret, rlen + nlen + 2)) == NULL) {
107 			free(ret);
108 			return NULL;
109 		}
110 		ret = tmp;
111 		memcpy(ret + rlen, k->name, nlen + 1);
112 		rlen += nlen;
113 	}
114 	return ret;
115 }
116 
117 static const struct kexalg *
118 kex_alg_by_name(const char *name)
119 {
120 	const struct kexalg *k;
121 
122 	for (k = kexalgs; k->name != NULL; k++) {
123 		if (strcmp(k->name, name) == 0)
124 			return k;
125 	}
126 	return NULL;
127 }
128 
129 /* Validate KEX method name list */
130 int
131 kex_names_valid(const char *names)
132 {
133 	char *s, *cp, *p;
134 
135 	if (names == NULL || strcmp(names, "") == 0)
136 		return 0;
137 	if ((s = cp = strdup(names)) == NULL)
138 		return 0;
139 	for ((p = strsep(&cp, ",")); p && *p != '\0';
140 	    (p = strsep(&cp, ","))) {
141 		if (kex_alg_by_name(p) == NULL) {
142 			error("Unsupported KEX algorithm \"%.100s\"", p);
143 			free(s);
144 			return 0;
145 		}
146 	}
147 	debug3("kex names ok: [%s]", names);
148 	free(s);
149 	return 1;
150 }
151 
152 /*
153  * Concatenate algorithm names, avoiding duplicates in the process.
154  * Caller must free returned string.
155  */
156 char *
157 kex_names_cat(const char *a, const char *b)
158 {
159 	char *ret = NULL, *tmp = NULL, *cp, *p;
160 	size_t len;
161 
162 	if (a == NULL || *a == '\0')
163 		return NULL;
164 	if (b == NULL || *b == '\0')
165 		return strdup(a);
166 	if (strlen(b) > 1024*1024)
167 		return NULL;
168 	len = strlen(a) + strlen(b) + 2;
169 	if ((tmp = cp = strdup(b)) == NULL ||
170 	    (ret = calloc(1, len)) == NULL) {
171 		free(tmp);
172 		return NULL;
173 	}
174 	strlcpy(ret, a, len);
175 	for ((p = strsep(&cp, ",")); p && *p != '\0'; (p = strsep(&cp, ","))) {
176 		if (match_list(ret, p, NULL) != NULL)
177 			continue; /* Algorithm already present */
178 		if (strlcat(ret, ",", len) >= len ||
179 		    strlcat(ret, p, len) >= len) {
180 			free(tmp);
181 			free(ret);
182 			return NULL; /* Shouldn't happen */
183 		}
184 	}
185 	free(tmp);
186 	return ret;
187 }
188 
189 /*
190  * Assemble a list of algorithms from a default list and a string from a
191  * configuration file. The user-provided string may begin with '+' to
192  * indicate that it should be appended to the default.
193  */
194 int
195 kex_assemble_names(const char *def, char **list)
196 {
197 	char *ret;
198 
199 	if (list == NULL || *list == NULL || **list == '\0') {
200 		*list = strdup(def);
201 		return 0;
202 	}
203 	if (**list != '+') {
204 		return 0;
205 	}
206 
207 	if ((ret = kex_names_cat(def, *list + 1)) == NULL)
208 		return SSH_ERR_ALLOC_FAIL;
209 	free(*list);
210 	*list = ret;
211 	return 0;
212 }
213 
214 /* put algorithm proposal into buffer */
215 int
216 kex_prop2buf(struct sshbuf *b, char *proposal[PROPOSAL_MAX])
217 {
218 	u_int i;
219 	int r;
220 
221 	sshbuf_reset(b);
222 
223 	/*
224 	 * add a dummy cookie, the cookie will be overwritten by
225 	 * kex_send_kexinit(), each time a kexinit is set
226 	 */
227 	for (i = 0; i < KEX_COOKIE_LEN; i++) {
228 		if ((r = sshbuf_put_u8(b, 0)) != 0)
229 			return r;
230 	}
231 	for (i = 0; i < PROPOSAL_MAX; i++) {
232 		if ((r = sshbuf_put_cstring(b, proposal[i])) != 0)
233 			return r;
234 	}
235 	if ((r = sshbuf_put_u8(b, 0)) != 0 ||	/* first_kex_packet_follows */
236 	    (r = sshbuf_put_u32(b, 0)) != 0)	/* uint32 reserved */
237 		return r;
238 	return 0;
239 }
240 
241 /* parse buffer and return algorithm proposal */
242 int
243 kex_buf2prop(struct sshbuf *raw, int *first_kex_follows, char ***propp)
244 {
245 	struct sshbuf *b = NULL;
246 	u_char v;
247 	u_int i;
248 	char **proposal = NULL;
249 	int r;
250 
251 	*propp = NULL;
252 	if ((proposal = calloc(PROPOSAL_MAX, sizeof(char *))) == NULL)
253 		return SSH_ERR_ALLOC_FAIL;
254 	if ((b = sshbuf_fromb(raw)) == NULL) {
255 		r = SSH_ERR_ALLOC_FAIL;
256 		goto out;
257 	}
258 	if ((r = sshbuf_consume(b, KEX_COOKIE_LEN)) != 0) /* skip cookie */
259 		goto out;
260 	/* extract kex init proposal strings */
261 	for (i = 0; i < PROPOSAL_MAX; i++) {
262 		if ((r = sshbuf_get_cstring(b, &(proposal[i]), NULL)) != 0)
263 			goto out;
264 		debug2("%s: %s", proposal_names[i], proposal[i]);
265 	}
266 	/* first kex follows / reserved */
267 	if ((r = sshbuf_get_u8(b, &v)) != 0 ||
268 	    (r = sshbuf_get_u32(b, &i)) != 0)
269 		goto out;
270 	if (first_kex_follows != NULL)
271 		*first_kex_follows = i;
272 	debug2("first_kex_follows %d ", v);
273 	debug2("reserved %u ", i);
274 	r = 0;
275 	*propp = proposal;
276  out:
277 	if (r != 0 && proposal != NULL)
278 		kex_prop_free(proposal);
279 	sshbuf_free(b);
280 	return r;
281 }
282 
283 void
284 kex_prop_free(char **proposal)
285 {
286 	u_int i;
287 
288 	if (proposal == NULL)
289 		return;
290 	for (i = 0; i < PROPOSAL_MAX; i++)
291 		free(proposal[i]);
292 	free(proposal);
293 }
294 
295 /* ARGSUSED */
296 static int
297 kex_protocol_error(int type, u_int32_t seq, void *ctxt)
298 {
299 	error("Hm, kex protocol error: type %d seq %u", type, seq);
300 	return 0;
301 }
302 
303 static void
304 kex_reset_dispatch(struct ssh *ssh)
305 {
306 	ssh_dispatch_range(ssh, SSH2_MSG_TRANSPORT_MIN,
307 	    SSH2_MSG_TRANSPORT_MAX, &kex_protocol_error);
308 	ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_input_kexinit);
309 }
310 
311 int
312 kex_send_newkeys(struct ssh *ssh)
313 {
314 	int r;
315 
316 	kex_reset_dispatch(ssh);
317 	if ((r = sshpkt_start(ssh, SSH2_MSG_NEWKEYS)) != 0 ||
318 	    (r = sshpkt_send(ssh)) != 0)
319 		return r;
320 	debug("SSH2_MSG_NEWKEYS sent");
321 	debug("expecting SSH2_MSG_NEWKEYS");
322 	ssh_dispatch_set(ssh, SSH2_MSG_NEWKEYS, &kex_input_newkeys);
323 	return 0;
324 }
325 
326 static int
327 kex_input_newkeys(int type, u_int32_t seq, void *ctxt)
328 {
329 	struct ssh *ssh = ctxt;
330 	struct kex *kex = ssh->kex;
331 	int r;
332 
333 	debug("SSH2_MSG_NEWKEYS received");
334 	ssh_dispatch_set(ssh, SSH2_MSG_NEWKEYS, &kex_protocol_error);
335 	if ((r = sshpkt_get_end(ssh)) != 0)
336 		return r;
337 	kex->done = 1;
338 	sshbuf_reset(kex->peer);
339 	/* sshbuf_reset(kex->my); */
340 	kex->flags &= ~KEX_INIT_SENT;
341 	free(kex->name);
342 	kex->name = NULL;
343 	return 0;
344 }
345 
346 int
347 kex_send_kexinit(struct ssh *ssh)
348 {
349 	u_char *cookie;
350 	struct kex *kex = ssh->kex;
351 	int r;
352 
353 	if (kex == NULL)
354 		return SSH_ERR_INTERNAL_ERROR;
355 	if (kex->flags & KEX_INIT_SENT)
356 		return 0;
357 	kex->done = 0;
358 
359 	/* generate a random cookie */
360 	if (sshbuf_len(kex->my) < KEX_COOKIE_LEN)
361 		return SSH_ERR_INVALID_FORMAT;
362 	if ((cookie = sshbuf_mutable_ptr(kex->my)) == NULL)
363 		return SSH_ERR_INTERNAL_ERROR;
364 	arc4random_buf(cookie, KEX_COOKIE_LEN);
365 
366 	if ((r = sshpkt_start(ssh, SSH2_MSG_KEXINIT)) != 0 ||
367 	    (r = sshpkt_putb(ssh, kex->my)) != 0 ||
368 	    (r = sshpkt_send(ssh)) != 0)
369 		return r;
370 	debug("SSH2_MSG_KEXINIT sent");
371 	kex->flags |= KEX_INIT_SENT;
372 	return 0;
373 }
374 
375 /* ARGSUSED */
376 int
377 kex_input_kexinit(int type, u_int32_t seq, void *ctxt)
378 {
379 	struct ssh *ssh = ctxt;
380 	struct kex *kex = ssh->kex;
381 	const u_char *ptr;
382 	u_int i;
383 	size_t dlen;
384 	int r;
385 
386 	debug("SSH2_MSG_KEXINIT received");
387 	if (kex == NULL)
388 		return SSH_ERR_INVALID_ARGUMENT;
389 
390 	ptr = sshpkt_ptr(ssh, &dlen);
391 	if ((r = sshbuf_put(kex->peer, ptr, dlen)) != 0)
392 		return r;
393 
394 	/* discard packet */
395 	for (i = 0; i < KEX_COOKIE_LEN; i++)
396 		if ((r = sshpkt_get_u8(ssh, NULL)) != 0)
397 			return r;
398 	for (i = 0; i < PROPOSAL_MAX; i++)
399 		if ((r = sshpkt_get_string(ssh, NULL, NULL)) != 0)
400 			return r;
401 	/*
402 	 * XXX RFC4253 sec 7: "each side MAY guess" - currently no supported
403 	 * KEX method has the server move first, but a server might be using
404 	 * a custom method or one that we otherwise don't support. We should
405 	 * be prepared to remember first_kex_follows here so we can eat a
406 	 * packet later.
407 	 * XXX2 - RFC4253 is kind of ambiguous on what first_kex_follows means
408 	 * for cases where the server *doesn't* go first. I guess we should
409 	 * ignore it when it is set for these cases, which is what we do now.
410 	 */
411 	if ((r = sshpkt_get_u8(ssh, NULL)) != 0 ||	/* first_kex_follows */
412 	    (r = sshpkt_get_u32(ssh, NULL)) != 0 ||	/* reserved */
413 	    (r = sshpkt_get_end(ssh)) != 0)
414 			return r;
415 
416 	if (!(kex->flags & KEX_INIT_SENT))
417 		if ((r = kex_send_kexinit(ssh)) != 0)
418 			return r;
419 	if ((r = kex_choose_conf(ssh)) != 0)
420 		return r;
421 
422 	if (kex->kex_type < KEX_MAX && kex->kex[kex->kex_type] != NULL)
423 		return (kex->kex[kex->kex_type])(ssh);
424 
425 	return SSH_ERR_INTERNAL_ERROR;
426 }
427 
428 int
429 kex_new(struct ssh *ssh, char *proposal[PROPOSAL_MAX], struct kex **kexp)
430 {
431 	struct kex *kex;
432 	int r;
433 
434 	*kexp = NULL;
435 	if ((kex = calloc(1, sizeof(*kex))) == NULL)
436 		return SSH_ERR_ALLOC_FAIL;
437 	if ((kex->peer = sshbuf_new()) == NULL ||
438 	    (kex->my = sshbuf_new()) == NULL) {
439 		r = SSH_ERR_ALLOC_FAIL;
440 		goto out;
441 	}
442 	if ((r = kex_prop2buf(kex->my, proposal)) != 0)
443 		goto out;
444 	kex->done = 0;
445 	kex_reset_dispatch(ssh);
446 	r = 0;
447 	*kexp = kex;
448  out:
449 	if (r != 0)
450 		kex_free(kex);
451 	return r;
452 }
453 
454 void
455 kex_free_newkeys(struct newkeys *newkeys)
456 {
457 	if (newkeys == NULL)
458 		return;
459 	if (newkeys->enc.key) {
460 		explicit_bzero(newkeys->enc.key, newkeys->enc.key_len);
461 		free(newkeys->enc.key);
462 		newkeys->enc.key = NULL;
463 	}
464 	if (newkeys->enc.iv) {
465 		explicit_bzero(newkeys->enc.iv, newkeys->enc.block_size);
466 		free(newkeys->enc.iv);
467 		newkeys->enc.iv = NULL;
468 	}
469 	free(newkeys->enc.name);
470 	explicit_bzero(&newkeys->enc, sizeof(newkeys->enc));
471 	free(newkeys->comp.name);
472 	explicit_bzero(&newkeys->comp, sizeof(newkeys->comp));
473 	mac_clear(&newkeys->mac);
474 	if (newkeys->mac.key) {
475 		explicit_bzero(newkeys->mac.key, newkeys->mac.key_len);
476 		free(newkeys->mac.key);
477 		newkeys->mac.key = NULL;
478 	}
479 	free(newkeys->mac.name);
480 	explicit_bzero(&newkeys->mac, sizeof(newkeys->mac));
481 	explicit_bzero(newkeys, sizeof(*newkeys));
482 	free(newkeys);
483 }
484 
485 void
486 kex_free(struct kex *kex)
487 {
488 	u_int mode;
489 
490 #ifdef WITH_OPENSSL
491 	if (kex->dh)
492 		DH_free(kex->dh);
493 	if (kex->ec_client_key)
494 		EC_KEY_free(kex->ec_client_key);
495 #endif
496 	for (mode = 0; mode < MODE_MAX; mode++) {
497 		kex_free_newkeys(kex->newkeys[mode]);
498 		kex->newkeys[mode] = NULL;
499 	}
500 	sshbuf_free(kex->peer);
501 	sshbuf_free(kex->my);
502 	free(kex->session_id);
503 	free(kex->client_version_string);
504 	free(kex->server_version_string);
505 	free(kex->failed_choice);
506 	free(kex);
507 }
508 
509 int
510 kex_setup(struct ssh *ssh, char *proposal[PROPOSAL_MAX])
511 {
512 	int r;
513 
514 	if ((r = kex_new(ssh, proposal, &ssh->kex)) != 0)
515 		return r;
516 	if ((r = kex_send_kexinit(ssh)) != 0) {		/* we start */
517 		kex_free(ssh->kex);
518 		ssh->kex = NULL;
519 		return r;
520 	}
521 	return 0;
522 }
523 
524 static int
525 choose_enc(struct sshenc *enc, char *client, char *server)
526 {
527 	char *name = match_list(client, server, NULL);
528 
529 	if (name == NULL)
530 		return SSH_ERR_NO_CIPHER_ALG_MATCH;
531 	if ((enc->cipher = cipher_by_name(name)) == NULL)
532 		return SSH_ERR_INTERNAL_ERROR;
533 	enc->name = name;
534 	enc->enabled = 0;
535 	enc->iv = NULL;
536 	enc->iv_len = cipher_ivlen(enc->cipher);
537 	enc->key = NULL;
538 	enc->key_len = cipher_keylen(enc->cipher);
539 	enc->block_size = cipher_blocksize(enc->cipher);
540 	return 0;
541 }
542 
543 static int
544 choose_mac(struct ssh *ssh, struct sshmac *mac, char *client, char *server)
545 {
546 	char *name = match_list(client, server, NULL);
547 
548 	if (name == NULL)
549 		return SSH_ERR_NO_MAC_ALG_MATCH;
550 	if (mac_setup(mac, name) < 0)
551 		return SSH_ERR_INTERNAL_ERROR;
552 	/* truncate the key */
553 	if (ssh->compat & SSH_BUG_HMAC)
554 		mac->key_len = 16;
555 	mac->name = name;
556 	mac->key = NULL;
557 	mac->enabled = 0;
558 	return 0;
559 }
560 
561 static int
562 choose_comp(struct sshcomp *comp, char *client, char *server)
563 {
564 	char *name = match_list(client, server, NULL);
565 
566 	if (name == NULL)
567 		return SSH_ERR_NO_COMPRESS_ALG_MATCH;
568 	if (strcmp(name, "zlib@openssh.com") == 0) {
569 		comp->type = COMP_DELAYED;
570 	} else if (strcmp(name, "zlib") == 0) {
571 		comp->type = COMP_ZLIB;
572 	} else if (strcmp(name, "none") == 0) {
573 		comp->type = COMP_NONE;
574 	} else {
575 		return SSH_ERR_INTERNAL_ERROR;
576 	}
577 	comp->name = name;
578 	return 0;
579 }
580 
581 static int
582 choose_kex(struct kex *k, char *client, char *server)
583 {
584 	const struct kexalg *kexalg;
585 
586 	k->name = match_list(client, server, NULL);
587 
588 	debug("kex: algorithm: %s", k->name ? k->name : "(no match)");
589 	if (k->name == NULL)
590 		return SSH_ERR_NO_KEX_ALG_MATCH;
591 	if ((kexalg = kex_alg_by_name(k->name)) == NULL)
592 		return SSH_ERR_INTERNAL_ERROR;
593 	k->kex_type = kexalg->type;
594 	k->hash_alg = kexalg->hash_alg;
595 	k->ec_nid = kexalg->ec_nid;
596 	return 0;
597 }
598 
599 static int
600 choose_hostkeyalg(struct kex *k, char *client, char *server)
601 {
602 	char *hostkeyalg = match_list(client, server, NULL);
603 
604 	debug("kex: host key algorithm: %s",
605 	    hostkeyalg ? hostkeyalg : "(no match)");
606 	if (hostkeyalg == NULL)
607 		return SSH_ERR_NO_HOSTKEY_ALG_MATCH;
608 	k->hostkey_type = sshkey_type_from_name(hostkeyalg);
609 	if (k->hostkey_type == KEY_UNSPEC)
610 		return SSH_ERR_INTERNAL_ERROR;
611 	k->hostkey_nid = sshkey_ecdsa_nid_from_name(hostkeyalg);
612 	free(hostkeyalg);
613 	return 0;
614 }
615 
616 static int
617 proposals_match(char *my[PROPOSAL_MAX], char *peer[PROPOSAL_MAX])
618 {
619 	static int check[] = {
620 		PROPOSAL_KEX_ALGS, PROPOSAL_SERVER_HOST_KEY_ALGS, -1
621 	};
622 	int *idx;
623 	char *p;
624 
625 	for (idx = &check[0]; *idx != -1; idx++) {
626 		if ((p = strchr(my[*idx], ',')) != NULL)
627 			*p = '\0';
628 		if ((p = strchr(peer[*idx], ',')) != NULL)
629 			*p = '\0';
630 		if (strcmp(my[*idx], peer[*idx]) != 0) {
631 			debug2("proposal mismatch: my %s peer %s",
632 			    my[*idx], peer[*idx]);
633 			return (0);
634 		}
635 	}
636 	debug2("proposals match");
637 	return (1);
638 }
639 
640 static int
641 kex_choose_conf(struct ssh *ssh)
642 {
643 	struct kex *kex = ssh->kex;
644 	struct newkeys *newkeys;
645 	char **my = NULL, **peer = NULL;
646 	char **cprop, **sprop;
647 	int nenc, nmac, ncomp;
648 	u_int mode, ctos, need, dh_need, authlen;
649 	int r, first_kex_follows;
650 
651 	debug2("local %s KEXINIT proposal", kex->server ? "server" : "client");
652 	if ((r = kex_buf2prop(kex->my, NULL, &my)) != 0)
653 		goto out;
654 	debug2("peer %s KEXINIT proposal", kex->server ? "client" : "server");
655 	if ((r = kex_buf2prop(kex->peer, &first_kex_follows, &peer)) != 0)
656 		goto out;
657 
658 	if (kex->server) {
659 		cprop=peer;
660 		sprop=my;
661 	} else {
662 		cprop=my;
663 		sprop=peer;
664 	}
665 
666 	/* Check whether server offers roaming */
667 	if (!kex->server) {
668 		char *roaming = match_list(KEX_RESUME,
669 		    peer[PROPOSAL_KEX_ALGS], NULL);
670 
671 		if (roaming) {
672 			kex->roaming = 1;
673 			free(roaming);
674 		}
675 	}
676 
677 	/* Algorithm Negotiation */
678 	if ((r = choose_kex(kex, cprop[PROPOSAL_KEX_ALGS],
679 	    sprop[PROPOSAL_KEX_ALGS])) != 0) {
680 		kex->failed_choice = peer[PROPOSAL_KEX_ALGS];
681 		peer[PROPOSAL_KEX_ALGS] = NULL;
682 		goto out;
683 	}
684 	if ((r = choose_hostkeyalg(kex, cprop[PROPOSAL_SERVER_HOST_KEY_ALGS],
685 	    sprop[PROPOSAL_SERVER_HOST_KEY_ALGS])) != 0) {
686 		kex->failed_choice = peer[PROPOSAL_SERVER_HOST_KEY_ALGS];
687 		peer[PROPOSAL_SERVER_HOST_KEY_ALGS] = NULL;
688 		goto out;
689 	}
690 	for (mode = 0; mode < MODE_MAX; mode++) {
691 		if ((newkeys = calloc(1, sizeof(*newkeys))) == NULL) {
692 			r = SSH_ERR_ALLOC_FAIL;
693 			goto out;
694 		}
695 		kex->newkeys[mode] = newkeys;
696 		ctos = (!kex->server && mode == MODE_OUT) ||
697 		    (kex->server && mode == MODE_IN);
698 		nenc  = ctos ? PROPOSAL_ENC_ALGS_CTOS  : PROPOSAL_ENC_ALGS_STOC;
699 		nmac  = ctos ? PROPOSAL_MAC_ALGS_CTOS  : PROPOSAL_MAC_ALGS_STOC;
700 		ncomp = ctos ? PROPOSAL_COMP_ALGS_CTOS : PROPOSAL_COMP_ALGS_STOC;
701 		if ((r = choose_enc(&newkeys->enc, cprop[nenc],
702 		    sprop[nenc])) != 0) {
703 			kex->failed_choice = peer[nenc];
704 			peer[nenc] = NULL;
705 			goto out;
706 		}
707 		authlen = cipher_authlen(newkeys->enc.cipher);
708 		/* ignore mac for authenticated encryption */
709 		if (authlen == 0 &&
710 		    (r = choose_mac(ssh, &newkeys->mac, cprop[nmac],
711 		    sprop[nmac])) != 0) {
712 			kex->failed_choice = peer[nmac];
713 			peer[nmac] = NULL;
714 			goto out;
715 		}
716 		if ((r = choose_comp(&newkeys->comp, cprop[ncomp],
717 		    sprop[ncomp])) != 0) {
718 			kex->failed_choice = peer[ncomp];
719 			peer[ncomp] = NULL;
720 			goto out;
721 		}
722 		debug("kex: %s cipher: %s MAC: %s compression: %s",
723 		    ctos ? "client->server" : "server->client",
724 		    newkeys->enc.name,
725 		    authlen == 0 ? newkeys->mac.name : "<implicit>",
726 		    newkeys->comp.name);
727 	}
728 	need = dh_need = 0;
729 	for (mode = 0; mode < MODE_MAX; mode++) {
730 		newkeys = kex->newkeys[mode];
731 		need = MAX(need, newkeys->enc.key_len);
732 		need = MAX(need, newkeys->enc.block_size);
733 		need = MAX(need, newkeys->enc.iv_len);
734 		need = MAX(need, newkeys->mac.key_len);
735 		dh_need = MAX(dh_need, cipher_seclen(newkeys->enc.cipher));
736 		dh_need = MAX(dh_need, newkeys->enc.block_size);
737 		dh_need = MAX(dh_need, newkeys->enc.iv_len);
738 		dh_need = MAX(dh_need, newkeys->mac.key_len);
739 	}
740 	/* XXX need runden? */
741 	kex->we_need = need;
742 	kex->dh_need = dh_need;
743 
744 	/* ignore the next message if the proposals do not match */
745 	if (first_kex_follows && !proposals_match(my, peer) &&
746 	    !(ssh->compat & SSH_BUG_FIRSTKEX))
747 		ssh->dispatch_skip_packets = 1;
748 	r = 0;
749  out:
750 	kex_prop_free(my);
751 	kex_prop_free(peer);
752 	return r;
753 }
754 
755 static int
756 derive_key(struct ssh *ssh, int id, u_int need, u_char *hash, u_int hashlen,
757     const struct sshbuf *shared_secret, u_char **keyp)
758 {
759 	struct kex *kex = ssh->kex;
760 	struct ssh_digest_ctx *hashctx = NULL;
761 	char c = id;
762 	u_int have;
763 	size_t mdsz;
764 	u_char *digest;
765 	int r;
766 
767 	if ((mdsz = ssh_digest_bytes(kex->hash_alg)) == 0)
768 		return SSH_ERR_INVALID_ARGUMENT;
769 	if ((digest = calloc(1, roundup(need, mdsz))) == NULL) {
770 		r = SSH_ERR_ALLOC_FAIL;
771 		goto out;
772 	}
773 
774 	/* K1 = HASH(K || H || "A" || session_id) */
775 	if ((hashctx = ssh_digest_start(kex->hash_alg)) == NULL ||
776 	    ssh_digest_update_buffer(hashctx, shared_secret) != 0 ||
777 	    ssh_digest_update(hashctx, hash, hashlen) != 0 ||
778 	    ssh_digest_update(hashctx, &c, 1) != 0 ||
779 	    ssh_digest_update(hashctx, kex->session_id,
780 	    kex->session_id_len) != 0 ||
781 	    ssh_digest_final(hashctx, digest, mdsz) != 0) {
782 		r = SSH_ERR_LIBCRYPTO_ERROR;
783 		goto out;
784 	}
785 	ssh_digest_free(hashctx);
786 	hashctx = NULL;
787 
788 	/*
789 	 * expand key:
790 	 * Kn = HASH(K || H || K1 || K2 || ... || Kn-1)
791 	 * Key = K1 || K2 || ... || Kn
792 	 */
793 	for (have = mdsz; need > have; have += mdsz) {
794 		if ((hashctx = ssh_digest_start(kex->hash_alg)) == NULL ||
795 		    ssh_digest_update_buffer(hashctx, shared_secret) != 0 ||
796 		    ssh_digest_update(hashctx, hash, hashlen) != 0 ||
797 		    ssh_digest_update(hashctx, digest, have) != 0 ||
798 		    ssh_digest_final(hashctx, digest + have, mdsz) != 0) {
799 			r = SSH_ERR_LIBCRYPTO_ERROR;
800 			goto out;
801 		}
802 		ssh_digest_free(hashctx);
803 		hashctx = NULL;
804 	}
805 #ifdef DEBUG_KEX
806 	fprintf(stderr, "key '%c'== ", c);
807 	dump_digest("key", digest, need);
808 #endif
809 	*keyp = digest;
810 	digest = NULL;
811 	r = 0;
812  out:
813 	if (digest)
814 		free(digest);
815 	ssh_digest_free(hashctx);
816 	return r;
817 }
818 
819 #define NKEYS	6
820 int
821 kex_derive_keys(struct ssh *ssh, u_char *hash, u_int hashlen,
822     const struct sshbuf *shared_secret)
823 {
824 	struct kex *kex = ssh->kex;
825 	u_char *keys[NKEYS];
826 	u_int i, j, mode, ctos;
827 	int r;
828 
829 	for (i = 0; i < NKEYS; i++) {
830 		if ((r = derive_key(ssh, 'A'+i, kex->we_need, hash, hashlen,
831 		    shared_secret, &keys[i])) != 0) {
832 			for (j = 0; j < i; j++)
833 				free(keys[j]);
834 			return r;
835 		}
836 	}
837 	for (mode = 0; mode < MODE_MAX; mode++) {
838 		ctos = (!kex->server && mode == MODE_OUT) ||
839 		    (kex->server && mode == MODE_IN);
840 		kex->newkeys[mode]->enc.iv  = keys[ctos ? 0 : 1];
841 		kex->newkeys[mode]->enc.key = keys[ctos ? 2 : 3];
842 		kex->newkeys[mode]->mac.key = keys[ctos ? 4 : 5];
843 	}
844 	return 0;
845 }
846 
847 #ifdef WITH_OPENSSL
848 int
849 kex_derive_keys_bn(struct ssh *ssh, u_char *hash, u_int hashlen,
850     const BIGNUM *secret)
851 {
852 	struct sshbuf *shared_secret;
853 	int r;
854 
855 	if ((shared_secret = sshbuf_new()) == NULL)
856 		return SSH_ERR_ALLOC_FAIL;
857 	if ((r = sshbuf_put_bignum2(shared_secret, secret)) == 0)
858 		r = kex_derive_keys(ssh, hash, hashlen, shared_secret);
859 	sshbuf_free(shared_secret);
860 	return r;
861 }
862 #endif
863 
864 #ifdef WITH_SSH1
865 int
866 derive_ssh1_session_id(BIGNUM *host_modulus, BIGNUM *server_modulus,
867     u_int8_t cookie[8], u_int8_t id[16])
868 {
869 	u_int8_t hbuf[2048], sbuf[2048], obuf[SSH_DIGEST_MAX_LENGTH];
870 	struct ssh_digest_ctx *hashctx = NULL;
871 	size_t hlen, slen;
872 	int r;
873 
874 	hlen = BN_num_bytes(host_modulus);
875 	slen = BN_num_bytes(server_modulus);
876 	if (hlen < (512 / 8) || (u_int)hlen > sizeof(hbuf) ||
877 	    slen < (512 / 8) || (u_int)slen > sizeof(sbuf))
878 		return SSH_ERR_KEY_BITS_MISMATCH;
879 	if (BN_bn2bin(host_modulus, hbuf) <= 0 ||
880 	    BN_bn2bin(server_modulus, sbuf) <= 0) {
881 		r = SSH_ERR_LIBCRYPTO_ERROR;
882 		goto out;
883 	}
884 	if ((hashctx = ssh_digest_start(SSH_DIGEST_MD5)) == NULL) {
885 		r = SSH_ERR_ALLOC_FAIL;
886 		goto out;
887 	}
888 	if (ssh_digest_update(hashctx, hbuf, hlen) != 0 ||
889 	    ssh_digest_update(hashctx, sbuf, slen) != 0 ||
890 	    ssh_digest_update(hashctx, cookie, 8) != 0 ||
891 	    ssh_digest_final(hashctx, obuf, sizeof(obuf)) != 0) {
892 		r = SSH_ERR_LIBCRYPTO_ERROR;
893 		goto out;
894 	}
895 	memcpy(id, obuf, ssh_digest_bytes(SSH_DIGEST_MD5));
896 	r = 0;
897  out:
898 	ssh_digest_free(hashctx);
899 	explicit_bzero(hbuf, sizeof(hbuf));
900 	explicit_bzero(sbuf, sizeof(sbuf));
901 	explicit_bzero(obuf, sizeof(obuf));
902 	return r;
903 }
904 #endif
905 
906 #if defined(DEBUG_KEX) || defined(DEBUG_KEXDH) || defined(DEBUG_KEXECDH)
907 void
908 dump_digest(char *msg, u_char *digest, int len)
909 {
910 	fprintf(stderr, "%s\n", msg);
911 	sshbuf_dump_data(digest, len, stderr);
912 }
913 #endif
914