rustc_index/
bit_set.rs

1use std::marker::PhantomData;
2#[cfg(not(feature = "nightly"))]
3use std::mem;
4use std::ops::{BitAnd, BitAndAssign, BitOrAssign, Bound, Not, Range, RangeBounds, Shl};
5use std::rc::Rc;
6use std::{fmt, iter, slice};
7
8use Chunk::*;
9#[cfg(feature = "nightly")]
10use rustc_macros::{Decodable_NoContext, Encodable_NoContext};
11
12use crate::{Idx, IndexVec};
13
14#[cfg(test)]
15mod tests;
16
17type Word = u64;
18const WORD_BYTES: usize = size_of::<Word>();
19const WORD_BITS: usize = WORD_BYTES * 8;
20
21// The choice of chunk size has some trade-offs.
22//
23// A big chunk size tends to favour cases where many large `ChunkedBitSet`s are
24// present, because they require fewer `Chunk`s, reducing the number of
25// allocations and reducing peak memory usage. Also, fewer chunk operations are
26// required, though more of them might be `Mixed`.
27//
28// A small chunk size tends to favour cases where many small `ChunkedBitSet`s
29// are present, because less space is wasted at the end of the final chunk (if
30// it's not full).
31const CHUNK_WORDS: usize = 32;
32const CHUNK_BITS: usize = CHUNK_WORDS * WORD_BITS; // 2048 bits
33
34/// ChunkSize is small to keep `Chunk` small. The static assertion ensures it's
35/// not too small.
36type ChunkSize = u16;
37const _: () = assert!(CHUNK_BITS <= ChunkSize::MAX as usize);
38
39pub trait BitRelations<Rhs> {
40    fn union(&mut self, other: &Rhs) -> bool;
41    fn subtract(&mut self, other: &Rhs) -> bool;
42    fn intersect(&mut self, other: &Rhs) -> bool;
43}
44
45#[inline]
46fn inclusive_start_end<T: Idx>(
47    range: impl RangeBounds<T>,
48    domain: usize,
49) -> Option<(usize, usize)> {
50    // Both start and end are inclusive.
51    let start = match range.start_bound().cloned() {
52        Bound::Included(start) => start.index(),
53        Bound::Excluded(start) => start.index() + 1,
54        Bound::Unbounded => 0,
55    };
56    let end = match range.end_bound().cloned() {
57        Bound::Included(end) => end.index(),
58        Bound::Excluded(end) => end.index().checked_sub(1)?,
59        Bound::Unbounded => domain - 1,
60    };
61    assert!(end < domain);
62    if start > end {
63        return None;
64    }
65    Some((start, end))
66}
67
68macro_rules! bit_relations_inherent_impls {
69    () => {
70        /// Sets `self = self | other` and returns `true` if `self` changed
71        /// (i.e., if new bits were added).
72        pub fn union<Rhs>(&mut self, other: &Rhs) -> bool
73        where
74            Self: BitRelations<Rhs>,
75        {
76            <Self as BitRelations<Rhs>>::union(self, other)
77        }
78
79        /// Sets `self = self - other` and returns `true` if `self` changed.
80        /// (i.e., if any bits were removed).
81        pub fn subtract<Rhs>(&mut self, other: &Rhs) -> bool
82        where
83            Self: BitRelations<Rhs>,
84        {
85            <Self as BitRelations<Rhs>>::subtract(self, other)
86        }
87
88        /// Sets `self = self & other` and return `true` if `self` changed.
89        /// (i.e., if any bits were removed).
90        pub fn intersect<Rhs>(&mut self, other: &Rhs) -> bool
91        where
92            Self: BitRelations<Rhs>,
93        {
94            <Self as BitRelations<Rhs>>::intersect(self, other)
95        }
96    };
97}
98
99/// A fixed-size bitset type with a dense representation.
100///
101/// Note 1: Since this bitset is dense, if your domain is big, and/or relatively
102/// homogeneous (for example, with long runs of bits set or unset), then it may
103/// be preferable to instead use a [MixedBitSet], or an
104/// [IntervalSet](crate::interval::IntervalSet). They should be more suited to
105/// sparse, or highly-compressible, domains.
106///
107/// Note 2: Use [`GrowableBitSet`] if you need support for resizing after creation.
108///
109/// `T` is an index type, typically a newtyped `usize` wrapper, but it can also
110/// just be `usize`.
111///
112/// All operations that involve an element will panic if the element is equal
113/// to or greater than the domain size. All operations that involve two bitsets
114/// will panic if the bitsets have differing domain sizes.
115///
116#[cfg_attr(feature = "nightly", derive(Decodable_NoContext, Encodable_NoContext))]
117#[derive(Eq, PartialEq, Hash)]
118pub struct DenseBitSet<T> {
119    domain_size: usize,
120    words: Vec<Word>,
121    marker: PhantomData<T>,
122}
123
124impl<T> DenseBitSet<T> {
125    /// Gets the domain size.
126    pub fn domain_size(&self) -> usize {
127        self.domain_size
128    }
129}
130
131impl<T: Idx> DenseBitSet<T> {
132    /// Creates a new, empty bitset with a given `domain_size`.
133    #[inline]
134    pub fn new_empty(domain_size: usize) -> DenseBitSet<T> {
135        let num_words = num_words(domain_size);
136        DenseBitSet { domain_size, words: vec![0; num_words], marker: PhantomData }
137    }
138
139    /// Creates a new, filled bitset with a given `domain_size`.
140    #[inline]
141    pub fn new_filled(domain_size: usize) -> DenseBitSet<T> {
142        let num_words = num_words(domain_size);
143        let mut result =
144            DenseBitSet { domain_size, words: vec![!0; num_words], marker: PhantomData };
145        result.clear_excess_bits();
146        result
147    }
148
149    /// Clear all elements.
150    #[inline]
151    pub fn clear(&mut self) {
152        self.words.fill(0);
153    }
154
155    /// Clear excess bits in the final word.
156    fn clear_excess_bits(&mut self) {
157        clear_excess_bits_in_final_word(self.domain_size, &mut self.words);
158    }
159
160    /// Count the number of set bits in the set.
161    pub fn count(&self) -> usize {
162        count_ones(&self.words)
163    }
164
165    /// Returns `true` if `self` contains `elem`.
166    #[inline]
167    pub fn contains(&self, elem: T) -> bool {
168        assert!(elem.index() < self.domain_size);
169        let (word_index, mask) = word_index_and_mask(elem);
170        (self.words[word_index] & mask) != 0
171    }
172
173    /// Is `self` is a (non-strict) superset of `other`?
174    #[inline]
175    pub fn superset(&self, other: &DenseBitSet<T>) -> bool {
176        assert_eq!(self.domain_size, other.domain_size);
177        self.words.iter().zip(&other.words).all(|(a, b)| (a & b) == *b)
178    }
179
180    /// Is the set empty?
181    #[inline]
182    pub fn is_empty(&self) -> bool {
183        self.words.iter().all(|a| *a == 0)
184    }
185
186    /// Insert `elem`. Returns whether the set has changed.
187    #[inline]
188    pub fn insert(&mut self, elem: T) -> bool {
189        assert!(
190            elem.index() < self.domain_size,
191            "inserting element at index {} but domain size is {}",
192            elem.index(),
193            self.domain_size,
194        );
195        let (word_index, mask) = word_index_and_mask(elem);
196        let word_ref = &mut self.words[word_index];
197        let word = *word_ref;
198        let new_word = word | mask;
199        *word_ref = new_word;
200        new_word != word
201    }
202
203    #[inline]
204    pub fn insert_range(&mut self, elems: impl RangeBounds<T>) {
205        let Some((start, end)) = inclusive_start_end(elems, self.domain_size) else {
206            return;
207        };
208
209        let (start_word_index, start_mask) = word_index_and_mask(start);
210        let (end_word_index, end_mask) = word_index_and_mask(end);
211
212        // Set all words in between start and end (exclusively of both).
213        for word_index in (start_word_index + 1)..end_word_index {
214            self.words[word_index] = !0;
215        }
216
217        if start_word_index != end_word_index {
218            // Start and end are in different words, so we handle each in turn.
219            //
220            // We set all leading bits. This includes the start_mask bit.
221            self.words[start_word_index] |= !(start_mask - 1);
222            // And all trailing bits (i.e. from 0..=end) in the end word,
223            // including the end.
224            self.words[end_word_index] |= end_mask | (end_mask - 1);
225        } else {
226            self.words[start_word_index] |= end_mask | (end_mask - start_mask);
227        }
228    }
229
230    /// Sets all bits to true.
231    pub fn insert_all(&mut self) {
232        self.words.fill(!0);
233        self.clear_excess_bits();
234    }
235
236    /// Checks whether any bit in the given range is a 1.
237    #[inline]
238    pub fn contains_any(&self, elems: impl RangeBounds<T>) -> bool {
239        let Some((start, end)) = inclusive_start_end(elems, self.domain_size) else {
240            return false;
241        };
242        let (start_word_index, start_mask) = word_index_and_mask(start);
243        let (end_word_index, end_mask) = word_index_and_mask(end);
244
245        if start_word_index == end_word_index {
246            self.words[start_word_index] & (end_mask | (end_mask - start_mask)) != 0
247        } else {
248            if self.words[start_word_index] & !(start_mask - 1) != 0 {
249                return true;
250            }
251
252            let remaining = start_word_index + 1..end_word_index;
253            if remaining.start <= remaining.end {
254                self.words[remaining].iter().any(|&w| w != 0)
255                    || self.words[end_word_index] & (end_mask | (end_mask - 1)) != 0
256            } else {
257                false
258            }
259        }
260    }
261
262    /// Returns `true` if the set has changed.
263    #[inline]
264    pub fn remove(&mut self, elem: T) -> bool {
265        assert!(elem.index() < self.domain_size);
266        let (word_index, mask) = word_index_and_mask(elem);
267        let word_ref = &mut self.words[word_index];
268        let word = *word_ref;
269        let new_word = word & !mask;
270        *word_ref = new_word;
271        new_word != word
272    }
273
274    /// Iterates over the indices of set bits in a sorted order.
275    #[inline]
276    pub fn iter(&self) -> BitIter<'_, T> {
277        BitIter::new(&self.words)
278    }
279
280    pub fn last_set_in(&self, range: impl RangeBounds<T>) -> Option<T> {
281        let (start, end) = inclusive_start_end(range, self.domain_size)?;
282        let (start_word_index, _) = word_index_and_mask(start);
283        let (end_word_index, end_mask) = word_index_and_mask(end);
284
285        let end_word = self.words[end_word_index] & (end_mask | (end_mask - 1));
286        if end_word != 0 {
287            let pos = max_bit(end_word) + WORD_BITS * end_word_index;
288            if start <= pos {
289                return Some(T::new(pos));
290            }
291        }
292
293        // We exclude end_word_index from the range here, because we don't want
294        // to limit ourselves to *just* the last word: the bits set it in may be
295        // after `end`, so it may not work out.
296        if let Some(offset) =
297            self.words[start_word_index..end_word_index].iter().rposition(|&w| w != 0)
298        {
299            let word_idx = start_word_index + offset;
300            let start_word = self.words[word_idx];
301            let pos = max_bit(start_word) + WORD_BITS * word_idx;
302            if start <= pos {
303                return Some(T::new(pos));
304            }
305        }
306
307        None
308    }
309
310    bit_relations_inherent_impls! {}
311
312    /// Sets `self = self | !other`.
313    ///
314    /// FIXME: Incorporate this into [`BitRelations`] and fill out
315    /// implementations for other bitset types, if needed.
316    pub fn union_not(&mut self, other: &DenseBitSet<T>) {
317        assert_eq!(self.domain_size, other.domain_size);
318
319        // FIXME(Zalathar): If we were to forcibly _set_ all excess bits before
320        // the bitwise update, and then clear them again afterwards, we could
321        // quickly and accurately detect whether the update changed anything.
322        // But that's only worth doing if there's an actual use-case.
323
324        bitwise(&mut self.words, &other.words, |a, b| a | !b);
325        // The bitwise update `a | !b` can result in the last word containing
326        // out-of-domain bits, so we need to clear them.
327        self.clear_excess_bits();
328    }
329}
330
331// dense REL dense
332impl<T: Idx> BitRelations<DenseBitSet<T>> for DenseBitSet<T> {
333    fn union(&mut self, other: &DenseBitSet<T>) -> bool {
334        assert_eq!(self.domain_size, other.domain_size);
335        bitwise(&mut self.words, &other.words, |a, b| a | b)
336    }
337
338    fn subtract(&mut self, other: &DenseBitSet<T>) -> bool {
339        assert_eq!(self.domain_size, other.domain_size);
340        bitwise(&mut self.words, &other.words, |a, b| a & !b)
341    }
342
343    fn intersect(&mut self, other: &DenseBitSet<T>) -> bool {
344        assert_eq!(self.domain_size, other.domain_size);
345        bitwise(&mut self.words, &other.words, |a, b| a & b)
346    }
347}
348
349impl<T: Idx> From<GrowableBitSet<T>> for DenseBitSet<T> {
350    fn from(bit_set: GrowableBitSet<T>) -> Self {
351        bit_set.bit_set
352    }
353}
354
355impl<T> Clone for DenseBitSet<T> {
356    fn clone(&self) -> Self {
357        DenseBitSet {
358            domain_size: self.domain_size,
359            words: self.words.clone(),
360            marker: PhantomData,
361        }
362    }
363
364    fn clone_from(&mut self, from: &Self) {
365        self.domain_size = from.domain_size;
366        self.words.clone_from(&from.words);
367    }
368}
369
370impl<T: Idx> fmt::Debug for DenseBitSet<T> {
371    fn fmt(&self, w: &mut fmt::Formatter<'_>) -> fmt::Result {
372        w.debug_list().entries(self.iter()).finish()
373    }
374}
375
376impl<T: Idx> ToString for DenseBitSet<T> {
377    fn to_string(&self) -> String {
378        let mut result = String::new();
379        let mut sep = '[';
380
381        // Note: this is a little endian printout of bytes.
382
383        // i tracks how many bits we have printed so far.
384        let mut i = 0;
385        for word in &self.words {
386            let mut word = *word;
387            for _ in 0..WORD_BYTES {
388                // for each byte in `word`:
389                let remain = self.domain_size - i;
390                // If less than a byte remains, then mask just that many bits.
391                let mask = if remain <= 8 { (1 << remain) - 1 } else { 0xFF };
392                assert!(mask <= 0xFF);
393                let byte = word & mask;
394
395                result.push_str(&format!("{sep}{byte:02x}"));
396
397                if remain <= 8 {
398                    break;
399                }
400                word >>= 8;
401                i += 8;
402                sep = '-';
403            }
404            sep = '|';
405        }
406        result.push(']');
407
408        result
409    }
410}
411
412pub struct BitIter<'a, T: Idx> {
413    /// A copy of the current word, but with any already-visited bits cleared.
414    /// (This lets us use `trailing_zeros()` to find the next set bit.) When it
415    /// is reduced to 0, we move onto the next word.
416    word: Word,
417
418    /// The offset (measured in bits) of the current word.
419    offset: usize,
420
421    /// Underlying iterator over the words.
422    iter: slice::Iter<'a, Word>,
423
424    marker: PhantomData<T>,
425}
426
427impl<'a, T: Idx> BitIter<'a, T> {
428    #[inline]
429    fn new(words: &'a [Word]) -> BitIter<'a, T> {
430        // We initialize `word` and `offset` to degenerate values. On the first
431        // call to `next()` we will fall through to getting the first word from
432        // `iter`, which sets `word` to the first word (if there is one) and
433        // `offset` to 0. Doing it this way saves us from having to maintain
434        // additional state about whether we have started.
435        BitIter {
436            word: 0,
437            offset: usize::MAX - (WORD_BITS - 1),
438            iter: words.iter(),
439            marker: PhantomData,
440        }
441    }
442}
443
444impl<'a, T: Idx> Iterator for BitIter<'a, T> {
445    type Item = T;
446    fn next(&mut self) -> Option<T> {
447        loop {
448            if self.word != 0 {
449                // Get the position of the next set bit in the current word,
450                // then clear the bit.
451                let bit_pos = self.word.trailing_zeros() as usize;
452                self.word ^= 1 << bit_pos;
453                return Some(T::new(bit_pos + self.offset));
454            }
455
456            // Move onto the next word. `wrapping_add()` is needed to handle
457            // the degenerate initial value given to `offset` in `new()`.
458            self.word = *self.iter.next()?;
459            self.offset = self.offset.wrapping_add(WORD_BITS);
460        }
461    }
462}
463
464/// A fixed-size bitset type with a partially dense, partially sparse
465/// representation. The bitset is broken into chunks, and chunks that are all
466/// zeros or all ones are represented and handled very efficiently.
467///
468/// This type is especially efficient for sets that typically have a large
469/// `domain_size` with significant stretches of all zeros or all ones, and also
470/// some stretches with lots of 0s and 1s mixed in a way that causes trouble
471/// for `IntervalSet`.
472///
473/// Best used via `MixedBitSet`, rather than directly, because `MixedBitSet`
474/// has better performance for small bitsets.
475///
476/// `T` is an index type, typically a newtyped `usize` wrapper, but it can also
477/// just be `usize`.
478///
479/// All operations that involve an element will panic if the element is equal
480/// to or greater than the domain size. All operations that involve two bitsets
481/// will panic if the bitsets have differing domain sizes.
482#[derive(PartialEq, Eq)]
483pub struct ChunkedBitSet<T> {
484    domain_size: usize,
485
486    /// The chunks. Each one contains exactly CHUNK_BITS values, except the
487    /// last one which contains 1..=CHUNK_BITS values.
488    chunks: Box<[Chunk]>,
489
490    marker: PhantomData<T>,
491}
492
493// NOTE: The chunk size is computed on-the-fly on each manipulation of a chunk.
494// This avoids storing it, as it's almost always CHUNK_BITS except for the last one.
495#[derive(Clone, Debug, PartialEq, Eq)]
496enum Chunk {
497    /// A chunk that is all zeros; we don't represent the zeros explicitly.
498    Zeros,
499
500    /// A chunk that is all ones; we don't represent the ones explicitly.
501    Ones,
502
503    /// A chunk that has a mix of zeros and ones, which are represented
504    /// explicitly and densely. It never has all zeros or all ones.
505    ///
506    /// If this is the final chunk there may be excess, unused words. This
507    /// turns out to be both simpler and have better performance than
508    /// allocating the minimum number of words, largely because we avoid having
509    /// to store the length, which would make this type larger. These excess
510    /// words are always zero, as are any excess bits in the final in-use word.
511    ///
512    /// The `ChunkSize` field is the count of 1s set in the chunk, and
513    /// must satisfy `0 < count < chunk_domain_size`.
514    ///
515    /// The words are within an `Rc` because it's surprisingly common to
516    /// duplicate an entire chunk, e.g. in `ChunkedBitSet::clone_from()`, or
517    /// when a `Mixed` chunk is union'd into a `Zeros` chunk. When we do need
518    /// to modify a chunk we use `Rc::make_mut`.
519    Mixed(ChunkSize, Rc<[Word; CHUNK_WORDS]>),
520}
521
522// This type is used a lot. Make sure it doesn't unintentionally get bigger.
523#[cfg(target_pointer_width = "64")]
524crate::static_assert_size!(Chunk, 16);
525
526impl<T> ChunkedBitSet<T> {
527    pub fn domain_size(&self) -> usize {
528        self.domain_size
529    }
530
531    #[inline]
532    fn last_chunk_size(&self) -> ChunkSize {
533        let n = self.domain_size % CHUNK_BITS;
534        if n == 0 { CHUNK_BITS as ChunkSize } else { n as ChunkSize }
535    }
536
537    /// All the chunks have a chunk_domain_size of `CHUNK_BITS` except the final one.
538    #[inline]
539    fn chunk_domain_size(&self, chunk: usize) -> ChunkSize {
540        if chunk == self.chunks.len() - 1 {
541            self.last_chunk_size()
542        } else {
543            CHUNK_BITS as ChunkSize
544        }
545    }
546
547    #[cfg(test)]
548    fn assert_valid(&self) {
549        if self.domain_size == 0 {
550            assert!(self.chunks.is_empty());
551            return;
552        }
553
554        assert!((self.chunks.len() - 1) * CHUNK_BITS <= self.domain_size);
555        assert!(self.chunks.len() * CHUNK_BITS >= self.domain_size);
556        for (chunk_index, chunk) in self.chunks.iter().enumerate() {
557            let chunk_domain_size = self.chunk_domain_size(chunk_index);
558            chunk.assert_valid(chunk_domain_size);
559        }
560    }
561}
562
563impl<T: Idx> ChunkedBitSet<T> {
564    /// Creates a new bitset with a given `domain_size` and chunk kind.
565    fn new(domain_size: usize, is_empty: bool) -> Self {
566        let chunks = if domain_size == 0 {
567            Box::new([])
568        } else {
569            vec![if is_empty { Zeros } else { Ones }; num_chunks(domain_size)].into_boxed_slice()
570        };
571        ChunkedBitSet { domain_size, chunks, marker: PhantomData }
572    }
573
574    /// Creates a new, empty bitset with a given `domain_size`.
575    #[inline]
576    pub fn new_empty(domain_size: usize) -> Self {
577        ChunkedBitSet::new(domain_size, /* is_empty */ true)
578    }
579
580    /// Creates a new, filled bitset with a given `domain_size`.
581    #[inline]
582    pub fn new_filled(domain_size: usize) -> Self {
583        ChunkedBitSet::new(domain_size, /* is_empty */ false)
584    }
585
586    pub fn clear(&mut self) {
587        self.chunks.fill_with(|| Chunk::Zeros);
588    }
589
590    #[cfg(test)]
591    fn chunks(&self) -> &[Chunk] {
592        &self.chunks
593    }
594
595    /// Count the number of bits in the set.
596    pub fn count(&self) -> usize {
597        self.chunks
598            .iter()
599            .enumerate()
600            .map(|(index, chunk)| chunk.count(self.chunk_domain_size(index)))
601            .sum()
602    }
603
604    pub fn is_empty(&self) -> bool {
605        self.chunks.iter().all(|chunk| matches!(chunk, Zeros))
606    }
607
608    /// Returns `true` if `self` contains `elem`.
609    #[inline]
610    pub fn contains(&self, elem: T) -> bool {
611        assert!(elem.index() < self.domain_size);
612        let chunk = &self.chunks[chunk_index(elem)];
613        match &chunk {
614            Zeros => false,
615            Ones => true,
616            Mixed(_, words) => {
617                let (word_index, mask) = chunk_word_index_and_mask(elem);
618                (words[word_index] & mask) != 0
619            }
620        }
621    }
622
623    #[inline]
624    pub fn iter(&self) -> ChunkedBitIter<'_, T> {
625        ChunkedBitIter::new(self)
626    }
627
628    /// Insert `elem`. Returns whether the set has changed.
629    pub fn insert(&mut self, elem: T) -> bool {
630        assert!(elem.index() < self.domain_size);
631        let chunk_index = chunk_index(elem);
632        let chunk_domain_size = self.chunk_domain_size(chunk_index);
633        let chunk = &mut self.chunks[chunk_index];
634        match *chunk {
635            Zeros => {
636                if chunk_domain_size > 1 {
637                    #[cfg(feature = "nightly")]
638                    let mut words = {
639                        // We take some effort to avoid copying the words.
640                        let words = Rc::<[Word; CHUNK_WORDS]>::new_zeroed();
641                        // SAFETY: `words` can safely be all zeroes.
642                        unsafe { words.assume_init() }
643                    };
644                    #[cfg(not(feature = "nightly"))]
645                    let mut words = {
646                        // FIXME: unconditionally use `Rc::new_zeroed` once it is stable (#129396).
647                        let words = mem::MaybeUninit::<[Word; CHUNK_WORDS]>::zeroed();
648                        // SAFETY: `words` can safely be all zeroes.
649                        let words = unsafe { words.assume_init() };
650                        // Unfortunate possibly-large copy
651                        Rc::new(words)
652                    };
653                    let words_ref = Rc::get_mut(&mut words).unwrap();
654
655                    let (word_index, mask) = chunk_word_index_and_mask(elem);
656                    words_ref[word_index] |= mask;
657                    *chunk = Mixed(1, words);
658                } else {
659                    *chunk = Ones;
660                }
661                true
662            }
663            Ones => false,
664            Mixed(ref mut count, ref mut words) => {
665                // We skip all the work if the bit is already set.
666                let (word_index, mask) = chunk_word_index_and_mask(elem);
667                if (words[word_index] & mask) == 0 {
668                    *count += 1;
669                    if *count < chunk_domain_size {
670                        let words = Rc::make_mut(words);
671                        words[word_index] |= mask;
672                    } else {
673                        *chunk = Ones;
674                    }
675                    true
676                } else {
677                    false
678                }
679            }
680        }
681    }
682
683    /// Sets all bits to true.
684    pub fn insert_all(&mut self) {
685        self.chunks.fill_with(|| Chunk::Ones);
686    }
687
688    /// Returns `true` if the set has changed.
689    pub fn remove(&mut self, elem: T) -> bool {
690        assert!(elem.index() < self.domain_size);
691        let chunk_index = chunk_index(elem);
692        let chunk_domain_size = self.chunk_domain_size(chunk_index);
693        let chunk = &mut self.chunks[chunk_index];
694        match *chunk {
695            Zeros => false,
696            Ones => {
697                if chunk_domain_size > 1 {
698                    #[cfg(feature = "nightly")]
699                    let mut words = {
700                        // We take some effort to avoid copying the words.
701                        let words = Rc::<[Word; CHUNK_WORDS]>::new_zeroed();
702                        // SAFETY: `words` can safely be all zeroes.
703                        unsafe { words.assume_init() }
704                    };
705                    #[cfg(not(feature = "nightly"))]
706                    let mut words = {
707                        // FIXME: unconditionally use `Rc::new_zeroed` once it is stable (#129396).
708                        let words = mem::MaybeUninit::<[Word; CHUNK_WORDS]>::zeroed();
709                        // SAFETY: `words` can safely be all zeroes.
710                        let words = unsafe { words.assume_init() };
711                        // Unfortunate possibly-large copy
712                        Rc::new(words)
713                    };
714                    let words_ref = Rc::get_mut(&mut words).unwrap();
715
716                    // Set only the bits in use.
717                    let num_words = num_words(chunk_domain_size as usize);
718                    words_ref[..num_words].fill(!0);
719                    clear_excess_bits_in_final_word(
720                        chunk_domain_size as usize,
721                        &mut words_ref[..num_words],
722                    );
723                    let (word_index, mask) = chunk_word_index_and_mask(elem);
724                    words_ref[word_index] &= !mask;
725                    *chunk = Mixed(chunk_domain_size - 1, words);
726                } else {
727                    *chunk = Zeros;
728                }
729                true
730            }
731            Mixed(ref mut count, ref mut words) => {
732                // We skip all the work if the bit is already clear.
733                let (word_index, mask) = chunk_word_index_and_mask(elem);
734                if (words[word_index] & mask) != 0 {
735                    *count -= 1;
736                    if *count > 0 {
737                        let words = Rc::make_mut(words);
738                        words[word_index] &= !mask;
739                    } else {
740                        *chunk = Zeros
741                    }
742                    true
743                } else {
744                    false
745                }
746            }
747        }
748    }
749
750    fn chunk_iter(&self, chunk_index: usize) -> ChunkIter<'_> {
751        let chunk_domain_size = self.chunk_domain_size(chunk_index);
752        match self.chunks.get(chunk_index) {
753            Some(Zeros) => ChunkIter::Zeros,
754            Some(Ones) => ChunkIter::Ones(0..chunk_domain_size as usize),
755            Some(Mixed(_, words)) => {
756                let num_words = num_words(chunk_domain_size as usize);
757                ChunkIter::Mixed(BitIter::new(&words[0..num_words]))
758            }
759            None => ChunkIter::Finished,
760        }
761    }
762
763    bit_relations_inherent_impls! {}
764}
765
766impl<T: Idx> BitRelations<ChunkedBitSet<T>> for ChunkedBitSet<T> {
767    fn union(&mut self, other: &ChunkedBitSet<T>) -> bool {
768        assert_eq!(self.domain_size, other.domain_size);
769
770        let num_chunks = self.chunks.len();
771        debug_assert_eq!(num_chunks, other.chunks.len());
772
773        let last_chunk_size = self.last_chunk_size();
774        debug_assert_eq!(last_chunk_size, other.last_chunk_size());
775
776        let mut changed = false;
777        for (chunk_index, (mut self_chunk, other_chunk)) in
778            self.chunks.iter_mut().zip(other.chunks.iter()).enumerate()
779        {
780            let chunk_domain_size = if chunk_index + 1 == num_chunks {
781                last_chunk_size
782            } else {
783                CHUNK_BITS as ChunkSize
784            };
785
786            match (&mut self_chunk, &other_chunk) {
787                (_, Zeros) | (Ones, _) => {}
788                (Zeros, _) | (Mixed(..), Ones) => {
789                    // `other_chunk` fully overwrites `self_chunk`
790                    *self_chunk = other_chunk.clone();
791                    changed = true;
792                }
793                (
794                    Mixed(self_chunk_count, self_chunk_words),
795                    Mixed(_other_chunk_count, other_chunk_words),
796                ) => {
797                    // First check if the operation would change
798                    // `self_chunk.words`. If not, we can avoid allocating some
799                    // words, and this happens often enough that it's a
800                    // performance win. Also, we only need to operate on the
801                    // in-use words, hence the slicing.
802                    let num_words = num_words(chunk_domain_size as usize);
803
804                    // If both sides are the same, nothing will change. This
805                    // case is very common and it's a pretty fast check, so
806                    // it's a performance win to do it.
807                    if self_chunk_words[0..num_words] == other_chunk_words[0..num_words] {
808                        continue;
809                    }
810
811                    // Do a more precise "will anything change?" test. Also a
812                    // performance win.
813                    let op = |a, b| a | b;
814                    if !bitwise_changes(
815                        &self_chunk_words[0..num_words],
816                        &other_chunk_words[0..num_words],
817                        op,
818                    ) {
819                        continue;
820                    }
821
822                    // If we reach here, `self_chunk_words` is definitely changing.
823                    let self_chunk_words = Rc::make_mut(self_chunk_words);
824                    let has_changed = bitwise(
825                        &mut self_chunk_words[0..num_words],
826                        &other_chunk_words[0..num_words],
827                        op,
828                    );
829                    debug_assert!(has_changed);
830                    *self_chunk_count = count_ones(&self_chunk_words[0..num_words]) as ChunkSize;
831                    if *self_chunk_count == chunk_domain_size {
832                        *self_chunk = Ones;
833                    }
834                    changed = true;
835                }
836            }
837        }
838        changed
839    }
840
841    fn subtract(&mut self, other: &ChunkedBitSet<T>) -> bool {
842        assert_eq!(self.domain_size, other.domain_size);
843
844        let num_chunks = self.chunks.len();
845        debug_assert_eq!(num_chunks, other.chunks.len());
846
847        let last_chunk_size = self.last_chunk_size();
848        debug_assert_eq!(last_chunk_size, other.last_chunk_size());
849
850        let mut changed = false;
851        for (chunk_index, (mut self_chunk, other_chunk)) in
852            self.chunks.iter_mut().zip(other.chunks.iter()).enumerate()
853        {
854            let chunk_domain_size = if chunk_index + 1 == num_chunks {
855                last_chunk_size
856            } else {
857                CHUNK_BITS as ChunkSize
858            };
859
860            match (&mut self_chunk, &other_chunk) {
861                (Zeros, _) | (_, Zeros) => {}
862                (Ones | Mixed(..), Ones) => {
863                    changed = true;
864                    *self_chunk = Zeros;
865                }
866                (Ones, Mixed(other_chunk_count, other_chunk_words)) => {
867                    changed = true;
868                    let num_words = num_words(chunk_domain_size as usize);
869                    debug_assert!(num_words > 0 && num_words <= CHUNK_WORDS);
870                    let mut tail_mask =
871                        1 << (chunk_domain_size - ((num_words - 1) * WORD_BITS) as u16) - 1;
872                    let mut self_chunk_words = **other_chunk_words;
873                    for word in self_chunk_words[0..num_words].iter_mut().rev() {
874                        *word = !*word & tail_mask;
875                        tail_mask = Word::MAX;
876                    }
877                    let self_chunk_count = chunk_domain_size - *other_chunk_count;
878                    debug_assert_eq!(
879                        self_chunk_count,
880                        count_ones(&self_chunk_words[0..num_words]) as ChunkSize
881                    );
882                    *self_chunk = Mixed(self_chunk_count, Rc::new(self_chunk_words));
883                }
884                (
885                    Mixed(self_chunk_count, self_chunk_words),
886                    Mixed(_other_chunk_count, other_chunk_words),
887                ) => {
888                    // See `ChunkedBitSet::union` for details on what is happening here.
889                    let num_words = num_words(chunk_domain_size as usize);
890                    let op = |a: Word, b: Word| a & !b;
891                    if !bitwise_changes(
892                        &self_chunk_words[0..num_words],
893                        &other_chunk_words[0..num_words],
894                        op,
895                    ) {
896                        continue;
897                    }
898
899                    let self_chunk_words = Rc::make_mut(self_chunk_words);
900                    let has_changed = bitwise(
901                        &mut self_chunk_words[0..num_words],
902                        &other_chunk_words[0..num_words],
903                        op,
904                    );
905                    debug_assert!(has_changed);
906                    *self_chunk_count = count_ones(&self_chunk_words[0..num_words]) as ChunkSize;
907                    if *self_chunk_count == 0 {
908                        *self_chunk = Zeros;
909                    }
910                    changed = true;
911                }
912            }
913        }
914        changed
915    }
916
917    fn intersect(&mut self, other: &ChunkedBitSet<T>) -> bool {
918        assert_eq!(self.domain_size, other.domain_size);
919
920        let num_chunks = self.chunks.len();
921        debug_assert_eq!(num_chunks, other.chunks.len());
922
923        let last_chunk_size = self.last_chunk_size();
924        debug_assert_eq!(last_chunk_size, other.last_chunk_size());
925
926        let mut changed = false;
927        for (chunk_index, (mut self_chunk, other_chunk)) in
928            self.chunks.iter_mut().zip(other.chunks.iter()).enumerate()
929        {
930            let chunk_domain_size = if chunk_index + 1 == num_chunks {
931                last_chunk_size
932            } else {
933                CHUNK_BITS as ChunkSize
934            };
935
936            match (&mut self_chunk, &other_chunk) {
937                (Zeros, _) | (_, Ones) => {}
938                (Ones, Zeros | Mixed(..)) | (Mixed(..), Zeros) => {
939                    changed = true;
940                    *self_chunk = other_chunk.clone();
941                }
942                (
943                    Mixed(self_chunk_count, self_chunk_words),
944                    Mixed(_other_chunk_count, other_chunk_words),
945                ) => {
946                    // See `ChunkedBitSet::union` for details on what is happening here.
947                    let num_words = num_words(chunk_domain_size as usize);
948                    let op = |a, b| a & b;
949                    if !bitwise_changes(
950                        &self_chunk_words[0..num_words],
951                        &other_chunk_words[0..num_words],
952                        op,
953                    ) {
954                        continue;
955                    }
956
957                    let self_chunk_words = Rc::make_mut(self_chunk_words);
958                    let has_changed = bitwise(
959                        &mut self_chunk_words[0..num_words],
960                        &other_chunk_words[0..num_words],
961                        op,
962                    );
963                    debug_assert!(has_changed);
964                    *self_chunk_count = count_ones(&self_chunk_words[0..num_words]) as ChunkSize;
965                    if *self_chunk_count == 0 {
966                        *self_chunk = Zeros;
967                    }
968                    changed = true;
969                }
970            }
971        }
972
973        changed
974    }
975}
976
977impl<T> Clone for ChunkedBitSet<T> {
978    fn clone(&self) -> Self {
979        ChunkedBitSet {
980            domain_size: self.domain_size,
981            chunks: self.chunks.clone(),
982            marker: PhantomData,
983        }
984    }
985
986    /// WARNING: this implementation of clone_from will panic if the two
987    /// bitsets have different domain sizes. This constraint is not inherent to
988    /// `clone_from`, but it works with the existing call sites and allows a
989    /// faster implementation, which is important because this function is hot.
990    fn clone_from(&mut self, from: &Self) {
991        assert_eq!(self.domain_size, from.domain_size);
992        debug_assert_eq!(self.chunks.len(), from.chunks.len());
993
994        self.chunks.clone_from(&from.chunks)
995    }
996}
997
998pub struct ChunkedBitIter<'a, T: Idx> {
999    bit_set: &'a ChunkedBitSet<T>,
1000
1001    // The index of the current chunk.
1002    chunk_index: usize,
1003
1004    // The sub-iterator for the current chunk.
1005    chunk_iter: ChunkIter<'a>,
1006}
1007
1008impl<'a, T: Idx> ChunkedBitIter<'a, T> {
1009    #[inline]
1010    fn new(bit_set: &'a ChunkedBitSet<T>) -> ChunkedBitIter<'a, T> {
1011        ChunkedBitIter { bit_set, chunk_index: 0, chunk_iter: bit_set.chunk_iter(0) }
1012    }
1013}
1014
1015impl<'a, T: Idx> Iterator for ChunkedBitIter<'a, T> {
1016    type Item = T;
1017
1018    fn next(&mut self) -> Option<T> {
1019        loop {
1020            match &mut self.chunk_iter {
1021                ChunkIter::Zeros => {}
1022                ChunkIter::Ones(iter) => {
1023                    if let Some(next) = iter.next() {
1024                        return Some(T::new(next + self.chunk_index * CHUNK_BITS));
1025                    }
1026                }
1027                ChunkIter::Mixed(iter) => {
1028                    if let Some(next) = iter.next() {
1029                        return Some(T::new(next + self.chunk_index * CHUNK_BITS));
1030                    }
1031                }
1032                ChunkIter::Finished => return None,
1033            }
1034            self.chunk_index += 1;
1035            self.chunk_iter = self.bit_set.chunk_iter(self.chunk_index);
1036        }
1037    }
1038}
1039
1040impl Chunk {
1041    #[cfg(test)]
1042    fn assert_valid(&self, chunk_domain_size: ChunkSize) {
1043        assert!(chunk_domain_size as usize <= CHUNK_BITS);
1044        match *self {
1045            Zeros | Ones => {}
1046            Mixed(count, ref words) => {
1047                assert!(0 < count && count < chunk_domain_size);
1048
1049                // Check the number of set bits matches `count`.
1050                assert_eq!(count_ones(words.as_slice()) as ChunkSize, count);
1051
1052                // Check the not-in-use words are all zeroed.
1053                let num_words = num_words(chunk_domain_size as usize);
1054                if num_words < CHUNK_WORDS {
1055                    assert_eq!(count_ones(&words[num_words..]) as ChunkSize, 0);
1056                }
1057            }
1058        }
1059    }
1060
1061    /// Count the number of 1s in the chunk.
1062    fn count(&self, chunk_domain_size: ChunkSize) -> usize {
1063        match *self {
1064            Zeros => 0,
1065            Ones => chunk_domain_size as usize,
1066            Mixed(count, _) => count as usize,
1067        }
1068    }
1069}
1070
1071enum ChunkIter<'a> {
1072    Zeros,
1073    Ones(Range<usize>),
1074    Mixed(BitIter<'a, usize>),
1075    Finished,
1076}
1077
1078impl<T: Idx> fmt::Debug for ChunkedBitSet<T> {
1079    fn fmt(&self, w: &mut fmt::Formatter<'_>) -> fmt::Result {
1080        w.debug_list().entries(self.iter()).finish()
1081    }
1082}
1083
1084/// Sets `out_vec[i] = op(out_vec[i], in_vec[i])` for each index `i` in both
1085/// slices. The slices must have the same length.
1086///
1087/// Returns true if at least one bit in `out_vec` was changed.
1088///
1089/// ## Warning
1090/// Some bitwise operations (e.g. union-not, xor) can set output bits that were
1091/// unset in in both inputs. If this happens in the last word/chunk of a bitset,
1092/// it can cause the bitset to contain out-of-domain values, which need to
1093/// be cleared with `clear_excess_bits_in_final_word`. This also makes the
1094/// "changed" return value unreliable, because the change might have only
1095/// affected excess bits.
1096#[inline]
1097fn bitwise<Op>(out_vec: &mut [Word], in_vec: &[Word], op: Op) -> bool
1098where
1099    Op: Fn(Word, Word) -> Word,
1100{
1101    assert_eq!(out_vec.len(), in_vec.len());
1102    let mut changed = 0;
1103    for (out_elem, in_elem) in iter::zip(out_vec, in_vec) {
1104        let old_val = *out_elem;
1105        let new_val = op(old_val, *in_elem);
1106        *out_elem = new_val;
1107        // This is essentially equivalent to a != with changed being a bool, but
1108        // in practice this code gets auto-vectorized by the compiler for most
1109        // operators. Using != here causes us to generate quite poor code as the
1110        // compiler tries to go back to a boolean on each loop iteration.
1111        changed |= old_val ^ new_val;
1112    }
1113    changed != 0
1114}
1115
1116/// Does this bitwise operation change `out_vec`?
1117#[inline]
1118fn bitwise_changes<Op>(out_vec: &[Word], in_vec: &[Word], op: Op) -> bool
1119where
1120    Op: Fn(Word, Word) -> Word,
1121{
1122    assert_eq!(out_vec.len(), in_vec.len());
1123    for (out_elem, in_elem) in iter::zip(out_vec, in_vec) {
1124        let old_val = *out_elem;
1125        let new_val = op(old_val, *in_elem);
1126        if old_val != new_val {
1127            return true;
1128        }
1129    }
1130    false
1131}
1132
1133/// A bitset with a mixed representation, using `DenseBitSet` for small and
1134/// medium bitsets, and `ChunkedBitSet` for large bitsets, i.e. those with
1135/// enough bits for at least two chunks. This is a good choice for many bitsets
1136/// that can have large domain sizes (e.g. 5000+).
1137///
1138/// `T` is an index type, typically a newtyped `usize` wrapper, but it can also
1139/// just be `usize`.
1140///
1141/// All operations that involve an element will panic if the element is equal
1142/// to or greater than the domain size. All operations that involve two bitsets
1143/// will panic if the bitsets have differing domain sizes.
1144#[derive(PartialEq, Eq)]
1145pub enum MixedBitSet<T> {
1146    Small(DenseBitSet<T>),
1147    Large(ChunkedBitSet<T>),
1148}
1149
1150impl<T> MixedBitSet<T> {
1151    pub fn domain_size(&self) -> usize {
1152        match self {
1153            MixedBitSet::Small(set) => set.domain_size(),
1154            MixedBitSet::Large(set) => set.domain_size(),
1155        }
1156    }
1157}
1158
1159impl<T: Idx> MixedBitSet<T> {
1160    #[inline]
1161    pub fn new_empty(domain_size: usize) -> MixedBitSet<T> {
1162        if domain_size <= CHUNK_BITS {
1163            MixedBitSet::Small(DenseBitSet::new_empty(domain_size))
1164        } else {
1165            MixedBitSet::Large(ChunkedBitSet::new_empty(domain_size))
1166        }
1167    }
1168
1169    #[inline]
1170    pub fn is_empty(&self) -> bool {
1171        match self {
1172            MixedBitSet::Small(set) => set.is_empty(),
1173            MixedBitSet::Large(set) => set.is_empty(),
1174        }
1175    }
1176
1177    #[inline]
1178    pub fn contains(&self, elem: T) -> bool {
1179        match self {
1180            MixedBitSet::Small(set) => set.contains(elem),
1181            MixedBitSet::Large(set) => set.contains(elem),
1182        }
1183    }
1184
1185    #[inline]
1186    pub fn insert(&mut self, elem: T) -> bool {
1187        match self {
1188            MixedBitSet::Small(set) => set.insert(elem),
1189            MixedBitSet::Large(set) => set.insert(elem),
1190        }
1191    }
1192
1193    pub fn insert_all(&mut self) {
1194        match self {
1195            MixedBitSet::Small(set) => set.insert_all(),
1196            MixedBitSet::Large(set) => set.insert_all(),
1197        }
1198    }
1199
1200    #[inline]
1201    pub fn remove(&mut self, elem: T) -> bool {
1202        match self {
1203            MixedBitSet::Small(set) => set.remove(elem),
1204            MixedBitSet::Large(set) => set.remove(elem),
1205        }
1206    }
1207
1208    pub fn iter(&self) -> MixedBitIter<'_, T> {
1209        match self {
1210            MixedBitSet::Small(set) => MixedBitIter::Small(set.iter()),
1211            MixedBitSet::Large(set) => MixedBitIter::Large(set.iter()),
1212        }
1213    }
1214
1215    #[inline]
1216    pub fn clear(&mut self) {
1217        match self {
1218            MixedBitSet::Small(set) => set.clear(),
1219            MixedBitSet::Large(set) => set.clear(),
1220        }
1221    }
1222
1223    bit_relations_inherent_impls! {}
1224}
1225
1226impl<T> Clone for MixedBitSet<T> {
1227    fn clone(&self) -> Self {
1228        match self {
1229            MixedBitSet::Small(set) => MixedBitSet::Small(set.clone()),
1230            MixedBitSet::Large(set) => MixedBitSet::Large(set.clone()),
1231        }
1232    }
1233
1234    /// WARNING: this implementation of clone_from may panic if the two
1235    /// bitsets have different domain sizes. This constraint is not inherent to
1236    /// `clone_from`, but it works with the existing call sites and allows a
1237    /// faster implementation, which is important because this function is hot.
1238    fn clone_from(&mut self, from: &Self) {
1239        match (self, from) {
1240            (MixedBitSet::Small(set), MixedBitSet::Small(from)) => set.clone_from(from),
1241            (MixedBitSet::Large(set), MixedBitSet::Large(from)) => set.clone_from(from),
1242            _ => panic!("MixedBitSet size mismatch"),
1243        }
1244    }
1245}
1246
1247impl<T: Idx> BitRelations<MixedBitSet<T>> for MixedBitSet<T> {
1248    fn union(&mut self, other: &MixedBitSet<T>) -> bool {
1249        match (self, other) {
1250            (MixedBitSet::Small(set), MixedBitSet::Small(other)) => set.union(other),
1251            (MixedBitSet::Large(set), MixedBitSet::Large(other)) => set.union(other),
1252            _ => panic!("MixedBitSet size mismatch"),
1253        }
1254    }
1255
1256    fn subtract(&mut self, other: &MixedBitSet<T>) -> bool {
1257        match (self, other) {
1258            (MixedBitSet::Small(set), MixedBitSet::Small(other)) => set.subtract(other),
1259            (MixedBitSet::Large(set), MixedBitSet::Large(other)) => set.subtract(other),
1260            _ => panic!("MixedBitSet size mismatch"),
1261        }
1262    }
1263
1264    fn intersect(&mut self, _other: &MixedBitSet<T>) -> bool {
1265        unimplemented!("implement if/when necessary");
1266    }
1267}
1268
1269impl<T: Idx> fmt::Debug for MixedBitSet<T> {
1270    fn fmt(&self, w: &mut fmt::Formatter<'_>) -> fmt::Result {
1271        match self {
1272            MixedBitSet::Small(set) => set.fmt(w),
1273            MixedBitSet::Large(set) => set.fmt(w),
1274        }
1275    }
1276}
1277
1278pub enum MixedBitIter<'a, T: Idx> {
1279    Small(BitIter<'a, T>),
1280    Large(ChunkedBitIter<'a, T>),
1281}
1282
1283impl<'a, T: Idx> Iterator for MixedBitIter<'a, T> {
1284    type Item = T;
1285    fn next(&mut self) -> Option<T> {
1286        match self {
1287            MixedBitIter::Small(iter) => iter.next(),
1288            MixedBitIter::Large(iter) => iter.next(),
1289        }
1290    }
1291}
1292
1293/// A resizable bitset type with a dense representation.
1294///
1295/// `T` is an index type, typically a newtyped `usize` wrapper, but it can also
1296/// just be `usize`.
1297///
1298/// All operations that involve an element will panic if the element is equal
1299/// to or greater than the domain size.
1300#[derive(Clone, Debug, PartialEq)]
1301pub struct GrowableBitSet<T: Idx> {
1302    bit_set: DenseBitSet<T>,
1303}
1304
1305impl<T: Idx> Default for GrowableBitSet<T> {
1306    fn default() -> Self {
1307        GrowableBitSet::new_empty()
1308    }
1309}
1310
1311impl<T: Idx> GrowableBitSet<T> {
1312    /// Ensure that the set can hold at least `min_domain_size` elements.
1313    pub fn ensure(&mut self, min_domain_size: usize) {
1314        if self.bit_set.domain_size < min_domain_size {
1315            self.bit_set.domain_size = min_domain_size;
1316        }
1317
1318        let min_num_words = num_words(min_domain_size);
1319        if self.bit_set.words.len() < min_num_words {
1320            self.bit_set.words.resize(min_num_words, 0)
1321        }
1322    }
1323
1324    pub fn new_empty() -> GrowableBitSet<T> {
1325        GrowableBitSet { bit_set: DenseBitSet::new_empty(0) }
1326    }
1327
1328    pub fn with_capacity(capacity: usize) -> GrowableBitSet<T> {
1329        GrowableBitSet { bit_set: DenseBitSet::new_empty(capacity) }
1330    }
1331
1332    /// Returns `true` if the set has changed.
1333    #[inline]
1334    pub fn insert(&mut self, elem: T) -> bool {
1335        self.ensure(elem.index() + 1);
1336        self.bit_set.insert(elem)
1337    }
1338
1339    /// Returns `true` if the set has changed.
1340    #[inline]
1341    pub fn remove(&mut self, elem: T) -> bool {
1342        self.ensure(elem.index() + 1);
1343        self.bit_set.remove(elem)
1344    }
1345
1346    #[inline]
1347    pub fn is_empty(&self) -> bool {
1348        self.bit_set.is_empty()
1349    }
1350
1351    #[inline]
1352    pub fn contains(&self, elem: T) -> bool {
1353        let (word_index, mask) = word_index_and_mask(elem);
1354        self.bit_set.words.get(word_index).is_some_and(|word| (word & mask) != 0)
1355    }
1356
1357    #[inline]
1358    pub fn iter(&self) -> BitIter<'_, T> {
1359        self.bit_set.iter()
1360    }
1361
1362    #[inline]
1363    pub fn len(&self) -> usize {
1364        self.bit_set.count()
1365    }
1366}
1367
1368impl<T: Idx> From<DenseBitSet<T>> for GrowableBitSet<T> {
1369    fn from(bit_set: DenseBitSet<T>) -> Self {
1370        Self { bit_set }
1371    }
1372}
1373
1374/// A fixed-size 2D bit matrix type with a dense representation.
1375///
1376/// `R` and `C` are index types used to identify rows and columns respectively;
1377/// typically newtyped `usize` wrappers, but they can also just be `usize`.
1378///
1379/// All operations that involve a row and/or column index will panic if the
1380/// index exceeds the relevant bound.
1381#[cfg_attr(feature = "nightly", derive(Decodable_NoContext, Encodable_NoContext))]
1382#[derive(Clone, Eq, PartialEq, Hash)]
1383pub struct BitMatrix<R: Idx, C: Idx> {
1384    num_rows: usize,
1385    num_columns: usize,
1386    words: Vec<Word>,
1387    marker: PhantomData<(R, C)>,
1388}
1389
1390impl<R: Idx, C: Idx> BitMatrix<R, C> {
1391    /// Creates a new `rows x columns` matrix, initially empty.
1392    pub fn new(num_rows: usize, num_columns: usize) -> BitMatrix<R, C> {
1393        // For every element, we need one bit for every other
1394        // element. Round up to an even number of words.
1395        let words_per_row = num_words(num_columns);
1396        BitMatrix {
1397            num_rows,
1398            num_columns,
1399            words: vec![0; num_rows * words_per_row],
1400            marker: PhantomData,
1401        }
1402    }
1403
1404    /// Creates a new matrix, with `row` used as the value for every row.
1405    pub fn from_row_n(row: &DenseBitSet<C>, num_rows: usize) -> BitMatrix<R, C> {
1406        let num_columns = row.domain_size();
1407        let words_per_row = num_words(num_columns);
1408        assert_eq!(words_per_row, row.words.len());
1409        BitMatrix {
1410            num_rows,
1411            num_columns,
1412            words: iter::repeat(&row.words).take(num_rows).flatten().cloned().collect(),
1413            marker: PhantomData,
1414        }
1415    }
1416
1417    pub fn rows(&self) -> impl Iterator<Item = R> {
1418        (0..self.num_rows).map(R::new)
1419    }
1420
1421    /// The range of bits for a given row.
1422    fn range(&self, row: R) -> (usize, usize) {
1423        let words_per_row = num_words(self.num_columns);
1424        let start = row.index() * words_per_row;
1425        (start, start + words_per_row)
1426    }
1427
1428    /// Sets the cell at `(row, column)` to true. Put another way, insert
1429    /// `column` to the bitset for `row`.
1430    ///
1431    /// Returns `true` if this changed the matrix.
1432    pub fn insert(&mut self, row: R, column: C) -> bool {
1433        assert!(row.index() < self.num_rows && column.index() < self.num_columns);
1434        let (start, _) = self.range(row);
1435        let (word_index, mask) = word_index_and_mask(column);
1436        let words = &mut self.words[..];
1437        let word = words[start + word_index];
1438        let new_word = word | mask;
1439        words[start + word_index] = new_word;
1440        word != new_word
1441    }
1442
1443    /// Do the bits from `row` contain `column`? Put another way, is
1444    /// the matrix cell at `(row, column)` true?  Put yet another way,
1445    /// if the matrix represents (transitive) reachability, can
1446    /// `row` reach `column`?
1447    pub fn contains(&self, row: R, column: C) -> bool {
1448        assert!(row.index() < self.num_rows && column.index() < self.num_columns);
1449        let (start, _) = self.range(row);
1450        let (word_index, mask) = word_index_and_mask(column);
1451        (self.words[start + word_index] & mask) != 0
1452    }
1453
1454    /// Returns those indices that are true in rows `a` and `b`. This
1455    /// is an *O*(*n*) operation where *n* is the number of elements
1456    /// (somewhat independent from the actual size of the
1457    /// intersection, in particular).
1458    pub fn intersect_rows(&self, row1: R, row2: R) -> Vec<C> {
1459        assert!(row1.index() < self.num_rows && row2.index() < self.num_rows);
1460        let (row1_start, row1_end) = self.range(row1);
1461        let (row2_start, row2_end) = self.range(row2);
1462        let mut result = Vec::with_capacity(self.num_columns);
1463        for (base, (i, j)) in (row1_start..row1_end).zip(row2_start..row2_end).enumerate() {
1464            let mut v = self.words[i] & self.words[j];
1465            for bit in 0..WORD_BITS {
1466                if v == 0 {
1467                    break;
1468                }
1469                if v & 0x1 != 0 {
1470                    result.push(C::new(base * WORD_BITS + bit));
1471                }
1472                v >>= 1;
1473            }
1474        }
1475        result
1476    }
1477
1478    /// Adds the bits from row `read` to the bits from row `write`, and
1479    /// returns `true` if anything changed.
1480    ///
1481    /// This is used when computing transitive reachability because if
1482    /// you have an edge `write -> read`, because in that case
1483    /// `write` can reach everything that `read` can (and
1484    /// potentially more).
1485    pub fn union_rows(&mut self, read: R, write: R) -> bool {
1486        assert!(read.index() < self.num_rows && write.index() < self.num_rows);
1487        let (read_start, read_end) = self.range(read);
1488        let (write_start, write_end) = self.range(write);
1489        let words = &mut self.words[..];
1490        let mut changed = 0;
1491        for (read_index, write_index) in iter::zip(read_start..read_end, write_start..write_end) {
1492            let word = words[write_index];
1493            let new_word = word | words[read_index];
1494            words[write_index] = new_word;
1495            // See `bitwise` for the rationale.
1496            changed |= word ^ new_word;
1497        }
1498        changed != 0
1499    }
1500
1501    /// Adds the bits from `with` to the bits from row `write`, and
1502    /// returns `true` if anything changed.
1503    pub fn union_row_with(&mut self, with: &DenseBitSet<C>, write: R) -> bool {
1504        assert!(write.index() < self.num_rows);
1505        assert_eq!(with.domain_size(), self.num_columns);
1506        let (write_start, write_end) = self.range(write);
1507        bitwise(&mut self.words[write_start..write_end], &with.words, |a, b| a | b)
1508    }
1509
1510    /// Sets every cell in `row` to true.
1511    pub fn insert_all_into_row(&mut self, row: R) {
1512        assert!(row.index() < self.num_rows);
1513        let (start, end) = self.range(row);
1514        let words = &mut self.words[..];
1515        for index in start..end {
1516            words[index] = !0;
1517        }
1518        clear_excess_bits_in_final_word(self.num_columns, &mut self.words[..end]);
1519    }
1520
1521    /// Gets a slice of the underlying words.
1522    pub fn words(&self) -> &[Word] {
1523        &self.words
1524    }
1525
1526    /// Iterates through all the columns set to true in a given row of
1527    /// the matrix.
1528    pub fn iter(&self, row: R) -> BitIter<'_, C> {
1529        assert!(row.index() < self.num_rows);
1530        let (start, end) = self.range(row);
1531        BitIter::new(&self.words[start..end])
1532    }
1533
1534    /// Returns the number of elements in `row`.
1535    pub fn count(&self, row: R) -> usize {
1536        let (start, end) = self.range(row);
1537        count_ones(&self.words[start..end])
1538    }
1539}
1540
1541impl<R: Idx, C: Idx> fmt::Debug for BitMatrix<R, C> {
1542    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
1543        /// Forces its contents to print in regular mode instead of alternate mode.
1544        struct OneLinePrinter<T>(T);
1545        impl<T: fmt::Debug> fmt::Debug for OneLinePrinter<T> {
1546            fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
1547                write!(fmt, "{:?}", self.0)
1548            }
1549        }
1550
1551        write!(fmt, "BitMatrix({}x{}) ", self.num_rows, self.num_columns)?;
1552        let items = self.rows().flat_map(|r| self.iter(r).map(move |c| (r, c)));
1553        fmt.debug_set().entries(items.map(OneLinePrinter)).finish()
1554    }
1555}
1556
1557/// A fixed-column-size, variable-row-size 2D bit matrix with a moderately
1558/// sparse representation.
1559///
1560/// Initially, every row has no explicit representation. If any bit within a row
1561/// is set, the entire row is instantiated as `Some(<DenseBitSet>)`.
1562/// Furthermore, any previously uninstantiated rows prior to it will be
1563/// instantiated as `None`. Those prior rows may themselves become fully
1564/// instantiated later on if any of their bits are set.
1565///
1566/// `R` and `C` are index types used to identify rows and columns respectively;
1567/// typically newtyped `usize` wrappers, but they can also just be `usize`.
1568#[derive(Clone, Debug)]
1569pub struct SparseBitMatrix<R, C>
1570where
1571    R: Idx,
1572    C: Idx,
1573{
1574    num_columns: usize,
1575    rows: IndexVec<R, Option<DenseBitSet<C>>>,
1576}
1577
1578impl<R: Idx, C: Idx> SparseBitMatrix<R, C> {
1579    /// Creates a new empty sparse bit matrix with no rows or columns.
1580    pub fn new(num_columns: usize) -> Self {
1581        Self { num_columns, rows: IndexVec::new() }
1582    }
1583
1584    fn ensure_row(&mut self, row: R) -> &mut DenseBitSet<C> {
1585        // Instantiate any missing rows up to and including row `row` with an empty `DenseBitSet`.
1586        // Then replace row `row` with a full `DenseBitSet` if necessary.
1587        self.rows.get_or_insert_with(row, || DenseBitSet::new_empty(self.num_columns))
1588    }
1589
1590    /// Sets the cell at `(row, column)` to true. Put another way, insert
1591    /// `column` to the bitset for `row`.
1592    ///
1593    /// Returns `true` if this changed the matrix.
1594    pub fn insert(&mut self, row: R, column: C) -> bool {
1595        self.ensure_row(row).insert(column)
1596    }
1597
1598    /// Sets the cell at `(row, column)` to false. Put another way, delete
1599    /// `column` from the bitset for `row`. Has no effect if `row` does not
1600    /// exist.
1601    ///
1602    /// Returns `true` if this changed the matrix.
1603    pub fn remove(&mut self, row: R, column: C) -> bool {
1604        match self.rows.get_mut(row) {
1605            Some(Some(row)) => row.remove(column),
1606            _ => false,
1607        }
1608    }
1609
1610    /// Sets all columns at `row` to false. Has no effect if `row` does
1611    /// not exist.
1612    pub fn clear(&mut self, row: R) {
1613        if let Some(Some(row)) = self.rows.get_mut(row) {
1614            row.clear();
1615        }
1616    }
1617
1618    /// Do the bits from `row` contain `column`? Put another way, is
1619    /// the matrix cell at `(row, column)` true?  Put yet another way,
1620    /// if the matrix represents (transitive) reachability, can
1621    /// `row` reach `column`?
1622    pub fn contains(&self, row: R, column: C) -> bool {
1623        self.row(row).is_some_and(|r| r.contains(column))
1624    }
1625
1626    /// Adds the bits from row `read` to the bits from row `write`, and
1627    /// returns `true` if anything changed.
1628    ///
1629    /// This is used when computing transitive reachability because if
1630    /// you have an edge `write -> read`, because in that case
1631    /// `write` can reach everything that `read` can (and
1632    /// potentially more).
1633    pub fn union_rows(&mut self, read: R, write: R) -> bool {
1634        if read == write || self.row(read).is_none() {
1635            return false;
1636        }
1637
1638        self.ensure_row(write);
1639        if let (Some(read_row), Some(write_row)) = self.rows.pick2_mut(read, write) {
1640            write_row.union(read_row)
1641        } else {
1642            unreachable!()
1643        }
1644    }
1645
1646    /// Insert all bits in the given row.
1647    pub fn insert_all_into_row(&mut self, row: R) {
1648        self.ensure_row(row).insert_all();
1649    }
1650
1651    pub fn rows(&self) -> impl Iterator<Item = R> {
1652        self.rows.indices()
1653    }
1654
1655    /// Iterates through all the columns set to true in a given row of
1656    /// the matrix.
1657    pub fn iter(&self, row: R) -> impl Iterator<Item = C> {
1658        self.row(row).into_iter().flat_map(|r| r.iter())
1659    }
1660
1661    pub fn row(&self, row: R) -> Option<&DenseBitSet<C>> {
1662        self.rows.get(row)?.as_ref()
1663    }
1664
1665    /// Intersects `row` with `set`. `set` can be either `DenseBitSet` or
1666    /// `ChunkedBitSet`. Has no effect if `row` does not exist.
1667    ///
1668    /// Returns true if the row was changed.
1669    pub fn intersect_row<Set>(&mut self, row: R, set: &Set) -> bool
1670    where
1671        DenseBitSet<C>: BitRelations<Set>,
1672    {
1673        match self.rows.get_mut(row) {
1674            Some(Some(row)) => row.intersect(set),
1675            _ => false,
1676        }
1677    }
1678
1679    /// Subtracts `set` from `row`. `set` can be either `DenseBitSet` or
1680    /// `ChunkedBitSet`. Has no effect if `row` does not exist.
1681    ///
1682    /// Returns true if the row was changed.
1683    pub fn subtract_row<Set>(&mut self, row: R, set: &Set) -> bool
1684    where
1685        DenseBitSet<C>: BitRelations<Set>,
1686    {
1687        match self.rows.get_mut(row) {
1688            Some(Some(row)) => row.subtract(set),
1689            _ => false,
1690        }
1691    }
1692
1693    /// Unions `row` with `set`. `set` can be either `DenseBitSet` or
1694    /// `ChunkedBitSet`.
1695    ///
1696    /// Returns true if the row was changed.
1697    pub fn union_row<Set>(&mut self, row: R, set: &Set) -> bool
1698    where
1699        DenseBitSet<C>: BitRelations<Set>,
1700    {
1701        self.ensure_row(row).union(set)
1702    }
1703}
1704
1705#[inline]
1706fn num_words<T: Idx>(domain_size: T) -> usize {
1707    domain_size.index().div_ceil(WORD_BITS)
1708}
1709
1710#[inline]
1711fn num_chunks<T: Idx>(domain_size: T) -> usize {
1712    assert!(domain_size.index() > 0);
1713    domain_size.index().div_ceil(CHUNK_BITS)
1714}
1715
1716#[inline]
1717fn word_index_and_mask<T: Idx>(elem: T) -> (usize, Word) {
1718    let elem = elem.index();
1719    let word_index = elem / WORD_BITS;
1720    let mask = 1 << (elem % WORD_BITS);
1721    (word_index, mask)
1722}
1723
1724#[inline]
1725fn chunk_index<T: Idx>(elem: T) -> usize {
1726    elem.index() / CHUNK_BITS
1727}
1728
1729#[inline]
1730fn chunk_word_index_and_mask<T: Idx>(elem: T) -> (usize, Word) {
1731    let chunk_elem = elem.index() % CHUNK_BITS;
1732    word_index_and_mask(chunk_elem)
1733}
1734
1735fn clear_excess_bits_in_final_word(domain_size: usize, words: &mut [Word]) {
1736    let num_bits_in_final_word = domain_size % WORD_BITS;
1737    if num_bits_in_final_word > 0 {
1738        let mask = (1 << num_bits_in_final_word) - 1;
1739        words[words.len() - 1] &= mask;
1740    }
1741}
1742
1743#[inline]
1744fn max_bit(word: Word) -> usize {
1745    WORD_BITS - 1 - word.leading_zeros() as usize
1746}
1747
1748#[inline]
1749fn count_ones(words: &[Word]) -> usize {
1750    words.iter().map(|word| word.count_ones() as usize).sum()
1751}
1752
1753/// Integral type used to represent the bit set.
1754pub trait FiniteBitSetTy:
1755    BitAnd<Output = Self>
1756    + BitAndAssign
1757    + BitOrAssign
1758    + Clone
1759    + Copy
1760    + Shl
1761    + Not<Output = Self>
1762    + PartialEq
1763    + Sized
1764{
1765    /// Size of the domain representable by this type, e.g. 64 for `u64`.
1766    const DOMAIN_SIZE: u32;
1767
1768    /// Value which represents the `FiniteBitSet` having every bit set.
1769    const FILLED: Self;
1770    /// Value which represents the `FiniteBitSet` having no bits set.
1771    const EMPTY: Self;
1772
1773    /// Value for one as the integral type.
1774    const ONE: Self;
1775    /// Value for zero as the integral type.
1776    const ZERO: Self;
1777
1778    /// Perform a checked left shift on the integral type.
1779    fn checked_shl(self, rhs: u32) -> Option<Self>;
1780    /// Perform a checked right shift on the integral type.
1781    fn checked_shr(self, rhs: u32) -> Option<Self>;
1782}
1783
1784impl FiniteBitSetTy for u32 {
1785    const DOMAIN_SIZE: u32 = 32;
1786
1787    const FILLED: Self = Self::MAX;
1788    const EMPTY: Self = Self::MIN;
1789
1790    const ONE: Self = 1u32;
1791    const ZERO: Self = 0u32;
1792
1793    fn checked_shl(self, rhs: u32) -> Option<Self> {
1794        self.checked_shl(rhs)
1795    }
1796
1797    fn checked_shr(self, rhs: u32) -> Option<Self> {
1798        self.checked_shr(rhs)
1799    }
1800}
1801
1802impl std::fmt::Debug for FiniteBitSet<u32> {
1803    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1804        write!(f, "{:032b}", self.0)
1805    }
1806}
1807
1808/// A fixed-sized bitset type represented by an integer type. Indices outwith than the range
1809/// representable by `T` are considered set.
1810#[cfg_attr(feature = "nightly", derive(Decodable_NoContext, Encodable_NoContext))]
1811#[derive(Copy, Clone, Eq, PartialEq)]
1812pub struct FiniteBitSet<T: FiniteBitSetTy>(pub T);
1813
1814impl<T: FiniteBitSetTy> FiniteBitSet<T> {
1815    /// Creates a new, empty bitset.
1816    pub fn new_empty() -> Self {
1817        Self(T::EMPTY)
1818    }
1819
1820    /// Sets the `index`th bit.
1821    pub fn set(&mut self, index: u32) {
1822        self.0 |= T::ONE.checked_shl(index).unwrap_or(T::ZERO);
1823    }
1824
1825    /// Unsets the `index`th bit.
1826    pub fn clear(&mut self, index: u32) {
1827        self.0 &= !T::ONE.checked_shl(index).unwrap_or(T::ZERO);
1828    }
1829
1830    /// Sets the `i`th to `j`th bits.
1831    pub fn set_range(&mut self, range: Range<u32>) {
1832        let bits = T::FILLED
1833            .checked_shl(range.end - range.start)
1834            .unwrap_or(T::ZERO)
1835            .not()
1836            .checked_shl(range.start)
1837            .unwrap_or(T::ZERO);
1838        self.0 |= bits;
1839    }
1840
1841    /// Is the set empty?
1842    pub fn is_empty(&self) -> bool {
1843        self.0 == T::EMPTY
1844    }
1845
1846    /// Returns the domain size of the bitset.
1847    pub fn within_domain(&self, index: u32) -> bool {
1848        index < T::DOMAIN_SIZE
1849    }
1850
1851    /// Returns if the `index`th bit is set.
1852    pub fn contains(&self, index: u32) -> Option<bool> {
1853        self.within_domain(index)
1854            .then(|| ((self.0.checked_shr(index).unwrap_or(T::ONE)) & T::ONE) == T::ONE)
1855    }
1856}
1857
1858impl<T: FiniteBitSetTy> Default for FiniteBitSet<T> {
1859    fn default() -> Self {
1860        Self::new_empty()
1861    }
1862}