xref: /openbsd-src/regress/lib/libtls/tls/tlstest.c (revision 5143226f3a439735e88060d24f13f914a0b7dd48)
1 /* $OpenBSD: tlstest.c,v 1.14 2022/06/22 10:01:17 tb Exp $ */
2 /*
3  * Copyright (c) 2017 Joel Sing <jsing@openbsd.org>
4  *
5  * Permission to use, copy, modify, and distribute this software for any
6  * purpose with or without fee is hereby granted, provided that the above
7  * copyright notice and this permission notice appear in all copies.
8  *
9  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
10  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
11  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
12  * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
13  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
14  * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
15  * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
16  */
17 
18 #include <sys/socket.h>
19 
20 #include <fcntl.h>
21 #include <unistd.h>
22 
23 #include <err.h>
24 #include <fcntl.h>
25 #include <stdio.h>
26 #include <string.h>
27 #include <unistd.h>
28 
29 #include <tls.h>
30 
31 #define CIRCULAR_BUFFER_SIZE 512
32 
33 unsigned char client_buffer[CIRCULAR_BUFFER_SIZE];
34 unsigned char *client_readptr, *client_writeptr;
35 
36 unsigned char server_buffer[CIRCULAR_BUFFER_SIZE];
37 unsigned char *server_readptr, *server_writeptr;
38 
39 char *cafile, *certfile, *keyfile;
40 
41 int debug = 0;
42 
43 static void
44 circular_init(void)
45 {
46 	client_readptr = client_writeptr = client_buffer;
47 	server_readptr = server_writeptr = server_buffer;
48 }
49 
50 static ssize_t
51 circular_read(char *name, unsigned char *buf, size_t bufsize,
52     unsigned char **readptr, unsigned char *writeptr,
53     unsigned char *outbuf, size_t outlen)
54 {
55 	unsigned char *nextptr = *readptr;
56 	size_t n = 0;
57 
58 	while (n < outlen) {
59 		if (nextptr == writeptr)
60 			break;
61 		*outbuf++ = *nextptr++;
62 		if ((size_t)(nextptr - buf) >= bufsize)
63 			nextptr = buf;
64 		*readptr = nextptr;
65 		n++;
66 	}
67 
68 	if (debug && n > 0)
69 		fprintf(stderr, "%s buffer: read %zi bytes\n", name, n);
70 
71 	return (n > 0 ? (ssize_t)n : TLS_WANT_POLLIN);
72 }
73 
74 static ssize_t
75 circular_write(char *name, unsigned char *buf, size_t bufsize,
76     unsigned char *readptr, unsigned char **writeptr,
77     const unsigned char *inbuf, size_t inlen)
78 {
79 	unsigned char *nextptr = *writeptr;
80 	unsigned char *prevptr;
81 	size_t n = 0;
82 
83 	while (n < inlen) {
84 		prevptr = nextptr++;
85 		if ((size_t)(nextptr - buf) >= bufsize)
86 			nextptr = buf;
87 		if (nextptr == readptr)
88 			break;
89 		*prevptr = *inbuf++;
90 		*writeptr = nextptr;
91 		n++;
92 	}
93 
94 	if (debug && n > 0)
95 		fprintf(stderr, "%s buffer: wrote %zi bytes\n", name, n);
96 
97 	return (n > 0 ? (ssize_t)n : TLS_WANT_POLLOUT);
98 }
99 
100 static ssize_t
101 client_read(struct tls *ctx, void *buf, size_t buflen, void *cb_arg)
102 {
103 	return circular_read("client", client_buffer, sizeof(client_buffer),
104 	    &client_readptr, client_writeptr, buf, buflen);
105 }
106 
107 static ssize_t
108 client_write(struct tls *ctx, const void *buf, size_t buflen, void *cb_arg)
109 {
110 	return circular_write("server", server_buffer, sizeof(server_buffer),
111 	    server_readptr, &server_writeptr, buf, buflen);
112 }
113 
114 static ssize_t
115 server_read(struct tls *ctx, void *buf, size_t buflen, void *cb_arg)
116 {
117 	return circular_read("server", server_buffer, sizeof(server_buffer),
118 	    &server_readptr, server_writeptr, buf, buflen);
119 }
120 
121 static ssize_t
122 server_write(struct tls *ctx, const void *buf, size_t buflen, void *cb_arg)
123 {
124 	return circular_write("client", client_buffer, sizeof(client_buffer),
125 	    client_readptr, &client_writeptr, buf, buflen);
126 }
127 
128 static int
129 do_tls_handshake(char *name, struct tls *ctx)
130 {
131 	int rv;
132 
133 	rv = tls_handshake(ctx);
134 	if (rv == 0)
135 		return (1);
136 	if (rv == TLS_WANT_POLLIN || rv == TLS_WANT_POLLOUT)
137 		return (0);
138 
139 	errx(1, "%s handshake failed: %s", name, tls_error(ctx));
140 }
141 
142 static int
143 do_tls_close(char *name, struct tls *ctx)
144 {
145 	int rv;
146 
147 	rv = tls_close(ctx);
148 	if (rv == 0)
149 		return (1);
150 	if (rv == TLS_WANT_POLLIN || rv == TLS_WANT_POLLOUT)
151 		return (0);
152 
153 	errx(1, "%s close failed: %s", name, tls_error(ctx));
154 }
155 
156 static int
157 do_client_server_handshake(char *desc, struct tls *client,
158     struct tls *server_cctx)
159 {
160 	int i, client_done, server_done;
161 
162 	i = client_done = server_done = 0;
163 	do {
164 		if (client_done == 0)
165 			client_done = do_tls_handshake("client", client);
166 		if (server_done == 0)
167 			server_done = do_tls_handshake("server", server_cctx);
168 	} while (i++ < 100 && (client_done == 0 || server_done == 0));
169 
170 	if (client_done == 0 || server_done == 0) {
171 		printf("FAIL: %s TLS handshake did not complete\n", desc);
172 		return (1);
173 	}
174 
175 	return (0);
176 }
177 
178 static int
179 do_client_server_close(char *desc, struct tls *client, struct tls *server_cctx)
180 {
181 	int i, client_done, server_done;
182 
183 	i = client_done = server_done = 0;
184 	do {
185 		if (client_done == 0)
186 			client_done = do_tls_close("client", client);
187 		if (server_done == 0)
188 			server_done = do_tls_close("server", server_cctx);
189 	} while (i++ < 100 && (client_done == 0 || server_done == 0));
190 
191 	if (client_done == 0 || server_done == 0) {
192 		printf("FAIL: %s TLS close did not complete\n", desc);
193 		return (1);
194 	}
195 
196 	return (0);
197 }
198 
199 static int
200 do_client_server_test(char *desc, struct tls *client, struct tls *server_cctx)
201 {
202 	if (do_client_server_handshake(desc, client, server_cctx) != 0)
203 		return (1);
204 
205 	printf("INFO: %s TLS handshake completed successfully\n", desc);
206 
207 	/* XXX - Do some reads and writes... */
208 
209 	if (do_client_server_close(desc, client, server_cctx) != 0)
210 		return (1);
211 
212 	printf("INFO: %s TLS close completed successfully\n", desc);
213 
214 	return (0);
215 }
216 
217 static int
218 test_tls_cbs(struct tls *client, struct tls *server)
219 {
220 	struct tls *server_cctx;
221 	int failure;
222 
223 	circular_init();
224 
225 	if (tls_accept_cbs(server, &server_cctx, server_read, server_write,
226 	    NULL) == -1)
227 		errx(1, "failed to accept: %s", tls_error(server));
228 
229 	if (tls_connect_cbs(client, client_read, client_write, NULL,
230 	    "test") == -1)
231 		errx(1, "failed to connect: %s", tls_error(client));
232 
233 	failure = do_client_server_test("callback", client, server_cctx);
234 
235 	tls_free(server_cctx);
236 
237 	return (failure);
238 }
239 
240 static int
241 test_tls_fds(struct tls *client, struct tls *server)
242 {
243 	struct tls *server_cctx;
244 	int cfds[2], sfds[2];
245 	int failure;
246 
247 	if (pipe2(cfds, O_NONBLOCK) == -1)
248 		err(1, "failed to create pipe");
249 	if (pipe2(sfds, O_NONBLOCK) == -1)
250 		err(1, "failed to create pipe");
251 
252 	if (tls_accept_fds(server, &server_cctx, sfds[0], cfds[1]) == -1)
253 		errx(1, "failed to accept: %s", tls_error(server));
254 
255 	if (tls_connect_fds(client, cfds[0], sfds[1], "test") == -1)
256 		errx(1, "failed to connect: %s", tls_error(client));
257 
258 	failure = do_client_server_test("file descriptor", client, server_cctx);
259 
260 	tls_free(server_cctx);
261 
262 	close(cfds[0]);
263 	close(cfds[1]);
264 	close(sfds[0]);
265 	close(sfds[1]);
266 
267 	return (failure);
268 }
269 
270 static int
271 test_tls_socket(struct tls *client, struct tls *server)
272 {
273 	struct tls *server_cctx;
274 	int failure;
275 	int sv[2];
276 
277 	if (socketpair(AF_UNIX, SOCK_STREAM | SOCK_NONBLOCK, PF_UNSPEC,
278 	    sv) == -1)
279 		err(1, "failed to create socketpair");
280 
281 	if (tls_accept_socket(server, &server_cctx, sv[0]) == -1)
282 		errx(1, "failed to accept: %s", tls_error(server));
283 
284 	if (tls_connect_socket(client, sv[1], "test") == -1)
285 		errx(1, "failed to connect: %s", tls_error(client));
286 
287 	failure = do_client_server_test("socket", client, server_cctx);
288 
289 	tls_free(server_cctx);
290 
291 	close(sv[0]);
292 	close(sv[1]);
293 
294 	return (failure);
295 }
296 
297 static int
298 test_tls(char *client_protocols, char *server_protocols, char *ciphers)
299 {
300 	struct tls_config *client_cfg, *server_cfg;
301 	struct tls *client, *server;
302 	uint32_t protocols;
303 	int failure = 0;
304 
305 	if ((client = tls_client()) == NULL)
306 		errx(1, "failed to create tls client");
307 	if ((client_cfg = tls_config_new()) == NULL)
308 		errx(1, "failed to create tls client config");
309 	tls_config_insecure_noverifyname(client_cfg);
310 	if (tls_config_parse_protocols(&protocols, client_protocols) == -1)
311 		errx(1, "failed to parse protocols: %s", tls_config_error(client_cfg));
312 	if (tls_config_set_protocols(client_cfg, protocols) == -1)
313 		errx(1, "failed to set protocols: %s", tls_config_error(client_cfg));
314 	if (tls_config_set_ciphers(client_cfg, ciphers) == -1)
315 		errx(1, "failed to set ciphers: %s", tls_config_error(client_cfg));
316 	if (tls_config_set_ca_file(client_cfg, cafile) == -1)
317 		errx(1, "failed to set ca: %s", tls_config_error(client_cfg));
318 
319 	if ((server = tls_server()) == NULL)
320 		errx(1, "failed to create tls server");
321 	if ((server_cfg = tls_config_new()) == NULL)
322 		errx(1, "failed to create tls server config");
323 	if (tls_config_parse_protocols(&protocols, server_protocols) == -1)
324 		errx(1, "failed to parse protocols: %s", tls_config_error(server_cfg));
325 	if (tls_config_set_protocols(server_cfg, protocols) == -1)
326 		errx(1, "failed to set protocols: %s", tls_config_error(server_cfg));
327 	if (tls_config_set_ciphers(server_cfg, ciphers) == -1)
328 		errx(1, "failed to set ciphers: %s", tls_config_error(server_cfg));
329 	if (tls_config_set_keypair_file(server_cfg, certfile, keyfile) == -1)
330 		errx(1, "failed to set keypair: %s",
331 		    tls_config_error(server_cfg));
332 
333 	if (tls_configure(client, client_cfg) == -1)
334 		errx(1, "failed to configure client: %s", tls_error(client));
335 	tls_reset(server);
336 	if (tls_configure(server, server_cfg) == -1)
337 		errx(1, "failed to configure server: %s", tls_error(server));
338 
339 	tls_config_free(client_cfg);
340 	tls_config_free(server_cfg);
341 
342 	failure |= test_tls_cbs(client, server);
343 
344 	tls_free(client);
345 	tls_free(server);
346 
347 	return (failure);
348 }
349 
350 static int
351 do_tls_tests(void)
352 {
353 	struct tls_config *client_cfg, *server_cfg;
354 	struct tls *client, *server;
355 	int failure = 0;
356 
357 	printf("== TLS tests ==\n");
358 
359 	if ((client = tls_client()) == NULL)
360 		errx(1, "failed to create tls client");
361 	if ((client_cfg = tls_config_new()) == NULL)
362 		errx(1, "failed to create tls client config");
363 	tls_config_insecure_noverifyname(client_cfg);
364 	if (tls_config_set_ca_file(client_cfg, cafile) == -1)
365 		errx(1, "failed to set ca: %s", tls_config_error(client_cfg));
366 
367 	if ((server = tls_server()) == NULL)
368 		errx(1, "failed to create tls server");
369 	if ((server_cfg = tls_config_new()) == NULL)
370 		errx(1, "failed to create tls server config");
371 	if (tls_config_set_keypair_file(server_cfg, certfile, keyfile) == -1)
372 		errx(1, "failed to set keypair: %s",
373 		    tls_config_error(server_cfg));
374 
375 	tls_reset(client);
376 	if (tls_configure(client, client_cfg) == -1)
377 		errx(1, "failed to configure client: %s", tls_error(client));
378 	tls_reset(server);
379 	if (tls_configure(server, server_cfg) == -1)
380 		errx(1, "failed to configure server: %s", tls_error(server));
381 
382 	failure |= test_tls_cbs(client, server);
383 
384 	tls_reset(client);
385 	if (tls_configure(client, client_cfg) == -1)
386 		errx(1, "failed to configure client: %s", tls_error(client));
387 	tls_reset(server);
388 	if (tls_configure(server, server_cfg) == -1)
389 		errx(1, "failed to configure server: %s", tls_error(server));
390 
391 	failure |= test_tls_fds(client, server);
392 
393 	tls_reset(client);
394 	if (tls_configure(client, client_cfg) == -1)
395 		errx(1, "failed to configure client: %s", tls_error(client));
396 	tls_reset(server);
397 	if (tls_configure(server, server_cfg) == -1)
398 		errx(1, "failed to configure server: %s", tls_error(server));
399 
400 	tls_config_free(client_cfg);
401 	tls_config_free(server_cfg);
402 
403 	failure |= test_tls_socket(client, server);
404 
405 	tls_free(client);
406 	tls_free(server);
407 
408 	printf("\n");
409 
410 	return (failure);
411 }
412 
413 static int
414 do_tls_ordering_tests(void)
415 {
416 	struct tls *client = NULL, *server = NULL, *server_cctx = NULL;
417 	struct tls_config *client_cfg, *server_cfg;
418 	int failure = 0;
419 
420 	printf("== TLS ordering tests ==\n");
421 
422 	if ((client = tls_client()) == NULL)
423 		errx(1, "failed to create tls client");
424 	if ((client_cfg = tls_config_new()) == NULL)
425 		errx(1, "failed to create tls client config");
426 	tls_config_insecure_noverifyname(client_cfg);
427 	if (tls_config_set_ca_file(client_cfg, cafile) == -1)
428 		errx(1, "failed to set ca: %s", tls_config_error(client_cfg));
429 
430 	if ((server = tls_server()) == NULL)
431 		errx(1, "failed to create tls server");
432 	if ((server_cfg = tls_config_new()) == NULL)
433 		errx(1, "failed to create tls server config");
434 	if (tls_config_set_keypair_file(server_cfg, certfile, keyfile) == -1)
435 		errx(1, "failed to set keypair: %s",
436 		    tls_config_error(server_cfg));
437 
438 	if (tls_configure(client, client_cfg) == -1)
439 		errx(1, "failed to configure client: %s", tls_error(client));
440 	if (tls_configure(server, server_cfg) == -1)
441 		errx(1, "failed to configure server: %s", tls_error(server));
442 
443 	tls_config_free(client_cfg);
444 	tls_config_free(server_cfg);
445 
446 	if (tls_handshake(client) != -1) {
447 		printf("FAIL: TLS handshake succeeded on unconnnected "
448 		    "client context\n");
449 		failure = 1;
450 		goto done;
451 	}
452 
453 	circular_init();
454 
455 	if (tls_accept_cbs(server, &server_cctx, server_read, server_write,
456 	    NULL) == -1)
457 		errx(1, "failed to accept: %s", tls_error(server));
458 
459 	if (tls_connect_cbs(client, client_read, client_write, NULL,
460 	    "test") == -1)
461 		errx(1, "failed to connect: %s", tls_error(client));
462 
463 	if (do_client_server_handshake("ordering", client, server_cctx) != 0) {
464 		failure = 1;
465 		goto done;
466 	}
467 
468 	if (tls_handshake(client) != -1) {
469 		printf("FAIL: TLS handshake succeeded twice\n");
470 		failure = 1;
471 		goto done;
472 	}
473 
474 	if (tls_handshake(server_cctx) != -1) {
475 		printf("FAIL: TLS handshake succeeded twice\n");
476 		failure = 1;
477 		goto done;
478 	}
479 
480 	if (do_client_server_close("ordering", client, server_cctx) != 0) {
481 		failure = 1;
482 		goto done;
483 	}
484 
485  done:
486 	tls_free(client);
487 	tls_free(server);
488 	tls_free(server_cctx);
489 
490 	printf("\n");
491 
492 	return (failure);
493 }
494 
495 struct test_versions {
496 	char *client;
497 	char *server;
498 };
499 
500 static struct test_versions tls_test_versions[] = {
501 	{"tlsv1.3", "all"},
502 	{"tlsv1.2", "all"},
503 	{"tlsv1.1", "all"},
504 	{"tlsv1.0", "all"},
505 	{"all", "tlsv1.3"},
506 	{"all", "tlsv1.2"},
507 	{"all", "tlsv1.1"},
508 	{"all", "tlsv1.0"},
509 	{"tlsv1.3", "tlsv1.3"},
510 	{"tlsv1.2", "tlsv1.2"},
511 	{"tlsv1.1", "tlsv1.1"},
512 	{"tlsv1.0", "tlsv1.0"},
513 };
514 
515 #define N_TLS_VERSION_TESTS \
516     (sizeof(tls_test_versions) / sizeof(*tls_test_versions))
517 
518 static int
519 do_tls_version_tests(void)
520 {
521 	struct test_versions *tv;
522 	int failure = 0;
523 	size_t i;
524 
525 	printf("== TLS version tests ==\n");
526 
527 	for (i = 0; i < N_TLS_VERSION_TESTS; i++) {
528 		tv = &tls_test_versions[i];
529 		printf("INFO: version test %zu - client versions '%s' "
530 		    "and server versions '%s'\n", i, tv->client, tv->server);
531 		failure |= test_tls(tv->client, tv->server, "legacy");
532 		printf("\n");
533 	}
534 
535 	return failure;
536 }
537 
538 int
539 main(int argc, char **argv)
540 {
541 	int failure = 0;
542 
543 	if (argc != 4) {
544 		fprintf(stderr, "usage: %s cafile certfile keyfile\n",
545 		    argv[0]);
546 		return (1);
547 	}
548 
549 	cafile = argv[1];
550 	certfile = argv[2];
551 	keyfile = argv[3];
552 
553 	failure |= do_tls_tests();
554 	failure |= do_tls_ordering_tests();
555 	failure |= do_tls_version_tests();
556 
557 	return (failure);
558 }
559