wai_bindgen_wasmer/
region.rs

1use crate::rt::RawMem;
2use crate::{Endian, GuestError, Le};
3use std::collections::HashSet;
4use std::convert::TryInto;
5use std::marker;
6use std::mem;
7use wasmer::RuntimeError;
8
9// This is a pretty naive way to account for borrows. This datastructure
10// could be made a lot more efficient with some effort.
11pub struct BorrowChecker<'a> {
12    /// Maps from handle to region borrowed. A HashMap is probably not ideal
13    /// for this but it works. It would be more efficient if we could
14    /// check `is_borrowed` without an O(n) iteration, by organizing borrows
15    /// by an ordering of Region.
16    shared_borrows: HashSet<Region>,
17    mut_borrows: HashSet<Region>,
18    _marker: marker::PhantomData<&'a mut [u8]>,
19    ptr: *mut u8,
20    len: usize,
21}
22
23// These are not automatically implemented with our storage of `*mut u8`, so we
24// need to manually declare that this type is threadsafe.
25unsafe impl Send for BorrowChecker<'_> {}
26unsafe impl Sync for BorrowChecker<'_> {}
27
28fn to_error(err: impl std::fmt::Display) -> RuntimeError {
29    RuntimeError::new(err.to_string())
30}
31
32impl<'a> BorrowChecker<'a> {
33    pub fn new(data: &'a mut [u8]) -> BorrowChecker<'a> {
34        BorrowChecker {
35            ptr: data.as_mut_ptr(),
36            len: data.len(),
37            shared_borrows: Default::default(),
38            mut_borrows: Default::default(),
39            _marker: marker::PhantomData,
40        }
41    }
42
43    pub fn slice<T: AllBytesValid>(&mut self, ptr: i32, len: i32) -> Result<&'a [T], RuntimeError> {
44        let (ret, r) = self.get_slice(ptr, len)?;
45        // SAFETY: We're promoting the valid lifetime of `ret` from a temporary
46        // borrow on `self` to `'a` on this `BorrowChecker`. At the same time
47        // we're recording that this is a persistent shared borrow (until this
48        // borrow checker is deleted), which disallows future mutable borrows
49        // of the same data.
50        let ret = unsafe { &*(ret as *const [T]) };
51        self.shared_borrows.insert(r);
52        Ok(ret)
53    }
54
55    pub fn slice_mut<T: AllBytesValid>(
56        &mut self,
57        ptr: i32,
58        len: i32,
59    ) -> Result<&'a mut [T], RuntimeError> {
60        let (ret, r) = self.get_slice_mut(ptr, len)?;
61        // SAFETY: see `slice` for how we're extending the lifetime by
62        // recording the borrow here. Note that the `mut_borrows` list is
63        // checked on both shared and mutable borrows in the future since a
64        // mutable borrow can't alias with anything.
65        let ret = unsafe { &mut *(ret as *mut [T]) };
66        self.mut_borrows.insert(r);
67        Ok(ret)
68    }
69
70    fn get_slice<T: AllBytesValid>(
71        &self,
72        ptr: i32,
73        len: i32,
74    ) -> Result<(&[T], Region), RuntimeError> {
75        let r = self.region::<T>(ptr, len)?;
76        if self.is_mut_borrowed(r) {
77            Err(to_error(GuestError::PtrBorrowed(r)))
78        } else {
79            Ok((
80                // SAFETY: invariants to uphold:
81                //
82                // * The lifetime of the input is valid for the lifetime of the
83                //   output. In this case we're threading through the lifetime
84                //   of `&self` to the output.
85                // * The actual output is valid, which is guaranteed with the
86                //   `AllBytesValid` bound.
87                // * We uphold Rust's borrowing guarantees, namely that this
88                //   borrow we're returning isn't overlapping with any mutable
89                //   borrows.
90                // * The region `r` we're returning accurately describes the
91                //   slice we're returning in wasm linear memory.
92                unsafe {
93                    std::slice::from_raw_parts(
94                        self.ptr.add(r.start as usize) as *const T,
95                        len as usize,
96                    )
97                },
98                r,
99            ))
100        }
101    }
102
103    fn get_slice_mut<T>(&mut self, ptr: i32, len: i32) -> Result<(&mut [T], Region), RuntimeError> {
104        let r = self.region::<T>(ptr, len)?;
105        if self.is_mut_borrowed(r) || self.is_shared_borrowed(r) {
106            Err(to_error(GuestError::PtrBorrowed(r)))
107        } else {
108            Ok((
109                // SAFETY: same as `get_slice`, except for that we're threading
110                // through `&mut` properties as well.
111                unsafe {
112                    std::slice::from_raw_parts_mut(
113                        self.ptr.add(r.start as usize) as *mut T,
114                        len as usize,
115                    )
116                },
117                r,
118            ))
119        }
120    }
121
122    fn region<T>(&self, ptr: i32, len: i32) -> Result<Region, RuntimeError> {
123        assert_eq!(std::mem::align_of::<T>(), 1);
124        let r = Region {
125            start: ptr as u32,
126            len: (len as u32)
127                .checked_mul(mem::size_of::<T>() as u32)
128                .ok_or_else(|| to_error(GuestError::PtrOverflow))?,
129        };
130        self.validate_contains(&r)?;
131        Ok(r)
132    }
133
134    pub fn slice_str(&mut self, ptr: i32, len: i32) -> Result<&'a str, RuntimeError> {
135        let bytes = self.slice(ptr, len)?;
136        std::str::from_utf8(bytes).map_err(to_error)
137    }
138
139    fn validate_contains(&self, region: &Region) -> Result<(), RuntimeError> {
140        let end = region
141            .start
142            .checked_add(region.len)
143            .ok_or_else(|| to_error(GuestError::PtrOverflow))? as usize;
144        if end <= self.len {
145            Ok(())
146        } else {
147            Err(to_error(GuestError::PtrOutOfBounds(*region)))
148        }
149    }
150
151    fn is_shared_borrowed(&self, r: Region) -> bool {
152        self.shared_borrows.iter().any(|b| b.overlaps(r))
153    }
154
155    fn is_mut_borrowed(&self, r: Region) -> bool {
156        self.mut_borrows.iter().any(|b| b.overlaps(r))
157    }
158
159    pub fn raw(&self) -> *mut [u8] {
160        std::ptr::slice_from_raw_parts_mut(self.ptr, self.len)
161    }
162}
163
164impl RawMem for BorrowChecker<'_> {
165    fn store<T: Endian>(&mut self, offset: i32, val: T) -> Result<(), RuntimeError> {
166        let (slice, _) = self.get_slice_mut::<Le<T>>(offset, 1)?;
167        slice[0].set(val);
168        Ok(())
169    }
170
171    fn store_many<T: Endian>(&mut self, offset: i32, val: &[T]) -> Result<(), RuntimeError> {
172        let (slice, _) = self.get_slice_mut::<Le<T>>(
173            offset,
174            val.len()
175                .try_into()
176                .map_err(|_| to_error(GuestError::PtrOverflow))?,
177        )?;
178        for (slot, val) in slice.iter_mut().zip(val) {
179            slot.set(*val);
180        }
181        Ok(())
182    }
183
184    fn load<T: Endian>(&self, offset: i32) -> Result<T, RuntimeError> {
185        let (slice, _) = self.get_slice::<Le<T>>(offset, 1)?;
186        Ok(slice[0].get())
187    }
188}
189
190/// Unsafe trait representing types where every byte pattern is valid for their
191/// representation.
192///
193/// This is the set of types which wasmer can have a raw pointer to for
194/// values which reside in wasm linear memory.
195///
196/// # Safety
197///
198/// TODO: add safety docs.
199///
200pub unsafe trait AllBytesValid {}
201
202unsafe impl AllBytesValid for u8 {}
203unsafe impl AllBytesValid for u16 {}
204unsafe impl AllBytesValid for u32 {}
205unsafe impl AllBytesValid for u64 {}
206unsafe impl AllBytesValid for i8 {}
207unsafe impl AllBytesValid for i16 {}
208unsafe impl AllBytesValid for i32 {}
209unsafe impl AllBytesValid for i64 {}
210unsafe impl AllBytesValid for f32 {}
211unsafe impl AllBytesValid for f64 {}
212
213macro_rules! tuples {
214    ($(($($t:ident)*))*) => ($(
215        unsafe impl <$($t:AllBytesValid,)*> AllBytesValid for ($($t,)*) {}
216    )*)
217}
218
219tuples! {
220    ()
221    (T1)
222    (T1 T2)
223    (T1 T2 T3)
224    (T1 T2 T3 T4)
225    (T1 T2 T3 T4 T5)
226    (T1 T2 T3 T4 T5 T6)
227    (T1 T2 T3 T4 T5 T6 T7)
228    (T1 T2 T3 T4 T5 T6 T7 T8)
229    (T1 T2 T3 T4 T5 T6 T7 T8 T9)
230    (T1 T2 T3 T4 T5 T6 T7 T8 T9 T10)
231}
232
233/// Represents a contiguous region in memory.
234#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
235pub struct Region {
236    pub start: u32,
237    pub len: u32,
238}
239
240impl Region {
241    /// Checks if this `Region` overlaps with `rhs` `Region`.
242    fn overlaps(&self, rhs: Region) -> bool {
243        // Zero-length regions can never overlap!
244        if self.len == 0 || rhs.len == 0 {
245            return false;
246        }
247
248        let self_start = self.start as u64;
249        let self_end = self_start + (self.len - 1) as u64;
250
251        let rhs_start = rhs.start as u64;
252        let rhs_end = rhs_start + (rhs.len - 1) as u64;
253
254        if self_start <= rhs_start {
255            self_end >= rhs_start
256        } else {
257            rhs_end >= self_start
258        }
259    }
260}
261
262#[cfg(test)]
263mod test {
264    use super::*;
265
266    #[test]
267    fn nonoverlapping() {
268        let mut bytes = [0; 100];
269        let mut bc = BorrowChecker::new(&mut bytes);
270        bc.slice::<u8>(0, 10).unwrap();
271        bc.slice::<u8>(10, 10).unwrap();
272
273        let mut bc = BorrowChecker::new(&mut bytes);
274        bc.slice::<u8>(10, 10).unwrap();
275        bc.slice::<u8>(0, 10).unwrap();
276
277        let mut bc = BorrowChecker::new(&mut bytes);
278        bc.slice_mut::<u8>(0, 10).unwrap();
279        bc.slice_mut::<u8>(10, 10).unwrap();
280
281        let mut bc = BorrowChecker::new(&mut bytes);
282        bc.slice_mut::<u8>(10, 10).unwrap();
283        bc.slice_mut::<u8>(0, 10).unwrap();
284    }
285
286    #[test]
287    fn overlapping() {
288        let mut bytes = [0; 100];
289        let mut bc = BorrowChecker::new(&mut bytes);
290        bc.slice::<u8>(0, 10).unwrap();
291        bc.slice_mut::<u8>(9, 10).unwrap_err();
292        bc.slice::<u8>(9, 10).unwrap();
293
294        let mut bc = BorrowChecker::new(&mut bytes);
295        bc.slice::<u8>(0, 10).unwrap();
296        bc.slice_mut::<u8>(2, 5).unwrap_err();
297        bc.slice::<u8>(2, 5).unwrap();
298
299        let mut bc = BorrowChecker::new(&mut bytes);
300        bc.slice::<u8>(9, 10).unwrap();
301        bc.slice_mut::<u8>(0, 10).unwrap_err();
302        bc.slice::<u8>(0, 10).unwrap();
303
304        let mut bc = BorrowChecker::new(&mut bytes);
305        bc.slice::<u8>(2, 5).unwrap();
306        bc.slice_mut::<u8>(0, 10).unwrap_err();
307        bc.slice::<u8>(0, 10).unwrap();
308
309        let mut bc = BorrowChecker::new(&mut bytes);
310        bc.slice::<u8>(2, 5).unwrap();
311        bc.slice::<u8>(10, 5).unwrap();
312        bc.slice::<u8>(15, 5).unwrap();
313        bc.slice_mut::<u8>(0, 10).unwrap_err();
314        bc.slice::<u8>(0, 10).unwrap();
315    }
316
317    #[test]
318    fn zero_length() {
319        let mut bytes = [0; 100];
320        let mut bc = BorrowChecker::new(&mut bytes);
321        bc.slice_mut::<u8>(0, 0).unwrap();
322        bc.slice_mut::<u8>(0, 0).unwrap();
323        bc.slice::<u8>(0, 1).unwrap();
324    }
325}