1use std::{
2 pin::Pin,
3 sync::{Arc, Mutex},
4 task::{Context, Poll, Waker},
5};
6
7use crate::Result;
8use futures_util::{Future, Sink, SinkExt, Stream, future::BoxFuture};
9#[cfg(feature = "hyper")]
10use hyper_util::rt::tokio::TokioIo;
11use serde::Serialize;
12#[cfg(feature = "tokio-tungstenite")]
13use tokio::net::TcpStream;
14use tokio::sync::{
15 mpsc::{self, error::TrySendError},
16 oneshot,
17};
18
19use crate::{NetworkError, io_err_into_net_error};
20
21#[derive(Debug, Clone, Default)]
22pub(crate) struct RemoteTxWakers {
23 wakers: Arc<Mutex<Vec<Waker>>>,
24}
25impl RemoteTxWakers {
26 pub fn add(&self, waker: &Waker) {
27 let mut guard = self.wakers.lock().unwrap();
28 if !guard.iter().any(|w| w.will_wake(waker)) {
29 guard.push(waker.clone());
30 }
31 }
32 pub fn wake(&self) {
33 let mut guard = self.wakers.lock().unwrap();
34 guard.drain(..).for_each(|w| w.wake());
35 }
36}
37
38pub(crate) type StreamSink<T> = Pin<Box<dyn Sink<T, Error = std::io::Error> + Send + 'static>>;
39
40#[derive(derive_more::Debug)]
41pub(crate) enum RemoteTx<T>
42where
43 T: Serialize,
44{
45 Mpsc {
46 tx: mpsc::Sender<T>,
47 work: mpsc::UnboundedSender<BoxFuture<'static, ()>>,
48 wakers: RemoteTxWakers,
49 },
50 Stream {
51 #[debug(ignore)]
52 tx: Arc<tokio::sync::Mutex<StreamSink<T>>>,
53 work: mpsc::UnboundedSender<BoxFuture<'static, ()>>,
54 wakers: RemoteTxWakers,
55 },
56 #[cfg(feature = "hyper")]
57 HyperWebSocket {
58 tx: Arc<
59 tokio::sync::Mutex<
60 futures_util::stream::SplitSink<
61 hyper_tungstenite::WebSocketStream<TokioIo<hyper::upgrade::Upgraded>>,
62 hyper_tungstenite::tungstenite::Message,
63 >,
64 >,
65 >,
66 work: mpsc::UnboundedSender<BoxFuture<'static, ()>>,
67 wakers: RemoteTxWakers,
68 format: crate::meta::FrameSerializationFormat,
69 },
70 #[cfg(feature = "tokio-tungstenite")]
71 TokioWebSocket {
72 tx: Arc<
73 tokio::sync::Mutex<
74 futures_util::stream::SplitSink<
75 tokio_tungstenite::WebSocketStream<
76 tokio_tungstenite::MaybeTlsStream<TcpStream>,
77 >,
78 tokio_tungstenite::tungstenite::Message,
79 >,
80 >,
81 >,
82 work: mpsc::UnboundedSender<BoxFuture<'static, ()>>,
83 wakers: RemoteTxWakers,
84 format: crate::meta::FrameSerializationFormat,
85 },
86}
87impl<T> RemoteTx<T>
88where
89 T: Serialize + Send + Sync + 'static,
90{
91 pub(crate) async fn send(&self, req: T) -> Result<()> {
92 match self {
93 RemoteTx::Mpsc { tx, .. } => tx
94 .send(req)
95 .await
96 .map_err(|_| NetworkError::ConnectionAborted),
97 RemoteTx::Stream { tx, work, .. } => {
98 let (tx_done, rx_done) = oneshot::channel();
99 let tx = tx.clone();
100 work.send(Box::pin(async move {
101 let job = async {
102 let mut tx_guard = tx.lock_owned().await;
103 tx_guard.send(req).await.map_err(io_err_into_net_error)
104 };
105 tx_done.send(job.await).ok();
106 }))
107 .map_err(|_| NetworkError::ConnectionAborted)?;
108
109 rx_done
110 .await
111 .unwrap_or(Err(NetworkError::ConnectionAborted))
112 }
113 #[cfg(feature = "hyper")]
114 RemoteTx::HyperWebSocket { tx, format, .. } => {
115 let data = match format {
116 crate::meta::FrameSerializationFormat::Bincode => bincode::serialize(&req)
117 .map_err(|err| {
118 tracing::warn!("failed to serialize message - {err}");
119 NetworkError::IOError
120 })?,
121 format => {
122 tracing::warn!("format not currently supported - {format:?}");
123 return Err(NetworkError::IOError);
124 }
125 };
126 let mut tx = tx.lock().await;
127 tx.send(hyper_tungstenite::tungstenite::Message::Binary(data))
128 .await
129 .map_err(|_| NetworkError::ConnectionAborted)
130 }
131 #[cfg(feature = "tokio-tungstenite")]
132 RemoteTx::TokioWebSocket { tx, format, .. } => {
133 let data = match format {
134 crate::meta::FrameSerializationFormat::Bincode => bincode::serialize(&req)
135 .map_err(|err| {
136 tracing::warn!("failed to serialize message - {err}");
137 NetworkError::IOError
138 })?,
139 format => {
140 tracing::warn!("format not currently supported - {format:?}");
141 return Err(NetworkError::IOError);
142 }
143 };
144 let mut tx = tx.lock().await;
145 tx.send(tokio_tungstenite::tungstenite::Message::Binary(data))
146 .await
147 .map_err(|_| NetworkError::ConnectionAborted)
148 }
149 }
150 }
151
152 pub(crate) fn poll_send(&self, cx: &mut Context<'_>, req: T) -> Poll<Result<()>> {
153 match self {
154 RemoteTx::Mpsc { tx, wakers, .. } => match tx.try_send(req) {
155 Ok(()) => Poll::Ready(Ok(())),
156 Err(TrySendError::Closed(_)) => Poll::Ready(Err(NetworkError::ConnectionAborted)),
157 Err(TrySendError::Full(_)) => {
158 wakers.add(cx.waker());
159 Poll::Pending
160 }
161 },
162 RemoteTx::Stream { tx, work, wakers } => {
163 let mut tx_guard = match tx.clone().try_lock_owned() {
164 Ok(lock) => lock,
165 Err(_) => {
166 wakers.add(cx.waker());
167 return Poll::Pending;
168 }
169 };
170 match tx_guard.poll_ready_unpin(cx) {
171 Poll::Ready(Ok(())) => {}
172 Poll::Ready(Err(err)) => return Poll::Ready(Err(io_err_into_net_error(err))),
173 Poll::Pending => return Poll::Pending,
174 }
175 let mut job = Box::pin(async move {
176 if let Err(err) = tx_guard.send(req).await.map_err(io_err_into_net_error) {
177 tracing::error!("failed to send remaining bytes for request - {}", err);
178 }
179 });
180
181 if job.as_mut().poll(cx).is_ready() {
183 return Poll::Ready(Ok(()));
184 }
185
186 work.send(job).map_err(|err| {
189 tracing::error!("failed to send remaining bytes for request - {}", err);
190 NetworkError::ConnectionAborted
191 })?;
192 Poll::Ready(Ok(()))
193 }
194 #[cfg(feature = "hyper")]
195 RemoteTx::HyperWebSocket {
196 tx,
197 format,
198 work,
199 wakers,
200 ..
201 } => {
202 let mut tx_guard = match tx.clone().try_lock_owned() {
203 Ok(lock) => lock,
204 Err(_) => {
205 wakers.add(cx.waker());
206 return Poll::Pending;
207 }
208 };
209 match tx_guard.poll_ready_unpin(cx) {
210 Poll::Ready(Ok(())) => {}
211 Poll::Ready(Err(err)) => {
212 tracing::warn!("failed to poll web socket for readiness - {err}");
213 return Poll::Ready(Err(NetworkError::IOError));
214 }
215 Poll::Pending => return Poll::Pending,
216 }
217
218 let data = match format {
219 crate::meta::FrameSerializationFormat::Bincode => bincode::serialize(&req)
220 .map_err(|err| {
221 tracing::warn!("failed to serialize message - {err}");
222 NetworkError::IOError
223 })?,
224 format => {
225 tracing::warn!("format not currently supported - {format:?}");
226 return Poll::Ready(Err(NetworkError::IOError));
227 }
228 };
229
230 let mut job = Box::pin(async move {
231 if let Err(err) = tx_guard
232 .send(hyper_tungstenite::tungstenite::Message::Binary(data))
233 .await
234 {
235 tracing::error!("failed to send remaining bytes for request - {}", err);
236 }
237 });
238
239 if job.as_mut().poll(cx).is_ready() {
241 return Poll::Ready(Ok(()));
242 }
243
244 work.send(job).map_err(|err| {
247 tracing::error!("failed to send remaining bytes for request - {}", err);
248 NetworkError::ConnectionAborted
249 })?;
250 Poll::Ready(Ok(()))
251 }
252 #[cfg(feature = "tokio-tungstenite")]
253 RemoteTx::TokioWebSocket {
254 tx,
255 format,
256 work,
257 wakers,
258 ..
259 } => {
260 let mut tx_guard = match tx.clone().try_lock_owned() {
261 Ok(lock) => lock,
262 Err(_) => {
263 wakers.add(cx.waker());
264 return Poll::Pending;
265 }
266 };
267 match tx_guard.poll_ready_unpin(cx) {
268 Poll::Ready(Ok(())) => {}
269 Poll::Ready(Err(err)) => {
270 tracing::warn!("failed to poll web socket for readiness - {err}");
271 return Poll::Ready(Err(NetworkError::IOError));
272 }
273 Poll::Pending => return Poll::Pending,
274 }
275
276 let data = match format {
277 crate::meta::FrameSerializationFormat::Bincode => bincode::serialize(&req)
278 .map_err(|err| {
279 tracing::warn!("failed to serialize message - {err}");
280 NetworkError::IOError
281 })?,
282 format => {
283 tracing::warn!("format not currently supported - {format:?}");
284 return Poll::Ready(Err(NetworkError::IOError));
285 }
286 };
287
288 let mut job = Box::pin(async move {
289 if let Err(err) = tx_guard
290 .send(tokio_tungstenite::tungstenite::Message::Binary(data))
291 .await
292 {
293 tracing::error!("failed to send remaining bytes for request - {}", err);
294 }
295 });
296
297 if job.as_mut().poll(cx).is_ready() {
299 return Poll::Ready(Ok(()));
300 }
301
302 work.send(job).map_err(|err| {
305 tracing::error!("failed to send remaining bytes for request - {}", err);
306 NetworkError::ConnectionAborted
307 })?;
308 Poll::Ready(Ok(()))
309 }
310 }
311 }
312
313 pub(crate) fn send_with_driver(&self, req: T) -> Result<()> {
314 match self {
315 RemoteTx::Mpsc { tx, work, .. } => match tx.try_send(req) {
316 Ok(()) => Ok(()),
317 Err(TrySendError::Closed(_)) => Err(NetworkError::ConnectionAborted),
318 Err(TrySendError::Full(req)) => {
319 let tx = tx.clone();
320 work.send(Box::pin(async move {
321 tx.send(req).await.ok();
322 }))
323 .ok();
324 Ok(())
325 }
326 },
327 RemoteTx::Stream { tx, work, .. } => {
328 let mut tx_guard = match tx.clone().try_lock_owned() {
329 Ok(lock) => lock,
330 Err(_) => {
331 let tx = tx.clone();
332 work.send(Box::pin(async move {
333 let mut tx_guard = tx.lock().await;
334 tx_guard.send(req).await.ok();
335 }))
336 .ok();
337 return Ok(());
338 }
339 };
340
341 let waker = NoopWaker::new_waker();
342 let mut cx = Context::from_waker(&waker);
343
344 let mut job = Box::pin(async move {
345 if let Err(err) = tx_guard.send(req).await.map_err(io_err_into_net_error) {
346 tracing::error!("failed to send remaining bytes for request - {}", err);
347 }
348 });
349
350 if job.as_mut().poll(&mut cx).is_ready() {
352 return Ok(());
353 }
354
355 work.send(job).map_err(|err| {
358 tracing::error!("failed to send remaining bytes for request - {}", err);
359 NetworkError::ConnectionAborted
360 })?;
361 Ok(())
362 }
363 #[cfg(feature = "hyper")]
364 RemoteTx::HyperWebSocket {
365 tx, format, work, ..
366 } => {
367 let data = match format {
368 crate::meta::FrameSerializationFormat::Bincode => bincode::serialize(&req)
369 .map_err(|err| {
370 tracing::warn!("failed to serialize message - {err}");
371 NetworkError::IOError
372 })?,
373 format => {
374 tracing::warn!("format not currently supported - {format:?}");
375 return Err(NetworkError::IOError);
376 }
377 };
378
379 let mut tx_guard = match tx.clone().try_lock_owned() {
380 Ok(lock) => lock,
381 Err(_) => {
382 let tx = tx.clone();
383 work.send(Box::pin(async move {
384 let mut tx_guard = tx.lock().await;
385 tx_guard
386 .send(hyper_tungstenite::tungstenite::Message::Binary(data))
387 .await
388 .ok();
389 }))
390 .ok();
391 return Ok(());
392 }
393 };
394
395 let waker = NoopWaker::new_waker();
396 let mut cx = Context::from_waker(&waker);
397
398 let mut job = Box::pin(async move {
399 if let Err(err) = tx_guard
400 .send(hyper_tungstenite::tungstenite::Message::Binary(data))
401 .await
402 {
403 tracing::error!("failed to send remaining bytes for request - {}", err);
404 }
405 });
406
407 if job.as_mut().poll(&mut cx).is_ready() {
409 return Ok(());
410 }
411
412 work.send(job).map_err(|err| {
415 tracing::error!("failed to send remaining bytes for request - {}", err);
416 NetworkError::ConnectionAborted
417 })?;
418 Ok(())
419 }
420 #[cfg(feature = "tokio-tungstenite")]
421 RemoteTx::TokioWebSocket {
422 tx, format, work, ..
423 } => {
424 let data = match format {
425 crate::meta::FrameSerializationFormat::Bincode => bincode::serialize(&req)
426 .map_err(|err| {
427 tracing::warn!("failed to serialize message - {err}");
428 NetworkError::IOError
429 })?,
430 format => {
431 tracing::warn!("format not currently supported - {format:?}");
432 return Err(NetworkError::IOError);
433 }
434 };
435
436 let mut tx_guard = match tx.clone().try_lock_owned() {
437 Ok(lock) => lock,
438 Err(_) => {
439 let tx = tx.clone();
440 work.send(Box::pin(async move {
441 let mut tx_guard = tx.lock().await;
442 tx_guard
443 .send(tokio_tungstenite::tungstenite::Message::Binary(data))
444 .await
445 .ok();
446 }))
447 .ok();
448 return Ok(());
449 }
450 };
451
452 let waker = NoopWaker::new_waker();
453 let mut cx = Context::from_waker(&waker);
454
455 let mut job = Box::pin(async move {
456 if let Err(err) = tx_guard
457 .send(tokio_tungstenite::tungstenite::Message::Binary(data))
458 .await
459 {
460 tracing::error!("failed to send remaining bytes for request - {}", err);
461 }
462 });
463
464 if job.as_mut().poll(&mut cx).is_ready() {
466 return Ok(());
467 }
468
469 work.send(job).map_err(|err| {
472 tracing::error!("failed to send remaining bytes for request - {}", err);
473 NetworkError::ConnectionAborted
474 })?;
475 Ok(())
476 }
477 }
478 }
479}
480
481#[derive(derive_more::Debug)]
482pub(crate) enum RemoteRx<T>
483where
484 T: serde::de::DeserializeOwned,
485{
486 Mpsc {
487 rx: mpsc::Receiver<T>,
488 wakers: RemoteTxWakers,
489 },
490 Stream {
491 #[debug(ignore)]
492 rx: Pin<Box<dyn Stream<Item = std::io::Result<T>> + Send + 'static>>,
493 },
494 #[cfg(feature = "hyper")]
495 HyperWebSocket {
496 rx: futures_util::stream::SplitStream<
497 hyper_tungstenite::WebSocketStream<TokioIo<hyper::upgrade::Upgraded>>,
498 >,
499 format: crate::meta::FrameSerializationFormat,
500 },
501 #[cfg(feature = "tokio-tungstenite")]
502 TokioWebSocket {
503 rx: futures_util::stream::SplitStream<
504 tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<TcpStream>>,
505 >,
506 format: crate::meta::FrameSerializationFormat,
507 },
508}
509impl<T> RemoteRx<T>
510where
511 T: serde::de::DeserializeOwned,
512{
513 pub(crate) fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
514 loop {
515 return match self {
516 RemoteRx::Mpsc { rx, wakers } => {
517 let ret = Pin::new(rx).poll_recv(cx);
518 if ret.is_ready() {
519 wakers.wake();
520 }
521 ret
522 }
523 RemoteRx::Stream { rx } => match rx.as_mut().poll_next(cx) {
524 Poll::Ready(Some(Ok(msg))) => Poll::Ready(Some(msg)),
525 Poll::Ready(Some(Err(err))) => {
526 tracing::debug!("failed to read from channel - {}", err);
527 Poll::Ready(None)
528 }
529 Poll::Ready(None) => Poll::Ready(None),
530 Poll::Pending => Poll::Pending,
531 },
532 #[cfg(feature = "hyper")]
533 RemoteRx::HyperWebSocket { rx, format } => match Pin::new(rx).poll_next(cx) {
534 Poll::Ready(Some(Ok(hyper_tungstenite::tungstenite::Message::Binary(msg)))) => {
535 match format {
536 crate::meta::FrameSerializationFormat::Bincode => {
537 return match bincode::deserialize(&msg) {
538 Ok(msg) => Poll::Ready(Some(msg)),
539 Err(err) => {
540 tracing::warn!("failed to deserialize message - {}", err);
541 continue;
542 }
543 };
544 }
545 format => {
546 tracing::warn!("format not currently supported - {format:?}");
547 continue;
548 }
549 }
550 }
551 Poll::Ready(Some(Ok(msg))) => {
552 tracing::warn!("unsupported message from channel - {}", msg);
553 continue;
554 }
555 Poll::Ready(Some(Err(err))) => {
556 tracing::debug!("failed to read from channel - {}", err);
557 Poll::Ready(None)
558 }
559 Poll::Ready(None) => Poll::Ready(None),
560 Poll::Pending => Poll::Pending,
561 },
562 #[cfg(feature = "tokio-tungstenite")]
563 RemoteRx::TokioWebSocket { rx, format } => match Pin::new(rx).poll_next(cx) {
564 Poll::Ready(Some(Ok(tokio_tungstenite::tungstenite::Message::Binary(msg)))) => {
565 match format {
566 crate::meta::FrameSerializationFormat::Bincode => {
567 return match bincode::deserialize(&msg) {
568 Ok(msg) => Poll::Ready(Some(msg)),
569 Err(err) => {
570 tracing::warn!("failed to deserialize message - {}", err);
571 continue;
572 }
573 };
574 }
575 format => {
576 tracing::warn!("format not currently supported - {format:?}");
577 continue;
578 }
579 }
580 }
581 Poll::Ready(Some(Ok(msg))) => {
582 tracing::warn!("unsupported message from channel - {}", msg);
583 continue;
584 }
585 Poll::Ready(Some(Err(err))) => {
586 tracing::debug!("failed to read from channel - {}", err);
587 Poll::Ready(None)
588 }
589 Poll::Ready(None) => Poll::Ready(None),
590 Poll::Pending => Poll::Pending,
591 },
592 };
593 }
594 }
595}
596
597struct NoopWaker;
598
599impl NoopWaker {
600 fn new_waker() -> Waker {
601 Waker::from(Arc::new(Self))
602 }
603}
604
605impl std::task::Wake for NoopWaker {
606 fn wake(self: Arc<Self>) {}
607}