1use crate::rt::RawMem;
2use crate::{Endian, GuestError, Le};
3use std::collections::HashSet;
4use std::convert::TryInto;
5use std::marker;
6use std::mem;
7use wasmer::RuntimeError;
8
9pub struct BorrowChecker<'a> {
12 shared_borrows: HashSet<Region>,
17 mut_borrows: HashSet<Region>,
18 _marker: marker::PhantomData<&'a mut [u8]>,
19 ptr: *mut u8,
20 len: usize,
21}
22
23unsafe impl Send for BorrowChecker<'_> {}
26unsafe impl Sync for BorrowChecker<'_> {}
27
28fn to_error(err: impl std::fmt::Display) -> RuntimeError {
29 RuntimeError::new(err.to_string())
30}
31
32impl<'a> BorrowChecker<'a> {
33 pub fn new(data: &'a mut [u8]) -> BorrowChecker<'a> {
34 BorrowChecker {
35 ptr: data.as_mut_ptr(),
36 len: data.len(),
37 shared_borrows: Default::default(),
38 mut_borrows: Default::default(),
39 _marker: marker::PhantomData,
40 }
41 }
42
43 pub fn slice<T: AllBytesValid>(&mut self, ptr: i32, len: i32) -> Result<&'a [T], RuntimeError> {
44 let (ret, r) = self.get_slice(ptr, len)?;
45 let ret = unsafe { &*(ret as *const [T]) };
51 self.shared_borrows.insert(r);
52 Ok(ret)
53 }
54
55 pub fn slice_mut<T: AllBytesValid>(
56 &mut self,
57 ptr: i32,
58 len: i32,
59 ) -> Result<&'a mut [T], RuntimeError> {
60 let (ret, r) = self.get_slice_mut(ptr, len)?;
61 let ret = unsafe { &mut *(ret as *mut [T]) };
66 self.mut_borrows.insert(r);
67 Ok(ret)
68 }
69
70 fn get_slice<T: AllBytesValid>(
71 &self,
72 ptr: i32,
73 len: i32,
74 ) -> Result<(&[T], Region), RuntimeError> {
75 let r = self.region::<T>(ptr, len)?;
76 if self.is_mut_borrowed(r) {
77 Err(to_error(GuestError::PtrBorrowed(r)))
78 } else {
79 Ok((
80 unsafe {
93 std::slice::from_raw_parts(
94 self.ptr.add(r.start as usize) as *const T,
95 len as usize,
96 )
97 },
98 r,
99 ))
100 }
101 }
102
103 fn get_slice_mut<T>(&mut self, ptr: i32, len: i32) -> Result<(&mut [T], Region), RuntimeError> {
104 let r = self.region::<T>(ptr, len)?;
105 if self.is_mut_borrowed(r) || self.is_shared_borrowed(r) {
106 Err(to_error(GuestError::PtrBorrowed(r)))
107 } else {
108 Ok((
109 unsafe {
112 std::slice::from_raw_parts_mut(
113 self.ptr.add(r.start as usize) as *mut T,
114 len as usize,
115 )
116 },
117 r,
118 ))
119 }
120 }
121
122 fn region<T>(&self, ptr: i32, len: i32) -> Result<Region, RuntimeError> {
123 assert_eq!(std::mem::align_of::<T>(), 1);
124 let r = Region {
125 start: ptr as u32,
126 len: (len as u32)
127 .checked_mul(mem::size_of::<T>() as u32)
128 .ok_or_else(|| to_error(GuestError::PtrOverflow))?,
129 };
130 self.validate_contains(&r)?;
131 Ok(r)
132 }
133
134 pub fn slice_str(&mut self, ptr: i32, len: i32) -> Result<&'a str, RuntimeError> {
135 let bytes = self.slice(ptr, len)?;
136 std::str::from_utf8(bytes).map_err(to_error)
137 }
138
139 fn validate_contains(&self, region: &Region) -> Result<(), RuntimeError> {
140 let end = region
141 .start
142 .checked_add(region.len)
143 .ok_or_else(|| to_error(GuestError::PtrOverflow))? as usize;
144 if end <= self.len {
145 Ok(())
146 } else {
147 Err(to_error(GuestError::PtrOutOfBounds(*region)))
148 }
149 }
150
151 fn is_shared_borrowed(&self, r: Region) -> bool {
152 self.shared_borrows.iter().any(|b| b.overlaps(r))
153 }
154
155 fn is_mut_borrowed(&self, r: Region) -> bool {
156 self.mut_borrows.iter().any(|b| b.overlaps(r))
157 }
158
159 pub fn raw(&self) -> *mut [u8] {
160 std::ptr::slice_from_raw_parts_mut(self.ptr, self.len)
161 }
162}
163
164impl RawMem for BorrowChecker<'_> {
165 fn store<T: Endian>(&mut self, offset: i32, val: T) -> Result<(), RuntimeError> {
166 let (slice, _) = self.get_slice_mut::<Le<T>>(offset, 1)?;
167 slice[0].set(val);
168 Ok(())
169 }
170
171 fn store_many<T: Endian>(&mut self, offset: i32, val: &[T]) -> Result<(), RuntimeError> {
172 let (slice, _) = self.get_slice_mut::<Le<T>>(
173 offset,
174 val.len()
175 .try_into()
176 .map_err(|_| to_error(GuestError::PtrOverflow))?,
177 )?;
178 for (slot, val) in slice.iter_mut().zip(val) {
179 slot.set(*val);
180 }
181 Ok(())
182 }
183
184 fn load<T: Endian>(&self, offset: i32) -> Result<T, RuntimeError> {
185 let (slice, _) = self.get_slice::<Le<T>>(offset, 1)?;
186 Ok(slice[0].get())
187 }
188}
189
190pub unsafe trait AllBytesValid {}
201
202unsafe impl AllBytesValid for u8 {}
203unsafe impl AllBytesValid for u16 {}
204unsafe impl AllBytesValid for u32 {}
205unsafe impl AllBytesValid for u64 {}
206unsafe impl AllBytesValid for i8 {}
207unsafe impl AllBytesValid for i16 {}
208unsafe impl AllBytesValid for i32 {}
209unsafe impl AllBytesValid for i64 {}
210unsafe impl AllBytesValid for f32 {}
211unsafe impl AllBytesValid for f64 {}
212
213macro_rules! tuples {
214 ($(($($t:ident)*))*) => ($(
215 unsafe impl <$($t:AllBytesValid,)*> AllBytesValid for ($($t,)*) {}
216 )*)
217}
218
219tuples! {
220 ()
221 (T1)
222 (T1 T2)
223 (T1 T2 T3)
224 (T1 T2 T3 T4)
225 (T1 T2 T3 T4 T5)
226 (T1 T2 T3 T4 T5 T6)
227 (T1 T2 T3 T4 T5 T6 T7)
228 (T1 T2 T3 T4 T5 T6 T7 T8)
229 (T1 T2 T3 T4 T5 T6 T7 T8 T9)
230 (T1 T2 T3 T4 T5 T6 T7 T8 T9 T10)
231}
232
233#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
235pub struct Region {
236 pub start: u32,
237 pub len: u32,
238}
239
240impl Region {
241 fn overlaps(&self, rhs: Region) -> bool {
243 if self.len == 0 || rhs.len == 0 {
245 return false;
246 }
247
248 let self_start = self.start as u64;
249 let self_end = self_start + (self.len - 1) as u64;
250
251 let rhs_start = rhs.start as u64;
252 let rhs_end = rhs_start + (rhs.len - 1) as u64;
253
254 if self_start <= rhs_start {
255 self_end >= rhs_start
256 } else {
257 rhs_end >= self_start
258 }
259 }
260}
261
262#[cfg(test)]
263mod test {
264 use super::*;
265
266 #[test]
267 fn nonoverlapping() {
268 let mut bytes = [0; 100];
269 let mut bc = BorrowChecker::new(&mut bytes);
270 bc.slice::<u8>(0, 10).unwrap();
271 bc.slice::<u8>(10, 10).unwrap();
272
273 let mut bc = BorrowChecker::new(&mut bytes);
274 bc.slice::<u8>(10, 10).unwrap();
275 bc.slice::<u8>(0, 10).unwrap();
276
277 let mut bc = BorrowChecker::new(&mut bytes);
278 bc.slice_mut::<u8>(0, 10).unwrap();
279 bc.slice_mut::<u8>(10, 10).unwrap();
280
281 let mut bc = BorrowChecker::new(&mut bytes);
282 bc.slice_mut::<u8>(10, 10).unwrap();
283 bc.slice_mut::<u8>(0, 10).unwrap();
284 }
285
286 #[test]
287 fn overlapping() {
288 let mut bytes = [0; 100];
289 let mut bc = BorrowChecker::new(&mut bytes);
290 bc.slice::<u8>(0, 10).unwrap();
291 bc.slice_mut::<u8>(9, 10).unwrap_err();
292 bc.slice::<u8>(9, 10).unwrap();
293
294 let mut bc = BorrowChecker::new(&mut bytes);
295 bc.slice::<u8>(0, 10).unwrap();
296 bc.slice_mut::<u8>(2, 5).unwrap_err();
297 bc.slice::<u8>(2, 5).unwrap();
298
299 let mut bc = BorrowChecker::new(&mut bytes);
300 bc.slice::<u8>(9, 10).unwrap();
301 bc.slice_mut::<u8>(0, 10).unwrap_err();
302 bc.slice::<u8>(0, 10).unwrap();
303
304 let mut bc = BorrowChecker::new(&mut bytes);
305 bc.slice::<u8>(2, 5).unwrap();
306 bc.slice_mut::<u8>(0, 10).unwrap_err();
307 bc.slice::<u8>(0, 10).unwrap();
308
309 let mut bc = BorrowChecker::new(&mut bytes);
310 bc.slice::<u8>(2, 5).unwrap();
311 bc.slice::<u8>(10, 5).unwrap();
312 bc.slice::<u8>(15, 5).unwrap();
313 bc.slice_mut::<u8>(0, 10).unwrap_err();
314 bc.slice::<u8>(0, 10).unwrap();
315 }
316
317 #[test]
318 fn zero_length() {
319 let mut bytes = [0; 100];
320 let mut bc = BorrowChecker::new(&mut bytes);
321 bc.slice_mut::<u8>(0, 0).unwrap();
322 bc.slice_mut::<u8>(0, 0).unwrap();
323 bc.slice::<u8>(0, 1).unwrap();
324 }
325}