virtual_net/
tcp_pair.rs

1use crate::{
2    InterestHandler, NetworkError, SocketStatus, VirtualConnectedSocket, VirtualIoSource,
3    VirtualSocket, VirtualTcpSocket, net_error_into_io_err,
4};
5use bytes::{Buf, Bytes};
6use futures_util::Future;
7use smoltcp::storage::RingBuffer;
8use std::io::{self};
9use std::pin::Pin;
10use std::sync::Arc;
11use std::sync::Mutex;
12use std::task::{Context, Waker};
13use std::time::Duration;
14use std::{net::SocketAddr, task::Poll};
15use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, BufReader};
16use virtual_mio::{ArcInterestHandler, InterestType};
17
18#[derive(Debug, PartialEq, Eq, Clone, Copy)]
19enum State {
20    Alive,
21    Dead,
22    Closed,
23    Shutdown,
24}
25
26#[derive(Debug)]
27struct SocketBufferState {
28    buffer: RingBuffer<'static, u8>,
29    push_handler: Option<ArcInterestHandler>,
30    pull_handler: Option<ArcInterestHandler>,
31    wakers: Vec<Waker>,
32    state: State,
33    // This flag prevents a poll write ready storm
34    halt_immediate_poll_write: bool,
35}
36
37impl SocketBufferState {
38    fn add_waker(&mut self, waker: &Waker) {
39        if !self.wakers.iter().any(|w| w.will_wake(waker)) {
40            self.wakers.push(waker.clone());
41        }
42    }
43}
44
45#[derive(Debug, Clone)]
46pub(crate) struct SocketBuffer {
47    state: Arc<Mutex<SocketBufferState>>,
48    dead_on_drop: bool,
49}
50
51impl Drop for SocketBuffer {
52    fn drop(&mut self) {
53        if self.state() == State::Alive {
54            if self.dead_on_drop {
55                self.set_state(State::Dead);
56            } else {
57                self.set_state(State::Closed);
58            }
59        }
60    }
61}
62
63impl SocketBuffer {
64    fn new(max_size: usize) -> Self {
65        Self {
66            state: Arc::new(Mutex::new(SocketBufferState {
67                buffer: RingBuffer::new(vec![0; max_size]),
68                push_handler: None,
69                pull_handler: None,
70                wakers: Vec::new(),
71                state: State::Alive,
72                halt_immediate_poll_write: false,
73            })),
74            dead_on_drop: false,
75        }
76    }
77
78    pub fn set_push_handler(&self, mut handler: ArcInterestHandler) {
79        let mut state = self.state.lock().unwrap();
80        if state.state != State::Alive {
81            handler.push_interest(InterestType::Closed);
82        }
83        if !state.buffer.is_empty() {
84            handler.push_interest(InterestType::Readable);
85        }
86        state.push_handler.replace(handler);
87    }
88
89    pub fn set_pull_handler(&self, mut handler: ArcInterestHandler) {
90        let mut state = self.state.lock().unwrap();
91        if state.state != State::Alive {
92            handler.push_interest(InterestType::Closed);
93        }
94        if !state.buffer.is_full() && state.pull_handler.is_none() {
95            handler.push_interest(InterestType::Writable);
96        }
97        state.pull_handler.replace(handler);
98    }
99
100    pub fn clear_push_handler(&self) {
101        let mut state = self.state.lock().unwrap();
102        state.push_handler.take();
103    }
104
105    pub fn clear_pull_handler(&self) {
106        let mut state = self.state.lock().unwrap();
107        state.pull_handler.take();
108    }
109
110    pub fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<crate::Result<usize>> {
111        let mut state = self.state.lock().unwrap();
112        if !state.buffer.is_empty() {
113            return Poll::Ready(Ok(state.buffer.len()));
114        }
115        match state.state {
116            State::Alive => {
117                if !state.wakers.iter().any(|w| w.will_wake(cx.waker())) {
118                    state.wakers.push(cx.waker().clone());
119                }
120                Poll::Pending
121            }
122            State::Dead => {
123                tracing::trace!("poll_read_ready: socket is dead");
124                Poll::Ready(Err(NetworkError::ConnectionReset))
125            }
126            State::Closed | State::Shutdown => {
127                tracing::trace!("poll_read_ready: socket is closed or shutdown");
128                Poll::Ready(Ok(0))
129            }
130        }
131    }
132
133    pub fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<crate::Result<usize>> {
134        let mut state = self.state.lock().unwrap();
135        match state.state {
136            State::Alive => {
137                if !state.buffer.is_full() && !state.halt_immediate_poll_write {
138                    state.halt_immediate_poll_write = true;
139                    return Poll::Ready(Ok(state.buffer.window()));
140                }
141                if !state.wakers.iter().any(|w| w.will_wake(cx.waker())) {
142                    state.wakers.push(cx.waker().clone());
143                }
144                Poll::Pending
145            }
146            State::Dead => {
147                tracing::trace!("poll_write_ready: socket is dead");
148                Poll::Ready(Err(NetworkError::ConnectionReset))
149            }
150            State::Closed | State::Shutdown => {
151                tracing::trace!("poll_write_ready: socket is closed or shutdown");
152                Poll::Ready(Ok(0))
153            }
154        }
155    }
156
157    fn set_state(&self, new_state: State) {
158        let mut state = self.state.lock().unwrap();
159        state.state = new_state;
160        if let Some(handler) = state.pull_handler.as_mut() {
161            handler.push_interest(InterestType::Closed);
162        }
163        if let Some(handler) = state.push_handler.as_mut() {
164            handler.push_interest(InterestType::Closed);
165        }
166        state.wakers.drain(..).for_each(|w| w.wake());
167    }
168
169    fn state(&self) -> State {
170        let state = self.state.lock().unwrap();
171        state.state
172    }
173
174    pub fn try_send(
175        &self,
176        data: &[u8],
177        all_or_nothing: bool,
178        waker: Option<&Waker>,
179    ) -> crate::Result<usize> {
180        let mut state = self.state.lock().unwrap();
181        if state.state != State::Alive {
182            return Err(NetworkError::ConnectionReset);
183        }
184        state.halt_immediate_poll_write = false;
185        let available = state.buffer.window();
186        if available == 0 {
187            if let Some(waker) = waker {
188                state.add_waker(waker)
189            }
190            return Err(NetworkError::WouldBlock);
191        }
192        if data.len() > available {
193            if all_or_nothing {
194                if let Some(waker) = waker {
195                    state.add_waker(waker)
196                }
197                return Err(NetworkError::WouldBlock);
198            }
199            let amt = state.buffer.enqueue_slice(&data[..available]);
200            return Ok(amt);
201        }
202        let amt = state.buffer.enqueue_slice(data);
203
204        if let Some(handler) = state.push_handler.as_mut() {
205            handler.push_interest(InterestType::Readable);
206        }
207        state.wakers.drain(..).for_each(|w| w.wake());
208        Ok(amt)
209    }
210
211    pub async fn send(&self, data: Bytes) -> crate::Result<()> {
212        struct Poller<'a> {
213            this: &'a SocketBuffer,
214            data: Bytes,
215        }
216        impl Future for Poller<'_> {
217            type Output = crate::Result<()>;
218            fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
219                loop {
220                    if self.data.is_empty() {
221                        return Poll::Ready(Ok(()));
222                    }
223                    return match self.this.try_send(&self.data, false, Some(cx.waker())) {
224                        Ok(amt) => {
225                            self.data.advance(amt);
226                            continue;
227                        }
228                        Err(NetworkError::WouldBlock) => Poll::Pending,
229                        Err(err) => Poll::Ready(Err(err)),
230                    };
231                }
232            }
233        }
234        Poller { this: self, data }.await
235    }
236
237    pub fn try_read(
238        &self,
239        buf: &mut [std::mem::MaybeUninit<u8>],
240        peek: bool,
241        waker: Option<&Waker>,
242    ) -> crate::Result<usize> {
243        let mut state = self.state.lock().unwrap();
244        if state.buffer.is_empty() {
245            return match state.state {
246                State::Alive => {
247                    if let Some(waker) = waker {
248                        state.add_waker(waker)
249                    }
250                    Err(NetworkError::WouldBlock)
251                }
252                State::Dead => {
253                    tracing::trace!("try_read: socket is dead");
254                    // Note: Returning `ConnectionReset` here may seem logical as other functions return this
255                    // however this code path is not always handled properly. In particular `tokio` inside
256                    // WASIX will panic if it receives this code.
257                    return Ok(0);
258                }
259                State::Closed | State::Shutdown => {
260                    tracing::trace!("try_read: socket is closed or shutdown");
261                    return Ok(0);
262                }
263            };
264        }
265
266        let buf: &mut [u8] = unsafe { std::mem::transmute(buf) };
267        let amt = buf.len().min(state.buffer.len());
268        let amt = if peek {
269            state.buffer.read_allocated(0, &mut buf[..amt])
270        } else {
271            state.buffer.dequeue_slice(&mut buf[..amt])
272        };
273
274        if let Some(handler) = state.pull_handler.as_mut() {
275            handler.push_interest(InterestType::Writable);
276        }
277        state.wakers.drain(..).for_each(|w| w.wake());
278        Ok(amt)
279    }
280
281    pub fn set_max_size(&self, new_size: usize) {
282        let mut state = self.state.lock().unwrap();
283        state.halt_immediate_poll_write = false;
284
285        let mut existing: Vec<u8> = vec![0; state.buffer.len()];
286        if !state.buffer.is_empty() {
287            let amt = state.buffer.dequeue_slice(&mut existing[..]);
288            existing.resize(amt, 0);
289        }
290
291        state.buffer = RingBuffer::new(vec![0; new_size]);
292        if !existing.is_empty() {
293            let _ = state.buffer.enqueue_slice(&existing[..]);
294        }
295    }
296
297    pub fn max_size(&self) -> usize {
298        let state = self.state.lock().unwrap();
299        state.buffer.capacity()
300    }
301}
302
303impl AsyncWrite for SocketBuffer {
304    fn poll_write(
305        self: Pin<&mut Self>,
306        cx: &mut Context<'_>,
307        buf: &[u8],
308    ) -> Poll<Result<usize, io::Error>> {
309        match self.try_send(buf, false, Some(cx.waker())) {
310            Ok(amt) => Poll::Ready(Ok(amt)),
311            Err(NetworkError::WouldBlock) => Poll::Pending,
312            Err(err) => Poll::Ready(Err(net_error_into_io_err(err))),
313        }
314    }
315
316    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
317        Poll::Ready(Ok(()))
318    }
319
320    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
321        self.set_state(State::Shutdown);
322        Poll::Ready(Ok(()))
323    }
324}
325
326impl AsyncRead for SocketBuffer {
327    fn poll_read(
328        self: Pin<&mut Self>,
329        cx: &mut Context<'_>,
330        buf: &mut tokio::io::ReadBuf<'_>,
331    ) -> Poll<io::Result<()>> {
332        match self.try_read(unsafe { buf.unfilled_mut() }, false, Some(cx.waker())) {
333            Ok(amt) => {
334                unsafe { buf.assume_init(amt) };
335                buf.advance(amt);
336                Poll::Ready(Ok(()))
337            }
338            Err(NetworkError::WouldBlock) => Poll::Pending,
339            Err(err) => Poll::Ready(Err(net_error_into_io_err(err))),
340        }
341    }
342}
343
344#[derive(Debug)]
345pub struct TcpSocketHalf {
346    addr_local: SocketAddr,
347    addr_peer: SocketAddr,
348    tx: SocketBuffer,
349    rx: SocketBuffer,
350    ttl: u32,
351}
352
353impl TcpSocketHalf {
354    pub fn channel(
355        max_buffer_size: usize,
356        addr1: SocketAddr,
357        addr2: SocketAddr,
358    ) -> (TcpSocketHalf, TcpSocketHalf) {
359        let mut buffer1 = SocketBuffer::new(max_buffer_size);
360        buffer1.dead_on_drop = true;
361
362        let mut buffer2 = SocketBuffer::new(max_buffer_size);
363        buffer2.dead_on_drop = true;
364
365        let half1 = Self {
366            tx: buffer1.clone(),
367            rx: buffer2.clone(),
368            addr_local: addr1,
369            addr_peer: addr2,
370            ttl: 64,
371        };
372        let half2 = Self {
373            tx: buffer2,
374            rx: buffer1,
375            addr_local: addr2,
376            addr_peer: addr1,
377            ttl: 64,
378        };
379        (half1, half2)
380    }
381
382    pub fn is_active(&self) -> bool {
383        self.tx.state() == State::Alive
384    }
385
386    pub fn close(&self) -> crate::Result<()> {
387        self.tx.set_state(State::Closed);
388        self.rx.set_state(State::Closed);
389        Ok(())
390    }
391}
392
393impl VirtualIoSource for TcpSocketHalf {
394    fn remove_handler(&mut self) {
395        self.tx.clear_pull_handler();
396        self.rx.clear_push_handler();
397    }
398
399    fn poll_read_ready(&mut self, cx: &mut Context<'_>) -> Poll<crate::Result<usize>> {
400        self.rx.poll_read_ready(cx)
401    }
402
403    fn poll_write_ready(&mut self, cx: &mut Context<'_>) -> Poll<crate::Result<usize>> {
404        self.tx.poll_write_ready(cx)
405    }
406}
407
408impl VirtualSocket for TcpSocketHalf {
409    fn set_ttl(&mut self, ttl: u32) -> crate::Result<()> {
410        self.ttl = ttl;
411        Ok(())
412    }
413
414    fn ttl(&self) -> crate::Result<u32> {
415        Ok(self.ttl)
416    }
417
418    fn addr_local(&self) -> crate::Result<SocketAddr> {
419        Ok(self.addr_local)
420    }
421
422    fn status(&self) -> crate::Result<SocketStatus> {
423        Ok(match self.tx.state() {
424            State::Alive => SocketStatus::Opened,
425            State::Dead => SocketStatus::Failed,
426            State::Closed => SocketStatus::Closed,
427            State::Shutdown => SocketStatus::Closed,
428        })
429    }
430
431    fn set_handler(
432        &mut self,
433        handler: Box<dyn InterestHandler + Send + Sync>,
434    ) -> crate::Result<()> {
435        let handler = ArcInterestHandler::new(handler);
436        self.tx.set_pull_handler(handler.clone());
437        self.rx.set_push_handler(handler);
438        Ok(())
439    }
440}
441
442impl VirtualConnectedSocket for TcpSocketHalf {
443    fn set_linger(&mut self, _linger: Option<Duration>) -> crate::Result<()> {
444        Ok(())
445    }
446
447    fn linger(&self) -> crate::Result<Option<Duration>> {
448        Ok(None)
449    }
450
451    fn try_send(&mut self, data: &[u8]) -> crate::Result<usize> {
452        self.tx.try_send(data, false, None)
453    }
454
455    fn try_flush(&mut self) -> crate::Result<()> {
456        Ok(())
457    }
458
459    fn close(&mut self) -> crate::Result<()> {
460        self.tx.set_state(State::Closed);
461        self.rx.set_state(State::Closed);
462        Ok(())
463    }
464
465    fn try_recv(
466        &mut self,
467        buf: &mut [std::mem::MaybeUninit<u8>],
468        peek: bool,
469    ) -> crate::Result<usize> {
470        self.rx.try_read(buf, peek, None)
471    }
472}
473
474impl VirtualTcpSocket for TcpSocketHalf {
475    fn set_recv_buf_size(&mut self, size: usize) -> crate::Result<()> {
476        self.rx.set_max_size(size);
477        Ok(())
478    }
479
480    fn recv_buf_size(&self) -> crate::Result<usize> {
481        Ok(self.rx.max_size())
482    }
483
484    fn set_send_buf_size(&mut self, size: usize) -> crate::Result<()> {
485        self.tx.set_max_size(size);
486        Ok(())
487    }
488
489    fn send_buf_size(&self) -> crate::Result<usize> {
490        Ok(self.tx.max_size())
491    }
492
493    fn set_nodelay(&mut self, _reuse: bool) -> crate::Result<()> {
494        Ok(())
495    }
496
497    fn nodelay(&self) -> crate::Result<bool> {
498        Ok(true)
499    }
500
501    fn set_keepalive(&mut self, _keepalive: bool) -> crate::Result<()> {
502        Ok(())
503    }
504
505    fn keepalive(&self) -> crate::Result<bool> {
506        Ok(false)
507    }
508
509    fn set_dontroute(&mut self, _keepalive: bool) -> crate::Result<()> {
510        Ok(())
511    }
512
513    fn dontroute(&self) -> crate::Result<bool> {
514        Ok(false)
515    }
516
517    fn addr_peer(&self) -> crate::Result<SocketAddr> {
518        Ok(self.addr_peer)
519    }
520
521    fn shutdown(&mut self, how: std::net::Shutdown) -> crate::Result<()> {
522        match how {
523            std::net::Shutdown::Both => {
524                self.tx.set_state(State::Shutdown);
525                self.rx.set_state(State::Shutdown);
526            }
527            std::net::Shutdown::Read => {
528                self.rx.set_state(State::Shutdown);
529            }
530            std::net::Shutdown::Write => {
531                self.tx.set_state(State::Shutdown);
532            }
533        }
534        Ok(())
535    }
536
537    fn is_closed(&self) -> bool {
538        self.tx.state() != State::Alive
539    }
540}
541
542#[allow(unused)]
543#[derive(Debug)]
544pub struct TcpSocketHalfTx {
545    addr_local: SocketAddr,
546    addr_peer: SocketAddr,
547    tx: SocketBuffer,
548    ttl: u32,
549}
550
551impl TcpSocketHalfTx {
552    pub fn poll_send(&self, cx: &mut Context<'_>, data: &[u8]) -> Poll<io::Result<usize>> {
553        match self.tx.try_send(data, false, Some(cx.waker())) {
554            Ok(amt) => Poll::Ready(Ok(amt)),
555            Err(NetworkError::WouldBlock) => Poll::Pending,
556            Err(err) => Poll::Ready(Err(net_error_into_io_err(err))),
557        }
558    }
559
560    pub fn try_send(&self, data: &[u8]) -> io::Result<usize> {
561        self.tx
562            .try_send(data, false, None)
563            .map_err(net_error_into_io_err)
564    }
565
566    pub async fn send_ext(&self, data: Bytes, non_blocking: bool) -> io::Result<()> {
567        if non_blocking {
568            self.tx
569                .try_send(&data, true, None)
570                .map_err(net_error_into_io_err)
571                .map(|_| ())
572        } else {
573            self.tx.send(data).await.map_err(net_error_into_io_err)
574        }
575    }
576
577    pub async fn send(&self, data: Bytes) -> io::Result<()> {
578        self.send_ext(data, false).await
579    }
580
581    pub fn close(&self) -> crate::Result<()> {
582        self.tx.set_state(State::Closed);
583        Ok(())
584    }
585}
586
587impl AsyncWrite for TcpSocketHalfTx {
588    fn poll_write(
589        mut self: Pin<&mut Self>,
590        cx: &mut Context<'_>,
591        buf: &[u8],
592    ) -> Poll<Result<usize, io::Error>> {
593        Pin::new(&mut self.tx).poll_write(cx, buf)
594    }
595
596    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
597        Pin::new(&mut self.tx).poll_flush(cx)
598    }
599
600    fn poll_shutdown(
601        mut self: Pin<&mut Self>,
602        cx: &mut Context<'_>,
603    ) -> Poll<Result<(), io::Error>> {
604        Pin::new(&mut self.tx).poll_shutdown(cx)
605    }
606}
607
608#[allow(unused)]
609#[derive(Debug)]
610pub struct TcpSocketHalfRx {
611    addr_local: SocketAddr,
612    addr_peer: SocketAddr,
613    rx: BufReader<SocketBuffer>,
614    ttl: u32,
615}
616
617impl TcpSocketHalfRx {
618    pub fn buffer(&self) -> &[u8] {
619        self.rx.buffer()
620    }
621
622    pub fn close(&mut self) -> crate::Result<()> {
623        self.rx.get_mut().set_state(State::Closed);
624        Ok(())
625    }
626
627    #[allow(dead_code)]
628    pub(crate) fn inner(&mut self) -> &BufReader<SocketBuffer> {
629        &self.rx
630    }
631
632    #[allow(dead_code)]
633    pub(crate) fn inner_mut(&mut self) -> &mut BufReader<SocketBuffer> {
634        &mut self.rx
635    }
636}
637
638impl AsyncRead for TcpSocketHalfRx {
639    fn poll_read(
640        mut self: Pin<&mut Self>,
641        cx: &mut Context<'_>,
642        buf: &mut tokio::io::ReadBuf<'_>,
643    ) -> Poll<io::Result<()>> {
644        Pin::new(&mut self.rx).poll_read(cx, buf)
645    }
646}
647
648impl TcpSocketHalfRx {
649    pub fn poll_fill_buf(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
650        Pin::new(&mut self.rx).poll_fill_buf(cx)
651    }
652
653    pub fn consume(&mut self, amt: usize) {
654        Pin::new(&mut self.rx).consume(amt)
655    }
656}
657
658impl TcpSocketHalf {
659    pub fn split(self) -> (TcpSocketHalfTx, TcpSocketHalfRx) {
660        let tx = TcpSocketHalfTx {
661            tx: self.tx,
662            addr_local: self.addr_local,
663            addr_peer: self.addr_peer,
664            ttl: self.ttl,
665        };
666        let rx = TcpSocketHalfRx {
667            rx: BufReader::new(self.rx),
668            addr_local: self.addr_local,
669            addr_peer: self.addr_peer,
670            ttl: self.ttl,
671        };
672        (tx, rx)
673    }
674
675    pub fn combine(tx: TcpSocketHalfTx, rx: TcpSocketHalfRx) -> Self {
676        Self {
677            tx: tx.tx,
678            rx: rx.rx.into_inner(),
679            addr_local: tx.addr_local,
680            addr_peer: tx.addr_peer,
681            ttl: tx.ttl,
682        }
683    }
684}