LCOV - code coverage report
Current view: top level - proxy/src - compute.rs (source / functions) Coverage Total Hit
Test: 90b23405d17e36048d3bb64e314067f397803f1b.info Lines: 19.1 % 251 48
Test Date: 2024-09-20 13:14:58 Functions: 20.6 % 34 7

            Line data    Source code
       1              : use crate::{
       2              :     auth::parse_endpoint_param,
       3              :     cancellation::CancelClosure,
       4              :     console::{errors::WakeComputeError, messages::MetricsAuxInfo, provider::ApiLockError},
       5              :     context::RequestMonitoring,
       6              :     error::{ReportableError, UserFacingError},
       7              :     metrics::{Metrics, NumDbConnectionsGuard},
       8              :     proxy::neon_option,
       9              :     Host,
      10              : };
      11              : use futures::{FutureExt, TryFutureExt};
      12              : use itertools::Itertools;
      13              : use once_cell::sync::OnceCell;
      14              : use pq_proto::StartupMessageParams;
      15              : use rustls::{client::danger::ServerCertVerifier, pki_types::InvalidDnsNameError};
      16              : use std::{io, net::SocketAddr, sync::Arc, time::Duration};
      17              : use thiserror::Error;
      18              : use tokio::net::TcpStream;
      19              : use tokio_postgres::tls::MakeTlsConnect;
      20              : use tokio_postgres_rustls::MakeRustlsConnect;
      21              : use tracing::{error, info, warn};
      22              : 
      23              : const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
      24              : 
      25            0 : #[derive(Debug, Error)]
      26              : pub(crate) enum ConnectionError {
      27              :     /// This error doesn't seem to reveal any secrets; for instance,
      28              :     /// `tokio_postgres::error::Kind` doesn't contain ip addresses and such.
      29              :     #[error("{COULD_NOT_CONNECT}: {0}")]
      30              :     Postgres(#[from] tokio_postgres::Error),
      31              : 
      32              :     #[error("{COULD_NOT_CONNECT}: {0}")]
      33              :     CouldNotConnect(#[from] io::Error),
      34              : 
      35              :     #[error("{COULD_NOT_CONNECT}: {0}")]
      36              :     TlsError(#[from] InvalidDnsNameError),
      37              : 
      38              :     #[error("{COULD_NOT_CONNECT}: {0}")]
      39              :     WakeComputeError(#[from] WakeComputeError),
      40              : 
      41              :     #[error("error acquiring resource permit: {0}")]
      42              :     TooManyConnectionAttempts(#[from] ApiLockError),
      43              : }
      44              : 
      45              : impl UserFacingError for ConnectionError {
      46            0 :     fn to_string_client(&self) -> String {
      47            0 :         match self {
      48              :             // This helps us drop irrelevant library-specific prefixes.
      49              :             // TODO: propagate severity level and other parameters.
      50            0 :             ConnectionError::Postgres(err) => match err.as_db_error() {
      51            0 :                 Some(err) => {
      52            0 :                     let msg = err.message();
      53            0 : 
      54            0 :                     if msg.starts_with("unsupported startup parameter: ")
      55            0 :                         || msg.starts_with("unsupported startup parameter in options: ")
      56              :                     {
      57            0 :                         format!("{msg}. Please use unpooled connection or remove this parameter from the startup package. More details: https://neon.tech/docs/connect/connection-errors#unsupported-startup-parameter")
      58              :                     } else {
      59            0 :                         msg.to_owned()
      60              :                     }
      61              :                 }
      62            0 :                 None => err.to_string(),
      63              :             },
      64            0 :             ConnectionError::WakeComputeError(err) => err.to_string_client(),
      65              :             ConnectionError::TooManyConnectionAttempts(_) => {
      66            0 :                 "Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
      67              :             }
      68            0 :             _ => COULD_NOT_CONNECT.to_owned(),
      69              :         }
      70            0 :     }
      71              : }
      72              : 
      73              : impl ReportableError for ConnectionError {
      74            0 :     fn get_error_kind(&self) -> crate::error::ErrorKind {
      75            0 :         match self {
      76            0 :             ConnectionError::Postgres(e) if e.as_db_error().is_some() => {
      77            0 :                 crate::error::ErrorKind::Postgres
      78              :             }
      79            0 :             ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute,
      80            0 :             ConnectionError::CouldNotConnect(_) => crate::error::ErrorKind::Compute,
      81            0 :             ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
      82            0 :             ConnectionError::WakeComputeError(e) => e.get_error_kind(),
      83            0 :             ConnectionError::TooManyConnectionAttempts(e) => e.get_error_kind(),
      84              :         }
      85            0 :     }
      86              : }
      87              : 
      88              : /// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
      89              : pub(crate) type ScramKeys = tokio_postgres::config::ScramKeys<32>;
      90              : 
      91              : /// A config for establishing a connection to compute node.
      92              : /// Eventually, `tokio_postgres` will be replaced with something better.
      93              : /// Newtype allows us to implement methods on top of it.
      94              : #[derive(Clone, Default)]
      95              : pub(crate) struct ConnCfg(Box<tokio_postgres::Config>);
      96              : 
      97              : /// Creation and initialization routines.
      98              : impl ConnCfg {
      99           10 :     pub(crate) fn new() -> Self {
     100           10 :         Self::default()
     101           10 :     }
     102              : 
     103              :     /// Reuse password or auth keys from the other config.
     104            4 :     pub(crate) fn reuse_password(&mut self, other: Self) {
     105            4 :         if let Some(password) = other.get_password() {
     106            4 :             self.password(password);
     107            4 :         }
     108              : 
     109            4 :         if let Some(keys) = other.get_auth_keys() {
     110            0 :             self.auth_keys(keys);
     111            4 :         }
     112            4 :     }
     113              : 
     114            0 :     pub(crate) fn get_host(&self) -> Result<Host, WakeComputeError> {
     115            0 :         match self.0.get_hosts() {
     116            0 :             [tokio_postgres::config::Host::Tcp(s)] => Ok(s.into()),
     117              :             // we should not have multiple address or unix addresses.
     118            0 :             _ => Err(WakeComputeError::BadComputeAddress(
     119            0 :                 "invalid compute address".into(),
     120            0 :             )),
     121              :         }
     122            0 :     }
     123              : 
     124              :     /// Apply startup message params to the connection config.
     125            0 :     pub(crate) fn set_startup_params(&mut self, params: &StartupMessageParams) {
     126              :         // Only set `user` if it's not present in the config.
     127              :         // Web auth flow takes username from the console's response.
     128            0 :         if let (None, Some(user)) = (self.get_user(), params.get("user")) {
     129            0 :             self.user(user);
     130            0 :         }
     131              : 
     132              :         // Only set `dbname` if it's not present in the config.
     133              :         // Web auth flow takes dbname from the console's response.
     134            0 :         if let (None, Some(dbname)) = (self.get_dbname(), params.get("database")) {
     135            0 :             self.dbname(dbname);
     136            0 :         }
     137              : 
     138              :         // Don't add `options` if they were only used for specifying a project.
     139              :         // Connection pools don't support `options`, because they affect backend startup.
     140            0 :         if let Some(options) = filtered_options(params) {
     141            0 :             self.options(&options);
     142            0 :         }
     143              : 
     144            0 :         if let Some(app_name) = params.get("application_name") {
     145            0 :             self.application_name(app_name);
     146            0 :         }
     147              : 
     148              :         // TODO: This is especially ugly...
     149            0 :         if let Some(replication) = params.get("replication") {
     150              :             use tokio_postgres::config::ReplicationMode;
     151            0 :             match replication {
     152            0 :                 "true" | "on" | "yes" | "1" => {
     153            0 :                     self.replication_mode(ReplicationMode::Physical);
     154            0 :                 }
     155            0 :                 "database" => {
     156            0 :                     self.replication_mode(ReplicationMode::Logical);
     157            0 :                 }
     158            0 :                 _other => {}
     159              :             }
     160            0 :         }
     161              : 
     162              :         // TODO: extend the list of the forwarded startup parameters.
     163              :         // Currently, tokio-postgres doesn't allow us to pass
     164              :         // arbitrary parameters, but the ones above are a good start.
     165              :         //
     166              :         // This and the reverse params problem can be better addressed
     167              :         // in a bespoke connection machinery (a new library for that sake).
     168            0 :     }
     169              : }
     170              : 
     171              : impl std::ops::Deref for ConnCfg {
     172              :     type Target = tokio_postgres::Config;
     173              : 
     174            8 :     fn deref(&self) -> &Self::Target {
     175            8 :         &self.0
     176            8 :     }
     177              : }
     178              : 
     179              : /// For now, let's make it easier to setup the config.
     180              : impl std::ops::DerefMut for ConnCfg {
     181           10 :     fn deref_mut(&mut self) -> &mut Self::Target {
     182           10 :         &mut self.0
     183           10 :     }
     184              : }
     185              : 
     186              : impl ConnCfg {
     187              :     /// Establish a raw TCP connection to the compute node.
     188            0 :     async fn connect_raw(&self, timeout: Duration) -> io::Result<(SocketAddr, TcpStream, &str)> {
     189              :         use tokio_postgres::config::Host;
     190              : 
     191              :         // wrap TcpStream::connect with timeout
     192            0 :         let connect_with_timeout = |host, port| {
     193            0 :             tokio::time::timeout(timeout, TcpStream::connect((host, port))).map(
     194            0 :                 move |res| match res {
     195            0 :                     Ok(tcpstream_connect_res) => tcpstream_connect_res,
     196            0 :                     Err(_) => Err(io::Error::new(
     197            0 :                         io::ErrorKind::TimedOut,
     198            0 :                         format!("exceeded connection timeout {timeout:?}"),
     199            0 :                     )),
     200            0 :                 },
     201            0 :             )
     202            0 :         };
     203              : 
     204            0 :         let connect_once = |host, port| {
     205            0 :             info!("trying to connect to compute node at {host}:{port}");
     206            0 :             connect_with_timeout(host, port).and_then(|socket| async {
     207            0 :                 let socket_addr = socket.peer_addr()?;
     208              :                 // This prevents load balancer from severing the connection.
     209            0 :                 socket2::SockRef::from(&socket).set_keepalive(true)?;
     210            0 :                 Ok((socket_addr, socket))
     211            0 :             })
     212            0 :         };
     213              : 
     214              :         // We can't reuse connection establishing logic from `tokio_postgres` here,
     215              :         // because it has no means for extracting the underlying socket which we
     216              :         // require for our business.
     217            0 :         let mut connection_error = None;
     218            0 :         let ports = self.0.get_ports();
     219            0 :         let hosts = self.0.get_hosts();
     220            0 :         // the ports array is supposed to have 0 entries, 1 entry, or as many entries as in the hosts array
     221            0 :         if ports.len() > 1 && ports.len() != hosts.len() {
     222            0 :             return Err(io::Error::new(
     223            0 :                 io::ErrorKind::Other,
     224            0 :                 format!(
     225            0 :                     "bad compute config, \
     226            0 :                      ports and hosts entries' count does not match: {:?}",
     227            0 :                     self.0
     228            0 :                 ),
     229            0 :             ));
     230            0 :         }
     231              : 
     232            0 :         for (i, host) in hosts.iter().enumerate() {
     233            0 :             let port = ports.get(i).or_else(|| ports.first()).unwrap_or(&5432);
     234            0 :             let host = match host {
     235            0 :                 Host::Tcp(host) => host.as_str(),
     236            0 :                 Host::Unix(_) => continue, // unix sockets are not welcome here
     237              :             };
     238              : 
     239            0 :             match connect_once(host, *port).await {
     240            0 :                 Ok((sockaddr, stream)) => return Ok((sockaddr, stream, host)),
     241            0 :                 Err(err) => {
     242            0 :                     // We can't throw an error here, as there might be more hosts to try.
     243            0 :                     warn!("couldn't connect to compute node at {host}:{port}: {err}");
     244            0 :                     connection_error = Some(err);
     245              :                 }
     246              :             }
     247              :         }
     248              : 
     249            0 :         Err(connection_error.unwrap_or_else(|| {
     250            0 :             io::Error::new(
     251            0 :                 io::ErrorKind::Other,
     252            0 :                 format!("bad compute config: {:?}", self.0),
     253            0 :             )
     254            0 :         }))
     255            0 :     }
     256              : }
     257              : 
     258              : pub(crate) struct PostgresConnection {
     259              :     /// Socket connected to a compute node.
     260              :     pub(crate) stream: tokio_postgres::maybe_tls_stream::MaybeTlsStream<
     261              :         tokio::net::TcpStream,
     262              :         tokio_postgres_rustls::RustlsStream<tokio::net::TcpStream>,
     263              :     >,
     264              :     /// PostgreSQL connection parameters.
     265              :     pub(crate) params: std::collections::HashMap<String, String>,
     266              :     /// Query cancellation token.
     267              :     pub(crate) cancel_closure: CancelClosure,
     268              :     /// Labels for proxy's metrics.
     269              :     pub(crate) aux: MetricsAuxInfo,
     270              : 
     271              :     _guage: NumDbConnectionsGuard<'static>,
     272              : }
     273              : 
     274              : impl ConnCfg {
     275              :     /// Connect to a corresponding compute node.
     276            0 :     pub(crate) async fn connect(
     277            0 :         &self,
     278            0 :         ctx: &RequestMonitoring,
     279            0 :         allow_self_signed_compute: bool,
     280            0 :         aux: MetricsAuxInfo,
     281            0 :         timeout: Duration,
     282            0 :     ) -> Result<PostgresConnection, ConnectionError> {
     283            0 :         let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
     284            0 :         let (socket_addr, stream, host) = self.connect_raw(timeout).await?;
     285            0 :         drop(pause);
     286              : 
     287            0 :         let client_config = if allow_self_signed_compute {
     288              :             // Allow all certificates for creating the connection
     289            0 :             let verifier = Arc::new(AcceptEverythingVerifier);
     290            0 :             rustls::ClientConfig::builder()
     291            0 :                 .dangerous()
     292            0 :                 .with_custom_certificate_verifier(verifier)
     293              :         } else {
     294            0 :             let root_store = TLS_ROOTS.get_or_try_init(load_certs)?.clone();
     295            0 :             rustls::ClientConfig::builder().with_root_certificates(root_store)
     296              :         };
     297            0 :         let client_config = client_config.with_no_client_auth();
     298            0 : 
     299            0 :         let mut mk_tls = tokio_postgres_rustls::MakeRustlsConnect::new(client_config);
     300            0 :         let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
     301            0 :             &mut mk_tls,
     302            0 :             host,
     303            0 :         )?;
     304              : 
     305              :         // connect_raw() will not use TLS if sslmode is "disable"
     306            0 :         let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
     307            0 :         let (client, connection) = self.0.connect_raw(stream, tls).await?;
     308            0 :         drop(pause);
     309            0 :         tracing::Span::current().record("pid", tracing::field::display(client.get_process_id()));
     310            0 :         let stream = connection.stream.into_inner();
     311            0 : 
     312            0 :         info!(
     313            0 :             cold_start_info = ctx.cold_start_info().as_str(),
     314            0 :             "connected to compute node at {host} ({socket_addr}) sslmode={:?}",
     315            0 :             self.0.get_ssl_mode()
     316              :         );
     317              : 
     318              :         // This is very ugly but as of now there's no better way to
     319              :         // extract the connection parameters from tokio-postgres' connection.
     320              :         // TODO: solve this problem in a more elegant manner (e.g. the new library).
     321            0 :         let params = connection.parameters;
     322            0 : 
     323            0 :         // NB: CancelToken is supposed to hold socket_addr, but we use connect_raw.
     324            0 :         // Yet another reason to rework the connection establishing code.
     325            0 :         let cancel_closure = CancelClosure::new(socket_addr, client.cancel_token());
     326            0 : 
     327            0 :         let connection = PostgresConnection {
     328            0 :             stream,
     329            0 :             params,
     330            0 :             cancel_closure,
     331            0 :             aux,
     332            0 :             _guage: Metrics::get().proxy.db_connections.guard(ctx.protocol()),
     333            0 :         };
     334            0 : 
     335            0 :         Ok(connection)
     336            0 :     }
     337              : }
     338              : 
     339              : /// Retrieve `options` from a startup message, dropping all proxy-secific flags.
     340            6 : fn filtered_options(params: &StartupMessageParams) -> Option<String> {
     341              :     #[allow(unstable_name_collisions)]
     342            6 :     let options: String = params
     343            6 :         .options_raw()?
     344           13 :         .filter(|opt| parse_endpoint_param(opt).is_none() && neon_option(opt).is_none())
     345            6 :         .intersperse(" ") // TODO: use impl from std once it's stabilized
     346            6 :         .collect();
     347            6 : 
     348            6 :     // Don't even bother with empty options.
     349            6 :     if options.is_empty() {
     350            3 :         return None;
     351            3 :     }
     352            3 : 
     353            3 :     Some(options)
     354            6 : }
     355              : 
     356            0 : fn load_certs() -> Result<Arc<rustls::RootCertStore>, io::Error> {
     357            0 :     let der_certs = rustls_native_certs::load_native_certs()?;
     358            0 :     let mut store = rustls::RootCertStore::empty();
     359            0 :     store.add_parsable_certificates(der_certs);
     360            0 :     Ok(Arc::new(store))
     361            0 : }
     362              : static TLS_ROOTS: OnceCell<Arc<rustls::RootCertStore>> = OnceCell::new();
     363              : 
     364              : #[derive(Debug)]
     365              : struct AcceptEverythingVerifier;
     366              : impl ServerCertVerifier for AcceptEverythingVerifier {
     367            0 :     fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
     368              :         use rustls::SignatureScheme;
     369              :         // The schemes for which `SignatureScheme::supported_in_tls13` returns true.
     370            0 :         vec![
     371            0 :             SignatureScheme::ECDSA_NISTP521_SHA512,
     372            0 :             SignatureScheme::ECDSA_NISTP384_SHA384,
     373            0 :             SignatureScheme::ECDSA_NISTP256_SHA256,
     374            0 :             SignatureScheme::RSA_PSS_SHA512,
     375            0 :             SignatureScheme::RSA_PSS_SHA384,
     376            0 :             SignatureScheme::RSA_PSS_SHA256,
     377            0 :             SignatureScheme::ED25519,
     378            0 :         ]
     379            0 :     }
     380            0 :     fn verify_server_cert(
     381            0 :         &self,
     382            0 :         _end_entity: &rustls::pki_types::CertificateDer<'_>,
     383            0 :         _intermediates: &[rustls::pki_types::CertificateDer<'_>],
     384            0 :         _server_name: &rustls::pki_types::ServerName<'_>,
     385            0 :         _ocsp_response: &[u8],
     386            0 :         _now: rustls::pki_types::UnixTime,
     387            0 :     ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
     388            0 :         Ok(rustls::client::danger::ServerCertVerified::assertion())
     389            0 :     }
     390            0 :     fn verify_tls12_signature(
     391            0 :         &self,
     392            0 :         _message: &[u8],
     393            0 :         _cert: &rustls::pki_types::CertificateDer<'_>,
     394            0 :         _dss: &rustls::DigitallySignedStruct,
     395            0 :     ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
     396            0 :         Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
     397            0 :     }
     398            0 :     fn verify_tls13_signature(
     399            0 :         &self,
     400            0 :         _message: &[u8],
     401            0 :         _cert: &rustls::pki_types::CertificateDer<'_>,
     402            0 :         _dss: &rustls::DigitallySignedStruct,
     403            0 :     ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
     404            0 :         Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
     405            0 :     }
     406              : }
     407              : 
     408              : #[cfg(test)]
     409              : mod tests {
     410              :     use super::*;
     411              : 
     412              :     #[test]
     413            1 :     fn test_filtered_options() {
     414            1 :         // Empty options is unlikely to be useful anyway.
     415            1 :         let params = StartupMessageParams::new([("options", "")]);
     416            1 :         assert_eq!(filtered_options(&params), None);
     417              : 
     418              :         // It's likely that clients will only use options to specify endpoint/project.
     419            1 :         let params = StartupMessageParams::new([("options", "project=foo")]);
     420            1 :         assert_eq!(filtered_options(&params), None);
     421              : 
     422              :         // Same, because unescaped whitespaces are no-op.
     423            1 :         let params = StartupMessageParams::new([("options", " project=foo ")]);
     424            1 :         assert_eq!(filtered_options(&params).as_deref(), None);
     425              : 
     426            1 :         let params = StartupMessageParams::new([("options", r"\  project=foo \ ")]);
     427            1 :         assert_eq!(filtered_options(&params).as_deref(), Some(r"\  \ "));
     428              : 
     429            1 :         let params = StartupMessageParams::new([("options", "project = foo")]);
     430            1 :         assert_eq!(filtered_options(&params).as_deref(), Some("project = foo"));
     431              : 
     432            1 :         let params = StartupMessageParams::new([(
     433            1 :             "options",
     434            1 :             "project = foo neon_endpoint_type:read_write   neon_lsn:0/2",
     435            1 :         )]);
     436            1 :         assert_eq!(filtered_options(&params).as_deref(), Some("project = foo"));
     437            1 :     }
     438              : }
        

Generated by: LCOV version 2.1-beta