xref: /openbsd-src/regress/lib/libcrypto/mlkem/parse_test_file.c (revision a22ce74601c5d6878965e0c1b7b4f0a3c9646976)
1 /*	$OpenBSD: parse_test_file.c,v 1.3 2024/12/27 11:17:48 tb Exp $ */
2 
3 /*
4  * Copyright (c) 2024 Theo Buehler <tb@openbsd.org>
5  *
6  * Permission to use, copy, modify, and distribute this software for any
7  * purpose with or without fee is hereby granted, provided that the above
8  * copyright notice and this permission notice appear in all copies.
9  *
10  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
11  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
12  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
13  * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
14  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
15  * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
16  * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
17  */
18 
19 #include <sys/types.h>
20 
21 #include <assert.h>
22 #include <err.h>
23 #include <stdarg.h>
24 #include <stdint.h>
25 #include <stdio.h>
26 #include <stdlib.h>
27 #include <string.h>
28 
29 #include "bytestring.h"
30 
31 #include "parse_test_file.h"
32 
33 struct line_data {
34 	uint8_t	*data;
35 	size_t	 data_len;
36 	CBS	 cbs;
37 	int	 val;
38 };
39 
40 static struct line_data *
41 line_data_new(void)
42 {
43 	return calloc(1, sizeof(struct line_data));
44 }
45 
46 static void
47 line_data_clear(struct line_data *ld)
48 {
49 	freezero(ld->data, ld->data_len);
50 	explicit_bzero(ld, sizeof(*ld));
51 }
52 
53 static void
54 line_data_free(struct line_data *ld)
55 {
56 	if (ld == NULL)
57 		return;
58 	line_data_clear(ld);
59 	free(ld);
60 }
61 
62 static void
63 line_data_get_int(struct line_data *ld, int *out)
64 {
65 	*out = ld->val;
66 }
67 
68 static void
69 line_data_get_cbs(struct line_data *ld, CBS *out)
70 {
71 	CBS_dup(&ld->cbs, out);
72 }
73 
74 static void
75 line_data_set_int(struct line_data *ld, int val)
76 {
77 	ld->val = val;
78 }
79 
80 static int
81 line_data_set_from_cbb(struct line_data *ld, CBB *cbb)
82 {
83 	if (!CBB_finish(cbb, &ld->data, &ld->data_len))
84 		return 0;
85 
86 	CBS_init(&ld->cbs, ld->data, ld->data_len);
87 
88 	return 1;
89 }
90 
91 struct parse_state {
92 	size_t line;
93 	size_t test;
94 
95 	size_t max;
96 	size_t cur;
97 	struct line_data **data;
98 
99 	size_t instruction_max;
100 	size_t instruction_cur;
101 	struct line_data **instruction_data;
102 
103 	int running_test_case;
104 };
105 
106 static void
107 parse_state_init(struct parse_state *ps, size_t max, size_t instruction_max)
108 {
109 	size_t i;
110 
111 	assert(max > 0);
112 
113 	memset(ps, 0, sizeof(*ps));
114 	ps->test = 1;
115 
116 	ps->max = max;
117 	if ((ps->data = calloc(max, sizeof(*ps->data))) == NULL)
118 		err(1, NULL);
119 	for (i = 0; i < max; i++) {
120 		if ((ps->data[i] = line_data_new()) == NULL)
121 			err(1, NULL);
122 	}
123 
124 	if ((ps->instruction_max = instruction_max) > 0) {
125 		if ((ps->instruction_data = calloc(instruction_max,
126 		    sizeof(*ps->instruction_data))) == NULL)
127 			err(1, NULL);
128 		for (i = 0; i < instruction_max; i++)
129 			if ((ps->instruction_data[i] = line_data_new()) == NULL)
130 				err(1, NULL);
131 	}
132 }
133 
134 static void
135 parse_state_finish(struct parse_state *ps)
136 {
137 	size_t i;
138 
139 	for (i = 0; i < ps->max; i++)
140 		line_data_free(ps->data[i]);
141 	free(ps->data);
142 
143 	for (i = 0; i < ps->instruction_max; i++)
144 		line_data_free(ps->instruction_data[i]);
145 	free(ps->instruction_data);
146 }
147 
148 static void
149 parse_state_new_line(struct parse_state *ps)
150 {
151 	ps->line++;
152 }
153 
154 static void
155 parse_instruction_advance(struct parse_state *ps)
156 {
157 	assert(ps->instruction_cur < ps->instruction_max);
158 	ps->instruction_cur++;
159 }
160 
161 static void
162 parse_state_advance(struct parse_state *ps)
163 {
164 	assert(ps->cur < ps->max);
165 
166 	ps->cur++;
167 	if ((ps->cur %= ps->max) == 0)
168 		ps->test++;
169 }
170 
171 struct parse {
172 	struct parse_state state;
173 
174 	char	*buf;
175 	size_t	 buf_max;
176 	CBS	 cbs;
177 
178 	const struct test_parse *tctx;
179 	void *ctx;
180 
181 	const char *fn;
182 	FILE *fp;
183 };
184 
185 static int
186 parse_instructions_parsed(struct parse *p)
187 {
188 	return p->state.instruction_max == p->state.instruction_cur;
189 }
190 
191 static void
192 parse_advance(struct parse *p)
193 {
194 	if (!parse_instructions_parsed(p)) {
195 		parse_instruction_advance(&p->state);
196 		return;
197 	}
198 	parse_state_advance(&p->state);
199 }
200 
201 static size_t
202 parse_max(struct parse *p)
203 {
204 	return p->state.max;
205 }
206 
207 static size_t
208 parse_instruction_max(struct parse *p)
209 {
210 	return p->state.instruction_max;
211 }
212 
213 static size_t
214 parse_cur(struct parse *p)
215 {
216 	if (!parse_instructions_parsed(p)) {
217 		assert(p->state.instruction_cur < p->state.instruction_max);
218 		return p->state.instruction_cur;
219 	}
220 
221 	assert(p->state.cur < parse_max(p));
222 	return p->state.cur;
223 }
224 
225 static size_t
226 parse_must_run_test_case(struct parse *p)
227 {
228 	return parse_instructions_parsed(p) && parse_max(p) - parse_cur(p) == 1;
229 }
230 
231 static const struct line_spec *
232 parse_states(struct parse *p)
233 {
234 	if (!parse_instructions_parsed(p))
235 		return p->tctx->instructions;
236 	return p->tctx->states;
237 }
238 
239 static const struct line_spec *
240 parse_instruction_states(struct parse *p)
241 {
242 	return p->tctx->instructions;
243 }
244 
245 static const struct line_spec *
246 parse_state(struct parse *p)
247 {
248 	return &parse_states(p)[parse_cur(p)];
249 }
250 
251 static size_t
252 line(struct parse *p)
253 {
254 	return p->state.line;
255 }
256 
257 static size_t
258 test(struct parse *p)
259 {
260 	return p->state.test;
261 }
262 
263 static const char *
264 name(struct parse *p)
265 {
266 	if (p->state.running_test_case)
267 		return "running test case";
268 	return parse_state(p)->name;
269 }
270 
271 static const char *
272 label(struct parse *p)
273 {
274 	return parse_state(p)->label;
275 }
276 
277 static const char *
278 match(struct parse *p)
279 {
280 	return parse_state(p)->match;
281 }
282 
283 static enum line
284 parse_line_type(struct parse *p)
285 {
286 	return parse_state(p)->type;
287 }
288 
289 static void
290 parse_vinfo(struct parse *p, const char *fmt, va_list ap)
291 {
292 	fprintf(stderr, "%s:%zu test #%zu (%s): ",
293 	    p->fn, line(p), test(p), name(p));
294 	vfprintf(stderr, fmt, ap);
295 	fprintf(stderr, "\n");
296 }
297 
298 void
299 parse_info(struct parse *p, const char *fmt, ...)
300 {
301 	va_list ap;
302 
303 	va_start(ap, fmt);
304 	parse_vinfo(p, fmt, ap);
305 	va_end(ap);
306 }
307 
308 void
309 parse_errx(struct parse *p, const char *fmt, ...)
310 {
311 	va_list ap;
312 
313 	va_start(ap, fmt);
314 	parse_vinfo(p, fmt, ap);
315 	va_end(ap);
316 
317 	exit(1);
318 }
319 
320 int
321 parse_length_equal(struct parse *p, const char *descr, size_t want, size_t got)
322 {
323 	if (want == got)
324 		return 1;
325 
326 	parse_info(p, "%s length: want %zu, got %zu", descr, want, got);
327 	return 0;
328 }
329 
330 static void
331 hexdump(const uint8_t *buf, size_t len, const uint8_t *compare)
332 {
333 	const char *mark = "", *newline;
334 	size_t i;
335 
336 	for (i = 1; i <= len; i++) {
337 		if (compare != NULL)
338 			mark = (buf[i - 1] != compare[i - 1]) ? "*" : " ";
339 		newline = i % 8 ? "" : "\n";
340 		fprintf(stderr, " %s0x%02x,%s", mark, buf[i - 1], newline);
341 	}
342 	if ((len % 8) != 0)
343 		fprintf(stderr, "\n");
344 }
345 
346 int
347 parse_data_equal(struct parse *p, const char *descr, CBS *want,
348     const uint8_t *got, size_t got_len)
349 {
350 	if (!parse_length_equal(p, descr, CBS_len(want), got_len))
351 		return 0;
352 	if (CBS_mem_equal(want, got, got_len))
353 		return 1;
354 
355 	parse_info(p, "%s differs", descr);
356 	fprintf(stderr, "want:\n");
357 	hexdump(CBS_data(want), CBS_len(want), got);
358 	fprintf(stderr, "got:\n");
359 	hexdump(got, got_len, CBS_data(want));
360 	fprintf(stderr, "\n");
361 
362 	return 0;
363 }
364 
365 static void
366 parse_line_data_clear(struct parse *p)
367 {
368 	size_t i;
369 
370 	for (i = 0; i < parse_max(p); i++)
371 		line_data_clear(p->state.data[i]);
372 }
373 
374 static struct line_data **
375 parse_state_data(struct parse *p)
376 {
377 	if (!parse_instructions_parsed(p))
378 		return p->state.instruction_data;
379 	return p->state.data;
380 }
381 
382 static void
383 parse_state_set_int(struct parse *p, int val)
384 {
385 	if (parse_line_type(p) != LINE_STRING_MATCH)
386 		parse_errx(p, "%s: want %d, got %d", __func__,
387 		    LINE_STRING_MATCH, parse_line_type(p));
388 	line_data_set_int(parse_state_data(p)[parse_cur(p)], val);
389 }
390 
391 static void
392 parse_state_set_from_cbb(struct parse *p, CBB *cbb)
393 {
394 	if (parse_line_type(p) != LINE_HEX)
395 		parse_errx(p, "%s: want %d, got %d", __func__,
396 		    LINE_HEX, parse_line_type(p));
397 	if (!line_data_set_from_cbb(parse_state_data(p)[parse_cur(p)], cbb))
398 		parse_errx(p, "line_data_set_from_cbb");
399 }
400 
401 int
402 parse_get_int(struct parse *p, size_t idx, int *out)
403 {
404 	assert(parse_must_run_test_case(p));
405 	assert(idx < parse_max(p));
406 	assert(parse_states(p)[idx].type == LINE_STRING_MATCH);
407 
408 	line_data_get_int(p->state.data[idx], out);
409 
410 	return 1;
411 }
412 
413 int
414 parse_get_cbs(struct parse *p, size_t idx, CBS *out)
415 {
416 	assert(parse_must_run_test_case(p));
417 	assert(idx < parse_max(p));
418 	assert(parse_states(p)[idx].type == LINE_HEX);
419 
420 	line_data_get_cbs(p->state.data[idx], out);
421 
422 	return 1;
423 }
424 
425 int
426 parse_instruction_get_int(struct parse *p, size_t idx, int *out)
427 {
428 	assert(parse_must_run_test_case(p));
429 	assert(idx < parse_instruction_max(p));
430 	assert(parse_instruction_states(p)[idx].type == LINE_STRING_MATCH);
431 
432 	line_data_get_int(p->state.instruction_data[idx], out);
433 
434 	return 1;
435 }
436 
437 int
438 parse_instruction_get_cbs(struct parse *p, size_t idx, CBS *out)
439 {
440 	assert(parse_must_run_test_case(p));
441 	assert(idx < parse_instruction_max(p));
442 	assert(parse_instruction_states(p)[idx].type == LINE_HEX);
443 
444 	line_data_get_cbs(p->state.instruction_data[idx], out);
445 
446 	return 1;
447 }
448 
449 static void
450 parse_line_skip_to_end(struct parse *p)
451 {
452 	if (!CBS_skip(&p->cbs, CBS_len(&p->cbs)))
453 		parse_errx(p, "CBS_skip");
454 }
455 
456 static int
457 CBS_peek_bytes(CBS *cbs, CBS *out, size_t len)
458 {
459 	CBS dup;
460 
461 	CBS_dup(cbs, &dup);
462 	return CBS_get_bytes(&dup, out, len);
463 }
464 
465 static int
466 parse_peek_string_cbs(struct parse *p, const char *str)
467 {
468 	CBS cbs;
469 	size_t len = strlen(str);
470 
471 	if (!CBS_peek_bytes(&p->cbs, &cbs, len))
472 		parse_errx(p, "CBS_peek_bytes");
473 
474 	return CBS_mem_equal(&cbs, (const uint8_t *)str, len);
475 }
476 
477 static int
478 parse_get_string_cbs(struct parse *p, const char *str)
479 {
480 	CBS cbs;
481 	size_t len = strlen(str);
482 
483 	if (!CBS_get_bytes(&p->cbs, &cbs, len))
484 		parse_errx(p, "CBS_get_bytes");
485 
486 	return CBS_mem_equal(&cbs, (const uint8_t *)str, len);
487 }
488 
489 static int
490 parse_get_string_end_cbs(struct parse *p, const char *str)
491 {
492 	CBS cbs;
493 	int equal = 1;
494 
495 	CBS_init(&cbs, (const uint8_t *)str, strlen(str));
496 
497 	if (CBS_len(&p->cbs) < CBS_len(&cbs))
498 		parse_errx(p, "line too short to match %s", str);
499 
500 	while (CBS_len(&cbs) > 0) {
501 		uint8_t want, got;
502 
503 		if (!CBS_get_last_u8(&cbs, &want))
504 			parse_errx(p, "CBS_get_last_u8");
505 		if (!CBS_get_last_u8(&p->cbs, &got))
506 			parse_errx(p, "CBS_get_last_u8");
507 		if (want != got)
508 			equal = 0;
509 	}
510 
511 	return equal;
512 }
513 
514 static void
515 parse_check_label_matches(struct parse *p)
516 {
517 	const char *sep = ": ";
518 
519 	if (!parse_get_string_cbs(p, label(p)))
520 		parse_errx(p, "label mismatch %s", label(p));
521 
522 	/* Now we expect either ": " or " = ". */
523 	if (!parse_peek_string_cbs(p, sep))
524 		sep = " = ";
525 	if (!parse_get_string_cbs(p, sep))
526 		parse_errx(p, "error getting \"%s\"", sep);
527 }
528 
529 static int
530 parse_empty_or_comment_line(struct parse *p)
531 {
532 	if (CBS_len(&p->cbs) == 0) {
533 		return 1;
534 	}
535 	if (parse_peek_string_cbs(p, "#")) {
536 		parse_line_skip_to_end(p);
537 		return 1;
538 	}
539 	return 0;
540 }
541 
542 static void
543 parse_string_match_line(struct parse *p)
544 {
545 	int string_matches;
546 
547 	parse_check_label_matches(p);
548 
549 	string_matches = parse_get_string_cbs(p, match(p));
550 	parse_state_set_int(p, string_matches);
551 
552 	if (!string_matches)
553 		parse_line_skip_to_end(p);
554 }
555 
556 static int
557 parse_get_hex_nibble_cbs(CBS *cbs, uint8_t *out_nibble)
558 {
559 	uint8_t c;
560 
561 	if (!CBS_get_u8(cbs, &c))
562 		return 0;
563 
564 	if (c >= '0' && c <= '9') {
565 		*out_nibble = c - '0';
566 		return 1;
567 	}
568 	if (c >= 'a' && c <= 'f') {
569 		*out_nibble = c - 'a' + 10;
570 		return 1;
571 	}
572 	if (c >= 'A' && c <= 'F') {
573 		*out_nibble = c - 'A' + 10;
574 		return 1;
575 	}
576 
577 	return 0;
578 }
579 
580 static void
581 parse_hex_line(struct parse *p)
582 {
583 	CBB cbb;
584 
585 	parse_check_label_matches(p);
586 
587 	if (!CBB_init(&cbb, 0))
588 		parse_errx(p, "CBB_init");
589 
590 	while (CBS_len(&p->cbs) > 0) {
591 		uint8_t hi, lo;
592 
593 		if (!parse_get_hex_nibble_cbs(&p->cbs, &hi))
594 			parse_errx(p, "parse_get_hex_nibble_cbs");
595 		if (!parse_get_hex_nibble_cbs(&p->cbs, &lo))
596 			parse_errx(p, "parse_get_hex_nibble_cbs");
597 
598 		if (!CBB_add_u8(&cbb, hi << 4 | lo))
599 			parse_errx(p, "CBB_add_u8");
600 	}
601 
602 	parse_state_set_from_cbb(p, &cbb);
603 }
604 
605 static void
606 parse_maybe_prepare_instruction_line(struct parse *p)
607 {
608 	if (parse_instructions_parsed(p))
609 		return;
610 
611 	/* Should not happen due to parse_empty_or_comment_line(). */
612 	if (CBS_len(&p->cbs) == 0)
613 		parse_errx(p, "empty instruction line");
614 
615 	if (!parse_peek_string_cbs(p, "["))
616 		parse_errx(p, "expected instruction line");
617 	if (!parse_get_string_cbs(p, "["))
618 		parse_errx(p, "expected start of instruction line");
619 	if (!parse_get_string_end_cbs(p, "]"))
620 		parse_errx(p, "expected end of instruction line");
621 }
622 
623 static void
624 parse_check_line_consumed(struct parse *p)
625 {
626 	if (CBS_len(&p->cbs) > 0)
627 		parse_errx(p, "%zu unprocessed bytes", CBS_len(&p->cbs));
628 }
629 
630 static int
631 parse_run_test_case(struct parse *p)
632 {
633 	const struct test_parse *tctx = p->tctx;
634 
635 	p->state.running_test_case = 1;
636 	return tctx->run_test_case(p->ctx);
637 }
638 
639 static void
640 parse_reinit(struct parse *p)
641 {
642 	const struct test_parse *tctx = p->tctx;
643 
644 	p->state.running_test_case = 0;
645 	parse_line_data_clear(p);
646 	tctx->finish(p->ctx);
647 	tctx->init(p->ctx, p);
648 }
649 
650 static int
651 parse_maybe_run_test_case(struct parse *p)
652 {
653 	int failed = 0;
654 
655 	if (parse_must_run_test_case(p)) {
656 		failed |= parse_run_test_case(p);
657 		parse_reinit(p);
658 	}
659 
660 	parse_advance(p);
661 
662 	return failed;
663 }
664 
665 static int
666 parse_process_line(struct parse *p)
667 {
668 	if (parse_empty_or_comment_line(p))
669 		return 0;
670 
671 	parse_maybe_prepare_instruction_line(p);
672 
673 	switch (parse_line_type(p)) {
674 	case LINE_STRING_MATCH:
675 		parse_string_match_line(p);
676 		break;
677 	case LINE_HEX:
678 		parse_hex_line(p);
679 		break;
680 	default:
681 		parse_errx(p, "unknown line type %d", parse_line_type(p));
682 	}
683 	parse_check_line_consumed(p);
684 
685 	return parse_maybe_run_test_case(p);
686 }
687 
688 static void
689 parse_init(struct parse *p, const char *fn, const struct test_parse *tctx,
690     void *ctx)
691 {
692 	FILE *fp;
693 
694 	memset(p, 0, sizeof(*p));
695 
696 	if ((fp = fopen(fn, "r")) == NULL)
697 		err(1, "error opening %s", fn);
698 
699 	/* Poor man's basename since POSIX basename is stupid. */
700 	if ((p->fn = strrchr(fn, '/')) != NULL)
701 		p->fn++;
702 	else
703 		p->fn = fn;
704 
705 	p->fp = fp;
706 	parse_state_init(&p->state, tctx->num_states, tctx->num_instructions);
707 	p->tctx = tctx;
708 	p->ctx = ctx;
709 	tctx->init(ctx, p);
710 }
711 
712 static int
713 parse_next_line(struct parse *p)
714 {
715 	ssize_t len;
716 	uint8_t u8;
717 
718 	if ((len = getline(&p->buf, &p->buf_max, p->fp)) == -1)
719 		return 0;
720 
721 	CBS_init(&p->cbs, (const uint8_t *)p->buf, len);
722 	parse_state_new_line(&p->state);
723 
724 	if (!CBS_get_last_u8(&p->cbs, &u8))
725 		parse_errx(p, "CBS_get_last_u8");
726 
727 	assert(u8 == '\n');
728 
729 	return 1;
730 }
731 
732 static void
733 parse_finish(struct parse *p)
734 {
735 	parse_state_finish(&p->state);
736 
737 	free(p->buf);
738 
739 	if (ferror(p->fp))
740 		err(1, "%s", p->fn);
741 	fclose(p->fp);
742 }
743 
744 int
745 parse_test_file(const char *fn, const struct test_parse *tctx, void *ctx)
746 {
747 	struct parse p;
748 	int failed = 0;
749 
750 	parse_init(&p, fn, tctx, ctx);
751 
752 	while (parse_next_line(&p))
753 		failed |= parse_process_line(&p);
754 
755 	parse_finish(&p);
756 
757 	return failed;
758 }
759