1use std::collections::HashMap;
2use std::collections::VecDeque;
3use std::future::Future;
4use std::net::IpAddr;
5use std::net::SocketAddr;
6use std::pin::Pin;
7use std::sync::Arc;
8use std::sync::Mutex;
9use std::sync::atomic::AtomicU64;
10use std::sync::atomic::Ordering;
11use std::task::Context;
12use std::task::Poll;
13use std::task::RawWaker;
14use std::task::RawWakerVTable;
15use std::task::Waker;
16use std::time::Duration;
17
18use bytes::Buf;
19use bytes::BytesMut;
20use futures_util::Sink;
21use futures_util::Stream;
22use futures_util::StreamExt;
23use futures_util::future::BoxFuture;
24use futures_util::stream::FuturesOrdered;
25#[cfg(feature = "hyper")]
26use hyper_util::rt::tokio::TokioIo;
27use tokio::io::AsyncRead;
28use tokio::io::AsyncWrite;
29use tokio::sync::mpsc;
30use tokio::sync::mpsc::error::TryRecvError;
31use tokio::sync::mpsc::error::TrySendError;
32use tokio_serde::SymmetricallyFramed;
33use tokio_serde::formats::SymmetricalBincode;
34#[cfg(feature = "cbor")]
35use tokio_serde::formats::SymmetricalCbor;
36#[cfg(feature = "json")]
37use tokio_serde::formats::SymmetricalJson;
38#[cfg(feature = "messagepack")]
39use tokio_serde::formats::SymmetricalMessagePack;
40use tokio_util::codec::FramedRead;
41use tokio_util::codec::FramedWrite;
42use tokio_util::codec::LengthDelimitedCodec;
43use virtual_mio::InterestType;
44use virtual_mio::block_on;
45
46use crate::IpCidr;
47use crate::IpRoute;
48use crate::NetworkError;
49use crate::StreamSecurity;
50use crate::VirtualConnectedSocket;
51use crate::VirtualConnectionlessSocket;
52use crate::VirtualIcmpSocket;
53use crate::VirtualIoSource;
54use crate::VirtualNetworking;
55use crate::VirtualRawSocket;
56use crate::VirtualSocket;
57use crate::VirtualTcpBoundSocket;
58use crate::VirtualTcpListener;
59use crate::VirtualTcpSocket;
60use crate::VirtualUdpSocket;
61use crate::meta;
62use crate::meta::FrameSerializationFormat;
63use crate::meta::RequestType;
64use crate::meta::ResponseType;
65use crate::meta::SocketId;
66use crate::meta::{MessageRequest, MessageResponse};
67
68use crate::Result;
69use crate::rx_tx::RemoteRx;
70use crate::rx_tx::RemoteTx;
71use crate::rx_tx::RemoteTxWakers;
72
73#[derive(Debug, Clone)]
74pub struct RemoteNetworkingClient {
75 common: Arc<RemoteCommon>,
76}
77
78impl RemoteNetworkingClient {
79 fn new(
80 tx: RemoteTx<MessageRequest>,
81 rx: RemoteRx<MessageResponse>,
82 rx_work: mpsc::UnboundedReceiver<BoxFuture<'static, ()>>,
83 ) -> (Self, RemoteNetworkingClientDriver) {
84 let common = RemoteCommon {
85 tx,
86 rx: Mutex::new(rx),
87 request_seed: AtomicU64::new(1),
88 requests: Default::default(),
89 socket_seed: AtomicU64::new(1),
90 recv_tx: Default::default(),
91 recv_with_addr_tx: Default::default(),
92 accept_tx: Default::default(),
93 sent_tx: Default::default(),
94 handlers: Default::default(),
95 stall: Default::default(),
96 };
97 let common = Arc::new(common);
98
99 let driver = RemoteNetworkingClientDriver {
100 more_work: rx_work,
101 tasks: Default::default(),
102 common: common.clone(),
103 };
104 let networking = Self { common };
105
106 (networking, driver)
107 }
108
109 pub fn new_from_mpsc(
112 tx: mpsc::Sender<MessageRequest>,
113 rx: mpsc::Receiver<MessageResponse>,
114 ) -> (Self, RemoteNetworkingClientDriver) {
115 let (tx_work, rx_work) = mpsc::unbounded_channel();
116 let tx_wakers = RemoteTxWakers::default();
117
118 let tx = RemoteTx::Mpsc {
119 tx,
120 work: tx_work,
121 wakers: tx_wakers.clone(),
122 };
123 let rx = RemoteRx::Mpsc {
124 rx,
125 wakers: tx_wakers,
126 };
127
128 Self::new(tx, rx, rx_work)
129 }
130
131 pub fn new_from_async_io<TX, RX>(
137 tx: TX,
138 rx: RX,
139 format: FrameSerializationFormat,
140 ) -> (Self, RemoteNetworkingClientDriver)
141 where
142 TX: AsyncWrite + Send + 'static,
143 RX: AsyncRead + Send + 'static,
144 {
145 let tx = FramedWrite::new(tx, LengthDelimitedCodec::new());
146 let tx: Pin<Box<dyn Sink<MessageRequest, Error = std::io::Error> + Send + 'static>> =
147 match format {
148 FrameSerializationFormat::Bincode => {
149 Box::pin(SymmetricallyFramed::new(tx, SymmetricalBincode::default()))
150 }
151 #[cfg(feature = "json")]
152 FrameSerializationFormat::Json => {
153 Box::pin(SymmetricallyFramed::new(tx, SymmetricalJson::default()))
154 }
155 #[cfg(feature = "messagepack")]
156 FrameSerializationFormat::MessagePack => Box::pin(SymmetricallyFramed::new(
157 tx,
158 SymmetricalMessagePack::default(),
159 )),
160 #[cfg(feature = "cbor")]
161 FrameSerializationFormat::Cbor => {
162 Box::pin(SymmetricallyFramed::new(tx, SymmetricalCbor::default()))
163 }
164 };
165
166 let rx = FramedRead::new(rx, LengthDelimitedCodec::new());
167 let rx: Pin<Box<dyn Stream<Item = std::io::Result<MessageResponse>> + Send + 'static>> =
168 match format {
169 FrameSerializationFormat::Bincode => {
170 Box::pin(SymmetricallyFramed::new(rx, SymmetricalBincode::default()))
171 }
172 #[cfg(feature = "json")]
173 FrameSerializationFormat::Json => {
174 Box::pin(SymmetricallyFramed::new(rx, SymmetricalJson::default()))
175 }
176 #[cfg(feature = "messagepack")]
177 FrameSerializationFormat::MessagePack => Box::pin(SymmetricallyFramed::new(
178 rx,
179 SymmetricalMessagePack::default(),
180 )),
181 #[cfg(feature = "cbor")]
182 FrameSerializationFormat::Cbor => {
183 Box::pin(SymmetricallyFramed::new(rx, SymmetricalCbor::default()))
184 }
185 };
186
187 let (tx_work, rx_work) = mpsc::unbounded_channel();
188 let tx_wakers = RemoteTxWakers::default();
189
190 let tx = RemoteTx::Stream {
191 tx: Arc::new(tokio::sync::Mutex::new(tx)),
192 work: tx_work,
193 wakers: tx_wakers,
194 };
195 let rx = RemoteRx::Stream { rx };
196
197 Self::new(tx, rx, rx_work)
198 }
199
200 #[cfg(feature = "hyper")]
203 pub fn new_from_hyper_ws_io(
204 tx: futures_util::stream::SplitSink<
205 hyper_tungstenite::WebSocketStream<TokioIo<hyper::upgrade::Upgraded>>,
206 hyper_tungstenite::tungstenite::Message,
207 >,
208 rx: futures_util::stream::SplitStream<
209 hyper_tungstenite::WebSocketStream<TokioIo<hyper::upgrade::Upgraded>>,
210 >,
211 format: FrameSerializationFormat,
212 ) -> (Self, RemoteNetworkingClientDriver) {
213 let (tx_work, rx_work) = mpsc::unbounded_channel();
214
215 let tx = RemoteTx::HyperWebSocket {
216 tx: Arc::new(tokio::sync::Mutex::new(tx)),
217 work: tx_work,
218 wakers: RemoteTxWakers::default(),
219 format,
220 };
221 let rx = RemoteRx::HyperWebSocket { rx, format };
222 Self::new(tx, rx, rx_work)
223 }
224
225 #[cfg(feature = "tokio-tungstenite")]
228 pub fn new_from_tokio_ws_io(
229 tx: futures_util::stream::SplitSink<
230 tokio_tungstenite::WebSocketStream<
231 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
232 >,
233 tokio_tungstenite::tungstenite::Message,
234 >,
235 rx: futures_util::stream::SplitStream<
236 tokio_tungstenite::WebSocketStream<
237 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
238 >,
239 >,
240 format: FrameSerializationFormat,
241 ) -> (Self, RemoteNetworkingClientDriver) {
242 let (tx_work, rx_work) = mpsc::unbounded_channel();
243
244 let tx = RemoteTx::TokioWebSocket {
245 tx: Arc::new(tokio::sync::Mutex::new(tx)),
246 work: tx_work,
247 wakers: RemoteTxWakers::default(),
248 format,
249 };
250 let rx = RemoteRx::TokioWebSocket { rx, format };
251 Self::new(tx, rx, rx_work)
252 }
253
254 fn new_socket(&self, id: SocketId) -> RemoteSocket {
255 let (tx, rx_recv) = tokio::sync::mpsc::channel(100);
256 self.common.recv_tx.lock().unwrap().insert(id, tx);
257
258 let (tx, rx_recv_with_addr) = tokio::sync::mpsc::channel(100);
259 self.common.recv_with_addr_tx.lock().unwrap().insert(id, tx);
260
261 let (tx, rx_accept) = tokio::sync::mpsc::channel(100);
262 self.common.accept_tx.lock().unwrap().insert(id, tx);
263
264 let (tx, rx_sent) = tokio::sync::mpsc::channel(100);
265 self.common.sent_tx.lock().unwrap().insert(id, tx);
266
267 RemoteSocket {
268 socket_id: id,
269 common: self.common.clone(),
270 rx_buffer: BytesMut::new(),
271 rx_recv,
272 rx_recv_with_addr,
273 rx_accept,
274 rx_sent,
275 tx_waker: TxWaker::new(&self.common).as_waker(),
276 pending_accept: None,
277 buffer_accept: Default::default(),
278 buffer_recv_with_addr: Default::default(),
279 send_available: 0,
280 owns_socket_bindings: true,
281 }
282 }
283}
284
285pin_project_lite::pin_project! {
286 pub struct RemoteNetworkingClientDriver {
287 common: Arc<RemoteCommon>,
288 more_work: mpsc::UnboundedReceiver<BoxFuture<'static, ()>>,
289 #[pin]
290 tasks: FuturesOrdered<BoxFuture<'static, ()>>,
291 }
292}
293
294impl Future for RemoteNetworkingClientDriver {
295 type Output = ();
296
297 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
298 let mut not_stalled_guard = None;
302
303 loop {
306 while let Poll::Ready(Some(work)) = Pin::new(&mut self.more_work).poll_recv(cx) {
308 self.tasks.push_back(work);
309 }
310
311 match self.tasks.poll_next_unpin(cx) {
314 Poll::Ready(Some(_)) => continue,
315 Poll::Ready(None) => {
316 not_stalled_guard.take();
317 }
318 Poll::Pending if not_stalled_guard.is_none() => {
319 match self.common.stall.clone().try_lock_owned() {
320 Ok(guard) => {
321 not_stalled_guard.replace(guard);
322 }
323 _ => {
324 return Poll::Pending;
325 }
326 }
327 }
328 Poll::Pending => {}
329 };
330
331 let msg = {
333 let mut rx_guard = self.common.rx.lock().unwrap();
334 rx_guard.poll(cx)
335 };
336 return match msg {
337 Poll::Ready(Some(msg)) => {
338 match msg {
339 MessageResponse::Recv { socket_id, data } => {
340 let tx = {
341 let guard = self.common.recv_tx.lock().unwrap();
342 match guard.get(&socket_id) {
343 Some(tx) => tx.clone(),
344 None => {
345 continue;
346 }
347 }
348 };
349 let common = self.common.clone();
350 self.tasks.push_back(Box::pin(async move {
351 tx.send(data).await.ok();
352
353 if let Some(h) = common.handlers.lock().unwrap().get_mut(&socket_id)
354 {
355 h.push_interest(InterestType::Readable)
356 }
357 }));
358 }
359 MessageResponse::RecvWithAddr {
360 socket_id,
361 data,
362 addr,
363 } => {
364 let tx = {
365 let guard = self.common.recv_with_addr_tx.lock().unwrap();
366 match guard.get(&socket_id) {
367 Some(tx) => tx.clone(),
368 None => continue,
369 }
370 };
371 let common = self.common.clone();
372 self.tasks.push_back(Box::pin(async move {
373 tx.send(DataWithAddr { data, addr }).await.ok();
374
375 if let Some(h) = common.handlers.lock().unwrap().get_mut(&socket_id)
376 {
377 h.push_interest(InterestType::Readable)
378 }
379 }));
380 }
381 MessageResponse::Sent {
382 socket_id, amount, ..
383 } => {
384 let tx = {
385 let guard = self.common.sent_tx.lock().unwrap();
386 match guard.get(&socket_id) {
387 Some(tx) => tx.clone(),
388 None => continue,
389 }
390 };
391 self.tasks.push_back(Box::pin(async move {
392 tx.send(amount).await.ok();
393 }));
394 if let Some(h) =
395 self.common.handlers.lock().unwrap().get_mut(&socket_id)
396 {
397 h.push_interest(InterestType::Writable)
398 }
399 }
400 MessageResponse::SendError {
401 socket_id, error, ..
402 } => match &error {
403 NetworkError::ConnectionAborted
404 | NetworkError::ConnectionReset
405 | NetworkError::BrokenPipe => {
406 if let Some(h) =
407 self.common.handlers.lock().unwrap().get_mut(&socket_id)
408 {
409 h.push_interest(InterestType::Closed)
410 }
411 }
412 _ => {
413 if let Some(h) =
414 self.common.handlers.lock().unwrap().get_mut(&socket_id)
415 {
416 h.push_interest(InterestType::Writable)
417 }
418 }
419 },
420 MessageResponse::FinishAccept {
421 socket_id,
422 child_id,
423 addr,
424 } => {
425 let common = self.common.clone();
426 self.tasks.push_back(Box::pin(async move {
427 let tx = common.accept_tx.lock().unwrap().get(&socket_id).cloned();
428 if let Some(tx) = tx {
429 tx.send(SocketWithAddr {
430 socket: child_id,
431 addr,
432 })
433 .await
434 .ok();
435 }
436
437 if let Some(h) = common.handlers.lock().unwrap().get_mut(&socket_id)
438 {
439 h.push_interest(InterestType::Readable)
440 }
441 }));
442 }
443 MessageResponse::Closed { socket_id } => {
444 if let Some(h) =
445 self.common.handlers.lock().unwrap().get_mut(&socket_id)
446 {
447 h.push_interest(InterestType::Closed)
448 }
449 }
450 MessageResponse::ResponseToRequest { req_id, res } => {
451 let mut requests = self.common.requests.lock().unwrap();
452 if let Some(request) = requests.remove(&req_id) {
453 request.try_send(res).ok();
454 }
455 }
456 }
457 continue;
458 }
459 Poll::Ready(None) => Poll::Ready(()),
460 Poll::Pending => Poll::Pending,
461 };
462 }
463 }
464}
465
466#[derive(Debug)]
467struct TxWaker {
468 common: Arc<RemoteCommon>,
469}
470impl TxWaker {
471 pub fn new(common: &Arc<RemoteCommon>) -> Arc<Self> {
472 Arc::new(Self {
473 common: common.clone(),
474 })
475 }
476
477 fn wake_now(&self) {
478 let mut guard = self.common.handlers.lock().unwrap();
479 for handler in guard.values_mut() {
480 handler.push_interest(InterestType::Writable);
481 }
482 }
483
484 pub fn as_waker(self: &Arc<Self>) -> Waker {
485 let s: *const Self = Arc::into_raw(Arc::clone(self));
486 let raw_waker = RawWaker::new(s as *const (), &VTABLE);
487 unsafe { Waker::from_raw(raw_waker) }
488 }
489}
490
491fn tx_waker_wake(s: &TxWaker) {
492 let waker_arc = unsafe { Arc::from_raw(s) };
493 waker_arc.wake_now();
494}
495
496fn tx_waker_clone(s: &TxWaker) -> RawWaker {
497 let arc = unsafe { Arc::from_raw(s) };
498 std::mem::forget(arc.clone());
499 RawWaker::new(Arc::into_raw(arc) as *const (), &VTABLE)
500}
501
502const VTABLE: RawWakerVTable = unsafe {
503 RawWakerVTable::new(
504 |s| tx_waker_clone(&*(s as *const TxWaker)), |s| tx_waker_wake(&*(s as *const TxWaker)), |s| (*(s as *const TxWaker)).wake_now(), |s| drop(Arc::from_raw(s as *const TxWaker)), )
509};
510
511#[derive(Debug)]
512struct RequestTx {
513 tx: mpsc::Sender<ResponseType>,
514}
515impl RequestTx {
516 pub fn try_send(self, msg: ResponseType) -> Result<()> {
517 match self.tx.try_send(msg) {
518 Ok(()) => Ok(()),
519 Err(TrySendError::Closed(_)) => Err(NetworkError::ConnectionAborted),
520 Err(TrySendError::Full(_)) => Err(NetworkError::WouldBlock),
521 }
522 }
523}
524
525#[derive(Debug)]
526struct DataWithAddr {
527 pub data: Vec<u8>,
528 pub addr: SocketAddr,
529}
530#[derive(Debug)]
531struct SocketWithAddr {
532 pub socket: SocketId,
533 pub addr: SocketAddr,
534}
535type SocketMap<T> = HashMap<SocketId, T>;
536
537#[derive(derive_more::Debug)]
538struct RemoteCommon {
539 #[debug(ignore)]
540 tx: RemoteTx<MessageRequest>,
541 #[debug(ignore)]
542 rx: Mutex<RemoteRx<MessageResponse>>,
543 request_seed: AtomicU64,
544 requests: Mutex<HashMap<u64, RequestTx>>,
545 socket_seed: AtomicU64,
546 recv_tx: Mutex<SocketMap<mpsc::Sender<Vec<u8>>>>,
547 recv_with_addr_tx: Mutex<SocketMap<mpsc::Sender<DataWithAddr>>>,
548 accept_tx: Mutex<SocketMap<mpsc::Sender<SocketWithAddr>>>,
549 sent_tx: Mutex<SocketMap<mpsc::Sender<u64>>>,
550 #[debug(ignore)]
551 handlers: Mutex<SocketMap<Box<dyn virtual_mio::InterestHandler + Send + Sync>>>,
552
553 stall: Arc<tokio::sync::Mutex<()>>,
556}
557
558impl RemoteCommon {
559 async fn io_iface(&self, req: RequestType) -> ResponseType {
560 let req_id = self.request_seed.fetch_add(1, Ordering::SeqCst);
561 let mut req_rx = {
562 let (tx, rx) = mpsc::channel(1);
563 let mut guard = self.requests.lock().unwrap();
564 guard.insert(req_id, RequestTx { tx });
565 rx
566 };
567 if let Err(err) = self
568 .tx
569 .send(MessageRequest::Interface {
570 req_id: Some(req_id),
571 req,
572 })
573 .await
574 {
575 return ResponseType::Err(err);
576 };
577 req_rx.recv().await.unwrap()
578 }
579
580 fn io_iface_fire_and_forget(&self, req: RequestType) -> Result<()> {
581 self.tx
582 .send_with_driver(MessageRequest::Interface { req_id: None, req })
583 }
584}
585
586#[async_trait::async_trait]
587impl VirtualNetworking for RemoteNetworkingClient {
588 async fn bridge(
589 &self,
590 network: &str,
591 access_token: &str,
592 security: StreamSecurity,
593 ) -> Result<()> {
594 match self
595 .common
596 .io_iface(RequestType::Bridge {
597 network: network.to_string(),
598 access_token: access_token.to_string(),
599 security,
600 })
601 .await
602 {
603 ResponseType::Err(err) => Err(err),
604 ResponseType::None => Ok(()),
605 res => {
606 tracing::debug!("invalid response to bridge request - {res:?}");
607 Err(NetworkError::IOError)
608 }
609 }
610 }
611
612 async fn unbridge(&self) -> Result<()> {
613 match self.common.io_iface(RequestType::Unbridge).await {
614 ResponseType::Err(err) => Err(err),
615 ResponseType::None => Ok(()),
616 res => {
617 tracing::debug!("invalid response to unbridge request - {res:?}");
618 Err(NetworkError::IOError)
619 }
620 }
621 }
622
623 async fn dhcp_acquire(&self) -> Result<Vec<IpAddr>> {
624 match self.common.io_iface(RequestType::DhcpAcquire).await {
625 ResponseType::Err(err) => Err(err),
626 ResponseType::IpAddressList(ips) => Ok(ips),
627 res => {
628 tracing::debug!("invalid response to DHCP acquire request - {res:?}");
629 Err(NetworkError::IOError)
630 }
631 }
632 }
633
634 async fn ip_add(&self, ip: IpAddr, prefix: u8) -> Result<()> {
635 self.common
636 .io_iface_fire_and_forget(RequestType::IpAdd { ip, prefix })
637 }
638
639 async fn ip_remove(&self, ip: IpAddr) -> Result<()> {
640 self.common
641 .io_iface_fire_and_forget(RequestType::IpRemove(ip))
642 }
643
644 async fn ip_clear(&self) -> Result<()> {
645 self.common.io_iface_fire_and_forget(RequestType::IpClear)
646 }
647
648 async fn ip_list(&self) -> Result<Vec<IpCidr>> {
649 match self.common.io_iface(RequestType::GetIpList).await {
650 ResponseType::Err(err) => Err(err),
651 ResponseType::CidrList(routes) => Ok(routes),
652 res => {
653 tracing::debug!("invalid response to IP list request - {res:?}");
654 Err(NetworkError::IOError)
655 }
656 }
657 }
658
659 async fn mac(&self) -> Result<[u8; 6]> {
660 match self.common.io_iface(RequestType::GetMac).await {
661 ResponseType::Err(err) => Err(err),
662 ResponseType::Mac(mac) => Ok(mac),
663 res => {
664 tracing::debug!("invalid response to MAC request - {res:?}");
665 Err(NetworkError::IOError)
666 }
667 }
668 }
669
670 async fn gateway_set(&self, ip: IpAddr) -> Result<()> {
671 self.common
672 .io_iface_fire_and_forget(RequestType::GatewaySet(ip))
673 }
674
675 async fn route_add(
676 &self,
677 cidr: IpCidr,
678 via_router: IpAddr,
679 preferred_until: Option<Duration>,
680 expires_at: Option<Duration>,
681 ) -> Result<()> {
682 self.common.io_iface_fire_and_forget(RequestType::RouteAdd {
683 cidr,
684 via_router,
685 preferred_until,
686 expires_at,
687 })
688 }
689
690 async fn route_remove(&self, cidr: IpAddr) -> Result<()> {
691 self.common
692 .io_iface_fire_and_forget(RequestType::RouteRemove(cidr))
693 }
694
695 async fn route_clear(&self) -> Result<()> {
696 self.common
697 .io_iface_fire_and_forget(RequestType::RouteClear)
698 }
699
700 async fn route_list(&self) -> Result<Vec<IpRoute>> {
701 match self.common.io_iface(RequestType::GetRouteList).await {
702 ResponseType::Err(err) => Err(err),
703 ResponseType::RouteList(routes) => Ok(routes),
704 res => {
705 tracing::debug!("invalid response to route list request - {res:?}");
706 Err(NetworkError::IOError)
707 }
708 }
709 }
710
711 async fn bind_raw(&self) -> Result<Box<dyn VirtualRawSocket + Sync>> {
712 let socket_id: SocketId = self
713 .common
714 .socket_seed
715 .fetch_add(1, Ordering::SeqCst)
716 .into();
717 match self.common.io_iface(RequestType::BindRaw(socket_id)).await {
718 ResponseType::Err(err) => Err(err),
719 ResponseType::None => Ok(Box::new(self.new_socket(socket_id))),
720 ResponseType::Socket(socket_id) => Ok(Box::new(self.new_socket(socket_id))),
721 res => {
722 tracing::debug!("invalid response to bind RAw request - {res:?}");
723 Err(NetworkError::IOError)
724 }
725 }
726 }
727
728 async fn listen_tcp(
729 &self,
730 addr: SocketAddr,
731 only_v6: bool,
732 reuse_port: bool,
733 reuse_addr: bool,
734 ) -> Result<Box<dyn VirtualTcpListener + Sync>> {
735 let socket_id: SocketId = self
736 .common
737 .socket_seed
738 .fetch_add(1, Ordering::SeqCst)
739 .into();
740 match self
741 .common
742 .io_iface(RequestType::ListenTcp {
743 socket_id,
744 addr,
745 only_v6,
746 reuse_port,
747 reuse_addr,
748 })
749 .await
750 {
751 ResponseType::Err(err) => Err(err),
752 ResponseType::None => Ok(Box::new(self.new_socket(socket_id))),
753 ResponseType::Socket(socket_id) => {
754 let mut socket = self.new_socket(socket_id);
755 socket.touch_begin_accept().ok();
756 Ok(Box::new(socket))
757 }
758 res => {
759 tracing::debug!("invalid response to listen TCP request - {res:?}");
760 Err(NetworkError::IOError)
761 }
762 }
763 }
764
765 async fn bind_tcp(
766 &self,
767 addr: SocketAddr,
768 only_v6: bool,
769 reuse_port: bool,
770 reuse_addr: bool,
771 ) -> Result<Box<dyn VirtualTcpBoundSocket + Sync>> {
772 let socket_id: SocketId = self
773 .common
774 .socket_seed
775 .fetch_add(1, Ordering::SeqCst)
776 .into();
777 match self
778 .common
779 .io_iface(RequestType::BindTcp {
780 socket_id,
781 addr,
782 only_v6,
783 reuse_port,
784 reuse_addr,
785 })
786 .await
787 {
788 ResponseType::Err(err) => Err(err),
789 ResponseType::None => Ok(Box::new(self.new_socket(socket_id))),
790 ResponseType::Socket(socket_id) => Ok(Box::new(self.new_socket(socket_id))),
791 res => {
792 tracing::debug!("invalid response to bind TCP request - {res:?}");
793 Err(NetworkError::IOError)
794 }
795 }
796 }
797
798 async fn bind_udp(
799 &self,
800 addr: SocketAddr,
801 reuse_port: bool,
802 reuse_addr: bool,
803 ) -> Result<Box<dyn VirtualUdpSocket + Sync>> {
804 let socket_id: SocketId = self
805 .common
806 .socket_seed
807 .fetch_add(1, Ordering::SeqCst)
808 .into();
809 match self
810 .common
811 .io_iface(RequestType::BindUdp {
812 socket_id,
813 addr,
814 reuse_port,
815 reuse_addr,
816 })
817 .await
818 {
819 ResponseType::Err(err) => Err(err),
820 ResponseType::None => Ok(Box::new(self.new_socket(socket_id))),
821 ResponseType::Socket(socket_id) => Ok(Box::new(self.new_socket(socket_id))),
822 res => {
823 tracing::debug!("invalid response to bind UDP request - {res:?}");
824 Err(NetworkError::IOError)
825 }
826 }
827 }
828
829 async fn bind_icmp(&self, addr: IpAddr) -> Result<Box<dyn VirtualIcmpSocket + Sync>> {
830 let socket_id: SocketId = self
831 .common
832 .socket_seed
833 .fetch_add(1, Ordering::SeqCst)
834 .into();
835 match self
836 .common
837 .io_iface(RequestType::BindIcmp { socket_id, addr })
838 .await
839 {
840 ResponseType::Err(err) => Err(err),
841 ResponseType::None => Ok(Box::new(self.new_socket(socket_id))),
842 ResponseType::Socket(socket_id) => Ok(Box::new(self.new_socket(socket_id))),
843 res => {
844 tracing::debug!("invalid response to bind ICMP request - {res:?}");
845 Err(NetworkError::IOError)
846 }
847 }
848 }
849
850 async fn connect_tcp(
851 &self,
852 addr: SocketAddr,
853 peer: SocketAddr,
854 ) -> Result<Box<dyn VirtualTcpSocket + Sync>> {
855 let socket_id: SocketId = self
856 .common
857 .socket_seed
858 .fetch_add(1, Ordering::SeqCst)
859 .into();
860 match self
861 .common
862 .io_iface(RequestType::ConnectTcp {
863 socket_id,
864 addr,
865 peer,
866 })
867 .await
868 {
869 ResponseType::Err(err) => Err(err),
870 ResponseType::None => Ok(Box::new(self.new_socket(socket_id))),
871 ResponseType::Socket(socket_id) => Ok(Box::new(self.new_socket(socket_id))),
872 res => {
873 tracing::debug!("invalid response to connect TCP request - {res:?}");
874 Err(NetworkError::IOError)
875 }
876 }
877 }
878
879 async fn resolve(
880 &self,
881 host: &str,
882 port: Option<u16>,
883 dns_server: Option<IpAddr>,
884 ) -> Result<Vec<IpAddr>> {
885 match self
886 .common
887 .io_iface(RequestType::Resolve {
888 host: host.to_string(),
889 port,
890 dns_server,
891 })
892 .await
893 {
894 ResponseType::Err(err) => Err(err),
895 ResponseType::IpAddressList(ips) => Ok(ips),
896 res => {
897 tracing::debug!("invalid response to resolve request - {res:?}");
898 Err(NetworkError::IOError)
899 }
900 }
901 }
902}
903
904#[derive(Debug)]
905struct RemoteSocket {
906 socket_id: SocketId,
907 common: Arc<RemoteCommon>,
908 rx_buffer: BytesMut,
909 rx_recv: mpsc::Receiver<Vec<u8>>,
910 rx_recv_with_addr: mpsc::Receiver<DataWithAddr>,
911 tx_waker: Waker,
912 rx_accept: mpsc::Receiver<SocketWithAddr>,
913 rx_sent: mpsc::Receiver<u64>,
914 pending_accept: Option<(SocketId, mpsc::Receiver<Vec<u8>>)>,
915 buffer_recv_with_addr: VecDeque<DataWithAddr>,
916 buffer_accept: VecDeque<SocketWithAddr>,
917 send_available: u64,
918 owns_socket_bindings: bool,
919}
920impl Drop for RemoteSocket {
921 fn drop(&mut self) {
922 if !self.owns_socket_bindings {
923 return;
924 }
925 let _ = self.io_socket_fire_and_forget(RequestType::Close);
926 self.release_socket_bindings();
927 }
928}
929
930impl RemoteSocket {
931 fn release_socket_bindings(&mut self) {
932 self.owns_socket_bindings = false;
933 self.common.recv_tx.lock().unwrap().remove(&self.socket_id);
934 self.common
935 .recv_with_addr_tx
936 .lock()
937 .unwrap()
938 .remove(&self.socket_id);
939 self.common
940 .accept_tx
941 .lock()
942 .unwrap()
943 .remove(&self.socket_id);
944 self.common.sent_tx.lock().unwrap().remove(&self.socket_id);
945 self.common.handlers.lock().unwrap().remove(&self.socket_id);
946
947 if let Some((child_id, _)) = self.pending_accept.take() {
948 self.common.recv_tx.lock().unwrap().remove(&child_id);
949 }
950 }
951
952 async fn io_socket(&self, req: RequestType) -> ResponseType {
953 let req_id = self.common.request_seed.fetch_add(1, Ordering::SeqCst);
954 let mut req_rx = {
955 let (tx, rx) = mpsc::channel(1);
956 let mut guard = self.common.requests.lock().unwrap();
957 guard.insert(req_id, RequestTx { tx });
958 rx
959 };
960 if let Err(err) = self
961 .common
962 .tx
963 .send(MessageRequest::Socket {
964 socket: self.socket_id,
965 req_id: Some(req_id),
966 req,
967 })
968 .await
969 {
970 return ResponseType::Err(err);
971 };
972 req_rx.recv().await.unwrap()
973 }
974
975 fn io_socket_fire_and_forget(&self, req: RequestType) -> Result<()> {
976 self.common.tx.send_with_driver(MessageRequest::Socket {
977 socket: self.socket_id,
978 req_id: None,
979 req,
980 })
981 }
982
983 fn touch_begin_accept(&mut self) -> Result<()> {
984 if self.pending_accept.is_some() {
985 return Ok(());
986 }
987 let child_id: SocketId = self
988 .common
989 .socket_seed
990 .fetch_add(1, Ordering::SeqCst)
991 .into();
992 self.io_socket_fire_and_forget(RequestType::BeginAccept(child_id))?;
993
994 let (tx, rx_recv) = tokio::sync::mpsc::channel(100);
995 self.common.recv_tx.lock().unwrap().insert(child_id, tx);
996
997 self.pending_accept.replace((child_id, rx_recv));
998 Ok(())
999 }
1000
1001 fn transition_socket(&mut self) -> RemoteSocket {
1002 let (_tx_recv, rx_recv) = tokio::sync::mpsc::channel(1);
1003 let (_tx_recv_with_addr, rx_recv_with_addr) = tokio::sync::mpsc::channel(1);
1004 let (_tx_accept, rx_accept) = tokio::sync::mpsc::channel(1);
1005 let (_tx_sent, rx_sent) = tokio::sync::mpsc::channel(1);
1006
1007 self.owns_socket_bindings = false;
1008
1009 RemoteSocket {
1010 socket_id: self.socket_id,
1011 common: self.common.clone(),
1012 rx_buffer: std::mem::take(&mut self.rx_buffer),
1013 rx_recv: std::mem::replace(&mut self.rx_recv, rx_recv),
1014 rx_recv_with_addr: std::mem::replace(&mut self.rx_recv_with_addr, rx_recv_with_addr),
1015 tx_waker: self.tx_waker.clone(),
1016 rx_accept: std::mem::replace(&mut self.rx_accept, rx_accept),
1017 rx_sent: std::mem::replace(&mut self.rx_sent, rx_sent),
1018 pending_accept: self.pending_accept.take(),
1019 buffer_recv_with_addr: std::mem::take(&mut self.buffer_recv_with_addr),
1020 buffer_accept: std::mem::take(&mut self.buffer_accept),
1021 send_available: self.send_available,
1022 owns_socket_bindings: true,
1023 }
1024 }
1025}
1026
1027impl VirtualIoSource for RemoteSocket {
1028 fn remove_handler(&mut self) {
1029 self.common.handlers.lock().unwrap().remove(&self.socket_id);
1030 }
1031
1032 fn poll_read_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<usize>> {
1033 if !self.rx_buffer.is_empty() {
1034 return Poll::Ready(Ok(self.rx_buffer.len()));
1035 }
1036 match self.rx_recv.poll_recv(cx) {
1037 Poll::Ready(Some(data)) => {
1038 self.rx_buffer.extend_from_slice(&data);
1039 return Poll::Ready(Ok(self.rx_buffer.len()));
1040 }
1041 Poll::Ready(None) => return Poll::Ready(Ok(0)),
1042 Poll::Pending => {}
1043 }
1044 if !self.buffer_recv_with_addr.is_empty() {
1045 let total = self
1046 .buffer_recv_with_addr
1047 .iter()
1048 .map(|a| a.data.len())
1049 .sum();
1050 return Poll::Ready(Ok(total));
1051 }
1052 match self.rx_recv_with_addr.poll_recv(cx) {
1053 Poll::Ready(Some(data)) => self.buffer_recv_with_addr.push_back(data),
1054 Poll::Ready(None) => return Poll::Ready(Ok(0)),
1055 Poll::Pending => {}
1056 }
1057 if !self.buffer_accept.is_empty() {
1058 return Poll::Ready(Ok(self.buffer_accept.len()));
1059 }
1060 match self.rx_accept.poll_recv(cx) {
1061 Poll::Ready(Some(data)) => self.buffer_accept.push_back(data),
1062 Poll::Ready(None) => {}
1063 Poll::Pending => {}
1064 }
1065 Poll::Pending
1066 }
1067
1068 fn poll_write_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<usize>> {
1069 if self.send_available > 0 {
1070 return Poll::Ready(Ok(self.send_available as usize));
1071 }
1072 match self.rx_sent.poll_recv(cx) {
1073 Poll::Ready(Some(amt)) => {
1074 self.send_available += amt;
1075 return Poll::Ready(Ok(self.send_available as usize));
1076 }
1077 Poll::Ready(None) => return Poll::Ready(Ok(0)),
1078 Poll::Pending => {}
1079 }
1080 Poll::Pending
1081 }
1082}
1083
1084impl VirtualSocket for RemoteSocket {
1085 fn set_ttl(&mut self, ttl: u32) -> Result<()> {
1086 self.io_socket_fire_and_forget(RequestType::SetTtl(ttl))
1087 }
1088
1089 fn ttl(&self) -> Result<u32> {
1090 match block_on(self.io_socket(RequestType::GetTtl)) {
1091 ResponseType::Err(err) => Err(err),
1092 ResponseType::Ttl(ttl) => Ok(ttl),
1093 res => {
1094 tracing::debug!("invalid response to get TTL request - {res:?}");
1095 Err(NetworkError::IOError)
1096 }
1097 }
1098 }
1099
1100 fn addr_local(&self) -> Result<SocketAddr> {
1101 match block_on(self.io_socket(RequestType::GetAddrLocal)) {
1102 ResponseType::Err(err) => Err(err),
1103 ResponseType::SocketAddr(addr) => Ok(addr),
1104 res => {
1105 tracing::debug!("invalid response to address local request - {res:?}");
1106 Err(NetworkError::IOError)
1107 }
1108 }
1109 }
1110
1111 fn status(&self) -> Result<crate::SocketStatus> {
1112 match block_on(self.io_socket(RequestType::GetStatus)) {
1113 ResponseType::Err(err) => Err(err),
1114 ResponseType::Status(status) => Ok(status),
1115 res => {
1116 tracing::debug!("invalid response to status request - {res:?}");
1117 Err(NetworkError::IOError)
1118 }
1119 }
1120 }
1121
1122 fn set_handler(
1123 &mut self,
1124 handler: Box<dyn virtual_mio::InterestHandler + Send + Sync>,
1125 ) -> Result<()> {
1126 self.common
1127 .handlers
1128 .lock()
1129 .unwrap()
1130 .insert(self.socket_id, handler);
1131 Ok(())
1132 }
1133}
1134
1135impl VirtualTcpListener for RemoteSocket {
1136 fn try_accept(&mut self) -> Result<(Box<dyn VirtualTcpSocket + Sync>, SocketAddr)> {
1137 self.touch_begin_accept()?;
1139 let accepted = if let Some(child) = self.buffer_accept.pop_front() {
1140 child
1141 } else {
1142 self.rx_accept.try_recv().map_err(|err| match err {
1143 TryRecvError::Empty => NetworkError::WouldBlock,
1144 TryRecvError::Disconnected => NetworkError::ConnectionAborted,
1145 })?
1146 };
1147
1148 let mut rx_recv = None;
1152 if let Some((rx_socket, existing_rx_recv)) = self.pending_accept.take()
1153 && accepted.socket == rx_socket
1154 {
1155 rx_recv.replace(existing_rx_recv);
1156 }
1157 let rx_recv = match rx_recv {
1158 Some(rx_recv) => rx_recv,
1159 None => {
1160 let (tx, rx_recv) = tokio::sync::mpsc::channel(100);
1161 self.common
1162 .recv_tx
1163 .lock()
1164 .unwrap()
1165 .insert(accepted.socket, tx);
1166 rx_recv
1167 }
1168 };
1169 self.touch_begin_accept().ok();
1170
1171 let (tx, rx_recv_with_addr) = tokio::sync::mpsc::channel(100);
1172 self.common
1173 .recv_with_addr_tx
1174 .lock()
1175 .unwrap()
1176 .insert(accepted.socket, tx);
1177
1178 let (tx, rx_accept) = tokio::sync::mpsc::channel(100);
1179 self.common
1180 .accept_tx
1181 .lock()
1182 .unwrap()
1183 .insert(accepted.socket, tx);
1184
1185 let (tx, rx_sent) = tokio::sync::mpsc::channel(100);
1186 self.common
1187 .sent_tx
1188 .lock()
1189 .unwrap()
1190 .insert(accepted.socket, tx);
1191
1192 let socket = RemoteSocket {
1193 socket_id: accepted.socket,
1194 common: self.common.clone(),
1195 rx_buffer: BytesMut::new(),
1196 rx_recv,
1197 rx_recv_with_addr,
1198 rx_accept,
1199 rx_sent,
1200 pending_accept: None,
1201 tx_waker: TxWaker::new(&self.common).as_waker(),
1202 buffer_accept: Default::default(),
1203 buffer_recv_with_addr: Default::default(),
1204 send_available: 0,
1205 owns_socket_bindings: true,
1206 };
1207 Ok((Box::new(socket), accepted.addr))
1208 }
1209
1210 fn set_handler(
1211 &mut self,
1212 handler: Box<dyn virtual_mio::InterestHandler + Send + Sync>,
1213 ) -> Result<()> {
1214 VirtualSocket::set_handler(self, handler)
1215 }
1216
1217 fn addr_local(&self) -> Result<SocketAddr> {
1218 match block_on(self.io_socket(RequestType::GetAddrLocal)) {
1219 ResponseType::Err(err) => Err(err),
1220 ResponseType::SocketAddr(addr) => Ok(addr),
1221 res => {
1222 tracing::debug!("invalid response to addr local request - {res:?}");
1223 Err(NetworkError::IOError)
1224 }
1225 }
1226 }
1227
1228 fn set_ttl(&mut self, ttl: u8) -> Result<()> {
1229 self.io_socket_fire_and_forget(RequestType::SetTtl(ttl as u32))
1230 }
1231
1232 fn ttl(&self) -> Result<u8> {
1233 match block_on(self.io_socket(RequestType::GetTtl)) {
1234 ResponseType::Err(err) => Err(err),
1235 ResponseType::Ttl(val) => Ok(val.try_into().map_err(|_| NetworkError::InvalidData)?),
1236 res => {
1237 tracing::debug!("invalid response to get TTL request - {res:?}");
1238 Err(NetworkError::IOError)
1239 }
1240 }
1241 }
1242}
1243
1244impl VirtualTcpBoundSocket for RemoteSocket {
1245 fn addr_local(&self) -> Result<SocketAddr> {
1246 VirtualSocket::addr_local(self)
1247 }
1248
1249 fn listen(&mut self) -> Result<Box<dyn VirtualTcpListener + Sync>> {
1250 match block_on(self.io_socket(RequestType::ListenBound)) {
1251 ResponseType::Err(err) => Err(err),
1252 ResponseType::None => {
1253 let mut socket = self.transition_socket();
1254 socket.touch_begin_accept().ok();
1255 Ok(Box::new(socket))
1256 }
1257 res => {
1258 tracing::debug!("invalid response to listen bound request - {res:?}");
1259 Err(NetworkError::IOError)
1260 }
1261 }
1262 }
1263
1264 fn connect(&mut self, peer: SocketAddr) -> Result<Box<dyn VirtualTcpSocket + Sync>> {
1265 match block_on(self.io_socket(RequestType::ConnectBound { peer })) {
1266 ResponseType::Err(err) => Err(err),
1267 ResponseType::None => Ok(Box::new(self.transition_socket())),
1268 res => {
1269 tracing::debug!("invalid response to connect bound request - {res:?}");
1270 Err(NetworkError::IOError)
1271 }
1272 }
1273 }
1274
1275 fn set_ttl(&mut self, ttl: u32) -> Result<()> {
1276 VirtualSocket::set_ttl(self, ttl)
1277 }
1278
1279 fn ttl(&self) -> Result<u32> {
1280 VirtualSocket::ttl(self)
1281 }
1282}
1283
1284impl VirtualRawSocket for RemoteSocket {
1285 fn try_send(&mut self, data: &[u8]) -> Result<usize> {
1286 let mut cx = Context::from_waker(&self.tx_waker);
1287 match self.common.tx.poll_send(
1288 &mut cx,
1289 MessageRequest::Send {
1290 socket: self.socket_id,
1291 data: data.to_vec(),
1292 req_id: None,
1293 },
1294 ) {
1295 Poll::Ready(Ok(())) => Ok(data.len()),
1296 Poll::Ready(Err(NetworkError::WouldBlock)) | Poll::Pending => {
1297 self.send_available = 0;
1298 Err(NetworkError::WouldBlock)
1299 }
1300 Poll::Ready(Err(err)) => Err(err),
1301 }
1302 }
1303
1304 fn try_flush(&mut self) -> Result<()> {
1305 let mut cx = Context::from_waker(&self.tx_waker);
1306 match self.common.tx.poll_send(
1307 &mut cx,
1308 MessageRequest::Socket {
1309 socket: self.socket_id,
1310 req: RequestType::Flush,
1311 req_id: None,
1312 },
1313 ) {
1314 Poll::Ready(Ok(())) => Ok(()),
1315 Poll::Ready(Err(NetworkError::WouldBlock)) | Poll::Pending => {
1316 self.send_available = 0;
1317 Err(NetworkError::WouldBlock)
1318 }
1319 Poll::Ready(Err(err)) => Err(err),
1320 }
1321 }
1322
1323 fn try_recv(&mut self, buf: &mut [std::mem::MaybeUninit<u8>], peek: bool) -> Result<usize> {
1324 loop {
1325 if !self.rx_buffer.is_empty() {
1326 let amt = self.rx_buffer.len().min(buf.len());
1327 let buf: &mut [u8] = unsafe { std::mem::transmute(buf) };
1328 buf[..amt].copy_from_slice(&self.rx_buffer[..amt]);
1329 if !peek {
1330 self.rx_buffer.advance(amt);
1331 }
1332 return Ok(amt);
1333 }
1334 match self.rx_recv.try_recv() {
1335 Ok(data) => self.rx_buffer.extend_from_slice(&data),
1336 Err(TryRecvError::Disconnected) => return Err(NetworkError::ConnectionAborted),
1337 Err(TryRecvError::Empty) => return Err(NetworkError::WouldBlock),
1338 }
1339 }
1340 }
1341
1342 fn set_promiscuous(&mut self, promiscuous: bool) -> Result<()> {
1343 self.io_socket_fire_and_forget(RequestType::SetPromiscuous(promiscuous))
1344 }
1345
1346 fn promiscuous(&self) -> Result<bool> {
1347 match block_on(self.io_socket(RequestType::GetPromiscuous)) {
1348 ResponseType::Err(err) => Err(err),
1349 ResponseType::Flag(val) => Ok(val),
1350 res => {
1351 tracing::debug!("invalid response to get promiscuous request - {res:?}");
1352 Err(NetworkError::IOError)
1353 }
1354 }
1355 }
1356}
1357
1358impl VirtualConnectionlessSocket for RemoteSocket {
1359 fn try_send_to(&mut self, data: &[u8], addr: SocketAddr) -> Result<usize> {
1360 let req_id = self.common.request_seed.fetch_add(1, Ordering::SeqCst);
1361 let mut cx = Context::from_waker(&self.tx_waker);
1362 match self.common.tx.poll_send(
1363 &mut cx,
1364 MessageRequest::SendTo {
1365 socket: self.socket_id,
1366 data: data.to_vec(),
1367 addr,
1368 req_id: Some(req_id),
1369 },
1370 ) {
1371 Poll::Ready(Ok(())) => Ok(data.len()),
1372 Poll::Ready(Err(NetworkError::WouldBlock)) | Poll::Pending => {
1373 self.send_available = 0;
1374 Err(NetworkError::WouldBlock)
1375 }
1376 Poll::Ready(Err(err)) => Err(err),
1377 }
1378 }
1379
1380 fn try_recv_from(
1381 &mut self,
1382 buf: &mut [std::mem::MaybeUninit<u8>],
1383 peek: bool,
1384 ) -> Result<(usize, SocketAddr)> {
1385 if peek {
1389 return Err(NetworkError::Unsupported);
1390 }
1391
1392 match self.rx_recv_with_addr.try_recv() {
1393 Ok(received) => {
1394 let amt = buf.len().min(received.data.len());
1395 let buf: &mut [u8] = unsafe { std::mem::transmute(buf) };
1396 buf[..amt].copy_from_slice(&received.data[..amt]);
1397 Ok((amt, received.addr))
1398 }
1399 Err(TryRecvError::Disconnected) => Err(NetworkError::ConnectionAborted),
1400 Err(TryRecvError::Empty) => Err(NetworkError::WouldBlock),
1401 }
1402 }
1403}
1404
1405impl VirtualUdpSocket for RemoteSocket {
1406 fn set_broadcast(&mut self, broadcast: bool) -> Result<()> {
1407 self.io_socket_fire_and_forget(RequestType::SetBroadcast(broadcast))
1408 }
1409
1410 fn broadcast(&self) -> Result<bool> {
1411 match block_on(self.io_socket(RequestType::GetBroadcast)) {
1412 ResponseType::Err(err) => Err(err),
1413 ResponseType::Flag(val) => Ok(val),
1414 res => {
1415 tracing::debug!("invalid response to get broadcast request - {res:?}");
1416 Err(NetworkError::IOError)
1417 }
1418 }
1419 }
1420
1421 fn set_multicast_loop_v4(&mut self, val: bool) -> Result<()> {
1422 self.io_socket_fire_and_forget(RequestType::SetMulticastLoopV4(val))
1423 }
1424
1425 fn multicast_loop_v4(&self) -> Result<bool> {
1426 match block_on(self.io_socket(RequestType::GetMulticastLoopV4)) {
1427 ResponseType::Err(err) => Err(err),
1428 ResponseType::Flag(val) => Ok(val),
1429 res => {
1430 tracing::debug!("invalid response to get multicast loop v4 request - {res:?}");
1431 Err(NetworkError::IOError)
1432 }
1433 }
1434 }
1435
1436 fn set_multicast_loop_v6(&mut self, val: bool) -> Result<()> {
1437 self.io_socket_fire_and_forget(RequestType::SetMulticastLoopV6(val))
1438 }
1439
1440 fn multicast_loop_v6(&self) -> Result<bool> {
1441 match block_on(self.io_socket(RequestType::GetMulticastLoopV6)) {
1442 ResponseType::Err(err) => Err(err),
1443 ResponseType::Flag(val) => Ok(val),
1444 res => {
1445 tracing::debug!("invalid response to get multicast loop v6 request - {res:?}");
1446 Err(NetworkError::IOError)
1447 }
1448 }
1449 }
1450
1451 fn set_multicast_ttl_v4(&mut self, ttl: u32) -> Result<()> {
1452 self.io_socket_fire_and_forget(RequestType::SetMulticastTtlV4(ttl))
1453 }
1454
1455 fn multicast_ttl_v4(&self) -> Result<u32> {
1456 match block_on(self.io_socket(RequestType::GetMulticastTtlV4)) {
1457 ResponseType::Err(err) => Err(err),
1458 ResponseType::Ttl(ttl) => Ok(ttl),
1459 res => {
1460 tracing::debug!("invalid response to get multicast TTL v4 request - {res:?}");
1461 Err(NetworkError::IOError)
1462 }
1463 }
1464 }
1465
1466 fn join_multicast_v4(
1467 &mut self,
1468 multiaddr: std::net::Ipv4Addr,
1469 iface: std::net::Ipv4Addr,
1470 ) -> Result<()> {
1471 self.io_socket_fire_and_forget(RequestType::JoinMulticastV4 { multiaddr, iface })
1472 }
1473
1474 fn leave_multicast_v4(
1475 &mut self,
1476 multiaddr: std::net::Ipv4Addr,
1477 iface: std::net::Ipv4Addr,
1478 ) -> Result<()> {
1479 self.io_socket_fire_and_forget(RequestType::LeaveMulticastV4 { multiaddr, iface })
1480 }
1481
1482 fn join_multicast_v6(&mut self, multiaddr: std::net::Ipv6Addr, iface: u32) -> Result<()> {
1483 self.io_socket_fire_and_forget(RequestType::JoinMulticastV6 { multiaddr, iface })
1484 }
1485
1486 fn leave_multicast_v6(&mut self, multiaddr: std::net::Ipv6Addr, iface: u32) -> Result<()> {
1487 self.io_socket_fire_and_forget(RequestType::LeaveMulticastV6 { multiaddr, iface })
1488 }
1489
1490 fn addr_peer(&self) -> Result<Option<SocketAddr>> {
1491 match block_on(self.io_socket(RequestType::GetAddrPeer)) {
1492 ResponseType::Err(err) => Err(err),
1493 ResponseType::None => Ok(None),
1494 ResponseType::SocketAddr(addr) => Ok(Some(addr)),
1495 res => {
1496 tracing::debug!("invalid response to addr peer request - {res:?}");
1497 Err(NetworkError::IOError)
1498 }
1499 }
1500 }
1501}
1502
1503impl VirtualIcmpSocket for RemoteSocket {}
1504
1505impl VirtualConnectedSocket for RemoteSocket {
1506 fn set_linger(&mut self, linger: Option<Duration>) -> Result<()> {
1507 self.io_socket_fire_and_forget(RequestType::SetLinger(linger))
1508 }
1509
1510 fn linger(&self) -> Result<Option<Duration>> {
1511 match block_on(self.io_socket(RequestType::GetLinger)) {
1512 ResponseType::Err(err) => Err(err),
1513 ResponseType::None => Ok(None),
1514 ResponseType::Duration(val) => Ok(Some(val)),
1515 res => {
1516 tracing::debug!("invalid response to get linger request - {res:?}");
1517 Err(NetworkError::IOError)
1518 }
1519 }
1520 }
1521
1522 fn try_send(&mut self, data: &[u8]) -> Result<usize> {
1523 let req_id = self.common.request_seed.fetch_add(1, Ordering::SeqCst);
1524 let mut cx = Context::from_waker(&self.tx_waker);
1525 match self.common.tx.poll_send(
1526 &mut cx,
1527 MessageRequest::Send {
1528 socket: self.socket_id,
1529 data: data.to_vec(),
1530 req_id: Some(req_id),
1531 },
1532 ) {
1533 Poll::Ready(Ok(())) => Ok(data.len()),
1534 Poll::Ready(Err(err)) => Err(err),
1535 Poll::Pending => Err(NetworkError::WouldBlock),
1536 }
1537 }
1538
1539 fn try_flush(&mut self) -> Result<()> {
1540 let mut cx = Context::from_waker(&self.tx_waker);
1541 match self.common.tx.poll_send(
1542 &mut cx,
1543 MessageRequest::Socket {
1544 socket: self.socket_id,
1545 req: RequestType::Flush,
1546 req_id: None,
1547 },
1548 ) {
1549 Poll::Ready(Ok(())) => Ok(()),
1550 Poll::Ready(Err(err)) => Err(err),
1551 Poll::Pending => Err(NetworkError::WouldBlock),
1552 }
1553 }
1554
1555 fn close(&mut self) -> Result<()> {
1556 let ret = self.io_socket_fire_and_forget(RequestType::Close);
1557 if ret.is_ok() {
1558 self.release_socket_bindings();
1559 }
1560 ret
1561 }
1562
1563 fn try_recv(&mut self, buf: &mut [std::mem::MaybeUninit<u8>], peek: bool) -> Result<usize> {
1564 loop {
1565 if !self.rx_buffer.is_empty() {
1566 let amt = self.rx_buffer.len().min(buf.len());
1567 let buf: &mut [u8] = unsafe { std::mem::transmute(buf) };
1568 buf[..amt].copy_from_slice(&self.rx_buffer[..amt]);
1569 if !peek {
1570 self.rx_buffer.advance(amt);
1571 }
1572 return Ok(amt);
1573 }
1574 match self.rx_recv.try_recv() {
1575 Ok(data) => self.rx_buffer.extend_from_slice(&data),
1576 Err(TryRecvError::Disconnected) => return Err(NetworkError::ConnectionAborted),
1577 Err(TryRecvError::Empty) => return Err(NetworkError::WouldBlock),
1578 }
1579 }
1580 }
1581}
1582
1583impl VirtualTcpSocket for RemoteSocket {
1584 fn set_recv_buf_size(&mut self, size: usize) -> Result<()> {
1585 self.io_socket_fire_and_forget(RequestType::SetRecvBufSize(size as u64))
1586 }
1587
1588 fn recv_buf_size(&self) -> Result<usize> {
1589 match block_on(self.io_socket(RequestType::GetRecvBufSize)) {
1590 ResponseType::Err(err) => Err(err),
1591 ResponseType::Amount(amt) => Ok(amt.try_into().map_err(|_| NetworkError::IOError)?),
1592 res => {
1593 tracing::debug!("invalid response to get recv buf size request - {res:?}");
1594 Err(NetworkError::IOError)
1595 }
1596 }
1597 }
1598
1599 fn set_send_buf_size(&mut self, size: usize) -> Result<()> {
1600 self.io_socket_fire_and_forget(RequestType::SetSendBufSize(size as u64))
1601 }
1602
1603 fn send_buf_size(&self) -> Result<usize> {
1604 match block_on(self.io_socket(RequestType::GetSendBufSize)) {
1605 ResponseType::Err(err) => Err(err),
1606 ResponseType::Amount(val) => Ok(val.try_into().map_err(|_| NetworkError::IOError)?),
1607 res => {
1608 tracing::debug!("invalid response to get send buf size request - {res:?}");
1609 Err(NetworkError::IOError)
1610 }
1611 }
1612 }
1613
1614 fn set_nodelay(&mut self, reuse: bool) -> Result<()> {
1615 self.io_socket_fire_and_forget(RequestType::SetNoDelay(reuse))
1616 }
1617
1618 fn nodelay(&self) -> Result<bool> {
1619 match block_on(self.io_socket(RequestType::GetNoDelay)) {
1620 ResponseType::Err(err) => Err(err),
1621 ResponseType::Flag(val) => Ok(val),
1622 res => {
1623 tracing::debug!("invalid response to get nodelay request - {res:?}");
1624 Err(NetworkError::IOError)
1625 }
1626 }
1627 }
1628
1629 fn set_keepalive(&mut self, keep_alive: bool) -> Result<()> {
1630 self.io_socket_fire_and_forget(RequestType::SetKeepAlive(keep_alive))
1631 }
1632
1633 fn keepalive(&self) -> Result<bool> {
1634 match block_on(self.io_socket(RequestType::GetKeepAlive)) {
1635 ResponseType::Err(err) => Err(err),
1636 ResponseType::Flag(val) => Ok(val),
1637 res => {
1638 tracing::debug!("invalid response to get nodelay request - {res:?}");
1639 Err(NetworkError::IOError)
1640 }
1641 }
1642 }
1643
1644 fn set_dontroute(&mut self, dont_route: bool) -> Result<()> {
1645 self.io_socket_fire_and_forget(RequestType::SetDontRoute(dont_route))
1646 }
1647
1648 fn dontroute(&self) -> Result<bool> {
1649 match block_on(self.io_socket(RequestType::GetDontRoute)) {
1650 ResponseType::Err(err) => Err(err),
1651 ResponseType::Flag(val) => Ok(val),
1652 res => {
1653 tracing::debug!("invalid response to get nodelay request - {res:?}");
1654 Err(NetworkError::IOError)
1655 }
1656 }
1657 }
1658
1659 fn addr_peer(&self) -> Result<SocketAddr> {
1660 match block_on(self.io_socket(RequestType::GetAddrPeer)) {
1661 ResponseType::Err(err) => Err(err),
1662 ResponseType::SocketAddr(addr) => Ok(addr),
1663 res => {
1664 tracing::debug!("invalid response to addr peer request - {res:?}");
1665 Err(NetworkError::IOError)
1666 }
1667 }
1668 }
1669
1670 fn shutdown(&mut self, how: std::net::Shutdown) -> Result<()> {
1671 let shutdown = match how {
1672 std::net::Shutdown::Read => meta::Shutdown::Read,
1673 std::net::Shutdown::Write => meta::Shutdown::Write,
1674 std::net::Shutdown::Both => meta::Shutdown::Both,
1675 };
1676 self.io_socket_fire_and_forget(RequestType::Shutdown(shutdown))
1677 }
1678
1679 fn is_closed(&self) -> bool {
1680 match block_on(self.io_socket(RequestType::IsClosed)) {
1681 ResponseType::Flag(val) => val,
1682 _ => false,
1683 }
1684 }
1685}