use std::{
net::{IpAddr, SocketAddr},
sync::{Arc, Mutex},
task::{Context, Poll, Waker},
time::Duration,
};
use virtual_net::{
host::LocalNetworking, loopback::LoopbackNetworking, IpCidr, IpRoute, NetworkError,
StreamSecurity, VirtualIcmpSocket, VirtualNetworking, VirtualRawSocket, VirtualTcpListener,
VirtualTcpSocket, VirtualUdpSocket,
};
#[derive(Debug, Default)]
struct LocalWithLoopbackNetworkingListening {
addresses: Vec<SocketAddr>,
wakers: Vec<Waker>,
}
#[derive(Debug, Clone)]
pub struct LocalWithLoopbackNetworking {
inner_networking: Arc<dyn VirtualNetworking + Send + Sync + 'static>,
local_listening: Arc<Mutex<LocalWithLoopbackNetworkingListening>>,
loopback_networking: LoopbackNetworking,
}
impl LocalWithLoopbackNetworking {
pub fn new() -> Self {
lazy_static::lazy_static! {
static ref LOCAL_NETWORKING: Arc<LocalNetworking> = Arc::new(LocalNetworking::default());
}
Self {
local_listening: Default::default(),
inner_networking: LOCAL_NETWORKING.clone(),
loopback_networking: LoopbackNetworking::new(),
}
}
pub fn poll_listening(&self, cx: &mut Context<'_>) -> Poll<SocketAddr> {
let mut listening = self.local_listening.lock().unwrap();
if let Some(addr) = listening.addresses.first() {
return Poll::Ready(*addr);
}
if !listening.wakers.iter().any(|w| w.will_wake(cx.waker())) {
listening.wakers.push(cx.waker().clone());
}
Poll::Pending
}
pub fn register_listener(&self, addr: SocketAddr) {
let mut listening = self.local_listening.lock().unwrap();
listening.addresses.push(addr);
listening.addresses.sort_by_key(|a| a.port());
listening.wakers.drain(..).for_each(|w| w.wake());
}
pub fn loopback_networking(&self) -> LoopbackNetworking {
self.loopback_networking.clone()
}
}
#[allow(unused_variables)]
#[async_trait::async_trait]
impl VirtualNetworking for LocalWithLoopbackNetworking {
async fn bridge(
&self,
network: &str,
access_token: &str,
security: StreamSecurity,
) -> Result<(), NetworkError> {
self.inner_networking
.bridge(network, access_token, security)
.await
}
async fn unbridge(&self) -> Result<(), NetworkError> {
self.inner_networking.unbridge().await
}
async fn dhcp_acquire(&self) -> Result<Vec<IpAddr>, NetworkError> {
self.inner_networking.dhcp_acquire().await
}
async fn ip_add(&self, ip: IpAddr, prefix: u8) -> Result<(), NetworkError> {
self.inner_networking.ip_add(ip, prefix).await
}
async fn ip_remove(&self, ip: IpAddr) -> Result<(), NetworkError> {
self.inner_networking.ip_remove(ip).await
}
async fn ip_clear(&self) -> Result<(), NetworkError> {
self.inner_networking.ip_clear().await
}
async fn ip_list(&self) -> Result<Vec<IpCidr>, NetworkError> {
self.inner_networking.ip_list().await
}
async fn mac(&self) -> Result<[u8; 6], NetworkError> {
self.inner_networking.mac().await
}
async fn gateway_set(&self, ip: IpAddr) -> Result<(), NetworkError> {
self.inner_networking.gateway_set(ip).await
}
async fn route_add(
&self,
cidr: IpCidr,
via_router: IpAddr,
preferred_until: Option<Duration>,
expires_at: Option<Duration>,
) -> Result<(), NetworkError> {
self.inner_networking
.route_add(cidr, via_router, preferred_until, expires_at)
.await
}
async fn route_remove(&self, cidr: IpAddr) -> Result<(), NetworkError> {
self.inner_networking.route_remove(cidr).await
}
async fn route_clear(&self) -> Result<(), NetworkError> {
self.inner_networking.route_clear().await
}
async fn route_list(&self) -> Result<Vec<IpRoute>, NetworkError> {
self.inner_networking.route_list().await
}
async fn bind_raw(&self) -> Result<Box<dyn VirtualRawSocket + Sync>, NetworkError> {
self.inner_networking.bind_raw().await
}
async fn listen_tcp(
&self,
addr: SocketAddr,
only_v6: bool,
reuse_port: bool,
reuse_addr: bool,
) -> Result<Box<dyn VirtualTcpListener + Sync>, NetworkError> {
let backlog = 1024;
tracing::debug!("registering listener on loopback networking");
let ret: Result<Box<dyn VirtualTcpListener + Sync>, NetworkError> = self
.loopback_networking
.listen_tcp(addr, only_v6, reuse_port, reuse_addr)
.await;
if ret.is_ok() {
tracing::debug!("registering listener on loopback networking");
self.register_listener(addr);
}
ret
}
async fn bind_udp(
&self,
addr: SocketAddr,
reuse_port: bool,
reuse_addr: bool,
) -> Result<Box<dyn VirtualUdpSocket + Sync>, NetworkError> {
self.inner_networking
.bind_udp(addr, reuse_port, reuse_addr)
.await
}
async fn bind_icmp(
&self,
addr: IpAddr,
) -> Result<Box<dyn VirtualIcmpSocket + Sync>, NetworkError> {
self.inner_networking.bind_icmp(addr).await
}
async fn connect_tcp(
&self,
addr: SocketAddr,
peer: SocketAddr,
) -> Result<Box<dyn VirtualTcpSocket + Sync>, NetworkError> {
self.inner_networking.connect_tcp(addr, peer).await
}
async fn resolve(
&self,
host: &str,
port: Option<u16>,
dns_server: Option<IpAddr>,
) -> Result<Vec<IpAddr>, NetworkError> {
self.inner_networking.resolve(host, port, dns_server).await
}
}