xref: /netbsd-src/external/lgpl3/gmp/dist/mpn/generic/toom53_mul.c (revision 72c7faa4dbb41dbb0238d6b4a109da0d4b236dd4)
1 /* mpn_toom53_mul -- Multiply {ap,an} and {bp,bn} where an is nominally 5/3
2    times as large as bn.  Or more accurately, (4/3)bn < an < (5/2)bn.
3 
4    Contributed to the GNU project by Torbjorn Granlund and Marco Bodrato.
5 
6    The idea of applying toom to unbalanced multiplication is due to Marco
7    Bodrato and Alberto Zanoni.
8 
9    THE FUNCTION IN THIS FILE IS INTERNAL WITH A MUTABLE INTERFACE.  IT IS ONLY
10    SAFE TO REACH IT THROUGH DOCUMENTED INTERFACES.  IN FACT, IT IS ALMOST
11    GUARANTEED THAT IT WILL CHANGE OR DISAPPEAR IN A FUTURE GNU MP RELEASE.
12 
13 Copyright 2006-2008, 2012, 2014, 2015 Free Software Foundation, Inc.
14 
15 This file is part of the GNU MP Library.
16 
17 The GNU MP Library is free software; you can redistribute it and/or modify
18 it under the terms of either:
19 
20   * the GNU Lesser General Public License as published by the Free
21     Software Foundation; either version 3 of the License, or (at your
22     option) any later version.
23 
24 or
25 
26   * the GNU General Public License as published by the Free Software
27     Foundation; either version 2 of the License, or (at your option) any
28     later version.
29 
30 or both in parallel, as here.
31 
32 The GNU MP Library is distributed in the hope that it will be useful, but
33 WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
34 or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
35 for more details.
36 
37 You should have received copies of the GNU General Public License and the
38 GNU Lesser General Public License along with the GNU MP Library.  If not,
39 see https://www.gnu.org/licenses/.  */
40 
41 
42 #include "gmp-impl.h"
43 
44 /* Evaluate in: 0, +1, -1, +2, -2, 1/2, +inf
45 
46   <-s-><--n--><--n--><--n--><--n-->
47    ___ ______ ______ ______ ______
48   |a4_|___a3_|___a2_|___a1_|___a0_|
49 	       |__b2|___b1_|___b0_|
50 	       <-t--><--n--><--n-->
51 
52   v0  =    a0                  *  b0          #    A(0)*B(0)
53   v1  = (  a0+ a1+ a2+ a3+  a4)*( b0+ b1+ b2) #    A(1)*B(1)      ah  <= 4   bh <= 2
54   vm1 = (  a0- a1+ a2- a3+  a4)*( b0- b1+ b2) #   A(-1)*B(-1)    |ah| <= 2   bh <= 1
55   v2  = (  a0+2a1+4a2+8a3+16a4)*( b0+2b1+4b2) #    A(2)*B(2)      ah  <= 30  bh <= 6
56   vm2 = (  a0-2a1+4a2-8a3+16a4)*( b0-2b1+4b2) #    A(2)*B(2)     -9<=ah<=20 -1<=bh<=4
57   vh  = (16a0+8a1+4a2+2a3+  a4)*(4b0+2b1+ b2) #  A(1/2)*B(1/2)    ah  <= 30  bh <= 6
58   vinf=                     a4 *          b2  #  A(inf)*B(inf)
59 */
60 
61 void
mpn_toom53_mul(mp_ptr pp,mp_srcptr ap,mp_size_t an,mp_srcptr bp,mp_size_t bn,mp_ptr scratch)62 mpn_toom53_mul (mp_ptr pp,
63 		mp_srcptr ap, mp_size_t an,
64 		mp_srcptr bp, mp_size_t bn,
65 		mp_ptr scratch)
66 {
67   mp_size_t n, s, t;
68   mp_limb_t cy;
69   mp_ptr gp;
70   mp_ptr as1, asm1, as2, asm2, ash;
71   mp_ptr bs1, bsm1, bs2, bsm2, bsh;
72   mp_ptr tmp;
73   enum toom7_flags flags;
74   TMP_DECL;
75 
76 #define a0  ap
77 #define a1  (ap + n)
78 #define a2  (ap + 2*n)
79 #define a3  (ap + 3*n)
80 #define a4  (ap + 4*n)
81 #define b0  bp
82 #define b1  (bp + n)
83 #define b2  (bp + 2*n)
84 
85   n = 1 + (3 * an >= 5 * bn ? (an - 1) / (size_t) 5 : (bn - 1) / (size_t) 3);
86 
87   s = an - 4 * n;
88   t = bn - 2 * n;
89 
90   ASSERT (0 < s && s <= n);
91   ASSERT (0 < t && t <= n);
92 
93   TMP_MARK;
94 
95   tmp = TMP_ALLOC_LIMBS (10 * (n + 1));
96   as1  = tmp; tmp += n + 1;
97   asm1 = tmp; tmp += n + 1;
98   as2  = tmp; tmp += n + 1;
99   asm2 = tmp; tmp += n + 1;
100   ash  = tmp; tmp += n + 1;
101   bs1  = tmp; tmp += n + 1;
102   bsm1 = tmp; tmp += n + 1;
103   bs2  = tmp; tmp += n + 1;
104   bsm2 = tmp; tmp += n + 1;
105   bsh  = tmp; tmp += n + 1;
106 
107   gp = pp;
108 
109   /* Compute as1 and asm1.  */
110   flags = (enum toom7_flags) (toom7_w3_neg & mpn_toom_eval_pm1 (as1, asm1, 4, ap, n, s, gp));
111 
112   /* Compute as2 and asm2. */
113   flags = (enum toom7_flags) (flags | (toom7_w1_neg & mpn_toom_eval_pm2 (as2, asm2, 4, ap, n, s, gp)));
114 
115   /* Compute ash = 16 a0 + 8 a1 + 4 a2 + 2 a3 + a4
116      = 2*(2*(2*(2*a0 + a1) + a2) + a3) + a4  */
117 #if HAVE_NATIVE_mpn_addlsh1_n
118   cy = mpn_addlsh1_n (ash, a1, a0, n);
119   cy = 2*cy + mpn_addlsh1_n (ash, a2, ash, n);
120   cy = 2*cy + mpn_addlsh1_n (ash, a3, ash, n);
121   if (s < n)
122     {
123       mp_limb_t cy2;
124       cy2 = mpn_addlsh1_n (ash, a4, ash, s);
125       ash[n] = 2*cy + mpn_lshift (ash + s, ash + s, n - s, 1);
126       MPN_INCR_U (ash + s, n+1-s, cy2);
127     }
128   else
129     ash[n] = 2*cy + mpn_addlsh1_n (ash, a4, ash, n);
130 #else
131   cy = mpn_lshift (ash, a0, n, 1);
132   cy += mpn_add_n (ash, ash, a1, n);
133   cy = 2*cy + mpn_lshift (ash, ash, n, 1);
134   cy += mpn_add_n (ash, ash, a2, n);
135   cy = 2*cy + mpn_lshift (ash, ash, n, 1);
136   cy += mpn_add_n (ash, ash, a3, n);
137   cy = 2*cy + mpn_lshift (ash, ash, n, 1);
138   ash[n] = cy + mpn_add (ash, ash, n, a4, s);
139 #endif
140 
141   /* Compute bs1 and bsm1.  */
142   bs1[n] = mpn_add (bs1, b0, n, b2, t);		/* b0 + b2 */
143 #if HAVE_NATIVE_mpn_add_n_sub_n
144   if (bs1[n] == 0 && mpn_cmp (bs1, b1, n) < 0)
145     {
146       bs1[n] = mpn_add_n_sub_n (bs1, bsm1, b1, bs1, n) >> 1;
147       bsm1[n] = 0;
148       flags = (enum toom7_flags) (flags ^ toom7_w3_neg);
149     }
150   else
151     {
152       cy = mpn_add_n_sub_n (bs1, bsm1, bs1, b1, n);
153       bsm1[n] = bs1[n] - (cy & 1);
154       bs1[n] += (cy >> 1);
155     }
156 #else
157   if (bs1[n] == 0 && mpn_cmp (bs1, b1, n) < 0)
158     {
159       mpn_sub_n (bsm1, b1, bs1, n);
160       bsm1[n] = 0;
161       flags = (enum toom7_flags) (flags ^ toom7_w3_neg);
162     }
163   else
164     {
165       bsm1[n] = bs1[n] - mpn_sub_n (bsm1, bs1, b1, n);
166     }
167   bs1[n] += mpn_add_n (bs1, bs1, b1, n);  /* b0+b1+b2 */
168 #endif
169 
170   /* Compute bs2 and bsm2. */
171 #if HAVE_NATIVE_mpn_addlsh_n || HAVE_NATIVE_mpn_addlsh2_n
172 #if HAVE_NATIVE_mpn_addlsh2_n
173   cy = mpn_addlsh2_n (bs2, b0, b2, t);
174 #else /* HAVE_NATIVE_mpn_addlsh_n */
175   cy = mpn_addlsh_n (bs2, b0, b2, t, 2);
176 #endif
177   if (t < n)
178     cy = mpn_add_1 (bs2 + t, b0 + t, n - t, cy);
179   bs2[n] = cy;
180 #else
181   cy = mpn_lshift (gp, b2, t, 2);
182   bs2[n] = mpn_add (bs2, b0, n, gp, t);
183   MPN_INCR_U (bs2 + t, n+1-t, cy);
184 #endif
185 
186   gp[n] = mpn_lshift (gp, b1, n, 1);
187 
188 #if HAVE_NATIVE_mpn_add_n_sub_n
189   if (mpn_cmp (bs2, gp, n+1) < 0)
190     {
191       ASSERT_NOCARRY (mpn_add_n_sub_n (bs2, bsm2, gp, bs2, n+1));
192       flags = (enum toom7_flags) (flags ^ toom7_w1_neg);
193     }
194   else
195     {
196       ASSERT_NOCARRY (mpn_add_n_sub_n (bs2, bsm2, bs2, gp, n+1));
197     }
198 #else
199   if (mpn_cmp (bs2, gp, n+1) < 0)
200     {
201       ASSERT_NOCARRY (mpn_sub_n (bsm2, gp, bs2, n+1));
202       flags = (enum toom7_flags) (flags ^ toom7_w1_neg);
203     }
204   else
205     {
206       ASSERT_NOCARRY (mpn_sub_n (bsm2, bs2, gp, n+1));
207     }
208   mpn_add_n (bs2, bs2, gp, n+1);
209 #endif
210 
211   /* Compute bsh = 4 b0 + 2 b1 + b2 = 2*(2*b0 + b1)+b2.  */
212 #if HAVE_NATIVE_mpn_addlsh1_n
213   cy = mpn_addlsh1_n (bsh, b1, b0, n);
214   if (t < n)
215     {
216       mp_limb_t cy2;
217       cy2 = mpn_addlsh1_n (bsh, b2, bsh, t);
218       bsh[n] = 2*cy + mpn_lshift (bsh + t, bsh + t, n - t, 1);
219       MPN_INCR_U (bsh + t, n+1-t, cy2);
220     }
221   else
222     bsh[n] = 2*cy + mpn_addlsh1_n (bsh, b2, bsh, n);
223 #else
224   cy = mpn_lshift (bsh, b0, n, 1);
225   cy += mpn_add_n (bsh, bsh, b1, n);
226   cy = 2*cy + mpn_lshift (bsh, bsh, n, 1);
227   bsh[n] = cy + mpn_add (bsh, bsh, n, b2, t);
228 #endif
229 
230   ASSERT (as1[n] <= 4);
231   ASSERT (bs1[n] <= 2);
232   ASSERT (asm1[n] <= 2);
233   ASSERT (bsm1[n] <= 1);
234   ASSERT (as2[n] <= 30);
235   ASSERT (bs2[n] <= 6);
236   ASSERT (asm2[n] <= 20);
237   ASSERT (bsm2[n] <= 4);
238   ASSERT (ash[n] <= 30);
239   ASSERT (bsh[n] <= 6);
240 
241 #define v0    pp				/* 2n */
242 #define v1    (pp + 2 * n)			/* 2n+1 */
243 #define vinf  (pp + 6 * n)			/* s+t */
244 #define v2    scratch				/* 2n+1 */
245 #define vm2   (scratch + 2 * n + 1)		/* 2n+1 */
246 #define vh    (scratch + 4 * n + 2)		/* 2n+1 */
247 #define vm1   (scratch + 6 * n + 3)		/* 2n+1 */
248 #define scratch_out (scratch + 8 * n + 4)		/* 2n+1 */
249   /* Total scratch need: 10*n+5 */
250 
251   /* Must be in allocation order, as they overwrite one limb beyond
252    * 2n+1. */
253   mpn_mul_n (v2, as2, bs2, n + 1);		/* v2, 2n+1 limbs */
254   mpn_mul_n (vm2, asm2, bsm2, n + 1);		/* vm2, 2n+1 limbs */
255   mpn_mul_n (vh, ash, bsh, n + 1);		/* vh, 2n+1 limbs */
256 
257   /* vm1, 2n+1 limbs */
258 #ifdef SMALLER_RECURSION
259   mpn_mul_n (vm1, asm1, bsm1, n);
260   if (asm1[n] == 1)
261     {
262       cy = bsm1[n] + mpn_add_n (vm1 + n, vm1 + n, bsm1, n);
263     }
264   else if (asm1[n] == 2)
265     {
266 #if HAVE_NATIVE_mpn_addlsh1_n_ip1
267       cy = 2 * bsm1[n] + mpn_addlsh1_n_ip1 (vm1 + n, bsm1, n);
268 #else
269       cy = 2 * bsm1[n] + mpn_addmul_1 (vm1 + n, bsm1, n, CNST_LIMB(2));
270 #endif
271     }
272   else
273     cy = 0;
274   if (bsm1[n] != 0)
275     cy += mpn_add_n (vm1 + n, vm1 + n, asm1, n);
276   vm1[2 * n] = cy;
277 #else /* SMALLER_RECURSION */
278   vm1[2 * n] = 0;
279   mpn_mul_n (vm1, asm1, bsm1, n + ((asm1[n] | bsm1[n]) != 0));
280 #endif /* SMALLER_RECURSION */
281 
282   /* v1, 2n+1 limbs */
283 #ifdef SMALLER_RECURSION
284   mpn_mul_n (v1, as1, bs1, n);
285   if (as1[n] == 1)
286     {
287       cy = bs1[n] + mpn_add_n (v1 + n, v1 + n, bs1, n);
288     }
289   else if (as1[n] == 2)
290     {
291 #if HAVE_NATIVE_mpn_addlsh1_n_ip1
292       cy = 2 * bs1[n] + mpn_addlsh1_n_ip1 (v1 + n, bs1, n);
293 #else
294       cy = 2 * bs1[n] + mpn_addmul_1 (v1 + n, bs1, n, CNST_LIMB(2));
295 #endif
296     }
297   else if (as1[n] != 0)
298     {
299       cy = as1[n] * bs1[n] + mpn_addmul_1 (v1 + n, bs1, n, as1[n]);
300     }
301   else
302     cy = 0;
303   if (bs1[n] == 1)
304     {
305       cy += mpn_add_n (v1 + n, v1 + n, as1, n);
306     }
307   else if (bs1[n] == 2)
308     {
309 #if HAVE_NATIVE_mpn_addlsh1_n_ip1
310       cy += mpn_addlsh1_n_ip1 (v1 + n, as1, n);
311 #else
312       cy += mpn_addmul_1 (v1 + n, as1, n, CNST_LIMB(2));
313 #endif
314     }
315   v1[2 * n] = cy;
316 #else /* SMALLER_RECURSION */
317   v1[2 * n] = 0;
318   mpn_mul_n (v1, as1, bs1, n + ((as1[n] | bs1[n]) != 0));
319 #endif /* SMALLER_RECURSION */
320 
321   mpn_mul_n (v0, a0, b0, n);			/* v0, 2n limbs */
322 
323   /* vinf, s+t limbs */
324   if (s > t)  mpn_mul (vinf, a4, s, b2, t);
325   else        mpn_mul (vinf, b2, t, a4, s);
326 
327   mpn_toom_interpolate_7pts (pp, n, flags, vm2, vm1, v2, vh, s + t,
328 			     scratch_out);
329 
330   TMP_FREE;
331 }
332