LCOV - code coverage report
Current view: top level - proxy/src - stream.rs (source / functions) Coverage Total Hit
Test: 32f4a56327bc9da697706839ed4836b2a00a408f.info Lines: 91.0 % 122 111
Test Date: 2024-02-07 07:37:29 Functions: 51.5 % 167 86

            Line data    Source code
       1              : use crate::config::TlsServerEndPoint;
       2              : use crate::error::UserFacingError;
       3              : use anyhow::bail;
       4              : use bytes::BytesMut;
       5              : 
       6              : use pq_proto::framed::{ConnectionError, Framed};
       7              : use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
       8              : use rustls::ServerConfig;
       9              : use std::pin::Pin;
      10              : use std::sync::Arc;
      11              : use std::{io, task};
      12              : use thiserror::Error;
      13              : use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
      14              : use tokio_rustls::server::TlsStream;
      15              : 
      16              : /// Stream wrapper which implements libpq's protocol.
      17              : /// NOTE: This object deliberately doesn't implement [`AsyncRead`]
      18              : /// or [`AsyncWrite`] to prevent subtle errors (e.g. trying
      19              : /// to pass random malformed bytes through the connection).
      20              : pub struct PqStream<S> {
      21              :     pub(crate) framed: Framed<S>,
      22              : }
      23              : 
      24              : impl<S> PqStream<S> {
      25              :     /// Construct a new libpq protocol wrapper.
      26          197 :     pub fn new(stream: S) -> Self {
      27          197 :         Self {
      28          197 :             framed: Framed::new(stream),
      29          197 :         }
      30          197 :     }
      31              : 
      32              :     /// Extract the underlying stream and read buffer.
      33          130 :     pub fn into_inner(self) -> (S, BytesMut) {
      34          130 :         self.framed.into_inner()
      35          130 :     }
      36              : 
      37              :     /// Get a shared reference to the underlying stream.
      38          204 :     pub fn get_ref(&self) -> &S {
      39          204 :         self.framed.get_ref()
      40          204 :     }
      41              : }
      42              : 
      43            1 : fn err_connection() -> io::Error {
      44            1 :     io::Error::new(io::ErrorKind::ConnectionAborted, "connection is lost")
      45            1 : }
      46              : 
      47              : impl<S: AsyncRead + Unpin> PqStream<S> {
      48              :     /// Receive [`FeStartupPacket`], which is a first packet sent by a client.
      49          197 :     pub async fn read_startup_packet(&mut self) -> io::Result<FeStartupPacket> {
      50          197 :         self.framed
      51          197 :             .read_startup_message()
      52           19 :             .await
      53          197 :             .map_err(ConnectionError::into_io_error)?
      54          197 :             .ok_or_else(err_connection)
      55          197 :     }
      56              : 
      57          121 :     async fn read_message(&mut self) -> io::Result<FeMessage> {
      58          121 :         self.framed
      59          121 :             .read_message()
      60          121 :             .await
      61          121 :             .map_err(ConnectionError::into_io_error)?
      62          119 :             .ok_or_else(err_connection)
      63          121 :     }
      64              : 
      65          121 :     pub async fn read_password_message(&mut self) -> io::Result<bytes::Bytes> {
      66          121 :         match self.read_message().await? {
      67          119 :             FeMessage::PasswordMessage(msg) => Ok(msg),
      68            0 :             bad => Err(io::Error::new(
      69            0 :                 io::ErrorKind::InvalidData,
      70            0 :                 format!("unexpected message type: {:?}", bad),
      71            0 :             )),
      72              :         }
      73          121 :     }
      74              : }
      75              : 
      76              : impl<S: AsyncWrite + Unpin> PqStream<S> {
      77              :     /// Write the message into an internal buffer, but don't flush the underlying stream.
      78          955 :     pub fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
      79          955 :         self.framed
      80          955 :             .write_message(message)
      81          955 :             .map_err(ProtocolError::into_io_error)?;
      82          955 :         Ok(self)
      83          955 :     }
      84              : 
      85              :     /// Write the message into an internal buffer and flush it.
      86          336 :     pub async fn write_message(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
      87          336 :         self.write_message_noflush(message)?;
      88          336 :         self.flush().await?;
      89          336 :         Ok(self)
      90          336 :     }
      91              : 
      92              :     /// Flush the output buffer into the underlying stream.
      93          336 :     pub async fn flush(&mut self) -> io::Result<&mut Self> {
      94          336 :         self.framed.flush().await?;
      95          336 :         Ok(self)
      96          336 :     }
      97              : 
      98              :     /// Write the error message using [`Self::write_message`], then re-throw it.
      99              :     /// Allowing string literals is safe under the assumption they might not contain any runtime info.
     100              :     /// This method exists due to `&str` not implementing `Into<anyhow::Error>`.
     101           13 :     pub async fn throw_error_str<T>(&mut self, error: &'static str) -> anyhow::Result<T> {
     102           11 :         tracing::info!("forwarding error to user: {error}");
     103           13 :         self.write_message(&BeMessage::ErrorResponse(error, None))
     104            0 :             .await?;
     105           13 :         bail!(error)
     106           13 :     }
     107              : 
     108              :     /// Write the error message using [`Self::write_message`], then re-throw it.
     109              :     /// Trait [`UserFacingError`] acts as an allowlist for error types.
     110           11 :     pub async fn throw_error<T, E>(&mut self, error: E) -> anyhow::Result<T>
     111           11 :     where
     112           11 :         E: UserFacingError + Into<anyhow::Error>,
     113           11 :     {
     114           11 :         let msg = error.to_string_client();
     115           11 :         tracing::info!("forwarding error to user: {msg}");
     116           11 :         self.write_message(&BeMessage::ErrorResponse(&msg, None))
     117            0 :             .await?;
     118           11 :         bail!(error)
     119           11 :     }
     120              : }
     121              : 
     122              : /// Wrapper for upgrading raw streams into secure streams.
     123              : pub enum Stream<S> {
     124              :     /// We always begin with a raw stream,
     125              :     /// which may then be upgraded into a secure stream.
     126              :     Raw { raw: S },
     127              :     Tls {
     128              :         /// We box [`TlsStream`] since it can be quite large.
     129              :         tls: Box<TlsStream<S>>,
     130              :         /// Channel binding parameter
     131              :         tls_server_end_point: TlsServerEndPoint,
     132              :     },
     133              : }
     134              : 
     135              : impl<S: Unpin> Unpin for Stream<S> {}
     136              : 
     137              : impl<S> Stream<S> {
     138              :     /// Construct a new instance from a raw stream.
     139          107 :     pub fn from_raw(raw: S) -> Self {
     140          107 :         Self::Raw { raw }
     141          107 :     }
     142              : 
     143              :     /// Return SNI hostname when it's available.
     144           51 :     pub fn sni_hostname(&self) -> Option<&str> {
     145           51 :         match self {
     146            0 :             Stream::Raw { .. } => None,
     147           51 :             Stream::Tls { tls, .. } => tls.get_ref().1.server_name(),
     148              :         }
     149           51 :     }
     150              : 
     151           64 :     pub fn tls_server_end_point(&self) -> TlsServerEndPoint {
     152           64 :         match self {
     153            0 :             Stream::Raw { .. } => TlsServerEndPoint::Undefined,
     154              :             Stream::Tls {
     155           64 :                 tls_server_end_point,
     156           64 :                 ..
     157           64 :             } => *tls_server_end_point,
     158              :         }
     159           64 :     }
     160              : }
     161              : 
     162            0 : #[derive(Debug, Error)]
     163              : #[error("Can't upgrade TLS stream")]
     164              : pub enum StreamUpgradeError {
     165              :     #[error("Bad state reached: can't upgrade TLS stream")]
     166              :     AlreadyTls,
     167              : 
     168              :     #[error("Can't upgrade stream: IO error: {0}")]
     169              :     Io(#[from] io::Error),
     170              : }
     171              : 
     172              : impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
     173              :     /// If possible, upgrade raw stream into a secure TLS-based stream.
     174           91 :     pub async fn upgrade(self, cfg: Arc<ServerConfig>) -> Result<TlsStream<S>, StreamUpgradeError> {
     175           91 :         match self {
     176          195 :             Stream::Raw { raw } => Ok(tokio_rustls::TlsAcceptor::from(cfg).accept(raw).await?),
     177            0 :             Stream::Tls { .. } => Err(StreamUpgradeError::AlreadyTls),
     178              :         }
     179           91 :     }
     180              : }
     181              : 
     182              : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for Stream<S> {
     183          806 :     fn poll_read(
     184          806 :         mut self: Pin<&mut Self>,
     185          806 :         context: &mut task::Context<'_>,
     186          806 :         buf: &mut ReadBuf<'_>,
     187          806 :     ) -> task::Poll<io::Result<()>> {
     188          806 :         match &mut *self {
     189          107 :             Self::Raw { raw } => Pin::new(raw).poll_read(context, buf),
     190          699 :             Self::Tls { tls, .. } => Pin::new(tls).poll_read(context, buf),
     191              :         }
     192          806 :     }
     193              : }
     194              : 
     195              : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
     196          408 :     fn poll_write(
     197          408 :         mut self: Pin<&mut Self>,
     198          408 :         context: &mut task::Context<'_>,
     199          408 :         buf: &[u8],
     200          408 :     ) -> task::Poll<io::Result<usize>> {
     201          408 :         match &mut *self {
     202          106 :             Self::Raw { raw } => Pin::new(raw).poll_write(context, buf),
     203          302 :             Self::Tls { tls, .. } => Pin::new(tls).poll_write(context, buf),
     204              :         }
     205          408 :     }
     206              : 
     207          448 :     fn poll_flush(
     208          448 :         mut self: Pin<&mut Self>,
     209          448 :         context: &mut task::Context<'_>,
     210          448 :     ) -> task::Poll<io::Result<()>> {
     211          448 :         match &mut *self {
     212          106 :             Self::Raw { raw } => Pin::new(raw).poll_flush(context),
     213          342 :             Self::Tls { tls, .. } => Pin::new(tls).poll_flush(context),
     214              :         }
     215          448 :     }
     216              : 
     217           40 :     fn poll_shutdown(
     218           40 :         mut self: Pin<&mut Self>,
     219           40 :         context: &mut task::Context<'_>,
     220           40 :     ) -> task::Poll<io::Result<()>> {
     221           40 :         match &mut *self {
     222            0 :             Self::Raw { raw } => Pin::new(raw).poll_shutdown(context),
     223           40 :             Self::Tls { tls, .. } => Pin::new(tls).poll_shutdown(context),
     224              :         }
     225           40 :     }
     226              : }
        

Generated by: LCOV version 2.1-beta