virtual_net/
host.rs

1#![allow(unused_variables)]
2use crate::ruleset::{Direction, Ruleset};
3#[allow(unused_imports)]
4use crate::{
5    IpCidr, IpRoute, NetworkError, Result, SocketStatus, StreamSecurity, VirtualConnectedSocket,
6    VirtualConnectionlessSocket, VirtualIcmpSocket, VirtualNetworking, VirtualRawSocket,
7    VirtualSocket, VirtualTcpListener, VirtualTcpSocket, VirtualUdpSocket,
8};
9use crate::{VirtualIoSource, io_err_into_net_error};
10use bytes::{Buf, BytesMut};
11use std::collections::VecDeque;
12use std::io::{self, Read, Write};
13use std::mem::MaybeUninit;
14use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, Shutdown, SocketAddr};
15#[cfg(not(target_os = "windows"))]
16use std::os::fd::AsRawFd;
17#[cfg(not(target_os = "windows"))]
18use std::os::fd::RawFd;
19#[cfg(windows)]
20use std::os::windows::io::AsRawSocket;
21
22use std::sync::Arc;
23use std::task::Poll;
24use std::time::Duration;
25use tokio::runtime::Handle;
26#[allow(unused_imports, dead_code)]
27use tracing::{debug, error, info, trace, warn};
28use virtual_mio::{
29    HandlerGuardState, InterestGuard, InterestHandler, InterestType, Selector, state_as_waker_map,
30};
31
32#[derive(Debug)]
33pub struct LocalNetworking {
34    selector: Arc<Selector>,
35    handle: Handle,
36    ruleset: Option<Ruleset>,
37}
38
39impl LocalNetworking {
40    pub fn new() -> Self {
41        Self {
42            selector: Selector::new(),
43            handle: Handle::current(),
44            ruleset: None,
45        }
46    }
47
48    pub fn with_ruleset(ruleset: Ruleset) -> Self {
49        Self {
50            selector: Selector::new(),
51            handle: Handle::current(),
52            ruleset: Some(ruleset),
53        }
54    }
55}
56
57impl Drop for LocalNetworking {
58    fn drop(&mut self) {
59        self.selector.shutdown();
60    }
61}
62
63impl Default for LocalNetworking {
64    fn default() -> Self {
65        Self::new()
66    }
67}
68
69#[async_trait::async_trait]
70#[allow(unused_variables)]
71impl VirtualNetworking for LocalNetworking {
72    async fn listen_tcp(
73        &self,
74        addr: SocketAddr,
75        only_v6: bool,
76        reuse_port: bool,
77        reuse_addr: bool,
78    ) -> Result<Box<dyn VirtualTcpListener + Sync>> {
79        if let Some(ruleset) = self.ruleset.as_ref()
80            && !ruleset.allows_socket(addr, Direction::Inbound)
81        {
82            tracing::warn!(%addr, "listen_tcp blocked by firewall rule");
83            return Err(NetworkError::PermissionDenied);
84        }
85
86        let listener = std::net::TcpListener::bind(addr)
87            .map(|sock| {
88                sock.set_nonblocking(true).ok();
89                Box::new(LocalTcpListener {
90                    stream: mio::net::TcpListener::from_std(sock),
91                    selector: self.selector.clone(),
92                    handler_guard: HandlerGuardState::None,
93                    no_delay: None,
94                    keep_alive: None,
95                    backlog: Default::default(),
96                    ruleset: self.ruleset.clone(),
97                })
98            })
99            .map_err(io_err_into_net_error)?;
100        Ok(listener)
101    }
102
103    async fn bind_udp(
104        &self,
105        addr: SocketAddr,
106        reuse_port: bool,
107        reuse_addr: bool,
108    ) -> Result<Box<dyn VirtualUdpSocket + Sync>> {
109        #[cfg(not(windows))]
110        use socket2::{Domain, Socket, Type};
111
112        if let Some(ruleset) = self.ruleset.as_ref()
113            && !ruleset.allows_socket(addr, Direction::Inbound)
114        {
115            tracing::warn!(%addr, "bind_udp blocked by firewall rule");
116            return Err(NetworkError::PermissionDenied);
117        }
118
119        #[cfg(not(windows))]
120        let socket = {
121            let domain = if addr.is_ipv4() {
122                Domain::IPV4
123            } else {
124                Domain::IPV6
125            };
126            let std_sock = Socket::new(domain, Type::DGRAM, None).map_err(io_err_into_net_error)?;
127            std_sock
128                .set_nonblocking(true)
129                .map_err(io_err_into_net_error)?;
130            std_sock
131                .set_reuse_address(reuse_addr)
132                .map_err(io_err_into_net_error)?;
133            std_sock
134                .set_reuse_port(reuse_port)
135                .map_err(io_err_into_net_error)?;
136            std_sock.bind(&addr.into()).map_err(io_err_into_net_error)?;
137            mio::net::UdpSocket::from_std(std_sock.into())
138        };
139        #[cfg(windows)]
140        let socket = mio::net::UdpSocket::bind(addr).map_err(io_err_into_net_error)?;
141
142        #[allow(unused_mut)]
143        let mut ret = LocalUdpSocket {
144            selector: self.selector.clone(),
145            socket,
146            addr,
147            handler_guard: HandlerGuardState::None,
148            backlog: Default::default(),
149            ruleset: self.ruleset.clone(),
150        };
151
152        // In windows we can not poll the socket as it is not supported and hence
153        // what we do is immediately set the writable flag and relay on `mio` to
154        // refresh that flag when the state changes. In Linux what we do is actually
155        // make a non-blocking `poll` call to determine this state
156        #[cfg(target_os = "windows")]
157        {
158            let (state, selector, socket) = ret.split_borrow();
159            let map = state_as_waker_map(state, selector, socket).map_err(io_err_into_net_error)?;
160            map.push(InterestType::Writable);
161        }
162
163        Ok(Box::new(ret))
164    }
165
166    async fn connect_tcp(
167        &self,
168        _addr: SocketAddr,
169        mut peer: SocketAddr,
170    ) -> Result<Box<dyn VirtualTcpSocket + Sync>> {
171        if let Some(ruleset) = self.ruleset.as_ref()
172            && !ruleset.allows_socket(peer, Direction::Outbound)
173        {
174            tracing::warn!(%peer, "connect_tcp blocked by firewall rule");
175            return Err(NetworkError::PermissionDenied);
176        }
177
178        let stream = mio::net::TcpStream::connect(peer).map_err(io_err_into_net_error)?;
179
180        if let Ok(p) = stream.peer_addr() {
181            peer = p;
182        }
183        let socket = Box::new(LocalTcpStream::new(self.selector.clone(), stream, peer));
184        Ok(socket)
185    }
186
187    async fn resolve(
188        &self,
189        host: &str,
190        port: Option<u16>,
191        dns_server: Option<IpAddr>,
192    ) -> Result<Vec<IpAddr>> {
193        if let Some(ruleset) = self.ruleset.as_ref()
194            && !ruleset.allows_domain(host)
195        {
196            tracing::warn!(%host, "dns resolve blocked by firewall rule");
197            return Err(NetworkError::PermissionDenied);
198        }
199
200        let host_to_lookup = if host.contains(':') {
201            host.to_string()
202        } else {
203            format!("{}:{}", host, port.unwrap_or(0))
204        };
205        let addrs = self
206            .handle
207            .spawn(tokio::net::lookup_host(host_to_lookup))
208            .await
209            .map_err(|_| NetworkError::IOError)?
210            .map(|a| a.map(|a| a.ip()).collect::<Vec<_>>())
211            .map_err(io_err_into_net_error)?;
212
213        if let Some(ruleset) = self.ruleset.as_ref() {
214            if let Err(e) = ruleset.expand_domain(host, &addrs) {
215                tracing::debug!(err=%e, "ruleset expansion failed");
216            } else {
217                tracing::debug!(addrs=?addrs, domain = host, "ruleset expansion")
218            }
219        }
220
221        Ok(addrs)
222    }
223}
224
225#[derive(Debug)]
226pub struct LocalTcpListener {
227    stream: mio::net::TcpListener,
228    selector: Arc<Selector>,
229    handler_guard: HandlerGuardState,
230    no_delay: Option<bool>,
231    keep_alive: Option<bool>,
232    backlog: VecDeque<(Box<dyn VirtualTcpSocket + Sync>, SocketAddr)>,
233    ruleset: Option<Ruleset>,
234}
235
236impl LocalTcpListener {
237    fn try_accept_internal(&mut self) -> Result<(Box<dyn VirtualTcpSocket + Sync>, SocketAddr)> {
238        match self.stream.accept().map_err(io_err_into_net_error) {
239            Ok((stream, addr)) => {
240                if let Some(ruleset) = self.ruleset.as_ref()
241                    && !ruleset.allows_socket(addr, Direction::Outbound)
242                {
243                    tracing::warn!(%addr, "try_accept blocked by firewall rule");
244                    return Err(NetworkError::PermissionDenied);
245                }
246
247                let mut socket = LocalTcpStream::new(self.selector.clone(), stream, addr);
248                if let Some(no_delay) = self.no_delay {
249                    socket.set_nodelay(no_delay).ok();
250                }
251                if let Some(keep_alive) = self.keep_alive {
252                    socket.set_keepalive(keep_alive).ok();
253                }
254                Ok((Box::new(socket), addr))
255            }
256            Err(NetworkError::WouldBlock) => {
257                if let HandlerGuardState::WakerMap(_, map) = &mut self.handler_guard {
258                    map.pop(InterestType::Readable);
259                    map.pop(InterestType::Writable);
260                }
261                Err(NetworkError::WouldBlock)
262            }
263            Err(err) => Err(err),
264        }
265    }
266}
267
268impl VirtualTcpListener for LocalTcpListener {
269    fn try_accept(&mut self) -> Result<(Box<dyn VirtualTcpSocket + Sync>, SocketAddr)> {
270        if let Some(child) = self.backlog.pop_front() {
271            return Ok(child);
272        }
273        self.try_accept_internal()
274    }
275
276    fn set_handler(&mut self, mut handler: Box<dyn InterestHandler + Send + Sync>) -> Result<()> {
277        if let HandlerGuardState::ExternalHandler(guard) = &mut self.handler_guard {
278            match guard.replace_handler(handler) {
279                Ok(()) => return Ok(()),
280                Err(h) => handler = h,
281            }
282
283            // the handler could not be replaced so we need to build a new handler instead
284            if let Err(err) = guard.unregister(&mut self.stream) {
285                tracing::debug!("failed to unregister previous token - {}", err);
286            }
287        }
288
289        let guard = InterestGuard::new(
290            &self.selector,
291            handler,
292            &mut self.stream,
293            mio::Interest::READABLE.add(mio::Interest::WRITABLE),
294        )
295        .map_err(io_err_into_net_error)?;
296
297        self.handler_guard = HandlerGuardState::ExternalHandler(guard);
298
299        Ok(())
300    }
301
302    fn addr_local(&self) -> Result<SocketAddr> {
303        self.stream.local_addr().map_err(io_err_into_net_error)
304    }
305
306    fn set_ttl(&mut self, ttl: u8) -> Result<()> {
307        self.stream
308            .set_ttl(ttl as u32)
309            .map_err(io_err_into_net_error)
310    }
311
312    fn ttl(&self) -> Result<u8> {
313        self.stream
314            .ttl()
315            .map(|ttl| ttl as u8)
316            .map_err(io_err_into_net_error)
317    }
318}
319
320impl LocalTcpListener {
321    fn split_borrow(
322        &mut self,
323    ) -> (
324        &mut HandlerGuardState,
325        &Arc<Selector>,
326        &mut mio::net::TcpListener,
327    ) {
328        (&mut self.handler_guard, &self.selector, &mut self.stream)
329    }
330}
331
332impl VirtualIoSource for LocalTcpListener {
333    fn remove_handler(&mut self) {
334        let mut guard = HandlerGuardState::None;
335        std::mem::swap(&mut guard, &mut self.handler_guard);
336        match guard {
337            HandlerGuardState::ExternalHandler(mut guard) => {
338                guard.unregister(&mut self.stream).ok();
339            }
340            HandlerGuardState::WakerMap(mut guard, _) => {
341                guard.unregister(&mut self.stream).ok();
342            }
343            HandlerGuardState::None => {}
344        }
345    }
346
347    fn poll_read_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
348        if !self.backlog.is_empty() {
349            return Poll::Ready(Ok(self.backlog.len()));
350        }
351
352        let (state, selector, source) = self.split_borrow();
353        let map = state_as_waker_map(state, selector, source).map_err(io_err_into_net_error)?;
354        map.add(InterestType::Readable, cx.waker());
355
356        if let Ok(child) = self.try_accept_internal() {
357            self.backlog.push_back(child);
358            return Poll::Ready(Ok(1));
359        }
360        Poll::Pending
361    }
362
363    fn poll_write_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
364        if !self.backlog.is_empty() {
365            return Poll::Ready(Ok(self.backlog.len()));
366        }
367
368        let (state, selector, source) = self.split_borrow();
369        let map = state_as_waker_map(state, selector, source).map_err(io_err_into_net_error)?;
370        map.add(InterestType::Writable, cx.waker());
371
372        if let Ok(child) = self.try_accept_internal() {
373            self.backlog.push_back(child);
374            return Poll::Ready(Ok(1));
375        }
376        Poll::Pending
377    }
378}
379
380#[derive(Debug)]
381pub struct LocalTcpStream {
382    stream: mio::net::TcpStream,
383    addr: SocketAddr,
384    shutdown: Option<Shutdown>,
385    selector: Arc<Selector>,
386    handler_guard: HandlerGuardState,
387    buffer: BytesMut,
388}
389
390impl LocalTcpStream {
391    fn new(selector: Arc<Selector>, stream: mio::net::TcpStream, addr: SocketAddr) -> Self {
392        #[allow(unused_mut)]
393        let mut ret = Self {
394            stream,
395            addr,
396            shutdown: None,
397            selector,
398            handler_guard: HandlerGuardState::None,
399            buffer: BytesMut::new(),
400        };
401
402        // In windows we can not poll the socket as it is not supported and hence
403        // what we do is immediately set the writable flag and relay on `mio` to
404        // refresh that flag when the state changes. In Linux what we do is actually
405        // make a non-blocking `poll` call to determine this state
406        #[cfg(target_os = "windows")]
407        {
408            let (state, selector, socket, _) = ret.split_borrow();
409            if let Ok(map) = state_as_waker_map(state, selector, socket) {
410                map.push(InterestType::Writable);
411            }
412        }
413
414        ret
415    }
416
417    fn with_sock_ref<F, R>(&self, f: F) -> R
418    where
419        for<'a> F: FnOnce(socket2::SockRef<'a>) -> R,
420    {
421        #[cfg(not(windows))]
422        let r = socket2::SockRef::from(&self.stream);
423
424        #[cfg(windows)]
425        let b = unsafe {
426            std::os::windows::io::BorrowedSocket::borrow_raw(self.stream.as_raw_socket())
427        };
428        #[cfg(windows)]
429        let r = socket2::SockRef::from(&b);
430
431        f(r)
432    }
433}
434
435impl VirtualTcpSocket for LocalTcpStream {
436    fn set_recv_buf_size(&mut self, size: usize) -> Result<()> {
437        Ok(())
438    }
439
440    fn recv_buf_size(&self) -> Result<usize> {
441        Err(NetworkError::Unsupported)
442    }
443
444    fn set_send_buf_size(&mut self, size: usize) -> Result<()> {
445        Ok(())
446    }
447
448    fn send_buf_size(&self) -> Result<usize> {
449        Err(NetworkError::Unsupported)
450    }
451
452    fn set_nodelay(&mut self, nodelay: bool) -> Result<()> {
453        self.stream
454            .set_nodelay(nodelay)
455            .map_err(io_err_into_net_error)
456    }
457
458    fn nodelay(&self) -> Result<bool> {
459        self.stream.nodelay().map_err(io_err_into_net_error)
460    }
461
462    fn set_keepalive(&mut self, keepalive: bool) -> Result<()> {
463        self.with_sock_ref(|s| s.set_keepalive(true))
464            .map_err(io_err_into_net_error)?;
465        Ok(())
466    }
467
468    fn keepalive(&self) -> Result<bool> {
469        let ret = self
470            .with_sock_ref(|s| s.keepalive())
471            .map_err(io_err_into_net_error)?;
472        Ok(ret)
473    }
474
475    #[cfg(not(target_os = "windows"))]
476    fn set_dontroute(&mut self, val: bool) -> Result<()> {
477        // TODO:
478        // Don't route is being set by WASIX which breaks networking
479        // Why this is being set is unknown but we need to disable
480        // the functionality for now as it breaks everything
481
482        let val = val as libc::c_int;
483        let payload = &val as *const libc::c_int as *const libc::c_void;
484        let err = unsafe {
485            libc::setsockopt(
486                self.stream.as_raw_fd(),
487                libc::SOL_SOCKET,
488                libc::SO_DONTROUTE,
489                payload,
490                std::mem::size_of::<libc::c_int>() as libc::socklen_t,
491            )
492        };
493        if err == -1 {
494            return Err(io_err_into_net_error(std::io::Error::last_os_error()));
495        }
496        Ok(())
497    }
498    #[cfg(target_os = "windows")]
499    fn set_dontroute(&mut self, val: bool) -> Result<()> {
500        Err(NetworkError::Unsupported)
501    }
502
503    #[cfg(not(target_os = "windows"))]
504    fn dontroute(&self) -> Result<bool> {
505        let mut payload: MaybeUninit<libc::c_int> = MaybeUninit::uninit();
506        let mut len = std::mem::size_of::<libc::c_int>() as libc::socklen_t;
507        let err = unsafe {
508            libc::getsockopt(
509                self.stream.as_raw_fd(),
510                libc::SOL_SOCKET,
511                libc::SO_DONTROUTE,
512                payload.as_mut_ptr().cast(),
513                &mut len,
514            )
515        };
516        if err == -1 {
517            return Err(io_err_into_net_error(std::io::Error::last_os_error()));
518        }
519        Ok(unsafe { payload.assume_init() != 0 })
520    }
521    #[cfg(target_os = "windows")]
522    fn dontroute(&self) -> Result<bool> {
523        Err(NetworkError::Unsupported)
524    }
525
526    fn addr_peer(&self) -> Result<SocketAddr> {
527        Ok(self.addr)
528    }
529
530    fn shutdown(&mut self, how: Shutdown) -> Result<()> {
531        self.stream.shutdown(how).map_err(io_err_into_net_error)?;
532        self.shutdown = Some(how);
533        Ok(())
534    }
535
536    fn is_closed(&self) -> bool {
537        false
538    }
539}
540
541impl VirtualConnectedSocket for LocalTcpStream {
542    fn set_linger(&mut self, linger: Option<Duration>) -> Result<()> {
543        self.with_sock_ref(|s| s.set_linger(linger))
544            .map_err(io_err_into_net_error)?;
545        Ok(())
546    }
547
548    fn linger(&self) -> Result<Option<Duration>> {
549        self.with_sock_ref(|s| s.linger())
550            .map_err(io_err_into_net_error)
551    }
552
553    fn try_send(&mut self, data: &[u8]) -> Result<usize> {
554        let ret = self.stream.write(data).map_err(io_err_into_net_error);
555        match &ret {
556            Ok(0) | Err(NetworkError::WouldBlock) => {
557                if let HandlerGuardState::WakerMap(_, map) = &mut self.handler_guard {
558                    map.pop(InterestType::Writable);
559                }
560            }
561            _ => {}
562        }
563        ret
564    }
565
566    fn try_flush(&mut self) -> Result<()> {
567        self.stream.flush().map_err(io_err_into_net_error)
568    }
569
570    fn close(&mut self) -> Result<()> {
571        Ok(())
572    }
573
574    fn try_recv(&mut self, buf: &mut [MaybeUninit<u8>], peek: bool) -> Result<usize> {
575        let buf: &mut [u8] = unsafe { std::mem::transmute(buf) };
576        if !self.buffer.is_empty() {
577            let amt = buf.len().min(self.buffer.len());
578            buf[..amt].copy_from_slice(&self.buffer[..amt]);
579            if !peek {
580                self.buffer.advance(amt);
581            }
582            return Ok(amt);
583        }
584
585        if peek {
586            self.stream.peek(buf)
587        } else {
588            self.stream.read(buf)
589        }
590        .map_err(io_err_into_net_error)
591    }
592}
593
594impl VirtualSocket for LocalTcpStream {
595    fn set_ttl(&mut self, ttl: u32) -> Result<()> {
596        self.stream.set_ttl(ttl).map_err(io_err_into_net_error)
597    }
598
599    fn ttl(&self) -> Result<u32> {
600        self.stream.ttl().map_err(io_err_into_net_error)
601    }
602
603    fn addr_local(&self) -> Result<SocketAddr> {
604        self.stream.local_addr().map_err(io_err_into_net_error)
605    }
606
607    fn status(&self) -> Result<SocketStatus> {
608        Ok(SocketStatus::Opened)
609    }
610
611    fn set_handler(&mut self, mut handler: Box<dyn InterestHandler + Send + Sync>) -> Result<()> {
612        if let HandlerGuardState::ExternalHandler(guard) = &mut self.handler_guard {
613            match guard.replace_handler(handler) {
614                Ok(()) => return Ok(()),
615                Err(h) => handler = h,
616            }
617
618            // the handler could not be replaced so we need to build a new handler instead
619            if let Err(err) = guard.unregister(&mut self.stream) {
620                tracing::debug!("failed to unregister previous token - {}", err);
621            }
622        }
623
624        let guard = InterestGuard::new(
625            &self.selector,
626            handler,
627            &mut self.stream,
628            mio::Interest::READABLE.add(mio::Interest::WRITABLE),
629        )
630        .map_err(io_err_into_net_error)?;
631
632        self.handler_guard = HandlerGuardState::ExternalHandler(guard);
633
634        Ok(())
635    }
636}
637
638impl LocalTcpStream {
639    fn split_borrow(
640        &mut self,
641    ) -> (
642        &mut HandlerGuardState,
643        &Arc<Selector>,
644        &mut mio::net::TcpStream,
645        &mut BytesMut,
646    ) {
647        (
648            &mut self.handler_guard,
649            &self.selector,
650            &mut self.stream,
651            &mut self.buffer,
652        )
653    }
654}
655
656impl VirtualIoSource for LocalTcpStream {
657    fn remove_handler(&mut self) {
658        let mut guard = HandlerGuardState::None;
659        std::mem::swap(&mut guard, &mut self.handler_guard);
660        match guard {
661            HandlerGuardState::ExternalHandler(mut guard) => {
662                guard.unregister(&mut self.stream).ok();
663            }
664            HandlerGuardState::WakerMap(mut guard, _) => {
665                guard.unregister(&mut self.stream).ok();
666            }
667            HandlerGuardState::None => {}
668        }
669    }
670
671    fn poll_read_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
672        if !self.buffer.is_empty() {
673            return Poll::Ready(Ok(self.buffer.len()));
674        }
675
676        let (state, selector, stream, buffer) = self.split_borrow();
677        let map = state_as_waker_map(state, selector, stream).map_err(io_err_into_net_error)?;
678        map.pop(InterestType::Readable);
679        map.add(InterestType::Readable, cx.waker());
680
681        buffer.reserve(buffer.len() + 10240);
682        let uninit: &mut [MaybeUninit<u8>] = buffer.spare_capacity_mut();
683        let uninit_unsafe: &mut [u8] = unsafe { std::mem::transmute(uninit) };
684
685        match stream.read(uninit_unsafe) {
686            Ok(0) => Poll::Ready(Ok(0)),
687            Ok(amt) => {
688                unsafe {
689                    buffer.set_len(buffer.len() + amt);
690                }
691                Poll::Ready(Ok(amt))
692            }
693            Err(err) if err.kind() == io::ErrorKind::ConnectionAborted => Poll::Ready(Ok(0)),
694            Err(err) if err.kind() == io::ErrorKind::ConnectionReset => Poll::Ready(Ok(0)),
695            Err(err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
696            Err(err) => Poll::Ready(Err(io_err_into_net_error(err))),
697        }
698    }
699
700    fn poll_write_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
701        let (state, selector, stream, _) = self.split_borrow();
702        let map = state_as_waker_map(state, selector, stream).map_err(io_err_into_net_error)?;
703        #[cfg(not(target_os = "windows"))]
704        map.pop(InterestType::Writable);
705        map.add(InterestType::Writable, cx.waker());
706        map.add(InterestType::Closed, cx.waker());
707        if map.has_interest(InterestType::Closed) {
708            return Poll::Ready(Ok(0));
709        }
710
711        #[cfg(not(target_os = "windows"))]
712        match libc_poll(stream.as_raw_fd(), libc::POLLOUT | libc::POLLHUP) {
713            Some(val) if (val & libc::POLLHUP) != 0 => {
714                return Poll::Ready(Ok(0));
715            }
716            Some(val) if (val & libc::POLLOUT) != 0 => return Poll::Ready(Ok(10240)),
717            _ => {}
718        }
719
720        // In windows we can not poll the socket as it is not supported and hence
721        // what we do is immediately set the writable flag and relay on `mio` to
722        // refresh that flag when the state changes. In Linux what we do is actually
723        // make a non-blocking `poll` call to determine this state
724        #[cfg(target_os = "windows")]
725        if map.has_interest(InterestType::Writable) {
726            return Poll::Ready(Ok(10240));
727        }
728
729        Poll::Pending
730    }
731}
732
733#[cfg(not(target_os = "windows"))]
734fn libc_poll(fd: RawFd, events: libc::c_short) -> Option<libc::c_short> {
735    let mut fds: [libc::pollfd; 1] = [libc::pollfd {
736        fd,
737        events,
738        revents: 0,
739    }];
740    let fds_mut = &mut fds[..];
741    let ret = unsafe { libc::poll(fds_mut.as_mut_ptr(), 1, 0) };
742    match ret == 1 {
743        true => Some(fds[0].revents),
744        false => None,
745    }
746}
747
748#[derive(Debug)]
749pub struct LocalUdpSocket {
750    socket: mio::net::UdpSocket,
751    #[allow(dead_code)]
752    addr: SocketAddr,
753    selector: Arc<Selector>,
754    handler_guard: HandlerGuardState,
755    backlog: VecDeque<(BytesMut, SocketAddr)>,
756    ruleset: Option<Ruleset>,
757}
758
759impl LocalUdpSocket {
760    fn with_sock_ref<F, R>(&self, f: F) -> R
761    where
762        for<'a> F: FnOnce(socket2::SockRef<'a>) -> R,
763    {
764        #[cfg(not(windows))]
765        let r = socket2::SockRef::from(&self.socket);
766
767        #[cfg(windows)]
768        let b = unsafe {
769            std::os::windows::io::BorrowedSocket::borrow_raw(self.socket.as_raw_socket())
770        };
771        #[cfg(windows)]
772        let r = socket2::SockRef::from(&b);
773
774        f(r)
775    }
776}
777
778impl VirtualUdpSocket for LocalUdpSocket {
779    fn set_broadcast(&mut self, broadcast: bool) -> Result<()> {
780        self.socket
781            .set_broadcast(broadcast)
782            .map_err(io_err_into_net_error)
783    }
784
785    fn broadcast(&self) -> Result<bool> {
786        self.socket.broadcast().map_err(io_err_into_net_error)
787    }
788
789    fn set_multicast_loop_v4(&mut self, val: bool) -> Result<()> {
790        self.socket
791            .set_multicast_loop_v4(val)
792            .map_err(io_err_into_net_error)
793    }
794
795    fn multicast_loop_v4(&self) -> Result<bool> {
796        self.socket
797            .multicast_loop_v4()
798            .map_err(io_err_into_net_error)
799    }
800
801    fn set_multicast_loop_v6(&mut self, val: bool) -> Result<()> {
802        self.socket
803            .set_multicast_loop_v6(val)
804            .map_err(io_err_into_net_error)
805    }
806
807    fn multicast_loop_v6(&self) -> Result<bool> {
808        self.socket
809            .multicast_loop_v6()
810            .map_err(io_err_into_net_error)
811    }
812
813    fn set_multicast_ttl_v4(&mut self, ttl: u32) -> Result<()> {
814        self.socket
815            .set_multicast_ttl_v4(ttl)
816            .map_err(io_err_into_net_error)
817    }
818
819    fn multicast_ttl_v4(&self) -> Result<u32> {
820        self.socket
821            .multicast_ttl_v4()
822            .map_err(io_err_into_net_error)
823    }
824
825    fn join_multicast_v4(&mut self, multiaddr: Ipv4Addr, iface: Ipv4Addr) -> Result<()> {
826        self.with_sock_ref(|s| s.join_multicast_v4(&multiaddr, &iface))
827            .map_err(io_err_into_net_error)
828    }
829
830    fn leave_multicast_v4(&mut self, multiaddr: Ipv4Addr, iface: Ipv4Addr) -> Result<()> {
831        self.with_sock_ref(|s| s.leave_multicast_v4(&multiaddr, &iface))
832            .map_err(io_err_into_net_error)
833    }
834
835    fn join_multicast_v6(&mut self, multiaddr: Ipv6Addr, iface: u32) -> Result<()> {
836        self.socket
837            .join_multicast_v6(&multiaddr, iface)
838            .map_err(io_err_into_net_error)
839    }
840
841    fn leave_multicast_v6(&mut self, multiaddr: Ipv6Addr, iface: u32) -> Result<()> {
842        self.socket
843            .leave_multicast_v6(&multiaddr, iface)
844            .map_err(io_err_into_net_error)
845    }
846
847    fn addr_peer(&self) -> Result<Option<SocketAddr>> {
848        self.socket
849            .peer_addr()
850            .map(Some)
851            .map_err(io_err_into_net_error)
852    }
853}
854
855impl VirtualConnectionlessSocket for LocalUdpSocket {
856    fn try_send_to(&mut self, data: &[u8], addr: SocketAddr) -> Result<usize> {
857        if let Some(ruleset) = self.ruleset.as_ref()
858            && !ruleset.allows_socket(addr, Direction::Outbound)
859        {
860            tracing::warn!(%addr, "try_send blocked by firewall rule");
861            return Err(NetworkError::PermissionDenied);
862        }
863
864        let ret = self
865            .socket
866            .send_to(data, addr)
867            .map_err(io_err_into_net_error);
868        match &ret {
869            Ok(0) | Err(NetworkError::WouldBlock) => {
870                if let HandlerGuardState::WakerMap(_, map) = &mut self.handler_guard {
871                    map.pop(InterestType::Writable);
872                }
873            }
874            _ => {}
875        }
876        ret
877    }
878
879    fn try_recv_from(
880        &mut self,
881        buf: &mut [MaybeUninit<u8>],
882        peek: bool,
883    ) -> Result<(usize, SocketAddr)> {
884        let buf: &mut [u8] = unsafe { std::mem::transmute(buf) };
885        if peek {
886            self.socket.peek_from(buf)
887        } else {
888            self.socket.recv_from(buf)
889        }
890        .map_err(io_err_into_net_error)
891    }
892}
893
894impl VirtualSocket for LocalUdpSocket {
895    fn set_ttl(&mut self, ttl: u32) -> Result<()> {
896        self.socket.set_ttl(ttl).map_err(io_err_into_net_error)
897    }
898
899    fn ttl(&self) -> Result<u32> {
900        self.socket.ttl().map_err(io_err_into_net_error)
901    }
902
903    fn addr_local(&self) -> Result<SocketAddr> {
904        self.socket.local_addr().map_err(io_err_into_net_error)
905    }
906
907    fn status(&self) -> Result<SocketStatus> {
908        Ok(SocketStatus::Opened)
909    }
910
911    fn set_handler(&mut self, mut handler: Box<dyn InterestHandler + Send + Sync>) -> Result<()> {
912        if let HandlerGuardState::ExternalHandler(guard) = &mut self.handler_guard {
913            match guard.replace_handler(handler) {
914                Ok(()) => {
915                    return Ok(());
916                }
917                Err(h) => handler = h,
918            }
919
920            // the handler could not be replaced so we need to build a new handler instead
921            if let Err(err) = guard.unregister(&mut self.socket) {
922                tracing::debug!("failed to unregister previous token - {}", err);
923            }
924        }
925
926        let guard = InterestGuard::new(
927            &self.selector,
928            handler,
929            &mut self.socket,
930            mio::Interest::READABLE.add(mio::Interest::WRITABLE),
931        )
932        .map_err(io_err_into_net_error)?;
933
934        self.handler_guard = HandlerGuardState::ExternalHandler(guard);
935
936        Ok(())
937    }
938}
939
940impl LocalUdpSocket {
941    fn split_borrow(
942        &mut self,
943    ) -> (
944        &mut HandlerGuardState,
945        &Arc<Selector>,
946        &mut mio::net::UdpSocket,
947    ) {
948        (&mut self.handler_guard, &self.selector, &mut self.socket)
949    }
950}
951
952impl VirtualIoSource for LocalUdpSocket {
953    fn remove_handler(&mut self) {
954        let mut guard = HandlerGuardState::None;
955        std::mem::swap(&mut guard, &mut self.handler_guard);
956        match guard {
957            HandlerGuardState::ExternalHandler(mut guard) => {
958                guard.unregister(&mut self.socket).ok();
959            }
960            HandlerGuardState::WakerMap(mut guard, _) => {
961                guard.unregister(&mut self.socket).ok();
962            }
963            HandlerGuardState::None => {}
964        }
965    }
966
967    fn poll_read_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
968        if !self.backlog.is_empty() {
969            let total = self.backlog.iter().map(|a| a.0.len()).sum();
970            return Poll::Ready(Ok(total));
971        }
972
973        let (state, selector, socket) = self.split_borrow();
974        let map = state_as_waker_map(state, selector, socket).map_err(io_err_into_net_error)?;
975        map.pop(InterestType::Readable);
976        map.add(InterestType::Readable, cx.waker());
977
978        let mut buffer = BytesMut::default();
979        buffer.reserve(10240);
980        let uninit: &mut [MaybeUninit<u8>] = buffer.spare_capacity_mut();
981        let uninit_unsafe: &mut [u8] = unsafe { std::mem::transmute(uninit) };
982
983        match self.socket.recv_from(uninit_unsafe) {
984            Ok((0, _)) => Poll::Ready(Ok(0)),
985            Ok((amt, peer)) => {
986                unsafe {
987                    buffer.set_len(amt);
988                }
989                self.backlog.push_back((buffer, peer));
990                Poll::Ready(Ok(amt))
991            }
992            Err(err) if err.kind() == io::ErrorKind::ConnectionAborted => Poll::Ready(Ok(0)),
993            Err(err) if err.kind() == io::ErrorKind::ConnectionReset => Poll::Ready(Ok(0)),
994            Err(err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
995            Err(err) => Poll::Ready(Err(io_err_into_net_error(err))),
996        }
997    }
998
999    fn poll_write_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
1000        let (state, selector, socket) = self.split_borrow();
1001        let map = state_as_waker_map(state, selector, socket).map_err(io_err_into_net_error)?;
1002        #[cfg(not(target_os = "windows"))]
1003        map.pop(InterestType::Writable);
1004        map.add(InterestType::Writable, cx.waker());
1005
1006        #[cfg(not(target_os = "windows"))]
1007        match libc_poll(socket.as_raw_fd(), libc::POLLOUT | libc::POLLHUP) {
1008            Some(val) if (val & libc::POLLHUP) != 0 => {
1009                return Poll::Ready(Ok(0));
1010            }
1011            Some(val) if (val & libc::POLLOUT) != 0 => return Poll::Ready(Ok(10240)),
1012            _ => {}
1013        }
1014
1015        // In windows we can not poll the socket as it is not supported and hence
1016        // what we do is immediately set the writable flag and relay on `mio` to
1017        // refresh that flag when the state changes. In Linux what we do is actually
1018        // make a non-blocking `poll` call to determine this state
1019        #[cfg(target_os = "windows")]
1020        if map.has_interest(InterestType::Writable) {
1021            return Poll::Ready(Ok(10240));
1022        }
1023
1024        Poll::Pending
1025    }
1026}