|             Line data    Source code 
       1              : use std::pin::Pin;
       2              : use std::sync::Arc;
       3              : use std::task::{Context, Poll, ready};
       4              : 
       5              : use anyhow::Context as _;
       6              : use bytes::{Buf, BufMut, Bytes, BytesMut};
       7              : use framed_websockets::{Frame, OpCode, WebSocketServer};
       8              : use futures::{Sink, Stream};
       9              : use hyper::upgrade::OnUpgrade;
      10              : use hyper_util::rt::TokioIo;
      11              : use pin_project_lite::pin_project;
      12              : use tokio::io::{self, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
      13              : use tracing::warn;
      14              : 
      15              : use crate::cancellation::CancellationHandler;
      16              : use crate::config::ProxyConfig;
      17              : use crate::context::RequestContext;
      18              : use crate::error::ReportableError;
      19              : use crate::metrics::Metrics;
      20              : use crate::pglb::{ClientMode, handle_connection};
      21              : use crate::proxy::ErrorSource;
      22              : use crate::rate_limiter::EndpointRateLimiter;
      23              : 
      24              : pin_project! {
      25              :     /// This is a wrapper around a [`WebSocketStream`] that
      26              :     /// implements [`AsyncRead`] and [`AsyncWrite`].
      27              :     pub(crate) struct WebSocketRw<S> {
      28              :         #[pin]
      29              :         stream: WebSocketServer<S>,
      30              :         recv: Bytes,
      31              :         send: BytesMut,
      32              :     }
      33              : }
      34              : 
      35              : impl<S> WebSocketRw<S> {
      36            1 :     pub(crate) fn new(stream: WebSocketServer<S>) -> Self {
      37            1 :         Self {
      38            1 :             stream,
      39            1 :             recv: Bytes::new(),
      40            1 :             send: BytesMut::new(),
      41            1 :         }
      42            1 :     }
      43              : }
      44              : 
      45              : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for WebSocketRw<S> {
      46            1 :     fn poll_write(
      47            1 :         self: Pin<&mut Self>,
      48            1 :         cx: &mut Context<'_>,
      49            1 :         buf: &[u8],
      50            1 :     ) -> Poll<io::Result<usize>> {
      51            1 :         let this = self.project();
      52            1 :         let mut stream = this.stream;
      53              : 
      54            1 :         ready!(stream.as_mut().poll_ready(cx).map_err(io::Error::other))?;
      55              : 
      56            1 :         this.send.put(buf);
      57            1 :         match stream.as_mut().start_send(Frame::binary(this.send.split())) {
      58            1 :             Ok(()) => Poll::Ready(Ok(buf.len())),
      59            0 :             Err(e) => Poll::Ready(Err(io::Error::other(e))),
      60              :         }
      61            1 :     }
      62              : 
      63            1 :     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
      64            1 :         let stream = self.project().stream;
      65            1 :         stream.poll_flush(cx).map_err(io::Error::other)
      66            1 :     }
      67              : 
      68            0 :     fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
      69            0 :         let stream = self.project().stream;
      70            0 :         stream.poll_close(cx).map_err(io::Error::other)
      71            0 :     }
      72              : }
      73              : 
      74              : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for WebSocketRw<S> {
      75            3 :     fn poll_read(
      76            3 :         mut self: Pin<&mut Self>,
      77            3 :         cx: &mut Context<'_>,
      78            3 :         buf: &mut ReadBuf<'_>,
      79            3 :     ) -> Poll<io::Result<()>> {
      80            3 :         let bytes = ready!(self.as_mut().poll_fill_buf(cx))?;
      81            2 :         let len = std::cmp::min(bytes.len(), buf.remaining());
      82            2 :         buf.put_slice(&bytes[..len]);
      83            2 :         self.consume(len);
      84            2 :         Poll::Ready(Ok(()))
      85            3 :     }
      86              : }
      87              : 
      88              : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncBufRead for WebSocketRw<S> {
      89            3 :     fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
      90              :         // Please refer to poll_fill_buf's documentation.
      91              :         const EOF: Poll<io::Result<&[u8]>> = Poll::Ready(Ok(&[]));
      92              : 
      93            3 :         let mut this = self.project();
      94              :         loop {
      95            4 :             if !this.recv.chunk().is_empty() {
      96            1 :                 let chunk = (*this.recv).chunk();
      97            1 :                 return Poll::Ready(Ok(chunk));
      98            3 :             }
      99              : 
     100            3 :             let res = ready!(this.stream.as_mut().poll_next(cx));
     101            2 :             match res.transpose().map_err(io::Error::other)? {
     102            2 :                 Some(message) => match message.opcode {
     103            0 :                     OpCode::Ping => {}
     104            0 :                     OpCode::Pong => {}
     105              :                     OpCode::Text => {
     106              :                         // We expect to see only binary messages.
     107            0 :                         let error = "unexpected text message in the websocket";
     108            0 :                         warn!(length = message.payload.len(), error);
     109            0 :                         return Poll::Ready(Err(io::Error::other(error)));
     110              :                     }
     111              :                     OpCode::Binary | OpCode::Continuation => {
     112            1 :                         debug_assert!(this.recv.is_empty());
     113            1 :                         *this.recv = message.payload.freeze();
     114              :                     }
     115            1 :                     OpCode::Close => return EOF,
     116              :                 },
     117            0 :                 None => return EOF,
     118              :             }
     119              :         }
     120            3 :     }
     121              : 
     122            2 :     fn consume(self: Pin<&mut Self>, amount: usize) {
     123            2 :         self.project().recv.advance(amount);
     124            2 :     }
     125              : }
     126              : 
     127              : #[allow(clippy::too_many_arguments)]
     128            0 : pub(crate) async fn serve_websocket(
     129            0 :     config: &'static ProxyConfig,
     130            0 :     auth_backend: &'static crate::auth::Backend<'static, ()>,
     131            0 :     ctx: RequestContext,
     132            0 :     websocket: OnUpgrade,
     133            0 :     cancellation_handler: Arc<CancellationHandler>,
     134            0 :     endpoint_rate_limiter: Arc<EndpointRateLimiter>,
     135            0 :     hostname: Option<String>,
     136            0 :     cancellations: tokio_util::task::task_tracker::TaskTracker,
     137            0 : ) -> anyhow::Result<()> {
     138            0 :     let websocket = websocket.await?;
     139            0 :     let websocket = WebSocketServer::after_handshake(TokioIo::new(websocket));
     140              : 
     141            0 :     let conn_gauge = Metrics::get()
     142            0 :         .proxy
     143            0 :         .client_connections
     144            0 :         .guard(crate::metrics::Protocol::Ws);
     145              : 
     146            0 :     let res = Box::pin(handle_connection(
     147            0 :         config,
     148            0 :         auth_backend,
     149            0 :         &ctx,
     150            0 :         cancellation_handler,
     151            0 :         WebSocketRw::new(websocket),
     152            0 :         ClientMode::Websockets { hostname },
     153            0 :         endpoint_rate_limiter,
     154            0 :         conn_gauge,
     155            0 :         cancellations,
     156            0 :     ))
     157            0 :     .await;
     158              : 
     159            0 :     match res {
     160            0 :         Err(e) => {
     161            0 :             ctx.set_error_kind(e.get_error_kind());
     162            0 :             Err(e.into())
     163              :         }
     164              :         Ok(None) => {
     165            0 :             ctx.set_success();
     166            0 :             Ok(())
     167              :         }
     168            0 :         Ok(Some(p)) => {
     169            0 :             ctx.set_success();
     170            0 :             ctx.log_connect();
     171            0 :             match p.proxy_pass().await {
     172            0 :                 Ok(()) => Ok(()),
     173            0 :                 Err(ErrorSource::Client(err)) => Err(err).context("client"),
     174            0 :                 Err(ErrorSource::Compute(err)) => Err(err).context("compute"),
     175              :             }
     176              :         }
     177              :     }
     178            0 : }
     179              : 
     180              : #[cfg(test)]
     181              : mod tests {
     182              :     use std::pin::pin;
     183              : 
     184              :     use framed_websockets::WebSocketServer;
     185              :     use futures::{SinkExt, StreamExt};
     186              :     use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
     187              :     use tokio::task::JoinSet;
     188              :     use tokio_tungstenite::WebSocketStream;
     189              :     use tokio_tungstenite::tungstenite::Message;
     190              :     use tokio_tungstenite::tungstenite::protocol::Role;
     191              : 
     192              :     use super::WebSocketRw;
     193              : 
     194              :     #[tokio::test]
     195            1 :     async fn websocket_stream_wrapper_happy_path() {
     196            1 :         let (stream1, stream2) = duplex(1024);
     197              : 
     198            1 :         let mut js = JoinSet::new();
     199              : 
     200            1 :         js.spawn(async move {
     201            1 :             let mut client = WebSocketStream::from_raw_socket(stream1, Role::Client, None).await;
     202              : 
     203            1 :             client
     204            1 :                 .send(Message::Binary(b"hello world".to_vec()))
     205            1 :                 .await
     206            1 :                 .unwrap();
     207              : 
     208            1 :             let message = client.next().await.unwrap().unwrap();
     209            1 :             assert_eq!(message, Message::Binary(b"websockets are cool".to_vec()));
     210              : 
     211            1 :             client.close(None).await.unwrap();
     212            1 :         });
     213              : 
     214            1 :         js.spawn(async move {
     215            1 :             let mut rw = pin!(WebSocketRw::new(WebSocketServer::after_handshake(stream2)));
     216              : 
     217            1 :             let mut buf = vec![0; 1024];
     218            1 :             let n = rw.read(&mut buf).await.unwrap();
     219            1 :             assert_eq!(&buf[..n], b"hello world");
     220              : 
     221            1 :             rw.write_all(b"websockets are cool").await.unwrap();
     222            1 :             rw.flush().await.unwrap();
     223              : 
     224            1 :             let n = rw.read_to_end(&mut buf).await.unwrap();
     225            1 :             assert_eq!(n, 0);
     226            1 :         });
     227              : 
     228            1 :         js.join_next().await.unwrap().unwrap();
     229            1 :         js.join_next().await.unwrap().unwrap();
     230            1 :     }
     231              : }
         |