LCOV - code coverage report
Current view: top level - proxy/src/proxy - handshake.rs (source / functions) Coverage Total Hit
Test: 046155f5c3321e806c1c5acca9ccd26414587b38.info Lines: 16.0 % 25 4
Test Date: 2025-03-27 12:42:09 Functions: 25.0 % 16 4

            Line data    Source code
       1              : use bytes::Buf;
       2              : use pq_proto::framed::Framed;
       3              : use pq_proto::{
       4              :     BeMessage as Be, CancelKeyData, FeStartupPacket, ProtocolVersion, StartupMessageParams,
       5              : };
       6              : use thiserror::Error;
       7              : use tokio::io::{AsyncRead, AsyncWrite};
       8              : use tracing::{debug, info, warn};
       9              : 
      10              : use crate::auth::endpoint_sni;
      11              : use crate::config::TlsConfig;
      12              : use crate::context::RequestContext;
      13              : use crate::error::ReportableError;
      14              : use crate::metrics::Metrics;
      15              : use crate::proxy::ERR_INSECURE_CONNECTION;
      16              : use crate::stream::{PqStream, Stream, StreamUpgradeError};
      17              : use crate::tls::PG_ALPN_PROTOCOL;
      18              : 
      19              : #[derive(Error, Debug)]
      20              : pub(crate) enum HandshakeError {
      21              :     #[error("data is sent before server replied with EncryptionResponse")]
      22              :     EarlyData,
      23              : 
      24              :     #[error("protocol violation")]
      25              :     ProtocolViolation,
      26              : 
      27              :     #[error("missing certificate")]
      28              :     MissingCertificate,
      29              : 
      30              :     #[error("{0}")]
      31              :     StreamUpgradeError(#[from] StreamUpgradeError),
      32              : 
      33              :     #[error("{0}")]
      34              :     Io(#[from] std::io::Error),
      35              : 
      36              :     #[error("{0}")]
      37              :     ReportedError(#[from] crate::stream::ReportedError),
      38              : }
      39              : 
      40              : impl ReportableError for HandshakeError {
      41            0 :     fn get_error_kind(&self) -> crate::error::ErrorKind {
      42            0 :         match self {
      43            0 :             HandshakeError::EarlyData => crate::error::ErrorKind::User,
      44            0 :             HandshakeError::ProtocolViolation => crate::error::ErrorKind::User,
      45              :             // This error should not happen, but will if we have no default certificate and
      46              :             // the client sends no SNI extension.
      47              :             // If they provide SNI then we can be sure there is a certificate that matches.
      48            0 :             HandshakeError::MissingCertificate => crate::error::ErrorKind::Service,
      49            0 :             HandshakeError::StreamUpgradeError(upgrade) => match upgrade {
      50            0 :                 StreamUpgradeError::AlreadyTls => crate::error::ErrorKind::Service,
      51            0 :                 StreamUpgradeError::Io(_) => crate::error::ErrorKind::ClientDisconnect,
      52              :             },
      53            0 :             HandshakeError::Io(_) => crate::error::ErrorKind::ClientDisconnect,
      54            0 :             HandshakeError::ReportedError(e) => e.get_error_kind(),
      55              :         }
      56            0 :     }
      57              : }
      58              : 
      59              : pub(crate) enum HandshakeData<S> {
      60              :     Startup(PqStream<Stream<S>>, StartupMessageParams),
      61              :     Cancel(CancelKeyData),
      62              : }
      63              : 
      64              : /// Establish a (most probably, secure) connection with the client.
      65              : /// For better testing experience, `stream` can be any object satisfying the traits.
      66              : /// It's easier to work with owned `stream` here as we need to upgrade it to TLS;
      67              : /// we also take an extra care of propagating only the select handshake errors to client.
      68              : #[tracing::instrument(skip_all)]
      69              : pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
      70              :     ctx: &RequestContext,
      71              :     stream: S,
      72              :     mut tls: Option<&TlsConfig>,
      73              :     record_handshake_error: bool,
      74              : ) -> Result<HandshakeData<S>, HandshakeError> {
      75              :     // Client may try upgrading to each protocol only once
      76              :     let (mut tried_ssl, mut tried_gss) = (false, false);
      77              : 
      78              :     const PG_PROTOCOL_EARLIEST: ProtocolVersion = ProtocolVersion::new(3, 0);
      79              :     const PG_PROTOCOL_LATEST: ProtocolVersion = ProtocolVersion::new(3, 0);
      80              : 
      81              :     let mut stream = PqStream::new(Stream::from_raw(stream));
      82              :     loop {
      83              :         let msg = stream.read_startup_packet().await?;
      84              :         match msg {
      85              :             FeStartupPacket::SslRequest { direct } => match stream.get_ref() {
      86              :                 Stream::Raw { .. } if !tried_ssl => {
      87              :                     tried_ssl = true;
      88              : 
      89              :                     // We can't perform TLS handshake without a config
      90              :                     let have_tls = tls.is_some();
      91              :                     if !direct {
      92              :                         stream
      93              :                             .write_message(&Be::EncryptionResponse(have_tls))
      94              :                             .await?;
      95              :                     } else if !have_tls {
      96              :                         return Err(HandshakeError::ProtocolViolation);
      97              :                     }
      98              : 
      99              :                     if let Some(tls) = tls.take() {
     100              :                         // Upgrade raw stream into a secure TLS-backed stream.
     101              :                         // NOTE: We've consumed `tls`; this fact will be used later.
     102              : 
     103              :                         let Framed {
     104              :                             stream: raw,
     105              :                             read_buf,
     106              :                             write_buf,
     107              :                         } = stream.framed;
     108              : 
     109              :                         let Stream::Raw { raw } = raw else {
     110              :                             return Err(HandshakeError::StreamUpgradeError(
     111              :                                 StreamUpgradeError::AlreadyTls,
     112              :                             ));
     113              :                         };
     114              : 
     115              :                         let mut read_buf = read_buf.reader();
     116              :                         let mut res = Ok(());
     117              :                         let accept = tokio_rustls::TlsAcceptor::from(tls.pg_config.clone())
     118           20 :                             .accept_with(raw, |session| {
     119              :                                 // push the early data to the tls session
     120           20 :                                 while !read_buf.get_ref().is_empty() {
     121            0 :                                     match session.read_tls(&mut read_buf) {
     122            0 :                                         Ok(_) => {}
     123            0 :                                         Err(e) => {
     124            0 :                                             res = Err(e);
     125            0 :                                             break;
     126              :                                         }
     127              :                                     }
     128              :                                 }
     129           20 :                             });
     130              : 
     131              :                         res?;
     132              : 
     133              :                         let read_buf = read_buf.into_inner();
     134              :                         if !read_buf.is_empty() {
     135              :                             return Err(HandshakeError::EarlyData);
     136              :                         }
     137              : 
     138            0 :                         let tls_stream = accept.await.inspect_err(|_| {
     139            0 :                             if record_handshake_error {
     140            0 :                                 Metrics::get().proxy.tls_handshake_failures.inc();
     141            0 :                             }
     142            0 :                         })?;
     143              : 
     144              :                         let conn_info = tls_stream.get_ref().1;
     145              : 
     146              :                         // try parse endpoint
     147              :                         let ep = conn_info
     148              :                             .server_name()
     149           20 :                             .and_then(|sni| endpoint_sni(sni, &tls.common_names).ok().flatten());
     150              :                         if let Some(ep) = ep {
     151              :                             ctx.set_endpoint_id(ep);
     152              :                         }
     153              : 
     154              :                         // check the ALPN, if exists, as required.
     155              :                         match conn_info.alpn_protocol() {
     156              :                             None | Some(PG_ALPN_PROTOCOL) => {}
     157              :                             Some(other) => {
     158              :                                 let alpn = String::from_utf8_lossy(other);
     159              :                                 warn!(%alpn, "unexpected ALPN");
     160              :                                 return Err(HandshakeError::ProtocolViolation);
     161              :                             }
     162              :                         }
     163              : 
     164              :                         let (_, tls_server_end_point) = tls
     165              :                             .cert_resolver
     166              :                             .resolve(conn_info.server_name())
     167              :                             .ok_or(HandshakeError::MissingCertificate)?;
     168              : 
     169              :                         stream = PqStream {
     170              :                             framed: Framed {
     171              :                                 stream: Stream::Tls {
     172              :                                     tls: Box::new(tls_stream),
     173              :                                     tls_server_end_point,
     174              :                                 },
     175              :                                 read_buf,
     176              :                                 write_buf,
     177              :                             },
     178              :                         };
     179              :                     }
     180              :                 }
     181              :                 _ => return Err(HandshakeError::ProtocolViolation),
     182              :             },
     183              :             FeStartupPacket::GssEncRequest => match stream.get_ref() {
     184              :                 Stream::Raw { .. } if !tried_gss => {
     185              :                     tried_gss = true;
     186              : 
     187              :                     // Currently, we don't support GSSAPI
     188              :                     stream.write_message(&Be::EncryptionResponse(false)).await?;
     189              :                 }
     190              :                 _ => return Err(HandshakeError::ProtocolViolation),
     191              :             },
     192              :             FeStartupPacket::StartupMessage { params, version }
     193              :                 if PG_PROTOCOL_EARLIEST <= version && version <= PG_PROTOCOL_LATEST =>
     194              :             {
     195              :                 // Check that the config has been consumed during upgrade
     196              :                 // OR we didn't provide it at all (for dev purposes).
     197              :                 if tls.is_some() {
     198              :                     return stream
     199              :                         .throw_error_str(ERR_INSECURE_CONNECTION, crate::error::ErrorKind::User)
     200              :                         .await?;
     201              :                 }
     202              : 
     203              :                 // This log highlights the start of the connection.
     204              :                 // This contains useful information for debugging, not logged elsewhere, like role name and endpoint id.
     205              :                 info!(
     206              :                     ?version,
     207              :                     ?params,
     208              :                     session_type = "normal",
     209              :                     "successful handshake"
     210              :                 );
     211              :                 break Ok(HandshakeData::Startup(stream, params));
     212              :             }
     213              :             // downgrade protocol version
     214              :             FeStartupPacket::StartupMessage { params, version }
     215              :                 if version.major() == 3 && version > PG_PROTOCOL_LATEST =>
     216              :             {
     217              :                 debug!(?version, "unsupported minor version");
     218              : 
     219              :                 // no protocol extensions are supported.
     220              :                 // <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/backend/tcop/backend_startup.c#L744-L753>
     221              :                 let mut unsupported = vec![];
     222              :                 for (k, _) in params.iter() {
     223              :                     if k.starts_with("_pq_.") {
     224              :                         unsupported.push(k);
     225              :                     }
     226              :                 }
     227              : 
     228              :                 // TODO: remove unsupported options so we don't send them to compute.
     229              : 
     230              :                 stream
     231              :                     .write_message(&Be::NegotiateProtocolVersion {
     232              :                         version: PG_PROTOCOL_LATEST,
     233              :                         options: &unsupported,
     234              :                     })
     235              :                     .await?;
     236              : 
     237              :                 info!(
     238              :                     ?version,
     239              :                     ?params,
     240              :                     session_type = "normal",
     241              :                     "successful handshake; unsupported minor version requested"
     242              :                 );
     243              :                 break Ok(HandshakeData::Startup(stream, params));
     244              :             }
     245              :             FeStartupPacket::StartupMessage { version, params } => {
     246              :                 warn!(
     247              :                     ?version,
     248              :                     ?params,
     249              :                     session_type = "normal",
     250              :                     "unsuccessful handshake; unsupported version"
     251              :                 );
     252              :                 return Err(HandshakeError::ProtocolViolation);
     253              :             }
     254              :             FeStartupPacket::CancelRequest(cancel_key_data) => {
     255              :                 info!(session_type = "cancellation", "successful handshake");
     256              :                 break Ok(HandshakeData::Cancel(cancel_key_data));
     257              :             }
     258              :         }
     259              :     }
     260              : }
        

Generated by: LCOV version 2.1-beta