1use std::collections::HashMap;
7use std::io::{Cursor, Read};
8use std::str::FromStr;
9use std::sync::mpsc::{self, Receiver, Sender};
10use std::thread::JoinHandle;
11use std::time::Duration;
12
13use curl::easy::WriteError;
14use curl::easy::{Easy2, Handler, InfoType};
15use curl::multi::{Easy2Handle, Multi};
16
17use crate::util::network::http::HandleConfiguration;
18use futures::channel::oneshot;
19use tracing::{debug, error, trace};
20
21type Response = http::Response<Vec<u8>>;
22type Request = http::Request<Vec<u8>>;
23type HttpResult<T> = std::result::Result<T, Error>;
24
25#[derive(Debug, Clone, thiserror::Error)]
26#[non_exhaustive]
27pub enum Error {
28 #[error("curl multi failed")]
29 Multi(#[from] curl::MultiError),
30
31 #[error("curl failed")]
32 Easy(#[from] curl::Error),
33
34 #[error("failed to convert header value of `{name}` to string: {bytes:?}")]
35 BadHeader { name: String, bytes: Vec<u8> },
36}
37
38struct Message {
39 easy: Easy2<Collector>,
40 sender: oneshot::Sender<HttpResult<Response>>,
41}
42
43pub struct Client {
46 channel: Option<Sender<Message>>,
47 thread_handle: Option<JoinHandle<()>>,
48 handle_config: HandleConfiguration,
49}
50
51impl Client {
52 pub fn new(handle_config: HandleConfiguration) -> Client {
54 let (tx, rx) = mpsc::channel();
55 let handle = std::thread::spawn(move || WorkerServer::run(rx, handle_config.multiplexing));
56 Client {
57 channel: Some(tx),
58 thread_handle: Some(handle),
59 handle_config,
60 }
61 }
62
63 pub async fn request(&self, request: Request) -> HttpResult<Response> {
65 let url = request.uri().to_string();
66 debug!(target: "network::fetch", url);
67 let mut collector = Collector::new();
68 let (parts, body) = request.into_parts();
69 let body_len = body.len();
70 collector.request_body = Cursor::new(body);
71 collector.debug = self.handle_config.verbose;
72 let mut handle = curl::easy::Easy2::new(collector);
73 self.handle_config.configure2(&mut handle)?;
74
75 handle.url(&url)?;
76 handle.follow_location(true)?;
77
78 match parts.method {
79 http::Method::HEAD => handle.nobody(true)?,
80 http::Method::GET => handle.get(true)?,
81 http::Method::POST => {
82 handle.post_field_size(body_len as u64)?;
83 handle.post(true)?;
84 }
85 http::Method::PUT => {
86 handle.in_filesize(body_len as u64)?;
87 handle.put(true)?;
88 }
89 method => {
90 handle.upload(true)?;
91 handle.in_filesize(body_len as u64)?;
92 handle.custom_request(method.as_str())?;
93 }
94 }
95
96 let mut headers = curl::easy::List::new();
97 for (name, value) in parts.headers {
98 if let Some(name) = name {
99 let value: &str = value.to_str().map_err(|_| Error::BadHeader {
100 name: name.to_string(),
101 bytes: value.as_bytes().to_owned(),
102 })?;
103 headers.append(&format!("{}: {}", name, value))?;
104 }
105 }
106 handle.http_headers(headers)?;
107
108 let (sender, receiver) = oneshot::channel();
109 let req = Message {
110 easy: handle,
111 sender,
112 };
113
114 self.channel.as_ref().unwrap().send(req).unwrap();
115 receiver.await.unwrap()
116 }
117}
118
119impl Drop for Client {
120 fn drop(&mut self) {
121 drop(self.channel.take().unwrap());
123 let _ = self.thread_handle.take().unwrap().join();
125 }
126}
127
128impl std::fmt::Debug for Client {
129 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
130 f.debug_struct("http_async::Client").finish()
131 }
132}
133
134struct WorkerServer {
137 incoming_work: Receiver<Message>,
138 multi: Multi,
139 handles: HashMap<
140 usize,
141 (
142 Easy2Handle<Collector>,
143 oneshot::Sender<HttpResult<Response>>,
144 ),
145 >,
146 token: usize,
147}
148
149impl WorkerServer {
150 fn run(incoming_work: Receiver<Message>, multiplex: bool) {
151 let mut multi = Multi::new();
152 if let Err(e) = multi.set_max_host_connections(2) {
154 error!("failed to set max host connections in curl: {e}");
155 }
156 if let Err(e) = multi.pipelining(false, multiplex) {
157 error!("failed to enable multiplexing/pipelining in curl: {e}");
158 }
159
160 let mut worker = Self {
161 incoming_work,
162 multi,
163 handles: HashMap::new(),
164 token: 0,
165 };
166 worker.worker_loop();
167 }
168
169 fn fail_and_drain(&mut self, e: &Error) {
170 for (_token, (_handle, sender)) in self.handles.drain() {
171 let _ = sender.send(Err(e.clone()));
172 }
173 }
174
175 fn worker_loop(&mut self) {
176 const INITIAL_DELAY: Duration = Duration::from_millis(1);
177 let mut wait_backoff = INITIAL_DELAY;
178 loop {
179 while let Ok(msg) = self.incoming_work.try_recv() {
181 self.enqueue_request(msg);
182 wait_backoff = INITIAL_DELAY;
183 }
184
185 match self.multi.perform() {
186 Err(e) if e.is_call_perform() => {
187 }
189 Err(e) => {
190 self.fail_and_drain(&Error::Multi(e));
191 }
192 Ok(running) => {
193 self.multi.messages(|msg| {
194 let t = msg.token().expect("all handles have tokens");
195 trace!(token = t, "finish");
196 let Some((handle, sender)) = self.handles.remove(&t) else {
197 error!("missing entry {t} in handle table");
198 return;
199 };
200 let result = msg.result_for2(&handle).expect("handle must have a result");
201 let mut easy = self.multi.remove2(handle).expect("handle must be in multi");
202 let mut response = std::mem::replace(
203 &mut easy.get_mut().response,
204 Response::new(Vec::new()),
205 );
206 if let Ok(status) = easy.response_code()
207 && status != 0
208 && let Ok(status) = http::StatusCode::from_u16(status as u16)
209 {
210 *response.status_mut() = status;
211 }
212 let extensions = Extensions {
214 client_ip: easy.primary_ip().ok().flatten().map(str::to_string),
215 };
216 response.extensions_mut().insert(extensions);
217 let _ = sender.send(result.map(|()| response).map_err(Into::into));
218 });
219
220 if running > 0 {
221 let max_timeout = Duration::from_millis(1000);
222 let mut timeout = self
223 .multi
224 .get_timeout()
225 .ok()
226 .flatten()
227 .unwrap_or(max_timeout)
228 .min(max_timeout);
229 if timeout.is_zero() {
230 continue;
232 }
233 if wait_backoff < timeout {
245 wait_backoff *= 2;
246 timeout = wait_backoff
247 }
248 trace!(
249 pending = self.handles.len(),
250 timeout = timeout.as_millis(),
251 "curl wait"
252 );
253 if let Err(e) = self.multi.wait(&mut [], timeout) {
254 self.fail_and_drain(&Error::Multi(e));
255 }
256 } else {
257 trace!("all work completed");
259 match self.incoming_work.recv() {
260 Ok(msg) => {
261 trace!("resuming work");
262 self.enqueue_request(msg);
263 wait_backoff = INITIAL_DELAY;
264 }
265 Err(_) => {
266 break;
268 }
269 }
270 }
271 }
272 }
273 }
274 }
275
276 fn enqueue_request(&mut self, message: Message) {
278 match self.multi.add2(message.easy) {
279 Ok(mut handle) => {
280 self.token = self.token.wrapping_add(1);
281 handle.set_token(self.token).ok();
282 self.handles.insert(self.token, (handle, message.sender));
283 }
284 Err(e) => {
285 let _ = message.sender.send(Err(e.into()));
286 }
287 }
288 }
289}
290
291struct Collector {
293 response: Response,
294 request_body: Cursor<Vec<u8>>,
295 debug: bool,
296}
297
298impl Collector {
299 fn new() -> Self {
300 Collector {
301 response: Response::new(Vec::new()),
302 request_body: Cursor::new(Vec::new()),
303 debug: false,
304 }
305 }
306}
307
308impl Handler for Collector {
309 fn write(&mut self, data: &[u8]) -> Result<usize, WriteError> {
310 self.response.body_mut().extend_from_slice(data);
311 Ok(data.len())
312 }
313
314 fn header(&mut self, data: &[u8]) -> bool {
315 if let Some((name, value)) = handle_http_header(data)
316 && let Ok(name) = http::HeaderName::from_str(name)
317 && let Ok(value) = http::HeaderValue::from_str(value)
318 {
319 self.response.headers_mut().append(name, value);
320 }
321 true
322 }
323
324 fn read(&mut self, data: &mut [u8]) -> Result<usize, curl::easy::ReadError> {
325 Ok(self.request_body.read(data).unwrap())
326 }
327
328 fn debug(&mut self, kind: InfoType, data: &[u8]) {
329 if self.debug {
330 super::http::debug(kind, data);
331 }
332 }
333
334 fn progress(&mut self, _dltotal: f64, _dlnow: f64, _ultotal: f64, _ulnow: f64) -> bool {
335 true
336 }
337}
338
339#[derive(Clone)]
341struct Extensions {
342 client_ip: Option<String>,
343}
344
345pub trait ResponsePartsExtensions {
346 fn client_ip(&self) -> Option<&str>;
347}
348
349impl ResponsePartsExtensions for http::response::Parts {
350 fn client_ip(&self) -> Option<&str> {
351 self.extensions
352 .get::<Extensions>()
353 .and_then(|extensions| extensions.client_ip.as_deref())
354 }
355}
356
357impl ResponsePartsExtensions for Response {
358 fn client_ip(&self) -> Option<&str> {
359 self.extensions()
360 .get::<Extensions>()
361 .and_then(|extensions| extensions.client_ip.as_deref())
362 }
363}
364
365fn handle_http_header(buf: &[u8]) -> Option<(&str, &str)> {
367 if buf.is_empty() {
368 return None;
369 }
370 let buf = std::str::from_utf8(buf).ok()?.trim_end();
371 if buf.contains('\n') {
373 return None;
374 }
375 let (tag, value) = buf.split_once(':')?;
376 let value = value.trim();
377 Some((tag, value))
378}