xref: /llvm-project/libcxx/test/std/algorithms/alg.nonmodifying/alg.search/search_n_pred.pass.cpp (revision 257831582c759048ecce5c393993268b71a602e0)
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 // <algorithm>
10 
11 // template<class ForwardIterator, class Size, class T, class BinaryPredicate>
12 //   constexpr ForwardIterator     // constexpr after C++17
13 //   search_n(ForwardIterator first, ForwardIterator last, Size count,
14 //            const T& value, BinaryPredicate pred);
15 
16 #include <algorithm>
17 #include <cassert>
18 
19 #include "test_macros.h"
20 #include "test_iterators.h"
21 #include "user_defined_integral.h"
22 
23 #if TEST_STD_VER > 17
24 TEST_CONSTEXPR bool eq(int a, int b) { return a == b; }
25 
26 TEST_CONSTEXPR bool test_constexpr() {
27     int ia[] = {0, 0, 1, 1, 2, 2};
28     return    (std::search_n(std::begin(ia), std::end(ia), 1, 0, eq) == ia)
29            && (std::search_n(std::begin(ia), std::end(ia), 2, 1, eq) == ia+2)
30            && (std::search_n(std::begin(ia), std::end(ia), 1, 3, eq) == std::end(ia))
31            ;
32     }
33 #endif
34 
35 struct count_equal {
36   static unsigned count;
37   template <class T>
38   bool operator()(const T& x, const T& y) const {
39     ++count;
40     return x == y;
41   }
42 };
43 
44 unsigned count_equal::count = 0;
45 
46 
47 template <class Iter>
48 void
49 test()
50 {
51     int ia[] = {0, 1, 2, 3, 4, 5};
52     const unsigned sa = sizeof(ia)/sizeof(ia[0]);
53     count_equal::count = 0;
54     assert(std::search_n(Iter(ia), Iter(ia+sa), 0, 0, count_equal()) == Iter(ia));
55     assert(count_equal::count <= sa);
56     count_equal::count = 0;
57     assert(std::search_n(Iter(ia), Iter(ia+sa), 1, 0, count_equal()) == Iter(ia+0));
58     assert(count_equal::count <= sa);
59     count_equal::count = 0;
60     assert(std::search_n(Iter(ia), Iter(ia+sa), 2, 0, count_equal()) == Iter(ia+sa));
61     assert(count_equal::count <= sa);
62     count_equal::count = 0;
63     assert(std::search_n(Iter(ia), Iter(ia+sa), sa, 0, count_equal()) == Iter(ia+sa));
64     assert(count_equal::count <= sa);
65     count_equal::count = 0;
66     assert(std::search_n(Iter(ia), Iter(ia+sa), 0, 3, count_equal()) == Iter(ia));
67     assert(count_equal::count <= sa);
68     count_equal::count = 0;
69     assert(std::search_n(Iter(ia), Iter(ia+sa), 1, 3, count_equal()) == Iter(ia+3));
70     assert(count_equal::count <= sa);
71     count_equal::count = 0;
72     assert(std::search_n(Iter(ia), Iter(ia+sa), 2, 3, count_equal()) == Iter(ia+sa));
73     assert(count_equal::count <= sa);
74     count_equal::count = 0;
75     assert(std::search_n(Iter(ia), Iter(ia+sa), sa, 3, count_equal()) == Iter(ia+sa));
76     assert(count_equal::count <= sa);
77     count_equal::count = 0;
78     assert(std::search_n(Iter(ia), Iter(ia+sa), 0, 5, count_equal()) == Iter(ia));
79     assert(count_equal::count <= sa);
80     count_equal::count = 0;
81     assert(std::search_n(Iter(ia), Iter(ia+sa), 1, 5, count_equal()) == Iter(ia+5));
82     assert(count_equal::count <= sa);
83     count_equal::count = 0;
84     assert(std::search_n(Iter(ia), Iter(ia+sa), 2, 5, count_equal()) == Iter(ia+sa));
85     assert(count_equal::count <= sa);
86     count_equal::count = 0;
87     assert(std::search_n(Iter(ia), Iter(ia+sa), sa, 5, count_equal()) == Iter(ia+sa));
88     assert(count_equal::count <= sa);
89     count_equal::count = 0;
90 
91     int ib[] = {0, 0, 1, 1, 2, 2};
92     const unsigned sb = sizeof(ib)/sizeof(ib[0]);
93     assert(std::search_n(Iter(ib), Iter(ib+sb), 0, 0, count_equal()) == Iter(ib));
94     assert(count_equal::count <= sb);
95     count_equal::count = 0;
96     assert(std::search_n(Iter(ib), Iter(ib+sb), 1, 0, count_equal()) == Iter(ib+0));
97     assert(count_equal::count <= sb);
98     count_equal::count = 0;
99     assert(std::search_n(Iter(ib), Iter(ib+sb), 2, 0, count_equal()) == Iter(ib+0));
100     assert(count_equal::count <= sb);
101     count_equal::count = 0;
102     assert(std::search_n(Iter(ib), Iter(ib+sb), 3, 0, count_equal()) == Iter(ib+sb));
103     assert(count_equal::count <= sb);
104     count_equal::count = 0;
105     assert(std::search_n(Iter(ib), Iter(ib+sb), sb, 0, count_equal()) == Iter(ib+sb));
106     assert(count_equal::count <= sb);
107     count_equal::count = 0;
108     assert(std::search_n(Iter(ib), Iter(ib+sb), 0, 1, count_equal()) == Iter(ib));
109     assert(count_equal::count <= sb);
110     count_equal::count = 0;
111     assert(std::search_n(Iter(ib), Iter(ib+sb), 1, 1, count_equal()) == Iter(ib+2));
112     assert(count_equal::count <= sb);
113     count_equal::count = 0;
114     assert(std::search_n(Iter(ib), Iter(ib+sb), 2, 1, count_equal()) == Iter(ib+2));
115     assert(count_equal::count <= sb);
116     count_equal::count = 0;
117     assert(std::search_n(Iter(ib), Iter(ib+sb), 3, 1, count_equal()) == Iter(ib+sb));
118     assert(count_equal::count <= sb);
119     count_equal::count = 0;
120     assert(std::search_n(Iter(ib), Iter(ib+sb), sb, 1, count_equal()) == Iter(ib+sb));
121     assert(count_equal::count <= sb);
122     count_equal::count = 0;
123     assert(std::search_n(Iter(ib), Iter(ib+sb), 0, 2, count_equal()) == Iter(ib));
124     assert(count_equal::count <= sb);
125     count_equal::count = 0;
126     assert(std::search_n(Iter(ib), Iter(ib+sb), 1, 2, count_equal()) == Iter(ib+4));
127     assert(count_equal::count <= sb);
128     count_equal::count = 0;
129     assert(std::search_n(Iter(ib), Iter(ib+sb), 2, 2, count_equal()) == Iter(ib+4));
130     assert(count_equal::count <= sb);
131     count_equal::count = 0;
132     assert(std::search_n(Iter(ib), Iter(ib+sb), 3, 2, count_equal()) == Iter(ib+sb));
133     assert(count_equal::count <= sb);
134     count_equal::count = 0;
135     assert(std::search_n(Iter(ib), Iter(ib+sb), sb, 2, count_equal()) == Iter(ib+sb));
136     assert(count_equal::count <= sb);
137     count_equal::count = 0;
138 
139     int ic[] = {0, 0, 0};
140     const unsigned sc = sizeof(ic)/sizeof(ic[0]);
141     assert(std::search_n(Iter(ic), Iter(ic+sc), 0, 0, count_equal()) == Iter(ic));
142     assert(count_equal::count <= sc);
143     count_equal::count = 0;
144     assert(std::search_n(Iter(ic), Iter(ic+sc), 1, 0, count_equal()) == Iter(ic));
145     assert(count_equal::count <= sc);
146     count_equal::count = 0;
147     assert(std::search_n(Iter(ic), Iter(ic+sc), 2, 0, count_equal()) == Iter(ic));
148     assert(count_equal::count <= sc);
149     count_equal::count = 0;
150     assert(std::search_n(Iter(ic), Iter(ic+sc), 3, 0, count_equal()) == Iter(ic));
151     assert(count_equal::count <= sc);
152     count_equal::count = 0;
153     assert(std::search_n(Iter(ic), Iter(ic+sc), 4, 0, count_equal()) == Iter(ic+sc));
154     assert(count_equal::count <= sc);
155     count_equal::count = 0;
156 
157     // Check that we properly convert the size argument to an integral.
158     TEST_IGNORE_NODISCARD std::search_n(Iter(ic), Iter(ic+sc), UserDefinedIntegral<unsigned>(4), 0, count_equal());
159     count_equal::count = 0;
160 }
161 
162 class A {
163 public:
164   A(int x, int y) : x_(x), y_(y) {}
165   int x() const { return x_; }
166   int y() const { return y_; }
167 
168 private:
169   int x_;
170   int y_;
171 };
172 
173 struct Pred {
174   bool operator()(const A& l, int r) const { return l.x() == r; }
175 };
176 
177 int main(int, char**)
178 {
179     test<forward_iterator<const int*> >();
180     test<bidirectional_iterator<const int*> >();
181     test<random_access_iterator<const int*> >();
182 
183     // test bug reported in https://reviews.llvm.org/D124079?#3661721
184     {
185         A a[]       = {A(1, 2), A(2, 3), A(2, 4)};
186         int value   = 2;
187         auto result = std::search_n(a, a + 3, 1, value, Pred());
188         assert(result == a + 1);
189     }
190 
191 #if TEST_STD_VER > 17
192     static_assert(test_constexpr());
193 #endif
194 
195   return 0;
196 }
197