xref: /llvm-project/libcxx/include/__algorithm/nth_element.h (revision 77a00c0d546cd4aa8311b5b9031ae9ea8cdb050c)
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef _LIBCPP___ALGORITHM_NTH_ELEMENT_H
10 #define _LIBCPP___ALGORITHM_NTH_ELEMENT_H
11 
12 #include <__algorithm/comp.h>
13 #include <__algorithm/comp_ref_type.h>
14 #include <__algorithm/iterator_operations.h>
15 #include <__algorithm/sort.h>
16 #include <__assert>
17 #include <__config>
18 #include <__debug_utils/randomize_range.h>
19 #include <__iterator/iterator_traits.h>
20 #include <__utility/move.h>
21 
22 #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
23 #  pragma GCC system_header
24 #endif
25 
26 _LIBCPP_BEGIN_NAMESPACE_STD
27 
28 template<class _Compare, class _RandomAccessIterator>
29 _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 bool
30 __nth_element_find_guard(_RandomAccessIterator& __i, _RandomAccessIterator& __j,
31                          _RandomAccessIterator __m, _Compare __comp)
32 {
33     // manually guard downward moving __j against __i
34     while (true) {
35         if (__i == --__j) {
36             return false;
37         }
38         if (__comp(*__j, *__m)) {
39             return true;  // found guard for downward moving __j, now use unguarded partition
40         }
41     }
42 }
43 
44 template <class _AlgPolicy, class _Compare, class _RandomAccessIterator>
45 _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 void
46 // NOLINTNEXTLINE(readability-function-cognitive-complexity)
47 __nth_element(_RandomAccessIterator __first, _RandomAccessIterator __nth, _RandomAccessIterator __last, _Compare __comp)
48 {
49     using _Ops = _IterOps<_AlgPolicy>;
50 
51     // _Compare is known to be a reference type
52     typedef typename iterator_traits<_RandomAccessIterator>::difference_type difference_type;
53     const difference_type __limit = 7;
54     while (true)
55     {
56         if (__nth == __last)
57             return;
58         difference_type __len = __last - __first;
59         switch (__len)
60         {
61         case 0:
62         case 1:
63             return;
64         case 2:
65             if (__comp(*--__last, *__first))
66                 _Ops::iter_swap(__first, __last);
67             return;
68         case 3:
69             {
70             _RandomAccessIterator __m = __first;
71             std::__sort3<_AlgPolicy, _Compare>(__first, ++__m, --__last, __comp);
72             return;
73             }
74         }
75         if (__len <= __limit)
76         {
77             std::__selection_sort<_AlgPolicy, _Compare>(__first, __last, __comp);
78             return;
79         }
80         // __len > __limit >= 3
81         _RandomAccessIterator __m = __first + __len/2;
82         _RandomAccessIterator __lm1 = __last;
83         unsigned __n_swaps = std::__sort3<_AlgPolicy, _Compare>(__first, __m, --__lm1, __comp);
84         // *__m is median
85         // partition [__first, __m) < *__m and *__m <= [__m, __last)
86         // (this inhibits tossing elements equivalent to __m around unnecessarily)
87         _RandomAccessIterator __i = __first;
88         _RandomAccessIterator __j = __lm1;
89         // j points beyond range to be tested, *__lm1 is known to be <= *__m
90         // The search going up is known to be guarded but the search coming down isn't.
91         // Prime the downward search with a guard.
92         if (!__comp(*__i, *__m))  // if *__first == *__m
93         {
94             // *__first == *__m, *__first doesn't go in first part
95             if (std::__nth_element_find_guard<_Compare>(__i, __j, __m, __comp)) {
96                 _Ops::iter_swap(__i, __j);
97                 ++__n_swaps;
98             } else {
99                 // *__first == *__m, *__m <= all other elements
100                 // Partition instead into [__first, __i) == *__first and *__first < [__i, __last)
101                 ++__i;  // __first + 1
102                 __j = __last;
103                 if (!__comp(*__first, *--__j)) {  // we need a guard if *__first == *(__last-1)
104                     while (true) {
105                         if (__i == __j) {
106                             return;  // [__first, __last) all equivalent elements
107                         } else if (__comp(*__first, *__i)) {
108                             _Ops::iter_swap(__i, __j);
109                             ++__n_swaps;
110                             ++__i;
111                             break;
112                         }
113                         ++__i;
114                     }
115                 }
116                 // [__first, __i) == *__first and *__first < [__j, __last) and __j == __last - 1
117                 if (__i == __j) {
118                     return;
119                 }
120                 while (true) {
121                     while (!__comp(*__first, *__i)) {
122                         ++__i;
123                         _LIBCPP_ASSERT_UNCATEGORIZED(
124                             __i != __last,
125                             "Would read out of bounds, does your comparator satisfy the strict-weak ordering requirement?");
126                     }
127                     do {
128                         _LIBCPP_ASSERT_UNCATEGORIZED(
129                             __j != __first,
130                             "Would read out of bounds, does your comparator satisfy the strict-weak ordering requirement?");
131                         --__j;
132                     } while (__comp(*__first, *__j));
133                     if (__i >= __j)
134                         break;
135                     _Ops::iter_swap(__i, __j);
136                     ++__n_swaps;
137                     ++__i;
138                 }
139                 // [__first, __i) == *__first and *__first < [__i, __last)
140                 // The first part is sorted,
141                 if (__nth < __i) {
142                     return;
143                 }
144                 // __nth_element the second part
145                 // std::__nth_element<_Compare>(__i, __nth, __last, __comp);
146                 __first = __i;
147                 continue;
148             }
149         }
150         ++__i;
151         // j points beyond range to be tested, *__lm1 is known to be <= *__m
152         // if not yet partitioned...
153         if (__i < __j)
154         {
155             // known that *(__i - 1) < *__m
156             while (true)
157             {
158                 // __m still guards upward moving __i
159                 while (__comp(*__i, *__m)) {
160                     ++__i;
161                     _LIBCPP_ASSERT_UNCATEGORIZED(
162                         __i != __last,
163                         "Would read out of bounds, does your comparator satisfy the strict-weak ordering requirement?");
164                 }
165                 // It is now known that a guard exists for downward moving __j
166                 do {
167                     _LIBCPP_ASSERT_UNCATEGORIZED(
168                         __j != __first,
169                         "Would read out of bounds, does your comparator satisfy the strict-weak ordering requirement?");
170                     --__j;
171                 } while (!__comp(*__j, *__m));
172                 if (__i >= __j)
173                     break;
174                 _Ops::iter_swap(__i, __j);
175                 ++__n_swaps;
176                 // It is known that __m != __j
177                 // If __m just moved, follow it
178                 if (__m == __i)
179                     __m = __j;
180                 ++__i;
181             }
182         }
183         // [__first, __i) < *__m and *__m <= [__i, __last)
184         if (__i != __m && __comp(*__m, *__i))
185         {
186             _Ops::iter_swap(__i, __m);
187             ++__n_swaps;
188         }
189         // [__first, __i) < *__i and *__i <= [__i+1, __last)
190         if (__nth == __i)
191             return;
192         if (__n_swaps == 0)
193         {
194             // We were given a perfectly partitioned sequence.  Coincidence?
195             if (__nth < __i)
196             {
197                 // Check for [__first, __i) already sorted
198                 __j = __m = __first;
199                 while (true) {
200                     if (++__j == __i) {
201                         // [__first, __i) sorted
202                         return;
203                     }
204                     if (__comp(*__j, *__m)) {
205                         // not yet sorted, so sort
206                         break;
207                     }
208                     __m = __j;
209                 }
210             }
211             else
212             {
213                 // Check for [__i, __last) already sorted
214                 __j = __m = __i;
215                 while (true) {
216                     if (++__j == __last) {
217                         // [__i, __last) sorted
218                         return;
219                     }
220                     if (__comp(*__j, *__m)) {
221                         // not yet sorted, so sort
222                         break;
223                     }
224                     __m = __j;
225                 }
226             }
227         }
228         // __nth_element on range containing __nth
229         if (__nth < __i)
230         {
231             // std::__nth_element<_Compare>(__first, __nth, __i, __comp);
232             __last = __i;
233         }
234         else
235         {
236             // std::__nth_element<_Compare>(__i+1, __nth, __last, __comp);
237             __first = ++__i;
238         }
239     }
240 }
241 
242 template <class _AlgPolicy, class _RandomAccessIterator, class _Compare>
243 inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20
244 void __nth_element_impl(_RandomAccessIterator __first, _RandomAccessIterator __nth, _RandomAccessIterator __last,
245                         _Compare& __comp) {
246   if (__nth == __last)
247     return;
248 
249   std::__debug_randomize_range<_AlgPolicy>(__first, __last);
250 
251   std::__nth_element<_AlgPolicy, __comp_ref_type<_Compare> >(__first, __nth, __last, __comp);
252 
253   std::__debug_randomize_range<_AlgPolicy>(__first, __nth);
254   if (__nth != __last) {
255     std::__debug_randomize_range<_AlgPolicy>(++__nth, __last);
256   }
257 }
258 
259 template <class _RandomAccessIterator, class _Compare>
260 inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20
261 void nth_element(_RandomAccessIterator __first, _RandomAccessIterator __nth, _RandomAccessIterator __last,
262                  _Compare __comp) {
263   std::__nth_element_impl<_ClassicAlgPolicy>(std::move(__first), std::move(__nth), std::move(__last), __comp);
264 }
265 
266 template <class _RandomAccessIterator>
267 inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20
268 void nth_element(_RandomAccessIterator __first, _RandomAccessIterator __nth, _RandomAccessIterator __last) {
269   std::nth_element(std::move(__first), std::move(__nth), std::move(__last), __less<>());
270 }
271 
272 _LIBCPP_END_NAMESPACE_STD
273 
274 #endif // _LIBCPP___ALGORITHM_NTH_ELEMENT_H
275