xref: /openbsd-src/usr.bin/ssh/sshkey-xmss.c (revision c90a81c56dcebd6a1b73fe4aff9b03385b8e63b3)
1 /* $OpenBSD: sshkey-xmss.c,v 1.3 2018/07/09 21:59:10 markus Exp $ */
2 /*
3  * Copyright (c) 2017 Markus Friedl.  All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  * 1. Redistributions of source code must retain the above copyright
9  *    notice, this list of conditions and the following disclaimer.
10  * 2. Redistributions in binary form must reproduce the above copyright
11  *    notice, this list of conditions and the following disclaimer in the
12  *    documentation and/or other materials provided with the distribution.
13  *
14  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
15  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
16  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
17  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
18  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
19  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
20  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
21  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
23  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24  */
25 
26 #include <sys/types.h>
27 #include <sys/uio.h>
28 
29 #include <stdio.h>
30 #include <string.h>
31 #include <unistd.h>
32 #include <fcntl.h>
33 #include <errno.h>
34 
35 #include "ssh2.h"
36 #include "ssherr.h"
37 #include "sshbuf.h"
38 #include "cipher.h"
39 #include "sshkey.h"
40 #include "sshkey-xmss.h"
41 #include "atomicio.h"
42 
43 #include "xmss_fast.h"
44 
45 /* opaque internal XMSS state */
46 #define XMSS_MAGIC		"xmss-state-v1"
47 #define XMSS_CIPHERNAME		"aes256-gcm@openssh.com"
48 struct ssh_xmss_state {
49 	xmss_params	params;
50 	u_int32_t	n, w, h, k;
51 
52 	bds_state	bds;
53 	u_char		*stack;
54 	u_int32_t	stackoffset;
55 	u_char		*stacklevels;
56 	u_char		*auth;
57 	u_char		*keep;
58 	u_char		*th_nodes;
59 	u_char		*retain;
60 	treehash_inst	*treehash;
61 
62 	u_int32_t	idx;		/* state read from file */
63 	u_int32_t	maxidx;		/* restricted # of signatures */
64 	int		have_state;	/* .state file exists */
65 	int		lockfd;		/* locked in sshkey_xmss_get_state() */
66 	int		allow_update;	/* allow sshkey_xmss_update_state() */
67 	char		*enc_ciphername;/* encrypt state with cipher */
68 	u_char		*enc_keyiv;	/* encrypt state with key */
69 	u_int32_t	enc_keyiv_len;	/* length of enc_keyiv */
70 };
71 
72 int	 sshkey_xmss_init_bds_state(struct sshkey *);
73 int	 sshkey_xmss_init_enc_key(struct sshkey *, const char *);
74 void	 sshkey_xmss_free_bds(struct sshkey *);
75 int	 sshkey_xmss_get_state_from_file(struct sshkey *, const char *,
76 	    int *, sshkey_printfn *);
77 int	 sshkey_xmss_encrypt_state(const struct sshkey *, struct sshbuf *,
78 	    struct sshbuf **);
79 int	 sshkey_xmss_decrypt_state(const struct sshkey *, struct sshbuf *,
80 	    struct sshbuf **);
81 int	 sshkey_xmss_serialize_enc_key(const struct sshkey *, struct sshbuf *);
82 int	 sshkey_xmss_deserialize_enc_key(struct sshkey *, struct sshbuf *);
83 
84 #define PRINT(s...) do { if (pr) pr(s); } while (0)
85 
86 int
87 sshkey_xmss_init(struct sshkey *key, const char *name)
88 {
89 	struct ssh_xmss_state *state;
90 
91 	if (key->xmss_state != NULL)
92 		return SSH_ERR_INVALID_FORMAT;
93 	if (name == NULL)
94 		return SSH_ERR_INVALID_FORMAT;
95 	state = calloc(sizeof(struct ssh_xmss_state), 1);
96 	if (state == NULL)
97 		return SSH_ERR_ALLOC_FAIL;
98 	if (strcmp(name, XMSS_SHA2_256_W16_H10_NAME) == 0) {
99 		state->n = 32;
100 		state->w = 16;
101 		state->h = 10;
102 	} else if (strcmp(name, XMSS_SHA2_256_W16_H16_NAME) == 0) {
103 		state->n = 32;
104 		state->w = 16;
105 		state->h = 16;
106 	} else if (strcmp(name, XMSS_SHA2_256_W16_H20_NAME) == 0) {
107 		state->n = 32;
108 		state->w = 16;
109 		state->h = 20;
110 	} else {
111 		free(state);
112 		return SSH_ERR_KEY_TYPE_UNKNOWN;
113 	}
114 	if ((key->xmss_name = strdup(name)) == NULL) {
115 		free(state);
116 		return SSH_ERR_ALLOC_FAIL;
117 	}
118 	state->k = 2;	/* XXX hardcoded */
119 	state->lockfd = -1;
120 	if (xmss_set_params(&state->params, state->n, state->h, state->w,
121 	    state->k) != 0) {
122 		free(state);
123 		return SSH_ERR_INVALID_FORMAT;
124 	}
125 	key->xmss_state = state;
126 	return 0;
127 }
128 
129 void
130 sshkey_xmss_free_state(struct sshkey *key)
131 {
132 	struct ssh_xmss_state *state = key->xmss_state;
133 
134 	sshkey_xmss_free_bds(key);
135 	if (state) {
136 		if (state->enc_keyiv) {
137 			explicit_bzero(state->enc_keyiv, state->enc_keyiv_len);
138 			free(state->enc_keyiv);
139 		}
140 		free(state->enc_ciphername);
141 		free(state);
142 	}
143 	key->xmss_state = NULL;
144 }
145 
146 #define SSH_XMSS_K2_MAGIC	"k=2"
147 #define num_stack(x)		((x->h+1)*(x->n))
148 #define num_stacklevels(x)	(x->h+1)
149 #define num_auth(x)		((x->h)*(x->n))
150 #define num_keep(x)		((x->h >> 1)*(x->n))
151 #define num_th_nodes(x)		((x->h - x->k)*(x->n))
152 #define num_retain(x)		(((1ULL << x->k) - x->k - 1) * (x->n))
153 #define num_treehash(x)		((x->h) - (x->k))
154 
155 int
156 sshkey_xmss_init_bds_state(struct sshkey *key)
157 {
158 	struct ssh_xmss_state *state = key->xmss_state;
159 	u_int32_t i;
160 
161 	state->stackoffset = 0;
162 	if ((state->stack = calloc(num_stack(state), 1)) == NULL ||
163 	    (state->stacklevels = calloc(num_stacklevels(state), 1))== NULL ||
164 	    (state->auth = calloc(num_auth(state), 1)) == NULL ||
165 	    (state->keep = calloc(num_keep(state), 1)) == NULL ||
166 	    (state->th_nodes = calloc(num_th_nodes(state), 1)) == NULL ||
167 	    (state->retain = calloc(num_retain(state), 1)) == NULL ||
168 	    (state->treehash = calloc(num_treehash(state),
169 	    sizeof(treehash_inst))) == NULL) {
170 		sshkey_xmss_free_bds(key);
171 		return SSH_ERR_ALLOC_FAIL;
172 	}
173 	for (i = 0; i < state->h - state->k; i++)
174 		state->treehash[i].node = &state->th_nodes[state->n*i];
175 	xmss_set_bds_state(&state->bds, state->stack, state->stackoffset,
176 	    state->stacklevels, state->auth, state->keep, state->treehash,
177 	    state->retain, 0);
178 	return 0;
179 }
180 
181 void
182 sshkey_xmss_free_bds(struct sshkey *key)
183 {
184 	struct ssh_xmss_state *state = key->xmss_state;
185 
186 	if (state == NULL)
187 		return;
188 	free(state->stack);
189 	free(state->stacklevels);
190 	free(state->auth);
191 	free(state->keep);
192 	free(state->th_nodes);
193 	free(state->retain);
194 	free(state->treehash);
195 	state->stack = NULL;
196 	state->stacklevels = NULL;
197 	state->auth = NULL;
198 	state->keep = NULL;
199 	state->th_nodes = NULL;
200 	state->retain = NULL;
201 	state->treehash = NULL;
202 }
203 
204 void *
205 sshkey_xmss_params(const struct sshkey *key)
206 {
207 	struct ssh_xmss_state *state = key->xmss_state;
208 
209 	if (state == NULL)
210 		return NULL;
211 	return &state->params;
212 }
213 
214 void *
215 sshkey_xmss_bds_state(const struct sshkey *key)
216 {
217 	struct ssh_xmss_state *state = key->xmss_state;
218 
219 	if (state == NULL)
220 		return NULL;
221 	return &state->bds;
222 }
223 
224 int
225 sshkey_xmss_siglen(const struct sshkey *key, size_t *lenp)
226 {
227 	struct ssh_xmss_state *state = key->xmss_state;
228 
229 	if (lenp == NULL)
230 		return SSH_ERR_INVALID_ARGUMENT;
231 	if (state == NULL)
232 		return SSH_ERR_INVALID_FORMAT;
233 	*lenp = 4 + state->n +
234 	    state->params.wots_par.keysize +
235 	    state->h * state->n;
236 	return 0;
237 }
238 
239 size_t
240 sshkey_xmss_pklen(const struct sshkey *key)
241 {
242 	struct ssh_xmss_state *state = key->xmss_state;
243 
244 	if (state == NULL)
245 		return 0;
246 	return state->n * 2;
247 }
248 
249 size_t
250 sshkey_xmss_sklen(const struct sshkey *key)
251 {
252 	struct ssh_xmss_state *state = key->xmss_state;
253 
254 	if (state == NULL)
255 		return 0;
256 	return state->n * 4 + 4;
257 }
258 
259 int
260 sshkey_xmss_init_enc_key(struct sshkey *k, const char *ciphername)
261 {
262 	struct ssh_xmss_state *state = k->xmss_state;
263 	const struct sshcipher *cipher;
264 	size_t keylen = 0, ivlen = 0;
265 
266 	if (state == NULL)
267 		return SSH_ERR_INVALID_ARGUMENT;
268 	if ((cipher = cipher_by_name(ciphername)) == NULL)
269 		return SSH_ERR_INTERNAL_ERROR;
270 	if ((state->enc_ciphername = strdup(ciphername)) == NULL)
271 		return SSH_ERR_ALLOC_FAIL;
272 	keylen = cipher_keylen(cipher);
273 	ivlen = cipher_ivlen(cipher);
274 	state->enc_keyiv_len = keylen + ivlen;
275 	if ((state->enc_keyiv = calloc(state->enc_keyiv_len, 1)) == NULL) {
276 		free(state->enc_ciphername);
277 		state->enc_ciphername = NULL;
278 		return SSH_ERR_ALLOC_FAIL;
279 	}
280 	arc4random_buf(state->enc_keyiv, state->enc_keyiv_len);
281 	return 0;
282 }
283 
284 int
285 sshkey_xmss_serialize_enc_key(const struct sshkey *k, struct sshbuf *b)
286 {
287 	struct ssh_xmss_state *state = k->xmss_state;
288 	int r;
289 
290 	if (state == NULL || state->enc_keyiv == NULL ||
291 	    state->enc_ciphername == NULL)
292 		return SSH_ERR_INVALID_ARGUMENT;
293 	if ((r = sshbuf_put_cstring(b, state->enc_ciphername)) != 0 ||
294 	    (r = sshbuf_put_string(b, state->enc_keyiv,
295 	    state->enc_keyiv_len)) != 0)
296 		return r;
297 	return 0;
298 }
299 
300 int
301 sshkey_xmss_deserialize_enc_key(struct sshkey *k, struct sshbuf *b)
302 {
303 	struct ssh_xmss_state *state = k->xmss_state;
304 	size_t len;
305 	int r;
306 
307 	if (state == NULL)
308 		return SSH_ERR_INVALID_ARGUMENT;
309 	if ((r = sshbuf_get_cstring(b, &state->enc_ciphername, NULL)) != 0 ||
310 	    (r = sshbuf_get_string(b, &state->enc_keyiv, &len)) != 0)
311 		return r;
312 	state->enc_keyiv_len = len;
313 	return 0;
314 }
315 
316 int
317 sshkey_xmss_serialize_pk_info(const struct sshkey *k, struct sshbuf *b,
318     enum sshkey_serialize_rep opts)
319 {
320 	struct ssh_xmss_state *state = k->xmss_state;
321 	u_char have_info = 1;
322 	u_int32_t idx;
323 	int r;
324 
325 	if (state == NULL)
326 		return SSH_ERR_INVALID_ARGUMENT;
327 	if (opts != SSHKEY_SERIALIZE_INFO)
328 		return 0;
329 	idx = k->xmss_sk ? PEEK_U32(k->xmss_sk) : state->idx;
330 	if ((r = sshbuf_put_u8(b, have_info)) != 0 ||
331 	    (r = sshbuf_put_u32(b, idx)) != 0 ||
332 	    (r = sshbuf_put_u32(b, state->maxidx)) != 0)
333 		return r;
334 	return 0;
335 }
336 
337 int
338 sshkey_xmss_deserialize_pk_info(struct sshkey *k, struct sshbuf *b)
339 {
340 	struct ssh_xmss_state *state = k->xmss_state;
341 	u_char have_info;
342 	int r;
343 
344 	if (state == NULL)
345 		return SSH_ERR_INVALID_ARGUMENT;
346 	/* optional */
347 	if (sshbuf_len(b) == 0)
348 		return 0;
349 	if ((r = sshbuf_get_u8(b, &have_info)) != 0)
350 		return r;
351 	if (have_info != 1)
352 		return SSH_ERR_INVALID_ARGUMENT;
353 	if ((r = sshbuf_get_u32(b, &state->idx)) != 0 ||
354 	    (r = sshbuf_get_u32(b, &state->maxidx)) != 0)
355 		return r;
356 	return 0;
357 }
358 
359 int
360 sshkey_xmss_generate_private_key(struct sshkey *k, u_int bits)
361 {
362 	int r;
363 	const char *name;
364 
365 	if (bits == 10) {
366 		name = XMSS_SHA2_256_W16_H10_NAME;
367 	} else if (bits == 16) {
368 		name = XMSS_SHA2_256_W16_H16_NAME;
369 	} else if (bits == 20) {
370 		name = XMSS_SHA2_256_W16_H20_NAME;
371 	} else {
372 		name = XMSS_DEFAULT_NAME;
373 	}
374 	if ((r = sshkey_xmss_init(k, name)) != 0 ||
375 	    (r = sshkey_xmss_init_bds_state(k)) != 0 ||
376 	    (r = sshkey_xmss_init_enc_key(k, XMSS_CIPHERNAME)) != 0)
377 		return r;
378 	if ((k->xmss_pk = malloc(sshkey_xmss_pklen(k))) == NULL ||
379 	    (k->xmss_sk = malloc(sshkey_xmss_sklen(k))) == NULL) {
380 		return SSH_ERR_ALLOC_FAIL;
381 	}
382 	xmss_keypair(k->xmss_pk, k->xmss_sk, sshkey_xmss_bds_state(k),
383 	    sshkey_xmss_params(k));
384 	return 0;
385 }
386 
387 int
388 sshkey_xmss_get_state_from_file(struct sshkey *k, const char *filename,
389     int *have_file, sshkey_printfn *pr)
390 {
391 	struct sshbuf *b = NULL, *enc = NULL;
392 	int ret = SSH_ERR_SYSTEM_ERROR, r, fd = -1;
393 	u_int32_t len;
394 	unsigned char buf[4], *data = NULL;
395 
396 	*have_file = 0;
397 	if ((fd = open(filename, O_RDONLY)) >= 0) {
398 		*have_file = 1;
399 		if (atomicio(read, fd, buf, sizeof(buf)) != sizeof(buf)) {
400 			PRINT("%s: corrupt state file: %s", __func__, filename);
401 			goto done;
402 		}
403 		len = PEEK_U32(buf);
404 		if ((data = calloc(len, 1)) == NULL) {
405 			ret = SSH_ERR_ALLOC_FAIL;
406 			goto done;
407 		}
408 		if (atomicio(read, fd, data, len) != len) {
409 			PRINT("%s: cannot read blob: %s", __func__, filename);
410 			goto done;
411 		}
412 		if ((enc = sshbuf_from(data, len)) == NULL) {
413 			ret = SSH_ERR_ALLOC_FAIL;
414 			goto done;
415 		}
416 		sshkey_xmss_free_bds(k);
417 		if ((r = sshkey_xmss_decrypt_state(k, enc, &b)) != 0) {
418 			ret = r;
419 			goto done;
420 		}
421 		if ((r = sshkey_xmss_deserialize_state(k, b)) != 0) {
422 			ret = r;
423 			goto done;
424 		}
425 		ret = 0;
426 	}
427 done:
428 	if (fd != -1)
429 		close(fd);
430 	free(data);
431 	sshbuf_free(enc);
432 	sshbuf_free(b);
433 	return ret;
434 }
435 
436 int
437 sshkey_xmss_get_state(const struct sshkey *k, sshkey_printfn *pr)
438 {
439 	struct ssh_xmss_state *state = k->xmss_state;
440 	u_int32_t idx = 0;
441 	char *filename = NULL;
442 	char *statefile = NULL, *ostatefile = NULL, *lockfile = NULL;
443 	int lockfd = -1, have_state = 0, have_ostate, tries = 0;
444 	int ret = SSH_ERR_INVALID_ARGUMENT, r;
445 
446 	if (state == NULL)
447 		goto done;
448 	/*
449 	 * If maxidx is set, then we are allowed a limited number
450 	 * of signatures, but don't need to access the disk.
451 	 * Otherwise we need to deal with the on-disk state.
452 	 */
453 	if (state->maxidx) {
454 		/* xmss_sk always contains the current state */
455 		idx = PEEK_U32(k->xmss_sk);
456 		if (idx < state->maxidx) {
457 			state->allow_update = 1;
458 			return 0;
459 		}
460 		return SSH_ERR_INVALID_ARGUMENT;
461 	}
462 	if ((filename = k->xmss_filename) == NULL)
463 		goto done;
464 	if (asprintf(&lockfile, "%s.lock", filename) < 0 ||
465 	    asprintf(&statefile, "%s.state", filename) < 0 ||
466 	    asprintf(&ostatefile, "%s.ostate", filename) < 0) {
467 		ret = SSH_ERR_ALLOC_FAIL;
468 		goto done;
469 	}
470 	if ((lockfd = open(lockfile, O_CREAT|O_RDONLY, 0600)) < 0) {
471 		ret = SSH_ERR_SYSTEM_ERROR;
472 		PRINT("%s: cannot open/create: %s", __func__, lockfile);
473 		goto done;
474 	}
475 	while (flock(lockfd, LOCK_EX|LOCK_NB) < 0) {
476 		if (errno != EWOULDBLOCK) {
477 			ret = SSH_ERR_SYSTEM_ERROR;
478 			PRINT("%s: cannot lock: %s", __func__, lockfile);
479 			goto done;
480 		}
481 		if (++tries > 10) {
482 			ret = SSH_ERR_SYSTEM_ERROR;
483 			PRINT("%s: giving up on: %s", __func__, lockfile);
484 			goto done;
485 		}
486 		usleep(1000*100*tries);
487 	}
488 	/* XXX no longer const */
489 	if ((r = sshkey_xmss_get_state_from_file((struct sshkey *)k,
490 	    statefile, &have_state, pr)) != 0) {
491 		if ((r = sshkey_xmss_get_state_from_file((struct sshkey *)k,
492 		    ostatefile, &have_ostate, pr)) == 0) {
493 			state->allow_update = 1;
494 			r = sshkey_xmss_forward_state(k, 1);
495 			state->idx = PEEK_U32(k->xmss_sk);
496 			state->allow_update = 0;
497 		}
498 	}
499 	if (!have_state && !have_ostate) {
500 		/* check that bds state is initialized */
501 		if (state->bds.auth == NULL)
502 			goto done;
503 		PRINT("%s: start from scratch idx 0: %u", __func__, state->idx);
504 	} else if (r != 0) {
505 		ret = r;
506 		goto done;
507 	}
508 	if (state->idx + 1 < state->idx) {
509 		PRINT("%s: state wrap: %u", __func__, state->idx);
510 		goto done;
511 	}
512 	state->have_state = have_state;
513 	state->lockfd = lockfd;
514 	state->allow_update = 1;
515 	lockfd = -1;
516 	ret = 0;
517 done:
518 	if (lockfd != -1)
519 		close(lockfd);
520 	free(lockfile);
521 	free(statefile);
522 	free(ostatefile);
523 	return ret;
524 }
525 
526 int
527 sshkey_xmss_forward_state(const struct sshkey *k, u_int32_t reserve)
528 {
529 	struct ssh_xmss_state *state = k->xmss_state;
530 	u_char *sig = NULL;
531 	size_t required_siglen;
532 	unsigned long long smlen;
533 	u_char data;
534 	int ret, r;
535 
536 	if (state == NULL || !state->allow_update)
537 		return SSH_ERR_INVALID_ARGUMENT;
538 	if (reserve == 0)
539 		return SSH_ERR_INVALID_ARGUMENT;
540 	if (state->idx + reserve <= state->idx)
541 		return SSH_ERR_INVALID_ARGUMENT;
542 	if ((r = sshkey_xmss_siglen(k, &required_siglen)) != 0)
543 		return r;
544 	if ((sig = malloc(required_siglen)) == NULL)
545 		return SSH_ERR_ALLOC_FAIL;
546 	while (reserve-- > 0) {
547 		state->idx = PEEK_U32(k->xmss_sk);
548 		smlen = required_siglen;
549 		if ((ret = xmss_sign(k->xmss_sk, sshkey_xmss_bds_state(k),
550 		    sig, &smlen, &data, 0, sshkey_xmss_params(k))) != 0) {
551 			r = SSH_ERR_INVALID_ARGUMENT;
552 			break;
553 		}
554 	}
555 	free(sig);
556 	return r;
557 }
558 
559 int
560 sshkey_xmss_update_state(const struct sshkey *k, sshkey_printfn *pr)
561 {
562 	struct ssh_xmss_state *state = k->xmss_state;
563 	struct sshbuf *b = NULL, *enc = NULL;
564 	u_int32_t idx = 0;
565 	unsigned char buf[4];
566 	char *filename = NULL;
567 	char *statefile = NULL, *ostatefile = NULL, *nstatefile = NULL;
568 	int fd = -1;
569 	int ret = SSH_ERR_INVALID_ARGUMENT;
570 
571 	if (state == NULL || !state->allow_update)
572 		return ret;
573 	if (state->maxidx) {
574 		/* no update since the number of signatures is limited */
575 		ret = 0;
576 		goto done;
577 	}
578 	idx = PEEK_U32(k->xmss_sk);
579 	if (idx == state->idx) {
580 		/* no signature happened, no need to update */
581 		ret = 0;
582 		goto done;
583 	} else if (idx != state->idx + 1) {
584 		PRINT("%s: more than one signature happened: idx %u state %u",
585 		     __func__, idx, state->idx);
586 		goto done;
587 	}
588 	state->idx = idx;
589 	if ((filename = k->xmss_filename) == NULL)
590 		goto done;
591 	if (asprintf(&statefile, "%s.state", filename) < 0 ||
592 	    asprintf(&ostatefile, "%s.ostate", filename) < 0 ||
593 	    asprintf(&nstatefile, "%s.nstate", filename) < 0) {
594 		ret = SSH_ERR_ALLOC_FAIL;
595 		goto done;
596 	}
597 	unlink(nstatefile);
598 	if ((b = sshbuf_new()) == NULL) {
599 		ret = SSH_ERR_ALLOC_FAIL;
600 		goto done;
601 	}
602 	if ((ret = sshkey_xmss_serialize_state(k, b)) != 0) {
603 		PRINT("%s: SERLIALIZE FAILED: %d", __func__, ret);
604 		goto done;
605 	}
606 	if ((ret = sshkey_xmss_encrypt_state(k, b, &enc)) != 0) {
607 		PRINT("%s: ENCRYPT FAILED: %d", __func__, ret);
608 		goto done;
609 	}
610 	if ((fd = open(nstatefile, O_CREAT|O_WRONLY|O_EXCL, 0600)) < 0) {
611 		ret = SSH_ERR_SYSTEM_ERROR;
612 		PRINT("%s: open new state file: %s", __func__, nstatefile);
613 		goto done;
614 	}
615 	POKE_U32(buf, sshbuf_len(enc));
616 	if (atomicio(vwrite, fd, buf, sizeof(buf)) != sizeof(buf)) {
617 		ret = SSH_ERR_SYSTEM_ERROR;
618 		PRINT("%s: write new state file hdr: %s", __func__, nstatefile);
619 		close(fd);
620 		goto done;
621 	}
622 	if (atomicio(vwrite, fd, sshbuf_mutable_ptr(enc), sshbuf_len(enc)) !=
623 	    sshbuf_len(enc)) {
624 		ret = SSH_ERR_SYSTEM_ERROR;
625 		PRINT("%s: write new state file data: %s", __func__, nstatefile);
626 		close(fd);
627 		goto done;
628 	}
629 	if (fsync(fd) < 0) {
630 		ret = SSH_ERR_SYSTEM_ERROR;
631 		PRINT("%s: sync new state file: %s", __func__, nstatefile);
632 		close(fd);
633 		goto done;
634 	}
635 	if (close(fd) < 0) {
636 		ret = SSH_ERR_SYSTEM_ERROR;
637 		PRINT("%s: close new state file: %s", __func__, nstatefile);
638 		goto done;
639 	}
640 	if (state->have_state) {
641 		unlink(ostatefile);
642 		if (link(statefile, ostatefile)) {
643 			ret = SSH_ERR_SYSTEM_ERROR;
644 			PRINT("%s: backup state %s to %s", __func__, statefile,
645 			    ostatefile);
646 			goto done;
647 		}
648 	}
649 	if (rename(nstatefile, statefile) < 0) {
650 		ret = SSH_ERR_SYSTEM_ERROR;
651 		PRINT("%s: rename %s to %s", __func__, nstatefile, statefile);
652 		goto done;
653 	}
654 	ret = 0;
655 done:
656 	if (state->lockfd != -1) {
657 		close(state->lockfd);
658 		state->lockfd = -1;
659 	}
660 	if (nstatefile)
661 		unlink(nstatefile);
662 	free(statefile);
663 	free(ostatefile);
664 	free(nstatefile);
665 	sshbuf_free(b);
666 	sshbuf_free(enc);
667 	return ret;
668 }
669 
670 int
671 sshkey_xmss_serialize_state(const struct sshkey *k, struct sshbuf *b)
672 {
673 	struct ssh_xmss_state *state = k->xmss_state;
674 	treehash_inst *th;
675 	u_int32_t i, node;
676 	int r;
677 
678 	if (state == NULL)
679 		return SSH_ERR_INVALID_ARGUMENT;
680 	if (state->stack == NULL)
681 		return SSH_ERR_INVALID_ARGUMENT;
682 	state->stackoffset = state->bds.stackoffset;	/* copy back */
683 	if ((r = sshbuf_put_cstring(b, SSH_XMSS_K2_MAGIC)) != 0 ||
684 	    (r = sshbuf_put_u32(b, state->idx)) != 0 ||
685 	    (r = sshbuf_put_string(b, state->stack, num_stack(state))) != 0 ||
686 	    (r = sshbuf_put_u32(b, state->stackoffset)) != 0 ||
687 	    (r = sshbuf_put_string(b, state->stacklevels, num_stacklevels(state))) != 0 ||
688 	    (r = sshbuf_put_string(b, state->auth, num_auth(state))) != 0 ||
689 	    (r = sshbuf_put_string(b, state->keep, num_keep(state))) != 0 ||
690 	    (r = sshbuf_put_string(b, state->th_nodes, num_th_nodes(state))) != 0 ||
691 	    (r = sshbuf_put_string(b, state->retain, num_retain(state))) != 0 ||
692 	    (r = sshbuf_put_u32(b, num_treehash(state))) != 0)
693 		return r;
694 	for (i = 0; i < num_treehash(state); i++) {
695 		th = &state->treehash[i];
696 		node = th->node - state->th_nodes;
697 		if ((r = sshbuf_put_u32(b, th->h)) != 0 ||
698 		    (r = sshbuf_put_u32(b, th->next_idx)) != 0 ||
699 		    (r = sshbuf_put_u32(b, th->stackusage)) != 0 ||
700 		    (r = sshbuf_put_u8(b, th->completed)) != 0 ||
701 		    (r = sshbuf_put_u32(b, node)) != 0)
702 			return r;
703 	}
704 	return 0;
705 }
706 
707 int
708 sshkey_xmss_serialize_state_opt(const struct sshkey *k, struct sshbuf *b,
709     enum sshkey_serialize_rep opts)
710 {
711 	struct ssh_xmss_state *state = k->xmss_state;
712 	int r = SSH_ERR_INVALID_ARGUMENT;
713 
714 	if (state == NULL)
715 		return SSH_ERR_INVALID_ARGUMENT;
716 	if ((r = sshbuf_put_u8(b, opts)) != 0)
717 		return r;
718 	switch (opts) {
719 	case SSHKEY_SERIALIZE_STATE:
720 		r = sshkey_xmss_serialize_state(k, b);
721 		break;
722 	case SSHKEY_SERIALIZE_FULL:
723 		if ((r = sshkey_xmss_serialize_enc_key(k, b)) != 0)
724 			break;
725 		r = sshkey_xmss_serialize_state(k, b);
726 		break;
727 	case SSHKEY_SERIALIZE_DEFAULT:
728 		r = 0;
729 		break;
730 	default:
731 		r = SSH_ERR_INVALID_ARGUMENT;
732 		break;
733 	}
734 	return r;
735 }
736 
737 int
738 sshkey_xmss_deserialize_state(struct sshkey *k, struct sshbuf *b)
739 {
740 	struct ssh_xmss_state *state = k->xmss_state;
741 	treehash_inst *th;
742 	u_int32_t i, lh, node;
743 	size_t ls, lsl, la, lk, ln, lr;
744 	char *magic;
745 	int r;
746 
747 	if (state == NULL)
748 		return SSH_ERR_INVALID_ARGUMENT;
749 	if (k->xmss_sk == NULL)
750 		return SSH_ERR_INVALID_ARGUMENT;
751 	if ((state->treehash = calloc(num_treehash(state),
752 	    sizeof(treehash_inst))) == NULL)
753 		return SSH_ERR_ALLOC_FAIL;
754 	if ((r = sshbuf_get_cstring(b, &magic, NULL)) != 0 ||
755 	    (r = sshbuf_get_u32(b, &state->idx)) != 0 ||
756 	    (r = sshbuf_get_string(b, &state->stack, &ls)) != 0 ||
757 	    (r = sshbuf_get_u32(b, &state->stackoffset)) != 0 ||
758 	    (r = sshbuf_get_string(b, &state->stacklevels, &lsl)) != 0 ||
759 	    (r = sshbuf_get_string(b, &state->auth, &la)) != 0 ||
760 	    (r = sshbuf_get_string(b, &state->keep, &lk)) != 0 ||
761 	    (r = sshbuf_get_string(b, &state->th_nodes, &ln)) != 0 ||
762 	    (r = sshbuf_get_string(b, &state->retain, &lr)) != 0 ||
763 	    (r = sshbuf_get_u32(b, &lh)) != 0)
764 		return r;
765 	if (strcmp(magic, SSH_XMSS_K2_MAGIC) != 0)
766 		return SSH_ERR_INVALID_ARGUMENT;
767 	/* XXX check stackoffset */
768 	if (ls != num_stack(state) ||
769 	    lsl != num_stacklevels(state) ||
770 	    la != num_auth(state) ||
771 	    lk != num_keep(state) ||
772 	    ln != num_th_nodes(state) ||
773 	    lr != num_retain(state) ||
774 	    lh != num_treehash(state))
775 		return SSH_ERR_INVALID_ARGUMENT;
776 	for (i = 0; i < num_treehash(state); i++) {
777 		th = &state->treehash[i];
778 		if ((r = sshbuf_get_u32(b, &th->h)) != 0 ||
779 		    (r = sshbuf_get_u32(b, &th->next_idx)) != 0 ||
780 		    (r = sshbuf_get_u32(b, &th->stackusage)) != 0 ||
781 		    (r = sshbuf_get_u8(b, &th->completed)) != 0 ||
782 		    (r = sshbuf_get_u32(b, &node)) != 0)
783 			return r;
784 		if (node < num_th_nodes(state))
785 			th->node = &state->th_nodes[node];
786 	}
787 	POKE_U32(k->xmss_sk, state->idx);
788 	xmss_set_bds_state(&state->bds, state->stack, state->stackoffset,
789 	    state->stacklevels, state->auth, state->keep, state->treehash,
790 	    state->retain, 0);
791 	return 0;
792 }
793 
794 int
795 sshkey_xmss_deserialize_state_opt(struct sshkey *k, struct sshbuf *b)
796 {
797 	enum sshkey_serialize_rep opts;
798 	u_char have_state;
799 	int r;
800 
801 	if ((r = sshbuf_get_u8(b, &have_state)) != 0)
802 		return r;
803 
804 	opts = have_state;
805 	switch (opts) {
806 	case SSHKEY_SERIALIZE_DEFAULT:
807 		r = 0;
808 		break;
809 	case SSHKEY_SERIALIZE_STATE:
810 		if ((r = sshkey_xmss_deserialize_state(k, b)) != 0)
811 			return r;
812 		break;
813 	case SSHKEY_SERIALIZE_FULL:
814 		if ((r = sshkey_xmss_deserialize_enc_key(k, b)) != 0 ||
815 		    (r = sshkey_xmss_deserialize_state(k, b)) != 0)
816 			return r;
817 		break;
818 	default:
819 		r = SSH_ERR_INVALID_FORMAT;
820 		break;
821 	}
822 	return r;
823 }
824 
825 int
826 sshkey_xmss_encrypt_state(const struct sshkey *k, struct sshbuf *b,
827    struct sshbuf **retp)
828 {
829 	struct ssh_xmss_state *state = k->xmss_state;
830 	struct sshbuf *encrypted = NULL, *encoded = NULL, *padded = NULL;
831 	struct sshcipher_ctx *ciphercontext = NULL;
832 	const struct sshcipher *cipher;
833 	u_char *cp, *key, *iv = NULL;
834 	size_t i, keylen, ivlen, blocksize, authlen, encrypted_len, aadlen;
835 	int r = SSH_ERR_INTERNAL_ERROR;
836 
837 	if (retp != NULL)
838 		*retp = NULL;
839 	if (state == NULL ||
840 	    state->enc_keyiv == NULL ||
841 	    state->enc_ciphername == NULL)
842 		return SSH_ERR_INTERNAL_ERROR;
843 	if ((cipher = cipher_by_name(state->enc_ciphername)) == NULL) {
844 		r = SSH_ERR_INTERNAL_ERROR;
845 		goto out;
846 	}
847 	blocksize = cipher_blocksize(cipher);
848 	keylen = cipher_keylen(cipher);
849 	ivlen = cipher_ivlen(cipher);
850 	authlen = cipher_authlen(cipher);
851 	if (state->enc_keyiv_len != keylen + ivlen) {
852 		r = SSH_ERR_INVALID_FORMAT;
853 		goto out;
854 	}
855 	key = state->enc_keyiv;
856 	if ((encrypted = sshbuf_new()) == NULL ||
857 	    (encoded = sshbuf_new()) == NULL ||
858 	    (padded = sshbuf_new()) == NULL ||
859 	    (iv = malloc(ivlen)) == NULL) {
860 		r = SSH_ERR_ALLOC_FAIL;
861 		goto out;
862 	}
863 
864 	/* replace first 4 bytes of IV with index to ensure uniqueness */
865 	memcpy(iv, key + keylen, ivlen);
866 	POKE_U32(iv, state->idx);
867 
868 	if ((r = sshbuf_put(encoded, XMSS_MAGIC, sizeof(XMSS_MAGIC))) != 0 ||
869 	    (r = sshbuf_put_u32(encoded, state->idx)) != 0)
870 		goto out;
871 
872 	/* padded state will be encrypted */
873 	if ((r = sshbuf_putb(padded, b)) != 0)
874 		goto out;
875 	i = 0;
876 	while (sshbuf_len(padded) % blocksize) {
877 		if ((r = sshbuf_put_u8(padded, ++i & 0xff)) != 0)
878 			goto out;
879 	}
880 	encrypted_len = sshbuf_len(padded);
881 
882 	/* header including the length of state is used as AAD */
883 	if ((r = sshbuf_put_u32(encoded, encrypted_len)) != 0)
884 		goto out;
885 	aadlen = sshbuf_len(encoded);
886 
887 	/* concat header and state */
888 	if ((r = sshbuf_putb(encoded, padded)) != 0)
889 		goto out;
890 
891 	/* reserve space for encryption of encoded data plus auth tag */
892 	/* encrypt at offset addlen */
893 	if ((r = sshbuf_reserve(encrypted,
894 	    encrypted_len + aadlen + authlen, &cp)) != 0 ||
895 	    (r = cipher_init(&ciphercontext, cipher, key, keylen,
896 	    iv, ivlen, 1)) != 0 ||
897 	    (r = cipher_crypt(ciphercontext, 0, cp, sshbuf_ptr(encoded),
898 	    encrypted_len, aadlen, authlen)) != 0)
899 		goto out;
900 
901 	/* success */
902 	r = 0;
903  out:
904 	if (retp != NULL) {
905 		*retp = encrypted;
906 		encrypted = NULL;
907 	}
908 	sshbuf_free(padded);
909 	sshbuf_free(encoded);
910 	sshbuf_free(encrypted);
911 	cipher_free(ciphercontext);
912 	free(iv);
913 	return r;
914 }
915 
916 int
917 sshkey_xmss_decrypt_state(const struct sshkey *k, struct sshbuf *encoded,
918    struct sshbuf **retp)
919 {
920 	struct ssh_xmss_state *state = k->xmss_state;
921 	struct sshbuf *copy = NULL, *decrypted = NULL;
922 	struct sshcipher_ctx *ciphercontext = NULL;
923 	const struct sshcipher *cipher = NULL;
924 	u_char *key, *iv = NULL, *dp;
925 	size_t keylen, ivlen, authlen, aadlen;
926 	u_int blocksize, encrypted_len, index;
927 	int r = SSH_ERR_INTERNAL_ERROR;
928 
929 	if (retp != NULL)
930 		*retp = NULL;
931 	if (state == NULL ||
932 	    state->enc_keyiv == NULL ||
933 	    state->enc_ciphername == NULL)
934 		return SSH_ERR_INTERNAL_ERROR;
935 	if ((cipher = cipher_by_name(state->enc_ciphername)) == NULL) {
936 		r = SSH_ERR_INVALID_FORMAT;
937 		goto out;
938 	}
939 	blocksize = cipher_blocksize(cipher);
940 	keylen = cipher_keylen(cipher);
941 	ivlen = cipher_ivlen(cipher);
942 	authlen = cipher_authlen(cipher);
943 	if (state->enc_keyiv_len != keylen + ivlen) {
944 		r = SSH_ERR_INTERNAL_ERROR;
945 		goto out;
946 	}
947 	key = state->enc_keyiv;
948 
949 	if ((copy = sshbuf_fromb(encoded)) == NULL ||
950 	    (decrypted = sshbuf_new()) == NULL ||
951 	    (iv = malloc(ivlen)) == NULL) {
952 		r = SSH_ERR_ALLOC_FAIL;
953 		goto out;
954 	}
955 
956 	/* check magic */
957 	if (sshbuf_len(encoded) < sizeof(XMSS_MAGIC) ||
958 	    memcmp(sshbuf_ptr(encoded), XMSS_MAGIC, sizeof(XMSS_MAGIC))) {
959 		r = SSH_ERR_INVALID_FORMAT;
960 		goto out;
961 	}
962 	/* parse public portion */
963 	if ((r = sshbuf_consume(encoded, sizeof(XMSS_MAGIC))) != 0 ||
964 	    (r = sshbuf_get_u32(encoded, &index)) != 0 ||
965 	    (r = sshbuf_get_u32(encoded, &encrypted_len)) != 0)
966 		goto out;
967 
968 	/* check size of encrypted key blob */
969 	if (encrypted_len < blocksize || (encrypted_len % blocksize) != 0) {
970 		r = SSH_ERR_INVALID_FORMAT;
971 		goto out;
972 	}
973 	/* check that an appropriate amount of auth data is present */
974 	if (sshbuf_len(encoded) < encrypted_len + authlen) {
975 		r = SSH_ERR_INVALID_FORMAT;
976 		goto out;
977 	}
978 
979 	aadlen = sshbuf_len(copy) - sshbuf_len(encoded);
980 
981 	/* replace first 4 bytes of IV with index to ensure uniqueness */
982 	memcpy(iv, key + keylen, ivlen);
983 	POKE_U32(iv, index);
984 
985 	/* decrypt private state of key */
986 	if ((r = sshbuf_reserve(decrypted, aadlen + encrypted_len, &dp)) != 0 ||
987 	    (r = cipher_init(&ciphercontext, cipher, key, keylen,
988 	    iv, ivlen, 0)) != 0 ||
989 	    (r = cipher_crypt(ciphercontext, 0, dp, sshbuf_ptr(copy),
990 	    encrypted_len, aadlen, authlen)) != 0)
991 		goto out;
992 
993 	/* there should be no trailing data */
994 	if ((r = sshbuf_consume(encoded, encrypted_len + authlen)) != 0)
995 		goto out;
996 	if (sshbuf_len(encoded) != 0) {
997 		r = SSH_ERR_INVALID_FORMAT;
998 		goto out;
999 	}
1000 
1001 	/* remove AAD */
1002 	if ((r = sshbuf_consume(decrypted, aadlen)) != 0)
1003 		goto out;
1004 	/* XXX encrypted includes unchecked padding */
1005 
1006 	/* success */
1007 	r = 0;
1008 	if (retp != NULL) {
1009 		*retp = decrypted;
1010 		decrypted = NULL;
1011 	}
1012  out:
1013 	cipher_free(ciphercontext);
1014 	sshbuf_free(copy);
1015 	sshbuf_free(decrypted);
1016 	free(iv);
1017 	return r;
1018 }
1019 
1020 u_int32_t
1021 sshkey_xmss_signatures_left(const struct sshkey *k)
1022 {
1023 	struct ssh_xmss_state *state = k->xmss_state;
1024 	u_int32_t idx;
1025 
1026 	if (sshkey_type_plain(k->type) == KEY_XMSS && state &&
1027 	    state->maxidx) {
1028 		idx = k->xmss_sk ? PEEK_U32(k->xmss_sk) : state->idx;
1029 		if (idx < state->maxidx)
1030 			return state->maxidx - idx;
1031 	}
1032 	return 0;
1033 }
1034 
1035 int
1036 sshkey_xmss_enable_maxsign(struct sshkey *k, u_int32_t maxsign)
1037 {
1038 	struct ssh_xmss_state *state = k->xmss_state;
1039 
1040 	if (sshkey_type_plain(k->type) != KEY_XMSS)
1041 		return SSH_ERR_INVALID_ARGUMENT;
1042 	if (maxsign == 0)
1043 		return 0;
1044 	if (state->idx + maxsign < state->idx)
1045 		return SSH_ERR_INVALID_ARGUMENT;
1046 	state->maxidx = state->idx + maxsign;
1047 	return 0;
1048 }
1049