virtual_net/
host.rs

1#![allow(unused_variables)]
2use crate::ruleset::{Direction, Ruleset};
3#[allow(unused_imports)]
4use crate::{
5    IpCidr, IpRoute, NetworkError, Result, SocketStatus, StreamSecurity, VirtualConnectedSocket,
6    VirtualConnectionlessSocket, VirtualIcmpSocket, VirtualNetworking, VirtualRawSocket,
7    VirtualSocket, VirtualTcpListener, VirtualTcpSocket, VirtualUdpSocket,
8};
9use crate::{VirtualIoSource, io_err_into_net_error};
10use bytes::{Buf, BytesMut};
11use std::collections::VecDeque;
12use std::io::{self, Read, Write};
13use std::mem::MaybeUninit;
14use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, Shutdown, SocketAddr};
15#[cfg(not(target_os = "windows"))]
16use std::os::fd::AsRawFd;
17#[cfg(not(target_os = "windows"))]
18use std::os::fd::RawFd;
19#[cfg(windows)]
20use std::os::windows::io::AsRawSocket;
21
22use std::sync::Arc;
23use std::task::Poll;
24use std::time::Duration;
25use tokio::runtime::Handle;
26#[allow(unused_imports, dead_code)]
27use tracing::{debug, error, info, trace, warn};
28use virtual_mio::{
29    HandlerGuardState, InterestGuard, InterestHandler, InterestType, Selector, state_as_waker_map,
30};
31
32#[derive(Debug)]
33pub struct LocalNetworking {
34    selector: Arc<Selector>,
35    handle: Handle,
36    ruleset: Option<Ruleset>,
37}
38
39impl LocalNetworking {
40    pub fn new() -> Self {
41        Self {
42            selector: Selector::new(),
43            handle: Handle::current(),
44            ruleset: None,
45        }
46    }
47
48    pub fn with_ruleset(ruleset: Ruleset) -> Self {
49        Self {
50            selector: Selector::new(),
51            handle: Handle::current(),
52            ruleset: Some(ruleset),
53        }
54    }
55}
56
57impl Drop for LocalNetworking {
58    fn drop(&mut self) {
59        self.selector.shutdown();
60    }
61}
62
63impl Default for LocalNetworking {
64    fn default() -> Self {
65        Self::new()
66    }
67}
68
69#[async_trait::async_trait]
70#[allow(unused_variables)]
71impl VirtualNetworking for LocalNetworking {
72    async fn listen_tcp(
73        &self,
74        addr: SocketAddr,
75        only_v6: bool,
76        reuse_port: bool,
77        reuse_addr: bool,
78    ) -> Result<Box<dyn VirtualTcpListener + Sync>> {
79        if let Some(ruleset) = self.ruleset.as_ref() {
80            if !ruleset.allows_socket(addr, Direction::Inbound) {
81                tracing::warn!(%addr, "listen_tcp blocked by firewall rule");
82                return Err(NetworkError::PermissionDenied);
83            }
84        }
85
86        let listener = std::net::TcpListener::bind(addr)
87            .map(|sock| {
88                sock.set_nonblocking(true).ok();
89                Box::new(LocalTcpListener {
90                    stream: mio::net::TcpListener::from_std(sock),
91                    selector: self.selector.clone(),
92                    handler_guard: HandlerGuardState::None,
93                    no_delay: None,
94                    keep_alive: None,
95                    backlog: Default::default(),
96                    ruleset: self.ruleset.clone(),
97                })
98            })
99            .map_err(io_err_into_net_error)?;
100        Ok(listener)
101    }
102
103    async fn bind_udp(
104        &self,
105        addr: SocketAddr,
106        reuse_port: bool,
107        reuse_addr: bool,
108    ) -> Result<Box<dyn VirtualUdpSocket + Sync>> {
109        use socket2::{Domain, Socket, Type};
110
111        if let Some(ruleset) = self.ruleset.as_ref() {
112            if !ruleset.allows_socket(addr, Direction::Inbound) {
113                tracing::warn!(%addr, "bind_udp blocked by firewall rule");
114                return Err(NetworkError::PermissionDenied);
115            }
116        }
117
118        #[cfg(not(windows))]
119        let socket = {
120            let domain = if addr.is_ipv4() {
121                Domain::IPV4
122            } else {
123                Domain::IPV6
124            };
125            let std_sock = Socket::new(domain, Type::DGRAM, None).map_err(io_err_into_net_error)?;
126            std_sock
127                .set_nonblocking(true)
128                .map_err(io_err_into_net_error)?;
129            std_sock
130                .set_reuse_address(reuse_addr)
131                .map_err(io_err_into_net_error)?;
132            std_sock
133                .set_reuse_port(reuse_port)
134                .map_err(io_err_into_net_error)?;
135            std_sock.bind(&addr.into()).map_err(io_err_into_net_error)?;
136            mio::net::UdpSocket::from_std(std_sock.into())
137        };
138        #[cfg(windows)]
139        let socket = mio::net::UdpSocket::bind(addr).map_err(io_err_into_net_error)?;
140
141        #[allow(unused_mut)]
142        let mut ret = LocalUdpSocket {
143            selector: self.selector.clone(),
144            socket,
145            addr,
146            handler_guard: HandlerGuardState::None,
147            backlog: Default::default(),
148            ruleset: self.ruleset.clone(),
149        };
150
151        // In windows we can not poll the socket as it is not supported and hence
152        // what we do is immediately set the writable flag and relay on `mio` to
153        // refresh that flag when the state changes. In Linux what we do is actually
154        // make a non-blocking `poll` call to determine this state
155        #[cfg(target_os = "windows")]
156        {
157            let (state, selector, socket) = ret.split_borrow();
158            let map = state_as_waker_map(state, selector, socket).map_err(io_err_into_net_error)?;
159            map.push(InterestType::Writable);
160        }
161
162        Ok(Box::new(ret))
163    }
164
165    async fn connect_tcp(
166        &self,
167        _addr: SocketAddr,
168        mut peer: SocketAddr,
169    ) -> Result<Box<dyn VirtualTcpSocket + Sync>> {
170        if let Some(ruleset) = self.ruleset.as_ref() {
171            if !ruleset.allows_socket(peer, Direction::Outbound) {
172                tracing::warn!(%peer, "connect_tcp blocked by firewall rule");
173                return Err(NetworkError::PermissionDenied);
174            }
175        }
176
177        let stream = mio::net::TcpStream::connect(peer).map_err(io_err_into_net_error)?;
178
179        if let Ok(p) = stream.peer_addr() {
180            peer = p;
181        }
182        let socket = Box::new(LocalTcpStream::new(self.selector.clone(), stream, peer));
183        Ok(socket)
184    }
185
186    async fn resolve(
187        &self,
188        host: &str,
189        port: Option<u16>,
190        dns_server: Option<IpAddr>,
191    ) -> Result<Vec<IpAddr>> {
192        if let Some(ruleset) = self.ruleset.as_ref() {
193            if !ruleset.allows_domain(host) {
194                tracing::warn!(%host, "dns resolve blocked by firewall rule");
195                return Err(NetworkError::PermissionDenied);
196            }
197        }
198
199        let host_to_lookup = if host.contains(':') {
200            host.to_string()
201        } else {
202            format!("{}:{}", host, port.unwrap_or(0))
203        };
204        let addrs = self
205            .handle
206            .spawn(tokio::net::lookup_host(host_to_lookup))
207            .await
208            .map_err(|_| NetworkError::IOError)?
209            .map(|a| a.map(|a| a.ip()).collect::<Vec<_>>())
210            .map_err(io_err_into_net_error)?;
211
212        if let Some(ruleset) = self.ruleset.as_ref() {
213            if let Err(e) = ruleset.expand_domain(host, &addrs) {
214                tracing::debug!(err=%e, "ruleset expansion failed");
215            } else {
216                tracing::debug!(addrs=?addrs, domain = host, "ruleset expansion")
217            }
218        }
219
220        Ok(addrs)
221    }
222}
223
224#[derive(Debug)]
225pub struct LocalTcpListener {
226    stream: mio::net::TcpListener,
227    selector: Arc<Selector>,
228    handler_guard: HandlerGuardState,
229    no_delay: Option<bool>,
230    keep_alive: Option<bool>,
231    backlog: VecDeque<(Box<dyn VirtualTcpSocket + Sync>, SocketAddr)>,
232    ruleset: Option<Ruleset>,
233}
234
235impl LocalTcpListener {
236    fn try_accept_internal(&mut self) -> Result<(Box<dyn VirtualTcpSocket + Sync>, SocketAddr)> {
237        match self.stream.accept().map_err(io_err_into_net_error) {
238            Ok((stream, addr)) => {
239                if let Some(ruleset) = self.ruleset.as_ref() {
240                    if !ruleset.allows_socket(addr, Direction::Outbound) {
241                        tracing::warn!(%addr, "try_accept blocked by firewall rule");
242                        return Err(NetworkError::PermissionDenied);
243                    }
244                }
245
246                let mut socket = LocalTcpStream::new(self.selector.clone(), stream, addr);
247                if let Some(no_delay) = self.no_delay {
248                    socket.set_nodelay(no_delay).ok();
249                }
250                if let Some(keep_alive) = self.keep_alive {
251                    socket.set_keepalive(keep_alive).ok();
252                }
253                Ok((Box::new(socket), addr))
254            }
255            Err(NetworkError::WouldBlock) => {
256                if let HandlerGuardState::WakerMap(_, map) = &mut self.handler_guard {
257                    map.pop(InterestType::Readable);
258                    map.pop(InterestType::Writable);
259                }
260                Err(NetworkError::WouldBlock)
261            }
262            Err(err) => Err(err),
263        }
264    }
265}
266
267impl VirtualTcpListener for LocalTcpListener {
268    fn try_accept(&mut self) -> Result<(Box<dyn VirtualTcpSocket + Sync>, SocketAddr)> {
269        if let Some(child) = self.backlog.pop_front() {
270            return Ok(child);
271        }
272        self.try_accept_internal()
273    }
274
275    fn set_handler(&mut self, mut handler: Box<dyn InterestHandler + Send + Sync>) -> Result<()> {
276        if let HandlerGuardState::ExternalHandler(guard) = &mut self.handler_guard {
277            match guard.replace_handler(handler) {
278                Ok(()) => return Ok(()),
279                Err(h) => handler = h,
280            }
281
282            // the handler could not be replaced so we need to build a new handler instead
283            if let Err(err) = guard.unregister(&mut self.stream) {
284                tracing::debug!("failed to unregister previous token - {}", err);
285            }
286        }
287
288        let guard = InterestGuard::new(
289            &self.selector,
290            handler,
291            &mut self.stream,
292            mio::Interest::READABLE.add(mio::Interest::WRITABLE),
293        )
294        .map_err(io_err_into_net_error)?;
295
296        self.handler_guard = HandlerGuardState::ExternalHandler(guard);
297
298        Ok(())
299    }
300
301    fn addr_local(&self) -> Result<SocketAddr> {
302        self.stream.local_addr().map_err(io_err_into_net_error)
303    }
304
305    fn set_ttl(&mut self, ttl: u8) -> Result<()> {
306        self.stream
307            .set_ttl(ttl as u32)
308            .map_err(io_err_into_net_error)
309    }
310
311    fn ttl(&self) -> Result<u8> {
312        self.stream
313            .ttl()
314            .map(|ttl| ttl as u8)
315            .map_err(io_err_into_net_error)
316    }
317}
318
319impl LocalTcpListener {
320    fn split_borrow(
321        &mut self,
322    ) -> (
323        &mut HandlerGuardState,
324        &Arc<Selector>,
325        &mut mio::net::TcpListener,
326    ) {
327        (&mut self.handler_guard, &self.selector, &mut self.stream)
328    }
329}
330
331impl VirtualIoSource for LocalTcpListener {
332    fn remove_handler(&mut self) {
333        let mut guard = HandlerGuardState::None;
334        std::mem::swap(&mut guard, &mut self.handler_guard);
335        match guard {
336            HandlerGuardState::ExternalHandler(mut guard) => {
337                guard.unregister(&mut self.stream).ok();
338            }
339            HandlerGuardState::WakerMap(mut guard, _) => {
340                guard.unregister(&mut self.stream).ok();
341            }
342            HandlerGuardState::None => {}
343        }
344    }
345
346    fn poll_read_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
347        if !self.backlog.is_empty() {
348            return Poll::Ready(Ok(self.backlog.len()));
349        }
350
351        let (state, selector, source) = self.split_borrow();
352        let map = state_as_waker_map(state, selector, source).map_err(io_err_into_net_error)?;
353        map.add(InterestType::Readable, cx.waker());
354
355        if let Ok(child) = self.try_accept_internal() {
356            self.backlog.push_back(child);
357            return Poll::Ready(Ok(1));
358        }
359        Poll::Pending
360    }
361
362    fn poll_write_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
363        if !self.backlog.is_empty() {
364            return Poll::Ready(Ok(self.backlog.len()));
365        }
366
367        let (state, selector, source) = self.split_borrow();
368        let map = state_as_waker_map(state, selector, source).map_err(io_err_into_net_error)?;
369        map.add(InterestType::Writable, cx.waker());
370
371        if let Ok(child) = self.try_accept_internal() {
372            self.backlog.push_back(child);
373            return Poll::Ready(Ok(1));
374        }
375        Poll::Pending
376    }
377}
378
379#[derive(Debug)]
380pub struct LocalTcpStream {
381    stream: mio::net::TcpStream,
382    addr: SocketAddr,
383    shutdown: Option<Shutdown>,
384    selector: Arc<Selector>,
385    handler_guard: HandlerGuardState,
386    buffer: BytesMut,
387}
388
389impl LocalTcpStream {
390    fn new(selector: Arc<Selector>, stream: mio::net::TcpStream, addr: SocketAddr) -> Self {
391        #[allow(unused_mut)]
392        let mut ret = Self {
393            stream,
394            addr,
395            shutdown: None,
396            selector,
397            handler_guard: HandlerGuardState::None,
398            buffer: BytesMut::new(),
399        };
400
401        // In windows we can not poll the socket as it is not supported and hence
402        // what we do is immediately set the writable flag and relay on `mio` to
403        // refresh that flag when the state changes. In Linux what we do is actually
404        // make a non-blocking `poll` call to determine this state
405        #[cfg(target_os = "windows")]
406        {
407            let (state, selector, socket, _) = ret.split_borrow();
408            if let Ok(map) = state_as_waker_map(state, selector, socket) {
409                map.push(InterestType::Writable);
410            }
411        }
412
413        ret
414    }
415
416    fn with_sock_ref<F, R>(&self, f: F) -> R
417    where
418        for<'a> F: FnOnce(socket2::SockRef<'a>) -> R,
419    {
420        #[cfg(not(windows))]
421        let r = socket2::SockRef::from(&self.stream);
422
423        #[cfg(windows)]
424        let b = unsafe {
425            std::os::windows::io::BorrowedSocket::borrow_raw(self.stream.as_raw_socket())
426        };
427        #[cfg(windows)]
428        let r = socket2::SockRef::from(&b);
429
430        f(r)
431    }
432}
433
434impl VirtualTcpSocket for LocalTcpStream {
435    fn set_recv_buf_size(&mut self, size: usize) -> Result<()> {
436        Ok(())
437    }
438
439    fn recv_buf_size(&self) -> Result<usize> {
440        Err(NetworkError::Unsupported)
441    }
442
443    fn set_send_buf_size(&mut self, size: usize) -> Result<()> {
444        Ok(())
445    }
446
447    fn send_buf_size(&self) -> Result<usize> {
448        Err(NetworkError::Unsupported)
449    }
450
451    fn set_nodelay(&mut self, nodelay: bool) -> Result<()> {
452        self.stream
453            .set_nodelay(nodelay)
454            .map_err(io_err_into_net_error)
455    }
456
457    fn nodelay(&self) -> Result<bool> {
458        self.stream.nodelay().map_err(io_err_into_net_error)
459    }
460
461    fn set_keepalive(&mut self, keepalive: bool) -> Result<()> {
462        self.with_sock_ref(|s| s.set_keepalive(true))
463            .map_err(io_err_into_net_error)?;
464        Ok(())
465    }
466
467    fn keepalive(&self) -> Result<bool> {
468        let ret = self
469            .with_sock_ref(|s| s.keepalive())
470            .map_err(io_err_into_net_error)?;
471        Ok(ret)
472    }
473
474    #[cfg(not(target_os = "windows"))]
475    fn set_dontroute(&mut self, val: bool) -> Result<()> {
476        // TODO:
477        // Don't route is being set by WASIX which breaks networking
478        // Why this is being set is unknown but we need to disable
479        // the functionality for now as it breaks everything
480
481        let val = val as libc::c_int;
482        let payload = &val as *const libc::c_int as *const libc::c_void;
483        let err = unsafe {
484            libc::setsockopt(
485                self.stream.as_raw_fd(),
486                libc::SOL_SOCKET,
487                libc::SO_DONTROUTE,
488                payload,
489                std::mem::size_of::<libc::c_int>() as libc::socklen_t,
490            )
491        };
492        if err == -1 {
493            return Err(io_err_into_net_error(std::io::Error::last_os_error()));
494        }
495        Ok(())
496    }
497    #[cfg(target_os = "windows")]
498    fn set_dontroute(&mut self, val: bool) -> Result<()> {
499        Err(NetworkError::Unsupported)
500    }
501
502    #[cfg(not(target_os = "windows"))]
503    fn dontroute(&self) -> Result<bool> {
504        let mut payload: MaybeUninit<libc::c_int> = MaybeUninit::uninit();
505        let mut len = std::mem::size_of::<libc::c_int>() as libc::socklen_t;
506        let err = unsafe {
507            libc::getsockopt(
508                self.stream.as_raw_fd(),
509                libc::SOL_SOCKET,
510                libc::SO_DONTROUTE,
511                payload.as_mut_ptr().cast(),
512                &mut len,
513            )
514        };
515        if err == -1 {
516            return Err(io_err_into_net_error(std::io::Error::last_os_error()));
517        }
518        Ok(unsafe { payload.assume_init() != 0 })
519    }
520    #[cfg(target_os = "windows")]
521    fn dontroute(&self) -> Result<bool> {
522        Err(NetworkError::Unsupported)
523    }
524
525    fn addr_peer(&self) -> Result<SocketAddr> {
526        Ok(self.addr)
527    }
528
529    fn shutdown(&mut self, how: Shutdown) -> Result<()> {
530        self.stream.shutdown(how).map_err(io_err_into_net_error)?;
531        self.shutdown = Some(how);
532        Ok(())
533    }
534
535    fn is_closed(&self) -> bool {
536        false
537    }
538}
539
540impl VirtualConnectedSocket for LocalTcpStream {
541    fn set_linger(&mut self, linger: Option<Duration>) -> Result<()> {
542        self.with_sock_ref(|s| s.set_linger(linger))
543            .map_err(io_err_into_net_error)?;
544        Ok(())
545    }
546
547    fn linger(&self) -> Result<Option<Duration>> {
548        self.with_sock_ref(|s| s.linger())
549            .map_err(io_err_into_net_error)
550    }
551
552    fn try_send(&mut self, data: &[u8]) -> Result<usize> {
553        let ret = self.stream.write(data).map_err(io_err_into_net_error);
554        match &ret {
555            Ok(0) | Err(NetworkError::WouldBlock) => {
556                if let HandlerGuardState::WakerMap(_, map) = &mut self.handler_guard {
557                    map.pop(InterestType::Writable);
558                }
559            }
560            _ => {}
561        }
562        ret
563    }
564
565    fn try_flush(&mut self) -> Result<()> {
566        self.stream.flush().map_err(io_err_into_net_error)
567    }
568
569    fn close(&mut self) -> Result<()> {
570        Ok(())
571    }
572
573    fn try_recv(&mut self, buf: &mut [MaybeUninit<u8>], peek: bool) -> Result<usize> {
574        let buf: &mut [u8] = unsafe { std::mem::transmute(buf) };
575        if !self.buffer.is_empty() {
576            let amt = buf.len().min(self.buffer.len());
577            buf[..amt].copy_from_slice(&self.buffer[..amt]);
578            if !peek {
579                self.buffer.advance(amt);
580            }
581            return Ok(amt);
582        }
583
584        if peek {
585            self.stream.peek(buf)
586        } else {
587            self.stream.read(buf)
588        }
589        .map_err(io_err_into_net_error)
590    }
591}
592
593impl VirtualSocket for LocalTcpStream {
594    fn set_ttl(&mut self, ttl: u32) -> Result<()> {
595        self.stream.set_ttl(ttl).map_err(io_err_into_net_error)
596    }
597
598    fn ttl(&self) -> Result<u32> {
599        self.stream.ttl().map_err(io_err_into_net_error)
600    }
601
602    fn addr_local(&self) -> Result<SocketAddr> {
603        self.stream.local_addr().map_err(io_err_into_net_error)
604    }
605
606    fn status(&self) -> Result<SocketStatus> {
607        Ok(SocketStatus::Opened)
608    }
609
610    fn set_handler(&mut self, mut handler: Box<dyn InterestHandler + Send + Sync>) -> Result<()> {
611        if let HandlerGuardState::ExternalHandler(guard) = &mut self.handler_guard {
612            match guard.replace_handler(handler) {
613                Ok(()) => return Ok(()),
614                Err(h) => handler = h,
615            }
616
617            // the handler could not be replaced so we need to build a new handler instead
618            if let Err(err) = guard.unregister(&mut self.stream) {
619                tracing::debug!("failed to unregister previous token - {}", err);
620            }
621        }
622
623        let guard = InterestGuard::new(
624            &self.selector,
625            handler,
626            &mut self.stream,
627            mio::Interest::READABLE.add(mio::Interest::WRITABLE),
628        )
629        .map_err(io_err_into_net_error)?;
630
631        self.handler_guard = HandlerGuardState::ExternalHandler(guard);
632
633        Ok(())
634    }
635}
636
637impl LocalTcpStream {
638    fn split_borrow(
639        &mut self,
640    ) -> (
641        &mut HandlerGuardState,
642        &Arc<Selector>,
643        &mut mio::net::TcpStream,
644        &mut BytesMut,
645    ) {
646        (
647            &mut self.handler_guard,
648            &self.selector,
649            &mut self.stream,
650            &mut self.buffer,
651        )
652    }
653}
654
655impl VirtualIoSource for LocalTcpStream {
656    fn remove_handler(&mut self) {
657        let mut guard = HandlerGuardState::None;
658        std::mem::swap(&mut guard, &mut self.handler_guard);
659        match guard {
660            HandlerGuardState::ExternalHandler(mut guard) => {
661                guard.unregister(&mut self.stream).ok();
662            }
663            HandlerGuardState::WakerMap(mut guard, _) => {
664                guard.unregister(&mut self.stream).ok();
665            }
666            HandlerGuardState::None => {}
667        }
668    }
669
670    fn poll_read_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
671        if !self.buffer.is_empty() {
672            return Poll::Ready(Ok(self.buffer.len()));
673        }
674
675        let (state, selector, stream, buffer) = self.split_borrow();
676        let map = state_as_waker_map(state, selector, stream).map_err(io_err_into_net_error)?;
677        map.pop(InterestType::Readable);
678        map.add(InterestType::Readable, cx.waker());
679
680        buffer.reserve(buffer.len() + 10240);
681        let uninit: &mut [MaybeUninit<u8>] = buffer.spare_capacity_mut();
682        let uninit_unsafe: &mut [u8] = unsafe { std::mem::transmute(uninit) };
683
684        match stream.read(uninit_unsafe) {
685            Ok(0) => Poll::Ready(Ok(0)),
686            Ok(amt) => {
687                unsafe {
688                    buffer.set_len(buffer.len() + amt);
689                }
690                Poll::Ready(Ok(amt))
691            }
692            Err(err) if err.kind() == io::ErrorKind::ConnectionAborted => Poll::Ready(Ok(0)),
693            Err(err) if err.kind() == io::ErrorKind::ConnectionReset => Poll::Ready(Ok(0)),
694            Err(err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
695            Err(err) => Poll::Ready(Err(io_err_into_net_error(err))),
696        }
697    }
698
699    fn poll_write_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
700        let (state, selector, stream, _) = self.split_borrow();
701        let map = state_as_waker_map(state, selector, stream).map_err(io_err_into_net_error)?;
702        #[cfg(not(target_os = "windows"))]
703        map.pop(InterestType::Writable);
704        map.add(InterestType::Writable, cx.waker());
705        map.add(InterestType::Closed, cx.waker());
706        if map.has_interest(InterestType::Closed) {
707            return Poll::Ready(Ok(0));
708        }
709
710        #[cfg(not(target_os = "windows"))]
711        match libc_poll(stream.as_raw_fd(), libc::POLLOUT | libc::POLLHUP) {
712            Some(val) if (val & libc::POLLHUP) != 0 => {
713                return Poll::Ready(Ok(0));
714            }
715            Some(val) if (val & libc::POLLOUT) != 0 => return Poll::Ready(Ok(10240)),
716            _ => {}
717        }
718
719        // In windows we can not poll the socket as it is not supported and hence
720        // what we do is immediately set the writable flag and relay on `mio` to
721        // refresh that flag when the state changes. In Linux what we do is actually
722        // make a non-blocking `poll` call to determine this state
723        #[cfg(target_os = "windows")]
724        if map.has_interest(InterestType::Writable) {
725            return Poll::Ready(Ok(10240));
726        }
727
728        Poll::Pending
729    }
730}
731
732#[cfg(not(target_os = "windows"))]
733fn libc_poll(fd: RawFd, events: libc::c_short) -> Option<libc::c_short> {
734    let mut fds: [libc::pollfd; 1] = [libc::pollfd {
735        fd,
736        events,
737        revents: 0,
738    }];
739    let fds_mut = &mut fds[..];
740    let ret = unsafe { libc::poll(fds_mut.as_mut_ptr(), 1, 0) };
741    match ret == 1 {
742        true => Some(fds[0].revents),
743        false => None,
744    }
745}
746
747#[derive(Debug)]
748pub struct LocalUdpSocket {
749    socket: mio::net::UdpSocket,
750    #[allow(dead_code)]
751    addr: SocketAddr,
752    selector: Arc<Selector>,
753    handler_guard: HandlerGuardState,
754    backlog: VecDeque<(BytesMut, SocketAddr)>,
755    ruleset: Option<Ruleset>,
756}
757
758impl LocalUdpSocket {
759    fn with_sock_ref<F, R>(&self, f: F) -> R
760    where
761        for<'a> F: FnOnce(socket2::SockRef<'a>) -> R,
762    {
763        #[cfg(not(windows))]
764        let r = socket2::SockRef::from(&self.socket);
765
766        #[cfg(windows)]
767        let b = unsafe {
768            std::os::windows::io::BorrowedSocket::borrow_raw(self.socket.as_raw_socket())
769        };
770        #[cfg(windows)]
771        let r = socket2::SockRef::from(&b);
772
773        f(r)
774    }
775}
776
777impl VirtualUdpSocket for LocalUdpSocket {
778    fn set_broadcast(&mut self, broadcast: bool) -> Result<()> {
779        self.socket
780            .set_broadcast(broadcast)
781            .map_err(io_err_into_net_error)
782    }
783
784    fn broadcast(&self) -> Result<bool> {
785        self.socket.broadcast().map_err(io_err_into_net_error)
786    }
787
788    fn set_multicast_loop_v4(&mut self, val: bool) -> Result<()> {
789        self.socket
790            .set_multicast_loop_v4(val)
791            .map_err(io_err_into_net_error)
792    }
793
794    fn multicast_loop_v4(&self) -> Result<bool> {
795        self.socket
796            .multicast_loop_v4()
797            .map_err(io_err_into_net_error)
798    }
799
800    fn set_multicast_loop_v6(&mut self, val: bool) -> Result<()> {
801        self.socket
802            .set_multicast_loop_v6(val)
803            .map_err(io_err_into_net_error)
804    }
805
806    fn multicast_loop_v6(&self) -> Result<bool> {
807        self.socket
808            .multicast_loop_v6()
809            .map_err(io_err_into_net_error)
810    }
811
812    fn set_multicast_ttl_v4(&mut self, ttl: u32) -> Result<()> {
813        self.socket
814            .set_multicast_ttl_v4(ttl)
815            .map_err(io_err_into_net_error)
816    }
817
818    fn multicast_ttl_v4(&self) -> Result<u32> {
819        self.socket
820            .multicast_ttl_v4()
821            .map_err(io_err_into_net_error)
822    }
823
824    fn join_multicast_v4(&mut self, multiaddr: Ipv4Addr, iface: Ipv4Addr) -> Result<()> {
825        self.with_sock_ref(|s| s.join_multicast_v4(&multiaddr, &iface))
826            .map_err(io_err_into_net_error)
827    }
828
829    fn leave_multicast_v4(&mut self, multiaddr: Ipv4Addr, iface: Ipv4Addr) -> Result<()> {
830        self.with_sock_ref(|s| s.leave_multicast_v4(&multiaddr, &iface))
831            .map_err(io_err_into_net_error)
832    }
833
834    fn join_multicast_v6(&mut self, multiaddr: Ipv6Addr, iface: u32) -> Result<()> {
835        self.socket
836            .join_multicast_v6(&multiaddr, iface)
837            .map_err(io_err_into_net_error)
838    }
839
840    fn leave_multicast_v6(&mut self, multiaddr: Ipv6Addr, iface: u32) -> Result<()> {
841        self.socket
842            .leave_multicast_v6(&multiaddr, iface)
843            .map_err(io_err_into_net_error)
844    }
845
846    fn addr_peer(&self) -> Result<Option<SocketAddr>> {
847        self.socket
848            .peer_addr()
849            .map(Some)
850            .map_err(io_err_into_net_error)
851    }
852}
853
854impl VirtualConnectionlessSocket for LocalUdpSocket {
855    fn try_send_to(&mut self, data: &[u8], addr: SocketAddr) -> Result<usize> {
856        if let Some(ruleset) = self.ruleset.as_ref() {
857            if !ruleset.allows_socket(addr, Direction::Outbound) {
858                tracing::warn!(%addr, "try_send blocked by firewall rule");
859                return Err(NetworkError::PermissionDenied);
860            }
861        }
862
863        let ret = self
864            .socket
865            .send_to(data, addr)
866            .map_err(io_err_into_net_error);
867        match &ret {
868            Ok(0) | Err(NetworkError::WouldBlock) => {
869                if let HandlerGuardState::WakerMap(_, map) = &mut self.handler_guard {
870                    map.pop(InterestType::Writable);
871                }
872            }
873            _ => {}
874        }
875        ret
876    }
877
878    fn try_recv_from(
879        &mut self,
880        buf: &mut [MaybeUninit<u8>],
881        peek: bool,
882    ) -> Result<(usize, SocketAddr)> {
883        let buf: &mut [u8] = unsafe { std::mem::transmute(buf) };
884        if peek {
885            self.socket.peek_from(buf)
886        } else {
887            self.socket.recv_from(buf)
888        }
889        .map_err(io_err_into_net_error)
890    }
891}
892
893impl VirtualSocket for LocalUdpSocket {
894    fn set_ttl(&mut self, ttl: u32) -> Result<()> {
895        self.socket.set_ttl(ttl).map_err(io_err_into_net_error)
896    }
897
898    fn ttl(&self) -> Result<u32> {
899        self.socket.ttl().map_err(io_err_into_net_error)
900    }
901
902    fn addr_local(&self) -> Result<SocketAddr> {
903        self.socket.local_addr().map_err(io_err_into_net_error)
904    }
905
906    fn status(&self) -> Result<SocketStatus> {
907        Ok(SocketStatus::Opened)
908    }
909
910    fn set_handler(&mut self, mut handler: Box<dyn InterestHandler + Send + Sync>) -> Result<()> {
911        if let HandlerGuardState::ExternalHandler(guard) = &mut self.handler_guard {
912            match guard.replace_handler(handler) {
913                Ok(()) => {
914                    return Ok(());
915                }
916                Err(h) => handler = h,
917            }
918
919            // the handler could not be replaced so we need to build a new handler instead
920            if let Err(err) = guard.unregister(&mut self.socket) {
921                tracing::debug!("failed to unregister previous token - {}", err);
922            }
923        }
924
925        let guard = InterestGuard::new(
926            &self.selector,
927            handler,
928            &mut self.socket,
929            mio::Interest::READABLE.add(mio::Interest::WRITABLE),
930        )
931        .map_err(io_err_into_net_error)?;
932
933        self.handler_guard = HandlerGuardState::ExternalHandler(guard);
934
935        Ok(())
936    }
937}
938
939impl LocalUdpSocket {
940    fn split_borrow(
941        &mut self,
942    ) -> (
943        &mut HandlerGuardState,
944        &Arc<Selector>,
945        &mut mio::net::UdpSocket,
946    ) {
947        (&mut self.handler_guard, &self.selector, &mut self.socket)
948    }
949}
950
951impl VirtualIoSource for LocalUdpSocket {
952    fn remove_handler(&mut self) {
953        let mut guard = HandlerGuardState::None;
954        std::mem::swap(&mut guard, &mut self.handler_guard);
955        match guard {
956            HandlerGuardState::ExternalHandler(mut guard) => {
957                guard.unregister(&mut self.socket).ok();
958            }
959            HandlerGuardState::WakerMap(mut guard, _) => {
960                guard.unregister(&mut self.socket).ok();
961            }
962            HandlerGuardState::None => {}
963        }
964    }
965
966    fn poll_read_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
967        if !self.backlog.is_empty() {
968            let total = self.backlog.iter().map(|a| a.0.len()).sum();
969            return Poll::Ready(Ok(total));
970        }
971
972        let (state, selector, socket) = self.split_borrow();
973        let map = state_as_waker_map(state, selector, socket).map_err(io_err_into_net_error)?;
974        map.pop(InterestType::Readable);
975        map.add(InterestType::Readable, cx.waker());
976
977        let mut buffer = BytesMut::default();
978        buffer.reserve(10240);
979        let uninit: &mut [MaybeUninit<u8>] = buffer.spare_capacity_mut();
980        let uninit_unsafe: &mut [u8] = unsafe { std::mem::transmute(uninit) };
981
982        match self.socket.recv_from(uninit_unsafe) {
983            Ok((0, _)) => Poll::Ready(Ok(0)),
984            Ok((amt, peer)) => {
985                unsafe {
986                    buffer.set_len(amt);
987                }
988                self.backlog.push_back((buffer, peer));
989                Poll::Ready(Ok(amt))
990            }
991            Err(err) if err.kind() == io::ErrorKind::ConnectionAborted => Poll::Ready(Ok(0)),
992            Err(err) if err.kind() == io::ErrorKind::ConnectionReset => Poll::Ready(Ok(0)),
993            Err(err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
994            Err(err) => Poll::Ready(Err(io_err_into_net_error(err))),
995        }
996    }
997
998    fn poll_write_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
999        let (state, selector, socket) = self.split_borrow();
1000        let map = state_as_waker_map(state, selector, socket).map_err(io_err_into_net_error)?;
1001        #[cfg(not(target_os = "windows"))]
1002        map.pop(InterestType::Writable);
1003        map.add(InterestType::Writable, cx.waker());
1004
1005        #[cfg(not(target_os = "windows"))]
1006        match libc_poll(socket.as_raw_fd(), libc::POLLOUT | libc::POLLHUP) {
1007            Some(val) if (val & libc::POLLHUP) != 0 => {
1008                return Poll::Ready(Ok(0));
1009            }
1010            Some(val) if (val & libc::POLLOUT) != 0 => return Poll::Ready(Ok(10240)),
1011            _ => {}
1012        }
1013
1014        // In windows we can not poll the socket as it is not supported and hence
1015        // what we do is immediately set the writable flag and relay on `mio` to
1016        // refresh that flag when the state changes. In Linux what we do is actually
1017        // make a non-blocking `poll` call to determine this state
1018        #[cfg(target_os = "windows")]
1019        if map.has_interest(InterestType::Writable) {
1020            return Poll::Ready(Ok(10240));
1021        }
1022
1023        Poll::Pending
1024    }
1025}