virtual_net/
host.rs

1#![allow(unused_variables)]
2use crate::ruleset::{Direction, Ruleset};
3#[allow(unused_imports)]
4use crate::{
5    IpCidr, IpRoute, NetworkError, Result, SocketStatus, StreamSecurity, VirtualConnectedSocket,
6    VirtualConnectionlessSocket, VirtualIcmpSocket, VirtualNetworking, VirtualRawSocket,
7    VirtualSocket, VirtualTcpBoundSocket, VirtualTcpListener, VirtualTcpSocket, VirtualUdpSocket,
8};
9use crate::{VirtualIoSource, io_err_into_net_error};
10use bytes::{Buf, BytesMut};
11use std::collections::VecDeque;
12use std::io::{self, Read, Write};
13use std::mem::MaybeUninit;
14use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, Shutdown, SocketAddr};
15#[cfg(not(target_os = "windows"))]
16use std::os::fd::AsRawFd;
17#[cfg(not(target_os = "windows"))]
18use std::os::fd::RawFd;
19#[cfg(windows)]
20use std::os::windows::io::AsRawSocket;
21
22use std::sync::{Arc, Mutex};
23use std::task::Poll;
24use std::time::Duration;
25use tokio::runtime::Handle;
26#[allow(unused_imports, dead_code)]
27use tracing::{debug, error, info, trace, warn};
28use virtual_mio::{
29    HandlerGuardState, InterestGuard, InterestHandler, InterestType, Selector, state_as_waker_map,
30};
31
32/// Use the platform's maximum listen backlog where available so that
33/// `LocalTcpBoundSocket::listen` preserves the same accept capacity as
34/// the previous `std::net::TcpListener`-based implementation.
35#[cfg(all(target_family = "unix", feature = "libc"))]
36const LISTEN_BACKLOG: i32 = libc::SOMAXCONN;
37#[cfg(not(all(target_family = "unix", feature = "libc")))]
38const LISTEN_BACKLOG: i32 = 128;
39
40#[derive(Debug)]
41pub struct LocalNetworking {
42    selector: Arc<Selector>,
43    handle: Handle,
44    ruleset: Option<Ruleset>,
45}
46
47impl LocalNetworking {
48    pub fn new() -> Self {
49        Self {
50            selector: Selector::new(),
51            handle: Handle::current(),
52            ruleset: None,
53        }
54    }
55
56    pub fn with_ruleset(ruleset: Ruleset) -> Self {
57        Self {
58            selector: Selector::new(),
59            handle: Handle::current(),
60            ruleset: Some(ruleset),
61        }
62    }
63}
64
65impl Drop for LocalNetworking {
66    fn drop(&mut self) {
67        self.selector.shutdown();
68    }
69}
70
71impl Default for LocalNetworking {
72    fn default() -> Self {
73        Self::new()
74    }
75}
76
77fn sock_addr_into_socket_addr(addr: socket2::SockAddr) -> Result<SocketAddr> {
78    addr.as_socket().ok_or(NetworkError::UnknownError)
79}
80
81fn tcp_socket_domain(addr: SocketAddr) -> socket2::Domain {
82    if addr.is_ipv4() {
83        socket2::Domain::IPV4
84    } else {
85        socket2::Domain::IPV6
86    }
87}
88
89#[allow(clippy::needless_bool)]
90fn tcp_connect_in_progress(err: &io::Error) -> bool {
91    if matches!(
92        err.kind(),
93        io::ErrorKind::WouldBlock | io::ErrorKind::Interrupted
94    ) {
95        true
96    } else {
97        #[cfg(all(target_family = "unix", feature = "libc"))]
98        {
99            matches!(
100                err.raw_os_error(),
101                Some(raw) if raw == libc::EINPROGRESS || raw == libc::EALREADY
102            )
103        }
104
105        #[cfg(not(all(target_family = "unix", feature = "libc")))]
106        {
107            false
108        }
109    }
110}
111
112#[async_trait::async_trait]
113#[allow(unused_variables)]
114impl VirtualNetworking for LocalNetworking {
115    async fn listen_tcp(
116        &self,
117        addr: SocketAddr,
118        only_v6: bool,
119        reuse_port: bool,
120        reuse_addr: bool,
121    ) -> Result<Box<dyn VirtualTcpListener + Sync>> {
122        if let Some(ruleset) = self.ruleset.as_ref()
123            && !ruleset.allows_socket(addr, Direction::Inbound)
124        {
125            tracing::warn!(%addr, "listen_tcp blocked by firewall rule");
126            return Err(NetworkError::PermissionDenied);
127        }
128
129        self.bind_tcp(addr, only_v6, reuse_port, reuse_addr)
130            .await?
131            .listen()
132    }
133
134    async fn bind_tcp(
135        &self,
136        addr: SocketAddr,
137        only_v6: bool,
138        reuse_port: bool,
139        reuse_addr: bool,
140    ) -> Result<Box<dyn VirtualTcpBoundSocket + Sync>> {
141        if let Some(ruleset) = self.ruleset.as_ref()
142            && !ruleset.allows_socket(addr, Direction::Inbound)
143        {
144            tracing::warn!(%addr, "bind_tcp blocked by firewall rule");
145            return Err(NetworkError::PermissionDenied);
146        }
147
148        let socket = socket2::Socket::new(tcp_socket_domain(addr), socket2::Type::STREAM, None)
149            .map_err(io_err_into_net_error)?;
150        socket
151            .set_nonblocking(true)
152            .map_err(io_err_into_net_error)?;
153        if addr.is_ipv6() {
154            socket.set_only_v6(only_v6).map_err(io_err_into_net_error)?;
155        }
156        socket
157            .set_reuse_address(reuse_addr)
158            .map_err(io_err_into_net_error)?;
159        #[cfg(not(windows))]
160        socket
161            .set_reuse_port(reuse_port)
162            .map_err(io_err_into_net_error)?;
163        socket.bind(&addr.into()).map_err(io_err_into_net_error)?;
164
165        Ok(Box::new(LocalTcpBoundSocket {
166            socket: Some(socket),
167            selector: self.selector.clone(),
168            ruleset: self.ruleset.clone(),
169        }))
170    }
171
172    async fn bind_udp(
173        &self,
174        addr: SocketAddr,
175        reuse_port: bool,
176        reuse_addr: bool,
177    ) -> Result<Box<dyn VirtualUdpSocket + Sync>> {
178        #[cfg(not(windows))]
179        use socket2::{Domain, Socket, Type};
180
181        if let Some(ruleset) = self.ruleset.as_ref()
182            && !ruleset.allows_socket(addr, Direction::Inbound)
183        {
184            tracing::warn!(%addr, "bind_udp blocked by firewall rule");
185            return Err(NetworkError::PermissionDenied);
186        }
187
188        #[cfg(not(windows))]
189        let socket = {
190            let domain = if addr.is_ipv4() {
191                Domain::IPV4
192            } else {
193                Domain::IPV6
194            };
195            let std_sock = Socket::new(domain, Type::DGRAM, None).map_err(io_err_into_net_error)?;
196            std_sock
197                .set_nonblocking(true)
198                .map_err(io_err_into_net_error)?;
199            std_sock
200                .set_reuse_address(reuse_addr)
201                .map_err(io_err_into_net_error)?;
202            std_sock
203                .set_reuse_port(reuse_port)
204                .map_err(io_err_into_net_error)?;
205            std_sock.bind(&addr.into()).map_err(io_err_into_net_error)?;
206            mio::net::UdpSocket::from_std(std_sock.into())
207        };
208        #[cfg(windows)]
209        let socket = mio::net::UdpSocket::bind(addr).map_err(io_err_into_net_error)?;
210
211        #[allow(unused_mut)]
212        let mut ret = LocalUdpSocket {
213            selector: self.selector.clone(),
214            socket,
215            addr,
216            handler_guard: HandlerGuardState::None,
217            backlog: Default::default(),
218            ruleset: self.ruleset.clone(),
219        };
220
221        // In windows we can not poll the socket as it is not supported and hence
222        // what we do is immediately set the writable flag and relay on `mio` to
223        // refresh that flag when the state changes. In Linux what we do is actually
224        // make a non-blocking `poll` call to determine this state
225        #[cfg(target_os = "windows")]
226        {
227            let (state, selector, socket) = ret.split_borrow();
228            let map = state_as_waker_map(state, selector, socket).map_err(io_err_into_net_error)?;
229            map.push(InterestType::Writable);
230        }
231
232        Ok(Box::new(ret))
233    }
234
235    async fn connect_tcp(
236        &self,
237        _addr: SocketAddr,
238        mut peer: SocketAddr,
239    ) -> Result<Box<dyn VirtualTcpSocket + Sync>> {
240        if let Some(ruleset) = self.ruleset.as_ref()
241            && !ruleset.allows_socket(peer, Direction::Outbound)
242        {
243            tracing::warn!(%peer, "connect_tcp blocked by firewall rule");
244            return Err(NetworkError::PermissionDenied);
245        }
246
247        let stream = mio::net::TcpStream::connect(peer).map_err(io_err_into_net_error)?;
248
249        if let Ok(p) = stream.peer_addr() {
250            peer = p;
251        }
252        let socket = Box::new(LocalTcpStream::new(self.selector.clone(), stream, peer));
253        Ok(socket)
254    }
255
256    async fn resolve(
257        &self,
258        host: &str,
259        port: Option<u16>,
260        dns_server: Option<IpAddr>,
261    ) -> Result<Vec<IpAddr>> {
262        if let Some(ruleset) = self.ruleset.as_ref()
263            && !ruleset.allows_domain(host)
264        {
265            tracing::warn!(%host, "dns resolve blocked by firewall rule");
266            return Err(NetworkError::PermissionDenied);
267        }
268
269        let host_to_lookup = if host.contains(':') {
270            host.to_string()
271        } else {
272            format!("{}:{}", host, port.unwrap_or(0))
273        };
274        let addrs = self
275            .handle
276            .spawn(tokio::net::lookup_host(host_to_lookup))
277            .await
278            .map_err(|_| NetworkError::IOError)?
279            .map(|a| a.map(|a| a.ip()).collect::<Vec<_>>())
280            .map_err(io_err_into_net_error)?;
281
282        if let Some(ruleset) = self.ruleset.as_ref() {
283            if let Err(e) = ruleset.expand_domain(host, &addrs) {
284                tracing::debug!(err=%e, "ruleset expansion failed");
285            } else {
286                tracing::debug!(addrs=?addrs, domain = host, "ruleset expansion")
287            }
288        }
289
290        Ok(addrs)
291    }
292}
293
294#[derive(Debug)]
295pub struct LocalTcpListener {
296    stream: mio::net::TcpListener,
297    selector: Arc<Selector>,
298    handler_guard: HandlerGuardState,
299    no_delay: Option<bool>,
300    keep_alive: Option<bool>,
301    backlog: VecDeque<(Box<dyn VirtualTcpSocket + Sync>, SocketAddr)>,
302    ruleset: Option<Ruleset>,
303}
304
305impl LocalTcpListener {
306    fn try_accept_internal(&mut self) -> Result<(Box<dyn VirtualTcpSocket + Sync>, SocketAddr)> {
307        match self.stream.accept().map_err(io_err_into_net_error) {
308            Ok((stream, addr)) => {
309                if let Some(ruleset) = self.ruleset.as_ref()
310                    && !ruleset.allows_socket(addr, Direction::Outbound)
311                {
312                    tracing::warn!(%addr, "try_accept blocked by firewall rule");
313                    return Err(NetworkError::PermissionDenied);
314                }
315
316                let mut socket = LocalTcpStream::new(self.selector.clone(), stream, addr);
317                if let Some(no_delay) = self.no_delay {
318                    socket.set_nodelay(no_delay).ok();
319                }
320                if let Some(keep_alive) = self.keep_alive {
321                    socket.set_keepalive(keep_alive).ok();
322                }
323                Ok((Box::new(socket), addr))
324            }
325            Err(NetworkError::WouldBlock) => {
326                if let HandlerGuardState::WakerMap(_, map) = &mut self.handler_guard {
327                    map.pop(InterestType::Readable);
328                    map.pop(InterestType::Writable);
329                }
330                Err(NetworkError::WouldBlock)
331            }
332            Err(err) => Err(err),
333        }
334    }
335}
336
337impl VirtualTcpListener for LocalTcpListener {
338    fn try_accept(&mut self) -> Result<(Box<dyn VirtualTcpSocket + Sync>, SocketAddr)> {
339        if let Some(child) = self.backlog.pop_front() {
340            return Ok(child);
341        }
342        self.try_accept_internal()
343    }
344
345    fn set_handler(&mut self, mut handler: Box<dyn InterestHandler + Send + Sync>) -> Result<()> {
346        if let HandlerGuardState::ExternalHandler(guard) = &mut self.handler_guard {
347            match guard.replace_handler(handler) {
348                Ok(()) => return Ok(()),
349                Err(h) => handler = h,
350            }
351
352            // the handler could not be replaced so we need to build a new handler instead
353            if let Err(err) = guard.unregister(&mut self.stream) {
354                tracing::debug!("failed to unregister previous token - {}", err);
355            }
356        }
357
358        let guard = InterestGuard::new(
359            &self.selector,
360            handler,
361            &mut self.stream,
362            mio::Interest::READABLE.add(mio::Interest::WRITABLE),
363        )
364        .map_err(io_err_into_net_error)?;
365
366        self.handler_guard = HandlerGuardState::ExternalHandler(guard);
367
368        Ok(())
369    }
370
371    fn addr_local(&self) -> Result<SocketAddr> {
372        self.stream.local_addr().map_err(io_err_into_net_error)
373    }
374
375    fn set_ttl(&mut self, ttl: u8) -> Result<()> {
376        self.stream
377            .set_ttl(ttl as u32)
378            .map_err(io_err_into_net_error)
379    }
380
381    fn ttl(&self) -> Result<u8> {
382        self.stream
383            .ttl()
384            .map(|ttl| ttl as u8)
385            .map_err(io_err_into_net_error)
386    }
387}
388
389impl LocalTcpListener {
390    fn split_borrow(
391        &mut self,
392    ) -> (
393        &mut HandlerGuardState,
394        &Arc<Selector>,
395        &mut mio::net::TcpListener,
396    ) {
397        (&mut self.handler_guard, &self.selector, &mut self.stream)
398    }
399}
400
401impl VirtualIoSource for LocalTcpListener {
402    fn remove_handler(&mut self) {
403        let mut guard = HandlerGuardState::None;
404        std::mem::swap(&mut guard, &mut self.handler_guard);
405        match guard {
406            HandlerGuardState::ExternalHandler(mut guard) => {
407                guard.unregister(&mut self.stream).ok();
408            }
409            HandlerGuardState::WakerMap(mut guard, _) => {
410                guard.unregister(&mut self.stream).ok();
411            }
412            HandlerGuardState::None => {}
413        }
414    }
415
416    fn poll_read_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
417        if !self.backlog.is_empty() {
418            return Poll::Ready(Ok(self.backlog.len()));
419        }
420
421        let (state, selector, source) = self.split_borrow();
422        let map = state_as_waker_map(state, selector, source).map_err(io_err_into_net_error)?;
423        map.add(InterestType::Readable, cx.waker());
424
425        if let Ok(child) = self.try_accept_internal() {
426            self.backlog.push_back(child);
427            return Poll::Ready(Ok(1));
428        }
429        Poll::Pending
430    }
431
432    fn poll_write_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
433        if !self.backlog.is_empty() {
434            return Poll::Ready(Ok(self.backlog.len()));
435        }
436
437        let (state, selector, source) = self.split_borrow();
438        let map = state_as_waker_map(state, selector, source).map_err(io_err_into_net_error)?;
439        map.add(InterestType::Writable, cx.waker());
440
441        if let Ok(child) = self.try_accept_internal() {
442            self.backlog.push_back(child);
443            return Poll::Ready(Ok(1));
444        }
445        Poll::Pending
446    }
447}
448
449#[derive(Debug)]
450pub struct LocalTcpBoundSocket {
451    socket: Option<socket2::Socket>,
452    selector: Arc<Selector>,
453    ruleset: Option<Ruleset>,
454}
455
456impl VirtualTcpBoundSocket for LocalTcpBoundSocket {
457    fn addr_local(&self) -> Result<SocketAddr> {
458        let socket = self.socket.as_ref().ok_or(NetworkError::InvalidFd)?;
459        let addr = socket.local_addr().map_err(io_err_into_net_error)?;
460        sock_addr_into_socket_addr(addr)
461    }
462
463    fn listen(&mut self) -> Result<Box<dyn VirtualTcpListener + Sync>> {
464        let socket = self.socket.take().ok_or(NetworkError::InvalidFd)?;
465        socket
466            .listen(LISTEN_BACKLOG)
467            .map_err(io_err_into_net_error)?;
468        let listener = mio::net::TcpListener::from_std(socket.into());
469        Ok(Box::new(LocalTcpListener {
470            stream: listener,
471            selector: self.selector.clone(),
472            handler_guard: HandlerGuardState::None,
473            no_delay: None,
474            keep_alive: None,
475            backlog: Default::default(),
476            ruleset: self.ruleset.clone(),
477        }))
478    }
479
480    fn connect(&mut self, mut peer: SocketAddr) -> Result<Box<dyn VirtualTcpSocket + Sync>> {
481        if let Some(ruleset) = self.ruleset.as_ref()
482            && !ruleset.allows_socket(peer, Direction::Outbound)
483        {
484            tracing::warn!(%peer, "bound connect_tcp blocked by firewall rule");
485            return Err(NetworkError::PermissionDenied);
486        }
487
488        let socket = self.socket.take().ok_or(NetworkError::InvalidFd)?;
489        if let Err(err) = socket.connect(&peer.into())
490            && !tcp_connect_in_progress(&err)
491        {
492            return Err(io_err_into_net_error(err));
493        }
494
495        let stream = mio::net::TcpStream::from_std(socket.into());
496        if let Ok(p) = stream.peer_addr() {
497            peer = p;
498        }
499        Ok(Box::new(LocalTcpStream::new(
500            self.selector.clone(),
501            stream,
502            peer,
503        )))
504    }
505
506    fn set_ttl(&mut self, ttl: u32) -> Result<()> {
507        let socket = self.socket.as_ref().ok_or(NetworkError::InvalidFd)?;
508        match self.addr_local()?.ip() {
509            IpAddr::V4(_) => socket.set_ttl_v4(ttl).map_err(io_err_into_net_error),
510            IpAddr::V6(_) => socket
511                .set_unicast_hops_v6(ttl)
512                .map_err(io_err_into_net_error),
513        }
514    }
515
516    fn ttl(&self) -> Result<u32> {
517        let socket = self.socket.as_ref().ok_or(NetworkError::InvalidFd)?;
518        match self.addr_local()?.ip() {
519            IpAddr::V4(_) => socket.ttl_v4().map_err(io_err_into_net_error),
520            IpAddr::V6(_) => socket.unicast_hops_v6().map_err(io_err_into_net_error),
521        }
522    }
523}
524
525#[derive(Debug)]
526enum ConnectState {
527    Unknown,
528    Opened,
529    Failed,
530}
531
532#[derive(Debug)]
533pub struct LocalTcpStream {
534    stream: mio::net::TcpStream,
535    addr: SocketAddr,
536    shutdown: Option<Shutdown>,
537    selector: Arc<Selector>,
538    handler_guard: HandlerGuardState,
539    buffer: BytesMut,
540    connect_state: Mutex<ConnectState>,
541}
542
543impl LocalTcpStream {
544    fn new(selector: Arc<Selector>, stream: mio::net::TcpStream, addr: SocketAddr) -> Self {
545        #[allow(unused_mut)]
546        let mut ret = Self {
547            stream,
548            addr,
549            shutdown: None,
550            selector,
551            handler_guard: HandlerGuardState::None,
552            buffer: BytesMut::new(),
553            connect_state: Mutex::new(ConnectState::Unknown),
554        };
555
556        // In windows we can not poll the socket as it is not supported and hence
557        // what we do is immediately set the writable flag and relay on `mio` to
558        // refresh that flag when the state changes. In Linux what we do is actually
559        // make a non-blocking `poll` call to determine this state
560        #[cfg(target_os = "windows")]
561        {
562            let (state, selector, socket, _) = ret.split_borrow();
563            if let Ok(map) = state_as_waker_map(state, selector, socket) {
564                map.push(InterestType::Writable);
565            }
566        }
567
568        ret
569    }
570
571    fn with_sock_ref<F, R>(&self, f: F) -> R
572    where
573        for<'a> F: FnOnce(socket2::SockRef<'a>) -> R,
574    {
575        #[cfg(not(windows))]
576        let r = socket2::SockRef::from(&self.stream);
577
578        #[cfg(windows)]
579        let b = unsafe {
580            std::os::windows::io::BorrowedSocket::borrow_raw(self.stream.as_raw_socket())
581        };
582        #[cfg(windows)]
583        let r = socket2::SockRef::from(&b);
584
585        f(r)
586    }
587}
588
589impl VirtualTcpSocket for LocalTcpStream {
590    fn set_recv_buf_size(&mut self, size: usize) -> Result<()> {
591        Ok(())
592    }
593
594    fn recv_buf_size(&self) -> Result<usize> {
595        Err(NetworkError::Unsupported)
596    }
597
598    fn set_send_buf_size(&mut self, size: usize) -> Result<()> {
599        Ok(())
600    }
601
602    fn send_buf_size(&self) -> Result<usize> {
603        Err(NetworkError::Unsupported)
604    }
605
606    fn set_nodelay(&mut self, nodelay: bool) -> Result<()> {
607        self.stream
608            .set_nodelay(nodelay)
609            .map_err(io_err_into_net_error)
610    }
611
612    fn nodelay(&self) -> Result<bool> {
613        self.stream.nodelay().map_err(io_err_into_net_error)
614    }
615
616    fn set_keepalive(&mut self, keepalive: bool) -> Result<()> {
617        self.with_sock_ref(|s| s.set_keepalive(true))
618            .map_err(io_err_into_net_error)?;
619        Ok(())
620    }
621
622    fn keepalive(&self) -> Result<bool> {
623        let ret = self
624            .with_sock_ref(|s| s.keepalive())
625            .map_err(io_err_into_net_error)?;
626        Ok(ret)
627    }
628
629    #[cfg(not(target_os = "windows"))]
630    fn set_dontroute(&mut self, val: bool) -> Result<()> {
631        // TODO:
632        // Don't route is being set by WASIX which breaks networking
633        // Why this is being set is unknown but we need to disable
634        // the functionality for now as it breaks everything
635
636        let val = val as libc::c_int;
637        let payload = &val as *const libc::c_int as *const libc::c_void;
638        let err = unsafe {
639            libc::setsockopt(
640                self.stream.as_raw_fd(),
641                libc::SOL_SOCKET,
642                libc::SO_DONTROUTE,
643                payload,
644                std::mem::size_of::<libc::c_int>() as libc::socklen_t,
645            )
646        };
647        if err == -1 {
648            return Err(io_err_into_net_error(std::io::Error::last_os_error()));
649        }
650        Ok(())
651    }
652    #[cfg(target_os = "windows")]
653    fn set_dontroute(&mut self, val: bool) -> Result<()> {
654        Err(NetworkError::Unsupported)
655    }
656
657    #[cfg(not(target_os = "windows"))]
658    fn dontroute(&self) -> Result<bool> {
659        let mut payload: MaybeUninit<libc::c_int> = MaybeUninit::uninit();
660        let mut len = std::mem::size_of::<libc::c_int>() as libc::socklen_t;
661        let err = unsafe {
662            libc::getsockopt(
663                self.stream.as_raw_fd(),
664                libc::SOL_SOCKET,
665                libc::SO_DONTROUTE,
666                payload.as_mut_ptr().cast(),
667                &mut len,
668            )
669        };
670        if err == -1 {
671            return Err(io_err_into_net_error(std::io::Error::last_os_error()));
672        }
673        Ok(unsafe { payload.assume_init() != 0 })
674    }
675    #[cfg(target_os = "windows")]
676    fn dontroute(&self) -> Result<bool> {
677        Err(NetworkError::Unsupported)
678    }
679
680    fn addr_peer(&self) -> Result<SocketAddr> {
681        Ok(self.addr)
682    }
683
684    fn shutdown(&mut self, how: Shutdown) -> Result<()> {
685        self.stream.shutdown(how).map_err(io_err_into_net_error)?;
686        self.shutdown = Some(how);
687        Ok(())
688    }
689
690    fn is_closed(&self) -> bool {
691        false
692    }
693}
694
695impl VirtualConnectedSocket for LocalTcpStream {
696    fn set_linger(&mut self, linger: Option<Duration>) -> Result<()> {
697        self.with_sock_ref(|s| s.set_linger(linger))
698            .map_err(io_err_into_net_error)?;
699        Ok(())
700    }
701
702    fn linger(&self) -> Result<Option<Duration>> {
703        self.with_sock_ref(|s| s.linger())
704            .map_err(io_err_into_net_error)
705    }
706
707    fn try_send(&mut self, data: &[u8]) -> Result<usize> {
708        let ret = self.stream.write(data).map_err(io_err_into_net_error);
709        match &ret {
710            Ok(0) | Err(NetworkError::WouldBlock) => {
711                if let HandlerGuardState::WakerMap(_, map) = &mut self.handler_guard {
712                    map.pop(InterestType::Writable);
713                }
714            }
715            _ => {}
716        }
717        ret
718    }
719
720    fn try_flush(&mut self) -> Result<()> {
721        self.stream.flush().map_err(io_err_into_net_error)
722    }
723
724    fn close(&mut self) -> Result<()> {
725        Ok(())
726    }
727
728    fn try_recv(&mut self, buf: &mut [MaybeUninit<u8>], peek: bool) -> Result<usize> {
729        let buf: &mut [u8] = unsafe { std::mem::transmute(buf) };
730        if !self.buffer.is_empty() {
731            let amt = buf.len().min(self.buffer.len());
732            buf[..amt].copy_from_slice(&self.buffer[..amt]);
733            if !peek {
734                self.buffer.advance(amt);
735            }
736            return Ok(amt);
737        }
738
739        if peek {
740            self.stream.peek(buf)
741        } else {
742            self.stream.read(buf)
743        }
744        .map_err(io_err_into_net_error)
745    }
746}
747
748impl VirtualSocket for LocalTcpStream {
749    fn set_ttl(&mut self, ttl: u32) -> Result<()> {
750        self.stream.set_ttl(ttl).map_err(io_err_into_net_error)
751    }
752
753    fn ttl(&self) -> Result<u32> {
754        self.stream.ttl().map_err(io_err_into_net_error)
755    }
756
757    fn addr_local(&self) -> Result<SocketAddr> {
758        self.stream.local_addr().map_err(io_err_into_net_error)
759    }
760
761    fn status(&self) -> Result<SocketStatus> {
762        // `take_error()` consumes the latched socket error, so once the
763        // connect resolves we cache the terminal state to keep status() stable.
764        let mut connect_state = self.connect_state.lock().unwrap();
765        match *connect_state {
766            ConnectState::Opened => return Ok(SocketStatus::Opened),
767            ConnectState::Failed => return Ok(SocketStatus::Failed),
768            ConnectState::Unknown => {}
769        }
770
771        if self
772            .with_sock_ref(|sockref| sockref.take_error())
773            .map_err(io_err_into_net_error)?
774            .is_some()
775        {
776            *connect_state = ConnectState::Failed;
777            return Ok(SocketStatus::Failed); // connect error on the socket
778        }
779        match self.stream.peer_addr() {
780            Ok(_) => {
781                *connect_state = ConnectState::Opened;
782                Ok(SocketStatus::Opened) // TCP handshake completed.
783            }
784            Err(err) => {
785                if matches!(
786                    err.kind(),
787                    io::ErrorKind::NotConnected | io::ErrorKind::WouldBlock
788                ) {
789                    Ok(SocketStatus::Opening) // The connect is still in progress
790                } else {
791                    // TODO: Store the concrete err so we can return it later on
792                    *connect_state = ConnectState::Failed;
793                    Ok(SocketStatus::Failed) // Any other error means the socket is unusable
794                }
795            }
796        }
797    }
798
799    fn set_handler(&mut self, mut handler: Box<dyn InterestHandler + Send + Sync>) -> Result<()> {
800        if let HandlerGuardState::ExternalHandler(guard) = &mut self.handler_guard {
801            match guard.replace_handler(handler) {
802                Ok(()) => return Ok(()),
803                Err(h) => handler = h,
804            }
805
806            // the handler could not be replaced so we need to build a new handler instead
807            if let Err(err) = guard.unregister(&mut self.stream) {
808                tracing::debug!("failed to unregister previous token - {}", err);
809            }
810        }
811
812        let guard = InterestGuard::new(
813            &self.selector,
814            handler,
815            &mut self.stream,
816            mio::Interest::READABLE.add(mio::Interest::WRITABLE),
817        )
818        .map_err(io_err_into_net_error)?;
819
820        self.handler_guard = HandlerGuardState::ExternalHandler(guard);
821
822        Ok(())
823    }
824}
825
826impl LocalTcpStream {
827    fn split_borrow(
828        &mut self,
829    ) -> (
830        &mut HandlerGuardState,
831        &Arc<Selector>,
832        &mut mio::net::TcpStream,
833        &mut BytesMut,
834    ) {
835        (
836            &mut self.handler_guard,
837            &self.selector,
838            &mut self.stream,
839            &mut self.buffer,
840        )
841    }
842}
843
844impl VirtualIoSource for LocalTcpStream {
845    fn remove_handler(&mut self) {
846        let mut guard = HandlerGuardState::None;
847        std::mem::swap(&mut guard, &mut self.handler_guard);
848        match guard {
849            HandlerGuardState::ExternalHandler(mut guard) => {
850                guard.unregister(&mut self.stream).ok();
851            }
852            HandlerGuardState::WakerMap(mut guard, _) => {
853                guard.unregister(&mut self.stream).ok();
854            }
855            HandlerGuardState::None => {}
856        }
857    }
858
859    fn poll_read_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
860        if !self.buffer.is_empty() {
861            return Poll::Ready(Ok(self.buffer.len()));
862        }
863
864        let (state, selector, stream, buffer) = self.split_borrow();
865        let map = state_as_waker_map(state, selector, stream).map_err(io_err_into_net_error)?;
866        map.pop(InterestType::Readable);
867        map.add(InterestType::Readable, cx.waker());
868
869        buffer.reserve(buffer.len() + 10240);
870        let uninit: &mut [MaybeUninit<u8>] = buffer.spare_capacity_mut();
871        let uninit_unsafe: &mut [u8] = unsafe { std::mem::transmute(uninit) };
872
873        match stream.read(uninit_unsafe) {
874            Ok(0) => Poll::Ready(Ok(0)),
875            Ok(amt) => {
876                unsafe {
877                    buffer.set_len(buffer.len() + amt);
878                }
879                Poll::Ready(Ok(amt))
880            }
881            Err(err) if err.kind() == io::ErrorKind::ConnectionAborted => Poll::Ready(Ok(0)),
882            Err(err) if err.kind() == io::ErrorKind::ConnectionReset => Poll::Ready(Ok(0)),
883            Err(err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
884            Err(err) => Poll::Ready(Err(io_err_into_net_error(err))),
885        }
886    }
887
888    fn poll_write_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
889        let (state, selector, stream, _) = self.split_borrow();
890        let map = state_as_waker_map(state, selector, stream).map_err(io_err_into_net_error)?;
891        #[cfg(not(target_os = "windows"))]
892        map.pop(InterestType::Writable);
893        map.add(InterestType::Writable, cx.waker());
894        map.add(InterestType::Closed, cx.waker());
895        if map.has_interest(InterestType::Closed) {
896            return Poll::Ready(Ok(0));
897        }
898
899        #[cfg(not(target_os = "windows"))]
900        match libc_poll(stream.as_raw_fd(), libc::POLLOUT | libc::POLLHUP) {
901            Some(val) if (val & libc::POLLHUP) != 0 => {
902                return Poll::Ready(Ok(0));
903            }
904            Some(val) if (val & libc::POLLOUT) != 0 => return Poll::Ready(Ok(10240)),
905            _ => {}
906        }
907
908        // In windows we can not poll the socket as it is not supported and hence
909        // what we do is immediately set the writable flag and relay on `mio` to
910        // refresh that flag when the state changes. In Linux what we do is actually
911        // make a non-blocking `poll` call to determine this state
912        #[cfg(target_os = "windows")]
913        if map.has_interest(InterestType::Writable) {
914            return Poll::Ready(Ok(10240));
915        }
916
917        Poll::Pending
918    }
919}
920
921#[cfg(not(target_os = "windows"))]
922fn libc_poll(fd: RawFd, events: libc::c_short) -> Option<libc::c_short> {
923    let mut fds: [libc::pollfd; 1] = [libc::pollfd {
924        fd,
925        events,
926        revents: 0,
927    }];
928    let fds_mut = &mut fds[..];
929    let ret = unsafe { libc::poll(fds_mut.as_mut_ptr(), 1, 0) };
930    match ret == 1 {
931        true => Some(fds[0].revents),
932        false => None,
933    }
934}
935
936#[derive(Debug)]
937pub struct LocalUdpSocket {
938    socket: mio::net::UdpSocket,
939    #[allow(dead_code)]
940    addr: SocketAddr,
941    selector: Arc<Selector>,
942    handler_guard: HandlerGuardState,
943    backlog: VecDeque<(BytesMut, SocketAddr)>,
944    ruleset: Option<Ruleset>,
945}
946
947impl LocalUdpSocket {
948    fn with_sock_ref<F, R>(&self, f: F) -> R
949    where
950        for<'a> F: FnOnce(socket2::SockRef<'a>) -> R,
951    {
952        #[cfg(not(windows))]
953        let r = socket2::SockRef::from(&self.socket);
954
955        #[cfg(windows)]
956        let b = unsafe {
957            std::os::windows::io::BorrowedSocket::borrow_raw(self.socket.as_raw_socket())
958        };
959        #[cfg(windows)]
960        let r = socket2::SockRef::from(&b);
961
962        f(r)
963    }
964}
965
966impl VirtualUdpSocket for LocalUdpSocket {
967    fn set_broadcast(&mut self, broadcast: bool) -> Result<()> {
968        self.socket
969            .set_broadcast(broadcast)
970            .map_err(io_err_into_net_error)
971    }
972
973    fn broadcast(&self) -> Result<bool> {
974        self.socket.broadcast().map_err(io_err_into_net_error)
975    }
976
977    fn set_multicast_loop_v4(&mut self, val: bool) -> Result<()> {
978        self.socket
979            .set_multicast_loop_v4(val)
980            .map_err(io_err_into_net_error)
981    }
982
983    fn multicast_loop_v4(&self) -> Result<bool> {
984        self.socket
985            .multicast_loop_v4()
986            .map_err(io_err_into_net_error)
987    }
988
989    fn set_multicast_loop_v6(&mut self, val: bool) -> Result<()> {
990        self.socket
991            .set_multicast_loop_v6(val)
992            .map_err(io_err_into_net_error)
993    }
994
995    fn multicast_loop_v6(&self) -> Result<bool> {
996        self.socket
997            .multicast_loop_v6()
998            .map_err(io_err_into_net_error)
999    }
1000
1001    fn set_multicast_ttl_v4(&mut self, ttl: u32) -> Result<()> {
1002        self.socket
1003            .set_multicast_ttl_v4(ttl)
1004            .map_err(io_err_into_net_error)
1005    }
1006
1007    fn multicast_ttl_v4(&self) -> Result<u32> {
1008        self.socket
1009            .multicast_ttl_v4()
1010            .map_err(io_err_into_net_error)
1011    }
1012
1013    fn join_multicast_v4(&mut self, multiaddr: Ipv4Addr, iface: Ipv4Addr) -> Result<()> {
1014        self.with_sock_ref(|s| s.join_multicast_v4(&multiaddr, &iface))
1015            .map_err(io_err_into_net_error)
1016    }
1017
1018    fn leave_multicast_v4(&mut self, multiaddr: Ipv4Addr, iface: Ipv4Addr) -> Result<()> {
1019        self.with_sock_ref(|s| s.leave_multicast_v4(&multiaddr, &iface))
1020            .map_err(io_err_into_net_error)
1021    }
1022
1023    fn join_multicast_v6(&mut self, multiaddr: Ipv6Addr, iface: u32) -> Result<()> {
1024        self.socket
1025            .join_multicast_v6(&multiaddr, iface)
1026            .map_err(io_err_into_net_error)
1027    }
1028
1029    fn leave_multicast_v6(&mut self, multiaddr: Ipv6Addr, iface: u32) -> Result<()> {
1030        self.socket
1031            .leave_multicast_v6(&multiaddr, iface)
1032            .map_err(io_err_into_net_error)
1033    }
1034
1035    fn addr_peer(&self) -> Result<Option<SocketAddr>> {
1036        self.socket
1037            .peer_addr()
1038            .map(Some)
1039            .map_err(io_err_into_net_error)
1040    }
1041}
1042
1043impl VirtualConnectionlessSocket for LocalUdpSocket {
1044    fn try_send_to(&mut self, data: &[u8], addr: SocketAddr) -> Result<usize> {
1045        if let Some(ruleset) = self.ruleset.as_ref()
1046            && !ruleset.allows_socket(addr, Direction::Outbound)
1047        {
1048            tracing::warn!(%addr, "try_send blocked by firewall rule");
1049            return Err(NetworkError::PermissionDenied);
1050        }
1051
1052        let ret = self
1053            .socket
1054            .send_to(data, addr)
1055            .map_err(io_err_into_net_error);
1056        match &ret {
1057            Ok(0) | Err(NetworkError::WouldBlock) => {
1058                if let HandlerGuardState::WakerMap(_, map) = &mut self.handler_guard {
1059                    map.pop(InterestType::Writable);
1060                }
1061            }
1062            _ => {}
1063        }
1064        ret
1065    }
1066
1067    fn try_recv_from(
1068        &mut self,
1069        buf: &mut [MaybeUninit<u8>],
1070        peek: bool,
1071    ) -> Result<(usize, SocketAddr)> {
1072        let buf: &mut [u8] = unsafe { std::mem::transmute(buf) };
1073        if peek {
1074            self.socket.peek_from(buf)
1075        } else {
1076            self.socket.recv_from(buf)
1077        }
1078        .map_err(io_err_into_net_error)
1079    }
1080}
1081
1082impl VirtualSocket for LocalUdpSocket {
1083    fn set_ttl(&mut self, ttl: u32) -> Result<()> {
1084        self.socket.set_ttl(ttl).map_err(io_err_into_net_error)
1085    }
1086
1087    fn ttl(&self) -> Result<u32> {
1088        self.socket.ttl().map_err(io_err_into_net_error)
1089    }
1090
1091    fn addr_local(&self) -> Result<SocketAddr> {
1092        self.socket.local_addr().map_err(io_err_into_net_error)
1093    }
1094
1095    fn status(&self) -> Result<SocketStatus> {
1096        Ok(SocketStatus::Opened)
1097    }
1098
1099    fn set_handler(&mut self, mut handler: Box<dyn InterestHandler + Send + Sync>) -> Result<()> {
1100        if let HandlerGuardState::ExternalHandler(guard) = &mut self.handler_guard {
1101            match guard.replace_handler(handler) {
1102                Ok(()) => {
1103                    return Ok(());
1104                }
1105                Err(h) => handler = h,
1106            }
1107
1108            // the handler could not be replaced so we need to build a new handler instead
1109            if let Err(err) = guard.unregister(&mut self.socket) {
1110                tracing::debug!("failed to unregister previous token - {}", err);
1111            }
1112        }
1113
1114        let guard = InterestGuard::new(
1115            &self.selector,
1116            handler,
1117            &mut self.socket,
1118            mio::Interest::READABLE.add(mio::Interest::WRITABLE),
1119        )
1120        .map_err(io_err_into_net_error)?;
1121
1122        self.handler_guard = HandlerGuardState::ExternalHandler(guard);
1123
1124        Ok(())
1125    }
1126}
1127
1128impl LocalUdpSocket {
1129    fn split_borrow(
1130        &mut self,
1131    ) -> (
1132        &mut HandlerGuardState,
1133        &Arc<Selector>,
1134        &mut mio::net::UdpSocket,
1135    ) {
1136        (&mut self.handler_guard, &self.selector, &mut self.socket)
1137    }
1138}
1139
1140impl VirtualIoSource for LocalUdpSocket {
1141    fn remove_handler(&mut self) {
1142        let mut guard = HandlerGuardState::None;
1143        std::mem::swap(&mut guard, &mut self.handler_guard);
1144        match guard {
1145            HandlerGuardState::ExternalHandler(mut guard) => {
1146                guard.unregister(&mut self.socket).ok();
1147            }
1148            HandlerGuardState::WakerMap(mut guard, _) => {
1149                guard.unregister(&mut self.socket).ok();
1150            }
1151            HandlerGuardState::None => {}
1152        }
1153    }
1154
1155    fn poll_read_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
1156        if !self.backlog.is_empty() {
1157            let total = self.backlog.iter().map(|a| a.0.len()).sum();
1158            return Poll::Ready(Ok(total));
1159        }
1160
1161        let (state, selector, socket) = self.split_borrow();
1162        let map = state_as_waker_map(state, selector, socket).map_err(io_err_into_net_error)?;
1163        map.pop(InterestType::Readable);
1164        map.add(InterestType::Readable, cx.waker());
1165
1166        let mut buffer = BytesMut::default();
1167        buffer.reserve(10240);
1168        let uninit: &mut [MaybeUninit<u8>] = buffer.spare_capacity_mut();
1169        let uninit_unsafe: &mut [u8] = unsafe { std::mem::transmute(uninit) };
1170
1171        match self.socket.recv_from(uninit_unsafe) {
1172            Ok((0, _)) => Poll::Ready(Ok(0)),
1173            Ok((amt, peer)) => {
1174                unsafe {
1175                    buffer.set_len(amt);
1176                }
1177                self.backlog.push_back((buffer, peer));
1178                Poll::Ready(Ok(amt))
1179            }
1180            Err(err) if err.kind() == io::ErrorKind::ConnectionAborted => Poll::Ready(Ok(0)),
1181            Err(err) if err.kind() == io::ErrorKind::ConnectionReset => Poll::Ready(Ok(0)),
1182            Err(err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
1183            Err(err) => Poll::Ready(Err(io_err_into_net_error(err))),
1184        }
1185    }
1186
1187    fn poll_write_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
1188        let (state, selector, socket) = self.split_borrow();
1189        let map = state_as_waker_map(state, selector, socket).map_err(io_err_into_net_error)?;
1190        #[cfg(not(target_os = "windows"))]
1191        map.pop(InterestType::Writable);
1192        map.add(InterestType::Writable, cx.waker());
1193
1194        #[cfg(not(target_os = "windows"))]
1195        match libc_poll(socket.as_raw_fd(), libc::POLLOUT | libc::POLLHUP) {
1196            Some(val) if (val & libc::POLLHUP) != 0 => {
1197                return Poll::Ready(Ok(0));
1198            }
1199            Some(val) if (val & libc::POLLOUT) != 0 => return Poll::Ready(Ok(10240)),
1200            _ => {}
1201        }
1202
1203        // In windows we can not poll the socket as it is not supported and hence
1204        // what we do is immediately set the writable flag and relay on `mio` to
1205        // refresh that flag when the state changes. In Linux what we do is actually
1206        // make a non-blocking `poll` call to determine this state
1207        #[cfg(target_os = "windows")]
1208        if map.has_interest(InterestType::Writable) {
1209            return Poll::Ready(Ok(10240));
1210        }
1211
1212        Poll::Pending
1213    }
1214}