LCOV - code coverage report
Current view: top level - proxy/src/serverless - websocket.rs (source / functions) Coverage Total Hit
Test: 32f4a56327bc9da697706839ed4836b2a00a408f.info Lines: 69.7 % 119 83
Test Date: 2024-02-07 07:37:29 Functions: 34.5 % 29 10

            Line data    Source code
       1              : use crate::{
       2              :     cancellation::CancelMap,
       3              :     config::ProxyConfig,
       4              :     context::RequestMonitoring,
       5              :     error::io_error,
       6              :     proxy::{handle_client, ClientMode},
       7              :     rate_limiter::EndpointRateLimiter,
       8              : };
       9              : use bytes::{Buf, Bytes};
      10              : use futures::{Sink, Stream};
      11              : use hyper::upgrade::Upgraded;
      12              : use hyper_tungstenite::{tungstenite::Message, HyperWebsocket, WebSocketStream};
      13              : use pin_project_lite::pin_project;
      14              : 
      15              : use std::{
      16              :     pin::Pin,
      17              :     sync::Arc,
      18              :     task::{ready, Context, Poll},
      19              : };
      20              : use tokio::io::{self, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
      21              : use tracing::warn;
      22              : 
      23              : // TODO: use `std::sync::Exclusive` once it's stabilized.
      24              : // Tracking issue: https://github.com/rust-lang/rust/issues/98407.
      25              : use sync_wrapper::SyncWrapper;
      26              : 
      27              : pin_project! {
      28              :     /// This is a wrapper around a [`WebSocketStream`] that
      29              :     /// implements [`AsyncRead`] and [`AsyncWrite`].
      30              :     pub struct WebSocketRw<S = Upgraded> {
      31              :         #[pin]
      32              :         stream: SyncWrapper<WebSocketStream<S>>,
      33              :         bytes: Bytes,
      34              :     }
      35              : }
      36              : 
      37              : impl<S> WebSocketRw<S> {
      38            2 :     pub fn new(stream: WebSocketStream<S>) -> Self {
      39            2 :         Self {
      40            2 :             stream: stream.into(),
      41            2 :             bytes: Bytes::new(),
      42            2 :         }
      43            2 :     }
      44              : }
      45              : 
      46              : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for WebSocketRw<S> {
      47            2 :     fn poll_write(
      48            2 :         self: Pin<&mut Self>,
      49            2 :         cx: &mut Context<'_>,
      50            2 :         buf: &[u8],
      51            2 :     ) -> Poll<io::Result<usize>> {
      52            2 :         let mut stream = self.project().stream.get_pin_mut();
      53              : 
      54            2 :         ready!(stream.as_mut().poll_ready(cx).map_err(io_error))?;
      55            2 :         match stream.as_mut().start_send(Message::Binary(buf.into())) {
      56            2 :             Ok(()) => Poll::Ready(Ok(buf.len())),
      57            0 :             Err(e) => Poll::Ready(Err(io_error(e))),
      58              :         }
      59            2 :     }
      60              : 
      61            2 :     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
      62            2 :         let stream = self.project().stream.get_pin_mut();
      63            2 :         stream.poll_flush(cx).map_err(io_error)
      64            2 :     }
      65              : 
      66            0 :     fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
      67            0 :         let stream = self.project().stream.get_pin_mut();
      68            0 :         stream.poll_close(cx).map_err(io_error)
      69            0 :     }
      70              : }
      71              : 
      72              : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for WebSocketRw<S> {
      73            6 :     fn poll_read(
      74            6 :         mut self: Pin<&mut Self>,
      75            6 :         cx: &mut Context<'_>,
      76            6 :         buf: &mut ReadBuf<'_>,
      77            6 :     ) -> Poll<io::Result<()>> {
      78            6 :         if buf.remaining() > 0 {
      79            6 :             let bytes = ready!(self.as_mut().poll_fill_buf(cx))?;
      80            4 :             let len = std::cmp::min(bytes.len(), buf.remaining());
      81            4 :             buf.put_slice(&bytes[..len]);
      82            4 :             self.consume(len);
      83            0 :         }
      84              : 
      85            4 :         Poll::Ready(Ok(()))
      86            6 :     }
      87              : }
      88              : 
      89              : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncBufRead for WebSocketRw<S> {
      90            6 :     fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
      91            6 :         // Please refer to poll_fill_buf's documentation.
      92            6 :         const EOF: Poll<io::Result<&[u8]>> = Poll::Ready(Ok(&[]));
      93            6 : 
      94            6 :         let mut this = self.project();
      95            8 :         loop {
      96            8 :             if !this.bytes.chunk().is_empty() {
      97            2 :                 let chunk = (*this.bytes).chunk();
      98            2 :                 return Poll::Ready(Ok(chunk));
      99            6 :             }
     100              : 
     101            6 :             let res = ready!(this.stream.as_mut().get_pin_mut().poll_next(cx));
     102            4 :             match res.transpose().map_err(io_error)? {
     103            4 :                 Some(message) => match message {
     104            0 :                     Message::Ping(_) => {}
     105            0 :                     Message::Pong(_) => {}
     106            0 :                     Message::Text(text) => {
     107            0 :                         // We expect to see only binary messages.
     108            0 :                         let error = "unexpected text message in the websocket";
     109            0 :                         warn!(length = text.len(), error);
     110            0 :                         return Poll::Ready(Err(io_error(error)));
     111              :                     }
     112              :                     Message::Frame(_) => {
     113              :                         // This case is impossible according to Frame's doc.
     114            0 :                         panic!("unexpected raw frame in the websocket");
     115              :                     }
     116            2 :                     Message::Binary(chunk) => {
     117            2 :                         assert!(this.bytes.is_empty());
     118            2 :                         *this.bytes = Bytes::from(chunk);
     119              :                     }
     120            2 :                     Message::Close(_) => return EOF,
     121              :                 },
     122            0 :                 None => return EOF,
     123              :             }
     124              :         }
     125            6 :     }
     126              : 
     127            4 :     fn consume(self: Pin<&mut Self>, amount: usize) {
     128            4 :         self.project().bytes.advance(amount);
     129            4 :     }
     130              : }
     131              : 
     132            0 : pub async fn serve_websocket(
     133            0 :     config: &'static ProxyConfig,
     134            0 :     ctx: &mut RequestMonitoring,
     135            0 :     websocket: HyperWebsocket,
     136            0 :     cancel_map: Arc<CancelMap>,
     137            0 :     hostname: Option<String>,
     138            0 :     endpoint_rate_limiter: Arc<EndpointRateLimiter>,
     139            0 : ) -> anyhow::Result<()> {
     140            0 :     let websocket = websocket.await?;
     141            0 :     handle_client(
     142            0 :         config,
     143            0 :         ctx,
     144            0 :         cancel_map,
     145            0 :         WebSocketRw::new(websocket),
     146            0 :         ClientMode::Websockets { hostname },
     147            0 :         endpoint_rate_limiter,
     148            0 :     )
     149            0 :     .await?;
     150            0 :     Ok(())
     151            0 : }
     152              : 
     153              : #[cfg(test)]
     154              : mod tests {
     155              :     use std::pin::pin;
     156              : 
     157              :     use futures::{SinkExt, StreamExt};
     158              :     use hyper_tungstenite::{
     159              :         tungstenite::{protocol::Role, Message},
     160              :         WebSocketStream,
     161              :     };
     162              :     use tokio::{
     163              :         io::{duplex, AsyncReadExt, AsyncWriteExt},
     164              :         task::JoinSet,
     165              :     };
     166              : 
     167              :     use super::WebSocketRw;
     168              : 
     169            2 :     #[tokio::test]
     170            2 :     async fn websocket_stream_wrapper_happy_path() {
     171            2 :         let (stream1, stream2) = duplex(1024);
     172            2 : 
     173            2 :         let mut js = JoinSet::new();
     174            2 : 
     175            2 :         js.spawn(async move {
     176            2 :             let mut client = WebSocketStream::from_raw_socket(stream1, Role::Client, None).await;
     177              : 
     178            2 :             client
     179            2 :                 .send(Message::Binary(b"hello world".to_vec()))
     180            0 :                 .await
     181            2 :                 .unwrap();
     182              : 
     183            2 :             let message = client.next().await.unwrap().unwrap();
     184            2 :             assert_eq!(message, Message::Binary(b"websockets are cool".to_vec()));
     185              : 
     186            2 :             client.close(None).await.unwrap();
     187            2 :         });
     188            2 : 
     189            2 :         js.spawn(async move {
     190            2 :             let mut rw = pin!(WebSocketRw::new(
     191            2 :                 WebSocketStream::from_raw_socket(stream2, Role::Server, None).await
     192              :             ));
     193              : 
     194            2 :             let mut buf = vec![0; 1024];
     195            2 :             let n = rw.read(&mut buf).await.unwrap();
     196            2 :             assert_eq!(&buf[..n], b"hello world");
     197              : 
     198            2 :             rw.write_all(b"websockets are cool").await.unwrap();
     199            2 :             rw.flush().await.unwrap();
     200              : 
     201            2 :             let n = rw.read_to_end(&mut buf).await.unwrap();
     202            2 :             assert_eq!(n, 0);
     203            2 :         });
     204            2 : 
     205            2 :         js.join_next().await.unwrap().unwrap();
     206            2 :         js.join_next().await.unwrap().unwrap();
     207              :     }
     208              : }
        

Generated by: LCOV version 2.1-beta