1use std::cell::{Cell, OnceCell};
2use std::num::NonZero;
3use std::ops::Deref;
4use std::ptr;
5use std::sync::Arc;
67use parking_lot::Mutex;
89use crate::outline;
10use crate::sync::CacheAligned;
1112/// A pointer to the `RegistryData` which uniquely identifies a registry.
13/// This identifier can be reused if the registry gets freed.
14#[derive(Clone, Copy, PartialEq)]
15struct RegistryId(*const RegistryData);
1617impl RegistryId {
18#[inline(always)]
19/// Verifies that the current thread is associated with the registry and returns its unique
20 /// index within the registry. This panics if the current thread is not associated with this
21 /// registry.
22 ///
23 /// Note that there's a race possible where the identifier in `THREAD_DATA` could be reused
24 /// so this can succeed from a different registry.
25fn verify(self) -> usize {
26let (id, index) = THREAD_DATA.with(|data| (data.registry_id.get(), data.index.get()));
2728if id == self { index } else { outline(|| panic!("Unable to verify registry association")) }
29 }
30}
3132struct RegistryData {
33 thread_limit: NonZero<usize>,
34 threads: Mutex<usize>,
35}
3637/// Represents a list of threads which can access worker locals.
38#[derive(Clone)]
39pub struct Registry(Arc<RegistryData>);
4041thread_local! {
42/// The registry associated with the thread.
43 /// This allows the `WorkerLocal` type to clone the registry in its constructor.
44static REGISTRY: OnceCell<Registry> = const { OnceCell::new() };
45}
4647struct ThreadData {
48 registry_id: Cell<RegistryId>,
49 index: Cell<usize>,
50}
5152thread_local! {
53/// A thread local which contains the identifier of `REGISTRY` but allows for faster access.
54 /// It also holds the index of the current thread.
55static THREAD_DATA: ThreadData = const { ThreadData {
56 registry_id: Cell::new(RegistryId(ptr::null())),
57 index: Cell::new(0),
58 }};
59}
6061impl Registry {
62/// Creates a registry which can hold up to `thread_limit` threads.
63pub fn new(thread_limit: NonZero<usize>) -> Self {
64Registry(Arc::new(RegistryData { thread_limit, threads: Mutex::new(0) }))
65 }
6667/// Gets the registry associated with the current thread. Panics if there's no such registry.
68pub fn current() -> Self {
69REGISTRY.with(|registry| registry.get().cloned().expect("No associated registry"))
70 }
7172/// Registers the current thread with the registry so worker locals can be used on it.
73 /// Panics if the thread limit is hit or if the thread already has an associated registry.
74pub fn register(&self) {
75let mut threads = self.0.threads.lock();
76if *threads < self.0.thread_limit.get() {
77REGISTRY.with(|registry| {
78if registry.get().is_some() {
79drop(threads);
80panic!("Thread already has a registry");
81 }
82registry.set(self.clone()).ok();
83THREAD_DATA.with(|data| {
84data.registry_id.set(self.id());
85data.index.set(*threads);
86 });
87*threads += 1;
88 });
89 } else {
90drop(threads);
91panic!("Thread limit reached");
92 }
93 }
9495/// Gets the identifier of this registry.
96fn id(&self) -> RegistryId {
97RegistryId(&*self.0)
98 }
99}
100101/// Holds worker local values for each possible thread in a registry. You can only access the
102/// worker local value through the `Deref` impl on the registry associated with the thread it was
103/// created on. It will panic otherwise.
104pub struct WorkerLocal<T> {
105 locals: Box<[CacheAligned<T>]>,
106 registry: Registry,
107}
108109// This is safe because the `deref` call will return a reference to a `T` unique to each thread
110// or it will panic for threads without an associated local. So there isn't a need for `T` to do
111// it's own synchronization. The `verify` method on `RegistryId` has an issue where the id
112// can be reused, but `WorkerLocal` has a reference to `Registry` which will prevent any reuse.
113unsafe impl<T: Send> Syncfor WorkerLocal<T> {}
114115impl<T> WorkerLocal<T> {
116/// Creates a new worker local where the `initial` closure computes the
117 /// value this worker local should take for each thread in the registry.
118#[inline]
119pub fn new<F: FnMut(usize) -> T>(mut initial: F) -> WorkerLocal<T> {
120let registry = Registry::current();
121WorkerLocal {
122 locals: (0..registry.0.thread_limit.get()).map(|i| CacheAligned(initial(i))).collect(),
123registry,
124 }
125 }
126127/// Returns the worker-local values for each thread
128#[inline]
129pub fn into_inner(self) -> impl Iterator<Item = T> {
130self.locals.into_vec().into_iter().map(|local| local.0)
131 }
132}
133134impl<T> Dereffor WorkerLocal<T> {
135type Target = T;
136137#[inline(always)]
138fn deref(&self) -> &T {
139// This is safe because `verify` will only return values less than
140 // `self.registry.thread_limit` which is the size of the `self.locals` array.
141unsafe { &self.locals.get_unchecked(self.registry.id().verify()).0 }
142 }
143}
144145impl<T: Default> Defaultfor WorkerLocal<T> {
146fn default() -> Self {
147WorkerLocal::new(|_| T::default())
148 }
149}