LCOV - code coverage report
Current view: top level - proxy/src - compute.rs (source / functions) Coverage Total Hit
Test: 465a86b0c1fda0069b3e0f6c1c126e6b635a1f72.info Lines: 15.4 % 272 42
Test Date: 2024-06-25 15:47:26 Functions: 20.6 % 34 7

            Line data    Source code
       1              : use crate::{
       2              :     auth::parse_endpoint_param,
       3              :     cancellation::CancelClosure,
       4              :     console::{errors::WakeComputeError, messages::MetricsAuxInfo, provider::ApiLockError},
       5              :     context::RequestMonitoring,
       6              :     error::{ReportableError, UserFacingError},
       7              :     metrics::{Metrics, NumDbConnectionsGuard},
       8              :     proxy::neon_option,
       9              :     Host,
      10              : };
      11              : use futures::{FutureExt, TryFutureExt};
      12              : use itertools::Itertools;
      13              : use once_cell::sync::OnceCell;
      14              : use pq_proto::StartupMessageParams;
      15              : use rustls::{client::danger::ServerCertVerifier, pki_types::InvalidDnsNameError};
      16              : use std::{io, net::SocketAddr, sync::Arc, time::Duration};
      17              : use thiserror::Error;
      18              : use tokio::net::TcpStream;
      19              : use tokio_postgres::tls::MakeTlsConnect;
      20              : use tokio_postgres_rustls::MakeRustlsConnect;
      21              : use tracing::{error, info, warn};
      22              : 
      23              : const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
      24              : 
      25            0 : #[derive(Debug, Error)]
      26              : pub enum ConnectionError {
      27              :     /// This error doesn't seem to reveal any secrets; for instance,
      28              :     /// `tokio_postgres::error::Kind` doesn't contain ip addresses and such.
      29              :     #[error("{COULD_NOT_CONNECT}: {0}")]
      30              :     Postgres(#[from] tokio_postgres::Error),
      31              : 
      32              :     #[error("{COULD_NOT_CONNECT}: {0}")]
      33              :     CouldNotConnect(#[from] io::Error),
      34              : 
      35              :     #[error("{COULD_NOT_CONNECT}: {0}")]
      36              :     TlsError(#[from] InvalidDnsNameError),
      37              : 
      38              :     #[error("{COULD_NOT_CONNECT}: {0}")]
      39              :     WakeComputeError(#[from] WakeComputeError),
      40              : 
      41              :     #[error("error acquiring resource permit: {0}")]
      42              :     TooManyConnectionAttempts(#[from] ApiLockError),
      43              : }
      44              : 
      45              : impl UserFacingError for ConnectionError {
      46            0 :     fn to_string_client(&self) -> String {
      47            0 :         use ConnectionError::*;
      48            0 :         match self {
      49              :             // This helps us drop irrelevant library-specific prefixes.
      50              :             // TODO: propagate severity level and other parameters.
      51            0 :             Postgres(err) => match err.as_db_error() {
      52            0 :                 Some(err) => {
      53            0 :                     let msg = err.message();
      54            0 : 
      55            0 :                     if msg.starts_with("unsupported startup parameter: ")
      56            0 :                         || msg.starts_with("unsupported startup parameter in options: ")
      57              :                     {
      58            0 :                         format!("{msg}. Please use unpooled connection or remove this parameter from the startup package. More details: https://neon.tech/docs/connect/connection-errors#unsupported-startup-parameter")
      59              :                     } else {
      60            0 :                         msg.to_owned()
      61              :                     }
      62              :                 }
      63            0 :                 None => err.to_string(),
      64              :             },
      65            0 :             WakeComputeError(err) => err.to_string_client(),
      66              :             TooManyConnectionAttempts(_) => {
      67            0 :                 "Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
      68              :             }
      69            0 :             _ => COULD_NOT_CONNECT.to_owned(),
      70              :         }
      71            0 :     }
      72              : }
      73              : 
      74              : impl ReportableError for ConnectionError {
      75            0 :     fn get_error_kind(&self) -> crate::error::ErrorKind {
      76            0 :         match self {
      77            0 :             ConnectionError::Postgres(e) if e.as_db_error().is_some() => {
      78            0 :                 crate::error::ErrorKind::Postgres
      79              :             }
      80            0 :             ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute,
      81            0 :             ConnectionError::CouldNotConnect(_) => crate::error::ErrorKind::Compute,
      82            0 :             ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
      83            0 :             ConnectionError::WakeComputeError(e) => e.get_error_kind(),
      84            0 :             ConnectionError::TooManyConnectionAttempts(e) => e.get_error_kind(),
      85              :         }
      86            0 :     }
      87              : }
      88              : 
      89              : /// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
      90              : pub type ScramKeys = tokio_postgres::config::ScramKeys<32>;
      91              : 
      92              : /// A config for establishing a connection to compute node.
      93              : /// Eventually, `tokio_postgres` will be replaced with something better.
      94              : /// Newtype allows us to implement methods on top of it.
      95              : #[derive(Clone, Default)]
      96              : pub struct ConnCfg(Box<tokio_postgres::Config>);
      97              : 
      98              : /// Creation and initialization routines.
      99              : impl ConnCfg {
     100           20 :     pub fn new() -> Self {
     101           20 :         Self::default()
     102           20 :     }
     103              : 
     104              :     /// Reuse password or auth keys from the other config.
     105            8 :     pub fn reuse_password(&mut self, other: Self) {
     106            8 :         if let Some(password) = other.get_auth() {
     107            8 :             self.auth(password);
     108            8 :         }
     109            8 :     }
     110              : 
     111            0 :     pub fn get_host(&self) -> Result<Host, WakeComputeError> {
     112            0 :         match self.0.get_hosts() {
     113            0 :             [tokio_postgres::config::Host::Tcp(s)] => Ok(s.into()),
     114              :             // we should not have multiple address or unix addresses.
     115            0 :             _ => Err(WakeComputeError::BadComputeAddress(
     116            0 :                 "invalid compute address".into(),
     117            0 :             )),
     118              :         }
     119            0 :     }
     120              : 
     121              :     /// Apply startup message params to the connection config.
     122            0 :     pub fn set_startup_params(&mut self, params: &StartupMessageParams) {
     123            0 :         let mut client_encoding = false;
     124            0 :         for (k, v) in params.iter() {
     125            0 :             match k {
     126            0 :                 "user" => {
     127              :                     // Only set `user` if it's not present in the config.
     128              :                     // Link auth flow takes username from the console's response.
     129            0 :                     if self.get_user().is_none() {
     130            0 :                         self.user(v);
     131            0 :                     }
     132              :                 }
     133            0 :                 "database" => {
     134              :                     // Only set `dbname` if it's not present in the config.
     135              :                     // Link auth flow takes dbname from the console's response.
     136            0 :                     if self.get_dbname().is_none() {
     137            0 :                         self.dbname(v);
     138            0 :                     }
     139              :                 }
     140            0 :                 "options" => {
     141              :                     // Don't add `options` if they were only used for specifying a project.
     142              :                     // Connection pools don't support `options`, because they affect backend startup.
     143            0 :                     if let Some(options) = filtered_options(v) {
     144            0 :                         self.options(&options);
     145            0 :                     }
     146              :                 }
     147              : 
     148              :                 // the special ones in tokio-postgres that we don't want being set by the user
     149            0 :                 "dbname" => {}
     150            0 :                 "password" => {}
     151            0 :                 "sslmode" => {}
     152            0 :                 "host" => {}
     153            0 :                 "port" => {}
     154            0 :                 "connect_timeout" => {}
     155            0 :                 "keepalives" => {}
     156            0 :                 "keepalives_idle" => {}
     157            0 :                 "keepalives_interval" => {}
     158            0 :                 "keepalives_retries" => {}
     159            0 :                 "target_session_attrs" => {}
     160            0 :                 "channel_binding" => {}
     161            0 :                 "max_backend_message_size" => {}
     162              : 
     163            0 :                 "client_encoding" => {
     164            0 :                     client_encoding = true;
     165            0 :                     // only error should be from bad null bytes,
     166            0 :                     // but we've already checked for those.
     167            0 :                     _ = self.param("client_encoding", v);
     168            0 :                 }
     169              : 
     170            0 :                 _ => {
     171            0 :                     // only error should be from bad null bytes,
     172            0 :                     // but we've already checked for those.
     173            0 :                     _ = self.param(k, v);
     174            0 :                 }
     175              :             }
     176              :         }
     177            0 :         if !client_encoding {
     178            0 :             // for compatibility since we removed it from tokio-postgres
     179            0 :             self.param("client_encoding", "UTF8").unwrap();
     180            0 :         }
     181            0 :     }
     182              : }
     183              : 
     184              : impl std::ops::Deref for ConnCfg {
     185              :     type Target = tokio_postgres::Config;
     186              : 
     187            8 :     fn deref(&self) -> &Self::Target {
     188            8 :         &self.0
     189            8 :     }
     190              : }
     191              : 
     192              : /// For now, let's make it easier to setup the config.
     193              : impl std::ops::DerefMut for ConnCfg {
     194           20 :     fn deref_mut(&mut self) -> &mut Self::Target {
     195           20 :         &mut self.0
     196           20 :     }
     197              : }
     198              : 
     199              : impl ConnCfg {
     200              :     /// Establish a raw TCP connection to the compute node.
     201            0 :     async fn connect_raw(&self, timeout: Duration) -> io::Result<(SocketAddr, TcpStream, &str)> {
     202            0 :         use tokio_postgres::config::Host;
     203            0 : 
     204            0 :         // wrap TcpStream::connect with timeout
     205            0 :         let connect_with_timeout = |host, port| {
     206            0 :             tokio::time::timeout(timeout, TcpStream::connect((host, port))).map(
     207            0 :                 move |res| match res {
     208            0 :                     Ok(tcpstream_connect_res) => tcpstream_connect_res,
     209            0 :                     Err(_) => Err(io::Error::new(
     210            0 :                         io::ErrorKind::TimedOut,
     211            0 :                         format!("exceeded connection timeout {timeout:?}"),
     212            0 :                     )),
     213            0 :                 },
     214            0 :             )
     215            0 :         };
     216              : 
     217            0 :         let connect_once = |host, port| {
     218            0 :             info!("trying to connect to compute node at {host}:{port}");
     219            0 :             connect_with_timeout(host, port).and_then(|socket| async {
     220            0 :                 let socket_addr = socket.peer_addr()?;
     221            0 :                 // This prevents load balancer from severing the connection.
     222            0 :                 socket2::SockRef::from(&socket).set_keepalive(true)?;
     223            0 :                 Ok((socket_addr, socket))
     224            0 :             })
     225            0 :         };
     226              : 
     227              :         // We can't reuse connection establishing logic from `tokio_postgres` here,
     228              :         // because it has no means for extracting the underlying socket which we
     229              :         // require for our business.
     230            0 :         let mut connection_error = None;
     231            0 :         let ports = self.0.get_ports();
     232            0 :         let hosts = self.0.get_hosts();
     233            0 :         // the ports array is supposed to have 0 entries, 1 entry, or as many entries as in the hosts array
     234            0 :         if ports.len() > 1 && ports.len() != hosts.len() {
     235            0 :             return Err(io::Error::new(
     236            0 :                 io::ErrorKind::Other,
     237            0 :                 format!(
     238            0 :                     "bad compute config, \
     239            0 :                      ports and hosts entries' count does not match: {:?}",
     240            0 :                     self.0
     241            0 :                 ),
     242            0 :             ));
     243            0 :         }
     244              : 
     245            0 :         for (i, host) in hosts.iter().enumerate() {
     246            0 :             let port = ports.get(i).or_else(|| ports.first()).unwrap_or(&5432);
     247            0 :             let host = match host {
     248            0 :                 Host::Tcp(host) => host.as_str(),
     249            0 :                 Host::Unix(_) => continue, // unix sockets are not welcome here
     250              :             };
     251              : 
     252            0 :             match connect_once(host, *port).await {
     253            0 :                 Ok((sockaddr, stream)) => return Ok((sockaddr, stream, host)),
     254            0 :                 Err(err) => {
     255            0 :                     // We can't throw an error here, as there might be more hosts to try.
     256            0 :                     warn!("couldn't connect to compute node at {host}:{port}: {err}");
     257            0 :                     connection_error = Some(err);
     258              :                 }
     259              :             }
     260              :         }
     261              : 
     262            0 :         Err(connection_error.unwrap_or_else(|| {
     263            0 :             io::Error::new(
     264            0 :                 io::ErrorKind::Other,
     265            0 :                 format!("bad compute config: {:?}", self.0),
     266            0 :             )
     267            0 :         }))
     268            0 :     }
     269              : }
     270              : 
     271              : pub struct PostgresConnection {
     272              :     /// Socket connected to a compute node.
     273              :     pub stream: tokio_postgres::maybe_tls_stream::MaybeTlsStream<
     274              :         tokio::net::TcpStream,
     275              :         tokio_postgres_rustls::RustlsStream<tokio::net::TcpStream>,
     276              :     >,
     277              :     /// PostgreSQL connection parameters.
     278              :     pub params: std::collections::HashMap<String, String>,
     279              :     /// Query cancellation token.
     280              :     pub cancel_closure: CancelClosure,
     281              :     /// Labels for proxy's metrics.
     282              :     pub aux: MetricsAuxInfo,
     283              : 
     284              :     _guage: NumDbConnectionsGuard<'static>,
     285              : }
     286              : 
     287              : impl ConnCfg {
     288              :     /// Connect to a corresponding compute node.
     289            0 :     pub async fn connect(
     290            0 :         &self,
     291            0 :         ctx: &mut RequestMonitoring,
     292            0 :         allow_self_signed_compute: bool,
     293            0 :         aux: MetricsAuxInfo,
     294            0 :         timeout: Duration,
     295            0 :     ) -> Result<PostgresConnection, ConnectionError> {
     296            0 :         let pause = ctx.latency_timer.pause(crate::metrics::Waiting::Compute);
     297            0 :         let (socket_addr, stream, host) = self.connect_raw(timeout).await?;
     298            0 :         drop(pause);
     299              : 
     300            0 :         let client_config = if allow_self_signed_compute {
     301              :             // Allow all certificates for creating the connection
     302            0 :             let verifier = Arc::new(AcceptEverythingVerifier) as Arc<dyn ServerCertVerifier>;
     303            0 :             rustls::ClientConfig::builder()
     304            0 :                 .dangerous()
     305            0 :                 .with_custom_certificate_verifier(verifier)
     306              :         } else {
     307            0 :             let root_store = TLS_ROOTS.get_or_try_init(load_certs)?.clone();
     308            0 :             rustls::ClientConfig::builder().with_root_certificates(root_store)
     309              :         };
     310            0 :         let client_config = client_config.with_no_client_auth();
     311            0 : 
     312            0 :         let mut mk_tls = tokio_postgres_rustls::MakeRustlsConnect::new(client_config);
     313            0 :         let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
     314            0 :             &mut mk_tls,
     315            0 :             host,
     316            0 :         )?;
     317              : 
     318              :         // connect_raw() will not use TLS if sslmode is "disable"
     319            0 :         let pause = ctx.latency_timer.pause(crate::metrics::Waiting::Compute);
     320            0 :         let (client, connection) = self.0.connect_raw(stream, tls).await?;
     321            0 :         drop(pause);
     322            0 :         tracing::Span::current().record("pid", &tracing::field::display(client.get_process_id()));
     323            0 :         let stream = connection.stream.into_inner();
     324            0 : 
     325            0 :         info!(
     326            0 :             cold_start_info = ctx.cold_start_info.as_str(),
     327            0 :             "connected to compute node at {host} ({socket_addr}) sslmode={:?}",
     328            0 :             self.0.get_ssl_mode()
     329              :         );
     330              : 
     331              :         // This is very ugly but as of now there's no better way to
     332              :         // extract the connection parameters from tokio-postgres' connection.
     333              :         // TODO: solve this problem in a more elegant manner (e.g. the new library).
     334            0 :         let params = connection.parameters;
     335            0 : 
     336            0 :         // NB: CancelToken is supposed to hold socket_addr, but we use connect_raw.
     337            0 :         // Yet another reason to rework the connection establishing code.
     338            0 :         let cancel_closure = CancelClosure::new(socket_addr, client.cancel_token());
     339            0 : 
     340            0 :         let connection = PostgresConnection {
     341            0 :             stream,
     342            0 :             params,
     343            0 :             cancel_closure,
     344            0 :             aux,
     345            0 :             _guage: Metrics::get().proxy.db_connections.guard(ctx.protocol),
     346            0 :         };
     347            0 : 
     348            0 :         Ok(connection)
     349            0 :     }
     350              : }
     351              : 
     352              : /// Retrieve `options` from a startup message, dropping all proxy-secific flags.
     353           12 : fn filtered_options(options: &str) -> Option<String> {
     354           12 :     #[allow(unstable_name_collisions)]
     355           12 :     let options: String = StartupMessageParams::parse_options_raw(options)
     356           26 :         .filter(|opt| parse_endpoint_param(opt).is_none() && neon_option(opt).is_none())
     357           12 :         .intersperse(" ") // TODO: use impl from std once it's stabilized
     358           12 :         .collect();
     359           12 : 
     360           12 :     // Don't even bother with empty options.
     361           12 :     if options.is_empty() {
     362            6 :         return None;
     363            6 :     }
     364            6 : 
     365            6 :     Some(options)
     366           12 : }
     367              : 
     368            0 : fn load_certs() -> Result<Arc<rustls::RootCertStore>, io::Error> {
     369            0 :     let der_certs = rustls_native_certs::load_native_certs()?;
     370            0 :     let mut store = rustls::RootCertStore::empty();
     371            0 :     store.add_parsable_certificates(der_certs);
     372            0 :     Ok(Arc::new(store))
     373            0 : }
     374              : static TLS_ROOTS: OnceCell<Arc<rustls::RootCertStore>> = OnceCell::new();
     375              : 
     376              : #[derive(Debug)]
     377              : struct AcceptEverythingVerifier;
     378              : impl ServerCertVerifier for AcceptEverythingVerifier {
     379            0 :     fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
     380            0 :         use rustls::SignatureScheme::*;
     381            0 :         // The schemes for which `SignatureScheme::supported_in_tls13` returns true.
     382            0 :         vec![
     383            0 :             ECDSA_NISTP521_SHA512,
     384            0 :             ECDSA_NISTP384_SHA384,
     385            0 :             ECDSA_NISTP256_SHA256,
     386            0 :             RSA_PSS_SHA512,
     387            0 :             RSA_PSS_SHA384,
     388            0 :             RSA_PSS_SHA256,
     389            0 :             ED25519,
     390            0 :         ]
     391            0 :     }
     392            0 :     fn verify_server_cert(
     393            0 :         &self,
     394            0 :         _end_entity: &rustls::pki_types::CertificateDer<'_>,
     395            0 :         _intermediates: &[rustls::pki_types::CertificateDer<'_>],
     396            0 :         _server_name: &rustls::pki_types::ServerName<'_>,
     397            0 :         _ocsp_response: &[u8],
     398            0 :         _now: rustls::pki_types::UnixTime,
     399            0 :     ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
     400            0 :         Ok(rustls::client::danger::ServerCertVerified::assertion())
     401            0 :     }
     402            0 :     fn verify_tls12_signature(
     403            0 :         &self,
     404            0 :         _message: &[u8],
     405            0 :         _cert: &rustls::pki_types::CertificateDer<'_>,
     406            0 :         _dss: &rustls::DigitallySignedStruct,
     407            0 :     ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
     408            0 :         Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
     409            0 :     }
     410            0 :     fn verify_tls13_signature(
     411            0 :         &self,
     412            0 :         _message: &[u8],
     413            0 :         _cert: &rustls::pki_types::CertificateDer<'_>,
     414            0 :         _dss: &rustls::DigitallySignedStruct,
     415            0 :     ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
     416            0 :         Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
     417            0 :     }
     418              : }
     419              : 
     420              : #[cfg(test)]
     421              : mod tests {
     422              :     use super::*;
     423              : 
     424              :     #[test]
     425            2 :     fn test_filtered_options() {
     426            2 :         // Empty options is unlikely to be useful anyway.
     427            2 :         assert_eq!(filtered_options(""), None);
     428              : 
     429              :         // It's likely that clients will only use options to specify endpoint/project.
     430            2 :         let params = "project=foo";
     431            2 :         assert_eq!(filtered_options(params), None);
     432              : 
     433              :         // Same, because unescaped whitespaces are no-op.
     434            2 :         let params = " project=foo ";
     435            2 :         assert_eq!(filtered_options(params), None);
     436              : 
     437            2 :         let params = r"\  project=foo \ ";
     438            2 :         assert_eq!(filtered_options(params).as_deref(), Some(r"\  \ "));
     439              : 
     440            2 :         let params = "project = foo";
     441            2 :         assert_eq!(filtered_options(params).as_deref(), Some("project = foo"));
     442              : 
     443            2 :         let params = "project = foo neon_endpoint_type:read_write   neon_lsn:0/2";
     444            2 :         assert_eq!(filtered_options(params).as_deref(), Some("project = foo"));
     445            2 :     }
     446              : }
        

Generated by: LCOV version 2.1-beta