std/sync/mpmc/
array.rs

1//! Bounded channel based on a preallocated array.
2//!
3//! This flavor has a fixed, positive capacity.
4//!
5//! The implementation is based on Dmitry Vyukov's bounded MPMC queue.
6//!
7//! Source:
8//!   - <http://www.1024cores.net/home/lock-free-algorithms/queues/bounded-mpmc-queue>
9//!   - <https://docs.google.com/document/d/1yIAYmbvL3JxOKOjuCyon7JhW4cSv1wy5hC0ApeGMV9s/pub>
10
11use super::context::Context;
12use super::error::*;
13use super::select::{Operation, Selected, Token};
14use super::utils::{Backoff, CachePadded};
15use super::waker::SyncWaker;
16use crate::cell::UnsafeCell;
17use crate::mem::MaybeUninit;
18use crate::ptr;
19use crate::sync::atomic::{self, AtomicUsize, Ordering};
20use crate::time::Instant;
21
22/// A slot in a channel.
23struct Slot<T> {
24    /// The current stamp.
25    stamp: AtomicUsize,
26
27    /// The message in this slot. Either read out in `read` or dropped through
28    /// `discard_all_messages`.
29    msg: UnsafeCell<MaybeUninit<T>>,
30}
31
32/// The token type for the array flavor.
33#[derive(Debug)]
34pub(crate) struct ArrayToken {
35    /// Slot to read from or write to.
36    slot: *const u8,
37
38    /// Stamp to store into the slot after reading or writing.
39    stamp: usize,
40}
41
42impl Default for ArrayToken {
43    #[inline]
44    fn default() -> Self {
45        ArrayToken { slot: ptr::null(), stamp: 0 }
46    }
47}
48
49/// Bounded channel based on a preallocated array.
50pub(crate) struct Channel<T> {
51    /// The head of the channel.
52    ///
53    /// This value is a "stamp" consisting of an index into the buffer, a mark bit, and a lap, but
54    /// packed into a single `usize`. The lower bits represent the index, while the upper bits
55    /// represent the lap. The mark bit in the head is always zero.
56    ///
57    /// Messages are popped from the head of the channel.
58    head: CachePadded<AtomicUsize>,
59
60    /// The tail of the channel.
61    ///
62    /// This value is a "stamp" consisting of an index into the buffer, a mark bit, and a lap, but
63    /// packed into a single `usize`. The lower bits represent the index, while the upper bits
64    /// represent the lap. The mark bit indicates that the channel is disconnected.
65    ///
66    /// Messages are pushed into the tail of the channel.
67    tail: CachePadded<AtomicUsize>,
68
69    /// The buffer holding slots.
70    buffer: Box<[Slot<T>]>,
71
72    /// The channel capacity.
73    cap: usize,
74
75    /// A stamp with the value of `{ lap: 1, mark: 0, index: 0 }`.
76    one_lap: usize,
77
78    /// If this bit is set in the tail, that means the channel is disconnected.
79    mark_bit: usize,
80
81    /// Senders waiting while the channel is full.
82    senders: SyncWaker,
83
84    /// Receivers waiting while the channel is empty and not disconnected.
85    receivers: SyncWaker,
86}
87
88impl<T> Channel<T> {
89    /// Creates a bounded channel of capacity `cap`.
90    pub(crate) fn with_capacity(cap: usize) -> Self {
91        assert!(cap > 0, "capacity must be positive");
92
93        // Compute constants `mark_bit` and `one_lap`.
94        let mark_bit = (cap + 1).next_power_of_two();
95        let one_lap = mark_bit * 2;
96
97        // Head is initialized to `{ lap: 0, mark: 0, index: 0 }`.
98        let head = 0;
99        // Tail is initialized to `{ lap: 0, mark: 0, index: 0 }`.
100        let tail = 0;
101
102        // Allocate a buffer of `cap` slots initialized
103        // with stamps.
104        let buffer: Box<[Slot<T>]> = (0..cap)
105            .map(|i| {
106                // Set the stamp to `{ lap: 0, mark: 0, index: i }`.
107                Slot { stamp: AtomicUsize::new(i), msg: UnsafeCell::new(MaybeUninit::uninit()) }
108            })
109            .collect();
110
111        Channel {
112            buffer,
113            cap,
114            one_lap,
115            mark_bit,
116            head: CachePadded::new(AtomicUsize::new(head)),
117            tail: CachePadded::new(AtomicUsize::new(tail)),
118            senders: SyncWaker::new(),
119            receivers: SyncWaker::new(),
120        }
121    }
122
123    /// Attempts to reserve a slot for sending a message.
124    fn start_send(&self, token: &mut Token) -> bool {
125        let backoff = Backoff::new();
126        let mut tail = self.tail.load(Ordering::Relaxed);
127
128        loop {
129            // Check if the channel is disconnected.
130            if tail & self.mark_bit != 0 {
131                token.array.slot = ptr::null();
132                token.array.stamp = 0;
133                return true;
134            }
135
136            // Deconstruct the tail.
137            let index = tail & (self.mark_bit - 1);
138            let lap = tail & !(self.one_lap - 1);
139
140            // Inspect the corresponding slot.
141            debug_assert!(index < self.buffer.len());
142            let slot = unsafe { self.buffer.get_unchecked(index) };
143            let stamp = slot.stamp.load(Ordering::Acquire);
144
145            // If the tail and the stamp match, we may attempt to push.
146            if tail == stamp {
147                let new_tail = if index + 1 < self.cap {
148                    // Same lap, incremented index.
149                    // Set to `{ lap: lap, mark: 0, index: index + 1 }`.
150                    tail + 1
151                } else {
152                    // One lap forward, index wraps around to zero.
153                    // Set to `{ lap: lap.wrapping_add(1), mark: 0, index: 0 }`.
154                    lap.wrapping_add(self.one_lap)
155                };
156
157                // Try moving the tail.
158                match self.tail.compare_exchange_weak(
159                    tail,
160                    new_tail,
161                    Ordering::SeqCst,
162                    Ordering::Relaxed,
163                ) {
164                    Ok(_) => {
165                        // Prepare the token for the follow-up call to `write`.
166                        token.array.slot = slot as *const Slot<T> as *const u8;
167                        token.array.stamp = tail + 1;
168                        return true;
169                    }
170                    Err(_) => {
171                        backoff.spin_light();
172                        tail = self.tail.load(Ordering::Relaxed);
173                    }
174                }
175            } else if stamp.wrapping_add(self.one_lap) == tail + 1 {
176                atomic::fence(Ordering::SeqCst);
177                let head = self.head.load(Ordering::Relaxed);
178
179                // If the head lags one lap behind the tail as well...
180                if head.wrapping_add(self.one_lap) == tail {
181                    // ...then the channel is full.
182                    return false;
183                }
184
185                backoff.spin_light();
186                tail = self.tail.load(Ordering::Relaxed);
187            } else {
188                // Snooze because we need to wait for the stamp to get updated.
189                backoff.spin_heavy();
190                tail = self.tail.load(Ordering::Relaxed);
191            }
192        }
193    }
194
195    /// Writes a message into the channel.
196    pub(crate) unsafe fn write(&self, token: &mut Token, msg: T) -> Result<(), T> {
197        // If there is no slot, the channel is disconnected.
198        if token.array.slot.is_null() {
199            return Err(msg);
200        }
201
202        // Write the message into the slot and update the stamp.
203        unsafe {
204            let slot: &Slot<T> = &*(token.array.slot as *const Slot<T>);
205            slot.msg.get().write(MaybeUninit::new(msg));
206            slot.stamp.store(token.array.stamp, Ordering::Release);
207        }
208
209        // Wake a sleeping receiver.
210        self.receivers.notify();
211        Ok(())
212    }
213
214    /// Attempts to reserve a slot for receiving a message.
215    fn start_recv(&self, token: &mut Token) -> bool {
216        let backoff = Backoff::new();
217        let mut head = self.head.load(Ordering::Relaxed);
218
219        loop {
220            // Deconstruct the head.
221            let index = head & (self.mark_bit - 1);
222            let lap = head & !(self.one_lap - 1);
223
224            // Inspect the corresponding slot.
225            debug_assert!(index < self.buffer.len());
226            let slot = unsafe { self.buffer.get_unchecked(index) };
227            let stamp = slot.stamp.load(Ordering::Acquire);
228
229            // If the stamp is ahead of the head by 1, we may attempt to pop.
230            if head + 1 == stamp {
231                let new = if index + 1 < self.cap {
232                    // Same lap, incremented index.
233                    // Set to `{ lap: lap, mark: 0, index: index + 1 }`.
234                    head + 1
235                } else {
236                    // One lap forward, index wraps around to zero.
237                    // Set to `{ lap: lap.wrapping_add(1), mark: 0, index: 0 }`.
238                    lap.wrapping_add(self.one_lap)
239                };
240
241                // Try moving the head.
242                match self.head.compare_exchange_weak(
243                    head,
244                    new,
245                    Ordering::SeqCst,
246                    Ordering::Relaxed,
247                ) {
248                    Ok(_) => {
249                        // Prepare the token for the follow-up call to `read`.
250                        token.array.slot = slot as *const Slot<T> as *const u8;
251                        token.array.stamp = head.wrapping_add(self.one_lap);
252                        return true;
253                    }
254                    Err(_) => {
255                        backoff.spin_light();
256                        head = self.head.load(Ordering::Relaxed);
257                    }
258                }
259            } else if stamp == head {
260                atomic::fence(Ordering::SeqCst);
261                let tail = self.tail.load(Ordering::Relaxed);
262
263                // If the tail equals the head, that means the channel is empty.
264                if (tail & !self.mark_bit) == head {
265                    // If the channel is disconnected...
266                    if tail & self.mark_bit != 0 {
267                        // ...then receive an error.
268                        token.array.slot = ptr::null();
269                        token.array.stamp = 0;
270                        return true;
271                    } else {
272                        // Otherwise, the receive operation is not ready.
273                        return false;
274                    }
275                }
276
277                backoff.spin_light();
278                head = self.head.load(Ordering::Relaxed);
279            } else {
280                // Snooze because we need to wait for the stamp to get updated.
281                backoff.spin_heavy();
282                head = self.head.load(Ordering::Relaxed);
283            }
284        }
285    }
286
287    /// Reads a message from the channel.
288    pub(crate) unsafe fn read(&self, token: &mut Token) -> Result<T, ()> {
289        if token.array.slot.is_null() {
290            // The channel is disconnected.
291            return Err(());
292        }
293
294        // Read the message from the slot and update the stamp.
295        let msg = unsafe {
296            let slot: &Slot<T> = &*(token.array.slot as *const Slot<T>);
297
298            let msg = slot.msg.get().read().assume_init();
299            slot.stamp.store(token.array.stamp, Ordering::Release);
300            msg
301        };
302
303        // Wake a sleeping sender.
304        self.senders.notify();
305        Ok(msg)
306    }
307
308    /// Attempts to send a message into the channel.
309    pub(crate) fn try_send(&self, msg: T) -> Result<(), TrySendError<T>> {
310        let token = &mut Token::default();
311        if self.start_send(token) {
312            unsafe { self.write(token, msg).map_err(TrySendError::Disconnected) }
313        } else {
314            Err(TrySendError::Full(msg))
315        }
316    }
317
318    /// Sends a message into the channel.
319    pub(crate) fn send(
320        &self,
321        msg: T,
322        deadline: Option<Instant>,
323    ) -> Result<(), SendTimeoutError<T>> {
324        let token = &mut Token::default();
325        loop {
326            // Try sending a message.
327            if self.start_send(token) {
328                let res = unsafe { self.write(token, msg) };
329                return res.map_err(SendTimeoutError::Disconnected);
330            }
331
332            if let Some(d) = deadline {
333                if Instant::now() >= d {
334                    return Err(SendTimeoutError::Timeout(msg));
335                }
336            }
337
338            Context::with(|cx| {
339                // Prepare for blocking until a receiver wakes us up.
340                let oper = Operation::hook(token);
341                self.senders.register(oper, cx);
342
343                // Has the channel become ready just now?
344                if !self.is_full() || self.is_disconnected() {
345                    let _ = cx.try_select(Selected::Aborted);
346                }
347
348                // Block the current thread.
349                // SAFETY: the context belongs to the current thread.
350                let sel = unsafe { cx.wait_until(deadline) };
351
352                match sel {
353                    Selected::Waiting => unreachable!(),
354                    Selected::Aborted | Selected::Disconnected => {
355                        self.senders.unregister(oper).unwrap();
356                    }
357                    Selected::Operation(_) => {}
358                }
359            });
360        }
361    }
362
363    /// Attempts to receive a message without blocking.
364    pub(crate) fn try_recv(&self) -> Result<T, TryRecvError> {
365        let token = &mut Token::default();
366
367        if self.start_recv(token) {
368            unsafe { self.read(token).map_err(|_| TryRecvError::Disconnected) }
369        } else {
370            Err(TryRecvError::Empty)
371        }
372    }
373
374    /// Receives a message from the channel.
375    pub(crate) fn recv(&self, deadline: Option<Instant>) -> Result<T, RecvTimeoutError> {
376        let token = &mut Token::default();
377        loop {
378            // Try receiving a message.
379            if self.start_recv(token) {
380                let res = unsafe { self.read(token) };
381                return res.map_err(|_| RecvTimeoutError::Disconnected);
382            }
383
384            if let Some(d) = deadline {
385                if Instant::now() >= d {
386                    return Err(RecvTimeoutError::Timeout);
387                }
388            }
389
390            Context::with(|cx| {
391                // Prepare for blocking until a sender wakes us up.
392                let oper = Operation::hook(token);
393                self.receivers.register(oper, cx);
394
395                // Has the channel become ready just now?
396                if !self.is_empty() || self.is_disconnected() {
397                    let _ = cx.try_select(Selected::Aborted);
398                }
399
400                // Block the current thread.
401                // SAFETY: the context belongs to the current thread.
402                let sel = unsafe { cx.wait_until(deadline) };
403
404                match sel {
405                    Selected::Waiting => unreachable!(),
406                    Selected::Aborted | Selected::Disconnected => {
407                        self.receivers.unregister(oper).unwrap();
408                        // If the channel was disconnected, we still have to check for remaining
409                        // messages.
410                    }
411                    Selected::Operation(_) => {}
412                }
413            });
414        }
415    }
416
417    /// Returns the current number of messages inside the channel.
418    pub(crate) fn len(&self) -> usize {
419        loop {
420            // Load the tail, then load the head.
421            let tail = self.tail.load(Ordering::SeqCst);
422            let head = self.head.load(Ordering::SeqCst);
423
424            // If the tail didn't change, we've got consistent values to work with.
425            if self.tail.load(Ordering::SeqCst) == tail {
426                let hix = head & (self.mark_bit - 1);
427                let tix = tail & (self.mark_bit - 1);
428
429                return if hix < tix {
430                    tix - hix
431                } else if hix > tix {
432                    self.cap - hix + tix
433                } else if (tail & !self.mark_bit) == head {
434                    0
435                } else {
436                    self.cap
437                };
438            }
439        }
440    }
441
442    /// Returns the capacity of the channel.
443    #[allow(clippy::unnecessary_wraps)] // This is intentional.
444    pub(crate) fn capacity(&self) -> Option<usize> {
445        Some(self.cap)
446    }
447
448    /// Disconnects senders and wakes up all blocked receivers.
449    ///
450    /// Returns `true` if this call disconnected the channel.
451    pub(crate) fn disconnect_senders(&self) -> bool {
452        let tail = self.tail.fetch_or(self.mark_bit, Ordering::SeqCst);
453
454        if tail & self.mark_bit == 0 {
455            self.receivers.disconnect();
456            true
457        } else {
458            false
459        }
460    }
461
462    /// Disconnects receivers and wakes up all blocked senders.
463    ///
464    /// Returns `true` if this call disconnected the channel.
465    ///
466    /// # Safety
467    /// May only be called once upon dropping the last receiver. The
468    /// destruction of all other receivers must have been observed with acquire
469    /// ordering or stronger.
470    pub(crate) unsafe fn disconnect_receivers(&self) -> bool {
471        let tail = self.tail.fetch_or(self.mark_bit, Ordering::SeqCst);
472        let disconnected = if tail & self.mark_bit == 0 {
473            self.senders.disconnect();
474            true
475        } else {
476            false
477        };
478
479        unsafe { self.discard_all_messages(tail) };
480        disconnected
481    }
482
483    /// Discards all messages.
484    ///
485    /// `tail` should be the current (and therefore last) value of `tail`.
486    ///
487    /// # Panicking
488    /// If a destructor panics, the remaining messages are leaked, matching the
489    /// behavior of the unbounded channel.
490    ///
491    /// # Safety
492    /// This method must only be called when dropping the last receiver. The
493    /// destruction of all other receivers must have been observed with acquire
494    /// ordering or stronger.
495    unsafe fn discard_all_messages(&self, tail: usize) {
496        debug_assert!(self.is_disconnected());
497
498        // Only receivers modify `head`, so since we are the last one,
499        // this value will not change and will not be observed (since
500        // no new messages can be sent after disconnection).
501        let mut head = self.head.load(Ordering::Relaxed);
502        let tail = tail & !self.mark_bit;
503
504        let backoff = Backoff::new();
505        loop {
506            // Deconstruct the head.
507            let index = head & (self.mark_bit - 1);
508            let lap = head & !(self.one_lap - 1);
509
510            // Inspect the corresponding slot.
511            debug_assert!(index < self.buffer.len());
512            let slot = unsafe { self.buffer.get_unchecked(index) };
513            let stamp = slot.stamp.load(Ordering::Acquire);
514
515            // If the stamp is ahead of the head by 1, we may drop the message.
516            if head + 1 == stamp {
517                head = if index + 1 < self.cap {
518                    // Same lap, incremented index.
519                    // Set to `{ lap: lap, mark: 0, index: index + 1 }`.
520                    head + 1
521                } else {
522                    // One lap forward, index wraps around to zero.
523                    // Set to `{ lap: lap.wrapping_add(1), mark: 0, index: 0 }`.
524                    lap.wrapping_add(self.one_lap)
525                };
526
527                unsafe {
528                    (*slot.msg.get()).assume_init_drop();
529                }
530            // If the tail equals the head, that means the channel is empty.
531            } else if tail == head {
532                return;
533            // Otherwise, a sender is about to write into the slot, so we need
534            // to wait for it to update the stamp.
535            } else {
536                backoff.spin_heavy();
537            }
538        }
539    }
540
541    /// Returns `true` if the channel is disconnected.
542    pub(crate) fn is_disconnected(&self) -> bool {
543        self.tail.load(Ordering::SeqCst) & self.mark_bit != 0
544    }
545
546    /// Returns `true` if the channel is empty.
547    pub(crate) fn is_empty(&self) -> bool {
548        let head = self.head.load(Ordering::SeqCst);
549        let tail = self.tail.load(Ordering::SeqCst);
550
551        // Is the tail equal to the head?
552        //
553        // Note: If the head changes just before we load the tail, that means there was a moment
554        // when the channel was not empty, so it is safe to just return `false`.
555        (tail & !self.mark_bit) == head
556    }
557
558    /// Returns `true` if the channel is full.
559    pub(crate) fn is_full(&self) -> bool {
560        let tail = self.tail.load(Ordering::SeqCst);
561        let head = self.head.load(Ordering::SeqCst);
562
563        // Is the head lagging one lap behind tail?
564        //
565        // Note: If the tail changes just before we load the head, that means there was a moment
566        // when the channel was not full, so it is safe to just return `false`.
567        head.wrapping_add(self.one_lap) == tail & !self.mark_bit
568    }
569}