xref: /netbsd-src/crypto/external/bsd/openssh/dist/sshkey-xmss.c (revision bdc22b2e01993381dcefeff2bc9b56ca75a4235c)
1 /*	$NetBSD: sshkey-xmss.c,v 1.2 2018/04/06 18:59:00 christos Exp $	*/
2 /* $OpenBSD: sshkey-xmss.c,v 1.1 2018/02/23 15:58:38 markus Exp $ */
3 /*
4  * Copyright (c) 2017 Markus Friedl.  All rights reserved.
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions
8  * are met:
9  * 1. Redistributions of source code must retain the above copyright
10  *    notice, this list of conditions and the following disclaimer.
11  * 2. Redistributions in binary form must reproduce the above copyright
12  *    notice, this list of conditions and the following disclaimer in the
13  *    documentation and/or other materials provided with the distribution.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
17  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
18  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
19  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
20  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
21  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
22  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
24  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25  */
26 #include "includes.h"
27 __RCSID("$NetBSD: sshkey-xmss.c,v 1.2 2018/04/06 18:59:00 christos Exp $");
28 
29 #include <sys/types.h>
30 #include <sys/uio.h>
31 
32 #include <stdio.h>
33 #include <string.h>
34 #include <unistd.h>
35 #include <fcntl.h>
36 #include <errno.h>
37 
38 #include "ssh2.h"
39 #include "ssherr.h"
40 #include "sshbuf.h"
41 #include "cipher.h"
42 #include "sshkey.h"
43 #include "sshkey-xmss.h"
44 #include "atomicio.h"
45 
46 #include "xmss_fast.h"
47 
48 /* opaque internal XMSS state */
49 #define XMSS_MAGIC		"xmss-state-v1"
50 #define XMSS_CIPHERNAME		"aes256-gcm@openssh.com"
51 struct ssh_xmss_state {
52 	xmss_params	params;
53 	u_int32_t	n, w, h, k;
54 
55 	bds_state	bds;
56 	u_char		*stack;
57 	u_int32_t	stackoffset;
58 	u_char		*stacklevels;
59 	u_char		*auth;
60 	u_char		*keep;
61 	u_char		*th_nodes;
62 	u_char		*retain;
63 	treehash_inst	*treehash;
64 
65 	u_int32_t	idx;		/* state read from file */
66 	u_int32_t	maxidx;		/* resticted # of signatures */
67 	int		have_state;	/* .state file exists */
68 	int		lockfd;		/* locked in sshkey_xmss_get_state() */
69 	int		allow_update;	/* allow sshkey_xmss_update_state() */
70 	char		*enc_ciphername;/* encrypt state with cipher */
71 	u_char		*enc_keyiv;	/* encrypt state with key */
72 	u_int32_t	enc_keyiv_len;	/* length of enc_keyiv */
73 };
74 
75 int	 sshkey_xmss_init_bds_state(struct sshkey *);
76 int	 sshkey_xmss_init_enc_key(struct sshkey *, const char *);
77 void	 sshkey_xmss_free_bds(struct sshkey *);
78 int	 sshkey_xmss_get_state_from_file(struct sshkey *, const char *,
79 	    int *, sshkey_printfn *);
80 int	 sshkey_xmss_encrypt_state(const struct sshkey *, struct sshbuf *,
81 	    struct sshbuf **);
82 int	 sshkey_xmss_decrypt_state(const struct sshkey *, struct sshbuf *,
83 	    struct sshbuf **);
84 int	 sshkey_xmss_serialize_enc_key(const struct sshkey *, struct sshbuf *);
85 int	 sshkey_xmss_deserialize_enc_key(struct sshkey *, struct sshbuf *);
86 
87 #define PRINT(s...) do { if (pr) pr(s); } while (/*CONSTCOND*/0)
88 
89 int
90 sshkey_xmss_init(struct sshkey *key, const char *name)
91 {
92 	struct ssh_xmss_state *state;
93 
94 	if (key->xmss_state != NULL)
95 		return SSH_ERR_INVALID_FORMAT;
96 	if (name == NULL)
97 		return SSH_ERR_INVALID_FORMAT;
98 	state = calloc(sizeof(struct ssh_xmss_state), 1);
99 	if (state == NULL)
100 		return SSH_ERR_ALLOC_FAIL;
101 	if (strcmp(name, XMSS_SHA2_256_W16_H10_NAME) == 0) {
102 		state->n = 32;
103 		state->w = 16;
104 		state->h = 10;
105 	} else if (strcmp(name, XMSS_SHA2_256_W16_H16_NAME) == 0) {
106 		state->n = 32;
107 		state->w = 16;
108 		state->h = 16;
109 	} else if (strcmp(name, XMSS_SHA2_256_W16_H20_NAME) == 0) {
110 		state->n = 32;
111 		state->w = 16;
112 		state->h = 20;
113 	} else {
114 		free(state);
115 		return SSH_ERR_KEY_TYPE_UNKNOWN;
116 	}
117 	if ((key->xmss_name = strdup(name)) == NULL) {
118 		free(state);
119 		return SSH_ERR_ALLOC_FAIL;
120 	}
121 	state->k = 2;	/* XXX hardcoded */
122 	state->lockfd = -1;
123 	if (xmss_set_params(&state->params, state->n, state->h, state->w,
124 	    state->k) != 0) {
125 		free(state);
126 		return SSH_ERR_INVALID_FORMAT;
127 	}
128 	key->xmss_state = state;
129 	return 0;
130 }
131 
132 void
133 sshkey_xmss_free_state(struct sshkey *key)
134 {
135 	struct ssh_xmss_state *state = key->xmss_state;
136 
137 	sshkey_xmss_free_bds(key);
138 	if (state) {
139 		if (state->enc_keyiv) {
140 			explicit_bzero(state->enc_keyiv, state->enc_keyiv_len);
141 			free(state->enc_keyiv);
142 		}
143 		free(state->enc_ciphername);
144 		free(state);
145 	}
146 	key->xmss_state = NULL;
147 }
148 
149 #define SSH_XMSS_K2_MAGIC	"k=2"
150 #define num_stack(x)		((x->h+1)*(x->n))
151 #define num_stacklevels(x)	(x->h+1)
152 #define num_auth(x)		((x->h)*(x->n))
153 #define num_keep(x)		((x->h >> 1)*(x->n))
154 #define num_th_nodes(x)		((x->h - x->k)*(x->n))
155 #define num_retain(x)		(((1ULL << x->k) - x->k - 1) * (x->n))
156 #define num_treehash(x)		((x->h) - (x->k))
157 
158 int
159 sshkey_xmss_init_bds_state(struct sshkey *key)
160 {
161 	struct ssh_xmss_state *state = key->xmss_state;
162 	u_int32_t i;
163 
164 	state->stackoffset = 0;
165 	if ((state->stack = calloc(num_stack(state), 1)) == NULL ||
166 	    (state->stacklevels = calloc(num_stacklevels(state), 1))== NULL ||
167 	    (state->auth = calloc(num_auth(state), 1)) == NULL ||
168 	    (state->keep = calloc(num_keep(state), 1)) == NULL ||
169 	    (state->th_nodes = calloc(num_th_nodes(state), 1)) == NULL ||
170 	    (state->retain = calloc(num_retain(state), 1)) == NULL ||
171 	    (state->treehash = calloc(num_treehash(state),
172 	    sizeof(treehash_inst))) == NULL) {
173 		sshkey_xmss_free_bds(key);
174 		return SSH_ERR_ALLOC_FAIL;
175 	}
176 	for (i = 0; i < state->h - state->k; i++)
177 		state->treehash[i].node = &state->th_nodes[state->n*i];
178 	xmss_set_bds_state(&state->bds, state->stack, state->stackoffset,
179 	    state->stacklevels, state->auth, state->keep, state->treehash,
180 	    state->retain, 0);
181 	return 0;
182 }
183 
184 void
185 sshkey_xmss_free_bds(struct sshkey *key)
186 {
187 	struct ssh_xmss_state *state = key->xmss_state;
188 
189 	if (state == NULL)
190 		return;
191 	free(state->stack);
192 	free(state->stacklevels);
193 	free(state->auth);
194 	free(state->keep);
195 	free(state->th_nodes);
196 	free(state->retain);
197 	free(state->treehash);
198 	state->stack = NULL;
199 	state->stacklevels = NULL;
200 	state->auth = NULL;
201 	state->keep = NULL;
202 	state->th_nodes = NULL;
203 	state->retain = NULL;
204 	state->treehash = NULL;
205 }
206 
207 void *
208 sshkey_xmss_params(const struct sshkey *key)
209 {
210 	struct ssh_xmss_state *state = key->xmss_state;
211 
212 	if (state == NULL)
213 		return NULL;
214 	return &state->params;
215 }
216 
217 void *
218 sshkey_xmss_bds_state(const struct sshkey *key)
219 {
220 	struct ssh_xmss_state *state = key->xmss_state;
221 
222 	if (state == NULL)
223 		return NULL;
224 	return &state->bds;
225 }
226 
227 int
228 sshkey_xmss_siglen(const struct sshkey *key, size_t *lenp)
229 {
230 	struct ssh_xmss_state *state = key->xmss_state;
231 
232 	if (lenp == NULL)
233 		return SSH_ERR_INVALID_ARGUMENT;
234 	if (state == NULL)
235 		return SSH_ERR_INVALID_FORMAT;
236 	*lenp = 4 + state->n +
237 	    state->params.wots_par.keysize +
238 	    state->h * state->n;
239 	return 0;
240 }
241 
242 size_t
243 sshkey_xmss_pklen(const struct sshkey *key)
244 {
245 	struct ssh_xmss_state *state = key->xmss_state;
246 
247 	if (state == NULL)
248 		return 0;
249 	return state->n * 2;
250 }
251 
252 size_t
253 sshkey_xmss_sklen(const struct sshkey *key)
254 {
255 	struct ssh_xmss_state *state = key->xmss_state;
256 
257 	if (state == NULL)
258 		return 0;
259 	return state->n * 4 + 4;
260 }
261 
262 int
263 sshkey_xmss_init_enc_key(struct sshkey *k, const char *ciphername)
264 {
265 	struct ssh_xmss_state *state = k->xmss_state;
266 	const struct sshcipher *cipher;
267 	size_t keylen = 0, ivlen = 0;
268 
269 	if (state == NULL)
270 		return SSH_ERR_INVALID_ARGUMENT;
271 	if ((cipher = cipher_by_name(ciphername)) == NULL)
272 		return SSH_ERR_INTERNAL_ERROR;
273 	if ((state->enc_ciphername = strdup(ciphername)) == NULL)
274 		return SSH_ERR_ALLOC_FAIL;
275 	keylen = cipher_keylen(cipher);
276 	ivlen = cipher_ivlen(cipher);
277 	state->enc_keyiv_len = keylen + ivlen;
278 	if ((state->enc_keyiv = calloc(state->enc_keyiv_len, 1)) == NULL) {
279 		free(state->enc_ciphername);
280 		state->enc_ciphername = NULL;
281 		return SSH_ERR_ALLOC_FAIL;
282 	}
283 	arc4random_buf(state->enc_keyiv, state->enc_keyiv_len);
284 	return 0;
285 }
286 
287 int
288 sshkey_xmss_serialize_enc_key(const struct sshkey *k, struct sshbuf *b)
289 {
290 	struct ssh_xmss_state *state = k->xmss_state;
291 	int r;
292 
293 	if (state == NULL || state->enc_keyiv == NULL ||
294 	    state->enc_ciphername == NULL)
295 		return SSH_ERR_INVALID_ARGUMENT;
296 	if ((r = sshbuf_put_cstring(b, state->enc_ciphername)) != 0 ||
297 	    (r = sshbuf_put_string(b, state->enc_keyiv,
298 	    state->enc_keyiv_len)) != 0)
299 		return r;
300 	return 0;
301 }
302 
303 int
304 sshkey_xmss_deserialize_enc_key(struct sshkey *k, struct sshbuf *b)
305 {
306 	struct ssh_xmss_state *state = k->xmss_state;
307 	size_t len;
308 	int r;
309 
310 	if (state == NULL)
311 		return SSH_ERR_INVALID_ARGUMENT;
312 	if ((r = sshbuf_get_cstring(b, &state->enc_ciphername, NULL)) != 0 ||
313 	    (r = sshbuf_get_string(b, &state->enc_keyiv, &len)) != 0)
314 		return r;
315 	state->enc_keyiv_len = len;
316 	return 0;
317 }
318 
319 int
320 sshkey_xmss_serialize_pk_info(const struct sshkey *k, struct sshbuf *b,
321     enum sshkey_serialize_rep opts)
322 {
323 	struct ssh_xmss_state *state = k->xmss_state;
324 	u_char have_info = 1;
325 	u_int32_t idx;
326 	int r;
327 
328 	if (state == NULL)
329 		return SSH_ERR_INVALID_ARGUMENT;
330 	if (opts != SSHKEY_SERIALIZE_INFO)
331 		return 0;
332 	idx = k->xmss_sk ? PEEK_U32(k->xmss_sk) : state->idx;
333 	if ((r = sshbuf_put_u8(b, have_info)) != 0 ||
334 	    (r = sshbuf_put_u32(b, idx)) != 0 ||
335 	    (r = sshbuf_put_u32(b, state->maxidx)) != 0)
336 		return r;
337 	return 0;
338 }
339 
340 int
341 sshkey_xmss_deserialize_pk_info(struct sshkey *k, struct sshbuf *b)
342 {
343 	struct ssh_xmss_state *state = k->xmss_state;
344 	u_char have_info;
345 	int r;
346 
347 	if (state == NULL)
348 		return SSH_ERR_INVALID_ARGUMENT;
349 	/* optional */
350 	if (sshbuf_len(b) == 0)
351 		return 0;
352 	if ((r = sshbuf_get_u8(b, &have_info)) != 0)
353 		return r;
354 	if (have_info != 1)
355 		return SSH_ERR_INVALID_ARGUMENT;
356 	if ((r = sshbuf_get_u32(b, &state->idx)) != 0 ||
357 	    (r = sshbuf_get_u32(b, &state->maxidx)) != 0)
358 		return r;
359 	return 0;
360 }
361 
362 int
363 sshkey_xmss_generate_private_key(struct sshkey *k, u_int bits)
364 {
365 	int r;
366 	const char *name;
367 
368 	if (bits == 10) {
369 		name = XMSS_SHA2_256_W16_H10_NAME;
370 	} else if (bits == 16) {
371 		name = XMSS_SHA2_256_W16_H16_NAME;
372 	} else if (bits == 20) {
373 		name = XMSS_SHA2_256_W16_H20_NAME;
374 	} else {
375 		name = XMSS_DEFAULT_NAME;
376 	}
377 	if ((r = sshkey_xmss_init(k, name)) != 0 ||
378 	    (r = sshkey_xmss_init_bds_state(k)) != 0 ||
379 	    (r = sshkey_xmss_init_enc_key(k, XMSS_CIPHERNAME)) != 0)
380 		return r;
381 	if ((k->xmss_pk = malloc(sshkey_xmss_pklen(k))) == NULL ||
382 	    (k->xmss_sk = malloc(sshkey_xmss_sklen(k))) == NULL) {
383 		return SSH_ERR_ALLOC_FAIL;
384 	}
385 	xmss_keypair(k->xmss_pk, k->xmss_sk, sshkey_xmss_bds_state(k),
386 	    sshkey_xmss_params(k));
387 	return 0;
388 }
389 
390 int
391 sshkey_xmss_get_state_from_file(struct sshkey *k, const char *filename,
392     int *have_file, sshkey_printfn *pr)
393 {
394 	struct sshbuf *b = NULL, *enc = NULL;
395 	int ret = SSH_ERR_SYSTEM_ERROR, r, fd = -1;
396 	u_int32_t len;
397 	unsigned char buf[4], *data = NULL;
398 
399 	*have_file = 0;
400 	if ((fd = open(filename, O_RDONLY)) >= 0) {
401 		*have_file = 1;
402 		if (atomicio(read, fd, buf, sizeof(buf)) != sizeof(buf)) {
403 			PRINT("%s: corrupt state file: %s", __func__, filename);
404 			goto done;
405 		}
406 		len = PEEK_U32(buf);
407 		if ((data = calloc(len, 1)) == NULL) {
408 			ret = SSH_ERR_ALLOC_FAIL;
409 			goto done;
410 		}
411 		if (atomicio(read, fd, data, len) != len) {
412 			PRINT("%s: cannot read blob: %s", __func__, filename);
413 			goto done;
414 		}
415 		if ((enc = sshbuf_from(data, len)) == NULL) {
416 			ret = SSH_ERR_ALLOC_FAIL;
417 			goto done;
418 		}
419 		sshkey_xmss_free_bds(k);
420 		if ((r = sshkey_xmss_decrypt_state(k, enc, &b)) != 0) {
421 			ret = r;
422 			goto done;
423 		}
424 		if ((r = sshkey_xmss_deserialize_state(k, b)) != 0) {
425 			ret = r;
426 			goto done;
427 		}
428 		ret = 0;
429 	}
430 done:
431 	if (fd != -1)
432 		close(fd);
433 	free(data);
434 	sshbuf_free(enc);
435 	sshbuf_free(b);
436 	return ret;
437 }
438 
439 int
440 sshkey_xmss_get_state(const struct sshkey *k, sshkey_printfn *pr)
441 {
442 	struct ssh_xmss_state *state = k->xmss_state;
443 	u_int32_t idx = 0;
444 	char *filename = NULL;
445 	char *statefile = NULL, *ostatefile = NULL, *lockfile = NULL;
446 	int lockfd = -1, have_state = 0, have_ostate, tries = 0;
447 	int ret = SSH_ERR_INVALID_ARGUMENT, r;
448 
449 	if (state == NULL)
450 		goto done;
451 	/*
452 	 * If maxidx is set, then we are allowed a limited number
453 	 * of signatures, but don't need to access the disk.
454 	 * Otherwise we need to deal with the on-disk state.
455 	 */
456 	if (state->maxidx) {
457 		/* xmss_sk always contains the current state */
458 		idx = PEEK_U32(k->xmss_sk);
459 		if (idx < state->maxidx) {
460 			state->allow_update = 1;
461 			return 0;
462 		}
463 		return SSH_ERR_INVALID_ARGUMENT;
464 	}
465 	if ((filename = k->xmss_filename) == NULL)
466 		goto done;
467 	if (asprintf(&lockfile, "%s.lock", filename) < 0 ||
468 	    asprintf(&statefile, "%s.state", filename) < 0 ||
469 	    asprintf(&ostatefile, "%s.ostate", filename) < 0) {
470 		ret = SSH_ERR_ALLOC_FAIL;
471 		goto done;
472 	}
473 	if ((lockfd = open(lockfile, O_CREAT|O_RDONLY, 0600)) < 0) {
474 		ret = SSH_ERR_SYSTEM_ERROR;
475 		PRINT("%s: cannot open/create: %s", __func__, lockfile);
476 		goto done;
477 	}
478 	while (flock(lockfd, LOCK_EX|LOCK_NB) < 0) {
479 		if (errno != EWOULDBLOCK) {
480 			ret = SSH_ERR_SYSTEM_ERROR;
481 			PRINT("%s: cannot lock: %s", __func__, lockfile);
482 			goto done;
483 		}
484 		if (++tries > 10) {
485 			ret = SSH_ERR_SYSTEM_ERROR;
486 			PRINT("%s: giving up on: %s", __func__, lockfile);
487 			goto done;
488 		}
489 		usleep(1000*100*tries);
490 	}
491 	/* XXX no longer const */
492 	if ((r = sshkey_xmss_get_state_from_file(__UNCONST(k),
493 	    statefile, &have_state, pr)) != 0) {
494 		if ((r = sshkey_xmss_get_state_from_file(__UNCONST(k),
495 		    ostatefile, &have_ostate, pr)) == 0) {
496 			state->allow_update = 1;
497 			r = sshkey_xmss_forward_state(k, 1);
498 			state->idx = PEEK_U32(k->xmss_sk);
499 			state->allow_update = 0;
500 		}
501 	}
502 	if (!have_state && !have_ostate) {
503 		/* check that bds state is initialized */
504 		if (state->bds.auth == NULL)
505 			goto done;
506 		PRINT("%s: start from scratch idx 0: %u", __func__, state->idx);
507 	} else if (r != 0) {
508 		ret = r;
509 		goto done;
510 	}
511 	if (state->idx + 1 < state->idx) {
512 		PRINT("%s: state wrap: %u", __func__, state->idx);
513 		goto done;
514 	}
515 	state->have_state = have_state;
516 	state->lockfd = lockfd;
517 	state->allow_update = 1;
518 	lockfd = -1;
519 	ret = 0;
520 done:
521 	if (lockfd != -1)
522 		close(lockfd);
523 	free(lockfile);
524 	free(statefile);
525 	free(ostatefile);
526 	return ret;
527 }
528 
529 int
530 sshkey_xmss_forward_state(const struct sshkey *k, u_int32_t reserve)
531 {
532 	struct ssh_xmss_state *state = k->xmss_state;
533 	u_char *sig = NULL;
534 	size_t required_siglen;
535 	unsigned long long smlen;
536 	u_char data;
537 	int ret, r;
538 
539 	if (state == NULL || !state->allow_update)
540 		return SSH_ERR_INVALID_ARGUMENT;
541 	if (reserve == 0)
542 		return SSH_ERR_INVALID_ARGUMENT;
543 	if (state->idx + reserve <= state->idx)
544 		return SSH_ERR_INVALID_ARGUMENT;
545 	if ((r = sshkey_xmss_siglen(k, &required_siglen)) != 0)
546 		return r;
547 	if ((sig = malloc(required_siglen)) == NULL)
548 		return SSH_ERR_ALLOC_FAIL;
549 	while (reserve-- > 0) {
550 		state->idx = PEEK_U32(k->xmss_sk);
551 		smlen = required_siglen;
552 		if ((ret = xmss_sign(k->xmss_sk, sshkey_xmss_bds_state(k),
553 		    sig, &smlen, &data, 0, sshkey_xmss_params(k))) != 0) {
554 			r = SSH_ERR_INVALID_ARGUMENT;
555 			break;
556 		}
557 	}
558 	free(sig);
559 	return r;
560 }
561 
562 int
563 sshkey_xmss_update_state(const struct sshkey *k, sshkey_printfn *pr)
564 {
565 	struct ssh_xmss_state *state = k->xmss_state;
566 	struct sshbuf *b = NULL, *enc = NULL;
567 	u_int32_t idx = 0;
568 	unsigned char buf[4];
569 	char *filename = NULL;
570 	char *statefile = NULL, *ostatefile = NULL, *nstatefile = NULL;
571 	int fd = -1;
572 	int ret = SSH_ERR_INVALID_ARGUMENT;
573 
574 	if (state == NULL || !state->allow_update)
575 		return ret;
576 	if (state->maxidx) {
577 		/* no update since the number of signatures is limited */
578 		ret = 0;
579 		goto done;
580 	}
581 	idx = PEEK_U32(k->xmss_sk);
582 	if (idx == state->idx) {
583 		/* no signature happend, no need to update */
584 		ret = 0;
585 		goto done;
586 	} else if (idx != state->idx + 1) {
587 		PRINT("%s: more than one signature happened: idx %u state %u",
588 		     __func__, idx, state->idx);
589 		goto done;
590 	}
591 	state->idx = idx;
592 	if ((filename = k->xmss_filename) == NULL)
593 		goto done;
594 	if (asprintf(&statefile, "%s.state", filename) < 0 ||
595 	    asprintf(&ostatefile, "%s.ostate", filename) < 0 ||
596 	    asprintf(&nstatefile, "%s.nstate", filename) < 0) {
597 		ret = SSH_ERR_ALLOC_FAIL;
598 		goto done;
599 	}
600 	unlink(nstatefile);
601 	if ((b = sshbuf_new()) == NULL) {
602 		ret = SSH_ERR_ALLOC_FAIL;
603 		goto done;
604 	}
605 	if ((ret = sshkey_xmss_serialize_state(k, b)) != 0) {
606 		PRINT("%s: SERLIALIZE FAILED: %d", __func__, ret);
607 		goto done;
608 	}
609 	if ((ret = sshkey_xmss_encrypt_state(k, b, &enc)) != 0) {
610 		PRINT("%s: ENCRYPT FAILED: %d", __func__, ret);
611 		goto done;
612 	}
613 	if ((fd = open(nstatefile, O_CREAT|O_WRONLY|O_EXCL, 0600)) < 0) {
614 		ret = SSH_ERR_SYSTEM_ERROR;
615 		PRINT("%s: open new state file: %s", __func__, nstatefile);
616 		goto done;
617 	}
618 	POKE_U32(buf, sshbuf_len(enc));
619 	if (atomicio(vwrite, fd, buf, sizeof(buf)) != sizeof(buf)) {
620 		ret = SSH_ERR_SYSTEM_ERROR;
621 		PRINT("%s: write new state file hdr: %s", __func__, nstatefile);
622 		close(fd);
623 		goto done;
624 	}
625 	if (atomicio(vwrite, fd, __UNCONST(sshbuf_ptr(enc)), sshbuf_len(enc)) !=
626 	    sshbuf_len(enc)) {
627 		ret = SSH_ERR_SYSTEM_ERROR;
628 		PRINT("%s: write new state file data: %s", __func__, nstatefile);
629 		close(fd);
630 		goto done;
631 	}
632 	if (fsync(fd) < 0) {
633 		ret = SSH_ERR_SYSTEM_ERROR;
634 		PRINT("%s: sync new state file: %s", __func__, nstatefile);
635 		close(fd);
636 		goto done;
637 	}
638 	if (close(fd) < 0) {
639 		ret = SSH_ERR_SYSTEM_ERROR;
640 		PRINT("%s: close new state file: %s", __func__, nstatefile);
641 		goto done;
642 	}
643 	if (state->have_state) {
644 		unlink(ostatefile);
645 		if (link(statefile, ostatefile)) {
646 			ret = SSH_ERR_SYSTEM_ERROR;
647 			PRINT("%s: backup state %s to %s", __func__, statefile,
648 			    ostatefile);
649 			goto done;
650 		}
651 	}
652 	if (rename(nstatefile, statefile) < 0) {
653 		ret = SSH_ERR_SYSTEM_ERROR;
654 		PRINT("%s: rename %s to %s", __func__, nstatefile, statefile);
655 		goto done;
656 	}
657 	ret = 0;
658 done:
659 	if (state->lockfd != -1) {
660 		close(state->lockfd);
661 		state->lockfd = -1;
662 	}
663 	if (nstatefile)
664 		unlink(nstatefile);
665 	free(statefile);
666 	free(ostatefile);
667 	free(nstatefile);
668 	sshbuf_free(b);
669 	sshbuf_free(enc);
670 	return ret;
671 }
672 
673 int
674 sshkey_xmss_serialize_state(const struct sshkey *k, struct sshbuf *b)
675 {
676 	struct ssh_xmss_state *state = k->xmss_state;
677 	treehash_inst *th;
678 	u_int32_t i, node;
679 	int r;
680 
681 	if (state == NULL)
682 		return SSH_ERR_INVALID_ARGUMENT;
683 	if (state->stack == NULL)
684 		return SSH_ERR_INVALID_ARGUMENT;
685 	state->stackoffset = state->bds.stackoffset;	/* copy back */
686 	if ((r = sshbuf_put_cstring(b, SSH_XMSS_K2_MAGIC)) != 0 ||
687 	    (r = sshbuf_put_u32(b, state->idx)) != 0 ||
688 	    (r = sshbuf_put_string(b, state->stack, num_stack(state))) != 0 ||
689 	    (r = sshbuf_put_u32(b, state->stackoffset)) != 0 ||
690 	    (r = sshbuf_put_string(b, state->stacklevels, num_stacklevels(state))) != 0 ||
691 	    (r = sshbuf_put_string(b, state->auth, num_auth(state))) != 0 ||
692 	    (r = sshbuf_put_string(b, state->keep, num_keep(state))) != 0 ||
693 	    (r = sshbuf_put_string(b, state->th_nodes, num_th_nodes(state))) != 0 ||
694 	    (r = sshbuf_put_string(b, state->retain, num_retain(state))) != 0 ||
695 	    (r = sshbuf_put_u32(b, num_treehash(state))) != 0)
696 		return r;
697 	for (i = 0; i < num_treehash(state); i++) {
698 		th = &state->treehash[i];
699 		node = th->node - state->th_nodes;
700 		if ((r = sshbuf_put_u32(b, th->h)) != 0 ||
701 		    (r = sshbuf_put_u32(b, th->next_idx)) != 0 ||
702 		    (r = sshbuf_put_u32(b, th->stackusage)) != 0 ||
703 		    (r = sshbuf_put_u8(b, th->completed)) != 0 ||
704 		    (r = sshbuf_put_u32(b, node)) != 0)
705 			return r;
706 	}
707 	return 0;
708 }
709 
710 int
711 sshkey_xmss_serialize_state_opt(const struct sshkey *k, struct sshbuf *b,
712     enum sshkey_serialize_rep opts)
713 {
714 	struct ssh_xmss_state *state = k->xmss_state;
715 	int r = SSH_ERR_INVALID_ARGUMENT;
716 
717 	if (state == NULL)
718 		return SSH_ERR_INVALID_ARGUMENT;
719 	if ((r = sshbuf_put_u8(b, opts)) != 0)
720 		return r;
721 	switch (opts) {
722 	case SSHKEY_SERIALIZE_STATE:
723 		r = sshkey_xmss_serialize_state(k, b);
724 		break;
725 	case SSHKEY_SERIALIZE_FULL:
726 		if ((r = sshkey_xmss_serialize_enc_key(k, b)) != 0)
727 			break;
728 		r = sshkey_xmss_serialize_state(k, b);
729 		break;
730 	case SSHKEY_SERIALIZE_DEFAULT:
731 		r = 0;
732 		break;
733 	default:
734 		r = SSH_ERR_INVALID_ARGUMENT;
735 		break;
736 	}
737 	return r;
738 }
739 
740 int
741 sshkey_xmss_deserialize_state(struct sshkey *k, struct sshbuf *b)
742 {
743 	struct ssh_xmss_state *state = k->xmss_state;
744 	treehash_inst *th;
745 	u_int32_t i, lh, node;
746 	size_t ls, lsl, la, lk, ln, lr;
747 	char *magic;
748 	int r;
749 
750 	if (state == NULL)
751 		return SSH_ERR_INVALID_ARGUMENT;
752 	if (k->xmss_sk == NULL)
753 		return SSH_ERR_INVALID_ARGUMENT;
754 	if ((state->treehash = calloc(num_treehash(state),
755 	    sizeof(treehash_inst))) == NULL)
756 		return SSH_ERR_ALLOC_FAIL;
757 	if ((r = sshbuf_get_cstring(b, &magic, NULL)) != 0 ||
758 	    (r = sshbuf_get_u32(b, &state->idx)) != 0 ||
759 	    (r = sshbuf_get_string(b, &state->stack, &ls)) != 0 ||
760 	    (r = sshbuf_get_u32(b, &state->stackoffset)) != 0 ||
761 	    (r = sshbuf_get_string(b, &state->stacklevels, &lsl)) != 0 ||
762 	    (r = sshbuf_get_string(b, &state->auth, &la)) != 0 ||
763 	    (r = sshbuf_get_string(b, &state->keep, &lk)) != 0 ||
764 	    (r = sshbuf_get_string(b, &state->th_nodes, &ln)) != 0 ||
765 	    (r = sshbuf_get_string(b, &state->retain, &lr)) != 0 ||
766 	    (r = sshbuf_get_u32(b, &lh)) != 0)
767 		return r;
768 	if (strcmp(magic, SSH_XMSS_K2_MAGIC) != 0)
769 		return SSH_ERR_INVALID_ARGUMENT;
770 	/* XXX check stackoffset */
771 	if (ls != num_stack(state) ||
772 	    lsl != num_stacklevels(state) ||
773 	    la != num_auth(state) ||
774 	    lk != num_keep(state) ||
775 	    ln != num_th_nodes(state) ||
776 	    lr != num_retain(state) ||
777 	    lh != num_treehash(state))
778 		return SSH_ERR_INVALID_ARGUMENT;
779 	for (i = 0; i < num_treehash(state); i++) {
780 		th = &state->treehash[i];
781 		if ((r = sshbuf_get_u32(b, &th->h)) != 0 ||
782 		    (r = sshbuf_get_u32(b, &th->next_idx)) != 0 ||
783 		    (r = sshbuf_get_u32(b, &th->stackusage)) != 0 ||
784 		    (r = sshbuf_get_u8(b, &th->completed)) != 0 ||
785 		    (r = sshbuf_get_u32(b, &node)) != 0)
786 			return r;
787 		if (node < num_th_nodes(state))
788 			th->node = &state->th_nodes[node];
789 	}
790 	POKE_U32(k->xmss_sk, state->idx);
791 	xmss_set_bds_state(&state->bds, state->stack, state->stackoffset,
792 	    state->stacklevels, state->auth, state->keep, state->treehash,
793 	    state->retain, 0);
794 	return 0;
795 }
796 
797 int
798 sshkey_xmss_deserialize_state_opt(struct sshkey *k, struct sshbuf *b)
799 {
800 	enum sshkey_serialize_rep opts;
801 	u_char have_state;
802 	int r;
803 
804 	if ((r = sshbuf_get_u8(b, &have_state)) != 0)
805 		return r;
806 
807 	opts = have_state;
808 	switch (opts) {
809 	case SSHKEY_SERIALIZE_DEFAULT:
810 		r = 0;
811 		break;
812 	case SSHKEY_SERIALIZE_STATE:
813 		if ((r = sshkey_xmss_deserialize_state(k, b)) != 0)
814 			return r;
815 		break;
816 	case SSHKEY_SERIALIZE_FULL:
817 		if ((r = sshkey_xmss_deserialize_enc_key(k, b)) != 0 ||
818 		    (r = sshkey_xmss_deserialize_state(k, b)) != 0)
819 			return r;
820 		break;
821 	default:
822 		r = SSH_ERR_INVALID_FORMAT;
823 		break;
824 	}
825 	return r;
826 }
827 
828 int
829 sshkey_xmss_encrypt_state(const struct sshkey *k, struct sshbuf *b,
830    struct sshbuf **retp)
831 {
832 	struct ssh_xmss_state *state = k->xmss_state;
833 	struct sshbuf *encrypted = NULL, *encoded = NULL, *padded = NULL;
834 	struct sshcipher_ctx *ciphercontext = NULL;
835 	const struct sshcipher *cipher;
836 	u_char *cp, *key, *iv = NULL;
837 	size_t i, keylen, ivlen, blocksize, authlen, encrypted_len, aadlen;
838 	int r = SSH_ERR_INTERNAL_ERROR;
839 
840 	if (retp != NULL)
841 		*retp = NULL;
842 	if (state == NULL ||
843 	    state->enc_keyiv == NULL ||
844 	    state->enc_ciphername == NULL)
845 		return SSH_ERR_INTERNAL_ERROR;
846 	if ((cipher = cipher_by_name(state->enc_ciphername)) == NULL) {
847 		r = SSH_ERR_INTERNAL_ERROR;
848 		goto out;
849 	}
850 	blocksize = cipher_blocksize(cipher);
851 	keylen = cipher_keylen(cipher);
852 	ivlen = cipher_ivlen(cipher);
853 	authlen = cipher_authlen(cipher);
854 	if (state->enc_keyiv_len != keylen + ivlen) {
855 		r = SSH_ERR_INVALID_FORMAT;
856 		goto out;
857 	}
858 	key = state->enc_keyiv;
859 	if ((encrypted = sshbuf_new()) == NULL ||
860 	    (encoded = sshbuf_new()) == NULL ||
861 	    (padded = sshbuf_new()) == NULL ||
862 	    (iv = malloc(ivlen)) == NULL) {
863 		r = SSH_ERR_ALLOC_FAIL;
864 		goto out;
865 	}
866 
867 	/* replace first 4 bytes of IV with index to ensure uniqueness */
868 	memcpy(iv, key + keylen, ivlen);
869 	POKE_U32(iv, state->idx);
870 
871 	if ((r = sshbuf_put(encoded, XMSS_MAGIC, sizeof(XMSS_MAGIC))) != 0 ||
872 	    (r = sshbuf_put_u32(encoded, state->idx)) != 0)
873 		goto out;
874 
875 	/* padded state will be encrypted */
876 	if ((r = sshbuf_putb(padded, b)) != 0)
877 		goto out;
878 	i = 0;
879 	while (sshbuf_len(padded) % blocksize) {
880 		if ((r = sshbuf_put_u8(padded, ++i & 0xff)) != 0)
881 			goto out;
882 	}
883 	encrypted_len = sshbuf_len(padded);
884 
885 	/* header including the length of state is used as AAD */
886 	if ((r = sshbuf_put_u32(encoded, encrypted_len)) != 0)
887 		goto out;
888 	aadlen = sshbuf_len(encoded);
889 
890 	/* concat header and state */
891 	if ((r = sshbuf_putb(encoded, padded)) != 0)
892 		goto out;
893 
894 	/* reserve space for encryption of encoded data plus auth tag */
895 	/* encrypt at offset addlen */
896 	if ((r = sshbuf_reserve(encrypted,
897 	    encrypted_len + aadlen + authlen, &cp)) != 0 ||
898 	    (r = cipher_init(&ciphercontext, cipher, key, keylen,
899 	    iv, ivlen, 1)) != 0 ||
900 	    (r = cipher_crypt(ciphercontext, 0, cp, sshbuf_ptr(encoded),
901 	    encrypted_len, aadlen, authlen)) != 0)
902 		goto out;
903 
904 	/* success */
905 	r = 0;
906  out:
907 	if (retp != NULL) {
908 		*retp = encrypted;
909 		encrypted = NULL;
910 	}
911 	sshbuf_free(padded);
912 	sshbuf_free(encoded);
913 	sshbuf_free(encrypted);
914 	cipher_free(ciphercontext);
915 	free(iv);
916 	return r;
917 }
918 
919 int
920 sshkey_xmss_decrypt_state(const struct sshkey *k, struct sshbuf *encoded,
921    struct sshbuf **retp)
922 {
923 	struct ssh_xmss_state *state = k->xmss_state;
924 	struct sshbuf *copy = NULL, *decrypted = NULL;
925 	struct sshcipher_ctx *ciphercontext = NULL;
926 	const struct sshcipher *cipher = NULL;
927 	u_char *key, *iv = NULL, *dp;
928 	size_t keylen, ivlen, authlen, aadlen;
929 	u_int blocksize, encrypted_len, index;
930 	int r = SSH_ERR_INTERNAL_ERROR;
931 
932 	if (retp != NULL)
933 		*retp = NULL;
934 	if (state == NULL ||
935 	    state->enc_keyiv == NULL ||
936 	    state->enc_ciphername == NULL)
937 		return SSH_ERR_INTERNAL_ERROR;
938 	if ((cipher = cipher_by_name(state->enc_ciphername)) == NULL) {
939 		r = SSH_ERR_INVALID_FORMAT;
940 		goto out;
941 	}
942 	blocksize = cipher_blocksize(cipher);
943 	keylen = cipher_keylen(cipher);
944 	ivlen = cipher_ivlen(cipher);
945 	authlen = cipher_authlen(cipher);
946 	if (state->enc_keyiv_len != keylen + ivlen) {
947 		r = SSH_ERR_INTERNAL_ERROR;
948 		goto out;
949 	}
950 	key = state->enc_keyiv;
951 
952 	if ((copy = sshbuf_fromb(encoded)) == NULL ||
953 	    (decrypted = sshbuf_new()) == NULL ||
954 	    (iv = malloc(ivlen)) == NULL) {
955 		r = SSH_ERR_ALLOC_FAIL;
956 		goto out;
957 	}
958 
959 	/* check magic */
960 	if (sshbuf_len(encoded) < sizeof(XMSS_MAGIC) ||
961 	    memcmp(sshbuf_ptr(encoded), XMSS_MAGIC, sizeof(XMSS_MAGIC))) {
962 		r = SSH_ERR_INVALID_FORMAT;
963 		goto out;
964 	}
965 	/* parse public portion */
966 	if ((r = sshbuf_consume(encoded, sizeof(XMSS_MAGIC))) != 0 ||
967 	    (r = sshbuf_get_u32(encoded, &index)) != 0 ||
968 	    (r = sshbuf_get_u32(encoded, &encrypted_len)) != 0)
969 		goto out;
970 
971 	/* check size of encrypted key blob */
972 	if (encrypted_len < blocksize || (encrypted_len % blocksize) != 0) {
973 		r = SSH_ERR_INVALID_FORMAT;
974 		goto out;
975 	}
976 	/* check that an appropriate amount of auth data is present */
977 	if (sshbuf_len(encoded) < encrypted_len + authlen) {
978 		r = SSH_ERR_INVALID_FORMAT;
979 		goto out;
980 	}
981 
982 	aadlen = sshbuf_len(copy) - sshbuf_len(encoded);
983 
984 	/* replace first 4 bytes of IV with index to ensure uniqueness */
985 	memcpy(iv, key + keylen, ivlen);
986 	POKE_U32(iv, index);
987 
988 	/* decrypt private state of key */
989 	if ((r = sshbuf_reserve(decrypted, aadlen + encrypted_len, &dp)) != 0 ||
990 	    (r = cipher_init(&ciphercontext, cipher, key, keylen,
991 	    iv, ivlen, 0)) != 0 ||
992 	    (r = cipher_crypt(ciphercontext, 0, dp, sshbuf_ptr(copy),
993 	    encrypted_len, aadlen, authlen)) != 0)
994 		goto out;
995 
996 	/* there should be no trailing data */
997 	if ((r = sshbuf_consume(encoded, encrypted_len + authlen)) != 0)
998 		goto out;
999 	if (sshbuf_len(encoded) != 0) {
1000 		r = SSH_ERR_INVALID_FORMAT;
1001 		goto out;
1002 	}
1003 
1004 	/* remove AAD */
1005 	if ((r = sshbuf_consume(decrypted, aadlen)) != 0)
1006 		goto out;
1007 	/* XXX encrypted includes unchecked padding */
1008 
1009 	/* success */
1010 	r = 0;
1011 	if (retp != NULL) {
1012 		*retp = decrypted;
1013 		decrypted = NULL;
1014 	}
1015  out:
1016 	cipher_free(ciphercontext);
1017 	sshbuf_free(copy);
1018 	sshbuf_free(decrypted);
1019 	free(iv);
1020 	return r;
1021 }
1022 
1023 u_int32_t
1024 sshkey_xmss_signatures_left(const struct sshkey *k)
1025 {
1026 	struct ssh_xmss_state *state = k->xmss_state;
1027 	u_int32_t idx;
1028 
1029 	if (sshkey_type_plain(k->type) == KEY_XMSS && state &&
1030 	    state->maxidx) {
1031 		idx = k->xmss_sk ? PEEK_U32(k->xmss_sk) : state->idx;
1032 		if (idx < state->maxidx)
1033 			return state->maxidx - idx;
1034 	}
1035 	return 0;
1036 }
1037 
1038 int
1039 sshkey_xmss_enable_maxsign(struct sshkey *k, u_int32_t maxsign)
1040 {
1041 	struct ssh_xmss_state *state = k->xmss_state;
1042 
1043 	if (sshkey_type_plain(k->type) != KEY_XMSS)
1044 		return SSH_ERR_INVALID_ARGUMENT;
1045 	if (maxsign == 0)
1046 		return 0;
1047 	if (state->idx + maxsign < state->idx)
1048 		return SSH_ERR_INVALID_ARGUMENT;
1049 	state->maxidx = state->idx + maxsign;
1050 	return 0;
1051 }
1052