xref: /plan9/sys/src/cmd/unix/drawterm/libmp/mpmul.c (revision 8ccd4a6360d974db7bd7bbd4f37e7018419ea908)
1 #include "os.h"
2 #include <mp.h>
3 #include "dat.h"
4 
5 //
6 //  from knuth's 1969 seminumberical algorithms, pp 233-235 and pp 258-260
7 //
8 //  mpvecmul is an assembly language routine that performs the inner
9 //  loop.
10 //
11 //  the karatsuba trade off is set empiricly by measuring the algs on
12 //  a 400 MHz Pentium II.
13 //
14 
15 // karatsuba like (see knuth pg 258)
16 // prereq: p is already zeroed
17 static void
mpkaratsuba(mpdigit * a,int alen,mpdigit * b,int blen,mpdigit * p)18 mpkaratsuba(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *p)
19 {
20 	mpdigit *t, *u0, *u1, *v0, *v1, *u0v0, *u1v1, *res, *diffprod;
21 	int u0len, u1len, v0len, v1len, reslen;
22 	int sign, n;
23 
24 	// divide each piece in half
25 	n = alen/2;
26 	if(alen&1)
27 		n++;
28 	u0len = n;
29 	u1len = alen-n;
30 	if(blen > n){
31 		v0len = n;
32 		v1len = blen-n;
33 	} else {
34 		v0len = blen;
35 		v1len = 0;
36 	}
37 	u0 = a;
38 	u1 = a + u0len;
39 	v0 = b;
40 	v1 = b + v0len;
41 
42 	// room for the partial products
43 	t = mallocz(Dbytes*5*(2*n+1), 1);
44 	if(t == nil)
45 		sysfatal("mpkaratsuba: %r");
46 	u0v0 = t;
47 	u1v1 = t + (2*n+1);
48 	diffprod = t + 2*(2*n+1);
49 	res = t + 3*(2*n+1);
50 	reslen = 4*n+1;
51 
52 	// t[0] = (u1-u0)
53 	sign = 1;
54 	if(mpveccmp(u1, u1len, u0, u0len) < 0){
55 		sign = -1;
56 		mpvecsub(u0, u0len, u1, u1len, u0v0);
57 	} else
58 		mpvecsub(u1, u1len, u0, u1len, u0v0);
59 
60 	// t[1] = (v0-v1)
61 	if(mpveccmp(v0, v0len, v1, v1len) < 0){
62 		sign *= -1;
63 		mpvecsub(v1, v1len, v0, v1len, u1v1);
64 	} else
65 		mpvecsub(v0, v0len, v1, v1len, u1v1);
66 
67 	// t[4:5] = (u1-u0)*(v0-v1)
68 	mpvecmul(u0v0, u0len, u1v1, v0len, diffprod);
69 
70 	// t[0:1] = u1*v1
71 	memset(t, 0, 2*(2*n+1)*Dbytes);
72 	if(v1len > 0)
73 		mpvecmul(u1, u1len, v1, v1len, u1v1);
74 
75 	// t[2:3] = u0v0
76 	mpvecmul(u0, u0len, v0, v0len, u0v0);
77 
78 	// res = u0*v0<<n + u0*v0
79 	mpvecadd(res, reslen, u0v0, u0len+v0len, res);
80 	mpvecadd(res+n, reslen-n, u0v0, u0len+v0len, res+n);
81 
82 	// res += u1*v1<<n + u1*v1<<2*n
83 	if(v1len > 0){
84 		mpvecadd(res+n, reslen-n, u1v1, u1len+v1len, res+n);
85 		mpvecadd(res+2*n, reslen-2*n, u1v1, u1len+v1len, res+2*n);
86 	}
87 
88 	// res += (u1-u0)*(v0-v1)<<n
89 	if(sign < 0)
90 		mpvecsub(res+n, reslen-n, diffprod, u0len+v0len, res+n);
91 	else
92 		mpvecadd(res+n, reslen-n, diffprod, u0len+v0len, res+n);
93 	memmove(p, res, (alen+blen)*Dbytes);
94 
95 	free(t);
96 }
97 
98 #define KARATSUBAMIN 32
99 
100 void
mpvecmul(mpdigit * a,int alen,mpdigit * b,int blen,mpdigit * p)101 mpvecmul(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *p)
102 {
103 	int i;
104 	mpdigit d;
105 	mpdigit *t;
106 
107 	// both mpvecdigmuladd and karatsuba are fastest when a is the longer vector
108 	if(alen < blen){
109 		i = alen;
110 		alen = blen;
111 		blen = i;
112 		t = a;
113 		a = b;
114 		b = t;
115 	}
116 	if(blen == 0){
117 		memset(p, 0, Dbytes*(alen+blen));
118 		return;
119 	}
120 
121 	if(alen >= KARATSUBAMIN && blen > 1){
122 		// O(n^1.585)
123 		mpkaratsuba(a, alen, b, blen, p);
124 	} else {
125 		// O(n^2)
126 		for(i = 0; i < blen; i++){
127 			d = b[i];
128 			if(d != 0)
129 				mpvecdigmuladd(a, alen, d, &p[i]);
130 		}
131 	}
132 }
133 
134 void
mpmul(mpint * b1,mpint * b2,mpint * prod)135 mpmul(mpint *b1, mpint *b2, mpint *prod)
136 {
137 	mpint *oprod;
138 
139 	oprod = nil;
140 	if(prod == b1 || prod == b2){
141 		oprod = prod;
142 		prod = mpnew(0);
143 	}
144 
145 	prod->top = 0;
146 	mpbits(prod, (b1->top+b2->top+1)*Dbits);
147 	mpvecmul(b1->p, b1->top, b2->p, b2->top, prod->p);
148 	prod->top = b1->top+b2->top+1;
149 	prod->sign = b1->sign*b2->sign;
150 	mpnorm(prod);
151 
152 	if(oprod != nil){
153 		mpassign(prod, oprod);
154 		mpfree(prod);
155 	}
156 }
157