|             Line data    Source code 
       1              : use crate::config::TlsServerEndPoint;
       2              : use crate::error::{ErrorKind, ReportableError, UserFacingError};
       3              : use crate::metrics::Metrics;
       4              : use bytes::BytesMut;
       5              : 
       6              : use pq_proto::framed::{ConnectionError, Framed};
       7              : use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
       8              : use rustls::ServerConfig;
       9              : use std::pin::Pin;
      10              : use std::sync::Arc;
      11              : use std::{io, task};
      12              : use thiserror::Error;
      13              : use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
      14              : use tokio_rustls::server::TlsStream;
      15              : 
      16              : /// Stream wrapper which implements libpq's protocol.
      17              : /// NOTE: This object deliberately doesn't implement [`AsyncRead`]
      18              : /// or [`AsyncWrite`] to prevent subtle errors (e.g. trying
      19              : /// to pass random malformed bytes through the connection).
      20              : pub struct PqStream<S> {
      21              :     pub(crate) framed: Framed<S>,
      22              : }
      23              : 
      24              : impl<S> PqStream<S> {
      25              :     /// Construct a new libpq protocol wrapper.
      26           90 :     pub fn new(stream: S) -> Self {
      27           90 :         Self {
      28           90 :             framed: Framed::new(stream),
      29           90 :         }
      30           90 :     }
      31              : 
      32              :     /// Extract the underlying stream and read buffer.
      33           40 :     pub fn into_inner(self) -> (S, BytesMut) {
      34           40 :         self.framed.into_inner()
      35           40 :     }
      36              : 
      37              :     /// Get a shared reference to the underlying stream.
      38           70 :     pub fn get_ref(&self) -> &S {
      39           70 :         self.framed.get_ref()
      40           70 :     }
      41              : }
      42              : 
      43            0 : fn err_connection() -> io::Error {
      44            0 :     io::Error::new(io::ErrorKind::ConnectionAborted, "connection is lost")
      45            0 : }
      46              : 
      47              : impl<S: AsyncRead + Unpin> PqStream<S> {
      48              :     /// Receive [`FeStartupPacket`], which is a first packet sent by a client.
      49           84 :     pub async fn read_startup_packet(&mut self) -> io::Result<FeStartupPacket> {
      50           84 :         self.framed
      51           84 :             .read_startup_message()
      52           14 :             .await
      53           84 :             .map_err(ConnectionError::into_io_error)?
      54           84 :             .ok_or_else(err_connection)
      55           84 :     }
      56              : 
      57           52 :     async fn read_message(&mut self) -> io::Result<FeMessage> {
      58           52 :         self.framed
      59           52 :             .read_message()
      60           52 :             .await
      61           52 :             .map_err(ConnectionError::into_io_error)?
      62           50 :             .ok_or_else(err_connection)
      63           52 :     }
      64              : 
      65           52 :     pub async fn read_password_message(&mut self) -> io::Result<bytes::Bytes> {
      66           52 :         match self.read_message().await? {
      67           50 :             FeMessage::PasswordMessage(msg) => Ok(msg),
      68            0 :             bad => Err(io::Error::new(
      69            0 :                 io::ErrorKind::InvalidData,
      70            0 :                 format!("unexpected message type: {:?}", bad),
      71            0 :             )),
      72              :         }
      73           52 :     }
      74              : }
      75              : 
      76              : #[derive(Debug)]
      77              : pub struct ReportedError {
      78              :     source: anyhow::Error,
      79              :     error_kind: ErrorKind,
      80              : }
      81              : 
      82              : impl std::fmt::Display for ReportedError {
      83            2 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
      84            2 :         self.source.fmt(f)
      85            2 :     }
      86              : }
      87              : 
      88              : impl std::error::Error for ReportedError {
      89            0 :     fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
      90            0 :         self.source.source()
      91            0 :     }
      92              : }
      93              : 
      94              : impl ReportableError for ReportedError {
      95            0 :     fn get_error_kind(&self) -> ErrorKind {
      96            0 :         self.error_kind
      97            0 :     }
      98              : }
      99              : 
     100              : impl<S: AsyncWrite + Unpin> PqStream<S> {
     101              :     /// Write the message into an internal buffer, but don't flush the underlying stream.
     102          154 :     pub fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
     103          154 :         self.framed
     104          154 :             .write_message(message)
     105          154 :             .map_err(ProtocolError::into_io_error)?;
     106          154 :         Ok(self)
     107          154 :     }
     108              : 
     109              :     /// Write the message into an internal buffer and flush it.
     110          120 :     pub async fn write_message(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
     111          120 :         self.write_message_noflush(message)?;
     112          120 :         self.flush().await?;
     113          120 :         Ok(self)
     114          120 :     }
     115              : 
     116              :     /// Flush the output buffer into the underlying stream.
     117          120 :     pub async fn flush(&mut self) -> io::Result<&mut Self> {
     118          120 :         self.framed.flush().await?;
     119          120 :         Ok(self)
     120          120 :     }
     121              : 
     122              :     /// Write the error message using [`Self::write_message`], then re-throw it.
     123              :     /// Allowing string literals is safe under the assumption they might not contain any runtime info.
     124              :     /// This method exists due to `&str` not implementing `Into<anyhow::Error>`.
     125            2 :     pub async fn throw_error_str<T>(
     126            2 :         &mut self,
     127            2 :         msg: &'static str,
     128            2 :         error_kind: ErrorKind,
     129            2 :     ) -> Result<T, ReportedError> {
     130            2 :         tracing::info!(
     131            0 :             kind = error_kind.to_metric_label(),
     132            0 :             msg,
     133            0 :             "forwarding error to user"
     134              :         );
     135              : 
     136              :         // already error case, ignore client IO error
     137            2 :         let _: Result<_, std::io::Error> = self
     138            2 :             .write_message(&BeMessage::ErrorResponse(msg, None))
     139            0 :             .await;
     140              : 
     141            2 :         Err(ReportedError {
     142            2 :             source: anyhow::anyhow!(msg),
     143            2 :             error_kind,
     144            2 :         })
     145            2 :     }
     146              : 
     147              :     /// Write the error message using [`Self::write_message`], then re-throw it.
     148              :     /// Trait [`UserFacingError`] acts as an allowlist for error types.
     149            0 :     pub async fn throw_error<T, E>(&mut self, error: E) -> Result<T, ReportedError>
     150            0 :     where
     151            0 :         E: UserFacingError + Into<anyhow::Error>,
     152            0 :     {
     153            0 :         let error_kind = error.get_error_kind();
     154            0 :         let msg = error.to_string_client();
     155            0 :         tracing::info!(
     156            0 :             kind=error_kind.to_metric_label(),
     157            0 :             error=%error,
     158            0 :             msg,
     159            0 :             "forwarding error to user"
     160              :         );
     161              : 
     162              :         // already error case, ignore client IO error
     163            0 :         let _: Result<_, std::io::Error> = self
     164            0 :             .write_message(&BeMessage::ErrorResponse(&msg, None))
     165            0 :             .await;
     166              : 
     167            0 :         Err(ReportedError {
     168            0 :             source: anyhow::anyhow!(error),
     169            0 :             error_kind,
     170            0 :         })
     171            0 :     }
     172              : }
     173              : 
     174              : /// Wrapper for upgrading raw streams into secure streams.
     175              : pub enum Stream<S> {
     176              :     /// We always begin with a raw stream,
     177              :     /// which may then be upgraded into a secure stream.
     178              :     Raw { raw: S },
     179              :     Tls {
     180              :         /// We box [`TlsStream`] since it can be quite large.
     181              :         tls: Box<TlsStream<S>>,
     182              :         /// Channel binding parameter
     183              :         tls_server_end_point: TlsServerEndPoint,
     184              :     },
     185              : }
     186              : 
     187              : impl<S: Unpin> Unpin for Stream<S> {}
     188              : 
     189              : impl<S> Stream<S> {
     190              :     /// Construct a new instance from a raw stream.
     191           50 :     pub fn from_raw(raw: S) -> Self {
     192           50 :         Self::Raw { raw }
     193           50 :     }
     194              : 
     195              :     /// Return SNI hostname when it's available.
     196            0 :     pub fn sni_hostname(&self) -> Option<&str> {
     197            0 :         match self {
     198            0 :             Stream::Raw { .. } => None,
     199            0 :             Stream::Tls { tls, .. } => tls.get_ref().1.server_name(),
     200              :         }
     201            0 :     }
     202              : 
     203           30 :     pub fn tls_server_end_point(&self) -> TlsServerEndPoint {
     204           30 :         match self {
     205            6 :             Stream::Raw { .. } => TlsServerEndPoint::Undefined,
     206              :             Stream::Tls {
     207           24 :                 tls_server_end_point,
     208           24 :                 ..
     209           24 :             } => *tls_server_end_point,
     210              :         }
     211           30 :     }
     212              : }
     213              : 
     214            0 : #[derive(Debug, Error)]
     215              : #[error("Can't upgrade TLS stream")]
     216              : pub enum StreamUpgradeError {
     217              :     #[error("Bad state reached: can't upgrade TLS stream")]
     218              :     AlreadyTls,
     219              : 
     220              :     #[error("Can't upgrade stream: IO error: {0}")]
     221              :     Io(#[from] io::Error),
     222              : }
     223              : 
     224              : impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
     225              :     /// If possible, upgrade raw stream into a secure TLS-based stream.
     226           40 :     pub async fn upgrade(
     227           40 :         self,
     228           40 :         cfg: Arc<ServerConfig>,
     229           40 :         record_handshake_error: bool,
     230           40 :     ) -> Result<TlsStream<S>, StreamUpgradeError> {
     231           40 :         match self {
     232           40 :             Stream::Raw { raw } => Ok(tokio_rustls::TlsAcceptor::from(cfg)
     233           40 :                 .accept(raw)
     234           94 :                 .await
     235           40 :                 .inspect_err(|_| {
     236            0 :                     if record_handshake_error {
     237            0 :                         Metrics::get().proxy.tls_handshake_failures.inc()
     238            0 :                     }
     239           40 :                 })?),
     240            0 :             Stream::Tls { .. } => Err(StreamUpgradeError::AlreadyTls),
     241              :         }
     242           40 :     }
     243              : }
     244              : 
     245              : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for Stream<S> {
     246          308 :     fn poll_read(
     247          308 :         mut self: Pin<&mut Self>,
     248          308 :         context: &mut task::Context<'_>,
     249          308 :         buf: &mut ReadBuf<'_>,
     250          308 :     ) -> task::Poll<io::Result<()>> {
     251          308 :         match &mut *self {
     252           60 :             Self::Raw { raw } => Pin::new(raw).poll_read(context, buf),
     253          248 :             Self::Tls { tls, .. } => Pin::new(tls).poll_read(context, buf),
     254              :         }
     255          308 :     }
     256              : }
     257              : 
     258              : impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
     259          152 :     fn poll_write(
     260          152 :         mut self: Pin<&mut Self>,
     261          152 :         context: &mut task::Context<'_>,
     262          152 :         buf: &[u8],
     263          152 :     ) -> task::Poll<io::Result<usize>> {
     264          152 :         match &mut *self {
     265           54 :             Self::Raw { raw } => Pin::new(raw).poll_write(context, buf),
     266           98 :             Self::Tls { tls, .. } => Pin::new(tls).poll_write(context, buf),
     267              :         }
     268          152 :     }
     269              : 
     270          152 :     fn poll_flush(
     271          152 :         mut self: Pin<&mut Self>,
     272          152 :         context: &mut task::Context<'_>,
     273          152 :     ) -> task::Poll<io::Result<()>> {
     274          152 :         match &mut *self {
     275           54 :             Self::Raw { raw } => Pin::new(raw).poll_flush(context),
     276           98 :             Self::Tls { tls, .. } => Pin::new(tls).poll_flush(context),
     277              :         }
     278          152 :     }
     279              : 
     280            0 :     fn poll_shutdown(
     281            0 :         mut self: Pin<&mut Self>,
     282            0 :         context: &mut task::Context<'_>,
     283            0 :     ) -> task::Poll<io::Result<()>> {
     284            0 :         match &mut *self {
     285            0 :             Self::Raw { raw } => Pin::new(raw).poll_shutdown(context),
     286            0 :             Self::Tls { tls, .. } => Pin::new(tls).poll_shutdown(context),
     287              :         }
     288            0 :     }
     289              : }
         |