Line data Source code
1 : use crate::{
2 : cancellation::CancellationHandlerMain,
3 : config::ProxyConfig,
4 : context::RequestMonitoring,
5 : error::{io_error, ReportableError},
6 : metrics::Metrics,
7 : proxy::{handle_client, ClientMode},
8 : rate_limiter::EndpointRateLimiter,
9 : };
10 : use bytes::{Buf, BufMut, Bytes, BytesMut};
11 : use framed_websockets::{Frame, OpCode, WebSocketServer};
12 : use futures::{Sink, Stream};
13 : use hyper1::upgrade::OnUpgrade;
14 : use hyper_util::rt::TokioIo;
15 : use pin_project_lite::pin_project;
16 :
17 : use std::{
18 : pin::Pin,
19 : sync::Arc,
20 : task::{ready, Context, Poll},
21 : };
22 : use tokio::io::{self, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
23 : use tracing::warn;
24 :
25 : pin_project! {
26 : /// This is a wrapper around a [`WebSocketStream`] that
27 : /// implements [`AsyncRead`] and [`AsyncWrite`].
28 : pub struct WebSocketRw<S> {
29 : #[pin]
30 : stream: WebSocketServer<S>,
31 : recv: Bytes,
32 : send: BytesMut,
33 : }
34 : }
35 :
36 : impl<S> WebSocketRw<S> {
37 2 : pub fn new(stream: WebSocketServer<S>) -> Self {
38 2 : Self {
39 2 : stream,
40 2 : recv: Bytes::new(),
41 2 : send: BytesMut::new(),
42 2 : }
43 2 : }
44 : }
45 :
46 : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for WebSocketRw<S> {
47 2 : fn poll_write(
48 2 : self: Pin<&mut Self>,
49 2 : cx: &mut Context<'_>,
50 2 : buf: &[u8],
51 2 : ) -> Poll<io::Result<usize>> {
52 2 : let this = self.project();
53 2 : let mut stream = this.stream;
54 :
55 2 : ready!(stream.as_mut().poll_ready(cx).map_err(io_error))?;
56 :
57 2 : this.send.put(buf);
58 2 : match stream.as_mut().start_send(Frame::binary(this.send.split())) {
59 2 : Ok(()) => Poll::Ready(Ok(buf.len())),
60 0 : Err(e) => Poll::Ready(Err(io_error(e))),
61 : }
62 2 : }
63 :
64 2 : fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
65 2 : let stream = self.project().stream;
66 2 : stream.poll_flush(cx).map_err(io_error)
67 2 : }
68 :
69 0 : fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
70 0 : let stream = self.project().stream;
71 0 : stream.poll_close(cx).map_err(io_error)
72 0 : }
73 : }
74 :
75 : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for WebSocketRw<S> {
76 6 : fn poll_read(
77 6 : mut self: Pin<&mut Self>,
78 6 : cx: &mut Context<'_>,
79 6 : buf: &mut ReadBuf<'_>,
80 6 : ) -> Poll<io::Result<()>> {
81 6 : let bytes = ready!(self.as_mut().poll_fill_buf(cx))?;
82 4 : let len = std::cmp::min(bytes.len(), buf.remaining());
83 4 : buf.put_slice(&bytes[..len]);
84 4 : self.consume(len);
85 4 : Poll::Ready(Ok(()))
86 6 : }
87 : }
88 :
89 : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncBufRead for WebSocketRw<S> {
90 6 : fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
91 6 : // Please refer to poll_fill_buf's documentation.
92 6 : const EOF: Poll<io::Result<&[u8]>> = Poll::Ready(Ok(&[]));
93 6 :
94 6 : let mut this = self.project();
95 8 : loop {
96 8 : if !this.recv.chunk().is_empty() {
97 2 : let chunk = (*this.recv).chunk();
98 2 : return Poll::Ready(Ok(chunk));
99 6 : }
100 :
101 6 : let res = ready!(this.stream.as_mut().poll_next(cx));
102 4 : match res.transpose().map_err(io_error)? {
103 4 : Some(message) => match message.opcode {
104 0 : OpCode::Ping => {}
105 0 : OpCode::Pong => {}
106 : OpCode::Text => {
107 : // We expect to see only binary messages.
108 0 : let error = "unexpected text message in the websocket";
109 0 : warn!(length = message.payload.len(), error);
110 0 : return Poll::Ready(Err(io_error(error)));
111 : }
112 : OpCode::Binary | OpCode::Continuation => {
113 2 : debug_assert!(this.recv.is_empty());
114 2 : *this.recv = message.payload.freeze();
115 : }
116 2 : OpCode::Close => return EOF,
117 : },
118 0 : None => return EOF,
119 : }
120 : }
121 6 : }
122 :
123 4 : fn consume(self: Pin<&mut Self>, amount: usize) {
124 4 : self.project().recv.advance(amount);
125 4 : }
126 : }
127 :
128 0 : pub async fn serve_websocket(
129 0 : config: &'static ProxyConfig,
130 0 : mut ctx: RequestMonitoring,
131 0 : websocket: OnUpgrade,
132 0 : cancellation_handler: Arc<CancellationHandlerMain>,
133 0 : endpoint_rate_limiter: Arc<EndpointRateLimiter>,
134 0 : hostname: Option<String>,
135 0 : ) -> anyhow::Result<()> {
136 0 : let websocket = websocket.await?;
137 0 : let websocket = WebSocketServer::after_handshake(TokioIo::new(websocket));
138 0 :
139 0 : let conn_gauge = Metrics::get()
140 0 : .proxy
141 0 : .client_connections
142 0 : .guard(crate::metrics::Protocol::Ws);
143 :
144 0 : let res = Box::pin(handle_client(
145 0 : config,
146 0 : &mut ctx,
147 0 : cancellation_handler,
148 0 : WebSocketRw::new(websocket),
149 0 : ClientMode::Websockets { hostname },
150 0 : endpoint_rate_limiter,
151 0 : conn_gauge,
152 0 : ))
153 0 : .await;
154 :
155 0 : match res {
156 0 : Err(e) => {
157 0 : // todo: log and push to ctx the error kind
158 0 : ctx.set_error_kind(e.get_error_kind());
159 0 : Err(e.into())
160 : }
161 : Ok(None) => {
162 0 : ctx.set_success();
163 0 : Ok(())
164 : }
165 0 : Ok(Some(p)) => {
166 0 : ctx.set_success();
167 0 : ctx.log_connect();
168 0 : p.proxy_pass().await
169 : }
170 : }
171 0 : }
172 :
173 : #[cfg(test)]
174 : mod tests {
175 : use std::pin::pin;
176 :
177 : use framed_websockets::WebSocketServer;
178 : use futures::{SinkExt, StreamExt};
179 : use tokio::{
180 : io::{duplex, AsyncReadExt, AsyncWriteExt},
181 : task::JoinSet,
182 : };
183 : use tokio_tungstenite::{
184 : tungstenite::{protocol::Role, Message},
185 : WebSocketStream,
186 : };
187 :
188 : use super::WebSocketRw;
189 :
190 : #[tokio::test]
191 2 : async fn websocket_stream_wrapper_happy_path() {
192 2 : let (stream1, stream2) = duplex(1024);
193 2 :
194 2 : let mut js = JoinSet::new();
195 2 :
196 2 : js.spawn(async move {
197 2 : let mut client = WebSocketStream::from_raw_socket(stream1, Role::Client, None).await;
198 2 :
199 2 : client
200 2 : .send(Message::Binary(b"hello world".to_vec()))
201 2 : .await
202 2 : .unwrap();
203 2 :
204 2 : let message = client.next().await.unwrap().unwrap();
205 2 : assert_eq!(message, Message::Binary(b"websockets are cool".to_vec()));
206 2 :
207 2 : client.close(None).await.unwrap();
208 2 : });
209 2 :
210 2 : js.spawn(async move {
211 2 : let mut rw = pin!(WebSocketRw::new(WebSocketServer::after_handshake(stream2)));
212 2 :
213 2 : let mut buf = vec![0; 1024];
214 2 : let n = rw.read(&mut buf).await.unwrap();
215 2 : assert_eq!(&buf[..n], b"hello world");
216 2 :
217 2 : rw.write_all(b"websockets are cool").await.unwrap();
218 2 : rw.flush().await.unwrap();
219 2 :
220 2 : let n = rw.read_to_end(&mut buf).await.unwrap();
221 2 : assert_eq!(n, 0);
222 2 : });
223 2 :
224 2 : js.join_next().await.unwrap().unwrap();
225 2 : js.join_next().await.unwrap().unwrap();
226 2 : }
227 : }
|