xref: /openbsd-src/regress/lib/libssl/handshake/handshake_table.c (revision 9a54de93546c4704261e5531082e48fff2e8073b)
1 /*	$OpenBSD: handshake_table.c,v 1.10 2019/02/13 17:04:17 tb Exp $	*/
2 /*
3  * Copyright (c) 2019 Theo Buehler <tb@openbsd.org>
4  *
5  * Permission to use, copy, modify, and distribute this software for any
6  * purpose with or without fee is hereby granted, provided that the above
7  * copyright notice and this permission notice appear in all copies.
8  *
9  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
10  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
11  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
12  * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
13  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
14  * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
15  * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
16  */
17 
18 #include <err.h>
19 #include <stdint.h>
20 #include <stdio.h>
21 #include <stdlib.h>
22 #include <unistd.h>
23 
24 #include "tls13_handshake.h"
25 
26 /*
27  * From RFC 8446:
28  *
29  * Appendix A.  State Machine
30  *
31  *    This appendix provides a summary of the legal state transitions for
32  *    the client and server handshakes.  State names (in all capitals,
33  *    e.g., START) have no formal meaning but are provided for ease of
34  *    comprehension.  Actions which are taken only in certain circumstances
35  *    are indicated in [].  The notation "K_{send,recv} = foo" means "set
36  *    the send/recv key to the given key".
37  *
38  * A.1.  Client
39  *
40  *                               START <----+
41  *                Send ClientHello |        | Recv HelloRetryRequest
42  *           [K_send = early data] |        |
43  *                                 v        |
44  *            /                 WAIT_SH ----+
45  *            |                    | Recv ServerHello
46  *            |                    | K_recv = handshake
47  *        Can |                    V
48  *       send |                 WAIT_EE
49  *      early |                    | Recv EncryptedExtensions
50  *       data |           +--------+--------+
51  *            |     Using |                 | Using certificate
52  *            |       PSK |                 v
53  *            |           |            WAIT_CERT_CR
54  *            |           |        Recv |       | Recv CertificateRequest
55  *            |           | Certificate |       v
56  *            |           |             |    WAIT_CERT
57  *            |           |             |       | Recv Certificate
58  *            |           |             v       v
59  *            |           |              WAIT_CV
60  *            |           |                 | Recv CertificateVerify
61  *            |           +> WAIT_FINISHED <+
62  *            |                  | Recv Finished
63  *            \                  | [Send EndOfEarlyData]
64  *                               | K_send = handshake
65  *                               | [Send Certificate [+ CertificateVerify]]
66  *     Can send                  | Send Finished
67  *     app data   -->            | K_send = K_recv = application
68  *     after here                v
69  *                           CONNECTED
70  *
71  *    Note that with the transitions as shown above, clients may send
72  *    alerts that derive from post-ServerHello messages in the clear or
73  *    with the early data keys.  If clients need to send such alerts, they
74  *    SHOULD first rekey to the handshake keys if possible.
75  *
76  */
77 
78 struct child {
79 	enum tls13_message_type	mt;
80 	uint8_t			flag;
81 	uint8_t			forced;
82 	uint8_t			illegal;
83 };
84 
85 #define DEFAULT			0x00
86 
87 static struct child stateinfo[][TLS13_NUM_MESSAGE_TYPES] = {
88 	[CLIENT_HELLO] = {
89 		{SERVER_HELLO, DEFAULT, 0, 0},
90 	},
91 	[SERVER_HELLO] = {
92 		{SERVER_ENCRYPTED_EXTENSIONS, DEFAULT, 0, 0},
93 		{CLIENT_HELLO_RETRY, WITH_HRR, 0, 0},
94 	},
95 	[CLIENT_HELLO_RETRY] = {
96 		{SERVER_ENCRYPTED_EXTENSIONS, DEFAULT, 0, 0},
97 	},
98 	[SERVER_ENCRYPTED_EXTENSIONS] = {
99 		{SERVER_CERTIFICATE_REQUEST, DEFAULT, 0, 0},
100 		{SERVER_CERTIFICATE, WITHOUT_CR, 0, 0},
101 		{SERVER_FINISHED, WITH_PSK, 0, 0},
102 	},
103 	[SERVER_CERTIFICATE_REQUEST] = {
104 		{SERVER_CERTIFICATE, DEFAULT, 0, 0},
105 	},
106 	[SERVER_CERTIFICATE] = {
107 		{SERVER_CERTIFICATE_VERIFY, DEFAULT, 0, 0},
108 	},
109 	[SERVER_CERTIFICATE_VERIFY] = {
110 		{SERVER_FINISHED, DEFAULT, 0, 0},
111 	},
112 	[SERVER_FINISHED] = {
113 		{CLIENT_FINISHED, DEFAULT, WITHOUT_CR | WITH_PSK, 0},
114 		{CLIENT_CERTIFICATE, DEFAULT, 0, WITHOUT_CR | WITH_PSK},
115 	},
116 	[CLIENT_CERTIFICATE] = {
117 		{CLIENT_FINISHED, DEFAULT, 0, 0},
118 		{CLIENT_CERTIFICATE_VERIFY, WITH_CCV, 0, 0},
119 	},
120 	[CLIENT_CERTIFICATE_VERIFY] = {
121 		{CLIENT_FINISHED, DEFAULT, 0, 0},
122 	},
123 	[CLIENT_FINISHED] = {
124 		{APPLICATION_DATA, DEFAULT, 0, 0},
125 	},
126 	[APPLICATION_DATA] = {
127 		{0, DEFAULT, 0, 0},
128 	},
129 };
130 
131 const size_t	 stateinfo_count = sizeof(stateinfo) / sizeof(stateinfo[0]);
132 
133 void		 build_table(enum tls13_message_type
134 		     table[UINT8_MAX][TLS13_NUM_MESSAGE_TYPES],
135 		     struct child current, struct child end,
136 		     struct child path[], uint8_t flags, unsigned int depth);
137 size_t		 count_handshakes(void);
138 void		 edge(enum tls13_message_type start,
139 		     enum tls13_message_type end, uint8_t flag);
140 const char	*flag2str(uint8_t flag);
141 void		 flag_label(uint8_t flag);
142 void		 forced_edges(enum tls13_message_type start,
143 		     enum tls13_message_type end, uint8_t forced);
144 int		 generate_graphics(void);
145 void		 fprint_entry(FILE *stream,
146 		     enum tls13_message_type path[TLS13_NUM_MESSAGE_TYPES],
147 		     uint8_t flags);
148 void		 fprint_flags(FILE *stream, uint8_t flags);
149 const char	*mt2str(enum tls13_message_type mt);
150 __dead void	 usage(void);
151 int		 verify_table(enum tls13_message_type
152 		     table[UINT8_MAX][TLS13_NUM_MESSAGE_TYPES], int print);
153 
154 const char *
155 flag2str(uint8_t flag)
156 {
157 	const char *ret;
158 
159 	if (flag & (flag - 1))
160 		errx(1, "more than one bit is set");
161 
162 	switch (flag) {
163 	case INITIAL:
164 		ret = "INITIAL";
165 		break;
166 	case NEGOTIATED:
167 		ret = "NEGOTIATED";
168 		break;
169 	case WITHOUT_CR:
170 		ret = "WITHOUT_CR";
171 		break;
172 	case WITH_HRR:
173 		ret = "WITH_HRR";
174 		break;
175 	case WITH_PSK:
176 		ret = "WITH_PSK";
177 		break;
178 	case WITH_CCV:
179 		ret = "WITH_CCV";
180 		break;
181 	case WITH_0RTT:
182 		ret = "WITH_0RTT";
183 		break;
184 	default:
185 		ret = "UNKNOWN";
186 	}
187 
188 	return ret;
189 }
190 
191 const char *
192 mt2str(enum tls13_message_type mt)
193 {
194 	const char *ret;
195 
196 	switch (mt) {
197 	case INVALID:
198 		ret = "INVALID";
199 		break;
200 	case CLIENT_HELLO:
201 		ret = "CLIENT_HELLO";
202 		break;
203 	case CLIENT_HELLO_RETRY:
204 		ret = "CLIENT_HELLO_RETRY";
205 		break;
206 	case CLIENT_END_OF_EARLY_DATA:
207 		ret = "CLIENT_END_OF_EARLY_DATA";
208 		break;
209 	case CLIENT_CERTIFICATE:
210 		ret = "CLIENT_CERTIFICATE";
211 		break;
212 	case CLIENT_CERTIFICATE_VERIFY:
213 		ret = "CLIENT_CERTIFICATE_VERIFY";
214 		break;
215 	case CLIENT_FINISHED:
216 		ret = "CLIENT_FINISHED";
217 		break;
218 	case CLIENT_KEY_UPDATE:
219 		ret = "CLIENT_KEY_UPDATE";
220 		break;
221 	case SERVER_HELLO:
222 		ret = "SERVER_HELLO";
223 		break;
224 	case SERVER_NEW_SESSION_TICKET:
225 		ret = "SERVER_NEW_SESSION_TICKET";
226 		break;
227 	case SERVER_ENCRYPTED_EXTENSIONS:
228 		ret = "SERVER_ENCRYPTED_EXTENSIONS";
229 		break;
230 	case SERVER_CERTIFICATE:
231 		ret = "SERVER_CERTIFICATE";
232 		break;
233 	case SERVER_CERTIFICATE_VERIFY:
234 		ret = "SERVER_CERTIFICATE_VERIFY";
235 		break;
236 	case SERVER_CERTIFICATE_REQUEST:
237 		ret = "SERVER_CERTIFICATE_REQUEST";
238 		break;
239 	case SERVER_FINISHED:
240 		ret = "SERVER_FINISHED";
241 		break;
242 	case APPLICATION_DATA:
243 		ret = "APPLICATION_DATA";
244 		break;
245 	case TLS13_NUM_MESSAGE_TYPES:
246 		ret = "TLS13_NUM_MESSAGE_TYPES";
247 		break;
248 	default:
249 		ret = "UNKNOWN";
250 		break;
251 	}
252 
253 	return ret;
254 }
255 
256 void
257 fprint_flags(FILE *stream, uint8_t flags)
258 {
259 	int first = 1, i;
260 
261 	if (flags == 0) {
262 		fprintf(stream, "%s", flag2str(flags));
263 		return;
264 	}
265 
266 	for (i = 0; i < 8; i++) {
267 		uint8_t set = flags & (1U << i);
268 
269 		if (set) {
270 			fprintf(stream, "%s%s", first ? "" : " | ",
271 			    flag2str(set));
272 			first = 0;
273 		}
274 	}
275 }
276 
277 void
278 fprint_entry(FILE *stream,
279     enum tls13_message_type path[TLS13_NUM_MESSAGE_TYPES], uint8_t flags)
280 {
281 	int i;
282 
283 	fprintf(stream, "\t[");
284 	fprint_flags(stream, flags);
285 	fprintf(stream, "] = {\n");
286 
287 	for (i = 0; i < TLS13_NUM_MESSAGE_TYPES; i++) {
288 		if (path[i] == 0)
289 			break;
290 		fprintf(stream, "\t\t%s,\n", mt2str(path[i]));
291 	}
292 	fprintf(stream, "\t},\n");
293 }
294 
295 void
296 edge(enum tls13_message_type start, enum tls13_message_type end,
297     uint8_t flag)
298 {
299 	printf("\t%s -> %s", mt2str(start), mt2str(end));
300 	flag_label(flag);
301 	printf(";\n");
302 }
303 
304 void
305 flag_label(uint8_t flag)
306 {
307 	if (flag)
308 		printf(" [label=\"%s\"]", flag2str(flag));
309 }
310 
311 void
312 forced_edges(enum tls13_message_type start, enum tls13_message_type end,
313     uint8_t forced)
314 {
315 	uint8_t	forced_flag, i;
316 
317 	if (forced == 0)
318 		return;
319 
320 	for (i = 0; i < 8; i++) {
321 		forced_flag = forced & (1U << i);
322 		if (forced_flag)
323 			edge(start, end, forced_flag);
324 	}
325 }
326 
327 int
328 generate_graphics(void)
329 {
330 	enum tls13_message_type	start, end;
331 	unsigned int		child;
332 	uint8_t			flag;
333 	uint8_t			forced;
334 
335 	printf("digraph G {\n");
336 	printf("\t%s [shape=box];\n", mt2str(CLIENT_HELLO));
337 	printf("\t%s [shape=box];\n", mt2str(APPLICATION_DATA));
338 
339 	for (start = CLIENT_HELLO; start < APPLICATION_DATA; start++) {
340 		for (child = 0; stateinfo[start][child].mt != 0; child++) {
341 			end = stateinfo[start][child].mt;
342 			flag = stateinfo[start][child].flag;
343 			forced = stateinfo[start][child].forced;
344 
345 			if (forced == 0)
346 				edge(start, end, flag);
347 			else
348 				forced_edges(start, end, forced);
349 		}
350 	}
351 
352 	printf("}\n");
353 	return 0;
354 }
355 
356 extern enum tls13_message_type	handshakes[][TLS13_NUM_MESSAGE_TYPES];
357 extern size_t			handshake_count;
358 
359 size_t
360 count_handshakes(void)
361 {
362 	size_t	ret = 0, i;
363 
364 	for (i = 0; i < handshake_count; i++) {
365 		if (handshakes[i][0] != INVALID)
366 			ret++;
367 	}
368 
369 	return ret;
370 }
371 
372 void
373 build_table(enum tls13_message_type table[UINT8_MAX][TLS13_NUM_MESSAGE_TYPES],
374     struct child current, struct child end, struct child path[], uint8_t flags,
375     unsigned int depth)
376 {
377 	unsigned int i;
378 
379 	if (depth >= TLS13_NUM_MESSAGE_TYPES - 1)
380 		errx(1, "recursed too deeply");
381 
382 	/* Record current node. */
383 	path[depth++] = current;
384 	flags |= current.flag;
385 
386 	/* If we haven't reached the end, recurse over the children. */
387 	if (current.mt != end.mt) {
388 		for (i = 0; stateinfo[current.mt][i].mt != 0; i++) {
389 			struct child child = stateinfo[current.mt][i];
390 			int forced = stateinfo[current.mt][i].forced;
391 			int illegal = stateinfo[current.mt][i].illegal;
392 
393 			if ((forced == 0 || (forced & flags)) &&
394 			    (illegal == 0 || !(illegal & flags)))
395 				build_table(table, child, end, path, flags,
396 				    depth);
397 		}
398 		return;
399 	}
400 
401 	if (flags == 0)
402 		errx(1, "path does not set flags");
403 
404 	if (table[flags][0] != 0)
405 		errx(1, "path traversed twice");
406 
407 	for (i = 0; i < depth; i++)
408 		table[flags][i] = path[i].mt;
409 }
410 
411 int
412 verify_table(enum tls13_message_type table[UINT8_MAX][TLS13_NUM_MESSAGE_TYPES],
413     int print)
414 {
415 	int	success = 1, i;
416 	size_t	num_valid, num_found = 0;
417 	uint8_t	flags = 0;
418 
419 	do {
420 		if (table[flags][0] == 0)
421 			continue;
422 
423 		num_found++;
424 
425 		for (i = 0; i < TLS13_NUM_MESSAGE_TYPES; i++) {
426 			if (table[flags][i] != handshakes[flags][i]) {
427 				fprintf(stderr,
428 				    "incorrect entry %d of handshake ", i);
429 				fprint_flags(stderr, flags);
430 				fprintf(stderr, "\n");
431 				success = 0;
432 			}
433 		}
434 
435 		if (print)
436 			fprint_entry(stdout, table[flags], flags);
437 	} while(++flags != 0);
438 
439 	num_valid = count_handshakes();
440 	if (num_valid != num_found) {
441 		fprintf(stderr,
442 		    "incorrect number of handshakes: want %zu, got %zu.\n",
443 		    num_valid, num_found);
444 		success = 0;
445 	}
446 
447 	return success;
448 }
449 
450 __dead void
451 usage(void)
452 {
453 	fprintf(stderr, "usage: handshake_table [-C | -g]\n");
454 	exit(1);
455 }
456 
457 int
458 main(int argc, char *argv[])
459 {
460 	static enum tls13_message_type
461 	    hs_table[UINT8_MAX][TLS13_NUM_MESSAGE_TYPES] = {
462 		[INITIAL] = {
463 			CLIENT_HELLO,
464 			SERVER_HELLO,
465 		},
466 	};
467 	struct child	start = {
468 		CLIENT_HELLO, DEFAULT, 0, 0,
469 	};
470 	struct child	end = {
471 		APPLICATION_DATA, DEFAULT, 0, 0,
472 	};
473 	struct child	path[TLS13_NUM_MESSAGE_TYPES] = {{0}};
474 	uint8_t		flags = NEGOTIATED;
475 	unsigned int	depth = 0;
476 	int		ch, graphviz = 0, print = 0;
477 
478 	while ((ch = getopt(argc, argv, "Cg")) != -1) {
479 		switch (ch) {
480 		case 'C':
481 			print = 1;
482 			break;
483 		case 'g':
484 			graphviz = 1;
485 			break;
486 		default:
487 			usage();
488 		}
489 	}
490 	argc -= optind;
491 	argv += optind;
492 
493 	if (argc != 0)
494 		usage();
495 
496 	if (graphviz && print)
497 		usage();
498 
499 	if (graphviz)
500 		return generate_graphics();
501 
502 	build_table(hs_table, start, end, path, flags, depth);
503 	if (!verify_table(hs_table, print))
504 		return 1;
505 
506 	if (!print)
507 		printf("SUCCESS\n");
508 
509 	return 0;
510 }
511