LCOV - code coverage report
Current view: top level - proxy/src/serverless - backend.rs (source / functions) Coverage Total Hit
Test: f26987deef05b637be3b9ae5d95c30faa25ab621.info Lines: 0.0 % 145 0
Test Date: 2025-07-31 11:15:47 Functions: 0.0 % 13 0

            Line data    Source code
       1              : use std::sync::Arc;
       2              : use std::time::Duration;
       3              : 
       4              : use ed25519_dalek::SigningKey;
       5              : use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
       6              : use jose_jwk::jose_b64;
       7              : use postgres_client::error::SqlState;
       8              : use postgres_client::maybe_tls_stream::MaybeTlsStream;
       9              : use rand_core::OsRng;
      10              : use tracing::field::display;
      11              : use tracing::{debug, info};
      12              : 
      13              : use super::AsyncRW;
      14              : use super::conn_pool::poll_client;
      15              : use super::conn_pool_lib::{Client, ConnInfo, EndpointConnPool, GlobalConnPool};
      16              : use super::http_conn_pool::{self, HttpConnPool, LocalProxyClient, poll_http2_client};
      17              : use super::local_conn_pool::{self, EXT_NAME, EXT_SCHEMA, EXT_VERSION, LocalConnPool};
      18              : use crate::auth::backend::local::StaticAuthRules;
      19              : use crate::auth::backend::{ComputeCredentials, ComputeUserInfo};
      20              : use crate::auth::{self, AuthError};
      21              : use crate::compute;
      22              : use crate::compute_ctl::{
      23              :     ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest,
      24              : };
      25              : use crate::config::ProxyConfig;
      26              : use crate::context::RequestContext;
      27              : use crate::control_plane::client::ApiLockError;
      28              : use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError};
      29              : use crate::error::{ErrorKind, ReportableError, UserFacingError};
      30              : use crate::intern::{EndpointIdInt, RoleNameInt};
      31              : use crate::pqproto::StartupMessageParams;
      32              : use crate::proxy::{connect_auth, connect_compute};
      33              : use crate::rate_limiter::EndpointRateLimiter;
      34              : use crate::types::{EndpointId, LOCAL_PROXY_SUFFIX};
      35              : 
      36              : pub(crate) struct PoolingBackend {
      37              :     pub(crate) http_conn_pool:
      38              :         Arc<GlobalConnPool<LocalProxyClient, HttpConnPool<LocalProxyClient>>>,
      39              :     pub(crate) local_pool: Arc<LocalConnPool<postgres_client::Client>>,
      40              :     pub(crate) pool:
      41              :         Arc<GlobalConnPool<postgres_client::Client, EndpointConnPool<postgres_client::Client>>>,
      42              : 
      43              :     pub(crate) config: &'static ProxyConfig,
      44              :     pub(crate) auth_backend: &'static crate::auth::Backend<'static, ()>,
      45              :     pub(crate) endpoint_rate_limiter: Arc<EndpointRateLimiter>,
      46              : }
      47              : 
      48              : impl PoolingBackend {
      49            0 :     pub(crate) async fn authenticate_with_password(
      50            0 :         &self,
      51            0 :         ctx: &RequestContext,
      52            0 :         user_info: &ComputeUserInfo,
      53            0 :         password: &[u8],
      54            0 :     ) -> Result<ComputeCredentials, AuthError> {
      55            0 :         ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
      56              : 
      57            0 :         let user_info = user_info.clone();
      58            0 :         let backend = self.auth_backend.as_ref().map(|()| user_info.clone());
      59            0 :         let access_control = backend.get_endpoint_access_control(ctx).await?;
      60            0 :         access_control.check(
      61            0 :             ctx,
      62            0 :             self.config.authentication_config.ip_allowlist_check_enabled,
      63            0 :             self.config.authentication_config.is_vpc_acccess_proxy,
      64            0 :         )?;
      65              : 
      66            0 :         access_control.connection_attempt_rate_limit(
      67            0 :             ctx,
      68            0 :             &user_info.endpoint,
      69            0 :             &self.endpoint_rate_limiter,
      70            0 :         )?;
      71              : 
      72            0 :         let role_access = backend.get_role_secret(ctx).await?;
      73            0 :         let Some(secret) = role_access.secret else {
      74              :             // If we don't have an authentication secret, for the http flow we can just return an error.
      75            0 :             info!("authentication info not found");
      76            0 :             return Err(AuthError::password_failed(&*user_info.user));
      77              :         };
      78              : 
      79            0 :         let ep = EndpointIdInt::from(&user_info.endpoint);
      80            0 :         let role = RoleNameInt::from(&user_info.user);
      81            0 :         let auth_outcome = crate::auth::validate_password_and_exchange(
      82            0 :             &self.config.authentication_config.scram_thread_pool,
      83            0 :             ep,
      84            0 :             role,
      85            0 :             password,
      86            0 :             secret,
      87            0 :         )
      88            0 :         .await?;
      89            0 :         let res = match auth_outcome {
      90            0 :             crate::sasl::Outcome::Success(key) => {
      91            0 :                 info!("user successfully authenticated");
      92            0 :                 Ok(key)
      93              :             }
      94            0 :             crate::sasl::Outcome::Failure(reason) => {
      95            0 :                 info!("auth backend failed with an error: {reason}");
      96            0 :                 Err(AuthError::password_failed(&*user_info.user))
      97              :             }
      98              :         };
      99            0 :         res.map(|key| ComputeCredentials {
     100            0 :             info: user_info,
     101            0 :             keys: key,
     102            0 :         })
     103            0 :     }
     104              : 
     105            0 :     pub(crate) async fn authenticate_with_jwt(
     106            0 :         &self,
     107            0 :         ctx: &RequestContext,
     108            0 :         user_info: &ComputeUserInfo,
     109            0 :         jwt: String,
     110            0 :     ) -> Result<ComputeCredentials, AuthError> {
     111            0 :         ctx.set_auth_method(crate::context::AuthMethod::Jwt);
     112              : 
     113            0 :         match &self.auth_backend {
     114            0 :             crate::auth::Backend::ControlPlane(console, ()) => {
     115            0 :                 let keys = self
     116            0 :                     .config
     117            0 :                     .authentication_config
     118            0 :                     .jwks_cache
     119            0 :                     .check_jwt(
     120            0 :                         ctx,
     121            0 :                         user_info.endpoint.clone(),
     122            0 :                         &user_info.user,
     123            0 :                         &**console,
     124            0 :                         &jwt,
     125            0 :                     )
     126            0 :                     .await?;
     127              : 
     128            0 :                 Ok(ComputeCredentials {
     129            0 :                     info: user_info.clone(),
     130            0 :                     keys,
     131            0 :                 })
     132              :             }
     133              :             crate::auth::Backend::Local(_) => {
     134            0 :                 let keys = self
     135            0 :                     .config
     136            0 :                     .authentication_config
     137            0 :                     .jwks_cache
     138            0 :                     .check_jwt(
     139            0 :                         ctx,
     140            0 :                         user_info.endpoint.clone(),
     141            0 :                         &user_info.user,
     142            0 :                         &StaticAuthRules,
     143            0 :                         &jwt,
     144            0 :                     )
     145            0 :                     .await?;
     146              : 
     147            0 :                 Ok(ComputeCredentials {
     148            0 :                     info: user_info.clone(),
     149            0 :                     keys,
     150            0 :                 })
     151              :             }
     152              :         }
     153            0 :     }
     154              : 
     155              :     // Wake up the destination if needed. Code here is a bit involved because
     156              :     // we reuse the code from the usual proxy and we need to prepare few structures
     157              :     // that this code expects.
     158              :     #[tracing::instrument(skip_all, fields(
     159              :         pid = tracing::field::Empty,
     160              :         compute_id = tracing::field::Empty,
     161              :         conn_id = tracing::field::Empty,
     162              :     ))]
     163              :     pub(crate) async fn connect_to_compute(
     164              :         &self,
     165              :         ctx: &RequestContext,
     166              :         conn_info: ConnInfo,
     167              :         keys: ComputeCredentials,
     168              :         force_new: bool,
     169              :     ) -> Result<Client<postgres_client::Client>, HttpConnError> {
     170              :         let maybe_client = if force_new {
     171              :             debug!("pool: pool is disabled");
     172              :             None
     173              :         } else {
     174              :             debug!("pool: looking for an existing connection");
     175              :             self.pool.get(ctx, &conn_info)?
     176              :         };
     177              : 
     178              :         if let Some(client) = maybe_client {
     179              :             return Ok(client);
     180              :         }
     181              :         let conn_id = uuid::Uuid::new_v4();
     182              :         tracing::Span::current().record("conn_id", display(conn_id));
     183              :         info!(%conn_id, "pool: opening a new connection '{conn_info}'");
     184              :         let backend = self.auth_backend.as_ref().map(|()| keys.info);
     185              : 
     186              :         let mut params = StartupMessageParams::default();
     187              :         params.insert("database", &conn_info.dbname);
     188              :         params.insert("user", &conn_info.user_info.user);
     189              : 
     190              :         let mut auth_info = compute::AuthInfo::with_auth_keys(keys.keys);
     191              :         auth_info.set_startup_params(&params, true);
     192              : 
     193              :         let node = connect_auth::connect_to_compute_and_auth(
     194              :             ctx,
     195              :             self.config,
     196              :             &backend,
     197              :             auth_info,
     198              :             connect_compute::TlsNegotiation::Postgres,
     199              :         )
     200              :         .await?;
     201              : 
     202              :         let (client, connection) = postgres_client::connect::managed(
     203              :             node.stream,
     204              :             Some(node.socket_addr.ip()),
     205              :             postgres_client::config::Host::Tcp(node.hostname.to_string()),
     206              :             node.socket_addr.port(),
     207              :             node.ssl_mode,
     208              :             Some(self.config.connect_to_compute.timeout),
     209              :         )
     210              :         .await?;
     211              : 
     212              :         Ok(poll_client(
     213              :             self.pool.clone(),
     214              :             ctx,
     215              :             conn_info,
     216              :             client,
     217              :             connection,
     218              :             conn_id,
     219              :             node.aux,
     220              :         ))
     221              :     }
     222              : 
     223              :     // Wake up the destination if needed
     224              :     #[tracing::instrument(skip_all, fields(
     225              :         compute_id = tracing::field::Empty,
     226              :         conn_id = tracing::field::Empty,
     227              :     ))]
     228              :     pub(crate) async fn connect_to_local_proxy(
     229              :         &self,
     230              :         ctx: &RequestContext,
     231              :         conn_info: ConnInfo,
     232              :     ) -> Result<http_conn_pool::Client<LocalProxyClient>, HttpConnError> {
     233              :         debug!("pool: looking for an existing connection");
     234              :         if let Ok(Some(client)) = self.http_conn_pool.get(ctx, &conn_info) {
     235              :             return Ok(client);
     236              :         }
     237              : 
     238              :         let conn_id = uuid::Uuid::new_v4();
     239              :         tracing::Span::current().record("conn_id", display(conn_id));
     240              :         debug!(%conn_id, "pool: opening a new connection '{conn_info}'");
     241              :         let backend = self.auth_backend.as_ref().map(|()| ComputeUserInfo {
     242            0 :             user: conn_info.user_info.user.clone(),
     243            0 :             endpoint: EndpointId::from(format!(
     244            0 :                 "{}{LOCAL_PROXY_SUFFIX}",
     245            0 :                 conn_info.user_info.endpoint.normalize()
     246              :             )),
     247            0 :             options: conn_info.user_info.options.clone(),
     248            0 :         });
     249              : 
     250              :         let node = connect_compute::connect_to_compute(
     251              :             ctx,
     252              :             self.config,
     253              :             &backend,
     254              :             connect_compute::TlsNegotiation::Direct,
     255              :         )
     256              :         .await?;
     257              : 
     258              :         let stream = match node.stream.into_framed().into_inner() {
     259              :             MaybeTlsStream::Raw(s) => Box::pin(s) as AsyncRW,
     260              :             MaybeTlsStream::Tls(s) => Box::pin(s) as AsyncRW,
     261              :         };
     262              : 
     263              :         let (client, connection) = hyper::client::conn::http2::Builder::new(TokioExecutor::new())
     264              :             .timer(TokioTimer::new())
     265              :             .keep_alive_interval(Duration::from_secs(20))
     266              :             .keep_alive_while_idle(true)
     267              :             .keep_alive_timeout(Duration::from_secs(5))
     268              :             .handshake(TokioIo::new(stream))
     269              :             .await
     270              :             .map_err(LocalProxyConnError::H2)?;
     271              : 
     272              :         Ok(poll_http2_client(
     273              :             self.http_conn_pool.clone(),
     274              :             ctx,
     275              :             &conn_info,
     276              :             client,
     277              :             connection,
     278              :             conn_id,
     279              :             node.aux.clone(),
     280              :         ))
     281              :     }
     282              : 
     283              :     /// Connect to postgres over localhost.
     284              :     ///
     285              :     /// We expect postgres to be started here, so we won't do any retries.
     286              :     ///
     287              :     /// # Panics
     288              :     ///
     289              :     /// Panics if called with a non-local_proxy backend.
     290              :     #[tracing::instrument(skip_all, fields(
     291              :         pid = tracing::field::Empty,
     292              :         conn_id = tracing::field::Empty,
     293              :     ))]
     294              :     pub(crate) async fn connect_to_local_postgres(
     295              :         &self,
     296              :         ctx: &RequestContext,
     297              :         conn_info: ConnInfo,
     298              :         disable_pg_session_jwt: bool,
     299              :     ) -> Result<Client<postgres_client::Client>, HttpConnError> {
     300              :         if let Some(client) = self.local_pool.get(ctx, &conn_info)? {
     301              :             return Ok(client);
     302              :         }
     303              : 
     304              :         let local_backend = match &self.auth_backend {
     305              :             auth::Backend::ControlPlane(_, ()) => {
     306              :                 unreachable!("only local_proxy can connect to local postgres")
     307              :             }
     308              :             auth::Backend::Local(local) => local,
     309              :         };
     310              : 
     311              :         if !self.local_pool.initialized(&conn_info) {
     312              :             // only install and grant usage one at a time.
     313              :             let _permit = local_backend
     314              :                 .initialize
     315              :                 .acquire()
     316              :                 .await
     317              :                 .expect("semaphore should never be closed");
     318              : 
     319              :             // check again for race
     320              :             if !self.local_pool.initialized(&conn_info) && !disable_pg_session_jwt {
     321              :                 local_backend
     322              :                     .compute_ctl
     323              :                     .install_extension(&ExtensionInstallRequest {
     324              :                         extension: EXT_NAME,
     325              :                         database: conn_info.dbname.clone(),
     326              :                         version: EXT_VERSION,
     327              :                     })
     328              :                     .await?;
     329              : 
     330              :                 local_backend
     331              :                     .compute_ctl
     332              :                     .grant_role(&SetRoleGrantsRequest {
     333              :                         schema: EXT_SCHEMA,
     334              :                         privileges: vec![Privilege::Usage],
     335              :                         database: conn_info.dbname.clone(),
     336              :                         role: conn_info.user_info.user.clone(),
     337              :                     })
     338              :                     .await?;
     339              : 
     340              :                 self.local_pool.set_initialized(&conn_info);
     341              :             }
     342              :         }
     343              : 
     344              :         let conn_id = uuid::Uuid::new_v4();
     345              :         tracing::Span::current().record("conn_id", display(conn_id));
     346              :         info!(%conn_id, "local_pool: opening a new connection '{conn_info}'");
     347              : 
     348              :         let (key, jwk) = create_random_jwk();
     349              : 
     350              :         let mut config = local_backend
     351              :             .node_info
     352              :             .conn_info
     353              :             .to_postgres_client_config();
     354              :         config
     355              :             .user(&conn_info.user_info.user)
     356              :             .dbname(&conn_info.dbname);
     357              :         if !disable_pg_session_jwt {
     358              :             config.set_param(
     359              :                 "options",
     360              :                 &format!(
     361              :                     "-c pg_session_jwt.jwk={}",
     362              :                     serde_json::to_string(&jwk).expect("serializing jwk to json should not fail")
     363              :                 ),
     364              :             );
     365              :         }
     366              : 
     367              :         let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
     368              :         let (client, connection) = config.connect(&postgres_client::NoTls).await?;
     369              :         drop(pause);
     370              : 
     371              :         let pid = client.get_process_id();
     372              :         tracing::Span::current().record("pid", pid);
     373              : 
     374              :         let mut handle = local_conn_pool::poll_client(
     375              :             self.local_pool.clone(),
     376              :             ctx,
     377              :             conn_info,
     378              :             client,
     379              :             connection,
     380              :             key,
     381              :             conn_id,
     382              :             local_backend.node_info.aux.clone(),
     383              :         );
     384              : 
     385              :         {
     386              :             let (client, mut discard) = handle.inner();
     387              :             debug!("setting up backend session state");
     388              : 
     389              :             // initiates the auth session
     390              :             if !disable_pg_session_jwt
     391              :                 && let Err(e) = client.batch_execute("select auth.init();").await
     392              :             {
     393              :                 discard.discard();
     394              :                 return Err(e.into());
     395              :             }
     396              : 
     397              :             info!("backend session state initialized");
     398              :         }
     399              : 
     400              :         Ok(handle)
     401              :     }
     402              : }
     403              : 
     404            0 : fn create_random_jwk() -> (SigningKey, jose_jwk::Key) {
     405            0 :     let key = SigningKey::generate(&mut OsRng);
     406              : 
     407            0 :     let jwk = jose_jwk::Key::Okp(jose_jwk::Okp {
     408            0 :         crv: jose_jwk::OkpCurves::Ed25519,
     409            0 :         x: jose_b64::serde::Bytes::from(key.verifying_key().to_bytes().to_vec()),
     410            0 :         d: None,
     411            0 :     });
     412              : 
     413            0 :     (key, jwk)
     414            0 : }
     415              : 
     416              : #[derive(Debug, thiserror::Error)]
     417              : pub(crate) enum HttpConnError {
     418              :     #[error("pooled connection closed at inconsistent state")]
     419              :     ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError<uuid::Uuid>),
     420              :     #[error("could not connect to compute")]
     421              :     ConnectError(#[from] compute::ConnectionError),
     422              :     #[error("could not connect to postgres in compute")]
     423              :     PostgresConnectionError(#[from] postgres_client::Error),
     424              :     #[error("could not connect to local-proxy in compute")]
     425              :     LocalProxyConnectionError(#[from] LocalProxyConnError),
     426              :     #[error("could not parse JWT payload")]
     427              :     JwtPayloadError(serde_json::Error),
     428              : 
     429              :     #[error("could not install extension: {0}")]
     430              :     ComputeCtl(#[from] ComputeCtlError),
     431              :     #[error("could not get auth info")]
     432              :     GetAuthInfo(#[from] GetAuthInfoError),
     433              :     #[error("user not authenticated")]
     434              :     AuthError(#[from] AuthError),
     435              :     #[error("wake_compute returned error")]
     436              :     WakeCompute(#[from] WakeComputeError),
     437              :     #[error("error acquiring resource permit: {0}")]
     438              :     TooManyConnectionAttempts(#[from] ApiLockError),
     439              : }
     440              : 
     441              : impl From<connect_auth::AuthError> for HttpConnError {
     442            0 :     fn from(value: connect_auth::AuthError) -> Self {
     443            0 :         match value {
     444            0 :             connect_auth::AuthError::Auth(compute::PostgresError::Postgres(error)) => {
     445            0 :                 Self::PostgresConnectionError(error)
     446              :             }
     447            0 :             connect_auth::AuthError::Connect(error) => Self::ConnectError(error),
     448              :         }
     449            0 :     }
     450              : }
     451              : 
     452              : #[derive(Debug, thiserror::Error)]
     453              : pub(crate) enum LocalProxyConnError {
     454              :     #[error("could not establish h2 connection")]
     455              :     H2(#[from] hyper::Error),
     456              : }
     457              : 
     458              : impl ReportableError for HttpConnError {
     459            0 :     fn get_error_kind(&self) -> ErrorKind {
     460            0 :         match self {
     461            0 :             HttpConnError::ConnectError(e) => e.get_error_kind(),
     462            0 :             HttpConnError::ConnectionClosedAbruptly(_) => ErrorKind::Compute,
     463            0 :             HttpConnError::PostgresConnectionError(p) => match p.as_db_error() {
     464              :                 // user provided a wrong database name
     465            0 :                 Some(err) if err.code() == &SqlState::INVALID_CATALOG_NAME => ErrorKind::User,
     466              :                 // postgres rejected the connection
     467            0 :                 Some(_) => ErrorKind::Postgres,
     468              :                 // couldn't even reach postgres
     469            0 :                 None => ErrorKind::Compute,
     470              :             },
     471            0 :             HttpConnError::LocalProxyConnectionError(_) => ErrorKind::Compute,
     472            0 :             HttpConnError::ComputeCtl(_) => ErrorKind::Service,
     473            0 :             HttpConnError::JwtPayloadError(_) => ErrorKind::User,
     474            0 :             HttpConnError::GetAuthInfo(a) => a.get_error_kind(),
     475            0 :             HttpConnError::AuthError(a) => a.get_error_kind(),
     476            0 :             HttpConnError::WakeCompute(w) => w.get_error_kind(),
     477            0 :             HttpConnError::TooManyConnectionAttempts(w) => w.get_error_kind(),
     478              :         }
     479            0 :     }
     480              : }
     481              : 
     482              : impl UserFacingError for HttpConnError {
     483            0 :     fn to_string_client(&self) -> String {
     484            0 :         match self {
     485            0 :             HttpConnError::ConnectError(p) => p.to_string_client(),
     486            0 :             HttpConnError::ConnectionClosedAbruptly(_) => self.to_string(),
     487            0 :             HttpConnError::PostgresConnectionError(p) => p.to_string(),
     488            0 :             HttpConnError::LocalProxyConnectionError(p) => p.to_string(),
     489            0 :             HttpConnError::ComputeCtl(_) => "could not set up the JWT authorization database extension".to_string(),
     490            0 :             HttpConnError::JwtPayloadError(p) => p.to_string(),
     491            0 :             HttpConnError::GetAuthInfo(c) => c.to_string_client(),
     492            0 :             HttpConnError::AuthError(c) => c.to_string_client(),
     493            0 :             HttpConnError::WakeCompute(c) => c.to_string_client(),
     494              :             HttpConnError::TooManyConnectionAttempts(_) => {
     495            0 :                 "Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
     496              :             }
     497              :         }
     498            0 :     }
     499              : }
     500              : 
     501              : impl ReportableError for LocalProxyConnError {
     502            0 :     fn get_error_kind(&self) -> ErrorKind {
     503            0 :         match self {
     504            0 :             LocalProxyConnError::H2(_) => ErrorKind::Compute,
     505              :         }
     506            0 :     }
     507              : }
     508              : 
     509              : impl UserFacingError for LocalProxyConnError {
     510            0 :     fn to_string_client(&self) -> String {
     511            0 :         "Could not establish HTTP connection to the database".to_string()
     512            0 :     }
     513              : }
        

Generated by: LCOV version 2.1-beta