xref: /openbsd-src/usr.bin/ssh/ssh-pkcs11-helper.c (revision c90a81c56dcebd6a1b73fe4aff9b03385b8e63b3)
1 /* $OpenBSD: ssh-pkcs11-helper.c,v 1.14 2018/01/08 15:18:46 markus Exp $ */
2 /*
3  * Copyright (c) 2010 Markus Friedl.  All rights reserved.
4  *
5  * Permission to use, copy, modify, and distribute this software for any
6  * purpose with or without fee is hereby granted, provided that the above
7  * copyright notice and this permission notice appear in all copies.
8  *
9  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
10  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
11  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
12  * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
13  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
14  * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
15  * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
16  */
17 
18 #include <sys/types.h>
19 #include <sys/queue.h>
20 #include <sys/time.h>
21 
22 #include <stdarg.h>
23 #include <string.h>
24 #include <unistd.h>
25 #include <errno.h>
26 
27 #include "xmalloc.h"
28 #include "sshbuf.h"
29 #include "log.h"
30 #include "misc.h"
31 #include "sshkey.h"
32 #include "authfd.h"
33 #include "ssh-pkcs11.h"
34 #include "ssherr.h"
35 
36 /* borrows code from sftp-server and ssh-agent */
37 
38 struct pkcs11_keyinfo {
39 	struct sshkey	*key;
40 	char		*providername;
41 	TAILQ_ENTRY(pkcs11_keyinfo) next;
42 };
43 
44 TAILQ_HEAD(, pkcs11_keyinfo) pkcs11_keylist;
45 
46 #define MAX_MSG_LENGTH		10240 /*XXX*/
47 
48 /* input and output queue */
49 struct sshbuf *iqueue;
50 struct sshbuf *oqueue;
51 
52 static void
53 add_key(struct sshkey *k, char *name)
54 {
55 	struct pkcs11_keyinfo *ki;
56 
57 	ki = xcalloc(1, sizeof(*ki));
58 	ki->providername = xstrdup(name);
59 	ki->key = k;
60 	TAILQ_INSERT_TAIL(&pkcs11_keylist, ki, next);
61 }
62 
63 static void
64 del_keys_by_name(char *name)
65 {
66 	struct pkcs11_keyinfo *ki, *nxt;
67 
68 	for (ki = TAILQ_FIRST(&pkcs11_keylist); ki; ki = nxt) {
69 		nxt = TAILQ_NEXT(ki, next);
70 		if (!strcmp(ki->providername, name)) {
71 			TAILQ_REMOVE(&pkcs11_keylist, ki, next);
72 			free(ki->providername);
73 			sshkey_free(ki->key);
74 			free(ki);
75 		}
76 	}
77 }
78 
79 /* lookup matching 'private' key */
80 static struct sshkey *
81 lookup_key(struct sshkey *k)
82 {
83 	struct pkcs11_keyinfo *ki;
84 
85 	TAILQ_FOREACH(ki, &pkcs11_keylist, next) {
86 		debug("check %p %s", ki, ki->providername);
87 		if (sshkey_equal(k, ki->key))
88 			return (ki->key);
89 	}
90 	return (NULL);
91 }
92 
93 static void
94 send_msg(struct sshbuf *m)
95 {
96 	int r;
97 
98 	if ((r = sshbuf_put_stringb(oqueue, m)) != 0)
99 		fatal("%s: buffer error: %s", __func__, ssh_err(r));
100 }
101 
102 static void
103 process_add(void)
104 {
105 	char *name, *pin;
106 	struct sshkey **keys;
107 	int r, i, nkeys;
108 	u_char *blob;
109 	size_t blen;
110 	struct sshbuf *msg;
111 
112 	if ((msg = sshbuf_new()) == NULL)
113 		fatal("%s: sshbuf_new failed", __func__);
114 	if ((r = sshbuf_get_cstring(iqueue, &name, NULL)) != 0 ||
115 	    (r = sshbuf_get_cstring(iqueue, &pin, NULL)) != 0)
116 		fatal("%s: buffer error: %s", __func__, ssh_err(r));
117 	if ((nkeys = pkcs11_add_provider(name, pin, &keys)) > 0) {
118 		if ((r = sshbuf_put_u8(msg,
119 		    SSH2_AGENT_IDENTITIES_ANSWER)) != 0 ||
120 		    (r = sshbuf_put_u32(msg, nkeys)) != 0)
121 			fatal("%s: buffer error: %s", __func__, ssh_err(r));
122 		for (i = 0; i < nkeys; i++) {
123 			if ((r = sshkey_to_blob(keys[i], &blob, &blen)) != 0) {
124 				debug("%s: sshkey_to_blob: %s",
125 				    __func__, ssh_err(r));
126 				continue;
127 			}
128 			if ((r = sshbuf_put_string(msg, blob, blen)) != 0 ||
129 			    (r = sshbuf_put_cstring(msg, name)) != 0)
130 				fatal("%s: buffer error: %s",
131 				    __func__, ssh_err(r));
132 			free(blob);
133 			add_key(keys[i], name);
134 		}
135 		free(keys);
136 	} else {
137 		if ((r = sshbuf_put_u8(msg, SSH_AGENT_FAILURE)) != 0)
138 			fatal("%s: buffer error: %s", __func__, ssh_err(r));
139 	}
140 	free(pin);
141 	free(name);
142 	send_msg(msg);
143 	sshbuf_free(msg);
144 }
145 
146 static void
147 process_del(void)
148 {
149 	char *name, *pin;
150 	struct sshbuf *msg;
151 	int r;
152 
153 	if ((msg = sshbuf_new()) == NULL)
154 		fatal("%s: sshbuf_new failed", __func__);
155 	if ((r = sshbuf_get_cstring(iqueue, &name, NULL)) != 0 ||
156 	    (r = sshbuf_get_cstring(iqueue, &pin, NULL)) != 0)
157 		fatal("%s: buffer error: %s", __func__, ssh_err(r));
158 	del_keys_by_name(name);
159 	if ((r = sshbuf_put_u8(msg, pkcs11_del_provider(name) == 0 ?
160 	    SSH_AGENT_SUCCESS : SSH_AGENT_FAILURE)) != 0)
161 		fatal("%s: buffer error: %s", __func__, ssh_err(r));
162 	free(pin);
163 	free(name);
164 	send_msg(msg);
165 	sshbuf_free(msg);
166 }
167 
168 static void
169 process_sign(void)
170 {
171 	u_char *blob, *data, *signature = NULL;
172 	size_t blen, dlen, slen = 0;
173 	int r, ok = -1;
174 	struct sshkey *key, *found;
175 	struct sshbuf *msg;
176 
177 	/* XXX support SHA2 signature flags */
178 	if ((r = sshbuf_get_string(iqueue, &blob, &blen)) != 0 ||
179 	    (r = sshbuf_get_string(iqueue, &data, &dlen)) != 0 ||
180 	    (r = sshbuf_get_u32(iqueue, NULL)) != 0)
181 		fatal("%s: buffer error: %s", __func__, ssh_err(r));
182 
183 	if ((r = sshkey_from_blob(blob, blen, &key)) != 0)
184 		error("%s: sshkey_from_blob: %s", __func__, ssh_err(r));
185 	else {
186 		if ((found = lookup_key(key)) != NULL) {
187 #ifdef WITH_OPENSSL
188 			int ret;
189 
190 			slen = RSA_size(key->rsa);
191 			signature = xmalloc(slen);
192 			if ((ret = RSA_private_encrypt(dlen, data, signature,
193 			    found->rsa, RSA_PKCS1_PADDING)) != -1) {
194 				slen = ret;
195 				ok = 0;
196 			}
197 #endif /* WITH_OPENSSL */
198 		}
199 		sshkey_free(key);
200 	}
201 	if ((msg = sshbuf_new()) == NULL)
202 		fatal("%s: sshbuf_new failed", __func__);
203 	if (ok == 0) {
204 		if ((r = sshbuf_put_u8(msg, SSH2_AGENT_SIGN_RESPONSE)) != 0 ||
205 		    (r = sshbuf_put_string(msg, signature, slen)) != 0)
206 			fatal("%s: buffer error: %s", __func__, ssh_err(r));
207 	} else {
208 		if ((r = sshbuf_put_u8(msg, SSH2_AGENT_FAILURE)) != 0)
209 			fatal("%s: buffer error: %s", __func__, ssh_err(r));
210 	}
211 	free(data);
212 	free(blob);
213 	free(signature);
214 	send_msg(msg);
215 	sshbuf_free(msg);
216 }
217 
218 static void
219 process(void)
220 {
221 	u_int msg_len;
222 	u_int buf_len;
223 	u_int consumed;
224 	u_char type;
225 	const u_char *cp;
226 	int r;
227 
228 	buf_len = sshbuf_len(iqueue);
229 	if (buf_len < 5)
230 		return;		/* Incomplete message. */
231 	cp = sshbuf_ptr(iqueue);
232 	msg_len = get_u32(cp);
233 	if (msg_len > MAX_MSG_LENGTH) {
234 		error("bad message len %d", msg_len);
235 		cleanup_exit(11);
236 	}
237 	if (buf_len < msg_len + 4)
238 		return;
239 	if ((r = sshbuf_consume(iqueue, 4)) != 0 ||
240 	    (r = sshbuf_get_u8(iqueue, &type)) != 0)
241 		fatal("%s: buffer error: %s", __func__, ssh_err(r));
242 	buf_len -= 4;
243 	switch (type) {
244 	case SSH_AGENTC_ADD_SMARTCARD_KEY:
245 		debug("process_add");
246 		process_add();
247 		break;
248 	case SSH_AGENTC_REMOVE_SMARTCARD_KEY:
249 		debug("process_del");
250 		process_del();
251 		break;
252 	case SSH2_AGENTC_SIGN_REQUEST:
253 		debug("process_sign");
254 		process_sign();
255 		break;
256 	default:
257 		error("Unknown message %d", type);
258 		break;
259 	}
260 	/* discard the remaining bytes from the current packet */
261 	if (buf_len < sshbuf_len(iqueue)) {
262 		error("iqueue grew unexpectedly");
263 		cleanup_exit(255);
264 	}
265 	consumed = buf_len - sshbuf_len(iqueue);
266 	if (msg_len < consumed) {
267 		error("msg_len %d < consumed %d", msg_len, consumed);
268 		cleanup_exit(255);
269 	}
270 	if (msg_len > consumed) {
271 		if ((r = sshbuf_consume(iqueue, msg_len - consumed)) != 0)
272 			fatal("%s: buffer error: %s", __func__, ssh_err(r));
273 	}
274 }
275 
276 void
277 cleanup_exit(int i)
278 {
279 	/* XXX */
280 	_exit(i);
281 }
282 
283 int
284 main(int argc, char **argv)
285 {
286 	fd_set *rset, *wset;
287 	int r, in, out, max, log_stderr = 0;
288 	ssize_t len, olen, set_size;
289 	SyslogFacility log_facility = SYSLOG_FACILITY_AUTH;
290 	LogLevel log_level = SYSLOG_LEVEL_ERROR;
291 	char buf[4*4096];
292 	extern char *__progname;
293 
294 	ssh_malloc_init();	/* must be called before any mallocs */
295 	TAILQ_INIT(&pkcs11_keylist);
296 	pkcs11_init(0);
297 
298 	log_init(__progname, log_level, log_facility, log_stderr);
299 
300 	in = STDIN_FILENO;
301 	out = STDOUT_FILENO;
302 
303 	max = 0;
304 	if (in > max)
305 		max = in;
306 	if (out > max)
307 		max = out;
308 
309 	if ((iqueue = sshbuf_new()) == NULL)
310 		fatal("%s: sshbuf_new failed", __func__);
311 	if ((oqueue = sshbuf_new()) == NULL)
312 		fatal("%s: sshbuf_new failed", __func__);
313 
314 	set_size = howmany(max + 1, NFDBITS) * sizeof(fd_mask);
315 	rset = xmalloc(set_size);
316 	wset = xmalloc(set_size);
317 
318 	for (;;) {
319 		memset(rset, 0, set_size);
320 		memset(wset, 0, set_size);
321 
322 		/*
323 		 * Ensure that we can read a full buffer and handle
324 		 * the worst-case length packet it can generate,
325 		 * otherwise apply backpressure by stopping reads.
326 		 */
327 		if ((r = sshbuf_check_reserve(iqueue, sizeof(buf))) == 0 &&
328 		    (r = sshbuf_check_reserve(oqueue, MAX_MSG_LENGTH)) == 0)
329 			FD_SET(in, rset);
330 		else if (r != SSH_ERR_NO_BUFFER_SPACE)
331 			fatal("%s: buffer error: %s", __func__, ssh_err(r));
332 
333 		olen = sshbuf_len(oqueue);
334 		if (olen > 0)
335 			FD_SET(out, wset);
336 
337 		if (select(max+1, rset, wset, NULL, NULL) < 0) {
338 			if (errno == EINTR)
339 				continue;
340 			error("select: %s", strerror(errno));
341 			cleanup_exit(2);
342 		}
343 
344 		/* copy stdin to iqueue */
345 		if (FD_ISSET(in, rset)) {
346 			len = read(in, buf, sizeof buf);
347 			if (len == 0) {
348 				debug("read eof");
349 				cleanup_exit(0);
350 			} else if (len < 0) {
351 				error("read: %s", strerror(errno));
352 				cleanup_exit(1);
353 			} else if ((r = sshbuf_put(iqueue, buf, len)) != 0) {
354 				fatal("%s: buffer error: %s",
355 				    __func__, ssh_err(r));
356 			}
357 		}
358 		/* send oqueue to stdout */
359 		if (FD_ISSET(out, wset)) {
360 			len = write(out, sshbuf_ptr(oqueue), olen);
361 			if (len < 0) {
362 				error("write: %s", strerror(errno));
363 				cleanup_exit(1);
364 			} else if ((r = sshbuf_consume(oqueue, len)) != 0) {
365 				fatal("%s: buffer error: %s",
366 				    __func__, ssh_err(r));
367 			}
368 		}
369 
370 		/*
371 		 * Process requests from client if we can fit the results
372 		 * into the output buffer, otherwise stop processing input
373 		 * and let the output queue drain.
374 		 */
375 		if ((r = sshbuf_check_reserve(oqueue, MAX_MSG_LENGTH)) == 0)
376 			process();
377 		else if (r != SSH_ERR_NO_BUFFER_SPACE)
378 			fatal("%s: buffer error: %s", __func__, ssh_err(r));
379 	}
380 }
381