xref: /openbsd-src/regress/lib/libtls/tls/tlstest.c (revision 1a8dbaac879b9f3335ad7fb25429ce63ac1d6bac)
1 /* $OpenBSD: tlstest.c,v 1.12 2020/07/04 09:07:02 jsing Exp $ */
2 /*
3  * Copyright (c) 2017 Joel Sing <jsing@openbsd.org>
4  *
5  * Permission to use, copy, modify, and distribute this software for any
6  * purpose with or without fee is hereby granted, provided that the above
7  * copyright notice and this permission notice appear in all copies.
8  *
9  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
10  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
11  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
12  * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
13  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
14  * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
15  * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
16  */
17 
18 #include <sys/socket.h>
19 
20 #include <err.h>
21 #include <fcntl.h>
22 #include <stdio.h>
23 #include <string.h>
24 #include <unistd.h>
25 
26 #include <tls.h>
27 
28 #define CIRCULAR_BUFFER_SIZE 512
29 
30 unsigned char client_buffer[CIRCULAR_BUFFER_SIZE];
31 unsigned char *client_readptr, *client_writeptr;
32 
33 unsigned char server_buffer[CIRCULAR_BUFFER_SIZE];
34 unsigned char *server_readptr, *server_writeptr;
35 
36 char *cafile, *certfile, *keyfile;
37 
38 int debug = 0;
39 
40 static void
41 circular_init(void)
42 {
43 	client_readptr = client_writeptr = client_buffer;
44 	server_readptr = server_writeptr = server_buffer;
45 }
46 
47 static ssize_t
48 circular_read(char *name, unsigned char *buf, size_t bufsize,
49     unsigned char **readptr, unsigned char *writeptr,
50     unsigned char *outbuf, size_t outlen)
51 {
52 	unsigned char *nextptr = *readptr;
53 	size_t n = 0;
54 
55 	while (n < outlen) {
56 		if (nextptr == writeptr)
57 			break;
58 		*outbuf++ = *nextptr++;
59 		if ((size_t)(nextptr - buf) >= bufsize)
60 			nextptr = buf;
61 		*readptr = nextptr;
62 		n++;
63 	}
64 
65 	if (debug && n > 0)
66 		fprintf(stderr, "%s buffer: read %zi bytes\n", name, n);
67 
68 	return (n > 0 ? (ssize_t)n : TLS_WANT_POLLIN);
69 }
70 
71 static ssize_t
72 circular_write(char *name, unsigned char *buf, size_t bufsize,
73     unsigned char *readptr, unsigned char **writeptr,
74     const unsigned char *inbuf, size_t inlen)
75 {
76 	unsigned char *nextptr = *writeptr;
77 	unsigned char *prevptr;
78 	size_t n = 0;
79 
80 	while (n < inlen) {
81 		prevptr = nextptr++;
82 		if ((size_t)(nextptr - buf) >= bufsize)
83 			nextptr = buf;
84 		if (nextptr == readptr)
85 			break;
86 		*prevptr = *inbuf++;
87 		*writeptr = nextptr;
88 		n++;
89 	}
90 
91 	if (debug && n > 0)
92 		fprintf(stderr, "%s buffer: wrote %zi bytes\n", name, n);
93 
94 	return (n > 0 ? (ssize_t)n : TLS_WANT_POLLOUT);
95 }
96 
97 static ssize_t
98 client_read(struct tls *ctx, void *buf, size_t buflen, void *cb_arg)
99 {
100 	return circular_read("client", client_buffer, sizeof(client_buffer),
101 	    &client_readptr, client_writeptr, buf, buflen);
102 }
103 
104 static ssize_t
105 client_write(struct tls *ctx, const void *buf, size_t buflen, void *cb_arg)
106 {
107 	return circular_write("server", server_buffer, sizeof(server_buffer),
108 	    server_readptr, &server_writeptr, buf, buflen);
109 }
110 
111 static ssize_t
112 server_read(struct tls *ctx, void *buf, size_t buflen, void *cb_arg)
113 {
114 	return circular_read("server", server_buffer, sizeof(server_buffer),
115 	    &server_readptr, server_writeptr, buf, buflen);
116 }
117 
118 static ssize_t
119 server_write(struct tls *ctx, const void *buf, size_t buflen, void *cb_arg)
120 {
121 	return circular_write("client", client_buffer, sizeof(client_buffer),
122 	    client_readptr, &client_writeptr, buf, buflen);
123 }
124 
125 static int
126 do_tls_handshake(char *name, struct tls *ctx)
127 {
128 	int rv;
129 
130 	rv = tls_handshake(ctx);
131 	if (rv == 0)
132 		return (1);
133 	if (rv == TLS_WANT_POLLIN || rv == TLS_WANT_POLLOUT)
134 		return (0);
135 
136 	errx(1, "%s handshake failed: %s", name, tls_error(ctx));
137 }
138 
139 static int
140 do_tls_close(char *name, struct tls *ctx)
141 {
142 	int rv;
143 
144 	rv = tls_close(ctx);
145 	if (rv == 0)
146 		return (1);
147 	if (rv == TLS_WANT_POLLIN || rv == TLS_WANT_POLLOUT)
148 		return (0);
149 
150 	errx(1, "%s close failed: %s", name, tls_error(ctx));
151 }
152 
153 static int
154 do_client_server_handshake(char *desc, struct tls *client,
155     struct tls *server_cctx)
156 {
157 	int i, client_done, server_done;
158 
159 	i = client_done = server_done = 0;
160 	do {
161 		if (client_done == 0)
162 			client_done = do_tls_handshake("client", client);
163 		if (server_done == 0)
164 			server_done = do_tls_handshake("server", server_cctx);
165 	} while (i++ < 100 && (client_done == 0 || server_done == 0));
166 
167 	if (client_done == 0 || server_done == 0) {
168 		printf("FAIL: %s TLS handshake did not complete\n", desc);
169 		return (1);
170 	}
171 
172 	return (0);
173 }
174 
175 static int
176 do_client_server_close(char *desc, struct tls *client, struct tls *server_cctx)
177 {
178 	int i, client_done, server_done;
179 
180 	i = client_done = server_done = 0;
181 	do {
182 		if (client_done == 0)
183 			client_done = do_tls_close("client", client);
184 		if (server_done == 0)
185 			server_done = do_tls_close("server", server_cctx);
186 	} while (i++ < 100 && (client_done == 0 || server_done == 0));
187 
188 	if (client_done == 0 || server_done == 0) {
189 		printf("FAIL: %s TLS close did not complete\n", desc);
190 		return (1);
191 	}
192 
193 	return (0);
194 }
195 
196 static int
197 do_client_server_test(char *desc, struct tls *client, struct tls *server_cctx)
198 {
199 	if (do_client_server_handshake(desc, client, server_cctx) != 0)
200 		return (1);
201 
202 	printf("INFO: %s TLS handshake completed successfully\n", desc);
203 
204 	/* XXX - Do some reads and writes... */
205 
206 	if (do_client_server_close(desc, client, server_cctx) != 0)
207 		return (1);
208 
209 	printf("INFO: %s TLS close completed successfully\n", desc);
210 
211 	return (0);
212 }
213 
214 static int
215 test_tls_cbs(struct tls *client, struct tls *server)
216 {
217 	struct tls *server_cctx;
218 	int failure;
219 
220 	circular_init();
221 
222 	if (tls_accept_cbs(server, &server_cctx, server_read, server_write,
223 	    NULL) == -1)
224 		errx(1, "failed to accept: %s", tls_error(server));
225 
226 	if (tls_connect_cbs(client, client_read, client_write, NULL,
227 	    "test") == -1)
228 		errx(1, "failed to connect: %s", tls_error(client));
229 
230 	failure = do_client_server_test("callback", client, server_cctx);
231 
232 	tls_free(server_cctx);
233 
234 	return (failure);
235 }
236 
237 static int
238 test_tls_fds(struct tls *client, struct tls *server)
239 {
240 	struct tls *server_cctx;
241 	int cfds[2], sfds[2];
242 	int failure;
243 
244 	if (pipe2(cfds, O_NONBLOCK) == -1)
245 		err(1, "failed to create pipe");
246 	if (pipe2(sfds, O_NONBLOCK) == -1)
247 		err(1, "failed to create pipe");
248 
249 	if (tls_accept_fds(server, &server_cctx, sfds[0], cfds[1]) == -1)
250 		errx(1, "failed to accept: %s", tls_error(server));
251 
252 	if (tls_connect_fds(client, cfds[0], sfds[1], "test") == -1)
253 		errx(1, "failed to connect: %s", tls_error(client));
254 
255 	failure = do_client_server_test("file descriptor", client, server_cctx);
256 
257 	tls_free(server_cctx);
258 
259 	close(cfds[0]);
260 	close(cfds[1]);
261 	close(sfds[0]);
262 	close(sfds[1]);
263 
264 	return (failure);
265 }
266 
267 static int
268 test_tls_socket(struct tls *client, struct tls *server)
269 {
270 	struct tls *server_cctx;
271 	int failure;
272 	int sv[2];
273 
274 	if (socketpair(AF_UNIX, SOCK_STREAM | SOCK_NONBLOCK, PF_UNSPEC,
275 	    sv) == -1)
276 		err(1, "failed to create socketpair");
277 
278 	if (tls_accept_socket(server, &server_cctx, sv[0]) == -1)
279 		errx(1, "failed to accept: %s", tls_error(server));
280 
281 	if (tls_connect_socket(client, sv[1], "test") == -1)
282 		errx(1, "failed to connect: %s", tls_error(client));
283 
284 	failure = do_client_server_test("socket", client, server_cctx);
285 
286 	tls_free(server_cctx);
287 
288 	close(sv[0]);
289 	close(sv[1]);
290 
291 	return (failure);
292 }
293 
294 static int
295 test_tls(char *client_protocols, char *server_protocols, char *ciphers)
296 {
297 	struct tls_config *client_cfg, *server_cfg;
298 	struct tls *client, *server;
299 	uint32_t protocols;
300 
301 	if ((client = tls_client()) == NULL)
302 		errx(1, "failed to create tls client");
303 	if ((client_cfg = tls_config_new()) == NULL)
304 		errx(1, "failed to create tls client config");
305 	tls_config_insecure_noverifyname(client_cfg);
306 	if (tls_config_parse_protocols(&protocols, client_protocols) == -1)
307 		errx(1, "failed to parse protocols: %s", tls_config_error(client_cfg));
308 	if (tls_config_set_protocols(client_cfg, protocols) == -1)
309 		errx(1, "failed to set protocols: %s", tls_config_error(client_cfg));
310 	if (tls_config_set_ciphers(client_cfg, ciphers) == -1)
311 		errx(1, "failed to set ciphers: %s", tls_config_error(client_cfg));
312 	if (tls_config_set_ca_file(client_cfg, cafile) == -1)
313 		errx(1, "failed to set ca: %s", tls_config_error(client_cfg));
314 
315 	if ((server = tls_server()) == NULL)
316 		errx(1, "failed to create tls server");
317 	if ((server_cfg = tls_config_new()) == NULL)
318 		errx(1, "failed to create tls server config");
319 	if (tls_config_parse_protocols(&protocols, server_protocols) == -1)
320 		errx(1, "failed to parse protocols: %s", tls_config_error(server_cfg));
321 	if (tls_config_set_protocols(server_cfg, protocols) == -1)
322 		errx(1, "failed to set protocols: %s", tls_config_error(server_cfg));
323 	if (tls_config_set_ciphers(server_cfg, ciphers) == -1)
324 		errx(1, "failed to set ciphers: %s", tls_config_error(server_cfg));
325 	if (tls_config_set_keypair_file(server_cfg, certfile, keyfile) == -1)
326 		errx(1, "failed to set keypair: %s",
327 		    tls_config_error(server_cfg));
328 
329 	if (tls_configure(client, client_cfg) == -1)
330 		errx(1, "failed to configure client: %s", tls_error(client));
331 	tls_reset(server);
332 	if (tls_configure(server, server_cfg) == -1)
333 		errx(1, "failed to configure server: %s", tls_error(server));
334 
335 	return test_tls_cbs(client, server);
336 }
337 
338 static int
339 do_tls_tests(void)
340 {
341 	struct tls_config *client_cfg, *server_cfg;
342 	struct tls *client, *server;
343 	int failure = 0;
344 
345 	printf("== TLS tests ==\n");
346 
347 	if ((client = tls_client()) == NULL)
348 		errx(1, "failed to create tls client");
349 	if ((client_cfg = tls_config_new()) == NULL)
350 		errx(1, "failed to create tls client config");
351 	tls_config_insecure_noverifyname(client_cfg);
352 	if (tls_config_set_ca_file(client_cfg, cafile) == -1)
353 		errx(1, "failed to set ca: %s", tls_config_error(client_cfg));
354 
355 	if ((server = tls_server()) == NULL)
356 		errx(1, "failed to create tls server");
357 	if ((server_cfg = tls_config_new()) == NULL)
358 		errx(1, "failed to create tls server config");
359 	if (tls_config_set_keypair_file(server_cfg, certfile, keyfile) == -1)
360 		errx(1, "failed to set keypair: %s",
361 		    tls_config_error(server_cfg));
362 
363 	tls_reset(client);
364 	if (tls_configure(client, client_cfg) == -1)
365 		errx(1, "failed to configure client: %s", tls_error(client));
366 	tls_reset(server);
367 	if (tls_configure(server, server_cfg) == -1)
368 		errx(1, "failed to configure server: %s", tls_error(server));
369 
370 	failure |= test_tls_cbs(client, server);
371 
372 	tls_reset(client);
373 	if (tls_configure(client, client_cfg) == -1)
374 		errx(1, "failed to configure client: %s", tls_error(client));
375 	tls_reset(server);
376 	if (tls_configure(server, server_cfg) == -1)
377 		errx(1, "failed to configure server: %s", tls_error(server));
378 
379 	failure |= test_tls_fds(client, server);
380 
381 	tls_reset(client);
382 	if (tls_configure(client, client_cfg) == -1)
383 		errx(1, "failed to configure client: %s", tls_error(client));
384 	tls_reset(server);
385 	if (tls_configure(server, server_cfg) == -1)
386 		errx(1, "failed to configure server: %s", tls_error(server));
387 
388 	tls_config_free(client_cfg);
389 	tls_config_free(server_cfg);
390 
391 	failure |= test_tls_socket(client, server);
392 
393 	tls_free(client);
394 	tls_free(server);
395 
396 	printf("\n");
397 
398 	return (failure);
399 }
400 
401 static int
402 do_tls_ordering_tests(void)
403 {
404 	struct tls *client = NULL, *server = NULL, *server_cctx = NULL;
405 	struct tls_config *client_cfg, *server_cfg;
406 	int failure = 0;
407 
408 	printf("== TLS ordering tests ==\n");
409 
410 	if ((client = tls_client()) == NULL)
411 		errx(1, "failed to create tls client");
412 	if ((client_cfg = tls_config_new()) == NULL)
413 		errx(1, "failed to create tls client config");
414 	tls_config_insecure_noverifyname(client_cfg);
415 	if (tls_config_set_ca_file(client_cfg, cafile) == -1)
416 		errx(1, "failed to set ca: %s", tls_config_error(client_cfg));
417 
418 	if ((server = tls_server()) == NULL)
419 		errx(1, "failed to create tls server");
420 	if ((server_cfg = tls_config_new()) == NULL)
421 		errx(1, "failed to create tls server config");
422 	if (tls_config_set_keypair_file(server_cfg, certfile, keyfile) == -1)
423 		errx(1, "failed to set keypair: %s",
424 		    tls_config_error(server_cfg));
425 
426 	if (tls_configure(client, client_cfg) == -1)
427 		errx(1, "failed to configure client: %s", tls_error(client));
428 	if (tls_configure(server, server_cfg) == -1)
429 		errx(1, "failed to configure server: %s", tls_error(server));
430 
431 	tls_config_free(client_cfg);
432 	tls_config_free(server_cfg);
433 
434 	if (tls_handshake(client) != -1) {
435 		printf("FAIL: TLS handshake succeeded on unconnnected "
436 		    "client context\n");
437 		failure = 1;
438 		goto done;
439 	}
440 
441 	circular_init();
442 
443 	if (tls_accept_cbs(server, &server_cctx, server_read, server_write,
444 	    NULL) == -1)
445 		errx(1, "failed to accept: %s", tls_error(server));
446 
447 	if (tls_connect_cbs(client, client_read, client_write, NULL,
448 	    "test") == -1)
449 		errx(1, "failed to connect: %s", tls_error(client));
450 
451 	if (do_client_server_handshake("ordering", client, server_cctx) != 0) {
452 		failure = 1;
453 		goto done;
454 	}
455 
456 	if (tls_handshake(client) != -1) {
457 		printf("FAIL: TLS handshake succeeded twice\n");
458 		failure = 1;
459 		goto done;
460 	}
461 
462 	if (tls_handshake(server_cctx) != -1) {
463 		printf("FAIL: TLS handshake succeeded twice\n");
464 		failure = 1;
465 		goto done;
466 	}
467 
468 	if (do_client_server_close("ordering", client, server_cctx) != 0) {
469 		failure = 1;
470 		goto done;
471 	}
472 
473  done:
474 	tls_free(client);
475 	tls_free(server);
476 	tls_free(server_cctx);
477 
478 	printf("\n");
479 
480 	return (failure);
481 }
482 
483 struct test_versions {
484 	char *client;
485 	char *server;
486 };
487 
488 static struct test_versions tls_test_versions[] = {
489 	{"tlsv1.3", "all"},
490 	{"tlsv1.2", "all"},
491 	{"tlsv1.1", "all"},
492 	{"tlsv1.0", "all"},
493 	{"all", "tlsv1.3"},
494 	{"all", "tlsv1.2"},
495 	{"all", "tlsv1.1"},
496 	{"all", "tlsv1.0"},
497 	{"tlsv1.3", "tlsv1.3"},
498 	{"tlsv1.2", "tlsv1.2"},
499 	{"tlsv1.1", "tlsv1.1"},
500 	{"tlsv1.0", "tlsv1.0"},
501 };
502 
503 #define N_TLS_VERSION_TESTS \
504     (sizeof(tls_test_versions) / sizeof(*tls_test_versions))
505 
506 static int
507 do_tls_version_tests(void)
508 {
509 	struct test_versions *tv;
510 	int failure = 0;
511 	size_t i;
512 
513 	printf("== TLS version tests ==\n");
514 
515 	for (i = 0; i < N_TLS_VERSION_TESTS; i++) {
516 		tv = &tls_test_versions[i];
517 		printf("INFO: version test %zu - client versions '%s' "
518 		    "and server versions '%s'\n", i, tv->client, tv->server);
519 		failure |= test_tls(tv->client, tv->server, "legacy");
520 		printf("\n");
521 	}
522 
523 	return failure;
524 }
525 
526 int
527 main(int argc, char **argv)
528 {
529 	int failure = 0;
530 
531 	if (argc != 4) {
532 		fprintf(stderr, "usage: %s cafile certfile keyfile\n",
533 		    argv[0]);
534 		return (1);
535 	}
536 
537 	cafile = argv[1];
538 	certfile = argv[2];
539 	keyfile = argv[3];
540 
541 	failure |= do_tls_tests();
542 	failure |= do_tls_ordering_tests();
543 	failure |= do_tls_version_tests();
544 
545 	return (failure);
546 }
547