1use crate::meta::{FrameSerializationFormat, ResponseType};
2use crate::rx_tx::{RemoteRx, RemoteTx, RemoteTxWakers};
3use crate::{IpCidr, IpRoute, NetworkError, SocketStatus, StreamSecurity, VirtualIcmpSocket};
4use crate::{
5 VirtualNetworking, VirtualRawSocket, VirtualTcpBoundSocket, VirtualTcpListener,
6 VirtualTcpSocket, VirtualUdpSocket,
7 meta::{MessageRequest, MessageResponse, RequestType, SocketId},
8};
9use futures_util::stream::FuturesOrdered;
10#[cfg(any(feature = "hyper", feature = "tokio-tungstenite"))]
11use futures_util::stream::{SplitSink, SplitStream};
12use futures_util::{Sink, Stream};
13use futures_util::{StreamExt, future::BoxFuture};
14use std::collections::HashSet;
15use std::mem::MaybeUninit;
16use std::net::IpAddr;
17use std::task::Waker;
18use std::time::Duration;
19
20#[cfg(feature = "hyper")]
21use hyper_util::rt::tokio::TokioIo;
22use std::{
23 collections::HashMap,
24 future::Future,
25 net::SocketAddr,
26 pin::Pin,
27 sync::{Arc, Mutex},
28 task::{Context, Poll},
29};
30use tokio::{
31 io::{AsyncRead, AsyncWrite},
32 sync::mpsc,
33};
34use tokio_serde::SymmetricallyFramed;
35use tokio_serde::formats::SymmetricalBincode;
36#[cfg(feature = "cbor")]
37use tokio_serde::formats::SymmetricalCbor;
38#[cfg(feature = "json")]
39use tokio_serde::formats::SymmetricalJson;
40#[cfg(feature = "messagepack")]
41use tokio_serde::formats::SymmetricalMessagePack;
42use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
43use virtual_mio::InterestHandler;
44
45type BackgroundTask = Option<BoxFuture<'static, ()>>;
46
47#[derive(Debug, Clone)]
48pub struct RemoteNetworkingServer {
49 #[allow(dead_code)]
50 common: Arc<RemoteAdapterCommon>,
51 inner: Arc<dyn VirtualNetworking + Send + Sync + 'static>,
52}
53
54impl RemoteNetworkingServer {
55 fn new(
56 tx: RemoteTx<MessageResponse>,
57 rx: RemoteRx<MessageRequest>,
58 work: mpsc::UnboundedReceiver<BoxFuture<'static, ()>>,
59 inner: Arc<dyn VirtualNetworking + Send + Sync + 'static>,
60 ) -> (Self, RemoteNetworkingServerDriver) {
61 let common = RemoteAdapterCommon {
62 tx,
63 rx: Mutex::new(rx),
64 sockets: Default::default(),
65 socket_accept: Default::default(),
66 handler: Default::default(),
67 stall_rx: Default::default(),
68 };
69 let common = Arc::new(common);
70
71 let driver = RemoteNetworkingServerDriver {
72 more_work: work,
73 tasks: Default::default(),
74 common: common.clone(),
75 inner: inner.clone(),
76 };
77 let networking = Self { common, inner };
78
79 (networking, driver)
80 }
81 pub fn new_from_mpsc(
84 tx: mpsc::Sender<MessageResponse>,
85 rx: mpsc::Receiver<MessageRequest>,
86 inner: Arc<dyn VirtualNetworking + Send + Sync + 'static>,
87 ) -> (Self, RemoteNetworkingServerDriver) {
88 let (tx_work, rx_work) = mpsc::unbounded_channel();
89 let tx_wakers = RemoteTxWakers::default();
90
91 let tx = RemoteTx::Mpsc {
92 tx,
93 work: tx_work,
94 wakers: tx_wakers.clone(),
95 };
96 let rx = RemoteRx::Mpsc {
97 rx,
98 wakers: tx_wakers,
99 };
100 Self::new(tx, rx, rx_work, inner)
101 }
102
103 pub fn new_from_async_io<TX, RX>(
106 tx: TX,
107 rx: RX,
108 format: FrameSerializationFormat,
109 inner: Arc<dyn VirtualNetworking + Send + Sync + 'static>,
110 ) -> (Self, RemoteNetworkingServerDriver)
111 where
112 TX: AsyncWrite + Send + 'static,
113 RX: AsyncRead + Send + 'static,
114 {
115 let tx = FramedWrite::new(tx, LengthDelimitedCodec::new());
116 let tx: Pin<Box<dyn Sink<MessageResponse, Error = std::io::Error> + Send + 'static>> =
117 match format {
118 FrameSerializationFormat::Bincode => {
119 Box::pin(SymmetricallyFramed::new(tx, SymmetricalBincode::default()))
120 }
121 #[cfg(feature = "json")]
122 FrameSerializationFormat::Json => {
123 Box::pin(SymmetricallyFramed::new(tx, SymmetricalJson::default()))
124 }
125 #[cfg(feature = "messagepack")]
126 FrameSerializationFormat::MessagePack => Box::pin(SymmetricallyFramed::new(
127 tx,
128 SymmetricalMessagePack::default(),
129 )),
130 #[cfg(feature = "cbor")]
131 FrameSerializationFormat::Cbor => {
132 Box::pin(SymmetricallyFramed::new(tx, SymmetricalCbor::default()))
133 }
134 };
135
136 let rx = FramedRead::new(rx, LengthDelimitedCodec::new());
137 let rx: Pin<Box<dyn Stream<Item = std::io::Result<MessageRequest>> + Send + 'static>> =
138 match format {
139 FrameSerializationFormat::Bincode => {
140 Box::pin(SymmetricallyFramed::new(rx, SymmetricalBincode::default()))
141 }
142 #[cfg(feature = "json")]
143 FrameSerializationFormat::Json => {
144 Box::pin(SymmetricallyFramed::new(rx, SymmetricalJson::default()))
145 }
146 #[cfg(feature = "messagepack")]
147 FrameSerializationFormat::MessagePack => Box::pin(SymmetricallyFramed::new(
148 rx,
149 SymmetricalMessagePack::default(),
150 )),
151 #[cfg(feature = "cbor")]
152 FrameSerializationFormat::Cbor => {
153 Box::pin(SymmetricallyFramed::new(rx, SymmetricalCbor::default()))
154 }
155 };
156
157 let (tx_work, rx_work) = mpsc::unbounded_channel();
158
159 let tx = RemoteTx::Stream {
160 tx: Arc::new(tokio::sync::Mutex::new(tx)),
161 work: tx_work,
162 wakers: RemoteTxWakers::default(),
163 };
164 let rx = RemoteRx::Stream { rx };
165 Self::new(tx, rx, rx_work, inner)
166 }
167
168 #[cfg(feature = "hyper")]
171 pub fn new_from_hyper_ws_io(
172 tx: SplitSink<
173 hyper_tungstenite::WebSocketStream<TokioIo<hyper::upgrade::Upgraded>>,
174 hyper_tungstenite::tungstenite::Message,
175 >,
176 rx: SplitStream<hyper_tungstenite::WebSocketStream<TokioIo<hyper::upgrade::Upgraded>>>,
177 format: FrameSerializationFormat,
178 inner: Arc<dyn VirtualNetworking + Send + Sync + 'static>,
179 ) -> (Self, RemoteNetworkingServerDriver) {
180 let (tx_work, rx_work) = mpsc::unbounded_channel();
181
182 let tx = RemoteTx::HyperWebSocket {
183 tx: Arc::new(tokio::sync::Mutex::new(tx)),
184 work: tx_work,
185 wakers: RemoteTxWakers::default(),
186 format,
187 };
188 let rx = RemoteRx::HyperWebSocket { rx, format };
189 Self::new(tx, rx, rx_work, inner)
190 }
191
192 #[cfg(all(test, feature = "tokio"))]
193 pub(crate) fn socket_count_for_test(&self) -> usize {
194 self.common.sockets.lock().unwrap().len()
195 }
196}
197
198#[async_trait::async_trait]
199impl VirtualNetworking for RemoteNetworkingServer {
200 async fn bridge(
201 &self,
202 network: &str,
203 access_token: &str,
204 security: StreamSecurity,
205 ) -> Result<(), NetworkError> {
206 self.inner.bridge(network, access_token, security).await
207 }
208
209 async fn unbridge(&self) -> Result<(), NetworkError> {
210 self.inner.unbridge().await
211 }
212
213 async fn dhcp_acquire(&self) -> Result<Vec<IpAddr>, NetworkError> {
214 self.inner.dhcp_acquire().await
215 }
216
217 async fn ip_add(&self, ip: IpAddr, prefix: u8) -> Result<(), NetworkError> {
218 self.inner.ip_add(ip, prefix).await
219 }
220
221 async fn ip_remove(&self, ip: IpAddr) -> Result<(), NetworkError> {
222 self.inner.ip_remove(ip).await
223 }
224
225 async fn ip_clear(&self) -> Result<(), NetworkError> {
226 self.inner.ip_clear().await
227 }
228
229 async fn ip_list(&self) -> Result<Vec<IpCidr>, NetworkError> {
230 self.inner.ip_list().await
231 }
232
233 async fn mac(&self) -> Result<[u8; 6], NetworkError> {
234 self.inner.mac().await
235 }
236
237 async fn gateway_set(&self, ip: IpAddr) -> Result<(), NetworkError> {
238 self.inner.gateway_set(ip).await
239 }
240
241 async fn route_add(
242 &self,
243 cidr: IpCidr,
244 via_router: IpAddr,
245 preferred_until: Option<Duration>,
246 expires_at: Option<Duration>,
247 ) -> Result<(), NetworkError> {
248 self.inner
249 .route_add(cidr, via_router, preferred_until, expires_at)
250 .await
251 }
252
253 async fn route_remove(&self, cidr: IpAddr) -> Result<(), NetworkError> {
254 self.inner.route_remove(cidr).await
255 }
256
257 async fn route_clear(&self) -> Result<(), NetworkError> {
258 self.inner.route_clear().await
259 }
260
261 async fn route_list(&self) -> Result<Vec<IpRoute>, NetworkError> {
262 self.inner.route_list().await
263 }
264
265 async fn bind_raw(&self) -> Result<Box<dyn VirtualRawSocket + Sync>, NetworkError> {
266 self.inner.bind_raw().await
267 }
268
269 async fn listen_tcp(
270 &self,
271 addr: SocketAddr,
272 only_v6: bool,
273 reuse_port: bool,
274 reuse_addr: bool,
275 ) -> Result<Box<dyn VirtualTcpListener + Sync>, NetworkError> {
276 self.inner
277 .listen_tcp(addr, only_v6, reuse_port, reuse_addr)
278 .await
279 }
280
281 async fn bind_udp(
282 &self,
283 addr: SocketAddr,
284 reuse_port: bool,
285 reuse_addr: bool,
286 ) -> Result<Box<dyn VirtualUdpSocket + Sync>, NetworkError> {
287 self.inner.bind_udp(addr, reuse_port, reuse_addr).await
288 }
289
290 async fn bind_icmp(
291 &self,
292 addr: IpAddr,
293 ) -> Result<Box<dyn VirtualIcmpSocket + Sync>, NetworkError> {
294 self.inner.bind_icmp(addr).await
295 }
296
297 async fn connect_tcp(
298 &self,
299 addr: SocketAddr,
300 peer: SocketAddr,
301 ) -> Result<Box<dyn VirtualTcpSocket + Sync>, NetworkError> {
302 self.inner.connect_tcp(addr, peer).await
303 }
304
305 async fn resolve(
306 &self,
307 host: &str,
308 port: Option<u16>,
309 dns_server: Option<IpAddr>,
310 ) -> Result<Vec<IpAddr>, NetworkError> {
311 self.inner.resolve(host, port, dns_server).await
312 }
313}
314
315pin_project_lite::pin_project! {
316 pub struct RemoteNetworkingServerDriver {
317 common: Arc<RemoteAdapterCommon>,
318 more_work: mpsc::UnboundedReceiver<BoxFuture<'static, ()>>,
319 #[pin]
320 tasks: FuturesOrdered<BoxFuture<'static, ()>>,
321 inner: Arc<dyn VirtualNetworking + Send + Sync + 'static>,
322 }
323}
324
325impl Future for RemoteNetworkingServerDriver {
326 type Output = ();
327
328 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
329 let readable = {
332 let mut guard = self.common.handler.state.lock().unwrap();
333 if !guard.driver_wakers.iter().any(|w| w.will_wake(cx.waker())) {
334 guard.driver_wakers.push(cx.waker().clone());
335 }
336 guard.readable.drain().collect()
337 };
338 let readable: Vec<_> = readable;
339
340 {
341 let common = self.common.clone();
344 let mut guard = common.sockets.lock().unwrap();
345 for socket_id in readable {
346 if let Some(task) = guard
347 .get_mut(&socket_id)
348 .map(|s| s.drain_reads_and_accepts(&common, socket_id))
349 .unwrap_or(None)
350 {
351 self.tasks.push_back(task);
352 }
353 }
354 }
355
356 let mut not_stalled_guard = None;
360
361 loop {
364 while let Poll::Ready(Some(work)) = Pin::new(&mut self.more_work).poll_recv(cx) {
366 self.tasks.push_back(work);
367 }
368
369 match self.tasks.poll_next_unpin(cx) {
373 Poll::Ready(Some(_)) => continue,
374 Poll::Ready(None) => {
375 not_stalled_guard.take();
376 }
377 Poll::Pending if not_stalled_guard.is_none() => {
378 match self.common.stall_rx.clone().try_lock_owned() {
379 Ok(guard) => {
380 not_stalled_guard.replace(guard);
381 }
382 _ => {
383 return Poll::Pending;
384 }
385 }
386 }
387 Poll::Pending => {}
388 };
389
390 let msg = {
392 let mut rx_guard = self.common.rx.lock().unwrap();
393 rx_guard.poll(cx)
394 };
395 return match msg {
396 Poll::Ready(Some(msg)) => {
397 if let Some(task) = self.process(msg) {
398 self.tasks.push_back(task)
401 };
402 continue;
403 }
404 Poll::Ready(None) => Poll::Ready(()),
405 Poll::Pending => Poll::Pending,
406 };
407 }
408 }
409}
410
411impl RemoteNetworkingServerDriver {
412 fn process(&mut self, msg: MessageRequest) -> BackgroundTask {
413 match msg {
414 MessageRequest::Send {
415 socket,
416 data,
417 req_id,
418 } => self.process_send(socket, data, req_id),
419 MessageRequest::SendTo {
420 socket,
421 data,
422 addr,
423 req_id,
424 } => self.process_send_to(socket, data, addr, req_id),
425 MessageRequest::Interface { req, req_id } => self.process_interface(req, req_id),
426 MessageRequest::Socket {
427 socket,
428 req,
429 req_id,
430 } => self.process_socket(socket, req, req_id),
431 MessageRequest::Reconnect => None,
432 }
433 }
434
435 fn process_send(
436 &mut self,
437 socket_id: SocketId,
438 data: Vec<u8>,
439 req_id: Option<u64>,
440 ) -> BackgroundTask {
441 let mut guard = self.common.sockets.lock().unwrap();
442 guard
443 .get_mut(&socket_id)
444 .map(|s| s.send(&self.common, socket_id, data, req_id))
445 .unwrap_or_else(|| {
446 tracing::debug!("orphaned socket {:?}", socket_id);
447 None
448 })
449 }
450
451 fn process_send_to(
452 &mut self,
453 socket_id: SocketId,
454 data: Vec<u8>,
455 addr: SocketAddr,
456 req_id: Option<u64>,
457 ) -> BackgroundTask {
458 let mut guard = self.common.sockets.lock().unwrap();
459 guard
460 .get_mut(&socket_id)
461 .map(|s| {
462 req_id.and_then(|req_id| s.send_to(&self.common, socket_id, data, addr, req_id))
463 })
464 .unwrap_or(None)
465 }
466
467 fn process_async<F>(future: F) -> BackgroundTask
468 where
469 F: Future<Output = BackgroundTask> + Send + 'static,
470 {
471 Some(Box::pin(async move {
472 let background_task = future.await;
473 if let Some(background_task) = background_task {
474 background_task.await;
475 }
476 }))
477 }
478
479 fn process_async_inner<F, Fut, T>(
480 &self,
481 work: F,
482 transmute: T,
483 req_id: Option<u64>,
484 ) -> BackgroundTask
485 where
486 F: FnOnce(Arc<dyn VirtualNetworking + Send + Sync>) -> Fut + Send + 'static,
487 Fut: Future + Send + 'static,
488 T: FnOnce(Fut::Output) -> ResponseType + Send + 'static,
489 {
490 let inner = self.inner.clone();
491 let common = self.common.clone();
492 Self::process_async(async move {
493 let future = work(inner);
494 let ret = future.await;
495 req_id.and_then(|req_id| {
496 common.send(MessageResponse::ResponseToRequest {
497 req_id,
498 res: transmute(ret),
499 })
500 })
501 })
502 }
503
504 fn process_async_noop<F, Fut>(&self, work: F, req_id: Option<u64>) -> BackgroundTask
505 where
506 F: FnOnce(Arc<dyn VirtualNetworking + Send + Sync>) -> Fut + Send + 'static,
507 Fut: Future<Output = Result<(), NetworkError>> + Send + 'static,
508 {
509 self.process_async_inner(
510 work,
511 move |ret| match ret {
512 Ok(()) => ResponseType::None,
513 Err(err) => ResponseType::Err(err),
514 },
515 req_id,
516 )
517 }
518
519 fn process_async_new_socket<F, Fut>(
520 &self,
521 work: F,
522 socket_id: SocketId,
523 req_id: Option<u64>,
524 ) -> BackgroundTask
525 where
526 F: FnOnce(Arc<dyn VirtualNetworking + Send + Sync>) -> Fut + Send + 'static,
527 Fut: Future<Output = Result<RemoteAdapterSocket, NetworkError>> + Send + 'static,
528 {
529 let common = self.common.clone();
530 self.process_async_inner(
531 work,
532 move |ret| match ret {
533 Ok(mut socket) => {
534 let handler = Box::new(common.handler.clone().for_socket(socket_id));
535
536 let err = match &mut socket {
537 RemoteAdapterSocket::TcpListener { .. } => {
538 Ok(())
542 }
543 RemoteAdapterSocket::BoundTcp(_) => Ok(()),
544 RemoteAdapterSocket::TcpSocket(s) => s.set_handler(handler),
545 RemoteAdapterSocket::UdpSocket(s) => s.set_handler(handler),
546 RemoteAdapterSocket::IcmpSocket(s) => s.set_handler(handler),
547 RemoteAdapterSocket::RawSocket(s) => s.set_handler(handler),
548 };
549 if let Err(err) = err {
550 return ResponseType::Err(err);
551 }
552
553 let mut guard = common.sockets.lock().unwrap();
554 guard.insert(socket_id, socket);
555
556 ResponseType::Socket(socket_id)
557 }
558 Err(err) => ResponseType::Err(err),
559 },
560 req_id,
561 )
562 }
563
564 fn process_inner<F, R, T>(
565 &self,
566 work: F,
567 transmute: T,
568 socket_id: SocketId,
569 req_id: Option<u64>,
570 ) -> BackgroundTask
571 where
572 F: FnOnce(&mut RemoteAdapterSocket) -> R + Send + 'static,
573 T: FnOnce(R) -> ResponseType + Send + 'static,
574 {
575 let ret = {
576 let mut guard = self.common.sockets.lock().unwrap();
577 let socket = match guard.get_mut(&socket_id) {
578 Some(s) => s,
579 None => {
580 return req_id.and_then(|req_id| {
581 self.common.send(MessageResponse::ResponseToRequest {
582 req_id,
583 res: ResponseType::Err(NetworkError::InvalidFd),
584 })
585 });
586 }
587 };
588 work(socket)
589 };
590 req_id.and_then(|req_id| {
591 self.common.send(MessageResponse::ResponseToRequest {
592 req_id,
593 res: transmute(ret),
594 })
595 })
596 }
597
598 fn process_inner_noop<F>(
599 &self,
600 work: F,
601 socket_id: SocketId,
602 req_id: Option<u64>,
603 ) -> BackgroundTask
604 where
605 F: FnOnce(&mut RemoteAdapterSocket) -> Result<(), NetworkError> + Send + 'static,
606 {
607 self.process_inner(
608 work,
609 move |ret| match ret {
610 Ok(()) => ResponseType::None,
611 Err(err) => ResponseType::Err(err),
612 },
613 socket_id,
614 req_id,
615 )
616 }
617
618 fn process_inner_begin_accept(
619 &self,
620 socket_id: SocketId,
621 child_id: SocketId,
622 req_id: Option<u64>,
623 ) -> BackgroundTask {
624 {
626 let mut guard = self.common.socket_accept.lock().unwrap();
627 guard.insert(socket_id, child_id);
628 }
629
630 let mut handler = Box::new(self.common.handler.clone().for_socket(socket_id));
632 handler.push_interest(virtual_mio::InterestType::Readable);
633 self.process_inner_noop(
634 move |socket| match socket {
635 RemoteAdapterSocket::TcpListener {
636 socket: s,
637 next_accept,
638 ..
639 } => {
640 next_accept.replace(child_id);
641 s.set_handler(handler)
642 }
643 _ => {
644 Err(NetworkError::Unsupported)
650 }
651 },
652 socket_id,
653 req_id,
654 )
655 }
656
657 fn process_interface(&mut self, req: RequestType, req_id: Option<u64>) -> BackgroundTask {
658 match req {
659 RequestType::Bridge {
660 network,
661 access_token,
662 security,
663 } => self.process_async_noop(
664 move |inner| async move { inner.bridge(&network, &access_token, security).await },
665 req_id,
666 ),
667 RequestType::Unbridge => {
668 self.process_async_noop(move |inner| async move { inner.unbridge().await }, req_id)
669 }
670 RequestType::DhcpAcquire => self.process_async_inner(
671 move |inner| async move { inner.dhcp_acquire().await },
672 |ret| match ret {
673 Ok(ips) => ResponseType::IpAddressList(ips),
674 Err(err) => ResponseType::Err(err),
675 },
676 req_id,
677 ),
678 RequestType::IpAdd { ip, prefix } => self.process_async_noop(
679 move |inner: Arc<dyn VirtualNetworking + Send + Sync>| async move {
680 inner.ip_add(ip, prefix).await
681 },
682 req_id,
683 ),
684 RequestType::IpRemove(ip) => self.process_async_noop(
685 move |inner: Arc<dyn VirtualNetworking + Send + Sync>| async move {
686 inner.ip_remove(ip).await
687 },
688 req_id,
689 ),
690 RequestType::IpClear => self.process_async_noop(
691 move |inner: Arc<dyn VirtualNetworking + Send + Sync>| async move {
692 inner.ip_clear().await
693 },
694 req_id,
695 ),
696 RequestType::GetIpList => self.process_async_inner(
697 move |inner: Arc<dyn VirtualNetworking + Send + Sync>| async move {
698 inner.ip_list().await
699 },
700 |ret| match ret {
701 Ok(cidr) => ResponseType::CidrList(cidr),
702 Err(err) => ResponseType::Err(err),
703 },
704 req_id,
705 ),
706 RequestType::GetMac => {
707 self.process_async_inner(
708 move |inner: Arc<dyn VirtualNetworking + Send + Sync>| async move {
709 inner.mac().await
710 },
711 |ret| match ret {
712 Ok(mac) => ResponseType::Mac(mac),
713 Err(err) => ResponseType::Err(err),
714 },
715 req_id,
716 )
717 }
718 RequestType::GatewaySet(ip) => self.process_async_noop(
719 move |inner: Arc<dyn VirtualNetworking + Send + Sync>| async move {
720 inner.gateway_set(ip).await
721 },
722 req_id,
723 ),
724 RequestType::RouteAdd {
725 cidr,
726 via_router,
727 preferred_until,
728 expires_at,
729 } => self.process_async_noop(
730 move |inner: Arc<dyn VirtualNetworking + Send + Sync>| async move {
731 inner
732 .route_add(cidr, via_router, preferred_until, expires_at)
733 .await
734 },
735 req_id,
736 ),
737 RequestType::RouteRemove(ip) => self.process_async_noop(
738 move |inner: Arc<dyn VirtualNetworking + Send + Sync>| async move {
739 inner.route_remove(ip).await
740 },
741 req_id,
742 ),
743 RequestType::RouteClear => self.process_async_noop(
744 move |inner: Arc<dyn VirtualNetworking + Send + Sync>| async move {
745 inner.route_clear().await
746 },
747 req_id,
748 ),
749 RequestType::GetRouteList => self.process_async_inner(
750 move |inner: Arc<dyn VirtualNetworking + Send + Sync>| async move {
751 inner.route_list().await
752 },
753 |ret| match ret {
754 Ok(routes) => ResponseType::RouteList(routes),
755 Err(err) => ResponseType::Err(err),
756 },
757 req_id,
758 ),
759 RequestType::BindRaw(socket_id) => self.process_async_new_socket(
760 move |inner: Arc<dyn VirtualNetworking + Send + Sync>| async move {
761 Ok(RemoteAdapterSocket::RawSocket(inner.bind_raw().await?))
762 },
763 socket_id,
764 req_id,
765 ),
766 RequestType::BindTcp {
767 socket_id,
768 addr,
769 only_v6,
770 reuse_port,
771 reuse_addr,
772 } => self.process_async_new_socket(
773 move |inner: Arc<dyn VirtualNetworking + Send + Sync>| async move {
774 Ok(RemoteAdapterSocket::BoundTcp(
775 inner
776 .bind_tcp(addr, only_v6, reuse_port, reuse_addr)
777 .await?,
778 ))
779 },
780 socket_id,
781 req_id,
782 ),
783 RequestType::ListenTcp {
784 socket_id,
785 addr,
786 only_v6,
787 reuse_port,
788 reuse_addr,
789 } => self.process_async_new_socket(
790 move |inner: Arc<dyn VirtualNetworking + Send + Sync>| async move {
791 Ok(RemoteAdapterSocket::TcpListener {
792 socket: inner
793 .listen_tcp(addr, only_v6, reuse_port, reuse_addr)
794 .await?,
795 next_accept: None,
796 })
797 },
798 socket_id,
799 req_id,
800 ),
801 RequestType::BindUdp {
802 socket_id,
803 addr,
804 reuse_port,
805 reuse_addr,
806 } => self.process_async_new_socket(
807 move |inner: Arc<dyn VirtualNetworking + Send + Sync>| async move {
808 Ok(RemoteAdapterSocket::UdpSocket(
809 inner.bind_udp(addr, reuse_port, reuse_addr).await?,
810 ))
811 },
812 socket_id,
813 req_id,
814 ),
815 RequestType::BindIcmp { socket_id, addr } => self.process_async_new_socket(
816 move |inner: Arc<dyn VirtualNetworking + Send + Sync>| async move {
817 Ok(RemoteAdapterSocket::IcmpSocket(
818 inner.bind_icmp(addr).await?,
819 ))
820 },
821 socket_id,
822 req_id,
823 ),
824 RequestType::ConnectTcp {
825 socket_id,
826 addr,
827 peer,
828 } => self.process_async_new_socket(
829 move |inner: Arc<dyn VirtualNetworking + Send + Sync>| async move {
830 Ok(RemoteAdapterSocket::TcpSocket(
831 inner.connect_tcp(addr, peer).await?,
832 ))
833 },
834 socket_id,
835 req_id,
836 ),
837 RequestType::Resolve {
838 host,
839 port,
840 dns_server,
841 } => self.process_async_inner(
842 move |inner: Arc<dyn VirtualNetworking + Send + Sync>| async move {
843 inner.resolve(&host, port, dns_server).await
844 },
845 |ret| match ret {
846 Ok(ips) => ResponseType::IpAddressList(ips),
847 Err(err) => ResponseType::Err(err),
848 },
849 req_id,
850 ),
851 _ => req_id.and_then(|req_id| {
852 self.common.send(MessageResponse::ResponseToRequest {
853 req_id,
854 res: ResponseType::Err(NetworkError::Unsupported),
855 })
856 }),
857 }
858 }
859
860 fn process_socket(
861 &mut self,
862 socket_id: SocketId,
863 req: RequestType,
864 req_id: Option<u64>,
865 ) -> BackgroundTask {
866 match req {
867 RequestType::Flush => self.process_inner_noop(
868 move |socket| match socket {
869 RemoteAdapterSocket::TcpSocket(s) => s.try_flush(),
870 RemoteAdapterSocket::RawSocket(s) => s.try_flush(),
871 _ => Err(NetworkError::Unsupported),
872 },
873 socket_id,
874 req_id,
875 ),
876 RequestType::Close => {
877 let res = {
878 let mut guard = self.common.sockets.lock().unwrap();
879 self.common.socket_accept.lock().unwrap().remove(&socket_id);
880 match guard.remove(&socket_id) {
881 Some(RemoteAdapterSocket::TcpSocket(mut socket)) => socket.close(),
882 Some(_) => Ok(()),
883 None => Err(NetworkError::InvalidFd),
884 }
885 };
886 req_id.and_then(|req_id| {
887 self.common.send(MessageResponse::ResponseToRequest {
888 req_id,
889 res: match res {
890 Ok(()) => ResponseType::None,
891 Err(err) => ResponseType::Err(err),
892 },
893 })
894 })
895 }
896 RequestType::ListenBound => {
897 let res = {
898 let mut guard = self.common.sockets.lock().unwrap();
899 match guard.get_mut(&socket_id) {
900 Some(socket) => match socket {
901 RemoteAdapterSocket::BoundTcp(bound) => match bound.listen() {
902 Ok(listener) => {
903 *socket = RemoteAdapterSocket::TcpListener {
904 socket: listener,
905 next_accept: None,
906 };
907 Ok(())
908 }
909 Err(err) => Err(err),
910 },
911 _ => Err(NetworkError::Unsupported),
912 },
913 None => Err(NetworkError::InvalidFd),
914 }
915 };
916 req_id.and_then(|req_id| {
917 self.common.send(MessageResponse::ResponseToRequest {
918 req_id,
919 res: match res {
920 Ok(()) => ResponseType::None,
921 Err(err) => ResponseType::Err(err),
922 },
923 })
924 })
925 }
926 RequestType::ConnectBound { peer } => {
927 let res = {
928 let mut guard = self.common.sockets.lock().unwrap();
929 match guard.get_mut(&socket_id) {
930 Some(socket) => match socket {
931 RemoteAdapterSocket::BoundTcp(bound) => match bound.connect(peer) {
932 Ok(connected) => {
933 *socket = RemoteAdapterSocket::TcpSocket(connected);
934 Ok(())
935 }
936 Err(err) => Err(err),
937 },
938 _ => Err(NetworkError::Unsupported),
939 },
940 None => Err(NetworkError::InvalidFd),
941 }
942 };
943 req_id.and_then(|req_id| {
944 self.common.send(MessageResponse::ResponseToRequest {
945 req_id,
946 res: match res {
947 Ok(()) => ResponseType::None,
948 Err(err) => ResponseType::Err(err),
949 },
950 })
951 })
952 }
953 RequestType::BeginAccept(child_id) => {
954 self.process_inner_begin_accept(socket_id, child_id, req_id)
955 }
956 RequestType::GetAddrLocal => self.process_inner(
957 move |socket| match socket {
958 RemoteAdapterSocket::BoundTcp(s) => s.addr_local(),
959 RemoteAdapterSocket::TcpSocket(s) => s.addr_local(),
960 RemoteAdapterSocket::TcpListener { socket: s, .. } => s.addr_local(),
961 RemoteAdapterSocket::UdpSocket(s) => s.addr_local(),
962 RemoteAdapterSocket::IcmpSocket(s) => s.addr_local(),
963 RemoteAdapterSocket::RawSocket(s) => s.addr_local(),
964 },
965 |ret| match ret {
966 Ok(addr) => ResponseType::SocketAddr(addr),
967 Err(err) => ResponseType::Err(err),
968 },
969 socket_id,
970 req_id,
971 ),
972 RequestType::GetAddrPeer => self.process_inner(
973 move |socket| match socket {
974 RemoteAdapterSocket::BoundTcp(_) => Err(NetworkError::Unsupported),
975 RemoteAdapterSocket::TcpSocket(s) => s.addr_peer().map(Some),
976 RemoteAdapterSocket::TcpListener { .. } => Err(NetworkError::Unsupported),
977 RemoteAdapterSocket::UdpSocket(s) => s.addr_peer(),
978 RemoteAdapterSocket::IcmpSocket(_) => Err(NetworkError::Unsupported),
979 RemoteAdapterSocket::RawSocket(_) => Err(NetworkError::Unsupported),
980 },
981 |ret| match ret {
982 Ok(Some(addr)) => ResponseType::SocketAddr(addr),
983 Ok(None) => ResponseType::None,
984 Err(err) => ResponseType::Err(err),
985 },
986 socket_id,
987 req_id,
988 ),
989 RequestType::SetTtl(ttl) => self.process_inner_noop(
990 move |socket| match socket {
991 RemoteAdapterSocket::BoundTcp(s) => s.set_ttl(ttl),
992 RemoteAdapterSocket::TcpSocket(s) => s.set_ttl(ttl),
993 RemoteAdapterSocket::TcpListener { socket: s, .. } => {
994 s.set_ttl(ttl.try_into().unwrap_or_default())
995 }
996 RemoteAdapterSocket::UdpSocket(s) => s.set_ttl(ttl),
997 RemoteAdapterSocket::IcmpSocket(s) => s.set_ttl(ttl),
998 RemoteAdapterSocket::RawSocket(s) => s.set_ttl(ttl),
999 },
1000 socket_id,
1001 req_id,
1002 ),
1003 RequestType::GetTtl => self.process_inner(
1004 move |socket| match socket {
1005 RemoteAdapterSocket::BoundTcp(s) => s.ttl(),
1006 RemoteAdapterSocket::TcpSocket(s) => s.ttl(),
1007 RemoteAdapterSocket::TcpListener { socket: s, .. } => s.ttl().map(|t| t as u32),
1008 RemoteAdapterSocket::UdpSocket(s) => s.ttl(),
1009 RemoteAdapterSocket::IcmpSocket(s) => s.ttl(),
1010 RemoteAdapterSocket::RawSocket(s) => s.ttl(),
1011 },
1012 |ret| match ret {
1013 Ok(ttl) => ResponseType::Ttl(ttl),
1014 Err(err) => ResponseType::Err(err),
1015 },
1016 socket_id,
1017 req_id,
1018 ),
1019 RequestType::GetStatus => self.process_inner(
1020 move |socket| match socket {
1021 RemoteAdapterSocket::BoundTcp(_) => Ok(SocketStatus::Opened),
1022 RemoteAdapterSocket::TcpSocket(s) => s.status(),
1023 RemoteAdapterSocket::TcpListener { .. } => Err(NetworkError::Unsupported),
1024 RemoteAdapterSocket::UdpSocket(s) => s.status(),
1025 RemoteAdapterSocket::IcmpSocket(s) => s.status(),
1026 RemoteAdapterSocket::RawSocket(s) => s.status(),
1027 },
1028 |ret| match ret {
1029 Ok(status) => ResponseType::Status(status),
1030 Err(err) => ResponseType::Err(err),
1031 },
1032 socket_id,
1033 req_id,
1034 ),
1035 RequestType::SetLinger(linger) => self.process_inner_noop(
1036 move |socket| match socket {
1037 RemoteAdapterSocket::TcpSocket(s) => s.set_linger(linger),
1038 _ => Err(NetworkError::Unsupported),
1039 },
1040 socket_id,
1041 req_id,
1042 ),
1043 RequestType::GetLinger => self.process_inner(
1044 move |socket| match socket {
1045 RemoteAdapterSocket::TcpSocket(s) => s.linger(),
1046 _ => Err(NetworkError::Unsupported),
1047 },
1048 |ret| match ret {
1049 Ok(Some(time)) => ResponseType::Duration(time),
1050 Ok(None) => ResponseType::None,
1051 Err(err) => ResponseType::Err(err),
1052 },
1053 socket_id,
1054 req_id,
1055 ),
1056 RequestType::SetPromiscuous(promiscuous) => self.process_inner_noop(
1057 move |socket| match socket {
1058 RemoteAdapterSocket::RawSocket(s) => s.set_promiscuous(promiscuous),
1059 _ => Err(NetworkError::Unsupported),
1060 },
1061 socket_id,
1062 req_id,
1063 ),
1064 RequestType::GetPromiscuous => self.process_inner(
1065 move |socket| match socket {
1066 RemoteAdapterSocket::RawSocket(s) => s.promiscuous(),
1067 _ => Err(NetworkError::Unsupported),
1068 },
1069 |ret| match ret {
1070 Ok(flag) => ResponseType::Flag(flag),
1071 Err(err) => ResponseType::Err(err),
1072 },
1073 socket_id,
1074 req_id,
1075 ),
1076 RequestType::SetRecvBufSize(size) => self.process_inner_noop(
1077 move |socket| match socket {
1078 RemoteAdapterSocket::TcpSocket(s) => {
1079 s.set_recv_buf_size(size.try_into().unwrap_or_default())
1080 }
1081 _ => Err(NetworkError::Unsupported),
1082 },
1083 socket_id,
1084 req_id,
1085 ),
1086 RequestType::GetRecvBufSize => self.process_inner(
1087 move |socket| match socket {
1088 RemoteAdapterSocket::TcpSocket(s) => s.recv_buf_size(),
1089 _ => Err(NetworkError::Unsupported),
1090 },
1091 |ret| match ret {
1092 Ok(amt) => ResponseType::Amount(amt as u64),
1093 Err(err) => ResponseType::Err(err),
1094 },
1095 socket_id,
1096 req_id,
1097 ),
1098 RequestType::SetSendBufSize(size) => self.process_inner_noop(
1099 move |socket| match socket {
1100 RemoteAdapterSocket::TcpSocket(s) => {
1101 s.set_send_buf_size(size.try_into().unwrap_or_default())
1102 }
1103 _ => Err(NetworkError::Unsupported),
1104 },
1105 socket_id,
1106 req_id,
1107 ),
1108 RequestType::GetSendBufSize => self.process_inner(
1109 move |socket| match socket {
1110 RemoteAdapterSocket::TcpSocket(s) => s.send_buf_size(),
1111 _ => Err(NetworkError::Unsupported),
1112 },
1113 |ret| match ret {
1114 Ok(amt) => ResponseType::Amount(amt as u64),
1115 Err(err) => ResponseType::Err(err),
1116 },
1117 socket_id,
1118 req_id,
1119 ),
1120 RequestType::SetNoDelay(reuse) => self.process_inner_noop(
1121 move |socket| match socket {
1122 RemoteAdapterSocket::TcpSocket(s) => s.set_nodelay(reuse),
1123 _ => Err(NetworkError::Unsupported),
1124 },
1125 socket_id,
1126 req_id,
1127 ),
1128 RequestType::GetNoDelay => self.process_inner(
1129 move |socket| match socket {
1130 RemoteAdapterSocket::TcpSocket(s) => s.nodelay(),
1131 _ => Err(NetworkError::Unsupported),
1132 },
1133 |ret| match ret {
1134 Ok(flag) => ResponseType::Flag(flag),
1135 Err(err) => ResponseType::Err(err),
1136 },
1137 socket_id,
1138 req_id,
1139 ),
1140 RequestType::SetKeepAlive(val) => self.process_inner_noop(
1141 move |socket| match socket {
1142 RemoteAdapterSocket::TcpSocket(s) => s.set_keepalive(val),
1143 _ => Err(NetworkError::Unsupported),
1144 },
1145 socket_id,
1146 req_id,
1147 ),
1148 RequestType::GetKeepAlive => self.process_inner(
1149 move |socket| match socket {
1150 RemoteAdapterSocket::TcpSocket(s) => s.keepalive(),
1151 _ => Err(NetworkError::Unsupported),
1152 },
1153 |ret| match ret {
1154 Ok(flag) => ResponseType::Flag(flag),
1155 Err(err) => ResponseType::Err(err),
1156 },
1157 socket_id,
1158 req_id,
1159 ),
1160 RequestType::SetDontRoute(val) => self.process_inner_noop(
1161 move |socket| match socket {
1162 RemoteAdapterSocket::TcpSocket(s) => s.set_dontroute(val),
1163 _ => Err(NetworkError::Unsupported),
1164 },
1165 socket_id,
1166 req_id,
1167 ),
1168 RequestType::GetDontRoute => self.process_inner(
1169 move |socket| match socket {
1170 RemoteAdapterSocket::TcpSocket(s) => s.dontroute(),
1171 _ => Err(NetworkError::Unsupported),
1172 },
1173 |ret| match ret {
1174 Ok(flag) => ResponseType::Flag(flag),
1175 Err(err) => ResponseType::Err(err),
1176 },
1177 socket_id,
1178 req_id,
1179 ),
1180 RequestType::Shutdown(shutdown) => self.process_inner_noop(
1181 move |socket| match socket {
1182 RemoteAdapterSocket::TcpSocket(s) => s.shutdown(match shutdown {
1183 crate::meta::Shutdown::Read => std::net::Shutdown::Read,
1184 crate::meta::Shutdown::Write => std::net::Shutdown::Write,
1185 crate::meta::Shutdown::Both => std::net::Shutdown::Both,
1186 }),
1187 _ => Err(NetworkError::Unsupported),
1188 },
1189 socket_id,
1190 req_id,
1191 ),
1192 RequestType::IsClosed => self.process_inner(
1193 move |socket| match socket {
1194 RemoteAdapterSocket::TcpSocket(s) => Ok(s.is_closed()),
1195 _ => Err(NetworkError::Unsupported),
1196 },
1197 |ret| match ret {
1198 Ok(flag) => ResponseType::Flag(flag),
1199 Err(err) => ResponseType::Err(err),
1200 },
1201 socket_id,
1202 req_id,
1203 ),
1204 RequestType::SetBroadcast(broadcast) => self.process_inner_noop(
1205 move |socket| match socket {
1206 RemoteAdapterSocket::UdpSocket(s) => s.set_broadcast(broadcast),
1207 _ => Err(NetworkError::Unsupported),
1208 },
1209 socket_id,
1210 req_id,
1211 ),
1212 RequestType::GetBroadcast => self.process_inner(
1213 move |socket| match socket {
1214 RemoteAdapterSocket::UdpSocket(s) => s.broadcast(),
1215 _ => Err(NetworkError::Unsupported),
1216 },
1217 |ret| match ret {
1218 Ok(flag) => ResponseType::Flag(flag),
1219 Err(err) => ResponseType::Err(err),
1220 },
1221 socket_id,
1222 req_id,
1223 ),
1224 RequestType::SetMulticastLoopV4(val) => self.process_inner_noop(
1225 move |socket| match socket {
1226 RemoteAdapterSocket::UdpSocket(s) => s.set_multicast_loop_v4(val),
1227 _ => Err(NetworkError::Unsupported),
1228 },
1229 socket_id,
1230 req_id,
1231 ),
1232 RequestType::GetMulticastLoopV4 => self.process_inner(
1233 move |socket| match socket {
1234 RemoteAdapterSocket::UdpSocket(s) => s.multicast_loop_v4(),
1235 _ => Err(NetworkError::Unsupported),
1236 },
1237 |ret| match ret {
1238 Ok(flag) => ResponseType::Flag(flag),
1239 Err(err) => ResponseType::Err(err),
1240 },
1241 socket_id,
1242 req_id,
1243 ),
1244 RequestType::SetMulticastLoopV6(val) => self.process_inner_noop(
1245 move |socket| match socket {
1246 RemoteAdapterSocket::UdpSocket(s) => s.set_multicast_loop_v6(val),
1247 _ => Err(NetworkError::Unsupported),
1248 },
1249 socket_id,
1250 req_id,
1251 ),
1252 RequestType::GetMulticastLoopV6 => self.process_inner(
1253 move |socket| match socket {
1254 RemoteAdapterSocket::UdpSocket(s) => s.multicast_loop_v6(),
1255 _ => Err(NetworkError::Unsupported),
1256 },
1257 |ret| match ret {
1258 Ok(flag) => ResponseType::Flag(flag),
1259 Err(err) => ResponseType::Err(err),
1260 },
1261 socket_id,
1262 req_id,
1263 ),
1264 RequestType::SetMulticastTtlV4(ttl) => self.process_inner_noop(
1265 move |socket| match socket {
1266 RemoteAdapterSocket::UdpSocket(s) => s.set_multicast_ttl_v4(ttl),
1267 _ => Err(NetworkError::Unsupported),
1268 },
1269 socket_id,
1270 req_id,
1271 ),
1272 RequestType::GetMulticastTtlV4 => self.process_inner(
1273 move |socket| match socket {
1274 RemoteAdapterSocket::UdpSocket(s) => s.multicast_ttl_v4(),
1275 _ => Err(NetworkError::Unsupported),
1276 },
1277 |ret| match ret {
1278 Ok(ttl) => ResponseType::Ttl(ttl),
1279 Err(err) => ResponseType::Err(err),
1280 },
1281 socket_id,
1282 req_id,
1283 ),
1284 RequestType::JoinMulticastV4 { multiaddr, iface } => self.process_inner_noop(
1285 move |socket| match socket {
1286 RemoteAdapterSocket::UdpSocket(s) => s.join_multicast_v4(multiaddr, iface),
1287 _ => Err(NetworkError::Unsupported),
1288 },
1289 socket_id,
1290 req_id,
1291 ),
1292 RequestType::LeaveMulticastV4 { multiaddr, iface } => self.process_inner_noop(
1293 move |socket| match socket {
1294 RemoteAdapterSocket::UdpSocket(s) => s.leave_multicast_v4(multiaddr, iface),
1295 _ => Err(NetworkError::Unsupported),
1296 },
1297 socket_id,
1298 req_id,
1299 ),
1300 RequestType::JoinMulticastV6 { multiaddr, iface } => self.process_inner_noop(
1301 move |socket| match socket {
1302 RemoteAdapterSocket::UdpSocket(s) => s.join_multicast_v6(multiaddr, iface),
1303 _ => Err(NetworkError::Unsupported),
1304 },
1305 socket_id,
1306 req_id,
1307 ),
1308 RequestType::LeaveMulticastV6 { multiaddr, iface } => self.process_inner_noop(
1309 move |socket| match socket {
1310 RemoteAdapterSocket::UdpSocket(s) => s.leave_multicast_v6(multiaddr, iface),
1311 _ => Err(NetworkError::Unsupported),
1312 },
1313 socket_id,
1314 req_id,
1315 ),
1316 _ => req_id.and_then(|req_id| {
1317 self.common.send(MessageResponse::ResponseToRequest {
1318 req_id,
1319 res: ResponseType::Err(NetworkError::Unsupported),
1320 })
1321 }),
1322 }
1323 }
1324}
1325
1326#[derive(Debug)]
1327enum RemoteAdapterSocket {
1328 BoundTcp(Box<dyn VirtualTcpBoundSocket + Sync + 'static>),
1329 TcpListener {
1330 socket: Box<dyn VirtualTcpListener + Sync + 'static>,
1331 next_accept: Option<SocketId>,
1332 },
1333 TcpSocket(Box<dyn VirtualTcpSocket + Sync + 'static>),
1334 UdpSocket(Box<dyn VirtualUdpSocket + Sync + 'static>),
1335 RawSocket(Box<dyn VirtualRawSocket + Sync + 'static>),
1336 IcmpSocket(Box<dyn VirtualIcmpSocket + Sync + 'static>),
1337}
1338
1339impl RemoteAdapterSocket {
1340 pub fn send(
1341 &mut self,
1342 common: &Arc<RemoteAdapterCommon>,
1343 socket_id: SocketId,
1344 data: Vec<u8>,
1345 req_id: Option<u64>,
1346 ) -> BackgroundTask {
1347 match self {
1348 Self::TcpSocket(this) => match this.try_send(&data) {
1349 Ok(amount) => {
1350 if let Some(req_id) = req_id {
1351 common.send(MessageResponse::Sent {
1352 socket_id,
1353 req_id,
1354 amount: amount as u64,
1355 })
1356 } else {
1357 None
1358 }
1359 }
1360 Err(NetworkError::WouldBlock) => {
1361 let common = common.clone();
1362 Some(Box::pin(async move {
1363 let _stall_rx = common.stall_rx.clone().lock_owned().await;
1366
1367 struct Poller {
1369 common: Arc<RemoteAdapterCommon>,
1370 socket_id: SocketId,
1371 data: Vec<u8>,
1372 req_id: Option<u64>,
1373 }
1374 impl Future for Poller {
1375 type Output = BackgroundTask;
1376 fn poll(
1377 self: Pin<&mut Self>,
1378 cx: &mut Context<'_>,
1379 ) -> Poll<Self::Output> {
1380 let mut guard = self.common.handler.state.lock().unwrap();
1383 if !guard.driver_wakers.iter().any(|w| w.will_wake(cx.waker())) {
1384 guard.driver_wakers.push(cx.waker().clone());
1385 }
1386 drop(guard);
1387
1388 let mut guard = self.common.sockets.lock().unwrap();
1389 if let Some(RemoteAdapterSocket::TcpSocket(socket)) =
1390 guard.get_mut(&self.socket_id)
1391 {
1392 match socket.try_send(&self.data) {
1393 Ok(amount) => {
1394 if let Some(req_id) = self.req_id {
1395 return Poll::Ready(self.common.send(
1396 MessageResponse::Sent {
1397 socket_id: self.socket_id,
1398 req_id,
1399 amount: amount as u64,
1400 },
1401 ));
1402 } else {
1403 return Poll::Ready(None);
1404 }
1405 }
1406 Err(NetworkError::WouldBlock) => return Poll::Pending,
1407 Err(error) => {
1408 if let Some(req_id) = self.req_id {
1409 return Poll::Ready(self.common.send(
1410 MessageResponse::SendError {
1411 socket_id: self.socket_id,
1412 req_id,
1413 error,
1414 },
1415 ));
1416 } else {
1417 return Poll::Ready(None);
1418 }
1419 }
1420 }
1421 }
1422 Poll::Ready(None)
1423 }
1424 }
1425
1426 let background_task = Poller {
1428 common,
1429 socket_id,
1430 data,
1431 req_id,
1432 }
1433 .await;
1434
1435 if let Some(background_task) = background_task {
1437 background_task.await;
1438 }
1439 }))
1440 }
1441 Err(error) => {
1442 if let Some(req_id) = req_id {
1443 common.send(MessageResponse::SendError {
1444 socket_id,
1445 req_id,
1446 error,
1447 })
1448 } else {
1449 None
1450 }
1451 }
1452 },
1453 Self::RawSocket(this) => {
1454 if let Err(err) = this.try_send(&data) {
1460 tracing::debug!("failed to send raw packet - {}", err);
1461 }
1462 None
1463 }
1464 _ => {
1465 if let Some(req_id) = req_id {
1466 common.send(MessageResponse::SendError {
1467 socket_id,
1468 req_id,
1469 error: NetworkError::Unsupported,
1470 })
1471 } else {
1472 None
1473 }
1474 }
1475 }
1476 }
1477 pub fn send_to(
1478 &mut self,
1479 common: &Arc<RemoteAdapterCommon>,
1480 socket_id: SocketId,
1481 data: Vec<u8>,
1482 addr: SocketAddr,
1483 req_id: u64,
1484 ) -> BackgroundTask {
1485 match self {
1486 Self::UdpSocket(this) => {
1487 this.try_send_to(&data, addr).ok();
1490 None
1491 }
1492
1493 Self::IcmpSocket(this) => {
1494 this.try_send_to(&data, addr).ok();
1497 None
1498 }
1499 _ => common.send(MessageResponse::SendError {
1500 socket_id,
1501 req_id,
1502 error: NetworkError::Unsupported,
1503 }),
1504 }
1505 }
1506 pub fn drain_reads_and_accepts(
1507 &mut self,
1508 common: &Arc<RemoteAdapterCommon>,
1509 socket_id: SocketId,
1510 ) -> BackgroundTask {
1511 let mut ret: FuturesOrdered<BoxFuture<'static, ()>> = Default::default();
1514 loop {
1515 break match self {
1516 Self::BoundTcp(_) => {}
1517 Self::TcpListener {
1518 socket,
1519 next_accept,
1520 } => {
1521 if next_accept.is_some() {
1522 match socket.try_accept() {
1523 Ok((mut child_socket, addr)) => {
1524 let child_id = next_accept.take().unwrap();
1525
1526 let handler = Box::new(common.handler.clone().for_socket(child_id));
1528 child_socket.set_handler(handler).ok();
1529
1530 let common = common.clone();
1533 ret.push_back(Box::pin(async move {
1534 {
1536 let child_socket =
1537 RemoteAdapterSocket::TcpSocket(child_socket);
1538 let mut guard = common.sockets.lock().unwrap();
1539 guard.insert(child_id, child_socket);
1540 }
1541
1542 if let Some(task) = common.send(MessageResponse::FinishAccept {
1544 socket_id,
1545 child_id,
1546 addr,
1547 }) {
1548 task.await;
1549 }
1550 }));
1551 }
1552 Err(NetworkError::WouldBlock) => {}
1553 Err(err) => {
1554 tracing::error!("failed to accept socket - {}", err);
1555 }
1556 }
1557 }
1558 }
1559 Self::TcpSocket(this) => {
1560 let mut chunk: [MaybeUninit<u8>; 10240] =
1561 unsafe { MaybeUninit::uninit().assume_init() };
1562 match this.try_recv(&mut chunk, false) {
1563 Ok(0) => {}
1564 Ok(amt) => {
1565 let chunk_unsafe: &mut [MaybeUninit<u8>] = &mut chunk[..amt];
1566 let chunk_unsafe: &mut [u8] =
1567 unsafe { std::mem::transmute(chunk_unsafe) };
1568 if let Some(task) = common.send(MessageResponse::Recv {
1569 socket_id,
1570 data: chunk_unsafe.to_vec(),
1571 }) {
1572 ret.push_back(task);
1573 }
1574 continue;
1575 }
1576 Err(_) => {}
1577 }
1578 }
1579 Self::UdpSocket(this) => {
1580 let mut chunk: [MaybeUninit<u8>; 10240] =
1581 unsafe { MaybeUninit::uninit().assume_init() };
1582 match this.try_recv_from(&mut chunk, false) {
1583 Ok((0, _)) => {}
1584 Ok((amt, addr)) => {
1585 let chunk_unsafe: &mut [MaybeUninit<u8>] = &mut chunk[..amt];
1586 let chunk_unsafe: &mut [u8] =
1587 unsafe { std::mem::transmute(chunk_unsafe) };
1588 if let Some(task) = common.send(MessageResponse::RecvWithAddr {
1589 socket_id,
1590 data: chunk_unsafe.to_vec(),
1591 addr,
1592 }) {
1593 ret.push_back(task);
1594 }
1595 continue;
1596 }
1597 Err(_) => {}
1598 }
1599 }
1600 Self::IcmpSocket(this) => {
1601 let mut chunk: [MaybeUninit<u8>; 10240] =
1602 unsafe { MaybeUninit::uninit().assume_init() };
1603 match this.try_recv_from(&mut chunk, false) {
1604 Ok((0, _)) => {}
1605 Ok((amt, addr)) => {
1606 let chunk_unsafe: &mut [MaybeUninit<u8>] = &mut chunk[..amt];
1607 let chunk_unsafe: &mut [u8] =
1608 unsafe { std::mem::transmute(chunk_unsafe) };
1609 if let Some(task) = common.send(MessageResponse::RecvWithAddr {
1610 socket_id,
1611 data: chunk_unsafe.to_vec(),
1612 addr,
1613 }) {
1614 ret.push_back(task);
1615 }
1616 continue;
1617 }
1618 Err(_) => {}
1619 }
1620 }
1621 Self::RawSocket(this) => {
1622 let mut chunk: [MaybeUninit<u8>; 10240] =
1623 unsafe { MaybeUninit::uninit().assume_init() };
1624 match this.try_recv(&mut chunk, false) {
1625 Ok(0) => {}
1626 Ok(amt) => {
1627 let chunk_unsafe: &mut [MaybeUninit<u8>] = &mut chunk[..amt];
1628 let chunk_unsafe: &mut [u8] =
1629 unsafe { std::mem::transmute(chunk_unsafe) };
1630 if let Some(task) = common.send(MessageResponse::Recv {
1631 socket_id,
1632 data: chunk_unsafe.to_vec(),
1633 }) {
1634 ret.push_back(task);
1635 }
1636 continue;
1637 }
1638 Err(_) => {}
1639 }
1640 }
1641 };
1642 }
1643
1644 if ret.is_empty() {
1645 None
1647 } else {
1648 Some(Box::pin(async move {
1649 let mut stream = ret;
1651 loop {
1652 let (next, s) = StreamExt::into_future(stream).await;
1653 if next.is_none() {
1654 break;
1655 }
1656 stream = s;
1657 }
1658 }))
1659 }
1660 }
1661}
1662
1663#[derive(Debug, Default)]
1664struct RemoteAdapterHandlerState {
1665 readable: HashSet<SocketId>,
1666 driver_wakers: Vec<Waker>,
1667}
1668
1669#[derive(Debug, Default, Clone)]
1670struct RemoteAdapterHandler {
1671 socket_id: Option<SocketId>,
1672 state: Arc<Mutex<RemoteAdapterHandlerState>>,
1673}
1674impl RemoteAdapterHandler {
1675 pub fn for_socket(self, id: SocketId) -> Self {
1676 Self {
1677 socket_id: Some(id),
1678 state: self.state,
1679 }
1680 }
1681}
1682impl InterestHandler for RemoteAdapterHandler {
1683 fn push_interest(&mut self, interest: virtual_mio::InterestType) {
1684 let mut guard = self.state.lock().unwrap();
1685 guard.driver_wakers.drain(..).for_each(|w| w.wake());
1686 let socket_id = match self.socket_id {
1687 Some(s) => s,
1688 None => return,
1689 };
1690 if interest == virtual_mio::InterestType::Readable {
1691 guard.readable.insert(socket_id);
1692 }
1693 }
1694
1695 fn pop_interest(&mut self, interest: virtual_mio::InterestType) -> bool {
1696 let mut guard = self.state.lock().unwrap();
1697 let socket_id = match self.socket_id {
1698 Some(s) => s,
1699 None => return false,
1700 };
1701 if interest == virtual_mio::InterestType::Readable {
1702 return guard.readable.remove(&socket_id);
1703 }
1704 false
1705 }
1706
1707 fn has_interest(&self, interest: virtual_mio::InterestType) -> bool {
1708 let guard = self.state.lock().unwrap();
1709 let socket_id = match self.socket_id {
1710 Some(s) => s,
1711 None => return false,
1712 };
1713 if interest == virtual_mio::InterestType::Readable {
1714 return guard.readable.contains(&socket_id);
1715 }
1716 false
1717 }
1718}
1719
1720type SocketMap<T> = HashMap<SocketId, T>;
1721
1722#[derive(Debug)]
1723struct RemoteAdapterCommon {
1724 tx: RemoteTx<MessageResponse>,
1725 rx: Mutex<RemoteRx<MessageRequest>>,
1726 sockets: Mutex<SocketMap<RemoteAdapterSocket>>,
1727 socket_accept: Mutex<SocketMap<SocketId>>,
1728 handler: RemoteAdapterHandler,
1729
1730 stall_rx: Arc<tokio::sync::Mutex<()>>,
1733}
1734impl RemoteAdapterCommon {
1735 fn send(self: &Arc<Self>, req: MessageResponse) -> BackgroundTask {
1736 let this = self.clone();
1737 Some(Box::pin(async move {
1738 if let Err(err) = this.tx.send(req).await {
1739 tracing::debug!("failed to send message - {}", err);
1740 }
1741 }))
1742 }
1743}