LCOV - code coverage report
Current view: top level - proxy/src - stream.rs (source / functions) Coverage Total Hit
Test: 7179b4db0d82ca8088cc95c44c4be4232078509c.info Lines: 58.2 % 170 99
Test Date: 2024-11-21 16:46:58 Functions: 23.5 % 166 39

            Line data    Source code
       1              : use std::pin::Pin;
       2              : use std::sync::Arc;
       3              : use std::{io, task};
       4              : 
       5              : use bytes::BytesMut;
       6              : use pq_proto::framed::{ConnectionError, Framed};
       7              : use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
       8              : use rustls::ServerConfig;
       9              : use thiserror::Error;
      10              : use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
      11              : use tokio_rustls::server::TlsStream;
      12              : use tracing::debug;
      13              : 
      14              : use crate::config::TlsServerEndPoint;
      15              : use crate::error::{ErrorKind, ReportableError, UserFacingError};
      16              : use crate::metrics::Metrics;
      17              : 
      18              : /// Stream wrapper which implements libpq's protocol.
      19              : ///
      20              : /// NOTE: This object deliberately doesn't implement [`AsyncRead`]
      21              : /// or [`AsyncWrite`] to prevent subtle errors (e.g. trying
      22              : /// to pass random malformed bytes through the connection).
      23              : pub struct PqStream<S> {
      24              :     pub(crate) framed: Framed<S>,
      25              : }
      26              : 
      27              : impl<S> PqStream<S> {
      28              :     /// Construct a new libpq protocol wrapper.
      29           25 :     pub fn new(stream: S) -> Self {
      30           25 :         Self {
      31           25 :             framed: Framed::new(stream),
      32           25 :         }
      33           25 :     }
      34              : 
      35              :     /// Extract the underlying stream and read buffer.
      36            0 :     pub fn into_inner(self) -> (S, BytesMut) {
      37            0 :         self.framed.into_inner()
      38            0 :     }
      39              : 
      40              :     /// Get a shared reference to the underlying stream.
      41           35 :     pub(crate) fn get_ref(&self) -> &S {
      42           35 :         self.framed.get_ref()
      43           35 :     }
      44              : }
      45              : 
      46            0 : fn err_connection() -> io::Error {
      47            0 :     io::Error::new(io::ErrorKind::ConnectionAborted, "connection is lost")
      48            0 : }
      49              : 
      50              : impl<S: AsyncRead + Unpin> PqStream<S> {
      51              :     /// Receive [`FeStartupPacket`], which is a first packet sent by a client.
      52           42 :     pub async fn read_startup_packet(&mut self) -> io::Result<FeStartupPacket> {
      53           42 :         self.framed
      54           42 :             .read_startup_message()
      55            7 :             .await
      56           42 :             .map_err(ConnectionError::into_io_error)?
      57           42 :             .ok_or_else(err_connection)
      58           42 :     }
      59              : 
      60           26 :     async fn read_message(&mut self) -> io::Result<FeMessage> {
      61           26 :         self.framed
      62           26 :             .read_message()
      63           26 :             .await
      64           26 :             .map_err(ConnectionError::into_io_error)?
      65           25 :             .ok_or_else(err_connection)
      66           26 :     }
      67              : 
      68           26 :     pub(crate) async fn read_password_message(&mut self) -> io::Result<bytes::Bytes> {
      69           26 :         match self.read_message().await? {
      70           25 :             FeMessage::PasswordMessage(msg) => Ok(msg),
      71            0 :             bad => Err(io::Error::new(
      72            0 :                 io::ErrorKind::InvalidData,
      73            0 :                 format!("unexpected message type: {bad:?}"),
      74            0 :             )),
      75              :         }
      76           26 :     }
      77              : }
      78              : 
      79              : #[derive(Debug)]
      80              : pub struct ReportedError {
      81              :     source: anyhow::Error,
      82              :     error_kind: ErrorKind,
      83              : }
      84              : 
      85              : impl std::fmt::Display for ReportedError {
      86            1 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
      87            1 :         self.source.fmt(f)
      88            1 :     }
      89              : }
      90              : 
      91              : impl std::error::Error for ReportedError {
      92            0 :     fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
      93            0 :         self.source.source()
      94            0 :     }
      95              : }
      96              : 
      97              : impl ReportableError for ReportedError {
      98            0 :     fn get_error_kind(&self) -> ErrorKind {
      99            0 :         self.error_kind
     100            0 :     }
     101              : }
     102              : 
     103              : impl<S: AsyncWrite + Unpin> PqStream<S> {
     104              :     /// Write the message into an internal buffer, but don't flush the underlying stream.
     105           77 :     pub(crate) fn write_message_noflush(
     106           77 :         &mut self,
     107           77 :         message: &BeMessage<'_>,
     108           77 :     ) -> io::Result<&mut Self> {
     109           77 :         self.framed
     110           77 :             .write_message(message)
     111           77 :             .map_err(ProtocolError::into_io_error)?;
     112           77 :         Ok(self)
     113           77 :     }
     114              : 
     115              :     /// Write the message into an internal buffer and flush it.
     116           60 :     pub async fn write_message(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
     117           60 :         self.write_message_noflush(message)?;
     118           60 :         self.flush().await?;
     119           60 :         Ok(self)
     120           60 :     }
     121              : 
     122              :     /// Flush the output buffer into the underlying stream.
     123           60 :     pub(crate) async fn flush(&mut self) -> io::Result<&mut Self> {
     124           60 :         self.framed.flush().await?;
     125           60 :         Ok(self)
     126           60 :     }
     127              : 
     128              :     /// Write the error message using [`Self::write_message`], then re-throw it.
     129              :     /// Allowing string literals is safe under the assumption they might not contain any runtime info.
     130              :     /// This method exists due to `&str` not implementing `Into<anyhow::Error>`.
     131            1 :     pub async fn throw_error_str<T>(
     132            1 :         &mut self,
     133            1 :         msg: &'static str,
     134            1 :         error_kind: ErrorKind,
     135            1 :     ) -> Result<T, ReportedError> {
     136            1 :         // TODO: only log this for actually interesting errors
     137            1 :         tracing::info!(
     138            0 :             kind = error_kind.to_metric_label(),
     139            0 :             msg,
     140            0 :             "forwarding error to user"
     141              :         );
     142              : 
     143              :         // already error case, ignore client IO error
     144            1 :         self.write_message(&BeMessage::ErrorResponse(msg, None))
     145            0 :             .await
     146            1 :             .inspect_err(|e| debug!("write_message failed: {e}"))
     147            1 :             .ok();
     148            1 : 
     149            1 :         Err(ReportedError {
     150            1 :             source: anyhow::anyhow!(msg),
     151            1 :             error_kind,
     152            1 :         })
     153            1 :     }
     154              : 
     155              :     /// Write the error message using [`Self::write_message`], then re-throw it.
     156              :     /// Trait [`UserFacingError`] acts as an allowlist for error types.
     157            0 :     pub(crate) async fn throw_error<T, E>(&mut self, error: E) -> Result<T, ReportedError>
     158            0 :     where
     159            0 :         E: UserFacingError + Into<anyhow::Error>,
     160            0 :     {
     161            0 :         let error_kind = error.get_error_kind();
     162            0 :         let msg = error.to_string_client();
     163            0 :         tracing::info!(
     164            0 :             kind=error_kind.to_metric_label(),
     165            0 :             error=%error,
     166            0 :             msg,
     167            0 :             "forwarding error to user"
     168              :         );
     169              : 
     170              :         // already error case, ignore client IO error
     171            0 :         self.write_message(&BeMessage::ErrorResponse(&msg, None))
     172            0 :             .await
     173            0 :             .inspect_err(|e| debug!("write_message failed: {e}"))
     174            0 :             .ok();
     175            0 : 
     176            0 :         Err(ReportedError {
     177            0 :             source: anyhow::anyhow!(error),
     178            0 :             error_kind,
     179            0 :         })
     180            0 :     }
     181              : }
     182              : 
     183              : /// Wrapper for upgrading raw streams into secure streams.
     184              : pub enum Stream<S> {
     185              :     /// We always begin with a raw stream,
     186              :     /// which may then be upgraded into a secure stream.
     187              :     Raw { raw: S },
     188              :     Tls {
     189              :         /// We box [`TlsStream`] since it can be quite large.
     190              :         tls: Box<TlsStream<S>>,
     191              :         /// Channel binding parameter
     192              :         tls_server_end_point: TlsServerEndPoint,
     193              :     },
     194              : }
     195              : 
     196              : impl<S: Unpin> Unpin for Stream<S> {}
     197              : 
     198              : impl<S> Stream<S> {
     199              :     /// Construct a new instance from a raw stream.
     200           25 :     pub fn from_raw(raw: S) -> Self {
     201           25 :         Self::Raw { raw }
     202           25 :     }
     203              : 
     204              :     /// Return SNI hostname when it's available.
     205            0 :     pub fn sni_hostname(&self) -> Option<&str> {
     206            0 :         match self {
     207            0 :             Stream::Raw { .. } => None,
     208            0 :             Stream::Tls { tls, .. } => tls.get_ref().1.server_name(),
     209              :         }
     210            0 :     }
     211              : 
     212           15 :     pub(crate) fn tls_server_end_point(&self) -> TlsServerEndPoint {
     213           15 :         match self {
     214            3 :             Stream::Raw { .. } => TlsServerEndPoint::Undefined,
     215              :             Stream::Tls {
     216           12 :                 tls_server_end_point,
     217           12 :                 ..
     218           12 :             } => *tls_server_end_point,
     219              :         }
     220           15 :     }
     221              : }
     222              : 
     223            0 : #[derive(Debug, Error)]
     224              : #[error("Can't upgrade TLS stream")]
     225              : pub enum StreamUpgradeError {
     226              :     #[error("Bad state reached: can't upgrade TLS stream")]
     227              :     AlreadyTls,
     228              : 
     229              :     #[error("Can't upgrade stream: IO error: {0}")]
     230              :     Io(#[from] io::Error),
     231              : }
     232              : 
     233              : impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
     234              :     /// If possible, upgrade raw stream into a secure TLS-based stream.
     235            0 :     pub async fn upgrade(
     236            0 :         self,
     237            0 :         cfg: Arc<ServerConfig>,
     238            0 :         record_handshake_error: bool,
     239            0 :     ) -> Result<TlsStream<S>, StreamUpgradeError> {
     240            0 :         match self {
     241            0 :             Stream::Raw { raw } => Ok(tokio_rustls::TlsAcceptor::from(cfg)
     242            0 :                 .accept(raw)
     243            0 :                 .await
     244            0 :                 .inspect_err(|_| {
     245            0 :                     if record_handshake_error {
     246            0 :                         Metrics::get().proxy.tls_handshake_failures.inc();
     247            0 :                     }
     248            0 :                 })?),
     249            0 :             Stream::Tls { .. } => Err(StreamUpgradeError::AlreadyTls),
     250              :         }
     251            0 :     }
     252              : }
     253              : 
     254              : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for Stream<S> {
     255          155 :     fn poll_read(
     256          155 :         mut self: Pin<&mut Self>,
     257          155 :         context: &mut task::Context<'_>,
     258          155 :         buf: &mut ReadBuf<'_>,
     259          155 :     ) -> task::Poll<io::Result<()>> {
     260          155 :         match &mut *self {
     261           30 :             Self::Raw { raw } => Pin::new(raw).poll_read(context, buf),
     262          125 :             Self::Tls { tls, .. } => Pin::new(tls).poll_read(context, buf),
     263              :         }
     264          155 :     }
     265              : }
     266              : 
     267              : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
     268           76 :     fn poll_write(
     269           76 :         mut self: Pin<&mut Self>,
     270           76 :         context: &mut task::Context<'_>,
     271           76 :         buf: &[u8],
     272           76 :     ) -> task::Poll<io::Result<usize>> {
     273           76 :         match &mut *self {
     274           27 :             Self::Raw { raw } => Pin::new(raw).poll_write(context, buf),
     275           49 :             Self::Tls { tls, .. } => Pin::new(tls).poll_write(context, buf),
     276              :         }
     277           76 :     }
     278              : 
     279           76 :     fn poll_flush(
     280           76 :         mut self: Pin<&mut Self>,
     281           76 :         context: &mut task::Context<'_>,
     282           76 :     ) -> task::Poll<io::Result<()>> {
     283           76 :         match &mut *self {
     284           27 :             Self::Raw { raw } => Pin::new(raw).poll_flush(context),
     285           49 :             Self::Tls { tls, .. } => Pin::new(tls).poll_flush(context),
     286              :         }
     287           76 :     }
     288              : 
     289            0 :     fn poll_shutdown(
     290            0 :         mut self: Pin<&mut Self>,
     291            0 :         context: &mut task::Context<'_>,
     292            0 :     ) -> task::Poll<io::Result<()>> {
     293            0 :         match &mut *self {
     294            0 :             Self::Raw { raw } => Pin::new(raw).poll_shutdown(context),
     295            0 :             Self::Tls { tls, .. } => Pin::new(tls).poll_shutdown(context),
     296              :         }
     297            0 :     }
     298              : }
        

Generated by: LCOV version 2.1-beta