virtual_net/
rx_tx.rs

1use std::{
2    pin::Pin,
3    sync::{Arc, Mutex},
4    task::{Context, Poll, Waker},
5};
6
7use crate::Result;
8use futures_util::{Future, Sink, SinkExt, Stream, future::BoxFuture};
9#[cfg(feature = "hyper")]
10use hyper_util::rt::tokio::TokioIo;
11use serde::Serialize;
12#[cfg(feature = "tokio-tungstenite")]
13use tokio::net::TcpStream;
14use tokio::sync::{
15    mpsc::{self, error::TrySendError},
16    oneshot,
17};
18use virtual_mio::InlineWaker;
19
20use crate::{NetworkError, io_err_into_net_error};
21
22#[derive(Debug, Clone, Default)]
23pub(crate) struct RemoteTxWakers {
24    wakers: Arc<Mutex<Vec<Waker>>>,
25}
26impl RemoteTxWakers {
27    pub fn add(&self, waker: &Waker) {
28        let mut guard = self.wakers.lock().unwrap();
29        if !guard.iter().any(|w| w.will_wake(waker)) {
30            guard.push(waker.clone());
31        }
32    }
33    pub fn wake(&self) {
34        let mut guard = self.wakers.lock().unwrap();
35        guard.drain(..).for_each(|w| w.wake());
36    }
37}
38
39pub(crate) type StreamSink<T> = Pin<Box<dyn Sink<T, Error = std::io::Error> + Send + 'static>>;
40
41#[derive(derive_more::Debug)]
42pub(crate) enum RemoteTx<T>
43where
44    T: Serialize,
45{
46    Mpsc {
47        tx: mpsc::Sender<T>,
48        work: mpsc::UnboundedSender<BoxFuture<'static, ()>>,
49        wakers: RemoteTxWakers,
50    },
51    Stream {
52        #[debug(ignore)]
53        tx: Arc<tokio::sync::Mutex<StreamSink<T>>>,
54        work: mpsc::UnboundedSender<BoxFuture<'static, ()>>,
55        wakers: RemoteTxWakers,
56    },
57    #[cfg(feature = "hyper")]
58    HyperWebSocket {
59        tx: Arc<
60            tokio::sync::Mutex<
61                futures_util::stream::SplitSink<
62                    hyper_tungstenite::WebSocketStream<TokioIo<hyper::upgrade::Upgraded>>,
63                    hyper_tungstenite::tungstenite::Message,
64                >,
65            >,
66        >,
67        work: mpsc::UnboundedSender<BoxFuture<'static, ()>>,
68        wakers: RemoteTxWakers,
69        format: crate::meta::FrameSerializationFormat,
70    },
71    #[cfg(feature = "tokio-tungstenite")]
72    TokioWebSocket {
73        tx: Arc<
74            tokio::sync::Mutex<
75                futures_util::stream::SplitSink<
76                    tokio_tungstenite::WebSocketStream<
77                        tokio_tungstenite::MaybeTlsStream<TcpStream>,
78                    >,
79                    tokio_tungstenite::tungstenite::Message,
80                >,
81            >,
82        >,
83        work: mpsc::UnboundedSender<BoxFuture<'static, ()>>,
84        wakers: RemoteTxWakers,
85        format: crate::meta::FrameSerializationFormat,
86    },
87}
88impl<T> RemoteTx<T>
89where
90    T: Serialize + Send + Sync + 'static,
91{
92    pub(crate) async fn send(&self, req: T) -> Result<()> {
93        match self {
94            RemoteTx::Mpsc { tx, .. } => tx
95                .send(req)
96                .await
97                .map_err(|_| NetworkError::ConnectionAborted),
98            RemoteTx::Stream { tx, work, .. } => {
99                let (tx_done, rx_done) = oneshot::channel();
100                let tx = tx.clone();
101                work.send(Box::pin(async move {
102                    let job = async {
103                        let mut tx_guard = tx.lock_owned().await;
104                        tx_guard.send(req).await.map_err(io_err_into_net_error)
105                    };
106                    tx_done.send(job.await).ok();
107                }))
108                .map_err(|_| NetworkError::ConnectionAborted)?;
109
110                rx_done
111                    .await
112                    .unwrap_or(Err(NetworkError::ConnectionAborted))
113            }
114            #[cfg(feature = "hyper")]
115            RemoteTx::HyperWebSocket { tx, format, .. } => {
116                let data = match format {
117                    crate::meta::FrameSerializationFormat::Bincode => bincode::serialize(&req)
118                        .map_err(|err| {
119                            tracing::warn!("failed to serialize message - {err}");
120                            NetworkError::IOError
121                        })?,
122                    format => {
123                        tracing::warn!("format not currently supported - {format:?}");
124                        return Err(NetworkError::IOError);
125                    }
126                };
127                let mut tx = tx.lock().await;
128                tx.send(hyper_tungstenite::tungstenite::Message::Binary(data))
129                    .await
130                    .map_err(|_| NetworkError::ConnectionAborted)
131            }
132            #[cfg(feature = "tokio-tungstenite")]
133            RemoteTx::TokioWebSocket { tx, format, .. } => {
134                let data = match format {
135                    crate::meta::FrameSerializationFormat::Bincode => bincode::serialize(&req)
136                        .map_err(|err| {
137                            tracing::warn!("failed to serialize message - {err}");
138                            NetworkError::IOError
139                        })?,
140                    format => {
141                        tracing::warn!("format not currently supported - {format:?}");
142                        return Err(NetworkError::IOError);
143                    }
144                };
145                let mut tx = tx.lock().await;
146                tx.send(tokio_tungstenite::tungstenite::Message::Binary(data))
147                    .await
148                    .map_err(|_| NetworkError::ConnectionAborted)
149            }
150        }
151    }
152
153    pub(crate) fn poll_send(&self, cx: &mut Context<'_>, req: T) -> Poll<Result<()>> {
154        match self {
155            RemoteTx::Mpsc { tx, wakers, .. } => match tx.try_send(req) {
156                Ok(()) => Poll::Ready(Ok(())),
157                Err(TrySendError::Closed(_)) => Poll::Ready(Err(NetworkError::ConnectionAborted)),
158                Err(TrySendError::Full(_)) => {
159                    wakers.add(cx.waker());
160                    Poll::Pending
161                }
162            },
163            RemoteTx::Stream { tx, work, wakers } => {
164                let mut tx_guard = match tx.clone().try_lock_owned() {
165                    Ok(lock) => lock,
166                    Err(_) => {
167                        wakers.add(cx.waker());
168                        return Poll::Pending;
169                    }
170                };
171                match tx_guard.poll_ready_unpin(cx) {
172                    Poll::Ready(Ok(())) => {}
173                    Poll::Ready(Err(err)) => return Poll::Ready(Err(io_err_into_net_error(err))),
174                    Poll::Pending => return Poll::Pending,
175                }
176                let mut job = Box::pin(async move {
177                    if let Err(err) = tx_guard.send(req).await.map_err(io_err_into_net_error) {
178                        tracing::error!("failed to send remaining bytes for request - {}", err);
179                    }
180                });
181
182                // First we try to finish it synchronously
183                if job.as_mut().poll(cx).is_ready() {
184                    return Poll::Ready(Ok(()));
185                }
186
187                // Otherwise we push it to the driver which will block all future send
188                // operations until it finishes
189                work.send(job).map_err(|err| {
190                    tracing::error!("failed to send remaining bytes for request - {}", err);
191                    NetworkError::ConnectionAborted
192                })?;
193                Poll::Ready(Ok(()))
194            }
195            #[cfg(feature = "hyper")]
196            RemoteTx::HyperWebSocket {
197                tx,
198                format,
199                work,
200                wakers,
201                ..
202            } => {
203                let mut tx_guard = match tx.clone().try_lock_owned() {
204                    Ok(lock) => lock,
205                    Err(_) => {
206                        wakers.add(cx.waker());
207                        return Poll::Pending;
208                    }
209                };
210                match tx_guard.poll_ready_unpin(cx) {
211                    Poll::Ready(Ok(())) => {}
212                    Poll::Ready(Err(err)) => {
213                        tracing::warn!("failed to poll web socket for readiness - {err}");
214                        return Poll::Ready(Err(NetworkError::IOError));
215                    }
216                    Poll::Pending => return Poll::Pending,
217                }
218
219                let data = match format {
220                    crate::meta::FrameSerializationFormat::Bincode => bincode::serialize(&req)
221                        .map_err(|err| {
222                            tracing::warn!("failed to serialize message - {err}");
223                            NetworkError::IOError
224                        })?,
225                    format => {
226                        tracing::warn!("format not currently supported - {format:?}");
227                        return Poll::Ready(Err(NetworkError::IOError));
228                    }
229                };
230
231                let mut job = Box::pin(async move {
232                    if let Err(err) = tx_guard
233                        .send(hyper_tungstenite::tungstenite::Message::Binary(data))
234                        .await
235                    {
236                        tracing::error!("failed to send remaining bytes for request - {}", err);
237                    }
238                });
239
240                // First we try to finish it synchronously
241                if job.as_mut().poll(cx).is_ready() {
242                    return Poll::Ready(Ok(()));
243                }
244
245                // Otherwise we push it to the driver which will block all future send
246                // operations until it finishes
247                work.send(job).map_err(|err| {
248                    tracing::error!("failed to send remaining bytes for request - {}", err);
249                    NetworkError::ConnectionAborted
250                })?;
251                Poll::Ready(Ok(()))
252            }
253            #[cfg(feature = "tokio-tungstenite")]
254            RemoteTx::TokioWebSocket {
255                tx,
256                format,
257                work,
258                wakers,
259                ..
260            } => {
261                let mut tx_guard = match tx.clone().try_lock_owned() {
262                    Ok(lock) => lock,
263                    Err(_) => {
264                        wakers.add(cx.waker());
265                        return Poll::Pending;
266                    }
267                };
268                match tx_guard.poll_ready_unpin(cx) {
269                    Poll::Ready(Ok(())) => {}
270                    Poll::Ready(Err(err)) => {
271                        tracing::warn!("failed to poll web socket for readiness - {err}");
272                        return Poll::Ready(Err(NetworkError::IOError));
273                    }
274                    Poll::Pending => return Poll::Pending,
275                }
276
277                let data = match format {
278                    crate::meta::FrameSerializationFormat::Bincode => bincode::serialize(&req)
279                        .map_err(|err| {
280                            tracing::warn!("failed to serialize message - {err}");
281                            NetworkError::IOError
282                        })?,
283                    format => {
284                        tracing::warn!("format not currently supported - {format:?}");
285                        return Poll::Ready(Err(NetworkError::IOError));
286                    }
287                };
288
289                let mut job = Box::pin(async move {
290                    if let Err(err) = tx_guard
291                        .send(tokio_tungstenite::tungstenite::Message::Binary(data))
292                        .await
293                    {
294                        tracing::error!("failed to send remaining bytes for request - {}", err);
295                    }
296                });
297
298                // First we try to finish it synchronously
299                if job.as_mut().poll(cx).is_ready() {
300                    return Poll::Ready(Ok(()));
301                }
302
303                // Otherwise we push it to the driver which will block all future send
304                // operations until it finishes
305                work.send(job).map_err(|err| {
306                    tracing::error!("failed to send remaining bytes for request - {}", err);
307                    NetworkError::ConnectionAborted
308                })?;
309                Poll::Ready(Ok(()))
310            }
311        }
312    }
313
314    pub(crate) fn send_with_driver(&self, req: T) -> Result<()> {
315        match self {
316            RemoteTx::Mpsc { tx, work, .. } => match tx.try_send(req) {
317                Ok(()) => Ok(()),
318                Err(TrySendError::Closed(_)) => Err(NetworkError::ConnectionAborted),
319                Err(TrySendError::Full(req)) => {
320                    let tx = tx.clone();
321                    work.send(Box::pin(async move {
322                        tx.send(req).await.ok();
323                    }))
324                    .ok();
325                    Ok(())
326                }
327            },
328            RemoteTx::Stream { tx, work, .. } => {
329                let mut tx_guard = match tx.clone().try_lock_owned() {
330                    Ok(lock) => lock,
331                    Err(_) => {
332                        let tx = tx.clone();
333                        work.send(Box::pin(async move {
334                            let mut tx_guard = tx.lock().await;
335                            tx_guard.send(req).await.ok();
336                        }))
337                        .ok();
338                        return Ok(());
339                    }
340                };
341
342                let inline_waker = InlineWaker::new();
343                let waker = inline_waker.as_waker();
344                let mut cx = Context::from_waker(&waker);
345
346                let mut job = Box::pin(async move {
347                    if let Err(err) = tx_guard.send(req).await.map_err(io_err_into_net_error) {
348                        tracing::error!("failed to send remaining bytes for request - {}", err);
349                    }
350                });
351
352                // First we try to finish it synchronously
353                if job.as_mut().poll(&mut cx).is_ready() {
354                    return Ok(());
355                }
356
357                // Otherwise we push it to the driver which will block all future send
358                // operations until it finishes
359                work.send(job).map_err(|err| {
360                    tracing::error!("failed to send remaining bytes for request - {}", err);
361                    NetworkError::ConnectionAborted
362                })?;
363                Ok(())
364            }
365            #[cfg(feature = "hyper")]
366            RemoteTx::HyperWebSocket {
367                tx, format, work, ..
368            } => {
369                let data = match format {
370                    crate::meta::FrameSerializationFormat::Bincode => bincode::serialize(&req)
371                        .map_err(|err| {
372                            tracing::warn!("failed to serialize message - {err}");
373                            NetworkError::IOError
374                        })?,
375                    format => {
376                        tracing::warn!("format not currently supported - {format:?}");
377                        return Err(NetworkError::IOError);
378                    }
379                };
380
381                let mut tx_guard = match tx.clone().try_lock_owned() {
382                    Ok(lock) => lock,
383                    Err(_) => {
384                        let tx = tx.clone();
385                        work.send(Box::pin(async move {
386                            let mut tx_guard = tx.lock().await;
387                            tx_guard
388                                .send(hyper_tungstenite::tungstenite::Message::Binary(data))
389                                .await
390                                .ok();
391                        }))
392                        .ok();
393                        return Ok(());
394                    }
395                };
396
397                let inline_waker = InlineWaker::new();
398                let waker = inline_waker.as_waker();
399                let mut cx = Context::from_waker(&waker);
400
401                let mut job = Box::pin(async move {
402                    if let Err(err) = tx_guard
403                        .send(hyper_tungstenite::tungstenite::Message::Binary(data))
404                        .await
405                    {
406                        tracing::error!("failed to send remaining bytes for request - {}", err);
407                    }
408                });
409
410                // First we try to finish it synchronously
411                if job.as_mut().poll(&mut cx).is_ready() {
412                    return Ok(());
413                }
414
415                // Otherwise we push it to the driver which will block all future send
416                // operations until it finishes
417                work.send(job).map_err(|err| {
418                    tracing::error!("failed to send remaining bytes for request - {}", err);
419                    NetworkError::ConnectionAborted
420                })?;
421                Ok(())
422            }
423            #[cfg(feature = "tokio-tungstenite")]
424            RemoteTx::TokioWebSocket {
425                tx, format, work, ..
426            } => {
427                let data = match format {
428                    crate::meta::FrameSerializationFormat::Bincode => bincode::serialize(&req)
429                        .map_err(|err| {
430                            tracing::warn!("failed to serialize message - {err}");
431                            NetworkError::IOError
432                        })?,
433                    format => {
434                        tracing::warn!("format not currently supported - {format:?}");
435                        return Err(NetworkError::IOError);
436                    }
437                };
438
439                let mut tx_guard = match tx.clone().try_lock_owned() {
440                    Ok(lock) => lock,
441                    Err(_) => {
442                        let tx = tx.clone();
443                        work.send(Box::pin(async move {
444                            let mut tx_guard = tx.lock().await;
445                            tx_guard
446                                .send(tokio_tungstenite::tungstenite::Message::Binary(data))
447                                .await
448                                .ok();
449                        }))
450                        .ok();
451                        return Ok(());
452                    }
453                };
454
455                let inline_waker = InlineWaker::new();
456                let waker = inline_waker.as_waker();
457                let mut cx = Context::from_waker(&waker);
458
459                let mut job = Box::pin(async move {
460                    if let Err(err) = tx_guard
461                        .send(tokio_tungstenite::tungstenite::Message::Binary(data))
462                        .await
463                    {
464                        tracing::error!("failed to send remaining bytes for request - {}", err);
465                    }
466                });
467
468                // First we try to finish it synchronously
469                if job.as_mut().poll(&mut cx).is_ready() {
470                    return Ok(());
471                }
472
473                // Otherwise we push it to the driver which will block all future send
474                // operations until it finishes
475                work.send(job).map_err(|err| {
476                    tracing::error!("failed to send remaining bytes for request - {}", err);
477                    NetworkError::ConnectionAborted
478                })?;
479                Ok(())
480            }
481        }
482    }
483}
484
485#[derive(derive_more::Debug)]
486pub(crate) enum RemoteRx<T>
487where
488    T: serde::de::DeserializeOwned,
489{
490    Mpsc {
491        rx: mpsc::Receiver<T>,
492        wakers: RemoteTxWakers,
493    },
494    Stream {
495        #[debug(ignore)]
496        rx: Pin<Box<dyn Stream<Item = std::io::Result<T>> + Send + 'static>>,
497    },
498    #[cfg(feature = "hyper")]
499    HyperWebSocket {
500        rx: futures_util::stream::SplitStream<
501            hyper_tungstenite::WebSocketStream<TokioIo<hyper::upgrade::Upgraded>>,
502        >,
503        format: crate::meta::FrameSerializationFormat,
504    },
505    #[cfg(feature = "tokio-tungstenite")]
506    TokioWebSocket {
507        rx: futures_util::stream::SplitStream<
508            tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<TcpStream>>,
509        >,
510        format: crate::meta::FrameSerializationFormat,
511    },
512}
513impl<T> RemoteRx<T>
514where
515    T: serde::de::DeserializeOwned,
516{
517    pub(crate) fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
518        loop {
519            return match self {
520                RemoteRx::Mpsc { rx, wakers } => {
521                    let ret = Pin::new(rx).poll_recv(cx);
522                    if ret.is_ready() {
523                        wakers.wake();
524                    }
525                    ret
526                }
527                RemoteRx::Stream { rx } => match rx.as_mut().poll_next(cx) {
528                    Poll::Ready(Some(Ok(msg))) => Poll::Ready(Some(msg)),
529                    Poll::Ready(Some(Err(err))) => {
530                        tracing::debug!("failed to read from channel - {}", err);
531                        Poll::Ready(None)
532                    }
533                    Poll::Ready(None) => Poll::Ready(None),
534                    Poll::Pending => Poll::Pending,
535                },
536                #[cfg(feature = "hyper")]
537                RemoteRx::HyperWebSocket { rx, format } => match Pin::new(rx).poll_next(cx) {
538                    Poll::Ready(Some(Ok(hyper_tungstenite::tungstenite::Message::Binary(msg)))) => {
539                        match format {
540                            crate::meta::FrameSerializationFormat::Bincode => {
541                                return match bincode::deserialize(&msg) {
542                                    Ok(msg) => Poll::Ready(Some(msg)),
543                                    Err(err) => {
544                                        tracing::warn!("failed to deserialize message - {}", err);
545                                        continue;
546                                    }
547                                };
548                            }
549                            format => {
550                                tracing::warn!("format not currently supported - {format:?}");
551                                continue;
552                            }
553                        }
554                    }
555                    Poll::Ready(Some(Ok(msg))) => {
556                        tracing::warn!("unsupported message from channel - {}", msg);
557                        continue;
558                    }
559                    Poll::Ready(Some(Err(err))) => {
560                        tracing::debug!("failed to read from channel - {}", err);
561                        Poll::Ready(None)
562                    }
563                    Poll::Ready(None) => Poll::Ready(None),
564                    Poll::Pending => Poll::Pending,
565                },
566                #[cfg(feature = "tokio-tungstenite")]
567                RemoteRx::TokioWebSocket { rx, format } => match Pin::new(rx).poll_next(cx) {
568                    Poll::Ready(Some(Ok(tokio_tungstenite::tungstenite::Message::Binary(msg)))) => {
569                        match format {
570                            crate::meta::FrameSerializationFormat::Bincode => {
571                                return match bincode::deserialize(&msg) {
572                                    Ok(msg) => Poll::Ready(Some(msg)),
573                                    Err(err) => {
574                                        tracing::warn!("failed to deserialize message - {}", err);
575                                        continue;
576                                    }
577                                };
578                            }
579                            format => {
580                                tracing::warn!("format not currently supported - {format:?}");
581                                continue;
582                            }
583                        }
584                    }
585                    Poll::Ready(Some(Ok(msg))) => {
586                        tracing::warn!("unsupported message from channel - {}", msg);
587                        continue;
588                    }
589                    Poll::Ready(Some(Err(err))) => {
590                        tracing::debug!("failed to read from channel - {}", err);
591                        Poll::Ready(None)
592                    }
593                    Poll::Ready(None) => Poll::Ready(None),
594                    Poll::Pending => Poll::Pending,
595                },
596            };
597        }
598    }
599}