xref: /openbsd-src/usr.bin/ssh/kex.c (revision fe38b55cb0aae270de3f844146814682e8cd345c)
1 /* $OpenBSD: kex.c,v 1.118 2016/05/02 10:26:04 djm Exp $ */
2 /*
3  * Copyright (c) 2000, 2001 Markus Friedl.  All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  * 1. Redistributions of source code must retain the above copyright
9  *    notice, this list of conditions and the following disclaimer.
10  * 2. Redistributions in binary form must reproduce the above copyright
11  *    notice, this list of conditions and the following disclaimer in the
12  *    documentation and/or other materials provided with the distribution.
13  *
14  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
15  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
16  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
17  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
18  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
19  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
20  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
21  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
23  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24  */
25 
26 #include <sys/param.h>	/* MAX roundup */
27 
28 #include <signal.h>
29 #include <stdio.h>
30 #include <stdlib.h>
31 #include <string.h>
32 
33 #ifdef WITH_OPENSSL
34 #include <openssl/crypto.h>
35 #endif
36 
37 #include "ssh2.h"
38 #include "packet.h"
39 #include "compat.h"
40 #include "cipher.h"
41 #include "sshkey.h"
42 #include "kex.h"
43 #include "log.h"
44 #include "mac.h"
45 #include "match.h"
46 #include "misc.h"
47 #include "dispatch.h"
48 #include "monitor.h"
49 
50 #include "ssherr.h"
51 #include "sshbuf.h"
52 #include "digest.h"
53 
54 /* prototype */
55 static int kex_choose_conf(struct ssh *);
56 static int kex_input_newkeys(int, u_int32_t, void *);
57 
58 static const char *proposal_names[PROPOSAL_MAX] = {
59 	"KEX algorithms",
60 	"host key algorithms",
61 	"ciphers ctos",
62 	"ciphers stoc",
63 	"MACs ctos",
64 	"MACs stoc",
65 	"compression ctos",
66 	"compression stoc",
67 	"languages ctos",
68 	"languages stoc",
69 };
70 
71 struct kexalg {
72 	char *name;
73 	u_int type;
74 	int ec_nid;
75 	int hash_alg;
76 };
77 static const struct kexalg kexalgs[] = {
78 #ifdef WITH_OPENSSL
79 	{ KEX_DH1, KEX_DH_GRP1_SHA1, 0, SSH_DIGEST_SHA1 },
80 	{ KEX_DH14_SHA1, KEX_DH_GRP14_SHA1, 0, SSH_DIGEST_SHA1 },
81 	{ KEX_DH14_SHA256, KEX_DH_GRP14_SHA256, 0, SSH_DIGEST_SHA256 },
82 	{ KEX_DH16_SHA512, KEX_DH_GRP16_SHA512, 0, SSH_DIGEST_SHA512 },
83 	{ KEX_DH18_SHA512, KEX_DH_GRP18_SHA512, 0, SSH_DIGEST_SHA512 },
84 	{ KEX_DHGEX_SHA1, KEX_DH_GEX_SHA1, 0, SSH_DIGEST_SHA1 },
85 	{ KEX_DHGEX_SHA256, KEX_DH_GEX_SHA256, 0, SSH_DIGEST_SHA256 },
86 	{ KEX_ECDH_SHA2_NISTP256, KEX_ECDH_SHA2,
87 	    NID_X9_62_prime256v1, SSH_DIGEST_SHA256 },
88 	{ KEX_ECDH_SHA2_NISTP384, KEX_ECDH_SHA2, NID_secp384r1,
89 	    SSH_DIGEST_SHA384 },
90 	{ KEX_ECDH_SHA2_NISTP521, KEX_ECDH_SHA2, NID_secp521r1,
91 	    SSH_DIGEST_SHA512 },
92 #endif
93 	{ KEX_CURVE25519_SHA256, KEX_C25519_SHA256, 0, SSH_DIGEST_SHA256 },
94 	{ NULL, -1, -1, -1},
95 };
96 
97 char *
98 kex_alg_list(char sep)
99 {
100 	char *ret = NULL, *tmp;
101 	size_t nlen, rlen = 0;
102 	const struct kexalg *k;
103 
104 	for (k = kexalgs; k->name != NULL; k++) {
105 		if (ret != NULL)
106 			ret[rlen++] = sep;
107 		nlen = strlen(k->name);
108 		if ((tmp = realloc(ret, rlen + nlen + 2)) == NULL) {
109 			free(ret);
110 			return NULL;
111 		}
112 		ret = tmp;
113 		memcpy(ret + rlen, k->name, nlen + 1);
114 		rlen += nlen;
115 	}
116 	return ret;
117 }
118 
119 static const struct kexalg *
120 kex_alg_by_name(const char *name)
121 {
122 	const struct kexalg *k;
123 
124 	for (k = kexalgs; k->name != NULL; k++) {
125 		if (strcmp(k->name, name) == 0)
126 			return k;
127 	}
128 	return NULL;
129 }
130 
131 /* Validate KEX method name list */
132 int
133 kex_names_valid(const char *names)
134 {
135 	char *s, *cp, *p;
136 
137 	if (names == NULL || strcmp(names, "") == 0)
138 		return 0;
139 	if ((s = cp = strdup(names)) == NULL)
140 		return 0;
141 	for ((p = strsep(&cp, ",")); p && *p != '\0';
142 	    (p = strsep(&cp, ","))) {
143 		if (kex_alg_by_name(p) == NULL) {
144 			error("Unsupported KEX algorithm \"%.100s\"", p);
145 			free(s);
146 			return 0;
147 		}
148 	}
149 	debug3("kex names ok: [%s]", names);
150 	free(s);
151 	return 1;
152 }
153 
154 /*
155  * Concatenate algorithm names, avoiding duplicates in the process.
156  * Caller must free returned string.
157  */
158 char *
159 kex_names_cat(const char *a, const char *b)
160 {
161 	char *ret = NULL, *tmp = NULL, *cp, *p;
162 	size_t len;
163 
164 	if (a == NULL || *a == '\0')
165 		return NULL;
166 	if (b == NULL || *b == '\0')
167 		return strdup(a);
168 	if (strlen(b) > 1024*1024)
169 		return NULL;
170 	len = strlen(a) + strlen(b) + 2;
171 	if ((tmp = cp = strdup(b)) == NULL ||
172 	    (ret = calloc(1, len)) == NULL) {
173 		free(tmp);
174 		return NULL;
175 	}
176 	strlcpy(ret, a, len);
177 	for ((p = strsep(&cp, ",")); p && *p != '\0'; (p = strsep(&cp, ","))) {
178 		if (match_list(ret, p, NULL) != NULL)
179 			continue; /* Algorithm already present */
180 		if (strlcat(ret, ",", len) >= len ||
181 		    strlcat(ret, p, len) >= len) {
182 			free(tmp);
183 			free(ret);
184 			return NULL; /* Shouldn't happen */
185 		}
186 	}
187 	free(tmp);
188 	return ret;
189 }
190 
191 /*
192  * Assemble a list of algorithms from a default list and a string from a
193  * configuration file. The user-provided string may begin with '+' to
194  * indicate that it should be appended to the default.
195  */
196 int
197 kex_assemble_names(const char *def, char **list)
198 {
199 	char *ret;
200 
201 	if (list == NULL || *list == NULL || **list == '\0') {
202 		*list = strdup(def);
203 		return 0;
204 	}
205 	if (**list != '+') {
206 		return 0;
207 	}
208 
209 	if ((ret = kex_names_cat(def, *list + 1)) == NULL)
210 		return SSH_ERR_ALLOC_FAIL;
211 	free(*list);
212 	*list = ret;
213 	return 0;
214 }
215 
216 /* put algorithm proposal into buffer */
217 int
218 kex_prop2buf(struct sshbuf *b, char *proposal[PROPOSAL_MAX])
219 {
220 	u_int i;
221 	int r;
222 
223 	sshbuf_reset(b);
224 
225 	/*
226 	 * add a dummy cookie, the cookie will be overwritten by
227 	 * kex_send_kexinit(), each time a kexinit is set
228 	 */
229 	for (i = 0; i < KEX_COOKIE_LEN; i++) {
230 		if ((r = sshbuf_put_u8(b, 0)) != 0)
231 			return r;
232 	}
233 	for (i = 0; i < PROPOSAL_MAX; i++) {
234 		if ((r = sshbuf_put_cstring(b, proposal[i])) != 0)
235 			return r;
236 	}
237 	if ((r = sshbuf_put_u8(b, 0)) != 0 ||	/* first_kex_packet_follows */
238 	    (r = sshbuf_put_u32(b, 0)) != 0)	/* uint32 reserved */
239 		return r;
240 	return 0;
241 }
242 
243 /* parse buffer and return algorithm proposal */
244 int
245 kex_buf2prop(struct sshbuf *raw, int *first_kex_follows, char ***propp)
246 {
247 	struct sshbuf *b = NULL;
248 	u_char v;
249 	u_int i;
250 	char **proposal = NULL;
251 	int r;
252 
253 	*propp = NULL;
254 	if ((proposal = calloc(PROPOSAL_MAX, sizeof(char *))) == NULL)
255 		return SSH_ERR_ALLOC_FAIL;
256 	if ((b = sshbuf_fromb(raw)) == NULL) {
257 		r = SSH_ERR_ALLOC_FAIL;
258 		goto out;
259 	}
260 	if ((r = sshbuf_consume(b, KEX_COOKIE_LEN)) != 0) /* skip cookie */
261 		goto out;
262 	/* extract kex init proposal strings */
263 	for (i = 0; i < PROPOSAL_MAX; i++) {
264 		if ((r = sshbuf_get_cstring(b, &(proposal[i]), NULL)) != 0)
265 			goto out;
266 		debug2("%s: %s", proposal_names[i], proposal[i]);
267 	}
268 	/* first kex follows / reserved */
269 	if ((r = sshbuf_get_u8(b, &v)) != 0 ||	/* first_kex_follows */
270 	    (r = sshbuf_get_u32(b, &i)) != 0)	/* reserved */
271 		goto out;
272 	if (first_kex_follows != NULL)
273 		*first_kex_follows = v;
274 	debug2("first_kex_follows %d ", v);
275 	debug2("reserved %u ", i);
276 	r = 0;
277 	*propp = proposal;
278  out:
279 	if (r != 0 && proposal != NULL)
280 		kex_prop_free(proposal);
281 	sshbuf_free(b);
282 	return r;
283 }
284 
285 void
286 kex_prop_free(char **proposal)
287 {
288 	u_int i;
289 
290 	if (proposal == NULL)
291 		return;
292 	for (i = 0; i < PROPOSAL_MAX; i++)
293 		free(proposal[i]);
294 	free(proposal);
295 }
296 
297 /* ARGSUSED */
298 static int
299 kex_protocol_error(int type, u_int32_t seq, void *ctxt)
300 {
301 	struct ssh *ssh = active_state; /* XXX */
302 	int r;
303 
304 	error("kex protocol error: type %d seq %u", type, seq);
305 	if ((r = sshpkt_start(ssh, SSH2_MSG_UNIMPLEMENTED)) != 0 ||
306 	    (r = sshpkt_put_u32(ssh, seq)) != 0 ||
307 	    (r = sshpkt_send(ssh)) != 0)
308 		return r;
309 	return 0;
310 }
311 
312 static void
313 kex_reset_dispatch(struct ssh *ssh)
314 {
315 	ssh_dispatch_range(ssh, SSH2_MSG_TRANSPORT_MIN,
316 	    SSH2_MSG_TRANSPORT_MAX, &kex_protocol_error);
317 	ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_input_kexinit);
318 }
319 
320 static int
321 kex_send_ext_info(struct ssh *ssh)
322 {
323 	int r;
324 
325 	if ((r = sshpkt_start(ssh, SSH2_MSG_EXT_INFO)) != 0 ||
326 	    (r = sshpkt_put_u32(ssh, 1)) != 0 ||
327 	    (r = sshpkt_put_cstring(ssh, "server-sig-algs")) != 0 ||
328 	    (r = sshpkt_put_cstring(ssh, "rsa-sha2-256,rsa-sha2-512")) != 0 ||
329 	    (r = sshpkt_send(ssh)) != 0)
330 		return r;
331 	return 0;
332 }
333 
334 int
335 kex_send_newkeys(struct ssh *ssh)
336 {
337 	int r;
338 
339 	kex_reset_dispatch(ssh);
340 	if ((r = sshpkt_start(ssh, SSH2_MSG_NEWKEYS)) != 0 ||
341 	    (r = sshpkt_send(ssh)) != 0)
342 		return r;
343 	debug("SSH2_MSG_NEWKEYS sent");
344 	debug("expecting SSH2_MSG_NEWKEYS");
345 	ssh_dispatch_set(ssh, SSH2_MSG_NEWKEYS, &kex_input_newkeys);
346 	if (ssh->kex->ext_info_c)
347 		if ((r = kex_send_ext_info(ssh)) != 0)
348 			return r;
349 	return 0;
350 }
351 
352 int
353 kex_input_ext_info(int type, u_int32_t seq, void *ctxt)
354 {
355 	struct ssh *ssh = ctxt;
356 	struct kex *kex = ssh->kex;
357 	u_int32_t i, ninfo;
358 	char *name, *val, *found;
359 	int r;
360 
361 	debug("SSH2_MSG_EXT_INFO received");
362 	ssh_dispatch_set(ssh, SSH2_MSG_EXT_INFO, &kex_protocol_error);
363 	if ((r = sshpkt_get_u32(ssh, &ninfo)) != 0)
364 		return r;
365 	for (i = 0; i < ninfo; i++) {
366 		if ((r = sshpkt_get_cstring(ssh, &name, NULL)) != 0)
367 			return r;
368 		if ((r = sshpkt_get_cstring(ssh, &val, NULL)) != 0) {
369 			free(name);
370 			return r;
371 		}
372 		debug("%s: %s=<%s>", __func__, name, val);
373 		if (strcmp(name, "server-sig-algs") == 0) {
374 			found = match_list("rsa-sha2-256", val, NULL);
375 			if (found) {
376 				kex->rsa_sha2 = 256;
377 				free(found);
378 			}
379 			found = match_list("rsa-sha2-512", val, NULL);
380 			if (found) {
381 				kex->rsa_sha2 = 512;
382 				free(found);
383 			}
384 		}
385 		free(name);
386 		free(val);
387 	}
388 	return sshpkt_get_end(ssh);
389 }
390 
391 static int
392 kex_input_newkeys(int type, u_int32_t seq, void *ctxt)
393 {
394 	struct ssh *ssh = ctxt;
395 	struct kex *kex = ssh->kex;
396 	int r;
397 
398 	debug("SSH2_MSG_NEWKEYS received");
399 	ssh_dispatch_set(ssh, SSH2_MSG_NEWKEYS, &kex_protocol_error);
400 	if ((r = sshpkt_get_end(ssh)) != 0)
401 		return r;
402 	kex->done = 1;
403 	sshbuf_reset(kex->peer);
404 	/* sshbuf_reset(kex->my); */
405 	kex->flags &= ~KEX_INIT_SENT;
406 	free(kex->name);
407 	kex->name = NULL;
408 	return 0;
409 }
410 
411 int
412 kex_send_kexinit(struct ssh *ssh)
413 {
414 	u_char *cookie;
415 	struct kex *kex = ssh->kex;
416 	int r;
417 
418 	if (kex == NULL)
419 		return SSH_ERR_INTERNAL_ERROR;
420 	if (kex->flags & KEX_INIT_SENT)
421 		return 0;
422 	kex->done = 0;
423 
424 	/* generate a random cookie */
425 	if (sshbuf_len(kex->my) < KEX_COOKIE_LEN)
426 		return SSH_ERR_INVALID_FORMAT;
427 	if ((cookie = sshbuf_mutable_ptr(kex->my)) == NULL)
428 		return SSH_ERR_INTERNAL_ERROR;
429 	arc4random_buf(cookie, KEX_COOKIE_LEN);
430 
431 	if ((r = sshpkt_start(ssh, SSH2_MSG_KEXINIT)) != 0 ||
432 	    (r = sshpkt_putb(ssh, kex->my)) != 0 ||
433 	    (r = sshpkt_send(ssh)) != 0)
434 		return r;
435 	debug("SSH2_MSG_KEXINIT sent");
436 	kex->flags |= KEX_INIT_SENT;
437 	return 0;
438 }
439 
440 /* ARGSUSED */
441 int
442 kex_input_kexinit(int type, u_int32_t seq, void *ctxt)
443 {
444 	struct ssh *ssh = ctxt;
445 	struct kex *kex = ssh->kex;
446 	const u_char *ptr;
447 	u_int i;
448 	size_t dlen;
449 	int r;
450 
451 	debug("SSH2_MSG_KEXINIT received");
452 	if (kex == NULL)
453 		return SSH_ERR_INVALID_ARGUMENT;
454 
455 	ptr = sshpkt_ptr(ssh, &dlen);
456 	if ((r = sshbuf_put(kex->peer, ptr, dlen)) != 0)
457 		return r;
458 
459 	/* discard packet */
460 	for (i = 0; i < KEX_COOKIE_LEN; i++)
461 		if ((r = sshpkt_get_u8(ssh, NULL)) != 0)
462 			return r;
463 	for (i = 0; i < PROPOSAL_MAX; i++)
464 		if ((r = sshpkt_get_string(ssh, NULL, NULL)) != 0)
465 			return r;
466 	/*
467 	 * XXX RFC4253 sec 7: "each side MAY guess" - currently no supported
468 	 * KEX method has the server move first, but a server might be using
469 	 * a custom method or one that we otherwise don't support. We should
470 	 * be prepared to remember first_kex_follows here so we can eat a
471 	 * packet later.
472 	 * XXX2 - RFC4253 is kind of ambiguous on what first_kex_follows means
473 	 * for cases where the server *doesn't* go first. I guess we should
474 	 * ignore it when it is set for these cases, which is what we do now.
475 	 */
476 	if ((r = sshpkt_get_u8(ssh, NULL)) != 0 ||	/* first_kex_follows */
477 	    (r = sshpkt_get_u32(ssh, NULL)) != 0 ||	/* reserved */
478 	    (r = sshpkt_get_end(ssh)) != 0)
479 			return r;
480 
481 	if (!(kex->flags & KEX_INIT_SENT))
482 		if ((r = kex_send_kexinit(ssh)) != 0)
483 			return r;
484 	if ((r = kex_choose_conf(ssh)) != 0)
485 		return r;
486 
487 	if (kex->kex_type < KEX_MAX && kex->kex[kex->kex_type] != NULL)
488 		return (kex->kex[kex->kex_type])(ssh);
489 
490 	return SSH_ERR_INTERNAL_ERROR;
491 }
492 
493 int
494 kex_new(struct ssh *ssh, char *proposal[PROPOSAL_MAX], struct kex **kexp)
495 {
496 	struct kex *kex;
497 	int r;
498 
499 	*kexp = NULL;
500 	if ((kex = calloc(1, sizeof(*kex))) == NULL)
501 		return SSH_ERR_ALLOC_FAIL;
502 	if ((kex->peer = sshbuf_new()) == NULL ||
503 	    (kex->my = sshbuf_new()) == NULL) {
504 		r = SSH_ERR_ALLOC_FAIL;
505 		goto out;
506 	}
507 	if ((r = kex_prop2buf(kex->my, proposal)) != 0)
508 		goto out;
509 	kex->done = 0;
510 	kex_reset_dispatch(ssh);
511 	r = 0;
512 	*kexp = kex;
513  out:
514 	if (r != 0)
515 		kex_free(kex);
516 	return r;
517 }
518 
519 void
520 kex_free_newkeys(struct newkeys *newkeys)
521 {
522 	if (newkeys == NULL)
523 		return;
524 	if (newkeys->enc.key) {
525 		explicit_bzero(newkeys->enc.key, newkeys->enc.key_len);
526 		free(newkeys->enc.key);
527 		newkeys->enc.key = NULL;
528 	}
529 	if (newkeys->enc.iv) {
530 		explicit_bzero(newkeys->enc.iv, newkeys->enc.iv_len);
531 		free(newkeys->enc.iv);
532 		newkeys->enc.iv = NULL;
533 	}
534 	free(newkeys->enc.name);
535 	explicit_bzero(&newkeys->enc, sizeof(newkeys->enc));
536 	free(newkeys->comp.name);
537 	explicit_bzero(&newkeys->comp, sizeof(newkeys->comp));
538 	mac_clear(&newkeys->mac);
539 	if (newkeys->mac.key) {
540 		explicit_bzero(newkeys->mac.key, newkeys->mac.key_len);
541 		free(newkeys->mac.key);
542 		newkeys->mac.key = NULL;
543 	}
544 	free(newkeys->mac.name);
545 	explicit_bzero(&newkeys->mac, sizeof(newkeys->mac));
546 	explicit_bzero(newkeys, sizeof(*newkeys));
547 	free(newkeys);
548 }
549 
550 void
551 kex_free(struct kex *kex)
552 {
553 	u_int mode;
554 
555 #ifdef WITH_OPENSSL
556 	if (kex->dh)
557 		DH_free(kex->dh);
558 	if (kex->ec_client_key)
559 		EC_KEY_free(kex->ec_client_key);
560 #endif
561 	for (mode = 0; mode < MODE_MAX; mode++) {
562 		kex_free_newkeys(kex->newkeys[mode]);
563 		kex->newkeys[mode] = NULL;
564 	}
565 	sshbuf_free(kex->peer);
566 	sshbuf_free(kex->my);
567 	free(kex->session_id);
568 	free(kex->client_version_string);
569 	free(kex->server_version_string);
570 	free(kex->failed_choice);
571 	free(kex->hostkey_alg);
572 	free(kex->name);
573 	free(kex);
574 }
575 
576 int
577 kex_setup(struct ssh *ssh, char *proposal[PROPOSAL_MAX])
578 {
579 	int r;
580 
581 	if ((r = kex_new(ssh, proposal, &ssh->kex)) != 0)
582 		return r;
583 	if ((r = kex_send_kexinit(ssh)) != 0) {		/* we start */
584 		kex_free(ssh->kex);
585 		ssh->kex = NULL;
586 		return r;
587 	}
588 	return 0;
589 }
590 
591 /*
592  * Request key re-exchange, returns 0 on success or a ssherr.h error
593  * code otherwise. Must not be called if KEX is incomplete or in-progress.
594  */
595 int
596 kex_start_rekex(struct ssh *ssh)
597 {
598 	if (ssh->kex == NULL) {
599 		error("%s: no kex", __func__);
600 		return SSH_ERR_INTERNAL_ERROR;
601 	}
602 	if (ssh->kex->done == 0) {
603 		error("%s: requested twice", __func__);
604 		return SSH_ERR_INTERNAL_ERROR;
605 	}
606 	ssh->kex->done = 0;
607 	return kex_send_kexinit(ssh);
608 }
609 
610 static int
611 choose_enc(struct sshenc *enc, char *client, char *server)
612 {
613 	char *name = match_list(client, server, NULL);
614 
615 	if (name == NULL)
616 		return SSH_ERR_NO_CIPHER_ALG_MATCH;
617 	if ((enc->cipher = cipher_by_name(name)) == NULL)
618 		return SSH_ERR_INTERNAL_ERROR;
619 	enc->name = name;
620 	enc->enabled = 0;
621 	enc->iv = NULL;
622 	enc->iv_len = cipher_ivlen(enc->cipher);
623 	enc->key = NULL;
624 	enc->key_len = cipher_keylen(enc->cipher);
625 	enc->block_size = cipher_blocksize(enc->cipher);
626 	return 0;
627 }
628 
629 static int
630 choose_mac(struct ssh *ssh, struct sshmac *mac, char *client, char *server)
631 {
632 	char *name = match_list(client, server, NULL);
633 
634 	if (name == NULL)
635 		return SSH_ERR_NO_MAC_ALG_MATCH;
636 	if (mac_setup(mac, name) < 0)
637 		return SSH_ERR_INTERNAL_ERROR;
638 	/* truncate the key */
639 	if (ssh->compat & SSH_BUG_HMAC)
640 		mac->key_len = 16;
641 	mac->name = name;
642 	mac->key = NULL;
643 	mac->enabled = 0;
644 	return 0;
645 }
646 
647 static int
648 choose_comp(struct sshcomp *comp, char *client, char *server)
649 {
650 	char *name = match_list(client, server, NULL);
651 
652 	if (name == NULL)
653 		return SSH_ERR_NO_COMPRESS_ALG_MATCH;
654 	if (strcmp(name, "zlib@openssh.com") == 0) {
655 		comp->type = COMP_DELAYED;
656 	} else if (strcmp(name, "zlib") == 0) {
657 		comp->type = COMP_ZLIB;
658 	} else if (strcmp(name, "none") == 0) {
659 		comp->type = COMP_NONE;
660 	} else {
661 		return SSH_ERR_INTERNAL_ERROR;
662 	}
663 	comp->name = name;
664 	return 0;
665 }
666 
667 static int
668 choose_kex(struct kex *k, char *client, char *server)
669 {
670 	const struct kexalg *kexalg;
671 
672 	k->name = match_list(client, server, NULL);
673 
674 	debug("kex: algorithm: %s", k->name ? k->name : "(no match)");
675 	if (k->name == NULL)
676 		return SSH_ERR_NO_KEX_ALG_MATCH;
677 	if ((kexalg = kex_alg_by_name(k->name)) == NULL)
678 		return SSH_ERR_INTERNAL_ERROR;
679 	k->kex_type = kexalg->type;
680 	k->hash_alg = kexalg->hash_alg;
681 	k->ec_nid = kexalg->ec_nid;
682 	return 0;
683 }
684 
685 static int
686 choose_hostkeyalg(struct kex *k, char *client, char *server)
687 {
688 	k->hostkey_alg = match_list(client, server, NULL);
689 
690 	debug("kex: host key algorithm: %s",
691 	    k->hostkey_alg ? k->hostkey_alg : "(no match)");
692 	if (k->hostkey_alg == NULL)
693 		return SSH_ERR_NO_HOSTKEY_ALG_MATCH;
694 	k->hostkey_type = sshkey_type_from_name(k->hostkey_alg);
695 	if (k->hostkey_type == KEY_UNSPEC)
696 		return SSH_ERR_INTERNAL_ERROR;
697 	k->hostkey_nid = sshkey_ecdsa_nid_from_name(k->hostkey_alg);
698 	return 0;
699 }
700 
701 static int
702 proposals_match(char *my[PROPOSAL_MAX], char *peer[PROPOSAL_MAX])
703 {
704 	static int check[] = {
705 		PROPOSAL_KEX_ALGS, PROPOSAL_SERVER_HOST_KEY_ALGS, -1
706 	};
707 	int *idx;
708 	char *p;
709 
710 	for (idx = &check[0]; *idx != -1; idx++) {
711 		if ((p = strchr(my[*idx], ',')) != NULL)
712 			*p = '\0';
713 		if ((p = strchr(peer[*idx], ',')) != NULL)
714 			*p = '\0';
715 		if (strcmp(my[*idx], peer[*idx]) != 0) {
716 			debug2("proposal mismatch: my %s peer %s",
717 			    my[*idx], peer[*idx]);
718 			return (0);
719 		}
720 	}
721 	debug2("proposals match");
722 	return (1);
723 }
724 
725 static int
726 kex_choose_conf(struct ssh *ssh)
727 {
728 	struct kex *kex = ssh->kex;
729 	struct newkeys *newkeys;
730 	char **my = NULL, **peer = NULL;
731 	char **cprop, **sprop;
732 	int nenc, nmac, ncomp;
733 	u_int mode, ctos, need, dh_need, authlen;
734 	int r, first_kex_follows;
735 
736 	debug2("local %s KEXINIT proposal", kex->server ? "server" : "client");
737 	if ((r = kex_buf2prop(kex->my, NULL, &my)) != 0)
738 		goto out;
739 	debug2("peer %s KEXINIT proposal", kex->server ? "client" : "server");
740 	if ((r = kex_buf2prop(kex->peer, &first_kex_follows, &peer)) != 0)
741 		goto out;
742 
743 	if (kex->server) {
744 		cprop=peer;
745 		sprop=my;
746 	} else {
747 		cprop=my;
748 		sprop=peer;
749 	}
750 
751 	/* Check whether client supports ext_info_c */
752 	if (kex->server) {
753 		char *ext;
754 
755 		ext = match_list("ext-info-c", peer[PROPOSAL_KEX_ALGS], NULL);
756 		if (ext) {
757 			kex->ext_info_c = 1;
758 			free(ext);
759 		}
760 	}
761 
762 	/* Algorithm Negotiation */
763 	if ((r = choose_kex(kex, cprop[PROPOSAL_KEX_ALGS],
764 	    sprop[PROPOSAL_KEX_ALGS])) != 0) {
765 		kex->failed_choice = peer[PROPOSAL_KEX_ALGS];
766 		peer[PROPOSAL_KEX_ALGS] = NULL;
767 		goto out;
768 	}
769 	if ((r = choose_hostkeyalg(kex, cprop[PROPOSAL_SERVER_HOST_KEY_ALGS],
770 	    sprop[PROPOSAL_SERVER_HOST_KEY_ALGS])) != 0) {
771 		kex->failed_choice = peer[PROPOSAL_SERVER_HOST_KEY_ALGS];
772 		peer[PROPOSAL_SERVER_HOST_KEY_ALGS] = NULL;
773 		goto out;
774 	}
775 	for (mode = 0; mode < MODE_MAX; mode++) {
776 		if ((newkeys = calloc(1, sizeof(*newkeys))) == NULL) {
777 			r = SSH_ERR_ALLOC_FAIL;
778 			goto out;
779 		}
780 		kex->newkeys[mode] = newkeys;
781 		ctos = (!kex->server && mode == MODE_OUT) ||
782 		    (kex->server && mode == MODE_IN);
783 		nenc  = ctos ? PROPOSAL_ENC_ALGS_CTOS  : PROPOSAL_ENC_ALGS_STOC;
784 		nmac  = ctos ? PROPOSAL_MAC_ALGS_CTOS  : PROPOSAL_MAC_ALGS_STOC;
785 		ncomp = ctos ? PROPOSAL_COMP_ALGS_CTOS : PROPOSAL_COMP_ALGS_STOC;
786 		if ((r = choose_enc(&newkeys->enc, cprop[nenc],
787 		    sprop[nenc])) != 0) {
788 			kex->failed_choice = peer[nenc];
789 			peer[nenc] = NULL;
790 			goto out;
791 		}
792 		authlen = cipher_authlen(newkeys->enc.cipher);
793 		/* ignore mac for authenticated encryption */
794 		if (authlen == 0 &&
795 		    (r = choose_mac(ssh, &newkeys->mac, cprop[nmac],
796 		    sprop[nmac])) != 0) {
797 			kex->failed_choice = peer[nmac];
798 			peer[nmac] = NULL;
799 			goto out;
800 		}
801 		if ((r = choose_comp(&newkeys->comp, cprop[ncomp],
802 		    sprop[ncomp])) != 0) {
803 			kex->failed_choice = peer[ncomp];
804 			peer[ncomp] = NULL;
805 			goto out;
806 		}
807 		debug("kex: %s cipher: %s MAC: %s compression: %s",
808 		    ctos ? "client->server" : "server->client",
809 		    newkeys->enc.name,
810 		    authlen == 0 ? newkeys->mac.name : "<implicit>",
811 		    newkeys->comp.name);
812 	}
813 	need = dh_need = 0;
814 	for (mode = 0; mode < MODE_MAX; mode++) {
815 		newkeys = kex->newkeys[mode];
816 		need = MAX(need, newkeys->enc.key_len);
817 		need = MAX(need, newkeys->enc.block_size);
818 		need = MAX(need, newkeys->enc.iv_len);
819 		need = MAX(need, newkeys->mac.key_len);
820 		dh_need = MAX(dh_need, cipher_seclen(newkeys->enc.cipher));
821 		dh_need = MAX(dh_need, newkeys->enc.block_size);
822 		dh_need = MAX(dh_need, newkeys->enc.iv_len);
823 		dh_need = MAX(dh_need, newkeys->mac.key_len);
824 	}
825 	/* XXX need runden? */
826 	kex->we_need = need;
827 	kex->dh_need = dh_need;
828 
829 	/* ignore the next message if the proposals do not match */
830 	if (first_kex_follows && !proposals_match(my, peer) &&
831 	    !(ssh->compat & SSH_BUG_FIRSTKEX))
832 		ssh->dispatch_skip_packets = 1;
833 	r = 0;
834  out:
835 	kex_prop_free(my);
836 	kex_prop_free(peer);
837 	return r;
838 }
839 
840 static int
841 derive_key(struct ssh *ssh, int id, u_int need, u_char *hash, u_int hashlen,
842     const struct sshbuf *shared_secret, u_char **keyp)
843 {
844 	struct kex *kex = ssh->kex;
845 	struct ssh_digest_ctx *hashctx = NULL;
846 	char c = id;
847 	u_int have;
848 	size_t mdsz;
849 	u_char *digest;
850 	int r;
851 
852 	if ((mdsz = ssh_digest_bytes(kex->hash_alg)) == 0)
853 		return SSH_ERR_INVALID_ARGUMENT;
854 	if ((digest = calloc(1, roundup(need, mdsz))) == NULL) {
855 		r = SSH_ERR_ALLOC_FAIL;
856 		goto out;
857 	}
858 
859 	/* K1 = HASH(K || H || "A" || session_id) */
860 	if ((hashctx = ssh_digest_start(kex->hash_alg)) == NULL ||
861 	    ssh_digest_update_buffer(hashctx, shared_secret) != 0 ||
862 	    ssh_digest_update(hashctx, hash, hashlen) != 0 ||
863 	    ssh_digest_update(hashctx, &c, 1) != 0 ||
864 	    ssh_digest_update(hashctx, kex->session_id,
865 	    kex->session_id_len) != 0 ||
866 	    ssh_digest_final(hashctx, digest, mdsz) != 0) {
867 		r = SSH_ERR_LIBCRYPTO_ERROR;
868 		goto out;
869 	}
870 	ssh_digest_free(hashctx);
871 	hashctx = NULL;
872 
873 	/*
874 	 * expand key:
875 	 * Kn = HASH(K || H || K1 || K2 || ... || Kn-1)
876 	 * Key = K1 || K2 || ... || Kn
877 	 */
878 	for (have = mdsz; need > have; have += mdsz) {
879 		if ((hashctx = ssh_digest_start(kex->hash_alg)) == NULL ||
880 		    ssh_digest_update_buffer(hashctx, shared_secret) != 0 ||
881 		    ssh_digest_update(hashctx, hash, hashlen) != 0 ||
882 		    ssh_digest_update(hashctx, digest, have) != 0 ||
883 		    ssh_digest_final(hashctx, digest + have, mdsz) != 0) {
884 			r = SSH_ERR_LIBCRYPTO_ERROR;
885 			goto out;
886 		}
887 		ssh_digest_free(hashctx);
888 		hashctx = NULL;
889 	}
890 #ifdef DEBUG_KEX
891 	fprintf(stderr, "key '%c'== ", c);
892 	dump_digest("key", digest, need);
893 #endif
894 	*keyp = digest;
895 	digest = NULL;
896 	r = 0;
897  out:
898 	free(digest);
899 	ssh_digest_free(hashctx);
900 	return r;
901 }
902 
903 #define NKEYS	6
904 int
905 kex_derive_keys(struct ssh *ssh, u_char *hash, u_int hashlen,
906     const struct sshbuf *shared_secret)
907 {
908 	struct kex *kex = ssh->kex;
909 	u_char *keys[NKEYS];
910 	u_int i, j, mode, ctos;
911 	int r;
912 
913 	for (i = 0; i < NKEYS; i++) {
914 		if ((r = derive_key(ssh, 'A'+i, kex->we_need, hash, hashlen,
915 		    shared_secret, &keys[i])) != 0) {
916 			for (j = 0; j < i; j++)
917 				free(keys[j]);
918 			return r;
919 		}
920 	}
921 	for (mode = 0; mode < MODE_MAX; mode++) {
922 		ctos = (!kex->server && mode == MODE_OUT) ||
923 		    (kex->server && mode == MODE_IN);
924 		kex->newkeys[mode]->enc.iv  = keys[ctos ? 0 : 1];
925 		kex->newkeys[mode]->enc.key = keys[ctos ? 2 : 3];
926 		kex->newkeys[mode]->mac.key = keys[ctos ? 4 : 5];
927 	}
928 	return 0;
929 }
930 
931 #ifdef WITH_OPENSSL
932 int
933 kex_derive_keys_bn(struct ssh *ssh, u_char *hash, u_int hashlen,
934     const BIGNUM *secret)
935 {
936 	struct sshbuf *shared_secret;
937 	int r;
938 
939 	if ((shared_secret = sshbuf_new()) == NULL)
940 		return SSH_ERR_ALLOC_FAIL;
941 	if ((r = sshbuf_put_bignum2(shared_secret, secret)) == 0)
942 		r = kex_derive_keys(ssh, hash, hashlen, shared_secret);
943 	sshbuf_free(shared_secret);
944 	return r;
945 }
946 #endif
947 
948 #ifdef WITH_SSH1
949 int
950 derive_ssh1_session_id(BIGNUM *host_modulus, BIGNUM *server_modulus,
951     u_int8_t cookie[8], u_int8_t id[16])
952 {
953 	u_int8_t hbuf[2048], sbuf[2048], obuf[SSH_DIGEST_MAX_LENGTH];
954 	struct ssh_digest_ctx *hashctx = NULL;
955 	size_t hlen, slen;
956 	int r;
957 
958 	hlen = BN_num_bytes(host_modulus);
959 	slen = BN_num_bytes(server_modulus);
960 	if (hlen < (512 / 8) || (u_int)hlen > sizeof(hbuf) ||
961 	    slen < (512 / 8) || (u_int)slen > sizeof(sbuf))
962 		return SSH_ERR_KEY_BITS_MISMATCH;
963 	if (BN_bn2bin(host_modulus, hbuf) <= 0 ||
964 	    BN_bn2bin(server_modulus, sbuf) <= 0) {
965 		r = SSH_ERR_LIBCRYPTO_ERROR;
966 		goto out;
967 	}
968 	if ((hashctx = ssh_digest_start(SSH_DIGEST_MD5)) == NULL) {
969 		r = SSH_ERR_ALLOC_FAIL;
970 		goto out;
971 	}
972 	if (ssh_digest_update(hashctx, hbuf, hlen) != 0 ||
973 	    ssh_digest_update(hashctx, sbuf, slen) != 0 ||
974 	    ssh_digest_update(hashctx, cookie, 8) != 0 ||
975 	    ssh_digest_final(hashctx, obuf, sizeof(obuf)) != 0) {
976 		r = SSH_ERR_LIBCRYPTO_ERROR;
977 		goto out;
978 	}
979 	memcpy(id, obuf, ssh_digest_bytes(SSH_DIGEST_MD5));
980 	r = 0;
981  out:
982 	ssh_digest_free(hashctx);
983 	explicit_bzero(hbuf, sizeof(hbuf));
984 	explicit_bzero(sbuf, sizeof(sbuf));
985 	explicit_bzero(obuf, sizeof(obuf));
986 	return r;
987 }
988 #endif
989 
990 #if defined(DEBUG_KEX) || defined(DEBUG_KEXDH) || defined(DEBUG_KEXECDH)
991 void
992 dump_digest(char *msg, u_char *digest, int len)
993 {
994 	fprintf(stderr, "%s\n", msg);
995 	sshbuf_dump_data(digest, len, stderr);
996 }
997 #endif
998