std/sync/
barrier.rs

1use crate::fmt;
2// FIXME(nonpoison_mutex,nonpoison_condvar): switch to nonpoison versions once they are available
3use crate::sync::{Condvar, Mutex};
4
5/// A barrier enables multiple threads to synchronize the beginning
6/// of some computation.
7///
8/// # Examples
9///
10/// ```
11/// use std::sync::Barrier;
12/// use std::thread;
13///
14/// let n = 10;
15/// let barrier = Barrier::new(n);
16/// thread::scope(|s| {
17///     for _ in 0..n {
18///         // The same messages will be printed together.
19///         // You will NOT see any interleaving.
20///         s.spawn(|| {
21///             println!("before wait");
22///             barrier.wait();
23///             println!("after wait");
24///         });
25///     }
26/// });
27/// ```
28#[stable(feature = "rust1", since = "1.0.0")]
29pub struct Barrier {
30    lock: Mutex<BarrierState>,
31    cvar: Condvar,
32    num_threads: usize,
33}
34
35// The inner state of a double barrier
36struct BarrierState {
37    count: usize,
38    generation_id: usize,
39}
40
41/// A `BarrierWaitResult` is returned by [`Barrier::wait()`] when all threads
42/// in the [`Barrier`] have rendezvoused.
43///
44/// # Examples
45///
46/// ```
47/// use std::sync::Barrier;
48///
49/// let barrier = Barrier::new(1);
50/// let barrier_wait_result = barrier.wait();
51/// ```
52#[stable(feature = "rust1", since = "1.0.0")]
53pub struct BarrierWaitResult(bool);
54
55#[stable(feature = "std_debug", since = "1.16.0")]
56impl fmt::Debug for Barrier {
57    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
58        f.debug_struct("Barrier").finish_non_exhaustive()
59    }
60}
61
62impl Barrier {
63    /// Creates a new barrier that can block a given number of threads.
64    ///
65    /// A barrier will block `n`-1 threads which call [`wait()`] and then wake
66    /// up all threads at once when the `n`th thread calls [`wait()`].
67    ///
68    /// [`wait()`]: Barrier::wait
69    ///
70    /// # Examples
71    ///
72    /// ```
73    /// use std::sync::Barrier;
74    ///
75    /// let barrier = Barrier::new(10);
76    /// ```
77    #[stable(feature = "rust1", since = "1.0.0")]
78    #[rustc_const_stable(feature = "const_barrier", since = "1.78.0")]
79    #[must_use]
80    #[inline]
81    pub const fn new(n: usize) -> Barrier {
82        Barrier {
83            lock: Mutex::new(BarrierState { count: 0, generation_id: 0 }),
84            cvar: Condvar::new(),
85            num_threads: n,
86        }
87    }
88
89    /// Blocks the current thread until all threads have rendezvoused here.
90    ///
91    /// Barriers are re-usable after all threads have rendezvoused once, and can
92    /// be used continuously.
93    ///
94    /// A single (arbitrary) thread will receive a [`BarrierWaitResult`] that
95    /// returns `true` from [`BarrierWaitResult::is_leader()`] when returning
96    /// from this function, and all other threads will receive a result that
97    /// will return `false` from [`BarrierWaitResult::is_leader()`].
98    ///
99    /// # Examples
100    ///
101    /// ```
102    /// use std::sync::Barrier;
103    /// use std::thread;
104    ///
105    /// let n = 10;
106    /// let barrier = Barrier::new(n);
107    /// thread::scope(|s| {
108    ///     for _ in 0..n {
109    ///         // The same messages will be printed together.
110    ///         // You will NOT see any interleaving.
111    ///         s.spawn(|| {
112    ///             println!("before wait");
113    ///             barrier.wait();
114    ///             println!("after wait");
115    ///         });
116    ///     }
117    /// });
118    /// ```
119    #[stable(feature = "rust1", since = "1.0.0")]
120    pub fn wait(&self) -> BarrierWaitResult {
121        let mut lock = self.lock.lock().unwrap();
122        let local_gen = lock.generation_id;
123        lock.count += 1;
124        if lock.count < self.num_threads {
125            let _guard =
126                self.cvar.wait_while(lock, |state| local_gen == state.generation_id).unwrap();
127            BarrierWaitResult(false)
128        } else {
129            lock.count = 0;
130            lock.generation_id = lock.generation_id.wrapping_add(1);
131            self.cvar.notify_all();
132            BarrierWaitResult(true)
133        }
134    }
135}
136
137#[stable(feature = "std_debug", since = "1.16.0")]
138impl fmt::Debug for BarrierWaitResult {
139    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
140        f.debug_struct("BarrierWaitResult").field("is_leader", &self.is_leader()).finish()
141    }
142}
143
144impl BarrierWaitResult {
145    /// Returns `true` if this thread is the "leader thread" for the call to
146    /// [`Barrier::wait()`].
147    ///
148    /// Only one thread will have `true` returned from their result, all other
149    /// threads will have `false` returned.
150    ///
151    /// # Examples
152    ///
153    /// ```
154    /// use std::sync::Barrier;
155    ///
156    /// let barrier = Barrier::new(1);
157    /// let barrier_wait_result = barrier.wait();
158    /// println!("{:?}", barrier_wait_result.is_leader());
159    /// ```
160    #[stable(feature = "rust1", since = "1.0.0")]
161    #[must_use]
162    pub fn is_leader(&self) -> bool {
163        self.0
164    }
165}