LCOV - code coverage report
Current view: top level - proxy/src/compute - mod.rs (source / functions) Coverage Total Hit
Test: 553e39c2773e5840c720c90d86e56f89a4330d43.info Lines: 13.5 % 215 29
Test Date: 2025-06-13 20:01:21 Functions: 15.0 % 20 3

            Line data    Source code
       1              : mod tls;
       2              : 
       3              : use std::fmt::Debug;
       4              : use std::io;
       5              : use std::net::{IpAddr, SocketAddr};
       6              : 
       7              : use futures::{FutureExt, TryFutureExt};
       8              : use itertools::Itertools;
       9              : use postgres_client::config::{AuthKeys, SslMode};
      10              : use postgres_client::maybe_tls_stream::MaybeTlsStream;
      11              : use postgres_client::tls::MakeTlsConnect;
      12              : use postgres_client::{CancelToken, NoTls, RawConnection};
      13              : use postgres_protocol::message::backend::NoticeResponseBody;
      14              : use thiserror::Error;
      15              : use tokio::net::{TcpStream, lookup_host};
      16              : use tracing::{debug, error, info, warn};
      17              : 
      18              : use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
      19              : use crate::auth::parse_endpoint_param;
      20              : use crate::cancellation::CancelClosure;
      21              : use crate::compute::tls::TlsError;
      22              : use crate::config::ComputeConfig;
      23              : use crate::context::RequestContext;
      24              : use crate::control_plane::client::ApiLockError;
      25              : use crate::control_plane::errors::WakeComputeError;
      26              : use crate::control_plane::messages::MetricsAuxInfo;
      27              : use crate::error::{ReportableError, UserFacingError};
      28              : use crate::metrics::{Metrics, NumDbConnectionsGuard};
      29              : use crate::pqproto::StartupMessageParams;
      30              : use crate::proxy::neon_option;
      31              : use crate::types::Host;
      32              : 
      33              : pub const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
      34              : 
      35              : #[derive(Debug, Error)]
      36              : pub(crate) enum ConnectionError {
      37              :     /// This error doesn't seem to reveal any secrets; for instance,
      38              :     /// `postgres_client::error::Kind` doesn't contain ip addresses and such.
      39              :     #[error("{COULD_NOT_CONNECT}: {0}")]
      40              :     Postgres(#[from] postgres_client::Error),
      41              : 
      42              :     #[error("{COULD_NOT_CONNECT}: {0}")]
      43              :     TlsError(#[from] TlsError),
      44              : 
      45              :     #[error("{COULD_NOT_CONNECT}: {0}")]
      46              :     WakeComputeError(#[from] WakeComputeError),
      47              : 
      48              :     #[error("error acquiring resource permit: {0}")]
      49              :     TooManyConnectionAttempts(#[from] ApiLockError),
      50              : }
      51              : 
      52              : impl UserFacingError for ConnectionError {
      53            0 :     fn to_string_client(&self) -> String {
      54            0 :         match self {
      55              :             // This helps us drop irrelevant library-specific prefixes.
      56              :             // TODO: propagate severity level and other parameters.
      57            0 :             ConnectionError::Postgres(err) => match err.as_db_error() {
      58            0 :                 Some(err) => {
      59            0 :                     let msg = err.message();
      60            0 : 
      61            0 :                     if msg.starts_with("unsupported startup parameter: ")
      62            0 :                         || msg.starts_with("unsupported startup parameter in options: ")
      63              :                     {
      64            0 :                         format!("{msg}. Please use unpooled connection or remove this parameter from the startup package. More details: https://neon.tech/docs/connect/connection-errors#unsupported-startup-parameter")
      65              :                     } else {
      66            0 :                         msg.to_owned()
      67              :                     }
      68              :                 }
      69            0 :                 None => err.to_string(),
      70              :             },
      71            0 :             ConnectionError::WakeComputeError(err) => err.to_string_client(),
      72              :             ConnectionError::TooManyConnectionAttempts(_) => {
      73            0 :                 "Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
      74              :             }
      75            0 :             ConnectionError::TlsError(_) => COULD_NOT_CONNECT.to_owned(),
      76              :         }
      77            0 :     }
      78              : }
      79              : 
      80              : impl ReportableError for ConnectionError {
      81            0 :     fn get_error_kind(&self) -> crate::error::ErrorKind {
      82            0 :         match self {
      83            0 :             ConnectionError::Postgres(e) if e.as_db_error().is_some() => {
      84            0 :                 crate::error::ErrorKind::Postgres
      85              :             }
      86            0 :             ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute,
      87            0 :             ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
      88            0 :             ConnectionError::WakeComputeError(e) => e.get_error_kind(),
      89            0 :             ConnectionError::TooManyConnectionAttempts(e) => e.get_error_kind(),
      90              :         }
      91            0 :     }
      92              : }
      93              : 
      94              : /// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
      95              : pub(crate) type ScramKeys = postgres_client::config::ScramKeys<32>;
      96              : 
      97              : #[derive(Clone)]
      98              : pub enum Auth {
      99              :     /// Only used during console-redirect.
     100              :     Password(Vec<u8>),
     101              :     /// Used by sql-over-http, ws, tcp.
     102              :     Scram(Box<ScramKeys>),
     103              : }
     104              : 
     105              : /// A config for authenticating to the compute node.
     106              : pub(crate) struct AuthInfo {
     107              :     /// None for local-proxy, as we use trust-based localhost auth.
     108              :     /// Some for sql-over-http, ws, tcp, and in most cases for console-redirect.
     109              :     /// Might be None for console-redirect, but that's only a consequence of testing environments ATM.
     110              :     auth: Option<Auth>,
     111              :     server_params: StartupMessageParams,
     112              : 
     113              :     /// Console redirect sets user and database, we shouldn't re-use those from the params.
     114              :     skip_db_user: bool,
     115              : }
     116              : 
     117              : /// Contains only the data needed to establish a secure connection to compute.
     118              : #[derive(Clone)]
     119              : pub struct ConnectInfo {
     120              :     pub host_addr: Option<IpAddr>,
     121              :     pub host: Host,
     122              :     pub port: u16,
     123              :     pub ssl_mode: SslMode,
     124              : }
     125              : 
     126              : /// Creation and initialization routines.
     127              : impl AuthInfo {
     128            0 :     pub(crate) fn for_console_redirect(db: &str, user: &str, pw: Option<&str>) -> Self {
     129            0 :         let mut server_params = StartupMessageParams::default();
     130            0 :         server_params.insert("database", db);
     131            0 :         server_params.insert("user", user);
     132            0 :         Self {
     133            0 :             auth: pw.map(|pw| Auth::Password(pw.as_bytes().to_owned())),
     134            0 :             server_params,
     135            0 :             skip_db_user: true,
     136            0 :         }
     137            0 :     }
     138              : 
     139            0 :     pub(crate) fn with_auth_keys(keys: ComputeCredentialKeys) -> Self {
     140            0 :         Self {
     141            0 :             auth: match keys {
     142            0 :                 ComputeCredentialKeys::AuthKeys(AuthKeys::ScramSha256(auth_keys)) => {
     143            0 :                     Some(Auth::Scram(Box::new(auth_keys)))
     144              :                 }
     145            0 :                 ComputeCredentialKeys::JwtPayload(_) | ComputeCredentialKeys::None => None,
     146              :             },
     147            0 :             server_params: StartupMessageParams::default(),
     148            0 :             skip_db_user: false,
     149            0 :         }
     150            0 :     }
     151              : }
     152              : 
     153              : impl ConnectInfo {
     154            0 :     pub fn to_postgres_client_config(&self) -> postgres_client::Config {
     155            0 :         let mut config = postgres_client::Config::new(self.host.to_string(), self.port);
     156            0 :         config.ssl_mode(self.ssl_mode);
     157            0 :         if let Some(host_addr) = self.host_addr {
     158            0 :             config.set_host_addr(host_addr);
     159            0 :         }
     160            0 :         config
     161            0 :     }
     162              : }
     163              : 
     164              : impl AuthInfo {
     165            0 :     fn enrich(&self, mut config: postgres_client::Config) -> postgres_client::Config {
     166            0 :         match &self.auth {
     167            0 :             Some(Auth::Scram(keys)) => config.auth_keys(AuthKeys::ScramSha256(**keys)),
     168            0 :             Some(Auth::Password(pw)) => config.password(pw),
     169            0 :             None => &mut config,
     170              :         };
     171            0 :         for (k, v) in self.server_params.iter() {
     172            0 :             config.set_param(k, v);
     173            0 :         }
     174            0 :         config
     175            0 :     }
     176              : 
     177              :     /// Apply startup message params to the connection config.
     178            0 :     pub(crate) fn set_startup_params(
     179            0 :         &mut self,
     180            0 :         params: &StartupMessageParams,
     181            0 :         arbitrary_params: bool,
     182            0 :     ) {
     183            0 :         if !arbitrary_params {
     184            0 :             self.server_params.insert("client_encoding", "UTF8");
     185            0 :         }
     186            0 :         for (k, v) in params.iter() {
     187            0 :             match k {
     188              :                 // Only set `user` if it's not present in the config.
     189              :                 // Console redirect auth flow takes username from the console's response.
     190            0 :                 "user" | "database" if self.skip_db_user => {}
     191            0 :                 "options" => {
     192            0 :                     if let Some(options) = filtered_options(v) {
     193            0 :                         self.server_params.insert(k, &options);
     194            0 :                     }
     195              :                 }
     196            0 :                 "user" | "database" | "application_name" | "replication" => {
     197            0 :                     self.server_params.insert(k, v);
     198            0 :                 }
     199              : 
     200              :                 // if we allow arbitrary params, then we forward them through.
     201              :                 // this is a flag for a period of backwards compatibility
     202            0 :                 k if arbitrary_params => {
     203            0 :                     self.server_params.insert(k, v);
     204            0 :                 }
     205            0 :                 _ => {}
     206              :             }
     207              :         }
     208            0 :     }
     209              : }
     210              : 
     211              : impl ConnectInfo {
     212              :     /// Establish a raw TCP+TLS connection to the compute node.
     213            0 :     async fn connect_raw(
     214            0 :         &self,
     215            0 :         config: &ComputeConfig,
     216            0 :     ) -> Result<(SocketAddr, MaybeTlsStream<TcpStream, RustlsStream>), TlsError> {
     217            0 :         let timeout = config.timeout;
     218            0 : 
     219            0 :         // wrap TcpStream::connect with timeout
     220            0 :         let connect_with_timeout = |addrs| {
     221            0 :             tokio::time::timeout(timeout, TcpStream::connect(addrs)).map(move |res| match res {
     222            0 :                 Ok(tcpstream_connect_res) => tcpstream_connect_res,
     223            0 :                 Err(_) => Err(io::Error::new(
     224            0 :                     io::ErrorKind::TimedOut,
     225            0 :                     format!("exceeded connection timeout {timeout:?}"),
     226            0 :                 )),
     227            0 :             })
     228            0 :         };
     229              : 
     230            0 :         let connect_once = |addrs| {
     231            0 :             debug!("trying to connect to compute node at {addrs:?}");
     232            0 :             connect_with_timeout(addrs).and_then(|stream| async {
     233            0 :                 let socket_addr = stream.peer_addr()?;
     234            0 :                 let socket = socket2::SockRef::from(&stream);
     235            0 :                 // Disable Nagle's algorithm to not introduce latency between
     236            0 :                 // client and compute.
     237            0 :                 socket.set_nodelay(true)?;
     238              :                 // This prevents load balancer from severing the connection.
     239            0 :                 socket.set_keepalive(true)?;
     240            0 :                 Ok((socket_addr, stream))
     241            0 :             })
     242            0 :         };
     243              : 
     244              :         // We can't reuse connection establishing logic from `postgres_client` here,
     245              :         // because it has no means for extracting the underlying socket which we
     246              :         // require for our business.
     247            0 :         let port = self.port;
     248            0 :         let host = &*self.host;
     249              : 
     250            0 :         let addrs = match self.host_addr {
     251            0 :             Some(addr) => vec![SocketAddr::new(addr, port)],
     252            0 :             None => lookup_host((host, port)).await?.collect(),
     253              :         };
     254              : 
     255            0 :         match connect_once(&*addrs).await {
     256            0 :             Ok((sockaddr, stream)) => Ok((
     257            0 :                 sockaddr,
     258            0 :                 tls::connect_tls(stream, self.ssl_mode, config, host).await?,
     259              :             )),
     260            0 :             Err(err) => {
     261            0 :                 warn!("couldn't connect to compute node at {host}:{port}: {err}");
     262            0 :                 Err(TlsError::Connection(err))
     263              :             }
     264              :         }
     265            0 :     }
     266              : }
     267              : 
     268              : type RustlsStream = <ComputeConfig as MakeTlsConnect<tokio::net::TcpStream>>::Stream;
     269              : 
     270              : pub(crate) struct PostgresConnection {
     271              :     /// Socket connected to a compute node.
     272              :     pub(crate) stream: MaybeTlsStream<tokio::net::TcpStream, RustlsStream>,
     273              :     /// PostgreSQL connection parameters.
     274              :     pub(crate) params: std::collections::HashMap<String, String>,
     275              :     /// Query cancellation token.
     276              :     pub(crate) cancel_closure: CancelClosure,
     277              :     /// Labels for proxy's metrics.
     278              :     pub(crate) aux: MetricsAuxInfo,
     279              :     /// Notices received from compute after authenticating
     280              :     pub(crate) delayed_notice: Vec<NoticeResponseBody>,
     281              : 
     282              :     _guage: NumDbConnectionsGuard<'static>,
     283              : }
     284              : 
     285              : impl ConnectInfo {
     286              :     /// Connect to a corresponding compute node.
     287            0 :     pub(crate) async fn connect(
     288            0 :         &self,
     289            0 :         ctx: &RequestContext,
     290            0 :         aux: MetricsAuxInfo,
     291            0 :         auth: &AuthInfo,
     292            0 :         config: &ComputeConfig,
     293            0 :         user_info: ComputeUserInfo,
     294            0 :     ) -> Result<PostgresConnection, ConnectionError> {
     295            0 :         let mut tmp_config = auth.enrich(self.to_postgres_client_config());
     296            0 :         // we setup SSL early in `ConnectInfo::connect_raw`.
     297            0 :         tmp_config.ssl_mode(SslMode::Disable);
     298            0 : 
     299            0 :         let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
     300            0 :         let (socket_addr, stream) = self.connect_raw(config).await?;
     301            0 :         let connection = tmp_config.connect_raw(stream, NoTls).await?;
     302            0 :         drop(pause);
     303            0 : 
     304            0 :         let RawConnection {
     305            0 :             stream,
     306            0 :             parameters,
     307            0 :             delayed_notice,
     308            0 :             process_id,
     309            0 :             secret_key,
     310            0 :         } = connection;
     311            0 : 
     312            0 :         tracing::Span::current().record("pid", tracing::field::display(process_id));
     313            0 :         tracing::Span::current().record("compute_id", tracing::field::display(&aux.compute_id));
     314            0 :         let MaybeTlsStream::Raw(stream) = stream.into_inner();
     315            0 : 
     316            0 :         // TODO: lots of useful info but maybe we can move it elsewhere (eg traces?)
     317            0 :         info!(
     318            0 :             cold_start_info = ctx.cold_start_info().as_str(),
     319            0 :             "connected to compute node at {} ({socket_addr}) sslmode={:?}, latency={}, query_id={}",
     320            0 :             self.host,
     321            0 :             self.ssl_mode,
     322            0 :             ctx.get_proxy_latency(),
     323            0 :             ctx.get_testodrome_id().unwrap_or_default(),
     324              :         );
     325              : 
     326              :         // NB: CancelToken is supposed to hold socket_addr, but we use connect_raw.
     327              :         // Yet another reason to rework the connection establishing code.
     328            0 :         let cancel_closure = CancelClosure::new(
     329            0 :             socket_addr,
     330            0 :             CancelToken {
     331            0 :                 socket_config: None,
     332            0 :                 ssl_mode: self.ssl_mode,
     333            0 :                 process_id,
     334            0 :                 secret_key,
     335            0 :             },
     336            0 :             self.host.to_string(),
     337            0 :             user_info,
     338            0 :         );
     339            0 : 
     340            0 :         let connection = PostgresConnection {
     341            0 :             stream,
     342            0 :             params: parameters,
     343            0 :             delayed_notice,
     344            0 :             cancel_closure,
     345            0 :             aux,
     346            0 :             _guage: Metrics::get().proxy.db_connections.guard(ctx.protocol()),
     347            0 :         };
     348            0 : 
     349            0 :         Ok(connection)
     350            0 :     }
     351              : }
     352              : 
     353              : /// Retrieve `options` from a startup message, dropping all proxy-secific flags.
     354            6 : fn filtered_options(options: &str) -> Option<String> {
     355            6 :     #[allow(unstable_name_collisions)]
     356            6 :     let options: String = StartupMessageParams::parse_options_raw(options)
     357           14 :         .filter(|opt| parse_endpoint_param(opt).is_none() && neon_option(opt).is_none())
     358            6 :         .intersperse(" ") // TODO: use impl from std once it's stabilized
     359            6 :         .collect();
     360            6 : 
     361            6 :     // Don't even bother with empty options.
     362            6 :     if options.is_empty() {
     363            3 :         return None;
     364            3 :     }
     365            3 : 
     366            3 :     Some(options)
     367            6 : }
     368              : 
     369              : #[cfg(test)]
     370              : mod tests {
     371              :     use super::*;
     372              : 
     373              :     #[test]
     374            1 :     fn test_filtered_options() {
     375            1 :         // Empty options is unlikely to be useful anyway.
     376            1 :         let params = "";
     377            1 :         assert_eq!(filtered_options(params), None);
     378              : 
     379              :         // It's likely that clients will only use options to specify endpoint/project.
     380            1 :         let params = "project=foo";
     381            1 :         assert_eq!(filtered_options(params), None);
     382              : 
     383              :         // Same, because unescaped whitespaces are no-op.
     384            1 :         let params = " project=foo ";
     385            1 :         assert_eq!(filtered_options(params).as_deref(), None);
     386              : 
     387            1 :         let params = r"\  project=foo \ ";
     388            1 :         assert_eq!(filtered_options(params).as_deref(), Some(r"\  \ "));
     389              : 
     390            1 :         let params = "project = foo";
     391            1 :         assert_eq!(filtered_options(params).as_deref(), Some("project = foo"));
     392              : 
     393            1 :         let params = "project = foo neon_endpoint_type:read_write   neon_lsn:0/2 neon_proxy_params_compat:true";
     394            1 :         assert_eq!(filtered_options(params).as_deref(), Some("project = foo"));
     395            1 :     }
     396              : }
        

Generated by: LCOV version 2.1-beta