LCOV - code coverage report
Current view: top level - proxy/src - compute.rs (source / functions) Coverage Total Hit
Test: 71d97c4519b9017c5903db4dfe4edf4a84645500.info Lines: 21.0 % 214 45
Test Date: 2024-12-19 16:48:20 Functions: 33.3 % 21 7

            Line data    Source code
       1              : use std::io;
       2              : use std::net::SocketAddr;
       3              : use std::sync::Arc;
       4              : use std::time::Duration;
       5              : 
       6              : use futures::{FutureExt, TryFutureExt};
       7              : use itertools::Itertools;
       8              : use once_cell::sync::OnceCell;
       9              : use postgres_client::tls::MakeTlsConnect;
      10              : use postgres_client::{CancelToken, RawConnection};
      11              : use postgres_protocol::message::backend::NoticeResponseBody;
      12              : use pq_proto::StartupMessageParams;
      13              : use rustls::crypto::ring;
      14              : use rustls::pki_types::InvalidDnsNameError;
      15              : use thiserror::Error;
      16              : use tokio::net::TcpStream;
      17              : use tracing::{debug, error, info, warn};
      18              : 
      19              : use crate::auth::parse_endpoint_param;
      20              : use crate::cancellation::CancelClosure;
      21              : use crate::context::RequestContext;
      22              : use crate::control_plane::client::ApiLockError;
      23              : use crate::control_plane::errors::WakeComputeError;
      24              : use crate::control_plane::messages::MetricsAuxInfo;
      25              : use crate::error::{ReportableError, UserFacingError};
      26              : use crate::metrics::{Metrics, NumDbConnectionsGuard};
      27              : use crate::postgres_rustls::MakeRustlsConnect;
      28              : use crate::proxy::neon_option;
      29              : use crate::types::Host;
      30              : 
      31              : pub const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
      32              : 
      33              : #[derive(Debug, Error)]
      34              : pub(crate) enum ConnectionError {
      35              :     /// This error doesn't seem to reveal any secrets; for instance,
      36              :     /// `postgres_client::error::Kind` doesn't contain ip addresses and such.
      37              :     #[error("{COULD_NOT_CONNECT}: {0}")]
      38              :     Postgres(#[from] postgres_client::Error),
      39              : 
      40              :     #[error("{COULD_NOT_CONNECT}: {0}")]
      41              :     CouldNotConnect(#[from] io::Error),
      42              : 
      43              :     #[error("Couldn't load native TLS certificates: {0:?}")]
      44              :     TlsCertificateError(Vec<rustls_native_certs::Error>),
      45              : 
      46              :     #[error("{COULD_NOT_CONNECT}: {0}")]
      47              :     TlsError(#[from] InvalidDnsNameError),
      48              : 
      49              :     #[error("{COULD_NOT_CONNECT}: {0}")]
      50              :     WakeComputeError(#[from] WakeComputeError),
      51              : 
      52              :     #[error("error acquiring resource permit: {0}")]
      53              :     TooManyConnectionAttempts(#[from] ApiLockError),
      54              : }
      55              : 
      56              : impl UserFacingError for ConnectionError {
      57            0 :     fn to_string_client(&self) -> String {
      58            0 :         match self {
      59              :             // This helps us drop irrelevant library-specific prefixes.
      60              :             // TODO: propagate severity level and other parameters.
      61            0 :             ConnectionError::Postgres(err) => match err.as_db_error() {
      62            0 :                 Some(err) => {
      63            0 :                     let msg = err.message();
      64            0 : 
      65            0 :                     if msg.starts_with("unsupported startup parameter: ")
      66            0 :                         || msg.starts_with("unsupported startup parameter in options: ")
      67              :                     {
      68            0 :                         format!("{msg}. Please use unpooled connection or remove this parameter from the startup package. More details: https://neon.tech/docs/connect/connection-errors#unsupported-startup-parameter")
      69              :                     } else {
      70            0 :                         msg.to_owned()
      71              :                     }
      72              :                 }
      73            0 :                 None => err.to_string(),
      74              :             },
      75            0 :             ConnectionError::WakeComputeError(err) => err.to_string_client(),
      76              :             ConnectionError::TooManyConnectionAttempts(_) => {
      77            0 :                 "Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
      78              :             }
      79            0 :             _ => COULD_NOT_CONNECT.to_owned(),
      80              :         }
      81            0 :     }
      82              : }
      83              : 
      84              : impl ReportableError for ConnectionError {
      85            0 :     fn get_error_kind(&self) -> crate::error::ErrorKind {
      86            0 :         match self {
      87            0 :             ConnectionError::Postgres(e) if e.as_db_error().is_some() => {
      88            0 :                 crate::error::ErrorKind::Postgres
      89              :             }
      90            0 :             ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute,
      91            0 :             ConnectionError::CouldNotConnect(_) => crate::error::ErrorKind::Compute,
      92            0 :             ConnectionError::TlsCertificateError(_) => crate::error::ErrorKind::Service,
      93            0 :             ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
      94            0 :             ConnectionError::WakeComputeError(e) => e.get_error_kind(),
      95            0 :             ConnectionError::TooManyConnectionAttempts(e) => e.get_error_kind(),
      96              :         }
      97            0 :     }
      98              : }
      99              : 
     100              : /// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
     101              : pub(crate) type ScramKeys = postgres_client::config::ScramKeys<32>;
     102              : 
     103              : /// A config for establishing a connection to compute node.
     104              : /// Eventually, `postgres_client` will be replaced with something better.
     105              : /// Newtype allows us to implement methods on top of it.
     106              : #[derive(Clone)]
     107              : pub(crate) struct ConnCfg(Box<postgres_client::Config>);
     108              : 
     109              : /// Creation and initialization routines.
     110              : impl ConnCfg {
     111           10 :     pub(crate) fn new(host: String, port: u16) -> Self {
     112           10 :         Self(Box::new(postgres_client::Config::new(host, port)))
     113           10 :     }
     114              : 
     115              :     /// Reuse password or auth keys from the other config.
     116            4 :     pub(crate) fn reuse_password(&mut self, other: Self) {
     117            4 :         if let Some(password) = other.get_password() {
     118            4 :             self.password(password);
     119            4 :         }
     120              : 
     121            4 :         if let Some(keys) = other.get_auth_keys() {
     122            0 :             self.auth_keys(keys);
     123            4 :         }
     124            4 :     }
     125              : 
     126            0 :     pub(crate) fn get_host(&self) -> Host {
     127            0 :         match self.0.get_host() {
     128            0 :             postgres_client::config::Host::Tcp(s) => s.into(),
     129            0 :         }
     130            0 :     }
     131              : 
     132              :     /// Apply startup message params to the connection config.
     133            0 :     pub(crate) fn set_startup_params(
     134            0 :         &mut self,
     135            0 :         params: &StartupMessageParams,
     136            0 :         arbitrary_params: bool,
     137            0 :     ) {
     138            0 :         if !arbitrary_params {
     139            0 :             self.set_param("client_encoding", "UTF8");
     140            0 :         }
     141            0 :         for (k, v) in params.iter() {
     142            0 :             match k {
     143              :                 // Only set `user` if it's not present in the config.
     144              :                 // Console redirect auth flow takes username from the console's response.
     145            0 :                 "user" if self.user_is_set() => continue,
     146            0 :                 "database" if self.db_is_set() => continue,
     147            0 :                 "options" => {
     148            0 :                     if let Some(options) = filtered_options(v) {
     149            0 :                         self.set_param(k, &options);
     150            0 :                     }
     151              :                 }
     152            0 :                 "user" | "database" | "application_name" | "replication" => {
     153            0 :                     self.set_param(k, v);
     154            0 :                 }
     155              : 
     156              :                 // if we allow arbitrary params, then we forward them through.
     157              :                 // this is a flag for a period of backwards compatibility
     158            0 :                 k if arbitrary_params => {
     159            0 :                     self.set_param(k, v);
     160            0 :                 }
     161            0 :                 _ => {}
     162              :             }
     163              :         }
     164            0 :     }
     165              : }
     166              : 
     167              : impl std::ops::Deref for ConnCfg {
     168              :     type Target = postgres_client::Config;
     169              : 
     170            8 :     fn deref(&self) -> &Self::Target {
     171            8 :         &self.0
     172            8 :     }
     173              : }
     174              : 
     175              : /// For now, let's make it easier to setup the config.
     176              : impl std::ops::DerefMut for ConnCfg {
     177           10 :     fn deref_mut(&mut self) -> &mut Self::Target {
     178           10 :         &mut self.0
     179           10 :     }
     180              : }
     181              : 
     182              : impl ConnCfg {
     183              :     /// Establish a raw TCP connection to the compute node.
     184            0 :     async fn connect_raw(&self, timeout: Duration) -> io::Result<(SocketAddr, TcpStream, &str)> {
     185              :         use postgres_client::config::Host;
     186              : 
     187              :         // wrap TcpStream::connect with timeout
     188            0 :         let connect_with_timeout = |host, port| {
     189            0 :             tokio::time::timeout(timeout, TcpStream::connect((host, port))).map(
     190            0 :                 move |res| match res {
     191            0 :                     Ok(tcpstream_connect_res) => tcpstream_connect_res,
     192            0 :                     Err(_) => Err(io::Error::new(
     193            0 :                         io::ErrorKind::TimedOut,
     194            0 :                         format!("exceeded connection timeout {timeout:?}"),
     195            0 :                     )),
     196            0 :                 },
     197            0 :             )
     198            0 :         };
     199              : 
     200            0 :         let connect_once = |host, port| {
     201            0 :             debug!("trying to connect to compute node at {host}:{port}");
     202            0 :             connect_with_timeout(host, port).and_then(|socket| async {
     203            0 :                 let socket_addr = socket.peer_addr()?;
     204              :                 // This prevents load balancer from severing the connection.
     205            0 :                 socket2::SockRef::from(&socket).set_keepalive(true)?;
     206            0 :                 Ok((socket_addr, socket))
     207            0 :             })
     208            0 :         };
     209              : 
     210              :         // We can't reuse connection establishing logic from `postgres_client` here,
     211              :         // because it has no means for extracting the underlying socket which we
     212              :         // require for our business.
     213            0 :         let port = self.0.get_port();
     214            0 :         let host = self.0.get_host();
     215            0 : 
     216            0 :         let host = match host {
     217            0 :             Host::Tcp(host) => host.as_str(),
     218            0 :         };
     219            0 : 
     220            0 :         match connect_once(host, port).await {
     221            0 :             Ok((sockaddr, stream)) => Ok((sockaddr, stream, host)),
     222            0 :             Err(err) => {
     223            0 :                 warn!("couldn't connect to compute node at {host}:{port}: {err}");
     224            0 :                 Err(err)
     225              :             }
     226              :         }
     227            0 :     }
     228              : }
     229              : 
     230              : type RustlsStream = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::Stream;
     231              : 
     232              : pub(crate) struct PostgresConnection {
     233              :     /// Socket connected to a compute node.
     234              :     pub(crate) stream:
     235              :         postgres_client::maybe_tls_stream::MaybeTlsStream<tokio::net::TcpStream, RustlsStream>,
     236              :     /// PostgreSQL connection parameters.
     237              :     pub(crate) params: std::collections::HashMap<String, String>,
     238              :     /// Query cancellation token.
     239              :     pub(crate) cancel_closure: CancelClosure,
     240              :     /// Labels for proxy's metrics.
     241              :     pub(crate) aux: MetricsAuxInfo,
     242              :     /// Notices received from compute after authenticating
     243              :     pub(crate) delayed_notice: Vec<NoticeResponseBody>,
     244              : 
     245              :     _guage: NumDbConnectionsGuard<'static>,
     246              : }
     247              : 
     248              : impl ConnCfg {
     249              :     /// Connect to a corresponding compute node.
     250            0 :     pub(crate) async fn connect(
     251            0 :         &self,
     252            0 :         ctx: &RequestContext,
     253            0 :         aux: MetricsAuxInfo,
     254            0 :         timeout: Duration,
     255            0 :     ) -> Result<PostgresConnection, ConnectionError> {
     256            0 :         let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
     257            0 :         let (socket_addr, stream, host) = self.connect_raw(timeout).await?;
     258            0 :         drop(pause);
     259              : 
     260            0 :         let root_store = TLS_ROOTS
     261            0 :             .get_or_try_init(load_certs)
     262            0 :             .map_err(ConnectionError::TlsCertificateError)?
     263            0 :             .clone();
     264            0 : 
     265            0 :         let client_config =
     266            0 :             rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
     267            0 :                 .with_safe_default_protocol_versions()
     268            0 :                 .expect("ring should support the default protocol versions")
     269            0 :                 .with_root_certificates(root_store)
     270            0 :                 .with_no_client_auth();
     271            0 : 
     272            0 :         let mut mk_tls = crate::postgres_rustls::MakeRustlsConnect::new(client_config);
     273            0 :         let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
     274            0 :             &mut mk_tls,
     275            0 :             host,
     276            0 :         )?;
     277              : 
     278              :         // connect_raw() will not use TLS if sslmode is "disable"
     279            0 :         let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
     280            0 :         let connection = self.0.connect_raw(stream, tls).await?;
     281            0 :         drop(pause);
     282            0 : 
     283            0 :         let RawConnection {
     284            0 :             stream,
     285            0 :             parameters,
     286            0 :             delayed_notice,
     287            0 :             process_id,
     288            0 :             secret_key,
     289            0 :         } = connection;
     290            0 : 
     291            0 :         tracing::Span::current().record("pid", tracing::field::display(process_id));
     292            0 :         let stream = stream.into_inner();
     293            0 : 
     294            0 :         // TODO: lots of useful info but maybe we can move it elsewhere (eg traces?)
     295            0 :         info!(
     296            0 :             cold_start_info = ctx.cold_start_info().as_str(),
     297            0 :             "connected to compute node at {host} ({socket_addr}) sslmode={:?}",
     298            0 :             self.0.get_ssl_mode()
     299              :         );
     300              : 
     301              :         // NB: CancelToken is supposed to hold socket_addr, but we use connect_raw.
     302              :         // Yet another reason to rework the connection establishing code.
     303            0 :         let cancel_closure = CancelClosure::new(
     304            0 :             socket_addr,
     305            0 :             CancelToken {
     306            0 :                 socket_config: None,
     307            0 :                 ssl_mode: self.0.get_ssl_mode(),
     308            0 :                 process_id,
     309            0 :                 secret_key,
     310            0 :             },
     311            0 :             vec![],
     312            0 :             host.to_string(),
     313            0 :         );
     314            0 : 
     315            0 :         let connection = PostgresConnection {
     316            0 :             stream,
     317            0 :             params: parameters,
     318            0 :             delayed_notice,
     319            0 :             cancel_closure,
     320            0 :             aux,
     321            0 :             _guage: Metrics::get().proxy.db_connections.guard(ctx.protocol()),
     322            0 :         };
     323            0 : 
     324            0 :         Ok(connection)
     325            0 :     }
     326              : }
     327              : 
     328              : /// Retrieve `options` from a startup message, dropping all proxy-secific flags.
     329            6 : fn filtered_options(options: &str) -> Option<String> {
     330            6 :     #[allow(unstable_name_collisions)]
     331            6 :     let options: String = StartupMessageParams::parse_options_raw(options)
     332           14 :         .filter(|opt| parse_endpoint_param(opt).is_none() && neon_option(opt).is_none())
     333            6 :         .intersperse(" ") // TODO: use impl from std once it's stabilized
     334            6 :         .collect();
     335            6 : 
     336            6 :     // Don't even bother with empty options.
     337            6 :     if options.is_empty() {
     338            3 :         return None;
     339            3 :     }
     340            3 : 
     341            3 :     Some(options)
     342            6 : }
     343              : 
     344            0 : pub(crate) fn load_certs() -> Result<Arc<rustls::RootCertStore>, Vec<rustls_native_certs::Error>> {
     345            0 :     let der_certs = rustls_native_certs::load_native_certs();
     346            0 : 
     347            0 :     if !der_certs.errors.is_empty() {
     348            0 :         return Err(der_certs.errors);
     349            0 :     }
     350            0 : 
     351            0 :     let mut store = rustls::RootCertStore::empty();
     352            0 :     store.add_parsable_certificates(der_certs.certs);
     353            0 :     Ok(Arc::new(store))
     354            0 : }
     355              : static TLS_ROOTS: OnceCell<Arc<rustls::RootCertStore>> = OnceCell::new();
     356              : 
     357              : #[cfg(test)]
     358              : mod tests {
     359              :     use super::*;
     360              : 
     361              :     #[test]
     362            1 :     fn test_filtered_options() {
     363            1 :         // Empty options is unlikely to be useful anyway.
     364            1 :         let params = "";
     365            1 :         assert_eq!(filtered_options(params), None);
     366              : 
     367              :         // It's likely that clients will only use options to specify endpoint/project.
     368            1 :         let params = "project=foo";
     369            1 :         assert_eq!(filtered_options(params), None);
     370              : 
     371              :         // Same, because unescaped whitespaces are no-op.
     372            1 :         let params = " project=foo ";
     373            1 :         assert_eq!(filtered_options(params).as_deref(), None);
     374              : 
     375            1 :         let params = r"\  project=foo \ ";
     376            1 :         assert_eq!(filtered_options(params).as_deref(), Some(r"\  \ "));
     377              : 
     378            1 :         let params = "project = foo";
     379            1 :         assert_eq!(filtered_options(params).as_deref(), Some("project = foo"));
     380              : 
     381            1 :         let params = "project = foo neon_endpoint_type:read_write   neon_lsn:0/2 neon_proxy_params_compat:true";
     382            1 :         assert_eq!(filtered_options(params).as_deref(), Some("project = foo"));
     383            1 :     }
     384              : }
        

Generated by: LCOV version 2.1-beta