LCOV - code coverage report
Current view: top level - proxy/src - compute.rs (source / functions) Coverage Total Hit
Test: 960803fca14b2e843c565dddf575f7017d250bc3.info Lines: 18.6 % 258 48
Test Date: 2024-06-22 23:41:44 Functions: 20.6 % 34 7

            Line data    Source code
       1              : use crate::{
       2              :     auth::parse_endpoint_param,
       3              :     cancellation::CancelClosure,
       4              :     console::{errors::WakeComputeError, messages::MetricsAuxInfo, provider::ApiLockError},
       5              :     context::RequestMonitoring,
       6              :     error::{ReportableError, UserFacingError},
       7              :     metrics::{Metrics, NumDbConnectionsGuard},
       8              :     proxy::neon_option,
       9              :     Host,
      10              : };
      11              : use futures::{FutureExt, TryFutureExt};
      12              : use itertools::Itertools;
      13              : use once_cell::sync::OnceCell;
      14              : use pq_proto::StartupMessageParams;
      15              : use rustls::{client::danger::ServerCertVerifier, pki_types::InvalidDnsNameError};
      16              : use std::{io, net::SocketAddr, sync::Arc, time::Duration};
      17              : use thiserror::Error;
      18              : use tokio::net::TcpStream;
      19              : use tokio_postgres::tls::MakeTlsConnect;
      20              : use tokio_postgres_rustls::MakeRustlsConnect;
      21              : use tracing::{error, info, warn};
      22              : 
      23              : const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
      24              : 
      25            0 : #[derive(Debug, Error)]
      26              : pub enum ConnectionError {
      27              :     /// This error doesn't seem to reveal any secrets; for instance,
      28              :     /// `tokio_postgres::error::Kind` doesn't contain ip addresses and such.
      29              :     #[error("{COULD_NOT_CONNECT}: {0}")]
      30              :     Postgres(#[from] tokio_postgres::Error),
      31              : 
      32              :     #[error("{COULD_NOT_CONNECT}: {0}")]
      33              :     CouldNotConnect(#[from] io::Error),
      34              : 
      35              :     #[error("{COULD_NOT_CONNECT}: {0}")]
      36              :     TlsError(#[from] InvalidDnsNameError),
      37              : 
      38              :     #[error("{COULD_NOT_CONNECT}: {0}")]
      39              :     WakeComputeError(#[from] WakeComputeError),
      40              : 
      41              :     #[error("error acquiring resource permit: {0}")]
      42              :     TooManyConnectionAttempts(#[from] ApiLockError),
      43              : }
      44              : 
      45              : impl UserFacingError for ConnectionError {
      46            0 :     fn to_string_client(&self) -> String {
      47            0 :         use ConnectionError::*;
      48            0 :         match self {
      49              :             // This helps us drop irrelevant library-specific prefixes.
      50              :             // TODO: propagate severity level and other parameters.
      51            0 :             Postgres(err) => match err.as_db_error() {
      52            0 :                 Some(err) => {
      53            0 :                     let msg = err.message();
      54            0 : 
      55            0 :                     if msg.starts_with("unsupported startup parameter: ")
      56            0 :                         || msg.starts_with("unsupported startup parameter in options: ")
      57              :                     {
      58            0 :                         format!("{msg}. Please use unpooled connection or remove this parameter from the startup package. More details: https://neon.tech/docs/connect/connection-errors#unsupported-startup-parameter")
      59              :                     } else {
      60            0 :                         msg.to_owned()
      61              :                     }
      62              :                 }
      63            0 :                 None => err.to_string(),
      64              :             },
      65            0 :             WakeComputeError(err) => err.to_string_client(),
      66              :             TooManyConnectionAttempts(_) => {
      67            0 :                 "Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
      68              :             }
      69            0 :             _ => COULD_NOT_CONNECT.to_owned(),
      70              :         }
      71            0 :     }
      72              : }
      73              : 
      74              : impl ReportableError for ConnectionError {
      75            0 :     fn get_error_kind(&self) -> crate::error::ErrorKind {
      76            0 :         match self {
      77            0 :             ConnectionError::Postgres(e) if e.as_db_error().is_some() => {
      78            0 :                 crate::error::ErrorKind::Postgres
      79              :             }
      80            0 :             ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute,
      81            0 :             ConnectionError::CouldNotConnect(_) => crate::error::ErrorKind::Compute,
      82            0 :             ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
      83            0 :             ConnectionError::WakeComputeError(e) => e.get_error_kind(),
      84            0 :             ConnectionError::TooManyConnectionAttempts(e) => e.get_error_kind(),
      85              :         }
      86            0 :     }
      87              : }
      88              : 
      89              : /// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
      90              : pub type ScramKeys = tokio_postgres::config::ScramKeys<32>;
      91              : 
      92              : /// A config for establishing a connection to compute node.
      93              : /// Eventually, `tokio_postgres` will be replaced with something better.
      94              : /// Newtype allows us to implement methods on top of it.
      95              : #[derive(Clone, Default)]
      96              : pub struct ConnCfg(Box<tokio_postgres::Config>);
      97              : 
      98              : /// Creation and initialization routines.
      99              : impl ConnCfg {
     100           20 :     pub fn new() -> Self {
     101           20 :         Self::default()
     102           20 :     }
     103              : 
     104              :     /// Reuse password or auth keys from the other config.
     105            8 :     pub fn reuse_password(&mut self, other: Self) {
     106            8 :         if let Some(password) = other.get_password() {
     107            8 :             self.password(password);
     108            8 :         }
     109              : 
     110            8 :         if let Some(keys) = other.get_auth_keys() {
     111            0 :             self.auth_keys(keys);
     112            8 :         }
     113            8 :     }
     114              : 
     115            0 :     pub fn get_host(&self) -> Result<Host, WakeComputeError> {
     116            0 :         match self.0.get_hosts() {
     117            0 :             [tokio_postgres::config::Host::Tcp(s)] => Ok(s.into()),
     118              :             // we should not have multiple address or unix addresses.
     119            0 :             _ => Err(WakeComputeError::BadComputeAddress(
     120            0 :                 "invalid compute address".into(),
     121            0 :             )),
     122              :         }
     123            0 :     }
     124              : 
     125              :     /// Apply startup message params to the connection config.
     126            0 :     pub fn set_startup_params(&mut self, params: &StartupMessageParams) {
     127              :         // Only set `user` if it's not present in the config.
     128              :         // Link auth flow takes username from the console's response.
     129            0 :         if let (None, Some(user)) = (self.get_user(), params.get("user")) {
     130            0 :             self.user(user);
     131            0 :         }
     132              : 
     133              :         // Only set `dbname` if it's not present in the config.
     134              :         // Link auth flow takes dbname from the console's response.
     135            0 :         if let (None, Some(dbname)) = (self.get_dbname(), params.get("database")) {
     136            0 :             self.dbname(dbname);
     137            0 :         }
     138              : 
     139              :         // Don't add `options` if they were only used for specifying a project.
     140              :         // Connection pools don't support `options`, because they affect backend startup.
     141            0 :         if let Some(options) = filtered_options(params) {
     142            0 :             self.options(&options);
     143            0 :         }
     144              : 
     145            0 :         if let Some(app_name) = params.get("application_name") {
     146            0 :             self.application_name(app_name);
     147            0 :         }
     148              : 
     149              :         // TODO: This is especially ugly...
     150            0 :         if let Some(replication) = params.get("replication") {
     151              :             use tokio_postgres::config::ReplicationMode;
     152            0 :             match replication {
     153            0 :                 "true" | "on" | "yes" | "1" => {
     154            0 :                     self.replication_mode(ReplicationMode::Physical);
     155            0 :                 }
     156            0 :                 "database" => {
     157            0 :                     self.replication_mode(ReplicationMode::Logical);
     158            0 :                 }
     159            0 :                 _other => {}
     160              :             }
     161            0 :         }
     162              : 
     163              :         // TODO: extend the list of the forwarded startup parameters.
     164              :         // Currently, tokio-postgres doesn't allow us to pass
     165              :         // arbitrary parameters, but the ones above are a good start.
     166              :         //
     167              :         // This and the reverse params problem can be better addressed
     168              :         // in a bespoke connection machinery (a new library for that sake).
     169            0 :     }
     170              : }
     171              : 
     172              : impl std::ops::Deref for ConnCfg {
     173              :     type Target = tokio_postgres::Config;
     174              : 
     175           16 :     fn deref(&self) -> &Self::Target {
     176           16 :         &self.0
     177           16 :     }
     178              : }
     179              : 
     180              : /// For now, let's make it easier to setup the config.
     181              : impl std::ops::DerefMut for ConnCfg {
     182           20 :     fn deref_mut(&mut self) -> &mut Self::Target {
     183           20 :         &mut self.0
     184           20 :     }
     185              : }
     186              : 
     187              : impl ConnCfg {
     188              :     /// Establish a raw TCP connection to the compute node.
     189            0 :     async fn connect_raw(&self, timeout: Duration) -> io::Result<(SocketAddr, TcpStream, &str)> {
     190            0 :         use tokio_postgres::config::Host;
     191            0 : 
     192            0 :         // wrap TcpStream::connect with timeout
     193            0 :         let connect_with_timeout = |host, port| {
     194            0 :             tokio::time::timeout(timeout, TcpStream::connect((host, port))).map(
     195            0 :                 move |res| match res {
     196            0 :                     Ok(tcpstream_connect_res) => tcpstream_connect_res,
     197            0 :                     Err(_) => Err(io::Error::new(
     198            0 :                         io::ErrorKind::TimedOut,
     199            0 :                         format!("exceeded connection timeout {timeout:?}"),
     200            0 :                     )),
     201            0 :                 },
     202            0 :             )
     203            0 :         };
     204              : 
     205            0 :         let connect_once = |host, port| {
     206            0 :             info!("trying to connect to compute node at {host}:{port}");
     207            0 :             connect_with_timeout(host, port).and_then(|socket| async {
     208            0 :                 let socket_addr = socket.peer_addr()?;
     209            0 :                 // This prevents load balancer from severing the connection.
     210            0 :                 socket2::SockRef::from(&socket).set_keepalive(true)?;
     211            0 :                 Ok((socket_addr, socket))
     212            0 :             })
     213            0 :         };
     214              : 
     215              :         // We can't reuse connection establishing logic from `tokio_postgres` here,
     216              :         // because it has no means for extracting the underlying socket which we
     217              :         // require for our business.
     218            0 :         let mut connection_error = None;
     219            0 :         let ports = self.0.get_ports();
     220            0 :         let hosts = self.0.get_hosts();
     221            0 :         // the ports array is supposed to have 0 entries, 1 entry, or as many entries as in the hosts array
     222            0 :         if ports.len() > 1 && ports.len() != hosts.len() {
     223            0 :             return Err(io::Error::new(
     224            0 :                 io::ErrorKind::Other,
     225            0 :                 format!(
     226            0 :                     "bad compute config, \
     227            0 :                      ports and hosts entries' count does not match: {:?}",
     228            0 :                     self.0
     229            0 :                 ),
     230            0 :             ));
     231            0 :         }
     232              : 
     233            0 :         for (i, host) in hosts.iter().enumerate() {
     234            0 :             let port = ports.get(i).or_else(|| ports.first()).unwrap_or(&5432);
     235            0 :             let host = match host {
     236            0 :                 Host::Tcp(host) => host.as_str(),
     237            0 :                 Host::Unix(_) => continue, // unix sockets are not welcome here
     238              :             };
     239              : 
     240            0 :             match connect_once(host, *port).await {
     241            0 :                 Ok((sockaddr, stream)) => return Ok((sockaddr, stream, host)),
     242            0 :                 Err(err) => {
     243            0 :                     // We can't throw an error here, as there might be more hosts to try.
     244            0 :                     warn!("couldn't connect to compute node at {host}:{port}: {err}");
     245            0 :                     connection_error = Some(err);
     246              :                 }
     247              :             }
     248              :         }
     249              : 
     250            0 :         Err(connection_error.unwrap_or_else(|| {
     251            0 :             io::Error::new(
     252            0 :                 io::ErrorKind::Other,
     253            0 :                 format!("bad compute config: {:?}", self.0),
     254            0 :             )
     255            0 :         }))
     256            0 :     }
     257              : }
     258              : 
     259              : pub struct PostgresConnection {
     260              :     /// Socket connected to a compute node.
     261              :     pub stream: tokio_postgres::maybe_tls_stream::MaybeTlsStream<
     262              :         tokio::net::TcpStream,
     263              :         tokio_postgres_rustls::RustlsStream<tokio::net::TcpStream>,
     264              :     >,
     265              :     /// PostgreSQL connection parameters.
     266              :     pub params: std::collections::HashMap<String, String>,
     267              :     /// Query cancellation token.
     268              :     pub cancel_closure: CancelClosure,
     269              :     /// Labels for proxy's metrics.
     270              :     pub aux: MetricsAuxInfo,
     271              : 
     272              :     _guage: NumDbConnectionsGuard<'static>,
     273              : }
     274              : 
     275              : impl ConnCfg {
     276              :     /// Connect to a corresponding compute node.
     277            0 :     pub async fn connect(
     278            0 :         &self,
     279            0 :         ctx: &mut RequestMonitoring,
     280            0 :         allow_self_signed_compute: bool,
     281            0 :         aux: MetricsAuxInfo,
     282            0 :         timeout: Duration,
     283            0 :     ) -> Result<PostgresConnection, ConnectionError> {
     284            0 :         let pause = ctx.latency_timer.pause(crate::metrics::Waiting::Compute);
     285            0 :         let (socket_addr, stream, host) = self.connect_raw(timeout).await?;
     286            0 :         drop(pause);
     287              : 
     288            0 :         let client_config = if allow_self_signed_compute {
     289              :             // Allow all certificates for creating the connection
     290            0 :             let verifier = Arc::new(AcceptEverythingVerifier) as Arc<dyn ServerCertVerifier>;
     291            0 :             rustls::ClientConfig::builder()
     292            0 :                 .dangerous()
     293            0 :                 .with_custom_certificate_verifier(verifier)
     294              :         } else {
     295            0 :             let root_store = TLS_ROOTS.get_or_try_init(load_certs)?.clone();
     296            0 :             rustls::ClientConfig::builder().with_root_certificates(root_store)
     297              :         };
     298            0 :         let client_config = client_config.with_no_client_auth();
     299            0 : 
     300            0 :         let mut mk_tls = tokio_postgres_rustls::MakeRustlsConnect::new(client_config);
     301            0 :         let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
     302            0 :             &mut mk_tls,
     303            0 :             host,
     304            0 :         )?;
     305              : 
     306              :         // connect_raw() will not use TLS if sslmode is "disable"
     307            0 :         let pause = ctx.latency_timer.pause(crate::metrics::Waiting::Compute);
     308            0 :         let (client, connection) = self.0.connect_raw(stream, tls).await?;
     309            0 :         drop(pause);
     310            0 :         tracing::Span::current().record("pid", &tracing::field::display(client.get_process_id()));
     311            0 :         let stream = connection.stream.into_inner();
     312            0 : 
     313            0 :         info!(
     314            0 :             cold_start_info = ctx.cold_start_info.as_str(),
     315            0 :             "connected to compute node at {host} ({socket_addr}) sslmode={:?}",
     316            0 :             self.0.get_ssl_mode()
     317              :         );
     318              : 
     319              :         // This is very ugly but as of now there's no better way to
     320              :         // extract the connection parameters from tokio-postgres' connection.
     321              :         // TODO: solve this problem in a more elegant manner (e.g. the new library).
     322            0 :         let params = connection.parameters;
     323            0 : 
     324            0 :         // NB: CancelToken is supposed to hold socket_addr, but we use connect_raw.
     325            0 :         // Yet another reason to rework the connection establishing code.
     326            0 :         let cancel_closure = CancelClosure::new(socket_addr, client.cancel_token());
     327            0 : 
     328            0 :         let connection = PostgresConnection {
     329            0 :             stream,
     330            0 :             params,
     331            0 :             cancel_closure,
     332            0 :             aux,
     333            0 :             _guage: Metrics::get().proxy.db_connections.guard(ctx.protocol),
     334            0 :         };
     335            0 : 
     336            0 :         Ok(connection)
     337            0 :     }
     338              : }
     339              : 
     340              : /// Retrieve `options` from a startup message, dropping all proxy-secific flags.
     341           12 : fn filtered_options(params: &StartupMessageParams) -> Option<String> {
     342              :     #[allow(unstable_name_collisions)]
     343           12 :     let options: String = params
     344           12 :         .options_raw()?
     345           26 :         .filter(|opt| parse_endpoint_param(opt).is_none() && neon_option(opt).is_none())
     346           12 :         .intersperse(" ") // TODO: use impl from std once it's stabilized
     347           12 :         .collect();
     348           12 : 
     349           12 :     // Don't even bother with empty options.
     350           12 :     if options.is_empty() {
     351            6 :         return None;
     352            6 :     }
     353            6 : 
     354            6 :     Some(options)
     355           12 : }
     356              : 
     357            0 : fn load_certs() -> Result<Arc<rustls::RootCertStore>, io::Error> {
     358            0 :     let der_certs = rustls_native_certs::load_native_certs()?;
     359            0 :     let mut store = rustls::RootCertStore::empty();
     360            0 :     store.add_parsable_certificates(der_certs);
     361            0 :     Ok(Arc::new(store))
     362            0 : }
     363              : static TLS_ROOTS: OnceCell<Arc<rustls::RootCertStore>> = OnceCell::new();
     364              : 
     365              : #[derive(Debug)]
     366              : struct AcceptEverythingVerifier;
     367              : impl ServerCertVerifier for AcceptEverythingVerifier {
     368            0 :     fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
     369            0 :         use rustls::SignatureScheme::*;
     370            0 :         // The schemes for which `SignatureScheme::supported_in_tls13` returns true.
     371            0 :         vec![
     372            0 :             ECDSA_NISTP521_SHA512,
     373            0 :             ECDSA_NISTP384_SHA384,
     374            0 :             ECDSA_NISTP256_SHA256,
     375            0 :             RSA_PSS_SHA512,
     376            0 :             RSA_PSS_SHA384,
     377            0 :             RSA_PSS_SHA256,
     378            0 :             ED25519,
     379            0 :         ]
     380            0 :     }
     381            0 :     fn verify_server_cert(
     382            0 :         &self,
     383            0 :         _end_entity: &rustls::pki_types::CertificateDer<'_>,
     384            0 :         _intermediates: &[rustls::pki_types::CertificateDer<'_>],
     385            0 :         _server_name: &rustls::pki_types::ServerName<'_>,
     386            0 :         _ocsp_response: &[u8],
     387            0 :         _now: rustls::pki_types::UnixTime,
     388            0 :     ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
     389            0 :         Ok(rustls::client::danger::ServerCertVerified::assertion())
     390            0 :     }
     391            0 :     fn verify_tls12_signature(
     392            0 :         &self,
     393            0 :         _message: &[u8],
     394            0 :         _cert: &rustls::pki_types::CertificateDer<'_>,
     395            0 :         _dss: &rustls::DigitallySignedStruct,
     396            0 :     ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
     397            0 :         Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
     398            0 :     }
     399            0 :     fn verify_tls13_signature(
     400            0 :         &self,
     401            0 :         _message: &[u8],
     402            0 :         _cert: &rustls::pki_types::CertificateDer<'_>,
     403            0 :         _dss: &rustls::DigitallySignedStruct,
     404            0 :     ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
     405            0 :         Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
     406            0 :     }
     407              : }
     408              : 
     409              : #[cfg(test)]
     410              : mod tests {
     411              :     use super::*;
     412              : 
     413              :     #[test]
     414            2 :     fn test_filtered_options() {
     415            2 :         // Empty options is unlikely to be useful anyway.
     416            2 :         let params = StartupMessageParams::new([("options", "")]);
     417            2 :         assert_eq!(filtered_options(&params), None);
     418              : 
     419              :         // It's likely that clients will only use options to specify endpoint/project.
     420            2 :         let params = StartupMessageParams::new([("options", "project=foo")]);
     421            2 :         assert_eq!(filtered_options(&params), None);
     422              : 
     423              :         // Same, because unescaped whitespaces are no-op.
     424            2 :         let params = StartupMessageParams::new([("options", " project=foo ")]);
     425            2 :         assert_eq!(filtered_options(&params).as_deref(), None);
     426              : 
     427            2 :         let params = StartupMessageParams::new([("options", r"\  project=foo \ ")]);
     428            2 :         assert_eq!(filtered_options(&params).as_deref(), Some(r"\  \ "));
     429              : 
     430            2 :         let params = StartupMessageParams::new([("options", "project = foo")]);
     431            2 :         assert_eq!(filtered_options(&params).as_deref(), Some("project = foo"));
     432              : 
     433            2 :         let params = StartupMessageParams::new([(
     434            2 :             "options",
     435            2 :             "project = foo neon_endpoint_type:read_write   neon_lsn:0/2",
     436            2 :         )]);
     437            2 :         assert_eq!(filtered_options(&params).as_deref(), Some("project = foo"));
     438            2 :     }
     439              : }
        

Generated by: LCOV version 2.1-beta