wasmer_wasix/http/
reqwest.rs

1use std::time::Duration;
2
3use anyhow::Context;
4use futures::{TryStreamExt, future::BoxFuture};
5use std::convert::TryFrom;
6use tokio::runtime::Handle;
7
8use super::{HttpRequest, HttpResponse};
9
10#[derive(Clone, Debug)]
11pub struct ReqwestHttpClient {
12    handle: Handle,
13    connect_timeout: Duration,
14    response_body_chunk_timeout: Option<std::time::Duration>,
15}
16
17impl Default for ReqwestHttpClient {
18    fn default() -> Self {
19        Self {
20            handle: Handle::current(),
21            connect_timeout: Self::DEFAULT_CONNECT_TIMEOUT,
22            response_body_chunk_timeout: None,
23        }
24    }
25}
26
27impl ReqwestHttpClient {
28    const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
29
30    pub fn with_connect_timeout(mut self, timeout: Duration) -> Self {
31        self.connect_timeout = timeout;
32        self
33    }
34
35    pub fn with_response_body_chunk_timeout(mut self, timeout: std::time::Duration) -> Self {
36        self.response_body_chunk_timeout = Some(timeout);
37        self
38    }
39
40    #[tracing::instrument(skip_all, fields(method=?request.method, url=%request.url))]
41    async fn request(&self, request: HttpRequest) -> Result<HttpResponse, anyhow::Error> {
42        let method = reqwest::Method::try_from(request.method.as_str())
43            .with_context(|| format!("Invalid http method {}", request.method))?;
44
45        // TODO: use persistent client?
46        let builder = {
47            let _guard = Handle::try_current().map_err(|_| self.handle.enter());
48            let mut builder = reqwest::ClientBuilder::new();
49            #[cfg(not(feature = "js"))]
50            {
51                builder = builder.connect_timeout(self.connect_timeout);
52            }
53            builder
54        };
55        let client = builder.build().context("failed to create reqwest client")?;
56
57        tracing::debug!("sending http request");
58        let mut builder = client.request(method, request.url.as_str());
59        for (header, val) in &request.headers {
60            builder = builder.header(header, val);
61        }
62
63        if let Some(body) = request.body {
64            builder = builder.body(reqwest::Body::from(body));
65        }
66
67        let request = builder
68            .build()
69            .context("Failed to construct http request")?;
70
71        let mut response = client.execute(request).await?;
72        let headers = std::mem::take(response.headers_mut());
73
74        let status = response.status();
75
76        tracing::debug!(status=?status, "received http response");
77
78        // Download the body.
79        #[cfg(not(feature = "js"))]
80        let data = if let Some(timeout_duration) = self.response_body_chunk_timeout {
81            // Download the body with a chunk timeout.
82            // The timeout prevents long stalls.
83
84            let mut stream = response.bytes_stream();
85            let mut buf = Vec::new();
86
87            // Creating tokio timeouts has overhead, so instead of a fresh
88            // timeout per chunk a shared timeout is used, and a chunk counter
89            // is kept. Only if no chunk was downloaded within the timeout a
90            // timeout error is raised.
91            'OUTER: loop {
92                let timeout = tokio::time::sleep(timeout_duration);
93                pin_utils::pin_mut!(timeout);
94
95                let mut chunk_count = 0;
96
97                loop {
98                    tokio::select! {
99                        // Biased because the timeout is secondary,
100                        // and chunks should always have priority.
101                        biased;
102
103                        res = stream.try_next() => {
104                            match res {
105                                Ok(Some(chunk)) => {
106                                    buf.extend_from_slice(&chunk);
107                                    chunk_count += 1;
108                                }
109                                Ok(None) => {
110                                    break 'OUTER;
111                                }
112                                Err(e) => {
113                                    return Err(e.into());
114                                }
115                            }
116                        }
117
118                        _ = &mut timeout => {
119                            if chunk_count == 0 {
120                                tracing::warn!(timeout= "timeout while downloading response body");
121                                return Err(anyhow::anyhow!("Timeout while downloading response body"));
122                            } else {
123                                tracing::debug!(downloaded_body_size_bytes=%buf.len(), "download progress");
124                                // Timeout, but chunks were downloaded, so
125                                // just continue with a fresh timeout.
126                                continue 'OUTER;
127                            }
128                        }
129                    }
130                }
131            }
132
133            buf
134        } else {
135            response.bytes().await?.to_vec()
136        };
137        #[cfg(feature = "js")]
138        let data = response.bytes().await?.to_vec();
139
140        tracing::debug!(body_size_bytes=%data.len(), "downloaded http response body");
141
142        Ok(HttpResponse {
143            status,
144            redirected: false,
145            body: Some(data),
146            headers,
147        })
148    }
149}
150
151impl super::HttpClient for ReqwestHttpClient {
152    #[cfg(not(feature = "js"))]
153    fn request(&self, request: HttpRequest) -> BoxFuture<'_, Result<HttpResponse, anyhow::Error>> {
154        let client = self.clone();
155        let f = async move { client.request(request).await };
156        Box::pin(f)
157    }
158
159    #[cfg(feature = "js")]
160    fn request(&self, request: HttpRequest) -> BoxFuture<'_, Result<HttpResponse, anyhow::Error>> {
161        let client = self.clone();
162        let (sender, receiver) = futures::channel::oneshot::channel();
163        wasm_bindgen_futures::spawn_local(async move {
164            let result = client.request(request).await;
165            let _ = sender.send(result);
166        });
167        Box::pin(async move {
168            match receiver.await {
169                Ok(result) => result,
170                Err(e) => Err(anyhow::Error::new(e)),
171            }
172        })
173    }
174}