xref: /openbsd-src/regress/lib/libcrypto/mlkem/mlkem_tests_util.c (revision 8889493e35cc5c346e19d64639edddcd18b8ae19)
1 /*	$OpenBSD: mlkem_tests_util.c,v 1.3 2024/12/20 00:07:12 tb Exp $ */
2 /*
3  * Copyright (c) 2024 Google Inc.
4  * Copyright (c) 2024 Bob Beck <beck@obtuse.com>
5  * Copyright (c) 2024 Theo Buehler <tb@openbsd.org>
6  *
7  * Permission to use, copy, modify, and/or distribute this software for any
8  * purpose with or without fee is hereby granted, provided that the above
9  * copyright notice and this permission notice appear in all copies.
10  *
11  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
12  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
13  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
14  * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
15  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
16  * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
17  * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
18  */
19 
20 #include <err.h>
21 #include <stdint.h>
22 #include <stdio.h>
23 #include <string.h>
24 
25 #include "bytestring.h"
26 #include "mlkem.h"
27 
28 #include "mlkem_internal.h"
29 
30 #include "mlkem_tests_util.h"
31 
32 int failure;
33 int test_number;
34 
35 static void
36 hexdump(const uint8_t *buf, size_t len, const uint8_t *compare)
37 {
38 	const char *mark = "";
39 	size_t i;
40 
41 	for (i = 1; i <= len; i++) {
42 		if (compare != NULL)
43 			mark = (buf[i - 1] != compare[i - 1]) ? "*" : " ";
44 		fprintf(stderr, " %s0x%02hhx,%s", mark, buf[i - 1],
45 		    i % 8 && i != len ? "" : "\n");
46 	}
47 	fprintf(stderr, "\n");
48 }
49 
50 int
51 compare_data(const uint8_t *want, const uint8_t *got, size_t len, size_t line,
52     const char *msg)
53 {
54 	if (memcmp(want, got, len) == 0)
55 		return 0;
56 
57 	warnx("FAIL: #%zu - %s differs", line, msg);
58 	fprintf(stderr, "want:\n");
59 	hexdump(want, len, got);
60 	fprintf(stderr, "got:\n");
61 	hexdump(got, len, want);
62 	fprintf(stderr, "\n");
63 
64 	return 1;
65 }
66 
67 int
68 compare_length(size_t want, size_t got, size_t line, const char *msg)
69 {
70 	if (want == got)
71 		return 1;
72 
73 	warnx("#%zu: %s: want %zu, got %zu", line, msg, want, got);
74 	return 0;
75 }
76 
77 static int
78 hex_get_nibble_cbs(CBS *cbs, uint8_t *out_nibble)
79 {
80 	uint8_t c;
81 
82 	if (!CBS_get_u8(cbs, &c))
83 		return 0;
84 
85 	if (c >= '0' && c <= '9') {
86 		*out_nibble = c - '0';
87 		return 1;
88 	}
89 	if (c >= 'a' && c <= 'f') {
90 		*out_nibble = c - 'a' + 10;
91 		return 1;
92 	}
93 	if (c >= 'A' && c <= 'F') {
94 		*out_nibble = c - 'A' + 10;
95 		return 1;
96 	}
97 
98 	return 0;
99 }
100 
101 void
102 hex_decode_cbs(CBS *cbs, CBB *cbb, size_t line, const char *msg)
103 {
104 	if (!CBB_init(cbb, 0))
105 		errx(1, "#%zu %s: %s CBB_init", line, msg, __func__);
106 
107 	while (CBS_len(cbs) > 0) {
108 		uint8_t hi, lo;
109 
110 		if (!hex_get_nibble_cbs(cbs, &hi))
111 			errx(1, "#%zu %s: %s nibble", line, msg, __func__);
112 		if (!hex_get_nibble_cbs(cbs, &lo))
113 			errx(1, "#%zu %s: %s nibble", line, msg, __func__);
114 
115 		if (!CBB_add_u8(cbb, hi << 4 | lo))
116 			errx(1, "#%zu %s: %s CBB_add_u8", line, msg, __func__);
117 	}
118 }
119 
120 int
121 get_string_cbs(CBS *cbs_in, const char *str, size_t line, const char *msg)
122 {
123 	CBS cbs;
124 	size_t len = strlen(str);
125 
126 	if (!CBS_get_bytes(cbs_in, &cbs, len))
127 		errx(1, "#%zu %s: %s CBB_get_bytes", line, msg, __func__);
128 
129 	return CBS_mem_equal(&cbs, str, len);
130 }
131 
132 int
133 mlkem768_encode_private_key(const struct MLKEM768_private_key *priv,
134     uint8_t **out_buf, size_t *out_len)
135 {
136 	CBB cbb;
137 	int ret = 0;
138 
139 	if (!CBB_init(&cbb, MLKEM768_PUBLIC_KEY_BYTES))
140 		goto err;
141 	if (!MLKEM768_marshal_private_key(&cbb, priv))
142 		goto err;
143 	if (!CBB_finish(&cbb, out_buf, out_len))
144 		goto err;
145 
146 	ret = 1;
147 
148  err:
149 	CBB_cleanup(&cbb);
150 
151 	return ret;
152 }
153 
154 int
155 mlkem768_encode_public_key(const struct MLKEM768_public_key *pub,
156     uint8_t **out_buf, size_t *out_len)
157 {
158 	CBB cbb;
159 	int ret = 0;
160 
161 	if (!CBB_init(&cbb, MLKEM768_PUBLIC_KEY_BYTES))
162 		goto err;
163 	if (!MLKEM768_marshal_public_key(&cbb, pub))
164 		goto err;
165 	if (!CBB_finish(&cbb, out_buf, out_len))
166 		goto err;
167 
168 	ret = 1;
169 
170  err:
171 	CBB_cleanup(&cbb);
172 
173 	return ret;
174 }
175 
176 int
177 mlkem1024_encode_private_key(const struct MLKEM1024_private_key *priv,
178     uint8_t **out_buf, size_t *out_len)
179 {
180 	CBB cbb;
181 	int ret = 0;
182 
183 	if (!CBB_init(&cbb, MLKEM1024_PUBLIC_KEY_BYTES))
184 		goto err;
185 	if (!MLKEM1024_marshal_private_key(&cbb, priv))
186 		goto err;
187 	if (!CBB_finish(&cbb, out_buf, out_len))
188 		goto err;
189 
190 	ret = 1;
191 
192  err:
193 	CBB_cleanup(&cbb);
194 
195 	return ret;
196 }
197 
198 int
199 mlkem1024_encode_public_key(const struct MLKEM1024_public_key *pub,
200     uint8_t **out_buf, size_t *out_len)
201 {
202 	CBB cbb;
203 	int ret = 0;
204 
205 	if (!CBB_init(&cbb, MLKEM1024_PUBLIC_KEY_BYTES))
206 		goto err;
207 	if (!MLKEM1024_marshal_public_key(&cbb, pub))
208 		goto err;
209 	if (!CBB_finish(&cbb, out_buf, out_len))
210 		goto err;
211 
212 	ret = 1;
213 
214  err:
215 	CBB_cleanup(&cbb);
216 
217 	return ret;
218 }
219