xref: /netbsd-src/external/lgpl3/mpfr/dist/src/random_deviate.c (revision ba125506a622fe649968631a56eba5d42ff57863)
1 /* random_deviate routines for mpfr_erandom and mpfr_nrandom.
2 
3 Copyright 2013-2023 Free Software Foundation, Inc.
4 Contributed by Charles Karney <charles@karney.com>, SRI International.
5 
6 This file is part of the GNU MPFR Library.
7 
8 The GNU MPFR Library is free software; you can redistribute it and/or modify
9 it under the terms of the GNU Lesser General Public License as published by
10 the Free Software Foundation; either version 3 of the License, or (at your
11 option) any later version.
12 
13 The GNU MPFR Library is distributed in the hope that it will be useful, but
14 WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
15 or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public
16 License for more details.
17 
18 You should have received a copy of the GNU Lesser General Public License
19 along with the GNU MPFR Library; see the file COPYING.LESSER.  If not, see
20 https://www.gnu.org/licenses/ or write to the Free Software Foundation, Inc.,
21 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA. */
22 
23 /*
24  * A mpfr_random_deviate represents the initial portion e bits of a random
25  * deviate uniformly distributed in (0,1) as
26  *
27  *  typedef struct {
28  *    unsigned long e;            // bits in the fraction
29  *    unsigned long h;            // the high W bits of the fraction
30  *    mpz_t f;                    // the rest of the fraction
31  *  } mpfr_random_deviate_t[1];
32  *
33  * e is always a multiple of RANDOM_CHUNK.  The first RANDOM_CHUNK bits, the
34  * high fraction, are held in an unsigned long, h, and the rest are held in an
35  * mpz_t, f.  The data in h is undefined if e == 0 and, similarly the data in f
36  * is undefined if e <= RANDOM_CHUNK.
37  */
38 
39 #define MPFR_NEED_LONGLONG_H
40 #include "random_deviate.h"
41 
42 /*
43  * RANDOM_CHUNK can be picked in the range 1 <= RANDOM_CHUNK <= 64.  Low values
44  * of RANDOM_CHUNK are good for testing, since they are more likely to make
45  * bugs obvious.  For portability, pick RANDOM_CHUNK <= 32 (since an unsigned
46  * long may only hold 32 bits).  For reproducibility across platforms,
47  * standardize on RANDOM_CHUNK = 32.
48  *
49  * When RANDOM_CHUNK = 32, this representation largely avoids manipulating
50  * mpz's (until the final cast to an mpfr is done).  In addition
51  * mpfr_random_deviate_less usually entails just a single comparison of
52  * unsigned longs.  In this way, we can stick with the published interface for
53  * extracting portions of an mpz (namely through mpz_tstbit) without hurting
54  * efficiency.
55  */
56 #if !defined(RANDOM_CHUNK)
57 /* note: for MPFR, we could use RANDOM_CHUNK = 32 or 64 according to the
58    number of bits per limb, but we use 32 everywhere to get reproducible
59    results on 32-bit and 64-bit computers */
60 #define RANDOM_CHUNK 32     /* Require 1 <= RANDOM_CHUNK <= 32; recommend 32 */
61 #endif
62 
63 #define W RANDOM_CHUNK         /* W is just an shorter name for RANDOM_CHUNK */
64 
65 /* allocate and set to (0,1) */
66 void
mpfr_random_deviate_init(mpfr_random_deviate_ptr x)67 mpfr_random_deviate_init (mpfr_random_deviate_ptr x)
68 {
69   mpz_init (x->f);
70   x->e = 0;
71 }
72 
73 /* reset to (0,1) */
74 void
mpfr_random_deviate_reset(mpfr_random_deviate_ptr x)75 mpfr_random_deviate_reset (mpfr_random_deviate_ptr x)
76 {
77   x->e = 0;
78 }
79 
80 /* deallocate */
81 void
mpfr_random_deviate_clear(mpfr_random_deviate_ptr x)82 mpfr_random_deviate_clear (mpfr_random_deviate_ptr x)
83 {
84   mpz_clear (x->f);
85 }
86 
87 /* swap two random deviates */
88 void
mpfr_random_deviate_swap(mpfr_random_deviate_ptr x,mpfr_random_deviate_ptr y)89 mpfr_random_deviate_swap (mpfr_random_deviate_ptr x,
90                           mpfr_random_deviate_ptr y)
91 {
92   mpfr_random_size_t s;
93   unsigned long t;
94 
95   /* swap x->e and y->e */
96   s = x->e;
97   x->e = y->e;
98   y->e = s;
99 
100   /* swap x->h and y->h */
101   t = x->h;
102   x->h = y->h;
103   y->h = t;
104 
105   /* swap x->f and y->f */
106   mpz_swap (x->f, y->f);
107 }
108 
109 /* ensure x has at least k bits */
110 static void
random_deviate_generate(mpfr_random_deviate_ptr x,mpfr_random_size_t k,gmp_randstate_t r,mpz_t t)111 random_deviate_generate (mpfr_random_deviate_ptr x, mpfr_random_size_t k,
112                          gmp_randstate_t r, mpz_t t)
113 {
114   /* Various compile time checks on mpfr_random_deviate_t */
115 
116   /* Check that the h field of a mpfr_random_deviate_t can hold W bits */
117   MPFR_STAT_STATIC_ASSERT (W > 0 && W <= sizeof (unsigned long) * CHAR_BIT);
118 
119   /* Check mpfr_random_size_t can hold 32 bits and a mpfr_uprec_t.  This
120    * ensures that max(mpfr_random_size_t) exceeds MPFR_PREC_MAX by at least
121    * 2^31 because mpfr_prec_t is a signed version of mpfr_uprec_t.  This allows
122    * random deviates with many leading zeros in the fraction to be handled
123    * correctly. */
124   MPFR_STAT_STATIC_ASSERT (sizeof (mpfr_random_size_t) * CHAR_BIT >= 32 &&
125                            sizeof (mpfr_random_size_t) >=
126                            sizeof (mpfr_uprec_t));
127 
128   /* Finally, at run time, check that k is not too big.  e is set to ceil(k/W)*W
129    * and we require that this allows x->e + 1 in random_deviate_leading_bit to
130    * be computed without overflow. */
131   MPFR_ASSERTN (k <= (mpfr_random_size_t)(-((int) W + 1)));
132 
133   /* if t is non-null, it is used as a temporary */
134   if (x->e >= k)
135     return;
136 
137   if (x->e == 0)
138     {
139       x->h = gmp_urandomb_ui (r, W); /* Generate the high fraction */
140       x->e = W;
141       if (x->e >= k)
142         return;    /* Maybe that's it? */
143     }
144 
145   if (t)
146     {
147       /* passed a mpz_t so compute needed bits in one call to mpz_urandomb */
148       k = ((k + (W-1)) / W) * W;  /* Round up to multiple of W */
149       k -= x->e;                  /* The number of new bits */
150       mpz_urandomb (x->e == W ? x->f : t, r, k); /* Copy directly to x->f? */
151       if (x->e > W)
152         {
153           mpz_mul_2exp (x->f, x->f, k);
154           mpz_add (x->f, x->f, t);
155         }
156       x->e += k;
157     }
158   else
159     {
160       /* no mpz_t so compute the bits W at a time via gmp_urandomb_ui */
161       while (x->e < k)
162         {
163           unsigned long w = gmp_urandomb_ui (r, W);
164           if (x->e == W)
165             mpz_set_ui (x->f, w);
166           else
167             {
168               mpz_mul_2exp (x->f, x->f, W);
169               mpz_add_ui (x->f, x->f, w);
170             }
171           x->e += W;
172         }
173     }
174 }
175 
176 #ifndef MPFR_LONG_WITHIN_LIMB /* a long does not fit in a mp_limb_t */
177 /*
178  * return index [0..127] of highest bit set.  Return 0 if x = 1, 2 if x = 4,
179  * etc. Assume x > 0. (From Algorithms for programmers by Joerg Arndt.)
180  */
181 static int
highest_bit_idx(unsigned long x)182 highest_bit_idx (unsigned long x)
183 {
184   unsigned long y;
185   int r = 0;
186 
187   MPFR_ASSERTD(x > 0);
188   MPFR_STAT_STATIC_ASSERT (sizeof (unsigned long) * CHAR_BIT <= 128);
189 
190   /* A compiler with VRP (like GCC) will optimize and not generate any code
191      for the following lines if unsigned long has at most 64 values bits. */
192   y = ((x >> 16) >> 24) >> 24;  /* portable x >> 64 */
193   if (y != 0)
194     {
195       x = y;
196       r += 64;
197     }
198 
199   if (x & ~0xffffffffUL) { x >>= 16; x >>= 16; r +=32; }
200   if (x &  0xffff0000UL) { x >>= 16; r += 16; }
201   if (x &  0x0000ff00UL) { x >>=  8; r +=  8; }
202   if (x &  0x000000f0UL) { x >>=  4; r +=  4; }
203   if (x &  0x0000000cUL) { x >>=  2; r +=  2; }
204   if (x &  0x00000002UL) {           r +=  1; }
205   return r;
206 }
207 #else /* a long fits in a mp_limb_t */
208 /*
209  * return index [0..63] of highest bit set. Assume x > 0.
210  * Return 0 if x = 1, 63 is if x = ~0 (for 64-bit unsigned long).
211  * See alternate code above too.
212  */
213 static int
highest_bit_idx(unsigned long x)214 highest_bit_idx (unsigned long x)
215 {
216   int cnt;
217 
218   MPFR_ASSERTD(x > 0);
219   count_leading_zeros (cnt, (mp_limb_t) x);
220   MPFR_ASSERTD (cnt <= GMP_NUMB_BITS - 1);
221   return GMP_NUMB_BITS - 1 - cnt;
222 }
223 #endif /* MPFR_LONG_WITHIN_LIMB */
224 
225 /* return position of leading bit, counting from 1 */
226 static mpfr_random_size_t
random_deviate_leading_bit(mpfr_random_deviate_ptr x,gmp_randstate_t r)227 random_deviate_leading_bit (mpfr_random_deviate_ptr x, gmp_randstate_t r)
228 {
229   mpfr_random_size_t l;
230   random_deviate_generate (x, W, r, 0);
231   if (x->h)
232     return W - highest_bit_idx (x->h);
233   random_deviate_generate (x, 2 * W, r, 0);
234   while (mpz_sgn (x->f) == 0)
235     random_deviate_generate (x, x->e + 1, r, 0);
236   l = x->e + 1 - mpz_sizeinbase (x->f, 2);
237   /* Guard against a ridiculously long string of leading zeros in the fraction;
238    * probability of this happening is 2^(-2^31).  In particular ensure that
239    * p + 1 + l in mpfr_random_deviate_value doesn't overflow with p =
240    * MPFR_PREC_MAX. */
241   MPFR_ASSERTN (l + 1 < (mpfr_random_size_t)(-MPFR_PREC_MAX));
242   return l;
243 }
244 
245 /* return kth bit of fraction, representing 2^-k */
246 int
mpfr_random_deviate_tstbit(mpfr_random_deviate_ptr x,mpfr_random_size_t k,gmp_randstate_t r)247 mpfr_random_deviate_tstbit (mpfr_random_deviate_ptr x, mpfr_random_size_t k,
248                             gmp_randstate_t r)
249 {
250   if (k == 0)
251     return 0;
252   random_deviate_generate (x, k, r, 0);
253   if (k <= W)
254     return (x->h >> (W - k)) & 1UL;
255   return mpz_tstbit (x->f, x->e - k);
256 }
257 
258 /* compare two random deviates, x < y */
259 int
mpfr_random_deviate_less(mpfr_random_deviate_ptr x,mpfr_random_deviate_ptr y,gmp_randstate_t r)260 mpfr_random_deviate_less (mpfr_random_deviate_ptr x,
261                           mpfr_random_deviate_ptr y,
262                           gmp_randstate_t r)
263 {
264   mpfr_random_size_t k = 1;
265 
266   if (x == y)
267     return 0;
268   random_deviate_generate (x, W, r, 0);
269   random_deviate_generate (y, W, r, 0);
270   if (x->h != y->h)
271     return x->h < y->h; /* Compare the high fractions */
272   k += W;
273   for (; ; ++k)
274     {             /* Compare the rest of the fraction bit by bit */
275       int a = mpfr_random_deviate_tstbit (x, k, r);
276       int b = mpfr_random_deviate_tstbit (y, k, r);
277       if (a != b)
278         return a < b;
279     }
280 }
281 
282 /* set mpfr_t z = (neg ? -1 : 1) * (n + x) */
283 int
mpfr_random_deviate_value(int neg,unsigned long n,mpfr_random_deviate_ptr x,mpfr_ptr z,gmp_randstate_t r,mpfr_rnd_t rnd)284 mpfr_random_deviate_value (int neg, unsigned long n,
285                            mpfr_random_deviate_ptr x, mpfr_ptr z,
286                            gmp_randstate_t r, mpfr_rnd_t rnd)
287 {
288   /* r is used to add as many bits as necessary to match the precision of z */
289   int s;
290   mpfr_random_size_t l;                     /* The leading bit is 2^(s*l) */
291   mpfr_random_size_t p = mpfr_get_prec (z); /* Number of bits in result */
292   mpz_t t;
293   int inex;
294   mpfr_exp_t negxe;
295 
296   if (n == 0)
297     {
298       s = -1;
299       l = random_deviate_leading_bit (x, r); /* l > 0 */
300     }
301   else
302     {
303       s = 1;
304       l = highest_bit_idx (n); /* l >= 0 */
305     }
306 
307   /*
308    * Leading bit is 2^(s*l); thus the trailing bit in result is 2^(s*l-p+1) =
309    * 2^-(p-1-s*l).  For the sake of illustration, take l = 0 and p = 4, thus
310    * bits through the 1/8 position need to be generated; assume that these bits
311    * are 1.010 = 10/8 which represents a deviate in the range (10,11)/8.
312    *
313    * If the rounding mode is one of RNDZ, RNDU, RNDD, RNDA, we add a 1 bit to
314    * the result to give 1.0101 = (10+1/2)/8.  When this is converted to a MPFR
315    * the result is rounded to 10/8, 11/8, 10/8, 11/8, respectively, and the
316    * inexact flag is set to -1, 1, -1, 1.
317    *
318    * If the rounding mode is RNDN, an additional random bit must be generated
319    * to determine if the result is in (10,10+1/2)/8 or (10+1/2,11)/8.  Assume
320    * that this random bit is 0, so the result is 1.0100 = (10+0/2)/8.  Then an
321    * additional 1 bit is added to give 1.010101 = (10+1/4)/8.  This last bit
322    * avoids the "round ties to even rule" (because there are no ties) and sets
323    * the inexact flag so that the result is 10/8 with the inexact flag = 1.
324    *
325    * Here we always generate at least 2 additional random bits, so that bit
326    * position 2^-(p+1-s*l) is generated.  (The result often contains more
327    * random bits than this because random bits are added in batches of W and
328    * because additional bits may have been required in the process of
329    * generating the random deviate.)  The integer and all the bits in the
330    * fraction are then copied into an mpz, the least significant bit is
331    * unconditionally set to 1, the sign is set, and the result together with
332    * the exponent -x->e is used to generate an mpfr using mpfr_set_z_2exp.
333    *
334    * If random bits were very expensive, we would only need to generate to the
335    * 2^-(p-1-s*l) bit (no extra bits) for the RNDZ, RNDU, RNDD, RNDA modes and
336    * to the 2^-(p-s*l) bit (1 extra bit) for RNDN.  By always generating 2 bits
337    * we save on some bit shuffling when formed the mpz to be converted to an
338    * mpfr.  The implementation of the RandomNumber class in RandomLib
339    * illustrates the more parsimonious approach (which was taken to allow
340    * accurate counts of the number of random digits to be made).
341    */
342   mpz_init (t);
343   /*
344    * This is the only call to random_deviate_generate where a mpz_t is passed
345    * (because an arbitrarily large number of bits may need to be generated).
346    */
347   if ((s > 0 && p + 1 > l) ||
348       (s < 0 && p + 1 + l > 0))
349     random_deviate_generate (x, s > 0 ? p + 1 - l : p + 1 + l, r, t);
350   if (n == 0)
351     {
352       /* Since the minimum prec is 2 we know that x->h has been generated. */
353       mpz_set_ui (t, x->h);        /* Set high fraction */
354     }
355   else
356     {
357       mpz_set_ui (t, n);           /* The integer part */
358       if (x->e > 0)
359         {
360           mpz_mul_2exp (t, t, W);    /* Shift to allow for high fraction */
361           mpz_add_ui (t, t, x->h);   /* Add high fraction */
362         }
363     }
364   if (x->e > W)
365     {
366       mpz_mul_2exp (t, t, x->e - W); /* Shift to allow for low fraction */
367       mpz_add (t, t, x->f);          /* Add low fraction */
368     }
369   /*
370    * We could trim off any excess bits here by shifting rightward.  This is an
371    * unnecessary complication.
372    */
373   mpz_setbit (t, 0);     /* Set the trailing bit so result is always inexact */
374   if (neg)
375     mpz_neg (t, t);
376   /* Portable version of the negation of x->e, with a check of overflow. */
377   if (MPFR_UNLIKELY (x->e > MPFR_EXP_MAX))
378     {
379       /* Overflow, except when x->e = MPFR_EXP_MAX + 1 = - MPFR_EXP_MIN. */
380       MPFR_ASSERTN (MPFR_EXP_MIN + MPFR_EXP_MAX == -1 &&
381                     x->e == (mpfr_random_size_t) MPFR_EXP_MAX + 1);
382       negxe = MPFR_EXP_MIN;
383     }
384   else
385     negxe = - (mpfr_exp_t) x->e;
386   /*
387    * Let mpfr_set_z_2exp do all the work of rounding to the requested
388    * precision, setting overflow/underflow flags, and returning the right
389    * inexact value.
390    */
391   inex = mpfr_set_z_2exp (z, t, negxe, rnd);
392   mpz_clear (t);
393   return inex;
394 }
395