1use std::{
2 pin::Pin,
3 sync::{Arc, Mutex},
4 task::{Context, Poll, Waker},
5};
6
7use crate::Result;
8use bincode::config;
9use bytes::Bytes;
10use futures_util::{Future, Sink, SinkExt, Stream, future::BoxFuture};
11#[cfg(feature = "hyper")]
12use hyper_util::rt::tokio::TokioIo;
13use serde::Serialize;
14#[cfg(feature = "tokio-tungstenite")]
15use tokio::net::TcpStream;
16use tokio::sync::{
17 mpsc::{self, error::TrySendError},
18 oneshot,
19};
20
21use crate::{NetworkError, io_err_into_net_error};
22
23#[derive(Debug, Clone, Default)]
24pub(crate) struct RemoteTxWakers {
25 wakers: Arc<Mutex<Vec<Waker>>>,
26}
27impl RemoteTxWakers {
28 pub fn add(&self, waker: &Waker) {
29 let mut guard = self.wakers.lock().unwrap();
30 if !guard.iter().any(|w| w.will_wake(waker)) {
31 guard.push(waker.clone());
32 }
33 }
34 pub fn wake(&self) {
35 let mut guard = self.wakers.lock().unwrap();
36 guard.drain(..).for_each(|w| w.wake());
37 }
38}
39
40pub(crate) type StreamSink<T> = Pin<Box<dyn Sink<T, Error = std::io::Error> + Send + 'static>>;
41
42#[derive(derive_more::Debug)]
43pub(crate) enum RemoteTx<T>
44where
45 T: Serialize,
46{
47 Mpsc {
48 tx: mpsc::Sender<T>,
49 work: mpsc::UnboundedSender<BoxFuture<'static, ()>>,
50 wakers: RemoteTxWakers,
51 },
52 Stream {
53 #[debug(ignore)]
54 tx: Arc<tokio::sync::Mutex<StreamSink<T>>>,
55 work: mpsc::UnboundedSender<BoxFuture<'static, ()>>,
56 wakers: RemoteTxWakers,
57 },
58 #[cfg(feature = "hyper")]
59 HyperWebSocket {
60 tx: Arc<
61 tokio::sync::Mutex<
62 futures_util::stream::SplitSink<
63 hyper_tungstenite::WebSocketStream<TokioIo<hyper::upgrade::Upgraded>>,
64 hyper_tungstenite::tungstenite::Message,
65 >,
66 >,
67 >,
68 work: mpsc::UnboundedSender<BoxFuture<'static, ()>>,
69 wakers: RemoteTxWakers,
70 format: crate::meta::FrameSerializationFormat,
71 },
72 #[cfg(feature = "tokio-tungstenite")]
73 TokioWebSocket {
74 tx: Arc<
75 tokio::sync::Mutex<
76 futures_util::stream::SplitSink<
77 tokio_tungstenite::WebSocketStream<
78 tokio_tungstenite::MaybeTlsStream<TcpStream>,
79 >,
80 tokio_tungstenite::tungstenite::Message,
81 >,
82 >,
83 >,
84 work: mpsc::UnboundedSender<BoxFuture<'static, ()>>,
85 wakers: RemoteTxWakers,
86 format: crate::meta::FrameSerializationFormat,
87 },
88}
89impl<T> RemoteTx<T>
90where
91 T: Serialize + Send + Sync + 'static,
92{
93 pub(crate) async fn send(&self, req: T) -> Result<()> {
94 match self {
95 RemoteTx::Mpsc { tx, .. } => tx
96 .send(req)
97 .await
98 .map_err(|_| NetworkError::ConnectionAborted),
99 RemoteTx::Stream { tx, work, .. } => {
100 let (tx_done, rx_done) = oneshot::channel();
101 let tx = tx.clone();
102 work.send(Box::pin(async move {
103 let job = async {
104 let mut tx_guard = tx.lock_owned().await;
105 tx_guard.send(req).await.map_err(io_err_into_net_error)
106 };
107 tx_done.send(job.await).ok();
108 }))
109 .map_err(|_| NetworkError::ConnectionAborted)?;
110
111 rx_done
112 .await
113 .unwrap_or(Err(NetworkError::ConnectionAborted))
114 }
115 #[cfg(feature = "hyper")]
116 RemoteTx::HyperWebSocket { tx, format, .. } => {
117 let data = match format {
118 crate::meta::FrameSerializationFormat::Bincode => {
119 bincode::serde::encode_to_vec(&req, config::legacy()).map_err(|err| {
120 tracing::warn!("failed to serialize message - {err}");
121 NetworkError::IOError
122 })?
123 }
124 format => {
125 tracing::warn!("format not currently supported - {format:?}");
126 return Err(NetworkError::IOError);
127 }
128 };
129 let mut tx = tx.lock().await;
130 tx.send(hyper_tungstenite::tungstenite::Message::Binary(
131 Bytes::from_owner(data),
132 ))
133 .await
134 .map_err(|_| NetworkError::ConnectionAborted)
135 }
136 #[cfg(feature = "tokio-tungstenite")]
137 RemoteTx::TokioWebSocket { tx, format, .. } => {
138 let data = match format {
139 crate::meta::FrameSerializationFormat::Bincode => {
140 bincode::serde::encode_to_vec(&req, config::legacy()).map_err(|err| {
141 tracing::warn!("failed to serialize message - {err}");
142 NetworkError::IOError
143 })?
144 }
145 format => {
146 tracing::warn!("format not currently supported - {format:?}");
147 return Err(NetworkError::IOError);
148 }
149 };
150 let mut tx = tx.lock().await;
151 tx.send(tokio_tungstenite::tungstenite::Message::Binary(
152 Bytes::from_owner(data),
153 ))
154 .await
155 .map_err(|_| NetworkError::ConnectionAborted)
156 }
157 }
158 }
159
160 pub(crate) fn poll_send(&self, cx: &mut Context<'_>, req: T) -> Poll<Result<()>> {
161 match self {
162 RemoteTx::Mpsc { tx, wakers, .. } => match tx.try_send(req) {
163 Ok(()) => Poll::Ready(Ok(())),
164 Err(TrySendError::Closed(_)) => Poll::Ready(Err(NetworkError::ConnectionAborted)),
165 Err(TrySendError::Full(_)) => {
166 wakers.add(cx.waker());
167 Poll::Pending
168 }
169 },
170 RemoteTx::Stream { tx, work, wakers } => {
171 let mut tx_guard = match tx.clone().try_lock_owned() {
172 Ok(lock) => lock,
173 Err(_) => {
174 wakers.add(cx.waker());
175 return Poll::Pending;
176 }
177 };
178 match tx_guard.poll_ready_unpin(cx) {
179 Poll::Ready(Ok(())) => {}
180 Poll::Ready(Err(err)) => return Poll::Ready(Err(io_err_into_net_error(err))),
181 Poll::Pending => return Poll::Pending,
182 }
183 let mut job = Box::pin(async move {
184 if let Err(err) = tx_guard.send(req).await.map_err(io_err_into_net_error) {
185 tracing::error!("failed to send remaining bytes for request - {}", err);
186 }
187 });
188
189 if job.as_mut().poll(cx).is_ready() {
191 return Poll::Ready(Ok(()));
192 }
193
194 work.send(job).map_err(|err| {
197 tracing::error!("failed to send remaining bytes for request - {}", err);
198 NetworkError::ConnectionAborted
199 })?;
200 Poll::Ready(Ok(()))
201 }
202 #[cfg(feature = "hyper")]
203 RemoteTx::HyperWebSocket {
204 tx,
205 format,
206 work,
207 wakers,
208 ..
209 } => {
210 let mut tx_guard = match tx.clone().try_lock_owned() {
211 Ok(lock) => lock,
212 Err(_) => {
213 wakers.add(cx.waker());
214 return Poll::Pending;
215 }
216 };
217 match tx_guard.poll_ready_unpin(cx) {
218 Poll::Ready(Ok(())) => {}
219 Poll::Ready(Err(err)) => {
220 tracing::warn!("failed to poll web socket for readiness - {err}");
221 return Poll::Ready(Err(NetworkError::IOError));
222 }
223 Poll::Pending => return Poll::Pending,
224 }
225
226 let data = match format {
227 crate::meta::FrameSerializationFormat::Bincode => {
228 bincode::serde::encode_to_vec(&req, config::legacy()).map_err(|err| {
229 tracing::warn!("failed to serialize message - {err}");
230 NetworkError::IOError
231 })?
232 }
233 format => {
234 tracing::warn!("format not currently supported - {format:?}");
235 return Poll::Ready(Err(NetworkError::IOError));
236 }
237 };
238
239 let mut job = Box::pin(async move {
240 if let Err(err) = tx_guard
241 .send(hyper_tungstenite::tungstenite::Message::Binary(
242 Bytes::from_owner(data),
243 ))
244 .await
245 {
246 tracing::error!("failed to send remaining bytes for request - {}", err);
247 }
248 });
249
250 if job.as_mut().poll(cx).is_ready() {
252 return Poll::Ready(Ok(()));
253 }
254
255 work.send(job).map_err(|err| {
258 tracing::error!("failed to send remaining bytes for request - {}", err);
259 NetworkError::ConnectionAborted
260 })?;
261 Poll::Ready(Ok(()))
262 }
263 #[cfg(feature = "tokio-tungstenite")]
264 RemoteTx::TokioWebSocket {
265 tx,
266 format,
267 work,
268 wakers,
269 ..
270 } => {
271 let mut tx_guard = match tx.clone().try_lock_owned() {
272 Ok(lock) => lock,
273 Err(_) => {
274 wakers.add(cx.waker());
275 return Poll::Pending;
276 }
277 };
278 match tx_guard.poll_ready_unpin(cx) {
279 Poll::Ready(Ok(())) => {}
280 Poll::Ready(Err(err)) => {
281 tracing::warn!("failed to poll web socket for readiness - {err}");
282 return Poll::Ready(Err(NetworkError::IOError));
283 }
284 Poll::Pending => return Poll::Pending,
285 }
286
287 let data = match format {
288 crate::meta::FrameSerializationFormat::Bincode => {
289 bincode::serde::encode_to_vec(&req, config::legacy()).map_err(|err| {
290 tracing::warn!("failed to serialize message - {err}");
291 NetworkError::IOError
292 })?
293 }
294 format => {
295 tracing::warn!("format not currently supported - {format:?}");
296 return Poll::Ready(Err(NetworkError::IOError));
297 }
298 };
299
300 let mut job = Box::pin(async move {
301 if let Err(err) = tx_guard
302 .send(tokio_tungstenite::tungstenite::Message::Binary(
303 Bytes::from_owner(data),
304 ))
305 .await
306 {
307 tracing::error!("failed to send remaining bytes for request - {}", err);
308 }
309 });
310
311 if job.as_mut().poll(cx).is_ready() {
313 return Poll::Ready(Ok(()));
314 }
315
316 work.send(job).map_err(|err| {
319 tracing::error!("failed to send remaining bytes for request - {}", err);
320 NetworkError::ConnectionAborted
321 })?;
322 Poll::Ready(Ok(()))
323 }
324 }
325 }
326
327 pub(crate) fn send_with_driver(&self, req: T) -> Result<()> {
328 match self {
329 RemoteTx::Mpsc { tx, work, .. } => match tx.try_send(req) {
330 Ok(()) => Ok(()),
331 Err(TrySendError::Closed(_)) => Err(NetworkError::ConnectionAborted),
332 Err(TrySendError::Full(req)) => {
333 let tx = tx.clone();
334 work.send(Box::pin(async move {
335 tx.send(req).await.ok();
336 }))
337 .ok();
338 Ok(())
339 }
340 },
341 RemoteTx::Stream { tx, work, .. } => {
342 let mut tx_guard = match tx.clone().try_lock_owned() {
343 Ok(lock) => lock,
344 Err(_) => {
345 let tx = tx.clone();
346 work.send(Box::pin(async move {
347 let mut tx_guard = tx.lock().await;
348 tx_guard.send(req).await.ok();
349 }))
350 .ok();
351 return Ok(());
352 }
353 };
354
355 let waker = NoopWaker::new_waker();
356 let mut cx = Context::from_waker(&waker);
357
358 let mut job = Box::pin(async move {
359 if let Err(err) = tx_guard.send(req).await.map_err(io_err_into_net_error) {
360 tracing::error!("failed to send remaining bytes for request - {}", err);
361 }
362 });
363
364 if job.as_mut().poll(&mut cx).is_ready() {
366 return Ok(());
367 }
368
369 work.send(job).map_err(|err| {
372 tracing::error!("failed to send remaining bytes for request - {}", err);
373 NetworkError::ConnectionAborted
374 })?;
375 Ok(())
376 }
377 #[cfg(feature = "hyper")]
378 RemoteTx::HyperWebSocket {
379 tx, format, work, ..
380 } => {
381 let data = match format {
382 crate::meta::FrameSerializationFormat::Bincode => {
383 bincode::serde::encode_to_vec(&req, config::legacy()).map_err(|err| {
384 tracing::warn!("failed to serialize message - {err}");
385 NetworkError::IOError
386 })?
387 }
388 format => {
389 tracing::warn!("format not currently supported - {format:?}");
390 return Err(NetworkError::IOError);
391 }
392 };
393
394 let mut tx_guard = match tx.clone().try_lock_owned() {
395 Ok(lock) => lock,
396 Err(_) => {
397 let tx = tx.clone();
398 work.send(Box::pin(async move {
399 let mut tx_guard = tx.lock().await;
400 tx_guard
401 .send(hyper_tungstenite::tungstenite::Message::Binary(
402 Bytes::from_owner(data),
403 ))
404 .await
405 .ok();
406 }))
407 .ok();
408 return Ok(());
409 }
410 };
411
412 let waker = NoopWaker::new_waker();
413 let mut cx = Context::from_waker(&waker);
414
415 let mut job = Box::pin(async move {
416 if let Err(err) = tx_guard
417 .send(hyper_tungstenite::tungstenite::Message::Binary(
418 Bytes::from_owner(data),
419 ))
420 .await
421 {
422 tracing::error!("failed to send remaining bytes for request - {}", err);
423 }
424 });
425
426 if job.as_mut().poll(&mut cx).is_ready() {
428 return Ok(());
429 }
430
431 work.send(job).map_err(|err| {
434 tracing::error!("failed to send remaining bytes for request - {}", err);
435 NetworkError::ConnectionAborted
436 })?;
437 Ok(())
438 }
439 #[cfg(feature = "tokio-tungstenite")]
440 RemoteTx::TokioWebSocket {
441 tx, format, work, ..
442 } => {
443 let data = match format {
444 crate::meta::FrameSerializationFormat::Bincode => {
445 bincode::serde::encode_to_vec(&req, config::legacy()).map_err(|err| {
446 tracing::warn!("failed to serialize message - {err}");
447 NetworkError::IOError
448 })?
449 }
450 format => {
451 tracing::warn!("format not currently supported - {format:?}");
452 return Err(NetworkError::IOError);
453 }
454 };
455
456 let mut tx_guard = match tx.clone().try_lock_owned() {
457 Ok(lock) => lock,
458 Err(_) => {
459 let tx = tx.clone();
460 work.send(Box::pin(async move {
461 let mut tx_guard = tx.lock().await;
462 tx_guard
463 .send(tokio_tungstenite::tungstenite::Message::Binary(
464 Bytes::from_owner(data),
465 ))
466 .await
467 .ok();
468 }))
469 .ok();
470 return Ok(());
471 }
472 };
473
474 let waker = NoopWaker::new_waker();
475 let mut cx = Context::from_waker(&waker);
476
477 let mut job = Box::pin(async move {
478 if let Err(err) = tx_guard
479 .send(tokio_tungstenite::tungstenite::Message::Binary(
480 Bytes::from_owner(data),
481 ))
482 .await
483 {
484 tracing::error!("failed to send remaining bytes for request - {}", err);
485 }
486 });
487
488 if job.as_mut().poll(&mut cx).is_ready() {
490 return Ok(());
491 }
492
493 work.send(job).map_err(|err| {
496 tracing::error!("failed to send remaining bytes for request - {}", err);
497 NetworkError::ConnectionAborted
498 })?;
499 Ok(())
500 }
501 }
502 }
503}
504
505#[derive(derive_more::Debug)]
506pub(crate) enum RemoteRx<T>
507where
508 T: serde::de::DeserializeOwned,
509{
510 Mpsc {
511 rx: mpsc::Receiver<T>,
512 wakers: RemoteTxWakers,
513 },
514 Stream {
515 #[debug(ignore)]
516 rx: Pin<Box<dyn Stream<Item = std::io::Result<T>> + Send + 'static>>,
517 },
518 #[cfg(feature = "hyper")]
519 HyperWebSocket {
520 rx: futures_util::stream::SplitStream<
521 hyper_tungstenite::WebSocketStream<TokioIo<hyper::upgrade::Upgraded>>,
522 >,
523 format: crate::meta::FrameSerializationFormat,
524 },
525 #[cfg(feature = "tokio-tungstenite")]
526 TokioWebSocket {
527 rx: futures_util::stream::SplitStream<
528 tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<TcpStream>>,
529 >,
530 format: crate::meta::FrameSerializationFormat,
531 },
532}
533impl<T> RemoteRx<T>
534where
535 T: serde::de::DeserializeOwned,
536{
537 pub(crate) fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
538 loop {
539 return match self {
540 RemoteRx::Mpsc { rx, wakers } => {
541 let ret = Pin::new(rx).poll_recv(cx);
542 if ret.is_ready() {
543 wakers.wake();
544 }
545 ret
546 }
547 RemoteRx::Stream { rx } => match rx.as_mut().poll_next(cx) {
548 Poll::Ready(Some(Ok(msg))) => Poll::Ready(Some(msg)),
549 Poll::Ready(Some(Err(err))) => {
550 tracing::debug!("failed to read from channel - {}", err);
551 Poll::Ready(None)
552 }
553 Poll::Ready(None) => Poll::Ready(None),
554 Poll::Pending => Poll::Pending,
555 },
556 #[cfg(feature = "hyper")]
557 RemoteRx::HyperWebSocket { rx, format } => match Pin::new(rx).poll_next(cx) {
558 Poll::Ready(Some(Ok(hyper_tungstenite::tungstenite::Message::Binary(msg)))) => {
559 match format {
560 crate::meta::FrameSerializationFormat::Bincode => {
561 return match bincode::serde::decode_from_slice(
562 &msg,
563 config::legacy(),
564 ) {
565 Ok((msg, _)) => Poll::Ready(Some(msg)),
566 Err(err) => {
567 tracing::warn!("failed to deserialize message - {}", err);
568 continue;
569 }
570 };
571 }
572 format => {
573 tracing::warn!("format not currently supported - {format:?}");
574 continue;
575 }
576 }
577 }
578 Poll::Ready(Some(Ok(msg))) => {
579 tracing::warn!("unsupported message from channel - {}", msg);
580 continue;
581 }
582 Poll::Ready(Some(Err(err))) => {
583 tracing::debug!("failed to read from channel - {}", err);
584 Poll::Ready(None)
585 }
586 Poll::Ready(None) => Poll::Ready(None),
587 Poll::Pending => Poll::Pending,
588 },
589 #[cfg(feature = "tokio-tungstenite")]
590 RemoteRx::TokioWebSocket { rx, format } => match Pin::new(rx).poll_next(cx) {
591 Poll::Ready(Some(Ok(tokio_tungstenite::tungstenite::Message::Binary(msg)))) => {
592 match format {
593 crate::meta::FrameSerializationFormat::Bincode => {
594 return match bincode::serde::decode_from_slice(
595 &msg,
596 config::legacy(),
597 ) {
598 Ok((msg, _)) => Poll::Ready(Some(msg)),
599 Err(err) => {
600 tracing::warn!("failed to deserialize message - {}", err);
601 continue;
602 }
603 };
604 }
605 format => {
606 tracing::warn!("format not currently supported - {format:?}");
607 continue;
608 }
609 }
610 }
611 Poll::Ready(Some(Ok(msg))) => {
612 tracing::warn!("unsupported message from channel - {}", msg);
613 continue;
614 }
615 Poll::Ready(Some(Err(err))) => {
616 tracing::debug!("failed to read from channel - {}", err);
617 Poll::Ready(None)
618 }
619 Poll::Ready(None) => Poll::Ready(None),
620 Poll::Pending => Poll::Pending,
621 },
622 };
623 }
624 }
625}
626
627struct NoopWaker;
628
629impl NoopWaker {
630 fn new_waker() -> Waker {
631 Waker::from(Arc::new(Self))
632 }
633}
634
635impl std::task::Wake for NoopWaker {
636 fn wake(self: Arc<Self>) {}
637}