wasmer_wasix/os/task/
thread.rs

1use serde::{Deserialize, Serialize};
2use std::sync::atomic::{AtomicBool, Ordering};
3use std::{
4    collections::HashMap,
5    ops::{Deref, DerefMut},
6    sync::{Arc, Condvar, Mutex, Weak},
7    task::Waker,
8};
9
10use bytes::{Bytes, BytesMut};
11use wasmer::{ExportError, InstantiationError, MemoryError};
12use wasmer_wasix_types::{
13    types::Signal,
14    wasi::{Errno, ExitCode},
15    wasix::ThreadStartType,
16};
17
18use crate::{
19    WasiRuntimeError,
20    os::task::process::{WasiProcessId, WasiProcessInner},
21    state::LinkError,
22    syscalls::HandleRewindType,
23};
24
25use super::{
26    control_plane::TaskCountGuard,
27    task_join_handle::{OwnedTaskStatus, TaskJoinHandle},
28};
29
30/// Represents the ID of a WASI thread
31#[derive(Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
32pub struct WasiThreadId(u32);
33
34impl WasiThreadId {
35    pub fn raw(&self) -> u32 {
36        self.0
37    }
38
39    pub fn inc(&mut self) -> WasiThreadId {
40        let ret = *self;
41        self.0 += 1;
42        ret
43    }
44}
45
46impl From<i32> for WasiThreadId {
47    fn from(id: i32) -> Self {
48        Self(id as u32)
49    }
50}
51
52impl From<WasiThreadId> for i32 {
53    fn from(val: WasiThreadId) -> Self {
54        val.0 as i32
55    }
56}
57
58impl From<u32> for WasiThreadId {
59    fn from(id: u32) -> Self {
60        Self(id)
61    }
62}
63
64impl From<WasiThreadId> for u32 {
65    fn from(t: WasiThreadId) -> u32 {
66        t.0
67    }
68}
69
70impl std::fmt::Display for WasiThreadId {
71    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72        write!(f, "{}", self.0)
73    }
74}
75
76impl std::fmt::Debug for WasiThreadId {
77    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78        write!(f, "{}", self.0)
79    }
80}
81
82/// Represents a linked list of stack snapshots
83#[derive(Debug, Clone)]
84struct ThreadSnapshot {
85    call_stack: Bytes,
86    store_data: Bytes,
87}
88
89/// Represents a linked list of stack snapshots
90#[derive(Debug, Clone, Default)]
91pub struct ThreadStack {
92    memory_stack: Vec<u8>,
93    memory_stack_corrected: Vec<u8>,
94    snapshots: HashMap<u128, ThreadSnapshot>,
95    next: Option<Box<ThreadStack>>,
96}
97
98/// Represents a running thread which allows a joiner to
99/// wait for the thread to exit
100#[derive(Clone, Debug)]
101pub struct WasiThread {
102    state: Arc<WasiThreadState>,
103    layout: WasiMemoryLayout,
104    start: ThreadStartType,
105
106    // This is used for stack rewinds
107    rewind: Option<RewindResult>,
108}
109
110impl WasiThread {
111    pub fn id(&self) -> WasiThreadId {
112        self.state.id
113    }
114
115    /// Sets that a rewind will take place
116    pub(crate) fn set_rewind(&mut self, rewind: RewindResult) {
117        self.rewind.replace(rewind);
118    }
119
120    /// Pops any rewinds that need to take place
121    pub(crate) fn take_rewind(&mut self) -> Option<RewindResult> {
122        self.rewind.take()
123    }
124
125    /// Gets the thread start type for this thread
126    pub fn thread_start_type(&self) -> ThreadStartType {
127        self.start
128    }
129
130    /// Returns true if a rewind of a particular type has been queued
131    /// for processed by a rewind operation
132    pub(crate) fn has_rewind_of_type(&self, type_: HandleRewindType) -> bool {
133        match type_ {
134            HandleRewindType::ResultDriven => match &self.rewind {
135                Some(rewind) => match rewind.rewind_result {
136                    RewindResultType::RewindRestart => true,
137                    RewindResultType::RewindWithoutResult => false,
138                    RewindResultType::RewindWithResult(_) => true,
139                },
140                None => false,
141            },
142            HandleRewindType::ResultLess => match &self.rewind {
143                Some(rewind) => match rewind.rewind_result {
144                    RewindResultType::RewindRestart => true,
145                    RewindResultType::RewindWithoutResult => true,
146                    RewindResultType::RewindWithResult(_) => false,
147                },
148                None => false,
149            },
150        }
151    }
152
153    /// Sets a flag that tells others if this thread is currently
154    /// deep sleeping
155    pub(crate) fn set_deep_sleeping(&self, val: bool) {
156        self.state.deep_sleeping.store(val, Ordering::SeqCst);
157    }
158
159    /// Reads a flag that determines if this thread is currently
160    /// deep sleeping
161    pub(crate) fn is_deep_sleeping(&self) -> bool {
162        self.state.deep_sleeping.load(Ordering::SeqCst)
163    }
164
165    /// Sets a flag that tells others that this thread is currently
166    /// check pointing itself
167    #[cfg(feature = "journal")]
168    pub(crate) fn set_checkpointing(&self, val: bool) {
169        self.state.check_pointing.store(val, Ordering::SeqCst);
170    }
171
172    /// Reads a flag that determines if this thread is currently
173    /// check pointing itself or not
174    #[cfg(feature = "journal")]
175    pub(crate) fn is_check_pointing(&self) -> bool {
176        self.state.check_pointing.load(Ordering::SeqCst)
177    }
178
179    /// Gets the memory layout for this thread
180    #[allow(dead_code)]
181    pub(crate) fn memory_layout(&self) -> &WasiMemoryLayout {
182        &self.layout
183    }
184
185    /// Gets the memory layout for this thread
186    pub(crate) fn set_memory_layout(&mut self, layout: WasiMemoryLayout) {
187        self.layout = layout;
188    }
189}
190
191/// A guard that ensures a thread is marked as terminated when dropped.
192///
193/// Normally the thread result should be manually registered with
194/// [`WasiThread::set_status_running`] or [`WasiThread::set_status_finished`],
195/// but this guard can ensure that the thread is marked as terminated even if
196/// this is forgotten or a panic occurs.
197pub struct WasiThreadRunGuard {
198    pub thread: WasiThread,
199}
200
201impl WasiThreadRunGuard {
202    pub fn new(thread: WasiThread) -> Self {
203        Self { thread }
204    }
205}
206
207impl Drop for WasiThreadRunGuard {
208    fn drop(&mut self) {
209        self.thread
210            .set_status_finished(Err(
211                crate::RuntimeError::new("Thread manager disconnected").into()
212            ));
213    }
214}
215
216/// Represents the memory layout of the parts that the thread itself uses
217pub use wasmer_wasix_types::wasix::WasiMemoryLayout;
218
219#[derive(Clone, Debug)]
220pub enum RewindResultType {
221    // The rewind must restart the operation it had already started
222    RewindRestart,
223    // The rewind has been triggered and should be handled but has not result
224    RewindWithoutResult,
225    // The rewind has been triggered and should be handled with the supplied result
226    RewindWithResult(Bytes),
227}
228
229// Contains the result of a rewind operation
230#[derive(Clone, Debug)]
231pub(crate) struct RewindResult {
232    /// Memory stack used to restore the memory stack (thing that holds local variables) back to where it was
233    pub memory_stack: Option<Bytes>,
234    /// Generic serialized object passed back to the rewind resumption code
235    /// (uses the bincode serializer)
236    pub rewind_result: RewindResultType,
237}
238
239#[derive(Debug)]
240struct WasiThreadState {
241    is_main: bool,
242    pid: WasiProcessId,
243    id: WasiThreadId,
244    signals: Mutex<(Vec<Signal>, Vec<Waker>)>,
245    stack: Mutex<ThreadStack>,
246    status: Arc<OwnedTaskStatus>,
247    #[cfg(feature = "journal")]
248    check_pointing: AtomicBool,
249    deep_sleeping: AtomicBool,
250
251    // Registers the task termination with the ControlPlane on drop.
252    // Never accessed, since it's a drop guard.
253    _task_count_guard: TaskCountGuard,
254}
255
256static NO_MORE_BYTES: [u8; 0] = [0u8; 0];
257
258impl WasiThread {
259    pub fn new(
260        pid: WasiProcessId,
261        id: WasiThreadId,
262        is_main: bool,
263        status: Arc<OwnedTaskStatus>,
264        guard: TaskCountGuard,
265        layout: WasiMemoryLayout,
266        start: ThreadStartType,
267    ) -> Self {
268        Self {
269            state: Arc::new(WasiThreadState {
270                is_main,
271                pid,
272                id,
273                status,
274                signals: Mutex::new((Vec::new(), Vec::new())),
275                stack: Mutex::new(ThreadStack::default()),
276                #[cfg(feature = "journal")]
277                check_pointing: AtomicBool::new(false),
278                deep_sleeping: AtomicBool::new(false),
279                _task_count_guard: guard,
280            }),
281            layout,
282            start,
283            rewind: None,
284        }
285    }
286
287    /// Returns the process ID
288    pub fn pid(&self) -> WasiProcessId {
289        self.state.pid
290    }
291
292    /// Returns the thread ID
293    pub fn tid(&self) -> WasiThreadId {
294        self.state.id
295    }
296
297    /// Returns true if this thread is the main thread
298    pub fn is_main(&self) -> bool {
299        self.state.is_main
300    }
301
302    /// Get a join handle to watch the task status.
303    pub fn join_handle(&self) -> TaskJoinHandle {
304        self.state.status.handle()
305    }
306
307    // TODO: this should be private, access should go through utility methods.
308    pub fn signals(&self) -> &Mutex<(Vec<Signal>, Vec<Waker>)> {
309        &self.state.signals
310    }
311
312    pub fn set_status_running(&self) {
313        self.state.status.set_running();
314    }
315
316    /// Gets or sets the exit code based of a signal that was received
317    /// Note: if the exit code was already set earlier this method will
318    /// just return that earlier set exit code
319    pub fn set_or_get_exit_code_for_signal(&self, sig: Signal) -> ExitCode {
320        let default_exitcode: ExitCode = match sig {
321            Signal::Sigquit | Signal::Sigabrt => Errno::Success.into(),
322            Signal::Sigpipe => Errno::Pipe.into(),
323            _ => Errno::Intr.into(),
324        };
325        // This will only set the status code if its not already set
326        self.set_status_finished(Ok(default_exitcode));
327        self.try_join()
328            .map(|r| r.unwrap_or(default_exitcode))
329            .unwrap_or(default_exitcode)
330    }
331
332    /// Marks the thread as finished (which will cause anyone that
333    /// joined on it to wake up)
334    pub fn set_status_finished(&self, res: Result<ExitCode, WasiRuntimeError>) {
335        self.state.status.set_finished(res.map_err(Arc::new));
336    }
337
338    /// Waits until the thread is finished or the timeout is reached
339    pub async fn join(&self) -> Result<ExitCode, Arc<WasiRuntimeError>> {
340        self.state.status.await_termination().await
341    }
342
343    /// Attempts to join on the thread
344    pub fn try_join(&self) -> Option<Result<ExitCode, Arc<WasiRuntimeError>>> {
345        self.state.status.status().into_finished()
346    }
347
348    /// Adds a signal for this thread to process
349    pub fn signal(&self, signal: Signal) {
350        let tid = self.tid();
351        tracing::trace!(%tid, "signal-thread({:?})", signal);
352
353        let mut guard = self.state.signals.lock().unwrap();
354        if !guard.0.contains(&signal) {
355            guard.0.push(signal);
356        }
357        guard.1.drain(..).for_each(|w| w.wake());
358    }
359
360    /// Returns all the signals that are waiting to be processed
361    pub fn has_signal(&self, signals: &[Signal]) -> bool {
362        let guard = self.state.signals.lock().unwrap();
363        for s in guard.0.iter() {
364            if signals.contains(s) {
365                return true;
366            }
367        }
368        false
369    }
370
371    /// Waits for a signal to arrive
372    pub async fn wait_for_signal(&self) {
373        // This poller will process any signals when the main working function is idle
374        struct SignalPoller<'a> {
375            thread: &'a WasiThread,
376        }
377        impl std::future::Future for SignalPoller<'_> {
378            type Output = ();
379            fn poll(
380                self: std::pin::Pin<&mut Self>,
381                cx: &mut std::task::Context<'_>,
382            ) -> std::task::Poll<Self::Output> {
383                if self.thread.has_signals_or_subscribe(cx.waker()) {
384                    return std::task::Poll::Ready(());
385                }
386                std::task::Poll::Pending
387            }
388        }
389        SignalPoller { thread: self }.await
390    }
391
392    /// Returns all the signals that are waiting to be processed
393    pub fn pop_signals_or_subscribe(&self, waker: &Waker) -> Option<Vec<Signal>> {
394        let mut guard = self.state.signals.lock().unwrap();
395        let mut ret = Vec::new();
396        std::mem::swap(&mut ret, &mut guard.0);
397        match ret.is_empty() {
398            true => {
399                if !guard.1.iter().any(|w| w.will_wake(waker)) {
400                    guard.1.push(waker.clone());
401                }
402                None
403            }
404            false => Some(ret),
405        }
406    }
407
408    pub fn signals_subscribe(&self, waker: &Waker) {
409        let mut guard = self.state.signals.lock().unwrap();
410        if !guard.1.iter().any(|w| w.will_wake(waker)) {
411            guard.1.push(waker.clone());
412        }
413    }
414
415    /// Returns all the signals that are waiting to be processed
416    pub fn has_signals_or_subscribe(&self, waker: &Waker) -> bool {
417        let mut guard = self.state.signals.lock().unwrap();
418        let has_signals = !guard.0.is_empty();
419        if !has_signals && !guard.1.iter().any(|w| w.will_wake(waker)) {
420            guard.1.push(waker.clone());
421        }
422        has_signals
423    }
424
425    /// Returns all the signals that are waiting to be processed
426    pub fn pop_signals(&self) -> Vec<Signal> {
427        let mut guard = self.state.signals.lock().unwrap();
428        let mut ret = Vec::new();
429        std::mem::swap(&mut ret, &mut guard.0);
430        ret
431    }
432
433    /// Adds a stack snapshot and removes dead ones
434    pub fn add_snapshot(
435        &self,
436        mut memory_stack: &[u8],
437        mut memory_stack_corrected: &[u8],
438        hash: u128,
439        rewind_stack: &[u8],
440        store_data: &[u8],
441    ) {
442        // Lock the stack
443        let mut stack = self.state.stack.lock().unwrap();
444        let mut pstack = stack.deref_mut();
445        loop {
446            // First we validate if the stack is no longer valid
447            let memory_stack_before = pstack.memory_stack.len();
448            let memory_stack_after = memory_stack.len();
449            if memory_stack_before > memory_stack_after
450                || (!pstack
451                    .memory_stack
452                    .iter()
453                    .zip(memory_stack.iter())
454                    .any(|(a, b)| *a == *b)
455                    && !pstack
456                        .memory_stack_corrected
457                        .iter()
458                        .zip(memory_stack.iter())
459                        .any(|(a, b)| *a == *b))
460            {
461                // The stacks have changed so need to start again at this segment
462                let mut new_stack = ThreadStack {
463                    memory_stack: memory_stack.to_vec(),
464                    memory_stack_corrected: memory_stack_corrected.to_vec(),
465                    ..Default::default()
466                };
467                std::mem::swap(pstack, &mut new_stack);
468                memory_stack = &NO_MORE_BYTES[..];
469                memory_stack_corrected = &NO_MORE_BYTES[..];
470
471                // Output debug info for the dead stack
472                let mut disown = Some(Box::new(new_stack));
473                if let Some(disown) = disown.as_ref() {
474                    if !disown.snapshots.is_empty() {
475                        tracing::trace!(
476                            "wasi[{}]::stacks forgotten (memory_stack_before={}, memory_stack_after={})",
477                            self.pid(),
478                            memory_stack_before,
479                            memory_stack_after
480                        );
481                    }
482                }
483                let mut total_forgotten = 0usize;
484                while let Some(disowned) = disown {
485                    for _hash in disowned.snapshots.keys() {
486                        total_forgotten += 1;
487                    }
488                    disown = disowned.next;
489                }
490                if total_forgotten > 0 {
491                    tracing::trace!(
492                        "wasi[{}]::stack has been forgotten (cnt={})",
493                        self.pid(),
494                        total_forgotten
495                    );
496                }
497            } else {
498                memory_stack = &memory_stack[pstack.memory_stack.len()..];
499                memory_stack_corrected =
500                    &memory_stack_corrected[pstack.memory_stack_corrected.len()..];
501            }
502
503            // If there is no more memory stack then we are done and can add the call stack
504            if memory_stack.is_empty() {
505                break;
506            }
507
508            // Otherwise we need to add a next stack pointer and continue the iterations
509            if pstack.next.is_none() {
510                let new_stack = ThreadStack {
511                    memory_stack: memory_stack.to_vec(),
512                    memory_stack_corrected: memory_stack_corrected.to_vec(),
513                    ..Default::default()
514                };
515                pstack.next.replace(Box::new(new_stack));
516            }
517            pstack = pstack.next.as_mut().unwrap();
518        }
519
520        // Add the call stack
521        pstack.snapshots.insert(
522            hash,
523            ThreadSnapshot {
524                call_stack: BytesMut::from(rewind_stack).freeze(),
525                store_data: BytesMut::from(store_data).freeze(),
526            },
527        );
528    }
529
530    /// Gets a snapshot that was previously addedf
531    pub fn get_snapshot(&self, hash: u128) -> Option<(BytesMut, Bytes, Bytes)> {
532        let mut memory_stack = BytesMut::new();
533
534        let stack = self.state.stack.lock().unwrap();
535        let mut pstack = stack.deref();
536        loop {
537            memory_stack.extend(pstack.memory_stack_corrected.iter());
538            if let Some(snapshot) = pstack.snapshots.get(&hash) {
539                return Some((
540                    memory_stack,
541                    snapshot.call_stack.clone(),
542                    snapshot.store_data.clone(),
543                ));
544            }
545            if let Some(next) = pstack.next.as_ref() {
546                pstack = next.deref();
547            } else {
548                return None;
549            }
550        }
551    }
552
553    // Copy the stacks from another thread
554    pub fn copy_stack_from(&self, other: &WasiThread) {
555        let mut stack = {
556            let stack_guard = other.state.stack.lock().unwrap();
557            stack_guard.clone()
558        };
559
560        let mut stack_guard = self.state.stack.lock().unwrap();
561        std::mem::swap(stack_guard.deref_mut(), &mut stack);
562    }
563}
564
565#[derive(Debug)]
566pub struct WasiThreadHandleProtected {
567    thread: WasiThread,
568    inner: Weak<(Mutex<WasiProcessInner>, Condvar)>,
569}
570
571#[derive(Debug, Clone)]
572pub struct WasiThreadHandle {
573    protected: Arc<WasiThreadHandleProtected>,
574}
575
576impl WasiThreadHandle {
577    pub(crate) fn new(
578        thread: WasiThread,
579        inner: &Arc<(Mutex<WasiProcessInner>, Condvar)>,
580    ) -> WasiThreadHandle {
581        Self {
582            protected: Arc::new(WasiThreadHandleProtected {
583                thread,
584                inner: Arc::downgrade(inner),
585            }),
586        }
587    }
588
589    pub fn id(&self) -> WasiThreadId {
590        self.protected.thread.tid()
591    }
592
593    pub fn as_thread(&self) -> WasiThread {
594        self.protected.thread.clone()
595    }
596}
597
598impl Drop for WasiThreadHandleProtected {
599    fn drop(&mut self) {
600        let id = self.thread.tid();
601        if let Some(inner) = Weak::upgrade(&self.inner) {
602            let mut inner = inner.0.lock().unwrap();
603            if let Some(ctrl) = inner.threads.remove(&id) {
604                ctrl.set_status_finished(Ok(Errno::Success.into()));
605            }
606            inner.thread_count -= 1;
607        }
608    }
609}
610
611impl std::ops::Deref for WasiThreadHandle {
612    type Target = WasiThread;
613
614    fn deref(&self) -> &Self::Target {
615        &self.protected.thread
616    }
617}
618
619#[derive(thiserror::Error, Debug, Clone)]
620pub enum WasiThreadError {
621    #[error("Multithreading is not supported")]
622    Unsupported,
623    #[error("The method named is not an exported function")]
624    MethodNotFound,
625    #[error("Failed to create the requested memory - {0}")]
626    MemoryCreateFailed(MemoryError),
627    #[error("{0}")]
628    ExportError(ExportError),
629    #[error("Linker error: {0}")]
630    LinkError(Arc<LinkError>),
631    #[error("Failed to create the instance - {0}")]
632    // Note: Boxed so we can keep the error size down
633    InstanceCreateFailed(Box<InstantiationError>),
634    #[error("Initialization function failed - {0}")]
635    InitFailed(Arc<anyhow::Error>),
636    /// This will happen if WASM is running in a thread has not been created by the spawn_wasm call
637    #[error("WASM context is invalid")]
638    InvalidWasmContext,
639}
640
641impl From<WasiThreadError> for Errno {
642    fn from(a: WasiThreadError) -> Errno {
643        match a {
644            WasiThreadError::Unsupported => Errno::Notsup,
645            WasiThreadError::MethodNotFound => Errno::Inval,
646            WasiThreadError::MemoryCreateFailed(_) => Errno::Nomem,
647            WasiThreadError::ExportError(_) => Errno::Noexec,
648            WasiThreadError::LinkError(_) => Errno::Noexec,
649            WasiThreadError::InstanceCreateFailed(_) => Errno::Noexec,
650            WasiThreadError::InitFailed(_) => Errno::Noexec,
651            WasiThreadError::InvalidWasmContext => Errno::Noexec,
652        }
653    }
654}