LCOV - code coverage report
Current view: top level - proxy/src - stream.rs (source / functions) Coverage Total Hit
Test: aca806cab4756d7eb6a304846130f4a73a5d5393.info Lines: 54.6 % 205 112
Test Date: 2025-04-24 20:31:15 Functions: 27.9 % 147 41

            Line data    Source code
       1              : use std::pin::Pin;
       2              : use std::sync::Arc;
       3              : use std::{io, task};
       4              : 
       5              : use bytes::BytesMut;
       6              : use pq_proto::framed::{ConnectionError, Framed};
       7              : use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
       8              : use rustls::ServerConfig;
       9              : use serde::{Deserialize, Serialize};
      10              : use thiserror::Error;
      11              : use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
      12              : use tokio_rustls::server::TlsStream;
      13              : use tracing::debug;
      14              : 
      15              : use crate::control_plane::messages::ColdStartInfo;
      16              : use crate::error::{ErrorKind, ReportableError, UserFacingError};
      17              : use crate::metrics::Metrics;
      18              : use crate::tls::TlsServerEndPoint;
      19              : 
      20              : /// Stream wrapper which implements libpq's protocol.
      21              : ///
      22              : /// NOTE: This object deliberately doesn't implement [`AsyncRead`]
      23              : /// or [`AsyncWrite`] to prevent subtle errors (e.g. trying
      24              : /// to pass random malformed bytes through the connection).
      25              : pub struct PqStream<S> {
      26              :     pub(crate) framed: Framed<S>,
      27              : }
      28              : 
      29              : impl<S> PqStream<S> {
      30              :     /// Construct a new libpq protocol wrapper.
      31           25 :     pub fn new(stream: S) -> Self {
      32           25 :         Self {
      33           25 :             framed: Framed::new(stream),
      34           25 :         }
      35           25 :     }
      36              : 
      37              :     /// Extract the underlying stream and read buffer.
      38            0 :     pub fn into_inner(self) -> (S, BytesMut) {
      39            0 :         self.framed.into_inner()
      40            0 :     }
      41              : 
      42              :     /// Get a shared reference to the underlying stream.
      43           35 :     pub(crate) fn get_ref(&self) -> &S {
      44           35 :         self.framed.get_ref()
      45           35 :     }
      46              : }
      47              : 
      48            0 : fn err_connection() -> io::Error {
      49            0 :     io::Error::new(io::ErrorKind::ConnectionAborted, "connection is lost")
      50            0 : }
      51              : 
      52              : impl<S: AsyncRead + Unpin> PqStream<S> {
      53              :     /// Receive [`FeStartupPacket`], which is a first packet sent by a client.
      54           42 :     pub async fn read_startup_packet(&mut self) -> io::Result<FeStartupPacket> {
      55           42 :         self.framed
      56           42 :             .read_startup_message()
      57           42 :             .await
      58           42 :             .map_err(ConnectionError::into_io_error)?
      59           42 :             .ok_or_else(err_connection)
      60           42 :     }
      61              : 
      62           26 :     async fn read_message(&mut self) -> io::Result<FeMessage> {
      63           26 :         self.framed
      64           26 :             .read_message()
      65           26 :             .await
      66           26 :             .map_err(ConnectionError::into_io_error)?
      67           25 :             .ok_or_else(err_connection)
      68           26 :     }
      69              : 
      70           26 :     pub(crate) async fn read_password_message(&mut self) -> io::Result<bytes::Bytes> {
      71           26 :         match self.read_message().await? {
      72           25 :             FeMessage::PasswordMessage(msg) => Ok(msg),
      73            0 :             bad => Err(io::Error::new(
      74            0 :                 io::ErrorKind::InvalidData,
      75            0 :                 format!("unexpected message type: {bad:?}"),
      76            0 :             )),
      77              :         }
      78           26 :     }
      79              : }
      80              : 
      81              : #[derive(Debug)]
      82              : pub struct ReportedError {
      83              :     source: anyhow::Error,
      84              :     error_kind: ErrorKind,
      85              : }
      86              : 
      87              : impl std::fmt::Display for ReportedError {
      88            1 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
      89            1 :         self.source.fmt(f)
      90            1 :     }
      91              : }
      92              : 
      93              : impl std::error::Error for ReportedError {
      94            0 :     fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
      95            0 :         self.source.source()
      96            0 :     }
      97              : }
      98              : 
      99              : impl ReportableError for ReportedError {
     100            0 :     fn get_error_kind(&self) -> ErrorKind {
     101            0 :         self.error_kind
     102            0 :     }
     103              : }
     104              : 
     105            0 : #[derive(Serialize, Deserialize, Debug)]
     106              : enum ErrorTag {
     107              :     #[serde(rename = "proxy")]
     108              :     Proxy,
     109              :     #[serde(rename = "compute")]
     110              :     Compute,
     111              :     #[serde(rename = "client")]
     112              :     Client,
     113              :     #[serde(rename = "controlplane")]
     114              :     ControlPlane,
     115              :     #[serde(rename = "other")]
     116              :     Other,
     117              : }
     118              : 
     119              : impl From<ErrorKind> for ErrorTag {
     120            0 :     fn from(error_kind: ErrorKind) -> Self {
     121            0 :         match error_kind {
     122            0 :             ErrorKind::User => Self::Client,
     123            0 :             ErrorKind::ClientDisconnect => Self::Client,
     124            0 :             ErrorKind::RateLimit => Self::Proxy,
     125            0 :             ErrorKind::ServiceRateLimit => Self::Proxy, // considering rate limit as proxy error for SLI
     126            0 :             ErrorKind::Quota => Self::Proxy,
     127            0 :             ErrorKind::Service => Self::Proxy,
     128            0 :             ErrorKind::ControlPlane => Self::ControlPlane,
     129            0 :             ErrorKind::Postgres => Self::Other,
     130            0 :             ErrorKind::Compute => Self::Compute,
     131              :         }
     132            0 :     }
     133              : }
     134              : 
     135            0 : #[derive(Serialize, Deserialize, Debug)]
     136              : #[serde(rename_all = "snake_case")]
     137              : struct ProbeErrorData {
     138              :     tag: ErrorTag,
     139              :     msg: String,
     140              :     cold_start_info: Option<ColdStartInfo>,
     141              : }
     142              : 
     143              : impl<S: AsyncWrite + Unpin> PqStream<S> {
     144              :     /// Write the message into an internal buffer, but don't flush the underlying stream.
     145           77 :     pub(crate) fn write_message_noflush(
     146           77 :         &mut self,
     147           77 :         message: &BeMessage<'_>,
     148           77 :     ) -> io::Result<&mut Self> {
     149           77 :         self.framed
     150           77 :             .write_message(message)
     151           77 :             .map_err(ProtocolError::into_io_error)?;
     152           77 :         Ok(self)
     153           77 :     }
     154              : 
     155              :     /// Write the message into an internal buffer and flush it.
     156           54 :     pub async fn write_message(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
     157           54 :         self.write_message_noflush(message)?;
     158           54 :         self.flush().await?;
     159           54 :         Ok(self)
     160           54 :     }
     161              : 
     162              :     /// Flush the output buffer into the underlying stream.
     163           55 :     pub(crate) async fn flush(&mut self) -> io::Result<&mut Self> {
     164           55 :         self.framed.flush().await?;
     165           55 :         Ok(self)
     166           55 :     }
     167              : 
     168              :     /// Writes message with the given error kind to the stream.
     169              :     /// Used only for probe queries
     170            1 :     async fn write_format_message(
     171            1 :         &mut self,
     172            1 :         msg: &str,
     173            1 :         error_kind: ErrorKind,
     174            1 :         ctx: Option<&crate::context::RequestContext>,
     175            1 :     ) -> String {
     176            1 :         let formatted_msg = match ctx {
     177            0 :             Some(ctx) if ctx.get_testodrome_id().is_some() => {
     178            0 :                 serde_json::to_string(&ProbeErrorData {
     179            0 :                     tag: ErrorTag::from(error_kind),
     180            0 :                     msg: msg.to_string(),
     181            0 :                     cold_start_info: Some(ctx.cold_start_info()),
     182            0 :                 })
     183            0 :                 .unwrap_or_default()
     184              :             }
     185            1 :             _ => msg.to_string(),
     186              :         };
     187              : 
     188              :         // already error case, ignore client IO error
     189            1 :         self.write_message(&BeMessage::ErrorResponse(&formatted_msg, None))
     190            1 :             .await
     191            1 :             .inspect_err(|e| debug!("write_message failed: {e}"))
     192            1 :             .ok();
     193            1 : 
     194            1 :         formatted_msg
     195            1 :     }
     196              : 
     197              :     /// Write the error message using [`Self::write_format_message`], then re-throw it.
     198              :     /// Allowing string literals is safe under the assumption they might not contain any runtime info.
     199              :     /// This method exists due to `&str` not implementing `Into<anyhow::Error>`.
     200              :     /// If `ctx` is provided and has testodrome_id set, error messages will be prefixed according to error kind.
     201            1 :     pub async fn throw_error_str<T>(
     202            1 :         &mut self,
     203            1 :         msg: &'static str,
     204            1 :         error_kind: ErrorKind,
     205            1 :         ctx: Option<&crate::context::RequestContext>,
     206            1 :     ) -> Result<T, ReportedError> {
     207            1 :         self.write_format_message(msg, error_kind, ctx).await;
     208              : 
     209            1 :         if error_kind != ErrorKind::RateLimit && error_kind != ErrorKind::User {
     210            0 :             tracing::info!(
     211            0 :                 kind = error_kind.to_metric_label(),
     212            0 :                 msg,
     213            0 :                 "forwarding error to user"
     214              :             );
     215            1 :         }
     216              : 
     217            1 :         Err(ReportedError {
     218            1 :             source: anyhow::anyhow!(msg),
     219            1 :             error_kind,
     220            1 :         })
     221            1 :     }
     222              : 
     223              :     /// Write the error message using [`Self::write_format_message`], then re-throw it.
     224              :     /// Trait [`UserFacingError`] acts as an allowlist for error types.
     225              :     /// If `ctx` is provided and has testodrome_id set, error messages will be prefixed according to error kind.
     226            0 :     pub(crate) async fn throw_error<T, E>(
     227            0 :         &mut self,
     228            0 :         error: E,
     229            0 :         ctx: Option<&crate::context::RequestContext>,
     230            0 :     ) -> Result<T, ReportedError>
     231            0 :     where
     232            0 :         E: UserFacingError + Into<anyhow::Error>,
     233            0 :     {
     234            0 :         let error_kind = error.get_error_kind();
     235            0 :         let msg = error.to_string_client();
     236            0 :         self.write_format_message(&msg, error_kind, ctx).await;
     237            0 :         if error_kind != ErrorKind::RateLimit && error_kind != ErrorKind::User {
     238            0 :             tracing::info!(
     239            0 :                 kind=error_kind.to_metric_label(),
     240            0 :                 error=%error,
     241            0 :                 msg,
     242            0 :                 "forwarding error to user",
     243              :             );
     244            0 :         }
     245              : 
     246            0 :         Err(ReportedError {
     247            0 :             source: anyhow::anyhow!(error),
     248            0 :             error_kind,
     249            0 :         })
     250            0 :     }
     251              : }
     252              : 
     253              : /// Wrapper for upgrading raw streams into secure streams.
     254              : pub enum Stream<S> {
     255              :     /// We always begin with a raw stream,
     256              :     /// which may then be upgraded into a secure stream.
     257              :     Raw { raw: S },
     258              :     Tls {
     259              :         /// We box [`TlsStream`] since it can be quite large.
     260              :         tls: Box<TlsStream<S>>,
     261              :         /// Channel binding parameter
     262              :         tls_server_end_point: TlsServerEndPoint,
     263              :     },
     264              : }
     265              : 
     266              : impl<S: Unpin> Unpin for Stream<S> {}
     267              : 
     268              : impl<S> Stream<S> {
     269              :     /// Construct a new instance from a raw stream.
     270           25 :     pub fn from_raw(raw: S) -> Self {
     271           25 :         Self::Raw { raw }
     272           25 :     }
     273              : 
     274              :     /// Return SNI hostname when it's available.
     275            0 :     pub fn sni_hostname(&self) -> Option<&str> {
     276            0 :         match self {
     277            0 :             Stream::Raw { .. } => None,
     278            0 :             Stream::Tls { tls, .. } => tls.get_ref().1.server_name(),
     279              :         }
     280            0 :     }
     281              : 
     282           15 :     pub(crate) fn tls_server_end_point(&self) -> TlsServerEndPoint {
     283           15 :         match self {
     284            3 :             Stream::Raw { .. } => TlsServerEndPoint::Undefined,
     285              :             Stream::Tls {
     286           12 :                 tls_server_end_point,
     287           12 :                 ..
     288           12 :             } => *tls_server_end_point,
     289              :         }
     290           15 :     }
     291              : }
     292              : 
     293              : #[derive(Debug, Error)]
     294              : #[error("Can't upgrade TLS stream")]
     295              : pub enum StreamUpgradeError {
     296              :     #[error("Bad state reached: can't upgrade TLS stream")]
     297              :     AlreadyTls,
     298              : 
     299              :     #[error("Can't upgrade stream: IO error: {0}")]
     300              :     Io(#[from] io::Error),
     301              : }
     302              : 
     303              : impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
     304              :     /// If possible, upgrade raw stream into a secure TLS-based stream.
     305            0 :     pub async fn upgrade(
     306            0 :         self,
     307            0 :         cfg: Arc<ServerConfig>,
     308            0 :         record_handshake_error: bool,
     309            0 :     ) -> Result<TlsStream<S>, StreamUpgradeError> {
     310            0 :         match self {
     311            0 :             Stream::Raw { raw } => Ok(tokio_rustls::TlsAcceptor::from(cfg)
     312            0 :                 .accept(raw)
     313            0 :                 .await
     314            0 :                 .inspect_err(|_| {
     315            0 :                     if record_handshake_error {
     316            0 :                         Metrics::get().proxy.tls_handshake_failures.inc();
     317            0 :                     }
     318            0 :                 })?),
     319            0 :             Stream::Tls { .. } => Err(StreamUpgradeError::AlreadyTls),
     320              :         }
     321            0 :     }
     322              : }
     323              : 
     324              : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for Stream<S> {
     325          154 :     fn poll_read(
     326          154 :         mut self: Pin<&mut Self>,
     327          154 :         context: &mut task::Context<'_>,
     328          154 :         buf: &mut ReadBuf<'_>,
     329          154 :     ) -> task::Poll<io::Result<()>> {
     330          154 :         match &mut *self {
     331           30 :             Self::Raw { raw } => Pin::new(raw).poll_read(context, buf),
     332          124 :             Self::Tls { tls, .. } => Pin::new(tls).poll_read(context, buf),
     333              :         }
     334          154 :     }
     335              : }
     336              : 
     337              : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
     338           71 :     fn poll_write(
     339           71 :         mut self: Pin<&mut Self>,
     340           71 :         context: &mut task::Context<'_>,
     341           71 :         buf: &[u8],
     342           71 :     ) -> task::Poll<io::Result<usize>> {
     343           71 :         match &mut *self {
     344           27 :             Self::Raw { raw } => Pin::new(raw).poll_write(context, buf),
     345           44 :             Self::Tls { tls, .. } => Pin::new(tls).poll_write(context, buf),
     346              :         }
     347           71 :     }
     348              : 
     349           71 :     fn poll_flush(
     350           71 :         mut self: Pin<&mut Self>,
     351           71 :         context: &mut task::Context<'_>,
     352           71 :     ) -> task::Poll<io::Result<()>> {
     353           71 :         match &mut *self {
     354           27 :             Self::Raw { raw } => Pin::new(raw).poll_flush(context),
     355           44 :             Self::Tls { tls, .. } => Pin::new(tls).poll_flush(context),
     356              :         }
     357           71 :     }
     358              : 
     359            0 :     fn poll_shutdown(
     360            0 :         mut self: Pin<&mut Self>,
     361            0 :         context: &mut task::Context<'_>,
     362            0 :     ) -> task::Poll<io::Result<()>> {
     363            0 :         match &mut *self {
     364            0 :             Self::Raw { raw } => Pin::new(raw).poll_shutdown(context),
     365            0 :             Self::Tls { tls, .. } => Pin::new(tls).poll_shutdown(context),
     366              :         }
     367            0 :     }
     368              : }
        

Generated by: LCOV version 2.1-beta