1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
use super::*;
use crate::syscalls::*;

/// ### `stack_checkpoint()`
/// Creates a snapshot of the current stack which allows it to be restored
/// later using its stack hash.
#[instrument(level = "trace", skip_all, ret)]
pub fn stack_checkpoint<M: MemorySize>(
    mut ctx: FunctionEnvMut<'_, WasiEnv>,
    snapshot_ptr: WasmPtr<StackSnapshot, M>,
    ret_val: WasmPtr<Longsize, M>,
) -> Result<Errno, WasiError> {
    // If we were just restored then we need to return the value instead
    if let Some(val) = unsafe { handle_rewind::<M, Longsize>(&mut ctx) } {
        let env = ctx.data();
        let memory = unsafe { env.memory_view(&ctx) };
        wasi_try_mem_ok!(ret_val.write(&memory, val));
        trace!("restored - (ret={})", val);
        return Ok(Errno::Success);
    }
    trace!("capturing",);

    wasi_try_ok!(WasiEnv::process_signals_and_exit(&mut ctx)?);

    // Set the return value that we will give back to
    // indicate we are a normal function call that has not yet
    // been restored
    let env = ctx.data();
    let memory = unsafe { env.memory_view(&ctx) };
    wasi_try_mem_ok!(ret_val.write(&memory, 0));

    // Pass some offsets to the unwind function
    let ret_offset = ret_val.offset();
    let snapshot_offset = snapshot_ptr.offset();
    let secret = env.state().secret;

    // We clear the target memory location before we grab the stack so that
    // it correctly hashes
    if let Err(err) = snapshot_ptr.write(&memory, StackSnapshot::new(0, 0)) {
        warn!(
            %err
        );
    }

    // Perform the unwind action
    unwind::<M, _>(ctx, move |mut ctx, mut memory_stack, rewind_stack| {
        // Grab all the globals and serialize them
        let store_data = crate::utils::store::capture_store_snapshot(&mut ctx.as_store_mut())
            .serialize()
            .unwrap();
        let env = ctx.data();
        let store_data = Bytes::from(store_data);

        // We compute the hash again for two reasons... integrity so if there
        // is a long jump that goes to the wrong place it will fail gracefully.
        // and security so that the stack can not be used to attempt to break
        // out of the sandbox
        let hash = {
            use sha2::{Digest, Sha256};
            let mut hasher = Sha256::new();
            hasher.update(&secret[..]);
            hasher.update(&memory_stack[..]);
            hasher.update(&rewind_stack[..]);
            hasher.update(&store_data[..]);
            let hash: [u8; 16] = hasher.finalize()[..16].try_into().unwrap();
            u128::from_le_bytes(hash)
        };

        // Build a stack snapshot
        let snapshot = StackSnapshot::new(ret_offset.into(), hash);

        // Get a reference directly to the bytes of snapshot
        let val_bytes = unsafe {
            let p = &snapshot;
            ::std::slice::from_raw_parts(
                (p as *const StackSnapshot) as *const u8,
                ::std::mem::size_of::<StackSnapshot>(),
            )
        };

        // The snapshot may itself reside on the stack (which means we
        // need to update the memory stack rather than write to the memory
        // as otherwise the rewind will wipe out the structure)
        // This correct memory stack is stored as well for validation purposes
        let mut memory_stack_corrected = memory_stack.clone();
        {
            let snapshot_offset: u64 = snapshot_offset.into();
            if snapshot_offset >= env.layout.stack_lower
                && (snapshot_offset + val_bytes.len() as u64) <= env.layout.stack_upper
            {
                // Make sure its within the "active" part of the memory stack
                // (note - the area being written to might not go past the memory pointer)
                let offset = env.layout.stack_upper - snapshot_offset;
                if (offset as usize) < memory_stack_corrected.len() {
                    let left = memory_stack_corrected.len() - (offset as usize);
                    let end = offset + (val_bytes.len().min(left) as u64);
                    if end as usize <= memory_stack_corrected.len() {
                        let pstart = memory_stack_corrected.len() - offset as usize;
                        let pend = pstart + val_bytes.len();
                        let pbytes = &mut memory_stack_corrected[pstart..pend];
                        pbytes.clone_from_slice(val_bytes);
                    }
                }
            }
        }

        /// Add a snapshot to the stack
        ctx.data().thread.add_snapshot(
            &memory_stack[..],
            &memory_stack_corrected[..],
            hash,
            &rewind_stack[..],
            &store_data[..],
        );
        trace!(hash = snapshot.hash(), user = snapshot.user);

        // Save the stack snapshot
        let env = ctx.data();
        let memory = unsafe { env.memory_view(&ctx) };
        let snapshot_ptr: WasmPtr<StackSnapshot, M> = WasmPtr::new(snapshot_offset);
        if let Err(err) = snapshot_ptr.write(&memory, snapshot) {
            warn!("could not save stack snapshot - {}", err);
            return OnCalledAction::Trap(Box::new(WasiError::Exit(mem_error_to_wasi(err).into())));
        }

        // Rewind the stack and carry on
        let pid = ctx.data().pid();
        let tid = ctx.data().tid();
        match rewind::<M, _>(
            ctx,
            memory_stack_corrected.freeze(),
            rewind_stack.freeze(),
            store_data,
            0 as Longsize,
        ) {
            Errno::Success => OnCalledAction::InvokeAgain,
            err => {
                warn!(
                    "failed checkpoint - could not rewind the stack - errno={}",
                    err
                );
                OnCalledAction::Trap(Box::new(WasiError::Exit(err.into())))
            }
        }
    })
}