xref: /openbsd-src/usr.bin/ssh/kex.c (revision ff0e7be1ebbcc809ea8ad2b6dafe215824da9e46)
1 /* $OpenBSD: kex.c,v 1.178 2023/03/12 10:40:39 dtucker Exp $ */
2 /*
3  * Copyright (c) 2000, 2001 Markus Friedl.  All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  * 1. Redistributions of source code must retain the above copyright
9  *    notice, this list of conditions and the following disclaimer.
10  * 2. Redistributions in binary form must reproduce the above copyright
11  *    notice, this list of conditions and the following disclaimer in the
12  *    documentation and/or other materials provided with the distribution.
13  *
14  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
15  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
16  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
17  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
18  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
19  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
20  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
21  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
23  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24  */
25 
26 
27 #include <sys/types.h>
28 #include <errno.h>
29 #include <signal.h>
30 #include <stdio.h>
31 #include <stdlib.h>
32 #include <string.h>
33 #include <unistd.h>
34 #include <poll.h>
35 
36 #ifdef WITH_OPENSSL
37 #include <openssl/crypto.h>
38 #endif
39 
40 #include "ssh.h"
41 #include "ssh2.h"
42 #include "atomicio.h"
43 #include "version.h"
44 #include "packet.h"
45 #include "compat.h"
46 #include "cipher.h"
47 #include "sshkey.h"
48 #include "kex.h"
49 #include "log.h"
50 #include "mac.h"
51 #include "match.h"
52 #include "misc.h"
53 #include "dispatch.h"
54 #include "monitor.h"
55 #include "myproposal.h"
56 
57 #include "ssherr.h"
58 #include "sshbuf.h"
59 #include "digest.h"
60 #include "xmalloc.h"
61 
62 /* prototype */
63 static int kex_choose_conf(struct ssh *);
64 static int kex_input_newkeys(int, u_int32_t, struct ssh *);
65 
66 static const char * const proposal_names[PROPOSAL_MAX] = {
67 	"KEX algorithms",
68 	"host key algorithms",
69 	"ciphers ctos",
70 	"ciphers stoc",
71 	"MACs ctos",
72 	"MACs stoc",
73 	"compression ctos",
74 	"compression stoc",
75 	"languages ctos",
76 	"languages stoc",
77 };
78 
79 struct kexalg {
80 	char *name;
81 	u_int type;
82 	int ec_nid;
83 	int hash_alg;
84 };
85 static const struct kexalg kexalgs[] = {
86 #ifdef WITH_OPENSSL
87 	{ KEX_DH1, KEX_DH_GRP1_SHA1, 0, SSH_DIGEST_SHA1 },
88 	{ KEX_DH14_SHA1, KEX_DH_GRP14_SHA1, 0, SSH_DIGEST_SHA1 },
89 	{ KEX_DH14_SHA256, KEX_DH_GRP14_SHA256, 0, SSH_DIGEST_SHA256 },
90 	{ KEX_DH16_SHA512, KEX_DH_GRP16_SHA512, 0, SSH_DIGEST_SHA512 },
91 	{ KEX_DH18_SHA512, KEX_DH_GRP18_SHA512, 0, SSH_DIGEST_SHA512 },
92 	{ KEX_DHGEX_SHA1, KEX_DH_GEX_SHA1, 0, SSH_DIGEST_SHA1 },
93 	{ KEX_DHGEX_SHA256, KEX_DH_GEX_SHA256, 0, SSH_DIGEST_SHA256 },
94 	{ KEX_ECDH_SHA2_NISTP256, KEX_ECDH_SHA2,
95 	    NID_X9_62_prime256v1, SSH_DIGEST_SHA256 },
96 	{ KEX_ECDH_SHA2_NISTP384, KEX_ECDH_SHA2, NID_secp384r1,
97 	    SSH_DIGEST_SHA384 },
98 	{ KEX_ECDH_SHA2_NISTP521, KEX_ECDH_SHA2, NID_secp521r1,
99 	    SSH_DIGEST_SHA512 },
100 #endif
101 	{ KEX_CURVE25519_SHA256, KEX_C25519_SHA256, 0, SSH_DIGEST_SHA256 },
102 	{ KEX_CURVE25519_SHA256_OLD, KEX_C25519_SHA256, 0, SSH_DIGEST_SHA256 },
103 	{ KEX_SNTRUP761X25519_SHA512, KEX_KEM_SNTRUP761X25519_SHA512, 0,
104 	    SSH_DIGEST_SHA512 },
105 	{ NULL, 0, -1, -1},
106 };
107 
108 char *
109 kex_alg_list(char sep)
110 {
111 	char *ret = NULL, *tmp;
112 	size_t nlen, rlen = 0;
113 	const struct kexalg *k;
114 
115 	for (k = kexalgs; k->name != NULL; k++) {
116 		if (ret != NULL)
117 			ret[rlen++] = sep;
118 		nlen = strlen(k->name);
119 		if ((tmp = realloc(ret, rlen + nlen + 2)) == NULL) {
120 			free(ret);
121 			return NULL;
122 		}
123 		ret = tmp;
124 		memcpy(ret + rlen, k->name, nlen + 1);
125 		rlen += nlen;
126 	}
127 	return ret;
128 }
129 
130 static const struct kexalg *
131 kex_alg_by_name(const char *name)
132 {
133 	const struct kexalg *k;
134 
135 	for (k = kexalgs; k->name != NULL; k++) {
136 		if (strcmp(k->name, name) == 0)
137 			return k;
138 	}
139 	return NULL;
140 }
141 
142 /* Validate KEX method name list */
143 int
144 kex_names_valid(const char *names)
145 {
146 	char *s, *cp, *p;
147 
148 	if (names == NULL || strcmp(names, "") == 0)
149 		return 0;
150 	if ((s = cp = strdup(names)) == NULL)
151 		return 0;
152 	for ((p = strsep(&cp, ",")); p && *p != '\0';
153 	    (p = strsep(&cp, ","))) {
154 		if (kex_alg_by_name(p) == NULL) {
155 			error("Unsupported KEX algorithm \"%.100s\"", p);
156 			free(s);
157 			return 0;
158 		}
159 	}
160 	debug3("kex names ok: [%s]", names);
161 	free(s);
162 	return 1;
163 }
164 
165 /*
166  * Concatenate algorithm names, avoiding duplicates in the process.
167  * Caller must free returned string.
168  */
169 char *
170 kex_names_cat(const char *a, const char *b)
171 {
172 	char *ret = NULL, *tmp = NULL, *cp, *p, *m;
173 	size_t len;
174 
175 	if (a == NULL || *a == '\0')
176 		return strdup(b);
177 	if (b == NULL || *b == '\0')
178 		return strdup(a);
179 	if (strlen(b) > 1024*1024)
180 		return NULL;
181 	len = strlen(a) + strlen(b) + 2;
182 	if ((tmp = cp = strdup(b)) == NULL ||
183 	    (ret = calloc(1, len)) == NULL) {
184 		free(tmp);
185 		return NULL;
186 	}
187 	strlcpy(ret, a, len);
188 	for ((p = strsep(&cp, ",")); p && *p != '\0'; (p = strsep(&cp, ","))) {
189 		if ((m = match_list(ret, p, NULL)) != NULL) {
190 			free(m);
191 			continue; /* Algorithm already present */
192 		}
193 		if (strlcat(ret, ",", len) >= len ||
194 		    strlcat(ret, p, len) >= len) {
195 			free(tmp);
196 			free(ret);
197 			return NULL; /* Shouldn't happen */
198 		}
199 	}
200 	free(tmp);
201 	return ret;
202 }
203 
204 /*
205  * Assemble a list of algorithms from a default list and a string from a
206  * configuration file. The user-provided string may begin with '+' to
207  * indicate that it should be appended to the default, '-' that the
208  * specified names should be removed, or '^' that they should be placed
209  * at the head.
210  */
211 int
212 kex_assemble_names(char **listp, const char *def, const char *all)
213 {
214 	char *cp, *tmp, *patterns;
215 	char *list = NULL, *ret = NULL, *matching = NULL, *opatterns = NULL;
216 	int r = SSH_ERR_INTERNAL_ERROR;
217 
218 	if (listp == NULL || def == NULL || all == NULL)
219 		return SSH_ERR_INVALID_ARGUMENT;
220 
221 	if (*listp == NULL || **listp == '\0') {
222 		if ((*listp = strdup(def)) == NULL)
223 			return SSH_ERR_ALLOC_FAIL;
224 		return 0;
225 	}
226 
227 	list = *listp;
228 	*listp = NULL;
229 	if (*list == '+') {
230 		/* Append names to default list */
231 		if ((tmp = kex_names_cat(def, list + 1)) == NULL) {
232 			r = SSH_ERR_ALLOC_FAIL;
233 			goto fail;
234 		}
235 		free(list);
236 		list = tmp;
237 	} else if (*list == '-') {
238 		/* Remove names from default list */
239 		if ((*listp = match_filter_denylist(def, list + 1)) == NULL) {
240 			r = SSH_ERR_ALLOC_FAIL;
241 			goto fail;
242 		}
243 		free(list);
244 		/* filtering has already been done */
245 		return 0;
246 	} else if (*list == '^') {
247 		/* Place names at head of default list */
248 		if ((tmp = kex_names_cat(list + 1, def)) == NULL) {
249 			r = SSH_ERR_ALLOC_FAIL;
250 			goto fail;
251 		}
252 		free(list);
253 		list = tmp;
254 	} else {
255 		/* Explicit list, overrides default - just use "list" as is */
256 	}
257 
258 	/*
259 	 * The supplied names may be a pattern-list. For the -list case,
260 	 * the patterns are applied above. For the +list and explicit list
261 	 * cases we need to do it now.
262 	 */
263 	ret = NULL;
264 	if ((patterns = opatterns = strdup(list)) == NULL) {
265 		r = SSH_ERR_ALLOC_FAIL;
266 		goto fail;
267 	}
268 	/* Apply positive (i.e. non-negated) patterns from the list */
269 	while ((cp = strsep(&patterns, ",")) != NULL) {
270 		if (*cp == '!') {
271 			/* negated matches are not supported here */
272 			r = SSH_ERR_INVALID_ARGUMENT;
273 			goto fail;
274 		}
275 		free(matching);
276 		if ((matching = match_filter_allowlist(all, cp)) == NULL) {
277 			r = SSH_ERR_ALLOC_FAIL;
278 			goto fail;
279 		}
280 		if ((tmp = kex_names_cat(ret, matching)) == NULL) {
281 			r = SSH_ERR_ALLOC_FAIL;
282 			goto fail;
283 		}
284 		free(ret);
285 		ret = tmp;
286 	}
287 	if (ret == NULL || *ret == '\0') {
288 		/* An empty name-list is an error */
289 		/* XXX better error code? */
290 		r = SSH_ERR_INVALID_ARGUMENT;
291 		goto fail;
292 	}
293 
294 	/* success */
295 	*listp = ret;
296 	ret = NULL;
297 	r = 0;
298 
299  fail:
300 	free(matching);
301 	free(opatterns);
302 	free(list);
303 	free(ret);
304 	return r;
305 }
306 
307 /*
308  * Fill out a proposal array with dynamically allocated values, which may
309  * be modified as required for compatibility reasons.
310  * Any of the options may be NULL, in which case the default is used.
311  * Array contents must be freed by calling kex_proposal_free_entries.
312  */
313 void
314 kex_proposal_populate_entries(struct ssh *ssh, char *prop[PROPOSAL_MAX],
315     const char *kexalgos, const char *ciphers, const char *macs,
316     const char *comp, const char *hkalgs)
317 {
318 	const char *defpropserver[PROPOSAL_MAX] = { KEX_SERVER };
319 	const char *defpropclient[PROPOSAL_MAX] = { KEX_CLIENT };
320 	const char **defprop = ssh->kex->server ? defpropserver : defpropclient;
321 	u_int i;
322 
323 	if (prop == NULL)
324 		fatal_f("proposal missing");
325 
326 	for (i = 0; i < PROPOSAL_MAX; i++) {
327 		switch(i) {
328 		case PROPOSAL_KEX_ALGS:
329 			prop[i] = compat_kex_proposal(ssh,
330 			    kexalgos ? kexalgos : defprop[i]);
331 			break;
332 		case PROPOSAL_ENC_ALGS_CTOS:
333 		case PROPOSAL_ENC_ALGS_STOC:
334 			prop[i] = xstrdup(ciphers ? ciphers : defprop[i]);
335 			break;
336 		case PROPOSAL_MAC_ALGS_CTOS:
337 		case PROPOSAL_MAC_ALGS_STOC:
338 			prop[i]  = xstrdup(macs ? macs : defprop[i]);
339 			break;
340 		case PROPOSAL_COMP_ALGS_CTOS:
341 		case PROPOSAL_COMP_ALGS_STOC:
342 			prop[i] = xstrdup(comp ? comp : defprop[i]);
343 			break;
344 		case PROPOSAL_SERVER_HOST_KEY_ALGS:
345 			prop[i] = xstrdup(hkalgs ? hkalgs : defprop[i]);
346 			break;
347 		default:
348 			prop[i] = xstrdup(defprop[i]);
349 		}
350 	}
351 }
352 
353 void
354 kex_proposal_free_entries(char *prop[PROPOSAL_MAX])
355 {
356 	u_int i;
357 
358 	for (i = 0; i < PROPOSAL_MAX; i++)
359 		free(prop[i]);
360 }
361 
362 /* put algorithm proposal into buffer */
363 int
364 kex_prop2buf(struct sshbuf *b, char *proposal[PROPOSAL_MAX])
365 {
366 	u_int i;
367 	int r;
368 
369 	sshbuf_reset(b);
370 
371 	/*
372 	 * add a dummy cookie, the cookie will be overwritten by
373 	 * kex_send_kexinit(), each time a kexinit is set
374 	 */
375 	for (i = 0; i < KEX_COOKIE_LEN; i++) {
376 		if ((r = sshbuf_put_u8(b, 0)) != 0)
377 			return r;
378 	}
379 	for (i = 0; i < PROPOSAL_MAX; i++) {
380 		if ((r = sshbuf_put_cstring(b, proposal[i])) != 0)
381 			return r;
382 	}
383 	if ((r = sshbuf_put_u8(b, 0)) != 0 ||	/* first_kex_packet_follows */
384 	    (r = sshbuf_put_u32(b, 0)) != 0)	/* uint32 reserved */
385 		return r;
386 	return 0;
387 }
388 
389 /* parse buffer and return algorithm proposal */
390 int
391 kex_buf2prop(struct sshbuf *raw, int *first_kex_follows, char ***propp)
392 {
393 	struct sshbuf *b = NULL;
394 	u_char v;
395 	u_int i;
396 	char **proposal = NULL;
397 	int r;
398 
399 	*propp = NULL;
400 	if ((proposal = calloc(PROPOSAL_MAX, sizeof(char *))) == NULL)
401 		return SSH_ERR_ALLOC_FAIL;
402 	if ((b = sshbuf_fromb(raw)) == NULL) {
403 		r = SSH_ERR_ALLOC_FAIL;
404 		goto out;
405 	}
406 	if ((r = sshbuf_consume(b, KEX_COOKIE_LEN)) != 0) { /* skip cookie */
407 		error_fr(r, "consume cookie");
408 		goto out;
409 	}
410 	/* extract kex init proposal strings */
411 	for (i = 0; i < PROPOSAL_MAX; i++) {
412 		if ((r = sshbuf_get_cstring(b, &(proposal[i]), NULL)) != 0) {
413 			error_fr(r, "parse proposal %u", i);
414 			goto out;
415 		}
416 		debug2("%s: %s", proposal_names[i], proposal[i]);
417 	}
418 	/* first kex follows / reserved */
419 	if ((r = sshbuf_get_u8(b, &v)) != 0 ||	/* first_kex_follows */
420 	    (r = sshbuf_get_u32(b, &i)) != 0) {	/* reserved */
421 		error_fr(r, "parse");
422 		goto out;
423 	}
424 	if (first_kex_follows != NULL)
425 		*first_kex_follows = v;
426 	debug2("first_kex_follows %d ", v);
427 	debug2("reserved %u ", i);
428 	r = 0;
429 	*propp = proposal;
430  out:
431 	if (r != 0 && proposal != NULL)
432 		kex_prop_free(proposal);
433 	sshbuf_free(b);
434 	return r;
435 }
436 
437 void
438 kex_prop_free(char **proposal)
439 {
440 	u_int i;
441 
442 	if (proposal == NULL)
443 		return;
444 	for (i = 0; i < PROPOSAL_MAX; i++)
445 		free(proposal[i]);
446 	free(proposal);
447 }
448 
449 int
450 kex_protocol_error(int type, u_int32_t seq, struct ssh *ssh)
451 {
452 	int r;
453 
454 	error("kex protocol error: type %d seq %u", type, seq);
455 	if ((r = sshpkt_start(ssh, SSH2_MSG_UNIMPLEMENTED)) != 0 ||
456 	    (r = sshpkt_put_u32(ssh, seq)) != 0 ||
457 	    (r = sshpkt_send(ssh)) != 0)
458 		return r;
459 	return 0;
460 }
461 
462 static void
463 kex_reset_dispatch(struct ssh *ssh)
464 {
465 	ssh_dispatch_range(ssh, SSH2_MSG_TRANSPORT_MIN,
466 	    SSH2_MSG_TRANSPORT_MAX, &kex_protocol_error);
467 }
468 
469 static int
470 kex_send_ext_info(struct ssh *ssh)
471 {
472 	int r;
473 	char *algs;
474 
475 	debug("Sending SSH2_MSG_EXT_INFO");
476 	if ((algs = sshkey_alg_list(0, 1, 1, ',')) == NULL)
477 		return SSH_ERR_ALLOC_FAIL;
478 	/* XXX filter algs list by allowed pubkey/hostbased types */
479 	if ((r = sshpkt_start(ssh, SSH2_MSG_EXT_INFO)) != 0 ||
480 	    (r = sshpkt_put_u32(ssh, 2)) != 0 ||
481 	    (r = sshpkt_put_cstring(ssh, "server-sig-algs")) != 0 ||
482 	    (r = sshpkt_put_cstring(ssh, algs)) != 0 ||
483 	    (r = sshpkt_put_cstring(ssh,
484 	    "publickey-hostbound@openssh.com")) != 0 ||
485 	    (r = sshpkt_put_cstring(ssh, "0")) != 0 ||
486 	    (r = sshpkt_send(ssh)) != 0) {
487 		error_fr(r, "compose");
488 		goto out;
489 	}
490 	/* success */
491 	r = 0;
492  out:
493 	free(algs);
494 	return r;
495 }
496 
497 int
498 kex_send_newkeys(struct ssh *ssh)
499 {
500 	int r;
501 
502 	kex_reset_dispatch(ssh);
503 	if ((r = sshpkt_start(ssh, SSH2_MSG_NEWKEYS)) != 0 ||
504 	    (r = sshpkt_send(ssh)) != 0)
505 		return r;
506 	debug("SSH2_MSG_NEWKEYS sent");
507 	ssh_dispatch_set(ssh, SSH2_MSG_NEWKEYS, &kex_input_newkeys);
508 	if (ssh->kex->ext_info_c && (ssh->kex->flags & KEX_INITIAL) != 0)
509 		if ((r = kex_send_ext_info(ssh)) != 0)
510 			return r;
511 	debug("expecting SSH2_MSG_NEWKEYS");
512 	return 0;
513 }
514 
515 int
516 kex_input_ext_info(int type, u_int32_t seq, struct ssh *ssh)
517 {
518 	struct kex *kex = ssh->kex;
519 	u_int32_t i, ninfo;
520 	char *name;
521 	u_char *val;
522 	size_t vlen;
523 	int r;
524 
525 	debug("SSH2_MSG_EXT_INFO received");
526 	ssh_dispatch_set(ssh, SSH2_MSG_EXT_INFO, &kex_protocol_error);
527 	if ((r = sshpkt_get_u32(ssh, &ninfo)) != 0)
528 		return r;
529 	if (ninfo >= 1024) {
530 		error("SSH2_MSG_EXT_INFO with too many entries, expected "
531 		    "<=1024, received %u", ninfo);
532 		return SSH_ERR_INVALID_FORMAT;
533 	}
534 	for (i = 0; i < ninfo; i++) {
535 		if ((r = sshpkt_get_cstring(ssh, &name, NULL)) != 0)
536 			return r;
537 		if ((r = sshpkt_get_string(ssh, &val, &vlen)) != 0) {
538 			free(name);
539 			return r;
540 		}
541 		if (strcmp(name, "server-sig-algs") == 0) {
542 			/* Ensure no \0 lurking in value */
543 			if (memchr(val, '\0', vlen) != NULL) {
544 				error_f("nul byte in %s", name);
545 				return SSH_ERR_INVALID_FORMAT;
546 			}
547 			debug_f("%s=<%s>", name, val);
548 			kex->server_sig_algs = val;
549 			val = NULL;
550 		} else if (strcmp(name,
551 		    "publickey-hostbound@openssh.com") == 0) {
552 			/* XXX refactor */
553 			/* Ensure no \0 lurking in value */
554 			if (memchr(val, '\0', vlen) != NULL) {
555 				error_f("nul byte in %s", name);
556 				return SSH_ERR_INVALID_FORMAT;
557 			}
558 			debug_f("%s=<%s>", name, val);
559 			if (strcmp(val, "0") == 0)
560 				kex->flags |= KEX_HAS_PUBKEY_HOSTBOUND;
561 			else {
562 				debug_f("unsupported version of %s extension",
563 				    name);
564 			}
565 		} else
566 			debug_f("%s (unrecognised)", name);
567 		free(name);
568 		free(val);
569 	}
570 	return sshpkt_get_end(ssh);
571 }
572 
573 static int
574 kex_input_newkeys(int type, u_int32_t seq, struct ssh *ssh)
575 {
576 	struct kex *kex = ssh->kex;
577 	int r;
578 
579 	debug("SSH2_MSG_NEWKEYS received");
580 	ssh_dispatch_set(ssh, SSH2_MSG_NEWKEYS, &kex_protocol_error);
581 	ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_input_kexinit);
582 	if ((r = sshpkt_get_end(ssh)) != 0)
583 		return r;
584 	if ((r = ssh_set_newkeys(ssh, MODE_IN)) != 0)
585 		return r;
586 	kex->done = 1;
587 	kex->flags &= ~KEX_INITIAL;
588 	sshbuf_reset(kex->peer);
589 	/* sshbuf_reset(kex->my); */
590 	kex->flags &= ~KEX_INIT_SENT;
591 	free(kex->name);
592 	kex->name = NULL;
593 	return 0;
594 }
595 
596 int
597 kex_send_kexinit(struct ssh *ssh)
598 {
599 	u_char *cookie;
600 	struct kex *kex = ssh->kex;
601 	int r;
602 
603 	if (kex == NULL) {
604 		error_f("no kex");
605 		return SSH_ERR_INTERNAL_ERROR;
606 	}
607 	if (kex->flags & KEX_INIT_SENT)
608 		return 0;
609 	kex->done = 0;
610 
611 	/* generate a random cookie */
612 	if (sshbuf_len(kex->my) < KEX_COOKIE_LEN) {
613 		error_f("bad kex length: %zu < %d",
614 		    sshbuf_len(kex->my), KEX_COOKIE_LEN);
615 		return SSH_ERR_INVALID_FORMAT;
616 	}
617 	if ((cookie = sshbuf_mutable_ptr(kex->my)) == NULL) {
618 		error_f("buffer error");
619 		return SSH_ERR_INTERNAL_ERROR;
620 	}
621 	arc4random_buf(cookie, KEX_COOKIE_LEN);
622 
623 	if ((r = sshpkt_start(ssh, SSH2_MSG_KEXINIT)) != 0 ||
624 	    (r = sshpkt_putb(ssh, kex->my)) != 0 ||
625 	    (r = sshpkt_send(ssh)) != 0) {
626 		error_fr(r, "compose reply");
627 		return r;
628 	}
629 	debug("SSH2_MSG_KEXINIT sent");
630 	kex->flags |= KEX_INIT_SENT;
631 	return 0;
632 }
633 
634 int
635 kex_input_kexinit(int type, u_int32_t seq, struct ssh *ssh)
636 {
637 	struct kex *kex = ssh->kex;
638 	const u_char *ptr;
639 	u_int i;
640 	size_t dlen;
641 	int r;
642 
643 	debug("SSH2_MSG_KEXINIT received");
644 	if (kex == NULL) {
645 		error_f("no kex");
646 		return SSH_ERR_INTERNAL_ERROR;
647 	}
648 	ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, NULL);
649 	ptr = sshpkt_ptr(ssh, &dlen);
650 	if ((r = sshbuf_put(kex->peer, ptr, dlen)) != 0)
651 		return r;
652 
653 	/* discard packet */
654 	for (i = 0; i < KEX_COOKIE_LEN; i++) {
655 		if ((r = sshpkt_get_u8(ssh, NULL)) != 0) {
656 			error_fr(r, "discard cookie");
657 			return r;
658 		}
659 	}
660 	for (i = 0; i < PROPOSAL_MAX; i++) {
661 		if ((r = sshpkt_get_string(ssh, NULL, NULL)) != 0) {
662 			error_fr(r, "discard proposal");
663 			return r;
664 		}
665 	}
666 	/*
667 	 * XXX RFC4253 sec 7: "each side MAY guess" - currently no supported
668 	 * KEX method has the server move first, but a server might be using
669 	 * a custom method or one that we otherwise don't support. We should
670 	 * be prepared to remember first_kex_follows here so we can eat a
671 	 * packet later.
672 	 * XXX2 - RFC4253 is kind of ambiguous on what first_kex_follows means
673 	 * for cases where the server *doesn't* go first. I guess we should
674 	 * ignore it when it is set for these cases, which is what we do now.
675 	 */
676 	if ((r = sshpkt_get_u8(ssh, NULL)) != 0 ||	/* first_kex_follows */
677 	    (r = sshpkt_get_u32(ssh, NULL)) != 0 ||	/* reserved */
678 	    (r = sshpkt_get_end(ssh)) != 0)
679 			return r;
680 
681 	if (!(kex->flags & KEX_INIT_SENT))
682 		if ((r = kex_send_kexinit(ssh)) != 0)
683 			return r;
684 	if ((r = kex_choose_conf(ssh)) != 0)
685 		return r;
686 
687 	if (kex->kex_type < KEX_MAX && kex->kex[kex->kex_type] != NULL)
688 		return (kex->kex[kex->kex_type])(ssh);
689 
690 	error_f("unknown kex type %u", kex->kex_type);
691 	return SSH_ERR_INTERNAL_ERROR;
692 }
693 
694 struct kex *
695 kex_new(void)
696 {
697 	struct kex *kex;
698 
699 	if ((kex = calloc(1, sizeof(*kex))) == NULL ||
700 	    (kex->peer = sshbuf_new()) == NULL ||
701 	    (kex->my = sshbuf_new()) == NULL ||
702 	    (kex->client_version = sshbuf_new()) == NULL ||
703 	    (kex->server_version = sshbuf_new()) == NULL ||
704 	    (kex->session_id = sshbuf_new()) == NULL) {
705 		kex_free(kex);
706 		return NULL;
707 	}
708 	return kex;
709 }
710 
711 void
712 kex_free_newkeys(struct newkeys *newkeys)
713 {
714 	if (newkeys == NULL)
715 		return;
716 	if (newkeys->enc.key) {
717 		explicit_bzero(newkeys->enc.key, newkeys->enc.key_len);
718 		free(newkeys->enc.key);
719 		newkeys->enc.key = NULL;
720 	}
721 	if (newkeys->enc.iv) {
722 		explicit_bzero(newkeys->enc.iv, newkeys->enc.iv_len);
723 		free(newkeys->enc.iv);
724 		newkeys->enc.iv = NULL;
725 	}
726 	free(newkeys->enc.name);
727 	explicit_bzero(&newkeys->enc, sizeof(newkeys->enc));
728 	free(newkeys->comp.name);
729 	explicit_bzero(&newkeys->comp, sizeof(newkeys->comp));
730 	mac_clear(&newkeys->mac);
731 	if (newkeys->mac.key) {
732 		explicit_bzero(newkeys->mac.key, newkeys->mac.key_len);
733 		free(newkeys->mac.key);
734 		newkeys->mac.key = NULL;
735 	}
736 	free(newkeys->mac.name);
737 	explicit_bzero(&newkeys->mac, sizeof(newkeys->mac));
738 	freezero(newkeys, sizeof(*newkeys));
739 }
740 
741 void
742 kex_free(struct kex *kex)
743 {
744 	u_int mode;
745 
746 	if (kex == NULL)
747 		return;
748 
749 #ifdef WITH_OPENSSL
750 	DH_free(kex->dh);
751 	EC_KEY_free(kex->ec_client_key);
752 #endif
753 	for (mode = 0; mode < MODE_MAX; mode++) {
754 		kex_free_newkeys(kex->newkeys[mode]);
755 		kex->newkeys[mode] = NULL;
756 	}
757 	sshbuf_free(kex->peer);
758 	sshbuf_free(kex->my);
759 	sshbuf_free(kex->client_version);
760 	sshbuf_free(kex->server_version);
761 	sshbuf_free(kex->client_pub);
762 	sshbuf_free(kex->session_id);
763 	sshbuf_free(kex->initial_sig);
764 	sshkey_free(kex->initial_hostkey);
765 	free(kex->failed_choice);
766 	free(kex->hostkey_alg);
767 	free(kex->name);
768 	free(kex);
769 }
770 
771 int
772 kex_ready(struct ssh *ssh, char *proposal[PROPOSAL_MAX])
773 {
774 	int r;
775 
776 	if ((r = kex_prop2buf(ssh->kex->my, proposal)) != 0)
777 		return r;
778 	ssh->kex->flags = KEX_INITIAL;
779 	kex_reset_dispatch(ssh);
780 	ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_input_kexinit);
781 	return 0;
782 }
783 
784 int
785 kex_setup(struct ssh *ssh, char *proposal[PROPOSAL_MAX])
786 {
787 	int r;
788 
789 	if ((r = kex_ready(ssh, proposal)) != 0)
790 		return r;
791 	if ((r = kex_send_kexinit(ssh)) != 0) {		/* we start */
792 		kex_free(ssh->kex);
793 		ssh->kex = NULL;
794 		return r;
795 	}
796 	return 0;
797 }
798 
799 /*
800  * Request key re-exchange, returns 0 on success or a ssherr.h error
801  * code otherwise. Must not be called if KEX is incomplete or in-progress.
802  */
803 int
804 kex_start_rekex(struct ssh *ssh)
805 {
806 	if (ssh->kex == NULL) {
807 		error_f("no kex");
808 		return SSH_ERR_INTERNAL_ERROR;
809 	}
810 	if (ssh->kex->done == 0) {
811 		error_f("requested twice");
812 		return SSH_ERR_INTERNAL_ERROR;
813 	}
814 	ssh->kex->done = 0;
815 	return kex_send_kexinit(ssh);
816 }
817 
818 static int
819 choose_enc(struct sshenc *enc, char *client, char *server)
820 {
821 	char *name = match_list(client, server, NULL);
822 
823 	if (name == NULL)
824 		return SSH_ERR_NO_CIPHER_ALG_MATCH;
825 	if ((enc->cipher = cipher_by_name(name)) == NULL) {
826 		error_f("unsupported cipher %s", name);
827 		free(name);
828 		return SSH_ERR_INTERNAL_ERROR;
829 	}
830 	enc->name = name;
831 	enc->enabled = 0;
832 	enc->iv = NULL;
833 	enc->iv_len = cipher_ivlen(enc->cipher);
834 	enc->key = NULL;
835 	enc->key_len = cipher_keylen(enc->cipher);
836 	enc->block_size = cipher_blocksize(enc->cipher);
837 	return 0;
838 }
839 
840 static int
841 choose_mac(struct ssh *ssh, struct sshmac *mac, char *client, char *server)
842 {
843 	char *name = match_list(client, server, NULL);
844 
845 	if (name == NULL)
846 		return SSH_ERR_NO_MAC_ALG_MATCH;
847 	if (mac_setup(mac, name) < 0) {
848 		error_f("unsupported MAC %s", name);
849 		free(name);
850 		return SSH_ERR_INTERNAL_ERROR;
851 	}
852 	mac->name = name;
853 	mac->key = NULL;
854 	mac->enabled = 0;
855 	return 0;
856 }
857 
858 static int
859 choose_comp(struct sshcomp *comp, char *client, char *server)
860 {
861 	char *name = match_list(client, server, NULL);
862 
863 	if (name == NULL)
864 		return SSH_ERR_NO_COMPRESS_ALG_MATCH;
865 #ifdef WITH_ZLIB
866 	if (strcmp(name, "zlib@openssh.com") == 0) {
867 		comp->type = COMP_DELAYED;
868 	} else if (strcmp(name, "zlib") == 0) {
869 		comp->type = COMP_ZLIB;
870 	} else
871 #endif	/* WITH_ZLIB */
872 	if (strcmp(name, "none") == 0) {
873 		comp->type = COMP_NONE;
874 	} else {
875 		error_f("unsupported compression scheme %s", name);
876 		free(name);
877 		return SSH_ERR_INTERNAL_ERROR;
878 	}
879 	comp->name = name;
880 	return 0;
881 }
882 
883 static int
884 choose_kex(struct kex *k, char *client, char *server)
885 {
886 	const struct kexalg *kexalg;
887 
888 	k->name = match_list(client, server, NULL);
889 
890 	debug("kex: algorithm: %s", k->name ? k->name : "(no match)");
891 	if (k->name == NULL)
892 		return SSH_ERR_NO_KEX_ALG_MATCH;
893 	if ((kexalg = kex_alg_by_name(k->name)) == NULL) {
894 		error_f("unsupported KEX method %s", k->name);
895 		return SSH_ERR_INTERNAL_ERROR;
896 	}
897 	k->kex_type = kexalg->type;
898 	k->hash_alg = kexalg->hash_alg;
899 	k->ec_nid = kexalg->ec_nid;
900 	return 0;
901 }
902 
903 static int
904 choose_hostkeyalg(struct kex *k, char *client, char *server)
905 {
906 	free(k->hostkey_alg);
907 	k->hostkey_alg = match_list(client, server, NULL);
908 
909 	debug("kex: host key algorithm: %s",
910 	    k->hostkey_alg ? k->hostkey_alg : "(no match)");
911 	if (k->hostkey_alg == NULL)
912 		return SSH_ERR_NO_HOSTKEY_ALG_MATCH;
913 	k->hostkey_type = sshkey_type_from_name(k->hostkey_alg);
914 	if (k->hostkey_type == KEY_UNSPEC) {
915 		error_f("unsupported hostkey algorithm %s", k->hostkey_alg);
916 		return SSH_ERR_INTERNAL_ERROR;
917 	}
918 	k->hostkey_nid = sshkey_ecdsa_nid_from_name(k->hostkey_alg);
919 	return 0;
920 }
921 
922 static int
923 proposals_match(char *my[PROPOSAL_MAX], char *peer[PROPOSAL_MAX])
924 {
925 	static int check[] = {
926 		PROPOSAL_KEX_ALGS, PROPOSAL_SERVER_HOST_KEY_ALGS, -1
927 	};
928 	int *idx;
929 	char *p;
930 
931 	for (idx = &check[0]; *idx != -1; idx++) {
932 		if ((p = strchr(my[*idx], ',')) != NULL)
933 			*p = '\0';
934 		if ((p = strchr(peer[*idx], ',')) != NULL)
935 			*p = '\0';
936 		if (strcmp(my[*idx], peer[*idx]) != 0) {
937 			debug2("proposal mismatch: my %s peer %s",
938 			    my[*idx], peer[*idx]);
939 			return (0);
940 		}
941 	}
942 	debug2("proposals match");
943 	return (1);
944 }
945 
946 /* returns non-zero if proposal contains any algorithm from algs */
947 static int
948 has_any_alg(const char *proposal, const char *algs)
949 {
950 	char *cp;
951 
952 	if ((cp = match_list(proposal, algs, NULL)) == NULL)
953 		return 0;
954 	free(cp);
955 	return 1;
956 }
957 
958 static int
959 kex_choose_conf(struct ssh *ssh)
960 {
961 	struct kex *kex = ssh->kex;
962 	struct newkeys *newkeys;
963 	char **my = NULL, **peer = NULL;
964 	char **cprop, **sprop;
965 	int nenc, nmac, ncomp;
966 	u_int mode, ctos, need, dh_need, authlen;
967 	int r, first_kex_follows;
968 
969 	debug2("local %s KEXINIT proposal", kex->server ? "server" : "client");
970 	if ((r = kex_buf2prop(kex->my, NULL, &my)) != 0)
971 		goto out;
972 	debug2("peer %s KEXINIT proposal", kex->server ? "client" : "server");
973 	if ((r = kex_buf2prop(kex->peer, &first_kex_follows, &peer)) != 0)
974 		goto out;
975 
976 	if (kex->server) {
977 		cprop=peer;
978 		sprop=my;
979 	} else {
980 		cprop=my;
981 		sprop=peer;
982 	}
983 
984 	/* Check whether client supports ext_info_c */
985 	if (kex->server && (kex->flags & KEX_INITIAL)) {
986 		char *ext;
987 
988 		ext = match_list("ext-info-c", peer[PROPOSAL_KEX_ALGS], NULL);
989 		kex->ext_info_c = (ext != NULL);
990 		free(ext);
991 	}
992 
993 	/* Check whether client supports rsa-sha2 algorithms */
994 	if (kex->server && (kex->flags & KEX_INITIAL)) {
995 		if (has_any_alg(peer[PROPOSAL_SERVER_HOST_KEY_ALGS],
996 		    "rsa-sha2-256,rsa-sha2-256-cert-v01@openssh.com"))
997 			kex->flags |= KEX_RSA_SHA2_256_SUPPORTED;
998 		if (has_any_alg(peer[PROPOSAL_SERVER_HOST_KEY_ALGS],
999 		    "rsa-sha2-512,rsa-sha2-512-cert-v01@openssh.com"))
1000 			kex->flags |= KEX_RSA_SHA2_512_SUPPORTED;
1001 	}
1002 
1003 	/* Algorithm Negotiation */
1004 	if ((r = choose_kex(kex, cprop[PROPOSAL_KEX_ALGS],
1005 	    sprop[PROPOSAL_KEX_ALGS])) != 0) {
1006 		kex->failed_choice = peer[PROPOSAL_KEX_ALGS];
1007 		peer[PROPOSAL_KEX_ALGS] = NULL;
1008 		goto out;
1009 	}
1010 	if ((r = choose_hostkeyalg(kex, cprop[PROPOSAL_SERVER_HOST_KEY_ALGS],
1011 	    sprop[PROPOSAL_SERVER_HOST_KEY_ALGS])) != 0) {
1012 		kex->failed_choice = peer[PROPOSAL_SERVER_HOST_KEY_ALGS];
1013 		peer[PROPOSAL_SERVER_HOST_KEY_ALGS] = NULL;
1014 		goto out;
1015 	}
1016 	for (mode = 0; mode < MODE_MAX; mode++) {
1017 		if ((newkeys = calloc(1, sizeof(*newkeys))) == NULL) {
1018 			r = SSH_ERR_ALLOC_FAIL;
1019 			goto out;
1020 		}
1021 		kex->newkeys[mode] = newkeys;
1022 		ctos = (!kex->server && mode == MODE_OUT) ||
1023 		    (kex->server && mode == MODE_IN);
1024 		nenc  = ctos ? PROPOSAL_ENC_ALGS_CTOS  : PROPOSAL_ENC_ALGS_STOC;
1025 		nmac  = ctos ? PROPOSAL_MAC_ALGS_CTOS  : PROPOSAL_MAC_ALGS_STOC;
1026 		ncomp = ctos ? PROPOSAL_COMP_ALGS_CTOS : PROPOSAL_COMP_ALGS_STOC;
1027 		if ((r = choose_enc(&newkeys->enc, cprop[nenc],
1028 		    sprop[nenc])) != 0) {
1029 			kex->failed_choice = peer[nenc];
1030 			peer[nenc] = NULL;
1031 			goto out;
1032 		}
1033 		authlen = cipher_authlen(newkeys->enc.cipher);
1034 		/* ignore mac for authenticated encryption */
1035 		if (authlen == 0 &&
1036 		    (r = choose_mac(ssh, &newkeys->mac, cprop[nmac],
1037 		    sprop[nmac])) != 0) {
1038 			kex->failed_choice = peer[nmac];
1039 			peer[nmac] = NULL;
1040 			goto out;
1041 		}
1042 		if ((r = choose_comp(&newkeys->comp, cprop[ncomp],
1043 		    sprop[ncomp])) != 0) {
1044 			kex->failed_choice = peer[ncomp];
1045 			peer[ncomp] = NULL;
1046 			goto out;
1047 		}
1048 		debug("kex: %s cipher: %s MAC: %s compression: %s",
1049 		    ctos ? "client->server" : "server->client",
1050 		    newkeys->enc.name,
1051 		    authlen == 0 ? newkeys->mac.name : "<implicit>",
1052 		    newkeys->comp.name);
1053 	}
1054 	need = dh_need = 0;
1055 	for (mode = 0; mode < MODE_MAX; mode++) {
1056 		newkeys = kex->newkeys[mode];
1057 		need = MAXIMUM(need, newkeys->enc.key_len);
1058 		need = MAXIMUM(need, newkeys->enc.block_size);
1059 		need = MAXIMUM(need, newkeys->enc.iv_len);
1060 		need = MAXIMUM(need, newkeys->mac.key_len);
1061 		dh_need = MAXIMUM(dh_need, cipher_seclen(newkeys->enc.cipher));
1062 		dh_need = MAXIMUM(dh_need, newkeys->enc.block_size);
1063 		dh_need = MAXIMUM(dh_need, newkeys->enc.iv_len);
1064 		dh_need = MAXIMUM(dh_need, newkeys->mac.key_len);
1065 	}
1066 	/* XXX need runden? */
1067 	kex->we_need = need;
1068 	kex->dh_need = dh_need;
1069 
1070 	/* ignore the next message if the proposals do not match */
1071 	if (first_kex_follows && !proposals_match(my, peer))
1072 		ssh->dispatch_skip_packets = 1;
1073 	r = 0;
1074  out:
1075 	kex_prop_free(my);
1076 	kex_prop_free(peer);
1077 	return r;
1078 }
1079 
1080 static int
1081 derive_key(struct ssh *ssh, int id, u_int need, u_char *hash, u_int hashlen,
1082     const struct sshbuf *shared_secret, u_char **keyp)
1083 {
1084 	struct kex *kex = ssh->kex;
1085 	struct ssh_digest_ctx *hashctx = NULL;
1086 	char c = id;
1087 	u_int have;
1088 	size_t mdsz;
1089 	u_char *digest;
1090 	int r;
1091 
1092 	if ((mdsz = ssh_digest_bytes(kex->hash_alg)) == 0)
1093 		return SSH_ERR_INVALID_ARGUMENT;
1094 	if ((digest = calloc(1, ROUNDUP(need, mdsz))) == NULL) {
1095 		r = SSH_ERR_ALLOC_FAIL;
1096 		goto out;
1097 	}
1098 
1099 	/* K1 = HASH(K || H || "A" || session_id) */
1100 	if ((hashctx = ssh_digest_start(kex->hash_alg)) == NULL ||
1101 	    ssh_digest_update_buffer(hashctx, shared_secret) != 0 ||
1102 	    ssh_digest_update(hashctx, hash, hashlen) != 0 ||
1103 	    ssh_digest_update(hashctx, &c, 1) != 0 ||
1104 	    ssh_digest_update_buffer(hashctx, kex->session_id) != 0 ||
1105 	    ssh_digest_final(hashctx, digest, mdsz) != 0) {
1106 		r = SSH_ERR_LIBCRYPTO_ERROR;
1107 		error_f("KEX hash failed");
1108 		goto out;
1109 	}
1110 	ssh_digest_free(hashctx);
1111 	hashctx = NULL;
1112 
1113 	/*
1114 	 * expand key:
1115 	 * Kn = HASH(K || H || K1 || K2 || ... || Kn-1)
1116 	 * Key = K1 || K2 || ... || Kn
1117 	 */
1118 	for (have = mdsz; need > have; have += mdsz) {
1119 		if ((hashctx = ssh_digest_start(kex->hash_alg)) == NULL ||
1120 		    ssh_digest_update_buffer(hashctx, shared_secret) != 0 ||
1121 		    ssh_digest_update(hashctx, hash, hashlen) != 0 ||
1122 		    ssh_digest_update(hashctx, digest, have) != 0 ||
1123 		    ssh_digest_final(hashctx, digest + have, mdsz) != 0) {
1124 			error_f("KDF failed");
1125 			r = SSH_ERR_LIBCRYPTO_ERROR;
1126 			goto out;
1127 		}
1128 		ssh_digest_free(hashctx);
1129 		hashctx = NULL;
1130 	}
1131 #ifdef DEBUG_KEX
1132 	fprintf(stderr, "key '%c'== ", c);
1133 	dump_digest("key", digest, need);
1134 #endif
1135 	*keyp = digest;
1136 	digest = NULL;
1137 	r = 0;
1138  out:
1139 	free(digest);
1140 	ssh_digest_free(hashctx);
1141 	return r;
1142 }
1143 
1144 #define NKEYS	6
1145 int
1146 kex_derive_keys(struct ssh *ssh, u_char *hash, u_int hashlen,
1147     const struct sshbuf *shared_secret)
1148 {
1149 	struct kex *kex = ssh->kex;
1150 	u_char *keys[NKEYS];
1151 	u_int i, j, mode, ctos;
1152 	int r;
1153 
1154 	/* save initial hash as session id */
1155 	if ((kex->flags & KEX_INITIAL) != 0) {
1156 		if (sshbuf_len(kex->session_id) != 0) {
1157 			error_f("already have session ID at kex");
1158 			return SSH_ERR_INTERNAL_ERROR;
1159 		}
1160 		if ((r = sshbuf_put(kex->session_id, hash, hashlen)) != 0)
1161 			return r;
1162 	} else if (sshbuf_len(kex->session_id) == 0) {
1163 		error_f("no session ID in rekex");
1164 		return SSH_ERR_INTERNAL_ERROR;
1165 	}
1166 	for (i = 0; i < NKEYS; i++) {
1167 		if ((r = derive_key(ssh, 'A'+i, kex->we_need, hash, hashlen,
1168 		    shared_secret, &keys[i])) != 0) {
1169 			for (j = 0; j < i; j++)
1170 				free(keys[j]);
1171 			return r;
1172 		}
1173 	}
1174 	for (mode = 0; mode < MODE_MAX; mode++) {
1175 		ctos = (!kex->server && mode == MODE_OUT) ||
1176 		    (kex->server && mode == MODE_IN);
1177 		kex->newkeys[mode]->enc.iv  = keys[ctos ? 0 : 1];
1178 		kex->newkeys[mode]->enc.key = keys[ctos ? 2 : 3];
1179 		kex->newkeys[mode]->mac.key = keys[ctos ? 4 : 5];
1180 	}
1181 	return 0;
1182 }
1183 
1184 int
1185 kex_load_hostkey(struct ssh *ssh, struct sshkey **prvp, struct sshkey **pubp)
1186 {
1187 	struct kex *kex = ssh->kex;
1188 
1189 	*pubp = NULL;
1190 	*prvp = NULL;
1191 	if (kex->load_host_public_key == NULL ||
1192 	    kex->load_host_private_key == NULL) {
1193 		error_f("missing hostkey loader");
1194 		return SSH_ERR_INVALID_ARGUMENT;
1195 	}
1196 	*pubp = kex->load_host_public_key(kex->hostkey_type,
1197 	    kex->hostkey_nid, ssh);
1198 	*prvp = kex->load_host_private_key(kex->hostkey_type,
1199 	    kex->hostkey_nid, ssh);
1200 	if (*pubp == NULL)
1201 		return SSH_ERR_NO_HOSTKEY_LOADED;
1202 	return 0;
1203 }
1204 
1205 int
1206 kex_verify_host_key(struct ssh *ssh, struct sshkey *server_host_key)
1207 {
1208 	struct kex *kex = ssh->kex;
1209 
1210 	if (kex->verify_host_key == NULL) {
1211 		error_f("missing hostkey verifier");
1212 		return SSH_ERR_INVALID_ARGUMENT;
1213 	}
1214 	if (server_host_key->type != kex->hostkey_type ||
1215 	    (kex->hostkey_type == KEY_ECDSA &&
1216 	    server_host_key->ecdsa_nid != kex->hostkey_nid))
1217 		return SSH_ERR_KEY_TYPE_MISMATCH;
1218 	if (kex->verify_host_key(server_host_key, ssh) == -1)
1219 		return  SSH_ERR_SIGNATURE_INVALID;
1220 	return 0;
1221 }
1222 
1223 #if defined(DEBUG_KEX) || defined(DEBUG_KEXDH) || defined(DEBUG_KEXECDH)
1224 void
1225 dump_digest(const char *msg, const u_char *digest, int len)
1226 {
1227 	fprintf(stderr, "%s\n", msg);
1228 	sshbuf_dump_data(digest, len, stderr);
1229 }
1230 #endif
1231 
1232 /*
1233  * Send a plaintext error message to the peer, suffixed by \r\n.
1234  * Only used during banner exchange, and there only for the server.
1235  */
1236 static void
1237 send_error(struct ssh *ssh, char *msg)
1238 {
1239 	char *crnl = "\r\n";
1240 
1241 	if (!ssh->kex->server)
1242 		return;
1243 
1244 	if (atomicio(vwrite, ssh_packet_get_connection_out(ssh),
1245 	    msg, strlen(msg)) != strlen(msg) ||
1246 	    atomicio(vwrite, ssh_packet_get_connection_out(ssh),
1247 	    crnl, strlen(crnl)) != strlen(crnl))
1248 		error_f("write: %.100s", strerror(errno));
1249 }
1250 
1251 /*
1252  * Sends our identification string and waits for the peer's. Will block for
1253  * up to timeout_ms (or indefinitely if timeout_ms <= 0).
1254  * Returns on 0 success or a ssherr.h code on failure.
1255  */
1256 int
1257 kex_exchange_identification(struct ssh *ssh, int timeout_ms,
1258     const char *version_addendum)
1259 {
1260 	int remote_major, remote_minor, mismatch, oerrno = 0;
1261 	size_t len, n;
1262 	int r, expect_nl;
1263 	u_char c;
1264 	struct sshbuf *our_version = ssh->kex->server ?
1265 	    ssh->kex->server_version : ssh->kex->client_version;
1266 	struct sshbuf *peer_version = ssh->kex->server ?
1267 	    ssh->kex->client_version : ssh->kex->server_version;
1268 	char *our_version_string = NULL, *peer_version_string = NULL;
1269 	char *cp, *remote_version = NULL;
1270 
1271 	/* Prepare and send our banner */
1272 	sshbuf_reset(our_version);
1273 	if (version_addendum != NULL && *version_addendum == '\0')
1274 		version_addendum = NULL;
1275 	if ((r = sshbuf_putf(our_version, "SSH-%d.%d-%.100s%s%s\r\n",
1276 	    PROTOCOL_MAJOR_2, PROTOCOL_MINOR_2, SSH_VERSION,
1277 	    version_addendum == NULL ? "" : " ",
1278 	    version_addendum == NULL ? "" : version_addendum)) != 0) {
1279 		oerrno = errno;
1280 		error_fr(r, "sshbuf_putf");
1281 		goto out;
1282 	}
1283 
1284 	if (atomicio(vwrite, ssh_packet_get_connection_out(ssh),
1285 	    sshbuf_mutable_ptr(our_version),
1286 	    sshbuf_len(our_version)) != sshbuf_len(our_version)) {
1287 		oerrno = errno;
1288 		debug_f("write: %.100s", strerror(errno));
1289 		r = SSH_ERR_SYSTEM_ERROR;
1290 		goto out;
1291 	}
1292 	if ((r = sshbuf_consume_end(our_version, 2)) != 0) { /* trim \r\n */
1293 		oerrno = errno;
1294 		error_fr(r, "sshbuf_consume_end");
1295 		goto out;
1296 	}
1297 	our_version_string = sshbuf_dup_string(our_version);
1298 	if (our_version_string == NULL) {
1299 		error_f("sshbuf_dup_string failed");
1300 		r = SSH_ERR_ALLOC_FAIL;
1301 		goto out;
1302 	}
1303 	debug("Local version string %.100s", our_version_string);
1304 
1305 	/* Read other side's version identification. */
1306 	for (n = 0; ; n++) {
1307 		if (n >= SSH_MAX_PRE_BANNER_LINES) {
1308 			send_error(ssh, "No SSH identification string "
1309 			    "received.");
1310 			error_f("No SSH version received in first %u lines "
1311 			    "from server", SSH_MAX_PRE_BANNER_LINES);
1312 			r = SSH_ERR_INVALID_FORMAT;
1313 			goto out;
1314 		}
1315 		sshbuf_reset(peer_version);
1316 		expect_nl = 0;
1317 		for (;;) {
1318 			if (timeout_ms > 0) {
1319 				r = waitrfd(ssh_packet_get_connection_in(ssh),
1320 				    &timeout_ms);
1321 				if (r == -1 && errno == ETIMEDOUT) {
1322 					send_error(ssh, "Timed out waiting "
1323 					    "for SSH identification string.");
1324 					error("Connection timed out during "
1325 					    "banner exchange");
1326 					r = SSH_ERR_CONN_TIMEOUT;
1327 					goto out;
1328 				} else if (r == -1) {
1329 					oerrno = errno;
1330 					error_f("%s", strerror(errno));
1331 					r = SSH_ERR_SYSTEM_ERROR;
1332 					goto out;
1333 				}
1334 			}
1335 
1336 			len = atomicio(read, ssh_packet_get_connection_in(ssh),
1337 			    &c, 1);
1338 			if (len != 1 && errno == EPIPE) {
1339 				error_f("Connection closed by remote host");
1340 				r = SSH_ERR_CONN_CLOSED;
1341 				goto out;
1342 			} else if (len != 1) {
1343 				oerrno = errno;
1344 				error_f("read: %.100s", strerror(errno));
1345 				r = SSH_ERR_SYSTEM_ERROR;
1346 				goto out;
1347 			}
1348 			if (c == '\r') {
1349 				expect_nl = 1;
1350 				continue;
1351 			}
1352 			if (c == '\n')
1353 				break;
1354 			if (c == '\0' || expect_nl) {
1355 				error_f("banner line contains invalid "
1356 				    "characters");
1357 				goto invalid;
1358 			}
1359 			if ((r = sshbuf_put_u8(peer_version, c)) != 0) {
1360 				oerrno = errno;
1361 				error_fr(r, "sshbuf_put");
1362 				goto out;
1363 			}
1364 			if (sshbuf_len(peer_version) > SSH_MAX_BANNER_LEN) {
1365 				error_f("banner line too long");
1366 				goto invalid;
1367 			}
1368 		}
1369 		/* Is this an actual protocol banner? */
1370 		if (sshbuf_len(peer_version) > 4 &&
1371 		    memcmp(sshbuf_ptr(peer_version), "SSH-", 4) == 0)
1372 			break;
1373 		/* If not, then just log the line and continue */
1374 		if ((cp = sshbuf_dup_string(peer_version)) == NULL) {
1375 			error_f("sshbuf_dup_string failed");
1376 			r = SSH_ERR_ALLOC_FAIL;
1377 			goto out;
1378 		}
1379 		/* Do not accept lines before the SSH ident from a client */
1380 		if (ssh->kex->server) {
1381 			error_f("client sent invalid protocol identifier "
1382 			    "\"%.256s\"", cp);
1383 			free(cp);
1384 			goto invalid;
1385 		}
1386 		debug_f("banner line %zu: %s", n, cp);
1387 		free(cp);
1388 	}
1389 	peer_version_string = sshbuf_dup_string(peer_version);
1390 	if (peer_version_string == NULL)
1391 		fatal_f("sshbuf_dup_string failed");
1392 	/* XXX must be same size for sscanf */
1393 	if ((remote_version = calloc(1, sshbuf_len(peer_version))) == NULL) {
1394 		error_f("calloc failed");
1395 		r = SSH_ERR_ALLOC_FAIL;
1396 		goto out;
1397 	}
1398 
1399 	/*
1400 	 * Check that the versions match.  In future this might accept
1401 	 * several versions and set appropriate flags to handle them.
1402 	 */
1403 	if (sscanf(peer_version_string, "SSH-%d.%d-%[^\n]\n",
1404 	    &remote_major, &remote_minor, remote_version) != 3) {
1405 		error("Bad remote protocol version identification: '%.100s'",
1406 		    peer_version_string);
1407  invalid:
1408 		send_error(ssh, "Invalid SSH identification string.");
1409 		r = SSH_ERR_INVALID_FORMAT;
1410 		goto out;
1411 	}
1412 	debug("Remote protocol version %d.%d, remote software version %.100s",
1413 	    remote_major, remote_minor, remote_version);
1414 	compat_banner(ssh, remote_version);
1415 
1416 	mismatch = 0;
1417 	switch (remote_major) {
1418 	case 2:
1419 		break;
1420 	case 1:
1421 		if (remote_minor != 99)
1422 			mismatch = 1;
1423 		break;
1424 	default:
1425 		mismatch = 1;
1426 		break;
1427 	}
1428 	if (mismatch) {
1429 		error("Protocol major versions differ: %d vs. %d",
1430 		    PROTOCOL_MAJOR_2, remote_major);
1431 		send_error(ssh, "Protocol major versions differ.");
1432 		r = SSH_ERR_NO_PROTOCOL_VERSION;
1433 		goto out;
1434 	}
1435 
1436 	if (ssh->kex->server && (ssh->compat & SSH_BUG_PROBE) != 0) {
1437 		logit("probed from %s port %d with %s.  Don't panic.",
1438 		    ssh_remote_ipaddr(ssh), ssh_remote_port(ssh),
1439 		    peer_version_string);
1440 		r = SSH_ERR_CONN_CLOSED; /* XXX */
1441 		goto out;
1442 	}
1443 	if (ssh->kex->server && (ssh->compat & SSH_BUG_SCANNER) != 0) {
1444 		logit("scanned from %s port %d with %s.  Don't panic.",
1445 		    ssh_remote_ipaddr(ssh), ssh_remote_port(ssh),
1446 		    peer_version_string);
1447 		r = SSH_ERR_CONN_CLOSED; /* XXX */
1448 		goto out;
1449 	}
1450 	/* success */
1451 	r = 0;
1452  out:
1453 	free(our_version_string);
1454 	free(peer_version_string);
1455 	free(remote_version);
1456 	if (r == SSH_ERR_SYSTEM_ERROR)
1457 		errno = oerrno;
1458 	return r;
1459 }
1460 
1461