wasmer_wasix/runners/wcgi/
handler.rs1use std::{collections::HashMap, ops::Deref, pin::Pin, sync::Arc};
2
3use anyhow::Error;
4use bytes::Bytes;
5use futures::{Future, FutureExt};
6use http::{Request, Response, StatusCode};
7use http_body_util::BodyExt;
8use hyper::body::Frame;
9use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt};
10use tracing::Instrument;
11use virtual_mio::InlineWaker;
12use wasmer::Module;
13use wasmer_wasix_types::wasi::ExitCode;
14use wcgi_host::CgiDialect;
15
16use super::super::Body;
17
18use crate::{
19 Runtime, VirtualTaskManager, WasiEnvBuilder,
20 bin_factory::run_exec,
21 os::task::OwnedTaskStatus,
22 runners::{
23 body_from_data, body_from_stream,
24 wcgi::{
25 Callbacks,
26 callbacks::{CreateEnvConfig, RecycleEnvConfig},
27 },
28 },
29 runtime::task_manager::{TaskWasm, TaskWasmRecycleProperties},
30};
31use wasmer_types::ModuleHash;
32
33#[derive(Clone, Debug)]
36pub(crate) struct Handler(Arc<SharedState>);
37
38impl Handler {
39 pub(crate) fn new(state: Arc<SharedState>) -> Self {
40 Handler(state)
41 }
42
43 #[tracing::instrument(level = "debug", skip_all, err)]
44 pub(crate) async fn handle<T>(
45 &self,
46 req: Request<hyper::body::Incoming>,
47 token: T,
48 ) -> Result<Response<Body>, Error>
49 where
50 T: Send + 'static,
51 {
52 tracing::debug!(headers=?req.headers());
53
54 let (parts, body) = req.into_parts();
55
56 let mut request_specific_env = HashMap::new();
60 request_specific_env.insert("REQUEST_METHOD".to_string(), parts.method.to_string());
61 request_specific_env.insert("SCRIPT_NAME".to_string(), parts.uri.path().to_string());
62 if let Some(query) = parts.uri.query() {
63 request_specific_env.insert("QUERY_STRING".to_string(), query.to_string());
64 }
65 self.dialect
66 .prepare_environment_variables(parts, &mut request_specific_env);
67
68 let create = self
69 .callbacks
70 .create_env(CreateEnvConfig {
71 env: request_specific_env,
72 program_name: self.program_name.clone(),
73 module: self.module.clone(),
74 module_hash: self.module_hash,
75 runtime: self.runtime.clone(),
76 setup_builder: self.setup_builder.clone(),
77 })
78 .await?;
79
80 tracing::debug!(
81 dialect=%self.dialect,
82 "Calling into the WCGI executable",
83 );
84
85 let task_manager = self.runtime.task_manager();
86 let env = create.env;
87 let module = self.module.clone();
88
89 let callbacks = Arc::clone(&self.callbacks);
91 let recycle = {
92 let callbacks = callbacks.clone();
93 move |props: TaskWasmRecycleProperties| {
94 InlineWaker::block_on(callbacks.recycle_env(RecycleEnvConfig {
95 env: props.env,
96 store: props.store,
97 memory: props.memory,
98 }));
99
100 drop(token);
104 }
105 };
106 let finished = env.process.finished.clone();
107
108 task_manager
128 .task_wasm(
129 TaskWasm::new(Box::new(run_exec), env, module, false, false)
130 .with_recycle(Box::new(recycle)),
132 )
133 .map_err(|err| {
134 tracing::warn!("failed to execute WCGI thread - {}", err);
135 err
136 })?;
137
138 let mut res_body_receiver = tokio::io::BufReader::new(create.body_receiver);
139
140 let stderr_receiver = create.stderr_receiver;
141 let propagate_stderr = self.propagate_stderr;
142 let work_consume_stderr = {
143 let callbacks = callbacks.clone();
144 async move { consume_stderr(stderr_receiver, callbacks, propagate_stderr).await }
145 .in_current_span()
146 };
147
148 tracing::trace!(
149 dialect=%self.dialect,
150 "spawning request forwarder",
151 );
152
153 let req_body_sender = create.body_sender;
154 let ret = drive_request_to_completion(finished, body, req_body_sender).await;
155
156 if propagate_stderr {
161 if let Some(stderr) = work_consume_stderr.await {
162 if !stderr.is_empty() {
163 return Ok(Response::builder()
164 .status(StatusCode::INTERNAL_SERVER_ERROR)
165 .body(body_from_data(stderr))?);
166 }
167 }
168 } else {
169 task_manager
170 .task_shared(Box::new(move || {
171 Box::pin(async move {
172 work_consume_stderr.await;
173 })
174 }))
175 .ok();
176 }
177
178 match ret {
179 Ok(_) => {}
180 Err(e) => {
181 let e = e.to_string();
182 tracing::error!(error = e, "Unable to drive the request to completion");
183 return Ok(Response::builder()
184 .status(StatusCode::INTERNAL_SERVER_ERROR)
185 .body(body_from_data(Bytes::from(e)))?);
186 }
187 }
188
189 tracing::trace!(
190 dialect=%self.dialect,
191 "extracting response parts",
192 );
193
194 let parts = self
195 .dialect
196 .extract_response_header(&mut res_body_receiver)
197 .await;
198 let parts = parts?;
199
200 tracing::trace!(
201 dialect=%self.dialect,
202 status=%parts.status,
203 "received response parts",
204 );
205
206 let chunks = futures::stream::try_unfold(res_body_receiver, |mut r| async move {
207 match r.fill_buf().await {
208 Ok([]) => Ok(None),
209 Ok(chunk) => {
210 let chunk: bytes::Bytes = chunk.to_vec().into();
211 r.consume(chunk.len());
212 Ok(Some((Frame::data(chunk), r)))
213 }
214 Err(e) => Err(anyhow::Error::from(e)),
215 }
216 });
217 let body = body_from_stream(chunks);
218
219 tracing::trace!(
220 dialect=%self.dialect,
221 "returning response with body stream",
222 );
223
224 let response = hyper::Response::from_parts(parts, body);
225 Ok(response)
226 }
227}
228
229impl Deref for Handler {
230 type Target = Arc<SharedState>;
231
232 fn deref(&self) -> &Self::Target {
233 &self.0
234 }
235}
236
237async fn drive_request_to_completion(
240 finished: Arc<OwnedTaskStatus>,
241 mut request_body: hyper::body::Incoming,
242 mut instance_stdin: impl AsyncWrite + Send + Sync + Unpin + 'static,
243) -> Result<ExitCode, Error> {
244 let request_body_send = async move {
245 let mut request_size = 0;
249 while let Some(res) = request_body.frame().await {
250 let chunk = res?;
253 if let Some(data) = chunk.data_ref() {
254 request_size += data.len();
255 instance_stdin.write_all(data.as_ref()).await?;
256 } else {
257 }
259 }
260
261 instance_stdin.shutdown().await?;
262 tracing::debug!(
263 request_size,
264 "Finished forwarding the request to the WCGI server"
265 );
266
267 Ok::<(), Error>(())
268 }
269 .in_current_span();
270
271 let (ret, _) = futures::try_join!(finished.await_termination_anyhow(), request_body_send)?;
272 Ok(ret)
273}
274
275async fn consume_stderr(
279 stderr: impl AsyncRead + Send + Unpin + 'static,
280 callbacks: Arc<dyn Callbacks>,
281 propagate_stderr: bool,
282) -> Option<Vec<u8>> {
283 let mut stderr = tokio::io::BufReader::new(stderr);
284
285 let mut propagate = match propagate_stderr {
286 true => Some(Vec::new()),
287 false => None,
288 };
289
290 loop {
294 match stderr.fill_buf().await {
295 Ok([]) => {
296 break;
298 }
299 Ok(chunk) => {
300 tracing::trace!("received stderr (len={})", chunk.len());
301 if let Some(propogate) = propagate.as_mut() {
302 propogate.write_all(chunk).await.ok();
303 }
304 callbacks.on_stderr(chunk);
305 let bytes_read = chunk.len();
306 stderr.consume(bytes_read);
307 }
308 Err(e) => {
309 tracing::trace!("received stderr (err={})", e);
310 callbacks.on_stderr_error(e);
311 break;
312 }
313 }
314 }
315
316 propagate
317}
318
319pub type SetupBuilder = Arc<dyn Fn(&mut WasiEnvBuilder) -> Result<(), anyhow::Error> + Send + Sync>;
320
321#[derive(derive_more::Debug)]
322pub(crate) struct SharedState {
323 pub(crate) module: Module,
324 pub(crate) module_hash: ModuleHash,
325 pub(crate) dialect: CgiDialect,
326 pub(crate) program_name: String,
327 pub(crate) propagate_stderr: bool,
328 #[debug(ignore)]
329 pub(crate) setup_builder: SetupBuilder,
330 pub(crate) callbacks: Arc<dyn Callbacks>,
331 pub(crate) runtime: Arc<dyn Runtime + Send + Sync>,
332}
333
334impl tower::Service<Request<hyper::body::Incoming>> for Handler {
335 type Response = Response<Body>;
336 type Error = Error;
337 type Future = Pin<Box<dyn Future<Output = Result<Response<Body>, Error>> + Send>>;
338
339 fn poll_ready(
340 &mut self,
341 _cx: &mut std::task::Context<'_>,
342 ) -> std::task::Poll<Result<(), Self::Error>> {
343 std::task::Poll::Ready(Ok(()))
344 }
345
346 fn call(&mut self, request: Request<hyper::body::Incoming>) -> Self::Future {
347 let handler = self.clone();
349 let fut = async move { handler.handle(request, ()).await };
350 fut.boxed()
351 }
352}