|             Line data    Source code 
       1              : use futures::{FutureExt, TryFutureExt};
       2              : use thiserror::Error;
       3              : use tokio::io::{AsyncRead, AsyncWrite};
       4              : use tracing::{debug, info, warn};
       5              : 
       6              : use crate::auth::endpoint_sni;
       7              : use crate::config::TlsConfig;
       8              : use crate::context::RequestContext;
       9              : use crate::error::ReportableError;
      10              : use crate::metrics::Metrics;
      11              : use crate::pglb::TlsRequired;
      12              : use crate::pqproto::{
      13              :     BeMessage, CancelKeyData, FeStartupPacket, ProtocolVersion, StartupMessageParams,
      14              : };
      15              : use crate::stream::{PqStream, Stream, StreamUpgradeError};
      16              : use crate::tls::PG_ALPN_PROTOCOL;
      17              : 
      18              : #[derive(Error, Debug)]
      19              : pub(crate) enum HandshakeError {
      20              :     #[error("data is sent before server replied with EncryptionResponse")]
      21              :     EarlyData,
      22              : 
      23              :     #[error("protocol violation")]
      24              :     ProtocolViolation,
      25              : 
      26              :     #[error("{0}")]
      27              :     StreamUpgradeError(#[from] StreamUpgradeError),
      28              : 
      29              :     #[error("{0}")]
      30              :     Io(#[from] std::io::Error),
      31              : 
      32              :     #[error("{0}")]
      33              :     ReportedError(#[from] crate::stream::ReportedError),
      34              : }
      35              : 
      36              : impl ReportableError for HandshakeError {
      37            0 :     fn get_error_kind(&self) -> crate::error::ErrorKind {
      38            0 :         match self {
      39            0 :             HandshakeError::EarlyData => crate::error::ErrorKind::User,
      40            0 :             HandshakeError::ProtocolViolation => crate::error::ErrorKind::User,
      41            0 :             HandshakeError::StreamUpgradeError(upgrade) => match upgrade {
      42            0 :                 StreamUpgradeError::AlreadyTls => crate::error::ErrorKind::Service,
      43            0 :                 StreamUpgradeError::Io(_) => crate::error::ErrorKind::ClientDisconnect,
      44              :             },
      45            0 :             HandshakeError::Io(_) => crate::error::ErrorKind::ClientDisconnect,
      46            0 :             HandshakeError::ReportedError(e) => e.get_error_kind(),
      47              :         }
      48            0 :     }
      49              : }
      50              : 
      51              : pub(crate) enum HandshakeData<S> {
      52              :     Startup(PqStream<Stream<S>>, StartupMessageParams),
      53              :     Cancel(CancelKeyData),
      54              : }
      55              : 
      56              : /// Establish a (most probably, secure) connection with the client.
      57              : /// For better testing experience, `stream` can be any object satisfying the traits.
      58              : /// It's easier to work with owned `stream` here as we need to upgrade it to TLS;
      59              : /// we also take an extra care of propagating only the select handshake errors to client.
      60              : #[tracing::instrument(skip_all)]
      61              : pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin + Send>(
      62              :     ctx: &RequestContext,
      63              :     stream: S,
      64              :     mut tls: Option<&TlsConfig>,
      65              :     record_handshake_error: bool,
      66              : ) -> Result<HandshakeData<S>, HandshakeError> {
      67              :     // Client may try upgrading to each protocol only once
      68              :     let (mut tried_ssl, mut tried_gss) = (false, false);
      69              : 
      70              :     const PG_PROTOCOL_EARLIEST: ProtocolVersion = ProtocolVersion::new(3, 0);
      71              :     const PG_PROTOCOL_LATEST: ProtocolVersion = ProtocolVersion::new(3, 0);
      72              : 
      73              :     let (mut stream, mut msg) = PqStream::parse_startup(Stream::from_raw(stream)).await?;
      74              :     loop {
      75              :         match msg {
      76              :             FeStartupPacket::SslRequest { direct } => match stream.get_ref() {
      77              :                 Stream::Raw { .. } if !tried_ssl => {
      78              :                     tried_ssl = true;
      79              : 
      80              :                     if let Some(tls) = tls.take() {
      81              :                         // Upgrade raw stream into a secure TLS-backed stream.
      82              :                         // NOTE: We've consumed `tls`; this fact will be used later.
      83              : 
      84              :                         let mut read_buf;
      85              :                         let raw = if let Some(direct) = &direct {
      86              :                             read_buf = &direct[..];
      87              :                             stream.accept_direct_tls()
      88              :                         } else {
      89              :                             read_buf = &[];
      90              :                             stream.accept_tls().await?
      91              :                         };
      92              : 
      93              :                         let Stream::Raw { raw } = raw else {
      94              :                             return Err(HandshakeError::StreamUpgradeError(
      95              :                                 StreamUpgradeError::AlreadyTls,
      96              :                             ));
      97              :                         };
      98              : 
      99              :                         let mut res = Ok(());
     100              :                         let accept = tokio_rustls::TlsAcceptor::from(tls.pg_config.clone())
     101           20 :                             .accept_with(raw, |session| {
     102              :                                 // push the early data to the tls session
     103           20 :                                 while !read_buf.is_empty() {
     104            0 :                                     match session.read_tls(&mut read_buf) {
     105            0 :                                         Ok(_) => {}
     106            0 :                                         Err(e) => {
     107            0 :                                             res = Err(e);
     108            0 :                                             break;
     109              :                                         }
     110              :                                     }
     111              :                                 }
     112           20 :                             })
     113              :                             .map_ok(Box::new)
     114              :                             .boxed();
     115              : 
     116              :                         res?;
     117              : 
     118              :                         if !read_buf.is_empty() {
     119              :                             return Err(HandshakeError::EarlyData);
     120              :                         }
     121              : 
     122            0 :                         let tls_stream = accept.await.inspect_err(|_| {
     123            0 :                             if record_handshake_error {
     124            0 :                                 Metrics::get().proxy.tls_handshake_failures.inc();
     125            0 :                             }
     126            0 :                         })?;
     127              : 
     128              :                         let conn_info = tls_stream.get_ref().1;
     129              : 
     130              :                         // try parse endpoint
     131              :                         let ep = conn_info
     132              :                             .server_name()
     133           20 :                             .and_then(|sni| endpoint_sni(sni, &tls.common_names));
     134              :                         if let Some(ep) = ep {
     135              :                             ctx.set_endpoint_id(ep);
     136              :                         }
     137              : 
     138              :                         // check the ALPN, if exists, as required.
     139              :                         match conn_info.alpn_protocol() {
     140              :                             None | Some(PG_ALPN_PROTOCOL) => {}
     141              :                             Some(other) => {
     142              :                                 let alpn = String::from_utf8_lossy(other);
     143              :                                 warn!(%alpn, "unexpected ALPN");
     144              :                                 return Err(HandshakeError::ProtocolViolation);
     145              :                             }
     146              :                         }
     147              : 
     148              :                         let (_, tls_server_end_point) =
     149              :                             tls.cert_resolver.resolve(conn_info.server_name());
     150              : 
     151              :                         let tls = Stream::Tls {
     152              :                             tls: tls_stream,
     153              :                             tls_server_end_point,
     154              :                         };
     155              :                         (stream, msg) = PqStream::parse_startup(tls).await?;
     156              :                     } else {
     157              :                         if direct.is_some() {
     158              :                             // client sent us a ClientHello already, we can't do anything with it.
     159              :                             return Err(HandshakeError::ProtocolViolation);
     160              :                         }
     161              :                         msg = stream.reject_encryption().await?;
     162              :                     }
     163              :                 }
     164              :                 _ => return Err(HandshakeError::ProtocolViolation),
     165              :             },
     166              :             FeStartupPacket::GssEncRequest => match stream.get_ref() {
     167              :                 Stream::Raw { .. } if !tried_gss => {
     168              :                     tried_gss = true;
     169              : 
     170              :                     // Currently, we don't support GSSAPI
     171              :                     msg = stream.reject_encryption().await?;
     172              :                 }
     173              :                 _ => return Err(HandshakeError::ProtocolViolation),
     174              :             },
     175              :             FeStartupPacket::StartupMessage { params, version }
     176              :                 if PG_PROTOCOL_EARLIEST <= version && version <= PG_PROTOCOL_LATEST =>
     177              :             {
     178              :                 // Check that the config has been consumed during upgrade
     179              :                 // OR we didn't provide it at all (for dev purposes).
     180              :                 if tls.is_some() {
     181              :                     Err(stream.throw_error(TlsRequired, None).await)?;
     182              :                 }
     183              : 
     184              :                 // This log highlights the start of the connection.
     185              :                 // This contains useful information for debugging, not logged elsewhere, like role name and endpoint id.
     186              :                 info!(
     187              :                     ?version,
     188              :                     ?params,
     189              :                     session_type = "normal",
     190              :                     "successful handshake"
     191              :                 );
     192              :                 break Ok(HandshakeData::Startup(stream, params));
     193              :             }
     194              :             // downgrade protocol version
     195              :             FeStartupPacket::StartupMessage { params, version }
     196              :                 if version.major() == 3 && version > PG_PROTOCOL_LATEST =>
     197              :             {
     198              :                 debug!(?version, "unsupported minor version");
     199              : 
     200              :                 // no protocol extensions are supported.
     201              :                 // <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/backend/tcop/backend_startup.c#L744-L753>
     202              :                 let mut unsupported = vec![];
     203              :                 let mut supported = StartupMessageParams::default();
     204              : 
     205              :                 for (k, v) in params.iter() {
     206              :                     if k.starts_with("_pq_.") {
     207              :                         unsupported.push(k);
     208              :                     } else {
     209              :                         supported.insert(k, v);
     210              :                     }
     211              :                 }
     212              : 
     213              :                 stream.write_message(BeMessage::NegotiateProtocolVersion {
     214              :                     version: PG_PROTOCOL_LATEST,
     215              :                     options: &unsupported,
     216              :                 });
     217              :                 stream.flush().await?;
     218              : 
     219              :                 info!(
     220              :                     ?version,
     221              :                     ?params,
     222              :                     session_type = "normal",
     223              :                     "successful handshake; unsupported minor version requested"
     224              :                 );
     225              :                 break Ok(HandshakeData::Startup(stream, supported));
     226              :             }
     227              :             FeStartupPacket::StartupMessage { version, params } => {
     228              :                 warn!(
     229              :                     ?version,
     230              :                     ?params,
     231              :                     session_type = "normal",
     232              :                     "unsuccessful handshake; unsupported version"
     233              :                 );
     234              :                 return Err(HandshakeError::ProtocolViolation);
     235              :             }
     236              :             FeStartupPacket::CancelRequest(cancel_key_data) => {
     237              :                 info!(session_type = "cancellation", "successful handshake");
     238              :                 break Ok(HandshakeData::Cancel(cancel_key_data));
     239              :             }
     240              :         }
     241              :     }
     242              : }
         |