xref: /openbsd-src/usr.bin/ssh/kex.c (revision f2a19305cfc49ea4d1a5feb55cd6c283c6f1e031)
1 /* $OpenBSD: kex.c,v 1.185 2024/01/08 00:34:33 djm Exp $ */
2 /*
3  * Copyright (c) 2000, 2001 Markus Friedl.  All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  * 1. Redistributions of source code must retain the above copyright
9  *    notice, this list of conditions and the following disclaimer.
10  * 2. Redistributions in binary form must reproduce the above copyright
11  *    notice, this list of conditions and the following disclaimer in the
12  *    documentation and/or other materials provided with the distribution.
13  *
14  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
15  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
16  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
17  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
18  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
19  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
20  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
21  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
23  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24  */
25 
26 
27 #include <sys/types.h>
28 #include <errno.h>
29 #include <signal.h>
30 #include <stdio.h>
31 #include <stdlib.h>
32 #include <string.h>
33 #include <unistd.h>
34 #include <poll.h>
35 
36 #ifdef WITH_OPENSSL
37 #include <openssl/crypto.h>
38 #endif
39 
40 #include "ssh.h"
41 #include "ssh2.h"
42 #include "atomicio.h"
43 #include "version.h"
44 #include "packet.h"
45 #include "compat.h"
46 #include "cipher.h"
47 #include "sshkey.h"
48 #include "kex.h"
49 #include "log.h"
50 #include "mac.h"
51 #include "match.h"
52 #include "misc.h"
53 #include "dispatch.h"
54 #include "monitor.h"
55 #include "myproposal.h"
56 
57 #include "ssherr.h"
58 #include "sshbuf.h"
59 #include "digest.h"
60 #include "xmalloc.h"
61 
62 /* prototype */
63 static int kex_choose_conf(struct ssh *, uint32_t seq);
64 static int kex_input_newkeys(int, u_int32_t, struct ssh *);
65 
66 static const char * const proposal_names[PROPOSAL_MAX] = {
67 	"KEX algorithms",
68 	"host key algorithms",
69 	"ciphers ctos",
70 	"ciphers stoc",
71 	"MACs ctos",
72 	"MACs stoc",
73 	"compression ctos",
74 	"compression stoc",
75 	"languages ctos",
76 	"languages stoc",
77 };
78 
79 struct kexalg {
80 	char *name;
81 	u_int type;
82 	int ec_nid;
83 	int hash_alg;
84 };
85 static const struct kexalg kexalgs[] = {
86 #ifdef WITH_OPENSSL
87 	{ KEX_DH1, KEX_DH_GRP1_SHA1, 0, SSH_DIGEST_SHA1 },
88 	{ KEX_DH14_SHA1, KEX_DH_GRP14_SHA1, 0, SSH_DIGEST_SHA1 },
89 	{ KEX_DH14_SHA256, KEX_DH_GRP14_SHA256, 0, SSH_DIGEST_SHA256 },
90 	{ KEX_DH16_SHA512, KEX_DH_GRP16_SHA512, 0, SSH_DIGEST_SHA512 },
91 	{ KEX_DH18_SHA512, KEX_DH_GRP18_SHA512, 0, SSH_DIGEST_SHA512 },
92 	{ KEX_DHGEX_SHA1, KEX_DH_GEX_SHA1, 0, SSH_DIGEST_SHA1 },
93 	{ KEX_DHGEX_SHA256, KEX_DH_GEX_SHA256, 0, SSH_DIGEST_SHA256 },
94 	{ KEX_ECDH_SHA2_NISTP256, KEX_ECDH_SHA2,
95 	    NID_X9_62_prime256v1, SSH_DIGEST_SHA256 },
96 	{ KEX_ECDH_SHA2_NISTP384, KEX_ECDH_SHA2, NID_secp384r1,
97 	    SSH_DIGEST_SHA384 },
98 	{ KEX_ECDH_SHA2_NISTP521, KEX_ECDH_SHA2, NID_secp521r1,
99 	    SSH_DIGEST_SHA512 },
100 #endif
101 	{ KEX_CURVE25519_SHA256, KEX_C25519_SHA256, 0, SSH_DIGEST_SHA256 },
102 	{ KEX_CURVE25519_SHA256_OLD, KEX_C25519_SHA256, 0, SSH_DIGEST_SHA256 },
103 	{ KEX_SNTRUP761X25519_SHA512, KEX_KEM_SNTRUP761X25519_SHA512, 0,
104 	    SSH_DIGEST_SHA512 },
105 	{ NULL, 0, -1, -1},
106 };
107 
108 char *
109 kex_alg_list(char sep)
110 {
111 	char *ret = NULL, *tmp;
112 	size_t nlen, rlen = 0;
113 	const struct kexalg *k;
114 
115 	for (k = kexalgs; k->name != NULL; k++) {
116 		if (ret != NULL)
117 			ret[rlen++] = sep;
118 		nlen = strlen(k->name);
119 		if ((tmp = realloc(ret, rlen + nlen + 2)) == NULL) {
120 			free(ret);
121 			return NULL;
122 		}
123 		ret = tmp;
124 		memcpy(ret + rlen, k->name, nlen + 1);
125 		rlen += nlen;
126 	}
127 	return ret;
128 }
129 
130 static const struct kexalg *
131 kex_alg_by_name(const char *name)
132 {
133 	const struct kexalg *k;
134 
135 	for (k = kexalgs; k->name != NULL; k++) {
136 		if (strcmp(k->name, name) == 0)
137 			return k;
138 	}
139 	return NULL;
140 }
141 
142 /* Validate KEX method name list */
143 int
144 kex_names_valid(const char *names)
145 {
146 	char *s, *cp, *p;
147 
148 	if (names == NULL || strcmp(names, "") == 0)
149 		return 0;
150 	if ((s = cp = strdup(names)) == NULL)
151 		return 0;
152 	for ((p = strsep(&cp, ",")); p && *p != '\0';
153 	    (p = strsep(&cp, ","))) {
154 		if (kex_alg_by_name(p) == NULL) {
155 			error("Unsupported KEX algorithm \"%.100s\"", p);
156 			free(s);
157 			return 0;
158 		}
159 	}
160 	debug3("kex names ok: [%s]", names);
161 	free(s);
162 	return 1;
163 }
164 
165 /* returns non-zero if proposal contains any algorithm from algs */
166 static int
167 has_any_alg(const char *proposal, const char *algs)
168 {
169 	char *cp;
170 
171 	if ((cp = match_list(proposal, algs, NULL)) == NULL)
172 		return 0;
173 	free(cp);
174 	return 1;
175 }
176 
177 /*
178  * Concatenate algorithm names, avoiding duplicates in the process.
179  * Caller must free returned string.
180  */
181 char *
182 kex_names_cat(const char *a, const char *b)
183 {
184 	char *ret = NULL, *tmp = NULL, *cp, *p;
185 	size_t len;
186 
187 	if (a == NULL || *a == '\0')
188 		return strdup(b);
189 	if (b == NULL || *b == '\0')
190 		return strdup(a);
191 	if (strlen(b) > 1024*1024)
192 		return NULL;
193 	len = strlen(a) + strlen(b) + 2;
194 	if ((tmp = cp = strdup(b)) == NULL ||
195 	    (ret = calloc(1, len)) == NULL) {
196 		free(tmp);
197 		return NULL;
198 	}
199 	strlcpy(ret, a, len);
200 	for ((p = strsep(&cp, ",")); p && *p != '\0'; (p = strsep(&cp, ","))) {
201 		if (has_any_alg(ret, p))
202 			continue; /* Algorithm already present */
203 		if (strlcat(ret, ",", len) >= len ||
204 		    strlcat(ret, p, len) >= len) {
205 			free(tmp);
206 			free(ret);
207 			return NULL; /* Shouldn't happen */
208 		}
209 	}
210 	free(tmp);
211 	return ret;
212 }
213 
214 /*
215  * Assemble a list of algorithms from a default list and a string from a
216  * configuration file. The user-provided string may begin with '+' to
217  * indicate that it should be appended to the default, '-' that the
218  * specified names should be removed, or '^' that they should be placed
219  * at the head.
220  */
221 int
222 kex_assemble_names(char **listp, const char *def, const char *all)
223 {
224 	char *cp, *tmp, *patterns;
225 	char *list = NULL, *ret = NULL, *matching = NULL, *opatterns = NULL;
226 	int r = SSH_ERR_INTERNAL_ERROR;
227 
228 	if (listp == NULL || def == NULL || all == NULL)
229 		return SSH_ERR_INVALID_ARGUMENT;
230 
231 	if (*listp == NULL || **listp == '\0') {
232 		if ((*listp = strdup(def)) == NULL)
233 			return SSH_ERR_ALLOC_FAIL;
234 		return 0;
235 	}
236 
237 	list = *listp;
238 	*listp = NULL;
239 	if (*list == '+') {
240 		/* Append names to default list */
241 		if ((tmp = kex_names_cat(def, list + 1)) == NULL) {
242 			r = SSH_ERR_ALLOC_FAIL;
243 			goto fail;
244 		}
245 		free(list);
246 		list = tmp;
247 	} else if (*list == '-') {
248 		/* Remove names from default list */
249 		if ((*listp = match_filter_denylist(def, list + 1)) == NULL) {
250 			r = SSH_ERR_ALLOC_FAIL;
251 			goto fail;
252 		}
253 		free(list);
254 		/* filtering has already been done */
255 		return 0;
256 	} else if (*list == '^') {
257 		/* Place names at head of default list */
258 		if ((tmp = kex_names_cat(list + 1, def)) == NULL) {
259 			r = SSH_ERR_ALLOC_FAIL;
260 			goto fail;
261 		}
262 		free(list);
263 		list = tmp;
264 	} else {
265 		/* Explicit list, overrides default - just use "list" as is */
266 	}
267 
268 	/*
269 	 * The supplied names may be a pattern-list. For the -list case,
270 	 * the patterns are applied above. For the +list and explicit list
271 	 * cases we need to do it now.
272 	 */
273 	ret = NULL;
274 	if ((patterns = opatterns = strdup(list)) == NULL) {
275 		r = SSH_ERR_ALLOC_FAIL;
276 		goto fail;
277 	}
278 	/* Apply positive (i.e. non-negated) patterns from the list */
279 	while ((cp = strsep(&patterns, ",")) != NULL) {
280 		if (*cp == '!') {
281 			/* negated matches are not supported here */
282 			r = SSH_ERR_INVALID_ARGUMENT;
283 			goto fail;
284 		}
285 		free(matching);
286 		if ((matching = match_filter_allowlist(all, cp)) == NULL) {
287 			r = SSH_ERR_ALLOC_FAIL;
288 			goto fail;
289 		}
290 		if ((tmp = kex_names_cat(ret, matching)) == NULL) {
291 			r = SSH_ERR_ALLOC_FAIL;
292 			goto fail;
293 		}
294 		free(ret);
295 		ret = tmp;
296 	}
297 	if (ret == NULL || *ret == '\0') {
298 		/* An empty name-list is an error */
299 		/* XXX better error code? */
300 		r = SSH_ERR_INVALID_ARGUMENT;
301 		goto fail;
302 	}
303 
304 	/* success */
305 	*listp = ret;
306 	ret = NULL;
307 	r = 0;
308 
309  fail:
310 	free(matching);
311 	free(opatterns);
312 	free(list);
313 	free(ret);
314 	return r;
315 }
316 
317 /*
318  * Fill out a proposal array with dynamically allocated values, which may
319  * be modified as required for compatibility reasons.
320  * Any of the options may be NULL, in which case the default is used.
321  * Array contents must be freed by calling kex_proposal_free_entries.
322  */
323 void
324 kex_proposal_populate_entries(struct ssh *ssh, char *prop[PROPOSAL_MAX],
325     const char *kexalgos, const char *ciphers, const char *macs,
326     const char *comp, const char *hkalgs)
327 {
328 	const char *defpropserver[PROPOSAL_MAX] = { KEX_SERVER };
329 	const char *defpropclient[PROPOSAL_MAX] = { KEX_CLIENT };
330 	const char **defprop = ssh->kex->server ? defpropserver : defpropclient;
331 	u_int i;
332 	char *cp;
333 
334 	if (prop == NULL)
335 		fatal_f("proposal missing");
336 
337 	/* Append EXT_INFO signalling to KexAlgorithms */
338 	if (kexalgos == NULL)
339 		kexalgos = defprop[PROPOSAL_KEX_ALGS];
340 	if ((cp = kex_names_cat(kexalgos, ssh->kex->server ?
341 	    "ext-info-s,kex-strict-s-v00@openssh.com" :
342 	    "ext-info-c,kex-strict-c-v00@openssh.com")) == NULL)
343 		fatal_f("kex_names_cat");
344 
345 	for (i = 0; i < PROPOSAL_MAX; i++) {
346 		switch(i) {
347 		case PROPOSAL_KEX_ALGS:
348 			prop[i] = compat_kex_proposal(ssh, cp);
349 			break;
350 		case PROPOSAL_ENC_ALGS_CTOS:
351 		case PROPOSAL_ENC_ALGS_STOC:
352 			prop[i] = xstrdup(ciphers ? ciphers : defprop[i]);
353 			break;
354 		case PROPOSAL_MAC_ALGS_CTOS:
355 		case PROPOSAL_MAC_ALGS_STOC:
356 			prop[i]  = xstrdup(macs ? macs : defprop[i]);
357 			break;
358 		case PROPOSAL_COMP_ALGS_CTOS:
359 		case PROPOSAL_COMP_ALGS_STOC:
360 			prop[i] = xstrdup(comp ? comp : defprop[i]);
361 			break;
362 		case PROPOSAL_SERVER_HOST_KEY_ALGS:
363 			prop[i] = xstrdup(hkalgs ? hkalgs : defprop[i]);
364 			break;
365 		default:
366 			prop[i] = xstrdup(defprop[i]);
367 		}
368 	}
369 	free(cp);
370 }
371 
372 void
373 kex_proposal_free_entries(char *prop[PROPOSAL_MAX])
374 {
375 	u_int i;
376 
377 	for (i = 0; i < PROPOSAL_MAX; i++)
378 		free(prop[i]);
379 }
380 
381 /* put algorithm proposal into buffer */
382 int
383 kex_prop2buf(struct sshbuf *b, char *proposal[PROPOSAL_MAX])
384 {
385 	u_int i;
386 	int r;
387 
388 	sshbuf_reset(b);
389 
390 	/*
391 	 * add a dummy cookie, the cookie will be overwritten by
392 	 * kex_send_kexinit(), each time a kexinit is set
393 	 */
394 	for (i = 0; i < KEX_COOKIE_LEN; i++) {
395 		if ((r = sshbuf_put_u8(b, 0)) != 0)
396 			return r;
397 	}
398 	for (i = 0; i < PROPOSAL_MAX; i++) {
399 		if ((r = sshbuf_put_cstring(b, proposal[i])) != 0)
400 			return r;
401 	}
402 	if ((r = sshbuf_put_u8(b, 0)) != 0 ||	/* first_kex_packet_follows */
403 	    (r = sshbuf_put_u32(b, 0)) != 0)	/* uint32 reserved */
404 		return r;
405 	return 0;
406 }
407 
408 /* parse buffer and return algorithm proposal */
409 int
410 kex_buf2prop(struct sshbuf *raw, int *first_kex_follows, char ***propp)
411 {
412 	struct sshbuf *b = NULL;
413 	u_char v;
414 	u_int i;
415 	char **proposal = NULL;
416 	int r;
417 
418 	*propp = NULL;
419 	if ((proposal = calloc(PROPOSAL_MAX, sizeof(char *))) == NULL)
420 		return SSH_ERR_ALLOC_FAIL;
421 	if ((b = sshbuf_fromb(raw)) == NULL) {
422 		r = SSH_ERR_ALLOC_FAIL;
423 		goto out;
424 	}
425 	if ((r = sshbuf_consume(b, KEX_COOKIE_LEN)) != 0) { /* skip cookie */
426 		error_fr(r, "consume cookie");
427 		goto out;
428 	}
429 	/* extract kex init proposal strings */
430 	for (i = 0; i < PROPOSAL_MAX; i++) {
431 		if ((r = sshbuf_get_cstring(b, &(proposal[i]), NULL)) != 0) {
432 			error_fr(r, "parse proposal %u", i);
433 			goto out;
434 		}
435 		debug2("%s: %s", proposal_names[i], proposal[i]);
436 	}
437 	/* first kex follows / reserved */
438 	if ((r = sshbuf_get_u8(b, &v)) != 0 ||	/* first_kex_follows */
439 	    (r = sshbuf_get_u32(b, &i)) != 0) {	/* reserved */
440 		error_fr(r, "parse");
441 		goto out;
442 	}
443 	if (first_kex_follows != NULL)
444 		*first_kex_follows = v;
445 	debug2("first_kex_follows %d ", v);
446 	debug2("reserved %u ", i);
447 	r = 0;
448 	*propp = proposal;
449  out:
450 	if (r != 0 && proposal != NULL)
451 		kex_prop_free(proposal);
452 	sshbuf_free(b);
453 	return r;
454 }
455 
456 void
457 kex_prop_free(char **proposal)
458 {
459 	u_int i;
460 
461 	if (proposal == NULL)
462 		return;
463 	for (i = 0; i < PROPOSAL_MAX; i++)
464 		free(proposal[i]);
465 	free(proposal);
466 }
467 
468 int
469 kex_protocol_error(int type, u_int32_t seq, struct ssh *ssh)
470 {
471 	int r;
472 
473 	/* If in strict mode, any unexpected message is an error */
474 	if ((ssh->kex->flags & KEX_INITIAL) && ssh->kex->kex_strict) {
475 		ssh_packet_disconnect(ssh, "strict KEX violation: "
476 		    "unexpected packet type %u (seqnr %u)", type, seq);
477 	}
478 	error_f("type %u seq %u", type, seq);
479 	if ((r = sshpkt_start(ssh, SSH2_MSG_UNIMPLEMENTED)) != 0 ||
480 	    (r = sshpkt_put_u32(ssh, seq)) != 0 ||
481 	    (r = sshpkt_send(ssh)) != 0)
482 		return r;
483 	return 0;
484 }
485 
486 static void
487 kex_reset_dispatch(struct ssh *ssh)
488 {
489 	ssh_dispatch_range(ssh, SSH2_MSG_TRANSPORT_MIN,
490 	    SSH2_MSG_TRANSPORT_MAX, &kex_protocol_error);
491 }
492 
493 void
494 kex_set_server_sig_algs(struct ssh *ssh, const char *allowed_algs)
495 {
496 	char *alg, *oalgs, *algs, *sigalgs;
497 	const char *sigalg;
498 
499 	/*
500 	 * NB. allowed algorithms may contain certificate algorithms that
501 	 * map to a specific plain signature type, e.g.
502 	 * rsa-sha2-512-cert-v01@openssh.com => rsa-sha2-512
503 	 * We need to be careful here to match these, retain the mapping
504 	 * and only add each signature algorithm once.
505 	 */
506 	if ((sigalgs = sshkey_alg_list(0, 1, 1, ',')) == NULL)
507 		fatal_f("sshkey_alg_list failed");
508 	oalgs = algs = xstrdup(allowed_algs);
509 	free(ssh->kex->server_sig_algs);
510 	ssh->kex->server_sig_algs = NULL;
511 	for ((alg = strsep(&algs, ",")); alg != NULL && *alg != '\0';
512 	    (alg = strsep(&algs, ","))) {
513 		if ((sigalg = sshkey_sigalg_by_name(alg)) == NULL)
514 			continue;
515 		if (!has_any_alg(sigalg, sigalgs))
516 			continue;
517 		/* Don't add an algorithm twice. */
518 		if (ssh->kex->server_sig_algs != NULL &&
519 		    has_any_alg(sigalg, ssh->kex->server_sig_algs))
520 			continue;
521 		xextendf(&ssh->kex->server_sig_algs, ",", "%s", sigalg);
522 	}
523 	free(oalgs);
524 	free(sigalgs);
525 	if (ssh->kex->server_sig_algs == NULL)
526 		ssh->kex->server_sig_algs = xstrdup("");
527 }
528 
529 static int
530 kex_compose_ext_info_server(struct ssh *ssh, struct sshbuf *m)
531 {
532 	int r;
533 
534 	if (ssh->kex->server_sig_algs == NULL &&
535 	    (ssh->kex->server_sig_algs = sshkey_alg_list(0, 1, 1, ',')) == NULL)
536 		return SSH_ERR_ALLOC_FAIL;
537 	if ((r = sshbuf_put_u32(m, 3)) != 0 ||
538 	    (r = sshbuf_put_cstring(m, "server-sig-algs")) != 0 ||
539 	    (r = sshbuf_put_cstring(m, ssh->kex->server_sig_algs)) != 0 ||
540 	    (r = sshbuf_put_cstring(m,
541 	    "publickey-hostbound@openssh.com")) != 0 ||
542 	    (r = sshbuf_put_cstring(m, "0")) != 0 ||
543 	    (r = sshbuf_put_cstring(m, "ping@openssh.com")) != 0 ||
544 	    (r = sshbuf_put_cstring(m, "0")) != 0) {
545 		error_fr(r, "compose");
546 		return r;
547 	}
548 	return 0;
549 }
550 
551 static int
552 kex_compose_ext_info_client(struct ssh *ssh, struct sshbuf *m)
553 {
554 	int r;
555 
556 	if ((r = sshbuf_put_u32(m, 1)) != 0 ||
557 	    (r = sshbuf_put_cstring(m, "ext-info-in-auth@openssh.com")) != 0 ||
558 	    (r = sshbuf_put_cstring(m, "0")) != 0) {
559 		error_fr(r, "compose");
560 		goto out;
561 	}
562 	/* success */
563 	r = 0;
564  out:
565 	return r;
566 }
567 
568 static int
569 kex_maybe_send_ext_info(struct ssh *ssh)
570 {
571 	int r;
572 	struct sshbuf *m = NULL;
573 
574 	if ((ssh->kex->flags & KEX_INITIAL) == 0)
575 		return 0;
576 	if (!ssh->kex->ext_info_c && !ssh->kex->ext_info_s)
577 		return 0;
578 
579 	/* Compose EXT_INFO packet. */
580 	if ((m = sshbuf_new()) == NULL)
581 		fatal_f("sshbuf_new failed");
582 	if (ssh->kex->ext_info_c &&
583 	    (r = kex_compose_ext_info_server(ssh, m)) != 0)
584 		goto fail;
585 	if (ssh->kex->ext_info_s &&
586 	    (r = kex_compose_ext_info_client(ssh, m)) != 0)
587 		goto fail;
588 
589 	/* Send the actual KEX_INFO packet */
590 	debug("Sending SSH2_MSG_EXT_INFO");
591 	if ((r = sshpkt_start(ssh, SSH2_MSG_EXT_INFO)) != 0 ||
592 	    (r = sshpkt_putb(ssh, m)) != 0 ||
593 	    (r = sshpkt_send(ssh)) != 0) {
594 		error_f("send EXT_INFO");
595 		goto fail;
596 	}
597 
598 	r = 0;
599 
600  fail:
601 	sshbuf_free(m);
602 	return r;
603 }
604 
605 int
606 kex_server_update_ext_info(struct ssh *ssh)
607 {
608 	int r;
609 
610 	if ((ssh->kex->flags & KEX_HAS_EXT_INFO_IN_AUTH) == 0)
611 		return 0;
612 
613 	debug_f("Sending SSH2_MSG_EXT_INFO");
614 	if ((r = sshpkt_start(ssh, SSH2_MSG_EXT_INFO)) != 0 ||
615 	    (r = sshpkt_put_u32(ssh, 1)) != 0 ||
616 	    (r = sshpkt_put_cstring(ssh, "server-sig-algs")) != 0 ||
617 	    (r = sshpkt_put_cstring(ssh, ssh->kex->server_sig_algs)) != 0 ||
618 	    (r = sshpkt_send(ssh)) != 0) {
619 		error_f("send EXT_INFO");
620 		return r;
621 	}
622 	return 0;
623 }
624 
625 int
626 kex_send_newkeys(struct ssh *ssh)
627 {
628 	int r;
629 
630 	kex_reset_dispatch(ssh);
631 	if ((r = sshpkt_start(ssh, SSH2_MSG_NEWKEYS)) != 0 ||
632 	    (r = sshpkt_send(ssh)) != 0)
633 		return r;
634 	debug("SSH2_MSG_NEWKEYS sent");
635 	ssh_dispatch_set(ssh, SSH2_MSG_NEWKEYS, &kex_input_newkeys);
636 	if ((r = kex_maybe_send_ext_info(ssh)) != 0)
637 		return r;
638 	debug("expecting SSH2_MSG_NEWKEYS");
639 	return 0;
640 }
641 
642 /* Check whether an ext_info value contains the expected version string */
643 static int
644 kex_ext_info_check_ver(struct kex *kex, const char *name,
645     const u_char *val, size_t len, const char *want_ver, u_int flag)
646 {
647 	if (memchr(val, '\0', len) != NULL) {
648 		error("SSH2_MSG_EXT_INFO: %s value contains nul byte", name);
649 		return SSH_ERR_INVALID_FORMAT;
650 	}
651 	debug_f("%s=<%s>", name, val);
652 	if (strcmp(val, want_ver) == 0)
653 		kex->flags |= flag;
654 	else
655 		debug_f("unsupported version of %s extension", name);
656 	return 0;
657 }
658 
659 static int
660 kex_ext_info_client_parse(struct ssh *ssh, const char *name,
661     const u_char *value, size_t vlen)
662 {
663 	int r;
664 
665 	/* NB. some messages are only accepted in the initial EXT_INFO */
666 	if (strcmp(name, "server-sig-algs") == 0) {
667 		/* Ensure no \0 lurking in value */
668 		if (memchr(value, '\0', vlen) != NULL) {
669 			error_f("nul byte in %s", name);
670 			return SSH_ERR_INVALID_FORMAT;
671 		}
672 		debug_f("%s=<%s>", name, value);
673 		free(ssh->kex->server_sig_algs);
674 		ssh->kex->server_sig_algs = xstrdup((const char *)value);
675 	} else if (ssh->kex->ext_info_received == 1 &&
676 	    strcmp(name, "publickey-hostbound@openssh.com") == 0) {
677 		if ((r = kex_ext_info_check_ver(ssh->kex, name, value, vlen,
678 		    "0", KEX_HAS_PUBKEY_HOSTBOUND)) != 0) {
679 			return r;
680 		}
681 	} else if (ssh->kex->ext_info_received == 1 &&
682 	    strcmp(name, "ping@openssh.com") == 0) {
683 		if ((r = kex_ext_info_check_ver(ssh->kex, name, value, vlen,
684 		    "0", KEX_HAS_PING)) != 0) {
685 			return r;
686 		}
687 	} else
688 		debug_f("%s (unrecognised)", name);
689 
690 	return 0;
691 }
692 
693 static int
694 kex_ext_info_server_parse(struct ssh *ssh, const char *name,
695     const u_char *value, size_t vlen)
696 {
697 	int r;
698 
699 	if (strcmp(name, "ext-info-in-auth@openssh.com") == 0) {
700 		if ((r = kex_ext_info_check_ver(ssh->kex, name, value, vlen,
701 		    "0", KEX_HAS_EXT_INFO_IN_AUTH)) != 0) {
702 			return r;
703 		}
704 	} else
705 		debug_f("%s (unrecognised)", name);
706 	return 0;
707 }
708 
709 int
710 kex_input_ext_info(int type, u_int32_t seq, struct ssh *ssh)
711 {
712 	struct kex *kex = ssh->kex;
713 	const int max_ext_info = kex->server ? 1 : 2;
714 	u_int32_t i, ninfo;
715 	char *name;
716 	u_char *val;
717 	size_t vlen;
718 	int r;
719 
720 	debug("SSH2_MSG_EXT_INFO received");
721 	if (++kex->ext_info_received > max_ext_info) {
722 		error("too many SSH2_MSG_EXT_INFO messages sent by peer");
723 		return dispatch_protocol_error(type, seq, ssh);
724 	}
725 	ssh_dispatch_set(ssh, SSH2_MSG_EXT_INFO, &kex_protocol_error);
726 	if ((r = sshpkt_get_u32(ssh, &ninfo)) != 0)
727 		return r;
728 	if (ninfo >= 1024) {
729 		error("SSH2_MSG_EXT_INFO with too many entries, expected "
730 		    "<=1024, received %u", ninfo);
731 		return dispatch_protocol_error(type, seq, ssh);
732 	}
733 	for (i = 0; i < ninfo; i++) {
734 		if ((r = sshpkt_get_cstring(ssh, &name, NULL)) != 0)
735 			return r;
736 		if ((r = sshpkt_get_string(ssh, &val, &vlen)) != 0) {
737 			free(name);
738 			return r;
739 		}
740 		debug3_f("extension %s", name);
741 		if (kex->server) {
742 			if ((r = kex_ext_info_server_parse(ssh, name,
743 			    val, vlen)) != 0)
744 				return r;
745 		} else {
746 			if ((r = kex_ext_info_client_parse(ssh, name,
747 			    val, vlen)) != 0)
748 				return r;
749 		}
750 		free(name);
751 		free(val);
752 	}
753 	return sshpkt_get_end(ssh);
754 }
755 
756 static int
757 kex_input_newkeys(int type, u_int32_t seq, struct ssh *ssh)
758 {
759 	struct kex *kex = ssh->kex;
760 	int r, initial = (kex->flags & KEX_INITIAL) != 0;
761 	char *cp, **prop;
762 
763 	debug("SSH2_MSG_NEWKEYS received");
764 	if (kex->ext_info_c && initial)
765 		ssh_dispatch_set(ssh, SSH2_MSG_EXT_INFO, &kex_input_ext_info);
766 	ssh_dispatch_set(ssh, SSH2_MSG_NEWKEYS, &kex_protocol_error);
767 	ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_input_kexinit);
768 	if ((r = sshpkt_get_end(ssh)) != 0)
769 		return r;
770 	if ((r = ssh_set_newkeys(ssh, MODE_IN)) != 0)
771 		return r;
772 	if (initial) {
773 		/* Remove initial KEX signalling from proposal for rekeying */
774 		if ((r = kex_buf2prop(kex->my, NULL, &prop)) != 0)
775 			return r;
776 		if ((cp = match_filter_denylist(prop[PROPOSAL_KEX_ALGS],
777 		    kex->server ?
778 		    "ext-info-s,kex-strict-s-v00@openssh.com" :
779 		    "ext-info-c,kex-strict-c-v00@openssh.com")) == NULL) {
780 			error_f("match_filter_denylist failed");
781 			goto fail;
782 		}
783 		free(prop[PROPOSAL_KEX_ALGS]);
784 		prop[PROPOSAL_KEX_ALGS] = cp;
785 		if ((r = kex_prop2buf(ssh->kex->my, prop)) != 0) {
786 			error_f("kex_prop2buf failed");
787  fail:
788 			kex_proposal_free_entries(prop);
789 			free(prop);
790 			return SSH_ERR_INTERNAL_ERROR;
791 		}
792 		kex_proposal_free_entries(prop);
793 		free(prop);
794 	}
795 	kex->done = 1;
796 	kex->flags &= ~KEX_INITIAL;
797 	sshbuf_reset(kex->peer);
798 	kex->flags &= ~KEX_INIT_SENT;
799 	free(kex->name);
800 	kex->name = NULL;
801 	return 0;
802 }
803 
804 int
805 kex_send_kexinit(struct ssh *ssh)
806 {
807 	u_char *cookie;
808 	struct kex *kex = ssh->kex;
809 	int r;
810 
811 	if (kex == NULL) {
812 		error_f("no kex");
813 		return SSH_ERR_INTERNAL_ERROR;
814 	}
815 	if (kex->flags & KEX_INIT_SENT)
816 		return 0;
817 	kex->done = 0;
818 
819 	/* generate a random cookie */
820 	if (sshbuf_len(kex->my) < KEX_COOKIE_LEN) {
821 		error_f("bad kex length: %zu < %d",
822 		    sshbuf_len(kex->my), KEX_COOKIE_LEN);
823 		return SSH_ERR_INVALID_FORMAT;
824 	}
825 	if ((cookie = sshbuf_mutable_ptr(kex->my)) == NULL) {
826 		error_f("buffer error");
827 		return SSH_ERR_INTERNAL_ERROR;
828 	}
829 	arc4random_buf(cookie, KEX_COOKIE_LEN);
830 
831 	if ((r = sshpkt_start(ssh, SSH2_MSG_KEXINIT)) != 0 ||
832 	    (r = sshpkt_putb(ssh, kex->my)) != 0 ||
833 	    (r = sshpkt_send(ssh)) != 0) {
834 		error_fr(r, "compose reply");
835 		return r;
836 	}
837 	debug("SSH2_MSG_KEXINIT sent");
838 	kex->flags |= KEX_INIT_SENT;
839 	return 0;
840 }
841 
842 int
843 kex_input_kexinit(int type, u_int32_t seq, struct ssh *ssh)
844 {
845 	struct kex *kex = ssh->kex;
846 	const u_char *ptr;
847 	u_int i;
848 	size_t dlen;
849 	int r;
850 
851 	debug("SSH2_MSG_KEXINIT received");
852 	if (kex == NULL) {
853 		error_f("no kex");
854 		return SSH_ERR_INTERNAL_ERROR;
855 	}
856 	ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_protocol_error);
857 	ptr = sshpkt_ptr(ssh, &dlen);
858 	if ((r = sshbuf_put(kex->peer, ptr, dlen)) != 0)
859 		return r;
860 
861 	/* discard packet */
862 	for (i = 0; i < KEX_COOKIE_LEN; i++) {
863 		if ((r = sshpkt_get_u8(ssh, NULL)) != 0) {
864 			error_fr(r, "discard cookie");
865 			return r;
866 		}
867 	}
868 	for (i = 0; i < PROPOSAL_MAX; i++) {
869 		if ((r = sshpkt_get_string(ssh, NULL, NULL)) != 0) {
870 			error_fr(r, "discard proposal");
871 			return r;
872 		}
873 	}
874 	/*
875 	 * XXX RFC4253 sec 7: "each side MAY guess" - currently no supported
876 	 * KEX method has the server move first, but a server might be using
877 	 * a custom method or one that we otherwise don't support. We should
878 	 * be prepared to remember first_kex_follows here so we can eat a
879 	 * packet later.
880 	 * XXX2 - RFC4253 is kind of ambiguous on what first_kex_follows means
881 	 * for cases where the server *doesn't* go first. I guess we should
882 	 * ignore it when it is set for these cases, which is what we do now.
883 	 */
884 	if ((r = sshpkt_get_u8(ssh, NULL)) != 0 ||	/* first_kex_follows */
885 	    (r = sshpkt_get_u32(ssh, NULL)) != 0 ||	/* reserved */
886 	    (r = sshpkt_get_end(ssh)) != 0)
887 			return r;
888 
889 	if (!(kex->flags & KEX_INIT_SENT))
890 		if ((r = kex_send_kexinit(ssh)) != 0)
891 			return r;
892 	if ((r = kex_choose_conf(ssh, seq)) != 0)
893 		return r;
894 
895 	if (kex->kex_type < KEX_MAX && kex->kex[kex->kex_type] != NULL)
896 		return (kex->kex[kex->kex_type])(ssh);
897 
898 	error_f("unknown kex type %u", kex->kex_type);
899 	return SSH_ERR_INTERNAL_ERROR;
900 }
901 
902 struct kex *
903 kex_new(void)
904 {
905 	struct kex *kex;
906 
907 	if ((kex = calloc(1, sizeof(*kex))) == NULL ||
908 	    (kex->peer = sshbuf_new()) == NULL ||
909 	    (kex->my = sshbuf_new()) == NULL ||
910 	    (kex->client_version = sshbuf_new()) == NULL ||
911 	    (kex->server_version = sshbuf_new()) == NULL ||
912 	    (kex->session_id = sshbuf_new()) == NULL) {
913 		kex_free(kex);
914 		return NULL;
915 	}
916 	return kex;
917 }
918 
919 void
920 kex_free_newkeys(struct newkeys *newkeys)
921 {
922 	if (newkeys == NULL)
923 		return;
924 	if (newkeys->enc.key) {
925 		explicit_bzero(newkeys->enc.key, newkeys->enc.key_len);
926 		free(newkeys->enc.key);
927 		newkeys->enc.key = NULL;
928 	}
929 	if (newkeys->enc.iv) {
930 		explicit_bzero(newkeys->enc.iv, newkeys->enc.iv_len);
931 		free(newkeys->enc.iv);
932 		newkeys->enc.iv = NULL;
933 	}
934 	free(newkeys->enc.name);
935 	explicit_bzero(&newkeys->enc, sizeof(newkeys->enc));
936 	free(newkeys->comp.name);
937 	explicit_bzero(&newkeys->comp, sizeof(newkeys->comp));
938 	mac_clear(&newkeys->mac);
939 	if (newkeys->mac.key) {
940 		explicit_bzero(newkeys->mac.key, newkeys->mac.key_len);
941 		free(newkeys->mac.key);
942 		newkeys->mac.key = NULL;
943 	}
944 	free(newkeys->mac.name);
945 	explicit_bzero(&newkeys->mac, sizeof(newkeys->mac));
946 	freezero(newkeys, sizeof(*newkeys));
947 }
948 
949 void
950 kex_free(struct kex *kex)
951 {
952 	u_int mode;
953 
954 	if (kex == NULL)
955 		return;
956 
957 #ifdef WITH_OPENSSL
958 	DH_free(kex->dh);
959 	EC_KEY_free(kex->ec_client_key);
960 #endif
961 	for (mode = 0; mode < MODE_MAX; mode++) {
962 		kex_free_newkeys(kex->newkeys[mode]);
963 		kex->newkeys[mode] = NULL;
964 	}
965 	sshbuf_free(kex->peer);
966 	sshbuf_free(kex->my);
967 	sshbuf_free(kex->client_version);
968 	sshbuf_free(kex->server_version);
969 	sshbuf_free(kex->client_pub);
970 	sshbuf_free(kex->session_id);
971 	sshbuf_free(kex->initial_sig);
972 	sshkey_free(kex->initial_hostkey);
973 	free(kex->failed_choice);
974 	free(kex->hostkey_alg);
975 	free(kex->name);
976 	free(kex);
977 }
978 
979 int
980 kex_ready(struct ssh *ssh, char *proposal[PROPOSAL_MAX])
981 {
982 	int r;
983 
984 	if ((r = kex_prop2buf(ssh->kex->my, proposal)) != 0)
985 		return r;
986 	ssh->kex->flags = KEX_INITIAL;
987 	kex_reset_dispatch(ssh);
988 	ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_input_kexinit);
989 	return 0;
990 }
991 
992 int
993 kex_setup(struct ssh *ssh, char *proposal[PROPOSAL_MAX])
994 {
995 	int r;
996 
997 	if ((r = kex_ready(ssh, proposal)) != 0)
998 		return r;
999 	if ((r = kex_send_kexinit(ssh)) != 0) {		/* we start */
1000 		kex_free(ssh->kex);
1001 		ssh->kex = NULL;
1002 		return r;
1003 	}
1004 	return 0;
1005 }
1006 
1007 /*
1008  * Request key re-exchange, returns 0 on success or a ssherr.h error
1009  * code otherwise. Must not be called if KEX is incomplete or in-progress.
1010  */
1011 int
1012 kex_start_rekex(struct ssh *ssh)
1013 {
1014 	if (ssh->kex == NULL) {
1015 		error_f("no kex");
1016 		return SSH_ERR_INTERNAL_ERROR;
1017 	}
1018 	if (ssh->kex->done == 0) {
1019 		error_f("requested twice");
1020 		return SSH_ERR_INTERNAL_ERROR;
1021 	}
1022 	ssh->kex->done = 0;
1023 	return kex_send_kexinit(ssh);
1024 }
1025 
1026 static int
1027 choose_enc(struct sshenc *enc, char *client, char *server)
1028 {
1029 	char *name = match_list(client, server, NULL);
1030 
1031 	if (name == NULL)
1032 		return SSH_ERR_NO_CIPHER_ALG_MATCH;
1033 	if ((enc->cipher = cipher_by_name(name)) == NULL) {
1034 		error_f("unsupported cipher %s", name);
1035 		free(name);
1036 		return SSH_ERR_INTERNAL_ERROR;
1037 	}
1038 	enc->name = name;
1039 	enc->enabled = 0;
1040 	enc->iv = NULL;
1041 	enc->iv_len = cipher_ivlen(enc->cipher);
1042 	enc->key = NULL;
1043 	enc->key_len = cipher_keylen(enc->cipher);
1044 	enc->block_size = cipher_blocksize(enc->cipher);
1045 	return 0;
1046 }
1047 
1048 static int
1049 choose_mac(struct ssh *ssh, struct sshmac *mac, char *client, char *server)
1050 {
1051 	char *name = match_list(client, server, NULL);
1052 
1053 	if (name == NULL)
1054 		return SSH_ERR_NO_MAC_ALG_MATCH;
1055 	if (mac_setup(mac, name) < 0) {
1056 		error_f("unsupported MAC %s", name);
1057 		free(name);
1058 		return SSH_ERR_INTERNAL_ERROR;
1059 	}
1060 	mac->name = name;
1061 	mac->key = NULL;
1062 	mac->enabled = 0;
1063 	return 0;
1064 }
1065 
1066 static int
1067 choose_comp(struct sshcomp *comp, char *client, char *server)
1068 {
1069 	char *name = match_list(client, server, NULL);
1070 
1071 	if (name == NULL)
1072 		return SSH_ERR_NO_COMPRESS_ALG_MATCH;
1073 #ifdef WITH_ZLIB
1074 	if (strcmp(name, "zlib@openssh.com") == 0) {
1075 		comp->type = COMP_DELAYED;
1076 	} else if (strcmp(name, "zlib") == 0) {
1077 		comp->type = COMP_ZLIB;
1078 	} else
1079 #endif	/* WITH_ZLIB */
1080 	if (strcmp(name, "none") == 0) {
1081 		comp->type = COMP_NONE;
1082 	} else {
1083 		error_f("unsupported compression scheme %s", name);
1084 		free(name);
1085 		return SSH_ERR_INTERNAL_ERROR;
1086 	}
1087 	comp->name = name;
1088 	return 0;
1089 }
1090 
1091 static int
1092 choose_kex(struct kex *k, char *client, char *server)
1093 {
1094 	const struct kexalg *kexalg;
1095 
1096 	k->name = match_list(client, server, NULL);
1097 
1098 	debug("kex: algorithm: %s", k->name ? k->name : "(no match)");
1099 	if (k->name == NULL)
1100 		return SSH_ERR_NO_KEX_ALG_MATCH;
1101 	if ((kexalg = kex_alg_by_name(k->name)) == NULL) {
1102 		error_f("unsupported KEX method %s", k->name);
1103 		return SSH_ERR_INTERNAL_ERROR;
1104 	}
1105 	k->kex_type = kexalg->type;
1106 	k->hash_alg = kexalg->hash_alg;
1107 	k->ec_nid = kexalg->ec_nid;
1108 	return 0;
1109 }
1110 
1111 static int
1112 choose_hostkeyalg(struct kex *k, char *client, char *server)
1113 {
1114 	free(k->hostkey_alg);
1115 	k->hostkey_alg = match_list(client, server, NULL);
1116 
1117 	debug("kex: host key algorithm: %s",
1118 	    k->hostkey_alg ? k->hostkey_alg : "(no match)");
1119 	if (k->hostkey_alg == NULL)
1120 		return SSH_ERR_NO_HOSTKEY_ALG_MATCH;
1121 	k->hostkey_type = sshkey_type_from_name(k->hostkey_alg);
1122 	if (k->hostkey_type == KEY_UNSPEC) {
1123 		error_f("unsupported hostkey algorithm %s", k->hostkey_alg);
1124 		return SSH_ERR_INTERNAL_ERROR;
1125 	}
1126 	k->hostkey_nid = sshkey_ecdsa_nid_from_name(k->hostkey_alg);
1127 	return 0;
1128 }
1129 
1130 static int
1131 proposals_match(char *my[PROPOSAL_MAX], char *peer[PROPOSAL_MAX])
1132 {
1133 	static int check[] = {
1134 		PROPOSAL_KEX_ALGS, PROPOSAL_SERVER_HOST_KEY_ALGS, -1
1135 	};
1136 	int *idx;
1137 	char *p;
1138 
1139 	for (idx = &check[0]; *idx != -1; idx++) {
1140 		if ((p = strchr(my[*idx], ',')) != NULL)
1141 			*p = '\0';
1142 		if ((p = strchr(peer[*idx], ',')) != NULL)
1143 			*p = '\0';
1144 		if (strcmp(my[*idx], peer[*idx]) != 0) {
1145 			debug2("proposal mismatch: my %s peer %s",
1146 			    my[*idx], peer[*idx]);
1147 			return (0);
1148 		}
1149 	}
1150 	debug2("proposals match");
1151 	return (1);
1152 }
1153 
1154 static int
1155 kexalgs_contains(char **peer, const char *ext)
1156 {
1157 	return has_any_alg(peer[PROPOSAL_KEX_ALGS], ext);
1158 }
1159 
1160 static int
1161 kex_choose_conf(struct ssh *ssh, uint32_t seq)
1162 {
1163 	struct kex *kex = ssh->kex;
1164 	struct newkeys *newkeys;
1165 	char **my = NULL, **peer = NULL;
1166 	char **cprop, **sprop;
1167 	int nenc, nmac, ncomp;
1168 	u_int mode, ctos, need, dh_need, authlen;
1169 	int r, first_kex_follows;
1170 
1171 	debug2("local %s KEXINIT proposal", kex->server ? "server" : "client");
1172 	if ((r = kex_buf2prop(kex->my, NULL, &my)) != 0)
1173 		goto out;
1174 	debug2("peer %s KEXINIT proposal", kex->server ? "client" : "server");
1175 	if ((r = kex_buf2prop(kex->peer, &first_kex_follows, &peer)) != 0)
1176 		goto out;
1177 
1178 	if (kex->server) {
1179 		cprop=peer;
1180 		sprop=my;
1181 	} else {
1182 		cprop=my;
1183 		sprop=peer;
1184 	}
1185 
1186 	/* Check whether peer supports ext_info/kex_strict */
1187 	if ((kex->flags & KEX_INITIAL) != 0) {
1188 		if (kex->server) {
1189 			kex->ext_info_c = kexalgs_contains(peer, "ext-info-c");
1190 			kex->kex_strict = kexalgs_contains(peer,
1191 			    "kex-strict-c-v00@openssh.com");
1192 		} else {
1193 			kex->ext_info_s = kexalgs_contains(peer, "ext-info-s");
1194 			kex->kex_strict = kexalgs_contains(peer,
1195 			    "kex-strict-s-v00@openssh.com");
1196 		}
1197 		if (kex->kex_strict) {
1198 			debug3_f("will use strict KEX ordering");
1199 			if (seq != 0)
1200 				ssh_packet_disconnect(ssh,
1201 				    "strict KEX violation: "
1202 				    "KEXINIT was not the first packet");
1203 		}
1204 	}
1205 
1206 	/* Check whether client supports rsa-sha2 algorithms */
1207 	if (kex->server && (kex->flags & KEX_INITIAL)) {
1208 		if (has_any_alg(peer[PROPOSAL_SERVER_HOST_KEY_ALGS],
1209 		    "rsa-sha2-256,rsa-sha2-256-cert-v01@openssh.com"))
1210 			kex->flags |= KEX_RSA_SHA2_256_SUPPORTED;
1211 		if (has_any_alg(peer[PROPOSAL_SERVER_HOST_KEY_ALGS],
1212 		    "rsa-sha2-512,rsa-sha2-512-cert-v01@openssh.com"))
1213 			kex->flags |= KEX_RSA_SHA2_512_SUPPORTED;
1214 	}
1215 
1216 	/* Algorithm Negotiation */
1217 	if ((r = choose_kex(kex, cprop[PROPOSAL_KEX_ALGS],
1218 	    sprop[PROPOSAL_KEX_ALGS])) != 0) {
1219 		kex->failed_choice = peer[PROPOSAL_KEX_ALGS];
1220 		peer[PROPOSAL_KEX_ALGS] = NULL;
1221 		goto out;
1222 	}
1223 	if ((r = choose_hostkeyalg(kex, cprop[PROPOSAL_SERVER_HOST_KEY_ALGS],
1224 	    sprop[PROPOSAL_SERVER_HOST_KEY_ALGS])) != 0) {
1225 		kex->failed_choice = peer[PROPOSAL_SERVER_HOST_KEY_ALGS];
1226 		peer[PROPOSAL_SERVER_HOST_KEY_ALGS] = NULL;
1227 		goto out;
1228 	}
1229 	for (mode = 0; mode < MODE_MAX; mode++) {
1230 		if ((newkeys = calloc(1, sizeof(*newkeys))) == NULL) {
1231 			r = SSH_ERR_ALLOC_FAIL;
1232 			goto out;
1233 		}
1234 		kex->newkeys[mode] = newkeys;
1235 		ctos = (!kex->server && mode == MODE_OUT) ||
1236 		    (kex->server && mode == MODE_IN);
1237 		nenc  = ctos ? PROPOSAL_ENC_ALGS_CTOS  : PROPOSAL_ENC_ALGS_STOC;
1238 		nmac  = ctos ? PROPOSAL_MAC_ALGS_CTOS  : PROPOSAL_MAC_ALGS_STOC;
1239 		ncomp = ctos ? PROPOSAL_COMP_ALGS_CTOS : PROPOSAL_COMP_ALGS_STOC;
1240 		if ((r = choose_enc(&newkeys->enc, cprop[nenc],
1241 		    sprop[nenc])) != 0) {
1242 			kex->failed_choice = peer[nenc];
1243 			peer[nenc] = NULL;
1244 			goto out;
1245 		}
1246 		authlen = cipher_authlen(newkeys->enc.cipher);
1247 		/* ignore mac for authenticated encryption */
1248 		if (authlen == 0 &&
1249 		    (r = choose_mac(ssh, &newkeys->mac, cprop[nmac],
1250 		    sprop[nmac])) != 0) {
1251 			kex->failed_choice = peer[nmac];
1252 			peer[nmac] = NULL;
1253 			goto out;
1254 		}
1255 		if ((r = choose_comp(&newkeys->comp, cprop[ncomp],
1256 		    sprop[ncomp])) != 0) {
1257 			kex->failed_choice = peer[ncomp];
1258 			peer[ncomp] = NULL;
1259 			goto out;
1260 		}
1261 		debug("kex: %s cipher: %s MAC: %s compression: %s",
1262 		    ctos ? "client->server" : "server->client",
1263 		    newkeys->enc.name,
1264 		    authlen == 0 ? newkeys->mac.name : "<implicit>",
1265 		    newkeys->comp.name);
1266 	}
1267 	need = dh_need = 0;
1268 	for (mode = 0; mode < MODE_MAX; mode++) {
1269 		newkeys = kex->newkeys[mode];
1270 		need = MAXIMUM(need, newkeys->enc.key_len);
1271 		need = MAXIMUM(need, newkeys->enc.block_size);
1272 		need = MAXIMUM(need, newkeys->enc.iv_len);
1273 		need = MAXIMUM(need, newkeys->mac.key_len);
1274 		dh_need = MAXIMUM(dh_need, cipher_seclen(newkeys->enc.cipher));
1275 		dh_need = MAXIMUM(dh_need, newkeys->enc.block_size);
1276 		dh_need = MAXIMUM(dh_need, newkeys->enc.iv_len);
1277 		dh_need = MAXIMUM(dh_need, newkeys->mac.key_len);
1278 	}
1279 	/* XXX need runden? */
1280 	kex->we_need = need;
1281 	kex->dh_need = dh_need;
1282 
1283 	/* ignore the next message if the proposals do not match */
1284 	if (first_kex_follows && !proposals_match(my, peer))
1285 		ssh->dispatch_skip_packets = 1;
1286 	r = 0;
1287  out:
1288 	kex_prop_free(my);
1289 	kex_prop_free(peer);
1290 	return r;
1291 }
1292 
1293 static int
1294 derive_key(struct ssh *ssh, int id, u_int need, u_char *hash, u_int hashlen,
1295     const struct sshbuf *shared_secret, u_char **keyp)
1296 {
1297 	struct kex *kex = ssh->kex;
1298 	struct ssh_digest_ctx *hashctx = NULL;
1299 	char c = id;
1300 	u_int have;
1301 	size_t mdsz;
1302 	u_char *digest;
1303 	int r;
1304 
1305 	if ((mdsz = ssh_digest_bytes(kex->hash_alg)) == 0)
1306 		return SSH_ERR_INVALID_ARGUMENT;
1307 	if ((digest = calloc(1, ROUNDUP(need, mdsz))) == NULL) {
1308 		r = SSH_ERR_ALLOC_FAIL;
1309 		goto out;
1310 	}
1311 
1312 	/* K1 = HASH(K || H || "A" || session_id) */
1313 	if ((hashctx = ssh_digest_start(kex->hash_alg)) == NULL ||
1314 	    ssh_digest_update_buffer(hashctx, shared_secret) != 0 ||
1315 	    ssh_digest_update(hashctx, hash, hashlen) != 0 ||
1316 	    ssh_digest_update(hashctx, &c, 1) != 0 ||
1317 	    ssh_digest_update_buffer(hashctx, kex->session_id) != 0 ||
1318 	    ssh_digest_final(hashctx, digest, mdsz) != 0) {
1319 		r = SSH_ERR_LIBCRYPTO_ERROR;
1320 		error_f("KEX hash failed");
1321 		goto out;
1322 	}
1323 	ssh_digest_free(hashctx);
1324 	hashctx = NULL;
1325 
1326 	/*
1327 	 * expand key:
1328 	 * Kn = HASH(K || H || K1 || K2 || ... || Kn-1)
1329 	 * Key = K1 || K2 || ... || Kn
1330 	 */
1331 	for (have = mdsz; need > have; have += mdsz) {
1332 		if ((hashctx = ssh_digest_start(kex->hash_alg)) == NULL ||
1333 		    ssh_digest_update_buffer(hashctx, shared_secret) != 0 ||
1334 		    ssh_digest_update(hashctx, hash, hashlen) != 0 ||
1335 		    ssh_digest_update(hashctx, digest, have) != 0 ||
1336 		    ssh_digest_final(hashctx, digest + have, mdsz) != 0) {
1337 			error_f("KDF failed");
1338 			r = SSH_ERR_LIBCRYPTO_ERROR;
1339 			goto out;
1340 		}
1341 		ssh_digest_free(hashctx);
1342 		hashctx = NULL;
1343 	}
1344 #ifdef DEBUG_KEX
1345 	fprintf(stderr, "key '%c'== ", c);
1346 	dump_digest("key", digest, need);
1347 #endif
1348 	*keyp = digest;
1349 	digest = NULL;
1350 	r = 0;
1351  out:
1352 	free(digest);
1353 	ssh_digest_free(hashctx);
1354 	return r;
1355 }
1356 
1357 #define NKEYS	6
1358 int
1359 kex_derive_keys(struct ssh *ssh, u_char *hash, u_int hashlen,
1360     const struct sshbuf *shared_secret)
1361 {
1362 	struct kex *kex = ssh->kex;
1363 	u_char *keys[NKEYS];
1364 	u_int i, j, mode, ctos;
1365 	int r;
1366 
1367 	/* save initial hash as session id */
1368 	if ((kex->flags & KEX_INITIAL) != 0) {
1369 		if (sshbuf_len(kex->session_id) != 0) {
1370 			error_f("already have session ID at kex");
1371 			return SSH_ERR_INTERNAL_ERROR;
1372 		}
1373 		if ((r = sshbuf_put(kex->session_id, hash, hashlen)) != 0)
1374 			return r;
1375 	} else if (sshbuf_len(kex->session_id) == 0) {
1376 		error_f("no session ID in rekex");
1377 		return SSH_ERR_INTERNAL_ERROR;
1378 	}
1379 	for (i = 0; i < NKEYS; i++) {
1380 		if ((r = derive_key(ssh, 'A'+i, kex->we_need, hash, hashlen,
1381 		    shared_secret, &keys[i])) != 0) {
1382 			for (j = 0; j < i; j++)
1383 				free(keys[j]);
1384 			return r;
1385 		}
1386 	}
1387 	for (mode = 0; mode < MODE_MAX; mode++) {
1388 		ctos = (!kex->server && mode == MODE_OUT) ||
1389 		    (kex->server && mode == MODE_IN);
1390 		kex->newkeys[mode]->enc.iv  = keys[ctos ? 0 : 1];
1391 		kex->newkeys[mode]->enc.key = keys[ctos ? 2 : 3];
1392 		kex->newkeys[mode]->mac.key = keys[ctos ? 4 : 5];
1393 	}
1394 	return 0;
1395 }
1396 
1397 int
1398 kex_load_hostkey(struct ssh *ssh, struct sshkey **prvp, struct sshkey **pubp)
1399 {
1400 	struct kex *kex = ssh->kex;
1401 
1402 	*pubp = NULL;
1403 	*prvp = NULL;
1404 	if (kex->load_host_public_key == NULL ||
1405 	    kex->load_host_private_key == NULL) {
1406 		error_f("missing hostkey loader");
1407 		return SSH_ERR_INVALID_ARGUMENT;
1408 	}
1409 	*pubp = kex->load_host_public_key(kex->hostkey_type,
1410 	    kex->hostkey_nid, ssh);
1411 	*prvp = kex->load_host_private_key(kex->hostkey_type,
1412 	    kex->hostkey_nid, ssh);
1413 	if (*pubp == NULL)
1414 		return SSH_ERR_NO_HOSTKEY_LOADED;
1415 	return 0;
1416 }
1417 
1418 int
1419 kex_verify_host_key(struct ssh *ssh, struct sshkey *server_host_key)
1420 {
1421 	struct kex *kex = ssh->kex;
1422 
1423 	if (kex->verify_host_key == NULL) {
1424 		error_f("missing hostkey verifier");
1425 		return SSH_ERR_INVALID_ARGUMENT;
1426 	}
1427 	if (server_host_key->type != kex->hostkey_type ||
1428 	    (kex->hostkey_type == KEY_ECDSA &&
1429 	    server_host_key->ecdsa_nid != kex->hostkey_nid))
1430 		return SSH_ERR_KEY_TYPE_MISMATCH;
1431 	if (kex->verify_host_key(server_host_key, ssh) == -1)
1432 		return  SSH_ERR_SIGNATURE_INVALID;
1433 	return 0;
1434 }
1435 
1436 #if defined(DEBUG_KEX) || defined(DEBUG_KEXDH) || defined(DEBUG_KEXECDH)
1437 void
1438 dump_digest(const char *msg, const u_char *digest, int len)
1439 {
1440 	fprintf(stderr, "%s\n", msg);
1441 	sshbuf_dump_data(digest, len, stderr);
1442 }
1443 #endif
1444 
1445 /*
1446  * Send a plaintext error message to the peer, suffixed by \r\n.
1447  * Only used during banner exchange, and there only for the server.
1448  */
1449 static void
1450 send_error(struct ssh *ssh, char *msg)
1451 {
1452 	char *crnl = "\r\n";
1453 
1454 	if (!ssh->kex->server)
1455 		return;
1456 
1457 	if (atomicio(vwrite, ssh_packet_get_connection_out(ssh),
1458 	    msg, strlen(msg)) != strlen(msg) ||
1459 	    atomicio(vwrite, ssh_packet_get_connection_out(ssh),
1460 	    crnl, strlen(crnl)) != strlen(crnl))
1461 		error_f("write: %.100s", strerror(errno));
1462 }
1463 
1464 /*
1465  * Sends our identification string and waits for the peer's. Will block for
1466  * up to timeout_ms (or indefinitely if timeout_ms <= 0).
1467  * Returns on 0 success or a ssherr.h code on failure.
1468  */
1469 int
1470 kex_exchange_identification(struct ssh *ssh, int timeout_ms,
1471     const char *version_addendum)
1472 {
1473 	int remote_major, remote_minor, mismatch, oerrno = 0;
1474 	size_t len, n;
1475 	int r, expect_nl;
1476 	u_char c;
1477 	struct sshbuf *our_version = ssh->kex->server ?
1478 	    ssh->kex->server_version : ssh->kex->client_version;
1479 	struct sshbuf *peer_version = ssh->kex->server ?
1480 	    ssh->kex->client_version : ssh->kex->server_version;
1481 	char *our_version_string = NULL, *peer_version_string = NULL;
1482 	char *cp, *remote_version = NULL;
1483 
1484 	/* Prepare and send our banner */
1485 	sshbuf_reset(our_version);
1486 	if (version_addendum != NULL && *version_addendum == '\0')
1487 		version_addendum = NULL;
1488 	if ((r = sshbuf_putf(our_version, "SSH-%d.%d-%s%s%s\r\n",
1489 	    PROTOCOL_MAJOR_2, PROTOCOL_MINOR_2, SSH_VERSION,
1490 	    version_addendum == NULL ? "" : " ",
1491 	    version_addendum == NULL ? "" : version_addendum)) != 0) {
1492 		oerrno = errno;
1493 		error_fr(r, "sshbuf_putf");
1494 		goto out;
1495 	}
1496 
1497 	if (atomicio(vwrite, ssh_packet_get_connection_out(ssh),
1498 	    sshbuf_mutable_ptr(our_version),
1499 	    sshbuf_len(our_version)) != sshbuf_len(our_version)) {
1500 		oerrno = errno;
1501 		debug_f("write: %.100s", strerror(errno));
1502 		r = SSH_ERR_SYSTEM_ERROR;
1503 		goto out;
1504 	}
1505 	if ((r = sshbuf_consume_end(our_version, 2)) != 0) { /* trim \r\n */
1506 		oerrno = errno;
1507 		error_fr(r, "sshbuf_consume_end");
1508 		goto out;
1509 	}
1510 	our_version_string = sshbuf_dup_string(our_version);
1511 	if (our_version_string == NULL) {
1512 		error_f("sshbuf_dup_string failed");
1513 		r = SSH_ERR_ALLOC_FAIL;
1514 		goto out;
1515 	}
1516 	debug("Local version string %.100s", our_version_string);
1517 
1518 	/* Read other side's version identification. */
1519 	for (n = 0; ; n++) {
1520 		if (n >= SSH_MAX_PRE_BANNER_LINES) {
1521 			send_error(ssh, "No SSH identification string "
1522 			    "received.");
1523 			error_f("No SSH version received in first %u lines "
1524 			    "from server", SSH_MAX_PRE_BANNER_LINES);
1525 			r = SSH_ERR_INVALID_FORMAT;
1526 			goto out;
1527 		}
1528 		sshbuf_reset(peer_version);
1529 		expect_nl = 0;
1530 		for (;;) {
1531 			if (timeout_ms > 0) {
1532 				r = waitrfd(ssh_packet_get_connection_in(ssh),
1533 				    &timeout_ms, NULL);
1534 				if (r == -1 && errno == ETIMEDOUT) {
1535 					send_error(ssh, "Timed out waiting "
1536 					    "for SSH identification string.");
1537 					error("Connection timed out during "
1538 					    "banner exchange");
1539 					r = SSH_ERR_CONN_TIMEOUT;
1540 					goto out;
1541 				} else if (r == -1) {
1542 					oerrno = errno;
1543 					error_f("%s", strerror(errno));
1544 					r = SSH_ERR_SYSTEM_ERROR;
1545 					goto out;
1546 				}
1547 			}
1548 
1549 			len = atomicio(read, ssh_packet_get_connection_in(ssh),
1550 			    &c, 1);
1551 			if (len != 1 && errno == EPIPE) {
1552 				verbose_f("Connection closed by remote host");
1553 				r = SSH_ERR_CONN_CLOSED;
1554 				goto out;
1555 			} else if (len != 1) {
1556 				oerrno = errno;
1557 				error_f("read: %.100s", strerror(errno));
1558 				r = SSH_ERR_SYSTEM_ERROR;
1559 				goto out;
1560 			}
1561 			if (c == '\r') {
1562 				expect_nl = 1;
1563 				continue;
1564 			}
1565 			if (c == '\n')
1566 				break;
1567 			if (c == '\0' || expect_nl) {
1568 				verbose_f("banner line contains invalid "
1569 				    "characters");
1570 				goto invalid;
1571 			}
1572 			if ((r = sshbuf_put_u8(peer_version, c)) != 0) {
1573 				oerrno = errno;
1574 				error_fr(r, "sshbuf_put");
1575 				goto out;
1576 			}
1577 			if (sshbuf_len(peer_version) > SSH_MAX_BANNER_LEN) {
1578 				verbose_f("banner line too long");
1579 				goto invalid;
1580 			}
1581 		}
1582 		/* Is this an actual protocol banner? */
1583 		if (sshbuf_len(peer_version) > 4 &&
1584 		    memcmp(sshbuf_ptr(peer_version), "SSH-", 4) == 0)
1585 			break;
1586 		/* If not, then just log the line and continue */
1587 		if ((cp = sshbuf_dup_string(peer_version)) == NULL) {
1588 			error_f("sshbuf_dup_string failed");
1589 			r = SSH_ERR_ALLOC_FAIL;
1590 			goto out;
1591 		}
1592 		/* Do not accept lines before the SSH ident from a client */
1593 		if (ssh->kex->server) {
1594 			verbose_f("client sent invalid protocol identifier "
1595 			    "\"%.256s\"", cp);
1596 			free(cp);
1597 			goto invalid;
1598 		}
1599 		debug_f("banner line %zu: %s", n, cp);
1600 		free(cp);
1601 	}
1602 	peer_version_string = sshbuf_dup_string(peer_version);
1603 	if (peer_version_string == NULL)
1604 		fatal_f("sshbuf_dup_string failed");
1605 	/* XXX must be same size for sscanf */
1606 	if ((remote_version = calloc(1, sshbuf_len(peer_version))) == NULL) {
1607 		error_f("calloc failed");
1608 		r = SSH_ERR_ALLOC_FAIL;
1609 		goto out;
1610 	}
1611 
1612 	/*
1613 	 * Check that the versions match.  In future this might accept
1614 	 * several versions and set appropriate flags to handle them.
1615 	 */
1616 	if (sscanf(peer_version_string, "SSH-%d.%d-%[^\n]\n",
1617 	    &remote_major, &remote_minor, remote_version) != 3) {
1618 		error("Bad remote protocol version identification: '%.100s'",
1619 		    peer_version_string);
1620  invalid:
1621 		send_error(ssh, "Invalid SSH identification string.");
1622 		r = SSH_ERR_INVALID_FORMAT;
1623 		goto out;
1624 	}
1625 	debug("Remote protocol version %d.%d, remote software version %.100s",
1626 	    remote_major, remote_minor, remote_version);
1627 	compat_banner(ssh, remote_version);
1628 
1629 	mismatch = 0;
1630 	switch (remote_major) {
1631 	case 2:
1632 		break;
1633 	case 1:
1634 		if (remote_minor != 99)
1635 			mismatch = 1;
1636 		break;
1637 	default:
1638 		mismatch = 1;
1639 		break;
1640 	}
1641 	if (mismatch) {
1642 		error("Protocol major versions differ: %d vs. %d",
1643 		    PROTOCOL_MAJOR_2, remote_major);
1644 		send_error(ssh, "Protocol major versions differ.");
1645 		r = SSH_ERR_NO_PROTOCOL_VERSION;
1646 		goto out;
1647 	}
1648 
1649 	if (ssh->kex->server && (ssh->compat & SSH_BUG_PROBE) != 0) {
1650 		logit("probed from %s port %d with %s.  Don't panic.",
1651 		    ssh_remote_ipaddr(ssh), ssh_remote_port(ssh),
1652 		    peer_version_string);
1653 		r = SSH_ERR_CONN_CLOSED; /* XXX */
1654 		goto out;
1655 	}
1656 	if (ssh->kex->server && (ssh->compat & SSH_BUG_SCANNER) != 0) {
1657 		logit("scanned from %s port %d with %s.  Don't panic.",
1658 		    ssh_remote_ipaddr(ssh), ssh_remote_port(ssh),
1659 		    peer_version_string);
1660 		r = SSH_ERR_CONN_CLOSED; /* XXX */
1661 		goto out;
1662 	}
1663 	/* success */
1664 	r = 0;
1665  out:
1666 	free(our_version_string);
1667 	free(peer_version_string);
1668 	free(remote_version);
1669 	if (r == SSH_ERR_SYSTEM_ERROR)
1670 		errno = oerrno;
1671 	return r;
1672 }
1673 
1674