1use std::{
2    pin::Pin,
3    sync::{Arc, Mutex},
4    task::{Context, Poll, Waker},
5};
6
7use crate::Result;
8use futures_util::{Future, Sink, SinkExt, Stream, future::BoxFuture};
9#[cfg(feature = "hyper")]
10use hyper_util::rt::tokio::TokioIo;
11use serde::Serialize;
12#[cfg(feature = "tokio-tungstenite")]
13use tokio::net::TcpStream;
14use tokio::sync::{
15    mpsc::{self, error::TrySendError},
16    oneshot,
17};
18
19use crate::{NetworkError, io_err_into_net_error};
20
21#[derive(Debug, Clone, Default)]
22pub(crate) struct RemoteTxWakers {
23    wakers: Arc<Mutex<Vec<Waker>>>,
24}
25impl RemoteTxWakers {
26    pub fn add(&self, waker: &Waker) {
27        let mut guard = self.wakers.lock().unwrap();
28        if !guard.iter().any(|w| w.will_wake(waker)) {
29            guard.push(waker.clone());
30        }
31    }
32    pub fn wake(&self) {
33        let mut guard = self.wakers.lock().unwrap();
34        guard.drain(..).for_each(|w| w.wake());
35    }
36}
37
38pub(crate) type StreamSink<T> = Pin<Box<dyn Sink<T, Error = std::io::Error> + Send + 'static>>;
39
40#[derive(derive_more::Debug)]
41pub(crate) enum RemoteTx<T>
42where
43    T: Serialize,
44{
45    Mpsc {
46        tx: mpsc::Sender<T>,
47        work: mpsc::UnboundedSender<BoxFuture<'static, ()>>,
48        wakers: RemoteTxWakers,
49    },
50    Stream {
51        #[debug(ignore)]
52        tx: Arc<tokio::sync::Mutex<StreamSink<T>>>,
53        work: mpsc::UnboundedSender<BoxFuture<'static, ()>>,
54        wakers: RemoteTxWakers,
55    },
56    #[cfg(feature = "hyper")]
57    HyperWebSocket {
58        tx: Arc<
59            tokio::sync::Mutex<
60                futures_util::stream::SplitSink<
61                    hyper_tungstenite::WebSocketStream<TokioIo<hyper::upgrade::Upgraded>>,
62                    hyper_tungstenite::tungstenite::Message,
63                >,
64            >,
65        >,
66        work: mpsc::UnboundedSender<BoxFuture<'static, ()>>,
67        wakers: RemoteTxWakers,
68        format: crate::meta::FrameSerializationFormat,
69    },
70    #[cfg(feature = "tokio-tungstenite")]
71    TokioWebSocket {
72        tx: Arc<
73            tokio::sync::Mutex<
74                futures_util::stream::SplitSink<
75                    tokio_tungstenite::WebSocketStream<
76                        tokio_tungstenite::MaybeTlsStream<TcpStream>,
77                    >,
78                    tokio_tungstenite::tungstenite::Message,
79                >,
80            >,
81        >,
82        work: mpsc::UnboundedSender<BoxFuture<'static, ()>>,
83        wakers: RemoteTxWakers,
84        format: crate::meta::FrameSerializationFormat,
85    },
86}
87impl<T> RemoteTx<T>
88where
89    T: Serialize + Send + Sync + 'static,
90{
91    pub(crate) async fn send(&self, req: T) -> Result<()> {
92        match self {
93            RemoteTx::Mpsc { tx, .. } => tx
94                .send(req)
95                .await
96                .map_err(|_| NetworkError::ConnectionAborted),
97            RemoteTx::Stream { tx, work, .. } => {
98                let (tx_done, rx_done) = oneshot::channel();
99                let tx = tx.clone();
100                work.send(Box::pin(async move {
101                    let job = async {
102                        let mut tx_guard = tx.lock_owned().await;
103                        tx_guard.send(req).await.map_err(io_err_into_net_error)
104                    };
105                    tx_done.send(job.await).ok();
106                }))
107                .map_err(|_| NetworkError::ConnectionAborted)?;
108
109                rx_done
110                    .await
111                    .unwrap_or(Err(NetworkError::ConnectionAborted))
112            }
113            #[cfg(feature = "hyper")]
114            RemoteTx::HyperWebSocket { tx, format, .. } => {
115                let data = match format {
116                    crate::meta::FrameSerializationFormat::Bincode => bincode::serialize(&req)
117                        .map_err(|err| {
118                            tracing::warn!("failed to serialize message - {err}");
119                            NetworkError::IOError
120                        })?,
121                    format => {
122                        tracing::warn!("format not currently supported - {format:?}");
123                        return Err(NetworkError::IOError);
124                    }
125                };
126                let mut tx = tx.lock().await;
127                tx.send(hyper_tungstenite::tungstenite::Message::Binary(data))
128                    .await
129                    .map_err(|_| NetworkError::ConnectionAborted)
130            }
131            #[cfg(feature = "tokio-tungstenite")]
132            RemoteTx::TokioWebSocket { tx, format, .. } => {
133                let data = match format {
134                    crate::meta::FrameSerializationFormat::Bincode => bincode::serialize(&req)
135                        .map_err(|err| {
136                            tracing::warn!("failed to serialize message - {err}");
137                            NetworkError::IOError
138                        })?,
139                    format => {
140                        tracing::warn!("format not currently supported - {format:?}");
141                        return Err(NetworkError::IOError);
142                    }
143                };
144                let mut tx = tx.lock().await;
145                tx.send(tokio_tungstenite::tungstenite::Message::Binary(data))
146                    .await
147                    .map_err(|_| NetworkError::ConnectionAborted)
148            }
149        }
150    }
151
152    pub(crate) fn poll_send(&self, cx: &mut Context<'_>, req: T) -> Poll<Result<()>> {
153        match self {
154            RemoteTx::Mpsc { tx, wakers, .. } => match tx.try_send(req) {
155                Ok(()) => Poll::Ready(Ok(())),
156                Err(TrySendError::Closed(_)) => Poll::Ready(Err(NetworkError::ConnectionAborted)),
157                Err(TrySendError::Full(_)) => {
158                    wakers.add(cx.waker());
159                    Poll::Pending
160                }
161            },
162            RemoteTx::Stream { tx, work, wakers } => {
163                let mut tx_guard = match tx.clone().try_lock_owned() {
164                    Ok(lock) => lock,
165                    Err(_) => {
166                        wakers.add(cx.waker());
167                        return Poll::Pending;
168                    }
169                };
170                match tx_guard.poll_ready_unpin(cx) {
171                    Poll::Ready(Ok(())) => {}
172                    Poll::Ready(Err(err)) => return Poll::Ready(Err(io_err_into_net_error(err))),
173                    Poll::Pending => return Poll::Pending,
174                }
175                let mut job = Box::pin(async move {
176                    if let Err(err) = tx_guard.send(req).await.map_err(io_err_into_net_error) {
177                        tracing::error!("failed to send remaining bytes for request - {}", err);
178                    }
179                });
180
181                if job.as_mut().poll(cx).is_ready() {
183                    return Poll::Ready(Ok(()));
184                }
185
186                work.send(job).map_err(|err| {
189                    tracing::error!("failed to send remaining bytes for request - {}", err);
190                    NetworkError::ConnectionAborted
191                })?;
192                Poll::Ready(Ok(()))
193            }
194            #[cfg(feature = "hyper")]
195            RemoteTx::HyperWebSocket {
196                tx,
197                format,
198                work,
199                wakers,
200                ..
201            } => {
202                let mut tx_guard = match tx.clone().try_lock_owned() {
203                    Ok(lock) => lock,
204                    Err(_) => {
205                        wakers.add(cx.waker());
206                        return Poll::Pending;
207                    }
208                };
209                match tx_guard.poll_ready_unpin(cx) {
210                    Poll::Ready(Ok(())) => {}
211                    Poll::Ready(Err(err)) => {
212                        tracing::warn!("failed to poll web socket for readiness - {err}");
213                        return Poll::Ready(Err(NetworkError::IOError));
214                    }
215                    Poll::Pending => return Poll::Pending,
216                }
217
218                let data = match format {
219                    crate::meta::FrameSerializationFormat::Bincode => bincode::serialize(&req)
220                        .map_err(|err| {
221                            tracing::warn!("failed to serialize message - {err}");
222                            NetworkError::IOError
223                        })?,
224                    format => {
225                        tracing::warn!("format not currently supported - {format:?}");
226                        return Poll::Ready(Err(NetworkError::IOError));
227                    }
228                };
229
230                let mut job = Box::pin(async move {
231                    if let Err(err) = tx_guard
232                        .send(hyper_tungstenite::tungstenite::Message::Binary(data))
233                        .await
234                    {
235                        tracing::error!("failed to send remaining bytes for request - {}", err);
236                    }
237                });
238
239                if job.as_mut().poll(cx).is_ready() {
241                    return Poll::Ready(Ok(()));
242                }
243
244                work.send(job).map_err(|err| {
247                    tracing::error!("failed to send remaining bytes for request - {}", err);
248                    NetworkError::ConnectionAborted
249                })?;
250                Poll::Ready(Ok(()))
251            }
252            #[cfg(feature = "tokio-tungstenite")]
253            RemoteTx::TokioWebSocket {
254                tx,
255                format,
256                work,
257                wakers,
258                ..
259            } => {
260                let mut tx_guard = match tx.clone().try_lock_owned() {
261                    Ok(lock) => lock,
262                    Err(_) => {
263                        wakers.add(cx.waker());
264                        return Poll::Pending;
265                    }
266                };
267                match tx_guard.poll_ready_unpin(cx) {
268                    Poll::Ready(Ok(())) => {}
269                    Poll::Ready(Err(err)) => {
270                        tracing::warn!("failed to poll web socket for readiness - {err}");
271                        return Poll::Ready(Err(NetworkError::IOError));
272                    }
273                    Poll::Pending => return Poll::Pending,
274                }
275
276                let data = match format {
277                    crate::meta::FrameSerializationFormat::Bincode => bincode::serialize(&req)
278                        .map_err(|err| {
279                            tracing::warn!("failed to serialize message - {err}");
280                            NetworkError::IOError
281                        })?,
282                    format => {
283                        tracing::warn!("format not currently supported - {format:?}");
284                        return Poll::Ready(Err(NetworkError::IOError));
285                    }
286                };
287
288                let mut job = Box::pin(async move {
289                    if let Err(err) = tx_guard
290                        .send(tokio_tungstenite::tungstenite::Message::Binary(data))
291                        .await
292                    {
293                        tracing::error!("failed to send remaining bytes for request - {}", err);
294                    }
295                });
296
297                if job.as_mut().poll(cx).is_ready() {
299                    return Poll::Ready(Ok(()));
300                }
301
302                work.send(job).map_err(|err| {
305                    tracing::error!("failed to send remaining bytes for request - {}", err);
306                    NetworkError::ConnectionAborted
307                })?;
308                Poll::Ready(Ok(()))
309            }
310        }
311    }
312
313    pub(crate) fn send_with_driver(&self, req: T) -> Result<()> {
314        match self {
315            RemoteTx::Mpsc { tx, work, .. } => match tx.try_send(req) {
316                Ok(()) => Ok(()),
317                Err(TrySendError::Closed(_)) => Err(NetworkError::ConnectionAborted),
318                Err(TrySendError::Full(req)) => {
319                    let tx = tx.clone();
320                    work.send(Box::pin(async move {
321                        tx.send(req).await.ok();
322                    }))
323                    .ok();
324                    Ok(())
325                }
326            },
327            RemoteTx::Stream { tx, work, .. } => {
328                let mut tx_guard = match tx.clone().try_lock_owned() {
329                    Ok(lock) => lock,
330                    Err(_) => {
331                        let tx = tx.clone();
332                        work.send(Box::pin(async move {
333                            let mut tx_guard = tx.lock().await;
334                            tx_guard.send(req).await.ok();
335                        }))
336                        .ok();
337                        return Ok(());
338                    }
339                };
340
341                let waker = NoopWaker::new_waker();
342                let mut cx = Context::from_waker(&waker);
343
344                let mut job = Box::pin(async move {
345                    if let Err(err) = tx_guard.send(req).await.map_err(io_err_into_net_error) {
346                        tracing::error!("failed to send remaining bytes for request - {}", err);
347                    }
348                });
349
350                if job.as_mut().poll(&mut cx).is_ready() {
352                    return Ok(());
353                }
354
355                work.send(job).map_err(|err| {
358                    tracing::error!("failed to send remaining bytes for request - {}", err);
359                    NetworkError::ConnectionAborted
360                })?;
361                Ok(())
362            }
363            #[cfg(feature = "hyper")]
364            RemoteTx::HyperWebSocket {
365                tx, format, work, ..
366            } => {
367                let data = match format {
368                    crate::meta::FrameSerializationFormat::Bincode => bincode::serialize(&req)
369                        .map_err(|err| {
370                            tracing::warn!("failed to serialize message - {err}");
371                            NetworkError::IOError
372                        })?,
373                    format => {
374                        tracing::warn!("format not currently supported - {format:?}");
375                        return Err(NetworkError::IOError);
376                    }
377                };
378
379                let mut tx_guard = match tx.clone().try_lock_owned() {
380                    Ok(lock) => lock,
381                    Err(_) => {
382                        let tx = tx.clone();
383                        work.send(Box::pin(async move {
384                            let mut tx_guard = tx.lock().await;
385                            tx_guard
386                                .send(hyper_tungstenite::tungstenite::Message::Binary(data))
387                                .await
388                                .ok();
389                        }))
390                        .ok();
391                        return Ok(());
392                    }
393                };
394
395                let waker = NoopWaker::new_waker();
396                let mut cx = Context::from_waker(&waker);
397
398                let mut job = Box::pin(async move {
399                    if let Err(err) = tx_guard
400                        .send(hyper_tungstenite::tungstenite::Message::Binary(data))
401                        .await
402                    {
403                        tracing::error!("failed to send remaining bytes for request - {}", err);
404                    }
405                });
406
407                if job.as_mut().poll(&mut cx).is_ready() {
409                    return Ok(());
410                }
411
412                work.send(job).map_err(|err| {
415                    tracing::error!("failed to send remaining bytes for request - {}", err);
416                    NetworkError::ConnectionAborted
417                })?;
418                Ok(())
419            }
420            #[cfg(feature = "tokio-tungstenite")]
421            RemoteTx::TokioWebSocket {
422                tx, format, work, ..
423            } => {
424                let data = match format {
425                    crate::meta::FrameSerializationFormat::Bincode => bincode::serialize(&req)
426                        .map_err(|err| {
427                            tracing::warn!("failed to serialize message - {err}");
428                            NetworkError::IOError
429                        })?,
430                    format => {
431                        tracing::warn!("format not currently supported - {format:?}");
432                        return Err(NetworkError::IOError);
433                    }
434                };
435
436                let mut tx_guard = match tx.clone().try_lock_owned() {
437                    Ok(lock) => lock,
438                    Err(_) => {
439                        let tx = tx.clone();
440                        work.send(Box::pin(async move {
441                            let mut tx_guard = tx.lock().await;
442                            tx_guard
443                                .send(tokio_tungstenite::tungstenite::Message::Binary(data))
444                                .await
445                                .ok();
446                        }))
447                        .ok();
448                        return Ok(());
449                    }
450                };
451
452                let waker = NoopWaker::new_waker();
453                let mut cx = Context::from_waker(&waker);
454
455                let mut job = Box::pin(async move {
456                    if let Err(err) = tx_guard
457                        .send(tokio_tungstenite::tungstenite::Message::Binary(data))
458                        .await
459                    {
460                        tracing::error!("failed to send remaining bytes for request - {}", err);
461                    }
462                });
463
464                if job.as_mut().poll(&mut cx).is_ready() {
466                    return Ok(());
467                }
468
469                work.send(job).map_err(|err| {
472                    tracing::error!("failed to send remaining bytes for request - {}", err);
473                    NetworkError::ConnectionAborted
474                })?;
475                Ok(())
476            }
477        }
478    }
479}
480
481#[derive(derive_more::Debug)]
482pub(crate) enum RemoteRx<T>
483where
484    T: serde::de::DeserializeOwned,
485{
486    Mpsc {
487        rx: mpsc::Receiver<T>,
488        wakers: RemoteTxWakers,
489    },
490    Stream {
491        #[debug(ignore)]
492        rx: Pin<Box<dyn Stream<Item = std::io::Result<T>> + Send + 'static>>,
493    },
494    #[cfg(feature = "hyper")]
495    HyperWebSocket {
496        rx: futures_util::stream::SplitStream<
497            hyper_tungstenite::WebSocketStream<TokioIo<hyper::upgrade::Upgraded>>,
498        >,
499        format: crate::meta::FrameSerializationFormat,
500    },
501    #[cfg(feature = "tokio-tungstenite")]
502    TokioWebSocket {
503        rx: futures_util::stream::SplitStream<
504            tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<TcpStream>>,
505        >,
506        format: crate::meta::FrameSerializationFormat,
507    },
508}
509impl<T> RemoteRx<T>
510where
511    T: serde::de::DeserializeOwned,
512{
513    pub(crate) fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
514        loop {
515            return match self {
516                RemoteRx::Mpsc { rx, wakers } => {
517                    let ret = Pin::new(rx).poll_recv(cx);
518                    if ret.is_ready() {
519                        wakers.wake();
520                    }
521                    ret
522                }
523                RemoteRx::Stream { rx } => match rx.as_mut().poll_next(cx) {
524                    Poll::Ready(Some(Ok(msg))) => Poll::Ready(Some(msg)),
525                    Poll::Ready(Some(Err(err))) => {
526                        tracing::debug!("failed to read from channel - {}", err);
527                        Poll::Ready(None)
528                    }
529                    Poll::Ready(None) => Poll::Ready(None),
530                    Poll::Pending => Poll::Pending,
531                },
532                #[cfg(feature = "hyper")]
533                RemoteRx::HyperWebSocket { rx, format } => match Pin::new(rx).poll_next(cx) {
534                    Poll::Ready(Some(Ok(hyper_tungstenite::tungstenite::Message::Binary(msg)))) => {
535                        match format {
536                            crate::meta::FrameSerializationFormat::Bincode => {
537                                return match bincode::deserialize(&msg) {
538                                    Ok(msg) => Poll::Ready(Some(msg)),
539                                    Err(err) => {
540                                        tracing::warn!("failed to deserialize message - {}", err);
541                                        continue;
542                                    }
543                                };
544                            }
545                            format => {
546                                tracing::warn!("format not currently supported - {format:?}");
547                                continue;
548                            }
549                        }
550                    }
551                    Poll::Ready(Some(Ok(msg))) => {
552                        tracing::warn!("unsupported message from channel - {}", msg);
553                        continue;
554                    }
555                    Poll::Ready(Some(Err(err))) => {
556                        tracing::debug!("failed to read from channel - {}", err);
557                        Poll::Ready(None)
558                    }
559                    Poll::Ready(None) => Poll::Ready(None),
560                    Poll::Pending => Poll::Pending,
561                },
562                #[cfg(feature = "tokio-tungstenite")]
563                RemoteRx::TokioWebSocket { rx, format } => match Pin::new(rx).poll_next(cx) {
564                    Poll::Ready(Some(Ok(tokio_tungstenite::tungstenite::Message::Binary(msg)))) => {
565                        match format {
566                            crate::meta::FrameSerializationFormat::Bincode => {
567                                return match bincode::deserialize(&msg) {
568                                    Ok(msg) => Poll::Ready(Some(msg)),
569                                    Err(err) => {
570                                        tracing::warn!("failed to deserialize message - {}", err);
571                                        continue;
572                                    }
573                                };
574                            }
575                            format => {
576                                tracing::warn!("format not currently supported - {format:?}");
577                                continue;
578                            }
579                        }
580                    }
581                    Poll::Ready(Some(Ok(msg))) => {
582                        tracing::warn!("unsupported message from channel - {}", msg);
583                        continue;
584                    }
585                    Poll::Ready(Some(Err(err))) => {
586                        tracing::debug!("failed to read from channel - {}", err);
587                        Poll::Ready(None)
588                    }
589                    Poll::Ready(None) => Poll::Ready(None),
590                    Poll::Pending => Poll::Pending,
591                },
592            };
593        }
594    }
595}
596
597struct NoopWaker;
598
599impl NoopWaker {
600    fn new_waker() -> Waker {
601        Waker::from(Arc::new(Self))
602    }
603}
604
605impl std::task::Wake for NoopWaker {
606    fn wake(self: Arc<Self>) {}
607}