1use std::collections::HashMap;
2use std::collections::VecDeque;
3use std::future::Future;
4use std::net::IpAddr;
5use std::net::SocketAddr;
6use std::pin::Pin;
7use std::sync::Arc;
8use std::sync::Mutex;
9use std::sync::atomic::AtomicU64;
10use std::sync::atomic::Ordering;
11use std::task::Context;
12use std::task::Poll;
13use std::task::RawWaker;
14use std::task::RawWakerVTable;
15use std::task::Waker;
16use std::time::Duration;
17
18use bytes::Buf;
19use bytes::BytesMut;
20use futures_util::Sink;
21use futures_util::Stream;
22use futures_util::StreamExt;
23use futures_util::future::BoxFuture;
24use futures_util::stream::FuturesOrdered;
25#[cfg(feature = "hyper")]
26use hyper_util::rt::tokio::TokioIo;
27use tokio::io::AsyncRead;
28use tokio::io::AsyncWrite;
29use tokio::sync::mpsc;
30use tokio::sync::mpsc::error::TryRecvError;
31use tokio::sync::mpsc::error::TrySendError;
32use tokio_serde::SymmetricallyFramed;
33use tokio_serde::formats::SymmetricalBincode;
34#[cfg(feature = "cbor")]
35use tokio_serde::formats::SymmetricalCbor;
36#[cfg(feature = "json")]
37use tokio_serde::formats::SymmetricalJson;
38#[cfg(feature = "messagepack")]
39use tokio_serde::formats::SymmetricalMessagePack;
40use tokio_util::codec::FramedRead;
41use tokio_util::codec::FramedWrite;
42use tokio_util::codec::LengthDelimitedCodec;
43use virtual_mio::InlineWaker;
44use virtual_mio::InterestType;
45
46use crate::IpCidr;
47use crate::IpRoute;
48use crate::NetworkError;
49use crate::StreamSecurity;
50use crate::VirtualConnectedSocket;
51use crate::VirtualConnectionlessSocket;
52use crate::VirtualIcmpSocket;
53use crate::VirtualIoSource;
54use crate::VirtualNetworking;
55use crate::VirtualRawSocket;
56use crate::VirtualSocket;
57use crate::VirtualTcpListener;
58use crate::VirtualTcpSocket;
59use crate::VirtualUdpSocket;
60use crate::meta;
61use crate::meta::FrameSerializationFormat;
62use crate::meta::RequestType;
63use crate::meta::ResponseType;
64use crate::meta::SocketId;
65use crate::meta::{MessageRequest, MessageResponse};
66
67use crate::Result;
68use crate::rx_tx::RemoteRx;
69use crate::rx_tx::RemoteTx;
70use crate::rx_tx::RemoteTxWakers;
71
72#[derive(Debug, Clone)]
73pub struct RemoteNetworkingClient {
74 common: Arc<RemoteCommon>,
75}
76
77impl RemoteNetworkingClient {
78 fn new(
79 tx: RemoteTx<MessageRequest>,
80 rx: RemoteRx<MessageResponse>,
81 rx_work: mpsc::UnboundedReceiver<BoxFuture<'static, ()>>,
82 ) -> (Self, RemoteNetworkingClientDriver) {
83 let common = RemoteCommon {
84 tx,
85 rx: Mutex::new(rx),
86 request_seed: AtomicU64::new(1),
87 requests: Default::default(),
88 socket_seed: AtomicU64::new(1),
89 recv_tx: Default::default(),
90 recv_with_addr_tx: Default::default(),
91 accept_tx: Default::default(),
92 sent_tx: Default::default(),
93 handlers: Default::default(),
94 stall: Default::default(),
95 };
96 let common = Arc::new(common);
97
98 let driver = RemoteNetworkingClientDriver {
99 more_work: rx_work,
100 tasks: Default::default(),
101 common: common.clone(),
102 };
103 let networking = Self { common };
104
105 (networking, driver)
106 }
107
108 pub fn new_from_mpsc(
111 tx: mpsc::Sender<MessageRequest>,
112 rx: mpsc::Receiver<MessageResponse>,
113 ) -> (Self, RemoteNetworkingClientDriver) {
114 let (tx_work, rx_work) = mpsc::unbounded_channel();
115 let tx_wakers = RemoteTxWakers::default();
116
117 let tx = RemoteTx::Mpsc {
118 tx,
119 work: tx_work,
120 wakers: tx_wakers.clone(),
121 };
122 let rx = RemoteRx::Mpsc {
123 rx,
124 wakers: tx_wakers,
125 };
126
127 Self::new(tx, rx, rx_work)
128 }
129
130 pub fn new_from_async_io<TX, RX>(
136 tx: TX,
137 rx: RX,
138 format: FrameSerializationFormat,
139 ) -> (Self, RemoteNetworkingClientDriver)
140 where
141 TX: AsyncWrite + Send + 'static,
142 RX: AsyncRead + Send + 'static,
143 {
144 let tx = FramedWrite::new(tx, LengthDelimitedCodec::new());
145 let tx: Pin<Box<dyn Sink<MessageRequest, Error = std::io::Error> + Send + 'static>> =
146 match format {
147 FrameSerializationFormat::Bincode => {
148 Box::pin(SymmetricallyFramed::new(tx, SymmetricalBincode::default()))
149 }
150 #[cfg(feature = "json")]
151 FrameSerializationFormat::Json => {
152 Box::pin(SymmetricallyFramed::new(tx, SymmetricalJson::default()))
153 }
154 #[cfg(feature = "messagepack")]
155 FrameSerializationFormat::MessagePack => Box::pin(SymmetricallyFramed::new(
156 tx,
157 SymmetricalMessagePack::default(),
158 )),
159 #[cfg(feature = "cbor")]
160 FrameSerializationFormat::Cbor => {
161 Box::pin(SymmetricallyFramed::new(tx, SymmetricalCbor::default()))
162 }
163 };
164
165 let rx = FramedRead::new(rx, LengthDelimitedCodec::new());
166 let rx: Pin<Box<dyn Stream<Item = std::io::Result<MessageResponse>> + Send + 'static>> =
167 match format {
168 FrameSerializationFormat::Bincode => {
169 Box::pin(SymmetricallyFramed::new(rx, SymmetricalBincode::default()))
170 }
171 #[cfg(feature = "json")]
172 FrameSerializationFormat::Json => {
173 Box::pin(SymmetricallyFramed::new(rx, SymmetricalJson::default()))
174 }
175 #[cfg(feature = "messagepack")]
176 FrameSerializationFormat::MessagePack => Box::pin(SymmetricallyFramed::new(
177 rx,
178 SymmetricalMessagePack::default(),
179 )),
180 #[cfg(feature = "cbor")]
181 FrameSerializationFormat::Cbor => {
182 Box::pin(SymmetricallyFramed::new(rx, SymmetricalCbor::default()))
183 }
184 };
185
186 let (tx_work, rx_work) = mpsc::unbounded_channel();
187 let tx_wakers = RemoteTxWakers::default();
188
189 let tx = RemoteTx::Stream {
190 tx: Arc::new(tokio::sync::Mutex::new(tx)),
191 work: tx_work,
192 wakers: tx_wakers,
193 };
194 let rx = RemoteRx::Stream { rx };
195
196 Self::new(tx, rx, rx_work)
197 }
198
199 #[cfg(feature = "hyper")]
202 pub fn new_from_hyper_ws_io(
203 tx: futures_util::stream::SplitSink<
204 hyper_tungstenite::WebSocketStream<TokioIo<hyper::upgrade::Upgraded>>,
205 hyper_tungstenite::tungstenite::Message,
206 >,
207 rx: futures_util::stream::SplitStream<
208 hyper_tungstenite::WebSocketStream<TokioIo<hyper::upgrade::Upgraded>>,
209 >,
210 format: FrameSerializationFormat,
211 ) -> (Self, RemoteNetworkingClientDriver) {
212 let (tx_work, rx_work) = mpsc::unbounded_channel();
213
214 let tx = RemoteTx::HyperWebSocket {
215 tx: Arc::new(tokio::sync::Mutex::new(tx)),
216 work: tx_work,
217 wakers: RemoteTxWakers::default(),
218 format,
219 };
220 let rx = RemoteRx::HyperWebSocket { rx, format };
221 Self::new(tx, rx, rx_work)
222 }
223
224 #[cfg(feature = "tokio-tungstenite")]
227 pub fn new_from_tokio_ws_io(
228 tx: futures_util::stream::SplitSink<
229 tokio_tungstenite::WebSocketStream<
230 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
231 >,
232 tokio_tungstenite::tungstenite::Message,
233 >,
234 rx: futures_util::stream::SplitStream<
235 tokio_tungstenite::WebSocketStream<
236 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
237 >,
238 >,
239 format: FrameSerializationFormat,
240 ) -> (Self, RemoteNetworkingClientDriver) {
241 let (tx_work, rx_work) = mpsc::unbounded_channel();
242
243 let tx = RemoteTx::TokioWebSocket {
244 tx: Arc::new(tokio::sync::Mutex::new(tx)),
245 work: tx_work,
246 wakers: RemoteTxWakers::default(),
247 format,
248 };
249 let rx = RemoteRx::TokioWebSocket { rx, format };
250 Self::new(tx, rx, rx_work)
251 }
252
253 fn new_socket(&self, id: SocketId) -> RemoteSocket {
254 let (tx, rx_recv) = tokio::sync::mpsc::channel(100);
255 self.common.recv_tx.lock().unwrap().insert(id, tx);
256
257 let (tx, rx_recv_with_addr) = tokio::sync::mpsc::channel(100);
258 self.common.recv_with_addr_tx.lock().unwrap().insert(id, tx);
259
260 let (tx, rx_accept) = tokio::sync::mpsc::channel(100);
261 self.common.accept_tx.lock().unwrap().insert(id, tx);
262
263 let (tx, rx_sent) = tokio::sync::mpsc::channel(100);
264 self.common.sent_tx.lock().unwrap().insert(id, tx);
265
266 RemoteSocket {
267 socket_id: id,
268 common: self.common.clone(),
269 rx_buffer: BytesMut::new(),
270 rx_recv,
271 rx_recv_with_addr,
272 rx_accept,
273 rx_sent,
274 tx_waker: TxWaker::new(&self.common).as_waker(),
275 pending_accept: None,
276 buffer_accept: Default::default(),
277 buffer_recv_with_addr: Default::default(),
278 send_available: 0,
279 }
280 }
281}
282
283pin_project_lite::pin_project! {
284 pub struct RemoteNetworkingClientDriver {
285 common: Arc<RemoteCommon>,
286 more_work: mpsc::UnboundedReceiver<BoxFuture<'static, ()>>,
287 #[pin]
288 tasks: FuturesOrdered<BoxFuture<'static, ()>>,
289 }
290}
291
292impl Future for RemoteNetworkingClientDriver {
293 type Output = ();
294
295 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
296 let mut not_stalled_guard = None;
300
301 loop {
304 while let Poll::Ready(Some(work)) = Pin::new(&mut self.more_work).poll_recv(cx) {
306 self.tasks.push_back(work);
307 }
308
309 match self.tasks.poll_next_unpin(cx) {
312 Poll::Ready(Some(_)) => continue,
313 Poll::Ready(None) => {
314 not_stalled_guard.take();
315 }
316 Poll::Pending if not_stalled_guard.is_none() => {
317 match self.common.stall.clone().try_lock_owned() {
318 Ok(guard) => {
319 not_stalled_guard.replace(guard);
320 }
321 _ => {
322 return Poll::Pending;
323 }
324 }
325 }
326 Poll::Pending => {}
327 };
328
329 let msg = {
331 let mut rx_guard = self.common.rx.lock().unwrap();
332 rx_guard.poll(cx)
333 };
334 return match msg {
335 Poll::Ready(Some(msg)) => {
336 match msg {
337 MessageResponse::Recv { socket_id, data } => {
338 let tx = {
339 let guard = self.common.recv_tx.lock().unwrap();
340 match guard.get(&socket_id) {
341 Some(tx) => tx.clone(),
342 None => {
343 continue;
344 }
345 }
346 };
347 let common = self.common.clone();
348 self.tasks.push_back(Box::pin(async move {
349 tx.send(data).await.ok();
350
351 if let Some(h) = common.handlers.lock().unwrap().get_mut(&socket_id)
352 {
353 h.push_interest(InterestType::Readable)
354 }
355 }));
356 }
357 MessageResponse::RecvWithAddr {
358 socket_id,
359 data,
360 addr,
361 } => {
362 let tx = {
363 let guard = self.common.recv_with_addr_tx.lock().unwrap();
364 match guard.get(&socket_id) {
365 Some(tx) => tx.clone(),
366 None => continue,
367 }
368 };
369 let common = self.common.clone();
370 self.tasks.push_back(Box::pin(async move {
371 tx.send(DataWithAddr { data, addr }).await.ok();
372
373 if let Some(h) = common.handlers.lock().unwrap().get_mut(&socket_id)
374 {
375 h.push_interest(InterestType::Readable)
376 }
377 }));
378 }
379 MessageResponse::Sent {
380 socket_id, amount, ..
381 } => {
382 let tx = {
383 let guard = self.common.sent_tx.lock().unwrap();
384 match guard.get(&socket_id) {
385 Some(tx) => tx.clone(),
386 None => continue,
387 }
388 };
389 self.tasks.push_back(Box::pin(async move {
390 tx.send(amount).await.ok();
391 }));
392 if let Some(h) =
393 self.common.handlers.lock().unwrap().get_mut(&socket_id)
394 {
395 h.push_interest(InterestType::Writable)
396 }
397 }
398 MessageResponse::SendError {
399 socket_id, error, ..
400 } => match &error {
401 NetworkError::ConnectionAborted
402 | NetworkError::ConnectionReset
403 | NetworkError::BrokenPipe => {
404 if let Some(h) =
405 self.common.handlers.lock().unwrap().get_mut(&socket_id)
406 {
407 h.push_interest(InterestType::Closed)
408 }
409 }
410 _ => {
411 if let Some(h) =
412 self.common.handlers.lock().unwrap().get_mut(&socket_id)
413 {
414 h.push_interest(InterestType::Writable)
415 }
416 }
417 },
418 MessageResponse::FinishAccept {
419 socket_id,
420 child_id,
421 addr,
422 } => {
423 let common = self.common.clone();
424 self.tasks.push_back(Box::pin(async move {
425 let tx = common.accept_tx.lock().unwrap().get(&socket_id).cloned();
426 if let Some(tx) = tx {
427 tx.send(SocketWithAddr {
428 socket: child_id,
429 addr,
430 })
431 .await
432 .ok();
433 }
434
435 if let Some(h) = common.handlers.lock().unwrap().get_mut(&socket_id)
436 {
437 h.push_interest(InterestType::Readable)
438 }
439 }));
440 }
441 MessageResponse::Closed { socket_id } => {
442 if let Some(h) =
443 self.common.handlers.lock().unwrap().get_mut(&socket_id)
444 {
445 h.push_interest(InterestType::Closed)
446 }
447 }
448 MessageResponse::ResponseToRequest { req_id, res } => {
449 let mut requests = self.common.requests.lock().unwrap();
450 if let Some(request) = requests.remove(&req_id) {
451 request.try_send(res).ok();
452 }
453 }
454 }
455 continue;
456 }
457 Poll::Ready(None) => Poll::Ready(()),
458 Poll::Pending => Poll::Pending,
459 };
460 }
461 }
462}
463
464#[derive(Debug)]
465struct TxWaker {
466 common: Arc<RemoteCommon>,
467}
468impl TxWaker {
469 pub fn new(common: &Arc<RemoteCommon>) -> Arc<Self> {
470 Arc::new(Self {
471 common: common.clone(),
472 })
473 }
474
475 fn wake_now(&self) {
476 let mut guard = self.common.handlers.lock().unwrap();
477 for (_, handler) in guard.iter_mut() {
478 handler.push_interest(InterestType::Writable);
479 }
480 }
481
482 pub fn as_waker(self: &Arc<Self>) -> Waker {
483 let s: *const Self = Arc::into_raw(Arc::clone(self));
484 let raw_waker = RawWaker::new(s as *const (), &VTABLE);
485 unsafe { Waker::from_raw(raw_waker) }
486 }
487}
488
489fn tx_waker_wake(s: &TxWaker) {
490 let waker_arc = unsafe { Arc::from_raw(s) };
491 waker_arc.wake_now();
492}
493
494fn tx_waker_clone(s: &TxWaker) -> RawWaker {
495 let arc = unsafe { Arc::from_raw(s) };
496 std::mem::forget(arc.clone());
497 RawWaker::new(Arc::into_raw(arc) as *const (), &VTABLE)
498}
499
500const VTABLE: RawWakerVTable = unsafe {
501 RawWakerVTable::new(
502 |s| tx_waker_clone(&*(s as *const TxWaker)), |s| tx_waker_wake(&*(s as *const TxWaker)), |s| (*(s as *const TxWaker)).wake_now(), |s| drop(Arc::from_raw(s as *const TxWaker)), )
507};
508
509#[derive(Debug)]
510struct RequestTx {
511 tx: mpsc::Sender<ResponseType>,
512}
513impl RequestTx {
514 pub fn try_send(self, msg: ResponseType) -> Result<()> {
515 match self.tx.try_send(msg) {
516 Ok(()) => Ok(()),
517 Err(TrySendError::Closed(_)) => Err(NetworkError::ConnectionAborted),
518 Err(TrySendError::Full(_)) => Err(NetworkError::WouldBlock),
519 }
520 }
521}
522
523#[derive(Debug)]
524struct DataWithAddr {
525 pub data: Vec<u8>,
526 pub addr: SocketAddr,
527}
528#[derive(Debug)]
529struct SocketWithAddr {
530 pub socket: SocketId,
531 pub addr: SocketAddr,
532}
533type SocketMap<T> = HashMap<SocketId, T>;
534
535#[derive(derive_more::Debug)]
536struct RemoteCommon {
537 #[debug(ignore)]
538 tx: RemoteTx<MessageRequest>,
539 #[debug(ignore)]
540 rx: Mutex<RemoteRx<MessageResponse>>,
541 request_seed: AtomicU64,
542 requests: Mutex<HashMap<u64, RequestTx>>,
543 socket_seed: AtomicU64,
544 recv_tx: Mutex<SocketMap<mpsc::Sender<Vec<u8>>>>,
545 recv_with_addr_tx: Mutex<SocketMap<mpsc::Sender<DataWithAddr>>>,
546 accept_tx: Mutex<SocketMap<mpsc::Sender<SocketWithAddr>>>,
547 sent_tx: Mutex<SocketMap<mpsc::Sender<u64>>>,
548 #[debug(ignore)]
549 handlers: Mutex<SocketMap<Box<dyn virtual_mio::InterestHandler + Send + Sync>>>,
550
551 stall: Arc<tokio::sync::Mutex<()>>,
554}
555
556impl RemoteCommon {
557 async fn io_iface(&self, req: RequestType) -> ResponseType {
558 let req_id = self.request_seed.fetch_add(1, Ordering::SeqCst);
559 let mut req_rx = {
560 let (tx, rx) = mpsc::channel(1);
561 let mut guard = self.requests.lock().unwrap();
562 guard.insert(req_id, RequestTx { tx });
563 rx
564 };
565 if let Err(err) = self
566 .tx
567 .send(MessageRequest::Interface {
568 req_id: Some(req_id),
569 req,
570 })
571 .await
572 {
573 return ResponseType::Err(err);
574 };
575 req_rx.recv().await.unwrap()
576 }
577
578 fn io_iface_fire_and_forget(&self, req: RequestType) -> Result<()> {
579 self.tx
580 .send_with_driver(MessageRequest::Interface { req_id: None, req })
581 }
582}
583
584#[async_trait::async_trait]
585impl VirtualNetworking for RemoteNetworkingClient {
586 async fn bridge(
587 &self,
588 network: &str,
589 access_token: &str,
590 security: StreamSecurity,
591 ) -> Result<()> {
592 match self
593 .common
594 .io_iface(RequestType::Bridge {
595 network: network.to_string(),
596 access_token: access_token.to_string(),
597 security,
598 })
599 .await
600 {
601 ResponseType::Err(err) => Err(err),
602 ResponseType::None => Ok(()),
603 res => {
604 tracing::debug!("invalid response to bridge request - {res:?}");
605 Err(NetworkError::IOError)
606 }
607 }
608 }
609
610 async fn unbridge(&self) -> Result<()> {
611 match self.common.io_iface(RequestType::Unbridge).await {
612 ResponseType::Err(err) => Err(err),
613 ResponseType::None => Ok(()),
614 res => {
615 tracing::debug!("invalid response to unbridge request - {res:?}");
616 Err(NetworkError::IOError)
617 }
618 }
619 }
620
621 async fn dhcp_acquire(&self) -> Result<Vec<IpAddr>> {
622 match self.common.io_iface(RequestType::DhcpAcquire).await {
623 ResponseType::Err(err) => Err(err),
624 ResponseType::IpAddressList(ips) => Ok(ips),
625 res => {
626 tracing::debug!("invalid response to DHCP acquire request - {res:?}");
627 Err(NetworkError::IOError)
628 }
629 }
630 }
631
632 async fn ip_add(&self, ip: IpAddr, prefix: u8) -> Result<()> {
633 self.common
634 .io_iface_fire_and_forget(RequestType::IpAdd { ip, prefix })
635 }
636
637 async fn ip_remove(&self, ip: IpAddr) -> Result<()> {
638 self.common
639 .io_iface_fire_and_forget(RequestType::IpRemove(ip))
640 }
641
642 async fn ip_clear(&self) -> Result<()> {
643 self.common.io_iface_fire_and_forget(RequestType::IpClear)
644 }
645
646 async fn ip_list(&self) -> Result<Vec<IpCidr>> {
647 match self.common.io_iface(RequestType::GetIpList).await {
648 ResponseType::Err(err) => Err(err),
649 ResponseType::CidrList(routes) => Ok(routes),
650 res => {
651 tracing::debug!("invalid response to IP list request - {res:?}");
652 Err(NetworkError::IOError)
653 }
654 }
655 }
656
657 async fn mac(&self) -> Result<[u8; 6]> {
658 match self.common.io_iface(RequestType::GetMac).await {
659 ResponseType::Err(err) => Err(err),
660 ResponseType::Mac(mac) => Ok(mac),
661 res => {
662 tracing::debug!("invalid response to MAC request - {res:?}");
663 Err(NetworkError::IOError)
664 }
665 }
666 }
667
668 async fn gateway_set(&self, ip: IpAddr) -> Result<()> {
669 self.common
670 .io_iface_fire_and_forget(RequestType::GatewaySet(ip))
671 }
672
673 async fn route_add(
674 &self,
675 cidr: IpCidr,
676 via_router: IpAddr,
677 preferred_until: Option<Duration>,
678 expires_at: Option<Duration>,
679 ) -> Result<()> {
680 self.common.io_iface_fire_and_forget(RequestType::RouteAdd {
681 cidr,
682 via_router,
683 preferred_until,
684 expires_at,
685 })
686 }
687
688 async fn route_remove(&self, cidr: IpAddr) -> Result<()> {
689 self.common
690 .io_iface_fire_and_forget(RequestType::RouteRemove(cidr))
691 }
692
693 async fn route_clear(&self) -> Result<()> {
694 self.common
695 .io_iface_fire_and_forget(RequestType::RouteClear)
696 }
697
698 async fn route_list(&self) -> Result<Vec<IpRoute>> {
699 match self.common.io_iface(RequestType::GetRouteList).await {
700 ResponseType::Err(err) => Err(err),
701 ResponseType::RouteList(routes) => Ok(routes),
702 res => {
703 tracing::debug!("invalid response to route list request - {res:?}");
704 Err(NetworkError::IOError)
705 }
706 }
707 }
708
709 async fn bind_raw(&self) -> Result<Box<dyn VirtualRawSocket + Sync>> {
710 let socket_id: SocketId = self
711 .common
712 .socket_seed
713 .fetch_add(1, Ordering::SeqCst)
714 .into();
715 match self.common.io_iface(RequestType::BindRaw(socket_id)).await {
716 ResponseType::Err(err) => Err(err),
717 ResponseType::None => Ok(Box::new(self.new_socket(socket_id))),
718 ResponseType::Socket(socket_id) => Ok(Box::new(self.new_socket(socket_id))),
719 res => {
720 tracing::debug!("invalid response to bind RAw request - {res:?}");
721 Err(NetworkError::IOError)
722 }
723 }
724 }
725
726 async fn listen_tcp(
727 &self,
728 addr: SocketAddr,
729 only_v6: bool,
730 reuse_port: bool,
731 reuse_addr: bool,
732 ) -> Result<Box<dyn VirtualTcpListener + Sync>> {
733 let socket_id: SocketId = self
734 .common
735 .socket_seed
736 .fetch_add(1, Ordering::SeqCst)
737 .into();
738 match self
739 .common
740 .io_iface(RequestType::ListenTcp {
741 socket_id,
742 addr,
743 only_v6,
744 reuse_port,
745 reuse_addr,
746 })
747 .await
748 {
749 ResponseType::Err(err) => Err(err),
750 ResponseType::None => Ok(Box::new(self.new_socket(socket_id))),
751 ResponseType::Socket(socket_id) => {
752 let mut socket = self.new_socket(socket_id);
753 socket.touch_begin_accept().ok();
754 Ok(Box::new(socket))
755 }
756 res => {
757 tracing::debug!("invalid response to listen TCP request - {res:?}");
758 Err(NetworkError::IOError)
759 }
760 }
761 }
762
763 async fn bind_udp(
764 &self,
765 addr: SocketAddr,
766 reuse_port: bool,
767 reuse_addr: bool,
768 ) -> Result<Box<dyn VirtualUdpSocket + Sync>> {
769 let socket_id: SocketId = self
770 .common
771 .socket_seed
772 .fetch_add(1, Ordering::SeqCst)
773 .into();
774 match self
775 .common
776 .io_iface(RequestType::BindUdp {
777 socket_id,
778 addr,
779 reuse_port,
780 reuse_addr,
781 })
782 .await
783 {
784 ResponseType::Err(err) => Err(err),
785 ResponseType::None => Ok(Box::new(self.new_socket(socket_id))),
786 ResponseType::Socket(socket_id) => Ok(Box::new(self.new_socket(socket_id))),
787 res => {
788 tracing::debug!("invalid response to bind UDP request - {res:?}");
789 Err(NetworkError::IOError)
790 }
791 }
792 }
793
794 async fn bind_icmp(&self, addr: IpAddr) -> Result<Box<dyn VirtualIcmpSocket + Sync>> {
795 let socket_id: SocketId = self
796 .common
797 .socket_seed
798 .fetch_add(1, Ordering::SeqCst)
799 .into();
800 match self
801 .common
802 .io_iface(RequestType::BindIcmp { socket_id, addr })
803 .await
804 {
805 ResponseType::Err(err) => Err(err),
806 ResponseType::None => Ok(Box::new(self.new_socket(socket_id))),
807 ResponseType::Socket(socket_id) => Ok(Box::new(self.new_socket(socket_id))),
808 res => {
809 tracing::debug!("invalid response to bind ICMP request - {res:?}");
810 Err(NetworkError::IOError)
811 }
812 }
813 }
814
815 async fn connect_tcp(
816 &self,
817 addr: SocketAddr,
818 peer: SocketAddr,
819 ) -> Result<Box<dyn VirtualTcpSocket + Sync>> {
820 let socket_id: SocketId = self
821 .common
822 .socket_seed
823 .fetch_add(1, Ordering::SeqCst)
824 .into();
825 match self
826 .common
827 .io_iface(RequestType::ConnectTcp {
828 socket_id,
829 addr,
830 peer,
831 })
832 .await
833 {
834 ResponseType::Err(err) => Err(err),
835 ResponseType::None => Ok(Box::new(self.new_socket(socket_id))),
836 ResponseType::Socket(socket_id) => Ok(Box::new(self.new_socket(socket_id))),
837 res => {
838 tracing::debug!("invalid response to connect TCP request - {res:?}");
839 Err(NetworkError::IOError)
840 }
841 }
842 }
843
844 async fn resolve(
845 &self,
846 host: &str,
847 port: Option<u16>,
848 dns_server: Option<IpAddr>,
849 ) -> Result<Vec<IpAddr>> {
850 match self
851 .common
852 .io_iface(RequestType::Resolve {
853 host: host.to_string(),
854 port,
855 dns_server,
856 })
857 .await
858 {
859 ResponseType::Err(err) => Err(err),
860 ResponseType::IpAddressList(ips) => Ok(ips),
861 res => {
862 tracing::debug!("invalid response to resolve request - {res:?}");
863 Err(NetworkError::IOError)
864 }
865 }
866 }
867}
868
869#[derive(Debug)]
870struct RemoteSocket {
871 socket_id: SocketId,
872 common: Arc<RemoteCommon>,
873 rx_buffer: BytesMut,
874 rx_recv: mpsc::Receiver<Vec<u8>>,
875 rx_recv_with_addr: mpsc::Receiver<DataWithAddr>,
876 tx_waker: Waker,
877 rx_accept: mpsc::Receiver<SocketWithAddr>,
878 rx_sent: mpsc::Receiver<u64>,
879 pending_accept: Option<(SocketId, mpsc::Receiver<Vec<u8>>)>,
880 buffer_recv_with_addr: VecDeque<DataWithAddr>,
881 buffer_accept: VecDeque<SocketWithAddr>,
882 send_available: u64,
883}
884impl Drop for RemoteSocket {
885 fn drop(&mut self) {
886 self.common.recv_tx.lock().unwrap().remove(&self.socket_id);
887 self.common
888 .recv_with_addr_tx
889 .lock()
890 .unwrap()
891 .remove(&self.socket_id);
892 }
893}
894
895impl RemoteSocket {
896 async fn io_socket(&self, req: RequestType) -> ResponseType {
897 let req_id = self.common.request_seed.fetch_add(1, Ordering::SeqCst);
898 let mut req_rx = {
899 let (tx, rx) = mpsc::channel(1);
900 let mut guard = self.common.requests.lock().unwrap();
901 guard.insert(req_id, RequestTx { tx });
902 rx
903 };
904 if let Err(err) = self
905 .common
906 .tx
907 .send(MessageRequest::Socket {
908 socket: self.socket_id,
909 req_id: Some(req_id),
910 req,
911 })
912 .await
913 {
914 return ResponseType::Err(err);
915 };
916 req_rx.recv().await.unwrap()
917 }
918
919 fn io_socket_fire_and_forget(&self, req: RequestType) -> Result<()> {
920 self.common.tx.send_with_driver(MessageRequest::Socket {
921 socket: self.socket_id,
922 req_id: None,
923 req,
924 })
925 }
926
927 fn touch_begin_accept(&mut self) -> Result<()> {
928 if self.pending_accept.is_some() {
929 return Ok(());
930 }
931 let child_id: SocketId = self
932 .common
933 .socket_seed
934 .fetch_add(1, Ordering::SeqCst)
935 .into();
936 self.io_socket_fire_and_forget(RequestType::BeginAccept(child_id))?;
937
938 let (tx, rx_recv) = tokio::sync::mpsc::channel(100);
939 self.common.recv_tx.lock().unwrap().insert(child_id, tx);
940
941 self.pending_accept.replace((child_id, rx_recv));
942 Ok(())
943 }
944}
945
946impl VirtualIoSource for RemoteSocket {
947 fn remove_handler(&mut self) {
948 self.common.handlers.lock().unwrap().remove(&self.socket_id);
949 }
950
951 fn poll_read_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<usize>> {
952 if !self.rx_buffer.is_empty() {
953 return Poll::Ready(Ok(self.rx_buffer.len()));
954 }
955 match self.rx_recv.poll_recv(cx) {
956 Poll::Ready(Some(data)) => {
957 self.rx_buffer.extend_from_slice(&data);
958 return Poll::Ready(Ok(self.rx_buffer.len()));
959 }
960 Poll::Ready(None) => return Poll::Ready(Ok(0)),
961 Poll::Pending => {}
962 }
963 if !self.buffer_recv_with_addr.is_empty() {
964 let total = self
965 .buffer_recv_with_addr
966 .iter()
967 .map(|a| a.data.len())
968 .sum();
969 return Poll::Ready(Ok(total));
970 }
971 match self.rx_recv_with_addr.poll_recv(cx) {
972 Poll::Ready(Some(data)) => self.buffer_recv_with_addr.push_back(data),
973 Poll::Ready(None) => return Poll::Ready(Ok(0)),
974 Poll::Pending => {}
975 }
976 if !self.buffer_accept.is_empty() {
977 return Poll::Ready(Ok(self.buffer_accept.len()));
978 }
979 match self.rx_accept.poll_recv(cx) {
980 Poll::Ready(Some(data)) => self.buffer_accept.push_back(data),
981 Poll::Ready(None) => {}
982 Poll::Pending => {}
983 }
984 Poll::Pending
985 }
986
987 fn poll_write_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<usize>> {
988 if self.send_available > 0 {
989 return Poll::Ready(Ok(self.send_available as usize));
990 }
991 match self.rx_sent.poll_recv(cx) {
992 Poll::Ready(Some(amt)) => {
993 self.send_available += amt;
994 return Poll::Ready(Ok(self.send_available as usize));
995 }
996 Poll::Ready(None) => return Poll::Ready(Ok(0)),
997 Poll::Pending => {}
998 }
999 Poll::Pending
1000 }
1001}
1002
1003impl VirtualSocket for RemoteSocket {
1004 fn set_ttl(&mut self, ttl: u32) -> Result<()> {
1005 self.io_socket_fire_and_forget(RequestType::SetTtl(ttl))
1006 }
1007
1008 fn ttl(&self) -> Result<u32> {
1009 match InlineWaker::block_on(self.io_socket(RequestType::GetTtl)) {
1010 ResponseType::Err(err) => Err(err),
1011 ResponseType::Ttl(ttl) => Ok(ttl),
1012 res => {
1013 tracing::debug!("invalid response to get TTL request - {res:?}");
1014 Err(NetworkError::IOError)
1015 }
1016 }
1017 }
1018
1019 fn addr_local(&self) -> Result<SocketAddr> {
1020 match InlineWaker::block_on(self.io_socket(RequestType::GetAddrLocal)) {
1021 ResponseType::Err(err) => Err(err),
1022 ResponseType::SocketAddr(addr) => Ok(addr),
1023 res => {
1024 tracing::debug!("invalid response to address local request - {res:?}");
1025 Err(NetworkError::IOError)
1026 }
1027 }
1028 }
1029
1030 fn status(&self) -> Result<crate::SocketStatus> {
1031 match InlineWaker::block_on(self.io_socket(RequestType::GetStatus)) {
1032 ResponseType::Err(err) => Err(err),
1033 ResponseType::Status(status) => Ok(status),
1034 res => {
1035 tracing::debug!("invalid response to status request - {res:?}");
1036 Err(NetworkError::IOError)
1037 }
1038 }
1039 }
1040
1041 fn set_handler(
1042 &mut self,
1043 handler: Box<dyn virtual_mio::InterestHandler + Send + Sync>,
1044 ) -> Result<()> {
1045 self.common
1046 .handlers
1047 .lock()
1048 .unwrap()
1049 .insert(self.socket_id, handler);
1050 Ok(())
1051 }
1052}
1053
1054impl VirtualTcpListener for RemoteSocket {
1055 fn try_accept(&mut self) -> Result<(Box<dyn VirtualTcpSocket + Sync>, SocketAddr)> {
1056 self.touch_begin_accept()?;
1058 let accepted = if let Some(child) = self.buffer_accept.pop_front() {
1059 child
1060 } else {
1061 self.rx_accept.try_recv().map_err(|err| match err {
1062 TryRecvError::Empty => NetworkError::WouldBlock,
1063 TryRecvError::Disconnected => NetworkError::ConnectionAborted,
1064 })?
1065 };
1066
1067 let mut rx_recv = None;
1071 if let Some((rx_socket, existing_rx_recv)) = self.pending_accept.take() {
1072 if accepted.socket == rx_socket {
1073 rx_recv.replace(existing_rx_recv);
1074 }
1075 }
1076 let rx_recv = match rx_recv {
1077 Some(rx_recv) => rx_recv,
1078 None => {
1079 let (tx, rx_recv) = tokio::sync::mpsc::channel(100);
1080 self.common
1081 .recv_tx
1082 .lock()
1083 .unwrap()
1084 .insert(accepted.socket, tx);
1085 rx_recv
1086 }
1087 };
1088 self.touch_begin_accept().ok();
1089
1090 let (tx, rx_recv_with_addr) = tokio::sync::mpsc::channel(100);
1091 self.common
1092 .recv_with_addr_tx
1093 .lock()
1094 .unwrap()
1095 .insert(accepted.socket, tx);
1096
1097 let (tx, rx_accept) = tokio::sync::mpsc::channel(100);
1098 self.common
1099 .accept_tx
1100 .lock()
1101 .unwrap()
1102 .insert(accepted.socket, tx);
1103
1104 let (tx, rx_sent) = tokio::sync::mpsc::channel(100);
1105 self.common
1106 .sent_tx
1107 .lock()
1108 .unwrap()
1109 .insert(accepted.socket, tx);
1110
1111 let socket = RemoteSocket {
1112 socket_id: accepted.socket,
1113 common: self.common.clone(),
1114 rx_buffer: BytesMut::new(),
1115 rx_recv,
1116 rx_recv_with_addr,
1117 rx_accept,
1118 rx_sent,
1119 pending_accept: None,
1120 tx_waker: TxWaker::new(&self.common).as_waker(),
1121 buffer_accept: Default::default(),
1122 buffer_recv_with_addr: Default::default(),
1123 send_available: 0,
1124 };
1125 Ok((Box::new(socket), accepted.addr))
1126 }
1127
1128 fn set_handler(
1129 &mut self,
1130 handler: Box<dyn virtual_mio::InterestHandler + Send + Sync>,
1131 ) -> Result<()> {
1132 VirtualSocket::set_handler(self, handler)
1133 }
1134
1135 fn addr_local(&self) -> Result<SocketAddr> {
1136 match InlineWaker::block_on(self.io_socket(RequestType::GetAddrLocal)) {
1137 ResponseType::Err(err) => Err(err),
1138 ResponseType::SocketAddr(addr) => Ok(addr),
1139 res => {
1140 tracing::debug!("invalid response to addr local request - {res:?}");
1141 Err(NetworkError::IOError)
1142 }
1143 }
1144 }
1145
1146 fn set_ttl(&mut self, ttl: u8) -> Result<()> {
1147 self.io_socket_fire_and_forget(RequestType::SetTtl(ttl as u32))
1148 }
1149
1150 fn ttl(&self) -> Result<u8> {
1151 match InlineWaker::block_on(self.io_socket(RequestType::GetTtl)) {
1152 ResponseType::Err(err) => Err(err),
1153 ResponseType::Ttl(val) => Ok(val.try_into().map_err(|_| NetworkError::InvalidData)?),
1154 res => {
1155 tracing::debug!("invalid response to get TTL request - {res:?}");
1156 Err(NetworkError::IOError)
1157 }
1158 }
1159 }
1160}
1161
1162impl VirtualRawSocket for RemoteSocket {
1163 fn try_send(&mut self, data: &[u8]) -> Result<usize> {
1164 let mut cx = Context::from_waker(&self.tx_waker);
1165 match self.common.tx.poll_send(
1166 &mut cx,
1167 MessageRequest::Send {
1168 socket: self.socket_id,
1169 data: data.to_vec(),
1170 req_id: None,
1171 },
1172 ) {
1173 Poll::Ready(Ok(())) => Ok(data.len()),
1174 Poll::Ready(Err(NetworkError::WouldBlock)) | Poll::Pending => {
1175 self.send_available = 0;
1176 Err(NetworkError::WouldBlock)
1177 }
1178 Poll::Ready(Err(err)) => Err(err),
1179 }
1180 }
1181
1182 fn try_flush(&mut self) -> Result<()> {
1183 let mut cx = Context::from_waker(&self.tx_waker);
1184 match self.common.tx.poll_send(
1185 &mut cx,
1186 MessageRequest::Socket {
1187 socket: self.socket_id,
1188 req: RequestType::Flush,
1189 req_id: None,
1190 },
1191 ) {
1192 Poll::Ready(Ok(())) => Ok(()),
1193 Poll::Ready(Err(NetworkError::WouldBlock)) | Poll::Pending => {
1194 self.send_available = 0;
1195 Err(NetworkError::WouldBlock)
1196 }
1197 Poll::Ready(Err(err)) => Err(err),
1198 }
1199 }
1200
1201 fn try_recv(&mut self, buf: &mut [std::mem::MaybeUninit<u8>], peek: bool) -> Result<usize> {
1202 loop {
1203 if !self.rx_buffer.is_empty() {
1204 let amt = self.rx_buffer.len().min(buf.len());
1205 let buf: &mut [u8] = unsafe { std::mem::transmute(buf) };
1206 buf[..amt].copy_from_slice(&self.rx_buffer[..amt]);
1207 if !peek {
1208 self.rx_buffer.advance(amt);
1209 }
1210 return Ok(amt);
1211 }
1212 match self.rx_recv.try_recv() {
1213 Ok(data) => self.rx_buffer.extend_from_slice(&data),
1214 Err(TryRecvError::Disconnected) => return Err(NetworkError::ConnectionAborted),
1215 Err(TryRecvError::Empty) => return Err(NetworkError::WouldBlock),
1216 }
1217 }
1218 }
1219
1220 fn set_promiscuous(&mut self, promiscuous: bool) -> Result<()> {
1221 self.io_socket_fire_and_forget(RequestType::SetPromiscuous(promiscuous))
1222 }
1223
1224 fn promiscuous(&self) -> Result<bool> {
1225 match InlineWaker::block_on(self.io_socket(RequestType::GetPromiscuous)) {
1226 ResponseType::Err(err) => Err(err),
1227 ResponseType::Flag(val) => Ok(val),
1228 res => {
1229 tracing::debug!("invalid response to get promiscuous request - {res:?}");
1230 Err(NetworkError::IOError)
1231 }
1232 }
1233 }
1234}
1235
1236impl VirtualConnectionlessSocket for RemoteSocket {
1237 fn try_send_to(&mut self, data: &[u8], addr: SocketAddr) -> Result<usize> {
1238 let req_id = self.common.request_seed.fetch_add(1, Ordering::SeqCst);
1239 let mut cx = Context::from_waker(&self.tx_waker);
1240 match self.common.tx.poll_send(
1241 &mut cx,
1242 MessageRequest::SendTo {
1243 socket: self.socket_id,
1244 data: data.to_vec(),
1245 addr,
1246 req_id: Some(req_id),
1247 },
1248 ) {
1249 Poll::Ready(Ok(())) => Ok(data.len()),
1250 Poll::Ready(Err(NetworkError::WouldBlock)) | Poll::Pending => {
1251 self.send_available = 0;
1252 Err(NetworkError::WouldBlock)
1253 }
1254 Poll::Ready(Err(err)) => Err(err),
1255 }
1256 }
1257
1258 fn try_recv_from(
1259 &mut self,
1260 buf: &mut [std::mem::MaybeUninit<u8>],
1261 peek: bool,
1262 ) -> Result<(usize, SocketAddr)> {
1263 if peek {
1267 return Err(NetworkError::Unsupported);
1268 }
1269
1270 match self.rx_recv_with_addr.try_recv() {
1271 Ok(received) => {
1272 let amt = buf.len().min(received.data.len());
1273 let buf: &mut [u8] = unsafe { std::mem::transmute(buf) };
1274 buf[..amt].copy_from_slice(&received.data[..amt]);
1275 Ok((amt, received.addr))
1276 }
1277 Err(TryRecvError::Disconnected) => Err(NetworkError::ConnectionAborted),
1278 Err(TryRecvError::Empty) => Err(NetworkError::WouldBlock),
1279 }
1280 }
1281}
1282
1283impl VirtualUdpSocket for RemoteSocket {
1284 fn set_broadcast(&mut self, broadcast: bool) -> Result<()> {
1285 self.io_socket_fire_and_forget(RequestType::SetBroadcast(broadcast))
1286 }
1287
1288 fn broadcast(&self) -> Result<bool> {
1289 match InlineWaker::block_on(self.io_socket(RequestType::GetBroadcast)) {
1290 ResponseType::Err(err) => Err(err),
1291 ResponseType::Flag(val) => Ok(val),
1292 res => {
1293 tracing::debug!("invalid response to get broadcast request - {res:?}");
1294 Err(NetworkError::IOError)
1295 }
1296 }
1297 }
1298
1299 fn set_multicast_loop_v4(&mut self, val: bool) -> Result<()> {
1300 self.io_socket_fire_and_forget(RequestType::SetMulticastLoopV4(val))
1301 }
1302
1303 fn multicast_loop_v4(&self) -> Result<bool> {
1304 match InlineWaker::block_on(self.io_socket(RequestType::GetMulticastLoopV4)) {
1305 ResponseType::Err(err) => Err(err),
1306 ResponseType::Flag(val) => Ok(val),
1307 res => {
1308 tracing::debug!("invalid response to get multicast loop v4 request - {res:?}");
1309 Err(NetworkError::IOError)
1310 }
1311 }
1312 }
1313
1314 fn set_multicast_loop_v6(&mut self, val: bool) -> Result<()> {
1315 self.io_socket_fire_and_forget(RequestType::SetMulticastLoopV6(val))
1316 }
1317
1318 fn multicast_loop_v6(&self) -> Result<bool> {
1319 match InlineWaker::block_on(self.io_socket(RequestType::GetMulticastLoopV6)) {
1320 ResponseType::Err(err) => Err(err),
1321 ResponseType::Flag(val) => Ok(val),
1322 res => {
1323 tracing::debug!("invalid response to get multicast loop v6 request - {res:?}");
1324 Err(NetworkError::IOError)
1325 }
1326 }
1327 }
1328
1329 fn set_multicast_ttl_v4(&mut self, ttl: u32) -> Result<()> {
1330 self.io_socket_fire_and_forget(RequestType::SetMulticastTtlV4(ttl))
1331 }
1332
1333 fn multicast_ttl_v4(&self) -> Result<u32> {
1334 match InlineWaker::block_on(self.io_socket(RequestType::GetMulticastTtlV4)) {
1335 ResponseType::Err(err) => Err(err),
1336 ResponseType::Ttl(ttl) => Ok(ttl),
1337 res => {
1338 tracing::debug!("invalid response to get multicast TTL v4 request - {res:?}");
1339 Err(NetworkError::IOError)
1340 }
1341 }
1342 }
1343
1344 fn join_multicast_v4(
1345 &mut self,
1346 multiaddr: std::net::Ipv4Addr,
1347 iface: std::net::Ipv4Addr,
1348 ) -> Result<()> {
1349 self.io_socket_fire_and_forget(RequestType::JoinMulticastV4 { multiaddr, iface })
1350 }
1351
1352 fn leave_multicast_v4(
1353 &mut self,
1354 multiaddr: std::net::Ipv4Addr,
1355 iface: std::net::Ipv4Addr,
1356 ) -> Result<()> {
1357 self.io_socket_fire_and_forget(RequestType::LeaveMulticastV4 { multiaddr, iface })
1358 }
1359
1360 fn join_multicast_v6(&mut self, multiaddr: std::net::Ipv6Addr, iface: u32) -> Result<()> {
1361 self.io_socket_fire_and_forget(RequestType::JoinMulticastV6 { multiaddr, iface })
1362 }
1363
1364 fn leave_multicast_v6(&mut self, multiaddr: std::net::Ipv6Addr, iface: u32) -> Result<()> {
1365 self.io_socket_fire_and_forget(RequestType::LeaveMulticastV6 { multiaddr, iface })
1366 }
1367
1368 fn addr_peer(&self) -> Result<Option<SocketAddr>> {
1369 match InlineWaker::block_on(self.io_socket(RequestType::GetAddrPeer)) {
1370 ResponseType::Err(err) => Err(err),
1371 ResponseType::None => Ok(None),
1372 ResponseType::SocketAddr(addr) => Ok(Some(addr)),
1373 res => {
1374 tracing::debug!("invalid response to addr peer request - {res:?}");
1375 Err(NetworkError::IOError)
1376 }
1377 }
1378 }
1379}
1380
1381impl VirtualIcmpSocket for RemoteSocket {}
1382
1383impl VirtualConnectedSocket for RemoteSocket {
1384 fn set_linger(&mut self, linger: Option<Duration>) -> Result<()> {
1385 self.io_socket_fire_and_forget(RequestType::SetLinger(linger))
1386 }
1387
1388 fn linger(&self) -> Result<Option<Duration>> {
1389 match InlineWaker::block_on(self.io_socket(RequestType::GetLinger)) {
1390 ResponseType::Err(err) => Err(err),
1391 ResponseType::None => Ok(None),
1392 ResponseType::Duration(val) => Ok(Some(val)),
1393 res => {
1394 tracing::debug!("invalid response to get linger request - {res:?}");
1395 Err(NetworkError::IOError)
1396 }
1397 }
1398 }
1399
1400 fn try_send(&mut self, data: &[u8]) -> Result<usize> {
1401 let req_id = self.common.request_seed.fetch_add(1, Ordering::SeqCst);
1402 let mut cx = Context::from_waker(&self.tx_waker);
1403 match self.common.tx.poll_send(
1404 &mut cx,
1405 MessageRequest::Send {
1406 socket: self.socket_id,
1407 data: data.to_vec(),
1408 req_id: Some(req_id),
1409 },
1410 ) {
1411 Poll::Ready(Ok(())) => Ok(data.len()),
1412 Poll::Ready(Err(err)) => Err(err),
1413 Poll::Pending => Err(NetworkError::WouldBlock),
1414 }
1415 }
1416
1417 fn try_flush(&mut self) -> Result<()> {
1418 let mut cx = Context::from_waker(&self.tx_waker);
1419 match self.common.tx.poll_send(
1420 &mut cx,
1421 MessageRequest::Socket {
1422 socket: self.socket_id,
1423 req: RequestType::Flush,
1424 req_id: None,
1425 },
1426 ) {
1427 Poll::Ready(Ok(())) => Ok(()),
1428 Poll::Ready(Err(err)) => Err(err),
1429 Poll::Pending => Err(NetworkError::WouldBlock),
1430 }
1431 }
1432
1433 fn close(&mut self) -> Result<()> {
1434 self.io_socket_fire_and_forget(RequestType::Close)
1435 }
1436
1437 fn try_recv(&mut self, buf: &mut [std::mem::MaybeUninit<u8>], peek: bool) -> Result<usize> {
1438 loop {
1439 if !self.rx_buffer.is_empty() {
1440 let amt = self.rx_buffer.len().min(buf.len());
1441 let buf: &mut [u8] = unsafe { std::mem::transmute(buf) };
1442 buf[..amt].copy_from_slice(&self.rx_buffer[..amt]);
1443 if !peek {
1444 self.rx_buffer.advance(amt);
1445 }
1446 return Ok(amt);
1447 }
1448 match self.rx_recv.try_recv() {
1449 Ok(data) => self.rx_buffer.extend_from_slice(&data),
1450 Err(TryRecvError::Disconnected) => return Err(NetworkError::ConnectionAborted),
1451 Err(TryRecvError::Empty) => return Err(NetworkError::WouldBlock),
1452 }
1453 }
1454 }
1455}
1456
1457impl VirtualTcpSocket for RemoteSocket {
1458 fn set_recv_buf_size(&mut self, size: usize) -> Result<()> {
1459 self.io_socket_fire_and_forget(RequestType::SetRecvBufSize(size as u64))
1460 }
1461
1462 fn recv_buf_size(&self) -> Result<usize> {
1463 match InlineWaker::block_on(self.io_socket(RequestType::GetRecvBufSize)) {
1464 ResponseType::Err(err) => Err(err),
1465 ResponseType::Amount(amt) => Ok(amt.try_into().map_err(|_| NetworkError::IOError)?),
1466 res => {
1467 tracing::debug!("invalid response to get recv buf size request - {res:?}");
1468 Err(NetworkError::IOError)
1469 }
1470 }
1471 }
1472
1473 fn set_send_buf_size(&mut self, size: usize) -> Result<()> {
1474 self.io_socket_fire_and_forget(RequestType::SetSendBufSize(size as u64))
1475 }
1476
1477 fn send_buf_size(&self) -> Result<usize> {
1478 match InlineWaker::block_on(self.io_socket(RequestType::GetSendBufSize)) {
1479 ResponseType::Err(err) => Err(err),
1480 ResponseType::Amount(val) => Ok(val.try_into().map_err(|_| NetworkError::IOError)?),
1481 res => {
1482 tracing::debug!("invalid response to get send buf size request - {res:?}");
1483 Err(NetworkError::IOError)
1484 }
1485 }
1486 }
1487
1488 fn set_nodelay(&mut self, reuse: bool) -> Result<()> {
1489 self.io_socket_fire_and_forget(RequestType::SetNoDelay(reuse))
1490 }
1491
1492 fn nodelay(&self) -> Result<bool> {
1493 match InlineWaker::block_on(self.io_socket(RequestType::GetNoDelay)) {
1494 ResponseType::Err(err) => Err(err),
1495 ResponseType::Flag(val) => Ok(val),
1496 res => {
1497 tracing::debug!("invalid response to get nodelay request - {res:?}");
1498 Err(NetworkError::IOError)
1499 }
1500 }
1501 }
1502
1503 fn set_keepalive(&mut self, keep_alive: bool) -> Result<()> {
1504 self.io_socket_fire_and_forget(RequestType::SetKeepAlive(keep_alive))
1505 }
1506
1507 fn keepalive(&self) -> Result<bool> {
1508 match InlineWaker::block_on(self.io_socket(RequestType::GetKeepAlive)) {
1509 ResponseType::Err(err) => Err(err),
1510 ResponseType::Flag(val) => Ok(val),
1511 res => {
1512 tracing::debug!("invalid response to get nodelay request - {res:?}");
1513 Err(NetworkError::IOError)
1514 }
1515 }
1516 }
1517
1518 fn set_dontroute(&mut self, dont_route: bool) -> Result<()> {
1519 self.io_socket_fire_and_forget(RequestType::SetDontRoute(dont_route))
1520 }
1521
1522 fn dontroute(&self) -> Result<bool> {
1523 match InlineWaker::block_on(self.io_socket(RequestType::GetDontRoute)) {
1524 ResponseType::Err(err) => Err(err),
1525 ResponseType::Flag(val) => Ok(val),
1526 res => {
1527 tracing::debug!("invalid response to get nodelay request - {res:?}");
1528 Err(NetworkError::IOError)
1529 }
1530 }
1531 }
1532
1533 fn addr_peer(&self) -> Result<SocketAddr> {
1534 match InlineWaker::block_on(self.io_socket(RequestType::GetAddrPeer)) {
1535 ResponseType::Err(err) => Err(err),
1536 ResponseType::SocketAddr(addr) => Ok(addr),
1537 res => {
1538 tracing::debug!("invalid response to addr peer request - {res:?}");
1539 Err(NetworkError::IOError)
1540 }
1541 }
1542 }
1543
1544 fn shutdown(&mut self, how: std::net::Shutdown) -> Result<()> {
1545 let shutdown = match how {
1546 std::net::Shutdown::Read => meta::Shutdown::Read,
1547 std::net::Shutdown::Write => meta::Shutdown::Write,
1548 std::net::Shutdown::Both => meta::Shutdown::Both,
1549 };
1550 self.io_socket_fire_and_forget(RequestType::Shutdown(shutdown))
1551 }
1552
1553 fn is_closed(&self) -> bool {
1554 match InlineWaker::block_on(self.io_socket(RequestType::IsClosed)) {
1555 ResponseType::Flag(val) => val,
1556 _ => false,
1557 }
1558 }
1559}