1#![allow(unused_variables)]
2use crate::ruleset::{Direction, Ruleset};
3#[allow(unused_imports)]
4use crate::{
5 IpCidr, IpRoute, MAX_SOCKET_PAYLOAD, NetworkError, Result, SocketStatus, StreamSecurity,
6 VirtualConnectedSocket, VirtualConnectionlessSocket, VirtualIcmpSocket, VirtualNetworking,
7 VirtualRawSocket, VirtualSocket, VirtualTcpBoundSocket, VirtualTcpListener, VirtualTcpSocket,
8 VirtualUdpSocket,
9};
10use crate::{VirtualIoSource, io_err_into_net_error};
11use bytes::{Buf, BytesMut};
12use std::collections::VecDeque;
13use std::io::{self, Read, Write};
14use std::mem::MaybeUninit;
15use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, Shutdown, SocketAddr};
16#[cfg(not(target_os = "windows"))]
17use std::os::fd::AsRawFd;
18#[cfg(not(target_os = "windows"))]
19use std::os::fd::RawFd;
20#[cfg(windows)]
21use std::os::windows::io::AsRawSocket;
22
23use std::sync::{Arc, Mutex};
24use std::task::Poll;
25use std::time::Duration;
26use tokio::runtime::Handle;
27#[allow(unused_imports, dead_code)]
28use tracing::{debug, error, info, trace, warn};
29use virtual_mio::{
30 HandlerGuardState, InterestGuard, InterestHandler, InterestType, Selector, state_as_waker_map,
31};
32
33#[cfg(all(target_family = "unix", feature = "libc"))]
37const LISTEN_BACKLOG: i32 = libc::SOMAXCONN;
38#[cfg(not(all(target_family = "unix", feature = "libc")))]
39const LISTEN_BACKLOG: i32 = 128;
40
41#[derive(Debug)]
42pub struct LocalNetworking {
43 selector: Arc<Selector>,
44 handle: Handle,
45 ruleset: Option<Ruleset>,
46}
47
48impl LocalNetworking {
49 pub fn new() -> Self {
50 Self {
51 selector: Selector::new(),
52 handle: Handle::current(),
53 ruleset: None,
54 }
55 }
56
57 pub fn with_ruleset(ruleset: Ruleset) -> Self {
58 Self {
59 selector: Selector::new(),
60 handle: Handle::current(),
61 ruleset: Some(ruleset),
62 }
63 }
64}
65
66impl Drop for LocalNetworking {
67 fn drop(&mut self) {
68 self.selector.shutdown();
69 }
70}
71
72impl Default for LocalNetworking {
73 fn default() -> Self {
74 Self::new()
75 }
76}
77
78fn sock_addr_into_socket_addr(addr: socket2::SockAddr) -> Result<SocketAddr> {
79 addr.as_socket().ok_or(NetworkError::UnknownError)
80}
81
82fn tcp_socket_domain(addr: SocketAddr) -> socket2::Domain {
83 if addr.is_ipv4() {
84 socket2::Domain::IPV4
85 } else {
86 socket2::Domain::IPV6
87 }
88}
89
90#[allow(clippy::needless_bool)]
91fn tcp_connect_in_progress(err: &io::Error) -> bool {
92 if matches!(
93 err.kind(),
94 io::ErrorKind::WouldBlock | io::ErrorKind::Interrupted
95 ) {
96 true
97 } else {
98 #[cfg(all(target_family = "unix", feature = "libc"))]
99 {
100 matches!(
101 err.raw_os_error(),
102 Some(raw) if raw == libc::EINPROGRESS || raw == libc::EALREADY
103 )
104 }
105
106 #[cfg(not(all(target_family = "unix", feature = "libc")))]
107 {
108 false
109 }
110 }
111}
112
113#[async_trait::async_trait]
114#[allow(unused_variables)]
115impl VirtualNetworking for LocalNetworking {
116 async fn listen_tcp(
117 &self,
118 addr: SocketAddr,
119 only_v6: bool,
120 reuse_port: bool,
121 reuse_addr: bool,
122 ) -> Result<Box<dyn VirtualTcpListener + Sync>> {
123 if let Some(ruleset) = self.ruleset.as_ref()
124 && !ruleset.allows_socket(addr, Direction::Inbound)
125 {
126 tracing::warn!(%addr, "listen_tcp blocked by firewall rule");
127 return Err(NetworkError::PermissionDenied);
128 }
129
130 self.bind_tcp(addr, only_v6, reuse_port, reuse_addr)
131 .await?
132 .listen()
133 }
134
135 async fn bind_tcp(
136 &self,
137 addr: SocketAddr,
138 only_v6: bool,
139 reuse_port: bool,
140 reuse_addr: bool,
141 ) -> Result<Box<dyn VirtualTcpBoundSocket + Sync>> {
142 if let Some(ruleset) = self.ruleset.as_ref()
143 && !ruleset.allows_socket(addr, Direction::Inbound)
144 {
145 tracing::warn!(%addr, "bind_tcp blocked by firewall rule");
146 return Err(NetworkError::PermissionDenied);
147 }
148
149 let socket = socket2::Socket::new(tcp_socket_domain(addr), socket2::Type::STREAM, None)
150 .map_err(io_err_into_net_error)?;
151 socket
152 .set_nonblocking(true)
153 .map_err(io_err_into_net_error)?;
154 if addr.is_ipv6() {
155 socket.set_only_v6(only_v6).map_err(io_err_into_net_error)?;
156 }
157 socket
158 .set_reuse_address(reuse_addr)
159 .map_err(io_err_into_net_error)?;
160 #[cfg(not(windows))]
161 socket
162 .set_reuse_port(reuse_port)
163 .map_err(io_err_into_net_error)?;
164 socket.bind(&addr.into()).map_err(io_err_into_net_error)?;
165
166 Ok(Box::new(LocalTcpBoundSocket {
167 socket: Some(socket),
168 selector: self.selector.clone(),
169 ruleset: self.ruleset.clone(),
170 }))
171 }
172
173 async fn bind_udp(
174 &self,
175 addr: SocketAddr,
176 reuse_port: bool,
177 reuse_addr: bool,
178 ) -> Result<Box<dyn VirtualUdpSocket + Sync>> {
179 #[cfg(not(windows))]
180 use socket2::{Domain, Socket, Type};
181
182 if let Some(ruleset) = self.ruleset.as_ref()
183 && !ruleset.allows_socket(addr, Direction::Inbound)
184 {
185 tracing::warn!(%addr, "bind_udp blocked by firewall rule");
186 return Err(NetworkError::PermissionDenied);
187 }
188
189 #[cfg(not(windows))]
190 let socket = {
191 let domain = if addr.is_ipv4() {
192 Domain::IPV4
193 } else {
194 Domain::IPV6
195 };
196 let std_sock = Socket::new(domain, Type::DGRAM, None).map_err(io_err_into_net_error)?;
197 std_sock
198 .set_nonblocking(true)
199 .map_err(io_err_into_net_error)?;
200 std_sock
201 .set_reuse_address(reuse_addr)
202 .map_err(io_err_into_net_error)?;
203 std_sock
204 .set_reuse_port(reuse_port)
205 .map_err(io_err_into_net_error)?;
206 std_sock.bind(&addr.into()).map_err(io_err_into_net_error)?;
207 mio::net::UdpSocket::from_std(std_sock.into())
208 };
209 #[cfg(windows)]
210 let socket = mio::net::UdpSocket::bind(addr).map_err(io_err_into_net_error)?;
211
212 #[allow(unused_mut)]
213 let mut ret = LocalUdpSocket {
214 selector: self.selector.clone(),
215 socket,
216 addr,
217 handler_guard: HandlerGuardState::None,
218 backlog: Default::default(),
219 ruleset: self.ruleset.clone(),
220 };
221
222 #[cfg(target_os = "windows")]
227 {
228 let (state, selector, socket) = ret.split_borrow();
229 let map = state_as_waker_map(state, selector, socket).map_err(io_err_into_net_error)?;
230 map.push(InterestType::Writable);
231 }
232
233 Ok(Box::new(ret))
234 }
235
236 async fn connect_tcp(
237 &self,
238 _addr: SocketAddr,
239 mut peer: SocketAddr,
240 ) -> Result<Box<dyn VirtualTcpSocket + Sync>> {
241 if let Some(ruleset) = self.ruleset.as_ref()
242 && !ruleset.allows_socket(peer, Direction::Outbound)
243 {
244 tracing::warn!(%peer, "connect_tcp blocked by firewall rule");
245 return Err(NetworkError::PermissionDenied);
246 }
247
248 let stream = mio::net::TcpStream::connect(peer).map_err(io_err_into_net_error)?;
249
250 if let Ok(p) = stream.peer_addr() {
251 peer = p;
252 }
253 let socket = Box::new(LocalTcpStream::new(self.selector.clone(), stream, peer));
254 Ok(socket)
255 }
256
257 async fn resolve(
258 &self,
259 host: &str,
260 port: Option<u16>,
261 dns_server: Option<IpAddr>,
262 ) -> Result<Vec<IpAddr>> {
263 if let Some(ruleset) = self.ruleset.as_ref()
264 && !ruleset.allows_domain(host)
265 {
266 tracing::warn!(%host, "dns resolve blocked by firewall rule");
267 return Err(NetworkError::PermissionDenied);
268 }
269
270 let host_to_lookup = if host.contains(':') {
271 host.to_string()
272 } else {
273 format!("{}:{}", host, port.unwrap_or(0))
274 };
275 let addrs = self
276 .handle
277 .spawn(tokio::net::lookup_host(host_to_lookup))
278 .await
279 .map_err(|_| NetworkError::IOError)?
280 .map(|a| a.map(|a| a.ip()).collect::<Vec<_>>())
281 .map_err(io_err_into_net_error)?;
282
283 if let Some(ruleset) = self.ruleset.as_ref() {
284 if let Err(e) = ruleset.expand_domain(host, &addrs) {
285 tracing::debug!(err=%e, "ruleset expansion failed");
286 } else {
287 tracing::debug!(addrs=?addrs, domain = host, "ruleset expansion")
288 }
289 }
290
291 Ok(addrs)
292 }
293}
294
295#[derive(Debug)]
296pub struct LocalTcpListener {
297 stream: mio::net::TcpListener,
298 selector: Arc<Selector>,
299 handler_guard: HandlerGuardState,
300 no_delay: Option<bool>,
301 keep_alive: Option<bool>,
302 backlog: VecDeque<(Box<dyn VirtualTcpSocket + Sync>, SocketAddr)>,
303 ruleset: Option<Ruleset>,
304}
305
306impl LocalTcpListener {
307 fn try_accept_internal(&mut self) -> Result<(Box<dyn VirtualTcpSocket + Sync>, SocketAddr)> {
308 match self.stream.accept().map_err(io_err_into_net_error) {
309 Ok((stream, addr)) => {
310 if let Some(ruleset) = self.ruleset.as_ref()
311 && !ruleset.allows_socket(addr, Direction::Outbound)
312 {
313 tracing::warn!(%addr, "try_accept blocked by firewall rule");
314 return Err(NetworkError::PermissionDenied);
315 }
316
317 let mut socket = LocalTcpStream::new(self.selector.clone(), stream, addr);
318 if let Some(no_delay) = self.no_delay {
319 socket.set_nodelay(no_delay).ok();
320 }
321 if let Some(keep_alive) = self.keep_alive {
322 socket.set_keepalive(keep_alive).ok();
323 }
324 Ok((Box::new(socket), addr))
325 }
326 Err(NetworkError::WouldBlock) => {
327 if let HandlerGuardState::WakerMap(_, map) = &mut self.handler_guard {
328 map.pop(InterestType::Readable);
329 map.pop(InterestType::Writable);
330 }
331 Err(NetworkError::WouldBlock)
332 }
333 Err(err) => Err(err),
334 }
335 }
336}
337
338impl VirtualTcpListener for LocalTcpListener {
339 fn try_accept(&mut self) -> Result<(Box<dyn VirtualTcpSocket + Sync>, SocketAddr)> {
340 if let Some(child) = self.backlog.pop_front() {
341 return Ok(child);
342 }
343 self.try_accept_internal()
344 }
345
346 fn set_handler(&mut self, mut handler: Box<dyn InterestHandler + Send + Sync>) -> Result<()> {
347 if let HandlerGuardState::ExternalHandler(guard) = &mut self.handler_guard {
348 match guard.replace_handler(handler) {
349 Ok(()) => return Ok(()),
350 Err(h) => handler = h,
351 }
352
353 if let Err(err) = guard.unregister(&mut self.stream) {
355 tracing::debug!("failed to unregister previous token - {}", err);
356 }
357 }
358
359 let guard = InterestGuard::new(
360 &self.selector,
361 handler,
362 &mut self.stream,
363 mio::Interest::READABLE.add(mio::Interest::WRITABLE),
364 )
365 .map_err(io_err_into_net_error)?;
366
367 self.handler_guard = HandlerGuardState::ExternalHandler(guard);
368
369 Ok(())
370 }
371
372 fn addr_local(&self) -> Result<SocketAddr> {
373 self.stream.local_addr().map_err(io_err_into_net_error)
374 }
375
376 fn set_ttl(&mut self, ttl: u8) -> Result<()> {
377 self.stream
378 .set_ttl(ttl as u32)
379 .map_err(io_err_into_net_error)
380 }
381
382 fn ttl(&self) -> Result<u8> {
383 self.stream
384 .ttl()
385 .map(|ttl| ttl as u8)
386 .map_err(io_err_into_net_error)
387 }
388}
389
390impl LocalTcpListener {
391 fn split_borrow(
392 &mut self,
393 ) -> (
394 &mut HandlerGuardState,
395 &Arc<Selector>,
396 &mut mio::net::TcpListener,
397 ) {
398 (&mut self.handler_guard, &self.selector, &mut self.stream)
399 }
400}
401
402impl VirtualIoSource for LocalTcpListener {
403 fn remove_handler(&mut self) {
404 let mut guard = HandlerGuardState::None;
405 std::mem::swap(&mut guard, &mut self.handler_guard);
406 match guard {
407 HandlerGuardState::ExternalHandler(mut guard) => {
408 guard.unregister(&mut self.stream).ok();
409 }
410 HandlerGuardState::WakerMap(mut guard, _) => {
411 guard.unregister(&mut self.stream).ok();
412 }
413 HandlerGuardState::None => {}
414 }
415 }
416
417 fn poll_read_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
418 if !self.backlog.is_empty() {
419 return Poll::Ready(Ok(self.backlog.len()));
420 }
421
422 let (state, selector, source) = self.split_borrow();
423 let map = state_as_waker_map(state, selector, source).map_err(io_err_into_net_error)?;
424 map.add(InterestType::Readable, cx.waker());
425
426 if let Ok(child) = self.try_accept_internal() {
427 self.backlog.push_back(child);
428 return Poll::Ready(Ok(1));
429 }
430 Poll::Pending
431 }
432
433 fn poll_write_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
434 if !self.backlog.is_empty() {
435 return Poll::Ready(Ok(self.backlog.len()));
436 }
437
438 let (state, selector, source) = self.split_borrow();
439 let map = state_as_waker_map(state, selector, source).map_err(io_err_into_net_error)?;
440 map.add(InterestType::Writable, cx.waker());
441
442 if let Ok(child) = self.try_accept_internal() {
443 self.backlog.push_back(child);
444 return Poll::Ready(Ok(1));
445 }
446 Poll::Pending
447 }
448}
449
450#[derive(Debug)]
451pub struct LocalTcpBoundSocket {
452 socket: Option<socket2::Socket>,
453 selector: Arc<Selector>,
454 ruleset: Option<Ruleset>,
455}
456
457impl VirtualTcpBoundSocket for LocalTcpBoundSocket {
458 fn addr_local(&self) -> Result<SocketAddr> {
459 let socket = self.socket.as_ref().ok_or(NetworkError::InvalidFd)?;
460 let addr = socket.local_addr().map_err(io_err_into_net_error)?;
461 sock_addr_into_socket_addr(addr)
462 }
463
464 fn listen(&mut self) -> Result<Box<dyn VirtualTcpListener + Sync>> {
465 let socket = self.socket.take().ok_or(NetworkError::InvalidFd)?;
466 socket
467 .listen(LISTEN_BACKLOG)
468 .map_err(io_err_into_net_error)?;
469 let listener = mio::net::TcpListener::from_std(socket.into());
470 Ok(Box::new(LocalTcpListener {
471 stream: listener,
472 selector: self.selector.clone(),
473 handler_guard: HandlerGuardState::None,
474 no_delay: None,
475 keep_alive: None,
476 backlog: Default::default(),
477 ruleset: self.ruleset.clone(),
478 }))
479 }
480
481 fn connect(&mut self, mut peer: SocketAddr) -> Result<Box<dyn VirtualTcpSocket + Sync>> {
482 if let Some(ruleset) = self.ruleset.as_ref()
483 && !ruleset.allows_socket(peer, Direction::Outbound)
484 {
485 tracing::warn!(%peer, "bound connect_tcp blocked by firewall rule");
486 return Err(NetworkError::PermissionDenied);
487 }
488
489 let socket = self.socket.take().ok_or(NetworkError::InvalidFd)?;
490 if let Err(err) = socket.connect(&peer.into())
491 && !tcp_connect_in_progress(&err)
492 {
493 return Err(io_err_into_net_error(err));
494 }
495
496 let stream = mio::net::TcpStream::from_std(socket.into());
497 if let Ok(p) = stream.peer_addr() {
498 peer = p;
499 }
500 Ok(Box::new(LocalTcpStream::new(
501 self.selector.clone(),
502 stream,
503 peer,
504 )))
505 }
506
507 fn set_ttl(&mut self, ttl: u32) -> Result<()> {
508 let socket = self.socket.as_ref().ok_or(NetworkError::InvalidFd)?;
509 match self.addr_local()?.ip() {
510 IpAddr::V4(_) => socket.set_ttl_v4(ttl).map_err(io_err_into_net_error),
511 IpAddr::V6(_) => socket
512 .set_unicast_hops_v6(ttl)
513 .map_err(io_err_into_net_error),
514 }
515 }
516
517 fn ttl(&self) -> Result<u32> {
518 let socket = self.socket.as_ref().ok_or(NetworkError::InvalidFd)?;
519 match self.addr_local()?.ip() {
520 IpAddr::V4(_) => socket.ttl_v4().map_err(io_err_into_net_error),
521 IpAddr::V6(_) => socket.unicast_hops_v6().map_err(io_err_into_net_error),
522 }
523 }
524}
525
526#[derive(Debug)]
527enum ConnectState {
528 Unknown,
529 Opened,
530 Failed(NetworkError),
531 Reset,
532}
533
534#[derive(Debug)]
535pub struct LocalTcpStream {
536 stream: mio::net::TcpStream,
537 addr: SocketAddr,
538 shutdown: Option<Shutdown>,
539 selector: Arc<Selector>,
540 handler_guard: HandlerGuardState,
541 buffer: BytesMut,
542 connect_state: Mutex<ConnectState>,
543}
544
545impl LocalTcpStream {
546 fn new(selector: Arc<Selector>, stream: mio::net::TcpStream, addr: SocketAddr) -> Self {
547 #[allow(unused_mut)]
548 let mut ret = Self {
549 stream,
550 addr,
551 shutdown: None,
552 selector,
553 handler_guard: HandlerGuardState::None,
554 buffer: BytesMut::new(),
555 connect_state: Mutex::new(ConnectState::Unknown),
556 };
557
558 #[cfg(target_os = "windows")]
563 {
564 let (state, selector, socket, _) = ret.split_borrow();
565 if let Ok(map) = state_as_waker_map(state, selector, socket) {
566 map.push(InterestType::Writable);
567 }
568 }
569
570 ret
571 }
572
573 fn with_sock_ref<F, R>(&self, f: F) -> R
574 where
575 for<'a> F: FnOnce(socket2::SockRef<'a>) -> R,
576 {
577 #[cfg(not(windows))]
578 let r = socket2::SockRef::from(&self.stream);
579
580 #[cfg(windows)]
581 let b = unsafe {
582 std::os::windows::io::BorrowedSocket::borrow_raw(self.stream.as_raw_socket())
583 };
584 #[cfg(windows)]
585 let r = socket2::SockRef::from(&b);
586
587 f(r)
588 }
589}
590
591impl VirtualTcpSocket for LocalTcpStream {
592 fn set_recv_buf_size(&mut self, size: usize) -> Result<()> {
593 Ok(())
594 }
595
596 fn recv_buf_size(&self) -> Result<usize> {
597 Err(NetworkError::Unsupported)
598 }
599
600 fn set_send_buf_size(&mut self, size: usize) -> Result<()> {
601 Ok(())
602 }
603
604 fn send_buf_size(&self) -> Result<usize> {
605 Err(NetworkError::Unsupported)
606 }
607
608 fn set_nodelay(&mut self, nodelay: bool) -> Result<()> {
609 self.stream
610 .set_nodelay(nodelay)
611 .map_err(io_err_into_net_error)
612 }
613
614 fn nodelay(&self) -> Result<bool> {
615 self.stream.nodelay().map_err(io_err_into_net_error)
616 }
617
618 fn set_keepalive(&mut self, keepalive: bool) -> Result<()> {
619 self.with_sock_ref(|s| s.set_keepalive(true))
620 .map_err(io_err_into_net_error)?;
621 Ok(())
622 }
623
624 fn keepalive(&self) -> Result<bool> {
625 let ret = self
626 .with_sock_ref(|s| s.keepalive())
627 .map_err(io_err_into_net_error)?;
628 Ok(ret)
629 }
630
631 #[cfg(not(target_os = "windows"))]
632 fn set_dontroute(&mut self, val: bool) -> Result<()> {
633 let val = val as libc::c_int;
639 let payload = &val as *const libc::c_int as *const libc::c_void;
640 let err = unsafe {
641 libc::setsockopt(
642 self.stream.as_raw_fd(),
643 libc::SOL_SOCKET,
644 libc::SO_DONTROUTE,
645 payload,
646 std::mem::size_of::<libc::c_int>() as libc::socklen_t,
647 )
648 };
649 if err == -1 {
650 return Err(io_err_into_net_error(std::io::Error::last_os_error()));
651 }
652 Ok(())
653 }
654 #[cfg(target_os = "windows")]
655 fn set_dontroute(&mut self, val: bool) -> Result<()> {
656 Err(NetworkError::Unsupported)
657 }
658
659 #[cfg(not(target_os = "windows"))]
660 fn dontroute(&self) -> Result<bool> {
661 let mut payload: MaybeUninit<libc::c_int> = MaybeUninit::uninit();
662 let mut len = std::mem::size_of::<libc::c_int>() as libc::socklen_t;
663 let err = unsafe {
664 libc::getsockopt(
665 self.stream.as_raw_fd(),
666 libc::SOL_SOCKET,
667 libc::SO_DONTROUTE,
668 payload.as_mut_ptr().cast(),
669 &mut len,
670 )
671 };
672 if err == -1 {
673 return Err(io_err_into_net_error(std::io::Error::last_os_error()));
674 }
675 Ok(unsafe { payload.assume_init() != 0 })
676 }
677 #[cfg(target_os = "windows")]
678 fn dontroute(&self) -> Result<bool> {
679 Err(NetworkError::Unsupported)
680 }
681
682 fn addr_peer(&self) -> Result<SocketAddr> {
683 Ok(self.addr)
684 }
685
686 fn shutdown(&mut self, how: Shutdown) -> Result<()> {
687 self.stream.shutdown(how).map_err(io_err_into_net_error)?;
688 self.shutdown = Some(how);
689 Ok(())
690 }
691
692 fn is_closed(&self) -> bool {
693 false
694 }
695}
696
697impl VirtualConnectedSocket for LocalTcpStream {
698 fn set_linger(&mut self, linger: Option<Duration>) -> Result<()> {
699 self.with_sock_ref(|s| s.set_linger(linger))
700 .map_err(io_err_into_net_error)?;
701 Ok(())
702 }
703
704 fn linger(&self) -> Result<Option<Duration>> {
705 self.with_sock_ref(|s| s.linger())
706 .map_err(io_err_into_net_error)
707 }
708
709 fn try_send(&mut self, data: &[u8]) -> Result<usize> {
710 let ret = self.stream.write(data).map_err(io_err_into_net_error);
711 match &ret {
712 Ok(0) | Err(NetworkError::WouldBlock) => {
713 if let HandlerGuardState::WakerMap(_, map) = &mut self.handler_guard {
714 map.pop(InterestType::Writable);
715 }
716 }
717 _ => {}
718 }
719 ret
720 }
721
722 fn try_flush(&mut self) -> Result<()> {
723 self.stream.flush().map_err(io_err_into_net_error)
724 }
725
726 fn close(&mut self) -> Result<()> {
727 Ok(())
728 }
729
730 fn try_recv(&mut self, buf: &mut [MaybeUninit<u8>], peek: bool) -> Result<usize> {
731 let buf: &mut [u8] = unsafe { std::mem::transmute(buf) };
732 if !self.buffer.is_empty() {
733 let amt = buf.len().min(self.buffer.len());
734 buf[..amt].copy_from_slice(&self.buffer[..amt]);
735 if !peek {
736 self.buffer.advance(amt);
737 }
738 return Ok(amt);
739 }
740
741 if peek {
742 self.stream.peek(buf)
743 } else {
744 self.stream.read(buf)
745 }
746 .map_err(io_err_into_net_error)
747 }
748}
749
750impl VirtualSocket for LocalTcpStream {
751 fn set_ttl(&mut self, ttl: u32) -> Result<()> {
752 self.stream.set_ttl(ttl).map_err(io_err_into_net_error)
753 }
754
755 fn ttl(&self) -> Result<u32> {
756 self.stream.ttl().map_err(io_err_into_net_error)
757 }
758
759 fn addr_local(&self) -> Result<SocketAddr> {
760 self.stream.local_addr().map_err(io_err_into_net_error)
761 }
762
763 fn status(&self) -> Result<SocketStatus> {
764 let mut connect_state = self.connect_state.lock().unwrap();
767 match *connect_state {
768 ConnectState::Opened => return Ok(SocketStatus::Opened),
769 ConnectState::Failed(_) | ConnectState::Reset => {
770 return Ok(SocketStatus::Failed);
771 }
772 ConnectState::Unknown => {}
773 }
774
775 if let Some(err) = self
776 .with_sock_ref(|sockref| sockref.take_error())
777 .map_err(io_err_into_net_error)?
778 {
779 *connect_state = ConnectState::Failed(io_err_into_net_error(err));
780 return Ok(SocketStatus::Failed); }
782 match self.stream.peer_addr() {
783 Ok(_) => {
784 *connect_state = ConnectState::Opened;
785 Ok(SocketStatus::Opened) }
787 Err(err) => {
788 if matches!(
789 err.kind(),
790 io::ErrorKind::NotConnected | io::ErrorKind::WouldBlock
791 ) {
792 Ok(SocketStatus::Opening) } else {
794 *connect_state = ConnectState::Failed(io_err_into_net_error(err));
795 Ok(SocketStatus::Failed) }
797 }
798 }
799 }
800
801 fn last_error(&self) -> Result<Option<NetworkError>> {
802 let mut connect_state = self.connect_state.lock().unwrap();
803 Ok(match *connect_state {
804 ConnectState::Failed(err) => {
805 *connect_state = ConnectState::Reset;
806 Some(err)
807 }
808 _ => None,
809 })
810 }
811
812 fn set_handler(&mut self, mut handler: Box<dyn InterestHandler + Send + Sync>) -> Result<()> {
813 if let HandlerGuardState::ExternalHandler(guard) = &mut self.handler_guard {
814 match guard.replace_handler(handler) {
815 Ok(()) => return Ok(()),
816 Err(h) => handler = h,
817 }
818
819 if let Err(err) = guard.unregister(&mut self.stream) {
821 tracing::debug!("failed to unregister previous token - {}", err);
822 }
823 }
824
825 let guard = InterestGuard::new(
826 &self.selector,
827 handler,
828 &mut self.stream,
829 mio::Interest::READABLE.add(mio::Interest::WRITABLE),
830 )
831 .map_err(io_err_into_net_error)?;
832
833 self.handler_guard = HandlerGuardState::ExternalHandler(guard);
834
835 Ok(())
836 }
837}
838
839impl LocalTcpStream {
840 fn split_borrow(
841 &mut self,
842 ) -> (
843 &mut HandlerGuardState,
844 &Arc<Selector>,
845 &mut mio::net::TcpStream,
846 &mut BytesMut,
847 ) {
848 (
849 &mut self.handler_guard,
850 &self.selector,
851 &mut self.stream,
852 &mut self.buffer,
853 )
854 }
855}
856
857impl VirtualIoSource for LocalTcpStream {
858 fn remove_handler(&mut self) {
859 let mut guard = HandlerGuardState::None;
860 std::mem::swap(&mut guard, &mut self.handler_guard);
861 match guard {
862 HandlerGuardState::ExternalHandler(mut guard) => {
863 guard.unregister(&mut self.stream).ok();
864 }
865 HandlerGuardState::WakerMap(mut guard, _) => {
866 guard.unregister(&mut self.stream).ok();
867 }
868 HandlerGuardState::None => {}
869 }
870 }
871
872 fn poll_read_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
873 if !self.buffer.is_empty() {
874 return Poll::Ready(Ok(self.buffer.len()));
875 }
876
877 let (state, selector, stream, buffer) = self.split_borrow();
878 let map = state_as_waker_map(state, selector, stream).map_err(io_err_into_net_error)?;
879 map.pop(InterestType::Readable);
880 map.add(InterestType::Readable, cx.waker());
881
882 buffer.reserve(buffer.len() + 10240);
883 let uninit: &mut [MaybeUninit<u8>] = buffer.spare_capacity_mut();
884 let uninit_unsafe: &mut [u8] = unsafe { std::mem::transmute(uninit) };
885
886 match stream.read(uninit_unsafe) {
887 Ok(0) => Poll::Ready(Ok(0)),
888 Ok(amt) => {
889 unsafe {
890 buffer.set_len(buffer.len() + amt);
891 }
892 Poll::Ready(Ok(amt))
893 }
894 Err(err) if err.kind() == io::ErrorKind::ConnectionAborted => Poll::Ready(Ok(0)),
895 Err(err) if err.kind() == io::ErrorKind::ConnectionReset => Poll::Ready(Ok(0)),
896 Err(err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
897 Err(err) => Poll::Ready(Err(io_err_into_net_error(err))),
898 }
899 }
900
901 fn poll_write_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
902 let (state, selector, stream, _) = self.split_borrow();
903 let map = state_as_waker_map(state, selector, stream).map_err(io_err_into_net_error)?;
904 #[cfg(not(target_os = "windows"))]
905 map.pop(InterestType::Writable);
906 map.add(InterestType::Writable, cx.waker());
907 map.add(InterestType::Closed, cx.waker());
908 if map.has_interest(InterestType::Closed) {
909 return Poll::Ready(Ok(0));
910 }
911
912 #[cfg(not(target_os = "windows"))]
913 match libc_poll(stream.as_raw_fd(), libc::POLLOUT | libc::POLLHUP) {
914 Some(val) if (val & libc::POLLHUP) != 0 => {
915 return Poll::Ready(Ok(0));
916 }
917 Some(val) if (val & libc::POLLOUT) != 0 => return Poll::Ready(Ok(10240)),
918 _ => {}
919 }
920
921 #[cfg(target_os = "windows")]
926 if map.has_interest(InterestType::Writable) {
927 return Poll::Ready(Ok(10240));
928 }
929
930 Poll::Pending
931 }
932}
933
934#[cfg(not(target_os = "windows"))]
935fn libc_poll(fd: RawFd, events: libc::c_short) -> Option<libc::c_short> {
936 let mut fds: [libc::pollfd; 1] = [libc::pollfd {
937 fd,
938 events,
939 revents: 0,
940 }];
941 let fds_mut = &mut fds[..];
942 let ret = unsafe { libc::poll(fds_mut.as_mut_ptr(), 1, 0) };
943 match ret == 1 {
944 true => Some(fds[0].revents),
945 false => None,
946 }
947}
948
949#[derive(Debug)]
950pub struct LocalUdpSocket {
951 socket: mio::net::UdpSocket,
952 #[allow(dead_code)]
953 addr: SocketAddr,
954 selector: Arc<Selector>,
955 handler_guard: HandlerGuardState,
956 backlog: VecDeque<(BytesMut, SocketAddr)>,
957 ruleset: Option<Ruleset>,
958}
959
960impl LocalUdpSocket {
961 fn with_sock_ref<F, R>(&self, f: F) -> R
962 where
963 for<'a> F: FnOnce(socket2::SockRef<'a>) -> R,
964 {
965 #[cfg(not(windows))]
966 let r = socket2::SockRef::from(&self.socket);
967
968 #[cfg(windows)]
969 let b = unsafe {
970 std::os::windows::io::BorrowedSocket::borrow_raw(self.socket.as_raw_socket())
971 };
972 #[cfg(windows)]
973 let r = socket2::SockRef::from(&b);
974
975 f(r)
976 }
977}
978
979impl VirtualUdpSocket for LocalUdpSocket {
980 fn set_broadcast(&mut self, broadcast: bool) -> Result<()> {
981 self.socket
982 .set_broadcast(broadcast)
983 .map_err(io_err_into_net_error)
984 }
985
986 fn broadcast(&self) -> Result<bool> {
987 self.socket.broadcast().map_err(io_err_into_net_error)
988 }
989
990 fn set_multicast_loop_v4(&mut self, val: bool) -> Result<()> {
991 self.socket
992 .set_multicast_loop_v4(val)
993 .map_err(io_err_into_net_error)
994 }
995
996 fn multicast_loop_v4(&self) -> Result<bool> {
997 self.socket
998 .multicast_loop_v4()
999 .map_err(io_err_into_net_error)
1000 }
1001
1002 fn set_multicast_loop_v6(&mut self, val: bool) -> Result<()> {
1003 self.socket
1004 .set_multicast_loop_v6(val)
1005 .map_err(io_err_into_net_error)
1006 }
1007
1008 fn multicast_loop_v6(&self) -> Result<bool> {
1009 self.socket
1010 .multicast_loop_v6()
1011 .map_err(io_err_into_net_error)
1012 }
1013
1014 fn set_multicast_ttl_v4(&mut self, ttl: u32) -> Result<()> {
1015 self.socket
1016 .set_multicast_ttl_v4(ttl)
1017 .map_err(io_err_into_net_error)
1018 }
1019
1020 fn multicast_ttl_v4(&self) -> Result<u32> {
1021 self.socket
1022 .multicast_ttl_v4()
1023 .map_err(io_err_into_net_error)
1024 }
1025
1026 fn join_multicast_v4(&mut self, multiaddr: Ipv4Addr, iface: Ipv4Addr) -> Result<()> {
1027 self.with_sock_ref(|s| s.join_multicast_v4(&multiaddr, &iface))
1028 .map_err(io_err_into_net_error)
1029 }
1030
1031 fn leave_multicast_v4(&mut self, multiaddr: Ipv4Addr, iface: Ipv4Addr) -> Result<()> {
1032 self.with_sock_ref(|s| s.leave_multicast_v4(&multiaddr, &iface))
1033 .map_err(io_err_into_net_error)
1034 }
1035
1036 fn join_multicast_v6(&mut self, multiaddr: Ipv6Addr, iface: u32) -> Result<()> {
1037 self.socket
1038 .join_multicast_v6(&multiaddr, iface)
1039 .map_err(io_err_into_net_error)
1040 }
1041
1042 fn leave_multicast_v6(&mut self, multiaddr: Ipv6Addr, iface: u32) -> Result<()> {
1043 self.socket
1044 .leave_multicast_v6(&multiaddr, iface)
1045 .map_err(io_err_into_net_error)
1046 }
1047
1048 fn addr_peer(&self) -> Result<Option<SocketAddr>> {
1049 self.socket
1050 .peer_addr()
1051 .map(Some)
1052 .map_err(io_err_into_net_error)
1053 }
1054}
1055
1056impl VirtualConnectionlessSocket for LocalUdpSocket {
1057 fn try_send_to(&mut self, data: &[u8], addr: SocketAddr) -> Result<usize> {
1058 if let Some(ruleset) = self.ruleset.as_ref()
1059 && !ruleset.allows_socket(addr, Direction::Outbound)
1060 {
1061 tracing::warn!(%addr, "try_send blocked by firewall rule");
1062 return Err(NetworkError::PermissionDenied);
1063 }
1064
1065 let ret = self
1066 .socket
1067 .send_to(data, addr)
1068 .map_err(io_err_into_net_error);
1069 match &ret {
1070 Ok(0) | Err(NetworkError::WouldBlock) => {
1071 if let HandlerGuardState::WakerMap(_, map) = &mut self.handler_guard {
1072 map.pop(InterestType::Writable);
1073 }
1074 }
1075 _ => {}
1076 }
1077 ret
1078 }
1079
1080 fn try_recv_from(
1081 &mut self,
1082 buf: &mut [MaybeUninit<u8>],
1083 peek: bool,
1084 ) -> Result<(usize, SocketAddr)> {
1085 if self.backlog.is_empty() {
1086 self.recv_into_backlog().map_err(io_err_into_net_error)?;
1087 }
1088
1089 let buf: &mut [u8] = unsafe { std::mem::transmute(buf) };
1090 let Some((packet, addr)) = self.backlog.front() else {
1091 return Err(NetworkError::WouldBlock);
1092 };
1093 let amt = buf.len().min(packet.len());
1094 let addr = *addr;
1095 buf[..amt].copy_from_slice(&packet[..amt]);
1096 if !peek {
1097 self.backlog.pop_front();
1098 }
1099 Ok((amt, addr))
1100 }
1101}
1102
1103impl VirtualSocket for LocalUdpSocket {
1104 fn set_ttl(&mut self, ttl: u32) -> Result<()> {
1105 self.socket.set_ttl(ttl).map_err(io_err_into_net_error)
1106 }
1107
1108 fn ttl(&self) -> Result<u32> {
1109 self.socket.ttl().map_err(io_err_into_net_error)
1110 }
1111
1112 fn addr_local(&self) -> Result<SocketAddr> {
1113 self.socket.local_addr().map_err(io_err_into_net_error)
1114 }
1115
1116 fn status(&self) -> Result<SocketStatus> {
1117 Ok(SocketStatus::Opened)
1118 }
1119
1120 fn set_handler(&mut self, mut handler: Box<dyn InterestHandler + Send + Sync>) -> Result<()> {
1121 if let HandlerGuardState::ExternalHandler(guard) = &mut self.handler_guard {
1122 match guard.replace_handler(handler) {
1123 Ok(()) => {
1124 return Ok(());
1125 }
1126 Err(h) => handler = h,
1127 }
1128
1129 if let Err(err) = guard.unregister(&mut self.socket) {
1131 tracing::debug!("failed to unregister previous token - {}", err);
1132 }
1133 }
1134
1135 let guard = InterestGuard::new(
1136 &self.selector,
1137 handler,
1138 &mut self.socket,
1139 mio::Interest::READABLE.add(mio::Interest::WRITABLE),
1140 )
1141 .map_err(io_err_into_net_error)?;
1142
1143 self.handler_guard = HandlerGuardState::ExternalHandler(guard);
1144
1145 Ok(())
1146 }
1147}
1148
1149impl LocalUdpSocket {
1150 fn backlog_read_ready_len(&self) -> usize {
1153 self.backlog
1154 .front()
1155 .map(|(packet, _)| packet.len().max(1))
1156 .unwrap_or_default()
1157 }
1158
1159 fn recv_into_backlog(&mut self) -> io::Result<()> {
1160 let mut buffer = BytesMut::default();
1161 buffer.reserve(MAX_SOCKET_PAYLOAD);
1162 let uninit: &mut [MaybeUninit<u8>] = buffer.spare_capacity_mut();
1163 let uninit_unsafe: &mut [u8] = unsafe { std::mem::transmute(uninit) };
1164
1165 let (amt, peer) = self.socket.recv_from(uninit_unsafe)?;
1166 unsafe {
1167 buffer.set_len(amt);
1168 }
1169 self.backlog.push_back((buffer, peer));
1170 Ok(())
1171 }
1172
1173 fn split_borrow(
1174 &mut self,
1175 ) -> (
1176 &mut HandlerGuardState,
1177 &Arc<Selector>,
1178 &mut mio::net::UdpSocket,
1179 ) {
1180 (&mut self.handler_guard, &self.selector, &mut self.socket)
1181 }
1182}
1183
1184impl VirtualIoSource for LocalUdpSocket {
1185 fn remove_handler(&mut self) {
1186 let mut guard = HandlerGuardState::None;
1187 std::mem::swap(&mut guard, &mut self.handler_guard);
1188 match guard {
1189 HandlerGuardState::ExternalHandler(mut guard) => {
1190 guard.unregister(&mut self.socket).ok();
1191 }
1192 HandlerGuardState::WakerMap(mut guard, _) => {
1193 guard.unregister(&mut self.socket).ok();
1194 }
1195 HandlerGuardState::None => {}
1196 }
1197 }
1198
1199 fn poll_read_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
1200 if !self.backlog.is_empty() {
1201 return Poll::Ready(Ok(self.backlog_read_ready_len()));
1202 }
1203
1204 let (state, selector, socket) = self.split_borrow();
1205 let map = state_as_waker_map(state, selector, socket).map_err(io_err_into_net_error)?;
1206 map.pop(InterestType::Readable);
1207 map.add(InterestType::Readable, cx.waker());
1208
1209 match self.recv_into_backlog() {
1210 Ok(()) => Poll::Ready(Ok(self.backlog_read_ready_len())),
1211 Err(err) if err.kind() == io::ErrorKind::ConnectionAborted => Poll::Ready(Ok(0)),
1212 Err(err) if err.kind() == io::ErrorKind::ConnectionReset => Poll::Ready(Ok(0)),
1213 Err(err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
1214 Err(err) => Poll::Ready(Err(io_err_into_net_error(err))),
1215 }
1216 }
1217
1218 fn poll_write_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
1219 let (state, selector, socket) = self.split_borrow();
1220 let map = state_as_waker_map(state, selector, socket).map_err(io_err_into_net_error)?;
1221 #[cfg(not(target_os = "windows"))]
1222 map.pop(InterestType::Writable);
1223 map.add(InterestType::Writable, cx.waker());
1224
1225 #[cfg(not(target_os = "windows"))]
1226 match libc_poll(socket.as_raw_fd(), libc::POLLOUT | libc::POLLHUP) {
1227 Some(val) if (val & libc::POLLHUP) != 0 => {
1228 return Poll::Ready(Ok(0));
1229 }
1230 Some(val) if (val & libc::POLLOUT) != 0 => return Poll::Ready(Ok(10240)),
1231 _ => {}
1232 }
1233
1234 #[cfg(target_os = "windows")]
1239 if map.has_interest(InterestType::Writable) {
1240 return Poll::Ready(Ok(10240));
1241 }
1242
1243 Poll::Pending
1244 }
1245}