wasmer_wasix/syscalls/wasix/
stack_checkpoint.rs

1use super::*;
2use crate::syscalls::*;
3
4/// ### `stack_checkpoint()`
5/// Creates a snapshot of the current stack which allows it to be restored
6/// later using its stack hash.
7#[instrument(level = "trace", skip_all, ret)]
8pub fn stack_checkpoint<M: MemorySize>(
9    mut ctx: FunctionEnvMut<'_, WasiEnv>,
10    snapshot_ptr: WasmPtr<StackSnapshot, M>,
11    ret_val: WasmPtr<Longsize, M>,
12) -> Result<Errno, WasiError> {
13    // If we were just restored then we need to return the value instead
14    if let Some(val) = unsafe { handle_rewind::<M, Longsize>(&mut ctx) } {
15        let env = ctx.data();
16        let memory = unsafe { env.memory_view(&ctx) };
17        wasi_try_mem_ok!(ret_val.write(&memory, val));
18        trace!("restored - (ret={})", val);
19        return Ok(Errno::Success);
20    }
21    trace!("capturing");
22
23    WasiEnv::do_pending_operations(&mut ctx)?;
24
25    // Set the return value that we will give back to
26    // indicate we are a normal function call that has not yet
27    // been restored
28    let env = ctx.data();
29    let memory = unsafe { env.memory_view(&ctx) };
30    wasi_try_mem_ok!(ret_val.write(&memory, 0));
31
32    // Pass some offsets to the unwind function
33    let ret_offset = ret_val.offset();
34    let snapshot_offset = snapshot_ptr.offset();
35    let secret = env.state().secret;
36
37    // We clear the target memory location before we grab the stack so that
38    // it correctly hashes
39    if let Err(err) = snapshot_ptr.write(&memory, StackSnapshot::new(0, 0)) {
40        warn!(
41            %err
42        );
43    }
44
45    // Perform the unwind action
46    unwind::<M, _>(ctx, move |mut ctx, mut memory_stack, rewind_stack| {
47        // Grab all the globals and serialize them
48        let store_data = crate::utils::store::capture_store_snapshot(&mut ctx.as_store_mut())
49            .serialize()
50            .unwrap();
51        let env = ctx.data();
52        let store_data = Bytes::from(store_data);
53
54        // We compute the hash again for two reasons... integrity so if there
55        // is a long jump that goes to the wrong place it will fail gracefully.
56        // and security so that the stack can not be used to attempt to break
57        // out of the sandbox
58        let hash = {
59            use sha2::{Digest, Sha256};
60            let mut hasher = Sha256::new();
61            hasher.update(&secret[..]);
62            hasher.update(&memory_stack[..]);
63            hasher.update(&rewind_stack[..]);
64            hasher.update(&store_data[..]);
65            let hash: [u8; 16] = hasher.finalize()[..16].try_into().unwrap();
66            u128::from_le_bytes(hash)
67        };
68
69        // Build a stack snapshot
70        let snapshot = StackSnapshot::new(ret_offset.into(), hash);
71
72        // Get a reference directly to the bytes of snapshot
73        let val_bytes = unsafe {
74            let p = &snapshot;
75            ::std::slice::from_raw_parts(
76                (p as *const StackSnapshot) as *const u8,
77                ::std::mem::size_of::<StackSnapshot>(),
78            )
79        };
80
81        // The snapshot may itself reside on the stack (which means we
82        // need to update the memory stack rather than write to the memory
83        // as otherwise the rewind will wipe out the structure)
84        // This correct memory stack is stored as well for validation purposes
85        let mut memory_stack_corrected = memory_stack.clone();
86        {
87            let snapshot_offset: u64 = snapshot_offset.into();
88            if snapshot_offset >= env.layout.stack_lower
89                && (snapshot_offset + val_bytes.len() as u64) <= env.layout.stack_upper
90            {
91                // Make sure its within the "active" part of the memory stack
92                // (note - the area being written to might not go past the memory pointer)
93                let offset = env.layout.stack_upper - snapshot_offset;
94                if (offset as usize) < memory_stack_corrected.len() {
95                    let left = memory_stack_corrected.len() - (offset as usize);
96                    let end = offset + (val_bytes.len().min(left) as u64);
97                    if end as usize <= memory_stack_corrected.len() {
98                        let pstart = memory_stack_corrected.len() - offset as usize;
99                        let pend = pstart + val_bytes.len();
100                        let pbytes = &mut memory_stack_corrected[pstart..pend];
101                        pbytes.clone_from_slice(val_bytes);
102                    }
103                }
104            }
105        }
106
107        /// Add a snapshot to the stack
108        ctx.data().thread.add_snapshot(
109            &memory_stack[..],
110            &memory_stack_corrected[..],
111            hash,
112            &rewind_stack[..],
113            &store_data[..],
114        );
115        trace!(hash = snapshot.hash(), user = snapshot.user);
116
117        // Save the stack snapshot
118        let env = ctx.data();
119        let memory = unsafe { env.memory_view(&ctx) };
120        let snapshot_ptr: WasmPtr<StackSnapshot, M> = WasmPtr::new(snapshot_offset);
121        if let Err(err) = snapshot_ptr.write(&memory, snapshot) {
122            warn!("could not save stack snapshot - {}", err);
123            return OnCalledAction::Trap(Box::new(WasiError::Exit(mem_error_to_wasi(err).into())));
124        }
125
126        // Rewind the stack and carry on
127        let pid = ctx.data().pid();
128        let tid = ctx.data().tid();
129        match rewind::<M, _>(
130            ctx,
131            Some(memory_stack_corrected.freeze()),
132            rewind_stack.freeze(),
133            store_data,
134            0 as Longsize,
135        ) {
136            Errno::Success => OnCalledAction::InvokeAgain,
137            err => {
138                warn!(
139                    "failed checkpoint - could not rewind the stack - errno={}",
140                    err
141                );
142                OnCalledAction::Trap(Box::new(WasiError::Exit(err.into())))
143            }
144        }
145    })
146}