LCOV - differential code coverage report
Current view: top level - proxy/src - stream.rs (source / functions) Coverage Total Hit UBC CBC
Current: f6946e90941b557c917ac98cd5a7e9506d180f3e.info Lines: 91.7 % 120 110 10 110
Current Date: 2023-10-19 02:04:12 Functions: 46.5 % 144 67 77 67
Baseline: c8637f37369098875162f194f92736355783b050.info
Baseline Date: 2023-10-18 20:25:20

           TLA  Line data    Source code
       1                 : use crate::error::UserFacingError;
       2                 : use anyhow::bail;
       3                 : use bytes::BytesMut;
       4                 : use pin_project_lite::pin_project;
       5                 : use pq_proto::framed::{ConnectionError, Framed};
       6                 : use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
       7                 : use rustls::ServerConfig;
       8                 : use std::pin::Pin;
       9                 : use std::sync::Arc;
      10                 : use std::{io, task};
      11                 : use thiserror::Error;
      12                 : use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
      13                 : use tokio_rustls::server::TlsStream;
      14                 : 
      15                 : /// Stream wrapper which implements libpq's protocol.
      16                 : /// NOTE: This object deliberately doesn't implement [`AsyncRead`]
      17                 : /// or [`AsyncWrite`] to prevent subtle errors (e.g. trying
      18                 : /// to pass random malformed bytes through the connection).
      19                 : pub struct PqStream<S> {
      20                 :     framed: Framed<S>,
      21                 : }
      22                 : 
      23                 : impl<S> PqStream<S> {
      24                 :     /// Construct a new libpq protocol wrapper.
      25 CBC          84 :     pub fn new(stream: S) -> Self {
      26              84 :         Self {
      27              84 :             framed: Framed::new(stream),
      28              84 :         }
      29              84 :     }
      30                 : 
      31                 :     /// Extract the underlying stream and read buffer.
      32              68 :     pub fn into_inner(self) -> (S, BytesMut) {
      33              68 :         self.framed.into_inner()
      34              68 :     }
      35                 : 
      36                 :     /// Get a shared reference to the underlying stream.
      37              71 :     pub fn get_ref(&self) -> &S {
      38              71 :         self.framed.get_ref()
      39              71 :     }
      40                 : }
      41                 : 
      42               1 : fn err_connection() -> io::Error {
      43               1 :     io::Error::new(io::ErrorKind::ConnectionAborted, "connection is lost")
      44               1 : }
      45                 : 
      46                 : impl<S: AsyncRead + Unpin> PqStream<S> {
      47                 :     /// Receive [`FeStartupPacket`], which is a first packet sent by a client.
      48              84 :     pub async fn read_startup_packet(&mut self) -> io::Result<FeStartupPacket> {
      49              84 :         self.framed
      50              84 :             .read_startup_message()
      51              24 :             .await
      52              84 :             .map_err(ConnectionError::into_io_error)?
      53              84 :             .ok_or_else(err_connection)
      54              84 :     }
      55                 : 
      56              65 :     async fn read_message(&mut self) -> io::Result<FeMessage> {
      57              65 :         self.framed
      58              65 :             .read_message()
      59              65 :             .await
      60              65 :             .map_err(ConnectionError::into_io_error)?
      61              65 :             .ok_or_else(err_connection)
      62              65 :     }
      63                 : 
      64              65 :     pub async fn read_password_message(&mut self) -> io::Result<bytes::Bytes> {
      65              65 :         match self.read_message().await? {
      66              65 :             FeMessage::PasswordMessage(msg) => Ok(msg),
      67 UBC           0 :             bad => Err(io::Error::new(
      68               0 :                 io::ErrorKind::InvalidData,
      69               0 :                 format!("unexpected message type: {:?}", bad),
      70               0 :             )),
      71                 :         }
      72 CBC          65 :     }
      73                 : }
      74                 : 
      75                 : impl<S: AsyncWrite + Unpin> PqStream<S> {
      76                 :     /// Write the message into an internal buffer, but don't flush the underlying stream.
      77                 :     pub fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
      78             628 :         self.framed
      79             628 :             .write_message(message)
      80             628 :             .map_err(ProtocolError::into_io_error)?;
      81             628 :         Ok(self)
      82             628 :     }
      83                 : 
      84                 :     /// Write the message into an internal buffer and flush it.
      85             177 :     pub async fn write_message(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
      86             177 :         self.write_message_noflush(message)?;
      87             177 :         self.flush().await?;
      88             177 :         Ok(self)
      89             177 :     }
      90                 : 
      91                 :     /// Flush the output buffer into the underlying stream.
      92             177 :     pub async fn flush(&mut self) -> io::Result<&mut Self> {
      93             177 :         self.framed.flush().await?;
      94             177 :         Ok(self)
      95             177 :     }
      96                 : 
      97                 :     /// Write the error message using [`Self::write_message`], then re-throw it.
      98                 :     /// Allowing string literals is safe under the assumption they might not contain any runtime info.
      99                 :     /// This method exists due to `&str` not implementing `Into<anyhow::Error>`.
     100               5 :     pub async fn throw_error_str<T>(&mut self, error: &'static str) -> anyhow::Result<T> {
     101               4 :         tracing::info!("forwarding error to user: {error}");
     102               5 :         self.write_message(&BeMessage::ErrorResponse(error, None))
     103 UBC           0 :             .await?;
     104 CBC           5 :         bail!(error)
     105               5 :     }
     106                 : 
     107                 :     /// Write the error message using [`Self::write_message`], then re-throw it.
     108                 :     /// Trait [`UserFacingError`] acts as an allowlist for error types.
     109               4 :     pub async fn throw_error<T, E>(&mut self, error: E) -> anyhow::Result<T>
     110               4 :     where
     111               4 :         E: UserFacingError + Into<anyhow::Error>,
     112               4 :     {
     113               4 :         let msg = error.to_string_client();
     114               4 :         tracing::info!("forwarding error to user: {msg}");
     115               4 :         self.write_message(&BeMessage::ErrorResponse(&msg, None))
     116 UBC           0 :             .await?;
     117 CBC           4 :         bail!(error)
     118               4 :     }
     119                 : }
     120                 : 
     121                 : pin_project! {
     122                 :     /// Wrapper for upgrading raw streams into secure streams.
     123                 :     /// NOTE: it should be possible to decompose this object as necessary.
     124                 :     #[project = StreamProj]
     125                 :     pub enum Stream<S> {
     126                 :         /// We always begin with a raw stream,
     127                 :         /// which may then be upgraded into a secure stream.
     128                 :         Raw { #[pin] raw: S },
     129                 :         /// We box [`TlsStream`] since it can be quite large.
     130                 :         Tls { #[pin] tls: Box<TlsStream<S>> },
     131                 :     }
     132                 : }
     133                 : 
     134                 : impl<S> Stream<S> {
     135                 :     /// Construct a new instance from a raw stream.
     136              46 :     pub fn from_raw(raw: S) -> Self {
     137              46 :         Self::Raw { raw }
     138              46 :     }
     139                 : 
     140                 :     /// Return SNI hostname when it's available.
     141              34 :     pub fn sni_hostname(&self) -> Option<&str> {
     142              34 :         match self {
     143 UBC           0 :             Stream::Raw { .. } => None,
     144 CBC          34 :             Stream::Tls { tls } => tls.get_ref().1.server_name(),
     145                 :         }
     146              34 :     }
     147                 : }
     148                 : 
     149 UBC           0 : #[derive(Debug, Error)]
     150                 : #[error("Can't upgrade TLS stream")]
     151                 : pub enum StreamUpgradeError {
     152                 :     #[error("Bad state reached: can't upgrade TLS stream")]
     153                 :     AlreadyTls,
     154                 : 
     155                 :     #[error("Can't upgrade stream: IO error: {0}")]
     156                 :     Io(#[from] io::Error),
     157                 : }
     158                 : 
     159                 : impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
     160                 :     /// If possible, upgrade raw stream into a secure TLS-based stream.
     161 CBC          39 :     pub async fn upgrade(self, cfg: Arc<ServerConfig>) -> Result<Self, StreamUpgradeError> {
     162              39 :         match self {
     163              39 :             Stream::Raw { raw } => {
     164              78 :                 let tls = Box::new(tokio_rustls::TlsAcceptor::from(cfg).accept(raw).await?);
     165              39 :                 Ok(Stream::Tls { tls })
     166                 :             }
     167 UBC           0 :             Stream::Tls { .. } => Err(StreamUpgradeError::AlreadyTls),
     168                 :         }
     169 CBC          39 :     }
     170                 : }
     171                 : 
     172                 : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for Stream<S> {
     173             418 :     fn poll_read(
     174             418 :         self: Pin<&mut Self>,
     175             418 :         context: &mut task::Context<'_>,
     176             418 :         buf: &mut ReadBuf<'_>,
     177             418 :     ) -> task::Poll<io::Result<()>> {
     178             418 :         use StreamProj::*;
     179             418 :         match self.project() {
     180              48 :             Raw { raw } => raw.poll_read(context, buf),
     181             370 :             Tls { tls } => tls.poll_read(context, buf),
     182                 :         }
     183             418 :     }
     184                 : }
     185                 : 
     186                 : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
     187             207 :     fn poll_write(
     188             207 :         self: Pin<&mut Self>,
     189             207 :         context: &mut task::Context<'_>,
     190             207 :         buf: &[u8],
     191             207 :     ) -> task::Poll<io::Result<usize>> {
     192             207 :         use StreamProj::*;
     193             207 :         match self.project() {
     194              45 :             Raw { raw } => raw.poll_write(context, buf),
     195             162 :             Tls { tls } => tls.poll_write(context, buf),
     196                 :         }
     197             207 :     }
     198                 : 
     199             237 :     fn poll_flush(
     200             237 :         self: Pin<&mut Self>,
     201             237 :         context: &mut task::Context<'_>,
     202             237 :     ) -> task::Poll<io::Result<()>> {
     203             237 :         use StreamProj::*;
     204             237 :         match self.project() {
     205              45 :             Raw { raw } => raw.poll_flush(context),
     206             192 :             Tls { tls } => tls.poll_flush(context),
     207                 :         }
     208             237 :     }
     209                 : 
     210              30 :     fn poll_shutdown(
     211              30 :         self: Pin<&mut Self>,
     212              30 :         context: &mut task::Context<'_>,
     213              30 :     ) -> task::Poll<io::Result<()>> {
     214              30 :         use StreamProj::*;
     215              30 :         match self.project() {
     216 UBC           0 :             Raw { raw } => raw.poll_shutdown(context),
     217 CBC          30 :             Tls { tls } => tls.poll_shutdown(context),
     218                 :         }
     219              30 :     }
     220                 : }
        

Generated by: LCOV version 2.1-beta