wasmer_wasix/runtime/task_manager/
tokio.rs

1use std::sync::Mutex;
2use std::{num::NonZeroUsize, pin::Pin, sync::Arc, time::Duration};
3
4use futures::{Future, future::BoxFuture};
5use tokio::runtime::{Handle, Runtime};
6use virtual_mio::InlineWaker;
7use wasmer::AsStoreMut;
8
9use crate::runtime::SpawnType;
10use crate::{WasiFunctionEnv, os::task::thread::WasiThreadError};
11
12use super::{SpawnMemoryTypeOrStore, TaskWasm, TaskWasmRunProperties, VirtualTaskManager};
13
14#[derive(Debug, Clone)]
15pub enum RuntimeOrHandle {
16    Handle(Handle),
17    Runtime(Handle, Arc<Mutex<Option<Runtime>>>),
18}
19impl From<Handle> for RuntimeOrHandle {
20    fn from(value: Handle) -> Self {
21        Self::Handle(value)
22    }
23}
24impl From<Runtime> for RuntimeOrHandle {
25    fn from(value: Runtime) -> Self {
26        Self::Runtime(value.handle().clone(), Arc::new(Mutex::new(Some(value))))
27    }
28}
29
30impl Drop for RuntimeOrHandle {
31    fn drop(&mut self) {
32        if let Self::Runtime(_, runtime) = self {
33            if let Some(h) = runtime.lock().unwrap().take() {
34                h.shutdown_timeout(Duration::from_secs(0))
35            }
36        }
37    }
38}
39
40impl RuntimeOrHandle {
41    pub fn handle(&self) -> &Handle {
42        match self {
43            Self::Handle(h) => h,
44            Self::Runtime(h, _) => h,
45        }
46    }
47}
48
49#[derive(Clone)]
50pub struct ThreadPool {
51    inner: rusty_pool::ThreadPool,
52}
53
54impl std::ops::Deref for ThreadPool {
55    type Target = rusty_pool::ThreadPool;
56
57    fn deref(&self) -> &Self::Target {
58        &self.inner
59    }
60}
61
62impl std::fmt::Debug for ThreadPool {
63    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64        f.debug_struct("ThreadPool")
65            .field("name", &self.get_name())
66            .field("current_worker_count", &self.get_current_worker_count())
67            .field("idle_worker_count", &self.get_idle_worker_count())
68            .finish()
69    }
70}
71
72/// A task manager that uses tokio to spawn tasks.
73#[derive(Clone, Debug)]
74pub struct TokioTaskManager {
75    rt: RuntimeOrHandle,
76    pool: Arc<ThreadPool>,
77}
78
79impl TokioTaskManager {
80    pub fn new<I>(rt: I) -> Self
81    where
82        I: Into<RuntimeOrHandle>,
83    {
84        let concurrency = std::thread::available_parallelism()
85            .unwrap_or(NonZeroUsize::new(1).unwrap())
86            .get();
87        let max_threads = 200usize.max(concurrency * 100);
88
89        Self {
90            rt: rt.into(),
91            pool: Arc::new(ThreadPool {
92                inner: rusty_pool::Builder::new()
93                    .name("TokioTaskManager Thread Pool".to_string())
94                    .core_size(max_threads)
95                    .max_size(max_threads)
96                    .build(),
97            }),
98        }
99    }
100
101    pub fn runtime_handle(&self) -> tokio::runtime::Handle {
102        self.rt.handle().clone()
103    }
104
105    pub fn pool_handle(&self) -> Arc<ThreadPool> {
106        self.pool.clone()
107    }
108}
109
110impl Default for TokioTaskManager {
111    fn default() -> Self {
112        Self::new(Handle::current())
113    }
114}
115
116impl VirtualTaskManager for TokioTaskManager {
117    /// See [`VirtualTaskManager::sleep_now`].
118    fn sleep_now(&self, time: Duration) -> Pin<Box<dyn Future<Output = ()> + Send + Sync>> {
119        let handle = self.runtime_handle();
120        Box::pin(async move {
121            SleepNow::default()
122                .enter(handle, time)
123                .await
124                .ok()
125                .unwrap_or(())
126        })
127    }
128
129    /// See [`VirtualTaskManager::task_shared`].
130    fn task_shared(
131        &self,
132        task: Box<dyn FnOnce() -> BoxFuture<'static, ()> + Send + 'static>,
133    ) -> Result<(), WasiThreadError> {
134        self.rt.handle().spawn(async move {
135            let fut = task();
136            fut.await
137        });
138        Ok(())
139    }
140
141    /// See [`VirtualTaskManager::task_wasm`].
142    fn task_wasm(&self, task: TaskWasm) -> Result<(), WasiThreadError> {
143        let run = task.run;
144        let recycle = task.recycle;
145        let env = task.env;
146        let pre_run = task.pre_run;
147
148        let make_memory: SpawnMemoryTypeOrStore = match &task.spawn_type {
149            SpawnType::CreateMemory | SpawnType::NewLinkerInstanceGroup(..) => {
150                SpawnMemoryTypeOrStore::New
151            }
152            SpawnType::CreateMemoryOfType(t) => SpawnMemoryTypeOrStore::Type(*t),
153            SpawnType::ShareMemory(_, _) | SpawnType::CopyMemory(_, _) => {
154                let mut store = env.runtime().new_store();
155                let memory = self.build_memory(&mut store.as_store_mut(), &task.spawn_type)?;
156                SpawnMemoryTypeOrStore::StoreAndMemory(store, memory)
157            }
158        };
159
160        // This should actually run in the blocking thread, just like the task itself.
161        // See the comment below for why we can't do it there yet.
162        //
163        // For now block_in_place at least ensures that we don't block the async runtime
164        let ret = tokio::task::block_in_place(move || {
165            if let SpawnType::NewLinkerInstanceGroup(linker, func_env, mut store) = task.spawn_type
166            {
167                WasiFunctionEnv::new_with_store(
168                    task.module,
169                    env,
170                    task.globals,
171                    make_memory,
172                    task.update_layout,
173                    task.call_initialize,
174                    Some((linker, &mut func_env.into_mut(&mut store))),
175                )
176            } else {
177                WasiFunctionEnv::new_with_store(
178                    task.module,
179                    env,
180                    task.globals,
181                    make_memory,
182                    task.update_layout,
183                    task.call_initialize,
184                    None,
185                )
186            }
187        });
188
189        if let Some(trigger) = task.trigger {
190            tracing::trace!("spawning task_wasm trigger in async pool");
191            // In principle, we'd need to create this in the `pool.execute` function below, that is
192            //
193            // ```
194            // 227: pool.execute(move || {
195            // ...:      let (ctx, mut store) = WasiFunctionEnv::new_with_store(
196            // ...:      ...
197            // ```
198            //
199            // However, in the loop spawned below we need to have a `FunctionEnvMut<WasiEnv>`, which
200            // must be created with a mutable reference to the store. We can't, however since
201            // ```
202            // pool.execute(move || {
203            //      let (ctx, mut store) = WasiFunctionEnv::new_with_store(
204            //      ...
205            //      tx.send(store.as_store_mut())
206            // ```
207            // or
208            // ```
209            // pool.execute(move || {
210            //      let (ctx, mut store) = WasiFunctionEnv::new_with_store(
211            //      ...
212            //      tx.send(ctx.env.clone().into_mut(&mut store.as_store_mut()))
213            // ```
214            // Since the reference would outlive the owned value.
215            //
216            // So, we create the store (and memory, and instance) outside the execution thread (the
217            // pool's one), and let it fail for runtimes that don't support entities created in a
218            // thread that's not the one in which execution happens in; this until we can clone
219            // stores.
220            let (mut ctx, mut store) = ret?;
221
222            let mut trigger = trigger();
223            let pool = self.pool.clone();
224            self.rt.handle().spawn(async move {
225                // We wait for either the trigger or for a snapshot to take place
226                let result = loop {
227                    let env = ctx.data(&store);
228                    break tokio::select! {
229                        r = &mut trigger => r,
230                        _ = env.thread.wait_for_signal() => {
231                            tracing::debug!("wait-for-signal(triggered)");
232                            let mut ctx = ctx.env.clone().into_mut(&mut store);
233                            if let Err(err) =
234                                crate::WasiEnv::do_pending_link_operations(
235                                    &mut ctx,
236                                    false
237                                ).and_then(|()|
238                                    crate::WasiEnv::process_signals_and_exit(&mut ctx)
239                                )
240                            {
241                                match err {
242                                    crate::WasiError::Exit(code) => Err(code),
243                                    err => {
244                                        tracing::error!("failed to process signals - {}", err);
245                                        continue;
246                                    }
247                                }
248                            } else {
249                                continue;
250                            }
251                        }
252                        _ = crate::wait_for_snapshot(env) => {
253                            tracing::debug!("wait-for-snapshot(triggered)");
254                            let mut ctx = ctx.env.clone().into_mut(&mut store);
255                            crate::os::task::WasiProcessInner::do_checkpoints_from_outside(&mut ctx);
256                            continue;
257                        }
258                    };
259                };
260
261                if let Some(pre_run) = pre_run {
262                    pre_run(&mut ctx, &mut store).await;
263                }
264
265                // Build the task that will go on the callback
266                pool.execute(move || {
267                    // Invoke the callback
268                    run(TaskWasmRunProperties {
269                        ctx,
270                        store,
271                        trigger_result: Some(result),
272                        recycle,
273                    });
274                });
275            });
276        } else {
277            tracing::trace!("spawning task_wasm in blocking thread");
278
279            let (sx, rx) = std::sync::mpsc::channel();
280
281            // Run the callback on a dedicated thread
282            self.pool.execute(move || {
283                tracing::trace!("task_wasm started in blocking thread");
284                let (mut ctx, mut store) = match ret {
285                    Ok(x) => {
286                        sx.send(Ok(())).unwrap();
287                        x
288                    }
289                    Err(c) => {
290                        sx.send(Err(c)).unwrap();
291                        return;
292                    }
293                };
294
295                if let Some(pre_run) = pre_run {
296                    InlineWaker::block_on(pre_run(&mut ctx, &mut store));
297                }
298
299                // Invoke the callback
300                run(TaskWasmRunProperties {
301                    ctx,
302                    store,
303                    trigger_result: None,
304                    recycle,
305                });
306            });
307
308            rx.recv()
309                .map_err(|_| WasiThreadError::InvalidWasmContext)??;
310        }
311        Ok(())
312    }
313
314    /// See [`VirtualTaskManager::task_dedicated`].
315    fn task_dedicated(
316        &self,
317        task: Box<dyn FnOnce() + Send + 'static>,
318    ) -> Result<(), WasiThreadError> {
319        self.pool.execute(move || {
320            task();
321        });
322        Ok(())
323    }
324
325    /// See [`VirtualTaskManager::thread_parallelism`].
326    fn thread_parallelism(&self) -> Result<usize, WasiThreadError> {
327        Ok(std::thread::available_parallelism()
328            .map(usize::from)
329            .unwrap_or(8))
330    }
331}
332
333// Used by [`VirtualTaskManager::sleep_now`] to abort a sleep task when drop.
334#[derive(Default)]
335struct SleepNow {
336    abort_handle: Option<tokio::task::AbortHandle>,
337}
338
339impl SleepNow {
340    async fn enter(
341        &mut self,
342        handle: tokio::runtime::Handle,
343        time: Duration,
344    ) -> Result<(), tokio::task::JoinError> {
345        let handle = handle.spawn(async move {
346            if time == Duration::ZERO {
347                tokio::task::yield_now().await;
348            } else {
349                tokio::time::sleep(time).await;
350            }
351        });
352        self.abort_handle = Some(handle.abort_handle());
353        handle.await
354    }
355}
356
357impl Drop for SleepNow {
358    fn drop(&mut self) {
359        if let Some(h) = self.abort_handle.as_ref() {
360            h.abort()
361        }
362    }
363}