1use std::{
2 future::Future,
3 io::{IoSlice, SeekFrom},
4 ops::{Deref, DerefMut},
5 pin::Pin,
6 sync::{Arc, RwLock},
7 task::{Context, Poll},
8};
9
10use tokio::io::{AsyncRead, AsyncSeek, AsyncWrite};
11use virtual_fs::{FsError, Pipe, PipeRx, PipeTx, VirtualFile};
12use wasmer_wasix_types::{
13 types::Eventtype,
14 wasi::{self, EpollType},
15 wasi::{Errno, EventFdReadwrite, Eventrwflags, Subscription},
16};
17
18use super::{InodeGuard, Kind, notification::NotificationInner};
19use crate::{
20 net::socket::{InodeSocketInner, InodeSocketKind},
21 state::{PollEvent, PollEventSet, WasiState, iterate_poll_events},
22 syscalls::{EventResult, EventResultType, map_io_err},
23 utils::{OwnedRwLockReadGuard, OwnedRwLockWriteGuard},
24};
25
26#[derive(Debug, Clone)]
27pub(crate) enum InodeValFilePollGuardMode {
28 File(Arc<RwLock<Box<dyn VirtualFile + Send + Sync + 'static>>>),
29 EventNotifications(Arc<NotificationInner>),
30 Socket { inner: Arc<InodeSocketInner> },
31 PipeRx { rx: Arc<RwLock<Box<PipeRx>>> },
32 PipeTx { tx: Arc<RwLock<Box<PipeTx>>> },
33 DuplexPipe { pipe: Arc<RwLock<Box<Pipe>>> },
34}
35
36pub struct InodeValFilePollGuard {
37 pub(crate) fd: u32,
38 pub(crate) peb: PollEventSet,
39 pub(crate) subscription: Subscription,
40 pub(crate) mode: InodeValFilePollGuardMode,
41}
42
43impl InodeValFilePollGuard {
44 pub(crate) fn new(
45 fd: u32,
46 peb: PollEventSet,
47 subscription: Subscription,
48 guard: &Kind,
49 ) -> Option<Self> {
50 let mode = match guard {
51 Kind::EventNotifications { inner, .. } => {
52 InodeValFilePollGuardMode::EventNotifications(inner.clone())
53 }
54 Kind::Socket { socket, .. } => InodeValFilePollGuardMode::Socket {
55 inner: socket.inner.clone(),
56 },
57 Kind::File {
58 handle: Some(handle),
59 ..
60 } => InodeValFilePollGuardMode::File(handle.clone()),
61 Kind::PipeRx { rx } => InodeValFilePollGuardMode::PipeRx {
62 rx: Arc::new(RwLock::new(Box::new(rx.clone()))),
63 },
64 Kind::PipeTx { tx } => InodeValFilePollGuardMode::PipeTx {
65 tx: Arc::new(RwLock::new(Box::new(tx.clone()))),
66 },
67 Kind::DuplexPipe { pipe } => InodeValFilePollGuardMode::DuplexPipe {
68 pipe: Arc::new(RwLock::new(Box::new(pipe.clone()))),
69 },
70 _ => {
71 return None;
72 }
73 };
74 Some(Self {
75 fd,
76 mode,
77 peb,
78 subscription,
79 })
80 }
81}
82
83impl std::fmt::Debug for InodeValFilePollGuard {
84 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85 match &self.mode {
86 InodeValFilePollGuardMode::File(..) => {
87 write!(f, "guard-file(fd={}, peb={})", self.fd, self.peb)
88 }
89 InodeValFilePollGuardMode::EventNotifications { .. } => {
90 write!(f, "guard-notifications(fd={}, peb={})", self.fd, self.peb)
91 }
92 InodeValFilePollGuardMode::Socket { inner } => {
93 let inner = inner.protected.read().unwrap();
94 match &inner.kind {
95 InodeSocketKind::TcpListener { .. } => {
96 write!(f, "guard-tcp-listener(fd={}, peb={})", self.fd, self.peb)
97 }
98 InodeSocketKind::TcpStream { socket, .. } => {
99 if socket.is_closed() {
100 write!(
101 f,
102 "guard-tcp-stream (closed, fd={}, peb={})",
103 self.fd, self.peb
104 )
105 } else {
106 write!(f, "guard-tcp-stream(fd={}, peb={})", self.fd, self.peb)
107 }
108 }
109 InodeSocketKind::UdpSocket { .. } => {
110 write!(f, "guard-udp-socket(fd={}, peb={})", self.fd, self.peb)
111 }
112 InodeSocketKind::Raw(..) => {
113 write!(f, "guard-raw-socket(fd={}, peb={})", self.fd, self.peb)
114 }
115 _ => write!(f, "guard-socket(fd={}), peb={})", self.fd, self.peb),
116 }
117 }
118 InodeValFilePollGuardMode::PipeRx { .. } => {
119 write!(f, "guard-pipe-rx(...)")
120 }
121 InodeValFilePollGuardMode::PipeTx { .. } => {
122 write!(f, "guard-pipe-tx(...)")
123 }
124 InodeValFilePollGuardMode::DuplexPipe { .. } => {
125 write!(f, "guard-duplex-pipe(...)")
126 }
127 }
128 }
129}
130
131#[derive(Debug)]
132pub struct InodeValFilePollGuardJoin {
133 mode: InodeValFilePollGuardMode,
134 fd: u32,
135 peb: PollEventSet,
136 subscription: Subscription,
137}
138
139impl InodeValFilePollGuardJoin {
140 pub(crate) fn new(guard: InodeValFilePollGuard) -> Self {
141 Self {
142 mode: guard.mode,
143 fd: guard.fd,
144 peb: guard.peb,
145 subscription: guard.subscription,
146 }
147 }
148 pub(crate) fn fd(&self) -> u32 {
149 self.fd
150 }
151 pub(crate) fn peb(&self) -> PollEventSet {
152 self.peb
153 }
154}
155
156pub const POLL_GUARD_MAX_RET: usize = 4;
157
158impl Future for InodeValFilePollGuardJoin {
159 type Output = heapless::Vec<(EventResult, EpollType), POLL_GUARD_MAX_RET>;
160
161 fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
162 let waker = cx.waker();
164 let mut has_read = false;
165 let mut has_write = false;
166 let mut has_close = false;
167 let mut has_hangup = false;
168
169 let mut ret = heapless::Vec::new();
170 for in_event in iterate_poll_events(self.peb) {
171 match in_event {
172 PollEvent::PollIn => {
173 has_read = true;
174 }
175 PollEvent::PollOut => {
176 has_write = true;
177 }
178 PollEvent::PollHangUp => {
179 has_hangup = true;
180 has_close = true;
181 }
182 PollEvent::PollError | PollEvent::PollInvalid => {
183 if !has_hangup {
184 has_close = true;
185 }
186 }
187 }
188 }
189 if has_read {
190 let poll_result = match &mut self.mode {
191 InodeValFilePollGuardMode::File(file) => {
192 let mut guard = file.write().unwrap();
193 let file = Pin::new(guard.as_mut());
194 file.poll_read_ready(cx)
195 }
196 InodeValFilePollGuardMode::EventNotifications(inner) => inner.poll(waker).map(Ok),
197 InodeValFilePollGuardMode::Socket { inner } => {
198 let mut guard = inner.protected.write().unwrap();
199 guard.poll_read_ready(cx)
200 }
201 InodeValFilePollGuardMode::PipeRx { rx } => {
202 let mut guard = rx.write().unwrap();
203 let rx = Pin::new(guard.as_mut());
204 rx.poll_read_ready(cx)
205 }
206 InodeValFilePollGuardMode::PipeTx { .. } => Poll::Ready(Err(std::io::Error::new(
207 std::io::ErrorKind::InvalidInput,
208 "Cannot read from a pipe write end",
209 ))),
210 InodeValFilePollGuardMode::DuplexPipe { pipe } => {
211 let mut guard = pipe.write().unwrap();
212 let pipe = Pin::new(guard.as_mut());
213 pipe.poll_read_ready(cx)
214 }
215 };
216 match poll_result {
217 Poll::Ready(Err(err)) if has_close && is_err_closed(&err) => {
218 let inner = match self.subscription.type_ {
219 Eventtype::FdRead | Eventtype::FdWrite => {
220 Some(EventResultType::Fd(EventFdReadwrite {
221 nbytes: 0,
222 flags: if has_hangup {
223 Eventrwflags::FD_READWRITE_HANGUP
224 } else {
225 Eventrwflags::empty()
226 },
227 }))
228 }
229 Eventtype::Clock => Some(EventResultType::Clock(0)),
230 Eventtype::Unknown => None,
231 };
232 if let Some(inner) = inner {
233 ret.push((
234 EventResult {
235 userdata: self.subscription.userdata,
236 error: Errno::Success,
237 type_: self.subscription.type_,
238 inner,
239 },
240 EpollType::EPOLLHUP,
241 ))
242 .ok();
243 }
244 }
245 Poll::Ready(bytes_available) => {
246 let mut error = Errno::Success;
247 let bytes_available = match bytes_available {
248 Ok(a) => a,
249 Err(e) => {
250 error = map_io_err(e);
251 0
252 }
253 };
254 let inner = match self.subscription.type_ {
255 Eventtype::FdRead | Eventtype::FdWrite => {
256 Some(EventResultType::Fd(EventFdReadwrite {
257 nbytes: bytes_available as u64,
258 flags: if bytes_available == 0 {
259 Eventrwflags::FD_READWRITE_HANGUP
260 } else {
261 Eventrwflags::empty()
262 },
263 }))
264 }
265 Eventtype::Clock => Some(EventResultType::Clock(0)),
266 Eventtype::Unknown => None,
267 };
268 if let Some(inner) = inner {
269 ret.push((
270 EventResult {
271 userdata: self.subscription.userdata,
272 error,
273 type_: self.subscription.type_,
274 inner,
275 },
276 if error == Errno::Success {
277 EpollType::EPOLLIN
278 } else {
279 EpollType::EPOLLERR
280 },
281 ))
282 .ok();
283 }
284 }
285 Poll::Pending => {}
286 };
287 }
288 if has_write {
289 let poll_result = match &mut self.mode {
290 InodeValFilePollGuardMode::File(file) => {
291 let mut guard = file.write().unwrap();
292 let file = Pin::new(guard.as_mut());
293 file.poll_write_ready(cx)
294 }
295 InodeValFilePollGuardMode::EventNotifications(inner) => inner.poll(waker).map(Ok),
296 InodeValFilePollGuardMode::Socket { inner } => {
297 let mut guard = inner.protected.write().unwrap();
298 guard.poll_write_ready(cx)
299 }
300 InodeValFilePollGuardMode::PipeRx { .. } => Poll::Ready(Err(std::io::Error::new(
301 std::io::ErrorKind::InvalidInput,
302 "Cannot write to a pipe read end",
303 ))),
304 InodeValFilePollGuardMode::PipeTx { tx } => {
305 let mut guard = tx.write().unwrap();
306 let tx = Pin::new(guard.as_mut());
307 tx.poll_write_ready()
308 }
309 InodeValFilePollGuardMode::DuplexPipe { pipe } => {
310 let mut guard = pipe.write().unwrap();
311 let pipe = Pin::new(guard.as_mut());
312 pipe.poll_write_ready(cx)
313 }
314 };
315 match poll_result {
316 Poll::Ready(Err(err)) if has_close && is_err_closed(&err) => {
317 let inner = match self.subscription.type_ {
318 Eventtype::FdRead | Eventtype::FdWrite => {
319 Some(EventResultType::Fd(EventFdReadwrite {
320 nbytes: 0,
321 flags: if has_hangup {
322 Eventrwflags::FD_READWRITE_HANGUP
323 } else {
324 Eventrwflags::empty()
325 },
326 }))
327 }
328 Eventtype::Clock => Some(EventResultType::Clock(0)),
329 Eventtype::Unknown => None,
330 };
331 if let Some(inner) = inner {
332 ret.push((
333 EventResult {
334 userdata: self.subscription.userdata,
335 error: Errno::Success,
336 type_: self.subscription.type_,
337 inner,
338 },
339 EpollType::EPOLLHUP,
340 ))
341 .ok();
342 }
343 }
344 Poll::Ready(bytes_available) => {
345 let mut error = Errno::Success;
346 let bytes_available = match bytes_available {
347 Ok(a) => a,
348 Err(e) => {
349 error = map_io_err(e);
350 0
351 }
352 };
353 let inner = match self.subscription.type_ {
354 Eventtype::FdRead | Eventtype::FdWrite => {
355 Some(EventResultType::Fd(EventFdReadwrite {
356 nbytes: bytes_available as u64,
357 flags: if bytes_available == 0 {
358 Eventrwflags::FD_READWRITE_HANGUP
359 } else {
360 Eventrwflags::empty()
361 },
362 }))
363 }
364 Eventtype::Clock => Some(EventResultType::Clock(0)),
365 Eventtype::Unknown => None,
366 };
367 if let Some(inner) = inner {
368 ret.push((
369 EventResult {
370 userdata: self.subscription.userdata,
371 error,
372 type_: self.subscription.type_,
373 inner,
374 },
375 if error == Errno::Success {
376 EpollType::EPOLLOUT
377 } else {
378 EpollType::EPOLLERR
379 },
380 ))
381 .ok();
382 }
383 }
384 Poll::Pending => {}
385 };
386 }
387 if !ret.is_empty() {
388 return Poll::Ready(ret);
389 }
390 Poll::Pending
391 }
392}
393
394#[derive(Debug)]
395pub(crate) struct InodeValFileReadGuard {
396 guard: OwnedRwLockReadGuard<Box<dyn VirtualFile + Send + Sync + 'static>>,
397}
398
399impl InodeValFileReadGuard {
400 pub(crate) fn new(file: &Arc<RwLock<Box<dyn VirtualFile + Send + Sync + 'static>>>) -> Self {
401 Self {
402 guard: crate::utils::read_owned(file).unwrap(),
403 }
404 }
405}
406
407impl InodeValFileReadGuard {
408 pub fn into_poll_guard(
409 self,
410 fd: u32,
411 peb: PollEventSet,
412 subscription: Subscription,
413 ) -> InodeValFilePollGuard {
414 InodeValFilePollGuard {
415 fd,
416 peb,
417 subscription,
418 mode: InodeValFilePollGuardMode::File(self.guard.into_inner()),
419 }
420 }
421}
422
423impl Deref for InodeValFileReadGuard {
424 type Target = dyn VirtualFile + Send + Sync + 'static;
425 fn deref(&self) -> &Self::Target {
426 self.guard.deref().deref()
427 }
428}
429
430#[derive(Debug)]
431pub struct InodeValFileWriteGuard {
432 guard: OwnedRwLockWriteGuard<Box<dyn VirtualFile + Send + Sync + 'static>>,
433}
434
435impl InodeValFileWriteGuard {
436 pub(crate) fn new(file: &Arc<RwLock<Box<dyn VirtualFile + Send + Sync + 'static>>>) -> Self {
437 Self {
438 guard: crate::utils::write_owned(file).unwrap(),
439 }
440 }
441 pub(crate) fn swap(
442 &mut self,
443 mut file: Box<dyn VirtualFile + Send + Sync + 'static>,
444 ) -> Box<dyn VirtualFile + Send + Sync + 'static> {
445 std::mem::swap(self.guard.deref_mut(), &mut file);
446 file
447 }
448}
449
450impl Deref for InodeValFileWriteGuard {
451 type Target = dyn VirtualFile + Send + Sync + 'static;
452 fn deref(&self) -> &Self::Target {
453 self.guard.deref().deref()
454 }
455}
456impl DerefMut for InodeValFileWriteGuard {
457 fn deref_mut(&mut self) -> &mut Self::Target {
458 self.guard.deref_mut().deref_mut()
459 }
460}
461
462#[derive(Debug)]
463pub(crate) struct WasiStateFileGuard {
464 inode: InodeGuard,
465}
466
467impl WasiStateFileGuard {
468 pub fn new(state: &WasiState, fd: wasi::Fd) -> Result<Option<Self>, FsError> {
469 let fd_map = state.fs.fd_map.read().unwrap();
470 if let Some(fd) = fd_map.get(fd) {
471 Ok(Some(Self {
472 inode: fd.inode.clone(),
473 }))
474 } else {
475 Ok(None)
476 }
477 }
478
479 pub fn lock_read(&self) -> Option<InodeValFileReadGuard> {
480 let guard = self.inode.read();
481 if let Kind::File { handle, .. } = guard.deref() {
482 handle.as_ref().map(InodeValFileReadGuard::new)
483 } else {
484 unreachable!("Non-file found in standard device location")
486 }
487 }
488
489 pub fn lock_write(&self) -> Option<InodeValFileWriteGuard> {
490 let guard = self.inode.read();
491 if let Kind::File { handle, .. } = guard.deref() {
492 handle.as_ref().map(InodeValFileWriteGuard::new)
493 } else {
494 unreachable!("Non-file found in standard device location")
496 }
497 }
498}
499
500impl VirtualFile for WasiStateFileGuard {
501 fn last_accessed(&self) -> u64 {
502 let guard = self.lock_read();
503 if let Some(file) = guard.as_ref() {
504 file.last_accessed()
505 } else {
506 0
507 }
508 }
509
510 fn last_modified(&self) -> u64 {
511 let guard = self.lock_read();
512 if let Some(file) = guard.as_ref() {
513 file.last_modified()
514 } else {
515 0
516 }
517 }
518
519 fn created_time(&self) -> u64 {
520 let guard = self.lock_read();
521 if let Some(file) = guard.as_ref() {
522 file.created_time()
523 } else {
524 0
525 }
526 }
527
528 fn set_times(
529 &mut self,
530 atime: Option<u64>,
531 mtime: Option<u64>,
532 ) -> Result<(), virtual_fs::FsError> {
533 let mut guard = self.lock_write();
534 if let Some(file) = guard.as_mut() {
535 file.set_times(atime, mtime)
536 } else {
537 Err(crate::FsError::Lock)
538 }
539 }
540
541 fn size(&self) -> u64 {
542 let guard = self.lock_read();
543 if let Some(file) = guard.as_ref() {
544 file.size()
545 } else {
546 0
547 }
548 }
549
550 fn set_len(&mut self, new_size: u64) -> Result<(), FsError> {
551 let mut guard = self.lock_write();
552 if let Some(file) = guard.as_mut() {
553 file.set_len(new_size)
554 } else {
555 Err(FsError::IOError)
556 }
557 }
558
559 fn unlink(&mut self) -> Result<(), FsError> {
560 let mut guard = self.lock_write();
561 if let Some(file) = guard.as_mut() {
562 file.unlink()
563 } else {
564 Err(FsError::IOError)
565 }
566 }
567
568 fn is_open(&self) -> bool {
569 let guard = self.lock_read();
570 if let Some(file) = guard.as_ref() {
571 file.is_open()
572 } else {
573 false
574 }
575 }
576
577 fn poll_read_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<usize>> {
578 let mut guard = self.lock_write();
579 if let Some(file) = guard.as_mut() {
580 let file = Pin::new(file.deref_mut());
581 file.poll_read_ready(cx)
582 } else {
583 Poll::Ready(Ok(0))
584 }
585 }
586
587 fn poll_write_ready(
588 self: Pin<&mut Self>,
589 cx: &mut Context<'_>,
590 ) -> Poll<std::io::Result<usize>> {
591 let mut guard = self.lock_write();
592 if let Some(file) = guard.as_mut() {
593 let file = Pin::new(file.deref_mut());
594 file.poll_write_ready(cx)
595 } else {
596 Poll::Ready(Ok(0))
597 }
598 }
599}
600
601impl AsyncSeek for WasiStateFileGuard {
602 fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> std::io::Result<()> {
603 let mut guard = self.lock_write();
604 if let Some(guard) = guard.as_mut() {
605 let file = Pin::new(guard.deref_mut());
606 file.start_seek(position)
607 } else {
608 Err(std::io::ErrorKind::Unsupported.into())
609 }
610 }
611 fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<u64>> {
612 let mut guard = self.lock_write();
613 if let Some(guard) = guard.as_mut() {
614 let file = Pin::new(guard.deref_mut());
615 file.poll_complete(cx)
616 } else {
617 Poll::Ready(Err(std::io::ErrorKind::Unsupported.into()))
618 }
619 }
620}
621
622impl AsyncWrite for WasiStateFileGuard {
623 fn poll_write(
624 self: Pin<&mut Self>,
625 cx: &mut Context<'_>,
626 buf: &[u8],
627 ) -> Poll<std::io::Result<usize>> {
628 let mut guard = self.lock_write();
629 if let Some(guard) = guard.as_mut() {
630 let file = Pin::new(guard.deref_mut());
631 file.poll_write(cx, buf)
632 } else {
633 Poll::Ready(Err(std::io::ErrorKind::Unsupported.into()))
634 }
635 }
636 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
637 let mut guard = self.lock_write();
638 if let Some(guard) = guard.as_mut() {
639 let file = Pin::new(guard.deref_mut());
640 file.poll_flush(cx)
641 } else {
642 Poll::Ready(Err(std::io::ErrorKind::Unsupported.into()))
643 }
644 }
645 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
646 let mut guard = self.lock_write();
647 if let Some(guard) = guard.as_mut() {
648 let file = Pin::new(guard.deref_mut());
649 file.poll_shutdown(cx)
650 } else {
651 Poll::Ready(Err(std::io::ErrorKind::Unsupported.into()))
652 }
653 }
654 fn poll_write_vectored(
655 self: Pin<&mut Self>,
656 cx: &mut Context<'_>,
657 bufs: &[IoSlice<'_>],
658 ) -> Poll<std::io::Result<usize>> {
659 let mut guard = self.lock_write();
660 if let Some(guard) = guard.as_mut() {
661 let file = Pin::new(guard.deref_mut());
662 file.poll_write_vectored(cx, bufs)
663 } else {
664 Poll::Ready(Err(std::io::ErrorKind::Unsupported.into()))
665 }
666 }
667 fn is_write_vectored(&self) -> bool {
668 let mut guard = self.lock_write();
669 if let Some(guard) = guard.as_mut() {
670 let file = Pin::new(guard.deref_mut());
671 file.is_write_vectored()
672 } else {
673 false
674 }
675 }
676}
677
678impl AsyncRead for WasiStateFileGuard {
679 fn poll_read(
680 self: Pin<&mut Self>,
681 cx: &mut Context<'_>,
682 buf: &mut tokio::io::ReadBuf<'_>,
683 ) -> Poll<std::io::Result<()>> {
684 let mut guard = self.lock_write();
685 if let Some(guard) = guard.as_mut() {
686 let file = Pin::new(guard.deref_mut());
687 file.poll_read(cx, buf)
688 } else {
689 Poll::Ready(Err(std::io::ErrorKind::Unsupported.into()))
690 }
691 }
692}
693
694fn is_err_closed(err: &std::io::Error) -> bool {
695 err.kind() == std::io::ErrorKind::ConnectionAborted
696 || err.kind() == std::io::ErrorKind::ConnectionRefused
697 || err.kind() == std::io::ErrorKind::ConnectionReset
698 || err.kind() == std::io::ErrorKind::BrokenPipe
699 || err.kind() == std::io::ErrorKind::NotConnected
700 || err.kind() == std::io::ErrorKind::UnexpectedEof
701}