xref: /openbsd-src/usr.bin/ssh/kex.c (revision 8dfe214903ce3625c937d5fad2469e8a0d1d4d71)
1 /* $OpenBSD: kex.c,v 1.184 2023/12/18 14:45:49 djm Exp $ */
2 /*
3  * Copyright (c) 2000, 2001 Markus Friedl.  All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  * 1. Redistributions of source code must retain the above copyright
9  *    notice, this list of conditions and the following disclaimer.
10  * 2. Redistributions in binary form must reproduce the above copyright
11  *    notice, this list of conditions and the following disclaimer in the
12  *    documentation and/or other materials provided with the distribution.
13  *
14  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
15  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
16  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
17  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
18  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
19  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
20  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
21  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
23  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24  */
25 
26 
27 #include <sys/types.h>
28 #include <errno.h>
29 #include <signal.h>
30 #include <stdio.h>
31 #include <stdlib.h>
32 #include <string.h>
33 #include <unistd.h>
34 #include <poll.h>
35 
36 #ifdef WITH_OPENSSL
37 #include <openssl/crypto.h>
38 #endif
39 
40 #include "ssh.h"
41 #include "ssh2.h"
42 #include "atomicio.h"
43 #include "version.h"
44 #include "packet.h"
45 #include "compat.h"
46 #include "cipher.h"
47 #include "sshkey.h"
48 #include "kex.h"
49 #include "log.h"
50 #include "mac.h"
51 #include "match.h"
52 #include "misc.h"
53 #include "dispatch.h"
54 #include "monitor.h"
55 #include "myproposal.h"
56 
57 #include "ssherr.h"
58 #include "sshbuf.h"
59 #include "digest.h"
60 #include "xmalloc.h"
61 
62 /* prototype */
63 static int kex_choose_conf(struct ssh *, uint32_t seq);
64 static int kex_input_newkeys(int, u_int32_t, struct ssh *);
65 
66 static const char * const proposal_names[PROPOSAL_MAX] = {
67 	"KEX algorithms",
68 	"host key algorithms",
69 	"ciphers ctos",
70 	"ciphers stoc",
71 	"MACs ctos",
72 	"MACs stoc",
73 	"compression ctos",
74 	"compression stoc",
75 	"languages ctos",
76 	"languages stoc",
77 };
78 
79 struct kexalg {
80 	char *name;
81 	u_int type;
82 	int ec_nid;
83 	int hash_alg;
84 };
85 static const struct kexalg kexalgs[] = {
86 #ifdef WITH_OPENSSL
87 	{ KEX_DH1, KEX_DH_GRP1_SHA1, 0, SSH_DIGEST_SHA1 },
88 	{ KEX_DH14_SHA1, KEX_DH_GRP14_SHA1, 0, SSH_DIGEST_SHA1 },
89 	{ KEX_DH14_SHA256, KEX_DH_GRP14_SHA256, 0, SSH_DIGEST_SHA256 },
90 	{ KEX_DH16_SHA512, KEX_DH_GRP16_SHA512, 0, SSH_DIGEST_SHA512 },
91 	{ KEX_DH18_SHA512, KEX_DH_GRP18_SHA512, 0, SSH_DIGEST_SHA512 },
92 	{ KEX_DHGEX_SHA1, KEX_DH_GEX_SHA1, 0, SSH_DIGEST_SHA1 },
93 	{ KEX_DHGEX_SHA256, KEX_DH_GEX_SHA256, 0, SSH_DIGEST_SHA256 },
94 	{ KEX_ECDH_SHA2_NISTP256, KEX_ECDH_SHA2,
95 	    NID_X9_62_prime256v1, SSH_DIGEST_SHA256 },
96 	{ KEX_ECDH_SHA2_NISTP384, KEX_ECDH_SHA2, NID_secp384r1,
97 	    SSH_DIGEST_SHA384 },
98 	{ KEX_ECDH_SHA2_NISTP521, KEX_ECDH_SHA2, NID_secp521r1,
99 	    SSH_DIGEST_SHA512 },
100 #endif
101 	{ KEX_CURVE25519_SHA256, KEX_C25519_SHA256, 0, SSH_DIGEST_SHA256 },
102 	{ KEX_CURVE25519_SHA256_OLD, KEX_C25519_SHA256, 0, SSH_DIGEST_SHA256 },
103 	{ KEX_SNTRUP761X25519_SHA512, KEX_KEM_SNTRUP761X25519_SHA512, 0,
104 	    SSH_DIGEST_SHA512 },
105 	{ NULL, 0, -1, -1},
106 };
107 
108 char *
109 kex_alg_list(char sep)
110 {
111 	char *ret = NULL, *tmp;
112 	size_t nlen, rlen = 0;
113 	const struct kexalg *k;
114 
115 	for (k = kexalgs; k->name != NULL; k++) {
116 		if (ret != NULL)
117 			ret[rlen++] = sep;
118 		nlen = strlen(k->name);
119 		if ((tmp = realloc(ret, rlen + nlen + 2)) == NULL) {
120 			free(ret);
121 			return NULL;
122 		}
123 		ret = tmp;
124 		memcpy(ret + rlen, k->name, nlen + 1);
125 		rlen += nlen;
126 	}
127 	return ret;
128 }
129 
130 static const struct kexalg *
131 kex_alg_by_name(const char *name)
132 {
133 	const struct kexalg *k;
134 
135 	for (k = kexalgs; k->name != NULL; k++) {
136 		if (strcmp(k->name, name) == 0)
137 			return k;
138 	}
139 	return NULL;
140 }
141 
142 /* Validate KEX method name list */
143 int
144 kex_names_valid(const char *names)
145 {
146 	char *s, *cp, *p;
147 
148 	if (names == NULL || strcmp(names, "") == 0)
149 		return 0;
150 	if ((s = cp = strdup(names)) == NULL)
151 		return 0;
152 	for ((p = strsep(&cp, ",")); p && *p != '\0';
153 	    (p = strsep(&cp, ","))) {
154 		if (kex_alg_by_name(p) == NULL) {
155 			error("Unsupported KEX algorithm \"%.100s\"", p);
156 			free(s);
157 			return 0;
158 		}
159 	}
160 	debug3("kex names ok: [%s]", names);
161 	free(s);
162 	return 1;
163 }
164 
165 /* returns non-zero if proposal contains any algorithm from algs */
166 static int
167 has_any_alg(const char *proposal, const char *algs)
168 {
169 	char *cp;
170 
171 	if ((cp = match_list(proposal, algs, NULL)) == NULL)
172 		return 0;
173 	free(cp);
174 	return 1;
175 }
176 
177 /*
178  * Concatenate algorithm names, avoiding duplicates in the process.
179  * Caller must free returned string.
180  */
181 char *
182 kex_names_cat(const char *a, const char *b)
183 {
184 	char *ret = NULL, *tmp = NULL, *cp, *p;
185 	size_t len;
186 
187 	if (a == NULL || *a == '\0')
188 		return strdup(b);
189 	if (b == NULL || *b == '\0')
190 		return strdup(a);
191 	if (strlen(b) > 1024*1024)
192 		return NULL;
193 	len = strlen(a) + strlen(b) + 2;
194 	if ((tmp = cp = strdup(b)) == NULL ||
195 	    (ret = calloc(1, len)) == NULL) {
196 		free(tmp);
197 		return NULL;
198 	}
199 	strlcpy(ret, a, len);
200 	for ((p = strsep(&cp, ",")); p && *p != '\0'; (p = strsep(&cp, ","))) {
201 		if (has_any_alg(ret, p))
202 			continue; /* Algorithm already present */
203 		if (strlcat(ret, ",", len) >= len ||
204 		    strlcat(ret, p, len) >= len) {
205 			free(tmp);
206 			free(ret);
207 			return NULL; /* Shouldn't happen */
208 		}
209 	}
210 	free(tmp);
211 	return ret;
212 }
213 
214 /*
215  * Assemble a list of algorithms from a default list and a string from a
216  * configuration file. The user-provided string may begin with '+' to
217  * indicate that it should be appended to the default, '-' that the
218  * specified names should be removed, or '^' that they should be placed
219  * at the head.
220  */
221 int
222 kex_assemble_names(char **listp, const char *def, const char *all)
223 {
224 	char *cp, *tmp, *patterns;
225 	char *list = NULL, *ret = NULL, *matching = NULL, *opatterns = NULL;
226 	int r = SSH_ERR_INTERNAL_ERROR;
227 
228 	if (listp == NULL || def == NULL || all == NULL)
229 		return SSH_ERR_INVALID_ARGUMENT;
230 
231 	if (*listp == NULL || **listp == '\0') {
232 		if ((*listp = strdup(def)) == NULL)
233 			return SSH_ERR_ALLOC_FAIL;
234 		return 0;
235 	}
236 
237 	list = *listp;
238 	*listp = NULL;
239 	if (*list == '+') {
240 		/* Append names to default list */
241 		if ((tmp = kex_names_cat(def, list + 1)) == NULL) {
242 			r = SSH_ERR_ALLOC_FAIL;
243 			goto fail;
244 		}
245 		free(list);
246 		list = tmp;
247 	} else if (*list == '-') {
248 		/* Remove names from default list */
249 		if ((*listp = match_filter_denylist(def, list + 1)) == NULL) {
250 			r = SSH_ERR_ALLOC_FAIL;
251 			goto fail;
252 		}
253 		free(list);
254 		/* filtering has already been done */
255 		return 0;
256 	} else if (*list == '^') {
257 		/* Place names at head of default list */
258 		if ((tmp = kex_names_cat(list + 1, def)) == NULL) {
259 			r = SSH_ERR_ALLOC_FAIL;
260 			goto fail;
261 		}
262 		free(list);
263 		list = tmp;
264 	} else {
265 		/* Explicit list, overrides default - just use "list" as is */
266 	}
267 
268 	/*
269 	 * The supplied names may be a pattern-list. For the -list case,
270 	 * the patterns are applied above. For the +list and explicit list
271 	 * cases we need to do it now.
272 	 */
273 	ret = NULL;
274 	if ((patterns = opatterns = strdup(list)) == NULL) {
275 		r = SSH_ERR_ALLOC_FAIL;
276 		goto fail;
277 	}
278 	/* Apply positive (i.e. non-negated) patterns from the list */
279 	while ((cp = strsep(&patterns, ",")) != NULL) {
280 		if (*cp == '!') {
281 			/* negated matches are not supported here */
282 			r = SSH_ERR_INVALID_ARGUMENT;
283 			goto fail;
284 		}
285 		free(matching);
286 		if ((matching = match_filter_allowlist(all, cp)) == NULL) {
287 			r = SSH_ERR_ALLOC_FAIL;
288 			goto fail;
289 		}
290 		if ((tmp = kex_names_cat(ret, matching)) == NULL) {
291 			r = SSH_ERR_ALLOC_FAIL;
292 			goto fail;
293 		}
294 		free(ret);
295 		ret = tmp;
296 	}
297 	if (ret == NULL || *ret == '\0') {
298 		/* An empty name-list is an error */
299 		/* XXX better error code? */
300 		r = SSH_ERR_INVALID_ARGUMENT;
301 		goto fail;
302 	}
303 
304 	/* success */
305 	*listp = ret;
306 	ret = NULL;
307 	r = 0;
308 
309  fail:
310 	free(matching);
311 	free(opatterns);
312 	free(list);
313 	free(ret);
314 	return r;
315 }
316 
317 /*
318  * Fill out a proposal array with dynamically allocated values, which may
319  * be modified as required for compatibility reasons.
320  * Any of the options may be NULL, in which case the default is used.
321  * Array contents must be freed by calling kex_proposal_free_entries.
322  */
323 void
324 kex_proposal_populate_entries(struct ssh *ssh, char *prop[PROPOSAL_MAX],
325     const char *kexalgos, const char *ciphers, const char *macs,
326     const char *comp, const char *hkalgs)
327 {
328 	const char *defpropserver[PROPOSAL_MAX] = { KEX_SERVER };
329 	const char *defpropclient[PROPOSAL_MAX] = { KEX_CLIENT };
330 	const char **defprop = ssh->kex->server ? defpropserver : defpropclient;
331 	u_int i;
332 	char *cp;
333 
334 	if (prop == NULL)
335 		fatal_f("proposal missing");
336 
337 	/* Append EXT_INFO signalling to KexAlgorithms */
338 	if (kexalgos == NULL)
339 		kexalgos = defprop[PROPOSAL_KEX_ALGS];
340 	if ((cp = kex_names_cat(kexalgos, ssh->kex->server ?
341 	    "ext-info-s,kex-strict-s-v00@openssh.com" :
342 	    "ext-info-c,kex-strict-c-v00@openssh.com")) == NULL)
343 		fatal_f("kex_names_cat");
344 
345 	for (i = 0; i < PROPOSAL_MAX; i++) {
346 		switch(i) {
347 		case PROPOSAL_KEX_ALGS:
348 			prop[i] = compat_kex_proposal(ssh, cp);
349 			break;
350 		case PROPOSAL_ENC_ALGS_CTOS:
351 		case PROPOSAL_ENC_ALGS_STOC:
352 			prop[i] = xstrdup(ciphers ? ciphers : defprop[i]);
353 			break;
354 		case PROPOSAL_MAC_ALGS_CTOS:
355 		case PROPOSAL_MAC_ALGS_STOC:
356 			prop[i]  = xstrdup(macs ? macs : defprop[i]);
357 			break;
358 		case PROPOSAL_COMP_ALGS_CTOS:
359 		case PROPOSAL_COMP_ALGS_STOC:
360 			prop[i] = xstrdup(comp ? comp : defprop[i]);
361 			break;
362 		case PROPOSAL_SERVER_HOST_KEY_ALGS:
363 			prop[i] = xstrdup(hkalgs ? hkalgs : defprop[i]);
364 			break;
365 		default:
366 			prop[i] = xstrdup(defprop[i]);
367 		}
368 	}
369 	free(cp);
370 }
371 
372 void
373 kex_proposal_free_entries(char *prop[PROPOSAL_MAX])
374 {
375 	u_int i;
376 
377 	for (i = 0; i < PROPOSAL_MAX; i++)
378 		free(prop[i]);
379 }
380 
381 /* put algorithm proposal into buffer */
382 int
383 kex_prop2buf(struct sshbuf *b, char *proposal[PROPOSAL_MAX])
384 {
385 	u_int i;
386 	int r;
387 
388 	sshbuf_reset(b);
389 
390 	/*
391 	 * add a dummy cookie, the cookie will be overwritten by
392 	 * kex_send_kexinit(), each time a kexinit is set
393 	 */
394 	for (i = 0; i < KEX_COOKIE_LEN; i++) {
395 		if ((r = sshbuf_put_u8(b, 0)) != 0)
396 			return r;
397 	}
398 	for (i = 0; i < PROPOSAL_MAX; i++) {
399 		if ((r = sshbuf_put_cstring(b, proposal[i])) != 0)
400 			return r;
401 	}
402 	if ((r = sshbuf_put_u8(b, 0)) != 0 ||	/* first_kex_packet_follows */
403 	    (r = sshbuf_put_u32(b, 0)) != 0)	/* uint32 reserved */
404 		return r;
405 	return 0;
406 }
407 
408 /* parse buffer and return algorithm proposal */
409 int
410 kex_buf2prop(struct sshbuf *raw, int *first_kex_follows, char ***propp)
411 {
412 	struct sshbuf *b = NULL;
413 	u_char v;
414 	u_int i;
415 	char **proposal = NULL;
416 	int r;
417 
418 	*propp = NULL;
419 	if ((proposal = calloc(PROPOSAL_MAX, sizeof(char *))) == NULL)
420 		return SSH_ERR_ALLOC_FAIL;
421 	if ((b = sshbuf_fromb(raw)) == NULL) {
422 		r = SSH_ERR_ALLOC_FAIL;
423 		goto out;
424 	}
425 	if ((r = sshbuf_consume(b, KEX_COOKIE_LEN)) != 0) { /* skip cookie */
426 		error_fr(r, "consume cookie");
427 		goto out;
428 	}
429 	/* extract kex init proposal strings */
430 	for (i = 0; i < PROPOSAL_MAX; i++) {
431 		if ((r = sshbuf_get_cstring(b, &(proposal[i]), NULL)) != 0) {
432 			error_fr(r, "parse proposal %u", i);
433 			goto out;
434 		}
435 		debug2("%s: %s", proposal_names[i], proposal[i]);
436 	}
437 	/* first kex follows / reserved */
438 	if ((r = sshbuf_get_u8(b, &v)) != 0 ||	/* first_kex_follows */
439 	    (r = sshbuf_get_u32(b, &i)) != 0) {	/* reserved */
440 		error_fr(r, "parse");
441 		goto out;
442 	}
443 	if (first_kex_follows != NULL)
444 		*first_kex_follows = v;
445 	debug2("first_kex_follows %d ", v);
446 	debug2("reserved %u ", i);
447 	r = 0;
448 	*propp = proposal;
449  out:
450 	if (r != 0 && proposal != NULL)
451 		kex_prop_free(proposal);
452 	sshbuf_free(b);
453 	return r;
454 }
455 
456 void
457 kex_prop_free(char **proposal)
458 {
459 	u_int i;
460 
461 	if (proposal == NULL)
462 		return;
463 	for (i = 0; i < PROPOSAL_MAX; i++)
464 		free(proposal[i]);
465 	free(proposal);
466 }
467 
468 int
469 kex_protocol_error(int type, u_int32_t seq, struct ssh *ssh)
470 {
471 	int r;
472 
473 	/* If in strict mode, any unexpected message is an error */
474 	if ((ssh->kex->flags & KEX_INITIAL) && ssh->kex->kex_strict) {
475 		ssh_packet_disconnect(ssh, "strict KEX violation: "
476 		    "unexpected packet type %u (seqnr %u)", type, seq);
477 	}
478 	error_f("type %u seq %u", type, seq);
479 	if ((r = sshpkt_start(ssh, SSH2_MSG_UNIMPLEMENTED)) != 0 ||
480 	    (r = sshpkt_put_u32(ssh, seq)) != 0 ||
481 	    (r = sshpkt_send(ssh)) != 0)
482 		return r;
483 	return 0;
484 }
485 
486 static void
487 kex_reset_dispatch(struct ssh *ssh)
488 {
489 	ssh_dispatch_range(ssh, SSH2_MSG_TRANSPORT_MIN,
490 	    SSH2_MSG_TRANSPORT_MAX, &kex_protocol_error);
491 }
492 
493 void
494 kex_set_server_sig_algs(struct ssh *ssh, const char *allowed_algs)
495 {
496 	char *alg, *oalgs, *algs, *sigalgs;
497 	const char *sigalg;
498 
499 	/*
500 	 * NB. allowed algorithms may contain certificate algorithms that
501 	 * map to a specific plain signature type, e.g.
502 	 * rsa-sha2-512-cert-v01@openssh.com => rsa-sha2-512
503 	 * We need to be careful here to match these, retain the mapping
504 	 * and only add each signature algorithm once.
505 	 */
506 	if ((sigalgs = sshkey_alg_list(0, 1, 1, ',')) == NULL)
507 		fatal_f("sshkey_alg_list failed");
508 	oalgs = algs = xstrdup(allowed_algs);
509 	free(ssh->kex->server_sig_algs);
510 	ssh->kex->server_sig_algs = NULL;
511 	for ((alg = strsep(&algs, ",")); alg != NULL && *alg != '\0';
512 	    (alg = strsep(&algs, ","))) {
513 		if ((sigalg = sshkey_sigalg_by_name(alg)) == NULL)
514 			continue;
515 		if (!has_any_alg(sigalg, sigalgs))
516 			continue;
517 		/* Don't add an algorithm twice. */
518 		if (ssh->kex->server_sig_algs != NULL &&
519 		    has_any_alg(sigalg, ssh->kex->server_sig_algs))
520 			continue;
521 		xextendf(&ssh->kex->server_sig_algs, ",", "%s", sigalg);
522 	}
523 	free(oalgs);
524 	free(sigalgs);
525 	if (ssh->kex->server_sig_algs == NULL)
526 		ssh->kex->server_sig_algs = xstrdup("");
527 }
528 
529 static int
530 kex_compose_ext_info_server(struct ssh *ssh, struct sshbuf *m)
531 {
532 	int r;
533 
534 	if (ssh->kex->server_sig_algs == NULL &&
535 	    (ssh->kex->server_sig_algs = sshkey_alg_list(0, 1, 1, ',')) == NULL)
536 		return SSH_ERR_ALLOC_FAIL;
537 	if ((r = sshbuf_put_u32(m, 3)) != 0 ||
538 	    (r = sshbuf_put_cstring(m, "server-sig-algs")) != 0 ||
539 	    (r = sshbuf_put_cstring(m, ssh->kex->server_sig_algs)) != 0 ||
540 	    (r = sshbuf_put_cstring(m,
541 	    "publickey-hostbound@openssh.com")) != 0 ||
542 	    (r = sshbuf_put_cstring(m, "0")) != 0 ||
543 	    (r = sshbuf_put_cstring(m, "ping@openssh.com")) != 0 ||
544 	    (r = sshbuf_put_cstring(m, "0")) != 0) {
545 		error_fr(r, "compose");
546 		return r;
547 	}
548 	return 0;
549 }
550 
551 static int
552 kex_compose_ext_info_client(struct ssh *ssh, struct sshbuf *m)
553 {
554 	int r;
555 
556 	if ((r = sshbuf_put_u32(m, 1)) != 0 ||
557 	    (r = sshbuf_put_cstring(m, "ext-info-in-auth@openssh.com")) != 0 ||
558 	    (r = sshbuf_put_cstring(m, "0")) != 0) {
559 		error_fr(r, "compose");
560 		goto out;
561 	}
562 	/* success */
563 	r = 0;
564  out:
565 	return r;
566 }
567 
568 static int
569 kex_maybe_send_ext_info(struct ssh *ssh)
570 {
571 	int r;
572 	struct sshbuf *m = NULL;
573 
574 	if ((ssh->kex->flags & KEX_INITIAL) == 0)
575 		return 0;
576 	if (!ssh->kex->ext_info_c && !ssh->kex->ext_info_s)
577 		return 0;
578 
579 	/* Compose EXT_INFO packet. */
580 	if ((m = sshbuf_new()) == NULL)
581 		fatal_f("sshbuf_new failed");
582 	if (ssh->kex->ext_info_c &&
583 	    (r = kex_compose_ext_info_server(ssh, m)) != 0)
584 		goto fail;
585 	if (ssh->kex->ext_info_s &&
586 	    (r = kex_compose_ext_info_client(ssh, m)) != 0)
587 		goto fail;
588 
589 	/* Send the actual KEX_INFO packet */
590 	debug("Sending SSH2_MSG_EXT_INFO");
591 	if ((r = sshpkt_start(ssh, SSH2_MSG_EXT_INFO)) != 0 ||
592 	    (r = sshpkt_putb(ssh, m)) != 0 ||
593 	    (r = sshpkt_send(ssh)) != 0) {
594 		error_f("send EXT_INFO");
595 		goto fail;
596 	}
597 
598 	r = 0;
599 
600  fail:
601 	sshbuf_free(m);
602 	return r;
603 }
604 
605 int
606 kex_server_update_ext_info(struct ssh *ssh)
607 {
608 	int r;
609 
610 	if ((ssh->kex->flags & KEX_HAS_EXT_INFO_IN_AUTH) == 0)
611 		return 0;
612 
613 	debug_f("Sending SSH2_MSG_EXT_INFO");
614 	if ((r = sshpkt_start(ssh, SSH2_MSG_EXT_INFO)) != 0 ||
615 	    (r = sshpkt_put_u32(ssh, 1)) != 0 ||
616 	    (r = sshpkt_put_cstring(ssh, "server-sig-algs")) != 0 ||
617 	    (r = sshpkt_put_cstring(ssh, ssh->kex->server_sig_algs)) != 0 ||
618 	    (r = sshpkt_send(ssh)) != 0) {
619 		error_f("send EXT_INFO");
620 		return r;
621 	}
622 	return 0;
623 }
624 
625 int
626 kex_send_newkeys(struct ssh *ssh)
627 {
628 	int r;
629 
630 	kex_reset_dispatch(ssh);
631 	if ((r = sshpkt_start(ssh, SSH2_MSG_NEWKEYS)) != 0 ||
632 	    (r = sshpkt_send(ssh)) != 0)
633 		return r;
634 	debug("SSH2_MSG_NEWKEYS sent");
635 	ssh_dispatch_set(ssh, SSH2_MSG_NEWKEYS, &kex_input_newkeys);
636 	if ((r = kex_maybe_send_ext_info(ssh)) != 0)
637 		return r;
638 	debug("expecting SSH2_MSG_NEWKEYS");
639 	return 0;
640 }
641 
642 /* Check whether an ext_info value contains the expected version string */
643 static int
644 kex_ext_info_check_ver(struct kex *kex, const char *name,
645     const u_char *val, size_t len, const char *want_ver, u_int flag)
646 {
647 	if (memchr(val, '\0', len) != NULL) {
648 		error("SSH2_MSG_EXT_INFO: %s value contains nul byte", name);
649 		return SSH_ERR_INVALID_FORMAT;
650 	}
651 	debug_f("%s=<%s>", name, val);
652 	if (strcmp(val, want_ver) == 0)
653 		kex->flags |= flag;
654 	else
655 		debug_f("unsupported version of %s extension", name);
656 	return 0;
657 }
658 
659 static int
660 kex_ext_info_client_parse(struct ssh *ssh, const char *name,
661     const u_char *value, size_t vlen)
662 {
663 	int r;
664 
665 	/* NB. some messages are only accepted in the initial EXT_INFO */
666 	if (strcmp(name, "server-sig-algs") == 0) {
667 		/* Ensure no \0 lurking in value */
668 		if (memchr(value, '\0', vlen) != NULL) {
669 			error_f("nul byte in %s", name);
670 			return SSH_ERR_INVALID_FORMAT;
671 		}
672 		debug_f("%s=<%s>", name, value);
673 		free(ssh->kex->server_sig_algs);
674 		ssh->kex->server_sig_algs = xstrdup((const char *)value);
675 	} else if (ssh->kex->ext_info_received == 1 &&
676 	    strcmp(name, "publickey-hostbound@openssh.com") == 0) {
677 		if ((r = kex_ext_info_check_ver(ssh->kex, name, value, vlen,
678 		    "0", KEX_HAS_PUBKEY_HOSTBOUND)) != 0) {
679 			return r;
680 		}
681 	} else if (ssh->kex->ext_info_received == 1 &&
682 	    strcmp(name, "ping@openssh.com") == 0) {
683 		if ((r = kex_ext_info_check_ver(ssh->kex, name, value, vlen,
684 		    "0", KEX_HAS_PING)) != 0) {
685 			return r;
686 		}
687 	} else
688 		debug_f("%s (unrecognised)", name);
689 
690 	return 0;
691 }
692 
693 static int
694 kex_ext_info_server_parse(struct ssh *ssh, const char *name,
695     const u_char *value, size_t vlen)
696 {
697 	int r;
698 
699 	if (strcmp(name, "ext-info-in-auth@openssh.com") == 0) {
700 		if ((r = kex_ext_info_check_ver(ssh->kex, name, value, vlen,
701 		    "0", KEX_HAS_EXT_INFO_IN_AUTH)) != 0) {
702 			return r;
703 		}
704 	} else
705 		debug_f("%s (unrecognised)", name);
706 	return 0;
707 }
708 
709 int
710 kex_input_ext_info(int type, u_int32_t seq, struct ssh *ssh)
711 {
712 	struct kex *kex = ssh->kex;
713 	const int max_ext_info = kex->server ? 1 : 2;
714 	u_int32_t i, ninfo;
715 	char *name;
716 	u_char *val;
717 	size_t vlen;
718 	int r;
719 
720 	debug("SSH2_MSG_EXT_INFO received");
721 	if (++kex->ext_info_received > max_ext_info) {
722 		error("too many SSH2_MSG_EXT_INFO messages sent by peer");
723 		return dispatch_protocol_error(type, seq, ssh);
724 	}
725 	ssh_dispatch_set(ssh, SSH2_MSG_EXT_INFO, &kex_protocol_error);
726 	if ((r = sshpkt_get_u32(ssh, &ninfo)) != 0)
727 		return r;
728 	if (ninfo >= 1024) {
729 		error("SSH2_MSG_EXT_INFO with too many entries, expected "
730 		    "<=1024, received %u", ninfo);
731 		return dispatch_protocol_error(type, seq, ssh);
732 	}
733 	for (i = 0; i < ninfo; i++) {
734 		if ((r = sshpkt_get_cstring(ssh, &name, NULL)) != 0)
735 			return r;
736 		if ((r = sshpkt_get_string(ssh, &val, &vlen)) != 0) {
737 			free(name);
738 			return r;
739 		}
740 		debug3_f("extension %s", name);
741 		if (kex->server) {
742 			if ((r = kex_ext_info_server_parse(ssh, name,
743 			    val, vlen)) != 0)
744 				return r;
745 		} else {
746 			if ((r = kex_ext_info_client_parse(ssh, name,
747 			    val, vlen)) != 0)
748 				return r;
749 		}
750 		free(name);
751 		free(val);
752 	}
753 	return sshpkt_get_end(ssh);
754 }
755 
756 static int
757 kex_input_newkeys(int type, u_int32_t seq, struct ssh *ssh)
758 {
759 	struct kex *kex = ssh->kex;
760 	int r;
761 
762 	debug("SSH2_MSG_NEWKEYS received");
763 	if (kex->ext_info_c && (kex->flags & KEX_INITIAL) != 0)
764 		ssh_dispatch_set(ssh, SSH2_MSG_EXT_INFO, &kex_input_ext_info);
765 	ssh_dispatch_set(ssh, SSH2_MSG_NEWKEYS, &kex_protocol_error);
766 	ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_input_kexinit);
767 	if ((r = sshpkt_get_end(ssh)) != 0)
768 		return r;
769 	if ((r = ssh_set_newkeys(ssh, MODE_IN)) != 0)
770 		return r;
771 	kex->done = 1;
772 	kex->flags &= ~KEX_INITIAL;
773 	sshbuf_reset(kex->peer);
774 	/* sshbuf_reset(kex->my); */
775 	kex->flags &= ~KEX_INIT_SENT;
776 	free(kex->name);
777 	kex->name = NULL;
778 	return 0;
779 }
780 
781 int
782 kex_send_kexinit(struct ssh *ssh)
783 {
784 	u_char *cookie;
785 	struct kex *kex = ssh->kex;
786 	int r;
787 
788 	if (kex == NULL) {
789 		error_f("no kex");
790 		return SSH_ERR_INTERNAL_ERROR;
791 	}
792 	if (kex->flags & KEX_INIT_SENT)
793 		return 0;
794 	kex->done = 0;
795 
796 	/* generate a random cookie */
797 	if (sshbuf_len(kex->my) < KEX_COOKIE_LEN) {
798 		error_f("bad kex length: %zu < %d",
799 		    sshbuf_len(kex->my), KEX_COOKIE_LEN);
800 		return SSH_ERR_INVALID_FORMAT;
801 	}
802 	if ((cookie = sshbuf_mutable_ptr(kex->my)) == NULL) {
803 		error_f("buffer error");
804 		return SSH_ERR_INTERNAL_ERROR;
805 	}
806 	arc4random_buf(cookie, KEX_COOKIE_LEN);
807 
808 	if ((r = sshpkt_start(ssh, SSH2_MSG_KEXINIT)) != 0 ||
809 	    (r = sshpkt_putb(ssh, kex->my)) != 0 ||
810 	    (r = sshpkt_send(ssh)) != 0) {
811 		error_fr(r, "compose reply");
812 		return r;
813 	}
814 	debug("SSH2_MSG_KEXINIT sent");
815 	kex->flags |= KEX_INIT_SENT;
816 	return 0;
817 }
818 
819 int
820 kex_input_kexinit(int type, u_int32_t seq, struct ssh *ssh)
821 {
822 	struct kex *kex = ssh->kex;
823 	const u_char *ptr;
824 	u_int i;
825 	size_t dlen;
826 	int r;
827 
828 	debug("SSH2_MSG_KEXINIT received");
829 	if (kex == NULL) {
830 		error_f("no kex");
831 		return SSH_ERR_INTERNAL_ERROR;
832 	}
833 	ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_protocol_error);
834 	ptr = sshpkt_ptr(ssh, &dlen);
835 	if ((r = sshbuf_put(kex->peer, ptr, dlen)) != 0)
836 		return r;
837 
838 	/* discard packet */
839 	for (i = 0; i < KEX_COOKIE_LEN; i++) {
840 		if ((r = sshpkt_get_u8(ssh, NULL)) != 0) {
841 			error_fr(r, "discard cookie");
842 			return r;
843 		}
844 	}
845 	for (i = 0; i < PROPOSAL_MAX; i++) {
846 		if ((r = sshpkt_get_string(ssh, NULL, NULL)) != 0) {
847 			error_fr(r, "discard proposal");
848 			return r;
849 		}
850 	}
851 	/*
852 	 * XXX RFC4253 sec 7: "each side MAY guess" - currently no supported
853 	 * KEX method has the server move first, but a server might be using
854 	 * a custom method or one that we otherwise don't support. We should
855 	 * be prepared to remember first_kex_follows here so we can eat a
856 	 * packet later.
857 	 * XXX2 - RFC4253 is kind of ambiguous on what first_kex_follows means
858 	 * for cases where the server *doesn't* go first. I guess we should
859 	 * ignore it when it is set for these cases, which is what we do now.
860 	 */
861 	if ((r = sshpkt_get_u8(ssh, NULL)) != 0 ||	/* first_kex_follows */
862 	    (r = sshpkt_get_u32(ssh, NULL)) != 0 ||	/* reserved */
863 	    (r = sshpkt_get_end(ssh)) != 0)
864 			return r;
865 
866 	if (!(kex->flags & KEX_INIT_SENT))
867 		if ((r = kex_send_kexinit(ssh)) != 0)
868 			return r;
869 	if ((r = kex_choose_conf(ssh, seq)) != 0)
870 		return r;
871 
872 	if (kex->kex_type < KEX_MAX && kex->kex[kex->kex_type] != NULL)
873 		return (kex->kex[kex->kex_type])(ssh);
874 
875 	error_f("unknown kex type %u", kex->kex_type);
876 	return SSH_ERR_INTERNAL_ERROR;
877 }
878 
879 struct kex *
880 kex_new(void)
881 {
882 	struct kex *kex;
883 
884 	if ((kex = calloc(1, sizeof(*kex))) == NULL ||
885 	    (kex->peer = sshbuf_new()) == NULL ||
886 	    (kex->my = sshbuf_new()) == NULL ||
887 	    (kex->client_version = sshbuf_new()) == NULL ||
888 	    (kex->server_version = sshbuf_new()) == NULL ||
889 	    (kex->session_id = sshbuf_new()) == NULL) {
890 		kex_free(kex);
891 		return NULL;
892 	}
893 	return kex;
894 }
895 
896 void
897 kex_free_newkeys(struct newkeys *newkeys)
898 {
899 	if (newkeys == NULL)
900 		return;
901 	if (newkeys->enc.key) {
902 		explicit_bzero(newkeys->enc.key, newkeys->enc.key_len);
903 		free(newkeys->enc.key);
904 		newkeys->enc.key = NULL;
905 	}
906 	if (newkeys->enc.iv) {
907 		explicit_bzero(newkeys->enc.iv, newkeys->enc.iv_len);
908 		free(newkeys->enc.iv);
909 		newkeys->enc.iv = NULL;
910 	}
911 	free(newkeys->enc.name);
912 	explicit_bzero(&newkeys->enc, sizeof(newkeys->enc));
913 	free(newkeys->comp.name);
914 	explicit_bzero(&newkeys->comp, sizeof(newkeys->comp));
915 	mac_clear(&newkeys->mac);
916 	if (newkeys->mac.key) {
917 		explicit_bzero(newkeys->mac.key, newkeys->mac.key_len);
918 		free(newkeys->mac.key);
919 		newkeys->mac.key = NULL;
920 	}
921 	free(newkeys->mac.name);
922 	explicit_bzero(&newkeys->mac, sizeof(newkeys->mac));
923 	freezero(newkeys, sizeof(*newkeys));
924 }
925 
926 void
927 kex_free(struct kex *kex)
928 {
929 	u_int mode;
930 
931 	if (kex == NULL)
932 		return;
933 
934 #ifdef WITH_OPENSSL
935 	DH_free(kex->dh);
936 	EC_KEY_free(kex->ec_client_key);
937 #endif
938 	for (mode = 0; mode < MODE_MAX; mode++) {
939 		kex_free_newkeys(kex->newkeys[mode]);
940 		kex->newkeys[mode] = NULL;
941 	}
942 	sshbuf_free(kex->peer);
943 	sshbuf_free(kex->my);
944 	sshbuf_free(kex->client_version);
945 	sshbuf_free(kex->server_version);
946 	sshbuf_free(kex->client_pub);
947 	sshbuf_free(kex->session_id);
948 	sshbuf_free(kex->initial_sig);
949 	sshkey_free(kex->initial_hostkey);
950 	free(kex->failed_choice);
951 	free(kex->hostkey_alg);
952 	free(kex->name);
953 	free(kex);
954 }
955 
956 int
957 kex_ready(struct ssh *ssh, char *proposal[PROPOSAL_MAX])
958 {
959 	int r;
960 
961 	if ((r = kex_prop2buf(ssh->kex->my, proposal)) != 0)
962 		return r;
963 	ssh->kex->flags = KEX_INITIAL;
964 	kex_reset_dispatch(ssh);
965 	ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_input_kexinit);
966 	return 0;
967 }
968 
969 int
970 kex_setup(struct ssh *ssh, char *proposal[PROPOSAL_MAX])
971 {
972 	int r;
973 
974 	if ((r = kex_ready(ssh, proposal)) != 0)
975 		return r;
976 	if ((r = kex_send_kexinit(ssh)) != 0) {		/* we start */
977 		kex_free(ssh->kex);
978 		ssh->kex = NULL;
979 		return r;
980 	}
981 	return 0;
982 }
983 
984 /*
985  * Request key re-exchange, returns 0 on success or a ssherr.h error
986  * code otherwise. Must not be called if KEX is incomplete or in-progress.
987  */
988 int
989 kex_start_rekex(struct ssh *ssh)
990 {
991 	if (ssh->kex == NULL) {
992 		error_f("no kex");
993 		return SSH_ERR_INTERNAL_ERROR;
994 	}
995 	if (ssh->kex->done == 0) {
996 		error_f("requested twice");
997 		return SSH_ERR_INTERNAL_ERROR;
998 	}
999 	ssh->kex->done = 0;
1000 	return kex_send_kexinit(ssh);
1001 }
1002 
1003 static int
1004 choose_enc(struct sshenc *enc, char *client, char *server)
1005 {
1006 	char *name = match_list(client, server, NULL);
1007 
1008 	if (name == NULL)
1009 		return SSH_ERR_NO_CIPHER_ALG_MATCH;
1010 	if ((enc->cipher = cipher_by_name(name)) == NULL) {
1011 		error_f("unsupported cipher %s", name);
1012 		free(name);
1013 		return SSH_ERR_INTERNAL_ERROR;
1014 	}
1015 	enc->name = name;
1016 	enc->enabled = 0;
1017 	enc->iv = NULL;
1018 	enc->iv_len = cipher_ivlen(enc->cipher);
1019 	enc->key = NULL;
1020 	enc->key_len = cipher_keylen(enc->cipher);
1021 	enc->block_size = cipher_blocksize(enc->cipher);
1022 	return 0;
1023 }
1024 
1025 static int
1026 choose_mac(struct ssh *ssh, struct sshmac *mac, char *client, char *server)
1027 {
1028 	char *name = match_list(client, server, NULL);
1029 
1030 	if (name == NULL)
1031 		return SSH_ERR_NO_MAC_ALG_MATCH;
1032 	if (mac_setup(mac, name) < 0) {
1033 		error_f("unsupported MAC %s", name);
1034 		free(name);
1035 		return SSH_ERR_INTERNAL_ERROR;
1036 	}
1037 	mac->name = name;
1038 	mac->key = NULL;
1039 	mac->enabled = 0;
1040 	return 0;
1041 }
1042 
1043 static int
1044 choose_comp(struct sshcomp *comp, char *client, char *server)
1045 {
1046 	char *name = match_list(client, server, NULL);
1047 
1048 	if (name == NULL)
1049 		return SSH_ERR_NO_COMPRESS_ALG_MATCH;
1050 #ifdef WITH_ZLIB
1051 	if (strcmp(name, "zlib@openssh.com") == 0) {
1052 		comp->type = COMP_DELAYED;
1053 	} else if (strcmp(name, "zlib") == 0) {
1054 		comp->type = COMP_ZLIB;
1055 	} else
1056 #endif	/* WITH_ZLIB */
1057 	if (strcmp(name, "none") == 0) {
1058 		comp->type = COMP_NONE;
1059 	} else {
1060 		error_f("unsupported compression scheme %s", name);
1061 		free(name);
1062 		return SSH_ERR_INTERNAL_ERROR;
1063 	}
1064 	comp->name = name;
1065 	return 0;
1066 }
1067 
1068 static int
1069 choose_kex(struct kex *k, char *client, char *server)
1070 {
1071 	const struct kexalg *kexalg;
1072 
1073 	k->name = match_list(client, server, NULL);
1074 
1075 	debug("kex: algorithm: %s", k->name ? k->name : "(no match)");
1076 	if (k->name == NULL)
1077 		return SSH_ERR_NO_KEX_ALG_MATCH;
1078 	if ((kexalg = kex_alg_by_name(k->name)) == NULL) {
1079 		error_f("unsupported KEX method %s", k->name);
1080 		return SSH_ERR_INTERNAL_ERROR;
1081 	}
1082 	k->kex_type = kexalg->type;
1083 	k->hash_alg = kexalg->hash_alg;
1084 	k->ec_nid = kexalg->ec_nid;
1085 	return 0;
1086 }
1087 
1088 static int
1089 choose_hostkeyalg(struct kex *k, char *client, char *server)
1090 {
1091 	free(k->hostkey_alg);
1092 	k->hostkey_alg = match_list(client, server, NULL);
1093 
1094 	debug("kex: host key algorithm: %s",
1095 	    k->hostkey_alg ? k->hostkey_alg : "(no match)");
1096 	if (k->hostkey_alg == NULL)
1097 		return SSH_ERR_NO_HOSTKEY_ALG_MATCH;
1098 	k->hostkey_type = sshkey_type_from_name(k->hostkey_alg);
1099 	if (k->hostkey_type == KEY_UNSPEC) {
1100 		error_f("unsupported hostkey algorithm %s", k->hostkey_alg);
1101 		return SSH_ERR_INTERNAL_ERROR;
1102 	}
1103 	k->hostkey_nid = sshkey_ecdsa_nid_from_name(k->hostkey_alg);
1104 	return 0;
1105 }
1106 
1107 static int
1108 proposals_match(char *my[PROPOSAL_MAX], char *peer[PROPOSAL_MAX])
1109 {
1110 	static int check[] = {
1111 		PROPOSAL_KEX_ALGS, PROPOSAL_SERVER_HOST_KEY_ALGS, -1
1112 	};
1113 	int *idx;
1114 	char *p;
1115 
1116 	for (idx = &check[0]; *idx != -1; idx++) {
1117 		if ((p = strchr(my[*idx], ',')) != NULL)
1118 			*p = '\0';
1119 		if ((p = strchr(peer[*idx], ',')) != NULL)
1120 			*p = '\0';
1121 		if (strcmp(my[*idx], peer[*idx]) != 0) {
1122 			debug2("proposal mismatch: my %s peer %s",
1123 			    my[*idx], peer[*idx]);
1124 			return (0);
1125 		}
1126 	}
1127 	debug2("proposals match");
1128 	return (1);
1129 }
1130 
1131 static int
1132 kexalgs_contains(char **peer, const char *ext)
1133 {
1134 	return has_any_alg(peer[PROPOSAL_KEX_ALGS], ext);
1135 }
1136 
1137 static int
1138 kex_choose_conf(struct ssh *ssh, uint32_t seq)
1139 {
1140 	struct kex *kex = ssh->kex;
1141 	struct newkeys *newkeys;
1142 	char **my = NULL, **peer = NULL;
1143 	char **cprop, **sprop;
1144 	int nenc, nmac, ncomp;
1145 	u_int mode, ctos, need, dh_need, authlen;
1146 	int r, first_kex_follows;
1147 
1148 	debug2("local %s KEXINIT proposal", kex->server ? "server" : "client");
1149 	if ((r = kex_buf2prop(kex->my, NULL, &my)) != 0)
1150 		goto out;
1151 	debug2("peer %s KEXINIT proposal", kex->server ? "client" : "server");
1152 	if ((r = kex_buf2prop(kex->peer, &first_kex_follows, &peer)) != 0)
1153 		goto out;
1154 
1155 	if (kex->server) {
1156 		cprop=peer;
1157 		sprop=my;
1158 	} else {
1159 		cprop=my;
1160 		sprop=peer;
1161 	}
1162 
1163 	/* Check whether peer supports ext_info/kex_strict */
1164 	if ((kex->flags & KEX_INITIAL) != 0) {
1165 		if (kex->server) {
1166 			kex->ext_info_c = kexalgs_contains(peer, "ext-info-c");
1167 			kex->kex_strict = kexalgs_contains(peer,
1168 			    "kex-strict-c-v00@openssh.com");
1169 		} else {
1170 			kex->ext_info_s = kexalgs_contains(peer, "ext-info-s");
1171 			kex->kex_strict = kexalgs_contains(peer,
1172 			    "kex-strict-s-v00@openssh.com");
1173 		}
1174 		if (kex->kex_strict) {
1175 			debug3_f("will use strict KEX ordering");
1176 			if (seq != 0)
1177 				ssh_packet_disconnect(ssh,
1178 				    "strict KEX violation: "
1179 				    "KEXINIT was not the first packet");
1180 		}
1181 	}
1182 
1183 	/* Check whether client supports rsa-sha2 algorithms */
1184 	if (kex->server && (kex->flags & KEX_INITIAL)) {
1185 		if (has_any_alg(peer[PROPOSAL_SERVER_HOST_KEY_ALGS],
1186 		    "rsa-sha2-256,rsa-sha2-256-cert-v01@openssh.com"))
1187 			kex->flags |= KEX_RSA_SHA2_256_SUPPORTED;
1188 		if (has_any_alg(peer[PROPOSAL_SERVER_HOST_KEY_ALGS],
1189 		    "rsa-sha2-512,rsa-sha2-512-cert-v01@openssh.com"))
1190 			kex->flags |= KEX_RSA_SHA2_512_SUPPORTED;
1191 	}
1192 
1193 	/* Algorithm Negotiation */
1194 	if ((r = choose_kex(kex, cprop[PROPOSAL_KEX_ALGS],
1195 	    sprop[PROPOSAL_KEX_ALGS])) != 0) {
1196 		kex->failed_choice = peer[PROPOSAL_KEX_ALGS];
1197 		peer[PROPOSAL_KEX_ALGS] = NULL;
1198 		goto out;
1199 	}
1200 	if ((r = choose_hostkeyalg(kex, cprop[PROPOSAL_SERVER_HOST_KEY_ALGS],
1201 	    sprop[PROPOSAL_SERVER_HOST_KEY_ALGS])) != 0) {
1202 		kex->failed_choice = peer[PROPOSAL_SERVER_HOST_KEY_ALGS];
1203 		peer[PROPOSAL_SERVER_HOST_KEY_ALGS] = NULL;
1204 		goto out;
1205 	}
1206 	for (mode = 0; mode < MODE_MAX; mode++) {
1207 		if ((newkeys = calloc(1, sizeof(*newkeys))) == NULL) {
1208 			r = SSH_ERR_ALLOC_FAIL;
1209 			goto out;
1210 		}
1211 		kex->newkeys[mode] = newkeys;
1212 		ctos = (!kex->server && mode == MODE_OUT) ||
1213 		    (kex->server && mode == MODE_IN);
1214 		nenc  = ctos ? PROPOSAL_ENC_ALGS_CTOS  : PROPOSAL_ENC_ALGS_STOC;
1215 		nmac  = ctos ? PROPOSAL_MAC_ALGS_CTOS  : PROPOSAL_MAC_ALGS_STOC;
1216 		ncomp = ctos ? PROPOSAL_COMP_ALGS_CTOS : PROPOSAL_COMP_ALGS_STOC;
1217 		if ((r = choose_enc(&newkeys->enc, cprop[nenc],
1218 		    sprop[nenc])) != 0) {
1219 			kex->failed_choice = peer[nenc];
1220 			peer[nenc] = NULL;
1221 			goto out;
1222 		}
1223 		authlen = cipher_authlen(newkeys->enc.cipher);
1224 		/* ignore mac for authenticated encryption */
1225 		if (authlen == 0 &&
1226 		    (r = choose_mac(ssh, &newkeys->mac, cprop[nmac],
1227 		    sprop[nmac])) != 0) {
1228 			kex->failed_choice = peer[nmac];
1229 			peer[nmac] = NULL;
1230 			goto out;
1231 		}
1232 		if ((r = choose_comp(&newkeys->comp, cprop[ncomp],
1233 		    sprop[ncomp])) != 0) {
1234 			kex->failed_choice = peer[ncomp];
1235 			peer[ncomp] = NULL;
1236 			goto out;
1237 		}
1238 		debug("kex: %s cipher: %s MAC: %s compression: %s",
1239 		    ctos ? "client->server" : "server->client",
1240 		    newkeys->enc.name,
1241 		    authlen == 0 ? newkeys->mac.name : "<implicit>",
1242 		    newkeys->comp.name);
1243 	}
1244 	need = dh_need = 0;
1245 	for (mode = 0; mode < MODE_MAX; mode++) {
1246 		newkeys = kex->newkeys[mode];
1247 		need = MAXIMUM(need, newkeys->enc.key_len);
1248 		need = MAXIMUM(need, newkeys->enc.block_size);
1249 		need = MAXIMUM(need, newkeys->enc.iv_len);
1250 		need = MAXIMUM(need, newkeys->mac.key_len);
1251 		dh_need = MAXIMUM(dh_need, cipher_seclen(newkeys->enc.cipher));
1252 		dh_need = MAXIMUM(dh_need, newkeys->enc.block_size);
1253 		dh_need = MAXIMUM(dh_need, newkeys->enc.iv_len);
1254 		dh_need = MAXIMUM(dh_need, newkeys->mac.key_len);
1255 	}
1256 	/* XXX need runden? */
1257 	kex->we_need = need;
1258 	kex->dh_need = dh_need;
1259 
1260 	/* ignore the next message if the proposals do not match */
1261 	if (first_kex_follows && !proposals_match(my, peer))
1262 		ssh->dispatch_skip_packets = 1;
1263 	r = 0;
1264  out:
1265 	kex_prop_free(my);
1266 	kex_prop_free(peer);
1267 	return r;
1268 }
1269 
1270 static int
1271 derive_key(struct ssh *ssh, int id, u_int need, u_char *hash, u_int hashlen,
1272     const struct sshbuf *shared_secret, u_char **keyp)
1273 {
1274 	struct kex *kex = ssh->kex;
1275 	struct ssh_digest_ctx *hashctx = NULL;
1276 	char c = id;
1277 	u_int have;
1278 	size_t mdsz;
1279 	u_char *digest;
1280 	int r;
1281 
1282 	if ((mdsz = ssh_digest_bytes(kex->hash_alg)) == 0)
1283 		return SSH_ERR_INVALID_ARGUMENT;
1284 	if ((digest = calloc(1, ROUNDUP(need, mdsz))) == NULL) {
1285 		r = SSH_ERR_ALLOC_FAIL;
1286 		goto out;
1287 	}
1288 
1289 	/* K1 = HASH(K || H || "A" || session_id) */
1290 	if ((hashctx = ssh_digest_start(kex->hash_alg)) == NULL ||
1291 	    ssh_digest_update_buffer(hashctx, shared_secret) != 0 ||
1292 	    ssh_digest_update(hashctx, hash, hashlen) != 0 ||
1293 	    ssh_digest_update(hashctx, &c, 1) != 0 ||
1294 	    ssh_digest_update_buffer(hashctx, kex->session_id) != 0 ||
1295 	    ssh_digest_final(hashctx, digest, mdsz) != 0) {
1296 		r = SSH_ERR_LIBCRYPTO_ERROR;
1297 		error_f("KEX hash failed");
1298 		goto out;
1299 	}
1300 	ssh_digest_free(hashctx);
1301 	hashctx = NULL;
1302 
1303 	/*
1304 	 * expand key:
1305 	 * Kn = HASH(K || H || K1 || K2 || ... || Kn-1)
1306 	 * Key = K1 || K2 || ... || Kn
1307 	 */
1308 	for (have = mdsz; need > have; have += mdsz) {
1309 		if ((hashctx = ssh_digest_start(kex->hash_alg)) == NULL ||
1310 		    ssh_digest_update_buffer(hashctx, shared_secret) != 0 ||
1311 		    ssh_digest_update(hashctx, hash, hashlen) != 0 ||
1312 		    ssh_digest_update(hashctx, digest, have) != 0 ||
1313 		    ssh_digest_final(hashctx, digest + have, mdsz) != 0) {
1314 			error_f("KDF failed");
1315 			r = SSH_ERR_LIBCRYPTO_ERROR;
1316 			goto out;
1317 		}
1318 		ssh_digest_free(hashctx);
1319 		hashctx = NULL;
1320 	}
1321 #ifdef DEBUG_KEX
1322 	fprintf(stderr, "key '%c'== ", c);
1323 	dump_digest("key", digest, need);
1324 #endif
1325 	*keyp = digest;
1326 	digest = NULL;
1327 	r = 0;
1328  out:
1329 	free(digest);
1330 	ssh_digest_free(hashctx);
1331 	return r;
1332 }
1333 
1334 #define NKEYS	6
1335 int
1336 kex_derive_keys(struct ssh *ssh, u_char *hash, u_int hashlen,
1337     const struct sshbuf *shared_secret)
1338 {
1339 	struct kex *kex = ssh->kex;
1340 	u_char *keys[NKEYS];
1341 	u_int i, j, mode, ctos;
1342 	int r;
1343 
1344 	/* save initial hash as session id */
1345 	if ((kex->flags & KEX_INITIAL) != 0) {
1346 		if (sshbuf_len(kex->session_id) != 0) {
1347 			error_f("already have session ID at kex");
1348 			return SSH_ERR_INTERNAL_ERROR;
1349 		}
1350 		if ((r = sshbuf_put(kex->session_id, hash, hashlen)) != 0)
1351 			return r;
1352 	} else if (sshbuf_len(kex->session_id) == 0) {
1353 		error_f("no session ID in rekex");
1354 		return SSH_ERR_INTERNAL_ERROR;
1355 	}
1356 	for (i = 0; i < NKEYS; i++) {
1357 		if ((r = derive_key(ssh, 'A'+i, kex->we_need, hash, hashlen,
1358 		    shared_secret, &keys[i])) != 0) {
1359 			for (j = 0; j < i; j++)
1360 				free(keys[j]);
1361 			return r;
1362 		}
1363 	}
1364 	for (mode = 0; mode < MODE_MAX; mode++) {
1365 		ctos = (!kex->server && mode == MODE_OUT) ||
1366 		    (kex->server && mode == MODE_IN);
1367 		kex->newkeys[mode]->enc.iv  = keys[ctos ? 0 : 1];
1368 		kex->newkeys[mode]->enc.key = keys[ctos ? 2 : 3];
1369 		kex->newkeys[mode]->mac.key = keys[ctos ? 4 : 5];
1370 	}
1371 	return 0;
1372 }
1373 
1374 int
1375 kex_load_hostkey(struct ssh *ssh, struct sshkey **prvp, struct sshkey **pubp)
1376 {
1377 	struct kex *kex = ssh->kex;
1378 
1379 	*pubp = NULL;
1380 	*prvp = NULL;
1381 	if (kex->load_host_public_key == NULL ||
1382 	    kex->load_host_private_key == NULL) {
1383 		error_f("missing hostkey loader");
1384 		return SSH_ERR_INVALID_ARGUMENT;
1385 	}
1386 	*pubp = kex->load_host_public_key(kex->hostkey_type,
1387 	    kex->hostkey_nid, ssh);
1388 	*prvp = kex->load_host_private_key(kex->hostkey_type,
1389 	    kex->hostkey_nid, ssh);
1390 	if (*pubp == NULL)
1391 		return SSH_ERR_NO_HOSTKEY_LOADED;
1392 	return 0;
1393 }
1394 
1395 int
1396 kex_verify_host_key(struct ssh *ssh, struct sshkey *server_host_key)
1397 {
1398 	struct kex *kex = ssh->kex;
1399 
1400 	if (kex->verify_host_key == NULL) {
1401 		error_f("missing hostkey verifier");
1402 		return SSH_ERR_INVALID_ARGUMENT;
1403 	}
1404 	if (server_host_key->type != kex->hostkey_type ||
1405 	    (kex->hostkey_type == KEY_ECDSA &&
1406 	    server_host_key->ecdsa_nid != kex->hostkey_nid))
1407 		return SSH_ERR_KEY_TYPE_MISMATCH;
1408 	if (kex->verify_host_key(server_host_key, ssh) == -1)
1409 		return  SSH_ERR_SIGNATURE_INVALID;
1410 	return 0;
1411 }
1412 
1413 #if defined(DEBUG_KEX) || defined(DEBUG_KEXDH) || defined(DEBUG_KEXECDH)
1414 void
1415 dump_digest(const char *msg, const u_char *digest, int len)
1416 {
1417 	fprintf(stderr, "%s\n", msg);
1418 	sshbuf_dump_data(digest, len, stderr);
1419 }
1420 #endif
1421 
1422 /*
1423  * Send a plaintext error message to the peer, suffixed by \r\n.
1424  * Only used during banner exchange, and there only for the server.
1425  */
1426 static void
1427 send_error(struct ssh *ssh, char *msg)
1428 {
1429 	char *crnl = "\r\n";
1430 
1431 	if (!ssh->kex->server)
1432 		return;
1433 
1434 	if (atomicio(vwrite, ssh_packet_get_connection_out(ssh),
1435 	    msg, strlen(msg)) != strlen(msg) ||
1436 	    atomicio(vwrite, ssh_packet_get_connection_out(ssh),
1437 	    crnl, strlen(crnl)) != strlen(crnl))
1438 		error_f("write: %.100s", strerror(errno));
1439 }
1440 
1441 /*
1442  * Sends our identification string and waits for the peer's. Will block for
1443  * up to timeout_ms (or indefinitely if timeout_ms <= 0).
1444  * Returns on 0 success or a ssherr.h code on failure.
1445  */
1446 int
1447 kex_exchange_identification(struct ssh *ssh, int timeout_ms,
1448     const char *version_addendum)
1449 {
1450 	int remote_major, remote_minor, mismatch, oerrno = 0;
1451 	size_t len, n;
1452 	int r, expect_nl;
1453 	u_char c;
1454 	struct sshbuf *our_version = ssh->kex->server ?
1455 	    ssh->kex->server_version : ssh->kex->client_version;
1456 	struct sshbuf *peer_version = ssh->kex->server ?
1457 	    ssh->kex->client_version : ssh->kex->server_version;
1458 	char *our_version_string = NULL, *peer_version_string = NULL;
1459 	char *cp, *remote_version = NULL;
1460 
1461 	/* Prepare and send our banner */
1462 	sshbuf_reset(our_version);
1463 	if (version_addendum != NULL && *version_addendum == '\0')
1464 		version_addendum = NULL;
1465 	if ((r = sshbuf_putf(our_version, "SSH-%d.%d-%s%s%s\r\n",
1466 	    PROTOCOL_MAJOR_2, PROTOCOL_MINOR_2, SSH_VERSION,
1467 	    version_addendum == NULL ? "" : " ",
1468 	    version_addendum == NULL ? "" : version_addendum)) != 0) {
1469 		oerrno = errno;
1470 		error_fr(r, "sshbuf_putf");
1471 		goto out;
1472 	}
1473 
1474 	if (atomicio(vwrite, ssh_packet_get_connection_out(ssh),
1475 	    sshbuf_mutable_ptr(our_version),
1476 	    sshbuf_len(our_version)) != sshbuf_len(our_version)) {
1477 		oerrno = errno;
1478 		debug_f("write: %.100s", strerror(errno));
1479 		r = SSH_ERR_SYSTEM_ERROR;
1480 		goto out;
1481 	}
1482 	if ((r = sshbuf_consume_end(our_version, 2)) != 0) { /* trim \r\n */
1483 		oerrno = errno;
1484 		error_fr(r, "sshbuf_consume_end");
1485 		goto out;
1486 	}
1487 	our_version_string = sshbuf_dup_string(our_version);
1488 	if (our_version_string == NULL) {
1489 		error_f("sshbuf_dup_string failed");
1490 		r = SSH_ERR_ALLOC_FAIL;
1491 		goto out;
1492 	}
1493 	debug("Local version string %.100s", our_version_string);
1494 
1495 	/* Read other side's version identification. */
1496 	for (n = 0; ; n++) {
1497 		if (n >= SSH_MAX_PRE_BANNER_LINES) {
1498 			send_error(ssh, "No SSH identification string "
1499 			    "received.");
1500 			error_f("No SSH version received in first %u lines "
1501 			    "from server", SSH_MAX_PRE_BANNER_LINES);
1502 			r = SSH_ERR_INVALID_FORMAT;
1503 			goto out;
1504 		}
1505 		sshbuf_reset(peer_version);
1506 		expect_nl = 0;
1507 		for (;;) {
1508 			if (timeout_ms > 0) {
1509 				r = waitrfd(ssh_packet_get_connection_in(ssh),
1510 				    &timeout_ms, NULL);
1511 				if (r == -1 && errno == ETIMEDOUT) {
1512 					send_error(ssh, "Timed out waiting "
1513 					    "for SSH identification string.");
1514 					error("Connection timed out during "
1515 					    "banner exchange");
1516 					r = SSH_ERR_CONN_TIMEOUT;
1517 					goto out;
1518 				} else if (r == -1) {
1519 					oerrno = errno;
1520 					error_f("%s", strerror(errno));
1521 					r = SSH_ERR_SYSTEM_ERROR;
1522 					goto out;
1523 				}
1524 			}
1525 
1526 			len = atomicio(read, ssh_packet_get_connection_in(ssh),
1527 			    &c, 1);
1528 			if (len != 1 && errno == EPIPE) {
1529 				verbose_f("Connection closed by remote host");
1530 				r = SSH_ERR_CONN_CLOSED;
1531 				goto out;
1532 			} else if (len != 1) {
1533 				oerrno = errno;
1534 				error_f("read: %.100s", strerror(errno));
1535 				r = SSH_ERR_SYSTEM_ERROR;
1536 				goto out;
1537 			}
1538 			if (c == '\r') {
1539 				expect_nl = 1;
1540 				continue;
1541 			}
1542 			if (c == '\n')
1543 				break;
1544 			if (c == '\0' || expect_nl) {
1545 				verbose_f("banner line contains invalid "
1546 				    "characters");
1547 				goto invalid;
1548 			}
1549 			if ((r = sshbuf_put_u8(peer_version, c)) != 0) {
1550 				oerrno = errno;
1551 				error_fr(r, "sshbuf_put");
1552 				goto out;
1553 			}
1554 			if (sshbuf_len(peer_version) > SSH_MAX_BANNER_LEN) {
1555 				verbose_f("banner line too long");
1556 				goto invalid;
1557 			}
1558 		}
1559 		/* Is this an actual protocol banner? */
1560 		if (sshbuf_len(peer_version) > 4 &&
1561 		    memcmp(sshbuf_ptr(peer_version), "SSH-", 4) == 0)
1562 			break;
1563 		/* If not, then just log the line and continue */
1564 		if ((cp = sshbuf_dup_string(peer_version)) == NULL) {
1565 			error_f("sshbuf_dup_string failed");
1566 			r = SSH_ERR_ALLOC_FAIL;
1567 			goto out;
1568 		}
1569 		/* Do not accept lines before the SSH ident from a client */
1570 		if (ssh->kex->server) {
1571 			verbose_f("client sent invalid protocol identifier "
1572 			    "\"%.256s\"", cp);
1573 			free(cp);
1574 			goto invalid;
1575 		}
1576 		debug_f("banner line %zu: %s", n, cp);
1577 		free(cp);
1578 	}
1579 	peer_version_string = sshbuf_dup_string(peer_version);
1580 	if (peer_version_string == NULL)
1581 		fatal_f("sshbuf_dup_string failed");
1582 	/* XXX must be same size for sscanf */
1583 	if ((remote_version = calloc(1, sshbuf_len(peer_version))) == NULL) {
1584 		error_f("calloc failed");
1585 		r = SSH_ERR_ALLOC_FAIL;
1586 		goto out;
1587 	}
1588 
1589 	/*
1590 	 * Check that the versions match.  In future this might accept
1591 	 * several versions and set appropriate flags to handle them.
1592 	 */
1593 	if (sscanf(peer_version_string, "SSH-%d.%d-%[^\n]\n",
1594 	    &remote_major, &remote_minor, remote_version) != 3) {
1595 		error("Bad remote protocol version identification: '%.100s'",
1596 		    peer_version_string);
1597  invalid:
1598 		send_error(ssh, "Invalid SSH identification string.");
1599 		r = SSH_ERR_INVALID_FORMAT;
1600 		goto out;
1601 	}
1602 	debug("Remote protocol version %d.%d, remote software version %.100s",
1603 	    remote_major, remote_minor, remote_version);
1604 	compat_banner(ssh, remote_version);
1605 
1606 	mismatch = 0;
1607 	switch (remote_major) {
1608 	case 2:
1609 		break;
1610 	case 1:
1611 		if (remote_minor != 99)
1612 			mismatch = 1;
1613 		break;
1614 	default:
1615 		mismatch = 1;
1616 		break;
1617 	}
1618 	if (mismatch) {
1619 		error("Protocol major versions differ: %d vs. %d",
1620 		    PROTOCOL_MAJOR_2, remote_major);
1621 		send_error(ssh, "Protocol major versions differ.");
1622 		r = SSH_ERR_NO_PROTOCOL_VERSION;
1623 		goto out;
1624 	}
1625 
1626 	if (ssh->kex->server && (ssh->compat & SSH_BUG_PROBE) != 0) {
1627 		logit("probed from %s port %d with %s.  Don't panic.",
1628 		    ssh_remote_ipaddr(ssh), ssh_remote_port(ssh),
1629 		    peer_version_string);
1630 		r = SSH_ERR_CONN_CLOSED; /* XXX */
1631 		goto out;
1632 	}
1633 	if (ssh->kex->server && (ssh->compat & SSH_BUG_SCANNER) != 0) {
1634 		logit("scanned from %s port %d with %s.  Don't panic.",
1635 		    ssh_remote_ipaddr(ssh), ssh_remote_port(ssh),
1636 		    peer_version_string);
1637 		r = SSH_ERR_CONN_CLOSED; /* XXX */
1638 		goto out;
1639 	}
1640 	/* success */
1641 	r = 0;
1642  out:
1643 	free(our_version_string);
1644 	free(peer_version_string);
1645 	free(remote_version);
1646 	if (r == SSH_ERR_SYSTEM_ERROR)
1647 		errno = oerrno;
1648 	return r;
1649 }
1650 
1651