1use derive_more::Debug;
5use replace_with::replace_with_or_abort;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8use std::{
9 future::Future,
10 io::{self, *},
11};
12
13use tokio::io::{AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt, AsyncWrite};
14
15use crate::{BufferFile, VirtualFile};
16
17#[derive(Debug)]
18enum CowState {
19 ReadOnly(Box<dyn VirtualFile + Send + Sync>),
20 Copying {
21 #[debug(skip)]
22 future: Pin<Box<dyn Future<Output = io::Result<BufferFile>> + Send + Sync>>,
23 requested_size: Option<u64>,
24 requested_position: Option<SeekFrom>,
25 cached_size: u64,
26 },
27 Copied(BufferFile),
28}
29
30impl CowState {
31 fn inner_mut(&mut self) -> &mut (dyn VirtualFile + Send + Sync) {
32 match self {
33 Self::ReadOnly(inner) => inner.as_mut(),
34 Self::Copying { .. } => panic!("Cannot access inner file while copying"),
35 Self::Copied(inner) => inner,
36 }
37 }
38}
39
40#[derive(Debug)]
41pub struct CopyOnWriteFile {
42 last_accessed: u64,
43 last_modified: u64,
44 created_time: u64,
45 state: CowState,
46}
47
48impl CopyOnWriteFile {
49 pub fn new(inner: Box<dyn VirtualFile + Send + Sync>) -> Self {
50 Self {
51 last_accessed: inner.last_accessed(),
52 last_modified: inner.last_modified(),
53 created_time: inner.created_time(),
54 state: CowState::ReadOnly(inner),
55 }
56 }
57
58 async fn copy(mut inner: Box<dyn VirtualFile + Send + Sync>) -> io::Result<BufferFile> {
59 let initial_position = inner.seek(SeekFrom::Current(0)).await?;
60 inner.seek(SeekFrom::Start(0)).await?;
61
62 let mut buffer = [0u8; 8192];
63 let mut buffer_file = BufferFile::default();
64 loop {
65 let read_bytes = inner.read_buf(&mut &mut buffer[..]).await?;
66 if read_bytes == 0 {
67 break;
68 }
69 buffer_file.data.write_all(&buffer[0..read_bytes])?;
70 }
71
72 buffer_file.seek(SeekFrom::Start(initial_position)).await?;
73
74 Ok(buffer_file)
75 }
76
77 fn poll_copy_progress(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
78 match self.state {
79 CowState::Copying {
80 ref mut future,
81 requested_size,
82 requested_position,
83 ..
84 } => match future.as_mut().poll(cx) {
85 Poll::Ready(Ok(mut buf)) => {
86 if let Some(requested_size) = requested_size {
87 buf.set_len(requested_size)?;
88 }
89 if let Some(requested_position) = requested_position {
90 Pin::new(&mut buf).start_seek(requested_position)?;
91 }
92 self.state = CowState::Copied(buf);
93 Poll::Ready(Ok(()))
94 }
95 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
96 Poll::Pending => Poll::Pending,
97 },
98 _ => Poll::Ready(Ok(())),
99 }
100 }
101
102 fn start_copy(&mut self) {
103 replace_with_or_abort(&mut self.state, |state| match state {
104 CowState::ReadOnly(inner) => CowState::Copying {
105 cached_size: inner.size(),
106 requested_size: None,
107 requested_position: None,
108 future: Box::pin(Self::copy(inner)),
109 },
110 state => state,
111 });
112 }
113
114 fn poll_copy_start_and_progress(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
115 self.start_copy();
116 self.poll_copy_progress(cx)
117 }
118}
119
120impl AsyncSeek for CopyOnWriteFile {
121 fn start_seek(mut self: Pin<&mut Self>, position: io::SeekFrom) -> io::Result<()> {
122 match self.state {
123 CowState::Copying {
124 ref mut requested_position,
125 ..
126 } => {
127 *requested_position = Some(position);
128 Ok(())
129 }
130
131 _ => Pin::new(self.state.inner_mut()).start_seek(position),
132 }
133 }
134
135 fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
136 match self.poll_copy_progress(cx) {
137 Poll::Ready(Ok(())) => {}
138 Poll::Pending => return Poll::Pending,
139 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
140 }
141
142 Pin::new(self.state.inner_mut()).poll_complete(cx)
143 }
144}
145
146impl AsyncWrite for CopyOnWriteFile {
147 fn poll_write(
148 mut self: Pin<&mut Self>,
149 cx: &mut Context<'_>,
150 buf: &[u8],
151 ) -> Poll<io::Result<usize>> {
152 match self.poll_copy_start_and_progress(cx) {
153 Poll::Pending => return Poll::Pending,
154 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
155 Poll::Ready(Ok(())) => {}
156 }
157 Pin::new(self.state.inner_mut()).poll_write(cx, buf)
158 }
159
160 fn poll_write_vectored(
161 mut self: Pin<&mut Self>,
162 cx: &mut Context<'_>,
163 bufs: &[io::IoSlice<'_>],
164 ) -> Poll<io::Result<usize>> {
165 match self.poll_copy_start_and_progress(cx) {
166 Poll::Pending => return Poll::Pending,
167 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
168 Poll::Ready(Ok(())) => {}
169 }
170 Pin::new(self.state.inner_mut()).poll_write_vectored(cx, bufs)
171 }
172
173 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
174 match self.poll_copy_start_and_progress(cx) {
175 Poll::Ready(Ok(())) => {}
176 p => return p,
177 }
178 Pin::new(self.state.inner_mut()).poll_flush(cx)
179 }
180
181 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
182 match self.poll_copy_start_and_progress(cx) {
183 Poll::Ready(Ok(())) => {}
184 p => return p,
185 }
186 Pin::new(self.state.inner_mut()).poll_shutdown(cx)
187 }
188}
189
190impl AsyncRead for CopyOnWriteFile {
191 fn poll_read(
192 mut self: Pin<&mut Self>,
193 cx: &mut Context<'_>,
194 buf: &mut tokio::io::ReadBuf<'_>,
195 ) -> Poll<io::Result<()>> {
196 match self.poll_copy_progress(cx) {
197 Poll::Ready(Ok(())) => {}
198 p => return p,
199 }
200 Pin::new(self.state.inner_mut()).poll_read(cx, buf)
201 }
202}
203
204impl VirtualFile for CopyOnWriteFile {
205 fn last_accessed(&self) -> u64 {
206 self.last_accessed
207 }
208
209 fn last_modified(&self) -> u64 {
210 self.last_modified
211 }
212
213 fn created_time(&self) -> u64 {
214 self.created_time
215 }
216
217 fn set_times(&mut self, atime: Option<u64>, mtime: Option<u64>) -> crate::Result<()> {
218 if let Some(atime) = atime {
219 self.last_accessed = atime;
220 }
221 if let Some(mtime) = mtime {
222 self.last_modified = mtime;
223 }
224
225 Ok(())
226 }
227
228 fn size(&self) -> u64 {
229 match &self.state {
230 CowState::ReadOnly(inner) => inner.size(),
231 CowState::Copying {
232 requested_size: Some(size),
233 ..
234 } => *size,
235 CowState::Copying { cached_size, .. } => *cached_size,
236 CowState::Copied(buffer_file) => buffer_file.size(),
237 }
238 }
239
240 fn set_len(&mut self, new_size: u64) -> crate::Result<()> {
241 match self.state {
242 CowState::ReadOnly(_) => {
243 self.start_copy();
244 let CowState::Copying {
245 ref mut requested_size,
246 ..
247 } = self.state
248 else {
249 unreachable!()
250 };
251 *requested_size = Some(new_size);
252 }
253
254 CowState::Copying {
255 ref mut requested_size,
256 ..
257 } => {
258 *requested_size = Some(new_size);
259 }
260
261 CowState::Copied(ref mut buf) => {
262 buf.set_len(new_size)?;
263 }
264 }
265
266 Ok(())
267 }
268
269 fn unlink(&mut self) -> crate::Result<()> {
270 self.set_len(0)
272 }
273
274 fn poll_read_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
275 match self.poll_copy_progress(cx) {
276 Poll::Pending => return Poll::Pending,
277 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
278 Poll::Ready(Ok(())) => {}
279 }
280 Pin::new(self.state.inner_mut()).poll_read_ready(cx)
281 }
282
283 fn poll_write_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
284 self.poll_copy_progress(cx).map_ok(|_| 8192)
285 }
286}
287
288#[cfg(test)]
289mod tests {
290 use super::*;
291
292 #[tokio::test]
295 async fn cow_file_works() {
296 let mut data = Vec::with_capacity(16385);
297 for i in 0..16385 {
298 data.push(i as u8);
299 }
300 let inner = BufferFile {
301 data: Cursor::new(data),
302 };
303 let mut file = CopyOnWriteFile::new(Box::new(inner));
304
305 assert!(matches!(file.state, CowState::ReadOnly(_)));
306 assert_eq!(file.size(), 16385);
307 assert_ne!(file.created_time(), 0);
308 assert_ne!(file.last_accessed(), 0);
309 assert_ne!(file.last_modified(), 0);
310
311 let mut buf = [0u8; 4];
312 let read = file.read_exact(buf.as_mut()).await.unwrap();
313 assert_eq!(read, 4);
314 assert_eq!(buf, [0, 1, 2, 3]);
315 assert_eq!(file.seek(SeekFrom::Current(0)).await.unwrap(), 4);
316 assert!(matches!(file.state, CowState::ReadOnly { .. }));
317
318 file.start_copy();
321 assert!(matches!(file.state, CowState::Copying { .. }));
322 assert_eq!(file.size(), 16385);
323
324 file.set_len(16400).unwrap();
326 assert!(matches!(file.state, CowState::Copying { .. }));
327 assert_eq!(file.size(), 16400);
328
329 let read = file.read_exact(buf.as_mut()).await.unwrap();
331 assert!(matches!(file.state, CowState::Copied { .. }));
332 assert_eq!(read, 4);
333 assert_eq!(buf, [4, 5, 6, 7]);
334 assert_eq!(file.seek(SeekFrom::Current(0)).await.unwrap(), 8);
335 assert_eq!(file.size(), 16400);
336
337 file.seek(SeekFrom::Start(16383)).await.unwrap();
338 let read = file.read_exact(buf.as_mut()).await.unwrap();
339 assert_eq!(read, 4);
340 assert_eq!(buf, [(16383 % 256) as u8, (16384 % 256) as u8, 0, 0]);
342 assert_eq!(file.seek(SeekFrom::Current(0)).await.unwrap(), 16387);
343 assert!(matches!(file.state, CowState::Copied { .. }));
344 }
345}