xref: /openbsd-src/usr.bin/ssh/kex.c (revision d59bb9942320b767f2a19aaa7690c8c6e30b724c)
1 /* $OpenBSD: kex.c,v 1.128 2017/02/03 23:01:19 djm Exp $ */
2 /*
3  * Copyright (c) 2000, 2001 Markus Friedl.  All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  * 1. Redistributions of source code must retain the above copyright
9  *    notice, this list of conditions and the following disclaimer.
10  * 2. Redistributions in binary form must reproduce the above copyright
11  *    notice, this list of conditions and the following disclaimer in the
12  *    documentation and/or other materials provided with the distribution.
13  *
14  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
15  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
16  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
17  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
18  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
19  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
20  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
21  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
23  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24  */
25 
26 
27 #include <signal.h>
28 #include <stdio.h>
29 #include <stdlib.h>
30 #include <string.h>
31 
32 #ifdef WITH_OPENSSL
33 #include <openssl/crypto.h>
34 #endif
35 
36 #include "ssh2.h"
37 #include "packet.h"
38 #include "compat.h"
39 #include "cipher.h"
40 #include "sshkey.h"
41 #include "kex.h"
42 #include "log.h"
43 #include "mac.h"
44 #include "match.h"
45 #include "misc.h"
46 #include "dispatch.h"
47 #include "monitor.h"
48 
49 #include "ssherr.h"
50 #include "sshbuf.h"
51 #include "digest.h"
52 
53 /* prototype */
54 static int kex_choose_conf(struct ssh *);
55 static int kex_input_newkeys(int, u_int32_t, void *);
56 
57 static const char *proposal_names[PROPOSAL_MAX] = {
58 	"KEX algorithms",
59 	"host key algorithms",
60 	"ciphers ctos",
61 	"ciphers stoc",
62 	"MACs ctos",
63 	"MACs stoc",
64 	"compression ctos",
65 	"compression stoc",
66 	"languages ctos",
67 	"languages stoc",
68 };
69 
70 struct kexalg {
71 	char *name;
72 	u_int type;
73 	int ec_nid;
74 	int hash_alg;
75 };
76 static const struct kexalg kexalgs[] = {
77 #ifdef WITH_OPENSSL
78 	{ KEX_DH1, KEX_DH_GRP1_SHA1, 0, SSH_DIGEST_SHA1 },
79 	{ KEX_DH14_SHA1, KEX_DH_GRP14_SHA1, 0, SSH_DIGEST_SHA1 },
80 	{ KEX_DH14_SHA256, KEX_DH_GRP14_SHA256, 0, SSH_DIGEST_SHA256 },
81 	{ KEX_DH16_SHA512, KEX_DH_GRP16_SHA512, 0, SSH_DIGEST_SHA512 },
82 	{ KEX_DH18_SHA512, KEX_DH_GRP18_SHA512, 0, SSH_DIGEST_SHA512 },
83 	{ KEX_DHGEX_SHA1, KEX_DH_GEX_SHA1, 0, SSH_DIGEST_SHA1 },
84 	{ KEX_DHGEX_SHA256, KEX_DH_GEX_SHA256, 0, SSH_DIGEST_SHA256 },
85 	{ KEX_ECDH_SHA2_NISTP256, KEX_ECDH_SHA2,
86 	    NID_X9_62_prime256v1, SSH_DIGEST_SHA256 },
87 	{ KEX_ECDH_SHA2_NISTP384, KEX_ECDH_SHA2, NID_secp384r1,
88 	    SSH_DIGEST_SHA384 },
89 	{ KEX_ECDH_SHA2_NISTP521, KEX_ECDH_SHA2, NID_secp521r1,
90 	    SSH_DIGEST_SHA512 },
91 #endif
92 	{ KEX_CURVE25519_SHA256, KEX_C25519_SHA256, 0, SSH_DIGEST_SHA256 },
93 	{ KEX_CURVE25519_SHA256_OLD, KEX_C25519_SHA256, 0, SSH_DIGEST_SHA256 },
94 	{ NULL, -1, -1, -1},
95 };
96 
97 char *
98 kex_alg_list(char sep)
99 {
100 	char *ret = NULL, *tmp;
101 	size_t nlen, rlen = 0;
102 	const struct kexalg *k;
103 
104 	for (k = kexalgs; k->name != NULL; k++) {
105 		if (ret != NULL)
106 			ret[rlen++] = sep;
107 		nlen = strlen(k->name);
108 		if ((tmp = realloc(ret, rlen + nlen + 2)) == NULL) {
109 			free(ret);
110 			return NULL;
111 		}
112 		ret = tmp;
113 		memcpy(ret + rlen, k->name, nlen + 1);
114 		rlen += nlen;
115 	}
116 	return ret;
117 }
118 
119 static const struct kexalg *
120 kex_alg_by_name(const char *name)
121 {
122 	const struct kexalg *k;
123 
124 	for (k = kexalgs; k->name != NULL; k++) {
125 		if (strcmp(k->name, name) == 0)
126 			return k;
127 	}
128 	return NULL;
129 }
130 
131 /* Validate KEX method name list */
132 int
133 kex_names_valid(const char *names)
134 {
135 	char *s, *cp, *p;
136 
137 	if (names == NULL || strcmp(names, "") == 0)
138 		return 0;
139 	if ((s = cp = strdup(names)) == NULL)
140 		return 0;
141 	for ((p = strsep(&cp, ",")); p && *p != '\0';
142 	    (p = strsep(&cp, ","))) {
143 		if (kex_alg_by_name(p) == NULL) {
144 			error("Unsupported KEX algorithm \"%.100s\"", p);
145 			free(s);
146 			return 0;
147 		}
148 	}
149 	debug3("kex names ok: [%s]", names);
150 	free(s);
151 	return 1;
152 }
153 
154 /*
155  * Concatenate algorithm names, avoiding duplicates in the process.
156  * Caller must free returned string.
157  */
158 char *
159 kex_names_cat(const char *a, const char *b)
160 {
161 	char *ret = NULL, *tmp = NULL, *cp, *p;
162 	size_t len;
163 
164 	if (a == NULL || *a == '\0')
165 		return NULL;
166 	if (b == NULL || *b == '\0')
167 		return strdup(a);
168 	if (strlen(b) > 1024*1024)
169 		return NULL;
170 	len = strlen(a) + strlen(b) + 2;
171 	if ((tmp = cp = strdup(b)) == NULL ||
172 	    (ret = calloc(1, len)) == NULL) {
173 		free(tmp);
174 		return NULL;
175 	}
176 	strlcpy(ret, a, len);
177 	for ((p = strsep(&cp, ",")); p && *p != '\0'; (p = strsep(&cp, ","))) {
178 		if (match_list(ret, p, NULL) != NULL)
179 			continue; /* Algorithm already present */
180 		if (strlcat(ret, ",", len) >= len ||
181 		    strlcat(ret, p, len) >= len) {
182 			free(tmp);
183 			free(ret);
184 			return NULL; /* Shouldn't happen */
185 		}
186 	}
187 	free(tmp);
188 	return ret;
189 }
190 
191 /*
192  * Assemble a list of algorithms from a default list and a string from a
193  * configuration file. The user-provided string may begin with '+' to
194  * indicate that it should be appended to the default or '-' that the
195  * specified names should be removed.
196  */
197 int
198 kex_assemble_names(const char *def, char **list)
199 {
200 	char *ret;
201 
202 	if (list == NULL || *list == NULL || **list == '\0') {
203 		*list = strdup(def);
204 		return 0;
205 	}
206 	if (**list == '+') {
207 		if ((ret = kex_names_cat(def, *list + 1)) == NULL)
208 			return SSH_ERR_ALLOC_FAIL;
209 		free(*list);
210 		*list = ret;
211 	} else if (**list == '-') {
212 		if ((ret = match_filter_list(def, *list + 1)) == NULL)
213 			return SSH_ERR_ALLOC_FAIL;
214 		free(*list);
215 		*list = ret;
216 	}
217 
218 	return 0;
219 }
220 
221 /* put algorithm proposal into buffer */
222 int
223 kex_prop2buf(struct sshbuf *b, char *proposal[PROPOSAL_MAX])
224 {
225 	u_int i;
226 	int r;
227 
228 	sshbuf_reset(b);
229 
230 	/*
231 	 * add a dummy cookie, the cookie will be overwritten by
232 	 * kex_send_kexinit(), each time a kexinit is set
233 	 */
234 	for (i = 0; i < KEX_COOKIE_LEN; i++) {
235 		if ((r = sshbuf_put_u8(b, 0)) != 0)
236 			return r;
237 	}
238 	for (i = 0; i < PROPOSAL_MAX; i++) {
239 		if ((r = sshbuf_put_cstring(b, proposal[i])) != 0)
240 			return r;
241 	}
242 	if ((r = sshbuf_put_u8(b, 0)) != 0 ||	/* first_kex_packet_follows */
243 	    (r = sshbuf_put_u32(b, 0)) != 0)	/* uint32 reserved */
244 		return r;
245 	return 0;
246 }
247 
248 /* parse buffer and return algorithm proposal */
249 int
250 kex_buf2prop(struct sshbuf *raw, int *first_kex_follows, char ***propp)
251 {
252 	struct sshbuf *b = NULL;
253 	u_char v;
254 	u_int i;
255 	char **proposal = NULL;
256 	int r;
257 
258 	*propp = NULL;
259 	if ((proposal = calloc(PROPOSAL_MAX, sizeof(char *))) == NULL)
260 		return SSH_ERR_ALLOC_FAIL;
261 	if ((b = sshbuf_fromb(raw)) == NULL) {
262 		r = SSH_ERR_ALLOC_FAIL;
263 		goto out;
264 	}
265 	if ((r = sshbuf_consume(b, KEX_COOKIE_LEN)) != 0) /* skip cookie */
266 		goto out;
267 	/* extract kex init proposal strings */
268 	for (i = 0; i < PROPOSAL_MAX; i++) {
269 		if ((r = sshbuf_get_cstring(b, &(proposal[i]), NULL)) != 0)
270 			goto out;
271 		debug2("%s: %s", proposal_names[i], proposal[i]);
272 	}
273 	/* first kex follows / reserved */
274 	if ((r = sshbuf_get_u8(b, &v)) != 0 ||	/* first_kex_follows */
275 	    (r = sshbuf_get_u32(b, &i)) != 0)	/* reserved */
276 		goto out;
277 	if (first_kex_follows != NULL)
278 		*first_kex_follows = v;
279 	debug2("first_kex_follows %d ", v);
280 	debug2("reserved %u ", i);
281 	r = 0;
282 	*propp = proposal;
283  out:
284 	if (r != 0 && proposal != NULL)
285 		kex_prop_free(proposal);
286 	sshbuf_free(b);
287 	return r;
288 }
289 
290 void
291 kex_prop_free(char **proposal)
292 {
293 	u_int i;
294 
295 	if (proposal == NULL)
296 		return;
297 	for (i = 0; i < PROPOSAL_MAX; i++)
298 		free(proposal[i]);
299 	free(proposal);
300 }
301 
302 /* ARGSUSED */
303 static int
304 kex_protocol_error(int type, u_int32_t seq, void *ctxt)
305 {
306 	struct ssh *ssh = active_state; /* XXX */
307 	int r;
308 
309 	error("kex protocol error: type %d seq %u", type, seq);
310 	if ((r = sshpkt_start(ssh, SSH2_MSG_UNIMPLEMENTED)) != 0 ||
311 	    (r = sshpkt_put_u32(ssh, seq)) != 0 ||
312 	    (r = sshpkt_send(ssh)) != 0)
313 		return r;
314 	return 0;
315 }
316 
317 static void
318 kex_reset_dispatch(struct ssh *ssh)
319 {
320 	ssh_dispatch_range(ssh, SSH2_MSG_TRANSPORT_MIN,
321 	    SSH2_MSG_TRANSPORT_MAX, &kex_protocol_error);
322 	ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_input_kexinit);
323 }
324 
325 static int
326 kex_send_ext_info(struct ssh *ssh)
327 {
328 	int r;
329 	char *algs;
330 
331 	if ((algs = sshkey_alg_list(0, 1, ',')) == NULL)
332 		return SSH_ERR_ALLOC_FAIL;
333 	if ((r = sshpkt_start(ssh, SSH2_MSG_EXT_INFO)) != 0 ||
334 	    (r = sshpkt_put_u32(ssh, 1)) != 0 ||
335 	    (r = sshpkt_put_cstring(ssh, "server-sig-algs")) != 0 ||
336 	    (r = sshpkt_put_cstring(ssh, algs)) != 0 ||
337 	    (r = sshpkt_send(ssh)) != 0)
338 		goto out;
339 	/* success */
340 	r = 0;
341  out:
342 	free(algs);
343 	return r;
344 }
345 
346 int
347 kex_send_newkeys(struct ssh *ssh)
348 {
349 	int r;
350 
351 	kex_reset_dispatch(ssh);
352 	if ((r = sshpkt_start(ssh, SSH2_MSG_NEWKEYS)) != 0 ||
353 	    (r = sshpkt_send(ssh)) != 0)
354 		return r;
355 	debug("SSH2_MSG_NEWKEYS sent");
356 	debug("expecting SSH2_MSG_NEWKEYS");
357 	ssh_dispatch_set(ssh, SSH2_MSG_NEWKEYS, &kex_input_newkeys);
358 	if (ssh->kex->ext_info_c)
359 		if ((r = kex_send_ext_info(ssh)) != 0)
360 			return r;
361 	return 0;
362 }
363 
364 int
365 kex_input_ext_info(int type, u_int32_t seq, void *ctxt)
366 {
367 	struct ssh *ssh = ctxt;
368 	struct kex *kex = ssh->kex;
369 	u_int32_t i, ninfo;
370 	char *name, *val, *found;
371 	int r;
372 
373 	debug("SSH2_MSG_EXT_INFO received");
374 	ssh_dispatch_set(ssh, SSH2_MSG_EXT_INFO, &kex_protocol_error);
375 	if ((r = sshpkt_get_u32(ssh, &ninfo)) != 0)
376 		return r;
377 	for (i = 0; i < ninfo; i++) {
378 		if ((r = sshpkt_get_cstring(ssh, &name, NULL)) != 0)
379 			return r;
380 		if ((r = sshpkt_get_cstring(ssh, &val, NULL)) != 0) {
381 			free(name);
382 			return r;
383 		}
384 		debug("%s: %s=<%s>", __func__, name, val);
385 		if (strcmp(name, "server-sig-algs") == 0) {
386 			found = match_list("rsa-sha2-256", val, NULL);
387 			if (found) {
388 				kex->rsa_sha2 = 256;
389 				free(found);
390 			}
391 			found = match_list("rsa-sha2-512", val, NULL);
392 			if (found) {
393 				kex->rsa_sha2 = 512;
394 				free(found);
395 			}
396 		}
397 		free(name);
398 		free(val);
399 	}
400 	return sshpkt_get_end(ssh);
401 }
402 
403 static int
404 kex_input_newkeys(int type, u_int32_t seq, void *ctxt)
405 {
406 	struct ssh *ssh = ctxt;
407 	struct kex *kex = ssh->kex;
408 	int r;
409 
410 	debug("SSH2_MSG_NEWKEYS received");
411 	ssh_dispatch_set(ssh, SSH2_MSG_NEWKEYS, &kex_protocol_error);
412 	if ((r = sshpkt_get_end(ssh)) != 0)
413 		return r;
414 	if ((r = ssh_set_newkeys(ssh, MODE_IN)) != 0)
415 		return r;
416 	kex->done = 1;
417 	sshbuf_reset(kex->peer);
418 	/* sshbuf_reset(kex->my); */
419 	kex->flags &= ~KEX_INIT_SENT;
420 	free(kex->name);
421 	kex->name = NULL;
422 	return 0;
423 }
424 
425 int
426 kex_send_kexinit(struct ssh *ssh)
427 {
428 	u_char *cookie;
429 	struct kex *kex = ssh->kex;
430 	int r;
431 
432 	if (kex == NULL)
433 		return SSH_ERR_INTERNAL_ERROR;
434 	if (kex->flags & KEX_INIT_SENT)
435 		return 0;
436 	kex->done = 0;
437 
438 	/* generate a random cookie */
439 	if (sshbuf_len(kex->my) < KEX_COOKIE_LEN)
440 		return SSH_ERR_INVALID_FORMAT;
441 	if ((cookie = sshbuf_mutable_ptr(kex->my)) == NULL)
442 		return SSH_ERR_INTERNAL_ERROR;
443 	arc4random_buf(cookie, KEX_COOKIE_LEN);
444 
445 	if ((r = sshpkt_start(ssh, SSH2_MSG_KEXINIT)) != 0 ||
446 	    (r = sshpkt_putb(ssh, kex->my)) != 0 ||
447 	    (r = sshpkt_send(ssh)) != 0)
448 		return r;
449 	debug("SSH2_MSG_KEXINIT sent");
450 	kex->flags |= KEX_INIT_SENT;
451 	return 0;
452 }
453 
454 /* ARGSUSED */
455 int
456 kex_input_kexinit(int type, u_int32_t seq, void *ctxt)
457 {
458 	struct ssh *ssh = ctxt;
459 	struct kex *kex = ssh->kex;
460 	const u_char *ptr;
461 	u_int i;
462 	size_t dlen;
463 	int r;
464 
465 	debug("SSH2_MSG_KEXINIT received");
466 	if (kex == NULL)
467 		return SSH_ERR_INVALID_ARGUMENT;
468 
469 	ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, NULL);
470 	ptr = sshpkt_ptr(ssh, &dlen);
471 	if ((r = sshbuf_put(kex->peer, ptr, dlen)) != 0)
472 		return r;
473 
474 	/* discard packet */
475 	for (i = 0; i < KEX_COOKIE_LEN; i++)
476 		if ((r = sshpkt_get_u8(ssh, NULL)) != 0)
477 			return r;
478 	for (i = 0; i < PROPOSAL_MAX; i++)
479 		if ((r = sshpkt_get_string(ssh, NULL, NULL)) != 0)
480 			return r;
481 	/*
482 	 * XXX RFC4253 sec 7: "each side MAY guess" - currently no supported
483 	 * KEX method has the server move first, but a server might be using
484 	 * a custom method or one that we otherwise don't support. We should
485 	 * be prepared to remember first_kex_follows here so we can eat a
486 	 * packet later.
487 	 * XXX2 - RFC4253 is kind of ambiguous on what first_kex_follows means
488 	 * for cases where the server *doesn't* go first. I guess we should
489 	 * ignore it when it is set for these cases, which is what we do now.
490 	 */
491 	if ((r = sshpkt_get_u8(ssh, NULL)) != 0 ||	/* first_kex_follows */
492 	    (r = sshpkt_get_u32(ssh, NULL)) != 0 ||	/* reserved */
493 	    (r = sshpkt_get_end(ssh)) != 0)
494 			return r;
495 
496 	if (!(kex->flags & KEX_INIT_SENT))
497 		if ((r = kex_send_kexinit(ssh)) != 0)
498 			return r;
499 	if ((r = kex_choose_conf(ssh)) != 0)
500 		return r;
501 
502 	if (kex->kex_type < KEX_MAX && kex->kex[kex->kex_type] != NULL)
503 		return (kex->kex[kex->kex_type])(ssh);
504 
505 	return SSH_ERR_INTERNAL_ERROR;
506 }
507 
508 int
509 kex_new(struct ssh *ssh, char *proposal[PROPOSAL_MAX], struct kex **kexp)
510 {
511 	struct kex *kex;
512 	int r;
513 
514 	*kexp = NULL;
515 	if ((kex = calloc(1, sizeof(*kex))) == NULL)
516 		return SSH_ERR_ALLOC_FAIL;
517 	if ((kex->peer = sshbuf_new()) == NULL ||
518 	    (kex->my = sshbuf_new()) == NULL) {
519 		r = SSH_ERR_ALLOC_FAIL;
520 		goto out;
521 	}
522 	if ((r = kex_prop2buf(kex->my, proposal)) != 0)
523 		goto out;
524 	kex->done = 0;
525 	kex_reset_dispatch(ssh);
526 	r = 0;
527 	*kexp = kex;
528  out:
529 	if (r != 0)
530 		kex_free(kex);
531 	return r;
532 }
533 
534 void
535 kex_free_newkeys(struct newkeys *newkeys)
536 {
537 	if (newkeys == NULL)
538 		return;
539 	if (newkeys->enc.key) {
540 		explicit_bzero(newkeys->enc.key, newkeys->enc.key_len);
541 		free(newkeys->enc.key);
542 		newkeys->enc.key = NULL;
543 	}
544 	if (newkeys->enc.iv) {
545 		explicit_bzero(newkeys->enc.iv, newkeys->enc.iv_len);
546 		free(newkeys->enc.iv);
547 		newkeys->enc.iv = NULL;
548 	}
549 	free(newkeys->enc.name);
550 	explicit_bzero(&newkeys->enc, sizeof(newkeys->enc));
551 	free(newkeys->comp.name);
552 	explicit_bzero(&newkeys->comp, sizeof(newkeys->comp));
553 	mac_clear(&newkeys->mac);
554 	if (newkeys->mac.key) {
555 		explicit_bzero(newkeys->mac.key, newkeys->mac.key_len);
556 		free(newkeys->mac.key);
557 		newkeys->mac.key = NULL;
558 	}
559 	free(newkeys->mac.name);
560 	explicit_bzero(&newkeys->mac, sizeof(newkeys->mac));
561 	explicit_bzero(newkeys, sizeof(*newkeys));
562 	free(newkeys);
563 }
564 
565 void
566 kex_free(struct kex *kex)
567 {
568 	u_int mode;
569 
570 #ifdef WITH_OPENSSL
571 	if (kex->dh)
572 		DH_free(kex->dh);
573 	if (kex->ec_client_key)
574 		EC_KEY_free(kex->ec_client_key);
575 #endif
576 	for (mode = 0; mode < MODE_MAX; mode++) {
577 		kex_free_newkeys(kex->newkeys[mode]);
578 		kex->newkeys[mode] = NULL;
579 	}
580 	sshbuf_free(kex->peer);
581 	sshbuf_free(kex->my);
582 	free(kex->session_id);
583 	free(kex->client_version_string);
584 	free(kex->server_version_string);
585 	free(kex->failed_choice);
586 	free(kex->hostkey_alg);
587 	free(kex->name);
588 	free(kex);
589 }
590 
591 int
592 kex_setup(struct ssh *ssh, char *proposal[PROPOSAL_MAX])
593 {
594 	int r;
595 
596 	if ((r = kex_new(ssh, proposal, &ssh->kex)) != 0)
597 		return r;
598 	if ((r = kex_send_kexinit(ssh)) != 0) {		/* we start */
599 		kex_free(ssh->kex);
600 		ssh->kex = NULL;
601 		return r;
602 	}
603 	return 0;
604 }
605 
606 /*
607  * Request key re-exchange, returns 0 on success or a ssherr.h error
608  * code otherwise. Must not be called if KEX is incomplete or in-progress.
609  */
610 int
611 kex_start_rekex(struct ssh *ssh)
612 {
613 	if (ssh->kex == NULL) {
614 		error("%s: no kex", __func__);
615 		return SSH_ERR_INTERNAL_ERROR;
616 	}
617 	if (ssh->kex->done == 0) {
618 		error("%s: requested twice", __func__);
619 		return SSH_ERR_INTERNAL_ERROR;
620 	}
621 	ssh->kex->done = 0;
622 	return kex_send_kexinit(ssh);
623 }
624 
625 static int
626 choose_enc(struct sshenc *enc, char *client, char *server)
627 {
628 	char *name = match_list(client, server, NULL);
629 
630 	if (name == NULL)
631 		return SSH_ERR_NO_CIPHER_ALG_MATCH;
632 	if ((enc->cipher = cipher_by_name(name)) == NULL)
633 		return SSH_ERR_INTERNAL_ERROR;
634 	enc->name = name;
635 	enc->enabled = 0;
636 	enc->iv = NULL;
637 	enc->iv_len = cipher_ivlen(enc->cipher);
638 	enc->key = NULL;
639 	enc->key_len = cipher_keylen(enc->cipher);
640 	enc->block_size = cipher_blocksize(enc->cipher);
641 	return 0;
642 }
643 
644 static int
645 choose_mac(struct ssh *ssh, struct sshmac *mac, char *client, char *server)
646 {
647 	char *name = match_list(client, server, NULL);
648 
649 	if (name == NULL)
650 		return SSH_ERR_NO_MAC_ALG_MATCH;
651 	if (mac_setup(mac, name) < 0)
652 		return SSH_ERR_INTERNAL_ERROR;
653 	/* truncate the key */
654 	if (ssh->compat & SSH_BUG_HMAC)
655 		mac->key_len = 16;
656 	mac->name = name;
657 	mac->key = NULL;
658 	mac->enabled = 0;
659 	return 0;
660 }
661 
662 static int
663 choose_comp(struct sshcomp *comp, char *client, char *server)
664 {
665 	char *name = match_list(client, server, NULL);
666 
667 	if (name == NULL)
668 		return SSH_ERR_NO_COMPRESS_ALG_MATCH;
669 	if (strcmp(name, "zlib@openssh.com") == 0) {
670 		comp->type = COMP_DELAYED;
671 	} else if (strcmp(name, "zlib") == 0) {
672 		comp->type = COMP_ZLIB;
673 	} else if (strcmp(name, "none") == 0) {
674 		comp->type = COMP_NONE;
675 	} else {
676 		return SSH_ERR_INTERNAL_ERROR;
677 	}
678 	comp->name = name;
679 	return 0;
680 }
681 
682 static int
683 choose_kex(struct kex *k, char *client, char *server)
684 {
685 	const struct kexalg *kexalg;
686 
687 	k->name = match_list(client, server, NULL);
688 
689 	debug("kex: algorithm: %s", k->name ? k->name : "(no match)");
690 	if (k->name == NULL)
691 		return SSH_ERR_NO_KEX_ALG_MATCH;
692 	if ((kexalg = kex_alg_by_name(k->name)) == NULL)
693 		return SSH_ERR_INTERNAL_ERROR;
694 	k->kex_type = kexalg->type;
695 	k->hash_alg = kexalg->hash_alg;
696 	k->ec_nid = kexalg->ec_nid;
697 	return 0;
698 }
699 
700 static int
701 choose_hostkeyalg(struct kex *k, char *client, char *server)
702 {
703 	k->hostkey_alg = match_list(client, server, NULL);
704 
705 	debug("kex: host key algorithm: %s",
706 	    k->hostkey_alg ? k->hostkey_alg : "(no match)");
707 	if (k->hostkey_alg == NULL)
708 		return SSH_ERR_NO_HOSTKEY_ALG_MATCH;
709 	k->hostkey_type = sshkey_type_from_name(k->hostkey_alg);
710 	if (k->hostkey_type == KEY_UNSPEC)
711 		return SSH_ERR_INTERNAL_ERROR;
712 	k->hostkey_nid = sshkey_ecdsa_nid_from_name(k->hostkey_alg);
713 	return 0;
714 }
715 
716 static int
717 proposals_match(char *my[PROPOSAL_MAX], char *peer[PROPOSAL_MAX])
718 {
719 	static int check[] = {
720 		PROPOSAL_KEX_ALGS, PROPOSAL_SERVER_HOST_KEY_ALGS, -1
721 	};
722 	int *idx;
723 	char *p;
724 
725 	for (idx = &check[0]; *idx != -1; idx++) {
726 		if ((p = strchr(my[*idx], ',')) != NULL)
727 			*p = '\0';
728 		if ((p = strchr(peer[*idx], ',')) != NULL)
729 			*p = '\0';
730 		if (strcmp(my[*idx], peer[*idx]) != 0) {
731 			debug2("proposal mismatch: my %s peer %s",
732 			    my[*idx], peer[*idx]);
733 			return (0);
734 		}
735 	}
736 	debug2("proposals match");
737 	return (1);
738 }
739 
740 static int
741 kex_choose_conf(struct ssh *ssh)
742 {
743 	struct kex *kex = ssh->kex;
744 	struct newkeys *newkeys;
745 	char **my = NULL, **peer = NULL;
746 	char **cprop, **sprop;
747 	int nenc, nmac, ncomp;
748 	u_int mode, ctos, need, dh_need, authlen;
749 	int r, first_kex_follows;
750 
751 	debug2("local %s KEXINIT proposal", kex->server ? "server" : "client");
752 	if ((r = kex_buf2prop(kex->my, NULL, &my)) != 0)
753 		goto out;
754 	debug2("peer %s KEXINIT proposal", kex->server ? "client" : "server");
755 	if ((r = kex_buf2prop(kex->peer, &first_kex_follows, &peer)) != 0)
756 		goto out;
757 
758 	if (kex->server) {
759 		cprop=peer;
760 		sprop=my;
761 	} else {
762 		cprop=my;
763 		sprop=peer;
764 	}
765 
766 	/* Check whether client supports ext_info_c */
767 	if (kex->server) {
768 		char *ext;
769 
770 		ext = match_list("ext-info-c", peer[PROPOSAL_KEX_ALGS], NULL);
771 		kex->ext_info_c = (ext != NULL);
772 		free(ext);
773 	}
774 
775 	/* Algorithm Negotiation */
776 	if ((r = choose_kex(kex, cprop[PROPOSAL_KEX_ALGS],
777 	    sprop[PROPOSAL_KEX_ALGS])) != 0) {
778 		kex->failed_choice = peer[PROPOSAL_KEX_ALGS];
779 		peer[PROPOSAL_KEX_ALGS] = NULL;
780 		goto out;
781 	}
782 	if ((r = choose_hostkeyalg(kex, cprop[PROPOSAL_SERVER_HOST_KEY_ALGS],
783 	    sprop[PROPOSAL_SERVER_HOST_KEY_ALGS])) != 0) {
784 		kex->failed_choice = peer[PROPOSAL_SERVER_HOST_KEY_ALGS];
785 		peer[PROPOSAL_SERVER_HOST_KEY_ALGS] = NULL;
786 		goto out;
787 	}
788 	for (mode = 0; mode < MODE_MAX; mode++) {
789 		if ((newkeys = calloc(1, sizeof(*newkeys))) == NULL) {
790 			r = SSH_ERR_ALLOC_FAIL;
791 			goto out;
792 		}
793 		kex->newkeys[mode] = newkeys;
794 		ctos = (!kex->server && mode == MODE_OUT) ||
795 		    (kex->server && mode == MODE_IN);
796 		nenc  = ctos ? PROPOSAL_ENC_ALGS_CTOS  : PROPOSAL_ENC_ALGS_STOC;
797 		nmac  = ctos ? PROPOSAL_MAC_ALGS_CTOS  : PROPOSAL_MAC_ALGS_STOC;
798 		ncomp = ctos ? PROPOSAL_COMP_ALGS_CTOS : PROPOSAL_COMP_ALGS_STOC;
799 		if ((r = choose_enc(&newkeys->enc, cprop[nenc],
800 		    sprop[nenc])) != 0) {
801 			kex->failed_choice = peer[nenc];
802 			peer[nenc] = NULL;
803 			goto out;
804 		}
805 		authlen = cipher_authlen(newkeys->enc.cipher);
806 		/* ignore mac for authenticated encryption */
807 		if (authlen == 0 &&
808 		    (r = choose_mac(ssh, &newkeys->mac, cprop[nmac],
809 		    sprop[nmac])) != 0) {
810 			kex->failed_choice = peer[nmac];
811 			peer[nmac] = NULL;
812 			goto out;
813 		}
814 		if ((r = choose_comp(&newkeys->comp, cprop[ncomp],
815 		    sprop[ncomp])) != 0) {
816 			kex->failed_choice = peer[ncomp];
817 			peer[ncomp] = NULL;
818 			goto out;
819 		}
820 		debug("kex: %s cipher: %s MAC: %s compression: %s",
821 		    ctos ? "client->server" : "server->client",
822 		    newkeys->enc.name,
823 		    authlen == 0 ? newkeys->mac.name : "<implicit>",
824 		    newkeys->comp.name);
825 	}
826 	need = dh_need = 0;
827 	for (mode = 0; mode < MODE_MAX; mode++) {
828 		newkeys = kex->newkeys[mode];
829 		need = MAXIMUM(need, newkeys->enc.key_len);
830 		need = MAXIMUM(need, newkeys->enc.block_size);
831 		need = MAXIMUM(need, newkeys->enc.iv_len);
832 		need = MAXIMUM(need, newkeys->mac.key_len);
833 		dh_need = MAXIMUM(dh_need, cipher_seclen(newkeys->enc.cipher));
834 		dh_need = MAXIMUM(dh_need, newkeys->enc.block_size);
835 		dh_need = MAXIMUM(dh_need, newkeys->enc.iv_len);
836 		dh_need = MAXIMUM(dh_need, newkeys->mac.key_len);
837 	}
838 	/* XXX need runden? */
839 	kex->we_need = need;
840 	kex->dh_need = dh_need;
841 
842 	/* ignore the next message if the proposals do not match */
843 	if (first_kex_follows && !proposals_match(my, peer) &&
844 	    !(ssh->compat & SSH_BUG_FIRSTKEX))
845 		ssh->dispatch_skip_packets = 1;
846 	r = 0;
847  out:
848 	kex_prop_free(my);
849 	kex_prop_free(peer);
850 	return r;
851 }
852 
853 static int
854 derive_key(struct ssh *ssh, int id, u_int need, u_char *hash, u_int hashlen,
855     const struct sshbuf *shared_secret, u_char **keyp)
856 {
857 	struct kex *kex = ssh->kex;
858 	struct ssh_digest_ctx *hashctx = NULL;
859 	char c = id;
860 	u_int have;
861 	size_t mdsz;
862 	u_char *digest;
863 	int r;
864 
865 	if ((mdsz = ssh_digest_bytes(kex->hash_alg)) == 0)
866 		return SSH_ERR_INVALID_ARGUMENT;
867 	if ((digest = calloc(1, ROUNDUP(need, mdsz))) == NULL) {
868 		r = SSH_ERR_ALLOC_FAIL;
869 		goto out;
870 	}
871 
872 	/* K1 = HASH(K || H || "A" || session_id) */
873 	if ((hashctx = ssh_digest_start(kex->hash_alg)) == NULL ||
874 	    ssh_digest_update_buffer(hashctx, shared_secret) != 0 ||
875 	    ssh_digest_update(hashctx, hash, hashlen) != 0 ||
876 	    ssh_digest_update(hashctx, &c, 1) != 0 ||
877 	    ssh_digest_update(hashctx, kex->session_id,
878 	    kex->session_id_len) != 0 ||
879 	    ssh_digest_final(hashctx, digest, mdsz) != 0) {
880 		r = SSH_ERR_LIBCRYPTO_ERROR;
881 		goto out;
882 	}
883 	ssh_digest_free(hashctx);
884 	hashctx = NULL;
885 
886 	/*
887 	 * expand key:
888 	 * Kn = HASH(K || H || K1 || K2 || ... || Kn-1)
889 	 * Key = K1 || K2 || ... || Kn
890 	 */
891 	for (have = mdsz; need > have; have += mdsz) {
892 		if ((hashctx = ssh_digest_start(kex->hash_alg)) == NULL ||
893 		    ssh_digest_update_buffer(hashctx, shared_secret) != 0 ||
894 		    ssh_digest_update(hashctx, hash, hashlen) != 0 ||
895 		    ssh_digest_update(hashctx, digest, have) != 0 ||
896 		    ssh_digest_final(hashctx, digest + have, mdsz) != 0) {
897 			r = SSH_ERR_LIBCRYPTO_ERROR;
898 			goto out;
899 		}
900 		ssh_digest_free(hashctx);
901 		hashctx = NULL;
902 	}
903 #ifdef DEBUG_KEX
904 	fprintf(stderr, "key '%c'== ", c);
905 	dump_digest("key", digest, need);
906 #endif
907 	*keyp = digest;
908 	digest = NULL;
909 	r = 0;
910  out:
911 	free(digest);
912 	ssh_digest_free(hashctx);
913 	return r;
914 }
915 
916 #define NKEYS	6
917 int
918 kex_derive_keys(struct ssh *ssh, u_char *hash, u_int hashlen,
919     const struct sshbuf *shared_secret)
920 {
921 	struct kex *kex = ssh->kex;
922 	u_char *keys[NKEYS];
923 	u_int i, j, mode, ctos;
924 	int r;
925 
926 	for (i = 0; i < NKEYS; i++) {
927 		if ((r = derive_key(ssh, 'A'+i, kex->we_need, hash, hashlen,
928 		    shared_secret, &keys[i])) != 0) {
929 			for (j = 0; j < i; j++)
930 				free(keys[j]);
931 			return r;
932 		}
933 	}
934 	for (mode = 0; mode < MODE_MAX; mode++) {
935 		ctos = (!kex->server && mode == MODE_OUT) ||
936 		    (kex->server && mode == MODE_IN);
937 		kex->newkeys[mode]->enc.iv  = keys[ctos ? 0 : 1];
938 		kex->newkeys[mode]->enc.key = keys[ctos ? 2 : 3];
939 		kex->newkeys[mode]->mac.key = keys[ctos ? 4 : 5];
940 	}
941 	return 0;
942 }
943 
944 #ifdef WITH_OPENSSL
945 int
946 kex_derive_keys_bn(struct ssh *ssh, u_char *hash, u_int hashlen,
947     const BIGNUM *secret)
948 {
949 	struct sshbuf *shared_secret;
950 	int r;
951 
952 	if ((shared_secret = sshbuf_new()) == NULL)
953 		return SSH_ERR_ALLOC_FAIL;
954 	if ((r = sshbuf_put_bignum2(shared_secret, secret)) == 0)
955 		r = kex_derive_keys(ssh, hash, hashlen, shared_secret);
956 	sshbuf_free(shared_secret);
957 	return r;
958 }
959 #endif
960 
961 #ifdef WITH_SSH1
962 int
963 derive_ssh1_session_id(BIGNUM *host_modulus, BIGNUM *server_modulus,
964     u_int8_t cookie[8], u_int8_t id[16])
965 {
966 	u_int8_t hbuf[2048], sbuf[2048], obuf[SSH_DIGEST_MAX_LENGTH];
967 	struct ssh_digest_ctx *hashctx = NULL;
968 	size_t hlen, slen;
969 	int r;
970 
971 	hlen = BN_num_bytes(host_modulus);
972 	slen = BN_num_bytes(server_modulus);
973 	if (hlen < (512 / 8) || (u_int)hlen > sizeof(hbuf) ||
974 	    slen < (512 / 8) || (u_int)slen > sizeof(sbuf))
975 		return SSH_ERR_KEY_BITS_MISMATCH;
976 	if (BN_bn2bin(host_modulus, hbuf) <= 0 ||
977 	    BN_bn2bin(server_modulus, sbuf) <= 0) {
978 		r = SSH_ERR_LIBCRYPTO_ERROR;
979 		goto out;
980 	}
981 	if ((hashctx = ssh_digest_start(SSH_DIGEST_MD5)) == NULL) {
982 		r = SSH_ERR_ALLOC_FAIL;
983 		goto out;
984 	}
985 	if (ssh_digest_update(hashctx, hbuf, hlen) != 0 ||
986 	    ssh_digest_update(hashctx, sbuf, slen) != 0 ||
987 	    ssh_digest_update(hashctx, cookie, 8) != 0 ||
988 	    ssh_digest_final(hashctx, obuf, sizeof(obuf)) != 0) {
989 		r = SSH_ERR_LIBCRYPTO_ERROR;
990 		goto out;
991 	}
992 	memcpy(id, obuf, ssh_digest_bytes(SSH_DIGEST_MD5));
993 	r = 0;
994  out:
995 	ssh_digest_free(hashctx);
996 	explicit_bzero(hbuf, sizeof(hbuf));
997 	explicit_bzero(sbuf, sizeof(sbuf));
998 	explicit_bzero(obuf, sizeof(obuf));
999 	return r;
1000 }
1001 #endif
1002 
1003 #if defined(DEBUG_KEX) || defined(DEBUG_KEXDH) || defined(DEBUG_KEXECDH)
1004 void
1005 dump_digest(char *msg, u_char *digest, int len)
1006 {
1007 	fprintf(stderr, "%s\n", msg);
1008 	sshbuf_dump_data(digest, len, stderr);
1009 }
1010 #endif
1011