xref: /openbsd-src/usr.bin/ssh/sshkey-xmss.c (revision b6025feb5ae4d7a3bac25b25ea94cdcc87366a3d)
1*b6025febSdjm /* $OpenBSD: sshkey-xmss.c,v 1.12 2022/10/28 00:39:29 djm Exp $ */
2a6be8e7cSmarkus /*
3a6be8e7cSmarkus  * Copyright (c) 2017 Markus Friedl.  All rights reserved.
4a6be8e7cSmarkus  *
5a6be8e7cSmarkus  * Redistribution and use in source and binary forms, with or without
6a6be8e7cSmarkus  * modification, are permitted provided that the following conditions
7a6be8e7cSmarkus  * are met:
8a6be8e7cSmarkus  * 1. Redistributions of source code must retain the above copyright
9a6be8e7cSmarkus  *    notice, this list of conditions and the following disclaimer.
10a6be8e7cSmarkus  * 2. Redistributions in binary form must reproduce the above copyright
11a6be8e7cSmarkus  *    notice, this list of conditions and the following disclaimer in the
12a6be8e7cSmarkus  *    documentation and/or other materials provided with the distribution.
13a6be8e7cSmarkus  *
14a6be8e7cSmarkus  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
15a6be8e7cSmarkus  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
16a6be8e7cSmarkus  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
17a6be8e7cSmarkus  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
18a6be8e7cSmarkus  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
19a6be8e7cSmarkus  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
20a6be8e7cSmarkus  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
21a6be8e7cSmarkus  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22a6be8e7cSmarkus  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
23a6be8e7cSmarkus  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24a6be8e7cSmarkus  */
25a6be8e7cSmarkus 
26a6be8e7cSmarkus #include <sys/types.h>
27a6be8e7cSmarkus #include <sys/uio.h>
28a6be8e7cSmarkus 
29a6be8e7cSmarkus #include <stdio.h>
30a6be8e7cSmarkus #include <string.h>
31a6be8e7cSmarkus #include <unistd.h>
32a6be8e7cSmarkus #include <fcntl.h>
33a6be8e7cSmarkus #include <errno.h>
34a6be8e7cSmarkus 
35a6be8e7cSmarkus #include "ssh2.h"
36a6be8e7cSmarkus #include "ssherr.h"
37a6be8e7cSmarkus #include "sshbuf.h"
38a6be8e7cSmarkus #include "cipher.h"
39a6be8e7cSmarkus #include "sshkey.h"
40a6be8e7cSmarkus #include "sshkey-xmss.h"
41a6be8e7cSmarkus #include "atomicio.h"
4279e62715Sdtucker #include "log.h"
43a6be8e7cSmarkus 
44a6be8e7cSmarkus #include "xmss_fast.h"
45a6be8e7cSmarkus 
46a6be8e7cSmarkus /* opaque internal XMSS state */
47a6be8e7cSmarkus #define XMSS_MAGIC		"xmss-state-v1"
48a6be8e7cSmarkus #define XMSS_CIPHERNAME		"aes256-gcm@openssh.com"
49a6be8e7cSmarkus struct ssh_xmss_state {
50a6be8e7cSmarkus 	xmss_params	params;
51a6be8e7cSmarkus 	u_int32_t	n, w, h, k;
52a6be8e7cSmarkus 
53a6be8e7cSmarkus 	bds_state	bds;
54a6be8e7cSmarkus 	u_char		*stack;
55a6be8e7cSmarkus 	u_int32_t	stackoffset;
56a6be8e7cSmarkus 	u_char		*stacklevels;
57a6be8e7cSmarkus 	u_char		*auth;
58a6be8e7cSmarkus 	u_char		*keep;
59a6be8e7cSmarkus 	u_char		*th_nodes;
60a6be8e7cSmarkus 	u_char		*retain;
61a6be8e7cSmarkus 	treehash_inst	*treehash;
62a6be8e7cSmarkus 
63a6be8e7cSmarkus 	u_int32_t	idx;		/* state read from file */
6427a1722dSdjm 	u_int32_t	maxidx;		/* restricted # of signatures */
65a6be8e7cSmarkus 	int		have_state;	/* .state file exists */
66a6be8e7cSmarkus 	int		lockfd;		/* locked in sshkey_xmss_get_state() */
67d3c68393Smarkus 	u_char		allow_update;	/* allow sshkey_xmss_update_state() */
68a6be8e7cSmarkus 	char		*enc_ciphername;/* encrypt state with cipher */
69a6be8e7cSmarkus 	u_char		*enc_keyiv;	/* encrypt state with key */
70a6be8e7cSmarkus 	u_int32_t	enc_keyiv_len;	/* length of enc_keyiv */
71a6be8e7cSmarkus };
72a6be8e7cSmarkus 
73a6be8e7cSmarkus int	 sshkey_xmss_init_bds_state(struct sshkey *);
74a6be8e7cSmarkus int	 sshkey_xmss_init_enc_key(struct sshkey *, const char *);
75a6be8e7cSmarkus void	 sshkey_xmss_free_bds(struct sshkey *);
76a6be8e7cSmarkus int	 sshkey_xmss_get_state_from_file(struct sshkey *, const char *,
7779e62715Sdtucker 	    int *, int);
78a6be8e7cSmarkus int	 sshkey_xmss_encrypt_state(const struct sshkey *, struct sshbuf *,
79a6be8e7cSmarkus 	    struct sshbuf **);
80a6be8e7cSmarkus int	 sshkey_xmss_decrypt_state(const struct sshkey *, struct sshbuf *,
81a6be8e7cSmarkus 	    struct sshbuf **);
82a6be8e7cSmarkus int	 sshkey_xmss_serialize_enc_key(const struct sshkey *, struct sshbuf *);
83a6be8e7cSmarkus int	 sshkey_xmss_deserialize_enc_key(struct sshkey *, struct sshbuf *);
84a6be8e7cSmarkus 
8579e62715Sdtucker #define PRINT(...) do { if (printerror) sshlog(__FILE__, __func__, __LINE__, \
86ecf7a8f6Smillert     0, SYSLOG_LEVEL_ERROR, NULL, __VA_ARGS__); } while (0)
87a6be8e7cSmarkus 
88a6be8e7cSmarkus int
sshkey_xmss_init(struct sshkey * key,const char * name)89a6be8e7cSmarkus sshkey_xmss_init(struct sshkey *key, const char *name)
90a6be8e7cSmarkus {
91a6be8e7cSmarkus 	struct ssh_xmss_state *state;
92a6be8e7cSmarkus 
93a6be8e7cSmarkus 	if (key->xmss_state != NULL)
94a6be8e7cSmarkus 		return SSH_ERR_INVALID_FORMAT;
95a6be8e7cSmarkus 	if (name == NULL)
96a6be8e7cSmarkus 		return SSH_ERR_INVALID_FORMAT;
97a6be8e7cSmarkus 	state = calloc(sizeof(struct ssh_xmss_state), 1);
98a6be8e7cSmarkus 	if (state == NULL)
99a6be8e7cSmarkus 		return SSH_ERR_ALLOC_FAIL;
100a6be8e7cSmarkus 	if (strcmp(name, XMSS_SHA2_256_W16_H10_NAME) == 0) {
101a6be8e7cSmarkus 		state->n = 32;
102a6be8e7cSmarkus 		state->w = 16;
103a6be8e7cSmarkus 		state->h = 10;
104a6be8e7cSmarkus 	} else if (strcmp(name, XMSS_SHA2_256_W16_H16_NAME) == 0) {
105a6be8e7cSmarkus 		state->n = 32;
106a6be8e7cSmarkus 		state->w = 16;
107a6be8e7cSmarkus 		state->h = 16;
108a6be8e7cSmarkus 	} else if (strcmp(name, XMSS_SHA2_256_W16_H20_NAME) == 0) {
109a6be8e7cSmarkus 		state->n = 32;
110a6be8e7cSmarkus 		state->w = 16;
111a6be8e7cSmarkus 		state->h = 20;
112a6be8e7cSmarkus 	} else {
113a6be8e7cSmarkus 		free(state);
114a6be8e7cSmarkus 		return SSH_ERR_KEY_TYPE_UNKNOWN;
115a6be8e7cSmarkus 	}
116a6be8e7cSmarkus 	if ((key->xmss_name = strdup(name)) == NULL) {
117a6be8e7cSmarkus 		free(state);
118a6be8e7cSmarkus 		return SSH_ERR_ALLOC_FAIL;
119a6be8e7cSmarkus 	}
120a6be8e7cSmarkus 	state->k = 2;	/* XXX hardcoded */
121a6be8e7cSmarkus 	state->lockfd = -1;
122a6be8e7cSmarkus 	if (xmss_set_params(&state->params, state->n, state->h, state->w,
123a6be8e7cSmarkus 	    state->k) != 0) {
124a6be8e7cSmarkus 		free(state);
125a6be8e7cSmarkus 		return SSH_ERR_INVALID_FORMAT;
126a6be8e7cSmarkus 	}
127a6be8e7cSmarkus 	key->xmss_state = state;
128a6be8e7cSmarkus 	return 0;
129a6be8e7cSmarkus }
130a6be8e7cSmarkus 
131a6be8e7cSmarkus void
sshkey_xmss_free_state(struct sshkey * key)132a6be8e7cSmarkus sshkey_xmss_free_state(struct sshkey *key)
133a6be8e7cSmarkus {
134a6be8e7cSmarkus 	struct ssh_xmss_state *state = key->xmss_state;
135a6be8e7cSmarkus 
136a6be8e7cSmarkus 	sshkey_xmss_free_bds(key);
137a6be8e7cSmarkus 	if (state) {
138a6be8e7cSmarkus 		if (state->enc_keyiv) {
139a6be8e7cSmarkus 			explicit_bzero(state->enc_keyiv, state->enc_keyiv_len);
140a6be8e7cSmarkus 			free(state->enc_keyiv);
141a6be8e7cSmarkus 		}
142a6be8e7cSmarkus 		free(state->enc_ciphername);
143a6be8e7cSmarkus 		free(state);
144a6be8e7cSmarkus 	}
145a6be8e7cSmarkus 	key->xmss_state = NULL;
146a6be8e7cSmarkus }
147a6be8e7cSmarkus 
148a6be8e7cSmarkus #define SSH_XMSS_K2_MAGIC	"k=2"
149a6be8e7cSmarkus #define num_stack(x)		((x->h+1)*(x->n))
150a6be8e7cSmarkus #define num_stacklevels(x)	(x->h+1)
151a6be8e7cSmarkus #define num_auth(x)		((x->h)*(x->n))
152a6be8e7cSmarkus #define num_keep(x)		((x->h >> 1)*(x->n))
153a6be8e7cSmarkus #define num_th_nodes(x)		((x->h - x->k)*(x->n))
154a6be8e7cSmarkus #define num_retain(x)		(((1ULL << x->k) - x->k - 1) * (x->n))
155a6be8e7cSmarkus #define num_treehash(x)		((x->h) - (x->k))
156a6be8e7cSmarkus 
157a6be8e7cSmarkus int
sshkey_xmss_init_bds_state(struct sshkey * key)158a6be8e7cSmarkus sshkey_xmss_init_bds_state(struct sshkey *key)
159a6be8e7cSmarkus {
160a6be8e7cSmarkus 	struct ssh_xmss_state *state = key->xmss_state;
161a6be8e7cSmarkus 	u_int32_t i;
162a6be8e7cSmarkus 
163a6be8e7cSmarkus 	state->stackoffset = 0;
164a6be8e7cSmarkus 	if ((state->stack = calloc(num_stack(state), 1)) == NULL ||
165a6be8e7cSmarkus 	    (state->stacklevels = calloc(num_stacklevels(state), 1))== NULL ||
166a6be8e7cSmarkus 	    (state->auth = calloc(num_auth(state), 1)) == NULL ||
167a6be8e7cSmarkus 	    (state->keep = calloc(num_keep(state), 1)) == NULL ||
168a6be8e7cSmarkus 	    (state->th_nodes = calloc(num_th_nodes(state), 1)) == NULL ||
169a6be8e7cSmarkus 	    (state->retain = calloc(num_retain(state), 1)) == NULL ||
170a6be8e7cSmarkus 	    (state->treehash = calloc(num_treehash(state),
171a6be8e7cSmarkus 	    sizeof(treehash_inst))) == NULL) {
172a6be8e7cSmarkus 		sshkey_xmss_free_bds(key);
173a6be8e7cSmarkus 		return SSH_ERR_ALLOC_FAIL;
174a6be8e7cSmarkus 	}
175a6be8e7cSmarkus 	for (i = 0; i < state->h - state->k; i++)
176a6be8e7cSmarkus 		state->treehash[i].node = &state->th_nodes[state->n*i];
177a6be8e7cSmarkus 	xmss_set_bds_state(&state->bds, state->stack, state->stackoffset,
178a6be8e7cSmarkus 	    state->stacklevels, state->auth, state->keep, state->treehash,
179a6be8e7cSmarkus 	    state->retain, 0);
180a6be8e7cSmarkus 	return 0;
181a6be8e7cSmarkus }
182a6be8e7cSmarkus 
183a6be8e7cSmarkus void
sshkey_xmss_free_bds(struct sshkey * key)184a6be8e7cSmarkus sshkey_xmss_free_bds(struct sshkey *key)
185a6be8e7cSmarkus {
186a6be8e7cSmarkus 	struct ssh_xmss_state *state = key->xmss_state;
187a6be8e7cSmarkus 
188a6be8e7cSmarkus 	if (state == NULL)
189a6be8e7cSmarkus 		return;
190a6be8e7cSmarkus 	free(state->stack);
191a6be8e7cSmarkus 	free(state->stacklevels);
192a6be8e7cSmarkus 	free(state->auth);
193a6be8e7cSmarkus 	free(state->keep);
194a6be8e7cSmarkus 	free(state->th_nodes);
195a6be8e7cSmarkus 	free(state->retain);
196a6be8e7cSmarkus 	free(state->treehash);
197a6be8e7cSmarkus 	state->stack = NULL;
198a6be8e7cSmarkus 	state->stacklevels = NULL;
199a6be8e7cSmarkus 	state->auth = NULL;
200a6be8e7cSmarkus 	state->keep = NULL;
201a6be8e7cSmarkus 	state->th_nodes = NULL;
202a6be8e7cSmarkus 	state->retain = NULL;
203a6be8e7cSmarkus 	state->treehash = NULL;
204a6be8e7cSmarkus }
205a6be8e7cSmarkus 
206a6be8e7cSmarkus void *
sshkey_xmss_params(const struct sshkey * key)207a6be8e7cSmarkus sshkey_xmss_params(const struct sshkey *key)
208a6be8e7cSmarkus {
209a6be8e7cSmarkus 	struct ssh_xmss_state *state = key->xmss_state;
210a6be8e7cSmarkus 
211a6be8e7cSmarkus 	if (state == NULL)
212a6be8e7cSmarkus 		return NULL;
213a6be8e7cSmarkus 	return &state->params;
214a6be8e7cSmarkus }
215a6be8e7cSmarkus 
216a6be8e7cSmarkus void *
sshkey_xmss_bds_state(const struct sshkey * key)217a6be8e7cSmarkus sshkey_xmss_bds_state(const struct sshkey *key)
218a6be8e7cSmarkus {
219a6be8e7cSmarkus 	struct ssh_xmss_state *state = key->xmss_state;
220a6be8e7cSmarkus 
221a6be8e7cSmarkus 	if (state == NULL)
222a6be8e7cSmarkus 		return NULL;
223a6be8e7cSmarkus 	return &state->bds;
224a6be8e7cSmarkus }
225a6be8e7cSmarkus 
226a6be8e7cSmarkus int
sshkey_xmss_siglen(const struct sshkey * key,size_t * lenp)227a6be8e7cSmarkus sshkey_xmss_siglen(const struct sshkey *key, size_t *lenp)
228a6be8e7cSmarkus {
229a6be8e7cSmarkus 	struct ssh_xmss_state *state = key->xmss_state;
230a6be8e7cSmarkus 
231a6be8e7cSmarkus 	if (lenp == NULL)
232a6be8e7cSmarkus 		return SSH_ERR_INVALID_ARGUMENT;
233a6be8e7cSmarkus 	if (state == NULL)
234a6be8e7cSmarkus 		return SSH_ERR_INVALID_FORMAT;
235a6be8e7cSmarkus 	*lenp = 4 + state->n +
236a6be8e7cSmarkus 	    state->params.wots_par.keysize +
237a6be8e7cSmarkus 	    state->h * state->n;
238a6be8e7cSmarkus 	return 0;
239a6be8e7cSmarkus }
240a6be8e7cSmarkus 
241a6be8e7cSmarkus size_t
sshkey_xmss_pklen(const struct sshkey * key)242a6be8e7cSmarkus sshkey_xmss_pklen(const struct sshkey *key)
243a6be8e7cSmarkus {
244a6be8e7cSmarkus 	struct ssh_xmss_state *state = key->xmss_state;
245a6be8e7cSmarkus 
246a6be8e7cSmarkus 	if (state == NULL)
247a6be8e7cSmarkus 		return 0;
248a6be8e7cSmarkus 	return state->n * 2;
249a6be8e7cSmarkus }
250a6be8e7cSmarkus 
251a6be8e7cSmarkus size_t
sshkey_xmss_sklen(const struct sshkey * key)252a6be8e7cSmarkus sshkey_xmss_sklen(const struct sshkey *key)
253a6be8e7cSmarkus {
254a6be8e7cSmarkus 	struct ssh_xmss_state *state = key->xmss_state;
255a6be8e7cSmarkus 
256a6be8e7cSmarkus 	if (state == NULL)
257a6be8e7cSmarkus 		return 0;
258a6be8e7cSmarkus 	return state->n * 4 + 4;
259a6be8e7cSmarkus }
260a6be8e7cSmarkus 
261a6be8e7cSmarkus int
sshkey_xmss_init_enc_key(struct sshkey * k,const char * ciphername)262a6be8e7cSmarkus sshkey_xmss_init_enc_key(struct sshkey *k, const char *ciphername)
263a6be8e7cSmarkus {
264a6be8e7cSmarkus 	struct ssh_xmss_state *state = k->xmss_state;
265a6be8e7cSmarkus 	const struct sshcipher *cipher;
266a6be8e7cSmarkus 	size_t keylen = 0, ivlen = 0;
267a6be8e7cSmarkus 
268a6be8e7cSmarkus 	if (state == NULL)
269a6be8e7cSmarkus 		return SSH_ERR_INVALID_ARGUMENT;
270a6be8e7cSmarkus 	if ((cipher = cipher_by_name(ciphername)) == NULL)
271a6be8e7cSmarkus 		return SSH_ERR_INTERNAL_ERROR;
272a6be8e7cSmarkus 	if ((state->enc_ciphername = strdup(ciphername)) == NULL)
273a6be8e7cSmarkus 		return SSH_ERR_ALLOC_FAIL;
274a6be8e7cSmarkus 	keylen = cipher_keylen(cipher);
275a6be8e7cSmarkus 	ivlen = cipher_ivlen(cipher);
276a6be8e7cSmarkus 	state->enc_keyiv_len = keylen + ivlen;
277a6be8e7cSmarkus 	if ((state->enc_keyiv = calloc(state->enc_keyiv_len, 1)) == NULL) {
278a6be8e7cSmarkus 		free(state->enc_ciphername);
279a6be8e7cSmarkus 		state->enc_ciphername = NULL;
280a6be8e7cSmarkus 		return SSH_ERR_ALLOC_FAIL;
281a6be8e7cSmarkus 	}
282a6be8e7cSmarkus 	arc4random_buf(state->enc_keyiv, state->enc_keyiv_len);
283a6be8e7cSmarkus 	return 0;
284a6be8e7cSmarkus }
285a6be8e7cSmarkus 
286a6be8e7cSmarkus int
sshkey_xmss_serialize_enc_key(const struct sshkey * k,struct sshbuf * b)287a6be8e7cSmarkus sshkey_xmss_serialize_enc_key(const struct sshkey *k, struct sshbuf *b)
288a6be8e7cSmarkus {
289a6be8e7cSmarkus 	struct ssh_xmss_state *state = k->xmss_state;
290a6be8e7cSmarkus 	int r;
291a6be8e7cSmarkus 
292a6be8e7cSmarkus 	if (state == NULL || state->enc_keyiv == NULL ||
293a6be8e7cSmarkus 	    state->enc_ciphername == NULL)
294a6be8e7cSmarkus 		return SSH_ERR_INVALID_ARGUMENT;
295a6be8e7cSmarkus 	if ((r = sshbuf_put_cstring(b, state->enc_ciphername)) != 0 ||
296a6be8e7cSmarkus 	    (r = sshbuf_put_string(b, state->enc_keyiv,
297a6be8e7cSmarkus 	    state->enc_keyiv_len)) != 0)
298a6be8e7cSmarkus 		return r;
299a6be8e7cSmarkus 	return 0;
300a6be8e7cSmarkus }
301a6be8e7cSmarkus 
302a6be8e7cSmarkus int
sshkey_xmss_deserialize_enc_key(struct sshkey * k,struct sshbuf * b)303a6be8e7cSmarkus sshkey_xmss_deserialize_enc_key(struct sshkey *k, struct sshbuf *b)
304a6be8e7cSmarkus {
305a6be8e7cSmarkus 	struct ssh_xmss_state *state = k->xmss_state;
306a6be8e7cSmarkus 	size_t len;
307a6be8e7cSmarkus 	int r;
308a6be8e7cSmarkus 
309a6be8e7cSmarkus 	if (state == NULL)
310a6be8e7cSmarkus 		return SSH_ERR_INVALID_ARGUMENT;
311a6be8e7cSmarkus 	if ((r = sshbuf_get_cstring(b, &state->enc_ciphername, NULL)) != 0 ||
312a6be8e7cSmarkus 	    (r = sshbuf_get_string(b, &state->enc_keyiv, &len)) != 0)
313a6be8e7cSmarkus 		return r;
314a6be8e7cSmarkus 	state->enc_keyiv_len = len;
315a6be8e7cSmarkus 	return 0;
316a6be8e7cSmarkus }
317a6be8e7cSmarkus 
318a6be8e7cSmarkus int
sshkey_xmss_serialize_pk_info(const struct sshkey * k,struct sshbuf * b,enum sshkey_serialize_rep opts)319a6be8e7cSmarkus sshkey_xmss_serialize_pk_info(const struct sshkey *k, struct sshbuf *b,
320a6be8e7cSmarkus     enum sshkey_serialize_rep opts)
321a6be8e7cSmarkus {
322a6be8e7cSmarkus 	struct ssh_xmss_state *state = k->xmss_state;
323a6be8e7cSmarkus 	u_char have_info = 1;
324a6be8e7cSmarkus 	u_int32_t idx;
325a6be8e7cSmarkus 	int r;
326a6be8e7cSmarkus 
327a6be8e7cSmarkus 	if (state == NULL)
328a6be8e7cSmarkus 		return SSH_ERR_INVALID_ARGUMENT;
329a6be8e7cSmarkus 	if (opts != SSHKEY_SERIALIZE_INFO)
330a6be8e7cSmarkus 		return 0;
331a6be8e7cSmarkus 	idx = k->xmss_sk ? PEEK_U32(k->xmss_sk) : state->idx;
332a6be8e7cSmarkus 	if ((r = sshbuf_put_u8(b, have_info)) != 0 ||
333a6be8e7cSmarkus 	    (r = sshbuf_put_u32(b, idx)) != 0 ||
334a6be8e7cSmarkus 	    (r = sshbuf_put_u32(b, state->maxidx)) != 0)
335a6be8e7cSmarkus 		return r;
336a6be8e7cSmarkus 	return 0;
337a6be8e7cSmarkus }
338a6be8e7cSmarkus 
339a6be8e7cSmarkus int
sshkey_xmss_deserialize_pk_info(struct sshkey * k,struct sshbuf * b)340a6be8e7cSmarkus sshkey_xmss_deserialize_pk_info(struct sshkey *k, struct sshbuf *b)
341a6be8e7cSmarkus {
342a6be8e7cSmarkus 	struct ssh_xmss_state *state = k->xmss_state;
343a6be8e7cSmarkus 	u_char have_info;
344a6be8e7cSmarkus 	int r;
345a6be8e7cSmarkus 
346a6be8e7cSmarkus 	if (state == NULL)
347a6be8e7cSmarkus 		return SSH_ERR_INVALID_ARGUMENT;
348a6be8e7cSmarkus 	/* optional */
349a6be8e7cSmarkus 	if (sshbuf_len(b) == 0)
350a6be8e7cSmarkus 		return 0;
351a6be8e7cSmarkus 	if ((r = sshbuf_get_u8(b, &have_info)) != 0)
352a6be8e7cSmarkus 		return r;
353a6be8e7cSmarkus 	if (have_info != 1)
354a6be8e7cSmarkus 		return SSH_ERR_INVALID_ARGUMENT;
355a6be8e7cSmarkus 	if ((r = sshbuf_get_u32(b, &state->idx)) != 0 ||
356a6be8e7cSmarkus 	    (r = sshbuf_get_u32(b, &state->maxidx)) != 0)
357a6be8e7cSmarkus 		return r;
358a6be8e7cSmarkus 	return 0;
359a6be8e7cSmarkus }
360a6be8e7cSmarkus 
361a6be8e7cSmarkus int
sshkey_xmss_generate_private_key(struct sshkey * k,int bits)362*b6025febSdjm sshkey_xmss_generate_private_key(struct sshkey *k, int bits)
363a6be8e7cSmarkus {
364a6be8e7cSmarkus 	int r;
365a6be8e7cSmarkus 	const char *name;
366a6be8e7cSmarkus 
367a6be8e7cSmarkus 	if (bits == 10) {
368a6be8e7cSmarkus 		name = XMSS_SHA2_256_W16_H10_NAME;
369a6be8e7cSmarkus 	} else if (bits == 16) {
370a6be8e7cSmarkus 		name = XMSS_SHA2_256_W16_H16_NAME;
371a6be8e7cSmarkus 	} else if (bits == 20) {
372a6be8e7cSmarkus 		name = XMSS_SHA2_256_W16_H20_NAME;
373a6be8e7cSmarkus 	} else {
374a6be8e7cSmarkus 		name = XMSS_DEFAULT_NAME;
375a6be8e7cSmarkus 	}
376a6be8e7cSmarkus 	if ((r = sshkey_xmss_init(k, name)) != 0 ||
377a6be8e7cSmarkus 	    (r = sshkey_xmss_init_bds_state(k)) != 0 ||
378a6be8e7cSmarkus 	    (r = sshkey_xmss_init_enc_key(k, XMSS_CIPHERNAME)) != 0)
379a6be8e7cSmarkus 		return r;
380a6be8e7cSmarkus 	if ((k->xmss_pk = malloc(sshkey_xmss_pklen(k))) == NULL ||
381a6be8e7cSmarkus 	    (k->xmss_sk = malloc(sshkey_xmss_sklen(k))) == NULL) {
382a6be8e7cSmarkus 		return SSH_ERR_ALLOC_FAIL;
383a6be8e7cSmarkus 	}
384a6be8e7cSmarkus 	xmss_keypair(k->xmss_pk, k->xmss_sk, sshkey_xmss_bds_state(k),
385a6be8e7cSmarkus 	    sshkey_xmss_params(k));
386a6be8e7cSmarkus 	return 0;
387a6be8e7cSmarkus }
388a6be8e7cSmarkus 
389a6be8e7cSmarkus int
sshkey_xmss_get_state_from_file(struct sshkey * k,const char * filename,int * have_file,int printerror)390a6be8e7cSmarkus sshkey_xmss_get_state_from_file(struct sshkey *k, const char *filename,
39179e62715Sdtucker     int *have_file, int printerror)
392a6be8e7cSmarkus {
393a6be8e7cSmarkus 	struct sshbuf *b = NULL, *enc = NULL;
394a6be8e7cSmarkus 	int ret = SSH_ERR_SYSTEM_ERROR, r, fd = -1;
395a6be8e7cSmarkus 	u_int32_t len;
396a6be8e7cSmarkus 	unsigned char buf[4], *data = NULL;
397a6be8e7cSmarkus 
398a6be8e7cSmarkus 	*have_file = 0;
399a6be8e7cSmarkus 	if ((fd = open(filename, O_RDONLY)) >= 0) {
400a6be8e7cSmarkus 		*have_file = 1;
401a6be8e7cSmarkus 		if (atomicio(read, fd, buf, sizeof(buf)) != sizeof(buf)) {
402ecf7a8f6Smillert 			PRINT("corrupt state file: %s", filename);
403a6be8e7cSmarkus 			goto done;
404a6be8e7cSmarkus 		}
405a6be8e7cSmarkus 		len = PEEK_U32(buf);
406a6be8e7cSmarkus 		if ((data = calloc(len, 1)) == NULL) {
407a6be8e7cSmarkus 			ret = SSH_ERR_ALLOC_FAIL;
408a6be8e7cSmarkus 			goto done;
409a6be8e7cSmarkus 		}
410a6be8e7cSmarkus 		if (atomicio(read, fd, data, len) != len) {
411ecf7a8f6Smillert 			PRINT("cannot read blob: %s", filename);
412a6be8e7cSmarkus 			goto done;
413a6be8e7cSmarkus 		}
414a6be8e7cSmarkus 		if ((enc = sshbuf_from(data, len)) == NULL) {
415a6be8e7cSmarkus 			ret = SSH_ERR_ALLOC_FAIL;
416a6be8e7cSmarkus 			goto done;
417a6be8e7cSmarkus 		}
418a6be8e7cSmarkus 		sshkey_xmss_free_bds(k);
419a6be8e7cSmarkus 		if ((r = sshkey_xmss_decrypt_state(k, enc, &b)) != 0) {
420a6be8e7cSmarkus 			ret = r;
421a6be8e7cSmarkus 			goto done;
422a6be8e7cSmarkus 		}
423a6be8e7cSmarkus 		if ((r = sshkey_xmss_deserialize_state(k, b)) != 0) {
424a6be8e7cSmarkus 			ret = r;
425a6be8e7cSmarkus 			goto done;
426a6be8e7cSmarkus 		}
427a6be8e7cSmarkus 		ret = 0;
428a6be8e7cSmarkus 	}
429a6be8e7cSmarkus done:
430a6be8e7cSmarkus 	if (fd != -1)
431a6be8e7cSmarkus 		close(fd);
432a6be8e7cSmarkus 	free(data);
433a6be8e7cSmarkus 	sshbuf_free(enc);
434a6be8e7cSmarkus 	sshbuf_free(b);
435a6be8e7cSmarkus 	return ret;
436a6be8e7cSmarkus }
437a6be8e7cSmarkus 
438a6be8e7cSmarkus int
sshkey_xmss_get_state(const struct sshkey * k,int printerror)43979e62715Sdtucker sshkey_xmss_get_state(const struct sshkey *k, int printerror)
440a6be8e7cSmarkus {
441a6be8e7cSmarkus 	struct ssh_xmss_state *state = k->xmss_state;
442a6be8e7cSmarkus 	u_int32_t idx = 0;
443a6be8e7cSmarkus 	char *filename = NULL;
444a6be8e7cSmarkus 	char *statefile = NULL, *ostatefile = NULL, *lockfile = NULL;
445a6be8e7cSmarkus 	int lockfd = -1, have_state = 0, have_ostate, tries = 0;
446a6be8e7cSmarkus 	int ret = SSH_ERR_INVALID_ARGUMENT, r;
447a6be8e7cSmarkus 
448a6be8e7cSmarkus 	if (state == NULL)
449a6be8e7cSmarkus 		goto done;
450a6be8e7cSmarkus 	/*
451a6be8e7cSmarkus 	 * If maxidx is set, then we are allowed a limited number
452a6be8e7cSmarkus 	 * of signatures, but don't need to access the disk.
453a6be8e7cSmarkus 	 * Otherwise we need to deal with the on-disk state.
454a6be8e7cSmarkus 	 */
455a6be8e7cSmarkus 	if (state->maxidx) {
456a6be8e7cSmarkus 		/* xmss_sk always contains the current state */
457a6be8e7cSmarkus 		idx = PEEK_U32(k->xmss_sk);
458a6be8e7cSmarkus 		if (idx < state->maxidx) {
459a6be8e7cSmarkus 			state->allow_update = 1;
460a6be8e7cSmarkus 			return 0;
461a6be8e7cSmarkus 		}
462a6be8e7cSmarkus 		return SSH_ERR_INVALID_ARGUMENT;
463a6be8e7cSmarkus 	}
464a6be8e7cSmarkus 	if ((filename = k->xmss_filename) == NULL)
465a6be8e7cSmarkus 		goto done;
46695af8abfSderaadt 	if (asprintf(&lockfile, "%s.lock", filename) == -1 ||
46795af8abfSderaadt 	    asprintf(&statefile, "%s.state", filename) == -1 ||
46895af8abfSderaadt 	    asprintf(&ostatefile, "%s.ostate", filename) == -1) {
469a6be8e7cSmarkus 		ret = SSH_ERR_ALLOC_FAIL;
470a6be8e7cSmarkus 		goto done;
471a6be8e7cSmarkus 	}
4723aaa63ebSderaadt 	if ((lockfd = open(lockfile, O_CREAT|O_RDONLY, 0600)) == -1) {
473a6be8e7cSmarkus 		ret = SSH_ERR_SYSTEM_ERROR;
474ecf7a8f6Smillert 		PRINT("cannot open/create: %s", lockfile);
475a6be8e7cSmarkus 		goto done;
476a6be8e7cSmarkus 	}
4773aaa63ebSderaadt 	while (flock(lockfd, LOCK_EX|LOCK_NB) == -1) {
478a6be8e7cSmarkus 		if (errno != EWOULDBLOCK) {
479a6be8e7cSmarkus 			ret = SSH_ERR_SYSTEM_ERROR;
480ecf7a8f6Smillert 			PRINT("cannot lock: %s", lockfile);
481a6be8e7cSmarkus 			goto done;
482a6be8e7cSmarkus 		}
483a6be8e7cSmarkus 		if (++tries > 10) {
484a6be8e7cSmarkus 			ret = SSH_ERR_SYSTEM_ERROR;
485ecf7a8f6Smillert 			PRINT("giving up on: %s", lockfile);
486a6be8e7cSmarkus 			goto done;
487a6be8e7cSmarkus 		}
488a6be8e7cSmarkus 		usleep(1000*100*tries);
489a6be8e7cSmarkus 	}
490a6be8e7cSmarkus 	/* XXX no longer const */
491a6be8e7cSmarkus 	if ((r = sshkey_xmss_get_state_from_file((struct sshkey *)k,
49279e62715Sdtucker 	    statefile, &have_state, printerror)) != 0) {
493a6be8e7cSmarkus 		if ((r = sshkey_xmss_get_state_from_file((struct sshkey *)k,
49479e62715Sdtucker 		    ostatefile, &have_ostate, printerror)) == 0) {
495a6be8e7cSmarkus 			state->allow_update = 1;
496a6be8e7cSmarkus 			r = sshkey_xmss_forward_state(k, 1);
497a6be8e7cSmarkus 			state->idx = PEEK_U32(k->xmss_sk);
498a6be8e7cSmarkus 			state->allow_update = 0;
499a6be8e7cSmarkus 		}
500a6be8e7cSmarkus 	}
501a6be8e7cSmarkus 	if (!have_state && !have_ostate) {
502a6be8e7cSmarkus 		/* check that bds state is initialized */
503a6be8e7cSmarkus 		if (state->bds.auth == NULL)
504a6be8e7cSmarkus 			goto done;
505ecf7a8f6Smillert 		PRINT("start from scratch idx 0: %u", state->idx);
506a6be8e7cSmarkus 	} else if (r != 0) {
507a6be8e7cSmarkus 		ret = r;
508a6be8e7cSmarkus 		goto done;
509a6be8e7cSmarkus 	}
510a6be8e7cSmarkus 	if (state->idx + 1 < state->idx) {
511ecf7a8f6Smillert 		PRINT("state wrap: %u", state->idx);
512a6be8e7cSmarkus 		goto done;
513a6be8e7cSmarkus 	}
514a6be8e7cSmarkus 	state->have_state = have_state;
515a6be8e7cSmarkus 	state->lockfd = lockfd;
516a6be8e7cSmarkus 	state->allow_update = 1;
517a6be8e7cSmarkus 	lockfd = -1;
518a6be8e7cSmarkus 	ret = 0;
519a6be8e7cSmarkus done:
520a6be8e7cSmarkus 	if (lockfd != -1)
521a6be8e7cSmarkus 		close(lockfd);
522a6be8e7cSmarkus 	free(lockfile);
523a6be8e7cSmarkus 	free(statefile);
524a6be8e7cSmarkus 	free(ostatefile);
525a6be8e7cSmarkus 	return ret;
526a6be8e7cSmarkus }
527a6be8e7cSmarkus 
528a6be8e7cSmarkus int
sshkey_xmss_forward_state(const struct sshkey * k,u_int32_t reserve)529a6be8e7cSmarkus sshkey_xmss_forward_state(const struct sshkey *k, u_int32_t reserve)
530a6be8e7cSmarkus {
531a6be8e7cSmarkus 	struct ssh_xmss_state *state = k->xmss_state;
532a6be8e7cSmarkus 	u_char *sig = NULL;
533a6be8e7cSmarkus 	size_t required_siglen;
534a6be8e7cSmarkus 	unsigned long long smlen;
535a6be8e7cSmarkus 	u_char data;
536a6be8e7cSmarkus 	int ret, r;
537a6be8e7cSmarkus 
538a6be8e7cSmarkus 	if (state == NULL || !state->allow_update)
539a6be8e7cSmarkus 		return SSH_ERR_INVALID_ARGUMENT;
540a6be8e7cSmarkus 	if (reserve == 0)
541a6be8e7cSmarkus 		return SSH_ERR_INVALID_ARGUMENT;
542a6be8e7cSmarkus 	if (state->idx + reserve <= state->idx)
543a6be8e7cSmarkus 		return SSH_ERR_INVALID_ARGUMENT;
544a6be8e7cSmarkus 	if ((r = sshkey_xmss_siglen(k, &required_siglen)) != 0)
545a6be8e7cSmarkus 		return r;
546a6be8e7cSmarkus 	if ((sig = malloc(required_siglen)) == NULL)
547a6be8e7cSmarkus 		return SSH_ERR_ALLOC_FAIL;
548a6be8e7cSmarkus 	while (reserve-- > 0) {
549a6be8e7cSmarkus 		state->idx = PEEK_U32(k->xmss_sk);
550a6be8e7cSmarkus 		smlen = required_siglen;
551a6be8e7cSmarkus 		if ((ret = xmss_sign(k->xmss_sk, sshkey_xmss_bds_state(k),
552a6be8e7cSmarkus 		    sig, &smlen, &data, 0, sshkey_xmss_params(k))) != 0) {
553a6be8e7cSmarkus 			r = SSH_ERR_INVALID_ARGUMENT;
554a6be8e7cSmarkus 			break;
555a6be8e7cSmarkus 		}
556a6be8e7cSmarkus 	}
557a6be8e7cSmarkus 	free(sig);
558a6be8e7cSmarkus 	return r;
559a6be8e7cSmarkus }
560a6be8e7cSmarkus 
561a6be8e7cSmarkus int
sshkey_xmss_update_state(const struct sshkey * k,int printerror)56279e62715Sdtucker sshkey_xmss_update_state(const struct sshkey *k, int printerror)
563a6be8e7cSmarkus {
564a6be8e7cSmarkus 	struct ssh_xmss_state *state = k->xmss_state;
565a6be8e7cSmarkus 	struct sshbuf *b = NULL, *enc = NULL;
566a6be8e7cSmarkus 	u_int32_t idx = 0;
567a6be8e7cSmarkus 	unsigned char buf[4];
568a6be8e7cSmarkus 	char *filename = NULL;
569a6be8e7cSmarkus 	char *statefile = NULL, *ostatefile = NULL, *nstatefile = NULL;
570a6be8e7cSmarkus 	int fd = -1;
571a6be8e7cSmarkus 	int ret = SSH_ERR_INVALID_ARGUMENT;
572a6be8e7cSmarkus 
573a6be8e7cSmarkus 	if (state == NULL || !state->allow_update)
574a6be8e7cSmarkus 		return ret;
575a6be8e7cSmarkus 	if (state->maxidx) {
576a6be8e7cSmarkus 		/* no update since the number of signatures is limited */
577a6be8e7cSmarkus 		ret = 0;
578a6be8e7cSmarkus 		goto done;
579a6be8e7cSmarkus 	}
580a6be8e7cSmarkus 	idx = PEEK_U32(k->xmss_sk);
581a6be8e7cSmarkus 	if (idx == state->idx) {
58227a1722dSdjm 		/* no signature happened, no need to update */
583a6be8e7cSmarkus 		ret = 0;
584a6be8e7cSmarkus 		goto done;
585a6be8e7cSmarkus 	} else if (idx != state->idx + 1) {
586ecf7a8f6Smillert 		PRINT("more than one signature happened: idx %u state %u",
587ecf7a8f6Smillert 		    idx, state->idx);
588a6be8e7cSmarkus 		goto done;
589a6be8e7cSmarkus 	}
590a6be8e7cSmarkus 	state->idx = idx;
591a6be8e7cSmarkus 	if ((filename = k->xmss_filename) == NULL)
592a6be8e7cSmarkus 		goto done;
59395af8abfSderaadt 	if (asprintf(&statefile, "%s.state", filename) == -1 ||
59495af8abfSderaadt 	    asprintf(&ostatefile, "%s.ostate", filename) == -1 ||
59595af8abfSderaadt 	    asprintf(&nstatefile, "%s.nstate", filename) == -1) {
596a6be8e7cSmarkus 		ret = SSH_ERR_ALLOC_FAIL;
597a6be8e7cSmarkus 		goto done;
598a6be8e7cSmarkus 	}
599a6be8e7cSmarkus 	unlink(nstatefile);
600a6be8e7cSmarkus 	if ((b = sshbuf_new()) == NULL) {
601a6be8e7cSmarkus 		ret = SSH_ERR_ALLOC_FAIL;
602a6be8e7cSmarkus 		goto done;
603a6be8e7cSmarkus 	}
604a6be8e7cSmarkus 	if ((ret = sshkey_xmss_serialize_state(k, b)) != 0) {
605ecf7a8f6Smillert 		PRINT("SERLIALIZE FAILED: %d", ret);
606a6be8e7cSmarkus 		goto done;
607a6be8e7cSmarkus 	}
608a6be8e7cSmarkus 	if ((ret = sshkey_xmss_encrypt_state(k, b, &enc)) != 0) {
609ecf7a8f6Smillert 		PRINT("ENCRYPT FAILED: %d", ret);
610a6be8e7cSmarkus 		goto done;
611a6be8e7cSmarkus 	}
6123aaa63ebSderaadt 	if ((fd = open(nstatefile, O_CREAT|O_WRONLY|O_EXCL, 0600)) == -1) {
613a6be8e7cSmarkus 		ret = SSH_ERR_SYSTEM_ERROR;
614ecf7a8f6Smillert 		PRINT("open new state file: %s", nstatefile);
615a6be8e7cSmarkus 		goto done;
616a6be8e7cSmarkus 	}
617a6be8e7cSmarkus 	POKE_U32(buf, sshbuf_len(enc));
618a6be8e7cSmarkus 	if (atomicio(vwrite, fd, buf, sizeof(buf)) != sizeof(buf)) {
619a6be8e7cSmarkus 		ret = SSH_ERR_SYSTEM_ERROR;
620ecf7a8f6Smillert 		PRINT("write new state file hdr: %s", nstatefile);
621a6be8e7cSmarkus 		close(fd);
622a6be8e7cSmarkus 		goto done;
623a6be8e7cSmarkus 	}
62425fe41faSmarkus 	if (atomicio(vwrite, fd, sshbuf_mutable_ptr(enc), sshbuf_len(enc)) !=
625a6be8e7cSmarkus 	    sshbuf_len(enc)) {
626a6be8e7cSmarkus 		ret = SSH_ERR_SYSTEM_ERROR;
627ecf7a8f6Smillert 		PRINT("write new state file data: %s", nstatefile);
628a6be8e7cSmarkus 		close(fd);
629a6be8e7cSmarkus 		goto done;
630a6be8e7cSmarkus 	}
6313aaa63ebSderaadt 	if (fsync(fd) == -1) {
632a6be8e7cSmarkus 		ret = SSH_ERR_SYSTEM_ERROR;
633ecf7a8f6Smillert 		PRINT("sync new state file: %s", nstatefile);
634a6be8e7cSmarkus 		close(fd);
635a6be8e7cSmarkus 		goto done;
636a6be8e7cSmarkus 	}
6373aaa63ebSderaadt 	if (close(fd) == -1) {
638a6be8e7cSmarkus 		ret = SSH_ERR_SYSTEM_ERROR;
639ecf7a8f6Smillert 		PRINT("close new state file: %s", nstatefile);
640a6be8e7cSmarkus 		goto done;
641a6be8e7cSmarkus 	}
642a6be8e7cSmarkus 	if (state->have_state) {
643a6be8e7cSmarkus 		unlink(ostatefile);
644a6be8e7cSmarkus 		if (link(statefile, ostatefile)) {
645a6be8e7cSmarkus 			ret = SSH_ERR_SYSTEM_ERROR;
646ecf7a8f6Smillert 			PRINT("backup state %s to %s", statefile, ostatefile);
647a6be8e7cSmarkus 			goto done;
648a6be8e7cSmarkus 		}
649a6be8e7cSmarkus 	}
6503aaa63ebSderaadt 	if (rename(nstatefile, statefile) == -1) {
651a6be8e7cSmarkus 		ret = SSH_ERR_SYSTEM_ERROR;
652ecf7a8f6Smillert 		PRINT("rename %s to %s", nstatefile, statefile);
653a6be8e7cSmarkus 		goto done;
654a6be8e7cSmarkus 	}
655a6be8e7cSmarkus 	ret = 0;
656a6be8e7cSmarkus done:
657a6be8e7cSmarkus 	if (state->lockfd != -1) {
658a6be8e7cSmarkus 		close(state->lockfd);
659a6be8e7cSmarkus 		state->lockfd = -1;
660a6be8e7cSmarkus 	}
661a6be8e7cSmarkus 	if (nstatefile)
662a6be8e7cSmarkus 		unlink(nstatefile);
663a6be8e7cSmarkus 	free(statefile);
664a6be8e7cSmarkus 	free(ostatefile);
665a6be8e7cSmarkus 	free(nstatefile);
666a6be8e7cSmarkus 	sshbuf_free(b);
667a6be8e7cSmarkus 	sshbuf_free(enc);
668a6be8e7cSmarkus 	return ret;
669a6be8e7cSmarkus }
670a6be8e7cSmarkus 
671a6be8e7cSmarkus int
sshkey_xmss_serialize_state(const struct sshkey * k,struct sshbuf * b)672a6be8e7cSmarkus sshkey_xmss_serialize_state(const struct sshkey *k, struct sshbuf *b)
673a6be8e7cSmarkus {
674a6be8e7cSmarkus 	struct ssh_xmss_state *state = k->xmss_state;
675a6be8e7cSmarkus 	treehash_inst *th;
676a6be8e7cSmarkus 	u_int32_t i, node;
677a6be8e7cSmarkus 	int r;
678a6be8e7cSmarkus 
679a6be8e7cSmarkus 	if (state == NULL)
680a6be8e7cSmarkus 		return SSH_ERR_INVALID_ARGUMENT;
681a6be8e7cSmarkus 	if (state->stack == NULL)
682a6be8e7cSmarkus 		return SSH_ERR_INVALID_ARGUMENT;
683a6be8e7cSmarkus 	state->stackoffset = state->bds.stackoffset;	/* copy back */
684a6be8e7cSmarkus 	if ((r = sshbuf_put_cstring(b, SSH_XMSS_K2_MAGIC)) != 0 ||
685a6be8e7cSmarkus 	    (r = sshbuf_put_u32(b, state->idx)) != 0 ||
686a6be8e7cSmarkus 	    (r = sshbuf_put_string(b, state->stack, num_stack(state))) != 0 ||
687a6be8e7cSmarkus 	    (r = sshbuf_put_u32(b, state->stackoffset)) != 0 ||
688a6be8e7cSmarkus 	    (r = sshbuf_put_string(b, state->stacklevels, num_stacklevels(state))) != 0 ||
689a6be8e7cSmarkus 	    (r = sshbuf_put_string(b, state->auth, num_auth(state))) != 0 ||
690a6be8e7cSmarkus 	    (r = sshbuf_put_string(b, state->keep, num_keep(state))) != 0 ||
691a6be8e7cSmarkus 	    (r = sshbuf_put_string(b, state->th_nodes, num_th_nodes(state))) != 0 ||
692a6be8e7cSmarkus 	    (r = sshbuf_put_string(b, state->retain, num_retain(state))) != 0 ||
693a6be8e7cSmarkus 	    (r = sshbuf_put_u32(b, num_treehash(state))) != 0)
694a6be8e7cSmarkus 		return r;
695a6be8e7cSmarkus 	for (i = 0; i < num_treehash(state); i++) {
696a6be8e7cSmarkus 		th = &state->treehash[i];
697a6be8e7cSmarkus 		node = th->node - state->th_nodes;
698a6be8e7cSmarkus 		if ((r = sshbuf_put_u32(b, th->h)) != 0 ||
699a6be8e7cSmarkus 		    (r = sshbuf_put_u32(b, th->next_idx)) != 0 ||
700a6be8e7cSmarkus 		    (r = sshbuf_put_u32(b, th->stackusage)) != 0 ||
701a6be8e7cSmarkus 		    (r = sshbuf_put_u8(b, th->completed)) != 0 ||
702a6be8e7cSmarkus 		    (r = sshbuf_put_u32(b, node)) != 0)
703a6be8e7cSmarkus 			return r;
704a6be8e7cSmarkus 	}
705a6be8e7cSmarkus 	return 0;
706a6be8e7cSmarkus }
707a6be8e7cSmarkus 
708a6be8e7cSmarkus int
sshkey_xmss_serialize_state_opt(const struct sshkey * k,struct sshbuf * b,enum sshkey_serialize_rep opts)709a6be8e7cSmarkus sshkey_xmss_serialize_state_opt(const struct sshkey *k, struct sshbuf *b,
710a6be8e7cSmarkus     enum sshkey_serialize_rep opts)
711a6be8e7cSmarkus {
712a6be8e7cSmarkus 	struct ssh_xmss_state *state = k->xmss_state;
713a6be8e7cSmarkus 	int r = SSH_ERR_INVALID_ARGUMENT;
714d3c68393Smarkus 	u_char have_stack, have_filename, have_enc;
715a6be8e7cSmarkus 
716a6be8e7cSmarkus 	if (state == NULL)
717a6be8e7cSmarkus 		return SSH_ERR_INVALID_ARGUMENT;
718a6be8e7cSmarkus 	if ((r = sshbuf_put_u8(b, opts)) != 0)
719a6be8e7cSmarkus 		return r;
720a6be8e7cSmarkus 	switch (opts) {
721a6be8e7cSmarkus 	case SSHKEY_SERIALIZE_STATE:
722a6be8e7cSmarkus 		r = sshkey_xmss_serialize_state(k, b);
723a6be8e7cSmarkus 		break;
724a6be8e7cSmarkus 	case SSHKEY_SERIALIZE_FULL:
725a6be8e7cSmarkus 		if ((r = sshkey_xmss_serialize_enc_key(k, b)) != 0)
726d3c68393Smarkus 			return r;
727a6be8e7cSmarkus 		r = sshkey_xmss_serialize_state(k, b);
728a6be8e7cSmarkus 		break;
729d3c68393Smarkus 	case SSHKEY_SERIALIZE_SHIELD:
730d3c68393Smarkus 		/* all of stack/filename/enc are optional */
731d3c68393Smarkus 		have_stack = state->stack != NULL;
732d3c68393Smarkus 		if ((r = sshbuf_put_u8(b, have_stack)) != 0)
733d3c68393Smarkus 			return r;
734d3c68393Smarkus 		if (have_stack) {
735d3c68393Smarkus 			state->idx = PEEK_U32(k->xmss_sk);	/* update */
736d3c68393Smarkus 			if ((r = sshkey_xmss_serialize_state(k, b)) != 0)
737d3c68393Smarkus 				return r;
738d3c68393Smarkus 		}
739d3c68393Smarkus 		have_filename = k->xmss_filename != NULL;
740d3c68393Smarkus 		if ((r = sshbuf_put_u8(b, have_filename)) != 0)
741d3c68393Smarkus 			return r;
742d3c68393Smarkus 		if (have_filename &&
743d3c68393Smarkus 		    (r = sshbuf_put_cstring(b, k->xmss_filename)) != 0)
744d3c68393Smarkus 			return r;
745d3c68393Smarkus 		have_enc = state->enc_keyiv != NULL;
746d3c68393Smarkus 		if ((r = sshbuf_put_u8(b, have_enc)) != 0)
747d3c68393Smarkus 			return r;
748d3c68393Smarkus 		if (have_enc &&
749d3c68393Smarkus 		    (r = sshkey_xmss_serialize_enc_key(k, b)) != 0)
750d3c68393Smarkus 			return r;
751d3c68393Smarkus 		if ((r = sshbuf_put_u32(b, state->maxidx)) != 0 ||
752d3c68393Smarkus 		    (r = sshbuf_put_u8(b, state->allow_update)) != 0)
753d3c68393Smarkus 			return r;
754d3c68393Smarkus 		break;
755a6be8e7cSmarkus 	case SSHKEY_SERIALIZE_DEFAULT:
756a6be8e7cSmarkus 		r = 0;
757a6be8e7cSmarkus 		break;
758a6be8e7cSmarkus 	default:
759a6be8e7cSmarkus 		r = SSH_ERR_INVALID_ARGUMENT;
760a6be8e7cSmarkus 		break;
761a6be8e7cSmarkus 	}
762a6be8e7cSmarkus 	return r;
763a6be8e7cSmarkus }
764a6be8e7cSmarkus 
765a6be8e7cSmarkus int
sshkey_xmss_deserialize_state(struct sshkey * k,struct sshbuf * b)766a6be8e7cSmarkus sshkey_xmss_deserialize_state(struct sshkey *k, struct sshbuf *b)
767a6be8e7cSmarkus {
768a6be8e7cSmarkus 	struct ssh_xmss_state *state = k->xmss_state;
769a6be8e7cSmarkus 	treehash_inst *th;
770a6be8e7cSmarkus 	u_int32_t i, lh, node;
771a6be8e7cSmarkus 	size_t ls, lsl, la, lk, ln, lr;
772a6be8e7cSmarkus 	char *magic;
7730bd8f658Sdjm 	int r = SSH_ERR_INTERNAL_ERROR;
774a6be8e7cSmarkus 
775a6be8e7cSmarkus 	if (state == NULL)
776a6be8e7cSmarkus 		return SSH_ERR_INVALID_ARGUMENT;
777a6be8e7cSmarkus 	if (k->xmss_sk == NULL)
778a6be8e7cSmarkus 		return SSH_ERR_INVALID_ARGUMENT;
779a6be8e7cSmarkus 	if ((state->treehash = calloc(num_treehash(state),
780a6be8e7cSmarkus 	    sizeof(treehash_inst))) == NULL)
781a6be8e7cSmarkus 		return SSH_ERR_ALLOC_FAIL;
782a6be8e7cSmarkus 	if ((r = sshbuf_get_cstring(b, &magic, NULL)) != 0 ||
783a6be8e7cSmarkus 	    (r = sshbuf_get_u32(b, &state->idx)) != 0 ||
784a6be8e7cSmarkus 	    (r = sshbuf_get_string(b, &state->stack, &ls)) != 0 ||
785a6be8e7cSmarkus 	    (r = sshbuf_get_u32(b, &state->stackoffset)) != 0 ||
786a6be8e7cSmarkus 	    (r = sshbuf_get_string(b, &state->stacklevels, &lsl)) != 0 ||
787a6be8e7cSmarkus 	    (r = sshbuf_get_string(b, &state->auth, &la)) != 0 ||
788a6be8e7cSmarkus 	    (r = sshbuf_get_string(b, &state->keep, &lk)) != 0 ||
789a6be8e7cSmarkus 	    (r = sshbuf_get_string(b, &state->th_nodes, &ln)) != 0 ||
790a6be8e7cSmarkus 	    (r = sshbuf_get_string(b, &state->retain, &lr)) != 0 ||
791a6be8e7cSmarkus 	    (r = sshbuf_get_u32(b, &lh)) != 0)
7920bd8f658Sdjm 		goto out;
7930bd8f658Sdjm 	if (strcmp(magic, SSH_XMSS_K2_MAGIC) != 0) {
7940bd8f658Sdjm 		r = SSH_ERR_INVALID_ARGUMENT;
7950bd8f658Sdjm 		goto out;
7960bd8f658Sdjm 	}
797a6be8e7cSmarkus 	/* XXX check stackoffset */
798a6be8e7cSmarkus 	if (ls != num_stack(state) ||
799a6be8e7cSmarkus 	    lsl != num_stacklevels(state) ||
800a6be8e7cSmarkus 	    la != num_auth(state) ||
801a6be8e7cSmarkus 	    lk != num_keep(state) ||
802a6be8e7cSmarkus 	    ln != num_th_nodes(state) ||
803a6be8e7cSmarkus 	    lr != num_retain(state) ||
8040bd8f658Sdjm 	    lh != num_treehash(state)) {
8050bd8f658Sdjm 		r = SSH_ERR_INVALID_ARGUMENT;
8060bd8f658Sdjm 		goto out;
8070bd8f658Sdjm 	}
808a6be8e7cSmarkus 	for (i = 0; i < num_treehash(state); i++) {
809a6be8e7cSmarkus 		th = &state->treehash[i];
810a6be8e7cSmarkus 		if ((r = sshbuf_get_u32(b, &th->h)) != 0 ||
811a6be8e7cSmarkus 		    (r = sshbuf_get_u32(b, &th->next_idx)) != 0 ||
812a6be8e7cSmarkus 		    (r = sshbuf_get_u32(b, &th->stackusage)) != 0 ||
813a6be8e7cSmarkus 		    (r = sshbuf_get_u8(b, &th->completed)) != 0 ||
814a6be8e7cSmarkus 		    (r = sshbuf_get_u32(b, &node)) != 0)
8150bd8f658Sdjm 			goto out;
816a6be8e7cSmarkus 		if (node < num_th_nodes(state))
817a6be8e7cSmarkus 			th->node = &state->th_nodes[node];
818a6be8e7cSmarkus 	}
819a6be8e7cSmarkus 	POKE_U32(k->xmss_sk, state->idx);
820a6be8e7cSmarkus 	xmss_set_bds_state(&state->bds, state->stack, state->stackoffset,
821a6be8e7cSmarkus 	    state->stacklevels, state->auth, state->keep, state->treehash,
822a6be8e7cSmarkus 	    state->retain, 0);
8230bd8f658Sdjm 	/* success */
8240bd8f658Sdjm 	r = 0;
8250bd8f658Sdjm  out:
8260bd8f658Sdjm 	free(magic);
8270bd8f658Sdjm 	return r;
828a6be8e7cSmarkus }
829a6be8e7cSmarkus 
830a6be8e7cSmarkus int
sshkey_xmss_deserialize_state_opt(struct sshkey * k,struct sshbuf * b)831a6be8e7cSmarkus sshkey_xmss_deserialize_state_opt(struct sshkey *k, struct sshbuf *b)
832a6be8e7cSmarkus {
833d3c68393Smarkus 	struct ssh_xmss_state *state = k->xmss_state;
834a6be8e7cSmarkus 	enum sshkey_serialize_rep opts;
835d3c68393Smarkus 	u_char have_state, have_stack, have_filename, have_enc;
836a6be8e7cSmarkus 	int r;
837a6be8e7cSmarkus 
838a6be8e7cSmarkus 	if ((r = sshbuf_get_u8(b, &have_state)) != 0)
839a6be8e7cSmarkus 		return r;
840a6be8e7cSmarkus 
841a6be8e7cSmarkus 	opts = have_state;
842a6be8e7cSmarkus 	switch (opts) {
843a6be8e7cSmarkus 	case SSHKEY_SERIALIZE_DEFAULT:
844a6be8e7cSmarkus 		r = 0;
845a6be8e7cSmarkus 		break;
846d3c68393Smarkus 	case SSHKEY_SERIALIZE_SHIELD:
847d3c68393Smarkus 		if ((r = sshbuf_get_u8(b, &have_stack)) != 0)
848d3c68393Smarkus 			return r;
849d3c68393Smarkus 		if (have_stack &&
850d3c68393Smarkus 		    (r = sshkey_xmss_deserialize_state(k, b)) != 0)
851d3c68393Smarkus 			return r;
852d3c68393Smarkus 		if ((r = sshbuf_get_u8(b, &have_filename)) != 0)
853d3c68393Smarkus 			return r;
854d3c68393Smarkus 		if (have_filename &&
855d3c68393Smarkus 		    (r = sshbuf_get_cstring(b, &k->xmss_filename, NULL)) != 0)
856d3c68393Smarkus 			return r;
857d3c68393Smarkus 		if ((r = sshbuf_get_u8(b, &have_enc)) != 0)
858d3c68393Smarkus 			return r;
859d3c68393Smarkus 		if (have_enc &&
860d3c68393Smarkus 		    (r = sshkey_xmss_deserialize_enc_key(k, b)) != 0)
861d3c68393Smarkus 			return r;
862d3c68393Smarkus 		if ((r = sshbuf_get_u32(b, &state->maxidx)) != 0 ||
863d3c68393Smarkus 		    (r = sshbuf_get_u8(b, &state->allow_update)) != 0)
864d3c68393Smarkus 			return r;
865d3c68393Smarkus 		break;
866a6be8e7cSmarkus 	case SSHKEY_SERIALIZE_STATE:
867a6be8e7cSmarkus 		if ((r = sshkey_xmss_deserialize_state(k, b)) != 0)
868a6be8e7cSmarkus 			return r;
869a6be8e7cSmarkus 		break;
870a6be8e7cSmarkus 	case SSHKEY_SERIALIZE_FULL:
871a6be8e7cSmarkus 		if ((r = sshkey_xmss_deserialize_enc_key(k, b)) != 0 ||
872a6be8e7cSmarkus 		    (r = sshkey_xmss_deserialize_state(k, b)) != 0)
873a6be8e7cSmarkus 			return r;
874a6be8e7cSmarkus 		break;
875a6be8e7cSmarkus 	default:
876a6be8e7cSmarkus 		r = SSH_ERR_INVALID_FORMAT;
877a6be8e7cSmarkus 		break;
878a6be8e7cSmarkus 	}
879a6be8e7cSmarkus 	return r;
880a6be8e7cSmarkus }
881a6be8e7cSmarkus 
882a6be8e7cSmarkus int
sshkey_xmss_encrypt_state(const struct sshkey * k,struct sshbuf * b,struct sshbuf ** retp)883a6be8e7cSmarkus sshkey_xmss_encrypt_state(const struct sshkey *k, struct sshbuf *b,
884a6be8e7cSmarkus    struct sshbuf **retp)
885a6be8e7cSmarkus {
886a6be8e7cSmarkus 	struct ssh_xmss_state *state = k->xmss_state;
887a6be8e7cSmarkus 	struct sshbuf *encrypted = NULL, *encoded = NULL, *padded = NULL;
888a6be8e7cSmarkus 	struct sshcipher_ctx *ciphercontext = NULL;
889a6be8e7cSmarkus 	const struct sshcipher *cipher;
890a6be8e7cSmarkus 	u_char *cp, *key, *iv = NULL;
891a6be8e7cSmarkus 	size_t i, keylen, ivlen, blocksize, authlen, encrypted_len, aadlen;
892a6be8e7cSmarkus 	int r = SSH_ERR_INTERNAL_ERROR;
893a6be8e7cSmarkus 
894a6be8e7cSmarkus 	if (retp != NULL)
895a6be8e7cSmarkus 		*retp = NULL;
896a6be8e7cSmarkus 	if (state == NULL ||
897a6be8e7cSmarkus 	    state->enc_keyiv == NULL ||
898a6be8e7cSmarkus 	    state->enc_ciphername == NULL)
899a6be8e7cSmarkus 		return SSH_ERR_INTERNAL_ERROR;
900a6be8e7cSmarkus 	if ((cipher = cipher_by_name(state->enc_ciphername)) == NULL) {
901a6be8e7cSmarkus 		r = SSH_ERR_INTERNAL_ERROR;
902a6be8e7cSmarkus 		goto out;
903a6be8e7cSmarkus 	}
904a6be8e7cSmarkus 	blocksize = cipher_blocksize(cipher);
905a6be8e7cSmarkus 	keylen = cipher_keylen(cipher);
906a6be8e7cSmarkus 	ivlen = cipher_ivlen(cipher);
907a6be8e7cSmarkus 	authlen = cipher_authlen(cipher);
908a6be8e7cSmarkus 	if (state->enc_keyiv_len != keylen + ivlen) {
909a6be8e7cSmarkus 		r = SSH_ERR_INVALID_FORMAT;
910a6be8e7cSmarkus 		goto out;
911a6be8e7cSmarkus 	}
912a6be8e7cSmarkus 	key = state->enc_keyiv;
913a6be8e7cSmarkus 	if ((encrypted = sshbuf_new()) == NULL ||
914a6be8e7cSmarkus 	    (encoded = sshbuf_new()) == NULL ||
915a6be8e7cSmarkus 	    (padded = sshbuf_new()) == NULL ||
916a6be8e7cSmarkus 	    (iv = malloc(ivlen)) == NULL) {
917a6be8e7cSmarkus 		r = SSH_ERR_ALLOC_FAIL;
918a6be8e7cSmarkus 		goto out;
919a6be8e7cSmarkus 	}
920a6be8e7cSmarkus 
921a6be8e7cSmarkus 	/* replace first 4 bytes of IV with index to ensure uniqueness */
922a6be8e7cSmarkus 	memcpy(iv, key + keylen, ivlen);
923a6be8e7cSmarkus 	POKE_U32(iv, state->idx);
924a6be8e7cSmarkus 
925a6be8e7cSmarkus 	if ((r = sshbuf_put(encoded, XMSS_MAGIC, sizeof(XMSS_MAGIC))) != 0 ||
926a6be8e7cSmarkus 	    (r = sshbuf_put_u32(encoded, state->idx)) != 0)
927a6be8e7cSmarkus 		goto out;
928a6be8e7cSmarkus 
929a6be8e7cSmarkus 	/* padded state will be encrypted */
930a6be8e7cSmarkus 	if ((r = sshbuf_putb(padded, b)) != 0)
931a6be8e7cSmarkus 		goto out;
932a6be8e7cSmarkus 	i = 0;
933a6be8e7cSmarkus 	while (sshbuf_len(padded) % blocksize) {
934a6be8e7cSmarkus 		if ((r = sshbuf_put_u8(padded, ++i & 0xff)) != 0)
935a6be8e7cSmarkus 			goto out;
936a6be8e7cSmarkus 	}
937a6be8e7cSmarkus 	encrypted_len = sshbuf_len(padded);
938a6be8e7cSmarkus 
939a6be8e7cSmarkus 	/* header including the length of state is used as AAD */
940a6be8e7cSmarkus 	if ((r = sshbuf_put_u32(encoded, encrypted_len)) != 0)
941a6be8e7cSmarkus 		goto out;
942a6be8e7cSmarkus 	aadlen = sshbuf_len(encoded);
943a6be8e7cSmarkus 
944a6be8e7cSmarkus 	/* concat header and state */
945a6be8e7cSmarkus 	if ((r = sshbuf_putb(encoded, padded)) != 0)
946a6be8e7cSmarkus 		goto out;
947a6be8e7cSmarkus 
948a6be8e7cSmarkus 	/* reserve space for encryption of encoded data plus auth tag */
949a6be8e7cSmarkus 	/* encrypt at offset addlen */
950a6be8e7cSmarkus 	if ((r = sshbuf_reserve(encrypted,
951a6be8e7cSmarkus 	    encrypted_len + aadlen + authlen, &cp)) != 0 ||
952a6be8e7cSmarkus 	    (r = cipher_init(&ciphercontext, cipher, key, keylen,
953a6be8e7cSmarkus 	    iv, ivlen, 1)) != 0 ||
954a6be8e7cSmarkus 	    (r = cipher_crypt(ciphercontext, 0, cp, sshbuf_ptr(encoded),
955a6be8e7cSmarkus 	    encrypted_len, aadlen, authlen)) != 0)
956a6be8e7cSmarkus 		goto out;
957a6be8e7cSmarkus 
958a6be8e7cSmarkus 	/* success */
959a6be8e7cSmarkus 	r = 0;
960a6be8e7cSmarkus  out:
961a6be8e7cSmarkus 	if (retp != NULL) {
962a6be8e7cSmarkus 		*retp = encrypted;
963a6be8e7cSmarkus 		encrypted = NULL;
964a6be8e7cSmarkus 	}
965a6be8e7cSmarkus 	sshbuf_free(padded);
966a6be8e7cSmarkus 	sshbuf_free(encoded);
967a6be8e7cSmarkus 	sshbuf_free(encrypted);
968a6be8e7cSmarkus 	cipher_free(ciphercontext);
969a6be8e7cSmarkus 	free(iv);
970a6be8e7cSmarkus 	return r;
971a6be8e7cSmarkus }
972a6be8e7cSmarkus 
973a6be8e7cSmarkus int
sshkey_xmss_decrypt_state(const struct sshkey * k,struct sshbuf * encoded,struct sshbuf ** retp)974a6be8e7cSmarkus sshkey_xmss_decrypt_state(const struct sshkey *k, struct sshbuf *encoded,
975a6be8e7cSmarkus    struct sshbuf **retp)
976a6be8e7cSmarkus {
977a6be8e7cSmarkus 	struct ssh_xmss_state *state = k->xmss_state;
978a6be8e7cSmarkus 	struct sshbuf *copy = NULL, *decrypted = NULL;
979a6be8e7cSmarkus 	struct sshcipher_ctx *ciphercontext = NULL;
980a6be8e7cSmarkus 	const struct sshcipher *cipher = NULL;
981a6be8e7cSmarkus 	u_char *key, *iv = NULL, *dp;
982a6be8e7cSmarkus 	size_t keylen, ivlen, authlen, aadlen;
983a6be8e7cSmarkus 	u_int blocksize, encrypted_len, index;
984a6be8e7cSmarkus 	int r = SSH_ERR_INTERNAL_ERROR;
985a6be8e7cSmarkus 
986a6be8e7cSmarkus 	if (retp != NULL)
987a6be8e7cSmarkus 		*retp = NULL;
988a6be8e7cSmarkus 	if (state == NULL ||
989a6be8e7cSmarkus 	    state->enc_keyiv == NULL ||
990a6be8e7cSmarkus 	    state->enc_ciphername == NULL)
991a6be8e7cSmarkus 		return SSH_ERR_INTERNAL_ERROR;
992a6be8e7cSmarkus 	if ((cipher = cipher_by_name(state->enc_ciphername)) == NULL) {
993a6be8e7cSmarkus 		r = SSH_ERR_INVALID_FORMAT;
994a6be8e7cSmarkus 		goto out;
995a6be8e7cSmarkus 	}
996a6be8e7cSmarkus 	blocksize = cipher_blocksize(cipher);
997a6be8e7cSmarkus 	keylen = cipher_keylen(cipher);
998a6be8e7cSmarkus 	ivlen = cipher_ivlen(cipher);
999a6be8e7cSmarkus 	authlen = cipher_authlen(cipher);
1000a6be8e7cSmarkus 	if (state->enc_keyiv_len != keylen + ivlen) {
1001a6be8e7cSmarkus 		r = SSH_ERR_INTERNAL_ERROR;
1002a6be8e7cSmarkus 		goto out;
1003a6be8e7cSmarkus 	}
1004a6be8e7cSmarkus 	key = state->enc_keyiv;
1005a6be8e7cSmarkus 
1006a6be8e7cSmarkus 	if ((copy = sshbuf_fromb(encoded)) == NULL ||
1007a6be8e7cSmarkus 	    (decrypted = sshbuf_new()) == NULL ||
1008a6be8e7cSmarkus 	    (iv = malloc(ivlen)) == NULL) {
1009a6be8e7cSmarkus 		r = SSH_ERR_ALLOC_FAIL;
1010a6be8e7cSmarkus 		goto out;
1011a6be8e7cSmarkus 	}
1012a6be8e7cSmarkus 
1013a6be8e7cSmarkus 	/* check magic */
1014a6be8e7cSmarkus 	if (sshbuf_len(encoded) < sizeof(XMSS_MAGIC) ||
1015a6be8e7cSmarkus 	    memcmp(sshbuf_ptr(encoded), XMSS_MAGIC, sizeof(XMSS_MAGIC))) {
1016a6be8e7cSmarkus 		r = SSH_ERR_INVALID_FORMAT;
1017a6be8e7cSmarkus 		goto out;
1018a6be8e7cSmarkus 	}
1019a6be8e7cSmarkus 	/* parse public portion */
1020a6be8e7cSmarkus 	if ((r = sshbuf_consume(encoded, sizeof(XMSS_MAGIC))) != 0 ||
1021a6be8e7cSmarkus 	    (r = sshbuf_get_u32(encoded, &index)) != 0 ||
1022a6be8e7cSmarkus 	    (r = sshbuf_get_u32(encoded, &encrypted_len)) != 0)
1023a6be8e7cSmarkus 		goto out;
1024a6be8e7cSmarkus 
1025a6be8e7cSmarkus 	/* check size of encrypted key blob */
1026a6be8e7cSmarkus 	if (encrypted_len < blocksize || (encrypted_len % blocksize) != 0) {
1027a6be8e7cSmarkus 		r = SSH_ERR_INVALID_FORMAT;
1028a6be8e7cSmarkus 		goto out;
1029a6be8e7cSmarkus 	}
1030a6be8e7cSmarkus 	/* check that an appropriate amount of auth data is present */
103142e35d48Sdjm 	if (sshbuf_len(encoded) < authlen ||
103242e35d48Sdjm 	    sshbuf_len(encoded) - authlen < encrypted_len) {
1033a6be8e7cSmarkus 		r = SSH_ERR_INVALID_FORMAT;
1034a6be8e7cSmarkus 		goto out;
1035a6be8e7cSmarkus 	}
1036a6be8e7cSmarkus 
1037a6be8e7cSmarkus 	aadlen = sshbuf_len(copy) - sshbuf_len(encoded);
1038a6be8e7cSmarkus 
1039a6be8e7cSmarkus 	/* replace first 4 bytes of IV with index to ensure uniqueness */
1040a6be8e7cSmarkus 	memcpy(iv, key + keylen, ivlen);
1041a6be8e7cSmarkus 	POKE_U32(iv, index);
1042a6be8e7cSmarkus 
1043a6be8e7cSmarkus 	/* decrypt private state of key */
1044a6be8e7cSmarkus 	if ((r = sshbuf_reserve(decrypted, aadlen + encrypted_len, &dp)) != 0 ||
1045a6be8e7cSmarkus 	    (r = cipher_init(&ciphercontext, cipher, key, keylen,
1046a6be8e7cSmarkus 	    iv, ivlen, 0)) != 0 ||
1047a6be8e7cSmarkus 	    (r = cipher_crypt(ciphercontext, 0, dp, sshbuf_ptr(copy),
1048a6be8e7cSmarkus 	    encrypted_len, aadlen, authlen)) != 0)
1049a6be8e7cSmarkus 		goto out;
1050a6be8e7cSmarkus 
1051a6be8e7cSmarkus 	/* there should be no trailing data */
1052a6be8e7cSmarkus 	if ((r = sshbuf_consume(encoded, encrypted_len + authlen)) != 0)
1053a6be8e7cSmarkus 		goto out;
1054a6be8e7cSmarkus 	if (sshbuf_len(encoded) != 0) {
1055a6be8e7cSmarkus 		r = SSH_ERR_INVALID_FORMAT;
1056a6be8e7cSmarkus 		goto out;
1057a6be8e7cSmarkus 	}
1058a6be8e7cSmarkus 
1059a6be8e7cSmarkus 	/* remove AAD */
1060a6be8e7cSmarkus 	if ((r = sshbuf_consume(decrypted, aadlen)) != 0)
1061a6be8e7cSmarkus 		goto out;
1062a6be8e7cSmarkus 	/* XXX encrypted includes unchecked padding */
1063a6be8e7cSmarkus 
1064a6be8e7cSmarkus 	/* success */
1065a6be8e7cSmarkus 	r = 0;
1066a6be8e7cSmarkus 	if (retp != NULL) {
1067a6be8e7cSmarkus 		*retp = decrypted;
1068a6be8e7cSmarkus 		decrypted = NULL;
1069a6be8e7cSmarkus 	}
1070a6be8e7cSmarkus  out:
1071a6be8e7cSmarkus 	cipher_free(ciphercontext);
1072a6be8e7cSmarkus 	sshbuf_free(copy);
1073a6be8e7cSmarkus 	sshbuf_free(decrypted);
1074a6be8e7cSmarkus 	free(iv);
1075a6be8e7cSmarkus 	return r;
1076a6be8e7cSmarkus }
1077a6be8e7cSmarkus 
1078a6be8e7cSmarkus u_int32_t
sshkey_xmss_signatures_left(const struct sshkey * k)1079a6be8e7cSmarkus sshkey_xmss_signatures_left(const struct sshkey *k)
1080a6be8e7cSmarkus {
1081a6be8e7cSmarkus 	struct ssh_xmss_state *state = k->xmss_state;
1082a6be8e7cSmarkus 	u_int32_t idx;
1083a6be8e7cSmarkus 
1084a6be8e7cSmarkus 	if (sshkey_type_plain(k->type) == KEY_XMSS && state &&
1085a6be8e7cSmarkus 	    state->maxidx) {
1086a6be8e7cSmarkus 		idx = k->xmss_sk ? PEEK_U32(k->xmss_sk) : state->idx;
1087a6be8e7cSmarkus 		if (idx < state->maxidx)
1088a6be8e7cSmarkus 			return state->maxidx - idx;
1089a6be8e7cSmarkus 	}
1090a6be8e7cSmarkus 	return 0;
1091a6be8e7cSmarkus }
1092a6be8e7cSmarkus 
1093a6be8e7cSmarkus int
sshkey_xmss_enable_maxsign(struct sshkey * k,u_int32_t maxsign)1094a6be8e7cSmarkus sshkey_xmss_enable_maxsign(struct sshkey *k, u_int32_t maxsign)
1095a6be8e7cSmarkus {
1096a6be8e7cSmarkus 	struct ssh_xmss_state *state = k->xmss_state;
1097a6be8e7cSmarkus 
1098a6be8e7cSmarkus 	if (sshkey_type_plain(k->type) != KEY_XMSS)
1099a6be8e7cSmarkus 		return SSH_ERR_INVALID_ARGUMENT;
1100a6be8e7cSmarkus 	if (maxsign == 0)
1101a6be8e7cSmarkus 		return 0;
1102a6be8e7cSmarkus 	if (state->idx + maxsign < state->idx)
1103a6be8e7cSmarkus 		return SSH_ERR_INVALID_ARGUMENT;
1104a6be8e7cSmarkus 	state->maxidx = state->idx + maxsign;
1105a6be8e7cSmarkus 	return 0;
1106a6be8e7cSmarkus }
1107