wasmer_wasix/runners/dproxy/
socket_manager.rs

1use std::{
2    future::poll_fn,
3    net::{IpAddr, Ipv4Addr, SocketAddr},
4    sync::{
5        Arc,
6        atomic::{AtomicBool, Ordering},
7    },
8    task::{Context, Poll},
9    time::Duration,
10};
11
12use tokio::sync::broadcast;
13use virtual_net::{LoopbackNetworking, tcp_pair::TcpSocketHalf};
14
15pub type PollListeningFn =
16    Arc<dyn Fn(&mut Context<'_>) -> Poll<SocketAddr> + Send + Sync + 'static>;
17
18#[derive(derive_more::Debug)]
19pub struct SocketManager {
20    #[debug(ignore)]
21    poll_listening: PollListeningFn,
22    loopback_networking: LoopbackNetworking,
23    proxy_connect_init_timeout: Duration,
24    proxy_connect_nominal_timeout: Duration,
25    is_running: AtomicBool,
26    is_terminated: AtomicBool,
27    terminate_all: broadcast::Sender<()>,
28}
29
30impl SocketManager {
31    pub fn new(
32        poll_listening: PollListeningFn,
33        loopback_networking: LoopbackNetworking,
34        proxy_connect_init_timeout: Duration,
35        proxy_connect_nominal_timeout: Duration,
36    ) -> Self {
37        Self {
38            poll_listening,
39            loopback_networking,
40            proxy_connect_init_timeout,
41            proxy_connect_nominal_timeout,
42            is_running: AtomicBool::new(false),
43            is_terminated: AtomicBool::new(false),
44            terminate_all: broadcast::channel(1).0,
45        }
46    }
47
48    pub fn shutdown(&self) {
49        self.is_terminated.store(true, Ordering::SeqCst);
50        self.terminate_all.send(()).ok();
51    }
52
53    pub fn terminate_rx(&self) -> broadcast::Receiver<()> {
54        self.terminate_all.subscribe()
55    }
56
57    pub async fn acquire_http_socket(&self) -> anyhow::Result<TcpSocketHalf> {
58        let mut rx_terminate = self.terminate_all.subscribe();
59
60        if self.is_terminated.load(Ordering::SeqCst) {
61            return Err(anyhow::anyhow!(
62                "failed to open HTTP socket as the instance has terminated"
63            ));
64        }
65        let connect_timeout = if self.is_running.load(Ordering::SeqCst) {
66            self.proxy_connect_nominal_timeout
67        } else {
68            self.proxy_connect_init_timeout
69        };
70
71        let ret = tokio::select! {
72            socket = tokio::time::timeout(connect_timeout, self.open_proxy_http_socket()) => socket??,
73            _ = rx_terminate.recv() => {
74                return Err(anyhow::anyhow!(
75                    "failed to open HTTP socket as the instance has terminated"
76                ));
77            }
78        };
79        self.is_running.store(true, Ordering::Relaxed);
80        Ok(ret)
81    }
82
83    pub async fn open_proxy_http_socket(&self) -> anyhow::Result<TcpSocketHalf> {
84        // We need to find the destination address
85        let poll_listening = self.poll_listening.clone();
86        let port = poll_fn(|cx| poll_listening(cx)).await.port();
87        let dst = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), port);
88
89        // Open a connection directly to the loopback port
90        // (or at least try to)
91        self.loopback_networking
92            .loopback_connect_to(SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0), dst)
93            .ok_or_else(|| {
94                tracing::debug!(
95                    "proxy connection attempt failed - could not connect to http server socket as the loopback socket is not open",
96                );
97                anyhow::anyhow!("failed to open HTTP socket as the loopback socket is not open")
98            })
99    }
100}