xref: /llvm-project/libcxx/test/std/numerics/rand/rand.dist/rand.dist.uni/rand.dist.uni.int/eval.pass.cpp (revision 09e3a360581dc36d0820d3fb6da9bd7cfed87b5d)
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // REQUIRES: long_tests
10 
11 // <random>
12 
13 // template<class _IntType = int>
14 // class uniform_int_distribution
15 
16 // template<class _URNG> result_type operator()(_URNG& g);
17 
18 #include <random>
19 #include <cassert>
20 #include <climits>
21 #include <cmath>
22 #include <cstddef>
23 #include <cstdint>
24 #include <limits>
25 #include <numeric>
26 #include <vector>
27 
28 #include "test_macros.h"
29 
30 template <class T>
31 T sqr(T x) {
32     return x * x;
33 }
34 
35 template <class ResultType, class EngineType>
36 void test_statistics(ResultType a, ResultType b) {
37     ASSERT_SAME_TYPE(typename std::uniform_int_distribution<ResultType>::result_type, ResultType);
38 
39     EngineType g;
40     std::uniform_int_distribution<ResultType> dist(a, b);
41     assert(dist.a() == a);
42     assert(dist.b() == b);
43     std::vector<ResultType> u;
44     for (int i = 0; i < 10000; ++i) {
45         ResultType v = dist(g);
46         assert(a <= v && v <= b);
47         u.push_back(v);
48     }
49 
50     // Quick check: The chance of getting *no* hits in any given tenth of the range
51     // is (0.9)^10000, or "ultra-astronomically low."
52     bool bottom_tenth = false;
53     bool top_tenth = false;
54     for (std::size_t i = 0; i < u.size(); ++i) {
55         bottom_tenth = bottom_tenth || (u[i] <= (a + (b / 10) - (a / 10)));
56         top_tenth = top_tenth || (u[i] >= (b - (b / 10) + (a / 10)));
57     }
58     assert(bottom_tenth);  // ...is populated
59     assert(top_tenth);  // ...is populated
60 
61     // Now do some more involved statistical math.
62     double mean = std::accumulate(u.begin(), u.end(), 0.0) / u.size();
63     double var = 0;
64     double skew = 0;
65     double kurtosis = 0;
66     for (std::size_t i = 0; i < u.size(); ++i) {
67         double dbl = (u[i] - mean);
68         double d2 = dbl * dbl;
69         var += d2;
70         skew += dbl * d2;
71         kurtosis += d2 * d2;
72     }
73     var /= u.size();
74     double dev = std::sqrt(var);
75     skew /= u.size() * dev * var;
76     kurtosis /= u.size() * var * var;
77 
78     double expected_mean = double(a) + double(b)/2 - double(a)/2;
79     double expected_var = (sqr(double(b) - double(a) + 1) - 1) / 12;
80 
81     double range = double(b) - double(a) + 1.0;
82     assert(range > range / 10);  // i.e., it's not infinity
83 
84     assert(std::abs(mean - expected_mean) < range / 100);
85     assert(std::abs(var - expected_var) < expected_var / 50);
86     assert(-0.1 < skew && skew < 0.1);
87     assert(1.6 < kurtosis && kurtosis < 2.0);
88 }
89 
90 template <class ResultType, class EngineType>
91 void test_statistics() {
92     test_statistics<ResultType, EngineType>(0, std::numeric_limits<ResultType>::max());
93 }
94 
95 int main(int, char**)
96 {
97     test_statistics<int, std::minstd_rand0>();
98     test_statistics<int, std::minstd_rand>();
99     test_statistics<int, std::mt19937>();
100     test_statistics<int, std::mt19937_64>();
101     test_statistics<int, std::ranlux24_base>();
102     test_statistics<int, std::ranlux48_base>();
103     test_statistics<int, std::ranlux24>();
104     test_statistics<int, std::ranlux48>();
105     test_statistics<int, std::knuth_b>();
106     test_statistics<int, std::minstd_rand0>(-6, 106);
107     test_statistics<int, std::minstd_rand>(5, 100);
108 
109     test_statistics<short, std::minstd_rand0>();
110     test_statistics<int, std::minstd_rand0>();
111     test_statistics<long, std::minstd_rand0>();
112     test_statistics<long long, std::minstd_rand0>();
113 
114     test_statistics<unsigned short, std::minstd_rand0>();
115     test_statistics<unsigned int, std::minstd_rand0>();
116     test_statistics<unsigned long, std::minstd_rand0>();
117     test_statistics<unsigned long long, std::minstd_rand0>();
118 
119     test_statistics<short, std::minstd_rand0>(SHRT_MIN, SHRT_MAX);
120 
121 #if defined(_LIBCPP_VERSION) // extension
122     test_statistics<std::int8_t, std::minstd_rand0>();
123     test_statistics<std::uint8_t, std::minstd_rand0>();
124 
125 #if !defined(TEST_HAS_NO_INT128)
126     test_statistics<__int128_t, std::minstd_rand0>();
127     test_statistics<__uint128_t, std::minstd_rand0>();
128 
129     test_statistics<__int128_t, std::minstd_rand0>(-100, 900);
130     test_statistics<__int128_t, std::minstd_rand0>(0, UINT64_MAX);
131     test_statistics<__int128_t, std::minstd_rand0>(std::numeric_limits<__int128_t>::min(), std::numeric_limits<__int128_t>::max());
132     test_statistics<__uint128_t, std::minstd_rand0>(0, UINT64_MAX);
133 #endif
134 #endif
135 
136     return 0;
137 }
138