1*37da2899SCharles.Forsyth #include "os.h"
2*37da2899SCharles.Forsyth #include <mp.h>
3*37da2899SCharles.Forsyth #include "dat.h"
4*37da2899SCharles.Forsyth
5*37da2899SCharles.Forsyth //
6*37da2899SCharles.Forsyth // from knuth's 1969 seminumberical algorithms, pp 233-235 and pp 258-260
7*37da2899SCharles.Forsyth //
8*37da2899SCharles.Forsyth // mpvecmul is an assembly language routine that performs the inner
9*37da2899SCharles.Forsyth // loop.
10*37da2899SCharles.Forsyth //
11*37da2899SCharles.Forsyth // the karatsuba trade off is set empiricly by measuring the algs on
12*37da2899SCharles.Forsyth // a 400 MHz Pentium II.
13*37da2899SCharles.Forsyth //
14*37da2899SCharles.Forsyth
15*37da2899SCharles.Forsyth // karatsuba like (see knuth pg 258)
16*37da2899SCharles.Forsyth // prereq: p is already zeroed
17*37da2899SCharles.Forsyth static void
mpkaratsuba(mpdigit * a,int alen,mpdigit * b,int blen,mpdigit * p)18*37da2899SCharles.Forsyth mpkaratsuba(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *p)
19*37da2899SCharles.Forsyth {
20*37da2899SCharles.Forsyth mpdigit *t, *u0, *u1, *v0, *v1, *u0v0, *u1v1, *res, *diffprod;
21*37da2899SCharles.Forsyth int u0len, u1len, v0len, v1len, reslen;
22*37da2899SCharles.Forsyth int sign, n;
23*37da2899SCharles.Forsyth
24*37da2899SCharles.Forsyth // divide each piece in half
25*37da2899SCharles.Forsyth n = alen/2;
26*37da2899SCharles.Forsyth if(alen&1)
27*37da2899SCharles.Forsyth n++;
28*37da2899SCharles.Forsyth u0len = n;
29*37da2899SCharles.Forsyth u1len = alen-n;
30*37da2899SCharles.Forsyth if(blen > n){
31*37da2899SCharles.Forsyth v0len = n;
32*37da2899SCharles.Forsyth v1len = blen-n;
33*37da2899SCharles.Forsyth } else {
34*37da2899SCharles.Forsyth v0len = blen;
35*37da2899SCharles.Forsyth v1len = 0;
36*37da2899SCharles.Forsyth }
37*37da2899SCharles.Forsyth u0 = a;
38*37da2899SCharles.Forsyth u1 = a + u0len;
39*37da2899SCharles.Forsyth v0 = b;
40*37da2899SCharles.Forsyth v1 = b + v0len;
41*37da2899SCharles.Forsyth
42*37da2899SCharles.Forsyth // room for the partial products
43*37da2899SCharles.Forsyth t = mallocz(Dbytes*5*(2*n+1), 1);
44*37da2899SCharles.Forsyth if(t == nil)
45*37da2899SCharles.Forsyth sysfatal("mpkaratsuba: %r");
46*37da2899SCharles.Forsyth u0v0 = t;
47*37da2899SCharles.Forsyth u1v1 = t + (2*n+1);
48*37da2899SCharles.Forsyth diffprod = t + 2*(2*n+1);
49*37da2899SCharles.Forsyth res = t + 3*(2*n+1);
50*37da2899SCharles.Forsyth reslen = 4*n+1;
51*37da2899SCharles.Forsyth
52*37da2899SCharles.Forsyth // t[0] = (u1-u0)
53*37da2899SCharles.Forsyth sign = 1;
54*37da2899SCharles.Forsyth if(mpveccmp(u1, u1len, u0, u0len) < 0){
55*37da2899SCharles.Forsyth sign = -1;
56*37da2899SCharles.Forsyth mpvecsub(u0, u0len, u1, u1len, u0v0);
57*37da2899SCharles.Forsyth } else
58*37da2899SCharles.Forsyth mpvecsub(u1, u1len, u0, u1len, u0v0);
59*37da2899SCharles.Forsyth
60*37da2899SCharles.Forsyth // t[1] = (v0-v1)
61*37da2899SCharles.Forsyth if(mpveccmp(v0, v0len, v1, v1len) < 0){
62*37da2899SCharles.Forsyth sign *= -1;
63*37da2899SCharles.Forsyth mpvecsub(v1, v1len, v0, v1len, u1v1);
64*37da2899SCharles.Forsyth } else
65*37da2899SCharles.Forsyth mpvecsub(v0, v0len, v1, v1len, u1v1);
66*37da2899SCharles.Forsyth
67*37da2899SCharles.Forsyth // t[4:5] = (u1-u0)*(v0-v1)
68*37da2899SCharles.Forsyth mpvecmul(u0v0, u0len, u1v1, v0len, diffprod);
69*37da2899SCharles.Forsyth
70*37da2899SCharles.Forsyth // t[0:1] = u1*v1
71*37da2899SCharles.Forsyth memset(t, 0, 2*(2*n+1)*Dbytes);
72*37da2899SCharles.Forsyth if(v1len > 0)
73*37da2899SCharles.Forsyth mpvecmul(u1, u1len, v1, v1len, u1v1);
74*37da2899SCharles.Forsyth
75*37da2899SCharles.Forsyth // t[2:3] = u0v0
76*37da2899SCharles.Forsyth mpvecmul(u0, u0len, v0, v0len, u0v0);
77*37da2899SCharles.Forsyth
78*37da2899SCharles.Forsyth // res = u0*v0<<n + u0*v0
79*37da2899SCharles.Forsyth mpvecadd(res, reslen, u0v0, u0len+v0len, res);
80*37da2899SCharles.Forsyth mpvecadd(res+n, reslen-n, u0v0, u0len+v0len, res+n);
81*37da2899SCharles.Forsyth
82*37da2899SCharles.Forsyth // res += u1*v1<<n + u1*v1<<2*n
83*37da2899SCharles.Forsyth if(v1len > 0){
84*37da2899SCharles.Forsyth mpvecadd(res+n, reslen-n, u1v1, u1len+v1len, res+n);
85*37da2899SCharles.Forsyth mpvecadd(res+2*n, reslen-2*n, u1v1, u1len+v1len, res+2*n);
86*37da2899SCharles.Forsyth }
87*37da2899SCharles.Forsyth
88*37da2899SCharles.Forsyth // res += (u1-u0)*(v0-v1)<<n
89*37da2899SCharles.Forsyth if(sign < 0)
90*37da2899SCharles.Forsyth mpvecsub(res+n, reslen-n, diffprod, u0len+v0len, res+n);
91*37da2899SCharles.Forsyth else
92*37da2899SCharles.Forsyth mpvecadd(res+n, reslen-n, diffprod, u0len+v0len, res+n);
93*37da2899SCharles.Forsyth memmove(p, res, (alen+blen)*Dbytes);
94*37da2899SCharles.Forsyth
95*37da2899SCharles.Forsyth free(t);
96*37da2899SCharles.Forsyth }
97*37da2899SCharles.Forsyth
98*37da2899SCharles.Forsyth #define KARATSUBAMIN 32
99*37da2899SCharles.Forsyth
100*37da2899SCharles.Forsyth void
mpvecmul(mpdigit * a,int alen,mpdigit * b,int blen,mpdigit * p)101*37da2899SCharles.Forsyth mpvecmul(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *p)
102*37da2899SCharles.Forsyth {
103*37da2899SCharles.Forsyth int i;
104*37da2899SCharles.Forsyth mpdigit d;
105*37da2899SCharles.Forsyth mpdigit *t;
106*37da2899SCharles.Forsyth
107*37da2899SCharles.Forsyth // both mpvecdigmuladd and karatsuba are fastest when a is the longer vector
108*37da2899SCharles.Forsyth if(alen < blen){
109*37da2899SCharles.Forsyth i = alen;
110*37da2899SCharles.Forsyth alen = blen;
111*37da2899SCharles.Forsyth blen = i;
112*37da2899SCharles.Forsyth t = a;
113*37da2899SCharles.Forsyth a = b;
114*37da2899SCharles.Forsyth b = t;
115*37da2899SCharles.Forsyth }
116*37da2899SCharles.Forsyth if(blen == 0){
117*37da2899SCharles.Forsyth memset(p, 0, Dbytes*(alen+blen));
118*37da2899SCharles.Forsyth return;
119*37da2899SCharles.Forsyth }
120*37da2899SCharles.Forsyth
121*37da2899SCharles.Forsyth if(alen >= KARATSUBAMIN && blen > 1){
122*37da2899SCharles.Forsyth // O(n^1.585)
123*37da2899SCharles.Forsyth mpkaratsuba(a, alen, b, blen, p);
124*37da2899SCharles.Forsyth } else {
125*37da2899SCharles.Forsyth // O(n^2)
126*37da2899SCharles.Forsyth for(i = 0; i < blen; i++){
127*37da2899SCharles.Forsyth d = b[i];
128*37da2899SCharles.Forsyth if(d != 0)
129*37da2899SCharles.Forsyth mpvecdigmuladd(a, alen, d, &p[i]);
130*37da2899SCharles.Forsyth }
131*37da2899SCharles.Forsyth }
132*37da2899SCharles.Forsyth }
133*37da2899SCharles.Forsyth
134*37da2899SCharles.Forsyth void
mpmul(mpint * b1,mpint * b2,mpint * prod)135*37da2899SCharles.Forsyth mpmul(mpint *b1, mpint *b2, mpint *prod)
136*37da2899SCharles.Forsyth {
137*37da2899SCharles.Forsyth mpint *oprod;
138*37da2899SCharles.Forsyth
139*37da2899SCharles.Forsyth oprod = nil;
140*37da2899SCharles.Forsyth if(prod == b1 || prod == b2){
141*37da2899SCharles.Forsyth oprod = prod;
142*37da2899SCharles.Forsyth prod = mpnew(0);
143*37da2899SCharles.Forsyth }
144*37da2899SCharles.Forsyth
145*37da2899SCharles.Forsyth prod->top = 0;
146*37da2899SCharles.Forsyth mpbits(prod, (b1->top+b2->top+1)*Dbits);
147*37da2899SCharles.Forsyth mpvecmul(b1->p, b1->top, b2->p, b2->top, prod->p);
148*37da2899SCharles.Forsyth prod->top = b1->top+b2->top+1;
149*37da2899SCharles.Forsyth prod->sign = b1->sign*b2->sign;
150*37da2899SCharles.Forsyth mpnorm(prod);
151*37da2899SCharles.Forsyth
152*37da2899SCharles.Forsyth if(oprod != nil){
153*37da2899SCharles.Forsyth mpassign(prod, oprod);
154*37da2899SCharles.Forsyth mpfree(prod);
155*37da2899SCharles.Forsyth }
156*37da2899SCharles.Forsyth }
157