wasmer_sys_utils/memory/fd_memory/
fd_mmap.rs

1// This file contains code from external sources.
2// Attributions: https://github.com/wasmerio/wasmer/blob/main/docs/ATTRIBUTIONS.md
3
4use std::{
5    io::{self, Read, Write},
6    ptr, slice,
7};
8
9// /// Round `size` up to the nearest multiple of `page_size`.
10// fn round_up_to_page_size(size: usize, page_size: usize) -> usize {
11//     (size + (page_size - 1)) & !(page_size - 1)
12// }
13
14/// A simple struct consisting of a page-aligned pointer to page-aligned
15/// and initially-zeroed memory and a length.
16#[derive(Debug)]
17pub struct FdMmap {
18    // Note that this is stored as a `usize` instead of a `*const` or `*mut`
19    // pointer to allow this structure to be natively `Send` and `Sync` without
20    // `unsafe impl`. This type is sendable across threads and shareable since
21    // the coordination all happens at the OS layer.
22    ptr: usize,
23    len: usize,
24    // Backing file that will be closed when the memory mapping goes out of scope
25    fd: FdGuard,
26}
27
28#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
29pub struct FdGuard(pub i32);
30
31impl Default for FdGuard {
32    fn default() -> Self {
33        Self(-1)
34    }
35}
36
37impl Clone for FdGuard {
38    fn clone(&self) -> Self {
39        unsafe { Self(libc::dup(self.0)) }
40    }
41}
42
43impl Drop for FdGuard {
44    fn drop(&mut self) {
45        if self.0 >= 0 {
46            unsafe {
47                libc::close(self.0);
48            }
49            self.0 = -1;
50        }
51    }
52}
53
54impl FdMmap {
55    /// Construct a new empty instance of `Mmap`.
56    pub fn new() -> Self {
57        // Rust's slices require non-null pointers, even when empty. `Vec`
58        // contains code to create a non-null dangling pointer value when
59        // constructed empty, so we reuse that here.
60        let empty = Vec::<u8>::new();
61        Self {
62            ptr: empty.as_ptr() as usize,
63            len: 0,
64            fd: FdGuard::default(),
65        }
66    }
67
68    // /// Create a new `Mmap` pointing to at least `size` bytes of page-aligned accessible memory.
69    // pub fn with_at_least(size: usize) -> Result<Self, String> {
70    //     let page_size = region::page::size();
71    //     let rounded_size = round_up_to_page_size(size, page_size);
72    //     Self::accessible_reserved(rounded_size, rounded_size)
73    // }
74
75    /// Create a new `Mmap` pointing to `accessible_size` bytes of page-aligned accessible memory,
76    /// within a reserved mapping of `mapping_size` bytes. `accessible_size` and `mapping_size`
77    /// must be native page-size multiples.
78    pub fn accessible_reserved(
79        accessible_size: usize,
80        mapping_size: usize,
81    ) -> Result<Self, String> {
82        let page_size = region::page::size();
83        assert!(accessible_size <= mapping_size);
84        assert_eq!(mapping_size & (page_size - 1), 0);
85        assert_eq!(accessible_size & (page_size - 1), 0);
86
87        // Mmap may return EINVAL if the size is zero, so just
88        // special-case that.
89        if mapping_size == 0 {
90            return Ok(Self::new());
91        }
92
93        // Open a temporary file (which is used for swapping)
94        let fd = unsafe {
95            let file = libc::tmpfile();
96            if file.is_null() {
97                return Err(format!(
98                    "failed to create temporary file - {}",
99                    io::Error::last_os_error()
100                ));
101            }
102            FdGuard(libc::fileno(file))
103        };
104
105        // First we initialize it with zeros
106        unsafe {
107            if libc::ftruncate(fd.0, mapping_size as libc::off_t) < 0 {
108                return Err("could not truncate tmpfile".to_string());
109            }
110        }
111
112        Ok(if accessible_size == mapping_size {
113            // Allocate a single read-write region at once.
114            let ptr = unsafe {
115                libc::mmap(
116                    ptr::null_mut(),
117                    mapping_size,
118                    libc::PROT_READ | libc::PROT_WRITE,
119                    libc::MAP_FILE | libc::MAP_SHARED,
120                    fd.0,
121                    0,
122                )
123            };
124            if ptr as isize == -1_isize {
125                return Err(io::Error::last_os_error().to_string());
126            }
127
128            Self {
129                ptr: ptr as usize,
130                len: mapping_size,
131                fd,
132            }
133        } else {
134            // Reserve the mapping size.
135            let ptr = unsafe {
136                libc::mmap(
137                    ptr::null_mut(),
138                    mapping_size,
139                    libc::PROT_NONE,
140                    libc::MAP_FILE | libc::MAP_SHARED,
141                    fd.0,
142                    0,
143                )
144            };
145            if ptr as isize == -1_isize {
146                return Err(io::Error::last_os_error().to_string());
147            }
148
149            let mut result = Self {
150                ptr: ptr as usize,
151                len: mapping_size,
152                fd,
153            };
154
155            if accessible_size != 0 {
156                // Commit the accessible size.
157                result.make_accessible(0, accessible_size)?;
158            }
159
160            result
161        })
162    }
163
164    /// Make the memory starting at `start` and extending for `len` bytes accessible.
165    /// `start` and `len` must be native page-size multiples and describe a range within
166    /// `self`'s reserved memory.
167    pub fn make_accessible(&mut self, start: usize, len: usize) -> Result<(), String> {
168        let page_size = region::page::size();
169        assert_eq!(start & (page_size - 1), 0);
170        assert_eq!(len & (page_size - 1), 0);
171        assert!(len < self.len);
172        assert!(start < self.len - len);
173
174        // Commit the accessible size.
175        let ptr = self.ptr as *const u8;
176        unsafe { region::protect(ptr.add(start), len, region::Protection::READ_WRITE) }
177            .map_err(|e| e.to_string())
178    }
179
180    /// Return the allocated memory as a slice of u8.
181    pub fn as_slice(&self) -> &[u8] {
182        unsafe { slice::from_raw_parts(self.ptr as *const u8, self.len) }
183    }
184
185    /// Return the allocated memory as a mutable slice of u8.
186    pub fn as_mut_slice(&mut self) -> &mut [u8] {
187        unsafe { slice::from_raw_parts_mut(self.ptr as *mut u8, self.len) }
188    }
189
190    // /// Return the allocated memory as a pointer to u8.
191    // pub fn as_ptr(&self) -> *const u8 {
192    //     self.ptr as *const u8
193    // }
194
195    /// Return the allocated memory as a mutable pointer to u8.
196    pub fn as_mut_ptr(&mut self) -> *mut u8 {
197        self.ptr as *mut u8
198    }
199
200    /// Return the length of the allocated memory.
201    pub fn len(&self) -> usize {
202        self.len
203    }
204
205    // /// Return whether any memory has been allocated.
206    // pub fn is_empty(&self) -> bool {
207    //     self.len() == 0
208    // }
209
210    /// Copies the memory to a new swap file (using copy-on-write if available)
211    pub fn duplicate(&mut self, hint_used: Option<usize>) -> Result<Self, String> {
212        // Empty memory is an edge case
213
214        use std::os::unix::prelude::FromRawFd;
215        if self.len == 0 {
216            return Ok(Self::new());
217        }
218
219        // First we sync all the data to the backing file
220        unsafe {
221            libc::fsync(self.fd.0);
222        }
223
224        // Open a new temporary file (which is used for swapping for the forked memory)
225        let fd = unsafe {
226            let file = libc::tmpfile();
227            if file.is_null() {
228                return Err(format!(
229                    "failed to create temporary file - {}",
230                    io::Error::last_os_error()
231                ));
232            }
233            FdGuard(libc::fileno(file))
234        };
235
236        // Attempt to do a shallow copy (needs a backing file system that supports it)
237        unsafe {
238            if libc::ioctl(fd.0, 0x94, 9, self.fd.0) != 0
239            // FICLONE
240            {
241                #[cfg(feature = "tracing")]
242                trace!("memory copy started");
243
244                // Determine host much to copy
245                let len = match hint_used {
246                    Some(a) => a,
247                    None => self.len,
248                };
249
250                // The shallow copy failed so we have to do it the hard way
251
252                let mut source = std::fs::File::from_raw_fd(self.fd.0);
253                let mut out = std::fs::File::from_raw_fd(fd.0);
254                copy_file_range(&mut source, 0, &mut out, 0, len)
255                    .map_err(|err| format!("Could not copy memory: {err}"))?;
256
257                #[cfg(feature = "tracing")]
258                trace!("memory copy finished (size={})", len);
259            }
260        }
261
262        // Compute the flags
263        let flags = libc::MAP_FILE | libc::MAP_SHARED;
264
265        // Allocate a single read-write region at once.
266        let ptr = unsafe {
267            libc::mmap(
268                ptr::null_mut(),
269                self.len,
270                libc::PROT_READ | libc::PROT_WRITE,
271                flags,
272                fd.0,
273                0,
274            )
275        };
276        if ptr as isize == -1_isize {
277            return Err(io::Error::last_os_error().to_string());
278        }
279
280        Ok(Self {
281            ptr: ptr as usize,
282            len: self.len,
283            fd,
284        })
285    }
286}
287
288impl Drop for FdMmap {
289    fn drop(&mut self) {
290        if self.len != 0 {
291            let r = unsafe { libc::munmap(self.ptr as *mut libc::c_void, self.len) };
292            assert_eq!(r, 0, "munmap failed: {}", io::Error::last_os_error());
293        }
294    }
295}
296
297/// Copy a range of a file to another file.
298// We could also use libc::copy_file_range on some systems, but it's
299// hard to do this because it is not available on many libc implementations.
300// (not on Mac OS, musl, ...)
301#[cfg(target_family = "unix")]
302fn copy_file_range(
303    source: &mut std::fs::File,
304    source_offset: u64,
305    out: &mut std::fs::File,
306    out_offset: u64,
307    len: usize,
308) -> Result<(), std::io::Error> {
309    use std::io::{Seek, SeekFrom};
310
311    let source_original_pos = source.stream_position()?;
312    source.seek(SeekFrom::Start(source_offset))?;
313
314    // TODO: don't cast with as
315
316    let out_original_pos = out.stream_position()?;
317    out.seek(SeekFrom::Start(out_offset))?;
318
319    // TODO: don't do this horrible "triple buffering" below".
320    // let mut reader = std::io::BufReader::new(source);
321
322    // TODO: larger buffer?
323    let mut buffer = vec![0u8; 4096];
324
325    let mut to_read = len;
326    while to_read > 0 {
327        let chunk_size = std::cmp::min(to_read, buffer.len());
328        let read = source.read(&mut buffer[0..chunk_size])?;
329        out.write_all(&buffer[0..read])?;
330        to_read -= read;
331    }
332
333    // Need to read the last chunk.
334    out.flush()?;
335
336    // Restore files to original position.
337    source.seek(SeekFrom::Start(source_original_pos))?;
338    out.flush()?;
339    out.sync_data()?;
340    out.seek(SeekFrom::Start(out_original_pos))?;
341
342    Ok(())
343}
344
345#[cfg(test)]
346mod tests {
347    use super::*;
348
349    // #[test]
350    // fn test_round_up_to_page_size() {
351    //     assert_eq!(round_up_to_page_size(0, 4096), 0);
352    //     assert_eq!(round_up_to_page_size(1, 4096), 4096);
353    //     assert_eq!(round_up_to_page_size(4096, 4096), 4096);
354    //     assert_eq!(round_up_to_page_size(4097, 4096), 8192);
355    // }
356
357    #[cfg(target_family = "unix")]
358    #[test]
359    fn test_copy_file_range() -> Result<(), std::io::Error> {
360        // I know tempfile:: exists, but this doesn't bring in an extra
361        // dependency.
362
363        use std::{fs::OpenOptions, io::Seek};
364
365        let dir = std::env::temp_dir().join("wasmer/copy_file_range");
366        if dir.is_dir() {
367            std::fs::remove_dir_all(&dir).unwrap()
368        }
369        std::fs::create_dir_all(&dir).unwrap();
370
371        let pa = dir.join("a");
372        let pb = dir.join("b");
373
374        let data: Vec<u8> = (0..100).collect();
375        let mut a = OpenOptions::new()
376            .read(true)
377            .write(true)
378            .create_new(true)
379            .open(pa)
380            .unwrap();
381        a.write_all(&data).unwrap();
382
383        let datb: Vec<u8> = (100..200).collect();
384        let mut b = OpenOptions::new()
385            .read(true)
386            .write(true)
387            .create_new(true)
388            .open(pb)
389            .unwrap();
390        b.write_all(&datb).unwrap();
391
392        a.seek(io::SeekFrom::Start(30)).unwrap();
393        b.seek(io::SeekFrom::Start(99)).unwrap();
394        copy_file_range(&mut a, 10, &mut b, 40, 15).unwrap();
395
396        assert_eq!(a.stream_position().unwrap(), 30);
397        assert_eq!(b.stream_position().unwrap(), 99);
398
399        b.seek(io::SeekFrom::Start(0)).unwrap();
400        let mut out = Vec::new();
401        let len = b.read_to_end(&mut out).unwrap();
402        assert_eq!(len, 100);
403        assert_eq!(out[0..40], datb[0..40]);
404        assert_eq!(out[40..55], data[10..25]);
405        assert_eq!(out[55..100], datb[55..100]);
406
407        // TODO: needs more variant tests, but this is enough for now.
408
409        Ok(())
410    }
411}