virtual_net/
loopback.rs

1use std::collections::{HashSet, VecDeque};
2use std::net::SocketAddr;
3use std::sync::Mutex;
4use std::task::{Context, Poll, Waker};
5use std::{collections::HashMap, sync::Arc};
6
7use crate::tcp_pair::TcpSocketHalf;
8use crate::{
9    InterestHandler, IpAddr, IpCidr, Ipv4Addr, Ipv6Addr, NetworkError, VirtualConnectedSocket,
10    VirtualIoSource, VirtualNetworking, VirtualSocket, VirtualTcpBoundSocket, VirtualTcpListener,
11    VirtualTcpSocket,
12};
13use virtual_mio::InterestType;
14
15const DEFAULT_MAX_BUFFER_SIZE: usize = 1_048_576;
16const LOOPBACK_EPHEMERAL_PORT_START: u16 = 49152;
17
18#[derive(Debug)]
19struct LoopbackNetworkingState {
20    tcp_listeners: HashMap<SocketAddr, LoopbackTcpListener>,
21    tcp_bound: HashSet<SocketAddr>,
22    ip_addresses: Vec<IpCidr>,
23    next_ephemeral_port: u16,
24}
25
26impl Default for LoopbackNetworkingState {
27    fn default() -> Self {
28        Self {
29            tcp_listeners: HashMap::new(),
30            tcp_bound: HashSet::new(),
31            ip_addresses: Vec::new(),
32            next_ephemeral_port: LOOPBACK_EPHEMERAL_PORT_START,
33        }
34    }
35}
36
37#[derive(Debug, Clone)]
38pub struct LoopbackNetworking {
39    state: Arc<Mutex<LoopbackNetworkingState>>,
40}
41
42impl LoopbackNetworking {
43    pub fn new() -> Self {
44        LoopbackNetworking {
45            state: Arc::new(Mutex::new(Default::default())),
46        }
47    }
48
49    pub fn loopback_connect_to(
50        &self,
51        mut local_addr: SocketAddr,
52        peer_addr: SocketAddr,
53    ) -> Option<TcpSocketHalf> {
54        let mut port = local_addr.port();
55        if port == 0 {
56            port = peer_addr.port();
57        }
58
59        local_addr = match local_addr.ip() {
60            IpAddr::V4(Ipv4Addr::UNSPECIFIED) => {
61                SocketAddr::new(Ipv4Addr::new(127, 0, 0, 100).into(), port)
62            }
63            IpAddr::V6(Ipv6Addr::UNSPECIFIED) => {
64                SocketAddr::new(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 100).into(), port)
65            }
66            ip => SocketAddr::new(ip, port),
67        };
68
69        let peer_key = Self::normalize_listener_addr(peer_addr);
70        let state = self.state.lock().unwrap();
71        if let Some(listener) = state.tcp_listeners.get(&peer_key) {
72            Some(listener.connect_to(local_addr))
73        } else {
74            state
75                .tcp_listeners
76                .iter()
77                .next()
78                .map(|listener| listener.1.connect_to(local_addr))
79        }
80    }
81
82    fn allocate_tcp_bind_addr(
83        state: &mut LoopbackNetworkingState,
84        mut addr: SocketAddr,
85    ) -> crate::Result<SocketAddr> {
86        let is_available = |candidate: SocketAddr, state: &LoopbackNetworkingState| {
87            let key = Self::normalize_listener_addr(candidate);
88            !state.tcp_listeners.contains_key(&key) && !state.tcp_bound.contains(&key)
89        };
90
91        if addr.port() == 0 {
92            let start = state.next_ephemeral_port;
93            let mut candidate = start;
94            loop {
95                let candidate_addr = SocketAddr::new(addr.ip(), candidate);
96                if is_available(candidate_addr, state) {
97                    addr.set_port(candidate);
98                    let normalized = Self::normalize_listener_addr(addr);
99                    state.tcp_bound.insert(normalized);
100                    state.next_ephemeral_port = if candidate == u16::MAX {
101                        LOOPBACK_EPHEMERAL_PORT_START
102                    } else {
103                        candidate + 1
104                    };
105                    return Ok(normalized);
106                }
107
108                candidate = if candidate == u16::MAX {
109                    LOOPBACK_EPHEMERAL_PORT_START
110                } else {
111                    candidate + 1
112                };
113                if candidate == start {
114                    return Err(NetworkError::AddressInUse);
115                }
116            }
117        }
118
119        let reservation_key = Self::normalize_listener_addr(addr);
120        if state.tcp_listeners.contains_key(&reservation_key)
121            || state.tcp_bound.contains(&reservation_key)
122        {
123            return Err(NetworkError::AddressInUse);
124        }
125        state.tcp_bound.insert(reservation_key);
126        Ok(reservation_key)
127    }
128
129    fn normalize_listener_addr(mut addr: SocketAddr) -> SocketAddr {
130        if addr.ip() == IpAddr::V4(Ipv4Addr::UNSPECIFIED) {
131            addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), addr.port());
132        } else if addr.ip() == IpAddr::V6(Ipv6Addr::UNSPECIFIED) {
133            addr = SocketAddr::new(Ipv6Addr::LOCALHOST.into(), addr.port());
134        }
135        addr
136    }
137}
138
139impl Default for LoopbackNetworking {
140    fn default() -> Self {
141        Self::new()
142    }
143}
144
145#[allow(unused_variables)]
146#[async_trait::async_trait]
147impl VirtualNetworking for LoopbackNetworking {
148    async fn dhcp_acquire(&self) -> crate::Result<Vec<IpAddr>> {
149        let mut state: std::sync::MutexGuard<'_, LoopbackNetworkingState> =
150            self.state.lock().unwrap();
151        state.ip_addresses.clear();
152        state.ip_addresses.push(IpCidr {
153            ip: IpAddr::V4(Ipv4Addr::LOCALHOST),
154            prefix: 32,
155        });
156        state.ip_addresses.push(IpCidr {
157            ip: IpAddr::V6(Ipv6Addr::LOCALHOST),
158            prefix: 128,
159        });
160        Ok(state.ip_addresses.iter().map(|cidr| cidr.ip).collect())
161    }
162
163    async fn ip_add(&self, ip: IpAddr, prefix: u8) -> crate::Result<()> {
164        let mut state = self.state.lock().unwrap();
165        state.ip_addresses.push(IpCidr { ip, prefix });
166        Ok(())
167    }
168
169    async fn ip_remove(&self, ip: IpAddr) -> crate::Result<()> {
170        let mut state: std::sync::MutexGuard<'_, LoopbackNetworkingState> =
171            self.state.lock().unwrap();
172        state.ip_addresses.retain(|cidr| cidr.ip != ip);
173        Ok(())
174    }
175
176    async fn ip_clear(&self) -> crate::Result<()> {
177        let mut state: std::sync::MutexGuard<'_, LoopbackNetworkingState> =
178            self.state.lock().unwrap();
179        state.ip_addresses.clear();
180        Ok(())
181    }
182
183    async fn ip_list(&self) -> crate::Result<Vec<IpCidr>> {
184        let state: std::sync::MutexGuard<'_, LoopbackNetworkingState> = self.state.lock().unwrap();
185        Ok(state.ip_addresses.clone())
186    }
187
188    async fn listen_tcp(
189        &self,
190        addr: SocketAddr,
191        only_v6: bool,
192        reuse_port: bool,
193        reuse_addr: bool,
194    ) -> crate::Result<Box<dyn VirtualTcpListener + Sync>> {
195        self.bind_tcp(addr, only_v6, reuse_port, reuse_addr)
196            .await?
197            .listen()
198    }
199
200    async fn bind_tcp(
201        &self,
202        addr: SocketAddr,
203        only_v6: bool,
204        reuse_port: bool,
205        reuse_addr: bool,
206    ) -> crate::Result<Box<dyn VirtualTcpBoundSocket + Sync>> {
207        let _ = (only_v6, reuse_port, reuse_addr);
208        let mut state = self.state.lock().unwrap();
209        let addr = Self::allocate_tcp_bind_addr(&mut state, addr)?;
210        Ok(Box::new(LoopbackTcpBoundSocket {
211            networking: self.clone(),
212            local_addr: addr,
213            reservation_key: Some(addr),
214            ttl: 64,
215        }))
216    }
217}
218
219#[cfg(all(test, feature = "tokio"))]
220impl LoopbackNetworking {
221    pub(crate) fn exhaust_tcp_ephemeral_ports_for_test(&self, ip: IpAddr) {
222        let mut state = self.state.lock().unwrap();
223        for port in LOOPBACK_EPHEMERAL_PORT_START..=u16::MAX {
224            let addr = SocketAddr::new(ip, port);
225            state
226                .tcp_listeners
227                .insert(addr, LoopbackTcpListener::new(addr, 64));
228        }
229        state.next_ephemeral_port = LOOPBACK_EPHEMERAL_PORT_START;
230    }
231}
232
233/// A connected TCP socket that keeps its local-port reservation in
234/// `LoopbackNetworkingState::tcp_bound` until it is explicitly closed or
235/// dropped, matching POSIX/Linux semantics where a connected socket holds
236/// its local port for its entire lifetime.
237#[derive(Debug)]
238struct LoopbackConnectedSocket {
239    inner: TcpSocketHalf,
240    networking: LoopbackNetworking,
241    /// `None` once the reservation has been released (after `close()` or `drop`).
242    reservation_key: Option<SocketAddr>,
243}
244
245impl LoopbackConnectedSocket {
246    fn release_reservation(&mut self) {
247        if let Some(key) = self.reservation_key.take() {
248            self.networking.state.lock().unwrap().tcp_bound.remove(&key);
249        }
250    }
251}
252
253impl Drop for LoopbackConnectedSocket {
254    fn drop(&mut self) {
255        self.release_reservation();
256    }
257}
258
259impl VirtualIoSource for LoopbackConnectedSocket {
260    fn remove_handler(&mut self) {
261        self.inner.remove_handler();
262    }
263
264    fn poll_read_ready(&mut self, cx: &mut Context<'_>) -> Poll<crate::Result<usize>> {
265        self.inner.poll_read_ready(cx)
266    }
267
268    fn poll_write_ready(&mut self, cx: &mut Context<'_>) -> Poll<crate::Result<usize>> {
269        self.inner.poll_write_ready(cx)
270    }
271}
272
273impl VirtualSocket for LoopbackConnectedSocket {
274    fn set_ttl(&mut self, ttl: u32) -> crate::Result<()> {
275        self.inner.set_ttl(ttl)
276    }
277
278    fn ttl(&self) -> crate::Result<u32> {
279        self.inner.ttl()
280    }
281
282    fn addr_local(&self) -> crate::Result<SocketAddr> {
283        self.inner.addr_local()
284    }
285
286    fn status(&self) -> crate::Result<crate::SocketStatus> {
287        self.inner.status()
288    }
289
290    fn set_handler(
291        &mut self,
292        handler: Box<dyn InterestHandler + Send + Sync>,
293    ) -> crate::Result<()> {
294        self.inner.set_handler(handler)
295    }
296}
297
298impl VirtualConnectedSocket for LoopbackConnectedSocket {
299    fn set_linger(&mut self, linger: Option<std::time::Duration>) -> crate::Result<()> {
300        self.inner.set_linger(linger)
301    }
302
303    fn linger(&self) -> crate::Result<Option<std::time::Duration>> {
304        self.inner.linger()
305    }
306
307    fn try_send(&mut self, data: &[u8]) -> crate::Result<usize> {
308        self.inner.try_send(data)
309    }
310
311    fn try_flush(&mut self) -> crate::Result<()> {
312        self.inner.try_flush()
313    }
314
315    fn close(&mut self) -> crate::Result<()> {
316        self.release_reservation();
317        self.inner.close()
318    }
319
320    fn try_recv(
321        &mut self,
322        buf: &mut [std::mem::MaybeUninit<u8>],
323        peek: bool,
324    ) -> crate::Result<usize> {
325        self.inner.try_recv(buf, peek)
326    }
327}
328
329impl VirtualTcpSocket for LoopbackConnectedSocket {
330    fn set_recv_buf_size(&mut self, size: usize) -> crate::Result<()> {
331        self.inner.set_recv_buf_size(size)
332    }
333
334    fn recv_buf_size(&self) -> crate::Result<usize> {
335        self.inner.recv_buf_size()
336    }
337
338    fn set_send_buf_size(&mut self, size: usize) -> crate::Result<()> {
339        self.inner.set_send_buf_size(size)
340    }
341
342    fn send_buf_size(&self) -> crate::Result<usize> {
343        self.inner.send_buf_size()
344    }
345
346    fn set_nodelay(&mut self, reuse: bool) -> crate::Result<()> {
347        self.inner.set_nodelay(reuse)
348    }
349
350    fn nodelay(&self) -> crate::Result<bool> {
351        self.inner.nodelay()
352    }
353
354    fn set_keepalive(&mut self, keepalive: bool) -> crate::Result<()> {
355        self.inner.set_keepalive(keepalive)
356    }
357
358    fn keepalive(&self) -> crate::Result<bool> {
359        self.inner.keepalive()
360    }
361
362    fn set_dontroute(&mut self, dontroute: bool) -> crate::Result<()> {
363        self.inner.set_dontroute(dontroute)
364    }
365
366    fn dontroute(&self) -> crate::Result<bool> {
367        self.inner.dontroute()
368    }
369
370    fn addr_peer(&self) -> crate::Result<SocketAddr> {
371        self.inner.addr_peer()
372    }
373
374    fn shutdown(&mut self, how: std::net::Shutdown) -> crate::Result<()> {
375        self.inner.shutdown(how)
376    }
377
378    fn is_closed(&self) -> bool {
379        self.inner.is_closed()
380    }
381}
382
383#[derive(Debug)]
384struct LoopbackTcpListenerState {
385    handler: Option<Box<dyn InterestHandler + Send + Sync>>,
386    addr_local: SocketAddr,
387    ttl: u8,
388    backlog: VecDeque<TcpSocketHalf>,
389    wakers: Vec<Waker>,
390}
391
392#[derive(Debug, Clone)]
393pub struct LoopbackTcpListener {
394    state: Arc<Mutex<LoopbackTcpListenerState>>,
395}
396
397impl LoopbackTcpListener {
398    pub fn new(addr_local: SocketAddr, ttl: u8) -> Self {
399        Self {
400            state: Arc::new(Mutex::new(LoopbackTcpListenerState {
401                handler: None,
402                addr_local,
403                ttl,
404                backlog: Default::default(),
405                wakers: Default::default(),
406            })),
407        }
408    }
409
410    pub fn connect_to(&self, addr_local: SocketAddr) -> TcpSocketHalf {
411        let mut state = self.state.lock().unwrap();
412        let (mut half1, half2) =
413            TcpSocketHalf::channel(DEFAULT_MAX_BUFFER_SIZE, state.addr_local, addr_local);
414        half1.set_ttl(u32::from(state.ttl)).ok();
415
416        state.backlog.push_back(half1);
417        if let Some(handler) = state.handler.as_mut() {
418            handler.push_interest(InterestType::Readable);
419        }
420        state.wakers.drain(..).for_each(|w| w.wake());
421
422        half2
423    }
424}
425
426impl VirtualIoSource for LoopbackTcpListener {
427    fn remove_handler(&mut self) {
428        let mut state = self.state.lock().unwrap();
429        state.handler.take();
430    }
431
432    fn poll_read_ready(&mut self, cx: &mut Context<'_>) -> Poll<crate::Result<usize>> {
433        let mut state = self.state.lock().unwrap();
434        if !state.backlog.is_empty() {
435            return Poll::Ready(Ok(state.backlog.len()));
436        }
437        if !state.wakers.iter().any(|w| w.will_wake(cx.waker())) {
438            state.wakers.push(cx.waker().clone());
439        }
440        Poll::Pending
441    }
442
443    fn poll_write_ready(&mut self, _cx: &mut Context<'_>) -> Poll<crate::Result<usize>> {
444        Poll::Pending
445    }
446}
447
448impl VirtualTcpListener for LoopbackTcpListener {
449    fn try_accept(
450        &mut self,
451    ) -> crate::Result<(Box<dyn crate::VirtualTcpSocket + Sync>, SocketAddr)> {
452        let mut state = self.state.lock().unwrap();
453        let next = state.backlog.pop_front();
454        if let Some(next) = next {
455            let peer = next.addr_peer()?;
456            return Ok((Box::new(next), peer));
457        }
458        Err(NetworkError::WouldBlock)
459    }
460
461    fn set_handler(
462        &mut self,
463        mut handler: Box<dyn InterestHandler + Send + Sync>,
464    ) -> crate::Result<()> {
465        let mut state = self.state.lock().unwrap();
466        if !state.backlog.is_empty() {
467            handler.push_interest(InterestType::Readable);
468        }
469        state.handler.replace(handler);
470        Ok(())
471    }
472
473    fn addr_local(&self) -> crate::Result<SocketAddr> {
474        let state = self.state.lock().unwrap();
475        Ok(state.addr_local)
476    }
477
478    fn set_ttl(&mut self, ttl: u8) -> crate::Result<()> {
479        let mut state = self.state.lock().unwrap();
480        state.ttl = ttl;
481        Ok(())
482    }
483
484    fn ttl(&self) -> crate::Result<u8> {
485        let state = self.state.lock().unwrap();
486        Ok(state.ttl)
487    }
488}
489
490#[derive(Debug)]
491pub struct LoopbackTcpBoundSocket {
492    networking: LoopbackNetworking,
493    local_addr: SocketAddr,
494    reservation_key: Option<SocketAddr>,
495    ttl: u32,
496}
497
498impl Drop for LoopbackTcpBoundSocket {
499    fn drop(&mut self) {
500        if let Some(reservation_key) = self.reservation_key.take() {
501            let mut state = self.networking.state.lock().unwrap();
502            state.tcp_bound.remove(&reservation_key);
503        }
504    }
505}
506
507impl VirtualTcpBoundSocket for LoopbackTcpBoundSocket {
508    fn addr_local(&self) -> crate::Result<SocketAddr> {
509        Ok(self.local_addr)
510    }
511
512    fn listen(&mut self) -> crate::Result<Box<dyn VirtualTcpListener + Sync>> {
513        let listener =
514            LoopbackTcpListener::new(self.local_addr, u8::try_from(self.ttl).unwrap_or(u8::MAX));
515        let mut state = self.networking.state.lock().unwrap();
516        let reservation_key = self.reservation_key.ok_or(NetworkError::InvalidFd)?;
517        if !state.tcp_bound.remove(&reservation_key) {
518            return Err(NetworkError::InvalidFd);
519        }
520        if state.tcp_listeners.contains_key(&reservation_key) {
521            state.tcp_bound.insert(reservation_key);
522            return Err(NetworkError::AddressInUse);
523        }
524        state
525            .tcp_listeners
526            .insert(reservation_key, listener.clone());
527        self.reservation_key = None;
528        Ok(Box::new(listener))
529    }
530
531    fn connect(&mut self, peer: SocketAddr) -> crate::Result<Box<dyn VirtualTcpSocket + Sync>> {
532        let mut socket = self
533            .networking
534            .loopback_connect_to(self.local_addr, peer)
535            .ok_or(NetworkError::ConnectionRefused)?;
536        // Transfer the port reservation to the connected socket so that the
537        // local port stays in `tcp_bound` for the socket's entire lifetime,
538        // matching POSIX/Linux semantics (a connected socket holds its local
539        // port; rebinding it returns EADDRINUSE).
540        let reservation_key = self.reservation_key.take().ok_or(NetworkError::InvalidFd)?;
541        socket.set_ttl(self.ttl)?;
542        Ok(Box::new(LoopbackConnectedSocket {
543            inner: socket,
544            networking: self.networking.clone(),
545            reservation_key: Some(reservation_key),
546        }))
547    }
548
549    fn set_ttl(&mut self, ttl: u32) -> crate::Result<()> {
550        self.ttl = ttl;
551        Ok(())
552    }
553
554    fn ttl(&self) -> crate::Result<u32> {
555        Ok(self.ttl)
556    }
557}