xref: /netbsd-src/sbin/nvmectl/bignum.c (revision 19c1490f616ddeb018ac174706e04b29131f530e)
1 /*	$NetBSD: bignum.c,v 1.6 2023/02/27 22:00:25 andvar Exp $	*/
2 
3 /*-
4  * Copyright (c) 2012 Alistair Crooks <agc@NetBSD.org>
5  * All rights reserved.
6  *
7  * Redistribution and use in source and binary forms, with or without
8  * modification, are permitted provided that the following conditions
9  * are met:
10  * 1. Redistributions of source code must retain the above copyright
11  *    notice, this list of conditions and the following disclaimer.
12  * 2. Redistributions in binary form must reproduce the above copyright
13  *    notice, this list of conditions and the following disclaimer in the
14  *    documentation and/or other materials provided with the distribution.
15  *
16  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
17  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
18  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
19  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
20  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
21  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
22  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
23  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
25  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26  */
27 /* LibTomMath, multiple-precision integer library -- Tom St Denis
28  *
29  * LibTomMath is a library that provides multiple-precision
30  * integer arithmetic as well as number theoretic functionality.
31  *
32  * The library was designed directly after the MPI library by
33  * Michael Fromberger but has been written from scratch with
34  * additional optimizations in place.
35  *
36  * The library is free for all purposes without any express
37  * guarantee it works.
38  *
39  * Tom St Denis, tomstdenis@gmail.com, http://libtom.org
40  */
41 
42 #include <sys/types.h>
43 #include <sys/param.h>
44 
45 #include <limits.h>
46 #include <stdarg.h>
47 #include <stdio.h>
48 #include <stdlib.h>
49 #include <string.h>
50 #include <unistd.h>
51 
52 #include "bn.h"
53 
54 /**************************************************************************/
55 
56 /* LibTomMath, multiple-precision integer library -- Tom St Denis
57  *
58  * LibTomMath is a library that provides multiple-precision
59  * integer arithmetic as well as number theoretic functionality.
60  *
61  * The library was designed directly after the MPI library by
62  * Michael Fromberger but has been written from scratch with
63  * additional optimizations in place.
64  *
65  * The library is free for all purposes without any express
66  * guarantee it works.
67  *
68  * Tom St Denis, tomstdenis@gmail.com, http://libtom.org
69  */
70 
71 #define MP_PREC		32
72 #define DIGIT_BIT	28
73 #define MP_MASK          ((((mp_digit)1)<<((mp_digit)DIGIT_BIT))-((mp_digit)1))
74 
75 #define MP_WARRAY	/*LINTED*/(1U << (((sizeof(mp_word) * CHAR_BIT) - (2 * DIGIT_BIT) + 1)))
76 
77 #define MP_NO		0
78 #define MP_YES		1
79 
80 #ifndef USE_ARG
81 #define USE_ARG(x)	/*LINTED*/(void)&(x)
82 #endif
83 
84 #ifndef __arraycount
85 #define	__arraycount(__x)	(sizeof(__x) / sizeof(__x[0]))
86 #endif
87 
88 #ifndef MIN
89 #define MIN(a,b)	(((a)<(b))?(a):(b))
90 #endif
91 
92 #define MP_ISZERO(a) (((a)->used == 0) ? MP_YES : MP_NO)
93 
94 typedef int           mp_err;
95 
96 static int signed_multiply(mp_int * a, mp_int * b, mp_int * c);
97 static int square(mp_int * a, mp_int * b);
98 
99 static int signed_subtract_word(mp_int *a, mp_digit b, mp_int *c);
100 
101 static inline void *
allocate(size_t n,size_t m)102 allocate(size_t n, size_t m)
103 {
104 	return calloc(n, m);
105 }
106 
107 static inline void
deallocate(void * v,size_t sz)108 deallocate(void *v, size_t sz)
109 {
110 	USE_ARG(sz);
111 	free(v);
112 }
113 
114 /* set to zero */
115 static inline void
mp_zero(mp_int * a)116 mp_zero(mp_int *a)
117 {
118 	a->sign = MP_ZPOS;
119 	a->used = 0;
120 	memset(a->dp, 0x0, a->alloc * sizeof(*a->dp));
121 }
122 
123 /* grow as required */
124 static int
mp_grow(mp_int * a,int size)125 mp_grow(mp_int *a, int size)
126 {
127 	mp_digit *tmp;
128 
129 	/* if the alloc size is smaller alloc more ram */
130 	if (a->alloc < size) {
131 		/* ensure there are always at least MP_PREC digits extra on top */
132 		size += (MP_PREC * 2) - (size % MP_PREC);
133 
134 		/* reallocate the array a->dp
135 		*
136 		* We store the return in a temporary variable
137 		* in case the operation failed we don't want
138 		* to overwrite the dp member of a.
139 		*/
140 		tmp = realloc(a->dp, sizeof(*tmp) * size);
141 		if (tmp == NULL) {
142 			/* reallocation failed but "a" is still valid [can be freed] */
143 			return MP_MEM;
144 		}
145 
146 		/* reallocation succeeded so set a->dp */
147 		a->dp = tmp;
148 		/* zero excess digits */
149 		memset(&a->dp[a->alloc], 0x0, (size - a->alloc) * sizeof(*a->dp));
150 		a->alloc = size;
151 	}
152 	return MP_OKAY;
153 }
154 
155 /* shift left a certain amount of digits */
156 static int
lshift_digits(mp_int * a,int b)157 lshift_digits(mp_int * a, int b)
158 {
159 	mp_digit *top, *bottom;
160 	int     x, res;
161 
162 	/* if its less than zero return */
163 	if (b <= 0) {
164 		return MP_OKAY;
165 	}
166 
167 	/* grow to fit the new digits */
168 	if (a->alloc < a->used + b) {
169 		if ((res = mp_grow(a, a->used + b)) != MP_OKAY) {
170 			return res;
171 		}
172 	}
173 
174 	/* increment the used by the shift amount then copy upwards */
175 	a->used += b;
176 
177 	/* top */
178 	top = a->dp + a->used - 1;
179 
180 	/* base */
181 	bottom = a->dp + a->used - 1 - b;
182 
183 	/* much like rshift_digits this is implemented using a sliding window
184 	* except the window goes the otherway around.  Copying from
185 	* the bottom to the top.
186 	*/
187 	for (x = a->used - 1; x >= b; x--) {
188 		*top-- = *bottom--;
189 	}
190 
191 	/* zero the lower digits */
192 	memset(a->dp, 0x0, b * sizeof(*a->dp));
193 	return MP_OKAY;
194 }
195 
196 /* trim unused digits
197  *
198  * This is used to ensure that leading zero digits are
199  * trimmed and the leading "used" digit will be non-zero
200  * Typically very fast.  Also fixes the sign if there
201  * are no more leading digits
202  */
203 static void
trim_unused_digits(mp_int * a)204 trim_unused_digits(mp_int * a)
205 {
206 	/* decrease used while the most significant digit is
207 	* zero.
208 	*/
209 	while (a->used > 0 && a->dp[a->used - 1] == 0) {
210 		a->used -= 1;
211 	}
212 	/* reset the sign flag if used == 0 */
213 	if (a->used == 0) {
214 		a->sign = MP_ZPOS;
215 	}
216 }
217 
218 /* copy, b = a */
219 static int
mp_copy(BIGNUM * a,BIGNUM * b)220 mp_copy(BIGNUM *a, BIGNUM *b)
221 {
222 	int	res;
223 
224 	/* if dst == src do nothing */
225 	if (a == b) {
226 		return MP_OKAY;
227 	}
228 	if (a == NULL || b == NULL) {
229 		return MP_VAL;
230 	}
231 
232 	/* grow dest */
233 	if (b->alloc < a->used) {
234 		if ((res = mp_grow(b, a->used)) != MP_OKAY) {
235 			return res;
236 		}
237 	}
238 
239 	memcpy(b->dp, a->dp, a->used * sizeof(*b->dp));
240 	if (b->used > a->used) {
241 		memset(&b->dp[a->used], 0x0, (b->used - a->used) * sizeof(*b->dp));
242 	}
243 
244 	/* copy used count and sign */
245 	b->used = a->used;
246 	b->sign = a->sign;
247 	return MP_OKAY;
248 }
249 
250 /* shift left by a certain bit count */
251 static int
lshift_bits(mp_int * a,int b,mp_int * c)252 lshift_bits(mp_int *a, int b, mp_int *c)
253 {
254 	mp_digit d;
255 	int      res;
256 
257 	/* copy */
258 	if (a != c) {
259 		if ((res = mp_copy(a, c)) != MP_OKAY) {
260 			return res;
261 		}
262 	}
263 
264 	if (c->alloc < (int)(c->used + b/DIGIT_BIT + 1)) {
265 		if ((res = mp_grow(c, c->used + b / DIGIT_BIT + 1)) != MP_OKAY) {
266 			return res;
267 		}
268 	}
269 
270 	/* shift by as many digits in the bit count */
271 	if (b >= (int)DIGIT_BIT) {
272 		if ((res = lshift_digits(c, b / DIGIT_BIT)) != MP_OKAY) {
273 			return res;
274 		}
275 	}
276 
277 	/* shift any bit count < DIGIT_BIT */
278 	d = (mp_digit) (b % DIGIT_BIT);
279 	if (d != 0) {
280 		mp_digit *tmpc, shift, mask, carry, rr;
281 		int x;
282 
283 		/* bitmask for carries */
284 		mask = (((mp_digit)1) << d) - 1;
285 
286 		/* shift for msbs */
287 		shift = DIGIT_BIT - d;
288 
289 		/* alias */
290 		tmpc = c->dp;
291 
292 		/* carry */
293 		carry = 0;
294 		for (x = 0; x < c->used; x++) {
295 			/* get the higher bits of the current word */
296 			rr = (*tmpc >> shift) & mask;
297 
298 			/* shift the current word and OR in the carry */
299 			*tmpc = ((*tmpc << d) | carry) & MP_MASK;
300 			++tmpc;
301 
302 			/* set the carry to the carry bits of the current word */
303 			carry = rr;
304 		}
305 
306 		/* set final carry */
307 		if (carry != 0) {
308 			c->dp[c->used++] = carry;
309 		}
310 	}
311 	trim_unused_digits(c);
312 	return MP_OKAY;
313 }
314 
315 /* reads a unsigned char array, assumes the msb is stored first [big endian] */
316 static int
mp_read_unsigned_bin(mp_int * a,const uint8_t * b,int c)317 mp_read_unsigned_bin(mp_int *a, const uint8_t *b, int c)
318 {
319 	int     res;
320 
321 	/* make sure there are at least two digits */
322 	if (a->alloc < 2) {
323 		if ((res = mp_grow(a, 2)) != MP_OKAY) {
324 			return res;
325 		}
326 	}
327 
328 	/* zero the int */
329 	mp_zero(a);
330 
331 	/* read the bytes in */
332 	while (c-- > 0) {
333 		if ((res = lshift_bits(a, 8, a)) != MP_OKAY) {
334 			return res;
335 		}
336 
337 		a->dp[0] |= *b++;
338 		a->used += 1;
339 	}
340 	trim_unused_digits(a);
341 	return MP_OKAY;
342 }
343 
344 /* returns the number of bits in an mpi */
345 static int
mp_count_bits(const mp_int * a)346 mp_count_bits(const mp_int *a)
347 {
348 	int     r;
349 	mp_digit q;
350 
351 	/* shortcut */
352 	if (a->used == 0) {
353 		return 0;
354 	}
355 
356 	/* get number of digits and add that */
357 	r = (a->used - 1) * DIGIT_BIT;
358 
359 	/* take the last digit and count the bits in it */
360 	for (q = a->dp[a->used - 1]; q > ((mp_digit) 0) ; r++) {
361 		q >>= ((mp_digit) 1);
362 	}
363 	return r;
364 }
365 
366 /* compare magnitude of two ints (unsigned) */
367 static int
compare_magnitude(mp_int * a,mp_int * b)368 compare_magnitude(mp_int * a, mp_int * b)
369 {
370 	int     n;
371 	mp_digit *tmpa, *tmpb;
372 
373 	/* compare based on # of non-zero digits */
374 	if (a->used > b->used) {
375 		return MP_GT;
376 	}
377 
378 	if (a->used < b->used) {
379 		return MP_LT;
380 	}
381 
382 	/* alias for a */
383 	tmpa = a->dp + (a->used - 1);
384 
385 	/* alias for b */
386 	tmpb = b->dp + (a->used - 1);
387 
388 	/* compare based on digits  */
389 	for (n = 0; n < a->used; ++n, --tmpa, --tmpb) {
390 		if (*tmpa > *tmpb) {
391 			return MP_GT;
392 		}
393 
394 		if (*tmpa < *tmpb) {
395 			return MP_LT;
396 		}
397 	}
398 	return MP_EQ;
399 }
400 
401 /* compare two ints (signed)*/
402 static int
signed_compare(mp_int * a,mp_int * b)403 signed_compare(mp_int * a, mp_int * b)
404 {
405 	/* compare based on sign */
406 	if (a->sign != b->sign) {
407 		return (a->sign == MP_NEG) ? MP_LT : MP_GT;
408 	}
409 	return (a->sign == MP_NEG) ? compare_magnitude(b, a) : compare_magnitude(a, b);
410 }
411 
412 /* get the size for an unsigned equivalent */
413 static int
mp_unsigned_bin_size(mp_int * a)414 mp_unsigned_bin_size(mp_int * a)
415 {
416 	int     size = mp_count_bits(a);
417 
418 	return (size / 8 + ((size & 7) != 0 ? 1 : 0));
419 }
420 
421 /* init a new mp_int */
422 static int
mp_init(mp_int * a)423 mp_init(mp_int * a)
424 {
425 	/* allocate memory required and clear it */
426 	a->dp = allocate(1, sizeof(*a->dp) * MP_PREC);
427 	if (a->dp == NULL) {
428 		return MP_MEM;
429 	}
430 
431 	/* set the digits to zero */
432 	memset(a->dp, 0x0, MP_PREC * sizeof(*a->dp));
433 
434 	/* set the used to zero, allocated digits to the default precision
435 	* and sign to positive */
436 	a->used  = 0;
437 	a->alloc = MP_PREC;
438 	a->sign  = MP_ZPOS;
439 
440 	return MP_OKAY;
441 }
442 
443 /* clear one (frees)  */
444 static void
mp_clear(mp_int * a)445 mp_clear(mp_int * a)
446 {
447 	/* only do anything if a hasn't been freed previously */
448 	if (a->dp != NULL) {
449 		memset(a->dp, 0x0, a->used * sizeof(*a->dp));
450 
451 		/* free ram */
452 		deallocate(a->dp, (size_t)a->alloc);
453 
454 		/* reset members to make debugging easier */
455 		a->dp = NULL;
456 		a->alloc = a->used = 0;
457 		a->sign  = MP_ZPOS;
458 	}
459 }
460 
461 static int
mp_init_multi(mp_int * mp,...)462 mp_init_multi(mp_int *mp, ...)
463 {
464 	mp_err res = MP_OKAY;      /* Assume ok until proven otherwise */
465 	int n = 0;                 /* Number of ok inits */
466 	mp_int* cur_arg = mp;
467 	va_list args;
468 
469 	va_start(args, mp);        /* init args to next argument from caller */
470 	while (cur_arg != NULL) {
471 		if (mp_init(cur_arg) != MP_OKAY) {
472 			/* Oops - error! Back-track and mp_clear what we already
473 			succeeded in init-ing, then return error.
474 			*/
475 			va_list clean_args;
476 
477 			/* end the current list */
478 			va_end(args);
479 
480 			/* now start cleaning up */
481 			cur_arg = mp;
482 			va_start(clean_args, mp);
483 			while (n--) {
484 				mp_clear(cur_arg);
485 				cur_arg = va_arg(clean_args, mp_int*);
486 			}
487 			va_end(clean_args);
488 			res = MP_MEM;
489 			break;
490 		}
491 		n++;
492 		cur_arg = va_arg(args, mp_int*);
493 	}
494 	va_end(args);
495 	return res;                /* Assumed ok, if error flagged above. */
496 }
497 
498 /* init an mp_init for a given size */
499 static int
mp_init_size(mp_int * a,int size)500 mp_init_size(mp_int * a, int size)
501 {
502 	/* pad size so there are always extra digits */
503 	size += (MP_PREC * 2) - (size % MP_PREC);
504 
505 	/* alloc mem */
506 	a->dp = allocate(1, sizeof(*a->dp) * size);
507 	if (a->dp == NULL) {
508 		return MP_MEM;
509 	}
510 
511 	/* set the members */
512 	a->used  = 0;
513 	a->alloc = size;
514 	a->sign  = MP_ZPOS;
515 
516 	/* zero the digits */
517 	memset(a->dp, 0x0, size * sizeof(*a->dp));
518 	return MP_OKAY;
519 }
520 
521 /* creates "a" then copies b into it */
522 static int
mp_init_copy(mp_int * a,mp_int * b)523 mp_init_copy(mp_int * a, mp_int * b)
524 {
525 	int     res;
526 
527 	if ((res = mp_init(a)) != MP_OKAY) {
528 		return res;
529 	}
530 	return mp_copy(b, a);
531 }
532 
533 /* low level addition, based on HAC pp.594, Algorithm 14.7 */
534 static int
basic_add(mp_int * a,mp_int * b,mp_int * c)535 basic_add(mp_int * a, mp_int * b, mp_int * c)
536 {
537 	mp_int *x;
538 	int     olduse, res, min, max;
539 
540 	/* find sizes, we let |a| <= |b| which means we have to sort
541 	* them.  "x" will point to the input with the most digits
542 	*/
543 	if (a->used > b->used) {
544 		min = b->used;
545 		max = a->used;
546 		x = a;
547 	} else {
548 		min = a->used;
549 		max = b->used;
550 		x = b;
551 	}
552 
553 	/* init result */
554 	if (c->alloc < max + 1) {
555 		if ((res = mp_grow(c, max + 1)) != MP_OKAY) {
556 			return res;
557 		}
558 	}
559 
560 	/* get old used digit count and set new one */
561 	olduse = c->used;
562 	c->used = max + 1;
563 
564 	{
565 		mp_digit carry, *tmpa, *tmpb, *tmpc;
566 		int i;
567 
568 		/* alias for digit pointers */
569 
570 		/* first input */
571 		tmpa = a->dp;
572 
573 		/* second input */
574 		tmpb = b->dp;
575 
576 		/* destination */
577 		tmpc = c->dp;
578 
579 		/* zero the carry */
580 		carry = 0;
581 		for (i = 0; i < min; i++) {
582 			/* Compute the sum at one digit, T[i] = A[i] + B[i] + U */
583 			*tmpc = *tmpa++ + *tmpb++ + carry;
584 
585 			/* U = carry bit of T[i] */
586 			carry = *tmpc >> ((mp_digit)DIGIT_BIT);
587 
588 			/* take away carry bit from T[i] */
589 			*tmpc++ &= MP_MASK;
590 		}
591 
592 		/* now copy higher words if any, that is in A+B
593 		* if A or B has more digits add those in
594 		*/
595 		if (min != max) {
596 			for (; i < max; i++) {
597 				/* T[i] = X[i] + U */
598 				*tmpc = x->dp[i] + carry;
599 
600 				/* U = carry bit of T[i] */
601 				carry = *tmpc >> ((mp_digit)DIGIT_BIT);
602 
603 				/* take away carry bit from T[i] */
604 				*tmpc++ &= MP_MASK;
605 			}
606 		}
607 
608 		/* add carry */
609 		*tmpc++ = carry;
610 
611 		/* clear digits above oldused */
612 		if (olduse > c->used) {
613 			memset(tmpc, 0x0, (olduse - c->used) * sizeof(*c->dp));
614 		}
615 	}
616 
617 	trim_unused_digits(c);
618 	return MP_OKAY;
619 }
620 
621 /* low level subtraction (assumes |a| > |b|), HAC pp.595 Algorithm 14.9 */
622 static int
basic_subtract(mp_int * a,mp_int * b,mp_int * c)623 basic_subtract(mp_int * a, mp_int * b, mp_int * c)
624 {
625 	int     olduse, res, min, max;
626 
627 	/* find sizes */
628 	min = b->used;
629 	max = a->used;
630 
631 	/* init result */
632 	if (c->alloc < max) {
633 		if ((res = mp_grow(c, max)) != MP_OKAY) {
634 			return res;
635 		}
636 	}
637 	olduse = c->used;
638 	c->used = max;
639 
640 	{
641 		mp_digit carry, *tmpa, *tmpb, *tmpc;
642 		int i;
643 
644 		/* alias for digit pointers */
645 		tmpa = a->dp;
646 		tmpb = b->dp;
647 		tmpc = c->dp;
648 
649 		/* set carry to zero */
650 		carry = 0;
651 		for (i = 0; i < min; i++) {
652 			/* T[i] = A[i] - B[i] - U */
653 			*tmpc = *tmpa++ - *tmpb++ - carry;
654 
655 			/* U = carry bit of T[i]
656 			* Note this saves performing an AND operation since
657 			* if a carry does occur it will propagate all the way to the
658 			* MSB.  As a result a single shift is enough to get the carry
659 			*/
660 			carry = *tmpc >> ((mp_digit)(CHAR_BIT * sizeof(mp_digit) - 1));
661 
662 			/* Clear carry from T[i] */
663 			*tmpc++ &= MP_MASK;
664 		}
665 
666 		/* now copy higher words if any, e.g. if A has more digits than B  */
667 		for (; i < max; i++) {
668 			/* T[i] = A[i] - U */
669 			*tmpc = *tmpa++ - carry;
670 
671 			/* U = carry bit of T[i] */
672 			carry = *tmpc >> ((mp_digit)(CHAR_BIT * sizeof(mp_digit) - 1));
673 
674 			/* Clear carry from T[i] */
675 			*tmpc++ &= MP_MASK;
676 		}
677 
678 		/* clear digits above used (since we may not have grown result above) */
679 		if (olduse > c->used) {
680 			memset(tmpc, 0x0, (olduse - c->used) * sizeof(*a->dp));
681 		}
682 	}
683 
684 	trim_unused_digits(c);
685 	return MP_OKAY;
686 }
687 
688 /* high level subtraction (handles signs) */
689 static int
signed_subtract(mp_int * a,mp_int * b,mp_int * c)690 signed_subtract(mp_int * a, mp_int * b, mp_int * c)
691 {
692 	int     sa, sb, res;
693 
694 	sa = a->sign;
695 	sb = b->sign;
696 
697 	if (sa != sb) {
698 		/* subtract a negative from a positive, OR */
699 		/* subtract a positive from a negative. */
700 		/* In either case, ADD their magnitudes, */
701 		/* and use the sign of the first number. */
702 		c->sign = sa;
703 		res = basic_add(a, b, c);
704 	} else {
705 		/* subtract a positive from a positive, OR */
706 		/* subtract a negative from a negative. */
707 		/* First, take the difference between their */
708 		/* magnitudes, then... */
709 		if (compare_magnitude(a, b) != MP_LT) {
710 			/* Copy the sign from the first */
711 			c->sign = sa;
712 			/* The first has a larger or equal magnitude */
713 			res = basic_subtract(a, b, c);
714 		} else {
715 			/* The result has the *opposite* sign from */
716 			/* the first number. */
717 			c->sign = (sa == MP_ZPOS) ? MP_NEG : MP_ZPOS;
718 			/* The second has a larger magnitude */
719 			res = basic_subtract(b, a, c);
720 		}
721 	}
722 	return res;
723 }
724 
725 /* shift right a certain amount of digits */
726 static int
rshift_digits(mp_int * a,int b)727 rshift_digits(mp_int * a, int b)
728 {
729 	/* if b <= 0 then ignore it */
730 	if (b <= 0) {
731 		return 0;
732 	}
733 
734 	/* if b > used then simply zero it and return */
735 	if (a->used <= b) {
736 		mp_zero(a);
737 		return 0;
738 	}
739 
740 	/* this is implemented as a sliding window where
741 	* the window is b-digits long and digits from
742 	* the top of the window are copied to the bottom
743 	*
744 	* e.g.
745 
746 	b-2 | b-1 | b0 | b1 | b2 | ... | bb |   ---->
747 		 /\                   |      ---->
748 		  \-------------------/      ---->
749 	*/
750 	memmove(a->dp, &a->dp[b], (a->used - b) * sizeof(*a->dp));
751 	memset(&a->dp[a->used - b], 0x0, b * sizeof(*a->dp));
752 
753 	/* remove excess digits */
754 	a->used -= b;
755 	return 1;
756 }
757 
758 /* multiply by a digit */
759 static int
multiply_digit(mp_int * a,mp_digit b,mp_int * c)760 multiply_digit(mp_int * a, mp_digit b, mp_int * c)
761 {
762 	mp_digit carry, *tmpa, *tmpc;
763 	mp_word  r;
764 	int      ix, res, olduse;
765 
766 	/* make sure c is big enough to hold a*b */
767 	if (c->alloc < a->used + 1) {
768 		if ((res = mp_grow(c, a->used + 1)) != MP_OKAY) {
769 			return res;
770 		}
771 	}
772 
773 	/* get the original destinations used count */
774 	olduse = c->used;
775 
776 	/* set the sign */
777 	c->sign = a->sign;
778 
779 	/* alias for a->dp [source] */
780 	tmpa = a->dp;
781 
782 	/* alias for c->dp [dest] */
783 	tmpc = c->dp;
784 
785 	/* zero carry */
786 	carry = 0;
787 
788 	/* compute columns */
789 	for (ix = 0; ix < a->used; ix++) {
790 		/* compute product and carry sum for this term */
791 		r = ((mp_word) carry) + ((mp_word)*tmpa++) * ((mp_word)b);
792 
793 		/* mask off higher bits to get a single digit */
794 		*tmpc++ = (mp_digit) (r & ((mp_word) MP_MASK));
795 
796 		/* send carry into next iteration */
797 		carry = (mp_digit) (r >> ((mp_word) DIGIT_BIT));
798 	}
799 
800 	/* store final carry [if any] and increment ix offset  */
801 	*tmpc++ = carry;
802 	++ix;
803 	if (olduse > ix) {
804 		memset(tmpc, 0x0, (olduse - ix) * sizeof(*tmpc));
805 	}
806 
807 	/* set used count */
808 	c->used = a->used + 1;
809 	trim_unused_digits(c);
810 
811 	return MP_OKAY;
812 }
813 
814 /* high level addition (handles signs) */
815 static int
signed_add(mp_int * a,mp_int * b,mp_int * c)816 signed_add(mp_int * a, mp_int * b, mp_int * c)
817 {
818 	int     asign, bsign, res;
819 
820 	/* get sign of both inputs */
821 	asign = a->sign;
822 	bsign = b->sign;
823 
824 	/* handle two cases, not four */
825 	if (asign == bsign) {
826 		/* both positive or both negative */
827 		/* add their magnitudes, copy the sign */
828 		c->sign = asign;
829 		res = basic_add(a, b, c);
830 	} else {
831 		/* one positive, the other negative */
832 		/* subtract the one with the greater magnitude from */
833 		/* the one of the lesser magnitude.  The result gets */
834 		/* the sign of the one with the greater magnitude. */
835 		if (compare_magnitude(a, b) == MP_LT) {
836 			c->sign = bsign;
837 			res = basic_subtract(b, a, c);
838 		} else {
839 			c->sign = asign;
840 			res = basic_subtract(a, b, c);
841 		}
842 	}
843 	return res;
844 }
845 
846 /* swap the elements of two integers, for cases where you can't simply swap the
847  * mp_int pointers around
848  */
849 static void
mp_exch(mp_int * a,mp_int * b)850 mp_exch(mp_int *a, mp_int *b)
851 {
852 	mp_int  t;
853 
854 	t  = *a;
855 	*a = *b;
856 	*b = t;
857 }
858 
859 /* calc a value mod 2**b */
860 static int
modulo_2_to_power(mp_int * a,int b,mp_int * c)861 modulo_2_to_power(mp_int * a, int b, mp_int * c)
862 {
863 	int     x, res;
864 
865 	/* if b is <= 0 then zero the int */
866 	if (b <= 0) {
867 		mp_zero(c);
868 		return MP_OKAY;
869 	}
870 
871 	/* if the modulus is larger than the value than return */
872 	if (b >= (int) (a->used * DIGIT_BIT)) {
873 		res = mp_copy(a, c);
874 		return res;
875 	}
876 
877 	/* copy */
878 	if ((res = mp_copy(a, c)) != MP_OKAY) {
879 		return res;
880 	}
881 
882 	/* zero digits above the last digit of the modulus */
883 	for (x = (b / DIGIT_BIT) + ((b % DIGIT_BIT) == 0 ? 0 : 1); x < c->used; x++) {
884 		c->dp[x] = 0;
885 	}
886 	/* clear the digit that is not completely outside/inside the modulus */
887 	c->dp[b / DIGIT_BIT] &=
888 		(mp_digit) ((((mp_digit) 1) << (((mp_digit) b) % DIGIT_BIT)) - ((mp_digit) 1));
889 	trim_unused_digits(c);
890 	return MP_OKAY;
891 }
892 
893 /* shift right by a certain bit count (store quotient in c, optional remainder in d) */
894 static int
rshift_bits(mp_int * a,int b,mp_int * c,mp_int * d)895 rshift_bits(mp_int * a, int b, mp_int * c, mp_int * d)
896 {
897 	mp_digit D, r, rr;
898 	int     x, res;
899 	mp_int  t;
900 
901 
902 	/* if the shift count is <= 0 then we do no work */
903 	if (b <= 0) {
904 		res = mp_copy(a, c);
905 		if (d != NULL) {
906 			mp_zero(d);
907 		}
908 		return res;
909 	}
910 
911 	if ((res = mp_init(&t)) != MP_OKAY) {
912 		return res;
913 	}
914 
915 	/* get the remainder */
916 	if (d != NULL) {
917 		if ((res = modulo_2_to_power(a, b, &t)) != MP_OKAY) {
918 			mp_clear(&t);
919 			return res;
920 		}
921 	}
922 
923 	/* copy */
924 	if ((res = mp_copy(a, c)) != MP_OKAY) {
925 		mp_clear(&t);
926 		return res;
927 	}
928 
929 	/* shift by as many digits in the bit count */
930 	if (b >= (int)DIGIT_BIT) {
931 		rshift_digits(c, b / DIGIT_BIT);
932 	}
933 
934 	/* shift any bit count < DIGIT_BIT */
935 	D = (mp_digit) (b % DIGIT_BIT);
936 	if (D != 0) {
937 		mp_digit *tmpc, mask, shift;
938 
939 		/* mask */
940 		mask = (((mp_digit)1) << D) - 1;
941 
942 		/* shift for lsb */
943 		shift = DIGIT_BIT - D;
944 
945 		/* alias */
946 		tmpc = c->dp + (c->used - 1);
947 
948 		/* carry */
949 		r = 0;
950 		for (x = c->used - 1; x >= 0; x--) {
951 			/* get the lower  bits of this word in a temp */
952 			rr = *tmpc & mask;
953 
954 			/* shift the current word and mix in the carry bits from the previous word */
955 			*tmpc = (*tmpc >> D) | (r << shift);
956 			--tmpc;
957 
958 			/* set the carry to the carry bits of the current word found above */
959 			r = rr;
960 		}
961 	}
962 	trim_unused_digits(c);
963 	if (d != NULL) {
964 		mp_exch(&t, d);
965 	}
966 	mp_clear(&t);
967 	return MP_OKAY;
968 }
969 
970 /* integer signed division.
971  * c*b + d == a [e.g. a/b, c=quotient, d=remainder]
972  * HAC pp.598 Algorithm 14.20
973  *
974  * Note that the description in HAC is horribly
975  * incomplete.  For example, it doesn't consider
976  * the case where digits are removed from 'x' in
977  * the inner loop.  It also doesn't consider the
978  * case that y has fewer than three digits, etc..
979  *
980  * The overall algorithm is as described as
981  * 14.20 from HAC but fixed to treat these cases.
982 */
983 static int
signed_divide(mp_int * c,mp_int * d,mp_int * a,mp_int * b)984 signed_divide(mp_int *c, mp_int *d, mp_int *a, mp_int *b)
985 {
986 	mp_int  q, x, y, t1, t2;
987 	int     res, n, t, i, norm, neg;
988 
989 	/* is divisor zero ? */
990 	if (MP_ISZERO(b) == MP_YES) {
991 		return MP_VAL;
992 	}
993 
994 	/* if a < b then q=0, r = a */
995 	if (compare_magnitude(a, b) == MP_LT) {
996 		if (d != NULL) {
997 			res = mp_copy(a, d);
998 		} else {
999 			res = MP_OKAY;
1000 		}
1001 		if (c != NULL) {
1002 			mp_zero(c);
1003 		}
1004 		return res;
1005 	}
1006 
1007 	if ((res = mp_init_size(&q, a->used + 2)) != MP_OKAY) {
1008 		return res;
1009 	}
1010 	q.used = a->used + 2;
1011 
1012 	if ((res = mp_init(&t1)) != MP_OKAY) {
1013 		goto LBL_Q;
1014 	}
1015 
1016 	if ((res = mp_init(&t2)) != MP_OKAY) {
1017 		goto LBL_T1;
1018 	}
1019 
1020 	if ((res = mp_init_copy(&x, a)) != MP_OKAY) {
1021 		goto LBL_T2;
1022 	}
1023 
1024 	if ((res = mp_init_copy(&y, b)) != MP_OKAY) {
1025 		goto LBL_X;
1026 	}
1027 
1028 	/* fix the sign */
1029 	neg = (a->sign == b->sign) ? MP_ZPOS : MP_NEG;
1030 	x.sign = y.sign = MP_ZPOS;
1031 
1032 	/* normalize both x and y, ensure that y >= b/2, [b == 2**DIGIT_BIT] */
1033 	norm = mp_count_bits(&y) % DIGIT_BIT;
1034 	if (norm < (int)(DIGIT_BIT-1)) {
1035 		norm = (DIGIT_BIT-1) - norm;
1036 		if ((res = lshift_bits(&x, norm, &x)) != MP_OKAY) {
1037 			goto LBL_Y;
1038 		}
1039 		if ((res = lshift_bits(&y, norm, &y)) != MP_OKAY) {
1040 			goto LBL_Y;
1041 		}
1042 	} else {
1043 		norm = 0;
1044 	}
1045 
1046 	/* note hac does 0 based, so if used==5 then its 0,1,2,3,4, e.g. use 4 */
1047 	n = x.used - 1;
1048 	t = y.used - 1;
1049 
1050 	/* while (x >= y*b**n-t) do { q[n-t] += 1; x -= y*b**{n-t} } */
1051 	if ((res = lshift_digits(&y, n - t)) != MP_OKAY) { /* y = y*b**{n-t} */
1052 		goto LBL_Y;
1053 	}
1054 
1055 	while (signed_compare(&x, &y) != MP_LT) {
1056 		++(q.dp[n - t]);
1057 		if ((res = signed_subtract(&x, &y, &x)) != MP_OKAY) {
1058 			goto LBL_Y;
1059 		}
1060 	}
1061 
1062 	/* reset y by shifting it back down */
1063 	rshift_digits(&y, n - t);
1064 
1065 	/* step 3. for i from n down to (t + 1) */
1066 	for (i = n; i >= (t + 1); i--) {
1067 		if (i > x.used) {
1068 			continue;
1069 		}
1070 
1071 		/* step 3.1 if xi == yt then set q{i-t-1} to b-1,
1072 		* otherwise set q{i-t-1} to (xi*b + x{i-1})/yt */
1073 		if (x.dp[i] == y.dp[t]) {
1074 			q.dp[i - t - 1] = ((((mp_digit)1) << DIGIT_BIT) - 1);
1075 		} else {
1076 			mp_word tmp;
1077 			tmp = ((mp_word) x.dp[i]) << ((mp_word) DIGIT_BIT);
1078 			tmp |= ((mp_word) x.dp[i - 1]);
1079 			tmp /= ((mp_word) y.dp[t]);
1080 			if (tmp > (mp_word) MP_MASK) {
1081 				tmp = MP_MASK;
1082 			}
1083 			q.dp[i - t - 1] = (mp_digit) (tmp & (mp_word) (MP_MASK));
1084 		}
1085 
1086 		/* while (q{i-t-1} * (yt * b + y{t-1})) >
1087 		     xi * b**2 + xi-1 * b + xi-2
1088 			do q{i-t-1} -= 1;
1089 		*/
1090 		q.dp[i - t - 1] = (q.dp[i - t - 1] + 1) & MP_MASK;
1091 		do {
1092 			q.dp[i - t - 1] = (q.dp[i - t - 1] - 1) & MP_MASK;
1093 
1094 			/* find left hand */
1095 			mp_zero(&t1);
1096 			t1.dp[0] = (t - 1 < 0) ? 0 : y.dp[t - 1];
1097 			t1.dp[1] = y.dp[t];
1098 			t1.used = 2;
1099 			if ((res = multiply_digit(&t1, q.dp[i - t - 1], &t1)) != MP_OKAY) {
1100 				goto LBL_Y;
1101 			}
1102 
1103 			/* find right hand */
1104 			t2.dp[0] = (i - 2 < 0) ? 0 : x.dp[i - 2];
1105 			t2.dp[1] = (i - 1 < 0) ? 0 : x.dp[i - 1];
1106 			t2.dp[2] = x.dp[i];
1107 			t2.used = 3;
1108 		} while (compare_magnitude(&t1, &t2) == MP_GT);
1109 
1110 		/* step 3.3 x = x - q{i-t-1} * y * b**{i-t-1} */
1111 		if ((res = multiply_digit(&y, q.dp[i - t - 1], &t1)) != MP_OKAY) {
1112 			goto LBL_Y;
1113 		}
1114 
1115 		if ((res = lshift_digits(&t1, i - t - 1)) != MP_OKAY) {
1116 			goto LBL_Y;
1117 		}
1118 
1119 		if ((res = signed_subtract(&x, &t1, &x)) != MP_OKAY) {
1120 			goto LBL_Y;
1121 		}
1122 
1123 		/* if x < 0 then { x = x + y*b**{i-t-1}; q{i-t-1} -= 1; } */
1124 		if (x.sign == MP_NEG) {
1125 			if ((res = mp_copy(&y, &t1)) != MP_OKAY) {
1126 				goto LBL_Y;
1127 			}
1128 			if ((res = lshift_digits(&t1, i - t - 1)) != MP_OKAY) {
1129 				goto LBL_Y;
1130 			}
1131 			if ((res = signed_add(&x, &t1, &x)) != MP_OKAY) {
1132 				goto LBL_Y;
1133 			}
1134 
1135 			q.dp[i - t - 1] = (q.dp[i - t - 1] - 1UL) & MP_MASK;
1136 		}
1137 	}
1138 
1139 	/* now q is the quotient and x is the remainder
1140 	* [which we have to normalize]
1141 	*/
1142 
1143 	/* get sign before writing to c */
1144 	x.sign = x.used == 0 ? MP_ZPOS : a->sign;
1145 
1146 	if (c != NULL) {
1147 		trim_unused_digits(&q);
1148 		mp_exch(&q, c);
1149 		c->sign = neg;
1150 	}
1151 
1152 	if (d != NULL) {
1153 		rshift_bits(&x, norm, &x, NULL);
1154 		mp_exch(&x, d);
1155 	}
1156 
1157 	res = MP_OKAY;
1158 
1159 LBL_Y:
1160 	mp_clear(&y);
1161 LBL_X:
1162 	mp_clear(&x);
1163 LBL_T2:
1164 	mp_clear(&t2);
1165 LBL_T1:
1166 	mp_clear(&t1);
1167 LBL_Q:
1168 	mp_clear(&q);
1169 	return res;
1170 }
1171 
1172 /* c = a mod b, 0 <= c < b */
1173 static int
modulo(mp_int * a,mp_int * b,mp_int * c)1174 modulo(mp_int * a, mp_int * b, mp_int * c)
1175 {
1176 	mp_int  t;
1177 	int     res;
1178 
1179 	if ((res = mp_init(&t)) != MP_OKAY) {
1180 		return res;
1181 	}
1182 
1183 	if ((res = signed_divide(NULL, &t, a, b)) != MP_OKAY) {
1184 		mp_clear(&t);
1185 		return res;
1186 	}
1187 
1188 	if (t.sign != b->sign) {
1189 		res = signed_add(b, &t, c);
1190 	} else {
1191 		res = MP_OKAY;
1192 		mp_exch(&t, c);
1193 	}
1194 
1195 	mp_clear(&t);
1196 	return res;
1197 }
1198 
1199 /* set to a digit */
1200 static void
set_word(mp_int * a,mp_digit b)1201 set_word(mp_int * a, mp_digit b)
1202 {
1203 	mp_zero(a);
1204 	a->dp[0] = b & MP_MASK;
1205 	a->used = (a->dp[0] != 0) ? 1 : 0;
1206 }
1207 
1208 /* b = a/2 */
1209 static int
half(mp_int * a,mp_int * b)1210 half(mp_int * a, mp_int * b)
1211 {
1212 	int     x, res, oldused;
1213 
1214 	/* copy */
1215 	if (b->alloc < a->used) {
1216 		if ((res = mp_grow(b, a->used)) != MP_OKAY) {
1217 			return res;
1218 		}
1219 	}
1220 
1221 	oldused = b->used;
1222 	b->used = a->used;
1223 	{
1224 		mp_digit r, rr, *tmpa, *tmpb;
1225 
1226 		/* source alias */
1227 		tmpa = a->dp + b->used - 1;
1228 
1229 		/* dest alias */
1230 		tmpb = b->dp + b->used - 1;
1231 
1232 		/* carry */
1233 		r = 0;
1234 		for (x = b->used - 1; x >= 0; x--) {
1235 			/* get the carry for the next iteration */
1236 			rr = *tmpa & 1;
1237 
1238 			/* shift the current digit, add in carry and store */
1239 			*tmpb-- = (*tmpa-- >> 1) | (r << (DIGIT_BIT - 1));
1240 
1241 			/* forward carry to next iteration */
1242 			r = rr;
1243 		}
1244 
1245 		/* zero excess digits */
1246 		tmpb = b->dp + b->used;
1247 		for (x = b->used; x < oldused; x++) {
1248 			*tmpb++ = 0;
1249 		}
1250 	}
1251 	b->sign = a->sign;
1252 	trim_unused_digits(b);
1253 	return MP_OKAY;
1254 }
1255 
1256 /* compare a digit */
1257 static int
compare_digit(mp_int * a,mp_digit b)1258 compare_digit(mp_int * a, mp_digit b)
1259 {
1260 	/* compare based on sign */
1261 	if (a->sign == MP_NEG) {
1262 		return MP_LT;
1263 	}
1264 
1265 	/* compare based on magnitude */
1266 	if (a->used > 1) {
1267 		return MP_GT;
1268 	}
1269 
1270 	/* compare the only digit of a to b */
1271 	if (a->dp[0] > b) {
1272 		return MP_GT;
1273 	} else if (a->dp[0] < b) {
1274 		return MP_LT;
1275 	} else {
1276 		return MP_EQ;
1277 	}
1278 }
1279 
1280 static void
mp_clear_multi(mp_int * mp,...)1281 mp_clear_multi(mp_int *mp, ...)
1282 {
1283 	mp_int* next_mp = mp;
1284 	va_list args;
1285 
1286 	va_start(args, mp);
1287 	while (next_mp != NULL) {
1288 		mp_clear(next_mp);
1289 		next_mp = va_arg(args, mp_int*);
1290 	}
1291 	va_end(args);
1292 }
1293 
1294 /* computes the modular inverse via binary extended euclidean algorithm,
1295  * that is c = 1/a mod b
1296  *
1297  * Based on slow invmod except this is optimized for the case where b is
1298  * odd as per HAC Note 14.64 on pp. 610
1299  */
1300 static int
fast_modular_inverse(mp_int * a,mp_int * b,mp_int * c)1301 fast_modular_inverse(mp_int * a, mp_int * b, mp_int * c)
1302 {
1303 	mp_int  x, y, u, v, B, D;
1304 	int     res, neg;
1305 
1306 	/* 2. [modified] b must be odd   */
1307 	if (MP_ISZERO(b) == MP_YES) {
1308 		return MP_VAL;
1309 	}
1310 
1311 	/* init all our temps */
1312 	if ((res = mp_init_multi(&x, &y, &u, &v, &B, &D, NULL)) != MP_OKAY) {
1313 		return res;
1314 	}
1315 
1316 	/* x == modulus, y == value to invert */
1317 	if ((res = mp_copy(b, &x)) != MP_OKAY) {
1318 		goto LBL_ERR;
1319 	}
1320 
1321 	/* we need y = |a| */
1322 	if ((res = modulo(a, b, &y)) != MP_OKAY) {
1323 		goto LBL_ERR;
1324 	}
1325 
1326 	/* 3. u=x, v=y, A=1, B=0, C=0,D=1 */
1327 	if ((res = mp_copy(&x, &u)) != MP_OKAY) {
1328 		goto LBL_ERR;
1329 	}
1330 	if ((res = mp_copy(&y, &v)) != MP_OKAY) {
1331 		goto LBL_ERR;
1332 	}
1333 	set_word(&D, 1);
1334 
1335 top:
1336 	/* 4.  while u is even do */
1337 	while (BN_is_even(&u) == 1) {
1338 		/* 4.1 u = u/2 */
1339 		if ((res = half(&u, &u)) != MP_OKAY) {
1340 			goto LBL_ERR;
1341 		}
1342 		/* 4.2 if B is odd then */
1343 		if (BN_is_odd(&B) == 1) {
1344 			if ((res = signed_subtract(&B, &x, &B)) != MP_OKAY) {
1345 				goto LBL_ERR;
1346 			}
1347 		}
1348 		/* B = B/2 */
1349 		if ((res = half(&B, &B)) != MP_OKAY) {
1350 			goto LBL_ERR;
1351 		}
1352 	}
1353 
1354 	/* 5.  while v is even do */
1355 	while (BN_is_even(&v) == 1) {
1356 		/* 5.1 v = v/2 */
1357 		if ((res = half(&v, &v)) != MP_OKAY) {
1358 			goto LBL_ERR;
1359 		}
1360 		/* 5.2 if D is odd then */
1361 		if (BN_is_odd(&D) == 1) {
1362 			/* D = (D-x)/2 */
1363 			if ((res = signed_subtract(&D, &x, &D)) != MP_OKAY) {
1364 				goto LBL_ERR;
1365 			}
1366 		}
1367 		/* D = D/2 */
1368 		if ((res = half(&D, &D)) != MP_OKAY) {
1369 			goto LBL_ERR;
1370 		}
1371 	}
1372 
1373 	/* 6.  if u >= v then */
1374 	if (signed_compare(&u, &v) != MP_LT) {
1375 		/* u = u - v, B = B - D */
1376 		if ((res = signed_subtract(&u, &v, &u)) != MP_OKAY) {
1377 			goto LBL_ERR;
1378 		}
1379 
1380 		if ((res = signed_subtract(&B, &D, &B)) != MP_OKAY) {
1381 			goto LBL_ERR;
1382 		}
1383 	} else {
1384 		/* v - v - u, D = D - B */
1385 		if ((res = signed_subtract(&v, &u, &v)) != MP_OKAY) {
1386 			goto LBL_ERR;
1387 		}
1388 
1389 		if ((res = signed_subtract(&D, &B, &D)) != MP_OKAY) {
1390 			goto LBL_ERR;
1391 		}
1392 	}
1393 
1394 	/* if not zero goto step 4 */
1395 	if (MP_ISZERO(&u) == MP_NO) {
1396 		goto top;
1397 	}
1398 
1399 	/* now a = C, b = D, gcd == g*v */
1400 
1401 	/* if v != 1 then there is no inverse */
1402 	if (compare_digit(&v, 1) != MP_EQ) {
1403 		res = MP_VAL;
1404 		goto LBL_ERR;
1405 	}
1406 
1407 	/* b is now the inverse */
1408 	neg = a->sign;
1409 	while (D.sign == MP_NEG) {
1410 		if ((res = signed_add(&D, b, &D)) != MP_OKAY) {
1411 			goto LBL_ERR;
1412 		}
1413 	}
1414 	mp_exch(&D, c);
1415 	c->sign = neg;
1416 	res = MP_OKAY;
1417 
1418 LBL_ERR:
1419 	mp_clear_multi (&x, &y, &u, &v, &B, &D, NULL);
1420 	return res;
1421 }
1422 
1423 /* hac 14.61, pp608 */
1424 static int
slow_modular_inverse(mp_int * a,mp_int * b,mp_int * c)1425 slow_modular_inverse(mp_int * a, mp_int * b, mp_int * c)
1426 {
1427 	mp_int  x, y, u, v, A, B, C, D;
1428 	int     res;
1429 
1430 	/* b cannot be negative */
1431 	if (b->sign == MP_NEG || MP_ISZERO(b) == MP_YES) {
1432 		return MP_VAL;
1433 	}
1434 
1435 	/* init temps */
1436 	if ((res = mp_init_multi(&x, &y, &u, &v,
1437 		   &A, &B, &C, &D, NULL)) != MP_OKAY) {
1438 		return res;
1439 	}
1440 
1441 	/* x = a, y = b */
1442 	if ((res = modulo(a, b, &x)) != MP_OKAY) {
1443 		goto LBL_ERR;
1444 	}
1445 	if ((res = mp_copy(b, &y)) != MP_OKAY) {
1446 		goto LBL_ERR;
1447 	}
1448 
1449 	/* 2. [modified] if x,y are both even then return an error! */
1450 	if (BN_is_even(&x) == 1 && BN_is_even(&y) == 1) {
1451 		res = MP_VAL;
1452 		goto LBL_ERR;
1453 	}
1454 
1455 	/* 3. u=x, v=y, A=1, B=0, C=0,D=1 */
1456 	if ((res = mp_copy(&x, &u)) != MP_OKAY) {
1457 		goto LBL_ERR;
1458 	}
1459 	if ((res = mp_copy(&y, &v)) != MP_OKAY) {
1460 		goto LBL_ERR;
1461 	}
1462 	set_word(&A, 1);
1463 	set_word(&D, 1);
1464 
1465 top:
1466 	/* 4.  while u is even do */
1467 	while (BN_is_even(&u) == 1) {
1468 		/* 4.1 u = u/2 */
1469 		if ((res = half(&u, &u)) != MP_OKAY) {
1470 			goto LBL_ERR;
1471 		}
1472 		/* 4.2 if A or B is odd then */
1473 		if (BN_is_odd(&A) == 1 || BN_is_odd(&B) == 1) {
1474 			/* A = (A+y)/2, B = (B-x)/2 */
1475 			if ((res = signed_add(&A, &y, &A)) != MP_OKAY) {
1476 				 goto LBL_ERR;
1477 			}
1478 			if ((res = signed_subtract(&B, &x, &B)) != MP_OKAY) {
1479 				 goto LBL_ERR;
1480 			}
1481 		}
1482 		/* A = A/2, B = B/2 */
1483 		if ((res = half(&A, &A)) != MP_OKAY) {
1484 			goto LBL_ERR;
1485 		}
1486 		if ((res = half(&B, &B)) != MP_OKAY) {
1487 			goto LBL_ERR;
1488 		}
1489 	}
1490 
1491 	/* 5.  while v is even do */
1492 	while (BN_is_even(&v) == 1) {
1493 		/* 5.1 v = v/2 */
1494 		if ((res = half(&v, &v)) != MP_OKAY) {
1495 			goto LBL_ERR;
1496 		}
1497 		/* 5.2 if C or D is odd then */
1498 		if (BN_is_odd(&C) == 1 || BN_is_odd(&D) == 1) {
1499 			/* C = (C+y)/2, D = (D-x)/2 */
1500 			if ((res = signed_add(&C, &y, &C)) != MP_OKAY) {
1501 				 goto LBL_ERR;
1502 			}
1503 			if ((res = signed_subtract(&D, &x, &D)) != MP_OKAY) {
1504 				 goto LBL_ERR;
1505 			}
1506 		}
1507 		/* C = C/2, D = D/2 */
1508 		if ((res = half(&C, &C)) != MP_OKAY) {
1509 			goto LBL_ERR;
1510 		}
1511 		if ((res = half(&D, &D)) != MP_OKAY) {
1512 			goto LBL_ERR;
1513 		}
1514 	}
1515 
1516 	/* 6.  if u >= v then */
1517 	if (signed_compare(&u, &v) != MP_LT) {
1518 		/* u = u - v, A = A - C, B = B - D */
1519 		if ((res = signed_subtract(&u, &v, &u)) != MP_OKAY) {
1520 			goto LBL_ERR;
1521 		}
1522 
1523 		if ((res = signed_subtract(&A, &C, &A)) != MP_OKAY) {
1524 			goto LBL_ERR;
1525 		}
1526 
1527 		if ((res = signed_subtract(&B, &D, &B)) != MP_OKAY) {
1528 			goto LBL_ERR;
1529 		}
1530 	} else {
1531 		/* v - v - u, C = C - A, D = D - B */
1532 		if ((res = signed_subtract(&v, &u, &v)) != MP_OKAY) {
1533 			goto LBL_ERR;
1534 		}
1535 
1536 		if ((res = signed_subtract(&C, &A, &C)) != MP_OKAY) {
1537 			goto LBL_ERR;
1538 		}
1539 
1540 		if ((res = signed_subtract(&D, &B, &D)) != MP_OKAY) {
1541 			goto LBL_ERR;
1542 		}
1543 	}
1544 
1545 	/* if not zero goto step 4 */
1546 	if (BN_is_zero(&u) == 0) {
1547 		goto top;
1548 	}
1549 	/* now a = C, b = D, gcd == g*v */
1550 
1551 	/* if v != 1 then there is no inverse */
1552 	if (compare_digit(&v, 1) != MP_EQ) {
1553 		res = MP_VAL;
1554 		goto LBL_ERR;
1555 	}
1556 
1557 	/* if its too low */
1558 	while (compare_digit(&C, 0) == MP_LT) {
1559 		if ((res = signed_add(&C, b, &C)) != MP_OKAY) {
1560 			 goto LBL_ERR;
1561 		}
1562 	}
1563 
1564 	/* too big */
1565 	while (compare_magnitude(&C, b) != MP_LT) {
1566 		if ((res = signed_subtract(&C, b, &C)) != MP_OKAY) {
1567 			 goto LBL_ERR;
1568 		}
1569 	}
1570 
1571 	/* C is now the inverse */
1572 	mp_exch(&C, c);
1573 	res = MP_OKAY;
1574 LBL_ERR:
1575 	mp_clear_multi(&x, &y, &u, &v, &A, &B, &C, &D, NULL);
1576 	return res;
1577 }
1578 
1579 static int
modular_inverse(mp_int * c,mp_int * a,mp_int * b)1580 modular_inverse(mp_int *c, mp_int *a, mp_int *b)
1581 {
1582 	/* b cannot be negative */
1583 	if (b->sign == MP_NEG || MP_ISZERO(b) == MP_YES) {
1584 		return MP_VAL;
1585 	}
1586 
1587 	/* if the modulus is odd we can use a faster routine instead */
1588 	if (BN_is_odd(b) == 1) {
1589 		return fast_modular_inverse(a, b, c);
1590 	}
1591 	return slow_modular_inverse(a, b, c);
1592 }
1593 
1594 /* b = |a|
1595  *
1596  * Simple function copies the input and fixes the sign to positive
1597  */
1598 static int
absolute(mp_int * a,mp_int * b)1599 absolute(mp_int * a, mp_int * b)
1600 {
1601 	int     res;
1602 
1603 	/* copy a to b */
1604 	if (a != b) {
1605 		if ((res = mp_copy(a, b)) != MP_OKAY) {
1606 			return res;
1607 		}
1608 	}
1609 
1610 	/* force the sign of b to positive */
1611 	b->sign = MP_ZPOS;
1612 
1613 	return MP_OKAY;
1614 }
1615 
1616 /* determines if reduce_2k_l can be used */
1617 static int
mp_reduce_is_2k_l(mp_int * a)1618 mp_reduce_is_2k_l(mp_int *a)
1619 {
1620 	int ix, iy;
1621 
1622 	if (a->used == 0) {
1623 		return MP_NO;
1624 	} else if (a->used == 1) {
1625 		return MP_YES;
1626 	} else if (a->used > 1) {
1627 		/* if more than half of the digits are -1 we're sold */
1628 		for (iy = ix = 0; ix < a->used; ix++) {
1629 			if (a->dp[ix] == MP_MASK) {
1630 				++iy;
1631 			}
1632 		}
1633 		return (iy >= (a->used/2)) ? MP_YES : MP_NO;
1634 
1635 	}
1636 	return MP_NO;
1637 }
1638 
1639 /* computes a = 2**b
1640  *
1641  * Simple algorithm which zeroes the int, grows it then just sets one bit
1642  * as required.
1643  */
1644 static int
mp_2expt(mp_int * a,int b)1645 mp_2expt(mp_int * a, int b)
1646 {
1647 	int     res;
1648 
1649 	/* zero a as per default */
1650 	mp_zero(a);
1651 
1652 	/* grow a to accommodate the single bit */
1653 	if ((res = mp_grow(a, b / DIGIT_BIT + 1)) != MP_OKAY) {
1654 		return res;
1655 	}
1656 
1657 	/* set the used count of where the bit will go */
1658 	a->used = b / DIGIT_BIT + 1;
1659 
1660 	/* put the single bit in its place */
1661 	a->dp[b / DIGIT_BIT] = ((mp_digit)1) << (b % DIGIT_BIT);
1662 
1663 	return MP_OKAY;
1664 }
1665 
1666 /* pre-calculate the value required for Barrett reduction
1667  * For a given modulus "b" it calculates the value required in "a"
1668  */
1669 static int
mp_reduce_setup(mp_int * a,mp_int * b)1670 mp_reduce_setup(mp_int * a, mp_int * b)
1671 {
1672 	int     res;
1673 
1674 	if ((res = mp_2expt(a, b->used * 2 * DIGIT_BIT)) != MP_OKAY) {
1675 		return res;
1676 	}
1677 	return signed_divide(a, NULL, a, b);
1678 }
1679 
1680 /* b = a*2 */
1681 static int
doubled(mp_int * a,mp_int * b)1682 doubled(mp_int * a, mp_int * b)
1683 {
1684 	int     x, res, oldused;
1685 
1686 	/* grow to accommodate result */
1687 	if (b->alloc < a->used + 1) {
1688 		if ((res = mp_grow(b, a->used + 1)) != MP_OKAY) {
1689 			return res;
1690 		}
1691 	}
1692 
1693 	oldused = b->used;
1694 	b->used = a->used;
1695 
1696 	{
1697 		mp_digit r, rr, *tmpa, *tmpb;
1698 
1699 		/* alias for source */
1700 		tmpa = a->dp;
1701 
1702 		/* alias for dest */
1703 		tmpb = b->dp;
1704 
1705 		/* carry */
1706 		r = 0;
1707 		for (x = 0; x < a->used; x++) {
1708 
1709 			/* get what will be the *next* carry bit from the
1710 			* MSB of the current digit
1711 			*/
1712 			rr = *tmpa >> ((mp_digit)(DIGIT_BIT - 1));
1713 
1714 			/* now shift up this digit, add in the carry [from the previous] */
1715 			*tmpb++ = ((*tmpa++ << ((mp_digit)1)) | r) & MP_MASK;
1716 
1717 			/* copy the carry that would be from the source
1718 			* digit into the next iteration
1719 			*/
1720 			r = rr;
1721 		}
1722 
1723 		/* new leading digit? */
1724 		if (r != 0) {
1725 			/* add a MSB which is always 1 at this point */
1726 			*tmpb = 1;
1727 			++(b->used);
1728 		}
1729 
1730 		/* now zero any excess digits on the destination
1731 		* that we didn't write to
1732 		*/
1733 		tmpb = b->dp + b->used;
1734 		for (x = b->used; x < oldused; x++) {
1735 			*tmpb++ = 0;
1736 		}
1737 	}
1738 	b->sign = a->sign;
1739 	return MP_OKAY;
1740 }
1741 
1742 /* divide by three (based on routine from MPI and the GMP manual) */
1743 static int
third(mp_int * a,mp_int * c,mp_digit * d)1744 third(mp_int * a, mp_int *c, mp_digit * d)
1745 {
1746 	mp_int   q;
1747 	mp_word  w, t;
1748 	mp_digit b;
1749 	int      res, ix;
1750 
1751 	/* b = 2**DIGIT_BIT / 3 */
1752 	b = (((mp_word)1) << ((mp_word)DIGIT_BIT)) / ((mp_word)3);
1753 
1754 	if ((res = mp_init_size(&q, a->used)) != MP_OKAY) {
1755 		return res;
1756 	}
1757 
1758 	q.used = a->used;
1759 	q.sign = a->sign;
1760 	w = 0;
1761 	for (ix = a->used - 1; ix >= 0; ix--) {
1762 		w = (w << ((mp_word)DIGIT_BIT)) | ((mp_word)a->dp[ix]);
1763 
1764 		if (w >= 3) {
1765 			/* multiply w by [1/3] */
1766 			t = (w * ((mp_word)b)) >> ((mp_word)DIGIT_BIT);
1767 
1768 			/* now subtract 3 * [w/3] from w, to get the remainder */
1769 			w -= t+t+t;
1770 
1771 			/* fixup the remainder as required since
1772 			* the optimization is not exact.
1773 			*/
1774 			while (w >= 3) {
1775 				t += 1;
1776 				w -= 3;
1777 			}
1778 		} else {
1779 			t = 0;
1780 		}
1781 		q.dp[ix] = (mp_digit)t;
1782 	}
1783 
1784 	/* [optional] store the remainder */
1785 	if (d != NULL) {
1786 		*d = (mp_digit)w;
1787 	}
1788 
1789 	/* [optional] store the quotient */
1790 	if (c != NULL) {
1791 		trim_unused_digits(&q);
1792 		mp_exch(&q, c);
1793 	}
1794 	mp_clear(&q);
1795 
1796 	return res;
1797 }
1798 
1799 /* multiplication using the Toom-Cook 3-way algorithm
1800  *
1801  * Much more complicated than Karatsuba but has a lower
1802  * asymptotic running time of O(N**1.464).  This algorithm is
1803  * only particularly useful on VERY large inputs
1804  * (we're talking 1000s of digits here...).
1805 */
1806 static int
toom_cook_multiply(mp_int * a,mp_int * b,mp_int * c)1807 toom_cook_multiply(mp_int *a, mp_int *b, mp_int *c)
1808 {
1809 	mp_int w0, w1, w2, w3, w4, tmp1, tmp2, a0, a1, a2, b0, b1, b2;
1810 	int res, B;
1811 
1812 	/* init temps */
1813 	if ((res = mp_init_multi(&w0, &w1, &w2, &w3, &w4,
1814 			&a0, &a1, &a2, &b0, &b1,
1815 			&b2, &tmp1, &tmp2, NULL)) != MP_OKAY) {
1816 		return res;
1817 	}
1818 
1819 	/* B */
1820 	B = MIN(a->used, b->used) / 3;
1821 
1822 	/* a = a2 * B**2 + a1 * B + a0 */
1823 	if ((res = modulo_2_to_power(a, DIGIT_BIT * B, &a0)) != MP_OKAY) {
1824 		goto ERR;
1825 	}
1826 
1827 	if ((res = mp_copy(a, &a1)) != MP_OKAY) {
1828 		goto ERR;
1829 	}
1830 	rshift_digits(&a1, B);
1831 	modulo_2_to_power(&a1, DIGIT_BIT * B, &a1);
1832 
1833 	if ((res = mp_copy(a, &a2)) != MP_OKAY) {
1834 		goto ERR;
1835 	}
1836 	rshift_digits(&a2, B*2);
1837 
1838 	/* b = b2 * B**2 + b1 * B + b0 */
1839 	if ((res = modulo_2_to_power(b, DIGIT_BIT * B, &b0)) != MP_OKAY) {
1840 		goto ERR;
1841 	}
1842 
1843 	if ((res = mp_copy(b, &b1)) != MP_OKAY) {
1844 		goto ERR;
1845 	}
1846 	rshift_digits(&b1, B);
1847 	modulo_2_to_power(&b1, DIGIT_BIT * B, &b1);
1848 
1849 	if ((res = mp_copy(b, &b2)) != MP_OKAY) {
1850 		goto ERR;
1851 	}
1852 	rshift_digits(&b2, B*2);
1853 
1854 	/* w0 = a0*b0 */
1855 	if ((res = signed_multiply(&a0, &b0, &w0)) != MP_OKAY) {
1856 		goto ERR;
1857 	}
1858 
1859 	/* w4 = a2 * b2 */
1860 	if ((res = signed_multiply(&a2, &b2, &w4)) != MP_OKAY) {
1861 		goto ERR;
1862 	}
1863 
1864 	/* w1 = (a2 + 2(a1 + 2a0))(b2 + 2(b1 + 2b0)) */
1865 	if ((res = doubled(&a0, &tmp1)) != MP_OKAY) {
1866 		goto ERR;
1867 	}
1868 	if ((res = signed_add(&tmp1, &a1, &tmp1)) != MP_OKAY) {
1869 		goto ERR;
1870 	}
1871 	if ((res = doubled(&tmp1, &tmp1)) != MP_OKAY) {
1872 		goto ERR;
1873 	}
1874 	if ((res = signed_add(&tmp1, &a2, &tmp1)) != MP_OKAY) {
1875 		goto ERR;
1876 	}
1877 
1878 	if ((res = doubled(&b0, &tmp2)) != MP_OKAY) {
1879 		goto ERR;
1880 	}
1881 	if ((res = signed_add(&tmp2, &b1, &tmp2)) != MP_OKAY) {
1882 		goto ERR;
1883 	}
1884 	if ((res = doubled(&tmp2, &tmp2)) != MP_OKAY) {
1885 		goto ERR;
1886 	}
1887 	if ((res = signed_add(&tmp2, &b2, &tmp2)) != MP_OKAY) {
1888 		goto ERR;
1889 	}
1890 
1891 	if ((res = signed_multiply(&tmp1, &tmp2, &w1)) != MP_OKAY) {
1892 		goto ERR;
1893 	}
1894 
1895 	/* w3 = (a0 + 2(a1 + 2a2))(b0 + 2(b1 + 2b2)) */
1896 	if ((res = doubled(&a2, &tmp1)) != MP_OKAY) {
1897 		goto ERR;
1898 	}
1899 	if ((res = signed_add(&tmp1, &a1, &tmp1)) != MP_OKAY) {
1900 		goto ERR;
1901 	}
1902 	if ((res = doubled(&tmp1, &tmp1)) != MP_OKAY) {
1903 		goto ERR;
1904 	}
1905 	if ((res = signed_add(&tmp1, &a0, &tmp1)) != MP_OKAY) {
1906 		goto ERR;
1907 	}
1908 
1909 	if ((res = doubled(&b2, &tmp2)) != MP_OKAY) {
1910 		goto ERR;
1911 	}
1912 	if ((res = signed_add(&tmp2, &b1, &tmp2)) != MP_OKAY) {
1913 		goto ERR;
1914 	}
1915 	if ((res = doubled(&tmp2, &tmp2)) != MP_OKAY) {
1916 		goto ERR;
1917 	}
1918 	if ((res = signed_add(&tmp2, &b0, &tmp2)) != MP_OKAY) {
1919 		goto ERR;
1920 	}
1921 
1922 	if ((res = signed_multiply(&tmp1, &tmp2, &w3)) != MP_OKAY) {
1923 		goto ERR;
1924 	}
1925 
1926 
1927 	/* w2 = (a2 + a1 + a0)(b2 + b1 + b0) */
1928 	if ((res = signed_add(&a2, &a1, &tmp1)) != MP_OKAY) {
1929 		goto ERR;
1930 	}
1931 	if ((res = signed_add(&tmp1, &a0, &tmp1)) != MP_OKAY) {
1932 		goto ERR;
1933 	}
1934 	if ((res = signed_add(&b2, &b1, &tmp2)) != MP_OKAY) {
1935 		goto ERR;
1936 	}
1937 	if ((res = signed_add(&tmp2, &b0, &tmp2)) != MP_OKAY) {
1938 		goto ERR;
1939 	}
1940 	if ((res = signed_multiply(&tmp1, &tmp2, &w2)) != MP_OKAY) {
1941 		goto ERR;
1942 	}
1943 
1944 	/* now solve the matrix
1945 
1946 	0  0  0  0  1
1947 	1  2  4  8  16
1948 	1  1  1  1  1
1949 	16 8  4  2  1
1950 	1  0  0  0  0
1951 
1952 	using 12 subtractions, 4 shifts,
1953 	2 small divisions and 1 small multiplication
1954 	*/
1955 
1956 	/* r1 - r4 */
1957 	if ((res = signed_subtract(&w1, &w4, &w1)) != MP_OKAY) {
1958 		goto ERR;
1959 	}
1960 	/* r3 - r0 */
1961 	if ((res = signed_subtract(&w3, &w0, &w3)) != MP_OKAY) {
1962 		goto ERR;
1963 	}
1964 	/* r1/2 */
1965 	if ((res = half(&w1, &w1)) != MP_OKAY) {
1966 		goto ERR;
1967 	}
1968 	/* r3/2 */
1969 	if ((res = half(&w3, &w3)) != MP_OKAY) {
1970 		goto ERR;
1971 	}
1972 	/* r2 - r0 - r4 */
1973 	if ((res = signed_subtract(&w2, &w0, &w2)) != MP_OKAY) {
1974 		goto ERR;
1975 	}
1976 	if ((res = signed_subtract(&w2, &w4, &w2)) != MP_OKAY) {
1977 		goto ERR;
1978 	}
1979 	/* r1 - r2 */
1980 	if ((res = signed_subtract(&w1, &w2, &w1)) != MP_OKAY) {
1981 		goto ERR;
1982 	}
1983 	/* r3 - r2 */
1984 	if ((res = signed_subtract(&w3, &w2, &w3)) != MP_OKAY) {
1985 		goto ERR;
1986 	}
1987 	/* r1 - 8r0 */
1988 	if ((res = lshift_bits(&w0, 3, &tmp1)) != MP_OKAY) {
1989 		goto ERR;
1990 	}
1991 	if ((res = signed_subtract(&w1, &tmp1, &w1)) != MP_OKAY) {
1992 		goto ERR;
1993 	}
1994 	/* r3 - 8r4 */
1995 	if ((res = lshift_bits(&w4, 3, &tmp1)) != MP_OKAY) {
1996 		goto ERR;
1997 	}
1998 	if ((res = signed_subtract(&w3, &tmp1, &w3)) != MP_OKAY) {
1999 		goto ERR;
2000 	}
2001 	/* 3r2 - r1 - r3 */
2002 	if ((res = multiply_digit(&w2, 3, &w2)) != MP_OKAY) {
2003 		goto ERR;
2004 	}
2005 	if ((res = signed_subtract(&w2, &w1, &w2)) != MP_OKAY) {
2006 		goto ERR;
2007 	}
2008 	if ((res = signed_subtract(&w2, &w3, &w2)) != MP_OKAY) {
2009 		goto ERR;
2010 	}
2011 	/* r1 - r2 */
2012 	if ((res = signed_subtract(&w1, &w2, &w1)) != MP_OKAY) {
2013 		goto ERR;
2014 	}
2015 	/* r3 - r2 */
2016 	if ((res = signed_subtract(&w3, &w2, &w3)) != MP_OKAY) {
2017 		goto ERR;
2018 	}
2019 	/* r1/3 */
2020 	if ((res = third(&w1, &w1, NULL)) != MP_OKAY) {
2021 		goto ERR;
2022 	}
2023 	/* r3/3 */
2024 	if ((res = third(&w3, &w3, NULL)) != MP_OKAY) {
2025 		goto ERR;
2026 	}
2027 
2028 	/* at this point shift W[n] by B*n */
2029 	if ((res = lshift_digits(&w1, 1*B)) != MP_OKAY) {
2030 		goto ERR;
2031 	}
2032 	if ((res = lshift_digits(&w2, 2*B)) != MP_OKAY) {
2033 		goto ERR;
2034 	}
2035 	if ((res = lshift_digits(&w3, 3*B)) != MP_OKAY) {
2036 		goto ERR;
2037 	}
2038 	if ((res = lshift_digits(&w4, 4*B)) != MP_OKAY) {
2039 		goto ERR;
2040 	}
2041 
2042 	if ((res = signed_add(&w0, &w1, c)) != MP_OKAY) {
2043 		goto ERR;
2044 	}
2045 	if ((res = signed_add(&w2, &w3, &tmp1)) != MP_OKAY) {
2046 		goto ERR;
2047 	}
2048 	if ((res = signed_add(&w4, &tmp1, &tmp1)) != MP_OKAY) {
2049 		goto ERR;
2050 	}
2051 	if ((res = signed_add(&tmp1, c, c)) != MP_OKAY) {
2052 		goto ERR;
2053 	}
2054 
2055 ERR:
2056 	mp_clear_multi(&w0, &w1, &w2, &w3, &w4,
2057 		&a0, &a1, &a2, &b0, &b1,
2058 		&b2, &tmp1, &tmp2, NULL);
2059 	return res;
2060 }
2061 
2062 #define TOOM_MUL_CUTOFF	350
2063 #define KARATSUBA_MUL_CUTOFF 80
2064 
2065 /* c = |a| * |b| using Karatsuba Multiplication using
2066  * three half size multiplications
2067  *
2068  * Let B represent the radix [e.g. 2**DIGIT_BIT] and
2069  * let n represent half of the number of digits in
2070  * the min(a,b)
2071  *
2072  * a = a1 * B**n + a0
2073  * b = b1 * B**n + b0
2074  *
2075  * Then, a * b =>
2076    a1b1 * B**2n + ((a1 + a0)(b1 + b0) - (a0b0 + a1b1)) * B + a0b0
2077  *
2078  * Note that a1b1 and a0b0 are used twice and only need to be
2079  * computed once.  So in total three half size (half # of
2080  * digit) multiplications are performed, a0b0, a1b1 and
2081  * (a1+b1)(a0+b0)
2082  *
2083  * Note that a multiplication of half the digits requires
2084  * 1/4th the number of single precision multiplications so in
2085  * total after one call 25% of the single precision multiplications
2086  * are saved.  Note also that the call to signed_multiply can end up back
2087  * in this function if the a0, a1, b0, or b1 are above the threshold.
2088  * This is known as divide-and-conquer and leads to the famous
2089  * O(N**lg(3)) or O(N**1.584) work which is asymptotically lower than
2090  * the standard O(N**2) that the baseline/comba methods use.
2091  * Generally though the overhead of this method doesn't pay off
2092  * until a certain size (N ~ 80) is reached.
2093  */
2094 static int
karatsuba_multiply(mp_int * a,mp_int * b,mp_int * c)2095 karatsuba_multiply(mp_int * a, mp_int * b, mp_int * c)
2096 {
2097 	mp_int  x0, x1, y0, y1, t1, x0y0, x1y1;
2098 	int     B;
2099 	int     err;
2100 
2101 	/* default the return code to an error */
2102 	err = MP_MEM;
2103 
2104 	/* min # of digits */
2105 	B = MIN(a->used, b->used);
2106 
2107 	/* now divide in two */
2108 	B = (int)((unsigned)B >> 1);
2109 
2110 	/* init copy all the temps */
2111 	if (mp_init_size(&x0, B) != MP_OKAY) {
2112 		goto ERR;
2113 	}
2114 	if (mp_init_size(&x1, a->used - B) != MP_OKAY) {
2115 		goto X0;
2116 	}
2117 	if (mp_init_size(&y0, B) != MP_OKAY) {
2118 		goto X1;
2119 	}
2120 	if (mp_init_size(&y1, b->used - B) != MP_OKAY) {
2121 		goto Y0;
2122 	}
2123 	/* init temps */
2124 	if (mp_init_size(&t1, B * 2) != MP_OKAY) {
2125 		goto Y1;
2126 	}
2127 	if (mp_init_size(&x0y0, B * 2) != MP_OKAY) {
2128 		goto T1;
2129 	}
2130 	if (mp_init_size(&x1y1, B * 2) != MP_OKAY) {
2131 		goto X0Y0;
2132 	}
2133 	/* now shift the digits */
2134 	x0.used = y0.used = B;
2135 	x1.used = a->used - B;
2136 	y1.used = b->used - B;
2137 
2138 	{
2139 		int x;
2140 		mp_digit *tmpa, *tmpb, *tmpx, *tmpy;
2141 
2142 		/* we copy the digits directly instead of using higher level functions
2143 		* since we also need to shift the digits
2144 		*/
2145 		tmpa = a->dp;
2146 		tmpb = b->dp;
2147 
2148 		tmpx = x0.dp;
2149 		tmpy = y0.dp;
2150 		for (x = 0; x < B; x++) {
2151 			*tmpx++ = *tmpa++;
2152 			*tmpy++ = *tmpb++;
2153 		}
2154 
2155 		tmpx = x1.dp;
2156 		for (x = B; x < a->used; x++) {
2157 			*tmpx++ = *tmpa++;
2158 		}
2159 
2160 		tmpy = y1.dp;
2161 		for (x = B; x < b->used; x++) {
2162 			*tmpy++ = *tmpb++;
2163 		}
2164 	}
2165 
2166 	/* only need to clamp the lower words since by definition the
2167 	* upper words x1/y1 must have a known number of digits
2168 	*/
2169 	trim_unused_digits(&x0);
2170 	trim_unused_digits(&y0);
2171 
2172 	/* now calc the products x0y0 and x1y1 */
2173 	/* after this x0 is no longer required, free temp [x0==t2]! */
2174 	if (signed_multiply(&x0, &y0, &x0y0) != MP_OKAY)  {
2175 		goto X1Y1;          /* x0y0 = x0*y0 */
2176 	}
2177 	if (signed_multiply(&x1, &y1, &x1y1) != MP_OKAY) {
2178 		goto X1Y1;          /* x1y1 = x1*y1 */
2179 	}
2180 	/* now calc x1+x0 and y1+y0 */
2181 	if (basic_add(&x1, &x0, &t1) != MP_OKAY) {
2182 		goto X1Y1;          /* t1 = x1 - x0 */
2183 	}
2184 	if (basic_add(&y1, &y0, &x0) != MP_OKAY) {
2185 		goto X1Y1;          /* t2 = y1 - y0 */
2186 	}
2187 	if (signed_multiply(&t1, &x0, &t1) != MP_OKAY) {
2188 		goto X1Y1;          /* t1 = (x1 + x0) * (y1 + y0) */
2189 	}
2190 	/* add x0y0 */
2191 	if (signed_add(&x0y0, &x1y1, &x0) != MP_OKAY) {
2192 		goto X1Y1;          /* t2 = x0y0 + x1y1 */
2193 	}
2194 	if (basic_subtract(&t1, &x0, &t1) != MP_OKAY) {
2195 		goto X1Y1;          /* t1 = (x1+x0)*(y1+y0) - (x1y1 + x0y0) */
2196 	}
2197 	/* shift by B */
2198 	if (lshift_digits(&t1, B) != MP_OKAY) {
2199 		goto X1Y1;          /* t1 = (x0y0 + x1y1 - (x1-x0)*(y1-y0))<<B */
2200 	}
2201 	if (lshift_digits(&x1y1, B * 2) != MP_OKAY) {
2202 		goto X1Y1;          /* x1y1 = x1y1 << 2*B */
2203 	}
2204 	if (signed_add(&x0y0, &t1, &t1) != MP_OKAY) {
2205 		goto X1Y1;          /* t1 = x0y0 + t1 */
2206 	}
2207 	if (signed_add(&t1, &x1y1, c) != MP_OKAY) {
2208 		goto X1Y1;          /* t1 = x0y0 + t1 + x1y1 */
2209 	}
2210 	/* Algorithm succeeded set the return code to MP_OKAY */
2211 	err = MP_OKAY;
2212 
2213 X1Y1:
2214 	mp_clear(&x1y1);
2215 X0Y0:
2216 	mp_clear(&x0y0);
2217 T1:
2218 	mp_clear(&t1);
2219 Y1:
2220 	mp_clear(&y1);
2221 Y0:
2222 	mp_clear(&y0);
2223 X1:
2224 	mp_clear(&x1);
2225 X0:
2226 	mp_clear(&x0);
2227 ERR:
2228 	return err;
2229 }
2230 
2231 /* Fast (comba) multiplier
2232  *
2233  * This is the fast column-array [comba] multiplier.  It is
2234  * designed to compute the columns of the product first
2235  * then handle the carries afterwards.  This has the effect
2236  * of making the nested loops that compute the columns very
2237  * simple and schedulable on super-scalar processors.
2238  *
2239  * This has been modified to produce a variable number of
2240  * digits of output so if say only a half-product is required
2241  * you don't have to compute the upper half (a feature
2242  * required for fast Barrett reduction).
2243  *
2244  * Based on Algorithm 14.12 on pp.595 of HAC.
2245  *
2246  */
2247 static int
fast_col_array_multiply(mp_int * a,mp_int * b,mp_int * c,int digs)2248 fast_col_array_multiply(mp_int * a, mp_int * b, mp_int * c, int digs)
2249 {
2250 	int     olduse, res, pa, ix, iz;
2251 	/*LINTED*/
2252 	mp_digit W[MP_WARRAY];
2253 	mp_word  _W;
2254 
2255 	/* grow the destination as required */
2256 	if (c->alloc < digs) {
2257 		if ((res = mp_grow(c, digs)) != MP_OKAY) {
2258 			return res;
2259 		}
2260 	}
2261 
2262 	/* number of output digits to produce */
2263 	pa = MIN(digs, a->used + b->used);
2264 
2265 	/* clear the carry */
2266 	_W = 0;
2267 	for (ix = 0; ix < pa; ix++) {
2268 		int      tx, ty;
2269 		int      iy;
2270 		mp_digit *tmpx, *tmpy;
2271 
2272 		/* get offsets into the two bignums */
2273 		ty = MIN(b->used-1, ix);
2274 		tx = ix - ty;
2275 
2276 		/* setup temp aliases */
2277 		tmpx = a->dp + tx;
2278 		tmpy = b->dp + ty;
2279 
2280 		/* this is the number of times the loop will iterate, essentially
2281 		while (tx++ < a->used && ty-- >= 0) { ... }
2282 		*/
2283 		iy = MIN(a->used-tx, ty+1);
2284 
2285 		/* execute loop */
2286 		for (iz = 0; iz < iy; ++iz) {
2287 			_W += ((mp_word)*tmpx++)*((mp_word)*tmpy--);
2288 
2289 		}
2290 
2291 		/* store term */
2292 		W[ix] = ((mp_digit)_W) & MP_MASK;
2293 
2294 		/* make next carry */
2295 		_W = _W >> ((mp_word)DIGIT_BIT);
2296 	}
2297 
2298 	/* setup dest */
2299 	olduse  = c->used;
2300 	c->used = pa;
2301 
2302 	{
2303 		mp_digit *tmpc;
2304 		tmpc = c->dp;
2305 		for (ix = 0; ix < pa+1; ix++) {
2306 			/* now extract the previous digit [below the carry] */
2307 			*tmpc++ = (ix < pa) ? W[ix] : 0;
2308 		}
2309 
2310 		/* clear unused digits [that existed in the old copy of c] */
2311 		for (; ix < olduse; ix++) {
2312 			*tmpc++ = 0;
2313 		}
2314 	}
2315 	trim_unused_digits(c);
2316 	return MP_OKAY;
2317 }
2318 
2319 /* return 1 if we can use fast column array multiply */
2320 /*
2321 * The fast multiplier can be used if the output will
2322 * have less than MP_WARRAY digits and the number of
2323 * digits won't affect carry propagation
2324 */
2325 static inline int
can_use_fast_column_array(int ndigits,int used)2326 can_use_fast_column_array(int ndigits, int used)
2327 {
2328 	return (((unsigned)ndigits < MP_WARRAY) &&
2329 		used < (1 << (unsigned)((CHAR_BIT * sizeof(mp_word)) - (2 * DIGIT_BIT))));
2330 }
2331 
2332 /* Source: /usr/cvsroot/libtommath/dist/libtommath/bn_fast_s_mp_mul_digs.c,v $ */
2333 /* Revision: 1.2 $ */
2334 /* Date: 2011/03/18 16:22:09 $ */
2335 
2336 
2337 /* multiplies |a| * |b| and only computes upto digs digits of result
2338  * HAC pp. 595, Algorithm 14.12  Modified so you can control how
2339  * many digits of output are created.
2340  */
2341 static int
basic_multiply_partial_lower(mp_int * a,mp_int * b,mp_int * c,int digs)2342 basic_multiply_partial_lower(mp_int * a, mp_int * b, mp_int * c, int digs)
2343 {
2344 	mp_int  t;
2345 	int     res, pa, pb, ix, iy;
2346 	mp_digit u;
2347 	mp_word r;
2348 	mp_digit tmpx, *tmpt, *tmpy;
2349 
2350 	/* can we use the fast multiplier? */
2351 	if (can_use_fast_column_array(digs, MIN(a->used, b->used))) {
2352 		return fast_col_array_multiply(a, b, c, digs);
2353 	}
2354 
2355 	if ((res = mp_init_size(&t, digs)) != MP_OKAY) {
2356 		return res;
2357 	}
2358 	t.used = digs;
2359 
2360 	/* compute the digits of the product directly */
2361 	pa = a->used;
2362 	for (ix = 0; ix < pa; ix++) {
2363 		/* set the carry to zero */
2364 		u = 0;
2365 
2366 		/* limit ourselves to making digs digits of output */
2367 		pb = MIN(b->used, digs - ix);
2368 
2369 		/* setup some aliases */
2370 		/* copy of the digit from a used within the nested loop */
2371 		tmpx = a->dp[ix];
2372 
2373 		/* an alias for the destination shifted ix places */
2374 		tmpt = t.dp + ix;
2375 
2376 		/* an alias for the digits of b */
2377 		tmpy = b->dp;
2378 
2379 		/* compute the columns of the output and propagate the carry */
2380 		for (iy = 0; iy < pb; iy++) {
2381 			/* compute the column as a mp_word */
2382 			r = ((mp_word)*tmpt) +
2383 				((mp_word)tmpx) * ((mp_word)*tmpy++) +
2384 				((mp_word) u);
2385 
2386 			/* the new column is the lower part of the result */
2387 			*tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
2388 
2389 			/* get the carry word from the result */
2390 			u = (mp_digit) (r >> ((mp_word) DIGIT_BIT));
2391 		}
2392 		/* set carry if it is placed below digs */
2393 		if (ix + iy < digs) {
2394 			*tmpt = u;
2395 		}
2396 	}
2397 
2398 	trim_unused_digits(&t);
2399 	mp_exch(&t, c);
2400 
2401 	mp_clear(&t);
2402 	return MP_OKAY;
2403 }
2404 
2405 /* Source: /usr/cvsroot/libtommath/dist/libtommath/bn_s_mp_mul_digs.c,v $ */
2406 /* Revision: 1.3 $ */
2407 /* Date: 2011/03/18 16:43:04 $ */
2408 
2409 /* high level multiplication (handles sign) */
2410 static int
signed_multiply(mp_int * a,mp_int * b,mp_int * c)2411 signed_multiply(mp_int * a, mp_int * b, mp_int * c)
2412 {
2413 	int     res, neg;
2414 
2415 	neg = (a->sign == b->sign) ? MP_ZPOS : MP_NEG;
2416 	/* use Toom-Cook? */
2417 	if (MIN(a->used, b->used) >= TOOM_MUL_CUTOFF) {
2418 		res = toom_cook_multiply(a, b, c);
2419 	} else if (MIN(a->used, b->used) >= KARATSUBA_MUL_CUTOFF) {
2420 		/* use Karatsuba? */
2421 		res = karatsuba_multiply(a, b, c);
2422 	} else {
2423 		/* can we use the fast multiplier? */
2424 		int     digs = a->used + b->used + 1;
2425 
2426 		if (can_use_fast_column_array(digs, MIN(a->used, b->used))) {
2427 			res = fast_col_array_multiply(a, b, c, digs);
2428 		} else  {
2429 			res = basic_multiply_partial_lower(a, b, c, (a)->used + (b)->used + 1);
2430 		}
2431 	}
2432 	c->sign = (c->used > 0) ? neg : MP_ZPOS;
2433 	return res;
2434 }
2435 
2436 /* this is a modified version of fast_s_mul_digs that only produces
2437  * output digits *above* digs.  See the comments for fast_s_mul_digs
2438  * to see how it works.
2439  *
2440  * This is used in the Barrett reduction since for one of the multiplications
2441  * only the higher digits were needed.  This essentially halves the work.
2442  *
2443  * Based on Algorithm 14.12 on pp.595 of HAC.
2444  */
2445 static int
fast_basic_multiply_partial_upper(mp_int * a,mp_int * b,mp_int * c,int digs)2446 fast_basic_multiply_partial_upper(mp_int * a, mp_int * b, mp_int * c, int digs)
2447 {
2448 	int     olduse, res, pa, ix, iz;
2449 	mp_digit W[MP_WARRAY];
2450 	mp_word  _W;
2451 
2452 	/* grow the destination as required */
2453 	pa = a->used + b->used;
2454 	if (c->alloc < pa) {
2455 		if ((res = mp_grow(c, pa)) != MP_OKAY) {
2456 			return res;
2457 		}
2458 	}
2459 
2460 	/* number of output digits to produce */
2461 	pa = a->used + b->used;
2462 	_W = 0;
2463 	for (ix = digs; ix < pa; ix++) {
2464 		int      tx, ty, iy;
2465 		mp_digit *tmpx, *tmpy;
2466 
2467 		/* get offsets into the two bignums */
2468 		ty = MIN(b->used-1, ix);
2469 		tx = ix - ty;
2470 
2471 		/* setup temp aliases */
2472 		tmpx = a->dp + tx;
2473 		tmpy = b->dp + ty;
2474 
2475 		/* this is the number of times the loop will iterate, essentially its
2476 		 while (tx++ < a->used && ty-- >= 0) { ... }
2477 		*/
2478 		iy = MIN(a->used-tx, ty+1);
2479 
2480 		/* execute loop */
2481 		for (iz = 0; iz < iy; iz++) {
2482 			 _W += ((mp_word)*tmpx++)*((mp_word)*tmpy--);
2483 		}
2484 
2485 		/* store term */
2486 		W[ix] = ((mp_digit)_W) & MP_MASK;
2487 
2488 		/* make next carry */
2489 		_W = _W >> ((mp_word)DIGIT_BIT);
2490 	}
2491 
2492 	/* setup dest */
2493 	olduse  = c->used;
2494 	c->used = pa;
2495 
2496 	{
2497 		mp_digit *tmpc;
2498 
2499 		tmpc = c->dp + digs;
2500 		for (ix = digs; ix < pa; ix++) {
2501 			/* now extract the previous digit [below the carry] */
2502 			*tmpc++ = W[ix];
2503 		}
2504 
2505 		/* clear unused digits [that existed in the old copy of c] */
2506 		for (; ix < olduse; ix++) {
2507 			*tmpc++ = 0;
2508 		}
2509 	}
2510 	trim_unused_digits(c);
2511 	return MP_OKAY;
2512 }
2513 
2514 /* Source: /usr/cvsroot/libtommath/dist/libtommath/bn_fast_s_mp_mul_high_digs.c,v $ */
2515 /* Revision: 1.1.1.1 $ */
2516 /* Date: 2011/03/12 22:58:18 $ */
2517 
2518 /* multiplies |a| * |b| and does not compute the lower digs digits
2519  * [meant to get the higher part of the product]
2520  */
2521 static int
basic_multiply_partial_upper(mp_int * a,mp_int * b,mp_int * c,int digs)2522 basic_multiply_partial_upper(mp_int * a, mp_int * b, mp_int * c, int digs)
2523 {
2524 	mp_int  t;
2525 	int     res, pa, pb, ix, iy;
2526 	mp_digit carry;
2527 	mp_word r;
2528 	mp_digit tmpx, *tmpt, *tmpy;
2529 
2530 	/* can we use the fast multiplier? */
2531 	if (can_use_fast_column_array(a->used + b->used + 1, MIN(a->used, b->used))) {
2532 		return fast_basic_multiply_partial_upper(a, b, c, digs);
2533 	}
2534 
2535 	if ((res = mp_init_size(&t, a->used + b->used + 1)) != MP_OKAY) {
2536 		return res;
2537 	}
2538 	t.used = a->used + b->used + 1;
2539 
2540 	pa = a->used;
2541 	pb = b->used;
2542 	for (ix = 0; ix < pa; ix++) {
2543 		/* clear the carry */
2544 		carry = 0;
2545 
2546 		/* left hand side of A[ix] * B[iy] */
2547 		tmpx = a->dp[ix];
2548 
2549 		/* alias to the address of where the digits will be stored */
2550 		tmpt = &(t.dp[digs]);
2551 
2552 		/* alias for where to read the right hand side from */
2553 		tmpy = b->dp + (digs - ix);
2554 
2555 		for (iy = digs - ix; iy < pb; iy++) {
2556 			/* calculate the double precision result */
2557 			r = ((mp_word)*tmpt) +
2558 				((mp_word)tmpx) * ((mp_word)*tmpy++) +
2559 				((mp_word) carry);
2560 
2561 			/* get the lower part */
2562 			*tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
2563 
2564 			/* carry the carry */
2565 			carry = (mp_digit) (r >> ((mp_word) DIGIT_BIT));
2566 		}
2567 		*tmpt = carry;
2568 	}
2569 	trim_unused_digits(&t);
2570 	mp_exch(&t, c);
2571 	mp_clear(&t);
2572 	return MP_OKAY;
2573 }
2574 
2575 /* Source: /usr/cvsroot/libtommath/dist/libtommath/bn_s_mp_mul_high_digs.c,v $ */
2576 /* Revision: 1.3 $ */
2577 /* Date: 2011/03/18 16:43:04 $ */
2578 
2579 /* reduces x mod m, assumes 0 < x < m**2, mu is
2580  * precomputed via mp_reduce_setup.
2581  * From HAC pp.604 Algorithm 14.42
2582  */
2583 static int
mp_reduce(mp_int * x,mp_int * m,mp_int * mu)2584 mp_reduce(mp_int * x, mp_int * m, mp_int * mu)
2585 {
2586 	mp_int  q;
2587 	int     res, um = m->used;
2588 
2589 	/* q = x */
2590 	if ((res = mp_init_copy(&q, x)) != MP_OKAY) {
2591 		return res;
2592 	}
2593 
2594 	/* q1 = x / b**(k-1)  */
2595 	rshift_digits(&q, um - 1);
2596 
2597 	/* according to HAC this optimization is ok */
2598 	if (((unsigned long) um) > (((mp_digit)1) << (DIGIT_BIT - 1))) {
2599 		if ((res = signed_multiply(&q, mu, &q)) != MP_OKAY) {
2600 			goto CLEANUP;
2601 		}
2602 	} else {
2603 		if ((res = basic_multiply_partial_upper(&q, mu, &q, um)) != MP_OKAY) {
2604 			goto CLEANUP;
2605 		}
2606 	}
2607 
2608 	/* q3 = q2 / b**(k+1) */
2609 	rshift_digits(&q, um + 1);
2610 
2611 	/* x = x mod b**(k+1), quick (no division) */
2612 	if ((res = modulo_2_to_power(x, DIGIT_BIT * (um + 1), x)) != MP_OKAY) {
2613 		goto CLEANUP;
2614 	}
2615 
2616 	/* q = q * m mod b**(k+1), quick (no division) */
2617 	if ((res = basic_multiply_partial_lower(&q, m, &q, um + 1)) != MP_OKAY) {
2618 		goto CLEANUP;
2619 	}
2620 
2621 	/* x = x - q */
2622 	if ((res = signed_subtract(x, &q, x)) != MP_OKAY) {
2623 		goto CLEANUP;
2624 	}
2625 
2626 	/* If x < 0, add b**(k+1) to it */
2627 	if (compare_digit(x, 0) == MP_LT) {
2628 		set_word(&q, 1);
2629 		if ((res = lshift_digits(&q, um + 1)) != MP_OKAY) {
2630 			goto CLEANUP;
2631 		}
2632 		if ((res = signed_add(x, &q, x)) != MP_OKAY) {
2633 			goto CLEANUP;
2634 		}
2635 	}
2636 
2637 	/* Back off if it's too big */
2638 	while (signed_compare(x, m) != MP_LT) {
2639 		if ((res = basic_subtract(x, m, x)) != MP_OKAY) {
2640 			goto CLEANUP;
2641 		}
2642 	}
2643 
2644 CLEANUP:
2645 	mp_clear(&q);
2646 
2647 	return res;
2648 }
2649 
2650 /* determines the setup value */
2651 static int
mp_reduce_2k_setup_l(mp_int * a,mp_int * d)2652 mp_reduce_2k_setup_l(mp_int *a, mp_int *d)
2653 {
2654 	int    res;
2655 	mp_int tmp;
2656 
2657 	if ((res = mp_init(&tmp)) != MP_OKAY) {
2658 		return res;
2659 	}
2660 
2661 	if ((res = mp_2expt(&tmp, mp_count_bits(a))) != MP_OKAY) {
2662 		goto ERR;
2663 	}
2664 
2665 	if ((res = basic_subtract(&tmp, a, d)) != MP_OKAY) {
2666 		goto ERR;
2667 	}
2668 
2669 ERR:
2670 	mp_clear(&tmp);
2671 	return res;
2672 }
2673 
2674 /* Source: /usr/cvsroot/libtommath/dist/libtommath/bn_mp_reduce_2k_setup_l.c,v $ */
2675 /* Revision: 1.1.1.1 $ */
2676 /* Date: 2011/03/12 22:58:18 $ */
2677 
2678 /* reduces a modulo n where n is of the form 2**p - d
2679    This differs from reduce_2k since "d" can be larger
2680    than a single digit.
2681 */
2682 static int
mp_reduce_2k_l(mp_int * a,mp_int * n,mp_int * d)2683 mp_reduce_2k_l(mp_int *a, mp_int *n, mp_int *d)
2684 {
2685 	mp_int q;
2686 	int    p, res;
2687 
2688 	if ((res = mp_init(&q)) != MP_OKAY) {
2689 		return res;
2690 	}
2691 
2692 	p = mp_count_bits(n);
2693 top:
2694 	/* q = a/2**p, a = a mod 2**p */
2695 	if ((res = rshift_bits(a, p, &q, a)) != MP_OKAY) {
2696 		goto ERR;
2697 	}
2698 
2699 	/* q = q * d */
2700 	if ((res = signed_multiply(&q, d, &q)) != MP_OKAY) {
2701 		goto ERR;
2702 	}
2703 
2704 	/* a = a + q */
2705 	if ((res = basic_add(a, &q, a)) != MP_OKAY) {
2706 		goto ERR;
2707 	}
2708 
2709 	if (compare_magnitude(a, n) != MP_LT) {
2710 		basic_subtract(a, n, a);
2711 		goto top;
2712 	}
2713 
2714 ERR:
2715 	mp_clear(&q);
2716 	return res;
2717 }
2718 
2719 /* Source: /usr/cvsroot/libtommath/dist/libtommath/bn_mp_reduce_2k_l.c,v $ */
2720 /* Revision: 1.1.1.1 $ */
2721 /* Date: 2011/03/12 22:58:18 $ */
2722 
2723 /* squaring using Toom-Cook 3-way algorithm */
2724 static int
toom_cook_square(mp_int * a,mp_int * b)2725 toom_cook_square(mp_int *a, mp_int *b)
2726 {
2727 	mp_int w0, w1, w2, w3, w4, tmp1, a0, a1, a2;
2728 	int res, B;
2729 
2730 	/* init temps */
2731 	if ((res = mp_init_multi(&w0, &w1, &w2, &w3, &w4, &a0, &a1, &a2, &tmp1, NULL)) != MP_OKAY) {
2732 		return res;
2733 	}
2734 
2735 	/* B */
2736 	B = a->used / 3;
2737 
2738 	/* a = a2 * B**2 + a1 * B + a0 */
2739 	if ((res = modulo_2_to_power(a, DIGIT_BIT * B, &a0)) != MP_OKAY) {
2740 		goto ERR;
2741 	}
2742 
2743 	if ((res = mp_copy(a, &a1)) != MP_OKAY) {
2744 		goto ERR;
2745 	}
2746 	rshift_digits(&a1, B);
2747 	modulo_2_to_power(&a1, DIGIT_BIT * B, &a1);
2748 
2749 	if ((res = mp_copy(a, &a2)) != MP_OKAY) {
2750 		goto ERR;
2751 	}
2752 	rshift_digits(&a2, B*2);
2753 
2754 	/* w0 = a0*a0 */
2755 	if ((res = square(&a0, &w0)) != MP_OKAY) {
2756 		goto ERR;
2757 	}
2758 
2759 	/* w4 = a2 * a2 */
2760 	if ((res = square(&a2, &w4)) != MP_OKAY) {
2761 		goto ERR;
2762 	}
2763 
2764 	/* w1 = (a2 + 2(a1 + 2a0))**2 */
2765 	if ((res = doubled(&a0, &tmp1)) != MP_OKAY) {
2766 		goto ERR;
2767 	}
2768 	if ((res = signed_add(&tmp1, &a1, &tmp1)) != MP_OKAY) {
2769 		goto ERR;
2770 	}
2771 	if ((res = doubled(&tmp1, &tmp1)) != MP_OKAY) {
2772 		goto ERR;
2773 	}
2774 	if ((res = signed_add(&tmp1, &a2, &tmp1)) != MP_OKAY) {
2775 		goto ERR;
2776 	}
2777 
2778 	if ((res = square(&tmp1, &w1)) != MP_OKAY) {
2779 		goto ERR;
2780 	}
2781 
2782 	/* w3 = (a0 + 2(a1 + 2a2))**2 */
2783 	if ((res = doubled(&a2, &tmp1)) != MP_OKAY) {
2784 		goto ERR;
2785 	}
2786 	if ((res = signed_add(&tmp1, &a1, &tmp1)) != MP_OKAY) {
2787 		goto ERR;
2788 	}
2789 	if ((res = doubled(&tmp1, &tmp1)) != MP_OKAY) {
2790 		goto ERR;
2791 	}
2792 	if ((res = signed_add(&tmp1, &a0, &tmp1)) != MP_OKAY) {
2793 		goto ERR;
2794 	}
2795 
2796 	if ((res = square(&tmp1, &w3)) != MP_OKAY) {
2797 		goto ERR;
2798 	}
2799 
2800 
2801 	/* w2 = (a2 + a1 + a0)**2 */
2802 	if ((res = signed_add(&a2, &a1, &tmp1)) != MP_OKAY) {
2803 		goto ERR;
2804 	}
2805 	if ((res = signed_add(&tmp1, &a0, &tmp1)) != MP_OKAY) {
2806 		goto ERR;
2807 	}
2808 	if ((res = square(&tmp1, &w2)) != MP_OKAY) {
2809 		goto ERR;
2810 	}
2811 
2812 	/* now solve the matrix
2813 
2814 	0  0  0  0  1
2815 	1  2  4  8  16
2816 	1  1  1  1  1
2817 	16 8  4  2  1
2818 	1  0  0  0  0
2819 
2820 	using 12 subtractions, 4 shifts, 2 small divisions and 1 small multiplication.
2821 	*/
2822 
2823 	/* r1 - r4 */
2824 	if ((res = signed_subtract(&w1, &w4, &w1)) != MP_OKAY) {
2825 		goto ERR;
2826 	}
2827 	/* r3 - r0 */
2828 	if ((res = signed_subtract(&w3, &w0, &w3)) != MP_OKAY) {
2829 		goto ERR;
2830 	}
2831 	/* r1/2 */
2832 	if ((res = half(&w1, &w1)) != MP_OKAY) {
2833 		goto ERR;
2834 	}
2835 	/* r3/2 */
2836 	if ((res = half(&w3, &w3)) != MP_OKAY) {
2837 		goto ERR;
2838 	}
2839 	/* r2 - r0 - r4 */
2840 	if ((res = signed_subtract(&w2, &w0, &w2)) != MP_OKAY) {
2841 		goto ERR;
2842 	}
2843 	if ((res = signed_subtract(&w2, &w4, &w2)) != MP_OKAY) {
2844 		goto ERR;
2845 	}
2846 	/* r1 - r2 */
2847 	if ((res = signed_subtract(&w1, &w2, &w1)) != MP_OKAY) {
2848 		goto ERR;
2849 	}
2850 	/* r3 - r2 */
2851 	if ((res = signed_subtract(&w3, &w2, &w3)) != MP_OKAY) {
2852 		goto ERR;
2853 	}
2854 	/* r1 - 8r0 */
2855 	if ((res = lshift_bits(&w0, 3, &tmp1)) != MP_OKAY) {
2856 		goto ERR;
2857 	}
2858 	if ((res = signed_subtract(&w1, &tmp1, &w1)) != MP_OKAY) {
2859 		goto ERR;
2860 	}
2861 	/* r3 - 8r4 */
2862 	if ((res = lshift_bits(&w4, 3, &tmp1)) != MP_OKAY) {
2863 		goto ERR;
2864 	}
2865 	if ((res = signed_subtract(&w3, &tmp1, &w3)) != MP_OKAY) {
2866 		goto ERR;
2867 	}
2868 	/* 3r2 - r1 - r3 */
2869 	if ((res = multiply_digit(&w2, 3, &w2)) != MP_OKAY) {
2870 		goto ERR;
2871 	}
2872 	if ((res = signed_subtract(&w2, &w1, &w2)) != MP_OKAY) {
2873 		goto ERR;
2874 	}
2875 	if ((res = signed_subtract(&w2, &w3, &w2)) != MP_OKAY) {
2876 		goto ERR;
2877 	}
2878 	/* r1 - r2 */
2879 	if ((res = signed_subtract(&w1, &w2, &w1)) != MP_OKAY) {
2880 		goto ERR;
2881 	}
2882 	/* r3 - r2 */
2883 	if ((res = signed_subtract(&w3, &w2, &w3)) != MP_OKAY) {
2884 		goto ERR;
2885 	}
2886 	/* r1/3 */
2887 	if ((res = third(&w1, &w1, NULL)) != MP_OKAY) {
2888 		goto ERR;
2889 	}
2890 	/* r3/3 */
2891 	if ((res = third(&w3, &w3, NULL)) != MP_OKAY) {
2892 		goto ERR;
2893 	}
2894 
2895 	/* at this point shift W[n] by B*n */
2896 	if ((res = lshift_digits(&w1, 1*B)) != MP_OKAY) {
2897 		goto ERR;
2898 	}
2899 	if ((res = lshift_digits(&w2, 2*B)) != MP_OKAY) {
2900 		goto ERR;
2901 	}
2902 	if ((res = lshift_digits(&w3, 3*B)) != MP_OKAY) {
2903 		goto ERR;
2904 	}
2905 	if ((res = lshift_digits(&w4, 4*B)) != MP_OKAY) {
2906 		goto ERR;
2907 	}
2908 
2909 	if ((res = signed_add(&w0, &w1, b)) != MP_OKAY) {
2910 		goto ERR;
2911 	}
2912 	if ((res = signed_add(&w2, &w3, &tmp1)) != MP_OKAY) {
2913 		goto ERR;
2914 	}
2915 	if ((res = signed_add(&w4, &tmp1, &tmp1)) != MP_OKAY) {
2916 		goto ERR;
2917 	}
2918 	if ((res = signed_add(&tmp1, b, b)) != MP_OKAY) {
2919 		goto ERR;
2920 	}
2921 
2922 ERR:
2923 	mp_clear_multi(&w0, &w1, &w2, &w3, &w4, &a0, &a1, &a2, &tmp1, NULL);
2924 	return res;
2925 }
2926 
2927 
2928 /* Source: /usr/cvsroot/libtommath/dist/libtommath/bn_mp_toom_sqr.c,v $ */
2929 /* Revision: 1.1.1.1 $ */
2930 /* Date: 2011/03/12 22:58:18 $ */
2931 
2932 /* Karatsuba squaring, computes b = a*a using three
2933  * half size squarings
2934  *
2935  * See comments of karatsuba_mul for details.  It
2936  * is essentially the same algorithm but merely
2937  * tuned to perform recursive squarings.
2938  */
2939 static int
karatsuba_square(mp_int * a,mp_int * b)2940 karatsuba_square(mp_int * a, mp_int * b)
2941 {
2942 	mp_int  x0, x1, t1, t2, x0x0, x1x1;
2943 	int     B, err;
2944 
2945 	err = MP_MEM;
2946 
2947 	/* min # of digits */
2948 	B = a->used;
2949 
2950 	/* now divide in two */
2951 	B = (unsigned)B >> 1;
2952 
2953 	/* init copy all the temps */
2954 	if (mp_init_size(&x0, B) != MP_OKAY) {
2955 		goto ERR;
2956 	}
2957 	if (mp_init_size(&x1, a->used - B) != MP_OKAY) {
2958 		goto X0;
2959 	}
2960 	/* init temps */
2961 	if (mp_init_size(&t1, a->used * 2) != MP_OKAY) {
2962 		goto X1;
2963 	}
2964 	if (mp_init_size(&t2, a->used * 2) != MP_OKAY) {
2965 		goto T1;
2966 	}
2967 	if (mp_init_size(&x0x0, B * 2) != MP_OKAY) {
2968 		goto T2;
2969 	}
2970 	if (mp_init_size(&x1x1, (a->used - B) * 2) != MP_OKAY) {
2971 		goto X0X0;
2972 	}
2973 
2974 	memcpy(x0.dp, a->dp, B * sizeof(*x0.dp));
2975 	memcpy(x1.dp, &a->dp[B], (a->used - B) * sizeof(*x1.dp));
2976 
2977 	x0.used = B;
2978 	x1.used = a->used - B;
2979 
2980 	trim_unused_digits(&x0);
2981 
2982 	/* now calc the products x0*x0 and x1*x1 */
2983 	if (square(&x0, &x0x0) != MP_OKAY) {
2984 		goto X1X1;           /* x0x0 = x0*x0 */
2985 	}
2986 	if (square(&x1, &x1x1) != MP_OKAY) {
2987 		goto X1X1;           /* x1x1 = x1*x1 */
2988 	}
2989 	/* now calc (x1+x0)**2 */
2990 	if (basic_add(&x1, &x0, &t1) != MP_OKAY) {
2991 		goto X1X1;           /* t1 = x1 - x0 */
2992 	}
2993 	if (square(&t1, &t1) != MP_OKAY) {
2994 		goto X1X1;           /* t1 = (x1 - x0) * (x1 - x0) */
2995 	}
2996 	/* add x0y0 */
2997 	if (basic_add(&x0x0, &x1x1, &t2) != MP_OKAY) {
2998 		goto X1X1;           /* t2 = x0x0 + x1x1 */
2999 	}
3000 	if (basic_subtract(&t1, &t2, &t1) != MP_OKAY) {
3001 		goto X1X1;           /* t1 = (x1+x0)**2 - (x0x0 + x1x1) */
3002 	}
3003 	/* shift by B */
3004 	if (lshift_digits(&t1, B) != MP_OKAY) {
3005 		goto X1X1;           /* t1 = (x0x0 + x1x1 - (x1-x0)*(x1-x0))<<B */
3006 	}
3007 	if (lshift_digits(&x1x1, B * 2) != MP_OKAY) {
3008 		goto X1X1;           /* x1x1 = x1x1 << 2*B */
3009 	}
3010 	if (signed_add(&x0x0, &t1, &t1) != MP_OKAY) {
3011 		goto X1X1;           /* t1 = x0x0 + t1 */
3012 	}
3013 	if (signed_add(&t1, &x1x1, b) != MP_OKAY) {
3014 		goto X1X1;           /* t1 = x0x0 + t1 + x1x1 */
3015 	}
3016 	err = MP_OKAY;
3017 
3018 X1X1:
3019 	mp_clear(&x1x1);
3020 X0X0:
3021 	mp_clear(&x0x0);
3022 T2:
3023 	mp_clear(&t2);
3024 T1:
3025 	mp_clear(&t1);
3026 X1:
3027 	mp_clear(&x1);
3028 X0:
3029 	mp_clear(&x0);
3030 ERR:
3031 	return err;
3032 }
3033 
3034 /* Source: /usr/cvsroot/libtommath/dist/libtommath/bn_mp_karatsuba_sqr.c,v $ */
3035 /* Revision: 1.2 $ */
3036 /* Date: 2011/03/12 23:43:54 $ */
3037 
3038 /* the jist of squaring...
3039  * you do like mult except the offset of the tmpx [one that
3040  * starts closer to zero] can't equal the offset of tmpy.
3041  * So basically you set up iy like before then you min it with
3042  * (ty-tx) so that it never happens.  You double all those
3043  * you add in the inner loop
3044 
3045 After that loop you do the squares and add them in.
3046 */
3047 
3048 static int
fast_basic_square(mp_int * a,mp_int * b)3049 fast_basic_square(mp_int * a, mp_int * b)
3050 {
3051 	int       olduse, res, pa, ix, iz;
3052 	mp_digit   W[MP_WARRAY], *tmpx;
3053 	mp_word   W1;
3054 
3055 	/* grow the destination as required */
3056 	pa = a->used + a->used;
3057 	if (b->alloc < pa) {
3058 		if ((res = mp_grow(b, pa)) != MP_OKAY) {
3059 			return res;
3060 		}
3061 	}
3062 
3063 	/* number of output digits to produce */
3064 	W1 = 0;
3065 	for (ix = 0; ix < pa; ix++) {
3066 		int      tx, ty, iy;
3067 		mp_word  _W;
3068 		mp_digit *tmpy;
3069 
3070 		/* clear counter */
3071 		_W = 0;
3072 
3073 		/* get offsets into the two bignums */
3074 		ty = MIN(a->used-1, ix);
3075 		tx = ix - ty;
3076 
3077 		/* setup temp aliases */
3078 		tmpx = a->dp + tx;
3079 		tmpy = a->dp + ty;
3080 
3081 		/* this is the number of times the loop will iterate, essentially
3082 		 while (tx++ < a->used && ty-- >= 0) { ... }
3083 		*/
3084 		iy = MIN(a->used-tx, ty+1);
3085 
3086 		/* now for squaring tx can never equal ty
3087 		* we halve the distance since they approach at a rate of 2x
3088 		* and we have to round because odd cases need to be executed
3089 		*/
3090 		iy = MIN(iy, (int)((unsigned)(ty-tx+1)>>1));
3091 
3092 		/* execute loop */
3093 		for (iz = 0; iz < iy; iz++) {
3094 			 _W += ((mp_word)*tmpx++)*((mp_word)*tmpy--);
3095 		}
3096 
3097 		/* double the inner product and add carry */
3098 		_W = _W + _W + W1;
3099 
3100 		/* even columns have the square term in them */
3101 		if ((ix&1) == 0) {
3102 			 _W += ((mp_word)a->dp[(unsigned)ix>>1])*((mp_word)a->dp[(unsigned)ix>>1]);
3103 		}
3104 
3105 		/* store it */
3106 		W[ix] = (mp_digit)(_W & MP_MASK);
3107 
3108 		/* make next carry */
3109 		W1 = _W >> ((mp_word)DIGIT_BIT);
3110 	}
3111 
3112 	/* setup dest */
3113 	olduse  = b->used;
3114 	b->used = a->used+a->used;
3115 
3116 	{
3117 		mp_digit *tmpb;
3118 		tmpb = b->dp;
3119 		for (ix = 0; ix < pa; ix++) {
3120 			*tmpb++ = W[ix] & MP_MASK;
3121 		}
3122 
3123 		/* clear unused digits [that existed in the old copy of c] */
3124 		for (; ix < olduse; ix++) {
3125 			*tmpb++ = 0;
3126 		}
3127 	}
3128 	trim_unused_digits(b);
3129 	return MP_OKAY;
3130 }
3131 
3132 /* Source: /usr/cvsroot/libtommath/dist/libtommath/bn_fast_s_mp_sqr.c,v $ */
3133 /* Revision: 1.3 $ */
3134 /* Date: 2011/03/18 16:43:04 $ */
3135 
3136 /* low level squaring, b = a*a, HAC pp.596-597, Algorithm 14.16 */
3137 static int
basic_square(mp_int * a,mp_int * b)3138 basic_square(mp_int * a, mp_int * b)
3139 {
3140 	mp_int  t;
3141 	int     res, ix, iy, pa;
3142 	mp_word r;
3143 	mp_digit carry, tmpx, *tmpt;
3144 
3145 	pa = a->used;
3146 	if ((res = mp_init_size(&t, 2*pa + 1)) != MP_OKAY) {
3147 		return res;
3148 	}
3149 
3150 	/* default used is maximum possible size */
3151 	t.used = 2*pa + 1;
3152 
3153 	for (ix = 0; ix < pa; ix++) {
3154 		/* first calculate the digit at 2*ix */
3155 		/* calculate double precision result */
3156 		r = ((mp_word) t.dp[2*ix]) +
3157 		((mp_word)a->dp[ix])*((mp_word)a->dp[ix]);
3158 
3159 		/* store lower part in result */
3160 		t.dp[ix+ix] = (mp_digit) (r & ((mp_word) MP_MASK));
3161 
3162 		/* get the carry */
3163 		carry = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
3164 
3165 		/* left hand side of A[ix] * A[iy] */
3166 		tmpx = a->dp[ix];
3167 
3168 		/* alias for where to store the results */
3169 		tmpt = t.dp + (2*ix + 1);
3170 
3171 		for (iy = ix + 1; iy < pa; iy++) {
3172 			/* first calculate the product */
3173 			r = ((mp_word)tmpx) * ((mp_word)a->dp[iy]);
3174 
3175 			/* now calculate the double precision result, note we use
3176 			* addition instead of *2 since it's easier to optimize
3177 			*/
3178 			r = ((mp_word) *tmpt) + r + r + ((mp_word) carry);
3179 
3180 			/* store lower part */
3181 			*tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
3182 
3183 			/* get carry */
3184 			carry = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
3185 		}
3186 		/* propagate upwards */
3187 		while (carry != ((mp_digit) 0)) {
3188 			r = ((mp_word) *tmpt) + ((mp_word) carry);
3189 			*tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
3190 			carry = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
3191 		}
3192 	}
3193 
3194 	trim_unused_digits(&t);
3195 	mp_exch(&t, b);
3196 	mp_clear(&t);
3197 	return MP_OKAY;
3198 }
3199 
3200 /* Source: /usr/cvsroot/libtommath/dist/libtommath/bn_s_mp_sqr.c,v $ */
3201 /* Revision: 1.1.1.1 $ */
3202 /* Date: 2011/03/12 22:58:18 $ */
3203 
3204 #define TOOM_SQR_CUTOFF      400
3205 #define KARATSUBA_SQR_CUTOFF 120
3206 
3207 /* computes b = a*a */
3208 static int
square(mp_int * a,mp_int * b)3209 square(mp_int * a, mp_int * b)
3210 {
3211 	int     res;
3212 
3213 	/* use Toom-Cook? */
3214 	if (a->used >= TOOM_SQR_CUTOFF) {
3215 		res = toom_cook_square(a, b);
3216 		/* Karatsuba? */
3217 	} else if (a->used >= KARATSUBA_SQR_CUTOFF) {
3218 		res = karatsuba_square(a, b);
3219 	} else {
3220 		/* can we use the fast comba multiplier? */
3221 		if (can_use_fast_column_array(a->used + a->used + 1, a->used)) {
3222 			res = fast_basic_square(a, b);
3223 		} else {
3224 			res = basic_square(a, b);
3225 		}
3226 	}
3227 	b->sign = MP_ZPOS;
3228 	return res;
3229 }
3230 
3231 /* find window size */
3232 static inline int
find_window_size(mp_int * X)3233 find_window_size(mp_int *X)
3234 {
3235 	int	x;
3236 
3237 	x = mp_count_bits(X);
3238 	return (x <= 7) ? 2 : (x <= 36) ? 3 : (x <= 140) ? 4 : (x <= 450) ? 5 : (x <= 1303) ? 6 : (x <= 3529) ? 7 : 8;
3239 }
3240 
3241 /* Source: /usr/cvsroot/libtommath/dist/libtommath/bn_mp_sqr.c,v $ */
3242 /* Revision: 1.3 $ */
3243 /* Date: 2011/03/18 16:43:04 $ */
3244 
3245 #define TAB_SIZE 256
3246 
3247 static int
basic_exponent_mod(mp_int * G,mp_int * X,mp_int * P,mp_int * Y,int redmode)3248 basic_exponent_mod(mp_int * G, mp_int * X, mp_int * P, mp_int * Y, int redmode)
3249 {
3250 	mp_digit buf;
3251 	mp_int  M[TAB_SIZE], res, mu;
3252 	int     err, bitbuf, bitcpy, bitcnt, mode, digidx, x, y, winsize;
3253 	int	(*redux)(mp_int*,mp_int*,mp_int*);
3254 
3255 	winsize = find_window_size(X);
3256 
3257 	/* init M array */
3258 	/* init first cell */
3259 	if ((err = mp_init(&M[1])) != MP_OKAY) {
3260 		return err;
3261 	}
3262 
3263 	/* now init the second half of the array */
3264 	for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
3265 		if ((err = mp_init(&M[x])) != MP_OKAY) {
3266 			for (y = 1<<(winsize-1); y < x; y++) {
3267 				mp_clear(&M[y]);
3268 			}
3269 			mp_clear(&M[1]);
3270 			return err;
3271 		}
3272 	}
3273 
3274 	/* create mu, used for Barrett reduction */
3275 	if ((err = mp_init(&mu)) != MP_OKAY) {
3276 		goto LBL_M;
3277 	}
3278 
3279 	if (redmode == 0) {
3280 		if ((err = mp_reduce_setup(&mu, P)) != MP_OKAY) {
3281 			goto LBL_MU;
3282 		}
3283 		redux = mp_reduce;
3284 	} else {
3285 		if ((err = mp_reduce_2k_setup_l(P, &mu)) != MP_OKAY) {
3286 			goto LBL_MU;
3287 		}
3288 		redux = mp_reduce_2k_l;
3289 	}
3290 
3291 	/* create M table
3292 	*
3293 	* The M table contains powers of the base,
3294 	* e.g. M[x] = G**x mod P
3295 	*
3296 	* The first half of the table is not
3297 	* computed though accept for M[0] and M[1]
3298 	*/
3299 	if ((err = modulo(G, P, &M[1])) != MP_OKAY) {
3300 		goto LBL_MU;
3301 	}
3302 
3303 	/* compute the value at M[1<<(winsize-1)] by squaring
3304 	* M[1] (winsize-1) times
3305 	*/
3306 	if ((err = mp_copy( &M[1], &M[1 << (winsize - 1)])) != MP_OKAY) {
3307 		goto LBL_MU;
3308 	}
3309 
3310 	for (x = 0; x < (winsize - 1); x++) {
3311 		/* square it */
3312 		if ((err = square(&M[1 << (winsize - 1)],
3313 		       &M[1 << (winsize - 1)])) != MP_OKAY) {
3314 			goto LBL_MU;
3315 		}
3316 
3317 		/* reduce modulo P */
3318 		if ((err = redux(&M[1 << (winsize - 1)], P, &mu)) != MP_OKAY) {
3319 			goto LBL_MU;
3320 		}
3321 	}
3322 
3323 	/* create upper table, that is M[x] = M[x-1] * M[1] (mod P)
3324 	* for x = (2**(winsize - 1) + 1) to (2**winsize - 1)
3325 	*/
3326 	for (x = (1 << (winsize - 1)) + 1; x < (1 << winsize); x++) {
3327 		if ((err = signed_multiply(&M[x - 1], &M[1], &M[x])) != MP_OKAY) {
3328 			goto LBL_MU;
3329 		}
3330 		if ((err = redux(&M[x], P, &mu)) != MP_OKAY) {
3331 			goto LBL_MU;
3332 		}
3333 	}
3334 
3335 	/* setup result */
3336 	if ((err = mp_init(&res)) != MP_OKAY) {
3337 		goto LBL_MU;
3338 	}
3339 	set_word(&res, 1);
3340 
3341 	/* set initial mode and bit cnt */
3342 	mode = 0;
3343 	bitcnt = 1;
3344 	buf = 0;
3345 	digidx = X->used - 1;
3346 	bitcpy = 0;
3347 	bitbuf = 0;
3348 
3349 	for (;;) {
3350 		/* grab next digit as required */
3351 		if (--bitcnt == 0) {
3352 			/* if digidx == -1 we are out of digits */
3353 			if (digidx == -1) {
3354 				break;
3355 			}
3356 			/* read next digit and reset the bitcnt */
3357 			buf = X->dp[digidx--];
3358 			bitcnt = (int) DIGIT_BIT;
3359 		}
3360 
3361 		/* grab the next msb from the exponent */
3362 		y = (unsigned)(buf >> (mp_digit)(DIGIT_BIT - 1)) & 1;
3363 		buf <<= (mp_digit)1;
3364 
3365 		/* if the bit is zero and mode == 0 then we ignore it
3366 		* These represent the leading zero bits before the first 1 bit
3367 		* in the exponent.  Technically this opt is not required but it
3368 		* does lower the # of trivial squaring/reductions used
3369 		*/
3370 		if (mode == 0 && y == 0) {
3371 			continue;
3372 		}
3373 
3374 		/* if the bit is zero and mode == 1 then we square */
3375 		if (mode == 1 && y == 0) {
3376 			if ((err = square(&res, &res)) != MP_OKAY) {
3377 				goto LBL_RES;
3378 			}
3379 			if ((err = redux(&res, P, &mu)) != MP_OKAY) {
3380 				goto LBL_RES;
3381 			}
3382 			continue;
3383 		}
3384 
3385 		/* else we add it to the window */
3386 		bitbuf |= (y << (winsize - ++bitcpy));
3387 		mode = 2;
3388 
3389 		if (bitcpy == winsize) {
3390 			/* ok window is filled so square as required and multiply  */
3391 			/* square first */
3392 			for (x = 0; x < winsize; x++) {
3393 				if ((err = square(&res, &res)) != MP_OKAY) {
3394 					goto LBL_RES;
3395 				}
3396 				if ((err = redux(&res, P, &mu)) != MP_OKAY) {
3397 					goto LBL_RES;
3398 				}
3399 			}
3400 
3401 			/* then multiply */
3402 			if ((err = signed_multiply(&res, &M[bitbuf], &res)) != MP_OKAY) {
3403 				goto LBL_RES;
3404 			}
3405 			if ((err = redux(&res, P, &mu)) != MP_OKAY) {
3406 				goto LBL_RES;
3407 			}
3408 
3409 			/* empty window and reset */
3410 			bitcpy = 0;
3411 			bitbuf = 0;
3412 			mode = 1;
3413 		}
3414 	}
3415 
3416 	/* if bits remain then square/multiply */
3417 	if (mode == 2 && bitcpy > 0) {
3418 		/* square then multiply if the bit is set */
3419 		for (x = 0; x < bitcpy; x++) {
3420 			if ((err = square(&res, &res)) != MP_OKAY) {
3421 				goto LBL_RES;
3422 			}
3423 			if ((err = redux(&res, P, &mu)) != MP_OKAY) {
3424 				goto LBL_RES;
3425 			}
3426 
3427 			bitbuf <<= 1;
3428 			if ((bitbuf & (1 << winsize)) != 0) {
3429 				/* then multiply */
3430 				if ((err = signed_multiply(&res, &M[1], &res)) != MP_OKAY) {
3431 					goto LBL_RES;
3432 				}
3433 				if ((err = redux(&res, P, &mu)) != MP_OKAY) {
3434 					goto LBL_RES;
3435 				}
3436 			}
3437 		}
3438 	}
3439 
3440 	mp_exch(&res, Y);
3441 	err = MP_OKAY;
3442 LBL_RES:
3443 	mp_clear(&res);
3444 LBL_MU:
3445 	mp_clear(&mu);
3446 LBL_M:
3447 	mp_clear(&M[1]);
3448 	for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
3449 		mp_clear(&M[x]);
3450 	}
3451 	return err;
3452 }
3453 
3454 /* determines if a number is a valid DR modulus */
3455 static int
is_diminished_radix_modulus(mp_int * a)3456 is_diminished_radix_modulus(mp_int *a)
3457 {
3458 	int ix;
3459 
3460 	/* must be at least two digits */
3461 	if (a->used < 2) {
3462 		return 0;
3463 	}
3464 
3465 	/* must be of the form b**k - a [a <= b] so all
3466 	* but the first digit must be equal to -1 (mod b).
3467 	*/
3468 	for (ix = 1; ix < a->used; ix++) {
3469 		if (a->dp[ix] != MP_MASK) {
3470 			  return 0;
3471 		}
3472 	}
3473 	return 1;
3474 }
3475 
3476 /* Source: /usr/cvsroot/libtommath/dist/libtommath/bn_mp_dr_is_modulus.c,v $ */
3477 /* Revision: 1.1.1.1 $ */
3478 /* Date: 2011/03/12 22:58:18 $ */
3479 
3480 /* determines if mp_reduce_2k can be used */
3481 static int
mp_reduce_is_2k(mp_int * a)3482 mp_reduce_is_2k(mp_int *a)
3483 {
3484 	int ix, iy, iw;
3485 	mp_digit iz;
3486 
3487 	if (a->used == 0) {
3488 		return MP_NO;
3489 	}
3490 	if (a->used == 1) {
3491 		return MP_YES;
3492 	}
3493 	if (a->used > 1) {
3494 		iy = mp_count_bits(a);
3495 		iz = 1;
3496 		iw = 1;
3497 
3498 		/* Test every bit from the second digit up, must be 1 */
3499 		for (ix = DIGIT_BIT; ix < iy; ix++) {
3500 			if ((a->dp[iw] & iz) == 0) {
3501 				return MP_NO;
3502 			}
3503 			iz <<= 1;
3504 			if (iz > (mp_digit)MP_MASK) {
3505 				++iw;
3506 				iz = 1;
3507 			}
3508 		}
3509 	}
3510 	return MP_YES;
3511 }
3512 
3513 /* Source: /usr/cvsroot/libtommath/dist/libtommath/bn_mp_reduce_is_2k.c,v $ */
3514 /* Revision: 1.1.1.1 $ */
3515 /* Date: 2011/03/12 22:58:18 $ */
3516 
3517 
3518 /* d = a * b (mod c) */
3519 static int
multiply_modulo(mp_int * d,mp_int * a,mp_int * b,mp_int * c)3520 multiply_modulo(mp_int *d, mp_int * a, mp_int * b, mp_int * c)
3521 {
3522 	mp_int  t;
3523 	int     res;
3524 
3525 	if ((res = mp_init(&t)) != MP_OKAY) {
3526 		return res;
3527 	}
3528 
3529 	if ((res = signed_multiply(a, b, &t)) != MP_OKAY) {
3530 		mp_clear(&t);
3531 		return res;
3532 	}
3533 	res = modulo(&t, c, d);
3534 	mp_clear(&t);
3535 	return res;
3536 }
3537 
3538 /* Source: /usr/cvsroot/libtommath/dist/libtommath/bn_mp_mulmod.c,v $ */
3539 /* Revision: 1.1.1.1 $ */
3540 /* Date: 2011/03/12 22:58:18 $ */
3541 
3542 /* setups the montgomery reduction stuff */
3543 static int
mp_montgomery_setup(mp_int * n,mp_digit * rho)3544 mp_montgomery_setup(mp_int * n, mp_digit * rho)
3545 {
3546 	mp_digit x, b;
3547 
3548 	/* fast inversion mod 2**k
3549 	*
3550 	* Based on the fact that
3551 	*
3552 	* XA = 1 (mod 2**n)  =>  (X(2-XA)) A = 1 (mod 2**2n)
3553 	*                    =>  2*X*A - X*X*A*A = 1
3554 	*                    =>  2*(1) - (1)     = 1
3555 	*/
3556 	b = n->dp[0];
3557 
3558 	if ((b & 1) == 0) {
3559 		return MP_VAL;
3560 	}
3561 
3562 	x = (((b + 2) & 4) << 1) + b; /* here x*a==1 mod 2**4 */
3563 	x *= 2 - b * x;               /* here x*a==1 mod 2**8 */
3564 	x *= 2 - b * x;               /* here x*a==1 mod 2**16 */
3565 	x *= 2 - b * x;               /* here x*a==1 mod 2**32 */
3566 	if (/*CONSTCOND*/sizeof(mp_digit) == 8) {
3567 		x *= 2 - b * x;	/* here x*a==1 mod 2**64 */
3568 	}
3569 
3570 	/* rho = -1/m mod b */
3571 	*rho = (unsigned long)(((mp_word)1 << ((mp_word) DIGIT_BIT)) - x) & MP_MASK;
3572 
3573 	return MP_OKAY;
3574 }
3575 
3576 /* Source: /usr/cvsroot/libtommath/dist/libtommath/bn_mp_montgomery_setup.c,v $ */
3577 /* Revision: 1.1.1.1 $ */
3578 /* Date: 2011/03/12 22:58:18 $ */
3579 
3580 /* computes xR**-1 == x (mod N) via Montgomery Reduction
3581  *
3582  * This is an optimized implementation of montgomery_reduce
3583  * which uses the comba method to quickly calculate the columns of the
3584  * reduction.
3585  *
3586  * Based on Algorithm 14.32 on pp.601 of HAC.
3587 */
3588 static int
fast_mp_montgomery_reduce(mp_int * x,mp_int * n,mp_digit rho)3589 fast_mp_montgomery_reduce(mp_int * x, mp_int * n, mp_digit rho)
3590 {
3591 	int     ix, res, olduse;
3592 	/*LINTED*/
3593 	mp_word W[MP_WARRAY];
3594 
3595 	/* get old used count */
3596 	olduse = x->used;
3597 
3598 	/* grow a as required */
3599 	if (x->alloc < n->used + 1) {
3600 		if ((res = mp_grow(x, n->used + 1)) != MP_OKAY) {
3601 			return res;
3602 		}
3603 	}
3604 
3605 	/* first we have to get the digits of the input into
3606 	* an array of double precision words W[...]
3607 	*/
3608 	{
3609 		mp_word *_W;
3610 		mp_digit *tmpx;
3611 
3612 		/* alias for the W[] array */
3613 		_W = W;
3614 
3615 		/* alias for the digits of  x*/
3616 		tmpx = x->dp;
3617 
3618 		/* copy the digits of a into W[0..a->used-1] */
3619 		for (ix = 0; ix < x->used; ix++) {
3620 			*_W++ = *tmpx++;
3621 		}
3622 
3623 		/* zero the high words of W[a->used..m->used*2] */
3624 		for (; ix < n->used * 2 + 1; ix++) {
3625 			*_W++ = 0;
3626 		}
3627 	}
3628 
3629 	/* now we proceed to zero successive digits
3630 	* from the least significant upwards
3631 	*/
3632 	for (ix = 0; ix < n->used; ix++) {
3633 		/* mu = ai * m' mod b
3634 		*
3635 		* We avoid a double precision multiplication (which isn't required)
3636 		* by casting the value down to a mp_digit.  Note this requires
3637 		* that W[ix-1] have  the carry cleared (see after the inner loop)
3638 		*/
3639 		mp_digit mu;
3640 		mu = (mp_digit) (((W[ix] & MP_MASK) * rho) & MP_MASK);
3641 
3642 		/* a = a + mu * m * b**i
3643 		*
3644 		* This is computed in place and on the fly.  The multiplication
3645 		* by b**i is handled by offsetting which columns the results
3646 		* are added to.
3647 		*
3648 		* Note the comba method normally doesn't handle carries in the
3649 		* inner loop In this case we fix the carry from the previous
3650 		* column since the Montgomery reduction requires digits of the
3651 		* result (so far) [see above] to work.  This is
3652 		* handled by fixing up one carry after the inner loop.  The
3653 		* carry fixups are done in order so after these loops the
3654 		* first m->used words of W[] have the carries fixed
3655 		*/
3656 		{
3657 			int iy;
3658 			mp_digit *tmpn;
3659 			mp_word *_W;
3660 
3661 			/* alias for the digits of the modulus */
3662 			tmpn = n->dp;
3663 
3664 			/* Alias for the columns set by an offset of ix */
3665 			_W = W + ix;
3666 
3667 			/* inner loop */
3668 			for (iy = 0; iy < n->used; iy++) {
3669 				  *_W++ += ((mp_word)mu) * ((mp_word)*tmpn++);
3670 			}
3671 		}
3672 
3673 		/* now fix carry for next digit, W[ix+1] */
3674 		W[ix + 1] += W[ix] >> ((mp_word) DIGIT_BIT);
3675 	}
3676 
3677 	/* now we have to propagate the carries and
3678 	* shift the words downward [all those least
3679 	* significant digits we zeroed].
3680 	*/
3681 	{
3682 		mp_digit *tmpx;
3683 		mp_word *_W, *_W1;
3684 
3685 		/* nox fix rest of carries */
3686 
3687 		/* alias for current word */
3688 		_W1 = W + ix;
3689 
3690 		/* alias for next word, where the carry goes */
3691 		_W = W + ++ix;
3692 
3693 		for (; ix <= n->used * 2 + 1; ix++) {
3694 			*_W++ += *_W1++ >> ((mp_word) DIGIT_BIT);
3695 		}
3696 
3697 		/* copy out, A = A/b**n
3698 		*
3699 		* The result is A/b**n but instead of converting from an
3700 		* array of mp_word to mp_digit than calling rshift_digits
3701 		* we just copy them in the right order
3702 		*/
3703 
3704 		/* alias for destination word */
3705 		tmpx = x->dp;
3706 
3707 		/* alias for shifted double precision result */
3708 		_W = W + n->used;
3709 
3710 		for (ix = 0; ix < n->used + 1; ix++) {
3711 			*tmpx++ = (mp_digit)(*_W++ & ((mp_word) MP_MASK));
3712 		}
3713 
3714 		/* zero oldused digits, if the input a was larger than
3715 		* m->used+1 we'll have to clear the digits
3716 		*/
3717 		for (; ix < olduse; ix++) {
3718 			*tmpx++ = 0;
3719 		}
3720 	}
3721 
3722 	/* set the max used and clamp */
3723 	x->used = n->used + 1;
3724 	trim_unused_digits(x);
3725 
3726 	/* if A >= m then A = A - m */
3727 	if (compare_magnitude(x, n) != MP_LT) {
3728 		return basic_subtract(x, n, x);
3729 	}
3730 	return MP_OKAY;
3731 }
3732 
3733 /* Source: /usr/cvsroot/libtommath/dist/libtommath/bn_fast_mp_montgomery_reduce.c,v $ */
3734 /* Revision: 1.2 $ */
3735 /* Date: 2011/03/18 16:22:09 $ */
3736 
3737 /* computes xR**-1 == x (mod N) via Montgomery Reduction */
3738 static int
mp_montgomery_reduce(mp_int * x,mp_int * n,mp_digit rho)3739 mp_montgomery_reduce(mp_int * x, mp_int * n, mp_digit rho)
3740 {
3741 	int     ix, res, digs;
3742 	mp_digit mu;
3743 
3744 	/* can the fast reduction [comba] method be used?
3745 	*
3746 	* Note that unlike in mul you're safely allowed *less*
3747 	* than the available columns [255 per default] since carries
3748 	* are fixed up in the inner loop.
3749 	*/
3750 	digs = n->used * 2 + 1;
3751 	if (can_use_fast_column_array(digs, n->used)) {
3752 		return fast_mp_montgomery_reduce(x, n, rho);
3753 	}
3754 
3755 	/* grow the input as required */
3756 	if (x->alloc < digs) {
3757 		if ((res = mp_grow(x, digs)) != MP_OKAY) {
3758 			return res;
3759 		}
3760 	}
3761 	x->used = digs;
3762 
3763 	for (ix = 0; ix < n->used; ix++) {
3764 		/* mu = ai * rho mod b
3765 		*
3766 		* The value of rho must be precalculated via
3767 		* montgomery_setup() such that
3768 		* it equals -1/n0 mod b this allows the
3769 		* following inner loop to reduce the
3770 		* input one digit at a time
3771 		*/
3772 		mu = (mp_digit) (((mp_word)x->dp[ix]) * ((mp_word)rho) & MP_MASK);
3773 
3774 		/* a = a + mu * m * b**i */
3775 		{
3776 			int iy;
3777 			mp_digit *tmpn, *tmpx, carry;
3778 			mp_word r;
3779 
3780 			/* alias for digits of the modulus */
3781 			tmpn = n->dp;
3782 
3783 			/* alias for the digits of x [the input] */
3784 			tmpx = x->dp + ix;
3785 
3786 			/* set the carry to zero */
3787 			carry = 0;
3788 
3789 			/* Multiply and add in place */
3790 			for (iy = 0; iy < n->used; iy++) {
3791 				/* compute product and sum */
3792 				r = ((mp_word)mu) * ((mp_word)*tmpn++) +
3793 					  ((mp_word) carry) + ((mp_word) * tmpx);
3794 
3795 				/* get carry */
3796 				carry = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
3797 
3798 				/* fix digit */
3799 				*tmpx++ = (mp_digit)(r & ((mp_word) MP_MASK));
3800 			}
3801 			/* At this point the ix'th digit of x should be zero */
3802 
3803 
3804 			/* propagate carries upwards as required*/
3805 			while (carry) {
3806 				*tmpx += carry;
3807 				carry = *tmpx >> DIGIT_BIT;
3808 				*tmpx++ &= MP_MASK;
3809 			}
3810 		}
3811 	}
3812 
3813 	/* at this point the n.used'th least
3814 	* significant digits of x are all zero
3815 	* which means we can shift x to the
3816 	* right by n.used digits and the
3817 	* residue is unchanged.
3818 	*/
3819 
3820 	/* x = x/b**n.used */
3821 	trim_unused_digits(x);
3822 	rshift_digits(x, n->used);
3823 
3824 	/* if x >= n then x = x - n */
3825 	if (compare_magnitude(x, n) != MP_LT) {
3826 		return basic_subtract(x, n, x);
3827 	}
3828 
3829 	return MP_OKAY;
3830 }
3831 
3832 /* Source: /usr/cvsroot/libtommath/dist/libtommath/bn_mp_montgomery_reduce.c,v $ */
3833 /* Revision: 1.3 $ */
3834 /* Date: 2011/03/18 16:43:04 $ */
3835 
3836 /* determines the setup value */
3837 static void
diminished_radix_setup(mp_int * a,mp_digit * d)3838 diminished_radix_setup(mp_int *a, mp_digit *d)
3839 {
3840 	/* the casts are required if DIGIT_BIT is one less than
3841 	* the number of bits in a mp_digit [e.g. DIGIT_BIT==31]
3842 	*/
3843 	*d = (mp_digit)((((mp_word)1) << ((mp_word)DIGIT_BIT)) -
3844 		((mp_word)a->dp[0]));
3845 }
3846 
3847 /* Source: /usr/cvsroot/libtommath/dist/libtommath/bn_mp_dr_setup.c,v $ */
3848 /* Revision: 1.1.1.1 $ */
3849 /* Date: 2011/03/12 22:58:18 $ */
3850 
3851 /* reduce "x" in place modulo "n" using the Diminished Radix algorithm.
3852  *
3853  * Based on algorithm from the paper
3854  *
3855  * "Generating Efficient Primes for Discrete Log Cryptosystems"
3856  *                 Chae Hoon Lim, Pil Joong Lee,
3857  *          POSTECH Information Research Laboratories
3858  *
3859  * The modulus must be of a special format [see manual]
3860  *
3861  * Has been modified to use algorithm 7.10 from the LTM book instead
3862  *
3863  * Input x must be in the range 0 <= x <= (n-1)**2
3864  */
3865 static int
diminished_radix_reduce(mp_int * x,mp_int * n,mp_digit k)3866 diminished_radix_reduce(mp_int * x, mp_int * n, mp_digit k)
3867 {
3868 	int      err, i, m;
3869 	mp_word  r;
3870 	mp_digit mu, *tmpx1, *tmpx2;
3871 
3872 	/* m = digits in modulus */
3873 	m = n->used;
3874 
3875 	/* ensure that "x" has at least 2m digits */
3876 	if (x->alloc < m + m) {
3877 		if ((err = mp_grow(x, m + m)) != MP_OKAY) {
3878 			return err;
3879 		}
3880 	}
3881 
3882 	/* top of loop, this is where the code resumes if
3883 	* another reduction pass is required.
3884 	*/
3885 top:
3886 	/* aliases for digits */
3887 	/* alias for lower half of x */
3888 	tmpx1 = x->dp;
3889 
3890 	/* alias for upper half of x, or x/B**m */
3891 	tmpx2 = x->dp + m;
3892 
3893 	/* set carry to zero */
3894 	mu = 0;
3895 
3896 	/* compute (x mod B**m) + k * [x/B**m] inline and inplace */
3897 	for (i = 0; i < m; i++) {
3898 		r = ((mp_word)*tmpx2++) * ((mp_word)k) + *tmpx1 + mu;
3899 		*tmpx1++  = (mp_digit)(r & MP_MASK);
3900 		mu = (mp_digit)(r >> ((mp_word)DIGIT_BIT));
3901 	}
3902 
3903 	/* set final carry */
3904 	*tmpx1++ = mu;
3905 
3906 	/* zero words above m */
3907 	for (i = m + 1; i < x->used; i++) {
3908 		*tmpx1++ = 0;
3909 	}
3910 
3911 	/* clamp, sub and return */
3912 	trim_unused_digits(x);
3913 
3914 	/* if x >= n then subtract and reduce again
3915 	* Each successive "recursion" makes the input smaller and smaller.
3916 	*/
3917 	if (compare_magnitude(x, n) != MP_LT) {
3918 		basic_subtract(x, n, x);
3919 		goto top;
3920 	}
3921 	return MP_OKAY;
3922 }
3923 
3924 /* Source: /usr/cvsroot/libtommath/dist/libtommath/bn_mp_dr_reduce.c,v $ */
3925 /* Revision: 1.1.1.1 $ */
3926 /* Date: 2011/03/12 22:58:18 $ */
3927 
3928 /* determines the setup value */
3929 static int
mp_reduce_2k_setup(mp_int * a,mp_digit * d)3930 mp_reduce_2k_setup(mp_int *a, mp_digit *d)
3931 {
3932 	int res, p;
3933 	mp_int tmp;
3934 
3935 	if ((res = mp_init(&tmp)) != MP_OKAY) {
3936 		return res;
3937 	}
3938 
3939 	p = mp_count_bits(a);
3940 	if ((res = mp_2expt(&tmp, p)) != MP_OKAY) {
3941 		mp_clear(&tmp);
3942 		return res;
3943 	}
3944 
3945 	if ((res = basic_subtract(&tmp, a, &tmp)) != MP_OKAY) {
3946 		mp_clear(&tmp);
3947 		return res;
3948 	}
3949 
3950 	*d = tmp.dp[0];
3951 	mp_clear(&tmp);
3952 	return MP_OKAY;
3953 }
3954 
3955 /* Source: /usr/cvsroot/libtommath/dist/libtommath/bn_mp_reduce_2k_setup.c,v $ */
3956 /* Revision: 1.1.1.1 $ */
3957 /* Date: 2011/03/12 22:58:18 $ */
3958 
3959 /* reduces a modulo n where n is of the form 2**p - d */
3960 static int
mp_reduce_2k(mp_int * a,mp_int * n,mp_digit d)3961 mp_reduce_2k(mp_int *a, mp_int *n, mp_digit d)
3962 {
3963 	mp_int q;
3964 	int    p, res;
3965 
3966 	if ((res = mp_init(&q)) != MP_OKAY) {
3967 		return res;
3968 	}
3969 
3970 	p = mp_count_bits(n);
3971 top:
3972 	/* q = a/2**p, a = a mod 2**p */
3973 	if ((res = rshift_bits(a, p, &q, a)) != MP_OKAY) {
3974 		goto ERR;
3975 	}
3976 
3977 	if (d != 1) {
3978 		/* q = q * d */
3979 		if ((res = multiply_digit(&q, d, &q)) != MP_OKAY) {
3980 			 goto ERR;
3981 		}
3982 	}
3983 
3984 	/* a = a + q */
3985 	if ((res = basic_add(a, &q, a)) != MP_OKAY) {
3986 		goto ERR;
3987 	}
3988 
3989 	if (compare_magnitude(a, n) != MP_LT) {
3990 		basic_subtract(a, n, a);
3991 		goto top;
3992 	}
3993 
3994 ERR:
3995 	mp_clear(&q);
3996 	return res;
3997 }
3998 
3999 /* Source: /usr/cvsroot/libtommath/dist/libtommath/bn_mp_reduce_2k.c,v $ */
4000 /* Revision: 1.1.1.1 $ */
4001 /* Date: 2011/03/12 22:58:18 $ */
4002 
4003 /*
4004  * shifts with subtractions when the result is greater than b.
4005  *
4006  * The method is slightly modified to shift B unconditionally upto just under
4007  * the leading bit of b.  This saves a lot of multiple precision shifting.
4008  */
4009 static int
mp_montgomery_calc_normalization(mp_int * a,mp_int * b)4010 mp_montgomery_calc_normalization(mp_int * a, mp_int * b)
4011 {
4012 	int     x, bits, res;
4013 
4014 	/* how many bits of last digit does b use */
4015 	bits = mp_count_bits(b) % DIGIT_BIT;
4016 
4017 	if (b->used > 1) {
4018 		if ((res = mp_2expt(a, (b->used - 1) * DIGIT_BIT + bits - 1)) != MP_OKAY) {
4019 			return res;
4020 		}
4021 	} else {
4022 		set_word(a, 1);
4023 		bits = 1;
4024 	}
4025 
4026 
4027 	/* now compute C = A * B mod b */
4028 	for (x = bits - 1; x < (int)DIGIT_BIT; x++) {
4029 		if ((res = doubled(a, a)) != MP_OKAY) {
4030 			return res;
4031 		}
4032 		if (compare_magnitude(a, b) != MP_LT) {
4033 			if ((res = basic_subtract(a, b, a)) != MP_OKAY) {
4034 				return res;
4035 			}
4036 		}
4037 	}
4038 
4039 	return MP_OKAY;
4040 }
4041 
4042 /* Source: /usr/cvsroot/libtommath/dist/libtommath/bn_mp_montgomery_calc_normalization.c,v $ */
4043 /* Revision: 1.1.1.1 $ */
4044 /* Date: 2011/03/12 22:58:18 $ */
4045 
4046 /* computes Y == G**X mod P, HAC pp.616, Algorithm 14.85
4047  *
4048  * Uses a left-to-right k-ary sliding window to compute the modular exponentiation.
4049  * The value of k changes based on the size of the exponent.
4050  *
4051  * Uses Montgomery or Diminished Radix reduction [whichever appropriate]
4052  */
4053 
4054 #define TAB_SIZE 256
4055 
4056 static int
fast_exponent_modulo(mp_int * G,mp_int * X,mp_int * P,mp_int * Y,int redmode)4057 fast_exponent_modulo(mp_int * G, mp_int * X, mp_int * P, mp_int * Y, int redmode)
4058 {
4059 	mp_int  M[TAB_SIZE], res;
4060 	mp_digit buf, mp;
4061 	int     err, bitbuf, bitcpy, bitcnt, mode, digidx, x, y, winsize;
4062 
4063 	/* use a pointer to the reduction algorithm.  This allows us to use
4064 	* one of many reduction algorithms without modding the guts of
4065 	* the code with if statements everywhere.
4066 	*/
4067 	int     (*redux)(mp_int*,mp_int*,mp_digit);
4068 
4069 	winsize = find_window_size(X);
4070 
4071 	/* init M array */
4072 	/* init first cell */
4073 	if ((err = mp_init(&M[1])) != MP_OKAY) {
4074 		return err;
4075 	}
4076 
4077 	/* now init the second half of the array */
4078 	for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
4079 		if ((err = mp_init(&M[x])) != MP_OKAY) {
4080 			for (y = 1<<(winsize-1); y < x; y++) {
4081 				mp_clear(&M[y]);
4082 			}
4083 			mp_clear(&M[1]);
4084 			return err;
4085 		}
4086 	}
4087 
4088 	/* determine and setup reduction code */
4089 	if (redmode == 0) {
4090 		/* now setup montgomery  */
4091 		if ((err = mp_montgomery_setup(P, &mp)) != MP_OKAY) {
4092 			goto LBL_M;
4093 		}
4094 
4095 		/* automatically pick the comba one if available (saves quite a few calls/ifs) */
4096 		if (can_use_fast_column_array(P->used + P->used + 1, P->used)) {
4097 			redux = fast_mp_montgomery_reduce;
4098 		} else {
4099 			/* use slower baseline Montgomery method */
4100 			redux = mp_montgomery_reduce;
4101 		}
4102 	} else if (redmode == 1) {
4103 		/* setup DR reduction for moduli of the form B**k - b */
4104 		diminished_radix_setup(P, &mp);
4105 		redux = diminished_radix_reduce;
4106 	} else {
4107 		/* setup DR reduction for moduli of the form 2**k - b */
4108 		if ((err = mp_reduce_2k_setup(P, &mp)) != MP_OKAY) {
4109 			goto LBL_M;
4110 		}
4111 		redux = mp_reduce_2k;
4112 	}
4113 
4114 	/* setup result */
4115 	if ((err = mp_init(&res)) != MP_OKAY) {
4116 		goto LBL_M;
4117 	}
4118 
4119 	/* create M table
4120 	*
4121 
4122 	*
4123 	* The first half of the table is not computed though accept for M[0] and M[1]
4124 	*/
4125 
4126 	if (redmode == 0) {
4127 		/* now we need R mod m */
4128 		if ((err = mp_montgomery_calc_normalization(&res, P)) != MP_OKAY) {
4129 			goto LBL_RES;
4130 		}
4131 
4132 		/* now set M[1] to G * R mod m */
4133 		if ((err = multiply_modulo(&M[1], G, &res, P)) != MP_OKAY) {
4134 			goto LBL_RES;
4135 		}
4136 	} else {
4137 		set_word(&res, 1);
4138 		if ((err = modulo(G, P, &M[1])) != MP_OKAY) {
4139 			goto LBL_RES;
4140 		}
4141 	}
4142 
4143 	/* compute the value at M[1<<(winsize-1)] by squaring M[1] (winsize-1) times */
4144 	if ((err = mp_copy( &M[1], &M[1 << (winsize - 1)])) != MP_OKAY) {
4145 		goto LBL_RES;
4146 	}
4147 
4148 	for (x = 0; x < (winsize - 1); x++) {
4149 		if ((err = square(&M[1 << (winsize - 1)], &M[1 << (winsize - 1)])) != MP_OKAY) {
4150 			goto LBL_RES;
4151 		}
4152 		if ((err = (*redux)(&M[1 << (winsize - 1)], P, mp)) != MP_OKAY) {
4153 			goto LBL_RES;
4154 		}
4155 	}
4156 
4157 	/* create upper table */
4158 	for (x = (1 << (winsize - 1)) + 1; x < (1 << winsize); x++) {
4159 		if ((err = signed_multiply(&M[x - 1], &M[1], &M[x])) != MP_OKAY) {
4160 			goto LBL_RES;
4161 		}
4162 		if ((err = (*redux)(&M[x], P, mp)) != MP_OKAY) {
4163 			goto LBL_RES;
4164 		}
4165 	}
4166 
4167 	/* set initial mode and bit cnt */
4168 	mode = 0;
4169 	bitcnt = 1;
4170 	buf = 0;
4171 	digidx = X->used - 1;
4172 	bitcpy = 0;
4173 	bitbuf = 0;
4174 
4175 	for (;;) {
4176 		/* grab next digit as required */
4177 		if (--bitcnt == 0) {
4178 			/* if digidx == -1 we are out of digits so break */
4179 			if (digidx == -1) {
4180 				break;
4181 			}
4182 			/* read next digit and reset bitcnt */
4183 			buf = X->dp[digidx--];
4184 			bitcnt = (int)DIGIT_BIT;
4185 		}
4186 
4187 		/* grab the next msb from the exponent */
4188 		y = (int)(mp_digit)((mp_digit)buf >> (unsigned)(DIGIT_BIT - 1)) & 1;
4189 		buf <<= (mp_digit)1;
4190 
4191 		/* if the bit is zero and mode == 0 then we ignore it
4192 		* These represent the leading zero bits before the first 1 bit
4193 		* in the exponent.  Technically this opt is not required but it
4194 		* does lower the # of trivial squaring/reductions used
4195 		*/
4196 		if (mode == 0 && y == 0) {
4197 			continue;
4198 		}
4199 
4200 		/* if the bit is zero and mode == 1 then we square */
4201 		if (mode == 1 && y == 0) {
4202 			if ((err = square(&res, &res)) != MP_OKAY) {
4203 				goto LBL_RES;
4204 			}
4205 			if ((err = (*redux)(&res, P, mp)) != MP_OKAY) {
4206 				goto LBL_RES;
4207 			}
4208 			continue;
4209 		}
4210 
4211 		/* else we add it to the window */
4212 		bitbuf |= (y << (winsize - ++bitcpy));
4213 		mode = 2;
4214 
4215 		if (bitcpy == winsize) {
4216 			/* ok window is filled so square as required and multiply  */
4217 			/* square first */
4218 			for (x = 0; x < winsize; x++) {
4219 				if ((err = square(&res, &res)) != MP_OKAY) {
4220 					goto LBL_RES;
4221 				}
4222 				if ((err = (*redux)(&res, P, mp)) != MP_OKAY) {
4223 					goto LBL_RES;
4224 				}
4225 			}
4226 
4227 			/* then multiply */
4228 			if ((err = signed_multiply(&res, &M[bitbuf], &res)) != MP_OKAY) {
4229 				goto LBL_RES;
4230 			}
4231 			if ((err = (*redux)(&res, P, mp)) != MP_OKAY) {
4232 				goto LBL_RES;
4233 			}
4234 
4235 			/* empty window and reset */
4236 			bitcpy = 0;
4237 			bitbuf = 0;
4238 			mode = 1;
4239 		}
4240 	}
4241 
4242 	/* if bits remain then square/multiply */
4243 	if (mode == 2 && bitcpy > 0) {
4244 		/* square then multiply if the bit is set */
4245 		for (x = 0; x < bitcpy; x++) {
4246 			if ((err = square(&res, &res)) != MP_OKAY) {
4247 				goto LBL_RES;
4248 			}
4249 			if ((err = (*redux)(&res, P, mp)) != MP_OKAY) {
4250 				goto LBL_RES;
4251 			}
4252 
4253 			/* get next bit of the window */
4254 			bitbuf <<= 1;
4255 			if ((bitbuf & (1 << winsize)) != 0) {
4256 				/* then multiply */
4257 				if ((err = signed_multiply(&res, &M[1], &res)) != MP_OKAY) {
4258 					goto LBL_RES;
4259 				}
4260 				if ((err = (*redux)(&res, P, mp)) != MP_OKAY) {
4261 					goto LBL_RES;
4262 				}
4263 			}
4264 		}
4265 	}
4266 
4267 	if (redmode == 0) {
4268 		/* fixup result if Montgomery reduction is used
4269 		* recall that any value in a Montgomery system is
4270 		* actually multiplied by R mod n.  So we have
4271 		* to reduce one more time to cancel out the factor
4272 		* of R.
4273 		*/
4274 		if ((err = (*redux)(&res, P, mp)) != MP_OKAY) {
4275 			goto LBL_RES;
4276 		}
4277 	}
4278 
4279 	/* swap res with Y */
4280 	mp_exch(&res, Y);
4281 	err = MP_OKAY;
4282 LBL_RES:
4283 	mp_clear(&res);
4284 LBL_M:
4285 	mp_clear(&M[1]);
4286 	for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
4287 		mp_clear(&M[x]);
4288 	}
4289 	return err;
4290 }
4291 
4292 /* Source: /usr/cvsroot/libtommath/dist/libtommath/bn_fast_exponent_modulo.c,v $ */
4293 /* Revision: 1.4 $ */
4294 /* Date: 2011/03/18 16:43:04 $ */
4295 
4296 /* this is a shell function that calls either the normal or Montgomery
4297  * exptmod functions.  Originally the call to the montgomery code was
4298  * embedded in the normal function but that wasted a lot of stack space
4299  * for nothing (since 99% of the time the Montgomery code would be called)
4300  */
4301 static int
exponent_modulo(mp_int * G,mp_int * X,mp_int * P,mp_int * Y)4302 exponent_modulo(mp_int * G, mp_int * X, mp_int * P, mp_int *Y)
4303 {
4304 	int diminished_radix;
4305 
4306 	/* modulus P must be positive */
4307 	if (P->sign == MP_NEG) {
4308 		return MP_VAL;
4309 	}
4310 
4311 	/* if exponent X is negative we have to recurse */
4312 	if (X->sign == MP_NEG) {
4313 		mp_int tmpG, tmpX;
4314 		int err;
4315 
4316 		/* first compute 1/G mod P */
4317 		if ((err = mp_init(&tmpG)) != MP_OKAY) {
4318 			return err;
4319 		}
4320 		if ((err = modular_inverse(&tmpG, G, P)) != MP_OKAY) {
4321 			mp_clear(&tmpG);
4322 			return err;
4323 		}
4324 
4325 		/* now get |X| */
4326 		if ((err = mp_init(&tmpX)) != MP_OKAY) {
4327 			mp_clear(&tmpG);
4328 			return err;
4329 		}
4330 		if ((err = absolute(X, &tmpX)) != MP_OKAY) {
4331 			mp_clear_multi(&tmpG, &tmpX, NULL);
4332 			return err;
4333 		}
4334 
4335 		/* and now compute (1/G)**|X| instead of G**X [X < 0] */
4336 		err = exponent_modulo(&tmpG, &tmpX, P, Y);
4337 		mp_clear_multi(&tmpG, &tmpX, NULL);
4338 		return err;
4339 	}
4340 
4341 	/* modified diminished radix reduction */
4342 	if (mp_reduce_is_2k_l(P) == MP_YES) {
4343 		return basic_exponent_mod(G, X, P, Y, 1);
4344 	}
4345 
4346 	/* is it a DR modulus? */
4347 	diminished_radix = is_diminished_radix_modulus(P);
4348 
4349 	/* if not, is it an unrestricted DR modulus? */
4350 	if (!diminished_radix) {
4351 		diminished_radix = mp_reduce_is_2k(P) << 1;
4352 	}
4353 
4354 	/* if the modulus is odd or diminished_radix, use the montgomery method */
4355 	if (BN_is_odd(P) == 1 || diminished_radix) {
4356 		return fast_exponent_modulo(G, X, P, Y, diminished_radix);
4357 	}
4358 	/* otherwise use the generic Barrett reduction technique */
4359 	return basic_exponent_mod(G, X, P, Y, 0);
4360 }
4361 
4362 /* reverse an array, used for radix code */
4363 static void
bn_reverse(unsigned char * s,int len)4364 bn_reverse(unsigned char *s, int len)
4365 {
4366 	int     ix, iy;
4367 	uint8_t t;
4368 
4369 	for (ix = 0, iy = len - 1; ix < iy ; ix++, --iy) {
4370 		t = s[ix];
4371 		s[ix] = s[iy];
4372 		s[iy] = t;
4373 	}
4374 }
4375 
4376 static inline int
is_power_of_two(mp_digit b,int * p)4377 is_power_of_two(mp_digit b, int *p)
4378 {
4379 	int x;
4380 
4381 	/* fast return if no power of two */
4382 	if ((b==0) || (b & (b-1))) {
4383 		return 0;
4384 	}
4385 
4386 	for (x = 0; x < DIGIT_BIT; x++) {
4387 		if (b == (((mp_digit)1)<<x)) {
4388 			*p = x;
4389 			return 1;
4390 		}
4391 	}
4392 	return 0;
4393 }
4394 
4395 /* single digit division (based on routine from MPI) */
4396 static int
signed_divide_word(mp_int * a,mp_digit b,mp_int * c,mp_digit * d)4397 signed_divide_word(mp_int *a, mp_digit b, mp_int *c, mp_digit *d)
4398 {
4399 	mp_int  q;
4400 	mp_word w;
4401 	mp_digit t;
4402 	int     res, ix;
4403 
4404 	/* cannot divide by zero */
4405 	if (b == 0) {
4406 		return MP_VAL;
4407 	}
4408 
4409 	/* quick outs */
4410 	if (b == 1 || MP_ISZERO(a) == 1) {
4411 		if (d != NULL) {
4412 			*d = 0;
4413 		}
4414 		if (c != NULL) {
4415 			return mp_copy(a, c);
4416 		}
4417 		return MP_OKAY;
4418 	}
4419 
4420 	/* power of two ? */
4421 	if (is_power_of_two(b, &ix) == 1) {
4422 		if (d != NULL) {
4423 			*d = a->dp[0] & ((((mp_digit)1)<<ix) - 1);
4424 		}
4425 		if (c != NULL) {
4426 			return rshift_bits(a, ix, c, NULL);
4427 		}
4428 		return MP_OKAY;
4429 	}
4430 
4431 	/* three? */
4432 	if (b == 3) {
4433 		return third(a, c, d);
4434 	}
4435 
4436 	/* no easy answer [c'est la vie].  Just division */
4437 	if ((res = mp_init_size(&q, a->used)) != MP_OKAY) {
4438 		return res;
4439 	}
4440 
4441 	q.used = a->used;
4442 	q.sign = a->sign;
4443 	w = 0;
4444 	for (ix = a->used - 1; ix >= 0; ix--) {
4445 		w = (w << ((mp_word)DIGIT_BIT)) | ((mp_word)a->dp[ix]);
4446 
4447 		if (w >= b) {
4448 			t = (mp_digit)(w / b);
4449 			w -= ((mp_word)t) * ((mp_word)b);
4450 		} else {
4451 			t = 0;
4452 		}
4453 		q.dp[ix] = (mp_digit)t;
4454 	}
4455 
4456 	if (d != NULL) {
4457 		*d = (mp_digit)w;
4458 	}
4459 
4460 	if (c != NULL) {
4461 		trim_unused_digits(&q);
4462 		mp_exch(&q, c);
4463 	}
4464 	mp_clear(&q);
4465 
4466 	return res;
4467 }
4468 
4469 static const mp_digit ltm_prime_tab[] = {
4470 	0x0002, 0x0003, 0x0005, 0x0007, 0x000B, 0x000D, 0x0011, 0x0013,
4471 	0x0017, 0x001D, 0x001F, 0x0025, 0x0029, 0x002B, 0x002F, 0x0035,
4472 	0x003B, 0x003D, 0x0043, 0x0047, 0x0049, 0x004F, 0x0053, 0x0059,
4473 	0x0061, 0x0065, 0x0067, 0x006B, 0x006D, 0x0071, 0x007F,
4474 #ifndef MP_8BIT
4475 	0x0083,
4476 	0x0089, 0x008B, 0x0095, 0x0097, 0x009D, 0x00A3, 0x00A7, 0x00AD,
4477 	0x00B3, 0x00B5, 0x00BF, 0x00C1, 0x00C5, 0x00C7, 0x00D3, 0x00DF,
4478 	0x00E3, 0x00E5, 0x00E9, 0x00EF, 0x00F1, 0x00FB, 0x0101, 0x0107,
4479 	0x010D, 0x010F, 0x0115, 0x0119, 0x011B, 0x0125, 0x0133, 0x0137,
4480 
4481 	0x0139, 0x013D, 0x014B, 0x0151, 0x015B, 0x015D, 0x0161, 0x0167,
4482 	0x016F, 0x0175, 0x017B, 0x017F, 0x0185, 0x018D, 0x0191, 0x0199,
4483 	0x01A3, 0x01A5, 0x01AF, 0x01B1, 0x01B7, 0x01BB, 0x01C1, 0x01C9,
4484 	0x01CD, 0x01CF, 0x01D3, 0x01DF, 0x01E7, 0x01EB, 0x01F3, 0x01F7,
4485 	0x01FD, 0x0209, 0x020B, 0x021D, 0x0223, 0x022D, 0x0233, 0x0239,
4486 	0x023B, 0x0241, 0x024B, 0x0251, 0x0257, 0x0259, 0x025F, 0x0265,
4487 	0x0269, 0x026B, 0x0277, 0x0281, 0x0283, 0x0287, 0x028D, 0x0293,
4488 	0x0295, 0x02A1, 0x02A5, 0x02AB, 0x02B3, 0x02BD, 0x02C5, 0x02CF,
4489 
4490 	0x02D7, 0x02DD, 0x02E3, 0x02E7, 0x02EF, 0x02F5, 0x02F9, 0x0301,
4491 	0x0305, 0x0313, 0x031D, 0x0329, 0x032B, 0x0335, 0x0337, 0x033B,
4492 	0x033D, 0x0347, 0x0355, 0x0359, 0x035B, 0x035F, 0x036D, 0x0371,
4493 	0x0373, 0x0377, 0x038B, 0x038F, 0x0397, 0x03A1, 0x03A9, 0x03AD,
4494 	0x03B3, 0x03B9, 0x03C7, 0x03CB, 0x03D1, 0x03D7, 0x03DF, 0x03E5,
4495 	0x03F1, 0x03F5, 0x03FB, 0x03FD, 0x0407, 0x0409, 0x040F, 0x0419,
4496 	0x041B, 0x0425, 0x0427, 0x042D, 0x043F, 0x0443, 0x0445, 0x0449,
4497 	0x044F, 0x0455, 0x045D, 0x0463, 0x0469, 0x047F, 0x0481, 0x048B,
4498 
4499 	0x0493, 0x049D, 0x04A3, 0x04A9, 0x04B1, 0x04BD, 0x04C1, 0x04C7,
4500 	0x04CD, 0x04CF, 0x04D5, 0x04E1, 0x04EB, 0x04FD, 0x04FF, 0x0503,
4501 	0x0509, 0x050B, 0x0511, 0x0515, 0x0517, 0x051B, 0x0527, 0x0529,
4502 	0x052F, 0x0551, 0x0557, 0x055D, 0x0565, 0x0577, 0x0581, 0x058F,
4503 	0x0593, 0x0595, 0x0599, 0x059F, 0x05A7, 0x05AB, 0x05AD, 0x05B3,
4504 	0x05BF, 0x05C9, 0x05CB, 0x05CF, 0x05D1, 0x05D5, 0x05DB, 0x05E7,
4505 	0x05F3, 0x05FB, 0x0607, 0x060D, 0x0611, 0x0617, 0x061F, 0x0623,
4506 	0x062B, 0x062F, 0x063D, 0x0641, 0x0647, 0x0649, 0x064D, 0x0653
4507 #endif
4508 };
4509 
4510 #define PRIME_SIZE	__arraycount(ltm_prime_tab)
4511 
4512 static inline int
mp_prime_is_divisible(mp_int * a,int * result)4513 mp_prime_is_divisible(mp_int *a, int *result)
4514 {
4515 	int     err, ix;
4516 	mp_digit res;
4517 
4518 	/* default to not */
4519 	*result = MP_NO;
4520 
4521 	for (ix = 0; ix < (int)PRIME_SIZE; ix++) {
4522 		/* what is a mod LBL_prime_tab[ix] */
4523 		if ((err = signed_divide_word(a, ltm_prime_tab[ix], NULL, &res)) != MP_OKAY) {
4524 			return err;
4525 		}
4526 
4527 		/* is the residue zero? */
4528 		if (res == 0) {
4529 			*result = MP_YES;
4530 			return MP_OKAY;
4531 		}
4532 	}
4533 
4534 	return MP_OKAY;
4535 }
4536 
4537 /* single digit addition */
4538 static int
add_single_digit(mp_int * a,mp_digit b,mp_int * c)4539 add_single_digit(mp_int *a, mp_digit b, mp_int *c)
4540 {
4541 	int     res, ix, oldused;
4542 	mp_digit *tmpa, *tmpc, mu;
4543 
4544 	/* grow c as required */
4545 	if (c->alloc < a->used + 1) {
4546 		if ((res = mp_grow(c, a->used + 1)) != MP_OKAY) {
4547 			return res;
4548 		}
4549 	}
4550 
4551 	/* if a is negative and |a| >= b, call c = |a| - b */
4552 	if (a->sign == MP_NEG && (a->used > 1 || a->dp[0] >= b)) {
4553 		/* temporarily fix sign of a */
4554 		a->sign = MP_ZPOS;
4555 
4556 		/* c = |a| - b */
4557 		res = signed_subtract_word(a, b, c);
4558 
4559 		/* fix sign  */
4560 		a->sign = c->sign = MP_NEG;
4561 
4562 		/* clamp */
4563 		trim_unused_digits(c);
4564 
4565 		return res;
4566 	}
4567 
4568 	/* old number of used digits in c */
4569 	oldused = c->used;
4570 
4571 	/* sign always positive */
4572 	c->sign = MP_ZPOS;
4573 
4574 	/* source alias */
4575 	tmpa = a->dp;
4576 
4577 	/* destination alias */
4578 	tmpc = c->dp;
4579 
4580 	/* if a is positive */
4581 	if (a->sign == MP_ZPOS) {
4582 		/* add digit, after this we're propagating
4583 		* the carry.
4584 		*/
4585 		*tmpc = *tmpa++ + b;
4586 		mu = *tmpc >> DIGIT_BIT;
4587 		*tmpc++ &= MP_MASK;
4588 
4589 		/* now handle rest of the digits */
4590 		for (ix = 1; ix < a->used; ix++) {
4591 			*tmpc = *tmpa++ + mu;
4592 			mu = *tmpc >> DIGIT_BIT;
4593 			*tmpc++ &= MP_MASK;
4594 		}
4595 		/* set final carry */
4596 		ix++;
4597 		*tmpc++  = mu;
4598 
4599 		/* setup size */
4600 		c->used = a->used + 1;
4601 	} else {
4602 		/* a was negative and |a| < b */
4603 		c->used  = 1;
4604 
4605 		/* the result is a single digit */
4606 		if (a->used == 1) {
4607 			*tmpc++  =  b - a->dp[0];
4608 		} else {
4609 			*tmpc++  =  b;
4610 		}
4611 
4612 		/* setup count so the clearing of oldused
4613 		* can fall through correctly
4614 		*/
4615 		ix = 1;
4616 	}
4617 
4618 	/* now zero to oldused */
4619 	while (ix++ < oldused) {
4620 		*tmpc++ = 0;
4621 	}
4622 	trim_unused_digits(c);
4623 
4624 	return MP_OKAY;
4625 }
4626 
4627 /* single digit subtraction */
4628 static int
signed_subtract_word(mp_int * a,mp_digit b,mp_int * c)4629 signed_subtract_word(mp_int *a, mp_digit b, mp_int *c)
4630 {
4631 	mp_digit *tmpa, *tmpc, mu;
4632 	int       res, ix, oldused;
4633 
4634 	/* grow c as required */
4635 	if (c->alloc < a->used + 1) {
4636 		if ((res = mp_grow(c, a->used + 1)) != MP_OKAY) {
4637 			return res;
4638 		}
4639 	}
4640 
4641 	/* if a is negative just do an unsigned
4642 	* addition [with fudged signs]
4643 	*/
4644 	if (a->sign == MP_NEG) {
4645 		a->sign = MP_ZPOS;
4646 		res = add_single_digit(a, b, c);
4647 		a->sign = c->sign = MP_NEG;
4648 
4649 		/* clamp */
4650 		trim_unused_digits(c);
4651 
4652 		return res;
4653 	}
4654 
4655 	/* setup regs */
4656 	oldused = c->used;
4657 	tmpa = a->dp;
4658 	tmpc = c->dp;
4659 
4660 	/* if a <= b simply fix the single digit */
4661 	if ((a->used == 1 && a->dp[0] <= b) || a->used == 0) {
4662 		if (a->used == 1) {
4663 			*tmpc++ = b - *tmpa;
4664 		} else {
4665 			*tmpc++ = b;
4666 		}
4667 		ix = 1;
4668 
4669 		/* negative/1digit */
4670 		c->sign = MP_NEG;
4671 		c->used = 1;
4672 	} else {
4673 		/* positive/size */
4674 		c->sign = MP_ZPOS;
4675 		c->used = a->used;
4676 
4677 		/* subtract first digit */
4678 		*tmpc = *tmpa++ - b;
4679 		mu = *tmpc >> (sizeof(mp_digit) * CHAR_BIT - 1);
4680 		*tmpc++ &= MP_MASK;
4681 
4682 		/* handle rest of the digits */
4683 		for (ix = 1; ix < a->used; ix++) {
4684 			*tmpc = *tmpa++ - mu;
4685 			mu = *tmpc >> (sizeof(mp_digit) * CHAR_BIT - 1);
4686 			*tmpc++ &= MP_MASK;
4687 		}
4688 	}
4689 
4690 	/* zero excess digits */
4691 	while (ix++ < oldused) {
4692 		*tmpc++ = 0;
4693 	}
4694 	trim_unused_digits(c);
4695 	return MP_OKAY;
4696 }
4697 
4698 static const int lnz[16] = {
4699 	4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0
4700 };
4701 
4702 /* Counts the number of lsbs which are zero before the first zero bit */
4703 static int
mp_cnt_lsb(mp_int * a)4704 mp_cnt_lsb(mp_int *a)
4705 {
4706 	int x;
4707 	mp_digit q, qq;
4708 
4709 	/* easy out */
4710 	if (MP_ISZERO(a) == 1) {
4711 		return 0;
4712 	}
4713 
4714 	/* scan lower digits until non-zero */
4715 	for (x = 0; x < a->used && a->dp[x] == 0; x++) {
4716 	}
4717 	q = a->dp[x];
4718 	x *= DIGIT_BIT;
4719 
4720 	/* now scan this digit until a 1 is found */
4721 	if ((q & 1) == 0) {
4722 		do {
4723 			 qq  = q & 15;
4724 			 /* LINTED previous op ensures range of qq */
4725 			 x  += lnz[qq];
4726 			 q >>= 4;
4727 		} while (qq == 0);
4728 	}
4729 	return x;
4730 }
4731 
4732 /* c = a * a (mod b) */
4733 static int
square_modulo(mp_int * a,mp_int * b,mp_int * c)4734 square_modulo(mp_int *a, mp_int *b, mp_int *c)
4735 {
4736 	int     res;
4737 	mp_int  t;
4738 
4739 	if ((res = mp_init(&t)) != MP_OKAY) {
4740 		return res;
4741 	}
4742 
4743 	if ((res = square(a, &t)) != MP_OKAY) {
4744 		mp_clear(&t);
4745 		return res;
4746 	}
4747 	res = modulo(&t, b, c);
4748 	mp_clear(&t);
4749 	return res;
4750 }
4751 
4752 static int
mp_prime_miller_rabin(mp_int * a,mp_int * b,int * result)4753 mp_prime_miller_rabin(mp_int *a, mp_int *b, int *result)
4754 {
4755 	mp_int  n1, y, r;
4756 	int     s, j, err;
4757 
4758 	/* default */
4759 	*result = MP_NO;
4760 
4761 	/* ensure b > 1 */
4762 	if (compare_digit(b, 1) != MP_GT) {
4763 		return MP_VAL;
4764 	}
4765 
4766 	/* get n1 = a - 1 */
4767 	if ((err = mp_init_copy(&n1, a)) != MP_OKAY) {
4768 		return err;
4769 	}
4770 	if ((err = signed_subtract_word(&n1, 1, &n1)) != MP_OKAY) {
4771 		goto LBL_N1;
4772 	}
4773 
4774 	/* set 2**s * r = n1 */
4775 	if ((err = mp_init_copy(&r, &n1)) != MP_OKAY) {
4776 		goto LBL_N1;
4777 	}
4778 
4779 	/* count the number of least significant bits
4780 	* which are zero
4781 	*/
4782 	s = mp_cnt_lsb(&r);
4783 
4784 	/* now divide n - 1 by 2**s */
4785 	if ((err = rshift_bits(&r, s, &r, NULL)) != MP_OKAY) {
4786 		goto LBL_R;
4787 	}
4788 
4789 	/* compute y = b**r mod a */
4790 	if ((err = mp_init(&y)) != MP_OKAY) {
4791 		goto LBL_R;
4792 	}
4793 	if ((err = exponent_modulo(b, &r, a, &y)) != MP_OKAY) {
4794 		goto LBL_Y;
4795 	}
4796 
4797 	/* if y != 1 and y != n1 do */
4798 	if (compare_digit(&y, 1) != MP_EQ && signed_compare(&y, &n1) != MP_EQ) {
4799 		j = 1;
4800 		/* while j <= s-1 and y != n1 */
4801 		while ((j <= (s - 1)) && signed_compare(&y, &n1) != MP_EQ) {
4802 			if ((err = square_modulo(&y, a, &y)) != MP_OKAY) {
4803 				goto LBL_Y;
4804 			}
4805 
4806 			/* if y == 1 then composite */
4807 			if (compare_digit(&y, 1) == MP_EQ) {
4808 				goto LBL_Y;
4809 			}
4810 
4811 			++j;
4812 		}
4813 
4814 		/* if y != n1 then composite */
4815 		if (signed_compare(&y, &n1) != MP_EQ) {
4816 			goto LBL_Y;
4817 		}
4818 	}
4819 
4820 	/* probably prime now */
4821 	*result = MP_YES;
4822 LBL_Y:
4823 	mp_clear(&y);
4824 LBL_R:
4825 	mp_clear(&r);
4826 LBL_N1:
4827 	mp_clear(&n1);
4828 	return err;
4829 }
4830 
4831 /* performs a variable number of rounds of Miller-Rabin
4832  *
4833  * Probability of error after t rounds is no more than
4834 
4835  *
4836  * Sets result to 1 if probably prime, 0 otherwise
4837  */
4838 static int
mp_prime_is_prime(mp_int * a,int t,int * result)4839 mp_prime_is_prime(mp_int *a, int t, int *result)
4840 {
4841 	mp_int  b;
4842 	int     ix, err, res;
4843 
4844 	/* default to no */
4845 	*result = MP_NO;
4846 
4847 	/* valid value of t? */
4848 	if (t <= 0 || t > (int)PRIME_SIZE) {
4849 		return MP_VAL;
4850 	}
4851 
4852 	/* is the input equal to one of the primes in the table? */
4853 	for (ix = 0; ix < (int)PRIME_SIZE; ix++) {
4854 		if (compare_digit(a, ltm_prime_tab[ix]) == MP_EQ) {
4855 			*result = 1;
4856 			return MP_OKAY;
4857 		}
4858 	}
4859 
4860 	/* first perform trial division */
4861 	if ((err = mp_prime_is_divisible(a, &res)) != MP_OKAY) {
4862 		return err;
4863 	}
4864 
4865 	/* return if it was trivially divisible */
4866 	if (res == MP_YES) {
4867 		return MP_OKAY;
4868 	}
4869 
4870 	/* now perform the miller-rabin rounds */
4871 	if ((err = mp_init(&b)) != MP_OKAY) {
4872 		return err;
4873 	}
4874 
4875 	for (ix = 0; ix < t; ix++) {
4876 		/* set the prime */
4877 		set_word(&b, ltm_prime_tab[ix]);
4878 
4879 		if ((err = mp_prime_miller_rabin(a, &b, &res)) != MP_OKAY) {
4880 			goto LBL_B;
4881 		}
4882 
4883 		if (res == MP_NO) {
4884 			goto LBL_B;
4885 		}
4886 	}
4887 
4888 	/* passed the test */
4889 	*result = MP_YES;
4890 LBL_B:
4891 	mp_clear(&b);
4892 	return err;
4893 }
4894 
4895 /* returns size of ASCII representation */
4896 static int
mp_radix_size(mp_int * a,int radix,int * size)4897 mp_radix_size(mp_int *a, int radix, int *size)
4898 {
4899 	int     res, digs;
4900 	mp_int  t;
4901 	mp_digit d;
4902 
4903 	*size = 0;
4904 
4905 	/* special case for binary */
4906 	if (radix == 2) {
4907 		*size = mp_count_bits(a) + (a->sign == MP_NEG ? 1 : 0) + 1;
4908 		return MP_OKAY;
4909 	}
4910 
4911 	/* make sure the radix is in range */
4912 	if (radix < 2 || radix > 64) {
4913 		return MP_VAL;
4914 	}
4915 
4916 	if (MP_ISZERO(a) == MP_YES) {
4917 		*size = 2;
4918 		return MP_OKAY;
4919 	}
4920 
4921 	/* digs is the digit count */
4922 	digs = 0;
4923 
4924 	/* if it's negative add one for the sign */
4925 	if (a->sign == MP_NEG) {
4926 		++digs;
4927 	}
4928 
4929 	/* init a copy of the input */
4930 	if ((res = mp_init_copy(&t, a)) != MP_OKAY) {
4931 		return res;
4932 	}
4933 
4934 	/* force temp to positive */
4935 	t.sign = MP_ZPOS;
4936 
4937 	/* fetch out all of the digits */
4938 	while (MP_ISZERO(&t) == MP_NO) {
4939 		if ((res = signed_divide_word(&t, (mp_digit) radix, &t, &d)) != MP_OKAY) {
4940 			mp_clear(&t);
4941 			return res;
4942 		}
4943 		++digs;
4944 	}
4945 	mp_clear(&t);
4946 
4947 	/* return digs + 1, the 1 is for the NULL byte that would be required. */
4948 	*size = digs + 1;
4949 	return MP_OKAY;
4950 }
4951 
4952 static const char *mp_s_rmap = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz+/";
4953 
4954 /* stores a bignum as a ASCII string in a given radix (2..64)
4955  *
4956  * Stores upto maxlen-1 chars and always a NULL byte
4957  */
4958 static int
mp_toradix_n(mp_int * a,char * str,int radix,int maxlen)4959 mp_toradix_n(mp_int * a, char *str, int radix, int maxlen)
4960 {
4961 	int     res, digs;
4962 	mp_int  t;
4963 	mp_digit d;
4964 	char   *_s = str;
4965 
4966 	/* check range of the maxlen, radix */
4967 	if (maxlen < 2 || radix < 2 || radix > 64) {
4968 		return MP_VAL;
4969 	}
4970 
4971 	/* quick out if its zero */
4972 	if (MP_ISZERO(a) == MP_YES) {
4973 		*str++ = '0';
4974 		*str = '\0';
4975 		return MP_OKAY;
4976 	}
4977 
4978 	if ((res = mp_init_copy(&t, a)) != MP_OKAY) {
4979 		return res;
4980 	}
4981 
4982 	/* if it is negative output a - */
4983 	if (t.sign == MP_NEG) {
4984 		/* we have to reverse our digits later... but not the - sign!! */
4985 		++_s;
4986 
4987 		/* store the flag and mark the number as positive */
4988 		*str++ = '-';
4989 		t.sign = MP_ZPOS;
4990 
4991 		/* subtract a char */
4992 		--maxlen;
4993 	}
4994 
4995 	digs = 0;
4996 	while (MP_ISZERO(&t) == 0) {
4997 		if (--maxlen < 1) {
4998 			/* no more room */
4999 			break;
5000 		}
5001 		if ((res = signed_divide_word(&t, (mp_digit) radix, &t, &d)) != MP_OKAY) {
5002 			mp_clear(&t);
5003 			return res;
5004 		}
5005 		/* LINTED -- radix' range is checked above, limits d's range */
5006 		*str++ = mp_s_rmap[d];
5007 		++digs;
5008 	}
5009 
5010 	/* reverse the digits of the string.  In this case _s points
5011 	* to the first digit [excluding the sign] of the number
5012 	*/
5013 	bn_reverse((unsigned char *)_s, digs);
5014 
5015 	/* append a NULL so the string is properly terminated */
5016 	*str = '\0';
5017 
5018 	mp_clear(&t);
5019 	return MP_OKAY;
5020 }
5021 
5022 static char *
formatbn(const BIGNUM * a,const int radix)5023 formatbn(const BIGNUM *a, const int radix)
5024 {
5025 	char	*s;
5026 	int	 len;
5027 
5028 	if (mp_radix_size(__UNCONST(a), radix, &len) != MP_OKAY) {
5029 		return NULL;
5030 	}
5031 	if ((s = allocate(1, (size_t)len)) != NULL) {
5032 		if (mp_toradix_n(__UNCONST(a), s, radix, len) != MP_OKAY) {
5033 			deallocate(s, (size_t)len);
5034 			return NULL;
5035 		}
5036 	}
5037 	return s;
5038 }
5039 
5040 static int
mp_getradix_num(mp_int * a,int radix,char * s)5041 mp_getradix_num(mp_int *a, int radix, char *s)
5042 {
5043 	int err, ch, neg, y;
5044 
5045 	/* clear a */
5046 	mp_zero(a);
5047 
5048 	/* if first digit is - then set negative */
5049 	if ((ch = *s++) == '-') {
5050 		neg = MP_NEG;
5051 		ch = *s++;
5052 	} else {
5053 		neg = MP_ZPOS;
5054 	}
5055 
5056 	for (;;) {
5057 		/* find y in the radix map */
5058 		for (y = 0; y < radix; y++) {
5059 			if (mp_s_rmap[y] == ch) {
5060 				break;
5061 			}
5062 		}
5063 		if (y == radix) {
5064 			break;
5065 		}
5066 
5067 		/* shift up and add */
5068 		if ((err = multiply_digit(a, radix, a)) != MP_OKAY) {
5069 			return err;
5070 		}
5071 		if ((err = add_single_digit(a, y, a)) != MP_OKAY) {
5072 			return err;
5073 		}
5074 
5075 		ch = *s++;
5076 	}
5077 	if (compare_digit(a, 0) != MP_EQ) {
5078 		a->sign = neg;
5079 	}
5080 
5081 	return MP_OKAY;
5082 }
5083 
5084 static int
getbn(BIGNUM ** a,const char * str,int radix)5085 getbn(BIGNUM **a, const char *str, int radix)
5086 {
5087 	int	len;
5088 
5089 	if (a == NULL || str == NULL || (*a = BN_new()) == NULL) {
5090 		return 0;
5091 	}
5092 	if (mp_getradix_num(*a, radix, __UNCONST(str)) != MP_OKAY) {
5093 		return 0;
5094 	}
5095 	mp_radix_size(__UNCONST(*a), radix, &len);
5096 	return len - 1;
5097 }
5098 
5099 /* d = a - b (mod c) */
5100 static int
subtract_modulo(mp_int * a,mp_int * b,mp_int * c,mp_int * d)5101 subtract_modulo(mp_int *a, mp_int *b, mp_int *c, mp_int *d)
5102 {
5103 	int     res;
5104 	mp_int  t;
5105 
5106 
5107 	if ((res = mp_init(&t)) != MP_OKAY) {
5108 		return res;
5109 	}
5110 
5111 	if ((res = signed_subtract(a, b, &t)) != MP_OKAY) {
5112 		mp_clear(&t);
5113 		return res;
5114 	}
5115 	res = modulo(&t, c, d);
5116 	mp_clear(&t);
5117 	return res;
5118 }
5119 
5120 /* bn_mp_gcd.c */
5121 /* Greatest Common Divisor using the binary method */
5122 static int
mp_gcd(mp_int * a,mp_int * b,mp_int * c)5123 mp_gcd(mp_int *a, mp_int *b, mp_int *c)
5124 {
5125 	mp_int  u, v;
5126 	int     k, u_lsb, v_lsb, res;
5127 
5128 	/* either zero than gcd is the largest */
5129 	if (BN_is_zero(a) == MP_YES) {
5130 		return absolute(b, c);
5131 	}
5132 	if (BN_is_zero(b) == MP_YES) {
5133 		return absolute(a, c);
5134 	}
5135 
5136 	/* get copies of a and b we can modify */
5137 	if ((res = mp_init_copy(&u, a)) != MP_OKAY) {
5138 		return res;
5139 	}
5140 
5141 	if ((res = mp_init_copy(&v, b)) != MP_OKAY) {
5142 		goto LBL_U;
5143 	}
5144 
5145 	/* must be positive for the remainder of the algorithm */
5146 	u.sign = v.sign = MP_ZPOS;
5147 
5148 	/* B1.  Find the common power of two for u and v */
5149 	u_lsb = mp_cnt_lsb(&u);
5150 	v_lsb = mp_cnt_lsb(&v);
5151 	k = MIN(u_lsb, v_lsb);
5152 
5153 	if (k > 0) {
5154 		/* divide the power of two out */
5155 		if ((res = rshift_bits(&u, k, &u, NULL)) != MP_OKAY) {
5156 			goto LBL_V;
5157 		}
5158 
5159 		if ((res = rshift_bits(&v, k, &v, NULL)) != MP_OKAY) {
5160 			goto LBL_V;
5161 		}
5162 	}
5163 
5164 	/* divide any remaining factors of two out */
5165 	if (u_lsb != k) {
5166 		if ((res = rshift_bits(&u, u_lsb - k, &u, NULL)) != MP_OKAY) {
5167 			goto LBL_V;
5168 		}
5169 	}
5170 
5171 	if (v_lsb != k) {
5172 		if ((res = rshift_bits(&v, v_lsb - k, &v, NULL)) != MP_OKAY) {
5173 			goto LBL_V;
5174 		}
5175 	}
5176 
5177 	while (BN_is_zero(&v) == 0) {
5178 		/* make sure v is the largest */
5179 		if (compare_magnitude(&u, &v) == MP_GT) {
5180 			/* swap u and v to make sure v is >= u */
5181 			mp_exch(&u, &v);
5182 		}
5183 
5184 		/* subtract smallest from largest */
5185 		if ((res = signed_subtract(&v, &u, &v)) != MP_OKAY) {
5186 			goto LBL_V;
5187 		}
5188 
5189 		/* Divide out all factors of two */
5190 		if ((res = rshift_bits(&v, mp_cnt_lsb(&v), &v, NULL)) != MP_OKAY) {
5191 			goto LBL_V;
5192 		}
5193 	}
5194 
5195 	/* multiply by 2**k which we divided out at the beginning */
5196 	if ((res = lshift_bits(&u, k, c)) != MP_OKAY) {
5197 		goto LBL_V;
5198 	}
5199 	c->sign = MP_ZPOS;
5200 	res = MP_OKAY;
5201 LBL_V:
5202 	mp_clear (&u);
5203 LBL_U:
5204 	mp_clear (&v);
5205 	return res;
5206 }
5207 
5208 /**************************************************************************/
5209 
5210 /* BIGNUM emulation layer */
5211 
5212 /* essentially, these are just wrappers around the libtommath functions */
5213 /* usually the order of args changes */
5214 /* the BIGNUM API tends to have more const poisoning */
5215 /* these wrappers also check the arguments passed for sanity */
5216 
5217 BIGNUM *
BN_bin2bn(const uint8_t * data,int len,BIGNUM * ret)5218 BN_bin2bn(const uint8_t *data, int len, BIGNUM *ret)
5219 {
5220 	if (data == NULL) {
5221 		return BN_new();
5222 	}
5223 	if (ret == NULL) {
5224 		ret = BN_new();
5225 	}
5226 	return (mp_read_unsigned_bin(ret, data, len) == MP_OKAY) ? ret : NULL;
5227 }
5228 
5229 /* store in unsigned [big endian] format */
5230 int
BN_bn2bin(const BIGNUM * a,unsigned char * b)5231 BN_bn2bin(const BIGNUM *a, unsigned char *b)
5232 {
5233 	BIGNUM	t;
5234 	int    	x;
5235 
5236 	if (a == NULL || b == NULL) {
5237 		return -1;
5238 	}
5239 	if (mp_init_copy (&t, __UNCONST(a)) != MP_OKAY) {
5240 		return -1;
5241 	}
5242 	for (x = 0; !BN_is_zero(&t) ; ) {
5243 		b[x++] = (unsigned char) (t.dp[0] & 0xff);
5244 		if (rshift_bits(&t, 8, &t, NULL) != MP_OKAY) {
5245 			mp_clear(&t);
5246 			return -1;
5247 		}
5248 	}
5249 	bn_reverse(b, x);
5250 	mp_clear(&t);
5251 	return x;
5252 }
5253 
5254 void
BN_init(BIGNUM * a)5255 BN_init(BIGNUM *a)
5256 {
5257 	if (a != NULL) {
5258 		mp_init(a);
5259 	}
5260 }
5261 
5262 BIGNUM *
BN_new(void)5263 BN_new(void)
5264 {
5265 	BIGNUM	*a;
5266 
5267 	if ((a = allocate(1, sizeof(*a))) != NULL) {
5268 		mp_init(a);
5269 	}
5270 	return a;
5271 }
5272 
5273 /* copy, b = a */
5274 int
BN_copy(BIGNUM * b,const BIGNUM * a)5275 BN_copy(BIGNUM *b, const BIGNUM *a)
5276 {
5277 	if (a == NULL || b == NULL) {
5278 		return MP_VAL;
5279 	}
5280 	return mp_copy(__UNCONST(a), b);
5281 }
5282 
5283 BIGNUM *
BN_dup(const BIGNUM * a)5284 BN_dup(const BIGNUM *a)
5285 {
5286 	BIGNUM	*ret;
5287 
5288 	if (a == NULL) {
5289 		return NULL;
5290 	}
5291 	if ((ret = BN_new()) != NULL) {
5292 		BN_copy(ret, a);
5293 	}
5294 	return ret;
5295 }
5296 
5297 void
BN_swap(BIGNUM * a,BIGNUM * b)5298 BN_swap(BIGNUM *a, BIGNUM *b)
5299 {
5300 	if (a && b) {
5301 		mp_exch(a, b);
5302 	}
5303 }
5304 
5305 int
BN_lshift(BIGNUM * r,const BIGNUM * a,int n)5306 BN_lshift(BIGNUM *r, const BIGNUM *a, int n)
5307 {
5308 	if (r == NULL || a == NULL || n < 0) {
5309 		return 0;
5310 	}
5311 	BN_copy(r, a);
5312 	return lshift_digits(r, n) == MP_OKAY;
5313 }
5314 
5315 int
BN_lshift1(BIGNUM * r,BIGNUM * a)5316 BN_lshift1(BIGNUM *r, BIGNUM *a)
5317 {
5318 	if (r == NULL || a == NULL) {
5319 		return 0;
5320 	}
5321 	BN_copy(r, a);
5322 	return lshift_digits(r, 1) == MP_OKAY;
5323 }
5324 
5325 int
BN_rshift(BIGNUM * r,const BIGNUM * a,int n)5326 BN_rshift(BIGNUM *r, const BIGNUM *a, int n)
5327 {
5328 	if (r == NULL || a == NULL || n < 0) {
5329 		return MP_VAL;
5330 	}
5331 	BN_copy(r, a);
5332 	return rshift_digits(r, n) == MP_OKAY;
5333 }
5334 
5335 int
BN_rshift1(BIGNUM * r,BIGNUM * a)5336 BN_rshift1(BIGNUM *r, BIGNUM *a)
5337 {
5338 	if (r == NULL || a == NULL) {
5339 		return 0;
5340 	}
5341 	BN_copy(r, a);
5342 	return rshift_digits(r, 1) == MP_OKAY;
5343 }
5344 
5345 int
BN_set_word(BIGNUM * a,BN_ULONG w)5346 BN_set_word(BIGNUM *a, BN_ULONG w)
5347 {
5348 	if (a == NULL) {
5349 		return 0;
5350 	}
5351 	set_word(a, w);
5352 	return 1;
5353 }
5354 
5355 int
BN_add(BIGNUM * r,const BIGNUM * a,const BIGNUM * b)5356 BN_add(BIGNUM *r, const BIGNUM *a, const BIGNUM *b)
5357 {
5358 	if (a == NULL || b == NULL || r == NULL) {
5359 		return 0;
5360 	}
5361 	return signed_add(__UNCONST(a), __UNCONST(b), r) == MP_OKAY;
5362 }
5363 
5364 int
BN_sub(BIGNUM * r,const BIGNUM * a,const BIGNUM * b)5365 BN_sub(BIGNUM *r, const BIGNUM *a, const BIGNUM *b)
5366 {
5367 	if (a == NULL || b == NULL || r == NULL) {
5368 		return 0;
5369 	}
5370 	return signed_subtract(__UNCONST(a), __UNCONST(b), r) == MP_OKAY;
5371 }
5372 
5373 int
BN_mul(BIGNUM * r,const BIGNUM * a,const BIGNUM * b,BN_CTX * ctx)5374 BN_mul(BIGNUM *r, const BIGNUM *a, const BIGNUM *b, BN_CTX *ctx)
5375 {
5376 	if (a == NULL || b == NULL || r == NULL) {
5377 		return 0;
5378 	}
5379 	USE_ARG(ctx);
5380 	return signed_multiply(__UNCONST(a), __UNCONST(b), r) == MP_OKAY;
5381 }
5382 
5383 int
BN_div(BIGNUM * dv,BIGNUM * rem,const BIGNUM * a,const BIGNUM * d,BN_CTX * ctx)5384 BN_div(BIGNUM *dv, BIGNUM *rem, const BIGNUM *a, const BIGNUM *d, BN_CTX *ctx)
5385 {
5386 	if ((dv == NULL && rem == NULL) || a == NULL || d == NULL) {
5387 		return 0;
5388 	}
5389 	USE_ARG(ctx);
5390 	return signed_divide(dv, rem, __UNCONST(a), __UNCONST(d)) == MP_OKAY;
5391 }
5392 
5393 /* perform a bit operation on the 2 bignums */
5394 int
BN_bitop(BIGNUM * r,const BIGNUM * a,char op,const BIGNUM * b)5395 BN_bitop(BIGNUM *r, const BIGNUM *a, char op, const BIGNUM *b)
5396 {
5397 	unsigned	ndigits;
5398 	mp_digit	ad;
5399 	mp_digit	bd;
5400 	int		i;
5401 
5402 	if (a == NULL || b == NULL || r == NULL) {
5403 		return 0;
5404 	}
5405 	if (BN_cmp(__UNCONST(a), __UNCONST(b)) >= 0) {
5406 		BN_copy(r, a);
5407 		ndigits = a->used;
5408 	} else {
5409 		BN_copy(r, b);
5410 		ndigits = b->used;
5411 	}
5412 	for (i = 0 ; i < (int)ndigits ; i++) {
5413 		ad = (i > a->used) ? 0 : a->dp[i];
5414 		bd = (i > b->used) ? 0 : b->dp[i];
5415 		switch(op) {
5416 		case '&':
5417 			r->dp[i] = (ad & bd);
5418 			break;
5419 		case '|':
5420 			r->dp[i] = (ad | bd);
5421 			break;
5422 		case '^':
5423 			r->dp[i] = (ad ^ bd);
5424 			break;
5425 		default:
5426 			break;
5427 		}
5428 	}
5429 	return 1;
5430 }
5431 
5432 void
BN_free(BIGNUM * a)5433 BN_free(BIGNUM *a)
5434 {
5435 	if (a) {
5436 		mp_clear(a);
5437 		free(a);
5438 	}
5439 }
5440 
5441 void
BN_clear(BIGNUM * a)5442 BN_clear(BIGNUM *a)
5443 {
5444 	if (a) {
5445 		mp_clear(a);
5446 	}
5447 }
5448 
5449 void
BN_clear_free(BIGNUM * a)5450 BN_clear_free(BIGNUM *a)
5451 {
5452 	BN_clear(a);
5453 	free(a);
5454 }
5455 
5456 int
BN_num_bytes(const BIGNUM * a)5457 BN_num_bytes(const BIGNUM *a)
5458 {
5459 	if (a == NULL) {
5460 		return MP_VAL;
5461 	}
5462 	return mp_unsigned_bin_size(__UNCONST(a));
5463 }
5464 
5465 int
BN_num_bits(const BIGNUM * a)5466 BN_num_bits(const BIGNUM *a)
5467 {
5468 	if (a == NULL) {
5469 		return 0;
5470 	}
5471 	return mp_count_bits(a);
5472 }
5473 
5474 void
BN_set_negative(BIGNUM * a,int n)5475 BN_set_negative(BIGNUM *a, int n)
5476 {
5477 	if (a) {
5478 		a->sign = (n) ? MP_NEG : 0;
5479 	}
5480 }
5481 
5482 int
BN_cmp(BIGNUM * a,BIGNUM * b)5483 BN_cmp(BIGNUM *a, BIGNUM *b)
5484 {
5485 	if (a == NULL || b == NULL) {
5486 		return MP_VAL;
5487 	}
5488 	switch(signed_compare(a, b)) {
5489 	case MP_LT:
5490 		return -1;
5491 	case MP_GT:
5492 		return 1;
5493 	case MP_EQ:
5494 	default:
5495 		return 0;
5496 	}
5497 }
5498 
5499 int
BN_mod_exp(BIGNUM * Y,BIGNUM * G,BIGNUM * X,BIGNUM * P,BN_CTX * ctx)5500 BN_mod_exp(BIGNUM *Y, BIGNUM *G, BIGNUM *X, BIGNUM *P, BN_CTX *ctx)
5501 {
5502 	if (Y == NULL || G == NULL || X == NULL || P == NULL) {
5503 		return MP_VAL;
5504 	}
5505 	USE_ARG(ctx);
5506 	return exponent_modulo(G, X, P, Y) == MP_OKAY;
5507 }
5508 
5509 BIGNUM *
BN_mod_inverse(BIGNUM * r,BIGNUM * a,const BIGNUM * n,BN_CTX * ctx)5510 BN_mod_inverse(BIGNUM *r, BIGNUM *a, const BIGNUM *n, BN_CTX *ctx)
5511 {
5512 	USE_ARG(ctx);
5513 	if (r == NULL || a == NULL || n == NULL) {
5514 		return NULL;
5515 	}
5516 	return (modular_inverse(r, a, __UNCONST(n)) == MP_OKAY) ? r : NULL;
5517 }
5518 
5519 int
BN_mod_mul(BIGNUM * ret,BIGNUM * a,BIGNUM * b,const BIGNUM * m,BN_CTX * ctx)5520 BN_mod_mul(BIGNUM *ret, BIGNUM *a, BIGNUM *b, const BIGNUM *m, BN_CTX *ctx)
5521 {
5522 	USE_ARG(ctx);
5523 	if (ret == NULL || a == NULL || b == NULL || m == NULL) {
5524 		return 0;
5525 	}
5526 	return multiply_modulo(ret, a, b, __UNCONST(m)) == MP_OKAY;
5527 }
5528 
5529 BN_CTX *
BN_CTX_new(void)5530 BN_CTX_new(void)
5531 {
5532 	return allocate(1, sizeof(BN_CTX));
5533 }
5534 
5535 void
BN_CTX_init(BN_CTX * c)5536 BN_CTX_init(BN_CTX *c)
5537 {
5538 	if (c != NULL) {
5539 		c->arraysize = 15;
5540 		if ((c->v = allocate(sizeof(*c->v), c->arraysize)) == NULL) {
5541 			c->arraysize = 0;
5542 		}
5543 	}
5544 }
5545 
5546 BIGNUM *
BN_CTX_get(BN_CTX * ctx)5547 BN_CTX_get(BN_CTX *ctx)
5548 {
5549 	if (ctx == NULL || ctx->v == NULL || ctx->arraysize == 0 || ctx->count == ctx->arraysize - 1) {
5550 		return NULL;
5551 	}
5552 	return ctx->v[ctx->count++] = BN_new();
5553 }
5554 
5555 void
BN_CTX_start(BN_CTX * ctx)5556 BN_CTX_start(BN_CTX *ctx)
5557 {
5558 	BN_CTX_init(ctx);
5559 }
5560 
5561 void
BN_CTX_free(BN_CTX * c)5562 BN_CTX_free(BN_CTX *c)
5563 {
5564 	unsigned	i;
5565 
5566 	if (c != NULL && c->v != NULL) {
5567 		for (i = 0 ; i < c->count ; i++) {
5568 			BN_clear_free(c->v[i]);
5569 		}
5570 		deallocate(c->v, sizeof(*c->v) * c->arraysize);
5571 	}
5572 }
5573 
5574 void
BN_CTX_end(BN_CTX * ctx)5575 BN_CTX_end(BN_CTX *ctx)
5576 {
5577 	BN_CTX_free(ctx);
5578 }
5579 
5580 char *
BN_bn2hex(const BIGNUM * a)5581 BN_bn2hex(const BIGNUM *a)
5582 {
5583 	return (a == NULL) ? NULL : formatbn(a, 16);
5584 }
5585 
5586 char *
BN_bn2dec(const BIGNUM * a)5587 BN_bn2dec(const BIGNUM *a)
5588 {
5589 	return (a == NULL) ? NULL : formatbn(a, 10);
5590 }
5591 
5592 char *
BN_bn2radix(const BIGNUM * a,unsigned radix)5593 BN_bn2radix(const BIGNUM *a, unsigned radix)
5594 {
5595 	return (a == NULL) ? NULL : formatbn(a, (int)radix);
5596 }
5597 
5598 int
BN_print_fp(FILE * fp,const BIGNUM * a)5599 BN_print_fp(FILE *fp, const BIGNUM *a)
5600 {
5601 	char	*s;
5602 	int	 ret;
5603 
5604 	if (fp == NULL || a == NULL) {
5605 		return 0;
5606 	}
5607 	s = BN_bn2hex(a);
5608 	ret = fprintf(fp, "%s", s);
5609 	deallocate(s, strlen(s) + 1);
5610 	return ret;
5611 }
5612 
5613 #ifdef BN_RAND_NEEDED
5614 int
BN_rand(BIGNUM * rnd,int bits,int top,int bottom)5615 BN_rand(BIGNUM *rnd, int bits, int top, int bottom)
5616 {
5617 	uint64_t	r;
5618 	int		digits;
5619 	int		i;
5620 
5621 	if (rnd == NULL) {
5622 		return 0;
5623 	}
5624 	mp_init_size(rnd, digits = howmany(bits, DIGIT_BIT));
5625 	for (i = 0 ; i < digits ; i++) {
5626 		r = (uint64_t)arc4random();
5627 		r <<= 32;
5628 		r |= arc4random();
5629 		rnd->dp[i] = (r & MP_MASK);
5630 		rnd->used += 1;
5631 	}
5632 	if (top == 0) {
5633 		rnd->dp[rnd->used - 1] |= (((mp_digit)1)<<((mp_digit)DIGIT_BIT));
5634 	}
5635 	if (top == 1) {
5636 		rnd->dp[rnd->used - 1] |= (((mp_digit)1)<<((mp_digit)DIGIT_BIT));
5637 		rnd->dp[rnd->used - 1] |= (((mp_digit)1)<<((mp_digit)(DIGIT_BIT - 1)));
5638 	}
5639 	if (bottom) {
5640 		rnd->dp[0] |= 0x1;
5641 	}
5642 	return 1;
5643 }
5644 
5645 int
BN_rand_range(BIGNUM * rnd,BIGNUM * range)5646 BN_rand_range(BIGNUM *rnd, BIGNUM *range)
5647 {
5648 	if (rnd == NULL || range == NULL || BN_is_zero(range)) {
5649 		return 0;
5650 	}
5651 	BN_rand(rnd, BN_num_bits(range), 1, 0);
5652 	return modulo(rnd, range, rnd) == MP_OKAY;
5653 }
5654 #endif
5655 
5656 int
BN_is_prime(const BIGNUM * a,int checks,void (* callback)(int,int,void *),BN_CTX * ctx,void * cb_arg)5657 BN_is_prime(const BIGNUM *a, int checks, void (*callback)(int, int, void *), BN_CTX *ctx, void *cb_arg)
5658 {
5659 	int	primality;
5660 
5661 	if (a == NULL) {
5662 		return 0;
5663 	}
5664 	USE_ARG(ctx);
5665 	USE_ARG(cb_arg);
5666 	USE_ARG(callback);
5667 	return (mp_prime_is_prime(__UNCONST(a), checks, &primality) == MP_OKAY) ? primality : 0;
5668 }
5669 
5670 const BIGNUM *
BN_value_one(void)5671 BN_value_one(void)
5672 {
5673 	static mp_digit		digit = 1UL;
5674 	static const BIGNUM	one = { &digit, 1, 1, 0 };
5675 
5676 	return &one;
5677 }
5678 
5679 int
BN_hex2bn(BIGNUM ** a,const char * str)5680 BN_hex2bn(BIGNUM **a, const char *str)
5681 {
5682 	return getbn(a, str, 16);
5683 }
5684 
5685 int
BN_dec2bn(BIGNUM ** a,const char * str)5686 BN_dec2bn(BIGNUM **a, const char *str)
5687 {
5688 	return getbn(a, str, 10);
5689 }
5690 
5691 int
BN_radix2bn(BIGNUM ** a,const char * str,unsigned radix)5692 BN_radix2bn(BIGNUM **a, const char *str, unsigned radix)
5693 {
5694 	return getbn(a, str, (int)radix);
5695 }
5696 
5697 int
BN_mod_sub(BIGNUM * r,BIGNUM * a,BIGNUM * b,const BIGNUM * m,BN_CTX * ctx)5698 BN_mod_sub(BIGNUM *r, BIGNUM *a, BIGNUM *b, const BIGNUM *m, BN_CTX *ctx)
5699 {
5700 	USE_ARG(ctx);
5701 	if (r == NULL || a == NULL || b == NULL || m == NULL) {
5702 		return 0;
5703 	}
5704 	return subtract_modulo(a, b, __UNCONST(m), r) == MP_OKAY;
5705 }
5706 
5707 int
BN_is_bit_set(const BIGNUM * a,int n)5708 BN_is_bit_set(const BIGNUM *a, int n)
5709 {
5710 	if (a == NULL || n < 0 || n >= a->used * DIGIT_BIT) {
5711 		return 0;
5712 	}
5713 	return (a->dp[n / DIGIT_BIT] & (1 << (n % DIGIT_BIT))) ? 1 : 0;
5714 }
5715 
5716 /* raise 'a' to power of 'b' */
5717 int
BN_raise(BIGNUM * res,BIGNUM * a,BIGNUM * b)5718 BN_raise(BIGNUM *res, BIGNUM *a, BIGNUM *b)
5719 {
5720 	uint64_t	 exponent;
5721 	BIGNUM		*power;
5722 	BIGNUM		*temp;
5723 	char		*t;
5724 
5725 	t = BN_bn2dec(b);
5726 	exponent = (uint64_t)strtoull(t, NULL, 10);
5727 	free(t);
5728 	if (exponent == 0) {
5729 		BN_copy(res, BN_value_one());
5730 	} else {
5731 		power = BN_dup(a);
5732 		for ( ; (exponent & 1) == 0 ; exponent >>= 1) {
5733 			BN_mul(power, power, power, NULL);
5734 		}
5735 		temp = BN_dup(power);
5736 		for (exponent >>= 1 ; exponent > 0 ; exponent >>= 1) {
5737 			BN_mul(power, power, power, NULL);
5738 			if (exponent & 1) {
5739 				BN_mul(temp, power, temp, NULL);
5740 			}
5741 		}
5742 		BN_copy(res, temp);
5743 		BN_free(power);
5744 		BN_free(temp);
5745 	}
5746 	return 1;
5747 }
5748 
5749 /* compute the factorial */
5750 int
BN_factorial(BIGNUM * res,BIGNUM * f)5751 BN_factorial(BIGNUM *res, BIGNUM *f)
5752 {
5753 	BIGNUM	*one;
5754 	BIGNUM	*i;
5755 
5756 	i = BN_dup(f);
5757 	one = __UNCONST(BN_value_one());
5758 	BN_sub(i, i, one);
5759 	BN_copy(res, f);
5760 	while (BN_cmp(i, one) > 0) {
5761 		BN_mul(res, res, i, NULL);
5762 		BN_sub(i, i, one);
5763 	}
5764 	BN_free(i);
5765 	return 1;
5766 }
5767 
5768 /* get greatest common divisor */
5769 int
BN_gcd(BIGNUM * r,BIGNUM * a,BIGNUM * b,BN_CTX * ctx)5770 BN_gcd(BIGNUM *r, BIGNUM *a, BIGNUM *b, BN_CTX *ctx)
5771 {
5772 	return mp_gcd(a, b, r);
5773 }
5774