1 /* $OpenBSD: bn_isqrt.c,v 1.4 2023/08/03 18:53:56 tb Exp $ */
2 /*
3 * Copyright (c) 2022 Theo Buehler <tb@openbsd.org>
4 *
5 * Permission to use, copy, modify, and distribute this software for any
6 * purpose with or without fee is hereby granted, provided that the above
7 * copyright notice and this permission notice appear in all copies.
8 *
9 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
10 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
11 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
12 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
13 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
14 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
15 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
16 */
17
18 #include <err.h>
19 #include <stdint.h>
20 #include <stdio.h>
21 #include <stdlib.h>
22 #include <string.h>
23 #include <unistd.h>
24
25 #include <openssl/bn.h>
26
27 #include "bn_local.h"
28
29 #define N_TESTS 100
30
31 /* Sample squares between 2^128 and 2^4096. */
32 #define LOWER_BITS 128
33 #define UPPER_BITS 4096
34
35 extern const uint8_t is_square_mod_11[];
36 extern const uint8_t is_square_mod_63[];
37 extern const uint8_t is_square_mod_64[];
38 extern const uint8_t is_square_mod_65[];
39
40 static void
hexdump(const unsigned char * buf,size_t len)41 hexdump(const unsigned char *buf, size_t len)
42 {
43 size_t i;
44
45 for (i = 1; i <= len; i++)
46 fprintf(stderr, " 0x%02hhx,%s", buf[i - 1], i % 8 ? "" : "\n");
47
48 if (len % 8)
49 fprintf(stderr, "\n");
50 }
51
52 static const uint8_t *
get_table(int modulus)53 get_table(int modulus)
54 {
55 switch (modulus) {
56 case 11:
57 return is_square_mod_11;
58 case 63:
59 return is_square_mod_63;
60 case 64:
61 return is_square_mod_64;
62 case 65:
63 return is_square_mod_65;
64 default:
65 return NULL;
66 }
67 }
68
69 static int
check_tables(int print)70 check_tables(int print)
71 {
72 int fill[] = {11, 63, 64, 65};
73 const uint8_t *table;
74 uint8_t q[65];
75 size_t i;
76 int j;
77 int failed = 0;
78
79 for (i = 0; i < sizeof(fill) / sizeof(fill[0]); i++) {
80 memset(q, 0, sizeof(q));
81
82 for (j = 0; j < fill[i]; j++)
83 q[(j * j) % fill[i]] = 1;
84
85 if ((table = get_table(fill[i])) == NULL) {
86 fprintf(stderr, "failed to get table %d\n", fill[i]);
87 failed |= 1;
88 continue;
89 }
90
91 if (memcmp(table, q, fill[i]) != 0) {
92 fprintf(stderr, "table %d does not match:\n", fill[i]);
93 fprintf(stderr, "want:\n");
94 hexdump(table, fill[i]);
95 fprintf(stderr, "got:\n");
96 hexdump(q, fill[i]);
97 failed |= 1;
98 continue;
99 }
100
101 if (!print)
102 continue;
103
104 printf("const uint8_t is_square_mod_%d[] = {\n\t", fill[i]);
105 for (j = 0; j < fill[i]; j++) {
106 const char *end = " ";
107
108 if (j % 16 == 15)
109 end = "\n\t";
110 if (j + 1 == fill[i])
111 end = "";
112
113 printf("%d,%s", q[j], end);
114 }
115 printf("\n};\nCTASSERT(sizeof(is_square_mod_%d) == %d);\n\n",
116 fill[i], fill[i]);
117 }
118
119 return failed;
120 }
121
122 static int
validate_tables(void)123 validate_tables(void)
124 {
125 int fill[] = {11, 63, 64, 65};
126 const uint8_t *table;
127 size_t i;
128 int j, k;
129 int failed = 0;
130
131 for (i = 0; i < sizeof(fill) / sizeof(fill[0]); i++) {
132 if ((table = get_table(fill[i])) == NULL) {
133 fprintf(stderr, "failed to get table %d\n", fill[i]);
134 failed |= 1;
135 continue;
136 }
137
138 for (j = 0; j < fill[i]; j++) {
139 for (k = 0; k < fill[i]; k++) {
140 if (j == (k * k) % fill[i])
141 break;
142 }
143
144 if (table[j] == 0 && k < fill[i]) {
145 fprintf(stderr, "%d == %d^2 (mod %d)", j, k,
146 fill[i]);
147 failed |= 1;
148 }
149 if (table[j] == 1 && k == fill[i]) {
150 fprintf(stderr, "%d not a square (mod %d)", j,
151 fill[i]);
152 failed |= 1;
153 }
154 }
155 }
156
157 return failed;
158 }
159
160 /*
161 * Choose a random number n of bit length between LOWER_BITS and UPPER_BITS and
162 * check that n == isqrt(n^2). Random numbers n^2 <= testcase < (n + 1)^2 are
163 * checked to have isqrt(testcase) == n.
164 */
165 static int
isqrt_test(void)166 isqrt_test(void)
167 {
168 BN_CTX *ctx;
169 BIGNUM *n, *n_sqr, *lower, *upper, *testcase, *isqrt;
170 int cmp, i, is_perfect_square;
171 int failed = 0;
172
173 if ((ctx = BN_CTX_new()) == NULL)
174 errx(1, "BN_CTX_new");
175
176 BN_CTX_start(ctx);
177
178 if ((lower = BN_CTX_get(ctx)) == NULL)
179 errx(1, "lower = BN_CTX_get(ctx)");
180 if ((upper = BN_CTX_get(ctx)) == NULL)
181 errx(1, "upper = BN_CTX_get(ctx)");
182 if ((n = BN_CTX_get(ctx)) == NULL)
183 errx(1, "n = BN_CTX_get(ctx)");
184 if ((n_sqr = BN_CTX_get(ctx)) == NULL)
185 errx(1, "n = BN_CTX_get(ctx)");
186 if ((isqrt = BN_CTX_get(ctx)) == NULL)
187 errx(1, "result = BN_CTX_get(ctx)");
188 if ((testcase = BN_CTX_get(ctx)) == NULL)
189 errx(1, "testcase = BN_CTX_get(ctx)");
190
191 /* lower = 2^LOWER_BITS, upper = 2^UPPER_BITS. */
192 if (!BN_set_bit(lower, LOWER_BITS))
193 errx(1, "BN_set_bit(lower, %d)", LOWER_BITS);
194 if (!BN_set_bit(upper, UPPER_BITS))
195 errx(1, "BN_set_bit(upper, %d)", UPPER_BITS);
196
197 if (!bn_rand_in_range(n, lower, upper))
198 errx(1, "bn_rand_in_range n");
199
200 /* n_sqr = n^2 */
201 if (!BN_sqr(n_sqr, n, ctx))
202 errx(1, "BN_sqr");
203
204 if (!bn_isqrt(isqrt, &is_perfect_square, n_sqr, ctx))
205 errx(1, "bn_isqrt n_sqr");
206
207 if ((cmp = BN_cmp(n, isqrt)) != 0 || !is_perfect_square) {
208 fprintf(stderr, "n = ");
209 BN_print_fp(stderr, n);
210 fprintf(stderr, "\nn^2 is_perfect_square: %d, cmp: %d\n",
211 is_perfect_square, cmp);
212 failed = 1;
213 }
214
215 /* upper = 2 * n + 1 */
216 if (!BN_lshift1(upper, n))
217 errx(1, "BN_lshift1(upper, n)");
218 if (!BN_add_word(upper, 1))
219 errx(1, "BN_sub_word(upper, 1)");
220
221 /* upper = (n + 1)^2 = n^2 + upper */
222 if (!BN_add(upper, n_sqr, upper))
223 errx(1, "BN_add");
224
225 /*
226 * Check that isqrt((n + 1)^2) - 1 == n.
227 */
228
229 if (!bn_isqrt(isqrt, &is_perfect_square, upper, ctx))
230 errx(1, "bn_isqrt(upper)");
231
232 if (!BN_sub_word(isqrt, 1))
233 errx(1, "BN_add_word(isqrt, 1)");
234
235 if ((cmp = BN_cmp(n, isqrt)) != 0 || !is_perfect_square) {
236 fprintf(stderr, "n = ");
237 BN_print_fp(stderr, n);
238 fprintf(stderr, "\n(n + 1)^2 is_perfect_square: %d, cmp: %d\n",
239 is_perfect_square, cmp);
240 failed = 1;
241 }
242
243 /*
244 * Test N_TESTS random numbers n^2 <= testcase < (n + 1)^2 and check
245 * that their isqrt is n.
246 */
247
248 for (i = 0; i < N_TESTS; i++) {
249 if (!bn_rand_in_range(testcase, n_sqr, upper))
250 errx(1, "bn_rand_in_range testcase");
251
252 if (!bn_isqrt(isqrt, &is_perfect_square, testcase, ctx))
253 errx(1, "bn_isqrt testcase");
254
255 if ((cmp = BN_cmp(n, isqrt)) != 0 ||
256 (is_perfect_square && BN_cmp(n_sqr, testcase) != 0)) {
257 fprintf(stderr, "n = ");
258 BN_print_fp(stderr, n);
259 fprintf(stderr, "\ntestcase = ");
260 BN_print_fp(stderr, testcase);
261 fprintf(stderr,
262 "\ntestcase is_perfect_square: %d, cmp: %d\n",
263 is_perfect_square, cmp);
264 failed = 1;
265 }
266 }
267
268 /*
269 * Finally check that isqrt(n^2 - 1) + 1 == n.
270 */
271
272 if (!BN_sub(testcase, n_sqr, BN_value_one()))
273 errx(1, "BN_sub(testcase, n_sqr, 1)");
274
275 if (!bn_isqrt(isqrt, &is_perfect_square, testcase, ctx))
276 errx(1, "bn_isqrt(n_sqr - 1)");
277
278 if (!BN_add_word(isqrt, 1))
279 errx(1, "BN_add_word(isqrt, 1)");
280
281 if ((cmp = BN_cmp(n, isqrt)) != 0 || is_perfect_square) {
282 fprintf(stderr, "n = ");
283 BN_print_fp(stderr, n);
284 fprintf(stderr, "\nn_sqr - 1 is_perfect_square: %d, cmp: %d\n",
285 is_perfect_square, cmp);
286 failed = 1;
287 }
288
289 BN_CTX_end(ctx);
290 BN_CTX_free(ctx);
291
292 return failed;
293 }
294
295 static void
usage(void)296 usage(void)
297 {
298 fprintf(stderr, "usage: bn_isqrt [-C]\n");
299 exit(1);
300 }
301
302 int
main(int argc,char * argv[])303 main(int argc, char *argv[])
304 {
305 size_t i;
306 int ch;
307 int failed = 0, print = 0;
308
309 while ((ch = getopt(argc, argv, "C")) != -1) {
310 switch (ch) {
311 case 'C':
312 print = 1;
313 break;
314 default:
315 usage();
316 break;
317 }
318 }
319
320 if (print)
321 return check_tables(1);
322
323 for (i = 0; i < N_TESTS; i++)
324 failed |= isqrt_test();
325
326 failed |= check_tables(0);
327 failed |= validate_tables();
328
329 return failed;
330 }
331