wasmer_wasix/os/task/
process.rs

1use crate::{WasiEnv, WasiRuntimeError, journal::SnapshotTrigger};
2#[cfg(feature = "journal")]
3use crate::{WasiResult, journal::JournalEffector, syscalls::do_checkpoint_from_outside, unwind};
4use serde::{Deserialize, Serialize};
5#[cfg(feature = "journal")]
6use std::collections::HashSet;
7use std::{
8    collections::HashMap,
9    convert::TryInto,
10    ops::Range,
11    sync::{
12        Arc, Condvar, Mutex, MutexGuard, RwLock, Weak,
13        atomic::{AtomicU32, Ordering},
14    },
15    task::Waker,
16    time::Duration,
17};
18use tracing::trace;
19use wasmer::FunctionEnvMut;
20use wasmer_types::ModuleHash;
21use wasmer_wasix_types::{
22    types::Signal,
23    wasi::{Errno, ExitCode, Snapshot0Clockid},
24    wasix::ThreadStartType,
25};
26
27use crate::{
28    WasiThread, WasiThreadHandle, WasiThreadId, os::task::signal::WasiSignalInterval,
29    syscalls::platform_clock_time_get,
30};
31
32use super::{
33    TaskStatus,
34    backoff::WasiProcessCpuBackoff,
35    control_plane::{ControlPlaneError, WasiControlPlaneHandle},
36    signal::{SignalDeliveryError, SignalHandlerAbi},
37    task_join_handle::OwnedTaskStatus,
38    thread::WasiMemoryLayout,
39};
40
41/// Represents the ID of a sub-process
42#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
43pub struct WasiProcessId(u32);
44
45impl WasiProcessId {
46    pub fn raw(&self) -> u32 {
47        self.0
48    }
49}
50
51impl From<i32> for WasiProcessId {
52    fn from(id: i32) -> Self {
53        Self(id as u32)
54    }
55}
56
57impl From<WasiProcessId> for i32 {
58    fn from(val: WasiProcessId) -> Self {
59        val.0 as i32
60    }
61}
62
63impl From<u32> for WasiProcessId {
64    fn from(id: u32) -> Self {
65        Self(id)
66    }
67}
68
69impl From<WasiProcessId> for u32 {
70    fn from(val: WasiProcessId) -> Self {
71        val.0
72    }
73}
74
75impl std::fmt::Display for WasiProcessId {
76    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77        write!(f, "{}", self.0)
78    }
79}
80
81impl std::fmt::Debug for WasiProcessId {
82    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83        write!(f, "{}", self.0)
84    }
85}
86
87pub type LockableWasiProcessInner = Arc<(Mutex<WasiProcessInner>, Condvar)>;
88
89/// Represents a process running within the compute state
90/// TODO: fields should be private and only accessed via methods.
91#[derive(Debug, Clone)]
92pub struct WasiProcess {
93    /// Unique ID of this process
94    pub(crate) pid: WasiProcessId,
95    /// Hash of the module that this process is using
96    pub(crate) module_hash: ModuleHash,
97    /// List of all the children spawned from this thread
98    pub(crate) parent: Option<Weak<RwLock<WasiProcessInner>>>,
99    /// The inner protected region of the process with a conditional
100    /// variable that is used for coordination such as snapshots.
101    pub(crate) inner: LockableWasiProcessInner,
102    /// Reference back to the compute engine
103    // TODO: remove this reference, access should happen via separate state instead
104    // (we don't want cyclical references)
105    pub(crate) compute: WasiControlPlaneHandle,
106    /// Reference to the exit code for the main thread
107    pub(crate) finished: Arc<OwnedTaskStatus>,
108    /// Number of threads waiting for children to exit
109    pub(crate) waiting: Arc<AtomicU32>,
110    /// Number of tokens that are currently active and thus
111    /// the exponential backoff of CPU is halted (as in CPU
112    /// is allowed to run freely)
113    pub(crate) cpu_run_tokens: Arc<AtomicU32>,
114}
115
116/// Represents a freeze of all threads to perform some action
117/// on the total state-machine. This is normally done for
118/// things like snapshots which require the memory to remain
119/// stable while it performs a diff.
120#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
121pub enum WasiProcessCheckpoint {
122    /// No checkpoint will take place and the process
123    /// should just execute as per normal
124    Execute,
125    /// The process needs to take a snapshot of the
126    /// memory and state-machine
127    Snapshot { trigger: SnapshotTrigger },
128}
129
130#[repr(C)]
131#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
132pub struct MemorySnapshotRegion {
133    pub start: u64,
134    pub end: u64,
135}
136
137impl From<Range<u64>> for MemorySnapshotRegion {
138    fn from(value: Range<u64>) -> Self {
139        Self {
140            start: value.start,
141            end: value.end,
142        }
143    }
144}
145
146#[allow(clippy::from_over_into)]
147impl Into<Range<u64>> for MemorySnapshotRegion {
148    fn into(self) -> Range<u64> {
149        self.start..self.end
150    }
151}
152
153// TODO: fields should be private and only accessed via methods.
154#[derive(Debug)]
155pub struct WasiProcessInner {
156    /// Unique ID of this process
157    pub pid: WasiProcessId,
158    /// Number of threads waiting for children to exit
159    pub(crate) waiting: Arc<AtomicU32>,
160    /// The threads that make up this process
161    pub threads: HashMap<WasiThreadId, WasiThread>,
162    /// Number of threads running for this process
163    pub thread_count: u32,
164    /// Signals that will be triggered at specific intervals
165    pub signal_intervals: HashMap<Signal, WasiSignalInterval>,
166    /// List of all the children spawned from this thread
167    pub children: Vec<WasiProcess>,
168    /// Represents a checkpoint which blocks all the threads
169    /// and then executes some maintenance action
170    pub checkpoint: WasiProcessCheckpoint,
171    /// If true then the journaling will be disabled after the
172    /// next snapshot is taken
173    pub disable_journaling_after_checkpoint: bool,
174    /// If true then the process will stop running after the
175    /// next snapshot is taken
176    pub stop_running_after_checkpoint: bool,
177    /// List of situations that the process will checkpoint on
178    #[cfg(feature = "journal")]
179    pub snapshot_on: HashSet<SnapshotTrigger>,
180    /// Any wakers waiting on this process (for example for a checkpoint)
181    pub wakers: Vec<Waker>,
182    /// The snapshot memory significantly reduce the amount of
183    /// duplicate entries in the journal for memory that has not changed
184    #[cfg(feature = "journal")]
185    pub snapshot_memory_hash: HashMap<MemorySnapshotRegion, u64>,
186    /// Represents all the backoff properties for this process
187    /// which will be used to determine if the CPU should be
188    /// throttled or not
189    pub(super) backoff: WasiProcessCpuBackoff,
190}
191
192pub enum MaybeCheckpointResult<'a> {
193    NotThisTime(FunctionEnvMut<'a, WasiEnv>),
194    Unwinding,
195}
196
197impl WasiProcessInner {
198    /// Checkpoints the process which will cause all other threads to
199    /// pause and for the thread and memory state to be saved
200    #[cfg(feature = "journal")]
201    pub fn checkpoint<M: wasmer_types::MemorySize>(
202        inner: LockableWasiProcessInner,
203        ctx: FunctionEnvMut<'_, WasiEnv>,
204        for_what: WasiProcessCheckpoint,
205    ) -> WasiResult<MaybeCheckpointResult<'_>> {
206        // Set the checkpoint flag and then enter the normal processing loop
207        {
208            let mut guard = inner.0.lock().unwrap();
209            guard.checkpoint = for_what;
210            for waker in guard.wakers.drain(..) {
211                waker.wake();
212            }
213            inner.1.notify_all();
214        }
215
216        Self::maybe_checkpoint::<M>(inner, ctx)
217    }
218
219    /// If a checkpoint has been started this will block the current process
220    /// until the checkpoint operation has completed
221    #[cfg(feature = "journal")]
222    pub fn maybe_checkpoint<M: wasmer_types::MemorySize>(
223        inner: LockableWasiProcessInner,
224        ctx: FunctionEnvMut<'_, WasiEnv>,
225    ) -> WasiResult<MaybeCheckpointResult<'_>> {
226        // Enter the lock which will determine if we are in a checkpoint or not
227
228        use bytes::Bytes;
229        use wasmer::AsStoreMut;
230        use wasmer_types::OnCalledAction;
231
232        use crate::{WasiError, os::task::thread::RewindResultType, rewind_ext};
233        let guard = inner.0.lock().unwrap();
234        if guard.checkpoint == WasiProcessCheckpoint::Execute {
235            // No checkpoint so just carry on
236            return Ok(Ok(MaybeCheckpointResult::NotThisTime(ctx)));
237        }
238        trace!("checkpoint capture");
239        drop(guard);
240
241        // Perform the unwind action
242        let thread_layout = ctx.data().thread.memory_layout().clone();
243        unwind::<M, _>(ctx, move |mut ctx, memory_stack, rewind_stack| {
244            // Grab all the globals and serialize them
245            let store_data = crate::utils::store::capture_store_snapshot(&mut ctx.as_store_mut())
246                .serialize()
247                .unwrap();
248            let memory_stack = memory_stack.freeze();
249            let rewind_stack = rewind_stack.freeze();
250            let store_data = Bytes::from(store_data);
251
252            tracing::debug!(
253                "stack snapshot unwind (memory_stack={}, rewind_stack={}, store_data={})",
254                memory_stack.len(),
255                rewind_stack.len(),
256                store_data.len(),
257            );
258
259            // Write our thread state to the snapshot
260            let thread_start = ctx.data().thread.thread_start_type();
261            let tid = ctx.data().thread.tid();
262            if let Err(err) = JournalEffector::save_thread_state::<M>(
263                &mut ctx,
264                tid,
265                memory_stack.clone(),
266                rewind_stack.clone(),
267                store_data.clone(),
268                thread_start,
269                thread_layout,
270            ) {
271                return wasmer_types::OnCalledAction::Trap(err.into());
272            }
273
274            let mut guard = inner.0.lock().unwrap();
275
276            // Wait for the checkpoint to finish (or if we are the last thread
277            // to freeze then we have to execute the checksum operation)
278            loop {
279                if let WasiProcessCheckpoint::Snapshot { trigger } = guard.checkpoint {
280                    ctx.data().thread.set_checkpointing(true);
281
282                    // Now if we are the last thread we also write the memory
283                    let is_last_thread = guard
284                        .threads
285                        .values()
286                        .all(|t| t.is_check_pointing() || t.is_deep_sleeping());
287                    if is_last_thread {
288                        if let Err(err) =
289                            JournalEffector::save_memory_and_snapshot(&mut ctx, &mut guard, trigger)
290                        {
291                            inner.1.notify_all();
292                            return wasmer_types::OnCalledAction::Trap(err.into());
293                        }
294
295                        // Clear the checkpointing flag and notify everyone to wake up
296                        ctx.data().thread.set_checkpointing(false);
297                        trace!("checkpoint complete");
298                        if guard.disable_journaling_after_checkpoint {
299                            ctx.data_mut().enable_journal = false;
300                        }
301                        guard.checkpoint = WasiProcessCheckpoint::Execute;
302                        for waker in guard.wakers.drain(..) {
303                            waker.wake();
304                        }
305                        inner.1.notify_all();
306                    } else {
307                        guard = inner.1.wait(guard).unwrap();
308                    }
309                    continue;
310                }
311
312                ctx.data().thread.set_checkpointing(false);
313                trace!("checkpoint finished");
314
315                if guard.stop_running_after_checkpoint {
316                    trace!("will stop running now");
317                    // Need to stop recording journal events so we don't also record the
318                    // thread and process exit events
319                    ctx.data_mut().enable_journal = false;
320                    return OnCalledAction::Finish;
321                }
322
323                // Rewind the stack and carry on
324                return match rewind_ext::<M>(
325                    &mut ctx,
326                    Some(memory_stack),
327                    rewind_stack,
328                    store_data,
329                    RewindResultType::RewindWithoutResult,
330                ) {
331                    Errno::Success => OnCalledAction::InvokeAgain,
332                    err => {
333                        tracing::warn!(
334                            "snapshot resumption failed - could not rewind the stack - errno={}",
335                            err
336                        );
337                        OnCalledAction::Trap(Box::new(WasiError::Exit(err.into())))
338                    }
339                };
340            }
341        })?;
342
343        Ok(Ok(MaybeCheckpointResult::Unwinding))
344    }
345
346    // Execute any checkpoints that can be executed while outside of the WASM process
347    #[cfg(not(feature = "journal"))]
348    pub fn do_checkpoints_from_outside(_ctx: &mut FunctionEnvMut<'_, WasiEnv>) {}
349
350    // Execute any checkpoints that can be executed while outside of the WASM process
351    #[cfg(feature = "journal")]
352    pub fn do_checkpoints_from_outside(ctx: &mut FunctionEnvMut<'_, WasiEnv>) {
353        let inner = ctx.data().process.inner.clone();
354        let mut guard = inner.0.lock().unwrap();
355
356        // Wait for the checkpoint to finish (or if we are the last thread
357        // to freeze then we have to execute the checksum operation)
358        while let WasiProcessCheckpoint::Snapshot { trigger } = guard.checkpoint {
359            ctx.data().thread.set_checkpointing(true);
360
361            // Now if we are the last thread we also write the memory
362            let is_last_thread = guard
363                .threads
364                .values()
365                .all(|t| t.is_check_pointing() || t.is_deep_sleeping());
366            if is_last_thread {
367                if let Err(err) =
368                    JournalEffector::save_memory_and_snapshot(ctx, &mut guard, trigger)
369                {
370                    inner.1.notify_all();
371                    tracing::error!("failed to snapshot memory and threads - {}", err);
372                    return;
373                }
374
375                // Clear the checkpointing flag and notify everyone to wake up
376                ctx.data().thread.set_checkpointing(false);
377                trace!("checkpoint complete");
378                if guard.disable_journaling_after_checkpoint {
379                    ctx.data_mut().enable_journal = false;
380                }
381                guard.checkpoint = WasiProcessCheckpoint::Execute;
382                for waker in guard.wakers.drain(..) {
383                    waker.wake();
384                }
385                inner.1.notify_all();
386            } else {
387                guard = inner.1.wait(guard).unwrap();
388            }
389            continue;
390        }
391
392        ctx.data().thread.set_checkpointing(false);
393        trace!("checkpoint finished");
394    }
395}
396
397// TODO: why do we need this, how is it used?
398pub(crate) struct WasiProcessWait {
399    waiting: Arc<AtomicU32>,
400}
401
402impl WasiProcessWait {
403    pub fn new(process: &WasiProcess) -> Self {
404        process.waiting.fetch_add(1, Ordering::AcqRel);
405        Self {
406            waiting: process.waiting.clone(),
407        }
408    }
409}
410
411impl Drop for WasiProcessWait {
412    fn drop(&mut self) {
413        self.waiting.fetch_sub(1, Ordering::AcqRel);
414    }
415}
416
417impl WasiProcess {
418    pub fn new(pid: WasiProcessId, module_hash: ModuleHash, plane: WasiControlPlaneHandle) -> Self {
419        let max_cpu_backoff_time = plane
420            .upgrade()
421            .and_then(|p| p.config().enable_exponential_cpu_backoff)
422            .unwrap_or(Duration::from_secs(30));
423        let max_cpu_cool_off_time = Duration::from_millis(500);
424
425        let waiting = Arc::new(AtomicU32::new(0));
426        let inner = Arc::new((
427            Mutex::new(WasiProcessInner {
428                pid,
429                threads: Default::default(),
430                thread_count: Default::default(),
431                signal_intervals: Default::default(),
432                children: Default::default(),
433                checkpoint: WasiProcessCheckpoint::Execute,
434                wakers: Default::default(),
435                waiting: waiting.clone(),
436                #[cfg(feature = "journal")]
437                snapshot_on: Default::default(),
438                #[cfg(feature = "journal")]
439                snapshot_memory_hash: Default::default(),
440                disable_journaling_after_checkpoint: false,
441                stop_running_after_checkpoint: false,
442                backoff: WasiProcessCpuBackoff::new(max_cpu_backoff_time, max_cpu_cool_off_time),
443            }),
444            Condvar::new(),
445        ));
446
447        #[derive(Debug)]
448        struct SignalHandler(LockableWasiProcessInner);
449        impl SignalHandlerAbi for SignalHandler {
450            fn signal(&self, signal: u8) -> Result<(), SignalDeliveryError> {
451                if let Ok(signal) = signal.try_into() {
452                    signal_process_internal(&self.0, signal);
453                    Ok(())
454                } else {
455                    Err(SignalDeliveryError)
456                }
457            }
458        }
459
460        WasiProcess {
461            pid,
462            module_hash,
463            parent: None,
464            compute: plane,
465            inner: inner.clone(),
466            finished: Arc::new(
467                OwnedTaskStatus::new(TaskStatus::Pending)
468                    .with_signal_handler(Arc::new(SignalHandler(inner))),
469            ),
470            waiting,
471            cpu_run_tokens: Arc::new(AtomicU32::new(0)),
472        }
473    }
474
475    pub(super) fn set_pid(&mut self, pid: WasiProcessId) {
476        self.pid = pid;
477    }
478
479    /// Gets the process ID of this process
480    pub fn pid(&self) -> WasiProcessId {
481        self.pid
482    }
483
484    /// Gets the process ID of the parent process
485    pub fn ppid(&self) -> WasiProcessId {
486        self.parent
487            .iter()
488            .filter_map(|parent| parent.upgrade())
489            .map(|parent| parent.read().unwrap().pid)
490            .next()
491            .unwrap_or(WasiProcessId(0))
492    }
493
494    /// Gains access to the process internals
495    // TODO: Make this private, all inner access should be exposed with methods.
496    pub fn lock(&self) -> MutexGuard<'_, WasiProcessInner> {
497        self.inner.0.lock().unwrap()
498    }
499
500    /// Creates a a thread and returns it
501    pub fn new_thread(
502        &self,
503        layout: WasiMemoryLayout,
504        start: ThreadStartType,
505    ) -> Result<WasiThreadHandle, ControlPlaneError> {
506        let control_plane = self.compute.must_upgrade();
507
508        // Determine if its the main thread or not
509        let is_main = matches!(start, ThreadStartType::MainThread);
510
511        // Generate a new process ID (this is because the process ID and thread ID
512        // address space must not overlap in libc). For the main proecess the TID=PID
513        let tid: WasiThreadId = if is_main {
514            self.pid().raw().into()
515        } else {
516            let tid: u32 = control_plane.generate_id()?.into();
517            tid.into()
518        };
519
520        self.new_thread_with_id(layout, start, tid)
521    }
522
523    /// Creates a a thread and returns it
524    pub fn new_thread_with_id(
525        &self,
526        layout: WasiMemoryLayout,
527        start: ThreadStartType,
528        tid: WasiThreadId,
529    ) -> Result<WasiThreadHandle, ControlPlaneError> {
530        let control_plane = self.compute.must_upgrade();
531        let task_count_guard = control_plane.register_task()?;
532
533        let is_main = matches!(start, ThreadStartType::MainThread);
534
535        // The wait finished should be the process version if its the main thread
536        let mut inner = self.inner.0.lock().unwrap();
537        let finished = if is_main {
538            self.finished.clone()
539        } else {
540            Arc::new(OwnedTaskStatus::default())
541        };
542
543        // Insert the thread into the pool
544        let ctrl = WasiThread::new(
545            self.pid(),
546            tid,
547            is_main,
548            finished,
549            task_count_guard,
550            layout,
551            start,
552        );
553        inner.threads.insert(tid, ctrl.clone());
554        inner.thread_count += 1;
555
556        Ok(WasiThreadHandle::new(ctrl, &self.inner))
557    }
558
559    pub fn all_threads(&self) -> Vec<WasiThreadId> {
560        let inner = self.inner.0.lock().unwrap();
561        inner.threads.keys().cloned().collect()
562    }
563
564    /// Gets a reference to a particular thread
565    pub fn get_thread(&self, tid: &WasiThreadId) -> Option<WasiThread> {
566        let inner = self.inner.0.lock().unwrap();
567        inner.threads.get(tid).cloned()
568    }
569
570    /// Signals a particular thread in the process
571    pub fn signal_thread(&self, tid: &WasiThreadId, signal: Signal) {
572        // Sometimes we will signal the process rather than the thread hence this libc hardcoded value
573        let mut tid = tid.raw();
574        if tid == 1073741823 {
575            tid = self.pid().raw();
576        }
577        let tid: WasiThreadId = tid.into();
578
579        let pid = self.pid();
580        tracing::trace!(%pid, %tid, "signal-thread({:?})", signal);
581
582        let inner = self.inner.0.lock().unwrap();
583        if let Some(thread) = inner.threads.get(&tid) {
584            thread.signal(signal);
585        } else {
586            trace!(
587                "wasi[{}]::lost-signal(tid={}, sig={:?})",
588                self.pid(),
589                tid,
590                signal
591            );
592        }
593    }
594
595    /// Signals all the threads in this process
596    pub fn signal_process(&self, signal: Signal) {
597        signal_process_internal(&self.inner, signal);
598    }
599
600    /// Takes a snapshot of the process and disables journaling returning
601    /// a future that can be waited on for the snapshot to complete
602    ///
603    /// Note: If you ignore the returned future the checkpoint will still
604    /// occur but it will execute asynchronously
605    pub fn snapshot_and_disable_journaling(
606        &self,
607        trigger: SnapshotTrigger,
608    ) -> std::pin::Pin<Box<dyn futures::Future<Output = ()> + Send + Sync>> {
609        let mut guard = self.inner.0.lock().unwrap();
610        guard.disable_journaling_after_checkpoint = true;
611        guard.checkpoint = WasiProcessCheckpoint::Snapshot { trigger };
612        self.wait_for_checkpoint_finish()
613    }
614
615    /// Takes a snapshot of the process and shuts it down after the snapshot
616    /// is taken.
617    ///
618    /// Note: If you ignore the returned future the checkpoint will still
619    /// occur but it will execute asynchronously
620    pub fn snapshot_and_stop(
621        &self,
622        trigger: SnapshotTrigger,
623    ) -> std::pin::Pin<Box<dyn futures::Future<Output = ()> + Send + Sync>> {
624        let mut guard = self.inner.0.lock().unwrap();
625        guard.stop_running_after_checkpoint = true;
626        guard.checkpoint = WasiProcessCheckpoint::Snapshot { trigger };
627        self.wait_for_checkpoint_finish()
628    }
629
630    /// Takes a snapshot of the process
631    ///
632    /// Note: If you ignore the returned future the checkpoint will still
633    /// occur but it will execute asynchronously
634    pub fn snapshot(
635        &self,
636        trigger: SnapshotTrigger,
637    ) -> std::pin::Pin<Box<dyn futures::Future<Output = ()> + Send + Sync>> {
638        let mut guard = self.inner.0.lock().unwrap();
639        guard.checkpoint = WasiProcessCheckpoint::Snapshot { trigger };
640        self.wait_for_checkpoint_finish()
641    }
642
643    /// Disables the journaling functionality
644    pub fn disable_journaling_after_checkpoint(&self) {
645        let mut guard = self.inner.0.lock().unwrap();
646        guard.disable_journaling_after_checkpoint = true;
647    }
648
649    /// Stop running once a checkpoint is taken
650    pub fn stop_running_after_checkpoint(&self) {
651        let mut guard = self.inner.0.lock().unwrap();
652        guard.stop_running_after_checkpoint = true;
653    }
654
655    /// Wait for the checkout process to finish
656    #[cfg(not(feature = "journal"))]
657    pub fn wait_for_checkpoint(
658        &self,
659    ) -> std::pin::Pin<Box<dyn futures::Future<Output = ()> + Send + Sync>> {
660        Box::pin(std::future::pending())
661    }
662
663    /// Wait for the checkout process to finish
664    #[cfg(feature = "journal")]
665    pub fn wait_for_checkpoint(
666        &self,
667    ) -> std::pin::Pin<Box<dyn futures::Future<Output = ()> + Send + Sync>> {
668        use futures::Future;
669        use std::{
670            pin::Pin,
671            task::{Context, Poll},
672        };
673
674        struct Poller {
675            inner: LockableWasiProcessInner,
676        }
677        impl Future for Poller {
678            type Output = ();
679            fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
680                let mut guard = self.inner.0.lock().unwrap();
681                if !matches!(guard.checkpoint, WasiProcessCheckpoint::Execute) {
682                    return Poll::Ready(());
683                }
684                if !guard.wakers.iter().any(|w| w.will_wake(cx.waker())) {
685                    guard.wakers.push(cx.waker().clone());
686                }
687                Poll::Pending
688            }
689        }
690        Box::pin(Poller {
691            inner: self.inner.clone(),
692        })
693    }
694
695    /// Wait for the checkout process to finish
696    #[cfg(not(feature = "journal"))]
697    pub fn wait_for_checkpoint_finish(
698        &self,
699    ) -> std::pin::Pin<Box<dyn futures::Future<Output = ()> + Send + Sync>> {
700        Box::pin(std::future::pending())
701    }
702
703    /// Wait for the checkout process to finish
704    #[cfg(feature = "journal")]
705    pub fn wait_for_checkpoint_finish(
706        &self,
707    ) -> std::pin::Pin<Box<dyn futures::Future<Output = ()> + Send + Sync>> {
708        use futures::Future;
709        use std::{
710            pin::Pin,
711            task::{Context, Poll},
712        };
713
714        struct Poller {
715            inner: LockableWasiProcessInner,
716        }
717        impl Future for Poller {
718            type Output = ();
719            fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
720                let mut guard = self.inner.0.lock().unwrap();
721                if matches!(guard.checkpoint, WasiProcessCheckpoint::Execute) {
722                    return Poll::Ready(());
723                }
724                if !guard.wakers.iter().any(|w| w.will_wake(cx.waker())) {
725                    guard.wakers.push(cx.waker().clone());
726                }
727                Poll::Pending
728            }
729        }
730        Box::pin(Poller {
731            inner: self.inner.clone(),
732        })
733    }
734
735    /// Signals one of the threads every interval
736    pub fn signal_interval(&self, signal: Signal, interval: Option<Duration>, repeat: bool) {
737        let mut inner = self.inner.0.lock().unwrap();
738
739        let interval = match interval {
740            None => {
741                inner.signal_intervals.remove(&signal);
742                return;
743            }
744            Some(a) => a,
745        };
746
747        let now = platform_clock_time_get(Snapshot0Clockid::Monotonic, 1_000_000).unwrap() as u128;
748        inner.signal_intervals.insert(
749            signal,
750            WasiSignalInterval {
751                signal,
752                interval,
753                last_signal: now,
754                repeat,
755            },
756        );
757    }
758
759    /// Returns the number of active threads for this process
760    pub fn active_threads(&self) -> u32 {
761        let inner = self.inner.0.lock().unwrap();
762        inner.thread_count
763    }
764
765    /// Waits until the process is finished.
766    pub async fn join(&self) -> Result<ExitCode, Arc<WasiRuntimeError>> {
767        let _guard = WasiProcessWait::new(self);
768        self.finished.await_termination().await
769    }
770
771    /// Attempts to join on the process
772    pub fn try_join(&self) -> Option<Result<ExitCode, Arc<WasiRuntimeError>>> {
773        self.finished.status().into_finished()
774    }
775
776    /// Waits for all the children to be finished
777    pub async fn join_children(&mut self) -> Option<Result<ExitCode, Arc<WasiRuntimeError>>> {
778        let _guard = WasiProcessWait::new(self);
779        let children: Vec<_> = {
780            let inner = self.inner.0.lock().unwrap();
781            inner.children.clone()
782        };
783        if children.is_empty() {
784            return None;
785        }
786        let mut waits = Vec::new();
787        for child in children {
788            if let Some(process) = self.compute.must_upgrade().get_process(child.pid) {
789                let inner = self.inner.clone();
790                waits.push(async move {
791                    let join = process.join().await;
792                    let mut inner = inner.0.lock().unwrap();
793                    inner.children.retain(|a| a.pid != child.pid);
794                    join
795                })
796            }
797        }
798        futures::future::join_all(waits.into_iter())
799            .await
800            .into_iter()
801            .next()
802    }
803
804    /// Waits for any of the children to finished
805    pub async fn join_any_child(&mut self) -> Result<Option<(WasiProcessId, ExitCode)>, Errno> {
806        let _guard = WasiProcessWait::new(self);
807        let children: Vec<_> = {
808            let inner = self.inner.0.lock().unwrap();
809            inner.children.clone()
810        };
811        if children.is_empty() {
812            return Err(Errno::Child);
813        }
814
815        let mut waits = Vec::new();
816        for child in children {
817            if let Some(process) = self.compute.must_upgrade().get_process(child.pid) {
818                let inner = self.inner.clone();
819                waits.push(async move {
820                    let join = process.join().await;
821                    let mut inner = inner.0.lock().unwrap();
822                    inner.children.retain(|a| a.pid != child.pid);
823                    (child, join)
824                })
825            }
826        }
827        let (child, res) = futures::future::select_all(waits.into_iter().map(Box::pin))
828            .await
829            .0;
830
831        let code =
832            res.unwrap_or_else(|e| e.as_exit_code().unwrap_or_else(|| Errno::Canceled.into()));
833
834        Ok(Some((child.pid, code)))
835    }
836
837    /// Terminate the process and all its threads
838    pub fn terminate(&self, exit_code: ExitCode) {
839        // FIXME: this is wrong, threads might still be running!
840        // Need special logic for the main thread.
841        let guard = self.inner.0.lock().unwrap();
842        for thread in guard.threads.values() {
843            thread.set_status_finished(Ok(exit_code))
844        }
845    }
846}
847
848/// Signals all the threads in this process
849fn signal_process_internal(process: &LockableWasiProcessInner, signal: Signal) {
850    #[allow(unused_mut)]
851    let mut guard = process.0.lock().unwrap();
852    let pid = guard.pid;
853    tracing::trace!(%pid, "signal-process({:?})", signal);
854
855    // If the snapshot on ctrl-c is currently registered then we need
856    // to take a snapshot and exit
857    #[cfg(feature = "journal")]
858    {
859        if signal == Signal::Sigint
860            && (guard.snapshot_on.contains(&SnapshotTrigger::Sigint)
861                || guard.snapshot_on.remove(&SnapshotTrigger::FirstSigint))
862        {
863            drop(guard);
864
865            tracing::debug!(%pid, "snapshot-on-interrupt-signal");
866
867            do_checkpoint_from_outside(
868                process,
869                WasiProcessCheckpoint::Snapshot {
870                    trigger: SnapshotTrigger::Sigint,
871                },
872            );
873            return;
874        };
875    }
876
877    // Check if there are subprocesses that will receive this signal
878    // instead of this process
879    if guard.waiting.load(Ordering::Acquire) > 0 {
880        let mut triggered = false;
881        for child in guard.children.iter() {
882            child.signal_process(signal);
883            triggered = true;
884        }
885        if triggered {
886            return;
887        }
888    }
889
890    // Otherwise just send the signal to all the threads
891    for thread in guard.threads.values() {
892        thread.signal(signal);
893    }
894}
895
896impl SignalHandlerAbi for WasiProcess {
897    fn signal(&self, sig: u8) -> Result<(), SignalDeliveryError> {
898        if let Ok(sig) = sig.try_into() {
899            self.signal_process(sig);
900            Ok(())
901        } else {
902            Err(SignalDeliveryError)
903        }
904    }
905}