Skip to main content

rustc_index/
interval.rs

1use std::iter::Step;
2use std::marker::PhantomData;
3use std::ops::{Bound, Range, RangeBounds};
4
5use smallvec::SmallVec;
6
7use crate::idx::Idx;
8use crate::vec::IndexVec;
9
10#[cfg(test)]
11mod tests;
12
13/// Stores a set of intervals on the indices.
14///
15/// The elements in `map` are sorted and non-adjacent, which means
16/// the second value of the previous element is *greater* than the
17/// first value of the following element.
18#[derive(#[automatically_derived]
impl<I: ::core::fmt::Debug> ::core::fmt::Debug for IntervalSet<I> {
    #[inline]
    fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
        ::core::fmt::Formatter::debug_struct_field3_finish(f, "IntervalSet",
            "map", &self.map, "domain", &self.domain, "_data", &&self._data)
    }
}Debug, #[automatically_derived]
impl<I: ::core::clone::Clone> ::core::clone::Clone for IntervalSet<I> {
    #[inline]
    fn clone(&self) -> IntervalSet<I> {
        IntervalSet {
            map: ::core::clone::Clone::clone(&self.map),
            domain: ::core::clone::Clone::clone(&self.domain),
            _data: ::core::clone::Clone::clone(&self._data),
        }
    }
}Clone)]
19pub struct IntervalSet<I> {
20    // Start, end (both inclusive)
21    map: SmallVec<[(u32, u32); 2]>,
22    domain: usize,
23    _data: PhantomData<I>,
24}
25
26#[inline]
27fn inclusive_start<T: Idx>(range: impl RangeBounds<T>) -> u32 {
28    match range.start_bound() {
29        Bound::Included(start) => start.index() as u32,
30        Bound::Excluded(start) => start.index() as u32 + 1,
31        Bound::Unbounded => 0,
32    }
33}
34
35#[inline]
36fn inclusive_end<T: Idx>(domain: usize, range: impl RangeBounds<T>) -> Option<u32> {
37    let end = match range.end_bound() {
38        Bound::Included(end) => end.index() as u32,
39        Bound::Excluded(end) => end.index().checked_sub(1)? as u32,
40        Bound::Unbounded => domain.checked_sub(1)? as u32,
41    };
42    Some(end)
43}
44
45impl<I: Idx> IntervalSet<I> {
46    pub fn new(domain: usize) -> IntervalSet<I> {
47        IntervalSet { map: SmallVec::new(), domain, _data: PhantomData }
48    }
49
50    pub fn clear(&mut self) {
51        self.map.clear();
52    }
53
54    pub fn iter(&self) -> impl Iterator<Item = I>
55    where
56        I: Step,
57    {
58        self.iter_intervals().flatten()
59    }
60
61    /// Iterates through intervals stored in the set, in order.
62    pub fn iter_intervals(&self) -> impl Iterator<Item = std::ops::Range<I>>
63    where
64        I: Step,
65    {
66        self.map.iter().map(|&(start, end)| I::new(start as usize)..I::new(end as usize + 1))
67    }
68
69    /// Returns true if we increased the number of elements present.
70    pub fn insert(&mut self, point: I) -> bool {
71        self.insert_range(point..=point)
72    }
73
74    /// Returns true if we increased the number of elements present.
75    pub fn insert_range(&mut self, range: impl RangeBounds<I> + Clone) -> bool {
76        let start = inclusive_start(range.clone());
77        let Some(end) = inclusive_end(self.domain, range) else {
78            // empty range
79            return false;
80        };
81        if start > end {
82            return false;
83        }
84
85        // This condition looks a bit weird, but actually makes sense.
86        //
87        // if r.0 == end + 1, then we're actually adjacent, so we want to
88        // continue to the next range. We're looking here for the first
89        // range which starts *non-adjacently* to our end.
90        let next = self.map.partition_point(|r| r.0 <= end + 1);
91        let result = if let Some(right) = next.checked_sub(1) {
92            let (prev_start, prev_end) = self.map[right];
93            if prev_end + 1 >= start {
94                // If the start for the inserted range is adjacent to the
95                // end of the previous, we can extend the previous range.
96                if start < prev_start {
97                    // The first range which ends *non-adjacently* to our start.
98                    // And we can ensure that left <= right.
99                    let left = self.map.partition_point(|l| l.1 + 1 < start);
100                    let min = std::cmp::min(self.map[left].0, start);
101                    let max = std::cmp::max(prev_end, end);
102                    self.map[right] = (min, max);
103                    if left != right {
104                        self.map.drain(left..right);
105                    }
106                    true
107                } else {
108                    // We overlap with the previous range, increase it to
109                    // include us.
110                    //
111                    // Make sure we're actually going to *increase* it though --
112                    // it may be that end is just inside the previously existing
113                    // set.
114                    if end > prev_end {
115                        self.map[right].1 = end;
116                        true
117                    } else {
118                        false
119                    }
120                }
121            } else {
122                // Otherwise, we don't overlap, so just insert
123                self.map.insert(right + 1, (start, end));
124                true
125            }
126        } else {
127            if self.map.is_empty() {
128                // Quite common in practice, and expensive to call memcpy
129                // with length zero.
130                self.map.push((start, end));
131            } else {
132                self.map.insert(next, (start, end));
133            }
134            true
135        };
136        if true {
    if !self.check_invariants() {
        {
            ::core::panicking::panic_fmt(format_args!("wrong intervals after insert {0:?}..={1:?} to {2:?}",
                    start, end, self));
        }
    };
};debug_assert!(
137            self.check_invariants(),
138            "wrong intervals after insert {start:?}..={end:?} to {self:?}"
139        );
140        result
141    }
142
143    /// Specialized version of `insert` when we know that the inserted point is *after* any
144    /// contained.
145    pub fn append(&mut self, point: I) {
146        let point = point.index() as u32;
147
148        if let Some((_, last_end)) = self.map.last_mut() {
149            if !(*last_end <= point) {
    ::core::panicking::panic("assertion failed: *last_end <= point")
};assert!(*last_end <= point);
150            if point == *last_end {
151                // The point is already in the set.
152            } else if point == *last_end + 1 {
153                *last_end = point;
154            } else {
155                self.map.push((point, point));
156            }
157        } else {
158            self.map.push((point, point));
159        }
160
161        if true {
    if !self.check_invariants() {
        {
            ::core::panicking::panic_fmt(format_args!("wrong intervals after append {0:?} to {1:?}",
                    point, self));
        }
    };
};debug_assert!(
162            self.check_invariants(),
163            "wrong intervals after append {point:?} to {self:?}"
164        );
165    }
166
167    /// Specialized version of `insert_range` when we know that the inserted point is *after* any
168    /// contained.
169    pub fn append_range(&mut self, range: impl RangeBounds<I> + Clone) {
170        let start = inclusive_start(range.clone());
171        let Some(end) = inclusive_end(self.domain, range) else {
172            // empty range
173            return;
174        };
175        if start > end {
176            return;
177        }
178
179        if let Some((_, last_end)) = self.map.last_mut() {
180            if !(*last_end <= start) {
    ::core::panicking::panic("assertion failed: *last_end <= start")
};assert!(*last_end <= start);
181            // The start is already adjacent to the set.
182            if start <= *last_end + 1 {
183                *last_end = end;
184            } else {
185                self.map.push((start, end));
186            }
187        } else {
188            self.map.push((start, end));
189        }
190
191        if true {
    if !self.check_invariants() {
        {
            ::core::panicking::panic_fmt(format_args!("wrong intervals after append {0:?}..={1:?} to {2:?}",
                    start, end, self));
        }
    };
};debug_assert!(
192            self.check_invariants(),
193            "wrong intervals after append {start:?}..={end:?} to {self:?}"
194        );
195    }
196
197    pub fn contains(&self, needle: I) -> bool {
198        let needle = needle.index() as u32;
199        let Some(last) = self.map.partition_point(|r| r.0 <= needle).checked_sub(1) else {
200            // All ranges in the map start after the new range's end
201            return false;
202        };
203        let (_, prev_end) = &self.map[last];
204        needle <= *prev_end
205    }
206
207    pub fn superset(&self, other: &IntervalSet<I>) -> bool
208    where
209        I: Step,
210    {
211        let mut sup_iter = self.iter_intervals();
212        let mut current = None;
213        let contains = |sup: Range<I>, sub: Range<I>, current: &mut Option<Range<I>>| {
214            if sup.end < sub.start {
215                // if `sup.end == sub.start`, the next sup doesn't contain `sub.start`
216                None // continue to the next sup
217            } else if sup.end >= sub.end && sup.start <= sub.start {
218                *current = Some(sup); // save the current sup
219                Some(true)
220            } else {
221                Some(false)
222            }
223        };
224        other.iter_intervals().all(|sub| {
225            current
226                .take()
227                .and_then(|sup| contains(sup, sub.clone(), &mut current))
228                .or_else(|| sup_iter.find_map(|sup| contains(sup, sub.clone(), &mut current)))
229                .unwrap_or(false)
230        })
231    }
232
233    pub fn disjoint(&self, other: &IntervalSet<I>) -> bool
234    where
235        I: Step,
236    {
237        let helper = move || {
238            let mut self_iter = self.iter_intervals();
239            let mut other_iter = other.iter_intervals();
240
241            let mut self_current = self_iter.next()?;
242            let mut other_current = other_iter.next()?;
243
244            loop {
245                if self_current.end <= other_current.start {
246                    self_current = self_iter.next()?;
247                    continue;
248                }
249                if other_current.end <= self_current.start {
250                    other_current = other_iter.next()?;
251                    continue;
252                }
253                return Some(false);
254            }
255        };
256        helper().unwrap_or(true)
257    }
258
259    pub fn is_empty(&self) -> bool {
260        self.map.is_empty()
261    }
262
263    /// Equivalent to `range.iter().find(|i| !self.contains(i))`.
264    pub fn first_unset_in(&self, range: impl RangeBounds<I> + Clone) -> Option<I> {
265        let start = inclusive_start(range.clone());
266        let Some(end) = inclusive_end(self.domain, range) else {
267            // empty range
268            return None;
269        };
270        if start > end {
271            return None;
272        }
273        let Some(last) = self.map.partition_point(|r| r.0 <= start).checked_sub(1) else {
274            // All ranges in the map start after the new range's end
275            return Some(I::new(start as usize));
276        };
277        let (_, prev_end) = self.map[last];
278        if start > prev_end {
279            Some(I::new(start as usize))
280        } else if prev_end < end {
281            Some(I::new(prev_end as usize + 1))
282        } else {
283            None
284        }
285    }
286
287    /// Returns the maximum (last) element present in the set from `range`.
288    pub fn last_set_in(&self, range: impl RangeBounds<I> + Clone) -> Option<I> {
289        let start = inclusive_start(range.clone());
290        let Some(end) = inclusive_end(self.domain, range) else {
291            // empty range
292            return None;
293        };
294        if start > end {
295            return None;
296        }
297        let Some(last) = self.map.partition_point(|r| r.0 <= end).checked_sub(1) else {
298            // All ranges in the map start after the new range's end
299            return None;
300        };
301        let (_, prev_end) = &self.map[last];
302        if start <= *prev_end { Some(I::new(std::cmp::min(*prev_end, end) as usize)) } else { None }
303    }
304
305    pub fn insert_all(&mut self) {
306        self.clear();
307        if let Some(end) = self.domain.checked_sub(1) {
308            self.map.push((0, end.try_into().unwrap()));
309        }
310        if true {
    if !self.check_invariants() {
        ::core::panicking::panic("assertion failed: self.check_invariants()")
    };
};debug_assert!(self.check_invariants());
311    }
312
313    pub fn union(&mut self, other: &IntervalSet<I>) -> bool
314    where
315        I: Step,
316    {
317        match (&self.domain, &other.domain) {
    (left_val, right_val) => {
        if !(*left_val == *right_val) {
            let kind = ::core::panicking::AssertKind::Eq;
            ::core::panicking::assert_failed(kind, &*left_val, &*right_val,
                ::core::option::Option::None);
        }
    }
};assert_eq!(self.domain, other.domain);
318        if self.map.len() < other.map.len() {
319            let backup = self.clone();
320            self.map.clone_from(&other.map);
321            return self.union(&backup);
322        }
323
324        let mut did_insert = false;
325        for range in other.iter_intervals() {
326            did_insert |= self.insert_range(range);
327        }
328        if true {
    if !self.check_invariants() {
        ::core::panicking::panic("assertion failed: self.check_invariants()")
    };
};debug_assert!(self.check_invariants());
329        did_insert
330    }
331
332    // Check the intervals are valid, sorted and non-adjacent
333    fn check_invariants(&self) -> bool {
334        let mut current: Option<u32> = None;
335        for (start, end) in &self.map {
336            if start > end || current.is_some_and(|x| x + 1 >= *start) {
337                return false;
338            }
339            current = Some(*end);
340        }
341        current.is_none_or(|x| x < self.domain as u32)
342    }
343}
344
345/// This data structure optimizes for cases where the stored bits in each row
346/// are expected to be highly contiguous (long ranges of 1s or 0s), in contrast
347/// to BitMatrix and SparseBitMatrix which are optimized for
348/// "random"/non-contiguous bits and cheap(er) point queries at the expense of
349/// memory usage.
350#[derive(#[automatically_derived]
impl<R: ::core::clone::Clone, C: ::core::clone::Clone> ::core::clone::Clone
    for SparseIntervalMatrix<R, C> where R: Idx, C: Idx {
    #[inline]
    fn clone(&self) -> SparseIntervalMatrix<R, C> {
        SparseIntervalMatrix {
            rows: ::core::clone::Clone::clone(&self.rows),
            column_size: ::core::clone::Clone::clone(&self.column_size),
        }
    }
}Clone)]
351pub struct SparseIntervalMatrix<R, C>
352where
353    R: Idx,
354    C: Idx,
355{
356    rows: IndexVec<R, IntervalSet<C>>,
357    column_size: usize,
358}
359
360impl<R: Idx, C: Step + Idx> SparseIntervalMatrix<R, C> {
361    pub fn new(column_size: usize) -> SparseIntervalMatrix<R, C> {
362        SparseIntervalMatrix { rows: IndexVec::new(), column_size }
363    }
364
365    pub fn rows(&self) -> impl Iterator<Item = R> {
366        self.rows.indices()
367    }
368
369    pub fn row(&self, row: R) -> Option<&IntervalSet<C>> {
370        self.rows.get(row)
371    }
372
373    fn ensure_row(&mut self, row: R) -> &mut IntervalSet<C> {
374        self.rows.ensure_contains_elem(row, || IntervalSet::new(self.column_size))
375    }
376
377    pub fn union_row(&mut self, row: R, from: &IntervalSet<C>) -> bool
378    where
379        C: Step,
380    {
381        self.ensure_row(row).union(from)
382    }
383
384    pub fn union_rows(&mut self, read: R, write: R) -> bool
385    where
386        C: Step,
387    {
388        if read == write || self.rows.get(read).is_none() {
389            return false;
390        }
391        self.ensure_row(write);
392        let (read_row, write_row) = self.rows.pick2_mut(read, write);
393        write_row.union(read_row)
394    }
395
396    pub fn insert_all_into_row(&mut self, row: R) {
397        self.ensure_row(row).insert_all();
398    }
399
400    pub fn insert_range(&mut self, row: R, range: impl RangeBounds<C> + Clone) {
401        self.ensure_row(row).insert_range(range);
402    }
403
404    pub fn insert(&mut self, row: R, point: C) -> bool {
405        self.ensure_row(row).insert(point)
406    }
407
408    pub fn append(&mut self, row: R, point: C) {
409        self.ensure_row(row).append(point)
410    }
411
412    pub fn append_range(&mut self, row: R, point: impl RangeBounds<C> + Clone) {
413        self.ensure_row(row).append_range(point)
414    }
415
416    pub fn contains(&self, row: R, point: C) -> bool {
417        self.row(row).is_some_and(|r| r.contains(point))
418    }
419}