core/slice/sort/unstable/
heapsort.rs

1//! This module contains a branchless heapsort as fallback for unstable quicksort.
2
3use crate::{cmp, intrinsics, ptr};
4
5/// Sorts `v` using heapsort, which guarantees *O*(*n* \* log(*n*)) worst-case.
6///
7/// Never inline this, it sits the main hot-loop in `recurse` and is meant as unlikely algorithmic
8/// fallback.
9#[inline(never)]
10pub(crate) fn heapsort<T, F>(v: &mut [T], is_less: &mut F)
11where
12    F: FnMut(&T, &T) -> bool,
13{
14    let len = v.len();
15
16    for i in (0..len + len / 2).rev() {
17        let sift_idx = if i >= len {
18            i - len
19        } else {
20            v.swap(0, i);
21            0
22        };
23
24        // SAFETY: The above calculation ensures that `sift_idx` is either 0 or
25        // `(len..(len + (len / 2))) - len`, which simplifies to `0..(len / 2)`.
26        // This guarantees the required `sift_idx <= len`.
27        unsafe {
28            sift_down(&mut v[..cmp::min(i, len)], sift_idx, is_less);
29        }
30    }
31}
32
33// This binary heap respects the invariant `parent >= child`.
34//
35// SAFETY: The caller has to guarantee that `node <= v.len()`.
36#[inline(always)]
37unsafe fn sift_down<T, F>(v: &mut [T], mut node: usize, is_less: &mut F)
38where
39    F: FnMut(&T, &T) -> bool,
40{
41    // SAFETY: See function safety.
42    unsafe {
43        intrinsics::assume(node <= v.len());
44    }
45
46    let len = v.len();
47
48    let v_base = v.as_mut_ptr();
49
50    loop {
51        // Children of `node`.
52        let mut child = 2 * node + 1;
53        if child >= len {
54            break;
55        }
56
57        // SAFETY: The invariants and checks guarantee that both node and child are in-bounds.
58        unsafe {
59            // Choose the greater child.
60            if child + 1 < len {
61                // We need a branch to be sure not to out-of-bounds index,
62                // but it's highly predictable.  The comparison, however,
63                // is better done branchless, especially for primitives.
64                child += is_less(&*v_base.add(child), &*v_base.add(child + 1)) as usize;
65            }
66
67            // Stop if the invariant holds at `node`.
68            if !is_less(&*v_base.add(node), &*v_base.add(child)) {
69                break;
70            }
71
72            ptr::swap_nonoverlapping(v_base.add(node), v_base.add(child), 1);
73        }
74
75        node = child;
76    }
77}