1 /* Schoenhage's fast multiplication modulo 2^N+1. 2 3 Contributed by Paul Zimmermann. 4 5 THE FUNCTIONS IN THIS FILE ARE INTERNAL WITH MUTABLE INTERFACES. IT IS ONLY 6 SAFE TO REACH THEM THROUGH DOCUMENTED INTERFACES. IN FACT, IT IS ALMOST 7 GUARANTEED THAT THEY WILL CHANGE OR DISAPPEAR IN A FUTURE GNU MP RELEASE. 8 9 Copyright 1998-2010, 2012, 2013, 2018 Free Software Foundation, Inc. 10 11 This file is part of the GNU MP Library. 12 13 The GNU MP Library is free software; you can redistribute it and/or modify 14 it under the terms of either: 15 16 * the GNU Lesser General Public License as published by the Free 17 Software Foundation; either version 3 of the License, or (at your 18 option) any later version. 19 20 or 21 22 * the GNU General Public License as published by the Free Software 23 Foundation; either version 2 of the License, or (at your option) any 24 later version. 25 26 or both in parallel, as here. 27 28 The GNU MP Library is distributed in the hope that it will be useful, but 29 WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY 30 or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 31 for more details. 32 33 You should have received copies of the GNU General Public License and the 34 GNU Lesser General Public License along with the GNU MP Library. If not, 35 see https://www.gnu.org/licenses/. */ 36 37 38 /* References: 39 40 Schnelle Multiplikation grosser Zahlen, by Arnold Schoenhage and Volker 41 Strassen, Computing 7, p. 281-292, 1971. 42 43 Asymptotically fast algorithms for the numerical multiplication and division 44 of polynomials with complex coefficients, by Arnold Schoenhage, Computer 45 Algebra, EUROCAM'82, LNCS 144, p. 3-15, 1982. 46 47 Tapes versus Pointers, a study in implementing fast algorithms, by Arnold 48 Schoenhage, Bulletin of the EATCS, 30, p. 23-32, 1986. 49 50 TODO: 51 52 Implement some of the tricks published at ISSAC'2007 by Gaudry, Kruppa, and 53 Zimmermann. 54 55 It might be possible to avoid a small number of MPN_COPYs by using a 56 rotating temporary or two. 57 58 Cleanup and simplify the code! 59 */ 60 61 #ifdef TRACE 62 #undef TRACE 63 #define TRACE(x) x 64 #include <stdio.h> 65 #else 66 #define TRACE(x) 67 #endif 68 69 #include "gmp-impl.h" 70 71 #ifdef WANT_ADDSUB 72 #include "generic/add_n_sub_n.c" 73 #define HAVE_NATIVE_mpn_add_n_sub_n 1 74 #endif 75 76 static mp_limb_t mpn_mul_fft_internal (mp_ptr, mp_size_t, int, mp_ptr *, 77 mp_ptr *, mp_ptr, mp_ptr, mp_size_t, 78 mp_size_t, mp_size_t, int **, mp_ptr, int); 79 static void mpn_mul_fft_decompose (mp_ptr, mp_ptr *, mp_size_t, mp_size_t, mp_srcptr, 80 mp_size_t, mp_size_t, mp_size_t, mp_ptr); 81 82 83 /* Find the best k to use for a mod 2^(m*GMP_NUMB_BITS)+1 FFT for m >= n. 84 We have sqr=0 if for a multiply, sqr=1 for a square. 85 There are three generations of this code; we keep the old ones as long as 86 some gmp-mparam.h is not updated. */ 87 88 89 /*****************************************************************************/ 90 91 #if TUNE_PROGRAM_BUILD || (defined (MUL_FFT_TABLE3) && defined (SQR_FFT_TABLE3)) 92 93 #ifndef FFT_TABLE3_SIZE /* When tuning this is defined in gmp-impl.h */ 94 #if defined (MUL_FFT_TABLE3_SIZE) && defined (SQR_FFT_TABLE3_SIZE) 95 #if MUL_FFT_TABLE3_SIZE > SQR_FFT_TABLE3_SIZE 96 #define FFT_TABLE3_SIZE MUL_FFT_TABLE3_SIZE 97 #else 98 #define FFT_TABLE3_SIZE SQR_FFT_TABLE3_SIZE 99 #endif 100 #endif 101 #endif 102 103 #ifndef FFT_TABLE3_SIZE 104 #define FFT_TABLE3_SIZE 200 105 #endif 106 107 FFT_TABLE_ATTRS struct fft_table_nk mpn_fft_table3[2][FFT_TABLE3_SIZE] = 108 { 109 MUL_FFT_TABLE3, 110 SQR_FFT_TABLE3 111 }; 112 113 int 114 mpn_fft_best_k (mp_size_t n, int sqr) 115 { 116 const struct fft_table_nk *fft_tab, *tab; 117 mp_size_t tab_n, thres; 118 int last_k; 119 120 fft_tab = mpn_fft_table3[sqr]; 121 last_k = fft_tab->k; 122 for (tab = fft_tab + 1; ; tab++) 123 { 124 tab_n = tab->n; 125 thres = tab_n << last_k; 126 if (n <= thres) 127 break; 128 last_k = tab->k; 129 } 130 return last_k; 131 } 132 133 #define MPN_FFT_BEST_READY 1 134 #endif 135 136 /*****************************************************************************/ 137 138 #if ! defined (MPN_FFT_BEST_READY) 139 FFT_TABLE_ATTRS mp_size_t mpn_fft_table[2][MPN_FFT_TABLE_SIZE] = 140 { 141 MUL_FFT_TABLE, 142 SQR_FFT_TABLE 143 }; 144 145 int 146 mpn_fft_best_k (mp_size_t n, int sqr) 147 { 148 int i; 149 150 for (i = 0; mpn_fft_table[sqr][i] != 0; i++) 151 if (n < mpn_fft_table[sqr][i]) 152 return i + FFT_FIRST_K; 153 154 /* treat 4*last as one further entry */ 155 if (i == 0 || n < 4 * mpn_fft_table[sqr][i - 1]) 156 return i + FFT_FIRST_K; 157 else 158 return i + FFT_FIRST_K + 1; 159 } 160 #endif 161 162 /*****************************************************************************/ 163 164 165 /* Returns smallest possible number of limbs >= pl for a fft of size 2^k, 166 i.e. smallest multiple of 2^k >= pl. 167 168 Don't declare static: needed by tuneup. 169 */ 170 171 mp_size_t 172 mpn_fft_next_size (mp_size_t pl, int k) 173 { 174 pl = 1 + ((pl - 1) >> k); /* ceil (pl/2^k) */ 175 return pl << k; 176 } 177 178 179 /* Initialize l[i][j] with bitrev(j) */ 180 static void 181 mpn_fft_initl (int **l, int k) 182 { 183 int i, j, K; 184 int *li; 185 186 l[0][0] = 0; 187 for (i = 1, K = 1; i <= k; i++, K *= 2) 188 { 189 li = l[i]; 190 for (j = 0; j < K; j++) 191 { 192 li[j] = 2 * l[i - 1][j]; 193 li[K + j] = 1 + li[j]; 194 } 195 } 196 } 197 198 199 /* r <- a*2^d mod 2^(n*GMP_NUMB_BITS)+1 with a = {a, n+1} 200 Assumes a is semi-normalized, i.e. a[n] <= 1. 201 r and a must have n+1 limbs, and not overlap. 202 */ 203 static void 204 mpn_fft_mul_2exp_modF (mp_ptr r, mp_srcptr a, mp_bitcnt_t d, mp_size_t n) 205 { 206 unsigned int sh; 207 mp_size_t m; 208 mp_limb_t cc, rd; 209 210 sh = d % GMP_NUMB_BITS; 211 m = d / GMP_NUMB_BITS; 212 213 if (m >= n) /* negate */ 214 { 215 /* r[0..m-1] <-- lshift(a[n-m]..a[n-1], sh) 216 r[m..n-1] <-- -lshift(a[0]..a[n-m-1], sh) */ 217 218 m -= n; 219 if (sh != 0) 220 { 221 /* no out shift below since a[n] <= 1 */ 222 mpn_lshift (r, a + n - m, m + 1, sh); 223 rd = r[m]; 224 cc = mpn_lshiftc (r + m, a, n - m, sh); 225 } 226 else 227 { 228 MPN_COPY (r, a + n - m, m); 229 rd = a[n]; 230 mpn_com (r + m, a, n - m); 231 cc = 0; 232 } 233 234 /* add cc to r[0], and add rd to r[m] */ 235 236 /* now add 1 in r[m], subtract 1 in r[n], i.e. add 1 in r[0] */ 237 238 r[n] = 0; 239 /* cc < 2^sh <= 2^(GMP_NUMB_BITS-1) thus no overflow here */ 240 cc++; 241 mpn_incr_u (r, cc); 242 243 rd++; 244 /* rd might overflow when sh=GMP_NUMB_BITS-1 */ 245 cc = (rd == 0) ? 1 : rd; 246 r = r + m + (rd == 0); 247 mpn_incr_u (r, cc); 248 } 249 else 250 { 251 /* r[0..m-1] <-- -lshift(a[n-m]..a[n-1], sh) 252 r[m..n-1] <-- lshift(a[0]..a[n-m-1], sh) */ 253 if (sh != 0) 254 { 255 /* no out bits below since a[n] <= 1 */ 256 mpn_lshiftc (r, a + n - m, m + 1, sh); 257 rd = ~r[m]; 258 /* {r, m+1} = {a+n-m, m+1} << sh */ 259 cc = mpn_lshift (r + m, a, n - m, sh); /* {r+m, n-m} = {a, n-m}<<sh */ 260 } 261 else 262 { 263 /* r[m] is not used below, but we save a test for m=0 */ 264 mpn_com (r, a + n - m, m + 1); 265 rd = a[n]; 266 MPN_COPY (r + m, a, n - m); 267 cc = 0; 268 } 269 270 /* now complement {r, m}, subtract cc from r[0], subtract rd from r[m] */ 271 272 /* if m=0 we just have r[0]=a[n] << sh */ 273 if (m != 0) 274 { 275 /* now add 1 in r[0], subtract 1 in r[m] */ 276 if (cc-- == 0) /* then add 1 to r[0] */ 277 cc = mpn_add_1 (r, r, n, CNST_LIMB(1)); 278 cc = mpn_sub_1 (r, r, m, cc) + 1; 279 /* add 1 to cc instead of rd since rd might overflow */ 280 } 281 282 /* now subtract cc and rd from r[m..n] */ 283 284 r[n] = -mpn_sub_1 (r + m, r + m, n - m, cc); 285 r[n] -= mpn_sub_1 (r + m, r + m, n - m, rd); 286 if (r[n] & GMP_LIMB_HIGHBIT) 287 r[n] = mpn_add_1 (r, r, n, CNST_LIMB(1)); 288 } 289 } 290 291 #if HAVE_NATIVE_mpn_add_n_sub_n 292 static inline void 293 mpn_fft_add_sub_modF (mp_ptr A0, mp_ptr Ai, mp_srcptr tp, mp_size_t n) 294 { 295 mp_limb_t cyas, c, x; 296 297 cyas = mpn_add_n_sub_n (A0, Ai, A0, tp, n); 298 299 c = A0[n] - tp[n] - (cyas & 1); 300 x = (-c) & -((c & GMP_LIMB_HIGHBIT) != 0); 301 Ai[n] = x + c; 302 MPN_INCR_U (Ai, n + 1, x); 303 304 c = A0[n] + tp[n] + (cyas >> 1); 305 x = (c - 1) & -(c != 0); 306 A0[n] = c - x; 307 MPN_DECR_U (A0, n + 1, x); 308 } 309 310 #else /* ! HAVE_NATIVE_mpn_add_n_sub_n */ 311 312 /* r <- a+b mod 2^(n*GMP_NUMB_BITS)+1. 313 Assumes a and b are semi-normalized. 314 */ 315 static inline void 316 mpn_fft_add_modF (mp_ptr r, mp_srcptr a, mp_srcptr b, mp_size_t n) 317 { 318 mp_limb_t c, x; 319 320 c = a[n] + b[n] + mpn_add_n (r, a, b, n); 321 /* 0 <= c <= 3 */ 322 323 #if 1 324 /* GCC 4.1 outsmarts most expressions here, and generates a 50% branch. The 325 result is slower code, of course. But the following outsmarts GCC. */ 326 x = (c - 1) & -(c != 0); 327 r[n] = c - x; 328 MPN_DECR_U (r, n + 1, x); 329 #endif 330 #if 0 331 if (c > 1) 332 { 333 r[n] = 1; /* r[n] - c = 1 */ 334 MPN_DECR_U (r, n + 1, c - 1); 335 } 336 else 337 { 338 r[n] = c; 339 } 340 #endif 341 } 342 343 /* r <- a-b mod 2^(n*GMP_NUMB_BITS)+1. 344 Assumes a and b are semi-normalized. 345 */ 346 static inline void 347 mpn_fft_sub_modF (mp_ptr r, mp_srcptr a, mp_srcptr b, mp_size_t n) 348 { 349 mp_limb_t c, x; 350 351 c = a[n] - b[n] - mpn_sub_n (r, a, b, n); 352 /* -2 <= c <= 1 */ 353 354 #if 1 355 /* GCC 4.1 outsmarts most expressions here, and generates a 50% branch. The 356 result is slower code, of course. But the following outsmarts GCC. */ 357 x = (-c) & -((c & GMP_LIMB_HIGHBIT) != 0); 358 r[n] = x + c; 359 MPN_INCR_U (r, n + 1, x); 360 #endif 361 #if 0 362 if ((c & GMP_LIMB_HIGHBIT) != 0) 363 { 364 r[n] = 0; 365 MPN_INCR_U (r, n + 1, -c); 366 } 367 else 368 { 369 r[n] = c; 370 } 371 #endif 372 } 373 #endif /* HAVE_NATIVE_mpn_add_n_sub_n */ 374 375 /* input: A[0] ... A[inc*(K-1)] are residues mod 2^N+1 where 376 N=n*GMP_NUMB_BITS, and 2^omega is a primitive root mod 2^N+1 377 output: A[inc*l[k][i]] <- \sum (2^omega)^(ij) A[inc*j] mod 2^N+1 */ 378 379 static void 380 mpn_fft_fft (mp_ptr *Ap, mp_size_t K, int **ll, 381 mp_size_t omega, mp_size_t n, mp_size_t inc, mp_ptr tp) 382 { 383 if (K == 2) 384 { 385 mp_limb_t cy; 386 #if HAVE_NATIVE_mpn_add_n_sub_n 387 cy = mpn_add_n_sub_n (Ap[0], Ap[inc], Ap[0], Ap[inc], n + 1) & 1; 388 #else 389 MPN_COPY (tp, Ap[0], n + 1); 390 mpn_add_n (Ap[0], Ap[0], Ap[inc], n + 1); 391 cy = mpn_sub_n (Ap[inc], tp, Ap[inc], n + 1); 392 #endif 393 if (Ap[0][n] > 1) /* can be 2 or 3 */ 394 Ap[0][n] = 1 - mpn_sub_1 (Ap[0], Ap[0], n, Ap[0][n] - 1); 395 if (cy) /* Ap[inc][n] can be -1 or -2 */ 396 Ap[inc][n] = mpn_add_1 (Ap[inc], Ap[inc], n, ~Ap[inc][n] + 1); 397 } 398 else 399 { 400 mp_size_t j, K2 = K >> 1; 401 int *lk = *ll; 402 403 mpn_fft_fft (Ap, K2, ll-1, 2 * omega, n, inc * 2, tp); 404 mpn_fft_fft (Ap+inc, K2, ll-1, 2 * omega, n, inc * 2, tp); 405 /* A[2*j*inc] <- A[2*j*inc] + omega^l[k][2*j*inc] A[(2j+1)inc] 406 A[(2j+1)inc] <- A[2*j*inc] + omega^l[k][(2j+1)inc] A[(2j+1)inc] */ 407 for (j = 0; j < K2; j++, lk += 2, Ap += 2 * inc) 408 { 409 /* Ap[inc] <- Ap[0] + Ap[inc] * 2^(lk[1] * omega) 410 Ap[0] <- Ap[0] + Ap[inc] * 2^(lk[0] * omega) */ 411 mpn_fft_mul_2exp_modF (tp, Ap[inc], lk[0] * omega, n); 412 #if HAVE_NATIVE_mpn_add_n_sub_n 413 mpn_fft_add_sub_modF (Ap[0], Ap[inc], tp, n); 414 #else 415 mpn_fft_sub_modF (Ap[inc], Ap[0], tp, n); 416 mpn_fft_add_modF (Ap[0], Ap[0], tp, n); 417 #endif 418 } 419 } 420 } 421 422 /* input: A[0] ... A[inc*(K-1)] are residues mod 2^N+1 where 423 N=n*GMP_NUMB_BITS, and 2^omega is a primitive root mod 2^N+1 424 output: A[inc*l[k][i]] <- \sum (2^omega)^(ij) A[inc*j] mod 2^N+1 425 tp must have space for 2*(n+1) limbs. 426 */ 427 428 429 /* Given ap[0..n] with ap[n]<=1, reduce it modulo 2^(n*GMP_NUMB_BITS)+1, 430 by subtracting that modulus if necessary. 431 432 If ap[0..n] is exactly 2^(n*GMP_NUMB_BITS) then mpn_sub_1 produces a 433 borrow and the limbs must be zeroed out again. This will occur very 434 infrequently. */ 435 436 static inline void 437 mpn_fft_normalize (mp_ptr ap, mp_size_t n) 438 { 439 if (ap[n] != 0) 440 { 441 MPN_DECR_U (ap, n + 1, CNST_LIMB(1)); 442 if (ap[n] == 0) 443 { 444 /* This happens with very low probability; we have yet to trigger it, 445 and thereby make sure this code is correct. */ 446 MPN_ZERO (ap, n); 447 ap[n] = 1; 448 } 449 else 450 ap[n] = 0; 451 } 452 } 453 454 /* a[i] <- a[i]*b[i] mod 2^(n*GMP_NUMB_BITS)+1 for 0 <= i < K */ 455 static void 456 mpn_fft_mul_modF_K (mp_ptr *ap, mp_ptr *bp, mp_size_t n, mp_size_t K) 457 { 458 int i; 459 int sqr = (ap == bp); 460 TMP_DECL; 461 462 TMP_MARK; 463 464 if (n >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD)) 465 { 466 mp_size_t K2, nprime2, Nprime2, M2, maxLK, l, Mp2; 467 int k; 468 int **fft_l, *tmp; 469 mp_ptr *Ap, *Bp, A, B, T; 470 471 k = mpn_fft_best_k (n, sqr); 472 K2 = (mp_size_t) 1 << k; 473 ASSERT_ALWAYS((n & (K2 - 1)) == 0); 474 maxLK = (K2 > GMP_NUMB_BITS) ? K2 : GMP_NUMB_BITS; 475 M2 = n * GMP_NUMB_BITS >> k; 476 l = n >> k; 477 Nprime2 = ((2 * M2 + k + 2 + maxLK) / maxLK) * maxLK; 478 /* Nprime2 = ceil((2*M2+k+3)/maxLK)*maxLK*/ 479 nprime2 = Nprime2 / GMP_NUMB_BITS; 480 481 /* we should ensure that nprime2 is a multiple of the next K */ 482 if (nprime2 >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD)) 483 { 484 mp_size_t K3; 485 for (;;) 486 { 487 K3 = (mp_size_t) 1 << mpn_fft_best_k (nprime2, sqr); 488 if ((nprime2 & (K3 - 1)) == 0) 489 break; 490 nprime2 = (nprime2 + K3 - 1) & -K3; 491 Nprime2 = nprime2 * GMP_LIMB_BITS; 492 /* warning: since nprime2 changed, K3 may change too! */ 493 } 494 } 495 ASSERT_ALWAYS(nprime2 < n); /* otherwise we'll loop */ 496 497 Mp2 = Nprime2 >> k; 498 499 Ap = TMP_BALLOC_MP_PTRS (K2); 500 Bp = TMP_BALLOC_MP_PTRS (K2); 501 A = TMP_BALLOC_LIMBS (2 * (nprime2 + 1) << k); 502 T = TMP_BALLOC_LIMBS (2 * (nprime2 + 1)); 503 B = A + ((nprime2 + 1) << k); 504 fft_l = TMP_BALLOC_TYPE (k + 1, int *); 505 tmp = TMP_BALLOC_TYPE ((size_t) 2 << k, int); 506 for (i = 0; i <= k; i++) 507 { 508 fft_l[i] = tmp; 509 tmp += (mp_size_t) 1 << i; 510 } 511 512 mpn_fft_initl (fft_l, k); 513 514 TRACE (printf ("recurse: %ldx%ld limbs -> %ld times %ldx%ld (%1.2f)\n", n, 515 n, K2, nprime2, nprime2, 2.0*(double)n/nprime2/K2)); 516 for (i = 0; i < K; i++, ap++, bp++) 517 { 518 mp_limb_t cy; 519 mpn_fft_normalize (*ap, n); 520 if (!sqr) 521 mpn_fft_normalize (*bp, n); 522 523 mpn_mul_fft_decompose (A, Ap, K2, nprime2, *ap, (l << k) + 1, l, Mp2, T); 524 if (!sqr) 525 mpn_mul_fft_decompose (B, Bp, K2, nprime2, *bp, (l << k) + 1, l, Mp2, T); 526 527 cy = mpn_mul_fft_internal (*ap, n, k, Ap, Bp, A, B, nprime2, 528 l, Mp2, fft_l, T, sqr); 529 (*ap)[n] = cy; 530 } 531 } 532 else 533 { 534 mp_ptr a, b, tp, tpn; 535 mp_limb_t cc; 536 mp_size_t n2 = 2 * n; 537 tp = TMP_BALLOC_LIMBS (n2); 538 tpn = tp + n; 539 TRACE (printf (" mpn_mul_n %ld of %ld limbs\n", K, n)); 540 for (i = 0; i < K; i++) 541 { 542 a = *ap++; 543 b = *bp++; 544 if (sqr) 545 mpn_sqr (tp, a, n); 546 else 547 mpn_mul_n (tp, b, a, n); 548 if (a[n] != 0) 549 cc = mpn_add_n (tpn, tpn, b, n); 550 else 551 cc = 0; 552 if (b[n] != 0) 553 cc += mpn_add_n (tpn, tpn, a, n) + a[n]; 554 if (cc != 0) 555 { 556 /* FIXME: use MPN_INCR_U here, since carry is not expected. */ 557 cc = mpn_add_1 (tp, tp, n2, cc); 558 ASSERT (cc == 0); 559 } 560 a[n] = mpn_sub_n (a, tp, tpn, n) && mpn_add_1 (a, a, n, CNST_LIMB(1)); 561 } 562 } 563 TMP_FREE; 564 } 565 566 567 /* input: A^[l[k][0]] A^[l[k][1]] ... A^[l[k][K-1]] 568 output: K*A[0] K*A[K-1] ... K*A[1]. 569 Assumes the Ap[] are pseudo-normalized, i.e. 0 <= Ap[][n] <= 1. 570 This condition is also fulfilled at exit. 571 */ 572 static void 573 mpn_fft_fftinv (mp_ptr *Ap, mp_size_t K, mp_size_t omega, mp_size_t n, mp_ptr tp) 574 { 575 if (K == 2) 576 { 577 mp_limb_t cy; 578 #if HAVE_NATIVE_mpn_add_n_sub_n 579 cy = mpn_add_n_sub_n (Ap[0], Ap[1], Ap[0], Ap[1], n + 1) & 1; 580 #else 581 MPN_COPY (tp, Ap[0], n + 1); 582 mpn_add_n (Ap[0], Ap[0], Ap[1], n + 1); 583 cy = mpn_sub_n (Ap[1], tp, Ap[1], n + 1); 584 #endif 585 if (Ap[0][n] > 1) /* can be 2 or 3 */ 586 Ap[0][n] = 1 - mpn_sub_1 (Ap[0], Ap[0], n, Ap[0][n] - 1); 587 if (cy) /* Ap[1][n] can be -1 or -2 */ 588 Ap[1][n] = mpn_add_1 (Ap[1], Ap[1], n, ~Ap[1][n] + 1); 589 } 590 else 591 { 592 mp_size_t j, K2 = K >> 1; 593 594 mpn_fft_fftinv (Ap, K2, 2 * omega, n, tp); 595 mpn_fft_fftinv (Ap + K2, K2, 2 * omega, n, tp); 596 /* A[j] <- A[j] + omega^j A[j+K/2] 597 A[j+K/2] <- A[j] + omega^(j+K/2) A[j+K/2] */ 598 for (j = 0; j < K2; j++, Ap++) 599 { 600 /* Ap[K2] <- Ap[0] + Ap[K2] * 2^((j + K2) * omega) 601 Ap[0] <- Ap[0] + Ap[K2] * 2^(j * omega) */ 602 mpn_fft_mul_2exp_modF (tp, Ap[K2], j * omega, n); 603 #if HAVE_NATIVE_mpn_add_n_sub_n 604 mpn_fft_add_sub_modF (Ap[0], Ap[K2], tp, n); 605 #else 606 mpn_fft_sub_modF (Ap[K2], Ap[0], tp, n); 607 mpn_fft_add_modF (Ap[0], Ap[0], tp, n); 608 #endif 609 } 610 } 611 } 612 613 614 /* R <- A/2^k mod 2^(n*GMP_NUMB_BITS)+1 */ 615 static void 616 mpn_fft_div_2exp_modF (mp_ptr r, mp_srcptr a, mp_bitcnt_t k, mp_size_t n) 617 { 618 mp_bitcnt_t i; 619 620 ASSERT (r != a); 621 i = (mp_bitcnt_t) 2 * n * GMP_NUMB_BITS - k; 622 mpn_fft_mul_2exp_modF (r, a, i, n); 623 /* 1/2^k = 2^(2nL-k) mod 2^(n*GMP_NUMB_BITS)+1 */ 624 /* normalize so that R < 2^(n*GMP_NUMB_BITS)+1 */ 625 mpn_fft_normalize (r, n); 626 } 627 628 629 /* {rp,n} <- {ap,an} mod 2^(n*GMP_NUMB_BITS)+1, n <= an <= 3*n. 630 Returns carry out, i.e. 1 iff {ap,an} = -1 mod 2^(n*GMP_NUMB_BITS)+1, 631 then {rp,n}=0. 632 */ 633 static mp_size_t 634 mpn_fft_norm_modF (mp_ptr rp, mp_size_t n, mp_ptr ap, mp_size_t an) 635 { 636 mp_size_t l, m, rpn; 637 mp_limb_t cc; 638 639 ASSERT ((n <= an) && (an <= 3 * n)); 640 m = an - 2 * n; 641 if (m > 0) 642 { 643 l = n; 644 /* add {ap, m} and {ap+2n, m} in {rp, m} */ 645 cc = mpn_add_n (rp, ap, ap + 2 * n, m); 646 /* copy {ap+m, n-m} to {rp+m, n-m} */ 647 rpn = mpn_add_1 (rp + m, ap + m, n - m, cc); 648 } 649 else 650 { 651 l = an - n; /* l <= n */ 652 MPN_COPY (rp, ap, n); 653 rpn = 0; 654 } 655 656 /* remains to subtract {ap+n, l} from {rp, n+1} */ 657 cc = mpn_sub_n (rp, rp, ap + n, l); 658 rpn -= mpn_sub_1 (rp + l, rp + l, n - l, cc); 659 if (rpn < 0) /* necessarily rpn = -1 */ 660 rpn = mpn_add_1 (rp, rp, n, CNST_LIMB(1)); 661 return rpn; 662 } 663 664 /* store in A[0..nprime] the first M bits from {n, nl}, 665 in A[nprime+1..] the following M bits, ... 666 Assumes M is a multiple of GMP_NUMB_BITS (M = l * GMP_NUMB_BITS). 667 T must have space for at least (nprime + 1) limbs. 668 We must have nl <= 2*K*l. 669 */ 670 static void 671 mpn_mul_fft_decompose (mp_ptr A, mp_ptr *Ap, mp_size_t K, mp_size_t nprime, 672 mp_srcptr n, mp_size_t nl, mp_size_t l, mp_size_t Mp, 673 mp_ptr T) 674 { 675 mp_size_t i, j; 676 mp_ptr tmp; 677 mp_size_t Kl = K * l; 678 TMP_DECL; 679 TMP_MARK; 680 681 if (nl > Kl) /* normalize {n, nl} mod 2^(Kl*GMP_NUMB_BITS)+1 */ 682 { 683 mp_size_t dif = nl - Kl; 684 mp_limb_signed_t cy; 685 686 tmp = TMP_BALLOC_LIMBS(Kl + 1); 687 688 if (dif > Kl) 689 { 690 int subp = 0; 691 692 cy = mpn_sub_n (tmp, n, n + Kl, Kl); 693 n += 2 * Kl; 694 dif -= Kl; 695 696 /* now dif > 0 */ 697 while (dif > Kl) 698 { 699 if (subp) 700 cy += mpn_sub_n (tmp, tmp, n, Kl); 701 else 702 cy -= mpn_add_n (tmp, tmp, n, Kl); 703 subp ^= 1; 704 n += Kl; 705 dif -= Kl; 706 } 707 /* now dif <= Kl */ 708 if (subp) 709 cy += mpn_sub (tmp, tmp, Kl, n, dif); 710 else 711 cy -= mpn_add (tmp, tmp, Kl, n, dif); 712 if (cy >= 0) 713 cy = mpn_add_1 (tmp, tmp, Kl, cy); 714 else 715 cy = mpn_sub_1 (tmp, tmp, Kl, -cy); 716 } 717 else /* dif <= Kl, i.e. nl <= 2 * Kl */ 718 { 719 cy = mpn_sub (tmp, n, Kl, n + Kl, dif); 720 cy = mpn_add_1 (tmp, tmp, Kl, cy); 721 } 722 tmp[Kl] = cy; 723 nl = Kl + 1; 724 n = tmp; 725 } 726 for (i = 0; i < K; i++) 727 { 728 Ap[i] = A; 729 /* store the next M bits of n into A[0..nprime] */ 730 if (nl > 0) /* nl is the number of remaining limbs */ 731 { 732 j = (l <= nl && i < K - 1) ? l : nl; /* store j next limbs */ 733 nl -= j; 734 MPN_COPY (T, n, j); 735 MPN_ZERO (T + j, nprime + 1 - j); 736 n += l; 737 mpn_fft_mul_2exp_modF (A, T, i * Mp, nprime); 738 } 739 else 740 MPN_ZERO (A, nprime + 1); 741 A += nprime + 1; 742 } 743 ASSERT_ALWAYS (nl == 0); 744 TMP_FREE; 745 } 746 747 /* op <- n*m mod 2^N+1 with fft of size 2^k where N=pl*GMP_NUMB_BITS 748 op is pl limbs, its high bit is returned. 749 One must have pl = mpn_fft_next_size (pl, k). 750 T must have space for 2 * (nprime + 1) limbs. 751 */ 752 753 static mp_limb_t 754 mpn_mul_fft_internal (mp_ptr op, mp_size_t pl, int k, 755 mp_ptr *Ap, mp_ptr *Bp, mp_ptr A, mp_ptr B, 756 mp_size_t nprime, mp_size_t l, mp_size_t Mp, 757 int **fft_l, mp_ptr T, int sqr) 758 { 759 mp_size_t K, i, pla, lo, sh, j; 760 mp_ptr p; 761 mp_limb_t cc; 762 763 K = (mp_size_t) 1 << k; 764 765 /* direct fft's */ 766 mpn_fft_fft (Ap, K, fft_l + k, 2 * Mp, nprime, 1, T); 767 if (!sqr) 768 mpn_fft_fft (Bp, K, fft_l + k, 2 * Mp, nprime, 1, T); 769 770 /* term to term multiplications */ 771 mpn_fft_mul_modF_K (Ap, sqr ? Ap : Bp, nprime, K); 772 773 /* inverse fft's */ 774 mpn_fft_fftinv (Ap, K, 2 * Mp, nprime, T); 775 776 /* division of terms after inverse fft */ 777 Bp[0] = T + nprime + 1; 778 mpn_fft_div_2exp_modF (Bp[0], Ap[0], k, nprime); 779 for (i = 1; i < K; i++) 780 { 781 Bp[i] = Ap[i - 1]; 782 mpn_fft_div_2exp_modF (Bp[i], Ap[i], k + (K - i) * Mp, nprime); 783 } 784 785 /* addition of terms in result p */ 786 MPN_ZERO (T, nprime + 1); 787 pla = l * (K - 1) + nprime + 1; /* number of required limbs for p */ 788 p = B; /* B has K*(n' + 1) limbs, which is >= pla, i.e. enough */ 789 MPN_ZERO (p, pla); 790 cc = 0; /* will accumulate the (signed) carry at p[pla] */ 791 for (i = K - 1, lo = l * i + nprime,sh = l * i; i >= 0; i--,lo -= l,sh -= l) 792 { 793 mp_ptr n = p + sh; 794 795 j = (K - i) & (K - 1); 796 797 if (mpn_add_n (n, n, Bp[j], nprime + 1)) 798 cc += mpn_add_1 (n + nprime + 1, n + nprime + 1, 799 pla - sh - nprime - 1, CNST_LIMB(1)); 800 T[2 * l] = i + 1; /* T = (i + 1)*2^(2*M) */ 801 if (mpn_cmp (Bp[j], T, nprime + 1) > 0) 802 { /* subtract 2^N'+1 */ 803 cc -= mpn_sub_1 (n, n, pla - sh, CNST_LIMB(1)); 804 cc -= mpn_sub_1 (p + lo, p + lo, pla - lo, CNST_LIMB(1)); 805 } 806 } 807 if (cc == -CNST_LIMB(1)) 808 { 809 if ((cc = mpn_add_1 (p + pla - pl, p + pla - pl, pl, CNST_LIMB(1)))) 810 { 811 /* p[pla-pl]...p[pla-1] are all zero */ 812 mpn_sub_1 (p + pla - pl - 1, p + pla - pl - 1, pl + 1, CNST_LIMB(1)); 813 mpn_sub_1 (p + pla - 1, p + pla - 1, 1, CNST_LIMB(1)); 814 } 815 } 816 else if (cc == 1) 817 { 818 if (pla >= 2 * pl) 819 { 820 while ((cc = mpn_add_1 (p + pla - 2 * pl, p + pla - 2 * pl, 2 * pl, cc))) 821 ; 822 } 823 else 824 { 825 cc = mpn_sub_1 (p + pla - pl, p + pla - pl, pl, cc); 826 ASSERT (cc == 0); 827 } 828 } 829 else 830 ASSERT (cc == 0); 831 832 /* here p < 2^(2M) [K 2^(M(K-1)) + (K-1) 2^(M(K-2)) + ... ] 833 < K 2^(2M) [2^(M(K-1)) + 2^(M(K-2)) + ... ] 834 < K 2^(2M) 2^(M(K-1))*2 = 2^(M*K+M+k+1) */ 835 return mpn_fft_norm_modF (op, pl, p, pla); 836 } 837 838 /* return the lcm of a and 2^k */ 839 static mp_bitcnt_t 840 mpn_mul_fft_lcm (mp_bitcnt_t a, int k) 841 { 842 mp_bitcnt_t l = k; 843 844 while (a % 2 == 0 && k > 0) 845 { 846 a >>= 1; 847 k --; 848 } 849 return a << l; 850 } 851 852 853 mp_limb_t 854 mpn_mul_fft (mp_ptr op, mp_size_t pl, 855 mp_srcptr n, mp_size_t nl, 856 mp_srcptr m, mp_size_t ml, 857 int k) 858 { 859 int i; 860 mp_size_t K, maxLK; 861 mp_size_t N, Nprime, nprime, M, Mp, l; 862 mp_ptr *Ap, *Bp, A, T, B; 863 int **fft_l, *tmp; 864 int sqr = (n == m && nl == ml); 865 mp_limb_t h; 866 TMP_DECL; 867 868 TRACE (printf ("\nmpn_mul_fft pl=%ld nl=%ld ml=%ld k=%d\n", pl, nl, ml, k)); 869 ASSERT_ALWAYS (mpn_fft_next_size (pl, k) == pl); 870 871 TMP_MARK; 872 N = pl * GMP_NUMB_BITS; 873 fft_l = TMP_BALLOC_TYPE (k + 1, int *); 874 tmp = TMP_BALLOC_TYPE ((size_t) 2 << k, int); 875 for (i = 0; i <= k; i++) 876 { 877 fft_l[i] = tmp; 878 tmp += (mp_size_t) 1 << i; 879 } 880 881 mpn_fft_initl (fft_l, k); 882 K = (mp_size_t) 1 << k; 883 M = N >> k; /* N = 2^k M */ 884 l = 1 + (M - 1) / GMP_NUMB_BITS; 885 maxLK = mpn_mul_fft_lcm (GMP_NUMB_BITS, k); /* lcm (GMP_NUMB_BITS, 2^k) */ 886 887 Nprime = (1 + (2 * M + k + 2) / maxLK) * maxLK; 888 /* Nprime = ceil((2*M+k+3)/maxLK)*maxLK; */ 889 nprime = Nprime / GMP_NUMB_BITS; 890 TRACE (printf ("N=%ld K=%ld, M=%ld, l=%ld, maxLK=%ld, Np=%ld, np=%ld\n", 891 N, K, M, l, maxLK, Nprime, nprime)); 892 /* we should ensure that recursively, nprime is a multiple of the next K */ 893 if (nprime >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD)) 894 { 895 mp_size_t K2; 896 for (;;) 897 { 898 K2 = (mp_size_t) 1 << mpn_fft_best_k (nprime, sqr); 899 if ((nprime & (K2 - 1)) == 0) 900 break; 901 nprime = (nprime + K2 - 1) & -K2; 902 Nprime = nprime * GMP_LIMB_BITS; 903 /* warning: since nprime changed, K2 may change too! */ 904 } 905 TRACE (printf ("new maxLK=%ld, Np=%ld, np=%ld\n", maxLK, Nprime, nprime)); 906 } 907 ASSERT_ALWAYS (nprime < pl); /* otherwise we'll loop */ 908 909 T = TMP_BALLOC_LIMBS (2 * (nprime + 1)); 910 Mp = Nprime >> k; 911 912 TRACE (printf ("%ldx%ld limbs -> %ld times %ldx%ld limbs (%1.2f)\n", 913 pl, pl, K, nprime, nprime, 2.0 * (double) N / Nprime / K); 914 printf (" temp space %ld\n", 2 * K * (nprime + 1))); 915 916 A = TMP_BALLOC_LIMBS (K * (nprime + 1)); 917 Ap = TMP_BALLOC_MP_PTRS (K); 918 mpn_mul_fft_decompose (A, Ap, K, nprime, n, nl, l, Mp, T); 919 if (sqr) 920 { 921 mp_size_t pla; 922 pla = l * (K - 1) + nprime + 1; /* number of required limbs for p */ 923 B = TMP_BALLOC_LIMBS (pla); 924 Bp = TMP_BALLOC_MP_PTRS (K); 925 } 926 else 927 { 928 B = TMP_BALLOC_LIMBS (K * (nprime + 1)); 929 Bp = TMP_BALLOC_MP_PTRS (K); 930 mpn_mul_fft_decompose (B, Bp, K, nprime, m, ml, l, Mp, T); 931 } 932 h = mpn_mul_fft_internal (op, pl, k, Ap, Bp, A, B, nprime, l, Mp, fft_l, T, sqr); 933 934 TMP_FREE; 935 return h; 936 } 937 938 #if WANT_OLD_FFT_FULL 939 /* multiply {n, nl} by {m, ml}, and put the result in {op, nl+ml} */ 940 void 941 mpn_mul_fft_full (mp_ptr op, 942 mp_srcptr n, mp_size_t nl, 943 mp_srcptr m, mp_size_t ml) 944 { 945 mp_ptr pad_op; 946 mp_size_t pl, pl2, pl3, l; 947 mp_size_t cc, c2, oldcc; 948 int k2, k3; 949 int sqr = (n == m && nl == ml); 950 951 pl = nl + ml; /* total number of limbs of the result */ 952 953 /* perform a fft mod 2^(2N)+1 and one mod 2^(3N)+1. 954 We must have pl3 = 3/2 * pl2, with pl2 a multiple of 2^k2, and 955 pl3 a multiple of 2^k3. Since k3 >= k2, both are multiples of 2^k2, 956 and pl2 must be an even multiple of 2^k2. Thus (pl2,pl3) = 957 (2*j*2^k2,3*j*2^k2), which works for 3*j <= pl/2^k2 <= 5*j. 958 We need that consecutive intervals overlap, i.e. 5*j >= 3*(j+1), 959 which requires j>=2. Thus this scheme requires pl >= 6 * 2^FFT_FIRST_K. */ 960 961 /* ASSERT_ALWAYS(pl >= 6 * (1 << FFT_FIRST_K)); */ 962 963 pl2 = (2 * pl - 1) / 5; /* ceil (2pl/5) - 1 */ 964 do 965 { 966 pl2++; 967 k2 = mpn_fft_best_k (pl2, sqr); /* best fft size for pl2 limbs */ 968 pl2 = mpn_fft_next_size (pl2, k2); 969 pl3 = 3 * pl2 / 2; /* since k>=FFT_FIRST_K=4, pl2 is a multiple of 2^4, 970 thus pl2 / 2 is exact */ 971 k3 = mpn_fft_best_k (pl3, sqr); 972 } 973 while (mpn_fft_next_size (pl3, k3) != pl3); 974 975 TRACE (printf ("mpn_mul_fft_full nl=%ld ml=%ld -> pl2=%ld pl3=%ld k=%d\n", 976 nl, ml, pl2, pl3, k2)); 977 978 ASSERT_ALWAYS(pl3 <= pl); 979 cc = mpn_mul_fft (op, pl3, n, nl, m, ml, k3); /* mu */ 980 ASSERT(cc == 0); 981 pad_op = __GMP_ALLOCATE_FUNC_LIMBS (pl2); 982 cc = mpn_mul_fft (pad_op, pl2, n, nl, m, ml, k2); /* lambda */ 983 cc = -cc + mpn_sub_n (pad_op, pad_op, op, pl2); /* lambda - low(mu) */ 984 /* 0 <= cc <= 1 */ 985 ASSERT(0 <= cc && cc <= 1); 986 l = pl3 - pl2; /* l = pl2 / 2 since pl3 = 3/2 * pl2 */ 987 c2 = mpn_add_n (pad_op, pad_op, op + pl2, l); 988 cc = mpn_add_1 (pad_op + l, pad_op + l, l, (mp_limb_t) c2) - cc; 989 ASSERT(-1 <= cc && cc <= 1); 990 if (cc < 0) 991 cc = mpn_add_1 (pad_op, pad_op, pl2, (mp_limb_t) -cc); 992 ASSERT(0 <= cc && cc <= 1); 993 /* now lambda-mu = {pad_op, pl2} - cc mod 2^(pl2*GMP_NUMB_BITS)+1 */ 994 oldcc = cc; 995 #if HAVE_NATIVE_mpn_add_n_sub_n 996 c2 = mpn_add_n_sub_n (pad_op + l, pad_op, pad_op, pad_op + l, l); 997 cc += c2 >> 1; /* carry out from high <- low + high */ 998 c2 = c2 & 1; /* borrow out from low <- low - high */ 999 #else 1000 { 1001 mp_ptr tmp; 1002 TMP_DECL; 1003 1004 TMP_MARK; 1005 tmp = TMP_BALLOC_LIMBS (l); 1006 MPN_COPY (tmp, pad_op, l); 1007 c2 = mpn_sub_n (pad_op, pad_op, pad_op + l, l); 1008 cc += mpn_add_n (pad_op + l, tmp, pad_op + l, l); 1009 TMP_FREE; 1010 } 1011 #endif 1012 c2 += oldcc; 1013 /* first normalize {pad_op, pl2} before dividing by 2: c2 is the borrow 1014 at pad_op + l, cc is the carry at pad_op + pl2 */ 1015 /* 0 <= cc <= 2 */ 1016 cc -= mpn_sub_1 (pad_op + l, pad_op + l, l, (mp_limb_t) c2); 1017 /* -1 <= cc <= 2 */ 1018 if (cc > 0) 1019 cc = -mpn_sub_1 (pad_op, pad_op, pl2, (mp_limb_t) cc); 1020 /* now -1 <= cc <= 0 */ 1021 if (cc < 0) 1022 cc = mpn_add_1 (pad_op, pad_op, pl2, (mp_limb_t) -cc); 1023 /* now {pad_op, pl2} is normalized, with 0 <= cc <= 1 */ 1024 if (pad_op[0] & 1) /* if odd, add 2^(pl2*GMP_NUMB_BITS)+1 */ 1025 cc += 1 + mpn_add_1 (pad_op, pad_op, pl2, CNST_LIMB(1)); 1026 /* now 0 <= cc <= 2, but cc=2 cannot occur since it would give a carry 1027 out below */ 1028 mpn_rshift (pad_op, pad_op, pl2, 1); /* divide by two */ 1029 if (cc) /* then cc=1 */ 1030 pad_op [pl2 - 1] |= (mp_limb_t) 1 << (GMP_NUMB_BITS - 1); 1031 /* now {pad_op,pl2}-cc = (lambda-mu)/(1-2^(l*GMP_NUMB_BITS)) 1032 mod 2^(pl2*GMP_NUMB_BITS) + 1 */ 1033 c2 = mpn_add_n (op, op, pad_op, pl2); /* no need to add cc (is 0) */ 1034 /* since pl2+pl3 >= pl, necessary the extra limbs (including cc) are zero */ 1035 MPN_COPY (op + pl3, pad_op, pl - pl3); 1036 ASSERT_MPN_ZERO_P (pad_op + pl - pl3, pl2 + pl3 - pl); 1037 __GMP_FREE_FUNC_LIMBS (pad_op, pl2); 1038 /* since the final result has at most pl limbs, no carry out below */ 1039 mpn_add_1 (op + pl2, op + pl2, pl - pl2, (mp_limb_t) c2); 1040 } 1041 #endif 1042