xref: /openbsd-src/usr.bin/ssh/sshkey-xmss.c (revision fcde59b201a29a2b4570b00b71e7aa25d61cb5c1)
1 /* $OpenBSD: sshkey-xmss.c,v 1.9 2020/10/19 22:49:23 dtucker Exp $ */
2 /*
3  * Copyright (c) 2017 Markus Friedl.  All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  * 1. Redistributions of source code must retain the above copyright
9  *    notice, this list of conditions and the following disclaimer.
10  * 2. Redistributions in binary form must reproduce the above copyright
11  *    notice, this list of conditions and the following disclaimer in the
12  *    documentation and/or other materials provided with the distribution.
13  *
14  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
15  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
16  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
17  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
18  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
19  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
20  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
21  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
23  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24  */
25 
26 #include <sys/types.h>
27 #include <sys/uio.h>
28 
29 #include <stdio.h>
30 #include <string.h>
31 #include <unistd.h>
32 #include <fcntl.h>
33 #include <errno.h>
34 
35 #include "ssh2.h"
36 #include "ssherr.h"
37 #include "sshbuf.h"
38 #include "cipher.h"
39 #include "sshkey.h"
40 #include "sshkey-xmss.h"
41 #include "atomicio.h"
42 #include "log.h"
43 
44 #include "xmss_fast.h"
45 
46 /* opaque internal XMSS state */
47 #define XMSS_MAGIC		"xmss-state-v1"
48 #define XMSS_CIPHERNAME		"aes256-gcm@openssh.com"
49 struct ssh_xmss_state {
50 	xmss_params	params;
51 	u_int32_t	n, w, h, k;
52 
53 	bds_state	bds;
54 	u_char		*stack;
55 	u_int32_t	stackoffset;
56 	u_char		*stacklevels;
57 	u_char		*auth;
58 	u_char		*keep;
59 	u_char		*th_nodes;
60 	u_char		*retain;
61 	treehash_inst	*treehash;
62 
63 	u_int32_t	idx;		/* state read from file */
64 	u_int32_t	maxidx;		/* restricted # of signatures */
65 	int		have_state;	/* .state file exists */
66 	int		lockfd;		/* locked in sshkey_xmss_get_state() */
67 	u_char		allow_update;	/* allow sshkey_xmss_update_state() */
68 	char		*enc_ciphername;/* encrypt state with cipher */
69 	u_char		*enc_keyiv;	/* encrypt state with key */
70 	u_int32_t	enc_keyiv_len;	/* length of enc_keyiv */
71 };
72 
73 int	 sshkey_xmss_init_bds_state(struct sshkey *);
74 int	 sshkey_xmss_init_enc_key(struct sshkey *, const char *);
75 void	 sshkey_xmss_free_bds(struct sshkey *);
76 int	 sshkey_xmss_get_state_from_file(struct sshkey *, const char *,
77 	    int *, int);
78 int	 sshkey_xmss_encrypt_state(const struct sshkey *, struct sshbuf *,
79 	    struct sshbuf **);
80 int	 sshkey_xmss_decrypt_state(const struct sshkey *, struct sshbuf *,
81 	    struct sshbuf **);
82 int	 sshkey_xmss_serialize_enc_key(const struct sshkey *, struct sshbuf *);
83 int	 sshkey_xmss_deserialize_enc_key(struct sshkey *, struct sshbuf *);
84 
85 #define PRINT(...) do { if (printerror) sshlog(__FILE__, __func__, __LINE__, \
86     0, SYSLOG_LEVEL_ERROR, __VA_ARGS__); } while (0)
87 
88 int
89 sshkey_xmss_init(struct sshkey *key, const char *name)
90 {
91 	struct ssh_xmss_state *state;
92 
93 	if (key->xmss_state != NULL)
94 		return SSH_ERR_INVALID_FORMAT;
95 	if (name == NULL)
96 		return SSH_ERR_INVALID_FORMAT;
97 	state = calloc(sizeof(struct ssh_xmss_state), 1);
98 	if (state == NULL)
99 		return SSH_ERR_ALLOC_FAIL;
100 	if (strcmp(name, XMSS_SHA2_256_W16_H10_NAME) == 0) {
101 		state->n = 32;
102 		state->w = 16;
103 		state->h = 10;
104 	} else if (strcmp(name, XMSS_SHA2_256_W16_H16_NAME) == 0) {
105 		state->n = 32;
106 		state->w = 16;
107 		state->h = 16;
108 	} else if (strcmp(name, XMSS_SHA2_256_W16_H20_NAME) == 0) {
109 		state->n = 32;
110 		state->w = 16;
111 		state->h = 20;
112 	} else {
113 		free(state);
114 		return SSH_ERR_KEY_TYPE_UNKNOWN;
115 	}
116 	if ((key->xmss_name = strdup(name)) == NULL) {
117 		free(state);
118 		return SSH_ERR_ALLOC_FAIL;
119 	}
120 	state->k = 2;	/* XXX hardcoded */
121 	state->lockfd = -1;
122 	if (xmss_set_params(&state->params, state->n, state->h, state->w,
123 	    state->k) != 0) {
124 		free(state);
125 		return SSH_ERR_INVALID_FORMAT;
126 	}
127 	key->xmss_state = state;
128 	return 0;
129 }
130 
131 void
132 sshkey_xmss_free_state(struct sshkey *key)
133 {
134 	struct ssh_xmss_state *state = key->xmss_state;
135 
136 	sshkey_xmss_free_bds(key);
137 	if (state) {
138 		if (state->enc_keyiv) {
139 			explicit_bzero(state->enc_keyiv, state->enc_keyiv_len);
140 			free(state->enc_keyiv);
141 		}
142 		free(state->enc_ciphername);
143 		free(state);
144 	}
145 	key->xmss_state = NULL;
146 }
147 
148 #define SSH_XMSS_K2_MAGIC	"k=2"
149 #define num_stack(x)		((x->h+1)*(x->n))
150 #define num_stacklevels(x)	(x->h+1)
151 #define num_auth(x)		((x->h)*(x->n))
152 #define num_keep(x)		((x->h >> 1)*(x->n))
153 #define num_th_nodes(x)		((x->h - x->k)*(x->n))
154 #define num_retain(x)		(((1ULL << x->k) - x->k - 1) * (x->n))
155 #define num_treehash(x)		((x->h) - (x->k))
156 
157 int
158 sshkey_xmss_init_bds_state(struct sshkey *key)
159 {
160 	struct ssh_xmss_state *state = key->xmss_state;
161 	u_int32_t i;
162 
163 	state->stackoffset = 0;
164 	if ((state->stack = calloc(num_stack(state), 1)) == NULL ||
165 	    (state->stacklevels = calloc(num_stacklevels(state), 1))== NULL ||
166 	    (state->auth = calloc(num_auth(state), 1)) == NULL ||
167 	    (state->keep = calloc(num_keep(state), 1)) == NULL ||
168 	    (state->th_nodes = calloc(num_th_nodes(state), 1)) == NULL ||
169 	    (state->retain = calloc(num_retain(state), 1)) == NULL ||
170 	    (state->treehash = calloc(num_treehash(state),
171 	    sizeof(treehash_inst))) == NULL) {
172 		sshkey_xmss_free_bds(key);
173 		return SSH_ERR_ALLOC_FAIL;
174 	}
175 	for (i = 0; i < state->h - state->k; i++)
176 		state->treehash[i].node = &state->th_nodes[state->n*i];
177 	xmss_set_bds_state(&state->bds, state->stack, state->stackoffset,
178 	    state->stacklevels, state->auth, state->keep, state->treehash,
179 	    state->retain, 0);
180 	return 0;
181 }
182 
183 void
184 sshkey_xmss_free_bds(struct sshkey *key)
185 {
186 	struct ssh_xmss_state *state = key->xmss_state;
187 
188 	if (state == NULL)
189 		return;
190 	free(state->stack);
191 	free(state->stacklevels);
192 	free(state->auth);
193 	free(state->keep);
194 	free(state->th_nodes);
195 	free(state->retain);
196 	free(state->treehash);
197 	state->stack = NULL;
198 	state->stacklevels = NULL;
199 	state->auth = NULL;
200 	state->keep = NULL;
201 	state->th_nodes = NULL;
202 	state->retain = NULL;
203 	state->treehash = NULL;
204 }
205 
206 void *
207 sshkey_xmss_params(const struct sshkey *key)
208 {
209 	struct ssh_xmss_state *state = key->xmss_state;
210 
211 	if (state == NULL)
212 		return NULL;
213 	return &state->params;
214 }
215 
216 void *
217 sshkey_xmss_bds_state(const struct sshkey *key)
218 {
219 	struct ssh_xmss_state *state = key->xmss_state;
220 
221 	if (state == NULL)
222 		return NULL;
223 	return &state->bds;
224 }
225 
226 int
227 sshkey_xmss_siglen(const struct sshkey *key, size_t *lenp)
228 {
229 	struct ssh_xmss_state *state = key->xmss_state;
230 
231 	if (lenp == NULL)
232 		return SSH_ERR_INVALID_ARGUMENT;
233 	if (state == NULL)
234 		return SSH_ERR_INVALID_FORMAT;
235 	*lenp = 4 + state->n +
236 	    state->params.wots_par.keysize +
237 	    state->h * state->n;
238 	return 0;
239 }
240 
241 size_t
242 sshkey_xmss_pklen(const struct sshkey *key)
243 {
244 	struct ssh_xmss_state *state = key->xmss_state;
245 
246 	if (state == NULL)
247 		return 0;
248 	return state->n * 2;
249 }
250 
251 size_t
252 sshkey_xmss_sklen(const struct sshkey *key)
253 {
254 	struct ssh_xmss_state *state = key->xmss_state;
255 
256 	if (state == NULL)
257 		return 0;
258 	return state->n * 4 + 4;
259 }
260 
261 int
262 sshkey_xmss_init_enc_key(struct sshkey *k, const char *ciphername)
263 {
264 	struct ssh_xmss_state *state = k->xmss_state;
265 	const struct sshcipher *cipher;
266 	size_t keylen = 0, ivlen = 0;
267 
268 	if (state == NULL)
269 		return SSH_ERR_INVALID_ARGUMENT;
270 	if ((cipher = cipher_by_name(ciphername)) == NULL)
271 		return SSH_ERR_INTERNAL_ERROR;
272 	if ((state->enc_ciphername = strdup(ciphername)) == NULL)
273 		return SSH_ERR_ALLOC_FAIL;
274 	keylen = cipher_keylen(cipher);
275 	ivlen = cipher_ivlen(cipher);
276 	state->enc_keyiv_len = keylen + ivlen;
277 	if ((state->enc_keyiv = calloc(state->enc_keyiv_len, 1)) == NULL) {
278 		free(state->enc_ciphername);
279 		state->enc_ciphername = NULL;
280 		return SSH_ERR_ALLOC_FAIL;
281 	}
282 	arc4random_buf(state->enc_keyiv, state->enc_keyiv_len);
283 	return 0;
284 }
285 
286 int
287 sshkey_xmss_serialize_enc_key(const struct sshkey *k, struct sshbuf *b)
288 {
289 	struct ssh_xmss_state *state = k->xmss_state;
290 	int r;
291 
292 	if (state == NULL || state->enc_keyiv == NULL ||
293 	    state->enc_ciphername == NULL)
294 		return SSH_ERR_INVALID_ARGUMENT;
295 	if ((r = sshbuf_put_cstring(b, state->enc_ciphername)) != 0 ||
296 	    (r = sshbuf_put_string(b, state->enc_keyiv,
297 	    state->enc_keyiv_len)) != 0)
298 		return r;
299 	return 0;
300 }
301 
302 int
303 sshkey_xmss_deserialize_enc_key(struct sshkey *k, struct sshbuf *b)
304 {
305 	struct ssh_xmss_state *state = k->xmss_state;
306 	size_t len;
307 	int r;
308 
309 	if (state == NULL)
310 		return SSH_ERR_INVALID_ARGUMENT;
311 	if ((r = sshbuf_get_cstring(b, &state->enc_ciphername, NULL)) != 0 ||
312 	    (r = sshbuf_get_string(b, &state->enc_keyiv, &len)) != 0)
313 		return r;
314 	state->enc_keyiv_len = len;
315 	return 0;
316 }
317 
318 int
319 sshkey_xmss_serialize_pk_info(const struct sshkey *k, struct sshbuf *b,
320     enum sshkey_serialize_rep opts)
321 {
322 	struct ssh_xmss_state *state = k->xmss_state;
323 	u_char have_info = 1;
324 	u_int32_t idx;
325 	int r;
326 
327 	if (state == NULL)
328 		return SSH_ERR_INVALID_ARGUMENT;
329 	if (opts != SSHKEY_SERIALIZE_INFO)
330 		return 0;
331 	idx = k->xmss_sk ? PEEK_U32(k->xmss_sk) : state->idx;
332 	if ((r = sshbuf_put_u8(b, have_info)) != 0 ||
333 	    (r = sshbuf_put_u32(b, idx)) != 0 ||
334 	    (r = sshbuf_put_u32(b, state->maxidx)) != 0)
335 		return r;
336 	return 0;
337 }
338 
339 int
340 sshkey_xmss_deserialize_pk_info(struct sshkey *k, struct sshbuf *b)
341 {
342 	struct ssh_xmss_state *state = k->xmss_state;
343 	u_char have_info;
344 	int r;
345 
346 	if (state == NULL)
347 		return SSH_ERR_INVALID_ARGUMENT;
348 	/* optional */
349 	if (sshbuf_len(b) == 0)
350 		return 0;
351 	if ((r = sshbuf_get_u8(b, &have_info)) != 0)
352 		return r;
353 	if (have_info != 1)
354 		return SSH_ERR_INVALID_ARGUMENT;
355 	if ((r = sshbuf_get_u32(b, &state->idx)) != 0 ||
356 	    (r = sshbuf_get_u32(b, &state->maxidx)) != 0)
357 		return r;
358 	return 0;
359 }
360 
361 int
362 sshkey_xmss_generate_private_key(struct sshkey *k, u_int bits)
363 {
364 	int r;
365 	const char *name;
366 
367 	if (bits == 10) {
368 		name = XMSS_SHA2_256_W16_H10_NAME;
369 	} else if (bits == 16) {
370 		name = XMSS_SHA2_256_W16_H16_NAME;
371 	} else if (bits == 20) {
372 		name = XMSS_SHA2_256_W16_H20_NAME;
373 	} else {
374 		name = XMSS_DEFAULT_NAME;
375 	}
376 	if ((r = sshkey_xmss_init(k, name)) != 0 ||
377 	    (r = sshkey_xmss_init_bds_state(k)) != 0 ||
378 	    (r = sshkey_xmss_init_enc_key(k, XMSS_CIPHERNAME)) != 0)
379 		return r;
380 	if ((k->xmss_pk = malloc(sshkey_xmss_pklen(k))) == NULL ||
381 	    (k->xmss_sk = malloc(sshkey_xmss_sklen(k))) == NULL) {
382 		return SSH_ERR_ALLOC_FAIL;
383 	}
384 	xmss_keypair(k->xmss_pk, k->xmss_sk, sshkey_xmss_bds_state(k),
385 	    sshkey_xmss_params(k));
386 	return 0;
387 }
388 
389 int
390 sshkey_xmss_get_state_from_file(struct sshkey *k, const char *filename,
391     int *have_file, int printerror)
392 {
393 	struct sshbuf *b = NULL, *enc = NULL;
394 	int ret = SSH_ERR_SYSTEM_ERROR, r, fd = -1;
395 	u_int32_t len;
396 	unsigned char buf[4], *data = NULL;
397 
398 	*have_file = 0;
399 	if ((fd = open(filename, O_RDONLY)) >= 0) {
400 		*have_file = 1;
401 		if (atomicio(read, fd, buf, sizeof(buf)) != sizeof(buf)) {
402 			PRINT("%s: corrupt state file: %s", __func__, filename);
403 			goto done;
404 		}
405 		len = PEEK_U32(buf);
406 		if ((data = calloc(len, 1)) == NULL) {
407 			ret = SSH_ERR_ALLOC_FAIL;
408 			goto done;
409 		}
410 		if (atomicio(read, fd, data, len) != len) {
411 			PRINT("%s: cannot read blob: %s", __func__, filename);
412 			goto done;
413 		}
414 		if ((enc = sshbuf_from(data, len)) == NULL) {
415 			ret = SSH_ERR_ALLOC_FAIL;
416 			goto done;
417 		}
418 		sshkey_xmss_free_bds(k);
419 		if ((r = sshkey_xmss_decrypt_state(k, enc, &b)) != 0) {
420 			ret = r;
421 			goto done;
422 		}
423 		if ((r = sshkey_xmss_deserialize_state(k, b)) != 0) {
424 			ret = r;
425 			goto done;
426 		}
427 		ret = 0;
428 	}
429 done:
430 	if (fd != -1)
431 		close(fd);
432 	free(data);
433 	sshbuf_free(enc);
434 	sshbuf_free(b);
435 	return ret;
436 }
437 
438 int
439 sshkey_xmss_get_state(const struct sshkey *k, int printerror)
440 {
441 	struct ssh_xmss_state *state = k->xmss_state;
442 	u_int32_t idx = 0;
443 	char *filename = NULL;
444 	char *statefile = NULL, *ostatefile = NULL, *lockfile = NULL;
445 	int lockfd = -1, have_state = 0, have_ostate, tries = 0;
446 	int ret = SSH_ERR_INVALID_ARGUMENT, r;
447 
448 	if (state == NULL)
449 		goto done;
450 	/*
451 	 * If maxidx is set, then we are allowed a limited number
452 	 * of signatures, but don't need to access the disk.
453 	 * Otherwise we need to deal with the on-disk state.
454 	 */
455 	if (state->maxidx) {
456 		/* xmss_sk always contains the current state */
457 		idx = PEEK_U32(k->xmss_sk);
458 		if (idx < state->maxidx) {
459 			state->allow_update = 1;
460 			return 0;
461 		}
462 		return SSH_ERR_INVALID_ARGUMENT;
463 	}
464 	if ((filename = k->xmss_filename) == NULL)
465 		goto done;
466 	if (asprintf(&lockfile, "%s.lock", filename) == -1 ||
467 	    asprintf(&statefile, "%s.state", filename) == -1 ||
468 	    asprintf(&ostatefile, "%s.ostate", filename) == -1) {
469 		ret = SSH_ERR_ALLOC_FAIL;
470 		goto done;
471 	}
472 	if ((lockfd = open(lockfile, O_CREAT|O_RDONLY, 0600)) == -1) {
473 		ret = SSH_ERR_SYSTEM_ERROR;
474 		PRINT("%s: cannot open/create: %s", __func__, lockfile);
475 		goto done;
476 	}
477 	while (flock(lockfd, LOCK_EX|LOCK_NB) == -1) {
478 		if (errno != EWOULDBLOCK) {
479 			ret = SSH_ERR_SYSTEM_ERROR;
480 			PRINT("%s: cannot lock: %s", __func__, lockfile);
481 			goto done;
482 		}
483 		if (++tries > 10) {
484 			ret = SSH_ERR_SYSTEM_ERROR;
485 			PRINT("%s: giving up on: %s", __func__, lockfile);
486 			goto done;
487 		}
488 		usleep(1000*100*tries);
489 	}
490 	/* XXX no longer const */
491 	if ((r = sshkey_xmss_get_state_from_file((struct sshkey *)k,
492 	    statefile, &have_state, printerror)) != 0) {
493 		if ((r = sshkey_xmss_get_state_from_file((struct sshkey *)k,
494 		    ostatefile, &have_ostate, printerror)) == 0) {
495 			state->allow_update = 1;
496 			r = sshkey_xmss_forward_state(k, 1);
497 			state->idx = PEEK_U32(k->xmss_sk);
498 			state->allow_update = 0;
499 		}
500 	}
501 	if (!have_state && !have_ostate) {
502 		/* check that bds state is initialized */
503 		if (state->bds.auth == NULL)
504 			goto done;
505 		PRINT("%s: start from scratch idx 0: %u", __func__, state->idx);
506 	} else if (r != 0) {
507 		ret = r;
508 		goto done;
509 	}
510 	if (state->idx + 1 < state->idx) {
511 		PRINT("%s: state wrap: %u", __func__, state->idx);
512 		goto done;
513 	}
514 	state->have_state = have_state;
515 	state->lockfd = lockfd;
516 	state->allow_update = 1;
517 	lockfd = -1;
518 	ret = 0;
519 done:
520 	if (lockfd != -1)
521 		close(lockfd);
522 	free(lockfile);
523 	free(statefile);
524 	free(ostatefile);
525 	return ret;
526 }
527 
528 int
529 sshkey_xmss_forward_state(const struct sshkey *k, u_int32_t reserve)
530 {
531 	struct ssh_xmss_state *state = k->xmss_state;
532 	u_char *sig = NULL;
533 	size_t required_siglen;
534 	unsigned long long smlen;
535 	u_char data;
536 	int ret, r;
537 
538 	if (state == NULL || !state->allow_update)
539 		return SSH_ERR_INVALID_ARGUMENT;
540 	if (reserve == 0)
541 		return SSH_ERR_INVALID_ARGUMENT;
542 	if (state->idx + reserve <= state->idx)
543 		return SSH_ERR_INVALID_ARGUMENT;
544 	if ((r = sshkey_xmss_siglen(k, &required_siglen)) != 0)
545 		return r;
546 	if ((sig = malloc(required_siglen)) == NULL)
547 		return SSH_ERR_ALLOC_FAIL;
548 	while (reserve-- > 0) {
549 		state->idx = PEEK_U32(k->xmss_sk);
550 		smlen = required_siglen;
551 		if ((ret = xmss_sign(k->xmss_sk, sshkey_xmss_bds_state(k),
552 		    sig, &smlen, &data, 0, sshkey_xmss_params(k))) != 0) {
553 			r = SSH_ERR_INVALID_ARGUMENT;
554 			break;
555 		}
556 	}
557 	free(sig);
558 	return r;
559 }
560 
561 int
562 sshkey_xmss_update_state(const struct sshkey *k, int printerror)
563 {
564 	struct ssh_xmss_state *state = k->xmss_state;
565 	struct sshbuf *b = NULL, *enc = NULL;
566 	u_int32_t idx = 0;
567 	unsigned char buf[4];
568 	char *filename = NULL;
569 	char *statefile = NULL, *ostatefile = NULL, *nstatefile = NULL;
570 	int fd = -1;
571 	int ret = SSH_ERR_INVALID_ARGUMENT;
572 
573 	if (state == NULL || !state->allow_update)
574 		return ret;
575 	if (state->maxidx) {
576 		/* no update since the number of signatures is limited */
577 		ret = 0;
578 		goto done;
579 	}
580 	idx = PEEK_U32(k->xmss_sk);
581 	if (idx == state->idx) {
582 		/* no signature happened, no need to update */
583 		ret = 0;
584 		goto done;
585 	} else if (idx != state->idx + 1) {
586 		PRINT("%s: more than one signature happened: idx %u state %u",
587 		     __func__, idx, state->idx);
588 		goto done;
589 	}
590 	state->idx = idx;
591 	if ((filename = k->xmss_filename) == NULL)
592 		goto done;
593 	if (asprintf(&statefile, "%s.state", filename) == -1 ||
594 	    asprintf(&ostatefile, "%s.ostate", filename) == -1 ||
595 	    asprintf(&nstatefile, "%s.nstate", filename) == -1) {
596 		ret = SSH_ERR_ALLOC_FAIL;
597 		goto done;
598 	}
599 	unlink(nstatefile);
600 	if ((b = sshbuf_new()) == NULL) {
601 		ret = SSH_ERR_ALLOC_FAIL;
602 		goto done;
603 	}
604 	if ((ret = sshkey_xmss_serialize_state(k, b)) != 0) {
605 		PRINT("%s: SERLIALIZE FAILED: %d", __func__, ret);
606 		goto done;
607 	}
608 	if ((ret = sshkey_xmss_encrypt_state(k, b, &enc)) != 0) {
609 		PRINT("%s: ENCRYPT FAILED: %d", __func__, ret);
610 		goto done;
611 	}
612 	if ((fd = open(nstatefile, O_CREAT|O_WRONLY|O_EXCL, 0600)) == -1) {
613 		ret = SSH_ERR_SYSTEM_ERROR;
614 		PRINT("%s: open new state file: %s", __func__, nstatefile);
615 		goto done;
616 	}
617 	POKE_U32(buf, sshbuf_len(enc));
618 	if (atomicio(vwrite, fd, buf, sizeof(buf)) != sizeof(buf)) {
619 		ret = SSH_ERR_SYSTEM_ERROR;
620 		PRINT("%s: write new state file hdr: %s", __func__, nstatefile);
621 		close(fd);
622 		goto done;
623 	}
624 	if (atomicio(vwrite, fd, sshbuf_mutable_ptr(enc), sshbuf_len(enc)) !=
625 	    sshbuf_len(enc)) {
626 		ret = SSH_ERR_SYSTEM_ERROR;
627 		PRINT("%s: write new state file data: %s", __func__, nstatefile);
628 		close(fd);
629 		goto done;
630 	}
631 	if (fsync(fd) == -1) {
632 		ret = SSH_ERR_SYSTEM_ERROR;
633 		PRINT("%s: sync new state file: %s", __func__, nstatefile);
634 		close(fd);
635 		goto done;
636 	}
637 	if (close(fd) == -1) {
638 		ret = SSH_ERR_SYSTEM_ERROR;
639 		PRINT("%s: close new state file: %s", __func__, nstatefile);
640 		goto done;
641 	}
642 	if (state->have_state) {
643 		unlink(ostatefile);
644 		if (link(statefile, ostatefile)) {
645 			ret = SSH_ERR_SYSTEM_ERROR;
646 			PRINT("%s: backup state %s to %s", __func__, statefile,
647 			    ostatefile);
648 			goto done;
649 		}
650 	}
651 	if (rename(nstatefile, statefile) == -1) {
652 		ret = SSH_ERR_SYSTEM_ERROR;
653 		PRINT("%s: rename %s to %s", __func__, nstatefile, statefile);
654 		goto done;
655 	}
656 	ret = 0;
657 done:
658 	if (state->lockfd != -1) {
659 		close(state->lockfd);
660 		state->lockfd = -1;
661 	}
662 	if (nstatefile)
663 		unlink(nstatefile);
664 	free(statefile);
665 	free(ostatefile);
666 	free(nstatefile);
667 	sshbuf_free(b);
668 	sshbuf_free(enc);
669 	return ret;
670 }
671 
672 int
673 sshkey_xmss_serialize_state(const struct sshkey *k, struct sshbuf *b)
674 {
675 	struct ssh_xmss_state *state = k->xmss_state;
676 	treehash_inst *th;
677 	u_int32_t i, node;
678 	int r;
679 
680 	if (state == NULL)
681 		return SSH_ERR_INVALID_ARGUMENT;
682 	if (state->stack == NULL)
683 		return SSH_ERR_INVALID_ARGUMENT;
684 	state->stackoffset = state->bds.stackoffset;	/* copy back */
685 	if ((r = sshbuf_put_cstring(b, SSH_XMSS_K2_MAGIC)) != 0 ||
686 	    (r = sshbuf_put_u32(b, state->idx)) != 0 ||
687 	    (r = sshbuf_put_string(b, state->stack, num_stack(state))) != 0 ||
688 	    (r = sshbuf_put_u32(b, state->stackoffset)) != 0 ||
689 	    (r = sshbuf_put_string(b, state->stacklevels, num_stacklevels(state))) != 0 ||
690 	    (r = sshbuf_put_string(b, state->auth, num_auth(state))) != 0 ||
691 	    (r = sshbuf_put_string(b, state->keep, num_keep(state))) != 0 ||
692 	    (r = sshbuf_put_string(b, state->th_nodes, num_th_nodes(state))) != 0 ||
693 	    (r = sshbuf_put_string(b, state->retain, num_retain(state))) != 0 ||
694 	    (r = sshbuf_put_u32(b, num_treehash(state))) != 0)
695 		return r;
696 	for (i = 0; i < num_treehash(state); i++) {
697 		th = &state->treehash[i];
698 		node = th->node - state->th_nodes;
699 		if ((r = sshbuf_put_u32(b, th->h)) != 0 ||
700 		    (r = sshbuf_put_u32(b, th->next_idx)) != 0 ||
701 		    (r = sshbuf_put_u32(b, th->stackusage)) != 0 ||
702 		    (r = sshbuf_put_u8(b, th->completed)) != 0 ||
703 		    (r = sshbuf_put_u32(b, node)) != 0)
704 			return r;
705 	}
706 	return 0;
707 }
708 
709 int
710 sshkey_xmss_serialize_state_opt(const struct sshkey *k, struct sshbuf *b,
711     enum sshkey_serialize_rep opts)
712 {
713 	struct ssh_xmss_state *state = k->xmss_state;
714 	int r = SSH_ERR_INVALID_ARGUMENT;
715 	u_char have_stack, have_filename, have_enc;
716 
717 	if (state == NULL)
718 		return SSH_ERR_INVALID_ARGUMENT;
719 	if ((r = sshbuf_put_u8(b, opts)) != 0)
720 		return r;
721 	switch (opts) {
722 	case SSHKEY_SERIALIZE_STATE:
723 		r = sshkey_xmss_serialize_state(k, b);
724 		break;
725 	case SSHKEY_SERIALIZE_FULL:
726 		if ((r = sshkey_xmss_serialize_enc_key(k, b)) != 0)
727 			return r;
728 		r = sshkey_xmss_serialize_state(k, b);
729 		break;
730 	case SSHKEY_SERIALIZE_SHIELD:
731 		/* all of stack/filename/enc are optional */
732 		have_stack = state->stack != NULL;
733 		if ((r = sshbuf_put_u8(b, have_stack)) != 0)
734 			return r;
735 		if (have_stack) {
736 			state->idx = PEEK_U32(k->xmss_sk);	/* update */
737 			if ((r = sshkey_xmss_serialize_state(k, b)) != 0)
738 				return r;
739 		}
740 		have_filename = k->xmss_filename != NULL;
741 		if ((r = sshbuf_put_u8(b, have_filename)) != 0)
742 			return r;
743 		if (have_filename &&
744 		    (r = sshbuf_put_cstring(b, k->xmss_filename)) != 0)
745 			return r;
746 		have_enc = state->enc_keyiv != NULL;
747 		if ((r = sshbuf_put_u8(b, have_enc)) != 0)
748 			return r;
749 		if (have_enc &&
750 		    (r = sshkey_xmss_serialize_enc_key(k, b)) != 0)
751 			return r;
752 		if ((r = sshbuf_put_u32(b, state->maxidx)) != 0 ||
753 		    (r = sshbuf_put_u8(b, state->allow_update)) != 0)
754 			return r;
755 		break;
756 	case SSHKEY_SERIALIZE_DEFAULT:
757 		r = 0;
758 		break;
759 	default:
760 		r = SSH_ERR_INVALID_ARGUMENT;
761 		break;
762 	}
763 	return r;
764 }
765 
766 int
767 sshkey_xmss_deserialize_state(struct sshkey *k, struct sshbuf *b)
768 {
769 	struct ssh_xmss_state *state = k->xmss_state;
770 	treehash_inst *th;
771 	u_int32_t i, lh, node;
772 	size_t ls, lsl, la, lk, ln, lr;
773 	char *magic;
774 	int r = SSH_ERR_INTERNAL_ERROR;
775 
776 	if (state == NULL)
777 		return SSH_ERR_INVALID_ARGUMENT;
778 	if (k->xmss_sk == NULL)
779 		return SSH_ERR_INVALID_ARGUMENT;
780 	if ((state->treehash = calloc(num_treehash(state),
781 	    sizeof(treehash_inst))) == NULL)
782 		return SSH_ERR_ALLOC_FAIL;
783 	if ((r = sshbuf_get_cstring(b, &magic, NULL)) != 0 ||
784 	    (r = sshbuf_get_u32(b, &state->idx)) != 0 ||
785 	    (r = sshbuf_get_string(b, &state->stack, &ls)) != 0 ||
786 	    (r = sshbuf_get_u32(b, &state->stackoffset)) != 0 ||
787 	    (r = sshbuf_get_string(b, &state->stacklevels, &lsl)) != 0 ||
788 	    (r = sshbuf_get_string(b, &state->auth, &la)) != 0 ||
789 	    (r = sshbuf_get_string(b, &state->keep, &lk)) != 0 ||
790 	    (r = sshbuf_get_string(b, &state->th_nodes, &ln)) != 0 ||
791 	    (r = sshbuf_get_string(b, &state->retain, &lr)) != 0 ||
792 	    (r = sshbuf_get_u32(b, &lh)) != 0)
793 		goto out;
794 	if (strcmp(magic, SSH_XMSS_K2_MAGIC) != 0) {
795 		r = SSH_ERR_INVALID_ARGUMENT;
796 		goto out;
797 	}
798 	/* XXX check stackoffset */
799 	if (ls != num_stack(state) ||
800 	    lsl != num_stacklevels(state) ||
801 	    la != num_auth(state) ||
802 	    lk != num_keep(state) ||
803 	    ln != num_th_nodes(state) ||
804 	    lr != num_retain(state) ||
805 	    lh != num_treehash(state)) {
806 		r = SSH_ERR_INVALID_ARGUMENT;
807 		goto out;
808 	}
809 	for (i = 0; i < num_treehash(state); i++) {
810 		th = &state->treehash[i];
811 		if ((r = sshbuf_get_u32(b, &th->h)) != 0 ||
812 		    (r = sshbuf_get_u32(b, &th->next_idx)) != 0 ||
813 		    (r = sshbuf_get_u32(b, &th->stackusage)) != 0 ||
814 		    (r = sshbuf_get_u8(b, &th->completed)) != 0 ||
815 		    (r = sshbuf_get_u32(b, &node)) != 0)
816 			goto out;
817 		if (node < num_th_nodes(state))
818 			th->node = &state->th_nodes[node];
819 	}
820 	POKE_U32(k->xmss_sk, state->idx);
821 	xmss_set_bds_state(&state->bds, state->stack, state->stackoffset,
822 	    state->stacklevels, state->auth, state->keep, state->treehash,
823 	    state->retain, 0);
824 	/* success */
825 	r = 0;
826  out:
827 	free(magic);
828 	return r;
829 }
830 
831 int
832 sshkey_xmss_deserialize_state_opt(struct sshkey *k, struct sshbuf *b)
833 {
834 	struct ssh_xmss_state *state = k->xmss_state;
835 	enum sshkey_serialize_rep opts;
836 	u_char have_state, have_stack, have_filename, have_enc;
837 	int r;
838 
839 	if ((r = sshbuf_get_u8(b, &have_state)) != 0)
840 		return r;
841 
842 	opts = have_state;
843 	switch (opts) {
844 	case SSHKEY_SERIALIZE_DEFAULT:
845 		r = 0;
846 		break;
847 	case SSHKEY_SERIALIZE_SHIELD:
848 		if ((r = sshbuf_get_u8(b, &have_stack)) != 0)
849 			return r;
850 		if (have_stack &&
851 		    (r = sshkey_xmss_deserialize_state(k, b)) != 0)
852 			return r;
853 		if ((r = sshbuf_get_u8(b, &have_filename)) != 0)
854 			return r;
855 		if (have_filename &&
856 		    (r = sshbuf_get_cstring(b, &k->xmss_filename, NULL)) != 0)
857 			return r;
858 		if ((r = sshbuf_get_u8(b, &have_enc)) != 0)
859 			return r;
860 		if (have_enc &&
861 		    (r = sshkey_xmss_deserialize_enc_key(k, b)) != 0)
862 			return r;
863 		if ((r = sshbuf_get_u32(b, &state->maxidx)) != 0 ||
864 		    (r = sshbuf_get_u8(b, &state->allow_update)) != 0)
865 			return r;
866 		break;
867 	case SSHKEY_SERIALIZE_STATE:
868 		if ((r = sshkey_xmss_deserialize_state(k, b)) != 0)
869 			return r;
870 		break;
871 	case SSHKEY_SERIALIZE_FULL:
872 		if ((r = sshkey_xmss_deserialize_enc_key(k, b)) != 0 ||
873 		    (r = sshkey_xmss_deserialize_state(k, b)) != 0)
874 			return r;
875 		break;
876 	default:
877 		r = SSH_ERR_INVALID_FORMAT;
878 		break;
879 	}
880 	return r;
881 }
882 
883 int
884 sshkey_xmss_encrypt_state(const struct sshkey *k, struct sshbuf *b,
885    struct sshbuf **retp)
886 {
887 	struct ssh_xmss_state *state = k->xmss_state;
888 	struct sshbuf *encrypted = NULL, *encoded = NULL, *padded = NULL;
889 	struct sshcipher_ctx *ciphercontext = NULL;
890 	const struct sshcipher *cipher;
891 	u_char *cp, *key, *iv = NULL;
892 	size_t i, keylen, ivlen, blocksize, authlen, encrypted_len, aadlen;
893 	int r = SSH_ERR_INTERNAL_ERROR;
894 
895 	if (retp != NULL)
896 		*retp = NULL;
897 	if (state == NULL ||
898 	    state->enc_keyiv == NULL ||
899 	    state->enc_ciphername == NULL)
900 		return SSH_ERR_INTERNAL_ERROR;
901 	if ((cipher = cipher_by_name(state->enc_ciphername)) == NULL) {
902 		r = SSH_ERR_INTERNAL_ERROR;
903 		goto out;
904 	}
905 	blocksize = cipher_blocksize(cipher);
906 	keylen = cipher_keylen(cipher);
907 	ivlen = cipher_ivlen(cipher);
908 	authlen = cipher_authlen(cipher);
909 	if (state->enc_keyiv_len != keylen + ivlen) {
910 		r = SSH_ERR_INVALID_FORMAT;
911 		goto out;
912 	}
913 	key = state->enc_keyiv;
914 	if ((encrypted = sshbuf_new()) == NULL ||
915 	    (encoded = sshbuf_new()) == NULL ||
916 	    (padded = sshbuf_new()) == NULL ||
917 	    (iv = malloc(ivlen)) == NULL) {
918 		r = SSH_ERR_ALLOC_FAIL;
919 		goto out;
920 	}
921 
922 	/* replace first 4 bytes of IV with index to ensure uniqueness */
923 	memcpy(iv, key + keylen, ivlen);
924 	POKE_U32(iv, state->idx);
925 
926 	if ((r = sshbuf_put(encoded, XMSS_MAGIC, sizeof(XMSS_MAGIC))) != 0 ||
927 	    (r = sshbuf_put_u32(encoded, state->idx)) != 0)
928 		goto out;
929 
930 	/* padded state will be encrypted */
931 	if ((r = sshbuf_putb(padded, b)) != 0)
932 		goto out;
933 	i = 0;
934 	while (sshbuf_len(padded) % blocksize) {
935 		if ((r = sshbuf_put_u8(padded, ++i & 0xff)) != 0)
936 			goto out;
937 	}
938 	encrypted_len = sshbuf_len(padded);
939 
940 	/* header including the length of state is used as AAD */
941 	if ((r = sshbuf_put_u32(encoded, encrypted_len)) != 0)
942 		goto out;
943 	aadlen = sshbuf_len(encoded);
944 
945 	/* concat header and state */
946 	if ((r = sshbuf_putb(encoded, padded)) != 0)
947 		goto out;
948 
949 	/* reserve space for encryption of encoded data plus auth tag */
950 	/* encrypt at offset addlen */
951 	if ((r = sshbuf_reserve(encrypted,
952 	    encrypted_len + aadlen + authlen, &cp)) != 0 ||
953 	    (r = cipher_init(&ciphercontext, cipher, key, keylen,
954 	    iv, ivlen, 1)) != 0 ||
955 	    (r = cipher_crypt(ciphercontext, 0, cp, sshbuf_ptr(encoded),
956 	    encrypted_len, aadlen, authlen)) != 0)
957 		goto out;
958 
959 	/* success */
960 	r = 0;
961  out:
962 	if (retp != NULL) {
963 		*retp = encrypted;
964 		encrypted = NULL;
965 	}
966 	sshbuf_free(padded);
967 	sshbuf_free(encoded);
968 	sshbuf_free(encrypted);
969 	cipher_free(ciphercontext);
970 	free(iv);
971 	return r;
972 }
973 
974 int
975 sshkey_xmss_decrypt_state(const struct sshkey *k, struct sshbuf *encoded,
976    struct sshbuf **retp)
977 {
978 	struct ssh_xmss_state *state = k->xmss_state;
979 	struct sshbuf *copy = NULL, *decrypted = NULL;
980 	struct sshcipher_ctx *ciphercontext = NULL;
981 	const struct sshcipher *cipher = NULL;
982 	u_char *key, *iv = NULL, *dp;
983 	size_t keylen, ivlen, authlen, aadlen;
984 	u_int blocksize, encrypted_len, index;
985 	int r = SSH_ERR_INTERNAL_ERROR;
986 
987 	if (retp != NULL)
988 		*retp = NULL;
989 	if (state == NULL ||
990 	    state->enc_keyiv == NULL ||
991 	    state->enc_ciphername == NULL)
992 		return SSH_ERR_INTERNAL_ERROR;
993 	if ((cipher = cipher_by_name(state->enc_ciphername)) == NULL) {
994 		r = SSH_ERR_INVALID_FORMAT;
995 		goto out;
996 	}
997 	blocksize = cipher_blocksize(cipher);
998 	keylen = cipher_keylen(cipher);
999 	ivlen = cipher_ivlen(cipher);
1000 	authlen = cipher_authlen(cipher);
1001 	if (state->enc_keyiv_len != keylen + ivlen) {
1002 		r = SSH_ERR_INTERNAL_ERROR;
1003 		goto out;
1004 	}
1005 	key = state->enc_keyiv;
1006 
1007 	if ((copy = sshbuf_fromb(encoded)) == NULL ||
1008 	    (decrypted = sshbuf_new()) == NULL ||
1009 	    (iv = malloc(ivlen)) == NULL) {
1010 		r = SSH_ERR_ALLOC_FAIL;
1011 		goto out;
1012 	}
1013 
1014 	/* check magic */
1015 	if (sshbuf_len(encoded) < sizeof(XMSS_MAGIC) ||
1016 	    memcmp(sshbuf_ptr(encoded), XMSS_MAGIC, sizeof(XMSS_MAGIC))) {
1017 		r = SSH_ERR_INVALID_FORMAT;
1018 		goto out;
1019 	}
1020 	/* parse public portion */
1021 	if ((r = sshbuf_consume(encoded, sizeof(XMSS_MAGIC))) != 0 ||
1022 	    (r = sshbuf_get_u32(encoded, &index)) != 0 ||
1023 	    (r = sshbuf_get_u32(encoded, &encrypted_len)) != 0)
1024 		goto out;
1025 
1026 	/* check size of encrypted key blob */
1027 	if (encrypted_len < blocksize || (encrypted_len % blocksize) != 0) {
1028 		r = SSH_ERR_INVALID_FORMAT;
1029 		goto out;
1030 	}
1031 	/* check that an appropriate amount of auth data is present */
1032 	if (sshbuf_len(encoded) < authlen ||
1033 	    sshbuf_len(encoded) - authlen < encrypted_len) {
1034 		r = SSH_ERR_INVALID_FORMAT;
1035 		goto out;
1036 	}
1037 
1038 	aadlen = sshbuf_len(copy) - sshbuf_len(encoded);
1039 
1040 	/* replace first 4 bytes of IV with index to ensure uniqueness */
1041 	memcpy(iv, key + keylen, ivlen);
1042 	POKE_U32(iv, index);
1043 
1044 	/* decrypt private state of key */
1045 	if ((r = sshbuf_reserve(decrypted, aadlen + encrypted_len, &dp)) != 0 ||
1046 	    (r = cipher_init(&ciphercontext, cipher, key, keylen,
1047 	    iv, ivlen, 0)) != 0 ||
1048 	    (r = cipher_crypt(ciphercontext, 0, dp, sshbuf_ptr(copy),
1049 	    encrypted_len, aadlen, authlen)) != 0)
1050 		goto out;
1051 
1052 	/* there should be no trailing data */
1053 	if ((r = sshbuf_consume(encoded, encrypted_len + authlen)) != 0)
1054 		goto out;
1055 	if (sshbuf_len(encoded) != 0) {
1056 		r = SSH_ERR_INVALID_FORMAT;
1057 		goto out;
1058 	}
1059 
1060 	/* remove AAD */
1061 	if ((r = sshbuf_consume(decrypted, aadlen)) != 0)
1062 		goto out;
1063 	/* XXX encrypted includes unchecked padding */
1064 
1065 	/* success */
1066 	r = 0;
1067 	if (retp != NULL) {
1068 		*retp = decrypted;
1069 		decrypted = NULL;
1070 	}
1071  out:
1072 	cipher_free(ciphercontext);
1073 	sshbuf_free(copy);
1074 	sshbuf_free(decrypted);
1075 	free(iv);
1076 	return r;
1077 }
1078 
1079 u_int32_t
1080 sshkey_xmss_signatures_left(const struct sshkey *k)
1081 {
1082 	struct ssh_xmss_state *state = k->xmss_state;
1083 	u_int32_t idx;
1084 
1085 	if (sshkey_type_plain(k->type) == KEY_XMSS && state &&
1086 	    state->maxidx) {
1087 		idx = k->xmss_sk ? PEEK_U32(k->xmss_sk) : state->idx;
1088 		if (idx < state->maxidx)
1089 			return state->maxidx - idx;
1090 	}
1091 	return 0;
1092 }
1093 
1094 int
1095 sshkey_xmss_enable_maxsign(struct sshkey *k, u_int32_t maxsign)
1096 {
1097 	struct ssh_xmss_state *state = k->xmss_state;
1098 
1099 	if (sshkey_type_plain(k->type) != KEY_XMSS)
1100 		return SSH_ERR_INVALID_ARGUMENT;
1101 	if (maxsign == 0)
1102 		return 0;
1103 	if (state->idx + maxsign < state->idx)
1104 		return SSH_ERR_INVALID_ARGUMENT;
1105 	state->maxidx = state->idx + maxsign;
1106 	return 0;
1107 }
1108