LCOV - differential code coverage report
Current view: top level - proxy/src/serverless - websocket.rs (source / functions) Coverage Total Hit UBC CBC
Current: cd44433dd675caa99df17a61b18949c8387e2242.info Lines: 69.7 % 119 83 36 83
Current Date: 2024-01-09 02:06:09 Functions: 34.5 % 29 10 19 10
Baseline: 66c52a629a0f4a503e193045e0df4c77139e344b.info
Baseline Date: 2024-01-08 15:34:46

           TLA  Line data    Source code
       1                 : use crate::{
       2                 :     cancellation::CancelMap,
       3                 :     config::ProxyConfig,
       4                 :     context::RequestMonitoring,
       5                 :     error::io_error,
       6                 :     proxy::{handle_client, ClientMode},
       7                 :     rate_limiter::EndpointRateLimiter,
       8                 : };
       9                 : use bytes::{Buf, Bytes};
      10                 : use futures::{Sink, Stream};
      11                 : use hyper::upgrade::Upgraded;
      12                 : use hyper_tungstenite::{tungstenite::Message, HyperWebsocket, WebSocketStream};
      13                 : use pin_project_lite::pin_project;
      14                 : 
      15                 : use std::{
      16                 :     pin::Pin,
      17                 :     sync::Arc,
      18                 :     task::{ready, Context, Poll},
      19                 : };
      20                 : use tokio::io::{self, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
      21                 : use tracing::warn;
      22                 : 
      23                 : // TODO: use `std::sync::Exclusive` once it's stabilized.
      24                 : // Tracking issue: https://github.com/rust-lang/rust/issues/98407.
      25                 : use sync_wrapper::SyncWrapper;
      26                 : 
      27                 : pin_project! {
      28                 :     /// This is a wrapper around a [`WebSocketStream`] that
      29                 :     /// implements [`AsyncRead`] and [`AsyncWrite`].
      30                 :     pub struct WebSocketRw<S = Upgraded> {
      31                 :         #[pin]
      32                 :         stream: SyncWrapper<WebSocketStream<S>>,
      33                 :         bytes: Bytes,
      34                 :     }
      35                 : }
      36                 : 
      37                 : impl<S> WebSocketRw<S> {
      38 CBC           1 :     pub fn new(stream: WebSocketStream<S>) -> Self {
      39               1 :         Self {
      40               1 :             stream: stream.into(),
      41               1 :             bytes: Bytes::new(),
      42               1 :         }
      43               1 :     }
      44                 : }
      45                 : 
      46                 : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for WebSocketRw<S> {
      47               1 :     fn poll_write(
      48               1 :         self: Pin<&mut Self>,
      49               1 :         cx: &mut Context<'_>,
      50               1 :         buf: &[u8],
      51               1 :     ) -> Poll<io::Result<usize>> {
      52               1 :         let mut stream = self.project().stream.get_pin_mut();
      53                 : 
      54               1 :         ready!(stream.as_mut().poll_ready(cx).map_err(io_error))?;
      55               1 :         match stream.as_mut().start_send(Message::Binary(buf.into())) {
      56               1 :             Ok(()) => Poll::Ready(Ok(buf.len())),
      57 UBC           0 :             Err(e) => Poll::Ready(Err(io_error(e))),
      58                 :         }
      59 CBC           1 :     }
      60                 : 
      61               1 :     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
      62               1 :         let stream = self.project().stream.get_pin_mut();
      63               1 :         stream.poll_flush(cx).map_err(io_error)
      64               1 :     }
      65                 : 
      66 UBC           0 :     fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
      67               0 :         let stream = self.project().stream.get_pin_mut();
      68               0 :         stream.poll_close(cx).map_err(io_error)
      69               0 :     }
      70                 : }
      71                 : 
      72                 : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for WebSocketRw<S> {
      73 CBC           3 :     fn poll_read(
      74               3 :         mut self: Pin<&mut Self>,
      75               3 :         cx: &mut Context<'_>,
      76               3 :         buf: &mut ReadBuf<'_>,
      77               3 :     ) -> Poll<io::Result<()>> {
      78               3 :         if buf.remaining() > 0 {
      79               3 :             let bytes = ready!(self.as_mut().poll_fill_buf(cx))?;
      80               2 :             let len = std::cmp::min(bytes.len(), buf.remaining());
      81               2 :             buf.put_slice(&bytes[..len]);
      82               2 :             self.consume(len);
      83 UBC           0 :         }
      84                 : 
      85 CBC           2 :         Poll::Ready(Ok(()))
      86               3 :     }
      87                 : }
      88                 : 
      89                 : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncBufRead for WebSocketRw<S> {
      90               3 :     fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
      91               3 :         // Please refer to poll_fill_buf's documentation.
      92               3 :         const EOF: Poll<io::Result<&[u8]>> = Poll::Ready(Ok(&[]));
      93               3 : 
      94               3 :         let mut this = self.project();
      95               4 :         loop {
      96               4 :             if !this.bytes.chunk().is_empty() {
      97               1 :                 let chunk = (*this.bytes).chunk();
      98               1 :                 return Poll::Ready(Ok(chunk));
      99               3 :             }
     100                 : 
     101               3 :             let res = ready!(this.stream.as_mut().get_pin_mut().poll_next(cx));
     102               2 :             match res.transpose().map_err(io_error)? {
     103               2 :                 Some(message) => match message {
     104 UBC           0 :                     Message::Ping(_) => {}
     105               0 :                     Message::Pong(_) => {}
     106               0 :                     Message::Text(text) => {
     107               0 :                         // We expect to see only binary messages.
     108               0 :                         let error = "unexpected text message in the websocket";
     109               0 :                         warn!(length = text.len(), error);
     110               0 :                         return Poll::Ready(Err(io_error(error)));
     111                 :                     }
     112                 :                     Message::Frame(_) => {
     113                 :                         // This case is impossible according to Frame's doc.
     114               0 :                         panic!("unexpected raw frame in the websocket");
     115                 :                     }
     116 CBC           1 :                     Message::Binary(chunk) => {
     117               1 :                         assert!(this.bytes.is_empty());
     118               1 :                         *this.bytes = Bytes::from(chunk);
     119                 :                     }
     120               1 :                     Message::Close(_) => return EOF,
     121                 :                 },
     122 UBC           0 :                 None => return EOF,
     123                 :             }
     124                 :         }
     125 CBC           3 :     }
     126                 : 
     127               2 :     fn consume(self: Pin<&mut Self>, amount: usize) {
     128               2 :         self.project().bytes.advance(amount);
     129               2 :     }
     130                 : }
     131                 : 
     132 UBC           0 : pub async fn serve_websocket(
     133               0 :     config: &'static ProxyConfig,
     134               0 :     ctx: &mut RequestMonitoring,
     135               0 :     websocket: HyperWebsocket,
     136               0 :     cancel_map: &CancelMap,
     137               0 :     hostname: Option<String>,
     138               0 :     endpoint_rate_limiter: Arc<EndpointRateLimiter>,
     139               0 : ) -> anyhow::Result<()> {
     140               0 :     let websocket = websocket.await?;
     141               0 :     handle_client(
     142               0 :         config,
     143               0 :         ctx,
     144               0 :         cancel_map,
     145               0 :         WebSocketRw::new(websocket),
     146               0 :         ClientMode::Websockets { hostname },
     147               0 :         endpoint_rate_limiter,
     148               0 :     )
     149               0 :     .await?;
     150               0 :     Ok(())
     151               0 : }
     152                 : 
     153                 : #[cfg(test)]
     154                 : mod tests {
     155                 :     use std::pin::pin;
     156                 : 
     157                 :     use futures::{SinkExt, StreamExt};
     158                 :     use hyper_tungstenite::{
     159                 :         tungstenite::{protocol::Role, Message},
     160                 :         WebSocketStream,
     161                 :     };
     162                 :     use tokio::{
     163                 :         io::{duplex, AsyncReadExt, AsyncWriteExt},
     164                 :         task::JoinSet,
     165                 :     };
     166                 : 
     167                 :     use super::WebSocketRw;
     168                 : 
     169 CBC           1 :     #[tokio::test]
     170               1 :     async fn websocket_stream_wrapper_happy_path() {
     171               1 :         let (stream1, stream2) = duplex(1024);
     172               1 : 
     173               1 :         let mut js = JoinSet::new();
     174               1 : 
     175               1 :         js.spawn(async move {
     176               1 :             let mut client = WebSocketStream::from_raw_socket(stream1, Role::Client, None).await;
     177                 : 
     178               1 :             client
     179               1 :                 .send(Message::Binary(b"hello world".to_vec()))
     180 UBC           0 :                 .await
     181 CBC           1 :                 .unwrap();
     182                 : 
     183               1 :             let message = client.next().await.unwrap().unwrap();
     184               1 :             assert_eq!(message, Message::Binary(b"websockets are cool".to_vec()));
     185                 : 
     186               1 :             client.close(None).await.unwrap();
     187               1 :         });
     188               1 : 
     189               1 :         js.spawn(async move {
     190               1 :             let mut rw = pin!(WebSocketRw::new(
     191               1 :                 WebSocketStream::from_raw_socket(stream2, Role::Server, None).await
     192                 :             ));
     193                 : 
     194               1 :             let mut buf = vec![0; 1024];
     195               1 :             let n = rw.read(&mut buf).await.unwrap();
     196               1 :             assert_eq!(&buf[..n], b"hello world");
     197                 : 
     198               1 :             rw.write_all(b"websockets are cool").await.unwrap();
     199               1 :             rw.flush().await.unwrap();
     200                 : 
     201               1 :             let n = rw.read_to_end(&mut buf).await.unwrap();
     202               1 :             assert_eq!(n, 0);
     203               1 :         });
     204               1 : 
     205               1 :         js.join_next().await.unwrap().unwrap();
     206               1 :         js.join_next().await.unwrap().unwrap();
     207                 :     }
     208                 : }
        

Generated by: LCOV version 2.1-beta