LCOV - code coverage report
Current view: top level - proxy/src - compute.rs (source / functions) Coverage Total Hit
Test: b4ae4c4857f9ef3e144e982a35ee23bc84c71983.info Lines: 18.1 % 265 48
Test Date: 2024-10-22 22:13:45 Functions: 20.6 % 34 7

            Line data    Source code
       1              : use std::io;
       2              : use std::net::SocketAddr;
       3              : use std::sync::Arc;
       4              : use std::time::Duration;
       5              : 
       6              : use futures::{FutureExt, TryFutureExt};
       7              : use itertools::Itertools;
       8              : use once_cell::sync::OnceCell;
       9              : use pq_proto::StartupMessageParams;
      10              : use rustls::client::danger::ServerCertVerifier;
      11              : use rustls::crypto::aws_lc_rs;
      12              : use rustls::pki_types::InvalidDnsNameError;
      13              : use thiserror::Error;
      14              : use tokio::net::TcpStream;
      15              : use tokio_postgres::tls::MakeTlsConnect;
      16              : use tokio_postgres_rustls::MakeRustlsConnect;
      17              : use tracing::{error, info, warn};
      18              : 
      19              : use crate::auth::parse_endpoint_param;
      20              : use crate::cancellation::CancelClosure;
      21              : use crate::context::RequestMonitoring;
      22              : use crate::control_plane::errors::WakeComputeError;
      23              : use crate::control_plane::messages::MetricsAuxInfo;
      24              : use crate::control_plane::provider::ApiLockError;
      25              : use crate::error::{ReportableError, UserFacingError};
      26              : use crate::metrics::{Metrics, NumDbConnectionsGuard};
      27              : use crate::proxy::neon_option;
      28              : use crate::Host;
      29              : 
      30              : pub const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
      31              : 
      32            0 : #[derive(Debug, Error)]
      33              : pub(crate) enum ConnectionError {
      34              :     /// This error doesn't seem to reveal any secrets; for instance,
      35              :     /// `tokio_postgres::error::Kind` doesn't contain ip addresses and such.
      36              :     #[error("{COULD_NOT_CONNECT}: {0}")]
      37              :     Postgres(#[from] tokio_postgres::Error),
      38              : 
      39              :     #[error("{COULD_NOT_CONNECT}: {0}")]
      40              :     CouldNotConnect(#[from] io::Error),
      41              : 
      42              :     #[error("Couldn't load native TLS certificates: {0:?}")]
      43              :     TlsCertificateError(Vec<rustls_native_certs::Error>),
      44              : 
      45              :     #[error("{COULD_NOT_CONNECT}: {0}")]
      46              :     TlsError(#[from] InvalidDnsNameError),
      47              : 
      48              :     #[error("{COULD_NOT_CONNECT}: {0}")]
      49              :     WakeComputeError(#[from] WakeComputeError),
      50              : 
      51              :     #[error("error acquiring resource permit: {0}")]
      52              :     TooManyConnectionAttempts(#[from] ApiLockError),
      53              : }
      54              : 
      55              : impl UserFacingError for ConnectionError {
      56            0 :     fn to_string_client(&self) -> String {
      57            0 :         match self {
      58              :             // This helps us drop irrelevant library-specific prefixes.
      59              :             // TODO: propagate severity level and other parameters.
      60            0 :             ConnectionError::Postgres(err) => match err.as_db_error() {
      61            0 :                 Some(err) => {
      62            0 :                     let msg = err.message();
      63            0 : 
      64            0 :                     if msg.starts_with("unsupported startup parameter: ")
      65            0 :                         || msg.starts_with("unsupported startup parameter in options: ")
      66              :                     {
      67            0 :                         format!("{msg}. Please use unpooled connection or remove this parameter from the startup package. More details: https://neon.tech/docs/connect/connection-errors#unsupported-startup-parameter")
      68              :                     } else {
      69            0 :                         msg.to_owned()
      70              :                     }
      71              :                 }
      72            0 :                 None => err.to_string(),
      73              :             },
      74            0 :             ConnectionError::WakeComputeError(err) => err.to_string_client(),
      75              :             ConnectionError::TooManyConnectionAttempts(_) => {
      76            0 :                 "Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
      77              :             }
      78            0 :             _ => COULD_NOT_CONNECT.to_owned(),
      79              :         }
      80            0 :     }
      81              : }
      82              : 
      83              : impl ReportableError for ConnectionError {
      84            0 :     fn get_error_kind(&self) -> crate::error::ErrorKind {
      85            0 :         match self {
      86            0 :             ConnectionError::Postgres(e) if e.as_db_error().is_some() => {
      87            0 :                 crate::error::ErrorKind::Postgres
      88              :             }
      89            0 :             ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute,
      90            0 :             ConnectionError::CouldNotConnect(_) => crate::error::ErrorKind::Compute,
      91            0 :             ConnectionError::TlsCertificateError(_) => crate::error::ErrorKind::Service,
      92            0 :             ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
      93            0 :             ConnectionError::WakeComputeError(e) => e.get_error_kind(),
      94            0 :             ConnectionError::TooManyConnectionAttempts(e) => e.get_error_kind(),
      95              :         }
      96            0 :     }
      97              : }
      98              : 
      99              : /// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
     100              : pub(crate) type ScramKeys = tokio_postgres::config::ScramKeys<32>;
     101              : 
     102              : /// A config for establishing a connection to compute node.
     103              : /// Eventually, `tokio_postgres` will be replaced with something better.
     104              : /// Newtype allows us to implement methods on top of it.
     105              : #[derive(Clone, Default)]
     106              : pub(crate) struct ConnCfg(Box<tokio_postgres::Config>);
     107              : 
     108              : /// Creation and initialization routines.
     109              : impl ConnCfg {
     110           10 :     pub(crate) fn new() -> Self {
     111           10 :         Self::default()
     112           10 :     }
     113              : 
     114              :     /// Reuse password or auth keys from the other config.
     115            4 :     pub(crate) fn reuse_password(&mut self, other: Self) {
     116            4 :         if let Some(password) = other.get_password() {
     117            4 :             self.password(password);
     118            4 :         }
     119              : 
     120            4 :         if let Some(keys) = other.get_auth_keys() {
     121            0 :             self.auth_keys(keys);
     122            4 :         }
     123            4 :     }
     124              : 
     125            0 :     pub(crate) fn get_host(&self) -> Result<Host, WakeComputeError> {
     126            0 :         match self.0.get_hosts() {
     127            0 :             [tokio_postgres::config::Host::Tcp(s)] => Ok(s.into()),
     128              :             // we should not have multiple address or unix addresses.
     129            0 :             _ => Err(WakeComputeError::BadComputeAddress(
     130            0 :                 "invalid compute address".into(),
     131            0 :             )),
     132              :         }
     133            0 :     }
     134              : 
     135              :     /// Apply startup message params to the connection config.
     136            0 :     pub(crate) fn set_startup_params(&mut self, params: &StartupMessageParams) {
     137              :         // Only set `user` if it's not present in the config.
     138              :         // Web auth flow takes username from the console's response.
     139            0 :         if let (None, Some(user)) = (self.get_user(), params.get("user")) {
     140            0 :             self.user(user);
     141            0 :         }
     142              : 
     143              :         // Only set `dbname` if it's not present in the config.
     144              :         // Web auth flow takes dbname from the console's response.
     145            0 :         if let (None, Some(dbname)) = (self.get_dbname(), params.get("database")) {
     146            0 :             self.dbname(dbname);
     147            0 :         }
     148              : 
     149              :         // Don't add `options` if they were only used for specifying a project.
     150              :         // Connection pools don't support `options`, because they affect backend startup.
     151            0 :         if let Some(options) = filtered_options(params) {
     152            0 :             self.options(&options);
     153            0 :         }
     154              : 
     155            0 :         if let Some(app_name) = params.get("application_name") {
     156            0 :             self.application_name(app_name);
     157            0 :         }
     158              : 
     159              :         // TODO: This is especially ugly...
     160            0 :         if let Some(replication) = params.get("replication") {
     161              :             use tokio_postgres::config::ReplicationMode;
     162            0 :             match replication {
     163            0 :                 "true" | "on" | "yes" | "1" => {
     164            0 :                     self.replication_mode(ReplicationMode::Physical);
     165            0 :                 }
     166            0 :                 "database" => {
     167            0 :                     self.replication_mode(ReplicationMode::Logical);
     168            0 :                 }
     169            0 :                 _other => {}
     170              :             }
     171            0 :         }
     172              : 
     173              :         // TODO: extend the list of the forwarded startup parameters.
     174              :         // Currently, tokio-postgres doesn't allow us to pass
     175              :         // arbitrary parameters, but the ones above are a good start.
     176              :         //
     177              :         // This and the reverse params problem can be better addressed
     178              :         // in a bespoke connection machinery (a new library for that sake).
     179            0 :     }
     180              : }
     181              : 
     182              : impl std::ops::Deref for ConnCfg {
     183              :     type Target = tokio_postgres::Config;
     184              : 
     185            8 :     fn deref(&self) -> &Self::Target {
     186            8 :         &self.0
     187            8 :     }
     188              : }
     189              : 
     190              : /// For now, let's make it easier to setup the config.
     191              : impl std::ops::DerefMut for ConnCfg {
     192           10 :     fn deref_mut(&mut self) -> &mut Self::Target {
     193           10 :         &mut self.0
     194           10 :     }
     195              : }
     196              : 
     197              : impl ConnCfg {
     198              :     /// Establish a raw TCP connection to the compute node.
     199            0 :     async fn connect_raw(&self, timeout: Duration) -> io::Result<(SocketAddr, TcpStream, &str)> {
     200              :         use tokio_postgres::config::Host;
     201              : 
     202              :         // wrap TcpStream::connect with timeout
     203            0 :         let connect_with_timeout = |host, port| {
     204            0 :             tokio::time::timeout(timeout, TcpStream::connect((host, port))).map(
     205            0 :                 move |res| match res {
     206            0 :                     Ok(tcpstream_connect_res) => tcpstream_connect_res,
     207            0 :                     Err(_) => Err(io::Error::new(
     208            0 :                         io::ErrorKind::TimedOut,
     209            0 :                         format!("exceeded connection timeout {timeout:?}"),
     210            0 :                     )),
     211            0 :                 },
     212            0 :             )
     213            0 :         };
     214              : 
     215            0 :         let connect_once = |host, port| {
     216            0 :             info!("trying to connect to compute node at {host}:{port}");
     217            0 :             connect_with_timeout(host, port).and_then(|socket| async {
     218            0 :                 let socket_addr = socket.peer_addr()?;
     219              :                 // This prevents load balancer from severing the connection.
     220            0 :                 socket2::SockRef::from(&socket).set_keepalive(true)?;
     221            0 :                 Ok((socket_addr, socket))
     222            0 :             })
     223            0 :         };
     224              : 
     225              :         // We can't reuse connection establishing logic from `tokio_postgres` here,
     226              :         // because it has no means for extracting the underlying socket which we
     227              :         // require for our business.
     228            0 :         let mut connection_error = None;
     229            0 :         let ports = self.0.get_ports();
     230            0 :         let hosts = self.0.get_hosts();
     231            0 :         // the ports array is supposed to have 0 entries, 1 entry, or as many entries as in the hosts array
     232            0 :         if ports.len() > 1 && ports.len() != hosts.len() {
     233            0 :             return Err(io::Error::new(
     234            0 :                 io::ErrorKind::Other,
     235            0 :                 format!(
     236            0 :                     "bad compute config, \
     237            0 :                      ports and hosts entries' count does not match: {:?}",
     238            0 :                     self.0
     239            0 :                 ),
     240            0 :             ));
     241            0 :         }
     242              : 
     243            0 :         for (i, host) in hosts.iter().enumerate() {
     244            0 :             let port = ports.get(i).or_else(|| ports.first()).unwrap_or(&5432);
     245            0 :             let host = match host {
     246            0 :                 Host::Tcp(host) => host.as_str(),
     247            0 :                 Host::Unix(_) => continue, // unix sockets are not welcome here
     248              :             };
     249              : 
     250            0 :             match connect_once(host, *port).await {
     251            0 :                 Ok((sockaddr, stream)) => return Ok((sockaddr, stream, host)),
     252            0 :                 Err(err) => {
     253            0 :                     // We can't throw an error here, as there might be more hosts to try.
     254            0 :                     warn!("couldn't connect to compute node at {host}:{port}: {err}");
     255            0 :                     connection_error = Some(err);
     256              :                 }
     257              :             }
     258              :         }
     259              : 
     260            0 :         Err(connection_error.unwrap_or_else(|| {
     261            0 :             io::Error::new(
     262            0 :                 io::ErrorKind::Other,
     263            0 :                 format!("bad compute config: {:?}", self.0),
     264            0 :             )
     265            0 :         }))
     266            0 :     }
     267              : }
     268              : 
     269              : pub(crate) struct PostgresConnection {
     270              :     /// Socket connected to a compute node.
     271              :     pub(crate) stream: tokio_postgres::maybe_tls_stream::MaybeTlsStream<
     272              :         tokio::net::TcpStream,
     273              :         tokio_postgres_rustls::RustlsStream<tokio::net::TcpStream>,
     274              :     >,
     275              :     /// PostgreSQL connection parameters.
     276              :     pub(crate) params: std::collections::HashMap<String, String>,
     277              :     /// Query cancellation token.
     278              :     pub(crate) cancel_closure: CancelClosure,
     279              :     /// Labels for proxy's metrics.
     280              :     pub(crate) aux: MetricsAuxInfo,
     281              : 
     282              :     _guage: NumDbConnectionsGuard<'static>,
     283              : }
     284              : 
     285              : impl ConnCfg {
     286              :     /// Connect to a corresponding compute node.
     287            0 :     pub(crate) async fn connect(
     288            0 :         &self,
     289            0 :         ctx: &RequestMonitoring,
     290            0 :         allow_self_signed_compute: bool,
     291            0 :         aux: MetricsAuxInfo,
     292            0 :         timeout: Duration,
     293            0 :     ) -> Result<PostgresConnection, ConnectionError> {
     294            0 :         let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
     295            0 :         let (socket_addr, stream, host) = self.connect_raw(timeout).await?;
     296            0 :         drop(pause);
     297              : 
     298            0 :         let client_config = if allow_self_signed_compute {
     299              :             // Allow all certificates for creating the connection
     300            0 :             let verifier = Arc::new(AcceptEverythingVerifier);
     301            0 :             rustls::ClientConfig::builder_with_provider(Arc::new(aws_lc_rs::default_provider()))
     302            0 :                 .with_safe_default_protocol_versions()
     303            0 :                 .expect("aws_lc_rs should support the default protocol versions")
     304            0 :                 .dangerous()
     305            0 :                 .with_custom_certificate_verifier(verifier)
     306              :         } else {
     307            0 :             let root_store = TLS_ROOTS
     308            0 :                 .get_or_try_init(load_certs)
     309            0 :                 .map_err(ConnectionError::TlsCertificateError)?
     310            0 :                 .clone();
     311            0 :             rustls::ClientConfig::builder_with_provider(Arc::new(aws_lc_rs::default_provider()))
     312            0 :                 .with_safe_default_protocol_versions()
     313            0 :                 .expect("aws_lc_rs should support the default protocol versions")
     314            0 :                 .with_root_certificates(root_store)
     315              :         };
     316            0 :         let client_config = client_config.with_no_client_auth();
     317            0 : 
     318            0 :         let mut mk_tls = tokio_postgres_rustls::MakeRustlsConnect::new(client_config);
     319            0 :         let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
     320            0 :             &mut mk_tls,
     321            0 :             host,
     322            0 :         )?;
     323              : 
     324              :         // connect_raw() will not use TLS if sslmode is "disable"
     325            0 :         let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
     326            0 :         let (client, connection) = self.0.connect_raw(stream, tls).await?;
     327            0 :         drop(pause);
     328            0 :         tracing::Span::current().record("pid", tracing::field::display(client.get_process_id()));
     329            0 :         let stream = connection.stream.into_inner();
     330            0 : 
     331            0 :         info!(
     332            0 :             cold_start_info = ctx.cold_start_info().as_str(),
     333            0 :             "connected to compute node at {host} ({socket_addr}) sslmode={:?}",
     334            0 :             self.0.get_ssl_mode()
     335              :         );
     336              : 
     337              :         // This is very ugly but as of now there's no better way to
     338              :         // extract the connection parameters from tokio-postgres' connection.
     339              :         // TODO: solve this problem in a more elegant manner (e.g. the new library).
     340            0 :         let params = connection.parameters;
     341            0 : 
     342            0 :         // NB: CancelToken is supposed to hold socket_addr, but we use connect_raw.
     343            0 :         // Yet another reason to rework the connection establishing code.
     344            0 :         let cancel_closure = CancelClosure::new(socket_addr, client.cancel_token());
     345            0 : 
     346            0 :         let connection = PostgresConnection {
     347            0 :             stream,
     348            0 :             params,
     349            0 :             cancel_closure,
     350            0 :             aux,
     351            0 :             _guage: Metrics::get().proxy.db_connections.guard(ctx.protocol()),
     352            0 :         };
     353            0 : 
     354            0 :         Ok(connection)
     355            0 :     }
     356              : }
     357              : 
     358              : /// Retrieve `options` from a startup message, dropping all proxy-secific flags.
     359            6 : fn filtered_options(params: &StartupMessageParams) -> Option<String> {
     360              :     #[allow(unstable_name_collisions)]
     361            6 :     let options: String = params
     362            6 :         .options_raw()?
     363           13 :         .filter(|opt| parse_endpoint_param(opt).is_none() && neon_option(opt).is_none())
     364            6 :         .intersperse(" ") // TODO: use impl from std once it's stabilized
     365            6 :         .collect();
     366            6 : 
     367            6 :     // Don't even bother with empty options.
     368            6 :     if options.is_empty() {
     369            3 :         return None;
     370            3 :     }
     371            3 : 
     372            3 :     Some(options)
     373            6 : }
     374              : 
     375            0 : fn load_certs() -> Result<Arc<rustls::RootCertStore>, Vec<rustls_native_certs::Error>> {
     376            0 :     let der_certs = rustls_native_certs::load_native_certs();
     377            0 : 
     378            0 :     if !der_certs.errors.is_empty() {
     379            0 :         return Err(der_certs.errors);
     380            0 :     }
     381            0 : 
     382            0 :     let mut store = rustls::RootCertStore::empty();
     383            0 :     store.add_parsable_certificates(der_certs.certs);
     384            0 :     Ok(Arc::new(store))
     385            0 : }
     386              : static TLS_ROOTS: OnceCell<Arc<rustls::RootCertStore>> = OnceCell::new();
     387              : 
     388              : #[derive(Debug)]
     389              : struct AcceptEverythingVerifier;
     390              : impl ServerCertVerifier for AcceptEverythingVerifier {
     391            0 :     fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
     392              :         use rustls::SignatureScheme;
     393              :         // The schemes for which `SignatureScheme::supported_in_tls13` returns true.
     394            0 :         vec![
     395            0 :             SignatureScheme::ECDSA_NISTP521_SHA512,
     396            0 :             SignatureScheme::ECDSA_NISTP384_SHA384,
     397            0 :             SignatureScheme::ECDSA_NISTP256_SHA256,
     398            0 :             SignatureScheme::RSA_PSS_SHA512,
     399            0 :             SignatureScheme::RSA_PSS_SHA384,
     400            0 :             SignatureScheme::RSA_PSS_SHA256,
     401            0 :             SignatureScheme::ED25519,
     402            0 :         ]
     403            0 :     }
     404            0 :     fn verify_server_cert(
     405            0 :         &self,
     406            0 :         _end_entity: &rustls::pki_types::CertificateDer<'_>,
     407            0 :         _intermediates: &[rustls::pki_types::CertificateDer<'_>],
     408            0 :         _server_name: &rustls::pki_types::ServerName<'_>,
     409            0 :         _ocsp_response: &[u8],
     410            0 :         _now: rustls::pki_types::UnixTime,
     411            0 :     ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
     412            0 :         Ok(rustls::client::danger::ServerCertVerified::assertion())
     413            0 :     }
     414            0 :     fn verify_tls12_signature(
     415            0 :         &self,
     416            0 :         _message: &[u8],
     417            0 :         _cert: &rustls::pki_types::CertificateDer<'_>,
     418            0 :         _dss: &rustls::DigitallySignedStruct,
     419            0 :     ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
     420            0 :         Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
     421            0 :     }
     422            0 :     fn verify_tls13_signature(
     423            0 :         &self,
     424            0 :         _message: &[u8],
     425            0 :         _cert: &rustls::pki_types::CertificateDer<'_>,
     426            0 :         _dss: &rustls::DigitallySignedStruct,
     427            0 :     ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
     428            0 :         Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
     429            0 :     }
     430              : }
     431              : 
     432              : #[cfg(test)]
     433              : mod tests {
     434              :     use super::*;
     435              : 
     436              :     #[test]
     437            1 :     fn test_filtered_options() {
     438            1 :         // Empty options is unlikely to be useful anyway.
     439            1 :         let params = StartupMessageParams::new([("options", "")]);
     440            1 :         assert_eq!(filtered_options(&params), None);
     441              : 
     442              :         // It's likely that clients will only use options to specify endpoint/project.
     443            1 :         let params = StartupMessageParams::new([("options", "project=foo")]);
     444            1 :         assert_eq!(filtered_options(&params), None);
     445              : 
     446              :         // Same, because unescaped whitespaces are no-op.
     447            1 :         let params = StartupMessageParams::new([("options", " project=foo ")]);
     448            1 :         assert_eq!(filtered_options(&params).as_deref(), None);
     449              : 
     450            1 :         let params = StartupMessageParams::new([("options", r"\  project=foo \ ")]);
     451            1 :         assert_eq!(filtered_options(&params).as_deref(), Some(r"\  \ "));
     452              : 
     453            1 :         let params = StartupMessageParams::new([("options", "project = foo")]);
     454            1 :         assert_eq!(filtered_options(&params).as_deref(), Some("project = foo"));
     455              : 
     456            1 :         let params = StartupMessageParams::new([(
     457            1 :             "options",
     458            1 :             "project = foo neon_endpoint_type:read_write   neon_lsn:0/2",
     459            1 :         )]);
     460            1 :         assert_eq!(filtered_options(&params).as_deref(), Some("project = foo"));
     461            1 :     }
     462              : }
        

Generated by: LCOV version 2.1-beta