|             Line data    Source code 
       1              : use std::pin::Pin;
       2              : use std::sync::Arc;
       3              : use std::{io, task};
       4              : 
       5              : use rustls::ServerConfig;
       6              : use thiserror::Error;
       7              : use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
       8              : use tokio_rustls::server::TlsStream;
       9              : 
      10              : use crate::error::{ErrorKind, ReportableError, UserFacingError};
      11              : use crate::metrics::Metrics;
      12              : use crate::pqproto::{
      13              :     BeMessage, FE_PASSWORD_MESSAGE, FeStartupPacket, SQLSTATE_INTERNAL_ERROR, WriteBuf,
      14              :     read_message, read_startup,
      15              : };
      16              : use crate::tls::TlsServerEndPoint;
      17              : 
      18              : /// Stream wrapper which implements libpq's protocol.
      19              : ///
      20              : /// NOTE: This object deliberately doesn't implement [`AsyncRead`]
      21              : /// or [`AsyncWrite`] to prevent subtle errors (e.g. trying
      22              : /// to pass random malformed bytes through the connection).
      23              : pub struct PqStream<S> {
      24              :     stream: S,
      25              :     read: Vec<u8>,
      26              :     write: WriteBuf,
      27              : }
      28              : 
      29              : impl<S> PqStream<S> {
      30           35 :     pub fn get_ref(&self) -> &S {
      31           35 :         &self.stream
      32           35 :     }
      33              : 
      34              :     /// Construct a new libpq protocol wrapper over a stream without the first startup message.
      35              :     #[cfg(test)]
      36            3 :     pub fn new_skip_handshake(stream: S) -> Self {
      37            3 :         Self {
      38            3 :             stream,
      39            3 :             read: Vec::new(),
      40            3 :             write: WriteBuf::new(),
      41            3 :         }
      42            3 :     }
      43              : }
      44              : 
      45              : impl<S: AsyncRead + AsyncWrite + Unpin> PqStream<S> {
      46              :     /// Construct a new libpq protocol wrapper and read the first startup message.
      47              :     ///
      48              :     /// This is not cancel safe.
      49           42 :     pub async fn parse_startup(mut stream: S) -> io::Result<(Self, FeStartupPacket)> {
      50           42 :         let startup = read_startup(&mut stream).await?;
      51           42 :         Ok((
      52           42 :             Self {
      53           42 :                 stream,
      54           42 :                 read: Vec::new(),
      55           42 :                 write: WriteBuf::new(),
      56           42 :             },
      57           42 :             startup,
      58           42 :         ))
      59           42 :     }
      60              : 
      61              :     /// Tell the client that encryption is not supported.
      62              :     ///
      63              :     /// This is not cancel safe
      64            0 :     pub async fn reject_encryption(&mut self) -> io::Result<FeStartupPacket> {
      65              :         // N for No.
      66            0 :         self.write.encryption(b'N');
      67            0 :         self.flush().await?;
      68            0 :         read_startup(&mut self.stream).await
      69            0 :     }
      70              : }
      71              : 
      72              : impl<S: AsyncRead + Unpin> PqStream<S> {
      73              :     /// Read a raw postgres packet, which will respect the max length requested.
      74              :     /// This is not cancel safe.
      75           26 :     async fn read_raw_expect(&mut self, tag: u8, max: u32) -> io::Result<&mut [u8]> {
      76           26 :         let (actual_tag, msg) = read_message(&mut self.stream, &mut self.read, max).await?;
      77           25 :         if actual_tag != tag {
      78            0 :             return Err(io::Error::other(format!(
      79            0 :                 "incorrect message tag, expected {:?}, got {:?}",
      80            0 :                 tag as char, actual_tag as char,
      81            0 :             )));
      82           25 :         }
      83           25 :         Ok(msg)
      84           26 :     }
      85              : 
      86              :     /// Read a postgres password message, which will respect the max length requested.
      87              :     /// This is not cancel safe.
      88           26 :     pub async fn read_password_message(&mut self) -> io::Result<&mut [u8]> {
      89              :         // passwords are usually pretty short
      90              :         // and SASL SCRAM messages are no longer than 256 bytes in my testing
      91              :         // (a few hashes and random bytes, encoded into base64).
      92              :         const MAX_PASSWORD_LENGTH: u32 = 512;
      93           26 :         self.read_raw_expect(FE_PASSWORD_MESSAGE, MAX_PASSWORD_LENGTH)
      94           26 :             .await
      95           26 :     }
      96              : }
      97              : 
      98              : #[derive(Debug)]
      99              : pub struct ReportedError {
     100              :     source: anyhow::Error,
     101              :     error_kind: ErrorKind,
     102              : }
     103              : 
     104              : impl ReportedError {
     105            1 :     pub fn new(e: impl UserFacingError + Into<anyhow::Error>) -> Self {
     106            1 :         let error_kind = e.get_error_kind();
     107            1 :         Self {
     108            1 :             source: e.into(),
     109            1 :             error_kind,
     110            1 :         }
     111            1 :     }
     112              : }
     113              : 
     114              : impl std::fmt::Display for ReportedError {
     115            1 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
     116            1 :         self.source.fmt(f)
     117            1 :     }
     118              : }
     119              : 
     120              : impl std::error::Error for ReportedError {
     121            0 :     fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
     122            0 :         self.source.source()
     123            0 :     }
     124              : }
     125              : 
     126              : impl ReportableError for ReportedError {
     127            0 :     fn get_error_kind(&self) -> ErrorKind {
     128            0 :         self.error_kind
     129            0 :     }
     130              : }
     131              : 
     132              : impl<S: AsyncWrite + Unpin> PqStream<S> {
     133              :     /// Tell the client that we are willing to accept SSL.
     134              :     /// This is not cancel safe
     135           20 :     pub async fn accept_tls(mut self) -> io::Result<S> {
     136              :         // S for SSL.
     137           20 :         self.write.encryption(b'S');
     138           20 :         self.flush().await?;
     139           20 :         Ok(self.stream)
     140           20 :     }
     141              : 
     142              :     /// Assert that we are using direct TLS.
     143            0 :     pub fn accept_direct_tls(self) -> S {
     144            0 :         self.stream
     145            0 :     }
     146              : 
     147              :     /// Write a raw message to the internal buffer.
     148            0 :     pub fn write_raw(&mut self, size_hint: usize, tag: u8, f: impl FnOnce(&mut Vec<u8>)) {
     149            0 :         self.write.write_raw(size_hint, tag, f);
     150            0 :     }
     151              : 
     152              :     /// Write the message into an internal buffer
     153           56 :     pub fn write_message(&mut self, message: BeMessage<'_>) {
     154           56 :         message.write_message(&mut self.write);
     155           56 :     }
     156              : 
     157              :     /// Write the buffer to the socket until we have some more space again.
     158            0 :     pub async fn write_if_full(&mut self) -> io::Result<()> {
     159            0 :         while self.write.occupied_len() > 2048 {
     160            0 :             self.stream.write_buf(&mut self.write).await?;
     161              :         }
     162              : 
     163            0 :         Ok(())
     164            0 :     }
     165              : 
     166              :     /// Flush the output buffer into the underlying stream.
     167              :     ///
     168              :     /// This is cancel safe.
     169           62 :     pub async fn flush(&mut self) -> io::Result<()> {
     170           62 :         self.stream.write_all_buf(&mut self.write).await?;
     171           62 :         self.write.reset();
     172              : 
     173           62 :         self.stream.flush().await?;
     174              : 
     175           62 :         Ok(())
     176           62 :     }
     177              : 
     178              :     /// Flush the output buffer into the underlying stream.
     179              :     ///
     180              :     /// This is cancel safe.
     181            7 :     pub async fn flush_and_into_inner(mut self) -> io::Result<S> {
     182            7 :         self.flush().await?;
     183            7 :         Ok(self.stream)
     184            7 :     }
     185              : 
     186              :     /// Write the error message to the client, then re-throw it.
     187              :     ///
     188              :     /// Trait [`UserFacingError`] acts as an allowlist for error types.
     189              :     /// If `ctx` is provided and has testodrome_id set, error messages will be prefixed according to error kind.
     190            1 :     pub(crate) async fn throw_error<E>(
     191            1 :         &mut self,
     192            1 :         error: E,
     193            1 :         ctx: Option<&crate::context::RequestContext>,
     194            1 :     ) -> ReportedError
     195            1 :     where
     196            1 :         E: UserFacingError + Into<anyhow::Error>,
     197            1 :     {
     198            1 :         let error_kind = error.get_error_kind();
     199            1 :         let msg = error.to_string_client();
     200              : 
     201            1 :         if error_kind != ErrorKind::RateLimit && error_kind != ErrorKind::User {
     202            0 :             tracing::info!(
     203            0 :                 kind = error_kind.to_metric_label(),
     204              :                 msg,
     205            0 :                 "forwarding error to user"
     206              :             );
     207            1 :         }
     208              : 
     209              :         let probe_msg;
     210            1 :         let mut msg = &*msg;
     211            1 :         if let Some(ctx) = ctx
     212            0 :             && ctx.get_testodrome_id().is_some()
     213              :         {
     214            0 :             let tag = match error_kind {
     215            0 :                 ErrorKind::User => "client",
     216            0 :                 ErrorKind::ClientDisconnect => "client",
     217            0 :                 ErrorKind::RateLimit => "proxy",
     218            0 :                 ErrorKind::ServiceRateLimit => "proxy",
     219            0 :                 ErrorKind::Quota => "proxy",
     220            0 :                 ErrorKind::Service => "proxy",
     221            0 :                 ErrorKind::ControlPlane => "controlplane",
     222            0 :                 ErrorKind::Postgres => "other",
     223            0 :                 ErrorKind::Compute => "compute",
     224              :             };
     225            0 :             probe_msg = typed_json::json!({
     226            0 :                 "tag": tag,
     227            0 :                 "msg": msg,
     228            0 :                 "cold_start_info": ctx.cold_start_info(),
     229              :             })
     230            0 :             .to_string();
     231            0 :             msg = &probe_msg;
     232            1 :         }
     233              : 
     234              :         // TODO: either preserve the error code from postgres, or assign error codes to proxy errors.
     235            1 :         self.write.write_error(msg, SQLSTATE_INTERNAL_ERROR);
     236              : 
     237            1 :         self.flush()
     238            1 :             .await
     239            1 :             .unwrap_or_else(|e| tracing::debug!("write_message failed: {e}"));
     240              : 
     241            1 :         ReportedError::new(error)
     242            1 :     }
     243              : }
     244              : 
     245              : /// Wrapper for upgrading raw streams into secure streams.
     246              : pub enum Stream<S> {
     247              :     /// We always begin with a raw stream,
     248              :     /// which may then be upgraded into a secure stream.
     249              :     Raw { raw: S },
     250              :     Tls {
     251              :         /// We box [`TlsStream`] since it can be quite large.
     252              :         tls: Box<TlsStream<S>>,
     253              :         /// Channel binding parameter
     254              :         tls_server_end_point: TlsServerEndPoint,
     255              :     },
     256              : }
     257              : 
     258              : impl<S: Unpin> Unpin for Stream<S> {}
     259              : 
     260              : impl<S> Stream<S> {
     261              :     /// Construct a new instance from a raw stream.
     262           25 :     pub fn from_raw(raw: S) -> Self {
     263           25 :         Self::Raw { raw }
     264           25 :     }
     265              : 
     266              :     /// Return SNI hostname when it's available.
     267            0 :     pub fn sni_hostname(&self) -> Option<&str> {
     268            0 :         match self {
     269            0 :             Stream::Raw { .. } => None,
     270            0 :             Stream::Tls { tls, .. } => tls.get_ref().1.server_name(),
     271              :         }
     272            0 :     }
     273              : 
     274           15 :     pub(crate) fn tls_server_end_point(&self) -> TlsServerEndPoint {
     275           15 :         match self {
     276            3 :             Stream::Raw { .. } => TlsServerEndPoint::Undefined,
     277              :             Stream::Tls {
     278           12 :                 tls_server_end_point,
     279              :                 ..
     280           12 :             } => *tls_server_end_point,
     281              :         }
     282           15 :     }
     283              : }
     284              : 
     285              : #[derive(Debug, Error)]
     286              : #[error("Can't upgrade TLS stream")]
     287              : pub enum StreamUpgradeError {
     288              :     #[error("Bad state reached: can't upgrade TLS stream")]
     289              :     AlreadyTls,
     290              : 
     291              :     #[error("Can't upgrade stream: IO error: {0}")]
     292              :     Io(#[from] io::Error),
     293              : }
     294              : 
     295              : impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
     296              :     /// If possible, upgrade raw stream into a secure TLS-based stream.
     297            0 :     pub async fn upgrade(
     298            0 :         self,
     299            0 :         cfg: Arc<ServerConfig>,
     300            0 :         record_handshake_error: bool,
     301            0 :     ) -> Result<TlsStream<S>, StreamUpgradeError> {
     302            0 :         match self {
     303            0 :             Stream::Raw { raw } => Ok(tokio_rustls::TlsAcceptor::from(cfg)
     304            0 :                 .accept(raw)
     305            0 :                 .await
     306            0 :                 .inspect_err(|_| {
     307            0 :                     if record_handshake_error {
     308            0 :                         Metrics::get().proxy.tls_handshake_failures.inc();
     309            0 :                     }
     310            0 :                 })?),
     311            0 :             Stream::Tls { .. } => Err(StreamUpgradeError::AlreadyTls),
     312              :         }
     313            0 :     }
     314              : }
     315              : 
     316              : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for Stream<S> {
     317          202 :     fn poll_read(
     318          202 :         mut self: Pin<&mut Self>,
     319          202 :         context: &mut task::Context<'_>,
     320          202 :         buf: &mut ReadBuf<'_>,
     321          202 :     ) -> task::Poll<io::Result<()>> {
     322          202 :         match &mut *self {
     323           36 :             Self::Raw { raw } => Pin::new(raw).poll_read(context, buf),
     324          166 :             Self::Tls { tls, .. } => Pin::new(tls).poll_read(context, buf),
     325              :         }
     326          202 :     }
     327              : }
     328              : 
     329              : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
     330           71 :     fn poll_write(
     331           71 :         mut self: Pin<&mut Self>,
     332           71 :         context: &mut task::Context<'_>,
     333           71 :         buf: &[u8],
     334           71 :     ) -> task::Poll<io::Result<usize>> {
     335           71 :         match &mut *self {
     336           27 :             Self::Raw { raw } => Pin::new(raw).poll_write(context, buf),
     337           44 :             Self::Tls { tls, .. } => Pin::new(tls).poll_write(context, buf),
     338              :         }
     339           71 :     }
     340              : 
     341           78 :     fn poll_flush(
     342           78 :         mut self: Pin<&mut Self>,
     343           78 :         context: &mut task::Context<'_>,
     344           78 :     ) -> task::Poll<io::Result<()>> {
     345           78 :         match &mut *self {
     346           27 :             Self::Raw { raw } => Pin::new(raw).poll_flush(context),
     347           51 :             Self::Tls { tls, .. } => Pin::new(tls).poll_flush(context),
     348              :         }
     349           78 :     }
     350              : 
     351            0 :     fn poll_shutdown(
     352            0 :         mut self: Pin<&mut Self>,
     353            0 :         context: &mut task::Context<'_>,
     354            0 :     ) -> task::Poll<io::Result<()>> {
     355            0 :         match &mut *self {
     356            0 :             Self::Raw { raw } => Pin::new(raw).poll_shutdown(context),
     357            0 :             Self::Tls { tls, .. } => Pin::new(tls).poll_shutdown(context),
     358              :         }
     359            0 :     }
     360              : }
         |