core/slice/sort/stable/
merge.rs

1//! This module contains logic for performing a merge of two sorted sub-slices.
2
3use crate::mem::MaybeUninit;
4use crate::{cmp, ptr};
5
6/// Merges non-decreasing runs `v[..mid]` and `v[mid..]` using `scratch` as
7/// temporary storage, and stores the result into `v[..]`.
8pub fn merge<T, F: FnMut(&T, &T) -> bool>(
9    v: &mut [T],
10    scratch: &mut [MaybeUninit<T>],
11    mid: usize,
12    is_less: &mut F,
13) {
14    let len = v.len();
15
16    if mid == 0 || mid >= len || scratch.len() < cmp::min(mid, len - mid) {
17        return;
18    }
19
20    // SAFETY: We checked that the two slices are non-empty and `mid` is in-bounds.
21    // We checked that the buffer `scratch` has enough capacity to hold a copy of
22    // the shorter slice. `merge_up` and `merge_down` are written in such a way that
23    // they uphold the contract described in `MergeState::drop`.
24    unsafe {
25        // The merge process first copies the shorter run into `buf`. Then it traces
26        // the newly copied run and the longer run forwards (or backwards), comparing
27        // their next unconsumed elements and copying the lesser (or greater) one into `v`.
28        //
29        // As soon as the shorter run is fully consumed, the process is done. If the
30        // longer run gets consumed first, then we must copy whatever is left of the
31        // shorter run into the remaining gap in `v`.
32        //
33        // Intermediate state of the process is always tracked by `gap`, which serves
34        // two purposes:
35        //  1. Protects integrity of `v` from panics in `is_less`.
36        //  2. Fills the remaining gap in `v` if the longer run gets consumed first.
37
38        let buf = MaybeUninit::slice_as_mut_ptr(scratch);
39
40        let v_base = v.as_mut_ptr();
41        let v_mid = v_base.add(mid);
42        let v_end = v_base.add(len);
43
44        let left_len = mid;
45        let right_len = len - mid;
46
47        let left_is_shorter = left_len <= right_len;
48        let save_base = if left_is_shorter { v_base } else { v_mid };
49        let save_len = if left_is_shorter { left_len } else { right_len };
50
51        ptr::copy_nonoverlapping(save_base, buf, save_len);
52
53        let mut merge_state = MergeState { start: buf, end: buf.add(save_len), dst: save_base };
54
55        if left_is_shorter {
56            merge_state.merge_up(v_mid, v_end, is_less);
57        } else {
58            merge_state.merge_down(v_base, buf, v_end, is_less);
59        }
60        // Finally, `merge_state` gets dropped. If the shorter run was not fully
61        // consumed, whatever remains of it will now be copied into the hole in `v`.
62    }
63}
64
65// When dropped, copies the range `start..end` into `dst..`.
66struct MergeState<T> {
67    start: *mut T,
68    end: *mut T,
69    dst: *mut T,
70}
71
72impl<T> MergeState<T> {
73    /// # Safety
74    /// The caller MUST guarantee that `self` is initialized in a way where `start -> end` is
75    /// the longer sub-slice and so that `dst` can be written to at least the shorter sub-slice
76    /// length times. In addition `start -> end` and `right -> right_end` MUST be valid to be
77    /// read. This function MUST only be called once.
78    unsafe fn merge_up<F: FnMut(&T, &T) -> bool>(
79        &mut self,
80        mut right: *const T,
81        right_end: *const T,
82        is_less: &mut F,
83    ) {
84        // SAFETY: See function safety comment.
85        unsafe {
86            let left = &mut self.start;
87            let out = &mut self.dst;
88
89            while *left != self.end && right as *const T != right_end {
90                let consume_left = !is_less(&*right, &**left);
91
92                let src = if consume_left { *left } else { right };
93                ptr::copy_nonoverlapping(src, *out, 1);
94
95                *left = left.add(consume_left as usize);
96                right = right.add(!consume_left as usize);
97
98                *out = out.add(1);
99            }
100        }
101    }
102
103    /// # Safety
104    /// The caller MUST guarantee that `self` is initialized in a way where `left_end <- dst` is
105    /// the shorter sub-slice and so that `out` can be written to at least the shorter sub-slice
106    /// length times. In addition `left_end <- dst` and `right_end <- end` MUST be valid to be
107    /// read. This function MUST only be called once.
108    unsafe fn merge_down<F: FnMut(&T, &T) -> bool>(
109        &mut self,
110        left_end: *const T,
111        right_end: *const T,
112        mut out: *mut T,
113        is_less: &mut F,
114    ) {
115        // SAFETY: See function safety comment.
116        unsafe {
117            loop {
118                let left = self.dst.sub(1);
119                let right = self.end.sub(1);
120                out = out.sub(1);
121
122                let consume_left = is_less(&*right, &*left);
123
124                let src = if consume_left { left } else { right };
125                ptr::copy_nonoverlapping(src, out, 1);
126
127                self.dst = left.add(!consume_left as usize);
128                self.end = right.add(consume_left as usize);
129
130                if self.dst as *const T == left_end || self.end as *const T == right_end {
131                    break;
132                }
133            }
134        }
135    }
136}
137
138impl<T> Drop for MergeState<T> {
139    fn drop(&mut self) {
140        // SAFETY: The user of MergeState MUST ensure, that at any point this drop
141        // impl MAY run, for example when the user provided `is_less` panics, that
142        // copying the contiguous region between `start` and `end` to `dst` will
143        // leave the input slice `v` with each original element and all possible
144        // modifications observed.
145        unsafe {
146            let len = self.end.offset_from_unsigned(self.start);
147            ptr::copy_nonoverlapping(self.start, self.dst, len);
148        }
149    }
150}