1#![allow(unused_variables)]
2use crate::ruleset::{Direction, Ruleset};
3#[allow(unused_imports)]
4use crate::{
5 IpCidr, IpRoute, NetworkError, Result, SocketStatus, StreamSecurity, VirtualConnectedSocket,
6 VirtualConnectionlessSocket, VirtualIcmpSocket, VirtualNetworking, VirtualRawSocket,
7 VirtualSocket, VirtualTcpListener, VirtualTcpSocket, VirtualUdpSocket,
8};
9use crate::{VirtualIoSource, io_err_into_net_error};
10use bytes::{Buf, BytesMut};
11use std::collections::VecDeque;
12use std::io::{self, Read, Write};
13use std::mem::MaybeUninit;
14use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, Shutdown, SocketAddr};
15#[cfg(not(target_os = "windows"))]
16use std::os::fd::AsRawFd;
17#[cfg(not(target_os = "windows"))]
18use std::os::fd::RawFd;
19#[cfg(windows)]
20use std::os::windows::io::AsRawSocket;
21
22use std::sync::Arc;
23use std::task::Poll;
24use std::time::Duration;
25use tokio::runtime::Handle;
26#[allow(unused_imports, dead_code)]
27use tracing::{debug, error, info, trace, warn};
28use virtual_mio::{
29 HandlerGuardState, InterestGuard, InterestHandler, InterestType, Selector, state_as_waker_map,
30};
31
32#[derive(Debug)]
33pub struct LocalNetworking {
34 selector: Arc<Selector>,
35 handle: Handle,
36 ruleset: Option<Ruleset>,
37}
38
39impl LocalNetworking {
40 pub fn new() -> Self {
41 Self {
42 selector: Selector::new(),
43 handle: Handle::current(),
44 ruleset: None,
45 }
46 }
47
48 pub fn with_ruleset(ruleset: Ruleset) -> Self {
49 Self {
50 selector: Selector::new(),
51 handle: Handle::current(),
52 ruleset: Some(ruleset),
53 }
54 }
55}
56
57impl Drop for LocalNetworking {
58 fn drop(&mut self) {
59 self.selector.shutdown();
60 }
61}
62
63impl Default for LocalNetworking {
64 fn default() -> Self {
65 Self::new()
66 }
67}
68
69#[async_trait::async_trait]
70#[allow(unused_variables)]
71impl VirtualNetworking for LocalNetworking {
72 async fn listen_tcp(
73 &self,
74 addr: SocketAddr,
75 only_v6: bool,
76 reuse_port: bool,
77 reuse_addr: bool,
78 ) -> Result<Box<dyn VirtualTcpListener + Sync>> {
79 if let Some(ruleset) = self.ruleset.as_ref()
80 && !ruleset.allows_socket(addr, Direction::Inbound)
81 {
82 tracing::warn!(%addr, "listen_tcp blocked by firewall rule");
83 return Err(NetworkError::PermissionDenied);
84 }
85
86 let listener = std::net::TcpListener::bind(addr)
87 .map(|sock| {
88 sock.set_nonblocking(true).ok();
89 Box::new(LocalTcpListener {
90 stream: mio::net::TcpListener::from_std(sock),
91 selector: self.selector.clone(),
92 handler_guard: HandlerGuardState::None,
93 no_delay: None,
94 keep_alive: None,
95 backlog: Default::default(),
96 ruleset: self.ruleset.clone(),
97 })
98 })
99 .map_err(io_err_into_net_error)?;
100 Ok(listener)
101 }
102
103 async fn bind_udp(
104 &self,
105 addr: SocketAddr,
106 reuse_port: bool,
107 reuse_addr: bool,
108 ) -> Result<Box<dyn VirtualUdpSocket + Sync>> {
109 #[cfg(not(windows))]
110 use socket2::{Domain, Socket, Type};
111
112 if let Some(ruleset) = self.ruleset.as_ref()
113 && !ruleset.allows_socket(addr, Direction::Inbound)
114 {
115 tracing::warn!(%addr, "bind_udp blocked by firewall rule");
116 return Err(NetworkError::PermissionDenied);
117 }
118
119 #[cfg(not(windows))]
120 let socket = {
121 let domain = if addr.is_ipv4() {
122 Domain::IPV4
123 } else {
124 Domain::IPV6
125 };
126 let std_sock = Socket::new(domain, Type::DGRAM, None).map_err(io_err_into_net_error)?;
127 std_sock
128 .set_nonblocking(true)
129 .map_err(io_err_into_net_error)?;
130 std_sock
131 .set_reuse_address(reuse_addr)
132 .map_err(io_err_into_net_error)?;
133 std_sock
134 .set_reuse_port(reuse_port)
135 .map_err(io_err_into_net_error)?;
136 std_sock.bind(&addr.into()).map_err(io_err_into_net_error)?;
137 mio::net::UdpSocket::from_std(std_sock.into())
138 };
139 #[cfg(windows)]
140 let socket = mio::net::UdpSocket::bind(addr).map_err(io_err_into_net_error)?;
141
142 #[allow(unused_mut)]
143 let mut ret = LocalUdpSocket {
144 selector: self.selector.clone(),
145 socket,
146 addr,
147 handler_guard: HandlerGuardState::None,
148 backlog: Default::default(),
149 ruleset: self.ruleset.clone(),
150 };
151
152 #[cfg(target_os = "windows")]
157 {
158 let (state, selector, socket) = ret.split_borrow();
159 let map = state_as_waker_map(state, selector, socket).map_err(io_err_into_net_error)?;
160 map.push(InterestType::Writable);
161 }
162
163 Ok(Box::new(ret))
164 }
165
166 async fn connect_tcp(
167 &self,
168 _addr: SocketAddr,
169 mut peer: SocketAddr,
170 ) -> Result<Box<dyn VirtualTcpSocket + Sync>> {
171 if let Some(ruleset) = self.ruleset.as_ref()
172 && !ruleset.allows_socket(peer, Direction::Outbound)
173 {
174 tracing::warn!(%peer, "connect_tcp blocked by firewall rule");
175 return Err(NetworkError::PermissionDenied);
176 }
177
178 let stream = mio::net::TcpStream::connect(peer).map_err(io_err_into_net_error)?;
179
180 if let Ok(p) = stream.peer_addr() {
181 peer = p;
182 }
183 let socket = Box::new(LocalTcpStream::new(self.selector.clone(), stream, peer));
184 Ok(socket)
185 }
186
187 async fn resolve(
188 &self,
189 host: &str,
190 port: Option<u16>,
191 dns_server: Option<IpAddr>,
192 ) -> Result<Vec<IpAddr>> {
193 if let Some(ruleset) = self.ruleset.as_ref()
194 && !ruleset.allows_domain(host)
195 {
196 tracing::warn!(%host, "dns resolve blocked by firewall rule");
197 return Err(NetworkError::PermissionDenied);
198 }
199
200 let host_to_lookup = if host.contains(':') {
201 host.to_string()
202 } else {
203 format!("{}:{}", host, port.unwrap_or(0))
204 };
205 let addrs = self
206 .handle
207 .spawn(tokio::net::lookup_host(host_to_lookup))
208 .await
209 .map_err(|_| NetworkError::IOError)?
210 .map(|a| a.map(|a| a.ip()).collect::<Vec<_>>())
211 .map_err(io_err_into_net_error)?;
212
213 if let Some(ruleset) = self.ruleset.as_ref() {
214 if let Err(e) = ruleset.expand_domain(host, &addrs) {
215 tracing::debug!(err=%e, "ruleset expansion failed");
216 } else {
217 tracing::debug!(addrs=?addrs, domain = host, "ruleset expansion")
218 }
219 }
220
221 Ok(addrs)
222 }
223}
224
225#[derive(Debug)]
226pub struct LocalTcpListener {
227 stream: mio::net::TcpListener,
228 selector: Arc<Selector>,
229 handler_guard: HandlerGuardState,
230 no_delay: Option<bool>,
231 keep_alive: Option<bool>,
232 backlog: VecDeque<(Box<dyn VirtualTcpSocket + Sync>, SocketAddr)>,
233 ruleset: Option<Ruleset>,
234}
235
236impl LocalTcpListener {
237 fn try_accept_internal(&mut self) -> Result<(Box<dyn VirtualTcpSocket + Sync>, SocketAddr)> {
238 match self.stream.accept().map_err(io_err_into_net_error) {
239 Ok((stream, addr)) => {
240 if let Some(ruleset) = self.ruleset.as_ref()
241 && !ruleset.allows_socket(addr, Direction::Outbound)
242 {
243 tracing::warn!(%addr, "try_accept blocked by firewall rule");
244 return Err(NetworkError::PermissionDenied);
245 }
246
247 let mut socket = LocalTcpStream::new(self.selector.clone(), stream, addr);
248 if let Some(no_delay) = self.no_delay {
249 socket.set_nodelay(no_delay).ok();
250 }
251 if let Some(keep_alive) = self.keep_alive {
252 socket.set_keepalive(keep_alive).ok();
253 }
254 Ok((Box::new(socket), addr))
255 }
256 Err(NetworkError::WouldBlock) => {
257 if let HandlerGuardState::WakerMap(_, map) = &mut self.handler_guard {
258 map.pop(InterestType::Readable);
259 map.pop(InterestType::Writable);
260 }
261 Err(NetworkError::WouldBlock)
262 }
263 Err(err) => Err(err),
264 }
265 }
266}
267
268impl VirtualTcpListener for LocalTcpListener {
269 fn try_accept(&mut self) -> Result<(Box<dyn VirtualTcpSocket + Sync>, SocketAddr)> {
270 if let Some(child) = self.backlog.pop_front() {
271 return Ok(child);
272 }
273 self.try_accept_internal()
274 }
275
276 fn set_handler(&mut self, mut handler: Box<dyn InterestHandler + Send + Sync>) -> Result<()> {
277 if let HandlerGuardState::ExternalHandler(guard) = &mut self.handler_guard {
278 match guard.replace_handler(handler) {
279 Ok(()) => return Ok(()),
280 Err(h) => handler = h,
281 }
282
283 if let Err(err) = guard.unregister(&mut self.stream) {
285 tracing::debug!("failed to unregister previous token - {}", err);
286 }
287 }
288
289 let guard = InterestGuard::new(
290 &self.selector,
291 handler,
292 &mut self.stream,
293 mio::Interest::READABLE.add(mio::Interest::WRITABLE),
294 )
295 .map_err(io_err_into_net_error)?;
296
297 self.handler_guard = HandlerGuardState::ExternalHandler(guard);
298
299 Ok(())
300 }
301
302 fn addr_local(&self) -> Result<SocketAddr> {
303 self.stream.local_addr().map_err(io_err_into_net_error)
304 }
305
306 fn set_ttl(&mut self, ttl: u8) -> Result<()> {
307 self.stream
308 .set_ttl(ttl as u32)
309 .map_err(io_err_into_net_error)
310 }
311
312 fn ttl(&self) -> Result<u8> {
313 self.stream
314 .ttl()
315 .map(|ttl| ttl as u8)
316 .map_err(io_err_into_net_error)
317 }
318}
319
320impl LocalTcpListener {
321 fn split_borrow(
322 &mut self,
323 ) -> (
324 &mut HandlerGuardState,
325 &Arc<Selector>,
326 &mut mio::net::TcpListener,
327 ) {
328 (&mut self.handler_guard, &self.selector, &mut self.stream)
329 }
330}
331
332impl VirtualIoSource for LocalTcpListener {
333 fn remove_handler(&mut self) {
334 let mut guard = HandlerGuardState::None;
335 std::mem::swap(&mut guard, &mut self.handler_guard);
336 match guard {
337 HandlerGuardState::ExternalHandler(mut guard) => {
338 guard.unregister(&mut self.stream).ok();
339 }
340 HandlerGuardState::WakerMap(mut guard, _) => {
341 guard.unregister(&mut self.stream).ok();
342 }
343 HandlerGuardState::None => {}
344 }
345 }
346
347 fn poll_read_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
348 if !self.backlog.is_empty() {
349 return Poll::Ready(Ok(self.backlog.len()));
350 }
351
352 let (state, selector, source) = self.split_borrow();
353 let map = state_as_waker_map(state, selector, source).map_err(io_err_into_net_error)?;
354 map.add(InterestType::Readable, cx.waker());
355
356 if let Ok(child) = self.try_accept_internal() {
357 self.backlog.push_back(child);
358 return Poll::Ready(Ok(1));
359 }
360 Poll::Pending
361 }
362
363 fn poll_write_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
364 if !self.backlog.is_empty() {
365 return Poll::Ready(Ok(self.backlog.len()));
366 }
367
368 let (state, selector, source) = self.split_borrow();
369 let map = state_as_waker_map(state, selector, source).map_err(io_err_into_net_error)?;
370 map.add(InterestType::Writable, cx.waker());
371
372 if let Ok(child) = self.try_accept_internal() {
373 self.backlog.push_back(child);
374 return Poll::Ready(Ok(1));
375 }
376 Poll::Pending
377 }
378}
379
380#[derive(Debug)]
381pub struct LocalTcpStream {
382 stream: mio::net::TcpStream,
383 addr: SocketAddr,
384 shutdown: Option<Shutdown>,
385 selector: Arc<Selector>,
386 handler_guard: HandlerGuardState,
387 buffer: BytesMut,
388}
389
390impl LocalTcpStream {
391 fn new(selector: Arc<Selector>, stream: mio::net::TcpStream, addr: SocketAddr) -> Self {
392 #[allow(unused_mut)]
393 let mut ret = Self {
394 stream,
395 addr,
396 shutdown: None,
397 selector,
398 handler_guard: HandlerGuardState::None,
399 buffer: BytesMut::new(),
400 };
401
402 #[cfg(target_os = "windows")]
407 {
408 let (state, selector, socket, _) = ret.split_borrow();
409 if let Ok(map) = state_as_waker_map(state, selector, socket) {
410 map.push(InterestType::Writable);
411 }
412 }
413
414 ret
415 }
416
417 fn with_sock_ref<F, R>(&self, f: F) -> R
418 where
419 for<'a> F: FnOnce(socket2::SockRef<'a>) -> R,
420 {
421 #[cfg(not(windows))]
422 let r = socket2::SockRef::from(&self.stream);
423
424 #[cfg(windows)]
425 let b = unsafe {
426 std::os::windows::io::BorrowedSocket::borrow_raw(self.stream.as_raw_socket())
427 };
428 #[cfg(windows)]
429 let r = socket2::SockRef::from(&b);
430
431 f(r)
432 }
433}
434
435impl VirtualTcpSocket for LocalTcpStream {
436 fn set_recv_buf_size(&mut self, size: usize) -> Result<()> {
437 Ok(())
438 }
439
440 fn recv_buf_size(&self) -> Result<usize> {
441 Err(NetworkError::Unsupported)
442 }
443
444 fn set_send_buf_size(&mut self, size: usize) -> Result<()> {
445 Ok(())
446 }
447
448 fn send_buf_size(&self) -> Result<usize> {
449 Err(NetworkError::Unsupported)
450 }
451
452 fn set_nodelay(&mut self, nodelay: bool) -> Result<()> {
453 self.stream
454 .set_nodelay(nodelay)
455 .map_err(io_err_into_net_error)
456 }
457
458 fn nodelay(&self) -> Result<bool> {
459 self.stream.nodelay().map_err(io_err_into_net_error)
460 }
461
462 fn set_keepalive(&mut self, keepalive: bool) -> Result<()> {
463 self.with_sock_ref(|s| s.set_keepalive(true))
464 .map_err(io_err_into_net_error)?;
465 Ok(())
466 }
467
468 fn keepalive(&self) -> Result<bool> {
469 let ret = self
470 .with_sock_ref(|s| s.keepalive())
471 .map_err(io_err_into_net_error)?;
472 Ok(ret)
473 }
474
475 #[cfg(not(target_os = "windows"))]
476 fn set_dontroute(&mut self, val: bool) -> Result<()> {
477 let val = val as libc::c_int;
483 let payload = &val as *const libc::c_int as *const libc::c_void;
484 let err = unsafe {
485 libc::setsockopt(
486 self.stream.as_raw_fd(),
487 libc::SOL_SOCKET,
488 libc::SO_DONTROUTE,
489 payload,
490 std::mem::size_of::<libc::c_int>() as libc::socklen_t,
491 )
492 };
493 if err == -1 {
494 return Err(io_err_into_net_error(std::io::Error::last_os_error()));
495 }
496 Ok(())
497 }
498 #[cfg(target_os = "windows")]
499 fn set_dontroute(&mut self, val: bool) -> Result<()> {
500 Err(NetworkError::Unsupported)
501 }
502
503 #[cfg(not(target_os = "windows"))]
504 fn dontroute(&self) -> Result<bool> {
505 let mut payload: MaybeUninit<libc::c_int> = MaybeUninit::uninit();
506 let mut len = std::mem::size_of::<libc::c_int>() as libc::socklen_t;
507 let err = unsafe {
508 libc::getsockopt(
509 self.stream.as_raw_fd(),
510 libc::SOL_SOCKET,
511 libc::SO_DONTROUTE,
512 payload.as_mut_ptr().cast(),
513 &mut len,
514 )
515 };
516 if err == -1 {
517 return Err(io_err_into_net_error(std::io::Error::last_os_error()));
518 }
519 Ok(unsafe { payload.assume_init() != 0 })
520 }
521 #[cfg(target_os = "windows")]
522 fn dontroute(&self) -> Result<bool> {
523 Err(NetworkError::Unsupported)
524 }
525
526 fn addr_peer(&self) -> Result<SocketAddr> {
527 Ok(self.addr)
528 }
529
530 fn shutdown(&mut self, how: Shutdown) -> Result<()> {
531 self.stream.shutdown(how).map_err(io_err_into_net_error)?;
532 self.shutdown = Some(how);
533 Ok(())
534 }
535
536 fn is_closed(&self) -> bool {
537 false
538 }
539}
540
541impl VirtualConnectedSocket for LocalTcpStream {
542 fn set_linger(&mut self, linger: Option<Duration>) -> Result<()> {
543 self.with_sock_ref(|s| s.set_linger(linger))
544 .map_err(io_err_into_net_error)?;
545 Ok(())
546 }
547
548 fn linger(&self) -> Result<Option<Duration>> {
549 self.with_sock_ref(|s| s.linger())
550 .map_err(io_err_into_net_error)
551 }
552
553 fn try_send(&mut self, data: &[u8]) -> Result<usize> {
554 let ret = self.stream.write(data).map_err(io_err_into_net_error);
555 match &ret {
556 Ok(0) | Err(NetworkError::WouldBlock) => {
557 if let HandlerGuardState::WakerMap(_, map) = &mut self.handler_guard {
558 map.pop(InterestType::Writable);
559 }
560 }
561 _ => {}
562 }
563 ret
564 }
565
566 fn try_flush(&mut self) -> Result<()> {
567 self.stream.flush().map_err(io_err_into_net_error)
568 }
569
570 fn close(&mut self) -> Result<()> {
571 Ok(())
572 }
573
574 fn try_recv(&mut self, buf: &mut [MaybeUninit<u8>], peek: bool) -> Result<usize> {
575 let buf: &mut [u8] = unsafe { std::mem::transmute(buf) };
576 if !self.buffer.is_empty() {
577 let amt = buf.len().min(self.buffer.len());
578 buf[..amt].copy_from_slice(&self.buffer[..amt]);
579 if !peek {
580 self.buffer.advance(amt);
581 }
582 return Ok(amt);
583 }
584
585 if peek {
586 self.stream.peek(buf)
587 } else {
588 self.stream.read(buf)
589 }
590 .map_err(io_err_into_net_error)
591 }
592}
593
594impl VirtualSocket for LocalTcpStream {
595 fn set_ttl(&mut self, ttl: u32) -> Result<()> {
596 self.stream.set_ttl(ttl).map_err(io_err_into_net_error)
597 }
598
599 fn ttl(&self) -> Result<u32> {
600 self.stream.ttl().map_err(io_err_into_net_error)
601 }
602
603 fn addr_local(&self) -> Result<SocketAddr> {
604 self.stream.local_addr().map_err(io_err_into_net_error)
605 }
606
607 fn status(&self) -> Result<SocketStatus> {
608 Ok(SocketStatus::Opened)
609 }
610
611 fn set_handler(&mut self, mut handler: Box<dyn InterestHandler + Send + Sync>) -> Result<()> {
612 if let HandlerGuardState::ExternalHandler(guard) = &mut self.handler_guard {
613 match guard.replace_handler(handler) {
614 Ok(()) => return Ok(()),
615 Err(h) => handler = h,
616 }
617
618 if let Err(err) = guard.unregister(&mut self.stream) {
620 tracing::debug!("failed to unregister previous token - {}", err);
621 }
622 }
623
624 let guard = InterestGuard::new(
625 &self.selector,
626 handler,
627 &mut self.stream,
628 mio::Interest::READABLE.add(mio::Interest::WRITABLE),
629 )
630 .map_err(io_err_into_net_error)?;
631
632 self.handler_guard = HandlerGuardState::ExternalHandler(guard);
633
634 Ok(())
635 }
636}
637
638impl LocalTcpStream {
639 fn split_borrow(
640 &mut self,
641 ) -> (
642 &mut HandlerGuardState,
643 &Arc<Selector>,
644 &mut mio::net::TcpStream,
645 &mut BytesMut,
646 ) {
647 (
648 &mut self.handler_guard,
649 &self.selector,
650 &mut self.stream,
651 &mut self.buffer,
652 )
653 }
654}
655
656impl VirtualIoSource for LocalTcpStream {
657 fn remove_handler(&mut self) {
658 let mut guard = HandlerGuardState::None;
659 std::mem::swap(&mut guard, &mut self.handler_guard);
660 match guard {
661 HandlerGuardState::ExternalHandler(mut guard) => {
662 guard.unregister(&mut self.stream).ok();
663 }
664 HandlerGuardState::WakerMap(mut guard, _) => {
665 guard.unregister(&mut self.stream).ok();
666 }
667 HandlerGuardState::None => {}
668 }
669 }
670
671 fn poll_read_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
672 if !self.buffer.is_empty() {
673 return Poll::Ready(Ok(self.buffer.len()));
674 }
675
676 let (state, selector, stream, buffer) = self.split_borrow();
677 let map = state_as_waker_map(state, selector, stream).map_err(io_err_into_net_error)?;
678 map.pop(InterestType::Readable);
679 map.add(InterestType::Readable, cx.waker());
680
681 buffer.reserve(buffer.len() + 10240);
682 let uninit: &mut [MaybeUninit<u8>] = buffer.spare_capacity_mut();
683 let uninit_unsafe: &mut [u8] = unsafe { std::mem::transmute(uninit) };
684
685 match stream.read(uninit_unsafe) {
686 Ok(0) => Poll::Ready(Ok(0)),
687 Ok(amt) => {
688 unsafe {
689 buffer.set_len(buffer.len() + amt);
690 }
691 Poll::Ready(Ok(amt))
692 }
693 Err(err) if err.kind() == io::ErrorKind::ConnectionAborted => Poll::Ready(Ok(0)),
694 Err(err) if err.kind() == io::ErrorKind::ConnectionReset => Poll::Ready(Ok(0)),
695 Err(err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
696 Err(err) => Poll::Ready(Err(io_err_into_net_error(err))),
697 }
698 }
699
700 fn poll_write_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
701 let (state, selector, stream, _) = self.split_borrow();
702 let map = state_as_waker_map(state, selector, stream).map_err(io_err_into_net_error)?;
703 #[cfg(not(target_os = "windows"))]
704 map.pop(InterestType::Writable);
705 map.add(InterestType::Writable, cx.waker());
706 map.add(InterestType::Closed, cx.waker());
707 if map.has_interest(InterestType::Closed) {
708 return Poll::Ready(Ok(0));
709 }
710
711 #[cfg(not(target_os = "windows"))]
712 match libc_poll(stream.as_raw_fd(), libc::POLLOUT | libc::POLLHUP) {
713 Some(val) if (val & libc::POLLHUP) != 0 => {
714 return Poll::Ready(Ok(0));
715 }
716 Some(val) if (val & libc::POLLOUT) != 0 => return Poll::Ready(Ok(10240)),
717 _ => {}
718 }
719
720 #[cfg(target_os = "windows")]
725 if map.has_interest(InterestType::Writable) {
726 return Poll::Ready(Ok(10240));
727 }
728
729 Poll::Pending
730 }
731}
732
733#[cfg(not(target_os = "windows"))]
734fn libc_poll(fd: RawFd, events: libc::c_short) -> Option<libc::c_short> {
735 let mut fds: [libc::pollfd; 1] = [libc::pollfd {
736 fd,
737 events,
738 revents: 0,
739 }];
740 let fds_mut = &mut fds[..];
741 let ret = unsafe { libc::poll(fds_mut.as_mut_ptr(), 1, 0) };
742 match ret == 1 {
743 true => Some(fds[0].revents),
744 false => None,
745 }
746}
747
748#[derive(Debug)]
749pub struct LocalUdpSocket {
750 socket: mio::net::UdpSocket,
751 #[allow(dead_code)]
752 addr: SocketAddr,
753 selector: Arc<Selector>,
754 handler_guard: HandlerGuardState,
755 backlog: VecDeque<(BytesMut, SocketAddr)>,
756 ruleset: Option<Ruleset>,
757}
758
759impl LocalUdpSocket {
760 fn with_sock_ref<F, R>(&self, f: F) -> R
761 where
762 for<'a> F: FnOnce(socket2::SockRef<'a>) -> R,
763 {
764 #[cfg(not(windows))]
765 let r = socket2::SockRef::from(&self.socket);
766
767 #[cfg(windows)]
768 let b = unsafe {
769 std::os::windows::io::BorrowedSocket::borrow_raw(self.socket.as_raw_socket())
770 };
771 #[cfg(windows)]
772 let r = socket2::SockRef::from(&b);
773
774 f(r)
775 }
776}
777
778impl VirtualUdpSocket for LocalUdpSocket {
779 fn set_broadcast(&mut self, broadcast: bool) -> Result<()> {
780 self.socket
781 .set_broadcast(broadcast)
782 .map_err(io_err_into_net_error)
783 }
784
785 fn broadcast(&self) -> Result<bool> {
786 self.socket.broadcast().map_err(io_err_into_net_error)
787 }
788
789 fn set_multicast_loop_v4(&mut self, val: bool) -> Result<()> {
790 self.socket
791 .set_multicast_loop_v4(val)
792 .map_err(io_err_into_net_error)
793 }
794
795 fn multicast_loop_v4(&self) -> Result<bool> {
796 self.socket
797 .multicast_loop_v4()
798 .map_err(io_err_into_net_error)
799 }
800
801 fn set_multicast_loop_v6(&mut self, val: bool) -> Result<()> {
802 self.socket
803 .set_multicast_loop_v6(val)
804 .map_err(io_err_into_net_error)
805 }
806
807 fn multicast_loop_v6(&self) -> Result<bool> {
808 self.socket
809 .multicast_loop_v6()
810 .map_err(io_err_into_net_error)
811 }
812
813 fn set_multicast_ttl_v4(&mut self, ttl: u32) -> Result<()> {
814 self.socket
815 .set_multicast_ttl_v4(ttl)
816 .map_err(io_err_into_net_error)
817 }
818
819 fn multicast_ttl_v4(&self) -> Result<u32> {
820 self.socket
821 .multicast_ttl_v4()
822 .map_err(io_err_into_net_error)
823 }
824
825 fn join_multicast_v4(&mut self, multiaddr: Ipv4Addr, iface: Ipv4Addr) -> Result<()> {
826 self.with_sock_ref(|s| s.join_multicast_v4(&multiaddr, &iface))
827 .map_err(io_err_into_net_error)
828 }
829
830 fn leave_multicast_v4(&mut self, multiaddr: Ipv4Addr, iface: Ipv4Addr) -> Result<()> {
831 self.with_sock_ref(|s| s.leave_multicast_v4(&multiaddr, &iface))
832 .map_err(io_err_into_net_error)
833 }
834
835 fn join_multicast_v6(&mut self, multiaddr: Ipv6Addr, iface: u32) -> Result<()> {
836 self.socket
837 .join_multicast_v6(&multiaddr, iface)
838 .map_err(io_err_into_net_error)
839 }
840
841 fn leave_multicast_v6(&mut self, multiaddr: Ipv6Addr, iface: u32) -> Result<()> {
842 self.socket
843 .leave_multicast_v6(&multiaddr, iface)
844 .map_err(io_err_into_net_error)
845 }
846
847 fn addr_peer(&self) -> Result<Option<SocketAddr>> {
848 self.socket
849 .peer_addr()
850 .map(Some)
851 .map_err(io_err_into_net_error)
852 }
853}
854
855impl VirtualConnectionlessSocket for LocalUdpSocket {
856 fn try_send_to(&mut self, data: &[u8], addr: SocketAddr) -> Result<usize> {
857 if let Some(ruleset) = self.ruleset.as_ref()
858 && !ruleset.allows_socket(addr, Direction::Outbound)
859 {
860 tracing::warn!(%addr, "try_send blocked by firewall rule");
861 return Err(NetworkError::PermissionDenied);
862 }
863
864 let ret = self
865 .socket
866 .send_to(data, addr)
867 .map_err(io_err_into_net_error);
868 match &ret {
869 Ok(0) | Err(NetworkError::WouldBlock) => {
870 if let HandlerGuardState::WakerMap(_, map) = &mut self.handler_guard {
871 map.pop(InterestType::Writable);
872 }
873 }
874 _ => {}
875 }
876 ret
877 }
878
879 fn try_recv_from(
880 &mut self,
881 buf: &mut [MaybeUninit<u8>],
882 peek: bool,
883 ) -> Result<(usize, SocketAddr)> {
884 let buf: &mut [u8] = unsafe { std::mem::transmute(buf) };
885 if peek {
886 self.socket.peek_from(buf)
887 } else {
888 self.socket.recv_from(buf)
889 }
890 .map_err(io_err_into_net_error)
891 }
892}
893
894impl VirtualSocket for LocalUdpSocket {
895 fn set_ttl(&mut self, ttl: u32) -> Result<()> {
896 self.socket.set_ttl(ttl).map_err(io_err_into_net_error)
897 }
898
899 fn ttl(&self) -> Result<u32> {
900 self.socket.ttl().map_err(io_err_into_net_error)
901 }
902
903 fn addr_local(&self) -> Result<SocketAddr> {
904 self.socket.local_addr().map_err(io_err_into_net_error)
905 }
906
907 fn status(&self) -> Result<SocketStatus> {
908 Ok(SocketStatus::Opened)
909 }
910
911 fn set_handler(&mut self, mut handler: Box<dyn InterestHandler + Send + Sync>) -> Result<()> {
912 if let HandlerGuardState::ExternalHandler(guard) = &mut self.handler_guard {
913 match guard.replace_handler(handler) {
914 Ok(()) => {
915 return Ok(());
916 }
917 Err(h) => handler = h,
918 }
919
920 if let Err(err) = guard.unregister(&mut self.socket) {
922 tracing::debug!("failed to unregister previous token - {}", err);
923 }
924 }
925
926 let guard = InterestGuard::new(
927 &self.selector,
928 handler,
929 &mut self.socket,
930 mio::Interest::READABLE.add(mio::Interest::WRITABLE),
931 )
932 .map_err(io_err_into_net_error)?;
933
934 self.handler_guard = HandlerGuardState::ExternalHandler(guard);
935
936 Ok(())
937 }
938}
939
940impl LocalUdpSocket {
941 fn split_borrow(
942 &mut self,
943 ) -> (
944 &mut HandlerGuardState,
945 &Arc<Selector>,
946 &mut mio::net::UdpSocket,
947 ) {
948 (&mut self.handler_guard, &self.selector, &mut self.socket)
949 }
950}
951
952impl VirtualIoSource for LocalUdpSocket {
953 fn remove_handler(&mut self) {
954 let mut guard = HandlerGuardState::None;
955 std::mem::swap(&mut guard, &mut self.handler_guard);
956 match guard {
957 HandlerGuardState::ExternalHandler(mut guard) => {
958 guard.unregister(&mut self.socket).ok();
959 }
960 HandlerGuardState::WakerMap(mut guard, _) => {
961 guard.unregister(&mut self.socket).ok();
962 }
963 HandlerGuardState::None => {}
964 }
965 }
966
967 fn poll_read_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
968 if !self.backlog.is_empty() {
969 let total = self.backlog.iter().map(|a| a.0.len()).sum();
970 return Poll::Ready(Ok(total));
971 }
972
973 let (state, selector, socket) = self.split_borrow();
974 let map = state_as_waker_map(state, selector, socket).map_err(io_err_into_net_error)?;
975 map.pop(InterestType::Readable);
976 map.add(InterestType::Readable, cx.waker());
977
978 let mut buffer = BytesMut::default();
979 buffer.reserve(10240);
980 let uninit: &mut [MaybeUninit<u8>] = buffer.spare_capacity_mut();
981 let uninit_unsafe: &mut [u8] = unsafe { std::mem::transmute(uninit) };
982
983 match self.socket.recv_from(uninit_unsafe) {
984 Ok((0, _)) => Poll::Ready(Ok(0)),
985 Ok((amt, peer)) => {
986 unsafe {
987 buffer.set_len(amt);
988 }
989 self.backlog.push_back((buffer, peer));
990 Poll::Ready(Ok(amt))
991 }
992 Err(err) if err.kind() == io::ErrorKind::ConnectionAborted => Poll::Ready(Ok(0)),
993 Err(err) if err.kind() == io::ErrorKind::ConnectionReset => Poll::Ready(Ok(0)),
994 Err(err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
995 Err(err) => Poll::Ready(Err(io_err_into_net_error(err))),
996 }
997 }
998
999 fn poll_write_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
1000 let (state, selector, socket) = self.split_borrow();
1001 let map = state_as_waker_map(state, selector, socket).map_err(io_err_into_net_error)?;
1002 #[cfg(not(target_os = "windows"))]
1003 map.pop(InterestType::Writable);
1004 map.add(InterestType::Writable, cx.waker());
1005
1006 #[cfg(not(target_os = "windows"))]
1007 match libc_poll(socket.as_raw_fd(), libc::POLLOUT | libc::POLLHUP) {
1008 Some(val) if (val & libc::POLLHUP) != 0 => {
1009 return Poll::Ready(Ok(0));
1010 }
1011 Some(val) if (val & libc::POLLOUT) != 0 => return Poll::Ready(Ok(10240)),
1012 _ => {}
1013 }
1014
1015 #[cfg(target_os = "windows")]
1020 if map.has_interest(InterestType::Writable) {
1021 return Poll::Ready(Ok(10240));
1022 }
1023
1024 Poll::Pending
1025 }
1026}