virtual_net/
rx_tx.rs

1use std::{
2    pin::Pin,
3    sync::{Arc, Mutex},
4    task::{Context, Poll, Waker},
5};
6
7use crate::Result;
8use bincode::config;
9use bytes::Bytes;
10use futures_util::{Future, Sink, SinkExt, Stream, future::BoxFuture};
11#[cfg(feature = "hyper")]
12use hyper_util::rt::tokio::TokioIo;
13use serde::Serialize;
14#[cfg(feature = "tokio-tungstenite")]
15use tokio::net::TcpStream;
16use tokio::sync::{
17    mpsc::{self, error::TrySendError},
18    oneshot,
19};
20
21use crate::{NetworkError, io_err_into_net_error};
22
23#[derive(Debug, Clone, Default)]
24pub(crate) struct RemoteTxWakers {
25    wakers: Arc<Mutex<Vec<Waker>>>,
26}
27impl RemoteTxWakers {
28    pub fn add(&self, waker: &Waker) {
29        let mut guard = self.wakers.lock().unwrap();
30        if !guard.iter().any(|w| w.will_wake(waker)) {
31            guard.push(waker.clone());
32        }
33    }
34    pub fn wake(&self) {
35        let mut guard = self.wakers.lock().unwrap();
36        guard.drain(..).for_each(|w| w.wake());
37    }
38}
39
40pub(crate) type StreamSink<T> = Pin<Box<dyn Sink<T, Error = std::io::Error> + Send + 'static>>;
41
42#[derive(derive_more::Debug)]
43pub(crate) enum RemoteTx<T>
44where
45    T: Serialize,
46{
47    Mpsc {
48        tx: mpsc::Sender<T>,
49        work: mpsc::UnboundedSender<BoxFuture<'static, ()>>,
50        wakers: RemoteTxWakers,
51    },
52    Stream {
53        #[debug(ignore)]
54        tx: Arc<tokio::sync::Mutex<StreamSink<T>>>,
55        work: mpsc::UnboundedSender<BoxFuture<'static, ()>>,
56        wakers: RemoteTxWakers,
57    },
58    #[cfg(feature = "hyper")]
59    HyperWebSocket {
60        tx: Arc<
61            tokio::sync::Mutex<
62                futures_util::stream::SplitSink<
63                    hyper_tungstenite::WebSocketStream<TokioIo<hyper::upgrade::Upgraded>>,
64                    hyper_tungstenite::tungstenite::Message,
65                >,
66            >,
67        >,
68        work: mpsc::UnboundedSender<BoxFuture<'static, ()>>,
69        wakers: RemoteTxWakers,
70        format: crate::meta::FrameSerializationFormat,
71    },
72    #[cfg(feature = "tokio-tungstenite")]
73    TokioWebSocket {
74        tx: Arc<
75            tokio::sync::Mutex<
76                futures_util::stream::SplitSink<
77                    tokio_tungstenite::WebSocketStream<
78                        tokio_tungstenite::MaybeTlsStream<TcpStream>,
79                    >,
80                    tokio_tungstenite::tungstenite::Message,
81                >,
82            >,
83        >,
84        work: mpsc::UnboundedSender<BoxFuture<'static, ()>>,
85        wakers: RemoteTxWakers,
86        format: crate::meta::FrameSerializationFormat,
87    },
88}
89impl<T> RemoteTx<T>
90where
91    T: Serialize + Send + Sync + 'static,
92{
93    pub(crate) async fn send(&self, req: T) -> Result<()> {
94        match self {
95            RemoteTx::Mpsc { tx, .. } => tx
96                .send(req)
97                .await
98                .map_err(|_| NetworkError::ConnectionAborted),
99            RemoteTx::Stream { tx, work, .. } => {
100                let (tx_done, rx_done) = oneshot::channel();
101                let tx = tx.clone();
102                work.send(Box::pin(async move {
103                    let job = async {
104                        let mut tx_guard = tx.lock_owned().await;
105                        tx_guard.send(req).await.map_err(io_err_into_net_error)
106                    };
107                    tx_done.send(job.await).ok();
108                }))
109                .map_err(|_| NetworkError::ConnectionAborted)?;
110
111                rx_done
112                    .await
113                    .unwrap_or(Err(NetworkError::ConnectionAborted))
114            }
115            #[cfg(feature = "hyper")]
116            RemoteTx::HyperWebSocket { tx, format, .. } => {
117                let data = match format {
118                    crate::meta::FrameSerializationFormat::Bincode => {
119                        bincode::serde::encode_to_vec(&req, config::legacy()).map_err(|err| {
120                            tracing::warn!("failed to serialize message - {err}");
121                            NetworkError::IOError
122                        })?
123                    }
124                    format => {
125                        tracing::warn!("format not currently supported - {format:?}");
126                        return Err(NetworkError::IOError);
127                    }
128                };
129                let mut tx = tx.lock().await;
130                tx.send(hyper_tungstenite::tungstenite::Message::Binary(
131                    Bytes::from_owner(data),
132                ))
133                .await
134                .map_err(|_| NetworkError::ConnectionAborted)
135            }
136            #[cfg(feature = "tokio-tungstenite")]
137            RemoteTx::TokioWebSocket { tx, format, .. } => {
138                let data = match format {
139                    crate::meta::FrameSerializationFormat::Bincode => {
140                        bincode::serde::encode_to_vec(&req, config::legacy()).map_err(|err| {
141                            tracing::warn!("failed to serialize message - {err}");
142                            NetworkError::IOError
143                        })?
144                    }
145                    format => {
146                        tracing::warn!("format not currently supported - {format:?}");
147                        return Err(NetworkError::IOError);
148                    }
149                };
150                let mut tx = tx.lock().await;
151                tx.send(tokio_tungstenite::tungstenite::Message::Binary(
152                    Bytes::from_owner(data),
153                ))
154                .await
155                .map_err(|_| NetworkError::ConnectionAborted)
156            }
157        }
158    }
159
160    pub(crate) fn poll_send(&self, cx: &mut Context<'_>, req: T) -> Poll<Result<()>> {
161        match self {
162            RemoteTx::Mpsc { tx, wakers, .. } => match tx.try_send(req) {
163                Ok(()) => Poll::Ready(Ok(())),
164                Err(TrySendError::Closed(_)) => Poll::Ready(Err(NetworkError::ConnectionAborted)),
165                Err(TrySendError::Full(_)) => {
166                    wakers.add(cx.waker());
167                    Poll::Pending
168                }
169            },
170            RemoteTx::Stream { tx, work, wakers } => {
171                let mut tx_guard = match tx.clone().try_lock_owned() {
172                    Ok(lock) => lock,
173                    Err(_) => {
174                        wakers.add(cx.waker());
175                        return Poll::Pending;
176                    }
177                };
178                match tx_guard.poll_ready_unpin(cx) {
179                    Poll::Ready(Ok(())) => {}
180                    Poll::Ready(Err(err)) => return Poll::Ready(Err(io_err_into_net_error(err))),
181                    Poll::Pending => return Poll::Pending,
182                }
183                let mut job = Box::pin(async move {
184                    if let Err(err) = tx_guard.send(req).await.map_err(io_err_into_net_error) {
185                        tracing::error!("failed to send remaining bytes for request - {}", err);
186                    }
187                });
188
189                // First we try to finish it synchronously
190                if job.as_mut().poll(cx).is_ready() {
191                    return Poll::Ready(Ok(()));
192                }
193
194                // Otherwise we push it to the driver which will block all future send
195                // operations until it finishes
196                work.send(job).map_err(|err| {
197                    tracing::error!("failed to send remaining bytes for request - {}", err);
198                    NetworkError::ConnectionAborted
199                })?;
200                Poll::Ready(Ok(()))
201            }
202            #[cfg(feature = "hyper")]
203            RemoteTx::HyperWebSocket {
204                tx,
205                format,
206                work,
207                wakers,
208                ..
209            } => {
210                let mut tx_guard = match tx.clone().try_lock_owned() {
211                    Ok(lock) => lock,
212                    Err(_) => {
213                        wakers.add(cx.waker());
214                        return Poll::Pending;
215                    }
216                };
217                match tx_guard.poll_ready_unpin(cx) {
218                    Poll::Ready(Ok(())) => {}
219                    Poll::Ready(Err(err)) => {
220                        tracing::warn!("failed to poll web socket for readiness - {err}");
221                        return Poll::Ready(Err(NetworkError::IOError));
222                    }
223                    Poll::Pending => return Poll::Pending,
224                }
225
226                let data = match format {
227                    crate::meta::FrameSerializationFormat::Bincode => {
228                        bincode::serde::encode_to_vec(&req, config::legacy()).map_err(|err| {
229                            tracing::warn!("failed to serialize message - {err}");
230                            NetworkError::IOError
231                        })?
232                    }
233                    format => {
234                        tracing::warn!("format not currently supported - {format:?}");
235                        return Poll::Ready(Err(NetworkError::IOError));
236                    }
237                };
238
239                let mut job = Box::pin(async move {
240                    if let Err(err) = tx_guard
241                        .send(hyper_tungstenite::tungstenite::Message::Binary(
242                            Bytes::from_owner(data),
243                        ))
244                        .await
245                    {
246                        tracing::error!("failed to send remaining bytes for request - {}", err);
247                    }
248                });
249
250                // First we try to finish it synchronously
251                if job.as_mut().poll(cx).is_ready() {
252                    return Poll::Ready(Ok(()));
253                }
254
255                // Otherwise we push it to the driver which will block all future send
256                // operations until it finishes
257                work.send(job).map_err(|err| {
258                    tracing::error!("failed to send remaining bytes for request - {}", err);
259                    NetworkError::ConnectionAborted
260                })?;
261                Poll::Ready(Ok(()))
262            }
263            #[cfg(feature = "tokio-tungstenite")]
264            RemoteTx::TokioWebSocket {
265                tx,
266                format,
267                work,
268                wakers,
269                ..
270            } => {
271                let mut tx_guard = match tx.clone().try_lock_owned() {
272                    Ok(lock) => lock,
273                    Err(_) => {
274                        wakers.add(cx.waker());
275                        return Poll::Pending;
276                    }
277                };
278                match tx_guard.poll_ready_unpin(cx) {
279                    Poll::Ready(Ok(())) => {}
280                    Poll::Ready(Err(err)) => {
281                        tracing::warn!("failed to poll web socket for readiness - {err}");
282                        return Poll::Ready(Err(NetworkError::IOError));
283                    }
284                    Poll::Pending => return Poll::Pending,
285                }
286
287                let data = match format {
288                    crate::meta::FrameSerializationFormat::Bincode => {
289                        bincode::serde::encode_to_vec(&req, config::legacy()).map_err(|err| {
290                            tracing::warn!("failed to serialize message - {err}");
291                            NetworkError::IOError
292                        })?
293                    }
294                    format => {
295                        tracing::warn!("format not currently supported - {format:?}");
296                        return Poll::Ready(Err(NetworkError::IOError));
297                    }
298                };
299
300                let mut job = Box::pin(async move {
301                    if let Err(err) = tx_guard
302                        .send(tokio_tungstenite::tungstenite::Message::Binary(
303                            Bytes::from_owner(data),
304                        ))
305                        .await
306                    {
307                        tracing::error!("failed to send remaining bytes for request - {}", err);
308                    }
309                });
310
311                // First we try to finish it synchronously
312                if job.as_mut().poll(cx).is_ready() {
313                    return Poll::Ready(Ok(()));
314                }
315
316                // Otherwise we push it to the driver which will block all future send
317                // operations until it finishes
318                work.send(job).map_err(|err| {
319                    tracing::error!("failed to send remaining bytes for request - {}", err);
320                    NetworkError::ConnectionAborted
321                })?;
322                Poll::Ready(Ok(()))
323            }
324        }
325    }
326
327    pub(crate) fn send_with_driver(&self, req: T) -> Result<()> {
328        match self {
329            RemoteTx::Mpsc { tx, work, .. } => match tx.try_send(req) {
330                Ok(()) => Ok(()),
331                Err(TrySendError::Closed(_)) => Err(NetworkError::ConnectionAborted),
332                Err(TrySendError::Full(req)) => {
333                    let tx = tx.clone();
334                    work.send(Box::pin(async move {
335                        tx.send(req).await.ok();
336                    }))
337                    .ok();
338                    Ok(())
339                }
340            },
341            RemoteTx::Stream { tx, work, .. } => {
342                let mut tx_guard = match tx.clone().try_lock_owned() {
343                    Ok(lock) => lock,
344                    Err(_) => {
345                        let tx = tx.clone();
346                        work.send(Box::pin(async move {
347                            let mut tx_guard = tx.lock().await;
348                            tx_guard.send(req).await.ok();
349                        }))
350                        .ok();
351                        return Ok(());
352                    }
353                };
354
355                let waker = NoopWaker::new_waker();
356                let mut cx = Context::from_waker(&waker);
357
358                let mut job = Box::pin(async move {
359                    if let Err(err) = tx_guard.send(req).await.map_err(io_err_into_net_error) {
360                        tracing::error!("failed to send remaining bytes for request - {}", err);
361                    }
362                });
363
364                // First we try to finish it synchronously
365                if job.as_mut().poll(&mut cx).is_ready() {
366                    return Ok(());
367                }
368
369                // Otherwise we push it to the driver which will block all future send
370                // operations until it finishes
371                work.send(job).map_err(|err| {
372                    tracing::error!("failed to send remaining bytes for request - {}", err);
373                    NetworkError::ConnectionAborted
374                })?;
375                Ok(())
376            }
377            #[cfg(feature = "hyper")]
378            RemoteTx::HyperWebSocket {
379                tx, format, work, ..
380            } => {
381                let data = match format {
382                    crate::meta::FrameSerializationFormat::Bincode => {
383                        bincode::serde::encode_to_vec(&req, config::legacy()).map_err(|err| {
384                            tracing::warn!("failed to serialize message - {err}");
385                            NetworkError::IOError
386                        })?
387                    }
388                    format => {
389                        tracing::warn!("format not currently supported - {format:?}");
390                        return Err(NetworkError::IOError);
391                    }
392                };
393
394                let mut tx_guard = match tx.clone().try_lock_owned() {
395                    Ok(lock) => lock,
396                    Err(_) => {
397                        let tx = tx.clone();
398                        work.send(Box::pin(async move {
399                            let mut tx_guard = tx.lock().await;
400                            tx_guard
401                                .send(hyper_tungstenite::tungstenite::Message::Binary(
402                                    Bytes::from_owner(data),
403                                ))
404                                .await
405                                .ok();
406                        }))
407                        .ok();
408                        return Ok(());
409                    }
410                };
411
412                let waker = NoopWaker::new_waker();
413                let mut cx = Context::from_waker(&waker);
414
415                let mut job = Box::pin(async move {
416                    if let Err(err) = tx_guard
417                        .send(hyper_tungstenite::tungstenite::Message::Binary(
418                            Bytes::from_owner(data),
419                        ))
420                        .await
421                    {
422                        tracing::error!("failed to send remaining bytes for request - {}", err);
423                    }
424                });
425
426                // First we try to finish it synchronously
427                if job.as_mut().poll(&mut cx).is_ready() {
428                    return Ok(());
429                }
430
431                // Otherwise we push it to the driver which will block all future send
432                // operations until it finishes
433                work.send(job).map_err(|err| {
434                    tracing::error!("failed to send remaining bytes for request - {}", err);
435                    NetworkError::ConnectionAborted
436                })?;
437                Ok(())
438            }
439            #[cfg(feature = "tokio-tungstenite")]
440            RemoteTx::TokioWebSocket {
441                tx, format, work, ..
442            } => {
443                let data = match format {
444                    crate::meta::FrameSerializationFormat::Bincode => {
445                        bincode::serde::encode_to_vec(&req, config::legacy()).map_err(|err| {
446                            tracing::warn!("failed to serialize message - {err}");
447                            NetworkError::IOError
448                        })?
449                    }
450                    format => {
451                        tracing::warn!("format not currently supported - {format:?}");
452                        return Err(NetworkError::IOError);
453                    }
454                };
455
456                let mut tx_guard = match tx.clone().try_lock_owned() {
457                    Ok(lock) => lock,
458                    Err(_) => {
459                        let tx = tx.clone();
460                        work.send(Box::pin(async move {
461                            let mut tx_guard = tx.lock().await;
462                            tx_guard
463                                .send(tokio_tungstenite::tungstenite::Message::Binary(
464                                    Bytes::from_owner(data),
465                                ))
466                                .await
467                                .ok();
468                        }))
469                        .ok();
470                        return Ok(());
471                    }
472                };
473
474                let waker = NoopWaker::new_waker();
475                let mut cx = Context::from_waker(&waker);
476
477                let mut job = Box::pin(async move {
478                    if let Err(err) = tx_guard
479                        .send(tokio_tungstenite::tungstenite::Message::Binary(
480                            Bytes::from_owner(data),
481                        ))
482                        .await
483                    {
484                        tracing::error!("failed to send remaining bytes for request - {}", err);
485                    }
486                });
487
488                // First we try to finish it synchronously
489                if job.as_mut().poll(&mut cx).is_ready() {
490                    return Ok(());
491                }
492
493                // Otherwise we push it to the driver which will block all future send
494                // operations until it finishes
495                work.send(job).map_err(|err| {
496                    tracing::error!("failed to send remaining bytes for request - {}", err);
497                    NetworkError::ConnectionAborted
498                })?;
499                Ok(())
500            }
501        }
502    }
503}
504
505#[derive(derive_more::Debug)]
506pub(crate) enum RemoteRx<T>
507where
508    T: serde::de::DeserializeOwned,
509{
510    Mpsc {
511        rx: mpsc::Receiver<T>,
512        wakers: RemoteTxWakers,
513    },
514    Stream {
515        #[debug(ignore)]
516        rx: Pin<Box<dyn Stream<Item = std::io::Result<T>> + Send + 'static>>,
517    },
518    #[cfg(feature = "hyper")]
519    HyperWebSocket {
520        rx: futures_util::stream::SplitStream<
521            hyper_tungstenite::WebSocketStream<TokioIo<hyper::upgrade::Upgraded>>,
522        >,
523        format: crate::meta::FrameSerializationFormat,
524    },
525    #[cfg(feature = "tokio-tungstenite")]
526    TokioWebSocket {
527        rx: futures_util::stream::SplitStream<
528            tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<TcpStream>>,
529        >,
530        format: crate::meta::FrameSerializationFormat,
531    },
532}
533impl<T> RemoteRx<T>
534where
535    T: serde::de::DeserializeOwned,
536{
537    pub(crate) fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
538        loop {
539            return match self {
540                RemoteRx::Mpsc { rx, wakers } => {
541                    let ret = Pin::new(rx).poll_recv(cx);
542                    if ret.is_ready() {
543                        wakers.wake();
544                    }
545                    ret
546                }
547                RemoteRx::Stream { rx } => match rx.as_mut().poll_next(cx) {
548                    Poll::Ready(Some(Ok(msg))) => Poll::Ready(Some(msg)),
549                    Poll::Ready(Some(Err(err))) => {
550                        tracing::debug!("failed to read from channel - {}", err);
551                        Poll::Ready(None)
552                    }
553                    Poll::Ready(None) => Poll::Ready(None),
554                    Poll::Pending => Poll::Pending,
555                },
556                #[cfg(feature = "hyper")]
557                RemoteRx::HyperWebSocket { rx, format } => match Pin::new(rx).poll_next(cx) {
558                    Poll::Ready(Some(Ok(hyper_tungstenite::tungstenite::Message::Binary(msg)))) => {
559                        match format {
560                            crate::meta::FrameSerializationFormat::Bincode => {
561                                return match bincode::serde::decode_from_slice(
562                                    &msg,
563                                    config::legacy(),
564                                ) {
565                                    Ok((msg, _)) => Poll::Ready(Some(msg)),
566                                    Err(err) => {
567                                        tracing::warn!("failed to deserialize message - {}", err);
568                                        continue;
569                                    }
570                                };
571                            }
572                            format => {
573                                tracing::warn!("format not currently supported - {format:?}");
574                                continue;
575                            }
576                        }
577                    }
578                    Poll::Ready(Some(Ok(msg))) => {
579                        tracing::warn!("unsupported message from channel - {}", msg);
580                        continue;
581                    }
582                    Poll::Ready(Some(Err(err))) => {
583                        tracing::debug!("failed to read from channel - {}", err);
584                        Poll::Ready(None)
585                    }
586                    Poll::Ready(None) => Poll::Ready(None),
587                    Poll::Pending => Poll::Pending,
588                },
589                #[cfg(feature = "tokio-tungstenite")]
590                RemoteRx::TokioWebSocket { rx, format } => match Pin::new(rx).poll_next(cx) {
591                    Poll::Ready(Some(Ok(tokio_tungstenite::tungstenite::Message::Binary(msg)))) => {
592                        match format {
593                            crate::meta::FrameSerializationFormat::Bincode => {
594                                return match bincode::serde::decode_from_slice(
595                                    &msg,
596                                    config::legacy(),
597                                ) {
598                                    Ok((msg, _)) => Poll::Ready(Some(msg)),
599                                    Err(err) => {
600                                        tracing::warn!("failed to deserialize message - {}", err);
601                                        continue;
602                                    }
603                                };
604                            }
605                            format => {
606                                tracing::warn!("format not currently supported - {format:?}");
607                                continue;
608                            }
609                        }
610                    }
611                    Poll::Ready(Some(Ok(msg))) => {
612                        tracing::warn!("unsupported message from channel - {}", msg);
613                        continue;
614                    }
615                    Poll::Ready(Some(Err(err))) => {
616                        tracing::debug!("failed to read from channel - {}", err);
617                        Poll::Ready(None)
618                    }
619                    Poll::Ready(None) => Poll::Ready(None),
620                    Poll::Pending => Poll::Pending,
621                },
622            };
623        }
624    }
625}
626
627struct NoopWaker;
628
629impl NoopWaker {
630    fn new_waker() -> Waker {
631        Waker::from(Arc::new(Self))
632    }
633}
634
635impl std::task::Wake for NoopWaker {
636    fn wake(self: Arc<Self>) {}
637}