xref: /openbsd-src/regress/lib/libssl/quic/quictest.c (revision fe6fa0cbb366148c8c6bcddd7c39efb2a131a9c5)
1*fe6fa0cbSjsing /* $OpenBSD: quictest.c,v 1.1 2022/08/27 09:16:29 jsing Exp $ */
2*fe6fa0cbSjsing /*
3*fe6fa0cbSjsing  * Copyright (c) 2020, 2021, 2022 Joel Sing <jsing@openbsd.org>
4*fe6fa0cbSjsing  *
5*fe6fa0cbSjsing  * Permission to use, copy, modify, and distribute this software for any
6*fe6fa0cbSjsing  * purpose with or without fee is hereby granted, provided that the above
7*fe6fa0cbSjsing  * copyright notice and this permission notice appear in all copies.
8*fe6fa0cbSjsing  *
9*fe6fa0cbSjsing  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
10*fe6fa0cbSjsing  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
11*fe6fa0cbSjsing  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
12*fe6fa0cbSjsing  * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
13*fe6fa0cbSjsing  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
14*fe6fa0cbSjsing  * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
15*fe6fa0cbSjsing  * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
16*fe6fa0cbSjsing  */
17*fe6fa0cbSjsing 
18*fe6fa0cbSjsing #include <err.h>
19*fe6fa0cbSjsing 
20*fe6fa0cbSjsing #include <openssl/bio.h>
21*fe6fa0cbSjsing #include <openssl/err.h>
22*fe6fa0cbSjsing #include <openssl/ssl.h>
23*fe6fa0cbSjsing 
24*fe6fa0cbSjsing const char *server_ca_file;
25*fe6fa0cbSjsing const char *server_cert_file;
26*fe6fa0cbSjsing const char *server_key_file;
27*fe6fa0cbSjsing 
28*fe6fa0cbSjsing int debug = 0;
29*fe6fa0cbSjsing 
30*fe6fa0cbSjsing static void
hexdump(const unsigned char * buf,size_t len)31*fe6fa0cbSjsing hexdump(const unsigned char *buf, size_t len)
32*fe6fa0cbSjsing {
33*fe6fa0cbSjsing 	size_t i;
34*fe6fa0cbSjsing 
35*fe6fa0cbSjsing 	for (i = 1; i <= len; i++)
36*fe6fa0cbSjsing 		fprintf(stderr, " 0x%02hhx,%s", buf[i - 1], i % 8 ? "" : "\n");
37*fe6fa0cbSjsing 
38*fe6fa0cbSjsing 	if (len % 8)
39*fe6fa0cbSjsing 		fprintf(stderr, "\n");
40*fe6fa0cbSjsing }
41*fe6fa0cbSjsing 
42*fe6fa0cbSjsing struct quic_data {
43*fe6fa0cbSjsing 	enum ssl_encryption_level_t rlevel;
44*fe6fa0cbSjsing 	enum ssl_encryption_level_t wlevel;
45*fe6fa0cbSjsing 	BIO *rbio;
46*fe6fa0cbSjsing 	BIO *wbio;
47*fe6fa0cbSjsing };
48*fe6fa0cbSjsing 
49*fe6fa0cbSjsing static int
quic_set_read_secret(SSL * ssl,enum ssl_encryption_level_t level,const SSL_CIPHER * cipher,const uint8_t * secret,size_t secret_len)50*fe6fa0cbSjsing quic_set_read_secret(SSL *ssl, enum ssl_encryption_level_t level,
51*fe6fa0cbSjsing     const SSL_CIPHER *cipher, const uint8_t *secret, size_t secret_len)
52*fe6fa0cbSjsing {
53*fe6fa0cbSjsing 	struct quic_data *qd = SSL_get_app_data(ssl);
54*fe6fa0cbSjsing 
55*fe6fa0cbSjsing 	qd->rlevel = level;
56*fe6fa0cbSjsing 
57*fe6fa0cbSjsing 	return 1;
58*fe6fa0cbSjsing }
59*fe6fa0cbSjsing 
60*fe6fa0cbSjsing static int
quic_set_write_secret(SSL * ssl,enum ssl_encryption_level_t level,const SSL_CIPHER * cipher,const uint8_t * secret,size_t secret_len)61*fe6fa0cbSjsing quic_set_write_secret(SSL *ssl, enum ssl_encryption_level_t level,
62*fe6fa0cbSjsing     const SSL_CIPHER *cipher, const uint8_t *secret, size_t secret_len)
63*fe6fa0cbSjsing {
64*fe6fa0cbSjsing 	struct quic_data *qd = SSL_get_app_data(ssl);
65*fe6fa0cbSjsing 
66*fe6fa0cbSjsing 	qd->wlevel = level;
67*fe6fa0cbSjsing 
68*fe6fa0cbSjsing 	return 1;
69*fe6fa0cbSjsing }
70*fe6fa0cbSjsing 
71*fe6fa0cbSjsing static int
quic_read_handshake_data(SSL * ssl)72*fe6fa0cbSjsing quic_read_handshake_data(SSL *ssl)
73*fe6fa0cbSjsing {
74*fe6fa0cbSjsing 	struct quic_data *qd = SSL_get_app_data(ssl);
75*fe6fa0cbSjsing 	uint8_t buf[2048];
76*fe6fa0cbSjsing 	int ret;
77*fe6fa0cbSjsing 
78*fe6fa0cbSjsing 	if ((ret = BIO_read(qd->rbio, buf, sizeof(buf))) > 0) {
79*fe6fa0cbSjsing 		if (debug > 1) {
80*fe6fa0cbSjsing 			fprintf(stderr, "== quic_read_handshake_data ==\n");
81*fe6fa0cbSjsing 			hexdump(buf, ret);
82*fe6fa0cbSjsing 		}
83*fe6fa0cbSjsing 		if (!SSL_provide_quic_data(ssl, qd->rlevel, buf, ret))
84*fe6fa0cbSjsing 			return -1;
85*fe6fa0cbSjsing 	}
86*fe6fa0cbSjsing 
87*fe6fa0cbSjsing 	return 1;
88*fe6fa0cbSjsing }
89*fe6fa0cbSjsing 
90*fe6fa0cbSjsing static int
quic_add_handshake_data(SSL * ssl,enum ssl_encryption_level_t level,const uint8_t * data,size_t len)91*fe6fa0cbSjsing quic_add_handshake_data(SSL *ssl, enum ssl_encryption_level_t level,
92*fe6fa0cbSjsing     const uint8_t *data, size_t len)
93*fe6fa0cbSjsing {
94*fe6fa0cbSjsing 	struct quic_data *qd = SSL_get_app_data(ssl);
95*fe6fa0cbSjsing 	int ret;
96*fe6fa0cbSjsing 
97*fe6fa0cbSjsing 	if (debug > 1) {
98*fe6fa0cbSjsing 		fprintf(stderr, "== quic_add_handshake_data\n");
99*fe6fa0cbSjsing 		hexdump(data, len);
100*fe6fa0cbSjsing 	}
101*fe6fa0cbSjsing 
102*fe6fa0cbSjsing 	if ((ret = BIO_write(qd->wbio, data, len)) <= 0)
103*fe6fa0cbSjsing 		return 0;
104*fe6fa0cbSjsing 
105*fe6fa0cbSjsing 	return (size_t)ret == len;
106*fe6fa0cbSjsing }
107*fe6fa0cbSjsing 
108*fe6fa0cbSjsing static int
quic_flush_flight(SSL * ssl)109*fe6fa0cbSjsing quic_flush_flight(SSL *ssl)
110*fe6fa0cbSjsing {
111*fe6fa0cbSjsing 	return 1;
112*fe6fa0cbSjsing }
113*fe6fa0cbSjsing 
114*fe6fa0cbSjsing static int
quic_send_alert(SSL * ssl,enum ssl_encryption_level_t level,uint8_t alert)115*fe6fa0cbSjsing quic_send_alert(SSL *ssl, enum ssl_encryption_level_t level, uint8_t alert)
116*fe6fa0cbSjsing {
117*fe6fa0cbSjsing 	return 1;
118*fe6fa0cbSjsing }
119*fe6fa0cbSjsing 
120*fe6fa0cbSjsing const SSL_QUIC_METHOD quic_method = {
121*fe6fa0cbSjsing 	.set_read_secret = quic_set_read_secret,
122*fe6fa0cbSjsing 	.set_write_secret = quic_set_write_secret,
123*fe6fa0cbSjsing 	.add_handshake_data = quic_add_handshake_data,
124*fe6fa0cbSjsing 	.flush_flight = quic_flush_flight,
125*fe6fa0cbSjsing 	.send_alert = quic_send_alert,
126*fe6fa0cbSjsing };
127*fe6fa0cbSjsing 
128*fe6fa0cbSjsing static SSL *
quic_client(struct quic_data * data)129*fe6fa0cbSjsing quic_client(struct quic_data *data)
130*fe6fa0cbSjsing {
131*fe6fa0cbSjsing 	SSL_CTX *ssl_ctx = NULL;
132*fe6fa0cbSjsing 	SSL *ssl = NULL;
133*fe6fa0cbSjsing 
134*fe6fa0cbSjsing 	if ((ssl_ctx = SSL_CTX_new(TLS_method())) == NULL)
135*fe6fa0cbSjsing 		errx(1, "client context");
136*fe6fa0cbSjsing 
137*fe6fa0cbSjsing 	if (!SSL_CTX_set_quic_method(ssl_ctx, &quic_method)) {
138*fe6fa0cbSjsing 		fprintf(stderr, "FAIL: Failed to set QUIC method\n");
139*fe6fa0cbSjsing 		goto failure;
140*fe6fa0cbSjsing 	}
141*fe6fa0cbSjsing 
142*fe6fa0cbSjsing 	if ((ssl = SSL_new(ssl_ctx)) == NULL)
143*fe6fa0cbSjsing 		errx(1, "client ssl");
144*fe6fa0cbSjsing 
145*fe6fa0cbSjsing 	SSL_set_connect_state(ssl);
146*fe6fa0cbSjsing 	SSL_set_app_data(ssl, data);
147*fe6fa0cbSjsing 
148*fe6fa0cbSjsing  failure:
149*fe6fa0cbSjsing 	SSL_CTX_free(ssl_ctx);
150*fe6fa0cbSjsing 
151*fe6fa0cbSjsing 	return ssl;
152*fe6fa0cbSjsing }
153*fe6fa0cbSjsing 
154*fe6fa0cbSjsing static SSL *
quic_server(struct quic_data * data)155*fe6fa0cbSjsing quic_server(struct quic_data *data)
156*fe6fa0cbSjsing {
157*fe6fa0cbSjsing 	SSL_CTX *ssl_ctx = NULL;
158*fe6fa0cbSjsing 	SSL *ssl = NULL;
159*fe6fa0cbSjsing 
160*fe6fa0cbSjsing 	if ((ssl_ctx = SSL_CTX_new(TLS_method())) == NULL)
161*fe6fa0cbSjsing 		errx(1, "server context");
162*fe6fa0cbSjsing 
163*fe6fa0cbSjsing 	SSL_CTX_set_dh_auto(ssl_ctx, 2);
164*fe6fa0cbSjsing 
165*fe6fa0cbSjsing 	if (SSL_CTX_use_certificate_file(ssl_ctx, server_cert_file,
166*fe6fa0cbSjsing 	    SSL_FILETYPE_PEM) != 1) {
167*fe6fa0cbSjsing 		fprintf(stderr, "FAIL: Failed to load server certificate\n");
168*fe6fa0cbSjsing 		goto failure;
169*fe6fa0cbSjsing 	}
170*fe6fa0cbSjsing 	if (SSL_CTX_use_PrivateKey_file(ssl_ctx, server_key_file,
171*fe6fa0cbSjsing 	    SSL_FILETYPE_PEM) != 1) {
172*fe6fa0cbSjsing 		fprintf(stderr, "FAIL: Failed to load server private key\n");
173*fe6fa0cbSjsing 		goto failure;
174*fe6fa0cbSjsing 	}
175*fe6fa0cbSjsing 
176*fe6fa0cbSjsing 	if (!SSL_CTX_set_quic_method(ssl_ctx, &quic_method)) {
177*fe6fa0cbSjsing 		fprintf(stderr, "FAIL: Failed to set QUIC method\n");
178*fe6fa0cbSjsing 		goto failure;
179*fe6fa0cbSjsing 	}
180*fe6fa0cbSjsing 
181*fe6fa0cbSjsing 	if ((ssl = SSL_new(ssl_ctx)) == NULL)
182*fe6fa0cbSjsing 		errx(1, "server ssl");
183*fe6fa0cbSjsing 
184*fe6fa0cbSjsing 	SSL_set_accept_state(ssl);
185*fe6fa0cbSjsing 	SSL_set_app_data(ssl, data);
186*fe6fa0cbSjsing 
187*fe6fa0cbSjsing  failure:
188*fe6fa0cbSjsing 	SSL_CTX_free(ssl_ctx);
189*fe6fa0cbSjsing 
190*fe6fa0cbSjsing 	return ssl;
191*fe6fa0cbSjsing }
192*fe6fa0cbSjsing 
193*fe6fa0cbSjsing static int
ssl_error(SSL * ssl,const char * name,const char * desc,int ssl_ret)194*fe6fa0cbSjsing ssl_error(SSL *ssl, const char *name, const char *desc, int ssl_ret)
195*fe6fa0cbSjsing {
196*fe6fa0cbSjsing 	int ssl_err;
197*fe6fa0cbSjsing 
198*fe6fa0cbSjsing 	ssl_err = SSL_get_error(ssl, ssl_ret);
199*fe6fa0cbSjsing 
200*fe6fa0cbSjsing 	if (ssl_err == SSL_ERROR_WANT_READ) {
201*fe6fa0cbSjsing 		if (quic_read_handshake_data(ssl) < 0)
202*fe6fa0cbSjsing 			return 0;
203*fe6fa0cbSjsing 		return 1;
204*fe6fa0cbSjsing 	} else if (ssl_err == SSL_ERROR_WANT_WRITE) {
205*fe6fa0cbSjsing 		return 1;
206*fe6fa0cbSjsing 	} else if (ssl_err == SSL_ERROR_SYSCALL && errno == 0) {
207*fe6fa0cbSjsing 		/* Yup, this is apparently a thing... */
208*fe6fa0cbSjsing 	} else {
209*fe6fa0cbSjsing 		fprintf(stderr, "FAIL: %s %s failed - ssl err = %d, errno = %d\n",
210*fe6fa0cbSjsing 		    name, desc, ssl_err, errno);
211*fe6fa0cbSjsing 		ERR_print_errors_fp(stderr);
212*fe6fa0cbSjsing 		return 0;
213*fe6fa0cbSjsing 	}
214*fe6fa0cbSjsing 
215*fe6fa0cbSjsing 	return 1;
216*fe6fa0cbSjsing }
217*fe6fa0cbSjsing 
218*fe6fa0cbSjsing static int
do_handshake(SSL * ssl,const char * name,int * done)219*fe6fa0cbSjsing do_handshake(SSL *ssl, const char *name, int *done)
220*fe6fa0cbSjsing {
221*fe6fa0cbSjsing 	int ssl_ret;
222*fe6fa0cbSjsing 
223*fe6fa0cbSjsing 	if ((ssl_ret = SSL_do_handshake(ssl)) == 1) {
224*fe6fa0cbSjsing 		fprintf(stderr, "INFO: %s handshake done\n", name);
225*fe6fa0cbSjsing 		*done = 1;
226*fe6fa0cbSjsing 		return 1;
227*fe6fa0cbSjsing 	}
228*fe6fa0cbSjsing 
229*fe6fa0cbSjsing 	return ssl_error(ssl, name, "handshake", ssl_ret);
230*fe6fa0cbSjsing }
231*fe6fa0cbSjsing 
232*fe6fa0cbSjsing typedef int (*ssl_func)(SSL *ssl, const char *name, int *done);
233*fe6fa0cbSjsing 
234*fe6fa0cbSjsing static int
do_client_server_loop(SSL * client,ssl_func client_func,SSL * server,ssl_func server_func)235*fe6fa0cbSjsing do_client_server_loop(SSL *client, ssl_func client_func, SSL *server,
236*fe6fa0cbSjsing     ssl_func server_func)
237*fe6fa0cbSjsing {
238*fe6fa0cbSjsing 	int client_done = 0, server_done = 0;
239*fe6fa0cbSjsing 	int i = 0;
240*fe6fa0cbSjsing 
241*fe6fa0cbSjsing 	do {
242*fe6fa0cbSjsing 		if (!client_done) {
243*fe6fa0cbSjsing 			if (debug)
244*fe6fa0cbSjsing 				fprintf(stderr, "DEBUG: client loop\n");
245*fe6fa0cbSjsing 			if (!client_func(client, "client", &client_done))
246*fe6fa0cbSjsing 				return 0;
247*fe6fa0cbSjsing 		}
248*fe6fa0cbSjsing 		if (!server_done) {
249*fe6fa0cbSjsing 			if (debug)
250*fe6fa0cbSjsing 				fprintf(stderr, "DEBUG: server loop\n");
251*fe6fa0cbSjsing 			if (!server_func(server, "server", &server_done))
252*fe6fa0cbSjsing 				return 0;
253*fe6fa0cbSjsing 		}
254*fe6fa0cbSjsing 	} while (i++ < 100 && (!client_done || !server_done));
255*fe6fa0cbSjsing 
256*fe6fa0cbSjsing 	if (!client_done || !server_done)
257*fe6fa0cbSjsing 		fprintf(stderr, "FAIL: gave up\n");
258*fe6fa0cbSjsing 
259*fe6fa0cbSjsing 	return client_done && server_done;
260*fe6fa0cbSjsing }
261*fe6fa0cbSjsing 
262*fe6fa0cbSjsing static int
quictest(void)263*fe6fa0cbSjsing quictest(void)
264*fe6fa0cbSjsing {
265*fe6fa0cbSjsing 	struct quic_data *client_data = NULL, *server_data = NULL;
266*fe6fa0cbSjsing 	BIO *client_wbio = NULL, *server_wbio = NULL;
267*fe6fa0cbSjsing 	SSL *client = NULL, *server = NULL;
268*fe6fa0cbSjsing 	int failed = 1;
269*fe6fa0cbSjsing 
270*fe6fa0cbSjsing 	if ((client_wbio = BIO_new(BIO_s_mem())) == NULL)
271*fe6fa0cbSjsing 		goto failure;
272*fe6fa0cbSjsing 	if (BIO_set_mem_eof_return(client_wbio, -1) <= 0)
273*fe6fa0cbSjsing 		goto failure;
274*fe6fa0cbSjsing 
275*fe6fa0cbSjsing 	if ((server_wbio = BIO_new(BIO_s_mem())) == NULL)
276*fe6fa0cbSjsing 		goto failure;
277*fe6fa0cbSjsing 	if (BIO_set_mem_eof_return(server_wbio, -1) <= 0)
278*fe6fa0cbSjsing 		goto failure;
279*fe6fa0cbSjsing 
280*fe6fa0cbSjsing 	if ((client_data = calloc(1, sizeof(*client_data))) == NULL)
281*fe6fa0cbSjsing 		goto failure;
282*fe6fa0cbSjsing 
283*fe6fa0cbSjsing 	client_data->rbio = server_wbio;
284*fe6fa0cbSjsing 	client_data->wbio = client_wbio;
285*fe6fa0cbSjsing 
286*fe6fa0cbSjsing 	if ((client = quic_client(client_data)) == NULL)
287*fe6fa0cbSjsing 		goto failure;
288*fe6fa0cbSjsing 
289*fe6fa0cbSjsing 	if ((server_data = calloc(1, sizeof(*server_data))) == NULL)
290*fe6fa0cbSjsing 		goto failure;
291*fe6fa0cbSjsing 
292*fe6fa0cbSjsing 	server_data->rbio = client_wbio;
293*fe6fa0cbSjsing 	server_data->wbio = server_wbio;
294*fe6fa0cbSjsing 
295*fe6fa0cbSjsing 	if ((server = quic_server(server_data)) == NULL)
296*fe6fa0cbSjsing 		goto failure;
297*fe6fa0cbSjsing 
298*fe6fa0cbSjsing 	if (!do_client_server_loop(client, do_handshake, server, do_handshake)) {
299*fe6fa0cbSjsing 		fprintf(stderr, "FAIL: client and server handshake failed\n");
300*fe6fa0cbSjsing 		ERR_print_errors_fp(stderr);
301*fe6fa0cbSjsing 		goto failure;
302*fe6fa0cbSjsing 	}
303*fe6fa0cbSjsing 
304*fe6fa0cbSjsing 	fprintf(stderr, "INFO: Done!\n");
305*fe6fa0cbSjsing 
306*fe6fa0cbSjsing 	failed = 0;
307*fe6fa0cbSjsing 
308*fe6fa0cbSjsing  failure:
309*fe6fa0cbSjsing 	BIO_free(client_wbio);
310*fe6fa0cbSjsing 	BIO_free(server_wbio);
311*fe6fa0cbSjsing 
312*fe6fa0cbSjsing 	free(client_data);
313*fe6fa0cbSjsing 	free(server_data);
314*fe6fa0cbSjsing 
315*fe6fa0cbSjsing 	SSL_free(client);
316*fe6fa0cbSjsing 	SSL_free(server);
317*fe6fa0cbSjsing 
318*fe6fa0cbSjsing 	return failed;
319*fe6fa0cbSjsing }
320*fe6fa0cbSjsing 
321*fe6fa0cbSjsing int
main(int argc,char ** argv)322*fe6fa0cbSjsing main(int argc, char **argv)
323*fe6fa0cbSjsing {
324*fe6fa0cbSjsing 	int failed = 0;
325*fe6fa0cbSjsing 
326*fe6fa0cbSjsing 	if (argc != 4) {
327*fe6fa0cbSjsing 		fprintf(stderr, "usage: %s keyfile certfile cafile\n",
328*fe6fa0cbSjsing 		    argv[0]);
329*fe6fa0cbSjsing 		exit(1);
330*fe6fa0cbSjsing 	}
331*fe6fa0cbSjsing 
332*fe6fa0cbSjsing 	server_key_file = argv[1];
333*fe6fa0cbSjsing 	server_cert_file = argv[2];
334*fe6fa0cbSjsing 	server_ca_file = argv[3];
335*fe6fa0cbSjsing 
336*fe6fa0cbSjsing 	failed |= quictest();
337*fe6fa0cbSjsing 
338*fe6fa0cbSjsing 	return failed;
339*fe6fa0cbSjsing }
340