1*8ccd4a63SDavid du Colombier #include "os.h"
2*8ccd4a63SDavid du Colombier #include <mp.h>
3*8ccd4a63SDavid du Colombier #include "dat.h"
4*8ccd4a63SDavid du Colombier
5*8ccd4a63SDavid du Colombier //
6*8ccd4a63SDavid du Colombier // from knuth's 1969 seminumberical algorithms, pp 233-235 and pp 258-260
7*8ccd4a63SDavid du Colombier //
8*8ccd4a63SDavid du Colombier // mpvecmul is an assembly language routine that performs the inner
9*8ccd4a63SDavid du Colombier // loop.
10*8ccd4a63SDavid du Colombier //
11*8ccd4a63SDavid du Colombier // the karatsuba trade off is set empiricly by measuring the algs on
12*8ccd4a63SDavid du Colombier // a 400 MHz Pentium II.
13*8ccd4a63SDavid du Colombier //
14*8ccd4a63SDavid du Colombier
15*8ccd4a63SDavid du Colombier // karatsuba like (see knuth pg 258)
16*8ccd4a63SDavid du Colombier // prereq: p is already zeroed
17*8ccd4a63SDavid du Colombier static void
mpkaratsuba(mpdigit * a,int alen,mpdigit * b,int blen,mpdigit * p)18*8ccd4a63SDavid du Colombier mpkaratsuba(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *p)
19*8ccd4a63SDavid du Colombier {
20*8ccd4a63SDavid du Colombier mpdigit *t, *u0, *u1, *v0, *v1, *u0v0, *u1v1, *res, *diffprod;
21*8ccd4a63SDavid du Colombier int u0len, u1len, v0len, v1len, reslen;
22*8ccd4a63SDavid du Colombier int sign, n;
23*8ccd4a63SDavid du Colombier
24*8ccd4a63SDavid du Colombier // divide each piece in half
25*8ccd4a63SDavid du Colombier n = alen/2;
26*8ccd4a63SDavid du Colombier if(alen&1)
27*8ccd4a63SDavid du Colombier n++;
28*8ccd4a63SDavid du Colombier u0len = n;
29*8ccd4a63SDavid du Colombier u1len = alen-n;
30*8ccd4a63SDavid du Colombier if(blen > n){
31*8ccd4a63SDavid du Colombier v0len = n;
32*8ccd4a63SDavid du Colombier v1len = blen-n;
33*8ccd4a63SDavid du Colombier } else {
34*8ccd4a63SDavid du Colombier v0len = blen;
35*8ccd4a63SDavid du Colombier v1len = 0;
36*8ccd4a63SDavid du Colombier }
37*8ccd4a63SDavid du Colombier u0 = a;
38*8ccd4a63SDavid du Colombier u1 = a + u0len;
39*8ccd4a63SDavid du Colombier v0 = b;
40*8ccd4a63SDavid du Colombier v1 = b + v0len;
41*8ccd4a63SDavid du Colombier
42*8ccd4a63SDavid du Colombier // room for the partial products
43*8ccd4a63SDavid du Colombier t = mallocz(Dbytes*5*(2*n+1), 1);
44*8ccd4a63SDavid du Colombier if(t == nil)
45*8ccd4a63SDavid du Colombier sysfatal("mpkaratsuba: %r");
46*8ccd4a63SDavid du Colombier u0v0 = t;
47*8ccd4a63SDavid du Colombier u1v1 = t + (2*n+1);
48*8ccd4a63SDavid du Colombier diffprod = t + 2*(2*n+1);
49*8ccd4a63SDavid du Colombier res = t + 3*(2*n+1);
50*8ccd4a63SDavid du Colombier reslen = 4*n+1;
51*8ccd4a63SDavid du Colombier
52*8ccd4a63SDavid du Colombier // t[0] = (u1-u0)
53*8ccd4a63SDavid du Colombier sign = 1;
54*8ccd4a63SDavid du Colombier if(mpveccmp(u1, u1len, u0, u0len) < 0){
55*8ccd4a63SDavid du Colombier sign = -1;
56*8ccd4a63SDavid du Colombier mpvecsub(u0, u0len, u1, u1len, u0v0);
57*8ccd4a63SDavid du Colombier } else
58*8ccd4a63SDavid du Colombier mpvecsub(u1, u1len, u0, u1len, u0v0);
59*8ccd4a63SDavid du Colombier
60*8ccd4a63SDavid du Colombier // t[1] = (v0-v1)
61*8ccd4a63SDavid du Colombier if(mpveccmp(v0, v0len, v1, v1len) < 0){
62*8ccd4a63SDavid du Colombier sign *= -1;
63*8ccd4a63SDavid du Colombier mpvecsub(v1, v1len, v0, v1len, u1v1);
64*8ccd4a63SDavid du Colombier } else
65*8ccd4a63SDavid du Colombier mpvecsub(v0, v0len, v1, v1len, u1v1);
66*8ccd4a63SDavid du Colombier
67*8ccd4a63SDavid du Colombier // t[4:5] = (u1-u0)*(v0-v1)
68*8ccd4a63SDavid du Colombier mpvecmul(u0v0, u0len, u1v1, v0len, diffprod);
69*8ccd4a63SDavid du Colombier
70*8ccd4a63SDavid du Colombier // t[0:1] = u1*v1
71*8ccd4a63SDavid du Colombier memset(t, 0, 2*(2*n+1)*Dbytes);
72*8ccd4a63SDavid du Colombier if(v1len > 0)
73*8ccd4a63SDavid du Colombier mpvecmul(u1, u1len, v1, v1len, u1v1);
74*8ccd4a63SDavid du Colombier
75*8ccd4a63SDavid du Colombier // t[2:3] = u0v0
76*8ccd4a63SDavid du Colombier mpvecmul(u0, u0len, v0, v0len, u0v0);
77*8ccd4a63SDavid du Colombier
78*8ccd4a63SDavid du Colombier // res = u0*v0<<n + u0*v0
79*8ccd4a63SDavid du Colombier mpvecadd(res, reslen, u0v0, u0len+v0len, res);
80*8ccd4a63SDavid du Colombier mpvecadd(res+n, reslen-n, u0v0, u0len+v0len, res+n);
81*8ccd4a63SDavid du Colombier
82*8ccd4a63SDavid du Colombier // res += u1*v1<<n + u1*v1<<2*n
83*8ccd4a63SDavid du Colombier if(v1len > 0){
84*8ccd4a63SDavid du Colombier mpvecadd(res+n, reslen-n, u1v1, u1len+v1len, res+n);
85*8ccd4a63SDavid du Colombier mpvecadd(res+2*n, reslen-2*n, u1v1, u1len+v1len, res+2*n);
86*8ccd4a63SDavid du Colombier }
87*8ccd4a63SDavid du Colombier
88*8ccd4a63SDavid du Colombier // res += (u1-u0)*(v0-v1)<<n
89*8ccd4a63SDavid du Colombier if(sign < 0)
90*8ccd4a63SDavid du Colombier mpvecsub(res+n, reslen-n, diffprod, u0len+v0len, res+n);
91*8ccd4a63SDavid du Colombier else
92*8ccd4a63SDavid du Colombier mpvecadd(res+n, reslen-n, diffprod, u0len+v0len, res+n);
93*8ccd4a63SDavid du Colombier memmove(p, res, (alen+blen)*Dbytes);
94*8ccd4a63SDavid du Colombier
95*8ccd4a63SDavid du Colombier free(t);
96*8ccd4a63SDavid du Colombier }
97*8ccd4a63SDavid du Colombier
98*8ccd4a63SDavid du Colombier #define KARATSUBAMIN 32
99*8ccd4a63SDavid du Colombier
100*8ccd4a63SDavid du Colombier void
mpvecmul(mpdigit * a,int alen,mpdigit * b,int blen,mpdigit * p)101*8ccd4a63SDavid du Colombier mpvecmul(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *p)
102*8ccd4a63SDavid du Colombier {
103*8ccd4a63SDavid du Colombier int i;
104*8ccd4a63SDavid du Colombier mpdigit d;
105*8ccd4a63SDavid du Colombier mpdigit *t;
106*8ccd4a63SDavid du Colombier
107*8ccd4a63SDavid du Colombier // both mpvecdigmuladd and karatsuba are fastest when a is the longer vector
108*8ccd4a63SDavid du Colombier if(alen < blen){
109*8ccd4a63SDavid du Colombier i = alen;
110*8ccd4a63SDavid du Colombier alen = blen;
111*8ccd4a63SDavid du Colombier blen = i;
112*8ccd4a63SDavid du Colombier t = a;
113*8ccd4a63SDavid du Colombier a = b;
114*8ccd4a63SDavid du Colombier b = t;
115*8ccd4a63SDavid du Colombier }
116*8ccd4a63SDavid du Colombier if(blen == 0){
117*8ccd4a63SDavid du Colombier memset(p, 0, Dbytes*(alen+blen));
118*8ccd4a63SDavid du Colombier return;
119*8ccd4a63SDavid du Colombier }
120*8ccd4a63SDavid du Colombier
121*8ccd4a63SDavid du Colombier if(alen >= KARATSUBAMIN && blen > 1){
122*8ccd4a63SDavid du Colombier // O(n^1.585)
123*8ccd4a63SDavid du Colombier mpkaratsuba(a, alen, b, blen, p);
124*8ccd4a63SDavid du Colombier } else {
125*8ccd4a63SDavid du Colombier // O(n^2)
126*8ccd4a63SDavid du Colombier for(i = 0; i < blen; i++){
127*8ccd4a63SDavid du Colombier d = b[i];
128*8ccd4a63SDavid du Colombier if(d != 0)
129*8ccd4a63SDavid du Colombier mpvecdigmuladd(a, alen, d, &p[i]);
130*8ccd4a63SDavid du Colombier }
131*8ccd4a63SDavid du Colombier }
132*8ccd4a63SDavid du Colombier }
133*8ccd4a63SDavid du Colombier
134*8ccd4a63SDavid du Colombier void
mpmul(mpint * b1,mpint * b2,mpint * prod)135*8ccd4a63SDavid du Colombier mpmul(mpint *b1, mpint *b2, mpint *prod)
136*8ccd4a63SDavid du Colombier {
137*8ccd4a63SDavid du Colombier mpint *oprod;
138*8ccd4a63SDavid du Colombier
139*8ccd4a63SDavid du Colombier oprod = nil;
140*8ccd4a63SDavid du Colombier if(prod == b1 || prod == b2){
141*8ccd4a63SDavid du Colombier oprod = prod;
142*8ccd4a63SDavid du Colombier prod = mpnew(0);
143*8ccd4a63SDavid du Colombier }
144*8ccd4a63SDavid du Colombier
145*8ccd4a63SDavid du Colombier prod->top = 0;
146*8ccd4a63SDavid du Colombier mpbits(prod, (b1->top+b2->top+1)*Dbits);
147*8ccd4a63SDavid du Colombier mpvecmul(b1->p, b1->top, b2->p, b2->top, prod->p);
148*8ccd4a63SDavid du Colombier prod->top = b1->top+b2->top+1;
149*8ccd4a63SDavid du Colombier prod->sign = b1->sign*b2->sign;
150*8ccd4a63SDavid du Colombier mpnorm(prod);
151*8ccd4a63SDavid du Colombier
152*8ccd4a63SDavid du Colombier if(oprod != nil){
153*8ccd4a63SDavid du Colombier mpassign(prod, oprod);
154*8ccd4a63SDavid du Colombier mpfree(prod);
155*8ccd4a63SDavid du Colombier }
156*8ccd4a63SDavid du Colombier }
157