wasmer_wasix/fs/
inode_guard.rs

1use std::{
2    future::Future,
3    io::{IoSlice, SeekFrom},
4    ops::{Deref, DerefMut},
5    pin::Pin,
6    sync::{Arc, RwLock},
7    task::{Context, Poll},
8};
9
10use tokio::io::{AsyncRead, AsyncSeek, AsyncWrite};
11use virtual_fs::{FsError, Pipe, PipeRx, PipeTx, VirtualFile};
12use wasmer_wasix_types::{
13    types::Eventtype,
14    wasi::{self, EpollType},
15    wasi::{Errno, EventFdReadwrite, Eventrwflags, Subscription},
16};
17
18use super::{InodeGuard, Kind, notification::NotificationInner};
19use crate::{
20    net::socket::{InodeSocketInner, InodeSocketKind},
21    state::{PollEvent, PollEventSet, WasiState, iterate_poll_events},
22    syscalls::{EventResult, EventResultType, map_io_err},
23    utils::{OwnedRwLockReadGuard, OwnedRwLockWriteGuard},
24};
25
26#[derive(Debug, Clone)]
27pub(crate) enum InodeValFilePollGuardMode {
28    File(Arc<RwLock<Box<dyn VirtualFile + Send + Sync + 'static>>>),
29    EventNotifications(Arc<NotificationInner>),
30    Socket { inner: Arc<InodeSocketInner> },
31    PipeRx { rx: Arc<RwLock<Box<PipeRx>>> },
32    PipeTx { tx: Arc<RwLock<Box<PipeTx>>> },
33    DuplexPipe { pipe: Arc<RwLock<Box<Pipe>>> },
34}
35
36pub struct InodeValFilePollGuard {
37    pub(crate) fd: u32,
38    pub(crate) peb: PollEventSet,
39    pub(crate) subscription: Subscription,
40    pub(crate) mode: InodeValFilePollGuardMode,
41}
42
43impl InodeValFilePollGuard {
44    pub(crate) fn new(
45        fd: u32,
46        peb: PollEventSet,
47        subscription: Subscription,
48        guard: &Kind,
49    ) -> Option<Self> {
50        let mode = match guard {
51            Kind::EventNotifications { inner, .. } => {
52                InodeValFilePollGuardMode::EventNotifications(inner.clone())
53            }
54            Kind::Socket { socket, .. } => InodeValFilePollGuardMode::Socket {
55                inner: socket.inner.clone(),
56            },
57            Kind::File {
58                handle: Some(handle),
59                ..
60            } => InodeValFilePollGuardMode::File(handle.clone()),
61            Kind::PipeRx { rx } => InodeValFilePollGuardMode::PipeRx {
62                rx: Arc::new(RwLock::new(Box::new(rx.clone()))),
63            },
64            Kind::PipeTx { tx } => InodeValFilePollGuardMode::PipeTx {
65                tx: Arc::new(RwLock::new(Box::new(tx.clone()))),
66            },
67            Kind::DuplexPipe { pipe } => InodeValFilePollGuardMode::DuplexPipe {
68                pipe: Arc::new(RwLock::new(Box::new(pipe.clone()))),
69            },
70            _ => {
71                return None;
72            }
73        };
74        Some(Self {
75            fd,
76            mode,
77            peb,
78            subscription,
79        })
80    }
81}
82
83impl std::fmt::Debug for InodeValFilePollGuard {
84    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85        match &self.mode {
86            InodeValFilePollGuardMode::File(..) => {
87                write!(f, "guard-file(fd={}, peb={})", self.fd, self.peb)
88            }
89            InodeValFilePollGuardMode::EventNotifications { .. } => {
90                write!(f, "guard-notifications(fd={}, peb={})", self.fd, self.peb)
91            }
92            InodeValFilePollGuardMode::Socket { inner } => {
93                let inner = inner.protected.read().unwrap();
94                match &inner.kind {
95                    InodeSocketKind::TcpListener { .. } => {
96                        write!(f, "guard-tcp-listener(fd={}, peb={})", self.fd, self.peb)
97                    }
98                    InodeSocketKind::TcpStream { socket, .. } => {
99                        if socket.is_closed() {
100                            write!(
101                                f,
102                                "guard-tcp-stream (closed, fd={}, peb={})",
103                                self.fd, self.peb
104                            )
105                        } else {
106                            write!(f, "guard-tcp-stream(fd={}, peb={})", self.fd, self.peb)
107                        }
108                    }
109                    InodeSocketKind::UdpSocket { .. } => {
110                        write!(f, "guard-udp-socket(fd={}, peb={})", self.fd, self.peb)
111                    }
112                    InodeSocketKind::Raw(..) => {
113                        write!(f, "guard-raw-socket(fd={}, peb={})", self.fd, self.peb)
114                    }
115                    _ => write!(f, "guard-socket(fd={}), peb={})", self.fd, self.peb),
116                }
117            }
118            InodeValFilePollGuardMode::PipeRx { .. } => {
119                write!(f, "guard-pipe-rx(...)")
120            }
121            InodeValFilePollGuardMode::PipeTx { .. } => {
122                write!(f, "guard-pipe-tx(...)")
123            }
124            InodeValFilePollGuardMode::DuplexPipe { .. } => {
125                write!(f, "guard-duplex-pipe(...)")
126            }
127        }
128    }
129}
130
131#[derive(Debug)]
132pub struct InodeValFilePollGuardJoin {
133    mode: InodeValFilePollGuardMode,
134    fd: u32,
135    peb: PollEventSet,
136    subscription: Subscription,
137}
138
139impl InodeValFilePollGuardJoin {
140    pub(crate) fn new(guard: InodeValFilePollGuard) -> Self {
141        Self {
142            mode: guard.mode,
143            fd: guard.fd,
144            peb: guard.peb,
145            subscription: guard.subscription,
146        }
147    }
148    pub(crate) fn fd(&self) -> u32 {
149        self.fd
150    }
151    pub(crate) fn peb(&self) -> PollEventSet {
152        self.peb
153    }
154}
155
156pub const POLL_GUARD_MAX_RET: usize = 4;
157
158impl Future for InodeValFilePollGuardJoin {
159    type Output = heapless::Vec<(EventResult, EpollType), POLL_GUARD_MAX_RET>;
160
161    fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
162        // Otherwise we need to register for the event
163        let waker = cx.waker();
164        let mut has_read = false;
165        let mut has_write = false;
166        let mut has_close = false;
167        let mut has_hangup = false;
168
169        let mut ret = heapless::Vec::new();
170        for in_event in iterate_poll_events(self.peb) {
171            match in_event {
172                PollEvent::PollIn => {
173                    has_read = true;
174                }
175                PollEvent::PollOut => {
176                    has_write = true;
177                }
178                PollEvent::PollHangUp => {
179                    has_hangup = true;
180                    has_close = true;
181                }
182                PollEvent::PollError | PollEvent::PollInvalid => {
183                    if !has_hangup {
184                        has_close = true;
185                    }
186                }
187            }
188        }
189        if has_read {
190            let poll_result = match &mut self.mode {
191                InodeValFilePollGuardMode::File(file) => {
192                    let mut guard = file.write().unwrap();
193                    let file = Pin::new(guard.as_mut());
194                    file.poll_read_ready(cx)
195                }
196                InodeValFilePollGuardMode::EventNotifications(inner) => inner.poll(waker).map(Ok),
197                InodeValFilePollGuardMode::Socket { inner } => {
198                    let mut guard = inner.protected.write().unwrap();
199                    guard.poll_read_ready(cx)
200                }
201                InodeValFilePollGuardMode::PipeRx { rx } => {
202                    let mut guard = rx.write().unwrap();
203                    let rx = Pin::new(guard.as_mut());
204                    rx.poll_read_ready(cx)
205                }
206                InodeValFilePollGuardMode::PipeTx { .. } => Poll::Ready(Err(std::io::Error::new(
207                    std::io::ErrorKind::InvalidInput,
208                    "Cannot read from a pipe write end",
209                ))),
210                InodeValFilePollGuardMode::DuplexPipe { pipe } => {
211                    let mut guard = pipe.write().unwrap();
212                    let pipe = Pin::new(guard.as_mut());
213                    pipe.poll_read_ready(cx)
214                }
215            };
216            match poll_result {
217                Poll::Ready(Err(err)) if has_close && is_err_closed(&err) => {
218                    let inner = match self.subscription.type_ {
219                        Eventtype::FdRead | Eventtype::FdWrite => {
220                            Some(EventResultType::Fd(EventFdReadwrite {
221                                nbytes: 0,
222                                flags: if has_hangup {
223                                    Eventrwflags::FD_READWRITE_HANGUP
224                                } else {
225                                    Eventrwflags::empty()
226                                },
227                            }))
228                        }
229                        Eventtype::Clock => Some(EventResultType::Clock(0)),
230                        Eventtype::Unknown => None,
231                    };
232                    if let Some(inner) = inner {
233                        ret.push((
234                            EventResult {
235                                userdata: self.subscription.userdata,
236                                error: Errno::Success,
237                                type_: self.subscription.type_,
238                                inner,
239                            },
240                            EpollType::EPOLLHUP,
241                        ))
242                        .ok();
243                    }
244                }
245                Poll::Ready(bytes_available) => {
246                    let mut error = Errno::Success;
247                    let bytes_available = match bytes_available {
248                        Ok(a) => a,
249                        Err(e) => {
250                            error = map_io_err(e);
251                            0
252                        }
253                    };
254                    let inner = match self.subscription.type_ {
255                        Eventtype::FdRead | Eventtype::FdWrite => {
256                            Some(EventResultType::Fd(EventFdReadwrite {
257                                nbytes: bytes_available as u64,
258                                flags: if bytes_available == 0 {
259                                    Eventrwflags::FD_READWRITE_HANGUP
260                                } else {
261                                    Eventrwflags::empty()
262                                },
263                            }))
264                        }
265                        Eventtype::Clock => Some(EventResultType::Clock(0)),
266                        Eventtype::Unknown => None,
267                    };
268                    if let Some(inner) = inner {
269                        ret.push((
270                            EventResult {
271                                userdata: self.subscription.userdata,
272                                error,
273                                type_: self.subscription.type_,
274                                inner,
275                            },
276                            if error == Errno::Success {
277                                EpollType::EPOLLIN
278                            } else {
279                                EpollType::EPOLLERR
280                            },
281                        ))
282                        .ok();
283                    }
284                }
285                Poll::Pending => {}
286            };
287        }
288        if has_write {
289            let poll_result = match &mut self.mode {
290                InodeValFilePollGuardMode::File(file) => {
291                    let mut guard = file.write().unwrap();
292                    let file = Pin::new(guard.as_mut());
293                    file.poll_write_ready(cx)
294                }
295                InodeValFilePollGuardMode::EventNotifications(inner) => inner.poll(waker).map(Ok),
296                InodeValFilePollGuardMode::Socket { inner } => {
297                    let mut guard = inner.protected.write().unwrap();
298                    guard.poll_write_ready(cx)
299                }
300                InodeValFilePollGuardMode::PipeRx { .. } => Poll::Ready(Err(std::io::Error::new(
301                    std::io::ErrorKind::InvalidInput,
302                    "Cannot write to a pipe read end",
303                ))),
304                InodeValFilePollGuardMode::PipeTx { tx } => {
305                    let mut guard = tx.write().unwrap();
306                    let tx = Pin::new(guard.as_mut());
307                    tx.poll_write_ready()
308                }
309                InodeValFilePollGuardMode::DuplexPipe { pipe } => {
310                    let mut guard = pipe.write().unwrap();
311                    let pipe = Pin::new(guard.as_mut());
312                    pipe.poll_write_ready(cx)
313                }
314            };
315            match poll_result {
316                Poll::Ready(Err(err)) if has_close && is_err_closed(&err) => {
317                    let inner = match self.subscription.type_ {
318                        Eventtype::FdRead | Eventtype::FdWrite => {
319                            Some(EventResultType::Fd(EventFdReadwrite {
320                                nbytes: 0,
321                                flags: if has_hangup {
322                                    Eventrwflags::FD_READWRITE_HANGUP
323                                } else {
324                                    Eventrwflags::empty()
325                                },
326                            }))
327                        }
328                        Eventtype::Clock => Some(EventResultType::Clock(0)),
329                        Eventtype::Unknown => None,
330                    };
331                    if let Some(inner) = inner {
332                        ret.push((
333                            EventResult {
334                                userdata: self.subscription.userdata,
335                                error: Errno::Success,
336                                type_: self.subscription.type_,
337                                inner,
338                            },
339                            EpollType::EPOLLHUP,
340                        ))
341                        .ok();
342                    }
343                }
344                Poll::Ready(bytes_available) => {
345                    let mut error = Errno::Success;
346                    let bytes_available = match bytes_available {
347                        Ok(a) => a,
348                        Err(e) => {
349                            error = map_io_err(e);
350                            0
351                        }
352                    };
353                    let inner = match self.subscription.type_ {
354                        Eventtype::FdRead | Eventtype::FdWrite => {
355                            Some(EventResultType::Fd(EventFdReadwrite {
356                                nbytes: bytes_available as u64,
357                                flags: if bytes_available == 0 {
358                                    Eventrwflags::FD_READWRITE_HANGUP
359                                } else {
360                                    Eventrwflags::empty()
361                                },
362                            }))
363                        }
364                        Eventtype::Clock => Some(EventResultType::Clock(0)),
365                        Eventtype::Unknown => None,
366                    };
367                    if let Some(inner) = inner {
368                        ret.push((
369                            EventResult {
370                                userdata: self.subscription.userdata,
371                                error,
372                                type_: self.subscription.type_,
373                                inner,
374                            },
375                            if error == Errno::Success {
376                                EpollType::EPOLLOUT
377                            } else {
378                                EpollType::EPOLLERR
379                            },
380                        ))
381                        .ok();
382                    }
383                }
384                Poll::Pending => {}
385            };
386        }
387        if !ret.is_empty() {
388            return Poll::Ready(ret);
389        }
390        Poll::Pending
391    }
392}
393
394#[derive(Debug)]
395pub(crate) struct InodeValFileReadGuard {
396    guard: OwnedRwLockReadGuard<Box<dyn VirtualFile + Send + Sync + 'static>>,
397}
398
399impl InodeValFileReadGuard {
400    pub(crate) fn new(file: &Arc<RwLock<Box<dyn VirtualFile + Send + Sync + 'static>>>) -> Self {
401        Self {
402            guard: crate::utils::read_owned(file).unwrap(),
403        }
404    }
405}
406
407impl InodeValFileReadGuard {
408    pub fn into_poll_guard(
409        self,
410        fd: u32,
411        peb: PollEventSet,
412        subscription: Subscription,
413    ) -> InodeValFilePollGuard {
414        InodeValFilePollGuard {
415            fd,
416            peb,
417            subscription,
418            mode: InodeValFilePollGuardMode::File(self.guard.into_inner()),
419        }
420    }
421}
422
423impl Deref for InodeValFileReadGuard {
424    type Target = dyn VirtualFile + Send + Sync + 'static;
425    fn deref(&self) -> &Self::Target {
426        self.guard.deref().deref()
427    }
428}
429
430#[derive(Debug)]
431pub struct InodeValFileWriteGuard {
432    guard: OwnedRwLockWriteGuard<Box<dyn VirtualFile + Send + Sync + 'static>>,
433}
434
435impl InodeValFileWriteGuard {
436    pub(crate) fn new(file: &Arc<RwLock<Box<dyn VirtualFile + Send + Sync + 'static>>>) -> Self {
437        Self {
438            guard: crate::utils::write_owned(file).unwrap(),
439        }
440    }
441    pub(crate) fn swap(
442        &mut self,
443        mut file: Box<dyn VirtualFile + Send + Sync + 'static>,
444    ) -> Box<dyn VirtualFile + Send + Sync + 'static> {
445        std::mem::swap(self.guard.deref_mut(), &mut file);
446        file
447    }
448}
449
450impl Deref for InodeValFileWriteGuard {
451    type Target = dyn VirtualFile + Send + Sync + 'static;
452    fn deref(&self) -> &Self::Target {
453        self.guard.deref().deref()
454    }
455}
456impl DerefMut for InodeValFileWriteGuard {
457    fn deref_mut(&mut self) -> &mut Self::Target {
458        self.guard.deref_mut().deref_mut()
459    }
460}
461
462#[derive(Debug)]
463pub(crate) struct WasiStateFileGuard {
464    inode: InodeGuard,
465}
466
467impl WasiStateFileGuard {
468    pub fn new(state: &WasiState, fd: wasi::Fd) -> Result<Option<Self>, FsError> {
469        let fd_map = state.fs.fd_map.read().unwrap();
470        if let Some(fd) = fd_map.get(fd) {
471            Ok(Some(Self {
472                inode: fd.inode.clone(),
473            }))
474        } else {
475            Ok(None)
476        }
477    }
478
479    pub fn lock_read(&self) -> Option<InodeValFileReadGuard> {
480        let guard = self.inode.read();
481        if let Kind::File { handle, .. } = guard.deref() {
482            handle.as_ref().map(InodeValFileReadGuard::new)
483        } else {
484            // Our public API should ensure that this is not possible
485            unreachable!("Non-file found in standard device location")
486        }
487    }
488
489    pub fn lock_write(&self) -> Option<InodeValFileWriteGuard> {
490        let guard = self.inode.read();
491        if let Kind::File { handle, .. } = guard.deref() {
492            handle.as_ref().map(InodeValFileWriteGuard::new)
493        } else {
494            // Our public API should ensure that this is not possible
495            unreachable!("Non-file found in standard device location")
496        }
497    }
498}
499
500impl VirtualFile for WasiStateFileGuard {
501    fn last_accessed(&self) -> u64 {
502        let guard = self.lock_read();
503        if let Some(file) = guard.as_ref() {
504            file.last_accessed()
505        } else {
506            0
507        }
508    }
509
510    fn last_modified(&self) -> u64 {
511        let guard = self.lock_read();
512        if let Some(file) = guard.as_ref() {
513            file.last_modified()
514        } else {
515            0
516        }
517    }
518
519    fn created_time(&self) -> u64 {
520        let guard = self.lock_read();
521        if let Some(file) = guard.as_ref() {
522            file.created_time()
523        } else {
524            0
525        }
526    }
527
528    fn set_times(
529        &mut self,
530        atime: Option<u64>,
531        mtime: Option<u64>,
532    ) -> Result<(), virtual_fs::FsError> {
533        let mut guard = self.lock_write();
534        if let Some(file) = guard.as_mut() {
535            file.set_times(atime, mtime)
536        } else {
537            Err(crate::FsError::Lock)
538        }
539    }
540
541    fn size(&self) -> u64 {
542        let guard = self.lock_read();
543        if let Some(file) = guard.as_ref() {
544            file.size()
545        } else {
546            0
547        }
548    }
549
550    fn set_len(&mut self, new_size: u64) -> Result<(), FsError> {
551        let mut guard = self.lock_write();
552        if let Some(file) = guard.as_mut() {
553            file.set_len(new_size)
554        } else {
555            Err(FsError::IOError)
556        }
557    }
558
559    fn unlink(&mut self) -> Result<(), FsError> {
560        let mut guard = self.lock_write();
561        if let Some(file) = guard.as_mut() {
562            file.unlink()
563        } else {
564            Err(FsError::IOError)
565        }
566    }
567
568    fn is_open(&self) -> bool {
569        let guard = self.lock_read();
570        if let Some(file) = guard.as_ref() {
571            file.is_open()
572        } else {
573            false
574        }
575    }
576
577    fn poll_read_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<usize>> {
578        let mut guard = self.lock_write();
579        if let Some(file) = guard.as_mut() {
580            let file = Pin::new(file.deref_mut());
581            file.poll_read_ready(cx)
582        } else {
583            Poll::Ready(Ok(0))
584        }
585    }
586
587    fn poll_write_ready(
588        self: Pin<&mut Self>,
589        cx: &mut Context<'_>,
590    ) -> Poll<std::io::Result<usize>> {
591        let mut guard = self.lock_write();
592        if let Some(file) = guard.as_mut() {
593            let file = Pin::new(file.deref_mut());
594            file.poll_write_ready(cx)
595        } else {
596            Poll::Ready(Ok(0))
597        }
598    }
599}
600
601impl AsyncSeek for WasiStateFileGuard {
602    fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> std::io::Result<()> {
603        let mut guard = self.lock_write();
604        if let Some(guard) = guard.as_mut() {
605            let file = Pin::new(guard.deref_mut());
606            file.start_seek(position)
607        } else {
608            Err(std::io::ErrorKind::Unsupported.into())
609        }
610    }
611    fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<u64>> {
612        let mut guard = self.lock_write();
613        if let Some(guard) = guard.as_mut() {
614            let file = Pin::new(guard.deref_mut());
615            file.poll_complete(cx)
616        } else {
617            Poll::Ready(Err(std::io::ErrorKind::Unsupported.into()))
618        }
619    }
620}
621
622impl AsyncWrite for WasiStateFileGuard {
623    fn poll_write(
624        self: Pin<&mut Self>,
625        cx: &mut Context<'_>,
626        buf: &[u8],
627    ) -> Poll<std::io::Result<usize>> {
628        let mut guard = self.lock_write();
629        if let Some(guard) = guard.as_mut() {
630            let file = Pin::new(guard.deref_mut());
631            file.poll_write(cx, buf)
632        } else {
633            Poll::Ready(Err(std::io::ErrorKind::Unsupported.into()))
634        }
635    }
636    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
637        let mut guard = self.lock_write();
638        if let Some(guard) = guard.as_mut() {
639            let file = Pin::new(guard.deref_mut());
640            file.poll_flush(cx)
641        } else {
642            Poll::Ready(Err(std::io::ErrorKind::Unsupported.into()))
643        }
644    }
645    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
646        let mut guard = self.lock_write();
647        if let Some(guard) = guard.as_mut() {
648            let file = Pin::new(guard.deref_mut());
649            file.poll_shutdown(cx)
650        } else {
651            Poll::Ready(Err(std::io::ErrorKind::Unsupported.into()))
652        }
653    }
654    fn poll_write_vectored(
655        self: Pin<&mut Self>,
656        cx: &mut Context<'_>,
657        bufs: &[IoSlice<'_>],
658    ) -> Poll<std::io::Result<usize>> {
659        let mut guard = self.lock_write();
660        if let Some(guard) = guard.as_mut() {
661            let file = Pin::new(guard.deref_mut());
662            file.poll_write_vectored(cx, bufs)
663        } else {
664            Poll::Ready(Err(std::io::ErrorKind::Unsupported.into()))
665        }
666    }
667    fn is_write_vectored(&self) -> bool {
668        let mut guard = self.lock_write();
669        if let Some(guard) = guard.as_mut() {
670            let file = Pin::new(guard.deref_mut());
671            file.is_write_vectored()
672        } else {
673            false
674        }
675    }
676}
677
678impl AsyncRead for WasiStateFileGuard {
679    fn poll_read(
680        self: Pin<&mut Self>,
681        cx: &mut Context<'_>,
682        buf: &mut tokio::io::ReadBuf<'_>,
683    ) -> Poll<std::io::Result<()>> {
684        let mut guard = self.lock_write();
685        if let Some(guard) = guard.as_mut() {
686            let file = Pin::new(guard.deref_mut());
687            file.poll_read(cx, buf)
688        } else {
689            Poll::Ready(Err(std::io::ErrorKind::Unsupported.into()))
690        }
691    }
692}
693
694fn is_err_closed(err: &std::io::Error) -> bool {
695    err.kind() == std::io::ErrorKind::ConnectionAborted
696        || err.kind() == std::io::ErrorKind::ConnectionRefused
697        || err.kind() == std::io::ErrorKind::ConnectionReset
698        || err.kind() == std::io::ErrorKind::BrokenPipe
699        || err.kind() == std::io::ErrorKind::NotConnected
700        || err.kind() == std::io::ErrorKind::UnexpectedEof
701}