1 /* $OpenBSD: dtlstest.c,v 1.4 2020/10/16 17:57:20 tb Exp $ */ 2 /* 3 * Copyright (c) 2020 Joel Sing <jsing@openbsd.org> 4 * 5 * Permission to use, copy, modify, and distribute this software for any 6 * purpose with or without fee is hereby granted, provided that the above 7 * copyright notice and this permission notice appear in all copies. 8 * 9 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 10 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 11 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 12 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 13 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 14 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 15 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 16 */ 17 18 #include <netinet/in.h> 19 #include <sys/limits.h> 20 #include <sys/socket.h> 21 22 #include <err.h> 23 #include <poll.h> 24 #include <unistd.h> 25 26 #include <openssl/bio.h> 27 #include <openssl/err.h> 28 #include <openssl/ssl.h> 29 30 const char *server_ca_file; 31 const char *server_cert_file; 32 const char *server_key_file; 33 34 char dtls_cookie[32]; 35 36 int debug = 0; 37 38 static void 39 hexdump(const unsigned char *buf, size_t len) 40 { 41 size_t i; 42 43 for (i = 1; i <= len; i++) 44 fprintf(stderr, " 0x%02hhx,%s", buf[i - 1], i % 8 ? "" : "\n"); 45 46 if (len % 8) 47 fprintf(stderr, "\n"); 48 } 49 50 #define BIO_C_DROP_PACKET 1000 51 #define BIO_C_DROP_RANDOM 1001 52 53 struct bio_packet_monkey_ctx { 54 unsigned int drop_rand; 55 unsigned int drop_mask; 56 }; 57 58 static int 59 bio_packet_monkey_new(BIO *bio) 60 { 61 struct bio_packet_monkey_ctx *ctx; 62 63 if ((ctx = calloc(1, sizeof(*ctx))) == NULL) 64 return 0; 65 66 bio->flags = 0; 67 bio->init = 1; 68 bio->num = 0; 69 bio->ptr = ctx; 70 71 return 1; 72 } 73 74 static int 75 bio_packet_monkey_free(BIO *bio) 76 { 77 struct bio_packet_monkey_ctx *ctx; 78 79 if (bio == NULL) 80 return 1; 81 82 ctx = bio->ptr; 83 free(ctx); 84 85 return 1; 86 } 87 88 static long 89 bio_packet_monkey_ctrl(BIO *bio, int cmd, long num, void *ptr) 90 { 91 struct bio_packet_monkey_ctx *ctx; 92 93 ctx = bio->ptr; 94 95 switch (cmd) { 96 case BIO_C_DROP_PACKET: 97 if (num < 1 || num > 31) 98 return 0; 99 ctx->drop_mask |= 1 << ((unsigned int)num - 1); 100 return 1; 101 102 case BIO_C_DROP_RANDOM: 103 if (num < 0 || (size_t)num > UINT_MAX) 104 return 0; 105 ctx->drop_rand = (unsigned int)num; 106 return 1; 107 } 108 109 if (bio->next_bio == NULL) 110 return 0; 111 112 return BIO_ctrl(bio->next_bio, cmd, num, ptr); 113 } 114 115 static int 116 bio_packet_monkey_read(BIO *bio, char *out, int out_len) 117 { 118 struct bio_packet_monkey_ctx *ctx = bio->ptr; 119 int ret; 120 121 if (ctx == NULL || bio->next_bio == NULL) 122 return 0; 123 124 ret = BIO_read(bio->next_bio, out, out_len); 125 126 BIO_clear_retry_flags(bio); 127 if (ret <= 0 && BIO_should_retry(bio->next_bio)) 128 BIO_set_retry_read(bio); 129 130 return ret; 131 } 132 133 static int 134 bio_packet_monkey_write(BIO *bio, const char *in, int in_len) 135 { 136 struct bio_packet_monkey_ctx *ctx = bio->ptr; 137 int drop = 0; 138 int ret; 139 140 if (ctx == NULL || bio->next_bio == NULL) 141 return 0; 142 143 if (ctx->drop_rand > 0) { 144 drop = arc4random_uniform(ctx->drop_rand) == 0; 145 } else if (ctx->drop_mask > 0) { 146 drop = ctx->drop_mask & 1; 147 ctx->drop_mask >>= 1; 148 } 149 if (debug) { 150 fprintf(stderr, "DEBUG: %s packet...\n", 151 drop ? "dropping" : "writing"); 152 hexdump(in, in_len); 153 } 154 if (drop) 155 return in_len; 156 157 ret = BIO_write(bio->next_bio, in, in_len); 158 159 BIO_clear_retry_flags(bio); 160 if (ret <= 0 && BIO_should_retry(bio->next_bio)) 161 BIO_set_retry_write(bio); 162 163 return ret; 164 } 165 166 static int 167 bio_packet_monkey_puts(BIO *bio, const char *str) 168 { 169 return bio_packet_monkey_write(bio, str, strlen(str)); 170 } 171 172 static const BIO_METHOD bio_packet_monkey = { 173 .type = BIO_TYPE_BUFFER, 174 .name = "packet monkey", 175 .bread = bio_packet_monkey_read, 176 .bwrite = bio_packet_monkey_write, 177 .bputs = bio_packet_monkey_puts, 178 .ctrl = bio_packet_monkey_ctrl, 179 .create = bio_packet_monkey_new, 180 .destroy = bio_packet_monkey_free 181 }; 182 183 static const BIO_METHOD * 184 BIO_f_packet_monkey(void) 185 { 186 return &bio_packet_monkey; 187 } 188 189 static BIO * 190 BIO_new_packet_monkey(void) 191 { 192 return BIO_new(BIO_f_packet_monkey()); 193 } 194 195 static int 196 BIO_packet_monkey_drop(BIO *bio, int num) 197 { 198 return BIO_ctrl(bio, BIO_C_DROP_PACKET, num, NULL); 199 } 200 201 #if 0 202 static int 203 BIO_packet_monkey_drop_random(BIO *bio, int num) 204 { 205 return BIO_ctrl(bio, BIO_C_DROP_RANDOM, num, NULL); 206 } 207 #endif 208 209 static int 210 datagram_pair(int *client_sock, int *server_sock, 211 struct sockaddr_in *server_sin) 212 { 213 struct sockaddr_in sin; 214 socklen_t sock_len; 215 int cs = -1, ss = -1; 216 217 memset(&sin, 0, sizeof(sin)); 218 sin.sin_family = AF_INET; 219 sin.sin_port = 0; 220 sin.sin_addr.s_addr = htonl(INADDR_LOOPBACK); 221 222 if ((ss = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP)) == -1) 223 err(1, "server socket"); 224 if (bind(ss, (struct sockaddr *)&sin, sizeof(sin)) == -1) 225 err(1, "server bind"); 226 sock_len = sizeof(sin); 227 if (getsockname(ss, (struct sockaddr *)&sin, &sock_len) == -1) 228 err(1, "server getsockname"); 229 230 if ((cs = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP)) == -1) 231 err(1, "client socket"); 232 if (connect(cs, (struct sockaddr *)&sin, sizeof(sin)) == -1) 233 err(1, "client connect"); 234 235 *client_sock = cs; 236 *server_sock = ss; 237 memcpy(server_sin, &sin, sizeof(sin)); 238 239 return 1; 240 } 241 242 static int 243 poll_timeout(SSL *client, SSL *server) 244 { 245 int client_timeout = 0, server_timeout = 0; 246 struct timeval timeout; 247 248 if (DTLSv1_get_timeout(client, &timeout)) 249 client_timeout = timeout.tv_sec * 1000 + timeout.tv_usec / 1000; 250 251 if (DTLSv1_get_timeout(server, &timeout)) 252 server_timeout = timeout.tv_sec * 1000 + timeout.tv_usec / 1000; 253 254 if (client_timeout <= 0) 255 return server_timeout; 256 if (client_timeout > 0 && server_timeout <= 0) 257 return client_timeout; 258 if (client_timeout < server_timeout) 259 return client_timeout; 260 261 return server_timeout; 262 } 263 264 static int 265 dtls_cookie_generate(SSL *ssl, unsigned char *cookie, 266 unsigned int *cookie_len) 267 { 268 arc4random_buf(dtls_cookie, sizeof(dtls_cookie)); 269 memcpy(cookie, dtls_cookie, sizeof(dtls_cookie)); 270 *cookie_len = sizeof(dtls_cookie); 271 272 return 1; 273 } 274 275 static int 276 dtls_cookie_verify(SSL *ssl, const unsigned char *cookie, 277 unsigned int cookie_len) 278 { 279 return cookie_len == sizeof(dtls_cookie) && 280 memcmp(cookie, dtls_cookie, sizeof(dtls_cookie)) == 0; 281 } 282 283 static void 284 dtls_info_callback(const SSL *ssl, int type, int val) 285 { 286 /* 287 * Squeal's ahead... remove the bbio from the info callback, so we can 288 * drop specific messages. Ideally this would be an option for the SSL. 289 */ 290 if (ssl->wbio == ssl->bbio) 291 ((SSL *)ssl)->wbio = BIO_pop(ssl->wbio); 292 } 293 294 static SSL * 295 dtls_client(int sock, struct sockaddr_in *server_sin, long mtu) 296 { 297 SSL_CTX *ssl_ctx = NULL; 298 SSL *ssl = NULL; 299 BIO *bio = NULL; 300 301 if ((bio = BIO_new_dgram(sock, BIO_NOCLOSE)) == NULL) 302 errx(1, "client bio"); 303 if (!BIO_socket_nbio(sock, 1)) 304 errx(1, "client nbio"); 305 if (!BIO_ctrl_set_connected(bio, 1, server_sin)) 306 errx(1, "client set connected"); 307 308 if ((ssl_ctx = SSL_CTX_new(DTLS_method())) == NULL) 309 errx(1, "client context"); 310 SSL_CTX_set_read_ahead(ssl_ctx, 1); 311 312 if ((ssl = SSL_new(ssl_ctx)) == NULL) 313 errx(1, "client ssl"); 314 315 SSL_set_bio(ssl, bio, bio); 316 bio = NULL; 317 318 if (mtu > 0) { 319 SSL_set_options(ssl, SSL_OP_NO_QUERY_MTU); 320 SSL_set_mtu(ssl, mtu); 321 } 322 323 SSL_CTX_free(ssl_ctx); 324 BIO_free(bio); 325 326 return ssl; 327 } 328 329 static SSL * 330 dtls_server(int sock, long options, long mtu) 331 { 332 SSL_CTX *ssl_ctx = NULL; 333 SSL *ssl = NULL; 334 BIO *bio = NULL; 335 336 if ((bio = BIO_new_dgram(sock, BIO_NOCLOSE)) == NULL) 337 errx(1, "server bio"); 338 if (!BIO_socket_nbio(sock, 1)) 339 errx(1, "server nbio"); 340 341 if ((ssl_ctx = SSL_CTX_new(DTLS_method())) == NULL) 342 errx(1, "server context"); 343 344 SSL_CTX_set_cookie_generate_cb(ssl_ctx, dtls_cookie_generate); 345 SSL_CTX_set_cookie_verify_cb(ssl_ctx, dtls_cookie_verify); 346 SSL_CTX_set_options(ssl_ctx, options); 347 SSL_CTX_set_read_ahead(ssl_ctx, 1); 348 349 if (SSL_CTX_use_certificate_file(ssl_ctx, server_cert_file, 350 SSL_FILETYPE_PEM) != 1) { 351 fprintf(stderr, "FAIL: Failed to load server certificate"); 352 goto failure; 353 } 354 if (SSL_CTX_use_PrivateKey_file(ssl_ctx, server_key_file, 355 SSL_FILETYPE_PEM) != 1) { 356 fprintf(stderr, "FAIL: Failed to load server private key"); 357 goto failure; 358 } 359 360 if ((ssl = SSL_new(ssl_ctx)) == NULL) 361 errx(1, "server ssl"); 362 363 SSL_set_bio(ssl, bio, bio); 364 bio = NULL; 365 366 if (mtu > 0) { 367 SSL_set_options(ssl, SSL_OP_NO_QUERY_MTU); 368 SSL_set_mtu(ssl, mtu); 369 } 370 371 failure: 372 SSL_CTX_free(ssl_ctx); 373 BIO_free(bio); 374 375 return ssl; 376 } 377 378 static int 379 ssl_error(SSL *ssl, const char *name, const char *desc, int ssl_ret, 380 short *events) 381 { 382 int ssl_err; 383 384 ssl_err = SSL_get_error(ssl, ssl_ret); 385 386 if (ssl_err == SSL_ERROR_WANT_READ) { 387 *events = POLLIN; 388 } else if (ssl_err == SSL_ERROR_WANT_WRITE) { 389 *events = POLLOUT; 390 } else if (ssl_err == SSL_ERROR_SYSCALL && errno == 0) { 391 /* Yup, this is apparently a thing... */ 392 } else { 393 fprintf(stderr, "FAIL: %s %s failed - ssl err = %d, errno = %d\n", 394 name, desc, ssl_err, errno); 395 ERR_print_errors_fp(stderr); 396 return 0; 397 } 398 399 return 1; 400 } 401 402 static int 403 do_connect(SSL *ssl, const char *name, int *done, short *events) 404 { 405 int ssl_ret; 406 407 if ((ssl_ret = SSL_connect(ssl)) == 1) { 408 fprintf(stderr, "INFO: %s connect done\n", name); 409 *done = 1; 410 return 1; 411 } 412 413 return ssl_error(ssl, name, "connect", ssl_ret, events); 414 } 415 416 static int 417 do_accept(SSL *ssl, const char *name, int *done, short *events) 418 { 419 int ssl_ret; 420 421 if ((ssl_ret = SSL_accept(ssl)) == 1) { 422 fprintf(stderr, "INFO: %s accept done\n", name); 423 *done = 1; 424 return 1; 425 } 426 427 return ssl_error(ssl, name, "accept", ssl_ret, events); 428 } 429 430 static int 431 do_shutdown(SSL *ssl, const char *name, int *done, short *events) 432 { 433 int ssl_ret; 434 435 ssl_ret = SSL_shutdown(ssl); 436 if (ssl_ret == 1) { 437 fprintf(stderr, "INFO: %s shutdown done\n", name); 438 *done = 1; 439 return 1; 440 } 441 return ssl_error(ssl, name, "shutdown", ssl_ret, events); 442 } 443 444 typedef int (*ssl_func)(SSL *ssl, const char *name, int *done, short *events); 445 446 static int 447 do_client_server_loop(SSL *client, ssl_func client_func, SSL *server, 448 ssl_func server_func, struct pollfd pfd[2]) 449 { 450 int client_done = 0, server_done = 0; 451 int i = 0; 452 453 pfd[0].revents = POLLIN; 454 pfd[1].revents = POLLIN; 455 456 do { 457 if (!client_done) { 458 if (debug) 459 fprintf(stderr, "DEBUG: client loop\n"); 460 if (DTLSv1_handle_timeout(client) > 0) 461 fprintf(stderr, "INFO: client timeout\n"); 462 if (!client_func(client, "client", &client_done, 463 &pfd[0].events)) 464 return 0; 465 if (client_done) 466 pfd[0].events = 0; 467 } 468 if (!server_done) { 469 if (debug) 470 fprintf(stderr, "DEBUG: server loop\n"); 471 if (DTLSv1_handle_timeout(server) > 0) 472 fprintf(stderr, "INFO: server timeout\n"); 473 if (!server_func(server, "server", &server_done, 474 &pfd[1].events)) 475 return 0; 476 if (server_done) 477 pfd[1].events = 0; 478 } 479 if (poll(pfd, 2, poll_timeout(client, server)) == -1) 480 err(1, "poll"); 481 482 } while (i++ < 100 && (!client_done || !server_done)); 483 484 if (!client_done || !server_done) 485 fprintf(stderr, "FAIL: gave up\n"); 486 487 return client_done && server_done; 488 } 489 490 #define MAX_PACKET_DROPS 32 491 492 struct dtls_test { 493 const unsigned char *desc; 494 long mtu; 495 long ssl_options; 496 int client_bbio_off; 497 int server_bbio_off; 498 uint8_t client_drops[MAX_PACKET_DROPS]; 499 uint8_t server_drops[MAX_PACKET_DROPS]; 500 }; 501 502 static const struct dtls_test dtls_tests[] = { 503 { 504 .desc = "DTLS without cookies", 505 .ssl_options = 0, 506 }, 507 { 508 .desc = "DTLS with cookies", 509 .ssl_options = SSL_OP_COOKIE_EXCHANGE, 510 }, 511 { 512 .desc = "DTLS with low MTU", 513 .mtu = 256, 514 .ssl_options = 0, 515 }, 516 { 517 .desc = "DTLS with low MTU and cookies", 518 .mtu = 256, 519 .ssl_options = SSL_OP_COOKIE_EXCHANGE, 520 }, 521 { 522 .desc = "DTLS with dropped server response", 523 .ssl_options = 0, 524 .server_drops = { 1 }, 525 }, 526 { 527 .desc = "DTLS with two dropped server responses", 528 .ssl_options = 0, 529 .server_drops = { 1, 2 }, 530 }, 531 { 532 .desc = "DTLS with dropped ServerHello", 533 .ssl_options = 0, 534 .server_bbio_off = 1, 535 .server_drops = { 1 }, 536 }, 537 { 538 .desc = "DTLS with dropped server Certificate", 539 .ssl_options = 0, 540 .server_bbio_off = 1, 541 .server_drops = { 2 }, 542 }, 543 { 544 .desc = "DTLS with dropped ServerKeyExchange", 545 .ssl_options = 0, 546 .server_bbio_off = 1, 547 .server_drops = { 3 }, 548 }, 549 #if 0 550 /* 551 * These three currently result in the server accept completing and the 552 * client looping on a timeout. Presumably the server should not 553 * complete until the client Finished is received... 554 */ 555 { 556 .desc = "DTLS with dropped ServerHelloDone", 557 .ssl_options = 0, 558 .server_bbio_off = 1, 559 .server_drops = { 4 }, 560 }, 561 { 562 .desc = "DTLS with dropped server CCS", 563 .ssl_options = 0, 564 .server_bbio_off = 1, 565 .server_drops = { 5 }, 566 }, 567 { 568 .desc = "DTLS with dropped server Finished", 569 .ssl_options = 0, 570 .server_bbio_off = 1, 571 .server_drops = { 6 }, 572 }, 573 #endif 574 { 575 .desc = "DTLS with dropped ClientKeyExchange", 576 .ssl_options = 0, 577 .client_bbio_off = 1, 578 .client_drops = { 2 }, 579 }, 580 { 581 .desc = "DTLS with dropped Client CCS", 582 .ssl_options = 0, 583 .client_bbio_off = 1, 584 .client_drops = { 3 }, 585 }, 586 { 587 .desc = "DTLS with dropped client Finished", 588 .ssl_options = 0, 589 .client_bbio_off = 1, 590 .client_drops = { 4 }, 591 }, 592 }; 593 594 #define N_DTLS_TESTS (sizeof(dtls_tests) / sizeof(*dtls_tests)) 595 596 static void 597 dtlstest_packet_monkey(SSL *ssl, const uint8_t drops[]) 598 { 599 BIO *bio_monkey; 600 BIO *bio; 601 int i; 602 603 if ((bio_monkey = BIO_new_packet_monkey()) == NULL) 604 errx(1, "packet monkey"); 605 606 for (i = 0; i < MAX_PACKET_DROPS; i++) { 607 if (drops[i] == 0) 608 break; 609 if (!BIO_packet_monkey_drop(bio_monkey, drops[i])) 610 errx(1, "drop failure"); 611 } 612 613 if ((bio = SSL_get_wbio(ssl)) == NULL) 614 errx(1, "SSL has NULL bio"); 615 616 BIO_up_ref(bio); 617 bio = BIO_push(bio_monkey, bio); 618 619 SSL_set_bio(ssl, bio, bio); 620 } 621 622 static int 623 dtlstest(const struct dtls_test *dt) 624 { 625 SSL *client = NULL, *server = NULL; 626 struct sockaddr_in server_sin; 627 struct pollfd pfd[2]; 628 int client_sock = -1; 629 int server_sock = -1; 630 int failed = 1; 631 632 fprintf(stderr, "\n== Testing %s... ==\n", dt->desc); 633 634 if (!datagram_pair(&client_sock, &server_sock, &server_sin)) 635 goto failure; 636 637 if ((client = dtls_client(client_sock, &server_sin, dt->mtu)) == NULL) 638 goto failure; 639 if ((server = dtls_server(server_sock, dt->ssl_options, dt->mtu)) == NULL) 640 goto failure; 641 642 if (dt->client_bbio_off) 643 SSL_set_info_callback(client, dtls_info_callback); 644 if (dt->server_bbio_off) 645 SSL_set_info_callback(server, dtls_info_callback); 646 647 dtlstest_packet_monkey(client, dt->client_drops); 648 dtlstest_packet_monkey(server, dt->server_drops); 649 650 pfd[0].fd = client_sock; 651 pfd[0].events = POLLOUT; 652 pfd[1].fd = server_sock; 653 pfd[1].events = POLLIN; 654 655 if (!do_client_server_loop(client, do_connect, server, do_accept, pfd)) { 656 fprintf(stderr, "FAIL: client and server handshake failed\n"); 657 goto failure; 658 } 659 660 /* XXX - do reads and writes. */ 661 662 pfd[0].events = POLLOUT; 663 pfd[1].events = POLLOUT; 664 665 if (!do_client_server_loop(client, do_shutdown, server, do_shutdown, pfd)) { 666 fprintf(stderr, "FAIL: client and server shutdown failed\n"); 667 goto failure; 668 } 669 670 fprintf(stderr, "INFO: Done!\n"); 671 672 failed = 0; 673 674 failure: 675 if (client_sock != -1) 676 close(client_sock); 677 if (server_sock != -1) 678 close(server_sock); 679 680 SSL_free(client); 681 SSL_free(server); 682 683 return failed; 684 } 685 686 int 687 main(int argc, char **argv) 688 { 689 int failed = 0; 690 size_t i; 691 692 if (argc != 4) { 693 fprintf(stderr, "usage: %s keyfile certfile cafile\n", 694 argv[0]); 695 exit(1); 696 } 697 698 server_key_file = argv[1]; 699 server_cert_file = argv[2]; 700 server_ca_file = argv[3]; 701 702 for (i = 0; i < N_DTLS_TESTS; i++) 703 failed |= dtlstest(&dtls_tests[i]); 704 705 return failed; 706 } 707