1*12347e81Stb /* $OpenBSD: bn_isqrt.c,v 1.4 2023/08/03 18:53:56 tb Exp $ */
2b8d22d11Stb /*
3b8d22d11Stb * Copyright (c) 2022 Theo Buehler <tb@openbsd.org>
4b8d22d11Stb *
5b8d22d11Stb * Permission to use, copy, modify, and distribute this software for any
6b8d22d11Stb * purpose with or without fee is hereby granted, provided that the above
7b8d22d11Stb * copyright notice and this permission notice appear in all copies.
8b8d22d11Stb *
9b8d22d11Stb * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
10b8d22d11Stb * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
11b8d22d11Stb * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
12b8d22d11Stb * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
13b8d22d11Stb * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
14b8d22d11Stb * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
15b8d22d11Stb * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
16b8d22d11Stb */
17b8d22d11Stb
18b8d22d11Stb #include <err.h>
1920c8b67dStb #include <stdint.h>
2020c8b67dStb #include <stdio.h>
2120c8b67dStb #include <stdlib.h>
22b8d22d11Stb #include <string.h>
23b8d22d11Stb #include <unistd.h>
24b8d22d11Stb
25b8d22d11Stb #include <openssl/bn.h>
26b8d22d11Stb
27b8d22d11Stb #include "bn_local.h"
28b8d22d11Stb
29d79da6ceStb #define N_TESTS 100
30b8d22d11Stb
31b8d22d11Stb /* Sample squares between 2^128 and 2^4096. */
32b8d22d11Stb #define LOWER_BITS 128
33b8d22d11Stb #define UPPER_BITS 4096
34b8d22d11Stb
35b8d22d11Stb extern const uint8_t is_square_mod_11[];
36b8d22d11Stb extern const uint8_t is_square_mod_63[];
37b8d22d11Stb extern const uint8_t is_square_mod_64[];
38b8d22d11Stb extern const uint8_t is_square_mod_65[];
39b8d22d11Stb
40b8d22d11Stb static void
hexdump(const unsigned char * buf,size_t len)41b8d22d11Stb hexdump(const unsigned char *buf, size_t len)
42b8d22d11Stb {
43b8d22d11Stb size_t i;
44b8d22d11Stb
45b8d22d11Stb for (i = 1; i <= len; i++)
46b8d22d11Stb fprintf(stderr, " 0x%02hhx,%s", buf[i - 1], i % 8 ? "" : "\n");
47b8d22d11Stb
48b8d22d11Stb if (len % 8)
49b8d22d11Stb fprintf(stderr, "\n");
50b8d22d11Stb }
51b8d22d11Stb
52b8d22d11Stb static const uint8_t *
get_table(int modulus)53b8d22d11Stb get_table(int modulus)
54b8d22d11Stb {
55b8d22d11Stb switch (modulus) {
56b8d22d11Stb case 11:
57b8d22d11Stb return is_square_mod_11;
58b8d22d11Stb case 63:
59b8d22d11Stb return is_square_mod_63;
60b8d22d11Stb case 64:
61b8d22d11Stb return is_square_mod_64;
62b8d22d11Stb case 65:
63b8d22d11Stb return is_square_mod_65;
64b8d22d11Stb default:
65b8d22d11Stb return NULL;
66b8d22d11Stb }
67b8d22d11Stb }
68b8d22d11Stb
69b8d22d11Stb static int
check_tables(int print)70b8d22d11Stb check_tables(int print)
71b8d22d11Stb {
72b8d22d11Stb int fill[] = {11, 63, 64, 65};
73b8d22d11Stb const uint8_t *table;
74b8d22d11Stb uint8_t q[65];
75b8d22d11Stb size_t i;
76b8d22d11Stb int j;
77b8d22d11Stb int failed = 0;
78b8d22d11Stb
79b8d22d11Stb for (i = 0; i < sizeof(fill) / sizeof(fill[0]); i++) {
80b8d22d11Stb memset(q, 0, sizeof(q));
81b8d22d11Stb
82b8d22d11Stb for (j = 0; j < fill[i]; j++)
83b8d22d11Stb q[(j * j) % fill[i]] = 1;
84b8d22d11Stb
85b8d22d11Stb if ((table = get_table(fill[i])) == NULL) {
86b8d22d11Stb fprintf(stderr, "failed to get table %d\n", fill[i]);
87b8d22d11Stb failed |= 1;
88b8d22d11Stb continue;
89b8d22d11Stb }
90b8d22d11Stb
91b8d22d11Stb if (memcmp(table, q, fill[i]) != 0) {
92b8d22d11Stb fprintf(stderr, "table %d does not match:\n", fill[i]);
93b8d22d11Stb fprintf(stderr, "want:\n");
94b8d22d11Stb hexdump(table, fill[i]);
95b8d22d11Stb fprintf(stderr, "got:\n");
96b8d22d11Stb hexdump(q, fill[i]);
97b8d22d11Stb failed |= 1;
98b8d22d11Stb continue;
99b8d22d11Stb }
100b8d22d11Stb
101b8d22d11Stb if (!print)
102b8d22d11Stb continue;
103b8d22d11Stb
104b8d22d11Stb printf("const uint8_t is_square_mod_%d[] = {\n\t", fill[i]);
105b8d22d11Stb for (j = 0; j < fill[i]; j++) {
106b8d22d11Stb const char *end = " ";
107b8d22d11Stb
108b8d22d11Stb if (j % 16 == 15)
109b8d22d11Stb end = "\n\t";
110b8d22d11Stb if (j + 1 == fill[i])
111b8d22d11Stb end = "";
112b8d22d11Stb
113b8d22d11Stb printf("%d,%s", q[j], end);
114b8d22d11Stb }
115b8d22d11Stb printf("\n};\nCTASSERT(sizeof(is_square_mod_%d) == %d);\n\n",
116b8d22d11Stb fill[i], fill[i]);
117b8d22d11Stb }
118b8d22d11Stb
119b8d22d11Stb return failed;
120b8d22d11Stb }
121b8d22d11Stb
122b8d22d11Stb static int
validate_tables(void)123b8d22d11Stb validate_tables(void)
124b8d22d11Stb {
125b8d22d11Stb int fill[] = {11, 63, 64, 65};
126b8d22d11Stb const uint8_t *table;
127b8d22d11Stb size_t i;
128b8d22d11Stb int j, k;
129b8d22d11Stb int failed = 0;
130b8d22d11Stb
131b8d22d11Stb for (i = 0; i < sizeof(fill) / sizeof(fill[0]); i++) {
132b8d22d11Stb if ((table = get_table(fill[i])) == NULL) {
133b8d22d11Stb fprintf(stderr, "failed to get table %d\n", fill[i]);
134b8d22d11Stb failed |= 1;
135b8d22d11Stb continue;
136b8d22d11Stb }
137b8d22d11Stb
138b8d22d11Stb for (j = 0; j < fill[i]; j++) {
139b8d22d11Stb for (k = 0; k < fill[i]; k++) {
140b8d22d11Stb if (j == (k * k) % fill[i])
141b8d22d11Stb break;
142b8d22d11Stb }
143b8d22d11Stb
144b8d22d11Stb if (table[j] == 0 && k < fill[i]) {
145b8d22d11Stb fprintf(stderr, "%d == %d^2 (mod %d)", j, k,
146b8d22d11Stb fill[i]);
147b8d22d11Stb failed |= 1;
148b8d22d11Stb }
149b8d22d11Stb if (table[j] == 1 && k == fill[i]) {
150b8d22d11Stb fprintf(stderr, "%d not a square (mod %d)", j,
151b8d22d11Stb fill[i]);
152b8d22d11Stb failed |= 1;
153b8d22d11Stb }
154b8d22d11Stb }
155b8d22d11Stb }
156b8d22d11Stb
157b8d22d11Stb return failed;
158b8d22d11Stb }
159b8d22d11Stb
160b8d22d11Stb /*
161b8d22d11Stb * Choose a random number n of bit length between LOWER_BITS and UPPER_BITS and
162b8d22d11Stb * check that n == isqrt(n^2). Random numbers n^2 <= testcase < (n + 1)^2 are
163b8d22d11Stb * checked to have isqrt(testcase) == n.
164b8d22d11Stb */
165b8d22d11Stb static int
isqrt_test(void)166b8d22d11Stb isqrt_test(void)
167b8d22d11Stb {
168b8d22d11Stb BN_CTX *ctx;
169b8d22d11Stb BIGNUM *n, *n_sqr, *lower, *upper, *testcase, *isqrt;
170b8d22d11Stb int cmp, i, is_perfect_square;
171b8d22d11Stb int failed = 0;
172b8d22d11Stb
173b8d22d11Stb if ((ctx = BN_CTX_new()) == NULL)
174b8d22d11Stb errx(1, "BN_CTX_new");
175b8d22d11Stb
176b8d22d11Stb BN_CTX_start(ctx);
177b8d22d11Stb
178b8d22d11Stb if ((lower = BN_CTX_get(ctx)) == NULL)
179b8d22d11Stb errx(1, "lower = BN_CTX_get(ctx)");
180b8d22d11Stb if ((upper = BN_CTX_get(ctx)) == NULL)
181b8d22d11Stb errx(1, "upper = BN_CTX_get(ctx)");
182b8d22d11Stb if ((n = BN_CTX_get(ctx)) == NULL)
183b8d22d11Stb errx(1, "n = BN_CTX_get(ctx)");
184b8d22d11Stb if ((n_sqr = BN_CTX_get(ctx)) == NULL)
185b8d22d11Stb errx(1, "n = BN_CTX_get(ctx)");
186b8d22d11Stb if ((isqrt = BN_CTX_get(ctx)) == NULL)
187b8d22d11Stb errx(1, "result = BN_CTX_get(ctx)");
188b8d22d11Stb if ((testcase = BN_CTX_get(ctx)) == NULL)
189b8d22d11Stb errx(1, "testcase = BN_CTX_get(ctx)");
190b8d22d11Stb
191b8d22d11Stb /* lower = 2^LOWER_BITS, upper = 2^UPPER_BITS. */
192b8d22d11Stb if (!BN_set_bit(lower, LOWER_BITS))
193b8d22d11Stb errx(1, "BN_set_bit(lower, %d)", LOWER_BITS);
194b8d22d11Stb if (!BN_set_bit(upper, UPPER_BITS))
195b8d22d11Stb errx(1, "BN_set_bit(upper, %d)", UPPER_BITS);
196b8d22d11Stb
197*12347e81Stb if (!bn_rand_in_range(n, lower, upper))
198*12347e81Stb errx(1, "bn_rand_in_range n");
199b8d22d11Stb
200b8d22d11Stb /* n_sqr = n^2 */
201b8d22d11Stb if (!BN_sqr(n_sqr, n, ctx))
202b8d22d11Stb errx(1, "BN_sqr");
203b8d22d11Stb
204b8d22d11Stb if (!bn_isqrt(isqrt, &is_perfect_square, n_sqr, ctx))
205b8d22d11Stb errx(1, "bn_isqrt n_sqr");
206b8d22d11Stb
207b8d22d11Stb if ((cmp = BN_cmp(n, isqrt)) != 0 || !is_perfect_square) {
208b8d22d11Stb fprintf(stderr, "n = ");
209b8d22d11Stb BN_print_fp(stderr, n);
210b8d22d11Stb fprintf(stderr, "\nn^2 is_perfect_square: %d, cmp: %d\n",
211b8d22d11Stb is_perfect_square, cmp);
212b8d22d11Stb failed = 1;
213b8d22d11Stb }
214b8d22d11Stb
215b8d22d11Stb /* upper = 2 * n + 1 */
216b8d22d11Stb if (!BN_lshift1(upper, n))
217b8d22d11Stb errx(1, "BN_lshift1(upper, n)");
218b8d22d11Stb if (!BN_add_word(upper, 1))
219b8d22d11Stb errx(1, "BN_sub_word(upper, 1)");
220b8d22d11Stb
221b8d22d11Stb /* upper = (n + 1)^2 = n^2 + upper */
222b8d22d11Stb if (!BN_add(upper, n_sqr, upper))
223b8d22d11Stb errx(1, "BN_add");
224b8d22d11Stb
225b8d22d11Stb /*
226b8d22d11Stb * Check that isqrt((n + 1)^2) - 1 == n.
227b8d22d11Stb */
228b8d22d11Stb
229b8d22d11Stb if (!bn_isqrt(isqrt, &is_perfect_square, upper, ctx))
230b8d22d11Stb errx(1, "bn_isqrt(upper)");
231b8d22d11Stb
232b8d22d11Stb if (!BN_sub_word(isqrt, 1))
233b8d22d11Stb errx(1, "BN_add_word(isqrt, 1)");
234b8d22d11Stb
235b8d22d11Stb if ((cmp = BN_cmp(n, isqrt)) != 0 || !is_perfect_square) {
236b8d22d11Stb fprintf(stderr, "n = ");
237b8d22d11Stb BN_print_fp(stderr, n);
238b8d22d11Stb fprintf(stderr, "\n(n + 1)^2 is_perfect_square: %d, cmp: %d\n",
239b8d22d11Stb is_perfect_square, cmp);
240b8d22d11Stb failed = 1;
241b8d22d11Stb }
242b8d22d11Stb
243b8d22d11Stb /*
244b8d22d11Stb * Test N_TESTS random numbers n^2 <= testcase < (n + 1)^2 and check
245b8d22d11Stb * that their isqrt is n.
246b8d22d11Stb */
247b8d22d11Stb
248b8d22d11Stb for (i = 0; i < N_TESTS; i++) {
249*12347e81Stb if (!bn_rand_in_range(testcase, n_sqr, upper))
250*12347e81Stb errx(1, "bn_rand_in_range testcase");
251b8d22d11Stb
252b8d22d11Stb if (!bn_isqrt(isqrt, &is_perfect_square, testcase, ctx))
253b8d22d11Stb errx(1, "bn_isqrt testcase");
254b8d22d11Stb
255b8d22d11Stb if ((cmp = BN_cmp(n, isqrt)) != 0 ||
256b8d22d11Stb (is_perfect_square && BN_cmp(n_sqr, testcase) != 0)) {
257b8d22d11Stb fprintf(stderr, "n = ");
258b8d22d11Stb BN_print_fp(stderr, n);
259b8d22d11Stb fprintf(stderr, "\ntestcase = ");
260b8d22d11Stb BN_print_fp(stderr, testcase);
261b8d22d11Stb fprintf(stderr,
262b8d22d11Stb "\ntestcase is_perfect_square: %d, cmp: %d\n",
263b8d22d11Stb is_perfect_square, cmp);
264b8d22d11Stb failed = 1;
265b8d22d11Stb }
266b8d22d11Stb }
267b8d22d11Stb
268b8d22d11Stb /*
269b8d22d11Stb * Finally check that isqrt(n^2 - 1) + 1 == n.
270b8d22d11Stb */
271b8d22d11Stb
272b8d22d11Stb if (!BN_sub(testcase, n_sqr, BN_value_one()))
273b8d22d11Stb errx(1, "BN_sub(testcase, n_sqr, 1)");
274b8d22d11Stb
275b8d22d11Stb if (!bn_isqrt(isqrt, &is_perfect_square, testcase, ctx))
276b8d22d11Stb errx(1, "bn_isqrt(n_sqr - 1)");
277b8d22d11Stb
278b8d22d11Stb if (!BN_add_word(isqrt, 1))
279b8d22d11Stb errx(1, "BN_add_word(isqrt, 1)");
280b8d22d11Stb
281b8d22d11Stb if ((cmp = BN_cmp(n, isqrt)) != 0 || is_perfect_square) {
282b8d22d11Stb fprintf(stderr, "n = ");
283b8d22d11Stb BN_print_fp(stderr, n);
284b8d22d11Stb fprintf(stderr, "\nn_sqr - 1 is_perfect_square: %d, cmp: %d\n",
285b8d22d11Stb is_perfect_square, cmp);
286b8d22d11Stb failed = 1;
287b8d22d11Stb }
288b8d22d11Stb
289b8d22d11Stb BN_CTX_end(ctx);
290b8d22d11Stb BN_CTX_free(ctx);
291b8d22d11Stb
292b8d22d11Stb return failed;
293b8d22d11Stb }
294b8d22d11Stb
295b8d22d11Stb static void
usage(void)296b8d22d11Stb usage(void)
297b8d22d11Stb {
298b8d22d11Stb fprintf(stderr, "usage: bn_isqrt [-C]\n");
299b8d22d11Stb exit(1);
300b8d22d11Stb }
301b8d22d11Stb
302b8d22d11Stb int
main(int argc,char * argv[])303b8d22d11Stb main(int argc, char *argv[])
304b8d22d11Stb {
305b8d22d11Stb size_t i;
306b8d22d11Stb int ch;
307b8d22d11Stb int failed = 0, print = 0;
308b8d22d11Stb
309b8d22d11Stb while ((ch = getopt(argc, argv, "C")) != -1) {
310b8d22d11Stb switch (ch) {
311b8d22d11Stb case 'C':
312b8d22d11Stb print = 1;
313b8d22d11Stb break;
314b8d22d11Stb default:
315b8d22d11Stb usage();
316b8d22d11Stb break;
317b8d22d11Stb }
318b8d22d11Stb }
319b8d22d11Stb
320b8d22d11Stb if (print)
321b8d22d11Stb return check_tables(1);
322b8d22d11Stb
323b8d22d11Stb for (i = 0; i < N_TESTS; i++)
324b8d22d11Stb failed |= isqrt_test();
325b8d22d11Stb
326b8d22d11Stb failed |= check_tables(0);
327b8d22d11Stb failed |= validate_tables();
328b8d22d11Stb
329b8d22d11Stb return failed;
330b8d22d11Stb }
331