xref: /openbsd-src/regress/lib/libtls/tls/tlstest.c (revision d7e1855272fac526a7ecb530129f34b42a3cddda)
1 /* $OpenBSD: tlstest.c,v 1.1 2017/01/12 15:50:16 jsing Exp $ */
2 /*
3  * Copyright (c) 2017 Joel Sing <jsing@openbsd.org>
4  *
5  * Permission to use, copy, modify, and distribute this software for any
6  * purpose with or without fee is hereby granted, provided that the above
7  * copyright notice and this permission notice appear in all copies.
8  *
9  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
10  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
11  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
12  * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
13  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
14  * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
15  * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
16  */
17 
18 #include <sys/socket.h>
19 
20 #include <err.h>
21 #include <fcntl.h>
22 #include <stdio.h>
23 #include <string.h>
24 #include <unistd.h>
25 
26 #include <tls.h>
27 
28 #define CIRCULAR_BUFFER_SIZE 512
29 
30 unsigned char client_buffer[CIRCULAR_BUFFER_SIZE];
31 unsigned char *client_readptr, *client_writeptr;
32 
33 unsigned char server_buffer[CIRCULAR_BUFFER_SIZE];
34 unsigned char *server_readptr, *server_writeptr;
35 
36 int debug = 0;
37 
38 static void
39 circular_init(void)
40 {
41 	client_readptr = client_writeptr = client_buffer;
42 	server_readptr = server_writeptr = server_buffer;
43 }
44 
45 static ssize_t
46 circular_read(char *name, unsigned char *buf, size_t bufsize,
47     unsigned char **readptr, unsigned char *writeptr,
48     unsigned char *outbuf, size_t outlen)
49 {
50 	unsigned char *nextptr = *readptr;
51 	size_t n = 0;
52 
53 	while (n < outlen) {
54 		if (nextptr == writeptr)
55 			break;
56 		*outbuf++ = *nextptr++;
57 		if ((size_t)(nextptr - buf) >= bufsize)
58 			nextptr = buf;
59 		*readptr = nextptr;
60 		n++;
61 	}
62 
63 	if (debug && n > 0)
64 		fprintf(stderr, "%s buffer: read %zi bytes\n", name, n);
65 
66 	return (n > 0 ? (ssize_t)n : TLS_WANT_POLLIN);
67 }
68 
69 static ssize_t
70 circular_write(char *name, unsigned char *buf, size_t bufsize,
71     unsigned char *readptr, unsigned char **writeptr,
72     const unsigned char *inbuf, size_t inlen)
73 {
74 	unsigned char *nextptr = *writeptr;
75 	unsigned char *prevptr;
76 	size_t n = 0;
77 
78 	while (n < inlen) {
79 		prevptr = nextptr++;
80 		if ((size_t)(nextptr - buf) >= bufsize)
81 			nextptr = buf;
82 		if (nextptr == readptr)
83 			break;
84 		*prevptr = *inbuf++;
85 		*writeptr = nextptr;
86 		n++;
87 	}
88 
89 	if (debug && n > 0)
90 		fprintf(stderr, "%s buffer: wrote %zi bytes\n", name, n);
91 
92 	return (n > 0 ? (ssize_t)n : TLS_WANT_POLLOUT);
93 }
94 
95 static ssize_t
96 client_read(struct tls *ctx, void *buf, size_t buflen, void *cb_arg)
97 {
98 	return circular_read("client", client_buffer, sizeof(client_buffer),
99 	    &client_readptr, client_writeptr, buf, buflen);
100 }
101 
102 static ssize_t
103 client_write(struct tls *ctx, const void *buf, size_t buflen, void *cb_arg)
104 {
105 	return circular_write("server", server_buffer, sizeof(server_buffer),
106 	    server_readptr, &server_writeptr, buf, buflen);
107 }
108 
109 static ssize_t
110 server_read(struct tls *ctx, void *buf, size_t buflen, void *cb_arg)
111 {
112 	return circular_read("server", server_buffer, sizeof(server_buffer),
113 	    &server_readptr, server_writeptr, buf, buflen);
114 }
115 
116 static ssize_t
117 server_write(struct tls *ctx, const void *buf, size_t buflen, void *cb_arg)
118 {
119 	return circular_write("client", client_buffer, sizeof(client_buffer),
120 	    client_readptr, &client_writeptr, buf, buflen);
121 }
122 
123 static int
124 do_tls_handshake(char *name, struct tls *ctx)
125 {
126 	int rv;
127 
128 	rv = tls_handshake(ctx);
129 	if (rv == 0)
130 		return (1);
131 	if (rv == TLS_WANT_POLLIN || rv == TLS_WANT_POLLOUT)
132 		return (0);
133 
134 	errx(1, "%s handshake failed: %s", name, tls_error(ctx));
135 }
136 
137 static int
138 do_tls_close(char *name, struct tls *ctx)
139 {
140 	int rv;
141 
142 	rv = tls_close(ctx);
143 	if (rv == 0)
144 		return (1);
145 	if (rv == TLS_WANT_POLLIN || rv == TLS_WANT_POLLOUT)
146 		return (0);
147 
148 	errx(1, "%s close failed: %s", name, tls_error(ctx));
149 }
150 
151 static int
152 do_client_server_test(char *desc, struct tls *client, struct tls *server_cctx)
153 {
154 	int i, client_done, server_done;
155 
156 	i = client_done = server_done = 0;
157 	do {
158 		if (client_done == 0)
159 			client_done = do_tls_handshake("client", client);
160 		if (server_done == 0)
161 			server_done = do_tls_handshake("server", server_cctx);
162 	} while (i++ < 100 && (client_done == 0 || server_done == 0));
163 
164 	if (client_done == 0 || server_done == 0) {
165 		printf("FAIL: %s TLS handshake did not complete\n", desc);
166 		return (1);
167 	}
168 	printf("INFO: %s TLS handshake completed successfully\n", desc);
169 
170 	/* XXX - Do some reads and writes... */
171 
172 	i = client_done = server_done = 0;
173 	do {
174 		if (client_done == 0)
175 			client_done = do_tls_close("client", client);
176 		if (server_done == 0)
177 			server_done = do_tls_close("server", server_cctx);
178 	} while (i++ < 100 && (client_done == 0 || server_done == 0));
179 
180 	if (client_done == 0 || server_done == 0) {
181 		printf("FAIL: %s TLS close did not complete\n", desc);
182 		return (1);
183 	}
184 	printf("INFO: %s TLS close completed successfully\n", desc);
185 
186 	return (0);
187 }
188 
189 static int
190 test_tls_cbs(struct tls *client, struct tls *server)
191 {
192 	struct tls *server_cctx;
193 	int failure;
194 
195 	circular_init();
196 
197 	if (tls_accept_cbs(server, &server_cctx, server_read, server_write,
198 	    NULL) == -1)
199 		errx(1, "failed to accept: %s", tls_error(server));
200 
201 	if (tls_connect_cbs(client, client_read, client_write, NULL,
202 	    "test") == -1)
203 		errx(1, "failed to connect: %s", tls_error(client));
204 
205 	failure = do_client_server_test("callback", client, server_cctx);
206 
207 	tls_free(server_cctx);
208 
209 	return (failure);
210 }
211 
212 static int
213 test_tls_fds(struct tls *client, struct tls *server)
214 {
215 	struct tls *server_cctx;
216 	int cfds[2], sfds[2];
217 	int failure;
218 
219 	if (pipe2(cfds, O_NONBLOCK) == -1)
220 		err(1, "failed to create pipe");
221 	if (pipe2(sfds, O_NONBLOCK) == -1)
222 		err(1, "failed to create pipe");
223 
224 	if (tls_accept_fds(server, &server_cctx, sfds[0], cfds[1]) == -1)
225 		errx(1, "failed to accept: %s", tls_error(server));
226 
227 	if (tls_connect_fds(client, cfds[0], sfds[1], "test") == -1)
228 		errx(1, "failed to connect: %s", tls_error(client));
229 
230 	failure = do_client_server_test("file descriptor", client, server_cctx);
231 
232 	tls_free(server_cctx);
233 
234 	close(cfds[0]);
235 	close(cfds[1]);
236 	close(sfds[0]);
237 	close(sfds[1]);
238 
239 	return (failure);
240 }
241 
242 static int
243 test_tls_socket(struct tls *client, struct tls *server)
244 {
245 	struct tls *server_cctx;
246 	int failure;
247 	int sv[2];
248 
249 	if (socketpair(AF_UNIX, SOCK_STREAM | SOCK_NONBLOCK, PF_UNSPEC,
250 	    sv) == -1)
251 		err(1, "failed to create socketpair");
252 
253 	if (tls_accept_socket(server, &server_cctx, sv[0]) == -1)
254 		errx(1, "failed to accept: %s", tls_error(server));
255 
256 	if (tls_connect_socket(client, sv[1], "test") == -1)
257 		errx(1, "failed to connect: %s", tls_error(client));
258 
259 	failure = do_client_server_test("socket", client, server_cctx);
260 
261 	tls_free(server_cctx);
262 
263 	close(sv[0]);
264 	close(sv[1]);
265 
266 	return (failure);
267 }
268 
269 int
270 main(int argc, char **argv)
271 {
272 	struct tls_config *client_cfg, *server_cfg;
273 	struct tls *client, *server;
274 	int failure = 0;
275 
276 	if (argc != 4) {
277 		fprintf(stderr, "usage: %s keyfile certfile cafile\n",
278 		    argv[0]);
279 		return (1);
280 	}
281 
282 	if (tls_init() == -1)
283 		errx(1, "failed to initialise tls");
284 
285 	if ((client = tls_client()) == NULL)
286 		errx(1, "failed to create tls client");
287 	if ((client_cfg = tls_config_new()) == NULL)
288 		errx(1, "failed to create tls client config");
289 	tls_config_insecure_noverifyname(client_cfg);
290 	if (tls_config_set_ca_file(client_cfg, argv[3]))
291 		errx(1, "failed to set ca: %s", tls_config_error(client_cfg));
292 
293 	if ((server = tls_server()) == NULL)
294 		errx(1, "failed to create tls server");
295 	if ((server_cfg = tls_config_new()) == NULL)
296 		errx(1, "failed to create tls server config");
297 	if (tls_config_set_keypair_file(server_cfg, argv[1], argv[2]) == -1)
298 		errx(1, "failed to set keypair: %s",
299 		    tls_config_error(server_cfg));
300 
301 	tls_reset(client);
302 	if (tls_configure(client, client_cfg) == -1)
303 		errx(1, "failed to configure client: %s", tls_error(client));
304 	tls_reset(server);
305 	if (tls_configure(server, server_cfg) == -1)
306 		errx(1, "failed to configure server: %s", tls_error(server));
307 
308 	failure |= test_tls_cbs(client, server);
309 
310 	tls_reset(client);
311 	if (tls_configure(client, client_cfg) == -1)
312 		errx(1, "failed to configure client: %s", tls_error(client));
313 	tls_reset(server);
314 	if (tls_configure(server, server_cfg) == -1)
315 		errx(1, "failed to configure server: %s", tls_error(server));
316 
317 	failure |= test_tls_fds(client, server);
318 
319 	tls_reset(client);
320 	if (tls_configure(client, client_cfg) == -1)
321 		errx(1, "failed to configure client: %s", tls_error(client));
322 	tls_reset(server);
323 	if (tls_configure(server, server_cfg) == -1)
324 		errx(1, "failed to configure server: %s", tls_error(server));
325 
326 	failure |= test_tls_socket(client, server);
327 
328 	tls_free(client);
329 	tls_free(server);
330 
331 	tls_config_free(client_cfg);
332 	tls_config_free(server_cfg);
333 
334 	return (failure);
335 }
336