virtual_net/
host.rs

1#![allow(unused_variables)]
2use crate::ruleset::{Direction, Ruleset};
3#[allow(unused_imports)]
4use crate::{
5    IpCidr, IpRoute, NetworkError, Result, SocketStatus, StreamSecurity, VirtualConnectedSocket,
6    VirtualConnectionlessSocket, VirtualIcmpSocket, VirtualNetworking, VirtualRawSocket,
7    VirtualSocket, VirtualTcpListener, VirtualTcpSocket, VirtualUdpSocket,
8};
9use crate::{VirtualIoSource, io_err_into_net_error};
10use bytes::{Buf, BytesMut};
11use std::collections::VecDeque;
12use std::io::{self, Read, Write};
13use std::mem::MaybeUninit;
14use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, Shutdown, SocketAddr};
15#[cfg(not(target_os = "windows"))]
16use std::os::fd::AsRawFd;
17#[cfg(not(target_os = "windows"))]
18use std::os::fd::RawFd;
19#[cfg(windows)]
20use std::os::windows::io::AsRawSocket;
21
22use std::sync::{Arc, Mutex};
23use std::task::Poll;
24use std::time::Duration;
25use tokio::runtime::Handle;
26#[allow(unused_imports, dead_code)]
27use tracing::{debug, error, info, trace, warn};
28use virtual_mio::{
29    HandlerGuardState, InterestGuard, InterestHandler, InterestType, Selector, state_as_waker_map,
30};
31
32#[derive(Debug)]
33pub struct LocalNetworking {
34    selector: Arc<Selector>,
35    handle: Handle,
36    ruleset: Option<Ruleset>,
37}
38
39impl LocalNetworking {
40    pub fn new() -> Self {
41        Self {
42            selector: Selector::new(),
43            handle: Handle::current(),
44            ruleset: None,
45        }
46    }
47
48    pub fn with_ruleset(ruleset: Ruleset) -> Self {
49        Self {
50            selector: Selector::new(),
51            handle: Handle::current(),
52            ruleset: Some(ruleset),
53        }
54    }
55}
56
57impl Drop for LocalNetworking {
58    fn drop(&mut self) {
59        self.selector.shutdown();
60    }
61}
62
63impl Default for LocalNetworking {
64    fn default() -> Self {
65        Self::new()
66    }
67}
68
69#[async_trait::async_trait]
70#[allow(unused_variables)]
71impl VirtualNetworking for LocalNetworking {
72    async fn listen_tcp(
73        &self,
74        addr: SocketAddr,
75        only_v6: bool,
76        reuse_port: bool,
77        reuse_addr: bool,
78    ) -> Result<Box<dyn VirtualTcpListener + Sync>> {
79        if let Some(ruleset) = self.ruleset.as_ref()
80            && !ruleset.allows_socket(addr, Direction::Inbound)
81        {
82            tracing::warn!(%addr, "listen_tcp blocked by firewall rule");
83            return Err(NetworkError::PermissionDenied);
84        }
85
86        let listener = std::net::TcpListener::bind(addr)
87            .map(|sock| {
88                sock.set_nonblocking(true).ok();
89                Box::new(LocalTcpListener {
90                    stream: mio::net::TcpListener::from_std(sock),
91                    selector: self.selector.clone(),
92                    handler_guard: HandlerGuardState::None,
93                    no_delay: None,
94                    keep_alive: None,
95                    backlog: Default::default(),
96                    ruleset: self.ruleset.clone(),
97                })
98            })
99            .map_err(io_err_into_net_error)?;
100        Ok(listener)
101    }
102
103    async fn bind_udp(
104        &self,
105        addr: SocketAddr,
106        reuse_port: bool,
107        reuse_addr: bool,
108    ) -> Result<Box<dyn VirtualUdpSocket + Sync>> {
109        #[cfg(not(windows))]
110        use socket2::{Domain, Socket, Type};
111
112        if let Some(ruleset) = self.ruleset.as_ref()
113            && !ruleset.allows_socket(addr, Direction::Inbound)
114        {
115            tracing::warn!(%addr, "bind_udp blocked by firewall rule");
116            return Err(NetworkError::PermissionDenied);
117        }
118
119        #[cfg(not(windows))]
120        let socket = {
121            let domain = if addr.is_ipv4() {
122                Domain::IPV4
123            } else {
124                Domain::IPV6
125            };
126            let std_sock = Socket::new(domain, Type::DGRAM, None).map_err(io_err_into_net_error)?;
127            std_sock
128                .set_nonblocking(true)
129                .map_err(io_err_into_net_error)?;
130            std_sock
131                .set_reuse_address(reuse_addr)
132                .map_err(io_err_into_net_error)?;
133            std_sock
134                .set_reuse_port(reuse_port)
135                .map_err(io_err_into_net_error)?;
136            std_sock.bind(&addr.into()).map_err(io_err_into_net_error)?;
137            mio::net::UdpSocket::from_std(std_sock.into())
138        };
139        #[cfg(windows)]
140        let socket = mio::net::UdpSocket::bind(addr).map_err(io_err_into_net_error)?;
141
142        #[allow(unused_mut)]
143        let mut ret = LocalUdpSocket {
144            selector: self.selector.clone(),
145            socket,
146            addr,
147            handler_guard: HandlerGuardState::None,
148            backlog: Default::default(),
149            ruleset: self.ruleset.clone(),
150        };
151
152        // In windows we can not poll the socket as it is not supported and hence
153        // what we do is immediately set the writable flag and relay on `mio` to
154        // refresh that flag when the state changes. In Linux what we do is actually
155        // make a non-blocking `poll` call to determine this state
156        #[cfg(target_os = "windows")]
157        {
158            let (state, selector, socket) = ret.split_borrow();
159            let map = state_as_waker_map(state, selector, socket).map_err(io_err_into_net_error)?;
160            map.push(InterestType::Writable);
161        }
162
163        Ok(Box::new(ret))
164    }
165
166    async fn connect_tcp(
167        &self,
168        _addr: SocketAddr,
169        mut peer: SocketAddr,
170    ) -> Result<Box<dyn VirtualTcpSocket + Sync>> {
171        if let Some(ruleset) = self.ruleset.as_ref()
172            && !ruleset.allows_socket(peer, Direction::Outbound)
173        {
174            tracing::warn!(%peer, "connect_tcp blocked by firewall rule");
175            return Err(NetworkError::PermissionDenied);
176        }
177
178        let stream = mio::net::TcpStream::connect(peer).map_err(io_err_into_net_error)?;
179
180        if let Ok(p) = stream.peer_addr() {
181            peer = p;
182        }
183        let socket = Box::new(LocalTcpStream::new(self.selector.clone(), stream, peer));
184        Ok(socket)
185    }
186
187    async fn resolve(
188        &self,
189        host: &str,
190        port: Option<u16>,
191        dns_server: Option<IpAddr>,
192    ) -> Result<Vec<IpAddr>> {
193        if let Some(ruleset) = self.ruleset.as_ref()
194            && !ruleset.allows_domain(host)
195        {
196            tracing::warn!(%host, "dns resolve blocked by firewall rule");
197            return Err(NetworkError::PermissionDenied);
198        }
199
200        let host_to_lookup = if host.contains(':') {
201            host.to_string()
202        } else {
203            format!("{}:{}", host, port.unwrap_or(0))
204        };
205        let addrs = self
206            .handle
207            .spawn(tokio::net::lookup_host(host_to_lookup))
208            .await
209            .map_err(|_| NetworkError::IOError)?
210            .map(|a| a.map(|a| a.ip()).collect::<Vec<_>>())
211            .map_err(io_err_into_net_error)?;
212
213        if let Some(ruleset) = self.ruleset.as_ref() {
214            if let Err(e) = ruleset.expand_domain(host, &addrs) {
215                tracing::debug!(err=%e, "ruleset expansion failed");
216            } else {
217                tracing::debug!(addrs=?addrs, domain = host, "ruleset expansion")
218            }
219        }
220
221        Ok(addrs)
222    }
223}
224
225#[derive(Debug)]
226pub struct LocalTcpListener {
227    stream: mio::net::TcpListener,
228    selector: Arc<Selector>,
229    handler_guard: HandlerGuardState,
230    no_delay: Option<bool>,
231    keep_alive: Option<bool>,
232    backlog: VecDeque<(Box<dyn VirtualTcpSocket + Sync>, SocketAddr)>,
233    ruleset: Option<Ruleset>,
234}
235
236impl LocalTcpListener {
237    fn try_accept_internal(&mut self) -> Result<(Box<dyn VirtualTcpSocket + Sync>, SocketAddr)> {
238        match self.stream.accept().map_err(io_err_into_net_error) {
239            Ok((stream, addr)) => {
240                if let Some(ruleset) = self.ruleset.as_ref()
241                    && !ruleset.allows_socket(addr, Direction::Outbound)
242                {
243                    tracing::warn!(%addr, "try_accept blocked by firewall rule");
244                    return Err(NetworkError::PermissionDenied);
245                }
246
247                let mut socket = LocalTcpStream::new(self.selector.clone(), stream, addr);
248                if let Some(no_delay) = self.no_delay {
249                    socket.set_nodelay(no_delay).ok();
250                }
251                if let Some(keep_alive) = self.keep_alive {
252                    socket.set_keepalive(keep_alive).ok();
253                }
254                Ok((Box::new(socket), addr))
255            }
256            Err(NetworkError::WouldBlock) => {
257                if let HandlerGuardState::WakerMap(_, map) = &mut self.handler_guard {
258                    map.pop(InterestType::Readable);
259                    map.pop(InterestType::Writable);
260                }
261                Err(NetworkError::WouldBlock)
262            }
263            Err(err) => Err(err),
264        }
265    }
266}
267
268impl VirtualTcpListener for LocalTcpListener {
269    fn try_accept(&mut self) -> Result<(Box<dyn VirtualTcpSocket + Sync>, SocketAddr)> {
270        if let Some(child) = self.backlog.pop_front() {
271            return Ok(child);
272        }
273        self.try_accept_internal()
274    }
275
276    fn set_handler(&mut self, mut handler: Box<dyn InterestHandler + Send + Sync>) -> Result<()> {
277        if let HandlerGuardState::ExternalHandler(guard) = &mut self.handler_guard {
278            match guard.replace_handler(handler) {
279                Ok(()) => return Ok(()),
280                Err(h) => handler = h,
281            }
282
283            // the handler could not be replaced so we need to build a new handler instead
284            if let Err(err) = guard.unregister(&mut self.stream) {
285                tracing::debug!("failed to unregister previous token - {}", err);
286            }
287        }
288
289        let guard = InterestGuard::new(
290            &self.selector,
291            handler,
292            &mut self.stream,
293            mio::Interest::READABLE.add(mio::Interest::WRITABLE),
294        )
295        .map_err(io_err_into_net_error)?;
296
297        self.handler_guard = HandlerGuardState::ExternalHandler(guard);
298
299        Ok(())
300    }
301
302    fn addr_local(&self) -> Result<SocketAddr> {
303        self.stream.local_addr().map_err(io_err_into_net_error)
304    }
305
306    fn set_ttl(&mut self, ttl: u8) -> Result<()> {
307        self.stream
308            .set_ttl(ttl as u32)
309            .map_err(io_err_into_net_error)
310    }
311
312    fn ttl(&self) -> Result<u8> {
313        self.stream
314            .ttl()
315            .map(|ttl| ttl as u8)
316            .map_err(io_err_into_net_error)
317    }
318}
319
320impl LocalTcpListener {
321    fn split_borrow(
322        &mut self,
323    ) -> (
324        &mut HandlerGuardState,
325        &Arc<Selector>,
326        &mut mio::net::TcpListener,
327    ) {
328        (&mut self.handler_guard, &self.selector, &mut self.stream)
329    }
330}
331
332impl VirtualIoSource for LocalTcpListener {
333    fn remove_handler(&mut self) {
334        let mut guard = HandlerGuardState::None;
335        std::mem::swap(&mut guard, &mut self.handler_guard);
336        match guard {
337            HandlerGuardState::ExternalHandler(mut guard) => {
338                guard.unregister(&mut self.stream).ok();
339            }
340            HandlerGuardState::WakerMap(mut guard, _) => {
341                guard.unregister(&mut self.stream).ok();
342            }
343            HandlerGuardState::None => {}
344        }
345    }
346
347    fn poll_read_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
348        if !self.backlog.is_empty() {
349            return Poll::Ready(Ok(self.backlog.len()));
350        }
351
352        let (state, selector, source) = self.split_borrow();
353        let map = state_as_waker_map(state, selector, source).map_err(io_err_into_net_error)?;
354        map.add(InterestType::Readable, cx.waker());
355
356        if let Ok(child) = self.try_accept_internal() {
357            self.backlog.push_back(child);
358            return Poll::Ready(Ok(1));
359        }
360        Poll::Pending
361    }
362
363    fn poll_write_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
364        if !self.backlog.is_empty() {
365            return Poll::Ready(Ok(self.backlog.len()));
366        }
367
368        let (state, selector, source) = self.split_borrow();
369        let map = state_as_waker_map(state, selector, source).map_err(io_err_into_net_error)?;
370        map.add(InterestType::Writable, cx.waker());
371
372        if let Ok(child) = self.try_accept_internal() {
373            self.backlog.push_back(child);
374            return Poll::Ready(Ok(1));
375        }
376        Poll::Pending
377    }
378}
379
380#[derive(Debug)]
381enum ConnectState {
382    Unknown,
383    Opened,
384    Failed,
385}
386
387#[derive(Debug)]
388pub struct LocalTcpStream {
389    stream: mio::net::TcpStream,
390    addr: SocketAddr,
391    shutdown: Option<Shutdown>,
392    selector: Arc<Selector>,
393    handler_guard: HandlerGuardState,
394    buffer: BytesMut,
395    connect_state: Mutex<ConnectState>,
396}
397
398impl LocalTcpStream {
399    fn new(selector: Arc<Selector>, stream: mio::net::TcpStream, addr: SocketAddr) -> Self {
400        #[allow(unused_mut)]
401        let mut ret = Self {
402            stream,
403            addr,
404            shutdown: None,
405            selector,
406            handler_guard: HandlerGuardState::None,
407            buffer: BytesMut::new(),
408            connect_state: Mutex::new(ConnectState::Unknown),
409        };
410
411        // In windows we can not poll the socket as it is not supported and hence
412        // what we do is immediately set the writable flag and relay on `mio` to
413        // refresh that flag when the state changes. In Linux what we do is actually
414        // make a non-blocking `poll` call to determine this state
415        #[cfg(target_os = "windows")]
416        {
417            let (state, selector, socket, _) = ret.split_borrow();
418            if let Ok(map) = state_as_waker_map(state, selector, socket) {
419                map.push(InterestType::Writable);
420            }
421        }
422
423        ret
424    }
425
426    fn with_sock_ref<F, R>(&self, f: F) -> R
427    where
428        for<'a> F: FnOnce(socket2::SockRef<'a>) -> R,
429    {
430        #[cfg(not(windows))]
431        let r = socket2::SockRef::from(&self.stream);
432
433        #[cfg(windows)]
434        let b = unsafe {
435            std::os::windows::io::BorrowedSocket::borrow_raw(self.stream.as_raw_socket())
436        };
437        #[cfg(windows)]
438        let r = socket2::SockRef::from(&b);
439
440        f(r)
441    }
442}
443
444impl VirtualTcpSocket for LocalTcpStream {
445    fn set_recv_buf_size(&mut self, size: usize) -> Result<()> {
446        Ok(())
447    }
448
449    fn recv_buf_size(&self) -> Result<usize> {
450        Err(NetworkError::Unsupported)
451    }
452
453    fn set_send_buf_size(&mut self, size: usize) -> Result<()> {
454        Ok(())
455    }
456
457    fn send_buf_size(&self) -> Result<usize> {
458        Err(NetworkError::Unsupported)
459    }
460
461    fn set_nodelay(&mut self, nodelay: bool) -> Result<()> {
462        self.stream
463            .set_nodelay(nodelay)
464            .map_err(io_err_into_net_error)
465    }
466
467    fn nodelay(&self) -> Result<bool> {
468        self.stream.nodelay().map_err(io_err_into_net_error)
469    }
470
471    fn set_keepalive(&mut self, keepalive: bool) -> Result<()> {
472        self.with_sock_ref(|s| s.set_keepalive(true))
473            .map_err(io_err_into_net_error)?;
474        Ok(())
475    }
476
477    fn keepalive(&self) -> Result<bool> {
478        let ret = self
479            .with_sock_ref(|s| s.keepalive())
480            .map_err(io_err_into_net_error)?;
481        Ok(ret)
482    }
483
484    #[cfg(not(target_os = "windows"))]
485    fn set_dontroute(&mut self, val: bool) -> Result<()> {
486        // TODO:
487        // Don't route is being set by WASIX which breaks networking
488        // Why this is being set is unknown but we need to disable
489        // the functionality for now as it breaks everything
490
491        let val = val as libc::c_int;
492        let payload = &val as *const libc::c_int as *const libc::c_void;
493        let err = unsafe {
494            libc::setsockopt(
495                self.stream.as_raw_fd(),
496                libc::SOL_SOCKET,
497                libc::SO_DONTROUTE,
498                payload,
499                std::mem::size_of::<libc::c_int>() as libc::socklen_t,
500            )
501        };
502        if err == -1 {
503            return Err(io_err_into_net_error(std::io::Error::last_os_error()));
504        }
505        Ok(())
506    }
507    #[cfg(target_os = "windows")]
508    fn set_dontroute(&mut self, val: bool) -> Result<()> {
509        Err(NetworkError::Unsupported)
510    }
511
512    #[cfg(not(target_os = "windows"))]
513    fn dontroute(&self) -> Result<bool> {
514        let mut payload: MaybeUninit<libc::c_int> = MaybeUninit::uninit();
515        let mut len = std::mem::size_of::<libc::c_int>() as libc::socklen_t;
516        let err = unsafe {
517            libc::getsockopt(
518                self.stream.as_raw_fd(),
519                libc::SOL_SOCKET,
520                libc::SO_DONTROUTE,
521                payload.as_mut_ptr().cast(),
522                &mut len,
523            )
524        };
525        if err == -1 {
526            return Err(io_err_into_net_error(std::io::Error::last_os_error()));
527        }
528        Ok(unsafe { payload.assume_init() != 0 })
529    }
530    #[cfg(target_os = "windows")]
531    fn dontroute(&self) -> Result<bool> {
532        Err(NetworkError::Unsupported)
533    }
534
535    fn addr_peer(&self) -> Result<SocketAddr> {
536        Ok(self.addr)
537    }
538
539    fn shutdown(&mut self, how: Shutdown) -> Result<()> {
540        self.stream.shutdown(how).map_err(io_err_into_net_error)?;
541        self.shutdown = Some(how);
542        Ok(())
543    }
544
545    fn is_closed(&self) -> bool {
546        false
547    }
548}
549
550impl VirtualConnectedSocket for LocalTcpStream {
551    fn set_linger(&mut self, linger: Option<Duration>) -> Result<()> {
552        self.with_sock_ref(|s| s.set_linger(linger))
553            .map_err(io_err_into_net_error)?;
554        Ok(())
555    }
556
557    fn linger(&self) -> Result<Option<Duration>> {
558        self.with_sock_ref(|s| s.linger())
559            .map_err(io_err_into_net_error)
560    }
561
562    fn try_send(&mut self, data: &[u8]) -> Result<usize> {
563        let ret = self.stream.write(data).map_err(io_err_into_net_error);
564        match &ret {
565            Ok(0) | Err(NetworkError::WouldBlock) => {
566                if let HandlerGuardState::WakerMap(_, map) = &mut self.handler_guard {
567                    map.pop(InterestType::Writable);
568                }
569            }
570            _ => {}
571        }
572        ret
573    }
574
575    fn try_flush(&mut self) -> Result<()> {
576        self.stream.flush().map_err(io_err_into_net_error)
577    }
578
579    fn close(&mut self) -> Result<()> {
580        Ok(())
581    }
582
583    fn try_recv(&mut self, buf: &mut [MaybeUninit<u8>], peek: bool) -> Result<usize> {
584        let buf: &mut [u8] = unsafe { std::mem::transmute(buf) };
585        if !self.buffer.is_empty() {
586            let amt = buf.len().min(self.buffer.len());
587            buf[..amt].copy_from_slice(&self.buffer[..amt]);
588            if !peek {
589                self.buffer.advance(amt);
590            }
591            return Ok(amt);
592        }
593
594        if peek {
595            self.stream.peek(buf)
596        } else {
597            self.stream.read(buf)
598        }
599        .map_err(io_err_into_net_error)
600    }
601}
602
603impl VirtualSocket for LocalTcpStream {
604    fn set_ttl(&mut self, ttl: u32) -> Result<()> {
605        self.stream.set_ttl(ttl).map_err(io_err_into_net_error)
606    }
607
608    fn ttl(&self) -> Result<u32> {
609        self.stream.ttl().map_err(io_err_into_net_error)
610    }
611
612    fn addr_local(&self) -> Result<SocketAddr> {
613        self.stream.local_addr().map_err(io_err_into_net_error)
614    }
615
616    fn status(&self) -> Result<SocketStatus> {
617        // `take_error()` consumes the latched socket error, so once the
618        // connect resolves we cache the terminal state to keep status() stable.
619        let mut connect_state = self.connect_state.lock().unwrap();
620        match *connect_state {
621            ConnectState::Opened => return Ok(SocketStatus::Opened),
622            ConnectState::Failed => return Ok(SocketStatus::Failed),
623            ConnectState::Unknown => {}
624        }
625
626        if self
627            .with_sock_ref(|sockref| sockref.take_error())
628            .map_err(io_err_into_net_error)?
629            .is_some()
630        {
631            *connect_state = ConnectState::Failed;
632            return Ok(SocketStatus::Failed); // connect error on the socket
633        }
634        match self.stream.peer_addr() {
635            Ok(_) => {
636                *connect_state = ConnectState::Opened;
637                Ok(SocketStatus::Opened) // TCP handshake completed.
638            }
639            Err(err) => {
640                if matches!(
641                    err.kind(),
642                    io::ErrorKind::NotConnected | io::ErrorKind::WouldBlock
643                ) {
644                    Ok(SocketStatus::Opening) // The connect is still in progress
645                } else {
646                    // TODO: Store the concrete err so we can return it later on
647                    *connect_state = ConnectState::Failed;
648                    Ok(SocketStatus::Failed) // Any other error means the socket is unusable
649                }
650            }
651        }
652    }
653
654    fn set_handler(&mut self, mut handler: Box<dyn InterestHandler + Send + Sync>) -> Result<()> {
655        if let HandlerGuardState::ExternalHandler(guard) = &mut self.handler_guard {
656            match guard.replace_handler(handler) {
657                Ok(()) => return Ok(()),
658                Err(h) => handler = h,
659            }
660
661            // the handler could not be replaced so we need to build a new handler instead
662            if let Err(err) = guard.unregister(&mut self.stream) {
663                tracing::debug!("failed to unregister previous token - {}", err);
664            }
665        }
666
667        let guard = InterestGuard::new(
668            &self.selector,
669            handler,
670            &mut self.stream,
671            mio::Interest::READABLE.add(mio::Interest::WRITABLE),
672        )
673        .map_err(io_err_into_net_error)?;
674
675        self.handler_guard = HandlerGuardState::ExternalHandler(guard);
676
677        Ok(())
678    }
679}
680
681impl LocalTcpStream {
682    fn split_borrow(
683        &mut self,
684    ) -> (
685        &mut HandlerGuardState,
686        &Arc<Selector>,
687        &mut mio::net::TcpStream,
688        &mut BytesMut,
689    ) {
690        (
691            &mut self.handler_guard,
692            &self.selector,
693            &mut self.stream,
694            &mut self.buffer,
695        )
696    }
697}
698
699impl VirtualIoSource for LocalTcpStream {
700    fn remove_handler(&mut self) {
701        let mut guard = HandlerGuardState::None;
702        std::mem::swap(&mut guard, &mut self.handler_guard);
703        match guard {
704            HandlerGuardState::ExternalHandler(mut guard) => {
705                guard.unregister(&mut self.stream).ok();
706            }
707            HandlerGuardState::WakerMap(mut guard, _) => {
708                guard.unregister(&mut self.stream).ok();
709            }
710            HandlerGuardState::None => {}
711        }
712    }
713
714    fn poll_read_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
715        if !self.buffer.is_empty() {
716            return Poll::Ready(Ok(self.buffer.len()));
717        }
718
719        let (state, selector, stream, buffer) = self.split_borrow();
720        let map = state_as_waker_map(state, selector, stream).map_err(io_err_into_net_error)?;
721        map.pop(InterestType::Readable);
722        map.add(InterestType::Readable, cx.waker());
723
724        buffer.reserve(buffer.len() + 10240);
725        let uninit: &mut [MaybeUninit<u8>] = buffer.spare_capacity_mut();
726        let uninit_unsafe: &mut [u8] = unsafe { std::mem::transmute(uninit) };
727
728        match stream.read(uninit_unsafe) {
729            Ok(0) => Poll::Ready(Ok(0)),
730            Ok(amt) => {
731                unsafe {
732                    buffer.set_len(buffer.len() + amt);
733                }
734                Poll::Ready(Ok(amt))
735            }
736            Err(err) if err.kind() == io::ErrorKind::ConnectionAborted => Poll::Ready(Ok(0)),
737            Err(err) if err.kind() == io::ErrorKind::ConnectionReset => Poll::Ready(Ok(0)),
738            Err(err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
739            Err(err) => Poll::Ready(Err(io_err_into_net_error(err))),
740        }
741    }
742
743    fn poll_write_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
744        let (state, selector, stream, _) = self.split_borrow();
745        let map = state_as_waker_map(state, selector, stream).map_err(io_err_into_net_error)?;
746        #[cfg(not(target_os = "windows"))]
747        map.pop(InterestType::Writable);
748        map.add(InterestType::Writable, cx.waker());
749        map.add(InterestType::Closed, cx.waker());
750        if map.has_interest(InterestType::Closed) {
751            return Poll::Ready(Ok(0));
752        }
753
754        #[cfg(not(target_os = "windows"))]
755        match libc_poll(stream.as_raw_fd(), libc::POLLOUT | libc::POLLHUP) {
756            Some(val) if (val & libc::POLLHUP) != 0 => {
757                return Poll::Ready(Ok(0));
758            }
759            Some(val) if (val & libc::POLLOUT) != 0 => return Poll::Ready(Ok(10240)),
760            _ => {}
761        }
762
763        // In windows we can not poll the socket as it is not supported and hence
764        // what we do is immediately set the writable flag and relay on `mio` to
765        // refresh that flag when the state changes. In Linux what we do is actually
766        // make a non-blocking `poll` call to determine this state
767        #[cfg(target_os = "windows")]
768        if map.has_interest(InterestType::Writable) {
769            return Poll::Ready(Ok(10240));
770        }
771
772        Poll::Pending
773    }
774}
775
776#[cfg(not(target_os = "windows"))]
777fn libc_poll(fd: RawFd, events: libc::c_short) -> Option<libc::c_short> {
778    let mut fds: [libc::pollfd; 1] = [libc::pollfd {
779        fd,
780        events,
781        revents: 0,
782    }];
783    let fds_mut = &mut fds[..];
784    let ret = unsafe { libc::poll(fds_mut.as_mut_ptr(), 1, 0) };
785    match ret == 1 {
786        true => Some(fds[0].revents),
787        false => None,
788    }
789}
790
791#[derive(Debug)]
792pub struct LocalUdpSocket {
793    socket: mio::net::UdpSocket,
794    #[allow(dead_code)]
795    addr: SocketAddr,
796    selector: Arc<Selector>,
797    handler_guard: HandlerGuardState,
798    backlog: VecDeque<(BytesMut, SocketAddr)>,
799    ruleset: Option<Ruleset>,
800}
801
802impl LocalUdpSocket {
803    fn with_sock_ref<F, R>(&self, f: F) -> R
804    where
805        for<'a> F: FnOnce(socket2::SockRef<'a>) -> R,
806    {
807        #[cfg(not(windows))]
808        let r = socket2::SockRef::from(&self.socket);
809
810        #[cfg(windows)]
811        let b = unsafe {
812            std::os::windows::io::BorrowedSocket::borrow_raw(self.socket.as_raw_socket())
813        };
814        #[cfg(windows)]
815        let r = socket2::SockRef::from(&b);
816
817        f(r)
818    }
819}
820
821impl VirtualUdpSocket for LocalUdpSocket {
822    fn set_broadcast(&mut self, broadcast: bool) -> Result<()> {
823        self.socket
824            .set_broadcast(broadcast)
825            .map_err(io_err_into_net_error)
826    }
827
828    fn broadcast(&self) -> Result<bool> {
829        self.socket.broadcast().map_err(io_err_into_net_error)
830    }
831
832    fn set_multicast_loop_v4(&mut self, val: bool) -> Result<()> {
833        self.socket
834            .set_multicast_loop_v4(val)
835            .map_err(io_err_into_net_error)
836    }
837
838    fn multicast_loop_v4(&self) -> Result<bool> {
839        self.socket
840            .multicast_loop_v4()
841            .map_err(io_err_into_net_error)
842    }
843
844    fn set_multicast_loop_v6(&mut self, val: bool) -> Result<()> {
845        self.socket
846            .set_multicast_loop_v6(val)
847            .map_err(io_err_into_net_error)
848    }
849
850    fn multicast_loop_v6(&self) -> Result<bool> {
851        self.socket
852            .multicast_loop_v6()
853            .map_err(io_err_into_net_error)
854    }
855
856    fn set_multicast_ttl_v4(&mut self, ttl: u32) -> Result<()> {
857        self.socket
858            .set_multicast_ttl_v4(ttl)
859            .map_err(io_err_into_net_error)
860    }
861
862    fn multicast_ttl_v4(&self) -> Result<u32> {
863        self.socket
864            .multicast_ttl_v4()
865            .map_err(io_err_into_net_error)
866    }
867
868    fn join_multicast_v4(&mut self, multiaddr: Ipv4Addr, iface: Ipv4Addr) -> Result<()> {
869        self.with_sock_ref(|s| s.join_multicast_v4(&multiaddr, &iface))
870            .map_err(io_err_into_net_error)
871    }
872
873    fn leave_multicast_v4(&mut self, multiaddr: Ipv4Addr, iface: Ipv4Addr) -> Result<()> {
874        self.with_sock_ref(|s| s.leave_multicast_v4(&multiaddr, &iface))
875            .map_err(io_err_into_net_error)
876    }
877
878    fn join_multicast_v6(&mut self, multiaddr: Ipv6Addr, iface: u32) -> Result<()> {
879        self.socket
880            .join_multicast_v6(&multiaddr, iface)
881            .map_err(io_err_into_net_error)
882    }
883
884    fn leave_multicast_v6(&mut self, multiaddr: Ipv6Addr, iface: u32) -> Result<()> {
885        self.socket
886            .leave_multicast_v6(&multiaddr, iface)
887            .map_err(io_err_into_net_error)
888    }
889
890    fn addr_peer(&self) -> Result<Option<SocketAddr>> {
891        self.socket
892            .peer_addr()
893            .map(Some)
894            .map_err(io_err_into_net_error)
895    }
896}
897
898impl VirtualConnectionlessSocket for LocalUdpSocket {
899    fn try_send_to(&mut self, data: &[u8], addr: SocketAddr) -> Result<usize> {
900        if let Some(ruleset) = self.ruleset.as_ref()
901            && !ruleset.allows_socket(addr, Direction::Outbound)
902        {
903            tracing::warn!(%addr, "try_send blocked by firewall rule");
904            return Err(NetworkError::PermissionDenied);
905        }
906
907        let ret = self
908            .socket
909            .send_to(data, addr)
910            .map_err(io_err_into_net_error);
911        match &ret {
912            Ok(0) | Err(NetworkError::WouldBlock) => {
913                if let HandlerGuardState::WakerMap(_, map) = &mut self.handler_guard {
914                    map.pop(InterestType::Writable);
915                }
916            }
917            _ => {}
918        }
919        ret
920    }
921
922    fn try_recv_from(
923        &mut self,
924        buf: &mut [MaybeUninit<u8>],
925        peek: bool,
926    ) -> Result<(usize, SocketAddr)> {
927        let buf: &mut [u8] = unsafe { std::mem::transmute(buf) };
928        if peek {
929            self.socket.peek_from(buf)
930        } else {
931            self.socket.recv_from(buf)
932        }
933        .map_err(io_err_into_net_error)
934    }
935}
936
937impl VirtualSocket for LocalUdpSocket {
938    fn set_ttl(&mut self, ttl: u32) -> Result<()> {
939        self.socket.set_ttl(ttl).map_err(io_err_into_net_error)
940    }
941
942    fn ttl(&self) -> Result<u32> {
943        self.socket.ttl().map_err(io_err_into_net_error)
944    }
945
946    fn addr_local(&self) -> Result<SocketAddr> {
947        self.socket.local_addr().map_err(io_err_into_net_error)
948    }
949
950    fn status(&self) -> Result<SocketStatus> {
951        Ok(SocketStatus::Opened)
952    }
953
954    fn set_handler(&mut self, mut handler: Box<dyn InterestHandler + Send + Sync>) -> Result<()> {
955        if let HandlerGuardState::ExternalHandler(guard) = &mut self.handler_guard {
956            match guard.replace_handler(handler) {
957                Ok(()) => {
958                    return Ok(());
959                }
960                Err(h) => handler = h,
961            }
962
963            // the handler could not be replaced so we need to build a new handler instead
964            if let Err(err) = guard.unregister(&mut self.socket) {
965                tracing::debug!("failed to unregister previous token - {}", err);
966            }
967        }
968
969        let guard = InterestGuard::new(
970            &self.selector,
971            handler,
972            &mut self.socket,
973            mio::Interest::READABLE.add(mio::Interest::WRITABLE),
974        )
975        .map_err(io_err_into_net_error)?;
976
977        self.handler_guard = HandlerGuardState::ExternalHandler(guard);
978
979        Ok(())
980    }
981}
982
983impl LocalUdpSocket {
984    fn split_borrow(
985        &mut self,
986    ) -> (
987        &mut HandlerGuardState,
988        &Arc<Selector>,
989        &mut mio::net::UdpSocket,
990    ) {
991        (&mut self.handler_guard, &self.selector, &mut self.socket)
992    }
993}
994
995impl VirtualIoSource for LocalUdpSocket {
996    fn remove_handler(&mut self) {
997        let mut guard = HandlerGuardState::None;
998        std::mem::swap(&mut guard, &mut self.handler_guard);
999        match guard {
1000            HandlerGuardState::ExternalHandler(mut guard) => {
1001                guard.unregister(&mut self.socket).ok();
1002            }
1003            HandlerGuardState::WakerMap(mut guard, _) => {
1004                guard.unregister(&mut self.socket).ok();
1005            }
1006            HandlerGuardState::None => {}
1007        }
1008    }
1009
1010    fn poll_read_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
1011        if !self.backlog.is_empty() {
1012            let total = self.backlog.iter().map(|a| a.0.len()).sum();
1013            return Poll::Ready(Ok(total));
1014        }
1015
1016        let (state, selector, socket) = self.split_borrow();
1017        let map = state_as_waker_map(state, selector, socket).map_err(io_err_into_net_error)?;
1018        map.pop(InterestType::Readable);
1019        map.add(InterestType::Readable, cx.waker());
1020
1021        let mut buffer = BytesMut::default();
1022        buffer.reserve(10240);
1023        let uninit: &mut [MaybeUninit<u8>] = buffer.spare_capacity_mut();
1024        let uninit_unsafe: &mut [u8] = unsafe { std::mem::transmute(uninit) };
1025
1026        match self.socket.recv_from(uninit_unsafe) {
1027            Ok((0, _)) => Poll::Ready(Ok(0)),
1028            Ok((amt, peer)) => {
1029                unsafe {
1030                    buffer.set_len(amt);
1031                }
1032                self.backlog.push_back((buffer, peer));
1033                Poll::Ready(Ok(amt))
1034            }
1035            Err(err) if err.kind() == io::ErrorKind::ConnectionAborted => Poll::Ready(Ok(0)),
1036            Err(err) if err.kind() == io::ErrorKind::ConnectionReset => Poll::Ready(Ok(0)),
1037            Err(err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
1038            Err(err) => Poll::Ready(Err(io_err_into_net_error(err))),
1039        }
1040    }
1041
1042    fn poll_write_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
1043        let (state, selector, socket) = self.split_borrow();
1044        let map = state_as_waker_map(state, selector, socket).map_err(io_err_into_net_error)?;
1045        #[cfg(not(target_os = "windows"))]
1046        map.pop(InterestType::Writable);
1047        map.add(InterestType::Writable, cx.waker());
1048
1049        #[cfg(not(target_os = "windows"))]
1050        match libc_poll(socket.as_raw_fd(), libc::POLLOUT | libc::POLLHUP) {
1051            Some(val) if (val & libc::POLLHUP) != 0 => {
1052                return Poll::Ready(Ok(0));
1053            }
1054            Some(val) if (val & libc::POLLOUT) != 0 => return Poll::Ready(Ok(10240)),
1055            _ => {}
1056        }
1057
1058        // In windows we can not poll the socket as it is not supported and hence
1059        // what we do is immediately set the writable flag and relay on `mio` to
1060        // refresh that flag when the state changes. In Linux what we do is actually
1061        // make a non-blocking `poll` call to determine this state
1062        #[cfg(target_os = "windows")]
1063        if map.has_interest(InterestType::Writable) {
1064            return Poll::Ready(Ok(10240));
1065        }
1066
1067        Poll::Pending
1068    }
1069}