1use std::collections::VecDeque;
2use std::net::SocketAddr;
3use std::sync::Mutex;
4use std::task::{Context, Poll, Waker};
5use std::{collections::HashMap, sync::Arc};
6
7use crate::tcp_pair::TcpSocketHalf;
8use crate::{
9 InterestHandler, IpAddr, IpCidr, Ipv4Addr, Ipv6Addr, NetworkError, VirtualIoSource,
10 VirtualNetworking, VirtualTcpListener, VirtualTcpSocket,
11};
12use virtual_mio::InterestType;
13
14const DEFAULT_MAX_BUFFER_SIZE: usize = 1_048_576;
15
16#[derive(Debug, Default)]
17struct LoopbackNetworkingState {
18 tcp_listeners: HashMap<SocketAddr, LoopbackTcpListener>,
19 ip_addresses: Vec<IpCidr>,
20}
21
22#[derive(Debug, Clone)]
23pub struct LoopbackNetworking {
24 state: Arc<Mutex<LoopbackNetworkingState>>,
25}
26
27impl LoopbackNetworking {
28 pub fn new() -> Self {
29 LoopbackNetworking {
30 state: Arc::new(Mutex::new(Default::default())),
31 }
32 }
33
34 pub fn loopback_connect_to(
35 &self,
36 mut local_addr: SocketAddr,
37 peer_addr: SocketAddr,
38 ) -> Option<TcpSocketHalf> {
39 let mut port = local_addr.port();
40 if port == 0 {
41 port = peer_addr.port();
42 }
43
44 local_addr = match local_addr.ip() {
45 IpAddr::V4(Ipv4Addr::UNSPECIFIED) => {
46 SocketAddr::new(Ipv4Addr::new(127, 0, 0, 100).into(), port)
47 }
48 IpAddr::V6(Ipv6Addr::UNSPECIFIED) => {
49 SocketAddr::new(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 100).into(), port)
50 }
51 ip => SocketAddr::new(ip, port),
52 };
53
54 let state = self.state.lock().unwrap();
55 if let Some(listener) = state.tcp_listeners.get(&peer_addr) {
56 Some(listener.connect_to(local_addr))
57 } else {
58 state
59 .tcp_listeners
60 .iter()
61 .next()
62 .map(|listener| listener.1.connect_to(local_addr))
63 }
64 }
65}
66
67impl Default for LoopbackNetworking {
68 fn default() -> Self {
69 Self::new()
70 }
71}
72
73#[allow(unused_variables)]
74#[async_trait::async_trait]
75impl VirtualNetworking for LoopbackNetworking {
76 async fn dhcp_acquire(&self) -> crate::Result<Vec<IpAddr>> {
77 let mut state: std::sync::MutexGuard<'_, LoopbackNetworkingState> =
78 self.state.lock().unwrap();
79 state.ip_addresses.clear();
80 state.ip_addresses.push(IpCidr {
81 ip: IpAddr::V4(Ipv4Addr::LOCALHOST),
82 prefix: 32,
83 });
84 state.ip_addresses.push(IpCidr {
85 ip: IpAddr::V6(Ipv6Addr::LOCALHOST),
86 prefix: 128,
87 });
88 Ok(state.ip_addresses.iter().map(|cidr| cidr.ip).collect())
89 }
90
91 async fn ip_add(&self, ip: IpAddr, prefix: u8) -> crate::Result<()> {
92 let mut state = self.state.lock().unwrap();
93 state.ip_addresses.push(IpCidr { ip, prefix });
94 Ok(())
95 }
96
97 async fn ip_remove(&self, ip: IpAddr) -> crate::Result<()> {
98 let mut state: std::sync::MutexGuard<'_, LoopbackNetworkingState> =
99 self.state.lock().unwrap();
100 state.ip_addresses.retain(|cidr| cidr.ip != ip);
101 Ok(())
102 }
103
104 async fn ip_clear(&self) -> crate::Result<()> {
105 let mut state: std::sync::MutexGuard<'_, LoopbackNetworkingState> =
106 self.state.lock().unwrap();
107 state.ip_addresses.clear();
108 Ok(())
109 }
110
111 async fn ip_list(&self) -> crate::Result<Vec<IpCidr>> {
112 let state: std::sync::MutexGuard<'_, LoopbackNetworkingState> = self.state.lock().unwrap();
113 Ok(state.ip_addresses.clone())
114 }
115
116 async fn listen_tcp(
117 &self,
118 mut addr: SocketAddr,
119 _only_v6: bool,
120 _reuse_port: bool,
121 _reuse_addr: bool,
122 ) -> crate::Result<Box<dyn VirtualTcpListener + Sync>> {
123 let listener = LoopbackTcpListener::new(addr);
124
125 if addr.ip() == IpAddr::V4(Ipv4Addr::UNSPECIFIED) {
126 addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), addr.port());
127 } else if addr.ip() == IpAddr::V6(Ipv6Addr::UNSPECIFIED) {
128 addr = SocketAddr::new(Ipv6Addr::LOCALHOST.into(), addr.port());
129 }
130
131 let mut state = self.state.lock().unwrap();
132 state.tcp_listeners.insert(addr, listener.clone());
133
134 Ok(Box::new(listener))
135 }
136}
137
138#[derive(Debug)]
139struct LoopbackTcpListenerState {
140 handler: Option<Box<dyn InterestHandler + Send + Sync>>,
141 addr_local: SocketAddr,
142 backlog: VecDeque<TcpSocketHalf>,
143 wakers: Vec<Waker>,
144}
145
146#[derive(Debug, Clone)]
147pub struct LoopbackTcpListener {
148 state: Arc<Mutex<LoopbackTcpListenerState>>,
149}
150
151impl LoopbackTcpListener {
152 pub fn new(addr_local: SocketAddr) -> Self {
153 Self {
154 state: Arc::new(Mutex::new(LoopbackTcpListenerState {
155 handler: None,
156 addr_local,
157 backlog: Default::default(),
158 wakers: Default::default(),
159 })),
160 }
161 }
162
163 pub fn connect_to(&self, addr_local: SocketAddr) -> TcpSocketHalf {
164 let mut state = self.state.lock().unwrap();
165 let (half1, half2) =
166 TcpSocketHalf::channel(DEFAULT_MAX_BUFFER_SIZE, state.addr_local, addr_local);
167
168 state.backlog.push_back(half1);
169 if let Some(handler) = state.handler.as_mut() {
170 handler.push_interest(InterestType::Readable);
171 }
172 state.wakers.drain(..).for_each(|w| w.wake());
173
174 half2
175 }
176}
177
178impl VirtualIoSource for LoopbackTcpListener {
179 fn remove_handler(&mut self) {
180 let mut state = self.state.lock().unwrap();
181 state.handler.take();
182 }
183
184 fn poll_read_ready(&mut self, cx: &mut Context<'_>) -> Poll<crate::Result<usize>> {
185 let mut state = self.state.lock().unwrap();
186 if !state.backlog.is_empty() {
187 return Poll::Ready(Ok(state.backlog.len()));
188 }
189 if !state.wakers.iter().any(|w| w.will_wake(cx.waker())) {
190 state.wakers.push(cx.waker().clone());
191 }
192 Poll::Pending
193 }
194
195 fn poll_write_ready(&mut self, _cx: &mut Context<'_>) -> Poll<crate::Result<usize>> {
196 Poll::Pending
197 }
198}
199
200impl VirtualTcpListener for LoopbackTcpListener {
201 fn try_accept(
202 &mut self,
203 ) -> crate::Result<(Box<dyn crate::VirtualTcpSocket + Sync>, SocketAddr)> {
204 let mut state = self.state.lock().unwrap();
205 let next = state.backlog.pop_front();
206 if let Some(next) = next {
207 let peer = next.addr_peer()?;
208 return Ok((Box::new(next), peer));
209 }
210 Err(NetworkError::WouldBlock)
211 }
212
213 fn set_handler(
214 &mut self,
215 mut handler: Box<dyn InterestHandler + Send + Sync>,
216 ) -> crate::Result<()> {
217 let mut state = self.state.lock().unwrap();
218 if !state.backlog.is_empty() {
219 handler.push_interest(InterestType::Readable);
220 }
221 state.handler.replace(handler);
222 Ok(())
223 }
224
225 fn addr_local(&self) -> crate::Result<SocketAddr> {
226 let state = self.state.lock().unwrap();
227 Ok(state.addr_local)
228 }
229
230 fn set_ttl(&mut self, _ttl: u8) -> crate::Result<()> {
231 Ok(())
232 }
233
234 fn ttl(&self) -> crate::Result<u8> {
235 Ok(64)
236 }
237}