virtual_net/
host.rs

1#![allow(unused_variables)]
2use crate::ruleset::{Direction, Ruleset};
3#[allow(unused_imports)]
4use crate::{
5    IpCidr, IpRoute, MAX_SOCKET_PAYLOAD, NetworkError, Result, SocketStatus, StreamSecurity,
6    VirtualConnectedSocket, VirtualConnectionlessSocket, VirtualIcmpSocket, VirtualNetworking,
7    VirtualRawSocket, VirtualSocket, VirtualTcpBoundSocket, VirtualTcpListener, VirtualTcpSocket,
8    VirtualUdpSocket,
9};
10use crate::{VirtualIoSource, io_err_into_net_error};
11use bytes::{Buf, BytesMut};
12use std::collections::VecDeque;
13use std::io::{self, Read, Write};
14use std::mem::MaybeUninit;
15use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, Shutdown, SocketAddr};
16#[cfg(not(target_os = "windows"))]
17use std::os::fd::AsRawFd;
18#[cfg(not(target_os = "windows"))]
19use std::os::fd::RawFd;
20#[cfg(windows)]
21use std::os::windows::io::AsRawSocket;
22
23use std::sync::{Arc, Mutex};
24use std::task::Poll;
25use std::time::Duration;
26use tokio::runtime::Handle;
27#[allow(unused_imports, dead_code)]
28use tracing::{debug, error, info, trace, warn};
29use virtual_mio::{
30    HandlerGuardState, InterestGuard, InterestHandler, InterestType, Selector, state_as_waker_map,
31};
32
33/// Use the platform's maximum listen backlog where available so that
34/// `LocalTcpBoundSocket::listen` preserves the same accept capacity as
35/// the previous `std::net::TcpListener`-based implementation.
36#[cfg(all(target_family = "unix", feature = "libc"))]
37const LISTEN_BACKLOG: i32 = libc::SOMAXCONN;
38#[cfg(not(all(target_family = "unix", feature = "libc")))]
39const LISTEN_BACKLOG: i32 = 128;
40
41#[derive(Debug)]
42pub struct LocalNetworking {
43    selector: Arc<Selector>,
44    handle: Handle,
45    ruleset: Option<Ruleset>,
46}
47
48impl LocalNetworking {
49    pub fn new() -> Self {
50        Self {
51            selector: Selector::new(),
52            handle: Handle::current(),
53            ruleset: None,
54        }
55    }
56
57    pub fn with_ruleset(ruleset: Ruleset) -> Self {
58        Self {
59            selector: Selector::new(),
60            handle: Handle::current(),
61            ruleset: Some(ruleset),
62        }
63    }
64}
65
66impl Drop for LocalNetworking {
67    fn drop(&mut self) {
68        self.selector.shutdown();
69    }
70}
71
72impl Default for LocalNetworking {
73    fn default() -> Self {
74        Self::new()
75    }
76}
77
78fn sock_addr_into_socket_addr(addr: socket2::SockAddr) -> Result<SocketAddr> {
79    addr.as_socket().ok_or(NetworkError::UnknownError)
80}
81
82fn tcp_socket_domain(addr: SocketAddr) -> socket2::Domain {
83    if addr.is_ipv4() {
84        socket2::Domain::IPV4
85    } else {
86        socket2::Domain::IPV6
87    }
88}
89
90#[allow(clippy::needless_bool)]
91fn tcp_connect_in_progress(err: &io::Error) -> bool {
92    if matches!(
93        err.kind(),
94        io::ErrorKind::WouldBlock | io::ErrorKind::Interrupted
95    ) {
96        true
97    } else {
98        #[cfg(all(target_family = "unix", feature = "libc"))]
99        {
100            matches!(
101                err.raw_os_error(),
102                Some(raw) if raw == libc::EINPROGRESS || raw == libc::EALREADY
103            )
104        }
105
106        #[cfg(not(all(target_family = "unix", feature = "libc")))]
107        {
108            false
109        }
110    }
111}
112
113#[async_trait::async_trait]
114#[allow(unused_variables)]
115impl VirtualNetworking for LocalNetworking {
116    async fn listen_tcp(
117        &self,
118        addr: SocketAddr,
119        only_v6: bool,
120        reuse_port: bool,
121        reuse_addr: bool,
122    ) -> Result<Box<dyn VirtualTcpListener + Sync>> {
123        if let Some(ruleset) = self.ruleset.as_ref()
124            && !ruleset.allows_socket(addr, Direction::Inbound)
125        {
126            tracing::warn!(%addr, "listen_tcp blocked by firewall rule");
127            return Err(NetworkError::PermissionDenied);
128        }
129
130        self.bind_tcp(addr, only_v6, reuse_port, reuse_addr)
131            .await?
132            .listen()
133    }
134
135    async fn bind_tcp(
136        &self,
137        addr: SocketAddr,
138        only_v6: bool,
139        reuse_port: bool,
140        reuse_addr: bool,
141    ) -> Result<Box<dyn VirtualTcpBoundSocket + Sync>> {
142        if let Some(ruleset) = self.ruleset.as_ref()
143            && !ruleset.allows_socket(addr, Direction::Inbound)
144        {
145            tracing::warn!(%addr, "bind_tcp blocked by firewall rule");
146            return Err(NetworkError::PermissionDenied);
147        }
148
149        let socket = socket2::Socket::new(tcp_socket_domain(addr), socket2::Type::STREAM, None)
150            .map_err(io_err_into_net_error)?;
151        socket
152            .set_nonblocking(true)
153            .map_err(io_err_into_net_error)?;
154        if addr.is_ipv6() {
155            socket.set_only_v6(only_v6).map_err(io_err_into_net_error)?;
156        }
157        socket
158            .set_reuse_address(reuse_addr)
159            .map_err(io_err_into_net_error)?;
160        #[cfg(not(windows))]
161        socket
162            .set_reuse_port(reuse_port)
163            .map_err(io_err_into_net_error)?;
164        socket.bind(&addr.into()).map_err(io_err_into_net_error)?;
165
166        Ok(Box::new(LocalTcpBoundSocket {
167            socket: Some(socket),
168            selector: self.selector.clone(),
169            ruleset: self.ruleset.clone(),
170        }))
171    }
172
173    async fn bind_udp(
174        &self,
175        addr: SocketAddr,
176        reuse_port: bool,
177        reuse_addr: bool,
178    ) -> Result<Box<dyn VirtualUdpSocket + Sync>> {
179        #[cfg(not(windows))]
180        use socket2::{Domain, Socket, Type};
181
182        if let Some(ruleset) = self.ruleset.as_ref()
183            && !ruleset.allows_socket(addr, Direction::Inbound)
184        {
185            tracing::warn!(%addr, "bind_udp blocked by firewall rule");
186            return Err(NetworkError::PermissionDenied);
187        }
188
189        #[cfg(not(windows))]
190        let socket = {
191            let domain = if addr.is_ipv4() {
192                Domain::IPV4
193            } else {
194                Domain::IPV6
195            };
196            let std_sock = Socket::new(domain, Type::DGRAM, None).map_err(io_err_into_net_error)?;
197            std_sock
198                .set_nonblocking(true)
199                .map_err(io_err_into_net_error)?;
200            std_sock
201                .set_reuse_address(reuse_addr)
202                .map_err(io_err_into_net_error)?;
203            std_sock
204                .set_reuse_port(reuse_port)
205                .map_err(io_err_into_net_error)?;
206            std_sock.bind(&addr.into()).map_err(io_err_into_net_error)?;
207            mio::net::UdpSocket::from_std(std_sock.into())
208        };
209        #[cfg(windows)]
210        let socket = mio::net::UdpSocket::bind(addr).map_err(io_err_into_net_error)?;
211
212        #[allow(unused_mut)]
213        let mut ret = LocalUdpSocket {
214            selector: self.selector.clone(),
215            socket,
216            addr,
217            handler_guard: HandlerGuardState::None,
218            backlog: Default::default(),
219            ruleset: self.ruleset.clone(),
220        };
221
222        // In windows we can not poll the socket as it is not supported and hence
223        // what we do is immediately set the writable flag and relay on `mio` to
224        // refresh that flag when the state changes. In Linux what we do is actually
225        // make a non-blocking `poll` call to determine this state
226        #[cfg(target_os = "windows")]
227        {
228            let (state, selector, socket) = ret.split_borrow();
229            let map = state_as_waker_map(state, selector, socket).map_err(io_err_into_net_error)?;
230            map.push(InterestType::Writable);
231        }
232
233        Ok(Box::new(ret))
234    }
235
236    async fn connect_tcp(
237        &self,
238        _addr: SocketAddr,
239        mut peer: SocketAddr,
240    ) -> Result<Box<dyn VirtualTcpSocket + Sync>> {
241        if let Some(ruleset) = self.ruleset.as_ref()
242            && !ruleset.allows_socket(peer, Direction::Outbound)
243        {
244            tracing::warn!(%peer, "connect_tcp blocked by firewall rule");
245            return Err(NetworkError::PermissionDenied);
246        }
247
248        let stream = mio::net::TcpStream::connect(peer).map_err(io_err_into_net_error)?;
249
250        if let Ok(p) = stream.peer_addr() {
251            peer = p;
252        }
253        let socket = Box::new(LocalTcpStream::new(self.selector.clone(), stream, peer));
254        Ok(socket)
255    }
256
257    async fn resolve(
258        &self,
259        host: &str,
260        port: Option<u16>,
261        dns_server: Option<IpAddr>,
262    ) -> Result<Vec<IpAddr>> {
263        if let Some(ruleset) = self.ruleset.as_ref()
264            && !ruleset.allows_domain(host)
265        {
266            tracing::warn!(%host, "dns resolve blocked by firewall rule");
267            return Err(NetworkError::PermissionDenied);
268        }
269
270        let host_to_lookup = if host.contains(':') {
271            host.to_string()
272        } else {
273            format!("{}:{}", host, port.unwrap_or(0))
274        };
275        let addrs = self
276            .handle
277            .spawn(tokio::net::lookup_host(host_to_lookup))
278            .await
279            .map_err(|_| NetworkError::IOError)?
280            .map(|a| a.map(|a| a.ip()).collect::<Vec<_>>())
281            .map_err(io_err_into_net_error)?;
282
283        if let Some(ruleset) = self.ruleset.as_ref() {
284            if let Err(e) = ruleset.expand_domain(host, &addrs) {
285                tracing::debug!(err=%e, "ruleset expansion failed");
286            } else {
287                tracing::debug!(addrs=?addrs, domain = host, "ruleset expansion")
288            }
289        }
290
291        Ok(addrs)
292    }
293}
294
295#[derive(Debug)]
296pub struct LocalTcpListener {
297    stream: mio::net::TcpListener,
298    selector: Arc<Selector>,
299    handler_guard: HandlerGuardState,
300    no_delay: Option<bool>,
301    keep_alive: Option<bool>,
302    backlog: VecDeque<(Box<dyn VirtualTcpSocket + Sync>, SocketAddr)>,
303    ruleset: Option<Ruleset>,
304}
305
306impl LocalTcpListener {
307    fn try_accept_internal(&mut self) -> Result<(Box<dyn VirtualTcpSocket + Sync>, SocketAddr)> {
308        match self.stream.accept().map_err(io_err_into_net_error) {
309            Ok((stream, addr)) => {
310                if let Some(ruleset) = self.ruleset.as_ref()
311                    && !ruleset.allows_socket(addr, Direction::Outbound)
312                {
313                    tracing::warn!(%addr, "try_accept blocked by firewall rule");
314                    return Err(NetworkError::PermissionDenied);
315                }
316
317                let mut socket = LocalTcpStream::new(self.selector.clone(), stream, addr);
318                if let Some(no_delay) = self.no_delay {
319                    socket.set_nodelay(no_delay).ok();
320                }
321                if let Some(keep_alive) = self.keep_alive {
322                    socket.set_keepalive(keep_alive).ok();
323                }
324                Ok((Box::new(socket), addr))
325            }
326            Err(NetworkError::WouldBlock) => {
327                if let HandlerGuardState::WakerMap(_, map) = &mut self.handler_guard {
328                    map.pop(InterestType::Readable);
329                    map.pop(InterestType::Writable);
330                }
331                Err(NetworkError::WouldBlock)
332            }
333            Err(err) => Err(err),
334        }
335    }
336}
337
338impl VirtualTcpListener for LocalTcpListener {
339    fn try_accept(&mut self) -> Result<(Box<dyn VirtualTcpSocket + Sync>, SocketAddr)> {
340        if let Some(child) = self.backlog.pop_front() {
341            return Ok(child);
342        }
343        self.try_accept_internal()
344    }
345
346    fn set_handler(&mut self, mut handler: Box<dyn InterestHandler + Send + Sync>) -> Result<()> {
347        if let HandlerGuardState::ExternalHandler(guard) = &mut self.handler_guard {
348            match guard.replace_handler(handler) {
349                Ok(()) => return Ok(()),
350                Err(h) => handler = h,
351            }
352
353            // the handler could not be replaced so we need to build a new handler instead
354            if let Err(err) = guard.unregister(&mut self.stream) {
355                tracing::debug!("failed to unregister previous token - {}", err);
356            }
357        }
358
359        let guard = InterestGuard::new(
360            &self.selector,
361            handler,
362            &mut self.stream,
363            mio::Interest::READABLE.add(mio::Interest::WRITABLE),
364        )
365        .map_err(io_err_into_net_error)?;
366
367        self.handler_guard = HandlerGuardState::ExternalHandler(guard);
368
369        Ok(())
370    }
371
372    fn addr_local(&self) -> Result<SocketAddr> {
373        self.stream.local_addr().map_err(io_err_into_net_error)
374    }
375
376    fn set_ttl(&mut self, ttl: u8) -> Result<()> {
377        self.stream
378            .set_ttl(ttl as u32)
379            .map_err(io_err_into_net_error)
380    }
381
382    fn ttl(&self) -> Result<u8> {
383        self.stream
384            .ttl()
385            .map(|ttl| ttl as u8)
386            .map_err(io_err_into_net_error)
387    }
388}
389
390impl LocalTcpListener {
391    fn split_borrow(
392        &mut self,
393    ) -> (
394        &mut HandlerGuardState,
395        &Arc<Selector>,
396        &mut mio::net::TcpListener,
397    ) {
398        (&mut self.handler_guard, &self.selector, &mut self.stream)
399    }
400}
401
402impl VirtualIoSource for LocalTcpListener {
403    fn remove_handler(&mut self) {
404        let mut guard = HandlerGuardState::None;
405        std::mem::swap(&mut guard, &mut self.handler_guard);
406        match guard {
407            HandlerGuardState::ExternalHandler(mut guard) => {
408                guard.unregister(&mut self.stream).ok();
409            }
410            HandlerGuardState::WakerMap(mut guard, _) => {
411                guard.unregister(&mut self.stream).ok();
412            }
413            HandlerGuardState::None => {}
414        }
415    }
416
417    fn poll_read_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
418        if !self.backlog.is_empty() {
419            return Poll::Ready(Ok(self.backlog.len()));
420        }
421
422        let (state, selector, source) = self.split_borrow();
423        let map = state_as_waker_map(state, selector, source).map_err(io_err_into_net_error)?;
424        map.add(InterestType::Readable, cx.waker());
425
426        if let Ok(child) = self.try_accept_internal() {
427            self.backlog.push_back(child);
428            return Poll::Ready(Ok(1));
429        }
430        Poll::Pending
431    }
432
433    fn poll_write_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
434        if !self.backlog.is_empty() {
435            return Poll::Ready(Ok(self.backlog.len()));
436        }
437
438        let (state, selector, source) = self.split_borrow();
439        let map = state_as_waker_map(state, selector, source).map_err(io_err_into_net_error)?;
440        map.add(InterestType::Writable, cx.waker());
441
442        if let Ok(child) = self.try_accept_internal() {
443            self.backlog.push_back(child);
444            return Poll::Ready(Ok(1));
445        }
446        Poll::Pending
447    }
448}
449
450#[derive(Debug)]
451pub struct LocalTcpBoundSocket {
452    socket: Option<socket2::Socket>,
453    selector: Arc<Selector>,
454    ruleset: Option<Ruleset>,
455}
456
457impl VirtualTcpBoundSocket for LocalTcpBoundSocket {
458    fn addr_local(&self) -> Result<SocketAddr> {
459        let socket = self.socket.as_ref().ok_or(NetworkError::InvalidFd)?;
460        let addr = socket.local_addr().map_err(io_err_into_net_error)?;
461        sock_addr_into_socket_addr(addr)
462    }
463
464    fn listen(&mut self) -> Result<Box<dyn VirtualTcpListener + Sync>> {
465        let socket = self.socket.take().ok_or(NetworkError::InvalidFd)?;
466        socket
467            .listen(LISTEN_BACKLOG)
468            .map_err(io_err_into_net_error)?;
469        let listener = mio::net::TcpListener::from_std(socket.into());
470        Ok(Box::new(LocalTcpListener {
471            stream: listener,
472            selector: self.selector.clone(),
473            handler_guard: HandlerGuardState::None,
474            no_delay: None,
475            keep_alive: None,
476            backlog: Default::default(),
477            ruleset: self.ruleset.clone(),
478        }))
479    }
480
481    fn connect(&mut self, mut peer: SocketAddr) -> Result<Box<dyn VirtualTcpSocket + Sync>> {
482        if let Some(ruleset) = self.ruleset.as_ref()
483            && !ruleset.allows_socket(peer, Direction::Outbound)
484        {
485            tracing::warn!(%peer, "bound connect_tcp blocked by firewall rule");
486            return Err(NetworkError::PermissionDenied);
487        }
488
489        let socket = self.socket.take().ok_or(NetworkError::InvalidFd)?;
490        if let Err(err) = socket.connect(&peer.into())
491            && !tcp_connect_in_progress(&err)
492        {
493            return Err(io_err_into_net_error(err));
494        }
495
496        let stream = mio::net::TcpStream::from_std(socket.into());
497        if let Ok(p) = stream.peer_addr() {
498            peer = p;
499        }
500        Ok(Box::new(LocalTcpStream::new(
501            self.selector.clone(),
502            stream,
503            peer,
504        )))
505    }
506
507    fn set_ttl(&mut self, ttl: u32) -> Result<()> {
508        let socket = self.socket.as_ref().ok_or(NetworkError::InvalidFd)?;
509        match self.addr_local()?.ip() {
510            IpAddr::V4(_) => socket.set_ttl_v4(ttl).map_err(io_err_into_net_error),
511            IpAddr::V6(_) => socket
512                .set_unicast_hops_v6(ttl)
513                .map_err(io_err_into_net_error),
514        }
515    }
516
517    fn ttl(&self) -> Result<u32> {
518        let socket = self.socket.as_ref().ok_or(NetworkError::InvalidFd)?;
519        match self.addr_local()?.ip() {
520            IpAddr::V4(_) => socket.ttl_v4().map_err(io_err_into_net_error),
521            IpAddr::V6(_) => socket.unicast_hops_v6().map_err(io_err_into_net_error),
522        }
523    }
524}
525
526#[derive(Debug)]
527enum ConnectState {
528    Unknown,
529    Opened,
530    Failed(NetworkError),
531    Reset,
532}
533
534#[derive(Debug)]
535pub struct LocalTcpStream {
536    stream: mio::net::TcpStream,
537    addr: SocketAddr,
538    shutdown: Option<Shutdown>,
539    selector: Arc<Selector>,
540    handler_guard: HandlerGuardState,
541    buffer: BytesMut,
542    connect_state: Mutex<ConnectState>,
543}
544
545impl LocalTcpStream {
546    fn new(selector: Arc<Selector>, stream: mio::net::TcpStream, addr: SocketAddr) -> Self {
547        #[allow(unused_mut)]
548        let mut ret = Self {
549            stream,
550            addr,
551            shutdown: None,
552            selector,
553            handler_guard: HandlerGuardState::None,
554            buffer: BytesMut::new(),
555            connect_state: Mutex::new(ConnectState::Unknown),
556        };
557
558        // In windows we can not poll the socket as it is not supported and hence
559        // what we do is immediately set the writable flag and relay on `mio` to
560        // refresh that flag when the state changes. In Linux what we do is actually
561        // make a non-blocking `poll` call to determine this state
562        #[cfg(target_os = "windows")]
563        {
564            let (state, selector, socket, _) = ret.split_borrow();
565            if let Ok(map) = state_as_waker_map(state, selector, socket) {
566                map.push(InterestType::Writable);
567            }
568        }
569
570        ret
571    }
572
573    fn with_sock_ref<F, R>(&self, f: F) -> R
574    where
575        for<'a> F: FnOnce(socket2::SockRef<'a>) -> R,
576    {
577        #[cfg(not(windows))]
578        let r = socket2::SockRef::from(&self.stream);
579
580        #[cfg(windows)]
581        let b = unsafe {
582            std::os::windows::io::BorrowedSocket::borrow_raw(self.stream.as_raw_socket())
583        };
584        #[cfg(windows)]
585        let r = socket2::SockRef::from(&b);
586
587        f(r)
588    }
589}
590
591impl VirtualTcpSocket for LocalTcpStream {
592    fn set_recv_buf_size(&mut self, size: usize) -> Result<()> {
593        Ok(())
594    }
595
596    fn recv_buf_size(&self) -> Result<usize> {
597        Err(NetworkError::Unsupported)
598    }
599
600    fn set_send_buf_size(&mut self, size: usize) -> Result<()> {
601        Ok(())
602    }
603
604    fn send_buf_size(&self) -> Result<usize> {
605        Err(NetworkError::Unsupported)
606    }
607
608    fn set_nodelay(&mut self, nodelay: bool) -> Result<()> {
609        self.stream
610            .set_nodelay(nodelay)
611            .map_err(io_err_into_net_error)
612    }
613
614    fn nodelay(&self) -> Result<bool> {
615        self.stream.nodelay().map_err(io_err_into_net_error)
616    }
617
618    fn set_keepalive(&mut self, keepalive: bool) -> Result<()> {
619        self.with_sock_ref(|s| s.set_keepalive(true))
620            .map_err(io_err_into_net_error)?;
621        Ok(())
622    }
623
624    fn keepalive(&self) -> Result<bool> {
625        let ret = self
626            .with_sock_ref(|s| s.keepalive())
627            .map_err(io_err_into_net_error)?;
628        Ok(ret)
629    }
630
631    #[cfg(not(target_os = "windows"))]
632    fn set_dontroute(&mut self, val: bool) -> Result<()> {
633        // TODO:
634        // Don't route is being set by WASIX which breaks networking
635        // Why this is being set is unknown but we need to disable
636        // the functionality for now as it breaks everything
637
638        let val = val as libc::c_int;
639        let payload = &val as *const libc::c_int as *const libc::c_void;
640        let err = unsafe {
641            libc::setsockopt(
642                self.stream.as_raw_fd(),
643                libc::SOL_SOCKET,
644                libc::SO_DONTROUTE,
645                payload,
646                std::mem::size_of::<libc::c_int>() as libc::socklen_t,
647            )
648        };
649        if err == -1 {
650            return Err(io_err_into_net_error(std::io::Error::last_os_error()));
651        }
652        Ok(())
653    }
654    #[cfg(target_os = "windows")]
655    fn set_dontroute(&mut self, val: bool) -> Result<()> {
656        Err(NetworkError::Unsupported)
657    }
658
659    #[cfg(not(target_os = "windows"))]
660    fn dontroute(&self) -> Result<bool> {
661        let mut payload: MaybeUninit<libc::c_int> = MaybeUninit::uninit();
662        let mut len = std::mem::size_of::<libc::c_int>() as libc::socklen_t;
663        let err = unsafe {
664            libc::getsockopt(
665                self.stream.as_raw_fd(),
666                libc::SOL_SOCKET,
667                libc::SO_DONTROUTE,
668                payload.as_mut_ptr().cast(),
669                &mut len,
670            )
671        };
672        if err == -1 {
673            return Err(io_err_into_net_error(std::io::Error::last_os_error()));
674        }
675        Ok(unsafe { payload.assume_init() != 0 })
676    }
677    #[cfg(target_os = "windows")]
678    fn dontroute(&self) -> Result<bool> {
679        Err(NetworkError::Unsupported)
680    }
681
682    fn addr_peer(&self) -> Result<SocketAddr> {
683        Ok(self.addr)
684    }
685
686    fn shutdown(&mut self, how: Shutdown) -> Result<()> {
687        self.stream.shutdown(how).map_err(io_err_into_net_error)?;
688        self.shutdown = Some(how);
689        Ok(())
690    }
691
692    fn is_closed(&self) -> bool {
693        false
694    }
695}
696
697impl VirtualConnectedSocket for LocalTcpStream {
698    fn set_linger(&mut self, linger: Option<Duration>) -> Result<()> {
699        self.with_sock_ref(|s| s.set_linger(linger))
700            .map_err(io_err_into_net_error)?;
701        Ok(())
702    }
703
704    fn linger(&self) -> Result<Option<Duration>> {
705        self.with_sock_ref(|s| s.linger())
706            .map_err(io_err_into_net_error)
707    }
708
709    fn try_send(&mut self, data: &[u8]) -> Result<usize> {
710        let ret = self.stream.write(data).map_err(io_err_into_net_error);
711        match &ret {
712            Ok(0) | Err(NetworkError::WouldBlock) => {
713                if let HandlerGuardState::WakerMap(_, map) = &mut self.handler_guard {
714                    map.pop(InterestType::Writable);
715                }
716            }
717            _ => {}
718        }
719        ret
720    }
721
722    fn try_flush(&mut self) -> Result<()> {
723        self.stream.flush().map_err(io_err_into_net_error)
724    }
725
726    fn close(&mut self) -> Result<()> {
727        Ok(())
728    }
729
730    fn try_recv(&mut self, buf: &mut [MaybeUninit<u8>], peek: bool) -> Result<usize> {
731        let buf: &mut [u8] = unsafe { std::mem::transmute(buf) };
732        if !self.buffer.is_empty() {
733            let amt = buf.len().min(self.buffer.len());
734            buf[..amt].copy_from_slice(&self.buffer[..amt]);
735            if !peek {
736                self.buffer.advance(amt);
737            }
738            return Ok(amt);
739        }
740
741        if peek {
742            self.stream.peek(buf)
743        } else {
744            self.stream.read(buf)
745        }
746        .map_err(io_err_into_net_error)
747    }
748}
749
750impl VirtualSocket for LocalTcpStream {
751    fn set_ttl(&mut self, ttl: u32) -> Result<()> {
752        self.stream.set_ttl(ttl).map_err(io_err_into_net_error)
753    }
754
755    fn ttl(&self) -> Result<u32> {
756        self.stream.ttl().map_err(io_err_into_net_error)
757    }
758
759    fn addr_local(&self) -> Result<SocketAddr> {
760        self.stream.local_addr().map_err(io_err_into_net_error)
761    }
762
763    fn status(&self) -> Result<SocketStatus> {
764        // `take_error()` consumes the latched socket error, so once the
765        // connect resolves we cache the terminal state to keep status() stable.
766        let mut connect_state = self.connect_state.lock().unwrap();
767        match *connect_state {
768            ConnectState::Opened => return Ok(SocketStatus::Opened),
769            ConnectState::Failed(_) | ConnectState::Reset => {
770                return Ok(SocketStatus::Failed);
771            }
772            ConnectState::Unknown => {}
773        }
774
775        if let Some(err) = self
776            .with_sock_ref(|sockref| sockref.take_error())
777            .map_err(io_err_into_net_error)?
778        {
779            *connect_state = ConnectState::Failed(io_err_into_net_error(err));
780            return Ok(SocketStatus::Failed); // connect error on the socket
781        }
782        match self.stream.peer_addr() {
783            Ok(_) => {
784                *connect_state = ConnectState::Opened;
785                Ok(SocketStatus::Opened) // TCP handshake completed.
786            }
787            Err(err) => {
788                if matches!(
789                    err.kind(),
790                    io::ErrorKind::NotConnected | io::ErrorKind::WouldBlock
791                ) {
792                    Ok(SocketStatus::Opening) // The connect is still in progress
793                } else {
794                    *connect_state = ConnectState::Failed(io_err_into_net_error(err));
795                    Ok(SocketStatus::Failed) // Any other error means the socket is unusable
796                }
797            }
798        }
799    }
800
801    fn last_error(&self) -> Result<Option<NetworkError>> {
802        let mut connect_state = self.connect_state.lock().unwrap();
803        Ok(match *connect_state {
804            ConnectState::Failed(err) => {
805                *connect_state = ConnectState::Reset;
806                Some(err)
807            }
808            _ => None,
809        })
810    }
811
812    fn set_handler(&mut self, mut handler: Box<dyn InterestHandler + Send + Sync>) -> Result<()> {
813        if let HandlerGuardState::ExternalHandler(guard) = &mut self.handler_guard {
814            match guard.replace_handler(handler) {
815                Ok(()) => return Ok(()),
816                Err(h) => handler = h,
817            }
818
819            // the handler could not be replaced so we need to build a new handler instead
820            if let Err(err) = guard.unregister(&mut self.stream) {
821                tracing::debug!("failed to unregister previous token - {}", err);
822            }
823        }
824
825        let guard = InterestGuard::new(
826            &self.selector,
827            handler,
828            &mut self.stream,
829            mio::Interest::READABLE.add(mio::Interest::WRITABLE),
830        )
831        .map_err(io_err_into_net_error)?;
832
833        self.handler_guard = HandlerGuardState::ExternalHandler(guard);
834
835        Ok(())
836    }
837}
838
839impl LocalTcpStream {
840    fn split_borrow(
841        &mut self,
842    ) -> (
843        &mut HandlerGuardState,
844        &Arc<Selector>,
845        &mut mio::net::TcpStream,
846        &mut BytesMut,
847    ) {
848        (
849            &mut self.handler_guard,
850            &self.selector,
851            &mut self.stream,
852            &mut self.buffer,
853        )
854    }
855}
856
857impl VirtualIoSource for LocalTcpStream {
858    fn remove_handler(&mut self) {
859        let mut guard = HandlerGuardState::None;
860        std::mem::swap(&mut guard, &mut self.handler_guard);
861        match guard {
862            HandlerGuardState::ExternalHandler(mut guard) => {
863                guard.unregister(&mut self.stream).ok();
864            }
865            HandlerGuardState::WakerMap(mut guard, _) => {
866                guard.unregister(&mut self.stream).ok();
867            }
868            HandlerGuardState::None => {}
869        }
870    }
871
872    fn poll_read_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
873        if !self.buffer.is_empty() {
874            return Poll::Ready(Ok(self.buffer.len()));
875        }
876
877        let (state, selector, stream, buffer) = self.split_borrow();
878        let map = state_as_waker_map(state, selector, stream).map_err(io_err_into_net_error)?;
879        map.pop(InterestType::Readable);
880        map.add(InterestType::Readable, cx.waker());
881
882        buffer.reserve(buffer.len() + 10240);
883        let uninit: &mut [MaybeUninit<u8>] = buffer.spare_capacity_mut();
884        let uninit_unsafe: &mut [u8] = unsafe { std::mem::transmute(uninit) };
885
886        match stream.read(uninit_unsafe) {
887            Ok(0) => Poll::Ready(Ok(0)),
888            Ok(amt) => {
889                unsafe {
890                    buffer.set_len(buffer.len() + amt);
891                }
892                Poll::Ready(Ok(amt))
893            }
894            Err(err) if err.kind() == io::ErrorKind::ConnectionAborted => Poll::Ready(Ok(0)),
895            Err(err) if err.kind() == io::ErrorKind::ConnectionReset => Poll::Ready(Ok(0)),
896            Err(err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
897            Err(err) => Poll::Ready(Err(io_err_into_net_error(err))),
898        }
899    }
900
901    fn poll_write_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
902        let (state, selector, stream, _) = self.split_borrow();
903        let map = state_as_waker_map(state, selector, stream).map_err(io_err_into_net_error)?;
904        #[cfg(not(target_os = "windows"))]
905        map.pop(InterestType::Writable);
906        map.add(InterestType::Writable, cx.waker());
907        map.add(InterestType::Closed, cx.waker());
908        if map.has_interest(InterestType::Closed) {
909            return Poll::Ready(Ok(0));
910        }
911
912        #[cfg(not(target_os = "windows"))]
913        match libc_poll(stream.as_raw_fd(), libc::POLLOUT | libc::POLLHUP) {
914            Some(val) if (val & libc::POLLHUP) != 0 => {
915                return Poll::Ready(Ok(0));
916            }
917            Some(val) if (val & libc::POLLOUT) != 0 => return Poll::Ready(Ok(10240)),
918            _ => {}
919        }
920
921        // In windows we can not poll the socket as it is not supported and hence
922        // what we do is immediately set the writable flag and relay on `mio` to
923        // refresh that flag when the state changes. In Linux what we do is actually
924        // make a non-blocking `poll` call to determine this state
925        #[cfg(target_os = "windows")]
926        if map.has_interest(InterestType::Writable) {
927            return Poll::Ready(Ok(10240));
928        }
929
930        Poll::Pending
931    }
932}
933
934#[cfg(not(target_os = "windows"))]
935fn libc_poll(fd: RawFd, events: libc::c_short) -> Option<libc::c_short> {
936    let mut fds: [libc::pollfd; 1] = [libc::pollfd {
937        fd,
938        events,
939        revents: 0,
940    }];
941    let fds_mut = &mut fds[..];
942    let ret = unsafe { libc::poll(fds_mut.as_mut_ptr(), 1, 0) };
943    match ret == 1 {
944        true => Some(fds[0].revents),
945        false => None,
946    }
947}
948
949#[derive(Debug)]
950pub struct LocalUdpSocket {
951    socket: mio::net::UdpSocket,
952    #[allow(dead_code)]
953    addr: SocketAddr,
954    selector: Arc<Selector>,
955    handler_guard: HandlerGuardState,
956    backlog: VecDeque<(BytesMut, SocketAddr)>,
957    ruleset: Option<Ruleset>,
958}
959
960impl LocalUdpSocket {
961    fn with_sock_ref<F, R>(&self, f: F) -> R
962    where
963        for<'a> F: FnOnce(socket2::SockRef<'a>) -> R,
964    {
965        #[cfg(not(windows))]
966        let r = socket2::SockRef::from(&self.socket);
967
968        #[cfg(windows)]
969        let b = unsafe {
970            std::os::windows::io::BorrowedSocket::borrow_raw(self.socket.as_raw_socket())
971        };
972        #[cfg(windows)]
973        let r = socket2::SockRef::from(&b);
974
975        f(r)
976    }
977}
978
979impl VirtualUdpSocket for LocalUdpSocket {
980    fn set_broadcast(&mut self, broadcast: bool) -> Result<()> {
981        self.socket
982            .set_broadcast(broadcast)
983            .map_err(io_err_into_net_error)
984    }
985
986    fn broadcast(&self) -> Result<bool> {
987        self.socket.broadcast().map_err(io_err_into_net_error)
988    }
989
990    fn set_multicast_loop_v4(&mut self, val: bool) -> Result<()> {
991        self.socket
992            .set_multicast_loop_v4(val)
993            .map_err(io_err_into_net_error)
994    }
995
996    fn multicast_loop_v4(&self) -> Result<bool> {
997        self.socket
998            .multicast_loop_v4()
999            .map_err(io_err_into_net_error)
1000    }
1001
1002    fn set_multicast_loop_v6(&mut self, val: bool) -> Result<()> {
1003        self.socket
1004            .set_multicast_loop_v6(val)
1005            .map_err(io_err_into_net_error)
1006    }
1007
1008    fn multicast_loop_v6(&self) -> Result<bool> {
1009        self.socket
1010            .multicast_loop_v6()
1011            .map_err(io_err_into_net_error)
1012    }
1013
1014    fn set_multicast_ttl_v4(&mut self, ttl: u32) -> Result<()> {
1015        self.socket
1016            .set_multicast_ttl_v4(ttl)
1017            .map_err(io_err_into_net_error)
1018    }
1019
1020    fn multicast_ttl_v4(&self) -> Result<u32> {
1021        self.socket
1022            .multicast_ttl_v4()
1023            .map_err(io_err_into_net_error)
1024    }
1025
1026    fn join_multicast_v4(&mut self, multiaddr: Ipv4Addr, iface: Ipv4Addr) -> Result<()> {
1027        self.with_sock_ref(|s| s.join_multicast_v4(&multiaddr, &iface))
1028            .map_err(io_err_into_net_error)
1029    }
1030
1031    fn leave_multicast_v4(&mut self, multiaddr: Ipv4Addr, iface: Ipv4Addr) -> Result<()> {
1032        self.with_sock_ref(|s| s.leave_multicast_v4(&multiaddr, &iface))
1033            .map_err(io_err_into_net_error)
1034    }
1035
1036    fn join_multicast_v6(&mut self, multiaddr: Ipv6Addr, iface: u32) -> Result<()> {
1037        self.socket
1038            .join_multicast_v6(&multiaddr, iface)
1039            .map_err(io_err_into_net_error)
1040    }
1041
1042    fn leave_multicast_v6(&mut self, multiaddr: Ipv6Addr, iface: u32) -> Result<()> {
1043        self.socket
1044            .leave_multicast_v6(&multiaddr, iface)
1045            .map_err(io_err_into_net_error)
1046    }
1047
1048    fn addr_peer(&self) -> Result<Option<SocketAddr>> {
1049        self.socket
1050            .peer_addr()
1051            .map(Some)
1052            .map_err(io_err_into_net_error)
1053    }
1054}
1055
1056impl VirtualConnectionlessSocket for LocalUdpSocket {
1057    fn try_send_to(&mut self, data: &[u8], addr: SocketAddr) -> Result<usize> {
1058        if let Some(ruleset) = self.ruleset.as_ref()
1059            && !ruleset.allows_socket(addr, Direction::Outbound)
1060        {
1061            tracing::warn!(%addr, "try_send blocked by firewall rule");
1062            return Err(NetworkError::PermissionDenied);
1063        }
1064
1065        let ret = self
1066            .socket
1067            .send_to(data, addr)
1068            .map_err(io_err_into_net_error);
1069        match &ret {
1070            Ok(0) | Err(NetworkError::WouldBlock) => {
1071                if let HandlerGuardState::WakerMap(_, map) = &mut self.handler_guard {
1072                    map.pop(InterestType::Writable);
1073                }
1074            }
1075            _ => {}
1076        }
1077        ret
1078    }
1079
1080    fn try_recv_from(
1081        &mut self,
1082        buf: &mut [MaybeUninit<u8>],
1083        peek: bool,
1084    ) -> Result<(usize, SocketAddr)> {
1085        if self.backlog.is_empty() {
1086            self.recv_into_backlog().map_err(io_err_into_net_error)?;
1087        }
1088
1089        let buf: &mut [u8] = unsafe { std::mem::transmute(buf) };
1090        let Some((packet, addr)) = self.backlog.front() else {
1091            return Err(NetworkError::WouldBlock);
1092        };
1093        let amt = buf.len().min(packet.len());
1094        let addr = *addr;
1095        buf[..amt].copy_from_slice(&packet[..amt]);
1096        if !peek {
1097            self.backlog.pop_front();
1098        }
1099        Ok((amt, addr))
1100    }
1101}
1102
1103impl VirtualSocket for LocalUdpSocket {
1104    fn set_ttl(&mut self, ttl: u32) -> Result<()> {
1105        self.socket.set_ttl(ttl).map_err(io_err_into_net_error)
1106    }
1107
1108    fn ttl(&self) -> Result<u32> {
1109        self.socket.ttl().map_err(io_err_into_net_error)
1110    }
1111
1112    fn addr_local(&self) -> Result<SocketAddr> {
1113        self.socket.local_addr().map_err(io_err_into_net_error)
1114    }
1115
1116    fn status(&self) -> Result<SocketStatus> {
1117        Ok(SocketStatus::Opened)
1118    }
1119
1120    fn set_handler(&mut self, mut handler: Box<dyn InterestHandler + Send + Sync>) -> Result<()> {
1121        if let HandlerGuardState::ExternalHandler(guard) = &mut self.handler_guard {
1122            match guard.replace_handler(handler) {
1123                Ok(()) => {
1124                    return Ok(());
1125                }
1126                Err(h) => handler = h,
1127            }
1128
1129            // the handler could not be replaced so we need to build a new handler instead
1130            if let Err(err) = guard.unregister(&mut self.socket) {
1131                tracing::debug!("failed to unregister previous token - {}", err);
1132            }
1133        }
1134
1135        let guard = InterestGuard::new(
1136            &self.selector,
1137            handler,
1138            &mut self.socket,
1139            mio::Interest::READABLE.add(mio::Interest::WRITABLE),
1140        )
1141        .map_err(io_err_into_net_error)?;
1142
1143        self.handler_guard = HandlerGuardState::ExternalHandler(guard);
1144
1145        Ok(())
1146    }
1147}
1148
1149impl LocalUdpSocket {
1150    /// A queued zero-length UDP datagram is readable (Linux `POLLIN`), but
1151    /// wasix poll treats `nbytes == 0` as hangup, so report at least 1.
1152    fn backlog_read_ready_len(&self) -> usize {
1153        self.backlog
1154            .front()
1155            .map(|(packet, _)| packet.len().max(1))
1156            .unwrap_or_default()
1157    }
1158
1159    fn recv_into_backlog(&mut self) -> io::Result<()> {
1160        let mut buffer = BytesMut::default();
1161        buffer.reserve(MAX_SOCKET_PAYLOAD);
1162        let uninit: &mut [MaybeUninit<u8>] = buffer.spare_capacity_mut();
1163        let uninit_unsafe: &mut [u8] = unsafe { std::mem::transmute(uninit) };
1164
1165        let (amt, peer) = self.socket.recv_from(uninit_unsafe)?;
1166        unsafe {
1167            buffer.set_len(amt);
1168        }
1169        self.backlog.push_back((buffer, peer));
1170        Ok(())
1171    }
1172
1173    fn split_borrow(
1174        &mut self,
1175    ) -> (
1176        &mut HandlerGuardState,
1177        &Arc<Selector>,
1178        &mut mio::net::UdpSocket,
1179    ) {
1180        (&mut self.handler_guard, &self.selector, &mut self.socket)
1181    }
1182}
1183
1184impl VirtualIoSource for LocalUdpSocket {
1185    fn remove_handler(&mut self) {
1186        let mut guard = HandlerGuardState::None;
1187        std::mem::swap(&mut guard, &mut self.handler_guard);
1188        match guard {
1189            HandlerGuardState::ExternalHandler(mut guard) => {
1190                guard.unregister(&mut self.socket).ok();
1191            }
1192            HandlerGuardState::WakerMap(mut guard, _) => {
1193                guard.unregister(&mut self.socket).ok();
1194            }
1195            HandlerGuardState::None => {}
1196        }
1197    }
1198
1199    fn poll_read_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
1200        if !self.backlog.is_empty() {
1201            return Poll::Ready(Ok(self.backlog_read_ready_len()));
1202        }
1203
1204        let (state, selector, socket) = self.split_borrow();
1205        let map = state_as_waker_map(state, selector, socket).map_err(io_err_into_net_error)?;
1206        map.pop(InterestType::Readable);
1207        map.add(InterestType::Readable, cx.waker());
1208
1209        match self.recv_into_backlog() {
1210            Ok(()) => Poll::Ready(Ok(self.backlog_read_ready_len())),
1211            Err(err) if err.kind() == io::ErrorKind::ConnectionAborted => Poll::Ready(Ok(0)),
1212            Err(err) if err.kind() == io::ErrorKind::ConnectionReset => Poll::Ready(Ok(0)),
1213            Err(err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
1214            Err(err) => Poll::Ready(Err(io_err_into_net_error(err))),
1215        }
1216    }
1217
1218    fn poll_write_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
1219        let (state, selector, socket) = self.split_borrow();
1220        let map = state_as_waker_map(state, selector, socket).map_err(io_err_into_net_error)?;
1221        #[cfg(not(target_os = "windows"))]
1222        map.pop(InterestType::Writable);
1223        map.add(InterestType::Writable, cx.waker());
1224
1225        #[cfg(not(target_os = "windows"))]
1226        match libc_poll(socket.as_raw_fd(), libc::POLLOUT | libc::POLLHUP) {
1227            Some(val) if (val & libc::POLLHUP) != 0 => {
1228                return Poll::Ready(Ok(0));
1229            }
1230            Some(val) if (val & libc::POLLOUT) != 0 => return Poll::Ready(Ok(10240)),
1231            _ => {}
1232        }
1233
1234        // In windows we can not poll the socket as it is not supported and hence
1235        // what we do is immediately set the writable flag and relay on `mio` to
1236        // refresh that flag when the state changes. In Linux what we do is actually
1237        // make a non-blocking `poll` call to determine this state
1238        #[cfg(target_os = "windows")]
1239        if map.has_interest(InterestType::Writable) {
1240            return Poll::Ready(Ok(10240));
1241        }
1242
1243        Poll::Pending
1244    }
1245}