xref: /llvm-project/libcxx/include/__algorithm/nth_element.h (revision 7ed7d4ccb8991e2b5b95334b508f8cec2faee737)
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef _LIBCPP___ALGORITHM_NTH_ELEMENT_H
10 #define _LIBCPP___ALGORITHM_NTH_ELEMENT_H
11 
12 #include <__config>
13 #include <__algorithm/comp_ref_type.h>
14 #include <__algorithm/sort.h>
15 #include <__iterator/iterator_traits.h>
16 
17 #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
18 #pragma GCC system_header
19 #endif
20 
21 _LIBCPP_PUSH_MACROS
22 #include <__undef_macros>
23 
24 _LIBCPP_BEGIN_NAMESPACE_STD
25 
26 template<class _Compare, class _RandomAccessIterator>
27 _LIBCPP_CONSTEXPR_AFTER_CXX11 bool
28 __nth_element_find_guard(_RandomAccessIterator& __i, _RandomAccessIterator& __j,
29                          _RandomAccessIterator __m, _Compare __comp)
30 {
31     // manually guard downward moving __j against __i
32     while (true) {
33         if (__i == --__j) {
34             return false;
35         }
36         if (__comp(*__j, *__m)) {
37             return true;  // found guard for downward moving __j, now use unguarded partition
38         }
39     }
40 }
41 
42 template <class _Compare, class _RandomAccessIterator>
43 _LIBCPP_CONSTEXPR_AFTER_CXX11 void
44 __nth_element(_RandomAccessIterator __first, _RandomAccessIterator __nth, _RandomAccessIterator __last, _Compare __comp)
45 {
46     // _Compare is known to be a reference type
47     typedef typename iterator_traits<_RandomAccessIterator>::difference_type difference_type;
48     const difference_type __limit = 7;
49     while (true)
50     {
51         if (__nth == __last)
52             return;
53         difference_type __len = __last - __first;
54         switch (__len)
55         {
56         case 0:
57         case 1:
58             return;
59         case 2:
60             if (__comp(*--__last, *__first))
61                 swap(*__first, *__last);
62             return;
63         case 3:
64             {
65             _RandomAccessIterator __m = __first;
66             _VSTD::__sort3<_Compare>(__first, ++__m, --__last, __comp);
67             return;
68             }
69         }
70         if (__len <= __limit)
71         {
72             _VSTD::__selection_sort<_Compare>(__first, __last, __comp);
73             return;
74         }
75         // __len > __limit >= 3
76         _RandomAccessIterator __m = __first + __len/2;
77         _RandomAccessIterator __lm1 = __last;
78         unsigned __n_swaps = _VSTD::__sort3<_Compare>(__first, __m, --__lm1, __comp);
79         // *__m is median
80         // partition [__first, __m) < *__m and *__m <= [__m, __last)
81         // (this inhibits tossing elements equivalent to __m around unnecessarily)
82         _RandomAccessIterator __i = __first;
83         _RandomAccessIterator __j = __lm1;
84         // j points beyond range to be tested, *__lm1 is known to be <= *__m
85         // The search going up is known to be guarded but the search coming down isn't.
86         // Prime the downward search with a guard.
87         if (!__comp(*__i, *__m))  // if *__first == *__m
88         {
89             // *__first == *__m, *__first doesn't go in first part
90             if (_VSTD::__nth_element_find_guard<_Compare>(__i, __j, __m, __comp)) {
91                 swap(*__i, *__j);
92                 ++__n_swaps;
93             } else {
94                 // *__first == *__m, *__m <= all other elements
95                 // Partition instead into [__first, __i) == *__first and *__first < [__i, __last)
96                 ++__i;  // __first + 1
97                 __j = __last;
98                 if (!__comp(*__first, *--__j)) {  // we need a guard if *__first == *(__last-1)
99                     while (true) {
100                         if (__i == __j) {
101                             return;  // [__first, __last) all equivalent elements
102                         } else if (__comp(*__first, *__i)) {
103                             swap(*__i, *__j);
104                             ++__n_swaps;
105                             ++__i;
106                             break;
107                         }
108                         ++__i;
109                     }
110                 }
111                 // [__first, __i) == *__first and *__first < [__j, __last) and __j == __last - 1
112                 if (__i == __j) {
113                     return;
114                 }
115                 while (true) {
116                     while (!__comp(*__first, *__i))
117                         ++__i;
118                     while (__comp(*__first, *--__j))
119                         ;
120                     if (__i >= __j)
121                         break;
122                     swap(*__i, *__j);
123                     ++__n_swaps;
124                     ++__i;
125                 }
126                 // [__first, __i) == *__first and *__first < [__i, __last)
127                 // The first part is sorted,
128                 if (__nth < __i) {
129                     return;
130                 }
131                 // __nth_element the second part
132                 // _VSTD::__nth_element<_Compare>(__i, __nth, __last, __comp);
133                 __first = __i;
134                 continue;
135             }
136         }
137         ++__i;
138         // j points beyond range to be tested, *__lm1 is known to be <= *__m
139         // if not yet partitioned...
140         if (__i < __j)
141         {
142             // known that *(__i - 1) < *__m
143             while (true)
144             {
145                 // __m still guards upward moving __i
146                 while (__comp(*__i, *__m))
147                     ++__i;
148                 // It is now known that a guard exists for downward moving __j
149                 while (!__comp(*--__j, *__m))
150                     ;
151                 if (__i >= __j)
152                     break;
153                 swap(*__i, *__j);
154                 ++__n_swaps;
155                 // It is known that __m != __j
156                 // If __m just moved, follow it
157                 if (__m == __i)
158                     __m = __j;
159                 ++__i;
160             }
161         }
162         // [__first, __i) < *__m and *__m <= [__i, __last)
163         if (__i != __m && __comp(*__m, *__i))
164         {
165             swap(*__i, *__m);
166             ++__n_swaps;
167         }
168         // [__first, __i) < *__i and *__i <= [__i+1, __last)
169         if (__nth == __i)
170             return;
171         if (__n_swaps == 0)
172         {
173             // We were given a perfectly partitioned sequence.  Coincidence?
174             if (__nth < __i)
175             {
176                 // Check for [__first, __i) already sorted
177                 __j = __m = __first;
178                 while (true) {
179                     if (++__j == __i) {
180                         // [__first, __i) sorted
181                         return;
182                     }
183                     if (__comp(*__j, *__m)) {
184                         // not yet sorted, so sort
185                         break;
186                     }
187                     __m = __j;
188                 }
189             }
190             else
191             {
192                 // Check for [__i, __last) already sorted
193                 __j = __m = __i;
194                 while (true) {
195                     if (++__j == __last) {
196                         // [__i, __last) sorted
197                         return;
198                     }
199                     if (__comp(*__j, *__m)) {
200                         // not yet sorted, so sort
201                         break;
202                     }
203                     __m = __j;
204                 }
205             }
206         }
207         // __nth_element on range containing __nth
208         if (__nth < __i)
209         {
210             // _VSTD::__nth_element<_Compare>(__first, __nth, __i, __comp);
211             __last = __i;
212         }
213         else
214         {
215             // _VSTD::__nth_element<_Compare>(__i+1, __nth, __last, __comp);
216             __first = ++__i;
217         }
218     }
219 }
220 
221 template <class _RandomAccessIterator, class _Compare>
222 inline _LIBCPP_INLINE_VISIBILITY _LIBCPP_CONSTEXPR_AFTER_CXX17
223 void
224 nth_element(_RandomAccessIterator __first, _RandomAccessIterator __nth, _RandomAccessIterator __last, _Compare __comp)
225 {
226     typedef typename __comp_ref_type<_Compare>::type _Comp_ref;
227     _VSTD::__nth_element<_Comp_ref>(__first, __nth, __last, __comp);
228 }
229 
230 template <class _RandomAccessIterator>
231 inline _LIBCPP_INLINE_VISIBILITY _LIBCPP_CONSTEXPR_AFTER_CXX17
232 void
233 nth_element(_RandomAccessIterator __first, _RandomAccessIterator __nth, _RandomAccessIterator __last)
234 {
235     _VSTD::nth_element(__first, __nth, __last, __less<typename iterator_traits<_RandomAccessIterator>::value_type>());
236 }
237 
238 _LIBCPP_END_NAMESPACE_STD
239 
240 _LIBCPP_POP_MACROS
241 
242 #endif // _LIBCPP___ALGORITHM_NTH_ELEMENT_H
243