LCOV - code coverage report
Current view: top level - proxy/src - stream.rs (source / functions) Coverage Total Hit
Test: 1e20c4f2b28aa592527961bb32170ebbd2c9172f.info Lines: 62.2 % 185 115
Test Date: 2025-07-16 12:29:03 Functions: 20.5 % 117 24

            Line data    Source code
       1              : use std::pin::Pin;
       2              : use std::sync::Arc;
       3              : use std::{io, task};
       4              : 
       5              : use rustls::ServerConfig;
       6              : use thiserror::Error;
       7              : use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
       8              : use tokio_rustls::server::TlsStream;
       9              : 
      10              : use crate::error::{ErrorKind, ReportableError, UserFacingError};
      11              : use crate::metrics::Metrics;
      12              : use crate::pqproto::{
      13              :     BeMessage, FE_PASSWORD_MESSAGE, FeStartupPacket, SQLSTATE_INTERNAL_ERROR, WriteBuf,
      14              :     read_message, read_startup,
      15              : };
      16              : use crate::tls::TlsServerEndPoint;
      17              : 
      18              : /// Stream wrapper which implements libpq's protocol.
      19              : ///
      20              : /// NOTE: This object deliberately doesn't implement [`AsyncRead`]
      21              : /// or [`AsyncWrite`] to prevent subtle errors (e.g. trying
      22              : /// to pass random malformed bytes through the connection).
      23              : pub struct PqStream<S> {
      24              :     stream: S,
      25              :     read: Vec<u8>,
      26              :     write: WriteBuf,
      27              : }
      28              : 
      29              : impl<S> PqStream<S> {
      30           35 :     pub fn get_ref(&self) -> &S {
      31           35 :         &self.stream
      32           35 :     }
      33              : 
      34              :     /// Construct a new libpq protocol wrapper over a stream without the first startup message.
      35              :     #[cfg(test)]
      36            3 :     pub fn new_skip_handshake(stream: S) -> Self {
      37            3 :         Self {
      38            3 :             stream,
      39            3 :             read: Vec::new(),
      40            3 :             write: WriteBuf::new(),
      41            3 :         }
      42            3 :     }
      43              : }
      44              : 
      45              : impl<S: AsyncRead + AsyncWrite + Unpin> PqStream<S> {
      46              :     /// Construct a new libpq protocol wrapper and read the first startup message.
      47              :     ///
      48              :     /// This is not cancel safe.
      49           42 :     pub async fn parse_startup(mut stream: S) -> io::Result<(Self, FeStartupPacket)> {
      50           42 :         let startup = read_startup(&mut stream).await?;
      51           42 :         Ok((
      52           42 :             Self {
      53           42 :                 stream,
      54           42 :                 read: Vec::new(),
      55           42 :                 write: WriteBuf::new(),
      56           42 :             },
      57           42 :             startup,
      58           42 :         ))
      59           42 :     }
      60              : 
      61              :     /// Tell the client that encryption is not supported.
      62              :     ///
      63              :     /// This is not cancel safe
      64            0 :     pub async fn reject_encryption(&mut self) -> io::Result<FeStartupPacket> {
      65              :         // N for No.
      66            0 :         self.write.encryption(b'N');
      67            0 :         self.flush().await?;
      68            0 :         read_startup(&mut self.stream).await
      69            0 :     }
      70              : }
      71              : 
      72              : impl<S: AsyncRead + Unpin> PqStream<S> {
      73              :     /// Read a raw postgres packet, which will respect the max length requested.
      74              :     /// This is not cancel safe.
      75           26 :     async fn read_raw_expect(&mut self, tag: u8, max: u32) -> io::Result<&mut [u8]> {
      76           26 :         let (actual_tag, msg) = read_message(&mut self.stream, &mut self.read, max).await?;
      77           25 :         if actual_tag != tag {
      78            0 :             return Err(io::Error::other(format!(
      79            0 :                 "incorrect message tag, expected {:?}, got {:?}",
      80            0 :                 tag as char, actual_tag as char,
      81            0 :             )));
      82           25 :         }
      83           25 :         Ok(msg)
      84           26 :     }
      85              : 
      86              :     /// Read a postgres password message, which will respect the max length requested.
      87              :     /// This is not cancel safe.
      88           26 :     pub async fn read_password_message(&mut self) -> io::Result<&mut [u8]> {
      89              :         // passwords are usually pretty short
      90              :         // and SASL SCRAM messages are no longer than 256 bytes in my testing
      91              :         // (a few hashes and random bytes, encoded into base64).
      92              :         const MAX_PASSWORD_LENGTH: u32 = 512;
      93           26 :         self.read_raw_expect(FE_PASSWORD_MESSAGE, MAX_PASSWORD_LENGTH)
      94           26 :             .await
      95           26 :     }
      96              : }
      97              : 
      98              : #[derive(Debug)]
      99              : pub struct ReportedError {
     100              :     source: anyhow::Error,
     101              :     error_kind: ErrorKind,
     102              : }
     103              : 
     104              : impl ReportedError {
     105            1 :     pub fn new(e: (impl UserFacingError + Into<anyhow::Error>)) -> Self {
     106            1 :         let error_kind = e.get_error_kind();
     107            1 :         Self {
     108            1 :             source: e.into(),
     109            1 :             error_kind,
     110            1 :         }
     111            1 :     }
     112              : }
     113              : 
     114              : impl std::fmt::Display for ReportedError {
     115            1 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
     116            1 :         self.source.fmt(f)
     117            1 :     }
     118              : }
     119              : 
     120              : impl std::error::Error for ReportedError {
     121            0 :     fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
     122            0 :         self.source.source()
     123            0 :     }
     124              : }
     125              : 
     126              : impl ReportableError for ReportedError {
     127            0 :     fn get_error_kind(&self) -> ErrorKind {
     128            0 :         self.error_kind
     129            0 :     }
     130              : }
     131              : 
     132              : impl<S: AsyncWrite + Unpin> PqStream<S> {
     133              :     /// Tell the client that we are willing to accept SSL.
     134              :     /// This is not cancel safe
     135           20 :     pub async fn accept_tls(mut self) -> io::Result<S> {
     136              :         // S for SSL.
     137           20 :         self.write.encryption(b'S');
     138           20 :         self.flush().await?;
     139           20 :         Ok(self.stream)
     140           20 :     }
     141              : 
     142              :     /// Assert that we are using direct TLS.
     143            0 :     pub fn accept_direct_tls(self) -> S {
     144            0 :         self.stream
     145            0 :     }
     146              : 
     147              :     /// Write a raw message to the internal buffer.
     148            0 :     pub fn write_raw(&mut self, size_hint: usize, tag: u8, f: impl FnOnce(&mut Vec<u8>)) {
     149            0 :         self.write.write_raw(size_hint, tag, f);
     150            0 :     }
     151              : 
     152              :     /// Write the message into an internal buffer
     153           56 :     pub fn write_message(&mut self, message: BeMessage<'_>) {
     154           56 :         message.write_message(&mut self.write);
     155           56 :     }
     156              : 
     157              :     /// Flush the output buffer into the underlying stream.
     158              :     ///
     159              :     /// This is cancel safe.
     160           62 :     pub async fn flush(&mut self) -> io::Result<()> {
     161           62 :         self.stream.write_all_buf(&mut self.write).await?;
     162           62 :         self.write.reset();
     163              : 
     164           62 :         self.stream.flush().await?;
     165              : 
     166           62 :         Ok(())
     167           62 :     }
     168              : 
     169              :     /// Flush the output buffer into the underlying stream.
     170              :     ///
     171              :     /// This is cancel safe.
     172            7 :     pub async fn flush_and_into_inner(mut self) -> io::Result<S> {
     173            7 :         self.flush().await?;
     174            7 :         Ok(self.stream)
     175            7 :     }
     176              : 
     177              :     /// Write the error message to the client, then re-throw it.
     178              :     ///
     179              :     /// Trait [`UserFacingError`] acts as an allowlist for error types.
     180              :     /// If `ctx` is provided and has testodrome_id set, error messages will be prefixed according to error kind.
     181            1 :     pub(crate) async fn throw_error<E>(
     182            1 :         &mut self,
     183            1 :         error: E,
     184            1 :         ctx: Option<&crate::context::RequestContext>,
     185            1 :     ) -> ReportedError
     186            1 :     where
     187            1 :         E: UserFacingError + Into<anyhow::Error>,
     188            1 :     {
     189            1 :         let error_kind = error.get_error_kind();
     190            1 :         let msg = error.to_string_client();
     191              : 
     192            1 :         if error_kind != ErrorKind::RateLimit && error_kind != ErrorKind::User {
     193            0 :             tracing::info!(
     194            0 :                 kind = error_kind.to_metric_label(),
     195              :                 msg,
     196            0 :                 "forwarding error to user"
     197              :             );
     198            1 :         }
     199              : 
     200              :         let probe_msg;
     201            1 :         let mut msg = &*msg;
     202            1 :         if let Some(ctx) = ctx
     203            0 :             && ctx.get_testodrome_id().is_some()
     204              :         {
     205            0 :             let tag = match error_kind {
     206            0 :                 ErrorKind::User => "client",
     207            0 :                 ErrorKind::ClientDisconnect => "client",
     208            0 :                 ErrorKind::RateLimit => "proxy",
     209            0 :                 ErrorKind::ServiceRateLimit => "proxy",
     210            0 :                 ErrorKind::Quota => "proxy",
     211            0 :                 ErrorKind::Service => "proxy",
     212            0 :                 ErrorKind::ControlPlane => "controlplane",
     213            0 :                 ErrorKind::Postgres => "other",
     214            0 :                 ErrorKind::Compute => "compute",
     215              :             };
     216            0 :             probe_msg = typed_json::json!({
     217            0 :                 "tag": tag,
     218            0 :                 "msg": msg,
     219            0 :                 "cold_start_info": ctx.cold_start_info(),
     220              :             })
     221            0 :             .to_string();
     222            0 :             msg = &probe_msg;
     223            1 :         }
     224              : 
     225              :         // TODO: either preserve the error code from postgres, or assign error codes to proxy errors.
     226            1 :         self.write.write_error(msg, SQLSTATE_INTERNAL_ERROR);
     227              : 
     228            1 :         self.flush()
     229            1 :             .await
     230            1 :             .unwrap_or_else(|e| tracing::debug!("write_message failed: {e}"));
     231              : 
     232            1 :         ReportedError::new(error)
     233            1 :     }
     234              : }
     235              : 
     236              : /// Wrapper for upgrading raw streams into secure streams.
     237              : pub enum Stream<S> {
     238              :     /// We always begin with a raw stream,
     239              :     /// which may then be upgraded into a secure stream.
     240              :     Raw { raw: S },
     241              :     Tls {
     242              :         /// We box [`TlsStream`] since it can be quite large.
     243              :         tls: Box<TlsStream<S>>,
     244              :         /// Channel binding parameter
     245              :         tls_server_end_point: TlsServerEndPoint,
     246              :     },
     247              : }
     248              : 
     249              : impl<S: Unpin> Unpin for Stream<S> {}
     250              : 
     251              : impl<S> Stream<S> {
     252              :     /// Construct a new instance from a raw stream.
     253           25 :     pub fn from_raw(raw: S) -> Self {
     254           25 :         Self::Raw { raw }
     255           25 :     }
     256              : 
     257              :     /// Return SNI hostname when it's available.
     258            0 :     pub fn sni_hostname(&self) -> Option<&str> {
     259            0 :         match self {
     260            0 :             Stream::Raw { .. } => None,
     261            0 :             Stream::Tls { tls, .. } => tls.get_ref().1.server_name(),
     262              :         }
     263            0 :     }
     264              : 
     265           15 :     pub(crate) fn tls_server_end_point(&self) -> TlsServerEndPoint {
     266           15 :         match self {
     267            3 :             Stream::Raw { .. } => TlsServerEndPoint::Undefined,
     268              :             Stream::Tls {
     269           12 :                 tls_server_end_point,
     270              :                 ..
     271           12 :             } => *tls_server_end_point,
     272              :         }
     273           15 :     }
     274              : }
     275              : 
     276              : #[derive(Debug, Error)]
     277              : #[error("Can't upgrade TLS stream")]
     278              : pub enum StreamUpgradeError {
     279              :     #[error("Bad state reached: can't upgrade TLS stream")]
     280              :     AlreadyTls,
     281              : 
     282              :     #[error("Can't upgrade stream: IO error: {0}")]
     283              :     Io(#[from] io::Error),
     284              : }
     285              : 
     286              : impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
     287              :     /// If possible, upgrade raw stream into a secure TLS-based stream.
     288            0 :     pub async fn upgrade(
     289            0 :         self,
     290            0 :         cfg: Arc<ServerConfig>,
     291            0 :         record_handshake_error: bool,
     292            0 :     ) -> Result<TlsStream<S>, StreamUpgradeError> {
     293            0 :         match self {
     294            0 :             Stream::Raw { raw } => Ok(tokio_rustls::TlsAcceptor::from(cfg)
     295            0 :                 .accept(raw)
     296            0 :                 .await
     297            0 :                 .inspect_err(|_| {
     298            0 :                     if record_handshake_error {
     299            0 :                         Metrics::get().proxy.tls_handshake_failures.inc();
     300            0 :                     }
     301            0 :                 })?),
     302            0 :             Stream::Tls { .. } => Err(StreamUpgradeError::AlreadyTls),
     303              :         }
     304            0 :     }
     305              : }
     306              : 
     307              : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for Stream<S> {
     308          199 :     fn poll_read(
     309          199 :         mut self: Pin<&mut Self>,
     310          199 :         context: &mut task::Context<'_>,
     311          199 :         buf: &mut ReadBuf<'_>,
     312          199 :     ) -> task::Poll<io::Result<()>> {
     313          199 :         match &mut *self {
     314           36 :             Self::Raw { raw } => Pin::new(raw).poll_read(context, buf),
     315          163 :             Self::Tls { tls, .. } => Pin::new(tls).poll_read(context, buf),
     316              :         }
     317          199 :     }
     318              : }
     319              : 
     320              : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
     321           71 :     fn poll_write(
     322           71 :         mut self: Pin<&mut Self>,
     323           71 :         context: &mut task::Context<'_>,
     324           71 :         buf: &[u8],
     325           71 :     ) -> task::Poll<io::Result<usize>> {
     326           71 :         match &mut *self {
     327           27 :             Self::Raw { raw } => Pin::new(raw).poll_write(context, buf),
     328           44 :             Self::Tls { tls, .. } => Pin::new(tls).poll_write(context, buf),
     329              :         }
     330           71 :     }
     331              : 
     332           78 :     fn poll_flush(
     333           78 :         mut self: Pin<&mut Self>,
     334           78 :         context: &mut task::Context<'_>,
     335           78 :     ) -> task::Poll<io::Result<()>> {
     336           78 :         match &mut *self {
     337           27 :             Self::Raw { raw } => Pin::new(raw).poll_flush(context),
     338           51 :             Self::Tls { tls, .. } => Pin::new(tls).poll_flush(context),
     339              :         }
     340           78 :     }
     341              : 
     342            0 :     fn poll_shutdown(
     343            0 :         mut self: Pin<&mut Self>,
     344            0 :         context: &mut task::Context<'_>,
     345            0 :     ) -> task::Poll<io::Result<()>> {
     346            0 :         match &mut *self {
     347            0 :             Self::Raw { raw } => Pin::new(raw).poll_shutdown(context),
     348            0 :             Self::Tls { tls, .. } => Pin::new(tls).poll_shutdown(context),
     349              :         }
     350            0 :     }
     351              : }
        

Generated by: LCOV version 2.1-beta