xref: /openbsd-src/usr.bin/ssh/kex.c (revision c7e8ea31cd41a963f06f0a8ba93948b06aa6b4a4)
1 /* $OpenBSD: kex.c,v 1.134 2017/06/13 12:13:59 djm Exp $ */
2 /*
3  * Copyright (c) 2000, 2001 Markus Friedl.  All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  * 1. Redistributions of source code must retain the above copyright
9  *    notice, this list of conditions and the following disclaimer.
10  * 2. Redistributions in binary form must reproduce the above copyright
11  *    notice, this list of conditions and the following disclaimer in the
12  *    documentation and/or other materials provided with the distribution.
13  *
14  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
15  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
16  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
17  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
18  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
19  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
20  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
21  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
23  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24  */
25 
26 
27 #include <signal.h>
28 #include <stdio.h>
29 #include <stdlib.h>
30 #include <string.h>
31 
32 #ifdef WITH_OPENSSL
33 #include <openssl/crypto.h>
34 #endif
35 
36 #include "ssh2.h"
37 #include "packet.h"
38 #include "compat.h"
39 #include "cipher.h"
40 #include "sshkey.h"
41 #include "kex.h"
42 #include "log.h"
43 #include "mac.h"
44 #include "match.h"
45 #include "misc.h"
46 #include "dispatch.h"
47 #include "monitor.h"
48 
49 #include "ssherr.h"
50 #include "sshbuf.h"
51 #include "digest.h"
52 
53 /* prototype */
54 static int kex_choose_conf(struct ssh *);
55 static int kex_input_newkeys(int, u_int32_t, struct ssh *);
56 
57 static const char *proposal_names[PROPOSAL_MAX] = {
58 	"KEX algorithms",
59 	"host key algorithms",
60 	"ciphers ctos",
61 	"ciphers stoc",
62 	"MACs ctos",
63 	"MACs stoc",
64 	"compression ctos",
65 	"compression stoc",
66 	"languages ctos",
67 	"languages stoc",
68 };
69 
70 struct kexalg {
71 	char *name;
72 	u_int type;
73 	int ec_nid;
74 	int hash_alg;
75 };
76 static const struct kexalg kexalgs[] = {
77 #ifdef WITH_OPENSSL
78 	{ KEX_DH1, KEX_DH_GRP1_SHA1, 0, SSH_DIGEST_SHA1 },
79 	{ KEX_DH14_SHA1, KEX_DH_GRP14_SHA1, 0, SSH_DIGEST_SHA1 },
80 	{ KEX_DH14_SHA256, KEX_DH_GRP14_SHA256, 0, SSH_DIGEST_SHA256 },
81 	{ KEX_DH16_SHA512, KEX_DH_GRP16_SHA512, 0, SSH_DIGEST_SHA512 },
82 	{ KEX_DH18_SHA512, KEX_DH_GRP18_SHA512, 0, SSH_DIGEST_SHA512 },
83 	{ KEX_DHGEX_SHA1, KEX_DH_GEX_SHA1, 0, SSH_DIGEST_SHA1 },
84 	{ KEX_DHGEX_SHA256, KEX_DH_GEX_SHA256, 0, SSH_DIGEST_SHA256 },
85 	{ KEX_ECDH_SHA2_NISTP256, KEX_ECDH_SHA2,
86 	    NID_X9_62_prime256v1, SSH_DIGEST_SHA256 },
87 	{ KEX_ECDH_SHA2_NISTP384, KEX_ECDH_SHA2, NID_secp384r1,
88 	    SSH_DIGEST_SHA384 },
89 	{ KEX_ECDH_SHA2_NISTP521, KEX_ECDH_SHA2, NID_secp521r1,
90 	    SSH_DIGEST_SHA512 },
91 #endif
92 	{ KEX_CURVE25519_SHA256, KEX_C25519_SHA256, 0, SSH_DIGEST_SHA256 },
93 	{ KEX_CURVE25519_SHA256_OLD, KEX_C25519_SHA256, 0, SSH_DIGEST_SHA256 },
94 	{ NULL, -1, -1, -1},
95 };
96 
97 char *
98 kex_alg_list(char sep)
99 {
100 	char *ret = NULL, *tmp;
101 	size_t nlen, rlen = 0;
102 	const struct kexalg *k;
103 
104 	for (k = kexalgs; k->name != NULL; k++) {
105 		if (ret != NULL)
106 			ret[rlen++] = sep;
107 		nlen = strlen(k->name);
108 		if ((tmp = realloc(ret, rlen + nlen + 2)) == NULL) {
109 			free(ret);
110 			return NULL;
111 		}
112 		ret = tmp;
113 		memcpy(ret + rlen, k->name, nlen + 1);
114 		rlen += nlen;
115 	}
116 	return ret;
117 }
118 
119 static const struct kexalg *
120 kex_alg_by_name(const char *name)
121 {
122 	const struct kexalg *k;
123 
124 	for (k = kexalgs; k->name != NULL; k++) {
125 		if (strcmp(k->name, name) == 0)
126 			return k;
127 	}
128 	return NULL;
129 }
130 
131 /* Validate KEX method name list */
132 int
133 kex_names_valid(const char *names)
134 {
135 	char *s, *cp, *p;
136 
137 	if (names == NULL || strcmp(names, "") == 0)
138 		return 0;
139 	if ((s = cp = strdup(names)) == NULL)
140 		return 0;
141 	for ((p = strsep(&cp, ",")); p && *p != '\0';
142 	    (p = strsep(&cp, ","))) {
143 		if (kex_alg_by_name(p) == NULL) {
144 			error("Unsupported KEX algorithm \"%.100s\"", p);
145 			free(s);
146 			return 0;
147 		}
148 	}
149 	debug3("kex names ok: [%s]", names);
150 	free(s);
151 	return 1;
152 }
153 
154 /*
155  * Concatenate algorithm names, avoiding duplicates in the process.
156  * Caller must free returned string.
157  */
158 char *
159 kex_names_cat(const char *a, const char *b)
160 {
161 	char *ret = NULL, *tmp = NULL, *cp, *p, *m;
162 	size_t len;
163 
164 	if (a == NULL || *a == '\0')
165 		return NULL;
166 	if (b == NULL || *b == '\0')
167 		return strdup(a);
168 	if (strlen(b) > 1024*1024)
169 		return NULL;
170 	len = strlen(a) + strlen(b) + 2;
171 	if ((tmp = cp = strdup(b)) == NULL ||
172 	    (ret = calloc(1, len)) == NULL) {
173 		free(tmp);
174 		return NULL;
175 	}
176 	strlcpy(ret, a, len);
177 	for ((p = strsep(&cp, ",")); p && *p != '\0'; (p = strsep(&cp, ","))) {
178 		if ((m = match_list(ret, p, NULL)) != NULL) {
179 			free(m);
180 			continue; /* Algorithm already present */
181 		}
182 		if (strlcat(ret, ",", len) >= len ||
183 		    strlcat(ret, p, len) >= len) {
184 			free(tmp);
185 			free(ret);
186 			return NULL; /* Shouldn't happen */
187 		}
188 	}
189 	free(tmp);
190 	return ret;
191 }
192 
193 /*
194  * Assemble a list of algorithms from a default list and a string from a
195  * configuration file. The user-provided string may begin with '+' to
196  * indicate that it should be appended to the default or '-' that the
197  * specified names should be removed.
198  */
199 int
200 kex_assemble_names(const char *def, char **list)
201 {
202 	char *ret;
203 
204 	if (list == NULL || *list == NULL || **list == '\0') {
205 		*list = strdup(def);
206 		return 0;
207 	}
208 	if (**list == '+') {
209 		if ((ret = kex_names_cat(def, *list + 1)) == NULL)
210 			return SSH_ERR_ALLOC_FAIL;
211 		free(*list);
212 		*list = ret;
213 	} else if (**list == '-') {
214 		if ((ret = match_filter_list(def, *list + 1)) == NULL)
215 			return SSH_ERR_ALLOC_FAIL;
216 		free(*list);
217 		*list = ret;
218 	}
219 
220 	return 0;
221 }
222 
223 /* put algorithm proposal into buffer */
224 int
225 kex_prop2buf(struct sshbuf *b, char *proposal[PROPOSAL_MAX])
226 {
227 	u_int i;
228 	int r;
229 
230 	sshbuf_reset(b);
231 
232 	/*
233 	 * add a dummy cookie, the cookie will be overwritten by
234 	 * kex_send_kexinit(), each time a kexinit is set
235 	 */
236 	for (i = 0; i < KEX_COOKIE_LEN; i++) {
237 		if ((r = sshbuf_put_u8(b, 0)) != 0)
238 			return r;
239 	}
240 	for (i = 0; i < PROPOSAL_MAX; i++) {
241 		if ((r = sshbuf_put_cstring(b, proposal[i])) != 0)
242 			return r;
243 	}
244 	if ((r = sshbuf_put_u8(b, 0)) != 0 ||	/* first_kex_packet_follows */
245 	    (r = sshbuf_put_u32(b, 0)) != 0)	/* uint32 reserved */
246 		return r;
247 	return 0;
248 }
249 
250 /* parse buffer and return algorithm proposal */
251 int
252 kex_buf2prop(struct sshbuf *raw, int *first_kex_follows, char ***propp)
253 {
254 	struct sshbuf *b = NULL;
255 	u_char v;
256 	u_int i;
257 	char **proposal = NULL;
258 	int r;
259 
260 	*propp = NULL;
261 	if ((proposal = calloc(PROPOSAL_MAX, sizeof(char *))) == NULL)
262 		return SSH_ERR_ALLOC_FAIL;
263 	if ((b = sshbuf_fromb(raw)) == NULL) {
264 		r = SSH_ERR_ALLOC_FAIL;
265 		goto out;
266 	}
267 	if ((r = sshbuf_consume(b, KEX_COOKIE_LEN)) != 0) /* skip cookie */
268 		goto out;
269 	/* extract kex init proposal strings */
270 	for (i = 0; i < PROPOSAL_MAX; i++) {
271 		if ((r = sshbuf_get_cstring(b, &(proposal[i]), NULL)) != 0)
272 			goto out;
273 		debug2("%s: %s", proposal_names[i], proposal[i]);
274 	}
275 	/* first kex follows / reserved */
276 	if ((r = sshbuf_get_u8(b, &v)) != 0 ||	/* first_kex_follows */
277 	    (r = sshbuf_get_u32(b, &i)) != 0)	/* reserved */
278 		goto out;
279 	if (first_kex_follows != NULL)
280 		*first_kex_follows = v;
281 	debug2("first_kex_follows %d ", v);
282 	debug2("reserved %u ", i);
283 	r = 0;
284 	*propp = proposal;
285  out:
286 	if (r != 0 && proposal != NULL)
287 		kex_prop_free(proposal);
288 	sshbuf_free(b);
289 	return r;
290 }
291 
292 void
293 kex_prop_free(char **proposal)
294 {
295 	u_int i;
296 
297 	if (proposal == NULL)
298 		return;
299 	for (i = 0; i < PROPOSAL_MAX; i++)
300 		free(proposal[i]);
301 	free(proposal);
302 }
303 
304 /* ARGSUSED */
305 static int
306 kex_protocol_error(int type, u_int32_t seq, struct ssh *ssh)
307 {
308 	int r;
309 
310 	error("kex protocol error: type %d seq %u", type, seq);
311 	if ((r = sshpkt_start(ssh, SSH2_MSG_UNIMPLEMENTED)) != 0 ||
312 	    (r = sshpkt_put_u32(ssh, seq)) != 0 ||
313 	    (r = sshpkt_send(ssh)) != 0)
314 		return r;
315 	return 0;
316 }
317 
318 static void
319 kex_reset_dispatch(struct ssh *ssh)
320 {
321 	ssh_dispatch_range(ssh, SSH2_MSG_TRANSPORT_MIN,
322 	    SSH2_MSG_TRANSPORT_MAX, &kex_protocol_error);
323 }
324 
325 static int
326 kex_send_ext_info(struct ssh *ssh)
327 {
328 	int r;
329 	char *algs;
330 
331 	if ((algs = sshkey_alg_list(0, 1, 1, ',')) == NULL)
332 		return SSH_ERR_ALLOC_FAIL;
333 	if ((r = sshpkt_start(ssh, SSH2_MSG_EXT_INFO)) != 0 ||
334 	    (r = sshpkt_put_u32(ssh, 1)) != 0 ||
335 	    (r = sshpkt_put_cstring(ssh, "server-sig-algs")) != 0 ||
336 	    (r = sshpkt_put_cstring(ssh, algs)) != 0 ||
337 	    (r = sshpkt_send(ssh)) != 0)
338 		goto out;
339 	/* success */
340 	r = 0;
341  out:
342 	free(algs);
343 	return r;
344 }
345 
346 int
347 kex_send_newkeys(struct ssh *ssh)
348 {
349 	int r;
350 
351 	kex_reset_dispatch(ssh);
352 	if ((r = sshpkt_start(ssh, SSH2_MSG_NEWKEYS)) != 0 ||
353 	    (r = sshpkt_send(ssh)) != 0)
354 		return r;
355 	debug("SSH2_MSG_NEWKEYS sent");
356 	debug("expecting SSH2_MSG_NEWKEYS");
357 	ssh_dispatch_set(ssh, SSH2_MSG_NEWKEYS, &kex_input_newkeys);
358 	if (ssh->kex->ext_info_c)
359 		if ((r = kex_send_ext_info(ssh)) != 0)
360 			return r;
361 	return 0;
362 }
363 
364 int
365 kex_input_ext_info(int type, u_int32_t seq, struct ssh *ssh)
366 {
367 	struct kex *kex = ssh->kex;
368 	u_int32_t i, ninfo;
369 	char *name, *found;
370 	u_char *val;
371 	size_t vlen;
372 	int r;
373 
374 	debug("SSH2_MSG_EXT_INFO received");
375 	ssh_dispatch_set(ssh, SSH2_MSG_EXT_INFO, &kex_protocol_error);
376 	if ((r = sshpkt_get_u32(ssh, &ninfo)) != 0)
377 		return r;
378 	for (i = 0; i < ninfo; i++) {
379 		if ((r = sshpkt_get_cstring(ssh, &name, NULL)) != 0)
380 			return r;
381 		if ((r = sshpkt_get_string(ssh, &val, &vlen)) != 0) {
382 			free(name);
383 			return r;
384 		}
385 		if (strcmp(name, "server-sig-algs") == 0) {
386 			/* Ensure no \0 lurking in value */
387 			if (memchr(val, '\0', vlen) != NULL) {
388 				error("%s: nul byte in %s", __func__, name);
389 				return SSH_ERR_INVALID_FORMAT;
390 			}
391 			debug("%s: %s=<%s>", __func__, name, val);
392 			found = match_list("rsa-sha2-256", val, NULL);
393 			if (found) {
394 				kex->rsa_sha2 = 256;
395 				free(found);
396 			}
397 			found = match_list("rsa-sha2-512", val, NULL);
398 			if (found) {
399 				kex->rsa_sha2 = 512;
400 				free(found);
401 			}
402 		} else
403 			debug("%s: %s (unrecognised)", __func__, name);
404 		free(name);
405 		free(val);
406 	}
407 	return sshpkt_get_end(ssh);
408 }
409 
410 static int
411 kex_input_newkeys(int type, u_int32_t seq, struct ssh *ssh)
412 {
413 	struct kex *kex = ssh->kex;
414 	int r;
415 
416 	debug("SSH2_MSG_NEWKEYS received");
417 	ssh_dispatch_set(ssh, SSH2_MSG_NEWKEYS, &kex_protocol_error);
418 	ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_input_kexinit);
419 	if ((r = sshpkt_get_end(ssh)) != 0)
420 		return r;
421 	if ((r = ssh_set_newkeys(ssh, MODE_IN)) != 0)
422 		return r;
423 	kex->done = 1;
424 	sshbuf_reset(kex->peer);
425 	/* sshbuf_reset(kex->my); */
426 	kex->flags &= ~KEX_INIT_SENT;
427 	free(kex->name);
428 	kex->name = NULL;
429 	return 0;
430 }
431 
432 int
433 kex_send_kexinit(struct ssh *ssh)
434 {
435 	u_char *cookie;
436 	struct kex *kex = ssh->kex;
437 	int r;
438 
439 	if (kex == NULL)
440 		return SSH_ERR_INTERNAL_ERROR;
441 	if (kex->flags & KEX_INIT_SENT)
442 		return 0;
443 	kex->done = 0;
444 
445 	/* generate a random cookie */
446 	if (sshbuf_len(kex->my) < KEX_COOKIE_LEN)
447 		return SSH_ERR_INVALID_FORMAT;
448 	if ((cookie = sshbuf_mutable_ptr(kex->my)) == NULL)
449 		return SSH_ERR_INTERNAL_ERROR;
450 	arc4random_buf(cookie, KEX_COOKIE_LEN);
451 
452 	if ((r = sshpkt_start(ssh, SSH2_MSG_KEXINIT)) != 0 ||
453 	    (r = sshpkt_putb(ssh, kex->my)) != 0 ||
454 	    (r = sshpkt_send(ssh)) != 0)
455 		return r;
456 	debug("SSH2_MSG_KEXINIT sent");
457 	kex->flags |= KEX_INIT_SENT;
458 	return 0;
459 }
460 
461 /* ARGSUSED */
462 int
463 kex_input_kexinit(int type, u_int32_t seq, struct ssh *ssh)
464 {
465 	struct kex *kex = ssh->kex;
466 	const u_char *ptr;
467 	u_int i;
468 	size_t dlen;
469 	int r;
470 
471 	debug("SSH2_MSG_KEXINIT received");
472 	if (kex == NULL)
473 		return SSH_ERR_INVALID_ARGUMENT;
474 
475 	ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, NULL);
476 	ptr = sshpkt_ptr(ssh, &dlen);
477 	if ((r = sshbuf_put(kex->peer, ptr, dlen)) != 0)
478 		return r;
479 
480 	/* discard packet */
481 	for (i = 0; i < KEX_COOKIE_LEN; i++)
482 		if ((r = sshpkt_get_u8(ssh, NULL)) != 0)
483 			return r;
484 	for (i = 0; i < PROPOSAL_MAX; i++)
485 		if ((r = sshpkt_get_string(ssh, NULL, NULL)) != 0)
486 			return r;
487 	/*
488 	 * XXX RFC4253 sec 7: "each side MAY guess" - currently no supported
489 	 * KEX method has the server move first, but a server might be using
490 	 * a custom method or one that we otherwise don't support. We should
491 	 * be prepared to remember first_kex_follows here so we can eat a
492 	 * packet later.
493 	 * XXX2 - RFC4253 is kind of ambiguous on what first_kex_follows means
494 	 * for cases where the server *doesn't* go first. I guess we should
495 	 * ignore it when it is set for these cases, which is what we do now.
496 	 */
497 	if ((r = sshpkt_get_u8(ssh, NULL)) != 0 ||	/* first_kex_follows */
498 	    (r = sshpkt_get_u32(ssh, NULL)) != 0 ||	/* reserved */
499 	    (r = sshpkt_get_end(ssh)) != 0)
500 			return r;
501 
502 	if (!(kex->flags & KEX_INIT_SENT))
503 		if ((r = kex_send_kexinit(ssh)) != 0)
504 			return r;
505 	if ((r = kex_choose_conf(ssh)) != 0)
506 		return r;
507 
508 	if (kex->kex_type < KEX_MAX && kex->kex[kex->kex_type] != NULL)
509 		return (kex->kex[kex->kex_type])(ssh);
510 
511 	return SSH_ERR_INTERNAL_ERROR;
512 }
513 
514 int
515 kex_new(struct ssh *ssh, char *proposal[PROPOSAL_MAX], struct kex **kexp)
516 {
517 	struct kex *kex;
518 	int r;
519 
520 	*kexp = NULL;
521 	if ((kex = calloc(1, sizeof(*kex))) == NULL)
522 		return SSH_ERR_ALLOC_FAIL;
523 	if ((kex->peer = sshbuf_new()) == NULL ||
524 	    (kex->my = sshbuf_new()) == NULL) {
525 		r = SSH_ERR_ALLOC_FAIL;
526 		goto out;
527 	}
528 	if ((r = kex_prop2buf(kex->my, proposal)) != 0)
529 		goto out;
530 	kex->done = 0;
531 	kex_reset_dispatch(ssh);
532 	ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_input_kexinit);
533 	r = 0;
534 	*kexp = kex;
535  out:
536 	if (r != 0)
537 		kex_free(kex);
538 	return r;
539 }
540 
541 void
542 kex_free_newkeys(struct newkeys *newkeys)
543 {
544 	if (newkeys == NULL)
545 		return;
546 	if (newkeys->enc.key) {
547 		explicit_bzero(newkeys->enc.key, newkeys->enc.key_len);
548 		free(newkeys->enc.key);
549 		newkeys->enc.key = NULL;
550 	}
551 	if (newkeys->enc.iv) {
552 		explicit_bzero(newkeys->enc.iv, newkeys->enc.iv_len);
553 		free(newkeys->enc.iv);
554 		newkeys->enc.iv = NULL;
555 	}
556 	free(newkeys->enc.name);
557 	explicit_bzero(&newkeys->enc, sizeof(newkeys->enc));
558 	free(newkeys->comp.name);
559 	explicit_bzero(&newkeys->comp, sizeof(newkeys->comp));
560 	mac_clear(&newkeys->mac);
561 	if (newkeys->mac.key) {
562 		explicit_bzero(newkeys->mac.key, newkeys->mac.key_len);
563 		free(newkeys->mac.key);
564 		newkeys->mac.key = NULL;
565 	}
566 	free(newkeys->mac.name);
567 	explicit_bzero(&newkeys->mac, sizeof(newkeys->mac));
568 	explicit_bzero(newkeys, sizeof(*newkeys));
569 	free(newkeys);
570 }
571 
572 void
573 kex_free(struct kex *kex)
574 {
575 	u_int mode;
576 
577 #ifdef WITH_OPENSSL
578 	if (kex->dh)
579 		DH_free(kex->dh);
580 	if (kex->ec_client_key)
581 		EC_KEY_free(kex->ec_client_key);
582 #endif
583 	for (mode = 0; mode < MODE_MAX; mode++) {
584 		kex_free_newkeys(kex->newkeys[mode]);
585 		kex->newkeys[mode] = NULL;
586 	}
587 	sshbuf_free(kex->peer);
588 	sshbuf_free(kex->my);
589 	free(kex->session_id);
590 	free(kex->client_version_string);
591 	free(kex->server_version_string);
592 	free(kex->failed_choice);
593 	free(kex->hostkey_alg);
594 	free(kex->name);
595 	free(kex);
596 }
597 
598 int
599 kex_setup(struct ssh *ssh, char *proposal[PROPOSAL_MAX])
600 {
601 	int r;
602 
603 	if ((r = kex_new(ssh, proposal, &ssh->kex)) != 0)
604 		return r;
605 	if ((r = kex_send_kexinit(ssh)) != 0) {		/* we start */
606 		kex_free(ssh->kex);
607 		ssh->kex = NULL;
608 		return r;
609 	}
610 	return 0;
611 }
612 
613 /*
614  * Request key re-exchange, returns 0 on success or a ssherr.h error
615  * code otherwise. Must not be called if KEX is incomplete or in-progress.
616  */
617 int
618 kex_start_rekex(struct ssh *ssh)
619 {
620 	if (ssh->kex == NULL) {
621 		error("%s: no kex", __func__);
622 		return SSH_ERR_INTERNAL_ERROR;
623 	}
624 	if (ssh->kex->done == 0) {
625 		error("%s: requested twice", __func__);
626 		return SSH_ERR_INTERNAL_ERROR;
627 	}
628 	ssh->kex->done = 0;
629 	return kex_send_kexinit(ssh);
630 }
631 
632 static int
633 choose_enc(struct sshenc *enc, char *client, char *server)
634 {
635 	char *name = match_list(client, server, NULL);
636 
637 	if (name == NULL)
638 		return SSH_ERR_NO_CIPHER_ALG_MATCH;
639 	if ((enc->cipher = cipher_by_name(name)) == NULL) {
640 		free(name);
641 		return SSH_ERR_INTERNAL_ERROR;
642 	}
643 	enc->name = name;
644 	enc->enabled = 0;
645 	enc->iv = NULL;
646 	enc->iv_len = cipher_ivlen(enc->cipher);
647 	enc->key = NULL;
648 	enc->key_len = cipher_keylen(enc->cipher);
649 	enc->block_size = cipher_blocksize(enc->cipher);
650 	return 0;
651 }
652 
653 static int
654 choose_mac(struct ssh *ssh, struct sshmac *mac, char *client, char *server)
655 {
656 	char *name = match_list(client, server, NULL);
657 
658 	if (name == NULL)
659 		return SSH_ERR_NO_MAC_ALG_MATCH;
660 	if (mac_setup(mac, name) < 0) {
661 		free(name);
662 		return SSH_ERR_INTERNAL_ERROR;
663 	}
664 	/* truncate the key */
665 	if (ssh->compat & SSH_BUG_HMAC)
666 		mac->key_len = 16;
667 	mac->name = name;
668 	mac->key = NULL;
669 	mac->enabled = 0;
670 	return 0;
671 }
672 
673 static int
674 choose_comp(struct sshcomp *comp, char *client, char *server)
675 {
676 	char *name = match_list(client, server, NULL);
677 
678 	if (name == NULL)
679 		return SSH_ERR_NO_COMPRESS_ALG_MATCH;
680 	if (strcmp(name, "zlib@openssh.com") == 0) {
681 		comp->type = COMP_DELAYED;
682 	} else if (strcmp(name, "zlib") == 0) {
683 		comp->type = COMP_ZLIB;
684 	} else if (strcmp(name, "none") == 0) {
685 		comp->type = COMP_NONE;
686 	} else {
687 		free(name);
688 		return SSH_ERR_INTERNAL_ERROR;
689 	}
690 	comp->name = name;
691 	return 0;
692 }
693 
694 static int
695 choose_kex(struct kex *k, char *client, char *server)
696 {
697 	const struct kexalg *kexalg;
698 
699 	k->name = match_list(client, server, NULL);
700 
701 	debug("kex: algorithm: %s", k->name ? k->name : "(no match)");
702 	if (k->name == NULL)
703 		return SSH_ERR_NO_KEX_ALG_MATCH;
704 	if ((kexalg = kex_alg_by_name(k->name)) == NULL)
705 		return SSH_ERR_INTERNAL_ERROR;
706 	k->kex_type = kexalg->type;
707 	k->hash_alg = kexalg->hash_alg;
708 	k->ec_nid = kexalg->ec_nid;
709 	return 0;
710 }
711 
712 static int
713 choose_hostkeyalg(struct kex *k, char *client, char *server)
714 {
715 	k->hostkey_alg = match_list(client, server, NULL);
716 
717 	debug("kex: host key algorithm: %s",
718 	    k->hostkey_alg ? k->hostkey_alg : "(no match)");
719 	if (k->hostkey_alg == NULL)
720 		return SSH_ERR_NO_HOSTKEY_ALG_MATCH;
721 	k->hostkey_type = sshkey_type_from_name(k->hostkey_alg);
722 	if (k->hostkey_type == KEY_UNSPEC)
723 		return SSH_ERR_INTERNAL_ERROR;
724 	k->hostkey_nid = sshkey_ecdsa_nid_from_name(k->hostkey_alg);
725 	return 0;
726 }
727 
728 static int
729 proposals_match(char *my[PROPOSAL_MAX], char *peer[PROPOSAL_MAX])
730 {
731 	static int check[] = {
732 		PROPOSAL_KEX_ALGS, PROPOSAL_SERVER_HOST_KEY_ALGS, -1
733 	};
734 	int *idx;
735 	char *p;
736 
737 	for (idx = &check[0]; *idx != -1; idx++) {
738 		if ((p = strchr(my[*idx], ',')) != NULL)
739 			*p = '\0';
740 		if ((p = strchr(peer[*idx], ',')) != NULL)
741 			*p = '\0';
742 		if (strcmp(my[*idx], peer[*idx]) != 0) {
743 			debug2("proposal mismatch: my %s peer %s",
744 			    my[*idx], peer[*idx]);
745 			return (0);
746 		}
747 	}
748 	debug2("proposals match");
749 	return (1);
750 }
751 
752 static int
753 kex_choose_conf(struct ssh *ssh)
754 {
755 	struct kex *kex = ssh->kex;
756 	struct newkeys *newkeys;
757 	char **my = NULL, **peer = NULL;
758 	char **cprop, **sprop;
759 	int nenc, nmac, ncomp;
760 	u_int mode, ctos, need, dh_need, authlen;
761 	int r, first_kex_follows;
762 
763 	debug2("local %s KEXINIT proposal", kex->server ? "server" : "client");
764 	if ((r = kex_buf2prop(kex->my, NULL, &my)) != 0)
765 		goto out;
766 	debug2("peer %s KEXINIT proposal", kex->server ? "client" : "server");
767 	if ((r = kex_buf2prop(kex->peer, &first_kex_follows, &peer)) != 0)
768 		goto out;
769 
770 	if (kex->server) {
771 		cprop=peer;
772 		sprop=my;
773 	} else {
774 		cprop=my;
775 		sprop=peer;
776 	}
777 
778 	/* Check whether client supports ext_info_c */
779 	if (kex->server) {
780 		char *ext;
781 
782 		ext = match_list("ext-info-c", peer[PROPOSAL_KEX_ALGS], NULL);
783 		kex->ext_info_c = (ext != NULL);
784 		free(ext);
785 	}
786 
787 	/* Algorithm Negotiation */
788 	if ((r = choose_kex(kex, cprop[PROPOSAL_KEX_ALGS],
789 	    sprop[PROPOSAL_KEX_ALGS])) != 0) {
790 		kex->failed_choice = peer[PROPOSAL_KEX_ALGS];
791 		peer[PROPOSAL_KEX_ALGS] = NULL;
792 		goto out;
793 	}
794 	if ((r = choose_hostkeyalg(kex, cprop[PROPOSAL_SERVER_HOST_KEY_ALGS],
795 	    sprop[PROPOSAL_SERVER_HOST_KEY_ALGS])) != 0) {
796 		kex->failed_choice = peer[PROPOSAL_SERVER_HOST_KEY_ALGS];
797 		peer[PROPOSAL_SERVER_HOST_KEY_ALGS] = NULL;
798 		goto out;
799 	}
800 	for (mode = 0; mode < MODE_MAX; mode++) {
801 		if ((newkeys = calloc(1, sizeof(*newkeys))) == NULL) {
802 			r = SSH_ERR_ALLOC_FAIL;
803 			goto out;
804 		}
805 		kex->newkeys[mode] = newkeys;
806 		ctos = (!kex->server && mode == MODE_OUT) ||
807 		    (kex->server && mode == MODE_IN);
808 		nenc  = ctos ? PROPOSAL_ENC_ALGS_CTOS  : PROPOSAL_ENC_ALGS_STOC;
809 		nmac  = ctos ? PROPOSAL_MAC_ALGS_CTOS  : PROPOSAL_MAC_ALGS_STOC;
810 		ncomp = ctos ? PROPOSAL_COMP_ALGS_CTOS : PROPOSAL_COMP_ALGS_STOC;
811 		if ((r = choose_enc(&newkeys->enc, cprop[nenc],
812 		    sprop[nenc])) != 0) {
813 			kex->failed_choice = peer[nenc];
814 			peer[nenc] = NULL;
815 			goto out;
816 		}
817 		authlen = cipher_authlen(newkeys->enc.cipher);
818 		/* ignore mac for authenticated encryption */
819 		if (authlen == 0 &&
820 		    (r = choose_mac(ssh, &newkeys->mac, cprop[nmac],
821 		    sprop[nmac])) != 0) {
822 			kex->failed_choice = peer[nmac];
823 			peer[nmac] = NULL;
824 			goto out;
825 		}
826 		if ((r = choose_comp(&newkeys->comp, cprop[ncomp],
827 		    sprop[ncomp])) != 0) {
828 			kex->failed_choice = peer[ncomp];
829 			peer[ncomp] = NULL;
830 			goto out;
831 		}
832 		debug("kex: %s cipher: %s MAC: %s compression: %s",
833 		    ctos ? "client->server" : "server->client",
834 		    newkeys->enc.name,
835 		    authlen == 0 ? newkeys->mac.name : "<implicit>",
836 		    newkeys->comp.name);
837 	}
838 	need = dh_need = 0;
839 	for (mode = 0; mode < MODE_MAX; mode++) {
840 		newkeys = kex->newkeys[mode];
841 		need = MAXIMUM(need, newkeys->enc.key_len);
842 		need = MAXIMUM(need, newkeys->enc.block_size);
843 		need = MAXIMUM(need, newkeys->enc.iv_len);
844 		need = MAXIMUM(need, newkeys->mac.key_len);
845 		dh_need = MAXIMUM(dh_need, cipher_seclen(newkeys->enc.cipher));
846 		dh_need = MAXIMUM(dh_need, newkeys->enc.block_size);
847 		dh_need = MAXIMUM(dh_need, newkeys->enc.iv_len);
848 		dh_need = MAXIMUM(dh_need, newkeys->mac.key_len);
849 	}
850 	/* XXX need runden? */
851 	kex->we_need = need;
852 	kex->dh_need = dh_need;
853 
854 	/* ignore the next message if the proposals do not match */
855 	if (first_kex_follows && !proposals_match(my, peer) &&
856 	    !(ssh->compat & SSH_BUG_FIRSTKEX))
857 		ssh->dispatch_skip_packets = 1;
858 	r = 0;
859  out:
860 	kex_prop_free(my);
861 	kex_prop_free(peer);
862 	return r;
863 }
864 
865 static int
866 derive_key(struct ssh *ssh, int id, u_int need, u_char *hash, u_int hashlen,
867     const struct sshbuf *shared_secret, u_char **keyp)
868 {
869 	struct kex *kex = ssh->kex;
870 	struct ssh_digest_ctx *hashctx = NULL;
871 	char c = id;
872 	u_int have;
873 	size_t mdsz;
874 	u_char *digest;
875 	int r;
876 
877 	if ((mdsz = ssh_digest_bytes(kex->hash_alg)) == 0)
878 		return SSH_ERR_INVALID_ARGUMENT;
879 	if ((digest = calloc(1, ROUNDUP(need, mdsz))) == NULL) {
880 		r = SSH_ERR_ALLOC_FAIL;
881 		goto out;
882 	}
883 
884 	/* K1 = HASH(K || H || "A" || session_id) */
885 	if ((hashctx = ssh_digest_start(kex->hash_alg)) == NULL ||
886 	    ssh_digest_update_buffer(hashctx, shared_secret) != 0 ||
887 	    ssh_digest_update(hashctx, hash, hashlen) != 0 ||
888 	    ssh_digest_update(hashctx, &c, 1) != 0 ||
889 	    ssh_digest_update(hashctx, kex->session_id,
890 	    kex->session_id_len) != 0 ||
891 	    ssh_digest_final(hashctx, digest, mdsz) != 0) {
892 		r = SSH_ERR_LIBCRYPTO_ERROR;
893 		goto out;
894 	}
895 	ssh_digest_free(hashctx);
896 	hashctx = NULL;
897 
898 	/*
899 	 * expand key:
900 	 * Kn = HASH(K || H || K1 || K2 || ... || Kn-1)
901 	 * Key = K1 || K2 || ... || Kn
902 	 */
903 	for (have = mdsz; need > have; have += mdsz) {
904 		if ((hashctx = ssh_digest_start(kex->hash_alg)) == NULL ||
905 		    ssh_digest_update_buffer(hashctx, shared_secret) != 0 ||
906 		    ssh_digest_update(hashctx, hash, hashlen) != 0 ||
907 		    ssh_digest_update(hashctx, digest, have) != 0 ||
908 		    ssh_digest_final(hashctx, digest + have, mdsz) != 0) {
909 			r = SSH_ERR_LIBCRYPTO_ERROR;
910 			goto out;
911 		}
912 		ssh_digest_free(hashctx);
913 		hashctx = NULL;
914 	}
915 #ifdef DEBUG_KEX
916 	fprintf(stderr, "key '%c'== ", c);
917 	dump_digest("key", digest, need);
918 #endif
919 	*keyp = digest;
920 	digest = NULL;
921 	r = 0;
922  out:
923 	free(digest);
924 	ssh_digest_free(hashctx);
925 	return r;
926 }
927 
928 #define NKEYS	6
929 int
930 kex_derive_keys(struct ssh *ssh, u_char *hash, u_int hashlen,
931     const struct sshbuf *shared_secret)
932 {
933 	struct kex *kex = ssh->kex;
934 	u_char *keys[NKEYS];
935 	u_int i, j, mode, ctos;
936 	int r;
937 
938 	for (i = 0; i < NKEYS; i++) {
939 		if ((r = derive_key(ssh, 'A'+i, kex->we_need, hash, hashlen,
940 		    shared_secret, &keys[i])) != 0) {
941 			for (j = 0; j < i; j++)
942 				free(keys[j]);
943 			return r;
944 		}
945 	}
946 	for (mode = 0; mode < MODE_MAX; mode++) {
947 		ctos = (!kex->server && mode == MODE_OUT) ||
948 		    (kex->server && mode == MODE_IN);
949 		kex->newkeys[mode]->enc.iv  = keys[ctos ? 0 : 1];
950 		kex->newkeys[mode]->enc.key = keys[ctos ? 2 : 3];
951 		kex->newkeys[mode]->mac.key = keys[ctos ? 4 : 5];
952 	}
953 	return 0;
954 }
955 
956 #ifdef WITH_OPENSSL
957 int
958 kex_derive_keys_bn(struct ssh *ssh, u_char *hash, u_int hashlen,
959     const BIGNUM *secret)
960 {
961 	struct sshbuf *shared_secret;
962 	int r;
963 
964 	if ((shared_secret = sshbuf_new()) == NULL)
965 		return SSH_ERR_ALLOC_FAIL;
966 	if ((r = sshbuf_put_bignum2(shared_secret, secret)) == 0)
967 		r = kex_derive_keys(ssh, hash, hashlen, shared_secret);
968 	sshbuf_free(shared_secret);
969 	return r;
970 }
971 #endif
972 
973 
974 #if defined(DEBUG_KEX) || defined(DEBUG_KEXDH) || defined(DEBUG_KEXECDH)
975 void
976 dump_digest(char *msg, u_char *digest, int len)
977 {
978 	fprintf(stderr, "%s\n", msg);
979 	sshbuf_dump_data(digest, len, stderr);
980 }
981 #endif
982