xref: /netbsd-src/external/lgpl3/gmp/dist/mpn/generic/mul_fft.c (revision e5d758f832e07a177fa24707c434b7ce26d0f762)
1 /* Schoenhage's fast multiplication modulo 2^N+1.
2 
3    Contributed by Paul Zimmermann.
4 
5    THE FUNCTIONS IN THIS FILE ARE INTERNAL WITH MUTABLE INTERFACES.  IT IS ONLY
6    SAFE TO REACH THEM THROUGH DOCUMENTED INTERFACES.  IN FACT, IT IS ALMOST
7    GUARANTEED THAT THEY WILL CHANGE OR DISAPPEAR IN A FUTURE GNU MP RELEASE.
8 
9 Copyright 1998-2010, 2012, 2013, 2018 Free Software Foundation, Inc.
10 
11 This file is part of the GNU MP Library.
12 
13 The GNU MP Library is free software; you can redistribute it and/or modify
14 it under the terms of either:
15 
16   * the GNU Lesser General Public License as published by the Free
17     Software Foundation; either version 3 of the License, or (at your
18     option) any later version.
19 
20 or
21 
22   * the GNU General Public License as published by the Free Software
23     Foundation; either version 2 of the License, or (at your option) any
24     later version.
25 
26 or both in parallel, as here.
27 
28 The GNU MP Library is distributed in the hope that it will be useful, but
29 WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
30 or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
31 for more details.
32 
33 You should have received copies of the GNU General Public License and the
34 GNU Lesser General Public License along with the GNU MP Library.  If not,
35 see https://www.gnu.org/licenses/.  */
36 
37 
38 /* References:
39 
40    Schnelle Multiplikation grosser Zahlen, by Arnold Schoenhage and Volker
41    Strassen, Computing 7, p. 281-292, 1971.
42 
43    Asymptotically fast algorithms for the numerical multiplication and division
44    of polynomials with complex coefficients, by Arnold Schoenhage, Computer
45    Algebra, EUROCAM'82, LNCS 144, p. 3-15, 1982.
46 
47    Tapes versus Pointers, a study in implementing fast algorithms, by Arnold
48    Schoenhage, Bulletin of the EATCS, 30, p. 23-32, 1986.
49 
50    TODO:
51 
52    Implement some of the tricks published at ISSAC'2007 by Gaudry, Kruppa, and
53    Zimmermann.
54 
55    It might be possible to avoid a small number of MPN_COPYs by using a
56    rotating temporary or two.
57 
58    Cleanup and simplify the code!
59 */
60 
61 #ifdef TRACE
62 #undef TRACE
63 #define TRACE(x) x
64 #include <stdio.h>
65 #else
66 #define TRACE(x)
67 #endif
68 
69 #include "gmp-impl.h"
70 
71 #ifdef WANT_ADDSUB
72 #include "generic/add_n_sub_n.c"
73 #define HAVE_NATIVE_mpn_add_n_sub_n 1
74 #endif
75 
76 static mp_limb_t mpn_mul_fft_internal (mp_ptr, mp_size_t, int, mp_ptr *,
77 				       mp_ptr *, mp_ptr, mp_ptr, mp_size_t,
78 				       mp_size_t, mp_size_t, int **, mp_ptr, int);
79 static void mpn_mul_fft_decompose (mp_ptr, mp_ptr *, mp_size_t, mp_size_t, mp_srcptr,
80 				   mp_size_t, mp_size_t, mp_size_t, mp_ptr);
81 
82 
83 /* Find the best k to use for a mod 2^(m*GMP_NUMB_BITS)+1 FFT for m >= n.
84    We have sqr=0 if for a multiply, sqr=1 for a square.
85    There are three generations of this code; we keep the old ones as long as
86    some gmp-mparam.h is not updated.  */
87 
88 
89 /*****************************************************************************/
90 
91 #if TUNE_PROGRAM_BUILD || (defined (MUL_FFT_TABLE3) && defined (SQR_FFT_TABLE3))
92 
93 #ifndef FFT_TABLE3_SIZE		/* When tuning this is defined in gmp-impl.h */
94 #if defined (MUL_FFT_TABLE3_SIZE) && defined (SQR_FFT_TABLE3_SIZE)
95 #if MUL_FFT_TABLE3_SIZE > SQR_FFT_TABLE3_SIZE
96 #define FFT_TABLE3_SIZE MUL_FFT_TABLE3_SIZE
97 #else
98 #define FFT_TABLE3_SIZE SQR_FFT_TABLE3_SIZE
99 #endif
100 #endif
101 #endif
102 
103 #ifndef FFT_TABLE3_SIZE
104 #define FFT_TABLE3_SIZE 200
105 #endif
106 
107 FFT_TABLE_ATTRS struct fft_table_nk mpn_fft_table3[2][FFT_TABLE3_SIZE] =
108 {
109   MUL_FFT_TABLE3,
110   SQR_FFT_TABLE3
111 };
112 
113 int
114 mpn_fft_best_k (mp_size_t n, int sqr)
115 {
116   const struct fft_table_nk *fft_tab, *tab;
117   mp_size_t tab_n, thres;
118   int last_k;
119 
120   fft_tab = mpn_fft_table3[sqr];
121   last_k = fft_tab->k;
122   for (tab = fft_tab + 1; ; tab++)
123     {
124       tab_n = tab->n;
125       thres = tab_n << last_k;
126       if (n <= thres)
127 	break;
128       last_k = tab->k;
129     }
130   return last_k;
131 }
132 
133 #define MPN_FFT_BEST_READY 1
134 #endif
135 
136 /*****************************************************************************/
137 
138 #if ! defined (MPN_FFT_BEST_READY)
139 FFT_TABLE_ATTRS mp_size_t mpn_fft_table[2][MPN_FFT_TABLE_SIZE] =
140 {
141   MUL_FFT_TABLE,
142   SQR_FFT_TABLE
143 };
144 
145 int
146 mpn_fft_best_k (mp_size_t n, int sqr)
147 {
148   int i;
149 
150   for (i = 0; mpn_fft_table[sqr][i] != 0; i++)
151     if (n < mpn_fft_table[sqr][i])
152       return i + FFT_FIRST_K;
153 
154   /* treat 4*last as one further entry */
155   if (i == 0 || n < 4 * mpn_fft_table[sqr][i - 1])
156     return i + FFT_FIRST_K;
157   else
158     return i + FFT_FIRST_K + 1;
159 }
160 #endif
161 
162 /*****************************************************************************/
163 
164 
165 /* Returns smallest possible number of limbs >= pl for a fft of size 2^k,
166    i.e. smallest multiple of 2^k >= pl.
167 
168    Don't declare static: needed by tuneup.
169 */
170 
171 mp_size_t
172 mpn_fft_next_size (mp_size_t pl, int k)
173 {
174   pl = 1 + ((pl - 1) >> k); /* ceil (pl/2^k) */
175   return pl << k;
176 }
177 
178 
179 /* Initialize l[i][j] with bitrev(j) */
180 static void
181 mpn_fft_initl (int **l, int k)
182 {
183   int i, j, K;
184   int *li;
185 
186   l[0][0] = 0;
187   for (i = 1, K = 1; i <= k; i++, K *= 2)
188     {
189       li = l[i];
190       for (j = 0; j < K; j++)
191 	{
192 	  li[j] = 2 * l[i - 1][j];
193 	  li[K + j] = 1 + li[j];
194 	}
195     }
196 }
197 
198 
199 /* r <- a*2^d mod 2^(n*GMP_NUMB_BITS)+1 with a = {a, n+1}
200    Assumes a is semi-normalized, i.e. a[n] <= 1.
201    r and a must have n+1 limbs, and not overlap.
202 */
203 static void
204 mpn_fft_mul_2exp_modF (mp_ptr r, mp_srcptr a, mp_bitcnt_t d, mp_size_t n)
205 {
206   unsigned int sh;
207   mp_size_t m;
208   mp_limb_t cc, rd;
209 
210   sh = d % GMP_NUMB_BITS;
211   m = d / GMP_NUMB_BITS;
212 
213   if (m >= n)			/* negate */
214     {
215       /* r[0..m-1]  <-- lshift(a[n-m]..a[n-1], sh)
216 	 r[m..n-1]  <-- -lshift(a[0]..a[n-m-1],  sh) */
217 
218       m -= n;
219       if (sh != 0)
220 	{
221 	  /* no out shift below since a[n] <= 1 */
222 	  mpn_lshift (r, a + n - m, m + 1, sh);
223 	  rd = r[m];
224 	  cc = mpn_lshiftc (r + m, a, n - m, sh);
225 	}
226       else
227 	{
228 	  MPN_COPY (r, a + n - m, m);
229 	  rd = a[n];
230 	  mpn_com (r + m, a, n - m);
231 	  cc = 0;
232 	}
233 
234       /* add cc to r[0], and add rd to r[m] */
235 
236       /* now add 1 in r[m], subtract 1 in r[n], i.e. add 1 in r[0] */
237 
238       r[n] = 0;
239       /* cc < 2^sh <= 2^(GMP_NUMB_BITS-1) thus no overflow here */
240       cc++;
241       mpn_incr_u (r, cc);
242 
243       rd++;
244       /* rd might overflow when sh=GMP_NUMB_BITS-1 */
245       cc = (rd == 0) ? 1 : rd;
246       r = r + m + (rd == 0);
247       mpn_incr_u (r, cc);
248     }
249   else
250     {
251       /* r[0..m-1]  <-- -lshift(a[n-m]..a[n-1], sh)
252 	 r[m..n-1]  <-- lshift(a[0]..a[n-m-1],  sh)  */
253       if (sh != 0)
254 	{
255 	  /* no out bits below since a[n] <= 1 */
256 	  mpn_lshiftc (r, a + n - m, m + 1, sh);
257 	  rd = ~r[m];
258 	  /* {r, m+1} = {a+n-m, m+1} << sh */
259 	  cc = mpn_lshift (r + m, a, n - m, sh); /* {r+m, n-m} = {a, n-m}<<sh */
260 	}
261       else
262 	{
263 	  /* r[m] is not used below, but we save a test for m=0 */
264 	  mpn_com (r, a + n - m, m + 1);
265 	  rd = a[n];
266 	  MPN_COPY (r + m, a, n - m);
267 	  cc = 0;
268 	}
269 
270       /* now complement {r, m}, subtract cc from r[0], subtract rd from r[m] */
271 
272       /* if m=0 we just have r[0]=a[n] << sh */
273       if (m != 0)
274 	{
275 	  /* now add 1 in r[0], subtract 1 in r[m] */
276 	  if (cc-- == 0) /* then add 1 to r[0] */
277 	    cc = mpn_add_1 (r, r, n, CNST_LIMB(1));
278 	  cc = mpn_sub_1 (r, r, m, cc) + 1;
279 	  /* add 1 to cc instead of rd since rd might overflow */
280 	}
281 
282       /* now subtract cc and rd from r[m..n] */
283 
284       r[n] = -mpn_sub_1 (r + m, r + m, n - m, cc);
285       r[n] -= mpn_sub_1 (r + m, r + m, n - m, rd);
286       if (r[n] & GMP_LIMB_HIGHBIT)
287 	r[n] = mpn_add_1 (r, r, n, CNST_LIMB(1));
288     }
289 }
290 
291 #if HAVE_NATIVE_mpn_add_n_sub_n
292 static inline void
293 mpn_fft_add_sub_modF (mp_ptr A0, mp_ptr Ai, mp_srcptr tp, mp_size_t n)
294 {
295   mp_limb_t cyas, c, x;
296 
297   cyas = mpn_add_n_sub_n (A0, Ai, A0, tp, n);
298 
299   c = A0[n] - tp[n] - (cyas & 1);
300   x = (-c) & -((c & GMP_LIMB_HIGHBIT) != 0);
301   Ai[n] = x + c;
302   MPN_INCR_U (Ai, n + 1, x);
303 
304   c = A0[n] + tp[n] + (cyas >> 1);
305   x = (c - 1) & -(c != 0);
306   A0[n] = c - x;
307   MPN_DECR_U (A0, n + 1, x);
308 }
309 
310 #else /* ! HAVE_NATIVE_mpn_add_n_sub_n  */
311 
312 /* r <- a+b mod 2^(n*GMP_NUMB_BITS)+1.
313    Assumes a and b are semi-normalized.
314 */
315 static inline void
316 mpn_fft_add_modF (mp_ptr r, mp_srcptr a, mp_srcptr b, mp_size_t n)
317 {
318   mp_limb_t c, x;
319 
320   c = a[n] + b[n] + mpn_add_n (r, a, b, n);
321   /* 0 <= c <= 3 */
322 
323 #if 1
324   /* GCC 4.1 outsmarts most expressions here, and generates a 50% branch.  The
325      result is slower code, of course.  But the following outsmarts GCC.  */
326   x = (c - 1) & -(c != 0);
327   r[n] = c - x;
328   MPN_DECR_U (r, n + 1, x);
329 #endif
330 #if 0
331   if (c > 1)
332     {
333       r[n] = 1;                       /* r[n] - c = 1 */
334       MPN_DECR_U (r, n + 1, c - 1);
335     }
336   else
337     {
338       r[n] = c;
339     }
340 #endif
341 }
342 
343 /* r <- a-b mod 2^(n*GMP_NUMB_BITS)+1.
344    Assumes a and b are semi-normalized.
345 */
346 static inline void
347 mpn_fft_sub_modF (mp_ptr r, mp_srcptr a, mp_srcptr b, mp_size_t n)
348 {
349   mp_limb_t c, x;
350 
351   c = a[n] - b[n] - mpn_sub_n (r, a, b, n);
352   /* -2 <= c <= 1 */
353 
354 #if 1
355   /* GCC 4.1 outsmarts most expressions here, and generates a 50% branch.  The
356      result is slower code, of course.  But the following outsmarts GCC.  */
357   x = (-c) & -((c & GMP_LIMB_HIGHBIT) != 0);
358   r[n] = x + c;
359   MPN_INCR_U (r, n + 1, x);
360 #endif
361 #if 0
362   if ((c & GMP_LIMB_HIGHBIT) != 0)
363     {
364       r[n] = 0;
365       MPN_INCR_U (r, n + 1, -c);
366     }
367   else
368     {
369       r[n] = c;
370     }
371 #endif
372 }
373 #endif /* HAVE_NATIVE_mpn_add_n_sub_n */
374 
375 /* input: A[0] ... A[inc*(K-1)] are residues mod 2^N+1 where
376 	  N=n*GMP_NUMB_BITS, and 2^omega is a primitive root mod 2^N+1
377    output: A[inc*l[k][i]] <- \sum (2^omega)^(ij) A[inc*j] mod 2^N+1 */
378 
379 static void
380 mpn_fft_fft (mp_ptr *Ap, mp_size_t K, int **ll,
381 	     mp_size_t omega, mp_size_t n, mp_size_t inc, mp_ptr tp)
382 {
383   if (K == 2)
384     {
385       mp_limb_t cy;
386 #if HAVE_NATIVE_mpn_add_n_sub_n
387       cy = mpn_add_n_sub_n (Ap[0], Ap[inc], Ap[0], Ap[inc], n + 1) & 1;
388 #else
389       MPN_COPY (tp, Ap[0], n + 1);
390       mpn_add_n (Ap[0], Ap[0], Ap[inc], n + 1);
391       cy = mpn_sub_n (Ap[inc], tp, Ap[inc], n + 1);
392 #endif
393       if (Ap[0][n] > 1) /* can be 2 or 3 */
394 	Ap[0][n] = 1 - mpn_sub_1 (Ap[0], Ap[0], n, Ap[0][n] - 1);
395       if (cy) /* Ap[inc][n] can be -1 or -2 */
396 	Ap[inc][n] = mpn_add_1 (Ap[inc], Ap[inc], n, ~Ap[inc][n] + 1);
397     }
398   else
399     {
400       mp_size_t j, K2 = K >> 1;
401       int *lk = *ll;
402 
403       mpn_fft_fft (Ap,     K2, ll-1, 2 * omega, n, inc * 2, tp);
404       mpn_fft_fft (Ap+inc, K2, ll-1, 2 * omega, n, inc * 2, tp);
405       /* A[2*j*inc]   <- A[2*j*inc] + omega^l[k][2*j*inc] A[(2j+1)inc]
406 	 A[(2j+1)inc] <- A[2*j*inc] + omega^l[k][(2j+1)inc] A[(2j+1)inc] */
407       for (j = 0; j < K2; j++, lk += 2, Ap += 2 * inc)
408 	{
409 	  /* Ap[inc] <- Ap[0] + Ap[inc] * 2^(lk[1] * omega)
410 	     Ap[0]   <- Ap[0] + Ap[inc] * 2^(lk[0] * omega) */
411 	  mpn_fft_mul_2exp_modF (tp, Ap[inc], lk[0] * omega, n);
412 #if HAVE_NATIVE_mpn_add_n_sub_n
413 	  mpn_fft_add_sub_modF (Ap[0], Ap[inc], tp, n);
414 #else
415 	  mpn_fft_sub_modF (Ap[inc], Ap[0], tp, n);
416 	  mpn_fft_add_modF (Ap[0],   Ap[0], tp, n);
417 #endif
418 	}
419     }
420 }
421 
422 /* input: A[0] ... A[inc*(K-1)] are residues mod 2^N+1 where
423 	  N=n*GMP_NUMB_BITS, and 2^omega is a primitive root mod 2^N+1
424    output: A[inc*l[k][i]] <- \sum (2^omega)^(ij) A[inc*j] mod 2^N+1
425    tp must have space for 2*(n+1) limbs.
426 */
427 
428 
429 /* Given ap[0..n] with ap[n]<=1, reduce it modulo 2^(n*GMP_NUMB_BITS)+1,
430    by subtracting that modulus if necessary.
431 
432    If ap[0..n] is exactly 2^(n*GMP_NUMB_BITS) then mpn_sub_1 produces a
433    borrow and the limbs must be zeroed out again.  This will occur very
434    infrequently.  */
435 
436 static inline void
437 mpn_fft_normalize (mp_ptr ap, mp_size_t n)
438 {
439   if (ap[n] != 0)
440     {
441       MPN_DECR_U (ap, n + 1, CNST_LIMB(1));
442       if (ap[n] == 0)
443 	{
444 	  /* This happens with very low probability; we have yet to trigger it,
445 	     and thereby make sure this code is correct.  */
446 	  MPN_ZERO (ap, n);
447 	  ap[n] = 1;
448 	}
449       else
450 	ap[n] = 0;
451     }
452 }
453 
454 /* a[i] <- a[i]*b[i] mod 2^(n*GMP_NUMB_BITS)+1 for 0 <= i < K */
455 static void
456 mpn_fft_mul_modF_K (mp_ptr *ap, mp_ptr *bp, mp_size_t n, mp_size_t K)
457 {
458   int i;
459   int sqr = (ap == bp);
460   TMP_DECL;
461 
462   TMP_MARK;
463 
464   if (n >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD))
465     {
466       mp_size_t K2, nprime2, Nprime2, M2, maxLK, l, Mp2;
467       int k;
468       int **fft_l, *tmp;
469       mp_ptr *Ap, *Bp, A, B, T;
470 
471       k = mpn_fft_best_k (n, sqr);
472       K2 = (mp_size_t) 1 << k;
473       ASSERT_ALWAYS((n & (K2 - 1)) == 0);
474       maxLK = (K2 > GMP_NUMB_BITS) ? K2 : GMP_NUMB_BITS;
475       M2 = n * GMP_NUMB_BITS >> k;
476       l = n >> k;
477       Nprime2 = ((2 * M2 + k + 2 + maxLK) / maxLK) * maxLK;
478       /* Nprime2 = ceil((2*M2+k+3)/maxLK)*maxLK*/
479       nprime2 = Nprime2 / GMP_NUMB_BITS;
480 
481       /* we should ensure that nprime2 is a multiple of the next K */
482       if (nprime2 >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD))
483 	{
484 	  mp_size_t K3;
485 	  for (;;)
486 	    {
487 	      K3 = (mp_size_t) 1 << mpn_fft_best_k (nprime2, sqr);
488 	      if ((nprime2 & (K3 - 1)) == 0)
489 		break;
490 	      nprime2 = (nprime2 + K3 - 1) & -K3;
491 	      Nprime2 = nprime2 * GMP_LIMB_BITS;
492 	      /* warning: since nprime2 changed, K3 may change too! */
493 	    }
494 	}
495       ASSERT_ALWAYS(nprime2 < n); /* otherwise we'll loop */
496 
497       Mp2 = Nprime2 >> k;
498 
499       Ap = TMP_BALLOC_MP_PTRS (K2);
500       Bp = TMP_BALLOC_MP_PTRS (K2);
501       A = TMP_BALLOC_LIMBS (2 * (nprime2 + 1) << k);
502       T = TMP_BALLOC_LIMBS (2 * (nprime2 + 1));
503       B = A + ((nprime2 + 1) << k);
504       fft_l = TMP_BALLOC_TYPE (k + 1, int *);
505       tmp = TMP_BALLOC_TYPE ((size_t) 2 << k, int);
506       for (i = 0; i <= k; i++)
507 	{
508 	  fft_l[i] = tmp;
509 	  tmp += (mp_size_t) 1 << i;
510 	}
511 
512       mpn_fft_initl (fft_l, k);
513 
514       TRACE (printf ("recurse: %ldx%ld limbs -> %ld times %ldx%ld (%1.2f)\n", n,
515 		    n, K2, nprime2, nprime2, 2.0*(double)n/nprime2/K2));
516       for (i = 0; i < K; i++, ap++, bp++)
517 	{
518 	  mp_limb_t cy;
519 	  mpn_fft_normalize (*ap, n);
520 	  if (!sqr)
521 	    mpn_fft_normalize (*bp, n);
522 
523 	  mpn_mul_fft_decompose (A, Ap, K2, nprime2, *ap, (l << k) + 1, l, Mp2, T);
524 	  if (!sqr)
525 	    mpn_mul_fft_decompose (B, Bp, K2, nprime2, *bp, (l << k) + 1, l, Mp2, T);
526 
527 	  cy = mpn_mul_fft_internal (*ap, n, k, Ap, Bp, A, B, nprime2,
528 				     l, Mp2, fft_l, T, sqr);
529 	  (*ap)[n] = cy;
530 	}
531     }
532   else
533     {
534       mp_ptr a, b, tp, tpn;
535       mp_limb_t cc;
536       mp_size_t n2 = 2 * n;
537       tp = TMP_BALLOC_LIMBS (n2);
538       tpn = tp + n;
539       TRACE (printf ("  mpn_mul_n %ld of %ld limbs\n", K, n));
540       for (i = 0; i < K; i++)
541 	{
542 	  a = *ap++;
543 	  b = *bp++;
544 	  if (sqr)
545 	    mpn_sqr (tp, a, n);
546 	  else
547 	    mpn_mul_n (tp, b, a, n);
548 	  if (a[n] != 0)
549 	    cc = mpn_add_n (tpn, tpn, b, n);
550 	  else
551 	    cc = 0;
552 	  if (b[n] != 0)
553 	    cc += mpn_add_n (tpn, tpn, a, n) + a[n];
554 	  if (cc != 0)
555 	    {
556 	      /* FIXME: use MPN_INCR_U here, since carry is not expected.  */
557 	      cc = mpn_add_1 (tp, tp, n2, cc);
558 	      ASSERT (cc == 0);
559 	    }
560 	  a[n] = mpn_sub_n (a, tp, tpn, n) && mpn_add_1 (a, a, n, CNST_LIMB(1));
561 	}
562     }
563   TMP_FREE;
564 }
565 
566 
567 /* input: A^[l[k][0]] A^[l[k][1]] ... A^[l[k][K-1]]
568    output: K*A[0] K*A[K-1] ... K*A[1].
569    Assumes the Ap[] are pseudo-normalized, i.e. 0 <= Ap[][n] <= 1.
570    This condition is also fulfilled at exit.
571 */
572 static void
573 mpn_fft_fftinv (mp_ptr *Ap, mp_size_t K, mp_size_t omega, mp_size_t n, mp_ptr tp)
574 {
575   if (K == 2)
576     {
577       mp_limb_t cy;
578 #if HAVE_NATIVE_mpn_add_n_sub_n
579       cy = mpn_add_n_sub_n (Ap[0], Ap[1], Ap[0], Ap[1], n + 1) & 1;
580 #else
581       MPN_COPY (tp, Ap[0], n + 1);
582       mpn_add_n (Ap[0], Ap[0], Ap[1], n + 1);
583       cy = mpn_sub_n (Ap[1], tp, Ap[1], n + 1);
584 #endif
585       if (Ap[0][n] > 1) /* can be 2 or 3 */
586 	Ap[0][n] = 1 - mpn_sub_1 (Ap[0], Ap[0], n, Ap[0][n] - 1);
587       if (cy) /* Ap[1][n] can be -1 or -2 */
588 	Ap[1][n] = mpn_add_1 (Ap[1], Ap[1], n, ~Ap[1][n] + 1);
589     }
590   else
591     {
592       mp_size_t j, K2 = K >> 1;
593 
594       mpn_fft_fftinv (Ap,      K2, 2 * omega, n, tp);
595       mpn_fft_fftinv (Ap + K2, K2, 2 * omega, n, tp);
596       /* A[j]     <- A[j] + omega^j A[j+K/2]
597 	 A[j+K/2] <- A[j] + omega^(j+K/2) A[j+K/2] */
598       for (j = 0; j < K2; j++, Ap++)
599 	{
600 	  /* Ap[K2] <- Ap[0] + Ap[K2] * 2^((j + K2) * omega)
601 	     Ap[0]  <- Ap[0] + Ap[K2] * 2^(j * omega) */
602 	  mpn_fft_mul_2exp_modF (tp, Ap[K2], j * omega, n);
603 #if HAVE_NATIVE_mpn_add_n_sub_n
604 	  mpn_fft_add_sub_modF (Ap[0], Ap[K2], tp, n);
605 #else
606 	  mpn_fft_sub_modF (Ap[K2], Ap[0], tp, n);
607 	  mpn_fft_add_modF (Ap[0],  Ap[0], tp, n);
608 #endif
609 	}
610     }
611 }
612 
613 
614 /* R <- A/2^k mod 2^(n*GMP_NUMB_BITS)+1 */
615 static void
616 mpn_fft_div_2exp_modF (mp_ptr r, mp_srcptr a, mp_bitcnt_t k, mp_size_t n)
617 {
618   mp_bitcnt_t i;
619 
620   ASSERT (r != a);
621   i = (mp_bitcnt_t) 2 * n * GMP_NUMB_BITS - k;
622   mpn_fft_mul_2exp_modF (r, a, i, n);
623   /* 1/2^k = 2^(2nL-k) mod 2^(n*GMP_NUMB_BITS)+1 */
624   /* normalize so that R < 2^(n*GMP_NUMB_BITS)+1 */
625   mpn_fft_normalize (r, n);
626 }
627 
628 
629 /* {rp,n} <- {ap,an} mod 2^(n*GMP_NUMB_BITS)+1, n <= an <= 3*n.
630    Returns carry out, i.e. 1 iff {ap,an} = -1 mod 2^(n*GMP_NUMB_BITS)+1,
631    then {rp,n}=0.
632 */
633 static mp_size_t
634 mpn_fft_norm_modF (mp_ptr rp, mp_size_t n, mp_ptr ap, mp_size_t an)
635 {
636   mp_size_t l, m, rpn;
637   mp_limb_t cc;
638 
639   ASSERT ((n <= an) && (an <= 3 * n));
640   m = an - 2 * n;
641   if (m > 0)
642     {
643       l = n;
644       /* add {ap, m} and {ap+2n, m} in {rp, m} */
645       cc = mpn_add_n (rp, ap, ap + 2 * n, m);
646       /* copy {ap+m, n-m} to {rp+m, n-m} */
647       rpn = mpn_add_1 (rp + m, ap + m, n - m, cc);
648     }
649   else
650     {
651       l = an - n; /* l <= n */
652       MPN_COPY (rp, ap, n);
653       rpn = 0;
654     }
655 
656   /* remains to subtract {ap+n, l} from {rp, n+1} */
657   cc = mpn_sub_n (rp, rp, ap + n, l);
658   rpn -= mpn_sub_1 (rp + l, rp + l, n - l, cc);
659   if (rpn < 0) /* necessarily rpn = -1 */
660     rpn = mpn_add_1 (rp, rp, n, CNST_LIMB(1));
661   return rpn;
662 }
663 
664 /* store in A[0..nprime] the first M bits from {n, nl},
665    in A[nprime+1..] the following M bits, ...
666    Assumes M is a multiple of GMP_NUMB_BITS (M = l * GMP_NUMB_BITS).
667    T must have space for at least (nprime + 1) limbs.
668    We must have nl <= 2*K*l.
669 */
670 static void
671 mpn_mul_fft_decompose (mp_ptr A, mp_ptr *Ap, mp_size_t K, mp_size_t nprime,
672 		       mp_srcptr n, mp_size_t nl, mp_size_t l, mp_size_t Mp,
673 		       mp_ptr T)
674 {
675   mp_size_t i, j;
676   mp_ptr tmp;
677   mp_size_t Kl = K * l;
678   TMP_DECL;
679   TMP_MARK;
680 
681   if (nl > Kl) /* normalize {n, nl} mod 2^(Kl*GMP_NUMB_BITS)+1 */
682     {
683       mp_size_t dif = nl - Kl;
684       mp_limb_signed_t cy;
685 
686       tmp = TMP_BALLOC_LIMBS(Kl + 1);
687 
688       if (dif > Kl)
689 	{
690 	  int subp = 0;
691 
692 	  cy = mpn_sub_n (tmp, n, n + Kl, Kl);
693 	  n += 2 * Kl;
694 	  dif -= Kl;
695 
696 	  /* now dif > 0 */
697 	  while (dif > Kl)
698 	    {
699 	      if (subp)
700 		cy += mpn_sub_n (tmp, tmp, n, Kl);
701 	      else
702 		cy -= mpn_add_n (tmp, tmp, n, Kl);
703 	      subp ^= 1;
704 	      n += Kl;
705 	      dif -= Kl;
706 	    }
707 	  /* now dif <= Kl */
708 	  if (subp)
709 	    cy += mpn_sub (tmp, tmp, Kl, n, dif);
710 	  else
711 	    cy -= mpn_add (tmp, tmp, Kl, n, dif);
712 	  if (cy >= 0)
713 	    cy = mpn_add_1 (tmp, tmp, Kl, cy);
714 	  else
715 	    cy = mpn_sub_1 (tmp, tmp, Kl, -cy);
716 	}
717       else /* dif <= Kl, i.e. nl <= 2 * Kl */
718 	{
719 	  cy = mpn_sub (tmp, n, Kl, n + Kl, dif);
720 	  cy = mpn_add_1 (tmp, tmp, Kl, cy);
721 	}
722       tmp[Kl] = cy;
723       nl = Kl + 1;
724       n = tmp;
725     }
726   for (i = 0; i < K; i++)
727     {
728       Ap[i] = A;
729       /* store the next M bits of n into A[0..nprime] */
730       if (nl > 0) /* nl is the number of remaining limbs */
731 	{
732 	  j = (l <= nl && i < K - 1) ? l : nl; /* store j next limbs */
733 	  nl -= j;
734 	  MPN_COPY (T, n, j);
735 	  MPN_ZERO (T + j, nprime + 1 - j);
736 	  n += l;
737 	  mpn_fft_mul_2exp_modF (A, T, i * Mp, nprime);
738 	}
739       else
740 	MPN_ZERO (A, nprime + 1);
741       A += nprime + 1;
742     }
743   ASSERT_ALWAYS (nl == 0);
744   TMP_FREE;
745 }
746 
747 /* op <- n*m mod 2^N+1 with fft of size 2^k where N=pl*GMP_NUMB_BITS
748    op is pl limbs, its high bit is returned.
749    One must have pl = mpn_fft_next_size (pl, k).
750    T must have space for 2 * (nprime + 1) limbs.
751 */
752 
753 static mp_limb_t
754 mpn_mul_fft_internal (mp_ptr op, mp_size_t pl, int k,
755 		      mp_ptr *Ap, mp_ptr *Bp, mp_ptr A, mp_ptr B,
756 		      mp_size_t nprime, mp_size_t l, mp_size_t Mp,
757 		      int **fft_l, mp_ptr T, int sqr)
758 {
759   mp_size_t K, i, pla, lo, sh, j;
760   mp_ptr p;
761   mp_limb_t cc;
762 
763   K = (mp_size_t) 1 << k;
764 
765   /* direct fft's */
766   mpn_fft_fft (Ap, K, fft_l + k, 2 * Mp, nprime, 1, T);
767   if (!sqr)
768     mpn_fft_fft (Bp, K, fft_l + k, 2 * Mp, nprime, 1, T);
769 
770   /* term to term multiplications */
771   mpn_fft_mul_modF_K (Ap, sqr ? Ap : Bp, nprime, K);
772 
773   /* inverse fft's */
774   mpn_fft_fftinv (Ap, K, 2 * Mp, nprime, T);
775 
776   /* division of terms after inverse fft */
777   Bp[0] = T + nprime + 1;
778   mpn_fft_div_2exp_modF (Bp[0], Ap[0], k, nprime);
779   for (i = 1; i < K; i++)
780     {
781       Bp[i] = Ap[i - 1];
782       mpn_fft_div_2exp_modF (Bp[i], Ap[i], k + (K - i) * Mp, nprime);
783     }
784 
785   /* addition of terms in result p */
786   MPN_ZERO (T, nprime + 1);
787   pla = l * (K - 1) + nprime + 1; /* number of required limbs for p */
788   p = B; /* B has K*(n' + 1) limbs, which is >= pla, i.e. enough */
789   MPN_ZERO (p, pla);
790   cc = 0; /* will accumulate the (signed) carry at p[pla] */
791   for (i = K - 1, lo = l * i + nprime,sh = l * i; i >= 0; i--,lo -= l,sh -= l)
792     {
793       mp_ptr n = p + sh;
794 
795       j = (K - i) & (K - 1);
796 
797       if (mpn_add_n (n, n, Bp[j], nprime + 1))
798 	cc += mpn_add_1 (n + nprime + 1, n + nprime + 1,
799 			  pla - sh - nprime - 1, CNST_LIMB(1));
800       T[2 * l] = i + 1; /* T = (i + 1)*2^(2*M) */
801       if (mpn_cmp (Bp[j], T, nprime + 1) > 0)
802 	{ /* subtract 2^N'+1 */
803 	  cc -= mpn_sub_1 (n, n, pla - sh, CNST_LIMB(1));
804 	  cc -= mpn_sub_1 (p + lo, p + lo, pla - lo, CNST_LIMB(1));
805 	}
806     }
807   if (cc == -CNST_LIMB(1))
808     {
809       if ((cc = mpn_add_1 (p + pla - pl, p + pla - pl, pl, CNST_LIMB(1))))
810 	{
811 	  /* p[pla-pl]...p[pla-1] are all zero */
812 	  mpn_sub_1 (p + pla - pl - 1, p + pla - pl - 1, pl + 1, CNST_LIMB(1));
813 	  mpn_sub_1 (p + pla - 1, p + pla - 1, 1, CNST_LIMB(1));
814 	}
815     }
816   else if (cc == 1)
817     {
818       if (pla >= 2 * pl)
819 	{
820 	  while ((cc = mpn_add_1 (p + pla - 2 * pl, p + pla - 2 * pl, 2 * pl, cc)))
821 	    ;
822 	}
823       else
824 	{
825 	  cc = mpn_sub_1 (p + pla - pl, p + pla - pl, pl, cc);
826 	  ASSERT (cc == 0);
827 	}
828     }
829   else
830     ASSERT (cc == 0);
831 
832   /* here p < 2^(2M) [K 2^(M(K-1)) + (K-1) 2^(M(K-2)) + ... ]
833      < K 2^(2M) [2^(M(K-1)) + 2^(M(K-2)) + ... ]
834      < K 2^(2M) 2^(M(K-1))*2 = 2^(M*K+M+k+1) */
835   return mpn_fft_norm_modF (op, pl, p, pla);
836 }
837 
838 /* return the lcm of a and 2^k */
839 static mp_bitcnt_t
840 mpn_mul_fft_lcm (mp_bitcnt_t a, int k)
841 {
842   mp_bitcnt_t l = k;
843 
844   while (a % 2 == 0 && k > 0)
845     {
846       a >>= 1;
847       k --;
848     }
849   return a << l;
850 }
851 
852 
853 mp_limb_t
854 mpn_mul_fft (mp_ptr op, mp_size_t pl,
855 	     mp_srcptr n, mp_size_t nl,
856 	     mp_srcptr m, mp_size_t ml,
857 	     int k)
858 {
859   int i;
860   mp_size_t K, maxLK;
861   mp_size_t N, Nprime, nprime, M, Mp, l;
862   mp_ptr *Ap, *Bp, A, T, B;
863   int **fft_l, *tmp;
864   int sqr = (n == m && nl == ml);
865   mp_limb_t h;
866   TMP_DECL;
867 
868   TRACE (printf ("\nmpn_mul_fft pl=%ld nl=%ld ml=%ld k=%d\n", pl, nl, ml, k));
869   ASSERT_ALWAYS (mpn_fft_next_size (pl, k) == pl);
870 
871   TMP_MARK;
872   N = pl * GMP_NUMB_BITS;
873   fft_l = TMP_BALLOC_TYPE (k + 1, int *);
874   tmp = TMP_BALLOC_TYPE ((size_t) 2 << k, int);
875   for (i = 0; i <= k; i++)
876     {
877       fft_l[i] = tmp;
878       tmp += (mp_size_t) 1 << i;
879     }
880 
881   mpn_fft_initl (fft_l, k);
882   K = (mp_size_t) 1 << k;
883   M = N >> k;	/* N = 2^k M */
884   l = 1 + (M - 1) / GMP_NUMB_BITS;
885   maxLK = mpn_mul_fft_lcm (GMP_NUMB_BITS, k); /* lcm (GMP_NUMB_BITS, 2^k) */
886 
887   Nprime = (1 + (2 * M + k + 2) / maxLK) * maxLK;
888   /* Nprime = ceil((2*M+k+3)/maxLK)*maxLK; */
889   nprime = Nprime / GMP_NUMB_BITS;
890   TRACE (printf ("N=%ld K=%ld, M=%ld, l=%ld, maxLK=%ld, Np=%ld, np=%ld\n",
891 		 N, K, M, l, maxLK, Nprime, nprime));
892   /* we should ensure that recursively, nprime is a multiple of the next K */
893   if (nprime >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD))
894     {
895       mp_size_t K2;
896       for (;;)
897 	{
898 	  K2 = (mp_size_t) 1 << mpn_fft_best_k (nprime, sqr);
899 	  if ((nprime & (K2 - 1)) == 0)
900 	    break;
901 	  nprime = (nprime + K2 - 1) & -K2;
902 	  Nprime = nprime * GMP_LIMB_BITS;
903 	  /* warning: since nprime changed, K2 may change too! */
904 	}
905       TRACE (printf ("new maxLK=%ld, Np=%ld, np=%ld\n", maxLK, Nprime, nprime));
906     }
907   ASSERT_ALWAYS (nprime < pl); /* otherwise we'll loop */
908 
909   T = TMP_BALLOC_LIMBS (2 * (nprime + 1));
910   Mp = Nprime >> k;
911 
912   TRACE (printf ("%ldx%ld limbs -> %ld times %ldx%ld limbs (%1.2f)\n",
913 		pl, pl, K, nprime, nprime, 2.0 * (double) N / Nprime / K);
914 	 printf ("   temp space %ld\n", 2 * K * (nprime + 1)));
915 
916   A = TMP_BALLOC_LIMBS (K * (nprime + 1));
917   Ap = TMP_BALLOC_MP_PTRS (K);
918   mpn_mul_fft_decompose (A, Ap, K, nprime, n, nl, l, Mp, T);
919   if (sqr)
920     {
921       mp_size_t pla;
922       pla = l * (K - 1) + nprime + 1; /* number of required limbs for p */
923       B = TMP_BALLOC_LIMBS (pla);
924       Bp = TMP_BALLOC_MP_PTRS (K);
925     }
926   else
927     {
928       B = TMP_BALLOC_LIMBS (K * (nprime + 1));
929       Bp = TMP_BALLOC_MP_PTRS (K);
930       mpn_mul_fft_decompose (B, Bp, K, nprime, m, ml, l, Mp, T);
931     }
932   h = mpn_mul_fft_internal (op, pl, k, Ap, Bp, A, B, nprime, l, Mp, fft_l, T, sqr);
933 
934   TMP_FREE;
935   return h;
936 }
937 
938 #if WANT_OLD_FFT_FULL
939 /* multiply {n, nl} by {m, ml}, and put the result in {op, nl+ml} */
940 void
941 mpn_mul_fft_full (mp_ptr op,
942 		  mp_srcptr n, mp_size_t nl,
943 		  mp_srcptr m, mp_size_t ml)
944 {
945   mp_ptr pad_op;
946   mp_size_t pl, pl2, pl3, l;
947   mp_size_t cc, c2, oldcc;
948   int k2, k3;
949   int sqr = (n == m && nl == ml);
950 
951   pl = nl + ml; /* total number of limbs of the result */
952 
953   /* perform a fft mod 2^(2N)+1 and one mod 2^(3N)+1.
954      We must have pl3 = 3/2 * pl2, with pl2 a multiple of 2^k2, and
955      pl3 a multiple of 2^k3. Since k3 >= k2, both are multiples of 2^k2,
956      and pl2 must be an even multiple of 2^k2. Thus (pl2,pl3) =
957      (2*j*2^k2,3*j*2^k2), which works for 3*j <= pl/2^k2 <= 5*j.
958      We need that consecutive intervals overlap, i.e. 5*j >= 3*(j+1),
959      which requires j>=2. Thus this scheme requires pl >= 6 * 2^FFT_FIRST_K. */
960 
961   /*  ASSERT_ALWAYS(pl >= 6 * (1 << FFT_FIRST_K)); */
962 
963   pl2 = (2 * pl - 1) / 5; /* ceil (2pl/5) - 1 */
964   do
965     {
966       pl2++;
967       k2 = mpn_fft_best_k (pl2, sqr); /* best fft size for pl2 limbs */
968       pl2 = mpn_fft_next_size (pl2, k2);
969       pl3 = 3 * pl2 / 2; /* since k>=FFT_FIRST_K=4, pl2 is a multiple of 2^4,
970 			    thus pl2 / 2 is exact */
971       k3 = mpn_fft_best_k (pl3, sqr);
972     }
973   while (mpn_fft_next_size (pl3, k3) != pl3);
974 
975   TRACE (printf ("mpn_mul_fft_full nl=%ld ml=%ld -> pl2=%ld pl3=%ld k=%d\n",
976 		 nl, ml, pl2, pl3, k2));
977 
978   ASSERT_ALWAYS(pl3 <= pl);
979   cc = mpn_mul_fft (op, pl3, n, nl, m, ml, k3);     /* mu */
980   ASSERT(cc == 0);
981   pad_op = __GMP_ALLOCATE_FUNC_LIMBS (pl2);
982   cc = mpn_mul_fft (pad_op, pl2, n, nl, m, ml, k2); /* lambda */
983   cc = -cc + mpn_sub_n (pad_op, pad_op, op, pl2);    /* lambda - low(mu) */
984   /* 0 <= cc <= 1 */
985   ASSERT(0 <= cc && cc <= 1);
986   l = pl3 - pl2; /* l = pl2 / 2 since pl3 = 3/2 * pl2 */
987   c2 = mpn_add_n (pad_op, pad_op, op + pl2, l);
988   cc = mpn_add_1 (pad_op + l, pad_op + l, l, (mp_limb_t) c2) - cc;
989   ASSERT(-1 <= cc && cc <= 1);
990   if (cc < 0)
991     cc = mpn_add_1 (pad_op, pad_op, pl2, (mp_limb_t) -cc);
992   ASSERT(0 <= cc && cc <= 1);
993   /* now lambda-mu = {pad_op, pl2} - cc mod 2^(pl2*GMP_NUMB_BITS)+1 */
994   oldcc = cc;
995 #if HAVE_NATIVE_mpn_add_n_sub_n
996   c2 = mpn_add_n_sub_n (pad_op + l, pad_op, pad_op, pad_op + l, l);
997   cc += c2 >> 1; /* carry out from high <- low + high */
998   c2 = c2 & 1; /* borrow out from low <- low - high */
999 #else
1000   {
1001     mp_ptr tmp;
1002     TMP_DECL;
1003 
1004     TMP_MARK;
1005     tmp = TMP_BALLOC_LIMBS (l);
1006     MPN_COPY (tmp, pad_op, l);
1007     c2 = mpn_sub_n (pad_op,      pad_op, pad_op + l, l);
1008     cc += mpn_add_n (pad_op + l, tmp,    pad_op + l, l);
1009     TMP_FREE;
1010   }
1011 #endif
1012   c2 += oldcc;
1013   /* first normalize {pad_op, pl2} before dividing by 2: c2 is the borrow
1014      at pad_op + l, cc is the carry at pad_op + pl2 */
1015   /* 0 <= cc <= 2 */
1016   cc -= mpn_sub_1 (pad_op + l, pad_op + l, l, (mp_limb_t) c2);
1017   /* -1 <= cc <= 2 */
1018   if (cc > 0)
1019     cc = -mpn_sub_1 (pad_op, pad_op, pl2, (mp_limb_t) cc);
1020   /* now -1 <= cc <= 0 */
1021   if (cc < 0)
1022     cc = mpn_add_1 (pad_op, pad_op, pl2, (mp_limb_t) -cc);
1023   /* now {pad_op, pl2} is normalized, with 0 <= cc <= 1 */
1024   if (pad_op[0] & 1) /* if odd, add 2^(pl2*GMP_NUMB_BITS)+1 */
1025     cc += 1 + mpn_add_1 (pad_op, pad_op, pl2, CNST_LIMB(1));
1026   /* now 0 <= cc <= 2, but cc=2 cannot occur since it would give a carry
1027      out below */
1028   mpn_rshift (pad_op, pad_op, pl2, 1); /* divide by two */
1029   if (cc) /* then cc=1 */
1030     pad_op [pl2 - 1] |= (mp_limb_t) 1 << (GMP_NUMB_BITS - 1);
1031   /* now {pad_op,pl2}-cc = (lambda-mu)/(1-2^(l*GMP_NUMB_BITS))
1032      mod 2^(pl2*GMP_NUMB_BITS) + 1 */
1033   c2 = mpn_add_n (op, op, pad_op, pl2); /* no need to add cc (is 0) */
1034   /* since pl2+pl3 >= pl, necessary the extra limbs (including cc) are zero */
1035   MPN_COPY (op + pl3, pad_op, pl - pl3);
1036   ASSERT_MPN_ZERO_P (pad_op + pl - pl3, pl2 + pl3 - pl);
1037   __GMP_FREE_FUNC_LIMBS (pad_op, pl2);
1038   /* since the final result has at most pl limbs, no carry out below */
1039   mpn_add_1 (op + pl2, op + pl2, pl - pl2, (mp_limb_t) c2);
1040 }
1041 #endif
1042