xref: /openbsd-src/regress/lib/libcrypto/bn/bn_isqrt.c (revision 12347e819a161ea252d407d10e9bab13ecdbe9e2)
1 /*	$OpenBSD: bn_isqrt.c,v 1.4 2023/08/03 18:53:56 tb Exp $ */
2 /*
3  * Copyright (c) 2022 Theo Buehler <tb@openbsd.org>
4  *
5  * Permission to use, copy, modify, and distribute this software for any
6  * purpose with or without fee is hereby granted, provided that the above
7  * copyright notice and this permission notice appear in all copies.
8  *
9  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
10  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
11  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
12  * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
13  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
14  * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
15  * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
16  */
17 
18 #include <err.h>
19 #include <stdint.h>
20 #include <stdio.h>
21 #include <stdlib.h>
22 #include <string.h>
23 #include <unistd.h>
24 
25 #include <openssl/bn.h>
26 
27 #include "bn_local.h"
28 
29 #define N_TESTS		100
30 
31 /* Sample squares between 2^128 and 2^4096. */
32 #define LOWER_BITS	128
33 #define UPPER_BITS	4096
34 
35 extern const uint8_t is_square_mod_11[];
36 extern const uint8_t is_square_mod_63[];
37 extern const uint8_t is_square_mod_64[];
38 extern const uint8_t is_square_mod_65[];
39 
40 static void
hexdump(const unsigned char * buf,size_t len)41 hexdump(const unsigned char *buf, size_t len)
42 {
43 	size_t i;
44 
45 	for (i = 1; i <= len; i++)
46 		fprintf(stderr, " 0x%02hhx,%s", buf[i - 1], i % 8 ? "" : "\n");
47 
48 	if (len % 8)
49 		fprintf(stderr, "\n");
50 }
51 
52 static const uint8_t *
get_table(int modulus)53 get_table(int modulus)
54 {
55 	switch (modulus) {
56 	case 11:
57 		return is_square_mod_11;
58 	case 63:
59 		return is_square_mod_63;
60 	case 64:
61 		return is_square_mod_64;
62 	case 65:
63 		return is_square_mod_65;
64 	default:
65 		return NULL;
66 	}
67 }
68 
69 static int
check_tables(int print)70 check_tables(int print)
71 {
72 	int fill[] = {11, 63, 64, 65};
73 	const uint8_t *table;
74 	uint8_t q[65];
75 	size_t i;
76 	int j;
77 	int failed = 0;
78 
79 	for (i = 0; i < sizeof(fill) / sizeof(fill[0]); i++) {
80 		memset(q, 0, sizeof(q));
81 
82 		for (j = 0; j < fill[i]; j++)
83 			q[(j * j) % fill[i]] = 1;
84 
85 		if ((table = get_table(fill[i])) == NULL) {
86 			fprintf(stderr, "failed to get table %d\n", fill[i]);
87 			failed |= 1;
88 			continue;
89 		}
90 
91 		if (memcmp(table, q, fill[i]) != 0) {
92 			fprintf(stderr, "table %d does not match:\n", fill[i]);
93 			fprintf(stderr, "want:\n");
94 			hexdump(table, fill[i]);
95 			fprintf(stderr, "got:\n");
96 			hexdump(q, fill[i]);
97 			failed |= 1;
98 			continue;
99 		}
100 
101 		if (!print)
102 			continue;
103 
104 		printf("const uint8_t is_square_mod_%d[] = {\n\t", fill[i]);
105 		for (j = 0; j < fill[i]; j++) {
106 			const char *end = " ";
107 
108 			if (j % 16 == 15)
109 				end = "\n\t";
110 			if (j + 1 == fill[i])
111 				end = "";
112 
113 			printf("%d,%s", q[j], end);
114 		}
115 		printf("\n};\nCTASSERT(sizeof(is_square_mod_%d) == %d);\n\n",
116 		    fill[i], fill[i]);
117 	}
118 
119 	return failed;
120 }
121 
122 static int
validate_tables(void)123 validate_tables(void)
124 {
125 	int fill[] = {11, 63, 64, 65};
126 	const uint8_t *table;
127 	size_t i;
128 	int j, k;
129 	int failed = 0;
130 
131 	for (i = 0; i < sizeof(fill) / sizeof(fill[0]); i++) {
132 		if ((table = get_table(fill[i])) == NULL) {
133 			fprintf(stderr, "failed to get table %d\n", fill[i]);
134 			failed |= 1;
135 			continue;
136 		}
137 
138 		for (j = 0; j < fill[i]; j++) {
139 			for (k = 0; k < fill[i]; k++) {
140 				if (j == (k * k) % fill[i])
141 					break;
142 			}
143 
144 			if (table[j] == 0 && k < fill[i]) {
145 				fprintf(stderr, "%d == %d^2 (mod %d)", j, k,
146 				    fill[i]);
147 				failed |= 1;
148 			}
149 			if (table[j] == 1 && k == fill[i]) {
150 				fprintf(stderr, "%d not a square (mod %d)", j,
151 				    fill[i]);
152 				failed |= 1;
153 			}
154 		}
155 	}
156 
157 	return failed;
158 }
159 
160 /*
161  * Choose a random number n of bit length between LOWER_BITS and UPPER_BITS and
162  * check that n == isqrt(n^2). Random numbers n^2 <= testcase < (n + 1)^2 are
163  * checked to have isqrt(testcase) == n.
164  */
165 static int
isqrt_test(void)166 isqrt_test(void)
167 {
168 	BN_CTX *ctx;
169 	BIGNUM *n, *n_sqr, *lower, *upper, *testcase, *isqrt;
170 	int cmp, i, is_perfect_square;
171 	int failed = 0;
172 
173 	if ((ctx = BN_CTX_new()) == NULL)
174 		errx(1, "BN_CTX_new");
175 
176 	BN_CTX_start(ctx);
177 
178 	if ((lower = BN_CTX_get(ctx)) == NULL)
179 		errx(1, "lower = BN_CTX_get(ctx)");
180 	if ((upper = BN_CTX_get(ctx)) == NULL)
181 		errx(1, "upper = BN_CTX_get(ctx)");
182 	if ((n = BN_CTX_get(ctx)) == NULL)
183 		errx(1, "n = BN_CTX_get(ctx)");
184 	if ((n_sqr = BN_CTX_get(ctx)) == NULL)
185 		errx(1, "n = BN_CTX_get(ctx)");
186 	if ((isqrt = BN_CTX_get(ctx)) == NULL)
187 		errx(1, "result = BN_CTX_get(ctx)");
188 	if ((testcase = BN_CTX_get(ctx)) == NULL)
189 		errx(1, "testcase = BN_CTX_get(ctx)");
190 
191 	/* lower = 2^LOWER_BITS, upper = 2^UPPER_BITS. */
192 	if (!BN_set_bit(lower, LOWER_BITS))
193 		errx(1, "BN_set_bit(lower, %d)", LOWER_BITS);
194 	if (!BN_set_bit(upper, UPPER_BITS))
195 		errx(1, "BN_set_bit(upper, %d)", UPPER_BITS);
196 
197 	if (!bn_rand_in_range(n, lower, upper))
198 		errx(1, "bn_rand_in_range n");
199 
200 	/* n_sqr = n^2 */
201 	if (!BN_sqr(n_sqr, n, ctx))
202 		errx(1, "BN_sqr");
203 
204 	if (!bn_isqrt(isqrt, &is_perfect_square, n_sqr, ctx))
205 		errx(1, "bn_isqrt n_sqr");
206 
207 	if ((cmp = BN_cmp(n, isqrt)) != 0 || !is_perfect_square) {
208 		fprintf(stderr, "n = ");
209 		BN_print_fp(stderr, n);
210 		fprintf(stderr, "\nn^2 is_perfect_square: %d, cmp: %d\n",
211 		    is_perfect_square, cmp);
212 		failed = 1;
213 	}
214 
215 	/* upper = 2 * n + 1 */
216 	if (!BN_lshift1(upper, n))
217 		errx(1, "BN_lshift1(upper, n)");
218 	if (!BN_add_word(upper, 1))
219 		errx(1, "BN_sub_word(upper, 1)");
220 
221 	/* upper = (n + 1)^2 = n^2 + upper */
222 	if (!BN_add(upper, n_sqr, upper))
223 		errx(1, "BN_add");
224 
225 	/*
226 	 * Check that isqrt((n + 1)^2) - 1 == n.
227 	 */
228 
229 	if (!bn_isqrt(isqrt, &is_perfect_square, upper, ctx))
230 		errx(1, "bn_isqrt(upper)");
231 
232 	if (!BN_sub_word(isqrt, 1))
233 		errx(1, "BN_add_word(isqrt, 1)");
234 
235 	if ((cmp = BN_cmp(n, isqrt)) != 0 || !is_perfect_square) {
236 		fprintf(stderr, "n = ");
237 		BN_print_fp(stderr, n);
238 		fprintf(stderr, "\n(n + 1)^2 is_perfect_square: %d, cmp: %d\n",
239 		    is_perfect_square, cmp);
240 		failed = 1;
241 	}
242 
243 	/*
244 	 * Test N_TESTS random numbers n^2 <= testcase < (n + 1)^2 and check
245 	 * that their isqrt is n.
246 	 */
247 
248 	for (i = 0; i < N_TESTS; i++) {
249 		if (!bn_rand_in_range(testcase, n_sqr, upper))
250 			errx(1, "bn_rand_in_range testcase");
251 
252 		if (!bn_isqrt(isqrt, &is_perfect_square, testcase, ctx))
253 			errx(1, "bn_isqrt testcase");
254 
255 		if ((cmp = BN_cmp(n, isqrt)) != 0 ||
256 		    (is_perfect_square && BN_cmp(n_sqr, testcase) != 0)) {
257 			fprintf(stderr, "n = ");
258 			BN_print_fp(stderr, n);
259 			fprintf(stderr, "\ntestcase = ");
260 			BN_print_fp(stderr, testcase);
261 			fprintf(stderr,
262 			    "\ntestcase is_perfect_square: %d, cmp: %d\n",
263 			    is_perfect_square, cmp);
264 			failed = 1;
265 		}
266 	}
267 
268 	/*
269 	 * Finally check that isqrt(n^2 - 1) + 1 == n.
270 	 */
271 
272 	if (!BN_sub(testcase, n_sqr, BN_value_one()))
273 		errx(1, "BN_sub(testcase, n_sqr, 1)");
274 
275 	if (!bn_isqrt(isqrt, &is_perfect_square, testcase, ctx))
276 		errx(1, "bn_isqrt(n_sqr - 1)");
277 
278 	if (!BN_add_word(isqrt, 1))
279 		errx(1, "BN_add_word(isqrt, 1)");
280 
281 	if ((cmp = BN_cmp(n, isqrt)) != 0 || is_perfect_square) {
282 		fprintf(stderr, "n = ");
283 		BN_print_fp(stderr, n);
284 		fprintf(stderr, "\nn_sqr - 1 is_perfect_square: %d, cmp: %d\n",
285 		    is_perfect_square, cmp);
286 		failed = 1;
287 	}
288 
289 	BN_CTX_end(ctx);
290 	BN_CTX_free(ctx);
291 
292 	return failed;
293 }
294 
295 static void
usage(void)296 usage(void)
297 {
298 	fprintf(stderr, "usage: bn_isqrt [-C]\n");
299 	exit(1);
300 }
301 
302 int
main(int argc,char * argv[])303 main(int argc, char *argv[])
304 {
305 	size_t i;
306 	int ch;
307 	int failed = 0, print = 0;
308 
309 	while ((ch = getopt(argc, argv, "C")) != -1) {
310 		switch (ch) {
311 		case 'C':
312 			print = 1;
313 			break;
314 		default:
315 			usage();
316 			break;
317 		}
318 	}
319 
320 	if (print)
321 		return check_tables(1);
322 
323 	for (i = 0; i < N_TESTS; i++)
324 		failed |= isqrt_test();
325 
326 	failed |= check_tables(0);
327 	failed |= validate_tables();
328 
329 	return failed;
330 }
331