1 /* mpn_toom32_mul -- Multiply {ap,an} and {bp,bn} where an is nominally 1.5
2 times as large as bn. Or more accurately, bn < an < 3bn.
3
4 Contributed to the GNU project by Torbjorn Granlund.
5 Improvements by Marco Bodrato and Niels Möller.
6
7 The idea of applying toom to unbalanced multiplication is due to Marco
8 Bodrato and Alberto Zanoni.
9
10 THE FUNCTION IN THIS FILE IS INTERNAL WITH A MUTABLE INTERFACE. IT IS ONLY
11 SAFE TO REACH IT THROUGH DOCUMENTED INTERFACES. IN FACT, IT IS ALMOST
12 GUARANTEED THAT IT WILL CHANGE OR DISAPPEAR IN A FUTURE GNU MP RELEASE.
13
14 Copyright 2006-2010 Free Software Foundation, Inc.
15
16 This file is part of the GNU MP Library.
17
18 The GNU MP Library is free software; you can redistribute it and/or modify
19 it under the terms of either:
20
21 * the GNU Lesser General Public License as published by the Free
22 Software Foundation; either version 3 of the License, or (at your
23 option) any later version.
24
25 or
26
27 * the GNU General Public License as published by the Free Software
28 Foundation; either version 2 of the License, or (at your option) any
29 later version.
30
31 or both in parallel, as here.
32
33 The GNU MP Library is distributed in the hope that it will be useful, but
34 WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
35 or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
36 for more details.
37
38 You should have received copies of the GNU General Public License and the
39 GNU Lesser General Public License along with the GNU MP Library. If not,
40 see https://www.gnu.org/licenses/. */
41
42
43 #include "gmp-impl.h"
44
45 /* Evaluate in: -1, 0, +1, +inf
46
47 <-s-><--n--><--n-->
48 ___ ______ ______
49 |a2_|___a1_|___a0_|
50 |_b1_|___b0_|
51 <-t--><--n-->
52
53 v0 = a0 * b0 # A(0)*B(0)
54 v1 = (a0+ a1+ a2)*(b0+ b1) # A(1)*B(1) ah <= 2 bh <= 1
55 vm1 = (a0- a1+ a2)*(b0- b1) # A(-1)*B(-1) |ah| <= 1 bh = 0
56 vinf= a2 * b1 # A(inf)*B(inf)
57 */
58
59 #define TOOM32_MUL_N_REC(p, a, b, n, ws) \
60 do { \
61 mpn_mul_n (p, a, b, n); \
62 } while (0)
63
64 void
mpn_toom32_mul(mp_ptr pp,mp_srcptr ap,mp_size_t an,mp_srcptr bp,mp_size_t bn,mp_ptr scratch)65 mpn_toom32_mul (mp_ptr pp,
66 mp_srcptr ap, mp_size_t an,
67 mp_srcptr bp, mp_size_t bn,
68 mp_ptr scratch)
69 {
70 mp_size_t n, s, t;
71 int vm1_neg;
72 mp_limb_t cy;
73 mp_limb_signed_t hi;
74 mp_limb_t ap1_hi, bp1_hi;
75
76 #define a0 ap
77 #define a1 (ap + n)
78 #define a2 (ap + 2 * n)
79 #define b0 bp
80 #define b1 (bp + n)
81
82 /* Required, to ensure that s + t >= n. */
83 ASSERT (bn + 2 <= an && an + 6 <= 3*bn);
84
85 n = 1 + (2 * an >= 3 * bn ? (an - 1) / (size_t) 3 : (bn - 1) >> 1);
86
87 s = an - 2 * n;
88 t = bn - n;
89
90 ASSERT (0 < s && s <= n);
91 ASSERT (0 < t && t <= n);
92 ASSERT (s + t >= n);
93
94 /* Product area of size an + bn = 3*n + s + t >= 4*n + 2. */
95 #define ap1 (pp) /* n, most significant limb in ap1_hi */
96 #define bp1 (pp + n) /* n, most significant bit in bp1_hi */
97 #define am1 (pp + 2*n) /* n, most significant bit in hi */
98 #define bm1 (pp + 3*n) /* n */
99 #define v1 (scratch) /* 2n + 1 */
100 #define vm1 (pp) /* 2n + 1 */
101 #define scratch_out (scratch + 2*n + 1) /* Currently unused. */
102
103 /* Scratch need: 2*n + 1 + scratch for the recursive multiplications. */
104
105 /* FIXME: Keep v1[2*n] and vm1[2*n] in scalar variables? */
106
107 /* Compute ap1 = a0 + a1 + a2, am1 = a0 - a1 + a2 */
108 ap1_hi = mpn_add (ap1, a0, n, a2, s);
109 #if HAVE_NATIVE_mpn_add_n_sub_n
110 if (ap1_hi == 0 && mpn_cmp (ap1, a1, n) < 0)
111 {
112 ap1_hi = mpn_add_n_sub_n (ap1, am1, a1, ap1, n) >> 1;
113 hi = 0;
114 vm1_neg = 1;
115 }
116 else
117 {
118 cy = mpn_add_n_sub_n (ap1, am1, ap1, a1, n);
119 hi = ap1_hi - (cy & 1);
120 ap1_hi += (cy >> 1);
121 vm1_neg = 0;
122 }
123 #else
124 if (ap1_hi == 0 && mpn_cmp (ap1, a1, n) < 0)
125 {
126 ASSERT_NOCARRY (mpn_sub_n (am1, a1, ap1, n));
127 hi = 0;
128 vm1_neg = 1;
129 }
130 else
131 {
132 hi = ap1_hi - mpn_sub_n (am1, ap1, a1, n);
133 vm1_neg = 0;
134 }
135 ap1_hi += mpn_add_n (ap1, ap1, a1, n);
136 #endif
137
138 /* Compute bp1 = b0 + b1 and bm1 = b0 - b1. */
139 if (t == n)
140 {
141 #if HAVE_NATIVE_mpn_add_n_sub_n
142 if (mpn_cmp (b0, b1, n) < 0)
143 {
144 cy = mpn_add_n_sub_n (bp1, bm1, b1, b0, n);
145 vm1_neg ^= 1;
146 }
147 else
148 {
149 cy = mpn_add_n_sub_n (bp1, bm1, b0, b1, n);
150 }
151 bp1_hi = cy >> 1;
152 #else
153 bp1_hi = mpn_add_n (bp1, b0, b1, n);
154
155 if (mpn_cmp (b0, b1, n) < 0)
156 {
157 ASSERT_NOCARRY (mpn_sub_n (bm1, b1, b0, n));
158 vm1_neg ^= 1;
159 }
160 else
161 {
162 ASSERT_NOCARRY (mpn_sub_n (bm1, b0, b1, n));
163 }
164 #endif
165 }
166 else
167 {
168 /* FIXME: Should still use mpn_add_n_sub_n for the main part. */
169 bp1_hi = mpn_add (bp1, b0, n, b1, t);
170
171 if (mpn_zero_p (b0 + t, n - t) && mpn_cmp (b0, b1, t) < 0)
172 {
173 ASSERT_NOCARRY (mpn_sub_n (bm1, b1, b0, t));
174 MPN_ZERO (bm1 + t, n - t);
175 vm1_neg ^= 1;
176 }
177 else
178 {
179 ASSERT_NOCARRY (mpn_sub (bm1, b0, n, b1, t));
180 }
181 }
182
183 TOOM32_MUL_N_REC (v1, ap1, bp1, n, scratch_out);
184 if (ap1_hi == 1)
185 {
186 cy = bp1_hi + mpn_add_n (v1 + n, v1 + n, bp1, n);
187 }
188 else if (ap1_hi == 2)
189 {
190 #if HAVE_NATIVE_mpn_addlsh1_n
191 cy = 2 * bp1_hi + mpn_addlsh1_n (v1 + n, v1 + n, bp1, n);
192 #else
193 cy = 2 * bp1_hi + mpn_addmul_1 (v1 + n, bp1, n, CNST_LIMB(2));
194 #endif
195 }
196 else
197 cy = 0;
198 if (bp1_hi != 0)
199 cy += mpn_add_n (v1 + n, v1 + n, ap1, n);
200 v1[2 * n] = cy;
201
202 TOOM32_MUL_N_REC (vm1, am1, bm1, n, scratch_out);
203 if (hi)
204 hi = mpn_add_n (vm1+n, vm1+n, bm1, n);
205
206 vm1[2*n] = hi;
207
208 /* v1 <-- (v1 + vm1) / 2 = x0 + x2 */
209 if (vm1_neg)
210 {
211 #if HAVE_NATIVE_mpn_rsh1sub_n
212 mpn_rsh1sub_n (v1, v1, vm1, 2*n+1);
213 #else
214 mpn_sub_n (v1, v1, vm1, 2*n+1);
215 ASSERT_NOCARRY (mpn_rshift (v1, v1, 2*n+1, 1));
216 #endif
217 }
218 else
219 {
220 #if HAVE_NATIVE_mpn_rsh1add_n
221 mpn_rsh1add_n (v1, v1, vm1, 2*n+1);
222 #else
223 mpn_add_n (v1, v1, vm1, 2*n+1);
224 ASSERT_NOCARRY (mpn_rshift (v1, v1, 2*n+1, 1));
225 #endif
226 }
227
228 /* We get x1 + x3 = (x0 + x2) - (x0 - x1 + x2 - x3), and hence
229
230 y = x1 + x3 + (x0 + x2) * B
231 = (x0 + x2) * B + (x0 + x2) - vm1.
232
233 y is 3*n + 1 limbs, y = y0 + y1 B + y2 B^2. We store them as
234 follows: y0 at scratch, y1 at pp + 2*n, and y2 at scratch + n
235 (already in place, except for carry propagation).
236
237 We thus add
238
239 B^3 B^2 B 1
240 | | | |
241 +-----+----+
242 + | x0 + x2 |
243 +----+-----+----+
244 + | x0 + x2 |
245 +----------+
246 - | vm1 |
247 --+----++----+----+-
248 | y2 | y1 | y0 |
249 +-----+----+----+
250
251 Since we store y0 at the same location as the low half of x0 + x2, we
252 need to do the middle sum first. */
253
254 hi = vm1[2*n];
255 cy = mpn_add_n (pp + 2*n, v1, v1 + n, n);
256 MPN_INCR_U (v1 + n, n + 1, cy + v1[2*n]);
257
258 /* FIXME: Can we get rid of this second vm1_neg conditional by
259 swapping the location of +1 and -1 values? */
260 if (vm1_neg)
261 {
262 cy = mpn_add_n (v1, v1, vm1, n);
263 hi += mpn_add_nc (pp + 2*n, pp + 2*n, vm1 + n, n, cy);
264 MPN_INCR_U (v1 + n, n+1, hi);
265 }
266 else
267 {
268 cy = mpn_sub_n (v1, v1, vm1, n);
269 hi += mpn_sub_nc (pp + 2*n, pp + 2*n, vm1 + n, n, cy);
270 MPN_DECR_U (v1 + n, n+1, hi);
271 }
272
273 TOOM32_MUL_N_REC (pp, a0, b0, n, scratch_out);
274 /* vinf, s+t limbs. Use mpn_mul for now, to handle unbalanced operands */
275 if (s > t) mpn_mul (pp+3*n, a2, s, b1, t);
276 else mpn_mul (pp+3*n, b1, t, a2, s);
277
278 /* Remaining interpolation.
279
280 y * B + x0 + x3 B^3 - x0 B^2 - x3 B
281 = (x1 + x3) B + (x0 + x2) B^2 + x0 + x3 B^3 - x0 B^2 - x3 B
282 = y0 B + y1 B^2 + y3 B^3 + Lx0 + H x0 B
283 + L x3 B^3 + H x3 B^4 - Lx0 B^2 - H x0 B^3 - L x3 B - H x3 B^2
284 = L x0 + (y0 + H x0 - L x3) B + (y1 - L x0 - H x3) B^2
285 + (y2 - (H x0 - L x3)) B^3 + H x3 B^4
286
287 B^4 B^3 B^2 B 1
288 | | | | | |
289 +-------+ +---------+---------+
290 | Hx3 | | Hx0-Lx3 | Lx0 |
291 +------+----------+---------+---------+---------+
292 | y2 | y1 | y0 |
293 ++---------+---------+---------+
294 -| Hx0-Lx3 | - Lx0 |
295 +---------+---------+
296 | - Hx3 |
297 +--------+
298
299 We must take into account the carry from Hx0 - Lx3.
300 */
301
302 cy = mpn_sub_n (pp + n, pp + n, pp+3*n, n);
303 hi = scratch[2*n] + cy;
304
305 cy = mpn_sub_nc (pp + 2*n, pp + 2*n, pp, n, cy);
306 hi -= mpn_sub_nc (pp + 3*n, scratch + n, pp + n, n, cy);
307
308 hi += mpn_add (pp + n, pp + n, 3*n, scratch, n);
309
310 /* FIXME: Is support for s + t == n needed? */
311 if (LIKELY (s + t > n))
312 {
313 hi -= mpn_sub (pp + 2*n, pp + 2*n, 2*n, pp + 4*n, s+t-n);
314
315 if (hi < 0)
316 MPN_DECR_U (pp + 4*n, s+t-n, -hi);
317 else
318 MPN_INCR_U (pp + 4*n, s+t-n, hi);
319 }
320 else
321 ASSERT (hi == 0);
322 }
323