virtual_net/
server.rs

1use crate::meta::{FrameSerializationFormat, ResponseType};
2use crate::rx_tx::{RemoteRx, RemoteTx, RemoteTxWakers};
3use crate::{IpCidr, IpRoute, NetworkError, StreamSecurity, VirtualIcmpSocket};
4use crate::{
5    VirtualNetworking, VirtualRawSocket, VirtualTcpListener, VirtualTcpSocket, VirtualUdpSocket,
6    meta::{MessageRequest, MessageResponse, RequestType, SocketId},
7};
8use futures_util::stream::FuturesOrdered;
9#[cfg(any(feature = "hyper", feature = "tokio-tungstenite"))]
10use futures_util::stream::{SplitSink, SplitStream};
11use futures_util::{Sink, Stream};
12use futures_util::{StreamExt, future::BoxFuture};
13use std::collections::HashSet;
14use std::mem::MaybeUninit;
15use std::net::IpAddr;
16use std::task::Waker;
17use std::time::Duration;
18
19#[cfg(feature = "hyper")]
20use hyper_util::rt::tokio::TokioIo;
21use std::{
22    collections::HashMap,
23    future::Future,
24    net::SocketAddr,
25    pin::Pin,
26    sync::{Arc, Mutex},
27    task::{Context, Poll},
28};
29use tokio::{
30    io::{AsyncRead, AsyncWrite},
31    sync::mpsc,
32};
33use tokio_serde::SymmetricallyFramed;
34use tokio_serde::formats::SymmetricalBincode;
35#[cfg(feature = "cbor")]
36use tokio_serde::formats::SymmetricalCbor;
37#[cfg(feature = "json")]
38use tokio_serde::formats::SymmetricalJson;
39#[cfg(feature = "messagepack")]
40use tokio_serde::formats::SymmetricalMessagePack;
41use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
42use virtual_mio::InterestHandler;
43
44type BackgroundTask = Option<BoxFuture<'static, ()>>;
45
46#[derive(Debug, Clone)]
47pub struct RemoteNetworkingServer {
48    #[allow(dead_code)]
49    common: Arc<RemoteAdapterCommon>,
50    inner: Arc<dyn VirtualNetworking + Send + Sync + 'static>,
51}
52
53impl RemoteNetworkingServer {
54    fn new(
55        tx: RemoteTx<MessageResponse>,
56        rx: RemoteRx<MessageRequest>,
57        work: mpsc::UnboundedReceiver<BoxFuture<'static, ()>>,
58        inner: Arc<dyn VirtualNetworking + Send + Sync + 'static>,
59    ) -> (Self, RemoteNetworkingServerDriver) {
60        let common = RemoteAdapterCommon {
61            tx,
62            rx: Mutex::new(rx),
63            sockets: Default::default(),
64            socket_accept: Default::default(),
65            handler: Default::default(),
66            stall_rx: Default::default(),
67        };
68        let common = Arc::new(common);
69
70        let driver = RemoteNetworkingServerDriver {
71            more_work: work,
72            tasks: Default::default(),
73            common: common.clone(),
74            inner: inner.clone(),
75        };
76        let networking = Self { common, inner };
77
78        (networking, driver)
79    }
80    /// Creates a new interface on the remote location using
81    /// a unique interface ID and a pair of channels
82    pub fn new_from_mpsc(
83        tx: mpsc::Sender<MessageResponse>,
84        rx: mpsc::Receiver<MessageRequest>,
85        inner: Arc<dyn VirtualNetworking + Send + Sync + 'static>,
86    ) -> (Self, RemoteNetworkingServerDriver) {
87        let (tx_work, rx_work) = mpsc::unbounded_channel();
88        let tx_wakers = RemoteTxWakers::default();
89
90        let tx = RemoteTx::Mpsc {
91            tx,
92            work: tx_work,
93            wakers: tx_wakers.clone(),
94        };
95        let rx = RemoteRx::Mpsc {
96            rx,
97            wakers: tx_wakers,
98        };
99        Self::new(tx, rx, rx_work, inner)
100    }
101
102    /// Creates a new interface on the remote location using
103    /// a unique interface ID and a pair of channels
104    pub fn new_from_async_io<TX, RX>(
105        tx: TX,
106        rx: RX,
107        format: FrameSerializationFormat,
108        inner: Arc<dyn VirtualNetworking + Send + Sync + 'static>,
109    ) -> (Self, RemoteNetworkingServerDriver)
110    where
111        TX: AsyncWrite + Send + 'static,
112        RX: AsyncRead + Send + 'static,
113    {
114        let tx = FramedWrite::new(tx, LengthDelimitedCodec::new());
115        let tx: Pin<Box<dyn Sink<MessageResponse, Error = std::io::Error> + Send + 'static>> =
116            match format {
117                FrameSerializationFormat::Bincode => {
118                    Box::pin(SymmetricallyFramed::new(tx, SymmetricalBincode::default()))
119                }
120                #[cfg(feature = "json")]
121                FrameSerializationFormat::Json => {
122                    Box::pin(SymmetricallyFramed::new(tx, SymmetricalJson::default()))
123                }
124                #[cfg(feature = "messagepack")]
125                FrameSerializationFormat::MessagePack => Box::pin(SymmetricallyFramed::new(
126                    tx,
127                    SymmetricalMessagePack::default(),
128                )),
129                #[cfg(feature = "cbor")]
130                FrameSerializationFormat::Cbor => {
131                    Box::pin(SymmetricallyFramed::new(tx, SymmetricalCbor::default()))
132                }
133            };
134
135        let rx = FramedRead::new(rx, LengthDelimitedCodec::new());
136        let rx: Pin<Box<dyn Stream<Item = std::io::Result<MessageRequest>> + Send + 'static>> =
137            match format {
138                FrameSerializationFormat::Bincode => {
139                    Box::pin(SymmetricallyFramed::new(rx, SymmetricalBincode::default()))
140                }
141                #[cfg(feature = "json")]
142                FrameSerializationFormat::Json => {
143                    Box::pin(SymmetricallyFramed::new(rx, SymmetricalJson::default()))
144                }
145                #[cfg(feature = "messagepack")]
146                FrameSerializationFormat::MessagePack => Box::pin(SymmetricallyFramed::new(
147                    rx,
148                    SymmetricalMessagePack::default(),
149                )),
150                #[cfg(feature = "cbor")]
151                FrameSerializationFormat::Cbor => {
152                    Box::pin(SymmetricallyFramed::new(rx, SymmetricalCbor::default()))
153                }
154            };
155
156        let (tx_work, rx_work) = mpsc::unbounded_channel();
157
158        let tx = RemoteTx::Stream {
159            tx: Arc::new(tokio::sync::Mutex::new(tx)),
160            work: tx_work,
161            wakers: RemoteTxWakers::default(),
162        };
163        let rx = RemoteRx::Stream { rx };
164        Self::new(tx, rx, rx_work, inner)
165    }
166
167    /// Creates a new interface on the remote location using
168    /// a unique interface ID and a pair of channels
169    #[cfg(feature = "hyper")]
170    pub fn new_from_hyper_ws_io(
171        tx: SplitSink<
172            hyper_tungstenite::WebSocketStream<TokioIo<hyper::upgrade::Upgraded>>,
173            hyper_tungstenite::tungstenite::Message,
174        >,
175        rx: SplitStream<hyper_tungstenite::WebSocketStream<TokioIo<hyper::upgrade::Upgraded>>>,
176        format: FrameSerializationFormat,
177        inner: Arc<dyn VirtualNetworking + Send + Sync + 'static>,
178    ) -> (Self, RemoteNetworkingServerDriver) {
179        let (tx_work, rx_work) = mpsc::unbounded_channel();
180
181        let tx = RemoteTx::HyperWebSocket {
182            tx: Arc::new(tokio::sync::Mutex::new(tx)),
183            work: tx_work,
184            wakers: RemoteTxWakers::default(),
185            format,
186        };
187        let rx = RemoteRx::HyperWebSocket { rx, format };
188        Self::new(tx, rx, rx_work, inner)
189    }
190}
191
192#[async_trait::async_trait]
193impl VirtualNetworking for RemoteNetworkingServer {
194    async fn bridge(
195        &self,
196        network: &str,
197        access_token: &str,
198        security: StreamSecurity,
199    ) -> Result<(), NetworkError> {
200        self.inner.bridge(network, access_token, security).await
201    }
202
203    async fn unbridge(&self) -> Result<(), NetworkError> {
204        self.inner.unbridge().await
205    }
206
207    async fn dhcp_acquire(&self) -> Result<Vec<IpAddr>, NetworkError> {
208        self.inner.dhcp_acquire().await
209    }
210
211    async fn ip_add(&self, ip: IpAddr, prefix: u8) -> Result<(), NetworkError> {
212        self.inner.ip_add(ip, prefix).await
213    }
214
215    async fn ip_remove(&self, ip: IpAddr) -> Result<(), NetworkError> {
216        self.inner.ip_remove(ip).await
217    }
218
219    async fn ip_clear(&self) -> Result<(), NetworkError> {
220        self.inner.ip_clear().await
221    }
222
223    async fn ip_list(&self) -> Result<Vec<IpCidr>, NetworkError> {
224        self.inner.ip_list().await
225    }
226
227    async fn mac(&self) -> Result<[u8; 6], NetworkError> {
228        self.inner.mac().await
229    }
230
231    async fn gateway_set(&self, ip: IpAddr) -> Result<(), NetworkError> {
232        self.inner.gateway_set(ip).await
233    }
234
235    async fn route_add(
236        &self,
237        cidr: IpCidr,
238        via_router: IpAddr,
239        preferred_until: Option<Duration>,
240        expires_at: Option<Duration>,
241    ) -> Result<(), NetworkError> {
242        self.inner
243            .route_add(cidr, via_router, preferred_until, expires_at)
244            .await
245    }
246
247    async fn route_remove(&self, cidr: IpAddr) -> Result<(), NetworkError> {
248        self.inner.route_remove(cidr).await
249    }
250
251    async fn route_clear(&self) -> Result<(), NetworkError> {
252        self.inner.route_clear().await
253    }
254
255    async fn route_list(&self) -> Result<Vec<IpRoute>, NetworkError> {
256        self.inner.route_list().await
257    }
258
259    async fn bind_raw(&self) -> Result<Box<dyn VirtualRawSocket + Sync>, NetworkError> {
260        self.inner.bind_raw().await
261    }
262
263    async fn listen_tcp(
264        &self,
265        addr: SocketAddr,
266        only_v6: bool,
267        reuse_port: bool,
268        reuse_addr: bool,
269    ) -> Result<Box<dyn VirtualTcpListener + Sync>, NetworkError> {
270        self.inner
271            .listen_tcp(addr, only_v6, reuse_port, reuse_addr)
272            .await
273    }
274
275    async fn bind_udp(
276        &self,
277        addr: SocketAddr,
278        reuse_port: bool,
279        reuse_addr: bool,
280    ) -> Result<Box<dyn VirtualUdpSocket + Sync>, NetworkError> {
281        self.inner.bind_udp(addr, reuse_port, reuse_addr).await
282    }
283
284    async fn bind_icmp(
285        &self,
286        addr: IpAddr,
287    ) -> Result<Box<dyn VirtualIcmpSocket + Sync>, NetworkError> {
288        self.inner.bind_icmp(addr).await
289    }
290
291    async fn connect_tcp(
292        &self,
293        addr: SocketAddr,
294        peer: SocketAddr,
295    ) -> Result<Box<dyn VirtualTcpSocket + Sync>, NetworkError> {
296        self.inner.connect_tcp(addr, peer).await
297    }
298
299    async fn resolve(
300        &self,
301        host: &str,
302        port: Option<u16>,
303        dns_server: Option<IpAddr>,
304    ) -> Result<Vec<IpAddr>, NetworkError> {
305        self.inner.resolve(host, port, dns_server).await
306    }
307}
308
309pin_project_lite::pin_project! {
310    pub struct RemoteNetworkingServerDriver {
311        common: Arc<RemoteAdapterCommon>,
312        more_work: mpsc::UnboundedReceiver<BoxFuture<'static, ()>>,
313        #[pin]
314        tasks: FuturesOrdered<BoxFuture<'static, ()>>,
315        inner: Arc<dyn VirtualNetworking + Send + Sync + 'static>,
316    }
317}
318
319impl Future for RemoteNetworkingServerDriver {
320    type Output = ();
321
322    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
323        // We register the waker into the interest of the sockets so
324        // that it is woken when something is ready to read or write
325        let readable = {
326            let mut guard = self.common.handler.state.lock().unwrap();
327            if !guard.driver_wakers.iter().any(|w| w.will_wake(cx.waker())) {
328                guard.driver_wakers.push(cx.waker().clone());
329            }
330            guard.readable.drain().collect()
331        };
332        let readable: Vec<_> = readable;
333
334        {
335            // When a socket is marked as readable then we should drain all the data
336            // from it and start sending it to the client
337            let common = self.common.clone();
338            let mut guard = common.sockets.lock().unwrap();
339            for socket_id in readable {
340                if let Some(task) = guard
341                    .get_mut(&socket_id)
342                    .map(|s| s.drain_reads_and_accepts(&common, socket_id))
343                    .unwrap_or(None)
344                {
345                    self.tasks.push_back(task);
346                }
347            }
348        }
349
350        // This guard will be held while the pipeline is not currently
351        // stalled by some back pressure. It is only acquired when there
352        // is background tasks being processed
353        let mut not_stalled_guard = None;
354
355        // We loop until the waker is registered with the receiving stream
356        // and all the background tasks
357        loop {
358            // Background tasks are sent to this driver in certain circumstances
359            while let Poll::Ready(Some(work)) = Pin::new(&mut self.more_work).poll_recv(cx) {
360                self.tasks.push_back(work);
361            }
362
363            // Background work basically stalls the stream until its all processed
364            // which creates back pressure on the client so that they don't overload
365            // the system
366            match self.tasks.poll_next_unpin(cx) {
367                Poll::Ready(Some(_)) => continue,
368                Poll::Ready(None) => {
369                    not_stalled_guard.take();
370                }
371                Poll::Pending if not_stalled_guard.is_none() => {
372                    match self.common.stall_rx.clone().try_lock_owned() {
373                        Ok(guard) => {
374                            not_stalled_guard.replace(guard);
375                        }
376                        _ => {
377                            return Poll::Pending;
378                        }
379                    }
380                }
381                Poll::Pending => {}
382            };
383
384            // We grab the next message sent by the client to us
385            let msg = {
386                let mut rx_guard = self.common.rx.lock().unwrap();
387                rx_guard.poll(cx)
388            };
389            return match msg {
390                Poll::Ready(Some(msg)) => {
391                    if let Some(task) = self.process(msg) {
392                        // With some messages we process there are background tasks that need to
393                        // be further driver to completion by the driver
394                        self.tasks.push_back(task)
395                    };
396                    continue;
397                }
398                Poll::Ready(None) => Poll::Ready(()),
399                Poll::Pending => Poll::Pending,
400            };
401        }
402    }
403}
404
405impl RemoteNetworkingServerDriver {
406    fn process(&mut self, msg: MessageRequest) -> BackgroundTask {
407        match msg {
408            MessageRequest::Send {
409                socket,
410                data,
411                req_id,
412            } => self.process_send(socket, data, req_id),
413            MessageRequest::SendTo {
414                socket,
415                data,
416                addr,
417                req_id,
418            } => self.process_send_to(socket, data, addr, req_id),
419            MessageRequest::Interface { req, req_id } => self.process_interface(req, req_id),
420            MessageRequest::Socket {
421                socket,
422                req,
423                req_id,
424            } => self.process_socket(socket, req, req_id),
425            MessageRequest::Reconnect => None,
426        }
427    }
428
429    fn process_send(
430        &mut self,
431        socket_id: SocketId,
432        data: Vec<u8>,
433        req_id: Option<u64>,
434    ) -> BackgroundTask {
435        let mut guard = self.common.sockets.lock().unwrap();
436        guard
437            .get_mut(&socket_id)
438            .map(|s| s.send(&self.common, socket_id, data, req_id))
439            .unwrap_or_else(|| {
440                tracing::debug!("orphaned socket {:?}", socket_id);
441                None
442            })
443    }
444
445    fn process_send_to(
446        &mut self,
447        socket_id: SocketId,
448        data: Vec<u8>,
449        addr: SocketAddr,
450        req_id: Option<u64>,
451    ) -> BackgroundTask {
452        let mut guard = self.common.sockets.lock().unwrap();
453        guard
454            .get_mut(&socket_id)
455            .map(|s| {
456                req_id.and_then(|req_id| s.send_to(&self.common, socket_id, data, addr, req_id))
457            })
458            .unwrap_or(None)
459    }
460
461    fn process_async<F>(future: F) -> BackgroundTask
462    where
463        F: Future<Output = BackgroundTask> + Send + 'static,
464    {
465        Some(Box::pin(async move {
466            let background_task = future.await;
467            if let Some(background_task) = background_task {
468                background_task.await;
469            }
470        }))
471    }
472
473    fn process_async_inner<F, Fut, T>(
474        &self,
475        work: F,
476        transmute: T,
477        req_id: Option<u64>,
478    ) -> BackgroundTask
479    where
480        F: FnOnce(Arc<dyn VirtualNetworking + Send + Sync>) -> Fut + Send + 'static,
481        Fut: Future + Send + 'static,
482        T: FnOnce(Fut::Output) -> ResponseType + Send + 'static,
483    {
484        let inner = self.inner.clone();
485        let common = self.common.clone();
486        Self::process_async(async move {
487            let future = work(inner);
488            let ret = future.await;
489            req_id.and_then(|req_id| {
490                common.send(MessageResponse::ResponseToRequest {
491                    req_id,
492                    res: transmute(ret),
493                })
494            })
495        })
496    }
497
498    fn process_async_noop<F, Fut>(&self, work: F, req_id: Option<u64>) -> BackgroundTask
499    where
500        F: FnOnce(Arc<dyn VirtualNetworking + Send + Sync>) -> Fut + Send + 'static,
501        Fut: Future<Output = Result<(), NetworkError>> + Send + 'static,
502    {
503        self.process_async_inner(
504            work,
505            move |ret| match ret {
506                Ok(()) => ResponseType::None,
507                Err(err) => ResponseType::Err(err),
508            },
509            req_id,
510        )
511    }
512
513    fn process_async_new_socket<F, Fut>(
514        &self,
515        work: F,
516        socket_id: SocketId,
517        req_id: Option<u64>,
518    ) -> BackgroundTask
519    where
520        F: FnOnce(Arc<dyn VirtualNetworking + Send + Sync>) -> Fut + Send + 'static,
521        Fut: Future<Output = Result<RemoteAdapterSocket, NetworkError>> + Send + 'static,
522    {
523        let common = self.common.clone();
524        self.process_async_inner(
525            work,
526            move |ret| match ret {
527                Ok(mut socket) => {
528                    let handler = Box::new(common.handler.clone().for_socket(socket_id));
529
530                    let err = match &mut socket {
531                        RemoteAdapterSocket::TcpListener { .. } => {
532                            // we do not attach the handler immediately with new TPC listeners as we
533                            // only want it to trigger when the BeginAccept message is received with
534                            // a child ID we can actually use
535                            Ok(())
536                        }
537                        RemoteAdapterSocket::TcpSocket(s) => s.set_handler(handler),
538                        RemoteAdapterSocket::UdpSocket(s) => s.set_handler(handler),
539                        RemoteAdapterSocket::IcmpSocket(s) => s.set_handler(handler),
540                        RemoteAdapterSocket::RawSocket(s) => s.set_handler(handler),
541                    };
542                    if let Err(err) = err {
543                        return ResponseType::Err(err);
544                    }
545
546                    let mut guard = common.sockets.lock().unwrap();
547                    guard.insert(socket_id, socket);
548
549                    ResponseType::Socket(socket_id)
550                }
551                Err(err) => ResponseType::Err(err),
552            },
553            req_id,
554        )
555    }
556
557    fn process_inner<F, R, T>(
558        &self,
559        work: F,
560        transmute: T,
561        socket_id: SocketId,
562        req_id: Option<u64>,
563    ) -> BackgroundTask
564    where
565        F: FnOnce(&mut RemoteAdapterSocket) -> R + Send + 'static,
566        T: FnOnce(R) -> ResponseType + Send + 'static,
567    {
568        let ret = {
569            let mut guard = self.common.sockets.lock().unwrap();
570            let socket = match guard.get_mut(&socket_id) {
571                Some(s) => s,
572                None => {
573                    return req_id.and_then(|req_id| {
574                        self.common.send(MessageResponse::ResponseToRequest {
575                            req_id,
576                            res: ResponseType::Err(NetworkError::InvalidFd),
577                        })
578                    });
579                }
580            };
581            work(socket)
582        };
583        req_id.and_then(|req_id| {
584            self.common.send(MessageResponse::ResponseToRequest {
585                req_id,
586                res: transmute(ret),
587            })
588        })
589    }
590
591    fn process_inner_noop<F>(
592        &self,
593        work: F,
594        socket_id: SocketId,
595        req_id: Option<u64>,
596    ) -> BackgroundTask
597    where
598        F: FnOnce(&mut RemoteAdapterSocket) -> Result<(), NetworkError> + Send + 'static,
599    {
600        self.process_inner(
601            work,
602            move |ret| match ret {
603                Ok(()) => ResponseType::None,
604                Err(err) => ResponseType::Err(err),
605            },
606            socket_id,
607            req_id,
608        )
609    }
610
611    fn process_inner_begin_accept(
612        &self,
613        socket_id: SocketId,
614        child_id: SocketId,
615        req_id: Option<u64>,
616    ) -> BackgroundTask {
617        // We record the child socket so it can be used on the next accepted socket
618        {
619            let mut guard = self.common.socket_accept.lock().unwrap();
620            guard.insert(socket_id, child_id);
621        }
622
623        // Now we attach the handler to the main listening socket
624        let mut handler = Box::new(self.common.handler.clone().for_socket(socket_id));
625        handler.push_interest(virtual_mio::InterestType::Readable);
626        self.process_inner_noop(
627            move |socket| match socket {
628                RemoteAdapterSocket::TcpListener {
629                    socket: s,
630                    next_accept,
631                    ..
632                } => {
633                    next_accept.replace(child_id);
634                    s.set_handler(handler)
635                }
636                _ => {
637                    // only the TCP listener needs its socket set as the other sockets
638                    // set their handlers when the socket is created instead. we need to
639                    // delay setting the handler so we have a child ID to use and return
640                    // to the client when a socket is accepted, thus we can not accept them
641                    // immediately
642                    Err(NetworkError::Unsupported)
643                }
644            },
645            socket_id,
646            req_id,
647        )
648    }
649
650    fn process_interface(&mut self, req: RequestType, req_id: Option<u64>) -> BackgroundTask {
651        match req {
652            RequestType::Bridge {
653                network,
654                access_token,
655                security,
656            } => self.process_async_noop(
657                move |inner| async move { inner.bridge(&network, &access_token, security).await },
658                req_id,
659            ),
660            RequestType::Unbridge => {
661                self.process_async_noop(move |inner| async move { inner.unbridge().await }, req_id)
662            }
663            RequestType::DhcpAcquire => self.process_async_inner(
664                move |inner| async move { inner.dhcp_acquire().await },
665                |ret| match ret {
666                    Ok(ips) => ResponseType::IpAddressList(ips),
667                    Err(err) => ResponseType::Err(err),
668                },
669                req_id,
670            ),
671            RequestType::IpAdd { ip, prefix } => self.process_async_noop(
672                move |inner: Arc<dyn VirtualNetworking + Send + Sync>| async move {
673                    inner.ip_add(ip, prefix).await
674                },
675                req_id,
676            ),
677            RequestType::IpRemove(ip) => self.process_async_noop(
678                move |inner: Arc<dyn VirtualNetworking + Send + Sync>| async move {
679                    inner.ip_remove(ip).await
680                },
681                req_id,
682            ),
683            RequestType::IpClear => self.process_async_noop(
684                move |inner: Arc<dyn VirtualNetworking + Send + Sync>| async move {
685                    inner.ip_clear().await
686                },
687                req_id,
688            ),
689            RequestType::GetIpList => self.process_async_inner(
690                move |inner: Arc<dyn VirtualNetworking + Send + Sync>| async move {
691                    inner.ip_list().await
692                },
693                |ret| match ret {
694                    Ok(cidr) => ResponseType::CidrList(cidr),
695                    Err(err) => ResponseType::Err(err),
696                },
697                req_id,
698            ),
699            RequestType::GetMac => {
700                self.process_async_inner(
701                    move |inner: Arc<dyn VirtualNetworking + Send + Sync>| async move {
702                        inner.mac().await
703                    },
704                    |ret| match ret {
705                        Ok(mac) => ResponseType::Mac(mac),
706                        Err(err) => ResponseType::Err(err),
707                    },
708                    req_id,
709                )
710            }
711            RequestType::GatewaySet(ip) => self.process_async_noop(
712                move |inner: Arc<dyn VirtualNetworking + Send + Sync>| async move {
713                    inner.gateway_set(ip).await
714                },
715                req_id,
716            ),
717            RequestType::RouteAdd {
718                cidr,
719                via_router,
720                preferred_until,
721                expires_at,
722            } => self.process_async_noop(
723                move |inner: Arc<dyn VirtualNetworking + Send + Sync>| async move {
724                    inner
725                        .route_add(cidr, via_router, preferred_until, expires_at)
726                        .await
727                },
728                req_id,
729            ),
730            RequestType::RouteRemove(ip) => self.process_async_noop(
731                move |inner: Arc<dyn VirtualNetworking + Send + Sync>| async move {
732                    inner.route_remove(ip).await
733                },
734                req_id,
735            ),
736            RequestType::RouteClear => self.process_async_noop(
737                move |inner: Arc<dyn VirtualNetworking + Send + Sync>| async move {
738                    inner.route_clear().await
739                },
740                req_id,
741            ),
742            RequestType::GetRouteList => self.process_async_inner(
743                move |inner: Arc<dyn VirtualNetworking + Send + Sync>| async move {
744                    inner.route_list().await
745                },
746                |ret| match ret {
747                    Ok(routes) => ResponseType::RouteList(routes),
748                    Err(err) => ResponseType::Err(err),
749                },
750                req_id,
751            ),
752            RequestType::BindRaw(socket_id) => self.process_async_new_socket(
753                move |inner: Arc<dyn VirtualNetworking + Send + Sync>| async move {
754                    Ok(RemoteAdapterSocket::RawSocket(inner.bind_raw().await?))
755                },
756                socket_id,
757                req_id,
758            ),
759            RequestType::ListenTcp {
760                socket_id,
761                addr,
762                only_v6,
763                reuse_port,
764                reuse_addr,
765            } => self.process_async_new_socket(
766                move |inner: Arc<dyn VirtualNetworking + Send + Sync>| async move {
767                    Ok(RemoteAdapterSocket::TcpListener {
768                        socket: inner
769                            .listen_tcp(addr, only_v6, reuse_port, reuse_addr)
770                            .await?,
771                        next_accept: None,
772                    })
773                },
774                socket_id,
775                req_id,
776            ),
777            RequestType::BindUdp {
778                socket_id,
779                addr,
780                reuse_port,
781                reuse_addr,
782            } => self.process_async_new_socket(
783                move |inner: Arc<dyn VirtualNetworking + Send + Sync>| async move {
784                    Ok(RemoteAdapterSocket::UdpSocket(
785                        inner.bind_udp(addr, reuse_port, reuse_addr).await?,
786                    ))
787                },
788                socket_id,
789                req_id,
790            ),
791            RequestType::BindIcmp { socket_id, addr } => self.process_async_new_socket(
792                move |inner: Arc<dyn VirtualNetworking + Send + Sync>| async move {
793                    Ok(RemoteAdapterSocket::IcmpSocket(
794                        inner.bind_icmp(addr).await?,
795                    ))
796                },
797                socket_id,
798                req_id,
799            ),
800            RequestType::ConnectTcp {
801                socket_id,
802                addr,
803                peer,
804            } => self.process_async_new_socket(
805                move |inner: Arc<dyn VirtualNetworking + Send + Sync>| async move {
806                    Ok(RemoteAdapterSocket::TcpSocket(
807                        inner.connect_tcp(addr, peer).await?,
808                    ))
809                },
810                socket_id,
811                req_id,
812            ),
813            RequestType::Resolve {
814                host,
815                port,
816                dns_server,
817            } => self.process_async_inner(
818                move |inner: Arc<dyn VirtualNetworking + Send + Sync>| async move {
819                    inner.resolve(&host, port, dns_server).await
820                },
821                |ret| match ret {
822                    Ok(ips) => ResponseType::IpAddressList(ips),
823                    Err(err) => ResponseType::Err(err),
824                },
825                req_id,
826            ),
827            _ => req_id.and_then(|req_id| {
828                self.common.send(MessageResponse::ResponseToRequest {
829                    req_id,
830                    res: ResponseType::Err(NetworkError::Unsupported),
831                })
832            }),
833        }
834    }
835
836    fn process_socket(
837        &mut self,
838        socket_id: SocketId,
839        req: RequestType,
840        req_id: Option<u64>,
841    ) -> BackgroundTask {
842        match req {
843            RequestType::Flush => self.process_inner_noop(
844                move |socket| match socket {
845                    RemoteAdapterSocket::TcpSocket(s) => s.try_flush(),
846                    RemoteAdapterSocket::RawSocket(s) => s.try_flush(),
847                    _ => Err(NetworkError::Unsupported),
848                },
849                socket_id,
850                req_id,
851            ),
852            RequestType::Close => self.process_inner_noop(
853                move |socket| match socket {
854                    RemoteAdapterSocket::TcpSocket(s) => s.close(),
855                    _ => Err(NetworkError::Unsupported),
856                },
857                socket_id,
858                req_id,
859            ),
860            RequestType::BeginAccept(child_id) => {
861                self.process_inner_begin_accept(socket_id, child_id, req_id)
862            }
863            RequestType::GetAddrLocal => self.process_inner(
864                move |socket| match socket {
865                    RemoteAdapterSocket::TcpSocket(s) => s.addr_local(),
866                    RemoteAdapterSocket::TcpListener { socket: s, .. } => s.addr_local(),
867                    RemoteAdapterSocket::UdpSocket(s) => s.addr_local(),
868                    RemoteAdapterSocket::IcmpSocket(s) => s.addr_local(),
869                    RemoteAdapterSocket::RawSocket(s) => s.addr_local(),
870                },
871                |ret| match ret {
872                    Ok(addr) => ResponseType::SocketAddr(addr),
873                    Err(err) => ResponseType::Err(err),
874                },
875                socket_id,
876                req_id,
877            ),
878            RequestType::GetAddrPeer => self.process_inner(
879                move |socket| match socket {
880                    RemoteAdapterSocket::TcpSocket(s) => s.addr_peer().map(Some),
881                    RemoteAdapterSocket::TcpListener { .. } => Err(NetworkError::Unsupported),
882                    RemoteAdapterSocket::UdpSocket(s) => s.addr_peer(),
883                    RemoteAdapterSocket::IcmpSocket(_) => Err(NetworkError::Unsupported),
884                    RemoteAdapterSocket::RawSocket(_) => Err(NetworkError::Unsupported),
885                },
886                |ret| match ret {
887                    Ok(Some(addr)) => ResponseType::SocketAddr(addr),
888                    Ok(None) => ResponseType::None,
889                    Err(err) => ResponseType::Err(err),
890                },
891                socket_id,
892                req_id,
893            ),
894            RequestType::SetTtl(ttl) => self.process_inner_noop(
895                move |socket| match socket {
896                    RemoteAdapterSocket::TcpSocket(s) => s.set_ttl(ttl),
897                    RemoteAdapterSocket::TcpListener { socket: s, .. } => {
898                        s.set_ttl(ttl.try_into().unwrap_or_default())
899                    }
900                    RemoteAdapterSocket::UdpSocket(s) => s.set_ttl(ttl),
901                    RemoteAdapterSocket::IcmpSocket(s) => s.set_ttl(ttl),
902                    RemoteAdapterSocket::RawSocket(s) => s.set_ttl(ttl),
903                },
904                socket_id,
905                req_id,
906            ),
907            RequestType::GetTtl => self.process_inner(
908                move |socket| match socket {
909                    RemoteAdapterSocket::TcpSocket(s) => s.ttl(),
910                    RemoteAdapterSocket::TcpListener { socket: s, .. } => s.ttl().map(|t| t as u32),
911                    RemoteAdapterSocket::UdpSocket(s) => s.ttl(),
912                    RemoteAdapterSocket::IcmpSocket(s) => s.ttl(),
913                    RemoteAdapterSocket::RawSocket(s) => s.ttl(),
914                },
915                |ret| match ret {
916                    Ok(ttl) => ResponseType::Ttl(ttl),
917                    Err(err) => ResponseType::Err(err),
918                },
919                socket_id,
920                req_id,
921            ),
922            RequestType::GetStatus => self.process_inner(
923                move |socket| match socket {
924                    RemoteAdapterSocket::TcpSocket(s) => s.status(),
925                    RemoteAdapterSocket::TcpListener { .. } => Err(NetworkError::Unsupported),
926                    RemoteAdapterSocket::UdpSocket(s) => s.status(),
927                    RemoteAdapterSocket::IcmpSocket(s) => s.status(),
928                    RemoteAdapterSocket::RawSocket(s) => s.status(),
929                },
930                |ret| match ret {
931                    Ok(status) => ResponseType::Status(status),
932                    Err(err) => ResponseType::Err(err),
933                },
934                socket_id,
935                req_id,
936            ),
937            RequestType::SetLinger(linger) => self.process_inner_noop(
938                move |socket| match socket {
939                    RemoteAdapterSocket::TcpSocket(s) => s.set_linger(linger),
940                    _ => Err(NetworkError::Unsupported),
941                },
942                socket_id,
943                req_id,
944            ),
945            RequestType::GetLinger => self.process_inner(
946                move |socket| match socket {
947                    RemoteAdapterSocket::TcpSocket(s) => s.linger(),
948                    _ => Err(NetworkError::Unsupported),
949                },
950                |ret| match ret {
951                    Ok(Some(time)) => ResponseType::Duration(time),
952                    Ok(None) => ResponseType::None,
953                    Err(err) => ResponseType::Err(err),
954                },
955                socket_id,
956                req_id,
957            ),
958            RequestType::SetPromiscuous(promiscuous) => self.process_inner_noop(
959                move |socket| match socket {
960                    RemoteAdapterSocket::RawSocket(s) => s.set_promiscuous(promiscuous),
961                    _ => Err(NetworkError::Unsupported),
962                },
963                socket_id,
964                req_id,
965            ),
966            RequestType::GetPromiscuous => self.process_inner(
967                move |socket| match socket {
968                    RemoteAdapterSocket::RawSocket(s) => s.promiscuous(),
969                    _ => Err(NetworkError::Unsupported),
970                },
971                |ret| match ret {
972                    Ok(flag) => ResponseType::Flag(flag),
973                    Err(err) => ResponseType::Err(err),
974                },
975                socket_id,
976                req_id,
977            ),
978            RequestType::SetRecvBufSize(size) => self.process_inner_noop(
979                move |socket| match socket {
980                    RemoteAdapterSocket::TcpSocket(s) => {
981                        s.set_recv_buf_size(size.try_into().unwrap_or_default())
982                    }
983                    _ => Err(NetworkError::Unsupported),
984                },
985                socket_id,
986                req_id,
987            ),
988            RequestType::GetRecvBufSize => self.process_inner(
989                move |socket| match socket {
990                    RemoteAdapterSocket::TcpSocket(s) => s.recv_buf_size(),
991                    _ => Err(NetworkError::Unsupported),
992                },
993                |ret| match ret {
994                    Ok(amt) => ResponseType::Amount(amt as u64),
995                    Err(err) => ResponseType::Err(err),
996                },
997                socket_id,
998                req_id,
999            ),
1000            RequestType::SetSendBufSize(size) => self.process_inner_noop(
1001                move |socket| match socket {
1002                    RemoteAdapterSocket::TcpSocket(s) => {
1003                        s.set_send_buf_size(size.try_into().unwrap_or_default())
1004                    }
1005                    _ => Err(NetworkError::Unsupported),
1006                },
1007                socket_id,
1008                req_id,
1009            ),
1010            RequestType::GetSendBufSize => self.process_inner(
1011                move |socket| match socket {
1012                    RemoteAdapterSocket::TcpSocket(s) => s.send_buf_size(),
1013                    _ => Err(NetworkError::Unsupported),
1014                },
1015                |ret| match ret {
1016                    Ok(amt) => ResponseType::Amount(amt as u64),
1017                    Err(err) => ResponseType::Err(err),
1018                },
1019                socket_id,
1020                req_id,
1021            ),
1022            RequestType::SetNoDelay(reuse) => self.process_inner_noop(
1023                move |socket| match socket {
1024                    RemoteAdapterSocket::TcpSocket(s) => s.set_nodelay(reuse),
1025                    _ => Err(NetworkError::Unsupported),
1026                },
1027                socket_id,
1028                req_id,
1029            ),
1030            RequestType::GetNoDelay => self.process_inner(
1031                move |socket| match socket {
1032                    RemoteAdapterSocket::TcpSocket(s) => s.nodelay(),
1033                    _ => Err(NetworkError::Unsupported),
1034                },
1035                |ret| match ret {
1036                    Ok(flag) => ResponseType::Flag(flag),
1037                    Err(err) => ResponseType::Err(err),
1038                },
1039                socket_id,
1040                req_id,
1041            ),
1042            RequestType::SetKeepAlive(val) => self.process_inner_noop(
1043                move |socket| match socket {
1044                    RemoteAdapterSocket::TcpSocket(s) => s.set_keepalive(val),
1045                    _ => Err(NetworkError::Unsupported),
1046                },
1047                socket_id,
1048                req_id,
1049            ),
1050            RequestType::GetKeepAlive => self.process_inner(
1051                move |socket| match socket {
1052                    RemoteAdapterSocket::TcpSocket(s) => s.keepalive(),
1053                    _ => Err(NetworkError::Unsupported),
1054                },
1055                |ret| match ret {
1056                    Ok(flag) => ResponseType::Flag(flag),
1057                    Err(err) => ResponseType::Err(err),
1058                },
1059                socket_id,
1060                req_id,
1061            ),
1062            RequestType::SetDontRoute(val) => self.process_inner_noop(
1063                move |socket| match socket {
1064                    RemoteAdapterSocket::TcpSocket(s) => s.set_dontroute(val),
1065                    _ => Err(NetworkError::Unsupported),
1066                },
1067                socket_id,
1068                req_id,
1069            ),
1070            RequestType::GetDontRoute => self.process_inner(
1071                move |socket| match socket {
1072                    RemoteAdapterSocket::TcpSocket(s) => s.dontroute(),
1073                    _ => Err(NetworkError::Unsupported),
1074                },
1075                |ret| match ret {
1076                    Ok(flag) => ResponseType::Flag(flag),
1077                    Err(err) => ResponseType::Err(err),
1078                },
1079                socket_id,
1080                req_id,
1081            ),
1082            RequestType::Shutdown(shutdown) => self.process_inner_noop(
1083                move |socket| match socket {
1084                    RemoteAdapterSocket::TcpSocket(s) => s.shutdown(match shutdown {
1085                        crate::meta::Shutdown::Read => std::net::Shutdown::Read,
1086                        crate::meta::Shutdown::Write => std::net::Shutdown::Write,
1087                        crate::meta::Shutdown::Both => std::net::Shutdown::Both,
1088                    }),
1089                    _ => Err(NetworkError::Unsupported),
1090                },
1091                socket_id,
1092                req_id,
1093            ),
1094            RequestType::IsClosed => self.process_inner(
1095                move |socket| match socket {
1096                    RemoteAdapterSocket::TcpSocket(s) => Ok(s.is_closed()),
1097                    _ => Err(NetworkError::Unsupported),
1098                },
1099                |ret| match ret {
1100                    Ok(flag) => ResponseType::Flag(flag),
1101                    Err(err) => ResponseType::Err(err),
1102                },
1103                socket_id,
1104                req_id,
1105            ),
1106            RequestType::SetBroadcast(broadcast) => self.process_inner_noop(
1107                move |socket| match socket {
1108                    RemoteAdapterSocket::UdpSocket(s) => s.set_broadcast(broadcast),
1109                    _ => Err(NetworkError::Unsupported),
1110                },
1111                socket_id,
1112                req_id,
1113            ),
1114            RequestType::GetBroadcast => self.process_inner(
1115                move |socket| match socket {
1116                    RemoteAdapterSocket::UdpSocket(s) => s.broadcast(),
1117                    _ => Err(NetworkError::Unsupported),
1118                },
1119                |ret| match ret {
1120                    Ok(flag) => ResponseType::Flag(flag),
1121                    Err(err) => ResponseType::Err(err),
1122                },
1123                socket_id,
1124                req_id,
1125            ),
1126            RequestType::SetMulticastLoopV4(val) => self.process_inner_noop(
1127                move |socket| match socket {
1128                    RemoteAdapterSocket::UdpSocket(s) => s.set_multicast_loop_v4(val),
1129                    _ => Err(NetworkError::Unsupported),
1130                },
1131                socket_id,
1132                req_id,
1133            ),
1134            RequestType::GetMulticastLoopV4 => self.process_inner(
1135                move |socket| match socket {
1136                    RemoteAdapterSocket::UdpSocket(s) => s.multicast_loop_v4(),
1137                    _ => Err(NetworkError::Unsupported),
1138                },
1139                |ret| match ret {
1140                    Ok(flag) => ResponseType::Flag(flag),
1141                    Err(err) => ResponseType::Err(err),
1142                },
1143                socket_id,
1144                req_id,
1145            ),
1146            RequestType::SetMulticastLoopV6(val) => self.process_inner_noop(
1147                move |socket| match socket {
1148                    RemoteAdapterSocket::UdpSocket(s) => s.set_multicast_loop_v6(val),
1149                    _ => Err(NetworkError::Unsupported),
1150                },
1151                socket_id,
1152                req_id,
1153            ),
1154            RequestType::GetMulticastLoopV6 => self.process_inner(
1155                move |socket| match socket {
1156                    RemoteAdapterSocket::UdpSocket(s) => s.multicast_loop_v6(),
1157                    _ => Err(NetworkError::Unsupported),
1158                },
1159                |ret| match ret {
1160                    Ok(flag) => ResponseType::Flag(flag),
1161                    Err(err) => ResponseType::Err(err),
1162                },
1163                socket_id,
1164                req_id,
1165            ),
1166            RequestType::SetMulticastTtlV4(ttl) => self.process_inner_noop(
1167                move |socket| match socket {
1168                    RemoteAdapterSocket::UdpSocket(s) => s.set_multicast_ttl_v4(ttl),
1169                    _ => Err(NetworkError::Unsupported),
1170                },
1171                socket_id,
1172                req_id,
1173            ),
1174            RequestType::GetMulticastTtlV4 => self.process_inner(
1175                move |socket| match socket {
1176                    RemoteAdapterSocket::UdpSocket(s) => s.multicast_ttl_v4(),
1177                    _ => Err(NetworkError::Unsupported),
1178                },
1179                |ret| match ret {
1180                    Ok(ttl) => ResponseType::Ttl(ttl),
1181                    Err(err) => ResponseType::Err(err),
1182                },
1183                socket_id,
1184                req_id,
1185            ),
1186            RequestType::JoinMulticastV4 { multiaddr, iface } => self.process_inner_noop(
1187                move |socket| match socket {
1188                    RemoteAdapterSocket::UdpSocket(s) => s.join_multicast_v4(multiaddr, iface),
1189                    _ => Err(NetworkError::Unsupported),
1190                },
1191                socket_id,
1192                req_id,
1193            ),
1194            RequestType::LeaveMulticastV4 { multiaddr, iface } => self.process_inner_noop(
1195                move |socket| match socket {
1196                    RemoteAdapterSocket::UdpSocket(s) => s.leave_multicast_v4(multiaddr, iface),
1197                    _ => Err(NetworkError::Unsupported),
1198                },
1199                socket_id,
1200                req_id,
1201            ),
1202            RequestType::JoinMulticastV6 { multiaddr, iface } => self.process_inner_noop(
1203                move |socket| match socket {
1204                    RemoteAdapterSocket::UdpSocket(s) => s.join_multicast_v6(multiaddr, iface),
1205                    _ => Err(NetworkError::Unsupported),
1206                },
1207                socket_id,
1208                req_id,
1209            ),
1210            RequestType::LeaveMulticastV6 { multiaddr, iface } => self.process_inner_noop(
1211                move |socket| match socket {
1212                    RemoteAdapterSocket::UdpSocket(s) => s.leave_multicast_v6(multiaddr, iface),
1213                    _ => Err(NetworkError::Unsupported),
1214                },
1215                socket_id,
1216                req_id,
1217            ),
1218            _ => req_id.and_then(|req_id| {
1219                self.common.send(MessageResponse::ResponseToRequest {
1220                    req_id,
1221                    res: ResponseType::Err(NetworkError::Unsupported),
1222                })
1223            }),
1224        }
1225    }
1226}
1227
1228#[derive(Debug)]
1229enum RemoteAdapterSocket {
1230    TcpListener {
1231        socket: Box<dyn VirtualTcpListener + Sync + 'static>,
1232        next_accept: Option<SocketId>,
1233    },
1234    TcpSocket(Box<dyn VirtualTcpSocket + Sync + 'static>),
1235    UdpSocket(Box<dyn VirtualUdpSocket + Sync + 'static>),
1236    RawSocket(Box<dyn VirtualRawSocket + Sync + 'static>),
1237    IcmpSocket(Box<dyn VirtualIcmpSocket + Sync + 'static>),
1238}
1239
1240impl RemoteAdapterSocket {
1241    pub fn send(
1242        &mut self,
1243        common: &Arc<RemoteAdapterCommon>,
1244        socket_id: SocketId,
1245        data: Vec<u8>,
1246        req_id: Option<u64>,
1247    ) -> BackgroundTask {
1248        match self {
1249            Self::TcpSocket(this) => match this.try_send(&data) {
1250                Ok(amount) => {
1251                    if let Some(req_id) = req_id {
1252                        common.send(MessageResponse::Sent {
1253                            socket_id,
1254                            req_id,
1255                            amount: amount as u64,
1256                        })
1257                    } else {
1258                        None
1259                    }
1260                }
1261                Err(NetworkError::WouldBlock) => {
1262                    let common = common.clone();
1263                    Some(Box::pin(async move {
1264                        // We will stall the receiver so that back pressure is sent back to the
1265                        // sender and they don't overwhelm us with transmitting data.
1266                        let _stall_rx = common.stall_rx.clone().lock_owned().await;
1267
1268                        // We use a poller here that uses the handler to wake itself up
1269                        struct Poller {
1270                            common: Arc<RemoteAdapterCommon>,
1271                            socket_id: SocketId,
1272                            data: Vec<u8>,
1273                            req_id: Option<u64>,
1274                        }
1275                        impl Future for Poller {
1276                            type Output = BackgroundTask;
1277                            fn poll(
1278                                self: Pin<&mut Self>,
1279                                cx: &mut Context<'_>,
1280                            ) -> Poll<Self::Output> {
1281                                // We make sure the waker is registered with the interest driver which will
1282                                // wake up this poller when there is writeability
1283                                let mut guard = self.common.handler.state.lock().unwrap();
1284                                if !guard.driver_wakers.iter().any(|w| w.will_wake(cx.waker())) {
1285                                    guard.driver_wakers.push(cx.waker().clone());
1286                                }
1287                                drop(guard);
1288
1289                                let mut guard = self.common.sockets.lock().unwrap();
1290                                if let Some(RemoteAdapterSocket::TcpSocket(socket)) =
1291                                    guard.get_mut(&self.socket_id)
1292                                {
1293                                    match socket.try_send(&self.data) {
1294                                        Ok(amount) => {
1295                                            if let Some(req_id) = self.req_id {
1296                                                return Poll::Ready(self.common.send(
1297                                                    MessageResponse::Sent {
1298                                                        socket_id: self.socket_id,
1299                                                        req_id,
1300                                                        amount: amount as u64,
1301                                                    },
1302                                                ));
1303                                            } else {
1304                                                return Poll::Ready(None);
1305                                            }
1306                                        }
1307                                        Err(NetworkError::WouldBlock) => return Poll::Pending,
1308                                        Err(error) => {
1309                                            if let Some(req_id) = self.req_id {
1310                                                return Poll::Ready(self.common.send(
1311                                                    MessageResponse::SendError {
1312                                                        socket_id: self.socket_id,
1313                                                        req_id,
1314                                                        error,
1315                                                    },
1316                                                ));
1317                                            } else {
1318                                                return Poll::Ready(None);
1319                                            }
1320                                        }
1321                                    }
1322                                }
1323                                Poll::Ready(None)
1324                            }
1325                        }
1326
1327                        // Run the poller until this message is sent, or the socket fails
1328                        let background_task = Poller {
1329                            common,
1330                            socket_id,
1331                            data,
1332                            req_id,
1333                        }
1334                        .await;
1335
1336                        // There might be more work left to finish off the send operation
1337                        if let Some(background_task) = background_task {
1338                            background_task.await;
1339                        }
1340                    }))
1341                }
1342                Err(error) => {
1343                    if let Some(req_id) = req_id {
1344                        common.send(MessageResponse::SendError {
1345                            socket_id,
1346                            req_id,
1347                            error,
1348                        })
1349                    } else {
1350                        None
1351                    }
1352                }
1353            },
1354            Self::RawSocket(this) => {
1355                // when the RAW socket is overloaded we just silently drop the packet
1356                // rather than buffering it and retrying later - Ethernet packets are
1357                // not lossless. In reality most socket drivers under this remote socket
1358                // will always succeed on `try_send` with RawSockets as they are always
1359                // processed.
1360                if let Err(err) = this.try_send(&data) {
1361                    tracing::debug!("failed to send raw packet - {}", err);
1362                }
1363                None
1364            }
1365            _ => {
1366                if let Some(req_id) = req_id {
1367                    common.send(MessageResponse::SendError {
1368                        socket_id,
1369                        req_id,
1370                        error: NetworkError::Unsupported,
1371                    })
1372                } else {
1373                    None
1374                }
1375            }
1376        }
1377    }
1378    pub fn send_to(
1379        &mut self,
1380        common: &Arc<RemoteAdapterCommon>,
1381        socket_id: SocketId,
1382        data: Vec<u8>,
1383        addr: SocketAddr,
1384        req_id: u64,
1385    ) -> BackgroundTask {
1386        match self {
1387            Self::UdpSocket(this) => {
1388                // when the UDP socket is overloaded we just silently drop the packet
1389                // rather than buffering it and retrying later
1390                this.try_send_to(&data, addr).ok();
1391                None
1392            }
1393
1394            Self::IcmpSocket(this) => {
1395                // when the ICMP socket is overloaded we just silently drop the packet
1396                // rather than buffering it and retrying later
1397                this.try_send_to(&data, addr).ok();
1398                None
1399            }
1400            _ => common.send(MessageResponse::SendError {
1401                socket_id,
1402                req_id,
1403                error: NetworkError::Unsupported,
1404            }),
1405        }
1406    }
1407    pub fn drain_reads_and_accepts(
1408        &mut self,
1409        common: &Arc<RemoteAdapterCommon>,
1410        socket_id: SocketId,
1411    ) -> BackgroundTask {
1412        // We loop reading the socket until all the pending reads are either
1413        // being processed in a background task or they are empty
1414        let mut ret: FuturesOrdered<BoxFuture<'static, ()>> = Default::default();
1415        loop {
1416            break match self {
1417                Self::TcpListener {
1418                    socket,
1419                    next_accept,
1420                } => {
1421                    if next_accept.is_some() {
1422                        match socket.try_accept() {
1423                            Ok((mut child_socket, addr)) => {
1424                                let child_id = next_accept.take().unwrap();
1425
1426                                // We set the handler on the socket so that it can
1427                                let handler = Box::new(common.handler.clone().for_socket(child_id));
1428                                child_socket.set_handler(handler).ok();
1429
1430                                // We will fix up the socket in the background then notify
1431                                // the client of the new socket
1432                                let common = common.clone();
1433                                ret.push_back(Box::pin(async move {
1434                                    // Next we record the socket so that it is active
1435                                    {
1436                                        let child_socket =
1437                                            RemoteAdapterSocket::TcpSocket(child_socket);
1438                                        let mut guard = common.sockets.lock().unwrap();
1439                                        guard.insert(child_id, child_socket);
1440                                    }
1441
1442                                    // Lastly we tell the client about the new socket
1443                                    if let Some(task) = common.send(MessageResponse::FinishAccept {
1444                                        socket_id,
1445                                        child_id,
1446                                        addr,
1447                                    }) {
1448                                        task.await;
1449                                    }
1450                                }));
1451                            }
1452                            Err(NetworkError::WouldBlock) => {}
1453                            Err(err) => {
1454                                tracing::error!("failed to accept socket - {}", err);
1455                            }
1456                        }
1457                    }
1458                }
1459                Self::TcpSocket(this) => {
1460                    let mut chunk: [MaybeUninit<u8>; 10240] =
1461                        unsafe { MaybeUninit::uninit().assume_init() };
1462                    match this.try_recv(&mut chunk, false) {
1463                        Ok(0) => {}
1464                        Ok(amt) => {
1465                            let chunk_unsafe: &mut [MaybeUninit<u8>] = &mut chunk[..amt];
1466                            let chunk_unsafe: &mut [u8] =
1467                                unsafe { std::mem::transmute(chunk_unsafe) };
1468                            if let Some(task) = common.send(MessageResponse::Recv {
1469                                socket_id,
1470                                data: chunk_unsafe.to_vec(),
1471                            }) {
1472                                ret.push_back(task);
1473                            }
1474                            continue;
1475                        }
1476                        Err(_) => {}
1477                    }
1478                }
1479                Self::UdpSocket(this) => {
1480                    let mut chunk: [MaybeUninit<u8>; 10240] =
1481                        unsafe { MaybeUninit::uninit().assume_init() };
1482                    match this.try_recv_from(&mut chunk, false) {
1483                        Ok((0, _)) => {}
1484                        Ok((amt, addr)) => {
1485                            let chunk_unsafe: &mut [MaybeUninit<u8>] = &mut chunk[..amt];
1486                            let chunk_unsafe: &mut [u8] =
1487                                unsafe { std::mem::transmute(chunk_unsafe) };
1488                            if let Some(task) = common.send(MessageResponse::RecvWithAddr {
1489                                socket_id,
1490                                data: chunk_unsafe.to_vec(),
1491                                addr,
1492                            }) {
1493                                ret.push_back(task);
1494                            }
1495                            continue;
1496                        }
1497                        Err(_) => {}
1498                    }
1499                }
1500                Self::IcmpSocket(this) => {
1501                    let mut chunk: [MaybeUninit<u8>; 10240] =
1502                        unsafe { MaybeUninit::uninit().assume_init() };
1503                    match this.try_recv_from(&mut chunk, false) {
1504                        Ok((0, _)) => {}
1505                        Ok((amt, addr)) => {
1506                            let chunk_unsafe: &mut [MaybeUninit<u8>] = &mut chunk[..amt];
1507                            let chunk_unsafe: &mut [u8] =
1508                                unsafe { std::mem::transmute(chunk_unsafe) };
1509                            if let Some(task) = common.send(MessageResponse::RecvWithAddr {
1510                                socket_id,
1511                                data: chunk_unsafe.to_vec(),
1512                                addr,
1513                            }) {
1514                                ret.push_back(task);
1515                            }
1516                            continue;
1517                        }
1518                        Err(_) => {}
1519                    }
1520                }
1521                Self::RawSocket(this) => {
1522                    let mut chunk: [MaybeUninit<u8>; 10240] =
1523                        unsafe { MaybeUninit::uninit().assume_init() };
1524                    match this.try_recv(&mut chunk, false) {
1525                        Ok(0) => {}
1526                        Ok(amt) => {
1527                            let chunk_unsafe: &mut [MaybeUninit<u8>] = &mut chunk[..amt];
1528                            let chunk_unsafe: &mut [u8] =
1529                                unsafe { std::mem::transmute(chunk_unsafe) };
1530                            if let Some(task) = common.send(MessageResponse::Recv {
1531                                socket_id,
1532                                data: chunk_unsafe.to_vec(),
1533                            }) {
1534                                ret.push_back(task);
1535                            }
1536                            continue;
1537                        }
1538                        Err(_) => {}
1539                    }
1540                }
1541            };
1542        }
1543
1544        if ret.is_empty() {
1545            // There is nothing to process so we are done
1546            None
1547        } else {
1548            Some(Box::pin(async move {
1549                // Processes all the background tasks until completion
1550                let mut stream = ret;
1551                loop {
1552                    let (next, s) = StreamExt::into_future(stream).await;
1553                    if next.is_none() {
1554                        break;
1555                    }
1556                    stream = s;
1557                }
1558            }))
1559        }
1560    }
1561}
1562
1563#[derive(Debug, Default)]
1564struct RemoteAdapterHandlerState {
1565    readable: HashSet<SocketId>,
1566    driver_wakers: Vec<Waker>,
1567}
1568
1569#[derive(Debug, Default, Clone)]
1570struct RemoteAdapterHandler {
1571    socket_id: Option<SocketId>,
1572    state: Arc<Mutex<RemoteAdapterHandlerState>>,
1573}
1574impl RemoteAdapterHandler {
1575    pub fn for_socket(self, id: SocketId) -> Self {
1576        Self {
1577            socket_id: Some(id),
1578            state: self.state,
1579        }
1580    }
1581}
1582impl InterestHandler for RemoteAdapterHandler {
1583    fn push_interest(&mut self, interest: virtual_mio::InterestType) {
1584        let mut guard = self.state.lock().unwrap();
1585        guard.driver_wakers.drain(..).for_each(|w| w.wake());
1586        let socket_id = match self.socket_id {
1587            Some(s) => s,
1588            None => return,
1589        };
1590        if interest == virtual_mio::InterestType::Readable {
1591            guard.readable.insert(socket_id);
1592        }
1593    }
1594
1595    fn pop_interest(&mut self, interest: virtual_mio::InterestType) -> bool {
1596        let mut guard = self.state.lock().unwrap();
1597        let socket_id = match self.socket_id {
1598            Some(s) => s,
1599            None => return false,
1600        };
1601        if interest == virtual_mio::InterestType::Readable {
1602            return guard.readable.remove(&socket_id);
1603        }
1604        false
1605    }
1606
1607    fn has_interest(&self, interest: virtual_mio::InterestType) -> bool {
1608        let guard = self.state.lock().unwrap();
1609        let socket_id = match self.socket_id {
1610            Some(s) => s,
1611            None => return false,
1612        };
1613        if interest == virtual_mio::InterestType::Readable {
1614            return guard.readable.contains(&socket_id);
1615        }
1616        false
1617    }
1618}
1619
1620type SocketMap<T> = HashMap<SocketId, T>;
1621
1622#[derive(Debug)]
1623struct RemoteAdapterCommon {
1624    tx: RemoteTx<MessageResponse>,
1625    rx: Mutex<RemoteRx<MessageRequest>>,
1626    sockets: Mutex<SocketMap<RemoteAdapterSocket>>,
1627    socket_accept: Mutex<SocketMap<SocketId>>,
1628    handler: RemoteAdapterHandler,
1629
1630    // The stall guard will prevent reads while its held and there are background tasks running
1631    // (the idea behind this is to create back pressure so that the task list infinitely grow)
1632    stall_rx: Arc<tokio::sync::Mutex<()>>,
1633}
1634impl RemoteAdapterCommon {
1635    fn send(self: &Arc<Self>, req: MessageResponse) -> BackgroundTask {
1636        let this = self.clone();
1637        Some(Box::pin(async move {
1638            if let Err(err) = this.tx.send(req).await {
1639                tracing::debug!("failed to send message - {}", err);
1640            }
1641        }))
1642    }
1643}