1#![allow(unused_variables)]
2use crate::ruleset::{Direction, Ruleset};
3#[allow(unused_imports)]
4use crate::{
5 IpCidr, IpRoute, NetworkError, Result, SocketStatus, StreamSecurity, VirtualConnectedSocket,
6 VirtualConnectionlessSocket, VirtualIcmpSocket, VirtualNetworking, VirtualRawSocket,
7 VirtualSocket, VirtualTcpListener, VirtualTcpSocket, VirtualUdpSocket,
8};
9use crate::{VirtualIoSource, io_err_into_net_error};
10use bytes::{Buf, BytesMut};
11use std::collections::VecDeque;
12use std::io::{self, Read, Write};
13use std::mem::MaybeUninit;
14use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, Shutdown, SocketAddr};
15#[cfg(not(target_os = "windows"))]
16use std::os::fd::AsRawFd;
17#[cfg(not(target_os = "windows"))]
18use std::os::fd::RawFd;
19#[cfg(windows)]
20use std::os::windows::io::AsRawSocket;
21
22use std::sync::Arc;
23use std::task::Poll;
24use std::time::Duration;
25use tokio::runtime::Handle;
26#[allow(unused_imports, dead_code)]
27use tracing::{debug, error, info, trace, warn};
28use virtual_mio::{
29 HandlerGuardState, InterestGuard, InterestHandler, InterestType, Selector, state_as_waker_map,
30};
31
32#[derive(Debug)]
33pub struct LocalNetworking {
34 selector: Arc<Selector>,
35 handle: Handle,
36 ruleset: Option<Ruleset>,
37}
38
39impl LocalNetworking {
40 pub fn new() -> Self {
41 Self {
42 selector: Selector::new(),
43 handle: Handle::current(),
44 ruleset: None,
45 }
46 }
47
48 pub fn with_ruleset(ruleset: Ruleset) -> Self {
49 Self {
50 selector: Selector::new(),
51 handle: Handle::current(),
52 ruleset: Some(ruleset),
53 }
54 }
55}
56
57impl Drop for LocalNetworking {
58 fn drop(&mut self) {
59 self.selector.shutdown();
60 }
61}
62
63impl Default for LocalNetworking {
64 fn default() -> Self {
65 Self::new()
66 }
67}
68
69#[async_trait::async_trait]
70#[allow(unused_variables)]
71impl VirtualNetworking for LocalNetworking {
72 async fn listen_tcp(
73 &self,
74 addr: SocketAddr,
75 only_v6: bool,
76 reuse_port: bool,
77 reuse_addr: bool,
78 ) -> Result<Box<dyn VirtualTcpListener + Sync>> {
79 if let Some(ruleset) = self.ruleset.as_ref()
80 && !ruleset.allows_socket(addr, Direction::Inbound)
81 {
82 tracing::warn!(%addr, "listen_tcp blocked by firewall rule");
83 return Err(NetworkError::PermissionDenied);
84 }
85
86 let listener = std::net::TcpListener::bind(addr)
87 .map(|sock| {
88 sock.set_nonblocking(true).ok();
89 Box::new(LocalTcpListener {
90 stream: mio::net::TcpListener::from_std(sock),
91 selector: self.selector.clone(),
92 handler_guard: HandlerGuardState::None,
93 no_delay: None,
94 keep_alive: None,
95 backlog: Default::default(),
96 ruleset: self.ruleset.clone(),
97 })
98 })
99 .map_err(io_err_into_net_error)?;
100 Ok(listener)
101 }
102
103 async fn bind_udp(
104 &self,
105 addr: SocketAddr,
106 reuse_port: bool,
107 reuse_addr: bool,
108 ) -> Result<Box<dyn VirtualUdpSocket + Sync>> {
109 #[cfg(not(windows))]
110 use socket2::{Domain, Socket, Type};
111
112 if let Some(ruleset) = self.ruleset.as_ref()
113 && !ruleset.allows_socket(addr, Direction::Inbound)
114 {
115 tracing::warn!(%addr, "bind_udp blocked by firewall rule");
116 return Err(NetworkError::PermissionDenied);
117 }
118
119 #[cfg(not(windows))]
120 let socket = {
121 let domain = if addr.is_ipv4() {
122 Domain::IPV4
123 } else {
124 Domain::IPV6
125 };
126 let std_sock = Socket::new(domain, Type::DGRAM, None).map_err(io_err_into_net_error)?;
127 std_sock
128 .set_nonblocking(true)
129 .map_err(io_err_into_net_error)?;
130 std_sock
131 .set_reuse_address(reuse_addr)
132 .map_err(io_err_into_net_error)?;
133 std_sock
134 .set_reuse_port(reuse_port)
135 .map_err(io_err_into_net_error)?;
136 std_sock.bind(&addr.into()).map_err(io_err_into_net_error)?;
137 mio::net::UdpSocket::from_std(std_sock.into())
138 };
139 #[cfg(windows)]
140 let socket = mio::net::UdpSocket::bind(addr).map_err(io_err_into_net_error)?;
141
142 #[allow(unused_mut)]
143 let mut ret = LocalUdpSocket {
144 selector: self.selector.clone(),
145 socket,
146 addr,
147 handler_guard: HandlerGuardState::None,
148 backlog: Default::default(),
149 ruleset: self.ruleset.clone(),
150 };
151
152 #[cfg(target_os = "windows")]
157 {
158 let (state, selector, socket) = ret.split_borrow();
159 let map = state_as_waker_map(state, selector, socket).map_err(io_err_into_net_error)?;
160 map.push(InterestType::Writable);
161 }
162
163 Ok(Box::new(ret))
164 }
165
166 async fn connect_tcp(
167 &self,
168 _addr: SocketAddr,
169 mut peer: SocketAddr,
170 ) -> Result<Box<dyn VirtualTcpSocket + Sync>> {
171 if let Some(ruleset) = self.ruleset.as_ref()
172 && !ruleset.allows_socket(peer, Direction::Outbound)
173 {
174 tracing::warn!(%peer, "connect_tcp blocked by firewall rule");
175 return Err(NetworkError::PermissionDenied);
176 }
177
178 let stream = self
181 .handle
182 .spawn(tokio::net::TcpStream::connect(peer))
183 .await
184 .map_err(|_| NetworkError::IOError)?
185 .map_err(io_err_into_net_error)?;
186 let stream = stream.into_std().map_err(io_err_into_net_error)?;
187 let stream = mio::net::TcpStream::from_std(stream);
188
189 if let Ok(p) = stream.peer_addr() {
190 peer = p;
191 }
192 let socket = Box::new(LocalTcpStream::new(self.selector.clone(), stream, peer));
193 Ok(socket)
194 }
195
196 async fn resolve(
197 &self,
198 host: &str,
199 port: Option<u16>,
200 dns_server: Option<IpAddr>,
201 ) -> Result<Vec<IpAddr>> {
202 if let Some(ruleset) = self.ruleset.as_ref()
203 && !ruleset.allows_domain(host)
204 {
205 tracing::warn!(%host, "dns resolve blocked by firewall rule");
206 return Err(NetworkError::PermissionDenied);
207 }
208
209 let host_to_lookup = if host.contains(':') {
210 host.to_string()
211 } else {
212 format!("{}:{}", host, port.unwrap_or(0))
213 };
214 let addrs = self
215 .handle
216 .spawn(tokio::net::lookup_host(host_to_lookup))
217 .await
218 .map_err(|_| NetworkError::IOError)?
219 .map(|a| a.map(|a| a.ip()).collect::<Vec<_>>())
220 .map_err(io_err_into_net_error)?;
221
222 if let Some(ruleset) = self.ruleset.as_ref() {
223 if let Err(e) = ruleset.expand_domain(host, &addrs) {
224 tracing::debug!(err=%e, "ruleset expansion failed");
225 } else {
226 tracing::debug!(addrs=?addrs, domain = host, "ruleset expansion")
227 }
228 }
229
230 Ok(addrs)
231 }
232}
233
234#[derive(Debug)]
235pub struct LocalTcpListener {
236 stream: mio::net::TcpListener,
237 selector: Arc<Selector>,
238 handler_guard: HandlerGuardState,
239 no_delay: Option<bool>,
240 keep_alive: Option<bool>,
241 backlog: VecDeque<(Box<dyn VirtualTcpSocket + Sync>, SocketAddr)>,
242 ruleset: Option<Ruleset>,
243}
244
245impl LocalTcpListener {
246 fn try_accept_internal(&mut self) -> Result<(Box<dyn VirtualTcpSocket + Sync>, SocketAddr)> {
247 match self.stream.accept().map_err(io_err_into_net_error) {
248 Ok((stream, addr)) => {
249 if let Some(ruleset) = self.ruleset.as_ref()
250 && !ruleset.allows_socket(addr, Direction::Outbound)
251 {
252 tracing::warn!(%addr, "try_accept blocked by firewall rule");
253 return Err(NetworkError::PermissionDenied);
254 }
255
256 let mut socket = LocalTcpStream::new(self.selector.clone(), stream, addr);
257 if let Some(no_delay) = self.no_delay {
258 socket.set_nodelay(no_delay).ok();
259 }
260 if let Some(keep_alive) = self.keep_alive {
261 socket.set_keepalive(keep_alive).ok();
262 }
263 Ok((Box::new(socket), addr))
264 }
265 Err(NetworkError::WouldBlock) => {
266 if let HandlerGuardState::WakerMap(_, map) = &mut self.handler_guard {
267 map.pop(InterestType::Readable);
268 map.pop(InterestType::Writable);
269 }
270 Err(NetworkError::WouldBlock)
271 }
272 Err(err) => Err(err),
273 }
274 }
275}
276
277impl VirtualTcpListener for LocalTcpListener {
278 fn try_accept(&mut self) -> Result<(Box<dyn VirtualTcpSocket + Sync>, SocketAddr)> {
279 if let Some(child) = self.backlog.pop_front() {
280 return Ok(child);
281 }
282 self.try_accept_internal()
283 }
284
285 fn set_handler(&mut self, mut handler: Box<dyn InterestHandler + Send + Sync>) -> Result<()> {
286 if let HandlerGuardState::ExternalHandler(guard) = &mut self.handler_guard {
287 match guard.replace_handler(handler) {
288 Ok(()) => return Ok(()),
289 Err(h) => handler = h,
290 }
291
292 if let Err(err) = guard.unregister(&mut self.stream) {
294 tracing::debug!("failed to unregister previous token - {}", err);
295 }
296 }
297
298 let guard = InterestGuard::new(
299 &self.selector,
300 handler,
301 &mut self.stream,
302 mio::Interest::READABLE.add(mio::Interest::WRITABLE),
303 )
304 .map_err(io_err_into_net_error)?;
305
306 self.handler_guard = HandlerGuardState::ExternalHandler(guard);
307
308 Ok(())
309 }
310
311 fn addr_local(&self) -> Result<SocketAddr> {
312 self.stream.local_addr().map_err(io_err_into_net_error)
313 }
314
315 fn set_ttl(&mut self, ttl: u8) -> Result<()> {
316 self.stream
317 .set_ttl(ttl as u32)
318 .map_err(io_err_into_net_error)
319 }
320
321 fn ttl(&self) -> Result<u8> {
322 self.stream
323 .ttl()
324 .map(|ttl| ttl as u8)
325 .map_err(io_err_into_net_error)
326 }
327}
328
329impl LocalTcpListener {
330 fn split_borrow(
331 &mut self,
332 ) -> (
333 &mut HandlerGuardState,
334 &Arc<Selector>,
335 &mut mio::net::TcpListener,
336 ) {
337 (&mut self.handler_guard, &self.selector, &mut self.stream)
338 }
339}
340
341impl VirtualIoSource for LocalTcpListener {
342 fn remove_handler(&mut self) {
343 let mut guard = HandlerGuardState::None;
344 std::mem::swap(&mut guard, &mut self.handler_guard);
345 match guard {
346 HandlerGuardState::ExternalHandler(mut guard) => {
347 guard.unregister(&mut self.stream).ok();
348 }
349 HandlerGuardState::WakerMap(mut guard, _) => {
350 guard.unregister(&mut self.stream).ok();
351 }
352 HandlerGuardState::None => {}
353 }
354 }
355
356 fn poll_read_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
357 if !self.backlog.is_empty() {
358 return Poll::Ready(Ok(self.backlog.len()));
359 }
360
361 let (state, selector, source) = self.split_borrow();
362 let map = state_as_waker_map(state, selector, source).map_err(io_err_into_net_error)?;
363 map.add(InterestType::Readable, cx.waker());
364
365 if let Ok(child) = self.try_accept_internal() {
366 self.backlog.push_back(child);
367 return Poll::Ready(Ok(1));
368 }
369 Poll::Pending
370 }
371
372 fn poll_write_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
373 if !self.backlog.is_empty() {
374 return Poll::Ready(Ok(self.backlog.len()));
375 }
376
377 let (state, selector, source) = self.split_borrow();
378 let map = state_as_waker_map(state, selector, source).map_err(io_err_into_net_error)?;
379 map.add(InterestType::Writable, cx.waker());
380
381 if let Ok(child) = self.try_accept_internal() {
382 self.backlog.push_back(child);
383 return Poll::Ready(Ok(1));
384 }
385 Poll::Pending
386 }
387}
388
389#[derive(Debug)]
390pub struct LocalTcpStream {
391 stream: mio::net::TcpStream,
392 addr: SocketAddr,
393 shutdown: Option<Shutdown>,
394 selector: Arc<Selector>,
395 handler_guard: HandlerGuardState,
396 buffer: BytesMut,
397}
398
399impl LocalTcpStream {
400 fn new(selector: Arc<Selector>, stream: mio::net::TcpStream, addr: SocketAddr) -> Self {
401 #[allow(unused_mut)]
402 let mut ret = Self {
403 stream,
404 addr,
405 shutdown: None,
406 selector,
407 handler_guard: HandlerGuardState::None,
408 buffer: BytesMut::new(),
409 };
410
411 #[cfg(target_os = "windows")]
416 {
417 let (state, selector, socket, _) = ret.split_borrow();
418 if let Ok(map) = state_as_waker_map(state, selector, socket) {
419 map.push(InterestType::Writable);
420 }
421 }
422
423 ret
424 }
425
426 fn with_sock_ref<F, R>(&self, f: F) -> R
427 where
428 for<'a> F: FnOnce(socket2::SockRef<'a>) -> R,
429 {
430 #[cfg(not(windows))]
431 let r = socket2::SockRef::from(&self.stream);
432
433 #[cfg(windows)]
434 let b = unsafe {
435 std::os::windows::io::BorrowedSocket::borrow_raw(self.stream.as_raw_socket())
436 };
437 #[cfg(windows)]
438 let r = socket2::SockRef::from(&b);
439
440 f(r)
441 }
442}
443
444impl VirtualTcpSocket for LocalTcpStream {
445 fn set_recv_buf_size(&mut self, size: usize) -> Result<()> {
446 Ok(())
447 }
448
449 fn recv_buf_size(&self) -> Result<usize> {
450 Err(NetworkError::Unsupported)
451 }
452
453 fn set_send_buf_size(&mut self, size: usize) -> Result<()> {
454 Ok(())
455 }
456
457 fn send_buf_size(&self) -> Result<usize> {
458 Err(NetworkError::Unsupported)
459 }
460
461 fn set_nodelay(&mut self, nodelay: bool) -> Result<()> {
462 self.stream
463 .set_nodelay(nodelay)
464 .map_err(io_err_into_net_error)
465 }
466
467 fn nodelay(&self) -> Result<bool> {
468 self.stream.nodelay().map_err(io_err_into_net_error)
469 }
470
471 fn set_keepalive(&mut self, keepalive: bool) -> Result<()> {
472 self.with_sock_ref(|s| s.set_keepalive(true))
473 .map_err(io_err_into_net_error)?;
474 Ok(())
475 }
476
477 fn keepalive(&self) -> Result<bool> {
478 let ret = self
479 .with_sock_ref(|s| s.keepalive())
480 .map_err(io_err_into_net_error)?;
481 Ok(ret)
482 }
483
484 #[cfg(not(target_os = "windows"))]
485 fn set_dontroute(&mut self, val: bool) -> Result<()> {
486 let val = val as libc::c_int;
492 let payload = &val as *const libc::c_int as *const libc::c_void;
493 let err = unsafe {
494 libc::setsockopt(
495 self.stream.as_raw_fd(),
496 libc::SOL_SOCKET,
497 libc::SO_DONTROUTE,
498 payload,
499 std::mem::size_of::<libc::c_int>() as libc::socklen_t,
500 )
501 };
502 if err == -1 {
503 return Err(io_err_into_net_error(std::io::Error::last_os_error()));
504 }
505 Ok(())
506 }
507 #[cfg(target_os = "windows")]
508 fn set_dontroute(&mut self, val: bool) -> Result<()> {
509 Err(NetworkError::Unsupported)
510 }
511
512 #[cfg(not(target_os = "windows"))]
513 fn dontroute(&self) -> Result<bool> {
514 let mut payload: MaybeUninit<libc::c_int> = MaybeUninit::uninit();
515 let mut len = std::mem::size_of::<libc::c_int>() as libc::socklen_t;
516 let err = unsafe {
517 libc::getsockopt(
518 self.stream.as_raw_fd(),
519 libc::SOL_SOCKET,
520 libc::SO_DONTROUTE,
521 payload.as_mut_ptr().cast(),
522 &mut len,
523 )
524 };
525 if err == -1 {
526 return Err(io_err_into_net_error(std::io::Error::last_os_error()));
527 }
528 Ok(unsafe { payload.assume_init() != 0 })
529 }
530 #[cfg(target_os = "windows")]
531 fn dontroute(&self) -> Result<bool> {
532 Err(NetworkError::Unsupported)
533 }
534
535 fn addr_peer(&self) -> Result<SocketAddr> {
536 Ok(self.addr)
537 }
538
539 fn shutdown(&mut self, how: Shutdown) -> Result<()> {
540 self.stream.shutdown(how).map_err(io_err_into_net_error)?;
541 self.shutdown = Some(how);
542 Ok(())
543 }
544
545 fn is_closed(&self) -> bool {
546 false
547 }
548}
549
550impl VirtualConnectedSocket for LocalTcpStream {
551 fn set_linger(&mut self, linger: Option<Duration>) -> Result<()> {
552 self.with_sock_ref(|s| s.set_linger(linger))
553 .map_err(io_err_into_net_error)?;
554 Ok(())
555 }
556
557 fn linger(&self) -> Result<Option<Duration>> {
558 self.with_sock_ref(|s| s.linger())
559 .map_err(io_err_into_net_error)
560 }
561
562 fn try_send(&mut self, data: &[u8]) -> Result<usize> {
563 let ret = self.stream.write(data).map_err(io_err_into_net_error);
564 match &ret {
565 Ok(0) | Err(NetworkError::WouldBlock) => {
566 if let HandlerGuardState::WakerMap(_, map) = &mut self.handler_guard {
567 map.pop(InterestType::Writable);
568 }
569 }
570 _ => {}
571 }
572 ret
573 }
574
575 fn try_flush(&mut self) -> Result<()> {
576 self.stream.flush().map_err(io_err_into_net_error)
577 }
578
579 fn close(&mut self) -> Result<()> {
580 Ok(())
581 }
582
583 fn try_recv(&mut self, buf: &mut [MaybeUninit<u8>], peek: bool) -> Result<usize> {
584 let buf: &mut [u8] = unsafe { std::mem::transmute(buf) };
585 if !self.buffer.is_empty() {
586 let amt = buf.len().min(self.buffer.len());
587 buf[..amt].copy_from_slice(&self.buffer[..amt]);
588 if !peek {
589 self.buffer.advance(amt);
590 }
591 return Ok(amt);
592 }
593
594 if peek {
595 self.stream.peek(buf)
596 } else {
597 self.stream.read(buf)
598 }
599 .map_err(io_err_into_net_error)
600 }
601}
602
603impl VirtualSocket for LocalTcpStream {
604 fn set_ttl(&mut self, ttl: u32) -> Result<()> {
605 self.stream.set_ttl(ttl).map_err(io_err_into_net_error)
606 }
607
608 fn ttl(&self) -> Result<u32> {
609 self.stream.ttl().map_err(io_err_into_net_error)
610 }
611
612 fn addr_local(&self) -> Result<SocketAddr> {
613 self.stream.local_addr().map_err(io_err_into_net_error)
614 }
615
616 fn status(&self) -> Result<SocketStatus> {
617 Ok(SocketStatus::Opened)
618 }
619
620 fn set_handler(&mut self, mut handler: Box<dyn InterestHandler + Send + Sync>) -> Result<()> {
621 if let HandlerGuardState::ExternalHandler(guard) = &mut self.handler_guard {
622 match guard.replace_handler(handler) {
623 Ok(()) => return Ok(()),
624 Err(h) => handler = h,
625 }
626
627 if let Err(err) = guard.unregister(&mut self.stream) {
629 tracing::debug!("failed to unregister previous token - {}", err);
630 }
631 }
632
633 let guard = InterestGuard::new(
634 &self.selector,
635 handler,
636 &mut self.stream,
637 mio::Interest::READABLE.add(mio::Interest::WRITABLE),
638 )
639 .map_err(io_err_into_net_error)?;
640
641 self.handler_guard = HandlerGuardState::ExternalHandler(guard);
642
643 Ok(())
644 }
645}
646
647impl LocalTcpStream {
648 fn split_borrow(
649 &mut self,
650 ) -> (
651 &mut HandlerGuardState,
652 &Arc<Selector>,
653 &mut mio::net::TcpStream,
654 &mut BytesMut,
655 ) {
656 (
657 &mut self.handler_guard,
658 &self.selector,
659 &mut self.stream,
660 &mut self.buffer,
661 )
662 }
663}
664
665impl VirtualIoSource for LocalTcpStream {
666 fn remove_handler(&mut self) {
667 let mut guard = HandlerGuardState::None;
668 std::mem::swap(&mut guard, &mut self.handler_guard);
669 match guard {
670 HandlerGuardState::ExternalHandler(mut guard) => {
671 guard.unregister(&mut self.stream).ok();
672 }
673 HandlerGuardState::WakerMap(mut guard, _) => {
674 guard.unregister(&mut self.stream).ok();
675 }
676 HandlerGuardState::None => {}
677 }
678 }
679
680 fn poll_read_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
681 if !self.buffer.is_empty() {
682 return Poll::Ready(Ok(self.buffer.len()));
683 }
684
685 let (state, selector, stream, buffer) = self.split_borrow();
686 let map = state_as_waker_map(state, selector, stream).map_err(io_err_into_net_error)?;
687 map.pop(InterestType::Readable);
688 map.add(InterestType::Readable, cx.waker());
689
690 buffer.reserve(buffer.len() + 10240);
691 let uninit: &mut [MaybeUninit<u8>] = buffer.spare_capacity_mut();
692 let uninit_unsafe: &mut [u8] = unsafe { std::mem::transmute(uninit) };
693
694 match stream.read(uninit_unsafe) {
695 Ok(0) => Poll::Ready(Ok(0)),
696 Ok(amt) => {
697 unsafe {
698 buffer.set_len(buffer.len() + amt);
699 }
700 Poll::Ready(Ok(amt))
701 }
702 Err(err) if err.kind() == io::ErrorKind::ConnectionAborted => Poll::Ready(Ok(0)),
703 Err(err) if err.kind() == io::ErrorKind::ConnectionReset => Poll::Ready(Ok(0)),
704 Err(err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
705 Err(err) => Poll::Ready(Err(io_err_into_net_error(err))),
706 }
707 }
708
709 fn poll_write_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
710 let (state, selector, stream, _) = self.split_borrow();
711 let map = state_as_waker_map(state, selector, stream).map_err(io_err_into_net_error)?;
712 #[cfg(not(target_os = "windows"))]
713 map.pop(InterestType::Writable);
714 map.add(InterestType::Writable, cx.waker());
715 map.add(InterestType::Closed, cx.waker());
716 if map.has_interest(InterestType::Closed) {
717 return Poll::Ready(Ok(0));
718 }
719
720 #[cfg(not(target_os = "windows"))]
721 match libc_poll(stream.as_raw_fd(), libc::POLLOUT | libc::POLLHUP) {
722 Some(val) if (val & libc::POLLHUP) != 0 => {
723 return Poll::Ready(Ok(0));
724 }
725 Some(val) if (val & libc::POLLOUT) != 0 => return Poll::Ready(Ok(10240)),
726 _ => {}
727 }
728
729 #[cfg(target_os = "windows")]
734 if map.has_interest(InterestType::Writable) {
735 return Poll::Ready(Ok(10240));
736 }
737
738 Poll::Pending
739 }
740}
741
742#[cfg(not(target_os = "windows"))]
743fn libc_poll(fd: RawFd, events: libc::c_short) -> Option<libc::c_short> {
744 let mut fds: [libc::pollfd; 1] = [libc::pollfd {
745 fd,
746 events,
747 revents: 0,
748 }];
749 let fds_mut = &mut fds[..];
750 let ret = unsafe { libc::poll(fds_mut.as_mut_ptr(), 1, 0) };
751 match ret == 1 {
752 true => Some(fds[0].revents),
753 false => None,
754 }
755}
756
757#[derive(Debug)]
758pub struct LocalUdpSocket {
759 socket: mio::net::UdpSocket,
760 #[allow(dead_code)]
761 addr: SocketAddr,
762 selector: Arc<Selector>,
763 handler_guard: HandlerGuardState,
764 backlog: VecDeque<(BytesMut, SocketAddr)>,
765 ruleset: Option<Ruleset>,
766}
767
768impl LocalUdpSocket {
769 fn with_sock_ref<F, R>(&self, f: F) -> R
770 where
771 for<'a> F: FnOnce(socket2::SockRef<'a>) -> R,
772 {
773 #[cfg(not(windows))]
774 let r = socket2::SockRef::from(&self.socket);
775
776 #[cfg(windows)]
777 let b = unsafe {
778 std::os::windows::io::BorrowedSocket::borrow_raw(self.socket.as_raw_socket())
779 };
780 #[cfg(windows)]
781 let r = socket2::SockRef::from(&b);
782
783 f(r)
784 }
785}
786
787impl VirtualUdpSocket for LocalUdpSocket {
788 fn set_broadcast(&mut self, broadcast: bool) -> Result<()> {
789 self.socket
790 .set_broadcast(broadcast)
791 .map_err(io_err_into_net_error)
792 }
793
794 fn broadcast(&self) -> Result<bool> {
795 self.socket.broadcast().map_err(io_err_into_net_error)
796 }
797
798 fn set_multicast_loop_v4(&mut self, val: bool) -> Result<()> {
799 self.socket
800 .set_multicast_loop_v4(val)
801 .map_err(io_err_into_net_error)
802 }
803
804 fn multicast_loop_v4(&self) -> Result<bool> {
805 self.socket
806 .multicast_loop_v4()
807 .map_err(io_err_into_net_error)
808 }
809
810 fn set_multicast_loop_v6(&mut self, val: bool) -> Result<()> {
811 self.socket
812 .set_multicast_loop_v6(val)
813 .map_err(io_err_into_net_error)
814 }
815
816 fn multicast_loop_v6(&self) -> Result<bool> {
817 self.socket
818 .multicast_loop_v6()
819 .map_err(io_err_into_net_error)
820 }
821
822 fn set_multicast_ttl_v4(&mut self, ttl: u32) -> Result<()> {
823 self.socket
824 .set_multicast_ttl_v4(ttl)
825 .map_err(io_err_into_net_error)
826 }
827
828 fn multicast_ttl_v4(&self) -> Result<u32> {
829 self.socket
830 .multicast_ttl_v4()
831 .map_err(io_err_into_net_error)
832 }
833
834 fn join_multicast_v4(&mut self, multiaddr: Ipv4Addr, iface: Ipv4Addr) -> Result<()> {
835 self.with_sock_ref(|s| s.join_multicast_v4(&multiaddr, &iface))
836 .map_err(io_err_into_net_error)
837 }
838
839 fn leave_multicast_v4(&mut self, multiaddr: Ipv4Addr, iface: Ipv4Addr) -> Result<()> {
840 self.with_sock_ref(|s| s.leave_multicast_v4(&multiaddr, &iface))
841 .map_err(io_err_into_net_error)
842 }
843
844 fn join_multicast_v6(&mut self, multiaddr: Ipv6Addr, iface: u32) -> Result<()> {
845 self.socket
846 .join_multicast_v6(&multiaddr, iface)
847 .map_err(io_err_into_net_error)
848 }
849
850 fn leave_multicast_v6(&mut self, multiaddr: Ipv6Addr, iface: u32) -> Result<()> {
851 self.socket
852 .leave_multicast_v6(&multiaddr, iface)
853 .map_err(io_err_into_net_error)
854 }
855
856 fn addr_peer(&self) -> Result<Option<SocketAddr>> {
857 self.socket
858 .peer_addr()
859 .map(Some)
860 .map_err(io_err_into_net_error)
861 }
862}
863
864impl VirtualConnectionlessSocket for LocalUdpSocket {
865 fn try_send_to(&mut self, data: &[u8], addr: SocketAddr) -> Result<usize> {
866 if let Some(ruleset) = self.ruleset.as_ref()
867 && !ruleset.allows_socket(addr, Direction::Outbound)
868 {
869 tracing::warn!(%addr, "try_send blocked by firewall rule");
870 return Err(NetworkError::PermissionDenied);
871 }
872
873 let ret = self
874 .socket
875 .send_to(data, addr)
876 .map_err(io_err_into_net_error);
877 match &ret {
878 Ok(0) | Err(NetworkError::WouldBlock) => {
879 if let HandlerGuardState::WakerMap(_, map) = &mut self.handler_guard {
880 map.pop(InterestType::Writable);
881 }
882 }
883 _ => {}
884 }
885 ret
886 }
887
888 fn try_recv_from(
889 &mut self,
890 buf: &mut [MaybeUninit<u8>],
891 peek: bool,
892 ) -> Result<(usize, SocketAddr)> {
893 let buf: &mut [u8] = unsafe { std::mem::transmute(buf) };
894 if peek {
895 self.socket.peek_from(buf)
896 } else {
897 self.socket.recv_from(buf)
898 }
899 .map_err(io_err_into_net_error)
900 }
901}
902
903impl VirtualSocket for LocalUdpSocket {
904 fn set_ttl(&mut self, ttl: u32) -> Result<()> {
905 self.socket.set_ttl(ttl).map_err(io_err_into_net_error)
906 }
907
908 fn ttl(&self) -> Result<u32> {
909 self.socket.ttl().map_err(io_err_into_net_error)
910 }
911
912 fn addr_local(&self) -> Result<SocketAddr> {
913 self.socket.local_addr().map_err(io_err_into_net_error)
914 }
915
916 fn status(&self) -> Result<SocketStatus> {
917 Ok(SocketStatus::Opened)
918 }
919
920 fn set_handler(&mut self, mut handler: Box<dyn InterestHandler + Send + Sync>) -> Result<()> {
921 if let HandlerGuardState::ExternalHandler(guard) = &mut self.handler_guard {
922 match guard.replace_handler(handler) {
923 Ok(()) => {
924 return Ok(());
925 }
926 Err(h) => handler = h,
927 }
928
929 if let Err(err) = guard.unregister(&mut self.socket) {
931 tracing::debug!("failed to unregister previous token - {}", err);
932 }
933 }
934
935 let guard = InterestGuard::new(
936 &self.selector,
937 handler,
938 &mut self.socket,
939 mio::Interest::READABLE.add(mio::Interest::WRITABLE),
940 )
941 .map_err(io_err_into_net_error)?;
942
943 self.handler_guard = HandlerGuardState::ExternalHandler(guard);
944
945 Ok(())
946 }
947}
948
949impl LocalUdpSocket {
950 fn split_borrow(
951 &mut self,
952 ) -> (
953 &mut HandlerGuardState,
954 &Arc<Selector>,
955 &mut mio::net::UdpSocket,
956 ) {
957 (&mut self.handler_guard, &self.selector, &mut self.socket)
958 }
959}
960
961impl VirtualIoSource for LocalUdpSocket {
962 fn remove_handler(&mut self) {
963 let mut guard = HandlerGuardState::None;
964 std::mem::swap(&mut guard, &mut self.handler_guard);
965 match guard {
966 HandlerGuardState::ExternalHandler(mut guard) => {
967 guard.unregister(&mut self.socket).ok();
968 }
969 HandlerGuardState::WakerMap(mut guard, _) => {
970 guard.unregister(&mut self.socket).ok();
971 }
972 HandlerGuardState::None => {}
973 }
974 }
975
976 fn poll_read_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
977 if !self.backlog.is_empty() {
978 let total = self.backlog.iter().map(|a| a.0.len()).sum();
979 return Poll::Ready(Ok(total));
980 }
981
982 let (state, selector, socket) = self.split_borrow();
983 let map = state_as_waker_map(state, selector, socket).map_err(io_err_into_net_error)?;
984 map.pop(InterestType::Readable);
985 map.add(InterestType::Readable, cx.waker());
986
987 let mut buffer = BytesMut::default();
988 buffer.reserve(10240);
989 let uninit: &mut [MaybeUninit<u8>] = buffer.spare_capacity_mut();
990 let uninit_unsafe: &mut [u8] = unsafe { std::mem::transmute(uninit) };
991
992 match self.socket.recv_from(uninit_unsafe) {
993 Ok((0, _)) => Poll::Ready(Ok(0)),
994 Ok((amt, peer)) => {
995 unsafe {
996 buffer.set_len(amt);
997 }
998 self.backlog.push_back((buffer, peer));
999 Poll::Ready(Ok(amt))
1000 }
1001 Err(err) if err.kind() == io::ErrorKind::ConnectionAborted => Poll::Ready(Ok(0)),
1002 Err(err) if err.kind() == io::ErrorKind::ConnectionReset => Poll::Ready(Ok(0)),
1003 Err(err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
1004 Err(err) => Poll::Ready(Err(io_err_into_net_error(err))),
1005 }
1006 }
1007
1008 fn poll_write_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
1009 let (state, selector, socket) = self.split_borrow();
1010 let map = state_as_waker_map(state, selector, socket).map_err(io_err_into_net_error)?;
1011 #[cfg(not(target_os = "windows"))]
1012 map.pop(InterestType::Writable);
1013 map.add(InterestType::Writable, cx.waker());
1014
1015 #[cfg(not(target_os = "windows"))]
1016 match libc_poll(socket.as_raw_fd(), libc::POLLOUT | libc::POLLHUP) {
1017 Some(val) if (val & libc::POLLHUP) != 0 => {
1018 return Poll::Ready(Ok(0));
1019 }
1020 Some(val) if (val & libc::POLLOUT) != 0 => return Poll::Ready(Ok(10240)),
1021 _ => {}
1022 }
1023
1024 #[cfg(target_os = "windows")]
1029 if map.has_interest(InterestType::Writable) {
1030 return Poll::Ready(Ok(10240));
1031 }
1032
1033 Poll::Pending
1034 }
1035}