virtual_net/
client.rs

1use std::collections::HashMap;
2use std::collections::VecDeque;
3use std::future::Future;
4use std::net::IpAddr;
5use std::net::SocketAddr;
6use std::pin::Pin;
7use std::sync::Arc;
8use std::sync::Mutex;
9use std::sync::atomic::AtomicU64;
10use std::sync::atomic::Ordering;
11use std::task::Context;
12use std::task::Poll;
13use std::task::RawWaker;
14use std::task::RawWakerVTable;
15use std::task::Waker;
16use std::time::Duration;
17
18use bytes::Buf;
19use bytes::BytesMut;
20use futures_util::Sink;
21use futures_util::Stream;
22use futures_util::StreamExt;
23use futures_util::future::BoxFuture;
24use futures_util::stream::FuturesOrdered;
25#[cfg(feature = "hyper")]
26use hyper_util::rt::tokio::TokioIo;
27use tokio::io::AsyncRead;
28use tokio::io::AsyncWrite;
29use tokio::sync::mpsc;
30use tokio::sync::mpsc::error::TryRecvError;
31use tokio::sync::mpsc::error::TrySendError;
32use tokio_serde::SymmetricallyFramed;
33use tokio_serde::formats::SymmetricalBincode;
34#[cfg(feature = "cbor")]
35use tokio_serde::formats::SymmetricalCbor;
36#[cfg(feature = "json")]
37use tokio_serde::formats::SymmetricalJson;
38#[cfg(feature = "messagepack")]
39use tokio_serde::formats::SymmetricalMessagePack;
40use tokio_util::codec::FramedRead;
41use tokio_util::codec::FramedWrite;
42use tokio_util::codec::LengthDelimitedCodec;
43use virtual_mio::InterestType;
44use virtual_mio::block_on;
45
46use crate::IpCidr;
47use crate::IpRoute;
48use crate::NetworkError;
49use crate::StreamSecurity;
50use crate::VirtualConnectedSocket;
51use crate::VirtualConnectionlessSocket;
52use crate::VirtualIcmpSocket;
53use crate::VirtualIoSource;
54use crate::VirtualNetworking;
55use crate::VirtualRawSocket;
56use crate::VirtualSocket;
57use crate::VirtualTcpBoundSocket;
58use crate::VirtualTcpListener;
59use crate::VirtualTcpSocket;
60use crate::VirtualUdpSocket;
61use crate::meta;
62use crate::meta::FrameSerializationFormat;
63use crate::meta::RequestType;
64use crate::meta::ResponseType;
65use crate::meta::SocketId;
66use crate::meta::{MessageRequest, MessageResponse};
67
68use crate::Result;
69use crate::rx_tx::RemoteRx;
70use crate::rx_tx::RemoteTx;
71use crate::rx_tx::RemoteTxWakers;
72
73#[derive(Debug, Clone)]
74pub struct RemoteNetworkingClient {
75    common: Arc<RemoteCommon>,
76}
77
78impl RemoteNetworkingClient {
79    fn new(
80        tx: RemoteTx<MessageRequest>,
81        rx: RemoteRx<MessageResponse>,
82        rx_work: mpsc::UnboundedReceiver<BoxFuture<'static, ()>>,
83    ) -> (Self, RemoteNetworkingClientDriver) {
84        let common = RemoteCommon {
85            tx,
86            rx: Mutex::new(rx),
87            request_seed: AtomicU64::new(1),
88            requests: Default::default(),
89            socket_seed: AtomicU64::new(1),
90            recv_tx: Default::default(),
91            recv_with_addr_tx: Default::default(),
92            accept_tx: Default::default(),
93            sent_tx: Default::default(),
94            handlers: Default::default(),
95            stall: Default::default(),
96        };
97        let common = Arc::new(common);
98
99        let driver = RemoteNetworkingClientDriver {
100            more_work: rx_work,
101            tasks: Default::default(),
102            common: common.clone(),
103        };
104        let networking = Self { common };
105
106        (networking, driver)
107    }
108
109    /// Creates a new interface on the remote location using
110    /// a unique interface ID and a pair of channels
111    pub fn new_from_mpsc(
112        tx: mpsc::Sender<MessageRequest>,
113        rx: mpsc::Receiver<MessageResponse>,
114    ) -> (Self, RemoteNetworkingClientDriver) {
115        let (tx_work, rx_work) = mpsc::unbounded_channel();
116        let tx_wakers = RemoteTxWakers::default();
117
118        let tx = RemoteTx::Mpsc {
119            tx,
120            work: tx_work,
121            wakers: tx_wakers.clone(),
122        };
123        let rx = RemoteRx::Mpsc {
124            rx,
125            wakers: tx_wakers,
126        };
127
128        Self::new(tx, rx, rx_work)
129    }
130
131    /// Creates a new interface on the remote location using
132    /// a unique interface ID and a pair of channels
133    ///
134    /// This version will run the async read and write operations
135    /// only the driver (this is needed for mixed runtimes)
136    pub fn new_from_async_io<TX, RX>(
137        tx: TX,
138        rx: RX,
139        format: FrameSerializationFormat,
140    ) -> (Self, RemoteNetworkingClientDriver)
141    where
142        TX: AsyncWrite + Send + 'static,
143        RX: AsyncRead + Send + 'static,
144    {
145        let tx = FramedWrite::new(tx, LengthDelimitedCodec::new());
146        let tx: Pin<Box<dyn Sink<MessageRequest, Error = std::io::Error> + Send + 'static>> =
147            match format {
148                FrameSerializationFormat::Bincode => {
149                    Box::pin(SymmetricallyFramed::new(tx, SymmetricalBincode::default()))
150                }
151                #[cfg(feature = "json")]
152                FrameSerializationFormat::Json => {
153                    Box::pin(SymmetricallyFramed::new(tx, SymmetricalJson::default()))
154                }
155                #[cfg(feature = "messagepack")]
156                FrameSerializationFormat::MessagePack => Box::pin(SymmetricallyFramed::new(
157                    tx,
158                    SymmetricalMessagePack::default(),
159                )),
160                #[cfg(feature = "cbor")]
161                FrameSerializationFormat::Cbor => {
162                    Box::pin(SymmetricallyFramed::new(tx, SymmetricalCbor::default()))
163                }
164            };
165
166        let rx = FramedRead::new(rx, LengthDelimitedCodec::new());
167        let rx: Pin<Box<dyn Stream<Item = std::io::Result<MessageResponse>> + Send + 'static>> =
168            match format {
169                FrameSerializationFormat::Bincode => {
170                    Box::pin(SymmetricallyFramed::new(rx, SymmetricalBincode::default()))
171                }
172                #[cfg(feature = "json")]
173                FrameSerializationFormat::Json => {
174                    Box::pin(SymmetricallyFramed::new(rx, SymmetricalJson::default()))
175                }
176                #[cfg(feature = "messagepack")]
177                FrameSerializationFormat::MessagePack => Box::pin(SymmetricallyFramed::new(
178                    rx,
179                    SymmetricalMessagePack::default(),
180                )),
181                #[cfg(feature = "cbor")]
182                FrameSerializationFormat::Cbor => {
183                    Box::pin(SymmetricallyFramed::new(rx, SymmetricalCbor::default()))
184                }
185            };
186
187        let (tx_work, rx_work) = mpsc::unbounded_channel();
188        let tx_wakers = RemoteTxWakers::default();
189
190        let tx = RemoteTx::Stream {
191            tx: Arc::new(tokio::sync::Mutex::new(tx)),
192            work: tx_work,
193            wakers: tx_wakers,
194        };
195        let rx = RemoteRx::Stream { rx };
196
197        Self::new(tx, rx, rx_work)
198    }
199
200    /// Creates a new interface on the remote location using
201    /// a unique interface ID and a pair of channels
202    #[cfg(feature = "hyper")]
203    pub fn new_from_hyper_ws_io(
204        tx: futures_util::stream::SplitSink<
205            hyper_tungstenite::WebSocketStream<TokioIo<hyper::upgrade::Upgraded>>,
206            hyper_tungstenite::tungstenite::Message,
207        >,
208        rx: futures_util::stream::SplitStream<
209            hyper_tungstenite::WebSocketStream<TokioIo<hyper::upgrade::Upgraded>>,
210        >,
211        format: FrameSerializationFormat,
212    ) -> (Self, RemoteNetworkingClientDriver) {
213        let (tx_work, rx_work) = mpsc::unbounded_channel();
214
215        let tx = RemoteTx::HyperWebSocket {
216            tx: Arc::new(tokio::sync::Mutex::new(tx)),
217            work: tx_work,
218            wakers: RemoteTxWakers::default(),
219            format,
220        };
221        let rx = RemoteRx::HyperWebSocket { rx, format };
222        Self::new(tx, rx, rx_work)
223    }
224
225    /// Creates a new interface on the remote location using
226    /// a unique interface ID and a pair of channels
227    #[cfg(feature = "tokio-tungstenite")]
228    pub fn new_from_tokio_ws_io(
229        tx: futures_util::stream::SplitSink<
230            tokio_tungstenite::WebSocketStream<
231                tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
232            >,
233            tokio_tungstenite::tungstenite::Message,
234        >,
235        rx: futures_util::stream::SplitStream<
236            tokio_tungstenite::WebSocketStream<
237                tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
238            >,
239        >,
240        format: FrameSerializationFormat,
241    ) -> (Self, RemoteNetworkingClientDriver) {
242        let (tx_work, rx_work) = mpsc::unbounded_channel();
243
244        let tx = RemoteTx::TokioWebSocket {
245            tx: Arc::new(tokio::sync::Mutex::new(tx)),
246            work: tx_work,
247            wakers: RemoteTxWakers::default(),
248            format,
249        };
250        let rx = RemoteRx::TokioWebSocket { rx, format };
251        Self::new(tx, rx, rx_work)
252    }
253
254    fn new_socket(&self, id: SocketId) -> RemoteSocket {
255        let (tx, rx_recv) = tokio::sync::mpsc::channel(100);
256        self.common.recv_tx.lock().unwrap().insert(id, tx);
257
258        let (tx, rx_recv_with_addr) = tokio::sync::mpsc::channel(100);
259        self.common.recv_with_addr_tx.lock().unwrap().insert(id, tx);
260
261        let (tx, rx_accept) = tokio::sync::mpsc::channel(100);
262        self.common.accept_tx.lock().unwrap().insert(id, tx);
263
264        let (tx, rx_sent) = tokio::sync::mpsc::channel(100);
265        self.common.sent_tx.lock().unwrap().insert(id, tx);
266
267        RemoteSocket {
268            socket_id: id,
269            common: self.common.clone(),
270            rx_buffer: BytesMut::new(),
271            rx_recv,
272            rx_recv_with_addr,
273            rx_accept,
274            rx_sent,
275            tx_waker: TxWaker::new(&self.common).as_waker(),
276            pending_accept: None,
277            buffer_accept: Default::default(),
278            buffer_recv_with_addr: Default::default(),
279            send_available: 0,
280            owns_socket_bindings: true,
281        }
282    }
283}
284
285pin_project_lite::pin_project! {
286    pub struct RemoteNetworkingClientDriver {
287        common: Arc<RemoteCommon>,
288        more_work: mpsc::UnboundedReceiver<BoxFuture<'static, ()>>,
289        #[pin]
290        tasks: FuturesOrdered<BoxFuture<'static, ()>>,
291    }
292}
293
294impl Future for RemoteNetworkingClientDriver {
295    type Output = ();
296
297    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
298        // This guard will be held while the pipeline is not currently
299        // stalled by some back pressure. It is only acquired when there
300        // is background tasks being processed
301        let mut not_stalled_guard = None;
302
303        // We loop until the waker is registered with the receiving stream
304        // and all the background tasks
305        loop {
306            // Background tasks are sent to this driver in certain circumstances
307            while let Poll::Ready(Some(work)) = Pin::new(&mut self.more_work).poll_recv(cx) {
308                self.tasks.push_back(work);
309            }
310
311            // Background work basically stalls the stream until its all processed
312            // which makes the back pressure system work properly
313            match self.tasks.poll_next_unpin(cx) {
314                Poll::Ready(Some(_)) => continue,
315                Poll::Ready(None) => {
316                    not_stalled_guard.take();
317                }
318                Poll::Pending if not_stalled_guard.is_none() => {
319                    match self.common.stall.clone().try_lock_owned() {
320                        Ok(guard) => {
321                            not_stalled_guard.replace(guard);
322                        }
323                        _ => {
324                            return Poll::Pending;
325                        }
326                    }
327                }
328                Poll::Pending => {}
329            };
330
331            // We grab the next message sent by the server to us
332            let msg = {
333                let mut rx_guard = self.common.rx.lock().unwrap();
334                rx_guard.poll(cx)
335            };
336            return match msg {
337                Poll::Ready(Some(msg)) => {
338                    match msg {
339                        MessageResponse::Recv { socket_id, data } => {
340                            let tx = {
341                                let guard = self.common.recv_tx.lock().unwrap();
342                                match guard.get(&socket_id) {
343                                    Some(tx) => tx.clone(),
344                                    None => {
345                                        continue;
346                                    }
347                                }
348                            };
349                            let common = self.common.clone();
350                            self.tasks.push_back(Box::pin(async move {
351                                tx.send(data).await.ok();
352
353                                if let Some(h) = common.handlers.lock().unwrap().get_mut(&socket_id)
354                                {
355                                    h.push_interest(InterestType::Readable)
356                                }
357                            }));
358                        }
359                        MessageResponse::RecvWithAddr {
360                            socket_id,
361                            data,
362                            addr,
363                        } => {
364                            let tx = {
365                                let guard = self.common.recv_with_addr_tx.lock().unwrap();
366                                match guard.get(&socket_id) {
367                                    Some(tx) => tx.clone(),
368                                    None => continue,
369                                }
370                            };
371                            let common = self.common.clone();
372                            self.tasks.push_back(Box::pin(async move {
373                                tx.send(DataWithAddr { data, addr }).await.ok();
374
375                                if let Some(h) = common.handlers.lock().unwrap().get_mut(&socket_id)
376                                {
377                                    h.push_interest(InterestType::Readable)
378                                }
379                            }));
380                        }
381                        MessageResponse::Sent {
382                            socket_id, amount, ..
383                        } => {
384                            let tx = {
385                                let guard = self.common.sent_tx.lock().unwrap();
386                                match guard.get(&socket_id) {
387                                    Some(tx) => tx.clone(),
388                                    None => continue,
389                                }
390                            };
391                            self.tasks.push_back(Box::pin(async move {
392                                tx.send(amount).await.ok();
393                            }));
394                            if let Some(h) =
395                                self.common.handlers.lock().unwrap().get_mut(&socket_id)
396                            {
397                                h.push_interest(InterestType::Writable)
398                            }
399                        }
400                        MessageResponse::SendError {
401                            socket_id, error, ..
402                        } => match &error {
403                            NetworkError::ConnectionAborted
404                            | NetworkError::ConnectionReset
405                            | NetworkError::BrokenPipe => {
406                                if let Some(h) =
407                                    self.common.handlers.lock().unwrap().get_mut(&socket_id)
408                                {
409                                    h.push_interest(InterestType::Closed)
410                                }
411                            }
412                            _ => {
413                                if let Some(h) =
414                                    self.common.handlers.lock().unwrap().get_mut(&socket_id)
415                                {
416                                    h.push_interest(InterestType::Writable)
417                                }
418                            }
419                        },
420                        MessageResponse::FinishAccept {
421                            socket_id,
422                            child_id,
423                            addr,
424                        } => {
425                            let common = self.common.clone();
426                            self.tasks.push_back(Box::pin(async move {
427                                let tx = common.accept_tx.lock().unwrap().get(&socket_id).cloned();
428                                if let Some(tx) = tx {
429                                    tx.send(SocketWithAddr {
430                                        socket: child_id,
431                                        addr,
432                                    })
433                                    .await
434                                    .ok();
435                                }
436
437                                if let Some(h) = common.handlers.lock().unwrap().get_mut(&socket_id)
438                                {
439                                    h.push_interest(InterestType::Readable)
440                                }
441                            }));
442                        }
443                        MessageResponse::Closed { socket_id } => {
444                            if let Some(h) =
445                                self.common.handlers.lock().unwrap().get_mut(&socket_id)
446                            {
447                                h.push_interest(InterestType::Closed)
448                            }
449                        }
450                        MessageResponse::ResponseToRequest { req_id, res } => {
451                            let mut requests = self.common.requests.lock().unwrap();
452                            if let Some(request) = requests.remove(&req_id) {
453                                request.try_send(res).ok();
454                            }
455                        }
456                    }
457                    continue;
458                }
459                Poll::Ready(None) => Poll::Ready(()),
460                Poll::Pending => Poll::Pending,
461            };
462        }
463    }
464}
465
466#[derive(Debug)]
467struct TxWaker {
468    common: Arc<RemoteCommon>,
469}
470impl TxWaker {
471    pub fn new(common: &Arc<RemoteCommon>) -> Arc<Self> {
472        Arc::new(Self {
473            common: common.clone(),
474        })
475    }
476
477    fn wake_now(&self) {
478        let mut guard = self.common.handlers.lock().unwrap();
479        for handler in guard.values_mut() {
480            handler.push_interest(InterestType::Writable);
481        }
482    }
483
484    pub fn as_waker(self: &Arc<Self>) -> Waker {
485        let s: *const Self = Arc::into_raw(Arc::clone(self));
486        let raw_waker = RawWaker::new(s as *const (), &VTABLE);
487        unsafe { Waker::from_raw(raw_waker) }
488    }
489}
490
491fn tx_waker_wake(s: &TxWaker) {
492    let waker_arc = unsafe { Arc::from_raw(s) };
493    waker_arc.wake_now();
494}
495
496fn tx_waker_clone(s: &TxWaker) -> RawWaker {
497    let arc = unsafe { Arc::from_raw(s) };
498    std::mem::forget(arc.clone());
499    RawWaker::new(Arc::into_raw(arc) as *const (), &VTABLE)
500}
501
502const VTABLE: RawWakerVTable = unsafe {
503    RawWakerVTable::new(
504        |s| tx_waker_clone(&*(s as *const TxWaker)),  // clone
505        |s| tx_waker_wake(&*(s as *const TxWaker)),   // wake
506        |s| (*(s as *const TxWaker)).wake_now(),      // wake by ref (don't decrease refcount)
507        |s| drop(Arc::from_raw(s as *const TxWaker)), // decrease refcount
508    )
509};
510
511#[derive(Debug)]
512struct RequestTx {
513    tx: mpsc::Sender<ResponseType>,
514}
515impl RequestTx {
516    pub fn try_send(self, msg: ResponseType) -> Result<()> {
517        match self.tx.try_send(msg) {
518            Ok(()) => Ok(()),
519            Err(TrySendError::Closed(_)) => Err(NetworkError::ConnectionAborted),
520            Err(TrySendError::Full(_)) => Err(NetworkError::WouldBlock),
521        }
522    }
523}
524
525#[derive(Debug)]
526struct DataWithAddr {
527    pub data: Vec<u8>,
528    pub addr: SocketAddr,
529}
530#[derive(Debug)]
531struct SocketWithAddr {
532    pub socket: SocketId,
533    pub addr: SocketAddr,
534}
535type SocketMap<T> = HashMap<SocketId, T>;
536
537#[derive(derive_more::Debug)]
538struct RemoteCommon {
539    #[debug(ignore)]
540    tx: RemoteTx<MessageRequest>,
541    #[debug(ignore)]
542    rx: Mutex<RemoteRx<MessageResponse>>,
543    request_seed: AtomicU64,
544    requests: Mutex<HashMap<u64, RequestTx>>,
545    socket_seed: AtomicU64,
546    recv_tx: Mutex<SocketMap<mpsc::Sender<Vec<u8>>>>,
547    recv_with_addr_tx: Mutex<SocketMap<mpsc::Sender<DataWithAddr>>>,
548    accept_tx: Mutex<SocketMap<mpsc::Sender<SocketWithAddr>>>,
549    sent_tx: Mutex<SocketMap<mpsc::Sender<u64>>>,
550    #[debug(ignore)]
551    handlers: Mutex<SocketMap<Box<dyn virtual_mio::InterestHandler + Send + Sync>>>,
552
553    // The stall guard will prevent reads while its held and there are background tasks running
554    // (the idea behind this is to create back pressure so that the task list infinitely grow)
555    stall: Arc<tokio::sync::Mutex<()>>,
556}
557
558impl RemoteCommon {
559    async fn io_iface(&self, req: RequestType) -> ResponseType {
560        let req_id = self.request_seed.fetch_add(1, Ordering::SeqCst);
561        let mut req_rx = {
562            let (tx, rx) = mpsc::channel(1);
563            let mut guard = self.requests.lock().unwrap();
564            guard.insert(req_id, RequestTx { tx });
565            rx
566        };
567        if let Err(err) = self
568            .tx
569            .send(MessageRequest::Interface {
570                req_id: Some(req_id),
571                req,
572            })
573            .await
574        {
575            return ResponseType::Err(err);
576        };
577        req_rx.recv().await.unwrap()
578    }
579
580    fn io_iface_fire_and_forget(&self, req: RequestType) -> Result<()> {
581        self.tx
582            .send_with_driver(MessageRequest::Interface { req_id: None, req })
583    }
584}
585
586#[async_trait::async_trait]
587impl VirtualNetworking for RemoteNetworkingClient {
588    async fn bridge(
589        &self,
590        network: &str,
591        access_token: &str,
592        security: StreamSecurity,
593    ) -> Result<()> {
594        match self
595            .common
596            .io_iface(RequestType::Bridge {
597                network: network.to_string(),
598                access_token: access_token.to_string(),
599                security,
600            })
601            .await
602        {
603            ResponseType::Err(err) => Err(err),
604            ResponseType::None => Ok(()),
605            res => {
606                tracing::debug!("invalid response to bridge request - {res:?}");
607                Err(NetworkError::IOError)
608            }
609        }
610    }
611
612    async fn unbridge(&self) -> Result<()> {
613        match self.common.io_iface(RequestType::Unbridge).await {
614            ResponseType::Err(err) => Err(err),
615            ResponseType::None => Ok(()),
616            res => {
617                tracing::debug!("invalid response to unbridge request - {res:?}");
618                Err(NetworkError::IOError)
619            }
620        }
621    }
622
623    async fn dhcp_acquire(&self) -> Result<Vec<IpAddr>> {
624        match self.common.io_iface(RequestType::DhcpAcquire).await {
625            ResponseType::Err(err) => Err(err),
626            ResponseType::IpAddressList(ips) => Ok(ips),
627            res => {
628                tracing::debug!("invalid response to DHCP acquire request - {res:?}");
629                Err(NetworkError::IOError)
630            }
631        }
632    }
633
634    async fn ip_add(&self, ip: IpAddr, prefix: u8) -> Result<()> {
635        self.common
636            .io_iface_fire_and_forget(RequestType::IpAdd { ip, prefix })
637    }
638
639    async fn ip_remove(&self, ip: IpAddr) -> Result<()> {
640        self.common
641            .io_iface_fire_and_forget(RequestType::IpRemove(ip))
642    }
643
644    async fn ip_clear(&self) -> Result<()> {
645        self.common.io_iface_fire_and_forget(RequestType::IpClear)
646    }
647
648    async fn ip_list(&self) -> Result<Vec<IpCidr>> {
649        match self.common.io_iface(RequestType::GetIpList).await {
650            ResponseType::Err(err) => Err(err),
651            ResponseType::CidrList(routes) => Ok(routes),
652            res => {
653                tracing::debug!("invalid response to IP list request - {res:?}");
654                Err(NetworkError::IOError)
655            }
656        }
657    }
658
659    async fn mac(&self) -> Result<[u8; 6]> {
660        match self.common.io_iface(RequestType::GetMac).await {
661            ResponseType::Err(err) => Err(err),
662            ResponseType::Mac(mac) => Ok(mac),
663            res => {
664                tracing::debug!("invalid response to MAC request - {res:?}");
665                Err(NetworkError::IOError)
666            }
667        }
668    }
669
670    async fn gateway_set(&self, ip: IpAddr) -> Result<()> {
671        self.common
672            .io_iface_fire_and_forget(RequestType::GatewaySet(ip))
673    }
674
675    async fn route_add(
676        &self,
677        cidr: IpCidr,
678        via_router: IpAddr,
679        preferred_until: Option<Duration>,
680        expires_at: Option<Duration>,
681    ) -> Result<()> {
682        self.common.io_iface_fire_and_forget(RequestType::RouteAdd {
683            cidr,
684            via_router,
685            preferred_until,
686            expires_at,
687        })
688    }
689
690    async fn route_remove(&self, cidr: IpAddr) -> Result<()> {
691        self.common
692            .io_iface_fire_and_forget(RequestType::RouteRemove(cidr))
693    }
694
695    async fn route_clear(&self) -> Result<()> {
696        self.common
697            .io_iface_fire_and_forget(RequestType::RouteClear)
698    }
699
700    async fn route_list(&self) -> Result<Vec<IpRoute>> {
701        match self.common.io_iface(RequestType::GetRouteList).await {
702            ResponseType::Err(err) => Err(err),
703            ResponseType::RouteList(routes) => Ok(routes),
704            res => {
705                tracing::debug!("invalid response to route list request - {res:?}");
706                Err(NetworkError::IOError)
707            }
708        }
709    }
710
711    async fn bind_raw(&self) -> Result<Box<dyn VirtualRawSocket + Sync>> {
712        let socket_id: SocketId = self
713            .common
714            .socket_seed
715            .fetch_add(1, Ordering::SeqCst)
716            .into();
717        match self.common.io_iface(RequestType::BindRaw(socket_id)).await {
718            ResponseType::Err(err) => Err(err),
719            ResponseType::None => Ok(Box::new(self.new_socket(socket_id))),
720            ResponseType::Socket(socket_id) => Ok(Box::new(self.new_socket(socket_id))),
721            res => {
722                tracing::debug!("invalid response to bind RAw request - {res:?}");
723                Err(NetworkError::IOError)
724            }
725        }
726    }
727
728    async fn listen_tcp(
729        &self,
730        addr: SocketAddr,
731        only_v6: bool,
732        reuse_port: bool,
733        reuse_addr: bool,
734    ) -> Result<Box<dyn VirtualTcpListener + Sync>> {
735        let socket_id: SocketId = self
736            .common
737            .socket_seed
738            .fetch_add(1, Ordering::SeqCst)
739            .into();
740        match self
741            .common
742            .io_iface(RequestType::ListenTcp {
743                socket_id,
744                addr,
745                only_v6,
746                reuse_port,
747                reuse_addr,
748            })
749            .await
750        {
751            ResponseType::Err(err) => Err(err),
752            ResponseType::None => Ok(Box::new(self.new_socket(socket_id))),
753            ResponseType::Socket(socket_id) => {
754                let mut socket = self.new_socket(socket_id);
755                socket.touch_begin_accept().ok();
756                Ok(Box::new(socket))
757            }
758            res => {
759                tracing::debug!("invalid response to listen TCP request - {res:?}");
760                Err(NetworkError::IOError)
761            }
762        }
763    }
764
765    async fn bind_tcp(
766        &self,
767        addr: SocketAddr,
768        only_v6: bool,
769        reuse_port: bool,
770        reuse_addr: bool,
771    ) -> Result<Box<dyn VirtualTcpBoundSocket + Sync>> {
772        let socket_id: SocketId = self
773            .common
774            .socket_seed
775            .fetch_add(1, Ordering::SeqCst)
776            .into();
777        match self
778            .common
779            .io_iface(RequestType::BindTcp {
780                socket_id,
781                addr,
782                only_v6,
783                reuse_port,
784                reuse_addr,
785            })
786            .await
787        {
788            ResponseType::Err(err) => Err(err),
789            ResponseType::None => Ok(Box::new(self.new_socket(socket_id))),
790            ResponseType::Socket(socket_id) => Ok(Box::new(self.new_socket(socket_id))),
791            res => {
792                tracing::debug!("invalid response to bind TCP request - {res:?}");
793                Err(NetworkError::IOError)
794            }
795        }
796    }
797
798    async fn bind_udp(
799        &self,
800        addr: SocketAddr,
801        reuse_port: bool,
802        reuse_addr: bool,
803    ) -> Result<Box<dyn VirtualUdpSocket + Sync>> {
804        let socket_id: SocketId = self
805            .common
806            .socket_seed
807            .fetch_add(1, Ordering::SeqCst)
808            .into();
809        match self
810            .common
811            .io_iface(RequestType::BindUdp {
812                socket_id,
813                addr,
814                reuse_port,
815                reuse_addr,
816            })
817            .await
818        {
819            ResponseType::Err(err) => Err(err),
820            ResponseType::None => Ok(Box::new(self.new_socket(socket_id))),
821            ResponseType::Socket(socket_id) => Ok(Box::new(self.new_socket(socket_id))),
822            res => {
823                tracing::debug!("invalid response to bind UDP request - {res:?}");
824                Err(NetworkError::IOError)
825            }
826        }
827    }
828
829    async fn bind_icmp(&self, addr: IpAddr) -> Result<Box<dyn VirtualIcmpSocket + Sync>> {
830        let socket_id: SocketId = self
831            .common
832            .socket_seed
833            .fetch_add(1, Ordering::SeqCst)
834            .into();
835        match self
836            .common
837            .io_iface(RequestType::BindIcmp { socket_id, addr })
838            .await
839        {
840            ResponseType::Err(err) => Err(err),
841            ResponseType::None => Ok(Box::new(self.new_socket(socket_id))),
842            ResponseType::Socket(socket_id) => Ok(Box::new(self.new_socket(socket_id))),
843            res => {
844                tracing::debug!("invalid response to bind ICMP request - {res:?}");
845                Err(NetworkError::IOError)
846            }
847        }
848    }
849
850    async fn connect_tcp(
851        &self,
852        addr: SocketAddr,
853        peer: SocketAddr,
854    ) -> Result<Box<dyn VirtualTcpSocket + Sync>> {
855        let socket_id: SocketId = self
856            .common
857            .socket_seed
858            .fetch_add(1, Ordering::SeqCst)
859            .into();
860        match self
861            .common
862            .io_iface(RequestType::ConnectTcp {
863                socket_id,
864                addr,
865                peer,
866            })
867            .await
868        {
869            ResponseType::Err(err) => Err(err),
870            ResponseType::None => Ok(Box::new(self.new_socket(socket_id))),
871            ResponseType::Socket(socket_id) => Ok(Box::new(self.new_socket(socket_id))),
872            res => {
873                tracing::debug!("invalid response to connect TCP request - {res:?}");
874                Err(NetworkError::IOError)
875            }
876        }
877    }
878
879    async fn resolve(
880        &self,
881        host: &str,
882        port: Option<u16>,
883        dns_server: Option<IpAddr>,
884    ) -> Result<Vec<IpAddr>> {
885        match self
886            .common
887            .io_iface(RequestType::Resolve {
888                host: host.to_string(),
889                port,
890                dns_server,
891            })
892            .await
893        {
894            ResponseType::Err(err) => Err(err),
895            ResponseType::IpAddressList(ips) => Ok(ips),
896            res => {
897                tracing::debug!("invalid response to resolve request - {res:?}");
898                Err(NetworkError::IOError)
899            }
900        }
901    }
902}
903
904#[derive(Debug)]
905struct RemoteSocket {
906    socket_id: SocketId,
907    common: Arc<RemoteCommon>,
908    rx_buffer: BytesMut,
909    rx_recv: mpsc::Receiver<Vec<u8>>,
910    rx_recv_with_addr: mpsc::Receiver<DataWithAddr>,
911    tx_waker: Waker,
912    rx_accept: mpsc::Receiver<SocketWithAddr>,
913    rx_sent: mpsc::Receiver<u64>,
914    pending_accept: Option<(SocketId, mpsc::Receiver<Vec<u8>>)>,
915    buffer_recv_with_addr: VecDeque<DataWithAddr>,
916    buffer_accept: VecDeque<SocketWithAddr>,
917    send_available: u64,
918    owns_socket_bindings: bool,
919}
920impl Drop for RemoteSocket {
921    fn drop(&mut self) {
922        if !self.owns_socket_bindings {
923            return;
924        }
925        let _ = self.io_socket_fire_and_forget(RequestType::Close);
926        self.release_socket_bindings();
927    }
928}
929
930impl RemoteSocket {
931    fn release_socket_bindings(&mut self) {
932        self.owns_socket_bindings = false;
933        self.common.recv_tx.lock().unwrap().remove(&self.socket_id);
934        self.common
935            .recv_with_addr_tx
936            .lock()
937            .unwrap()
938            .remove(&self.socket_id);
939        self.common
940            .accept_tx
941            .lock()
942            .unwrap()
943            .remove(&self.socket_id);
944        self.common.sent_tx.lock().unwrap().remove(&self.socket_id);
945        self.common.handlers.lock().unwrap().remove(&self.socket_id);
946
947        if let Some((child_id, _)) = self.pending_accept.take() {
948            self.common.recv_tx.lock().unwrap().remove(&child_id);
949        }
950    }
951
952    async fn io_socket(&self, req: RequestType) -> ResponseType {
953        let req_id = self.common.request_seed.fetch_add(1, Ordering::SeqCst);
954        let mut req_rx = {
955            let (tx, rx) = mpsc::channel(1);
956            let mut guard = self.common.requests.lock().unwrap();
957            guard.insert(req_id, RequestTx { tx });
958            rx
959        };
960        if let Err(err) = self
961            .common
962            .tx
963            .send(MessageRequest::Socket {
964                socket: self.socket_id,
965                req_id: Some(req_id),
966                req,
967            })
968            .await
969        {
970            return ResponseType::Err(err);
971        };
972        req_rx.recv().await.unwrap()
973    }
974
975    fn io_socket_fire_and_forget(&self, req: RequestType) -> Result<()> {
976        self.common.tx.send_with_driver(MessageRequest::Socket {
977            socket: self.socket_id,
978            req_id: None,
979            req,
980        })
981    }
982
983    fn touch_begin_accept(&mut self) -> Result<()> {
984        if self.pending_accept.is_some() {
985            return Ok(());
986        }
987        let child_id: SocketId = self
988            .common
989            .socket_seed
990            .fetch_add(1, Ordering::SeqCst)
991            .into();
992        self.io_socket_fire_and_forget(RequestType::BeginAccept(child_id))?;
993
994        let (tx, rx_recv) = tokio::sync::mpsc::channel(100);
995        self.common.recv_tx.lock().unwrap().insert(child_id, tx);
996
997        self.pending_accept.replace((child_id, rx_recv));
998        Ok(())
999    }
1000
1001    fn transition_socket(&mut self) -> RemoteSocket {
1002        let (_tx_recv, rx_recv) = tokio::sync::mpsc::channel(1);
1003        let (_tx_recv_with_addr, rx_recv_with_addr) = tokio::sync::mpsc::channel(1);
1004        let (_tx_accept, rx_accept) = tokio::sync::mpsc::channel(1);
1005        let (_tx_sent, rx_sent) = tokio::sync::mpsc::channel(1);
1006
1007        self.owns_socket_bindings = false;
1008
1009        RemoteSocket {
1010            socket_id: self.socket_id,
1011            common: self.common.clone(),
1012            rx_buffer: std::mem::take(&mut self.rx_buffer),
1013            rx_recv: std::mem::replace(&mut self.rx_recv, rx_recv),
1014            rx_recv_with_addr: std::mem::replace(&mut self.rx_recv_with_addr, rx_recv_with_addr),
1015            tx_waker: self.tx_waker.clone(),
1016            rx_accept: std::mem::replace(&mut self.rx_accept, rx_accept),
1017            rx_sent: std::mem::replace(&mut self.rx_sent, rx_sent),
1018            pending_accept: self.pending_accept.take(),
1019            buffer_recv_with_addr: std::mem::take(&mut self.buffer_recv_with_addr),
1020            buffer_accept: std::mem::take(&mut self.buffer_accept),
1021            send_available: self.send_available,
1022            owns_socket_bindings: true,
1023        }
1024    }
1025}
1026
1027impl VirtualIoSource for RemoteSocket {
1028    fn remove_handler(&mut self) {
1029        self.common.handlers.lock().unwrap().remove(&self.socket_id);
1030    }
1031
1032    fn poll_read_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<usize>> {
1033        if !self.rx_buffer.is_empty() {
1034            return Poll::Ready(Ok(self.rx_buffer.len()));
1035        }
1036        match self.rx_recv.poll_recv(cx) {
1037            Poll::Ready(Some(data)) => {
1038                self.rx_buffer.extend_from_slice(&data);
1039                return Poll::Ready(Ok(self.rx_buffer.len()));
1040            }
1041            Poll::Ready(None) => return Poll::Ready(Ok(0)),
1042            Poll::Pending => {}
1043        }
1044        if !self.buffer_recv_with_addr.is_empty() {
1045            let total = self
1046                .buffer_recv_with_addr
1047                .iter()
1048                .map(|a| a.data.len())
1049                .sum();
1050            return Poll::Ready(Ok(total));
1051        }
1052        match self.rx_recv_with_addr.poll_recv(cx) {
1053            Poll::Ready(Some(data)) => self.buffer_recv_with_addr.push_back(data),
1054            Poll::Ready(None) => return Poll::Ready(Ok(0)),
1055            Poll::Pending => {}
1056        }
1057        if !self.buffer_accept.is_empty() {
1058            return Poll::Ready(Ok(self.buffer_accept.len()));
1059        }
1060        match self.rx_accept.poll_recv(cx) {
1061            Poll::Ready(Some(data)) => self.buffer_accept.push_back(data),
1062            Poll::Ready(None) => {}
1063            Poll::Pending => {}
1064        }
1065        Poll::Pending
1066    }
1067
1068    fn poll_write_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<usize>> {
1069        if self.send_available > 0 {
1070            return Poll::Ready(Ok(self.send_available as usize));
1071        }
1072        match self.rx_sent.poll_recv(cx) {
1073            Poll::Ready(Some(amt)) => {
1074                self.send_available += amt;
1075                return Poll::Ready(Ok(self.send_available as usize));
1076            }
1077            Poll::Ready(None) => return Poll::Ready(Ok(0)),
1078            Poll::Pending => {}
1079        }
1080        Poll::Pending
1081    }
1082}
1083
1084impl VirtualSocket for RemoteSocket {
1085    fn set_ttl(&mut self, ttl: u32) -> Result<()> {
1086        self.io_socket_fire_and_forget(RequestType::SetTtl(ttl))
1087    }
1088
1089    fn ttl(&self) -> Result<u32> {
1090        match block_on(self.io_socket(RequestType::GetTtl)) {
1091            ResponseType::Err(err) => Err(err),
1092            ResponseType::Ttl(ttl) => Ok(ttl),
1093            res => {
1094                tracing::debug!("invalid response to get TTL request - {res:?}");
1095                Err(NetworkError::IOError)
1096            }
1097        }
1098    }
1099
1100    fn addr_local(&self) -> Result<SocketAddr> {
1101        match block_on(self.io_socket(RequestType::GetAddrLocal)) {
1102            ResponseType::Err(err) => Err(err),
1103            ResponseType::SocketAddr(addr) => Ok(addr),
1104            res => {
1105                tracing::debug!("invalid response to address local request - {res:?}");
1106                Err(NetworkError::IOError)
1107            }
1108        }
1109    }
1110
1111    fn status(&self) -> Result<crate::SocketStatus> {
1112        match block_on(self.io_socket(RequestType::GetStatus)) {
1113            ResponseType::Err(err) => Err(err),
1114            ResponseType::Status(status) => Ok(status),
1115            res => {
1116                tracing::debug!("invalid response to status request - {res:?}");
1117                Err(NetworkError::IOError)
1118            }
1119        }
1120    }
1121
1122    fn set_handler(
1123        &mut self,
1124        handler: Box<dyn virtual_mio::InterestHandler + Send + Sync>,
1125    ) -> Result<()> {
1126        self.common
1127            .handlers
1128            .lock()
1129            .unwrap()
1130            .insert(self.socket_id, handler);
1131        Ok(())
1132    }
1133}
1134
1135impl VirtualTcpListener for RemoteSocket {
1136    fn try_accept(&mut self) -> Result<(Box<dyn VirtualTcpSocket + Sync>, SocketAddr)> {
1137        // We may already have accepted a connection in the `poll_read_ready` method
1138        self.touch_begin_accept()?;
1139        let accepted = if let Some(child) = self.buffer_accept.pop_front() {
1140            child
1141        } else {
1142            self.rx_accept.try_recv().map_err(|err| match err {
1143                TryRecvError::Empty => NetworkError::WouldBlock,
1144                TryRecvError::Disconnected => NetworkError::ConnectionAborted,
1145            })?
1146        };
1147
1148        // This placed here will mean there is always an accept request pending at the
1149        // server as the constructor invokes this method and we invoke it here after
1150        // receiving a child connection.
1151        let mut rx_recv = None;
1152        if let Some((rx_socket, existing_rx_recv)) = self.pending_accept.take()
1153            && accepted.socket == rx_socket
1154        {
1155            rx_recv.replace(existing_rx_recv);
1156        }
1157        let rx_recv = match rx_recv {
1158            Some(rx_recv) => rx_recv,
1159            None => {
1160                let (tx, rx_recv) = tokio::sync::mpsc::channel(100);
1161                self.common
1162                    .recv_tx
1163                    .lock()
1164                    .unwrap()
1165                    .insert(accepted.socket, tx);
1166                rx_recv
1167            }
1168        };
1169        self.touch_begin_accept().ok();
1170
1171        let (tx, rx_recv_with_addr) = tokio::sync::mpsc::channel(100);
1172        self.common
1173            .recv_with_addr_tx
1174            .lock()
1175            .unwrap()
1176            .insert(accepted.socket, tx);
1177
1178        let (tx, rx_accept) = tokio::sync::mpsc::channel(100);
1179        self.common
1180            .accept_tx
1181            .lock()
1182            .unwrap()
1183            .insert(accepted.socket, tx);
1184
1185        let (tx, rx_sent) = tokio::sync::mpsc::channel(100);
1186        self.common
1187            .sent_tx
1188            .lock()
1189            .unwrap()
1190            .insert(accepted.socket, tx);
1191
1192        let socket = RemoteSocket {
1193            socket_id: accepted.socket,
1194            common: self.common.clone(),
1195            rx_buffer: BytesMut::new(),
1196            rx_recv,
1197            rx_recv_with_addr,
1198            rx_accept,
1199            rx_sent,
1200            pending_accept: None,
1201            tx_waker: TxWaker::new(&self.common).as_waker(),
1202            buffer_accept: Default::default(),
1203            buffer_recv_with_addr: Default::default(),
1204            send_available: 0,
1205            owns_socket_bindings: true,
1206        };
1207        Ok((Box::new(socket), accepted.addr))
1208    }
1209
1210    fn set_handler(
1211        &mut self,
1212        handler: Box<dyn virtual_mio::InterestHandler + Send + Sync>,
1213    ) -> Result<()> {
1214        VirtualSocket::set_handler(self, handler)
1215    }
1216
1217    fn addr_local(&self) -> Result<SocketAddr> {
1218        match block_on(self.io_socket(RequestType::GetAddrLocal)) {
1219            ResponseType::Err(err) => Err(err),
1220            ResponseType::SocketAddr(addr) => Ok(addr),
1221            res => {
1222                tracing::debug!("invalid response to addr local request - {res:?}");
1223                Err(NetworkError::IOError)
1224            }
1225        }
1226    }
1227
1228    fn set_ttl(&mut self, ttl: u8) -> Result<()> {
1229        self.io_socket_fire_and_forget(RequestType::SetTtl(ttl as u32))
1230    }
1231
1232    fn ttl(&self) -> Result<u8> {
1233        match block_on(self.io_socket(RequestType::GetTtl)) {
1234            ResponseType::Err(err) => Err(err),
1235            ResponseType::Ttl(val) => Ok(val.try_into().map_err(|_| NetworkError::InvalidData)?),
1236            res => {
1237                tracing::debug!("invalid response to get TTL request - {res:?}");
1238                Err(NetworkError::IOError)
1239            }
1240        }
1241    }
1242}
1243
1244impl VirtualTcpBoundSocket for RemoteSocket {
1245    fn addr_local(&self) -> Result<SocketAddr> {
1246        VirtualSocket::addr_local(self)
1247    }
1248
1249    fn listen(&mut self) -> Result<Box<dyn VirtualTcpListener + Sync>> {
1250        match block_on(self.io_socket(RequestType::ListenBound)) {
1251            ResponseType::Err(err) => Err(err),
1252            ResponseType::None => {
1253                let mut socket = self.transition_socket();
1254                socket.touch_begin_accept().ok();
1255                Ok(Box::new(socket))
1256            }
1257            res => {
1258                tracing::debug!("invalid response to listen bound request - {res:?}");
1259                Err(NetworkError::IOError)
1260            }
1261        }
1262    }
1263
1264    fn connect(&mut self, peer: SocketAddr) -> Result<Box<dyn VirtualTcpSocket + Sync>> {
1265        match block_on(self.io_socket(RequestType::ConnectBound { peer })) {
1266            ResponseType::Err(err) => Err(err),
1267            ResponseType::None => Ok(Box::new(self.transition_socket())),
1268            res => {
1269                tracing::debug!("invalid response to connect bound request - {res:?}");
1270                Err(NetworkError::IOError)
1271            }
1272        }
1273    }
1274
1275    fn set_ttl(&mut self, ttl: u32) -> Result<()> {
1276        VirtualSocket::set_ttl(self, ttl)
1277    }
1278
1279    fn ttl(&self) -> Result<u32> {
1280        VirtualSocket::ttl(self)
1281    }
1282}
1283
1284impl VirtualRawSocket for RemoteSocket {
1285    fn try_send(&mut self, data: &[u8]) -> Result<usize> {
1286        let mut cx = Context::from_waker(&self.tx_waker);
1287        match self.common.tx.poll_send(
1288            &mut cx,
1289            MessageRequest::Send {
1290                socket: self.socket_id,
1291                data: data.to_vec(),
1292                req_id: None,
1293            },
1294        ) {
1295            Poll::Ready(Ok(())) => Ok(data.len()),
1296            Poll::Ready(Err(NetworkError::WouldBlock)) | Poll::Pending => {
1297                self.send_available = 0;
1298                Err(NetworkError::WouldBlock)
1299            }
1300            Poll::Ready(Err(err)) => Err(err),
1301        }
1302    }
1303
1304    fn try_flush(&mut self) -> Result<()> {
1305        let mut cx = Context::from_waker(&self.tx_waker);
1306        match self.common.tx.poll_send(
1307            &mut cx,
1308            MessageRequest::Socket {
1309                socket: self.socket_id,
1310                req: RequestType::Flush,
1311                req_id: None,
1312            },
1313        ) {
1314            Poll::Ready(Ok(())) => Ok(()),
1315            Poll::Ready(Err(NetworkError::WouldBlock)) | Poll::Pending => {
1316                self.send_available = 0;
1317                Err(NetworkError::WouldBlock)
1318            }
1319            Poll::Ready(Err(err)) => Err(err),
1320        }
1321    }
1322
1323    fn try_recv(&mut self, buf: &mut [std::mem::MaybeUninit<u8>], peek: bool) -> Result<usize> {
1324        loop {
1325            if !self.rx_buffer.is_empty() {
1326                let amt = self.rx_buffer.len().min(buf.len());
1327                let buf: &mut [u8] = unsafe { std::mem::transmute(buf) };
1328                buf[..amt].copy_from_slice(&self.rx_buffer[..amt]);
1329                if !peek {
1330                    self.rx_buffer.advance(amt);
1331                }
1332                return Ok(amt);
1333            }
1334            match self.rx_recv.try_recv() {
1335                Ok(data) => self.rx_buffer.extend_from_slice(&data),
1336                Err(TryRecvError::Disconnected) => return Err(NetworkError::ConnectionAborted),
1337                Err(TryRecvError::Empty) => return Err(NetworkError::WouldBlock),
1338            }
1339        }
1340    }
1341
1342    fn set_promiscuous(&mut self, promiscuous: bool) -> Result<()> {
1343        self.io_socket_fire_and_forget(RequestType::SetPromiscuous(promiscuous))
1344    }
1345
1346    fn promiscuous(&self) -> Result<bool> {
1347        match block_on(self.io_socket(RequestType::GetPromiscuous)) {
1348            ResponseType::Err(err) => Err(err),
1349            ResponseType::Flag(val) => Ok(val),
1350            res => {
1351                tracing::debug!("invalid response to get promiscuous request - {res:?}");
1352                Err(NetworkError::IOError)
1353            }
1354        }
1355    }
1356}
1357
1358impl VirtualConnectionlessSocket for RemoteSocket {
1359    fn try_send_to(&mut self, data: &[u8], addr: SocketAddr) -> Result<usize> {
1360        let req_id = self.common.request_seed.fetch_add(1, Ordering::SeqCst);
1361        let mut cx = Context::from_waker(&self.tx_waker);
1362        match self.common.tx.poll_send(
1363            &mut cx,
1364            MessageRequest::SendTo {
1365                socket: self.socket_id,
1366                data: data.to_vec(),
1367                addr,
1368                req_id: Some(req_id),
1369            },
1370        ) {
1371            Poll::Ready(Ok(())) => Ok(data.len()),
1372            Poll::Ready(Err(NetworkError::WouldBlock)) | Poll::Pending => {
1373                self.send_available = 0;
1374                Err(NetworkError::WouldBlock)
1375            }
1376            Poll::Ready(Err(err)) => Err(err),
1377        }
1378    }
1379
1380    fn try_recv_from(
1381        &mut self,
1382        buf: &mut [std::mem::MaybeUninit<u8>],
1383        peek: bool,
1384    ) -> Result<(usize, SocketAddr)> {
1385        // FIXME: A proper peek implementation would require correct use of a
1386        // buffer. We don't use this implementation AFAIK, so I'll skip over it
1387        // for the time being.
1388        if peek {
1389            return Err(NetworkError::Unsupported);
1390        }
1391
1392        match self.rx_recv_with_addr.try_recv() {
1393            Ok(received) => {
1394                let amt = buf.len().min(received.data.len());
1395                let buf: &mut [u8] = unsafe { std::mem::transmute(buf) };
1396                buf[..amt].copy_from_slice(&received.data[..amt]);
1397                Ok((amt, received.addr))
1398            }
1399            Err(TryRecvError::Disconnected) => Err(NetworkError::ConnectionAborted),
1400            Err(TryRecvError::Empty) => Err(NetworkError::WouldBlock),
1401        }
1402    }
1403}
1404
1405impl VirtualUdpSocket for RemoteSocket {
1406    fn set_broadcast(&mut self, broadcast: bool) -> Result<()> {
1407        self.io_socket_fire_and_forget(RequestType::SetBroadcast(broadcast))
1408    }
1409
1410    fn broadcast(&self) -> Result<bool> {
1411        match block_on(self.io_socket(RequestType::GetBroadcast)) {
1412            ResponseType::Err(err) => Err(err),
1413            ResponseType::Flag(val) => Ok(val),
1414            res => {
1415                tracing::debug!("invalid response to get broadcast request - {res:?}");
1416                Err(NetworkError::IOError)
1417            }
1418        }
1419    }
1420
1421    fn set_multicast_loop_v4(&mut self, val: bool) -> Result<()> {
1422        self.io_socket_fire_and_forget(RequestType::SetMulticastLoopV4(val))
1423    }
1424
1425    fn multicast_loop_v4(&self) -> Result<bool> {
1426        match block_on(self.io_socket(RequestType::GetMulticastLoopV4)) {
1427            ResponseType::Err(err) => Err(err),
1428            ResponseType::Flag(val) => Ok(val),
1429            res => {
1430                tracing::debug!("invalid response to get multicast loop v4 request - {res:?}");
1431                Err(NetworkError::IOError)
1432            }
1433        }
1434    }
1435
1436    fn set_multicast_loop_v6(&mut self, val: bool) -> Result<()> {
1437        self.io_socket_fire_and_forget(RequestType::SetMulticastLoopV6(val))
1438    }
1439
1440    fn multicast_loop_v6(&self) -> Result<bool> {
1441        match block_on(self.io_socket(RequestType::GetMulticastLoopV6)) {
1442            ResponseType::Err(err) => Err(err),
1443            ResponseType::Flag(val) => Ok(val),
1444            res => {
1445                tracing::debug!("invalid response to get multicast loop v6 request - {res:?}");
1446                Err(NetworkError::IOError)
1447            }
1448        }
1449    }
1450
1451    fn set_multicast_ttl_v4(&mut self, ttl: u32) -> Result<()> {
1452        self.io_socket_fire_and_forget(RequestType::SetMulticastTtlV4(ttl))
1453    }
1454
1455    fn multicast_ttl_v4(&self) -> Result<u32> {
1456        match block_on(self.io_socket(RequestType::GetMulticastTtlV4)) {
1457            ResponseType::Err(err) => Err(err),
1458            ResponseType::Ttl(ttl) => Ok(ttl),
1459            res => {
1460                tracing::debug!("invalid response to get multicast TTL v4 request - {res:?}");
1461                Err(NetworkError::IOError)
1462            }
1463        }
1464    }
1465
1466    fn join_multicast_v4(
1467        &mut self,
1468        multiaddr: std::net::Ipv4Addr,
1469        iface: std::net::Ipv4Addr,
1470    ) -> Result<()> {
1471        self.io_socket_fire_and_forget(RequestType::JoinMulticastV4 { multiaddr, iface })
1472    }
1473
1474    fn leave_multicast_v4(
1475        &mut self,
1476        multiaddr: std::net::Ipv4Addr,
1477        iface: std::net::Ipv4Addr,
1478    ) -> Result<()> {
1479        self.io_socket_fire_and_forget(RequestType::LeaveMulticastV4 { multiaddr, iface })
1480    }
1481
1482    fn join_multicast_v6(&mut self, multiaddr: std::net::Ipv6Addr, iface: u32) -> Result<()> {
1483        self.io_socket_fire_and_forget(RequestType::JoinMulticastV6 { multiaddr, iface })
1484    }
1485
1486    fn leave_multicast_v6(&mut self, multiaddr: std::net::Ipv6Addr, iface: u32) -> Result<()> {
1487        self.io_socket_fire_and_forget(RequestType::LeaveMulticastV6 { multiaddr, iface })
1488    }
1489
1490    fn addr_peer(&self) -> Result<Option<SocketAddr>> {
1491        match block_on(self.io_socket(RequestType::GetAddrPeer)) {
1492            ResponseType::Err(err) => Err(err),
1493            ResponseType::None => Ok(None),
1494            ResponseType::SocketAddr(addr) => Ok(Some(addr)),
1495            res => {
1496                tracing::debug!("invalid response to addr peer request - {res:?}");
1497                Err(NetworkError::IOError)
1498            }
1499        }
1500    }
1501}
1502
1503impl VirtualIcmpSocket for RemoteSocket {}
1504
1505impl VirtualConnectedSocket for RemoteSocket {
1506    fn set_linger(&mut self, linger: Option<Duration>) -> Result<()> {
1507        self.io_socket_fire_and_forget(RequestType::SetLinger(linger))
1508    }
1509
1510    fn linger(&self) -> Result<Option<Duration>> {
1511        match block_on(self.io_socket(RequestType::GetLinger)) {
1512            ResponseType::Err(err) => Err(err),
1513            ResponseType::None => Ok(None),
1514            ResponseType::Duration(val) => Ok(Some(val)),
1515            res => {
1516                tracing::debug!("invalid response to get linger request - {res:?}");
1517                Err(NetworkError::IOError)
1518            }
1519        }
1520    }
1521
1522    fn try_send(&mut self, data: &[u8]) -> Result<usize> {
1523        let req_id = self.common.request_seed.fetch_add(1, Ordering::SeqCst);
1524        let mut cx = Context::from_waker(&self.tx_waker);
1525        match self.common.tx.poll_send(
1526            &mut cx,
1527            MessageRequest::Send {
1528                socket: self.socket_id,
1529                data: data.to_vec(),
1530                req_id: Some(req_id),
1531            },
1532        ) {
1533            Poll::Ready(Ok(())) => Ok(data.len()),
1534            Poll::Ready(Err(err)) => Err(err),
1535            Poll::Pending => Err(NetworkError::WouldBlock),
1536        }
1537    }
1538
1539    fn try_flush(&mut self) -> Result<()> {
1540        let mut cx = Context::from_waker(&self.tx_waker);
1541        match self.common.tx.poll_send(
1542            &mut cx,
1543            MessageRequest::Socket {
1544                socket: self.socket_id,
1545                req: RequestType::Flush,
1546                req_id: None,
1547            },
1548        ) {
1549            Poll::Ready(Ok(())) => Ok(()),
1550            Poll::Ready(Err(err)) => Err(err),
1551            Poll::Pending => Err(NetworkError::WouldBlock),
1552        }
1553    }
1554
1555    fn close(&mut self) -> Result<()> {
1556        let ret = self.io_socket_fire_and_forget(RequestType::Close);
1557        if ret.is_ok() {
1558            self.release_socket_bindings();
1559        }
1560        ret
1561    }
1562
1563    fn try_recv(&mut self, buf: &mut [std::mem::MaybeUninit<u8>], peek: bool) -> Result<usize> {
1564        loop {
1565            if !self.rx_buffer.is_empty() {
1566                let amt = self.rx_buffer.len().min(buf.len());
1567                let buf: &mut [u8] = unsafe { std::mem::transmute(buf) };
1568                buf[..amt].copy_from_slice(&self.rx_buffer[..amt]);
1569                if !peek {
1570                    self.rx_buffer.advance(amt);
1571                }
1572                return Ok(amt);
1573            }
1574            match self.rx_recv.try_recv() {
1575                Ok(data) => self.rx_buffer.extend_from_slice(&data),
1576                Err(TryRecvError::Disconnected) => return Err(NetworkError::ConnectionAborted),
1577                Err(TryRecvError::Empty) => return Err(NetworkError::WouldBlock),
1578            }
1579        }
1580    }
1581}
1582
1583impl VirtualTcpSocket for RemoteSocket {
1584    fn set_recv_buf_size(&mut self, size: usize) -> Result<()> {
1585        self.io_socket_fire_and_forget(RequestType::SetRecvBufSize(size as u64))
1586    }
1587
1588    fn recv_buf_size(&self) -> Result<usize> {
1589        match block_on(self.io_socket(RequestType::GetRecvBufSize)) {
1590            ResponseType::Err(err) => Err(err),
1591            ResponseType::Amount(amt) => Ok(amt.try_into().map_err(|_| NetworkError::IOError)?),
1592            res => {
1593                tracing::debug!("invalid response to get recv buf size request - {res:?}");
1594                Err(NetworkError::IOError)
1595            }
1596        }
1597    }
1598
1599    fn set_send_buf_size(&mut self, size: usize) -> Result<()> {
1600        self.io_socket_fire_and_forget(RequestType::SetSendBufSize(size as u64))
1601    }
1602
1603    fn send_buf_size(&self) -> Result<usize> {
1604        match block_on(self.io_socket(RequestType::GetSendBufSize)) {
1605            ResponseType::Err(err) => Err(err),
1606            ResponseType::Amount(val) => Ok(val.try_into().map_err(|_| NetworkError::IOError)?),
1607            res => {
1608                tracing::debug!("invalid response to get send buf size request - {res:?}");
1609                Err(NetworkError::IOError)
1610            }
1611        }
1612    }
1613
1614    fn set_nodelay(&mut self, reuse: bool) -> Result<()> {
1615        self.io_socket_fire_and_forget(RequestType::SetNoDelay(reuse))
1616    }
1617
1618    fn nodelay(&self) -> Result<bool> {
1619        match block_on(self.io_socket(RequestType::GetNoDelay)) {
1620            ResponseType::Err(err) => Err(err),
1621            ResponseType::Flag(val) => Ok(val),
1622            res => {
1623                tracing::debug!("invalid response to get nodelay request - {res:?}");
1624                Err(NetworkError::IOError)
1625            }
1626        }
1627    }
1628
1629    fn set_keepalive(&mut self, keep_alive: bool) -> Result<()> {
1630        self.io_socket_fire_and_forget(RequestType::SetKeepAlive(keep_alive))
1631    }
1632
1633    fn keepalive(&self) -> Result<bool> {
1634        match block_on(self.io_socket(RequestType::GetKeepAlive)) {
1635            ResponseType::Err(err) => Err(err),
1636            ResponseType::Flag(val) => Ok(val),
1637            res => {
1638                tracing::debug!("invalid response to get nodelay request - {res:?}");
1639                Err(NetworkError::IOError)
1640            }
1641        }
1642    }
1643
1644    fn set_dontroute(&mut self, dont_route: bool) -> Result<()> {
1645        self.io_socket_fire_and_forget(RequestType::SetDontRoute(dont_route))
1646    }
1647
1648    fn dontroute(&self) -> Result<bool> {
1649        match block_on(self.io_socket(RequestType::GetDontRoute)) {
1650            ResponseType::Err(err) => Err(err),
1651            ResponseType::Flag(val) => Ok(val),
1652            res => {
1653                tracing::debug!("invalid response to get nodelay request - {res:?}");
1654                Err(NetworkError::IOError)
1655            }
1656        }
1657    }
1658
1659    fn addr_peer(&self) -> Result<SocketAddr> {
1660        match block_on(self.io_socket(RequestType::GetAddrPeer)) {
1661            ResponseType::Err(err) => Err(err),
1662            ResponseType::SocketAddr(addr) => Ok(addr),
1663            res => {
1664                tracing::debug!("invalid response to addr peer request - {res:?}");
1665                Err(NetworkError::IOError)
1666            }
1667        }
1668    }
1669
1670    fn shutdown(&mut self, how: std::net::Shutdown) -> Result<()> {
1671        let shutdown = match how {
1672            std::net::Shutdown::Read => meta::Shutdown::Read,
1673            std::net::Shutdown::Write => meta::Shutdown::Write,
1674            std::net::Shutdown::Both => meta::Shutdown::Both,
1675        };
1676        self.io_socket_fire_and_forget(RequestType::Shutdown(shutdown))
1677    }
1678
1679    fn is_closed(&self) -> bool {
1680        match block_on(self.io_socket(RequestType::IsClosed)) {
1681            ResponseType::Flag(val) => val,
1682            _ => false,
1683        }
1684    }
1685}