LCOV - code coverage report
Current view: top level - proxy/src/proxy - handshake.rs (source / functions) Coverage Total Hit
Test: d3dc80ba303d573d9d44a7f3290f1b1b48b7e1a0.info Lines: 15.4 % 13 2
Test Date: 2024-06-25 11:53:14 Functions: 37.5 % 16 6

            Line data    Source code
       1              : use pq_proto::{BeMessage as Be, CancelKeyData, FeStartupPacket, StartupMessageParams};
       2              : use thiserror::Error;
       3              : use tokio::io::{AsyncRead, AsyncWrite};
       4              : use tracing::info;
       5              : 
       6              : use crate::{
       7              :     config::TlsConfig,
       8              :     error::ReportableError,
       9              :     proxy::ERR_INSECURE_CONNECTION,
      10              :     stream::{PqStream, Stream, StreamUpgradeError},
      11              : };
      12              : 
      13            2 : #[derive(Error, Debug)]
      14              : pub enum HandshakeError {
      15              :     #[error("data is sent before server replied with EncryptionResponse")]
      16              :     EarlyData,
      17              : 
      18              :     #[error("protocol violation")]
      19              :     ProtocolViolation,
      20              : 
      21              :     #[error("missing certificate")]
      22              :     MissingCertificate,
      23              : 
      24              :     #[error("{0}")]
      25              :     StreamUpgradeError(#[from] StreamUpgradeError),
      26              : 
      27              :     #[error("{0}")]
      28              :     Io(#[from] std::io::Error),
      29              : 
      30              :     #[error("{0}")]
      31              :     ReportedError(#[from] crate::stream::ReportedError),
      32              : }
      33              : 
      34              : impl ReportableError for HandshakeError {
      35            0 :     fn get_error_kind(&self) -> crate::error::ErrorKind {
      36            0 :         match self {
      37            0 :             HandshakeError::EarlyData => crate::error::ErrorKind::User,
      38            0 :             HandshakeError::ProtocolViolation => crate::error::ErrorKind::User,
      39              :             // This error should not happen, but will if we have no default certificate and
      40              :             // the client sends no SNI extension.
      41              :             // If they provide SNI then we can be sure there is a certificate that matches.
      42            0 :             HandshakeError::MissingCertificate => crate::error::ErrorKind::Service,
      43            0 :             HandshakeError::StreamUpgradeError(upgrade) => match upgrade {
      44            0 :                 StreamUpgradeError::AlreadyTls => crate::error::ErrorKind::Service,
      45            0 :                 StreamUpgradeError::Io(_) => crate::error::ErrorKind::ClientDisconnect,
      46              :             },
      47            0 :             HandshakeError::Io(_) => crate::error::ErrorKind::ClientDisconnect,
      48            0 :             HandshakeError::ReportedError(e) => e.get_error_kind(),
      49              :         }
      50            0 :     }
      51              : }
      52              : 
      53              : pub enum HandshakeData<S> {
      54              :     Startup(PqStream<Stream<S>>, StartupMessageParams),
      55              :     Cancel(CancelKeyData),
      56              : }
      57              : 
      58              : /// Establish a (most probably, secure) connection with the client.
      59              : /// For better testing experience, `stream` can be any object satisfying the traits.
      60              : /// It's easier to work with owned `stream` here as we need to upgrade it to TLS;
      61              : /// we also take an extra care of propagating only the select handshake errors to client.
      62           88 : #[tracing::instrument(skip_all)]
      63              : pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
      64              :     stream: S,
      65              :     mut tls: Option<&TlsConfig>,
      66              :     record_handshake_error: bool,
      67              : ) -> Result<HandshakeData<S>, HandshakeError> {
      68              :     // Client may try upgrading to each protocol only once
      69              :     let (mut tried_ssl, mut tried_gss) = (false, false);
      70              : 
      71              :     let mut stream = PqStream::new(Stream::from_raw(stream));
      72              :     loop {
      73              :         let msg = stream.read_startup_packet().await?;
      74              :         info!("received {msg:?}");
      75              : 
      76              :         use FeStartupPacket::*;
      77              :         match msg {
      78              :             SslRequest => match stream.get_ref() {
      79              :                 Stream::Raw { .. } if !tried_ssl => {
      80              :                     tried_ssl = true;
      81              : 
      82              :                     // We can't perform TLS handshake without a config
      83              :                     let enc = tls.is_some();
      84              :                     stream.write_message(&Be::EncryptionResponse(enc)).await?;
      85              :                     if let Some(tls) = tls.take() {
      86              :                         // Upgrade raw stream into a secure TLS-backed stream.
      87              :                         // NOTE: We've consumed `tls`; this fact will be used later.
      88              : 
      89              :                         let (raw, read_buf) = stream.into_inner();
      90              :                         // TODO: Normally, client doesn't send any data before
      91              :                         // server says TLS handshake is ok and read_buf is empy.
      92              :                         // However, you could imagine pipelining of postgres
      93              :                         // SSLRequest + TLS ClientHello in one hunk similar to
      94              :                         // pipelining in our node js driver. We should probably
      95              :                         // support that by chaining read_buf with the stream.
      96              :                         if !read_buf.is_empty() {
      97              :                             return Err(HandshakeError::EarlyData);
      98              :                         }
      99              :                         let tls_stream = raw
     100              :                             .upgrade(tls.to_server_config(), record_handshake_error)
     101              :                             .await?;
     102              : 
     103              :                         let (_, tls_server_end_point) = tls
     104              :                             .cert_resolver
     105              :                             .resolve(tls_stream.get_ref().1.server_name())
     106              :                             .ok_or(HandshakeError::MissingCertificate)?;
     107              : 
     108              :                         stream = PqStream::new(Stream::Tls {
     109              :                             tls: Box::new(tls_stream),
     110              :                             tls_server_end_point,
     111              :                         });
     112              :                     }
     113              :                 }
     114              :                 _ => return Err(HandshakeError::ProtocolViolation),
     115              :             },
     116              :             GssEncRequest => match stream.get_ref() {
     117              :                 Stream::Raw { .. } if !tried_gss => {
     118              :                     tried_gss = true;
     119              : 
     120              :                     // Currently, we don't support GSSAPI
     121              :                     stream.write_message(&Be::EncryptionResponse(false)).await?;
     122              :                 }
     123              :                 _ => return Err(HandshakeError::ProtocolViolation),
     124              :             },
     125              :             StartupMessage { params, .. } => {
     126              :                 // Check that the config has been consumed during upgrade
     127              :                 // OR we didn't provide it at all (for dev purposes).
     128              :                 if tls.is_some() {
     129              :                     return stream
     130              :                         .throw_error_str(ERR_INSECURE_CONNECTION, crate::error::ErrorKind::User)
     131              :                         .await?;
     132              :                 }
     133              : 
     134              :                 info!(session_type = "normal", "successful handshake");
     135              :                 break Ok(HandshakeData::Startup(stream, params));
     136              :             }
     137              :             CancelRequest(cancel_key_data) => {
     138              :                 info!(session_type = "cancellation", "successful handshake");
     139              :                 break Ok(HandshakeData::Cancel(cancel_key_data));
     140              :             }
     141              :         }
     142              :     }
     143              : }
        

Generated by: LCOV version 2.1-beta