xref: /netbsd-src/crypto/external/bsd/openssh/dist/ssh-pkcs11-helper.c (revision 9469f4f13c84743995b7d51c506f9c9849ba30de)
1 /*	$NetBSD: ssh-pkcs11-helper.c,v 1.23 2024/09/24 21:32:19 christos Exp $	*/
2 /* $OpenBSD: ssh-pkcs11-helper.c,v 1.27 2024/08/15 00:51:51 djm Exp $ */
3 
4 /*
5  * Copyright (c) 2010 Markus Friedl.  All rights reserved.
6  *
7  * Permission to use, copy, modify, and distribute this software for any
8  * purpose with or without fee is hereby granted, provided that the above
9  * copyright notice and this permission notice appear in all copies.
10  *
11  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
12  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
13  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
14  * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
15  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
16  * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
17  * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
18  */
19 #include "includes.h"
20 __RCSID("$NetBSD: ssh-pkcs11-helper.c,v 1.23 2024/09/24 21:32:19 christos Exp $");
21 
22 #include <sys/types.h>
23 #include <sys/queue.h>
24 #include <sys/time.h>
25 #include <sys/param.h>
26 
27 #include <stdlib.h>
28 #include <errno.h>
29 #include <poll.h>
30 #include <stdarg.h>
31 #include <string.h>
32 #include <unistd.h>
33 
34 #include "xmalloc.h"
35 #include "sshbuf.h"
36 #include "log.h"
37 #include "misc.h"
38 #include "sshkey.h"
39 #include "authfd.h"
40 #include "ssh-pkcs11.h"
41 #include "ssherr.h"
42 
43 #ifdef WITH_OPENSSL
44 #include <openssl/ec.h>
45 #include <openssl/rsa.h>
46 
47 /* borrows code from sftp-server and ssh-agent */
48 
49 struct pkcs11_keyinfo {
50 	struct sshkey	*key;
51 	char		*providername, *label;
52 	TAILQ_ENTRY(pkcs11_keyinfo) next;
53 };
54 
55 TAILQ_HEAD(, pkcs11_keyinfo) pkcs11_keylist;
56 
57 #define MAX_MSG_LENGTH		10240 /*XXX*/
58 
59 /* input and output queue */
60 struct sshbuf *iqueue;
61 struct sshbuf *oqueue;
62 
63 static void
64 add_key(struct sshkey *k, char *name, char *label)
65 {
66 	struct pkcs11_keyinfo *ki;
67 
68 	ki = xcalloc(1, sizeof(*ki));
69 	ki->providername = xstrdup(name);
70 	ki->key = k;
71 	ki->label = xstrdup(label);
72 	TAILQ_INSERT_TAIL(&pkcs11_keylist, ki, next);
73 }
74 
75 static void
76 del_keys_by_name(char *name)
77 {
78 	struct pkcs11_keyinfo *ki, *nxt;
79 
80 	for (ki = TAILQ_FIRST(&pkcs11_keylist); ki; ki = nxt) {
81 		nxt = TAILQ_NEXT(ki, next);
82 		if (!strcmp(ki->providername, name)) {
83 			TAILQ_REMOVE(&pkcs11_keylist, ki, next);
84 			free(ki->providername);
85 			free(ki->label);
86 			sshkey_free(ki->key);
87 			free(ki);
88 		}
89 	}
90 }
91 
92 /* lookup matching 'private' key */
93 static struct sshkey *
94 lookup_key(struct sshkey *k)
95 {
96 	struct pkcs11_keyinfo *ki;
97 
98 	TAILQ_FOREACH(ki, &pkcs11_keylist, next) {
99 		debug("check %s %s %s", sshkey_type(ki->key),
100 		    ki->providername, ki->label);
101 		if (sshkey_equal(k, ki->key))
102 			return (ki->key);
103 	}
104 	return (NULL);
105 }
106 
107 static void
108 send_msg(struct sshbuf *m)
109 {
110 	int r;
111 
112 	if ((r = sshbuf_put_stringb(oqueue, m)) != 0)
113 		fatal_fr(r, "enqueue");
114 }
115 
116 static void
117 process_add(void)
118 {
119 	char *name, *pin;
120 	struct sshkey **keys = NULL;
121 	int r, i, nkeys;
122 	u_char *blob;
123 	size_t blen;
124 	struct sshbuf *msg;
125 	char **labels = NULL;
126 
127 	if ((msg = sshbuf_new()) == NULL)
128 		fatal_f("sshbuf_new failed");
129 	if ((r = sshbuf_get_cstring(iqueue, &name, NULL)) != 0 ||
130 	    (r = sshbuf_get_cstring(iqueue, &pin, NULL)) != 0)
131 		fatal_fr(r, "parse");
132 	if ((nkeys = pkcs11_add_provider(name, pin, &keys, &labels)) > 0) {
133 		if ((r = sshbuf_put_u8(msg,
134 		    SSH2_AGENT_IDENTITIES_ANSWER)) != 0 ||
135 		    (r = sshbuf_put_u32(msg, nkeys)) != 0)
136 			fatal_fr(r, "compose");
137 		for (i = 0; i < nkeys; i++) {
138 			if ((r = sshkey_to_blob(keys[i], &blob, &blen)) != 0) {
139 				debug_fr(r, "encode key");
140 				continue;
141 			}
142 			if ((r = sshbuf_put_string(msg, blob, blen)) != 0 ||
143 			    (r = sshbuf_put_cstring(msg, labels[i])) != 0)
144 				fatal_fr(r, "compose key");
145 			free(blob);
146 			add_key(keys[i], name, labels[i]);
147 			free(labels[i]);
148 		}
149 	} else if ((r = sshbuf_put_u8(msg, SSH_AGENT_FAILURE)) != 0 ||
150 	    (r = sshbuf_put_u32(msg, -nkeys)) != 0)
151 		fatal_fr(r, "compose");
152 	free(labels);
153 	free(keys); /* keys themselves are transferred to pkcs11_keylist */
154 	free(pin);
155 	free(name);
156 	send_msg(msg);
157 	sshbuf_free(msg);
158 }
159 
160 static void
161 process_del(void)
162 {
163 	char *name, *pin;
164 	struct sshbuf *msg;
165 	int r;
166 
167 	if ((msg = sshbuf_new()) == NULL)
168 		fatal_f("sshbuf_new failed");
169 	if ((r = sshbuf_get_cstring(iqueue, &name, NULL)) != 0 ||
170 	    (r = sshbuf_get_cstring(iqueue, &pin, NULL)) != 0)
171 		fatal_fr(r, "parse");
172 	del_keys_by_name(name);
173 	if ((r = sshbuf_put_u8(msg, pkcs11_del_provider(name) == 0 ?
174 	    SSH_AGENT_SUCCESS : SSH_AGENT_FAILURE)) != 0)
175 		fatal_fr(r, "compose");
176 	free(pin);
177 	free(name);
178 	send_msg(msg);
179 	sshbuf_free(msg);
180 }
181 
182 static void
183 process_sign(void)
184 {
185 	u_char *blob, *data, *signature = NULL;
186 	size_t blen, dlen;
187 	u_int slen = 0;
188 	int len, r, ok = -1;
189 	struct sshkey *key = NULL, *found;
190 	struct sshbuf *msg;
191 	RSA *rsa = NULL;
192 	EC_KEY *ecdsa = NULL;
193 
194 	/* XXX support SHA2 signature flags */
195 	if ((r = sshbuf_get_string(iqueue, &blob, &blen)) != 0 ||
196 	    (r = sshbuf_get_string(iqueue, &data, &dlen)) != 0 ||
197 	    (r = sshbuf_get_u32(iqueue, NULL)) != 0)
198 		fatal_fr(r, "parse");
199 
200 	if ((r = sshkey_from_blob(blob, blen, &key)) != 0)
201 		fatal_fr(r, "decode key");
202 	if ((found = lookup_key(key)) == NULL)
203 		goto reply;
204 
205 	/* XXX use pkey API properly for signing */
206 	switch (key->type) {
207 	case KEY_RSA:
208 		if ((rsa = EVP_PKEY_get1_RSA(found->pkey)) == NULL)
209 			fatal_f("no RSA in pkey");
210 		if ((len = RSA_size(rsa)) < 0)
211 			fatal_f("bad RSA length");
212 		signature = xmalloc(len);
213 		if ((len = RSA_private_encrypt(dlen, data, signature,
214 		    rsa, RSA_PKCS1_PADDING)) < 0) {
215 			error_f("RSA_private_encrypt failed");
216 			goto reply;
217 		}
218 		slen = (u_int)len;
219 		break;
220 	case KEY_ECDSA:
221 		if ((ecdsa = EVP_PKEY_get1_EC_KEY(found->pkey)) == NULL)
222 			fatal_f("no ECDSA in pkey");
223 		if ((len = ECDSA_size(ecdsa)) < 0)
224 			fatal_f("bad ECDSA length");
225 		slen = (u_int)len;
226 		signature = xmalloc(slen);
227 		/* "The parameter type is ignored." */
228 		if (!ECDSA_sign(-1, data, dlen, signature, &slen, ecdsa)) {
229 			error_f("ECDSA_sign failed");
230 			goto reply;
231 		}
232 		break;
233 	default:
234 		fatal_f("unsupported key type %d", key->type);
235 	}
236 	/* success */
237 	ok = 0;
238  reply:
239 	if ((msg = sshbuf_new()) == NULL)
240 		fatal_f("sshbuf_new failed");
241 	if (ok == 0) {
242 		if ((r = sshbuf_put_u8(msg, SSH2_AGENT_SIGN_RESPONSE)) != 0 ||
243 		    (r = sshbuf_put_string(msg, signature, slen)) != 0)
244 			fatal_fr(r, "compose response");
245 	} else {
246 		if ((r = sshbuf_put_u8(msg, SSH2_AGENT_FAILURE)) != 0)
247 			fatal_fr(r, "compose failure response");
248 	}
249 	sshkey_free(key);
250 	RSA_free(rsa);
251 	EC_KEY_free(ecdsa);
252 	free(data);
253 	free(blob);
254 	free(signature);
255 	send_msg(msg);
256 	sshbuf_free(msg);
257 }
258 
259 static void
260 process(void)
261 {
262 	u_int msg_len;
263 	u_int buf_len;
264 	u_int consumed;
265 	u_char type;
266 	const u_char *cp;
267 	int r;
268 
269 	buf_len = sshbuf_len(iqueue);
270 	if (buf_len < 5)
271 		return;		/* Incomplete message. */
272 	cp = sshbuf_ptr(iqueue);
273 	msg_len = get_u32(cp);
274 	if (msg_len > MAX_MSG_LENGTH) {
275 		error("bad message len %d", msg_len);
276 		cleanup_exit(11);
277 	}
278 	if (buf_len < msg_len + 4)
279 		return;
280 	if ((r = sshbuf_consume(iqueue, 4)) != 0 ||
281 	    (r = sshbuf_get_u8(iqueue, &type)) != 0)
282 		fatal_fr(r, "parse type/len");
283 	buf_len -= 4;
284 	switch (type) {
285 	case SSH_AGENTC_ADD_SMARTCARD_KEY:
286 		debug("process_add");
287 		process_add();
288 		break;
289 	case SSH_AGENTC_REMOVE_SMARTCARD_KEY:
290 		debug("process_del");
291 		process_del();
292 		break;
293 	case SSH2_AGENTC_SIGN_REQUEST:
294 		debug("process_sign");
295 		process_sign();
296 		break;
297 	default:
298 		error("Unknown message %d", type);
299 		break;
300 	}
301 	/* discard the remaining bytes from the current packet */
302 	if (buf_len < sshbuf_len(iqueue)) {
303 		error("iqueue grew unexpectedly");
304 		cleanup_exit(255);
305 	}
306 	consumed = buf_len - sshbuf_len(iqueue);
307 	if (msg_len < consumed) {
308 		error("msg_len %d < consumed %d", msg_len, consumed);
309 		cleanup_exit(255);
310 	}
311 	if (msg_len > consumed) {
312 		if ((r = sshbuf_consume(iqueue, msg_len - consumed)) != 0)
313 			fatal_fr(r, "consume");
314 	}
315 }
316 
317 void
318 cleanup_exit(int i)
319 {
320 	/* XXX */
321 	_exit(i);
322 }
323 
324 
325 int
326 main(int argc, char **argv)
327 {
328 	int r, ch, in, out, log_stderr = 0;
329 	ssize_t len;
330 	SyslogFacility log_facility = SYSLOG_FACILITY_AUTH;
331 	LogLevel log_level = SYSLOG_LEVEL_ERROR;
332 	char buf[4*4096];
333 	extern char *__progname;
334 	struct pollfd pfd[2];
335 
336 	TAILQ_INIT(&pkcs11_keylist);
337 
338 	log_init(__progname, log_level, log_facility, log_stderr);
339 
340 	while ((ch = getopt(argc, argv, "v")) != -1) {
341 		switch (ch) {
342 		case 'v':
343 			log_stderr = 1;
344 			if (log_level == SYSLOG_LEVEL_ERROR)
345 				log_level = SYSLOG_LEVEL_DEBUG1;
346 			else if (log_level < SYSLOG_LEVEL_DEBUG3)
347 				log_level++;
348 			break;
349 		default:
350 			fprintf(stderr, "usage: %s [-v]\n", __progname);
351 			exit(1);
352 		}
353 	}
354 
355 	log_init(__progname, log_level, log_facility, log_stderr);
356 
357 	pkcs11_init(0);
358 	in = STDIN_FILENO;
359 	out = STDOUT_FILENO;
360 
361 	if ((iqueue = sshbuf_new()) == NULL)
362 		fatal_f("sshbuf_new failed");
363 	if ((oqueue = sshbuf_new()) == NULL)
364 		fatal_f("sshbuf_new failed");
365 
366 	while (1) {
367 		memset(pfd, 0, sizeof(pfd));
368 		pfd[0].fd = in;
369 		pfd[1].fd = out;
370 
371 		/*
372 		 * Ensure that we can read a full buffer and handle
373 		 * the worst-case length packet it can generate,
374 		 * otherwise apply backpressure by stopping reads.
375 		 */
376 		if ((r = sshbuf_check_reserve(iqueue, sizeof(buf))) == 0 &&
377 		    (r = sshbuf_check_reserve(oqueue, MAX_MSG_LENGTH)) == 0)
378 			pfd[0].events = POLLIN;
379 		else if (r != SSH_ERR_NO_BUFFER_SPACE)
380 			fatal_fr(r, "reserve");
381 
382 		if (sshbuf_len(oqueue) > 0)
383 			pfd[1].events = POLLOUT;
384 
385 		if ((r = poll(pfd, 2, -1 /* INFTIM */)) <= 0) {
386 			if (r == 0 || errno == EINTR)
387 				continue;
388 			fatal("poll: %s", strerror(errno));
389 		}
390 
391 		/* copy stdin to iqueue */
392 		if ((pfd[0].revents & (POLLIN|POLLHUP|POLLERR)) != 0) {
393 			len = read(in, buf, sizeof buf);
394 			if (len == 0) {
395 				debug("read eof");
396 				cleanup_exit(0);
397 			} else if (len < 0) {
398 				error("read: %s", strerror(errno));
399 				cleanup_exit(1);
400 			} else if ((r = sshbuf_put(iqueue, buf, len)) != 0)
401 				fatal_fr(r, "sshbuf_put");
402 		}
403 		/* send oqueue to stdout */
404 		if ((pfd[1].revents & (POLLOUT|POLLHUP)) != 0) {
405 			len = write(out, sshbuf_ptr(oqueue),
406 			    sshbuf_len(oqueue));
407 			if (len < 0) {
408 				error("write: %s", strerror(errno));
409 				cleanup_exit(1);
410 			} else if ((r = sshbuf_consume(oqueue, len)) != 0)
411 				fatal_fr(r, "consume");
412 		}
413 
414 		/*
415 		 * Process requests from client if we can fit the results
416 		 * into the output buffer, otherwise stop processing input
417 		 * and let the output queue drain.
418 		 */
419 		if ((r = sshbuf_check_reserve(oqueue, MAX_MSG_LENGTH)) == 0)
420 			process();
421 		else if (r != SSH_ERR_NO_BUFFER_SPACE)
422 			fatal_fr(r, "reserve");
423 	}
424 }
425 
426 #else /* WITH_OPENSSL */
427 void
428 cleanup_exit(int i)
429 {
430 	_exit(i);
431 }
432 
433 int
434 main(int argc, char **argv)
435 {
436 	fprintf(stderr, "PKCS#11 code is not enabled\n");
437 	return 1;
438 }
439 #endif /* WITH_OPENSSL */
440