Line data Source code
1 : use crate::proxy::ErrorSource;
2 : use crate::{
3 : cancellation::CancellationHandlerMain,
4 : config::ProxyConfig,
5 : context::RequestMonitoring,
6 : error::{io_error, ReportableError},
7 : metrics::Metrics,
8 : proxy::{handle_client, ClientMode},
9 : rate_limiter::EndpointRateLimiter,
10 : };
11 : use anyhow::Context as _;
12 : use bytes::{Buf, BufMut, Bytes, BytesMut};
13 : use framed_websockets::{Frame, OpCode, WebSocketServer};
14 : use futures::{Sink, Stream};
15 : use hyper1::upgrade::OnUpgrade;
16 : use hyper_util::rt::TokioIo;
17 : use pin_project_lite::pin_project;
18 :
19 : use std::{
20 : pin::Pin,
21 : sync::Arc,
22 : task::{ready, Context, Poll},
23 : };
24 : use tokio::io::{self, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
25 : use tracing::warn;
26 :
27 : pin_project! {
28 : /// This is a wrapper around a [`WebSocketStream`] that
29 : /// implements [`AsyncRead`] and [`AsyncWrite`].
30 : pub(crate) struct WebSocketRw<S> {
31 : #[pin]
32 : stream: WebSocketServer<S>,
33 : recv: Bytes,
34 : send: BytesMut,
35 : }
36 : }
37 :
38 : impl<S> WebSocketRw<S> {
39 1 : pub(crate) fn new(stream: WebSocketServer<S>) -> Self {
40 1 : Self {
41 1 : stream,
42 1 : recv: Bytes::new(),
43 1 : send: BytesMut::new(),
44 1 : }
45 1 : }
46 : }
47 :
48 : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for WebSocketRw<S> {
49 1 : fn poll_write(
50 1 : self: Pin<&mut Self>,
51 1 : cx: &mut Context<'_>,
52 1 : buf: &[u8],
53 1 : ) -> Poll<io::Result<usize>> {
54 1 : let this = self.project();
55 1 : let mut stream = this.stream;
56 :
57 1 : ready!(stream.as_mut().poll_ready(cx).map_err(io_error))?;
58 :
59 1 : this.send.put(buf);
60 1 : match stream.as_mut().start_send(Frame::binary(this.send.split())) {
61 1 : Ok(()) => Poll::Ready(Ok(buf.len())),
62 0 : Err(e) => Poll::Ready(Err(io_error(e))),
63 : }
64 1 : }
65 :
66 1 : fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
67 1 : let stream = self.project().stream;
68 1 : stream.poll_flush(cx).map_err(io_error)
69 1 : }
70 :
71 0 : fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
72 0 : let stream = self.project().stream;
73 0 : stream.poll_close(cx).map_err(io_error)
74 0 : }
75 : }
76 :
77 : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for WebSocketRw<S> {
78 3 : fn poll_read(
79 3 : mut self: Pin<&mut Self>,
80 3 : cx: &mut Context<'_>,
81 3 : buf: &mut ReadBuf<'_>,
82 3 : ) -> Poll<io::Result<()>> {
83 3 : let bytes = ready!(self.as_mut().poll_fill_buf(cx))?;
84 2 : let len = std::cmp::min(bytes.len(), buf.remaining());
85 2 : buf.put_slice(&bytes[..len]);
86 2 : self.consume(len);
87 2 : Poll::Ready(Ok(()))
88 3 : }
89 : }
90 :
91 : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncBufRead for WebSocketRw<S> {
92 3 : fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
93 : // Please refer to poll_fill_buf's documentation.
94 : const EOF: Poll<io::Result<&[u8]>> = Poll::Ready(Ok(&[]));
95 :
96 3 : let mut this = self.project();
97 : loop {
98 4 : if !this.recv.chunk().is_empty() {
99 1 : let chunk = (*this.recv).chunk();
100 1 : return Poll::Ready(Ok(chunk));
101 3 : }
102 :
103 3 : let res = ready!(this.stream.as_mut().poll_next(cx));
104 2 : match res.transpose().map_err(io_error)? {
105 2 : Some(message) => match message.opcode {
106 0 : OpCode::Ping => {}
107 0 : OpCode::Pong => {}
108 : OpCode::Text => {
109 : // We expect to see only binary messages.
110 0 : let error = "unexpected text message in the websocket";
111 0 : warn!(length = message.payload.len(), error);
112 0 : return Poll::Ready(Err(io_error(error)));
113 : }
114 : OpCode::Binary | OpCode::Continuation => {
115 1 : debug_assert!(this.recv.is_empty());
116 1 : *this.recv = message.payload.freeze();
117 : }
118 1 : OpCode::Close => return EOF,
119 : },
120 0 : None => return EOF,
121 : }
122 : }
123 3 : }
124 :
125 2 : fn consume(self: Pin<&mut Self>, amount: usize) {
126 2 : self.project().recv.advance(amount);
127 2 : }
128 : }
129 :
130 0 : pub(crate) async fn serve_websocket(
131 0 : config: &'static ProxyConfig,
132 0 : ctx: RequestMonitoring,
133 0 : websocket: OnUpgrade,
134 0 : cancellation_handler: Arc<CancellationHandlerMain>,
135 0 : endpoint_rate_limiter: Arc<EndpointRateLimiter>,
136 0 : hostname: Option<String>,
137 0 : ) -> anyhow::Result<()> {
138 0 : let websocket = websocket.await?;
139 0 : let websocket = WebSocketServer::after_handshake(TokioIo::new(websocket));
140 0 :
141 0 : let conn_gauge = Metrics::get()
142 0 : .proxy
143 0 : .client_connections
144 0 : .guard(crate::metrics::Protocol::Ws);
145 :
146 0 : let res = Box::pin(handle_client(
147 0 : config,
148 0 : &ctx,
149 0 : cancellation_handler,
150 0 : WebSocketRw::new(websocket),
151 0 : ClientMode::Websockets { hostname },
152 0 : endpoint_rate_limiter,
153 0 : conn_gauge,
154 0 : ))
155 0 : .await;
156 :
157 0 : match res {
158 0 : Err(e) => {
159 0 : // todo: log and push to ctx the error kind
160 0 : ctx.set_error_kind(e.get_error_kind());
161 0 : Err(e.into())
162 : }
163 : Ok(None) => {
164 0 : ctx.set_success();
165 0 : Ok(())
166 : }
167 0 : Ok(Some(p)) => {
168 0 : ctx.set_success();
169 0 : ctx.log_connect();
170 0 : match p.proxy_pass().await {
171 0 : Ok(()) => Ok(()),
172 0 : Err(ErrorSource::Client(err)) => Err(err).context("client"),
173 0 : Err(ErrorSource::Compute(err)) => Err(err).context("compute"),
174 : }
175 : }
176 : }
177 0 : }
178 :
179 : #[cfg(test)]
180 : mod tests {
181 : use std::pin::pin;
182 :
183 : use framed_websockets::WebSocketServer;
184 : use futures::{SinkExt, StreamExt};
185 : use tokio::{
186 : io::{duplex, AsyncReadExt, AsyncWriteExt},
187 : task::JoinSet,
188 : };
189 : use tokio_tungstenite::{
190 : tungstenite::{protocol::Role, Message},
191 : WebSocketStream,
192 : };
193 :
194 : use super::WebSocketRw;
195 :
196 : #[tokio::test]
197 1 : async fn websocket_stream_wrapper_happy_path() {
198 1 : let (stream1, stream2) = duplex(1024);
199 1 :
200 1 : let mut js = JoinSet::new();
201 1 :
202 1 : js.spawn(async move {
203 1 : let mut client = WebSocketStream::from_raw_socket(stream1, Role::Client, None).await;
204 1 :
205 1 : client
206 1 : .send(Message::Binary(b"hello world".to_vec()))
207 1 : .await
208 1 : .unwrap();
209 1 :
210 1 : let message = client.next().await.unwrap().unwrap();
211 1 : assert_eq!(message, Message::Binary(b"websockets are cool".to_vec()));
212 1 :
213 1 : client.close(None).await.unwrap();
214 1 : });
215 1 :
216 1 : js.spawn(async move {
217 1 : let mut rw = pin!(WebSocketRw::new(WebSocketServer::after_handshake(stream2)));
218 1 :
219 1 : let mut buf = vec![0; 1024];
220 1 : let n = rw.read(&mut buf).await.unwrap();
221 1 : assert_eq!(&buf[..n], b"hello world");
222 1 :
223 1 : rw.write_all(b"websockets are cool").await.unwrap();
224 1 : rw.flush().await.unwrap();
225 1 :
226 1 : let n = rw.read_to_end(&mut buf).await.unwrap();
227 1 : assert_eq!(n, 0);
228 1 : });
229 1 :
230 1 : js.join_next().await.unwrap().unwrap();
231 1 : js.join_next().await.unwrap().unwrap();
232 1 : }
233 : }
|