wasmer_wasix/os/task/
thread.rs

1use super::{
2    control_plane::TaskCountGuard,
3    task_join_handle::{OwnedTaskStatus, TaskJoinHandle},
4};
5use crate::{
6    WasiRuntimeError,
7    os::task::process::{WasiProcessId, WasiProcessInner},
8    state::LinkError,
9    syscalls::HandleRewindType,
10};
11use bytes::{Bytes, BytesMut};
12use serde::{Deserialize, Serialize};
13use std::sync::atomic::{AtomicBool, Ordering};
14use std::{
15    collections::HashMap,
16    ops::{Deref, DerefMut},
17    sync::{Arc, Condvar, Mutex, Weak},
18    task::Waker,
19};
20use wasmer::{ExportError, InstantiationError, MemoryError};
21use wasmer_wasix_types::{
22    types::Signal,
23    wasi::{Errno, ExitCode},
24    wasix::ThreadStartType,
25};
26
27/// Represents the ID of a WASI thread
28#[derive(Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
29pub struct WasiThreadId(u32);
30
31impl WasiThreadId {
32    pub fn raw(&self) -> u32 {
33        self.0
34    }
35
36    pub fn inc(&mut self) -> WasiThreadId {
37        let ret = *self;
38        self.0 += 1;
39        ret
40    }
41}
42
43impl From<i32> for WasiThreadId {
44    fn from(id: i32) -> Self {
45        Self(id as u32)
46    }
47}
48
49impl From<WasiThreadId> for i32 {
50    fn from(val: WasiThreadId) -> Self {
51        val.0 as i32
52    }
53}
54
55impl From<u32> for WasiThreadId {
56    fn from(id: u32) -> Self {
57        Self(id)
58    }
59}
60
61impl From<WasiThreadId> for u32 {
62    fn from(t: WasiThreadId) -> u32 {
63        t.0
64    }
65}
66
67impl std::fmt::Display for WasiThreadId {
68    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69        write!(f, "{}", self.0)
70    }
71}
72
73impl std::fmt::Debug for WasiThreadId {
74    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75        write!(f, "{}", self.0)
76    }
77}
78
79/// Represents a linked list of stack snapshots
80#[derive(Debug, Clone)]
81struct ThreadSnapshot {
82    call_stack: Bytes,
83    store_data: Bytes,
84}
85
86/// Represents a linked list of stack snapshots
87#[derive(Debug, Clone, Default)]
88pub struct ThreadStack {
89    memory_stack: Vec<u8>,
90    memory_stack_corrected: Vec<u8>,
91    snapshots: HashMap<u128, ThreadSnapshot>,
92    next: Option<Box<ThreadStack>>,
93}
94
95/// Represents a running thread which allows a joiner to
96/// wait for the thread to exit
97#[derive(Clone, Debug)]
98pub struct WasiThread {
99    state: Arc<WasiThreadState>,
100    layout: WasiMemoryLayout,
101    start: ThreadStartType,
102
103    // This is used for stack rewinds
104    rewind: Option<RewindResult>,
105}
106
107impl WasiThread {
108    pub fn id(&self) -> WasiThreadId {
109        self.state.id
110    }
111
112    /// Sets that a rewind will take place
113    pub(crate) fn set_rewind(&mut self, rewind: RewindResult) {
114        self.rewind.replace(rewind);
115    }
116
117    /// Pops any rewinds that need to take place
118    pub(crate) fn take_rewind(&mut self) -> Option<RewindResult> {
119        self.rewind.take()
120    }
121
122    /// Gets the thread start type for this thread
123    pub fn thread_start_type(&self) -> ThreadStartType {
124        self.start
125    }
126
127    /// Returns true if a rewind of a particular type has been queued
128    /// for processed by a rewind operation
129    pub(crate) fn has_rewind_of_type(&self, type_: HandleRewindType) -> bool {
130        match type_ {
131            HandleRewindType::ResultDriven => match &self.rewind {
132                Some(rewind) => match rewind.rewind_result {
133                    RewindResultType::RewindRestart => true,
134                    RewindResultType::RewindWithoutResult => false,
135                    RewindResultType::RewindWithResult(_) => true,
136                },
137                None => false,
138            },
139            HandleRewindType::ResultLess => match &self.rewind {
140                Some(rewind) => match rewind.rewind_result {
141                    RewindResultType::RewindRestart => true,
142                    RewindResultType::RewindWithoutResult => true,
143                    RewindResultType::RewindWithResult(_) => false,
144                },
145                None => false,
146            },
147        }
148    }
149
150    /// Sets a flag that tells others if this thread is currently
151    /// deep sleeping
152    pub(crate) fn set_deep_sleeping(&self, val: bool) {
153        self.state.deep_sleeping.store(val, Ordering::SeqCst);
154    }
155
156    /// Reads a flag that determines if this thread is currently
157    /// deep sleeping
158    pub(crate) fn is_deep_sleeping(&self) -> bool {
159        self.state.deep_sleeping.load(Ordering::SeqCst)
160    }
161
162    /// Sets a flag that tells others that this thread is currently
163    /// check pointing itself
164    #[cfg(feature = "journal")]
165    pub(crate) fn set_checkpointing(&self, val: bool) {
166        self.state.check_pointing.store(val, Ordering::SeqCst);
167    }
168
169    /// Reads a flag that determines if this thread is currently
170    /// check pointing itself or not
171    #[cfg(feature = "journal")]
172    pub(crate) fn is_check_pointing(&self) -> bool {
173        self.state.check_pointing.load(Ordering::SeqCst)
174    }
175
176    /// Gets the memory layout for this thread
177    #[allow(dead_code)]
178    pub(crate) fn memory_layout(&self) -> &WasiMemoryLayout {
179        &self.layout
180    }
181
182    /// Gets the memory layout for this thread
183    pub(crate) fn set_memory_layout(&mut self, layout: WasiMemoryLayout) {
184        self.layout = layout;
185    }
186}
187
188/// A guard that ensures a thread is marked as terminated when dropped.
189///
190/// Normally the thread result should be manually registered with
191/// [`WasiThread::set_status_running`] or [`WasiThread::set_status_finished`],
192/// but this guard can ensure that the thread is marked as terminated even if
193/// this is forgotten or a panic occurs.
194pub struct WasiThreadRunGuard {
195    pub thread: WasiThread,
196}
197
198impl WasiThreadRunGuard {
199    pub fn new(thread: WasiThread) -> Self {
200        Self { thread }
201    }
202}
203
204impl Drop for WasiThreadRunGuard {
205    fn drop(&mut self) {
206        self.thread
207            .set_status_finished(Err(
208                crate::RuntimeError::new("Thread manager disconnected").into()
209            ));
210    }
211}
212
213/// Represents the memory layout of the parts that the thread itself uses
214pub use wasmer_wasix_types::wasix::WasiMemoryLayout;
215
216#[derive(Clone, Debug)]
217pub enum RewindResultType {
218    // The rewind must restart the operation it had already started
219    RewindRestart,
220    // The rewind has been triggered and should be handled but has not result
221    RewindWithoutResult,
222    // The rewind has been triggered and should be handled with the supplied result
223    RewindWithResult(Bytes),
224}
225
226// Contains the result of a rewind operation
227#[derive(Clone, Debug)]
228pub(crate) struct RewindResult {
229    /// Memory stack used to restore the memory stack (thing that holds local variables) back to where it was
230    pub memory_stack: Option<Bytes>,
231    /// Generic serialized object passed back to the rewind resumption code
232    /// (uses the bincode serializer)
233    pub rewind_result: RewindResultType,
234}
235
236#[derive(Debug)]
237struct WasiThreadState {
238    is_main: bool,
239    pid: WasiProcessId,
240    id: WasiThreadId,
241    signals: Mutex<(Vec<Signal>, Vec<Waker>)>,
242    stack: Mutex<ThreadStack>,
243    status: Arc<OwnedTaskStatus>,
244    #[cfg(feature = "journal")]
245    check_pointing: AtomicBool,
246    deep_sleeping: AtomicBool,
247
248    // Registers the task termination with the ControlPlane on drop.
249    // Never accessed, since it's a drop guard.
250    _task_count_guard: TaskCountGuard,
251}
252
253static NO_MORE_BYTES: [u8; 0] = [0u8; 0];
254
255impl WasiThread {
256    pub fn new(
257        pid: WasiProcessId,
258        id: WasiThreadId,
259        is_main: bool,
260        status: Arc<OwnedTaskStatus>,
261        guard: TaskCountGuard,
262        layout: WasiMemoryLayout,
263        start: ThreadStartType,
264    ) -> Self {
265        Self {
266            state: Arc::new(WasiThreadState {
267                is_main,
268                pid,
269                id,
270                status,
271                signals: Mutex::new((Vec::new(), Vec::new())),
272                stack: Mutex::new(ThreadStack::default()),
273                #[cfg(feature = "journal")]
274                check_pointing: AtomicBool::new(false),
275                deep_sleeping: AtomicBool::new(false),
276                _task_count_guard: guard,
277            }),
278            layout,
279            start,
280            rewind: None,
281        }
282    }
283
284    /// Returns the process ID
285    pub fn pid(&self) -> WasiProcessId {
286        self.state.pid
287    }
288
289    /// Returns the thread ID
290    pub fn tid(&self) -> WasiThreadId {
291        self.state.id
292    }
293
294    /// Returns true if this thread is the main thread
295    pub fn is_main(&self) -> bool {
296        self.state.is_main
297    }
298
299    /// Get a join handle to watch the task status.
300    pub fn join_handle(&self) -> TaskJoinHandle {
301        self.state.status.handle()
302    }
303
304    // TODO: this should be private, access should go through utility methods.
305    pub fn signals(&self) -> &Mutex<(Vec<Signal>, Vec<Waker>)> {
306        &self.state.signals
307    }
308
309    pub fn set_status_running(&self) {
310        self.state.status.set_running();
311    }
312
313    /// Gets or sets the exit code based of a signal that was received
314    /// Note: if the exit code was already set earlier this method will
315    /// just return that earlier set exit code
316    pub fn set_or_get_exit_code_for_signal(&self, sig: Signal) -> ExitCode {
317        let default_exitcode: ExitCode = match sig {
318            Signal::Sigquit | Signal::Sigabrt => Errno::Success.into(),
319            Signal::Sigpipe => Errno::Pipe.into(),
320            _ => Errno::Intr.into(),
321        };
322        // This will only set the status code if its not already set
323        self.set_status_finished(Ok(default_exitcode));
324        self.try_join()
325            .map(|r| r.unwrap_or(default_exitcode))
326            .unwrap_or(default_exitcode)
327    }
328
329    /// Marks the thread as finished (which will cause anyone that
330    /// joined on it to wake up)
331    pub fn set_status_finished(&self, res: Result<ExitCode, WasiRuntimeError>) {
332        self.state.status.set_finished(res.map_err(Arc::new));
333    }
334
335    /// Waits until the thread is finished or the timeout is reached
336    pub async fn join(&self) -> Result<ExitCode, Arc<WasiRuntimeError>> {
337        self.state.status.await_termination().await
338    }
339
340    /// Attempts to join on the thread
341    pub fn try_join(&self) -> Option<Result<ExitCode, Arc<WasiRuntimeError>>> {
342        self.state.status.status().into_finished()
343    }
344
345    /// Adds a signal for this thread to process
346    pub fn signal(&self, signal: Signal) {
347        let tid = self.tid();
348        tracing::trace!(%tid, "signal-thread({:?})", signal);
349
350        let mut guard = self.state.signals.lock().unwrap();
351        if !guard.0.contains(&signal) {
352            guard.0.push(signal);
353        }
354        guard.1.drain(..).for_each(|w| w.wake());
355    }
356
357    /// Returns all the signals that are waiting to be processed
358    pub fn has_signal(&self, signals: &[Signal]) -> bool {
359        let guard = self.state.signals.lock().unwrap();
360        for s in guard.0.iter() {
361            if signals.contains(s) {
362                return true;
363            }
364        }
365        false
366    }
367
368    /// Waits for a signal to arrive
369    pub async fn wait_for_signal(&self) {
370        // This poller will process any signals when the main working function is idle
371        struct SignalPoller<'a> {
372            thread: &'a WasiThread,
373        }
374        impl std::future::Future for SignalPoller<'_> {
375            type Output = ();
376            fn poll(
377                self: std::pin::Pin<&mut Self>,
378                cx: &mut std::task::Context<'_>,
379            ) -> std::task::Poll<Self::Output> {
380                if self.thread.has_signals_or_subscribe(cx.waker()) {
381                    return std::task::Poll::Ready(());
382                }
383                std::task::Poll::Pending
384            }
385        }
386        SignalPoller { thread: self }.await
387    }
388
389    /// Returns all the signals that are waiting to be processed
390    pub fn pop_signals_or_subscribe(&self, waker: &Waker) -> Option<Vec<Signal>> {
391        let mut guard = self.state.signals.lock().unwrap();
392        let mut ret = Vec::new();
393        std::mem::swap(&mut ret, &mut guard.0);
394        match ret.is_empty() {
395            true => {
396                if !guard.1.iter().any(|w| w.will_wake(waker)) {
397                    guard.1.push(waker.clone());
398                }
399                None
400            }
401            false => Some(ret),
402        }
403    }
404
405    pub fn signals_subscribe(&self, waker: &Waker) {
406        let mut guard = self.state.signals.lock().unwrap();
407        if !guard.1.iter().any(|w| w.will_wake(waker)) {
408            guard.1.push(waker.clone());
409        }
410    }
411
412    /// Returns all the signals that are waiting to be processed
413    pub fn has_signals_or_subscribe(&self, waker: &Waker) -> bool {
414        let mut guard = self.state.signals.lock().unwrap();
415        let has_signals = !guard.0.is_empty();
416        if !has_signals && !guard.1.iter().any(|w| w.will_wake(waker)) {
417            guard.1.push(waker.clone());
418        }
419        has_signals
420    }
421
422    /// Returns all the signals that are waiting to be processed
423    pub fn pop_signals(&self) -> Vec<Signal> {
424        let mut guard = self.state.signals.lock().unwrap();
425        let mut ret = Vec::new();
426        std::mem::swap(&mut ret, &mut guard.0);
427        ret
428    }
429
430    /// Adds a stack snapshot and removes dead ones
431    pub fn add_snapshot(
432        &self,
433        mut memory_stack: &[u8],
434        mut memory_stack_corrected: &[u8],
435        hash: u128,
436        rewind_stack: &[u8],
437        store_data: &[u8],
438    ) {
439        // Lock the stack
440        let mut stack = self.state.stack.lock().unwrap();
441        let mut pstack = stack.deref_mut();
442        loop {
443            // First we validate if the stack is no longer valid
444            let memory_stack_before = pstack.memory_stack.len();
445            let memory_stack_after = memory_stack.len();
446            if memory_stack_before > memory_stack_after
447                || (!pstack
448                    .memory_stack
449                    .iter()
450                    .zip(memory_stack.iter())
451                    .any(|(a, b)| *a == *b)
452                    && !pstack
453                        .memory_stack_corrected
454                        .iter()
455                        .zip(memory_stack.iter())
456                        .any(|(a, b)| *a == *b))
457            {
458                // The stacks have changed so need to start again at this segment
459                let mut new_stack = ThreadStack {
460                    memory_stack: memory_stack.to_vec(),
461                    memory_stack_corrected: memory_stack_corrected.to_vec(),
462                    ..Default::default()
463                };
464                std::mem::swap(pstack, &mut new_stack);
465                memory_stack = &NO_MORE_BYTES[..];
466                memory_stack_corrected = &NO_MORE_BYTES[..];
467
468                // Output debug info for the dead stack
469                let mut disown = Some(Box::new(new_stack));
470                if let Some(disown) = disown.as_ref()
471                    && !disown.snapshots.is_empty()
472                {
473                    tracing::trace!(
474                        "wasi[{}]::stacks forgotten (memory_stack_before={}, memory_stack_after={})",
475                        self.pid(),
476                        memory_stack_before,
477                        memory_stack_after
478                    );
479                }
480                let mut total_forgotten = 0usize;
481                while let Some(disowned) = disown {
482                    for _hash in disowned.snapshots.keys() {
483                        total_forgotten += 1;
484                    }
485                    disown = disowned.next;
486                }
487                if total_forgotten > 0 {
488                    tracing::trace!(
489                        "wasi[{}]::stack has been forgotten (cnt={})",
490                        self.pid(),
491                        total_forgotten
492                    );
493                }
494            } else {
495                memory_stack = &memory_stack[pstack.memory_stack.len()..];
496                memory_stack_corrected =
497                    &memory_stack_corrected[pstack.memory_stack_corrected.len()..];
498            }
499
500            // If there is no more memory stack then we are done and can add the call stack
501            if memory_stack.is_empty() {
502                break;
503            }
504
505            // Otherwise we need to add a next stack pointer and continue the iterations
506            if pstack.next.is_none() {
507                let new_stack = ThreadStack {
508                    memory_stack: memory_stack.to_vec(),
509                    memory_stack_corrected: memory_stack_corrected.to_vec(),
510                    ..Default::default()
511                };
512                pstack.next.replace(Box::new(new_stack));
513            }
514            pstack = pstack.next.as_mut().unwrap();
515        }
516
517        // Add the call stack
518        pstack.snapshots.insert(
519            hash,
520            ThreadSnapshot {
521                call_stack: BytesMut::from(rewind_stack).freeze(),
522                store_data: BytesMut::from(store_data).freeze(),
523            },
524        );
525    }
526
527    /// Gets a snapshot that was previously addedf
528    pub fn get_snapshot(&self, hash: u128) -> Option<(BytesMut, Bytes, Bytes)> {
529        let mut memory_stack = BytesMut::new();
530
531        let stack = self.state.stack.lock().unwrap();
532        let mut pstack = stack.deref();
533        loop {
534            memory_stack.extend(pstack.memory_stack_corrected.iter());
535            if let Some(snapshot) = pstack.snapshots.get(&hash) {
536                return Some((
537                    memory_stack,
538                    snapshot.call_stack.clone(),
539                    snapshot.store_data.clone(),
540                ));
541            }
542            if let Some(next) = pstack.next.as_ref() {
543                pstack = next.deref();
544            } else {
545                return None;
546            }
547        }
548    }
549
550    // Copy the stacks from another thread
551    pub fn copy_stack_from(&self, other: &WasiThread) {
552        let mut stack = {
553            let stack_guard = other.state.stack.lock().unwrap();
554            stack_guard.clone()
555        };
556
557        let mut stack_guard = self.state.stack.lock().unwrap();
558        std::mem::swap(stack_guard.deref_mut(), &mut stack);
559    }
560}
561
562#[derive(Debug)]
563pub struct WasiThreadHandleProtected {
564    thread: WasiThread,
565    inner: Weak<(Mutex<WasiProcessInner>, Condvar)>,
566}
567
568#[derive(Debug, Clone)]
569pub struct WasiThreadHandle {
570    protected: Arc<WasiThreadHandleProtected>,
571}
572
573impl WasiThreadHandle {
574    pub(crate) fn new(
575        thread: WasiThread,
576        inner: &Arc<(Mutex<WasiProcessInner>, Condvar)>,
577    ) -> WasiThreadHandle {
578        Self {
579            protected: Arc::new(WasiThreadHandleProtected {
580                thread,
581                inner: Arc::downgrade(inner),
582            }),
583        }
584    }
585
586    pub fn id(&self) -> WasiThreadId {
587        self.protected.thread.tid()
588    }
589
590    pub fn as_thread(&self) -> WasiThread {
591        self.protected.thread.clone()
592    }
593}
594
595impl Drop for WasiThreadHandleProtected {
596    fn drop(&mut self) {
597        let id = self.thread.tid();
598        if let Some(inner) = Weak::upgrade(&self.inner) {
599            let mut inner = inner.0.lock().unwrap();
600            if let Some(ctrl) = inner.threads.remove(&id) {
601                ctrl.set_status_finished(Ok(Errno::Success.into()));
602            }
603            inner.thread_count -= 1;
604        }
605    }
606}
607
608impl std::ops::Deref for WasiThreadHandle {
609    type Target = WasiThread;
610
611    fn deref(&self) -> &Self::Target {
612        &self.protected.thread
613    }
614}
615
616#[derive(thiserror::Error, Debug, Clone)]
617pub enum WasiThreadError {
618    #[error("Multithreading is not supported")]
619    Unsupported,
620    #[error("The method named is not an exported function")]
621    MethodNotFound,
622    #[error("Failed to create the requested memory - {0}")]
623    MemoryCreateFailed(MemoryError),
624    #[error("{0}")]
625    ExportError(ExportError),
626    #[error("Linker error: {0}")]
627    LinkError(Arc<LinkError>),
628    #[error("Failed to create the instance - {0}")]
629    // Note: Boxed so we can keep the error size down
630    InstanceCreateFailed(Box<InstantiationError>),
631    #[error("Initialization function failed - {0}")]
632    InitFailed(Arc<anyhow::Error>),
633    /// This will happen if WASM is running in a thread has not been created by the spawn_wasm call
634    #[error("WASM context is invalid")]
635    InvalidWasmContext,
636}
637
638impl From<WasiThreadError> for Errno {
639    fn from(a: WasiThreadError) -> Errno {
640        match a {
641            WasiThreadError::Unsupported => Errno::Notsup,
642            WasiThreadError::MethodNotFound => Errno::Inval,
643            WasiThreadError::MemoryCreateFailed(_) => Errno::Nomem,
644            WasiThreadError::ExportError(_) => Errno::Noexec,
645            WasiThreadError::LinkError(_) => Errno::Noexec,
646            WasiThreadError::InstanceCreateFailed(_) => Errno::Noexec,
647            WasiThreadError::InitFailed(_) => Errno::Noexec,
648            WasiThreadError::InvalidWasmContext => Errno::Noexec,
649        }
650    }
651}