Line data Source code
1 : use std::pin::Pin;
2 : use std::sync::Arc;
3 : use std::task::{Context, Poll, ready};
4 :
5 : use anyhow::Context as _;
6 : use bytes::{Buf, BufMut, Bytes, BytesMut};
7 : use framed_websockets::{Frame, OpCode, WebSocketServer};
8 : use futures::{Sink, Stream};
9 : use hyper::upgrade::OnUpgrade;
10 : use hyper_util::rt::TokioIo;
11 : use pin_project_lite::pin_project;
12 : use tokio::io::{self, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
13 : use tracing::warn;
14 :
15 : use crate::cancellation::CancellationHandler;
16 : use crate::config::ProxyConfig;
17 : use crate::context::RequestContext;
18 : use crate::error::ReportableError;
19 : use crate::metrics::Metrics;
20 : use crate::pglb::{ClientMode, handle_connection};
21 : use crate::proxy::ErrorSource;
22 : use crate::rate_limiter::EndpointRateLimiter;
23 :
24 : pin_project! {
25 : /// This is a wrapper around a [`WebSocketStream`] that
26 : /// implements [`AsyncRead`] and [`AsyncWrite`].
27 : pub(crate) struct WebSocketRw<S> {
28 : #[pin]
29 : stream: WebSocketServer<S>,
30 : recv: Bytes,
31 : send: BytesMut,
32 : }
33 : }
34 :
35 : impl<S> WebSocketRw<S> {
36 1 : pub(crate) fn new(stream: WebSocketServer<S>) -> Self {
37 1 : Self {
38 1 : stream,
39 1 : recv: Bytes::new(),
40 1 : send: BytesMut::new(),
41 1 : }
42 1 : }
43 : }
44 :
45 : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for WebSocketRw<S> {
46 1 : fn poll_write(
47 1 : self: Pin<&mut Self>,
48 1 : cx: &mut Context<'_>,
49 1 : buf: &[u8],
50 1 : ) -> Poll<io::Result<usize>> {
51 1 : let this = self.project();
52 1 : let mut stream = this.stream;
53 :
54 1 : ready!(stream.as_mut().poll_ready(cx).map_err(io::Error::other))?;
55 :
56 1 : this.send.put(buf);
57 1 : match stream.as_mut().start_send(Frame::binary(this.send.split())) {
58 1 : Ok(()) => Poll::Ready(Ok(buf.len())),
59 0 : Err(e) => Poll::Ready(Err(io::Error::other(e))),
60 : }
61 1 : }
62 :
63 1 : fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
64 1 : let stream = self.project().stream;
65 1 : stream.poll_flush(cx).map_err(io::Error::other)
66 1 : }
67 :
68 0 : fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
69 0 : let stream = self.project().stream;
70 0 : stream.poll_close(cx).map_err(io::Error::other)
71 0 : }
72 : }
73 :
74 : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for WebSocketRw<S> {
75 3 : fn poll_read(
76 3 : mut self: Pin<&mut Self>,
77 3 : cx: &mut Context<'_>,
78 3 : buf: &mut ReadBuf<'_>,
79 3 : ) -> Poll<io::Result<()>> {
80 3 : let bytes = ready!(self.as_mut().poll_fill_buf(cx))?;
81 2 : let len = std::cmp::min(bytes.len(), buf.remaining());
82 2 : buf.put_slice(&bytes[..len]);
83 2 : self.consume(len);
84 2 : Poll::Ready(Ok(()))
85 3 : }
86 : }
87 :
88 : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncBufRead for WebSocketRw<S> {
89 3 : fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
90 : // Please refer to poll_fill_buf's documentation.
91 : const EOF: Poll<io::Result<&[u8]>> = Poll::Ready(Ok(&[]));
92 :
93 3 : let mut this = self.project();
94 : loop {
95 4 : if !this.recv.chunk().is_empty() {
96 1 : let chunk = (*this.recv).chunk();
97 1 : return Poll::Ready(Ok(chunk));
98 3 : }
99 :
100 3 : let res = ready!(this.stream.as_mut().poll_next(cx));
101 2 : match res.transpose().map_err(io::Error::other)? {
102 2 : Some(message) => match message.opcode {
103 0 : OpCode::Ping => {}
104 0 : OpCode::Pong => {}
105 : OpCode::Text => {
106 : // We expect to see only binary messages.
107 0 : let error = "unexpected text message in the websocket";
108 0 : warn!(length = message.payload.len(), error);
109 0 : return Poll::Ready(Err(io::Error::other(error)));
110 : }
111 : OpCode::Binary | OpCode::Continuation => {
112 1 : debug_assert!(this.recv.is_empty());
113 1 : *this.recv = message.payload.freeze();
114 : }
115 1 : OpCode::Close => return EOF,
116 : },
117 0 : None => return EOF,
118 : }
119 : }
120 3 : }
121 :
122 2 : fn consume(self: Pin<&mut Self>, amount: usize) {
123 2 : self.project().recv.advance(amount);
124 2 : }
125 : }
126 :
127 : #[allow(clippy::too_many_arguments)]
128 0 : pub(crate) async fn serve_websocket(
129 0 : config: &'static ProxyConfig,
130 0 : auth_backend: &'static crate::auth::Backend<'static, ()>,
131 0 : ctx: RequestContext,
132 0 : websocket: OnUpgrade,
133 0 : cancellation_handler: Arc<CancellationHandler>,
134 0 : endpoint_rate_limiter: Arc<EndpointRateLimiter>,
135 0 : hostname: Option<String>,
136 0 : cancellations: tokio_util::task::task_tracker::TaskTracker,
137 0 : ) -> anyhow::Result<()> {
138 0 : let websocket = websocket.await?;
139 0 : let websocket = WebSocketServer::after_handshake(TokioIo::new(websocket));
140 :
141 0 : let conn_gauge = Metrics::get()
142 0 : .proxy
143 0 : .client_connections
144 0 : .guard(crate::metrics::Protocol::Ws);
145 :
146 0 : let res = Box::pin(handle_connection(
147 0 : config,
148 0 : auth_backend,
149 0 : &ctx,
150 0 : cancellation_handler,
151 0 : WebSocketRw::new(websocket),
152 0 : ClientMode::Websockets { hostname },
153 0 : endpoint_rate_limiter,
154 0 : conn_gauge,
155 0 : cancellations,
156 0 : ))
157 0 : .await;
158 :
159 0 : match res {
160 0 : Err(e) => {
161 0 : ctx.set_error_kind(e.get_error_kind());
162 0 : Err(e.into())
163 : }
164 : Ok(None) => {
165 0 : ctx.set_success();
166 0 : Ok(())
167 : }
168 0 : Ok(Some(p)) => {
169 0 : ctx.set_success();
170 0 : ctx.log_connect();
171 0 : match p.proxy_pass().await {
172 0 : Ok(()) => Ok(()),
173 0 : Err(ErrorSource::Client(err)) => Err(err).context("client"),
174 0 : Err(ErrorSource::Compute(err)) => Err(err).context("compute"),
175 : }
176 : }
177 : }
178 0 : }
179 :
180 : #[cfg(test)]
181 : mod tests {
182 : use std::pin::pin;
183 :
184 : use framed_websockets::WebSocketServer;
185 : use futures::{SinkExt, StreamExt};
186 : use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
187 : use tokio::task::JoinSet;
188 : use tokio_tungstenite::WebSocketStream;
189 : use tokio_tungstenite::tungstenite::Message;
190 : use tokio_tungstenite::tungstenite::protocol::Role;
191 :
192 : use super::WebSocketRw;
193 :
194 : #[tokio::test]
195 1 : async fn websocket_stream_wrapper_happy_path() {
196 1 : let (stream1, stream2) = duplex(1024);
197 :
198 1 : let mut js = JoinSet::new();
199 :
200 1 : js.spawn(async move {
201 1 : let mut client = WebSocketStream::from_raw_socket(stream1, Role::Client, None).await;
202 :
203 1 : client
204 1 : .send(Message::Binary(b"hello world".to_vec()))
205 1 : .await
206 1 : .unwrap();
207 :
208 1 : let message = client.next().await.unwrap().unwrap();
209 1 : assert_eq!(message, Message::Binary(b"websockets are cool".to_vec()));
210 :
211 1 : client.close(None).await.unwrap();
212 1 : });
213 :
214 1 : js.spawn(async move {
215 1 : let mut rw = pin!(WebSocketRw::new(WebSocketServer::after_handshake(stream2)));
216 :
217 1 : let mut buf = vec![0; 1024];
218 1 : let n = rw.read(&mut buf).await.unwrap();
219 1 : assert_eq!(&buf[..n], b"hello world");
220 :
221 1 : rw.write_all(b"websockets are cool").await.unwrap();
222 1 : rw.flush().await.unwrap();
223 :
224 1 : let n = rw.read_to_end(&mut buf).await.unwrap();
225 1 : assert_eq!(n, 0);
226 1 : });
227 :
228 1 : js.join_next().await.unwrap().unwrap();
229 1 : js.join_next().await.unwrap().unwrap();
230 1 : }
231 : }
|