virtual_net/
client.rs

1use std::collections::HashMap;
2use std::collections::VecDeque;
3use std::future::Future;
4use std::net::IpAddr;
5use std::net::SocketAddr;
6use std::pin::Pin;
7use std::sync::Arc;
8use std::sync::Mutex;
9use std::sync::atomic::AtomicU64;
10use std::sync::atomic::Ordering;
11use std::task::Context;
12use std::task::Poll;
13use std::task::RawWaker;
14use std::task::RawWakerVTable;
15use std::task::Waker;
16use std::time::Duration;
17
18use bytes::Buf;
19use bytes::BytesMut;
20use futures_util::Sink;
21use futures_util::Stream;
22use futures_util::StreamExt;
23use futures_util::future::BoxFuture;
24use futures_util::stream::FuturesOrdered;
25#[cfg(feature = "hyper")]
26use hyper_util::rt::tokio::TokioIo;
27use tokio::io::AsyncRead;
28use tokio::io::AsyncWrite;
29use tokio::sync::mpsc;
30use tokio::sync::mpsc::error::TryRecvError;
31use tokio::sync::mpsc::error::TrySendError;
32use tokio_serde::SymmetricallyFramed;
33use tokio_serde::formats::SymmetricalBincode;
34#[cfg(feature = "cbor")]
35use tokio_serde::formats::SymmetricalCbor;
36#[cfg(feature = "json")]
37use tokio_serde::formats::SymmetricalJson;
38#[cfg(feature = "messagepack")]
39use tokio_serde::formats::SymmetricalMessagePack;
40use tokio_util::codec::FramedRead;
41use tokio_util::codec::FramedWrite;
42use tokio_util::codec::LengthDelimitedCodec;
43use virtual_mio::InterestType;
44use virtual_mio::block_on;
45
46use crate::IpCidr;
47use crate::IpRoute;
48use crate::NetworkError;
49use crate::StreamSecurity;
50use crate::VirtualConnectedSocket;
51use crate::VirtualConnectionlessSocket;
52use crate::VirtualIcmpSocket;
53use crate::VirtualIoSource;
54use crate::VirtualNetworking;
55use crate::VirtualRawSocket;
56use crate::VirtualSocket;
57use crate::VirtualTcpBoundSocket;
58use crate::VirtualTcpListener;
59use crate::VirtualTcpSocket;
60use crate::VirtualUdpSocket;
61use crate::meta;
62use crate::meta::FrameSerializationFormat;
63use crate::meta::RequestType;
64use crate::meta::ResponseType;
65use crate::meta::SocketId;
66use crate::meta::{MessageRequest, MessageResponse};
67
68use crate::Result;
69use crate::rx_tx::RemoteRx;
70use crate::rx_tx::RemoteTx;
71use crate::rx_tx::RemoteTxWakers;
72
73#[derive(Debug, Clone)]
74pub struct RemoteNetworkingClient {
75    common: Arc<RemoteCommon>,
76}
77
78impl RemoteNetworkingClient {
79    fn new(
80        tx: RemoteTx<MessageRequest>,
81        rx: RemoteRx<MessageResponse>,
82        rx_work: mpsc::UnboundedReceiver<BoxFuture<'static, ()>>,
83    ) -> (Self, RemoteNetworkingClientDriver) {
84        let common = RemoteCommon {
85            tx,
86            rx: Mutex::new(rx),
87            request_seed: AtomicU64::new(1),
88            requests: Default::default(),
89            socket_seed: AtomicU64::new(1),
90            recv_tx: Default::default(),
91            recv_with_addr_tx: Default::default(),
92            accept_tx: Default::default(),
93            sent_tx: Default::default(),
94            handlers: Default::default(),
95            stall: Default::default(),
96        };
97        let common = Arc::new(common);
98
99        let driver = RemoteNetworkingClientDriver {
100            more_work: rx_work,
101            tasks: Default::default(),
102            common: common.clone(),
103        };
104        let networking = Self { common };
105
106        (networking, driver)
107    }
108
109    /// Creates a new interface on the remote location using
110    /// a unique interface ID and a pair of channels
111    pub fn new_from_mpsc(
112        tx: mpsc::Sender<MessageRequest>,
113        rx: mpsc::Receiver<MessageResponse>,
114    ) -> (Self, RemoteNetworkingClientDriver) {
115        let (tx_work, rx_work) = mpsc::unbounded_channel();
116        let tx_wakers = RemoteTxWakers::default();
117
118        let tx = RemoteTx::Mpsc {
119            tx,
120            work: tx_work,
121            wakers: tx_wakers.clone(),
122        };
123        let rx = RemoteRx::Mpsc {
124            rx,
125            wakers: tx_wakers,
126        };
127
128        Self::new(tx, rx, rx_work)
129    }
130
131    /// Creates a new interface on the remote location using
132    /// a unique interface ID and a pair of channels
133    ///
134    /// This version will run the async read and write operations
135    /// only the driver (this is needed for mixed runtimes)
136    pub fn new_from_async_io<TX, RX>(
137        tx: TX,
138        rx: RX,
139        format: FrameSerializationFormat,
140    ) -> (Self, RemoteNetworkingClientDriver)
141    where
142        TX: AsyncWrite + Send + 'static,
143        RX: AsyncRead + Send + 'static,
144    {
145        let tx = FramedWrite::new(tx, LengthDelimitedCodec::new());
146        let tx: Pin<Box<dyn Sink<MessageRequest, Error = std::io::Error> + Send + 'static>> =
147            match format {
148                FrameSerializationFormat::Bincode => {
149                    Box::pin(SymmetricallyFramed::new(tx, SymmetricalBincode::default()))
150                }
151                #[cfg(feature = "json")]
152                FrameSerializationFormat::Json => {
153                    Box::pin(SymmetricallyFramed::new(tx, SymmetricalJson::default()))
154                }
155                #[cfg(feature = "messagepack")]
156                FrameSerializationFormat::MessagePack => Box::pin(SymmetricallyFramed::new(
157                    tx,
158                    SymmetricalMessagePack::default(),
159                )),
160                #[cfg(feature = "cbor")]
161                FrameSerializationFormat::Cbor => {
162                    Box::pin(SymmetricallyFramed::new(tx, SymmetricalCbor::default()))
163                }
164            };
165
166        let rx = FramedRead::new(rx, LengthDelimitedCodec::new());
167        let rx: Pin<Box<dyn Stream<Item = std::io::Result<MessageResponse>> + Send + 'static>> =
168            match format {
169                FrameSerializationFormat::Bincode => {
170                    Box::pin(SymmetricallyFramed::new(rx, SymmetricalBincode::default()))
171                }
172                #[cfg(feature = "json")]
173                FrameSerializationFormat::Json => {
174                    Box::pin(SymmetricallyFramed::new(rx, SymmetricalJson::default()))
175                }
176                #[cfg(feature = "messagepack")]
177                FrameSerializationFormat::MessagePack => Box::pin(SymmetricallyFramed::new(
178                    rx,
179                    SymmetricalMessagePack::default(),
180                )),
181                #[cfg(feature = "cbor")]
182                FrameSerializationFormat::Cbor => {
183                    Box::pin(SymmetricallyFramed::new(rx, SymmetricalCbor::default()))
184                }
185            };
186
187        let (tx_work, rx_work) = mpsc::unbounded_channel();
188        let tx_wakers = RemoteTxWakers::default();
189
190        let tx = RemoteTx::Stream {
191            tx: Arc::new(tokio::sync::Mutex::new(tx)),
192            work: tx_work,
193            wakers: tx_wakers,
194        };
195        let rx = RemoteRx::Stream { rx };
196
197        Self::new(tx, rx, rx_work)
198    }
199
200    /// Creates a new interface on the remote location using
201    /// a unique interface ID and a pair of channels
202    #[cfg(feature = "hyper")]
203    pub fn new_from_hyper_ws_io(
204        tx: futures_util::stream::SplitSink<
205            hyper_tungstenite::WebSocketStream<TokioIo<hyper::upgrade::Upgraded>>,
206            hyper_tungstenite::tungstenite::Message,
207        >,
208        rx: futures_util::stream::SplitStream<
209            hyper_tungstenite::WebSocketStream<TokioIo<hyper::upgrade::Upgraded>>,
210        >,
211        format: FrameSerializationFormat,
212    ) -> (Self, RemoteNetworkingClientDriver) {
213        let (tx_work, rx_work) = mpsc::unbounded_channel();
214
215        let tx = RemoteTx::HyperWebSocket {
216            tx: Arc::new(tokio::sync::Mutex::new(tx)),
217            work: tx_work,
218            wakers: RemoteTxWakers::default(),
219            format,
220        };
221        let rx = RemoteRx::HyperWebSocket { rx, format };
222        Self::new(tx, rx, rx_work)
223    }
224
225    /// Creates a new interface on the remote location using
226    /// a unique interface ID and a pair of channels
227    #[cfg(feature = "tokio-tungstenite")]
228    pub fn new_from_tokio_ws_io(
229        tx: futures_util::stream::SplitSink<
230            tokio_tungstenite::WebSocketStream<
231                tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
232            >,
233            tokio_tungstenite::tungstenite::Message,
234        >,
235        rx: futures_util::stream::SplitStream<
236            tokio_tungstenite::WebSocketStream<
237                tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
238            >,
239        >,
240        format: FrameSerializationFormat,
241    ) -> (Self, RemoteNetworkingClientDriver) {
242        let (tx_work, rx_work) = mpsc::unbounded_channel();
243
244        let tx = RemoteTx::TokioWebSocket {
245            tx: Arc::new(tokio::sync::Mutex::new(tx)),
246            work: tx_work,
247            wakers: RemoteTxWakers::default(),
248            format,
249        };
250        let rx = RemoteRx::TokioWebSocket { rx, format };
251        Self::new(tx, rx, rx_work)
252    }
253
254    fn new_socket(&self, id: SocketId) -> RemoteSocket {
255        let (tx, rx_recv) = tokio::sync::mpsc::channel(100);
256        self.common.recv_tx.lock().unwrap().insert(id, tx);
257
258        let (tx, rx_recv_with_addr) = tokio::sync::mpsc::channel(100);
259        self.common.recv_with_addr_tx.lock().unwrap().insert(id, tx);
260
261        let (tx, rx_accept) = tokio::sync::mpsc::channel(100);
262        self.common.accept_tx.lock().unwrap().insert(id, tx);
263
264        let (tx, rx_sent) = tokio::sync::mpsc::channel(100);
265        self.common.sent_tx.lock().unwrap().insert(id, tx);
266
267        RemoteSocket {
268            socket_id: id,
269            common: self.common.clone(),
270            rx_buffer: BytesMut::new(),
271            rx_recv,
272            rx_recv_with_addr,
273            rx_accept,
274            rx_sent,
275            tx_waker: TxWaker::new(&self.common).as_waker(),
276            pending_accept: None,
277            buffer_accept: Default::default(),
278            buffer_recv_with_addr: Default::default(),
279            send_available: 0,
280            owns_socket_bindings: true,
281        }
282    }
283}
284
285pin_project_lite::pin_project! {
286    pub struct RemoteNetworkingClientDriver {
287        common: Arc<RemoteCommon>,
288        more_work: mpsc::UnboundedReceiver<BoxFuture<'static, ()>>,
289        #[pin]
290        tasks: FuturesOrdered<BoxFuture<'static, ()>>,
291    }
292}
293
294impl Future for RemoteNetworkingClientDriver {
295    type Output = ();
296
297    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
298        // This guard will be held while the pipeline is not currently
299        // stalled by some back pressure. It is only acquired when there
300        // is background tasks being processed
301        let mut not_stalled_guard = None;
302
303        // We loop until the waker is registered with the receiving stream
304        // and all the background tasks
305        loop {
306            // Background tasks are sent to this driver in certain circumstances
307            while let Poll::Ready(Some(work)) = Pin::new(&mut self.more_work).poll_recv(cx) {
308                self.tasks.push_back(work);
309            }
310
311            // Background work basically stalls the stream until its all processed
312            // which makes the back pressure system work properly
313            match self.tasks.poll_next_unpin(cx) {
314                Poll::Ready(Some(_)) => continue,
315                Poll::Ready(None) => {
316                    not_stalled_guard.take();
317                }
318                Poll::Pending if not_stalled_guard.is_none() => {
319                    match self.common.stall.clone().try_lock_owned() {
320                        Ok(guard) => {
321                            not_stalled_guard.replace(guard);
322                        }
323                        _ => {
324                            return Poll::Pending;
325                        }
326                    }
327                }
328                Poll::Pending => {}
329            };
330
331            // We grab the next message sent by the server to us
332            let msg = {
333                let mut rx_guard = self.common.rx.lock().unwrap();
334                rx_guard.poll(cx)
335            };
336            return match msg {
337                Poll::Ready(Some(msg)) => {
338                    match msg {
339                        MessageResponse::Recv { socket_id, data } => {
340                            let tx = {
341                                let guard = self.common.recv_tx.lock().unwrap();
342                                match guard.get(&socket_id) {
343                                    Some(tx) => tx.clone(),
344                                    None => {
345                                        continue;
346                                    }
347                                }
348                            };
349                            let common = self.common.clone();
350                            self.tasks.push_back(Box::pin(async move {
351                                tx.send(data).await.ok();
352
353                                if let Some(h) = common.handlers.lock().unwrap().get_mut(&socket_id)
354                                {
355                                    h.push_interest(InterestType::Readable)
356                                }
357                            }));
358                        }
359                        MessageResponse::RecvWithAddr {
360                            socket_id,
361                            data,
362                            addr,
363                        } => {
364                            let tx = {
365                                let guard = self.common.recv_with_addr_tx.lock().unwrap();
366                                match guard.get(&socket_id) {
367                                    Some(tx) => tx.clone(),
368                                    None => continue,
369                                }
370                            };
371                            let common = self.common.clone();
372                            self.tasks.push_back(Box::pin(async move {
373                                tx.send(DataWithAddr { data, addr }).await.ok();
374
375                                if let Some(h) = common.handlers.lock().unwrap().get_mut(&socket_id)
376                                {
377                                    h.push_interest(InterestType::Readable)
378                                }
379                            }));
380                        }
381                        MessageResponse::Sent {
382                            socket_id, amount, ..
383                        } => {
384                            let tx = {
385                                let guard = self.common.sent_tx.lock().unwrap();
386                                match guard.get(&socket_id) {
387                                    Some(tx) => tx.clone(),
388                                    None => continue,
389                                }
390                            };
391                            self.tasks.push_back(Box::pin(async move {
392                                tx.send(amount).await.ok();
393                            }));
394                            if let Some(h) =
395                                self.common.handlers.lock().unwrap().get_mut(&socket_id)
396                            {
397                                h.push_interest(InterestType::Writable)
398                            }
399                        }
400                        MessageResponse::SendError {
401                            socket_id, error, ..
402                        } => match &error {
403                            NetworkError::ConnectionAborted
404                            | NetworkError::ConnectionReset
405                            | NetworkError::BrokenPipe => {
406                                if let Some(h) =
407                                    self.common.handlers.lock().unwrap().get_mut(&socket_id)
408                                {
409                                    h.push_interest(InterestType::Closed)
410                                }
411                            }
412                            _ => {
413                                if let Some(h) =
414                                    self.common.handlers.lock().unwrap().get_mut(&socket_id)
415                                {
416                                    h.push_interest(InterestType::Writable)
417                                }
418                            }
419                        },
420                        MessageResponse::FinishAccept {
421                            socket_id,
422                            child_id,
423                            addr,
424                        } => {
425                            let common = self.common.clone();
426                            self.tasks.push_back(Box::pin(async move {
427                                let tx = common.accept_tx.lock().unwrap().get(&socket_id).cloned();
428                                if let Some(tx) = tx {
429                                    tx.send(SocketWithAddr {
430                                        socket: child_id,
431                                        addr,
432                                    })
433                                    .await
434                                    .ok();
435                                }
436
437                                if let Some(h) = common.handlers.lock().unwrap().get_mut(&socket_id)
438                                {
439                                    h.push_interest(InterestType::Readable)
440                                }
441                            }));
442                        }
443                        MessageResponse::Closed { socket_id } => {
444                            if let Some(h) =
445                                self.common.handlers.lock().unwrap().get_mut(&socket_id)
446                            {
447                                h.push_interest(InterestType::Closed)
448                            }
449                        }
450                        MessageResponse::ResponseToRequest { req_id, res } => {
451                            let mut requests = self.common.requests.lock().unwrap();
452                            if let Some(request) = requests.remove(&req_id) {
453                                request.try_send(res).ok();
454                            }
455                        }
456                    }
457                    continue;
458                }
459                Poll::Ready(None) => Poll::Ready(()),
460                Poll::Pending => Poll::Pending,
461            };
462        }
463    }
464}
465
466#[derive(Debug)]
467struct TxWaker {
468    common: Arc<RemoteCommon>,
469}
470impl TxWaker {
471    pub fn new(common: &Arc<RemoteCommon>) -> Arc<Self> {
472        Arc::new(Self {
473            common: common.clone(),
474        })
475    }
476
477    fn wake_now(&self) {
478        let mut guard = self.common.handlers.lock().unwrap();
479        for handler in guard.values_mut() {
480            handler.push_interest(InterestType::Writable);
481        }
482    }
483
484    pub fn as_waker(self: &Arc<Self>) -> Waker {
485        let s: *const Self = Arc::into_raw(Arc::clone(self));
486        let raw_waker = RawWaker::new(s as *const (), &VTABLE);
487        unsafe { Waker::from_raw(raw_waker) }
488    }
489}
490
491fn tx_waker_wake(s: &TxWaker) {
492    let waker_arc = unsafe { Arc::from_raw(s) };
493    waker_arc.wake_now();
494}
495
496fn tx_waker_clone(s: &TxWaker) -> RawWaker {
497    let arc = unsafe { Arc::from_raw(s) };
498    std::mem::forget(arc.clone());
499    RawWaker::new(Arc::into_raw(arc) as *const (), &VTABLE)
500}
501
502const VTABLE: RawWakerVTable = unsafe {
503    RawWakerVTable::new(
504        |s| tx_waker_clone(&*(s as *const TxWaker)),  // clone
505        |s| tx_waker_wake(&*(s as *const TxWaker)),   // wake
506        |s| (*(s as *const TxWaker)).wake_now(),      // wake by ref (don't decrease refcount)
507        |s| drop(Arc::from_raw(s as *const TxWaker)), // decrease refcount
508    )
509};
510
511#[derive(Debug)]
512struct RequestTx {
513    tx: mpsc::Sender<ResponseType>,
514}
515impl RequestTx {
516    pub fn try_send(self, msg: ResponseType) -> Result<()> {
517        match self.tx.try_send(msg) {
518            Ok(()) => Ok(()),
519            Err(TrySendError::Closed(_)) => Err(NetworkError::ConnectionAborted),
520            Err(TrySendError::Full(_)) => Err(NetworkError::WouldBlock),
521        }
522    }
523}
524
525#[derive(Debug)]
526struct DataWithAddr {
527    pub data: Vec<u8>,
528    pub addr: SocketAddr,
529}
530#[derive(Debug)]
531struct SocketWithAddr {
532    pub socket: SocketId,
533    pub addr: SocketAddr,
534}
535type SocketMap<T> = HashMap<SocketId, T>;
536
537#[derive(derive_more::Debug)]
538struct RemoteCommon {
539    #[debug(ignore)]
540    tx: RemoteTx<MessageRequest>,
541    #[debug(ignore)]
542    rx: Mutex<RemoteRx<MessageResponse>>,
543    request_seed: AtomicU64,
544    requests: Mutex<HashMap<u64, RequestTx>>,
545    socket_seed: AtomicU64,
546    recv_tx: Mutex<SocketMap<mpsc::Sender<Vec<u8>>>>,
547    recv_with_addr_tx: Mutex<SocketMap<mpsc::Sender<DataWithAddr>>>,
548    accept_tx: Mutex<SocketMap<mpsc::Sender<SocketWithAddr>>>,
549    sent_tx: Mutex<SocketMap<mpsc::Sender<u64>>>,
550    #[debug(ignore)]
551    handlers: Mutex<SocketMap<Box<dyn virtual_mio::InterestHandler + Send + Sync>>>,
552
553    // The stall guard will prevent reads while its held and there are background tasks running
554    // (the idea behind this is to create back pressure so that the task list infinitely grow)
555    stall: Arc<tokio::sync::Mutex<()>>,
556}
557
558impl RemoteCommon {
559    async fn io_iface(&self, req: RequestType) -> ResponseType {
560        let req_id = self.request_seed.fetch_add(1, Ordering::SeqCst);
561        let mut req_rx = {
562            let (tx, rx) = mpsc::channel(1);
563            let mut guard = self.requests.lock().unwrap();
564            guard.insert(req_id, RequestTx { tx });
565            rx
566        };
567        if let Err(err) = self
568            .tx
569            .send(MessageRequest::Interface {
570                req_id: Some(req_id),
571                req,
572            })
573            .await
574        {
575            return ResponseType::Err(err);
576        };
577        req_rx.recv().await.unwrap()
578    }
579
580    fn io_iface_fire_and_forget(&self, req: RequestType) -> Result<()> {
581        self.tx
582            .send_with_driver(MessageRequest::Interface { req_id: None, req })
583    }
584}
585
586#[async_trait::async_trait]
587impl VirtualNetworking for RemoteNetworkingClient {
588    async fn bridge(
589        &self,
590        network: &str,
591        access_token: &str,
592        security: StreamSecurity,
593    ) -> Result<()> {
594        match self
595            .common
596            .io_iface(RequestType::Bridge {
597                network: network.to_string(),
598                access_token: access_token.to_string(),
599                security,
600            })
601            .await
602        {
603            ResponseType::Err(err) => Err(err),
604            ResponseType::None => Ok(()),
605            res => {
606                tracing::debug!("invalid response to bridge request - {res:?}");
607                Err(NetworkError::IOError)
608            }
609        }
610    }
611
612    async fn unbridge(&self) -> Result<()> {
613        match self.common.io_iface(RequestType::Unbridge).await {
614            ResponseType::Err(err) => Err(err),
615            ResponseType::None => Ok(()),
616            res => {
617                tracing::debug!("invalid response to unbridge request - {res:?}");
618                Err(NetworkError::IOError)
619            }
620        }
621    }
622
623    async fn dhcp_acquire(&self) -> Result<Vec<IpAddr>> {
624        match self.common.io_iface(RequestType::DhcpAcquire).await {
625            ResponseType::Err(err) => Err(err),
626            ResponseType::IpAddressList(ips) => Ok(ips),
627            res => {
628                tracing::debug!("invalid response to DHCP acquire request - {res:?}");
629                Err(NetworkError::IOError)
630            }
631        }
632    }
633
634    async fn ip_add(&self, ip: IpAddr, prefix: u8) -> Result<()> {
635        self.common
636            .io_iface_fire_and_forget(RequestType::IpAdd { ip, prefix })
637    }
638
639    async fn ip_remove(&self, ip: IpAddr) -> Result<()> {
640        self.common
641            .io_iface_fire_and_forget(RequestType::IpRemove(ip))
642    }
643
644    async fn ip_clear(&self) -> Result<()> {
645        self.common.io_iface_fire_and_forget(RequestType::IpClear)
646    }
647
648    async fn ip_list(&self) -> Result<Vec<IpCidr>> {
649        match self.common.io_iface(RequestType::GetIpList).await {
650            ResponseType::Err(err) => Err(err),
651            ResponseType::CidrList(routes) => Ok(routes),
652            res => {
653                tracing::debug!("invalid response to IP list request - {res:?}");
654                Err(NetworkError::IOError)
655            }
656        }
657    }
658
659    async fn mac(&self) -> Result<[u8; 6]> {
660        match self.common.io_iface(RequestType::GetMac).await {
661            ResponseType::Err(err) => Err(err),
662            ResponseType::Mac(mac) => Ok(mac),
663            res => {
664                tracing::debug!("invalid response to MAC request - {res:?}");
665                Err(NetworkError::IOError)
666            }
667        }
668    }
669
670    async fn gateway_set(&self, ip: IpAddr) -> Result<()> {
671        self.common
672            .io_iface_fire_and_forget(RequestType::GatewaySet(ip))
673    }
674
675    async fn route_add(
676        &self,
677        cidr: IpCidr,
678        via_router: IpAddr,
679        preferred_until: Option<Duration>,
680        expires_at: Option<Duration>,
681    ) -> Result<()> {
682        self.common.io_iface_fire_and_forget(RequestType::RouteAdd {
683            cidr,
684            via_router,
685            preferred_until,
686            expires_at,
687        })
688    }
689
690    async fn route_remove(&self, cidr: IpAddr) -> Result<()> {
691        self.common
692            .io_iface_fire_and_forget(RequestType::RouteRemove(cidr))
693    }
694
695    async fn route_clear(&self) -> Result<()> {
696        self.common
697            .io_iface_fire_and_forget(RequestType::RouteClear)
698    }
699
700    async fn route_list(&self) -> Result<Vec<IpRoute>> {
701        match self.common.io_iface(RequestType::GetRouteList).await {
702            ResponseType::Err(err) => Err(err),
703            ResponseType::RouteList(routes) => Ok(routes),
704            res => {
705                tracing::debug!("invalid response to route list request - {res:?}");
706                Err(NetworkError::IOError)
707            }
708        }
709    }
710
711    async fn bind_raw(&self) -> Result<Box<dyn VirtualRawSocket + Sync>> {
712        let socket_id: SocketId = self
713            .common
714            .socket_seed
715            .fetch_add(1, Ordering::SeqCst)
716            .into();
717        match self.common.io_iface(RequestType::BindRaw(socket_id)).await {
718            ResponseType::Err(err) => Err(err),
719            ResponseType::None => Ok(Box::new(self.new_socket(socket_id))),
720            ResponseType::Socket(socket_id) => Ok(Box::new(self.new_socket(socket_id))),
721            res => {
722                tracing::debug!("invalid response to bind RAw request - {res:?}");
723                Err(NetworkError::IOError)
724            }
725        }
726    }
727
728    async fn listen_tcp(
729        &self,
730        addr: SocketAddr,
731        only_v6: bool,
732        reuse_port: bool,
733        reuse_addr: bool,
734    ) -> Result<Box<dyn VirtualTcpListener + Sync>> {
735        let socket_id: SocketId = self
736            .common
737            .socket_seed
738            .fetch_add(1, Ordering::SeqCst)
739            .into();
740        match self
741            .common
742            .io_iface(RequestType::ListenTcp {
743                socket_id,
744                addr,
745                only_v6,
746                reuse_port,
747                reuse_addr,
748            })
749            .await
750        {
751            ResponseType::Err(err) => Err(err),
752            ResponseType::None => Ok(Box::new(self.new_socket(socket_id))),
753            ResponseType::Socket(socket_id) => {
754                let mut socket = self.new_socket(socket_id);
755                socket.touch_begin_accept().ok();
756                Ok(Box::new(socket))
757            }
758            res => {
759                tracing::debug!("invalid response to listen TCP request - {res:?}");
760                Err(NetworkError::IOError)
761            }
762        }
763    }
764
765    async fn bind_tcp(
766        &self,
767        addr: SocketAddr,
768        only_v6: bool,
769        reuse_port: bool,
770        reuse_addr: bool,
771    ) -> Result<Box<dyn VirtualTcpBoundSocket + Sync>> {
772        let socket_id: SocketId = self
773            .common
774            .socket_seed
775            .fetch_add(1, Ordering::SeqCst)
776            .into();
777        match self
778            .common
779            .io_iface(RequestType::BindTcp {
780                socket_id,
781                addr,
782                only_v6,
783                reuse_port,
784                reuse_addr,
785            })
786            .await
787        {
788            ResponseType::Err(err) => Err(err),
789            ResponseType::None => Ok(Box::new(self.new_socket(socket_id))),
790            ResponseType::Socket(socket_id) => Ok(Box::new(self.new_socket(socket_id))),
791            res => {
792                tracing::debug!("invalid response to bind TCP request - {res:?}");
793                Err(NetworkError::IOError)
794            }
795        }
796    }
797
798    async fn bind_udp(
799        &self,
800        addr: SocketAddr,
801        reuse_port: bool,
802        reuse_addr: bool,
803    ) -> Result<Box<dyn VirtualUdpSocket + Sync>> {
804        let socket_id: SocketId = self
805            .common
806            .socket_seed
807            .fetch_add(1, Ordering::SeqCst)
808            .into();
809        match self
810            .common
811            .io_iface(RequestType::BindUdp {
812                socket_id,
813                addr,
814                reuse_port,
815                reuse_addr,
816            })
817            .await
818        {
819            ResponseType::Err(err) => Err(err),
820            ResponseType::None => Ok(Box::new(self.new_socket(socket_id))),
821            ResponseType::Socket(socket_id) => Ok(Box::new(self.new_socket(socket_id))),
822            res => {
823                tracing::debug!("invalid response to bind UDP request - {res:?}");
824                Err(NetworkError::IOError)
825            }
826        }
827    }
828
829    async fn bind_icmp(&self, addr: IpAddr) -> Result<Box<dyn VirtualIcmpSocket + Sync>> {
830        let socket_id: SocketId = self
831            .common
832            .socket_seed
833            .fetch_add(1, Ordering::SeqCst)
834            .into();
835        match self
836            .common
837            .io_iface(RequestType::BindIcmp { socket_id, addr })
838            .await
839        {
840            ResponseType::Err(err) => Err(err),
841            ResponseType::None => Ok(Box::new(self.new_socket(socket_id))),
842            ResponseType::Socket(socket_id) => Ok(Box::new(self.new_socket(socket_id))),
843            res => {
844                tracing::debug!("invalid response to bind ICMP request - {res:?}");
845                Err(NetworkError::IOError)
846            }
847        }
848    }
849
850    async fn connect_tcp(
851        &self,
852        addr: SocketAddr,
853        peer: SocketAddr,
854    ) -> Result<Box<dyn VirtualTcpSocket + Sync>> {
855        let socket_id: SocketId = self
856            .common
857            .socket_seed
858            .fetch_add(1, Ordering::SeqCst)
859            .into();
860        match self
861            .common
862            .io_iface(RequestType::ConnectTcp {
863                socket_id,
864                addr,
865                peer,
866            })
867            .await
868        {
869            ResponseType::Err(err) => Err(err),
870            ResponseType::None => Ok(Box::new(self.new_socket(socket_id))),
871            ResponseType::Socket(socket_id) => Ok(Box::new(self.new_socket(socket_id))),
872            res => {
873                tracing::debug!("invalid response to connect TCP request - {res:?}");
874                Err(NetworkError::IOError)
875            }
876        }
877    }
878
879    async fn resolve(
880        &self,
881        host: &str,
882        port: Option<u16>,
883        dns_server: Option<IpAddr>,
884    ) -> Result<Vec<IpAddr>> {
885        match self
886            .common
887            .io_iface(RequestType::Resolve {
888                host: host.to_string(),
889                port,
890                dns_server,
891            })
892            .await
893        {
894            ResponseType::Err(err) => Err(err),
895            ResponseType::IpAddressList(ips) => Ok(ips),
896            res => {
897                tracing::debug!("invalid response to resolve request - {res:?}");
898                Err(NetworkError::IOError)
899            }
900        }
901    }
902}
903
904#[derive(Debug)]
905struct RemoteSocket {
906    socket_id: SocketId,
907    common: Arc<RemoteCommon>,
908    rx_buffer: BytesMut,
909    rx_recv: mpsc::Receiver<Vec<u8>>,
910    rx_recv_with_addr: mpsc::Receiver<DataWithAddr>,
911    tx_waker: Waker,
912    rx_accept: mpsc::Receiver<SocketWithAddr>,
913    rx_sent: mpsc::Receiver<u64>,
914    pending_accept: Option<(SocketId, mpsc::Receiver<Vec<u8>>)>,
915    buffer_recv_with_addr: VecDeque<DataWithAddr>,
916    buffer_accept: VecDeque<SocketWithAddr>,
917    send_available: u64,
918    owns_socket_bindings: bool,
919}
920impl Drop for RemoteSocket {
921    fn drop(&mut self) {
922        if !self.owns_socket_bindings {
923            return;
924        }
925        let _ = self.io_socket_fire_and_forget(RequestType::Close);
926        self.release_socket_bindings();
927    }
928}
929
930impl RemoteSocket {
931    fn release_socket_bindings(&mut self) {
932        self.owns_socket_bindings = false;
933        self.common.recv_tx.lock().unwrap().remove(&self.socket_id);
934        self.common
935            .recv_with_addr_tx
936            .lock()
937            .unwrap()
938            .remove(&self.socket_id);
939        self.common
940            .accept_tx
941            .lock()
942            .unwrap()
943            .remove(&self.socket_id);
944        self.common.sent_tx.lock().unwrap().remove(&self.socket_id);
945        self.common.handlers.lock().unwrap().remove(&self.socket_id);
946
947        if let Some((child_id, _)) = self.pending_accept.take() {
948            self.common.recv_tx.lock().unwrap().remove(&child_id);
949        }
950    }
951
952    async fn io_socket(&self, req: RequestType) -> ResponseType {
953        let req_id = self.common.request_seed.fetch_add(1, Ordering::SeqCst);
954        let mut req_rx = {
955            let (tx, rx) = mpsc::channel(1);
956            let mut guard = self.common.requests.lock().unwrap();
957            guard.insert(req_id, RequestTx { tx });
958            rx
959        };
960        if let Err(err) = self
961            .common
962            .tx
963            .send(MessageRequest::Socket {
964                socket: self.socket_id,
965                req_id: Some(req_id),
966                req,
967            })
968            .await
969        {
970            return ResponseType::Err(err);
971        };
972        req_rx.recv().await.unwrap()
973    }
974
975    fn io_socket_fire_and_forget(&self, req: RequestType) -> Result<()> {
976        self.common.tx.send_with_driver(MessageRequest::Socket {
977            socket: self.socket_id,
978            req_id: None,
979            req,
980        })
981    }
982
983    fn touch_begin_accept(&mut self) -> Result<()> {
984        if self.pending_accept.is_some() {
985            return Ok(());
986        }
987        let child_id: SocketId = self
988            .common
989            .socket_seed
990            .fetch_add(1, Ordering::SeqCst)
991            .into();
992        self.io_socket_fire_and_forget(RequestType::BeginAccept(child_id))?;
993
994        let (tx, rx_recv) = tokio::sync::mpsc::channel(100);
995        self.common.recv_tx.lock().unwrap().insert(child_id, tx);
996
997        self.pending_accept.replace((child_id, rx_recv));
998        Ok(())
999    }
1000
1001    fn transition_socket(&mut self) -> RemoteSocket {
1002        let (_tx_recv, rx_recv) = tokio::sync::mpsc::channel(1);
1003        let (_tx_recv_with_addr, rx_recv_with_addr) = tokio::sync::mpsc::channel(1);
1004        let (_tx_accept, rx_accept) = tokio::sync::mpsc::channel(1);
1005        let (_tx_sent, rx_sent) = tokio::sync::mpsc::channel(1);
1006
1007        self.owns_socket_bindings = false;
1008
1009        RemoteSocket {
1010            socket_id: self.socket_id,
1011            common: self.common.clone(),
1012            rx_buffer: std::mem::take(&mut self.rx_buffer),
1013            rx_recv: std::mem::replace(&mut self.rx_recv, rx_recv),
1014            rx_recv_with_addr: std::mem::replace(&mut self.rx_recv_with_addr, rx_recv_with_addr),
1015            tx_waker: self.tx_waker.clone(),
1016            rx_accept: std::mem::replace(&mut self.rx_accept, rx_accept),
1017            rx_sent: std::mem::replace(&mut self.rx_sent, rx_sent),
1018            pending_accept: self.pending_accept.take(),
1019            buffer_recv_with_addr: std::mem::take(&mut self.buffer_recv_with_addr),
1020            buffer_accept: std::mem::take(&mut self.buffer_accept),
1021            send_available: self.send_available,
1022            owns_socket_bindings: true,
1023        }
1024    }
1025}
1026
1027impl VirtualIoSource for RemoteSocket {
1028    fn remove_handler(&mut self) {
1029        self.common.handlers.lock().unwrap().remove(&self.socket_id);
1030    }
1031
1032    fn poll_read_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<usize>> {
1033        if !self.rx_buffer.is_empty() {
1034            return Poll::Ready(Ok(self.rx_buffer.len()));
1035        }
1036        match self.rx_recv.poll_recv(cx) {
1037            Poll::Ready(Some(data)) => {
1038                self.rx_buffer.extend_from_slice(&data);
1039                return Poll::Ready(Ok(self.rx_buffer.len()));
1040            }
1041            Poll::Ready(None) => return Poll::Ready(Ok(0)),
1042            Poll::Pending => {}
1043        }
1044        if !self.buffer_recv_with_addr.is_empty() {
1045            let len = self
1046                .buffer_recv_with_addr
1047                .front()
1048                .map(|a| a.data.len().max(1))
1049                .unwrap_or_default();
1050            return Poll::Ready(Ok(len));
1051        }
1052        match self.rx_recv_with_addr.poll_recv(cx) {
1053            Poll::Ready(Some(data)) => {
1054                let len = data.data.len().max(1);
1055                self.buffer_recv_with_addr.push_back(data);
1056                return Poll::Ready(Ok(len));
1057            }
1058            Poll::Ready(None) => return Poll::Ready(Ok(0)),
1059            Poll::Pending => {}
1060        }
1061        if !self.buffer_accept.is_empty() {
1062            return Poll::Ready(Ok(self.buffer_accept.len()));
1063        }
1064        match self.rx_accept.poll_recv(cx) {
1065            Poll::Ready(Some(data)) => self.buffer_accept.push_back(data),
1066            Poll::Ready(None) => {}
1067            Poll::Pending => {}
1068        }
1069        Poll::Pending
1070    }
1071
1072    fn poll_write_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<usize>> {
1073        if self.send_available > 0 {
1074            return Poll::Ready(Ok(self.send_available as usize));
1075        }
1076        match self.rx_sent.poll_recv(cx) {
1077            Poll::Ready(Some(amt)) => {
1078                self.send_available += amt;
1079                return Poll::Ready(Ok(self.send_available as usize));
1080            }
1081            Poll::Ready(None) => return Poll::Ready(Ok(0)),
1082            Poll::Pending => {}
1083        }
1084        Poll::Pending
1085    }
1086}
1087
1088impl VirtualSocket for RemoteSocket {
1089    fn set_ttl(&mut self, ttl: u32) -> Result<()> {
1090        self.io_socket_fire_and_forget(RequestType::SetTtl(ttl))
1091    }
1092
1093    fn ttl(&self) -> Result<u32> {
1094        match block_on(self.io_socket(RequestType::GetTtl)) {
1095            ResponseType::Err(err) => Err(err),
1096            ResponseType::Ttl(ttl) => Ok(ttl),
1097            res => {
1098                tracing::debug!("invalid response to get TTL request - {res:?}");
1099                Err(NetworkError::IOError)
1100            }
1101        }
1102    }
1103
1104    fn addr_local(&self) -> Result<SocketAddr> {
1105        match block_on(self.io_socket(RequestType::GetAddrLocal)) {
1106            ResponseType::Err(err) => Err(err),
1107            ResponseType::SocketAddr(addr) => Ok(addr),
1108            res => {
1109                tracing::debug!("invalid response to address local request - {res:?}");
1110                Err(NetworkError::IOError)
1111            }
1112        }
1113    }
1114
1115    fn status(&self) -> Result<crate::SocketStatus> {
1116        match block_on(self.io_socket(RequestType::GetStatus)) {
1117            ResponseType::Err(err) => Err(err),
1118            ResponseType::Status(status) => Ok(status),
1119            res => {
1120                tracing::debug!("invalid response to status request - {res:?}");
1121                Err(NetworkError::IOError)
1122            }
1123        }
1124    }
1125
1126    fn set_handler(
1127        &mut self,
1128        handler: Box<dyn virtual_mio::InterestHandler + Send + Sync>,
1129    ) -> Result<()> {
1130        self.common
1131            .handlers
1132            .lock()
1133            .unwrap()
1134            .insert(self.socket_id, handler);
1135        Ok(())
1136    }
1137}
1138
1139impl VirtualTcpListener for RemoteSocket {
1140    fn try_accept(&mut self) -> Result<(Box<dyn VirtualTcpSocket + Sync>, SocketAddr)> {
1141        // We may already have accepted a connection in the `poll_read_ready` method
1142        self.touch_begin_accept()?;
1143        let accepted = if let Some(child) = self.buffer_accept.pop_front() {
1144            child
1145        } else {
1146            self.rx_accept.try_recv().map_err(|err| match err {
1147                TryRecvError::Empty => NetworkError::WouldBlock,
1148                TryRecvError::Disconnected => NetworkError::ConnectionAborted,
1149            })?
1150        };
1151
1152        // This placed here will mean there is always an accept request pending at the
1153        // server as the constructor invokes this method and we invoke it here after
1154        // receiving a child connection.
1155        let mut rx_recv = None;
1156        if let Some((rx_socket, existing_rx_recv)) = self.pending_accept.take()
1157            && accepted.socket == rx_socket
1158        {
1159            rx_recv.replace(existing_rx_recv);
1160        }
1161        let rx_recv = match rx_recv {
1162            Some(rx_recv) => rx_recv,
1163            None => {
1164                let (tx, rx_recv) = tokio::sync::mpsc::channel(100);
1165                self.common
1166                    .recv_tx
1167                    .lock()
1168                    .unwrap()
1169                    .insert(accepted.socket, tx);
1170                rx_recv
1171            }
1172        };
1173        self.touch_begin_accept().ok();
1174
1175        let (tx, rx_recv_with_addr) = tokio::sync::mpsc::channel(100);
1176        self.common
1177            .recv_with_addr_tx
1178            .lock()
1179            .unwrap()
1180            .insert(accepted.socket, tx);
1181
1182        let (tx, rx_accept) = tokio::sync::mpsc::channel(100);
1183        self.common
1184            .accept_tx
1185            .lock()
1186            .unwrap()
1187            .insert(accepted.socket, tx);
1188
1189        let (tx, rx_sent) = tokio::sync::mpsc::channel(100);
1190        self.common
1191            .sent_tx
1192            .lock()
1193            .unwrap()
1194            .insert(accepted.socket, tx);
1195
1196        let socket = RemoteSocket {
1197            socket_id: accepted.socket,
1198            common: self.common.clone(),
1199            rx_buffer: BytesMut::new(),
1200            rx_recv,
1201            rx_recv_with_addr,
1202            rx_accept,
1203            rx_sent,
1204            pending_accept: None,
1205            tx_waker: TxWaker::new(&self.common).as_waker(),
1206            buffer_accept: Default::default(),
1207            buffer_recv_with_addr: Default::default(),
1208            send_available: 0,
1209            owns_socket_bindings: true,
1210        };
1211        Ok((Box::new(socket), accepted.addr))
1212    }
1213
1214    fn set_handler(
1215        &mut self,
1216        handler: Box<dyn virtual_mio::InterestHandler + Send + Sync>,
1217    ) -> Result<()> {
1218        VirtualSocket::set_handler(self, handler)
1219    }
1220
1221    fn addr_local(&self) -> Result<SocketAddr> {
1222        match block_on(self.io_socket(RequestType::GetAddrLocal)) {
1223            ResponseType::Err(err) => Err(err),
1224            ResponseType::SocketAddr(addr) => Ok(addr),
1225            res => {
1226                tracing::debug!("invalid response to addr local request - {res:?}");
1227                Err(NetworkError::IOError)
1228            }
1229        }
1230    }
1231
1232    fn set_ttl(&mut self, ttl: u8) -> Result<()> {
1233        self.io_socket_fire_and_forget(RequestType::SetTtl(ttl as u32))
1234    }
1235
1236    fn ttl(&self) -> Result<u8> {
1237        match block_on(self.io_socket(RequestType::GetTtl)) {
1238            ResponseType::Err(err) => Err(err),
1239            ResponseType::Ttl(val) => Ok(val.try_into().map_err(|_| NetworkError::InvalidData)?),
1240            res => {
1241                tracing::debug!("invalid response to get TTL request - {res:?}");
1242                Err(NetworkError::IOError)
1243            }
1244        }
1245    }
1246}
1247
1248impl VirtualTcpBoundSocket for RemoteSocket {
1249    fn addr_local(&self) -> Result<SocketAddr> {
1250        VirtualSocket::addr_local(self)
1251    }
1252
1253    fn listen(&mut self) -> Result<Box<dyn VirtualTcpListener + Sync>> {
1254        match block_on(self.io_socket(RequestType::ListenBound)) {
1255            ResponseType::Err(err) => Err(err),
1256            ResponseType::None => {
1257                let mut socket = self.transition_socket();
1258                socket.touch_begin_accept().ok();
1259                Ok(Box::new(socket))
1260            }
1261            res => {
1262                tracing::debug!("invalid response to listen bound request - {res:?}");
1263                Err(NetworkError::IOError)
1264            }
1265        }
1266    }
1267
1268    fn connect(&mut self, peer: SocketAddr) -> Result<Box<dyn VirtualTcpSocket + Sync>> {
1269        match block_on(self.io_socket(RequestType::ConnectBound { peer })) {
1270            ResponseType::Err(err) => Err(err),
1271            ResponseType::None => Ok(Box::new(self.transition_socket())),
1272            res => {
1273                tracing::debug!("invalid response to connect bound request - {res:?}");
1274                Err(NetworkError::IOError)
1275            }
1276        }
1277    }
1278
1279    fn set_ttl(&mut self, ttl: u32) -> Result<()> {
1280        VirtualSocket::set_ttl(self, ttl)
1281    }
1282
1283    fn ttl(&self) -> Result<u32> {
1284        VirtualSocket::ttl(self)
1285    }
1286}
1287
1288impl VirtualRawSocket for RemoteSocket {
1289    fn try_send(&mut self, data: &[u8]) -> Result<usize> {
1290        let mut cx = Context::from_waker(&self.tx_waker);
1291        match self.common.tx.poll_send(
1292            &mut cx,
1293            MessageRequest::Send {
1294                socket: self.socket_id,
1295                data: data.to_vec(),
1296                req_id: None,
1297            },
1298        ) {
1299            Poll::Ready(Ok(())) => Ok(data.len()),
1300            Poll::Ready(Err(NetworkError::WouldBlock)) | Poll::Pending => {
1301                self.send_available = 0;
1302                Err(NetworkError::WouldBlock)
1303            }
1304            Poll::Ready(Err(err)) => Err(err),
1305        }
1306    }
1307
1308    fn try_flush(&mut self) -> Result<()> {
1309        let mut cx = Context::from_waker(&self.tx_waker);
1310        match self.common.tx.poll_send(
1311            &mut cx,
1312            MessageRequest::Socket {
1313                socket: self.socket_id,
1314                req: RequestType::Flush,
1315                req_id: None,
1316            },
1317        ) {
1318            Poll::Ready(Ok(())) => Ok(()),
1319            Poll::Ready(Err(NetworkError::WouldBlock)) | Poll::Pending => {
1320                self.send_available = 0;
1321                Err(NetworkError::WouldBlock)
1322            }
1323            Poll::Ready(Err(err)) => Err(err),
1324        }
1325    }
1326
1327    fn try_recv(&mut self, buf: &mut [std::mem::MaybeUninit<u8>], peek: bool) -> Result<usize> {
1328        loop {
1329            if !self.rx_buffer.is_empty() {
1330                let amt = self.rx_buffer.len().min(buf.len());
1331                let buf: &mut [u8] = unsafe { std::mem::transmute(buf) };
1332                buf[..amt].copy_from_slice(&self.rx_buffer[..amt]);
1333                if !peek {
1334                    self.rx_buffer.advance(amt);
1335                }
1336                return Ok(amt);
1337            }
1338            match self.rx_recv.try_recv() {
1339                Ok(data) => self.rx_buffer.extend_from_slice(&data),
1340                Err(TryRecvError::Disconnected) => return Err(NetworkError::ConnectionAborted),
1341                Err(TryRecvError::Empty) => return Err(NetworkError::WouldBlock),
1342            }
1343        }
1344    }
1345
1346    fn set_promiscuous(&mut self, promiscuous: bool) -> Result<()> {
1347        self.io_socket_fire_and_forget(RequestType::SetPromiscuous(promiscuous))
1348    }
1349
1350    fn promiscuous(&self) -> Result<bool> {
1351        match block_on(self.io_socket(RequestType::GetPromiscuous)) {
1352            ResponseType::Err(err) => Err(err),
1353            ResponseType::Flag(val) => Ok(val),
1354            res => {
1355                tracing::debug!("invalid response to get promiscuous request - {res:?}");
1356                Err(NetworkError::IOError)
1357            }
1358        }
1359    }
1360}
1361
1362impl VirtualConnectionlessSocket for RemoteSocket {
1363    fn try_send_to(&mut self, data: &[u8], addr: SocketAddr) -> Result<usize> {
1364        let req_id = self.common.request_seed.fetch_add(1, Ordering::SeqCst);
1365        let mut cx = Context::from_waker(&self.tx_waker);
1366        match self.common.tx.poll_send(
1367            &mut cx,
1368            MessageRequest::SendTo {
1369                socket: self.socket_id,
1370                data: data.to_vec(),
1371                addr,
1372                req_id: Some(req_id),
1373            },
1374        ) {
1375            Poll::Ready(Ok(())) => Ok(data.len()),
1376            Poll::Ready(Err(NetworkError::WouldBlock)) | Poll::Pending => {
1377                self.send_available = 0;
1378                Err(NetworkError::WouldBlock)
1379            }
1380            Poll::Ready(Err(err)) => Err(err),
1381        }
1382    }
1383
1384    fn try_recv_from(
1385        &mut self,
1386        buf: &mut [std::mem::MaybeUninit<u8>],
1387        peek: bool,
1388    ) -> Result<(usize, SocketAddr)> {
1389        if self.buffer_recv_with_addr.is_empty() {
1390            let received = self.rx_recv_with_addr.try_recv().map_err(|err| match err {
1391                TryRecvError::Disconnected => NetworkError::ConnectionAborted,
1392                TryRecvError::Empty => NetworkError::WouldBlock,
1393            })?;
1394            self.buffer_recv_with_addr.push_back(received);
1395        }
1396
1397        let Some(received) = self.buffer_recv_with_addr.front() else {
1398            return Err(NetworkError::WouldBlock);
1399        };
1400        let amt = buf.len().min(received.data.len());
1401        let addr = received.addr;
1402        let buf: &mut [u8] = unsafe { std::mem::transmute(buf) };
1403        buf[..amt].copy_from_slice(&received.data[..amt]);
1404
1405        if !peek {
1406            self.buffer_recv_with_addr.pop_front();
1407        }
1408
1409        Ok((amt, addr))
1410    }
1411}
1412
1413impl VirtualUdpSocket for RemoteSocket {
1414    fn set_broadcast(&mut self, broadcast: bool) -> Result<()> {
1415        self.io_socket_fire_and_forget(RequestType::SetBroadcast(broadcast))
1416    }
1417
1418    fn broadcast(&self) -> Result<bool> {
1419        match block_on(self.io_socket(RequestType::GetBroadcast)) {
1420            ResponseType::Err(err) => Err(err),
1421            ResponseType::Flag(val) => Ok(val),
1422            res => {
1423                tracing::debug!("invalid response to get broadcast request - {res:?}");
1424                Err(NetworkError::IOError)
1425            }
1426        }
1427    }
1428
1429    fn set_multicast_loop_v4(&mut self, val: bool) -> Result<()> {
1430        self.io_socket_fire_and_forget(RequestType::SetMulticastLoopV4(val))
1431    }
1432
1433    fn multicast_loop_v4(&self) -> Result<bool> {
1434        match block_on(self.io_socket(RequestType::GetMulticastLoopV4)) {
1435            ResponseType::Err(err) => Err(err),
1436            ResponseType::Flag(val) => Ok(val),
1437            res => {
1438                tracing::debug!("invalid response to get multicast loop v4 request - {res:?}");
1439                Err(NetworkError::IOError)
1440            }
1441        }
1442    }
1443
1444    fn set_multicast_loop_v6(&mut self, val: bool) -> Result<()> {
1445        self.io_socket_fire_and_forget(RequestType::SetMulticastLoopV6(val))
1446    }
1447
1448    fn multicast_loop_v6(&self) -> Result<bool> {
1449        match block_on(self.io_socket(RequestType::GetMulticastLoopV6)) {
1450            ResponseType::Err(err) => Err(err),
1451            ResponseType::Flag(val) => Ok(val),
1452            res => {
1453                tracing::debug!("invalid response to get multicast loop v6 request - {res:?}");
1454                Err(NetworkError::IOError)
1455            }
1456        }
1457    }
1458
1459    fn set_multicast_ttl_v4(&mut self, ttl: u32) -> Result<()> {
1460        self.io_socket_fire_and_forget(RequestType::SetMulticastTtlV4(ttl))
1461    }
1462
1463    fn multicast_ttl_v4(&self) -> Result<u32> {
1464        match block_on(self.io_socket(RequestType::GetMulticastTtlV4)) {
1465            ResponseType::Err(err) => Err(err),
1466            ResponseType::Ttl(ttl) => Ok(ttl),
1467            res => {
1468                tracing::debug!("invalid response to get multicast TTL v4 request - {res:?}");
1469                Err(NetworkError::IOError)
1470            }
1471        }
1472    }
1473
1474    fn join_multicast_v4(
1475        &mut self,
1476        multiaddr: std::net::Ipv4Addr,
1477        iface: std::net::Ipv4Addr,
1478    ) -> Result<()> {
1479        self.io_socket_fire_and_forget(RequestType::JoinMulticastV4 { multiaddr, iface })
1480    }
1481
1482    fn leave_multicast_v4(
1483        &mut self,
1484        multiaddr: std::net::Ipv4Addr,
1485        iface: std::net::Ipv4Addr,
1486    ) -> Result<()> {
1487        self.io_socket_fire_and_forget(RequestType::LeaveMulticastV4 { multiaddr, iface })
1488    }
1489
1490    fn join_multicast_v6(&mut self, multiaddr: std::net::Ipv6Addr, iface: u32) -> Result<()> {
1491        self.io_socket_fire_and_forget(RequestType::JoinMulticastV6 { multiaddr, iface })
1492    }
1493
1494    fn leave_multicast_v6(&mut self, multiaddr: std::net::Ipv6Addr, iface: u32) -> Result<()> {
1495        self.io_socket_fire_and_forget(RequestType::LeaveMulticastV6 { multiaddr, iface })
1496    }
1497
1498    fn addr_peer(&self) -> Result<Option<SocketAddr>> {
1499        match block_on(self.io_socket(RequestType::GetAddrPeer)) {
1500            ResponseType::Err(err) => Err(err),
1501            ResponseType::None => Ok(None),
1502            ResponseType::SocketAddr(addr) => Ok(Some(addr)),
1503            res => {
1504                tracing::debug!("invalid response to addr peer request - {res:?}");
1505                Err(NetworkError::IOError)
1506            }
1507        }
1508    }
1509}
1510
1511impl VirtualIcmpSocket for RemoteSocket {}
1512
1513impl VirtualConnectedSocket for RemoteSocket {
1514    fn set_linger(&mut self, linger: Option<Duration>) -> Result<()> {
1515        self.io_socket_fire_and_forget(RequestType::SetLinger(linger))
1516    }
1517
1518    fn linger(&self) -> Result<Option<Duration>> {
1519        match block_on(self.io_socket(RequestType::GetLinger)) {
1520            ResponseType::Err(err) => Err(err),
1521            ResponseType::None => Ok(None),
1522            ResponseType::Duration(val) => Ok(Some(val)),
1523            res => {
1524                tracing::debug!("invalid response to get linger request - {res:?}");
1525                Err(NetworkError::IOError)
1526            }
1527        }
1528    }
1529
1530    fn try_send(&mut self, data: &[u8]) -> Result<usize> {
1531        let req_id = self.common.request_seed.fetch_add(1, Ordering::SeqCst);
1532        let mut cx = Context::from_waker(&self.tx_waker);
1533        match self.common.tx.poll_send(
1534            &mut cx,
1535            MessageRequest::Send {
1536                socket: self.socket_id,
1537                data: data.to_vec(),
1538                req_id: Some(req_id),
1539            },
1540        ) {
1541            Poll::Ready(Ok(())) => Ok(data.len()),
1542            Poll::Ready(Err(err)) => Err(err),
1543            Poll::Pending => Err(NetworkError::WouldBlock),
1544        }
1545    }
1546
1547    fn try_flush(&mut self) -> Result<()> {
1548        let mut cx = Context::from_waker(&self.tx_waker);
1549        match self.common.tx.poll_send(
1550            &mut cx,
1551            MessageRequest::Socket {
1552                socket: self.socket_id,
1553                req: RequestType::Flush,
1554                req_id: None,
1555            },
1556        ) {
1557            Poll::Ready(Ok(())) => Ok(()),
1558            Poll::Ready(Err(err)) => Err(err),
1559            Poll::Pending => Err(NetworkError::WouldBlock),
1560        }
1561    }
1562
1563    fn close(&mut self) -> Result<()> {
1564        let ret = self.io_socket_fire_and_forget(RequestType::Close);
1565        if ret.is_ok() {
1566            self.release_socket_bindings();
1567        }
1568        ret
1569    }
1570
1571    fn try_recv(&mut self, buf: &mut [std::mem::MaybeUninit<u8>], peek: bool) -> Result<usize> {
1572        loop {
1573            if !self.rx_buffer.is_empty() {
1574                let amt = self.rx_buffer.len().min(buf.len());
1575                let buf: &mut [u8] = unsafe { std::mem::transmute(buf) };
1576                buf[..amt].copy_from_slice(&self.rx_buffer[..amt]);
1577                if !peek {
1578                    self.rx_buffer.advance(amt);
1579                }
1580                return Ok(amt);
1581            }
1582            match self.rx_recv.try_recv() {
1583                Ok(data) => self.rx_buffer.extend_from_slice(&data),
1584                Err(TryRecvError::Disconnected) => return Err(NetworkError::ConnectionAborted),
1585                Err(TryRecvError::Empty) => return Err(NetworkError::WouldBlock),
1586            }
1587        }
1588    }
1589}
1590
1591impl VirtualTcpSocket for RemoteSocket {
1592    fn set_recv_buf_size(&mut self, size: usize) -> Result<()> {
1593        self.io_socket_fire_and_forget(RequestType::SetRecvBufSize(size as u64))
1594    }
1595
1596    fn recv_buf_size(&self) -> Result<usize> {
1597        match block_on(self.io_socket(RequestType::GetRecvBufSize)) {
1598            ResponseType::Err(err) => Err(err),
1599            ResponseType::Amount(amt) => Ok(amt.try_into().map_err(|_| NetworkError::IOError)?),
1600            res => {
1601                tracing::debug!("invalid response to get recv buf size request - {res:?}");
1602                Err(NetworkError::IOError)
1603            }
1604        }
1605    }
1606
1607    fn set_send_buf_size(&mut self, size: usize) -> Result<()> {
1608        self.io_socket_fire_and_forget(RequestType::SetSendBufSize(size as u64))
1609    }
1610
1611    fn send_buf_size(&self) -> Result<usize> {
1612        match block_on(self.io_socket(RequestType::GetSendBufSize)) {
1613            ResponseType::Err(err) => Err(err),
1614            ResponseType::Amount(val) => Ok(val.try_into().map_err(|_| NetworkError::IOError)?),
1615            res => {
1616                tracing::debug!("invalid response to get send buf size request - {res:?}");
1617                Err(NetworkError::IOError)
1618            }
1619        }
1620    }
1621
1622    fn set_nodelay(&mut self, reuse: bool) -> Result<()> {
1623        self.io_socket_fire_and_forget(RequestType::SetNoDelay(reuse))
1624    }
1625
1626    fn nodelay(&self) -> Result<bool> {
1627        match block_on(self.io_socket(RequestType::GetNoDelay)) {
1628            ResponseType::Err(err) => Err(err),
1629            ResponseType::Flag(val) => Ok(val),
1630            res => {
1631                tracing::debug!("invalid response to get nodelay request - {res:?}");
1632                Err(NetworkError::IOError)
1633            }
1634        }
1635    }
1636
1637    fn set_keepalive(&mut self, keep_alive: bool) -> Result<()> {
1638        self.io_socket_fire_and_forget(RequestType::SetKeepAlive(keep_alive))
1639    }
1640
1641    fn keepalive(&self) -> Result<bool> {
1642        match block_on(self.io_socket(RequestType::GetKeepAlive)) {
1643            ResponseType::Err(err) => Err(err),
1644            ResponseType::Flag(val) => Ok(val),
1645            res => {
1646                tracing::debug!("invalid response to get nodelay request - {res:?}");
1647                Err(NetworkError::IOError)
1648            }
1649        }
1650    }
1651
1652    fn set_dontroute(&mut self, dont_route: bool) -> Result<()> {
1653        self.io_socket_fire_and_forget(RequestType::SetDontRoute(dont_route))
1654    }
1655
1656    fn dontroute(&self) -> Result<bool> {
1657        match block_on(self.io_socket(RequestType::GetDontRoute)) {
1658            ResponseType::Err(err) => Err(err),
1659            ResponseType::Flag(val) => Ok(val),
1660            res => {
1661                tracing::debug!("invalid response to get nodelay request - {res:?}");
1662                Err(NetworkError::IOError)
1663            }
1664        }
1665    }
1666
1667    fn addr_peer(&self) -> Result<SocketAddr> {
1668        match block_on(self.io_socket(RequestType::GetAddrPeer)) {
1669            ResponseType::Err(err) => Err(err),
1670            ResponseType::SocketAddr(addr) => Ok(addr),
1671            res => {
1672                tracing::debug!("invalid response to addr peer request - {res:?}");
1673                Err(NetworkError::IOError)
1674            }
1675        }
1676    }
1677
1678    fn shutdown(&mut self, how: std::net::Shutdown) -> Result<()> {
1679        let shutdown = match how {
1680            std::net::Shutdown::Read => meta::Shutdown::Read,
1681            std::net::Shutdown::Write => meta::Shutdown::Write,
1682            std::net::Shutdown::Both => meta::Shutdown::Both,
1683        };
1684        self.io_socket_fire_and_forget(RequestType::Shutdown(shutdown))
1685    }
1686
1687    fn is_closed(&self) -> bool {
1688        match block_on(self.io_socket(RequestType::IsClosed)) {
1689            ResponseType::Flag(val) => val,
1690            _ => false,
1691        }
1692    }
1693}