virtual_net/
host.rs

1#![allow(unused_variables)]
2use crate::ruleset::{Direction, Ruleset};
3#[allow(unused_imports)]
4use crate::{
5    IpCidr, IpRoute, NetworkError, Result, SocketStatus, StreamSecurity, VirtualConnectedSocket,
6    VirtualConnectionlessSocket, VirtualIcmpSocket, VirtualNetworking, VirtualRawSocket,
7    VirtualSocket, VirtualTcpListener, VirtualTcpSocket, VirtualUdpSocket,
8};
9use crate::{VirtualIoSource, io_err_into_net_error};
10use bytes::{Buf, BytesMut};
11use std::collections::VecDeque;
12use std::io::{self, Read, Write};
13use std::mem::MaybeUninit;
14use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, Shutdown, SocketAddr};
15#[cfg(not(target_os = "windows"))]
16use std::os::fd::AsRawFd;
17#[cfg(not(target_os = "windows"))]
18use std::os::fd::RawFd;
19#[cfg(windows)]
20use std::os::windows::io::AsRawSocket;
21
22use std::sync::Arc;
23use std::task::Poll;
24use std::time::Duration;
25use tokio::runtime::Handle;
26#[allow(unused_imports, dead_code)]
27use tracing::{debug, error, info, trace, warn};
28use virtual_mio::{
29    HandlerGuardState, InterestGuard, InterestHandler, InterestType, Selector, state_as_waker_map,
30};
31
32#[derive(Debug)]
33pub struct LocalNetworking {
34    selector: Arc<Selector>,
35    handle: Handle,
36    ruleset: Option<Ruleset>,
37}
38
39impl LocalNetworking {
40    pub fn new() -> Self {
41        Self {
42            selector: Selector::new(),
43            handle: Handle::current(),
44            ruleset: None,
45        }
46    }
47
48    pub fn with_ruleset(ruleset: Ruleset) -> Self {
49        Self {
50            selector: Selector::new(),
51            handle: Handle::current(),
52            ruleset: Some(ruleset),
53        }
54    }
55}
56
57impl Drop for LocalNetworking {
58    fn drop(&mut self) {
59        self.selector.shutdown();
60    }
61}
62
63impl Default for LocalNetworking {
64    fn default() -> Self {
65        Self::new()
66    }
67}
68
69#[async_trait::async_trait]
70#[allow(unused_variables)]
71impl VirtualNetworking for LocalNetworking {
72    async fn listen_tcp(
73        &self,
74        addr: SocketAddr,
75        only_v6: bool,
76        reuse_port: bool,
77        reuse_addr: bool,
78    ) -> Result<Box<dyn VirtualTcpListener + Sync>> {
79        if let Some(ruleset) = self.ruleset.as_ref()
80            && !ruleset.allows_socket(addr, Direction::Inbound)
81        {
82            tracing::warn!(%addr, "listen_tcp blocked by firewall rule");
83            return Err(NetworkError::PermissionDenied);
84        }
85
86        let listener = std::net::TcpListener::bind(addr)
87            .map(|sock| {
88                sock.set_nonblocking(true).ok();
89                Box::new(LocalTcpListener {
90                    stream: mio::net::TcpListener::from_std(sock),
91                    selector: self.selector.clone(),
92                    handler_guard: HandlerGuardState::None,
93                    no_delay: None,
94                    keep_alive: None,
95                    backlog: Default::default(),
96                    ruleset: self.ruleset.clone(),
97                })
98            })
99            .map_err(io_err_into_net_error)?;
100        Ok(listener)
101    }
102
103    async fn bind_udp(
104        &self,
105        addr: SocketAddr,
106        reuse_port: bool,
107        reuse_addr: bool,
108    ) -> Result<Box<dyn VirtualUdpSocket + Sync>> {
109        #[cfg(not(windows))]
110        use socket2::{Domain, Socket, Type};
111
112        if let Some(ruleset) = self.ruleset.as_ref()
113            && !ruleset.allows_socket(addr, Direction::Inbound)
114        {
115            tracing::warn!(%addr, "bind_udp blocked by firewall rule");
116            return Err(NetworkError::PermissionDenied);
117        }
118
119        #[cfg(not(windows))]
120        let socket = {
121            let domain = if addr.is_ipv4() {
122                Domain::IPV4
123            } else {
124                Domain::IPV6
125            };
126            let std_sock = Socket::new(domain, Type::DGRAM, None).map_err(io_err_into_net_error)?;
127            std_sock
128                .set_nonblocking(true)
129                .map_err(io_err_into_net_error)?;
130            std_sock
131                .set_reuse_address(reuse_addr)
132                .map_err(io_err_into_net_error)?;
133            std_sock
134                .set_reuse_port(reuse_port)
135                .map_err(io_err_into_net_error)?;
136            std_sock.bind(&addr.into()).map_err(io_err_into_net_error)?;
137            mio::net::UdpSocket::from_std(std_sock.into())
138        };
139        #[cfg(windows)]
140        let socket = mio::net::UdpSocket::bind(addr).map_err(io_err_into_net_error)?;
141
142        #[allow(unused_mut)]
143        let mut ret = LocalUdpSocket {
144            selector: self.selector.clone(),
145            socket,
146            addr,
147            handler_guard: HandlerGuardState::None,
148            backlog: Default::default(),
149            ruleset: self.ruleset.clone(),
150        };
151
152        // In windows we can not poll the socket as it is not supported and hence
153        // what we do is immediately set the writable flag and relay on `mio` to
154        // refresh that flag when the state changes. In Linux what we do is actually
155        // make a non-blocking `poll` call to determine this state
156        #[cfg(target_os = "windows")]
157        {
158            let (state, selector, socket) = ret.split_borrow();
159            let map = state_as_waker_map(state, selector, socket).map_err(io_err_into_net_error)?;
160            map.push(InterestType::Writable);
161        }
162
163        Ok(Box::new(ret))
164    }
165
166    async fn connect_tcp(
167        &self,
168        _addr: SocketAddr,
169        mut peer: SocketAddr,
170    ) -> Result<Box<dyn VirtualTcpSocket + Sync>> {
171        if let Some(ruleset) = self.ruleset.as_ref()
172            && !ruleset.allows_socket(peer, Direction::Outbound)
173        {
174            tracing::warn!(%peer, "connect_tcp blocked by firewall rule");
175            return Err(NetworkError::PermissionDenied);
176        }
177
178        // This future may be polled outside Tokio's executor context, so run
179        // the connect itself on the stored runtime handle to guarantee a reactor.
180        let stream = self
181            .handle
182            .spawn(tokio::net::TcpStream::connect(peer))
183            .await
184            .map_err(|_| NetworkError::IOError)?
185            .map_err(io_err_into_net_error)?;
186        let stream = stream.into_std().map_err(io_err_into_net_error)?;
187        let stream = mio::net::TcpStream::from_std(stream);
188
189        if let Ok(p) = stream.peer_addr() {
190            peer = p;
191        }
192        let socket = Box::new(LocalTcpStream::new(self.selector.clone(), stream, peer));
193        Ok(socket)
194    }
195
196    async fn resolve(
197        &self,
198        host: &str,
199        port: Option<u16>,
200        dns_server: Option<IpAddr>,
201    ) -> Result<Vec<IpAddr>> {
202        if let Some(ruleset) = self.ruleset.as_ref()
203            && !ruleset.allows_domain(host)
204        {
205            tracing::warn!(%host, "dns resolve blocked by firewall rule");
206            return Err(NetworkError::PermissionDenied);
207        }
208
209        let host_to_lookup = if host.contains(':') {
210            host.to_string()
211        } else {
212            format!("{}:{}", host, port.unwrap_or(0))
213        };
214        let addrs = self
215            .handle
216            .spawn(tokio::net::lookup_host(host_to_lookup))
217            .await
218            .map_err(|_| NetworkError::IOError)?
219            .map(|a| a.map(|a| a.ip()).collect::<Vec<_>>())
220            .map_err(io_err_into_net_error)?;
221
222        if let Some(ruleset) = self.ruleset.as_ref() {
223            if let Err(e) = ruleset.expand_domain(host, &addrs) {
224                tracing::debug!(err=%e, "ruleset expansion failed");
225            } else {
226                tracing::debug!(addrs=?addrs, domain = host, "ruleset expansion")
227            }
228        }
229
230        Ok(addrs)
231    }
232}
233
234#[derive(Debug)]
235pub struct LocalTcpListener {
236    stream: mio::net::TcpListener,
237    selector: Arc<Selector>,
238    handler_guard: HandlerGuardState,
239    no_delay: Option<bool>,
240    keep_alive: Option<bool>,
241    backlog: VecDeque<(Box<dyn VirtualTcpSocket + Sync>, SocketAddr)>,
242    ruleset: Option<Ruleset>,
243}
244
245impl LocalTcpListener {
246    fn try_accept_internal(&mut self) -> Result<(Box<dyn VirtualTcpSocket + Sync>, SocketAddr)> {
247        match self.stream.accept().map_err(io_err_into_net_error) {
248            Ok((stream, addr)) => {
249                if let Some(ruleset) = self.ruleset.as_ref()
250                    && !ruleset.allows_socket(addr, Direction::Outbound)
251                {
252                    tracing::warn!(%addr, "try_accept blocked by firewall rule");
253                    return Err(NetworkError::PermissionDenied);
254                }
255
256                let mut socket = LocalTcpStream::new(self.selector.clone(), stream, addr);
257                if let Some(no_delay) = self.no_delay {
258                    socket.set_nodelay(no_delay).ok();
259                }
260                if let Some(keep_alive) = self.keep_alive {
261                    socket.set_keepalive(keep_alive).ok();
262                }
263                Ok((Box::new(socket), addr))
264            }
265            Err(NetworkError::WouldBlock) => {
266                if let HandlerGuardState::WakerMap(_, map) = &mut self.handler_guard {
267                    map.pop(InterestType::Readable);
268                    map.pop(InterestType::Writable);
269                }
270                Err(NetworkError::WouldBlock)
271            }
272            Err(err) => Err(err),
273        }
274    }
275}
276
277impl VirtualTcpListener for LocalTcpListener {
278    fn try_accept(&mut self) -> Result<(Box<dyn VirtualTcpSocket + Sync>, SocketAddr)> {
279        if let Some(child) = self.backlog.pop_front() {
280            return Ok(child);
281        }
282        self.try_accept_internal()
283    }
284
285    fn set_handler(&mut self, mut handler: Box<dyn InterestHandler + Send + Sync>) -> Result<()> {
286        if let HandlerGuardState::ExternalHandler(guard) = &mut self.handler_guard {
287            match guard.replace_handler(handler) {
288                Ok(()) => return Ok(()),
289                Err(h) => handler = h,
290            }
291
292            // the handler could not be replaced so we need to build a new handler instead
293            if let Err(err) = guard.unregister(&mut self.stream) {
294                tracing::debug!("failed to unregister previous token - {}", err);
295            }
296        }
297
298        let guard = InterestGuard::new(
299            &self.selector,
300            handler,
301            &mut self.stream,
302            mio::Interest::READABLE.add(mio::Interest::WRITABLE),
303        )
304        .map_err(io_err_into_net_error)?;
305
306        self.handler_guard = HandlerGuardState::ExternalHandler(guard);
307
308        Ok(())
309    }
310
311    fn addr_local(&self) -> Result<SocketAddr> {
312        self.stream.local_addr().map_err(io_err_into_net_error)
313    }
314
315    fn set_ttl(&mut self, ttl: u8) -> Result<()> {
316        self.stream
317            .set_ttl(ttl as u32)
318            .map_err(io_err_into_net_error)
319    }
320
321    fn ttl(&self) -> Result<u8> {
322        self.stream
323            .ttl()
324            .map(|ttl| ttl as u8)
325            .map_err(io_err_into_net_error)
326    }
327}
328
329impl LocalTcpListener {
330    fn split_borrow(
331        &mut self,
332    ) -> (
333        &mut HandlerGuardState,
334        &Arc<Selector>,
335        &mut mio::net::TcpListener,
336    ) {
337        (&mut self.handler_guard, &self.selector, &mut self.stream)
338    }
339}
340
341impl VirtualIoSource for LocalTcpListener {
342    fn remove_handler(&mut self) {
343        let mut guard = HandlerGuardState::None;
344        std::mem::swap(&mut guard, &mut self.handler_guard);
345        match guard {
346            HandlerGuardState::ExternalHandler(mut guard) => {
347                guard.unregister(&mut self.stream).ok();
348            }
349            HandlerGuardState::WakerMap(mut guard, _) => {
350                guard.unregister(&mut self.stream).ok();
351            }
352            HandlerGuardState::None => {}
353        }
354    }
355
356    fn poll_read_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
357        if !self.backlog.is_empty() {
358            return Poll::Ready(Ok(self.backlog.len()));
359        }
360
361        let (state, selector, source) = self.split_borrow();
362        let map = state_as_waker_map(state, selector, source).map_err(io_err_into_net_error)?;
363        map.add(InterestType::Readable, cx.waker());
364
365        if let Ok(child) = self.try_accept_internal() {
366            self.backlog.push_back(child);
367            return Poll::Ready(Ok(1));
368        }
369        Poll::Pending
370    }
371
372    fn poll_write_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
373        if !self.backlog.is_empty() {
374            return Poll::Ready(Ok(self.backlog.len()));
375        }
376
377        let (state, selector, source) = self.split_borrow();
378        let map = state_as_waker_map(state, selector, source).map_err(io_err_into_net_error)?;
379        map.add(InterestType::Writable, cx.waker());
380
381        if let Ok(child) = self.try_accept_internal() {
382            self.backlog.push_back(child);
383            return Poll::Ready(Ok(1));
384        }
385        Poll::Pending
386    }
387}
388
389#[derive(Debug)]
390pub struct LocalTcpStream {
391    stream: mio::net::TcpStream,
392    addr: SocketAddr,
393    shutdown: Option<Shutdown>,
394    selector: Arc<Selector>,
395    handler_guard: HandlerGuardState,
396    buffer: BytesMut,
397}
398
399impl LocalTcpStream {
400    fn new(selector: Arc<Selector>, stream: mio::net::TcpStream, addr: SocketAddr) -> Self {
401        #[allow(unused_mut)]
402        let mut ret = Self {
403            stream,
404            addr,
405            shutdown: None,
406            selector,
407            handler_guard: HandlerGuardState::None,
408            buffer: BytesMut::new(),
409        };
410
411        // In windows we can not poll the socket as it is not supported and hence
412        // what we do is immediately set the writable flag and relay on `mio` to
413        // refresh that flag when the state changes. In Linux what we do is actually
414        // make a non-blocking `poll` call to determine this state
415        #[cfg(target_os = "windows")]
416        {
417            let (state, selector, socket, _) = ret.split_borrow();
418            if let Ok(map) = state_as_waker_map(state, selector, socket) {
419                map.push(InterestType::Writable);
420            }
421        }
422
423        ret
424    }
425
426    fn with_sock_ref<F, R>(&self, f: F) -> R
427    where
428        for<'a> F: FnOnce(socket2::SockRef<'a>) -> R,
429    {
430        #[cfg(not(windows))]
431        let r = socket2::SockRef::from(&self.stream);
432
433        #[cfg(windows)]
434        let b = unsafe {
435            std::os::windows::io::BorrowedSocket::borrow_raw(self.stream.as_raw_socket())
436        };
437        #[cfg(windows)]
438        let r = socket2::SockRef::from(&b);
439
440        f(r)
441    }
442}
443
444impl VirtualTcpSocket for LocalTcpStream {
445    fn set_recv_buf_size(&mut self, size: usize) -> Result<()> {
446        Ok(())
447    }
448
449    fn recv_buf_size(&self) -> Result<usize> {
450        Err(NetworkError::Unsupported)
451    }
452
453    fn set_send_buf_size(&mut self, size: usize) -> Result<()> {
454        Ok(())
455    }
456
457    fn send_buf_size(&self) -> Result<usize> {
458        Err(NetworkError::Unsupported)
459    }
460
461    fn set_nodelay(&mut self, nodelay: bool) -> Result<()> {
462        self.stream
463            .set_nodelay(nodelay)
464            .map_err(io_err_into_net_error)
465    }
466
467    fn nodelay(&self) -> Result<bool> {
468        self.stream.nodelay().map_err(io_err_into_net_error)
469    }
470
471    fn set_keepalive(&mut self, keepalive: bool) -> Result<()> {
472        self.with_sock_ref(|s| s.set_keepalive(true))
473            .map_err(io_err_into_net_error)?;
474        Ok(())
475    }
476
477    fn keepalive(&self) -> Result<bool> {
478        let ret = self
479            .with_sock_ref(|s| s.keepalive())
480            .map_err(io_err_into_net_error)?;
481        Ok(ret)
482    }
483
484    #[cfg(not(target_os = "windows"))]
485    fn set_dontroute(&mut self, val: bool) -> Result<()> {
486        // TODO:
487        // Don't route is being set by WASIX which breaks networking
488        // Why this is being set is unknown but we need to disable
489        // the functionality for now as it breaks everything
490
491        let val = val as libc::c_int;
492        let payload = &val as *const libc::c_int as *const libc::c_void;
493        let err = unsafe {
494            libc::setsockopt(
495                self.stream.as_raw_fd(),
496                libc::SOL_SOCKET,
497                libc::SO_DONTROUTE,
498                payload,
499                std::mem::size_of::<libc::c_int>() as libc::socklen_t,
500            )
501        };
502        if err == -1 {
503            return Err(io_err_into_net_error(std::io::Error::last_os_error()));
504        }
505        Ok(())
506    }
507    #[cfg(target_os = "windows")]
508    fn set_dontroute(&mut self, val: bool) -> Result<()> {
509        Err(NetworkError::Unsupported)
510    }
511
512    #[cfg(not(target_os = "windows"))]
513    fn dontroute(&self) -> Result<bool> {
514        let mut payload: MaybeUninit<libc::c_int> = MaybeUninit::uninit();
515        let mut len = std::mem::size_of::<libc::c_int>() as libc::socklen_t;
516        let err = unsafe {
517            libc::getsockopt(
518                self.stream.as_raw_fd(),
519                libc::SOL_SOCKET,
520                libc::SO_DONTROUTE,
521                payload.as_mut_ptr().cast(),
522                &mut len,
523            )
524        };
525        if err == -1 {
526            return Err(io_err_into_net_error(std::io::Error::last_os_error()));
527        }
528        Ok(unsafe { payload.assume_init() != 0 })
529    }
530    #[cfg(target_os = "windows")]
531    fn dontroute(&self) -> Result<bool> {
532        Err(NetworkError::Unsupported)
533    }
534
535    fn addr_peer(&self) -> Result<SocketAddr> {
536        Ok(self.addr)
537    }
538
539    fn shutdown(&mut self, how: Shutdown) -> Result<()> {
540        self.stream.shutdown(how).map_err(io_err_into_net_error)?;
541        self.shutdown = Some(how);
542        Ok(())
543    }
544
545    fn is_closed(&self) -> bool {
546        false
547    }
548}
549
550impl VirtualConnectedSocket for LocalTcpStream {
551    fn set_linger(&mut self, linger: Option<Duration>) -> Result<()> {
552        self.with_sock_ref(|s| s.set_linger(linger))
553            .map_err(io_err_into_net_error)?;
554        Ok(())
555    }
556
557    fn linger(&self) -> Result<Option<Duration>> {
558        self.with_sock_ref(|s| s.linger())
559            .map_err(io_err_into_net_error)
560    }
561
562    fn try_send(&mut self, data: &[u8]) -> Result<usize> {
563        let ret = self.stream.write(data).map_err(io_err_into_net_error);
564        match &ret {
565            Ok(0) | Err(NetworkError::WouldBlock) => {
566                if let HandlerGuardState::WakerMap(_, map) = &mut self.handler_guard {
567                    map.pop(InterestType::Writable);
568                }
569            }
570            _ => {}
571        }
572        ret
573    }
574
575    fn try_flush(&mut self) -> Result<()> {
576        self.stream.flush().map_err(io_err_into_net_error)
577    }
578
579    fn close(&mut self) -> Result<()> {
580        Ok(())
581    }
582
583    fn try_recv(&mut self, buf: &mut [MaybeUninit<u8>], peek: bool) -> Result<usize> {
584        let buf: &mut [u8] = unsafe { std::mem::transmute(buf) };
585        if !self.buffer.is_empty() {
586            let amt = buf.len().min(self.buffer.len());
587            buf[..amt].copy_from_slice(&self.buffer[..amt]);
588            if !peek {
589                self.buffer.advance(amt);
590            }
591            return Ok(amt);
592        }
593
594        if peek {
595            self.stream.peek(buf)
596        } else {
597            self.stream.read(buf)
598        }
599        .map_err(io_err_into_net_error)
600    }
601}
602
603impl VirtualSocket for LocalTcpStream {
604    fn set_ttl(&mut self, ttl: u32) -> Result<()> {
605        self.stream.set_ttl(ttl).map_err(io_err_into_net_error)
606    }
607
608    fn ttl(&self) -> Result<u32> {
609        self.stream.ttl().map_err(io_err_into_net_error)
610    }
611
612    fn addr_local(&self) -> Result<SocketAddr> {
613        self.stream.local_addr().map_err(io_err_into_net_error)
614    }
615
616    fn status(&self) -> Result<SocketStatus> {
617        Ok(SocketStatus::Opened)
618    }
619
620    fn set_handler(&mut self, mut handler: Box<dyn InterestHandler + Send + Sync>) -> Result<()> {
621        if let HandlerGuardState::ExternalHandler(guard) = &mut self.handler_guard {
622            match guard.replace_handler(handler) {
623                Ok(()) => return Ok(()),
624                Err(h) => handler = h,
625            }
626
627            // the handler could not be replaced so we need to build a new handler instead
628            if let Err(err) = guard.unregister(&mut self.stream) {
629                tracing::debug!("failed to unregister previous token - {}", err);
630            }
631        }
632
633        let guard = InterestGuard::new(
634            &self.selector,
635            handler,
636            &mut self.stream,
637            mio::Interest::READABLE.add(mio::Interest::WRITABLE),
638        )
639        .map_err(io_err_into_net_error)?;
640
641        self.handler_guard = HandlerGuardState::ExternalHandler(guard);
642
643        Ok(())
644    }
645}
646
647impl LocalTcpStream {
648    fn split_borrow(
649        &mut self,
650    ) -> (
651        &mut HandlerGuardState,
652        &Arc<Selector>,
653        &mut mio::net::TcpStream,
654        &mut BytesMut,
655    ) {
656        (
657            &mut self.handler_guard,
658            &self.selector,
659            &mut self.stream,
660            &mut self.buffer,
661        )
662    }
663}
664
665impl VirtualIoSource for LocalTcpStream {
666    fn remove_handler(&mut self) {
667        let mut guard = HandlerGuardState::None;
668        std::mem::swap(&mut guard, &mut self.handler_guard);
669        match guard {
670            HandlerGuardState::ExternalHandler(mut guard) => {
671                guard.unregister(&mut self.stream).ok();
672            }
673            HandlerGuardState::WakerMap(mut guard, _) => {
674                guard.unregister(&mut self.stream).ok();
675            }
676            HandlerGuardState::None => {}
677        }
678    }
679
680    fn poll_read_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
681        if !self.buffer.is_empty() {
682            return Poll::Ready(Ok(self.buffer.len()));
683        }
684
685        let (state, selector, stream, buffer) = self.split_borrow();
686        let map = state_as_waker_map(state, selector, stream).map_err(io_err_into_net_error)?;
687        map.pop(InterestType::Readable);
688        map.add(InterestType::Readable, cx.waker());
689
690        buffer.reserve(buffer.len() + 10240);
691        let uninit: &mut [MaybeUninit<u8>] = buffer.spare_capacity_mut();
692        let uninit_unsafe: &mut [u8] = unsafe { std::mem::transmute(uninit) };
693
694        match stream.read(uninit_unsafe) {
695            Ok(0) => Poll::Ready(Ok(0)),
696            Ok(amt) => {
697                unsafe {
698                    buffer.set_len(buffer.len() + amt);
699                }
700                Poll::Ready(Ok(amt))
701            }
702            Err(err) if err.kind() == io::ErrorKind::ConnectionAborted => Poll::Ready(Ok(0)),
703            Err(err) if err.kind() == io::ErrorKind::ConnectionReset => Poll::Ready(Ok(0)),
704            Err(err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
705            Err(err) => Poll::Ready(Err(io_err_into_net_error(err))),
706        }
707    }
708
709    fn poll_write_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
710        let (state, selector, stream, _) = self.split_borrow();
711        let map = state_as_waker_map(state, selector, stream).map_err(io_err_into_net_error)?;
712        #[cfg(not(target_os = "windows"))]
713        map.pop(InterestType::Writable);
714        map.add(InterestType::Writable, cx.waker());
715        map.add(InterestType::Closed, cx.waker());
716        if map.has_interest(InterestType::Closed) {
717            return Poll::Ready(Ok(0));
718        }
719
720        #[cfg(not(target_os = "windows"))]
721        match libc_poll(stream.as_raw_fd(), libc::POLLOUT | libc::POLLHUP) {
722            Some(val) if (val & libc::POLLHUP) != 0 => {
723                return Poll::Ready(Ok(0));
724            }
725            Some(val) if (val & libc::POLLOUT) != 0 => return Poll::Ready(Ok(10240)),
726            _ => {}
727        }
728
729        // In windows we can not poll the socket as it is not supported and hence
730        // what we do is immediately set the writable flag and relay on `mio` to
731        // refresh that flag when the state changes. In Linux what we do is actually
732        // make a non-blocking `poll` call to determine this state
733        #[cfg(target_os = "windows")]
734        if map.has_interest(InterestType::Writable) {
735            return Poll::Ready(Ok(10240));
736        }
737
738        Poll::Pending
739    }
740}
741
742#[cfg(not(target_os = "windows"))]
743fn libc_poll(fd: RawFd, events: libc::c_short) -> Option<libc::c_short> {
744    let mut fds: [libc::pollfd; 1] = [libc::pollfd {
745        fd,
746        events,
747        revents: 0,
748    }];
749    let fds_mut = &mut fds[..];
750    let ret = unsafe { libc::poll(fds_mut.as_mut_ptr(), 1, 0) };
751    match ret == 1 {
752        true => Some(fds[0].revents),
753        false => None,
754    }
755}
756
757#[derive(Debug)]
758pub struct LocalUdpSocket {
759    socket: mio::net::UdpSocket,
760    #[allow(dead_code)]
761    addr: SocketAddr,
762    selector: Arc<Selector>,
763    handler_guard: HandlerGuardState,
764    backlog: VecDeque<(BytesMut, SocketAddr)>,
765    ruleset: Option<Ruleset>,
766}
767
768impl LocalUdpSocket {
769    fn with_sock_ref<F, R>(&self, f: F) -> R
770    where
771        for<'a> F: FnOnce(socket2::SockRef<'a>) -> R,
772    {
773        #[cfg(not(windows))]
774        let r = socket2::SockRef::from(&self.socket);
775
776        #[cfg(windows)]
777        let b = unsafe {
778            std::os::windows::io::BorrowedSocket::borrow_raw(self.socket.as_raw_socket())
779        };
780        #[cfg(windows)]
781        let r = socket2::SockRef::from(&b);
782
783        f(r)
784    }
785}
786
787impl VirtualUdpSocket for LocalUdpSocket {
788    fn set_broadcast(&mut self, broadcast: bool) -> Result<()> {
789        self.socket
790            .set_broadcast(broadcast)
791            .map_err(io_err_into_net_error)
792    }
793
794    fn broadcast(&self) -> Result<bool> {
795        self.socket.broadcast().map_err(io_err_into_net_error)
796    }
797
798    fn set_multicast_loop_v4(&mut self, val: bool) -> Result<()> {
799        self.socket
800            .set_multicast_loop_v4(val)
801            .map_err(io_err_into_net_error)
802    }
803
804    fn multicast_loop_v4(&self) -> Result<bool> {
805        self.socket
806            .multicast_loop_v4()
807            .map_err(io_err_into_net_error)
808    }
809
810    fn set_multicast_loop_v6(&mut self, val: bool) -> Result<()> {
811        self.socket
812            .set_multicast_loop_v6(val)
813            .map_err(io_err_into_net_error)
814    }
815
816    fn multicast_loop_v6(&self) -> Result<bool> {
817        self.socket
818            .multicast_loop_v6()
819            .map_err(io_err_into_net_error)
820    }
821
822    fn set_multicast_ttl_v4(&mut self, ttl: u32) -> Result<()> {
823        self.socket
824            .set_multicast_ttl_v4(ttl)
825            .map_err(io_err_into_net_error)
826    }
827
828    fn multicast_ttl_v4(&self) -> Result<u32> {
829        self.socket
830            .multicast_ttl_v4()
831            .map_err(io_err_into_net_error)
832    }
833
834    fn join_multicast_v4(&mut self, multiaddr: Ipv4Addr, iface: Ipv4Addr) -> Result<()> {
835        self.with_sock_ref(|s| s.join_multicast_v4(&multiaddr, &iface))
836            .map_err(io_err_into_net_error)
837    }
838
839    fn leave_multicast_v4(&mut self, multiaddr: Ipv4Addr, iface: Ipv4Addr) -> Result<()> {
840        self.with_sock_ref(|s| s.leave_multicast_v4(&multiaddr, &iface))
841            .map_err(io_err_into_net_error)
842    }
843
844    fn join_multicast_v6(&mut self, multiaddr: Ipv6Addr, iface: u32) -> Result<()> {
845        self.socket
846            .join_multicast_v6(&multiaddr, iface)
847            .map_err(io_err_into_net_error)
848    }
849
850    fn leave_multicast_v6(&mut self, multiaddr: Ipv6Addr, iface: u32) -> Result<()> {
851        self.socket
852            .leave_multicast_v6(&multiaddr, iface)
853            .map_err(io_err_into_net_error)
854    }
855
856    fn addr_peer(&self) -> Result<Option<SocketAddr>> {
857        self.socket
858            .peer_addr()
859            .map(Some)
860            .map_err(io_err_into_net_error)
861    }
862}
863
864impl VirtualConnectionlessSocket for LocalUdpSocket {
865    fn try_send_to(&mut self, data: &[u8], addr: SocketAddr) -> Result<usize> {
866        if let Some(ruleset) = self.ruleset.as_ref()
867            && !ruleset.allows_socket(addr, Direction::Outbound)
868        {
869            tracing::warn!(%addr, "try_send blocked by firewall rule");
870            return Err(NetworkError::PermissionDenied);
871        }
872
873        let ret = self
874            .socket
875            .send_to(data, addr)
876            .map_err(io_err_into_net_error);
877        match &ret {
878            Ok(0) | Err(NetworkError::WouldBlock) => {
879                if let HandlerGuardState::WakerMap(_, map) = &mut self.handler_guard {
880                    map.pop(InterestType::Writable);
881                }
882            }
883            _ => {}
884        }
885        ret
886    }
887
888    fn try_recv_from(
889        &mut self,
890        buf: &mut [MaybeUninit<u8>],
891        peek: bool,
892    ) -> Result<(usize, SocketAddr)> {
893        let buf: &mut [u8] = unsafe { std::mem::transmute(buf) };
894        if peek {
895            self.socket.peek_from(buf)
896        } else {
897            self.socket.recv_from(buf)
898        }
899        .map_err(io_err_into_net_error)
900    }
901}
902
903impl VirtualSocket for LocalUdpSocket {
904    fn set_ttl(&mut self, ttl: u32) -> Result<()> {
905        self.socket.set_ttl(ttl).map_err(io_err_into_net_error)
906    }
907
908    fn ttl(&self) -> Result<u32> {
909        self.socket.ttl().map_err(io_err_into_net_error)
910    }
911
912    fn addr_local(&self) -> Result<SocketAddr> {
913        self.socket.local_addr().map_err(io_err_into_net_error)
914    }
915
916    fn status(&self) -> Result<SocketStatus> {
917        Ok(SocketStatus::Opened)
918    }
919
920    fn set_handler(&mut self, mut handler: Box<dyn InterestHandler + Send + Sync>) -> Result<()> {
921        if let HandlerGuardState::ExternalHandler(guard) = &mut self.handler_guard {
922            match guard.replace_handler(handler) {
923                Ok(()) => {
924                    return Ok(());
925                }
926                Err(h) => handler = h,
927            }
928
929            // the handler could not be replaced so we need to build a new handler instead
930            if let Err(err) = guard.unregister(&mut self.socket) {
931                tracing::debug!("failed to unregister previous token - {}", err);
932            }
933        }
934
935        let guard = InterestGuard::new(
936            &self.selector,
937            handler,
938            &mut self.socket,
939            mio::Interest::READABLE.add(mio::Interest::WRITABLE),
940        )
941        .map_err(io_err_into_net_error)?;
942
943        self.handler_guard = HandlerGuardState::ExternalHandler(guard);
944
945        Ok(())
946    }
947}
948
949impl LocalUdpSocket {
950    fn split_borrow(
951        &mut self,
952    ) -> (
953        &mut HandlerGuardState,
954        &Arc<Selector>,
955        &mut mio::net::UdpSocket,
956    ) {
957        (&mut self.handler_guard, &self.selector, &mut self.socket)
958    }
959}
960
961impl VirtualIoSource for LocalUdpSocket {
962    fn remove_handler(&mut self) {
963        let mut guard = HandlerGuardState::None;
964        std::mem::swap(&mut guard, &mut self.handler_guard);
965        match guard {
966            HandlerGuardState::ExternalHandler(mut guard) => {
967                guard.unregister(&mut self.socket).ok();
968            }
969            HandlerGuardState::WakerMap(mut guard, _) => {
970                guard.unregister(&mut self.socket).ok();
971            }
972            HandlerGuardState::None => {}
973        }
974    }
975
976    fn poll_read_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
977        if !self.backlog.is_empty() {
978            let total = self.backlog.iter().map(|a| a.0.len()).sum();
979            return Poll::Ready(Ok(total));
980        }
981
982        let (state, selector, socket) = self.split_borrow();
983        let map = state_as_waker_map(state, selector, socket).map_err(io_err_into_net_error)?;
984        map.pop(InterestType::Readable);
985        map.add(InterestType::Readable, cx.waker());
986
987        let mut buffer = BytesMut::default();
988        buffer.reserve(10240);
989        let uninit: &mut [MaybeUninit<u8>] = buffer.spare_capacity_mut();
990        let uninit_unsafe: &mut [u8] = unsafe { std::mem::transmute(uninit) };
991
992        match self.socket.recv_from(uninit_unsafe) {
993            Ok((0, _)) => Poll::Ready(Ok(0)),
994            Ok((amt, peer)) => {
995                unsafe {
996                    buffer.set_len(amt);
997                }
998                self.backlog.push_back((buffer, peer));
999                Poll::Ready(Ok(amt))
1000            }
1001            Err(err) if err.kind() == io::ErrorKind::ConnectionAborted => Poll::Ready(Ok(0)),
1002            Err(err) if err.kind() == io::ErrorKind::ConnectionReset => Poll::Ready(Ok(0)),
1003            Err(err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
1004            Err(err) => Poll::Ready(Err(io_err_into_net_error(err))),
1005        }
1006    }
1007
1008    fn poll_write_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
1009        let (state, selector, socket) = self.split_borrow();
1010        let map = state_as_waker_map(state, selector, socket).map_err(io_err_into_net_error)?;
1011        #[cfg(not(target_os = "windows"))]
1012        map.pop(InterestType::Writable);
1013        map.add(InterestType::Writable, cx.waker());
1014
1015        #[cfg(not(target_os = "windows"))]
1016        match libc_poll(socket.as_raw_fd(), libc::POLLOUT | libc::POLLHUP) {
1017            Some(val) if (val & libc::POLLHUP) != 0 => {
1018                return Poll::Ready(Ok(0));
1019            }
1020            Some(val) if (val & libc::POLLOUT) != 0 => return Poll::Ready(Ok(10240)),
1021            _ => {}
1022        }
1023
1024        // In windows we can not poll the socket as it is not supported and hence
1025        // what we do is immediately set the writable flag and relay on `mio` to
1026        // refresh that flag when the state changes. In Linux what we do is actually
1027        // make a non-blocking `poll` call to determine this state
1028        #[cfg(target_os = "windows")]
1029        if map.has_interest(InterestType::Writable) {
1030            return Poll::Ready(Ok(10240));
1031        }
1032
1033        Poll::Pending
1034    }
1035}