core/slice/sort/stable/quicksort.rs
1//! This module contains a stable quicksort and partition implementation.
2
3use crate::mem::{ManuallyDrop, MaybeUninit};
4use crate::slice::sort::shared::FreezeMarker;
5use crate::slice::sort::shared::pivot::choose_pivot;
6use crate::slice::sort::shared::smallsort::StableSmallSortTypeImpl;
7use crate::{intrinsics, ptr};
8
9/// Sorts `v` recursively using quicksort.
10/// `scratch.len()` must be at least `max(v.len() - v.len() / 2, SMALL_SORT_GENERAL_SCRATCH_LEN)`
11/// otherwise the implementation may abort.
12///
13/// `limit` when initialized with `c*log(v.len())` for some c ensures we do not
14/// overflow the stack or go quadratic.
15#[inline(never)]
16pub fn quicksort<T, F: FnMut(&T, &T) -> bool>(
17 mut v: &mut [T],
18 scratch: &mut [MaybeUninit<T>],
19 mut limit: u32,
20 mut left_ancestor_pivot: Option<&T>,
21 is_less: &mut F,
22) {
23 loop {
24 let len = v.len();
25
26 if len <= T::small_sort_threshold() {
27 T::small_sort(v, scratch, is_less);
28 return;
29 }
30
31 if limit == 0 {
32 // We have had too many bad pivots, switch to O(n log n) fallback
33 // algorithm. In our case that is driftsort in eager mode.
34 crate::slice::sort::stable::drift::sort(v, scratch, true, is_less);
35 return;
36 }
37 limit -= 1;
38
39 let pivot_pos = choose_pivot(v, is_less);
40 // SAFETY: choose_pivot promises to return a valid pivot index.
41 unsafe {
42 intrinsics::assume(pivot_pos < v.len());
43 }
44
45 // SAFETY: We only access the temporary copy for Freeze types, otherwise
46 // self-modifications via `is_less` would not be observed and this would
47 // be unsound. Our temporary copy does not escape this scope.
48 let pivot_copy = unsafe { ManuallyDrop::new(ptr::read(&v[pivot_pos])) };
49 let pivot_ref = (!has_direct_interior_mutability::<T>()).then_some(&*pivot_copy);
50
51 // We choose a pivot, and check if this pivot is equal to our left
52 // ancestor. If true, we do a partition putting equal elements on the
53 // left and do not recurse on it. This gives O(n log k) sorting for k
54 // distinct values, a strategy borrowed from pdqsort. For types with
55 // interior mutability we can't soundly create a temporary copy of the
56 // ancestor pivot, and use left_partition_len == 0 as our method for
57 // detecting when we re-use a pivot, which means we do at most three
58 // partition operations with pivot p instead of the optimal two.
59 let mut perform_equal_partition = false;
60 if let Some(la_pivot) = left_ancestor_pivot {
61 perform_equal_partition = !is_less(la_pivot, &v[pivot_pos]);
62 }
63
64 let mut left_partition_len = 0;
65 if !perform_equal_partition {
66 left_partition_len = stable_partition(v, scratch, pivot_pos, false, is_less);
67 perform_equal_partition = left_partition_len == 0;
68 }
69
70 if perform_equal_partition {
71 let mid_eq = stable_partition(v, scratch, pivot_pos, true, &mut |a, b| !is_less(b, a));
72 v = &mut v[mid_eq..];
73 left_ancestor_pivot = None;
74 continue;
75 }
76
77 // Process left side with the next loop iter, right side with recursion.
78 let (left, right) = v.split_at_mut(left_partition_len);
79 quicksort(right, scratch, limit, pivot_ref, is_less);
80 v = left;
81 }
82}
83
84/// Partitions `v` using pivot `p = v[pivot_pos]` and returns the number of
85/// elements less than `p`. The relative order of elements that compare < p and
86/// those that compare >= p is preserved - it is a stable partition.
87///
88/// If `is_less` is not a strict total order or panics, `scratch.len() < v.len()`,
89/// or `pivot_pos >= v.len()`, the result and `v`'s state is sound but unspecified.
90fn stable_partition<T, F: FnMut(&T, &T) -> bool>(
91 v: &mut [T],
92 scratch: &mut [MaybeUninit<T>],
93 pivot_pos: usize,
94 pivot_goes_left: bool,
95 is_less: &mut F,
96) -> usize {
97 let len = v.len();
98
99 if intrinsics::unlikely(scratch.len() < len || pivot_pos >= len) {
100 core::intrinsics::abort()
101 }
102
103 let v_base = v.as_ptr();
104 let scratch_base = MaybeUninit::slice_as_mut_ptr(scratch);
105
106 // The core idea is to write the values that compare as less-than to the left
107 // side of `scratch`, while the values that compared as greater or equal than
108 // `v[pivot_pos]` go to the right side of `scratch` in reverse. See
109 // PartitionState for details.
110
111 // SAFETY: see individual comments.
112 unsafe {
113 // SAFETY: we made sure the scratch has length >= len and that pivot_pos
114 // is in-bounds. v and scratch are disjoint slices.
115 let pivot = v_base.add(pivot_pos);
116 let mut state = PartitionState::new(v_base, scratch_base, len);
117
118 let mut pivot_in_scratch = ptr::null_mut();
119 let mut loop_end_pos = pivot_pos;
120
121 // SAFETY: this loop is equivalent to calling state.partition_one
122 // exactly len times.
123 loop {
124 // Ideally the outer loop won't be unrolled, to save binary size,
125 // but we do want the inner loop to be unrolled for small types, as
126 // this gave significant performance boosts in benchmarks. Unrolling
127 // through for _ in 0..UNROLL_LEN { .. } instead of manually improves
128 // compile times but has a ~10-20% performance penalty on opt-level=s.
129 if const { size_of::<T>() <= 16 } {
130 const UNROLL_LEN: usize = 4;
131 let unroll_end = v_base.add(loop_end_pos.saturating_sub(UNROLL_LEN - 1));
132 while state.scan < unroll_end {
133 state.partition_one(is_less(&*state.scan, &*pivot));
134 state.partition_one(is_less(&*state.scan, &*pivot));
135 state.partition_one(is_less(&*state.scan, &*pivot));
136 state.partition_one(is_less(&*state.scan, &*pivot));
137 }
138 }
139
140 let loop_end = v_base.add(loop_end_pos);
141 while state.scan < loop_end {
142 state.partition_one(is_less(&*state.scan, &*pivot));
143 }
144
145 if loop_end_pos == len {
146 break;
147 }
148
149 // We avoid comparing pivot with itself, as this could create deadlocks for
150 // certain comparison operators. We also store its location later for later.
151 pivot_in_scratch = state.partition_one(pivot_goes_left);
152
153 loop_end_pos = len;
154 }
155
156 // `pivot` must be copied into its correct position again, because a
157 // comparison operator might have modified it.
158 if has_direct_interior_mutability::<T>() {
159 ptr::copy_nonoverlapping(pivot, pivot_in_scratch, 1);
160 }
161
162 // SAFETY: partition_one being called exactly len times guarantees that scratch
163 // is initialized with a permuted copy of `v`, and that num_left <= v.len().
164 // Copying scratch[0..num_left] and scratch[num_left..v.len()] back is thus
165 // sound, as the values in scratch will never be read again, meaning our copies
166 // semantically act as moves, permuting `v`.
167
168 // Copy all the elements < p directly from swap to v.
169 let v_base = v.as_mut_ptr();
170 ptr::copy_nonoverlapping(scratch_base, v_base, state.num_left);
171
172 // Copy the elements >= p in reverse order.
173 for i in 0..len - state.num_left {
174 ptr::copy_nonoverlapping(
175 scratch_base.add(len - 1 - i),
176 v_base.add(state.num_left + i),
177 1,
178 );
179 }
180
181 state.num_left
182 }
183}
184
185struct PartitionState<T> {
186 // The start of the scratch auxiliary memory.
187 scratch_base: *mut T,
188 // The current element that is being looked at, scans left to right through slice.
189 scan: *const T,
190 // Counts the number of elements that went to the left side, also works around:
191 // https://github.com/rust-lang/rust/issues/117128
192 num_left: usize,
193 // Reverse scratch output pointer.
194 scratch_rev: *mut T,
195}
196
197impl<T> PartitionState<T> {
198 /// # Safety
199 ///
200 /// `scan` and `scratch` must point to valid disjoint buffers of length `len`. The
201 /// scan buffer must be initialized.
202 unsafe fn new(scan: *const T, scratch: *mut T, len: usize) -> Self {
203 // SAFETY: See function safety comment.
204 unsafe { Self { scratch_base: scratch, scan, num_left: 0, scratch_rev: scratch.add(len) } }
205 }
206
207 /// Depending on the value of `towards_left` this function will write a value
208 /// to the growing left or right side of the scratch memory. This forms the
209 /// branchless core of the partition.
210 ///
211 /// # Safety
212 ///
213 /// This function may be called at most `len` times. If it is called exactly
214 /// `len` times the scratch buffer then contains a copy of each element from
215 /// the scan buffer exactly once - a permutation, and num_left <= len.
216 unsafe fn partition_one(&mut self, towards_left: bool) -> *mut T {
217 // SAFETY: see individual comments.
218 unsafe {
219 // SAFETY: in-bounds because this function is called at most len times, and thus
220 // right now is incremented at most len - 1 times. Similarly, num_left < len and
221 // num_right < len, where num_right == i - num_left at the start of the ith
222 // iteration (zero-indexed).
223 self.scratch_rev = self.scratch_rev.sub(1);
224
225 // SAFETY: now we have scratch_rev == base + len - (i + 1). This means
226 // scratch_rev + num_left == base + len - 1 - num_right < base + len.
227 let dst_base = if towards_left { self.scratch_base } else { self.scratch_rev };
228 let dst = dst_base.add(self.num_left);
229 ptr::copy_nonoverlapping(self.scan, dst, 1);
230
231 self.num_left += towards_left as usize;
232 self.scan = self.scan.add(1);
233 dst
234 }
235 }
236}
237
238trait IsFreeze {
239 fn is_freeze() -> bool;
240}
241
242impl<T> IsFreeze for T {
243 default fn is_freeze() -> bool {
244 false
245 }
246}
247impl<T: FreezeMarker> IsFreeze for T {
248 fn is_freeze() -> bool {
249 true
250 }
251}
252
253#[must_use]
254fn has_direct_interior_mutability<T>() -> bool {
255 // If a type has interior mutability it may alter itself during comparison
256 // in a way that must be preserved after the sort operation concludes.
257 // Otherwise a type like Mutex<Option<Box<str>>> could lead to double free.
258 !T::is_freeze()
259}