xref: /llvm-project/libcxx/test/std/algorithms/alg.modifying.operations/alg.random.sample/sample.pass.cpp (revision 5f26d8636f506b487962cd2a9b0e32940e93fa9b)
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 // UNSUPPORTED: c++03, c++11, c++14
10 
11 // <algorithm>
12 
13 // template <class PopulationIterator, class SampleIterator, class Distance,
14 //           class UniformRandomNumberGenerator>
15 // SampleIterator sample(PopulationIterator first, PopulationIterator last,
16 //                       SampleIterator out, Distance n,
17 //                       UniformRandomNumberGenerator &&g);
18 
19 #include <algorithm>
20 #include <random>
21 #include <type_traits>
22 #include <cassert>
23 #include <cstddef>
24 
25 #include "test_iterators.h"
26 #include "test_macros.h"
27 
28 struct ReservoirSampleExpectations {
29   enum { os = 4 };
30   static int oa1[os];
31   static int oa2[os];
32 };
33 
34 int ReservoirSampleExpectations::oa1[] = {10, 5, 9, 4};
35 int ReservoirSampleExpectations::oa2[] = {5, 2, 10, 4};
36 
37 struct SelectionSampleExpectations {
38   enum { os = 4 };
39   static int oa1[os];
40   static int oa2[os];
41 };
42 
43 int SelectionSampleExpectations::oa1[] = {1, 4, 6, 7};
44 int SelectionSampleExpectations::oa2[] = {1, 2, 6, 8};
45 
46 template <class IteratorCategory> struct TestExpectations
47     : public SelectionSampleExpectations {};
48 
49 template <>
50 struct TestExpectations<std::input_iterator_tag>
51     : public ReservoirSampleExpectations {};
52 
53 template <template<class...> class PopulationIteratorType, class PopulationItem,
54           template<class...> class SampleIteratorType, class SampleItem>
test()55 void test() {
56   typedef PopulationIteratorType<PopulationItem *> PopulationIterator;
57   typedef SampleIteratorType<SampleItem *> SampleIterator;
58   PopulationItem ia[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
59   const unsigned is = sizeof(ia) / sizeof(ia[0]);
60   typedef TestExpectations<typename std::iterator_traits<
61       PopulationIterator>::iterator_category> Expectations;
62   const unsigned os = Expectations::os;
63   SampleItem oa[os];
64   const int *oa1 = Expectations::oa1;
65   ((void)oa1); // Prevent unused warning
66   const int *oa2 = Expectations::oa2;
67   ((void)oa2); // Prevent unused warning
68   std::minstd_rand g;
69   SampleIterator end = std::sample(PopulationIterator(ia),
70                                    PopulationIterator(ia + is),
71                                    SampleIterator(oa), os, g);
72   assert(static_cast<std::size_t>(base(end) - oa) == std::min(os, is));
73   // sample() is deterministic but non-reproducible;
74   // its results can vary between implementations.
75   LIBCPP_ASSERT(std::equal(oa, oa + os, oa1));
76   end = std::sample(PopulationIterator(ia),
77                                   PopulationIterator(ia + is),
78                                   SampleIterator(oa), os, std::move(g));
79   assert(static_cast<std::size_t>(base(end) - oa) == std::min(os, is));
80   LIBCPP_ASSERT(std::equal(oa, oa + os, oa2));
81 }
82 
83 template <template<class...> class PopulationIteratorType, class PopulationItem,
84           template<class...> class SampleIteratorType, class SampleItem>
test_empty_population()85 void test_empty_population() {
86   typedef PopulationIteratorType<PopulationItem *> PopulationIterator;
87   typedef SampleIteratorType<SampleItem *> SampleIterator;
88   PopulationItem ia[] = {42};
89   const unsigned os = 4;
90   SampleItem oa[os];
91   std::minstd_rand g;
92   SampleIterator end =
93       std::sample(PopulationIterator(ia), PopulationIterator(ia),
94                                 SampleIterator(oa), os, g);
95   assert(base(end) == oa);
96 }
97 
98 template <template<class...> class PopulationIteratorType, class PopulationItem,
99           template<class...> class SampleIteratorType, class SampleItem>
test_empty_sample()100 void test_empty_sample() {
101   typedef PopulationIteratorType<PopulationItem *> PopulationIterator;
102   typedef SampleIteratorType<SampleItem *> SampleIterator;
103   PopulationItem ia[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
104   const unsigned is = sizeof(ia) / sizeof(ia[0]);
105   SampleItem oa[1];
106   std::minstd_rand g;
107   SampleIterator end =
108       std::sample(PopulationIterator(ia), PopulationIterator(ia + is),
109                                 SampleIterator(oa), 0, g);
110   assert(base(end) == oa);
111 }
112 
113 template <template<class...> class PopulationIteratorType, class PopulationItem,
114           template<class...> class SampleIteratorType, class SampleItem>
test_small_population()115 void test_small_population() {
116   // The population size is less than the sample size.
117   typedef PopulationIteratorType<PopulationItem *> PopulationIterator;
118   typedef SampleIteratorType<SampleItem *> SampleIterator;
119   PopulationItem ia[] = {1, 2, 3, 4, 5};
120   const unsigned is = sizeof(ia) / sizeof(ia[0]);
121   const unsigned os = 8;
122   SampleItem oa[os];
123   const SampleItem oa1[] = {1, 2, 3, 4, 5};
124   std::minstd_rand g;
125   SampleIterator end = std::sample(PopulationIterator(ia),
126                                    PopulationIterator(ia + is),
127                                    SampleIterator(oa), os, g);
128   assert(static_cast<std::size_t>(base(end) - oa) == std::min(os, is));
129   typedef typename std::iterator_traits<PopulationIterator>::iterator_category PopulationCategory;
130   if (std::is_base_of<std::forward_iterator_tag, PopulationCategory>::value) {
131     assert(std::equal(oa, base(end), oa1));
132   } else {
133     assert(std::is_permutation(oa, base(end), oa1));
134   }
135 }
136 
main(int,char **)137 int main(int, char**) {
138   test<cpp17_input_iterator, int, random_access_iterator, int>();
139   test<forward_iterator, int, cpp17_output_iterator, int>();
140   test<forward_iterator, int, random_access_iterator, int>();
141 
142   test<cpp17_input_iterator, int, random_access_iterator, double>();
143   test<forward_iterator, int, cpp17_output_iterator, double>();
144   test<forward_iterator, int, random_access_iterator, double>();
145 
146   test_empty_population<cpp17_input_iterator, int, random_access_iterator, int>();
147   test_empty_population<forward_iterator, int, cpp17_output_iterator, int>();
148   test_empty_population<forward_iterator, int, random_access_iterator, int>();
149 
150   test_empty_sample<cpp17_input_iterator, int, random_access_iterator, int>();
151   test_empty_sample<forward_iterator, int, cpp17_output_iterator, int>();
152   test_empty_sample<forward_iterator, int, random_access_iterator, int>();
153 
154   test_small_population<cpp17_input_iterator, int, random_access_iterator, int>();
155   test_small_population<forward_iterator, int, cpp17_output_iterator, int>();
156   test_small_population<forward_iterator, int, random_access_iterator, int>();
157 
158   return 0;
159 }
160