core/slice/sort/unstable/quicksort.rs
1//! This module contains an unstable quicksort and two partition implementations.
2
3#[cfg(not(feature = "optimize_for_size"))]
4use crate::mem;
5use crate::mem::ManuallyDrop;
6#[cfg(not(feature = "optimize_for_size"))]
7use crate::slice::sort::shared::pivot::choose_pivot;
8#[cfg(not(feature = "optimize_for_size"))]
9use crate::slice::sort::shared::smallsort::UnstableSmallSortTypeImpl;
10#[cfg(not(feature = "optimize_for_size"))]
11use crate::slice::sort::unstable::heapsort;
12use crate::{cfg_match, intrinsics, ptr};
13
14/// Sorts `v` recursively.
15///
16/// If the slice had a predecessor in the original array, it is specified as `ancestor_pivot`.
17///
18/// `limit` is the number of allowed imbalanced partitions before switching to `heapsort`. If zero,
19/// this function will immediately switch to heapsort.
20#[cfg(not(feature = "optimize_for_size"))]
21pub(crate) fn quicksort<'a, T, F>(
22 mut v: &'a mut [T],
23 mut ancestor_pivot: Option<&'a T>,
24 mut limit: u32,
25 is_less: &mut F,
26) where
27 F: FnMut(&T, &T) -> bool,
28{
29 loop {
30 if v.len() <= T::small_sort_threshold() {
31 T::small_sort(v, is_less);
32 return;
33 }
34
35 // If too many bad pivot choices were made, simply fall back to heapsort in order to
36 // guarantee `O(N x log(N))` worst-case.
37 if limit == 0 {
38 heapsort::heapsort(v, is_less);
39 return;
40 }
41
42 limit -= 1;
43
44 // Choose a pivot and try guessing whether the slice is already sorted.
45 let pivot_pos = choose_pivot(v, is_less);
46
47 // If the chosen pivot is equal to the predecessor, then it's the smallest element in the
48 // slice. Partition the slice into elements equal to and elements greater than the pivot.
49 // This case is usually hit when the slice contains many duplicate elements.
50 if let Some(p) = ancestor_pivot {
51 // SAFETY: We assume choose_pivot yields an in-bounds position.
52 if !is_less(p, unsafe { v.get_unchecked(pivot_pos) }) {
53 let num_lt = partition(v, pivot_pos, &mut |a, b| !is_less(b, a));
54
55 // Continue sorting elements greater than the pivot. We know that `num_lt` contains
56 // the pivot. So we can continue after `num_lt`.
57 v = &mut v[(num_lt + 1)..];
58 ancestor_pivot = None;
59 continue;
60 }
61 }
62
63 // Partition the slice.
64 let num_lt = partition(v, pivot_pos, is_less);
65 // SAFETY: partition ensures that `num_lt` will be in-bounds.
66 unsafe { intrinsics::assume(num_lt < v.len()) };
67
68 // Split the slice into `left`, `pivot`, and `right`.
69 let (left, right) = v.split_at_mut(num_lt);
70 let (pivot, right) = right.split_at_mut(1);
71 let pivot = &pivot[0];
72
73 // Recurse into the left side. We have a fixed recursion limit, testing shows no real
74 // benefit for recursing into the shorter side.
75 quicksort(left, ancestor_pivot, limit, is_less);
76
77 // Continue with the right side.
78 v = right;
79 ancestor_pivot = Some(pivot);
80 }
81}
82
83/// Takes the input slice `v` and re-arranges elements such that when the call returns normally
84/// all elements that compare true for `is_less(elem, pivot)` where `pivot == v[pivot_pos]` are
85/// on the left side of `v` followed by the other elements, notionally considered greater or
86/// equal to `pivot`.
87///
88/// Returns the number of elements that are compared true for `is_less(elem, pivot)`.
89///
90/// If `is_less` does not implement a total order the resulting order and return value are
91/// unspecified. All original elements will remain in `v` and any possible modifications via
92/// interior mutability will be observable. Same is true if `is_less` panics or `v.len()`
93/// exceeds `scratch.len()`.
94pub(crate) fn partition<T, F>(v: &mut [T], pivot: usize, is_less: &mut F) -> usize
95where
96 F: FnMut(&T, &T) -> bool,
97{
98 let len = v.len();
99
100 // Allows for panic-free code-gen by proving this property to the compiler.
101 if len == 0 {
102 return 0;
103 }
104
105 if pivot >= len {
106 intrinsics::abort();
107 }
108
109 // SAFETY: We checked that `pivot` is in-bounds.
110 unsafe {
111 // Place the pivot at the beginning of slice.
112 v.swap_unchecked(0, pivot);
113 }
114 let (pivot, v_without_pivot) = v.split_at_mut(1);
115
116 // Assuming that Rust generates noalias LLVM IR we can be sure that a partition function
117 // signature of the form `(v: &mut [T], pivot: &T)` guarantees that pivot and v can't alias.
118 // Having this guarantee is crucial for optimizations. It's possible to copy the pivot value
119 // into a stack value, but this creates issues for types with interior mutability mandating
120 // a drop guard.
121 let pivot = &mut pivot[0];
122
123 // This construct is used to limit the LLVM IR generated, which saves large amounts of
124 // compile-time by only instantiating the code that is needed. Idea by Frank Steffahn.
125 let num_lt = (const { inst_partition::<T, F>() })(v_without_pivot, pivot, is_less);
126
127 if num_lt >= len {
128 intrinsics::abort();
129 }
130
131 // SAFETY: We checked that `num_lt` is in-bounds.
132 unsafe {
133 // Place the pivot between the two partitions.
134 v.swap_unchecked(0, num_lt);
135 }
136
137 num_lt
138}
139
140const fn inst_partition<T, F: FnMut(&T, &T) -> bool>() -> fn(&mut [T], &T, &mut F) -> usize {
141 const MAX_BRANCHLESS_PARTITION_SIZE: usize = 96;
142 if size_of::<T>() <= MAX_BRANCHLESS_PARTITION_SIZE {
143 // Specialize for types that are relatively cheap to copy, where branchless optimizations
144 // have large leverage e.g. `u64` and `String`.
145 cfg_match! {
146 feature = "optimize_for_size" => {
147 partition_lomuto_branchless_simple::<T, F>
148 }
149 _ => {
150 partition_lomuto_branchless_cyclic::<T, F>
151 }
152 }
153 } else {
154 partition_hoare_branchy_cyclic::<T, F>
155 }
156}
157
158/// See [`partition`].
159fn partition_hoare_branchy_cyclic<T, F>(v: &mut [T], pivot: &T, is_less: &mut F) -> usize
160where
161 F: FnMut(&T, &T) -> bool,
162{
163 let len = v.len();
164
165 if len == 0 {
166 return 0;
167 }
168
169 // Optimized for large types that are expensive to move. Not optimized for integers. Optimized
170 // for small code-gen, assuming that is_less is an expensive operation that generates
171 // substantial amounts of code or a call. And that copying elements will likely be a call to
172 // memcpy. Using 2 `ptr::copy_nonoverlapping` has the chance to be faster than
173 // `ptr::swap_nonoverlapping` because `memcpy` can use wide SIMD based on runtime feature
174 // detection. Benchmarks support this analysis.
175
176 let mut gap_opt: Option<GapGuard<T>> = None;
177
178 // SAFETY: The left-to-right scanning loop performs a bounds check, where we know that `left >=
179 // v_base && left < right && right <= v_base.add(len)`. The right-to-left scanning loop performs
180 // a bounds check ensuring that `right` is in-bounds. We checked that `len` is more than zero,
181 // which means that unconditional `right = right.sub(1)` is safe to do. The exit check makes
182 // sure that `left` and `right` never alias, making `ptr::copy_nonoverlapping` safe. The
183 // drop-guard `gap` ensures that should `is_less` panic we always overwrite the duplicate in the
184 // input. `gap.pos` stores the previous value of `right` and starts at `right` and so it too is
185 // in-bounds. We never pass the saved `gap.value` to `is_less` while it is inside the `GapGuard`
186 // thus any changes via interior mutability will be observed.
187 unsafe {
188 let v_base = v.as_mut_ptr();
189
190 let mut left = v_base;
191 let mut right = v_base.add(len);
192
193 loop {
194 // Find the first element greater than the pivot.
195 while left < right && is_less(&*left, pivot) {
196 left = left.add(1);
197 }
198
199 // Find the last element equal to the pivot.
200 loop {
201 right = right.sub(1);
202 if left >= right || is_less(&*right, pivot) {
203 break;
204 }
205 }
206
207 if left >= right {
208 break;
209 }
210
211 // Swap the found pair of out-of-order elements via cyclic permutation.
212 let is_first_swap_pair = gap_opt.is_none();
213
214 if is_first_swap_pair {
215 gap_opt = Some(GapGuard { pos: right, value: ManuallyDrop::new(ptr::read(left)) });
216 }
217
218 let gap = gap_opt.as_mut().unwrap_unchecked();
219
220 // Single place where we instantiate ptr::copy_nonoverlapping in the partition.
221 if !is_first_swap_pair {
222 ptr::copy_nonoverlapping(left, gap.pos, 1);
223 }
224 gap.pos = right;
225 ptr::copy_nonoverlapping(right, left, 1);
226
227 left = left.add(1);
228 }
229
230 left.offset_from_unsigned(v_base)
231
232 // `gap_opt` goes out of scope and overwrites the last wrong-side element on the right side
233 // with the first wrong-side element of the left side that was initially overwritten by the
234 // first wrong-side element on the right side element.
235 }
236}
237
238#[cfg(not(feature = "optimize_for_size"))]
239struct PartitionState<T> {
240 // The current element that is being looked at, scans left to right through slice.
241 right: *mut T,
242 // Counts the number of elements that compared less-than, also works around:
243 // https://github.com/rust-lang/rust/issues/117128
244 num_lt: usize,
245 // Gap guard that tracks the temporary duplicate in the input.
246 gap: GapGuardRaw<T>,
247}
248
249#[cfg(not(feature = "optimize_for_size"))]
250fn partition_lomuto_branchless_cyclic<T, F>(v: &mut [T], pivot: &T, is_less: &mut F) -> usize
251where
252 F: FnMut(&T, &T) -> bool,
253{
254 // Novel partition implementation by Lukas Bergdoll and Orson Peters. Branchless Lomuto
255 // partition paired with a cyclic permutation.
256 // https://github.com/Voultapher/sort-research-rs/blob/main/writeup/lomcyc_partition/text.md
257
258 let len = v.len();
259 let v_base = v.as_mut_ptr();
260
261 if len == 0 {
262 return 0;
263 }
264
265 // SAFETY: We checked that `len` is more than zero, which means that reading `v_base` is safe to
266 // do. From there we have a bounded loop where `v_base.add(i)` is guaranteed in-bounds. `v` and
267 // `pivot` can't alias because of type system rules. The drop-guard `gap` ensures that should
268 // `is_less` panic we always overwrite the duplicate in the input. `gap.pos` stores the previous
269 // value of `right` and starts at `v_base` and so it too is in-bounds. Given `UNROLL_LEN == 2`
270 // after the main loop we either have A) the last element in `v` that has not yet been processed
271 // because `len % 2 != 0`, or B) all elements have been processed except the gap value that was
272 // saved at the beginning with `ptr::read(v_base)`. In the case A) the loop will iterate twice,
273 // first performing loop_body to take care of the last element that didn't fit into the unroll.
274 // After that the behavior is the same as for B) where we use the saved value as `right` to
275 // overwrite the duplicate. If this very last call to `is_less` panics the saved value will be
276 // copied back including all possible changes via interior mutability. If `is_less` does not
277 // panic and the code continues we overwrite the duplicate and do `right = right.add(1)`, this
278 // is safe to do with `&mut *gap.value` because `T` is the same as `[T; 1]` and generating a
279 // pointer one past the allocation is safe.
280 unsafe {
281 let mut loop_body = |state: &mut PartitionState<T>| {
282 let right_is_lt = is_less(&*state.right, pivot);
283 let left = v_base.add(state.num_lt);
284
285 ptr::copy(left, state.gap.pos, 1);
286 ptr::copy_nonoverlapping(state.right, left, 1);
287
288 state.gap.pos = state.right;
289 state.num_lt += right_is_lt as usize;
290
291 state.right = state.right.add(1);
292 };
293
294 // Ideally we could just use GapGuard in PartitionState, but the reference that is
295 // materialized with `&mut state` when calling `loop_body` would create a mutable reference
296 // to the parent struct that contains the gap value, invalidating the reference pointer
297 // created from a reference to the gap value in the cleanup loop. This is only an issue
298 // under Stacked Borrows, Tree Borrows accepts the intuitive code using GapGuard as valid.
299 let mut gap_value = ManuallyDrop::new(ptr::read(v_base));
300
301 let mut state = PartitionState {
302 num_lt: 0,
303 right: v_base.add(1),
304
305 gap: GapGuardRaw { pos: v_base, value: &mut *gap_value },
306 };
307
308 // Manual unrolling that works well on x86, Arm and with opt-level=s without murdering
309 // compile-times. Leaving this to the compiler yields ok to bad results.
310 let unroll_len = const { if size_of::<T>() <= 16 { 2 } else { 1 } };
311
312 let unroll_end = v_base.add(len - (unroll_len - 1));
313 while state.right < unroll_end {
314 if unroll_len == 2 {
315 loop_body(&mut state);
316 loop_body(&mut state);
317 } else {
318 loop_body(&mut state);
319 }
320 }
321
322 // Single instantiate `loop_body` for both the unroll cleanup and cyclic permutation
323 // cleanup. Optimizes binary-size and compile-time.
324 let end = v_base.add(len);
325 loop {
326 let is_done = state.right == end;
327 state.right = if is_done { state.gap.value } else { state.right };
328
329 loop_body(&mut state);
330
331 if is_done {
332 mem::forget(state.gap);
333 break;
334 }
335 }
336
337 state.num_lt
338 }
339}
340
341#[cfg(feature = "optimize_for_size")]
342fn partition_lomuto_branchless_simple<T, F: FnMut(&T, &T) -> bool>(
343 v: &mut [T],
344 pivot: &T,
345 is_less: &mut F,
346) -> usize {
347 let mut left = 0;
348
349 for right in 0..v.len() {
350 // SAFETY: `left` can at max be incremented by 1 each loop iteration, which implies that
351 // left <= right and that both are in-bounds.
352 unsafe {
353 let right_is_lt = is_less(v.get_unchecked(right), pivot);
354 v.swap_unchecked(left, right);
355 left += right_is_lt as usize;
356 }
357 }
358
359 left
360}
361
362struct GapGuard<T> {
363 pos: *mut T,
364 value: ManuallyDrop<T>,
365}
366
367impl<T> Drop for GapGuard<T> {
368 fn drop(&mut self) {
369 // SAFETY: `self` MUST be constructed in a way that makes copying the gap value into
370 // `self.pos` sound.
371 unsafe {
372 ptr::copy_nonoverlapping(&*self.value, self.pos, 1);
373 }
374 }
375}
376
377/// Ideally this wouldn't be needed and we could just use the regular GapGuard.
378/// See comment in [`partition_lomuto_branchless_cyclic`].
379#[cfg(not(feature = "optimize_for_size"))]
380struct GapGuardRaw<T> {
381 pos: *mut T,
382 value: *mut T,
383}
384
385#[cfg(not(feature = "optimize_for_size"))]
386impl<T> Drop for GapGuardRaw<T> {
387 fn drop(&mut self) {
388 // SAFETY: `self` MUST be constructed in a way that makes copying the gap value into
389 // `self.pos` sound.
390 unsafe {
391 ptr::copy_nonoverlapping(self.value, self.pos, 1);
392 }
393 }
394}