1 /* $OpenBSD: tlstest.c,v 1.1 2017/01/12 15:50:16 jsing Exp $ */ 2 /* 3 * Copyright (c) 2017 Joel Sing <jsing@openbsd.org> 4 * 5 * Permission to use, copy, modify, and distribute this software for any 6 * purpose with or without fee is hereby granted, provided that the above 7 * copyright notice and this permission notice appear in all copies. 8 * 9 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 10 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 11 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 12 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 13 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 14 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 15 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 16 */ 17 18 #include <sys/socket.h> 19 20 #include <err.h> 21 #include <fcntl.h> 22 #include <stdio.h> 23 #include <string.h> 24 #include <unistd.h> 25 26 #include <tls.h> 27 28 #define CIRCULAR_BUFFER_SIZE 512 29 30 unsigned char client_buffer[CIRCULAR_BUFFER_SIZE]; 31 unsigned char *client_readptr, *client_writeptr; 32 33 unsigned char server_buffer[CIRCULAR_BUFFER_SIZE]; 34 unsigned char *server_readptr, *server_writeptr; 35 36 int debug = 0; 37 38 static void 39 circular_init(void) 40 { 41 client_readptr = client_writeptr = client_buffer; 42 server_readptr = server_writeptr = server_buffer; 43 } 44 45 static ssize_t 46 circular_read(char *name, unsigned char *buf, size_t bufsize, 47 unsigned char **readptr, unsigned char *writeptr, 48 unsigned char *outbuf, size_t outlen) 49 { 50 unsigned char *nextptr = *readptr; 51 size_t n = 0; 52 53 while (n < outlen) { 54 if (nextptr == writeptr) 55 break; 56 *outbuf++ = *nextptr++; 57 if ((size_t)(nextptr - buf) >= bufsize) 58 nextptr = buf; 59 *readptr = nextptr; 60 n++; 61 } 62 63 if (debug && n > 0) 64 fprintf(stderr, "%s buffer: read %zi bytes\n", name, n); 65 66 return (n > 0 ? (ssize_t)n : TLS_WANT_POLLIN); 67 } 68 69 static ssize_t 70 circular_write(char *name, unsigned char *buf, size_t bufsize, 71 unsigned char *readptr, unsigned char **writeptr, 72 const unsigned char *inbuf, size_t inlen) 73 { 74 unsigned char *nextptr = *writeptr; 75 unsigned char *prevptr; 76 size_t n = 0; 77 78 while (n < inlen) { 79 prevptr = nextptr++; 80 if ((size_t)(nextptr - buf) >= bufsize) 81 nextptr = buf; 82 if (nextptr == readptr) 83 break; 84 *prevptr = *inbuf++; 85 *writeptr = nextptr; 86 n++; 87 } 88 89 if (debug && n > 0) 90 fprintf(stderr, "%s buffer: wrote %zi bytes\n", name, n); 91 92 return (n > 0 ? (ssize_t)n : TLS_WANT_POLLOUT); 93 } 94 95 static ssize_t 96 client_read(struct tls *ctx, void *buf, size_t buflen, void *cb_arg) 97 { 98 return circular_read("client", client_buffer, sizeof(client_buffer), 99 &client_readptr, client_writeptr, buf, buflen); 100 } 101 102 static ssize_t 103 client_write(struct tls *ctx, const void *buf, size_t buflen, void *cb_arg) 104 { 105 return circular_write("server", server_buffer, sizeof(server_buffer), 106 server_readptr, &server_writeptr, buf, buflen); 107 } 108 109 static ssize_t 110 server_read(struct tls *ctx, void *buf, size_t buflen, void *cb_arg) 111 { 112 return circular_read("server", server_buffer, sizeof(server_buffer), 113 &server_readptr, server_writeptr, buf, buflen); 114 } 115 116 static ssize_t 117 server_write(struct tls *ctx, const void *buf, size_t buflen, void *cb_arg) 118 { 119 return circular_write("client", client_buffer, sizeof(client_buffer), 120 client_readptr, &client_writeptr, buf, buflen); 121 } 122 123 static int 124 do_tls_handshake(char *name, struct tls *ctx) 125 { 126 int rv; 127 128 rv = tls_handshake(ctx); 129 if (rv == 0) 130 return (1); 131 if (rv == TLS_WANT_POLLIN || rv == TLS_WANT_POLLOUT) 132 return (0); 133 134 errx(1, "%s handshake failed: %s", name, tls_error(ctx)); 135 } 136 137 static int 138 do_tls_close(char *name, struct tls *ctx) 139 { 140 int rv; 141 142 rv = tls_close(ctx); 143 if (rv == 0) 144 return (1); 145 if (rv == TLS_WANT_POLLIN || rv == TLS_WANT_POLLOUT) 146 return (0); 147 148 errx(1, "%s close failed: %s", name, tls_error(ctx)); 149 } 150 151 static int 152 do_client_server_test(char *desc, struct tls *client, struct tls *server_cctx) 153 { 154 int i, client_done, server_done; 155 156 i = client_done = server_done = 0; 157 do { 158 if (client_done == 0) 159 client_done = do_tls_handshake("client", client); 160 if (server_done == 0) 161 server_done = do_tls_handshake("server", server_cctx); 162 } while (i++ < 100 && (client_done == 0 || server_done == 0)); 163 164 if (client_done == 0 || server_done == 0) { 165 printf("FAIL: %s TLS handshake did not complete\n", desc); 166 return (1); 167 } 168 printf("INFO: %s TLS handshake completed successfully\n", desc); 169 170 /* XXX - Do some reads and writes... */ 171 172 i = client_done = server_done = 0; 173 do { 174 if (client_done == 0) 175 client_done = do_tls_close("client", client); 176 if (server_done == 0) 177 server_done = do_tls_close("server", server_cctx); 178 } while (i++ < 100 && (client_done == 0 || server_done == 0)); 179 180 if (client_done == 0 || server_done == 0) { 181 printf("FAIL: %s TLS close did not complete\n", desc); 182 return (1); 183 } 184 printf("INFO: %s TLS close completed successfully\n", desc); 185 186 return (0); 187 } 188 189 static int 190 test_tls_cbs(struct tls *client, struct tls *server) 191 { 192 struct tls *server_cctx; 193 int failure; 194 195 circular_init(); 196 197 if (tls_accept_cbs(server, &server_cctx, server_read, server_write, 198 NULL) == -1) 199 errx(1, "failed to accept: %s", tls_error(server)); 200 201 if (tls_connect_cbs(client, client_read, client_write, NULL, 202 "test") == -1) 203 errx(1, "failed to connect: %s", tls_error(client)); 204 205 failure = do_client_server_test("callback", client, server_cctx); 206 207 tls_free(server_cctx); 208 209 return (failure); 210 } 211 212 static int 213 test_tls_fds(struct tls *client, struct tls *server) 214 { 215 struct tls *server_cctx; 216 int cfds[2], sfds[2]; 217 int failure; 218 219 if (pipe2(cfds, O_NONBLOCK) == -1) 220 err(1, "failed to create pipe"); 221 if (pipe2(sfds, O_NONBLOCK) == -1) 222 err(1, "failed to create pipe"); 223 224 if (tls_accept_fds(server, &server_cctx, sfds[0], cfds[1]) == -1) 225 errx(1, "failed to accept: %s", tls_error(server)); 226 227 if (tls_connect_fds(client, cfds[0], sfds[1], "test") == -1) 228 errx(1, "failed to connect: %s", tls_error(client)); 229 230 failure = do_client_server_test("file descriptor", client, server_cctx); 231 232 tls_free(server_cctx); 233 234 close(cfds[0]); 235 close(cfds[1]); 236 close(sfds[0]); 237 close(sfds[1]); 238 239 return (failure); 240 } 241 242 static int 243 test_tls_socket(struct tls *client, struct tls *server) 244 { 245 struct tls *server_cctx; 246 int failure; 247 int sv[2]; 248 249 if (socketpair(AF_UNIX, SOCK_STREAM | SOCK_NONBLOCK, PF_UNSPEC, 250 sv) == -1) 251 err(1, "failed to create socketpair"); 252 253 if (tls_accept_socket(server, &server_cctx, sv[0]) == -1) 254 errx(1, "failed to accept: %s", tls_error(server)); 255 256 if (tls_connect_socket(client, sv[1], "test") == -1) 257 errx(1, "failed to connect: %s", tls_error(client)); 258 259 failure = do_client_server_test("socket", client, server_cctx); 260 261 tls_free(server_cctx); 262 263 close(sv[0]); 264 close(sv[1]); 265 266 return (failure); 267 } 268 269 int 270 main(int argc, char **argv) 271 { 272 struct tls_config *client_cfg, *server_cfg; 273 struct tls *client, *server; 274 int failure = 0; 275 276 if (argc != 4) { 277 fprintf(stderr, "usage: %s keyfile certfile cafile\n", 278 argv[0]); 279 return (1); 280 } 281 282 if (tls_init() == -1) 283 errx(1, "failed to initialise tls"); 284 285 if ((client = tls_client()) == NULL) 286 errx(1, "failed to create tls client"); 287 if ((client_cfg = tls_config_new()) == NULL) 288 errx(1, "failed to create tls client config"); 289 tls_config_insecure_noverifyname(client_cfg); 290 if (tls_config_set_ca_file(client_cfg, argv[3])) 291 errx(1, "failed to set ca: %s", tls_config_error(client_cfg)); 292 293 if ((server = tls_server()) == NULL) 294 errx(1, "failed to create tls server"); 295 if ((server_cfg = tls_config_new()) == NULL) 296 errx(1, "failed to create tls server config"); 297 if (tls_config_set_keypair_file(server_cfg, argv[1], argv[2]) == -1) 298 errx(1, "failed to set keypair: %s", 299 tls_config_error(server_cfg)); 300 301 tls_reset(client); 302 if (tls_configure(client, client_cfg) == -1) 303 errx(1, "failed to configure client: %s", tls_error(client)); 304 tls_reset(server); 305 if (tls_configure(server, server_cfg) == -1) 306 errx(1, "failed to configure server: %s", tls_error(server)); 307 308 failure |= test_tls_cbs(client, server); 309 310 tls_reset(client); 311 if (tls_configure(client, client_cfg) == -1) 312 errx(1, "failed to configure client: %s", tls_error(client)); 313 tls_reset(server); 314 if (tls_configure(server, server_cfg) == -1) 315 errx(1, "failed to configure server: %s", tls_error(server)); 316 317 failure |= test_tls_fds(client, server); 318 319 tls_reset(client); 320 if (tls_configure(client, client_cfg) == -1) 321 errx(1, "failed to configure client: %s", tls_error(client)); 322 tls_reset(server); 323 if (tls_configure(server, server_cfg) == -1) 324 errx(1, "failed to configure server: %s", tls_error(server)); 325 326 failure |= test_tls_socket(client, server); 327 328 tls_free(client); 329 tls_free(server); 330 331 tls_config_free(client_cfg); 332 tls_config_free(server_cfg); 333 334 return (failure); 335 } 336