1#![allow(unused_variables)]
2use crate::ruleset::{Direction, Ruleset};
3#[allow(unused_imports)]
4use crate::{
5 IpCidr, IpRoute, NetworkError, Result, SocketStatus, StreamSecurity, VirtualConnectedSocket,
6 VirtualConnectionlessSocket, VirtualIcmpSocket, VirtualNetworking, VirtualRawSocket,
7 VirtualSocket, VirtualTcpListener, VirtualTcpSocket, VirtualUdpSocket,
8};
9use crate::{VirtualIoSource, io_err_into_net_error};
10use bytes::{Buf, BytesMut};
11use std::collections::VecDeque;
12use std::io::{self, Read, Write};
13use std::mem::MaybeUninit;
14use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, Shutdown, SocketAddr};
15#[cfg(not(target_os = "windows"))]
16use std::os::fd::AsRawFd;
17#[cfg(not(target_os = "windows"))]
18use std::os::fd::RawFd;
19#[cfg(windows)]
20use std::os::windows::io::AsRawSocket;
21
22use std::sync::Arc;
23use std::task::Poll;
24use std::time::Duration;
25use tokio::runtime::Handle;
26#[allow(unused_imports, dead_code)]
27use tracing::{debug, error, info, trace, warn};
28use virtual_mio::{
29 HandlerGuardState, InterestGuard, InterestHandler, InterestType, Selector, state_as_waker_map,
30};
31
32#[derive(Debug)]
33pub struct LocalNetworking {
34 selector: Arc<Selector>,
35 handle: Handle,
36 ruleset: Option<Ruleset>,
37}
38
39impl LocalNetworking {
40 pub fn new() -> Self {
41 Self {
42 selector: Selector::new(),
43 handle: Handle::current(),
44 ruleset: None,
45 }
46 }
47
48 pub fn with_ruleset(ruleset: Ruleset) -> Self {
49 Self {
50 selector: Selector::new(),
51 handle: Handle::current(),
52 ruleset: Some(ruleset),
53 }
54 }
55}
56
57impl Drop for LocalNetworking {
58 fn drop(&mut self) {
59 self.selector.shutdown();
60 }
61}
62
63impl Default for LocalNetworking {
64 fn default() -> Self {
65 Self::new()
66 }
67}
68
69#[async_trait::async_trait]
70#[allow(unused_variables)]
71impl VirtualNetworking for LocalNetworking {
72 async fn listen_tcp(
73 &self,
74 addr: SocketAddr,
75 only_v6: bool,
76 reuse_port: bool,
77 reuse_addr: bool,
78 ) -> Result<Box<dyn VirtualTcpListener + Sync>> {
79 if let Some(ruleset) = self.ruleset.as_ref() {
80 if !ruleset.allows_socket(addr, Direction::Inbound) {
81 tracing::warn!(%addr, "listen_tcp blocked by firewall rule");
82 return Err(NetworkError::PermissionDenied);
83 }
84 }
85
86 let listener = std::net::TcpListener::bind(addr)
87 .map(|sock| {
88 sock.set_nonblocking(true).ok();
89 Box::new(LocalTcpListener {
90 stream: mio::net::TcpListener::from_std(sock),
91 selector: self.selector.clone(),
92 handler_guard: HandlerGuardState::None,
93 no_delay: None,
94 keep_alive: None,
95 backlog: Default::default(),
96 ruleset: self.ruleset.clone(),
97 })
98 })
99 .map_err(io_err_into_net_error)?;
100 Ok(listener)
101 }
102
103 async fn bind_udp(
104 &self,
105 addr: SocketAddr,
106 reuse_port: bool,
107 reuse_addr: bool,
108 ) -> Result<Box<dyn VirtualUdpSocket + Sync>> {
109 use socket2::{Domain, Socket, Type};
110
111 if let Some(ruleset) = self.ruleset.as_ref() {
112 if !ruleset.allows_socket(addr, Direction::Inbound) {
113 tracing::warn!(%addr, "bind_udp blocked by firewall rule");
114 return Err(NetworkError::PermissionDenied);
115 }
116 }
117
118 #[cfg(not(windows))]
119 let socket = {
120 let domain = if addr.is_ipv4() {
121 Domain::IPV4
122 } else {
123 Domain::IPV6
124 };
125 let std_sock = Socket::new(domain, Type::DGRAM, None).map_err(io_err_into_net_error)?;
126 std_sock
127 .set_nonblocking(true)
128 .map_err(io_err_into_net_error)?;
129 std_sock
130 .set_reuse_address(reuse_addr)
131 .map_err(io_err_into_net_error)?;
132 std_sock
133 .set_reuse_port(reuse_port)
134 .map_err(io_err_into_net_error)?;
135 std_sock.bind(&addr.into()).map_err(io_err_into_net_error)?;
136 mio::net::UdpSocket::from_std(std_sock.into())
137 };
138 #[cfg(windows)]
139 let socket = mio::net::UdpSocket::bind(addr).map_err(io_err_into_net_error)?;
140
141 #[allow(unused_mut)]
142 let mut ret = LocalUdpSocket {
143 selector: self.selector.clone(),
144 socket,
145 addr,
146 handler_guard: HandlerGuardState::None,
147 backlog: Default::default(),
148 ruleset: self.ruleset.clone(),
149 };
150
151 #[cfg(target_os = "windows")]
156 {
157 let (state, selector, socket) = ret.split_borrow();
158 let map = state_as_waker_map(state, selector, socket).map_err(io_err_into_net_error)?;
159 map.push(InterestType::Writable);
160 }
161
162 Ok(Box::new(ret))
163 }
164
165 async fn connect_tcp(
166 &self,
167 _addr: SocketAddr,
168 mut peer: SocketAddr,
169 ) -> Result<Box<dyn VirtualTcpSocket + Sync>> {
170 if let Some(ruleset) = self.ruleset.as_ref() {
171 if !ruleset.allows_socket(peer, Direction::Outbound) {
172 tracing::warn!(%peer, "connect_tcp blocked by firewall rule");
173 return Err(NetworkError::PermissionDenied);
174 }
175 }
176
177 let stream = mio::net::TcpStream::connect(peer).map_err(io_err_into_net_error)?;
178
179 if let Ok(p) = stream.peer_addr() {
180 peer = p;
181 }
182 let socket = Box::new(LocalTcpStream::new(self.selector.clone(), stream, peer));
183 Ok(socket)
184 }
185
186 async fn resolve(
187 &self,
188 host: &str,
189 port: Option<u16>,
190 dns_server: Option<IpAddr>,
191 ) -> Result<Vec<IpAddr>> {
192 if let Some(ruleset) = self.ruleset.as_ref() {
193 if !ruleset.allows_domain(host) {
194 tracing::warn!(%host, "dns resolve blocked by firewall rule");
195 return Err(NetworkError::PermissionDenied);
196 }
197 }
198
199 let host_to_lookup = if host.contains(':') {
200 host.to_string()
201 } else {
202 format!("{}:{}", host, port.unwrap_or(0))
203 };
204 let addrs = self
205 .handle
206 .spawn(tokio::net::lookup_host(host_to_lookup))
207 .await
208 .map_err(|_| NetworkError::IOError)?
209 .map(|a| a.map(|a| a.ip()).collect::<Vec<_>>())
210 .map_err(io_err_into_net_error)?;
211
212 if let Some(ruleset) = self.ruleset.as_ref() {
213 if let Err(e) = ruleset.expand_domain(host, &addrs) {
214 tracing::debug!(err=%e, "ruleset expansion failed");
215 } else {
216 tracing::debug!(addrs=?addrs, domain = host, "ruleset expansion")
217 }
218 }
219
220 Ok(addrs)
221 }
222}
223
224#[derive(Debug)]
225pub struct LocalTcpListener {
226 stream: mio::net::TcpListener,
227 selector: Arc<Selector>,
228 handler_guard: HandlerGuardState,
229 no_delay: Option<bool>,
230 keep_alive: Option<bool>,
231 backlog: VecDeque<(Box<dyn VirtualTcpSocket + Sync>, SocketAddr)>,
232 ruleset: Option<Ruleset>,
233}
234
235impl LocalTcpListener {
236 fn try_accept_internal(&mut self) -> Result<(Box<dyn VirtualTcpSocket + Sync>, SocketAddr)> {
237 match self.stream.accept().map_err(io_err_into_net_error) {
238 Ok((stream, addr)) => {
239 if let Some(ruleset) = self.ruleset.as_ref() {
240 if !ruleset.allows_socket(addr, Direction::Outbound) {
241 tracing::warn!(%addr, "try_accept blocked by firewall rule");
242 return Err(NetworkError::PermissionDenied);
243 }
244 }
245
246 let mut socket = LocalTcpStream::new(self.selector.clone(), stream, addr);
247 if let Some(no_delay) = self.no_delay {
248 socket.set_nodelay(no_delay).ok();
249 }
250 if let Some(keep_alive) = self.keep_alive {
251 socket.set_keepalive(keep_alive).ok();
252 }
253 Ok((Box::new(socket), addr))
254 }
255 Err(NetworkError::WouldBlock) => {
256 if let HandlerGuardState::WakerMap(_, map) = &mut self.handler_guard {
257 map.pop(InterestType::Readable);
258 map.pop(InterestType::Writable);
259 }
260 Err(NetworkError::WouldBlock)
261 }
262 Err(err) => Err(err),
263 }
264 }
265}
266
267impl VirtualTcpListener for LocalTcpListener {
268 fn try_accept(&mut self) -> Result<(Box<dyn VirtualTcpSocket + Sync>, SocketAddr)> {
269 if let Some(child) = self.backlog.pop_front() {
270 return Ok(child);
271 }
272 self.try_accept_internal()
273 }
274
275 fn set_handler(&mut self, mut handler: Box<dyn InterestHandler + Send + Sync>) -> Result<()> {
276 if let HandlerGuardState::ExternalHandler(guard) = &mut self.handler_guard {
277 match guard.replace_handler(handler) {
278 Ok(()) => return Ok(()),
279 Err(h) => handler = h,
280 }
281
282 if let Err(err) = guard.unregister(&mut self.stream) {
284 tracing::debug!("failed to unregister previous token - {}", err);
285 }
286 }
287
288 let guard = InterestGuard::new(
289 &self.selector,
290 handler,
291 &mut self.stream,
292 mio::Interest::READABLE.add(mio::Interest::WRITABLE),
293 )
294 .map_err(io_err_into_net_error)?;
295
296 self.handler_guard = HandlerGuardState::ExternalHandler(guard);
297
298 Ok(())
299 }
300
301 fn addr_local(&self) -> Result<SocketAddr> {
302 self.stream.local_addr().map_err(io_err_into_net_error)
303 }
304
305 fn set_ttl(&mut self, ttl: u8) -> Result<()> {
306 self.stream
307 .set_ttl(ttl as u32)
308 .map_err(io_err_into_net_error)
309 }
310
311 fn ttl(&self) -> Result<u8> {
312 self.stream
313 .ttl()
314 .map(|ttl| ttl as u8)
315 .map_err(io_err_into_net_error)
316 }
317}
318
319impl LocalTcpListener {
320 fn split_borrow(
321 &mut self,
322 ) -> (
323 &mut HandlerGuardState,
324 &Arc<Selector>,
325 &mut mio::net::TcpListener,
326 ) {
327 (&mut self.handler_guard, &self.selector, &mut self.stream)
328 }
329}
330
331impl VirtualIoSource for LocalTcpListener {
332 fn remove_handler(&mut self) {
333 let mut guard = HandlerGuardState::None;
334 std::mem::swap(&mut guard, &mut self.handler_guard);
335 match guard {
336 HandlerGuardState::ExternalHandler(mut guard) => {
337 guard.unregister(&mut self.stream).ok();
338 }
339 HandlerGuardState::WakerMap(mut guard, _) => {
340 guard.unregister(&mut self.stream).ok();
341 }
342 HandlerGuardState::None => {}
343 }
344 }
345
346 fn poll_read_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
347 if !self.backlog.is_empty() {
348 return Poll::Ready(Ok(self.backlog.len()));
349 }
350
351 let (state, selector, source) = self.split_borrow();
352 let map = state_as_waker_map(state, selector, source).map_err(io_err_into_net_error)?;
353 map.add(InterestType::Readable, cx.waker());
354
355 if let Ok(child) = self.try_accept_internal() {
356 self.backlog.push_back(child);
357 return Poll::Ready(Ok(1));
358 }
359 Poll::Pending
360 }
361
362 fn poll_write_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
363 if !self.backlog.is_empty() {
364 return Poll::Ready(Ok(self.backlog.len()));
365 }
366
367 let (state, selector, source) = self.split_borrow();
368 let map = state_as_waker_map(state, selector, source).map_err(io_err_into_net_error)?;
369 map.add(InterestType::Writable, cx.waker());
370
371 if let Ok(child) = self.try_accept_internal() {
372 self.backlog.push_back(child);
373 return Poll::Ready(Ok(1));
374 }
375 Poll::Pending
376 }
377}
378
379#[derive(Debug)]
380pub struct LocalTcpStream {
381 stream: mio::net::TcpStream,
382 addr: SocketAddr,
383 shutdown: Option<Shutdown>,
384 selector: Arc<Selector>,
385 handler_guard: HandlerGuardState,
386 buffer: BytesMut,
387}
388
389impl LocalTcpStream {
390 fn new(selector: Arc<Selector>, stream: mio::net::TcpStream, addr: SocketAddr) -> Self {
391 #[allow(unused_mut)]
392 let mut ret = Self {
393 stream,
394 addr,
395 shutdown: None,
396 selector,
397 handler_guard: HandlerGuardState::None,
398 buffer: BytesMut::new(),
399 };
400
401 #[cfg(target_os = "windows")]
406 {
407 let (state, selector, socket, _) = ret.split_borrow();
408 if let Ok(map) = state_as_waker_map(state, selector, socket) {
409 map.push(InterestType::Writable);
410 }
411 }
412
413 ret
414 }
415
416 fn with_sock_ref<F, R>(&self, f: F) -> R
417 where
418 for<'a> F: FnOnce(socket2::SockRef<'a>) -> R,
419 {
420 #[cfg(not(windows))]
421 let r = socket2::SockRef::from(&self.stream);
422
423 #[cfg(windows)]
424 let b = unsafe {
425 std::os::windows::io::BorrowedSocket::borrow_raw(self.stream.as_raw_socket())
426 };
427 #[cfg(windows)]
428 let r = socket2::SockRef::from(&b);
429
430 f(r)
431 }
432}
433
434impl VirtualTcpSocket for LocalTcpStream {
435 fn set_recv_buf_size(&mut self, size: usize) -> Result<()> {
436 Ok(())
437 }
438
439 fn recv_buf_size(&self) -> Result<usize> {
440 Err(NetworkError::Unsupported)
441 }
442
443 fn set_send_buf_size(&mut self, size: usize) -> Result<()> {
444 Ok(())
445 }
446
447 fn send_buf_size(&self) -> Result<usize> {
448 Err(NetworkError::Unsupported)
449 }
450
451 fn set_nodelay(&mut self, nodelay: bool) -> Result<()> {
452 self.stream
453 .set_nodelay(nodelay)
454 .map_err(io_err_into_net_error)
455 }
456
457 fn nodelay(&self) -> Result<bool> {
458 self.stream.nodelay().map_err(io_err_into_net_error)
459 }
460
461 fn set_keepalive(&mut self, keepalive: bool) -> Result<()> {
462 self.with_sock_ref(|s| s.set_keepalive(true))
463 .map_err(io_err_into_net_error)?;
464 Ok(())
465 }
466
467 fn keepalive(&self) -> Result<bool> {
468 let ret = self
469 .with_sock_ref(|s| s.keepalive())
470 .map_err(io_err_into_net_error)?;
471 Ok(ret)
472 }
473
474 #[cfg(not(target_os = "windows"))]
475 fn set_dontroute(&mut self, val: bool) -> Result<()> {
476 let val = val as libc::c_int;
482 let payload = &val as *const libc::c_int as *const libc::c_void;
483 let err = unsafe {
484 libc::setsockopt(
485 self.stream.as_raw_fd(),
486 libc::SOL_SOCKET,
487 libc::SO_DONTROUTE,
488 payload,
489 std::mem::size_of::<libc::c_int>() as libc::socklen_t,
490 )
491 };
492 if err == -1 {
493 return Err(io_err_into_net_error(std::io::Error::last_os_error()));
494 }
495 Ok(())
496 }
497 #[cfg(target_os = "windows")]
498 fn set_dontroute(&mut self, val: bool) -> Result<()> {
499 Err(NetworkError::Unsupported)
500 }
501
502 #[cfg(not(target_os = "windows"))]
503 fn dontroute(&self) -> Result<bool> {
504 let mut payload: MaybeUninit<libc::c_int> = MaybeUninit::uninit();
505 let mut len = std::mem::size_of::<libc::c_int>() as libc::socklen_t;
506 let err = unsafe {
507 libc::getsockopt(
508 self.stream.as_raw_fd(),
509 libc::SOL_SOCKET,
510 libc::SO_DONTROUTE,
511 payload.as_mut_ptr().cast(),
512 &mut len,
513 )
514 };
515 if err == -1 {
516 return Err(io_err_into_net_error(std::io::Error::last_os_error()));
517 }
518 Ok(unsafe { payload.assume_init() != 0 })
519 }
520 #[cfg(target_os = "windows")]
521 fn dontroute(&self) -> Result<bool> {
522 Err(NetworkError::Unsupported)
523 }
524
525 fn addr_peer(&self) -> Result<SocketAddr> {
526 Ok(self.addr)
527 }
528
529 fn shutdown(&mut self, how: Shutdown) -> Result<()> {
530 self.stream.shutdown(how).map_err(io_err_into_net_error)?;
531 self.shutdown = Some(how);
532 Ok(())
533 }
534
535 fn is_closed(&self) -> bool {
536 false
537 }
538}
539
540impl VirtualConnectedSocket for LocalTcpStream {
541 fn set_linger(&mut self, linger: Option<Duration>) -> Result<()> {
542 self.with_sock_ref(|s| s.set_linger(linger))
543 .map_err(io_err_into_net_error)?;
544 Ok(())
545 }
546
547 fn linger(&self) -> Result<Option<Duration>> {
548 self.with_sock_ref(|s| s.linger())
549 .map_err(io_err_into_net_error)
550 }
551
552 fn try_send(&mut self, data: &[u8]) -> Result<usize> {
553 let ret = self.stream.write(data).map_err(io_err_into_net_error);
554 match &ret {
555 Ok(0) | Err(NetworkError::WouldBlock) => {
556 if let HandlerGuardState::WakerMap(_, map) = &mut self.handler_guard {
557 map.pop(InterestType::Writable);
558 }
559 }
560 _ => {}
561 }
562 ret
563 }
564
565 fn try_flush(&mut self) -> Result<()> {
566 self.stream.flush().map_err(io_err_into_net_error)
567 }
568
569 fn close(&mut self) -> Result<()> {
570 Ok(())
571 }
572
573 fn try_recv(&mut self, buf: &mut [MaybeUninit<u8>], peek: bool) -> Result<usize> {
574 let buf: &mut [u8] = unsafe { std::mem::transmute(buf) };
575 if !self.buffer.is_empty() {
576 let amt = buf.len().min(self.buffer.len());
577 buf[..amt].copy_from_slice(&self.buffer[..amt]);
578 if !peek {
579 self.buffer.advance(amt);
580 }
581 return Ok(amt);
582 }
583
584 if peek {
585 self.stream.peek(buf)
586 } else {
587 self.stream.read(buf)
588 }
589 .map_err(io_err_into_net_error)
590 }
591}
592
593impl VirtualSocket for LocalTcpStream {
594 fn set_ttl(&mut self, ttl: u32) -> Result<()> {
595 self.stream.set_ttl(ttl).map_err(io_err_into_net_error)
596 }
597
598 fn ttl(&self) -> Result<u32> {
599 self.stream.ttl().map_err(io_err_into_net_error)
600 }
601
602 fn addr_local(&self) -> Result<SocketAddr> {
603 self.stream.local_addr().map_err(io_err_into_net_error)
604 }
605
606 fn status(&self) -> Result<SocketStatus> {
607 Ok(SocketStatus::Opened)
608 }
609
610 fn set_handler(&mut self, mut handler: Box<dyn InterestHandler + Send + Sync>) -> Result<()> {
611 if let HandlerGuardState::ExternalHandler(guard) = &mut self.handler_guard {
612 match guard.replace_handler(handler) {
613 Ok(()) => return Ok(()),
614 Err(h) => handler = h,
615 }
616
617 if let Err(err) = guard.unregister(&mut self.stream) {
619 tracing::debug!("failed to unregister previous token - {}", err);
620 }
621 }
622
623 let guard = InterestGuard::new(
624 &self.selector,
625 handler,
626 &mut self.stream,
627 mio::Interest::READABLE.add(mio::Interest::WRITABLE),
628 )
629 .map_err(io_err_into_net_error)?;
630
631 self.handler_guard = HandlerGuardState::ExternalHandler(guard);
632
633 Ok(())
634 }
635}
636
637impl LocalTcpStream {
638 fn split_borrow(
639 &mut self,
640 ) -> (
641 &mut HandlerGuardState,
642 &Arc<Selector>,
643 &mut mio::net::TcpStream,
644 &mut BytesMut,
645 ) {
646 (
647 &mut self.handler_guard,
648 &self.selector,
649 &mut self.stream,
650 &mut self.buffer,
651 )
652 }
653}
654
655impl VirtualIoSource for LocalTcpStream {
656 fn remove_handler(&mut self) {
657 let mut guard = HandlerGuardState::None;
658 std::mem::swap(&mut guard, &mut self.handler_guard);
659 match guard {
660 HandlerGuardState::ExternalHandler(mut guard) => {
661 guard.unregister(&mut self.stream).ok();
662 }
663 HandlerGuardState::WakerMap(mut guard, _) => {
664 guard.unregister(&mut self.stream).ok();
665 }
666 HandlerGuardState::None => {}
667 }
668 }
669
670 fn poll_read_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
671 if !self.buffer.is_empty() {
672 return Poll::Ready(Ok(self.buffer.len()));
673 }
674
675 let (state, selector, stream, buffer) = self.split_borrow();
676 let map = state_as_waker_map(state, selector, stream).map_err(io_err_into_net_error)?;
677 map.pop(InterestType::Readable);
678 map.add(InterestType::Readable, cx.waker());
679
680 buffer.reserve(buffer.len() + 10240);
681 let uninit: &mut [MaybeUninit<u8>] = buffer.spare_capacity_mut();
682 let uninit_unsafe: &mut [u8] = unsafe { std::mem::transmute(uninit) };
683
684 match stream.read(uninit_unsafe) {
685 Ok(0) => Poll::Ready(Ok(0)),
686 Ok(amt) => {
687 unsafe {
688 buffer.set_len(buffer.len() + amt);
689 }
690 Poll::Ready(Ok(amt))
691 }
692 Err(err) if err.kind() == io::ErrorKind::ConnectionAborted => Poll::Ready(Ok(0)),
693 Err(err) if err.kind() == io::ErrorKind::ConnectionReset => Poll::Ready(Ok(0)),
694 Err(err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
695 Err(err) => Poll::Ready(Err(io_err_into_net_error(err))),
696 }
697 }
698
699 fn poll_write_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
700 let (state, selector, stream, _) = self.split_borrow();
701 let map = state_as_waker_map(state, selector, stream).map_err(io_err_into_net_error)?;
702 #[cfg(not(target_os = "windows"))]
703 map.pop(InterestType::Writable);
704 map.add(InterestType::Writable, cx.waker());
705 map.add(InterestType::Closed, cx.waker());
706 if map.has_interest(InterestType::Closed) {
707 return Poll::Ready(Ok(0));
708 }
709
710 #[cfg(not(target_os = "windows"))]
711 match libc_poll(stream.as_raw_fd(), libc::POLLOUT | libc::POLLHUP) {
712 Some(val) if (val & libc::POLLHUP) != 0 => {
713 return Poll::Ready(Ok(0));
714 }
715 Some(val) if (val & libc::POLLOUT) != 0 => return Poll::Ready(Ok(10240)),
716 _ => {}
717 }
718
719 #[cfg(target_os = "windows")]
724 if map.has_interest(InterestType::Writable) {
725 return Poll::Ready(Ok(10240));
726 }
727
728 Poll::Pending
729 }
730}
731
732#[cfg(not(target_os = "windows"))]
733fn libc_poll(fd: RawFd, events: libc::c_short) -> Option<libc::c_short> {
734 let mut fds: [libc::pollfd; 1] = [libc::pollfd {
735 fd,
736 events,
737 revents: 0,
738 }];
739 let fds_mut = &mut fds[..];
740 let ret = unsafe { libc::poll(fds_mut.as_mut_ptr(), 1, 0) };
741 match ret == 1 {
742 true => Some(fds[0].revents),
743 false => None,
744 }
745}
746
747#[derive(Debug)]
748pub struct LocalUdpSocket {
749 socket: mio::net::UdpSocket,
750 #[allow(dead_code)]
751 addr: SocketAddr,
752 selector: Arc<Selector>,
753 handler_guard: HandlerGuardState,
754 backlog: VecDeque<(BytesMut, SocketAddr)>,
755 ruleset: Option<Ruleset>,
756}
757
758impl LocalUdpSocket {
759 fn with_sock_ref<F, R>(&self, f: F) -> R
760 where
761 for<'a> F: FnOnce(socket2::SockRef<'a>) -> R,
762 {
763 #[cfg(not(windows))]
764 let r = socket2::SockRef::from(&self.socket);
765
766 #[cfg(windows)]
767 let b = unsafe {
768 std::os::windows::io::BorrowedSocket::borrow_raw(self.socket.as_raw_socket())
769 };
770 #[cfg(windows)]
771 let r = socket2::SockRef::from(&b);
772
773 f(r)
774 }
775}
776
777impl VirtualUdpSocket for LocalUdpSocket {
778 fn set_broadcast(&mut self, broadcast: bool) -> Result<()> {
779 self.socket
780 .set_broadcast(broadcast)
781 .map_err(io_err_into_net_error)
782 }
783
784 fn broadcast(&self) -> Result<bool> {
785 self.socket.broadcast().map_err(io_err_into_net_error)
786 }
787
788 fn set_multicast_loop_v4(&mut self, val: bool) -> Result<()> {
789 self.socket
790 .set_multicast_loop_v4(val)
791 .map_err(io_err_into_net_error)
792 }
793
794 fn multicast_loop_v4(&self) -> Result<bool> {
795 self.socket
796 .multicast_loop_v4()
797 .map_err(io_err_into_net_error)
798 }
799
800 fn set_multicast_loop_v6(&mut self, val: bool) -> Result<()> {
801 self.socket
802 .set_multicast_loop_v6(val)
803 .map_err(io_err_into_net_error)
804 }
805
806 fn multicast_loop_v6(&self) -> Result<bool> {
807 self.socket
808 .multicast_loop_v6()
809 .map_err(io_err_into_net_error)
810 }
811
812 fn set_multicast_ttl_v4(&mut self, ttl: u32) -> Result<()> {
813 self.socket
814 .set_multicast_ttl_v4(ttl)
815 .map_err(io_err_into_net_error)
816 }
817
818 fn multicast_ttl_v4(&self) -> Result<u32> {
819 self.socket
820 .multicast_ttl_v4()
821 .map_err(io_err_into_net_error)
822 }
823
824 fn join_multicast_v4(&mut self, multiaddr: Ipv4Addr, iface: Ipv4Addr) -> Result<()> {
825 self.with_sock_ref(|s| s.join_multicast_v4(&multiaddr, &iface))
826 .map_err(io_err_into_net_error)
827 }
828
829 fn leave_multicast_v4(&mut self, multiaddr: Ipv4Addr, iface: Ipv4Addr) -> Result<()> {
830 self.with_sock_ref(|s| s.leave_multicast_v4(&multiaddr, &iface))
831 .map_err(io_err_into_net_error)
832 }
833
834 fn join_multicast_v6(&mut self, multiaddr: Ipv6Addr, iface: u32) -> Result<()> {
835 self.socket
836 .join_multicast_v6(&multiaddr, iface)
837 .map_err(io_err_into_net_error)
838 }
839
840 fn leave_multicast_v6(&mut self, multiaddr: Ipv6Addr, iface: u32) -> Result<()> {
841 self.socket
842 .leave_multicast_v6(&multiaddr, iface)
843 .map_err(io_err_into_net_error)
844 }
845
846 fn addr_peer(&self) -> Result<Option<SocketAddr>> {
847 self.socket
848 .peer_addr()
849 .map(Some)
850 .map_err(io_err_into_net_error)
851 }
852}
853
854impl VirtualConnectionlessSocket for LocalUdpSocket {
855 fn try_send_to(&mut self, data: &[u8], addr: SocketAddr) -> Result<usize> {
856 if let Some(ruleset) = self.ruleset.as_ref() {
857 if !ruleset.allows_socket(addr, Direction::Outbound) {
858 tracing::warn!(%addr, "try_send blocked by firewall rule");
859 return Err(NetworkError::PermissionDenied);
860 }
861 }
862
863 let ret = self
864 .socket
865 .send_to(data, addr)
866 .map_err(io_err_into_net_error);
867 match &ret {
868 Ok(0) | Err(NetworkError::WouldBlock) => {
869 if let HandlerGuardState::WakerMap(_, map) = &mut self.handler_guard {
870 map.pop(InterestType::Writable);
871 }
872 }
873 _ => {}
874 }
875 ret
876 }
877
878 fn try_recv_from(
879 &mut self,
880 buf: &mut [MaybeUninit<u8>],
881 peek: bool,
882 ) -> Result<(usize, SocketAddr)> {
883 let buf: &mut [u8] = unsafe { std::mem::transmute(buf) };
884 if peek {
885 self.socket.peek_from(buf)
886 } else {
887 self.socket.recv_from(buf)
888 }
889 .map_err(io_err_into_net_error)
890 }
891}
892
893impl VirtualSocket for LocalUdpSocket {
894 fn set_ttl(&mut self, ttl: u32) -> Result<()> {
895 self.socket.set_ttl(ttl).map_err(io_err_into_net_error)
896 }
897
898 fn ttl(&self) -> Result<u32> {
899 self.socket.ttl().map_err(io_err_into_net_error)
900 }
901
902 fn addr_local(&self) -> Result<SocketAddr> {
903 self.socket.local_addr().map_err(io_err_into_net_error)
904 }
905
906 fn status(&self) -> Result<SocketStatus> {
907 Ok(SocketStatus::Opened)
908 }
909
910 fn set_handler(&mut self, mut handler: Box<dyn InterestHandler + Send + Sync>) -> Result<()> {
911 if let HandlerGuardState::ExternalHandler(guard) = &mut self.handler_guard {
912 match guard.replace_handler(handler) {
913 Ok(()) => {
914 return Ok(());
915 }
916 Err(h) => handler = h,
917 }
918
919 if let Err(err) = guard.unregister(&mut self.socket) {
921 tracing::debug!("failed to unregister previous token - {}", err);
922 }
923 }
924
925 let guard = InterestGuard::new(
926 &self.selector,
927 handler,
928 &mut self.socket,
929 mio::Interest::READABLE.add(mio::Interest::WRITABLE),
930 )
931 .map_err(io_err_into_net_error)?;
932
933 self.handler_guard = HandlerGuardState::ExternalHandler(guard);
934
935 Ok(())
936 }
937}
938
939impl LocalUdpSocket {
940 fn split_borrow(
941 &mut self,
942 ) -> (
943 &mut HandlerGuardState,
944 &Arc<Selector>,
945 &mut mio::net::UdpSocket,
946 ) {
947 (&mut self.handler_guard, &self.selector, &mut self.socket)
948 }
949}
950
951impl VirtualIoSource for LocalUdpSocket {
952 fn remove_handler(&mut self) {
953 let mut guard = HandlerGuardState::None;
954 std::mem::swap(&mut guard, &mut self.handler_guard);
955 match guard {
956 HandlerGuardState::ExternalHandler(mut guard) => {
957 guard.unregister(&mut self.socket).ok();
958 }
959 HandlerGuardState::WakerMap(mut guard, _) => {
960 guard.unregister(&mut self.socket).ok();
961 }
962 HandlerGuardState::None => {}
963 }
964 }
965
966 fn poll_read_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
967 if !self.backlog.is_empty() {
968 let total = self.backlog.iter().map(|a| a.0.len()).sum();
969 return Poll::Ready(Ok(total));
970 }
971
972 let (state, selector, socket) = self.split_borrow();
973 let map = state_as_waker_map(state, selector, socket).map_err(io_err_into_net_error)?;
974 map.pop(InterestType::Readable);
975 map.add(InterestType::Readable, cx.waker());
976
977 let mut buffer = BytesMut::default();
978 buffer.reserve(10240);
979 let uninit: &mut [MaybeUninit<u8>] = buffer.spare_capacity_mut();
980 let uninit_unsafe: &mut [u8] = unsafe { std::mem::transmute(uninit) };
981
982 match self.socket.recv_from(uninit_unsafe) {
983 Ok((0, _)) => Poll::Ready(Ok(0)),
984 Ok((amt, peer)) => {
985 unsafe {
986 buffer.set_len(amt);
987 }
988 self.backlog.push_back((buffer, peer));
989 Poll::Ready(Ok(amt))
990 }
991 Err(err) if err.kind() == io::ErrorKind::ConnectionAborted => Poll::Ready(Ok(0)),
992 Err(err) if err.kind() == io::ErrorKind::ConnectionReset => Poll::Ready(Ok(0)),
993 Err(err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
994 Err(err) => Poll::Ready(Err(io_err_into_net_error(err))),
995 }
996 }
997
998 fn poll_write_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
999 let (state, selector, socket) = self.split_borrow();
1000 let map = state_as_waker_map(state, selector, socket).map_err(io_err_into_net_error)?;
1001 #[cfg(not(target_os = "windows"))]
1002 map.pop(InterestType::Writable);
1003 map.add(InterestType::Writable, cx.waker());
1004
1005 #[cfg(not(target_os = "windows"))]
1006 match libc_poll(socket.as_raw_fd(), libc::POLLOUT | libc::POLLHUP) {
1007 Some(val) if (val & libc::POLLHUP) != 0 => {
1008 return Poll::Ready(Ok(0));
1009 }
1010 Some(val) if (val & libc::POLLOUT) != 0 => return Poll::Ready(Ok(10240)),
1011 _ => {}
1012 }
1013
1014 #[cfg(target_os = "windows")]
1019 if map.has_interest(InterestType::Writable) {
1020 return Poll::Ready(Ok(10240));
1021 }
1022
1023 Poll::Pending
1024 }
1025}