Skip to main content

cargo/util/
local_poll_adapter.rs

1use futures::{FutureExt, future::LocalBoxFuture, stream::FuturesUnordered};
2use std::{collections::HashMap, hash::Hash, ops::Deref, task::Poll};
3
4/// A local (!Send) adapter for caching and executing an async method
5/// from a non-async context.
6///
7/// The `self_parameter`, `key`, and successful (Ok) results must all be cheap to `clone`.
8///
9/// Ensures at most one in-flight computation per key. Results are:
10/// - cached on success
11/// - not retained on error
12pub struct LocalPollAdapter<'a, S, K, R> {
13    pool: FuturesUnordered<LocalBoxFuture<'a, (K, R)>>,
14    cache: HashMap<K, Poll<R>>,
15    self_parameter: S,
16}
17
18impl<'a, S, K, V, E> LocalPollAdapter<'a, S, K, Result<V, E>>
19where
20    S: Clone + Deref + 'a,
21    K: Clone + Hash + Eq + 'a,
22    V: Clone,
23{
24    pub fn new(self_parameter: S) -> Self {
25        Self {
26            pool: FuturesUnordered::new(),
27            cache: HashMap::new(),
28            self_parameter,
29        }
30    }
31
32    /// Polls the result for `key`, spawning work if needed.
33    ///
34    /// If this function returns [`Poll::Pending`], call [`LocalPollAdapter::wait`]
35    /// to execute the work, then call this function again with the same key
36    /// to pick up the result.
37    ///
38    /// Futures that complete immediately are not queued.
39    pub fn poll<F>(&mut self, f: F, key: K) -> Poll<Result<V, E>>
40    where
41        F: AsyncFn(&S::Target, &K) -> Result<V, E> + 'a,
42    {
43        match self.cache.get(&key) {
44            // We have a cached success value, clone it and return.
45            Some(Poll::Ready(Ok(v))) => return Poll::Ready(Ok(v.clone())),
46            // We have a cached error value, remove it and return.
47            // Errors are not Clone, so they are only stored once.
48            Some(Poll::Ready(Err(_))) => return self.cache.remove(&key).unwrap(),
49            // This key is already pending.
50            Some(Poll::Pending) => return Poll::Pending,
51            // Looks like we have work to do!
52            None => {}
53        }
54
55        // Created a pinned future that executes the function,
56        // returning the key and the result.
57        let mut future = {
58            let key = key.clone();
59            let self_parameter = self.self_parameter.clone();
60            async move {
61                let v = f(self_parameter.deref(), &key).await;
62                (key, v)
63            }
64            .boxed_local()
65        };
66
67        // Attempt to run the future immediately. If it has no `await` yields,
68        // it will return here.
69        if let Some((k, v)) = (&mut future).now_or_never() {
70            if let Ok(success) = &v {
71                // Only cache successful results.
72                self.cache.insert(k, Poll::Ready(Ok(success.clone())));
73            }
74            return Poll::Ready(v);
75        }
76
77        // Insert Pending into the cache so we avoid queuing the same future twice.
78        self.cache.insert(key.clone(), Poll::Pending);
79
80        // Add the future to the pending queue.
81        self.pool.push(future);
82        Poll::Pending
83    }
84
85    /// Returns the number of pending futures.
86    pub fn pending_count(&self) -> usize {
87        self.pool.len()
88    }
89
90    /// Run all pending futures. Returns true if there was no work to do.
91    pub fn wait(&mut self) -> bool {
92        let is_empty = self.pool.is_empty();
93        for (k, v) in crate::util::block_on_stream(&mut self.pool) {
94            *self
95                .cache
96                .get_mut(&k)
97                .expect("all pending work is in the cache") = Poll::Ready(v);
98        }
99        is_empty
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use super::LocalPollAdapter;
106    use std::{rc::Rc, task::Poll, time::Duration};
107
108    struct Thing {}
109
110    impl Thing {
111        async fn widen(&self, i: &i32) -> Result<i64, ()> {
112            if *i > 10 {
113                // Big numbers take longer to process (need to test futures that yield).
114                futures_timer::Delay::new(Duration::from_millis(1)).await
115            }
116            if *i % 2 != 0 {
117                // Odd numbers are not supported (need to test errors).
118                return Err(());
119            }
120            Ok(*i as i64)
121        }
122    }
123
124    /// Poll wrapper around `Thing`
125    struct PolledThing<'a> {
126        poller: LocalPollAdapter<'a, Rc<Thing>, i32, Result<i64, ()>>,
127    }
128
129    impl<'a> PolledThing<'a> {
130        fn new() -> Self {
131            Self {
132                poller: LocalPollAdapter::new(Rc::new(Thing {})),
133            }
134        }
135
136        // Non-async version of the widen method.
137        fn widen(&mut self, i: &i32) -> Poll<Result<i64, ()>> {
138            self.poller.poll(Thing::widen, i.clone())
139        }
140
141        fn wait(&mut self) -> bool {
142            self.poller.wait()
143        }
144    }
145
146    #[test]
147    fn immediate_success() {
148        let mut p = PolledThing::new();
149        assert_eq!(p.widen(&2), Poll::Ready(Ok(2)));
150        assert!(p.wait());
151    }
152
153    #[test]
154    fn immediate_error() {
155        let mut p = PolledThing::new();
156        assert_eq!(p.widen(&1), Poll::Ready(Err(())));
157        assert!(p.wait());
158    }
159
160    #[test]
161    fn deferred_error() {
162        let mut p = PolledThing::new();
163        assert_eq!(p.widen(&1001), Poll::Pending);
164        assert!(!p.wait());
165        assert_eq!(p.widen(&1001), Poll::Ready(Err(())));
166        assert!(p.wait());
167        // Errors are not cached
168        assert_eq!(p.widen(&1001), Poll::Pending);
169        assert!(!p.wait());
170        assert_eq!(p.widen(&1001), Poll::Ready(Err(())));
171        assert!(p.wait());
172    }
173
174    #[test]
175    fn deferred_success() {
176        let mut p = PolledThing::new();
177        assert_eq!(p.widen(&50), Poll::Pending);
178        assert!(!p.wait());
179        assert_eq!(p.widen(&50), Poll::Ready(Ok(50)));
180        assert!(p.wait());
181        // Success is cached.
182        assert_eq!(p.widen(&50), Poll::Ready(Ok(50)));
183        assert!(p.wait());
184    }
185}