wasmer_wasix/runners/dproxy/
handler.rs

1use std::pin::Pin;
2use std::sync::Arc;
3use std::task::Poll;
4
5use futures::{Future, FutureExt};
6use http::{Request, Response, Uri};
7use http_body_util::BodyExt;
8use tower::Service;
9
10use super::super::Body;
11use crate::Runtime;
12use crate::runners::dproxy::shard::Shard;
13
14use super::Config;
15use super::factory::DProxyInstanceFactory;
16
17#[derive(Debug)]
18pub struct SharedState {
19    pub(crate) config: Config,
20    pub(crate) command_name: String,
21    pub(crate) runtime: Arc<dyn Runtime + Send + Sync>,
22    pub(crate) factory: DProxyInstanceFactory,
23}
24
25/// Handler which will process DProxy requests
26#[derive(Clone, Debug)]
27pub struct Handler(Arc<SharedState>);
28
29impl Handler {
30    pub(crate) fn new(
31        config: Config,
32        command_name: String,
33        factory: DProxyInstanceFactory,
34        runtime: Arc<dyn Runtime + Send + Sync>,
35    ) -> Self {
36        Handler(Arc::new(SharedState {
37            config,
38            command_name,
39            runtime,
40            factory,
41        }))
42    }
43
44    #[tracing::instrument(level = "debug", skip_all, err)]
45    pub(crate) async fn handle<T>(
46        &self,
47        mut req: Request<hyper::body::Incoming>,
48        _token: T,
49    ) -> anyhow::Result<Response<Body>>
50    where
51        T: Send + 'static,
52    {
53        tracing::debug!(headers=?req.headers());
54
55        // Determine the shard we are using
56        let shard = req
57            .headers()
58            .get("X-Shard")
59            .map(|v| String::from_utf8_lossy(v.as_bytes()))
60            .map(|s| match s.parse::<u64>() {
61                Ok(id) => Ok(Shard::ById(id)),
62                Err(err) => Err(err),
63            })
64            .unwrap_or(Ok(Shard::Singleton))?;
65
66        // Modify the request URI so that it will work with the hyper proxy
67        let mut new_uri = Uri::builder()
68            .scheme("http")
69            .authority(
70                req.uri()
71                    .authority()
72                    .cloned()
73                    .unwrap_or_else(|| "localhost".parse().unwrap()),
74            )
75            .path_and_query(
76                req.uri()
77                    .path_and_query()
78                    .cloned()
79                    .unwrap_or_else(|| "/".parse().unwrap()),
80            )
81            .build()
82            .unwrap();
83        std::mem::swap(req.uri_mut(), &mut new_uri);
84
85        // Acquire a DProxy instance
86        tracing::debug!("Acquiring DProxy instance instance");
87        let instance = self.factory.acquire(self, shard).await?;
88
89        tracing::debug!("Calling into the DProxy instance");
90        let client = instance.client.clone();
91
92        // Perform the request
93        let resp = client.request(req).await?;
94        let (parts, body) = resp.into_parts();
95        let body = body
96            .collect()
97            .await?
98            .map_err(|_| anyhow::anyhow!("Infallible"))
99            .boxed();
100
101        Ok(Response::from_parts(parts, body))
102    }
103}
104
105impl std::ops::Deref for Handler {
106    type Target = Arc<SharedState>;
107
108    fn deref(&self) -> &Self::Target {
109        &self.0
110    }
111}
112
113impl Service<Request<hyper::body::Incoming>> for Handler {
114    type Response = Response<Body>;
115    type Error = anyhow::Error;
116    type Future = Pin<Box<dyn Future<Output = anyhow::Result<Response<Body>>> + Send>>;
117
118    fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
119        Poll::Ready(Ok(()))
120    }
121
122    fn call(&mut self, request: Request<hyper::body::Incoming>) -> Self::Future {
123        // Note: all fields are reference-counted so cloning is pretty cheap
124        let handler = self.clone();
125        let fut = async move { handler.handle(request, ()).await };
126        fut.boxed()
127    }
128}