xref: /netbsd-src/crypto/external/bsd/openssh/dist/kex.c (revision 195d6bf6173f7fc3dac511ecb082e722f8fcd656)
1 /*	$NetBSD: kex.c,v 1.38 2024/10/03 08:14:13 rin Exp $	*/
2 /* $OpenBSD: kex.c,v 1.187 2024/08/23 04:51:00 deraadt Exp $ */
3 
4 /*
5  * Copyright (c) 2000, 2001 Markus Friedl.  All rights reserved.
6  *
7  * Redistribution and use in source and binary forms, with or without
8  * modification, are permitted provided that the following conditions
9  * are met:
10  * 1. Redistributions of source code must retain the above copyright
11  *    notice, this list of conditions and the following disclaimer.
12  * 2. Redistributions in binary form must reproduce the above copyright
13  *    notice, this list of conditions and the following disclaimer in the
14  *    documentation and/or other materials provided with the distribution.
15  *
16  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
17  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
18  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
19  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
20  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
21  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
22  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
23  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
25  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26  */
27 
28 #include "includes.h"
29 __RCSID("$NetBSD: kex.c,v 1.38 2024/10/03 08:14:13 rin Exp $");
30 
31 #include <sys/param.h>	/* MAX roundup */
32 #include <sys/types.h>
33 #include <errno.h>
34 #include <signal.h>
35 #include <stdio.h>
36 #include <stdlib.h>
37 #include <string.h>
38 #include <unistd.h>
39 #include <poll.h>
40 
41 #ifdef WITH_OPENSSL
42 #include <openssl/crypto.h>
43 #include <openssl/dh.h>
44 #endif
45 
46 #include "ssh.h"
47 #include "ssh2.h"
48 #include "atomicio.h"
49 #include "version.h"
50 #include "packet.h"
51 #include "compat.h"
52 #include "cipher.h"
53 #include "sshkey.h"
54 #include "kex.h"
55 #include "log.h"
56 #include "mac.h"
57 #include "match.h"
58 #include "misc.h"
59 #include "dispatch.h"
60 #include "packet.h"
61 #include "monitor.h"
62 #include "myproposal.h"
63 
64 #include "ssherr.h"
65 #include "sshbuf.h"
66 #include "digest.h"
67 #include "xmalloc.h"
68 
69 /* prototype */
70 static int kex_choose_conf(struct ssh *, uint32_t seq);
71 static int kex_input_newkeys(int, u_int32_t, struct ssh *);
72 
73 static const char * const proposal_names[PROPOSAL_MAX] = {
74 	"KEX algorithms",
75 	"host key algorithms",
76 	"ciphers ctos",
77 	"ciphers stoc",
78 	"MACs ctos",
79 	"MACs stoc",
80 	"compression ctos",
81 	"compression stoc",
82 	"languages ctos",
83 	"languages stoc",
84 };
85 
86 /*
87  * Fill out a proposal array with dynamically allocated values, which may
88  * be modified as required for compatibility reasons.
89  * Any of the options may be NULL, in which case the default is used.
90  * Array contents must be freed by calling kex_proposal_free_entries.
91  */
92 void
93 kex_proposal_populate_entries(struct ssh *ssh, char *prop[PROPOSAL_MAX],
94     const char *kexalgos, const char *ciphers, const char *macs,
95     const char *comp, const char *hkalgs)
96 {
97 	const char *defpropserver[PROPOSAL_MAX] = { KEX_SERVER };
98 	const char *defpropclient[PROPOSAL_MAX] = { KEX_CLIENT };
99 	const char **defprop = ssh->kex->server ? defpropserver : defpropclient;
100 	u_int i;
101 	char *cp;
102 
103 	if (prop == NULL)
104 		fatal_f("proposal missing");
105 
106 	/* Append EXT_INFO signalling to KexAlgorithms */
107 	if (kexalgos == NULL)
108 		kexalgos = defprop[PROPOSAL_KEX_ALGS];
109 	if ((cp = kex_names_cat(kexalgos, ssh->kex->server ?
110 	    "ext-info-s,kex-strict-s-v00@openssh.com" :
111 	    "ext-info-c,kex-strict-c-v00@openssh.com")) == NULL)
112 		fatal_f("kex_names_cat");
113 
114 	for (i = 0; i < PROPOSAL_MAX; i++) {
115 		switch(i) {
116 		case PROPOSAL_KEX_ALGS:
117 			prop[i] = compat_kex_proposal(ssh, cp);
118 			break;
119 		case PROPOSAL_ENC_ALGS_CTOS:
120 		case PROPOSAL_ENC_ALGS_STOC:
121 			prop[i] = xstrdup(ciphers ? ciphers : defprop[i]);
122 			break;
123 		case PROPOSAL_MAC_ALGS_CTOS:
124 		case PROPOSAL_MAC_ALGS_STOC:
125 			prop[i]  = xstrdup(macs ? macs : defprop[i]);
126 			break;
127 		case PROPOSAL_COMP_ALGS_CTOS:
128 		case PROPOSAL_COMP_ALGS_STOC:
129 			prop[i] = xstrdup(comp ? comp : defprop[i]);
130 			break;
131 		case PROPOSAL_SERVER_HOST_KEY_ALGS:
132 			prop[i] = xstrdup(hkalgs ? hkalgs : defprop[i]);
133 			break;
134 		default:
135 			prop[i] = xstrdup(defprop[i]);
136 		}
137 	}
138 	free(cp);
139 }
140 
141 void
142 kex_proposal_free_entries(char *prop[PROPOSAL_MAX])
143 {
144 	u_int i;
145 
146 	for (i = 0; i < PROPOSAL_MAX; i++)
147 		free(prop[i]);
148 }
149 
150 /* put algorithm proposal into buffer */
151 int
152 kex_prop2buf(struct sshbuf *b, char *proposal[PROPOSAL_MAX])
153 {
154 	u_int i;
155 	int r;
156 
157 	sshbuf_reset(b);
158 
159 	/*
160 	 * add a dummy cookie, the cookie will be overwritten by
161 	 * kex_send_kexinit(), each time a kexinit is set
162 	 */
163 	for (i = 0; i < KEX_COOKIE_LEN; i++) {
164 		if ((r = sshbuf_put_u8(b, 0)) != 0)
165 			return r;
166 	}
167 	for (i = 0; i < PROPOSAL_MAX; i++) {
168 		if ((r = sshbuf_put_cstring(b, proposal[i])) != 0)
169 			return r;
170 	}
171 	if ((r = sshbuf_put_u8(b, 0)) != 0 ||	/* first_kex_packet_follows */
172 	    (r = sshbuf_put_u32(b, 0)) != 0)	/* uint32 reserved */
173 		return r;
174 	return 0;
175 }
176 
177 /* parse buffer and return algorithm proposal */
178 int
179 kex_buf2prop(struct sshbuf *raw, int *first_kex_follows, char ***propp)
180 {
181 	struct sshbuf *b = NULL;
182 	u_char v;
183 	u_int i;
184 	char **proposal = NULL;
185 	int r;
186 
187 	*propp = NULL;
188 	if ((proposal = calloc(PROPOSAL_MAX, sizeof(char *))) == NULL)
189 		return SSH_ERR_ALLOC_FAIL;
190 	if ((b = sshbuf_fromb(raw)) == NULL) {
191 		r = SSH_ERR_ALLOC_FAIL;
192 		goto out;
193 	}
194 	if ((r = sshbuf_consume(b, KEX_COOKIE_LEN)) != 0) { /* skip cookie */
195 		error_fr(r, "consume cookie");
196 		goto out;
197 	}
198 	/* extract kex init proposal strings */
199 	for (i = 0; i < PROPOSAL_MAX; i++) {
200 		if ((r = sshbuf_get_cstring(b, &(proposal[i]), NULL)) != 0) {
201 			error_fr(r, "parse proposal %u", i);
202 			goto out;
203 		}
204 		debug2("%s: %s", proposal_names[i], proposal[i]);
205 	}
206 	/* first kex follows / reserved */
207 	if ((r = sshbuf_get_u8(b, &v)) != 0 ||	/* first_kex_follows */
208 	    (r = sshbuf_get_u32(b, &i)) != 0) {	/* reserved */
209 		error_fr(r, "parse");
210 		goto out;
211 	}
212 	if (first_kex_follows != NULL)
213 		*first_kex_follows = v;
214 	debug2("first_kex_follows %d ", v);
215 	debug2("reserved %u ", i);
216 	r = 0;
217 	*propp = proposal;
218  out:
219 	if (r != 0 && proposal != NULL)
220 		kex_prop_free(proposal);
221 	sshbuf_free(b);
222 	return r;
223 }
224 
225 void
226 kex_prop_free(char **proposal)
227 {
228 	u_int i;
229 
230 	if (proposal == NULL)
231 		return;
232 	for (i = 0; i < PROPOSAL_MAX; i++)
233 		free(proposal[i]);
234 	free(proposal);
235 }
236 
237 int
238 kex_protocol_error(int type, u_int32_t seq, struct ssh *ssh)
239 {
240 	int r;
241 
242 	/* If in strict mode, any unexpected message is an error */
243 	if ((ssh->kex->flags & KEX_INITIAL) && ssh->kex->kex_strict) {
244 		ssh_packet_disconnect(ssh, "strict KEX violation: "
245 		    "unexpected packet type %u (seqnr %u)", type, seq);
246 	}
247 	error_f("type %u seq %u", type, seq);
248 	if ((r = sshpkt_start(ssh, SSH2_MSG_UNIMPLEMENTED)) != 0 ||
249 	    (r = sshpkt_put_u32(ssh, seq)) != 0 ||
250 	    (r = sshpkt_send(ssh)) != 0)
251 		return r;
252 	return 0;
253 }
254 
255 static void
256 kex_reset_dispatch(struct ssh *ssh)
257 {
258 	ssh_dispatch_range(ssh, SSH2_MSG_TRANSPORT_MIN,
259 	    SSH2_MSG_TRANSPORT_MAX, &kex_protocol_error);
260 }
261 
262 void
263 kex_set_server_sig_algs(struct ssh *ssh, const char *allowed_algs)
264 {
265 	char *alg, *oalgs, *algs, *sigalgs;
266 	const char *sigalg;
267 
268 	/*
269 	 * NB. allowed algorithms may contain certificate algorithms that
270 	 * map to a specific plain signature type, e.g.
271 	 * rsa-sha2-512-cert-v01@openssh.com => rsa-sha2-512
272 	 * We need to be careful here to match these, retain the mapping
273 	 * and only add each signature algorithm once.
274 	 */
275 	if ((sigalgs = sshkey_alg_list(0, 1, 1, ',')) == NULL)
276 		fatal_f("sshkey_alg_list failed");
277 	oalgs = algs = xstrdup(allowed_algs);
278 	free(ssh->kex->server_sig_algs);
279 	ssh->kex->server_sig_algs = NULL;
280 	for ((alg = strsep(&algs, ",")); alg != NULL && *alg != '\0';
281 	    (alg = strsep(&algs, ","))) {
282 		if ((sigalg = sshkey_sigalg_by_name(alg)) == NULL)
283 			continue;
284 		if (!kex_has_any_alg(sigalg, sigalgs))
285 			continue;
286 		/* Don't add an algorithm twice. */
287 		if (ssh->kex->server_sig_algs != NULL &&
288 		    kex_has_any_alg(sigalg, ssh->kex->server_sig_algs))
289 			continue;
290 		xextendf(&ssh->kex->server_sig_algs, ",", "%s", sigalg);
291 	}
292 	free(oalgs);
293 	free(sigalgs);
294 	if (ssh->kex->server_sig_algs == NULL)
295 		ssh->kex->server_sig_algs = xstrdup("");
296 }
297 
298 static int
299 kex_compose_ext_info_server(struct ssh *ssh, struct sshbuf *m)
300 {
301 	int r;
302 
303 	if (ssh->kex->server_sig_algs == NULL &&
304 	    (ssh->kex->server_sig_algs = sshkey_alg_list(0, 1, 1, ',')) == NULL)
305 		return SSH_ERR_ALLOC_FAIL;
306 	if ((r = sshbuf_put_u32(m, 3)) != 0 ||
307 	    (r = sshbuf_put_cstring(m, "server-sig-algs")) != 0 ||
308 	    (r = sshbuf_put_cstring(m, ssh->kex->server_sig_algs)) != 0 ||
309 	    (r = sshbuf_put_cstring(m,
310 	    "publickey-hostbound@openssh.com")) != 0 ||
311 	    (r = sshbuf_put_cstring(m, "0")) != 0 ||
312 	    (r = sshbuf_put_cstring(m, "ping@openssh.com")) != 0 ||
313 	    (r = sshbuf_put_cstring(m, "0")) != 0) {
314 		error_fr(r, "compose");
315 		return r;
316 	}
317 	return 0;
318 }
319 
320 static int
321 kex_compose_ext_info_client(struct ssh *ssh, struct sshbuf *m)
322 {
323 	int r;
324 
325 	if ((r = sshbuf_put_u32(m, 1)) != 0 ||
326 	    (r = sshbuf_put_cstring(m, "ext-info-in-auth@openssh.com")) != 0 ||
327 	    (r = sshbuf_put_cstring(m, "0")) != 0) {
328 		error_fr(r, "compose");
329 		goto out;
330 	}
331 	/* success */
332 	r = 0;
333  out:
334 	return r;
335 }
336 
337 static int
338 kex_maybe_send_ext_info(struct ssh *ssh)
339 {
340 	int r;
341 	struct sshbuf *m = NULL;
342 
343 	if ((ssh->kex->flags & KEX_INITIAL) == 0)
344 		return 0;
345 	if (!ssh->kex->ext_info_c && !ssh->kex->ext_info_s)
346 		return 0;
347 
348 	/* Compose EXT_INFO packet. */
349 	if ((m = sshbuf_new()) == NULL)
350 		fatal_f("sshbuf_new failed");
351 	if (ssh->kex->ext_info_c &&
352 	    (r = kex_compose_ext_info_server(ssh, m)) != 0)
353 		goto fail;
354 	if (ssh->kex->ext_info_s &&
355 	    (r = kex_compose_ext_info_client(ssh, m)) != 0)
356 		goto fail;
357 
358 	/* Send the actual KEX_INFO packet */
359 	debug("Sending SSH2_MSG_EXT_INFO");
360 	if ((r = sshpkt_start(ssh, SSH2_MSG_EXT_INFO)) != 0 ||
361 	    (r = sshpkt_putb(ssh, m)) != 0 ||
362 	    (r = sshpkt_send(ssh)) != 0) {
363 		error_f("send EXT_INFO");
364 		goto fail;
365 	}
366 
367 	r = 0;
368 
369  fail:
370 	sshbuf_free(m);
371 	return r;
372 }
373 
374 int
375 kex_server_update_ext_info(struct ssh *ssh)
376 {
377 	int r;
378 
379 	if ((ssh->kex->flags & KEX_HAS_EXT_INFO_IN_AUTH) == 0)
380 		return 0;
381 
382 	debug_f("Sending SSH2_MSG_EXT_INFO");
383 	if ((r = sshpkt_start(ssh, SSH2_MSG_EXT_INFO)) != 0 ||
384 	    (r = sshpkt_put_u32(ssh, 1)) != 0 ||
385 	    (r = sshpkt_put_cstring(ssh, "server-sig-algs")) != 0 ||
386 	    (r = sshpkt_put_cstring(ssh, ssh->kex->server_sig_algs)) != 0 ||
387 	    (r = sshpkt_send(ssh)) != 0) {
388 		error_f("send EXT_INFO");
389 		return r;
390 	}
391 	return 0;
392 }
393 
394 int
395 kex_send_newkeys(struct ssh *ssh)
396 {
397 	int r;
398 
399 	kex_reset_dispatch(ssh);
400 	if ((r = sshpkt_start(ssh, SSH2_MSG_NEWKEYS)) != 0 ||
401 	    (r = sshpkt_send(ssh)) != 0)
402 		return r;
403 	debug("SSH2_MSG_NEWKEYS sent");
404 	ssh_dispatch_set(ssh, SSH2_MSG_NEWKEYS, &kex_input_newkeys);
405 	if ((r = kex_maybe_send_ext_info(ssh)) != 0)
406 		return r;
407 	debug("expecting SSH2_MSG_NEWKEYS");
408 	return 0;
409 }
410 
411 /* Check whether an ext_info value contains the expected version string */
412 static int
413 kex_ext_info_check_ver(struct kex *kex, const char *name,
414     const u_char *val, size_t len, const char *want_ver, u_int flag)
415 {
416 	if (memchr(val, '\0', len) != NULL) {
417 		error("SSH2_MSG_EXT_INFO: %s value contains nul byte", name);
418 		return SSH_ERR_INVALID_FORMAT;
419 	}
420 	debug_f("%s=<%s>", name, val);
421 	if (strcmp((const char *)val, want_ver) == 0)
422 		kex->flags |= flag;
423 	else
424 		debug_f("unsupported version of %s extension", name);
425 	return 0;
426 }
427 
428 static int
429 kex_ext_info_client_parse(struct ssh *ssh, const char *name,
430     const u_char *value, size_t vlen)
431 {
432 	int r;
433 
434 	/* NB. some messages are only accepted in the initial EXT_INFO */
435 	if (strcmp(name, "server-sig-algs") == 0) {
436 		/* Ensure no \0 lurking in value */
437 		if (memchr(value, '\0', vlen) != NULL) {
438 			error_f("nul byte in %s", name);
439 			return SSH_ERR_INVALID_FORMAT;
440 		}
441 		debug_f("%s=<%s>", name, value);
442 		free(ssh->kex->server_sig_algs);
443 		ssh->kex->server_sig_algs = xstrdup((const char *)value);
444 	} else if (ssh->kex->ext_info_received == 1 &&
445 	    strcmp(name, "publickey-hostbound@openssh.com") == 0) {
446 		if ((r = kex_ext_info_check_ver(ssh->kex, name, value, vlen,
447 		    "0", KEX_HAS_PUBKEY_HOSTBOUND)) != 0) {
448 			return r;
449 		}
450 	} else if (ssh->kex->ext_info_received == 1 &&
451 	    strcmp(name, "ping@openssh.com") == 0) {
452 		if ((r = kex_ext_info_check_ver(ssh->kex, name, value, vlen,
453 		    "0", KEX_HAS_PING)) != 0) {
454 			return r;
455 		}
456 	} else
457 		debug_f("%s (unrecognised)", name);
458 
459 	return 0;
460 }
461 
462 static int
463 kex_ext_info_server_parse(struct ssh *ssh, const char *name,
464     const u_char *value, size_t vlen)
465 {
466 	int r;
467 
468 	if (strcmp(name, "ext-info-in-auth@openssh.com") == 0) {
469 		if ((r = kex_ext_info_check_ver(ssh->kex, name, value, vlen,
470 		    "0", KEX_HAS_EXT_INFO_IN_AUTH)) != 0) {
471 			return r;
472 		}
473 	} else
474 		debug_f("%s (unrecognised)", name);
475 	return 0;
476 }
477 
478 int
479 kex_input_ext_info(int type, u_int32_t seq, struct ssh *ssh)
480 {
481 	struct kex *kex = ssh->kex;
482 	const int max_ext_info = kex->server ? 1 : 2;
483 	u_int32_t i, ninfo;
484 	char *name;
485 	u_char *val;
486 	size_t vlen;
487 	int r;
488 
489 	debug("SSH2_MSG_EXT_INFO received");
490 	if (++kex->ext_info_received > max_ext_info) {
491 		error("too many SSH2_MSG_EXT_INFO messages sent by peer");
492 		return dispatch_protocol_error(type, seq, ssh);
493 	}
494 	ssh_dispatch_set(ssh, SSH2_MSG_EXT_INFO, &kex_protocol_error);
495 	if ((r = sshpkt_get_u32(ssh, &ninfo)) != 0)
496 		return r;
497 	if (ninfo >= 1024) {
498 		error("SSH2_MSG_EXT_INFO with too many entries, expected "
499 		    "<=1024, received %u", ninfo);
500 		return dispatch_protocol_error(type, seq, ssh);
501 	}
502 	for (i = 0; i < ninfo; i++) {
503 		if ((r = sshpkt_get_cstring(ssh, &name, NULL)) != 0)
504 			return r;
505 		if ((r = sshpkt_get_string(ssh, &val, &vlen)) != 0) {
506 			free(name);
507 			return r;
508 		}
509 		debug3_f("extension %s", name);
510 		if (kex->server) {
511 			if ((r = kex_ext_info_server_parse(ssh, name,
512 			    val, vlen)) != 0)
513 				return r;
514 		} else {
515 			if ((r = kex_ext_info_client_parse(ssh, name,
516 			    val, vlen)) != 0)
517 				return r;
518 		}
519 		free(name);
520 		free(val);
521 	}
522 	return sshpkt_get_end(ssh);
523 }
524 
525 static int
526 kex_input_newkeys(int type, u_int32_t seq, struct ssh *ssh)
527 {
528 	struct kex *kex = ssh->kex;
529 	int r, initial = (kex->flags & KEX_INITIAL) != 0;
530 	char *cp, **prop;
531 
532 	debug("SSH2_MSG_NEWKEYS received");
533 	if (kex->ext_info_c && initial)
534 		ssh_dispatch_set(ssh, SSH2_MSG_EXT_INFO, &kex_input_ext_info);
535 	ssh_dispatch_set(ssh, SSH2_MSG_NEWKEYS, &kex_protocol_error);
536 	ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_input_kexinit);
537 	if ((r = sshpkt_get_end(ssh)) != 0)
538 		return r;
539 	if ((r = ssh_set_newkeys(ssh, MODE_IN)) != 0)
540 		return r;
541 	if (initial) {
542 		/* Remove initial KEX signalling from proposal for rekeying */
543 		if ((r = kex_buf2prop(kex->my, NULL, &prop)) != 0)
544 			return r;
545 		if ((cp = match_filter_denylist(prop[PROPOSAL_KEX_ALGS],
546 		    kex->server ?
547 		    "ext-info-s,kex-strict-s-v00@openssh.com" :
548 		    "ext-info-c,kex-strict-c-v00@openssh.com")) == NULL) {
549 			error_f("match_filter_denylist failed");
550 			goto fail;
551 		}
552 		free(prop[PROPOSAL_KEX_ALGS]);
553 		prop[PROPOSAL_KEX_ALGS] = cp;
554 		if ((r = kex_prop2buf(ssh->kex->my, prop)) != 0) {
555 			error_f("kex_prop2buf failed");
556  fail:
557 			kex_proposal_free_entries(prop);
558 			free(prop);
559 			return SSH_ERR_INTERNAL_ERROR;
560 		}
561 		kex_proposal_free_entries(prop);
562 		free(prop);
563 	}
564 	kex->done = 1;
565 	kex->flags &= ~KEX_INITIAL;
566 	sshbuf_reset(kex->peer);
567 	kex->flags &= ~KEX_INIT_SENT;
568 	free(kex->name);
569 	kex->name = NULL;
570 	return 0;
571 }
572 
573 int
574 kex_send_kexinit(struct ssh *ssh)
575 {
576 	u_char *cookie;
577 	struct kex *kex = ssh->kex;
578 	int r;
579 
580 	if (kex == NULL) {
581 		error_f("no kex");
582 		return SSH_ERR_INTERNAL_ERROR;
583 	}
584 	if (kex->flags & KEX_INIT_SENT)
585 		return 0;
586 	kex->done = 0;
587 
588 	/* generate a random cookie */
589 	if (sshbuf_len(kex->my) < KEX_COOKIE_LEN) {
590 		error_f("bad kex length: %zu < %d",
591 		    sshbuf_len(kex->my), KEX_COOKIE_LEN);
592 		return SSH_ERR_INVALID_FORMAT;
593 	}
594 	if ((cookie = sshbuf_mutable_ptr(kex->my)) == NULL) {
595 		error_f("buffer error");
596 		return SSH_ERR_INTERNAL_ERROR;
597 	}
598 	arc4random_buf(cookie, KEX_COOKIE_LEN);
599 
600 	if ((r = sshpkt_start(ssh, SSH2_MSG_KEXINIT)) != 0 ||
601 	    (r = sshpkt_putb(ssh, kex->my)) != 0 ||
602 	    (r = sshpkt_send(ssh)) != 0) {
603 		error_fr(r, "compose reply");
604 		return r;
605 	}
606 	debug("SSH2_MSG_KEXINIT sent");
607 	kex->flags |= KEX_INIT_SENT;
608 	return 0;
609 }
610 
611 int
612 kex_input_kexinit(int type, u_int32_t seq, struct ssh *ssh)
613 {
614 	struct kex *kex = ssh->kex;
615 	const u_char *ptr;
616 	u_int i;
617 	size_t dlen;
618 	int r;
619 
620 	debug("SSH2_MSG_KEXINIT received");
621 	if (kex == NULL) {
622 		error_f("no kex");
623 		return SSH_ERR_INTERNAL_ERROR;
624 	}
625 	ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_protocol_error);
626 	ptr = sshpkt_ptr(ssh, &dlen);
627 	if ((r = sshbuf_put(kex->peer, ptr, dlen)) != 0)
628 		return r;
629 
630 	/* discard packet */
631 	for (i = 0; i < KEX_COOKIE_LEN; i++) {
632 		if ((r = sshpkt_get_u8(ssh, NULL)) != 0) {
633 			error_fr(r, "discard cookie");
634 			return r;
635 		}
636 	}
637 	for (i = 0; i < PROPOSAL_MAX; i++) {
638 		if ((r = sshpkt_get_string(ssh, NULL, NULL)) != 0) {
639 			error_fr(r, "discard proposal");
640 			return r;
641 		}
642 	}
643 	/*
644 	 * XXX RFC4253 sec 7: "each side MAY guess" - currently no supported
645 	 * KEX method has the server move first, but a server might be using
646 	 * a custom method or one that we otherwise don't support. We should
647 	 * be prepared to remember first_kex_follows here so we can eat a
648 	 * packet later.
649 	 * XXX2 - RFC4253 is kind of ambiguous on what first_kex_follows means
650 	 * for cases where the server *doesn't* go first. I guess we should
651 	 * ignore it when it is set for these cases, which is what we do now.
652 	 */
653 	if ((r = sshpkt_get_u8(ssh, NULL)) != 0 ||	/* first_kex_follows */
654 	    (r = sshpkt_get_u32(ssh, NULL)) != 0 ||	/* reserved */
655 	    (r = sshpkt_get_end(ssh)) != 0)
656 			return r;
657 
658 	if (!(kex->flags & KEX_INIT_SENT))
659 		if ((r = kex_send_kexinit(ssh)) != 0)
660 			return r;
661 	if ((r = kex_choose_conf(ssh, seq)) != 0)
662 		return r;
663 
664 	if (kex->kex_type < KEX_MAX && kex->kex[kex->kex_type] != NULL)
665 		return (kex->kex[kex->kex_type])(ssh);
666 
667 	error_f("unknown kex type %u", kex->kex_type);
668 	return SSH_ERR_INTERNAL_ERROR;
669 }
670 
671 struct kex *
672 kex_new(void)
673 {
674 	struct kex *kex;
675 
676 	if ((kex = calloc(1, sizeof(*kex))) == NULL ||
677 	    (kex->peer = sshbuf_new()) == NULL ||
678 	    (kex->my = sshbuf_new()) == NULL ||
679 	    (kex->client_version = sshbuf_new()) == NULL ||
680 	    (kex->server_version = sshbuf_new()) == NULL ||
681 	    (kex->session_id = sshbuf_new()) == NULL) {
682 		kex_free(kex);
683 		return NULL;
684 	}
685 	return kex;
686 }
687 
688 void
689 kex_free_newkeys(struct newkeys *newkeys)
690 {
691 	if (newkeys == NULL)
692 		return;
693 	if (newkeys->enc.key) {
694 		explicit_bzero(newkeys->enc.key, newkeys->enc.key_len);
695 		free(newkeys->enc.key);
696 		newkeys->enc.key = NULL;
697 	}
698 	if (newkeys->enc.iv) {
699 		explicit_bzero(newkeys->enc.iv, newkeys->enc.iv_len);
700 		free(newkeys->enc.iv);
701 		newkeys->enc.iv = NULL;
702 	}
703 	free(newkeys->enc.name);
704 	explicit_bzero(&newkeys->enc, sizeof(newkeys->enc));
705 	free(newkeys->comp.name);
706 	explicit_bzero(&newkeys->comp, sizeof(newkeys->comp));
707 	mac_clear(&newkeys->mac);
708 	if (newkeys->mac.key) {
709 		explicit_bzero(newkeys->mac.key, newkeys->mac.key_len);
710 		free(newkeys->mac.key);
711 		newkeys->mac.key = NULL;
712 	}
713 	free(newkeys->mac.name);
714 	explicit_bzero(&newkeys->mac, sizeof(newkeys->mac));
715 	freezero(newkeys, sizeof(*newkeys));
716 }
717 
718 void
719 kex_free(struct kex *kex)
720 {
721 	u_int mode;
722 
723 	if (kex == NULL)
724 		return;
725 
726 #ifdef WITH_OPENSSL
727 	DH_free(kex->dh);
728 	EC_KEY_free(kex->ec_client_key);
729 #endif
730 	for (mode = 0; mode < MODE_MAX; mode++) {
731 		kex_free_newkeys(kex->newkeys[mode]);
732 		kex->newkeys[mode] = NULL;
733 	}
734 	sshbuf_free(kex->peer);
735 	sshbuf_free(kex->my);
736 	sshbuf_free(kex->client_version);
737 	sshbuf_free(kex->server_version);
738 	sshbuf_free(kex->client_pub);
739 	sshbuf_free(kex->session_id);
740 	sshbuf_free(kex->initial_sig);
741 	sshkey_free(kex->initial_hostkey);
742 	free(kex->failed_choice);
743 	free(kex->hostkey_alg);
744 	free(kex->name);
745 	free(kex);
746 }
747 
748 int
749 kex_ready(struct ssh *ssh, char *proposal[PROPOSAL_MAX])
750 {
751 	int r;
752 
753 	if ((r = kex_prop2buf(ssh->kex->my, proposal)) != 0)
754 		return r;
755 	ssh->kex->flags = KEX_INITIAL;
756 	kex_reset_dispatch(ssh);
757 	ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_input_kexinit);
758 	return 0;
759 }
760 
761 int
762 kex_setup(struct ssh *ssh, char *proposal[PROPOSAL_MAX])
763 {
764 	int r;
765 
766 	if ((r = kex_ready(ssh, proposal)) != 0)
767 		return r;
768 	if ((r = kex_send_kexinit(ssh)) != 0) {		/* we start */
769 		kex_free(ssh->kex);
770 		ssh->kex = NULL;
771 		return r;
772 	}
773 	return 0;
774 }
775 
776 /*
777  * Request key re-exchange, returns 0 on success or a ssherr.h error
778  * code otherwise. Must not be called if KEX is incomplete or in-progress.
779  */
780 int
781 kex_start_rekex(struct ssh *ssh)
782 {
783 	if (ssh->kex == NULL) {
784 		error_f("no kex");
785 		return SSH_ERR_INTERNAL_ERROR;
786 	}
787 	if (ssh->kex->done == 0) {
788 		error_f("requested twice");
789 		return SSH_ERR_INTERNAL_ERROR;
790 	}
791 	ssh->kex->done = 0;
792 	return kex_send_kexinit(ssh);
793 }
794 
795 static int
796 choose_enc(struct sshenc *enc, char *client, char *server)
797 {
798 	char *name = match_list(client, server, NULL);
799 
800 	if (name == NULL)
801 		return SSH_ERR_NO_CIPHER_ALG_MATCH;
802 	if ((enc->cipher = cipher_by_name(name)) == NULL) {
803 		error_f("unsupported cipher %s", name);
804 		free(name);
805 		return SSH_ERR_INTERNAL_ERROR;
806 	}
807 	enc->name = name;
808 	enc->enabled = 0;
809 	enc->iv = NULL;
810 	enc->iv_len = cipher_ivlen(enc->cipher);
811 	enc->key = NULL;
812 	enc->key_len = cipher_keylen(enc->cipher);
813 	enc->block_size = cipher_blocksize(enc->cipher);
814 	return 0;
815 }
816 
817 static int
818 choose_mac(struct ssh *ssh, struct sshmac *mac, char *client, char *server)
819 {
820 	char *name = match_list(client, server, NULL);
821 
822 	if (name == NULL)
823 		return SSH_ERR_NO_MAC_ALG_MATCH;
824 	if (mac_setup(mac, name) < 0) {
825 		error_f("unsupported MAC %s", name);
826 		free(name);
827 		return SSH_ERR_INTERNAL_ERROR;
828 	}
829 	mac->name = name;
830 	mac->key = NULL;
831 	mac->enabled = 0;
832 	return 0;
833 }
834 
835 static int
836 choose_comp(struct sshcomp *comp, char *client, char *server)
837 {
838 	char *name = match_list(client, server, NULL);
839 
840 	if (name == NULL)
841 		return SSH_ERR_NO_COMPRESS_ALG_MATCH;
842 #ifdef WITH_ZLIB
843 	if (strcmp(name, "zlib@openssh.com") == 0) {
844 		comp->type = COMP_DELAYED;
845 	} else
846 #endif	/* WITH_ZLIB */
847 	if (strcmp(name, "none") == 0) {
848 		comp->type = COMP_NONE;
849 	} else {
850 		error_f("unsupported compression scheme %s", name);
851 		free(name);
852 		return SSH_ERR_INTERNAL_ERROR;
853 	}
854 	comp->name = name;
855 	return 0;
856 }
857 
858 static int
859 choose_kex(struct kex *k, char *client, char *server)
860 {
861 	k->name = match_list(client, server, NULL);
862 
863 	debug("kex: algorithm: %s", k->name ? k->name : "(no match)");
864 	if (k->name == NULL)
865 		return SSH_ERR_NO_KEX_ALG_MATCH;
866 	if (!kex_name_valid(k->name)) {
867 		error_f("unsupported KEX method %s", k->name);
868 		return SSH_ERR_INTERNAL_ERROR;
869 	}
870 	k->kex_type = kex_type_from_name(k->name);
871 	k->hash_alg = kex_hash_from_name(k->name);
872 	k->ec_nid = kex_nid_from_name(k->name);
873 	return 0;
874 }
875 
876 static int
877 choose_hostkeyalg(struct kex *k, char *client, char *server)
878 {
879 	free(k->hostkey_alg);
880 	k->hostkey_alg = match_list(client, server, NULL);
881 
882 	debug("kex: host key algorithm: %s",
883 	    k->hostkey_alg ? k->hostkey_alg : "(no match)");
884 	if (k->hostkey_alg == NULL)
885 		return SSH_ERR_NO_HOSTKEY_ALG_MATCH;
886 	k->hostkey_type = sshkey_type_from_name(k->hostkey_alg);
887 	if (k->hostkey_type == KEY_UNSPEC) {
888 		error_f("unsupported hostkey algorithm %s", k->hostkey_alg);
889 		return SSH_ERR_INTERNAL_ERROR;
890 	}
891 	k->hostkey_nid = sshkey_ecdsa_nid_from_name(k->hostkey_alg);
892 	return 0;
893 }
894 
895 static int
896 proposals_match(char *my[PROPOSAL_MAX], char *peer[PROPOSAL_MAX])
897 {
898 	static int check[] = {
899 		PROPOSAL_KEX_ALGS, PROPOSAL_SERVER_HOST_KEY_ALGS, -1
900 	};
901 	int *idx;
902 	char *p;
903 
904 	for (idx = &check[0]; *idx != -1; idx++) {
905 		if ((p = strchr(my[*idx], ',')) != NULL)
906 			*p = '\0';
907 		if ((p = strchr(peer[*idx], ',')) != NULL)
908 			*p = '\0';
909 		if (strcmp(my[*idx], peer[*idx]) != 0) {
910 			debug2("proposal mismatch: my %s peer %s",
911 			    my[*idx], peer[*idx]);
912 			return (0);
913 		}
914 	}
915 	debug2("proposals match");
916 	return (1);
917 }
918 
919 static int
920 kexalgs_contains(char **peer, const char *ext)
921 {
922 	return kex_has_any_alg(peer[PROPOSAL_KEX_ALGS], ext);
923 }
924 
925 static int
926 kex_choose_conf(struct ssh *ssh, uint32_t seq)
927 {
928 	struct kex *kex = ssh->kex;
929 	struct newkeys *newkeys;
930 	char **my = NULL, **peer = NULL;
931 	char **cprop, **sprop;
932 	int nenc, nmac, ncomp;
933 	u_int mode, ctos, need, dh_need, authlen;
934 	int log_flag = 0;
935 	int r, first_kex_follows;
936 
937 	debug2("local %s KEXINIT proposal", kex->server ? "server" : "client");
938 	if ((r = kex_buf2prop(kex->my, NULL, &my)) != 0)
939 		goto out;
940 	debug2("peer %s KEXINIT proposal", kex->server ? "client" : "server");
941 	if ((r = kex_buf2prop(kex->peer, &first_kex_follows, &peer)) != 0)
942 		goto out;
943 
944 	if (kex->server) {
945 		cprop=peer;
946 		sprop=my;
947 	} else {
948 		cprop=my;
949 		sprop=peer;
950 	}
951 
952 	/* Check whether peer supports ext_info/kex_strict */
953 	if ((kex->flags & KEX_INITIAL) != 0) {
954 		if (kex->server) {
955 			kex->ext_info_c = kexalgs_contains(peer, "ext-info-c");
956 			kex->kex_strict = kexalgs_contains(peer,
957 			    "kex-strict-c-v00@openssh.com");
958 		} else {
959 			kex->ext_info_s = kexalgs_contains(peer, "ext-info-s");
960 			kex->kex_strict = kexalgs_contains(peer,
961 			    "kex-strict-s-v00@openssh.com");
962 		}
963 		if (kex->kex_strict) {
964 			debug3_f("will use strict KEX ordering");
965 			if (seq != 0)
966 				ssh_packet_disconnect(ssh,
967 				    "strict KEX violation: "
968 				    "KEXINIT was not the first packet");
969 		}
970 	}
971 
972 	/* Check whether client supports rsa-sha2 algorithms */
973 	if (kex->server && (kex->flags & KEX_INITIAL)) {
974 		if (kex_has_any_alg(peer[PROPOSAL_SERVER_HOST_KEY_ALGS],
975 		    "rsa-sha2-256,rsa-sha2-256-cert-v01@openssh.com"))
976 			kex->flags |= KEX_RSA_SHA2_256_SUPPORTED;
977 		if (kex_has_any_alg(peer[PROPOSAL_SERVER_HOST_KEY_ALGS],
978 		    "rsa-sha2-512,rsa-sha2-512-cert-v01@openssh.com"))
979 			kex->flags |= KEX_RSA_SHA2_512_SUPPORTED;
980 	}
981 
982 	/* Algorithm Negotiation */
983 	if ((r = choose_kex(kex, cprop[PROPOSAL_KEX_ALGS],
984 	    sprop[PROPOSAL_KEX_ALGS])) != 0) {
985 		kex->failed_choice = peer[PROPOSAL_KEX_ALGS];
986 		peer[PROPOSAL_KEX_ALGS] = NULL;
987 		goto out;
988 	}
989 	if ((r = choose_hostkeyalg(kex, cprop[PROPOSAL_SERVER_HOST_KEY_ALGS],
990 	    sprop[PROPOSAL_SERVER_HOST_KEY_ALGS])) != 0) {
991 		kex->failed_choice = peer[PROPOSAL_SERVER_HOST_KEY_ALGS];
992 		peer[PROPOSAL_SERVER_HOST_KEY_ALGS] = NULL;
993 		goto out;
994 	}
995 	for (mode = 0; mode < MODE_MAX; mode++) {
996 		if ((newkeys = calloc(1, sizeof(*newkeys))) == NULL) {
997 			r = SSH_ERR_ALLOC_FAIL;
998 			goto out;
999 		}
1000 		kex->newkeys[mode] = newkeys;
1001 		ctos = (!kex->server && mode == MODE_OUT) ||
1002 		    (kex->server && mode == MODE_IN);
1003 		nenc  = ctos ? PROPOSAL_ENC_ALGS_CTOS  : PROPOSAL_ENC_ALGS_STOC;
1004 		nmac  = ctos ? PROPOSAL_MAC_ALGS_CTOS  : PROPOSAL_MAC_ALGS_STOC;
1005 		ncomp = ctos ? PROPOSAL_COMP_ALGS_CTOS : PROPOSAL_COMP_ALGS_STOC;
1006 		if ((r = choose_enc(&newkeys->enc, cprop[nenc],
1007 		    sprop[nenc])) != 0) {
1008 			kex->failed_choice = peer[nenc];
1009 			peer[nenc] = NULL;
1010 			goto out;
1011 		}
1012 		authlen = cipher_authlen(newkeys->enc.cipher);
1013 		/* ignore mac for authenticated encryption */
1014 		if (authlen == 0 &&
1015 		    (r = choose_mac(ssh, &newkeys->mac, cprop[nmac],
1016 		    sprop[nmac])) != 0) {
1017 			kex->failed_choice = peer[nmac];
1018 			peer[nmac] = NULL;
1019 			goto out;
1020 		}
1021 		if ((r = choose_comp(&newkeys->comp, cprop[ncomp],
1022 		    sprop[ncomp])) != 0) {
1023 			kex->failed_choice = peer[ncomp];
1024 			peer[ncomp] = NULL;
1025 			goto out;
1026 		}
1027 		debug("REQUESTED ENC.NAME is '%s'", newkeys->enc.name);
1028 		if (strcmp(newkeys->enc.name, "none") == 0) {
1029 			int auth_flag;
1030 
1031 			auth_flag = ssh_packet_authentication_state(ssh);
1032 			debug("Requesting NONE. Authflag is %d", auth_flag);
1033 			if (auth_flag == 1) {
1034 				debug("None requested post authentication.");
1035 			} else {
1036 				fatal("Pre-authentication none cipher requests are not allowed.");
1037 			}
1038 		}
1039 		debug("kex: %s cipher: %s MAC: %s compression: %s",
1040 		    ctos ? "client->server" : "server->client",
1041 		    newkeys->enc.name,
1042 		    authlen == 0 ? newkeys->mac.name : "<implicit>",
1043 		    newkeys->comp.name);
1044 		/* client starts withctos = 0 && log flag = 0 and no log*/
1045 		/* 2nd client pass ctos=1 and flag = 1 so no log*/
1046 		/* server starts with ctos =1 && log_flag = 0 so log */
1047 		/* 2nd sever pass ctos = 1 && log flag = 1 so no log*/
1048 		/* -cjr*/
1049 		if (ctos && !log_flag) {
1050 			logit("SSH: Server;Ltype: Kex;Remote: %s-%d;Enc: %s;MAC: %s;Comp: %s",
1051 			      ssh_remote_ipaddr(ssh),
1052 			      ssh_remote_port(ssh),
1053 			      newkeys->enc.name,
1054 			      authlen == 0 ? newkeys->mac.name : "<implicit>",
1055 			      newkeys->comp.name);
1056 		}
1057 		log_flag = 1;
1058 	}
1059 	need = dh_need = 0;
1060 	for (mode = 0; mode < MODE_MAX; mode++) {
1061 		newkeys = kex->newkeys[mode];
1062 		need = MAXIMUM(need, newkeys->enc.key_len);
1063 		need = MAXIMUM(need, newkeys->enc.block_size);
1064 		need = MAXIMUM(need, newkeys->enc.iv_len);
1065 		need = MAXIMUM(need, newkeys->mac.key_len);
1066 		dh_need = MAXIMUM(dh_need, cipher_seclen(newkeys->enc.cipher));
1067 		dh_need = MAXIMUM(dh_need, newkeys->enc.block_size);
1068 		dh_need = MAXIMUM(dh_need, newkeys->enc.iv_len);
1069 		dh_need = MAXIMUM(dh_need, newkeys->mac.key_len);
1070 	}
1071 	/* XXX need runden? */
1072 	kex->we_need = need;
1073 	kex->dh_need = dh_need;
1074 
1075 	/* ignore the next message if the proposals do not match */
1076 	if (first_kex_follows && !proposals_match(my, peer))
1077 		ssh->dispatch_skip_packets = 1;
1078 	r = 0;
1079  out:
1080 	kex_prop_free(my);
1081 	kex_prop_free(peer);
1082 	return r;
1083 }
1084 
1085 static int
1086 derive_key(struct ssh *ssh, int id, u_int need, u_char *hash, u_int hashlen,
1087     const struct sshbuf *shared_secret, u_char **keyp)
1088 {
1089 	struct kex *kex = ssh->kex;
1090 	struct ssh_digest_ctx *hashctx = NULL;
1091 	char c = id;
1092 	u_int have;
1093 	size_t mdsz;
1094 	u_char *digest;
1095 	int r;
1096 
1097 	if ((mdsz = ssh_digest_bytes(kex->hash_alg)) == 0)
1098 		return SSH_ERR_INVALID_ARGUMENT;
1099 	if ((digest = calloc(1, ROUNDUP(need, mdsz))) == NULL) {
1100 		r = SSH_ERR_ALLOC_FAIL;
1101 		goto out;
1102 	}
1103 
1104 	/* K1 = HASH(K || H || "A" || session_id) */
1105 	if ((hashctx = ssh_digest_start(kex->hash_alg)) == NULL ||
1106 	    ssh_digest_update_buffer(hashctx, shared_secret) != 0 ||
1107 	    ssh_digest_update(hashctx, hash, hashlen) != 0 ||
1108 	    ssh_digest_update(hashctx, &c, 1) != 0 ||
1109 	    ssh_digest_update_buffer(hashctx, kex->session_id) != 0 ||
1110 	    ssh_digest_final(hashctx, digest, mdsz) != 0) {
1111 		r = SSH_ERR_LIBCRYPTO_ERROR;
1112 		error_f("KEX hash failed");
1113 		goto out;
1114 	}
1115 	ssh_digest_free(hashctx);
1116 	hashctx = NULL;
1117 
1118 	/*
1119 	 * expand key:
1120 	 * Kn = HASH(K || H || K1 || K2 || ... || Kn-1)
1121 	 * Key = K1 || K2 || ... || Kn
1122 	 */
1123 	for (have = mdsz; need > have; have += mdsz) {
1124 		if ((hashctx = ssh_digest_start(kex->hash_alg)) == NULL ||
1125 		    ssh_digest_update_buffer(hashctx, shared_secret) != 0 ||
1126 		    ssh_digest_update(hashctx, hash, hashlen) != 0 ||
1127 		    ssh_digest_update(hashctx, digest, have) != 0 ||
1128 		    ssh_digest_final(hashctx, digest + have, mdsz) != 0) {
1129 			error_f("KDF failed");
1130 			r = SSH_ERR_LIBCRYPTO_ERROR;
1131 			goto out;
1132 		}
1133 		ssh_digest_free(hashctx);
1134 		hashctx = NULL;
1135 	}
1136 #ifdef DEBUG_KEX
1137 	fprintf(stderr, "key '%c'== ", c);
1138 	dump_digest("key", digest, need);
1139 #endif
1140 	*keyp = digest;
1141 	digest = NULL;
1142 	r = 0;
1143  out:
1144 	free(digest);
1145 	ssh_digest_free(hashctx);
1146 	return r;
1147 }
1148 
1149 #define NKEYS	6
1150 int
1151 kex_derive_keys(struct ssh *ssh, u_char *hash, u_int hashlen,
1152     const struct sshbuf *shared_secret)
1153 {
1154 	struct kex *kex = ssh->kex;
1155 	u_char *keys[NKEYS];
1156 	u_int i, j, mode, ctos;
1157 	int r;
1158 
1159 	/* save initial hash as session id */
1160 	if ((kex->flags & KEX_INITIAL) != 0) {
1161 		if (sshbuf_len(kex->session_id) != 0) {
1162 			error_f("already have session ID at kex");
1163 			return SSH_ERR_INTERNAL_ERROR;
1164 		}
1165 		if ((r = sshbuf_put(kex->session_id, hash, hashlen)) != 0)
1166 			return r;
1167 	} else if (sshbuf_len(kex->session_id) == 0) {
1168 		error_f("no session ID in rekex");
1169 		return SSH_ERR_INTERNAL_ERROR;
1170 	}
1171 	for (i = 0; i < NKEYS; i++) {
1172 		if ((r = derive_key(ssh, 'A'+i, kex->we_need, hash, hashlen,
1173 		    shared_secret, &keys[i])) != 0) {
1174 			for (j = 0; j < i; j++)
1175 				free(keys[j]);
1176 			return r;
1177 		}
1178 	}
1179 	for (mode = 0; mode < MODE_MAX; mode++) {
1180 		ctos = (!kex->server && mode == MODE_OUT) ||
1181 		    (kex->server && mode == MODE_IN);
1182 		kex->newkeys[mode]->enc.iv  = keys[ctos ? 0 : 1];
1183 		kex->newkeys[mode]->enc.key = keys[ctos ? 2 : 3];
1184 		kex->newkeys[mode]->mac.key = keys[ctos ? 4 : 5];
1185 	}
1186 	return 0;
1187 }
1188 
1189 int
1190 kex_load_hostkey(struct ssh *ssh, struct sshkey **prvp, struct sshkey **pubp)
1191 {
1192 	struct kex *kex = ssh->kex;
1193 
1194 	*pubp = NULL;
1195 	*prvp = NULL;
1196 	if (kex->load_host_public_key == NULL ||
1197 	    kex->load_host_private_key == NULL) {
1198 		error_f("missing hostkey loader");
1199 		return SSH_ERR_INVALID_ARGUMENT;
1200 	}
1201 	*pubp = kex->load_host_public_key(kex->hostkey_type,
1202 	    kex->hostkey_nid, ssh);
1203 	*prvp = kex->load_host_private_key(kex->hostkey_type,
1204 	    kex->hostkey_nid, ssh);
1205 	if (*pubp == NULL)
1206 		return SSH_ERR_NO_HOSTKEY_LOADED;
1207 	return 0;
1208 }
1209 
1210 int
1211 kex_verify_host_key(struct ssh *ssh, struct sshkey *server_host_key)
1212 {
1213 	struct kex *kex = ssh->kex;
1214 
1215 	if (kex->verify_host_key == NULL) {
1216 		error_f("missing hostkey verifier");
1217 		return SSH_ERR_INVALID_ARGUMENT;
1218 	}
1219 	if (server_host_key->type != kex->hostkey_type ||
1220 	    (kex->hostkey_type == KEY_ECDSA &&
1221 	    server_host_key->ecdsa_nid != kex->hostkey_nid))
1222 		return SSH_ERR_KEY_TYPE_MISMATCH;
1223 	if (kex->verify_host_key(server_host_key, ssh) == -1)
1224 		return  SSH_ERR_SIGNATURE_INVALID;
1225 	return 0;
1226 }
1227 
1228 #if defined(DEBUG_KEX) || defined(DEBUG_KEXDH) || defined(DEBUG_KEXECDH)
1229 void
1230 dump_digest(const char *msg, const u_char *digest, int len)
1231 {
1232 	fprintf(stderr, "%s\n", msg);
1233 	sshbuf_dump_data(digest, len, stderr);
1234 }
1235 #endif
1236 
1237 /*
1238  * Send a plaintext error message to the peer, suffixed by \r\n.
1239  * Only used during banner exchange, and there only for the server.
1240  */
1241 static void
1242 send_error(struct ssh *ssh, const char *msg)
1243 {
1244 	const char *crnl = "\r\n";
1245 
1246 	if (!ssh->kex->server)
1247 		return;
1248 
1249 	if (atomicio(vwrite, ssh_packet_get_connection_out(ssh),
1250 	    __UNCONST(msg), strlen(msg)) != strlen(msg) ||
1251 	    atomicio(vwrite, ssh_packet_get_connection_out(ssh),
1252 	    __UNCONST(crnl), strlen(crnl)) != strlen(crnl))
1253 		error_f("write: %.100s", strerror(errno));
1254 }
1255 
1256 /*
1257  * Sends our identification string and waits for the peer's. Will block for
1258  * up to timeout_ms (or indefinitely if timeout_ms <= 0).
1259  * Returns on 0 success or a ssherr.h code on failure.
1260  */
1261 int
1262 kex_exchange_identification(struct ssh *ssh, int timeout_ms,
1263     const char *version_addendum)
1264 {
1265 	int remote_major, remote_minor, mismatch, oerrno = 0;
1266 	size_t len, n;
1267 	int r, expect_nl;
1268 	u_char c;
1269 	struct sshbuf *our_version = ssh->kex->server ?
1270 	    ssh->kex->server_version : ssh->kex->client_version;
1271 	struct sshbuf *peer_version = ssh->kex->server ?
1272 	    ssh->kex->client_version : ssh->kex->server_version;
1273 	char *our_version_string = NULL, *peer_version_string = NULL;
1274 	char *cp, *remote_version = NULL;
1275 
1276 	/* Prepare and send our banner */
1277 	sshbuf_reset(our_version);
1278 	if (version_addendum != NULL && *version_addendum == '\0')
1279 		version_addendum = NULL;
1280 	if ((r = sshbuf_putf(our_version, "SSH-%d.%d-%s%s%s\r\n",
1281 	    PROTOCOL_MAJOR_2, PROTOCOL_MINOR_2, SSH_VERSION,
1282 	    version_addendum == NULL ? "" : " ",
1283 	    version_addendum == NULL ? "" : version_addendum)) != 0) {
1284 		oerrno = errno;
1285 		error_fr(r, "sshbuf_putf");
1286 		goto out;
1287 	}
1288 
1289 	if (atomicio(vwrite, ssh_packet_get_connection_out(ssh),
1290 	    sshbuf_mutable_ptr(our_version),
1291 	    sshbuf_len(our_version)) != sshbuf_len(our_version)) {
1292 		oerrno = errno;
1293 		debug_f("write: %.100s", strerror(errno));
1294 		r = SSH_ERR_SYSTEM_ERROR;
1295 		goto out;
1296 	}
1297 	if ((r = sshbuf_consume_end(our_version, 2)) != 0) { /* trim \r\n */
1298 		oerrno = errno;
1299 		error_fr(r, "sshbuf_consume_end");
1300 		goto out;
1301 	}
1302 	our_version_string = sshbuf_dup_string(our_version);
1303 	if (our_version_string == NULL) {
1304 		error_f("sshbuf_dup_string failed");
1305 		r = SSH_ERR_ALLOC_FAIL;
1306 		goto out;
1307 	}
1308 	debug("Local version string %.100s", our_version_string);
1309 
1310 	/* Read other side's version identification. */
1311 	for (n = 0; ; n++) {
1312 		if (n >= SSH_MAX_PRE_BANNER_LINES) {
1313 			send_error(ssh, "No SSH identification string "
1314 			    "received.");
1315 			error_f("No SSH version received in first %u lines "
1316 			    "from server", SSH_MAX_PRE_BANNER_LINES);
1317 			r = SSH_ERR_INVALID_FORMAT;
1318 			goto out;
1319 		}
1320 		sshbuf_reset(peer_version);
1321 		expect_nl = 0;
1322 		for (;;) {
1323 			if (timeout_ms > 0) {
1324 				r = waitrfd(ssh_packet_get_connection_in(ssh),
1325 				    &timeout_ms, NULL);
1326 				if (r == -1 && errno == ETIMEDOUT) {
1327 					send_error(ssh, "Timed out waiting "
1328 					    "for SSH identification string.");
1329 					error("Connection timed out during "
1330 					    "banner exchange");
1331 					r = SSH_ERR_CONN_TIMEOUT;
1332 					goto out;
1333 				} else if (r == -1) {
1334 					oerrno = errno;
1335 					error_f("%s", strerror(errno));
1336 					r = SSH_ERR_SYSTEM_ERROR;
1337 					goto out;
1338 				}
1339 			}
1340 
1341 			len = atomicio(read, ssh_packet_get_connection_in(ssh),
1342 			    &c, 1);
1343 			if (len != 1 && errno == EPIPE) {
1344 				verbose_f("Connection closed by remote host");
1345 				r = SSH_ERR_CONN_CLOSED;
1346 				goto out;
1347 			} else if (len != 1) {
1348 				oerrno = errno;
1349 				error_f("read: %.100s", strerror(errno));
1350 				r = SSH_ERR_SYSTEM_ERROR;
1351 				goto out;
1352 			}
1353 			if (c == '\r') {
1354 				expect_nl = 1;
1355 				continue;
1356 			}
1357 			if (c == '\n')
1358 				break;
1359 			if (c == '\0' || expect_nl) {
1360 				verbose_f("banner line contains invalid "
1361 				    "characters");
1362 				goto invalid;
1363 			}
1364 			if ((r = sshbuf_put_u8(peer_version, c)) != 0) {
1365 				oerrno = errno;
1366 				error_fr(r, "sshbuf_put");
1367 				goto out;
1368 			}
1369 			if (sshbuf_len(peer_version) > SSH_MAX_BANNER_LEN) {
1370 				verbose_f("banner line too long");
1371 				goto invalid;
1372 			}
1373 		}
1374 		/* Is this an actual protocol banner? */
1375 		if (sshbuf_len(peer_version) > 4 &&
1376 		    memcmp(sshbuf_ptr(peer_version), "SSH-", 4) == 0)
1377 			break;
1378 		/* If not, then just log the line and continue */
1379 		if ((cp = sshbuf_dup_string(peer_version)) == NULL) {
1380 			error_f("sshbuf_dup_string failed");
1381 			r = SSH_ERR_ALLOC_FAIL;
1382 			goto out;
1383 		}
1384 		/* Do not accept lines before the SSH ident from a client */
1385 		if (ssh->kex->server) {
1386 			verbose_f("client sent invalid protocol identifier "
1387 			    "\"%.256s\"", cp);
1388 			free(cp);
1389 			goto invalid;
1390 		}
1391 		debug_f("banner line %zu: %s", n, cp);
1392 		free(cp);
1393 	}
1394 	peer_version_string = sshbuf_dup_string(peer_version);
1395 	if (peer_version_string == NULL)
1396 		fatal_f("sshbuf_dup_string failed");
1397 	/* XXX must be same size for sscanf */
1398 	if ((remote_version = calloc(1, sshbuf_len(peer_version))) == NULL) {
1399 		error_f("calloc failed");
1400 		r = SSH_ERR_ALLOC_FAIL;
1401 		goto out;
1402 	}
1403 
1404 	/*
1405 	 * Check that the versions match.  In future this might accept
1406 	 * several versions and set appropriate flags to handle them.
1407 	 */
1408 	if (sscanf(peer_version_string, "SSH-%d.%d-%[^\n]\n",
1409 	    &remote_major, &remote_minor, remote_version) != 3) {
1410 		error("Bad remote protocol version identification: '%.100s'",
1411 		    peer_version_string);
1412  invalid:
1413 		send_error(ssh, "Invalid SSH identification string.");
1414 		r = SSH_ERR_INVALID_FORMAT;
1415 		goto out;
1416 	}
1417 	debug("Remote protocol version %d.%d, remote software version %.100s",
1418 	    remote_major, remote_minor, remote_version);
1419 	compat_banner(ssh, remote_version);
1420 
1421 	mismatch = 0;
1422 	switch (remote_major) {
1423 	case 2:
1424 		break;
1425 	case 1:
1426 		if (remote_minor != 99)
1427 			mismatch = 1;
1428 		break;
1429 	default:
1430 		mismatch = 1;
1431 		break;
1432 	}
1433 	if (mismatch) {
1434 		error("Protocol major versions differ: %d vs. %d",
1435 		    PROTOCOL_MAJOR_2, remote_major);
1436 		send_error(ssh, "Protocol major versions differ.");
1437 		r = SSH_ERR_NO_PROTOCOL_VERSION;
1438 		goto out;
1439 	}
1440 
1441 	if (ssh->kex->server && (ssh->compat & SSH_BUG_PROBE) != 0) {
1442 		logit("probed from %s port %d with %s.  Don't panic.",
1443 		    ssh_remote_ipaddr(ssh), ssh_remote_port(ssh),
1444 		    peer_version_string);
1445 		r = SSH_ERR_CONN_CLOSED; /* XXX */
1446 		goto out;
1447 	}
1448 	if (ssh->kex->server && (ssh->compat & SSH_BUG_SCANNER) != 0) {
1449 		logit("scanned from %s port %d with %s.  Don't panic.",
1450 		    ssh_remote_ipaddr(ssh), ssh_remote_port(ssh),
1451 		    peer_version_string);
1452 		r = SSH_ERR_CONN_CLOSED; /* XXX */
1453 		goto out;
1454 	}
1455 	/* success */
1456 	r = 0;
1457  out:
1458 	free(our_version_string);
1459 	free(peer_version_string);
1460 	free(remote_version);
1461 	if (r == SSH_ERR_SYSTEM_ERROR)
1462 		errno = oerrno;
1463 	return r;
1464 }
1465 
1466