1use std::collections::HashMap;
2use std::collections::VecDeque;
3use std::future::Future;
4use std::net::IpAddr;
5use std::net::SocketAddr;
6use std::pin::Pin;
7use std::sync::Arc;
8use std::sync::Mutex;
9use std::sync::atomic::AtomicU64;
10use std::sync::atomic::Ordering;
11use std::task::Context;
12use std::task::Poll;
13use std::task::RawWaker;
14use std::task::RawWakerVTable;
15use std::task::Waker;
16use std::time::Duration;
17
18use bytes::Buf;
19use bytes::BytesMut;
20use futures_util::Sink;
21use futures_util::Stream;
22use futures_util::StreamExt;
23use futures_util::future::BoxFuture;
24use futures_util::stream::FuturesOrdered;
25#[cfg(feature = "hyper")]
26use hyper_util::rt::tokio::TokioIo;
27use tokio::io::AsyncRead;
28use tokio::io::AsyncWrite;
29use tokio::sync::mpsc;
30use tokio::sync::mpsc::error::TryRecvError;
31use tokio::sync::mpsc::error::TrySendError;
32use tokio_serde::SymmetricallyFramed;
33use tokio_serde::formats::SymmetricalBincode;
34#[cfg(feature = "cbor")]
35use tokio_serde::formats::SymmetricalCbor;
36#[cfg(feature = "json")]
37use tokio_serde::formats::SymmetricalJson;
38#[cfg(feature = "messagepack")]
39use tokio_serde::formats::SymmetricalMessagePack;
40use tokio_util::codec::FramedRead;
41use tokio_util::codec::FramedWrite;
42use tokio_util::codec::LengthDelimitedCodec;
43use virtual_mio::InterestType;
44use virtual_mio::block_on;
45
46use crate::IpCidr;
47use crate::IpRoute;
48use crate::NetworkError;
49use crate::StreamSecurity;
50use crate::VirtualConnectedSocket;
51use crate::VirtualConnectionlessSocket;
52use crate::VirtualIcmpSocket;
53use crate::VirtualIoSource;
54use crate::VirtualNetworking;
55use crate::VirtualRawSocket;
56use crate::VirtualSocket;
57use crate::VirtualTcpBoundSocket;
58use crate::VirtualTcpListener;
59use crate::VirtualTcpSocket;
60use crate::VirtualUdpSocket;
61use crate::meta;
62use crate::meta::FrameSerializationFormat;
63use crate::meta::RequestType;
64use crate::meta::ResponseType;
65use crate::meta::SocketId;
66use crate::meta::{MessageRequest, MessageResponse};
67
68use crate::Result;
69use crate::rx_tx::RemoteRx;
70use crate::rx_tx::RemoteTx;
71use crate::rx_tx::RemoteTxWakers;
72
73#[derive(Debug, Clone)]
74pub struct RemoteNetworkingClient {
75 common: Arc<RemoteCommon>,
76}
77
78impl RemoteNetworkingClient {
79 fn new(
80 tx: RemoteTx<MessageRequest>,
81 rx: RemoteRx<MessageResponse>,
82 rx_work: mpsc::UnboundedReceiver<BoxFuture<'static, ()>>,
83 ) -> (Self, RemoteNetworkingClientDriver) {
84 let common = RemoteCommon {
85 tx,
86 rx: Mutex::new(rx),
87 request_seed: AtomicU64::new(1),
88 requests: Default::default(),
89 socket_seed: AtomicU64::new(1),
90 recv_tx: Default::default(),
91 recv_with_addr_tx: Default::default(),
92 accept_tx: Default::default(),
93 sent_tx: Default::default(),
94 handlers: Default::default(),
95 stall: Default::default(),
96 };
97 let common = Arc::new(common);
98
99 let driver = RemoteNetworkingClientDriver {
100 more_work: rx_work,
101 tasks: Default::default(),
102 common: common.clone(),
103 };
104 let networking = Self { common };
105
106 (networking, driver)
107 }
108
109 pub fn new_from_mpsc(
112 tx: mpsc::Sender<MessageRequest>,
113 rx: mpsc::Receiver<MessageResponse>,
114 ) -> (Self, RemoteNetworkingClientDriver) {
115 let (tx_work, rx_work) = mpsc::unbounded_channel();
116 let tx_wakers = RemoteTxWakers::default();
117
118 let tx = RemoteTx::Mpsc {
119 tx,
120 work: tx_work,
121 wakers: tx_wakers.clone(),
122 };
123 let rx = RemoteRx::Mpsc {
124 rx,
125 wakers: tx_wakers,
126 };
127
128 Self::new(tx, rx, rx_work)
129 }
130
131 pub fn new_from_async_io<TX, RX>(
137 tx: TX,
138 rx: RX,
139 format: FrameSerializationFormat,
140 ) -> (Self, RemoteNetworkingClientDriver)
141 where
142 TX: AsyncWrite + Send + 'static,
143 RX: AsyncRead + Send + 'static,
144 {
145 let tx = FramedWrite::new(tx, LengthDelimitedCodec::new());
146 let tx: Pin<Box<dyn Sink<MessageRequest, Error = std::io::Error> + Send + 'static>> =
147 match format {
148 FrameSerializationFormat::Bincode => {
149 Box::pin(SymmetricallyFramed::new(tx, SymmetricalBincode::default()))
150 }
151 #[cfg(feature = "json")]
152 FrameSerializationFormat::Json => {
153 Box::pin(SymmetricallyFramed::new(tx, SymmetricalJson::default()))
154 }
155 #[cfg(feature = "messagepack")]
156 FrameSerializationFormat::MessagePack => Box::pin(SymmetricallyFramed::new(
157 tx,
158 SymmetricalMessagePack::default(),
159 )),
160 #[cfg(feature = "cbor")]
161 FrameSerializationFormat::Cbor => {
162 Box::pin(SymmetricallyFramed::new(tx, SymmetricalCbor::default()))
163 }
164 };
165
166 let rx = FramedRead::new(rx, LengthDelimitedCodec::new());
167 let rx: Pin<Box<dyn Stream<Item = std::io::Result<MessageResponse>> + Send + 'static>> =
168 match format {
169 FrameSerializationFormat::Bincode => {
170 Box::pin(SymmetricallyFramed::new(rx, SymmetricalBincode::default()))
171 }
172 #[cfg(feature = "json")]
173 FrameSerializationFormat::Json => {
174 Box::pin(SymmetricallyFramed::new(rx, SymmetricalJson::default()))
175 }
176 #[cfg(feature = "messagepack")]
177 FrameSerializationFormat::MessagePack => Box::pin(SymmetricallyFramed::new(
178 rx,
179 SymmetricalMessagePack::default(),
180 )),
181 #[cfg(feature = "cbor")]
182 FrameSerializationFormat::Cbor => {
183 Box::pin(SymmetricallyFramed::new(rx, SymmetricalCbor::default()))
184 }
185 };
186
187 let (tx_work, rx_work) = mpsc::unbounded_channel();
188 let tx_wakers = RemoteTxWakers::default();
189
190 let tx = RemoteTx::Stream {
191 tx: Arc::new(tokio::sync::Mutex::new(tx)),
192 work: tx_work,
193 wakers: tx_wakers,
194 };
195 let rx = RemoteRx::Stream { rx };
196
197 Self::new(tx, rx, rx_work)
198 }
199
200 #[cfg(feature = "hyper")]
203 pub fn new_from_hyper_ws_io(
204 tx: futures_util::stream::SplitSink<
205 hyper_tungstenite::WebSocketStream<TokioIo<hyper::upgrade::Upgraded>>,
206 hyper_tungstenite::tungstenite::Message,
207 >,
208 rx: futures_util::stream::SplitStream<
209 hyper_tungstenite::WebSocketStream<TokioIo<hyper::upgrade::Upgraded>>,
210 >,
211 format: FrameSerializationFormat,
212 ) -> (Self, RemoteNetworkingClientDriver) {
213 let (tx_work, rx_work) = mpsc::unbounded_channel();
214
215 let tx = RemoteTx::HyperWebSocket {
216 tx: Arc::new(tokio::sync::Mutex::new(tx)),
217 work: tx_work,
218 wakers: RemoteTxWakers::default(),
219 format,
220 };
221 let rx = RemoteRx::HyperWebSocket { rx, format };
222 Self::new(tx, rx, rx_work)
223 }
224
225 #[cfg(feature = "tokio-tungstenite")]
228 pub fn new_from_tokio_ws_io(
229 tx: futures_util::stream::SplitSink<
230 tokio_tungstenite::WebSocketStream<
231 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
232 >,
233 tokio_tungstenite::tungstenite::Message,
234 >,
235 rx: futures_util::stream::SplitStream<
236 tokio_tungstenite::WebSocketStream<
237 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
238 >,
239 >,
240 format: FrameSerializationFormat,
241 ) -> (Self, RemoteNetworkingClientDriver) {
242 let (tx_work, rx_work) = mpsc::unbounded_channel();
243
244 let tx = RemoteTx::TokioWebSocket {
245 tx: Arc::new(tokio::sync::Mutex::new(tx)),
246 work: tx_work,
247 wakers: RemoteTxWakers::default(),
248 format,
249 };
250 let rx = RemoteRx::TokioWebSocket { rx, format };
251 Self::new(tx, rx, rx_work)
252 }
253
254 fn new_socket(&self, id: SocketId) -> RemoteSocket {
255 let (tx, rx_recv) = tokio::sync::mpsc::channel(100);
256 self.common.recv_tx.lock().unwrap().insert(id, tx);
257
258 let (tx, rx_recv_with_addr) = tokio::sync::mpsc::channel(100);
259 self.common.recv_with_addr_tx.lock().unwrap().insert(id, tx);
260
261 let (tx, rx_accept) = tokio::sync::mpsc::channel(100);
262 self.common.accept_tx.lock().unwrap().insert(id, tx);
263
264 let (tx, rx_sent) = tokio::sync::mpsc::channel(100);
265 self.common.sent_tx.lock().unwrap().insert(id, tx);
266
267 RemoteSocket {
268 socket_id: id,
269 common: self.common.clone(),
270 rx_buffer: BytesMut::new(),
271 rx_recv,
272 rx_recv_with_addr,
273 rx_accept,
274 rx_sent,
275 tx_waker: TxWaker::new(&self.common).as_waker(),
276 pending_accept: None,
277 buffer_accept: Default::default(),
278 buffer_recv_with_addr: Default::default(),
279 send_available: 0,
280 owns_socket_bindings: true,
281 }
282 }
283}
284
285pin_project_lite::pin_project! {
286 pub struct RemoteNetworkingClientDriver {
287 common: Arc<RemoteCommon>,
288 more_work: mpsc::UnboundedReceiver<BoxFuture<'static, ()>>,
289 #[pin]
290 tasks: FuturesOrdered<BoxFuture<'static, ()>>,
291 }
292}
293
294impl Future for RemoteNetworkingClientDriver {
295 type Output = ();
296
297 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
298 let mut not_stalled_guard = None;
302
303 loop {
306 while let Poll::Ready(Some(work)) = Pin::new(&mut self.more_work).poll_recv(cx) {
308 self.tasks.push_back(work);
309 }
310
311 match self.tasks.poll_next_unpin(cx) {
314 Poll::Ready(Some(_)) => continue,
315 Poll::Ready(None) => {
316 not_stalled_guard.take();
317 }
318 Poll::Pending if not_stalled_guard.is_none() => {
319 match self.common.stall.clone().try_lock_owned() {
320 Ok(guard) => {
321 not_stalled_guard.replace(guard);
322 }
323 _ => {
324 return Poll::Pending;
325 }
326 }
327 }
328 Poll::Pending => {}
329 };
330
331 let msg = {
333 let mut rx_guard = self.common.rx.lock().unwrap();
334 rx_guard.poll(cx)
335 };
336 return match msg {
337 Poll::Ready(Some(msg)) => {
338 match msg {
339 MessageResponse::Recv { socket_id, data } => {
340 let tx = {
341 let guard = self.common.recv_tx.lock().unwrap();
342 match guard.get(&socket_id) {
343 Some(tx) => tx.clone(),
344 None => {
345 continue;
346 }
347 }
348 };
349 let common = self.common.clone();
350 self.tasks.push_back(Box::pin(async move {
351 tx.send(data).await.ok();
352
353 if let Some(h) = common.handlers.lock().unwrap().get_mut(&socket_id)
354 {
355 h.push_interest(InterestType::Readable)
356 }
357 }));
358 }
359 MessageResponse::RecvWithAddr {
360 socket_id,
361 data,
362 addr,
363 } => {
364 let tx = {
365 let guard = self.common.recv_with_addr_tx.lock().unwrap();
366 match guard.get(&socket_id) {
367 Some(tx) => tx.clone(),
368 None => continue,
369 }
370 };
371 let common = self.common.clone();
372 self.tasks.push_back(Box::pin(async move {
373 tx.send(DataWithAddr { data, addr }).await.ok();
374
375 if let Some(h) = common.handlers.lock().unwrap().get_mut(&socket_id)
376 {
377 h.push_interest(InterestType::Readable)
378 }
379 }));
380 }
381 MessageResponse::Sent {
382 socket_id, amount, ..
383 } => {
384 let tx = {
385 let guard = self.common.sent_tx.lock().unwrap();
386 match guard.get(&socket_id) {
387 Some(tx) => tx.clone(),
388 None => continue,
389 }
390 };
391 self.tasks.push_back(Box::pin(async move {
392 tx.send(amount).await.ok();
393 }));
394 if let Some(h) =
395 self.common.handlers.lock().unwrap().get_mut(&socket_id)
396 {
397 h.push_interest(InterestType::Writable)
398 }
399 }
400 MessageResponse::SendError {
401 socket_id, error, ..
402 } => match &error {
403 NetworkError::ConnectionAborted
404 | NetworkError::ConnectionReset
405 | NetworkError::BrokenPipe => {
406 if let Some(h) =
407 self.common.handlers.lock().unwrap().get_mut(&socket_id)
408 {
409 h.push_interest(InterestType::Closed)
410 }
411 }
412 _ => {
413 if let Some(h) =
414 self.common.handlers.lock().unwrap().get_mut(&socket_id)
415 {
416 h.push_interest(InterestType::Writable)
417 }
418 }
419 },
420 MessageResponse::FinishAccept {
421 socket_id,
422 child_id,
423 addr,
424 } => {
425 let common = self.common.clone();
426 self.tasks.push_back(Box::pin(async move {
427 let tx = common.accept_tx.lock().unwrap().get(&socket_id).cloned();
428 if let Some(tx) = tx {
429 tx.send(SocketWithAddr {
430 socket: child_id,
431 addr,
432 })
433 .await
434 .ok();
435 }
436
437 if let Some(h) = common.handlers.lock().unwrap().get_mut(&socket_id)
438 {
439 h.push_interest(InterestType::Readable)
440 }
441 }));
442 }
443 MessageResponse::Closed { socket_id } => {
444 if let Some(h) =
445 self.common.handlers.lock().unwrap().get_mut(&socket_id)
446 {
447 h.push_interest(InterestType::Closed)
448 }
449 }
450 MessageResponse::ResponseToRequest { req_id, res } => {
451 let mut requests = self.common.requests.lock().unwrap();
452 if let Some(request) = requests.remove(&req_id) {
453 request.try_send(res).ok();
454 }
455 }
456 }
457 continue;
458 }
459 Poll::Ready(None) => Poll::Ready(()),
460 Poll::Pending => Poll::Pending,
461 };
462 }
463 }
464}
465
466#[derive(Debug)]
467struct TxWaker {
468 common: Arc<RemoteCommon>,
469}
470impl TxWaker {
471 pub fn new(common: &Arc<RemoteCommon>) -> Arc<Self> {
472 Arc::new(Self {
473 common: common.clone(),
474 })
475 }
476
477 fn wake_now(&self) {
478 let mut guard = self.common.handlers.lock().unwrap();
479 for handler in guard.values_mut() {
480 handler.push_interest(InterestType::Writable);
481 }
482 }
483
484 pub fn as_waker(self: &Arc<Self>) -> Waker {
485 let s: *const Self = Arc::into_raw(Arc::clone(self));
486 let raw_waker = RawWaker::new(s as *const (), &VTABLE);
487 unsafe { Waker::from_raw(raw_waker) }
488 }
489}
490
491fn tx_waker_wake(s: &TxWaker) {
492 let waker_arc = unsafe { Arc::from_raw(s) };
493 waker_arc.wake_now();
494}
495
496fn tx_waker_clone(s: &TxWaker) -> RawWaker {
497 let arc = unsafe { Arc::from_raw(s) };
498 std::mem::forget(arc.clone());
499 RawWaker::new(Arc::into_raw(arc) as *const (), &VTABLE)
500}
501
502const VTABLE: RawWakerVTable = unsafe {
503 RawWakerVTable::new(
504 |s| tx_waker_clone(&*(s as *const TxWaker)), |s| tx_waker_wake(&*(s as *const TxWaker)), |s| (*(s as *const TxWaker)).wake_now(), |s| drop(Arc::from_raw(s as *const TxWaker)), )
509};
510
511#[derive(Debug)]
512struct RequestTx {
513 tx: mpsc::Sender<ResponseType>,
514}
515impl RequestTx {
516 pub fn try_send(self, msg: ResponseType) -> Result<()> {
517 match self.tx.try_send(msg) {
518 Ok(()) => Ok(()),
519 Err(TrySendError::Closed(_)) => Err(NetworkError::ConnectionAborted),
520 Err(TrySendError::Full(_)) => Err(NetworkError::WouldBlock),
521 }
522 }
523}
524
525#[derive(Debug)]
526struct DataWithAddr {
527 pub data: Vec<u8>,
528 pub addr: SocketAddr,
529}
530#[derive(Debug)]
531struct SocketWithAddr {
532 pub socket: SocketId,
533 pub addr: SocketAddr,
534}
535type SocketMap<T> = HashMap<SocketId, T>;
536
537#[derive(derive_more::Debug)]
538struct RemoteCommon {
539 #[debug(ignore)]
540 tx: RemoteTx<MessageRequest>,
541 #[debug(ignore)]
542 rx: Mutex<RemoteRx<MessageResponse>>,
543 request_seed: AtomicU64,
544 requests: Mutex<HashMap<u64, RequestTx>>,
545 socket_seed: AtomicU64,
546 recv_tx: Mutex<SocketMap<mpsc::Sender<Vec<u8>>>>,
547 recv_with_addr_tx: Mutex<SocketMap<mpsc::Sender<DataWithAddr>>>,
548 accept_tx: Mutex<SocketMap<mpsc::Sender<SocketWithAddr>>>,
549 sent_tx: Mutex<SocketMap<mpsc::Sender<u64>>>,
550 #[debug(ignore)]
551 handlers: Mutex<SocketMap<Box<dyn virtual_mio::InterestHandler + Send + Sync>>>,
552
553 stall: Arc<tokio::sync::Mutex<()>>,
556}
557
558impl RemoteCommon {
559 async fn io_iface(&self, req: RequestType) -> ResponseType {
560 let req_id = self.request_seed.fetch_add(1, Ordering::SeqCst);
561 let mut req_rx = {
562 let (tx, rx) = mpsc::channel(1);
563 let mut guard = self.requests.lock().unwrap();
564 guard.insert(req_id, RequestTx { tx });
565 rx
566 };
567 if let Err(err) = self
568 .tx
569 .send(MessageRequest::Interface {
570 req_id: Some(req_id),
571 req,
572 })
573 .await
574 {
575 return ResponseType::Err(err);
576 };
577 req_rx.recv().await.unwrap()
578 }
579
580 fn io_iface_fire_and_forget(&self, req: RequestType) -> Result<()> {
581 self.tx
582 .send_with_driver(MessageRequest::Interface { req_id: None, req })
583 }
584}
585
586#[async_trait::async_trait]
587impl VirtualNetworking for RemoteNetworkingClient {
588 async fn bridge(
589 &self,
590 network: &str,
591 access_token: &str,
592 security: StreamSecurity,
593 ) -> Result<()> {
594 match self
595 .common
596 .io_iface(RequestType::Bridge {
597 network: network.to_string(),
598 access_token: access_token.to_string(),
599 security,
600 })
601 .await
602 {
603 ResponseType::Err(err) => Err(err),
604 ResponseType::None => Ok(()),
605 res => {
606 tracing::debug!("invalid response to bridge request - {res:?}");
607 Err(NetworkError::IOError)
608 }
609 }
610 }
611
612 async fn unbridge(&self) -> Result<()> {
613 match self.common.io_iface(RequestType::Unbridge).await {
614 ResponseType::Err(err) => Err(err),
615 ResponseType::None => Ok(()),
616 res => {
617 tracing::debug!("invalid response to unbridge request - {res:?}");
618 Err(NetworkError::IOError)
619 }
620 }
621 }
622
623 async fn dhcp_acquire(&self) -> Result<Vec<IpAddr>> {
624 match self.common.io_iface(RequestType::DhcpAcquire).await {
625 ResponseType::Err(err) => Err(err),
626 ResponseType::IpAddressList(ips) => Ok(ips),
627 res => {
628 tracing::debug!("invalid response to DHCP acquire request - {res:?}");
629 Err(NetworkError::IOError)
630 }
631 }
632 }
633
634 async fn ip_add(&self, ip: IpAddr, prefix: u8) -> Result<()> {
635 self.common
636 .io_iface_fire_and_forget(RequestType::IpAdd { ip, prefix })
637 }
638
639 async fn ip_remove(&self, ip: IpAddr) -> Result<()> {
640 self.common
641 .io_iface_fire_and_forget(RequestType::IpRemove(ip))
642 }
643
644 async fn ip_clear(&self) -> Result<()> {
645 self.common.io_iface_fire_and_forget(RequestType::IpClear)
646 }
647
648 async fn ip_list(&self) -> Result<Vec<IpCidr>> {
649 match self.common.io_iface(RequestType::GetIpList).await {
650 ResponseType::Err(err) => Err(err),
651 ResponseType::CidrList(routes) => Ok(routes),
652 res => {
653 tracing::debug!("invalid response to IP list request - {res:?}");
654 Err(NetworkError::IOError)
655 }
656 }
657 }
658
659 async fn mac(&self) -> Result<[u8; 6]> {
660 match self.common.io_iface(RequestType::GetMac).await {
661 ResponseType::Err(err) => Err(err),
662 ResponseType::Mac(mac) => Ok(mac),
663 res => {
664 tracing::debug!("invalid response to MAC request - {res:?}");
665 Err(NetworkError::IOError)
666 }
667 }
668 }
669
670 async fn gateway_set(&self, ip: IpAddr) -> Result<()> {
671 self.common
672 .io_iface_fire_and_forget(RequestType::GatewaySet(ip))
673 }
674
675 async fn route_add(
676 &self,
677 cidr: IpCidr,
678 via_router: IpAddr,
679 preferred_until: Option<Duration>,
680 expires_at: Option<Duration>,
681 ) -> Result<()> {
682 self.common.io_iface_fire_and_forget(RequestType::RouteAdd {
683 cidr,
684 via_router,
685 preferred_until,
686 expires_at,
687 })
688 }
689
690 async fn route_remove(&self, cidr: IpAddr) -> Result<()> {
691 self.common
692 .io_iface_fire_and_forget(RequestType::RouteRemove(cidr))
693 }
694
695 async fn route_clear(&self) -> Result<()> {
696 self.common
697 .io_iface_fire_and_forget(RequestType::RouteClear)
698 }
699
700 async fn route_list(&self) -> Result<Vec<IpRoute>> {
701 match self.common.io_iface(RequestType::GetRouteList).await {
702 ResponseType::Err(err) => Err(err),
703 ResponseType::RouteList(routes) => Ok(routes),
704 res => {
705 tracing::debug!("invalid response to route list request - {res:?}");
706 Err(NetworkError::IOError)
707 }
708 }
709 }
710
711 async fn bind_raw(&self) -> Result<Box<dyn VirtualRawSocket + Sync>> {
712 let socket_id: SocketId = self
713 .common
714 .socket_seed
715 .fetch_add(1, Ordering::SeqCst)
716 .into();
717 match self.common.io_iface(RequestType::BindRaw(socket_id)).await {
718 ResponseType::Err(err) => Err(err),
719 ResponseType::None => Ok(Box::new(self.new_socket(socket_id))),
720 ResponseType::Socket(socket_id) => Ok(Box::new(self.new_socket(socket_id))),
721 res => {
722 tracing::debug!("invalid response to bind RAw request - {res:?}");
723 Err(NetworkError::IOError)
724 }
725 }
726 }
727
728 async fn listen_tcp(
729 &self,
730 addr: SocketAddr,
731 only_v6: bool,
732 reuse_port: bool,
733 reuse_addr: bool,
734 ) -> Result<Box<dyn VirtualTcpListener + Sync>> {
735 let socket_id: SocketId = self
736 .common
737 .socket_seed
738 .fetch_add(1, Ordering::SeqCst)
739 .into();
740 match self
741 .common
742 .io_iface(RequestType::ListenTcp {
743 socket_id,
744 addr,
745 only_v6,
746 reuse_port,
747 reuse_addr,
748 })
749 .await
750 {
751 ResponseType::Err(err) => Err(err),
752 ResponseType::None => Ok(Box::new(self.new_socket(socket_id))),
753 ResponseType::Socket(socket_id) => {
754 let mut socket = self.new_socket(socket_id);
755 socket.touch_begin_accept().ok();
756 Ok(Box::new(socket))
757 }
758 res => {
759 tracing::debug!("invalid response to listen TCP request - {res:?}");
760 Err(NetworkError::IOError)
761 }
762 }
763 }
764
765 async fn bind_tcp(
766 &self,
767 addr: SocketAddr,
768 only_v6: bool,
769 reuse_port: bool,
770 reuse_addr: bool,
771 ) -> Result<Box<dyn VirtualTcpBoundSocket + Sync>> {
772 let socket_id: SocketId = self
773 .common
774 .socket_seed
775 .fetch_add(1, Ordering::SeqCst)
776 .into();
777 match self
778 .common
779 .io_iface(RequestType::BindTcp {
780 socket_id,
781 addr,
782 only_v6,
783 reuse_port,
784 reuse_addr,
785 })
786 .await
787 {
788 ResponseType::Err(err) => Err(err),
789 ResponseType::None => Ok(Box::new(self.new_socket(socket_id))),
790 ResponseType::Socket(socket_id) => Ok(Box::new(self.new_socket(socket_id))),
791 res => {
792 tracing::debug!("invalid response to bind TCP request - {res:?}");
793 Err(NetworkError::IOError)
794 }
795 }
796 }
797
798 async fn bind_udp(
799 &self,
800 addr: SocketAddr,
801 reuse_port: bool,
802 reuse_addr: bool,
803 ) -> Result<Box<dyn VirtualUdpSocket + Sync>> {
804 let socket_id: SocketId = self
805 .common
806 .socket_seed
807 .fetch_add(1, Ordering::SeqCst)
808 .into();
809 match self
810 .common
811 .io_iface(RequestType::BindUdp {
812 socket_id,
813 addr,
814 reuse_port,
815 reuse_addr,
816 })
817 .await
818 {
819 ResponseType::Err(err) => Err(err),
820 ResponseType::None => Ok(Box::new(self.new_socket(socket_id))),
821 ResponseType::Socket(socket_id) => Ok(Box::new(self.new_socket(socket_id))),
822 res => {
823 tracing::debug!("invalid response to bind UDP request - {res:?}");
824 Err(NetworkError::IOError)
825 }
826 }
827 }
828
829 async fn bind_icmp(&self, addr: IpAddr) -> Result<Box<dyn VirtualIcmpSocket + Sync>> {
830 let socket_id: SocketId = self
831 .common
832 .socket_seed
833 .fetch_add(1, Ordering::SeqCst)
834 .into();
835 match self
836 .common
837 .io_iface(RequestType::BindIcmp { socket_id, addr })
838 .await
839 {
840 ResponseType::Err(err) => Err(err),
841 ResponseType::None => Ok(Box::new(self.new_socket(socket_id))),
842 ResponseType::Socket(socket_id) => Ok(Box::new(self.new_socket(socket_id))),
843 res => {
844 tracing::debug!("invalid response to bind ICMP request - {res:?}");
845 Err(NetworkError::IOError)
846 }
847 }
848 }
849
850 async fn connect_tcp(
851 &self,
852 addr: SocketAddr,
853 peer: SocketAddr,
854 ) -> Result<Box<dyn VirtualTcpSocket + Sync>> {
855 let socket_id: SocketId = self
856 .common
857 .socket_seed
858 .fetch_add(1, Ordering::SeqCst)
859 .into();
860 match self
861 .common
862 .io_iface(RequestType::ConnectTcp {
863 socket_id,
864 addr,
865 peer,
866 })
867 .await
868 {
869 ResponseType::Err(err) => Err(err),
870 ResponseType::None => Ok(Box::new(self.new_socket(socket_id))),
871 ResponseType::Socket(socket_id) => Ok(Box::new(self.new_socket(socket_id))),
872 res => {
873 tracing::debug!("invalid response to connect TCP request - {res:?}");
874 Err(NetworkError::IOError)
875 }
876 }
877 }
878
879 async fn resolve(
880 &self,
881 host: &str,
882 port: Option<u16>,
883 dns_server: Option<IpAddr>,
884 ) -> Result<Vec<IpAddr>> {
885 match self
886 .common
887 .io_iface(RequestType::Resolve {
888 host: host.to_string(),
889 port,
890 dns_server,
891 })
892 .await
893 {
894 ResponseType::Err(err) => Err(err),
895 ResponseType::IpAddressList(ips) => Ok(ips),
896 res => {
897 tracing::debug!("invalid response to resolve request - {res:?}");
898 Err(NetworkError::IOError)
899 }
900 }
901 }
902}
903
904#[derive(Debug)]
905struct RemoteSocket {
906 socket_id: SocketId,
907 common: Arc<RemoteCommon>,
908 rx_buffer: BytesMut,
909 rx_recv: mpsc::Receiver<Vec<u8>>,
910 rx_recv_with_addr: mpsc::Receiver<DataWithAddr>,
911 tx_waker: Waker,
912 rx_accept: mpsc::Receiver<SocketWithAddr>,
913 rx_sent: mpsc::Receiver<u64>,
914 pending_accept: Option<(SocketId, mpsc::Receiver<Vec<u8>>)>,
915 buffer_recv_with_addr: VecDeque<DataWithAddr>,
916 buffer_accept: VecDeque<SocketWithAddr>,
917 send_available: u64,
918 owns_socket_bindings: bool,
919}
920impl Drop for RemoteSocket {
921 fn drop(&mut self) {
922 if !self.owns_socket_bindings {
923 return;
924 }
925 let _ = self.io_socket_fire_and_forget(RequestType::Close);
926 self.release_socket_bindings();
927 }
928}
929
930impl RemoteSocket {
931 fn release_socket_bindings(&mut self) {
932 self.owns_socket_bindings = false;
933 self.common.recv_tx.lock().unwrap().remove(&self.socket_id);
934 self.common
935 .recv_with_addr_tx
936 .lock()
937 .unwrap()
938 .remove(&self.socket_id);
939 self.common
940 .accept_tx
941 .lock()
942 .unwrap()
943 .remove(&self.socket_id);
944 self.common.sent_tx.lock().unwrap().remove(&self.socket_id);
945 self.common.handlers.lock().unwrap().remove(&self.socket_id);
946
947 if let Some((child_id, _)) = self.pending_accept.take() {
948 self.common.recv_tx.lock().unwrap().remove(&child_id);
949 }
950 }
951
952 async fn io_socket(&self, req: RequestType) -> ResponseType {
953 let req_id = self.common.request_seed.fetch_add(1, Ordering::SeqCst);
954 let mut req_rx = {
955 let (tx, rx) = mpsc::channel(1);
956 let mut guard = self.common.requests.lock().unwrap();
957 guard.insert(req_id, RequestTx { tx });
958 rx
959 };
960 if let Err(err) = self
961 .common
962 .tx
963 .send(MessageRequest::Socket {
964 socket: self.socket_id,
965 req_id: Some(req_id),
966 req,
967 })
968 .await
969 {
970 return ResponseType::Err(err);
971 };
972 req_rx.recv().await.unwrap()
973 }
974
975 fn io_socket_fire_and_forget(&self, req: RequestType) -> Result<()> {
976 self.common.tx.send_with_driver(MessageRequest::Socket {
977 socket: self.socket_id,
978 req_id: None,
979 req,
980 })
981 }
982
983 fn touch_begin_accept(&mut self) -> Result<()> {
984 if self.pending_accept.is_some() {
985 return Ok(());
986 }
987 let child_id: SocketId = self
988 .common
989 .socket_seed
990 .fetch_add(1, Ordering::SeqCst)
991 .into();
992 self.io_socket_fire_and_forget(RequestType::BeginAccept(child_id))?;
993
994 let (tx, rx_recv) = tokio::sync::mpsc::channel(100);
995 self.common.recv_tx.lock().unwrap().insert(child_id, tx);
996
997 self.pending_accept.replace((child_id, rx_recv));
998 Ok(())
999 }
1000
1001 fn transition_socket(&mut self) -> RemoteSocket {
1002 let (_tx_recv, rx_recv) = tokio::sync::mpsc::channel(1);
1003 let (_tx_recv_with_addr, rx_recv_with_addr) = tokio::sync::mpsc::channel(1);
1004 let (_tx_accept, rx_accept) = tokio::sync::mpsc::channel(1);
1005 let (_tx_sent, rx_sent) = tokio::sync::mpsc::channel(1);
1006
1007 self.owns_socket_bindings = false;
1008
1009 RemoteSocket {
1010 socket_id: self.socket_id,
1011 common: self.common.clone(),
1012 rx_buffer: std::mem::take(&mut self.rx_buffer),
1013 rx_recv: std::mem::replace(&mut self.rx_recv, rx_recv),
1014 rx_recv_with_addr: std::mem::replace(&mut self.rx_recv_with_addr, rx_recv_with_addr),
1015 tx_waker: self.tx_waker.clone(),
1016 rx_accept: std::mem::replace(&mut self.rx_accept, rx_accept),
1017 rx_sent: std::mem::replace(&mut self.rx_sent, rx_sent),
1018 pending_accept: self.pending_accept.take(),
1019 buffer_recv_with_addr: std::mem::take(&mut self.buffer_recv_with_addr),
1020 buffer_accept: std::mem::take(&mut self.buffer_accept),
1021 send_available: self.send_available,
1022 owns_socket_bindings: true,
1023 }
1024 }
1025}
1026
1027impl VirtualIoSource for RemoteSocket {
1028 fn remove_handler(&mut self) {
1029 self.common.handlers.lock().unwrap().remove(&self.socket_id);
1030 }
1031
1032 fn poll_read_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<usize>> {
1033 if !self.rx_buffer.is_empty() {
1034 return Poll::Ready(Ok(self.rx_buffer.len()));
1035 }
1036 match self.rx_recv.poll_recv(cx) {
1037 Poll::Ready(Some(data)) => {
1038 self.rx_buffer.extend_from_slice(&data);
1039 return Poll::Ready(Ok(self.rx_buffer.len()));
1040 }
1041 Poll::Ready(None) => return Poll::Ready(Ok(0)),
1042 Poll::Pending => {}
1043 }
1044 if !self.buffer_recv_with_addr.is_empty() {
1045 let len = self
1046 .buffer_recv_with_addr
1047 .front()
1048 .map(|a| a.data.len().max(1))
1049 .unwrap_or_default();
1050 return Poll::Ready(Ok(len));
1051 }
1052 match self.rx_recv_with_addr.poll_recv(cx) {
1053 Poll::Ready(Some(data)) => {
1054 let len = data.data.len().max(1);
1055 self.buffer_recv_with_addr.push_back(data);
1056 return Poll::Ready(Ok(len));
1057 }
1058 Poll::Ready(None) => return Poll::Ready(Ok(0)),
1059 Poll::Pending => {}
1060 }
1061 if !self.buffer_accept.is_empty() {
1062 return Poll::Ready(Ok(self.buffer_accept.len()));
1063 }
1064 match self.rx_accept.poll_recv(cx) {
1065 Poll::Ready(Some(data)) => self.buffer_accept.push_back(data),
1066 Poll::Ready(None) => {}
1067 Poll::Pending => {}
1068 }
1069 Poll::Pending
1070 }
1071
1072 fn poll_write_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<usize>> {
1073 if self.send_available > 0 {
1074 return Poll::Ready(Ok(self.send_available as usize));
1075 }
1076 match self.rx_sent.poll_recv(cx) {
1077 Poll::Ready(Some(amt)) => {
1078 self.send_available += amt;
1079 return Poll::Ready(Ok(self.send_available as usize));
1080 }
1081 Poll::Ready(None) => return Poll::Ready(Ok(0)),
1082 Poll::Pending => {}
1083 }
1084 Poll::Pending
1085 }
1086}
1087
1088impl VirtualSocket for RemoteSocket {
1089 fn set_ttl(&mut self, ttl: u32) -> Result<()> {
1090 self.io_socket_fire_and_forget(RequestType::SetTtl(ttl))
1091 }
1092
1093 fn ttl(&self) -> Result<u32> {
1094 match block_on(self.io_socket(RequestType::GetTtl)) {
1095 ResponseType::Err(err) => Err(err),
1096 ResponseType::Ttl(ttl) => Ok(ttl),
1097 res => {
1098 tracing::debug!("invalid response to get TTL request - {res:?}");
1099 Err(NetworkError::IOError)
1100 }
1101 }
1102 }
1103
1104 fn addr_local(&self) -> Result<SocketAddr> {
1105 match block_on(self.io_socket(RequestType::GetAddrLocal)) {
1106 ResponseType::Err(err) => Err(err),
1107 ResponseType::SocketAddr(addr) => Ok(addr),
1108 res => {
1109 tracing::debug!("invalid response to address local request - {res:?}");
1110 Err(NetworkError::IOError)
1111 }
1112 }
1113 }
1114
1115 fn status(&self) -> Result<crate::SocketStatus> {
1116 match block_on(self.io_socket(RequestType::GetStatus)) {
1117 ResponseType::Err(err) => Err(err),
1118 ResponseType::Status(status) => Ok(status),
1119 res => {
1120 tracing::debug!("invalid response to status request - {res:?}");
1121 Err(NetworkError::IOError)
1122 }
1123 }
1124 }
1125
1126 fn set_handler(
1127 &mut self,
1128 handler: Box<dyn virtual_mio::InterestHandler + Send + Sync>,
1129 ) -> Result<()> {
1130 self.common
1131 .handlers
1132 .lock()
1133 .unwrap()
1134 .insert(self.socket_id, handler);
1135 Ok(())
1136 }
1137}
1138
1139impl VirtualTcpListener for RemoteSocket {
1140 fn try_accept(&mut self) -> Result<(Box<dyn VirtualTcpSocket + Sync>, SocketAddr)> {
1141 self.touch_begin_accept()?;
1143 let accepted = if let Some(child) = self.buffer_accept.pop_front() {
1144 child
1145 } else {
1146 self.rx_accept.try_recv().map_err(|err| match err {
1147 TryRecvError::Empty => NetworkError::WouldBlock,
1148 TryRecvError::Disconnected => NetworkError::ConnectionAborted,
1149 })?
1150 };
1151
1152 let mut rx_recv = None;
1156 if let Some((rx_socket, existing_rx_recv)) = self.pending_accept.take()
1157 && accepted.socket == rx_socket
1158 {
1159 rx_recv.replace(existing_rx_recv);
1160 }
1161 let rx_recv = match rx_recv {
1162 Some(rx_recv) => rx_recv,
1163 None => {
1164 let (tx, rx_recv) = tokio::sync::mpsc::channel(100);
1165 self.common
1166 .recv_tx
1167 .lock()
1168 .unwrap()
1169 .insert(accepted.socket, tx);
1170 rx_recv
1171 }
1172 };
1173 self.touch_begin_accept().ok();
1174
1175 let (tx, rx_recv_with_addr) = tokio::sync::mpsc::channel(100);
1176 self.common
1177 .recv_with_addr_tx
1178 .lock()
1179 .unwrap()
1180 .insert(accepted.socket, tx);
1181
1182 let (tx, rx_accept) = tokio::sync::mpsc::channel(100);
1183 self.common
1184 .accept_tx
1185 .lock()
1186 .unwrap()
1187 .insert(accepted.socket, tx);
1188
1189 let (tx, rx_sent) = tokio::sync::mpsc::channel(100);
1190 self.common
1191 .sent_tx
1192 .lock()
1193 .unwrap()
1194 .insert(accepted.socket, tx);
1195
1196 let socket = RemoteSocket {
1197 socket_id: accepted.socket,
1198 common: self.common.clone(),
1199 rx_buffer: BytesMut::new(),
1200 rx_recv,
1201 rx_recv_with_addr,
1202 rx_accept,
1203 rx_sent,
1204 pending_accept: None,
1205 tx_waker: TxWaker::new(&self.common).as_waker(),
1206 buffer_accept: Default::default(),
1207 buffer_recv_with_addr: Default::default(),
1208 send_available: 0,
1209 owns_socket_bindings: true,
1210 };
1211 Ok((Box::new(socket), accepted.addr))
1212 }
1213
1214 fn set_handler(
1215 &mut self,
1216 handler: Box<dyn virtual_mio::InterestHandler + Send + Sync>,
1217 ) -> Result<()> {
1218 VirtualSocket::set_handler(self, handler)
1219 }
1220
1221 fn addr_local(&self) -> Result<SocketAddr> {
1222 match block_on(self.io_socket(RequestType::GetAddrLocal)) {
1223 ResponseType::Err(err) => Err(err),
1224 ResponseType::SocketAddr(addr) => Ok(addr),
1225 res => {
1226 tracing::debug!("invalid response to addr local request - {res:?}");
1227 Err(NetworkError::IOError)
1228 }
1229 }
1230 }
1231
1232 fn set_ttl(&mut self, ttl: u8) -> Result<()> {
1233 self.io_socket_fire_and_forget(RequestType::SetTtl(ttl as u32))
1234 }
1235
1236 fn ttl(&self) -> Result<u8> {
1237 match block_on(self.io_socket(RequestType::GetTtl)) {
1238 ResponseType::Err(err) => Err(err),
1239 ResponseType::Ttl(val) => Ok(val.try_into().map_err(|_| NetworkError::InvalidData)?),
1240 res => {
1241 tracing::debug!("invalid response to get TTL request - {res:?}");
1242 Err(NetworkError::IOError)
1243 }
1244 }
1245 }
1246}
1247
1248impl VirtualTcpBoundSocket for RemoteSocket {
1249 fn addr_local(&self) -> Result<SocketAddr> {
1250 VirtualSocket::addr_local(self)
1251 }
1252
1253 fn listen(&mut self) -> Result<Box<dyn VirtualTcpListener + Sync>> {
1254 match block_on(self.io_socket(RequestType::ListenBound)) {
1255 ResponseType::Err(err) => Err(err),
1256 ResponseType::None => {
1257 let mut socket = self.transition_socket();
1258 socket.touch_begin_accept().ok();
1259 Ok(Box::new(socket))
1260 }
1261 res => {
1262 tracing::debug!("invalid response to listen bound request - {res:?}");
1263 Err(NetworkError::IOError)
1264 }
1265 }
1266 }
1267
1268 fn connect(&mut self, peer: SocketAddr) -> Result<Box<dyn VirtualTcpSocket + Sync>> {
1269 match block_on(self.io_socket(RequestType::ConnectBound { peer })) {
1270 ResponseType::Err(err) => Err(err),
1271 ResponseType::None => Ok(Box::new(self.transition_socket())),
1272 res => {
1273 tracing::debug!("invalid response to connect bound request - {res:?}");
1274 Err(NetworkError::IOError)
1275 }
1276 }
1277 }
1278
1279 fn set_ttl(&mut self, ttl: u32) -> Result<()> {
1280 VirtualSocket::set_ttl(self, ttl)
1281 }
1282
1283 fn ttl(&self) -> Result<u32> {
1284 VirtualSocket::ttl(self)
1285 }
1286}
1287
1288impl VirtualRawSocket for RemoteSocket {
1289 fn try_send(&mut self, data: &[u8]) -> Result<usize> {
1290 let mut cx = Context::from_waker(&self.tx_waker);
1291 match self.common.tx.poll_send(
1292 &mut cx,
1293 MessageRequest::Send {
1294 socket: self.socket_id,
1295 data: data.to_vec(),
1296 req_id: None,
1297 },
1298 ) {
1299 Poll::Ready(Ok(())) => Ok(data.len()),
1300 Poll::Ready(Err(NetworkError::WouldBlock)) | Poll::Pending => {
1301 self.send_available = 0;
1302 Err(NetworkError::WouldBlock)
1303 }
1304 Poll::Ready(Err(err)) => Err(err),
1305 }
1306 }
1307
1308 fn try_flush(&mut self) -> Result<()> {
1309 let mut cx = Context::from_waker(&self.tx_waker);
1310 match self.common.tx.poll_send(
1311 &mut cx,
1312 MessageRequest::Socket {
1313 socket: self.socket_id,
1314 req: RequestType::Flush,
1315 req_id: None,
1316 },
1317 ) {
1318 Poll::Ready(Ok(())) => Ok(()),
1319 Poll::Ready(Err(NetworkError::WouldBlock)) | Poll::Pending => {
1320 self.send_available = 0;
1321 Err(NetworkError::WouldBlock)
1322 }
1323 Poll::Ready(Err(err)) => Err(err),
1324 }
1325 }
1326
1327 fn try_recv(&mut self, buf: &mut [std::mem::MaybeUninit<u8>], peek: bool) -> Result<usize> {
1328 loop {
1329 if !self.rx_buffer.is_empty() {
1330 let amt = self.rx_buffer.len().min(buf.len());
1331 let buf: &mut [u8] = unsafe { std::mem::transmute(buf) };
1332 buf[..amt].copy_from_slice(&self.rx_buffer[..amt]);
1333 if !peek {
1334 self.rx_buffer.advance(amt);
1335 }
1336 return Ok(amt);
1337 }
1338 match self.rx_recv.try_recv() {
1339 Ok(data) => self.rx_buffer.extend_from_slice(&data),
1340 Err(TryRecvError::Disconnected) => return Err(NetworkError::ConnectionAborted),
1341 Err(TryRecvError::Empty) => return Err(NetworkError::WouldBlock),
1342 }
1343 }
1344 }
1345
1346 fn set_promiscuous(&mut self, promiscuous: bool) -> Result<()> {
1347 self.io_socket_fire_and_forget(RequestType::SetPromiscuous(promiscuous))
1348 }
1349
1350 fn promiscuous(&self) -> Result<bool> {
1351 match block_on(self.io_socket(RequestType::GetPromiscuous)) {
1352 ResponseType::Err(err) => Err(err),
1353 ResponseType::Flag(val) => Ok(val),
1354 res => {
1355 tracing::debug!("invalid response to get promiscuous request - {res:?}");
1356 Err(NetworkError::IOError)
1357 }
1358 }
1359 }
1360}
1361
1362impl VirtualConnectionlessSocket for RemoteSocket {
1363 fn try_send_to(&mut self, data: &[u8], addr: SocketAddr) -> Result<usize> {
1364 let req_id = self.common.request_seed.fetch_add(1, Ordering::SeqCst);
1365 let mut cx = Context::from_waker(&self.tx_waker);
1366 match self.common.tx.poll_send(
1367 &mut cx,
1368 MessageRequest::SendTo {
1369 socket: self.socket_id,
1370 data: data.to_vec(),
1371 addr,
1372 req_id: Some(req_id),
1373 },
1374 ) {
1375 Poll::Ready(Ok(())) => Ok(data.len()),
1376 Poll::Ready(Err(NetworkError::WouldBlock)) | Poll::Pending => {
1377 self.send_available = 0;
1378 Err(NetworkError::WouldBlock)
1379 }
1380 Poll::Ready(Err(err)) => Err(err),
1381 }
1382 }
1383
1384 fn try_recv_from(
1385 &mut self,
1386 buf: &mut [std::mem::MaybeUninit<u8>],
1387 peek: bool,
1388 ) -> Result<(usize, SocketAddr)> {
1389 if self.buffer_recv_with_addr.is_empty() {
1390 let received = self.rx_recv_with_addr.try_recv().map_err(|err| match err {
1391 TryRecvError::Disconnected => NetworkError::ConnectionAborted,
1392 TryRecvError::Empty => NetworkError::WouldBlock,
1393 })?;
1394 self.buffer_recv_with_addr.push_back(received);
1395 }
1396
1397 let Some(received) = self.buffer_recv_with_addr.front() else {
1398 return Err(NetworkError::WouldBlock);
1399 };
1400 let amt = buf.len().min(received.data.len());
1401 let addr = received.addr;
1402 let buf: &mut [u8] = unsafe { std::mem::transmute(buf) };
1403 buf[..amt].copy_from_slice(&received.data[..amt]);
1404
1405 if !peek {
1406 self.buffer_recv_with_addr.pop_front();
1407 }
1408
1409 Ok((amt, addr))
1410 }
1411}
1412
1413impl VirtualUdpSocket for RemoteSocket {
1414 fn set_broadcast(&mut self, broadcast: bool) -> Result<()> {
1415 self.io_socket_fire_and_forget(RequestType::SetBroadcast(broadcast))
1416 }
1417
1418 fn broadcast(&self) -> Result<bool> {
1419 match block_on(self.io_socket(RequestType::GetBroadcast)) {
1420 ResponseType::Err(err) => Err(err),
1421 ResponseType::Flag(val) => Ok(val),
1422 res => {
1423 tracing::debug!("invalid response to get broadcast request - {res:?}");
1424 Err(NetworkError::IOError)
1425 }
1426 }
1427 }
1428
1429 fn set_multicast_loop_v4(&mut self, val: bool) -> Result<()> {
1430 self.io_socket_fire_and_forget(RequestType::SetMulticastLoopV4(val))
1431 }
1432
1433 fn multicast_loop_v4(&self) -> Result<bool> {
1434 match block_on(self.io_socket(RequestType::GetMulticastLoopV4)) {
1435 ResponseType::Err(err) => Err(err),
1436 ResponseType::Flag(val) => Ok(val),
1437 res => {
1438 tracing::debug!("invalid response to get multicast loop v4 request - {res:?}");
1439 Err(NetworkError::IOError)
1440 }
1441 }
1442 }
1443
1444 fn set_multicast_loop_v6(&mut self, val: bool) -> Result<()> {
1445 self.io_socket_fire_and_forget(RequestType::SetMulticastLoopV6(val))
1446 }
1447
1448 fn multicast_loop_v6(&self) -> Result<bool> {
1449 match block_on(self.io_socket(RequestType::GetMulticastLoopV6)) {
1450 ResponseType::Err(err) => Err(err),
1451 ResponseType::Flag(val) => Ok(val),
1452 res => {
1453 tracing::debug!("invalid response to get multicast loop v6 request - {res:?}");
1454 Err(NetworkError::IOError)
1455 }
1456 }
1457 }
1458
1459 fn set_multicast_ttl_v4(&mut self, ttl: u32) -> Result<()> {
1460 self.io_socket_fire_and_forget(RequestType::SetMulticastTtlV4(ttl))
1461 }
1462
1463 fn multicast_ttl_v4(&self) -> Result<u32> {
1464 match block_on(self.io_socket(RequestType::GetMulticastTtlV4)) {
1465 ResponseType::Err(err) => Err(err),
1466 ResponseType::Ttl(ttl) => Ok(ttl),
1467 res => {
1468 tracing::debug!("invalid response to get multicast TTL v4 request - {res:?}");
1469 Err(NetworkError::IOError)
1470 }
1471 }
1472 }
1473
1474 fn join_multicast_v4(
1475 &mut self,
1476 multiaddr: std::net::Ipv4Addr,
1477 iface: std::net::Ipv4Addr,
1478 ) -> Result<()> {
1479 self.io_socket_fire_and_forget(RequestType::JoinMulticastV4 { multiaddr, iface })
1480 }
1481
1482 fn leave_multicast_v4(
1483 &mut self,
1484 multiaddr: std::net::Ipv4Addr,
1485 iface: std::net::Ipv4Addr,
1486 ) -> Result<()> {
1487 self.io_socket_fire_and_forget(RequestType::LeaveMulticastV4 { multiaddr, iface })
1488 }
1489
1490 fn join_multicast_v6(&mut self, multiaddr: std::net::Ipv6Addr, iface: u32) -> Result<()> {
1491 self.io_socket_fire_and_forget(RequestType::JoinMulticastV6 { multiaddr, iface })
1492 }
1493
1494 fn leave_multicast_v6(&mut self, multiaddr: std::net::Ipv6Addr, iface: u32) -> Result<()> {
1495 self.io_socket_fire_and_forget(RequestType::LeaveMulticastV6 { multiaddr, iface })
1496 }
1497
1498 fn addr_peer(&self) -> Result<Option<SocketAddr>> {
1499 match block_on(self.io_socket(RequestType::GetAddrPeer)) {
1500 ResponseType::Err(err) => Err(err),
1501 ResponseType::None => Ok(None),
1502 ResponseType::SocketAddr(addr) => Ok(Some(addr)),
1503 res => {
1504 tracing::debug!("invalid response to addr peer request - {res:?}");
1505 Err(NetworkError::IOError)
1506 }
1507 }
1508 }
1509}
1510
1511impl VirtualIcmpSocket for RemoteSocket {}
1512
1513impl VirtualConnectedSocket for RemoteSocket {
1514 fn set_linger(&mut self, linger: Option<Duration>) -> Result<()> {
1515 self.io_socket_fire_and_forget(RequestType::SetLinger(linger))
1516 }
1517
1518 fn linger(&self) -> Result<Option<Duration>> {
1519 match block_on(self.io_socket(RequestType::GetLinger)) {
1520 ResponseType::Err(err) => Err(err),
1521 ResponseType::None => Ok(None),
1522 ResponseType::Duration(val) => Ok(Some(val)),
1523 res => {
1524 tracing::debug!("invalid response to get linger request - {res:?}");
1525 Err(NetworkError::IOError)
1526 }
1527 }
1528 }
1529
1530 fn try_send(&mut self, data: &[u8]) -> Result<usize> {
1531 let req_id = self.common.request_seed.fetch_add(1, Ordering::SeqCst);
1532 let mut cx = Context::from_waker(&self.tx_waker);
1533 match self.common.tx.poll_send(
1534 &mut cx,
1535 MessageRequest::Send {
1536 socket: self.socket_id,
1537 data: data.to_vec(),
1538 req_id: Some(req_id),
1539 },
1540 ) {
1541 Poll::Ready(Ok(())) => Ok(data.len()),
1542 Poll::Ready(Err(err)) => Err(err),
1543 Poll::Pending => Err(NetworkError::WouldBlock),
1544 }
1545 }
1546
1547 fn try_flush(&mut self) -> Result<()> {
1548 let mut cx = Context::from_waker(&self.tx_waker);
1549 match self.common.tx.poll_send(
1550 &mut cx,
1551 MessageRequest::Socket {
1552 socket: self.socket_id,
1553 req: RequestType::Flush,
1554 req_id: None,
1555 },
1556 ) {
1557 Poll::Ready(Ok(())) => Ok(()),
1558 Poll::Ready(Err(err)) => Err(err),
1559 Poll::Pending => Err(NetworkError::WouldBlock),
1560 }
1561 }
1562
1563 fn close(&mut self) -> Result<()> {
1564 let ret = self.io_socket_fire_and_forget(RequestType::Close);
1565 if ret.is_ok() {
1566 self.release_socket_bindings();
1567 }
1568 ret
1569 }
1570
1571 fn try_recv(&mut self, buf: &mut [std::mem::MaybeUninit<u8>], peek: bool) -> Result<usize> {
1572 loop {
1573 if !self.rx_buffer.is_empty() {
1574 let amt = self.rx_buffer.len().min(buf.len());
1575 let buf: &mut [u8] = unsafe { std::mem::transmute(buf) };
1576 buf[..amt].copy_from_slice(&self.rx_buffer[..amt]);
1577 if !peek {
1578 self.rx_buffer.advance(amt);
1579 }
1580 return Ok(amt);
1581 }
1582 match self.rx_recv.try_recv() {
1583 Ok(data) => self.rx_buffer.extend_from_slice(&data),
1584 Err(TryRecvError::Disconnected) => return Err(NetworkError::ConnectionAborted),
1585 Err(TryRecvError::Empty) => return Err(NetworkError::WouldBlock),
1586 }
1587 }
1588 }
1589}
1590
1591impl VirtualTcpSocket for RemoteSocket {
1592 fn set_recv_buf_size(&mut self, size: usize) -> Result<()> {
1593 self.io_socket_fire_and_forget(RequestType::SetRecvBufSize(size as u64))
1594 }
1595
1596 fn recv_buf_size(&self) -> Result<usize> {
1597 match block_on(self.io_socket(RequestType::GetRecvBufSize)) {
1598 ResponseType::Err(err) => Err(err),
1599 ResponseType::Amount(amt) => Ok(amt.try_into().map_err(|_| NetworkError::IOError)?),
1600 res => {
1601 tracing::debug!("invalid response to get recv buf size request - {res:?}");
1602 Err(NetworkError::IOError)
1603 }
1604 }
1605 }
1606
1607 fn set_send_buf_size(&mut self, size: usize) -> Result<()> {
1608 self.io_socket_fire_and_forget(RequestType::SetSendBufSize(size as u64))
1609 }
1610
1611 fn send_buf_size(&self) -> Result<usize> {
1612 match block_on(self.io_socket(RequestType::GetSendBufSize)) {
1613 ResponseType::Err(err) => Err(err),
1614 ResponseType::Amount(val) => Ok(val.try_into().map_err(|_| NetworkError::IOError)?),
1615 res => {
1616 tracing::debug!("invalid response to get send buf size request - {res:?}");
1617 Err(NetworkError::IOError)
1618 }
1619 }
1620 }
1621
1622 fn set_nodelay(&mut self, reuse: bool) -> Result<()> {
1623 self.io_socket_fire_and_forget(RequestType::SetNoDelay(reuse))
1624 }
1625
1626 fn nodelay(&self) -> Result<bool> {
1627 match block_on(self.io_socket(RequestType::GetNoDelay)) {
1628 ResponseType::Err(err) => Err(err),
1629 ResponseType::Flag(val) => Ok(val),
1630 res => {
1631 tracing::debug!("invalid response to get nodelay request - {res:?}");
1632 Err(NetworkError::IOError)
1633 }
1634 }
1635 }
1636
1637 fn set_keepalive(&mut self, keep_alive: bool) -> Result<()> {
1638 self.io_socket_fire_and_forget(RequestType::SetKeepAlive(keep_alive))
1639 }
1640
1641 fn keepalive(&self) -> Result<bool> {
1642 match block_on(self.io_socket(RequestType::GetKeepAlive)) {
1643 ResponseType::Err(err) => Err(err),
1644 ResponseType::Flag(val) => Ok(val),
1645 res => {
1646 tracing::debug!("invalid response to get nodelay request - {res:?}");
1647 Err(NetworkError::IOError)
1648 }
1649 }
1650 }
1651
1652 fn set_dontroute(&mut self, dont_route: bool) -> Result<()> {
1653 self.io_socket_fire_and_forget(RequestType::SetDontRoute(dont_route))
1654 }
1655
1656 fn dontroute(&self) -> Result<bool> {
1657 match block_on(self.io_socket(RequestType::GetDontRoute)) {
1658 ResponseType::Err(err) => Err(err),
1659 ResponseType::Flag(val) => Ok(val),
1660 res => {
1661 tracing::debug!("invalid response to get nodelay request - {res:?}");
1662 Err(NetworkError::IOError)
1663 }
1664 }
1665 }
1666
1667 fn addr_peer(&self) -> Result<SocketAddr> {
1668 match block_on(self.io_socket(RequestType::GetAddrPeer)) {
1669 ResponseType::Err(err) => Err(err),
1670 ResponseType::SocketAddr(addr) => Ok(addr),
1671 res => {
1672 tracing::debug!("invalid response to addr peer request - {res:?}");
1673 Err(NetworkError::IOError)
1674 }
1675 }
1676 }
1677
1678 fn shutdown(&mut self, how: std::net::Shutdown) -> Result<()> {
1679 let shutdown = match how {
1680 std::net::Shutdown::Read => meta::Shutdown::Read,
1681 std::net::Shutdown::Write => meta::Shutdown::Write,
1682 std::net::Shutdown::Both => meta::Shutdown::Both,
1683 };
1684 self.io_socket_fire_and_forget(RequestType::Shutdown(shutdown))
1685 }
1686
1687 fn is_closed(&self) -> bool {
1688 match block_on(self.io_socket(RequestType::IsClosed)) {
1689 ResponseType::Flag(val) => val,
1690 _ => false,
1691 }
1692 }
1693}