LCOV - code coverage report
Current view: top level - proxy/src/serverless - local_conn_pool.rs (source / functions) Coverage Total Hit
Test: 6df3fc19ec669bcfbbf9aba41d1338898d24eaa0.info Lines: 28.6 % 262 75
Test Date: 2025-03-12 18:28:53 Functions: 14.3 % 28 4

            Line data    Source code
       1              : //! Manages the pool of connections between local_proxy and postgres.
       2              : //!
       3              : //! The pool is keyed by database and role_name, and can contain multiple connections
       4              : //! shared between users.
       5              : //!
       6              : //! The pool manages the pg_session_jwt extension used for authorizing
       7              : //! requests in the db.
       8              : //!
       9              : //! The first time a db/role pair is seen, local_proxy attempts to install the extension
      10              : //! and grant usage to the role on the given schema.
      11              : 
      12              : use std::collections::HashMap;
      13              : use std::pin::pin;
      14              : use std::sync::Arc;
      15              : use std::sync::atomic::AtomicUsize;
      16              : use std::task::{Poll, ready};
      17              : use std::time::Duration;
      18              : 
      19              : use ed25519_dalek::{Signature, Signer, SigningKey};
      20              : use futures::Future;
      21              : use futures::future::poll_fn;
      22              : use indexmap::IndexMap;
      23              : use jose_jwk::jose_b64::base64ct::{Base64UrlUnpadded, Encoding};
      24              : use parking_lot::RwLock;
      25              : use postgres_client::AsyncMessage;
      26              : use postgres_client::tls::NoTlsStream;
      27              : use serde_json::value::RawValue;
      28              : use tokio::net::TcpStream;
      29              : use tokio::time::Instant;
      30              : use tokio_util::sync::CancellationToken;
      31              : use tracing::{Instrument, debug, error, info, info_span, warn};
      32              : 
      33              : use super::backend::HttpConnError;
      34              : use super::conn_pool_lib::{
      35              :     Client, ClientDataEnum, ClientInnerCommon, ClientInnerExt, ConnInfo, DbUserConn,
      36              :     EndpointConnPool,
      37              : };
      38              : use super::sql_over_http::SqlOverHttpError;
      39              : use crate::context::RequestContext;
      40              : use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
      41              : use crate::metrics::Metrics;
      42              : 
      43              : pub(crate) const EXT_NAME: &str = "pg_session_jwt";
      44              : pub(crate) const EXT_VERSION: &str = "0.2.0";
      45              : pub(crate) const EXT_SCHEMA: &str = "auth";
      46              : 
      47              : #[derive(Clone)]
      48              : pub(crate) struct ClientDataLocal {
      49              :     session: tokio::sync::watch::Sender<uuid::Uuid>,
      50              :     cancel: CancellationToken,
      51              :     key: SigningKey,
      52              :     jti: u64,
      53              : }
      54              : 
      55              : impl ClientDataLocal {
      56            0 :     pub fn session(&mut self) -> &mut tokio::sync::watch::Sender<uuid::Uuid> {
      57            0 :         &mut self.session
      58            0 :     }
      59              : 
      60            0 :     pub fn cancel(&mut self) {
      61            0 :         self.cancel.cancel();
      62            0 :     }
      63              : }
      64              : 
      65              : pub(crate) struct LocalConnPool<C: ClientInnerExt> {
      66              :     global_pool: Arc<RwLock<EndpointConnPool<C>>>,
      67              : 
      68              :     config: &'static crate::config::HttpConfig,
      69              : }
      70              : 
      71              : impl<C: ClientInnerExt> LocalConnPool<C> {
      72            0 :     pub(crate) fn new(config: &'static crate::config::HttpConfig) -> Arc<Self> {
      73            0 :         Arc::new(Self {
      74            0 :             global_pool: Arc::new(RwLock::new(EndpointConnPool::new(
      75            0 :                 HashMap::new(),
      76            0 :                 0,
      77            0 :                 config.pool_options.max_conns_per_endpoint,
      78            0 :                 Arc::new(AtomicUsize::new(0)),
      79            0 :                 config.pool_options.max_total_conns,
      80            0 :                 String::from("local_pool"),
      81            0 :             ))),
      82            0 :             config,
      83            0 :         })
      84            0 :     }
      85              : 
      86            0 :     pub(crate) fn get_idle_timeout(&self) -> Duration {
      87            0 :         self.config.pool_options.idle_timeout
      88            0 :     }
      89              : 
      90            0 :     pub(crate) fn get(
      91            0 :         self: &Arc<Self>,
      92            0 :         ctx: &RequestContext,
      93            0 :         conn_info: &ConnInfo,
      94            0 :     ) -> Result<Option<Client<C>>, HttpConnError> {
      95            0 :         let client = self
      96            0 :             .global_pool
      97            0 :             .write()
      98            0 :             .get_conn_entry(conn_info.db_and_user())
      99            0 :             .map(|entry| entry.conn);
     100              : 
     101              :         // ok return cached connection if found and establish a new one otherwise
     102            0 :         if let Some(mut client) = client {
     103            0 :             if client.inner.is_closed() {
     104            0 :                 info!("local_pool: cached connection '{conn_info}' is closed, opening a new one");
     105            0 :                 return Ok(None);
     106            0 :             }
     107            0 : 
     108            0 :             tracing::Span::current()
     109            0 :                 .record("conn_id", tracing::field::display(client.get_conn_id()));
     110            0 :             tracing::Span::current().record(
     111            0 :                 "pid",
     112            0 :                 tracing::field::display(client.inner.get_process_id()),
     113            0 :             );
     114            0 :             debug!(
     115            0 :                 cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
     116            0 :                 "local_pool: reusing connection '{conn_info}'"
     117              :             );
     118              : 
     119            0 :             match client.get_data() {
     120            0 :                 ClientDataEnum::Local(data) => {
     121            0 :                     data.session().send(ctx.session_id())?;
     122              :                 }
     123              : 
     124            0 :                 ClientDataEnum::Remote(data) => {
     125            0 :                     data.session().send(ctx.session_id())?;
     126              :                 }
     127            0 :                 ClientDataEnum::Http(_) => (),
     128              :             }
     129              : 
     130            0 :             ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
     131            0 :             ctx.success();
     132            0 : 
     133            0 :             return Ok(Some(Client::new(
     134            0 :                 client,
     135            0 :                 conn_info.clone(),
     136            0 :                 Arc::downgrade(&self.global_pool),
     137            0 :             )));
     138            0 :         }
     139            0 :         Ok(None)
     140            0 :     }
     141              : 
     142            0 :     pub(crate) fn initialized(self: &Arc<Self>, conn_info: &ConnInfo) -> bool {
     143            0 :         if let Some(pool) = self.global_pool.read().get_pool(conn_info.db_and_user()) {
     144            0 :             return pool.is_initialized();
     145            0 :         }
     146            0 :         false
     147            0 :     }
     148              : 
     149            0 :     pub(crate) fn set_initialized(self: &Arc<Self>, conn_info: &ConnInfo) {
     150            0 :         if let Some(pool) = self
     151            0 :             .global_pool
     152            0 :             .write()
     153            0 :             .get_pool_mut(conn_info.db_and_user())
     154            0 :         {
     155            0 :             pool.set_initialized();
     156            0 :         }
     157            0 :     }
     158              : }
     159              : 
     160              : #[allow(clippy::too_many_arguments)]
     161            0 : pub(crate) fn poll_client<C: ClientInnerExt>(
     162            0 :     global_pool: Arc<LocalConnPool<C>>,
     163            0 :     ctx: &RequestContext,
     164            0 :     conn_info: ConnInfo,
     165            0 :     client: C,
     166            0 :     mut connection: postgres_client::Connection<TcpStream, NoTlsStream>,
     167            0 :     key: SigningKey,
     168            0 :     conn_id: uuid::Uuid,
     169            0 :     aux: MetricsAuxInfo,
     170            0 : ) -> Client<C> {
     171            0 :     let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol());
     172            0 :     let mut session_id = ctx.session_id();
     173            0 :     let (tx, mut rx) = tokio::sync::watch::channel(session_id);
     174              : 
     175            0 :     let span = info_span!(parent: None, "connection", %conn_id);
     176            0 :     let cold_start_info = ctx.cold_start_info();
     177            0 :     span.in_scope(|| {
     178            0 :         info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection");
     179            0 :     });
     180            0 :     let pool = Arc::downgrade(&global_pool);
     181            0 : 
     182            0 :     let db_user = conn_info.db_and_user();
     183            0 :     let idle = global_pool.get_idle_timeout();
     184            0 :     let cancel = CancellationToken::new();
     185            0 :     let cancelled = cancel.clone().cancelled_owned();
     186            0 : 
     187            0 :     tokio::spawn(
     188            0 :     async move {
     189            0 :         let _conn_gauge = conn_gauge;
     190            0 :         let mut idle_timeout = pin!(tokio::time::sleep(idle));
     191            0 :         let mut cancelled = pin!(cancelled);
     192            0 : 
     193            0 :         poll_fn(move |cx| {
     194            0 :             if cancelled.as_mut().poll(cx).is_ready() {
     195            0 :                 info!("connection dropped");
     196            0 :                 return Poll::Ready(())
     197            0 :             }
     198            0 : 
     199            0 :             match rx.has_changed() {
     200              :                 Ok(true) => {
     201            0 :                     session_id = *rx.borrow_and_update();
     202            0 :                     info!(%session_id, "changed session");
     203            0 :                     idle_timeout.as_mut().reset(Instant::now() + idle);
     204              :                 }
     205              :                 Err(_) => {
     206            0 :                     info!("connection dropped");
     207            0 :                     return Poll::Ready(())
     208              :                 }
     209            0 :                 _ => {}
     210              :             }
     211              : 
     212              :             // 5 minute idle connection timeout
     213            0 :             if idle_timeout.as_mut().poll(cx).is_ready() {
     214            0 :                 idle_timeout.as_mut().reset(Instant::now() + idle);
     215            0 :                 info!("connection idle");
     216            0 :                 if let Some(pool) = pool.clone().upgrade() {
     217              :                     // remove client from pool - should close the connection if it's idle.
     218              :                     // does nothing if the client is currently checked-out and in-use
     219            0 :                     if pool.global_pool.write().remove_client(db_user.clone(), conn_id) {
     220            0 :                         info!("idle connection removed");
     221            0 :                     }
     222            0 :                 }
     223            0 :             }
     224              : 
     225              :             loop {
     226            0 :                 let message = ready!(connection.poll_message(cx));
     227              : 
     228            0 :                 match message {
     229            0 :                     Some(Ok(AsyncMessage::Notice(notice))) => {
     230            0 :                         info!(%session_id, "notice: {}", notice);
     231              :                     }
     232            0 :                     Some(Ok(AsyncMessage::Notification(notif))) => {
     233            0 :                         warn!(%session_id, pid = notif.process_id(), channel = notif.channel(), "notification received");
     234              :                     }
     235              :                     Some(Ok(_)) => {
     236            0 :                         warn!(%session_id, "unknown message");
     237              :                     }
     238            0 :                     Some(Err(e)) => {
     239            0 :                         error!(%session_id, "connection error: {}", e);
     240            0 :                         break
     241              :                     }
     242              :                     None => {
     243            0 :                         info!("connection closed");
     244            0 :                         break
     245              :                     }
     246              :                 }
     247              :             }
     248              : 
     249              :             // remove from connection pool
     250            0 :             if let Some(pool) = pool.clone().upgrade() {
     251            0 :                 if pool.global_pool.write().remove_client(db_user.clone(), conn_id) {
     252            0 :                     info!("closed connection removed");
     253            0 :                 }
     254            0 :             }
     255              : 
     256            0 :             Poll::Ready(())
     257            0 :         }).await;
     258              : 
     259            0 :     }
     260            0 :     .instrument(span));
     261            0 : 
     262            0 :     let inner = ClientInnerCommon {
     263            0 :         inner: client,
     264            0 :         aux,
     265            0 :         conn_id,
     266            0 :         data: ClientDataEnum::Local(ClientDataLocal {
     267            0 :             session: tx,
     268            0 :             cancel,
     269            0 :             key,
     270            0 :             jti: 0,
     271            0 :         }),
     272            0 :     };
     273            0 : 
     274            0 :     Client::new(inner, conn_info, Arc::downgrade(&global_pool.global_pool))
     275            0 : }
     276              : 
     277              : impl ClientInnerCommon<postgres_client::Client> {
     278            0 :     pub(crate) async fn set_jwt_session(&mut self, payload: &[u8]) -> Result<(), SqlOverHttpError> {
     279            0 :         if let ClientDataEnum::Local(local_data) = &mut self.data {
     280            0 :             local_data.jti += 1;
     281            0 :             let token = resign_jwt(&local_data.key, payload, local_data.jti)?;
     282              : 
     283            0 :             self.inner
     284            0 :                 .discard_all()
     285            0 :                 .await
     286            0 :                 .map_err(SqlOverHttpError::InternalPostgres)?;
     287              : 
     288              :             // initiates the auth session
     289              :             // this is safe from query injections as the jwt format free of any escape characters.
     290            0 :             let query = format!("select auth.jwt_session_init('{token}')");
     291            0 :             self.inner
     292            0 :                 .batch_execute(&query)
     293            0 :                 .await
     294            0 :                 .map_err(SqlOverHttpError::InternalPostgres)?;
     295              : 
     296            0 :             let pid = self.inner.get_process_id();
     297            0 :             info!(pid, jti = local_data.jti, "user session state init");
     298            0 :             Ok(())
     299              :         } else {
     300            0 :             panic!("unexpected client data type");
     301              :         }
     302            0 :     }
     303              : }
     304              : 
     305              : /// implements relatively efficient in-place json object key upserting
     306              : ///
     307              : /// only supports top-level keys
     308            1 : fn upsert_json_object(
     309            1 :     payload: &[u8],
     310            1 :     key: &str,
     311            1 :     value: &RawValue,
     312            1 : ) -> Result<String, serde_json::Error> {
     313            1 :     let mut payload = serde_json::from_slice::<IndexMap<&str, &RawValue>>(payload)?;
     314            1 :     payload.insert(key, value);
     315            1 :     serde_json::to_string(&payload)
     316            1 : }
     317              : 
     318            1 : fn resign_jwt(sk: &SigningKey, payload: &[u8], jti: u64) -> Result<String, HttpConnError> {
     319            1 :     let mut buffer = itoa::Buffer::new();
     320            1 : 
     321            1 :     // encode the jti integer to a json rawvalue
     322            1 :     let jti = serde_json::from_str::<&RawValue>(buffer.format(jti))
     323            1 :         .expect("itoa formatted integer should be guaranteed valid json");
     324              : 
     325              :     // update the jti in-place
     326            1 :     let payload =
     327            1 :         upsert_json_object(payload, "jti", jti).map_err(HttpConnError::JwtPayloadError)?;
     328              : 
     329              :     // sign the jwt
     330            1 :     let token = sign_jwt(sk, payload.as_bytes());
     331            1 : 
     332            1 :     Ok(token)
     333            1 : }
     334              : 
     335            1 : fn sign_jwt(sk: &SigningKey, payload: &[u8]) -> String {
     336            1 :     let header_len = 20;
     337            1 :     let payload_len = Base64UrlUnpadded::encoded_len(payload);
     338            1 :     let signature_len = Base64UrlUnpadded::encoded_len(&[0; 64]);
     339            1 :     let total_len = header_len + payload_len + signature_len + 2;
     340            1 : 
     341            1 :     let mut jwt = String::with_capacity(total_len);
     342            1 :     let cap = jwt.capacity();
     343            1 : 
     344            1 :     // we only need an empty header with the alg specified.
     345            1 :     // base64url(r#"{"alg":"EdDSA"}"#) == "eyJhbGciOiJFZERTQSJ9"
     346            1 :     jwt.push_str("eyJhbGciOiJFZERTQSJ9.");
     347            1 : 
     348            1 :     // encode the jwt payload in-place
     349            1 :     base64::encode_config_buf(payload, base64::URL_SAFE_NO_PAD, &mut jwt);
     350            1 : 
     351            1 :     // create the signature from the encoded header || payload
     352            1 :     let sig: Signature = sk.sign(jwt.as_bytes());
     353            1 : 
     354            1 :     jwt.push('.');
     355            1 : 
     356            1 :     // encode the jwt signature in-place
     357            1 :     base64::encode_config_buf(sig.to_bytes(), base64::URL_SAFE_NO_PAD, &mut jwt);
     358            1 : 
     359            1 :     debug_assert_eq!(
     360            1 :         jwt.len(),
     361              :         total_len,
     362            0 :         "the jwt len should match our expected len"
     363              :     );
     364            1 :     debug_assert_eq!(jwt.capacity(), cap, "the jwt capacity should not change");
     365              : 
     366            1 :     jwt
     367            1 : }
     368              : 
     369              : #[cfg(test)]
     370              : #[expect(clippy::unwrap_used)]
     371              : mod tests {
     372              :     use ed25519_dalek::SigningKey;
     373              :     use typed_json::json;
     374              : 
     375              :     use super::resign_jwt;
     376              : 
     377              :     #[test]
     378            1 :     fn jwt_token_snapshot() {
     379            1 :         let key = SigningKey::from_bytes(&[1; 32]);
     380            1 :         let data =
     381            1 :             json!({"foo":"bar","jti":"foo\nbar","nested":{"jti":"tricky nesting"}}).to_string();
     382            1 : 
     383            1 :         let jwt = resign_jwt(&key, data.as_bytes(), 2).unwrap();
     384            1 : 
     385            1 :         // To validate the JWT, copy the JWT string and paste it into https://jwt.io/.
     386            1 :         // In the public-key box, paste the following jwk public key
     387            1 :         // `{"kty":"OKP","crv":"Ed25519","x":"iojj3XQJ8ZX9UtstPLpdcspnCb8dlBIb83SIAbQPb1w"}`
     388            1 :         // Note - jwt.io doesn't support EdDSA :(
     389            1 :         // https://github.com/jsonwebtoken/jsonwebtoken.github.io/issues/509
     390            1 : 
     391            1 :         // let jwk = jose_jwk::Key::Okp(jose_jwk::Okp {
     392            1 :         //     crv: jose_jwk::OkpCurves::Ed25519,
     393            1 :         //     x: jose_jwk::jose_b64::serde::Bytes::from(key.verifying_key().to_bytes().to_vec()),
     394            1 :         //     d: None,
     395            1 :         // });
     396            1 :         // println!("{}", serde_json::to_string(&jwk).unwrap());
     397            1 : 
     398            1 :         assert_eq!(
     399            1 :             jwt,
     400            1 :             "eyJhbGciOiJFZERTQSJ9.eyJmb28iOiJiYXIiLCJqdGkiOjIsIm5lc3RlZCI6eyJqdGkiOiJ0cmlja3kgbmVzdGluZyJ9fQ.Cvyc2By33KI0f0obystwdy8PN111L3Sc9_Mr2CU3XshtSqSdxuRxNEZGbb_RvyJf2IzheC_s7aBZ-jLeQ9N0Bg"
     401            1 :         );
     402            1 :     }
     403              : }
        

Generated by: LCOV version 2.1-beta