LCOV - code coverage report
Current view: top level - proxy/src - compute.rs (source / functions) Coverage Total Hit
Test: bb45db3982713bfd5bec075773079136e362195e.info Lines: 17.6 % 256 45
Test Date: 2024-12-11 15:53:32 Functions: 28.0 % 25 7

            Line data    Source code
       1              : use std::io;
       2              : use std::net::SocketAddr;
       3              : use std::sync::Arc;
       4              : use std::time::Duration;
       5              : 
       6              : use futures::{FutureExt, TryFutureExt};
       7              : use itertools::Itertools;
       8              : use once_cell::sync::OnceCell;
       9              : use postgres_client::tls::MakeTlsConnect;
      10              : use postgres_client::{CancelToken, RawConnection};
      11              : use postgres_protocol::message::backend::NoticeResponseBody;
      12              : use pq_proto::StartupMessageParams;
      13              : use rustls::client::danger::ServerCertVerifier;
      14              : use rustls::crypto::ring;
      15              : use rustls::pki_types::InvalidDnsNameError;
      16              : use thiserror::Error;
      17              : use tokio::net::TcpStream;
      18              : use tracing::{debug, error, info, warn};
      19              : 
      20              : use crate::auth::parse_endpoint_param;
      21              : use crate::cancellation::CancelClosure;
      22              : use crate::context::RequestContext;
      23              : use crate::control_plane::client::ApiLockError;
      24              : use crate::control_plane::errors::WakeComputeError;
      25              : use crate::control_plane::messages::MetricsAuxInfo;
      26              : use crate::error::{ReportableError, UserFacingError};
      27              : use crate::metrics::{Metrics, NumDbConnectionsGuard};
      28              : use crate::postgres_rustls::MakeRustlsConnect;
      29              : use crate::proxy::neon_option;
      30              : use crate::types::Host;
      31              : 
      32              : pub const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
      33              : 
      34              : #[derive(Debug, Error)]
      35              : pub(crate) enum ConnectionError {
      36              :     /// This error doesn't seem to reveal any secrets; for instance,
      37              :     /// `postgres_client::error::Kind` doesn't contain ip addresses and such.
      38              :     #[error("{COULD_NOT_CONNECT}: {0}")]
      39              :     Postgres(#[from] postgres_client::Error),
      40              : 
      41              :     #[error("{COULD_NOT_CONNECT}: {0}")]
      42              :     CouldNotConnect(#[from] io::Error),
      43              : 
      44              :     #[error("Couldn't load native TLS certificates: {0:?}")]
      45              :     TlsCertificateError(Vec<rustls_native_certs::Error>),
      46              : 
      47              :     #[error("{COULD_NOT_CONNECT}: {0}")]
      48              :     TlsError(#[from] InvalidDnsNameError),
      49              : 
      50              :     #[error("{COULD_NOT_CONNECT}: {0}")]
      51              :     WakeComputeError(#[from] WakeComputeError),
      52              : 
      53              :     #[error("error acquiring resource permit: {0}")]
      54              :     TooManyConnectionAttempts(#[from] ApiLockError),
      55              : }
      56              : 
      57              : impl UserFacingError for ConnectionError {
      58            0 :     fn to_string_client(&self) -> String {
      59            0 :         match self {
      60              :             // This helps us drop irrelevant library-specific prefixes.
      61              :             // TODO: propagate severity level and other parameters.
      62            0 :             ConnectionError::Postgres(err) => match err.as_db_error() {
      63            0 :                 Some(err) => {
      64            0 :                     let msg = err.message();
      65            0 : 
      66            0 :                     if msg.starts_with("unsupported startup parameter: ")
      67            0 :                         || msg.starts_with("unsupported startup parameter in options: ")
      68              :                     {
      69            0 :                         format!("{msg}. Please use unpooled connection or remove this parameter from the startup package. More details: https://neon.tech/docs/connect/connection-errors#unsupported-startup-parameter")
      70              :                     } else {
      71            0 :                         msg.to_owned()
      72              :                     }
      73              :                 }
      74            0 :                 None => err.to_string(),
      75              :             },
      76            0 :             ConnectionError::WakeComputeError(err) => err.to_string_client(),
      77              :             ConnectionError::TooManyConnectionAttempts(_) => {
      78            0 :                 "Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
      79              :             }
      80            0 :             _ => COULD_NOT_CONNECT.to_owned(),
      81              :         }
      82            0 :     }
      83              : }
      84              : 
      85              : impl ReportableError for ConnectionError {
      86            0 :     fn get_error_kind(&self) -> crate::error::ErrorKind {
      87            0 :         match self {
      88            0 :             ConnectionError::Postgres(e) if e.as_db_error().is_some() => {
      89            0 :                 crate::error::ErrorKind::Postgres
      90              :             }
      91            0 :             ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute,
      92            0 :             ConnectionError::CouldNotConnect(_) => crate::error::ErrorKind::Compute,
      93            0 :             ConnectionError::TlsCertificateError(_) => crate::error::ErrorKind::Service,
      94            0 :             ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
      95            0 :             ConnectionError::WakeComputeError(e) => e.get_error_kind(),
      96            0 :             ConnectionError::TooManyConnectionAttempts(e) => e.get_error_kind(),
      97              :         }
      98            0 :     }
      99              : }
     100              : 
     101              : /// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
     102              : pub(crate) type ScramKeys = postgres_client::config::ScramKeys<32>;
     103              : 
     104              : /// A config for establishing a connection to compute node.
     105              : /// Eventually, `postgres_client` will be replaced with something better.
     106              : /// Newtype allows us to implement methods on top of it.
     107              : #[derive(Clone)]
     108              : pub(crate) struct ConnCfg(Box<postgres_client::Config>);
     109              : 
     110              : /// Creation and initialization routines.
     111              : impl ConnCfg {
     112           10 :     pub(crate) fn new(host: String, port: u16) -> Self {
     113           10 :         Self(Box::new(postgres_client::Config::new(host, port)))
     114           10 :     }
     115              : 
     116              :     /// Reuse password or auth keys from the other config.
     117            4 :     pub(crate) fn reuse_password(&mut self, other: Self) {
     118            4 :         if let Some(password) = other.get_password() {
     119            4 :             self.password(password);
     120            4 :         }
     121              : 
     122            4 :         if let Some(keys) = other.get_auth_keys() {
     123            0 :             self.auth_keys(keys);
     124            4 :         }
     125            4 :     }
     126              : 
     127            0 :     pub(crate) fn get_host(&self) -> Host {
     128            0 :         match self.0.get_host() {
     129            0 :             postgres_client::config::Host::Tcp(s) => s.into(),
     130            0 :         }
     131            0 :     }
     132              : 
     133              :     /// Apply startup message params to the connection config.
     134            0 :     pub(crate) fn set_startup_params(
     135            0 :         &mut self,
     136            0 :         params: &StartupMessageParams,
     137            0 :         arbitrary_params: bool,
     138            0 :     ) {
     139            0 :         if !arbitrary_params {
     140            0 :             self.set_param("client_encoding", "UTF8");
     141            0 :         }
     142            0 :         for (k, v) in params.iter() {
     143            0 :             match k {
     144              :                 // Only set `user` if it's not present in the config.
     145              :                 // Console redirect auth flow takes username from the console's response.
     146            0 :                 "user" if self.user_is_set() => continue,
     147            0 :                 "database" if self.db_is_set() => continue,
     148            0 :                 "options" => {
     149            0 :                     if let Some(options) = filtered_options(v) {
     150            0 :                         self.set_param(k, &options);
     151            0 :                     }
     152              :                 }
     153            0 :                 "user" | "database" | "application_name" | "replication" => {
     154            0 :                     self.set_param(k, v);
     155            0 :                 }
     156              : 
     157              :                 // if we allow arbitrary params, then we forward them through.
     158              :                 // this is a flag for a period of backwards compatibility
     159            0 :                 k if arbitrary_params => {
     160            0 :                     self.set_param(k, v);
     161            0 :                 }
     162            0 :                 _ => {}
     163              :             }
     164              :         }
     165            0 :     }
     166              : }
     167              : 
     168              : impl std::ops::Deref for ConnCfg {
     169              :     type Target = postgres_client::Config;
     170              : 
     171            8 :     fn deref(&self) -> &Self::Target {
     172            8 :         &self.0
     173            8 :     }
     174              : }
     175              : 
     176              : /// For now, let's make it easier to setup the config.
     177              : impl std::ops::DerefMut for ConnCfg {
     178           10 :     fn deref_mut(&mut self) -> &mut Self::Target {
     179           10 :         &mut self.0
     180           10 :     }
     181              : }
     182              : 
     183              : impl ConnCfg {
     184              :     /// Establish a raw TCP connection to the compute node.
     185            0 :     async fn connect_raw(&self, timeout: Duration) -> io::Result<(SocketAddr, TcpStream, &str)> {
     186              :         use postgres_client::config::Host;
     187              : 
     188              :         // wrap TcpStream::connect with timeout
     189            0 :         let connect_with_timeout = |host, port| {
     190            0 :             tokio::time::timeout(timeout, TcpStream::connect((host, port))).map(
     191            0 :                 move |res| match res {
     192            0 :                     Ok(tcpstream_connect_res) => tcpstream_connect_res,
     193            0 :                     Err(_) => Err(io::Error::new(
     194            0 :                         io::ErrorKind::TimedOut,
     195            0 :                         format!("exceeded connection timeout {timeout:?}"),
     196            0 :                     )),
     197            0 :                 },
     198            0 :             )
     199            0 :         };
     200              : 
     201            0 :         let connect_once = |host, port| {
     202            0 :             debug!("trying to connect to compute node at {host}:{port}");
     203            0 :             connect_with_timeout(host, port).and_then(|socket| async {
     204            0 :                 let socket_addr = socket.peer_addr()?;
     205              :                 // This prevents load balancer from severing the connection.
     206            0 :                 socket2::SockRef::from(&socket).set_keepalive(true)?;
     207            0 :                 Ok((socket_addr, socket))
     208            0 :             })
     209            0 :         };
     210              : 
     211              :         // We can't reuse connection establishing logic from `postgres_client` here,
     212              :         // because it has no means for extracting the underlying socket which we
     213              :         // require for our business.
     214            0 :         let port = self.0.get_port();
     215            0 :         let host = self.0.get_host();
     216            0 : 
     217            0 :         let host = match host {
     218            0 :             Host::Tcp(host) => host.as_str(),
     219            0 :         };
     220            0 : 
     221            0 :         match connect_once(host, port).await {
     222            0 :             Ok((sockaddr, stream)) => Ok((sockaddr, stream, host)),
     223            0 :             Err(err) => {
     224            0 :                 warn!("couldn't connect to compute node at {host}:{port}: {err}");
     225            0 :                 Err(err)
     226              :             }
     227              :         }
     228            0 :     }
     229              : }
     230              : 
     231              : type RustlsStream = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::Stream;
     232              : 
     233              : pub(crate) struct PostgresConnection {
     234              :     /// Socket connected to a compute node.
     235              :     pub(crate) stream:
     236              :         postgres_client::maybe_tls_stream::MaybeTlsStream<tokio::net::TcpStream, RustlsStream>,
     237              :     /// PostgreSQL connection parameters.
     238              :     pub(crate) params: std::collections::HashMap<String, String>,
     239              :     /// Query cancellation token.
     240              :     pub(crate) cancel_closure: CancelClosure,
     241              :     /// Labels for proxy's metrics.
     242              :     pub(crate) aux: MetricsAuxInfo,
     243              :     /// Notices received from compute after authenticating
     244              :     pub(crate) delayed_notice: Vec<NoticeResponseBody>,
     245              : 
     246              :     _guage: NumDbConnectionsGuard<'static>,
     247              : }
     248              : 
     249              : impl ConnCfg {
     250              :     /// Connect to a corresponding compute node.
     251            0 :     pub(crate) async fn connect(
     252            0 :         &self,
     253            0 :         ctx: &RequestContext,
     254            0 :         allow_self_signed_compute: bool,
     255            0 :         aux: MetricsAuxInfo,
     256            0 :         timeout: Duration,
     257            0 :     ) -> Result<PostgresConnection, ConnectionError> {
     258            0 :         let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
     259            0 :         let (socket_addr, stream, host) = self.connect_raw(timeout).await?;
     260            0 :         drop(pause);
     261              : 
     262            0 :         let client_config = if allow_self_signed_compute {
     263              :             // Allow all certificates for creating the connection
     264            0 :             let verifier = Arc::new(AcceptEverythingVerifier);
     265            0 :             rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
     266            0 :                 .with_safe_default_protocol_versions()
     267            0 :                 .expect("ring should support the default protocol versions")
     268            0 :                 .dangerous()
     269            0 :                 .with_custom_certificate_verifier(verifier)
     270              :         } else {
     271            0 :             let root_store = TLS_ROOTS
     272            0 :                 .get_or_try_init(load_certs)
     273            0 :                 .map_err(ConnectionError::TlsCertificateError)?
     274            0 :                 .clone();
     275            0 :             rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
     276            0 :                 .with_safe_default_protocol_versions()
     277            0 :                 .expect("ring should support the default protocol versions")
     278            0 :                 .with_root_certificates(root_store)
     279              :         };
     280            0 :         let client_config = client_config.with_no_client_auth();
     281            0 : 
     282            0 :         let mut mk_tls = crate::postgres_rustls::MakeRustlsConnect::new(client_config);
     283            0 :         let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
     284            0 :             &mut mk_tls,
     285            0 :             host,
     286            0 :         )?;
     287              : 
     288              :         // connect_raw() will not use TLS if sslmode is "disable"
     289            0 :         let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
     290            0 :         let connection = self.0.connect_raw(stream, tls).await?;
     291            0 :         drop(pause);
     292            0 : 
     293            0 :         let RawConnection {
     294            0 :             stream,
     295            0 :             parameters,
     296            0 :             delayed_notice,
     297            0 :             process_id,
     298            0 :             secret_key,
     299            0 :         } = connection;
     300            0 : 
     301            0 :         tracing::Span::current().record("pid", tracing::field::display(process_id));
     302            0 :         let stream = stream.into_inner();
     303            0 : 
     304            0 :         // TODO: lots of useful info but maybe we can move it elsewhere (eg traces?)
     305            0 :         info!(
     306            0 :             cold_start_info = ctx.cold_start_info().as_str(),
     307            0 :             "connected to compute node at {host} ({socket_addr}) sslmode={:?}",
     308            0 :             self.0.get_ssl_mode()
     309              :         );
     310              : 
     311              :         // NB: CancelToken is supposed to hold socket_addr, but we use connect_raw.
     312              :         // Yet another reason to rework the connection establishing code.
     313            0 :         let cancel_closure = CancelClosure::new(
     314            0 :             socket_addr,
     315            0 :             CancelToken {
     316            0 :                 socket_config: None,
     317            0 :                 ssl_mode: self.0.get_ssl_mode(),
     318            0 :                 process_id,
     319            0 :                 secret_key,
     320            0 :             },
     321            0 :             vec![],
     322            0 :         );
     323            0 : 
     324            0 :         let connection = PostgresConnection {
     325            0 :             stream,
     326            0 :             params: parameters,
     327            0 :             delayed_notice,
     328            0 :             cancel_closure,
     329            0 :             aux,
     330            0 :             _guage: Metrics::get().proxy.db_connections.guard(ctx.protocol()),
     331            0 :         };
     332            0 : 
     333            0 :         Ok(connection)
     334            0 :     }
     335              : }
     336              : 
     337              : /// Retrieve `options` from a startup message, dropping all proxy-secific flags.
     338            6 : fn filtered_options(options: &str) -> Option<String> {
     339            6 :     #[allow(unstable_name_collisions)]
     340            6 :     let options: String = StartupMessageParams::parse_options_raw(options)
     341           14 :         .filter(|opt| parse_endpoint_param(opt).is_none() && neon_option(opt).is_none())
     342            6 :         .intersperse(" ") // TODO: use impl from std once it's stabilized
     343            6 :         .collect();
     344            6 : 
     345            6 :     // Don't even bother with empty options.
     346            6 :     if options.is_empty() {
     347            3 :         return None;
     348            3 :     }
     349            3 : 
     350            3 :     Some(options)
     351            6 : }
     352              : 
     353            0 : fn load_certs() -> Result<Arc<rustls::RootCertStore>, Vec<rustls_native_certs::Error>> {
     354            0 :     let der_certs = rustls_native_certs::load_native_certs();
     355            0 : 
     356            0 :     if !der_certs.errors.is_empty() {
     357            0 :         return Err(der_certs.errors);
     358            0 :     }
     359            0 : 
     360            0 :     let mut store = rustls::RootCertStore::empty();
     361            0 :     store.add_parsable_certificates(der_certs.certs);
     362            0 :     Ok(Arc::new(store))
     363            0 : }
     364              : static TLS_ROOTS: OnceCell<Arc<rustls::RootCertStore>> = OnceCell::new();
     365              : 
     366              : #[derive(Debug)]
     367              : struct AcceptEverythingVerifier;
     368              : impl ServerCertVerifier for AcceptEverythingVerifier {
     369            0 :     fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
     370              :         use rustls::SignatureScheme;
     371              :         // The schemes for which `SignatureScheme::supported_in_tls13` returns true.
     372            0 :         vec![
     373            0 :             SignatureScheme::ECDSA_NISTP521_SHA512,
     374            0 :             SignatureScheme::ECDSA_NISTP384_SHA384,
     375            0 :             SignatureScheme::ECDSA_NISTP256_SHA256,
     376            0 :             SignatureScheme::RSA_PSS_SHA512,
     377            0 :             SignatureScheme::RSA_PSS_SHA384,
     378            0 :             SignatureScheme::RSA_PSS_SHA256,
     379            0 :             SignatureScheme::ED25519,
     380            0 :         ]
     381            0 :     }
     382            0 :     fn verify_server_cert(
     383            0 :         &self,
     384            0 :         _end_entity: &rustls::pki_types::CertificateDer<'_>,
     385            0 :         _intermediates: &[rustls::pki_types::CertificateDer<'_>],
     386            0 :         _server_name: &rustls::pki_types::ServerName<'_>,
     387            0 :         _ocsp_response: &[u8],
     388            0 :         _now: rustls::pki_types::UnixTime,
     389            0 :     ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
     390            0 :         Ok(rustls::client::danger::ServerCertVerified::assertion())
     391            0 :     }
     392            0 :     fn verify_tls12_signature(
     393            0 :         &self,
     394            0 :         _message: &[u8],
     395            0 :         _cert: &rustls::pki_types::CertificateDer<'_>,
     396            0 :         _dss: &rustls::DigitallySignedStruct,
     397            0 :     ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
     398            0 :         Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
     399            0 :     }
     400            0 :     fn verify_tls13_signature(
     401            0 :         &self,
     402            0 :         _message: &[u8],
     403            0 :         _cert: &rustls::pki_types::CertificateDer<'_>,
     404            0 :         _dss: &rustls::DigitallySignedStruct,
     405            0 :     ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
     406            0 :         Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
     407            0 :     }
     408              : }
     409              : 
     410              : #[cfg(test)]
     411              : mod tests {
     412              :     use super::*;
     413              : 
     414              :     #[test]
     415            1 :     fn test_filtered_options() {
     416            1 :         // Empty options is unlikely to be useful anyway.
     417            1 :         let params = "";
     418            1 :         assert_eq!(filtered_options(params), None);
     419              : 
     420              :         // It's likely that clients will only use options to specify endpoint/project.
     421            1 :         let params = "project=foo";
     422            1 :         assert_eq!(filtered_options(params), None);
     423              : 
     424              :         // Same, because unescaped whitespaces are no-op.
     425            1 :         let params = " project=foo ";
     426            1 :         assert_eq!(filtered_options(params).as_deref(), None);
     427              : 
     428            1 :         let params = r"\  project=foo \ ";
     429            1 :         assert_eq!(filtered_options(params).as_deref(), Some(r"\  \ "));
     430              : 
     431            1 :         let params = "project = foo";
     432            1 :         assert_eq!(filtered_options(params).as_deref(), Some("project = foo"));
     433              : 
     434            1 :         let params = "project = foo neon_endpoint_type:read_write   neon_lsn:0/2 neon_proxy_params_compat:true";
     435            1 :         assert_eq!(filtered_options(params).as_deref(), Some("project = foo"));
     436            1 :     }
     437              : }
        

Generated by: LCOV version 2.1-beta