virtual_net/
rx_tx.rs

1use std::{
2    pin::Pin,
3    sync::{Arc, Mutex},
4    task::{Context, Poll, Waker},
5};
6
7use crate::Result;
8use bincode::config;
9use futures_util::{Future, Sink, SinkExt, Stream, future::BoxFuture};
10#[cfg(feature = "hyper")]
11use hyper_util::rt::tokio::TokioIo;
12use serde::Serialize;
13#[cfg(feature = "tokio-tungstenite")]
14use tokio::net::TcpStream;
15use tokio::sync::{
16    mpsc::{self, error::TrySendError},
17    oneshot,
18};
19
20use crate::{NetworkError, io_err_into_net_error};
21
22#[derive(Debug, Clone, Default)]
23pub(crate) struct RemoteTxWakers {
24    wakers: Arc<Mutex<Vec<Waker>>>,
25}
26impl RemoteTxWakers {
27    pub fn add(&self, waker: &Waker) {
28        let mut guard = self.wakers.lock().unwrap();
29        if !guard.iter().any(|w| w.will_wake(waker)) {
30            guard.push(waker.clone());
31        }
32    }
33    pub fn wake(&self) {
34        let mut guard = self.wakers.lock().unwrap();
35        guard.drain(..).for_each(|w| w.wake());
36    }
37}
38
39pub(crate) type StreamSink<T> = Pin<Box<dyn Sink<T, Error = std::io::Error> + Send + 'static>>;
40
41#[derive(derive_more::Debug)]
42pub(crate) enum RemoteTx<T>
43where
44    T: Serialize,
45{
46    Mpsc {
47        tx: mpsc::Sender<T>,
48        work: mpsc::UnboundedSender<BoxFuture<'static, ()>>,
49        wakers: RemoteTxWakers,
50    },
51    Stream {
52        #[debug(ignore)]
53        tx: Arc<tokio::sync::Mutex<StreamSink<T>>>,
54        work: mpsc::UnboundedSender<BoxFuture<'static, ()>>,
55        wakers: RemoteTxWakers,
56    },
57    #[cfg(feature = "hyper")]
58    HyperWebSocket {
59        tx: Arc<
60            tokio::sync::Mutex<
61                futures_util::stream::SplitSink<
62                    hyper_tungstenite::WebSocketStream<TokioIo<hyper::upgrade::Upgraded>>,
63                    hyper_tungstenite::tungstenite::Message,
64                >,
65            >,
66        >,
67        work: mpsc::UnboundedSender<BoxFuture<'static, ()>>,
68        wakers: RemoteTxWakers,
69        format: crate::meta::FrameSerializationFormat,
70    },
71    #[cfg(feature = "tokio-tungstenite")]
72    TokioWebSocket {
73        tx: Arc<
74            tokio::sync::Mutex<
75                futures_util::stream::SplitSink<
76                    tokio_tungstenite::WebSocketStream<
77                        tokio_tungstenite::MaybeTlsStream<TcpStream>,
78                    >,
79                    tokio_tungstenite::tungstenite::Message,
80                >,
81            >,
82        >,
83        work: mpsc::UnboundedSender<BoxFuture<'static, ()>>,
84        wakers: RemoteTxWakers,
85        format: crate::meta::FrameSerializationFormat,
86    },
87}
88impl<T> RemoteTx<T>
89where
90    T: Serialize + Send + Sync + 'static,
91{
92    pub(crate) async fn send(&self, req: T) -> Result<()> {
93        match self {
94            RemoteTx::Mpsc { tx, .. } => tx
95                .send(req)
96                .await
97                .map_err(|_| NetworkError::ConnectionAborted),
98            RemoteTx::Stream { tx, work, .. } => {
99                let (tx_done, rx_done) = oneshot::channel();
100                let tx = tx.clone();
101                work.send(Box::pin(async move {
102                    let job = async {
103                        let mut tx_guard = tx.lock_owned().await;
104                        tx_guard.send(req).await.map_err(io_err_into_net_error)
105                    };
106                    tx_done.send(job.await).ok();
107                }))
108                .map_err(|_| NetworkError::ConnectionAborted)?;
109
110                rx_done
111                    .await
112                    .unwrap_or(Err(NetworkError::ConnectionAborted))
113            }
114            #[cfg(feature = "hyper")]
115            RemoteTx::HyperWebSocket { tx, format, .. } => {
116                let data = match format {
117                    crate::meta::FrameSerializationFormat::Bincode => {
118                        bincode::serde::encode_to_vec(&req, config::legacy()).map_err(|err| {
119                            tracing::warn!("failed to serialize message - {err}");
120                            NetworkError::IOError
121                        })?
122                    }
123                    format => {
124                        tracing::warn!("format not currently supported - {format:?}");
125                        return Err(NetworkError::IOError);
126                    }
127                };
128                let mut tx = tx.lock().await;
129                tx.send(hyper_tungstenite::tungstenite::Message::Binary(data))
130                    .await
131                    .map_err(|_| NetworkError::ConnectionAborted)
132            }
133            #[cfg(feature = "tokio-tungstenite")]
134            RemoteTx::TokioWebSocket { tx, format, .. } => {
135                let data = match format {
136                    crate::meta::FrameSerializationFormat::Bincode => {
137                        bincode::serde::encode_to_vec(&req, config::legacy()).map_err(|err| {
138                            tracing::warn!("failed to serialize message - {err}");
139                            NetworkError::IOError
140                        })?
141                    }
142                    format => {
143                        tracing::warn!("format not currently supported - {format:?}");
144                        return Err(NetworkError::IOError);
145                    }
146                };
147                let mut tx = tx.lock().await;
148                tx.send(tokio_tungstenite::tungstenite::Message::Binary(data))
149                    .await
150                    .map_err(|_| NetworkError::ConnectionAborted)
151            }
152        }
153    }
154
155    pub(crate) fn poll_send(&self, cx: &mut Context<'_>, req: T) -> Poll<Result<()>> {
156        match self {
157            RemoteTx::Mpsc { tx, wakers, .. } => match tx.try_send(req) {
158                Ok(()) => Poll::Ready(Ok(())),
159                Err(TrySendError::Closed(_)) => Poll::Ready(Err(NetworkError::ConnectionAborted)),
160                Err(TrySendError::Full(_)) => {
161                    wakers.add(cx.waker());
162                    Poll::Pending
163                }
164            },
165            RemoteTx::Stream { tx, work, wakers } => {
166                let mut tx_guard = match tx.clone().try_lock_owned() {
167                    Ok(lock) => lock,
168                    Err(_) => {
169                        wakers.add(cx.waker());
170                        return Poll::Pending;
171                    }
172                };
173                match tx_guard.poll_ready_unpin(cx) {
174                    Poll::Ready(Ok(())) => {}
175                    Poll::Ready(Err(err)) => return Poll::Ready(Err(io_err_into_net_error(err))),
176                    Poll::Pending => return Poll::Pending,
177                }
178                let mut job = Box::pin(async move {
179                    if let Err(err) = tx_guard.send(req).await.map_err(io_err_into_net_error) {
180                        tracing::error!("failed to send remaining bytes for request - {}", err);
181                    }
182                });
183
184                // First we try to finish it synchronously
185                if job.as_mut().poll(cx).is_ready() {
186                    return Poll::Ready(Ok(()));
187                }
188
189                // Otherwise we push it to the driver which will block all future send
190                // operations until it finishes
191                work.send(job).map_err(|err| {
192                    tracing::error!("failed to send remaining bytes for request - {}", err);
193                    NetworkError::ConnectionAborted
194                })?;
195                Poll::Ready(Ok(()))
196            }
197            #[cfg(feature = "hyper")]
198            RemoteTx::HyperWebSocket {
199                tx,
200                format,
201                work,
202                wakers,
203                ..
204            } => {
205                let mut tx_guard = match tx.clone().try_lock_owned() {
206                    Ok(lock) => lock,
207                    Err(_) => {
208                        wakers.add(cx.waker());
209                        return Poll::Pending;
210                    }
211                };
212                match tx_guard.poll_ready_unpin(cx) {
213                    Poll::Ready(Ok(())) => {}
214                    Poll::Ready(Err(err)) => {
215                        tracing::warn!("failed to poll web socket for readiness - {err}");
216                        return Poll::Ready(Err(NetworkError::IOError));
217                    }
218                    Poll::Pending => return Poll::Pending,
219                }
220
221                let data = match format {
222                    crate::meta::FrameSerializationFormat::Bincode => {
223                        bincode::serde::encode_to_vec(&req, config::legacy()).map_err(|err| {
224                            tracing::warn!("failed to serialize message - {err}");
225                            NetworkError::IOError
226                        })?
227                    }
228                    format => {
229                        tracing::warn!("format not currently supported - {format:?}");
230                        return Poll::Ready(Err(NetworkError::IOError));
231                    }
232                };
233
234                let mut job = Box::pin(async move {
235                    if let Err(err) = tx_guard
236                        .send(hyper_tungstenite::tungstenite::Message::Binary(data))
237                        .await
238                    {
239                        tracing::error!("failed to send remaining bytes for request - {}", err);
240                    }
241                });
242
243                // First we try to finish it synchronously
244                if job.as_mut().poll(cx).is_ready() {
245                    return Poll::Ready(Ok(()));
246                }
247
248                // Otherwise we push it to the driver which will block all future send
249                // operations until it finishes
250                work.send(job).map_err(|err| {
251                    tracing::error!("failed to send remaining bytes for request - {}", err);
252                    NetworkError::ConnectionAborted
253                })?;
254                Poll::Ready(Ok(()))
255            }
256            #[cfg(feature = "tokio-tungstenite")]
257            RemoteTx::TokioWebSocket {
258                tx,
259                format,
260                work,
261                wakers,
262                ..
263            } => {
264                let mut tx_guard = match tx.clone().try_lock_owned() {
265                    Ok(lock) => lock,
266                    Err(_) => {
267                        wakers.add(cx.waker());
268                        return Poll::Pending;
269                    }
270                };
271                match tx_guard.poll_ready_unpin(cx) {
272                    Poll::Ready(Ok(())) => {}
273                    Poll::Ready(Err(err)) => {
274                        tracing::warn!("failed to poll web socket for readiness - {err}");
275                        return Poll::Ready(Err(NetworkError::IOError));
276                    }
277                    Poll::Pending => return Poll::Pending,
278                }
279
280                let data = match format {
281                    crate::meta::FrameSerializationFormat::Bincode => {
282                        bincode::serde::encode_to_vec(&req, config::legacy()).map_err(|err| {
283                            tracing::warn!("failed to serialize message - {err}");
284                            NetworkError::IOError
285                        })?
286                    }
287                    format => {
288                        tracing::warn!("format not currently supported - {format:?}");
289                        return Poll::Ready(Err(NetworkError::IOError));
290                    }
291                };
292
293                let mut job = Box::pin(async move {
294                    if let Err(err) = tx_guard
295                        .send(tokio_tungstenite::tungstenite::Message::Binary(data))
296                        .await
297                    {
298                        tracing::error!("failed to send remaining bytes for request - {}", err);
299                    }
300                });
301
302                // First we try to finish it synchronously
303                if job.as_mut().poll(cx).is_ready() {
304                    return Poll::Ready(Ok(()));
305                }
306
307                // Otherwise we push it to the driver which will block all future send
308                // operations until it finishes
309                work.send(job).map_err(|err| {
310                    tracing::error!("failed to send remaining bytes for request - {}", err);
311                    NetworkError::ConnectionAborted
312                })?;
313                Poll::Ready(Ok(()))
314            }
315        }
316    }
317
318    pub(crate) fn send_with_driver(&self, req: T) -> Result<()> {
319        match self {
320            RemoteTx::Mpsc { tx, work, .. } => match tx.try_send(req) {
321                Ok(()) => Ok(()),
322                Err(TrySendError::Closed(_)) => Err(NetworkError::ConnectionAborted),
323                Err(TrySendError::Full(req)) => {
324                    let tx = tx.clone();
325                    work.send(Box::pin(async move {
326                        tx.send(req).await.ok();
327                    }))
328                    .ok();
329                    Ok(())
330                }
331            },
332            RemoteTx::Stream { tx, work, .. } => {
333                let mut tx_guard = match tx.clone().try_lock_owned() {
334                    Ok(lock) => lock,
335                    Err(_) => {
336                        let tx = tx.clone();
337                        work.send(Box::pin(async move {
338                            let mut tx_guard = tx.lock().await;
339                            tx_guard.send(req).await.ok();
340                        }))
341                        .ok();
342                        return Ok(());
343                    }
344                };
345
346                let waker = NoopWaker::new_waker();
347                let mut cx = Context::from_waker(&waker);
348
349                let mut job = Box::pin(async move {
350                    if let Err(err) = tx_guard.send(req).await.map_err(io_err_into_net_error) {
351                        tracing::error!("failed to send remaining bytes for request - {}", err);
352                    }
353                });
354
355                // First we try to finish it synchronously
356                if job.as_mut().poll(&mut cx).is_ready() {
357                    return Ok(());
358                }
359
360                // Otherwise we push it to the driver which will block all future send
361                // operations until it finishes
362                work.send(job).map_err(|err| {
363                    tracing::error!("failed to send remaining bytes for request - {}", err);
364                    NetworkError::ConnectionAborted
365                })?;
366                Ok(())
367            }
368            #[cfg(feature = "hyper")]
369            RemoteTx::HyperWebSocket {
370                tx, format, work, ..
371            } => {
372                let data = match format {
373                    crate::meta::FrameSerializationFormat::Bincode => {
374                        bincode::serde::encode_to_vec(&req, config::legacy()).map_err(|err| {
375                            tracing::warn!("failed to serialize message - {err}");
376                            NetworkError::IOError
377                        })?
378                    }
379                    format => {
380                        tracing::warn!("format not currently supported - {format:?}");
381                        return Err(NetworkError::IOError);
382                    }
383                };
384
385                let mut tx_guard = match tx.clone().try_lock_owned() {
386                    Ok(lock) => lock,
387                    Err(_) => {
388                        let tx = tx.clone();
389                        work.send(Box::pin(async move {
390                            let mut tx_guard = tx.lock().await;
391                            tx_guard
392                                .send(hyper_tungstenite::tungstenite::Message::Binary(data))
393                                .await
394                                .ok();
395                        }))
396                        .ok();
397                        return Ok(());
398                    }
399                };
400
401                let waker = NoopWaker::new_waker();
402                let mut cx = Context::from_waker(&waker);
403
404                let mut job = Box::pin(async move {
405                    if let Err(err) = tx_guard
406                        .send(hyper_tungstenite::tungstenite::Message::Binary(data))
407                        .await
408                    {
409                        tracing::error!("failed to send remaining bytes for request - {}", err);
410                    }
411                });
412
413                // First we try to finish it synchronously
414                if job.as_mut().poll(&mut cx).is_ready() {
415                    return Ok(());
416                }
417
418                // Otherwise we push it to the driver which will block all future send
419                // operations until it finishes
420                work.send(job).map_err(|err| {
421                    tracing::error!("failed to send remaining bytes for request - {}", err);
422                    NetworkError::ConnectionAborted
423                })?;
424                Ok(())
425            }
426            #[cfg(feature = "tokio-tungstenite")]
427            RemoteTx::TokioWebSocket {
428                tx, format, work, ..
429            } => {
430                let data = match format {
431                    crate::meta::FrameSerializationFormat::Bincode => {
432                        bincode::serde::encode_to_vec(&req, config::legacy()).map_err(|err| {
433                            tracing::warn!("failed to serialize message - {err}");
434                            NetworkError::IOError
435                        })?
436                    }
437                    format => {
438                        tracing::warn!("format not currently supported - {format:?}");
439                        return Err(NetworkError::IOError);
440                    }
441                };
442
443                let mut tx_guard = match tx.clone().try_lock_owned() {
444                    Ok(lock) => lock,
445                    Err(_) => {
446                        let tx = tx.clone();
447                        work.send(Box::pin(async move {
448                            let mut tx_guard = tx.lock().await;
449                            tx_guard
450                                .send(tokio_tungstenite::tungstenite::Message::Binary(data))
451                                .await
452                                .ok();
453                        }))
454                        .ok();
455                        return Ok(());
456                    }
457                };
458
459                let waker = NoopWaker::new_waker();
460                let mut cx = Context::from_waker(&waker);
461
462                let mut job = Box::pin(async move {
463                    if let Err(err) = tx_guard
464                        .send(tokio_tungstenite::tungstenite::Message::Binary(data))
465                        .await
466                    {
467                        tracing::error!("failed to send remaining bytes for request - {}", err);
468                    }
469                });
470
471                // First we try to finish it synchronously
472                if job.as_mut().poll(&mut cx).is_ready() {
473                    return Ok(());
474                }
475
476                // Otherwise we push it to the driver which will block all future send
477                // operations until it finishes
478                work.send(job).map_err(|err| {
479                    tracing::error!("failed to send remaining bytes for request - {}", err);
480                    NetworkError::ConnectionAborted
481                })?;
482                Ok(())
483            }
484        }
485    }
486}
487
488#[derive(derive_more::Debug)]
489pub(crate) enum RemoteRx<T>
490where
491    T: serde::de::DeserializeOwned,
492{
493    Mpsc {
494        rx: mpsc::Receiver<T>,
495        wakers: RemoteTxWakers,
496    },
497    Stream {
498        #[debug(ignore)]
499        rx: Pin<Box<dyn Stream<Item = std::io::Result<T>> + Send + 'static>>,
500    },
501    #[cfg(feature = "hyper")]
502    HyperWebSocket {
503        rx: futures_util::stream::SplitStream<
504            hyper_tungstenite::WebSocketStream<TokioIo<hyper::upgrade::Upgraded>>,
505        >,
506        format: crate::meta::FrameSerializationFormat,
507    },
508    #[cfg(feature = "tokio-tungstenite")]
509    TokioWebSocket {
510        rx: futures_util::stream::SplitStream<
511            tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<TcpStream>>,
512        >,
513        format: crate::meta::FrameSerializationFormat,
514    },
515}
516impl<T> RemoteRx<T>
517where
518    T: serde::de::DeserializeOwned,
519{
520    pub(crate) fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
521        loop {
522            return match self {
523                RemoteRx::Mpsc { rx, wakers } => {
524                    let ret = Pin::new(rx).poll_recv(cx);
525                    if ret.is_ready() {
526                        wakers.wake();
527                    }
528                    ret
529                }
530                RemoteRx::Stream { rx } => match rx.as_mut().poll_next(cx) {
531                    Poll::Ready(Some(Ok(msg))) => Poll::Ready(Some(msg)),
532                    Poll::Ready(Some(Err(err))) => {
533                        tracing::debug!("failed to read from channel - {}", err);
534                        Poll::Ready(None)
535                    }
536                    Poll::Ready(None) => Poll::Ready(None),
537                    Poll::Pending => Poll::Pending,
538                },
539                #[cfg(feature = "hyper")]
540                RemoteRx::HyperWebSocket { rx, format } => match Pin::new(rx).poll_next(cx) {
541                    Poll::Ready(Some(Ok(hyper_tungstenite::tungstenite::Message::Binary(msg)))) => {
542                        match format {
543                            crate::meta::FrameSerializationFormat::Bincode => {
544                                return match bincode::serde::decode_from_slice(
545                                    &msg,
546                                    config::legacy(),
547                                ) {
548                                    Ok((msg, _)) => Poll::Ready(Some(msg)),
549                                    Err(err) => {
550                                        tracing::warn!("failed to deserialize message - {}", err);
551                                        continue;
552                                    }
553                                };
554                            }
555                            format => {
556                                tracing::warn!("format not currently supported - {format:?}");
557                                continue;
558                            }
559                        }
560                    }
561                    Poll::Ready(Some(Ok(msg))) => {
562                        tracing::warn!("unsupported message from channel - {}", msg);
563                        continue;
564                    }
565                    Poll::Ready(Some(Err(err))) => {
566                        tracing::debug!("failed to read from channel - {}", err);
567                        Poll::Ready(None)
568                    }
569                    Poll::Ready(None) => Poll::Ready(None),
570                    Poll::Pending => Poll::Pending,
571                },
572                #[cfg(feature = "tokio-tungstenite")]
573                RemoteRx::TokioWebSocket { rx, format } => match Pin::new(rx).poll_next(cx) {
574                    Poll::Ready(Some(Ok(tokio_tungstenite::tungstenite::Message::Binary(msg)))) => {
575                        match format {
576                            crate::meta::FrameSerializationFormat::Bincode => {
577                                return match bincode::serde::decode_from_slice(
578                                    &msg,
579                                    config::legacy(),
580                                ) {
581                                    Ok((msg, _)) => Poll::Ready(Some(msg)),
582                                    Err(err) => {
583                                        tracing::warn!("failed to deserialize message - {}", err);
584                                        continue;
585                                    }
586                                };
587                            }
588                            format => {
589                                tracing::warn!("format not currently supported - {format:?}");
590                                continue;
591                            }
592                        }
593                    }
594                    Poll::Ready(Some(Ok(msg))) => {
595                        tracing::warn!("unsupported message from channel - {}", msg);
596                        continue;
597                    }
598                    Poll::Ready(Some(Err(err))) => {
599                        tracing::debug!("failed to read from channel - {}", err);
600                        Poll::Ready(None)
601                    }
602                    Poll::Ready(None) => Poll::Ready(None),
603                    Poll::Pending => Poll::Pending,
604                },
605            };
606        }
607    }
608}
609
610struct NoopWaker;
611
612impl NoopWaker {
613    fn new_waker() -> Waker {
614        Waker::from(Arc::new(Self))
615    }
616}
617
618impl std::task::Wake for NoopWaker {
619    fn wake(self: Arc<Self>) {}
620}