virtual_net/
server.rs

1use crate::meta::{FrameSerializationFormat, ResponseType};
2use crate::rx_tx::{RemoteRx, RemoteTx, RemoteTxWakers};
3use crate::{IpCidr, IpRoute, NetworkError, SocketStatus, StreamSecurity, VirtualIcmpSocket};
4use crate::{
5    VirtualNetworking, VirtualRawSocket, VirtualTcpBoundSocket, VirtualTcpListener,
6    VirtualTcpSocket, VirtualUdpSocket,
7    meta::{MessageRequest, MessageResponse, RequestType, SocketId},
8};
9use futures_util::stream::FuturesOrdered;
10#[cfg(any(feature = "hyper", feature = "tokio-tungstenite"))]
11use futures_util::stream::{SplitSink, SplitStream};
12use futures_util::{Sink, Stream};
13use futures_util::{StreamExt, future::BoxFuture};
14use std::collections::HashSet;
15use std::mem::MaybeUninit;
16use std::net::IpAddr;
17use std::task::Waker;
18use std::time::Duration;
19
20#[cfg(feature = "hyper")]
21use hyper_util::rt::tokio::TokioIo;
22use std::{
23    collections::HashMap,
24    future::Future,
25    net::SocketAddr,
26    pin::Pin,
27    sync::{Arc, Mutex},
28    task::{Context, Poll},
29};
30use tokio::{
31    io::{AsyncRead, AsyncWrite},
32    sync::mpsc,
33};
34use tokio_serde::SymmetricallyFramed;
35use tokio_serde::formats::SymmetricalBincode;
36#[cfg(feature = "cbor")]
37use tokio_serde::formats::SymmetricalCbor;
38#[cfg(feature = "json")]
39use tokio_serde::formats::SymmetricalJson;
40#[cfg(feature = "messagepack")]
41use tokio_serde::formats::SymmetricalMessagePack;
42use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
43use virtual_mio::InterestHandler;
44
45type BackgroundTask = Option<BoxFuture<'static, ()>>;
46
47#[derive(Debug, Clone)]
48pub struct RemoteNetworkingServer {
49    #[allow(dead_code)]
50    common: Arc<RemoteAdapterCommon>,
51    inner: Arc<dyn VirtualNetworking + Send + Sync + 'static>,
52}
53
54impl RemoteNetworkingServer {
55    fn new(
56        tx: RemoteTx<MessageResponse>,
57        rx: RemoteRx<MessageRequest>,
58        work: mpsc::UnboundedReceiver<BoxFuture<'static, ()>>,
59        inner: Arc<dyn VirtualNetworking + Send + Sync + 'static>,
60    ) -> (Self, RemoteNetworkingServerDriver) {
61        let common = RemoteAdapterCommon {
62            tx,
63            rx: Mutex::new(rx),
64            sockets: Default::default(),
65            socket_accept: Default::default(),
66            handler: Default::default(),
67            stall_rx: Default::default(),
68        };
69        let common = Arc::new(common);
70
71        let driver = RemoteNetworkingServerDriver {
72            more_work: work,
73            tasks: Default::default(),
74            common: common.clone(),
75            inner: inner.clone(),
76        };
77        let networking = Self { common, inner };
78
79        (networking, driver)
80    }
81    /// Creates a new interface on the remote location using
82    /// a unique interface ID and a pair of channels
83    pub fn new_from_mpsc(
84        tx: mpsc::Sender<MessageResponse>,
85        rx: mpsc::Receiver<MessageRequest>,
86        inner: Arc<dyn VirtualNetworking + Send + Sync + 'static>,
87    ) -> (Self, RemoteNetworkingServerDriver) {
88        let (tx_work, rx_work) = mpsc::unbounded_channel();
89        let tx_wakers = RemoteTxWakers::default();
90
91        let tx = RemoteTx::Mpsc {
92            tx,
93            work: tx_work,
94            wakers: tx_wakers.clone(),
95        };
96        let rx = RemoteRx::Mpsc {
97            rx,
98            wakers: tx_wakers,
99        };
100        Self::new(tx, rx, rx_work, inner)
101    }
102
103    /// Creates a new interface on the remote location using
104    /// a unique interface ID and a pair of channels
105    pub fn new_from_async_io<TX, RX>(
106        tx: TX,
107        rx: RX,
108        format: FrameSerializationFormat,
109        inner: Arc<dyn VirtualNetworking + Send + Sync + 'static>,
110    ) -> (Self, RemoteNetworkingServerDriver)
111    where
112        TX: AsyncWrite + Send + 'static,
113        RX: AsyncRead + Send + 'static,
114    {
115        let tx = FramedWrite::new(tx, LengthDelimitedCodec::new());
116        let tx: Pin<Box<dyn Sink<MessageResponse, Error = std::io::Error> + Send + 'static>> =
117            match format {
118                FrameSerializationFormat::Bincode => {
119                    Box::pin(SymmetricallyFramed::new(tx, SymmetricalBincode::default()))
120                }
121                #[cfg(feature = "json")]
122                FrameSerializationFormat::Json => {
123                    Box::pin(SymmetricallyFramed::new(tx, SymmetricalJson::default()))
124                }
125                #[cfg(feature = "messagepack")]
126                FrameSerializationFormat::MessagePack => Box::pin(SymmetricallyFramed::new(
127                    tx,
128                    SymmetricalMessagePack::default(),
129                )),
130                #[cfg(feature = "cbor")]
131                FrameSerializationFormat::Cbor => {
132                    Box::pin(SymmetricallyFramed::new(tx, SymmetricalCbor::default()))
133                }
134            };
135
136        let rx = FramedRead::new(rx, LengthDelimitedCodec::new());
137        let rx: Pin<Box<dyn Stream<Item = std::io::Result<MessageRequest>> + Send + 'static>> =
138            match format {
139                FrameSerializationFormat::Bincode => {
140                    Box::pin(SymmetricallyFramed::new(rx, SymmetricalBincode::default()))
141                }
142                #[cfg(feature = "json")]
143                FrameSerializationFormat::Json => {
144                    Box::pin(SymmetricallyFramed::new(rx, SymmetricalJson::default()))
145                }
146                #[cfg(feature = "messagepack")]
147                FrameSerializationFormat::MessagePack => Box::pin(SymmetricallyFramed::new(
148                    rx,
149                    SymmetricalMessagePack::default(),
150                )),
151                #[cfg(feature = "cbor")]
152                FrameSerializationFormat::Cbor => {
153                    Box::pin(SymmetricallyFramed::new(rx, SymmetricalCbor::default()))
154                }
155            };
156
157        let (tx_work, rx_work) = mpsc::unbounded_channel();
158
159        let tx = RemoteTx::Stream {
160            tx: Arc::new(tokio::sync::Mutex::new(tx)),
161            work: tx_work,
162            wakers: RemoteTxWakers::default(),
163        };
164        let rx = RemoteRx::Stream { rx };
165        Self::new(tx, rx, rx_work, inner)
166    }
167
168    /// Creates a new interface on the remote location using
169    /// a unique interface ID and a pair of channels
170    #[cfg(feature = "hyper")]
171    pub fn new_from_hyper_ws_io(
172        tx: SplitSink<
173            hyper_tungstenite::WebSocketStream<TokioIo<hyper::upgrade::Upgraded>>,
174            hyper_tungstenite::tungstenite::Message,
175        >,
176        rx: SplitStream<hyper_tungstenite::WebSocketStream<TokioIo<hyper::upgrade::Upgraded>>>,
177        format: FrameSerializationFormat,
178        inner: Arc<dyn VirtualNetworking + Send + Sync + 'static>,
179    ) -> (Self, RemoteNetworkingServerDriver) {
180        let (tx_work, rx_work) = mpsc::unbounded_channel();
181
182        let tx = RemoteTx::HyperWebSocket {
183            tx: Arc::new(tokio::sync::Mutex::new(tx)),
184            work: tx_work,
185            wakers: RemoteTxWakers::default(),
186            format,
187        };
188        let rx = RemoteRx::HyperWebSocket { rx, format };
189        Self::new(tx, rx, rx_work, inner)
190    }
191
192    #[cfg(all(test, feature = "tokio"))]
193    pub(crate) fn socket_count_for_test(&self) -> usize {
194        self.common.sockets.lock().unwrap().len()
195    }
196}
197
198#[async_trait::async_trait]
199impl VirtualNetworking for RemoteNetworkingServer {
200    async fn bridge(
201        &self,
202        network: &str,
203        access_token: &str,
204        security: StreamSecurity,
205    ) -> Result<(), NetworkError> {
206        self.inner.bridge(network, access_token, security).await
207    }
208
209    async fn unbridge(&self) -> Result<(), NetworkError> {
210        self.inner.unbridge().await
211    }
212
213    async fn dhcp_acquire(&self) -> Result<Vec<IpAddr>, NetworkError> {
214        self.inner.dhcp_acquire().await
215    }
216
217    async fn ip_add(&self, ip: IpAddr, prefix: u8) -> Result<(), NetworkError> {
218        self.inner.ip_add(ip, prefix).await
219    }
220
221    async fn ip_remove(&self, ip: IpAddr) -> Result<(), NetworkError> {
222        self.inner.ip_remove(ip).await
223    }
224
225    async fn ip_clear(&self) -> Result<(), NetworkError> {
226        self.inner.ip_clear().await
227    }
228
229    async fn ip_list(&self) -> Result<Vec<IpCidr>, NetworkError> {
230        self.inner.ip_list().await
231    }
232
233    async fn mac(&self) -> Result<[u8; 6], NetworkError> {
234        self.inner.mac().await
235    }
236
237    async fn gateway_set(&self, ip: IpAddr) -> Result<(), NetworkError> {
238        self.inner.gateway_set(ip).await
239    }
240
241    async fn route_add(
242        &self,
243        cidr: IpCidr,
244        via_router: IpAddr,
245        preferred_until: Option<Duration>,
246        expires_at: Option<Duration>,
247    ) -> Result<(), NetworkError> {
248        self.inner
249            .route_add(cidr, via_router, preferred_until, expires_at)
250            .await
251    }
252
253    async fn route_remove(&self, cidr: IpAddr) -> Result<(), NetworkError> {
254        self.inner.route_remove(cidr).await
255    }
256
257    async fn route_clear(&self) -> Result<(), NetworkError> {
258        self.inner.route_clear().await
259    }
260
261    async fn route_list(&self) -> Result<Vec<IpRoute>, NetworkError> {
262        self.inner.route_list().await
263    }
264
265    async fn bind_raw(&self) -> Result<Box<dyn VirtualRawSocket + Sync>, NetworkError> {
266        self.inner.bind_raw().await
267    }
268
269    async fn listen_tcp(
270        &self,
271        addr: SocketAddr,
272        only_v6: bool,
273        reuse_port: bool,
274        reuse_addr: bool,
275    ) -> Result<Box<dyn VirtualTcpListener + Sync>, NetworkError> {
276        self.inner
277            .listen_tcp(addr, only_v6, reuse_port, reuse_addr)
278            .await
279    }
280
281    async fn bind_udp(
282        &self,
283        addr: SocketAddr,
284        reuse_port: bool,
285        reuse_addr: bool,
286    ) -> Result<Box<dyn VirtualUdpSocket + Sync>, NetworkError> {
287        self.inner.bind_udp(addr, reuse_port, reuse_addr).await
288    }
289
290    async fn bind_icmp(
291        &self,
292        addr: IpAddr,
293    ) -> Result<Box<dyn VirtualIcmpSocket + Sync>, NetworkError> {
294        self.inner.bind_icmp(addr).await
295    }
296
297    async fn connect_tcp(
298        &self,
299        addr: SocketAddr,
300        peer: SocketAddr,
301    ) -> Result<Box<dyn VirtualTcpSocket + Sync>, NetworkError> {
302        self.inner.connect_tcp(addr, peer).await
303    }
304
305    async fn resolve(
306        &self,
307        host: &str,
308        port: Option<u16>,
309        dns_server: Option<IpAddr>,
310    ) -> Result<Vec<IpAddr>, NetworkError> {
311        self.inner.resolve(host, port, dns_server).await
312    }
313}
314
315pin_project_lite::pin_project! {
316    pub struct RemoteNetworkingServerDriver {
317        common: Arc<RemoteAdapterCommon>,
318        more_work: mpsc::UnboundedReceiver<BoxFuture<'static, ()>>,
319        #[pin]
320        tasks: FuturesOrdered<BoxFuture<'static, ()>>,
321        inner: Arc<dyn VirtualNetworking + Send + Sync + 'static>,
322    }
323}
324
325impl Future for RemoteNetworkingServerDriver {
326    type Output = ();
327
328    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
329        // We register the waker into the interest of the sockets so
330        // that it is woken when something is ready to read or write
331        let readable = {
332            let mut guard = self.common.handler.state.lock().unwrap();
333            if !guard.driver_wakers.iter().any(|w| w.will_wake(cx.waker())) {
334                guard.driver_wakers.push(cx.waker().clone());
335            }
336            guard.readable.drain().collect()
337        };
338        let readable: Vec<_> = readable;
339
340        {
341            // When a socket is marked as readable then we should drain all the data
342            // from it and start sending it to the client
343            let common = self.common.clone();
344            let mut guard = common.sockets.lock().unwrap();
345            for socket_id in readable {
346                if let Some(task) = guard
347                    .get_mut(&socket_id)
348                    .map(|s| s.drain_reads_and_accepts(&common, socket_id))
349                    .unwrap_or(None)
350                {
351                    self.tasks.push_back(task);
352                }
353            }
354        }
355
356        // This guard will be held while the pipeline is not currently
357        // stalled by some back pressure. It is only acquired when there
358        // is background tasks being processed
359        let mut not_stalled_guard = None;
360
361        // We loop until the waker is registered with the receiving stream
362        // and all the background tasks
363        loop {
364            // Background tasks are sent to this driver in certain circumstances
365            while let Poll::Ready(Some(work)) = Pin::new(&mut self.more_work).poll_recv(cx) {
366                self.tasks.push_back(work);
367            }
368
369            // Background work basically stalls the stream until its all processed
370            // which creates back pressure on the client so that they don't overload
371            // the system
372            match self.tasks.poll_next_unpin(cx) {
373                Poll::Ready(Some(_)) => continue,
374                Poll::Ready(None) => {
375                    not_stalled_guard.take();
376                }
377                Poll::Pending if not_stalled_guard.is_none() => {
378                    match self.common.stall_rx.clone().try_lock_owned() {
379                        Ok(guard) => {
380                            not_stalled_guard.replace(guard);
381                        }
382                        _ => {
383                            return Poll::Pending;
384                        }
385                    }
386                }
387                Poll::Pending => {}
388            };
389
390            // We grab the next message sent by the client to us
391            let msg = {
392                let mut rx_guard = self.common.rx.lock().unwrap();
393                rx_guard.poll(cx)
394            };
395            return match msg {
396                Poll::Ready(Some(msg)) => {
397                    if let Some(task) = self.process(msg) {
398                        // With some messages we process there are background tasks that need to
399                        // be further driver to completion by the driver
400                        self.tasks.push_back(task)
401                    };
402                    continue;
403                }
404                Poll::Ready(None) => Poll::Ready(()),
405                Poll::Pending => Poll::Pending,
406            };
407        }
408    }
409}
410
411impl RemoteNetworkingServerDriver {
412    fn process(&mut self, msg: MessageRequest) -> BackgroundTask {
413        match msg {
414            MessageRequest::Send {
415                socket,
416                data,
417                req_id,
418            } => self.process_send(socket, data, req_id),
419            MessageRequest::SendTo {
420                socket,
421                data,
422                addr,
423                req_id,
424            } => self.process_send_to(socket, data, addr, req_id),
425            MessageRequest::Interface { req, req_id } => self.process_interface(req, req_id),
426            MessageRequest::Socket {
427                socket,
428                req,
429                req_id,
430            } => self.process_socket(socket, req, req_id),
431            MessageRequest::Reconnect => None,
432        }
433    }
434
435    fn process_send(
436        &mut self,
437        socket_id: SocketId,
438        data: Vec<u8>,
439        req_id: Option<u64>,
440    ) -> BackgroundTask {
441        let mut guard = self.common.sockets.lock().unwrap();
442        guard
443            .get_mut(&socket_id)
444            .map(|s| s.send(&self.common, socket_id, data, req_id))
445            .unwrap_or_else(|| {
446                tracing::debug!("orphaned socket {:?}", socket_id);
447                None
448            })
449    }
450
451    fn process_send_to(
452        &mut self,
453        socket_id: SocketId,
454        data: Vec<u8>,
455        addr: SocketAddr,
456        req_id: Option<u64>,
457    ) -> BackgroundTask {
458        let mut guard = self.common.sockets.lock().unwrap();
459        guard
460            .get_mut(&socket_id)
461            .map(|s| {
462                req_id.and_then(|req_id| s.send_to(&self.common, socket_id, data, addr, req_id))
463            })
464            .unwrap_or(None)
465    }
466
467    fn process_async<F>(future: F) -> BackgroundTask
468    where
469        F: Future<Output = BackgroundTask> + Send + 'static,
470    {
471        Some(Box::pin(async move {
472            let background_task = future.await;
473            if let Some(background_task) = background_task {
474                background_task.await;
475            }
476        }))
477    }
478
479    fn process_async_inner<F, Fut, T>(
480        &self,
481        work: F,
482        transmute: T,
483        req_id: Option<u64>,
484    ) -> BackgroundTask
485    where
486        F: FnOnce(Arc<dyn VirtualNetworking + Send + Sync>) -> Fut + Send + 'static,
487        Fut: Future + Send + 'static,
488        T: FnOnce(Fut::Output) -> ResponseType + Send + 'static,
489    {
490        let inner = self.inner.clone();
491        let common = self.common.clone();
492        Self::process_async(async move {
493            let future = work(inner);
494            let ret = future.await;
495            req_id.and_then(|req_id| {
496                common.send(MessageResponse::ResponseToRequest {
497                    req_id,
498                    res: transmute(ret),
499                })
500            })
501        })
502    }
503
504    fn process_async_noop<F, Fut>(&self, work: F, req_id: Option<u64>) -> BackgroundTask
505    where
506        F: FnOnce(Arc<dyn VirtualNetworking + Send + Sync>) -> Fut + Send + 'static,
507        Fut: Future<Output = Result<(), NetworkError>> + Send + 'static,
508    {
509        self.process_async_inner(
510            work,
511            move |ret| match ret {
512                Ok(()) => ResponseType::None,
513                Err(err) => ResponseType::Err(err),
514            },
515            req_id,
516        )
517    }
518
519    fn process_async_new_socket<F, Fut>(
520        &self,
521        work: F,
522        socket_id: SocketId,
523        req_id: Option<u64>,
524    ) -> BackgroundTask
525    where
526        F: FnOnce(Arc<dyn VirtualNetworking + Send + Sync>) -> Fut + Send + 'static,
527        Fut: Future<Output = Result<RemoteAdapterSocket, NetworkError>> + Send + 'static,
528    {
529        let common = self.common.clone();
530        self.process_async_inner(
531            work,
532            move |ret| match ret {
533                Ok(mut socket) => {
534                    let handler = Box::new(common.handler.clone().for_socket(socket_id));
535
536                    let err = match &mut socket {
537                        RemoteAdapterSocket::TcpListener { .. } => {
538                            // we do not attach the handler immediately with new TPC listeners as we
539                            // only want it to trigger when the BeginAccept message is received with
540                            // a child ID we can actually use
541                            Ok(())
542                        }
543                        RemoteAdapterSocket::BoundTcp(_) => Ok(()),
544                        RemoteAdapterSocket::TcpSocket(s) => s.set_handler(handler),
545                        RemoteAdapterSocket::UdpSocket(s) => s.set_handler(handler),
546                        RemoteAdapterSocket::IcmpSocket(s) => s.set_handler(handler),
547                        RemoteAdapterSocket::RawSocket(s) => s.set_handler(handler),
548                    };
549                    if let Err(err) = err {
550                        return ResponseType::Err(err);
551                    }
552
553                    let mut guard = common.sockets.lock().unwrap();
554                    guard.insert(socket_id, socket);
555
556                    ResponseType::Socket(socket_id)
557                }
558                Err(err) => ResponseType::Err(err),
559            },
560            req_id,
561        )
562    }
563
564    fn process_inner<F, R, T>(
565        &self,
566        work: F,
567        transmute: T,
568        socket_id: SocketId,
569        req_id: Option<u64>,
570    ) -> BackgroundTask
571    where
572        F: FnOnce(&mut RemoteAdapterSocket) -> R + Send + 'static,
573        T: FnOnce(R) -> ResponseType + Send + 'static,
574    {
575        let ret = {
576            let mut guard = self.common.sockets.lock().unwrap();
577            let socket = match guard.get_mut(&socket_id) {
578                Some(s) => s,
579                None => {
580                    return req_id.and_then(|req_id| {
581                        self.common.send(MessageResponse::ResponseToRequest {
582                            req_id,
583                            res: ResponseType::Err(NetworkError::InvalidFd),
584                        })
585                    });
586                }
587            };
588            work(socket)
589        };
590        req_id.and_then(|req_id| {
591            self.common.send(MessageResponse::ResponseToRequest {
592                req_id,
593                res: transmute(ret),
594            })
595        })
596    }
597
598    fn process_inner_noop<F>(
599        &self,
600        work: F,
601        socket_id: SocketId,
602        req_id: Option<u64>,
603    ) -> BackgroundTask
604    where
605        F: FnOnce(&mut RemoteAdapterSocket) -> Result<(), NetworkError> + Send + 'static,
606    {
607        self.process_inner(
608            work,
609            move |ret| match ret {
610                Ok(()) => ResponseType::None,
611                Err(err) => ResponseType::Err(err),
612            },
613            socket_id,
614            req_id,
615        )
616    }
617
618    fn process_inner_begin_accept(
619        &self,
620        socket_id: SocketId,
621        child_id: SocketId,
622        req_id: Option<u64>,
623    ) -> BackgroundTask {
624        // We record the child socket so it can be used on the next accepted socket
625        {
626            let mut guard = self.common.socket_accept.lock().unwrap();
627            guard.insert(socket_id, child_id);
628        }
629
630        // Now we attach the handler to the main listening socket
631        let mut handler = Box::new(self.common.handler.clone().for_socket(socket_id));
632        handler.push_interest(virtual_mio::InterestType::Readable);
633        self.process_inner_noop(
634            move |socket| match socket {
635                RemoteAdapterSocket::TcpListener {
636                    socket: s,
637                    next_accept,
638                    ..
639                } => {
640                    next_accept.replace(child_id);
641                    s.set_handler(handler)
642                }
643                _ => {
644                    // only the TCP listener needs its socket set as the other sockets
645                    // set their handlers when the socket is created instead. we need to
646                    // delay setting the handler so we have a child ID to use and return
647                    // to the client when a socket is accepted, thus we can not accept them
648                    // immediately
649                    Err(NetworkError::Unsupported)
650                }
651            },
652            socket_id,
653            req_id,
654        )
655    }
656
657    fn process_interface(&mut self, req: RequestType, req_id: Option<u64>) -> BackgroundTask {
658        match req {
659            RequestType::Bridge {
660                network,
661                access_token,
662                security,
663            } => self.process_async_noop(
664                move |inner| async move { inner.bridge(&network, &access_token, security).await },
665                req_id,
666            ),
667            RequestType::Unbridge => {
668                self.process_async_noop(move |inner| async move { inner.unbridge().await }, req_id)
669            }
670            RequestType::DhcpAcquire => self.process_async_inner(
671                move |inner| async move { inner.dhcp_acquire().await },
672                |ret| match ret {
673                    Ok(ips) => ResponseType::IpAddressList(ips),
674                    Err(err) => ResponseType::Err(err),
675                },
676                req_id,
677            ),
678            RequestType::IpAdd { ip, prefix } => self.process_async_noop(
679                move |inner: Arc<dyn VirtualNetworking + Send + Sync>| async move {
680                    inner.ip_add(ip, prefix).await
681                },
682                req_id,
683            ),
684            RequestType::IpRemove(ip) => self.process_async_noop(
685                move |inner: Arc<dyn VirtualNetworking + Send + Sync>| async move {
686                    inner.ip_remove(ip).await
687                },
688                req_id,
689            ),
690            RequestType::IpClear => self.process_async_noop(
691                move |inner: Arc<dyn VirtualNetworking + Send + Sync>| async move {
692                    inner.ip_clear().await
693                },
694                req_id,
695            ),
696            RequestType::GetIpList => self.process_async_inner(
697                move |inner: Arc<dyn VirtualNetworking + Send + Sync>| async move {
698                    inner.ip_list().await
699                },
700                |ret| match ret {
701                    Ok(cidr) => ResponseType::CidrList(cidr),
702                    Err(err) => ResponseType::Err(err),
703                },
704                req_id,
705            ),
706            RequestType::GetMac => {
707                self.process_async_inner(
708                    move |inner: Arc<dyn VirtualNetworking + Send + Sync>| async move {
709                        inner.mac().await
710                    },
711                    |ret| match ret {
712                        Ok(mac) => ResponseType::Mac(mac),
713                        Err(err) => ResponseType::Err(err),
714                    },
715                    req_id,
716                )
717            }
718            RequestType::GatewaySet(ip) => self.process_async_noop(
719                move |inner: Arc<dyn VirtualNetworking + Send + Sync>| async move {
720                    inner.gateway_set(ip).await
721                },
722                req_id,
723            ),
724            RequestType::RouteAdd {
725                cidr,
726                via_router,
727                preferred_until,
728                expires_at,
729            } => self.process_async_noop(
730                move |inner: Arc<dyn VirtualNetworking + Send + Sync>| async move {
731                    inner
732                        .route_add(cidr, via_router, preferred_until, expires_at)
733                        .await
734                },
735                req_id,
736            ),
737            RequestType::RouteRemove(ip) => self.process_async_noop(
738                move |inner: Arc<dyn VirtualNetworking + Send + Sync>| async move {
739                    inner.route_remove(ip).await
740                },
741                req_id,
742            ),
743            RequestType::RouteClear => self.process_async_noop(
744                move |inner: Arc<dyn VirtualNetworking + Send + Sync>| async move {
745                    inner.route_clear().await
746                },
747                req_id,
748            ),
749            RequestType::GetRouteList => self.process_async_inner(
750                move |inner: Arc<dyn VirtualNetworking + Send + Sync>| async move {
751                    inner.route_list().await
752                },
753                |ret| match ret {
754                    Ok(routes) => ResponseType::RouteList(routes),
755                    Err(err) => ResponseType::Err(err),
756                },
757                req_id,
758            ),
759            RequestType::BindRaw(socket_id) => self.process_async_new_socket(
760                move |inner: Arc<dyn VirtualNetworking + Send + Sync>| async move {
761                    Ok(RemoteAdapterSocket::RawSocket(inner.bind_raw().await?))
762                },
763                socket_id,
764                req_id,
765            ),
766            RequestType::BindTcp {
767                socket_id,
768                addr,
769                only_v6,
770                reuse_port,
771                reuse_addr,
772            } => self.process_async_new_socket(
773                move |inner: Arc<dyn VirtualNetworking + Send + Sync>| async move {
774                    Ok(RemoteAdapterSocket::BoundTcp(
775                        inner
776                            .bind_tcp(addr, only_v6, reuse_port, reuse_addr)
777                            .await?,
778                    ))
779                },
780                socket_id,
781                req_id,
782            ),
783            RequestType::ListenTcp {
784                socket_id,
785                addr,
786                only_v6,
787                reuse_port,
788                reuse_addr,
789            } => self.process_async_new_socket(
790                move |inner: Arc<dyn VirtualNetworking + Send + Sync>| async move {
791                    Ok(RemoteAdapterSocket::TcpListener {
792                        socket: inner
793                            .listen_tcp(addr, only_v6, reuse_port, reuse_addr)
794                            .await?,
795                        next_accept: None,
796                    })
797                },
798                socket_id,
799                req_id,
800            ),
801            RequestType::BindUdp {
802                socket_id,
803                addr,
804                reuse_port,
805                reuse_addr,
806            } => self.process_async_new_socket(
807                move |inner: Arc<dyn VirtualNetworking + Send + Sync>| async move {
808                    Ok(RemoteAdapterSocket::UdpSocket(
809                        inner.bind_udp(addr, reuse_port, reuse_addr).await?,
810                    ))
811                },
812                socket_id,
813                req_id,
814            ),
815            RequestType::BindIcmp { socket_id, addr } => self.process_async_new_socket(
816                move |inner: Arc<dyn VirtualNetworking + Send + Sync>| async move {
817                    Ok(RemoteAdapterSocket::IcmpSocket(
818                        inner.bind_icmp(addr).await?,
819                    ))
820                },
821                socket_id,
822                req_id,
823            ),
824            RequestType::ConnectTcp {
825                socket_id,
826                addr,
827                peer,
828            } => self.process_async_new_socket(
829                move |inner: Arc<dyn VirtualNetworking + Send + Sync>| async move {
830                    Ok(RemoteAdapterSocket::TcpSocket(
831                        inner.connect_tcp(addr, peer).await?,
832                    ))
833                },
834                socket_id,
835                req_id,
836            ),
837            RequestType::Resolve {
838                host,
839                port,
840                dns_server,
841            } => self.process_async_inner(
842                move |inner: Arc<dyn VirtualNetworking + Send + Sync>| async move {
843                    inner.resolve(&host, port, dns_server).await
844                },
845                |ret| match ret {
846                    Ok(ips) => ResponseType::IpAddressList(ips),
847                    Err(err) => ResponseType::Err(err),
848                },
849                req_id,
850            ),
851            _ => req_id.and_then(|req_id| {
852                self.common.send(MessageResponse::ResponseToRequest {
853                    req_id,
854                    res: ResponseType::Err(NetworkError::Unsupported),
855                })
856            }),
857        }
858    }
859
860    fn process_socket(
861        &mut self,
862        socket_id: SocketId,
863        req: RequestType,
864        req_id: Option<u64>,
865    ) -> BackgroundTask {
866        match req {
867            RequestType::Flush => self.process_inner_noop(
868                move |socket| match socket {
869                    RemoteAdapterSocket::TcpSocket(s) => s.try_flush(),
870                    RemoteAdapterSocket::RawSocket(s) => s.try_flush(),
871                    _ => Err(NetworkError::Unsupported),
872                },
873                socket_id,
874                req_id,
875            ),
876            RequestType::Close => {
877                let res = {
878                    let mut guard = self.common.sockets.lock().unwrap();
879                    self.common.socket_accept.lock().unwrap().remove(&socket_id);
880                    match guard.remove(&socket_id) {
881                        Some(RemoteAdapterSocket::TcpSocket(mut socket)) => socket.close(),
882                        Some(_) => Ok(()),
883                        None => Err(NetworkError::InvalidFd),
884                    }
885                };
886                req_id.and_then(|req_id| {
887                    self.common.send(MessageResponse::ResponseToRequest {
888                        req_id,
889                        res: match res {
890                            Ok(()) => ResponseType::None,
891                            Err(err) => ResponseType::Err(err),
892                        },
893                    })
894                })
895            }
896            RequestType::ListenBound => {
897                let res = {
898                    let mut guard = self.common.sockets.lock().unwrap();
899                    match guard.get_mut(&socket_id) {
900                        Some(socket) => match socket {
901                            RemoteAdapterSocket::BoundTcp(bound) => match bound.listen() {
902                                Ok(listener) => {
903                                    *socket = RemoteAdapterSocket::TcpListener {
904                                        socket: listener,
905                                        next_accept: None,
906                                    };
907                                    Ok(())
908                                }
909                                Err(err) => Err(err),
910                            },
911                            _ => Err(NetworkError::Unsupported),
912                        },
913                        None => Err(NetworkError::InvalidFd),
914                    }
915                };
916                req_id.and_then(|req_id| {
917                    self.common.send(MessageResponse::ResponseToRequest {
918                        req_id,
919                        res: match res {
920                            Ok(()) => ResponseType::None,
921                            Err(err) => ResponseType::Err(err),
922                        },
923                    })
924                })
925            }
926            RequestType::ConnectBound { peer } => {
927                let res = {
928                    let mut guard = self.common.sockets.lock().unwrap();
929                    match guard.get_mut(&socket_id) {
930                        Some(socket) => match socket {
931                            RemoteAdapterSocket::BoundTcp(bound) => match bound.connect(peer) {
932                                Ok(connected) => {
933                                    *socket = RemoteAdapterSocket::TcpSocket(connected);
934                                    Ok(())
935                                }
936                                Err(err) => Err(err),
937                            },
938                            _ => Err(NetworkError::Unsupported),
939                        },
940                        None => Err(NetworkError::InvalidFd),
941                    }
942                };
943                req_id.and_then(|req_id| {
944                    self.common.send(MessageResponse::ResponseToRequest {
945                        req_id,
946                        res: match res {
947                            Ok(()) => ResponseType::None,
948                            Err(err) => ResponseType::Err(err),
949                        },
950                    })
951                })
952            }
953            RequestType::BeginAccept(child_id) => {
954                self.process_inner_begin_accept(socket_id, child_id, req_id)
955            }
956            RequestType::GetAddrLocal => self.process_inner(
957                move |socket| match socket {
958                    RemoteAdapterSocket::BoundTcp(s) => s.addr_local(),
959                    RemoteAdapterSocket::TcpSocket(s) => s.addr_local(),
960                    RemoteAdapterSocket::TcpListener { socket: s, .. } => s.addr_local(),
961                    RemoteAdapterSocket::UdpSocket(s) => s.addr_local(),
962                    RemoteAdapterSocket::IcmpSocket(s) => s.addr_local(),
963                    RemoteAdapterSocket::RawSocket(s) => s.addr_local(),
964                },
965                |ret| match ret {
966                    Ok(addr) => ResponseType::SocketAddr(addr),
967                    Err(err) => ResponseType::Err(err),
968                },
969                socket_id,
970                req_id,
971            ),
972            RequestType::GetAddrPeer => self.process_inner(
973                move |socket| match socket {
974                    RemoteAdapterSocket::BoundTcp(_) => Err(NetworkError::Unsupported),
975                    RemoteAdapterSocket::TcpSocket(s) => s.addr_peer().map(Some),
976                    RemoteAdapterSocket::TcpListener { .. } => Err(NetworkError::Unsupported),
977                    RemoteAdapterSocket::UdpSocket(s) => s.addr_peer(),
978                    RemoteAdapterSocket::IcmpSocket(_) => Err(NetworkError::Unsupported),
979                    RemoteAdapterSocket::RawSocket(_) => Err(NetworkError::Unsupported),
980                },
981                |ret| match ret {
982                    Ok(Some(addr)) => ResponseType::SocketAddr(addr),
983                    Ok(None) => ResponseType::None,
984                    Err(err) => ResponseType::Err(err),
985                },
986                socket_id,
987                req_id,
988            ),
989            RequestType::SetTtl(ttl) => self.process_inner_noop(
990                move |socket| match socket {
991                    RemoteAdapterSocket::BoundTcp(s) => s.set_ttl(ttl),
992                    RemoteAdapterSocket::TcpSocket(s) => s.set_ttl(ttl),
993                    RemoteAdapterSocket::TcpListener { socket: s, .. } => {
994                        s.set_ttl(ttl.try_into().unwrap_or_default())
995                    }
996                    RemoteAdapterSocket::UdpSocket(s) => s.set_ttl(ttl),
997                    RemoteAdapterSocket::IcmpSocket(s) => s.set_ttl(ttl),
998                    RemoteAdapterSocket::RawSocket(s) => s.set_ttl(ttl),
999                },
1000                socket_id,
1001                req_id,
1002            ),
1003            RequestType::GetTtl => self.process_inner(
1004                move |socket| match socket {
1005                    RemoteAdapterSocket::BoundTcp(s) => s.ttl(),
1006                    RemoteAdapterSocket::TcpSocket(s) => s.ttl(),
1007                    RemoteAdapterSocket::TcpListener { socket: s, .. } => s.ttl().map(|t| t as u32),
1008                    RemoteAdapterSocket::UdpSocket(s) => s.ttl(),
1009                    RemoteAdapterSocket::IcmpSocket(s) => s.ttl(),
1010                    RemoteAdapterSocket::RawSocket(s) => s.ttl(),
1011                },
1012                |ret| match ret {
1013                    Ok(ttl) => ResponseType::Ttl(ttl),
1014                    Err(err) => ResponseType::Err(err),
1015                },
1016                socket_id,
1017                req_id,
1018            ),
1019            RequestType::GetStatus => self.process_inner(
1020                move |socket| match socket {
1021                    RemoteAdapterSocket::BoundTcp(_) => Ok(SocketStatus::Opened),
1022                    RemoteAdapterSocket::TcpSocket(s) => s.status(),
1023                    RemoteAdapterSocket::TcpListener { .. } => Err(NetworkError::Unsupported),
1024                    RemoteAdapterSocket::UdpSocket(s) => s.status(),
1025                    RemoteAdapterSocket::IcmpSocket(s) => s.status(),
1026                    RemoteAdapterSocket::RawSocket(s) => s.status(),
1027                },
1028                |ret| match ret {
1029                    Ok(status) => ResponseType::Status(status),
1030                    Err(err) => ResponseType::Err(err),
1031                },
1032                socket_id,
1033                req_id,
1034            ),
1035            RequestType::SetLinger(linger) => self.process_inner_noop(
1036                move |socket| match socket {
1037                    RemoteAdapterSocket::TcpSocket(s) => s.set_linger(linger),
1038                    _ => Err(NetworkError::Unsupported),
1039                },
1040                socket_id,
1041                req_id,
1042            ),
1043            RequestType::GetLinger => self.process_inner(
1044                move |socket| match socket {
1045                    RemoteAdapterSocket::TcpSocket(s) => s.linger(),
1046                    _ => Err(NetworkError::Unsupported),
1047                },
1048                |ret| match ret {
1049                    Ok(Some(time)) => ResponseType::Duration(time),
1050                    Ok(None) => ResponseType::None,
1051                    Err(err) => ResponseType::Err(err),
1052                },
1053                socket_id,
1054                req_id,
1055            ),
1056            RequestType::SetPromiscuous(promiscuous) => self.process_inner_noop(
1057                move |socket| match socket {
1058                    RemoteAdapterSocket::RawSocket(s) => s.set_promiscuous(promiscuous),
1059                    _ => Err(NetworkError::Unsupported),
1060                },
1061                socket_id,
1062                req_id,
1063            ),
1064            RequestType::GetPromiscuous => self.process_inner(
1065                move |socket| match socket {
1066                    RemoteAdapterSocket::RawSocket(s) => s.promiscuous(),
1067                    _ => Err(NetworkError::Unsupported),
1068                },
1069                |ret| match ret {
1070                    Ok(flag) => ResponseType::Flag(flag),
1071                    Err(err) => ResponseType::Err(err),
1072                },
1073                socket_id,
1074                req_id,
1075            ),
1076            RequestType::SetRecvBufSize(size) => self.process_inner_noop(
1077                move |socket| match socket {
1078                    RemoteAdapterSocket::TcpSocket(s) => {
1079                        s.set_recv_buf_size(size.try_into().unwrap_or_default())
1080                    }
1081                    _ => Err(NetworkError::Unsupported),
1082                },
1083                socket_id,
1084                req_id,
1085            ),
1086            RequestType::GetRecvBufSize => self.process_inner(
1087                move |socket| match socket {
1088                    RemoteAdapterSocket::TcpSocket(s) => s.recv_buf_size(),
1089                    _ => Err(NetworkError::Unsupported),
1090                },
1091                |ret| match ret {
1092                    Ok(amt) => ResponseType::Amount(amt as u64),
1093                    Err(err) => ResponseType::Err(err),
1094                },
1095                socket_id,
1096                req_id,
1097            ),
1098            RequestType::SetSendBufSize(size) => self.process_inner_noop(
1099                move |socket| match socket {
1100                    RemoteAdapterSocket::TcpSocket(s) => {
1101                        s.set_send_buf_size(size.try_into().unwrap_or_default())
1102                    }
1103                    _ => Err(NetworkError::Unsupported),
1104                },
1105                socket_id,
1106                req_id,
1107            ),
1108            RequestType::GetSendBufSize => self.process_inner(
1109                move |socket| match socket {
1110                    RemoteAdapterSocket::TcpSocket(s) => s.send_buf_size(),
1111                    _ => Err(NetworkError::Unsupported),
1112                },
1113                |ret| match ret {
1114                    Ok(amt) => ResponseType::Amount(amt as u64),
1115                    Err(err) => ResponseType::Err(err),
1116                },
1117                socket_id,
1118                req_id,
1119            ),
1120            RequestType::SetNoDelay(reuse) => self.process_inner_noop(
1121                move |socket| match socket {
1122                    RemoteAdapterSocket::TcpSocket(s) => s.set_nodelay(reuse),
1123                    _ => Err(NetworkError::Unsupported),
1124                },
1125                socket_id,
1126                req_id,
1127            ),
1128            RequestType::GetNoDelay => self.process_inner(
1129                move |socket| match socket {
1130                    RemoteAdapterSocket::TcpSocket(s) => s.nodelay(),
1131                    _ => Err(NetworkError::Unsupported),
1132                },
1133                |ret| match ret {
1134                    Ok(flag) => ResponseType::Flag(flag),
1135                    Err(err) => ResponseType::Err(err),
1136                },
1137                socket_id,
1138                req_id,
1139            ),
1140            RequestType::SetKeepAlive(val) => self.process_inner_noop(
1141                move |socket| match socket {
1142                    RemoteAdapterSocket::TcpSocket(s) => s.set_keepalive(val),
1143                    _ => Err(NetworkError::Unsupported),
1144                },
1145                socket_id,
1146                req_id,
1147            ),
1148            RequestType::GetKeepAlive => self.process_inner(
1149                move |socket| match socket {
1150                    RemoteAdapterSocket::TcpSocket(s) => s.keepalive(),
1151                    _ => Err(NetworkError::Unsupported),
1152                },
1153                |ret| match ret {
1154                    Ok(flag) => ResponseType::Flag(flag),
1155                    Err(err) => ResponseType::Err(err),
1156                },
1157                socket_id,
1158                req_id,
1159            ),
1160            RequestType::SetDontRoute(val) => self.process_inner_noop(
1161                move |socket| match socket {
1162                    RemoteAdapterSocket::TcpSocket(s) => s.set_dontroute(val),
1163                    _ => Err(NetworkError::Unsupported),
1164                },
1165                socket_id,
1166                req_id,
1167            ),
1168            RequestType::GetDontRoute => self.process_inner(
1169                move |socket| match socket {
1170                    RemoteAdapterSocket::TcpSocket(s) => s.dontroute(),
1171                    _ => Err(NetworkError::Unsupported),
1172                },
1173                |ret| match ret {
1174                    Ok(flag) => ResponseType::Flag(flag),
1175                    Err(err) => ResponseType::Err(err),
1176                },
1177                socket_id,
1178                req_id,
1179            ),
1180            RequestType::Shutdown(shutdown) => self.process_inner_noop(
1181                move |socket| match socket {
1182                    RemoteAdapterSocket::TcpSocket(s) => s.shutdown(match shutdown {
1183                        crate::meta::Shutdown::Read => std::net::Shutdown::Read,
1184                        crate::meta::Shutdown::Write => std::net::Shutdown::Write,
1185                        crate::meta::Shutdown::Both => std::net::Shutdown::Both,
1186                    }),
1187                    _ => Err(NetworkError::Unsupported),
1188                },
1189                socket_id,
1190                req_id,
1191            ),
1192            RequestType::IsClosed => self.process_inner(
1193                move |socket| match socket {
1194                    RemoteAdapterSocket::TcpSocket(s) => Ok(s.is_closed()),
1195                    _ => Err(NetworkError::Unsupported),
1196                },
1197                |ret| match ret {
1198                    Ok(flag) => ResponseType::Flag(flag),
1199                    Err(err) => ResponseType::Err(err),
1200                },
1201                socket_id,
1202                req_id,
1203            ),
1204            RequestType::SetBroadcast(broadcast) => self.process_inner_noop(
1205                move |socket| match socket {
1206                    RemoteAdapterSocket::UdpSocket(s) => s.set_broadcast(broadcast),
1207                    _ => Err(NetworkError::Unsupported),
1208                },
1209                socket_id,
1210                req_id,
1211            ),
1212            RequestType::GetBroadcast => self.process_inner(
1213                move |socket| match socket {
1214                    RemoteAdapterSocket::UdpSocket(s) => s.broadcast(),
1215                    _ => Err(NetworkError::Unsupported),
1216                },
1217                |ret| match ret {
1218                    Ok(flag) => ResponseType::Flag(flag),
1219                    Err(err) => ResponseType::Err(err),
1220                },
1221                socket_id,
1222                req_id,
1223            ),
1224            RequestType::SetMulticastLoopV4(val) => self.process_inner_noop(
1225                move |socket| match socket {
1226                    RemoteAdapterSocket::UdpSocket(s) => s.set_multicast_loop_v4(val),
1227                    _ => Err(NetworkError::Unsupported),
1228                },
1229                socket_id,
1230                req_id,
1231            ),
1232            RequestType::GetMulticastLoopV4 => self.process_inner(
1233                move |socket| match socket {
1234                    RemoteAdapterSocket::UdpSocket(s) => s.multicast_loop_v4(),
1235                    _ => Err(NetworkError::Unsupported),
1236                },
1237                |ret| match ret {
1238                    Ok(flag) => ResponseType::Flag(flag),
1239                    Err(err) => ResponseType::Err(err),
1240                },
1241                socket_id,
1242                req_id,
1243            ),
1244            RequestType::SetMulticastLoopV6(val) => self.process_inner_noop(
1245                move |socket| match socket {
1246                    RemoteAdapterSocket::UdpSocket(s) => s.set_multicast_loop_v6(val),
1247                    _ => Err(NetworkError::Unsupported),
1248                },
1249                socket_id,
1250                req_id,
1251            ),
1252            RequestType::GetMulticastLoopV6 => self.process_inner(
1253                move |socket| match socket {
1254                    RemoteAdapterSocket::UdpSocket(s) => s.multicast_loop_v6(),
1255                    _ => Err(NetworkError::Unsupported),
1256                },
1257                |ret| match ret {
1258                    Ok(flag) => ResponseType::Flag(flag),
1259                    Err(err) => ResponseType::Err(err),
1260                },
1261                socket_id,
1262                req_id,
1263            ),
1264            RequestType::SetMulticastTtlV4(ttl) => self.process_inner_noop(
1265                move |socket| match socket {
1266                    RemoteAdapterSocket::UdpSocket(s) => s.set_multicast_ttl_v4(ttl),
1267                    _ => Err(NetworkError::Unsupported),
1268                },
1269                socket_id,
1270                req_id,
1271            ),
1272            RequestType::GetMulticastTtlV4 => self.process_inner(
1273                move |socket| match socket {
1274                    RemoteAdapterSocket::UdpSocket(s) => s.multicast_ttl_v4(),
1275                    _ => Err(NetworkError::Unsupported),
1276                },
1277                |ret| match ret {
1278                    Ok(ttl) => ResponseType::Ttl(ttl),
1279                    Err(err) => ResponseType::Err(err),
1280                },
1281                socket_id,
1282                req_id,
1283            ),
1284            RequestType::JoinMulticastV4 { multiaddr, iface } => self.process_inner_noop(
1285                move |socket| match socket {
1286                    RemoteAdapterSocket::UdpSocket(s) => s.join_multicast_v4(multiaddr, iface),
1287                    _ => Err(NetworkError::Unsupported),
1288                },
1289                socket_id,
1290                req_id,
1291            ),
1292            RequestType::LeaveMulticastV4 { multiaddr, iface } => self.process_inner_noop(
1293                move |socket| match socket {
1294                    RemoteAdapterSocket::UdpSocket(s) => s.leave_multicast_v4(multiaddr, iface),
1295                    _ => Err(NetworkError::Unsupported),
1296                },
1297                socket_id,
1298                req_id,
1299            ),
1300            RequestType::JoinMulticastV6 { multiaddr, iface } => self.process_inner_noop(
1301                move |socket| match socket {
1302                    RemoteAdapterSocket::UdpSocket(s) => s.join_multicast_v6(multiaddr, iface),
1303                    _ => Err(NetworkError::Unsupported),
1304                },
1305                socket_id,
1306                req_id,
1307            ),
1308            RequestType::LeaveMulticastV6 { multiaddr, iface } => self.process_inner_noop(
1309                move |socket| match socket {
1310                    RemoteAdapterSocket::UdpSocket(s) => s.leave_multicast_v6(multiaddr, iface),
1311                    _ => Err(NetworkError::Unsupported),
1312                },
1313                socket_id,
1314                req_id,
1315            ),
1316            _ => req_id.and_then(|req_id| {
1317                self.common.send(MessageResponse::ResponseToRequest {
1318                    req_id,
1319                    res: ResponseType::Err(NetworkError::Unsupported),
1320                })
1321            }),
1322        }
1323    }
1324}
1325
1326#[derive(Debug)]
1327enum RemoteAdapterSocket {
1328    BoundTcp(Box<dyn VirtualTcpBoundSocket + Sync + 'static>),
1329    TcpListener {
1330        socket: Box<dyn VirtualTcpListener + Sync + 'static>,
1331        next_accept: Option<SocketId>,
1332    },
1333    TcpSocket(Box<dyn VirtualTcpSocket + Sync + 'static>),
1334    UdpSocket(Box<dyn VirtualUdpSocket + Sync + 'static>),
1335    RawSocket(Box<dyn VirtualRawSocket + Sync + 'static>),
1336    IcmpSocket(Box<dyn VirtualIcmpSocket + Sync + 'static>),
1337}
1338
1339impl RemoteAdapterSocket {
1340    pub fn send(
1341        &mut self,
1342        common: &Arc<RemoteAdapterCommon>,
1343        socket_id: SocketId,
1344        data: Vec<u8>,
1345        req_id: Option<u64>,
1346    ) -> BackgroundTask {
1347        match self {
1348            Self::TcpSocket(this) => match this.try_send(&data) {
1349                Ok(amount) => {
1350                    if let Some(req_id) = req_id {
1351                        common.send(MessageResponse::Sent {
1352                            socket_id,
1353                            req_id,
1354                            amount: amount as u64,
1355                        })
1356                    } else {
1357                        None
1358                    }
1359                }
1360                Err(NetworkError::WouldBlock) => {
1361                    let common = common.clone();
1362                    Some(Box::pin(async move {
1363                        // We will stall the receiver so that back pressure is sent back to the
1364                        // sender and they don't overwhelm us with transmitting data.
1365                        let _stall_rx = common.stall_rx.clone().lock_owned().await;
1366
1367                        // We use a poller here that uses the handler to wake itself up
1368                        struct Poller {
1369                            common: Arc<RemoteAdapterCommon>,
1370                            socket_id: SocketId,
1371                            data: Vec<u8>,
1372                            req_id: Option<u64>,
1373                        }
1374                        impl Future for Poller {
1375                            type Output = BackgroundTask;
1376                            fn poll(
1377                                self: Pin<&mut Self>,
1378                                cx: &mut Context<'_>,
1379                            ) -> Poll<Self::Output> {
1380                                // We make sure the waker is registered with the interest driver which will
1381                                // wake up this poller when there is writeability
1382                                let mut guard = self.common.handler.state.lock().unwrap();
1383                                if !guard.driver_wakers.iter().any(|w| w.will_wake(cx.waker())) {
1384                                    guard.driver_wakers.push(cx.waker().clone());
1385                                }
1386                                drop(guard);
1387
1388                                let mut guard = self.common.sockets.lock().unwrap();
1389                                if let Some(RemoteAdapterSocket::TcpSocket(socket)) =
1390                                    guard.get_mut(&self.socket_id)
1391                                {
1392                                    match socket.try_send(&self.data) {
1393                                        Ok(amount) => {
1394                                            if let Some(req_id) = self.req_id {
1395                                                return Poll::Ready(self.common.send(
1396                                                    MessageResponse::Sent {
1397                                                        socket_id: self.socket_id,
1398                                                        req_id,
1399                                                        amount: amount as u64,
1400                                                    },
1401                                                ));
1402                                            } else {
1403                                                return Poll::Ready(None);
1404                                            }
1405                                        }
1406                                        Err(NetworkError::WouldBlock) => return Poll::Pending,
1407                                        Err(error) => {
1408                                            if let Some(req_id) = self.req_id {
1409                                                return Poll::Ready(self.common.send(
1410                                                    MessageResponse::SendError {
1411                                                        socket_id: self.socket_id,
1412                                                        req_id,
1413                                                        error,
1414                                                    },
1415                                                ));
1416                                            } else {
1417                                                return Poll::Ready(None);
1418                                            }
1419                                        }
1420                                    }
1421                                }
1422                                Poll::Ready(None)
1423                            }
1424                        }
1425
1426                        // Run the poller until this message is sent, or the socket fails
1427                        let background_task = Poller {
1428                            common,
1429                            socket_id,
1430                            data,
1431                            req_id,
1432                        }
1433                        .await;
1434
1435                        // There might be more work left to finish off the send operation
1436                        if let Some(background_task) = background_task {
1437                            background_task.await;
1438                        }
1439                    }))
1440                }
1441                Err(error) => {
1442                    if let Some(req_id) = req_id {
1443                        common.send(MessageResponse::SendError {
1444                            socket_id,
1445                            req_id,
1446                            error,
1447                        })
1448                    } else {
1449                        None
1450                    }
1451                }
1452            },
1453            Self::RawSocket(this) => {
1454                // when the RAW socket is overloaded we just silently drop the packet
1455                // rather than buffering it and retrying later - Ethernet packets are
1456                // not lossless. In reality most socket drivers under this remote socket
1457                // will always succeed on `try_send` with RawSockets as they are always
1458                // processed.
1459                if let Err(err) = this.try_send(&data) {
1460                    tracing::debug!("failed to send raw packet - {}", err);
1461                }
1462                None
1463            }
1464            _ => {
1465                if let Some(req_id) = req_id {
1466                    common.send(MessageResponse::SendError {
1467                        socket_id,
1468                        req_id,
1469                        error: NetworkError::Unsupported,
1470                    })
1471                } else {
1472                    None
1473                }
1474            }
1475        }
1476    }
1477    pub fn send_to(
1478        &mut self,
1479        common: &Arc<RemoteAdapterCommon>,
1480        socket_id: SocketId,
1481        data: Vec<u8>,
1482        addr: SocketAddr,
1483        req_id: u64,
1484    ) -> BackgroundTask {
1485        match self {
1486            Self::UdpSocket(this) => {
1487                // when the UDP socket is overloaded we just silently drop the packet
1488                // rather than buffering it and retrying later
1489                this.try_send_to(&data, addr).ok();
1490                None
1491            }
1492
1493            Self::IcmpSocket(this) => {
1494                // when the ICMP socket is overloaded we just silently drop the packet
1495                // rather than buffering it and retrying later
1496                this.try_send_to(&data, addr).ok();
1497                None
1498            }
1499            _ => common.send(MessageResponse::SendError {
1500                socket_id,
1501                req_id,
1502                error: NetworkError::Unsupported,
1503            }),
1504        }
1505    }
1506    pub fn drain_reads_and_accepts(
1507        &mut self,
1508        common: &Arc<RemoteAdapterCommon>,
1509        socket_id: SocketId,
1510    ) -> BackgroundTask {
1511        // We loop reading the socket until all the pending reads are either
1512        // being processed in a background task or they are empty
1513        let mut ret: FuturesOrdered<BoxFuture<'static, ()>> = Default::default();
1514        loop {
1515            break match self {
1516                Self::BoundTcp(_) => {}
1517                Self::TcpListener {
1518                    socket,
1519                    next_accept,
1520                } => {
1521                    if next_accept.is_some() {
1522                        match socket.try_accept() {
1523                            Ok((mut child_socket, addr)) => {
1524                                let child_id = next_accept.take().unwrap();
1525
1526                                // We set the handler on the socket so that it can
1527                                let handler = Box::new(common.handler.clone().for_socket(child_id));
1528                                child_socket.set_handler(handler).ok();
1529
1530                                // We will fix up the socket in the background then notify
1531                                // the client of the new socket
1532                                let common = common.clone();
1533                                ret.push_back(Box::pin(async move {
1534                                    // Next we record the socket so that it is active
1535                                    {
1536                                        let child_socket =
1537                                            RemoteAdapterSocket::TcpSocket(child_socket);
1538                                        let mut guard = common.sockets.lock().unwrap();
1539                                        guard.insert(child_id, child_socket);
1540                                    }
1541
1542                                    // Lastly we tell the client about the new socket
1543                                    if let Some(task) = common.send(MessageResponse::FinishAccept {
1544                                        socket_id,
1545                                        child_id,
1546                                        addr,
1547                                    }) {
1548                                        task.await;
1549                                    }
1550                                }));
1551                            }
1552                            Err(NetworkError::WouldBlock) => {}
1553                            Err(err) => {
1554                                tracing::error!("failed to accept socket - {}", err);
1555                            }
1556                        }
1557                    }
1558                }
1559                Self::TcpSocket(this) => {
1560                    let mut chunk: [MaybeUninit<u8>; 10240] =
1561                        unsafe { MaybeUninit::uninit().assume_init() };
1562                    match this.try_recv(&mut chunk, false) {
1563                        Ok(0) => {}
1564                        Ok(amt) => {
1565                            let chunk_unsafe: &mut [MaybeUninit<u8>] = &mut chunk[..amt];
1566                            let chunk_unsafe: &mut [u8] =
1567                                unsafe { std::mem::transmute(chunk_unsafe) };
1568                            if let Some(task) = common.send(MessageResponse::Recv {
1569                                socket_id,
1570                                data: chunk_unsafe.to_vec(),
1571                            }) {
1572                                ret.push_back(task);
1573                            }
1574                            continue;
1575                        }
1576                        Err(_) => {}
1577                    }
1578                }
1579                Self::UdpSocket(this) => {
1580                    let mut chunk: [MaybeUninit<u8>; 10240] =
1581                        unsafe { MaybeUninit::uninit().assume_init() };
1582                    match this.try_recv_from(&mut chunk, false) {
1583                        Ok((0, _)) => {}
1584                        Ok((amt, addr)) => {
1585                            let chunk_unsafe: &mut [MaybeUninit<u8>] = &mut chunk[..amt];
1586                            let chunk_unsafe: &mut [u8] =
1587                                unsafe { std::mem::transmute(chunk_unsafe) };
1588                            if let Some(task) = common.send(MessageResponse::RecvWithAddr {
1589                                socket_id,
1590                                data: chunk_unsafe.to_vec(),
1591                                addr,
1592                            }) {
1593                                ret.push_back(task);
1594                            }
1595                            continue;
1596                        }
1597                        Err(_) => {}
1598                    }
1599                }
1600                Self::IcmpSocket(this) => {
1601                    let mut chunk: [MaybeUninit<u8>; 10240] =
1602                        unsafe { MaybeUninit::uninit().assume_init() };
1603                    match this.try_recv_from(&mut chunk, false) {
1604                        Ok((0, _)) => {}
1605                        Ok((amt, addr)) => {
1606                            let chunk_unsafe: &mut [MaybeUninit<u8>] = &mut chunk[..amt];
1607                            let chunk_unsafe: &mut [u8] =
1608                                unsafe { std::mem::transmute(chunk_unsafe) };
1609                            if let Some(task) = common.send(MessageResponse::RecvWithAddr {
1610                                socket_id,
1611                                data: chunk_unsafe.to_vec(),
1612                                addr,
1613                            }) {
1614                                ret.push_back(task);
1615                            }
1616                            continue;
1617                        }
1618                        Err(_) => {}
1619                    }
1620                }
1621                Self::RawSocket(this) => {
1622                    let mut chunk: [MaybeUninit<u8>; 10240] =
1623                        unsafe { MaybeUninit::uninit().assume_init() };
1624                    match this.try_recv(&mut chunk, false) {
1625                        Ok(0) => {}
1626                        Ok(amt) => {
1627                            let chunk_unsafe: &mut [MaybeUninit<u8>] = &mut chunk[..amt];
1628                            let chunk_unsafe: &mut [u8] =
1629                                unsafe { std::mem::transmute(chunk_unsafe) };
1630                            if let Some(task) = common.send(MessageResponse::Recv {
1631                                socket_id,
1632                                data: chunk_unsafe.to_vec(),
1633                            }) {
1634                                ret.push_back(task);
1635                            }
1636                            continue;
1637                        }
1638                        Err(_) => {}
1639                    }
1640                }
1641            };
1642        }
1643
1644        if ret.is_empty() {
1645            // There is nothing to process so we are done
1646            None
1647        } else {
1648            Some(Box::pin(async move {
1649                // Processes all the background tasks until completion
1650                let mut stream = ret;
1651                loop {
1652                    let (next, s) = StreamExt::into_future(stream).await;
1653                    if next.is_none() {
1654                        break;
1655                    }
1656                    stream = s;
1657                }
1658            }))
1659        }
1660    }
1661}
1662
1663#[derive(Debug, Default)]
1664struct RemoteAdapterHandlerState {
1665    readable: HashSet<SocketId>,
1666    driver_wakers: Vec<Waker>,
1667}
1668
1669#[derive(Debug, Default, Clone)]
1670struct RemoteAdapterHandler {
1671    socket_id: Option<SocketId>,
1672    state: Arc<Mutex<RemoteAdapterHandlerState>>,
1673}
1674impl RemoteAdapterHandler {
1675    pub fn for_socket(self, id: SocketId) -> Self {
1676        Self {
1677            socket_id: Some(id),
1678            state: self.state,
1679        }
1680    }
1681}
1682impl InterestHandler for RemoteAdapterHandler {
1683    fn push_interest(&mut self, interest: virtual_mio::InterestType) {
1684        let mut guard = self.state.lock().unwrap();
1685        guard.driver_wakers.drain(..).for_each(|w| w.wake());
1686        let socket_id = match self.socket_id {
1687            Some(s) => s,
1688            None => return,
1689        };
1690        if interest == virtual_mio::InterestType::Readable {
1691            guard.readable.insert(socket_id);
1692        }
1693    }
1694
1695    fn pop_interest(&mut self, interest: virtual_mio::InterestType) -> bool {
1696        let mut guard = self.state.lock().unwrap();
1697        let socket_id = match self.socket_id {
1698            Some(s) => s,
1699            None => return false,
1700        };
1701        if interest == virtual_mio::InterestType::Readable {
1702            return guard.readable.remove(&socket_id);
1703        }
1704        false
1705    }
1706
1707    fn has_interest(&self, interest: virtual_mio::InterestType) -> bool {
1708        let guard = self.state.lock().unwrap();
1709        let socket_id = match self.socket_id {
1710            Some(s) => s,
1711            None => return false,
1712        };
1713        if interest == virtual_mio::InterestType::Readable {
1714            return guard.readable.contains(&socket_id);
1715        }
1716        false
1717    }
1718}
1719
1720type SocketMap<T> = HashMap<SocketId, T>;
1721
1722#[derive(Debug)]
1723struct RemoteAdapterCommon {
1724    tx: RemoteTx<MessageResponse>,
1725    rx: Mutex<RemoteRx<MessageRequest>>,
1726    sockets: Mutex<SocketMap<RemoteAdapterSocket>>,
1727    socket_accept: Mutex<SocketMap<SocketId>>,
1728    handler: RemoteAdapterHandler,
1729
1730    // The stall guard will prevent reads while its held and there are background tasks running
1731    // (the idea behind this is to create back pressure so that the task list infinitely grow)
1732    stall_rx: Arc<tokio::sync::Mutex<()>>,
1733}
1734impl RemoteAdapterCommon {
1735    fn send(self: &Arc<Self>, req: MessageResponse) -> BackgroundTask {
1736        let this = self.clone();
1737        Some(Box::pin(async move {
1738            if let Err(err) = this.tx.send(req).await {
1739                tracing::debug!("failed to send message - {}", err);
1740            }
1741        }))
1742    }
1743}