wasmer_wasix/runtime/task_manager/
tokio.rs1use std::sync::Mutex;
2use std::{num::NonZeroUsize, pin::Pin, sync::Arc, time::Duration};
3
4use futures::{Future, future::BoxFuture};
5use tokio::runtime::{Handle, Runtime};
6use virtual_mio::block_on;
7
8use crate::runtime::{SpawnType, task_manager::TaskWasmCallbacks};
9use crate::{WasiFunctionEnv, os::task::thread::WasiThreadError};
10
11use super::{SpawnMemoryTypeOrStore, TaskWasm, TaskWasmRunProperties, VirtualTaskManager};
12
13#[derive(Debug, Clone)]
14pub enum RuntimeOrHandle {
15 Handle(Handle),
16 Runtime(Handle, Arc<Mutex<Option<Runtime>>>),
17}
18impl From<Handle> for RuntimeOrHandle {
19 fn from(value: Handle) -> Self {
20 Self::Handle(value)
21 }
22}
23impl From<Runtime> for RuntimeOrHandle {
24 fn from(value: Runtime) -> Self {
25 Self::Runtime(value.handle().clone(), Arc::new(Mutex::new(Some(value))))
26 }
27}
28
29impl Drop for RuntimeOrHandle {
30 fn drop(&mut self) {
31 if let Self::Runtime(_, runtime) = self
32 && let Some(h) = runtime.lock().unwrap().take()
33 {
34 h.shutdown_timeout(Duration::from_secs(0))
35 }
36 }
37}
38
39impl RuntimeOrHandle {
40 pub fn handle(&self) -> &Handle {
41 match self {
42 Self::Handle(h) => h,
43 Self::Runtime(h, _) => h,
44 }
45 }
46}
47
48#[derive(Clone)]
49pub struct ThreadPool {
50 inner: rusty_pool::ThreadPool,
51}
52
53impl std::ops::Deref for ThreadPool {
54 type Target = rusty_pool::ThreadPool;
55
56 fn deref(&self) -> &Self::Target {
57 &self.inner
58 }
59}
60
61impl std::fmt::Debug for ThreadPool {
62 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63 f.debug_struct("ThreadPool")
64 .field("name", &self.get_name())
65 .field("current_worker_count", &self.get_current_worker_count())
66 .field("idle_worker_count", &self.get_idle_worker_count())
67 .finish()
68 }
69}
70
71#[derive(Clone, Debug)]
73pub struct TokioTaskManager {
74 rt: RuntimeOrHandle,
75 pool: Arc<ThreadPool>,
76}
77
78impl TokioTaskManager {
79 pub fn new<I>(rt: I) -> Self
80 where
81 I: Into<RuntimeOrHandle>,
82 {
83 let concurrency = std::thread::available_parallelism()
84 .unwrap_or(NonZeroUsize::new(1).unwrap())
85 .get();
86 let max_threads = 200usize.max(concurrency * 100);
87
88 Self {
89 rt: rt.into(),
90 pool: Arc::new(ThreadPool {
91 inner: rusty_pool::Builder::new()
92 .name("TokioTaskManager Thread Pool".to_string())
93 .core_size(max_threads)
94 .max_size(max_threads)
95 .build(),
96 }),
97 }
98 }
99
100 pub fn runtime_handle(&self) -> tokio::runtime::Handle {
101 self.rt.handle().clone()
102 }
103
104 pub fn pool_handle(&self) -> Arc<ThreadPool> {
105 self.pool.clone()
106 }
107}
108
109impl Default for TokioTaskManager {
110 fn default() -> Self {
111 Self::new(Handle::current())
112 }
113}
114
115impl VirtualTaskManager for TokioTaskManager {
116 fn sleep_now(&self, time: Duration) -> Pin<Box<dyn Future<Output = ()> + Send + Sync>> {
118 let handle = self.runtime_handle();
119 Box::pin(async move {
120 SleepNow::default()
121 .enter(handle, time)
122 .await
123 .ok()
124 .unwrap_or(())
125 })
126 }
127
128 fn task_shared(
130 &self,
131 task: Box<dyn FnOnce() -> BoxFuture<'static, ()> + Send + 'static>,
132 ) -> Result<(), WasiThreadError> {
133 self.rt.handle().spawn(async move {
134 let fut = task();
135 fut.await
136 });
137 Ok(())
138 }
139
140 fn task_wasm(&self, task: TaskWasm) -> Result<(), WasiThreadError> {
142 fn env_and_store(
143 task: TaskWasm,
144 ) -> Result<(WasiFunctionEnv, wasmer::Store, TaskWasmCallbacks), WasiThreadError> {
145 let (make_memory, instance_group_data) = match task.spawn_type {
146 SpawnType::CreateMemory => (SpawnMemoryTypeOrStore::New, None),
147 SpawnType::NewLinkerInstanceGroup(instance_group_data) => {
148 (SpawnMemoryTypeOrStore::New, Some(instance_group_data))
149 }
150 SpawnType::CreateMemoryOfType(t) => (SpawnMemoryTypeOrStore::Type(t), None),
151 SpawnType::AttachMemory(mem) => {
152 let mut store = task.env.runtime().new_store();
153 let memory = mem.attach(&mut store);
154 (SpawnMemoryTypeOrStore::StoreAndMemory(store, memory), None)
155 }
156 };
157
158 let (env, store) = WasiFunctionEnv::new_with_store(
159 task.module,
160 task.env,
161 task.globals,
162 make_memory,
163 task.update_layout,
164 task.call_initialize,
165 instance_group_data,
166 )?;
167 Ok((env, store, task.callbacks))
168 }
169
170 let (sx, rx) = std::sync::mpsc::channel();
171
172 if task.callbacks.trigger.is_some() {
173 tracing::trace!("spawning task_wasm trigger in async pool");
174 self.pool.execute(move || {
175 let (mut ctx, mut store, callbacks) = match env_and_store(task) {
176 Ok(x) => {
177 sx.send(Ok(())).unwrap();
178 x
179 }
180 Err(c) => {
181 tracing::error!("failed to prepare environment for task_wasm trigger: {c}");
182 sx.send(Err(c)).unwrap();
183 return;
184 }
185 };
186
187 let result = {
188 let mut trigger = (callbacks.trigger.unwrap())();
189 let pre_run = callbacks.pre_run;
190 let ctx = &mut ctx;
191 let store = &mut store;
192 block_on(async move {
193 let result = loop {
195 let env = ctx.data(store);
196 break tokio::select! {
197 r = &mut trigger => r,
198 _ = env.thread.wait_for_signal() => {
199 tracing::debug!("wait-for-signal(triggered)");
200 let mut ctx = ctx.env.clone().into_mut(store);
201 if let Err(err) =
202 crate::WasiEnv::do_pending_link_operations(
203 &mut ctx,
204 false
205 ).and_then(|()|
206 crate::WasiEnv::process_signals_and_exit(&mut ctx)
207 )
208 {
209 match err {
210 crate::WasiError::Exit(code) => Err(code),
211 err => {
212 tracing::error!("failed to process signals - {}", err);
213 continue;
214 }
215 }
216 } else {
217 continue;
218 }
219 }
220 _ = crate::wait_for_snapshot(env) => {
221 tracing::debug!("wait-for-snapshot(triggered)");
222 let mut ctx = ctx.env.clone().into_mut(store);
223 crate::os::task::WasiProcessInner::do_checkpoints_from_outside(&mut ctx);
224 continue;
225 }
226 };
227 };
228
229 if let Some(pre_run) = pre_run {
230 pre_run(ctx, store).await;
231 }
232
233 result
234 })
235 };
236
237 (callbacks.run)(TaskWasmRunProperties {
239 ctx,
240 store,
241 trigger_result: Some(result),
242 recycle: callbacks.recycle,
243 });
244 });
245 } else {
246 tracing::trace!("spawning task_wasm in blocking thread");
247
248 self.pool.execute(move || {
250 tracing::trace!("task_wasm started in blocking thread");
251 let (mut ctx, mut store, callbacks) = match env_and_store(task) {
252 Ok(x) => {
253 sx.send(Ok(())).unwrap();
254 x
255 }
256 Err(c) => {
257 sx.send(Err(c)).unwrap();
258 return;
259 }
260 };
261
262 if let Some(pre_run) = callbacks.pre_run {
263 block_on(pre_run(&mut ctx, &mut store));
264 }
265
266 (callbacks.run)(TaskWasmRunProperties {
268 ctx,
269 store,
270 trigger_result: None,
271 recycle: callbacks.recycle,
272 });
273 });
274 }
275
276 rx.recv()
277 .map_err(|_| WasiThreadError::InvalidWasmContext)??;
278
279 Ok(())
280 }
281
282 fn task_dedicated(
284 &self,
285 task: Box<dyn FnOnce() + Send + 'static>,
286 ) -> Result<(), WasiThreadError> {
287 self.pool.execute(move || {
288 task();
289 });
290 Ok(())
291 }
292
293 fn thread_parallelism(&self) -> Result<usize, WasiThreadError> {
295 Ok(std::thread::available_parallelism()
296 .map(usize::from)
297 .unwrap_or(8))
298 }
299}
300
301#[derive(Default)]
303pub struct SleepNow {
304 abort_handle: Option<tokio::task::AbortHandle>,
305}
306
307impl SleepNow {
308 pub async fn enter(
309 &mut self,
310 handle: tokio::runtime::Handle,
311 time: Duration,
312 ) -> Result<(), tokio::task::JoinError> {
313 let handle = handle.spawn(async move {
314 if time == Duration::ZERO {
315 tokio::task::yield_now().await;
316 } else {
317 tokio::time::sleep(time).await;
318 }
319 });
320 self.abort_handle = Some(handle.abort_handle());
321 handle.await
322 }
323}
324
325impl Drop for SleepNow {
326 fn drop(&mut self) {
327 if let Some(h) = self.abort_handle.as_ref() {
328 h.abort()
329 }
330 }
331}