virtual_net/
loopback.rs

1use std::collections::VecDeque;
2use std::net::SocketAddr;
3use std::sync::Mutex;
4use std::task::{Context, Poll, Waker};
5use std::{collections::HashMap, sync::Arc};
6
7use crate::tcp_pair::TcpSocketHalf;
8use crate::{
9    InterestHandler, IpAddr, IpCidr, Ipv4Addr, Ipv6Addr, NetworkError, VirtualIoSource,
10    VirtualNetworking, VirtualTcpListener, VirtualTcpSocket,
11};
12use virtual_mio::InterestType;
13
14const DEFAULT_MAX_BUFFER_SIZE: usize = 1_048_576;
15
16#[derive(Debug, Default)]
17struct LoopbackNetworkingState {
18    tcp_listeners: HashMap<SocketAddr, LoopbackTcpListener>,
19    ip_addresses: Vec<IpCidr>,
20}
21
22#[derive(Debug, Clone)]
23pub struct LoopbackNetworking {
24    state: Arc<Mutex<LoopbackNetworkingState>>,
25}
26
27impl LoopbackNetworking {
28    pub fn new() -> Self {
29        LoopbackNetworking {
30            state: Arc::new(Mutex::new(Default::default())),
31        }
32    }
33
34    pub fn loopback_connect_to(
35        &self,
36        mut local_addr: SocketAddr,
37        peer_addr: SocketAddr,
38    ) -> Option<TcpSocketHalf> {
39        let mut port = local_addr.port();
40        if port == 0 {
41            port = peer_addr.port();
42        }
43
44        local_addr = match local_addr.ip() {
45            IpAddr::V4(Ipv4Addr::UNSPECIFIED) => {
46                SocketAddr::new(Ipv4Addr::new(127, 0, 0, 100).into(), port)
47            }
48            IpAddr::V6(Ipv6Addr::UNSPECIFIED) => {
49                SocketAddr::new(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 100).into(), port)
50            }
51            ip => SocketAddr::new(ip, port),
52        };
53
54        let state = self.state.lock().unwrap();
55        if let Some(listener) = state.tcp_listeners.get(&peer_addr) {
56            Some(listener.connect_to(local_addr))
57        } else {
58            state
59                .tcp_listeners
60                .iter()
61                .next()
62                .map(|listener| listener.1.connect_to(local_addr))
63        }
64    }
65}
66
67impl Default for LoopbackNetworking {
68    fn default() -> Self {
69        Self::new()
70    }
71}
72
73#[allow(unused_variables)]
74#[async_trait::async_trait]
75impl VirtualNetworking for LoopbackNetworking {
76    async fn dhcp_acquire(&self) -> crate::Result<Vec<IpAddr>> {
77        let mut state: std::sync::MutexGuard<'_, LoopbackNetworkingState> =
78            self.state.lock().unwrap();
79        state.ip_addresses.clear();
80        state.ip_addresses.push(IpCidr {
81            ip: IpAddr::V4(Ipv4Addr::LOCALHOST),
82            prefix: 32,
83        });
84        state.ip_addresses.push(IpCidr {
85            ip: IpAddr::V6(Ipv6Addr::LOCALHOST),
86            prefix: 128,
87        });
88        Ok(state.ip_addresses.iter().map(|cidr| cidr.ip).collect())
89    }
90
91    async fn ip_add(&self, ip: IpAddr, prefix: u8) -> crate::Result<()> {
92        let mut state = self.state.lock().unwrap();
93        state.ip_addresses.push(IpCidr { ip, prefix });
94        Ok(())
95    }
96
97    async fn ip_remove(&self, ip: IpAddr) -> crate::Result<()> {
98        let mut state: std::sync::MutexGuard<'_, LoopbackNetworkingState> =
99            self.state.lock().unwrap();
100        state.ip_addresses.retain(|cidr| cidr.ip != ip);
101        Ok(())
102    }
103
104    async fn ip_clear(&self) -> crate::Result<()> {
105        let mut state: std::sync::MutexGuard<'_, LoopbackNetworkingState> =
106            self.state.lock().unwrap();
107        state.ip_addresses.clear();
108        Ok(())
109    }
110
111    async fn ip_list(&self) -> crate::Result<Vec<IpCidr>> {
112        let state: std::sync::MutexGuard<'_, LoopbackNetworkingState> = self.state.lock().unwrap();
113        Ok(state.ip_addresses.clone())
114    }
115
116    async fn listen_tcp(
117        &self,
118        mut addr: SocketAddr,
119        _only_v6: bool,
120        _reuse_port: bool,
121        _reuse_addr: bool,
122    ) -> crate::Result<Box<dyn VirtualTcpListener + Sync>> {
123        let listener = LoopbackTcpListener::new(addr);
124
125        if addr.ip() == IpAddr::V4(Ipv4Addr::UNSPECIFIED) {
126            addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), addr.port());
127        } else if addr.ip() == IpAddr::V6(Ipv6Addr::UNSPECIFIED) {
128            addr = SocketAddr::new(Ipv6Addr::LOCALHOST.into(), addr.port());
129        }
130
131        let mut state = self.state.lock().unwrap();
132        state.tcp_listeners.insert(addr, listener.clone());
133
134        Ok(Box::new(listener))
135    }
136}
137
138#[derive(Debug)]
139struct LoopbackTcpListenerState {
140    handler: Option<Box<dyn InterestHandler + Send + Sync>>,
141    addr_local: SocketAddr,
142    backlog: VecDeque<TcpSocketHalf>,
143    wakers: Vec<Waker>,
144}
145
146#[derive(Debug, Clone)]
147pub struct LoopbackTcpListener {
148    state: Arc<Mutex<LoopbackTcpListenerState>>,
149}
150
151impl LoopbackTcpListener {
152    pub fn new(addr_local: SocketAddr) -> Self {
153        Self {
154            state: Arc::new(Mutex::new(LoopbackTcpListenerState {
155                handler: None,
156                addr_local,
157                backlog: Default::default(),
158                wakers: Default::default(),
159            })),
160        }
161    }
162
163    pub fn connect_to(&self, addr_local: SocketAddr) -> TcpSocketHalf {
164        let mut state = self.state.lock().unwrap();
165        let (half1, half2) =
166            TcpSocketHalf::channel(DEFAULT_MAX_BUFFER_SIZE, state.addr_local, addr_local);
167
168        state.backlog.push_back(half1);
169        if let Some(handler) = state.handler.as_mut() {
170            handler.push_interest(InterestType::Readable);
171        }
172        state.wakers.drain(..).for_each(|w| w.wake());
173
174        half2
175    }
176}
177
178impl VirtualIoSource for LoopbackTcpListener {
179    fn remove_handler(&mut self) {
180        let mut state = self.state.lock().unwrap();
181        state.handler.take();
182    }
183
184    fn poll_read_ready(&mut self, cx: &mut Context<'_>) -> Poll<crate::Result<usize>> {
185        let mut state = self.state.lock().unwrap();
186        if !state.backlog.is_empty() {
187            return Poll::Ready(Ok(state.backlog.len()));
188        }
189        if !state.wakers.iter().any(|w| w.will_wake(cx.waker())) {
190            state.wakers.push(cx.waker().clone());
191        }
192        Poll::Pending
193    }
194
195    fn poll_write_ready(&mut self, _cx: &mut Context<'_>) -> Poll<crate::Result<usize>> {
196        Poll::Pending
197    }
198}
199
200impl VirtualTcpListener for LoopbackTcpListener {
201    fn try_accept(
202        &mut self,
203    ) -> crate::Result<(Box<dyn crate::VirtualTcpSocket + Sync>, SocketAddr)> {
204        let mut state = self.state.lock().unwrap();
205        let next = state.backlog.pop_front();
206        if let Some(next) = next {
207            let peer = next.addr_peer()?;
208            return Ok((Box::new(next), peer));
209        }
210        Err(NetworkError::WouldBlock)
211    }
212
213    fn set_handler(
214        &mut self,
215        mut handler: Box<dyn InterestHandler + Send + Sync>,
216    ) -> crate::Result<()> {
217        let mut state = self.state.lock().unwrap();
218        if !state.backlog.is_empty() {
219            handler.push_interest(InterestType::Readable);
220        }
221        state.handler.replace(handler);
222        Ok(())
223    }
224
225    fn addr_local(&self) -> crate::Result<SocketAddr> {
226        let state = self.state.lock().unwrap();
227        Ok(state.addr_local)
228    }
229
230    fn set_ttl(&mut self, _ttl: u8) -> crate::Result<()> {
231        Ok(())
232    }
233
234    fn ttl(&self) -> crate::Result<u8> {
235        Ok(64)
236    }
237}