1use std::borrow::Borrow;
2use std::hash::{Hash, Hasher};
3use std::{iter, mem};
4
5use either::Either;
6use hashbrown::hash_table::{Entry, HashTable};
7
8use crate::fx::FxHasher;
9use crate::sync::{CacheAligned, Lock, LockGuard, Mode, is_dyn_thread_safe};
10
11const SHARD_BITS: usize = 5;
15
16const SHARDS: usize = 1 << SHARD_BITS;
17
18pub enum Sharded<T> {
21 Single(Lock<T>),
22 Shards(Box<[CacheAligned<Lock<T>>; SHARDS]>),
23}
24
25impl<T: Default> Default for Sharded<T> {
26 #[inline]
27 fn default() -> Self {
28 Self::new(T::default)
29 }
30}
31
32impl<T> Sharded<T> {
33 #[inline]
34 pub fn new(mut value: impl FnMut() -> T) -> Self {
35 if is_dyn_thread_safe() {
36 return Sharded::Shards(Box::new(
37 [(); SHARDS].map(|()| CacheAligned(Lock::new(value()))),
38 ));
39 }
40
41 Sharded::Single(Lock::new(value()))
42 }
43
44 #[inline]
46 pub fn get_shard_by_value<K: Hash + ?Sized>(&self, val: &K) -> &Lock<T> {
47 match self {
48 Self::Single(single) => single,
49 Self::Shards(..) => self.get_shard_by_hash(make_hash(val)),
50 }
51 }
52
53 #[inline]
54 pub fn get_shard_by_hash(&self, hash: u64) -> &Lock<T> {
55 self.get_shard_by_index(get_shard_hash(hash))
56 }
57
58 #[inline]
59 pub fn get_shard_by_index(&self, i: usize) -> &Lock<T> {
60 match self {
61 Self::Single(single) => single,
62 Self::Shards(shards) => {
63 unsafe { &shards.get_unchecked(i & (SHARDS - 1)).0 }
65 }
66 }
67 }
68
69 #[inline]
71 #[track_caller]
72 pub fn lock_shard_by_value<K: Hash + ?Sized>(&self, val: &K) -> LockGuard<'_, T> {
73 match self {
74 Self::Single(single) => {
75 unsafe { single.lock_assume(Mode::NoSync) }
81 }
82 Self::Shards(..) => self.lock_shard_by_hash(make_hash(val)),
83 }
84 }
85
86 #[inline]
87 #[track_caller]
88 pub fn lock_shard_by_hash(&self, hash: u64) -> LockGuard<'_, T> {
89 self.lock_shard_by_index(get_shard_hash(hash))
90 }
91
92 #[inline]
93 #[track_caller]
94 pub fn lock_shard_by_index(&self, i: usize) -> LockGuard<'_, T> {
95 match self {
96 Self::Single(single) => {
97 unsafe { single.lock_assume(Mode::NoSync) }
103 }
104 Self::Shards(shards) => {
105 unsafe { shards.get_unchecked(i & (SHARDS - 1)).0.lock_assume(Mode::Sync) }
113 }
114 }
115 }
116
117 #[inline]
118 pub fn lock_shards(&self) -> impl Iterator<Item = LockGuard<'_, T>> {
119 match self {
120 Self::Single(single) => Either::Left(iter::once(single.lock())),
121 Self::Shards(shards) => Either::Right(shards.iter().map(|shard| shard.0.lock())),
122 }
123 }
124
125 #[inline]
126 pub fn try_lock_shards(&self) -> impl Iterator<Item = Option<LockGuard<'_, T>>> {
127 match self {
128 Self::Single(single) => Either::Left(iter::once(single.try_lock())),
129 Self::Shards(shards) => Either::Right(shards.iter().map(|shard| shard.0.try_lock())),
130 }
131 }
132}
133
134#[inline]
135pub fn shards() -> usize {
136 if is_dyn_thread_safe() {
137 return SHARDS;
138 }
139
140 1
141}
142
143pub type ShardedHashMap<K, V> = Sharded<HashTable<(K, V)>>;
144
145impl<K: Eq, V> ShardedHashMap<K, V> {
146 pub fn with_capacity(cap: usize) -> Self {
147 Self::new(|| HashTable::with_capacity(cap))
148 }
149 pub fn len(&self) -> usize {
150 self.lock_shards().map(|shard| shard.len()).sum()
151 }
152}
153
154impl<K: Eq + Hash, V> ShardedHashMap<K, V> {
155 #[inline]
156 pub fn get<Q>(&self, key: &Q) -> Option<V>
157 where
158 K: Borrow<Q>,
159 Q: Hash + Eq,
160 V: Clone,
161 {
162 let hash = make_hash(key);
163 let shard = self.lock_shard_by_hash(hash);
164 let (_, value) = shard.find(hash, |(k, _)| k.borrow() == key)?;
165 Some(value.clone())
166 }
167
168 #[inline]
169 pub fn get_or_insert_with(&self, key: K, default: impl FnOnce() -> V) -> V
170 where
171 V: Copy,
172 {
173 let hash = make_hash(&key);
174 let mut shard = self.lock_shard_by_hash(hash);
175
176 match table_entry(&mut shard, hash, &key) {
177 Entry::Occupied(e) => e.get().1,
178 Entry::Vacant(e) => {
179 let value = default();
180 e.insert((key, value));
181 value
182 }
183 }
184 }
185
186 #[inline]
187 pub fn insert(&self, key: K, value: V) -> Option<V> {
188 let hash = make_hash(&key);
189 let mut shard = self.lock_shard_by_hash(hash);
190
191 match table_entry(&mut shard, hash, &key) {
192 Entry::Occupied(e) => {
193 let previous = mem::replace(&mut e.into_mut().1, value);
194 Some(previous)
195 }
196 Entry::Vacant(e) => {
197 e.insert((key, value));
198 None
199 }
200 }
201 }
202}
203
204impl<K: Eq + Hash + Copy> ShardedHashMap<K, ()> {
205 #[inline]
206 pub fn intern_ref<Q: ?Sized>(&self, value: &Q, make: impl FnOnce() -> K) -> K
207 where
208 K: Borrow<Q>,
209 Q: Hash + Eq,
210 {
211 let hash = make_hash(value);
212 let mut shard = self.lock_shard_by_hash(hash);
213
214 match table_entry(&mut shard, hash, value) {
215 Entry::Occupied(e) => e.get().0,
216 Entry::Vacant(e) => {
217 let v = make();
218 e.insert((v, ()));
219 v
220 }
221 }
222 }
223
224 #[inline]
225 pub fn intern<Q>(&self, value: Q, make: impl FnOnce(Q) -> K) -> K
226 where
227 K: Borrow<Q>,
228 Q: Hash + Eq,
229 {
230 let hash = make_hash(&value);
231 let mut shard = self.lock_shard_by_hash(hash);
232
233 match table_entry(&mut shard, hash, &value) {
234 Entry::Occupied(e) => e.get().0,
235 Entry::Vacant(e) => {
236 let v = make(value);
237 e.insert((v, ()));
238 v
239 }
240 }
241 }
242}
243
244pub trait IntoPointer {
245 fn into_pointer(&self) -> *const ();
247}
248
249impl<K: Eq + Hash + Copy + IntoPointer> ShardedHashMap<K, ()> {
250 pub fn contains_pointer_to<T: Hash + IntoPointer>(&self, value: &T) -> bool {
251 let hash = make_hash(&value);
252 let shard = self.lock_shard_by_hash(hash);
253 let value = value.into_pointer();
254 shard.find(hash, |(k, ())| k.into_pointer() == value).is_some()
255 }
256}
257
258#[inline]
259pub fn make_hash<K: Hash + ?Sized>(val: &K) -> u64 {
260 let mut state = FxHasher::default();
261 val.hash(&mut state);
262 state.finish()
263}
264
265#[inline]
266fn table_entry<'a, K, V, Q>(
267 table: &'a mut HashTable<(K, V)>,
268 hash: u64,
269 key: &Q,
270) -> Entry<'a, (K, V)>
271where
272 K: Hash + Borrow<Q>,
273 Q: ?Sized + Eq,
274{
275 table.entry(hash, move |(k, _)| k.borrow() == key, |(k, _)| make_hash(k))
276}
277
278#[inline]
284fn get_shard_hash(hash: u64) -> usize {
285 let hash_len = size_of::<usize>();
286 (hash >> (hash_len * 8 - 7 - SHARD_BITS)) as usize
289}