virtual_fs/
cow_file.rs

1//! Used for /dev/zero - infinitely returns zero
2//! which is useful for commands like `dd if=/dev/zero of=bigfile.img size=1G`
3
4use derive_more::Debug;
5use replace_with::replace_with_or_abort;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8use std::{
9    future::Future,
10    io::{self, *},
11};
12
13use tokio::io::{AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt, AsyncWrite};
14
15use crate::{BufferFile, VirtualFile};
16
17#[derive(Debug)]
18enum CowState {
19    ReadOnly(Box<dyn VirtualFile + Send + Sync>),
20    Copying {
21        #[debug(skip)]
22        future: Pin<Box<dyn Future<Output = io::Result<BufferFile>> + Send + Sync>>,
23        requested_size: Option<u64>,
24        requested_position: Option<SeekFrom>,
25        cached_size: u64,
26    },
27    Copied(BufferFile),
28}
29
30impl CowState {
31    fn inner_mut(&mut self) -> &mut (dyn VirtualFile + Send + Sync) {
32        match self {
33            Self::ReadOnly(inner) => inner.as_mut(),
34            Self::Copying { .. } => panic!("Cannot access inner file while copying"),
35            Self::Copied(inner) => inner,
36        }
37    }
38}
39
40#[derive(Debug)]
41pub struct CopyOnWriteFile {
42    last_accessed: u64,
43    last_modified: u64,
44    created_time: u64,
45    state: CowState,
46}
47
48impl CopyOnWriteFile {
49    pub fn new(inner: Box<dyn VirtualFile + Send + Sync>) -> Self {
50        Self {
51            last_accessed: inner.last_accessed(),
52            last_modified: inner.last_modified(),
53            created_time: inner.created_time(),
54            state: CowState::ReadOnly(inner),
55        }
56    }
57
58    async fn copy(mut inner: Box<dyn VirtualFile + Send + Sync>) -> io::Result<BufferFile> {
59        let initial_position = inner.seek(SeekFrom::Current(0)).await?;
60        inner.seek(SeekFrom::Start(0)).await?;
61
62        let mut buffer = [0u8; 8192];
63        let mut buffer_file = BufferFile::default();
64        loop {
65            let read_bytes = inner.read_buf(&mut &mut buffer[..]).await?;
66            if read_bytes == 0 {
67                break;
68            }
69            buffer_file.data.write_all(&buffer[0..read_bytes])?;
70        }
71
72        buffer_file.seek(SeekFrom::Start(initial_position)).await?;
73
74        Ok(buffer_file)
75    }
76
77    fn poll_copy_progress(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
78        match self.state {
79            CowState::Copying {
80                ref mut future,
81                requested_size,
82                requested_position,
83                ..
84            } => match future.as_mut().poll(cx) {
85                Poll::Ready(Ok(mut buf)) => {
86                    if let Some(requested_size) = requested_size {
87                        buf.set_len(requested_size)?;
88                    }
89                    if let Some(requested_position) = requested_position {
90                        Pin::new(&mut buf).start_seek(requested_position)?;
91                    }
92                    self.state = CowState::Copied(buf);
93                    Poll::Ready(Ok(()))
94                }
95                Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
96                Poll::Pending => Poll::Pending,
97            },
98            _ => Poll::Ready(Ok(())),
99        }
100    }
101
102    fn start_copy(&mut self) {
103        replace_with_or_abort(&mut self.state, |state| match state {
104            CowState::ReadOnly(inner) => CowState::Copying {
105                cached_size: inner.size(),
106                requested_size: None,
107                requested_position: None,
108                future: Box::pin(Self::copy(inner)),
109            },
110            state => state,
111        });
112    }
113
114    fn poll_copy_start_and_progress(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
115        self.start_copy();
116        self.poll_copy_progress(cx)
117    }
118}
119
120impl AsyncSeek for CopyOnWriteFile {
121    fn start_seek(mut self: Pin<&mut Self>, position: io::SeekFrom) -> io::Result<()> {
122        match self.state {
123            CowState::Copying {
124                ref mut requested_position,
125                ..
126            } => {
127                *requested_position = Some(position);
128                Ok(())
129            }
130
131            _ => Pin::new(self.state.inner_mut()).start_seek(position),
132        }
133    }
134
135    fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
136        match self.poll_copy_progress(cx) {
137            Poll::Ready(Ok(())) => {}
138            Poll::Pending => return Poll::Pending,
139            Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
140        }
141
142        Pin::new(self.state.inner_mut()).poll_complete(cx)
143    }
144}
145
146impl AsyncWrite for CopyOnWriteFile {
147    fn poll_write(
148        mut self: Pin<&mut Self>,
149        cx: &mut Context<'_>,
150        buf: &[u8],
151    ) -> Poll<io::Result<usize>> {
152        match self.poll_copy_start_and_progress(cx) {
153            Poll::Pending => return Poll::Pending,
154            Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
155            Poll::Ready(Ok(())) => {}
156        }
157        Pin::new(self.state.inner_mut()).poll_write(cx, buf)
158    }
159
160    fn poll_write_vectored(
161        mut self: Pin<&mut Self>,
162        cx: &mut Context<'_>,
163        bufs: &[io::IoSlice<'_>],
164    ) -> Poll<io::Result<usize>> {
165        match self.poll_copy_start_and_progress(cx) {
166            Poll::Pending => return Poll::Pending,
167            Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
168            Poll::Ready(Ok(())) => {}
169        }
170        Pin::new(self.state.inner_mut()).poll_write_vectored(cx, bufs)
171    }
172
173    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
174        match self.poll_copy_start_and_progress(cx) {
175            Poll::Ready(Ok(())) => {}
176            p => return p,
177        }
178        Pin::new(self.state.inner_mut()).poll_flush(cx)
179    }
180
181    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
182        match self.poll_copy_start_and_progress(cx) {
183            Poll::Ready(Ok(())) => {}
184            p => return p,
185        }
186        Pin::new(self.state.inner_mut()).poll_shutdown(cx)
187    }
188}
189
190impl AsyncRead for CopyOnWriteFile {
191    fn poll_read(
192        mut self: Pin<&mut Self>,
193        cx: &mut Context<'_>,
194        buf: &mut tokio::io::ReadBuf<'_>,
195    ) -> Poll<io::Result<()>> {
196        match self.poll_copy_progress(cx) {
197            Poll::Ready(Ok(())) => {}
198            p => return p,
199        }
200        Pin::new(self.state.inner_mut()).poll_read(cx, buf)
201    }
202}
203
204impl VirtualFile for CopyOnWriteFile {
205    fn last_accessed(&self) -> u64 {
206        self.last_accessed
207    }
208
209    fn last_modified(&self) -> u64 {
210        self.last_modified
211    }
212
213    fn created_time(&self) -> u64 {
214        self.created_time
215    }
216
217    fn set_times(&mut self, atime: Option<u64>, mtime: Option<u64>) -> crate::Result<()> {
218        if let Some(atime) = atime {
219            self.last_accessed = atime;
220        }
221        if let Some(mtime) = mtime {
222            self.last_modified = mtime;
223        }
224
225        Ok(())
226    }
227
228    fn size(&self) -> u64 {
229        match &self.state {
230            CowState::ReadOnly(inner) => inner.size(),
231            CowState::Copying {
232                requested_size: Some(size),
233                ..
234            } => *size,
235            CowState::Copying { cached_size, .. } => *cached_size,
236            CowState::Copied(buffer_file) => buffer_file.size(),
237        }
238    }
239
240    fn set_len(&mut self, new_size: u64) -> crate::Result<()> {
241        match self.state {
242            CowState::ReadOnly(_) => {
243                self.start_copy();
244                let CowState::Copying {
245                    ref mut requested_size,
246                    ..
247                } = self.state
248                else {
249                    unreachable!()
250                };
251                *requested_size = Some(new_size);
252            }
253
254            CowState::Copying {
255                ref mut requested_size,
256                ..
257            } => {
258                *requested_size = Some(new_size);
259            }
260
261            CowState::Copied(ref mut buf) => {
262                buf.set_len(new_size)?;
263            }
264        }
265
266        Ok(())
267    }
268
269    fn unlink(&mut self) -> crate::Result<()> {
270        // TODO: one can imagine interrupting an in-progress copy here
271        self.set_len(0)
272    }
273
274    fn poll_read_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
275        match self.poll_copy_progress(cx) {
276            Poll::Pending => return Poll::Pending,
277            Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
278            Poll::Ready(Ok(())) => {}
279        }
280        Pin::new(self.state.inner_mut()).poll_read_ready(cx)
281    }
282
283    fn poll_write_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
284        self.poll_copy_progress(cx).map_ok(|_| 8192)
285    }
286}
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291
292    // This is as weird a test as it gets, yes, but I'm (unashamedly!) cramming
293    // everything we know was wrong with the impl into this one test to save time.
294    #[tokio::test]
295    async fn cow_file_works() {
296        let mut data = Vec::with_capacity(16385);
297        for i in 0..16385 {
298            data.push(i as u8);
299        }
300        let inner = BufferFile {
301            data: Cursor::new(data),
302        };
303        let mut file = CopyOnWriteFile::new(Box::new(inner));
304
305        assert!(matches!(file.state, CowState::ReadOnly(_)));
306        assert_eq!(file.size(), 16385);
307        assert_ne!(file.created_time(), 0);
308        assert_ne!(file.last_accessed(), 0);
309        assert_ne!(file.last_modified(), 0);
310
311        let mut buf = [0u8; 4];
312        let read = file.read_exact(buf.as_mut()).await.unwrap();
313        assert_eq!(read, 4);
314        assert_eq!(buf, [0, 1, 2, 3]);
315        assert_eq!(file.seek(SeekFrom::Current(0)).await.unwrap(), 4);
316        assert!(matches!(file.state, CowState::ReadOnly { .. }));
317
318        // After this call, the file will "start" copying, but the actual
319        // future won't be polled until we try to read or write.
320        file.start_copy();
321        assert!(matches!(file.state, CowState::Copying { .. }));
322        assert_eq!(file.size(), 16385);
323
324        // The cached length should be returned while copying
325        file.set_len(16400).unwrap();
326        assert!(matches!(file.state, CowState::Copying { .. }));
327        assert_eq!(file.size(), 16400);
328
329        // Now try to read from the file, which will trigger the copy
330        let read = file.read_exact(buf.as_mut()).await.unwrap();
331        assert!(matches!(file.state, CowState::Copied { .. }));
332        assert_eq!(read, 4);
333        assert_eq!(buf, [4, 5, 6, 7]);
334        assert_eq!(file.seek(SeekFrom::Current(0)).await.unwrap(), 8);
335        assert_eq!(file.size(), 16400);
336
337        file.seek(SeekFrom::Start(16383)).await.unwrap();
338        let read = file.read_exact(buf.as_mut()).await.unwrap();
339        assert_eq!(read, 4);
340        // set_len should have filled the rest with zeroes
341        assert_eq!(buf, [(16383 % 256) as u8, (16384 % 256) as u8, 0, 0]);
342        assert_eq!(file.seek(SeekFrom::Current(0)).await.unwrap(), 16387);
343        assert!(matches!(file.state, CowState::Copied { .. }));
344    }
345}