xref: /minix3/crypto/external/bsd/netpgp/dist/src/netpgpverify/bignum.c (revision e1cdaee10649323af446eb1a74571984b2ab3181)
1 /*-
2  * Copyright (c) 2012 Alistair Crooks <agc@NetBSD.org>
3  * All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  * 1. Redistributions of source code must retain the above copyright
9  *    notice, this list of conditions and the following disclaimer.
10  * 2. Redistributions in binary form must reproduce the above copyright
11  *    notice, this list of conditions and the following disclaimer in the
12  *    documentation and/or other materials provided with the distribution.
13  *
14  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
15  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
16  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
17  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
18  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
19  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
20  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
21  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
23  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24  */
25 /* LibTomMath, multiple-precision integer library -- Tom St Denis
26  *
27  * LibTomMath is a library that provides multiple-precision
28  * integer arithmetic as well as number theoretic functionality.
29  *
30  * The library was designed directly after the MPI library by
31  * Michael Fromberger but has been written from scratch with
32  * additional optimizations in place.
33  *
34  * The library is free for all purposes without any express
35  * guarantee it works.
36  *
37  * Tom St Denis, tomstdenis@gmail.com, http://libtom.org
38  */
39 #include "config.h"
40 
41 #include <sys/types.h>
42 #include <sys/param.h>
43 
44 #ifdef _KERNEL
45 # include <sys/kmem.h>
46 #else
47 # include <arpa/inet.h>
48 # include <stdarg.h>
49 # include <stdio.h>
50 # include <stdlib.h>
51 # include <string.h>
52 # include <unistd.h>
53 #endif
54 
55 #include "bn.h"
56 
57 /**************************************************************************/
58 
59 /* LibTomMath, multiple-precision integer library -- Tom St Denis
60  *
61  * LibTomMath is a library that provides multiple-precision
62  * integer arithmetic as well as number theoretic functionality.
63  *
64  * The library was designed directly after the MPI library by
65  * Michael Fromberger but has been written from scratch with
66  * additional optimizations in place.
67  *
68  * The library is free for all purposes without any express
69  * guarantee it works.
70  *
71  * Tom St Denis, tomstdenis@gmail.com, http://libtom.org
72  */
73 
74 #define MP_PREC		32
75 #define DIGIT_BIT	28
76 #define MP_MASK          ((((mp_digit)1)<<((mp_digit)DIGIT_BIT))-((mp_digit)1))
77 
78 #define MP_WARRAY	/*LINTED*/(1U << (((sizeof(mp_word) * CHAR_BIT) - (2 * DIGIT_BIT) + 1)))
79 
80 #define MP_NO		0
81 #define MP_YES		1
82 
83 #ifndef USE_ARG
84 #define USE_ARG(x)	/*LINTED*/(void)&(x)
85 #endif
86 
87 #ifndef __arraycount
88 #define	__arraycount(__x)	(sizeof(__x) / sizeof(__x[0]))
89 #endif
90 
91 #define MP_ISZERO(a) (((a)->used == 0) ? MP_YES : MP_NO)
92 
93 typedef int           mp_err;
94 
95 static int signed_multiply(mp_int * a, mp_int * b, mp_int * c);
96 static int square(mp_int * a, mp_int * b);
97 
98 static int signed_subtract_word(mp_int *a, mp_digit b, mp_int *c);
99 
100 static inline void *
allocate(size_t n,size_t m)101 allocate(size_t n, size_t m)
102 {
103 	return calloc(n, m);
104 }
105 
106 static inline void
deallocate(void * v,size_t sz)107 deallocate(void *v, size_t sz)
108 {
109 	USE_ARG(sz);
110 	free(v);
111 }
112 
113 /* set to zero */
114 static inline void
mp_zero(mp_int * a)115 mp_zero(mp_int *a)
116 {
117 	a->sign = MP_ZPOS;
118 	a->used = 0;
119 	memset(a->dp, 0x0, a->alloc * sizeof(*a->dp));
120 }
121 
122 /* grow as required */
123 static int
mp_grow(mp_int * a,int size)124 mp_grow(mp_int *a, int size)
125 {
126 	mp_digit *tmp;
127 
128 	/* if the alloc size is smaller alloc more ram */
129 	if (a->alloc < size) {
130 		/* ensure there are always at least MP_PREC digits extra on top */
131 		size += (MP_PREC * 2) - (size % MP_PREC);
132 
133 		/* reallocate the array a->dp
134 		*
135 		* We store the return in a temporary variable
136 		* in case the operation failed we don't want
137 		* to overwrite the dp member of a.
138 		*/
139 		tmp = realloc(a->dp, sizeof(*tmp) * size);
140 		if (tmp == NULL) {
141 			/* reallocation failed but "a" is still valid [can be freed] */
142 			return MP_MEM;
143 		}
144 
145 		/* reallocation succeeded so set a->dp */
146 		a->dp = tmp;
147 		/* zero excess digits */
148 		memset(&a->dp[a->alloc], 0x0, (size - a->alloc) * sizeof(*a->dp));
149 		a->alloc = size;
150 	}
151 	return MP_OKAY;
152 }
153 
154 /* shift left a certain amount of digits */
155 static int
lshift_digits(mp_int * a,int b)156 lshift_digits(mp_int * a, int b)
157 {
158 	mp_digit *top, *bottom;
159 	int     x, res;
160 
161 	/* if its less than zero return */
162 	if (b <= 0) {
163 		return MP_OKAY;
164 	}
165 
166 	/* grow to fit the new digits */
167 	if (a->alloc < a->used + b) {
168 		if ((res = mp_grow(a, a->used + b)) != MP_OKAY) {
169 			return res;
170 		}
171 	}
172 
173 	/* increment the used by the shift amount then copy upwards */
174 	a->used += b;
175 
176 	/* top */
177 	top = a->dp + a->used - 1;
178 
179 	/* base */
180 	bottom = a->dp + a->used - 1 - b;
181 
182 	/* much like rshift_digits this is implemented using a sliding window
183 	* except the window goes the otherway around.  Copying from
184 	* the bottom to the top.
185 	*/
186 	for (x = a->used - 1; x >= b; x--) {
187 		*top-- = *bottom--;
188 	}
189 
190 	/* zero the lower digits */
191 	memset(a->dp, 0x0, b * sizeof(*a->dp));
192 	return MP_OKAY;
193 }
194 
195 /* trim unused digits
196  *
197  * This is used to ensure that leading zero digits are
198  * trimed and the leading "used" digit will be non-zero
199  * Typically very fast.  Also fixes the sign if there
200  * are no more leading digits
201  */
202 static void
trim_unused_digits(mp_int * a)203 trim_unused_digits(mp_int * a)
204 {
205 	/* decrease used while the most significant digit is
206 	* zero.
207 	*/
208 	while (a->used > 0 && a->dp[a->used - 1] == 0) {
209 		a->used -= 1;
210 	}
211 	/* reset the sign flag if used == 0 */
212 	if (a->used == 0) {
213 		a->sign = MP_ZPOS;
214 	}
215 }
216 
217 /* copy, b = a */
218 static int
mp_copy(BIGNUM * a,BIGNUM * b)219 mp_copy(BIGNUM *a, BIGNUM *b)
220 {
221 	int	res;
222 
223 	/* if dst == src do nothing */
224 	if (a == b) {
225 		return MP_OKAY;
226 	}
227 	if (a == NULL || b == NULL) {
228 		return MP_VAL;
229 	}
230 
231 	/* grow dest */
232 	if (b->alloc < a->used) {
233 		if ((res = mp_grow(b, a->used)) != MP_OKAY) {
234 			return res;
235 		}
236 	}
237 
238 	memcpy(b->dp, a->dp, a->used * sizeof(*b->dp));
239 	if (b->used > a->used) {
240 		memset(&b->dp[a->used], 0x0, (b->used - a->used) * sizeof(*b->dp));
241 	}
242 
243 	/* copy used count and sign */
244 	b->used = a->used;
245 	b->sign = a->sign;
246 	return MP_OKAY;
247 }
248 
249 /* shift left by a certain bit count */
250 static int
lshift_bits(mp_int * a,int b,mp_int * c)251 lshift_bits(mp_int *a, int b, mp_int *c)
252 {
253 	mp_digit d;
254 	int      res;
255 
256 	/* copy */
257 	if (a != c) {
258 		if ((res = mp_copy(a, c)) != MP_OKAY) {
259 			return res;
260 		}
261 	}
262 
263 	if (c->alloc < (int)(c->used + b/DIGIT_BIT + 1)) {
264 		if ((res = mp_grow(c, c->used + b / DIGIT_BIT + 1)) != MP_OKAY) {
265 			return res;
266 		}
267 	}
268 
269 	/* shift by as many digits in the bit count */
270 	if (b >= (int)DIGIT_BIT) {
271 		if ((res = lshift_digits(c, b / DIGIT_BIT)) != MP_OKAY) {
272 			return res;
273 		}
274 	}
275 
276 	/* shift any bit count < DIGIT_BIT */
277 	d = (mp_digit) (b % DIGIT_BIT);
278 	if (d != 0) {
279 		mp_digit *tmpc, shift, mask, carry, rr;
280 		int x;
281 
282 		/* bitmask for carries */
283 		mask = (((mp_digit)1) << d) - 1;
284 
285 		/* shift for msbs */
286 		shift = DIGIT_BIT - d;
287 
288 		/* alias */
289 		tmpc = c->dp;
290 
291 		/* carry */
292 		carry = 0;
293 		for (x = 0; x < c->used; x++) {
294 			/* get the higher bits of the current word */
295 			rr = (*tmpc >> shift) & mask;
296 
297 			/* shift the current word and OR in the carry */
298 			*tmpc = ((*tmpc << d) | carry) & MP_MASK;
299 			++tmpc;
300 
301 			/* set the carry to the carry bits of the current word */
302 			carry = rr;
303 		}
304 
305 		/* set final carry */
306 		if (carry != 0) {
307 			c->dp[c->used++] = carry;
308 		}
309 	}
310 	trim_unused_digits(c);
311 	return MP_OKAY;
312 }
313 
314 /* reads a unsigned char array, assumes the msb is stored first [big endian] */
315 static int
mp_read_unsigned_bin(mp_int * a,const uint8_t * b,int c)316 mp_read_unsigned_bin(mp_int *a, const uint8_t *b, int c)
317 {
318 	int     res;
319 
320 	/* make sure there are at least two digits */
321 	if (a->alloc < 2) {
322 		if ((res = mp_grow(a, 2)) != MP_OKAY) {
323 			return res;
324 		}
325 	}
326 
327 	/* zero the int */
328 	mp_zero(a);
329 
330 	/* read the bytes in */
331 	while (c-- > 0) {
332 		if ((res = lshift_bits(a, 8, a)) != MP_OKAY) {
333 			return res;
334 		}
335 
336 		a->dp[0] |= *b++;
337 		a->used += 1;
338 	}
339 	trim_unused_digits(a);
340 	return MP_OKAY;
341 }
342 
343 /* returns the number of bits in an mpi */
344 static int
mp_count_bits(const mp_int * a)345 mp_count_bits(const mp_int *a)
346 {
347 	int     r;
348 	mp_digit q;
349 
350 	/* shortcut */
351 	if (a->used == 0) {
352 		return 0;
353 	}
354 
355 	/* get number of digits and add that */
356 	r = (a->used - 1) * DIGIT_BIT;
357 
358 	/* take the last digit and count the bits in it */
359 	for (q = a->dp[a->used - 1]; q > ((mp_digit) 0) ; r++) {
360 		q >>= ((mp_digit) 1);
361 	}
362 	return r;
363 }
364 
365 /* compare maginitude of two ints (unsigned) */
366 static int
compare_magnitude(mp_int * a,mp_int * b)367 compare_magnitude(mp_int * a, mp_int * b)
368 {
369 	int     n;
370 	mp_digit *tmpa, *tmpb;
371 
372 	/* compare based on # of non-zero digits */
373 	if (a->used > b->used) {
374 		return MP_GT;
375 	}
376 
377 	if (a->used < b->used) {
378 		return MP_LT;
379 	}
380 
381 	/* alias for a */
382 	tmpa = a->dp + (a->used - 1);
383 
384 	/* alias for b */
385 	tmpb = b->dp + (a->used - 1);
386 
387 	/* compare based on digits  */
388 	for (n = 0; n < a->used; ++n, --tmpa, --tmpb) {
389 		if (*tmpa > *tmpb) {
390 			return MP_GT;
391 		}
392 
393 		if (*tmpa < *tmpb) {
394 			return MP_LT;
395 		}
396 	}
397 	return MP_EQ;
398 }
399 
400 /* compare two ints (signed)*/
401 static int
signed_compare(mp_int * a,mp_int * b)402 signed_compare(mp_int * a, mp_int * b)
403 {
404 	/* compare based on sign */
405 	if (a->sign != b->sign) {
406 		return (a->sign == MP_NEG) ? MP_LT : MP_GT;
407 	}
408 	return (a->sign == MP_NEG) ? compare_magnitude(b, a) : compare_magnitude(a, b);
409 }
410 
411 /* get the size for an unsigned equivalent */
412 static int
mp_unsigned_bin_size(mp_int * a)413 mp_unsigned_bin_size(mp_int * a)
414 {
415 	int     size = mp_count_bits(a);
416 
417 	return (size / 8 + ((size & 7) != 0 ? 1 : 0));
418 }
419 
420 /* init a new mp_int */
421 static int
mp_init(mp_int * a)422 mp_init(mp_int * a)
423 {
424 	/* allocate memory required and clear it */
425 	a->dp = allocate(1, sizeof(*a->dp) * MP_PREC);
426 	if (a->dp == NULL) {
427 		return MP_MEM;
428 	}
429 
430 	/* set the digits to zero */
431 	memset(a->dp, 0x0, MP_PREC * sizeof(*a->dp));
432 
433 	/* set the used to zero, allocated digits to the default precision
434 	* and sign to positive */
435 	a->used  = 0;
436 	a->alloc = MP_PREC;
437 	a->sign  = MP_ZPOS;
438 
439 	return MP_OKAY;
440 }
441 
442 /* clear one (frees)  */
443 static void
mp_clear(mp_int * a)444 mp_clear(mp_int * a)
445 {
446 	/* only do anything if a hasn't been freed previously */
447 	if (a->dp != NULL) {
448 		memset(a->dp, 0x0, a->used * sizeof(*a->dp));
449 
450 		/* free ram */
451 		deallocate(a->dp, (size_t)a->alloc);
452 
453 		/* reset members to make debugging easier */
454 		a->dp = NULL;
455 		a->alloc = a->used = 0;
456 		a->sign  = MP_ZPOS;
457 	}
458 }
459 
460 static int
mp_init_multi(mp_int * mp,...)461 mp_init_multi(mp_int *mp, ...)
462 {
463 	mp_err res = MP_OKAY;      /* Assume ok until proven otherwise */
464 	int n = 0;                 /* Number of ok inits */
465 	mp_int* cur_arg = mp;
466 	va_list args;
467 
468 	va_start(args, mp);        /* init args to next argument from caller */
469 	while (cur_arg != NULL) {
470 		if (mp_init(cur_arg) != MP_OKAY) {
471 			/* Oops - error! Back-track and mp_clear what we already
472 			succeeded in init-ing, then return error.
473 			*/
474 			va_list clean_args;
475 
476 			/* end the current list */
477 			va_end(args);
478 
479 			/* now start cleaning up */
480 			cur_arg = mp;
481 			va_start(clean_args, mp);
482 			while (n--) {
483 				mp_clear(cur_arg);
484 				cur_arg = va_arg(clean_args, mp_int*);
485 			}
486 			va_end(clean_args);
487 			res = MP_MEM;
488 			break;
489 		}
490 		n++;
491 		cur_arg = va_arg(args, mp_int*);
492 	}
493 	va_end(args);
494 	return res;                /* Assumed ok, if error flagged above. */
495 }
496 
497 /* init an mp_init for a given size */
498 static int
mp_init_size(mp_int * a,int size)499 mp_init_size(mp_int * a, int size)
500 {
501 	/* pad size so there are always extra digits */
502 	size += (MP_PREC * 2) - (size % MP_PREC);
503 
504 	/* alloc mem */
505 	a->dp = allocate(1, sizeof(*a->dp) * size);
506 	if (a->dp == NULL) {
507 		return MP_MEM;
508 	}
509 
510 	/* set the members */
511 	a->used  = 0;
512 	a->alloc = size;
513 	a->sign  = MP_ZPOS;
514 
515 	/* zero the digits */
516 	memset(a->dp, 0x0, size * sizeof(*a->dp));
517 	return MP_OKAY;
518 }
519 
520 /* creates "a" then copies b into it */
521 static int
mp_init_copy(mp_int * a,mp_int * b)522 mp_init_copy(mp_int * a, mp_int * b)
523 {
524 	int     res;
525 
526 	if ((res = mp_init(a)) != MP_OKAY) {
527 		return res;
528 	}
529 	return mp_copy(b, a);
530 }
531 
532 /* low level addition, based on HAC pp.594, Algorithm 14.7 */
533 static int
basic_add(mp_int * a,mp_int * b,mp_int * c)534 basic_add(mp_int * a, mp_int * b, mp_int * c)
535 {
536 	mp_int *x;
537 	int     olduse, res, min, max;
538 
539 	/* find sizes, we let |a| <= |b| which means we have to sort
540 	* them.  "x" will point to the input with the most digits
541 	*/
542 	if (a->used > b->used) {
543 		min = b->used;
544 		max = a->used;
545 		x = a;
546 	} else {
547 		min = a->used;
548 		max = b->used;
549 		x = b;
550 	}
551 
552 	/* init result */
553 	if (c->alloc < max + 1) {
554 		if ((res = mp_grow(c, max + 1)) != MP_OKAY) {
555 			return res;
556 		}
557 	}
558 
559 	/* get old used digit count and set new one */
560 	olduse = c->used;
561 	c->used = max + 1;
562 
563 	{
564 		mp_digit carry, *tmpa, *tmpb, *tmpc;
565 		int i;
566 
567 		/* alias for digit pointers */
568 
569 		/* first input */
570 		tmpa = a->dp;
571 
572 		/* second input */
573 		tmpb = b->dp;
574 
575 		/* destination */
576 		tmpc = c->dp;
577 
578 		/* zero the carry */
579 		carry = 0;
580 		for (i = 0; i < min; i++) {
581 			/* Compute the sum at one digit, T[i] = A[i] + B[i] + U */
582 			*tmpc = *tmpa++ + *tmpb++ + carry;
583 
584 			/* U = carry bit of T[i] */
585 			carry = *tmpc >> ((mp_digit)DIGIT_BIT);
586 
587 			/* take away carry bit from T[i] */
588 			*tmpc++ &= MP_MASK;
589 		}
590 
591 		/* now copy higher words if any, that is in A+B
592 		* if A or B has more digits add those in
593 		*/
594 		if (min != max) {
595 			for (; i < max; i++) {
596 				/* T[i] = X[i] + U */
597 				*tmpc = x->dp[i] + carry;
598 
599 				/* U = carry bit of T[i] */
600 				carry = *tmpc >> ((mp_digit)DIGIT_BIT);
601 
602 				/* take away carry bit from T[i] */
603 				*tmpc++ &= MP_MASK;
604 			}
605 		}
606 
607 		/* add carry */
608 		*tmpc++ = carry;
609 
610 		/* clear digits above oldused */
611 		if (olduse > c->used) {
612 			memset(tmpc, 0x0, (olduse - c->used) * sizeof(*c->dp));
613 		}
614 	}
615 
616 	trim_unused_digits(c);
617 	return MP_OKAY;
618 }
619 
620 /* low level subtraction (assumes |a| > |b|), HAC pp.595 Algorithm 14.9 */
621 static int
basic_subtract(mp_int * a,mp_int * b,mp_int * c)622 basic_subtract(mp_int * a, mp_int * b, mp_int * c)
623 {
624 	int     olduse, res, min, max;
625 
626 	/* find sizes */
627 	min = b->used;
628 	max = a->used;
629 
630 	/* init result */
631 	if (c->alloc < max) {
632 		if ((res = mp_grow(c, max)) != MP_OKAY) {
633 			return res;
634 		}
635 	}
636 	olduse = c->used;
637 	c->used = max;
638 
639 	{
640 		mp_digit carry, *tmpa, *tmpb, *tmpc;
641 		int i;
642 
643 		/* alias for digit pointers */
644 		tmpa = a->dp;
645 		tmpb = b->dp;
646 		tmpc = c->dp;
647 
648 		/* set carry to zero */
649 		carry = 0;
650 		for (i = 0; i < min; i++) {
651 			/* T[i] = A[i] - B[i] - U */
652 			*tmpc = *tmpa++ - *tmpb++ - carry;
653 
654 			/* U = carry bit of T[i]
655 			* Note this saves performing an AND operation since
656 			* if a carry does occur it will propagate all the way to the
657 			* MSB.  As a result a single shift is enough to get the carry
658 			*/
659 			carry = *tmpc >> ((mp_digit)(CHAR_BIT * sizeof(mp_digit) - 1));
660 
661 			/* Clear carry from T[i] */
662 			*tmpc++ &= MP_MASK;
663 		}
664 
665 		/* now copy higher words if any, e.g. if A has more digits than B  */
666 		for (; i < max; i++) {
667 			/* T[i] = A[i] - U */
668 			*tmpc = *tmpa++ - carry;
669 
670 			/* U = carry bit of T[i] */
671 			carry = *tmpc >> ((mp_digit)(CHAR_BIT * sizeof(mp_digit) - 1));
672 
673 			/* Clear carry from T[i] */
674 			*tmpc++ &= MP_MASK;
675 		}
676 
677 		/* clear digits above used (since we may not have grown result above) */
678 		if (olduse > c->used) {
679 			memset(tmpc, 0x0, (olduse - c->used) * sizeof(*a->dp));
680 		}
681 	}
682 
683 	trim_unused_digits(c);
684 	return MP_OKAY;
685 }
686 
687 /* high level subtraction (handles signs) */
688 static int
signed_subtract(mp_int * a,mp_int * b,mp_int * c)689 signed_subtract(mp_int * a, mp_int * b, mp_int * c)
690 {
691 	int     sa, sb, res;
692 
693 	sa = a->sign;
694 	sb = b->sign;
695 
696 	if (sa != sb) {
697 		/* subtract a negative from a positive, OR */
698 		/* subtract a positive from a negative. */
699 		/* In either case, ADD their magnitudes, */
700 		/* and use the sign of the first number. */
701 		c->sign = sa;
702 		res = basic_add(a, b, c);
703 	} else {
704 		/* subtract a positive from a positive, OR */
705 		/* subtract a negative from a negative. */
706 		/* First, take the difference between their */
707 		/* magnitudes, then... */
708 		if (compare_magnitude(a, b) != MP_LT) {
709 			/* Copy the sign from the first */
710 			c->sign = sa;
711 			/* The first has a larger or equal magnitude */
712 			res = basic_subtract(a, b, c);
713 		} else {
714 			/* The result has the *opposite* sign from */
715 			/* the first number. */
716 			c->sign = (sa == MP_ZPOS) ? MP_NEG : MP_ZPOS;
717 			/* The second has a larger magnitude */
718 			res = basic_subtract(b, a, c);
719 		}
720 	}
721 	return res;
722 }
723 
724 /* shift right a certain amount of digits */
725 static int
rshift_digits(mp_int * a,int b)726 rshift_digits(mp_int * a, int b)
727 {
728 	/* if b <= 0 then ignore it */
729 	if (b <= 0) {
730 		return 0;
731 	}
732 
733 	/* if b > used then simply zero it and return */
734 	if (a->used <= b) {
735 		mp_zero(a);
736 		return 0;
737 	}
738 
739 	/* this is implemented as a sliding window where
740 	* the window is b-digits long and digits from
741 	* the top of the window are copied to the bottom
742 	*
743 	* e.g.
744 
745 	b-2 | b-1 | b0 | b1 | b2 | ... | bb |   ---->
746 		 /\                   |      ---->
747 		  \-------------------/      ---->
748 	*/
749 	memmove(a->dp, &a->dp[b], (a->used - b) * sizeof(*a->dp));
750 	memset(&a->dp[a->used - b], 0x0, b * sizeof(*a->dp));
751 
752 	/* remove excess digits */
753 	a->used -= b;
754 	return 1;
755 }
756 
757 /* multiply by a digit */
758 static int
multiply_digit(mp_int * a,mp_digit b,mp_int * c)759 multiply_digit(mp_int * a, mp_digit b, mp_int * c)
760 {
761 	mp_digit carry, *tmpa, *tmpc;
762 	mp_word  r;
763 	int      ix, res, olduse;
764 
765 	/* make sure c is big enough to hold a*b */
766 	if (c->alloc < a->used + 1) {
767 		if ((res = mp_grow(c, a->used + 1)) != MP_OKAY) {
768 			return res;
769 		}
770 	}
771 
772 	/* get the original destinations used count */
773 	olduse = c->used;
774 
775 	/* set the sign */
776 	c->sign = a->sign;
777 
778 	/* alias for a->dp [source] */
779 	tmpa = a->dp;
780 
781 	/* alias for c->dp [dest] */
782 	tmpc = c->dp;
783 
784 	/* zero carry */
785 	carry = 0;
786 
787 	/* compute columns */
788 	for (ix = 0; ix < a->used; ix++) {
789 		/* compute product and carry sum for this term */
790 		r = ((mp_word) carry) + ((mp_word)*tmpa++) * ((mp_word)b);
791 
792 		/* mask off higher bits to get a single digit */
793 		*tmpc++ = (mp_digit) (r & ((mp_word) MP_MASK));
794 
795 		/* send carry into next iteration */
796 		carry = (mp_digit) (r >> ((mp_word) DIGIT_BIT));
797 	}
798 
799 	/* store final carry [if any] and increment ix offset  */
800 	*tmpc++ = carry;
801 	++ix;
802 	if (olduse > ix) {
803 		memset(tmpc, 0x0, (olduse - ix) * sizeof(*tmpc));
804 	}
805 
806 	/* set used count */
807 	c->used = a->used + 1;
808 	trim_unused_digits(c);
809 
810 	return MP_OKAY;
811 }
812 
813 /* high level addition (handles signs) */
814 static int
signed_add(mp_int * a,mp_int * b,mp_int * c)815 signed_add(mp_int * a, mp_int * b, mp_int * c)
816 {
817 	int     asign, bsign, res;
818 
819 	/* get sign of both inputs */
820 	asign = a->sign;
821 	bsign = b->sign;
822 
823 	/* handle two cases, not four */
824 	if (asign == bsign) {
825 		/* both positive or both negative */
826 		/* add their magnitudes, copy the sign */
827 		c->sign = asign;
828 		res = basic_add(a, b, c);
829 	} else {
830 		/* one positive, the other negative */
831 		/* subtract the one with the greater magnitude from */
832 		/* the one of the lesser magnitude.  The result gets */
833 		/* the sign of the one with the greater magnitude. */
834 		if (compare_magnitude(a, b) == MP_LT) {
835 			c->sign = bsign;
836 			res = basic_subtract(b, a, c);
837 		} else {
838 			c->sign = asign;
839 			res = basic_subtract(a, b, c);
840 		}
841 	}
842 	return res;
843 }
844 
845 /* swap the elements of two integers, for cases where you can't simply swap the
846  * mp_int pointers around
847  */
848 static void
mp_exch(mp_int * a,mp_int * b)849 mp_exch(mp_int *a, mp_int *b)
850 {
851 	mp_int  t;
852 
853 	t  = *a;
854 	*a = *b;
855 	*b = t;
856 }
857 
858 /* calc a value mod 2**b */
859 static int
modulo_2_to_power(mp_int * a,int b,mp_int * c)860 modulo_2_to_power(mp_int * a, int b, mp_int * c)
861 {
862 	int     x, res;
863 
864 	/* if b is <= 0 then zero the int */
865 	if (b <= 0) {
866 		mp_zero(c);
867 		return MP_OKAY;
868 	}
869 
870 	/* if the modulus is larger than the value than return */
871 	if (b >= (int) (a->used * DIGIT_BIT)) {
872 		res = mp_copy(a, c);
873 		return res;
874 	}
875 
876 	/* copy */
877 	if ((res = mp_copy(a, c)) != MP_OKAY) {
878 		return res;
879 	}
880 
881 	/* zero digits above the last digit of the modulus */
882 	for (x = (b / DIGIT_BIT) + ((b % DIGIT_BIT) == 0 ? 0 : 1); x < c->used; x++) {
883 		c->dp[x] = 0;
884 	}
885 	/* clear the digit that is not completely outside/inside the modulus */
886 	c->dp[b / DIGIT_BIT] &=
887 		(mp_digit) ((((mp_digit) 1) << (((mp_digit) b) % DIGIT_BIT)) - ((mp_digit) 1));
888 	trim_unused_digits(c);
889 	return MP_OKAY;
890 }
891 
892 /* shift right by a certain bit count (store quotient in c, optional remainder in d) */
893 static int
rshift_bits(mp_int * a,int b,mp_int * c,mp_int * d)894 rshift_bits(mp_int * a, int b, mp_int * c, mp_int * d)
895 {
896 	mp_digit D, r, rr;
897 	int     x, res;
898 	mp_int  t;
899 
900 
901 	/* if the shift count is <= 0 then we do no work */
902 	if (b <= 0) {
903 		res = mp_copy(a, c);
904 		if (d != NULL) {
905 			mp_zero(d);
906 		}
907 		return res;
908 	}
909 
910 	if ((res = mp_init(&t)) != MP_OKAY) {
911 		return res;
912 	}
913 
914 	/* get the remainder */
915 	if (d != NULL) {
916 		if ((res = modulo_2_to_power(a, b, &t)) != MP_OKAY) {
917 			mp_clear(&t);
918 			return res;
919 		}
920 	}
921 
922 	/* copy */
923 	if ((res = mp_copy(a, c)) != MP_OKAY) {
924 		mp_clear(&t);
925 		return res;
926 	}
927 
928 	/* shift by as many digits in the bit count */
929 	if (b >= (int)DIGIT_BIT) {
930 		rshift_digits(c, b / DIGIT_BIT);
931 	}
932 
933 	/* shift any bit count < DIGIT_BIT */
934 	D = (mp_digit) (b % DIGIT_BIT);
935 	if (D != 0) {
936 		mp_digit *tmpc, mask, shift;
937 
938 		/* mask */
939 		mask = (((mp_digit)1) << D) - 1;
940 
941 		/* shift for lsb */
942 		shift = DIGIT_BIT - D;
943 
944 		/* alias */
945 		tmpc = c->dp + (c->used - 1);
946 
947 		/* carry */
948 		r = 0;
949 		for (x = c->used - 1; x >= 0; x--) {
950 			/* get the lower  bits of this word in a temp */
951 			rr = *tmpc & mask;
952 
953 			/* shift the current word and mix in the carry bits from the previous word */
954 			*tmpc = (*tmpc >> D) | (r << shift);
955 			--tmpc;
956 
957 			/* set the carry to the carry bits of the current word found above */
958 			r = rr;
959 		}
960 	}
961 	trim_unused_digits(c);
962 	if (d != NULL) {
963 		mp_exch(&t, d);
964 	}
965 	mp_clear(&t);
966 	return MP_OKAY;
967 }
968 
969 /* integer signed division.
970  * c*b + d == a [e.g. a/b, c=quotient, d=remainder]
971  * HAC pp.598 Algorithm 14.20
972  *
973  * Note that the description in HAC is horribly
974  * incomplete.  For example, it doesn't consider
975  * the case where digits are removed from 'x' in
976  * the inner loop.  It also doesn't consider the
977  * case that y has fewer than three digits, etc..
978  *
979  * The overall algorithm is as described as
980  * 14.20 from HAC but fixed to treat these cases.
981 */
982 static int
signed_divide(mp_int * c,mp_int * d,mp_int * a,mp_int * b)983 signed_divide(mp_int *c, mp_int *d, mp_int *a, mp_int *b)
984 {
985 	mp_int  q, x, y, t1, t2;
986 	int     res, n, t, i, norm, neg;
987 
988 	/* is divisor zero ? */
989 	if (MP_ISZERO(b) == MP_YES) {
990 		return MP_VAL;
991 	}
992 
993 	/* if a < b then q=0, r = a */
994 	if (compare_magnitude(a, b) == MP_LT) {
995 		if (d != NULL) {
996 			res = mp_copy(a, d);
997 		} else {
998 			res = MP_OKAY;
999 		}
1000 		if (c != NULL) {
1001 			mp_zero(c);
1002 		}
1003 		return res;
1004 	}
1005 
1006 	if ((res = mp_init_size(&q, a->used + 2)) != MP_OKAY) {
1007 		return res;
1008 	}
1009 	q.used = a->used + 2;
1010 
1011 	if ((res = mp_init(&t1)) != MP_OKAY) {
1012 		goto LBL_Q;
1013 	}
1014 
1015 	if ((res = mp_init(&t2)) != MP_OKAY) {
1016 		goto LBL_T1;
1017 	}
1018 
1019 	if ((res = mp_init_copy(&x, a)) != MP_OKAY) {
1020 		goto LBL_T2;
1021 	}
1022 
1023 	if ((res = mp_init_copy(&y, b)) != MP_OKAY) {
1024 		goto LBL_X;
1025 	}
1026 
1027 	/* fix the sign */
1028 	neg = (a->sign == b->sign) ? MP_ZPOS : MP_NEG;
1029 	x.sign = y.sign = MP_ZPOS;
1030 
1031 	/* normalize both x and y, ensure that y >= b/2, [b == 2**DIGIT_BIT] */
1032 	norm = mp_count_bits(&y) % DIGIT_BIT;
1033 	if (norm < (int)(DIGIT_BIT-1)) {
1034 		norm = (DIGIT_BIT-1) - norm;
1035 		if ((res = lshift_bits(&x, norm, &x)) != MP_OKAY) {
1036 			goto LBL_Y;
1037 		}
1038 		if ((res = lshift_bits(&y, norm, &y)) != MP_OKAY) {
1039 			goto LBL_Y;
1040 		}
1041 	} else {
1042 		norm = 0;
1043 	}
1044 
1045 	/* note hac does 0 based, so if used==5 then its 0,1,2,3,4, e.g. use 4 */
1046 	n = x.used - 1;
1047 	t = y.used - 1;
1048 
1049 	/* while (x >= y*b**n-t) do { q[n-t] += 1; x -= y*b**{n-t} } */
1050 	if ((res = lshift_digits(&y, n - t)) != MP_OKAY) { /* y = y*b**{n-t} */
1051 		goto LBL_Y;
1052 	}
1053 
1054 	while (signed_compare(&x, &y) != MP_LT) {
1055 		++(q.dp[n - t]);
1056 		if ((res = signed_subtract(&x, &y, &x)) != MP_OKAY) {
1057 			goto LBL_Y;
1058 		}
1059 	}
1060 
1061 	/* reset y by shifting it back down */
1062 	rshift_digits(&y, n - t);
1063 
1064 	/* step 3. for i from n down to (t + 1) */
1065 	for (i = n; i >= (t + 1); i--) {
1066 		if (i > x.used) {
1067 			continue;
1068 		}
1069 
1070 		/* step 3.1 if xi == yt then set q{i-t-1} to b-1,
1071 		* otherwise set q{i-t-1} to (xi*b + x{i-1})/yt */
1072 		if (x.dp[i] == y.dp[t]) {
1073 			q.dp[i - t - 1] = ((((mp_digit)1) << DIGIT_BIT) - 1);
1074 		} else {
1075 			mp_word tmp;
1076 			tmp = ((mp_word) x.dp[i]) << ((mp_word) DIGIT_BIT);
1077 			tmp |= ((mp_word) x.dp[i - 1]);
1078 			tmp /= ((mp_word) y.dp[t]);
1079 			if (tmp > (mp_word) MP_MASK) {
1080 				tmp = MP_MASK;
1081 			}
1082 			q.dp[i - t - 1] = (mp_digit) (tmp & (mp_word) (MP_MASK));
1083 		}
1084 
1085 		/* while (q{i-t-1} * (yt * b + y{t-1})) >
1086 		     xi * b**2 + xi-1 * b + xi-2
1087 			do q{i-t-1} -= 1;
1088 		*/
1089 		q.dp[i - t - 1] = (q.dp[i - t - 1] + 1) & MP_MASK;
1090 		do {
1091 			q.dp[i - t - 1] = (q.dp[i - t - 1] - 1) & MP_MASK;
1092 
1093 			/* find left hand */
1094 			mp_zero(&t1);
1095 			t1.dp[0] = (t - 1 < 0) ? 0 : y.dp[t - 1];
1096 			t1.dp[1] = y.dp[t];
1097 			t1.used = 2;
1098 			if ((res = multiply_digit(&t1, q.dp[i - t - 1], &t1)) != MP_OKAY) {
1099 				goto LBL_Y;
1100 			}
1101 
1102 			/* find right hand */
1103 			t2.dp[0] = (i - 2 < 0) ? 0 : x.dp[i - 2];
1104 			t2.dp[1] = (i - 1 < 0) ? 0 : x.dp[i - 1];
1105 			t2.dp[2] = x.dp[i];
1106 			t2.used = 3;
1107 		} while (compare_magnitude(&t1, &t2) == MP_GT);
1108 
1109 		/* step 3.3 x = x - q{i-t-1} * y * b**{i-t-1} */
1110 		if ((res = multiply_digit(&y, q.dp[i - t - 1], &t1)) != MP_OKAY) {
1111 			goto LBL_Y;
1112 		}
1113 
1114 		if ((res = lshift_digits(&t1, i - t - 1)) != MP_OKAY) {
1115 			goto LBL_Y;
1116 		}
1117 
1118 		if ((res = signed_subtract(&x, &t1, &x)) != MP_OKAY) {
1119 			goto LBL_Y;
1120 		}
1121 
1122 		/* if x < 0 then { x = x + y*b**{i-t-1}; q{i-t-1} -= 1; } */
1123 		if (x.sign == MP_NEG) {
1124 			if ((res = mp_copy(&y, &t1)) != MP_OKAY) {
1125 				goto LBL_Y;
1126 			}
1127 			if ((res = lshift_digits(&t1, i - t - 1)) != MP_OKAY) {
1128 				goto LBL_Y;
1129 			}
1130 			if ((res = signed_add(&x, &t1, &x)) != MP_OKAY) {
1131 				goto LBL_Y;
1132 			}
1133 
1134 			q.dp[i - t - 1] = (q.dp[i - t - 1] - 1UL) & MP_MASK;
1135 		}
1136 	}
1137 
1138 	/* now q is the quotient and x is the remainder
1139 	* [which we have to normalize]
1140 	*/
1141 
1142 	/* get sign before writing to c */
1143 	x.sign = x.used == 0 ? MP_ZPOS : a->sign;
1144 
1145 	if (c != NULL) {
1146 		trim_unused_digits(&q);
1147 		mp_exch(&q, c);
1148 		c->sign = neg;
1149 	}
1150 
1151 	if (d != NULL) {
1152 		rshift_bits(&x, norm, &x, NULL);
1153 		mp_exch(&x, d);
1154 	}
1155 
1156 	res = MP_OKAY;
1157 
1158 LBL_Y:
1159 	mp_clear(&y);
1160 LBL_X:
1161 	mp_clear(&x);
1162 LBL_T2:
1163 	mp_clear(&t2);
1164 LBL_T1:
1165 	mp_clear(&t1);
1166 LBL_Q:
1167 	mp_clear(&q);
1168 	return res;
1169 }
1170 
1171 /* c = a mod b, 0 <= c < b */
1172 static int
modulo(mp_int * a,mp_int * b,mp_int * c)1173 modulo(mp_int * a, mp_int * b, mp_int * c)
1174 {
1175 	mp_int  t;
1176 	int     res;
1177 
1178 	if ((res = mp_init(&t)) != MP_OKAY) {
1179 		return res;
1180 	}
1181 
1182 	if ((res = signed_divide(NULL, &t, a, b)) != MP_OKAY) {
1183 		mp_clear(&t);
1184 		return res;
1185 	}
1186 
1187 	if (t.sign != b->sign) {
1188 		res = signed_add(b, &t, c);
1189 	} else {
1190 		res = MP_OKAY;
1191 		mp_exch(&t, c);
1192 	}
1193 
1194 	mp_clear(&t);
1195 	return res;
1196 }
1197 
1198 /* set to a digit */
1199 static void
set_word(mp_int * a,mp_digit b)1200 set_word(mp_int * a, mp_digit b)
1201 {
1202 	mp_zero(a);
1203 	a->dp[0] = b & MP_MASK;
1204 	a->used = (a->dp[0] != 0) ? 1 : 0;
1205 }
1206 
1207 /* b = a/2 */
1208 static int
half(mp_int * a,mp_int * b)1209 half(mp_int * a, mp_int * b)
1210 {
1211 	int     x, res, oldused;
1212 
1213 	/* copy */
1214 	if (b->alloc < a->used) {
1215 		if ((res = mp_grow(b, a->used)) != MP_OKAY) {
1216 			return res;
1217 		}
1218 	}
1219 
1220 	oldused = b->used;
1221 	b->used = a->used;
1222 	{
1223 		mp_digit r, rr, *tmpa, *tmpb;
1224 
1225 		/* source alias */
1226 		tmpa = a->dp + b->used - 1;
1227 
1228 		/* dest alias */
1229 		tmpb = b->dp + b->used - 1;
1230 
1231 		/* carry */
1232 		r = 0;
1233 		for (x = b->used - 1; x >= 0; x--) {
1234 			/* get the carry for the next iteration */
1235 			rr = *tmpa & 1;
1236 
1237 			/* shift the current digit, add in carry and store */
1238 			*tmpb-- = (*tmpa-- >> 1) | (r << (DIGIT_BIT - 1));
1239 
1240 			/* forward carry to next iteration */
1241 			r = rr;
1242 		}
1243 
1244 		/* zero excess digits */
1245 		tmpb = b->dp + b->used;
1246 		for (x = b->used; x < oldused; x++) {
1247 			*tmpb++ = 0;
1248 		}
1249 	}
1250 	b->sign = a->sign;
1251 	trim_unused_digits(b);
1252 	return MP_OKAY;
1253 }
1254 
1255 /* compare a digit */
1256 static int
compare_digit(mp_int * a,mp_digit b)1257 compare_digit(mp_int * a, mp_digit b)
1258 {
1259 	/* compare based on sign */
1260 	if (a->sign == MP_NEG) {
1261 		return MP_LT;
1262 	}
1263 
1264 	/* compare based on magnitude */
1265 	if (a->used > 1) {
1266 		return MP_GT;
1267 	}
1268 
1269 	/* compare the only digit of a to b */
1270 	if (a->dp[0] > b) {
1271 		return MP_GT;
1272 	} else if (a->dp[0] < b) {
1273 		return MP_LT;
1274 	} else {
1275 		return MP_EQ;
1276 	}
1277 }
1278 
1279 static void
mp_clear_multi(mp_int * mp,...)1280 mp_clear_multi(mp_int *mp, ...)
1281 {
1282 	mp_int* next_mp = mp;
1283 	va_list args;
1284 
1285 	va_start(args, mp);
1286 	while (next_mp != NULL) {
1287 		mp_clear(next_mp);
1288 		next_mp = va_arg(args, mp_int*);
1289 	}
1290 	va_end(args);
1291 }
1292 
1293 /* computes the modular inverse via binary extended euclidean algorithm,
1294  * that is c = 1/a mod b
1295  *
1296  * Based on slow invmod except this is optimized for the case where b is
1297  * odd as per HAC Note 14.64 on pp. 610
1298  */
1299 static int
fast_modular_inverse(mp_int * a,mp_int * b,mp_int * c)1300 fast_modular_inverse(mp_int * a, mp_int * b, mp_int * c)
1301 {
1302 	mp_int  x, y, u, v, B, D;
1303 	int     res, neg;
1304 
1305 	/* 2. [modified] b must be odd   */
1306 	if (MP_ISZERO(b) == MP_YES) {
1307 		return MP_VAL;
1308 	}
1309 
1310 	/* init all our temps */
1311 	if ((res = mp_init_multi(&x, &y, &u, &v, &B, &D, NULL)) != MP_OKAY) {
1312 		return res;
1313 	}
1314 
1315 	/* x == modulus, y == value to invert */
1316 	if ((res = mp_copy(b, &x)) != MP_OKAY) {
1317 		goto LBL_ERR;
1318 	}
1319 
1320 	/* we need y = |a| */
1321 	if ((res = modulo(a, b, &y)) != MP_OKAY) {
1322 		goto LBL_ERR;
1323 	}
1324 
1325 	/* 3. u=x, v=y, A=1, B=0, C=0,D=1 */
1326 	if ((res = mp_copy(&x, &u)) != MP_OKAY) {
1327 		goto LBL_ERR;
1328 	}
1329 	if ((res = mp_copy(&y, &v)) != MP_OKAY) {
1330 		goto LBL_ERR;
1331 	}
1332 	set_word(&D, 1);
1333 
1334 top:
1335 	/* 4.  while u is even do */
1336 	while (BN_is_even(&u) == 1) {
1337 		/* 4.1 u = u/2 */
1338 		if ((res = half(&u, &u)) != MP_OKAY) {
1339 			goto LBL_ERR;
1340 		}
1341 		/* 4.2 if B is odd then */
1342 		if (BN_is_odd(&B) == 1) {
1343 			if ((res = signed_subtract(&B, &x, &B)) != MP_OKAY) {
1344 				goto LBL_ERR;
1345 			}
1346 		}
1347 		/* B = B/2 */
1348 		if ((res = half(&B, &B)) != MP_OKAY) {
1349 			goto LBL_ERR;
1350 		}
1351 	}
1352 
1353 	/* 5.  while v is even do */
1354 	while (BN_is_even(&v) == 1) {
1355 		/* 5.1 v = v/2 */
1356 		if ((res = half(&v, &v)) != MP_OKAY) {
1357 			goto LBL_ERR;
1358 		}
1359 		/* 5.2 if D is odd then */
1360 		if (BN_is_odd(&D) == 1) {
1361 			/* D = (D-x)/2 */
1362 			if ((res = signed_subtract(&D, &x, &D)) != MP_OKAY) {
1363 				goto LBL_ERR;
1364 			}
1365 		}
1366 		/* D = D/2 */
1367 		if ((res = half(&D, &D)) != MP_OKAY) {
1368 			goto LBL_ERR;
1369 		}
1370 	}
1371 
1372 	/* 6.  if u >= v then */
1373 	if (signed_compare(&u, &v) != MP_LT) {
1374 		/* u = u - v, B = B - D */
1375 		if ((res = signed_subtract(&u, &v, &u)) != MP_OKAY) {
1376 			goto LBL_ERR;
1377 		}
1378 
1379 		if ((res = signed_subtract(&B, &D, &B)) != MP_OKAY) {
1380 			goto LBL_ERR;
1381 		}
1382 	} else {
1383 		/* v - v - u, D = D - B */
1384 		if ((res = signed_subtract(&v, &u, &v)) != MP_OKAY) {
1385 			goto LBL_ERR;
1386 		}
1387 
1388 		if ((res = signed_subtract(&D, &B, &D)) != MP_OKAY) {
1389 			goto LBL_ERR;
1390 		}
1391 	}
1392 
1393 	/* if not zero goto step 4 */
1394 	if (MP_ISZERO(&u) == MP_NO) {
1395 		goto top;
1396 	}
1397 
1398 	/* now a = C, b = D, gcd == g*v */
1399 
1400 	/* if v != 1 then there is no inverse */
1401 	if (compare_digit(&v, 1) != MP_EQ) {
1402 		res = MP_VAL;
1403 		goto LBL_ERR;
1404 	}
1405 
1406 	/* b is now the inverse */
1407 	neg = a->sign;
1408 	while (D.sign == MP_NEG) {
1409 		if ((res = signed_add(&D, b, &D)) != MP_OKAY) {
1410 			goto LBL_ERR;
1411 		}
1412 	}
1413 	mp_exch(&D, c);
1414 	c->sign = neg;
1415 	res = MP_OKAY;
1416 
1417 LBL_ERR:
1418 	mp_clear_multi (&x, &y, &u, &v, &B, &D, NULL);
1419 	return res;
1420 }
1421 
1422 /* hac 14.61, pp608 */
1423 static int
slow_modular_inverse(mp_int * a,mp_int * b,mp_int * c)1424 slow_modular_inverse(mp_int * a, mp_int * b, mp_int * c)
1425 {
1426 	mp_int  x, y, u, v, A, B, C, D;
1427 	int     res;
1428 
1429 	/* b cannot be negative */
1430 	if (b->sign == MP_NEG || MP_ISZERO(b) == MP_YES) {
1431 		return MP_VAL;
1432 	}
1433 
1434 	/* init temps */
1435 	if ((res = mp_init_multi(&x, &y, &u, &v,
1436 		   &A, &B, &C, &D, NULL)) != MP_OKAY) {
1437 		return res;
1438 	}
1439 
1440 	/* x = a, y = b */
1441 	if ((res = modulo(a, b, &x)) != MP_OKAY) {
1442 		goto LBL_ERR;
1443 	}
1444 	if ((res = mp_copy(b, &y)) != MP_OKAY) {
1445 		goto LBL_ERR;
1446 	}
1447 
1448 	/* 2. [modified] if x,y are both even then return an error! */
1449 	if (BN_is_even(&x) == 1 && BN_is_even(&y) == 1) {
1450 		res = MP_VAL;
1451 		goto LBL_ERR;
1452 	}
1453 
1454 	/* 3. u=x, v=y, A=1, B=0, C=0,D=1 */
1455 	if ((res = mp_copy(&x, &u)) != MP_OKAY) {
1456 		goto LBL_ERR;
1457 	}
1458 	if ((res = mp_copy(&y, &v)) != MP_OKAY) {
1459 		goto LBL_ERR;
1460 	}
1461 	set_word(&A, 1);
1462 	set_word(&D, 1);
1463 
1464 top:
1465 	/* 4.  while u is even do */
1466 	while (BN_is_even(&u) == 1) {
1467 		/* 4.1 u = u/2 */
1468 		if ((res = half(&u, &u)) != MP_OKAY) {
1469 			goto LBL_ERR;
1470 		}
1471 		/* 4.2 if A or B is odd then */
1472 		if (BN_is_odd(&A) == 1 || BN_is_odd(&B) == 1) {
1473 			/* A = (A+y)/2, B = (B-x)/2 */
1474 			if ((res = signed_add(&A, &y, &A)) != MP_OKAY) {
1475 				 goto LBL_ERR;
1476 			}
1477 			if ((res = signed_subtract(&B, &x, &B)) != MP_OKAY) {
1478 				 goto LBL_ERR;
1479 			}
1480 		}
1481 		/* A = A/2, B = B/2 */
1482 		if ((res = half(&A, &A)) != MP_OKAY) {
1483 			goto LBL_ERR;
1484 		}
1485 		if ((res = half(&B, &B)) != MP_OKAY) {
1486 			goto LBL_ERR;
1487 		}
1488 	}
1489 
1490 	/* 5.  while v is even do */
1491 	while (BN_is_even(&v) == 1) {
1492 		/* 5.1 v = v/2 */
1493 		if ((res = half(&v, &v)) != MP_OKAY) {
1494 			goto LBL_ERR;
1495 		}
1496 		/* 5.2 if C or D is odd then */
1497 		if (BN_is_odd(&C) == 1 || BN_is_odd(&D) == 1) {
1498 			/* C = (C+y)/2, D = (D-x)/2 */
1499 			if ((res = signed_add(&C, &y, &C)) != MP_OKAY) {
1500 				 goto LBL_ERR;
1501 			}
1502 			if ((res = signed_subtract(&D, &x, &D)) != MP_OKAY) {
1503 				 goto LBL_ERR;
1504 			}
1505 		}
1506 		/* C = C/2, D = D/2 */
1507 		if ((res = half(&C, &C)) != MP_OKAY) {
1508 			goto LBL_ERR;
1509 		}
1510 		if ((res = half(&D, &D)) != MP_OKAY) {
1511 			goto LBL_ERR;
1512 		}
1513 	}
1514 
1515 	/* 6.  if u >= v then */
1516 	if (signed_compare(&u, &v) != MP_LT) {
1517 		/* u = u - v, A = A - C, B = B - D */
1518 		if ((res = signed_subtract(&u, &v, &u)) != MP_OKAY) {
1519 			goto LBL_ERR;
1520 		}
1521 
1522 		if ((res = signed_subtract(&A, &C, &A)) != MP_OKAY) {
1523 			goto LBL_ERR;
1524 		}
1525 
1526 		if ((res = signed_subtract(&B, &D, &B)) != MP_OKAY) {
1527 			goto LBL_ERR;
1528 		}
1529 	} else {
1530 		/* v - v - u, C = C - A, D = D - B */
1531 		if ((res = signed_subtract(&v, &u, &v)) != MP_OKAY) {
1532 			goto LBL_ERR;
1533 		}
1534 
1535 		if ((res = signed_subtract(&C, &A, &C)) != MP_OKAY) {
1536 			goto LBL_ERR;
1537 		}
1538 
1539 		if ((res = signed_subtract(&D, &B, &D)) != MP_OKAY) {
1540 			goto LBL_ERR;
1541 		}
1542 	}
1543 
1544 	/* if not zero goto step 4 */
1545 	if (BN_is_zero(&u) == 0) {
1546 		goto top;
1547 	}
1548 	/* now a = C, b = D, gcd == g*v */
1549 
1550 	/* if v != 1 then there is no inverse */
1551 	if (compare_digit(&v, 1) != MP_EQ) {
1552 		res = MP_VAL;
1553 		goto LBL_ERR;
1554 	}
1555 
1556 	/* if its too low */
1557 	while (compare_digit(&C, 0) == MP_LT) {
1558 		if ((res = signed_add(&C, b, &C)) != MP_OKAY) {
1559 			 goto LBL_ERR;
1560 		}
1561 	}
1562 
1563 	/* too big */
1564 	while (compare_magnitude(&C, b) != MP_LT) {
1565 		if ((res = signed_subtract(&C, b, &C)) != MP_OKAY) {
1566 			 goto LBL_ERR;
1567 		}
1568 	}
1569 
1570 	/* C is now the inverse */
1571 	mp_exch(&C, c);
1572 	res = MP_OKAY;
1573 LBL_ERR:
1574 	mp_clear_multi(&x, &y, &u, &v, &A, &B, &C, &D, NULL);
1575 	return res;
1576 }
1577 
1578 static int
modular_inverse(mp_int * c,mp_int * a,mp_int * b)1579 modular_inverse(mp_int *c, mp_int *a, mp_int *b)
1580 {
1581 	/* b cannot be negative */
1582 	if (b->sign == MP_NEG || MP_ISZERO(b) == MP_YES) {
1583 		return MP_VAL;
1584 	}
1585 
1586 	/* if the modulus is odd we can use a faster routine instead */
1587 	if (BN_is_odd(b) == 1) {
1588 		return fast_modular_inverse(a, b, c);
1589 	}
1590 	return slow_modular_inverse(a, b, c);
1591 }
1592 
1593 /* b = |a|
1594  *
1595  * Simple function copies the input and fixes the sign to positive
1596  */
1597 static int
absolute(mp_int * a,mp_int * b)1598 absolute(mp_int * a, mp_int * b)
1599 {
1600 	int     res;
1601 
1602 	/* copy a to b */
1603 	if (a != b) {
1604 		if ((res = mp_copy(a, b)) != MP_OKAY) {
1605 			return res;
1606 		}
1607 	}
1608 
1609 	/* force the sign of b to positive */
1610 	b->sign = MP_ZPOS;
1611 
1612 	return MP_OKAY;
1613 }
1614 
1615 /* determines if reduce_2k_l can be used */
1616 static int
mp_reduce_is_2k_l(mp_int * a)1617 mp_reduce_is_2k_l(mp_int *a)
1618 {
1619 	int ix, iy;
1620 
1621 	if (a->used == 0) {
1622 		return MP_NO;
1623 	} else if (a->used == 1) {
1624 		return MP_YES;
1625 	} else if (a->used > 1) {
1626 		/* if more than half of the digits are -1 we're sold */
1627 		for (iy = ix = 0; ix < a->used; ix++) {
1628 			if (a->dp[ix] == MP_MASK) {
1629 				++iy;
1630 			}
1631 		}
1632 		return (iy >= (a->used/2)) ? MP_YES : MP_NO;
1633 
1634 	}
1635 	return MP_NO;
1636 }
1637 
1638 /* computes a = 2**b
1639  *
1640  * Simple algorithm which zeroes the int, grows it then just sets one bit
1641  * as required.
1642  */
1643 static int
mp_2expt(mp_int * a,int b)1644 mp_2expt(mp_int * a, int b)
1645 {
1646 	int     res;
1647 
1648 	/* zero a as per default */
1649 	mp_zero(a);
1650 
1651 	/* grow a to accomodate the single bit */
1652 	if ((res = mp_grow(a, b / DIGIT_BIT + 1)) != MP_OKAY) {
1653 		return res;
1654 	}
1655 
1656 	/* set the used count of where the bit will go */
1657 	a->used = b / DIGIT_BIT + 1;
1658 
1659 	/* put the single bit in its place */
1660 	a->dp[b / DIGIT_BIT] = ((mp_digit)1) << (b % DIGIT_BIT);
1661 
1662 	return MP_OKAY;
1663 }
1664 
1665 /* pre-calculate the value required for Barrett reduction
1666  * For a given modulus "b" it calulates the value required in "a"
1667  */
1668 static int
mp_reduce_setup(mp_int * a,mp_int * b)1669 mp_reduce_setup(mp_int * a, mp_int * b)
1670 {
1671 	int     res;
1672 
1673 	if ((res = mp_2expt(a, b->used * 2 * DIGIT_BIT)) != MP_OKAY) {
1674 		return res;
1675 	}
1676 	return signed_divide(a, NULL, a, b);
1677 }
1678 
1679 /* b = a*2 */
1680 static int
doubled(mp_int * a,mp_int * b)1681 doubled(mp_int * a, mp_int * b)
1682 {
1683 	int     x, res, oldused;
1684 
1685 	/* grow to accomodate result */
1686 	if (b->alloc < a->used + 1) {
1687 		if ((res = mp_grow(b, a->used + 1)) != MP_OKAY) {
1688 			return res;
1689 		}
1690 	}
1691 
1692 	oldused = b->used;
1693 	b->used = a->used;
1694 
1695 	{
1696 		mp_digit r, rr, *tmpa, *tmpb;
1697 
1698 		/* alias for source */
1699 		tmpa = a->dp;
1700 
1701 		/* alias for dest */
1702 		tmpb = b->dp;
1703 
1704 		/* carry */
1705 		r = 0;
1706 		for (x = 0; x < a->used; x++) {
1707 
1708 			/* get what will be the *next* carry bit from the
1709 			* MSB of the current digit
1710 			*/
1711 			rr = *tmpa >> ((mp_digit)(DIGIT_BIT - 1));
1712 
1713 			/* now shift up this digit, add in the carry [from the previous] */
1714 			*tmpb++ = ((*tmpa++ << ((mp_digit)1)) | r) & MP_MASK;
1715 
1716 			/* copy the carry that would be from the source
1717 			* digit into the next iteration
1718 			*/
1719 			r = rr;
1720 		}
1721 
1722 		/* new leading digit? */
1723 		if (r != 0) {
1724 			/* add a MSB which is always 1 at this point */
1725 			*tmpb = 1;
1726 			++(b->used);
1727 		}
1728 
1729 		/* now zero any excess digits on the destination
1730 		* that we didn't write to
1731 		*/
1732 		tmpb = b->dp + b->used;
1733 		for (x = b->used; x < oldused; x++) {
1734 			*tmpb++ = 0;
1735 		}
1736 	}
1737 	b->sign = a->sign;
1738 	return MP_OKAY;
1739 }
1740 
1741 /* divide by three (based on routine from MPI and the GMP manual) */
1742 static int
third(mp_int * a,mp_int * c,mp_digit * d)1743 third(mp_int * a, mp_int *c, mp_digit * d)
1744 {
1745 	mp_int   q;
1746 	mp_word  w, t;
1747 	mp_digit b;
1748 	int      res, ix;
1749 
1750 	/* b = 2**DIGIT_BIT / 3 */
1751 	b = (((mp_word)1) << ((mp_word)DIGIT_BIT)) / ((mp_word)3);
1752 
1753 	if ((res = mp_init_size(&q, a->used)) != MP_OKAY) {
1754 		return res;
1755 	}
1756 
1757 	q.used = a->used;
1758 	q.sign = a->sign;
1759 	w = 0;
1760 	for (ix = a->used - 1; ix >= 0; ix--) {
1761 		w = (w << ((mp_word)DIGIT_BIT)) | ((mp_word)a->dp[ix]);
1762 
1763 		if (w >= 3) {
1764 			/* multiply w by [1/3] */
1765 			t = (w * ((mp_word)b)) >> ((mp_word)DIGIT_BIT);
1766 
1767 			/* now subtract 3 * [w/3] from w, to get the remainder */
1768 			w -= t+t+t;
1769 
1770 			/* fixup the remainder as required since
1771 			* the optimization is not exact.
1772 			*/
1773 			while (w >= 3) {
1774 				t += 1;
1775 				w -= 3;
1776 			}
1777 		} else {
1778 			t = 0;
1779 		}
1780 		q.dp[ix] = (mp_digit)t;
1781 	}
1782 
1783 	/* [optional] store the remainder */
1784 	if (d != NULL) {
1785 		*d = (mp_digit)w;
1786 	}
1787 
1788 	/* [optional] store the quotient */
1789 	if (c != NULL) {
1790 		trim_unused_digits(&q);
1791 		mp_exch(&q, c);
1792 	}
1793 	mp_clear(&q);
1794 
1795 	return res;
1796 }
1797 
1798 /* multiplication using the Toom-Cook 3-way algorithm
1799  *
1800  * Much more complicated than Karatsuba but has a lower
1801  * asymptotic running time of O(N**1.464).  This algorithm is
1802  * only particularly useful on VERY large inputs
1803  * (we're talking 1000s of digits here...).
1804 */
1805 static int
toom_cook_multiply(mp_int * a,mp_int * b,mp_int * c)1806 toom_cook_multiply(mp_int *a, mp_int *b, mp_int *c)
1807 {
1808 	mp_int w0, w1, w2, w3, w4, tmp1, tmp2, a0, a1, a2, b0, b1, b2;
1809 	int res, B;
1810 
1811 	/* init temps */
1812 	if ((res = mp_init_multi(&w0, &w1, &w2, &w3, &w4,
1813 			&a0, &a1, &a2, &b0, &b1,
1814 			&b2, &tmp1, &tmp2, NULL)) != MP_OKAY) {
1815 		return res;
1816 	}
1817 
1818 	/* B */
1819 	B = MIN(a->used, b->used) / 3;
1820 
1821 	/* a = a2 * B**2 + a1 * B + a0 */
1822 	if ((res = modulo_2_to_power(a, DIGIT_BIT * B, &a0)) != MP_OKAY) {
1823 		goto ERR;
1824 	}
1825 
1826 	if ((res = mp_copy(a, &a1)) != MP_OKAY) {
1827 		goto ERR;
1828 	}
1829 	rshift_digits(&a1, B);
1830 	modulo_2_to_power(&a1, DIGIT_BIT * B, &a1);
1831 
1832 	if ((res = mp_copy(a, &a2)) != MP_OKAY) {
1833 		goto ERR;
1834 	}
1835 	rshift_digits(&a2, B*2);
1836 
1837 	/* b = b2 * B**2 + b1 * B + b0 */
1838 	if ((res = modulo_2_to_power(b, DIGIT_BIT * B, &b0)) != MP_OKAY) {
1839 		goto ERR;
1840 	}
1841 
1842 	if ((res = mp_copy(b, &b1)) != MP_OKAY) {
1843 		goto ERR;
1844 	}
1845 	rshift_digits(&b1, B);
1846 	modulo_2_to_power(&b1, DIGIT_BIT * B, &b1);
1847 
1848 	if ((res = mp_copy(b, &b2)) != MP_OKAY) {
1849 		goto ERR;
1850 	}
1851 	rshift_digits(&b2, B*2);
1852 
1853 	/* w0 = a0*b0 */
1854 	if ((res = signed_multiply(&a0, &b0, &w0)) != MP_OKAY) {
1855 		goto ERR;
1856 	}
1857 
1858 	/* w4 = a2 * b2 */
1859 	if ((res = signed_multiply(&a2, &b2, &w4)) != MP_OKAY) {
1860 		goto ERR;
1861 	}
1862 
1863 	/* w1 = (a2 + 2(a1 + 2a0))(b2 + 2(b1 + 2b0)) */
1864 	if ((res = doubled(&a0, &tmp1)) != MP_OKAY) {
1865 		goto ERR;
1866 	}
1867 	if ((res = signed_add(&tmp1, &a1, &tmp1)) != MP_OKAY) {
1868 		goto ERR;
1869 	}
1870 	if ((res = doubled(&tmp1, &tmp1)) != MP_OKAY) {
1871 		goto ERR;
1872 	}
1873 	if ((res = signed_add(&tmp1, &a2, &tmp1)) != MP_OKAY) {
1874 		goto ERR;
1875 	}
1876 
1877 	if ((res = doubled(&b0, &tmp2)) != MP_OKAY) {
1878 		goto ERR;
1879 	}
1880 	if ((res = signed_add(&tmp2, &b1, &tmp2)) != MP_OKAY) {
1881 		goto ERR;
1882 	}
1883 	if ((res = doubled(&tmp2, &tmp2)) != MP_OKAY) {
1884 		goto ERR;
1885 	}
1886 	if ((res = signed_add(&tmp2, &b2, &tmp2)) != MP_OKAY) {
1887 		goto ERR;
1888 	}
1889 
1890 	if ((res = signed_multiply(&tmp1, &tmp2, &w1)) != MP_OKAY) {
1891 		goto ERR;
1892 	}
1893 
1894 	/* w3 = (a0 + 2(a1 + 2a2))(b0 + 2(b1 + 2b2)) */
1895 	if ((res = doubled(&a2, &tmp1)) != MP_OKAY) {
1896 		goto ERR;
1897 	}
1898 	if ((res = signed_add(&tmp1, &a1, &tmp1)) != MP_OKAY) {
1899 		goto ERR;
1900 	}
1901 	if ((res = doubled(&tmp1, &tmp1)) != MP_OKAY) {
1902 		goto ERR;
1903 	}
1904 	if ((res = signed_add(&tmp1, &a0, &tmp1)) != MP_OKAY) {
1905 		goto ERR;
1906 	}
1907 
1908 	if ((res = doubled(&b2, &tmp2)) != MP_OKAY) {
1909 		goto ERR;
1910 	}
1911 	if ((res = signed_add(&tmp2, &b1, &tmp2)) != MP_OKAY) {
1912 		goto ERR;
1913 	}
1914 	if ((res = doubled(&tmp2, &tmp2)) != MP_OKAY) {
1915 		goto ERR;
1916 	}
1917 	if ((res = signed_add(&tmp2, &b0, &tmp2)) != MP_OKAY) {
1918 		goto ERR;
1919 	}
1920 
1921 	if ((res = signed_multiply(&tmp1, &tmp2, &w3)) != MP_OKAY) {
1922 		goto ERR;
1923 	}
1924 
1925 
1926 	/* w2 = (a2 + a1 + a0)(b2 + b1 + b0) */
1927 	if ((res = signed_add(&a2, &a1, &tmp1)) != MP_OKAY) {
1928 		goto ERR;
1929 	}
1930 	if ((res = signed_add(&tmp1, &a0, &tmp1)) != MP_OKAY) {
1931 		goto ERR;
1932 	}
1933 	if ((res = signed_add(&b2, &b1, &tmp2)) != MP_OKAY) {
1934 		goto ERR;
1935 	}
1936 	if ((res = signed_add(&tmp2, &b0, &tmp2)) != MP_OKAY) {
1937 		goto ERR;
1938 	}
1939 	if ((res = signed_multiply(&tmp1, &tmp2, &w2)) != MP_OKAY) {
1940 		goto ERR;
1941 	}
1942 
1943 	/* now solve the matrix
1944 
1945 	0  0  0  0  1
1946 	1  2  4  8  16
1947 	1  1  1  1  1
1948 	16 8  4  2  1
1949 	1  0  0  0  0
1950 
1951 	using 12 subtractions, 4 shifts,
1952 	2 small divisions and 1 small multiplication
1953 	*/
1954 
1955 	/* r1 - r4 */
1956 	if ((res = signed_subtract(&w1, &w4, &w1)) != MP_OKAY) {
1957 		goto ERR;
1958 	}
1959 	/* r3 - r0 */
1960 	if ((res = signed_subtract(&w3, &w0, &w3)) != MP_OKAY) {
1961 		goto ERR;
1962 	}
1963 	/* r1/2 */
1964 	if ((res = half(&w1, &w1)) != MP_OKAY) {
1965 		goto ERR;
1966 	}
1967 	/* r3/2 */
1968 	if ((res = half(&w3, &w3)) != MP_OKAY) {
1969 		goto ERR;
1970 	}
1971 	/* r2 - r0 - r4 */
1972 	if ((res = signed_subtract(&w2, &w0, &w2)) != MP_OKAY) {
1973 		goto ERR;
1974 	}
1975 	if ((res = signed_subtract(&w2, &w4, &w2)) != MP_OKAY) {
1976 		goto ERR;
1977 	}
1978 	/* r1 - r2 */
1979 	if ((res = signed_subtract(&w1, &w2, &w1)) != MP_OKAY) {
1980 		goto ERR;
1981 	}
1982 	/* r3 - r2 */
1983 	if ((res = signed_subtract(&w3, &w2, &w3)) != MP_OKAY) {
1984 		goto ERR;
1985 	}
1986 	/* r1 - 8r0 */
1987 	if ((res = lshift_bits(&w0, 3, &tmp1)) != MP_OKAY) {
1988 		goto ERR;
1989 	}
1990 	if ((res = signed_subtract(&w1, &tmp1, &w1)) != MP_OKAY) {
1991 		goto ERR;
1992 	}
1993 	/* r3 - 8r4 */
1994 	if ((res = lshift_bits(&w4, 3, &tmp1)) != MP_OKAY) {
1995 		goto ERR;
1996 	}
1997 	if ((res = signed_subtract(&w3, &tmp1, &w3)) != MP_OKAY) {
1998 		goto ERR;
1999 	}
2000 	/* 3r2 - r1 - r3 */
2001 	if ((res = multiply_digit(&w2, 3, &w2)) != MP_OKAY) {
2002 		goto ERR;
2003 	}
2004 	if ((res = signed_subtract(&w2, &w1, &w2)) != MP_OKAY) {
2005 		goto ERR;
2006 	}
2007 	if ((res = signed_subtract(&w2, &w3, &w2)) != MP_OKAY) {
2008 		goto ERR;
2009 	}
2010 	/* r1 - r2 */
2011 	if ((res = signed_subtract(&w1, &w2, &w1)) != MP_OKAY) {
2012 		goto ERR;
2013 	}
2014 	/* r3 - r2 */
2015 	if ((res = signed_subtract(&w3, &w2, &w3)) != MP_OKAY) {
2016 		goto ERR;
2017 	}
2018 	/* r1/3 */
2019 	if ((res = third(&w1, &w1, NULL)) != MP_OKAY) {
2020 		goto ERR;
2021 	}
2022 	/* r3/3 */
2023 	if ((res = third(&w3, &w3, NULL)) != MP_OKAY) {
2024 		goto ERR;
2025 	}
2026 
2027 	/* at this point shift W[n] by B*n */
2028 	if ((res = lshift_digits(&w1, 1*B)) != MP_OKAY) {
2029 		goto ERR;
2030 	}
2031 	if ((res = lshift_digits(&w2, 2*B)) != MP_OKAY) {
2032 		goto ERR;
2033 	}
2034 	if ((res = lshift_digits(&w3, 3*B)) != MP_OKAY) {
2035 		goto ERR;
2036 	}
2037 	if ((res = lshift_digits(&w4, 4*B)) != MP_OKAY) {
2038 		goto ERR;
2039 	}
2040 
2041 	if ((res = signed_add(&w0, &w1, c)) != MP_OKAY) {
2042 		goto ERR;
2043 	}
2044 	if ((res = signed_add(&w2, &w3, &tmp1)) != MP_OKAY) {
2045 		goto ERR;
2046 	}
2047 	if ((res = signed_add(&w4, &tmp1, &tmp1)) != MP_OKAY) {
2048 		goto ERR;
2049 	}
2050 	if ((res = signed_add(&tmp1, c, c)) != MP_OKAY) {
2051 		goto ERR;
2052 	}
2053 
2054 ERR:
2055 	mp_clear_multi(&w0, &w1, &w2, &w3, &w4,
2056 		&a0, &a1, &a2, &b0, &b1,
2057 		&b2, &tmp1, &tmp2, NULL);
2058 	return res;
2059 }
2060 
2061 #define TOOM_MUL_CUTOFF	350
2062 #define KARATSUBA_MUL_CUTOFF 80
2063 
2064 /* c = |a| * |b| using Karatsuba Multiplication using
2065  * three half size multiplications
2066  *
2067  * Let B represent the radix [e.g. 2**DIGIT_BIT] and
2068  * let n represent half of the number of digits in
2069  * the min(a,b)
2070  *
2071  * a = a1 * B**n + a0
2072  * b = b1 * B**n + b0
2073  *
2074  * Then, a * b =>
2075    a1b1 * B**2n + ((a1 + a0)(b1 + b0) - (a0b0 + a1b1)) * B + a0b0
2076  *
2077  * Note that a1b1 and a0b0 are used twice and only need to be
2078  * computed once.  So in total three half size (half # of
2079  * digit) multiplications are performed, a0b0, a1b1 and
2080  * (a1+b1)(a0+b0)
2081  *
2082  * Note that a multiplication of half the digits requires
2083  * 1/4th the number of single precision multiplications so in
2084  * total after one call 25% of the single precision multiplications
2085  * are saved.  Note also that the call to signed_multiply can end up back
2086  * in this function if the a0, a1, b0, or b1 are above the threshold.
2087  * This is known as divide-and-conquer and leads to the famous
2088  * O(N**lg(3)) or O(N**1.584) work which is asymptopically lower than
2089  * the standard O(N**2) that the baseline/comba methods use.
2090  * Generally though the overhead of this method doesn't pay off
2091  * until a certain size (N ~ 80) is reached.
2092  */
2093 static int
karatsuba_multiply(mp_int * a,mp_int * b,mp_int * c)2094 karatsuba_multiply(mp_int * a, mp_int * b, mp_int * c)
2095 {
2096 	mp_int  x0, x1, y0, y1, t1, x0y0, x1y1;
2097 	int     B;
2098 	int     err;
2099 
2100 	/* default the return code to an error */
2101 	err = MP_MEM;
2102 
2103 	/* min # of digits */
2104 	B = MIN(a->used, b->used);
2105 
2106 	/* now divide in two */
2107 	B = (int)((unsigned)B >> 1);
2108 
2109 	/* init copy all the temps */
2110 	if (mp_init_size(&x0, B) != MP_OKAY) {
2111 		goto ERR;
2112 	}
2113 	if (mp_init_size(&x1, a->used - B) != MP_OKAY) {
2114 		goto X0;
2115 	}
2116 	if (mp_init_size(&y0, B) != MP_OKAY) {
2117 		goto X1;
2118 	}
2119 	if (mp_init_size(&y1, b->used - B) != MP_OKAY) {
2120 		goto Y0;
2121 	}
2122 	/* init temps */
2123 	if (mp_init_size(&t1, B * 2) != MP_OKAY) {
2124 		goto Y1;
2125 	}
2126 	if (mp_init_size(&x0y0, B * 2) != MP_OKAY) {
2127 		goto T1;
2128 	}
2129 	if (mp_init_size(&x1y1, B * 2) != MP_OKAY) {
2130 		goto X0Y0;
2131 	}
2132 	/* now shift the digits */
2133 	x0.used = y0.used = B;
2134 	x1.used = a->used - B;
2135 	y1.used = b->used - B;
2136 
2137 	{
2138 		int x;
2139 		mp_digit *tmpa, *tmpb, *tmpx, *tmpy;
2140 
2141 		/* we copy the digits directly instead of using higher level functions
2142 		* since we also need to shift the digits
2143 		*/
2144 		tmpa = a->dp;
2145 		tmpb = b->dp;
2146 
2147 		tmpx = x0.dp;
2148 		tmpy = y0.dp;
2149 		for (x = 0; x < B; x++) {
2150 			*tmpx++ = *tmpa++;
2151 			*tmpy++ = *tmpb++;
2152 		}
2153 
2154 		tmpx = x1.dp;
2155 		for (x = B; x < a->used; x++) {
2156 			*tmpx++ = *tmpa++;
2157 		}
2158 
2159 		tmpy = y1.dp;
2160 		for (x = B; x < b->used; x++) {
2161 			*tmpy++ = *tmpb++;
2162 		}
2163 	}
2164 
2165 	/* only need to clamp the lower words since by definition the
2166 	* upper words x1/y1 must have a known number of digits
2167 	*/
2168 	trim_unused_digits(&x0);
2169 	trim_unused_digits(&y0);
2170 
2171 	/* now calc the products x0y0 and x1y1 */
2172 	/* after this x0 is no longer required, free temp [x0==t2]! */
2173 	if (signed_multiply(&x0, &y0, &x0y0) != MP_OKAY)  {
2174 		goto X1Y1;          /* x0y0 = x0*y0 */
2175 	}
2176 	if (signed_multiply(&x1, &y1, &x1y1) != MP_OKAY) {
2177 		goto X1Y1;          /* x1y1 = x1*y1 */
2178 	}
2179 	/* now calc x1+x0 and y1+y0 */
2180 	if (basic_add(&x1, &x0, &t1) != MP_OKAY) {
2181 		goto X1Y1;          /* t1 = x1 - x0 */
2182 	}
2183 	if (basic_add(&y1, &y0, &x0) != MP_OKAY) {
2184 		goto X1Y1;          /* t2 = y1 - y0 */
2185 	}
2186 	if (signed_multiply(&t1, &x0, &t1) != MP_OKAY) {
2187 		goto X1Y1;          /* t1 = (x1 + x0) * (y1 + y0) */
2188 	}
2189 	/* add x0y0 */
2190 	if (signed_add(&x0y0, &x1y1, &x0) != MP_OKAY) {
2191 		goto X1Y1;          /* t2 = x0y0 + x1y1 */
2192 	}
2193 	if (basic_subtract(&t1, &x0, &t1) != MP_OKAY) {
2194 		goto X1Y1;          /* t1 = (x1+x0)*(y1+y0) - (x1y1 + x0y0) */
2195 	}
2196 	/* shift by B */
2197 	if (lshift_digits(&t1, B) != MP_OKAY) {
2198 		goto X1Y1;          /* t1 = (x0y0 + x1y1 - (x1-x0)*(y1-y0))<<B */
2199 	}
2200 	if (lshift_digits(&x1y1, B * 2) != MP_OKAY) {
2201 		goto X1Y1;          /* x1y1 = x1y1 << 2*B */
2202 	}
2203 	if (signed_add(&x0y0, &t1, &t1) != MP_OKAY) {
2204 		goto X1Y1;          /* t1 = x0y0 + t1 */
2205 	}
2206 	if (signed_add(&t1, &x1y1, c) != MP_OKAY) {
2207 		goto X1Y1;          /* t1 = x0y0 + t1 + x1y1 */
2208 	}
2209 	/* Algorithm succeeded set the return code to MP_OKAY */
2210 	err = MP_OKAY;
2211 
2212 X1Y1:
2213 	mp_clear(&x1y1);
2214 X0Y0:
2215 	mp_clear(&x0y0);
2216 T1:
2217 	mp_clear(&t1);
2218 Y1:
2219 	mp_clear(&y1);
2220 Y0:
2221 	mp_clear(&y0);
2222 X1:
2223 	mp_clear(&x1);
2224 X0:
2225 	mp_clear(&x0);
2226 ERR:
2227 	return err;
2228 }
2229 
2230 /* Fast (comba) multiplier
2231  *
2232  * This is the fast column-array [comba] multiplier.  It is
2233  * designed to compute the columns of the product first
2234  * then handle the carries afterwards.  This has the effect
2235  * of making the nested loops that compute the columns very
2236  * simple and schedulable on super-scalar processors.
2237  *
2238  * This has been modified to produce a variable number of
2239  * digits of output so if say only a half-product is required
2240  * you don't have to compute the upper half (a feature
2241  * required for fast Barrett reduction).
2242  *
2243  * Based on Algorithm 14.12 on pp.595 of HAC.
2244  *
2245  */
2246 static int
fast_col_array_multiply(mp_int * a,mp_int * b,mp_int * c,int digs)2247 fast_col_array_multiply(mp_int * a, mp_int * b, mp_int * c, int digs)
2248 {
2249 	int     olduse, res, pa, ix, iz;
2250 	/*LINTED*/
2251 	mp_digit W[MP_WARRAY];
2252 	mp_word  _W;
2253 
2254 	/* grow the destination as required */
2255 	if (c->alloc < digs) {
2256 		if ((res = mp_grow(c, digs)) != MP_OKAY) {
2257 			return res;
2258 		}
2259 	}
2260 
2261 	/* number of output digits to produce */
2262 	pa = MIN(digs, a->used + b->used);
2263 
2264 	/* clear the carry */
2265 	_W = 0;
2266 	for (ix = 0; ix < pa; ix++) {
2267 		int      tx, ty;
2268 		int      iy;
2269 		mp_digit *tmpx, *tmpy;
2270 
2271 		/* get offsets into the two bignums */
2272 		ty = MIN(b->used-1, ix);
2273 		tx = ix - ty;
2274 
2275 		/* setup temp aliases */
2276 		tmpx = a->dp + tx;
2277 		tmpy = b->dp + ty;
2278 
2279 		/* this is the number of times the loop will iterrate, essentially
2280 		while (tx++ < a->used && ty-- >= 0) { ... }
2281 		*/
2282 		iy = MIN(a->used-tx, ty+1);
2283 
2284 		/* execute loop */
2285 		for (iz = 0; iz < iy; ++iz) {
2286 			_W += ((mp_word)*tmpx++)*((mp_word)*tmpy--);
2287 
2288 		}
2289 
2290 		/* store term */
2291 		W[ix] = ((mp_digit)_W) & MP_MASK;
2292 
2293 		/* make next carry */
2294 		_W = _W >> ((mp_word)DIGIT_BIT);
2295 	}
2296 
2297 	/* setup dest */
2298 	olduse  = c->used;
2299 	c->used = pa;
2300 
2301 	{
2302 		mp_digit *tmpc;
2303 		tmpc = c->dp;
2304 		for (ix = 0; ix < pa+1; ix++) {
2305 			/* now extract the previous digit [below the carry] */
2306 			*tmpc++ = W[ix];
2307 		}
2308 
2309 		/* clear unused digits [that existed in the old copy of c] */
2310 		for (; ix < olduse; ix++) {
2311 			*tmpc++ = 0;
2312 		}
2313 	}
2314 	trim_unused_digits(c);
2315 	return MP_OKAY;
2316 }
2317 
2318 /* return 1 if we can use fast column array multiply */
2319 /*
2320 * The fast multiplier can be used if the output will
2321 * have less than MP_WARRAY digits and the number of
2322 * digits won't affect carry propagation
2323 */
2324 static inline int
can_use_fast_column_array(int ndigits,int used)2325 can_use_fast_column_array(int ndigits, int used)
2326 {
2327 	return (((unsigned)ndigits < MP_WARRAY) &&
2328 		used < (1 << (unsigned)((CHAR_BIT * sizeof(mp_word)) - (2 * DIGIT_BIT))));
2329 }
2330 
2331 /* Source: /usr/cvsroot/libtommath/dist/libtommath/bn_fast_s_mp_mul_digs.c,v $ */
2332 /* Revision: 1.2 $ */
2333 /* Date: 2011/03/18 16:22:09 $ */
2334 
2335 
2336 /* multiplies |a| * |b| and only computes upto digs digits of result
2337  * HAC pp. 595, Algorithm 14.12  Modified so you can control how
2338  * many digits of output are created.
2339  */
2340 static int
basic_multiply_partial_lower(mp_int * a,mp_int * b,mp_int * c,int digs)2341 basic_multiply_partial_lower(mp_int * a, mp_int * b, mp_int * c, int digs)
2342 {
2343 	mp_int  t;
2344 	int     res, pa, pb, ix, iy;
2345 	mp_digit u;
2346 	mp_word r;
2347 	mp_digit tmpx, *tmpt, *tmpy;
2348 
2349 	/* can we use the fast multiplier? */
2350 	if (can_use_fast_column_array(digs, MIN(a->used, b->used))) {
2351 		return fast_col_array_multiply(a, b, c, digs);
2352 	}
2353 
2354 	if ((res = mp_init_size(&t, digs)) != MP_OKAY) {
2355 		return res;
2356 	}
2357 	t.used = digs;
2358 
2359 	/* compute the digits of the product directly */
2360 	pa = a->used;
2361 	for (ix = 0; ix < pa; ix++) {
2362 		/* set the carry to zero */
2363 		u = 0;
2364 
2365 		/* limit ourselves to making digs digits of output */
2366 		pb = MIN(b->used, digs - ix);
2367 
2368 		/* setup some aliases */
2369 		/* copy of the digit from a used within the nested loop */
2370 		tmpx = a->dp[ix];
2371 
2372 		/* an alias for the destination shifted ix places */
2373 		tmpt = t.dp + ix;
2374 
2375 		/* an alias for the digits of b */
2376 		tmpy = b->dp;
2377 
2378 		/* compute the columns of the output and propagate the carry */
2379 		for (iy = 0; iy < pb; iy++) {
2380 			/* compute the column as a mp_word */
2381 			r = ((mp_word)*tmpt) +
2382 				((mp_word)tmpx) * ((mp_word)*tmpy++) +
2383 				((mp_word) u);
2384 
2385 			/* the new column is the lower part of the result */
2386 			*tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
2387 
2388 			/* get the carry word from the result */
2389 			u = (mp_digit) (r >> ((mp_word) DIGIT_BIT));
2390 		}
2391 		/* set carry if it is placed below digs */
2392 		if (ix + iy < digs) {
2393 			*tmpt = u;
2394 		}
2395 	}
2396 
2397 	trim_unused_digits(&t);
2398 	mp_exch(&t, c);
2399 
2400 	mp_clear(&t);
2401 	return MP_OKAY;
2402 }
2403 
2404 /* Source: /usr/cvsroot/libtommath/dist/libtommath/bn_s_mp_mul_digs.c,v $ */
2405 /* Revision: 1.3 $ */
2406 /* Date: 2011/03/18 16:43:04 $ */
2407 
2408 /* high level multiplication (handles sign) */
2409 static int
signed_multiply(mp_int * a,mp_int * b,mp_int * c)2410 signed_multiply(mp_int * a, mp_int * b, mp_int * c)
2411 {
2412 	int     res, neg;
2413 
2414 	neg = (a->sign == b->sign) ? MP_ZPOS : MP_NEG;
2415 	/* use Toom-Cook? */
2416 	if (MIN(a->used, b->used) >= TOOM_MUL_CUTOFF) {
2417 		res = toom_cook_multiply(a, b, c);
2418 	} else if (MIN(a->used, b->used) >= KARATSUBA_MUL_CUTOFF) {
2419 		/* use Karatsuba? */
2420 		res = karatsuba_multiply(a, b, c);
2421 	} else {
2422 		/* can we use the fast multiplier? */
2423 		int     digs = a->used + b->used + 1;
2424 
2425 		if (can_use_fast_column_array(digs, MIN(a->used, b->used))) {
2426 			res = fast_col_array_multiply(a, b, c, digs);
2427 		} else  {
2428 			res = basic_multiply_partial_lower(a, b, c, (a)->used + (b)->used + 1);
2429 		}
2430 	}
2431 	c->sign = (c->used > 0) ? neg : MP_ZPOS;
2432 	return res;
2433 }
2434 
2435 /* this is a modified version of fast_s_mul_digs that only produces
2436  * output digits *above* digs.  See the comments for fast_s_mul_digs
2437  * to see how it works.
2438  *
2439  * This is used in the Barrett reduction since for one of the multiplications
2440  * only the higher digits were needed.  This essentially halves the work.
2441  *
2442  * Based on Algorithm 14.12 on pp.595 of HAC.
2443  */
2444 static int
fast_basic_multiply_partial_upper(mp_int * a,mp_int * b,mp_int * c,int digs)2445 fast_basic_multiply_partial_upper(mp_int * a, mp_int * b, mp_int * c, int digs)
2446 {
2447 	int     olduse, res, pa, ix, iz;
2448 	mp_digit W[MP_WARRAY];
2449 	mp_word  _W;
2450 
2451 	/* grow the destination as required */
2452 	pa = a->used + b->used;
2453 	if (c->alloc < pa) {
2454 		if ((res = mp_grow(c, pa)) != MP_OKAY) {
2455 			return res;
2456 		}
2457 	}
2458 
2459 	/* number of output digits to produce */
2460 	pa = a->used + b->used;
2461 	_W = 0;
2462 	for (ix = digs; ix < pa; ix++) {
2463 		int      tx, ty, iy;
2464 		mp_digit *tmpx, *tmpy;
2465 
2466 		/* get offsets into the two bignums */
2467 		ty = MIN(b->used-1, ix);
2468 		tx = ix - ty;
2469 
2470 		/* setup temp aliases */
2471 		tmpx = a->dp + tx;
2472 		tmpy = b->dp + ty;
2473 
2474 		/* this is the number of times the loop will iterrate, essentially its
2475 		 while (tx++ < a->used && ty-- >= 0) { ... }
2476 		*/
2477 		iy = MIN(a->used-tx, ty+1);
2478 
2479 		/* execute loop */
2480 		for (iz = 0; iz < iy; iz++) {
2481 			 _W += ((mp_word)*tmpx++)*((mp_word)*tmpy--);
2482 		}
2483 
2484 		/* store term */
2485 		W[ix] = ((mp_digit)_W) & MP_MASK;
2486 
2487 		/* make next carry */
2488 		_W = _W >> ((mp_word)DIGIT_BIT);
2489 	}
2490 
2491 	/* setup dest */
2492 	olduse  = c->used;
2493 	c->used = pa;
2494 
2495 	{
2496 		mp_digit *tmpc;
2497 
2498 		tmpc = c->dp + digs;
2499 		for (ix = digs; ix < pa; ix++) {
2500 			/* now extract the previous digit [below the carry] */
2501 			*tmpc++ = W[ix];
2502 		}
2503 
2504 		/* clear unused digits [that existed in the old copy of c] */
2505 		for (; ix < olduse; ix++) {
2506 			*tmpc++ = 0;
2507 		}
2508 	}
2509 	trim_unused_digits(c);
2510 	return MP_OKAY;
2511 }
2512 
2513 /* Source: /usr/cvsroot/libtommath/dist/libtommath/bn_fast_s_mp_mul_high_digs.c,v $ */
2514 /* Revision: 1.1.1.1 $ */
2515 /* Date: 2011/03/12 22:58:18 $ */
2516 
2517 /* multiplies |a| * |b| and does not compute the lower digs digits
2518  * [meant to get the higher part of the product]
2519  */
2520 static int
basic_multiply_partial_upper(mp_int * a,mp_int * b,mp_int * c,int digs)2521 basic_multiply_partial_upper(mp_int * a, mp_int * b, mp_int * c, int digs)
2522 {
2523 	mp_int  t;
2524 	int     res, pa, pb, ix, iy;
2525 	mp_digit carry;
2526 	mp_word r;
2527 	mp_digit tmpx, *tmpt, *tmpy;
2528 
2529 	/* can we use the fast multiplier? */
2530 	if (can_use_fast_column_array(a->used + b->used + 1, MIN(a->used, b->used))) {
2531 		return fast_basic_multiply_partial_upper(a, b, c, digs);
2532 	}
2533 
2534 	if ((res = mp_init_size(&t, a->used + b->used + 1)) != MP_OKAY) {
2535 		return res;
2536 	}
2537 	t.used = a->used + b->used + 1;
2538 
2539 	pa = a->used;
2540 	pb = b->used;
2541 	for (ix = 0; ix < pa; ix++) {
2542 		/* clear the carry */
2543 		carry = 0;
2544 
2545 		/* left hand side of A[ix] * B[iy] */
2546 		tmpx = a->dp[ix];
2547 
2548 		/* alias to the address of where the digits will be stored */
2549 		tmpt = &(t.dp[digs]);
2550 
2551 		/* alias for where to read the right hand side from */
2552 		tmpy = b->dp + (digs - ix);
2553 
2554 		for (iy = digs - ix; iy < pb; iy++) {
2555 			/* calculate the double precision result */
2556 			r = ((mp_word)*tmpt) +
2557 				((mp_word)tmpx) * ((mp_word)*tmpy++) +
2558 				((mp_word) carry);
2559 
2560 			/* get the lower part */
2561 			*tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
2562 
2563 			/* carry the carry */
2564 			carry = (mp_digit) (r >> ((mp_word) DIGIT_BIT));
2565 		}
2566 		*tmpt = carry;
2567 	}
2568 	trim_unused_digits(&t);
2569 	mp_exch(&t, c);
2570 	mp_clear(&t);
2571 	return MP_OKAY;
2572 }
2573 
2574 /* Source: /usr/cvsroot/libtommath/dist/libtommath/bn_s_mp_mul_high_digs.c,v $ */
2575 /* Revision: 1.3 $ */
2576 /* Date: 2011/03/18 16:43:04 $ */
2577 
2578 /* reduces x mod m, assumes 0 < x < m**2, mu is
2579  * precomputed via mp_reduce_setup.
2580  * From HAC pp.604 Algorithm 14.42
2581  */
2582 static int
mp_reduce(mp_int * x,mp_int * m,mp_int * mu)2583 mp_reduce(mp_int * x, mp_int * m, mp_int * mu)
2584 {
2585 	mp_int  q;
2586 	int     res, um = m->used;
2587 
2588 	/* q = x */
2589 	if ((res = mp_init_copy(&q, x)) != MP_OKAY) {
2590 		return res;
2591 	}
2592 
2593 	/* q1 = x / b**(k-1)  */
2594 	rshift_digits(&q, um - 1);
2595 
2596 	/* according to HAC this optimization is ok */
2597 	if (((unsigned long) um) > (((mp_digit)1) << (DIGIT_BIT - 1))) {
2598 		if ((res = signed_multiply(&q, mu, &q)) != MP_OKAY) {
2599 			goto CLEANUP;
2600 		}
2601 	} else {
2602 		if ((res = basic_multiply_partial_upper(&q, mu, &q, um)) != MP_OKAY) {
2603 			goto CLEANUP;
2604 		}
2605 	}
2606 
2607 	/* q3 = q2 / b**(k+1) */
2608 	rshift_digits(&q, um + 1);
2609 
2610 	/* x = x mod b**(k+1), quick (no division) */
2611 	if ((res = modulo_2_to_power(x, DIGIT_BIT * (um + 1), x)) != MP_OKAY) {
2612 		goto CLEANUP;
2613 	}
2614 
2615 	/* q = q * m mod b**(k+1), quick (no division) */
2616 	if ((res = basic_multiply_partial_lower(&q, m, &q, um + 1)) != MP_OKAY) {
2617 		goto CLEANUP;
2618 	}
2619 
2620 	/* x = x - q */
2621 	if ((res = signed_subtract(x, &q, x)) != MP_OKAY) {
2622 		goto CLEANUP;
2623 	}
2624 
2625 	/* If x < 0, add b**(k+1) to it */
2626 	if (compare_digit(x, 0) == MP_LT) {
2627 		set_word(&q, 1);
2628 		if ((res = lshift_digits(&q, um + 1)) != MP_OKAY) {
2629 			goto CLEANUP;
2630 		}
2631 		if ((res = signed_add(x, &q, x)) != MP_OKAY) {
2632 			goto CLEANUP;
2633 		}
2634 	}
2635 
2636 	/* Back off if it's too big */
2637 	while (signed_compare(x, m) != MP_LT) {
2638 		if ((res = basic_subtract(x, m, x)) != MP_OKAY) {
2639 			goto CLEANUP;
2640 		}
2641 	}
2642 
2643 CLEANUP:
2644 	mp_clear(&q);
2645 
2646 	return res;
2647 }
2648 
2649 /* determines the setup value */
2650 static int
mp_reduce_2k_setup_l(mp_int * a,mp_int * d)2651 mp_reduce_2k_setup_l(mp_int *a, mp_int *d)
2652 {
2653 	int    res;
2654 	mp_int tmp;
2655 
2656 	if ((res = mp_init(&tmp)) != MP_OKAY) {
2657 		return res;
2658 	}
2659 
2660 	if ((res = mp_2expt(&tmp, mp_count_bits(a))) != MP_OKAY) {
2661 		goto ERR;
2662 	}
2663 
2664 	if ((res = basic_subtract(&tmp, a, d)) != MP_OKAY) {
2665 		goto ERR;
2666 	}
2667 
2668 ERR:
2669 	mp_clear(&tmp);
2670 	return res;
2671 }
2672 
2673 /* Source: /usr/cvsroot/libtommath/dist/libtommath/bn_mp_reduce_2k_setup_l.c,v $ */
2674 /* Revision: 1.1.1.1 $ */
2675 /* Date: 2011/03/12 22:58:18 $ */
2676 
2677 /* reduces a modulo n where n is of the form 2**p - d
2678    This differs from reduce_2k since "d" can be larger
2679    than a single digit.
2680 */
2681 static int
mp_reduce_2k_l(mp_int * a,mp_int * n,mp_int * d)2682 mp_reduce_2k_l(mp_int *a, mp_int *n, mp_int *d)
2683 {
2684 	mp_int q;
2685 	int    p, res;
2686 
2687 	if ((res = mp_init(&q)) != MP_OKAY) {
2688 		return res;
2689 	}
2690 
2691 	p = mp_count_bits(n);
2692 top:
2693 	/* q = a/2**p, a = a mod 2**p */
2694 	if ((res = rshift_bits(a, p, &q, a)) != MP_OKAY) {
2695 		goto ERR;
2696 	}
2697 
2698 	/* q = q * d */
2699 	if ((res = signed_multiply(&q, d, &q)) != MP_OKAY) {
2700 		goto ERR;
2701 	}
2702 
2703 	/* a = a + q */
2704 	if ((res = basic_add(a, &q, a)) != MP_OKAY) {
2705 		goto ERR;
2706 	}
2707 
2708 	if (compare_magnitude(a, n) != MP_LT) {
2709 		basic_subtract(a, n, a);
2710 		goto top;
2711 	}
2712 
2713 ERR:
2714 	mp_clear(&q);
2715 	return res;
2716 }
2717 
2718 /* Source: /usr/cvsroot/libtommath/dist/libtommath/bn_mp_reduce_2k_l.c,v $ */
2719 /* Revision: 1.1.1.1 $ */
2720 /* Date: 2011/03/12 22:58:18 $ */
2721 
2722 /* squaring using Toom-Cook 3-way algorithm */
2723 static int
toom_cook_square(mp_int * a,mp_int * b)2724 toom_cook_square(mp_int *a, mp_int *b)
2725 {
2726 	mp_int w0, w1, w2, w3, w4, tmp1, a0, a1, a2;
2727 	int res, B;
2728 
2729 	/* init temps */
2730 	if ((res = mp_init_multi(&w0, &w1, &w2, &w3, &w4, &a0, &a1, &a2, &tmp1, NULL)) != MP_OKAY) {
2731 		return res;
2732 	}
2733 
2734 	/* B */
2735 	B = a->used / 3;
2736 
2737 	/* a = a2 * B**2 + a1 * B + a0 */
2738 	if ((res = modulo_2_to_power(a, DIGIT_BIT * B, &a0)) != MP_OKAY) {
2739 		goto ERR;
2740 	}
2741 
2742 	if ((res = mp_copy(a, &a1)) != MP_OKAY) {
2743 		goto ERR;
2744 	}
2745 	rshift_digits(&a1, B);
2746 	modulo_2_to_power(&a1, DIGIT_BIT * B, &a1);
2747 
2748 	if ((res = mp_copy(a, &a2)) != MP_OKAY) {
2749 		goto ERR;
2750 	}
2751 	rshift_digits(&a2, B*2);
2752 
2753 	/* w0 = a0*a0 */
2754 	if ((res = square(&a0, &w0)) != MP_OKAY) {
2755 		goto ERR;
2756 	}
2757 
2758 	/* w4 = a2 * a2 */
2759 	if ((res = square(&a2, &w4)) != MP_OKAY) {
2760 		goto ERR;
2761 	}
2762 
2763 	/* w1 = (a2 + 2(a1 + 2a0))**2 */
2764 	if ((res = doubled(&a0, &tmp1)) != MP_OKAY) {
2765 		goto ERR;
2766 	}
2767 	if ((res = signed_add(&tmp1, &a1, &tmp1)) != MP_OKAY) {
2768 		goto ERR;
2769 	}
2770 	if ((res = doubled(&tmp1, &tmp1)) != MP_OKAY) {
2771 		goto ERR;
2772 	}
2773 	if ((res = signed_add(&tmp1, &a2, &tmp1)) != MP_OKAY) {
2774 		goto ERR;
2775 	}
2776 
2777 	if ((res = square(&tmp1, &w1)) != MP_OKAY) {
2778 		goto ERR;
2779 	}
2780 
2781 	/* w3 = (a0 + 2(a1 + 2a2))**2 */
2782 	if ((res = doubled(&a2, &tmp1)) != MP_OKAY) {
2783 		goto ERR;
2784 	}
2785 	if ((res = signed_add(&tmp1, &a1, &tmp1)) != MP_OKAY) {
2786 		goto ERR;
2787 	}
2788 	if ((res = doubled(&tmp1, &tmp1)) != MP_OKAY) {
2789 		goto ERR;
2790 	}
2791 	if ((res = signed_add(&tmp1, &a0, &tmp1)) != MP_OKAY) {
2792 		goto ERR;
2793 	}
2794 
2795 	if ((res = square(&tmp1, &w3)) != MP_OKAY) {
2796 		goto ERR;
2797 	}
2798 
2799 
2800 	/* w2 = (a2 + a1 + a0)**2 */
2801 	if ((res = signed_add(&a2, &a1, &tmp1)) != MP_OKAY) {
2802 		goto ERR;
2803 	}
2804 	if ((res = signed_add(&tmp1, &a0, &tmp1)) != MP_OKAY) {
2805 		goto ERR;
2806 	}
2807 	if ((res = square(&tmp1, &w2)) != MP_OKAY) {
2808 		goto ERR;
2809 	}
2810 
2811 	/* now solve the matrix
2812 
2813 	0  0  0  0  1
2814 	1  2  4  8  16
2815 	1  1  1  1  1
2816 	16 8  4  2  1
2817 	1  0  0  0  0
2818 
2819 	using 12 subtractions, 4 shifts, 2 small divisions and 1 small multiplication.
2820 	*/
2821 
2822 	/* r1 - r4 */
2823 	if ((res = signed_subtract(&w1, &w4, &w1)) != MP_OKAY) {
2824 		goto ERR;
2825 	}
2826 	/* r3 - r0 */
2827 	if ((res = signed_subtract(&w3, &w0, &w3)) != MP_OKAY) {
2828 		goto ERR;
2829 	}
2830 	/* r1/2 */
2831 	if ((res = half(&w1, &w1)) != MP_OKAY) {
2832 		goto ERR;
2833 	}
2834 	/* r3/2 */
2835 	if ((res = half(&w3, &w3)) != MP_OKAY) {
2836 		goto ERR;
2837 	}
2838 	/* r2 - r0 - r4 */
2839 	if ((res = signed_subtract(&w2, &w0, &w2)) != MP_OKAY) {
2840 		goto ERR;
2841 	}
2842 	if ((res = signed_subtract(&w2, &w4, &w2)) != MP_OKAY) {
2843 		goto ERR;
2844 	}
2845 	/* r1 - r2 */
2846 	if ((res = signed_subtract(&w1, &w2, &w1)) != MP_OKAY) {
2847 		goto ERR;
2848 	}
2849 	/* r3 - r2 */
2850 	if ((res = signed_subtract(&w3, &w2, &w3)) != MP_OKAY) {
2851 		goto ERR;
2852 	}
2853 	/* r1 - 8r0 */
2854 	if ((res = lshift_bits(&w0, 3, &tmp1)) != MP_OKAY) {
2855 		goto ERR;
2856 	}
2857 	if ((res = signed_subtract(&w1, &tmp1, &w1)) != MP_OKAY) {
2858 		goto ERR;
2859 	}
2860 	/* r3 - 8r4 */
2861 	if ((res = lshift_bits(&w4, 3, &tmp1)) != MP_OKAY) {
2862 		goto ERR;
2863 	}
2864 	if ((res = signed_subtract(&w3, &tmp1, &w3)) != MP_OKAY) {
2865 		goto ERR;
2866 	}
2867 	/* 3r2 - r1 - r3 */
2868 	if ((res = multiply_digit(&w2, 3, &w2)) != MP_OKAY) {
2869 		goto ERR;
2870 	}
2871 	if ((res = signed_subtract(&w2, &w1, &w2)) != MP_OKAY) {
2872 		goto ERR;
2873 	}
2874 	if ((res = signed_subtract(&w2, &w3, &w2)) != MP_OKAY) {
2875 		goto ERR;
2876 	}
2877 	/* r1 - r2 */
2878 	if ((res = signed_subtract(&w1, &w2, &w1)) != MP_OKAY) {
2879 		goto ERR;
2880 	}
2881 	/* r3 - r2 */
2882 	if ((res = signed_subtract(&w3, &w2, &w3)) != MP_OKAY) {
2883 		goto ERR;
2884 	}
2885 	/* r1/3 */
2886 	if ((res = third(&w1, &w1, NULL)) != MP_OKAY) {
2887 		goto ERR;
2888 	}
2889 	/* r3/3 */
2890 	if ((res = third(&w3, &w3, NULL)) != MP_OKAY) {
2891 		goto ERR;
2892 	}
2893 
2894 	/* at this point shift W[n] by B*n */
2895 	if ((res = lshift_digits(&w1, 1*B)) != MP_OKAY) {
2896 		goto ERR;
2897 	}
2898 	if ((res = lshift_digits(&w2, 2*B)) != MP_OKAY) {
2899 		goto ERR;
2900 	}
2901 	if ((res = lshift_digits(&w3, 3*B)) != MP_OKAY) {
2902 		goto ERR;
2903 	}
2904 	if ((res = lshift_digits(&w4, 4*B)) != MP_OKAY) {
2905 		goto ERR;
2906 	}
2907 
2908 	if ((res = signed_add(&w0, &w1, b)) != MP_OKAY) {
2909 		goto ERR;
2910 	}
2911 	if ((res = signed_add(&w2, &w3, &tmp1)) != MP_OKAY) {
2912 		goto ERR;
2913 	}
2914 	if ((res = signed_add(&w4, &tmp1, &tmp1)) != MP_OKAY) {
2915 		goto ERR;
2916 	}
2917 	if ((res = signed_add(&tmp1, b, b)) != MP_OKAY) {
2918 		goto ERR;
2919 	}
2920 
2921 ERR:
2922 	mp_clear_multi(&w0, &w1, &w2, &w3, &w4, &a0, &a1, &a2, &tmp1, NULL);
2923 	return res;
2924 }
2925 
2926 
2927 /* Source: /usr/cvsroot/libtommath/dist/libtommath/bn_mp_toom_sqr.c,v $ */
2928 /* Revision: 1.1.1.1 $ */
2929 /* Date: 2011/03/12 22:58:18 $ */
2930 
2931 /* Karatsuba squaring, computes b = a*a using three
2932  * half size squarings
2933  *
2934  * See comments of karatsuba_mul for details.  It
2935  * is essentially the same algorithm but merely
2936  * tuned to perform recursive squarings.
2937  */
2938 static int
karatsuba_square(mp_int * a,mp_int * b)2939 karatsuba_square(mp_int * a, mp_int * b)
2940 {
2941 	mp_int  x0, x1, t1, t2, x0x0, x1x1;
2942 	int     B, err;
2943 
2944 	err = MP_MEM;
2945 
2946 	/* min # of digits */
2947 	B = a->used;
2948 
2949 	/* now divide in two */
2950 	B = (unsigned)B >> 1;
2951 
2952 	/* init copy all the temps */
2953 	if (mp_init_size(&x0, B) != MP_OKAY) {
2954 		goto ERR;
2955 	}
2956 	if (mp_init_size(&x1, a->used - B) != MP_OKAY) {
2957 		goto X0;
2958 	}
2959 	/* init temps */
2960 	if (mp_init_size(&t1, a->used * 2) != MP_OKAY) {
2961 		goto X1;
2962 	}
2963 	if (mp_init_size(&t2, a->used * 2) != MP_OKAY) {
2964 		goto T1;
2965 	}
2966 	if (mp_init_size(&x0x0, B * 2) != MP_OKAY) {
2967 		goto T2;
2968 	}
2969 	if (mp_init_size(&x1x1, (a->used - B) * 2) != MP_OKAY) {
2970 		goto X0X0;
2971 	}
2972 
2973 	memcpy(x0.dp, a->dp, B * sizeof(*x0.dp));
2974 	memcpy(x1.dp, &a->dp[B], (a->used - B) * sizeof(*x1.dp));
2975 
2976 	x0.used = B;
2977 	x1.used = a->used - B;
2978 
2979 	trim_unused_digits(&x0);
2980 
2981 	/* now calc the products x0*x0 and x1*x1 */
2982 	if (square(&x0, &x0x0) != MP_OKAY) {
2983 		goto X1X1;           /* x0x0 = x0*x0 */
2984 	}
2985 	if (square(&x1, &x1x1) != MP_OKAY) {
2986 		goto X1X1;           /* x1x1 = x1*x1 */
2987 	}
2988 	/* now calc (x1+x0)**2 */
2989 	if (basic_add(&x1, &x0, &t1) != MP_OKAY) {
2990 		goto X1X1;           /* t1 = x1 - x0 */
2991 	}
2992 	if (square(&t1, &t1) != MP_OKAY) {
2993 		goto X1X1;           /* t1 = (x1 - x0) * (x1 - x0) */
2994 	}
2995 	/* add x0y0 */
2996 	if (basic_add(&x0x0, &x1x1, &t2) != MP_OKAY) {
2997 		goto X1X1;           /* t2 = x0x0 + x1x1 */
2998 	}
2999 	if (basic_subtract(&t1, &t2, &t1) != MP_OKAY) {
3000 		goto X1X1;           /* t1 = (x1+x0)**2 - (x0x0 + x1x1) */
3001 	}
3002 	/* shift by B */
3003 	if (lshift_digits(&t1, B) != MP_OKAY) {
3004 		goto X1X1;           /* t1 = (x0x0 + x1x1 - (x1-x0)*(x1-x0))<<B */
3005 	}
3006 	if (lshift_digits(&x1x1, B * 2) != MP_OKAY) {
3007 		goto X1X1;           /* x1x1 = x1x1 << 2*B */
3008 	}
3009 	if (signed_add(&x0x0, &t1, &t1) != MP_OKAY) {
3010 		goto X1X1;           /* t1 = x0x0 + t1 */
3011 	}
3012 	if (signed_add(&t1, &x1x1, b) != MP_OKAY) {
3013 		goto X1X1;           /* t1 = x0x0 + t1 + x1x1 */
3014 	}
3015 	err = MP_OKAY;
3016 
3017 X1X1:
3018 	mp_clear(&x1x1);
3019 X0X0:
3020 	mp_clear(&x0x0);
3021 T2:
3022 	mp_clear(&t2);
3023 T1:
3024 	mp_clear(&t1);
3025 X1:
3026 	mp_clear(&x1);
3027 X0:
3028 	mp_clear(&x0);
3029 ERR:
3030 	return err;
3031 }
3032 
3033 /* Source: /usr/cvsroot/libtommath/dist/libtommath/bn_mp_karatsuba_sqr.c,v $ */
3034 /* Revision: 1.2 $ */
3035 /* Date: 2011/03/12 23:43:54 $ */
3036 
3037 /* the jist of squaring...
3038  * you do like mult except the offset of the tmpx [one that
3039  * starts closer to zero] can't equal the offset of tmpy.
3040  * So basically you set up iy like before then you min it with
3041  * (ty-tx) so that it never happens.  You double all those
3042  * you add in the inner loop
3043 
3044 After that loop you do the squares and add them in.
3045 */
3046 
3047 static int
fast_basic_square(mp_int * a,mp_int * b)3048 fast_basic_square(mp_int * a, mp_int * b)
3049 {
3050 	int       olduse, res, pa, ix, iz;
3051 	mp_digit   W[MP_WARRAY], *tmpx;
3052 	mp_word   W1;
3053 
3054 	/* grow the destination as required */
3055 	pa = a->used + a->used;
3056 	if (b->alloc < pa) {
3057 		if ((res = mp_grow(b, pa)) != MP_OKAY) {
3058 			return res;
3059 		}
3060 	}
3061 
3062 	/* number of output digits to produce */
3063 	W1 = 0;
3064 	for (ix = 0; ix < pa; ix++) {
3065 		int      tx, ty, iy;
3066 		mp_word  _W;
3067 		mp_digit *tmpy;
3068 
3069 		/* clear counter */
3070 		_W = 0;
3071 
3072 		/* get offsets into the two bignums */
3073 		ty = MIN(a->used-1, ix);
3074 		tx = ix - ty;
3075 
3076 		/* setup temp aliases */
3077 		tmpx = a->dp + tx;
3078 		tmpy = a->dp + ty;
3079 
3080 		/* this is the number of times the loop will iterrate, essentially
3081 		 while (tx++ < a->used && ty-- >= 0) { ... }
3082 		*/
3083 		iy = MIN(a->used-tx, ty+1);
3084 
3085 		/* now for squaring tx can never equal ty
3086 		* we halve the distance since they approach at a rate of 2x
3087 		* and we have to round because odd cases need to be executed
3088 		*/
3089 		iy = MIN(iy, (int)((unsigned)(ty-tx+1)>>1));
3090 
3091 		/* execute loop */
3092 		for (iz = 0; iz < iy; iz++) {
3093 			 _W += ((mp_word)*tmpx++)*((mp_word)*tmpy--);
3094 		}
3095 
3096 		/* double the inner product and add carry */
3097 		_W = _W + _W + W1;
3098 
3099 		/* even columns have the square term in them */
3100 		if ((ix&1) == 0) {
3101 			 _W += ((mp_word)a->dp[(unsigned)ix>>1])*((mp_word)a->dp[(unsigned)ix>>1]);
3102 		}
3103 
3104 		/* store it */
3105 		W[ix] = (mp_digit)(_W & MP_MASK);
3106 
3107 		/* make next carry */
3108 		W1 = _W >> ((mp_word)DIGIT_BIT);
3109 	}
3110 
3111 	/* setup dest */
3112 	olduse  = b->used;
3113 	b->used = a->used+a->used;
3114 
3115 	{
3116 		mp_digit *tmpb;
3117 		tmpb = b->dp;
3118 		for (ix = 0; ix < pa; ix++) {
3119 			*tmpb++ = W[ix] & MP_MASK;
3120 		}
3121 
3122 		/* clear unused digits [that existed in the old copy of c] */
3123 		for (; ix < olduse; ix++) {
3124 			*tmpb++ = 0;
3125 		}
3126 	}
3127 	trim_unused_digits(b);
3128 	return MP_OKAY;
3129 }
3130 
3131 /* Source: /usr/cvsroot/libtommath/dist/libtommath/bn_fast_s_mp_sqr.c,v $ */
3132 /* Revision: 1.3 $ */
3133 /* Date: 2011/03/18 16:43:04 $ */
3134 
3135 /* low level squaring, b = a*a, HAC pp.596-597, Algorithm 14.16 */
3136 static int
basic_square(mp_int * a,mp_int * b)3137 basic_square(mp_int * a, mp_int * b)
3138 {
3139 	mp_int  t;
3140 	int     res, ix, iy, pa;
3141 	mp_word r;
3142 	mp_digit carry, tmpx, *tmpt;
3143 
3144 	pa = a->used;
3145 	if ((res = mp_init_size(&t, 2*pa + 1)) != MP_OKAY) {
3146 		return res;
3147 	}
3148 
3149 	/* default used is maximum possible size */
3150 	t.used = 2*pa + 1;
3151 
3152 	for (ix = 0; ix < pa; ix++) {
3153 		/* first calculate the digit at 2*ix */
3154 		/* calculate double precision result */
3155 		r = ((mp_word) t.dp[2*ix]) +
3156 		((mp_word)a->dp[ix])*((mp_word)a->dp[ix]);
3157 
3158 		/* store lower part in result */
3159 		t.dp[ix+ix] = (mp_digit) (r & ((mp_word) MP_MASK));
3160 
3161 		/* get the carry */
3162 		carry = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
3163 
3164 		/* left hand side of A[ix] * A[iy] */
3165 		tmpx = a->dp[ix];
3166 
3167 		/* alias for where to store the results */
3168 		tmpt = t.dp + (2*ix + 1);
3169 
3170 		for (iy = ix + 1; iy < pa; iy++) {
3171 			/* first calculate the product */
3172 			r = ((mp_word)tmpx) * ((mp_word)a->dp[iy]);
3173 
3174 			/* now calculate the double precision result, note we use
3175 			* addition instead of *2 since it's easier to optimize
3176 			*/
3177 			r = ((mp_word) *tmpt) + r + r + ((mp_word) carry);
3178 
3179 			/* store lower part */
3180 			*tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
3181 
3182 			/* get carry */
3183 			carry = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
3184 		}
3185 		/* propagate upwards */
3186 		while (carry != ((mp_digit) 0)) {
3187 			r = ((mp_word) *tmpt) + ((mp_word) carry);
3188 			*tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
3189 			carry = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
3190 		}
3191 	}
3192 
3193 	trim_unused_digits(&t);
3194 	mp_exch(&t, b);
3195 	mp_clear(&t);
3196 	return MP_OKAY;
3197 }
3198 
3199 /* Source: /usr/cvsroot/libtommath/dist/libtommath/bn_s_mp_sqr.c,v $ */
3200 /* Revision: 1.1.1.1 $ */
3201 /* Date: 2011/03/12 22:58:18 $ */
3202 
3203 #define TOOM_SQR_CUTOFF      400
3204 #define KARATSUBA_SQR_CUTOFF 120
3205 
3206 /* computes b = a*a */
3207 static int
square(mp_int * a,mp_int * b)3208 square(mp_int * a, mp_int * b)
3209 {
3210 	int     res;
3211 
3212 	/* use Toom-Cook? */
3213 	if (a->used >= TOOM_SQR_CUTOFF) {
3214 		res = toom_cook_square(a, b);
3215 		/* Karatsuba? */
3216 	} else if (a->used >= KARATSUBA_SQR_CUTOFF) {
3217 		res = karatsuba_square(a, b);
3218 	} else {
3219 		/* can we use the fast comba multiplier? */
3220 		if (can_use_fast_column_array(a->used + a->used + 1, a->used)) {
3221 			res = fast_basic_square(a, b);
3222 		} else {
3223 			res = basic_square(a, b);
3224 		}
3225 	}
3226 	b->sign = MP_ZPOS;
3227 	return res;
3228 }
3229 
3230 /* find window size */
3231 static inline int
find_window_size(mp_int * X)3232 find_window_size(mp_int *X)
3233 {
3234 	int	x;
3235 
3236 	x = mp_count_bits(X);
3237 	return (x <= 7) ? 2 : (x <= 36) ? 3 : (x <= 140) ? 4 : (x <= 450) ? 5 : (x <= 1303) ? 6 : (x <= 3529) ? 7 : 8;
3238 }
3239 
3240 /* Source: /usr/cvsroot/libtommath/dist/libtommath/bn_mp_sqr.c,v $ */
3241 /* Revision: 1.3 $ */
3242 /* Date: 2011/03/18 16:43:04 $ */
3243 
3244 #define TAB_SIZE 256
3245 
3246 static int
basic_exponent_mod(mp_int * G,mp_int * X,mp_int * P,mp_int * Y,int redmode)3247 basic_exponent_mod(mp_int * G, mp_int * X, mp_int * P, mp_int * Y, int redmode)
3248 {
3249 	mp_digit buf;
3250 	mp_int  M[TAB_SIZE], res, mu;
3251 	int     err, bitbuf, bitcpy, bitcnt, mode, digidx, x, y, winsize;
3252 	int	(*redux)(mp_int*,mp_int*,mp_int*);
3253 
3254 	winsize = find_window_size(X);
3255 
3256 	/* init M array */
3257 	/* init first cell */
3258 	if ((err = mp_init(&M[1])) != MP_OKAY) {
3259 		return err;
3260 	}
3261 
3262 	/* now init the second half of the array */
3263 	for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
3264 		if ((err = mp_init(&M[x])) != MP_OKAY) {
3265 			for (y = 1<<(winsize-1); y < x; y++) {
3266 				mp_clear(&M[y]);
3267 			}
3268 			mp_clear(&M[1]);
3269 			return err;
3270 		}
3271 	}
3272 
3273 	/* create mu, used for Barrett reduction */
3274 	if ((err = mp_init(&mu)) != MP_OKAY) {
3275 		goto LBL_M;
3276 	}
3277 
3278 	if (redmode == 0) {
3279 		if ((err = mp_reduce_setup(&mu, P)) != MP_OKAY) {
3280 			goto LBL_MU;
3281 		}
3282 		redux = mp_reduce;
3283 	} else {
3284 		if ((err = mp_reduce_2k_setup_l(P, &mu)) != MP_OKAY) {
3285 			goto LBL_MU;
3286 		}
3287 		redux = mp_reduce_2k_l;
3288 	}
3289 
3290 	/* create M table
3291 	*
3292 	* The M table contains powers of the base,
3293 	* e.g. M[x] = G**x mod P
3294 	*
3295 	* The first half of the table is not
3296 	* computed though accept for M[0] and M[1]
3297 	*/
3298 	if ((err = modulo(G, P, &M[1])) != MP_OKAY) {
3299 		goto LBL_MU;
3300 	}
3301 
3302 	/* compute the value at M[1<<(winsize-1)] by squaring
3303 	* M[1] (winsize-1) times
3304 	*/
3305 	if ((err = mp_copy( &M[1], &M[1 << (winsize - 1)])) != MP_OKAY) {
3306 		goto LBL_MU;
3307 	}
3308 
3309 	for (x = 0; x < (winsize - 1); x++) {
3310 		/* square it */
3311 		if ((err = square(&M[1 << (winsize - 1)],
3312 		       &M[1 << (winsize - 1)])) != MP_OKAY) {
3313 			goto LBL_MU;
3314 		}
3315 
3316 		/* reduce modulo P */
3317 		if ((err = redux(&M[1 << (winsize - 1)], P, &mu)) != MP_OKAY) {
3318 			goto LBL_MU;
3319 		}
3320 	}
3321 
3322 	/* create upper table, that is M[x] = M[x-1] * M[1] (mod P)
3323 	* for x = (2**(winsize - 1) + 1) to (2**winsize - 1)
3324 	*/
3325 	for (x = (1 << (winsize - 1)) + 1; x < (1 << winsize); x++) {
3326 		if ((err = signed_multiply(&M[x - 1], &M[1], &M[x])) != MP_OKAY) {
3327 			goto LBL_MU;
3328 		}
3329 		if ((err = redux(&M[x], P, &mu)) != MP_OKAY) {
3330 			goto LBL_MU;
3331 		}
3332 	}
3333 
3334 	/* setup result */
3335 	if ((err = mp_init(&res)) != MP_OKAY) {
3336 		goto LBL_MU;
3337 	}
3338 	set_word(&res, 1);
3339 
3340 	/* set initial mode and bit cnt */
3341 	mode = 0;
3342 	bitcnt = 1;
3343 	buf = 0;
3344 	digidx = X->used - 1;
3345 	bitcpy = 0;
3346 	bitbuf = 0;
3347 
3348 	for (;;) {
3349 		/* grab next digit as required */
3350 		if (--bitcnt == 0) {
3351 			/* if digidx == -1 we are out of digits */
3352 			if (digidx == -1) {
3353 				break;
3354 			}
3355 			/* read next digit and reset the bitcnt */
3356 			buf = X->dp[digidx--];
3357 			bitcnt = (int) DIGIT_BIT;
3358 		}
3359 
3360 		/* grab the next msb from the exponent */
3361 		y = (unsigned)(buf >> (mp_digit)(DIGIT_BIT - 1)) & 1;
3362 		buf <<= (mp_digit)1;
3363 
3364 		/* if the bit is zero and mode == 0 then we ignore it
3365 		* These represent the leading zero bits before the first 1 bit
3366 		* in the exponent.  Technically this opt is not required but it
3367 		* does lower the # of trivial squaring/reductions used
3368 		*/
3369 		if (mode == 0 && y == 0) {
3370 			continue;
3371 		}
3372 
3373 		/* if the bit is zero and mode == 1 then we square */
3374 		if (mode == 1 && y == 0) {
3375 			if ((err = square(&res, &res)) != MP_OKAY) {
3376 				goto LBL_RES;
3377 			}
3378 			if ((err = redux(&res, P, &mu)) != MP_OKAY) {
3379 				goto LBL_RES;
3380 			}
3381 			continue;
3382 		}
3383 
3384 		/* else we add it to the window */
3385 		bitbuf |= (y << (winsize - ++bitcpy));
3386 		mode = 2;
3387 
3388 		if (bitcpy == winsize) {
3389 			/* ok window is filled so square as required and multiply  */
3390 			/* square first */
3391 			for (x = 0; x < winsize; x++) {
3392 				if ((err = square(&res, &res)) != MP_OKAY) {
3393 					goto LBL_RES;
3394 				}
3395 				if ((err = redux(&res, P, &mu)) != MP_OKAY) {
3396 					goto LBL_RES;
3397 				}
3398 			}
3399 
3400 			/* then multiply */
3401 			if ((err = signed_multiply(&res, &M[bitbuf], &res)) != MP_OKAY) {
3402 				goto LBL_RES;
3403 			}
3404 			if ((err = redux(&res, P, &mu)) != MP_OKAY) {
3405 				goto LBL_RES;
3406 			}
3407 
3408 			/* empty window and reset */
3409 			bitcpy = 0;
3410 			bitbuf = 0;
3411 			mode = 1;
3412 		}
3413 	}
3414 
3415 	/* if bits remain then square/multiply */
3416 	if (mode == 2 && bitcpy > 0) {
3417 		/* square then multiply if the bit is set */
3418 		for (x = 0; x < bitcpy; x++) {
3419 			if ((err = square(&res, &res)) != MP_OKAY) {
3420 				goto LBL_RES;
3421 			}
3422 			if ((err = redux(&res, P, &mu)) != MP_OKAY) {
3423 				goto LBL_RES;
3424 			}
3425 
3426 			bitbuf <<= 1;
3427 			if ((bitbuf & (1 << winsize)) != 0) {
3428 				/* then multiply */
3429 				if ((err = signed_multiply(&res, &M[1], &res)) != MP_OKAY) {
3430 					goto LBL_RES;
3431 				}
3432 				if ((err = redux(&res, P, &mu)) != MP_OKAY) {
3433 					goto LBL_RES;
3434 				}
3435 			}
3436 		}
3437 	}
3438 
3439 	mp_exch(&res, Y);
3440 	err = MP_OKAY;
3441 LBL_RES:
3442 	mp_clear(&res);
3443 LBL_MU:
3444 	mp_clear(&mu);
3445 LBL_M:
3446 	mp_clear(&M[1]);
3447 	for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
3448 		mp_clear(&M[x]);
3449 	}
3450 	return err;
3451 }
3452 
3453 /* determines if a number is a valid DR modulus */
3454 static int
is_diminished_radix_modulus(mp_int * a)3455 is_diminished_radix_modulus(mp_int *a)
3456 {
3457 	int ix;
3458 
3459 	/* must be at least two digits */
3460 	if (a->used < 2) {
3461 		return 0;
3462 	}
3463 
3464 	/* must be of the form b**k - a [a <= b] so all
3465 	* but the first digit must be equal to -1 (mod b).
3466 	*/
3467 	for (ix = 1; ix < a->used; ix++) {
3468 		if (a->dp[ix] != MP_MASK) {
3469 			  return 0;
3470 		}
3471 	}
3472 	return 1;
3473 }
3474 
3475 /* Source: /usr/cvsroot/libtommath/dist/libtommath/bn_mp_dr_is_modulus.c,v $ */
3476 /* Revision: 1.1.1.1 $ */
3477 /* Date: 2011/03/12 22:58:18 $ */
3478 
3479 /* determines if mp_reduce_2k can be used */
3480 static int
mp_reduce_is_2k(mp_int * a)3481 mp_reduce_is_2k(mp_int *a)
3482 {
3483 	int ix, iy, iw;
3484 	mp_digit iz;
3485 
3486 	if (a->used == 0) {
3487 		return MP_NO;
3488 	}
3489 	if (a->used == 1) {
3490 		return MP_YES;
3491 	}
3492 	if (a->used > 1) {
3493 		iy = mp_count_bits(a);
3494 		iz = 1;
3495 		iw = 1;
3496 
3497 		/* Test every bit from the second digit up, must be 1 */
3498 		for (ix = DIGIT_BIT; ix < iy; ix++) {
3499 			if ((a->dp[iw] & iz) == 0) {
3500 				return MP_NO;
3501 			}
3502 			iz <<= 1;
3503 			if (iz > (mp_digit)MP_MASK) {
3504 				++iw;
3505 				iz = 1;
3506 			}
3507 		}
3508 	}
3509 	return MP_YES;
3510 }
3511 
3512 /* Source: /usr/cvsroot/libtommath/dist/libtommath/bn_mp_reduce_is_2k.c,v $ */
3513 /* Revision: 1.1.1.1 $ */
3514 /* Date: 2011/03/12 22:58:18 $ */
3515 
3516 
3517 /* d = a * b (mod c) */
3518 static int
multiply_modulo(mp_int * d,mp_int * a,mp_int * b,mp_int * c)3519 multiply_modulo(mp_int *d, mp_int * a, mp_int * b, mp_int * c)
3520 {
3521 	mp_int  t;
3522 	int     res;
3523 
3524 	if ((res = mp_init(&t)) != MP_OKAY) {
3525 		return res;
3526 	}
3527 
3528 	if ((res = signed_multiply(a, b, &t)) != MP_OKAY) {
3529 		mp_clear(&t);
3530 		return res;
3531 	}
3532 	res = modulo(&t, c, d);
3533 	mp_clear(&t);
3534 	return res;
3535 }
3536 
3537 /* Source: /usr/cvsroot/libtommath/dist/libtommath/bn_mp_mulmod.c,v $ */
3538 /* Revision: 1.1.1.1 $ */
3539 /* Date: 2011/03/12 22:58:18 $ */
3540 
3541 /* setups the montgomery reduction stuff */
3542 static int
mp_montgomery_setup(mp_int * n,mp_digit * rho)3543 mp_montgomery_setup(mp_int * n, mp_digit * rho)
3544 {
3545 	mp_digit x, b;
3546 
3547 	/* fast inversion mod 2**k
3548 	*
3549 	* Based on the fact that
3550 	*
3551 	* XA = 1 (mod 2**n)  =>  (X(2-XA)) A = 1 (mod 2**2n)
3552 	*                    =>  2*X*A - X*X*A*A = 1
3553 	*                    =>  2*(1) - (1)     = 1
3554 	*/
3555 	b = n->dp[0];
3556 
3557 	if ((b & 1) == 0) {
3558 		return MP_VAL;
3559 	}
3560 
3561 	x = (((b + 2) & 4) << 1) + b; /* here x*a==1 mod 2**4 */
3562 	x *= 2 - b * x;               /* here x*a==1 mod 2**8 */
3563 	x *= 2 - b * x;               /* here x*a==1 mod 2**16 */
3564 	x *= 2 - b * x;               /* here x*a==1 mod 2**32 */
3565 	if (/*CONSTCOND*/sizeof(mp_digit) == 8) {
3566 		x *= 2 - b * x;	/* here x*a==1 mod 2**64 */
3567 	}
3568 
3569 	/* rho = -1/m mod b */
3570 	*rho = (unsigned long)(((mp_word)1 << ((mp_word) DIGIT_BIT)) - x) & MP_MASK;
3571 
3572 	return MP_OKAY;
3573 }
3574 
3575 /* Source: /usr/cvsroot/libtommath/dist/libtommath/bn_mp_montgomery_setup.c,v $ */
3576 /* Revision: 1.1.1.1 $ */
3577 /* Date: 2011/03/12 22:58:18 $ */
3578 
3579 /* computes xR**-1 == x (mod N) via Montgomery Reduction
3580  *
3581  * This is an optimized implementation of montgomery_reduce
3582  * which uses the comba method to quickly calculate the columns of the
3583  * reduction.
3584  *
3585  * Based on Algorithm 14.32 on pp.601 of HAC.
3586 */
3587 static int
fast_mp_montgomery_reduce(mp_int * x,mp_int * n,mp_digit rho)3588 fast_mp_montgomery_reduce(mp_int * x, mp_int * n, mp_digit rho)
3589 {
3590 	int     ix, res, olduse;
3591 	/*LINTED*/
3592 	mp_word W[MP_WARRAY];
3593 
3594 	/* get old used count */
3595 	olduse = x->used;
3596 
3597 	/* grow a as required */
3598 	if (x->alloc < n->used + 1) {
3599 		if ((res = mp_grow(x, n->used + 1)) != MP_OKAY) {
3600 			return res;
3601 		}
3602 	}
3603 
3604 	/* first we have to get the digits of the input into
3605 	* an array of double precision words W[...]
3606 	*/
3607 	{
3608 		mp_word *_W;
3609 		mp_digit *tmpx;
3610 
3611 		/* alias for the W[] array */
3612 		_W = W;
3613 
3614 		/* alias for the digits of  x*/
3615 		tmpx = x->dp;
3616 
3617 		/* copy the digits of a into W[0..a->used-1] */
3618 		for (ix = 0; ix < x->used; ix++) {
3619 			*_W++ = *tmpx++;
3620 		}
3621 
3622 		/* zero the high words of W[a->used..m->used*2] */
3623 		for (; ix < n->used * 2 + 1; ix++) {
3624 			*_W++ = 0;
3625 		}
3626 	}
3627 
3628 	/* now we proceed to zero successive digits
3629 	* from the least significant upwards
3630 	*/
3631 	for (ix = 0; ix < n->used; ix++) {
3632 		/* mu = ai * m' mod b
3633 		*
3634 		* We avoid a double precision multiplication (which isn't required)
3635 		* by casting the value down to a mp_digit.  Note this requires
3636 		* that W[ix-1] have  the carry cleared (see after the inner loop)
3637 		*/
3638 		mp_digit mu;
3639 		mu = (mp_digit) (((W[ix] & MP_MASK) * rho) & MP_MASK);
3640 
3641 		/* a = a + mu * m * b**i
3642 		*
3643 		* This is computed in place and on the fly.  The multiplication
3644 		* by b**i is handled by offseting which columns the results
3645 		* are added to.
3646 		*
3647 		* Note the comba method normally doesn't handle carries in the
3648 		* inner loop In this case we fix the carry from the previous
3649 		* column since the Montgomery reduction requires digits of the
3650 		* result (so far) [see above] to work.  This is
3651 		* handled by fixing up one carry after the inner loop.  The
3652 		* carry fixups are done in order so after these loops the
3653 		* first m->used words of W[] have the carries fixed
3654 		*/
3655 		{
3656 			int iy;
3657 			mp_digit *tmpn;
3658 			mp_word *_W;
3659 
3660 			/* alias for the digits of the modulus */
3661 			tmpn = n->dp;
3662 
3663 			/* Alias for the columns set by an offset of ix */
3664 			_W = W + ix;
3665 
3666 			/* inner loop */
3667 			for (iy = 0; iy < n->used; iy++) {
3668 				  *_W++ += ((mp_word)mu) * ((mp_word)*tmpn++);
3669 			}
3670 		}
3671 
3672 		/* now fix carry for next digit, W[ix+1] */
3673 		W[ix + 1] += W[ix] >> ((mp_word) DIGIT_BIT);
3674 	}
3675 
3676 	/* now we have to propagate the carries and
3677 	* shift the words downward [all those least
3678 	* significant digits we zeroed].
3679 	*/
3680 	{
3681 		mp_digit *tmpx;
3682 		mp_word *_W, *_W1;
3683 
3684 		/* nox fix rest of carries */
3685 
3686 		/* alias for current word */
3687 		_W1 = W + ix;
3688 
3689 		/* alias for next word, where the carry goes */
3690 		_W = W + ++ix;
3691 
3692 		for (; ix <= n->used * 2 + 1; ix++) {
3693 			*_W++ += *_W1++ >> ((mp_word) DIGIT_BIT);
3694 		}
3695 
3696 		/* copy out, A = A/b**n
3697 		*
3698 		* The result is A/b**n but instead of converting from an
3699 		* array of mp_word to mp_digit than calling rshift_digits
3700 		* we just copy them in the right order
3701 		*/
3702 
3703 		/* alias for destination word */
3704 		tmpx = x->dp;
3705 
3706 		/* alias for shifted double precision result */
3707 		_W = W + n->used;
3708 
3709 		for (ix = 0; ix < n->used + 1; ix++) {
3710 			*tmpx++ = (mp_digit)(*_W++ & ((mp_word) MP_MASK));
3711 		}
3712 
3713 		/* zero oldused digits, if the input a was larger than
3714 		* m->used+1 we'll have to clear the digits
3715 		*/
3716 		for (; ix < olduse; ix++) {
3717 			*tmpx++ = 0;
3718 		}
3719 	}
3720 
3721 	/* set the max used and clamp */
3722 	x->used = n->used + 1;
3723 	trim_unused_digits(x);
3724 
3725 	/* if A >= m then A = A - m */
3726 	if (compare_magnitude(x, n) != MP_LT) {
3727 		return basic_subtract(x, n, x);
3728 	}
3729 	return MP_OKAY;
3730 }
3731 
3732 /* Source: /usr/cvsroot/libtommath/dist/libtommath/bn_fast_mp_montgomery_reduce.c,v $ */
3733 /* Revision: 1.2 $ */
3734 /* Date: 2011/03/18 16:22:09 $ */
3735 
3736 /* computes xR**-1 == x (mod N) via Montgomery Reduction */
3737 static int
mp_montgomery_reduce(mp_int * x,mp_int * n,mp_digit rho)3738 mp_montgomery_reduce(mp_int * x, mp_int * n, mp_digit rho)
3739 {
3740 	int     ix, res, digs;
3741 	mp_digit mu;
3742 
3743 	/* can the fast reduction [comba] method be used?
3744 	*
3745 	* Note that unlike in mul you're safely allowed *less*
3746 	* than the available columns [255 per default] since carries
3747 	* are fixed up in the inner loop.
3748 	*/
3749 	digs = n->used * 2 + 1;
3750 	if (can_use_fast_column_array(digs, n->used)) {
3751 		return fast_mp_montgomery_reduce(x, n, rho);
3752 	}
3753 
3754 	/* grow the input as required */
3755 	if (x->alloc < digs) {
3756 		if ((res = mp_grow(x, digs)) != MP_OKAY) {
3757 			return res;
3758 		}
3759 	}
3760 	x->used = digs;
3761 
3762 	for (ix = 0; ix < n->used; ix++) {
3763 		/* mu = ai * rho mod b
3764 		*
3765 		* The value of rho must be precalculated via
3766 		* montgomery_setup() such that
3767 		* it equals -1/n0 mod b this allows the
3768 		* following inner loop to reduce the
3769 		* input one digit at a time
3770 		*/
3771 		mu = (mp_digit) (((mp_word)x->dp[ix]) * ((mp_word)rho) & MP_MASK);
3772 
3773 		/* a = a + mu * m * b**i */
3774 		{
3775 			int iy;
3776 			mp_digit *tmpn, *tmpx, carry;
3777 			mp_word r;
3778 
3779 			/* alias for digits of the modulus */
3780 			tmpn = n->dp;
3781 
3782 			/* alias for the digits of x [the input] */
3783 			tmpx = x->dp + ix;
3784 
3785 			/* set the carry to zero */
3786 			carry = 0;
3787 
3788 			/* Multiply and add in place */
3789 			for (iy = 0; iy < n->used; iy++) {
3790 				/* compute product and sum */
3791 				r = ((mp_word)mu) * ((mp_word)*tmpn++) +
3792 					  ((mp_word) carry) + ((mp_word) * tmpx);
3793 
3794 				/* get carry */
3795 				carry = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
3796 
3797 				/* fix digit */
3798 				*tmpx++ = (mp_digit)(r & ((mp_word) MP_MASK));
3799 			}
3800 			/* At this point the ix'th digit of x should be zero */
3801 
3802 
3803 			/* propagate carries upwards as required*/
3804 			while (carry) {
3805 				*tmpx += carry;
3806 				carry = *tmpx >> DIGIT_BIT;
3807 				*tmpx++ &= MP_MASK;
3808 			}
3809 		}
3810 	}
3811 
3812 	/* at this point the n.used'th least
3813 	* significant digits of x are all zero
3814 	* which means we can shift x to the
3815 	* right by n.used digits and the
3816 	* residue is unchanged.
3817 	*/
3818 
3819 	/* x = x/b**n.used */
3820 	trim_unused_digits(x);
3821 	rshift_digits(x, n->used);
3822 
3823 	/* if x >= n then x = x - n */
3824 	if (compare_magnitude(x, n) != MP_LT) {
3825 		return basic_subtract(x, n, x);
3826 	}
3827 
3828 	return MP_OKAY;
3829 }
3830 
3831 /* Source: /usr/cvsroot/libtommath/dist/libtommath/bn_mp_montgomery_reduce.c,v $ */
3832 /* Revision: 1.3 $ */
3833 /* Date: 2011/03/18 16:43:04 $ */
3834 
3835 /* determines the setup value */
3836 static void
diminished_radix_setup(mp_int * a,mp_digit * d)3837 diminished_radix_setup(mp_int *a, mp_digit *d)
3838 {
3839 	/* the casts are required if DIGIT_BIT is one less than
3840 	* the number of bits in a mp_digit [e.g. DIGIT_BIT==31]
3841 	*/
3842 	*d = (mp_digit)((((mp_word)1) << ((mp_word)DIGIT_BIT)) -
3843 		((mp_word)a->dp[0]));
3844 }
3845 
3846 /* Source: /usr/cvsroot/libtommath/dist/libtommath/bn_mp_dr_setup.c,v $ */
3847 /* Revision: 1.1.1.1 $ */
3848 /* Date: 2011/03/12 22:58:18 $ */
3849 
3850 /* reduce "x" in place modulo "n" using the Diminished Radix algorithm.
3851  *
3852  * Based on algorithm from the paper
3853  *
3854  * "Generating Efficient Primes for Discrete Log Cryptosystems"
3855  *                 Chae Hoon Lim, Pil Joong Lee,
3856  *          POSTECH Information Research Laboratories
3857  *
3858  * The modulus must be of a special format [see manual]
3859  *
3860  * Has been modified to use algorithm 7.10 from the LTM book instead
3861  *
3862  * Input x must be in the range 0 <= x <= (n-1)**2
3863  */
3864 static int
diminished_radix_reduce(mp_int * x,mp_int * n,mp_digit k)3865 diminished_radix_reduce(mp_int * x, mp_int * n, mp_digit k)
3866 {
3867 	int      err, i, m;
3868 	mp_word  r;
3869 	mp_digit mu, *tmpx1, *tmpx2;
3870 
3871 	/* m = digits in modulus */
3872 	m = n->used;
3873 
3874 	/* ensure that "x" has at least 2m digits */
3875 	if (x->alloc < m + m) {
3876 		if ((err = mp_grow(x, m + m)) != MP_OKAY) {
3877 			return err;
3878 		}
3879 	}
3880 
3881 	/* top of loop, this is where the code resumes if
3882 	* another reduction pass is required.
3883 	*/
3884 top:
3885 	/* aliases for digits */
3886 	/* alias for lower half of x */
3887 	tmpx1 = x->dp;
3888 
3889 	/* alias for upper half of x, or x/B**m */
3890 	tmpx2 = x->dp + m;
3891 
3892 	/* set carry to zero */
3893 	mu = 0;
3894 
3895 	/* compute (x mod B**m) + k * [x/B**m] inline and inplace */
3896 	for (i = 0; i < m; i++) {
3897 		r = ((mp_word)*tmpx2++) * ((mp_word)k) + *tmpx1 + mu;
3898 		*tmpx1++  = (mp_digit)(r & MP_MASK);
3899 		mu = (mp_digit)(r >> ((mp_word)DIGIT_BIT));
3900 	}
3901 
3902 	/* set final carry */
3903 	*tmpx1++ = mu;
3904 
3905 	/* zero words above m */
3906 	for (i = m + 1; i < x->used; i++) {
3907 		*tmpx1++ = 0;
3908 	}
3909 
3910 	/* clamp, sub and return */
3911 	trim_unused_digits(x);
3912 
3913 	/* if x >= n then subtract and reduce again
3914 	* Each successive "recursion" makes the input smaller and smaller.
3915 	*/
3916 	if (compare_magnitude(x, n) != MP_LT) {
3917 		basic_subtract(x, n, x);
3918 		goto top;
3919 	}
3920 	return MP_OKAY;
3921 }
3922 
3923 /* Source: /usr/cvsroot/libtommath/dist/libtommath/bn_mp_dr_reduce.c,v $ */
3924 /* Revision: 1.1.1.1 $ */
3925 /* Date: 2011/03/12 22:58:18 $ */
3926 
3927 /* determines the setup value */
3928 static int
mp_reduce_2k_setup(mp_int * a,mp_digit * d)3929 mp_reduce_2k_setup(mp_int *a, mp_digit *d)
3930 {
3931 	int res, p;
3932 	mp_int tmp;
3933 
3934 	if ((res = mp_init(&tmp)) != MP_OKAY) {
3935 		return res;
3936 	}
3937 
3938 	p = mp_count_bits(a);
3939 	if ((res = mp_2expt(&tmp, p)) != MP_OKAY) {
3940 		mp_clear(&tmp);
3941 		return res;
3942 	}
3943 
3944 	if ((res = basic_subtract(&tmp, a, &tmp)) != MP_OKAY) {
3945 		mp_clear(&tmp);
3946 		return res;
3947 	}
3948 
3949 	*d = tmp.dp[0];
3950 	mp_clear(&tmp);
3951 	return MP_OKAY;
3952 }
3953 
3954 /* Source: /usr/cvsroot/libtommath/dist/libtommath/bn_mp_reduce_2k_setup.c,v $ */
3955 /* Revision: 1.1.1.1 $ */
3956 /* Date: 2011/03/12 22:58:18 $ */
3957 
3958 /* reduces a modulo n where n is of the form 2**p - d */
3959 static int
mp_reduce_2k(mp_int * a,mp_int * n,mp_digit d)3960 mp_reduce_2k(mp_int *a, mp_int *n, mp_digit d)
3961 {
3962 	mp_int q;
3963 	int    p, res;
3964 
3965 	if ((res = mp_init(&q)) != MP_OKAY) {
3966 		return res;
3967 	}
3968 
3969 	p = mp_count_bits(n);
3970 top:
3971 	/* q = a/2**p, a = a mod 2**p */
3972 	if ((res = rshift_bits(a, p, &q, a)) != MP_OKAY) {
3973 		goto ERR;
3974 	}
3975 
3976 	if (d != 1) {
3977 		/* q = q * d */
3978 		if ((res = multiply_digit(&q, d, &q)) != MP_OKAY) {
3979 			 goto ERR;
3980 		}
3981 	}
3982 
3983 	/* a = a + q */
3984 	if ((res = basic_add(a, &q, a)) != MP_OKAY) {
3985 		goto ERR;
3986 	}
3987 
3988 	if (compare_magnitude(a, n) != MP_LT) {
3989 		basic_subtract(a, n, a);
3990 		goto top;
3991 	}
3992 
3993 ERR:
3994 	mp_clear(&q);
3995 	return res;
3996 }
3997 
3998 /* Source: /usr/cvsroot/libtommath/dist/libtommath/bn_mp_reduce_2k.c,v $ */
3999 /* Revision: 1.1.1.1 $ */
4000 /* Date: 2011/03/12 22:58:18 $ */
4001 
4002 /*
4003  * shifts with subtractions when the result is greater than b.
4004  *
4005  * The method is slightly modified to shift B unconditionally upto just under
4006  * the leading bit of b.  This saves alot of multiple precision shifting.
4007  */
4008 static int
mp_montgomery_calc_normalization(mp_int * a,mp_int * b)4009 mp_montgomery_calc_normalization(mp_int * a, mp_int * b)
4010 {
4011 	int     x, bits, res;
4012 
4013 	/* how many bits of last digit does b use */
4014 	bits = mp_count_bits(b) % DIGIT_BIT;
4015 
4016 	if (b->used > 1) {
4017 		if ((res = mp_2expt(a, (b->used - 1) * DIGIT_BIT + bits - 1)) != MP_OKAY) {
4018 			return res;
4019 		}
4020 	} else {
4021 		set_word(a, 1);
4022 		bits = 1;
4023 	}
4024 
4025 
4026 	/* now compute C = A * B mod b */
4027 	for (x = bits - 1; x < (int)DIGIT_BIT; x++) {
4028 		if ((res = doubled(a, a)) != MP_OKAY) {
4029 			return res;
4030 		}
4031 		if (compare_magnitude(a, b) != MP_LT) {
4032 			if ((res = basic_subtract(a, b, a)) != MP_OKAY) {
4033 				return res;
4034 			}
4035 		}
4036 	}
4037 
4038 	return MP_OKAY;
4039 }
4040 
4041 /* Source: /usr/cvsroot/libtommath/dist/libtommath/bn_mp_montgomery_calc_normalization.c,v $ */
4042 /* Revision: 1.1.1.1 $ */
4043 /* Date: 2011/03/12 22:58:18 $ */
4044 
4045 /* computes Y == G**X mod P, HAC pp.616, Algorithm 14.85
4046  *
4047  * Uses a left-to-right k-ary sliding window to compute the modular exponentiation.
4048  * The value of k changes based on the size of the exponent.
4049  *
4050  * Uses Montgomery or Diminished Radix reduction [whichever appropriate]
4051  */
4052 
4053 #define TAB_SIZE 256
4054 
4055 static int
fast_exponent_modulo(mp_int * G,mp_int * X,mp_int * P,mp_int * Y,int redmode)4056 fast_exponent_modulo(mp_int * G, mp_int * X, mp_int * P, mp_int * Y, int redmode)
4057 {
4058 	mp_int  M[TAB_SIZE], res;
4059 	mp_digit buf, mp;
4060 	int     err, bitbuf, bitcpy, bitcnt, mode, digidx, x, y, winsize;
4061 
4062 	/* use a pointer to the reduction algorithm.  This allows us to use
4063 	* one of many reduction algorithms without modding the guts of
4064 	* the code with if statements everywhere.
4065 	*/
4066 	int     (*redux)(mp_int*,mp_int*,mp_digit);
4067 
4068 #if defined(__minix)
4069 	mp = 0; /* LSC: Fix -Os compilation: -Werror=maybe-uninitialized */
4070 #endif /* defined(__minix) */
4071 	winsize = find_window_size(X);
4072 
4073 	/* init M array */
4074 	/* init first cell */
4075 	if ((err = mp_init(&M[1])) != MP_OKAY) {
4076 		return err;
4077 	}
4078 
4079 	/* now init the second half of the array */
4080 	for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
4081 		if ((err = mp_init(&M[x])) != MP_OKAY) {
4082 			for (y = 1<<(winsize-1); y < x; y++) {
4083 				mp_clear(&M[y]);
4084 			}
4085 			mp_clear(&M[1]);
4086 			return err;
4087 		}
4088 	}
4089 
4090 	/* determine and setup reduction code */
4091 	if (redmode == 0) {
4092 		/* now setup montgomery  */
4093 		if ((err = mp_montgomery_setup(P, &mp)) != MP_OKAY) {
4094 			goto LBL_M;
4095 		}
4096 
4097 		/* automatically pick the comba one if available (saves quite a few calls/ifs) */
4098 		if (can_use_fast_column_array(P->used + P->used + 1, P->used)) {
4099 			redux = fast_mp_montgomery_reduce;
4100 		} else {
4101 			/* use slower baseline Montgomery method */
4102 			redux = mp_montgomery_reduce;
4103 		}
4104 	} else if (redmode == 1) {
4105 		/* setup DR reduction for moduli of the form B**k - b */
4106 		diminished_radix_setup(P, &mp);
4107 		redux = diminished_radix_reduce;
4108 	} else {
4109 		/* setup DR reduction for moduli of the form 2**k - b */
4110 		if ((err = mp_reduce_2k_setup(P, &mp)) != MP_OKAY) {
4111 			goto LBL_M;
4112 		}
4113 		redux = mp_reduce_2k;
4114 	}
4115 
4116 	/* setup result */
4117 	if ((err = mp_init(&res)) != MP_OKAY) {
4118 		goto LBL_M;
4119 	}
4120 
4121 	/* create M table
4122 	*
4123 
4124 	*
4125 	* The first half of the table is not computed though accept for M[0] and M[1]
4126 	*/
4127 
4128 	if (redmode == 0) {
4129 		/* now we need R mod m */
4130 		if ((err = mp_montgomery_calc_normalization(&res, P)) != MP_OKAY) {
4131 			goto LBL_RES;
4132 		}
4133 
4134 		/* now set M[1] to G * R mod m */
4135 		if ((err = multiply_modulo(&M[1], G, &res, P)) != MP_OKAY) {
4136 			goto LBL_RES;
4137 		}
4138 	} else {
4139 		set_word(&res, 1);
4140 		if ((err = modulo(G, P, &M[1])) != MP_OKAY) {
4141 			goto LBL_RES;
4142 		}
4143 	}
4144 
4145 	/* compute the value at M[1<<(winsize-1)] by squaring M[1] (winsize-1) times */
4146 	if ((err = mp_copy( &M[1], &M[1 << (winsize - 1)])) != MP_OKAY) {
4147 		goto LBL_RES;
4148 	}
4149 
4150 	for (x = 0; x < (winsize - 1); x++) {
4151 		if ((err = square(&M[1 << (winsize - 1)], &M[1 << (winsize - 1)])) != MP_OKAY) {
4152 			goto LBL_RES;
4153 		}
4154 		if ((err = (*redux)(&M[1 << (winsize - 1)], P, mp)) != MP_OKAY) {
4155 			goto LBL_RES;
4156 		}
4157 	}
4158 
4159 	/* create upper table */
4160 	for (x = (1 << (winsize - 1)) + 1; x < (1 << winsize); x++) {
4161 		if ((err = signed_multiply(&M[x - 1], &M[1], &M[x])) != MP_OKAY) {
4162 			goto LBL_RES;
4163 		}
4164 		if ((err = (*redux)(&M[x], P, mp)) != MP_OKAY) {
4165 			goto LBL_RES;
4166 		}
4167 	}
4168 
4169 	/* set initial mode and bit cnt */
4170 	mode = 0;
4171 	bitcnt = 1;
4172 	buf = 0;
4173 	digidx = X->used - 1;
4174 	bitcpy = 0;
4175 	bitbuf = 0;
4176 
4177 	for (;;) {
4178 		/* grab next digit as required */
4179 		if (--bitcnt == 0) {
4180 			/* if digidx == -1 we are out of digits so break */
4181 			if (digidx == -1) {
4182 				break;
4183 			}
4184 			/* read next digit and reset bitcnt */
4185 			buf = X->dp[digidx--];
4186 			bitcnt = (int)DIGIT_BIT;
4187 		}
4188 
4189 		/* grab the next msb from the exponent */
4190 		y = (int)(mp_digit)((mp_digit)buf >> (unsigned)(DIGIT_BIT - 1)) & 1;
4191 		buf <<= (mp_digit)1;
4192 
4193 		/* if the bit is zero and mode == 0 then we ignore it
4194 		* These represent the leading zero bits before the first 1 bit
4195 		* in the exponent.  Technically this opt is not required but it
4196 		* does lower the # of trivial squaring/reductions used
4197 		*/
4198 		if (mode == 0 && y == 0) {
4199 			continue;
4200 		}
4201 
4202 		/* if the bit is zero and mode == 1 then we square */
4203 		if (mode == 1 && y == 0) {
4204 			if ((err = square(&res, &res)) != MP_OKAY) {
4205 				goto LBL_RES;
4206 			}
4207 			if ((err = (*redux)(&res, P, mp)) != MP_OKAY) {
4208 				goto LBL_RES;
4209 			}
4210 			continue;
4211 		}
4212 
4213 		/* else we add it to the window */
4214 		bitbuf |= (y << (winsize - ++bitcpy));
4215 		mode = 2;
4216 
4217 		if (bitcpy == winsize) {
4218 			/* ok window is filled so square as required and multiply  */
4219 			/* square first */
4220 			for (x = 0; x < winsize; x++) {
4221 				if ((err = square(&res, &res)) != MP_OKAY) {
4222 					goto LBL_RES;
4223 				}
4224 				if ((err = (*redux)(&res, P, mp)) != MP_OKAY) {
4225 					goto LBL_RES;
4226 				}
4227 			}
4228 
4229 			/* then multiply */
4230 			if ((err = signed_multiply(&res, &M[bitbuf], &res)) != MP_OKAY) {
4231 				goto LBL_RES;
4232 			}
4233 			if ((err = (*redux)(&res, P, mp)) != MP_OKAY) {
4234 				goto LBL_RES;
4235 			}
4236 
4237 			/* empty window and reset */
4238 			bitcpy = 0;
4239 			bitbuf = 0;
4240 			mode = 1;
4241 		}
4242 	}
4243 
4244 	/* if bits remain then square/multiply */
4245 	if (mode == 2 && bitcpy > 0) {
4246 		/* square then multiply if the bit is set */
4247 		for (x = 0; x < bitcpy; x++) {
4248 			if ((err = square(&res, &res)) != MP_OKAY) {
4249 				goto LBL_RES;
4250 			}
4251 			if ((err = (*redux)(&res, P, mp)) != MP_OKAY) {
4252 				goto LBL_RES;
4253 			}
4254 
4255 			/* get next bit of the window */
4256 			bitbuf <<= 1;
4257 			if ((bitbuf & (1 << winsize)) != 0) {
4258 				/* then multiply */
4259 				if ((err = signed_multiply(&res, &M[1], &res)) != MP_OKAY) {
4260 					goto LBL_RES;
4261 				}
4262 				if ((err = (*redux)(&res, P, mp)) != MP_OKAY) {
4263 					goto LBL_RES;
4264 				}
4265 			}
4266 		}
4267 	}
4268 
4269 	if (redmode == 0) {
4270 		/* fixup result if Montgomery reduction is used
4271 		* recall that any value in a Montgomery system is
4272 		* actually multiplied by R mod n.  So we have
4273 		* to reduce one more time to cancel out the factor
4274 		* of R.
4275 		*/
4276 		if ((err = (*redux)(&res, P, mp)) != MP_OKAY) {
4277 			goto LBL_RES;
4278 		}
4279 	}
4280 
4281 	/* swap res with Y */
4282 	mp_exch(&res, Y);
4283 	err = MP_OKAY;
4284 LBL_RES:
4285 	mp_clear(&res);
4286 LBL_M:
4287 	mp_clear(&M[1]);
4288 	for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
4289 		mp_clear(&M[x]);
4290 	}
4291 	return err;
4292 }
4293 
4294 /* Source: /usr/cvsroot/libtommath/dist/libtommath/bn_fast_exponent_modulo.c,v $ */
4295 /* Revision: 1.4 $ */
4296 /* Date: 2011/03/18 16:43:04 $ */
4297 
4298 /* this is a shell function that calls either the normal or Montgomery
4299  * exptmod functions.  Originally the call to the montgomery code was
4300  * embedded in the normal function but that wasted alot of stack space
4301  * for nothing (since 99% of the time the Montgomery code would be called)
4302  */
4303 static int
exponent_modulo(mp_int * G,mp_int * X,mp_int * P,mp_int * Y)4304 exponent_modulo(mp_int * G, mp_int * X, mp_int * P, mp_int *Y)
4305 {
4306 	int diminished_radix;
4307 
4308 	/* modulus P must be positive */
4309 	if (P->sign == MP_NEG) {
4310 		return MP_VAL;
4311 	}
4312 
4313 	/* if exponent X is negative we have to recurse */
4314 	if (X->sign == MP_NEG) {
4315 		mp_int tmpG, tmpX;
4316 		int err;
4317 
4318 		/* first compute 1/G mod P */
4319 		if ((err = mp_init(&tmpG)) != MP_OKAY) {
4320 			return err;
4321 		}
4322 		if ((err = modular_inverse(&tmpG, G, P)) != MP_OKAY) {
4323 			mp_clear(&tmpG);
4324 			return err;
4325 		}
4326 
4327 		/* now get |X| */
4328 		if ((err = mp_init(&tmpX)) != MP_OKAY) {
4329 			mp_clear(&tmpG);
4330 			return err;
4331 		}
4332 		if ((err = absolute(X, &tmpX)) != MP_OKAY) {
4333 			mp_clear_multi(&tmpG, &tmpX, NULL);
4334 			return err;
4335 		}
4336 
4337 		/* and now compute (1/G)**|X| instead of G**X [X < 0] */
4338 		err = exponent_modulo(&tmpG, &tmpX, P, Y);
4339 		mp_clear_multi(&tmpG, &tmpX, NULL);
4340 		return err;
4341 	}
4342 
4343 	/* modified diminished radix reduction */
4344 	if (mp_reduce_is_2k_l(P) == MP_YES) {
4345 		return basic_exponent_mod(G, X, P, Y, 1);
4346 	}
4347 
4348 	/* is it a DR modulus? */
4349 	diminished_radix = is_diminished_radix_modulus(P);
4350 
4351 	/* if not, is it a unrestricted DR modulus? */
4352 	if (!diminished_radix) {
4353 		diminished_radix = mp_reduce_is_2k(P) << 1;
4354 	}
4355 
4356 	/* if the modulus is odd or diminished_radix, use the montgomery method */
4357 	if (BN_is_odd(P) == 1 || diminished_radix) {
4358 		return fast_exponent_modulo(G, X, P, Y, diminished_radix);
4359 	}
4360 	/* otherwise use the generic Barrett reduction technique */
4361 	return basic_exponent_mod(G, X, P, Y, 0);
4362 }
4363 
4364 /* reverse an array, used for radix code */
4365 static void
bn_reverse(unsigned char * s,int len)4366 bn_reverse(unsigned char *s, int len)
4367 {
4368 	int     ix, iy;
4369 	uint8_t t;
4370 
4371 	for (ix = 0, iy = len - 1; ix < iy ; ix++, --iy) {
4372 		t = s[ix];
4373 		s[ix] = s[iy];
4374 		s[iy] = t;
4375 	}
4376 }
4377 
4378 static inline int
is_power_of_two(mp_digit b,int * p)4379 is_power_of_two(mp_digit b, int *p)
4380 {
4381 	int x;
4382 
4383 	/* fast return if no power of two */
4384 	if ((b==0) || (b & (b-1))) {
4385 		return 0;
4386 	}
4387 
4388 	for (x = 0; x < DIGIT_BIT; x++) {
4389 		if (b == (((mp_digit)1)<<x)) {
4390 			*p = x;
4391 			return 1;
4392 		}
4393 	}
4394 	return 0;
4395 }
4396 
4397 /* single digit division (based on routine from MPI) */
4398 static int
signed_divide_word(mp_int * a,mp_digit b,mp_int * c,mp_digit * d)4399 signed_divide_word(mp_int *a, mp_digit b, mp_int *c, mp_digit *d)
4400 {
4401 	mp_int  q;
4402 	mp_word w;
4403 	mp_digit t;
4404 	int     res, ix;
4405 
4406 	/* cannot divide by zero */
4407 	if (b == 0) {
4408 		return MP_VAL;
4409 	}
4410 
4411 	/* quick outs */
4412 	if (b == 1 || MP_ISZERO(a) == 1) {
4413 		if (d != NULL) {
4414 			*d = 0;
4415 		}
4416 		if (c != NULL) {
4417 			return mp_copy(a, c);
4418 		}
4419 		return MP_OKAY;
4420 	}
4421 
4422 	/* power of two ? */
4423 	if (is_power_of_two(b, &ix) == 1) {
4424 		if (d != NULL) {
4425 			*d = a->dp[0] & ((((mp_digit)1)<<ix) - 1);
4426 		}
4427 		if (c != NULL) {
4428 			return rshift_bits(a, ix, c, NULL);
4429 		}
4430 		return MP_OKAY;
4431 	}
4432 
4433 	/* three? */
4434 	if (b == 3) {
4435 		return third(a, c, d);
4436 	}
4437 
4438 	/* no easy answer [c'est la vie].  Just division */
4439 	if ((res = mp_init_size(&q, a->used)) != MP_OKAY) {
4440 		return res;
4441 	}
4442 
4443 	q.used = a->used;
4444 	q.sign = a->sign;
4445 	w = 0;
4446 	for (ix = a->used - 1; ix >= 0; ix--) {
4447 		w = (w << ((mp_word)DIGIT_BIT)) | ((mp_word)a->dp[ix]);
4448 
4449 		if (w >= b) {
4450 			t = (mp_digit)(w / b);
4451 			w -= ((mp_word)t) * ((mp_word)b);
4452 		} else {
4453 			t = 0;
4454 		}
4455 		q.dp[ix] = (mp_digit)t;
4456 	}
4457 
4458 	if (d != NULL) {
4459 		*d = (mp_digit)w;
4460 	}
4461 
4462 	if (c != NULL) {
4463 		trim_unused_digits(&q);
4464 		mp_exch(&q, c);
4465 	}
4466 	mp_clear(&q);
4467 
4468 	return res;
4469 }
4470 
4471 static const mp_digit ltm_prime_tab[] = {
4472 	0x0002, 0x0003, 0x0005, 0x0007, 0x000B, 0x000D, 0x0011, 0x0013,
4473 	0x0017, 0x001D, 0x001F, 0x0025, 0x0029, 0x002B, 0x002F, 0x0035,
4474 	0x003B, 0x003D, 0x0043, 0x0047, 0x0049, 0x004F, 0x0053, 0x0059,
4475 	0x0061, 0x0065, 0x0067, 0x006B, 0x006D, 0x0071, 0x007F,
4476 #ifndef MP_8BIT
4477 	0x0083,
4478 	0x0089, 0x008B, 0x0095, 0x0097, 0x009D, 0x00A3, 0x00A7, 0x00AD,
4479 	0x00B3, 0x00B5, 0x00BF, 0x00C1, 0x00C5, 0x00C7, 0x00D3, 0x00DF,
4480 	0x00E3, 0x00E5, 0x00E9, 0x00EF, 0x00F1, 0x00FB, 0x0101, 0x0107,
4481 	0x010D, 0x010F, 0x0115, 0x0119, 0x011B, 0x0125, 0x0133, 0x0137,
4482 
4483 	0x0139, 0x013D, 0x014B, 0x0151, 0x015B, 0x015D, 0x0161, 0x0167,
4484 	0x016F, 0x0175, 0x017B, 0x017F, 0x0185, 0x018D, 0x0191, 0x0199,
4485 	0x01A3, 0x01A5, 0x01AF, 0x01B1, 0x01B7, 0x01BB, 0x01C1, 0x01C9,
4486 	0x01CD, 0x01CF, 0x01D3, 0x01DF, 0x01E7, 0x01EB, 0x01F3, 0x01F7,
4487 	0x01FD, 0x0209, 0x020B, 0x021D, 0x0223, 0x022D, 0x0233, 0x0239,
4488 	0x023B, 0x0241, 0x024B, 0x0251, 0x0257, 0x0259, 0x025F, 0x0265,
4489 	0x0269, 0x026B, 0x0277, 0x0281, 0x0283, 0x0287, 0x028D, 0x0293,
4490 	0x0295, 0x02A1, 0x02A5, 0x02AB, 0x02B3, 0x02BD, 0x02C5, 0x02CF,
4491 
4492 	0x02D7, 0x02DD, 0x02E3, 0x02E7, 0x02EF, 0x02F5, 0x02F9, 0x0301,
4493 	0x0305, 0x0313, 0x031D, 0x0329, 0x032B, 0x0335, 0x0337, 0x033B,
4494 	0x033D, 0x0347, 0x0355, 0x0359, 0x035B, 0x035F, 0x036D, 0x0371,
4495 	0x0373, 0x0377, 0x038B, 0x038F, 0x0397, 0x03A1, 0x03A9, 0x03AD,
4496 	0x03B3, 0x03B9, 0x03C7, 0x03CB, 0x03D1, 0x03D7, 0x03DF, 0x03E5,
4497 	0x03F1, 0x03F5, 0x03FB, 0x03FD, 0x0407, 0x0409, 0x040F, 0x0419,
4498 	0x041B, 0x0425, 0x0427, 0x042D, 0x043F, 0x0443, 0x0445, 0x0449,
4499 	0x044F, 0x0455, 0x045D, 0x0463, 0x0469, 0x047F, 0x0481, 0x048B,
4500 
4501 	0x0493, 0x049D, 0x04A3, 0x04A9, 0x04B1, 0x04BD, 0x04C1, 0x04C7,
4502 	0x04CD, 0x04CF, 0x04D5, 0x04E1, 0x04EB, 0x04FD, 0x04FF, 0x0503,
4503 	0x0509, 0x050B, 0x0511, 0x0515, 0x0517, 0x051B, 0x0527, 0x0529,
4504 	0x052F, 0x0551, 0x0557, 0x055D, 0x0565, 0x0577, 0x0581, 0x058F,
4505 	0x0593, 0x0595, 0x0599, 0x059F, 0x05A7, 0x05AB, 0x05AD, 0x05B3,
4506 	0x05BF, 0x05C9, 0x05CB, 0x05CF, 0x05D1, 0x05D5, 0x05DB, 0x05E7,
4507 	0x05F3, 0x05FB, 0x0607, 0x060D, 0x0611, 0x0617, 0x061F, 0x0623,
4508 	0x062B, 0x062F, 0x063D, 0x0641, 0x0647, 0x0649, 0x064D, 0x0653
4509 #endif
4510 };
4511 
4512 #define PRIME_SIZE	__arraycount(ltm_prime_tab)
4513 
4514 static inline int
mp_prime_is_divisible(mp_int * a,int * result)4515 mp_prime_is_divisible(mp_int *a, int *result)
4516 {
4517 	int     err, ix;
4518 	mp_digit res;
4519 
4520 	/* default to not */
4521 	*result = MP_NO;
4522 
4523 	for (ix = 0; ix < (int)PRIME_SIZE; ix++) {
4524 		/* what is a mod LBL_prime_tab[ix] */
4525 		if ((err = signed_divide_word(a, ltm_prime_tab[ix], NULL, &res)) != MP_OKAY) {
4526 			return err;
4527 		}
4528 
4529 		/* is the residue zero? */
4530 		if (res == 0) {
4531 			*result = MP_YES;
4532 			return MP_OKAY;
4533 		}
4534 	}
4535 
4536 	return MP_OKAY;
4537 }
4538 
4539 /* single digit addition */
4540 static int
add_single_digit(mp_int * a,mp_digit b,mp_int * c)4541 add_single_digit(mp_int *a, mp_digit b, mp_int *c)
4542 {
4543 	int     res, ix, oldused;
4544 	mp_digit *tmpa, *tmpc, mu;
4545 
4546 	/* grow c as required */
4547 	if (c->alloc < a->used + 1) {
4548 		if ((res = mp_grow(c, a->used + 1)) != MP_OKAY) {
4549 			return res;
4550 		}
4551 	}
4552 
4553 	/* if a is negative and |a| >= b, call c = |a| - b */
4554 	if (a->sign == MP_NEG && (a->used > 1 || a->dp[0] >= b)) {
4555 		/* temporarily fix sign of a */
4556 		a->sign = MP_ZPOS;
4557 
4558 		/* c = |a| - b */
4559 		res = signed_subtract_word(a, b, c);
4560 
4561 		/* fix sign  */
4562 		a->sign = c->sign = MP_NEG;
4563 
4564 		/* clamp */
4565 		trim_unused_digits(c);
4566 
4567 		return res;
4568 	}
4569 
4570 	/* old number of used digits in c */
4571 	oldused = c->used;
4572 
4573 	/* sign always positive */
4574 	c->sign = MP_ZPOS;
4575 
4576 	/* source alias */
4577 	tmpa = a->dp;
4578 
4579 	/* destination alias */
4580 	tmpc = c->dp;
4581 
4582 	/* if a is positive */
4583 	if (a->sign == MP_ZPOS) {
4584 		/* add digit, after this we're propagating
4585 		* the carry.
4586 		*/
4587 		*tmpc = *tmpa++ + b;
4588 		mu = *tmpc >> DIGIT_BIT;
4589 		*tmpc++ &= MP_MASK;
4590 
4591 		/* now handle rest of the digits */
4592 		for (ix = 1; ix < a->used; ix++) {
4593 			*tmpc = *tmpa++ + mu;
4594 			mu = *tmpc >> DIGIT_BIT;
4595 			*tmpc++ &= MP_MASK;
4596 		}
4597 		/* set final carry */
4598 		ix++;
4599 		*tmpc++  = mu;
4600 
4601 		/* setup size */
4602 		c->used = a->used + 1;
4603 	} else {
4604 		/* a was negative and |a| < b */
4605 		c->used  = 1;
4606 
4607 		/* the result is a single digit */
4608 		if (a->used == 1) {
4609 			*tmpc++  =  b - a->dp[0];
4610 		} else {
4611 			*tmpc++  =  b;
4612 		}
4613 
4614 		/* setup count so the clearing of oldused
4615 		* can fall through correctly
4616 		*/
4617 		ix = 1;
4618 	}
4619 
4620 	/* now zero to oldused */
4621 	while (ix++ < oldused) {
4622 		*tmpc++ = 0;
4623 	}
4624 	trim_unused_digits(c);
4625 
4626 	return MP_OKAY;
4627 }
4628 
4629 /* single digit subtraction */
4630 static int
signed_subtract_word(mp_int * a,mp_digit b,mp_int * c)4631 signed_subtract_word(mp_int *a, mp_digit b, mp_int *c)
4632 {
4633 	mp_digit *tmpa, *tmpc, mu;
4634 	int       res, ix, oldused;
4635 
4636 	/* grow c as required */
4637 	if (c->alloc < a->used + 1) {
4638 		if ((res = mp_grow(c, a->used + 1)) != MP_OKAY) {
4639 			return res;
4640 		}
4641 	}
4642 
4643 	/* if a is negative just do an unsigned
4644 	* addition [with fudged signs]
4645 	*/
4646 	if (a->sign == MP_NEG) {
4647 		a->sign = MP_ZPOS;
4648 		res = add_single_digit(a, b, c);
4649 		a->sign = c->sign = MP_NEG;
4650 
4651 		/* clamp */
4652 		trim_unused_digits(c);
4653 
4654 		return res;
4655 	}
4656 
4657 	/* setup regs */
4658 	oldused = c->used;
4659 	tmpa = a->dp;
4660 	tmpc = c->dp;
4661 
4662 	/* if a <= b simply fix the single digit */
4663 	if ((a->used == 1 && a->dp[0] <= b) || a->used == 0) {
4664 		if (a->used == 1) {
4665 			*tmpc++ = b - *tmpa;
4666 		} else {
4667 			*tmpc++ = b;
4668 		}
4669 		ix = 1;
4670 
4671 		/* negative/1digit */
4672 		c->sign = MP_NEG;
4673 		c->used = 1;
4674 	} else {
4675 		/* positive/size */
4676 		c->sign = MP_ZPOS;
4677 		c->used = a->used;
4678 
4679 		/* subtract first digit */
4680 		*tmpc = *tmpa++ - b;
4681 		mu = *tmpc >> (sizeof(mp_digit) * CHAR_BIT - 1);
4682 		*tmpc++ &= MP_MASK;
4683 
4684 		/* handle rest of the digits */
4685 		for (ix = 1; ix < a->used; ix++) {
4686 			*tmpc = *tmpa++ - mu;
4687 			mu = *tmpc >> (sizeof(mp_digit) * CHAR_BIT - 1);
4688 			*tmpc++ &= MP_MASK;
4689 		}
4690 	}
4691 
4692 	/* zero excess digits */
4693 	while (ix++ < oldused) {
4694 		*tmpc++ = 0;
4695 	}
4696 	trim_unused_digits(c);
4697 	return MP_OKAY;
4698 }
4699 
4700 static const int lnz[16] = {
4701 	4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0
4702 };
4703 
4704 /* Counts the number of lsbs which are zero before the first zero bit */
4705 static int
mp_cnt_lsb(mp_int * a)4706 mp_cnt_lsb(mp_int *a)
4707 {
4708 	int x;
4709 	mp_digit q, qq;
4710 
4711 	/* easy out */
4712 	if (MP_ISZERO(a) == 1) {
4713 		return 0;
4714 	}
4715 
4716 	/* scan lower digits until non-zero */
4717 	for (x = 0; x < a->used && a->dp[x] == 0; x++) {
4718 	}
4719 	q = a->dp[x];
4720 	x *= DIGIT_BIT;
4721 
4722 	/* now scan this digit until a 1 is found */
4723 	if ((q & 1) == 0) {
4724 		do {
4725 			 qq  = q & 15;
4726 			 /* LINTED previous op ensures range of qq */
4727 			 x  += lnz[qq];
4728 			 q >>= 4;
4729 		} while (qq == 0);
4730 	}
4731 	return x;
4732 }
4733 
4734 /* c = a * a (mod b) */
4735 static int
square_modulo(mp_int * a,mp_int * b,mp_int * c)4736 square_modulo(mp_int *a, mp_int *b, mp_int *c)
4737 {
4738 	int     res;
4739 	mp_int  t;
4740 
4741 	if ((res = mp_init(&t)) != MP_OKAY) {
4742 		return res;
4743 	}
4744 
4745 	if ((res = square(a, &t)) != MP_OKAY) {
4746 		mp_clear(&t);
4747 		return res;
4748 	}
4749 	res = modulo(&t, b, c);
4750 	mp_clear(&t);
4751 	return res;
4752 }
4753 
4754 static int
mp_prime_miller_rabin(mp_int * a,mp_int * b,int * result)4755 mp_prime_miller_rabin(mp_int *a, mp_int *b, int *result)
4756 {
4757 	mp_int  n1, y, r;
4758 	int     s, j, err;
4759 
4760 	/* default */
4761 	*result = MP_NO;
4762 
4763 	/* ensure b > 1 */
4764 	if (compare_digit(b, 1) != MP_GT) {
4765 		return MP_VAL;
4766 	}
4767 
4768 	/* get n1 = a - 1 */
4769 	if ((err = mp_init_copy(&n1, a)) != MP_OKAY) {
4770 		return err;
4771 	}
4772 	if ((err = signed_subtract_word(&n1, 1, &n1)) != MP_OKAY) {
4773 		goto LBL_N1;
4774 	}
4775 
4776 	/* set 2**s * r = n1 */
4777 	if ((err = mp_init_copy(&r, &n1)) != MP_OKAY) {
4778 		goto LBL_N1;
4779 	}
4780 
4781 	/* count the number of least significant bits
4782 	* which are zero
4783 	*/
4784 	s = mp_cnt_lsb(&r);
4785 
4786 	/* now divide n - 1 by 2**s */
4787 	if ((err = rshift_bits(&r, s, &r, NULL)) != MP_OKAY) {
4788 		goto LBL_R;
4789 	}
4790 
4791 	/* compute y = b**r mod a */
4792 	if ((err = mp_init(&y)) != MP_OKAY) {
4793 		goto LBL_R;
4794 	}
4795 	if ((err = exponent_modulo(b, &r, a, &y)) != MP_OKAY) {
4796 		goto LBL_Y;
4797 	}
4798 
4799 	/* if y != 1 and y != n1 do */
4800 	if (compare_digit(&y, 1) != MP_EQ && signed_compare(&y, &n1) != MP_EQ) {
4801 		j = 1;
4802 		/* while j <= s-1 and y != n1 */
4803 		while ((j <= (s - 1)) && signed_compare(&y, &n1) != MP_EQ) {
4804 			if ((err = square_modulo(&y, a, &y)) != MP_OKAY) {
4805 				goto LBL_Y;
4806 			}
4807 
4808 			/* if y == 1 then composite */
4809 			if (compare_digit(&y, 1) == MP_EQ) {
4810 				goto LBL_Y;
4811 			}
4812 
4813 			++j;
4814 		}
4815 
4816 		/* if y != n1 then composite */
4817 		if (signed_compare(&y, &n1) != MP_EQ) {
4818 			goto LBL_Y;
4819 		}
4820 	}
4821 
4822 	/* probably prime now */
4823 	*result = MP_YES;
4824 LBL_Y:
4825 	mp_clear(&y);
4826 LBL_R:
4827 	mp_clear(&r);
4828 LBL_N1:
4829 	mp_clear(&n1);
4830 	return err;
4831 }
4832 
4833 /* performs a variable number of rounds of Miller-Rabin
4834  *
4835  * Probability of error after t rounds is no more than
4836 
4837  *
4838  * Sets result to 1 if probably prime, 0 otherwise
4839  */
4840 static int
mp_prime_is_prime(mp_int * a,int t,int * result)4841 mp_prime_is_prime(mp_int *a, int t, int *result)
4842 {
4843 	mp_int  b;
4844 	int     ix, err, res;
4845 
4846 	/* default to no */
4847 	*result = MP_NO;
4848 
4849 	/* valid value of t? */
4850 	if (t <= 0 || t > (int)PRIME_SIZE) {
4851 		return MP_VAL;
4852 	}
4853 
4854 	/* is the input equal to one of the primes in the table? */
4855 	for (ix = 0; ix < (int)PRIME_SIZE; ix++) {
4856 		if (compare_digit(a, ltm_prime_tab[ix]) == MP_EQ) {
4857 			*result = 1;
4858 			return MP_OKAY;
4859 		}
4860 	}
4861 
4862 	/* first perform trial division */
4863 	if ((err = mp_prime_is_divisible(a, &res)) != MP_OKAY) {
4864 		return err;
4865 	}
4866 
4867 	/* return if it was trivially divisible */
4868 	if (res == MP_YES) {
4869 		return MP_OKAY;
4870 	}
4871 
4872 	/* now perform the miller-rabin rounds */
4873 	if ((err = mp_init(&b)) != MP_OKAY) {
4874 		return err;
4875 	}
4876 
4877 	for (ix = 0; ix < t; ix++) {
4878 		/* set the prime */
4879 		set_word(&b, ltm_prime_tab[ix]);
4880 
4881 		if ((err = mp_prime_miller_rabin(a, &b, &res)) != MP_OKAY) {
4882 			goto LBL_B;
4883 		}
4884 
4885 		if (res == MP_NO) {
4886 			goto LBL_B;
4887 		}
4888 	}
4889 
4890 	/* passed the test */
4891 	*result = MP_YES;
4892 LBL_B:
4893 	mp_clear(&b);
4894 	return err;
4895 }
4896 
4897 /* returns size of ASCII reprensentation */
4898 static int
mp_radix_size(mp_int * a,int radix,int * size)4899 mp_radix_size(mp_int *a, int radix, int *size)
4900 {
4901 	int     res, digs;
4902 	mp_int  t;
4903 	mp_digit d;
4904 
4905 	*size = 0;
4906 
4907 	/* special case for binary */
4908 	if (radix == 2) {
4909 		*size = mp_count_bits(a) + (a->sign == MP_NEG ? 1 : 0) + 1;
4910 		return MP_OKAY;
4911 	}
4912 
4913 	/* make sure the radix is in range */
4914 	if (radix < 2 || radix > 64) {
4915 		return MP_VAL;
4916 	}
4917 
4918 	if (MP_ISZERO(a) == MP_YES) {
4919 		*size = 2;
4920 		return MP_OKAY;
4921 	}
4922 
4923 	/* digs is the digit count */
4924 	digs = 0;
4925 
4926 	/* if it's negative add one for the sign */
4927 	if (a->sign == MP_NEG) {
4928 		++digs;
4929 	}
4930 
4931 	/* init a copy of the input */
4932 	if ((res = mp_init_copy(&t, a)) != MP_OKAY) {
4933 		return res;
4934 	}
4935 
4936 	/* force temp to positive */
4937 	t.sign = MP_ZPOS;
4938 
4939 	/* fetch out all of the digits */
4940 	while (MP_ISZERO(&t) == MP_NO) {
4941 		if ((res = signed_divide_word(&t, (mp_digit) radix, &t, &d)) != MP_OKAY) {
4942 			mp_clear(&t);
4943 			return res;
4944 		}
4945 		++digs;
4946 	}
4947 	mp_clear(&t);
4948 
4949 	/* return digs + 1, the 1 is for the NULL byte that would be required. */
4950 	*size = digs + 1;
4951 	return MP_OKAY;
4952 }
4953 
4954 static const char *mp_s_rmap = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz+/";
4955 
4956 /* stores a bignum as a ASCII string in a given radix (2..64)
4957  *
4958  * Stores upto maxlen-1 chars and always a NULL byte
4959  */
4960 static int
mp_toradix_n(mp_int * a,char * str,int radix,int maxlen)4961 mp_toradix_n(mp_int * a, char *str, int radix, int maxlen)
4962 {
4963 	int     res, digs;
4964 	mp_int  t;
4965 	mp_digit d;
4966 	char   *_s = str;
4967 
4968 	/* check range of the maxlen, radix */
4969 	if (maxlen < 2 || radix < 2 || radix > 64) {
4970 		return MP_VAL;
4971 	}
4972 
4973 	/* quick out if its zero */
4974 	if (MP_ISZERO(a) == MP_YES) {
4975 		*str++ = '0';
4976 		*str = '\0';
4977 		return MP_OKAY;
4978 	}
4979 
4980 	if ((res = mp_init_copy(&t, a)) != MP_OKAY) {
4981 		return res;
4982 	}
4983 
4984 	/* if it is negative output a - */
4985 	if (t.sign == MP_NEG) {
4986 		/* we have to reverse our digits later... but not the - sign!! */
4987 		++_s;
4988 
4989 		/* store the flag and mark the number as positive */
4990 		*str++ = '-';
4991 		t.sign = MP_ZPOS;
4992 
4993 		/* subtract a char */
4994 		--maxlen;
4995 	}
4996 
4997 	digs = 0;
4998 	while (MP_ISZERO(&t) == 0) {
4999 		if (--maxlen < 1) {
5000 			/* no more room */
5001 			break;
5002 		}
5003 		if ((res = signed_divide_word(&t, (mp_digit) radix, &t, &d)) != MP_OKAY) {
5004 			mp_clear(&t);
5005 			return res;
5006 		}
5007 		/* LINTED -- radix' range is checked above, limits d's range */
5008 		*str++ = mp_s_rmap[d];
5009 		++digs;
5010 	}
5011 
5012 	/* reverse the digits of the string.  In this case _s points
5013 	* to the first digit [exluding the sign] of the number
5014 	*/
5015 	bn_reverse((unsigned char *)_s, digs);
5016 
5017 	/* append a NULL so the string is properly terminated */
5018 	*str = '\0';
5019 
5020 	mp_clear(&t);
5021 	return MP_OKAY;
5022 }
5023 
5024 static char *
formatbn(const BIGNUM * a,const int radix)5025 formatbn(const BIGNUM *a, const int radix)
5026 {
5027 	char	*s;
5028 	int	 len;
5029 
5030 	if (mp_radix_size(__UNCONST(a), radix, &len) != MP_OKAY) {
5031 		return NULL;
5032 	}
5033 	if ((s = allocate(1, (size_t)len)) != NULL) {
5034 		if (mp_toradix_n(__UNCONST(a), s, radix, len) != MP_OKAY) {
5035 			deallocate(s, (size_t)len);
5036 			return NULL;
5037 		}
5038 	}
5039 	return s;
5040 }
5041 
5042 static int
mp_getradix_num(mp_int * a,int radix,char * s)5043 mp_getradix_num(mp_int *a, int radix, char *s)
5044 {
5045 	int err, ch, neg, y;
5046 
5047 	/* clear a */
5048 	mp_zero(a);
5049 
5050 	/* if first digit is - then set negative */
5051 	if ((ch = *s++) == '-') {
5052 		neg = MP_NEG;
5053 		ch = *s++;
5054 	} else {
5055 		neg = MP_ZPOS;
5056 	}
5057 
5058 	for (;;) {
5059 		/* find y in the radix map */
5060 		for (y = 0; y < radix; y++) {
5061 			if (mp_s_rmap[y] == ch) {
5062 				break;
5063 			}
5064 		}
5065 		if (y == radix) {
5066 			break;
5067 		}
5068 
5069 		/* shift up and add */
5070 		if ((err = multiply_digit(a, radix, a)) != MP_OKAY) {
5071 			return err;
5072 		}
5073 		if ((err = add_single_digit(a, y, a)) != MP_OKAY) {
5074 			return err;
5075 		}
5076 
5077 		ch = *s++;
5078 	}
5079 	if (compare_digit(a, 0) != MP_EQ) {
5080 		a->sign = neg;
5081 	}
5082 
5083 	return MP_OKAY;
5084 }
5085 
5086 static int
getbn(BIGNUM ** a,const char * str,int radix)5087 getbn(BIGNUM **a, const char *str, int radix)
5088 {
5089 	int	len;
5090 
5091 	if (a == NULL || str == NULL || (*a = BN_new()) == NULL) {
5092 		return 0;
5093 	}
5094 	if (mp_getradix_num(*a, radix, __UNCONST(str)) != MP_OKAY) {
5095 		return 0;
5096 	}
5097 	mp_radix_size(__UNCONST(*a), radix, &len);
5098 	return len - 1;
5099 }
5100 
5101 /* d = a - b (mod c) */
5102 static int
subtract_modulo(mp_int * a,mp_int * b,mp_int * c,mp_int * d)5103 subtract_modulo(mp_int *a, mp_int *b, mp_int *c, mp_int *d)
5104 {
5105 	int     res;
5106 	mp_int  t;
5107 
5108 
5109 	if ((res = mp_init(&t)) != MP_OKAY) {
5110 		return res;
5111 	}
5112 
5113 	if ((res = signed_subtract(a, b, &t)) != MP_OKAY) {
5114 		mp_clear(&t);
5115 		return res;
5116 	}
5117 	res = modulo(&t, c, d);
5118 	mp_clear(&t);
5119 	return res;
5120 }
5121 
5122 /**************************************************************************/
5123 
5124 /* BIGNUM emulation layer */
5125 
5126 /* essentiually, these are just wrappers around the libtommath functions */
5127 /* usually the order of args changes */
5128 /* the BIGNUM API tends to have more const poisoning */
5129 /* these wrappers also check the arguments passed for sanity */
5130 
5131 BIGNUM *
BN_bin2bn(const uint8_t * data,int len,BIGNUM * ret)5132 BN_bin2bn(const uint8_t *data, int len, BIGNUM *ret)
5133 {
5134 	if (data == NULL) {
5135 		return BN_new();
5136 	}
5137 	if (ret == NULL) {
5138 		ret = BN_new();
5139 	}
5140 	return (mp_read_unsigned_bin(ret, data, len) == MP_OKAY) ? ret : NULL;
5141 }
5142 
5143 /* store in unsigned [big endian] format */
5144 int
BN_bn2bin(const BIGNUM * a,unsigned char * b)5145 BN_bn2bin(const BIGNUM *a, unsigned char *b)
5146 {
5147 	BIGNUM	t;
5148 	int    	x;
5149 
5150 	if (a == NULL || b == NULL) {
5151 		return -1;
5152 	}
5153 	if (mp_init_copy (&t, __UNCONST(a)) != MP_OKAY) {
5154 		return -1;
5155 	}
5156 	for (x = 0; !BN_is_zero(&t) ; ) {
5157 		b[x++] = (unsigned char) (t.dp[0] & 0xff);
5158 		if (rshift_bits(&t, 8, &t, NULL) != MP_OKAY) {
5159 			mp_clear(&t);
5160 			return -1;
5161 		}
5162 	}
5163 	bn_reverse(b, x);
5164 	mp_clear(&t);
5165 	return x;
5166 }
5167 
5168 void
BN_init(BIGNUM * a)5169 BN_init(BIGNUM *a)
5170 {
5171 	if (a != NULL) {
5172 		mp_init(a);
5173 	}
5174 }
5175 
5176 BIGNUM *
BN_new(void)5177 BN_new(void)
5178 {
5179 	BIGNUM	*a;
5180 
5181 	if ((a = allocate(1, sizeof(*a))) != NULL) {
5182 		mp_init(a);
5183 	}
5184 	return a;
5185 }
5186 
5187 /* copy, b = a */
5188 int
BN_copy(BIGNUM * b,const BIGNUM * a)5189 BN_copy(BIGNUM *b, const BIGNUM *a)
5190 {
5191 	if (a == NULL || b == NULL) {
5192 		return MP_VAL;
5193 	}
5194 	return mp_copy(__UNCONST(a), b);
5195 }
5196 
5197 BIGNUM *
BN_dup(const BIGNUM * a)5198 BN_dup(const BIGNUM *a)
5199 {
5200 	BIGNUM	*ret;
5201 
5202 	if (a == NULL) {
5203 		return NULL;
5204 	}
5205 	if ((ret = BN_new()) != NULL) {
5206 		BN_copy(ret, a);
5207 	}
5208 	return ret;
5209 }
5210 
5211 void
BN_swap(BIGNUM * a,BIGNUM * b)5212 BN_swap(BIGNUM *a, BIGNUM *b)
5213 {
5214 	if (a && b) {
5215 		mp_exch(a, b);
5216 	}
5217 }
5218 
5219 int
BN_lshift(BIGNUM * r,const BIGNUM * a,int n)5220 BN_lshift(BIGNUM *r, const BIGNUM *a, int n)
5221 {
5222 	if (r == NULL || a == NULL || n < 0) {
5223 		return 0;
5224 	}
5225 	BN_copy(r, a);
5226 	return lshift_digits(r, n) == MP_OKAY;
5227 }
5228 
5229 int
BN_lshift1(BIGNUM * r,BIGNUM * a)5230 BN_lshift1(BIGNUM *r, BIGNUM *a)
5231 {
5232 	if (r == NULL || a == NULL) {
5233 		return 0;
5234 	}
5235 	BN_copy(r, a);
5236 	return lshift_digits(r, 1) == MP_OKAY;
5237 }
5238 
5239 int
BN_rshift(BIGNUM * r,const BIGNUM * a,int n)5240 BN_rshift(BIGNUM *r, const BIGNUM *a, int n)
5241 {
5242 	if (r == NULL || a == NULL || n < 0) {
5243 		return MP_VAL;
5244 	}
5245 	BN_copy(r, a);
5246 	return rshift_digits(r, n) == MP_OKAY;
5247 }
5248 
5249 int
BN_rshift1(BIGNUM * r,BIGNUM * a)5250 BN_rshift1(BIGNUM *r, BIGNUM *a)
5251 {
5252 	if (r == NULL || a == NULL) {
5253 		return 0;
5254 	}
5255 	BN_copy(r, a);
5256 	return rshift_digits(r, 1) == MP_OKAY;
5257 }
5258 
5259 int
BN_set_word(BIGNUM * a,BN_ULONG w)5260 BN_set_word(BIGNUM *a, BN_ULONG w)
5261 {
5262 	if (a == NULL) {
5263 		return 0;
5264 	}
5265 	set_word(a, w);
5266 	return 1;
5267 }
5268 
5269 int
BN_add(BIGNUM * r,const BIGNUM * a,const BIGNUM * b)5270 BN_add(BIGNUM *r, const BIGNUM *a, const BIGNUM *b)
5271 {
5272 	if (a == NULL || b == NULL || r == NULL) {
5273 		return 0;
5274 	}
5275 	return signed_add(__UNCONST(a), __UNCONST(b), r) == MP_OKAY;
5276 }
5277 
5278 int
BN_sub(BIGNUM * r,const BIGNUM * a,const BIGNUM * b)5279 BN_sub(BIGNUM *r, const BIGNUM *a, const BIGNUM *b)
5280 {
5281 	if (a == NULL || b == NULL || r == NULL) {
5282 		return 0;
5283 	}
5284 	return signed_subtract(__UNCONST(a), __UNCONST(b), r) == MP_OKAY;
5285 }
5286 
5287 int
BN_mul(BIGNUM * r,const BIGNUM * a,const BIGNUM * b,BN_CTX * ctx)5288 BN_mul(BIGNUM *r, const BIGNUM *a, const BIGNUM *b, BN_CTX *ctx)
5289 {
5290 	if (a == NULL || b == NULL || r == NULL) {
5291 		return 0;
5292 	}
5293 	USE_ARG(ctx);
5294 	return signed_multiply(__UNCONST(a), __UNCONST(b), r) == MP_OKAY;
5295 }
5296 
5297 int
BN_div(BIGNUM * dv,BIGNUM * rem,const BIGNUM * a,const BIGNUM * d,BN_CTX * ctx)5298 BN_div(BIGNUM *dv, BIGNUM *rem, const BIGNUM *a, const BIGNUM *d, BN_CTX *ctx)
5299 {
5300 	if ((dv == NULL && rem == NULL) || a == NULL || d == NULL) {
5301 		return 0;
5302 	}
5303 	USE_ARG(ctx);
5304 	return signed_divide(dv, rem, __UNCONST(a), __UNCONST(d)) == MP_OKAY;
5305 }
5306 
5307 /* perform a bit operation on the 2 bignums */
5308 int
BN_bitop(BIGNUM * r,const BIGNUM * a,char op,const BIGNUM * b)5309 BN_bitop(BIGNUM *r, const BIGNUM *a, char op, const BIGNUM *b)
5310 {
5311 	unsigned	ndigits;
5312 	mp_digit	ad;
5313 	mp_digit	bd;
5314 	int		i;
5315 
5316 	if (a == NULL || b == NULL || r == NULL) {
5317 		return 0;
5318 	}
5319 	if (BN_cmp(__UNCONST(a), __UNCONST(b)) >= 0) {
5320 		BN_copy(r, a);
5321 		ndigits = a->used;
5322 	} else {
5323 		BN_copy(r, b);
5324 		ndigits = b->used;
5325 	}
5326 	for (i = 0 ; i < (int)ndigits ; i++) {
5327 		ad = (i > a->used) ? 0 : a->dp[i];
5328 		bd = (i > b->used) ? 0 : b->dp[i];
5329 		switch(op) {
5330 		case '&':
5331 			r->dp[i] = (ad & bd);
5332 			break;
5333 		case '|':
5334 			r->dp[i] = (ad | bd);
5335 			break;
5336 		case '^':
5337 			r->dp[i] = (ad ^ bd);
5338 			break;
5339 		default:
5340 			break;
5341 		}
5342 	}
5343 	return 1;
5344 }
5345 
5346 void
BN_free(BIGNUM * a)5347 BN_free(BIGNUM *a)
5348 {
5349 	if (a) {
5350 		mp_clear(a);
5351 	}
5352 }
5353 
5354 void
BN_clear(BIGNUM * a)5355 BN_clear(BIGNUM *a)
5356 {
5357 	if (a) {
5358 		mp_clear(a);
5359 	}
5360 }
5361 
5362 void
BN_clear_free(BIGNUM * a)5363 BN_clear_free(BIGNUM *a)
5364 {
5365 	if (a) {
5366 		mp_clear(a);
5367 	}
5368 }
5369 
5370 int
BN_num_bytes(const BIGNUM * a)5371 BN_num_bytes(const BIGNUM *a)
5372 {
5373 	if (a == NULL) {
5374 		return MP_VAL;
5375 	}
5376 	return mp_unsigned_bin_size(__UNCONST(a));
5377 }
5378 
5379 int
BN_num_bits(const BIGNUM * a)5380 BN_num_bits(const BIGNUM *a)
5381 {
5382 	if (a == NULL) {
5383 		return 0;
5384 	}
5385 	return mp_count_bits(a);
5386 }
5387 
5388 void
BN_set_negative(BIGNUM * a,int n)5389 BN_set_negative(BIGNUM *a, int n)
5390 {
5391 	if (a) {
5392 		a->sign = (n) ? MP_NEG : 0;
5393 	}
5394 }
5395 
5396 int
BN_cmp(BIGNUM * a,BIGNUM * b)5397 BN_cmp(BIGNUM *a, BIGNUM *b)
5398 {
5399 	if (a == NULL || b == NULL) {
5400 		return MP_VAL;
5401 	}
5402 	switch(signed_compare(a, b)) {
5403 	case MP_LT:
5404 		return -1;
5405 	case MP_GT:
5406 		return 1;
5407 	case MP_EQ:
5408 	default:
5409 		return 0;
5410 	}
5411 }
5412 
5413 int
BN_mod_exp(BIGNUM * Y,BIGNUM * G,BIGNUM * X,BIGNUM * P,BN_CTX * ctx)5414 BN_mod_exp(BIGNUM *Y, BIGNUM *G, BIGNUM *X, BIGNUM *P, BN_CTX *ctx)
5415 {
5416 	if (Y == NULL || G == NULL || X == NULL || P == NULL) {
5417 		return MP_VAL;
5418 	}
5419 	USE_ARG(ctx);
5420 	return exponent_modulo(G, X, P, Y) == MP_OKAY;
5421 }
5422 
5423 BIGNUM *
BN_mod_inverse(BIGNUM * r,BIGNUM * a,const BIGNUM * n,BN_CTX * ctx)5424 BN_mod_inverse(BIGNUM *r, BIGNUM *a, const BIGNUM *n, BN_CTX *ctx)
5425 {
5426 	USE_ARG(ctx);
5427 	if (r == NULL || a == NULL || n == NULL) {
5428 		return NULL;
5429 	}
5430 	return (modular_inverse(r, a, __UNCONST(n)) == MP_OKAY) ? r : NULL;
5431 }
5432 
5433 int
BN_mod_mul(BIGNUM * ret,BIGNUM * a,BIGNUM * b,const BIGNUM * m,BN_CTX * ctx)5434 BN_mod_mul(BIGNUM *ret, BIGNUM *a, BIGNUM *b, const BIGNUM *m, BN_CTX *ctx)
5435 {
5436 	USE_ARG(ctx);
5437 	if (ret == NULL || a == NULL || b == NULL || m == NULL) {
5438 		return 0;
5439 	}
5440 	return multiply_modulo(ret, a, b, __UNCONST(m)) == MP_OKAY;
5441 }
5442 
5443 BN_CTX *
BN_CTX_new(void)5444 BN_CTX_new(void)
5445 {
5446 	return allocate(1, sizeof(BN_CTX));
5447 }
5448 
5449 void
BN_CTX_init(BN_CTX * c)5450 BN_CTX_init(BN_CTX *c)
5451 {
5452 	if (c != NULL) {
5453 		c->arraysize = 15;
5454 		if ((c->v = allocate(sizeof(*c->v), c->arraysize)) == NULL) {
5455 			c->arraysize = 0;
5456 		}
5457 	}
5458 }
5459 
5460 BIGNUM *
BN_CTX_get(BN_CTX * ctx)5461 BN_CTX_get(BN_CTX *ctx)
5462 {
5463 	if (ctx == NULL || ctx->v == NULL || ctx->arraysize == 0 || ctx->count == ctx->arraysize - 1) {
5464 		return NULL;
5465 	}
5466 	return ctx->v[ctx->count++] = BN_new();
5467 }
5468 
5469 void
BN_CTX_start(BN_CTX * ctx)5470 BN_CTX_start(BN_CTX *ctx)
5471 {
5472 	BN_CTX_init(ctx);
5473 }
5474 
5475 void
BN_CTX_free(BN_CTX * c)5476 BN_CTX_free(BN_CTX *c)
5477 {
5478 	unsigned	i;
5479 
5480 	if (c != NULL && c->v != NULL) {
5481 		for (i = 0 ; i < c->count ; i++) {
5482 			BN_clear_free(c->v[i]);
5483 		}
5484 		deallocate(c->v, sizeof(*c->v) * c->arraysize);
5485 	}
5486 }
5487 
5488 void
BN_CTX_end(BN_CTX * ctx)5489 BN_CTX_end(BN_CTX *ctx)
5490 {
5491 	BN_CTX_free(ctx);
5492 }
5493 
5494 char *
BN_bn2hex(const BIGNUM * a)5495 BN_bn2hex(const BIGNUM *a)
5496 {
5497 	return (a == NULL) ? NULL : formatbn(a, 16);
5498 }
5499 
5500 char *
BN_bn2dec(const BIGNUM * a)5501 BN_bn2dec(const BIGNUM *a)
5502 {
5503 	return (a == NULL) ? NULL : formatbn(a, 10);
5504 }
5505 
5506 char *
BN_bn2radix(const BIGNUM * a,unsigned radix)5507 BN_bn2radix(const BIGNUM *a, unsigned radix)
5508 {
5509 	return (a == NULL) ? NULL : formatbn(a, (int)radix);
5510 }
5511 
5512 #ifndef _KERNEL
5513 int
BN_print_fp(FILE * fp,const BIGNUM * a)5514 BN_print_fp(FILE *fp, const BIGNUM *a)
5515 {
5516 	char	*s;
5517 	int	 ret;
5518 
5519 	if (fp == NULL || a == NULL) {
5520 		return 0;
5521 	}
5522 	s = BN_bn2hex(a);
5523 	ret = fprintf(fp, "%s", s);
5524 	deallocate(s, strlen(s) + 1);
5525 	return ret;
5526 }
5527 #endif
5528 
5529 #ifdef BN_RAND_NEEDED
5530 int
BN_rand(BIGNUM * rnd,int bits,int top,int bottom)5531 BN_rand(BIGNUM *rnd, int bits, int top, int bottom)
5532 {
5533 	uint64_t	r;
5534 	int		digits;
5535 	int		i;
5536 
5537 	if (rnd == NULL) {
5538 		return 0;
5539 	}
5540 	mp_init_size(rnd, digits = howmany(bits, DIGIT_BIT));
5541 	for (i = 0 ; i < digits ; i++) {
5542 		r = (uint64_t)arc4random();
5543 		r <<= 32;
5544 		r |= arc4random();
5545 		rnd->dp[i] = (r & MP_MASK);
5546 	}
5547 	if (top == 0) {
5548 		rnd->dp[rnd->used - 1] |= (((mp_digit)1)<<((mp_digit)DIGIT_BIT));
5549 	}
5550 	if (top == 1) {
5551 		rnd->dp[rnd->used - 1] |= (((mp_digit)1)<<((mp_digit)DIGIT_BIT));
5552 		rnd->dp[rnd->used - 1] |= (((mp_digit)1)<<((mp_digit)(DIGIT_BIT - 1)));
5553 	}
5554 	if (bottom) {
5555 		rnd->dp[0] |= 0x1;
5556 	}
5557 	return 1;
5558 }
5559 
5560 int
BN_rand_range(BIGNUM * rnd,BIGNUM * range)5561 BN_rand_range(BIGNUM *rnd, BIGNUM *range)
5562 {
5563 	if (rnd == NULL || range == NULL || BN_is_zero(range)) {
5564 		return 0;
5565 	}
5566 	BN_rand(rnd, BN_num_bits(range), 1, 0);
5567 	return modulo(rnd, range, rnd) == MP_OKAY;
5568 }
5569 #endif
5570 
5571 int
BN_is_prime(const BIGNUM * a,int checks,void (* callback)(int,int,void *),BN_CTX * ctx,void * cb_arg)5572 BN_is_prime(const BIGNUM *a, int checks, void (*callback)(int, int, void *), BN_CTX *ctx, void *cb_arg)
5573 {
5574 	int	primality;
5575 
5576 	if (a == NULL) {
5577 		return 0;
5578 	}
5579 	USE_ARG(ctx);
5580 	USE_ARG(cb_arg);
5581 	USE_ARG(callback);
5582 	return (mp_prime_is_prime(__UNCONST(a), checks, &primality) == MP_OKAY) ? primality : 0;
5583 }
5584 
5585 const BIGNUM *
BN_value_one(void)5586 BN_value_one(void)
5587 {
5588 	static mp_digit		digit = 1UL;
5589 	static const BIGNUM	one = { &digit, 1, 1, 0 };
5590 
5591 	return &one;
5592 }
5593 
5594 int
BN_hex2bn(BIGNUM ** a,const char * str)5595 BN_hex2bn(BIGNUM **a, const char *str)
5596 {
5597 	return getbn(a, str, 16);
5598 }
5599 
5600 int
BN_dec2bn(BIGNUM ** a,const char * str)5601 BN_dec2bn(BIGNUM **a, const char *str)
5602 {
5603 	return getbn(a, str, 10);
5604 }
5605 
5606 int
BN_radix2bn(BIGNUM ** a,const char * str,unsigned radix)5607 BN_radix2bn(BIGNUM **a, const char *str, unsigned radix)
5608 {
5609 	return getbn(a, str, (int)radix);
5610 }
5611 
5612 int
BN_mod_sub(BIGNUM * r,BIGNUM * a,BIGNUM * b,const BIGNUM * m,BN_CTX * ctx)5613 BN_mod_sub(BIGNUM *r, BIGNUM *a, BIGNUM *b, const BIGNUM *m, BN_CTX *ctx)
5614 {
5615 	USE_ARG(ctx);
5616 	if (r == NULL || a == NULL || b == NULL || m == NULL) {
5617 		return 0;
5618 	}
5619 	return subtract_modulo(a, b, __UNCONST(m), r) == MP_OKAY;
5620 }
5621 
5622 int
BN_is_bit_set(const BIGNUM * a,int n)5623 BN_is_bit_set(const BIGNUM *a, int n)
5624 {
5625 	if (a == NULL || n < 0 || n >= a->used * DIGIT_BIT) {
5626 		return 0;
5627 	}
5628 	return (a->dp[n / DIGIT_BIT] & (1 << (n % DIGIT_BIT))) ? 1 : 0;
5629 }
5630 
5631 /* raise 'a' to power of 'b' */
5632 int
BN_raise(BIGNUM * res,BIGNUM * a,BIGNUM * b)5633 BN_raise(BIGNUM *res, BIGNUM *a, BIGNUM *b)
5634 {
5635 	uint64_t	 exponent;
5636 	BIGNUM		*power;
5637 	BIGNUM		*temp;
5638 	char		*t;
5639 
5640 	t = BN_bn2dec(b);
5641 	exponent = (uint64_t)strtoull(t, NULL, 10);
5642 	free(t);
5643 	if (exponent == 0) {
5644 		BN_copy(res, BN_value_one());
5645 	} else {
5646 		power = BN_dup(a);
5647 		for ( ; (exponent & 1) == 0 ; exponent >>= 1) {
5648 			BN_mul(power, power, power, NULL);
5649 		}
5650 		temp = BN_dup(power);
5651 		for (exponent >>= 1 ; exponent > 0 ; exponent >>= 1) {
5652 			BN_mul(power, power, power, NULL);
5653 			if (exponent & 1) {
5654 				BN_mul(temp, power, temp, NULL);
5655 			}
5656 		}
5657 		BN_copy(res, temp);
5658 		BN_free(power);
5659 		BN_free(temp);
5660 	}
5661 	return 1;
5662 }
5663 
5664 /* compute the factorial */
5665 int
BN_factorial(BIGNUM * res,BIGNUM * f)5666 BN_factorial(BIGNUM *res, BIGNUM *f)
5667 {
5668 	BIGNUM	*one;
5669 	BIGNUM	*i;
5670 
5671 	i = BN_dup(f);
5672 	one = __UNCONST(BN_value_one());
5673 	BN_sub(i, i, one);
5674 	BN_copy(res, f);
5675 	while (BN_cmp(i, one) > 0) {
5676 		BN_mul(res, res, i, NULL);
5677 		BN_sub(i, i, one);
5678 	}
5679 	BN_free(i);
5680 	return 1;
5681 }
5682