1 /* $OpenBSD: tlstest.c,v 1.9 2017/05/07 03:25:26 jsing Exp $ */ 2 /* 3 * Copyright (c) 2017 Joel Sing <jsing@openbsd.org> 4 * 5 * Permission to use, copy, modify, and distribute this software for any 6 * purpose with or without fee is hereby granted, provided that the above 7 * copyright notice and this permission notice appear in all copies. 8 * 9 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 10 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 11 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 12 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 13 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 14 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 15 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 16 */ 17 18 #include <sys/socket.h> 19 20 #include <err.h> 21 #include <fcntl.h> 22 #include <stdio.h> 23 #include <string.h> 24 #include <unistd.h> 25 26 #include <tls.h> 27 28 #define CIRCULAR_BUFFER_SIZE 512 29 30 unsigned char client_buffer[CIRCULAR_BUFFER_SIZE]; 31 unsigned char *client_readptr, *client_writeptr; 32 33 unsigned char server_buffer[CIRCULAR_BUFFER_SIZE]; 34 unsigned char *server_readptr, *server_writeptr; 35 36 char *cafile, *certfile, *keyfile; 37 38 int debug = 0; 39 40 static void 41 circular_init(void) 42 { 43 client_readptr = client_writeptr = client_buffer; 44 server_readptr = server_writeptr = server_buffer; 45 } 46 47 static ssize_t 48 circular_read(char *name, unsigned char *buf, size_t bufsize, 49 unsigned char **readptr, unsigned char *writeptr, 50 unsigned char *outbuf, size_t outlen) 51 { 52 unsigned char *nextptr = *readptr; 53 size_t n = 0; 54 55 while (n < outlen) { 56 if (nextptr == writeptr) 57 break; 58 *outbuf++ = *nextptr++; 59 if ((size_t)(nextptr - buf) >= bufsize) 60 nextptr = buf; 61 *readptr = nextptr; 62 n++; 63 } 64 65 if (debug && n > 0) 66 fprintf(stderr, "%s buffer: read %zi bytes\n", name, n); 67 68 return (n > 0 ? (ssize_t)n : TLS_WANT_POLLIN); 69 } 70 71 static ssize_t 72 circular_write(char *name, unsigned char *buf, size_t bufsize, 73 unsigned char *readptr, unsigned char **writeptr, 74 const unsigned char *inbuf, size_t inlen) 75 { 76 unsigned char *nextptr = *writeptr; 77 unsigned char *prevptr; 78 size_t n = 0; 79 80 while (n < inlen) { 81 prevptr = nextptr++; 82 if ((size_t)(nextptr - buf) >= bufsize) 83 nextptr = buf; 84 if (nextptr == readptr) 85 break; 86 *prevptr = *inbuf++; 87 *writeptr = nextptr; 88 n++; 89 } 90 91 if (debug && n > 0) 92 fprintf(stderr, "%s buffer: wrote %zi bytes\n", name, n); 93 94 return (n > 0 ? (ssize_t)n : TLS_WANT_POLLOUT); 95 } 96 97 static ssize_t 98 client_read(struct tls *ctx, void *buf, size_t buflen, void *cb_arg) 99 { 100 return circular_read("client", client_buffer, sizeof(client_buffer), 101 &client_readptr, client_writeptr, buf, buflen); 102 } 103 104 static ssize_t 105 client_write(struct tls *ctx, const void *buf, size_t buflen, void *cb_arg) 106 { 107 return circular_write("server", server_buffer, sizeof(server_buffer), 108 server_readptr, &server_writeptr, buf, buflen); 109 } 110 111 static ssize_t 112 server_read(struct tls *ctx, void *buf, size_t buflen, void *cb_arg) 113 { 114 return circular_read("server", server_buffer, sizeof(server_buffer), 115 &server_readptr, server_writeptr, buf, buflen); 116 } 117 118 static ssize_t 119 server_write(struct tls *ctx, const void *buf, size_t buflen, void *cb_arg) 120 { 121 return circular_write("client", client_buffer, sizeof(client_buffer), 122 client_readptr, &client_writeptr, buf, buflen); 123 } 124 125 static int 126 do_tls_handshake(char *name, struct tls *ctx) 127 { 128 int rv; 129 130 rv = tls_handshake(ctx); 131 if (rv == 0) 132 return (1); 133 if (rv == TLS_WANT_POLLIN || rv == TLS_WANT_POLLOUT) 134 return (0); 135 136 errx(1, "%s handshake failed: %s", name, tls_error(ctx)); 137 } 138 139 static int 140 do_tls_close(char *name, struct tls *ctx) 141 { 142 int rv; 143 144 rv = tls_close(ctx); 145 if (rv == 0) 146 return (1); 147 if (rv == TLS_WANT_POLLIN || rv == TLS_WANT_POLLOUT) 148 return (0); 149 150 errx(1, "%s close failed: %s", name, tls_error(ctx)); 151 } 152 153 static int 154 do_client_server_handshake(char *desc, struct tls *client, 155 struct tls *server_cctx) 156 { 157 int i, client_done, server_done; 158 159 i = client_done = server_done = 0; 160 do { 161 if (client_done == 0) 162 client_done = do_tls_handshake("client", client); 163 if (server_done == 0) 164 server_done = do_tls_handshake("server", server_cctx); 165 } while (i++ < 100 && (client_done == 0 || server_done == 0)); 166 167 if (client_done == 0 || server_done == 0) { 168 printf("FAIL: %s TLS handshake did not complete\n", desc); 169 return (1); 170 } 171 172 return (0); 173 } 174 175 static int 176 do_client_server_close(char *desc, struct tls *client, struct tls *server_cctx) 177 { 178 int i, client_done, server_done; 179 180 i = client_done = server_done = 0; 181 do { 182 if (client_done == 0) 183 client_done = do_tls_close("client", client); 184 if (server_done == 0) 185 server_done = do_tls_close("server", server_cctx); 186 } while (i++ < 100 && (client_done == 0 || server_done == 0)); 187 188 if (client_done == 0 || server_done == 0) { 189 printf("FAIL: %s TLS close did not complete\n", desc); 190 return (1); 191 } 192 193 return (0); 194 } 195 196 static int 197 do_client_server_test(char *desc, struct tls *client, struct tls *server_cctx) 198 { 199 if (do_client_server_handshake(desc, client, server_cctx) != 0) 200 return (1); 201 202 printf("INFO: %s TLS handshake completed successfully\n", desc); 203 204 /* XXX - Do some reads and writes... */ 205 206 if (do_client_server_close(desc, client, server_cctx) != 0) 207 return (1); 208 209 printf("INFO: %s TLS close completed successfully\n", desc); 210 211 return (0); 212 } 213 214 static int 215 test_tls_cbs(struct tls *client, struct tls *server) 216 { 217 struct tls *server_cctx; 218 int failure; 219 220 circular_init(); 221 222 if (tls_accept_cbs(server, &server_cctx, server_read, server_write, 223 NULL) == -1) 224 errx(1, "failed to accept: %s", tls_error(server)); 225 226 if (tls_connect_cbs(client, client_read, client_write, NULL, 227 "test") == -1) 228 errx(1, "failed to connect: %s", tls_error(client)); 229 230 failure = do_client_server_test("callback", client, server_cctx); 231 232 tls_free(server_cctx); 233 234 return (failure); 235 } 236 237 static int 238 test_tls_fds(struct tls *client, struct tls *server) 239 { 240 struct tls *server_cctx; 241 int cfds[2], sfds[2]; 242 int failure; 243 244 if (pipe2(cfds, O_NONBLOCK) == -1) 245 err(1, "failed to create pipe"); 246 if (pipe2(sfds, O_NONBLOCK) == -1) 247 err(1, "failed to create pipe"); 248 249 if (tls_accept_fds(server, &server_cctx, sfds[0], cfds[1]) == -1) 250 errx(1, "failed to accept: %s", tls_error(server)); 251 252 if (tls_connect_fds(client, cfds[0], sfds[1], "test") == -1) 253 errx(1, "failed to connect: %s", tls_error(client)); 254 255 failure = do_client_server_test("file descriptor", client, server_cctx); 256 257 tls_free(server_cctx); 258 259 close(cfds[0]); 260 close(cfds[1]); 261 close(sfds[0]); 262 close(sfds[1]); 263 264 return (failure); 265 } 266 267 static int 268 test_tls_socket(struct tls *client, struct tls *server) 269 { 270 struct tls *server_cctx; 271 int failure; 272 int sv[2]; 273 274 if (socketpair(AF_UNIX, SOCK_STREAM | SOCK_NONBLOCK, PF_UNSPEC, 275 sv) == -1) 276 err(1, "failed to create socketpair"); 277 278 if (tls_accept_socket(server, &server_cctx, sv[0]) == -1) 279 errx(1, "failed to accept: %s", tls_error(server)); 280 281 if (tls_connect_socket(client, sv[1], "test") == -1) 282 errx(1, "failed to connect: %s", tls_error(client)); 283 284 failure = do_client_server_test("socket", client, server_cctx); 285 286 tls_free(server_cctx); 287 288 close(sv[0]); 289 close(sv[1]); 290 291 return (failure); 292 } 293 294 static int 295 do_tls_tests(void) 296 { 297 struct tls_config *client_cfg, *server_cfg; 298 struct tls *client, *server; 299 int failure = 0; 300 301 if ((client = tls_client()) == NULL) 302 errx(1, "failed to create tls client"); 303 if ((client_cfg = tls_config_new()) == NULL) 304 errx(1, "failed to create tls client config"); 305 tls_config_insecure_noverifyname(client_cfg); 306 if (tls_config_set_ca_file(client_cfg, cafile) == -1) 307 errx(1, "failed to set ca: %s", tls_config_error(client_cfg)); 308 309 if ((server = tls_server()) == NULL) 310 errx(1, "failed to create tls server"); 311 if ((server_cfg = tls_config_new()) == NULL) 312 errx(1, "failed to create tls server config"); 313 if (tls_config_set_keypair_file(server_cfg, certfile, keyfile) == -1) 314 errx(1, "failed to set keypair: %s", 315 tls_config_error(server_cfg)); 316 317 tls_reset(client); 318 if (tls_configure(client, client_cfg) == -1) 319 errx(1, "failed to configure client: %s", tls_error(client)); 320 tls_reset(server); 321 if (tls_configure(server, server_cfg) == -1) 322 errx(1, "failed to configure server: %s", tls_error(server)); 323 324 failure |= test_tls_cbs(client, server); 325 326 tls_reset(client); 327 if (tls_configure(client, client_cfg) == -1) 328 errx(1, "failed to configure client: %s", tls_error(client)); 329 tls_reset(server); 330 if (tls_configure(server, server_cfg) == -1) 331 errx(1, "failed to configure server: %s", tls_error(server)); 332 333 failure |= test_tls_fds(client, server); 334 335 tls_reset(client); 336 if (tls_configure(client, client_cfg) == -1) 337 errx(1, "failed to configure client: %s", tls_error(client)); 338 tls_reset(server); 339 if (tls_configure(server, server_cfg) == -1) 340 errx(1, "failed to configure server: %s", tls_error(server)); 341 342 tls_config_free(client_cfg); 343 tls_config_free(server_cfg); 344 345 failure |= test_tls_socket(client, server); 346 347 tls_free(client); 348 tls_free(server); 349 350 return (failure); 351 } 352 353 static int 354 do_tls_ordering_tests(void) 355 { 356 struct tls *client = NULL, *server = NULL, *server_cctx = NULL; 357 struct tls_config *client_cfg, *server_cfg; 358 int failure = 0; 359 360 circular_init(); 361 362 if ((client = tls_client()) == NULL) 363 errx(1, "failed to create tls client"); 364 if ((client_cfg = tls_config_new()) == NULL) 365 errx(1, "failed to create tls client config"); 366 tls_config_insecure_noverifyname(client_cfg); 367 if (tls_config_set_ca_file(client_cfg, cafile) == -1) 368 errx(1, "failed to set ca: %s", tls_config_error(client_cfg)); 369 370 if ((server = tls_server()) == NULL) 371 errx(1, "failed to create tls server"); 372 if ((server_cfg = tls_config_new()) == NULL) 373 errx(1, "failed to create tls server config"); 374 if (tls_config_set_keypair_file(server_cfg, certfile, keyfile) == -1) 375 errx(1, "failed to set keypair: %s", 376 tls_config_error(server_cfg)); 377 378 if (tls_configure(client, client_cfg) == -1) 379 errx(1, "failed to configure client: %s", tls_error(client)); 380 if (tls_configure(server, server_cfg) == -1) 381 errx(1, "failed to configure server: %s", tls_error(server)); 382 383 tls_config_free(client_cfg); 384 tls_config_free(server_cfg); 385 386 if (tls_handshake(client) != -1) { 387 printf("FAIL: TLS handshake succeeded on unconnnected " 388 "client context\n"); 389 failure = 1; 390 goto done; 391 } 392 393 if (tls_accept_cbs(server, &server_cctx, server_read, server_write, 394 NULL) == -1) 395 errx(1, "failed to accept: %s", tls_error(server)); 396 397 if (tls_connect_cbs(client, client_read, client_write, NULL, 398 "test") == -1) 399 errx(1, "failed to connect: %s", tls_error(client)); 400 401 if (do_client_server_handshake("ordering", client, server_cctx) != 0) { 402 failure = 1; 403 goto done; 404 } 405 406 if (tls_handshake(client) != -1) { 407 printf("FAIL: TLS handshake succeeded twice\n"); 408 failure = 1; 409 goto done; 410 } 411 412 if (tls_handshake(server_cctx) != -1) { 413 printf("FAIL: TLS handshake succeeded twice\n"); 414 failure = 1; 415 goto done; 416 } 417 418 if (do_client_server_close("ordering", client, server_cctx) != 0) { 419 failure = 1; 420 goto done; 421 } 422 423 done: 424 tls_free(client); 425 tls_free(server); 426 tls_free(server_cctx); 427 428 return (failure); 429 } 430 431 int 432 main(int argc, char **argv) 433 { 434 int failure = 0; 435 436 if (argc != 4) { 437 fprintf(stderr, "usage: %s cafile certfile keyfile\n", 438 argv[0]); 439 return (1); 440 } 441 442 cafile = argv[1]; 443 certfile = argv[2]; 444 keyfile = argv[3]; 445 446 if (tls_init() == -1) 447 errx(1, "failed to initialise tls"); 448 449 failure |= do_tls_tests(); 450 failure |= do_tls_ordering_tests(); 451 452 return (failure); 453 } 454