xref: /llvm-project/libc/src/stdlib/quick_sort.h (revision a738d81cd2822698539b0482af48d49d91ea5a2e)
1 //===-- Implementation header for qsort utilities ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef LLVM_LIBC_SRC_STDLIB_QUICK_SORT_H
10 #define LLVM_LIBC_SRC_STDLIB_QUICK_SORT_H
11 
12 #include "src/__support/CPP/bit.h"
13 #include "src/__support/CPP/cstddef.h"
14 #include "src/__support/macros/config.h"
15 #include "src/stdlib/qsort_pivot.h"
16 
17 #include <stdint.h>
18 
19 namespace LIBC_NAMESPACE_DECL {
20 namespace internal {
21 
22 // Branchless Lomuto partition based on the implementation by Lukas
23 // Bergdoll and Orson Peters
24 // https://github.com/Voultapher/sort-research-rs/blob/main/writeup/lomcyc_partition/text.md.
25 // Simplified to avoid having to stack allocate.
26 template <typename A, typename F>
27 LIBC_INLINE size_t partition_lomuto_branchless(const A &array,
28                                                const void *pivot,
29                                                const F &is_less) {
30   const size_t array_len = array.len();
31 
32   size_t left = 0;
33   size_t right = 0;
34 
35   while (right < array_len) {
36     const bool right_is_lt = is_less(array.get(right), pivot);
37     array.swap(left, right);
38     left += static_cast<size_t>(right_is_lt);
39     right += 1;
40   }
41 
42   return left;
43 }
44 
45 // Optimized for large types that are expensive to move. Not optimized
46 // for integers. It's possible to use a cyclic permutation here for
47 // large types as done in ipnsort but the advantages of this are limited
48 // as `is_less` is a small wrapper around a call to a function pointer
49 // and won't incur much binary-size overhead. The other reason to use
50 // cyclic permutation is to have more efficient swapping, but we don't
51 // know the element size so this isn't applicable here either.
52 template <typename A, typename F>
53 LIBC_INLINE size_t partition_hoare_branchy(const A &array, const void *pivot,
54                                            const F &is_less) {
55   const size_t array_len = array.len();
56 
57   size_t left = 0;
58   size_t right = array_len;
59 
60   while (true) {
61     while (left < right && is_less(array.get(left), pivot))
62       ++left;
63 
64     while (true) {
65       --right;
66       if (left >= right || is_less(array.get(right), pivot)) {
67         break;
68       }
69     }
70 
71     if (left >= right)
72       break;
73 
74     array.swap(left, right);
75     ++left;
76   }
77 
78   return left;
79 }
80 
81 template <typename A, typename F>
82 LIBC_INLINE size_t partition(const A &array, size_t pivot_index,
83                              const F &is_less) {
84   // Place the pivot at the beginning of the array.
85   if (pivot_index != 0) {
86     array.swap(0, pivot_index);
87   }
88 
89   const A array_without_pivot = array.make_array(1, array.len() - 1);
90   const void *pivot = array.get(0);
91 
92   size_t num_lt;
93   if constexpr (A::has_fixed_size()) {
94     // Branchless Lomuto avoid branch misprediction penalties, but
95     // it also swaps more often which is only faster if the swap is a fast
96     // constant operation.
97     num_lt = partition_lomuto_branchless(array_without_pivot, pivot, is_less);
98   } else {
99     num_lt = partition_hoare_branchy(array_without_pivot, pivot, is_less);
100   }
101 
102   // Place the pivot between the two partitions.
103   array.swap(0, num_lt);
104 
105   return num_lt;
106 }
107 
108 template <typename A, typename F>
109 LIBC_INLINE void quick_sort_impl(A &array, const void *ancestor_pivot,
110                                  size_t limit, const F &is_less) {
111   while (true) {
112     const size_t array_len = array.len();
113     if (array_len <= 1)
114       return;
115 
116     // If too many bad pivot choices were made, simply fall back to
117     // heapsort in order to guarantee `O(N x log(N))` worst-case.
118     if (limit == 0) {
119       heap_sort(array, is_less);
120       return;
121     }
122 
123     limit -= 1;
124 
125     const size_t pivot_index = choose_pivot(array, is_less);
126 
127     // If the chosen pivot is equal to the predecessor, then it's the smallest
128     // element in the slice. Partition the slice into elements equal to and
129     // elements greater than the pivot. This case is usually hit when the slice
130     // contains many duplicate elements.
131     if (ancestor_pivot) {
132       if (!is_less(ancestor_pivot, array.get(pivot_index))) {
133         const size_t num_lt =
134             partition(array, pivot_index,
135                       [is_less](const void *a, const void *b) -> bool {
136                         return !is_less(b, a);
137                       });
138 
139         // Continue sorting elements greater than the pivot. We know that
140         // `num_lt` cont
141         array.reset_bounds(num_lt + 1, array.len() - (num_lt + 1));
142         ancestor_pivot = nullptr;
143         continue;
144       }
145     }
146 
147     size_t split_index = partition(array, pivot_index, is_less);
148 
149     if (array_len == 2)
150       // The partition operation sorts the two element array.
151       return;
152 
153     // Split the array into `left`, `pivot`, and `right`.
154     A left = array.make_array(0, split_index);
155     const void *pivot = array.get(split_index);
156     const size_t right_start = split_index + 1;
157     A right = array.make_array(right_start, array.len() - right_start);
158 
159     // Recurse into the left side. We have a fixed recursion limit,
160     // testing shows no real benefit for recursing into the shorter
161     // side.
162     quick_sort_impl(left, ancestor_pivot, limit, is_less);
163 
164     // Continue with the right side.
165     array = right;
166     ancestor_pivot = pivot;
167   }
168 }
169 
170 constexpr size_t ilog2(size_t n) { return cpp::bit_width(n) - 1; }
171 
172 template <typename A, typename F>
173 LIBC_INLINE void quick_sort(A &array, const F &is_less) {
174   const void *ancestor_pivot = nullptr;
175   // Limit the number of imbalanced partitions to `2 * floor(log2(len))`.
176   // The binary OR by one is used to eliminate the zero-check in the logarithm.
177   const size_t limit = 2 * ilog2((array.len() | 1));
178   quick_sort_impl(array, ancestor_pivot, limit, is_less);
179 }
180 
181 } // namespace internal
182 } // namespace LIBC_NAMESPACE_DECL
183 
184 #endif // LLVM_LIBC_SRC_STDLIB_QUICK_SORT_H
185