1use super::{
2 control_plane::TaskCountGuard,
3 task_join_handle::{OwnedTaskStatus, TaskJoinHandle},
4};
5use crate::{
6 WasiRuntimeError,
7 os::task::process::{WasiProcessId, WasiProcessInner},
8 state::LinkError,
9 syscalls::HandleRewindType,
10};
11use bytes::{Bytes, BytesMut};
12use serde::{Deserialize, Serialize};
13use std::sync::atomic::{AtomicBool, Ordering};
14use std::{
15 collections::HashMap,
16 ops::{Deref, DerefMut},
17 sync::{Arc, Condvar, Mutex, Weak},
18 task::Waker,
19};
20use wasmer::{ExportError, InstantiationError, MemoryError};
21use wasmer_wasix_types::{
22 types::Signal,
23 wasi::{Errno, ExitCode},
24 wasix::ThreadStartType,
25};
26
27#[derive(Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
29pub struct WasiThreadId(u32);
30
31impl WasiThreadId {
32 pub fn raw(&self) -> u32 {
33 self.0
34 }
35
36 pub fn inc(&mut self) -> WasiThreadId {
37 let ret = *self;
38 self.0 += 1;
39 ret
40 }
41}
42
43impl From<i32> for WasiThreadId {
44 fn from(id: i32) -> Self {
45 Self(id as u32)
46 }
47}
48
49impl From<WasiThreadId> for i32 {
50 fn from(val: WasiThreadId) -> Self {
51 val.0 as i32
52 }
53}
54
55impl From<u32> for WasiThreadId {
56 fn from(id: u32) -> Self {
57 Self(id)
58 }
59}
60
61impl From<WasiThreadId> for u32 {
62 fn from(t: WasiThreadId) -> u32 {
63 t.0
64 }
65}
66
67impl std::fmt::Display for WasiThreadId {
68 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69 write!(f, "{}", self.0)
70 }
71}
72
73impl std::fmt::Debug for WasiThreadId {
74 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75 write!(f, "{}", self.0)
76 }
77}
78
79#[derive(Debug, Clone)]
81struct ThreadSnapshot {
82 call_stack: Bytes,
83 store_data: Bytes,
84}
85
86#[derive(Debug, Clone, Default)]
88pub struct ThreadStack {
89 memory_stack: Vec<u8>,
90 memory_stack_corrected: Vec<u8>,
91 snapshots: HashMap<u128, ThreadSnapshot>,
92 next: Option<Box<ThreadStack>>,
93}
94
95#[derive(Clone, Debug)]
98pub struct WasiThread {
99 state: Arc<WasiThreadState>,
100 layout: WasiMemoryLayout,
101 start: ThreadStartType,
102
103 rewind: Option<RewindResult>,
105}
106
107impl WasiThread {
108 pub fn id(&self) -> WasiThreadId {
109 self.state.id
110 }
111
112 pub(crate) fn set_rewind(&mut self, rewind: RewindResult) {
114 self.rewind.replace(rewind);
115 }
116
117 pub(crate) fn take_rewind(&mut self) -> Option<RewindResult> {
119 self.rewind.take()
120 }
121
122 pub fn thread_start_type(&self) -> ThreadStartType {
124 self.start
125 }
126
127 pub(crate) fn has_rewind_of_type(&self, type_: HandleRewindType) -> bool {
130 match type_ {
131 HandleRewindType::ResultDriven => match &self.rewind {
132 Some(rewind) => match rewind.rewind_result {
133 RewindResultType::RewindRestart => true,
134 RewindResultType::RewindWithoutResult => false,
135 RewindResultType::RewindWithResult(_) => true,
136 },
137 None => false,
138 },
139 HandleRewindType::ResultLess => match &self.rewind {
140 Some(rewind) => match rewind.rewind_result {
141 RewindResultType::RewindRestart => true,
142 RewindResultType::RewindWithoutResult => true,
143 RewindResultType::RewindWithResult(_) => false,
144 },
145 None => false,
146 },
147 }
148 }
149
150 pub(crate) fn set_deep_sleeping(&self, val: bool) {
153 self.state.deep_sleeping.store(val, Ordering::SeqCst);
154 }
155
156 pub(crate) fn is_deep_sleeping(&self) -> bool {
159 self.state.deep_sleeping.load(Ordering::SeqCst)
160 }
161
162 #[cfg(feature = "journal")]
165 pub(crate) fn set_checkpointing(&self, val: bool) {
166 self.state.check_pointing.store(val, Ordering::SeqCst);
167 }
168
169 #[cfg(feature = "journal")]
172 pub(crate) fn is_check_pointing(&self) -> bool {
173 self.state.check_pointing.load(Ordering::SeqCst)
174 }
175
176 #[allow(dead_code)]
178 pub(crate) fn memory_layout(&self) -> &WasiMemoryLayout {
179 &self.layout
180 }
181
182 pub(crate) fn set_memory_layout(&mut self, layout: WasiMemoryLayout) {
184 self.layout = layout;
185 }
186}
187
188pub struct WasiThreadRunGuard {
195 pub thread: WasiThread,
196}
197
198impl WasiThreadRunGuard {
199 pub fn new(thread: WasiThread) -> Self {
200 Self { thread }
201 }
202}
203
204impl Drop for WasiThreadRunGuard {
205 fn drop(&mut self) {
206 self.thread
207 .set_status_finished(Err(
208 crate::RuntimeError::new("Thread manager disconnected").into()
209 ));
210 }
211}
212
213pub use wasmer_wasix_types::wasix::WasiMemoryLayout;
215
216#[derive(Clone, Debug)]
217pub enum RewindResultType {
218 RewindRestart,
220 RewindWithoutResult,
222 RewindWithResult(Bytes),
224}
225
226#[derive(Clone, Debug)]
228pub(crate) struct RewindResult {
229 pub memory_stack: Option<Bytes>,
231 pub rewind_result: RewindResultType,
234}
235
236#[derive(Debug)]
237struct WasiThreadState {
238 is_main: bool,
239 pid: WasiProcessId,
240 id: WasiThreadId,
241 signals: Mutex<(Vec<Signal>, Vec<Waker>)>,
242 stack: Mutex<ThreadStack>,
243 status: Arc<OwnedTaskStatus>,
244 #[cfg(feature = "journal")]
245 check_pointing: AtomicBool,
246 deep_sleeping: AtomicBool,
247
248 _task_count_guard: TaskCountGuard,
251}
252
253static NO_MORE_BYTES: [u8; 0] = [0u8; 0];
254
255impl WasiThread {
256 pub fn new(
257 pid: WasiProcessId,
258 id: WasiThreadId,
259 is_main: bool,
260 status: Arc<OwnedTaskStatus>,
261 guard: TaskCountGuard,
262 layout: WasiMemoryLayout,
263 start: ThreadStartType,
264 ) -> Self {
265 Self {
266 state: Arc::new(WasiThreadState {
267 is_main,
268 pid,
269 id,
270 status,
271 signals: Mutex::new((Vec::new(), Vec::new())),
272 stack: Mutex::new(ThreadStack::default()),
273 #[cfg(feature = "journal")]
274 check_pointing: AtomicBool::new(false),
275 deep_sleeping: AtomicBool::new(false),
276 _task_count_guard: guard,
277 }),
278 layout,
279 start,
280 rewind: None,
281 }
282 }
283
284 pub fn pid(&self) -> WasiProcessId {
286 self.state.pid
287 }
288
289 pub fn tid(&self) -> WasiThreadId {
291 self.state.id
292 }
293
294 pub fn is_main(&self) -> bool {
296 self.state.is_main
297 }
298
299 pub fn join_handle(&self) -> TaskJoinHandle {
301 self.state.status.handle()
302 }
303
304 pub fn signals(&self) -> &Mutex<(Vec<Signal>, Vec<Waker>)> {
306 &self.state.signals
307 }
308
309 pub fn set_status_running(&self) {
310 self.state.status.set_running();
311 }
312
313 pub fn set_or_get_exit_code_for_signal(&self, sig: Signal) -> ExitCode {
317 let default_exitcode: ExitCode = match sig {
318 Signal::Sigquit => Errno::Success.into(),
319 Signal::Sigabrt => ExitCode::from(128 + sig as i32),
321 Signal::Sigpipe => Errno::Pipe.into(),
322 _ => Errno::Intr.into(),
323 };
324 self.set_status_finished(Ok(default_exitcode));
326 self.try_join()
327 .map(|r| r.unwrap_or(default_exitcode))
328 .unwrap_or(default_exitcode)
329 }
330
331 pub fn set_status_finished(&self, res: Result<ExitCode, WasiRuntimeError>) {
334 self.state.status.set_finished(res.map_err(Arc::new));
335 }
336
337 pub async fn join(&self) -> Result<ExitCode, Arc<WasiRuntimeError>> {
339 self.state.status.await_termination().await
340 }
341
342 pub fn try_join(&self) -> Option<Result<ExitCode, Arc<WasiRuntimeError>>> {
344 self.state.status.status().into_finished()
345 }
346
347 pub fn signal(&self, signal: Signal) {
349 let tid = self.tid();
350 tracing::trace!(%tid, "signal-thread({:?})", signal);
351
352 let mut guard = self.state.signals.lock().unwrap();
353 if !guard.0.contains(&signal) {
354 guard.0.push(signal);
355 }
356 guard.1.drain(..).for_each(|w| w.wake());
357 }
358
359 pub fn has_signal(&self, signals: &[Signal]) -> bool {
361 let guard = self.state.signals.lock().unwrap();
362 for s in guard.0.iter() {
363 if signals.contains(s) {
364 return true;
365 }
366 }
367 false
368 }
369
370 pub async fn wait_for_signal(&self) {
372 struct SignalPoller<'a> {
374 thread: &'a WasiThread,
375 }
376 impl std::future::Future for SignalPoller<'_> {
377 type Output = ();
378 fn poll(
379 self: std::pin::Pin<&mut Self>,
380 cx: &mut std::task::Context<'_>,
381 ) -> std::task::Poll<Self::Output> {
382 if self.thread.has_signals_or_subscribe(cx.waker()) {
383 return std::task::Poll::Ready(());
384 }
385 std::task::Poll::Pending
386 }
387 }
388 SignalPoller { thread: self }.await
389 }
390
391 pub fn pop_signals_or_subscribe(&self, waker: &Waker) -> Option<Vec<Signal>> {
393 let mut guard = self.state.signals.lock().unwrap();
394 let mut ret = Vec::new();
395 std::mem::swap(&mut ret, &mut guard.0);
396 match ret.is_empty() {
397 true => {
398 if !guard.1.iter().any(|w| w.will_wake(waker)) {
399 guard.1.push(waker.clone());
400 }
401 None
402 }
403 false => Some(ret),
404 }
405 }
406
407 pub fn signals_subscribe(&self, waker: &Waker) {
408 let mut guard = self.state.signals.lock().unwrap();
409 if !guard.1.iter().any(|w| w.will_wake(waker)) {
410 guard.1.push(waker.clone());
411 }
412 }
413
414 pub fn has_signals_or_subscribe(&self, waker: &Waker) -> bool {
416 let mut guard = self.state.signals.lock().unwrap();
417 let has_signals = !guard.0.is_empty();
418 if !has_signals && !guard.1.iter().any(|w| w.will_wake(waker)) {
419 guard.1.push(waker.clone());
420 }
421 has_signals
422 }
423
424 pub fn pop_signals(&self) -> Vec<Signal> {
426 let mut guard = self.state.signals.lock().unwrap();
427 let mut ret = Vec::new();
428 std::mem::swap(&mut ret, &mut guard.0);
429 ret
430 }
431
432 pub fn add_snapshot(
434 &self,
435 mut memory_stack: &[u8],
436 mut memory_stack_corrected: &[u8],
437 hash: u128,
438 rewind_stack: &[u8],
439 store_data: &[u8],
440 ) {
441 let mut stack = self.state.stack.lock().unwrap();
443 let mut pstack = stack.deref_mut();
444 loop {
445 let memory_stack_before = pstack.memory_stack.len();
447 let memory_stack_after = memory_stack.len();
448 if memory_stack_before > memory_stack_after
449 || (!pstack
450 .memory_stack
451 .iter()
452 .zip(memory_stack.iter())
453 .any(|(a, b)| *a == *b)
454 && !pstack
455 .memory_stack_corrected
456 .iter()
457 .zip(memory_stack.iter())
458 .any(|(a, b)| *a == *b))
459 {
460 let mut new_stack = ThreadStack {
462 memory_stack: memory_stack.to_vec(),
463 memory_stack_corrected: memory_stack_corrected.to_vec(),
464 ..Default::default()
465 };
466 std::mem::swap(pstack, &mut new_stack);
467 memory_stack = &NO_MORE_BYTES[..];
468 memory_stack_corrected = &NO_MORE_BYTES[..];
469
470 let mut disown = Some(Box::new(new_stack));
472 if let Some(disown) = disown.as_ref()
473 && !disown.snapshots.is_empty()
474 {
475 tracing::trace!(
476 "wasi[{}]::stacks forgotten (memory_stack_before={}, memory_stack_after={})",
477 self.pid(),
478 memory_stack_before,
479 memory_stack_after
480 );
481 }
482 let mut total_forgotten = 0usize;
483 while let Some(disowned) = disown {
484 for _hash in disowned.snapshots.keys() {
485 total_forgotten += 1;
486 }
487 disown = disowned.next;
488 }
489 if total_forgotten > 0 {
490 tracing::trace!(
491 "wasi[{}]::stack has been forgotten (cnt={})",
492 self.pid(),
493 total_forgotten
494 );
495 }
496 } else {
497 memory_stack = &memory_stack[pstack.memory_stack.len()..];
498 memory_stack_corrected =
499 &memory_stack_corrected[pstack.memory_stack_corrected.len()..];
500 }
501
502 if memory_stack.is_empty() {
504 break;
505 }
506
507 if pstack.next.is_none() {
509 let new_stack = ThreadStack {
510 memory_stack: memory_stack.to_vec(),
511 memory_stack_corrected: memory_stack_corrected.to_vec(),
512 ..Default::default()
513 };
514 pstack.next.replace(Box::new(new_stack));
515 }
516 pstack = pstack.next.as_mut().unwrap();
517 }
518
519 pstack.snapshots.insert(
521 hash,
522 ThreadSnapshot {
523 call_stack: BytesMut::from(rewind_stack).freeze(),
524 store_data: BytesMut::from(store_data).freeze(),
525 },
526 );
527 }
528
529 pub fn get_snapshot(&self, hash: u128) -> Option<(BytesMut, Bytes, Bytes)> {
531 let mut memory_stack = BytesMut::new();
532
533 let stack = self.state.stack.lock().unwrap();
534 let mut pstack = stack.deref();
535 loop {
536 memory_stack.extend(pstack.memory_stack_corrected.iter());
537 if let Some(snapshot) = pstack.snapshots.get(&hash) {
538 return Some((
539 memory_stack,
540 snapshot.call_stack.clone(),
541 snapshot.store_data.clone(),
542 ));
543 }
544 if let Some(next) = pstack.next.as_ref() {
545 pstack = next.deref();
546 } else {
547 return None;
548 }
549 }
550 }
551
552 pub fn copy_stack_from(&self, other: &WasiThread) {
554 let mut stack = {
555 let stack_guard = other.state.stack.lock().unwrap();
556 stack_guard.clone()
557 };
558
559 let mut stack_guard = self.state.stack.lock().unwrap();
560 std::mem::swap(stack_guard.deref_mut(), &mut stack);
561 }
562}
563
564#[derive(Debug)]
565pub struct WasiThreadHandleProtected {
566 thread: WasiThread,
567 inner: Weak<(Mutex<WasiProcessInner>, Condvar)>,
568}
569
570#[derive(Debug, Clone)]
571pub struct WasiThreadHandle {
572 protected: Arc<WasiThreadHandleProtected>,
573}
574
575impl WasiThreadHandle {
576 pub(crate) fn new(
577 thread: WasiThread,
578 inner: &Arc<(Mutex<WasiProcessInner>, Condvar)>,
579 ) -> WasiThreadHandle {
580 Self {
581 protected: Arc::new(WasiThreadHandleProtected {
582 thread,
583 inner: Arc::downgrade(inner),
584 }),
585 }
586 }
587
588 pub fn id(&self) -> WasiThreadId {
589 self.protected.thread.tid()
590 }
591
592 pub fn as_thread(&self) -> WasiThread {
593 self.protected.thread.clone()
594 }
595}
596
597impl Drop for WasiThreadHandleProtected {
598 fn drop(&mut self) {
599 let id = self.thread.tid();
600 if let Some(inner) = Weak::upgrade(&self.inner) {
601 let mut inner = inner.0.lock().unwrap();
602 if let Some(ctrl) = inner.threads.remove(&id) {
603 ctrl.set_status_finished(Ok(Errno::Success.into()));
604 }
605 inner.thread_count -= 1;
606 }
607 }
608}
609
610impl std::ops::Deref for WasiThreadHandle {
611 type Target = WasiThread;
612
613 fn deref(&self) -> &Self::Target {
614 &self.protected.thread
615 }
616}
617
618#[derive(thiserror::Error, Debug, Clone)]
619pub enum WasiThreadError {
620 #[error("Multithreading is not supported")]
621 Unsupported,
622 #[error("The method named is not an exported function")]
623 MethodNotFound,
624 #[error("Failed to create the requested memory - {0}")]
625 MemoryCreateFailed(MemoryError),
626 #[error("{0}")]
627 ExportError(ExportError),
628 #[error("Failed to create additional imports - {0}")]
629 AdditionalImportCreationFailed(Arc<anyhow::Error>),
630 #[error("Linker error: {0}")]
631 LinkError(Arc<LinkError>),
632 #[error("Failed to create the instance - {0}")]
633 InstanceCreateFailed(Box<InstantiationError>),
635 #[error("Initialization function failed - {0}")]
636 InitFailed(Arc<anyhow::Error>),
637 #[error("WASM context is invalid")]
639 InvalidWasmContext,
640}
641
642impl From<WasiThreadError> for Errno {
643 fn from(a: WasiThreadError) -> Errno {
644 match a {
645 WasiThreadError::Unsupported => Errno::Notsup,
646 WasiThreadError::MethodNotFound => Errno::Inval,
647 WasiThreadError::MemoryCreateFailed(_) => Errno::Nomem,
648 WasiThreadError::ExportError(_) => Errno::Noexec,
649 WasiThreadError::AdditionalImportCreationFailed(_) => Errno::Noexec,
650 WasiThreadError::LinkError(_) => Errno::Noexec,
651 WasiThreadError::InstanceCreateFailed(_) => Errno::Noexec,
652 WasiThreadError::InitFailed(_) => Errno::Noexec,
653 WasiThreadError::InvalidWasmContext => Errno::Noexec,
654 }
655 }
656}