xref: /llvm-project/pstl/test/std/algorithms/alg.sorting/alg.set.operations/set.pass.cpp (revision b5e896c0493d44c2d9819d584fe2814af7cfd476)
1 // -*- C++ -*-
2 //===-- set.pass.cpp ------------------------------------------------------===//
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 
10 // UNSUPPORTED: c++03, c++11, c++14
11 
12 #include "support/pstl_test_config.h"
13 
14 #include <algorithm>
15 #include <chrono>
16 #include <cmath>
17 #include <execution>
18 #include <functional>
19 
20 #include "support/utils.h"
21 
22 using namespace TestUtils;
23 
24 template <typename T>
25 struct Num
26 {
27     T val;
28 
NumNum29     Num() : val{} {}
NumNum30     Num(const T& v) : val(v) {}
31 
32     //for "includes" checks
33     template <typename T1>
34     bool
operator <Num35     operator<(const Num<T1>& v1) const
36     {
37         return val < v1.val;
38     }
39 
40     //The types Type1 and Type2 must be such that an object of type InputIt can be dereferenced and then implicitly converted to both of them
41     template <typename T1>
operator Num<T1>Num42     operator Num<T1>() const
43     {
44         return Num<T1>((T1)val);
45     }
46 
47     friend bool
operator ==(const Num & v1,const Num & v2)48     operator==(const Num& v1, const Num& v2)
49     {
50         return v1.val == v2.val;
51     }
52 };
53 
54 template <typename Type>
55 struct test_set_union
56 {
57     template <typename Policy, typename InputIterator1, typename InputIterator2, typename Compare>
58     typename std::enable_if<!TestUtils::isReverse<InputIterator1>::value, void>::type
operator ()test_set_union59     operator()(Policy&& exec, InputIterator1 first1, InputIterator1 last1, InputIterator2 first2, InputIterator2 last2,
60                Compare comp)
61     {
62         using T1 = typename std::iterator_traits<InputIterator1>::value_type;
63 
64         auto n1 = std::distance(first1, last1);
65         auto n2 = std::distance(first2, last2);
66         auto n = n1 + n2;
67         Sequence<T1> expect(n);
68         Sequence<T1> out(n);
69 
70         auto expect_res = std::set_union(first1, last1, first2, last2, expect.begin(), comp);
71         auto res = std::set_union(exec, first1, last1, first2, last2, out.begin(), comp);
72 
73         EXPECT_TRUE(expect_res - expect.begin() == res - out.begin(), "wrong result for set_union");
74         EXPECT_EQ_N(expect.begin(), out.begin(), std::distance(out.begin(), res), "wrong set_union effect");
75     }
76 
77     template <typename Policy, typename InputIterator1, typename InputIterator2, typename Compare>
78     typename std::enable_if<TestUtils::isReverse<InputIterator1>::value, void>::type
operator ()test_set_union79     operator()(Policy&&, InputIterator1, InputIterator1, InputIterator2, InputIterator2, Compare)
80     {
81     }
82 };
83 
84 template <typename Type>
85 struct test_set_intersection
86 {
87     template <typename Policy, typename InputIterator1, typename InputIterator2, typename Compare>
88     typename std::enable_if<!TestUtils::isReverse<InputIterator1>::value, void>::type
operator ()test_set_intersection89     operator()(Policy&& exec, InputIterator1 first1, InputIterator1 last1, InputIterator2 first2, InputIterator2 last2,
90                Compare comp)
91     {
92         using T1 = typename std::iterator_traits<InputIterator1>::value_type;
93 
94         auto n1 = std::distance(first1, last1);
95         auto n2 = std::distance(first2, last2);
96         auto n = n1 + n2;
97         Sequence<T1> expect(n);
98         Sequence<T1> out(n);
99 
100         auto expect_res = std::set_intersection(first1, last1, first2, last2, expect.begin(), comp);
101         auto res = std::set_intersection(exec, first1, last1, first2, last2, out.begin(), comp);
102 
103         EXPECT_TRUE(expect_res - expect.begin() == res - out.begin(), "wrong result for set_intersection");
104         EXPECT_EQ_N(expect.begin(), out.begin(), std::distance(out.begin(), res), "wrong set_intersection effect");
105     }
106 
107     template <typename Policy, typename InputIterator1, typename InputIterator2, typename Compare>
108     typename std::enable_if<TestUtils::isReverse<InputIterator1>::value, void>::type
operator ()test_set_intersection109     operator()(Policy&&, InputIterator1, InputIterator1, InputIterator2, InputIterator2, Compare)
110     {
111     }
112 };
113 
114 template <typename Type>
115 struct test_set_difference
116 {
117     template <typename Policy, typename InputIterator1, typename InputIterator2, typename Compare>
118     typename std::enable_if<!TestUtils::isReverse<InputIterator1>::value, void>::type
operator ()test_set_difference119     operator()(Policy&& exec, InputIterator1 first1, InputIterator1 last1, InputIterator2 first2, InputIterator2 last2,
120                Compare comp)
121     {
122         using T1 = typename std::iterator_traits<InputIterator1>::value_type;
123 
124         auto n1 = std::distance(first1, last1);
125         auto n2 = std::distance(first2, last2);
126         auto n = n1 + n2;
127         Sequence<T1> expect(n);
128         Sequence<T1> out(n);
129 
130         auto expect_res = std::set_difference(first1, last1, first2, last2, expect.begin(), comp);
131         auto res = std::set_difference(exec, first1, last1, first2, last2, out.begin(), comp);
132 
133         EXPECT_TRUE(expect_res - expect.begin() == res - out.begin(), "wrong result for set_difference");
134         EXPECT_EQ_N(expect.begin(), out.begin(), std::distance(out.begin(), res), "wrong set_difference effect");
135     }
136 
137     template <typename Policy, typename InputIterator1, typename InputIterator2, typename Compare>
138     typename std::enable_if<TestUtils::isReverse<InputIterator1>::value, void>::type
operator ()test_set_difference139     operator()(Policy&&, InputIterator1, InputIterator1, InputIterator2, InputIterator2, Compare)
140     {
141     }
142 };
143 
144 template <typename Type>
145 struct test_set_symmetric_difference
146 {
147     template <typename Policy, typename InputIterator1, typename InputIterator2, typename Compare>
148     typename std::enable_if<!TestUtils::isReverse<InputIterator1>::value, void>::type
operator ()test_set_symmetric_difference149     operator()(Policy&& exec, InputIterator1 first1, InputIterator1 last1, InputIterator2 first2, InputIterator2 last2,
150                Compare comp)
151     {
152         using T1 = typename std::iterator_traits<InputIterator1>::value_type;
153 
154         auto n1 = std::distance(first1, last1);
155         auto n2 = std::distance(first2, last2);
156         auto n = n1 + n2;
157         Sequence<T1> expect(n);
158         Sequence<T1> out(n);
159 
160         auto expect_res = std::set_symmetric_difference(first1, last1, first2, last2, expect.begin(), comp);
161         auto res = std::set_symmetric_difference(exec, first1, last1, first2, last2, out.begin(), comp);
162 
163         EXPECT_TRUE(expect_res - expect.begin() == res - out.begin(), "wrong result for set_symmetric_difference");
164         EXPECT_EQ_N(expect.begin(), out.begin(), std::distance(out.begin(), res),
165                     "wrong set_symmetric_difference effect");
166     }
167 
168     template <typename Policy, typename InputIterator1, typename InputIterator2, typename Compare>
169     typename std::enable_if<TestUtils::isReverse<InputIterator1>::value, void>::type
operator ()test_set_symmetric_difference170     operator()(Policy&&, InputIterator1, InputIterator1, InputIterator2, InputIterator2, Compare)
171     {
172     }
173 };
174 
175 template <typename T1, typename T2, typename Compare>
176 void
test_set(Compare compare)177 test_set(Compare compare)
178 {
179 
180     const std::size_t n_max = 100000;
181 
182     // The rand()%(2*n+1) encourages generation of some duplicates.
183     std::srand(4200);
184 
185     for (std::size_t n = 0; n < n_max; n = n <= 16 ? n + 1 : size_t(3.1415 * n))
186     {
187         for (std::size_t m = 0; m < n_max; m = m <= 16 ? m + 1 : size_t(2.71828 * m))
188         {
189             //prepare the input ranges
190             Sequence<T1> in1(n, [](std::size_t k) { return rand() % (2 * k + 1); });
191             Sequence<T2> in2(m, [m](std::size_t k) { return (m % 2) * rand() + rand() % (k + 1); });
192 
193             std::sort(in1.begin(), in1.end(), compare);
194             std::sort(in2.begin(), in2.end(), compare);
195 
196             invoke_on_all_policies(test_set_union<T1>(), in1.begin(), in1.end(), in2.cbegin(), in2.cend(),
197                                         compare);
198 
199             invoke_on_all_policies(test_set_intersection<T1>(), in1.begin(), in1.end(), in2.cbegin(), in2.cend(),
200                                         compare);
201 
202             invoke_on_all_policies(test_set_difference<T1>(), in1.begin(), in1.end(), in2.cbegin(), in2.cend(),
203                                         compare);
204 
205             invoke_on_all_policies(test_set_symmetric_difference<T1>(), in1.begin(), in1.end(), in2.cbegin(),
206                                         in2.cend(), compare);
207         }
208     }
209 }
210 
211 template <typename T>
212 struct test_non_const_set_difference
213 {
214     template <typename Policy, typename InputIterator, typename OutputInterator>
215     void
operator ()test_non_const_set_difference216     operator()(Policy&& exec, InputIterator input_iter, OutputInterator out_iter)
217     {
218         set_difference(exec, input_iter, input_iter, input_iter, input_iter, out_iter, non_const(std::less<T>()));
219     }
220 };
221 
222 template <typename T>
223 struct test_non_const_set_intersection
224 {
225     template <typename Policy, typename InputIterator, typename OutputInterator>
226     void
operator ()test_non_const_set_intersection227     operator()(Policy&& exec, InputIterator input_iter, OutputInterator out_iter)
228     {
229         set_intersection(exec, input_iter, input_iter, input_iter, input_iter, out_iter, non_const(std::less<T>()));
230     }
231 };
232 
233 template <typename T>
234 struct test_non_const_set_symmetric_difference
235 {
236     template <typename Policy, typename InputIterator, typename OutputInterator>
237     void
operator ()test_non_const_set_symmetric_difference238     operator()(Policy&& exec, InputIterator input_iter, OutputInterator out_iter)
239     {
240         set_symmetric_difference(exec, input_iter, input_iter, input_iter, input_iter, out_iter,
241                                  non_const(std::less<T>()));
242     }
243 };
244 
245 template <typename T>
246 struct test_non_const_set_union
247 {
248     template <typename Policy, typename InputIterator, typename OutputInterator>
249     void
operator ()test_non_const_set_union250     operator()(Policy&& exec, InputIterator input_iter, OutputInterator out_iter)
251     {
252         set_union(exec, input_iter, input_iter, input_iter, input_iter, out_iter, non_const(std::less<T>()));
253     }
254 };
255 
256 int
main()257 main()
258 {
259 
260     test_set<float64_t, float64_t>(std::less<>());
261     test_set<Num<int64_t>, Num<int32_t>>([](const Num<int64_t>& x, const Num<int32_t>& y) { return x < y; });
262 
263     test_set<MemoryChecker, MemoryChecker>([](const MemoryChecker& val1, const MemoryChecker& val2) -> bool {
264         return val1.value() < val2.value();
265     });
266     EXPECT_FALSE(MemoryChecker::alive_objects() < 0, "wrong effect from set algorithms: number of ctors calls < num of dtors calls");
267     EXPECT_FALSE(MemoryChecker::alive_objects() > 0, "wrong effect from set algorithms: number of ctors calls > num of dtors calls");
268 
269     test_algo_basic_double<int32_t>(run_for_rnd_fw<test_non_const_set_difference<int32_t>>());
270 
271     test_algo_basic_double<int32_t>(run_for_rnd_fw<test_non_const_set_intersection<int32_t>>());
272 
273     test_algo_basic_double<int32_t>(run_for_rnd_fw<test_non_const_set_symmetric_difference<int32_t>>());
274 
275     test_algo_basic_double<int32_t>(run_for_rnd_fw<test_non_const_set_union<int32_t>>());
276 
277     std::cout << done() << std::endl;
278 
279     return 0;
280 }
281