xref: /openbsd-src/regress/lib/libssl/handshake/handshake_table.c (revision 99fd087599a8791921855f21bd7e36130f39aadc)
1 /*	$OpenBSD: handshake_table.c,v 1.11 2019/04/05 20:25:25 tb Exp $	*/
2 /*
3  * Copyright (c) 2019 Theo Buehler <tb@openbsd.org>
4  *
5  * Permission to use, copy, modify, and distribute this software for any
6  * purpose with or without fee is hereby granted, provided that the above
7  * copyright notice and this permission notice appear in all copies.
8  *
9  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
10  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
11  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
12  * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
13  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
14  * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
15  * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
16  */
17 
18 #include <err.h>
19 #include <stdint.h>
20 #include <stdio.h>
21 #include <stdlib.h>
22 #include <unistd.h>
23 
24 #include "tls13_handshake.h"
25 
26 /*
27  * From RFC 8446:
28  *
29  * Appendix A.  State Machine
30  *
31  *    This appendix provides a summary of the legal state transitions for
32  *    the client and server handshakes.  State names (in all capitals,
33  *    e.g., START) have no formal meaning but are provided for ease of
34  *    comprehension.  Actions which are taken only in certain circumstances
35  *    are indicated in [].  The notation "K_{send,recv} = foo" means "set
36  *    the send/recv key to the given key".
37  *
38  * A.1.  Client
39  *
40  *                               START <----+
41  *                Send ClientHello |        | Recv HelloRetryRequest
42  *           [K_send = early data] |        |
43  *                                 v        |
44  *            /                 WAIT_SH ----+
45  *            |                    | Recv ServerHello
46  *            |                    | K_recv = handshake
47  *        Can |                    V
48  *       send |                 WAIT_EE
49  *      early |                    | Recv EncryptedExtensions
50  *       data |           +--------+--------+
51  *            |     Using |                 | Using certificate
52  *            |       PSK |                 v
53  *            |           |            WAIT_CERT_CR
54  *            |           |        Recv |       | Recv CertificateRequest
55  *            |           | Certificate |       v
56  *            |           |             |    WAIT_CERT
57  *            |           |             |       | Recv Certificate
58  *            |           |             v       v
59  *            |           |              WAIT_CV
60  *            |           |                 | Recv CertificateVerify
61  *            |           +> WAIT_FINISHED <+
62  *            |                  | Recv Finished
63  *            \                  | [Send EndOfEarlyData]
64  *                               | K_send = handshake
65  *                               | [Send Certificate [+ CertificateVerify]]
66  *     Can send                  | Send Finished
67  *     app data   -->            | K_send = K_recv = application
68  *     after here                v
69  *                           CONNECTED
70  *
71  *    Note that with the transitions as shown above, clients may send
72  *    alerts that derive from post-ServerHello messages in the clear or
73  *    with the early data keys.  If clients need to send such alerts, they
74  *    SHOULD first rekey to the handshake keys if possible.
75  *
76  */
77 
78 struct child {
79 	enum tls13_message_type	mt;
80 	uint8_t			flag;
81 	uint8_t			forced;
82 	uint8_t			illegal;
83 };
84 
85 #define DEFAULT			0x00
86 
87 static struct child stateinfo[][TLS13_NUM_MESSAGE_TYPES] = {
88 	[CLIENT_HELLO] = {
89 		{SERVER_HELLO, DEFAULT, 0, 0},
90 	},
91 	[SERVER_HELLO] = {
92 		{SERVER_ENCRYPTED_EXTENSIONS, DEFAULT, 0, 0},
93 		{CLIENT_HELLO_RETRY, WITH_HRR, 0, 0},
94 	},
95 	[CLIENT_HELLO_RETRY] = {
96 		{SERVER_HELLO_RETRY, DEFAULT, 0, 0},
97 	},
98 	[SERVER_HELLO_RETRY] = {
99 		{SERVER_ENCRYPTED_EXTENSIONS, DEFAULT, 0, 0},
100 	},
101 	[SERVER_ENCRYPTED_EXTENSIONS] = {
102 		{SERVER_CERTIFICATE_REQUEST, DEFAULT, 0, 0},
103 		{SERVER_CERTIFICATE, WITHOUT_CR, 0, 0},
104 		{SERVER_FINISHED, WITH_PSK, 0, 0},
105 	},
106 	[SERVER_CERTIFICATE_REQUEST] = {
107 		{SERVER_CERTIFICATE, DEFAULT, 0, 0},
108 	},
109 	[SERVER_CERTIFICATE] = {
110 		{SERVER_CERTIFICATE_VERIFY, DEFAULT, 0, 0},
111 	},
112 	[SERVER_CERTIFICATE_VERIFY] = {
113 		{SERVER_FINISHED, DEFAULT, 0, 0},
114 	},
115 	[SERVER_FINISHED] = {
116 		{CLIENT_FINISHED, DEFAULT, WITHOUT_CR | WITH_PSK, 0},
117 		{CLIENT_CERTIFICATE, DEFAULT, 0, WITHOUT_CR | WITH_PSK},
118 	},
119 	[CLIENT_CERTIFICATE] = {
120 		{CLIENT_FINISHED, DEFAULT, 0, 0},
121 		{CLIENT_CERTIFICATE_VERIFY, WITH_CCV, 0, 0},
122 	},
123 	[CLIENT_CERTIFICATE_VERIFY] = {
124 		{CLIENT_FINISHED, DEFAULT, 0, 0},
125 	},
126 	[CLIENT_FINISHED] = {
127 		{APPLICATION_DATA, DEFAULT, 0, 0},
128 	},
129 	[APPLICATION_DATA] = {
130 		{0, DEFAULT, 0, 0},
131 	},
132 };
133 
134 const size_t	 stateinfo_count = sizeof(stateinfo) / sizeof(stateinfo[0]);
135 
136 void		 build_table(enum tls13_message_type
137 		     table[UINT8_MAX][TLS13_NUM_MESSAGE_TYPES],
138 		     struct child current, struct child end,
139 		     struct child path[], uint8_t flags, unsigned int depth);
140 size_t		 count_handshakes(void);
141 void		 edge(enum tls13_message_type start,
142 		     enum tls13_message_type end, uint8_t flag);
143 const char	*flag2str(uint8_t flag);
144 void		 flag_label(uint8_t flag);
145 void		 forced_edges(enum tls13_message_type start,
146 		     enum tls13_message_type end, uint8_t forced);
147 int		 generate_graphics(void);
148 void		 fprint_entry(FILE *stream,
149 		     enum tls13_message_type path[TLS13_NUM_MESSAGE_TYPES],
150 		     uint8_t flags);
151 void		 fprint_flags(FILE *stream, uint8_t flags);
152 const char	*mt2str(enum tls13_message_type mt);
153 __dead void	 usage(void);
154 int		 verify_table(enum tls13_message_type
155 		     table[UINT8_MAX][TLS13_NUM_MESSAGE_TYPES], int print);
156 
157 const char *
158 flag2str(uint8_t flag)
159 {
160 	const char *ret;
161 
162 	if (flag & (flag - 1))
163 		errx(1, "more than one bit is set");
164 
165 	switch (flag) {
166 	case INITIAL:
167 		ret = "INITIAL";
168 		break;
169 	case NEGOTIATED:
170 		ret = "NEGOTIATED";
171 		break;
172 	case WITHOUT_CR:
173 		ret = "WITHOUT_CR";
174 		break;
175 	case WITH_HRR:
176 		ret = "WITH_HRR";
177 		break;
178 	case WITH_PSK:
179 		ret = "WITH_PSK";
180 		break;
181 	case WITH_CCV:
182 		ret = "WITH_CCV";
183 		break;
184 	case WITH_0RTT:
185 		ret = "WITH_0RTT";
186 		break;
187 	default:
188 		ret = "UNKNOWN";
189 	}
190 
191 	return ret;
192 }
193 
194 const char *
195 mt2str(enum tls13_message_type mt)
196 {
197 	const char *ret;
198 
199 	switch (mt) {
200 	case INVALID:
201 		ret = "INVALID";
202 		break;
203 	case CLIENT_HELLO:
204 		ret = "CLIENT_HELLO";
205 		break;
206 	case CLIENT_HELLO_RETRY:
207 		ret = "CLIENT_HELLO_RETRY";
208 		break;
209 	case CLIENT_END_OF_EARLY_DATA:
210 		ret = "CLIENT_END_OF_EARLY_DATA";
211 		break;
212 	case CLIENT_CERTIFICATE:
213 		ret = "CLIENT_CERTIFICATE";
214 		break;
215 	case CLIENT_CERTIFICATE_VERIFY:
216 		ret = "CLIENT_CERTIFICATE_VERIFY";
217 		break;
218 	case CLIENT_FINISHED:
219 		ret = "CLIENT_FINISHED";
220 		break;
221 	case CLIENT_KEY_UPDATE:
222 		ret = "CLIENT_KEY_UPDATE";
223 		break;
224 	case SERVER_HELLO:
225 		ret = "SERVER_HELLO";
226 		break;
227 	case SERVER_HELLO_RETRY:
228 		ret = "SERVER_HELLO_RETRY";
229 		break;
230 	case SERVER_NEW_SESSION_TICKET:
231 		ret = "SERVER_NEW_SESSION_TICKET";
232 		break;
233 	case SERVER_ENCRYPTED_EXTENSIONS:
234 		ret = "SERVER_ENCRYPTED_EXTENSIONS";
235 		break;
236 	case SERVER_CERTIFICATE:
237 		ret = "SERVER_CERTIFICATE";
238 		break;
239 	case SERVER_CERTIFICATE_VERIFY:
240 		ret = "SERVER_CERTIFICATE_VERIFY";
241 		break;
242 	case SERVER_CERTIFICATE_REQUEST:
243 		ret = "SERVER_CERTIFICATE_REQUEST";
244 		break;
245 	case SERVER_FINISHED:
246 		ret = "SERVER_FINISHED";
247 		break;
248 	case APPLICATION_DATA:
249 		ret = "APPLICATION_DATA";
250 		break;
251 	case TLS13_NUM_MESSAGE_TYPES:
252 		ret = "TLS13_NUM_MESSAGE_TYPES";
253 		break;
254 	default:
255 		ret = "UNKNOWN";
256 		break;
257 	}
258 
259 	return ret;
260 }
261 
262 void
263 fprint_flags(FILE *stream, uint8_t flags)
264 {
265 	int first = 1, i;
266 
267 	if (flags == 0) {
268 		fprintf(stream, "%s", flag2str(flags));
269 		return;
270 	}
271 
272 	for (i = 0; i < 8; i++) {
273 		uint8_t set = flags & (1U << i);
274 
275 		if (set) {
276 			fprintf(stream, "%s%s", first ? "" : " | ",
277 			    flag2str(set));
278 			first = 0;
279 		}
280 	}
281 }
282 
283 void
284 fprint_entry(FILE *stream,
285     enum tls13_message_type path[TLS13_NUM_MESSAGE_TYPES], uint8_t flags)
286 {
287 	int i;
288 
289 	fprintf(stream, "\t[");
290 	fprint_flags(stream, flags);
291 	fprintf(stream, "] = {\n");
292 
293 	for (i = 0; i < TLS13_NUM_MESSAGE_TYPES; i++) {
294 		if (path[i] == 0)
295 			break;
296 		fprintf(stream, "\t\t%s,\n", mt2str(path[i]));
297 	}
298 	fprintf(stream, "\t},\n");
299 }
300 
301 void
302 edge(enum tls13_message_type start, enum tls13_message_type end,
303     uint8_t flag)
304 {
305 	printf("\t%s -> %s", mt2str(start), mt2str(end));
306 	flag_label(flag);
307 	printf(";\n");
308 }
309 
310 void
311 flag_label(uint8_t flag)
312 {
313 	if (flag)
314 		printf(" [label=\"%s\"]", flag2str(flag));
315 }
316 
317 void
318 forced_edges(enum tls13_message_type start, enum tls13_message_type end,
319     uint8_t forced)
320 {
321 	uint8_t	forced_flag, i;
322 
323 	if (forced == 0)
324 		return;
325 
326 	for (i = 0; i < 8; i++) {
327 		forced_flag = forced & (1U << i);
328 		if (forced_flag)
329 			edge(start, end, forced_flag);
330 	}
331 }
332 
333 int
334 generate_graphics(void)
335 {
336 	enum tls13_message_type	start, end;
337 	unsigned int		child;
338 	uint8_t			flag;
339 	uint8_t			forced;
340 
341 	printf("digraph G {\n");
342 	printf("\t%s [shape=box];\n", mt2str(CLIENT_HELLO));
343 	printf("\t%s [shape=box];\n", mt2str(APPLICATION_DATA));
344 
345 	for (start = CLIENT_HELLO; start < APPLICATION_DATA; start++) {
346 		for (child = 0; stateinfo[start][child].mt != 0; child++) {
347 			end = stateinfo[start][child].mt;
348 			flag = stateinfo[start][child].flag;
349 			forced = stateinfo[start][child].forced;
350 
351 			if (forced == 0)
352 				edge(start, end, flag);
353 			else
354 				forced_edges(start, end, forced);
355 		}
356 	}
357 
358 	printf("}\n");
359 	return 0;
360 }
361 
362 extern enum tls13_message_type	handshakes[][TLS13_NUM_MESSAGE_TYPES];
363 extern size_t			handshake_count;
364 
365 size_t
366 count_handshakes(void)
367 {
368 	size_t	ret = 0, i;
369 
370 	for (i = 0; i < handshake_count; i++) {
371 		if (handshakes[i][0] != INVALID)
372 			ret++;
373 	}
374 
375 	return ret;
376 }
377 
378 void
379 build_table(enum tls13_message_type table[UINT8_MAX][TLS13_NUM_MESSAGE_TYPES],
380     struct child current, struct child end, struct child path[], uint8_t flags,
381     unsigned int depth)
382 {
383 	unsigned int i;
384 
385 	if (depth >= TLS13_NUM_MESSAGE_TYPES - 1)
386 		errx(1, "recursed too deeply");
387 
388 	/* Record current node. */
389 	path[depth++] = current;
390 	flags |= current.flag;
391 
392 	/* If we haven't reached the end, recurse over the children. */
393 	if (current.mt != end.mt) {
394 		for (i = 0; stateinfo[current.mt][i].mt != 0; i++) {
395 			struct child child = stateinfo[current.mt][i];
396 			int forced = stateinfo[current.mt][i].forced;
397 			int illegal = stateinfo[current.mt][i].illegal;
398 
399 			if ((forced == 0 || (forced & flags)) &&
400 			    (illegal == 0 || !(illegal & flags)))
401 				build_table(table, child, end, path, flags,
402 				    depth);
403 		}
404 		return;
405 	}
406 
407 	if (flags == 0)
408 		errx(1, "path does not set flags");
409 
410 	if (table[flags][0] != 0)
411 		errx(1, "path traversed twice");
412 
413 	for (i = 0; i < depth; i++)
414 		table[flags][i] = path[i].mt;
415 }
416 
417 int
418 verify_table(enum tls13_message_type table[UINT8_MAX][TLS13_NUM_MESSAGE_TYPES],
419     int print)
420 {
421 	int	success = 1, i;
422 	size_t	num_valid, num_found = 0;
423 	uint8_t	flags = 0;
424 
425 	do {
426 		if (table[flags][0] == 0)
427 			continue;
428 
429 		num_found++;
430 
431 		for (i = 0; i < TLS13_NUM_MESSAGE_TYPES; i++) {
432 			if (table[flags][i] != handshakes[flags][i]) {
433 				fprintf(stderr,
434 				    "incorrect entry %d of handshake ", i);
435 				fprint_flags(stderr, flags);
436 				fprintf(stderr, "\n");
437 				success = 0;
438 			}
439 		}
440 
441 		if (print)
442 			fprint_entry(stdout, table[flags], flags);
443 	} while(++flags != 0);
444 
445 	num_valid = count_handshakes();
446 	if (num_valid != num_found) {
447 		fprintf(stderr,
448 		    "incorrect number of handshakes: want %zu, got %zu.\n",
449 		    num_valid, num_found);
450 		success = 0;
451 	}
452 
453 	return success;
454 }
455 
456 __dead void
457 usage(void)
458 {
459 	fprintf(stderr, "usage: handshake_table [-C | -g]\n");
460 	exit(1);
461 }
462 
463 int
464 main(int argc, char *argv[])
465 {
466 	static enum tls13_message_type
467 	    hs_table[UINT8_MAX][TLS13_NUM_MESSAGE_TYPES] = {
468 		[INITIAL] = {
469 			CLIENT_HELLO,
470 			SERVER_HELLO,
471 		},
472 	};
473 	struct child	start = {
474 		CLIENT_HELLO, DEFAULT, 0, 0,
475 	};
476 	struct child	end = {
477 		APPLICATION_DATA, DEFAULT, 0, 0,
478 	};
479 	struct child	path[TLS13_NUM_MESSAGE_TYPES] = {{0}};
480 	uint8_t		flags = NEGOTIATED;
481 	unsigned int	depth = 0;
482 	int		ch, graphviz = 0, print = 0;
483 
484 	while ((ch = getopt(argc, argv, "Cg")) != -1) {
485 		switch (ch) {
486 		case 'C':
487 			print = 1;
488 			break;
489 		case 'g':
490 			graphviz = 1;
491 			break;
492 		default:
493 			usage();
494 		}
495 	}
496 	argc -= optind;
497 	argv += optind;
498 
499 	if (argc != 0)
500 		usage();
501 
502 	if (graphviz && print)
503 		usage();
504 
505 	if (graphviz)
506 		return generate_graphics();
507 
508 	build_table(hs_table, start, end, path, flags, depth);
509 	if (!verify_table(hs_table, print))
510 		return 1;
511 
512 	if (!print)
513 		printf("SUCCESS\n");
514 
515 	return 0;
516 }
517