virtual_fs/mem_fs/
stdio.rs

1//! This module contains the standard I/O streams, i.e. “emulated”
2//! `stdin`, `stdout` and `stderr`.
3
4use crate::{FsError, Result, VirtualFile};
5use std::io::{self, Write};
6
7macro_rules! impl_virtualfile_on_std_streams {
8    ($name:ident { readable: $readable:expr_2021, writable: $writable:expr_2021 $(,)* }) => {
9        /// A wrapper type around the standard I/O stream of the same
10        /// name that implements `VirtualFile`.
11        #[derive(Debug, Default)]
12        pub struct $name {
13            pub buf: Vec<u8>,
14        }
15
16        impl $name {
17            const fn is_readable(&self) -> bool {
18                $readable
19            }
20
21            const fn is_writable(&self) -> bool {
22                $writable
23            }
24        }
25
26        #[async_trait::async_trait]
27        impl VirtualFile for $name {
28            fn last_accessed(&self) -> u64 {
29                0
30            }
31
32            fn last_modified(&self) -> u64 {
33                0
34            }
35
36            fn created_time(&self) -> u64 {
37                0
38            }
39
40            fn size(&self) -> u64 {
41                0
42            }
43
44            fn set_len(& mut self, _new_size: u64) -> Result<()> {
45                Err(FsError::PermissionDenied)
46            }
47
48            fn unlink(&mut self) -> Result<()> {
49                Ok(())
50            }
51
52            fn poll_read_ready(self: std::pin::Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> std::task::Poll<std::io::Result<usize>> {
53                std::task::Poll::Ready(Ok(self.buf.len()))
54            }
55
56            fn poll_write_ready(self: std::pin::Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> std::task::Poll<std::io::Result<usize>> {
57                std::task::Poll::Ready(Ok(8192))
58            }
59        }
60
61        impl_virtualfile_on_std_streams!(impl AsyncSeek for $name);
62        impl_virtualfile_on_std_streams!(impl AsyncRead for $name);
63        impl_virtualfile_on_std_streams!(impl AsyncWrite for $name);
64    };
65
66    (impl AsyncSeek for $name:ident) => {
67        impl tokio::io::AsyncSeek for $name {
68            fn start_seek(
69                self: std::pin::Pin<&mut Self>,
70                _position: io::SeekFrom
71            ) -> io::Result<()> {
72                Err(io::Error::new(
73                    io::ErrorKind::PermissionDenied,
74                    concat!("cannot seek `", stringify!($name), "`"),
75                ))
76            }
77            fn poll_complete(
78                self: std::pin::Pin<&mut Self>,
79                _cx: &mut std::task::Context<'_>
80            ) -> std::task::Poll<io::Result<u64>>
81            {
82                std::task::Poll::Ready(
83                    Err(io::Error::new(
84                        io::ErrorKind::PermissionDenied,
85                        concat!("cannot seek `", stringify!($name), "`"),
86                    ))
87                )
88            }
89        }
90    };
91
92    (impl AsyncRead for $name:ident) => {
93        impl tokio::io::AsyncRead for $name {
94            fn poll_read(
95                mut self: std::pin::Pin<&mut Self>,
96                _cx: &mut std::task::Context<'_>,
97                buf: &mut tokio::io::ReadBuf<'_>,
98            ) -> std::task::Poll<io::Result<()>> {
99                std::task::Poll::Ready(
100                    if self.is_readable() {
101                        let length = buf.remaining().min(self.buf.len());
102                        buf.put_slice(&self.buf[..length]);
103
104                        // Remove what has been consumed.
105                        self.buf.drain(..length);
106
107                        Ok(())
108                    } else {
109                        Err(io::Error::new(
110                            io::ErrorKind::PermissionDenied,
111                            concat!("cannot read from `", stringify!($name), "`"),
112                        ))
113                    }
114                )
115            }
116        }
117    };
118
119    (impl AsyncWrite for $name:ident) => {
120        impl tokio::io::AsyncWrite for $name {
121            fn poll_write(
122                mut self: std::pin::Pin<&mut Self>,
123                _cx: &mut std::task::Context<'_>,
124                buf: &[u8],
125            ) -> std::task::Poll<io::Result<usize>> {
126                std::task::Poll::Ready(
127                    if self.is_writable() {
128                        self.buf.write(buf)
129                    } else {
130                        Err(io::Error::new(
131                            io::ErrorKind::PermissionDenied,
132                            concat!("cannot write to `", stringify!($name), "`"),
133                        ))
134                    }
135                )
136            }
137
138            fn poll_flush(
139                mut self: std::pin::Pin<&mut Self>,
140                _cx: &mut std::task::Context<'_>
141            ) -> std::task::Poll<io::Result<()>> {
142                std::task::Poll::Ready(
143                    if self.is_writable() {
144                        self.buf.flush()
145                    } else {
146                        Err(io::Error::new(
147                            io::ErrorKind::PermissionDenied,
148                            concat!("cannot flush `", stringify!($name), "`"),
149                        ))
150                    }
151                )
152            }
153
154            fn poll_shutdown(
155                self: std::pin::Pin<&mut Self>,
156                _cx: &mut std::task::Context<'_>
157            ) -> std::task::Poll<io::Result<()>> {
158                std::task::Poll::Ready(Ok(()))
159            }
160        }
161    };
162}
163
164impl_virtualfile_on_std_streams!(Stdin {
165    readable: true,
166    writable: false,
167});
168impl_virtualfile_on_std_streams!(Stdout {
169    readable: false,
170    writable: true,
171});
172impl_virtualfile_on_std_streams!(Stderr {
173    readable: false,
174    writable: true,
175});
176
177#[cfg(test)]
178mod test_read_write_seek {
179    use tokio::io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt};
180
181    use crate::mem_fs::*;
182    use std::io::{self};
183
184    #[tokio::test]
185    async fn test_read_stdin() {
186        let mut stdin = Stdin {
187            buf: vec![b'f', b'o', b'o', b'b', b'a', b'r'],
188        };
189        let mut buffer = [0; 3];
190
191        assert!(
192            matches!(stdin.read(&mut buffer).await, Ok(3)),
193            "reading bytes from `stdin`",
194        );
195        assert_eq!(
196            buffer,
197            [b'f', b'o', b'o'],
198            "checking the bytes read from `stdin`"
199        );
200
201        let mut buffer = Vec::new();
202
203        assert!(
204            matches!(stdin.read_to_end(&mut buffer).await, Ok(3)),
205            "reading bytes again from `stdin`",
206        );
207        assert_eq!(buffer, b"bar", "checking the bytes read from `stdin`");
208
209        let mut buffer = [0; 1];
210
211        assert!(
212            stdin.read_exact(&mut buffer).await.is_err(),
213            "cannot read bytes again because `stdin` has fully consumed",
214        );
215    }
216
217    #[tokio::test]
218    async fn test_write_stdin() {
219        let mut stdin = Stdin { buf: vec![] };
220
221        assert!(
222            stdin.write(b"bazqux").await.is_err(),
223            "cannot write into `stdin`"
224        );
225    }
226
227    #[tokio::test]
228    async fn test_seek_stdin() {
229        let mut stdin = Stdin {
230            buf: vec![b'f', b'o', b'o', b'b', b'a', b'r'],
231        };
232
233        assert!(
234            stdin.seek(io::SeekFrom::End(0)).await.is_err(),
235            "cannot seek `stdin`",
236        );
237    }
238
239    #[tokio::test]
240    async fn test_read_stdout() {
241        let mut stdout = Stdout {
242            buf: vec![b'f', b'o', b'o', b'b', b'a', b'r'],
243        };
244        let mut buffer = String::new();
245
246        assert!(
247            stdout.read_to_string(&mut buffer).await.is_err(),
248            "cannot read from `stdout`"
249        );
250    }
251
252    #[tokio::test]
253    async fn test_write_stdout() {
254        let mut stdout = Stdout { buf: vec![] };
255
256        assert!(
257            matches!(stdout.write(b"baz").await, Ok(3)),
258            "writing into `stdout`",
259        );
260        assert!(
261            matches!(stdout.write(b"qux").await, Ok(3)),
262            "writing again into `stdout`",
263        );
264        assert_eq!(stdout.buf, b"bazqux", "checking the content of `stdout`");
265    }
266
267    #[tokio::test]
268    async fn test_seek_stdout() {
269        let mut stdout = Stdout {
270            buf: vec![b'f', b'o', b'o', b'b', b'a', b'r'],
271        };
272
273        assert!(
274            stdout.seek(io::SeekFrom::End(0)).await.is_err(),
275            "cannot seek `stdout`",
276        );
277    }
278
279    #[tokio::test]
280    async fn test_read_stderr() {
281        let mut stderr = Stderr {
282            buf: vec![b'f', b'o', b'o', b'b', b'a', b'r'],
283        };
284        let mut buffer = String::new();
285
286        assert!(
287            stderr.read_to_string(&mut buffer).await.is_err(),
288            "cannot read from `stderr`"
289        );
290    }
291
292    #[tokio::test]
293    async fn test_write_stderr() {
294        let mut stderr = Stderr { buf: vec![] };
295
296        assert!(
297            matches!(stderr.write(b"baz").await, Ok(3)),
298            "writing into `stderr`",
299        );
300        assert!(
301            matches!(stderr.write(b"qux").await, Ok(3)),
302            "writing again into `stderr`",
303        );
304        assert_eq!(stderr.buf, b"bazqux", "checking the content of `stderr`");
305    }
306
307    #[tokio::test]
308    async fn test_seek_stderr() {
309        let mut stderr = Stderr {
310            buf: vec![b'f', b'o', b'o', b'b', b'a', b'r'],
311        };
312
313        assert!(
314            stderr.seek(io::SeekFrom::End(0)).await.is_err(),
315            "cannot seek `stderr`",
316        );
317    }
318}