Skip to main content

cargo/util/network/
http_async.rs

1//! Async wrapper around cURL for making managing HTTP requests.
2//!
3//! Requests are executed in parallel using cURL [`Multi`] on
4//! a worker thread that is owned by the Client.
5
6use std::collections::HashMap;
7use std::io::{Cursor, Read};
8use std::str::FromStr;
9use std::sync::mpsc::{self, Receiver, Sender};
10use std::thread::JoinHandle;
11use std::time::Duration;
12
13use curl::easy::WriteError;
14use curl::easy::{Easy2, Handler, InfoType};
15use curl::multi::{Easy2Handle, Multi};
16
17use crate::util::network::http::HandleConfiguration;
18use futures::channel::oneshot;
19use tracing::{debug, error, trace};
20
21type Response = http::Response<Vec<u8>>;
22type Request = http::Request<Vec<u8>>;
23type HttpResult<T> = std::result::Result<T, Error>;
24
25#[derive(Debug, Clone, thiserror::Error)]
26#[non_exhaustive]
27pub enum Error {
28    #[error("curl multi failed")]
29    Multi(#[from] curl::MultiError),
30
31    #[error("curl failed")]
32    Easy(#[from] curl::Error),
33
34    #[error("failed to convert header value of `{name}` to string: {bytes:?}")]
35    BadHeader { name: String, bytes: Vec<u8> },
36}
37
38struct Message {
39    easy: Easy2<Collector>,
40    sender: oneshot::Sender<HttpResult<Response>>,
41}
42
43/// HTTP Client. Creating a new client spawns a cURL `Multi` and
44/// thread that is used for all HTTP requests by this client.
45pub struct Client {
46    channel: Option<Sender<Message>>,
47    thread_handle: Option<JoinHandle<()>>,
48    handle_config: HandleConfiguration,
49}
50
51impl Client {
52    /// Spawns a new worker thread where HTTP request execute.
53    pub fn new(handle_config: HandleConfiguration) -> Client {
54        let (tx, rx) = mpsc::channel();
55        let handle = std::thread::spawn(move || WorkerServer::run(rx, handle_config.multiplexing));
56        Client {
57            channel: Some(tx),
58            thread_handle: Some(handle),
59            handle_config,
60        }
61    }
62
63    /// Perform an HTTP request using this client.
64    pub async fn request(&self, request: Request) -> HttpResult<Response> {
65        let url = request.uri().to_string();
66        debug!(target: "network::fetch", url);
67        let mut collector = Collector::new();
68        let (parts, body) = request.into_parts();
69        let body_len = body.len();
70        collector.request_body = Cursor::new(body);
71        collector.debug = self.handle_config.verbose;
72        let mut handle = curl::easy::Easy2::new(collector);
73        self.handle_config.configure2(&mut handle)?;
74
75        handle.url(&url)?;
76        handle.follow_location(true)?;
77
78        match parts.method {
79            http::Method::HEAD => handle.nobody(true)?,
80            http::Method::GET => handle.get(true)?,
81            http::Method::POST => {
82                handle.post_field_size(body_len as u64)?;
83                handle.post(true)?;
84            }
85            http::Method::PUT => {
86                handle.in_filesize(body_len as u64)?;
87                handle.put(true)?;
88            }
89            method => {
90                handle.upload(true)?;
91                handle.in_filesize(body_len as u64)?;
92                handle.custom_request(method.as_str())?;
93            }
94        }
95
96        let mut headers = curl::easy::List::new();
97        for (name, value) in parts.headers {
98            if let Some(name) = name {
99                let value: &str = value.to_str().map_err(|_| Error::BadHeader {
100                    name: name.to_string(),
101                    bytes: value.as_bytes().to_owned(),
102                })?;
103                headers.append(&format!("{}: {}", name, value))?;
104            }
105        }
106        handle.http_headers(headers)?;
107
108        let (sender, receiver) = oneshot::channel();
109        let req = Message {
110            easy: handle,
111            sender,
112        };
113
114        self.channel.as_ref().unwrap().send(req).unwrap();
115        receiver.await.unwrap()
116    }
117}
118
119impl Drop for Client {
120    fn drop(&mut self) {
121        // Close the channel
122        drop(self.channel.take().unwrap());
123        // Join the thread
124        let _ = self.thread_handle.take().unwrap().join();
125    }
126}
127
128impl std::fmt::Debug for Client {
129    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
130        f.debug_struct("http_async::Client").finish()
131    }
132}
133
134/// Manages the cURL `Multi`. Processes incoming work sent over the
135/// channel, and returns responses.
136struct WorkerServer {
137    incoming_work: Receiver<Message>,
138    multi: Multi,
139    handles: HashMap<
140        usize,
141        (
142            Easy2Handle<Collector>,
143            oneshot::Sender<HttpResult<Response>>,
144        ),
145    >,
146    token: usize,
147}
148
149impl WorkerServer {
150    fn run(incoming_work: Receiver<Message>, multiplex: bool) {
151        let mut multi = Multi::new();
152        // let's not flood the server with connections
153        if let Err(e) = multi.set_max_host_connections(2) {
154            error!("failed to set max host connections in curl: {e}");
155        }
156        if let Err(e) = multi.pipelining(false, multiplex) {
157            error!("failed to enable multiplexing/pipelining in curl: {e}");
158        }
159
160        let mut worker = Self {
161            incoming_work,
162            multi,
163            handles: HashMap::new(),
164            token: 0,
165        };
166        worker.worker_loop();
167    }
168
169    fn fail_and_drain(&mut self, e: &Error) {
170        for (_token, (_handle, sender)) in self.handles.drain() {
171            let _ = sender.send(Err(e.clone()));
172        }
173    }
174
175    fn worker_loop(&mut self) {
176        const INITIAL_DELAY: Duration = Duration::from_millis(1);
177        let mut wait_backoff = INITIAL_DELAY;
178        loop {
179            // Start any pending work.
180            while let Ok(msg) = self.incoming_work.try_recv() {
181                self.enqueue_request(msg);
182                wait_backoff = INITIAL_DELAY;
183            }
184
185            match self.multi.perform() {
186                Err(e) if e.is_call_perform() => {
187                    // cURL states if you receive `is_call_perform`, this means that you should call `perform` again.
188                }
189                Err(e) => {
190                    self.fail_and_drain(&Error::Multi(e));
191                }
192                Ok(running) => {
193                    self.multi.messages(|msg| {
194                        let t = msg.token().expect("all handles have tokens");
195                        trace!(token = t, "finish");
196                        let Some((handle, sender)) = self.handles.remove(&t) else {
197                            error!("missing entry {t} in handle table");
198                            return;
199                        };
200                        let result = msg.result_for2(&handle).expect("handle must have a result");
201                        let mut easy = self.multi.remove2(handle).expect("handle must be in multi");
202                        let mut response = std::mem::replace(
203                            &mut easy.get_mut().response,
204                            Response::new(Vec::new()),
205                        );
206                        if let Ok(status) = easy.response_code()
207                            && status != 0
208                            && let Ok(status) = http::StatusCode::from_u16(status as u16)
209                        {
210                            *response.status_mut() = status;
211                        }
212                        // Would be nice to set HTTP version via `response.version_mut()`, but `curl` doesn't have it exposed.
213                        let extensions = Extensions {
214                            client_ip: easy.primary_ip().ok().flatten().map(str::to_string),
215                        };
216                        response.extensions_mut().insert(extensions);
217                        let _ = sender.send(result.map(|()| response).map_err(Into::into));
218                    });
219
220                    if running > 0 {
221                        let max_timeout = Duration::from_millis(1000);
222                        let mut timeout = self
223                            .multi
224                            .get_timeout()
225                            .ok()
226                            .flatten()
227                            .unwrap_or(max_timeout)
228                            .min(max_timeout);
229                        if timeout.is_zero() {
230                            // curl said not to wait.
231                            continue;
232                        }
233                        // Ideally we would use `Multi::poll` + a `MultiWaker` instead of `Multi::wait`
234                        // to wake the thread when new work is queued. But it requires curl 7.68+,
235                        // which is not available everywhere we support.
236                        //
237                        // Instead, we use an exponential backoff approach so that as long as requests
238                        // are being queued, we poll quickly to allow the requests to be added sooner.
239                        // Without this, we end up sitting in `Multi::wait` too long while new work is
240                        // added to the channel.
241                        //
242                        // `get_timeout` says we should wait *at most* the timeout amount, so reducing
243                        // the wait time is fine.
244                        if wait_backoff < timeout {
245                            wait_backoff *= 2;
246                            timeout = wait_backoff
247                        }
248                        trace!(
249                            pending = self.handles.len(),
250                            timeout = timeout.as_millis(),
251                            "curl wait"
252                        );
253                        if let Err(e) = self.multi.wait(&mut [], timeout) {
254                            self.fail_and_drain(&Error::Multi(e));
255                        }
256                    } else {
257                        // Block, waiting for more work
258                        trace!("all work completed");
259                        match self.incoming_work.recv() {
260                            Ok(msg) => {
261                                trace!("resuming work");
262                                self.enqueue_request(msg);
263                                wait_backoff = INITIAL_DELAY;
264                            }
265                            Err(_) => {
266                                // The sending channel is closed. Shut down the worker.
267                                break;
268                            }
269                        }
270                    }
271                }
272            }
273        }
274    }
275
276    /// Adds the request to the `Multi`, or send an error back through the channel.
277    fn enqueue_request(&mut self, message: Message) {
278        match self.multi.add2(message.easy) {
279            Ok(mut handle) => {
280                self.token = self.token.wrapping_add(1);
281                handle.set_token(self.token).ok();
282                self.handles.insert(self.token, (handle, message.sender));
283            }
284            Err(e) => {
285                let _ = message.sender.send(Err(e.into()));
286            }
287        }
288    }
289}
290
291/// Interface that cURL (`Easy2`) uses to make progress.
292struct Collector {
293    response: Response,
294    request_body: Cursor<Vec<u8>>,
295    debug: bool,
296}
297
298impl Collector {
299    fn new() -> Self {
300        Collector {
301            response: Response::new(Vec::new()),
302            request_body: Cursor::new(Vec::new()),
303            debug: false,
304        }
305    }
306}
307
308impl Handler for Collector {
309    fn write(&mut self, data: &[u8]) -> Result<usize, WriteError> {
310        self.response.body_mut().extend_from_slice(data);
311        Ok(data.len())
312    }
313
314    fn header(&mut self, data: &[u8]) -> bool {
315        if let Some((name, value)) = handle_http_header(data)
316            && let Ok(name) = http::HeaderName::from_str(name)
317            && let Ok(value) = http::HeaderValue::from_str(value)
318        {
319            self.response.headers_mut().append(name, value);
320        }
321        true
322    }
323
324    fn read(&mut self, data: &mut [u8]) -> Result<usize, curl::easy::ReadError> {
325        Ok(self.request_body.read(data).unwrap())
326    }
327
328    fn debug(&mut self, kind: InfoType, data: &[u8]) {
329        if self.debug {
330            super::http::debug(kind, data);
331        }
332    }
333
334    fn progress(&mut self, _dltotal: f64, _dlnow: f64, _ultotal: f64, _ulnow: f64) -> bool {
335        true
336    }
337}
338
339/// Additional fields on an [`http::Response`].
340#[derive(Clone)]
341struct Extensions {
342    client_ip: Option<String>,
343}
344
345pub trait ResponsePartsExtensions {
346    fn client_ip(&self) -> Option<&str>;
347}
348
349impl ResponsePartsExtensions for http::response::Parts {
350    fn client_ip(&self) -> Option<&str> {
351        self.extensions
352            .get::<Extensions>()
353            .and_then(|extensions| extensions.client_ip.as_deref())
354    }
355}
356
357impl ResponsePartsExtensions for Response {
358    fn client_ip(&self) -> Option<&str> {
359        self.extensions()
360            .get::<Extensions>()
361            .and_then(|extensions| extensions.client_ip.as_deref())
362    }
363}
364
365/// Splits HTTP `HEADER: VALUE` to a tuple.
366fn handle_http_header(buf: &[u8]) -> Option<(&str, &str)> {
367    if buf.is_empty() {
368        return None;
369    }
370    let buf = std::str::from_utf8(buf).ok()?.trim_end();
371    // Don't let server sneak extra lines anywhere.
372    if buf.contains('\n') {
373        return None;
374    }
375    let (tag, value) = buf.split_once(':')?;
376    let value = value.trim();
377    Some((tag, value))
378}