1#![allow(unused_variables)]
2use crate::ruleset::{Direction, Ruleset};
3#[allow(unused_imports)]
4use crate::{
5 IpCidr, IpRoute, NetworkError, Result, SocketStatus, StreamSecurity, VirtualConnectedSocket,
6 VirtualConnectionlessSocket, VirtualIcmpSocket, VirtualNetworking, VirtualRawSocket,
7 VirtualSocket, VirtualTcpListener, VirtualTcpSocket, VirtualUdpSocket,
8};
9use crate::{VirtualIoSource, io_err_into_net_error};
10use bytes::{Buf, BytesMut};
11use std::collections::VecDeque;
12use std::io::{self, Read, Write};
13use std::mem::MaybeUninit;
14use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, Shutdown, SocketAddr};
15#[cfg(not(target_os = "windows"))]
16use std::os::fd::AsRawFd;
17#[cfg(not(target_os = "windows"))]
18use std::os::fd::RawFd;
19#[cfg(windows)]
20use std::os::windows::io::AsRawSocket;
21
22use std::sync::{Arc, Mutex};
23use std::task::Poll;
24use std::time::Duration;
25use tokio::runtime::Handle;
26#[allow(unused_imports, dead_code)]
27use tracing::{debug, error, info, trace, warn};
28use virtual_mio::{
29 HandlerGuardState, InterestGuard, InterestHandler, InterestType, Selector, state_as_waker_map,
30};
31
32#[derive(Debug)]
33pub struct LocalNetworking {
34 selector: Arc<Selector>,
35 handle: Handle,
36 ruleset: Option<Ruleset>,
37}
38
39impl LocalNetworking {
40 pub fn new() -> Self {
41 Self {
42 selector: Selector::new(),
43 handle: Handle::current(),
44 ruleset: None,
45 }
46 }
47
48 pub fn with_ruleset(ruleset: Ruleset) -> Self {
49 Self {
50 selector: Selector::new(),
51 handle: Handle::current(),
52 ruleset: Some(ruleset),
53 }
54 }
55}
56
57impl Drop for LocalNetworking {
58 fn drop(&mut self) {
59 self.selector.shutdown();
60 }
61}
62
63impl Default for LocalNetworking {
64 fn default() -> Self {
65 Self::new()
66 }
67}
68
69#[async_trait::async_trait]
70#[allow(unused_variables)]
71impl VirtualNetworking for LocalNetworking {
72 async fn listen_tcp(
73 &self,
74 addr: SocketAddr,
75 only_v6: bool,
76 reuse_port: bool,
77 reuse_addr: bool,
78 ) -> Result<Box<dyn VirtualTcpListener + Sync>> {
79 if let Some(ruleset) = self.ruleset.as_ref()
80 && !ruleset.allows_socket(addr, Direction::Inbound)
81 {
82 tracing::warn!(%addr, "listen_tcp blocked by firewall rule");
83 return Err(NetworkError::PermissionDenied);
84 }
85
86 let listener = std::net::TcpListener::bind(addr)
87 .map(|sock| {
88 sock.set_nonblocking(true).ok();
89 Box::new(LocalTcpListener {
90 stream: mio::net::TcpListener::from_std(sock),
91 selector: self.selector.clone(),
92 handler_guard: HandlerGuardState::None,
93 no_delay: None,
94 keep_alive: None,
95 backlog: Default::default(),
96 ruleset: self.ruleset.clone(),
97 })
98 })
99 .map_err(io_err_into_net_error)?;
100 Ok(listener)
101 }
102
103 async fn bind_udp(
104 &self,
105 addr: SocketAddr,
106 reuse_port: bool,
107 reuse_addr: bool,
108 ) -> Result<Box<dyn VirtualUdpSocket + Sync>> {
109 #[cfg(not(windows))]
110 use socket2::{Domain, Socket, Type};
111
112 if let Some(ruleset) = self.ruleset.as_ref()
113 && !ruleset.allows_socket(addr, Direction::Inbound)
114 {
115 tracing::warn!(%addr, "bind_udp blocked by firewall rule");
116 return Err(NetworkError::PermissionDenied);
117 }
118
119 #[cfg(not(windows))]
120 let socket = {
121 let domain = if addr.is_ipv4() {
122 Domain::IPV4
123 } else {
124 Domain::IPV6
125 };
126 let std_sock = Socket::new(domain, Type::DGRAM, None).map_err(io_err_into_net_error)?;
127 std_sock
128 .set_nonblocking(true)
129 .map_err(io_err_into_net_error)?;
130 std_sock
131 .set_reuse_address(reuse_addr)
132 .map_err(io_err_into_net_error)?;
133 std_sock
134 .set_reuse_port(reuse_port)
135 .map_err(io_err_into_net_error)?;
136 std_sock.bind(&addr.into()).map_err(io_err_into_net_error)?;
137 mio::net::UdpSocket::from_std(std_sock.into())
138 };
139 #[cfg(windows)]
140 let socket = mio::net::UdpSocket::bind(addr).map_err(io_err_into_net_error)?;
141
142 #[allow(unused_mut)]
143 let mut ret = LocalUdpSocket {
144 selector: self.selector.clone(),
145 socket,
146 addr,
147 handler_guard: HandlerGuardState::None,
148 backlog: Default::default(),
149 ruleset: self.ruleset.clone(),
150 };
151
152 #[cfg(target_os = "windows")]
157 {
158 let (state, selector, socket) = ret.split_borrow();
159 let map = state_as_waker_map(state, selector, socket).map_err(io_err_into_net_error)?;
160 map.push(InterestType::Writable);
161 }
162
163 Ok(Box::new(ret))
164 }
165
166 async fn connect_tcp(
167 &self,
168 _addr: SocketAddr,
169 mut peer: SocketAddr,
170 ) -> Result<Box<dyn VirtualTcpSocket + Sync>> {
171 if let Some(ruleset) = self.ruleset.as_ref()
172 && !ruleset.allows_socket(peer, Direction::Outbound)
173 {
174 tracing::warn!(%peer, "connect_tcp blocked by firewall rule");
175 return Err(NetworkError::PermissionDenied);
176 }
177
178 let stream = mio::net::TcpStream::connect(peer).map_err(io_err_into_net_error)?;
179
180 if let Ok(p) = stream.peer_addr() {
181 peer = p;
182 }
183 let socket = Box::new(LocalTcpStream::new(self.selector.clone(), stream, peer));
184 Ok(socket)
185 }
186
187 async fn resolve(
188 &self,
189 host: &str,
190 port: Option<u16>,
191 dns_server: Option<IpAddr>,
192 ) -> Result<Vec<IpAddr>> {
193 if let Some(ruleset) = self.ruleset.as_ref()
194 && !ruleset.allows_domain(host)
195 {
196 tracing::warn!(%host, "dns resolve blocked by firewall rule");
197 return Err(NetworkError::PermissionDenied);
198 }
199
200 let host_to_lookup = if host.contains(':') {
201 host.to_string()
202 } else {
203 format!("{}:{}", host, port.unwrap_or(0))
204 };
205 let addrs = self
206 .handle
207 .spawn(tokio::net::lookup_host(host_to_lookup))
208 .await
209 .map_err(|_| NetworkError::IOError)?
210 .map(|a| a.map(|a| a.ip()).collect::<Vec<_>>())
211 .map_err(io_err_into_net_error)?;
212
213 if let Some(ruleset) = self.ruleset.as_ref() {
214 if let Err(e) = ruleset.expand_domain(host, &addrs) {
215 tracing::debug!(err=%e, "ruleset expansion failed");
216 } else {
217 tracing::debug!(addrs=?addrs, domain = host, "ruleset expansion")
218 }
219 }
220
221 Ok(addrs)
222 }
223}
224
225#[derive(Debug)]
226pub struct LocalTcpListener {
227 stream: mio::net::TcpListener,
228 selector: Arc<Selector>,
229 handler_guard: HandlerGuardState,
230 no_delay: Option<bool>,
231 keep_alive: Option<bool>,
232 backlog: VecDeque<(Box<dyn VirtualTcpSocket + Sync>, SocketAddr)>,
233 ruleset: Option<Ruleset>,
234}
235
236impl LocalTcpListener {
237 fn try_accept_internal(&mut self) -> Result<(Box<dyn VirtualTcpSocket + Sync>, SocketAddr)> {
238 match self.stream.accept().map_err(io_err_into_net_error) {
239 Ok((stream, addr)) => {
240 if let Some(ruleset) = self.ruleset.as_ref()
241 && !ruleset.allows_socket(addr, Direction::Outbound)
242 {
243 tracing::warn!(%addr, "try_accept blocked by firewall rule");
244 return Err(NetworkError::PermissionDenied);
245 }
246
247 let mut socket = LocalTcpStream::new(self.selector.clone(), stream, addr);
248 if let Some(no_delay) = self.no_delay {
249 socket.set_nodelay(no_delay).ok();
250 }
251 if let Some(keep_alive) = self.keep_alive {
252 socket.set_keepalive(keep_alive).ok();
253 }
254 Ok((Box::new(socket), addr))
255 }
256 Err(NetworkError::WouldBlock) => {
257 if let HandlerGuardState::WakerMap(_, map) = &mut self.handler_guard {
258 map.pop(InterestType::Readable);
259 map.pop(InterestType::Writable);
260 }
261 Err(NetworkError::WouldBlock)
262 }
263 Err(err) => Err(err),
264 }
265 }
266}
267
268impl VirtualTcpListener for LocalTcpListener {
269 fn try_accept(&mut self) -> Result<(Box<dyn VirtualTcpSocket + Sync>, SocketAddr)> {
270 if let Some(child) = self.backlog.pop_front() {
271 return Ok(child);
272 }
273 self.try_accept_internal()
274 }
275
276 fn set_handler(&mut self, mut handler: Box<dyn InterestHandler + Send + Sync>) -> Result<()> {
277 if let HandlerGuardState::ExternalHandler(guard) = &mut self.handler_guard {
278 match guard.replace_handler(handler) {
279 Ok(()) => return Ok(()),
280 Err(h) => handler = h,
281 }
282
283 if let Err(err) = guard.unregister(&mut self.stream) {
285 tracing::debug!("failed to unregister previous token - {}", err);
286 }
287 }
288
289 let guard = InterestGuard::new(
290 &self.selector,
291 handler,
292 &mut self.stream,
293 mio::Interest::READABLE.add(mio::Interest::WRITABLE),
294 )
295 .map_err(io_err_into_net_error)?;
296
297 self.handler_guard = HandlerGuardState::ExternalHandler(guard);
298
299 Ok(())
300 }
301
302 fn addr_local(&self) -> Result<SocketAddr> {
303 self.stream.local_addr().map_err(io_err_into_net_error)
304 }
305
306 fn set_ttl(&mut self, ttl: u8) -> Result<()> {
307 self.stream
308 .set_ttl(ttl as u32)
309 .map_err(io_err_into_net_error)
310 }
311
312 fn ttl(&self) -> Result<u8> {
313 self.stream
314 .ttl()
315 .map(|ttl| ttl as u8)
316 .map_err(io_err_into_net_error)
317 }
318}
319
320impl LocalTcpListener {
321 fn split_borrow(
322 &mut self,
323 ) -> (
324 &mut HandlerGuardState,
325 &Arc<Selector>,
326 &mut mio::net::TcpListener,
327 ) {
328 (&mut self.handler_guard, &self.selector, &mut self.stream)
329 }
330}
331
332impl VirtualIoSource for LocalTcpListener {
333 fn remove_handler(&mut self) {
334 let mut guard = HandlerGuardState::None;
335 std::mem::swap(&mut guard, &mut self.handler_guard);
336 match guard {
337 HandlerGuardState::ExternalHandler(mut guard) => {
338 guard.unregister(&mut self.stream).ok();
339 }
340 HandlerGuardState::WakerMap(mut guard, _) => {
341 guard.unregister(&mut self.stream).ok();
342 }
343 HandlerGuardState::None => {}
344 }
345 }
346
347 fn poll_read_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
348 if !self.backlog.is_empty() {
349 return Poll::Ready(Ok(self.backlog.len()));
350 }
351
352 let (state, selector, source) = self.split_borrow();
353 let map = state_as_waker_map(state, selector, source).map_err(io_err_into_net_error)?;
354 map.add(InterestType::Readable, cx.waker());
355
356 if let Ok(child) = self.try_accept_internal() {
357 self.backlog.push_back(child);
358 return Poll::Ready(Ok(1));
359 }
360 Poll::Pending
361 }
362
363 fn poll_write_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
364 if !self.backlog.is_empty() {
365 return Poll::Ready(Ok(self.backlog.len()));
366 }
367
368 let (state, selector, source) = self.split_borrow();
369 let map = state_as_waker_map(state, selector, source).map_err(io_err_into_net_error)?;
370 map.add(InterestType::Writable, cx.waker());
371
372 if let Ok(child) = self.try_accept_internal() {
373 self.backlog.push_back(child);
374 return Poll::Ready(Ok(1));
375 }
376 Poll::Pending
377 }
378}
379
380#[derive(Debug)]
381enum ConnectState {
382 Unknown,
383 Opened,
384 Failed,
385}
386
387#[derive(Debug)]
388pub struct LocalTcpStream {
389 stream: mio::net::TcpStream,
390 addr: SocketAddr,
391 shutdown: Option<Shutdown>,
392 selector: Arc<Selector>,
393 handler_guard: HandlerGuardState,
394 buffer: BytesMut,
395 connect_state: Mutex<ConnectState>,
396}
397
398impl LocalTcpStream {
399 fn new(selector: Arc<Selector>, stream: mio::net::TcpStream, addr: SocketAddr) -> Self {
400 #[allow(unused_mut)]
401 let mut ret = Self {
402 stream,
403 addr,
404 shutdown: None,
405 selector,
406 handler_guard: HandlerGuardState::None,
407 buffer: BytesMut::new(),
408 connect_state: Mutex::new(ConnectState::Unknown),
409 };
410
411 #[cfg(target_os = "windows")]
416 {
417 let (state, selector, socket, _) = ret.split_borrow();
418 if let Ok(map) = state_as_waker_map(state, selector, socket) {
419 map.push(InterestType::Writable);
420 }
421 }
422
423 ret
424 }
425
426 fn with_sock_ref<F, R>(&self, f: F) -> R
427 where
428 for<'a> F: FnOnce(socket2::SockRef<'a>) -> R,
429 {
430 #[cfg(not(windows))]
431 let r = socket2::SockRef::from(&self.stream);
432
433 #[cfg(windows)]
434 let b = unsafe {
435 std::os::windows::io::BorrowedSocket::borrow_raw(self.stream.as_raw_socket())
436 };
437 #[cfg(windows)]
438 let r = socket2::SockRef::from(&b);
439
440 f(r)
441 }
442}
443
444impl VirtualTcpSocket for LocalTcpStream {
445 fn set_recv_buf_size(&mut self, size: usize) -> Result<()> {
446 Ok(())
447 }
448
449 fn recv_buf_size(&self) -> Result<usize> {
450 Err(NetworkError::Unsupported)
451 }
452
453 fn set_send_buf_size(&mut self, size: usize) -> Result<()> {
454 Ok(())
455 }
456
457 fn send_buf_size(&self) -> Result<usize> {
458 Err(NetworkError::Unsupported)
459 }
460
461 fn set_nodelay(&mut self, nodelay: bool) -> Result<()> {
462 self.stream
463 .set_nodelay(nodelay)
464 .map_err(io_err_into_net_error)
465 }
466
467 fn nodelay(&self) -> Result<bool> {
468 self.stream.nodelay().map_err(io_err_into_net_error)
469 }
470
471 fn set_keepalive(&mut self, keepalive: bool) -> Result<()> {
472 self.with_sock_ref(|s| s.set_keepalive(true))
473 .map_err(io_err_into_net_error)?;
474 Ok(())
475 }
476
477 fn keepalive(&self) -> Result<bool> {
478 let ret = self
479 .with_sock_ref(|s| s.keepalive())
480 .map_err(io_err_into_net_error)?;
481 Ok(ret)
482 }
483
484 #[cfg(not(target_os = "windows"))]
485 fn set_dontroute(&mut self, val: bool) -> Result<()> {
486 let val = val as libc::c_int;
492 let payload = &val as *const libc::c_int as *const libc::c_void;
493 let err = unsafe {
494 libc::setsockopt(
495 self.stream.as_raw_fd(),
496 libc::SOL_SOCKET,
497 libc::SO_DONTROUTE,
498 payload,
499 std::mem::size_of::<libc::c_int>() as libc::socklen_t,
500 )
501 };
502 if err == -1 {
503 return Err(io_err_into_net_error(std::io::Error::last_os_error()));
504 }
505 Ok(())
506 }
507 #[cfg(target_os = "windows")]
508 fn set_dontroute(&mut self, val: bool) -> Result<()> {
509 Err(NetworkError::Unsupported)
510 }
511
512 #[cfg(not(target_os = "windows"))]
513 fn dontroute(&self) -> Result<bool> {
514 let mut payload: MaybeUninit<libc::c_int> = MaybeUninit::uninit();
515 let mut len = std::mem::size_of::<libc::c_int>() as libc::socklen_t;
516 let err = unsafe {
517 libc::getsockopt(
518 self.stream.as_raw_fd(),
519 libc::SOL_SOCKET,
520 libc::SO_DONTROUTE,
521 payload.as_mut_ptr().cast(),
522 &mut len,
523 )
524 };
525 if err == -1 {
526 return Err(io_err_into_net_error(std::io::Error::last_os_error()));
527 }
528 Ok(unsafe { payload.assume_init() != 0 })
529 }
530 #[cfg(target_os = "windows")]
531 fn dontroute(&self) -> Result<bool> {
532 Err(NetworkError::Unsupported)
533 }
534
535 fn addr_peer(&self) -> Result<SocketAddr> {
536 Ok(self.addr)
537 }
538
539 fn shutdown(&mut self, how: Shutdown) -> Result<()> {
540 self.stream.shutdown(how).map_err(io_err_into_net_error)?;
541 self.shutdown = Some(how);
542 Ok(())
543 }
544
545 fn is_closed(&self) -> bool {
546 false
547 }
548}
549
550impl VirtualConnectedSocket for LocalTcpStream {
551 fn set_linger(&mut self, linger: Option<Duration>) -> Result<()> {
552 self.with_sock_ref(|s| s.set_linger(linger))
553 .map_err(io_err_into_net_error)?;
554 Ok(())
555 }
556
557 fn linger(&self) -> Result<Option<Duration>> {
558 self.with_sock_ref(|s| s.linger())
559 .map_err(io_err_into_net_error)
560 }
561
562 fn try_send(&mut self, data: &[u8]) -> Result<usize> {
563 let ret = self.stream.write(data).map_err(io_err_into_net_error);
564 match &ret {
565 Ok(0) | Err(NetworkError::WouldBlock) => {
566 if let HandlerGuardState::WakerMap(_, map) = &mut self.handler_guard {
567 map.pop(InterestType::Writable);
568 }
569 }
570 _ => {}
571 }
572 ret
573 }
574
575 fn try_flush(&mut self) -> Result<()> {
576 self.stream.flush().map_err(io_err_into_net_error)
577 }
578
579 fn close(&mut self) -> Result<()> {
580 Ok(())
581 }
582
583 fn try_recv(&mut self, buf: &mut [MaybeUninit<u8>], peek: bool) -> Result<usize> {
584 let buf: &mut [u8] = unsafe { std::mem::transmute(buf) };
585 if !self.buffer.is_empty() {
586 let amt = buf.len().min(self.buffer.len());
587 buf[..amt].copy_from_slice(&self.buffer[..amt]);
588 if !peek {
589 self.buffer.advance(amt);
590 }
591 return Ok(amt);
592 }
593
594 if peek {
595 self.stream.peek(buf)
596 } else {
597 self.stream.read(buf)
598 }
599 .map_err(io_err_into_net_error)
600 }
601}
602
603impl VirtualSocket for LocalTcpStream {
604 fn set_ttl(&mut self, ttl: u32) -> Result<()> {
605 self.stream.set_ttl(ttl).map_err(io_err_into_net_error)
606 }
607
608 fn ttl(&self) -> Result<u32> {
609 self.stream.ttl().map_err(io_err_into_net_error)
610 }
611
612 fn addr_local(&self) -> Result<SocketAddr> {
613 self.stream.local_addr().map_err(io_err_into_net_error)
614 }
615
616 fn status(&self) -> Result<SocketStatus> {
617 let mut connect_state = self.connect_state.lock().unwrap();
620 match *connect_state {
621 ConnectState::Opened => return Ok(SocketStatus::Opened),
622 ConnectState::Failed => return Ok(SocketStatus::Failed),
623 ConnectState::Unknown => {}
624 }
625
626 if self
627 .with_sock_ref(|sockref| sockref.take_error())
628 .map_err(io_err_into_net_error)?
629 .is_some()
630 {
631 *connect_state = ConnectState::Failed;
632 return Ok(SocketStatus::Failed); }
634 match self.stream.peer_addr() {
635 Ok(_) => {
636 *connect_state = ConnectState::Opened;
637 Ok(SocketStatus::Opened) }
639 Err(err) => {
640 if matches!(
641 err.kind(),
642 io::ErrorKind::NotConnected | io::ErrorKind::WouldBlock
643 ) {
644 Ok(SocketStatus::Opening) } else {
646 *connect_state = ConnectState::Failed;
648 Ok(SocketStatus::Failed) }
650 }
651 }
652 }
653
654 fn set_handler(&mut self, mut handler: Box<dyn InterestHandler + Send + Sync>) -> Result<()> {
655 if let HandlerGuardState::ExternalHandler(guard) = &mut self.handler_guard {
656 match guard.replace_handler(handler) {
657 Ok(()) => return Ok(()),
658 Err(h) => handler = h,
659 }
660
661 if let Err(err) = guard.unregister(&mut self.stream) {
663 tracing::debug!("failed to unregister previous token - {}", err);
664 }
665 }
666
667 let guard = InterestGuard::new(
668 &self.selector,
669 handler,
670 &mut self.stream,
671 mio::Interest::READABLE.add(mio::Interest::WRITABLE),
672 )
673 .map_err(io_err_into_net_error)?;
674
675 self.handler_guard = HandlerGuardState::ExternalHandler(guard);
676
677 Ok(())
678 }
679}
680
681impl LocalTcpStream {
682 fn split_borrow(
683 &mut self,
684 ) -> (
685 &mut HandlerGuardState,
686 &Arc<Selector>,
687 &mut mio::net::TcpStream,
688 &mut BytesMut,
689 ) {
690 (
691 &mut self.handler_guard,
692 &self.selector,
693 &mut self.stream,
694 &mut self.buffer,
695 )
696 }
697}
698
699impl VirtualIoSource for LocalTcpStream {
700 fn remove_handler(&mut self) {
701 let mut guard = HandlerGuardState::None;
702 std::mem::swap(&mut guard, &mut self.handler_guard);
703 match guard {
704 HandlerGuardState::ExternalHandler(mut guard) => {
705 guard.unregister(&mut self.stream).ok();
706 }
707 HandlerGuardState::WakerMap(mut guard, _) => {
708 guard.unregister(&mut self.stream).ok();
709 }
710 HandlerGuardState::None => {}
711 }
712 }
713
714 fn poll_read_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
715 if !self.buffer.is_empty() {
716 return Poll::Ready(Ok(self.buffer.len()));
717 }
718
719 let (state, selector, stream, buffer) = self.split_borrow();
720 let map = state_as_waker_map(state, selector, stream).map_err(io_err_into_net_error)?;
721 map.pop(InterestType::Readable);
722 map.add(InterestType::Readable, cx.waker());
723
724 buffer.reserve(buffer.len() + 10240);
725 let uninit: &mut [MaybeUninit<u8>] = buffer.spare_capacity_mut();
726 let uninit_unsafe: &mut [u8] = unsafe { std::mem::transmute(uninit) };
727
728 match stream.read(uninit_unsafe) {
729 Ok(0) => Poll::Ready(Ok(0)),
730 Ok(amt) => {
731 unsafe {
732 buffer.set_len(buffer.len() + amt);
733 }
734 Poll::Ready(Ok(amt))
735 }
736 Err(err) if err.kind() == io::ErrorKind::ConnectionAborted => Poll::Ready(Ok(0)),
737 Err(err) if err.kind() == io::ErrorKind::ConnectionReset => Poll::Ready(Ok(0)),
738 Err(err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
739 Err(err) => Poll::Ready(Err(io_err_into_net_error(err))),
740 }
741 }
742
743 fn poll_write_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
744 let (state, selector, stream, _) = self.split_borrow();
745 let map = state_as_waker_map(state, selector, stream).map_err(io_err_into_net_error)?;
746 #[cfg(not(target_os = "windows"))]
747 map.pop(InterestType::Writable);
748 map.add(InterestType::Writable, cx.waker());
749 map.add(InterestType::Closed, cx.waker());
750 if map.has_interest(InterestType::Closed) {
751 return Poll::Ready(Ok(0));
752 }
753
754 #[cfg(not(target_os = "windows"))]
755 match libc_poll(stream.as_raw_fd(), libc::POLLOUT | libc::POLLHUP) {
756 Some(val) if (val & libc::POLLHUP) != 0 => {
757 return Poll::Ready(Ok(0));
758 }
759 Some(val) if (val & libc::POLLOUT) != 0 => return Poll::Ready(Ok(10240)),
760 _ => {}
761 }
762
763 #[cfg(target_os = "windows")]
768 if map.has_interest(InterestType::Writable) {
769 return Poll::Ready(Ok(10240));
770 }
771
772 Poll::Pending
773 }
774}
775
776#[cfg(not(target_os = "windows"))]
777fn libc_poll(fd: RawFd, events: libc::c_short) -> Option<libc::c_short> {
778 let mut fds: [libc::pollfd; 1] = [libc::pollfd {
779 fd,
780 events,
781 revents: 0,
782 }];
783 let fds_mut = &mut fds[..];
784 let ret = unsafe { libc::poll(fds_mut.as_mut_ptr(), 1, 0) };
785 match ret == 1 {
786 true => Some(fds[0].revents),
787 false => None,
788 }
789}
790
791#[derive(Debug)]
792pub struct LocalUdpSocket {
793 socket: mio::net::UdpSocket,
794 #[allow(dead_code)]
795 addr: SocketAddr,
796 selector: Arc<Selector>,
797 handler_guard: HandlerGuardState,
798 backlog: VecDeque<(BytesMut, SocketAddr)>,
799 ruleset: Option<Ruleset>,
800}
801
802impl LocalUdpSocket {
803 fn with_sock_ref<F, R>(&self, f: F) -> R
804 where
805 for<'a> F: FnOnce(socket2::SockRef<'a>) -> R,
806 {
807 #[cfg(not(windows))]
808 let r = socket2::SockRef::from(&self.socket);
809
810 #[cfg(windows)]
811 let b = unsafe {
812 std::os::windows::io::BorrowedSocket::borrow_raw(self.socket.as_raw_socket())
813 };
814 #[cfg(windows)]
815 let r = socket2::SockRef::from(&b);
816
817 f(r)
818 }
819}
820
821impl VirtualUdpSocket for LocalUdpSocket {
822 fn set_broadcast(&mut self, broadcast: bool) -> Result<()> {
823 self.socket
824 .set_broadcast(broadcast)
825 .map_err(io_err_into_net_error)
826 }
827
828 fn broadcast(&self) -> Result<bool> {
829 self.socket.broadcast().map_err(io_err_into_net_error)
830 }
831
832 fn set_multicast_loop_v4(&mut self, val: bool) -> Result<()> {
833 self.socket
834 .set_multicast_loop_v4(val)
835 .map_err(io_err_into_net_error)
836 }
837
838 fn multicast_loop_v4(&self) -> Result<bool> {
839 self.socket
840 .multicast_loop_v4()
841 .map_err(io_err_into_net_error)
842 }
843
844 fn set_multicast_loop_v6(&mut self, val: bool) -> Result<()> {
845 self.socket
846 .set_multicast_loop_v6(val)
847 .map_err(io_err_into_net_error)
848 }
849
850 fn multicast_loop_v6(&self) -> Result<bool> {
851 self.socket
852 .multicast_loop_v6()
853 .map_err(io_err_into_net_error)
854 }
855
856 fn set_multicast_ttl_v4(&mut self, ttl: u32) -> Result<()> {
857 self.socket
858 .set_multicast_ttl_v4(ttl)
859 .map_err(io_err_into_net_error)
860 }
861
862 fn multicast_ttl_v4(&self) -> Result<u32> {
863 self.socket
864 .multicast_ttl_v4()
865 .map_err(io_err_into_net_error)
866 }
867
868 fn join_multicast_v4(&mut self, multiaddr: Ipv4Addr, iface: Ipv4Addr) -> Result<()> {
869 self.with_sock_ref(|s| s.join_multicast_v4(&multiaddr, &iface))
870 .map_err(io_err_into_net_error)
871 }
872
873 fn leave_multicast_v4(&mut self, multiaddr: Ipv4Addr, iface: Ipv4Addr) -> Result<()> {
874 self.with_sock_ref(|s| s.leave_multicast_v4(&multiaddr, &iface))
875 .map_err(io_err_into_net_error)
876 }
877
878 fn join_multicast_v6(&mut self, multiaddr: Ipv6Addr, iface: u32) -> Result<()> {
879 self.socket
880 .join_multicast_v6(&multiaddr, iface)
881 .map_err(io_err_into_net_error)
882 }
883
884 fn leave_multicast_v6(&mut self, multiaddr: Ipv6Addr, iface: u32) -> Result<()> {
885 self.socket
886 .leave_multicast_v6(&multiaddr, iface)
887 .map_err(io_err_into_net_error)
888 }
889
890 fn addr_peer(&self) -> Result<Option<SocketAddr>> {
891 self.socket
892 .peer_addr()
893 .map(Some)
894 .map_err(io_err_into_net_error)
895 }
896}
897
898impl VirtualConnectionlessSocket for LocalUdpSocket {
899 fn try_send_to(&mut self, data: &[u8], addr: SocketAddr) -> Result<usize> {
900 if let Some(ruleset) = self.ruleset.as_ref()
901 && !ruleset.allows_socket(addr, Direction::Outbound)
902 {
903 tracing::warn!(%addr, "try_send blocked by firewall rule");
904 return Err(NetworkError::PermissionDenied);
905 }
906
907 let ret = self
908 .socket
909 .send_to(data, addr)
910 .map_err(io_err_into_net_error);
911 match &ret {
912 Ok(0) | Err(NetworkError::WouldBlock) => {
913 if let HandlerGuardState::WakerMap(_, map) = &mut self.handler_guard {
914 map.pop(InterestType::Writable);
915 }
916 }
917 _ => {}
918 }
919 ret
920 }
921
922 fn try_recv_from(
923 &mut self,
924 buf: &mut [MaybeUninit<u8>],
925 peek: bool,
926 ) -> Result<(usize, SocketAddr)> {
927 let buf: &mut [u8] = unsafe { std::mem::transmute(buf) };
928 if peek {
929 self.socket.peek_from(buf)
930 } else {
931 self.socket.recv_from(buf)
932 }
933 .map_err(io_err_into_net_error)
934 }
935}
936
937impl VirtualSocket for LocalUdpSocket {
938 fn set_ttl(&mut self, ttl: u32) -> Result<()> {
939 self.socket.set_ttl(ttl).map_err(io_err_into_net_error)
940 }
941
942 fn ttl(&self) -> Result<u32> {
943 self.socket.ttl().map_err(io_err_into_net_error)
944 }
945
946 fn addr_local(&self) -> Result<SocketAddr> {
947 self.socket.local_addr().map_err(io_err_into_net_error)
948 }
949
950 fn status(&self) -> Result<SocketStatus> {
951 Ok(SocketStatus::Opened)
952 }
953
954 fn set_handler(&mut self, mut handler: Box<dyn InterestHandler + Send + Sync>) -> Result<()> {
955 if let HandlerGuardState::ExternalHandler(guard) = &mut self.handler_guard {
956 match guard.replace_handler(handler) {
957 Ok(()) => {
958 return Ok(());
959 }
960 Err(h) => handler = h,
961 }
962
963 if let Err(err) = guard.unregister(&mut self.socket) {
965 tracing::debug!("failed to unregister previous token - {}", err);
966 }
967 }
968
969 let guard = InterestGuard::new(
970 &self.selector,
971 handler,
972 &mut self.socket,
973 mio::Interest::READABLE.add(mio::Interest::WRITABLE),
974 )
975 .map_err(io_err_into_net_error)?;
976
977 self.handler_guard = HandlerGuardState::ExternalHandler(guard);
978
979 Ok(())
980 }
981}
982
983impl LocalUdpSocket {
984 fn split_borrow(
985 &mut self,
986 ) -> (
987 &mut HandlerGuardState,
988 &Arc<Selector>,
989 &mut mio::net::UdpSocket,
990 ) {
991 (&mut self.handler_guard, &self.selector, &mut self.socket)
992 }
993}
994
995impl VirtualIoSource for LocalUdpSocket {
996 fn remove_handler(&mut self) {
997 let mut guard = HandlerGuardState::None;
998 std::mem::swap(&mut guard, &mut self.handler_guard);
999 match guard {
1000 HandlerGuardState::ExternalHandler(mut guard) => {
1001 guard.unregister(&mut self.socket).ok();
1002 }
1003 HandlerGuardState::WakerMap(mut guard, _) => {
1004 guard.unregister(&mut self.socket).ok();
1005 }
1006 HandlerGuardState::None => {}
1007 }
1008 }
1009
1010 fn poll_read_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
1011 if !self.backlog.is_empty() {
1012 let total = self.backlog.iter().map(|a| a.0.len()).sum();
1013 return Poll::Ready(Ok(total));
1014 }
1015
1016 let (state, selector, socket) = self.split_borrow();
1017 let map = state_as_waker_map(state, selector, socket).map_err(io_err_into_net_error)?;
1018 map.pop(InterestType::Readable);
1019 map.add(InterestType::Readable, cx.waker());
1020
1021 let mut buffer = BytesMut::default();
1022 buffer.reserve(10240);
1023 let uninit: &mut [MaybeUninit<u8>] = buffer.spare_capacity_mut();
1024 let uninit_unsafe: &mut [u8] = unsafe { std::mem::transmute(uninit) };
1025
1026 match self.socket.recv_from(uninit_unsafe) {
1027 Ok((0, _)) => Poll::Ready(Ok(0)),
1028 Ok((amt, peer)) => {
1029 unsafe {
1030 buffer.set_len(amt);
1031 }
1032 self.backlog.push_back((buffer, peer));
1033 Poll::Ready(Ok(amt))
1034 }
1035 Err(err) if err.kind() == io::ErrorKind::ConnectionAborted => Poll::Ready(Ok(0)),
1036 Err(err) if err.kind() == io::ErrorKind::ConnectionReset => Poll::Ready(Ok(0)),
1037 Err(err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
1038 Err(err) => Poll::Ready(Err(io_err_into_net_error(err))),
1039 }
1040 }
1041
1042 fn poll_write_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<usize>> {
1043 let (state, selector, socket) = self.split_borrow();
1044 let map = state_as_waker_map(state, selector, socket).map_err(io_err_into_net_error)?;
1045 #[cfg(not(target_os = "windows"))]
1046 map.pop(InterestType::Writable);
1047 map.add(InterestType::Writable, cx.waker());
1048
1049 #[cfg(not(target_os = "windows"))]
1050 match libc_poll(socket.as_raw_fd(), libc::POLLOUT | libc::POLLHUP) {
1051 Some(val) if (val & libc::POLLHUP) != 0 => {
1052 return Poll::Ready(Ok(0));
1053 }
1054 Some(val) if (val & libc::POLLOUT) != 0 => return Poll::Ready(Ok(10240)),
1055 _ => {}
1056 }
1057
1058 #[cfg(target_os = "windows")]
1063 if map.has_interest(InterestType::Writable) {
1064 return Poll::Ready(Ok(10240));
1065 }
1066
1067 Poll::Pending
1068 }
1069}