xref: /plan9/sys/src/cmd/unix/drawterm/libmp/mpvecdigmuladd.c (revision 8ccd4a6360d974db7bd7bbd4f37e7018419ea908)
1 #include "os.h"
2 #include <mp.h>
3 #include "dat.h"
4 
5 #define LO(x) ((x) & ((1<<(Dbits/2))-1))
6 #define HI(x) ((x) >> (Dbits/2))
7 
8 static void
mpdigmul(mpdigit a,mpdigit b,mpdigit * p)9 mpdigmul(mpdigit a, mpdigit b, mpdigit *p)
10 {
11 	mpdigit x, ah, al, bh, bl, p1, p2, p3, p4;
12 	int carry;
13 
14 	// half digits
15 	ah = HI(a);
16 	al = LO(a);
17 	bh = HI(b);
18 	bl = LO(b);
19 
20 	// partial products
21 	p1 = ah*bl;
22 	p2 = bh*al;
23 	p3 = bl*al;
24 	p4 = ah*bh;
25 
26 	// p = ((p1+p2)<<(Dbits/2)) + (p4<<Dbits) + p3
27 	carry = 0;
28 	x = p1<<(Dbits/2);
29 	p3 += x;
30 	if(p3 < x)
31 		carry++;
32 	x = p2<<(Dbits/2);
33 	p3 += x;
34 	if(p3 < x)
35 		carry++;
36 	p4 += carry + HI(p1) + HI(p2);	// can't carry out of the high digit
37 	p[0] = p3;
38 	p[1] = p4;
39 }
40 
41 // prereq: p must have room for n+1 digits
42 void
mpvecdigmuladd(mpdigit * b,int n,mpdigit m,mpdigit * p)43 mpvecdigmuladd(mpdigit *b, int n, mpdigit m, mpdigit *p)
44 {
45 	int i;
46 	mpdigit carry, x, y, part[2];
47 
48 	carry = 0;
49 	part[1] = 0;
50 	for(i = 0; i < n; i++){
51 		x = part[1] + carry;
52 		if(x < carry)
53 			carry = 1;
54 		else
55 			carry = 0;
56 		y = *p;
57 		mpdigmul(*b++, m, part);
58 		x += part[0];
59 		if(x < part[0])
60 			carry++;
61 		x += y;
62 		if(x < y)
63 			carry++;
64 		*p++ = x;
65 	}
66 	*p = part[1] + carry;
67 }
68 
69 // prereq: p must have room for n+1 digits
70 int
mpvecdigmulsub(mpdigit * b,int n,mpdigit m,mpdigit * p)71 mpvecdigmulsub(mpdigit *b, int n, mpdigit m, mpdigit *p)
72 {
73 	int i;
74 	mpdigit x, y, part[2], borrow;
75 
76 	borrow = 0;
77 	part[1] = 0;
78 	for(i = 0; i < n; i++){
79 		x = *p;
80 		y = x - borrow;
81 		if(y > x)
82 			borrow = 1;
83 		else
84 			borrow = 0;
85 		x = part[1];
86 		mpdigmul(*b++, m, part);
87 		x += part[0];
88 		if(x < part[0])
89 			borrow++;
90 		x = y - x;
91 		if(x > y)
92 			borrow++;
93 		*p++ = x;
94 	}
95 
96 	x = *p;
97 	y = x - borrow - part[1];
98 	*p = y;
99 	if(y > x)
100 		return -1;
101 	else
102 		return 1;
103 }
104