LCOV - code coverage report
Current view: top level - proxy/src - stream.rs (source / functions) Coverage Total Hit
Test: 8ac049b474321fdc72ddcb56d7165153a1a900e8.info Lines: 91.7 % 120 110
Test Date: 2023-09-06 10:18:01 Functions: 46.5 % 144 67

            Line data    Source code
       1              : use crate::error::UserFacingError;
       2              : use anyhow::bail;
       3              : use bytes::BytesMut;
       4              : use pin_project_lite::pin_project;
       5              : use pq_proto::framed::{ConnectionError, Framed};
       6              : use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
       7              : use rustls::ServerConfig;
       8              : use std::pin::Pin;
       9              : use std::sync::Arc;
      10              : use std::{io, task};
      11              : use thiserror::Error;
      12              : use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
      13              : use tokio_rustls::server::TlsStream;
      14              : 
      15              : /// Stream wrapper which implements libpq's protocol.
      16              : /// NOTE: This object deliberately doesn't implement [`AsyncRead`]
      17              : /// or [`AsyncWrite`] to prevent subtle errors (e.g. trying
      18              : /// to pass random malformed bytes through the connection).
      19              : pub struct PqStream<S> {
      20              :     framed: Framed<S>,
      21              : }
      22              : 
      23              : impl<S> PqStream<S> {
      24              :     /// Construct a new libpq protocol wrapper.
      25           80 :     pub fn new(stream: S) -> Self {
      26           80 :         Self {
      27           80 :             framed: Framed::new(stream),
      28           80 :         }
      29           80 :     }
      30              : 
      31              :     /// Extract the underlying stream and read buffer.
      32           64 :     pub fn into_inner(self) -> (S, BytesMut) {
      33           64 :         self.framed.into_inner()
      34           64 :     }
      35              : 
      36              :     /// Get a shared reference to the underlying stream.
      37           67 :     pub fn get_ref(&self) -> &S {
      38           67 :         self.framed.get_ref()
      39           67 :     }
      40              : }
      41              : 
      42            1 : fn err_connection() -> io::Error {
      43            1 :     io::Error::new(io::ErrorKind::ConnectionAborted, "connection is lost")
      44            1 : }
      45              : 
      46              : impl<S: AsyncRead + Unpin> PqStream<S> {
      47              :     /// Receive [`FeStartupPacket`], which is a first packet sent by a client.
      48           80 :     pub async fn read_startup_packet(&mut self) -> io::Result<FeStartupPacket> {
      49           80 :         self.framed
      50           80 :             .read_startup_message()
      51           21 :             .await
      52           80 :             .map_err(ConnectionError::into_io_error)?
      53           80 :             .ok_or_else(err_connection)
      54           80 :     }
      55              : 
      56           61 :     async fn read_message(&mut self) -> io::Result<FeMessage> {
      57           61 :         self.framed
      58           61 :             .read_message()
      59           61 :             .await
      60           61 :             .map_err(ConnectionError::into_io_error)?
      61           61 :             .ok_or_else(err_connection)
      62           61 :     }
      63              : 
      64           61 :     pub async fn read_password_message(&mut self) -> io::Result<bytes::Bytes> {
      65           61 :         match self.read_message().await? {
      66           61 :             FeMessage::PasswordMessage(msg) => Ok(msg),
      67            0 :             bad => Err(io::Error::new(
      68            0 :                 io::ErrorKind::InvalidData,
      69            0 :                 format!("unexpected message type: {:?}", bad),
      70            0 :             )),
      71              :         }
      72           61 :     }
      73              : }
      74              : 
      75              : impl<S: AsyncWrite + Unpin> PqStream<S> {
      76              :     /// Write the message into an internal buffer, but don't flush the underlying stream.
      77              :     pub fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
      78          588 :         self.framed
      79          588 :             .write_message(message)
      80          588 :             .map_err(ProtocolError::into_io_error)?;
      81          588 :         Ok(self)
      82          588 :     }
      83              : 
      84              :     /// Write the message into an internal buffer and flush it.
      85          167 :     pub async fn write_message(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
      86          167 :         self.write_message_noflush(message)?;
      87          167 :         self.flush().await?;
      88          167 :         Ok(self)
      89          167 :     }
      90              : 
      91              :     /// Flush the output buffer into the underlying stream.
      92          167 :     pub async fn flush(&mut self) -> io::Result<&mut Self> {
      93          167 :         self.framed.flush().await?;
      94          167 :         Ok(self)
      95          167 :     }
      96              : 
      97              :     /// Write the error message using [`Self::write_message`], then re-throw it.
      98              :     /// Allowing string literals is safe under the assumption they might not contain any runtime info.
      99              :     /// This method exists due to `&str` not implementing `Into<anyhow::Error>`.
     100            5 :     pub async fn throw_error_str<T>(&mut self, error: &'static str) -> anyhow::Result<T> {
     101            4 :         tracing::info!("forwarding error to user: {error}");
     102            5 :         self.write_message(&BeMessage::ErrorResponse(error, None))
     103            0 :             .await?;
     104            5 :         bail!(error)
     105            5 :     }
     106              : 
     107              :     /// Write the error message using [`Self::write_message`], then re-throw it.
     108              :     /// Trait [`UserFacingError`] acts as an allowlist for error types.
     109            4 :     pub async fn throw_error<T, E>(&mut self, error: E) -> anyhow::Result<T>
     110            4 :     where
     111            4 :         E: UserFacingError + Into<anyhow::Error>,
     112            4 :     {
     113            4 :         let msg = error.to_string_client();
     114            4 :         tracing::info!("forwarding error to user: {msg}");
     115            4 :         self.write_message(&BeMessage::ErrorResponse(&msg, None))
     116            0 :             .await?;
     117            4 :         bail!(error)
     118            4 :     }
     119              : }
     120              : 
     121              : pin_project! {
     122              :     /// Wrapper for upgrading raw streams into secure streams.
     123              :     /// NOTE: it should be possible to decompose this object as necessary.
     124              :     #[project = StreamProj]
     125              :     pub enum Stream<S> {
     126              :         /// We always begin with a raw stream,
     127              :         /// which may then be upgraded into a secure stream.
     128              :         Raw { #[pin] raw: S },
     129              :         /// We box [`TlsStream`] since it can be quite large.
     130              :         Tls { #[pin] tls: Box<TlsStream<S>> },
     131              :     }
     132              : }
     133              : 
     134              : impl<S> Stream<S> {
     135              :     /// Construct a new instance from a raw stream.
     136           44 :     pub fn from_raw(raw: S) -> Self {
     137           44 :         Self::Raw { raw }
     138           44 :     }
     139              : 
     140              :     /// Return SNI hostname when it's available.
     141           32 :     pub fn sni_hostname(&self) -> Option<&str> {
     142           32 :         match self {
     143            0 :             Stream::Raw { .. } => None,
     144           32 :             Stream::Tls { tls } => tls.get_ref().1.server_name(),
     145              :         }
     146           32 :     }
     147              : }
     148              : 
     149            0 : #[derive(Debug, Error)]
     150              : #[error("Can't upgrade TLS stream")]
     151              : pub enum StreamUpgradeError {
     152              :     #[error("Bad state reached: can't upgrade TLS stream")]
     153              :     AlreadyTls,
     154              : 
     155              :     #[error("Can't upgrade stream: IO error: {0}")]
     156              :     Io(#[from] io::Error),
     157              : }
     158              : 
     159              : impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
     160              :     /// If possible, upgrade raw stream into a secure TLS-based stream.
     161           37 :     pub async fn upgrade(self, cfg: Arc<ServerConfig>) -> Result<Self, StreamUpgradeError> {
     162           37 :         match self {
     163           37 :             Stream::Raw { raw } => {
     164           74 :                 let tls = Box::new(tokio_rustls::TlsAcceptor::from(cfg).accept(raw).await?);
     165           37 :                 Ok(Stream::Tls { tls })
     166              :             }
     167            0 :             Stream::Tls { .. } => Err(StreamUpgradeError::AlreadyTls),
     168              :         }
     169           37 :     }
     170              : }
     171              : 
     172              : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for Stream<S> {
     173          390 :     fn poll_read(
     174          390 :         self: Pin<&mut Self>,
     175          390 :         context: &mut task::Context<'_>,
     176          390 :         buf: &mut ReadBuf<'_>,
     177          390 :     ) -> task::Poll<io::Result<()>> {
     178          390 :         use StreamProj::*;
     179          390 :         match self.project() {
     180           46 :             Raw { raw } => raw.poll_read(context, buf),
     181          344 :             Tls { tls } => tls.poll_read(context, buf),
     182              :         }
     183          390 :     }
     184              : }
     185              : 
     186              : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
     187          195 :     fn poll_write(
     188          195 :         self: Pin<&mut Self>,
     189          195 :         context: &mut task::Context<'_>,
     190          195 :         buf: &[u8],
     191          195 :     ) -> task::Poll<io::Result<usize>> {
     192          195 :         use StreamProj::*;
     193          195 :         match self.project() {
     194           43 :             Raw { raw } => raw.poll_write(context, buf),
     195          152 :             Tls { tls } => tls.poll_write(context, buf),
     196              :         }
     197          195 :     }
     198              : 
     199          223 :     fn poll_flush(
     200          223 :         self: Pin<&mut Self>,
     201          223 :         context: &mut task::Context<'_>,
     202          223 :     ) -> task::Poll<io::Result<()>> {
     203          223 :         use StreamProj::*;
     204          223 :         match self.project() {
     205           43 :             Raw { raw } => raw.poll_flush(context),
     206          180 :             Tls { tls } => tls.poll_flush(context),
     207              :         }
     208          223 :     }
     209              : 
     210           28 :     fn poll_shutdown(
     211           28 :         self: Pin<&mut Self>,
     212           28 :         context: &mut task::Context<'_>,
     213           28 :     ) -> task::Poll<io::Result<()>> {
     214           28 :         use StreamProj::*;
     215           28 :         match self.project() {
     216            0 :             Raw { raw } => raw.poll_shutdown(context),
     217           28 :             Tls { tls } => tls.poll_shutdown(context),
     218              :         }
     219           28 :     }
     220              : }
        

Generated by: LCOV version 2.1-beta