wasmer_wasix/runtime/task_manager/
tokio.rs1use std::sync::Mutex;
2use std::{num::NonZeroUsize, pin::Pin, sync::Arc, time::Duration};
3
4use futures::{Future, future::BoxFuture};
5use tokio::runtime::{Handle, Runtime};
6use virtual_mio::InlineWaker;
7use wasmer::AsStoreMut;
8
9use crate::runtime::SpawnType;
10use crate::{WasiFunctionEnv, os::task::thread::WasiThreadError};
11
12use super::{SpawnMemoryTypeOrStore, TaskWasm, TaskWasmRunProperties, VirtualTaskManager};
13
14#[derive(Debug, Clone)]
15pub enum RuntimeOrHandle {
16 Handle(Handle),
17 Runtime(Handle, Arc<Mutex<Option<Runtime>>>),
18}
19impl From<Handle> for RuntimeOrHandle {
20 fn from(value: Handle) -> Self {
21 Self::Handle(value)
22 }
23}
24impl From<Runtime> for RuntimeOrHandle {
25 fn from(value: Runtime) -> Self {
26 Self::Runtime(value.handle().clone(), Arc::new(Mutex::new(Some(value))))
27 }
28}
29
30impl Drop for RuntimeOrHandle {
31 fn drop(&mut self) {
32 if let Self::Runtime(_, runtime) = self {
33 if let Some(h) = runtime.lock().unwrap().take() {
34 h.shutdown_timeout(Duration::from_secs(0))
35 }
36 }
37 }
38}
39
40impl RuntimeOrHandle {
41 pub fn handle(&self) -> &Handle {
42 match self {
43 Self::Handle(h) => h,
44 Self::Runtime(h, _) => h,
45 }
46 }
47}
48
49#[derive(Clone)]
50pub struct ThreadPool {
51 inner: rusty_pool::ThreadPool,
52}
53
54impl std::ops::Deref for ThreadPool {
55 type Target = rusty_pool::ThreadPool;
56
57 fn deref(&self) -> &Self::Target {
58 &self.inner
59 }
60}
61
62impl std::fmt::Debug for ThreadPool {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 f.debug_struct("ThreadPool")
65 .field("name", &self.get_name())
66 .field("current_worker_count", &self.get_current_worker_count())
67 .field("idle_worker_count", &self.get_idle_worker_count())
68 .finish()
69 }
70}
71
72#[derive(Clone, Debug)]
74pub struct TokioTaskManager {
75 rt: RuntimeOrHandle,
76 pool: Arc<ThreadPool>,
77}
78
79impl TokioTaskManager {
80 pub fn new<I>(rt: I) -> Self
81 where
82 I: Into<RuntimeOrHandle>,
83 {
84 let concurrency = std::thread::available_parallelism()
85 .unwrap_or(NonZeroUsize::new(1).unwrap())
86 .get();
87 let max_threads = 200usize.max(concurrency * 100);
88
89 Self {
90 rt: rt.into(),
91 pool: Arc::new(ThreadPool {
92 inner: rusty_pool::Builder::new()
93 .name("TokioTaskManager Thread Pool".to_string())
94 .core_size(max_threads)
95 .max_size(max_threads)
96 .build(),
97 }),
98 }
99 }
100
101 pub fn runtime_handle(&self) -> tokio::runtime::Handle {
102 self.rt.handle().clone()
103 }
104
105 pub fn pool_handle(&self) -> Arc<ThreadPool> {
106 self.pool.clone()
107 }
108}
109
110impl Default for TokioTaskManager {
111 fn default() -> Self {
112 Self::new(Handle::current())
113 }
114}
115
116impl VirtualTaskManager for TokioTaskManager {
117 fn sleep_now(&self, time: Duration) -> Pin<Box<dyn Future<Output = ()> + Send + Sync>> {
119 let handle = self.runtime_handle();
120 Box::pin(async move {
121 SleepNow::default()
122 .enter(handle, time)
123 .await
124 .ok()
125 .unwrap_or(())
126 })
127 }
128
129 fn task_shared(
131 &self,
132 task: Box<dyn FnOnce() -> BoxFuture<'static, ()> + Send + 'static>,
133 ) -> Result<(), WasiThreadError> {
134 self.rt.handle().spawn(async move {
135 let fut = task();
136 fut.await
137 });
138 Ok(())
139 }
140
141 fn task_wasm(&self, task: TaskWasm) -> Result<(), WasiThreadError> {
143 let run = task.run;
144 let recycle = task.recycle;
145 let env = task.env;
146 let pre_run = task.pre_run;
147
148 let make_memory: SpawnMemoryTypeOrStore = match &task.spawn_type {
149 SpawnType::CreateMemory | SpawnType::NewLinkerInstanceGroup(..) => {
150 SpawnMemoryTypeOrStore::New
151 }
152 SpawnType::CreateMemoryOfType(t) => SpawnMemoryTypeOrStore::Type(*t),
153 SpawnType::ShareMemory(_, _) | SpawnType::CopyMemory(_, _) => {
154 let mut store = env.runtime().new_store();
155 let memory = self.build_memory(&mut store.as_store_mut(), &task.spawn_type)?;
156 SpawnMemoryTypeOrStore::StoreAndMemory(store, memory)
157 }
158 };
159
160 let ret = tokio::task::block_in_place(move || {
165 if let SpawnType::NewLinkerInstanceGroup(linker, func_env, mut store) = task.spawn_type
166 {
167 WasiFunctionEnv::new_with_store(
168 task.module,
169 env,
170 task.globals,
171 make_memory,
172 task.update_layout,
173 task.call_initialize,
174 Some((linker, &mut func_env.into_mut(&mut store))),
175 )
176 } else {
177 WasiFunctionEnv::new_with_store(
178 task.module,
179 env,
180 task.globals,
181 make_memory,
182 task.update_layout,
183 task.call_initialize,
184 None,
185 )
186 }
187 });
188
189 if let Some(trigger) = task.trigger {
190 tracing::trace!("spawning task_wasm trigger in async pool");
191 let (mut ctx, mut store) = ret?;
221
222 let mut trigger = trigger();
223 let pool = self.pool.clone();
224 self.rt.handle().spawn(async move {
225 let result = loop {
227 let env = ctx.data(&store);
228 break tokio::select! {
229 r = &mut trigger => r,
230 _ = env.thread.wait_for_signal() => {
231 tracing::debug!("wait-for-signal(triggered)");
232 let mut ctx = ctx.env.clone().into_mut(&mut store);
233 if let Err(err) =
234 crate::WasiEnv::do_pending_link_operations(
235 &mut ctx,
236 false
237 ).and_then(|()|
238 crate::WasiEnv::process_signals_and_exit(&mut ctx)
239 )
240 {
241 match err {
242 crate::WasiError::Exit(code) => Err(code),
243 err => {
244 tracing::error!("failed to process signals - {}", err);
245 continue;
246 }
247 }
248 } else {
249 continue;
250 }
251 }
252 _ = crate::wait_for_snapshot(env) => {
253 tracing::debug!("wait-for-snapshot(triggered)");
254 let mut ctx = ctx.env.clone().into_mut(&mut store);
255 crate::os::task::WasiProcessInner::do_checkpoints_from_outside(&mut ctx);
256 continue;
257 }
258 };
259 };
260
261 if let Some(pre_run) = pre_run {
262 pre_run(&mut ctx, &mut store).await;
263 }
264
265 pool.execute(move || {
267 run(TaskWasmRunProperties {
269 ctx,
270 store,
271 trigger_result: Some(result),
272 recycle,
273 });
274 });
275 });
276 } else {
277 tracing::trace!("spawning task_wasm in blocking thread");
278
279 let (sx, rx) = std::sync::mpsc::channel();
280
281 self.pool.execute(move || {
283 tracing::trace!("task_wasm started in blocking thread");
284 let (mut ctx, mut store) = match ret {
285 Ok(x) => {
286 sx.send(Ok(())).unwrap();
287 x
288 }
289 Err(c) => {
290 sx.send(Err(c)).unwrap();
291 return;
292 }
293 };
294
295 if let Some(pre_run) = pre_run {
296 InlineWaker::block_on(pre_run(&mut ctx, &mut store));
297 }
298
299 run(TaskWasmRunProperties {
301 ctx,
302 store,
303 trigger_result: None,
304 recycle,
305 });
306 });
307
308 rx.recv()
309 .map_err(|_| WasiThreadError::InvalidWasmContext)??;
310 }
311 Ok(())
312 }
313
314 fn task_dedicated(
316 &self,
317 task: Box<dyn FnOnce() + Send + 'static>,
318 ) -> Result<(), WasiThreadError> {
319 self.pool.execute(move || {
320 task();
321 });
322 Ok(())
323 }
324
325 fn thread_parallelism(&self) -> Result<usize, WasiThreadError> {
327 Ok(std::thread::available_parallelism()
328 .map(usize::from)
329 .unwrap_or(8))
330 }
331}
332
333#[derive(Default)]
335struct SleepNow {
336 abort_handle: Option<tokio::task::AbortHandle>,
337}
338
339impl SleepNow {
340 async fn enter(
341 &mut self,
342 handle: tokio::runtime::Handle,
343 time: Duration,
344 ) -> Result<(), tokio::task::JoinError> {
345 let handle = handle.spawn(async move {
346 if time == Duration::ZERO {
347 tokio::task::yield_now().await;
348 } else {
349 tokio::time::sleep(time).await;
350 }
351 });
352 self.abort_handle = Some(handle.abort_handle());
353 handle.await
354 }
355}
356
357impl Drop for SleepNow {
358 fn drop(&mut self) {
359 if let Some(h) = self.abort_handle.as_ref() {
360 h.abort()
361 }
362 }
363}