xref: /netbsd-src/crypto/external/bsd/openssh/dist/sshkey-xmss.c (revision 82d56013d7b633d116a93943de88e08335357a7c)
1 /*	$NetBSD: sshkey-xmss.c,v 1.8 2021/04/19 14:40:15 christos Exp $	*/
2 /* $OpenBSD: sshkey-xmss.c,v 1.11 2021/04/03 06:18:41 djm Exp $ */
3 
4 /*
5  * Copyright (c) 2017 Markus Friedl.  All rights reserved.
6  *
7  * Redistribution and use in source and binary forms, with or without
8  * modification, are permitted provided that the following conditions
9  * are met:
10  * 1. Redistributions of source code must retain the above copyright
11  *    notice, this list of conditions and the following disclaimer.
12  * 2. Redistributions in binary form must reproduce the above copyright
13  *    notice, this list of conditions and the following disclaimer in the
14  *    documentation and/or other materials provided with the distribution.
15  *
16  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
17  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
18  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
19  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
20  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
21  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
22  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
23  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
25  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26  */
27 #include "includes.h"
28 __RCSID("$NetBSD: sshkey-xmss.c,v 1.8 2021/04/19 14:40:15 christos Exp $");
29 
30 #include <sys/types.h>
31 #include <sys/uio.h>
32 
33 #include <stdio.h>
34 #include <string.h>
35 #include <unistd.h>
36 #include <fcntl.h>
37 #include <errno.h>
38 
39 #include "ssh2.h"
40 #include "ssherr.h"
41 #include "sshbuf.h"
42 #include "cipher.h"
43 #include "sshkey.h"
44 #include "sshkey-xmss.h"
45 #include "atomicio.h"
46 #include "log.h"
47 
48 #include "xmss_fast.h"
49 
50 /* opaque internal XMSS state */
51 #define XMSS_MAGIC		"xmss-state-v1"
52 #define XMSS_CIPHERNAME		"aes256-gcm@openssh.com"
53 struct ssh_xmss_state {
54 	xmss_params	params;
55 	u_int32_t	n, w, h, k;
56 
57 	bds_state	bds;
58 	u_char		*stack;
59 	u_int32_t	stackoffset;
60 	u_char		*stacklevels;
61 	u_char		*auth;
62 	u_char		*keep;
63 	u_char		*th_nodes;
64 	u_char		*retain;
65 	treehash_inst	*treehash;
66 
67 	u_int32_t	idx;		/* state read from file */
68 	u_int32_t	maxidx;		/* restricted # of signatures */
69 	int		have_state;	/* .state file exists */
70 	int		lockfd;		/* locked in sshkey_xmss_get_state() */
71 	u_char		allow_update;	/* allow sshkey_xmss_update_state() */
72 	char		*enc_ciphername;/* encrypt state with cipher */
73 	u_char		*enc_keyiv;	/* encrypt state with key */
74 	u_int32_t	enc_keyiv_len;	/* length of enc_keyiv */
75 };
76 
77 int	 sshkey_xmss_init_bds_state(struct sshkey *);
78 int	 sshkey_xmss_init_enc_key(struct sshkey *, const char *);
79 void	 sshkey_xmss_free_bds(struct sshkey *);
80 int	 sshkey_xmss_get_state_from_file(struct sshkey *, const char *,
81 	    int *, int);
82 int	 sshkey_xmss_encrypt_state(const struct sshkey *, struct sshbuf *,
83 	    struct sshbuf **);
84 int	 sshkey_xmss_decrypt_state(const struct sshkey *, struct sshbuf *,
85 	    struct sshbuf **);
86 int	 sshkey_xmss_serialize_enc_key(const struct sshkey *, struct sshbuf *);
87 int	 sshkey_xmss_deserialize_enc_key(struct sshkey *, struct sshbuf *);
88 
89 #define PRINT(...) do { if (printerror) sshlog(__FILE__, __func__, __LINE__, \
90     0, SYSLOG_LEVEL_ERROR, NULL, __VA_ARGS__); } while (0)
91 
92 int
93 sshkey_xmss_init(struct sshkey *key, const char *name)
94 {
95 	struct ssh_xmss_state *state;
96 
97 	if (key->xmss_state != NULL)
98 		return SSH_ERR_INVALID_FORMAT;
99 	if (name == NULL)
100 		return SSH_ERR_INVALID_FORMAT;
101 	state = calloc(sizeof(struct ssh_xmss_state), 1);
102 	if (state == NULL)
103 		return SSH_ERR_ALLOC_FAIL;
104 	if (strcmp(name, XMSS_SHA2_256_W16_H10_NAME) == 0) {
105 		state->n = 32;
106 		state->w = 16;
107 		state->h = 10;
108 	} else if (strcmp(name, XMSS_SHA2_256_W16_H16_NAME) == 0) {
109 		state->n = 32;
110 		state->w = 16;
111 		state->h = 16;
112 	} else if (strcmp(name, XMSS_SHA2_256_W16_H20_NAME) == 0) {
113 		state->n = 32;
114 		state->w = 16;
115 		state->h = 20;
116 	} else {
117 		free(state);
118 		return SSH_ERR_KEY_TYPE_UNKNOWN;
119 	}
120 	if ((key->xmss_name = strdup(name)) == NULL) {
121 		free(state);
122 		return SSH_ERR_ALLOC_FAIL;
123 	}
124 	state->k = 2;	/* XXX hardcoded */
125 	state->lockfd = -1;
126 	if (xmss_set_params(&state->params, state->n, state->h, state->w,
127 	    state->k) != 0) {
128 		free(state);
129 		return SSH_ERR_INVALID_FORMAT;
130 	}
131 	key->xmss_state = state;
132 	return 0;
133 }
134 
135 void
136 sshkey_xmss_free_state(struct sshkey *key)
137 {
138 	struct ssh_xmss_state *state = key->xmss_state;
139 
140 	sshkey_xmss_free_bds(key);
141 	if (state) {
142 		if (state->enc_keyiv) {
143 			explicit_bzero(state->enc_keyiv, state->enc_keyiv_len);
144 			free(state->enc_keyiv);
145 		}
146 		free(state->enc_ciphername);
147 		free(state);
148 	}
149 	key->xmss_state = NULL;
150 }
151 
152 #define SSH_XMSS_K2_MAGIC	"k=2"
153 #define num_stack(x)		((x->h+1)*(x->n))
154 #define num_stacklevels(x)	(x->h+1)
155 #define num_auth(x)		((x->h)*(x->n))
156 #define num_keep(x)		((x->h >> 1)*(x->n))
157 #define num_th_nodes(x)		((x->h - x->k)*(x->n))
158 #define num_retain(x)		(((1ULL << x->k) - x->k - 1) * (x->n))
159 #define num_treehash(x)		((x->h) - (x->k))
160 
161 int
162 sshkey_xmss_init_bds_state(struct sshkey *key)
163 {
164 	struct ssh_xmss_state *state = key->xmss_state;
165 	u_int32_t i;
166 
167 	state->stackoffset = 0;
168 	if ((state->stack = calloc(num_stack(state), 1)) == NULL ||
169 	    (state->stacklevels = calloc(num_stacklevels(state), 1))== NULL ||
170 	    (state->auth = calloc(num_auth(state), 1)) == NULL ||
171 	    (state->keep = calloc(num_keep(state), 1)) == NULL ||
172 	    (state->th_nodes = calloc(num_th_nodes(state), 1)) == NULL ||
173 	    (state->retain = calloc(num_retain(state), 1)) == NULL ||
174 	    (state->treehash = calloc(num_treehash(state),
175 	    sizeof(treehash_inst))) == NULL) {
176 		sshkey_xmss_free_bds(key);
177 		return SSH_ERR_ALLOC_FAIL;
178 	}
179 	for (i = 0; i < state->h - state->k; i++)
180 		state->treehash[i].node = &state->th_nodes[state->n*i];
181 	xmss_set_bds_state(&state->bds, state->stack, state->stackoffset,
182 	    state->stacklevels, state->auth, state->keep, state->treehash,
183 	    state->retain, 0);
184 	return 0;
185 }
186 
187 void
188 sshkey_xmss_free_bds(struct sshkey *key)
189 {
190 	struct ssh_xmss_state *state = key->xmss_state;
191 
192 	if (state == NULL)
193 		return;
194 	free(state->stack);
195 	free(state->stacklevels);
196 	free(state->auth);
197 	free(state->keep);
198 	free(state->th_nodes);
199 	free(state->retain);
200 	free(state->treehash);
201 	state->stack = NULL;
202 	state->stacklevels = NULL;
203 	state->auth = NULL;
204 	state->keep = NULL;
205 	state->th_nodes = NULL;
206 	state->retain = NULL;
207 	state->treehash = NULL;
208 }
209 
210 void *
211 sshkey_xmss_params(const struct sshkey *key)
212 {
213 	struct ssh_xmss_state *state = key->xmss_state;
214 
215 	if (state == NULL)
216 		return NULL;
217 	return &state->params;
218 }
219 
220 void *
221 sshkey_xmss_bds_state(const struct sshkey *key)
222 {
223 	struct ssh_xmss_state *state = key->xmss_state;
224 
225 	if (state == NULL)
226 		return NULL;
227 	return &state->bds;
228 }
229 
230 int
231 sshkey_xmss_siglen(const struct sshkey *key, size_t *lenp)
232 {
233 	struct ssh_xmss_state *state = key->xmss_state;
234 
235 	if (lenp == NULL)
236 		return SSH_ERR_INVALID_ARGUMENT;
237 	if (state == NULL)
238 		return SSH_ERR_INVALID_FORMAT;
239 	*lenp = 4 + state->n +
240 	    state->params.wots_par.keysize +
241 	    state->h * state->n;
242 	return 0;
243 }
244 
245 size_t
246 sshkey_xmss_pklen(const struct sshkey *key)
247 {
248 	struct ssh_xmss_state *state = key->xmss_state;
249 
250 	if (state == NULL)
251 		return 0;
252 	return state->n * 2;
253 }
254 
255 size_t
256 sshkey_xmss_sklen(const struct sshkey *key)
257 {
258 	struct ssh_xmss_state *state = key->xmss_state;
259 
260 	if (state == NULL)
261 		return 0;
262 	return state->n * 4 + 4;
263 }
264 
265 int
266 sshkey_xmss_init_enc_key(struct sshkey *k, const char *ciphername)
267 {
268 	struct ssh_xmss_state *state = k->xmss_state;
269 	const struct sshcipher *cipher;
270 	size_t keylen = 0, ivlen = 0;
271 
272 	if (state == NULL)
273 		return SSH_ERR_INVALID_ARGUMENT;
274 	if ((cipher = cipher_by_name(ciphername)) == NULL)
275 		return SSH_ERR_INTERNAL_ERROR;
276 	if ((state->enc_ciphername = strdup(ciphername)) == NULL)
277 		return SSH_ERR_ALLOC_FAIL;
278 	keylen = cipher_keylen(cipher);
279 	ivlen = cipher_ivlen(cipher);
280 	state->enc_keyiv_len = keylen + ivlen;
281 	if ((state->enc_keyiv = calloc(state->enc_keyiv_len, 1)) == NULL) {
282 		free(state->enc_ciphername);
283 		state->enc_ciphername = NULL;
284 		return SSH_ERR_ALLOC_FAIL;
285 	}
286 	arc4random_buf(state->enc_keyiv, state->enc_keyiv_len);
287 	return 0;
288 }
289 
290 int
291 sshkey_xmss_serialize_enc_key(const struct sshkey *k, struct sshbuf *b)
292 {
293 	struct ssh_xmss_state *state = k->xmss_state;
294 	int r;
295 
296 	if (state == NULL || state->enc_keyiv == NULL ||
297 	    state->enc_ciphername == NULL)
298 		return SSH_ERR_INVALID_ARGUMENT;
299 	if ((r = sshbuf_put_cstring(b, state->enc_ciphername)) != 0 ||
300 	    (r = sshbuf_put_string(b, state->enc_keyiv,
301 	    state->enc_keyiv_len)) != 0)
302 		return r;
303 	return 0;
304 }
305 
306 int
307 sshkey_xmss_deserialize_enc_key(struct sshkey *k, struct sshbuf *b)
308 {
309 	struct ssh_xmss_state *state = k->xmss_state;
310 	size_t len;
311 	int r;
312 
313 	if (state == NULL)
314 		return SSH_ERR_INVALID_ARGUMENT;
315 	if ((r = sshbuf_get_cstring(b, &state->enc_ciphername, NULL)) != 0 ||
316 	    (r = sshbuf_get_string(b, &state->enc_keyiv, &len)) != 0)
317 		return r;
318 	state->enc_keyiv_len = len;
319 	return 0;
320 }
321 
322 int
323 sshkey_xmss_serialize_pk_info(const struct sshkey *k, struct sshbuf *b,
324     enum sshkey_serialize_rep opts)
325 {
326 	struct ssh_xmss_state *state = k->xmss_state;
327 	u_char have_info = 1;
328 	u_int32_t idx;
329 	int r;
330 
331 	if (state == NULL)
332 		return SSH_ERR_INVALID_ARGUMENT;
333 	if (opts != SSHKEY_SERIALIZE_INFO)
334 		return 0;
335 	idx = k->xmss_sk ? PEEK_U32(k->xmss_sk) : state->idx;
336 	if ((r = sshbuf_put_u8(b, have_info)) != 0 ||
337 	    (r = sshbuf_put_u32(b, idx)) != 0 ||
338 	    (r = sshbuf_put_u32(b, state->maxidx)) != 0)
339 		return r;
340 	return 0;
341 }
342 
343 int
344 sshkey_xmss_deserialize_pk_info(struct sshkey *k, struct sshbuf *b)
345 {
346 	struct ssh_xmss_state *state = k->xmss_state;
347 	u_char have_info;
348 	int r;
349 
350 	if (state == NULL)
351 		return SSH_ERR_INVALID_ARGUMENT;
352 	/* optional */
353 	if (sshbuf_len(b) == 0)
354 		return 0;
355 	if ((r = sshbuf_get_u8(b, &have_info)) != 0)
356 		return r;
357 	if (have_info != 1)
358 		return SSH_ERR_INVALID_ARGUMENT;
359 	if ((r = sshbuf_get_u32(b, &state->idx)) != 0 ||
360 	    (r = sshbuf_get_u32(b, &state->maxidx)) != 0)
361 		return r;
362 	return 0;
363 }
364 
365 int
366 sshkey_xmss_generate_private_key(struct sshkey *k, u_int bits)
367 {
368 	int r;
369 	const char *name;
370 
371 	if (bits == 10) {
372 		name = XMSS_SHA2_256_W16_H10_NAME;
373 	} else if (bits == 16) {
374 		name = XMSS_SHA2_256_W16_H16_NAME;
375 	} else if (bits == 20) {
376 		name = XMSS_SHA2_256_W16_H20_NAME;
377 	} else {
378 		name = XMSS_DEFAULT_NAME;
379 	}
380 	if ((r = sshkey_xmss_init(k, name)) != 0 ||
381 	    (r = sshkey_xmss_init_bds_state(k)) != 0 ||
382 	    (r = sshkey_xmss_init_enc_key(k, XMSS_CIPHERNAME)) != 0)
383 		return r;
384 	if ((k->xmss_pk = malloc(sshkey_xmss_pklen(k))) == NULL ||
385 	    (k->xmss_sk = malloc(sshkey_xmss_sklen(k))) == NULL) {
386 		return SSH_ERR_ALLOC_FAIL;
387 	}
388 	xmss_keypair(k->xmss_pk, k->xmss_sk, sshkey_xmss_bds_state(k),
389 	    sshkey_xmss_params(k));
390 	return 0;
391 }
392 
393 int
394 sshkey_xmss_get_state_from_file(struct sshkey *k, const char *filename,
395     int *have_file, int printerror)
396 {
397 	struct sshbuf *b = NULL, *enc = NULL;
398 	int ret = SSH_ERR_SYSTEM_ERROR, r, fd = -1;
399 	u_int32_t len;
400 	unsigned char buf[4], *data = NULL;
401 
402 	*have_file = 0;
403 	if ((fd = open(filename, O_RDONLY)) >= 0) {
404 		*have_file = 1;
405 		if (atomicio(read, fd, buf, sizeof(buf)) != sizeof(buf)) {
406 			PRINT("corrupt state file: %s", filename);
407 			goto done;
408 		}
409 		len = PEEK_U32(buf);
410 		if ((data = calloc(len, 1)) == NULL) {
411 			ret = SSH_ERR_ALLOC_FAIL;
412 			goto done;
413 		}
414 		if (atomicio(read, fd, data, len) != len) {
415 			PRINT("cannot read blob: %s", filename);
416 			goto done;
417 		}
418 		if ((enc = sshbuf_from(data, len)) == NULL) {
419 			ret = SSH_ERR_ALLOC_FAIL;
420 			goto done;
421 		}
422 		sshkey_xmss_free_bds(k);
423 		if ((r = sshkey_xmss_decrypt_state(k, enc, &b)) != 0) {
424 			ret = r;
425 			goto done;
426 		}
427 		if ((r = sshkey_xmss_deserialize_state(k, b)) != 0) {
428 			ret = r;
429 			goto done;
430 		}
431 		ret = 0;
432 	}
433 done:
434 	if (fd != -1)
435 		close(fd);
436 	free(data);
437 	sshbuf_free(enc);
438 	sshbuf_free(b);
439 	return ret;
440 }
441 
442 int
443 sshkey_xmss_get_state(const struct sshkey *k, int printerror)
444 {
445 	struct ssh_xmss_state *state = k->xmss_state;
446 	u_int32_t idx = 0;
447 	char *filename = NULL;
448 	char *statefile = NULL, *ostatefile = NULL, *lockfile = NULL;
449 	int lockfd = -1, have_state = 0, have_ostate, tries = 0;
450 	int ret = SSH_ERR_INVALID_ARGUMENT, r;
451 
452 	if (state == NULL)
453 		goto done;
454 	/*
455 	 * If maxidx is set, then we are allowed a limited number
456 	 * of signatures, but don't need to access the disk.
457 	 * Otherwise we need to deal with the on-disk state.
458 	 */
459 	if (state->maxidx) {
460 		/* xmss_sk always contains the current state */
461 		idx = PEEK_U32(k->xmss_sk);
462 		if (idx < state->maxidx) {
463 			state->allow_update = 1;
464 			return 0;
465 		}
466 		return SSH_ERR_INVALID_ARGUMENT;
467 	}
468 	if ((filename = k->xmss_filename) == NULL)
469 		goto done;
470 	if (asprintf(&lockfile, "%s.lock", filename) == -1 ||
471 	    asprintf(&statefile, "%s.state", filename) == -1 ||
472 	    asprintf(&ostatefile, "%s.ostate", filename) == -1) {
473 		ret = SSH_ERR_ALLOC_FAIL;
474 		goto done;
475 	}
476 	if ((lockfd = open(lockfile, O_CREAT|O_RDONLY, 0600)) == -1) {
477 		ret = SSH_ERR_SYSTEM_ERROR;
478 		PRINT("cannot open/create: %s", lockfile);
479 		goto done;
480 	}
481 	while (flock(lockfd, LOCK_EX|LOCK_NB) == -1) {
482 		if (errno != EWOULDBLOCK) {
483 			ret = SSH_ERR_SYSTEM_ERROR;
484 			PRINT("cannot lock: %s", lockfile);
485 			goto done;
486 		}
487 		if (++tries > 10) {
488 			ret = SSH_ERR_SYSTEM_ERROR;
489 			PRINT("giving up on: %s", lockfile);
490 			goto done;
491 		}
492 		usleep(1000*100*tries);
493 	}
494 	/* XXX no longer const */
495 	if ((r = sshkey_xmss_get_state_from_file(__UNCONST(k),
496 	    statefile, &have_state, printerror)) != 0) {
497 		if ((r = sshkey_xmss_get_state_from_file(__UNCONST(k),
498 		    ostatefile, &have_ostate, printerror)) == 0) {
499 			state->allow_update = 1;
500 			r = sshkey_xmss_forward_state(k, 1);
501 			state->idx = PEEK_U32(k->xmss_sk);
502 			state->allow_update = 0;
503 		}
504 	}
505 	if (!have_state && !have_ostate) {
506 		/* check that bds state is initialized */
507 		if (state->bds.auth == NULL)
508 			goto done;
509 		PRINT("start from scratch idx 0: %u", state->idx);
510 	} else if (r != 0) {
511 		ret = r;
512 		goto done;
513 	}
514 	if (state->idx + 1 < state->idx) {
515 		PRINT("state wrap: %u", state->idx);
516 		goto done;
517 	}
518 	state->have_state = have_state;
519 	state->lockfd = lockfd;
520 	state->allow_update = 1;
521 	lockfd = -1;
522 	ret = 0;
523 done:
524 	if (lockfd != -1)
525 		close(lockfd);
526 	free(lockfile);
527 	free(statefile);
528 	free(ostatefile);
529 	return ret;
530 }
531 
532 int
533 sshkey_xmss_forward_state(const struct sshkey *k, u_int32_t reserve)
534 {
535 	struct ssh_xmss_state *state = k->xmss_state;
536 	u_char *sig = NULL;
537 	size_t required_siglen;
538 	unsigned long long smlen;
539 	u_char data;
540 	int ret, r;
541 
542 	if (state == NULL || !state->allow_update)
543 		return SSH_ERR_INVALID_ARGUMENT;
544 	if (reserve == 0)
545 		return SSH_ERR_INVALID_ARGUMENT;
546 	if (state->idx + reserve <= state->idx)
547 		return SSH_ERR_INVALID_ARGUMENT;
548 	if ((r = sshkey_xmss_siglen(k, &required_siglen)) != 0)
549 		return r;
550 	if ((sig = malloc(required_siglen)) == NULL)
551 		return SSH_ERR_ALLOC_FAIL;
552 	while (reserve-- > 0) {
553 		state->idx = PEEK_U32(k->xmss_sk);
554 		smlen = required_siglen;
555 		if ((ret = xmss_sign(k->xmss_sk, sshkey_xmss_bds_state(k),
556 		    sig, &smlen, &data, 0, sshkey_xmss_params(k))) != 0) {
557 			r = SSH_ERR_INVALID_ARGUMENT;
558 			break;
559 		}
560 	}
561 	free(sig);
562 	return r;
563 }
564 
565 int
566 sshkey_xmss_update_state(const struct sshkey *k, int printerror)
567 {
568 	struct ssh_xmss_state *state = k->xmss_state;
569 	struct sshbuf *b = NULL, *enc = NULL;
570 	u_int32_t idx = 0;
571 	unsigned char buf[4];
572 	char *filename = NULL;
573 	char *statefile = NULL, *ostatefile = NULL, *nstatefile = NULL;
574 	int fd = -1;
575 	int ret = SSH_ERR_INVALID_ARGUMENT;
576 
577 	if (state == NULL || !state->allow_update)
578 		return ret;
579 	if (state->maxidx) {
580 		/* no update since the number of signatures is limited */
581 		ret = 0;
582 		goto done;
583 	}
584 	idx = PEEK_U32(k->xmss_sk);
585 	if (idx == state->idx) {
586 		/* no signature happened, no need to update */
587 		ret = 0;
588 		goto done;
589 	} else if (idx != state->idx + 1) {
590 		PRINT("more than one signature happened: idx %u state %u",
591 		    idx, state->idx);
592 		goto done;
593 	}
594 	state->idx = idx;
595 	if ((filename = k->xmss_filename) == NULL)
596 		goto done;
597 	if (asprintf(&statefile, "%s.state", filename) == -1 ||
598 	    asprintf(&ostatefile, "%s.ostate", filename) == -1 ||
599 	    asprintf(&nstatefile, "%s.nstate", filename) == -1) {
600 		ret = SSH_ERR_ALLOC_FAIL;
601 		goto done;
602 	}
603 	unlink(nstatefile);
604 	if ((b = sshbuf_new()) == NULL) {
605 		ret = SSH_ERR_ALLOC_FAIL;
606 		goto done;
607 	}
608 	if ((ret = sshkey_xmss_serialize_state(k, b)) != 0) {
609 		PRINT("SERLIALIZE FAILED: %d", ret);
610 		goto done;
611 	}
612 	if ((ret = sshkey_xmss_encrypt_state(k, b, &enc)) != 0) {
613 		PRINT("ENCRYPT FAILED: %d", ret);
614 		goto done;
615 	}
616 	if ((fd = open(nstatefile, O_CREAT|O_WRONLY|O_EXCL, 0600)) == -1) {
617 		ret = SSH_ERR_SYSTEM_ERROR;
618 		PRINT("open new state file: %s", nstatefile);
619 		goto done;
620 	}
621 	POKE_U32(buf, sshbuf_len(enc));
622 	if (atomicio(vwrite, fd, buf, sizeof(buf)) != sizeof(buf)) {
623 		ret = SSH_ERR_SYSTEM_ERROR;
624 		PRINT("write new state file hdr: %s", nstatefile);
625 		close(fd);
626 		goto done;
627 	}
628 	if (atomicio(vwrite, fd, sshbuf_mutable_ptr(enc), sshbuf_len(enc)) !=
629 	    sshbuf_len(enc)) {
630 		ret = SSH_ERR_SYSTEM_ERROR;
631 		PRINT("write new state file data: %s", nstatefile);
632 		close(fd);
633 		goto done;
634 	}
635 	if (fsync(fd) == -1) {
636 		ret = SSH_ERR_SYSTEM_ERROR;
637 		PRINT("sync new state file: %s", nstatefile);
638 		close(fd);
639 		goto done;
640 	}
641 	if (close(fd) == -1) {
642 		ret = SSH_ERR_SYSTEM_ERROR;
643 		PRINT("close new state file: %s", nstatefile);
644 		goto done;
645 	}
646 	if (state->have_state) {
647 		unlink(ostatefile);
648 		if (link(statefile, ostatefile)) {
649 			ret = SSH_ERR_SYSTEM_ERROR;
650 			PRINT("backup state %s to %s", statefile, ostatefile);
651 			goto done;
652 		}
653 	}
654 	if (rename(nstatefile, statefile) == -1) {
655 		ret = SSH_ERR_SYSTEM_ERROR;
656 		PRINT("rename %s to %s", nstatefile, statefile);
657 		goto done;
658 	}
659 	ret = 0;
660 done:
661 	if (state->lockfd != -1) {
662 		close(state->lockfd);
663 		state->lockfd = -1;
664 	}
665 	if (nstatefile)
666 		unlink(nstatefile);
667 	free(statefile);
668 	free(ostatefile);
669 	free(nstatefile);
670 	sshbuf_free(b);
671 	sshbuf_free(enc);
672 	return ret;
673 }
674 
675 int
676 sshkey_xmss_serialize_state(const struct sshkey *k, struct sshbuf *b)
677 {
678 	struct ssh_xmss_state *state = k->xmss_state;
679 	treehash_inst *th;
680 	u_int32_t i, node;
681 	int r;
682 
683 	if (state == NULL)
684 		return SSH_ERR_INVALID_ARGUMENT;
685 	if (state->stack == NULL)
686 		return SSH_ERR_INVALID_ARGUMENT;
687 	state->stackoffset = state->bds.stackoffset;	/* copy back */
688 	if ((r = sshbuf_put_cstring(b, SSH_XMSS_K2_MAGIC)) != 0 ||
689 	    (r = sshbuf_put_u32(b, state->idx)) != 0 ||
690 	    (r = sshbuf_put_string(b, state->stack, num_stack(state))) != 0 ||
691 	    (r = sshbuf_put_u32(b, state->stackoffset)) != 0 ||
692 	    (r = sshbuf_put_string(b, state->stacklevels, num_stacklevels(state))) != 0 ||
693 	    (r = sshbuf_put_string(b, state->auth, num_auth(state))) != 0 ||
694 	    (r = sshbuf_put_string(b, state->keep, num_keep(state))) != 0 ||
695 	    (r = sshbuf_put_string(b, state->th_nodes, num_th_nodes(state))) != 0 ||
696 	    (r = sshbuf_put_string(b, state->retain, num_retain(state))) != 0 ||
697 	    (r = sshbuf_put_u32(b, num_treehash(state))) != 0)
698 		return r;
699 	for (i = 0; i < num_treehash(state); i++) {
700 		th = &state->treehash[i];
701 		node = th->node - state->th_nodes;
702 		if ((r = sshbuf_put_u32(b, th->h)) != 0 ||
703 		    (r = sshbuf_put_u32(b, th->next_idx)) != 0 ||
704 		    (r = sshbuf_put_u32(b, th->stackusage)) != 0 ||
705 		    (r = sshbuf_put_u8(b, th->completed)) != 0 ||
706 		    (r = sshbuf_put_u32(b, node)) != 0)
707 			return r;
708 	}
709 	return 0;
710 }
711 
712 int
713 sshkey_xmss_serialize_state_opt(const struct sshkey *k, struct sshbuf *b,
714     enum sshkey_serialize_rep opts)
715 {
716 	struct ssh_xmss_state *state = k->xmss_state;
717 	int r = SSH_ERR_INVALID_ARGUMENT;
718 	u_char have_stack, have_filename, have_enc;
719 
720 	if (state == NULL)
721 		return SSH_ERR_INVALID_ARGUMENT;
722 	if ((r = sshbuf_put_u8(b, opts)) != 0)
723 		return r;
724 	switch (opts) {
725 	case SSHKEY_SERIALIZE_STATE:
726 		r = sshkey_xmss_serialize_state(k, b);
727 		break;
728 	case SSHKEY_SERIALIZE_FULL:
729 		if ((r = sshkey_xmss_serialize_enc_key(k, b)) != 0)
730 			return r;
731 		r = sshkey_xmss_serialize_state(k, b);
732 		break;
733 	case SSHKEY_SERIALIZE_SHIELD:
734 		/* all of stack/filename/enc are optional */
735 		have_stack = state->stack != NULL;
736 		if ((r = sshbuf_put_u8(b, have_stack)) != 0)
737 			return r;
738 		if (have_stack) {
739 			state->idx = PEEK_U32(k->xmss_sk);	/* update */
740 			if ((r = sshkey_xmss_serialize_state(k, b)) != 0)
741 				return r;
742 		}
743 		have_filename = k->xmss_filename != NULL;
744 		if ((r = sshbuf_put_u8(b, have_filename)) != 0)
745 			return r;
746 		if (have_filename &&
747 		    (r = sshbuf_put_cstring(b, k->xmss_filename)) != 0)
748 			return r;
749 		have_enc = state->enc_keyiv != NULL;
750 		if ((r = sshbuf_put_u8(b, have_enc)) != 0)
751 			return r;
752 		if (have_enc &&
753 		    (r = sshkey_xmss_serialize_enc_key(k, b)) != 0)
754 			return r;
755 		if ((r = sshbuf_put_u32(b, state->maxidx)) != 0 ||
756 		    (r = sshbuf_put_u8(b, state->allow_update)) != 0)
757 			return r;
758 		break;
759 	case SSHKEY_SERIALIZE_DEFAULT:
760 		r = 0;
761 		break;
762 	default:
763 		r = SSH_ERR_INVALID_ARGUMENT;
764 		break;
765 	}
766 	return r;
767 }
768 
769 int
770 sshkey_xmss_deserialize_state(struct sshkey *k, struct sshbuf *b)
771 {
772 	struct ssh_xmss_state *state = k->xmss_state;
773 	treehash_inst *th;
774 	u_int32_t i, lh, node;
775 	size_t ls, lsl, la, lk, ln, lr;
776 	char *magic;
777 	int r = SSH_ERR_INTERNAL_ERROR;
778 
779 	if (state == NULL)
780 		return SSH_ERR_INVALID_ARGUMENT;
781 	if (k->xmss_sk == NULL)
782 		return SSH_ERR_INVALID_ARGUMENT;
783 	if ((state->treehash = calloc(num_treehash(state),
784 	    sizeof(treehash_inst))) == NULL)
785 		return SSH_ERR_ALLOC_FAIL;
786 	if ((r = sshbuf_get_cstring(b, &magic, NULL)) != 0 ||
787 	    (r = sshbuf_get_u32(b, &state->idx)) != 0 ||
788 	    (r = sshbuf_get_string(b, &state->stack, &ls)) != 0 ||
789 	    (r = sshbuf_get_u32(b, &state->stackoffset)) != 0 ||
790 	    (r = sshbuf_get_string(b, &state->stacklevels, &lsl)) != 0 ||
791 	    (r = sshbuf_get_string(b, &state->auth, &la)) != 0 ||
792 	    (r = sshbuf_get_string(b, &state->keep, &lk)) != 0 ||
793 	    (r = sshbuf_get_string(b, &state->th_nodes, &ln)) != 0 ||
794 	    (r = sshbuf_get_string(b, &state->retain, &lr)) != 0 ||
795 	    (r = sshbuf_get_u32(b, &lh)) != 0)
796 		goto out;
797 	if (strcmp(magic, SSH_XMSS_K2_MAGIC) != 0) {
798 		r = SSH_ERR_INVALID_ARGUMENT;
799 		goto out;
800 	}
801 	/* XXX check stackoffset */
802 	if (ls != num_stack(state) ||
803 	    lsl != num_stacklevels(state) ||
804 	    la != num_auth(state) ||
805 	    lk != num_keep(state) ||
806 	    ln != num_th_nodes(state) ||
807 	    lr != num_retain(state) ||
808 	    lh != num_treehash(state)) {
809 		r = SSH_ERR_INVALID_ARGUMENT;
810 		goto out;
811 	}
812 	for (i = 0; i < num_treehash(state); i++) {
813 		th = &state->treehash[i];
814 		if ((r = sshbuf_get_u32(b, &th->h)) != 0 ||
815 		    (r = sshbuf_get_u32(b, &th->next_idx)) != 0 ||
816 		    (r = sshbuf_get_u32(b, &th->stackusage)) != 0 ||
817 		    (r = sshbuf_get_u8(b, &th->completed)) != 0 ||
818 		    (r = sshbuf_get_u32(b, &node)) != 0)
819 			goto out;
820 		if (node < num_th_nodes(state))
821 			th->node = &state->th_nodes[node];
822 	}
823 	POKE_U32(k->xmss_sk, state->idx);
824 	xmss_set_bds_state(&state->bds, state->stack, state->stackoffset,
825 	    state->stacklevels, state->auth, state->keep, state->treehash,
826 	    state->retain, 0);
827 	/* success */
828 	r = 0;
829  out:
830 	free(magic);
831 	return r;
832 }
833 
834 int
835 sshkey_xmss_deserialize_state_opt(struct sshkey *k, struct sshbuf *b)
836 {
837 	struct ssh_xmss_state *state = k->xmss_state;
838 	enum sshkey_serialize_rep opts;
839 	u_char have_state, have_stack, have_filename, have_enc;
840 	int r;
841 
842 	if ((r = sshbuf_get_u8(b, &have_state)) != 0)
843 		return r;
844 
845 	opts = have_state;
846 	switch (opts) {
847 	case SSHKEY_SERIALIZE_DEFAULT:
848 		r = 0;
849 		break;
850 	case SSHKEY_SERIALIZE_SHIELD:
851 		if ((r = sshbuf_get_u8(b, &have_stack)) != 0)
852 			return r;
853 		if (have_stack &&
854 		    (r = sshkey_xmss_deserialize_state(k, b)) != 0)
855 			return r;
856 		if ((r = sshbuf_get_u8(b, &have_filename)) != 0)
857 			return r;
858 		if (have_filename &&
859 		    (r = sshbuf_get_cstring(b, &k->xmss_filename, NULL)) != 0)
860 			return r;
861 		if ((r = sshbuf_get_u8(b, &have_enc)) != 0)
862 			return r;
863 		if (have_enc &&
864 		    (r = sshkey_xmss_deserialize_enc_key(k, b)) != 0)
865 			return r;
866 		if ((r = sshbuf_get_u32(b, &state->maxidx)) != 0 ||
867 		    (r = sshbuf_get_u8(b, &state->allow_update)) != 0)
868 			return r;
869 		break;
870 	case SSHKEY_SERIALIZE_STATE:
871 		if ((r = sshkey_xmss_deserialize_state(k, b)) != 0)
872 			return r;
873 		break;
874 	case SSHKEY_SERIALIZE_FULL:
875 		if ((r = sshkey_xmss_deserialize_enc_key(k, b)) != 0 ||
876 		    (r = sshkey_xmss_deserialize_state(k, b)) != 0)
877 			return r;
878 		break;
879 	default:
880 		r = SSH_ERR_INVALID_FORMAT;
881 		break;
882 	}
883 	return r;
884 }
885 
886 int
887 sshkey_xmss_encrypt_state(const struct sshkey *k, struct sshbuf *b,
888    struct sshbuf **retp)
889 {
890 	struct ssh_xmss_state *state = k->xmss_state;
891 	struct sshbuf *encrypted = NULL, *encoded = NULL, *padded = NULL;
892 	struct sshcipher_ctx *ciphercontext = NULL;
893 	const struct sshcipher *cipher;
894 	u_char *cp, *key, *iv = NULL;
895 	size_t i, keylen, ivlen, blocksize, authlen, encrypted_len, aadlen;
896 	int r = SSH_ERR_INTERNAL_ERROR;
897 
898 	if (retp != NULL)
899 		*retp = NULL;
900 	if (state == NULL ||
901 	    state->enc_keyiv == NULL ||
902 	    state->enc_ciphername == NULL)
903 		return SSH_ERR_INTERNAL_ERROR;
904 	if ((cipher = cipher_by_name(state->enc_ciphername)) == NULL) {
905 		r = SSH_ERR_INTERNAL_ERROR;
906 		goto out;
907 	}
908 	blocksize = cipher_blocksize(cipher);
909 	keylen = cipher_keylen(cipher);
910 	ivlen = cipher_ivlen(cipher);
911 	authlen = cipher_authlen(cipher);
912 	if (state->enc_keyiv_len != keylen + ivlen) {
913 		r = SSH_ERR_INVALID_FORMAT;
914 		goto out;
915 	}
916 	key = state->enc_keyiv;
917 	if ((encrypted = sshbuf_new()) == NULL ||
918 	    (encoded = sshbuf_new()) == NULL ||
919 	    (padded = sshbuf_new()) == NULL ||
920 	    (iv = malloc(ivlen)) == NULL) {
921 		r = SSH_ERR_ALLOC_FAIL;
922 		goto out;
923 	}
924 
925 	/* replace first 4 bytes of IV with index to ensure uniqueness */
926 	memcpy(iv, key + keylen, ivlen);
927 	POKE_U32(iv, state->idx);
928 
929 	if ((r = sshbuf_put(encoded, XMSS_MAGIC, sizeof(XMSS_MAGIC))) != 0 ||
930 	    (r = sshbuf_put_u32(encoded, state->idx)) != 0)
931 		goto out;
932 
933 	/* padded state will be encrypted */
934 	if ((r = sshbuf_putb(padded, b)) != 0)
935 		goto out;
936 	i = 0;
937 	while (sshbuf_len(padded) % blocksize) {
938 		if ((r = sshbuf_put_u8(padded, ++i & 0xff)) != 0)
939 			goto out;
940 	}
941 	encrypted_len = sshbuf_len(padded);
942 
943 	/* header including the length of state is used as AAD */
944 	if ((r = sshbuf_put_u32(encoded, encrypted_len)) != 0)
945 		goto out;
946 	aadlen = sshbuf_len(encoded);
947 
948 	/* concat header and state */
949 	if ((r = sshbuf_putb(encoded, padded)) != 0)
950 		goto out;
951 
952 	/* reserve space for encryption of encoded data plus auth tag */
953 	/* encrypt at offset addlen */
954 	if ((r = sshbuf_reserve(encrypted,
955 	    encrypted_len + aadlen + authlen, &cp)) != 0 ||
956 	    (r = cipher_init(&ciphercontext, cipher, key, keylen,
957 	    iv, ivlen, 1)) != 0 ||
958 	    (r = cipher_crypt(ciphercontext, 0, cp, sshbuf_ptr(encoded),
959 	    encrypted_len, aadlen, authlen)) != 0)
960 		goto out;
961 
962 	/* success */
963 	r = 0;
964  out:
965 	if (retp != NULL) {
966 		*retp = encrypted;
967 		encrypted = NULL;
968 	}
969 	sshbuf_free(padded);
970 	sshbuf_free(encoded);
971 	sshbuf_free(encrypted);
972 	cipher_free(ciphercontext);
973 	free(iv);
974 	return r;
975 }
976 
977 int
978 sshkey_xmss_decrypt_state(const struct sshkey *k, struct sshbuf *encoded,
979    struct sshbuf **retp)
980 {
981 	struct ssh_xmss_state *state = k->xmss_state;
982 	struct sshbuf *copy = NULL, *decrypted = NULL;
983 	struct sshcipher_ctx *ciphercontext = NULL;
984 	const struct sshcipher *cipher = NULL;
985 	u_char *key, *iv = NULL, *dp;
986 	size_t keylen, ivlen, authlen, aadlen;
987 	u_int blocksize, encrypted_len, index;
988 	int r = SSH_ERR_INTERNAL_ERROR;
989 
990 	if (retp != NULL)
991 		*retp = NULL;
992 	if (state == NULL ||
993 	    state->enc_keyiv == NULL ||
994 	    state->enc_ciphername == NULL)
995 		return SSH_ERR_INTERNAL_ERROR;
996 	if ((cipher = cipher_by_name(state->enc_ciphername)) == NULL) {
997 		r = SSH_ERR_INVALID_FORMAT;
998 		goto out;
999 	}
1000 	blocksize = cipher_blocksize(cipher);
1001 	keylen = cipher_keylen(cipher);
1002 	ivlen = cipher_ivlen(cipher);
1003 	authlen = cipher_authlen(cipher);
1004 	if (state->enc_keyiv_len != keylen + ivlen) {
1005 		r = SSH_ERR_INTERNAL_ERROR;
1006 		goto out;
1007 	}
1008 	key = state->enc_keyiv;
1009 
1010 	if ((copy = sshbuf_fromb(encoded)) == NULL ||
1011 	    (decrypted = sshbuf_new()) == NULL ||
1012 	    (iv = malloc(ivlen)) == NULL) {
1013 		r = SSH_ERR_ALLOC_FAIL;
1014 		goto out;
1015 	}
1016 
1017 	/* check magic */
1018 	if (sshbuf_len(encoded) < sizeof(XMSS_MAGIC) ||
1019 	    memcmp(sshbuf_ptr(encoded), XMSS_MAGIC, sizeof(XMSS_MAGIC))) {
1020 		r = SSH_ERR_INVALID_FORMAT;
1021 		goto out;
1022 	}
1023 	/* parse public portion */
1024 	if ((r = sshbuf_consume(encoded, sizeof(XMSS_MAGIC))) != 0 ||
1025 	    (r = sshbuf_get_u32(encoded, &index)) != 0 ||
1026 	    (r = sshbuf_get_u32(encoded, &encrypted_len)) != 0)
1027 		goto out;
1028 
1029 	/* check size of encrypted key blob */
1030 	if (encrypted_len < blocksize || (encrypted_len % blocksize) != 0) {
1031 		r = SSH_ERR_INVALID_FORMAT;
1032 		goto out;
1033 	}
1034 	/* check that an appropriate amount of auth data is present */
1035 	if (sshbuf_len(encoded) < authlen ||
1036 	    sshbuf_len(encoded) - authlen < encrypted_len) {
1037 		r = SSH_ERR_INVALID_FORMAT;
1038 		goto out;
1039 	}
1040 
1041 	aadlen = sshbuf_len(copy) - sshbuf_len(encoded);
1042 
1043 	/* replace first 4 bytes of IV with index to ensure uniqueness */
1044 	memcpy(iv, key + keylen, ivlen);
1045 	POKE_U32(iv, index);
1046 
1047 	/* decrypt private state of key */
1048 	if ((r = sshbuf_reserve(decrypted, aadlen + encrypted_len, &dp)) != 0 ||
1049 	    (r = cipher_init(&ciphercontext, cipher, key, keylen,
1050 	    iv, ivlen, 0)) != 0 ||
1051 	    (r = cipher_crypt(ciphercontext, 0, dp, sshbuf_ptr(copy),
1052 	    encrypted_len, aadlen, authlen)) != 0)
1053 		goto out;
1054 
1055 	/* there should be no trailing data */
1056 	if ((r = sshbuf_consume(encoded, encrypted_len + authlen)) != 0)
1057 		goto out;
1058 	if (sshbuf_len(encoded) != 0) {
1059 		r = SSH_ERR_INVALID_FORMAT;
1060 		goto out;
1061 	}
1062 
1063 	/* remove AAD */
1064 	if ((r = sshbuf_consume(decrypted, aadlen)) != 0)
1065 		goto out;
1066 	/* XXX encrypted includes unchecked padding */
1067 
1068 	/* success */
1069 	r = 0;
1070 	if (retp != NULL) {
1071 		*retp = decrypted;
1072 		decrypted = NULL;
1073 	}
1074  out:
1075 	cipher_free(ciphercontext);
1076 	sshbuf_free(copy);
1077 	sshbuf_free(decrypted);
1078 	free(iv);
1079 	return r;
1080 }
1081 
1082 u_int32_t
1083 sshkey_xmss_signatures_left(const struct sshkey *k)
1084 {
1085 	struct ssh_xmss_state *state = k->xmss_state;
1086 	u_int32_t idx;
1087 
1088 	if (sshkey_type_plain(k->type) == KEY_XMSS && state &&
1089 	    state->maxidx) {
1090 		idx = k->xmss_sk ? PEEK_U32(k->xmss_sk) : state->idx;
1091 		if (idx < state->maxidx)
1092 			return state->maxidx - idx;
1093 	}
1094 	return 0;
1095 }
1096 
1097 int
1098 sshkey_xmss_enable_maxsign(struct sshkey *k, u_int32_t maxsign)
1099 {
1100 	struct ssh_xmss_state *state = k->xmss_state;
1101 
1102 	if (sshkey_type_plain(k->type) != KEY_XMSS)
1103 		return SSH_ERR_INVALID_ARGUMENT;
1104 	if (maxsign == 0)
1105 		return 0;
1106 	if (state->idx + maxsign < state->idx)
1107 		return SSH_ERR_INVALID_ARGUMENT;
1108 	state->maxidx = state->idx + maxsign;
1109 	return 0;
1110 }
1111