Line data Source code
1 : use std::pin::Pin;
2 : use std::sync::Arc;
3 : use std::{io, task};
4 :
5 : use rustls::ServerConfig;
6 : use thiserror::Error;
7 : use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
8 : use tokio_rustls::server::TlsStream;
9 :
10 : use crate::error::{ErrorKind, ReportableError, UserFacingError};
11 : use crate::metrics::Metrics;
12 : use crate::pqproto::{
13 : BeMessage, FE_PASSWORD_MESSAGE, FeStartupPacket, SQLSTATE_INTERNAL_ERROR, WriteBuf,
14 : read_message, read_startup,
15 : };
16 : use crate::tls::TlsServerEndPoint;
17 :
18 : /// Stream wrapper which implements libpq's protocol.
19 : ///
20 : /// NOTE: This object deliberately doesn't implement [`AsyncRead`]
21 : /// or [`AsyncWrite`] to prevent subtle errors (e.g. trying
22 : /// to pass random malformed bytes through the connection).
23 : pub struct PqStream<S> {
24 : stream: S,
25 : read: Vec<u8>,
26 : write: WriteBuf,
27 : }
28 :
29 : impl<S> PqStream<S> {
30 35 : pub fn get_ref(&self) -> &S {
31 35 : &self.stream
32 35 : }
33 :
34 : /// Construct a new libpq protocol wrapper over a stream without the first startup message.
35 : #[cfg(test)]
36 3 : pub fn new_skip_handshake(stream: S) -> Self {
37 3 : Self {
38 3 : stream,
39 3 : read: Vec::new(),
40 3 : write: WriteBuf::new(),
41 3 : }
42 3 : }
43 : }
44 :
45 : impl<S: AsyncRead + AsyncWrite + Unpin> PqStream<S> {
46 : /// Construct a new libpq protocol wrapper and read the first startup message.
47 : ///
48 : /// This is not cancel safe.
49 42 : pub async fn parse_startup(mut stream: S) -> io::Result<(Self, FeStartupPacket)> {
50 42 : let startup = read_startup(&mut stream).await?;
51 42 : Ok((
52 42 : Self {
53 42 : stream,
54 42 : read: Vec::new(),
55 42 : write: WriteBuf::new(),
56 42 : },
57 42 : startup,
58 42 : ))
59 42 : }
60 :
61 : /// Tell the client that encryption is not supported.
62 : ///
63 : /// This is not cancel safe
64 0 : pub async fn reject_encryption(&mut self) -> io::Result<FeStartupPacket> {
65 : // N for No.
66 0 : self.write.encryption(b'N');
67 0 : self.flush().await?;
68 0 : read_startup(&mut self.stream).await
69 0 : }
70 : }
71 :
72 : impl<S: AsyncRead + Unpin> PqStream<S> {
73 : /// Read a raw postgres packet, which will respect the max length requested.
74 : /// This is not cancel safe.
75 26 : async fn read_raw_expect(&mut self, tag: u8, max: u32) -> io::Result<&mut [u8]> {
76 26 : let (actual_tag, msg) = read_message(&mut self.stream, &mut self.read, max).await?;
77 25 : if actual_tag != tag {
78 0 : return Err(io::Error::other(format!(
79 0 : "incorrect message tag, expected {:?}, got {:?}",
80 0 : tag as char, actual_tag as char,
81 0 : )));
82 25 : }
83 25 : Ok(msg)
84 26 : }
85 :
86 : /// Read a postgres password message, which will respect the max length requested.
87 : /// This is not cancel safe.
88 26 : pub async fn read_password_message(&mut self) -> io::Result<&mut [u8]> {
89 : // passwords are usually pretty short
90 : // and SASL SCRAM messages are no longer than 256 bytes in my testing
91 : // (a few hashes and random bytes, encoded into base64).
92 : const MAX_PASSWORD_LENGTH: u32 = 512;
93 26 : self.read_raw_expect(FE_PASSWORD_MESSAGE, MAX_PASSWORD_LENGTH)
94 26 : .await
95 26 : }
96 : }
97 :
98 : #[derive(Debug)]
99 : pub struct ReportedError {
100 : source: anyhow::Error,
101 : error_kind: ErrorKind,
102 : }
103 :
104 : impl ReportedError {
105 1 : pub fn new(e: (impl UserFacingError + Into<anyhow::Error>)) -> Self {
106 1 : let error_kind = e.get_error_kind();
107 1 : Self {
108 1 : source: e.into(),
109 1 : error_kind,
110 1 : }
111 1 : }
112 : }
113 :
114 : impl std::fmt::Display for ReportedError {
115 1 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116 1 : self.source.fmt(f)
117 1 : }
118 : }
119 :
120 : impl std::error::Error for ReportedError {
121 0 : fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
122 0 : self.source.source()
123 0 : }
124 : }
125 :
126 : impl ReportableError for ReportedError {
127 0 : fn get_error_kind(&self) -> ErrorKind {
128 0 : self.error_kind
129 0 : }
130 : }
131 :
132 : impl<S: AsyncWrite + Unpin> PqStream<S> {
133 : /// Tell the client that we are willing to accept SSL.
134 : /// This is not cancel safe
135 20 : pub async fn accept_tls(mut self) -> io::Result<S> {
136 : // S for SSL.
137 20 : self.write.encryption(b'S');
138 20 : self.flush().await?;
139 20 : Ok(self.stream)
140 20 : }
141 :
142 : /// Assert that we are using direct TLS.
143 0 : pub fn accept_direct_tls(self) -> S {
144 0 : self.stream
145 0 : }
146 :
147 : /// Write a raw message to the internal buffer.
148 0 : pub fn write_raw(&mut self, size_hint: usize, tag: u8, f: impl FnOnce(&mut Vec<u8>)) {
149 0 : self.write.write_raw(size_hint, tag, f);
150 0 : }
151 :
152 : /// Write the message into an internal buffer
153 56 : pub fn write_message(&mut self, message: BeMessage<'_>) {
154 56 : message.write_message(&mut self.write);
155 56 : }
156 :
157 : /// Flush the output buffer into the underlying stream.
158 : ///
159 : /// This is cancel safe.
160 62 : pub async fn flush(&mut self) -> io::Result<()> {
161 62 : self.stream.write_all_buf(&mut self.write).await?;
162 62 : self.write.reset();
163 :
164 62 : self.stream.flush().await?;
165 :
166 62 : Ok(())
167 62 : }
168 :
169 : /// Flush the output buffer into the underlying stream.
170 : ///
171 : /// This is cancel safe.
172 7 : pub async fn flush_and_into_inner(mut self) -> io::Result<S> {
173 7 : self.flush().await?;
174 7 : Ok(self.stream)
175 7 : }
176 :
177 : /// Write the error message to the client, then re-throw it.
178 : ///
179 : /// Trait [`UserFacingError`] acts as an allowlist for error types.
180 : /// If `ctx` is provided and has testodrome_id set, error messages will be prefixed according to error kind.
181 1 : pub(crate) async fn throw_error<E>(
182 1 : &mut self,
183 1 : error: E,
184 1 : ctx: Option<&crate::context::RequestContext>,
185 1 : ) -> ReportedError
186 1 : where
187 1 : E: UserFacingError + Into<anyhow::Error>,
188 1 : {
189 1 : let error_kind = error.get_error_kind();
190 1 : let msg = error.to_string_client();
191 :
192 1 : if error_kind != ErrorKind::RateLimit && error_kind != ErrorKind::User {
193 0 : tracing::info!(
194 0 : kind = error_kind.to_metric_label(),
195 : msg,
196 0 : "forwarding error to user"
197 : );
198 1 : }
199 :
200 : let probe_msg;
201 1 : let mut msg = &*msg;
202 1 : if let Some(ctx) = ctx
203 0 : && ctx.get_testodrome_id().is_some()
204 : {
205 0 : let tag = match error_kind {
206 0 : ErrorKind::User => "client",
207 0 : ErrorKind::ClientDisconnect => "client",
208 0 : ErrorKind::RateLimit => "proxy",
209 0 : ErrorKind::ServiceRateLimit => "proxy",
210 0 : ErrorKind::Quota => "proxy",
211 0 : ErrorKind::Service => "proxy",
212 0 : ErrorKind::ControlPlane => "controlplane",
213 0 : ErrorKind::Postgres => "other",
214 0 : ErrorKind::Compute => "compute",
215 : };
216 0 : probe_msg = typed_json::json!({
217 0 : "tag": tag,
218 0 : "msg": msg,
219 0 : "cold_start_info": ctx.cold_start_info(),
220 : })
221 0 : .to_string();
222 0 : msg = &probe_msg;
223 1 : }
224 :
225 : // TODO: either preserve the error code from postgres, or assign error codes to proxy errors.
226 1 : self.write.write_error(msg, SQLSTATE_INTERNAL_ERROR);
227 :
228 1 : self.flush()
229 1 : .await
230 1 : .unwrap_or_else(|e| tracing::debug!("write_message failed: {e}"));
231 :
232 1 : ReportedError::new(error)
233 1 : }
234 : }
235 :
236 : /// Wrapper for upgrading raw streams into secure streams.
237 : pub enum Stream<S> {
238 : /// We always begin with a raw stream,
239 : /// which may then be upgraded into a secure stream.
240 : Raw { raw: S },
241 : Tls {
242 : /// We box [`TlsStream`] since it can be quite large.
243 : tls: Box<TlsStream<S>>,
244 : /// Channel binding parameter
245 : tls_server_end_point: TlsServerEndPoint,
246 : },
247 : }
248 :
249 : impl<S: Unpin> Unpin for Stream<S> {}
250 :
251 : impl<S> Stream<S> {
252 : /// Construct a new instance from a raw stream.
253 25 : pub fn from_raw(raw: S) -> Self {
254 25 : Self::Raw { raw }
255 25 : }
256 :
257 : /// Return SNI hostname when it's available.
258 0 : pub fn sni_hostname(&self) -> Option<&str> {
259 0 : match self {
260 0 : Stream::Raw { .. } => None,
261 0 : Stream::Tls { tls, .. } => tls.get_ref().1.server_name(),
262 : }
263 0 : }
264 :
265 15 : pub(crate) fn tls_server_end_point(&self) -> TlsServerEndPoint {
266 15 : match self {
267 3 : Stream::Raw { .. } => TlsServerEndPoint::Undefined,
268 : Stream::Tls {
269 12 : tls_server_end_point,
270 : ..
271 12 : } => *tls_server_end_point,
272 : }
273 15 : }
274 : }
275 :
276 : #[derive(Debug, Error)]
277 : #[error("Can't upgrade TLS stream")]
278 : pub enum StreamUpgradeError {
279 : #[error("Bad state reached: can't upgrade TLS stream")]
280 : AlreadyTls,
281 :
282 : #[error("Can't upgrade stream: IO error: {0}")]
283 : Io(#[from] io::Error),
284 : }
285 :
286 : impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
287 : /// If possible, upgrade raw stream into a secure TLS-based stream.
288 0 : pub async fn upgrade(
289 0 : self,
290 0 : cfg: Arc<ServerConfig>,
291 0 : record_handshake_error: bool,
292 0 : ) -> Result<TlsStream<S>, StreamUpgradeError> {
293 0 : match self {
294 0 : Stream::Raw { raw } => Ok(tokio_rustls::TlsAcceptor::from(cfg)
295 0 : .accept(raw)
296 0 : .await
297 0 : .inspect_err(|_| {
298 0 : if record_handshake_error {
299 0 : Metrics::get().proxy.tls_handshake_failures.inc();
300 0 : }
301 0 : })?),
302 0 : Stream::Tls { .. } => Err(StreamUpgradeError::AlreadyTls),
303 : }
304 0 : }
305 : }
306 :
307 : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for Stream<S> {
308 199 : fn poll_read(
309 199 : mut self: Pin<&mut Self>,
310 199 : context: &mut task::Context<'_>,
311 199 : buf: &mut ReadBuf<'_>,
312 199 : ) -> task::Poll<io::Result<()>> {
313 199 : match &mut *self {
314 36 : Self::Raw { raw } => Pin::new(raw).poll_read(context, buf),
315 163 : Self::Tls { tls, .. } => Pin::new(tls).poll_read(context, buf),
316 : }
317 199 : }
318 : }
319 :
320 : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
321 71 : fn poll_write(
322 71 : mut self: Pin<&mut Self>,
323 71 : context: &mut task::Context<'_>,
324 71 : buf: &[u8],
325 71 : ) -> task::Poll<io::Result<usize>> {
326 71 : match &mut *self {
327 27 : Self::Raw { raw } => Pin::new(raw).poll_write(context, buf),
328 44 : Self::Tls { tls, .. } => Pin::new(tls).poll_write(context, buf),
329 : }
330 71 : }
331 :
332 78 : fn poll_flush(
333 78 : mut self: Pin<&mut Self>,
334 78 : context: &mut task::Context<'_>,
335 78 : ) -> task::Poll<io::Result<()>> {
336 78 : match &mut *self {
337 27 : Self::Raw { raw } => Pin::new(raw).poll_flush(context),
338 51 : Self::Tls { tls, .. } => Pin::new(tls).poll_flush(context),
339 : }
340 78 : }
341 :
342 0 : fn poll_shutdown(
343 0 : mut self: Pin<&mut Self>,
344 0 : context: &mut task::Context<'_>,
345 0 : ) -> task::Poll<io::Result<()>> {
346 0 : match &mut *self {
347 0 : Self::Raw { raw } => Pin::new(raw).poll_shutdown(context),
348 0 : Self::Tls { tls, .. } => Pin::new(tls).poll_shutdown(context),
349 : }
350 0 : }
351 : }
|