rustc_data_structures/
sharded.rs

1use std::borrow::Borrow;
2use std::hash::{Hash, Hasher};
3use std::{iter, mem};
4
5use either::Either;
6use hashbrown::hash_table::{Entry, HashTable};
7
8use crate::fx::FxHasher;
9use crate::sync::{CacheAligned, Lock, LockGuard, Mode, is_dyn_thread_safe};
10
11// 32 shards is sufficient to reduce contention on an 8-core Ryzen 7 1700,
12// but this should be tested on higher core count CPUs. How the `Sharded` type gets used
13// may also affect the ideal number of shards.
14const SHARD_BITS: usize = 5;
15
16const SHARDS: usize = 1 << SHARD_BITS;
17
18/// An array of cache-line aligned inner locked structures with convenience methods.
19/// A single field is used when the compiler uses only one thread.
20pub enum Sharded<T> {
21    Single(Lock<T>),
22    Shards(Box<[CacheAligned<Lock<T>>; SHARDS]>),
23}
24
25impl<T: Default> Default for Sharded<T> {
26    #[inline]
27    fn default() -> Self {
28        Self::new(T::default)
29    }
30}
31
32impl<T> Sharded<T> {
33    #[inline]
34    pub fn new(mut value: impl FnMut() -> T) -> Self {
35        if is_dyn_thread_safe() {
36            return Sharded::Shards(Box::new(
37                [(); SHARDS].map(|()| CacheAligned(Lock::new(value()))),
38            ));
39        }
40
41        Sharded::Single(Lock::new(value()))
42    }
43
44    /// The shard is selected by hashing `val` with `FxHasher`.
45    #[inline]
46    pub fn get_shard_by_value<K: Hash + ?Sized>(&self, val: &K) -> &Lock<T> {
47        match self {
48            Self::Single(single) => single,
49            Self::Shards(..) => self.get_shard_by_hash(make_hash(val)),
50        }
51    }
52
53    #[inline]
54    pub fn get_shard_by_hash(&self, hash: u64) -> &Lock<T> {
55        self.get_shard_by_index(get_shard_hash(hash))
56    }
57
58    #[inline]
59    pub fn get_shard_by_index(&self, i: usize) -> &Lock<T> {
60        match self {
61            Self::Single(single) => single,
62            Self::Shards(shards) => {
63                // SAFETY: The index gets ANDed with the shard mask, ensuring it is always inbounds.
64                unsafe { &shards.get_unchecked(i & (SHARDS - 1)).0 }
65            }
66        }
67    }
68
69    /// The shard is selected by hashing `val` with `FxHasher`.
70    #[inline]
71    #[track_caller]
72    pub fn lock_shard_by_value<K: Hash + ?Sized>(&self, val: &K) -> LockGuard<'_, T> {
73        match self {
74            Self::Single(single) => {
75                // Synchronization is disabled so use the `lock_assume_no_sync` method optimized
76                // for that case.
77
78                // SAFETY: We know `is_dyn_thread_safe` was false when creating the lock thus
79                // `might_be_dyn_thread_safe` was also false.
80                unsafe { single.lock_assume(Mode::NoSync) }
81            }
82            Self::Shards(..) => self.lock_shard_by_hash(make_hash(val)),
83        }
84    }
85
86    #[inline]
87    #[track_caller]
88    pub fn lock_shard_by_hash(&self, hash: u64) -> LockGuard<'_, T> {
89        self.lock_shard_by_index(get_shard_hash(hash))
90    }
91
92    #[inline]
93    #[track_caller]
94    pub fn lock_shard_by_index(&self, i: usize) -> LockGuard<'_, T> {
95        match self {
96            Self::Single(single) => {
97                // Synchronization is disabled so use the `lock_assume_no_sync` method optimized
98                // for that case.
99
100                // SAFETY: We know `is_dyn_thread_safe` was false when creating the lock thus
101                // `might_be_dyn_thread_safe` was also false.
102                unsafe { single.lock_assume(Mode::NoSync) }
103            }
104            Self::Shards(shards) => {
105                // Synchronization is enabled so use the `lock_assume_sync` method optimized
106                // for that case.
107
108                // SAFETY (get_unchecked): The index gets ANDed with the shard mask, ensuring it is
109                // always inbounds.
110                // SAFETY (lock_assume_sync): We know `is_dyn_thread_safe` was true when creating
111                // the lock thus `might_be_dyn_thread_safe` was also true.
112                unsafe { shards.get_unchecked(i & (SHARDS - 1)).0.lock_assume(Mode::Sync) }
113            }
114        }
115    }
116
117    #[inline]
118    pub fn lock_shards(&self) -> impl Iterator<Item = LockGuard<'_, T>> {
119        match self {
120            Self::Single(single) => Either::Left(iter::once(single.lock())),
121            Self::Shards(shards) => Either::Right(shards.iter().map(|shard| shard.0.lock())),
122        }
123    }
124
125    #[inline]
126    pub fn try_lock_shards(&self) -> impl Iterator<Item = Option<LockGuard<'_, T>>> {
127        match self {
128            Self::Single(single) => Either::Left(iter::once(single.try_lock())),
129            Self::Shards(shards) => Either::Right(shards.iter().map(|shard| shard.0.try_lock())),
130        }
131    }
132}
133
134#[inline]
135pub fn shards() -> usize {
136    if is_dyn_thread_safe() {
137        return SHARDS;
138    }
139
140    1
141}
142
143pub type ShardedHashMap<K, V> = Sharded<HashTable<(K, V)>>;
144
145impl<K: Eq, V> ShardedHashMap<K, V> {
146    pub fn with_capacity(cap: usize) -> Self {
147        Self::new(|| HashTable::with_capacity(cap))
148    }
149    pub fn len(&self) -> usize {
150        self.lock_shards().map(|shard| shard.len()).sum()
151    }
152}
153
154impl<K: Eq + Hash, V> ShardedHashMap<K, V> {
155    #[inline]
156    pub fn get<Q>(&self, key: &Q) -> Option<V>
157    where
158        K: Borrow<Q>,
159        Q: Hash + Eq,
160        V: Clone,
161    {
162        let hash = make_hash(key);
163        let shard = self.lock_shard_by_hash(hash);
164        let (_, value) = shard.find(hash, |(k, _)| k.borrow() == key)?;
165        Some(value.clone())
166    }
167
168    #[inline]
169    pub fn get_or_insert_with(&self, key: K, default: impl FnOnce() -> V) -> V
170    where
171        V: Copy,
172    {
173        let hash = make_hash(&key);
174        let mut shard = self.lock_shard_by_hash(hash);
175
176        match table_entry(&mut shard, hash, &key) {
177            Entry::Occupied(e) => e.get().1,
178            Entry::Vacant(e) => {
179                let value = default();
180                e.insert((key, value));
181                value
182            }
183        }
184    }
185
186    #[inline]
187    pub fn insert(&self, key: K, value: V) -> Option<V> {
188        let hash = make_hash(&key);
189        let mut shard = self.lock_shard_by_hash(hash);
190
191        match table_entry(&mut shard, hash, &key) {
192            Entry::Occupied(e) => {
193                let previous = mem::replace(&mut e.into_mut().1, value);
194                Some(previous)
195            }
196            Entry::Vacant(e) => {
197                e.insert((key, value));
198                None
199            }
200        }
201    }
202}
203
204impl<K: Eq + Hash + Copy> ShardedHashMap<K, ()> {
205    #[inline]
206    pub fn intern_ref<Q: ?Sized>(&self, value: &Q, make: impl FnOnce() -> K) -> K
207    where
208        K: Borrow<Q>,
209        Q: Hash + Eq,
210    {
211        let hash = make_hash(value);
212        let mut shard = self.lock_shard_by_hash(hash);
213
214        match table_entry(&mut shard, hash, value) {
215            Entry::Occupied(e) => e.get().0,
216            Entry::Vacant(e) => {
217                let v = make();
218                e.insert((v, ()));
219                v
220            }
221        }
222    }
223
224    #[inline]
225    pub fn intern<Q>(&self, value: Q, make: impl FnOnce(Q) -> K) -> K
226    where
227        K: Borrow<Q>,
228        Q: Hash + Eq,
229    {
230        let hash = make_hash(&value);
231        let mut shard = self.lock_shard_by_hash(hash);
232
233        match table_entry(&mut shard, hash, &value) {
234            Entry::Occupied(e) => e.get().0,
235            Entry::Vacant(e) => {
236                let v = make(value);
237                e.insert((v, ()));
238                v
239            }
240        }
241    }
242}
243
244pub trait IntoPointer {
245    /// Returns a pointer which outlives `self`.
246    fn into_pointer(&self) -> *const ();
247}
248
249impl<K: Eq + Hash + Copy + IntoPointer> ShardedHashMap<K, ()> {
250    pub fn contains_pointer_to<T: Hash + IntoPointer>(&self, value: &T) -> bool {
251        let hash = make_hash(&value);
252        let shard = self.lock_shard_by_hash(hash);
253        let value = value.into_pointer();
254        shard.find(hash, |(k, ())| k.into_pointer() == value).is_some()
255    }
256}
257
258#[inline]
259pub fn make_hash<K: Hash + ?Sized>(val: &K) -> u64 {
260    let mut state = FxHasher::default();
261    val.hash(&mut state);
262    state.finish()
263}
264
265#[inline]
266fn table_entry<'a, K, V, Q>(
267    table: &'a mut HashTable<(K, V)>,
268    hash: u64,
269    key: &Q,
270) -> Entry<'a, (K, V)>
271where
272    K: Hash + Borrow<Q>,
273    Q: ?Sized + Eq,
274{
275    table.entry(hash, move |(k, _)| k.borrow() == key, |(k, _)| make_hash(k))
276}
277
278/// Get a shard with a pre-computed hash value. If `get_shard_by_value` is
279/// ever used in combination with `get_shard_by_hash` on a single `Sharded`
280/// instance, then `hash` must be computed with `FxHasher`. Otherwise,
281/// `hash` can be computed with any hasher, so long as that hasher is used
282/// consistently for each `Sharded` instance.
283#[inline]
284fn get_shard_hash(hash: u64) -> usize {
285    let hash_len = size_of::<usize>();
286    // Ignore the top 7 bits as hashbrown uses these and get the next SHARD_BITS highest bits.
287    // hashbrown also uses the lowest bits, so we can't use those
288    (hash >> (hash_len * 8 - 7 - SHARD_BITS)) as usize
289}