LCOV - differential code coverage report
Current view: top level - proxy/src - stream.rs (source / functions) Coverage Total Hit UBC CBC
Current: cd44433dd675caa99df17a61b18949c8387e2242.info Lines: 91.0 % 122 111 11 111
Current Date: 2024-01-09 02:06:09 Functions: 51.5 % 167 86 81 86
Baseline: 66c52a629a0f4a503e193045e0df4c77139e344b.info
Baseline Date: 2024-01-08 15:34:46

           TLA  Line data    Source code
       1                 : use crate::config::TlsServerEndPoint;
       2                 : use crate::error::UserFacingError;
       3                 : use anyhow::bail;
       4                 : use bytes::BytesMut;
       5                 : 
       6                 : use pq_proto::framed::{ConnectionError, Framed};
       7                 : use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
       8                 : use rustls::ServerConfig;
       9                 : use std::pin::Pin;
      10                 : use std::sync::Arc;
      11                 : use std::{io, task};
      12                 : use thiserror::Error;
      13                 : use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
      14                 : use tokio_rustls::server::TlsStream;
      15                 : 
      16                 : /// Stream wrapper which implements libpq's protocol.
      17                 : /// NOTE: This object deliberately doesn't implement [`AsyncRead`]
      18                 : /// or [`AsyncWrite`] to prevent subtle errors (e.g. trying
      19                 : /// to pass random malformed bytes through the connection).
      20                 : pub struct PqStream<S> {
      21                 :     pub(crate) framed: Framed<S>,
      22                 : }
      23                 : 
      24                 : impl<S> PqStream<S> {
      25                 :     /// Construct a new libpq protocol wrapper.
      26 CBC         153 :     pub fn new(stream: S) -> Self {
      27             153 :         Self {
      28             153 :             framed: Framed::new(stream),
      29             153 :         }
      30             153 :     }
      31                 : 
      32                 :     /// Extract the underlying stream and read buffer.
      33             108 :     pub fn into_inner(self) -> (S, BytesMut) {
      34             108 :         self.framed.into_inner()
      35             108 :     }
      36                 : 
      37                 :     /// Get a shared reference to the underlying stream.
      38             169 :     pub fn get_ref(&self) -> &S {
      39             169 :         self.framed.get_ref()
      40             169 :     }
      41                 : }
      42                 : 
      43               1 : fn err_connection() -> io::Error {
      44               1 :     io::Error::new(io::ErrorKind::ConnectionAborted, "connection is lost")
      45               1 : }
      46                 : 
      47                 : impl<S: AsyncRead + Unpin> PqStream<S> {
      48                 :     /// Receive [`FeStartupPacket`], which is a first packet sent by a client.
      49             153 :     pub async fn read_startup_packet(&mut self) -> io::Result<FeStartupPacket> {
      50             153 :         self.framed
      51             153 :             .read_startup_message()
      52               9 :             .await
      53             153 :             .map_err(ConnectionError::into_io_error)?
      54             153 :             .ok_or_else(err_connection)
      55             153 :     }
      56                 : 
      57              97 :     async fn read_message(&mut self) -> io::Result<FeMessage> {
      58              97 :         self.framed
      59              97 :             .read_message()
      60              93 :             .await
      61              97 :             .map_err(ConnectionError::into_io_error)?
      62              96 :             .ok_or_else(err_connection)
      63              97 :     }
      64                 : 
      65              97 :     pub async fn read_password_message(&mut self) -> io::Result<bytes::Bytes> {
      66              97 :         match self.read_message().await? {
      67              96 :             FeMessage::PasswordMessage(msg) => Ok(msg),
      68 UBC           0 :             bad => Err(io::Error::new(
      69               0 :                 io::ErrorKind::InvalidData,
      70               0 :                 format!("unexpected message type: {:?}", bad),
      71               0 :             )),
      72                 :         }
      73 CBC          97 :     }
      74                 : }
      75                 : 
      76                 : impl<S: AsyncWrite + Unpin> PqStream<S> {
      77                 :     /// Write the message into an internal buffer, but don't flush the underlying stream.
      78             866 :     pub fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
      79             866 :         self.framed
      80             866 :             .write_message(message)
      81             866 :             .map_err(ProtocolError::into_io_error)?;
      82             866 :         Ok(self)
      83             866 :     }
      84                 : 
      85                 :     /// Write the message into an internal buffer and flush it.
      86             276 :     pub async fn write_message(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
      87             276 :         self.write_message_noflush(message)?;
      88             276 :         self.flush().await?;
      89             276 :         Ok(self)
      90             276 :     }
      91                 : 
      92                 :     /// Flush the output buffer into the underlying stream.
      93             276 :     pub async fn flush(&mut self) -> io::Result<&mut Self> {
      94             276 :         self.framed.flush().await?;
      95             276 :         Ok(self)
      96             276 :     }
      97                 : 
      98                 :     /// Write the error message using [`Self::write_message`], then re-throw it.
      99                 :     /// Allowing string literals is safe under the assumption they might not contain any runtime info.
     100                 :     /// This method exists due to `&str` not implementing `Into<anyhow::Error>`.
     101              12 :     pub async fn throw_error_str<T>(&mut self, error: &'static str) -> anyhow::Result<T> {
     102              11 :         tracing::info!("forwarding error to user: {error}");
     103              12 :         self.write_message(&BeMessage::ErrorResponse(error, None))
     104 UBC           0 :             .await?;
     105 CBC          12 :         bail!(error)
     106              12 :     }
     107                 : 
     108                 :     /// Write the error message using [`Self::write_message`], then re-throw it.
     109                 :     /// Trait [`UserFacingError`] acts as an allowlist for error types.
     110              11 :     pub async fn throw_error<T, E>(&mut self, error: E) -> anyhow::Result<T>
     111              11 :     where
     112              11 :         E: UserFacingError + Into<anyhow::Error>,
     113              11 :     {
     114              11 :         let msg = error.to_string_client();
     115              11 :         tracing::info!("forwarding error to user: {msg}");
     116              11 :         self.write_message(&BeMessage::ErrorResponse(&msg, None))
     117 UBC           0 :             .await?;
     118 CBC          11 :         bail!(error)
     119              11 :     }
     120                 : }
     121                 : 
     122                 : /// Wrapper for upgrading raw streams into secure streams.
     123                 : pub enum Stream<S> {
     124                 :     /// We always begin with a raw stream,
     125                 :     /// which may then be upgraded into a secure stream.
     126                 :     Raw { raw: S },
     127                 :     Tls {
     128                 :         /// We box [`TlsStream`] since it can be quite large.
     129                 :         tls: Box<TlsStream<S>>,
     130                 :         /// Channel binding parameter
     131                 :         tls_server_end_point: TlsServerEndPoint,
     132                 :     },
     133                 : }
     134                 : 
     135                 : impl<S: Unpin> Unpin for Stream<S> {}
     136                 : 
     137                 : impl<S> Stream<S> {
     138                 :     /// Construct a new instance from a raw stream.
     139              84 :     pub fn from_raw(raw: S) -> Self {
     140              84 :         Self::Raw { raw }
     141              84 :     }
     142                 : 
     143                 :     /// Return SNI hostname when it's available.
     144              50 :     pub fn sni_hostname(&self) -> Option<&str> {
     145              50 :         match self {
     146 UBC           0 :             Stream::Raw { .. } => None,
     147 CBC          50 :             Stream::Tls { tls, .. } => tls.get_ref().1.server_name(),
     148                 :         }
     149              50 :     }
     150                 : 
     151              51 :     pub fn tls_server_end_point(&self) -> TlsServerEndPoint {
     152              51 :         match self {
     153 UBC           0 :             Stream::Raw { .. } => TlsServerEndPoint::Undefined,
     154                 :             Stream::Tls {
     155 CBC          51 :                 tls_server_end_point,
     156              51 :                 ..
     157              51 :             } => *tls_server_end_point,
     158                 :         }
     159              51 :     }
     160                 : }
     161                 : 
     162 UBC           0 : #[derive(Debug, Error)]
     163                 : #[error("Can't upgrade TLS stream")]
     164                 : pub enum StreamUpgradeError {
     165                 :     #[error("Bad state reached: can't upgrade TLS stream")]
     166                 :     AlreadyTls,
     167                 : 
     168                 :     #[error("Can't upgrade stream: IO error: {0}")]
     169                 :     Io(#[from] io::Error),
     170                 : }
     171                 : 
     172                 : impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
     173                 :     /// If possible, upgrade raw stream into a secure TLS-based stream.
     174 CBC          70 :     pub async fn upgrade(self, cfg: Arc<ServerConfig>) -> Result<TlsStream<S>, StreamUpgradeError> {
     175              70 :         match self {
     176             147 :             Stream::Raw { raw } => Ok(tokio_rustls::TlsAcceptor::from(cfg).accept(raw).await?),
     177 UBC           0 :             Stream::Tls { .. } => Err(StreamUpgradeError::AlreadyTls),
     178                 :         }
     179 CBC          70 :     }
     180                 : }
     181                 : 
     182                 : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for Stream<S> {
     183             638 :     fn poll_read(
     184             638 :         mut self: Pin<&mut Self>,
     185             638 :         context: &mut task::Context<'_>,
     186             638 :         buf: &mut ReadBuf<'_>,
     187             638 :     ) -> task::Poll<io::Result<()>> {
     188             638 :         match &mut *self {
     189              84 :             Self::Raw { raw } => Pin::new(raw).poll_read(context, buf),
     190             554 :             Self::Tls { tls, .. } => Pin::new(tls).poll_read(context, buf),
     191                 :         }
     192             638 :     }
     193                 : }
     194                 : 
     195                 : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
     196             331 :     fn poll_write(
     197             331 :         mut self: Pin<&mut Self>,
     198             331 :         context: &mut task::Context<'_>,
     199             331 :         buf: &[u8],
     200             331 :     ) -> task::Poll<io::Result<usize>> {
     201             331 :         match &mut *self {
     202              83 :             Self::Raw { raw } => Pin::new(raw).poll_write(context, buf),
     203             248 :             Self::Tls { tls, .. } => Pin::new(tls).poll_write(context, buf),
     204                 :         }
     205             331 :     }
     206                 : 
     207             370 :     fn poll_flush(
     208             370 :         mut self: Pin<&mut Self>,
     209             370 :         context: &mut task::Context<'_>,
     210             370 :     ) -> task::Poll<io::Result<()>> {
     211             370 :         match &mut *self {
     212              83 :             Self::Raw { raw } => Pin::new(raw).poll_flush(context),
     213             287 :             Self::Tls { tls, .. } => Pin::new(tls).poll_flush(context),
     214                 :         }
     215             370 :     }
     216                 : 
     217              39 :     fn poll_shutdown(
     218              39 :         mut self: Pin<&mut Self>,
     219              39 :         context: &mut task::Context<'_>,
     220              39 :     ) -> task::Poll<io::Result<()>> {
     221              39 :         match &mut *self {
     222 UBC           0 :             Self::Raw { raw } => Pin::new(raw).poll_shutdown(context),
     223 CBC          39 :             Self::Tls { tls, .. } => Pin::new(tls).poll_shutdown(context),
     224                 :         }
     225              39 :     }
     226                 : }
        

Generated by: LCOV version 2.1-beta