core/slice/sort/unstable/
quicksort.rs

1//! This module contains an unstable quicksort and two partition implementations.
2
3use crate::mem::{self, ManuallyDrop};
4#[cfg(not(feature = "optimize_for_size"))]
5use crate::slice::sort::shared::pivot::choose_pivot;
6#[cfg(not(feature = "optimize_for_size"))]
7use crate::slice::sort::shared::smallsort::UnstableSmallSortTypeImpl;
8#[cfg(not(feature = "optimize_for_size"))]
9use crate::slice::sort::unstable::heapsort;
10use crate::{intrinsics, ptr};
11
12/// Sorts `v` recursively.
13///
14/// If the slice had a predecessor in the original array, it is specified as `ancestor_pivot`.
15///
16/// `limit` is the number of allowed imbalanced partitions before switching to `heapsort`. If zero,
17/// this function will immediately switch to heapsort.
18#[cfg(not(feature = "optimize_for_size"))]
19pub(crate) fn quicksort<'a, T, F>(
20    mut v: &'a mut [T],
21    mut ancestor_pivot: Option<&'a T>,
22    mut limit: u32,
23    is_less: &mut F,
24) where
25    F: FnMut(&T, &T) -> bool,
26{
27    loop {
28        if v.len() <= T::small_sort_threshold() {
29            T::small_sort(v, is_less);
30            return;
31        }
32
33        // If too many bad pivot choices were made, simply fall back to heapsort in order to
34        // guarantee `O(N x log(N))` worst-case.
35        if limit == 0 {
36            heapsort::heapsort(v, is_less);
37            return;
38        }
39
40        limit -= 1;
41
42        // Choose a pivot and try guessing whether the slice is already sorted.
43        let pivot_pos = choose_pivot(v, is_less);
44
45        // If the chosen pivot is equal to the predecessor, then it's the smallest element in the
46        // slice. Partition the slice into elements equal to and elements greater than the pivot.
47        // This case is usually hit when the slice contains many duplicate elements.
48        if let Some(p) = ancestor_pivot {
49            // SAFETY: We assume choose_pivot yields an in-bounds position.
50            if !is_less(p, unsafe { v.get_unchecked(pivot_pos) }) {
51                let num_lt = partition(v, pivot_pos, &mut |a, b| !is_less(b, a));
52
53                // Continue sorting elements greater than the pivot. We know that `num_lt` contains
54                // the pivot. So we can continue after `num_lt`.
55                v = &mut v[(num_lt + 1)..];
56                ancestor_pivot = None;
57                continue;
58            }
59        }
60
61        // Partition the slice.
62        let num_lt = partition(v, pivot_pos, is_less);
63        // SAFETY: partition ensures that `num_lt` will be in-bounds.
64        unsafe { intrinsics::assume(num_lt < v.len()) };
65
66        // Split the slice into `left`, `pivot`, and `right`.
67        let (left, right) = v.split_at_mut(num_lt);
68        let (pivot, right) = right.split_at_mut(1);
69        let pivot = &pivot[0];
70
71        // Recurse into the left side. We have a fixed recursion limit, testing shows no real
72        // benefit for recursing into the shorter side.
73        quicksort(left, ancestor_pivot, limit, is_less);
74
75        // Continue with the right side.
76        v = right;
77        ancestor_pivot = Some(pivot);
78    }
79}
80
81/// Takes the input slice `v` and re-arranges elements such that when the call returns normally
82/// all elements that compare true for `is_less(elem, pivot)` where `pivot == v[pivot_pos]` are
83/// on the left side of `v` followed by the other elements, notionally considered greater or
84/// equal to `pivot`.
85///
86/// Returns the number of elements that are compared true for `is_less(elem, pivot)`.
87///
88/// If `is_less` does not implement a total order the resulting order and return value are
89/// unspecified. All original elements will remain in `v` and any possible modifications via
90/// interior mutability will be observable. Same is true if `is_less` panics or `v.len()`
91/// exceeds `scratch.len()`.
92pub(crate) fn partition<T, F>(v: &mut [T], pivot: usize, is_less: &mut F) -> usize
93where
94    F: FnMut(&T, &T) -> bool,
95{
96    let len = v.len();
97
98    // Allows for panic-free code-gen by proving this property to the compiler.
99    if len == 0 {
100        return 0;
101    }
102
103    if pivot >= len {
104        intrinsics::abort();
105    }
106
107    // SAFETY: We checked that `pivot` is in-bounds.
108    unsafe {
109        // Place the pivot at the beginning of slice.
110        v.swap_unchecked(0, pivot);
111    }
112    let (pivot, v_without_pivot) = v.split_at_mut(1);
113
114    // Assuming that Rust generates noalias LLVM IR we can be sure that a partition function
115    // signature of the form `(v: &mut [T], pivot: &T)` guarantees that pivot and v can't alias.
116    // Having this guarantee is crucial for optimizations. It's possible to copy the pivot value
117    // into a stack value, but this creates issues for types with interior mutability mandating
118    // a drop guard.
119    let pivot = &mut pivot[0];
120
121    // This construct is used to limit the LLVM IR generated, which saves large amounts of
122    // compile-time by only instantiating the code that is needed. Idea by Frank Steffahn.
123    let num_lt = (const { inst_partition::<T, F>() })(v_without_pivot, pivot, is_less);
124
125    if num_lt >= len {
126        intrinsics::abort();
127    }
128
129    // SAFETY: We checked that `num_lt` is in-bounds.
130    unsafe {
131        // Place the pivot between the two partitions.
132        v.swap_unchecked(0, num_lt);
133    }
134
135    num_lt
136}
137
138const fn inst_partition<T, F: FnMut(&T, &T) -> bool>() -> fn(&mut [T], &T, &mut F) -> usize {
139    const MAX_BRANCHLESS_PARTITION_SIZE: usize = 96;
140    if mem::size_of::<T>() <= MAX_BRANCHLESS_PARTITION_SIZE {
141        // Specialize for types that are relatively cheap to copy, where branchless optimizations
142        // have large leverage e.g. `u64` and `String`.
143        cfg_if! {
144            if #[cfg(feature = "optimize_for_size")] {
145                partition_lomuto_branchless_simple::<T, F>
146            } else {
147                partition_lomuto_branchless_cyclic::<T, F>
148            }
149        }
150    } else {
151        partition_hoare_branchy_cyclic::<T, F>
152    }
153}
154
155/// See [`partition`].
156fn partition_hoare_branchy_cyclic<T, F>(v: &mut [T], pivot: &T, is_less: &mut F) -> usize
157where
158    F: FnMut(&T, &T) -> bool,
159{
160    let len = v.len();
161
162    if len == 0 {
163        return 0;
164    }
165
166    // Optimized for large types that are expensive to move. Not optimized for integers. Optimized
167    // for small code-gen, assuming that is_less is an expensive operation that generates
168    // substantial amounts of code or a call. And that copying elements will likely be a call to
169    // memcpy. Using 2 `ptr::copy_nonoverlapping` has the chance to be faster than
170    // `ptr::swap_nonoverlapping` because `memcpy` can use wide SIMD based on runtime feature
171    // detection. Benchmarks support this analysis.
172
173    let mut gap_opt: Option<GapGuard<T>> = None;
174
175    // SAFETY: The left-to-right scanning loop performs a bounds check, where we know that `left >=
176    // v_base && left < right && right <= v_base.add(len)`. The right-to-left scanning loop performs
177    // a bounds check ensuring that `right` is in-bounds. We checked that `len` is more than zero,
178    // which means that unconditional `right = right.sub(1)` is safe to do. The exit check makes
179    // sure that `left` and `right` never alias, making `ptr::copy_nonoverlapping` safe. The
180    // drop-guard `gap` ensures that should `is_less` panic we always overwrite the duplicate in the
181    // input. `gap.pos` stores the previous value of `right` and starts at `right` and so it too is
182    // in-bounds. We never pass the saved `gap.value` to `is_less` while it is inside the `GapGuard`
183    // thus any changes via interior mutability will be observed.
184    unsafe {
185        let v_base = v.as_mut_ptr();
186
187        let mut left = v_base;
188        let mut right = v_base.add(len);
189
190        loop {
191            // Find the first element greater than the pivot.
192            while left < right && is_less(&*left, pivot) {
193                left = left.add(1);
194            }
195
196            // Find the last element equal to the pivot.
197            loop {
198                right = right.sub(1);
199                if left >= right || is_less(&*right, pivot) {
200                    break;
201                }
202            }
203
204            if left >= right {
205                break;
206            }
207
208            // Swap the found pair of out-of-order elements via cyclic permutation.
209            let is_first_swap_pair = gap_opt.is_none();
210
211            if is_first_swap_pair {
212                gap_opt = Some(GapGuard { pos: right, value: ManuallyDrop::new(ptr::read(left)) });
213            }
214
215            let gap = gap_opt.as_mut().unwrap_unchecked();
216
217            // Single place where we instantiate ptr::copy_nonoverlapping in the partition.
218            if !is_first_swap_pair {
219                ptr::copy_nonoverlapping(left, gap.pos, 1);
220            }
221            gap.pos = right;
222            ptr::copy_nonoverlapping(right, left, 1);
223
224            left = left.add(1);
225        }
226
227        left.sub_ptr(v_base)
228
229        // `gap_opt` goes out of scope and overwrites the last wrong-side element on the right side
230        // with the first wrong-side element of the left side that was initially overwritten by the
231        // first wrong-side element on the right side element.
232    }
233}
234
235#[cfg(not(feature = "optimize_for_size"))]
236struct PartitionState<T> {
237    // The current element that is being looked at, scans left to right through slice.
238    right: *mut T,
239    // Counts the number of elements that compared less-than, also works around:
240    // https://github.com/rust-lang/rust/issues/117128
241    num_lt: usize,
242    // Gap guard that tracks the temporary duplicate in the input.
243    gap: GapGuardRaw<T>,
244}
245
246#[cfg(not(feature = "optimize_for_size"))]
247fn partition_lomuto_branchless_cyclic<T, F>(v: &mut [T], pivot: &T, is_less: &mut F) -> usize
248where
249    F: FnMut(&T, &T) -> bool,
250{
251    // Novel partition implementation by Lukas Bergdoll and Orson Peters. Branchless Lomuto
252    // partition paired with a cyclic permutation.
253    // https://github.com/Voultapher/sort-research-rs/blob/main/writeup/lomcyc_partition/text.md
254
255    let len = v.len();
256    let v_base = v.as_mut_ptr();
257
258    if len == 0 {
259        return 0;
260    }
261
262    // SAFETY: We checked that `len` is more than zero, which means that reading `v_base` is safe to
263    // do. From there we have a bounded loop where `v_base.add(i)` is guaranteed in-bounds. `v` and
264    // `pivot` can't alias because of type system rules. The drop-guard `gap` ensures that should
265    // `is_less` panic we always overwrite the duplicate in the input. `gap.pos` stores the previous
266    // value of `right` and starts at `v_base` and so it too is in-bounds. Given `UNROLL_LEN == 2`
267    // after the main loop we either have A) the last element in `v` that has not yet been processed
268    // because `len % 2 != 0`, or B) all elements have been processed except the gap value that was
269    // saved at the beginning with `ptr::read(v_base)`. In the case A) the loop will iterate twice,
270    // first performing loop_body to take care of the last element that didn't fit into the unroll.
271    // After that the behavior is the same as for B) where we use the saved value as `right` to
272    // overwrite the duplicate. If this very last call to `is_less` panics the saved value will be
273    // copied back including all possible changes via interior mutability. If `is_less` does not
274    // panic and the code continues we overwrite the duplicate and do `right = right.add(1)`, this
275    // is safe to do with `&mut *gap.value` because `T` is the same as `[T; 1]` and generating a
276    // pointer one past the allocation is safe.
277    unsafe {
278        let mut loop_body = |state: &mut PartitionState<T>| {
279            let right_is_lt = is_less(&*state.right, pivot);
280            let left = v_base.add(state.num_lt);
281
282            ptr::copy(left, state.gap.pos, 1);
283            ptr::copy_nonoverlapping(state.right, left, 1);
284
285            state.gap.pos = state.right;
286            state.num_lt += right_is_lt as usize;
287
288            state.right = state.right.add(1);
289        };
290
291        // Ideally we could just use GapGuard in PartitionState, but the reference that is
292        // materialized with `&mut state` when calling `loop_body` would create a mutable reference
293        // to the parent struct that contains the gap value, invalidating the reference pointer
294        // created from a reference to the gap value in the cleanup loop. This is only an issue
295        // under Stacked Borrows, Tree Borrows accepts the intuitive code using GapGuard as valid.
296        let mut gap_value = ManuallyDrop::new(ptr::read(v_base));
297
298        let mut state = PartitionState {
299            num_lt: 0,
300            right: v_base.add(1),
301
302            gap: GapGuardRaw { pos: v_base, value: &mut *gap_value },
303        };
304
305        // Manual unrolling that works well on x86, Arm and with opt-level=s without murdering
306        // compile-times. Leaving this to the compiler yields ok to bad results.
307        let unroll_len = const { if mem::size_of::<T>() <= 16 { 2 } else { 1 } };
308
309        let unroll_end = v_base.add(len - (unroll_len - 1));
310        while state.right < unroll_end {
311            if unroll_len == 2 {
312                loop_body(&mut state);
313                loop_body(&mut state);
314            } else {
315                loop_body(&mut state);
316            }
317        }
318
319        // Single instantiate `loop_body` for both the unroll cleanup and cyclic permutation
320        // cleanup. Optimizes binary-size and compile-time.
321        let end = v_base.add(len);
322        loop {
323            let is_done = state.right == end;
324            state.right = if is_done { state.gap.value } else { state.right };
325
326            loop_body(&mut state);
327
328            if is_done {
329                mem::forget(state.gap);
330                break;
331            }
332        }
333
334        state.num_lt
335    }
336}
337
338#[cfg(feature = "optimize_for_size")]
339fn partition_lomuto_branchless_simple<T, F: FnMut(&T, &T) -> bool>(
340    v: &mut [T],
341    pivot: &T,
342    is_less: &mut F,
343) -> usize {
344    let mut left = 0;
345
346    for right in 0..v.len() {
347        // SAFETY: `left` can at max be incremented by 1 each loop iteration, which implies that
348        // left <= right and that both are in-bounds.
349        unsafe {
350            let right_is_lt = is_less(v.get_unchecked(right), pivot);
351            v.swap_unchecked(left, right);
352            left += right_is_lt as usize;
353        }
354    }
355
356    left
357}
358
359struct GapGuard<T> {
360    pos: *mut T,
361    value: ManuallyDrop<T>,
362}
363
364impl<T> Drop for GapGuard<T> {
365    fn drop(&mut self) {
366        // SAFETY: `self` MUST be constructed in a way that makes copying the gap value into
367        // `self.pos` sound.
368        unsafe {
369            ptr::copy_nonoverlapping(&*self.value, self.pos, 1);
370        }
371    }
372}
373
374/// Ideally this wouldn't be needed and we could just use the regular GapGuard.
375/// See comment in [`partition_lomuto_branchless_cyclic`].
376#[cfg(not(feature = "optimize_for_size"))]
377struct GapGuardRaw<T> {
378    pos: *mut T,
379    value: *mut T,
380}
381
382#[cfg(not(feature = "optimize_for_size"))]
383impl<T> Drop for GapGuardRaw<T> {
384    fn drop(&mut self) {
385        // SAFETY: `self` MUST be constructed in a way that makes copying the gap value into
386        // `self.pos` sound.
387        unsafe {
388            ptr::copy_nonoverlapping(self.value, self.pos, 1);
389        }
390    }
391}