wasmer_wasix/os/task/
thread.rs

1use super::{
2    control_plane::TaskCountGuard,
3    task_join_handle::{OwnedTaskStatus, TaskJoinHandle},
4};
5use crate::{
6    WasiRuntimeError,
7    os::task::process::{WasiProcessId, WasiProcessInner},
8    state::LinkError,
9    syscalls::HandleRewindType,
10};
11use bytes::{Bytes, BytesMut};
12use serde::{Deserialize, Serialize};
13use std::sync::atomic::{AtomicBool, Ordering};
14use std::{
15    collections::HashMap,
16    ops::{Deref, DerefMut},
17    sync::{Arc, Condvar, Mutex, Weak},
18    task::Waker,
19};
20use wasmer::{ExportError, InstantiationError, MemoryError};
21use wasmer_wasix_types::{
22    types::Signal,
23    wasi::{Errno, ExitCode},
24    wasix::ThreadStartType,
25};
26
27/// Represents the ID of a WASI thread
28#[derive(Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
29pub struct WasiThreadId(u32);
30
31impl WasiThreadId {
32    pub fn raw(&self) -> u32 {
33        self.0
34    }
35
36    pub fn inc(&mut self) -> WasiThreadId {
37        let ret = *self;
38        self.0 += 1;
39        ret
40    }
41}
42
43impl From<i32> for WasiThreadId {
44    fn from(id: i32) -> Self {
45        Self(id as u32)
46    }
47}
48
49impl From<WasiThreadId> for i32 {
50    fn from(val: WasiThreadId) -> Self {
51        val.0 as i32
52    }
53}
54
55impl From<u32> for WasiThreadId {
56    fn from(id: u32) -> Self {
57        Self(id)
58    }
59}
60
61impl From<WasiThreadId> for u32 {
62    fn from(t: WasiThreadId) -> u32 {
63        t.0
64    }
65}
66
67impl std::fmt::Display for WasiThreadId {
68    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69        write!(f, "{}", self.0)
70    }
71}
72
73impl std::fmt::Debug for WasiThreadId {
74    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75        write!(f, "{}", self.0)
76    }
77}
78
79/// Represents a linked list of stack snapshots
80#[derive(Debug, Clone)]
81struct ThreadSnapshot {
82    call_stack: Bytes,
83    store_data: Bytes,
84}
85
86/// Represents a linked list of stack snapshots
87#[derive(Debug, Clone, Default)]
88pub struct ThreadStack {
89    memory_stack: Vec<u8>,
90    memory_stack_corrected: Vec<u8>,
91    snapshots: HashMap<u128, ThreadSnapshot>,
92    next: Option<Box<ThreadStack>>,
93}
94
95/// Represents a running thread which allows a joiner to
96/// wait for the thread to exit
97#[derive(Clone, Debug)]
98pub struct WasiThread {
99    state: Arc<WasiThreadState>,
100    layout: WasiMemoryLayout,
101    start: ThreadStartType,
102
103    // This is used for stack rewinds
104    rewind: Option<RewindResult>,
105}
106
107impl WasiThread {
108    pub fn id(&self) -> WasiThreadId {
109        self.state.id
110    }
111
112    /// Sets that a rewind will take place
113    pub(crate) fn set_rewind(&mut self, rewind: RewindResult) {
114        self.rewind.replace(rewind);
115    }
116
117    /// Pops any rewinds that need to take place
118    pub(crate) fn take_rewind(&mut self) -> Option<RewindResult> {
119        self.rewind.take()
120    }
121
122    /// Gets the thread start type for this thread
123    pub fn thread_start_type(&self) -> ThreadStartType {
124        self.start
125    }
126
127    /// Returns true if a rewind of a particular type has been queued
128    /// for processed by a rewind operation
129    pub(crate) fn has_rewind_of_type(&self, type_: HandleRewindType) -> bool {
130        match type_ {
131            HandleRewindType::ResultDriven => match &self.rewind {
132                Some(rewind) => match rewind.rewind_result {
133                    RewindResultType::RewindRestart => true,
134                    RewindResultType::RewindWithoutResult => false,
135                    RewindResultType::RewindWithResult(_) => true,
136                },
137                None => false,
138            },
139            HandleRewindType::ResultLess => match &self.rewind {
140                Some(rewind) => match rewind.rewind_result {
141                    RewindResultType::RewindRestart => true,
142                    RewindResultType::RewindWithoutResult => true,
143                    RewindResultType::RewindWithResult(_) => false,
144                },
145                None => false,
146            },
147        }
148    }
149
150    /// Sets a flag that tells others if this thread is currently
151    /// deep sleeping
152    pub(crate) fn set_deep_sleeping(&self, val: bool) {
153        self.state.deep_sleeping.store(val, Ordering::SeqCst);
154    }
155
156    /// Reads a flag that determines if this thread is currently
157    /// deep sleeping
158    pub(crate) fn is_deep_sleeping(&self) -> bool {
159        self.state.deep_sleeping.load(Ordering::SeqCst)
160    }
161
162    /// Sets a flag that tells others that this thread is currently
163    /// check pointing itself
164    #[cfg(feature = "journal")]
165    pub(crate) fn set_checkpointing(&self, val: bool) {
166        self.state.check_pointing.store(val, Ordering::SeqCst);
167    }
168
169    /// Reads a flag that determines if this thread is currently
170    /// check pointing itself or not
171    #[cfg(feature = "journal")]
172    pub(crate) fn is_check_pointing(&self) -> bool {
173        self.state.check_pointing.load(Ordering::SeqCst)
174    }
175
176    /// Gets the memory layout for this thread
177    #[allow(dead_code)]
178    pub(crate) fn memory_layout(&self) -> &WasiMemoryLayout {
179        &self.layout
180    }
181
182    /// Gets the memory layout for this thread
183    pub(crate) fn set_memory_layout(&mut self, layout: WasiMemoryLayout) {
184        self.layout = layout;
185    }
186}
187
188/// A guard that ensures a thread is marked as terminated when dropped.
189///
190/// Normally the thread result should be manually registered with
191/// [`WasiThread::set_status_running`] or [`WasiThread::set_status_finished`],
192/// but this guard can ensure that the thread is marked as terminated even if
193/// this is forgotten or a panic occurs.
194pub struct WasiThreadRunGuard {
195    pub thread: WasiThread,
196}
197
198impl WasiThreadRunGuard {
199    pub fn new(thread: WasiThread) -> Self {
200        Self { thread }
201    }
202}
203
204impl Drop for WasiThreadRunGuard {
205    fn drop(&mut self) {
206        self.thread
207            .set_status_finished(Err(
208                crate::RuntimeError::new("Thread manager disconnected").into()
209            ));
210    }
211}
212
213/// Represents the memory layout of the parts that the thread itself uses
214pub use wasmer_wasix_types::wasix::WasiMemoryLayout;
215
216#[derive(Clone, Debug)]
217pub enum RewindResultType {
218    // The rewind must restart the operation it had already started
219    RewindRestart,
220    // The rewind has been triggered and should be handled but has not result
221    RewindWithoutResult,
222    // The rewind has been triggered and should be handled with the supplied result
223    RewindWithResult(Bytes),
224}
225
226// Contains the result of a rewind operation
227#[derive(Clone, Debug)]
228pub(crate) struct RewindResult {
229    /// Memory stack used to restore the memory stack (thing that holds local variables) back to where it was
230    pub memory_stack: Option<Bytes>,
231    /// Generic serialized object passed back to the rewind resumption code
232    /// (uses the bincode serializer)
233    pub rewind_result: RewindResultType,
234}
235
236#[derive(Debug)]
237struct WasiThreadState {
238    is_main: bool,
239    pid: WasiProcessId,
240    id: WasiThreadId,
241    signals: Mutex<(Vec<Signal>, Vec<Waker>)>,
242    stack: Mutex<ThreadStack>,
243    status: Arc<OwnedTaskStatus>,
244    #[cfg(feature = "journal")]
245    check_pointing: AtomicBool,
246    deep_sleeping: AtomicBool,
247
248    // Registers the task termination with the ControlPlane on drop.
249    // Never accessed, since it's a drop guard.
250    _task_count_guard: TaskCountGuard,
251}
252
253static NO_MORE_BYTES: [u8; 0] = [0u8; 0];
254
255impl WasiThread {
256    pub fn new(
257        pid: WasiProcessId,
258        id: WasiThreadId,
259        is_main: bool,
260        status: Arc<OwnedTaskStatus>,
261        guard: TaskCountGuard,
262        layout: WasiMemoryLayout,
263        start: ThreadStartType,
264    ) -> Self {
265        Self {
266            state: Arc::new(WasiThreadState {
267                is_main,
268                pid,
269                id,
270                status,
271                signals: Mutex::new((Vec::new(), Vec::new())),
272                stack: Mutex::new(ThreadStack::default()),
273                #[cfg(feature = "journal")]
274                check_pointing: AtomicBool::new(false),
275                deep_sleeping: AtomicBool::new(false),
276                _task_count_guard: guard,
277            }),
278            layout,
279            start,
280            rewind: None,
281        }
282    }
283
284    /// Returns the process ID
285    pub fn pid(&self) -> WasiProcessId {
286        self.state.pid
287    }
288
289    /// Returns the thread ID
290    pub fn tid(&self) -> WasiThreadId {
291        self.state.id
292    }
293
294    /// Returns true if this thread is the main thread
295    pub fn is_main(&self) -> bool {
296        self.state.is_main
297    }
298
299    /// Get a join handle to watch the task status.
300    pub fn join_handle(&self) -> TaskJoinHandle {
301        self.state.status.handle()
302    }
303
304    // TODO: this should be private, access should go through utility methods.
305    pub fn signals(&self) -> &Mutex<(Vec<Signal>, Vec<Waker>)> {
306        &self.state.signals
307    }
308
309    pub fn set_status_running(&self) {
310        self.state.status.set_running();
311    }
312
313    /// Gets or sets the exit code based of a signal that was received
314    /// Note: if the exit code was already set earlier this method will
315    /// just return that earlier set exit code
316    pub fn set_or_get_exit_code_for_signal(&self, sig: Signal) -> ExitCode {
317        let default_exitcode: ExitCode = match sig {
318            Signal::Sigquit => Errno::Success.into(),
319            // Match the POSIX shell convention for signal termination.
320            Signal::Sigabrt => ExitCode::from(128 + sig as i32),
321            Signal::Sigpipe => Errno::Pipe.into(),
322            _ => Errno::Intr.into(),
323        };
324        // This will only set the status code if its not already set
325        self.set_status_finished(Ok(default_exitcode));
326        self.try_join()
327            .map(|r| r.unwrap_or(default_exitcode))
328            .unwrap_or(default_exitcode)
329    }
330
331    /// Marks the thread as finished (which will cause anyone that
332    /// joined on it to wake up)
333    pub fn set_status_finished(&self, res: Result<ExitCode, WasiRuntimeError>) {
334        self.state.status.set_finished(res.map_err(Arc::new));
335    }
336
337    /// Waits until the thread is finished or the timeout is reached
338    pub async fn join(&self) -> Result<ExitCode, Arc<WasiRuntimeError>> {
339        self.state.status.await_termination().await
340    }
341
342    /// Attempts to join on the thread
343    pub fn try_join(&self) -> Option<Result<ExitCode, Arc<WasiRuntimeError>>> {
344        self.state.status.status().into_finished()
345    }
346
347    /// Adds a signal for this thread to process
348    pub fn signal(&self, signal: Signal) {
349        let tid = self.tid();
350        tracing::trace!(%tid, "signal-thread({:?})", signal);
351
352        let mut guard = self.state.signals.lock().unwrap();
353        if !guard.0.contains(&signal) {
354            guard.0.push(signal);
355        }
356        guard.1.drain(..).for_each(|w| w.wake());
357    }
358
359    /// Returns all the signals that are waiting to be processed
360    pub fn has_signal(&self, signals: &[Signal]) -> bool {
361        let guard = self.state.signals.lock().unwrap();
362        for s in guard.0.iter() {
363            if signals.contains(s) {
364                return true;
365            }
366        }
367        false
368    }
369
370    /// Waits for a signal to arrive
371    pub async fn wait_for_signal(&self) {
372        // This poller will process any signals when the main working function is idle
373        struct SignalPoller<'a> {
374            thread: &'a WasiThread,
375        }
376        impl std::future::Future for SignalPoller<'_> {
377            type Output = ();
378            fn poll(
379                self: std::pin::Pin<&mut Self>,
380                cx: &mut std::task::Context<'_>,
381            ) -> std::task::Poll<Self::Output> {
382                if self.thread.has_signals_or_subscribe(cx.waker()) {
383                    return std::task::Poll::Ready(());
384                }
385                std::task::Poll::Pending
386            }
387        }
388        SignalPoller { thread: self }.await
389    }
390
391    /// Returns all the signals that are waiting to be processed
392    pub fn pop_signals_or_subscribe(&self, waker: &Waker) -> Option<Vec<Signal>> {
393        let mut guard = self.state.signals.lock().unwrap();
394        let mut ret = Vec::new();
395        std::mem::swap(&mut ret, &mut guard.0);
396        match ret.is_empty() {
397            true => {
398                if !guard.1.iter().any(|w| w.will_wake(waker)) {
399                    guard.1.push(waker.clone());
400                }
401                None
402            }
403            false => Some(ret),
404        }
405    }
406
407    pub fn signals_subscribe(&self, waker: &Waker) {
408        let mut guard = self.state.signals.lock().unwrap();
409        if !guard.1.iter().any(|w| w.will_wake(waker)) {
410            guard.1.push(waker.clone());
411        }
412    }
413
414    /// Returns all the signals that are waiting to be processed
415    pub fn has_signals_or_subscribe(&self, waker: &Waker) -> bool {
416        let mut guard = self.state.signals.lock().unwrap();
417        let has_signals = !guard.0.is_empty();
418        if !has_signals && !guard.1.iter().any(|w| w.will_wake(waker)) {
419            guard.1.push(waker.clone());
420        }
421        has_signals
422    }
423
424    /// Returns all the signals that are waiting to be processed
425    pub fn pop_signals(&self) -> Vec<Signal> {
426        let mut guard = self.state.signals.lock().unwrap();
427        let mut ret = Vec::new();
428        std::mem::swap(&mut ret, &mut guard.0);
429        ret
430    }
431
432    /// Adds a stack snapshot and removes dead ones
433    pub fn add_snapshot(
434        &self,
435        mut memory_stack: &[u8],
436        mut memory_stack_corrected: &[u8],
437        hash: u128,
438        rewind_stack: &[u8],
439        store_data: &[u8],
440    ) {
441        // Lock the stack
442        let mut stack = self.state.stack.lock().unwrap();
443        let mut pstack = stack.deref_mut();
444        loop {
445            // First we validate if the stack is no longer valid
446            let memory_stack_before = pstack.memory_stack.len();
447            let memory_stack_after = memory_stack.len();
448            if memory_stack_before > memory_stack_after
449                || (!pstack
450                    .memory_stack
451                    .iter()
452                    .zip(memory_stack.iter())
453                    .any(|(a, b)| *a == *b)
454                    && !pstack
455                        .memory_stack_corrected
456                        .iter()
457                        .zip(memory_stack.iter())
458                        .any(|(a, b)| *a == *b))
459            {
460                // The stacks have changed so need to start again at this segment
461                let mut new_stack = ThreadStack {
462                    memory_stack: memory_stack.to_vec(),
463                    memory_stack_corrected: memory_stack_corrected.to_vec(),
464                    ..Default::default()
465                };
466                std::mem::swap(pstack, &mut new_stack);
467                memory_stack = &NO_MORE_BYTES[..];
468                memory_stack_corrected = &NO_MORE_BYTES[..];
469
470                // Output debug info for the dead stack
471                let mut disown = Some(Box::new(new_stack));
472                if let Some(disown) = disown.as_ref()
473                    && !disown.snapshots.is_empty()
474                {
475                    tracing::trace!(
476                        "wasi[{}]::stacks forgotten (memory_stack_before={}, memory_stack_after={})",
477                        self.pid(),
478                        memory_stack_before,
479                        memory_stack_after
480                    );
481                }
482                let mut total_forgotten = 0usize;
483                while let Some(disowned) = disown {
484                    for _hash in disowned.snapshots.keys() {
485                        total_forgotten += 1;
486                    }
487                    disown = disowned.next;
488                }
489                if total_forgotten > 0 {
490                    tracing::trace!(
491                        "wasi[{}]::stack has been forgotten (cnt={})",
492                        self.pid(),
493                        total_forgotten
494                    );
495                }
496            } else {
497                memory_stack = &memory_stack[pstack.memory_stack.len()..];
498                memory_stack_corrected =
499                    &memory_stack_corrected[pstack.memory_stack_corrected.len()..];
500            }
501
502            // If there is no more memory stack then we are done and can add the call stack
503            if memory_stack.is_empty() {
504                break;
505            }
506
507            // Otherwise we need to add a next stack pointer and continue the iterations
508            if pstack.next.is_none() {
509                let new_stack = ThreadStack {
510                    memory_stack: memory_stack.to_vec(),
511                    memory_stack_corrected: memory_stack_corrected.to_vec(),
512                    ..Default::default()
513                };
514                pstack.next.replace(Box::new(new_stack));
515            }
516            pstack = pstack.next.as_mut().unwrap();
517        }
518
519        // Add the call stack
520        pstack.snapshots.insert(
521            hash,
522            ThreadSnapshot {
523                call_stack: BytesMut::from(rewind_stack).freeze(),
524                store_data: BytesMut::from(store_data).freeze(),
525            },
526        );
527    }
528
529    /// Gets a snapshot that was previously addedf
530    pub fn get_snapshot(&self, hash: u128) -> Option<(BytesMut, Bytes, Bytes)> {
531        let mut memory_stack = BytesMut::new();
532
533        let stack = self.state.stack.lock().unwrap();
534        let mut pstack = stack.deref();
535        loop {
536            memory_stack.extend(pstack.memory_stack_corrected.iter());
537            if let Some(snapshot) = pstack.snapshots.get(&hash) {
538                return Some((
539                    memory_stack,
540                    snapshot.call_stack.clone(),
541                    snapshot.store_data.clone(),
542                ));
543            }
544            if let Some(next) = pstack.next.as_ref() {
545                pstack = next.deref();
546            } else {
547                return None;
548            }
549        }
550    }
551
552    // Copy the stacks from another thread
553    pub fn copy_stack_from(&self, other: &WasiThread) {
554        let mut stack = {
555            let stack_guard = other.state.stack.lock().unwrap();
556            stack_guard.clone()
557        };
558
559        let mut stack_guard = self.state.stack.lock().unwrap();
560        std::mem::swap(stack_guard.deref_mut(), &mut stack);
561    }
562}
563
564#[derive(Debug)]
565pub struct WasiThreadHandleProtected {
566    thread: WasiThread,
567    inner: Weak<(Mutex<WasiProcessInner>, Condvar)>,
568}
569
570#[derive(Debug, Clone)]
571pub struct WasiThreadHandle {
572    protected: Arc<WasiThreadHandleProtected>,
573}
574
575impl WasiThreadHandle {
576    pub(crate) fn new(
577        thread: WasiThread,
578        inner: &Arc<(Mutex<WasiProcessInner>, Condvar)>,
579    ) -> WasiThreadHandle {
580        Self {
581            protected: Arc::new(WasiThreadHandleProtected {
582                thread,
583                inner: Arc::downgrade(inner),
584            }),
585        }
586    }
587
588    pub fn id(&self) -> WasiThreadId {
589        self.protected.thread.tid()
590    }
591
592    pub fn as_thread(&self) -> WasiThread {
593        self.protected.thread.clone()
594    }
595}
596
597impl Drop for WasiThreadHandleProtected {
598    fn drop(&mut self) {
599        let id = self.thread.tid();
600        if let Some(inner) = Weak::upgrade(&self.inner) {
601            let mut inner = inner.0.lock().unwrap();
602            if let Some(ctrl) = inner.threads.remove(&id) {
603                ctrl.set_status_finished(Ok(Errno::Success.into()));
604            }
605            inner.thread_count -= 1;
606        }
607    }
608}
609
610impl std::ops::Deref for WasiThreadHandle {
611    type Target = WasiThread;
612
613    fn deref(&self) -> &Self::Target {
614        &self.protected.thread
615    }
616}
617
618#[derive(thiserror::Error, Debug, Clone)]
619pub enum WasiThreadError {
620    #[error("Multithreading is not supported")]
621    Unsupported,
622    #[error("The method named is not an exported function")]
623    MethodNotFound,
624    #[error("Failed to create the requested memory - {0}")]
625    MemoryCreateFailed(MemoryError),
626    #[error("{0}")]
627    ExportError(ExportError),
628    #[error("Failed to create additional imports - {0}")]
629    AdditionalImportCreationFailed(Arc<anyhow::Error>),
630    #[error("Linker error: {0}")]
631    LinkError(Arc<LinkError>),
632    #[error("Failed to create the instance - {0}")]
633    // Note: Boxed so we can keep the error size down
634    InstanceCreateFailed(Box<InstantiationError>),
635    #[error("Initialization function failed - {0}")]
636    InitFailed(Arc<anyhow::Error>),
637    /// This will happen if WASM is running in a thread has not been created by the spawn_wasm call
638    #[error("WASM context is invalid")]
639    InvalidWasmContext,
640}
641
642impl From<WasiThreadError> for Errno {
643    fn from(a: WasiThreadError) -> Errno {
644        match a {
645            WasiThreadError::Unsupported => Errno::Notsup,
646            WasiThreadError::MethodNotFound => Errno::Inval,
647            WasiThreadError::MemoryCreateFailed(_) => Errno::Nomem,
648            WasiThreadError::ExportError(_) => Errno::Noexec,
649            WasiThreadError::AdditionalImportCreationFailed(_) => Errno::Noexec,
650            WasiThreadError::LinkError(_) => Errno::Noexec,
651            WasiThreadError::InstanceCreateFailed(_) => Errno::Noexec,
652            WasiThreadError::InitFailed(_) => Errno::Noexec,
653            WasiThreadError::InvalidWasmContext => Errno::Noexec,
654        }
655    }
656}