wasmer_wasix/runtime/task_manager/
tokio.rs

1use std::sync::Mutex;
2use std::{num::NonZeroUsize, pin::Pin, sync::Arc, time::Duration};
3
4use futures::{Future, future::BoxFuture};
5use tokio::runtime::{Handle, Runtime};
6use virtual_mio::block_on;
7
8use crate::runtime::{SpawnType, task_manager::TaskWasmCallbacks};
9use crate::{WasiFunctionEnv, os::task::thread::WasiThreadError};
10
11use super::{SpawnMemoryTypeOrStore, TaskWasm, TaskWasmRunProperties, VirtualTaskManager};
12
13#[derive(Debug, Clone)]
14pub enum RuntimeOrHandle {
15    Handle(Handle),
16    Runtime(Handle, Arc<Mutex<Option<Runtime>>>),
17}
18impl From<Handle> for RuntimeOrHandle {
19    fn from(value: Handle) -> Self {
20        Self::Handle(value)
21    }
22}
23impl From<Runtime> for RuntimeOrHandle {
24    fn from(value: Runtime) -> Self {
25        Self::Runtime(value.handle().clone(), Arc::new(Mutex::new(Some(value))))
26    }
27}
28
29impl Drop for RuntimeOrHandle {
30    fn drop(&mut self) {
31        if let Self::Runtime(_, runtime) = self
32            && let Some(h) = runtime.lock().unwrap().take()
33        {
34            h.shutdown_timeout(Duration::from_secs(0))
35        }
36    }
37}
38
39impl RuntimeOrHandle {
40    pub fn handle(&self) -> &Handle {
41        match self {
42            Self::Handle(h) => h,
43            Self::Runtime(h, _) => h,
44        }
45    }
46}
47
48#[derive(Clone)]
49pub struct ThreadPool {
50    inner: rusty_pool::ThreadPool,
51}
52
53impl std::ops::Deref for ThreadPool {
54    type Target = rusty_pool::ThreadPool;
55
56    fn deref(&self) -> &Self::Target {
57        &self.inner
58    }
59}
60
61impl std::fmt::Debug for ThreadPool {
62    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63        f.debug_struct("ThreadPool")
64            .field("name", &self.get_name())
65            .field("current_worker_count", &self.get_current_worker_count())
66            .field("idle_worker_count", &self.get_idle_worker_count())
67            .finish()
68    }
69}
70
71/// A task manager that uses tokio to spawn tasks.
72#[derive(Clone, Debug)]
73pub struct TokioTaskManager {
74    rt: RuntimeOrHandle,
75    pool: Arc<ThreadPool>,
76}
77
78impl TokioTaskManager {
79    pub fn new<I>(rt: I) -> Self
80    where
81        I: Into<RuntimeOrHandle>,
82    {
83        let concurrency = std::thread::available_parallelism()
84            .unwrap_or(NonZeroUsize::new(1).unwrap())
85            .get();
86        let max_threads = 200usize.max(concurrency * 100);
87
88        Self {
89            rt: rt.into(),
90            pool: Arc::new(ThreadPool {
91                inner: rusty_pool::Builder::new()
92                    .name("TokioTaskManager Thread Pool".to_string())
93                    .core_size(max_threads)
94                    .max_size(max_threads)
95                    .build(),
96            }),
97        }
98    }
99
100    pub fn runtime_handle(&self) -> tokio::runtime::Handle {
101        self.rt.handle().clone()
102    }
103
104    pub fn pool_handle(&self) -> Arc<ThreadPool> {
105        self.pool.clone()
106    }
107}
108
109impl Default for TokioTaskManager {
110    fn default() -> Self {
111        Self::new(Handle::current())
112    }
113}
114
115impl VirtualTaskManager for TokioTaskManager {
116    /// See [`VirtualTaskManager::sleep_now`].
117    fn sleep_now(&self, time: Duration) -> Pin<Box<dyn Future<Output = ()> + Send + Sync>> {
118        let handle = self.runtime_handle();
119        Box::pin(async move {
120            SleepNow::default()
121                .enter(handle, time)
122                .await
123                .ok()
124                .unwrap_or(())
125        })
126    }
127
128    /// See [`VirtualTaskManager::task_shared`].
129    fn task_shared(
130        &self,
131        task: Box<dyn FnOnce() -> BoxFuture<'static, ()> + Send + 'static>,
132    ) -> Result<(), WasiThreadError> {
133        self.rt.handle().spawn(async move {
134            let fut = task();
135            fut.await
136        });
137        Ok(())
138    }
139
140    /// See [`VirtualTaskManager::task_wasm`].
141    fn task_wasm(&self, task: TaskWasm) -> Result<(), WasiThreadError> {
142        fn env_and_store(
143            task: TaskWasm,
144        ) -> Result<(WasiFunctionEnv, wasmer::Store, TaskWasmCallbacks), WasiThreadError> {
145            let (make_memory, instance_group_data) = match task.spawn_type {
146                SpawnType::CreateMemory => (SpawnMemoryTypeOrStore::New, None),
147                SpawnType::NewLinkerInstanceGroup(instance_group_data) => {
148                    (SpawnMemoryTypeOrStore::New, Some(instance_group_data))
149                }
150                SpawnType::CreateMemoryOfType(t) => (SpawnMemoryTypeOrStore::Type(t), None),
151                SpawnType::AttachMemory(mem) => {
152                    let mut store = task.env.runtime().new_store();
153                    let memory = mem.attach(&mut store);
154                    (SpawnMemoryTypeOrStore::StoreAndMemory(store, memory), None)
155                }
156            };
157
158            let (env, store) = WasiFunctionEnv::new_with_store(
159                task.module,
160                task.env,
161                task.globals,
162                make_memory,
163                task.update_layout,
164                task.call_initialize,
165                instance_group_data,
166            )?;
167            Ok((env, store, task.callbacks))
168        }
169
170        let (sx, rx) = std::sync::mpsc::channel();
171
172        if task.callbacks.trigger.is_some() {
173            tracing::trace!("spawning task_wasm trigger in async pool");
174            self.pool.execute(move || {
175                let (mut ctx, mut store, callbacks) = match env_and_store(task) {
176                    Ok(x) => {
177                        sx.send(Ok(())).unwrap();
178                        x
179                    }
180                    Err(c) => {
181                        tracing::error!("failed to prepare environment for task_wasm trigger: {c}");
182                        sx.send(Err(c)).unwrap();
183                        return;
184                    }
185                };
186
187                let result = {
188                    let mut trigger = (callbacks.trigger.unwrap())();
189                    let pre_run = callbacks.pre_run;
190                    let ctx = &mut ctx;
191                    let store = &mut store;
192                    block_on(async move {
193                        // We wait for either the trigger or for a snapshot to take place
194                        let result = loop {
195                            let env = ctx.data(store);
196                            break tokio::select! {
197                                r = &mut trigger => r,
198                                _ = env.thread.wait_for_signal() => {
199                                    tracing::debug!("wait-for-signal(triggered)");
200                                    let mut ctx = ctx.env.clone().into_mut(store);
201                                    if let Err(err) =
202                                        crate::WasiEnv::do_pending_link_operations(
203                                            &mut ctx,
204                                            false
205                                        ).and_then(|()|
206                                            crate::WasiEnv::process_signals_and_exit(&mut ctx)
207                                        )
208                                    {
209                                        match err {
210                                            crate::WasiError::Exit(code) => Err(code),
211                                            err => {
212                                                tracing::error!("failed to process signals - {}", err);
213                                                continue;
214                                            }
215                                        }
216                                    } else {
217                                        continue;
218                                    }
219                                }
220                                _ = crate::wait_for_snapshot(env) => {
221                                    tracing::debug!("wait-for-snapshot(triggered)");
222                                    let mut ctx = ctx.env.clone().into_mut(store);
223                                    crate::os::task::WasiProcessInner::do_checkpoints_from_outside(&mut ctx);
224                                    continue;
225                                }
226                            };
227                        };
228
229                        if let Some(pre_run) = pre_run {
230                            pre_run(ctx, store).await;
231                        }
232
233                        result
234                    })
235                };
236
237                // Invoke the callback
238                (callbacks.run)(TaskWasmRunProperties {
239                    ctx,
240                    store,
241                    trigger_result: Some(result),
242                    recycle: callbacks.recycle,
243                });
244            });
245        } else {
246            tracing::trace!("spawning task_wasm in blocking thread");
247
248            // Run the callback on a dedicated thread
249            self.pool.execute(move || {
250                tracing::trace!("task_wasm started in blocking thread");
251                let (mut ctx, mut store, callbacks) = match env_and_store(task) {
252                    Ok(x) => {
253                        sx.send(Ok(())).unwrap();
254                        x
255                    }
256                    Err(c) => {
257                        sx.send(Err(c)).unwrap();
258                        return;
259                    }
260                };
261
262                if let Some(pre_run) = callbacks.pre_run {
263                    block_on(pre_run(&mut ctx, &mut store));
264                }
265
266                // Invoke the callback
267                (callbacks.run)(TaskWasmRunProperties {
268                    ctx,
269                    store,
270                    trigger_result: None,
271                    recycle: callbacks.recycle,
272                });
273            });
274        }
275
276        rx.recv()
277            .map_err(|_| WasiThreadError::InvalidWasmContext)??;
278
279        Ok(())
280    }
281
282    /// See [`VirtualTaskManager::task_dedicated`].
283    fn task_dedicated(
284        &self,
285        task: Box<dyn FnOnce() + Send + 'static>,
286    ) -> Result<(), WasiThreadError> {
287        self.pool.execute(move || {
288            task();
289        });
290        Ok(())
291    }
292
293    /// See [`VirtualTaskManager::thread_parallelism`].
294    fn thread_parallelism(&self) -> Result<usize, WasiThreadError> {
295        Ok(std::thread::available_parallelism()
296            .map(usize::from)
297            .unwrap_or(8))
298    }
299}
300
301// Used by [`VirtualTaskManager::sleep_now`] to abort a sleep task when drop.
302#[derive(Default)]
303pub struct SleepNow {
304    abort_handle: Option<tokio::task::AbortHandle>,
305}
306
307impl SleepNow {
308    pub async fn enter(
309        &mut self,
310        handle: tokio::runtime::Handle,
311        time: Duration,
312    ) -> Result<(), tokio::task::JoinError> {
313        let handle = handle.spawn(async move {
314            if time == Duration::ZERO {
315                tokio::task::yield_now().await;
316            } else {
317                tokio::time::sleep(time).await;
318            }
319        });
320        self.abort_handle = Some(handle.abort_handle());
321        handle.await
322    }
323}
324
325impl Drop for SleepNow {
326    fn drop(&mut self) {
327        if let Some(h) = self.abort_handle.as_ref() {
328            h.abort()
329        }
330    }
331}