std/sync/
reentrant_lock.rs

1use cfg_if::cfg_if;
2
3use crate::cell::UnsafeCell;
4use crate::fmt;
5use crate::ops::Deref;
6use crate::panic::{RefUnwindSafe, UnwindSafe};
7use crate::sys::sync as sys;
8use crate::thread::{ThreadId, current_id};
9
10/// A re-entrant mutual exclusion lock
11///
12/// This lock will block *other* threads waiting for the lock to become
13/// available. The thread which has already locked the mutex can lock it
14/// multiple times without blocking, preventing a common source of deadlocks.
15///
16/// # Examples
17///
18/// Allow recursively calling a function needing synchronization from within
19/// a callback (this is how [`StdoutLock`](crate::io::StdoutLock) is currently
20/// implemented):
21///
22/// ```
23/// #![feature(reentrant_lock)]
24///
25/// use std::cell::RefCell;
26/// use std::sync::ReentrantLock;
27///
28/// pub struct Log {
29///     data: RefCell<String>,
30/// }
31///
32/// impl Log {
33///     pub fn append(&self, msg: &str) {
34///         self.data.borrow_mut().push_str(msg);
35///     }
36/// }
37///
38/// static LOG: ReentrantLock<Log> = ReentrantLock::new(Log { data: RefCell::new(String::new()) });
39///
40/// pub fn with_log<R>(f: impl FnOnce(&Log) -> R) -> R {
41///     let log = LOG.lock();
42///     f(&*log)
43/// }
44///
45/// with_log(|log| {
46///     log.append("Hello");
47///     with_log(|log| log.append(" there!"));
48/// });
49/// ```
50///
51// # Implementation details
52//
53// The 'owner' field tracks which thread has locked the mutex.
54//
55// We use thread::current_id() as the thread identifier, which is just the
56// current thread's ThreadId, so it's unique across the process lifetime.
57//
58// If `owner` is set to the identifier of the current thread,
59// we assume the mutex is already locked and instead of locking it again,
60// we increment `lock_count`.
61//
62// When unlocking, we decrement `lock_count`, and only unlock the mutex when
63// it reaches zero.
64//
65// `lock_count` is protected by the mutex and only accessed by the thread that has
66// locked the mutex, so needs no synchronization.
67//
68// `owner` can be checked by other threads that want to see if they already
69// hold the lock, so needs to be atomic. If it compares equal, we're on the
70// same thread that holds the mutex and memory access can use relaxed ordering
71// since we're not dealing with multiple threads. If it's not equal,
72// synchronization is left to the mutex, making relaxed memory ordering for
73// the `owner` field fine in all cases.
74//
75// On systems without 64 bit atomics we also store the address of a TLS variable
76// along the 64-bit TID. We then first check that address against the address
77// of that variable on the current thread, and only if they compare equal do we
78// compare the actual TIDs. Because we only ever read the TID on the same thread
79// that it was written on (or a thread sharing the TLS block with that writer thread),
80// we don't need to further synchronize the TID accesses, so they can be regular 64-bit
81// non-atomic accesses.
82#[unstable(feature = "reentrant_lock", issue = "121440")]
83pub struct ReentrantLock<T: ?Sized> {
84    mutex: sys::Mutex,
85    owner: Tid,
86    lock_count: UnsafeCell<u32>,
87    data: T,
88}
89
90cfg_if!(
91    if #[cfg(target_has_atomic = "64")] {
92        use crate::sync::atomic::{AtomicU64, Ordering::Relaxed};
93
94        struct Tid(AtomicU64);
95
96        impl Tid {
97            const fn new() -> Self {
98                Self(AtomicU64::new(0))
99            }
100
101            #[inline]
102            fn contains(&self, owner: ThreadId) -> bool {
103                owner.as_u64().get() == self.0.load(Relaxed)
104            }
105
106            #[inline]
107            // This is just unsafe to match the API of the Tid type below.
108            unsafe fn set(&self, tid: Option<ThreadId>) {
109                let value = tid.map_or(0, |tid| tid.as_u64().get());
110                self.0.store(value, Relaxed);
111            }
112        }
113    } else {
114        /// Returns the address of a TLS variable. This is guaranteed to
115        /// be unique across all currently alive threads.
116        fn tls_addr() -> usize {
117            thread_local! { static X: u8 = const { 0u8 } };
118
119            X.with(|p| <*const u8>::addr(p))
120        }
121
122        use crate::sync::atomic::{
123            AtomicUsize,
124            Ordering,
125        };
126
127        struct Tid {
128            // When a thread calls `set()`, this value gets updated to
129            // the address of a thread local on that thread. This is
130            // used as a first check in `contains()`; if the `tls_addr`
131            // doesn't match the TLS address of the current thread, then
132            // the ThreadId also can't match. Only if the TLS addresses do
133            // match do we read out the actual TID.
134            // Note also that we can use relaxed atomic operations here, because
135            // we only ever read from the tid if `tls_addr` matches the current
136            // TLS address. In that case, either the tid has been set by
137            // the current thread, or by a thread that has terminated before
138            // the current thread was created. In either case, no further
139            // synchronization is needed (as per <https://github.com/rust-lang/miri/issues/3450>)
140            tls_addr: AtomicUsize,
141            tid: UnsafeCell<u64>,
142        }
143
144        unsafe impl Send for Tid {}
145        unsafe impl Sync for Tid {}
146
147        impl Tid {
148            const fn new() -> Self {
149                Self { tls_addr: AtomicUsize::new(0), tid: UnsafeCell::new(0) }
150            }
151
152            #[inline]
153            // NOTE: This assumes that `owner` is the ID of the current
154            // thread, and may spuriously return `false` if that's not the case.
155            fn contains(&self, owner: ThreadId) -> bool {
156                // SAFETY: See the comments in the struct definition.
157                self.tls_addr.load(Ordering::Relaxed) == tls_addr()
158                    && unsafe { *self.tid.get() } == owner.as_u64().get()
159            }
160
161            #[inline]
162            // This may only be called by one thread at a time, and can lead to
163            // race conditions otherwise.
164            unsafe fn set(&self, tid: Option<ThreadId>) {
165                // It's important that we set `self.tls_addr` to 0 if the tid is
166                // cleared. Otherwise, there might be race conditions between
167                // `set()` and `get()`.
168                let tls_addr = if tid.is_some() { tls_addr() } else { 0 };
169                let value = tid.map_or(0, |tid| tid.as_u64().get());
170                self.tls_addr.store(tls_addr, Ordering::Relaxed);
171                unsafe { *self.tid.get() = value };
172            }
173        }
174    }
175);
176
177#[unstable(feature = "reentrant_lock", issue = "121440")]
178unsafe impl<T: Send + ?Sized> Send for ReentrantLock<T> {}
179#[unstable(feature = "reentrant_lock", issue = "121440")]
180unsafe impl<T: Send + ?Sized> Sync for ReentrantLock<T> {}
181
182// Because of the `UnsafeCell`, these traits are not implemented automatically
183#[unstable(feature = "reentrant_lock", issue = "121440")]
184impl<T: UnwindSafe + ?Sized> UnwindSafe for ReentrantLock<T> {}
185#[unstable(feature = "reentrant_lock", issue = "121440")]
186impl<T: RefUnwindSafe + ?Sized> RefUnwindSafe for ReentrantLock<T> {}
187
188/// An RAII implementation of a "scoped lock" of a re-entrant lock. When this
189/// structure is dropped (falls out of scope), the lock will be unlocked.
190///
191/// The data protected by the mutex can be accessed through this guard via its
192/// [`Deref`] implementation.
193///
194/// This structure is created by the [`lock`](ReentrantLock::lock) method on
195/// [`ReentrantLock`].
196///
197/// # Mutability
198///
199/// Unlike [`MutexGuard`](super::MutexGuard), `ReentrantLockGuard` does not
200/// implement [`DerefMut`](crate::ops::DerefMut), because implementation of
201/// the trait would violate Rust’s reference aliasing rules. Use interior
202/// mutability (usually [`RefCell`](crate::cell::RefCell)) in order to mutate
203/// the guarded data.
204#[must_use = "if unused the ReentrantLock will immediately unlock"]
205#[unstable(feature = "reentrant_lock", issue = "121440")]
206pub struct ReentrantLockGuard<'a, T: ?Sized + 'a> {
207    lock: &'a ReentrantLock<T>,
208}
209
210#[unstable(feature = "reentrant_lock", issue = "121440")]
211impl<T: ?Sized> !Send for ReentrantLockGuard<'_, T> {}
212
213#[unstable(feature = "reentrant_lock", issue = "121440")]
214unsafe impl<T: ?Sized + Sync> Sync for ReentrantLockGuard<'_, T> {}
215
216#[unstable(feature = "reentrant_lock", issue = "121440")]
217impl<T> ReentrantLock<T> {
218    /// Creates a new re-entrant lock in an unlocked state ready for use.
219    ///
220    /// # Examples
221    ///
222    /// ```
223    /// #![feature(reentrant_lock)]
224    /// use std::sync::ReentrantLock;
225    ///
226    /// let lock = ReentrantLock::new(0);
227    /// ```
228    pub const fn new(t: T) -> ReentrantLock<T> {
229        ReentrantLock {
230            mutex: sys::Mutex::new(),
231            owner: Tid::new(),
232            lock_count: UnsafeCell::new(0),
233            data: t,
234        }
235    }
236
237    /// Consumes this lock, returning the underlying data.
238    ///
239    /// # Examples
240    ///
241    /// ```
242    /// #![feature(reentrant_lock)]
243    ///
244    /// use std::sync::ReentrantLock;
245    ///
246    /// let lock = ReentrantLock::new(0);
247    /// assert_eq!(lock.into_inner(), 0);
248    /// ```
249    pub fn into_inner(self) -> T {
250        self.data
251    }
252}
253
254#[unstable(feature = "reentrant_lock", issue = "121440")]
255impl<T: ?Sized> ReentrantLock<T> {
256    /// Acquires the lock, blocking the current thread until it is able to do
257    /// so.
258    ///
259    /// This function will block the caller until it is available to acquire
260    /// the lock. Upon returning, the thread is the only thread with the lock
261    /// held. When the thread calling this method already holds the lock, the
262    /// call succeeds without blocking.
263    ///
264    /// # Examples
265    ///
266    /// ```
267    /// #![feature(reentrant_lock)]
268    /// use std::cell::Cell;
269    /// use std::sync::{Arc, ReentrantLock};
270    /// use std::thread;
271    ///
272    /// let lock = Arc::new(ReentrantLock::new(Cell::new(0)));
273    /// let c_lock = Arc::clone(&lock);
274    ///
275    /// thread::spawn(move || {
276    ///     c_lock.lock().set(10);
277    /// }).join().expect("thread::spawn failed");
278    /// assert_eq!(lock.lock().get(), 10);
279    /// ```
280    pub fn lock(&self) -> ReentrantLockGuard<'_, T> {
281        let this_thread = current_id();
282        // Safety: We only touch lock_count when we own the inner mutex.
283        // Additionally, we only call `self.owner.set()` while holding
284        // the inner mutex, so no two threads can call it concurrently.
285        unsafe {
286            if self.owner.contains(this_thread) {
287                self.increment_lock_count().expect("lock count overflow in reentrant mutex");
288            } else {
289                self.mutex.lock();
290                self.owner.set(Some(this_thread));
291                debug_assert_eq!(*self.lock_count.get(), 0);
292                *self.lock_count.get() = 1;
293            }
294        }
295        ReentrantLockGuard { lock: self }
296    }
297
298    /// Returns a mutable reference to the underlying data.
299    ///
300    /// Since this call borrows the `ReentrantLock` mutably, no actual locking
301    /// needs to take place -- the mutable borrow statically guarantees no locks
302    /// exist.
303    ///
304    /// # Examples
305    ///
306    /// ```
307    /// #![feature(reentrant_lock)]
308    /// use std::sync::ReentrantLock;
309    ///
310    /// let mut lock = ReentrantLock::new(0);
311    /// *lock.get_mut() = 10;
312    /// assert_eq!(*lock.lock(), 10);
313    /// ```
314    pub fn get_mut(&mut self) -> &mut T {
315        &mut self.data
316    }
317
318    /// Attempts to acquire this lock.
319    ///
320    /// If the lock could not be acquired at this time, then `None` is returned.
321    /// Otherwise, an RAII guard is returned.
322    ///
323    /// This function does not block.
324    // FIXME maybe make it a public part of the API?
325    #[unstable(issue = "none", feature = "std_internals")]
326    #[doc(hidden)]
327    pub fn try_lock(&self) -> Option<ReentrantLockGuard<'_, T>> {
328        let this_thread = current_id();
329        // Safety: We only touch lock_count when we own the inner mutex.
330        // Additionally, we only call `self.owner.set()` while holding
331        // the inner mutex, so no two threads can call it concurrently.
332        unsafe {
333            if self.owner.contains(this_thread) {
334                self.increment_lock_count()?;
335                Some(ReentrantLockGuard { lock: self })
336            } else if self.mutex.try_lock() {
337                self.owner.set(Some(this_thread));
338                debug_assert_eq!(*self.lock_count.get(), 0);
339                *self.lock_count.get() = 1;
340                Some(ReentrantLockGuard { lock: self })
341            } else {
342                None
343            }
344        }
345    }
346
347    unsafe fn increment_lock_count(&self) -> Option<()> {
348        unsafe {
349            *self.lock_count.get() = (*self.lock_count.get()).checked_add(1)?;
350        }
351        Some(())
352    }
353}
354
355#[unstable(feature = "reentrant_lock", issue = "121440")]
356impl<T: fmt::Debug + ?Sized> fmt::Debug for ReentrantLock<T> {
357    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
358        let mut d = f.debug_struct("ReentrantLock");
359        match self.try_lock() {
360            Some(v) => d.field("data", &&*v),
361            None => d.field("data", &format_args!("<locked>")),
362        };
363        d.finish_non_exhaustive()
364    }
365}
366
367#[unstable(feature = "reentrant_lock", issue = "121440")]
368impl<T: Default> Default for ReentrantLock<T> {
369    fn default() -> Self {
370        Self::new(T::default())
371    }
372}
373
374#[unstable(feature = "reentrant_lock", issue = "121440")]
375impl<T> From<T> for ReentrantLock<T> {
376    fn from(t: T) -> Self {
377        Self::new(t)
378    }
379}
380
381#[unstable(feature = "reentrant_lock", issue = "121440")]
382impl<T: ?Sized> Deref for ReentrantLockGuard<'_, T> {
383    type Target = T;
384
385    fn deref(&self) -> &T {
386        &self.lock.data
387    }
388}
389
390#[unstable(feature = "reentrant_lock", issue = "121440")]
391impl<T: fmt::Debug + ?Sized> fmt::Debug for ReentrantLockGuard<'_, T> {
392    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
393        (**self).fmt(f)
394    }
395}
396
397#[unstable(feature = "reentrant_lock", issue = "121440")]
398impl<T: fmt::Display + ?Sized> fmt::Display for ReentrantLockGuard<'_, T> {
399    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
400        (**self).fmt(f)
401    }
402}
403
404#[unstable(feature = "reentrant_lock", issue = "121440")]
405impl<T: ?Sized> Drop for ReentrantLockGuard<'_, T> {
406    #[inline]
407    fn drop(&mut self) {
408        // Safety: We own the lock.
409        unsafe {
410            *self.lock.lock_count.get() -= 1;
411            if *self.lock.lock_count.get() == 0 {
412                self.lock.owner.set(None);
413                self.lock.mutex.unlock();
414            }
415        }
416    }
417}