1#![allow(unused_variables)]
2use crate::ruleset::{Direction, Ruleset};
3#[allow(unused_imports)]
4use crate::{
5 IpCidr, IpRoute, NetworkError, Result, SocketStatus, StreamSecurity, VirtualConnectedSocket,
6 VirtualConnectionlessSocket, VirtualIcmpSocket, VirtualNetworking, VirtualRawSocket,
7 VirtualSocket, VirtualTcpBoundSocket, VirtualTcpListener, VirtualTcpSocket, VirtualUdpSocket,
8};
9use crate::{VirtualIoSource, io_err_into_net_error};
10use bytes::{Buf, BytesMut};
11use std::collections::VecDeque;
12use std::io::{self, Read, Write};
13use std::mem::MaybeUninit;
14use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, Shutdown, SocketAddr};
15#[cfg(not(target_os = "windows"))]
16use std::os::fd::AsRawFd;
17#[cfg(not(target_os = "windows"))]
18use std::os::fd::RawFd;
19#[cfg(windows)]
20use std::os::windows::io::AsRawSocket;
21
22use std::sync::{Arc, Mutex};
23use std::task::Poll;
24use std::time::Duration;
25use tokio::runtime::Handle;
26#[allow(unused_imports, dead_code)]
27use tracing::{debug, error, info, trace, warn};
28use virtual_mio::{
29 HandlerGuardState, InterestGuard, InterestHandler, InterestType, Selector, state_as_waker_map,
30};
31
32#[cfg(all(target_family = "unix", feature = "libc"))]
36const LISTEN_BACKLOG: i32 = libc::SOMAXCONN;
37#[cfg(not(all(target_family = "unix", feature = "libc")))]
38const LISTEN_BACKLOG: i32 = 128;
39
40#[derive(Debug)]
41pub struct LocalNetworking {
42 selector: Arc<Selector>,
43 handle: Handle,
44 ruleset: Option<Ruleset>,
45}
46
47impl LocalNetworking {
48 pub fn new() -> Self {
49 Self {
50 selector: Selector::new(),
51 handle: Handle::current(),
52 ruleset: None,
53 }
54 }
55
56 pub fn with_ruleset(ruleset: Ruleset) -> Self {
57 Self {
58 selector: Selector::new(),
59 handle: Handle::current(),
60 ruleset: Some(ruleset),
61 }
62 }
63}
64
65impl Drop for LocalNetworking {
66 fn drop(&mut self) {
67 self.selector.shutdown();
68 }
69}
70
71impl Default for LocalNetworking {
72 fn default() -> Self {
73 Self::new()
74 }
75}
76
77fn sock_addr_into_socket_addr(addr: socket2::SockAddr) -> Result<SocketAddr> {
78 addr.as_socket().ok_or(NetworkError::UnknownError)
79}
80
81fn tcp_socket_domain(addr: SocketAddr) -> socket2::Domain {
82 if addr.is_ipv4() {
83 socket2::Domain::IPV4
84 } else {
85 socket2::Domain::IPV6
86 }
87}
88
89#[allow(clippy::needless_bool)]
90fn tcp_connect_in_progress(err: &io::Error) -> bool {
91 if matches!(
92 err.kind(),
93 io::ErrorKind::WouldBlock | io::ErrorKind::Interrupted
94 ) {
95 true
96 } else {
97 #[cfg(all(target_family = "unix", feature = "libc"))]
98 {
99 matches!(
100 err.raw_os_error(),
101 Some(raw) if raw == libc::EINPROGRESS || raw == libc::EALREADY
102 )
103 }
104
105 #[cfg(not(all(target_family = "unix", feature = "libc")))]
106 {
107 false
108 }
109 }
110}
111
112#[async_trait::async_trait]
113#[allow(unused_variables)]
114impl VirtualNetworking for LocalNetworking {
115 async fn listen_tcp(
116 &self,
117 addr: SocketAddr,
118 only_v6: bool,
119 reuse_port: bool,
120 reuse_addr: bool,
121 ) -> Result<Box<dyn VirtualTcpListener + Sync>> {
122 if let Some(ruleset) = self.ruleset.as_ref()
123 && !ruleset.allows_socket(addr, Direction::Inbound)
124 {
125 tracing::warn!(%addr, "listen_tcp blocked by firewall rule");
126 return Err(NetworkError::PermissionDenied);
127 }
128
129 self.bind_tcp(addr, only_v6, reuse_port, reuse_addr)
130 .await?
131 .listen()
132 }
133
134 async fn bind_tcp(
135 &self,
136 addr: SocketAddr,
137 only_v6: bool,
138 reuse_port: bool,
139 reuse_addr: bool,
140 ) -> Result<Box<dyn VirtualTcpBoundSocket + Sync>> {
141 if let Some(ruleset) = self.ruleset.as_ref()
142 && !ruleset.allows_socket(addr, Direction::Inbound)
143 {
144 tracing::warn!(%addr, "bind_tcp blocked by firewall rule");
145 return Err(NetworkError::PermissionDenied);
146 }
147
148 let socket = socket2::Socket::new(tcp_socket_domain(addr), socket2::Type::STREAM, None)
149 .map_err(io_err_into_net_error)?;
150 socket
151 .set_nonblocking(true)
152 .map_err(io_err_into_net_error)?;
153 if addr.is_ipv6() {
154 socket.set_only_v6(only_v6).map_err(io_err_into_net_error)?;
155 }
156 socket
157 .set_reuse_address(reuse_addr)
158 .map_err(io_err_into_net_error)?;
159 #[cfg(not(windows))]
160 socket
161 .set_reuse_port(reuse_port)
162 .map_err(io_err_into_net_error)?;
163 socket.bind(&addr.into()).map_err(io_err_into_net_error)?;
164
165 Ok(Box::new(LocalTcpBoundSocket {
166 socket: Some(socket),
167 selector: self.selector.clone(),
168 ruleset: self.ruleset.clone(),
169 }))
170 }
171
172 async fn bind_udp(
173 &self,
174 addr: SocketAddr,
175 reuse_port: bool,
176 reuse_addr: bool,
177 ) -> Result<Box<dyn VirtualUdpSocket + Sync>> {
178 #[cfg(not(windows))]
179 use socket2::{Domain, Socket, Type};
180
181 if let Some(ruleset) = self.ruleset.as_ref()
182 && !ruleset.allows_socket(addr, Direction::Inbound)
183 {
184 tracing::warn!(%addr, "bind_udp blocked by firewall rule");
185 return Err(NetworkError::PermissionDenied);
186 }
187
188 #[cfg(not(windows))]
189 let socket = {
190 let domain = if addr.is_ipv4() {
191 Domain::IPV4
192 } else {
193 Domain::IPV6
194 };
195 let std_sock = Socket::new(domain, Type::DGRAM, None).map_err(io_err_into_net_error)?;
196 std_sock
197 .set_nonblocking(true)
198 .map_err(io_err_into_net_error)?;
199 std_sock
200 .set_reuse_address(reuse_addr)
201 .map_err(io_err_into_net_error)?;
202 std_sock
203 .set_reuse_port(reuse_port)
204 .map_err(io_err_into_net_error)?;
205 std_sock.bind(&addr.into()).map_err(io_err_into_net_error)?;
206 mio::net::UdpSocket::from_std(std_sock.into())
207 };
208 #[cfg(windows)]
209 let socket = mio::net::UdpSocket::bind(addr).map_err(io_err_into_net_error)?;
210
211 #[allow(unused_mut)]
212 let mut ret = LocalUdpSocket {
213 selector: self.selector.clone(),
214 socket,
215 addr,
216 handler_guard: HandlerGuardState::None,
217 backlog: Default::default(),
218 ruleset: self.ruleset.clone(),
219 };
220
221 #[cfg(target_os = "windows")]
226 {
227 let (state, selector, socket) = ret.split_borrow();
228 let map = state_as_waker_map(state, selector, socket).map_err(io_err_into_net_error)?;
229 map.push(InterestType::Writable);
230 }
231
232 Ok(Box::new(ret))
233 }
234
235 async fn connect_tcp(
236 &self,
237 _addr: SocketAddr,
238 mut peer: SocketAddr,
239 ) -> Result<Box<dyn VirtualTcpSocket + Sync>> {
240 if let Some(ruleset) = self.ruleset.as_ref()
241 && !ruleset.allows_socket(peer, Direction::Outbound)
242 {
243 tracing::warn!(%peer, "connect_tcp blocked by firewall rule");
244 return Err(NetworkError::PermissionDenied);
245 }
246
247 let stream = mio::net::TcpStream::connect(peer).map_err(io_err_into_net_error)?;
248
249 if let Ok(p) = stream.peer_addr() {
250 peer = p;
251 }
252 let socket = Box::new(LocalTcpStream::new(self.selector.clone(), stream, peer));
253 Ok(socket)
254 }
255
256 async fn resolve(
257 &self,
258 host: &str,
259 port: Option<u16>,
260 dns_server: Option<IpAddr>,
261 ) -> Result<Vec<IpAddr>> {
262 if let Some(ruleset) = self.ruleset.as_ref()
263 && !ruleset.allows_domain(host)
264 {
265 tracing::warn!(%host, "dns resolve blocked by firewall rule");
266 return Err(NetworkError::PermissionDenied);
267 }
268
269 let host_to_lookup = if host.contains(':') {
270 host.to_string()
271 } else {
272 format!("{}:{}", host, port.unwrap_or(0))
273 };
274 let addrs = self
275 .handle
276 .spawn(tokio::net::lookup_host(host_to_lookup))
277 .await
278 .map_err(|_| NetworkError::IOError)?
279 .map(|a| a.map(|a| a.ip()).collect::<Vec<_>>())
280 .map_err(io_err_into_net_error)?;
281
282 if let Some(ruleset) = self.ruleset.as_ref() {
283 if let Err(e) = ruleset.expand_domain(host, &addrs) {
284 tracing::debug!(err=%e, "ruleset expansion failed");
285 } else {
286 tracing::debug!(addrs=?addrs, domain = host, "ruleset expansion")
287 }
288 }
289
290 Ok(addrs)
291 }
292}
293
294#[derive(Debug)]
295pub struct LocalTcpListener {
296 stream: mio::net::TcpListener,
297 selector: Arc<Selector>,
298 handler_guard: HandlerGuardState,
299 no_delay: Option<bool>,
300 keep_alive: Option<bool>,
301 backlog: VecDeque<(Box<dyn VirtualTcpSocket + Sync>, SocketAddr)>,
302 ruleset: Option<Ruleset>,
303}
304
305impl LocalTcpListener {
306 fn try_accept_internal(&mut self) -> Result<(Box<dyn VirtualTcpSocket + Sync>, SocketAddr)> {
307 match self.stream.accept().map_err(io_err_into_net_error) {
308 Ok((stream, addr)) => {
309 if let Some(ruleset) = self.ruleset.as_ref()
310 && !ruleset.allows_socket(addr, Direction::Outbound)
311 {
312 tracing::warn!(%addr, "try_accept blocked by firewall rule");
313 return Err(NetworkError::PermissionDenied);
314 }
315
316 let mut socket = LocalTcpStream::new(self.selector.clone(), stream, addr);
317 if let Some(no_delay) = self.no_delay {
318 socket.set_nodelay(no_delay).ok();
319 }
320 if let Some(keep_alive) = self.keep_alive {
321 socket.set_keepalive(keep_alive).ok();
322 }
323 Ok((Box::new(socket), addr))
324 }
325 Err(NetworkError::WouldBlock) => {
326 if let HandlerGuardState::WakerMap(_, map) = &mut self.handler_guard {
327 map.pop(InterestType::Readable);
328 map.pop(InterestType::Writable);
329 }
330 Err(NetworkError::WouldBlock)
331 }
332 Err(err) => Err(err),
333 }
334 }
335}
336
337impl VirtualTcpListener for LocalTcpListener {
338 fn try_accept(&mut self) -> Result<(Box<dyn VirtualTcpSocket + Sync>, SocketAddr)> {
339 if let Some(child) = self.backlog.pop_front() {
340 return Ok(child);
341 }
342 self.try_accept_internal()
343 }
344
345 fn set_handler(&mut self, mut handler: Box<dyn InterestHandler + Send + Sync>) -> Result<()> {
346 if let HandlerGuardState::ExternalHandler(guard) = &mut self.handler_guard {
347 match guard.replace_handler(handler) {
348 Ok(()) => return Ok(()),
349 Err(h) => handler = h,
350 }
351
352 if let Err(err) = guard.unregister(&mut self.stream) {
354 tracing::debug!("failed to unregister previous token - {}", err);
355 }
356 }
357
358 let guard = InterestGuard::new(
359 &self.selector,
360 handler,
361 &mut self.stream,
362 mio::Interest::READABLE.add(mio::Interest::WRITABLE),
363 )
364 .map_err(io_err_into_net_error)?;
365
366 self.handler_guard = HandlerGuardState::ExternalHandler(guard);
367
368 Ok(())
369 }
370
371 fn addr_local(&self) -> Result<SocketAddr> {
372 self.stream.local_addr().map_err(io_err_into_net_error)
373 }
374
375 fn set_ttl(&mut self, ttl: u8) -> Result<()> {
376 self.stream
377 .set_ttl(ttl as u32)
378 .map_err(io_err_into_net_error)
379 }
380
381 fn ttl(&self) -> Result<u8> {
382 self.stream
383 .ttl()
384 .map(|ttl| ttl as u8)
385 .map_err(io_err_into_net_error)
386 }
387}
388
389impl LocalTcpListener {
390 fn split_borrow(
391 &mut self,
392 ) -> (
393 &mut HandlerGuardState,
394 &Arc<Selector>,
395 &mut mio::net::TcpListener,
396 ) {
397 (&mut self.handler_guard, &self.selector, &mut self.stream)
398 }
399}
400
401impl VirtualIoSource for LocalTcpListener {
402 fn remove_handler(&mut self) {
403 let mut guard = HandlerGuardState::None;
404 std::mem::swap(&mut guard, &mut self.handler_guard);
405 match guard {
406 HandlerGuardState::ExternalHandler(mut guard) => {
407 guard.unregister(&mut self.stream).ok();
408 }
409 HandlerGuardState::WakerMap(mut guard, _) => {
410 guard.unregister(&mut self.stream).ok();
411 }
412 HandlerGuardState::None => {}
413 }
414 }
415
416 fn poll_read_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
417 if !self.backlog.is_empty() {
418 return Poll::Ready(Ok(self.backlog.len()));
419 }
420
421 let (state, selector, source) = self.split_borrow();
422 let map = state_as_waker_map(state, selector, source).map_err(io_err_into_net_error)?;
423 map.add(InterestType::Readable, cx.waker());
424
425 if let Ok(child) = self.try_accept_internal() {
426 self.backlog.push_back(child);
427 return Poll::Ready(Ok(1));
428 }
429 Poll::Pending
430 }
431
432 fn poll_write_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
433 if !self.backlog.is_empty() {
434 return Poll::Ready(Ok(self.backlog.len()));
435 }
436
437 let (state, selector, source) = self.split_borrow();
438 let map = state_as_waker_map(state, selector, source).map_err(io_err_into_net_error)?;
439 map.add(InterestType::Writable, cx.waker());
440
441 if let Ok(child) = self.try_accept_internal() {
442 self.backlog.push_back(child);
443 return Poll::Ready(Ok(1));
444 }
445 Poll::Pending
446 }
447}
448
449#[derive(Debug)]
450pub struct LocalTcpBoundSocket {
451 socket: Option<socket2::Socket>,
452 selector: Arc<Selector>,
453 ruleset: Option<Ruleset>,
454}
455
456impl VirtualTcpBoundSocket for LocalTcpBoundSocket {
457 fn addr_local(&self) -> Result<SocketAddr> {
458 let socket = self.socket.as_ref().ok_or(NetworkError::InvalidFd)?;
459 let addr = socket.local_addr().map_err(io_err_into_net_error)?;
460 sock_addr_into_socket_addr(addr)
461 }
462
463 fn listen(&mut self) -> Result<Box<dyn VirtualTcpListener + Sync>> {
464 let socket = self.socket.take().ok_or(NetworkError::InvalidFd)?;
465 socket
466 .listen(LISTEN_BACKLOG)
467 .map_err(io_err_into_net_error)?;
468 let listener = mio::net::TcpListener::from_std(socket.into());
469 Ok(Box::new(LocalTcpListener {
470 stream: listener,
471 selector: self.selector.clone(),
472 handler_guard: HandlerGuardState::None,
473 no_delay: None,
474 keep_alive: None,
475 backlog: Default::default(),
476 ruleset: self.ruleset.clone(),
477 }))
478 }
479
480 fn connect(&mut self, mut peer: SocketAddr) -> Result<Box<dyn VirtualTcpSocket + Sync>> {
481 if let Some(ruleset) = self.ruleset.as_ref()
482 && !ruleset.allows_socket(peer, Direction::Outbound)
483 {
484 tracing::warn!(%peer, "bound connect_tcp blocked by firewall rule");
485 return Err(NetworkError::PermissionDenied);
486 }
487
488 let socket = self.socket.take().ok_or(NetworkError::InvalidFd)?;
489 if let Err(err) = socket.connect(&peer.into())
490 && !tcp_connect_in_progress(&err)
491 {
492 return Err(io_err_into_net_error(err));
493 }
494
495 let stream = mio::net::TcpStream::from_std(socket.into());
496 if let Ok(p) = stream.peer_addr() {
497 peer = p;
498 }
499 Ok(Box::new(LocalTcpStream::new(
500 self.selector.clone(),
501 stream,
502 peer,
503 )))
504 }
505
506 fn set_ttl(&mut self, ttl: u32) -> Result<()> {
507 let socket = self.socket.as_ref().ok_or(NetworkError::InvalidFd)?;
508 match self.addr_local()?.ip() {
509 IpAddr::V4(_) => socket.set_ttl_v4(ttl).map_err(io_err_into_net_error),
510 IpAddr::V6(_) => socket
511 .set_unicast_hops_v6(ttl)
512 .map_err(io_err_into_net_error),
513 }
514 }
515
516 fn ttl(&self) -> Result<u32> {
517 let socket = self.socket.as_ref().ok_or(NetworkError::InvalidFd)?;
518 match self.addr_local()?.ip() {
519 IpAddr::V4(_) => socket.ttl_v4().map_err(io_err_into_net_error),
520 IpAddr::V6(_) => socket.unicast_hops_v6().map_err(io_err_into_net_error),
521 }
522 }
523}
524
525#[derive(Debug)]
526enum ConnectState {
527 Unknown,
528 Opened,
529 Failed,
530}
531
532#[derive(Debug)]
533pub struct LocalTcpStream {
534 stream: mio::net::TcpStream,
535 addr: SocketAddr,
536 shutdown: Option<Shutdown>,
537 selector: Arc<Selector>,
538 handler_guard: HandlerGuardState,
539 buffer: BytesMut,
540 connect_state: Mutex<ConnectState>,
541}
542
543impl LocalTcpStream {
544 fn new(selector: Arc<Selector>, stream: mio::net::TcpStream, addr: SocketAddr) -> Self {
545 #[allow(unused_mut)]
546 let mut ret = Self {
547 stream,
548 addr,
549 shutdown: None,
550 selector,
551 handler_guard: HandlerGuardState::None,
552 buffer: BytesMut::new(),
553 connect_state: Mutex::new(ConnectState::Unknown),
554 };
555
556 #[cfg(target_os = "windows")]
561 {
562 let (state, selector, socket, _) = ret.split_borrow();
563 if let Ok(map) = state_as_waker_map(state, selector, socket) {
564 map.push(InterestType::Writable);
565 }
566 }
567
568 ret
569 }
570
571 fn with_sock_ref<F, R>(&self, f: F) -> R
572 where
573 for<'a> F: FnOnce(socket2::SockRef<'a>) -> R,
574 {
575 #[cfg(not(windows))]
576 let r = socket2::SockRef::from(&self.stream);
577
578 #[cfg(windows)]
579 let b = unsafe {
580 std::os::windows::io::BorrowedSocket::borrow_raw(self.stream.as_raw_socket())
581 };
582 #[cfg(windows)]
583 let r = socket2::SockRef::from(&b);
584
585 f(r)
586 }
587}
588
589impl VirtualTcpSocket for LocalTcpStream {
590 fn set_recv_buf_size(&mut self, size: usize) -> Result<()> {
591 Ok(())
592 }
593
594 fn recv_buf_size(&self) -> Result<usize> {
595 Err(NetworkError::Unsupported)
596 }
597
598 fn set_send_buf_size(&mut self, size: usize) -> Result<()> {
599 Ok(())
600 }
601
602 fn send_buf_size(&self) -> Result<usize> {
603 Err(NetworkError::Unsupported)
604 }
605
606 fn set_nodelay(&mut self, nodelay: bool) -> Result<()> {
607 self.stream
608 .set_nodelay(nodelay)
609 .map_err(io_err_into_net_error)
610 }
611
612 fn nodelay(&self) -> Result<bool> {
613 self.stream.nodelay().map_err(io_err_into_net_error)
614 }
615
616 fn set_keepalive(&mut self, keepalive: bool) -> Result<()> {
617 self.with_sock_ref(|s| s.set_keepalive(true))
618 .map_err(io_err_into_net_error)?;
619 Ok(())
620 }
621
622 fn keepalive(&self) -> Result<bool> {
623 let ret = self
624 .with_sock_ref(|s| s.keepalive())
625 .map_err(io_err_into_net_error)?;
626 Ok(ret)
627 }
628
629 #[cfg(not(target_os = "windows"))]
630 fn set_dontroute(&mut self, val: bool) -> Result<()> {
631 let val = val as libc::c_int;
637 let payload = &val as *const libc::c_int as *const libc::c_void;
638 let err = unsafe {
639 libc::setsockopt(
640 self.stream.as_raw_fd(),
641 libc::SOL_SOCKET,
642 libc::SO_DONTROUTE,
643 payload,
644 std::mem::size_of::<libc::c_int>() as libc::socklen_t,
645 )
646 };
647 if err == -1 {
648 return Err(io_err_into_net_error(std::io::Error::last_os_error()));
649 }
650 Ok(())
651 }
652 #[cfg(target_os = "windows")]
653 fn set_dontroute(&mut self, val: bool) -> Result<()> {
654 Err(NetworkError::Unsupported)
655 }
656
657 #[cfg(not(target_os = "windows"))]
658 fn dontroute(&self) -> Result<bool> {
659 let mut payload: MaybeUninit<libc::c_int> = MaybeUninit::uninit();
660 let mut len = std::mem::size_of::<libc::c_int>() as libc::socklen_t;
661 let err = unsafe {
662 libc::getsockopt(
663 self.stream.as_raw_fd(),
664 libc::SOL_SOCKET,
665 libc::SO_DONTROUTE,
666 payload.as_mut_ptr().cast(),
667 &mut len,
668 )
669 };
670 if err == -1 {
671 return Err(io_err_into_net_error(std::io::Error::last_os_error()));
672 }
673 Ok(unsafe { payload.assume_init() != 0 })
674 }
675 #[cfg(target_os = "windows")]
676 fn dontroute(&self) -> Result<bool> {
677 Err(NetworkError::Unsupported)
678 }
679
680 fn addr_peer(&self) -> Result<SocketAddr> {
681 Ok(self.addr)
682 }
683
684 fn shutdown(&mut self, how: Shutdown) -> Result<()> {
685 self.stream.shutdown(how).map_err(io_err_into_net_error)?;
686 self.shutdown = Some(how);
687 Ok(())
688 }
689
690 fn is_closed(&self) -> bool {
691 false
692 }
693}
694
695impl VirtualConnectedSocket for LocalTcpStream {
696 fn set_linger(&mut self, linger: Option<Duration>) -> Result<()> {
697 self.with_sock_ref(|s| s.set_linger(linger))
698 .map_err(io_err_into_net_error)?;
699 Ok(())
700 }
701
702 fn linger(&self) -> Result<Option<Duration>> {
703 self.with_sock_ref(|s| s.linger())
704 .map_err(io_err_into_net_error)
705 }
706
707 fn try_send(&mut self, data: &[u8]) -> Result<usize> {
708 let ret = self.stream.write(data).map_err(io_err_into_net_error);
709 match &ret {
710 Ok(0) | Err(NetworkError::WouldBlock) => {
711 if let HandlerGuardState::WakerMap(_, map) = &mut self.handler_guard {
712 map.pop(InterestType::Writable);
713 }
714 }
715 _ => {}
716 }
717 ret
718 }
719
720 fn try_flush(&mut self) -> Result<()> {
721 self.stream.flush().map_err(io_err_into_net_error)
722 }
723
724 fn close(&mut self) -> Result<()> {
725 Ok(())
726 }
727
728 fn try_recv(&mut self, buf: &mut [MaybeUninit<u8>], peek: bool) -> Result<usize> {
729 let buf: &mut [u8] = unsafe { std::mem::transmute(buf) };
730 if !self.buffer.is_empty() {
731 let amt = buf.len().min(self.buffer.len());
732 buf[..amt].copy_from_slice(&self.buffer[..amt]);
733 if !peek {
734 self.buffer.advance(amt);
735 }
736 return Ok(amt);
737 }
738
739 if peek {
740 self.stream.peek(buf)
741 } else {
742 self.stream.read(buf)
743 }
744 .map_err(io_err_into_net_error)
745 }
746}
747
748impl VirtualSocket for LocalTcpStream {
749 fn set_ttl(&mut self, ttl: u32) -> Result<()> {
750 self.stream.set_ttl(ttl).map_err(io_err_into_net_error)
751 }
752
753 fn ttl(&self) -> Result<u32> {
754 self.stream.ttl().map_err(io_err_into_net_error)
755 }
756
757 fn addr_local(&self) -> Result<SocketAddr> {
758 self.stream.local_addr().map_err(io_err_into_net_error)
759 }
760
761 fn status(&self) -> Result<SocketStatus> {
762 let mut connect_state = self.connect_state.lock().unwrap();
765 match *connect_state {
766 ConnectState::Opened => return Ok(SocketStatus::Opened),
767 ConnectState::Failed => return Ok(SocketStatus::Failed),
768 ConnectState::Unknown => {}
769 }
770
771 if self
772 .with_sock_ref(|sockref| sockref.take_error())
773 .map_err(io_err_into_net_error)?
774 .is_some()
775 {
776 *connect_state = ConnectState::Failed;
777 return Ok(SocketStatus::Failed); }
779 match self.stream.peer_addr() {
780 Ok(_) => {
781 *connect_state = ConnectState::Opened;
782 Ok(SocketStatus::Opened) }
784 Err(err) => {
785 if matches!(
786 err.kind(),
787 io::ErrorKind::NotConnected | io::ErrorKind::WouldBlock
788 ) {
789 Ok(SocketStatus::Opening) } else {
791 *connect_state = ConnectState::Failed;
793 Ok(SocketStatus::Failed) }
795 }
796 }
797 }
798
799 fn set_handler(&mut self, mut handler: Box<dyn InterestHandler + Send + Sync>) -> Result<()> {
800 if let HandlerGuardState::ExternalHandler(guard) = &mut self.handler_guard {
801 match guard.replace_handler(handler) {
802 Ok(()) => return Ok(()),
803 Err(h) => handler = h,
804 }
805
806 if let Err(err) = guard.unregister(&mut self.stream) {
808 tracing::debug!("failed to unregister previous token - {}", err);
809 }
810 }
811
812 let guard = InterestGuard::new(
813 &self.selector,
814 handler,
815 &mut self.stream,
816 mio::Interest::READABLE.add(mio::Interest::WRITABLE),
817 )
818 .map_err(io_err_into_net_error)?;
819
820 self.handler_guard = HandlerGuardState::ExternalHandler(guard);
821
822 Ok(())
823 }
824}
825
826impl LocalTcpStream {
827 fn split_borrow(
828 &mut self,
829 ) -> (
830 &mut HandlerGuardState,
831 &Arc<Selector>,
832 &mut mio::net::TcpStream,
833 &mut BytesMut,
834 ) {
835 (
836 &mut self.handler_guard,
837 &self.selector,
838 &mut self.stream,
839 &mut self.buffer,
840 )
841 }
842}
843
844impl VirtualIoSource for LocalTcpStream {
845 fn remove_handler(&mut self) {
846 let mut guard = HandlerGuardState::None;
847 std::mem::swap(&mut guard, &mut self.handler_guard);
848 match guard {
849 HandlerGuardState::ExternalHandler(mut guard) => {
850 guard.unregister(&mut self.stream).ok();
851 }
852 HandlerGuardState::WakerMap(mut guard, _) => {
853 guard.unregister(&mut self.stream).ok();
854 }
855 HandlerGuardState::None => {}
856 }
857 }
858
859 fn poll_read_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
860 if !self.buffer.is_empty() {
861 return Poll::Ready(Ok(self.buffer.len()));
862 }
863
864 let (state, selector, stream, buffer) = self.split_borrow();
865 let map = state_as_waker_map(state, selector, stream).map_err(io_err_into_net_error)?;
866 map.pop(InterestType::Readable);
867 map.add(InterestType::Readable, cx.waker());
868
869 buffer.reserve(buffer.len() + 10240);
870 let uninit: &mut [MaybeUninit<u8>] = buffer.spare_capacity_mut();
871 let uninit_unsafe: &mut [u8] = unsafe { std::mem::transmute(uninit) };
872
873 match stream.read(uninit_unsafe) {
874 Ok(0) => Poll::Ready(Ok(0)),
875 Ok(amt) => {
876 unsafe {
877 buffer.set_len(buffer.len() + amt);
878 }
879 Poll::Ready(Ok(amt))
880 }
881 Err(err) if err.kind() == io::ErrorKind::ConnectionAborted => Poll::Ready(Ok(0)),
882 Err(err) if err.kind() == io::ErrorKind::ConnectionReset => Poll::Ready(Ok(0)),
883 Err(err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
884 Err(err) => Poll::Ready(Err(io_err_into_net_error(err))),
885 }
886 }
887
888 fn poll_write_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
889 let (state, selector, stream, _) = self.split_borrow();
890 let map = state_as_waker_map(state, selector, stream).map_err(io_err_into_net_error)?;
891 #[cfg(not(target_os = "windows"))]
892 map.pop(InterestType::Writable);
893 map.add(InterestType::Writable, cx.waker());
894 map.add(InterestType::Closed, cx.waker());
895 if map.has_interest(InterestType::Closed) {
896 return Poll::Ready(Ok(0));
897 }
898
899 #[cfg(not(target_os = "windows"))]
900 match libc_poll(stream.as_raw_fd(), libc::POLLOUT | libc::POLLHUP) {
901 Some(val) if (val & libc::POLLHUP) != 0 => {
902 return Poll::Ready(Ok(0));
903 }
904 Some(val) if (val & libc::POLLOUT) != 0 => return Poll::Ready(Ok(10240)),
905 _ => {}
906 }
907
908 #[cfg(target_os = "windows")]
913 if map.has_interest(InterestType::Writable) {
914 return Poll::Ready(Ok(10240));
915 }
916
917 Poll::Pending
918 }
919}
920
921#[cfg(not(target_os = "windows"))]
922fn libc_poll(fd: RawFd, events: libc::c_short) -> Option<libc::c_short> {
923 let mut fds: [libc::pollfd; 1] = [libc::pollfd {
924 fd,
925 events,
926 revents: 0,
927 }];
928 let fds_mut = &mut fds[..];
929 let ret = unsafe { libc::poll(fds_mut.as_mut_ptr(), 1, 0) };
930 match ret == 1 {
931 true => Some(fds[0].revents),
932 false => None,
933 }
934}
935
936#[derive(Debug)]
937pub struct LocalUdpSocket {
938 socket: mio::net::UdpSocket,
939 #[allow(dead_code)]
940 addr: SocketAddr,
941 selector: Arc<Selector>,
942 handler_guard: HandlerGuardState,
943 backlog: VecDeque<(BytesMut, SocketAddr)>,
944 ruleset: Option<Ruleset>,
945}
946
947impl LocalUdpSocket {
948 fn with_sock_ref<F, R>(&self, f: F) -> R
949 where
950 for<'a> F: FnOnce(socket2::SockRef<'a>) -> R,
951 {
952 #[cfg(not(windows))]
953 let r = socket2::SockRef::from(&self.socket);
954
955 #[cfg(windows)]
956 let b = unsafe {
957 std::os::windows::io::BorrowedSocket::borrow_raw(self.socket.as_raw_socket())
958 };
959 #[cfg(windows)]
960 let r = socket2::SockRef::from(&b);
961
962 f(r)
963 }
964}
965
966impl VirtualUdpSocket for LocalUdpSocket {
967 fn set_broadcast(&mut self, broadcast: bool) -> Result<()> {
968 self.socket
969 .set_broadcast(broadcast)
970 .map_err(io_err_into_net_error)
971 }
972
973 fn broadcast(&self) -> Result<bool> {
974 self.socket.broadcast().map_err(io_err_into_net_error)
975 }
976
977 fn set_multicast_loop_v4(&mut self, val: bool) -> Result<()> {
978 self.socket
979 .set_multicast_loop_v4(val)
980 .map_err(io_err_into_net_error)
981 }
982
983 fn multicast_loop_v4(&self) -> Result<bool> {
984 self.socket
985 .multicast_loop_v4()
986 .map_err(io_err_into_net_error)
987 }
988
989 fn set_multicast_loop_v6(&mut self, val: bool) -> Result<()> {
990 self.socket
991 .set_multicast_loop_v6(val)
992 .map_err(io_err_into_net_error)
993 }
994
995 fn multicast_loop_v6(&self) -> Result<bool> {
996 self.socket
997 .multicast_loop_v6()
998 .map_err(io_err_into_net_error)
999 }
1000
1001 fn set_multicast_ttl_v4(&mut self, ttl: u32) -> Result<()> {
1002 self.socket
1003 .set_multicast_ttl_v4(ttl)
1004 .map_err(io_err_into_net_error)
1005 }
1006
1007 fn multicast_ttl_v4(&self) -> Result<u32> {
1008 self.socket
1009 .multicast_ttl_v4()
1010 .map_err(io_err_into_net_error)
1011 }
1012
1013 fn join_multicast_v4(&mut self, multiaddr: Ipv4Addr, iface: Ipv4Addr) -> Result<()> {
1014 self.with_sock_ref(|s| s.join_multicast_v4(&multiaddr, &iface))
1015 .map_err(io_err_into_net_error)
1016 }
1017
1018 fn leave_multicast_v4(&mut self, multiaddr: Ipv4Addr, iface: Ipv4Addr) -> Result<()> {
1019 self.with_sock_ref(|s| s.leave_multicast_v4(&multiaddr, &iface))
1020 .map_err(io_err_into_net_error)
1021 }
1022
1023 fn join_multicast_v6(&mut self, multiaddr: Ipv6Addr, iface: u32) -> Result<()> {
1024 self.socket
1025 .join_multicast_v6(&multiaddr, iface)
1026 .map_err(io_err_into_net_error)
1027 }
1028
1029 fn leave_multicast_v6(&mut self, multiaddr: Ipv6Addr, iface: u32) -> Result<()> {
1030 self.socket
1031 .leave_multicast_v6(&multiaddr, iface)
1032 .map_err(io_err_into_net_error)
1033 }
1034
1035 fn addr_peer(&self) -> Result<Option<SocketAddr>> {
1036 self.socket
1037 .peer_addr()
1038 .map(Some)
1039 .map_err(io_err_into_net_error)
1040 }
1041}
1042
1043impl VirtualConnectionlessSocket for LocalUdpSocket {
1044 fn try_send_to(&mut self, data: &[u8], addr: SocketAddr) -> Result<usize> {
1045 if let Some(ruleset) = self.ruleset.as_ref()
1046 && !ruleset.allows_socket(addr, Direction::Outbound)
1047 {
1048 tracing::warn!(%addr, "try_send blocked by firewall rule");
1049 return Err(NetworkError::PermissionDenied);
1050 }
1051
1052 let ret = self
1053 .socket
1054 .send_to(data, addr)
1055 .map_err(io_err_into_net_error);
1056 match &ret {
1057 Ok(0) | Err(NetworkError::WouldBlock) => {
1058 if let HandlerGuardState::WakerMap(_, map) = &mut self.handler_guard {
1059 map.pop(InterestType::Writable);
1060 }
1061 }
1062 _ => {}
1063 }
1064 ret
1065 }
1066
1067 fn try_recv_from(
1068 &mut self,
1069 buf: &mut [MaybeUninit<u8>],
1070 peek: bool,
1071 ) -> Result<(usize, SocketAddr)> {
1072 let buf: &mut [u8] = unsafe { std::mem::transmute(buf) };
1073 if peek {
1074 self.socket.peek_from(buf)
1075 } else {
1076 self.socket.recv_from(buf)
1077 }
1078 .map_err(io_err_into_net_error)
1079 }
1080}
1081
1082impl VirtualSocket for LocalUdpSocket {
1083 fn set_ttl(&mut self, ttl: u32) -> Result<()> {
1084 self.socket.set_ttl(ttl).map_err(io_err_into_net_error)
1085 }
1086
1087 fn ttl(&self) -> Result<u32> {
1088 self.socket.ttl().map_err(io_err_into_net_error)
1089 }
1090
1091 fn addr_local(&self) -> Result<SocketAddr> {
1092 self.socket.local_addr().map_err(io_err_into_net_error)
1093 }
1094
1095 fn status(&self) -> Result<SocketStatus> {
1096 Ok(SocketStatus::Opened)
1097 }
1098
1099 fn set_handler(&mut self, mut handler: Box<dyn InterestHandler + Send + Sync>) -> Result<()> {
1100 if let HandlerGuardState::ExternalHandler(guard) = &mut self.handler_guard {
1101 match guard.replace_handler(handler) {
1102 Ok(()) => {
1103 return Ok(());
1104 }
1105 Err(h) => handler = h,
1106 }
1107
1108 if let Err(err) = guard.unregister(&mut self.socket) {
1110 tracing::debug!("failed to unregister previous token - {}", err);
1111 }
1112 }
1113
1114 let guard = InterestGuard::new(
1115 &self.selector,
1116 handler,
1117 &mut self.socket,
1118 mio::Interest::READABLE.add(mio::Interest::WRITABLE),
1119 )
1120 .map_err(io_err_into_net_error)?;
1121
1122 self.handler_guard = HandlerGuardState::ExternalHandler(guard);
1123
1124 Ok(())
1125 }
1126}
1127
1128impl LocalUdpSocket {
1129 fn split_borrow(
1130 &mut self,
1131 ) -> (
1132 &mut HandlerGuardState,
1133 &Arc<Selector>,
1134 &mut mio::net::UdpSocket,
1135 ) {
1136 (&mut self.handler_guard, &self.selector, &mut self.socket)
1137 }
1138}
1139
1140impl VirtualIoSource for LocalUdpSocket {
1141 fn remove_handler(&mut self) {
1142 let mut guard = HandlerGuardState::None;
1143 std::mem::swap(&mut guard, &mut self.handler_guard);
1144 match guard {
1145 HandlerGuardState::ExternalHandler(mut guard) => {
1146 guard.unregister(&mut self.socket).ok();
1147 }
1148 HandlerGuardState::WakerMap(mut guard, _) => {
1149 guard.unregister(&mut self.socket).ok();
1150 }
1151 HandlerGuardState::None => {}
1152 }
1153 }
1154
1155 fn poll_read_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
1156 if !self.backlog.is_empty() {
1157 let total = self.backlog.iter().map(|a| a.0.len()).sum();
1158 return Poll::Ready(Ok(total));
1159 }
1160
1161 let (state, selector, socket) = self.split_borrow();
1162 let map = state_as_waker_map(state, selector, socket).map_err(io_err_into_net_error)?;
1163 map.pop(InterestType::Readable);
1164 map.add(InterestType::Readable, cx.waker());
1165
1166 let mut buffer = BytesMut::default();
1167 buffer.reserve(10240);
1168 let uninit: &mut [MaybeUninit<u8>] = buffer.spare_capacity_mut();
1169 let uninit_unsafe: &mut [u8] = unsafe { std::mem::transmute(uninit) };
1170
1171 match self.socket.recv_from(uninit_unsafe) {
1172 Ok((0, _)) => Poll::Ready(Ok(0)),
1173 Ok((amt, peer)) => {
1174 unsafe {
1175 buffer.set_len(amt);
1176 }
1177 self.backlog.push_back((buffer, peer));
1178 Poll::Ready(Ok(amt))
1179 }
1180 Err(err) if err.kind() == io::ErrorKind::ConnectionAborted => Poll::Ready(Ok(0)),
1181 Err(err) if err.kind() == io::ErrorKind::ConnectionReset => Poll::Ready(Ok(0)),
1182 Err(err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
1183 Err(err) => Poll::Ready(Err(io_err_into_net_error(err))),
1184 }
1185 }
1186
1187 fn poll_write_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
1188 let (state, selector, socket) = self.split_borrow();
1189 let map = state_as_waker_map(state, selector, socket).map_err(io_err_into_net_error)?;
1190 #[cfg(not(target_os = "windows"))]
1191 map.pop(InterestType::Writable);
1192 map.add(InterestType::Writable, cx.waker());
1193
1194 #[cfg(not(target_os = "windows"))]
1195 match libc_poll(socket.as_raw_fd(), libc::POLLOUT | libc::POLLHUP) {
1196 Some(val) if (val & libc::POLLHUP) != 0 => {
1197 return Poll::Ready(Ok(0));
1198 }
1199 Some(val) if (val & libc::POLLOUT) != 0 => return Poll::Ready(Ok(10240)),
1200 _ => {}
1201 }
1202
1203 #[cfg(target_os = "windows")]
1208 if map.has_interest(InterestType::Writable) {
1209 return Poll::Ready(Ok(10240));
1210 }
1211
1212 Poll::Pending
1213 }
1214}