1use std::{
2 pin::Pin,
3 sync::{Arc, Mutex},
4 task::{Context, Poll, Waker},
5};
6
7use crate::Result;
8use futures_util::{Future, Sink, SinkExt, Stream, future::BoxFuture};
9#[cfg(feature = "hyper")]
10use hyper_util::rt::tokio::TokioIo;
11use serde::Serialize;
12#[cfg(feature = "tokio-tungstenite")]
13use tokio::net::TcpStream;
14use tokio::sync::{
15 mpsc::{self, error::TrySendError},
16 oneshot,
17};
18use virtual_mio::InlineWaker;
19
20use crate::{NetworkError, io_err_into_net_error};
21
22#[derive(Debug, Clone, Default)]
23pub(crate) struct RemoteTxWakers {
24 wakers: Arc<Mutex<Vec<Waker>>>,
25}
26impl RemoteTxWakers {
27 pub fn add(&self, waker: &Waker) {
28 let mut guard = self.wakers.lock().unwrap();
29 if !guard.iter().any(|w| w.will_wake(waker)) {
30 guard.push(waker.clone());
31 }
32 }
33 pub fn wake(&self) {
34 let mut guard = self.wakers.lock().unwrap();
35 guard.drain(..).for_each(|w| w.wake());
36 }
37}
38
39pub(crate) type StreamSink<T> = Pin<Box<dyn Sink<T, Error = std::io::Error> + Send + 'static>>;
40
41#[derive(derive_more::Debug)]
42pub(crate) enum RemoteTx<T>
43where
44 T: Serialize,
45{
46 Mpsc {
47 tx: mpsc::Sender<T>,
48 work: mpsc::UnboundedSender<BoxFuture<'static, ()>>,
49 wakers: RemoteTxWakers,
50 },
51 Stream {
52 #[debug(ignore)]
53 tx: Arc<tokio::sync::Mutex<StreamSink<T>>>,
54 work: mpsc::UnboundedSender<BoxFuture<'static, ()>>,
55 wakers: RemoteTxWakers,
56 },
57 #[cfg(feature = "hyper")]
58 HyperWebSocket {
59 tx: Arc<
60 tokio::sync::Mutex<
61 futures_util::stream::SplitSink<
62 hyper_tungstenite::WebSocketStream<TokioIo<hyper::upgrade::Upgraded>>,
63 hyper_tungstenite::tungstenite::Message,
64 >,
65 >,
66 >,
67 work: mpsc::UnboundedSender<BoxFuture<'static, ()>>,
68 wakers: RemoteTxWakers,
69 format: crate::meta::FrameSerializationFormat,
70 },
71 #[cfg(feature = "tokio-tungstenite")]
72 TokioWebSocket {
73 tx: Arc<
74 tokio::sync::Mutex<
75 futures_util::stream::SplitSink<
76 tokio_tungstenite::WebSocketStream<
77 tokio_tungstenite::MaybeTlsStream<TcpStream>,
78 >,
79 tokio_tungstenite::tungstenite::Message,
80 >,
81 >,
82 >,
83 work: mpsc::UnboundedSender<BoxFuture<'static, ()>>,
84 wakers: RemoteTxWakers,
85 format: crate::meta::FrameSerializationFormat,
86 },
87}
88impl<T> RemoteTx<T>
89where
90 T: Serialize + Send + Sync + 'static,
91{
92 pub(crate) async fn send(&self, req: T) -> Result<()> {
93 match self {
94 RemoteTx::Mpsc { tx, .. } => tx
95 .send(req)
96 .await
97 .map_err(|_| NetworkError::ConnectionAborted),
98 RemoteTx::Stream { tx, work, .. } => {
99 let (tx_done, rx_done) = oneshot::channel();
100 let tx = tx.clone();
101 work.send(Box::pin(async move {
102 let job = async {
103 let mut tx_guard = tx.lock_owned().await;
104 tx_guard.send(req).await.map_err(io_err_into_net_error)
105 };
106 tx_done.send(job.await).ok();
107 }))
108 .map_err(|_| NetworkError::ConnectionAborted)?;
109
110 rx_done
111 .await
112 .unwrap_or(Err(NetworkError::ConnectionAborted))
113 }
114 #[cfg(feature = "hyper")]
115 RemoteTx::HyperWebSocket { tx, format, .. } => {
116 let data = match format {
117 crate::meta::FrameSerializationFormat::Bincode => bincode::serialize(&req)
118 .map_err(|err| {
119 tracing::warn!("failed to serialize message - {err}");
120 NetworkError::IOError
121 })?,
122 format => {
123 tracing::warn!("format not currently supported - {format:?}");
124 return Err(NetworkError::IOError);
125 }
126 };
127 let mut tx = tx.lock().await;
128 tx.send(hyper_tungstenite::tungstenite::Message::Binary(data))
129 .await
130 .map_err(|_| NetworkError::ConnectionAborted)
131 }
132 #[cfg(feature = "tokio-tungstenite")]
133 RemoteTx::TokioWebSocket { tx, format, .. } => {
134 let data = match format {
135 crate::meta::FrameSerializationFormat::Bincode => bincode::serialize(&req)
136 .map_err(|err| {
137 tracing::warn!("failed to serialize message - {err}");
138 NetworkError::IOError
139 })?,
140 format => {
141 tracing::warn!("format not currently supported - {format:?}");
142 return Err(NetworkError::IOError);
143 }
144 };
145 let mut tx = tx.lock().await;
146 tx.send(tokio_tungstenite::tungstenite::Message::Binary(data))
147 .await
148 .map_err(|_| NetworkError::ConnectionAborted)
149 }
150 }
151 }
152
153 pub(crate) fn poll_send(&self, cx: &mut Context<'_>, req: T) -> Poll<Result<()>> {
154 match self {
155 RemoteTx::Mpsc { tx, wakers, .. } => match tx.try_send(req) {
156 Ok(()) => Poll::Ready(Ok(())),
157 Err(TrySendError::Closed(_)) => Poll::Ready(Err(NetworkError::ConnectionAborted)),
158 Err(TrySendError::Full(_)) => {
159 wakers.add(cx.waker());
160 Poll::Pending
161 }
162 },
163 RemoteTx::Stream { tx, work, wakers } => {
164 let mut tx_guard = match tx.clone().try_lock_owned() {
165 Ok(lock) => lock,
166 Err(_) => {
167 wakers.add(cx.waker());
168 return Poll::Pending;
169 }
170 };
171 match tx_guard.poll_ready_unpin(cx) {
172 Poll::Ready(Ok(())) => {}
173 Poll::Ready(Err(err)) => return Poll::Ready(Err(io_err_into_net_error(err))),
174 Poll::Pending => return Poll::Pending,
175 }
176 let mut job = Box::pin(async move {
177 if let Err(err) = tx_guard.send(req).await.map_err(io_err_into_net_error) {
178 tracing::error!("failed to send remaining bytes for request - {}", err);
179 }
180 });
181
182 if job.as_mut().poll(cx).is_ready() {
184 return Poll::Ready(Ok(()));
185 }
186
187 work.send(job).map_err(|err| {
190 tracing::error!("failed to send remaining bytes for request - {}", err);
191 NetworkError::ConnectionAborted
192 })?;
193 Poll::Ready(Ok(()))
194 }
195 #[cfg(feature = "hyper")]
196 RemoteTx::HyperWebSocket {
197 tx,
198 format,
199 work,
200 wakers,
201 ..
202 } => {
203 let mut tx_guard = match tx.clone().try_lock_owned() {
204 Ok(lock) => lock,
205 Err(_) => {
206 wakers.add(cx.waker());
207 return Poll::Pending;
208 }
209 };
210 match tx_guard.poll_ready_unpin(cx) {
211 Poll::Ready(Ok(())) => {}
212 Poll::Ready(Err(err)) => {
213 tracing::warn!("failed to poll web socket for readiness - {err}");
214 return Poll::Ready(Err(NetworkError::IOError));
215 }
216 Poll::Pending => return Poll::Pending,
217 }
218
219 let data = match format {
220 crate::meta::FrameSerializationFormat::Bincode => bincode::serialize(&req)
221 .map_err(|err| {
222 tracing::warn!("failed to serialize message - {err}");
223 NetworkError::IOError
224 })?,
225 format => {
226 tracing::warn!("format not currently supported - {format:?}");
227 return Poll::Ready(Err(NetworkError::IOError));
228 }
229 };
230
231 let mut job = Box::pin(async move {
232 if let Err(err) = tx_guard
233 .send(hyper_tungstenite::tungstenite::Message::Binary(data))
234 .await
235 {
236 tracing::error!("failed to send remaining bytes for request - {}", err);
237 }
238 });
239
240 if job.as_mut().poll(cx).is_ready() {
242 return Poll::Ready(Ok(()));
243 }
244
245 work.send(job).map_err(|err| {
248 tracing::error!("failed to send remaining bytes for request - {}", err);
249 NetworkError::ConnectionAborted
250 })?;
251 Poll::Ready(Ok(()))
252 }
253 #[cfg(feature = "tokio-tungstenite")]
254 RemoteTx::TokioWebSocket {
255 tx,
256 format,
257 work,
258 wakers,
259 ..
260 } => {
261 let mut tx_guard = match tx.clone().try_lock_owned() {
262 Ok(lock) => lock,
263 Err(_) => {
264 wakers.add(cx.waker());
265 return Poll::Pending;
266 }
267 };
268 match tx_guard.poll_ready_unpin(cx) {
269 Poll::Ready(Ok(())) => {}
270 Poll::Ready(Err(err)) => {
271 tracing::warn!("failed to poll web socket for readiness - {err}");
272 return Poll::Ready(Err(NetworkError::IOError));
273 }
274 Poll::Pending => return Poll::Pending,
275 }
276
277 let data = match format {
278 crate::meta::FrameSerializationFormat::Bincode => bincode::serialize(&req)
279 .map_err(|err| {
280 tracing::warn!("failed to serialize message - {err}");
281 NetworkError::IOError
282 })?,
283 format => {
284 tracing::warn!("format not currently supported - {format:?}");
285 return Poll::Ready(Err(NetworkError::IOError));
286 }
287 };
288
289 let mut job = Box::pin(async move {
290 if let Err(err) = tx_guard
291 .send(tokio_tungstenite::tungstenite::Message::Binary(data))
292 .await
293 {
294 tracing::error!("failed to send remaining bytes for request - {}", err);
295 }
296 });
297
298 if job.as_mut().poll(cx).is_ready() {
300 return Poll::Ready(Ok(()));
301 }
302
303 work.send(job).map_err(|err| {
306 tracing::error!("failed to send remaining bytes for request - {}", err);
307 NetworkError::ConnectionAborted
308 })?;
309 Poll::Ready(Ok(()))
310 }
311 }
312 }
313
314 pub(crate) fn send_with_driver(&self, req: T) -> Result<()> {
315 match self {
316 RemoteTx::Mpsc { tx, work, .. } => match tx.try_send(req) {
317 Ok(()) => Ok(()),
318 Err(TrySendError::Closed(_)) => Err(NetworkError::ConnectionAborted),
319 Err(TrySendError::Full(req)) => {
320 let tx = tx.clone();
321 work.send(Box::pin(async move {
322 tx.send(req).await.ok();
323 }))
324 .ok();
325 Ok(())
326 }
327 },
328 RemoteTx::Stream { tx, work, .. } => {
329 let mut tx_guard = match tx.clone().try_lock_owned() {
330 Ok(lock) => lock,
331 Err(_) => {
332 let tx = tx.clone();
333 work.send(Box::pin(async move {
334 let mut tx_guard = tx.lock().await;
335 tx_guard.send(req).await.ok();
336 }))
337 .ok();
338 return Ok(());
339 }
340 };
341
342 let inline_waker = InlineWaker::new();
343 let waker = inline_waker.as_waker();
344 let mut cx = Context::from_waker(&waker);
345
346 let mut job = Box::pin(async move {
347 if let Err(err) = tx_guard.send(req).await.map_err(io_err_into_net_error) {
348 tracing::error!("failed to send remaining bytes for request - {}", err);
349 }
350 });
351
352 if job.as_mut().poll(&mut cx).is_ready() {
354 return Ok(());
355 }
356
357 work.send(job).map_err(|err| {
360 tracing::error!("failed to send remaining bytes for request - {}", err);
361 NetworkError::ConnectionAborted
362 })?;
363 Ok(())
364 }
365 #[cfg(feature = "hyper")]
366 RemoteTx::HyperWebSocket {
367 tx, format, work, ..
368 } => {
369 let data = match format {
370 crate::meta::FrameSerializationFormat::Bincode => bincode::serialize(&req)
371 .map_err(|err| {
372 tracing::warn!("failed to serialize message - {err}");
373 NetworkError::IOError
374 })?,
375 format => {
376 tracing::warn!("format not currently supported - {format:?}");
377 return Err(NetworkError::IOError);
378 }
379 };
380
381 let mut tx_guard = match tx.clone().try_lock_owned() {
382 Ok(lock) => lock,
383 Err(_) => {
384 let tx = tx.clone();
385 work.send(Box::pin(async move {
386 let mut tx_guard = tx.lock().await;
387 tx_guard
388 .send(hyper_tungstenite::tungstenite::Message::Binary(data))
389 .await
390 .ok();
391 }))
392 .ok();
393 return Ok(());
394 }
395 };
396
397 let inline_waker = InlineWaker::new();
398 let waker = inline_waker.as_waker();
399 let mut cx = Context::from_waker(&waker);
400
401 let mut job = Box::pin(async move {
402 if let Err(err) = tx_guard
403 .send(hyper_tungstenite::tungstenite::Message::Binary(data))
404 .await
405 {
406 tracing::error!("failed to send remaining bytes for request - {}", err);
407 }
408 });
409
410 if job.as_mut().poll(&mut cx).is_ready() {
412 return Ok(());
413 }
414
415 work.send(job).map_err(|err| {
418 tracing::error!("failed to send remaining bytes for request - {}", err);
419 NetworkError::ConnectionAborted
420 })?;
421 Ok(())
422 }
423 #[cfg(feature = "tokio-tungstenite")]
424 RemoteTx::TokioWebSocket {
425 tx, format, work, ..
426 } => {
427 let data = match format {
428 crate::meta::FrameSerializationFormat::Bincode => bincode::serialize(&req)
429 .map_err(|err| {
430 tracing::warn!("failed to serialize message - {err}");
431 NetworkError::IOError
432 })?,
433 format => {
434 tracing::warn!("format not currently supported - {format:?}");
435 return Err(NetworkError::IOError);
436 }
437 };
438
439 let mut tx_guard = match tx.clone().try_lock_owned() {
440 Ok(lock) => lock,
441 Err(_) => {
442 let tx = tx.clone();
443 work.send(Box::pin(async move {
444 let mut tx_guard = tx.lock().await;
445 tx_guard
446 .send(tokio_tungstenite::tungstenite::Message::Binary(data))
447 .await
448 .ok();
449 }))
450 .ok();
451 return Ok(());
452 }
453 };
454
455 let inline_waker = InlineWaker::new();
456 let waker = inline_waker.as_waker();
457 let mut cx = Context::from_waker(&waker);
458
459 let mut job = Box::pin(async move {
460 if let Err(err) = tx_guard
461 .send(tokio_tungstenite::tungstenite::Message::Binary(data))
462 .await
463 {
464 tracing::error!("failed to send remaining bytes for request - {}", err);
465 }
466 });
467
468 if job.as_mut().poll(&mut cx).is_ready() {
470 return Ok(());
471 }
472
473 work.send(job).map_err(|err| {
476 tracing::error!("failed to send remaining bytes for request - {}", err);
477 NetworkError::ConnectionAborted
478 })?;
479 Ok(())
480 }
481 }
482 }
483}
484
485#[derive(derive_more::Debug)]
486pub(crate) enum RemoteRx<T>
487where
488 T: serde::de::DeserializeOwned,
489{
490 Mpsc {
491 rx: mpsc::Receiver<T>,
492 wakers: RemoteTxWakers,
493 },
494 Stream {
495 #[debug(ignore)]
496 rx: Pin<Box<dyn Stream<Item = std::io::Result<T>> + Send + 'static>>,
497 },
498 #[cfg(feature = "hyper")]
499 HyperWebSocket {
500 rx: futures_util::stream::SplitStream<
501 hyper_tungstenite::WebSocketStream<TokioIo<hyper::upgrade::Upgraded>>,
502 >,
503 format: crate::meta::FrameSerializationFormat,
504 },
505 #[cfg(feature = "tokio-tungstenite")]
506 TokioWebSocket {
507 rx: futures_util::stream::SplitStream<
508 tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<TcpStream>>,
509 >,
510 format: crate::meta::FrameSerializationFormat,
511 },
512}
513impl<T> RemoteRx<T>
514where
515 T: serde::de::DeserializeOwned,
516{
517 pub(crate) fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
518 loop {
519 return match self {
520 RemoteRx::Mpsc { rx, wakers } => {
521 let ret = Pin::new(rx).poll_recv(cx);
522 if ret.is_ready() {
523 wakers.wake();
524 }
525 ret
526 }
527 RemoteRx::Stream { rx } => match rx.as_mut().poll_next(cx) {
528 Poll::Ready(Some(Ok(msg))) => Poll::Ready(Some(msg)),
529 Poll::Ready(Some(Err(err))) => {
530 tracing::debug!("failed to read from channel - {}", err);
531 Poll::Ready(None)
532 }
533 Poll::Ready(None) => Poll::Ready(None),
534 Poll::Pending => Poll::Pending,
535 },
536 #[cfg(feature = "hyper")]
537 RemoteRx::HyperWebSocket { rx, format } => match Pin::new(rx).poll_next(cx) {
538 Poll::Ready(Some(Ok(hyper_tungstenite::tungstenite::Message::Binary(msg)))) => {
539 match format {
540 crate::meta::FrameSerializationFormat::Bincode => {
541 return match bincode::deserialize(&msg) {
542 Ok(msg) => Poll::Ready(Some(msg)),
543 Err(err) => {
544 tracing::warn!("failed to deserialize message - {}", err);
545 continue;
546 }
547 };
548 }
549 format => {
550 tracing::warn!("format not currently supported - {format:?}");
551 continue;
552 }
553 }
554 }
555 Poll::Ready(Some(Ok(msg))) => {
556 tracing::warn!("unsupported message from channel - {}", msg);
557 continue;
558 }
559 Poll::Ready(Some(Err(err))) => {
560 tracing::debug!("failed to read from channel - {}", err);
561 Poll::Ready(None)
562 }
563 Poll::Ready(None) => Poll::Ready(None),
564 Poll::Pending => Poll::Pending,
565 },
566 #[cfg(feature = "tokio-tungstenite")]
567 RemoteRx::TokioWebSocket { rx, format } => match Pin::new(rx).poll_next(cx) {
568 Poll::Ready(Some(Ok(tokio_tungstenite::tungstenite::Message::Binary(msg)))) => {
569 match format {
570 crate::meta::FrameSerializationFormat::Bincode => {
571 return match bincode::deserialize(&msg) {
572 Ok(msg) => Poll::Ready(Some(msg)),
573 Err(err) => {
574 tracing::warn!("failed to deserialize message - {}", err);
575 continue;
576 }
577 };
578 }
579 format => {
580 tracing::warn!("format not currently supported - {format:?}");
581 continue;
582 }
583 }
584 }
585 Poll::Ready(Some(Ok(msg))) => {
586 tracing::warn!("unsupported message from channel - {}", msg);
587 continue;
588 }
589 Poll::Ready(Some(Err(err))) => {
590 tracing::debug!("failed to read from channel - {}", err);
591 Poll::Ready(None)
592 }
593 Poll::Ready(None) => Poll::Ready(None),
594 Poll::Pending => Poll::Pending,
595 },
596 };
597 }
598 }
599}