LCOV - code coverage report
Current view: top level - proxy/src - stream.rs (source / functions) Coverage Total Hit
Test: c639aa5f7ab62b43d647b10f40d15a15686ce8a9.info Lines: 92.2 % 154 142
Test Date: 2024-02-12 20:26:03 Functions: 52.0 % 171 89

            Line data    Source code
       1              : use crate::config::TlsServerEndPoint;
       2              : use crate::error::{ErrorKind, ReportableError, UserFacingError};
       3              : use bytes::BytesMut;
       4              : 
       5              : use pq_proto::framed::{ConnectionError, Framed};
       6              : use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
       7              : use rustls::ServerConfig;
       8              : use std::pin::Pin;
       9              : use std::sync::Arc;
      10              : use std::{io, task};
      11              : use thiserror::Error;
      12              : use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
      13              : use tokio_rustls::server::TlsStream;
      14              : 
      15              : /// Stream wrapper which implements libpq's protocol.
      16              : /// NOTE: This object deliberately doesn't implement [`AsyncRead`]
      17              : /// or [`AsyncWrite`] to prevent subtle errors (e.g. trying
      18              : /// to pass random malformed bytes through the connection).
      19              : pub struct PqStream<S> {
      20              :     pub(crate) framed: Framed<S>,
      21              : }
      22              : 
      23              : impl<S> PqStream<S> {
      24              :     /// Construct a new libpq protocol wrapper.
      25          201 :     pub fn new(stream: S) -> Self {
      26          201 :         Self {
      27          201 :             framed: Framed::new(stream),
      28          201 :         }
      29          201 :     }
      30              : 
      31              :     /// Extract the underlying stream and read buffer.
      32          134 :     pub fn into_inner(self) -> (S, BytesMut) {
      33          134 :         self.framed.into_inner()
      34          134 :     }
      35              : 
      36              :     /// Get a shared reference to the underlying stream.
      37          210 :     pub fn get_ref(&self) -> &S {
      38          210 :         self.framed.get_ref()
      39          210 :     }
      40              : }
      41              : 
      42            1 : fn err_connection() -> io::Error {
      43            1 :     io::Error::new(io::ErrorKind::ConnectionAborted, "connection is lost")
      44            1 : }
      45              : 
      46              : impl<S: AsyncRead + Unpin> PqStream<S> {
      47              :     /// Receive [`FeStartupPacket`], which is a first packet sent by a client.
      48          201 :     pub async fn read_startup_packet(&mut self) -> io::Result<FeStartupPacket> {
      49          201 :         self.framed
      50          201 :             .read_startup_message()
      51           18 :             .await
      52          201 :             .map_err(ConnectionError::into_io_error)?
      53          201 :             .ok_or_else(err_connection)
      54          201 :     }
      55              : 
      56          125 :     async fn read_message(&mut self) -> io::Result<FeMessage> {
      57          125 :         self.framed
      58          125 :             .read_message()
      59          125 :             .await
      60          125 :             .map_err(ConnectionError::into_io_error)?
      61          123 :             .ok_or_else(err_connection)
      62          125 :     }
      63              : 
      64          125 :     pub async fn read_password_message(&mut self) -> io::Result<bytes::Bytes> {
      65          125 :         match self.read_message().await? {
      66          123 :             FeMessage::PasswordMessage(msg) => Ok(msg),
      67            0 :             bad => Err(io::Error::new(
      68            0 :                 io::ErrorKind::InvalidData,
      69            0 :                 format!("unexpected message type: {:?}", bad),
      70            0 :             )),
      71              :         }
      72          125 :     }
      73              : }
      74              : 
      75            0 : #[derive(Debug)]
      76              : pub struct ReportedError {
      77              :     source: anyhow::Error,
      78              :     error_kind: ErrorKind,
      79              : }
      80              : 
      81              : impl std::fmt::Display for ReportedError {
      82          112 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
      83          112 :         self.source.fmt(f)
      84          112 :     }
      85              : }
      86              : 
      87              : impl std::error::Error for ReportedError {
      88           44 :     fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
      89           44 :         self.source.source()
      90           44 :     }
      91              : }
      92              : 
      93              : impl ReportableError for ReportedError {
      94           22 :     fn get_error_kind(&self) -> ErrorKind {
      95           22 :         self.error_kind
      96           22 :     }
      97              : }
      98              : 
      99              : impl<S: AsyncWrite + Unpin> PqStream<S> {
     100              :     /// Write the message into an internal buffer, but don't flush the underlying stream.
     101          995 :     pub fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
     102          995 :         self.framed
     103          995 :             .write_message(message)
     104          995 :             .map_err(ProtocolError::into_io_error)?;
     105          995 :         Ok(self)
     106          995 :     }
     107              : 
     108              :     /// Write the message into an internal buffer and flush it.
     109          346 :     pub async fn write_message(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
     110          346 :         self.write_message_noflush(message)?;
     111          346 :         self.flush().await?;
     112          346 :         Ok(self)
     113          346 :     }
     114              : 
     115              :     /// Flush the output buffer into the underlying stream.
     116          346 :     pub async fn flush(&mut self) -> io::Result<&mut Self> {
     117          346 :         self.framed.flush().await?;
     118          346 :         Ok(self)
     119          346 :     }
     120              : 
     121              :     /// Write the error message using [`Self::write_message`], then re-throw it.
     122              :     /// Allowing string literals is safe under the assumption they might not contain any runtime info.
     123              :     /// This method exists due to `&str` not implementing `Into<anyhow::Error>`.
     124           13 :     pub async fn throw_error_str<T>(
     125           13 :         &mut self,
     126           13 :         msg: &'static str,
     127           13 :         error_kind: ErrorKind,
     128           13 :     ) -> Result<T, ReportedError> {
     129           11 :         tracing::info!(
     130           11 :             kind = error_kind.to_metric_label(),
     131           11 :             msg,
     132           11 :             "forwarding error to user"
     133           11 :         );
     134              : 
     135              :         // already error case, ignore client IO error
     136           13 :         let _: Result<_, std::io::Error> = self
     137           13 :             .write_message(&BeMessage::ErrorResponse(msg, None))
     138            0 :             .await;
     139              : 
     140           13 :         Err(ReportedError {
     141           13 :             source: anyhow::anyhow!(msg),
     142           13 :             error_kind,
     143           13 :         })
     144           13 :     }
     145              : 
     146              :     /// Write the error message using [`Self::write_message`], then re-throw it.
     147              :     /// Trait [`UserFacingError`] acts as an allowlist for error types.
     148           11 :     pub async fn throw_error<T, E>(&mut self, error: E) -> Result<T, ReportedError>
     149           11 :     where
     150           11 :         E: UserFacingError + Into<anyhow::Error>,
     151           11 :     {
     152           11 :         let error_kind = error.get_error_kind();
     153           11 :         let msg = error.to_string_client();
     154           11 :         tracing::info!(
     155           11 :             kind=error_kind.to_metric_label(),
     156           11 :             error=%error,
     157           11 :             msg,
     158           11 :             "forwarding error to user"
     159           11 :         );
     160              : 
     161              :         // already error case, ignore client IO error
     162           11 :         let _: Result<_, std::io::Error> = self
     163           11 :             .write_message(&BeMessage::ErrorResponse(&msg, None))
     164            0 :             .await;
     165              : 
     166           11 :         Err(ReportedError {
     167           11 :             source: anyhow::anyhow!(error),
     168           11 :             error_kind,
     169           11 :         })
     170           11 :     }
     171              : }
     172              : 
     173              : /// Wrapper for upgrading raw streams into secure streams.
     174              : pub enum Stream<S> {
     175              :     /// We always begin with a raw stream,
     176              :     /// which may then be upgraded into a secure stream.
     177              :     Raw { raw: S },
     178              :     Tls {
     179              :         /// We box [`TlsStream`] since it can be quite large.
     180              :         tls: Box<TlsStream<S>>,
     181              :         /// Channel binding parameter
     182              :         tls_server_end_point: TlsServerEndPoint,
     183              :     },
     184              : }
     185              : 
     186              : impl<S: Unpin> Unpin for Stream<S> {}
     187              : 
     188              : impl<S> Stream<S> {
     189              :     /// Construct a new instance from a raw stream.
     190          109 :     pub fn from_raw(raw: S) -> Self {
     191          109 :         Self::Raw { raw }
     192          109 :     }
     193              : 
     194              :     /// Return SNI hostname when it's available.
     195           53 :     pub fn sni_hostname(&self) -> Option<&str> {
     196           53 :         match self {
     197            0 :             Stream::Raw { .. } => None,
     198           53 :             Stream::Tls { tls, .. } => tls.get_ref().1.server_name(),
     199              :         }
     200           53 :     }
     201              : 
     202           66 :     pub fn tls_server_end_point(&self) -> TlsServerEndPoint {
     203           66 :         match self {
     204            0 :             Stream::Raw { .. } => TlsServerEndPoint::Undefined,
     205              :             Stream::Tls {
     206           66 :                 tls_server_end_point,
     207           66 :                 ..
     208           66 :             } => *tls_server_end_point,
     209              :         }
     210           66 :     }
     211              : }
     212              : 
     213            0 : #[derive(Debug, Error)]
     214              : #[error("Can't upgrade TLS stream")]
     215              : pub enum StreamUpgradeError {
     216              :     #[error("Bad state reached: can't upgrade TLS stream")]
     217              :     AlreadyTls,
     218              : 
     219              :     #[error("Can't upgrade stream: IO error: {0}")]
     220              :     Io(#[from] io::Error),
     221              : }
     222              : 
     223              : impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
     224              :     /// If possible, upgrade raw stream into a secure TLS-based stream.
     225           93 :     pub async fn upgrade(self, cfg: Arc<ServerConfig>) -> Result<TlsStream<S>, StreamUpgradeError> {
     226           93 :         match self {
     227          200 :             Stream::Raw { raw } => Ok(tokio_rustls::TlsAcceptor::from(cfg).accept(raw).await?),
     228            0 :             Stream::Tls { .. } => Err(StreamUpgradeError::AlreadyTls),
     229              :         }
     230           93 :     }
     231              : }
     232              : 
     233              : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for Stream<S> {
     234          821 :     fn poll_read(
     235          821 :         mut self: Pin<&mut Self>,
     236          821 :         context: &mut task::Context<'_>,
     237          821 :         buf: &mut ReadBuf<'_>,
     238          821 :     ) -> task::Poll<io::Result<()>> {
     239          821 :         match &mut *self {
     240          109 :             Self::Raw { raw } => Pin::new(raw).poll_read(context, buf),
     241          712 :             Self::Tls { tls, .. } => Pin::new(tls).poll_read(context, buf),
     242              :         }
     243          821 :     }
     244              : }
     245              : 
     246              : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
     247          420 :     fn poll_write(
     248          420 :         mut self: Pin<&mut Self>,
     249          420 :         context: &mut task::Context<'_>,
     250          420 :         buf: &[u8],
     251          420 :     ) -> task::Poll<io::Result<usize>> {
     252          420 :         match &mut *self {
     253          108 :             Self::Raw { raw } => Pin::new(raw).poll_write(context, buf),
     254          312 :             Self::Tls { tls, .. } => Pin::new(tls).poll_write(context, buf),
     255              :         }
     256          420 :     }
     257              : 
     258          420 :     fn poll_flush(
     259          420 :         mut self: Pin<&mut Self>,
     260          420 :         context: &mut task::Context<'_>,
     261          420 :     ) -> task::Poll<io::Result<()>> {
     262          420 :         match &mut *self {
     263          108 :             Self::Raw { raw } => Pin::new(raw).poll_flush(context),
     264          312 :             Self::Tls { tls, .. } => Pin::new(tls).poll_flush(context),
     265              :         }
     266          420 :     }
     267              : 
     268           42 :     fn poll_shutdown(
     269           42 :         mut self: Pin<&mut Self>,
     270           42 :         context: &mut task::Context<'_>,
     271           42 :     ) -> task::Poll<io::Result<()>> {
     272           42 :         match &mut *self {
     273            0 :             Self::Raw { raw } => Pin::new(raw).poll_shutdown(context),
     274           42 :             Self::Tls { tls, .. } => Pin::new(tls).poll_shutdown(context),
     275              :         }
     276           42 :     }
     277              : }
        

Generated by: LCOV version 2.1-beta