Line data Source code
1 : use std::pin::Pin;
2 : use std::sync::Arc;
3 : use std::{io, task};
4 :
5 : use rustls::ServerConfig;
6 : use thiserror::Error;
7 : use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
8 : use tokio_rustls::server::TlsStream;
9 :
10 : use crate::error::{ErrorKind, ReportableError, UserFacingError};
11 : use crate::metrics::Metrics;
12 : use crate::pqproto::{
13 : BeMessage, FE_PASSWORD_MESSAGE, FeStartupPacket, SQLSTATE_INTERNAL_ERROR, WriteBuf,
14 : read_message, read_startup,
15 : };
16 : use crate::tls::TlsServerEndPoint;
17 :
18 : /// Stream wrapper which implements libpq's protocol.
19 : ///
20 : /// NOTE: This object deliberately doesn't implement [`AsyncRead`]
21 : /// or [`AsyncWrite`] to prevent subtle errors (e.g. trying
22 : /// to pass random malformed bytes through the connection).
23 : pub struct PqStream<S> {
24 : stream: S,
25 : read: Vec<u8>,
26 : write: WriteBuf,
27 : }
28 :
29 : impl<S> PqStream<S> {
30 35 : pub fn get_ref(&self) -> &S {
31 35 : &self.stream
32 35 : }
33 :
34 : /// Construct a new libpq protocol wrapper over a stream without the first startup message.
35 : #[cfg(test)]
36 3 : pub fn new_skip_handshake(stream: S) -> Self {
37 3 : Self {
38 3 : stream,
39 3 : read: Vec::new(),
40 3 : write: WriteBuf::new(),
41 3 : }
42 3 : }
43 : }
44 :
45 : impl<S: AsyncRead + AsyncWrite + Unpin> PqStream<S> {
46 : /// Construct a new libpq protocol wrapper and read the first startup message.
47 : ///
48 : /// This is not cancel safe.
49 42 : pub async fn parse_startup(mut stream: S) -> io::Result<(Self, FeStartupPacket)> {
50 42 : let startup = read_startup(&mut stream).await?;
51 42 : Ok((
52 42 : Self {
53 42 : stream,
54 42 : read: Vec::new(),
55 42 : write: WriteBuf::new(),
56 42 : },
57 42 : startup,
58 42 : ))
59 42 : }
60 :
61 : /// Tell the client that encryption is not supported.
62 : ///
63 : /// This is not cancel safe
64 0 : pub async fn reject_encryption(&mut self) -> io::Result<FeStartupPacket> {
65 : // N for No.
66 0 : self.write.encryption(b'N');
67 0 : self.flush().await?;
68 0 : read_startup(&mut self.stream).await
69 0 : }
70 : }
71 :
72 : impl<S: AsyncRead + Unpin> PqStream<S> {
73 : /// Read a raw postgres packet, which will respect the max length requested.
74 : /// This is not cancel safe.
75 26 : async fn read_raw_expect(&mut self, tag: u8, max: u32) -> io::Result<&mut [u8]> {
76 26 : let (actual_tag, msg) = read_message(&mut self.stream, &mut self.read, max).await?;
77 25 : if actual_tag != tag {
78 0 : return Err(io::Error::other(format!(
79 0 : "incorrect message tag, expected {:?}, got {:?}",
80 0 : tag as char, actual_tag as char,
81 0 : )));
82 25 : }
83 25 : Ok(msg)
84 26 : }
85 :
86 : /// Read a postgres password message, which will respect the max length requested.
87 : /// This is not cancel safe.
88 26 : pub async fn read_password_message(&mut self) -> io::Result<&mut [u8]> {
89 : // passwords are usually pretty short
90 : // and SASL SCRAM messages are no longer than 256 bytes in my testing
91 : // (a few hashes and random bytes, encoded into base64).
92 : const MAX_PASSWORD_LENGTH: u32 = 512;
93 26 : self.read_raw_expect(FE_PASSWORD_MESSAGE, MAX_PASSWORD_LENGTH)
94 26 : .await
95 26 : }
96 : }
97 :
98 : #[derive(Debug)]
99 : pub struct ReportedError {
100 : source: anyhow::Error,
101 : error_kind: ErrorKind,
102 : }
103 :
104 : impl ReportedError {
105 1 : pub fn new(e: (impl UserFacingError + Into<anyhow::Error>)) -> Self {
106 1 : let error_kind = e.get_error_kind();
107 1 : Self {
108 1 : source: e.into(),
109 1 : error_kind,
110 1 : }
111 1 : }
112 : }
113 :
114 : impl std::fmt::Display for ReportedError {
115 1 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116 1 : self.source.fmt(f)
117 1 : }
118 : }
119 :
120 : impl std::error::Error for ReportedError {
121 0 : fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
122 0 : self.source.source()
123 0 : }
124 : }
125 :
126 : impl ReportableError for ReportedError {
127 0 : fn get_error_kind(&self) -> ErrorKind {
128 0 : self.error_kind
129 0 : }
130 : }
131 :
132 : impl<S: AsyncWrite + Unpin> PqStream<S> {
133 : /// Tell the client that we are willing to accept SSL.
134 : /// This is not cancel safe
135 20 : pub async fn accept_tls(mut self) -> io::Result<S> {
136 : // S for SSL.
137 20 : self.write.encryption(b'S');
138 20 : self.flush().await?;
139 20 : Ok(self.stream)
140 20 : }
141 :
142 : /// Assert that we are using direct TLS.
143 0 : pub fn accept_direct_tls(self) -> S {
144 0 : self.stream
145 0 : }
146 :
147 : /// Write a raw message to the internal buffer.
148 0 : pub fn write_raw(&mut self, size_hint: usize, tag: u8, f: impl FnOnce(&mut Vec<u8>)) {
149 0 : self.write.write_raw(size_hint, tag, f);
150 0 : }
151 :
152 : /// Write the message into an internal buffer
153 56 : pub fn write_message(&mut self, message: BeMessage<'_>) {
154 56 : message.write_message(&mut self.write);
155 56 : }
156 :
157 : /// Write the buffer to the socket until we have some more space again.
158 0 : pub async fn write_if_full(&mut self) -> io::Result<()> {
159 0 : while self.write.occupied_len() > 2048 {
160 0 : self.stream.write_buf(&mut self.write).await?;
161 : }
162 :
163 0 : Ok(())
164 0 : }
165 :
166 : /// Flush the output buffer into the underlying stream.
167 : ///
168 : /// This is cancel safe.
169 62 : pub async fn flush(&mut self) -> io::Result<()> {
170 62 : self.stream.write_all_buf(&mut self.write).await?;
171 62 : self.write.reset();
172 :
173 62 : self.stream.flush().await?;
174 :
175 62 : Ok(())
176 62 : }
177 :
178 : /// Flush the output buffer into the underlying stream.
179 : ///
180 : /// This is cancel safe.
181 7 : pub async fn flush_and_into_inner(mut self) -> io::Result<S> {
182 7 : self.flush().await?;
183 7 : Ok(self.stream)
184 7 : }
185 :
186 : /// Write the error message to the client, then re-throw it.
187 : ///
188 : /// Trait [`UserFacingError`] acts as an allowlist for error types.
189 : /// If `ctx` is provided and has testodrome_id set, error messages will be prefixed according to error kind.
190 1 : pub(crate) async fn throw_error<E>(
191 1 : &mut self,
192 1 : error: E,
193 1 : ctx: Option<&crate::context::RequestContext>,
194 1 : ) -> ReportedError
195 1 : where
196 1 : E: UserFacingError + Into<anyhow::Error>,
197 1 : {
198 1 : let error_kind = error.get_error_kind();
199 1 : let msg = error.to_string_client();
200 :
201 1 : if error_kind != ErrorKind::RateLimit && error_kind != ErrorKind::User {
202 0 : tracing::info!(
203 0 : kind = error_kind.to_metric_label(),
204 : msg,
205 0 : "forwarding error to user"
206 : );
207 1 : }
208 :
209 : let probe_msg;
210 1 : let mut msg = &*msg;
211 1 : if let Some(ctx) = ctx
212 0 : && ctx.get_testodrome_id().is_some()
213 : {
214 0 : let tag = match error_kind {
215 0 : ErrorKind::User => "client",
216 0 : ErrorKind::ClientDisconnect => "client",
217 0 : ErrorKind::RateLimit => "proxy",
218 0 : ErrorKind::ServiceRateLimit => "proxy",
219 0 : ErrorKind::Quota => "proxy",
220 0 : ErrorKind::Service => "proxy",
221 0 : ErrorKind::ControlPlane => "controlplane",
222 0 : ErrorKind::Postgres => "other",
223 0 : ErrorKind::Compute => "compute",
224 : };
225 0 : probe_msg = typed_json::json!({
226 0 : "tag": tag,
227 0 : "msg": msg,
228 0 : "cold_start_info": ctx.cold_start_info(),
229 : })
230 0 : .to_string();
231 0 : msg = &probe_msg;
232 1 : }
233 :
234 : // TODO: either preserve the error code from postgres, or assign error codes to proxy errors.
235 1 : self.write.write_error(msg, SQLSTATE_INTERNAL_ERROR);
236 :
237 1 : self.flush()
238 1 : .await
239 1 : .unwrap_or_else(|e| tracing::debug!("write_message failed: {e}"));
240 :
241 1 : ReportedError::new(error)
242 1 : }
243 : }
244 :
245 : /// Wrapper for upgrading raw streams into secure streams.
246 : pub enum Stream<S> {
247 : /// We always begin with a raw stream,
248 : /// which may then be upgraded into a secure stream.
249 : Raw { raw: S },
250 : Tls {
251 : /// We box [`TlsStream`] since it can be quite large.
252 : tls: Box<TlsStream<S>>,
253 : /// Channel binding parameter
254 : tls_server_end_point: TlsServerEndPoint,
255 : },
256 : }
257 :
258 : impl<S: Unpin> Unpin for Stream<S> {}
259 :
260 : impl<S> Stream<S> {
261 : /// Construct a new instance from a raw stream.
262 25 : pub fn from_raw(raw: S) -> Self {
263 25 : Self::Raw { raw }
264 25 : }
265 :
266 : /// Return SNI hostname when it's available.
267 0 : pub fn sni_hostname(&self) -> Option<&str> {
268 0 : match self {
269 0 : Stream::Raw { .. } => None,
270 0 : Stream::Tls { tls, .. } => tls.get_ref().1.server_name(),
271 : }
272 0 : }
273 :
274 15 : pub(crate) fn tls_server_end_point(&self) -> TlsServerEndPoint {
275 15 : match self {
276 3 : Stream::Raw { .. } => TlsServerEndPoint::Undefined,
277 : Stream::Tls {
278 12 : tls_server_end_point,
279 : ..
280 12 : } => *tls_server_end_point,
281 : }
282 15 : }
283 : }
284 :
285 : #[derive(Debug, Error)]
286 : #[error("Can't upgrade TLS stream")]
287 : pub enum StreamUpgradeError {
288 : #[error("Bad state reached: can't upgrade TLS stream")]
289 : AlreadyTls,
290 :
291 : #[error("Can't upgrade stream: IO error: {0}")]
292 : Io(#[from] io::Error),
293 : }
294 :
295 : impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
296 : /// If possible, upgrade raw stream into a secure TLS-based stream.
297 0 : pub async fn upgrade(
298 0 : self,
299 0 : cfg: Arc<ServerConfig>,
300 0 : record_handshake_error: bool,
301 0 : ) -> Result<TlsStream<S>, StreamUpgradeError> {
302 0 : match self {
303 0 : Stream::Raw { raw } => Ok(tokio_rustls::TlsAcceptor::from(cfg)
304 0 : .accept(raw)
305 0 : .await
306 0 : .inspect_err(|_| {
307 0 : if record_handshake_error {
308 0 : Metrics::get().proxy.tls_handshake_failures.inc();
309 0 : }
310 0 : })?),
311 0 : Stream::Tls { .. } => Err(StreamUpgradeError::AlreadyTls),
312 : }
313 0 : }
314 : }
315 :
316 : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for Stream<S> {
317 202 : fn poll_read(
318 202 : mut self: Pin<&mut Self>,
319 202 : context: &mut task::Context<'_>,
320 202 : buf: &mut ReadBuf<'_>,
321 202 : ) -> task::Poll<io::Result<()>> {
322 202 : match &mut *self {
323 36 : Self::Raw { raw } => Pin::new(raw).poll_read(context, buf),
324 166 : Self::Tls { tls, .. } => Pin::new(tls).poll_read(context, buf),
325 : }
326 202 : }
327 : }
328 :
329 : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
330 71 : fn poll_write(
331 71 : mut self: Pin<&mut Self>,
332 71 : context: &mut task::Context<'_>,
333 71 : buf: &[u8],
334 71 : ) -> task::Poll<io::Result<usize>> {
335 71 : match &mut *self {
336 27 : Self::Raw { raw } => Pin::new(raw).poll_write(context, buf),
337 44 : Self::Tls { tls, .. } => Pin::new(tls).poll_write(context, buf),
338 : }
339 71 : }
340 :
341 78 : fn poll_flush(
342 78 : mut self: Pin<&mut Self>,
343 78 : context: &mut task::Context<'_>,
344 78 : ) -> task::Poll<io::Result<()>> {
345 78 : match &mut *self {
346 27 : Self::Raw { raw } => Pin::new(raw).poll_flush(context),
347 51 : Self::Tls { tls, .. } => Pin::new(tls).poll_flush(context),
348 : }
349 78 : }
350 :
351 0 : fn poll_shutdown(
352 0 : mut self: Pin<&mut Self>,
353 0 : context: &mut task::Context<'_>,
354 0 : ) -> task::Poll<io::Result<()>> {
355 0 : match &mut *self {
356 0 : Self::Raw { raw } => Pin::new(raw).poll_shutdown(context),
357 0 : Self::Tls { tls, .. } => Pin::new(tls).poll_shutdown(context),
358 : }
359 0 : }
360 : }
|