1use std::{
2 pin::Pin,
3 sync::{Arc, Mutex},
4 task::{Context, Poll, Waker},
5};
6
7use crate::Result;
8use bincode::config;
9use futures_util::{Future, Sink, SinkExt, Stream, future::BoxFuture};
10#[cfg(feature = "hyper")]
11use hyper_util::rt::tokio::TokioIo;
12use serde::Serialize;
13#[cfg(feature = "tokio-tungstenite")]
14use tokio::net::TcpStream;
15use tokio::sync::{
16 mpsc::{self, error::TrySendError},
17 oneshot,
18};
19
20use crate::{NetworkError, io_err_into_net_error};
21
22#[derive(Debug, Clone, Default)]
23pub(crate) struct RemoteTxWakers {
24 wakers: Arc<Mutex<Vec<Waker>>>,
25}
26impl RemoteTxWakers {
27 pub fn add(&self, waker: &Waker) {
28 let mut guard = self.wakers.lock().unwrap();
29 if !guard.iter().any(|w| w.will_wake(waker)) {
30 guard.push(waker.clone());
31 }
32 }
33 pub fn wake(&self) {
34 let mut guard = self.wakers.lock().unwrap();
35 guard.drain(..).for_each(|w| w.wake());
36 }
37}
38
39pub(crate) type StreamSink<T> = Pin<Box<dyn Sink<T, Error = std::io::Error> + Send + 'static>>;
40
41#[derive(derive_more::Debug)]
42pub(crate) enum RemoteTx<T>
43where
44 T: Serialize,
45{
46 Mpsc {
47 tx: mpsc::Sender<T>,
48 work: mpsc::UnboundedSender<BoxFuture<'static, ()>>,
49 wakers: RemoteTxWakers,
50 },
51 Stream {
52 #[debug(ignore)]
53 tx: Arc<tokio::sync::Mutex<StreamSink<T>>>,
54 work: mpsc::UnboundedSender<BoxFuture<'static, ()>>,
55 wakers: RemoteTxWakers,
56 },
57 #[cfg(feature = "hyper")]
58 HyperWebSocket {
59 tx: Arc<
60 tokio::sync::Mutex<
61 futures_util::stream::SplitSink<
62 hyper_tungstenite::WebSocketStream<TokioIo<hyper::upgrade::Upgraded>>,
63 hyper_tungstenite::tungstenite::Message,
64 >,
65 >,
66 >,
67 work: mpsc::UnboundedSender<BoxFuture<'static, ()>>,
68 wakers: RemoteTxWakers,
69 format: crate::meta::FrameSerializationFormat,
70 },
71 #[cfg(feature = "tokio-tungstenite")]
72 TokioWebSocket {
73 tx: Arc<
74 tokio::sync::Mutex<
75 futures_util::stream::SplitSink<
76 tokio_tungstenite::WebSocketStream<
77 tokio_tungstenite::MaybeTlsStream<TcpStream>,
78 >,
79 tokio_tungstenite::tungstenite::Message,
80 >,
81 >,
82 >,
83 work: mpsc::UnboundedSender<BoxFuture<'static, ()>>,
84 wakers: RemoteTxWakers,
85 format: crate::meta::FrameSerializationFormat,
86 },
87}
88impl<T> RemoteTx<T>
89where
90 T: Serialize + Send + Sync + 'static,
91{
92 pub(crate) async fn send(&self, req: T) -> Result<()> {
93 match self {
94 RemoteTx::Mpsc { tx, .. } => tx
95 .send(req)
96 .await
97 .map_err(|_| NetworkError::ConnectionAborted),
98 RemoteTx::Stream { tx, work, .. } => {
99 let (tx_done, rx_done) = oneshot::channel();
100 let tx = tx.clone();
101 work.send(Box::pin(async move {
102 let job = async {
103 let mut tx_guard = tx.lock_owned().await;
104 tx_guard.send(req).await.map_err(io_err_into_net_error)
105 };
106 tx_done.send(job.await).ok();
107 }))
108 .map_err(|_| NetworkError::ConnectionAborted)?;
109
110 rx_done
111 .await
112 .unwrap_or(Err(NetworkError::ConnectionAborted))
113 }
114 #[cfg(feature = "hyper")]
115 RemoteTx::HyperWebSocket { tx, format, .. } => {
116 let data = match format {
117 crate::meta::FrameSerializationFormat::Bincode => {
118 bincode::serde::encode_to_vec(&req, config::legacy()).map_err(|err| {
119 tracing::warn!("failed to serialize message - {err}");
120 NetworkError::IOError
121 })?
122 }
123 format => {
124 tracing::warn!("format not currently supported - {format:?}");
125 return Err(NetworkError::IOError);
126 }
127 };
128 let mut tx = tx.lock().await;
129 tx.send(hyper_tungstenite::tungstenite::Message::Binary(data))
130 .await
131 .map_err(|_| NetworkError::ConnectionAborted)
132 }
133 #[cfg(feature = "tokio-tungstenite")]
134 RemoteTx::TokioWebSocket { tx, format, .. } => {
135 let data = match format {
136 crate::meta::FrameSerializationFormat::Bincode => {
137 bincode::serde::encode_to_vec(&req, config::legacy()).map_err(|err| {
138 tracing::warn!("failed to serialize message - {err}");
139 NetworkError::IOError
140 })?
141 }
142 format => {
143 tracing::warn!("format not currently supported - {format:?}");
144 return Err(NetworkError::IOError);
145 }
146 };
147 let mut tx = tx.lock().await;
148 tx.send(tokio_tungstenite::tungstenite::Message::Binary(data))
149 .await
150 .map_err(|_| NetworkError::ConnectionAborted)
151 }
152 }
153 }
154
155 pub(crate) fn poll_send(&self, cx: &mut Context<'_>, req: T) -> Poll<Result<()>> {
156 match self {
157 RemoteTx::Mpsc { tx, wakers, .. } => match tx.try_send(req) {
158 Ok(()) => Poll::Ready(Ok(())),
159 Err(TrySendError::Closed(_)) => Poll::Ready(Err(NetworkError::ConnectionAborted)),
160 Err(TrySendError::Full(_)) => {
161 wakers.add(cx.waker());
162 Poll::Pending
163 }
164 },
165 RemoteTx::Stream { tx, work, wakers } => {
166 let mut tx_guard = match tx.clone().try_lock_owned() {
167 Ok(lock) => lock,
168 Err(_) => {
169 wakers.add(cx.waker());
170 return Poll::Pending;
171 }
172 };
173 match tx_guard.poll_ready_unpin(cx) {
174 Poll::Ready(Ok(())) => {}
175 Poll::Ready(Err(err)) => return Poll::Ready(Err(io_err_into_net_error(err))),
176 Poll::Pending => return Poll::Pending,
177 }
178 let mut job = Box::pin(async move {
179 if let Err(err) = tx_guard.send(req).await.map_err(io_err_into_net_error) {
180 tracing::error!("failed to send remaining bytes for request - {}", err);
181 }
182 });
183
184 if job.as_mut().poll(cx).is_ready() {
186 return Poll::Ready(Ok(()));
187 }
188
189 work.send(job).map_err(|err| {
192 tracing::error!("failed to send remaining bytes for request - {}", err);
193 NetworkError::ConnectionAborted
194 })?;
195 Poll::Ready(Ok(()))
196 }
197 #[cfg(feature = "hyper")]
198 RemoteTx::HyperWebSocket {
199 tx,
200 format,
201 work,
202 wakers,
203 ..
204 } => {
205 let mut tx_guard = match tx.clone().try_lock_owned() {
206 Ok(lock) => lock,
207 Err(_) => {
208 wakers.add(cx.waker());
209 return Poll::Pending;
210 }
211 };
212 match tx_guard.poll_ready_unpin(cx) {
213 Poll::Ready(Ok(())) => {}
214 Poll::Ready(Err(err)) => {
215 tracing::warn!("failed to poll web socket for readiness - {err}");
216 return Poll::Ready(Err(NetworkError::IOError));
217 }
218 Poll::Pending => return Poll::Pending,
219 }
220
221 let data = match format {
222 crate::meta::FrameSerializationFormat::Bincode => {
223 bincode::serde::encode_to_vec(&req, config::legacy()).map_err(|err| {
224 tracing::warn!("failed to serialize message - {err}");
225 NetworkError::IOError
226 })?
227 }
228 format => {
229 tracing::warn!("format not currently supported - {format:?}");
230 return Poll::Ready(Err(NetworkError::IOError));
231 }
232 };
233
234 let mut job = Box::pin(async move {
235 if let Err(err) = tx_guard
236 .send(hyper_tungstenite::tungstenite::Message::Binary(data))
237 .await
238 {
239 tracing::error!("failed to send remaining bytes for request - {}", err);
240 }
241 });
242
243 if job.as_mut().poll(cx).is_ready() {
245 return Poll::Ready(Ok(()));
246 }
247
248 work.send(job).map_err(|err| {
251 tracing::error!("failed to send remaining bytes for request - {}", err);
252 NetworkError::ConnectionAborted
253 })?;
254 Poll::Ready(Ok(()))
255 }
256 #[cfg(feature = "tokio-tungstenite")]
257 RemoteTx::TokioWebSocket {
258 tx,
259 format,
260 work,
261 wakers,
262 ..
263 } => {
264 let mut tx_guard = match tx.clone().try_lock_owned() {
265 Ok(lock) => lock,
266 Err(_) => {
267 wakers.add(cx.waker());
268 return Poll::Pending;
269 }
270 };
271 match tx_guard.poll_ready_unpin(cx) {
272 Poll::Ready(Ok(())) => {}
273 Poll::Ready(Err(err)) => {
274 tracing::warn!("failed to poll web socket for readiness - {err}");
275 return Poll::Ready(Err(NetworkError::IOError));
276 }
277 Poll::Pending => return Poll::Pending,
278 }
279
280 let data = match format {
281 crate::meta::FrameSerializationFormat::Bincode => {
282 bincode::serde::encode_to_vec(&req, config::legacy()).map_err(|err| {
283 tracing::warn!("failed to serialize message - {err}");
284 NetworkError::IOError
285 })?
286 }
287 format => {
288 tracing::warn!("format not currently supported - {format:?}");
289 return Poll::Ready(Err(NetworkError::IOError));
290 }
291 };
292
293 let mut job = Box::pin(async move {
294 if let Err(err) = tx_guard
295 .send(tokio_tungstenite::tungstenite::Message::Binary(data))
296 .await
297 {
298 tracing::error!("failed to send remaining bytes for request - {}", err);
299 }
300 });
301
302 if job.as_mut().poll(cx).is_ready() {
304 return Poll::Ready(Ok(()));
305 }
306
307 work.send(job).map_err(|err| {
310 tracing::error!("failed to send remaining bytes for request - {}", err);
311 NetworkError::ConnectionAborted
312 })?;
313 Poll::Ready(Ok(()))
314 }
315 }
316 }
317
318 pub(crate) fn send_with_driver(&self, req: T) -> Result<()> {
319 match self {
320 RemoteTx::Mpsc { tx, work, .. } => match tx.try_send(req) {
321 Ok(()) => Ok(()),
322 Err(TrySendError::Closed(_)) => Err(NetworkError::ConnectionAborted),
323 Err(TrySendError::Full(req)) => {
324 let tx = tx.clone();
325 work.send(Box::pin(async move {
326 tx.send(req).await.ok();
327 }))
328 .ok();
329 Ok(())
330 }
331 },
332 RemoteTx::Stream { tx, work, .. } => {
333 let mut tx_guard = match tx.clone().try_lock_owned() {
334 Ok(lock) => lock,
335 Err(_) => {
336 let tx = tx.clone();
337 work.send(Box::pin(async move {
338 let mut tx_guard = tx.lock().await;
339 tx_guard.send(req).await.ok();
340 }))
341 .ok();
342 return Ok(());
343 }
344 };
345
346 let waker = NoopWaker::new_waker();
347 let mut cx = Context::from_waker(&waker);
348
349 let mut job = Box::pin(async move {
350 if let Err(err) = tx_guard.send(req).await.map_err(io_err_into_net_error) {
351 tracing::error!("failed to send remaining bytes for request - {}", err);
352 }
353 });
354
355 if job.as_mut().poll(&mut cx).is_ready() {
357 return Ok(());
358 }
359
360 work.send(job).map_err(|err| {
363 tracing::error!("failed to send remaining bytes for request - {}", err);
364 NetworkError::ConnectionAborted
365 })?;
366 Ok(())
367 }
368 #[cfg(feature = "hyper")]
369 RemoteTx::HyperWebSocket {
370 tx, format, work, ..
371 } => {
372 let data = match format {
373 crate::meta::FrameSerializationFormat::Bincode => {
374 bincode::serde::encode_to_vec(&req, config::legacy()).map_err(|err| {
375 tracing::warn!("failed to serialize message - {err}");
376 NetworkError::IOError
377 })?
378 }
379 format => {
380 tracing::warn!("format not currently supported - {format:?}");
381 return Err(NetworkError::IOError);
382 }
383 };
384
385 let mut tx_guard = match tx.clone().try_lock_owned() {
386 Ok(lock) => lock,
387 Err(_) => {
388 let tx = tx.clone();
389 work.send(Box::pin(async move {
390 let mut tx_guard = tx.lock().await;
391 tx_guard
392 .send(hyper_tungstenite::tungstenite::Message::Binary(data))
393 .await
394 .ok();
395 }))
396 .ok();
397 return Ok(());
398 }
399 };
400
401 let waker = NoopWaker::new_waker();
402 let mut cx = Context::from_waker(&waker);
403
404 let mut job = Box::pin(async move {
405 if let Err(err) = tx_guard
406 .send(hyper_tungstenite::tungstenite::Message::Binary(data))
407 .await
408 {
409 tracing::error!("failed to send remaining bytes for request - {}", err);
410 }
411 });
412
413 if job.as_mut().poll(&mut cx).is_ready() {
415 return Ok(());
416 }
417
418 work.send(job).map_err(|err| {
421 tracing::error!("failed to send remaining bytes for request - {}", err);
422 NetworkError::ConnectionAborted
423 })?;
424 Ok(())
425 }
426 #[cfg(feature = "tokio-tungstenite")]
427 RemoteTx::TokioWebSocket {
428 tx, format, work, ..
429 } => {
430 let data = match format {
431 crate::meta::FrameSerializationFormat::Bincode => {
432 bincode::serde::encode_to_vec(&req, config::legacy()).map_err(|err| {
433 tracing::warn!("failed to serialize message - {err}");
434 NetworkError::IOError
435 })?
436 }
437 format => {
438 tracing::warn!("format not currently supported - {format:?}");
439 return Err(NetworkError::IOError);
440 }
441 };
442
443 let mut tx_guard = match tx.clone().try_lock_owned() {
444 Ok(lock) => lock,
445 Err(_) => {
446 let tx = tx.clone();
447 work.send(Box::pin(async move {
448 let mut tx_guard = tx.lock().await;
449 tx_guard
450 .send(tokio_tungstenite::tungstenite::Message::Binary(data))
451 .await
452 .ok();
453 }))
454 .ok();
455 return Ok(());
456 }
457 };
458
459 let waker = NoopWaker::new_waker();
460 let mut cx = Context::from_waker(&waker);
461
462 let mut job = Box::pin(async move {
463 if let Err(err) = tx_guard
464 .send(tokio_tungstenite::tungstenite::Message::Binary(data))
465 .await
466 {
467 tracing::error!("failed to send remaining bytes for request - {}", err);
468 }
469 });
470
471 if job.as_mut().poll(&mut cx).is_ready() {
473 return Ok(());
474 }
475
476 work.send(job).map_err(|err| {
479 tracing::error!("failed to send remaining bytes for request - {}", err);
480 NetworkError::ConnectionAborted
481 })?;
482 Ok(())
483 }
484 }
485 }
486}
487
488#[derive(derive_more::Debug)]
489pub(crate) enum RemoteRx<T>
490where
491 T: serde::de::DeserializeOwned,
492{
493 Mpsc {
494 rx: mpsc::Receiver<T>,
495 wakers: RemoteTxWakers,
496 },
497 Stream {
498 #[debug(ignore)]
499 rx: Pin<Box<dyn Stream<Item = std::io::Result<T>> + Send + 'static>>,
500 },
501 #[cfg(feature = "hyper")]
502 HyperWebSocket {
503 rx: futures_util::stream::SplitStream<
504 hyper_tungstenite::WebSocketStream<TokioIo<hyper::upgrade::Upgraded>>,
505 >,
506 format: crate::meta::FrameSerializationFormat,
507 },
508 #[cfg(feature = "tokio-tungstenite")]
509 TokioWebSocket {
510 rx: futures_util::stream::SplitStream<
511 tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<TcpStream>>,
512 >,
513 format: crate::meta::FrameSerializationFormat,
514 },
515}
516impl<T> RemoteRx<T>
517where
518 T: serde::de::DeserializeOwned,
519{
520 pub(crate) fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
521 loop {
522 return match self {
523 RemoteRx::Mpsc { rx, wakers } => {
524 let ret = Pin::new(rx).poll_recv(cx);
525 if ret.is_ready() {
526 wakers.wake();
527 }
528 ret
529 }
530 RemoteRx::Stream { rx } => match rx.as_mut().poll_next(cx) {
531 Poll::Ready(Some(Ok(msg))) => Poll::Ready(Some(msg)),
532 Poll::Ready(Some(Err(err))) => {
533 tracing::debug!("failed to read from channel - {}", err);
534 Poll::Ready(None)
535 }
536 Poll::Ready(None) => Poll::Ready(None),
537 Poll::Pending => Poll::Pending,
538 },
539 #[cfg(feature = "hyper")]
540 RemoteRx::HyperWebSocket { rx, format } => match Pin::new(rx).poll_next(cx) {
541 Poll::Ready(Some(Ok(hyper_tungstenite::tungstenite::Message::Binary(msg)))) => {
542 match format {
543 crate::meta::FrameSerializationFormat::Bincode => {
544 return match bincode::serde::decode_from_slice(
545 &msg,
546 config::legacy(),
547 ) {
548 Ok((msg, _)) => Poll::Ready(Some(msg)),
549 Err(err) => {
550 tracing::warn!("failed to deserialize message - {}", err);
551 continue;
552 }
553 };
554 }
555 format => {
556 tracing::warn!("format not currently supported - {format:?}");
557 continue;
558 }
559 }
560 }
561 Poll::Ready(Some(Ok(msg))) => {
562 tracing::warn!("unsupported message from channel - {}", msg);
563 continue;
564 }
565 Poll::Ready(Some(Err(err))) => {
566 tracing::debug!("failed to read from channel - {}", err);
567 Poll::Ready(None)
568 }
569 Poll::Ready(None) => Poll::Ready(None),
570 Poll::Pending => Poll::Pending,
571 },
572 #[cfg(feature = "tokio-tungstenite")]
573 RemoteRx::TokioWebSocket { rx, format } => match Pin::new(rx).poll_next(cx) {
574 Poll::Ready(Some(Ok(tokio_tungstenite::tungstenite::Message::Binary(msg)))) => {
575 match format {
576 crate::meta::FrameSerializationFormat::Bincode => {
577 return match bincode::serde::decode_from_slice(
578 &msg,
579 config::legacy(),
580 ) {
581 Ok((msg, _)) => Poll::Ready(Some(msg)),
582 Err(err) => {
583 tracing::warn!("failed to deserialize message - {}", err);
584 continue;
585 }
586 };
587 }
588 format => {
589 tracing::warn!("format not currently supported - {format:?}");
590 continue;
591 }
592 }
593 }
594 Poll::Ready(Some(Ok(msg))) => {
595 tracing::warn!("unsupported message from channel - {}", msg);
596 continue;
597 }
598 Poll::Ready(Some(Err(err))) => {
599 tracing::debug!("failed to read from channel - {}", err);
600 Poll::Ready(None)
601 }
602 Poll::Ready(None) => Poll::Ready(None),
603 Poll::Pending => Poll::Pending,
604 },
605 };
606 }
607 }
608}
609
610struct NoopWaker;
611
612impl NoopWaker {
613 fn new_waker() -> Waker {
614 Waker::from(Arc::new(Self))
615 }
616}
617
618impl std::task::Wake for NoopWaker {
619 fn wake(self: Arc<Self>) {}
620}