1use serde::{Deserialize, Serialize};
2use std::sync::atomic::{AtomicBool, Ordering};
3use std::{
4 collections::HashMap,
5 ops::{Deref, DerefMut},
6 sync::{Arc, Condvar, Mutex, Weak},
7 task::Waker,
8};
9
10use bytes::{Bytes, BytesMut};
11use wasmer::{ExportError, InstantiationError, MemoryError};
12use wasmer_wasix_types::{
13 types::Signal,
14 wasi::{Errno, ExitCode},
15 wasix::ThreadStartType,
16};
17
18use crate::{
19 WasiRuntimeError,
20 os::task::process::{WasiProcessId, WasiProcessInner},
21 state::LinkError,
22 syscalls::HandleRewindType,
23};
24
25use super::{
26 control_plane::TaskCountGuard,
27 task_join_handle::{OwnedTaskStatus, TaskJoinHandle},
28};
29
30#[derive(Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
32pub struct WasiThreadId(u32);
33
34impl WasiThreadId {
35 pub fn raw(&self) -> u32 {
36 self.0
37 }
38
39 pub fn inc(&mut self) -> WasiThreadId {
40 let ret = *self;
41 self.0 += 1;
42 ret
43 }
44}
45
46impl From<i32> for WasiThreadId {
47 fn from(id: i32) -> Self {
48 Self(id as u32)
49 }
50}
51
52impl From<WasiThreadId> for i32 {
53 fn from(val: WasiThreadId) -> Self {
54 val.0 as i32
55 }
56}
57
58impl From<u32> for WasiThreadId {
59 fn from(id: u32) -> Self {
60 Self(id)
61 }
62}
63
64impl From<WasiThreadId> for u32 {
65 fn from(t: WasiThreadId) -> u32 {
66 t.0
67 }
68}
69
70impl std::fmt::Display for WasiThreadId {
71 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72 write!(f, "{}", self.0)
73 }
74}
75
76impl std::fmt::Debug for WasiThreadId {
77 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78 write!(f, "{}", self.0)
79 }
80}
81
82#[derive(Debug, Clone)]
84struct ThreadSnapshot {
85 call_stack: Bytes,
86 store_data: Bytes,
87}
88
89#[derive(Debug, Clone, Default)]
91pub struct ThreadStack {
92 memory_stack: Vec<u8>,
93 memory_stack_corrected: Vec<u8>,
94 snapshots: HashMap<u128, ThreadSnapshot>,
95 next: Option<Box<ThreadStack>>,
96}
97
98#[derive(Clone, Debug)]
101pub struct WasiThread {
102 state: Arc<WasiThreadState>,
103 layout: WasiMemoryLayout,
104 start: ThreadStartType,
105
106 rewind: Option<RewindResult>,
108}
109
110impl WasiThread {
111 pub fn id(&self) -> WasiThreadId {
112 self.state.id
113 }
114
115 pub(crate) fn set_rewind(&mut self, rewind: RewindResult) {
117 self.rewind.replace(rewind);
118 }
119
120 pub(crate) fn take_rewind(&mut self) -> Option<RewindResult> {
122 self.rewind.take()
123 }
124
125 pub fn thread_start_type(&self) -> ThreadStartType {
127 self.start
128 }
129
130 pub(crate) fn has_rewind_of_type(&self, type_: HandleRewindType) -> bool {
133 match type_ {
134 HandleRewindType::ResultDriven => match &self.rewind {
135 Some(rewind) => match rewind.rewind_result {
136 RewindResultType::RewindRestart => true,
137 RewindResultType::RewindWithoutResult => false,
138 RewindResultType::RewindWithResult(_) => true,
139 },
140 None => false,
141 },
142 HandleRewindType::ResultLess => match &self.rewind {
143 Some(rewind) => match rewind.rewind_result {
144 RewindResultType::RewindRestart => true,
145 RewindResultType::RewindWithoutResult => true,
146 RewindResultType::RewindWithResult(_) => false,
147 },
148 None => false,
149 },
150 }
151 }
152
153 pub(crate) fn set_deep_sleeping(&self, val: bool) {
156 self.state.deep_sleeping.store(val, Ordering::SeqCst);
157 }
158
159 pub(crate) fn is_deep_sleeping(&self) -> bool {
162 self.state.deep_sleeping.load(Ordering::SeqCst)
163 }
164
165 #[cfg(feature = "journal")]
168 pub(crate) fn set_checkpointing(&self, val: bool) {
169 self.state.check_pointing.store(val, Ordering::SeqCst);
170 }
171
172 #[cfg(feature = "journal")]
175 pub(crate) fn is_check_pointing(&self) -> bool {
176 self.state.check_pointing.load(Ordering::SeqCst)
177 }
178
179 #[allow(dead_code)]
181 pub(crate) fn memory_layout(&self) -> &WasiMemoryLayout {
182 &self.layout
183 }
184
185 pub(crate) fn set_memory_layout(&mut self, layout: WasiMemoryLayout) {
187 self.layout = layout;
188 }
189}
190
191pub struct WasiThreadRunGuard {
198 pub thread: WasiThread,
199}
200
201impl WasiThreadRunGuard {
202 pub fn new(thread: WasiThread) -> Self {
203 Self { thread }
204 }
205}
206
207impl Drop for WasiThreadRunGuard {
208 fn drop(&mut self) {
209 self.thread
210 .set_status_finished(Err(
211 crate::RuntimeError::new("Thread manager disconnected").into()
212 ));
213 }
214}
215
216pub use wasmer_wasix_types::wasix::WasiMemoryLayout;
218
219#[derive(Clone, Debug)]
220pub enum RewindResultType {
221 RewindRestart,
223 RewindWithoutResult,
225 RewindWithResult(Bytes),
227}
228
229#[derive(Clone, Debug)]
231pub(crate) struct RewindResult {
232 pub memory_stack: Option<Bytes>,
234 pub rewind_result: RewindResultType,
237}
238
239#[derive(Debug)]
240struct WasiThreadState {
241 is_main: bool,
242 pid: WasiProcessId,
243 id: WasiThreadId,
244 signals: Mutex<(Vec<Signal>, Vec<Waker>)>,
245 stack: Mutex<ThreadStack>,
246 status: Arc<OwnedTaskStatus>,
247 #[cfg(feature = "journal")]
248 check_pointing: AtomicBool,
249 deep_sleeping: AtomicBool,
250
251 _task_count_guard: TaskCountGuard,
254}
255
256static NO_MORE_BYTES: [u8; 0] = [0u8; 0];
257
258impl WasiThread {
259 pub fn new(
260 pid: WasiProcessId,
261 id: WasiThreadId,
262 is_main: bool,
263 status: Arc<OwnedTaskStatus>,
264 guard: TaskCountGuard,
265 layout: WasiMemoryLayout,
266 start: ThreadStartType,
267 ) -> Self {
268 Self {
269 state: Arc::new(WasiThreadState {
270 is_main,
271 pid,
272 id,
273 status,
274 signals: Mutex::new((Vec::new(), Vec::new())),
275 stack: Mutex::new(ThreadStack::default()),
276 #[cfg(feature = "journal")]
277 check_pointing: AtomicBool::new(false),
278 deep_sleeping: AtomicBool::new(false),
279 _task_count_guard: guard,
280 }),
281 layout,
282 start,
283 rewind: None,
284 }
285 }
286
287 pub fn pid(&self) -> WasiProcessId {
289 self.state.pid
290 }
291
292 pub fn tid(&self) -> WasiThreadId {
294 self.state.id
295 }
296
297 pub fn is_main(&self) -> bool {
299 self.state.is_main
300 }
301
302 pub fn join_handle(&self) -> TaskJoinHandle {
304 self.state.status.handle()
305 }
306
307 pub fn signals(&self) -> &Mutex<(Vec<Signal>, Vec<Waker>)> {
309 &self.state.signals
310 }
311
312 pub fn set_status_running(&self) {
313 self.state.status.set_running();
314 }
315
316 pub fn set_or_get_exit_code_for_signal(&self, sig: Signal) -> ExitCode {
320 let default_exitcode: ExitCode = match sig {
321 Signal::Sigquit | Signal::Sigabrt => Errno::Success.into(),
322 Signal::Sigpipe => Errno::Pipe.into(),
323 _ => Errno::Intr.into(),
324 };
325 self.set_status_finished(Ok(default_exitcode));
327 self.try_join()
328 .map(|r| r.unwrap_or(default_exitcode))
329 .unwrap_or(default_exitcode)
330 }
331
332 pub fn set_status_finished(&self, res: Result<ExitCode, WasiRuntimeError>) {
335 self.state.status.set_finished(res.map_err(Arc::new));
336 }
337
338 pub async fn join(&self) -> Result<ExitCode, Arc<WasiRuntimeError>> {
340 self.state.status.await_termination().await
341 }
342
343 pub fn try_join(&self) -> Option<Result<ExitCode, Arc<WasiRuntimeError>>> {
345 self.state.status.status().into_finished()
346 }
347
348 pub fn signal(&self, signal: Signal) {
350 let tid = self.tid();
351 tracing::trace!(%tid, "signal-thread({:?})", signal);
352
353 let mut guard = self.state.signals.lock().unwrap();
354 if !guard.0.contains(&signal) {
355 guard.0.push(signal);
356 }
357 guard.1.drain(..).for_each(|w| w.wake());
358 }
359
360 pub fn has_signal(&self, signals: &[Signal]) -> bool {
362 let guard = self.state.signals.lock().unwrap();
363 for s in guard.0.iter() {
364 if signals.contains(s) {
365 return true;
366 }
367 }
368 false
369 }
370
371 pub async fn wait_for_signal(&self) {
373 struct SignalPoller<'a> {
375 thread: &'a WasiThread,
376 }
377 impl std::future::Future for SignalPoller<'_> {
378 type Output = ();
379 fn poll(
380 self: std::pin::Pin<&mut Self>,
381 cx: &mut std::task::Context<'_>,
382 ) -> std::task::Poll<Self::Output> {
383 if self.thread.has_signals_or_subscribe(cx.waker()) {
384 return std::task::Poll::Ready(());
385 }
386 std::task::Poll::Pending
387 }
388 }
389 SignalPoller { thread: self }.await
390 }
391
392 pub fn pop_signals_or_subscribe(&self, waker: &Waker) -> Option<Vec<Signal>> {
394 let mut guard = self.state.signals.lock().unwrap();
395 let mut ret = Vec::new();
396 std::mem::swap(&mut ret, &mut guard.0);
397 match ret.is_empty() {
398 true => {
399 if !guard.1.iter().any(|w| w.will_wake(waker)) {
400 guard.1.push(waker.clone());
401 }
402 None
403 }
404 false => Some(ret),
405 }
406 }
407
408 pub fn signals_subscribe(&self, waker: &Waker) {
409 let mut guard = self.state.signals.lock().unwrap();
410 if !guard.1.iter().any(|w| w.will_wake(waker)) {
411 guard.1.push(waker.clone());
412 }
413 }
414
415 pub fn has_signals_or_subscribe(&self, waker: &Waker) -> bool {
417 let mut guard = self.state.signals.lock().unwrap();
418 let has_signals = !guard.0.is_empty();
419 if !has_signals && !guard.1.iter().any(|w| w.will_wake(waker)) {
420 guard.1.push(waker.clone());
421 }
422 has_signals
423 }
424
425 pub fn pop_signals(&self) -> Vec<Signal> {
427 let mut guard = self.state.signals.lock().unwrap();
428 let mut ret = Vec::new();
429 std::mem::swap(&mut ret, &mut guard.0);
430 ret
431 }
432
433 pub fn add_snapshot(
435 &self,
436 mut memory_stack: &[u8],
437 mut memory_stack_corrected: &[u8],
438 hash: u128,
439 rewind_stack: &[u8],
440 store_data: &[u8],
441 ) {
442 let mut stack = self.state.stack.lock().unwrap();
444 let mut pstack = stack.deref_mut();
445 loop {
446 let memory_stack_before = pstack.memory_stack.len();
448 let memory_stack_after = memory_stack.len();
449 if memory_stack_before > memory_stack_after
450 || (!pstack
451 .memory_stack
452 .iter()
453 .zip(memory_stack.iter())
454 .any(|(a, b)| *a == *b)
455 && !pstack
456 .memory_stack_corrected
457 .iter()
458 .zip(memory_stack.iter())
459 .any(|(a, b)| *a == *b))
460 {
461 let mut new_stack = ThreadStack {
463 memory_stack: memory_stack.to_vec(),
464 memory_stack_corrected: memory_stack_corrected.to_vec(),
465 ..Default::default()
466 };
467 std::mem::swap(pstack, &mut new_stack);
468 memory_stack = &NO_MORE_BYTES[..];
469 memory_stack_corrected = &NO_MORE_BYTES[..];
470
471 let mut disown = Some(Box::new(new_stack));
473 if let Some(disown) = disown.as_ref() {
474 if !disown.snapshots.is_empty() {
475 tracing::trace!(
476 "wasi[{}]::stacks forgotten (memory_stack_before={}, memory_stack_after={})",
477 self.pid(),
478 memory_stack_before,
479 memory_stack_after
480 );
481 }
482 }
483 let mut total_forgotten = 0usize;
484 while let Some(disowned) = disown {
485 for _hash in disowned.snapshots.keys() {
486 total_forgotten += 1;
487 }
488 disown = disowned.next;
489 }
490 if total_forgotten > 0 {
491 tracing::trace!(
492 "wasi[{}]::stack has been forgotten (cnt={})",
493 self.pid(),
494 total_forgotten
495 );
496 }
497 } else {
498 memory_stack = &memory_stack[pstack.memory_stack.len()..];
499 memory_stack_corrected =
500 &memory_stack_corrected[pstack.memory_stack_corrected.len()..];
501 }
502
503 if memory_stack.is_empty() {
505 break;
506 }
507
508 if pstack.next.is_none() {
510 let new_stack = ThreadStack {
511 memory_stack: memory_stack.to_vec(),
512 memory_stack_corrected: memory_stack_corrected.to_vec(),
513 ..Default::default()
514 };
515 pstack.next.replace(Box::new(new_stack));
516 }
517 pstack = pstack.next.as_mut().unwrap();
518 }
519
520 pstack.snapshots.insert(
522 hash,
523 ThreadSnapshot {
524 call_stack: BytesMut::from(rewind_stack).freeze(),
525 store_data: BytesMut::from(store_data).freeze(),
526 },
527 );
528 }
529
530 pub fn get_snapshot(&self, hash: u128) -> Option<(BytesMut, Bytes, Bytes)> {
532 let mut memory_stack = BytesMut::new();
533
534 let stack = self.state.stack.lock().unwrap();
535 let mut pstack = stack.deref();
536 loop {
537 memory_stack.extend(pstack.memory_stack_corrected.iter());
538 if let Some(snapshot) = pstack.snapshots.get(&hash) {
539 return Some((
540 memory_stack,
541 snapshot.call_stack.clone(),
542 snapshot.store_data.clone(),
543 ));
544 }
545 if let Some(next) = pstack.next.as_ref() {
546 pstack = next.deref();
547 } else {
548 return None;
549 }
550 }
551 }
552
553 pub fn copy_stack_from(&self, other: &WasiThread) {
555 let mut stack = {
556 let stack_guard = other.state.stack.lock().unwrap();
557 stack_guard.clone()
558 };
559
560 let mut stack_guard = self.state.stack.lock().unwrap();
561 std::mem::swap(stack_guard.deref_mut(), &mut stack);
562 }
563}
564
565#[derive(Debug)]
566pub struct WasiThreadHandleProtected {
567 thread: WasiThread,
568 inner: Weak<(Mutex<WasiProcessInner>, Condvar)>,
569}
570
571#[derive(Debug, Clone)]
572pub struct WasiThreadHandle {
573 protected: Arc<WasiThreadHandleProtected>,
574}
575
576impl WasiThreadHandle {
577 pub(crate) fn new(
578 thread: WasiThread,
579 inner: &Arc<(Mutex<WasiProcessInner>, Condvar)>,
580 ) -> WasiThreadHandle {
581 Self {
582 protected: Arc::new(WasiThreadHandleProtected {
583 thread,
584 inner: Arc::downgrade(inner),
585 }),
586 }
587 }
588
589 pub fn id(&self) -> WasiThreadId {
590 self.protected.thread.tid()
591 }
592
593 pub fn as_thread(&self) -> WasiThread {
594 self.protected.thread.clone()
595 }
596}
597
598impl Drop for WasiThreadHandleProtected {
599 fn drop(&mut self) {
600 let id = self.thread.tid();
601 if let Some(inner) = Weak::upgrade(&self.inner) {
602 let mut inner = inner.0.lock().unwrap();
603 if let Some(ctrl) = inner.threads.remove(&id) {
604 ctrl.set_status_finished(Ok(Errno::Success.into()));
605 }
606 inner.thread_count -= 1;
607 }
608 }
609}
610
611impl std::ops::Deref for WasiThreadHandle {
612 type Target = WasiThread;
613
614 fn deref(&self) -> &Self::Target {
615 &self.protected.thread
616 }
617}
618
619#[derive(thiserror::Error, Debug, Clone)]
620pub enum WasiThreadError {
621 #[error("Multithreading is not supported")]
622 Unsupported,
623 #[error("The method named is not an exported function")]
624 MethodNotFound,
625 #[error("Failed to create the requested memory - {0}")]
626 MemoryCreateFailed(MemoryError),
627 #[error("{0}")]
628 ExportError(ExportError),
629 #[error("Linker error: {0}")]
630 LinkError(Arc<LinkError>),
631 #[error("Failed to create the instance - {0}")]
632 InstanceCreateFailed(Box<InstantiationError>),
634 #[error("Initialization function failed - {0}")]
635 InitFailed(Arc<anyhow::Error>),
636 #[error("WASM context is invalid")]
638 InvalidWasmContext,
639}
640
641impl From<WasiThreadError> for Errno {
642 fn from(a: WasiThreadError) -> Errno {
643 match a {
644 WasiThreadError::Unsupported => Errno::Notsup,
645 WasiThreadError::MethodNotFound => Errno::Inval,
646 WasiThreadError::MemoryCreateFailed(_) => Errno::Nomem,
647 WasiThreadError::ExportError(_) => Errno::Noexec,
648 WasiThreadError::LinkError(_) => Errno::Noexec,
649 WasiThreadError::InstanceCreateFailed(_) => Errno::Noexec,
650 WasiThreadError::InitFailed(_) => Errno::Noexec,
651 WasiThreadError::InvalidWasmContext => Errno::Noexec,
652 }
653 }
654}