1use super::{
2 control_plane::TaskCountGuard,
3 task_join_handle::{OwnedTaskStatus, TaskJoinHandle},
4};
5use crate::{
6 WasiRuntimeError,
7 os::task::process::{WasiProcessId, WasiProcessInner},
8 state::LinkError,
9 syscalls::HandleRewindType,
10};
11use bytes::{Bytes, BytesMut};
12use serde::{Deserialize, Serialize};
13use std::sync::atomic::{AtomicBool, Ordering};
14use std::{
15 collections::HashMap,
16 ops::{Deref, DerefMut},
17 sync::{Arc, Condvar, Mutex, Weak},
18 task::Waker,
19};
20use wasmer::{ExportError, InstantiationError, MemoryError};
21use wasmer_wasix_types::{
22 types::Signal,
23 wasi::{Errno, ExitCode},
24 wasix::ThreadStartType,
25};
26
27#[derive(Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
29pub struct WasiThreadId(u32);
30
31impl WasiThreadId {
32 pub fn raw(&self) -> u32 {
33 self.0
34 }
35
36 pub fn inc(&mut self) -> WasiThreadId {
37 let ret = *self;
38 self.0 += 1;
39 ret
40 }
41}
42
43impl From<i32> for WasiThreadId {
44 fn from(id: i32) -> Self {
45 Self(id as u32)
46 }
47}
48
49impl From<WasiThreadId> for i32 {
50 fn from(val: WasiThreadId) -> Self {
51 val.0 as i32
52 }
53}
54
55impl From<u32> for WasiThreadId {
56 fn from(id: u32) -> Self {
57 Self(id)
58 }
59}
60
61impl From<WasiThreadId> for u32 {
62 fn from(t: WasiThreadId) -> u32 {
63 t.0
64 }
65}
66
67impl std::fmt::Display for WasiThreadId {
68 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69 write!(f, "{}", self.0)
70 }
71}
72
73impl std::fmt::Debug for WasiThreadId {
74 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75 write!(f, "{}", self.0)
76 }
77}
78
79#[derive(Debug, Clone)]
81struct ThreadSnapshot {
82 call_stack: Bytes,
83 store_data: Bytes,
84}
85
86#[derive(Debug, Clone, Default)]
88pub struct ThreadStack {
89 memory_stack: Vec<u8>,
90 memory_stack_corrected: Vec<u8>,
91 snapshots: HashMap<u128, ThreadSnapshot>,
92 next: Option<Box<ThreadStack>>,
93}
94
95#[derive(Clone, Debug)]
98pub struct WasiThread {
99 state: Arc<WasiThreadState>,
100 layout: WasiMemoryLayout,
101 start: ThreadStartType,
102
103 rewind: Option<RewindResult>,
105}
106
107impl WasiThread {
108 pub fn id(&self) -> WasiThreadId {
109 self.state.id
110 }
111
112 pub(crate) fn set_rewind(&mut self, rewind: RewindResult) {
114 self.rewind.replace(rewind);
115 }
116
117 pub(crate) fn take_rewind(&mut self) -> Option<RewindResult> {
119 self.rewind.take()
120 }
121
122 pub fn thread_start_type(&self) -> ThreadStartType {
124 self.start
125 }
126
127 pub(crate) fn has_rewind_of_type(&self, type_: HandleRewindType) -> bool {
130 match type_ {
131 HandleRewindType::ResultDriven => match &self.rewind {
132 Some(rewind) => match rewind.rewind_result {
133 RewindResultType::RewindRestart => true,
134 RewindResultType::RewindWithoutResult => false,
135 RewindResultType::RewindWithResult(_) => true,
136 },
137 None => false,
138 },
139 HandleRewindType::ResultLess => match &self.rewind {
140 Some(rewind) => match rewind.rewind_result {
141 RewindResultType::RewindRestart => true,
142 RewindResultType::RewindWithoutResult => true,
143 RewindResultType::RewindWithResult(_) => false,
144 },
145 None => false,
146 },
147 }
148 }
149
150 pub(crate) fn set_deep_sleeping(&self, val: bool) {
153 self.state.deep_sleeping.store(val, Ordering::SeqCst);
154 }
155
156 pub(crate) fn is_deep_sleeping(&self) -> bool {
159 self.state.deep_sleeping.load(Ordering::SeqCst)
160 }
161
162 #[cfg(feature = "journal")]
165 pub(crate) fn set_checkpointing(&self, val: bool) {
166 self.state.check_pointing.store(val, Ordering::SeqCst);
167 }
168
169 #[cfg(feature = "journal")]
172 pub(crate) fn is_check_pointing(&self) -> bool {
173 self.state.check_pointing.load(Ordering::SeqCst)
174 }
175
176 #[allow(dead_code)]
178 pub(crate) fn memory_layout(&self) -> &WasiMemoryLayout {
179 &self.layout
180 }
181
182 pub(crate) fn set_memory_layout(&mut self, layout: WasiMemoryLayout) {
184 self.layout = layout;
185 }
186}
187
188pub struct WasiThreadRunGuard {
195 pub thread: WasiThread,
196}
197
198impl WasiThreadRunGuard {
199 pub fn new(thread: WasiThread) -> Self {
200 Self { thread }
201 }
202}
203
204impl Drop for WasiThreadRunGuard {
205 fn drop(&mut self) {
206 self.thread
207 .set_status_finished(Err(
208 crate::RuntimeError::new("Thread manager disconnected").into()
209 ));
210 }
211}
212
213pub use wasmer_wasix_types::wasix::WasiMemoryLayout;
215
216#[derive(Clone, Debug)]
217pub enum RewindResultType {
218 RewindRestart,
220 RewindWithoutResult,
222 RewindWithResult(Bytes),
224}
225
226#[derive(Clone, Debug)]
228pub(crate) struct RewindResult {
229 pub memory_stack: Option<Bytes>,
231 pub rewind_result: RewindResultType,
234}
235
236#[derive(Debug)]
237struct WasiThreadState {
238 is_main: bool,
239 pid: WasiProcessId,
240 id: WasiThreadId,
241 signals: Mutex<(Vec<Signal>, Vec<Waker>)>,
242 stack: Mutex<ThreadStack>,
243 status: Arc<OwnedTaskStatus>,
244 #[cfg(feature = "journal")]
245 check_pointing: AtomicBool,
246 deep_sleeping: AtomicBool,
247
248 _task_count_guard: TaskCountGuard,
251}
252
253static NO_MORE_BYTES: [u8; 0] = [0u8; 0];
254
255impl WasiThread {
256 pub fn new(
257 pid: WasiProcessId,
258 id: WasiThreadId,
259 is_main: bool,
260 status: Arc<OwnedTaskStatus>,
261 guard: TaskCountGuard,
262 layout: WasiMemoryLayout,
263 start: ThreadStartType,
264 ) -> Self {
265 Self {
266 state: Arc::new(WasiThreadState {
267 is_main,
268 pid,
269 id,
270 status,
271 signals: Mutex::new((Vec::new(), Vec::new())),
272 stack: Mutex::new(ThreadStack::default()),
273 #[cfg(feature = "journal")]
274 check_pointing: AtomicBool::new(false),
275 deep_sleeping: AtomicBool::new(false),
276 _task_count_guard: guard,
277 }),
278 layout,
279 start,
280 rewind: None,
281 }
282 }
283
284 pub fn pid(&self) -> WasiProcessId {
286 self.state.pid
287 }
288
289 pub fn tid(&self) -> WasiThreadId {
291 self.state.id
292 }
293
294 pub fn is_main(&self) -> bool {
296 self.state.is_main
297 }
298
299 pub fn join_handle(&self) -> TaskJoinHandle {
301 self.state.status.handle()
302 }
303
304 pub fn signals(&self) -> &Mutex<(Vec<Signal>, Vec<Waker>)> {
306 &self.state.signals
307 }
308
309 pub fn set_status_running(&self) {
310 self.state.status.set_running();
311 }
312
313 pub fn set_or_get_exit_code_for_signal(&self, sig: Signal) -> ExitCode {
317 let default_exitcode: ExitCode = match sig {
318 Signal::Sigquit | Signal::Sigabrt => Errno::Success.into(),
319 Signal::Sigpipe => Errno::Pipe.into(),
320 _ => Errno::Intr.into(),
321 };
322 self.set_status_finished(Ok(default_exitcode));
324 self.try_join()
325 .map(|r| r.unwrap_or(default_exitcode))
326 .unwrap_or(default_exitcode)
327 }
328
329 pub fn set_status_finished(&self, res: Result<ExitCode, WasiRuntimeError>) {
332 self.state.status.set_finished(res.map_err(Arc::new));
333 }
334
335 pub async fn join(&self) -> Result<ExitCode, Arc<WasiRuntimeError>> {
337 self.state.status.await_termination().await
338 }
339
340 pub fn try_join(&self) -> Option<Result<ExitCode, Arc<WasiRuntimeError>>> {
342 self.state.status.status().into_finished()
343 }
344
345 pub fn signal(&self, signal: Signal) {
347 let tid = self.tid();
348 tracing::trace!(%tid, "signal-thread({:?})", signal);
349
350 let mut guard = self.state.signals.lock().unwrap();
351 if !guard.0.contains(&signal) {
352 guard.0.push(signal);
353 }
354 guard.1.drain(..).for_each(|w| w.wake());
355 }
356
357 pub fn has_signal(&self, signals: &[Signal]) -> bool {
359 let guard = self.state.signals.lock().unwrap();
360 for s in guard.0.iter() {
361 if signals.contains(s) {
362 return true;
363 }
364 }
365 false
366 }
367
368 pub async fn wait_for_signal(&self) {
370 struct SignalPoller<'a> {
372 thread: &'a WasiThread,
373 }
374 impl std::future::Future for SignalPoller<'_> {
375 type Output = ();
376 fn poll(
377 self: std::pin::Pin<&mut Self>,
378 cx: &mut std::task::Context<'_>,
379 ) -> std::task::Poll<Self::Output> {
380 if self.thread.has_signals_or_subscribe(cx.waker()) {
381 return std::task::Poll::Ready(());
382 }
383 std::task::Poll::Pending
384 }
385 }
386 SignalPoller { thread: self }.await
387 }
388
389 pub fn pop_signals_or_subscribe(&self, waker: &Waker) -> Option<Vec<Signal>> {
391 let mut guard = self.state.signals.lock().unwrap();
392 let mut ret = Vec::new();
393 std::mem::swap(&mut ret, &mut guard.0);
394 match ret.is_empty() {
395 true => {
396 if !guard.1.iter().any(|w| w.will_wake(waker)) {
397 guard.1.push(waker.clone());
398 }
399 None
400 }
401 false => Some(ret),
402 }
403 }
404
405 pub fn signals_subscribe(&self, waker: &Waker) {
406 let mut guard = self.state.signals.lock().unwrap();
407 if !guard.1.iter().any(|w| w.will_wake(waker)) {
408 guard.1.push(waker.clone());
409 }
410 }
411
412 pub fn has_signals_or_subscribe(&self, waker: &Waker) -> bool {
414 let mut guard = self.state.signals.lock().unwrap();
415 let has_signals = !guard.0.is_empty();
416 if !has_signals && !guard.1.iter().any(|w| w.will_wake(waker)) {
417 guard.1.push(waker.clone());
418 }
419 has_signals
420 }
421
422 pub fn pop_signals(&self) -> Vec<Signal> {
424 let mut guard = self.state.signals.lock().unwrap();
425 let mut ret = Vec::new();
426 std::mem::swap(&mut ret, &mut guard.0);
427 ret
428 }
429
430 pub fn add_snapshot(
432 &self,
433 mut memory_stack: &[u8],
434 mut memory_stack_corrected: &[u8],
435 hash: u128,
436 rewind_stack: &[u8],
437 store_data: &[u8],
438 ) {
439 let mut stack = self.state.stack.lock().unwrap();
441 let mut pstack = stack.deref_mut();
442 loop {
443 let memory_stack_before = pstack.memory_stack.len();
445 let memory_stack_after = memory_stack.len();
446 if memory_stack_before > memory_stack_after
447 || (!pstack
448 .memory_stack
449 .iter()
450 .zip(memory_stack.iter())
451 .any(|(a, b)| *a == *b)
452 && !pstack
453 .memory_stack_corrected
454 .iter()
455 .zip(memory_stack.iter())
456 .any(|(a, b)| *a == *b))
457 {
458 let mut new_stack = ThreadStack {
460 memory_stack: memory_stack.to_vec(),
461 memory_stack_corrected: memory_stack_corrected.to_vec(),
462 ..Default::default()
463 };
464 std::mem::swap(pstack, &mut new_stack);
465 memory_stack = &NO_MORE_BYTES[..];
466 memory_stack_corrected = &NO_MORE_BYTES[..];
467
468 let mut disown = Some(Box::new(new_stack));
470 if let Some(disown) = disown.as_ref()
471 && !disown.snapshots.is_empty()
472 {
473 tracing::trace!(
474 "wasi[{}]::stacks forgotten (memory_stack_before={}, memory_stack_after={})",
475 self.pid(),
476 memory_stack_before,
477 memory_stack_after
478 );
479 }
480 let mut total_forgotten = 0usize;
481 while let Some(disowned) = disown {
482 for _hash in disowned.snapshots.keys() {
483 total_forgotten += 1;
484 }
485 disown = disowned.next;
486 }
487 if total_forgotten > 0 {
488 tracing::trace!(
489 "wasi[{}]::stack has been forgotten (cnt={})",
490 self.pid(),
491 total_forgotten
492 );
493 }
494 } else {
495 memory_stack = &memory_stack[pstack.memory_stack.len()..];
496 memory_stack_corrected =
497 &memory_stack_corrected[pstack.memory_stack_corrected.len()..];
498 }
499
500 if memory_stack.is_empty() {
502 break;
503 }
504
505 if pstack.next.is_none() {
507 let new_stack = ThreadStack {
508 memory_stack: memory_stack.to_vec(),
509 memory_stack_corrected: memory_stack_corrected.to_vec(),
510 ..Default::default()
511 };
512 pstack.next.replace(Box::new(new_stack));
513 }
514 pstack = pstack.next.as_mut().unwrap();
515 }
516
517 pstack.snapshots.insert(
519 hash,
520 ThreadSnapshot {
521 call_stack: BytesMut::from(rewind_stack).freeze(),
522 store_data: BytesMut::from(store_data).freeze(),
523 },
524 );
525 }
526
527 pub fn get_snapshot(&self, hash: u128) -> Option<(BytesMut, Bytes, Bytes)> {
529 let mut memory_stack = BytesMut::new();
530
531 let stack = self.state.stack.lock().unwrap();
532 let mut pstack = stack.deref();
533 loop {
534 memory_stack.extend(pstack.memory_stack_corrected.iter());
535 if let Some(snapshot) = pstack.snapshots.get(&hash) {
536 return Some((
537 memory_stack,
538 snapshot.call_stack.clone(),
539 snapshot.store_data.clone(),
540 ));
541 }
542 if let Some(next) = pstack.next.as_ref() {
543 pstack = next.deref();
544 } else {
545 return None;
546 }
547 }
548 }
549
550 pub fn copy_stack_from(&self, other: &WasiThread) {
552 let mut stack = {
553 let stack_guard = other.state.stack.lock().unwrap();
554 stack_guard.clone()
555 };
556
557 let mut stack_guard = self.state.stack.lock().unwrap();
558 std::mem::swap(stack_guard.deref_mut(), &mut stack);
559 }
560}
561
562#[derive(Debug)]
563pub struct WasiThreadHandleProtected {
564 thread: WasiThread,
565 inner: Weak<(Mutex<WasiProcessInner>, Condvar)>,
566}
567
568#[derive(Debug, Clone)]
569pub struct WasiThreadHandle {
570 protected: Arc<WasiThreadHandleProtected>,
571}
572
573impl WasiThreadHandle {
574 pub(crate) fn new(
575 thread: WasiThread,
576 inner: &Arc<(Mutex<WasiProcessInner>, Condvar)>,
577 ) -> WasiThreadHandle {
578 Self {
579 protected: Arc::new(WasiThreadHandleProtected {
580 thread,
581 inner: Arc::downgrade(inner),
582 }),
583 }
584 }
585
586 pub fn id(&self) -> WasiThreadId {
587 self.protected.thread.tid()
588 }
589
590 pub fn as_thread(&self) -> WasiThread {
591 self.protected.thread.clone()
592 }
593}
594
595impl Drop for WasiThreadHandleProtected {
596 fn drop(&mut self) {
597 let id = self.thread.tid();
598 if let Some(inner) = Weak::upgrade(&self.inner) {
599 let mut inner = inner.0.lock().unwrap();
600 if let Some(ctrl) = inner.threads.remove(&id) {
601 ctrl.set_status_finished(Ok(Errno::Success.into()));
602 }
603 inner.thread_count -= 1;
604 }
605 }
606}
607
608impl std::ops::Deref for WasiThreadHandle {
609 type Target = WasiThread;
610
611 fn deref(&self) -> &Self::Target {
612 &self.protected.thread
613 }
614}
615
616#[derive(thiserror::Error, Debug, Clone)]
617pub enum WasiThreadError {
618 #[error("Multithreading is not supported")]
619 Unsupported,
620 #[error("The method named is not an exported function")]
621 MethodNotFound,
622 #[error("Failed to create the requested memory - {0}")]
623 MemoryCreateFailed(MemoryError),
624 #[error("{0}")]
625 ExportError(ExportError),
626 #[error("Linker error: {0}")]
627 LinkError(Arc<LinkError>),
628 #[error("Failed to create the instance - {0}")]
629 InstanceCreateFailed(Box<InstantiationError>),
631 #[error("Initialization function failed - {0}")]
632 InitFailed(Arc<anyhow::Error>),
633 #[error("WASM context is invalid")]
635 InvalidWasmContext,
636}
637
638impl From<WasiThreadError> for Errno {
639 fn from(a: WasiThreadError) -> Errno {
640 match a {
641 WasiThreadError::Unsupported => Errno::Notsup,
642 WasiThreadError::MethodNotFound => Errno::Inval,
643 WasiThreadError::MemoryCreateFailed(_) => Errno::Nomem,
644 WasiThreadError::ExportError(_) => Errno::Noexec,
645 WasiThreadError::LinkError(_) => Errno::Noexec,
646 WasiThreadError::InstanceCreateFailed(_) => Errno::Noexec,
647 WasiThreadError::InitFailed(_) => Errno::Noexec,
648 WasiThreadError::InvalidWasmContext => Errno::Noexec,
649 }
650 }
651}