1use core::borrow::Borrow;
2use core::cmp::Ordering;
3use core::ops::{Bound, RangeBounds};
4
5use SearchBound::*;
6use SearchResult::*;
7
8use super::node::ForceResult::*;
9use super::node::{Handle, NodeRef, marker};
10
11pub(super) enum SearchBound<T> {
12 Included(T),
14 Excluded(T),
16 AllIncluded,
18 AllExcluded,
20}
21
22impl<T> SearchBound<T> {
23 pub(super) fn from_range(range_bound: Bound<T>) -> Self {
24 match range_bound {
25 Bound::Included(t) => Included(t),
26 Bound::Excluded(t) => Excluded(t),
27 Bound::Unbounded => AllIncluded,
28 }
29 }
30}
31
32pub(super) enum SearchResult<BorrowType, K, V, FoundType, GoDownType> {
33 Found(Handle<NodeRef<BorrowType, K, V, FoundType>, marker::KV>),
34 GoDown(Handle<NodeRef<BorrowType, K, V, GoDownType>, marker::Edge>),
35}
36
37pub(super) enum IndexResult {
38 KV(usize),
39 Edge(usize),
40}
41
42impl<BorrowType: marker::BorrowType, K, V> NodeRef<BorrowType, K, V, marker::LeafOrInternal> {
43 pub(super) fn search_tree<Q: ?Sized>(
50 mut self,
51 key: &Q,
52 ) -> SearchResult<BorrowType, K, V, marker::LeafOrInternal, marker::Leaf>
53 where
54 Q: Ord,
55 K: Borrow<Q>,
56 {
57 loop {
58 self = match self.search_node(key) {
59 Found(handle) => return Found(handle),
60 GoDown(handle) => match handle.force() {
61 Leaf(leaf) => return GoDown(leaf),
62 Internal(internal) => internal.descend(),
63 },
64 }
65 }
66 }
67
68 pub(super) fn search_tree_for_bifurcation<'r, Q: ?Sized, R>(
84 mut self,
85 range: &'r R,
86 ) -> Result<
87 (
88 NodeRef<BorrowType, K, V, marker::LeafOrInternal>,
89 usize,
90 usize,
91 SearchBound<&'r Q>,
92 SearchBound<&'r Q>,
93 ),
94 Handle<NodeRef<BorrowType, K, V, marker::Leaf>, marker::Edge>,
95 >
96 where
97 Q: Ord,
98 K: Borrow<Q>,
99 R: RangeBounds<Q>,
100 {
101 let is_set = <V as super::set_val::IsSetVal>::is_set_val();
103
104 let (start, end) = (range.start_bound(), range.end_bound());
107 match (start, end) {
108 (Bound::Excluded(s), Bound::Excluded(e)) if s == e => {
109 if is_set {
110 panic!("range start and end are equal and excluded in BTreeSet")
111 } else {
112 panic!("range start and end are equal and excluded in BTreeMap")
113 }
114 }
115 (Bound::Included(s) | Bound::Excluded(s), Bound::Included(e) | Bound::Excluded(e))
116 if s > e =>
117 {
118 if is_set {
119 panic!("range start is greater than range end in BTreeSet")
120 } else {
121 panic!("range start is greater than range end in BTreeMap")
122 }
123 }
124 _ => {}
125 }
126 let mut lower_bound = SearchBound::from_range(start);
127 let mut upper_bound = SearchBound::from_range(end);
128 loop {
129 let (lower_edge_idx, lower_child_bound) = self.find_lower_bound_index(lower_bound);
130 let (upper_edge_idx, upper_child_bound) =
131 unsafe { self.find_upper_bound_index(upper_bound, lower_edge_idx) };
132 if lower_edge_idx < upper_edge_idx {
133 return Ok((
134 self,
135 lower_edge_idx,
136 upper_edge_idx,
137 lower_child_bound,
138 upper_child_bound,
139 ));
140 }
141 debug_assert_eq!(lower_edge_idx, upper_edge_idx);
142 let common_edge = unsafe { Handle::new_edge(self, lower_edge_idx) };
143 match common_edge.force() {
144 Leaf(common_edge) => return Err(common_edge),
145 Internal(common_edge) => {
146 self = common_edge.descend();
147 lower_bound = lower_child_bound;
148 upper_bound = upper_child_bound;
149 }
150 }
151 }
152 }
153
154 pub(super) fn find_lower_bound_edge<'r, Q>(
160 self,
161 bound: SearchBound<&'r Q>,
162 ) -> (Handle<Self, marker::Edge>, SearchBound<&'r Q>)
163 where
164 Q: ?Sized + Ord,
165 K: Borrow<Q>,
166 {
167 let (edge_idx, bound) = self.find_lower_bound_index(bound);
168 let edge = unsafe { Handle::new_edge(self, edge_idx) };
169 (edge, bound)
170 }
171
172 pub(super) fn find_upper_bound_edge<'r, Q>(
174 self,
175 bound: SearchBound<&'r Q>,
176 ) -> (Handle<Self, marker::Edge>, SearchBound<&'r Q>)
177 where
178 Q: ?Sized + Ord,
179 K: Borrow<Q>,
180 {
181 let (edge_idx, bound) = unsafe { self.find_upper_bound_index(bound, 0) };
182 let edge = unsafe { Handle::new_edge(self, edge_idx) };
183 (edge, bound)
184 }
185}
186
187impl<BorrowType, K, V, Type> NodeRef<BorrowType, K, V, Type> {
188 pub(super) fn search_node<Q: ?Sized>(
196 self,
197 key: &Q,
198 ) -> SearchResult<BorrowType, K, V, Type, Type>
199 where
200 Q: Ord,
201 K: Borrow<Q>,
202 {
203 match unsafe { self.find_key_index(key, 0) } {
204 IndexResult::KV(idx) => Found(unsafe { Handle::new_kv(self, idx) }),
205 IndexResult::Edge(idx) => GoDown(unsafe { Handle::new_edge(self, idx) }),
206 }
207 }
208
209 unsafe fn find_key_index<Q: ?Sized>(&self, key: &Q, start_index: usize) -> IndexResult
218 where
219 Q: Ord,
220 K: Borrow<Q>,
221 {
222 let node = self.reborrow();
223 let keys = node.keys();
224 debug_assert!(start_index <= keys.len());
225 for (offset, k) in unsafe { keys.get_unchecked(start_index..) }.iter().enumerate() {
226 match key.cmp(k.borrow()) {
227 Ordering::Greater => {}
228 Ordering::Equal => return IndexResult::KV(start_index + offset),
229 Ordering::Less => return IndexResult::Edge(start_index + offset),
230 }
231 }
232 IndexResult::Edge(keys.len())
233 }
234
235 fn find_lower_bound_index<'r, Q>(
241 &self,
242 bound: SearchBound<&'r Q>,
243 ) -> (usize, SearchBound<&'r Q>)
244 where
245 Q: ?Sized + Ord,
246 K: Borrow<Q>,
247 {
248 match bound {
249 Included(key) => match unsafe { self.find_key_index(key, 0) } {
250 IndexResult::KV(idx) => (idx, AllExcluded),
251 IndexResult::Edge(idx) => (idx, bound),
252 },
253 Excluded(key) => match unsafe { self.find_key_index(key, 0) } {
254 IndexResult::KV(idx) => (idx + 1, AllIncluded),
255 IndexResult::Edge(idx) => (idx, bound),
256 },
257 AllIncluded => (0, AllIncluded),
258 AllExcluded => (self.len(), AllExcluded),
259 }
260 }
261
262 unsafe fn find_upper_bound_index<'r, Q>(
268 &self,
269 bound: SearchBound<&'r Q>,
270 start_index: usize,
271 ) -> (usize, SearchBound<&'r Q>)
272 where
273 Q: ?Sized + Ord,
274 K: Borrow<Q>,
275 {
276 match bound {
277 Included(key) => match unsafe { self.find_key_index(key, start_index) } {
278 IndexResult::KV(idx) => (idx + 1, AllExcluded),
279 IndexResult::Edge(idx) => (idx, bound),
280 },
281 Excluded(key) => match unsafe { self.find_key_index(key, start_index) } {
282 IndexResult::KV(idx) => (idx, AllIncluded),
283 IndexResult::Edge(idx) => (idx, bound),
284 },
285 AllIncluded => (self.len(), AllIncluded),
286 AllExcluded => (start_index, AllExcluded),
287 }
288 }
289}