std/sync/
reentrant_lock.rs

1use crate::cell::UnsafeCell;
2use crate::fmt;
3use crate::ops::Deref;
4use crate::panic::{RefUnwindSafe, UnwindSafe};
5use crate::sys::sync as sys;
6use crate::thread::{ThreadId, current_id};
7
8/// A re-entrant mutual exclusion lock
9///
10/// This lock will block *other* threads waiting for the lock to become
11/// available. The thread which has already locked the mutex can lock it
12/// multiple times without blocking, preventing a common source of deadlocks.
13///
14/// # Examples
15///
16/// Allow recursively calling a function needing synchronization from within
17/// a callback (this is how [`StdoutLock`](crate::io::StdoutLock) is currently
18/// implemented):
19///
20/// ```
21/// #![feature(reentrant_lock)]
22///
23/// use std::cell::RefCell;
24/// use std::sync::ReentrantLock;
25///
26/// pub struct Log {
27///     data: RefCell<String>,
28/// }
29///
30/// impl Log {
31///     pub fn append(&self, msg: &str) {
32///         self.data.borrow_mut().push_str(msg);
33///     }
34/// }
35///
36/// static LOG: ReentrantLock<Log> = ReentrantLock::new(Log { data: RefCell::new(String::new()) });
37///
38/// pub fn with_log<R>(f: impl FnOnce(&Log) -> R) -> R {
39///     let log = LOG.lock();
40///     f(&*log)
41/// }
42///
43/// with_log(|log| {
44///     log.append("Hello");
45///     with_log(|log| log.append(" there!"));
46/// });
47/// ```
48///
49// # Implementation details
50//
51// The 'owner' field tracks which thread has locked the mutex.
52//
53// We use thread::current_id() as the thread identifier, which is just the
54// current thread's ThreadId, so it's unique across the process lifetime.
55//
56// If `owner` is set to the identifier of the current thread,
57// we assume the mutex is already locked and instead of locking it again,
58// we increment `lock_count`.
59//
60// When unlocking, we decrement `lock_count`, and only unlock the mutex when
61// it reaches zero.
62//
63// `lock_count` is protected by the mutex and only accessed by the thread that has
64// locked the mutex, so needs no synchronization.
65//
66// `owner` can be checked by other threads that want to see if they already
67// hold the lock, so needs to be atomic. If it compares equal, we're on the
68// same thread that holds the mutex and memory access can use relaxed ordering
69// since we're not dealing with multiple threads. If it's not equal,
70// synchronization is left to the mutex, making relaxed memory ordering for
71// the `owner` field fine in all cases.
72//
73// On systems without 64 bit atomics we also store the address of a TLS variable
74// along the 64-bit TID. We then first check that address against the address
75// of that variable on the current thread, and only if they compare equal do we
76// compare the actual TIDs. Because we only ever read the TID on the same thread
77// that it was written on (or a thread sharing the TLS block with that writer thread),
78// we don't need to further synchronize the TID accesses, so they can be regular 64-bit
79// non-atomic accesses.
80#[unstable(feature = "reentrant_lock", issue = "121440")]
81pub struct ReentrantLock<T: ?Sized> {
82    mutex: sys::Mutex,
83    owner: Tid,
84    lock_count: UnsafeCell<u32>,
85    data: T,
86}
87
88cfg_select!(
89    target_has_atomic = "64" => {
90        use crate::sync::atomic::{Atomic, AtomicU64, Ordering::Relaxed};
91
92        struct Tid(Atomic<u64>);
93
94        impl Tid {
95            const fn new() -> Self {
96                Self(AtomicU64::new(0))
97            }
98
99            #[inline]
100            fn contains(&self, owner: ThreadId) -> bool {
101                owner.as_u64().get() == self.0.load(Relaxed)
102            }
103
104            #[inline]
105            // This is just unsafe to match the API of the Tid type below.
106            unsafe fn set(&self, tid: Option<ThreadId>) {
107                let value = tid.map_or(0, |tid| tid.as_u64().get());
108                self.0.store(value, Relaxed);
109            }
110        }
111    }
112    _ => {
113        /// Returns the address of a TLS variable. This is guaranteed to
114        /// be unique across all currently alive threads.
115        fn tls_addr() -> usize {
116            thread_local! { static X: u8 = const { 0u8 } };
117
118            X.with(|p| <*const u8>::addr(p))
119        }
120
121        use crate::sync::atomic::{
122            Atomic,
123            AtomicUsize,
124            Ordering,
125        };
126
127        struct Tid {
128            // When a thread calls `set()`, this value gets updated to
129            // the address of a thread local on that thread. This is
130            // used as a first check in `contains()`; if the `tls_addr`
131            // doesn't match the TLS address of the current thread, then
132            // the ThreadId also can't match. Only if the TLS addresses do
133            // match do we read out the actual TID.
134            // Note also that we can use relaxed atomic operations here, because
135            // we only ever read from the tid if `tls_addr` matches the current
136            // TLS address. In that case, either the tid has been set by
137            // the current thread, or by a thread that has terminated before
138            // the current thread's `tls_addr` was allocated. In either case, no further
139            // synchronization is needed (as per <https://github.com/rust-lang/miri/issues/3450>)
140            tls_addr: Atomic<usize>,
141            tid: UnsafeCell<u64>,
142        }
143
144        unsafe impl Send for Tid {}
145        unsafe impl Sync for Tid {}
146
147        impl Tid {
148            const fn new() -> Self {
149                Self { tls_addr: AtomicUsize::new(0), tid: UnsafeCell::new(0) }
150            }
151
152            #[inline]
153            // NOTE: This assumes that `owner` is the ID of the current
154            // thread, and may spuriously return `false` if that's not the case.
155            fn contains(&self, owner: ThreadId) -> bool {
156                // We must call `tls_addr()` *before* doing the load to ensure that if we reuse an
157                // earlier thread's address, the `tls_addr.load()` below happens-after everything
158                // that thread did.
159                let tls_addr = tls_addr();
160                // SAFETY: See the comments in the struct definition.
161                self.tls_addr.load(Ordering::Relaxed) == tls_addr
162                    && unsafe { *self.tid.get() } == owner.as_u64().get()
163            }
164
165            #[inline]
166            // This may only be called by one thread at a time, and can lead to
167            // race conditions otherwise.
168            unsafe fn set(&self, tid: Option<ThreadId>) {
169                // It's important that we set `self.tls_addr` to 0 if the tid is
170                // cleared. Otherwise, there might be race conditions between
171                // `set()` and `get()`.
172                let tls_addr = if tid.is_some() { tls_addr() } else { 0 };
173                let value = tid.map_or(0, |tid| tid.as_u64().get());
174                self.tls_addr.store(tls_addr, Ordering::Relaxed);
175                unsafe { *self.tid.get() = value };
176            }
177        }
178    }
179);
180
181#[unstable(feature = "reentrant_lock", issue = "121440")]
182unsafe impl<T: Send + ?Sized> Send for ReentrantLock<T> {}
183#[unstable(feature = "reentrant_lock", issue = "121440")]
184unsafe impl<T: Send + ?Sized> Sync for ReentrantLock<T> {}
185
186// Because of the `UnsafeCell`, these traits are not implemented automatically
187#[unstable(feature = "reentrant_lock", issue = "121440")]
188impl<T: UnwindSafe + ?Sized> UnwindSafe for ReentrantLock<T> {}
189#[unstable(feature = "reentrant_lock", issue = "121440")]
190impl<T: RefUnwindSafe + ?Sized> RefUnwindSafe for ReentrantLock<T> {}
191
192/// An RAII implementation of a "scoped lock" of a re-entrant lock. When this
193/// structure is dropped (falls out of scope), the lock will be unlocked.
194///
195/// The data protected by the mutex can be accessed through this guard via its
196/// [`Deref`] implementation.
197///
198/// This structure is created by the [`lock`](ReentrantLock::lock) method on
199/// [`ReentrantLock`].
200///
201/// # Mutability
202///
203/// Unlike [`MutexGuard`](super::MutexGuard), `ReentrantLockGuard` does not
204/// implement [`DerefMut`](crate::ops::DerefMut), because implementation of
205/// the trait would violate Rust’s reference aliasing rules. Use interior
206/// mutability (usually [`RefCell`](crate::cell::RefCell)) in order to mutate
207/// the guarded data.
208#[must_use = "if unused the ReentrantLock will immediately unlock"]
209#[unstable(feature = "reentrant_lock", issue = "121440")]
210pub struct ReentrantLockGuard<'a, T: ?Sized + 'a> {
211    lock: &'a ReentrantLock<T>,
212}
213
214#[unstable(feature = "reentrant_lock", issue = "121440")]
215impl<T: ?Sized> !Send for ReentrantLockGuard<'_, T> {}
216
217#[unstable(feature = "reentrant_lock", issue = "121440")]
218unsafe impl<T: ?Sized + Sync> Sync for ReentrantLockGuard<'_, T> {}
219
220#[unstable(feature = "reentrant_lock", issue = "121440")]
221impl<T> ReentrantLock<T> {
222    /// Creates a new re-entrant lock in an unlocked state ready for use.
223    ///
224    /// # Examples
225    ///
226    /// ```
227    /// #![feature(reentrant_lock)]
228    /// use std::sync::ReentrantLock;
229    ///
230    /// let lock = ReentrantLock::new(0);
231    /// ```
232    pub const fn new(t: T) -> ReentrantLock<T> {
233        ReentrantLock {
234            mutex: sys::Mutex::new(),
235            owner: Tid::new(),
236            lock_count: UnsafeCell::new(0),
237            data: t,
238        }
239    }
240
241    /// Consumes this lock, returning the underlying data.
242    ///
243    /// # Examples
244    ///
245    /// ```
246    /// #![feature(reentrant_lock)]
247    ///
248    /// use std::sync::ReentrantLock;
249    ///
250    /// let lock = ReentrantLock::new(0);
251    /// assert_eq!(lock.into_inner(), 0);
252    /// ```
253    pub fn into_inner(self) -> T {
254        self.data
255    }
256}
257
258#[unstable(feature = "reentrant_lock", issue = "121440")]
259impl<T: ?Sized> ReentrantLock<T> {
260    /// Acquires the lock, blocking the current thread until it is able to do
261    /// so.
262    ///
263    /// This function will block the caller until it is available to acquire
264    /// the lock. Upon returning, the thread is the only thread with the lock
265    /// held. When the thread calling this method already holds the lock, the
266    /// call succeeds without blocking.
267    ///
268    /// # Examples
269    ///
270    /// ```
271    /// #![feature(reentrant_lock)]
272    /// use std::cell::Cell;
273    /// use std::sync::{Arc, ReentrantLock};
274    /// use std::thread;
275    ///
276    /// let lock = Arc::new(ReentrantLock::new(Cell::new(0)));
277    /// let c_lock = Arc::clone(&lock);
278    ///
279    /// thread::spawn(move || {
280    ///     c_lock.lock().set(10);
281    /// }).join().expect("thread::spawn failed");
282    /// assert_eq!(lock.lock().get(), 10);
283    /// ```
284    pub fn lock(&self) -> ReentrantLockGuard<'_, T> {
285        let this_thread = current_id();
286        // Safety: We only touch lock_count when we own the inner mutex.
287        // Additionally, we only call `self.owner.set()` while holding
288        // the inner mutex, so no two threads can call it concurrently.
289        unsafe {
290            if self.owner.contains(this_thread) {
291                self.increment_lock_count().expect("lock count overflow in reentrant mutex");
292            } else {
293                self.mutex.lock();
294                self.owner.set(Some(this_thread));
295                debug_assert_eq!(*self.lock_count.get(), 0);
296                *self.lock_count.get() = 1;
297            }
298        }
299        ReentrantLockGuard { lock: self }
300    }
301
302    /// Returns a mutable reference to the underlying data.
303    ///
304    /// Since this call borrows the `ReentrantLock` mutably, no actual locking
305    /// needs to take place -- the mutable borrow statically guarantees no locks
306    /// exist.
307    ///
308    /// # Examples
309    ///
310    /// ```
311    /// #![feature(reentrant_lock)]
312    /// use std::sync::ReentrantLock;
313    ///
314    /// let mut lock = ReentrantLock::new(0);
315    /// *lock.get_mut() = 10;
316    /// assert_eq!(*lock.lock(), 10);
317    /// ```
318    pub fn get_mut(&mut self) -> &mut T {
319        &mut self.data
320    }
321
322    /// Attempts to acquire this lock.
323    ///
324    /// If the lock could not be acquired at this time, then `None` is returned.
325    /// Otherwise, an RAII guard is returned.
326    ///
327    /// This function does not block.
328    // FIXME maybe make it a public part of the API?
329    #[unstable(issue = "none", feature = "std_internals")]
330    #[doc(hidden)]
331    pub fn try_lock(&self) -> Option<ReentrantLockGuard<'_, T>> {
332        let this_thread = current_id();
333        // Safety: We only touch lock_count when we own the inner mutex.
334        // Additionally, we only call `self.owner.set()` while holding
335        // the inner mutex, so no two threads can call it concurrently.
336        unsafe {
337            if self.owner.contains(this_thread) {
338                self.increment_lock_count()?;
339                Some(ReentrantLockGuard { lock: self })
340            } else if self.mutex.try_lock() {
341                self.owner.set(Some(this_thread));
342                debug_assert_eq!(*self.lock_count.get(), 0);
343                *self.lock_count.get() = 1;
344                Some(ReentrantLockGuard { lock: self })
345            } else {
346                None
347            }
348        }
349    }
350
351    /// Returns a raw pointer to the underlying data.
352    ///
353    /// The returned pointer is always non-null and properly aligned, but it is
354    /// the user's responsibility to ensure that any reads through it are
355    /// properly synchronized to avoid data races, and that it is not read
356    /// through after the lock is dropped.
357    #[unstable(feature = "reentrant_lock_data_ptr", issue = "140368")]
358    pub fn data_ptr(&self) -> *const T {
359        &raw const self.data
360    }
361
362    unsafe fn increment_lock_count(&self) -> Option<()> {
363        unsafe {
364            *self.lock_count.get() = (*self.lock_count.get()).checked_add(1)?;
365        }
366        Some(())
367    }
368}
369
370#[unstable(feature = "reentrant_lock", issue = "121440")]
371impl<T: fmt::Debug + ?Sized> fmt::Debug for ReentrantLock<T> {
372    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
373        let mut d = f.debug_struct("ReentrantLock");
374        match self.try_lock() {
375            Some(v) => d.field("data", &&*v),
376            None => d.field("data", &format_args!("<locked>")),
377        };
378        d.finish_non_exhaustive()
379    }
380}
381
382#[unstable(feature = "reentrant_lock", issue = "121440")]
383impl<T: Default> Default for ReentrantLock<T> {
384    fn default() -> Self {
385        Self::new(T::default())
386    }
387}
388
389#[unstable(feature = "reentrant_lock", issue = "121440")]
390impl<T> From<T> for ReentrantLock<T> {
391    fn from(t: T) -> Self {
392        Self::new(t)
393    }
394}
395
396#[unstable(feature = "reentrant_lock", issue = "121440")]
397impl<T: ?Sized> Deref for ReentrantLockGuard<'_, T> {
398    type Target = T;
399
400    fn deref(&self) -> &T {
401        &self.lock.data
402    }
403}
404
405#[unstable(feature = "reentrant_lock", issue = "121440")]
406impl<T: fmt::Debug + ?Sized> fmt::Debug for ReentrantLockGuard<'_, T> {
407    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
408        (**self).fmt(f)
409    }
410}
411
412#[unstable(feature = "reentrant_lock", issue = "121440")]
413impl<T: fmt::Display + ?Sized> fmt::Display for ReentrantLockGuard<'_, T> {
414    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
415        (**self).fmt(f)
416    }
417}
418
419#[unstable(feature = "reentrant_lock", issue = "121440")]
420impl<T: ?Sized> Drop for ReentrantLockGuard<'_, T> {
421    #[inline]
422    fn drop(&mut self) {
423        // Safety: We own the lock.
424        unsafe {
425            *self.lock.lock_count.get() -= 1;
426            if *self.lock.lock_count.get() == 0 {
427                self.lock.owner.set(None);
428                self.lock.mutex.unlock();
429            }
430        }
431    }
432}