1use crate::{
2 InterestHandler, NetworkError, SocketStatus, VirtualConnectedSocket, VirtualIoSource,
3 VirtualSocket, VirtualTcpSocket, net_error_into_io_err,
4};
5use bytes::{Buf, Bytes};
6use futures_util::Future;
7use smoltcp::storage::RingBuffer;
8use std::io::{self};
9use std::pin::Pin;
10use std::sync::Arc;
11use std::sync::Mutex;
12use std::task::{Context, Waker};
13use std::time::Duration;
14use std::{net::SocketAddr, task::Poll};
15use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, BufReader};
16use virtual_mio::{ArcInterestHandler, InterestType};
17
18#[derive(Debug, PartialEq, Eq, Clone, Copy)]
19enum State {
20 Alive,
21 Dead,
22 Closed,
23 Shutdown,
24}
25
26#[derive(Debug)]
27struct SocketBufferState {
28 buffer: RingBuffer<'static, u8>,
29 push_handler: Option<ArcInterestHandler>,
30 pull_handler: Option<ArcInterestHandler>,
31 wakers: Vec<Waker>,
32 state: State,
33 halt_immediate_poll_write: bool,
35}
36
37impl SocketBufferState {
38 fn add_waker(&mut self, waker: &Waker) {
39 if !self.wakers.iter().any(|w| w.will_wake(waker)) {
40 self.wakers.push(waker.clone());
41 }
42 }
43}
44
45#[derive(Debug, Clone)]
46pub(crate) struct SocketBuffer {
47 state: Arc<Mutex<SocketBufferState>>,
48 dead_on_drop: bool,
49}
50
51impl Drop for SocketBuffer {
52 fn drop(&mut self) {
53 if self.state() == State::Alive {
54 if self.dead_on_drop {
55 self.set_state(State::Dead);
56 } else {
57 self.set_state(State::Closed);
58 }
59 }
60 }
61}
62
63impl SocketBuffer {
64 fn new(max_size: usize) -> Self {
65 Self {
66 state: Arc::new(Mutex::new(SocketBufferState {
67 buffer: RingBuffer::new(vec![0; max_size]),
68 push_handler: None,
69 pull_handler: None,
70 wakers: Vec::new(),
71 state: State::Alive,
72 halt_immediate_poll_write: false,
73 })),
74 dead_on_drop: false,
75 }
76 }
77
78 pub fn set_push_handler(&self, mut handler: ArcInterestHandler) {
79 let mut state = self.state.lock().unwrap();
80 if state.state != State::Alive {
81 handler.push_interest(InterestType::Closed);
82 }
83 if !state.buffer.is_empty() {
84 handler.push_interest(InterestType::Readable);
85 }
86 state.push_handler.replace(handler);
87 }
88
89 pub fn set_pull_handler(&self, mut handler: ArcInterestHandler) {
90 let mut state = self.state.lock().unwrap();
91 if state.state != State::Alive {
92 handler.push_interest(InterestType::Closed);
93 }
94 if !state.buffer.is_full() && state.pull_handler.is_none() {
95 handler.push_interest(InterestType::Writable);
96 }
97 state.pull_handler.replace(handler);
98 }
99
100 pub fn clear_push_handler(&self) {
101 let mut state = self.state.lock().unwrap();
102 state.push_handler.take();
103 }
104
105 pub fn clear_pull_handler(&self) {
106 let mut state = self.state.lock().unwrap();
107 state.pull_handler.take();
108 }
109
110 pub fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<crate::Result<usize>> {
111 let mut state = self.state.lock().unwrap();
112 if !state.buffer.is_empty() {
113 return Poll::Ready(Ok(state.buffer.len()));
114 }
115 match state.state {
116 State::Alive => {
117 if !state.wakers.iter().any(|w| w.will_wake(cx.waker())) {
118 state.wakers.push(cx.waker().clone());
119 }
120 Poll::Pending
121 }
122 State::Dead => {
123 tracing::trace!("poll_read_ready: socket is dead");
124 Poll::Ready(Err(NetworkError::ConnectionReset))
125 }
126 State::Closed | State::Shutdown => {
127 tracing::trace!("poll_read_ready: socket is closed or shutdown");
128 Poll::Ready(Ok(0))
129 }
130 }
131 }
132
133 pub fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<crate::Result<usize>> {
134 let mut state = self.state.lock().unwrap();
135 match state.state {
136 State::Alive => {
137 if !state.buffer.is_full() && !state.halt_immediate_poll_write {
138 state.halt_immediate_poll_write = true;
139 return Poll::Ready(Ok(state.buffer.window()));
140 }
141 if !state.wakers.iter().any(|w| w.will_wake(cx.waker())) {
142 state.wakers.push(cx.waker().clone());
143 }
144 Poll::Pending
145 }
146 State::Dead => {
147 tracing::trace!("poll_write_ready: socket is dead");
148 Poll::Ready(Err(NetworkError::ConnectionReset))
149 }
150 State::Closed | State::Shutdown => {
151 tracing::trace!("poll_write_ready: socket is closed or shutdown");
152 Poll::Ready(Ok(0))
153 }
154 }
155 }
156
157 fn set_state(&self, new_state: State) {
158 let mut state = self.state.lock().unwrap();
159 state.state = new_state;
160 if let Some(handler) = state.pull_handler.as_mut() {
161 handler.push_interest(InterestType::Closed);
162 }
163 if let Some(handler) = state.push_handler.as_mut() {
164 handler.push_interest(InterestType::Closed);
165 }
166 state.wakers.drain(..).for_each(|w| w.wake());
167 }
168
169 fn state(&self) -> State {
170 let state = self.state.lock().unwrap();
171 state.state
172 }
173
174 pub fn try_send(
175 &self,
176 data: &[u8],
177 all_or_nothing: bool,
178 waker: Option<&Waker>,
179 ) -> crate::Result<usize> {
180 let mut state = self.state.lock().unwrap();
181 if state.state != State::Alive {
182 return Err(NetworkError::ConnectionReset);
183 }
184 state.halt_immediate_poll_write = false;
185 let available = state.buffer.window();
186 if available == 0 {
187 if let Some(waker) = waker {
188 state.add_waker(waker)
189 }
190 return Err(NetworkError::WouldBlock);
191 }
192 if data.len() > available {
193 if all_or_nothing {
194 if let Some(waker) = waker {
195 state.add_waker(waker)
196 }
197 return Err(NetworkError::WouldBlock);
198 }
199 let amt = state.buffer.enqueue_slice(&data[..available]);
200 return Ok(amt);
201 }
202 let amt = state.buffer.enqueue_slice(data);
203
204 if let Some(handler) = state.push_handler.as_mut() {
205 handler.push_interest(InterestType::Readable);
206 }
207 state.wakers.drain(..).for_each(|w| w.wake());
208 Ok(amt)
209 }
210
211 pub async fn send(&self, data: Bytes) -> crate::Result<()> {
212 struct Poller<'a> {
213 this: &'a SocketBuffer,
214 data: Bytes,
215 }
216 impl Future for Poller<'_> {
217 type Output = crate::Result<()>;
218 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
219 loop {
220 if self.data.is_empty() {
221 return Poll::Ready(Ok(()));
222 }
223 return match self.this.try_send(&self.data, false, Some(cx.waker())) {
224 Ok(amt) => {
225 self.data.advance(amt);
226 continue;
227 }
228 Err(NetworkError::WouldBlock) => Poll::Pending,
229 Err(err) => Poll::Ready(Err(err)),
230 };
231 }
232 }
233 }
234 Poller { this: self, data }.await
235 }
236
237 pub fn try_read(
238 &self,
239 buf: &mut [std::mem::MaybeUninit<u8>],
240 peek: bool,
241 waker: Option<&Waker>,
242 ) -> crate::Result<usize> {
243 let mut state = self.state.lock().unwrap();
244 if state.buffer.is_empty() {
245 return match state.state {
246 State::Alive => {
247 if let Some(waker) = waker {
248 state.add_waker(waker)
249 }
250 Err(NetworkError::WouldBlock)
251 }
252 State::Dead => {
253 tracing::trace!("try_read: socket is dead");
254 return Ok(0);
258 }
259 State::Closed | State::Shutdown => {
260 tracing::trace!("try_read: socket is closed or shutdown");
261 return Ok(0);
262 }
263 };
264 }
265
266 let buf: &mut [u8] = unsafe { std::mem::transmute(buf) };
267 let amt = buf.len().min(state.buffer.len());
268 let amt = if peek {
269 state.buffer.read_allocated(0, &mut buf[..amt])
270 } else {
271 state.buffer.dequeue_slice(&mut buf[..amt])
272 };
273
274 if let Some(handler) = state.pull_handler.as_mut() {
275 handler.push_interest(InterestType::Writable);
276 }
277 state.wakers.drain(..).for_each(|w| w.wake());
278 Ok(amt)
279 }
280
281 pub fn set_max_size(&self, new_size: usize) {
282 let mut state = self.state.lock().unwrap();
283 state.halt_immediate_poll_write = false;
284
285 let mut existing: Vec<u8> = vec![0; state.buffer.len()];
286 if !state.buffer.is_empty() {
287 let amt = state.buffer.dequeue_slice(&mut existing[..]);
288 existing.resize(amt, 0);
289 }
290
291 state.buffer = RingBuffer::new(vec![0; new_size]);
292 if !existing.is_empty() {
293 let _ = state.buffer.enqueue_slice(&existing[..]);
294 }
295 }
296
297 pub fn max_size(&self) -> usize {
298 let state = self.state.lock().unwrap();
299 state.buffer.capacity()
300 }
301}
302
303impl AsyncWrite for SocketBuffer {
304 fn poll_write(
305 self: Pin<&mut Self>,
306 cx: &mut Context<'_>,
307 buf: &[u8],
308 ) -> Poll<Result<usize, io::Error>> {
309 match self.try_send(buf, false, Some(cx.waker())) {
310 Ok(amt) => Poll::Ready(Ok(amt)),
311 Err(NetworkError::WouldBlock) => Poll::Pending,
312 Err(err) => Poll::Ready(Err(net_error_into_io_err(err))),
313 }
314 }
315
316 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
317 Poll::Ready(Ok(()))
318 }
319
320 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
321 self.set_state(State::Shutdown);
322 Poll::Ready(Ok(()))
323 }
324}
325
326impl AsyncRead for SocketBuffer {
327 fn poll_read(
328 self: Pin<&mut Self>,
329 cx: &mut Context<'_>,
330 buf: &mut tokio::io::ReadBuf<'_>,
331 ) -> Poll<io::Result<()>> {
332 match self.try_read(unsafe { buf.unfilled_mut() }, false, Some(cx.waker())) {
333 Ok(amt) => {
334 unsafe { buf.assume_init(amt) };
335 buf.advance(amt);
336 Poll::Ready(Ok(()))
337 }
338 Err(NetworkError::WouldBlock) => Poll::Pending,
339 Err(err) => Poll::Ready(Err(net_error_into_io_err(err))),
340 }
341 }
342}
343
344#[derive(Debug)]
345pub struct TcpSocketHalf {
346 addr_local: SocketAddr,
347 addr_peer: SocketAddr,
348 tx: SocketBuffer,
349 rx: SocketBuffer,
350 ttl: u32,
351}
352
353impl TcpSocketHalf {
354 pub fn channel(
355 max_buffer_size: usize,
356 addr1: SocketAddr,
357 addr2: SocketAddr,
358 ) -> (TcpSocketHalf, TcpSocketHalf) {
359 let mut buffer1 = SocketBuffer::new(max_buffer_size);
360 buffer1.dead_on_drop = true;
361
362 let mut buffer2 = SocketBuffer::new(max_buffer_size);
363 buffer2.dead_on_drop = true;
364
365 let half1 = Self {
366 tx: buffer1.clone(),
367 rx: buffer2.clone(),
368 addr_local: addr1,
369 addr_peer: addr2,
370 ttl: 64,
371 };
372 let half2 = Self {
373 tx: buffer2,
374 rx: buffer1,
375 addr_local: addr2,
376 addr_peer: addr1,
377 ttl: 64,
378 };
379 (half1, half2)
380 }
381
382 pub fn is_active(&self) -> bool {
383 self.tx.state() == State::Alive
384 }
385
386 pub fn close(&self) -> crate::Result<()> {
387 self.tx.set_state(State::Closed);
388 self.rx.set_state(State::Closed);
389 Ok(())
390 }
391}
392
393impl VirtualIoSource for TcpSocketHalf {
394 fn remove_handler(&mut self) {
395 self.tx.clear_pull_handler();
396 self.rx.clear_push_handler();
397 }
398
399 fn poll_read_ready(&mut self, cx: &mut Context<'_>) -> Poll<crate::Result<usize>> {
400 self.rx.poll_read_ready(cx)
401 }
402
403 fn poll_write_ready(&mut self, cx: &mut Context<'_>) -> Poll<crate::Result<usize>> {
404 self.tx.poll_write_ready(cx)
405 }
406}
407
408impl VirtualSocket for TcpSocketHalf {
409 fn set_ttl(&mut self, ttl: u32) -> crate::Result<()> {
410 self.ttl = ttl;
411 Ok(())
412 }
413
414 fn ttl(&self) -> crate::Result<u32> {
415 Ok(self.ttl)
416 }
417
418 fn addr_local(&self) -> crate::Result<SocketAddr> {
419 Ok(self.addr_local)
420 }
421
422 fn status(&self) -> crate::Result<SocketStatus> {
423 Ok(match self.tx.state() {
424 State::Alive => SocketStatus::Opened,
425 State::Dead => SocketStatus::Failed,
426 State::Closed => SocketStatus::Closed,
427 State::Shutdown => SocketStatus::Closed,
428 })
429 }
430
431 fn set_handler(
432 &mut self,
433 handler: Box<dyn InterestHandler + Send + Sync>,
434 ) -> crate::Result<()> {
435 let handler = ArcInterestHandler::new(handler);
436 self.tx.set_pull_handler(handler.clone());
437 self.rx.set_push_handler(handler);
438 Ok(())
439 }
440}
441
442impl VirtualConnectedSocket for TcpSocketHalf {
443 fn set_linger(&mut self, _linger: Option<Duration>) -> crate::Result<()> {
444 Ok(())
445 }
446
447 fn linger(&self) -> crate::Result<Option<Duration>> {
448 Ok(None)
449 }
450
451 fn try_send(&mut self, data: &[u8]) -> crate::Result<usize> {
452 self.tx.try_send(data, false, None)
453 }
454
455 fn try_flush(&mut self) -> crate::Result<()> {
456 Ok(())
457 }
458
459 fn close(&mut self) -> crate::Result<()> {
460 self.tx.set_state(State::Closed);
461 self.rx.set_state(State::Closed);
462 Ok(())
463 }
464
465 fn try_recv(
466 &mut self,
467 buf: &mut [std::mem::MaybeUninit<u8>],
468 peek: bool,
469 ) -> crate::Result<usize> {
470 self.rx.try_read(buf, peek, None)
471 }
472}
473
474impl VirtualTcpSocket for TcpSocketHalf {
475 fn set_recv_buf_size(&mut self, size: usize) -> crate::Result<()> {
476 self.rx.set_max_size(size);
477 Ok(())
478 }
479
480 fn recv_buf_size(&self) -> crate::Result<usize> {
481 Ok(self.rx.max_size())
482 }
483
484 fn set_send_buf_size(&mut self, size: usize) -> crate::Result<()> {
485 self.tx.set_max_size(size);
486 Ok(())
487 }
488
489 fn send_buf_size(&self) -> crate::Result<usize> {
490 Ok(self.tx.max_size())
491 }
492
493 fn set_nodelay(&mut self, _reuse: bool) -> crate::Result<()> {
494 Ok(())
495 }
496
497 fn nodelay(&self) -> crate::Result<bool> {
498 Ok(true)
499 }
500
501 fn set_keepalive(&mut self, _keepalive: bool) -> crate::Result<()> {
502 Ok(())
503 }
504
505 fn keepalive(&self) -> crate::Result<bool> {
506 Ok(false)
507 }
508
509 fn set_dontroute(&mut self, _keepalive: bool) -> crate::Result<()> {
510 Ok(())
511 }
512
513 fn dontroute(&self) -> crate::Result<bool> {
514 Ok(false)
515 }
516
517 fn addr_peer(&self) -> crate::Result<SocketAddr> {
518 Ok(self.addr_peer)
519 }
520
521 fn shutdown(&mut self, how: std::net::Shutdown) -> crate::Result<()> {
522 match how {
523 std::net::Shutdown::Both => {
524 self.tx.set_state(State::Shutdown);
525 self.rx.set_state(State::Shutdown);
526 }
527 std::net::Shutdown::Read => {
528 self.rx.set_state(State::Shutdown);
529 }
530 std::net::Shutdown::Write => {
531 self.tx.set_state(State::Shutdown);
532 }
533 }
534 Ok(())
535 }
536
537 fn is_closed(&self) -> bool {
538 self.tx.state() != State::Alive
539 }
540}
541
542#[allow(unused)]
543#[derive(Debug)]
544pub struct TcpSocketHalfTx {
545 addr_local: SocketAddr,
546 addr_peer: SocketAddr,
547 tx: SocketBuffer,
548 ttl: u32,
549}
550
551impl TcpSocketHalfTx {
552 pub fn poll_send(&self, cx: &mut Context<'_>, data: &[u8]) -> Poll<io::Result<usize>> {
553 match self.tx.try_send(data, false, Some(cx.waker())) {
554 Ok(amt) => Poll::Ready(Ok(amt)),
555 Err(NetworkError::WouldBlock) => Poll::Pending,
556 Err(err) => Poll::Ready(Err(net_error_into_io_err(err))),
557 }
558 }
559
560 pub fn try_send(&self, data: &[u8]) -> io::Result<usize> {
561 self.tx
562 .try_send(data, false, None)
563 .map_err(net_error_into_io_err)
564 }
565
566 pub async fn send_ext(&self, data: Bytes, non_blocking: bool) -> io::Result<()> {
567 if non_blocking {
568 self.tx
569 .try_send(&data, true, None)
570 .map_err(net_error_into_io_err)
571 .map(|_| ())
572 } else {
573 self.tx.send(data).await.map_err(net_error_into_io_err)
574 }
575 }
576
577 pub async fn send(&self, data: Bytes) -> io::Result<()> {
578 self.send_ext(data, false).await
579 }
580
581 pub fn close(&self) -> crate::Result<()> {
582 self.tx.set_state(State::Closed);
583 Ok(())
584 }
585}
586
587impl AsyncWrite for TcpSocketHalfTx {
588 fn poll_write(
589 mut self: Pin<&mut Self>,
590 cx: &mut Context<'_>,
591 buf: &[u8],
592 ) -> Poll<Result<usize, io::Error>> {
593 Pin::new(&mut self.tx).poll_write(cx, buf)
594 }
595
596 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
597 Pin::new(&mut self.tx).poll_flush(cx)
598 }
599
600 fn poll_shutdown(
601 mut self: Pin<&mut Self>,
602 cx: &mut Context<'_>,
603 ) -> Poll<Result<(), io::Error>> {
604 Pin::new(&mut self.tx).poll_shutdown(cx)
605 }
606}
607
608#[allow(unused)]
609#[derive(Debug)]
610pub struct TcpSocketHalfRx {
611 addr_local: SocketAddr,
612 addr_peer: SocketAddr,
613 rx: BufReader<SocketBuffer>,
614 ttl: u32,
615}
616
617impl TcpSocketHalfRx {
618 pub fn buffer(&self) -> &[u8] {
619 self.rx.buffer()
620 }
621
622 pub fn close(&mut self) -> crate::Result<()> {
623 self.rx.get_mut().set_state(State::Closed);
624 Ok(())
625 }
626
627 #[allow(dead_code)]
628 pub(crate) fn inner(&mut self) -> &BufReader<SocketBuffer> {
629 &self.rx
630 }
631
632 #[allow(dead_code)]
633 pub(crate) fn inner_mut(&mut self) -> &mut BufReader<SocketBuffer> {
634 &mut self.rx
635 }
636}
637
638impl AsyncRead for TcpSocketHalfRx {
639 fn poll_read(
640 mut self: Pin<&mut Self>,
641 cx: &mut Context<'_>,
642 buf: &mut tokio::io::ReadBuf<'_>,
643 ) -> Poll<io::Result<()>> {
644 Pin::new(&mut self.rx).poll_read(cx, buf)
645 }
646}
647
648impl TcpSocketHalfRx {
649 pub fn poll_fill_buf(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
650 Pin::new(&mut self.rx).poll_fill_buf(cx)
651 }
652
653 pub fn consume(&mut self, amt: usize) {
654 Pin::new(&mut self.rx).consume(amt)
655 }
656}
657
658impl TcpSocketHalf {
659 pub fn split(self) -> (TcpSocketHalfTx, TcpSocketHalfRx) {
660 let tx = TcpSocketHalfTx {
661 tx: self.tx,
662 addr_local: self.addr_local,
663 addr_peer: self.addr_peer,
664 ttl: self.ttl,
665 };
666 let rx = TcpSocketHalfRx {
667 rx: BufReader::new(self.rx),
668 addr_local: self.addr_local,
669 addr_peer: self.addr_peer,
670 ttl: self.ttl,
671 };
672 (tx, rx)
673 }
674
675 pub fn combine(tx: TcpSocketHalfTx, rx: TcpSocketHalfRx) -> Self {
676 Self {
677 tx: tx.tx,
678 rx: rx.rx.into_inner(),
679 addr_local: tx.addr_local,
680 addr_peer: tx.addr_peer,
681 ttl: tx.ttl,
682 }
683 }
684}