LCOV - code coverage report
Current view: top level - proxy/src - cancellation.rs (source / functions) Coverage Total Hit
Test: f2bfe5dc5ab550768e936d6bc7b94d9b2e2d4cc9.info Lines: 40.8 % 309 126
Test Date: 2025-01-27 20:39:28 Functions: 30.8 % 39 12

            Line data    Source code
       1              : use std::net::{IpAddr, SocketAddr};
       2              : use std::sync::Arc;
       3              : 
       4              : use dashmap::DashMap;
       5              : use ipnet::{IpNet, Ipv4Net, Ipv6Net};
       6              : use postgres_client::tls::MakeTlsConnect;
       7              : use postgres_client::CancelToken;
       8              : use pq_proto::CancelKeyData;
       9              : use thiserror::Error;
      10              : use tokio::net::TcpStream;
      11              : use tokio::sync::Mutex;
      12              : use tracing::{debug, info};
      13              : use uuid::Uuid;
      14              : 
      15              : use crate::auth::backend::{BackendIpAllowlist, ComputeUserInfo};
      16              : use crate::auth::{check_peer_addr_is_in_list, AuthError, IpPattern};
      17              : use crate::config::ComputeConfig;
      18              : use crate::context::RequestContext;
      19              : use crate::error::ReportableError;
      20              : use crate::ext::LockExt;
      21              : use crate::metrics::{CancellationRequest, CancellationSource, Metrics};
      22              : use crate::rate_limiter::LeakyBucketRateLimiter;
      23              : use crate::redis::cancellation_publisher::{
      24              :     CancellationPublisher, CancellationPublisherMut, RedisPublisherClient,
      25              : };
      26              : use crate::tls::postgres_rustls::MakeRustlsConnect;
      27              : 
      28              : pub type CancelMap = Arc<DashMap<CancelKeyData, Option<CancelClosure>>>;
      29              : pub type CancellationHandlerMain = CancellationHandler<Option<Arc<Mutex<RedisPublisherClient>>>>;
      30              : pub(crate) type CancellationHandlerMainInternal = Option<Arc<Mutex<RedisPublisherClient>>>;
      31              : 
      32              : type IpSubnetKey = IpNet;
      33              : 
      34              : /// Enables serving `CancelRequest`s.
      35              : ///
      36              : /// If `CancellationPublisher` is available, cancel request will be used to publish the cancellation key to other proxy instances.
      37              : pub struct CancellationHandler<P> {
      38              :     compute_config: &'static ComputeConfig,
      39              :     map: CancelMap,
      40              :     client: P,
      41              :     /// This field used for the monitoring purposes.
      42              :     /// Represents the source of the cancellation request.
      43              :     from: CancellationSource,
      44              :     // rate limiter of cancellation requests
      45              :     limiter: Arc<std::sync::Mutex<LeakyBucketRateLimiter<IpSubnetKey>>>,
      46              : }
      47              : 
      48              : #[derive(Debug, Error)]
      49              : pub(crate) enum CancelError {
      50              :     #[error("{0}")]
      51              :     IO(#[from] std::io::Error),
      52              : 
      53              :     #[error("{0}")]
      54              :     Postgres(#[from] postgres_client::Error),
      55              : 
      56              :     #[error("rate limit exceeded")]
      57              :     RateLimit,
      58              : 
      59              :     #[error("IP is not allowed")]
      60              :     IpNotAllowed,
      61              : 
      62              :     #[error("Authentication backend error")]
      63              :     AuthError(#[from] AuthError),
      64              : }
      65              : 
      66              : impl ReportableError for CancelError {
      67            0 :     fn get_error_kind(&self) -> crate::error::ErrorKind {
      68            0 :         match self {
      69            0 :             CancelError::IO(_) => crate::error::ErrorKind::Compute,
      70            0 :             CancelError::Postgres(e) if e.as_db_error().is_some() => {
      71            0 :                 crate::error::ErrorKind::Postgres
      72              :             }
      73            0 :             CancelError::Postgres(_) => crate::error::ErrorKind::Compute,
      74            0 :             CancelError::RateLimit => crate::error::ErrorKind::RateLimit,
      75            0 :             CancelError::IpNotAllowed => crate::error::ErrorKind::User,
      76            0 :             CancelError::AuthError(_) => crate::error::ErrorKind::ControlPlane,
      77              :         }
      78            0 :     }
      79              : }
      80              : 
      81              : impl<P: CancellationPublisher> CancellationHandler<P> {
      82              :     /// Run async action within an ephemeral session identified by [`CancelKeyData`].
      83            1 :     pub(crate) fn get_session(self: Arc<Self>) -> Session<P> {
      84              :         // we intentionally generate a random "backend pid" and "secret key" here.
      85              :         // we use the corresponding u64 as an identifier for the
      86              :         // actual endpoint+pid+secret for postgres/pgbouncer.
      87              :         //
      88              :         // if we forwarded the backend_pid from postgres to the client, there would be a lot
      89              :         // of overlap between our computes as most pids are small (~100).
      90            1 :         let key = loop {
      91            1 :             let key = rand::random();
      92            1 : 
      93            1 :             // Random key collisions are unlikely to happen here, but they're still possible,
      94            1 :             // which is why we have to take care not to rewrite an existing key.
      95            1 :             match self.map.entry(key) {
      96            0 :                 dashmap::mapref::entry::Entry::Occupied(_) => continue,
      97            1 :                 dashmap::mapref::entry::Entry::Vacant(e) => {
      98            1 :                     e.insert(None);
      99            1 :                 }
     100            1 :             }
     101            1 :             break key;
     102            1 :         };
     103            1 : 
     104            1 :         debug!("registered new query cancellation key {key}");
     105            1 :         Session {
     106            1 :             key,
     107            1 :             cancellation_handler: self,
     108            1 :         }
     109            1 :     }
     110              : 
     111              :     /// Cancelling only in notification, will be removed
     112            1 :     pub(crate) async fn cancel_session(
     113            1 :         &self,
     114            1 :         key: CancelKeyData,
     115            1 :         session_id: Uuid,
     116            1 :         peer_addr: IpAddr,
     117            1 :         check_allowed: bool,
     118            1 :     ) -> Result<(), CancelError> {
     119            1 :         // TODO: check for unspecified address is only for backward compatibility, should be removed
     120            1 :         if !peer_addr.is_unspecified() {
     121            1 :             let subnet_key = match peer_addr {
     122            1 :                 IpAddr::V4(ip) => IpNet::V4(Ipv4Net::new_assert(ip, 24).trunc()), // use defaut mask here
     123            0 :                 IpAddr::V6(ip) => IpNet::V6(Ipv6Net::new_assert(ip, 64).trunc()),
     124              :             };
     125            1 :             if !self.limiter.lock_propagate_poison().check(subnet_key, 1) {
     126              :                 // log only the subnet part of the IP address to know which subnet is rate limited
     127            0 :                 tracing::warn!("Rate limit exceeded. Skipping cancellation message, {subnet_key}");
     128            0 :                 Metrics::get()
     129            0 :                     .proxy
     130            0 :                     .cancellation_requests_total
     131            0 :                     .inc(CancellationRequest {
     132            0 :                         source: self.from,
     133            0 :                         kind: crate::metrics::CancellationOutcome::RateLimitExceeded,
     134            0 :                     });
     135            0 :                 return Err(CancelError::RateLimit);
     136            1 :             }
     137            0 :         }
     138              : 
     139              :         // NB: we should immediately release the lock after cloning the token.
     140            1 :         let cancel_state = self.map.get(&key).and_then(|x| x.clone());
     141            1 :         let Some(cancel_closure) = cancel_state else {
     142            1 :             tracing::warn!("query cancellation key not found: {key}");
     143            1 :             Metrics::get()
     144            1 :                 .proxy
     145            1 :                 .cancellation_requests_total
     146            1 :                 .inc(CancellationRequest {
     147            1 :                     source: self.from,
     148            1 :                     kind: crate::metrics::CancellationOutcome::NotFound,
     149            1 :                 });
     150            1 : 
     151            1 :             if session_id == Uuid::nil() {
     152              :                 // was already published, do not publish it again
     153            0 :                 return Ok(());
     154            1 :             }
     155            1 : 
     156            1 :             match self.client.try_publish(key, session_id, peer_addr).await {
     157            1 :                 Ok(()) => {} // do nothing
     158            0 :                 Err(e) => {
     159            0 :                     // log it here since cancel_session could be spawned in a task
     160            0 :                     tracing::error!("failed to publish cancellation key: {key}, error: {e}");
     161            0 :                     return Err(CancelError::IO(std::io::Error::new(
     162            0 :                         std::io::ErrorKind::Other,
     163            0 :                         e.to_string(),
     164            0 :                     )));
     165              :                 }
     166              :             }
     167            1 :             return Ok(());
     168              :         };
     169              : 
     170            0 :         if check_allowed
     171            0 :             && !check_peer_addr_is_in_list(&peer_addr, cancel_closure.ip_allowlist.as_slice())
     172              :         {
     173              :             // log it here since cancel_session could be spawned in a task
     174            0 :             tracing::warn!("IP is not allowed to cancel the query: {key}");
     175            0 :             return Err(CancelError::IpNotAllowed);
     176            0 :         }
     177            0 : 
     178            0 :         Metrics::get()
     179            0 :             .proxy
     180            0 :             .cancellation_requests_total
     181            0 :             .inc(CancellationRequest {
     182            0 :                 source: self.from,
     183            0 :                 kind: crate::metrics::CancellationOutcome::Found,
     184            0 :             });
     185            0 :         info!(
     186            0 :             "cancelling query per user's request using key {key}, hostname {}, address: {}",
     187              :             cancel_closure.hostname, cancel_closure.socket_addr
     188              :         );
     189            0 :         cancel_closure.try_cancel_query(self.compute_config).await
     190            1 :     }
     191              : 
     192              :     /// Try to cancel a running query for the corresponding connection.
     193              :     /// If the cancellation key is not found, it will be published to Redis.
     194              :     /// check_allowed - if true, check if the IP is allowed to cancel the query.
     195              :     /// Will fetch IP allowlist internally.
     196              :     ///
     197              :     /// return Result primarily for tests
     198            0 :     pub(crate) async fn cancel_session_auth<T: BackendIpAllowlist>(
     199            0 :         &self,
     200            0 :         key: CancelKeyData,
     201            0 :         ctx: RequestContext,
     202            0 :         check_allowed: bool,
     203            0 :         auth_backend: &T,
     204            0 :     ) -> Result<(), CancelError> {
     205            0 :         // TODO: check for unspecified address is only for backward compatibility, should be removed
     206            0 :         if !ctx.peer_addr().is_unspecified() {
     207            0 :             let subnet_key = match ctx.peer_addr() {
     208            0 :                 IpAddr::V4(ip) => IpNet::V4(Ipv4Net::new_assert(ip, 24).trunc()), // use defaut mask here
     209            0 :                 IpAddr::V6(ip) => IpNet::V6(Ipv6Net::new_assert(ip, 64).trunc()),
     210              :             };
     211            0 :             if !self.limiter.lock_propagate_poison().check(subnet_key, 1) {
     212              :                 // log only the subnet part of the IP address to know which subnet is rate limited
     213            0 :                 tracing::warn!("Rate limit exceeded. Skipping cancellation message, {subnet_key}");
     214            0 :                 Metrics::get()
     215            0 :                     .proxy
     216            0 :                     .cancellation_requests_total
     217            0 :                     .inc(CancellationRequest {
     218            0 :                         source: self.from,
     219            0 :                         kind: crate::metrics::CancellationOutcome::RateLimitExceeded,
     220            0 :                     });
     221            0 :                 return Err(CancelError::RateLimit);
     222            0 :             }
     223            0 :         }
     224              : 
     225              :         // NB: we should immediately release the lock after cloning the token.
     226            0 :         let cancel_state = self.map.get(&key).and_then(|x| x.clone());
     227            0 :         let Some(cancel_closure) = cancel_state else {
     228            0 :             tracing::warn!("query cancellation key not found: {key}");
     229            0 :             Metrics::get()
     230            0 :                 .proxy
     231            0 :                 .cancellation_requests_total
     232            0 :                 .inc(CancellationRequest {
     233            0 :                     source: self.from,
     234            0 :                     kind: crate::metrics::CancellationOutcome::NotFound,
     235            0 :                 });
     236            0 : 
     237            0 :             if ctx.session_id() == Uuid::nil() {
     238              :                 // was already published, do not publish it again
     239            0 :                 return Ok(());
     240            0 :             }
     241            0 : 
     242            0 :             match self
     243            0 :                 .client
     244            0 :                 .try_publish(key, ctx.session_id(), ctx.peer_addr())
     245            0 :                 .await
     246              :             {
     247            0 :                 Ok(()) => {} // do nothing
     248            0 :                 Err(e) => {
     249            0 :                     // log it here since cancel_session could be spawned in a task
     250            0 :                     tracing::error!("failed to publish cancellation key: {key}, error: {e}");
     251            0 :                     return Err(CancelError::IO(std::io::Error::new(
     252            0 :                         std::io::ErrorKind::Other,
     253            0 :                         e.to_string(),
     254            0 :                     )));
     255              :                 }
     256              :             }
     257            0 :             return Ok(());
     258              :         };
     259              : 
     260            0 :         let ip_allowlist = auth_backend
     261            0 :             .get_allowed_ips(&ctx, &cancel_closure.user_info)
     262            0 :             .await
     263            0 :             .map_err(CancelError::AuthError)?;
     264              : 
     265            0 :         if check_allowed && !check_peer_addr_is_in_list(&ctx.peer_addr(), &ip_allowlist) {
     266              :             // log it here since cancel_session could be spawned in a task
     267            0 :             tracing::warn!("IP is not allowed to cancel the query: {key}");
     268            0 :             return Err(CancelError::IpNotAllowed);
     269            0 :         }
     270            0 : 
     271            0 :         Metrics::get()
     272            0 :             .proxy
     273            0 :             .cancellation_requests_total
     274            0 :             .inc(CancellationRequest {
     275            0 :                 source: self.from,
     276            0 :                 kind: crate::metrics::CancellationOutcome::Found,
     277            0 :             });
     278            0 :         info!("cancelling query per user's request using key {key}");
     279            0 :         cancel_closure.try_cancel_query(self.compute_config).await
     280            0 :     }
     281              : 
     282              :     #[cfg(test)]
     283            1 :     fn contains(&self, session: &Session<P>) -> bool {
     284            1 :         self.map.contains_key(&session.key)
     285            1 :     }
     286              : 
     287              :     #[cfg(test)]
     288            1 :     fn is_empty(&self) -> bool {
     289            1 :         self.map.is_empty()
     290            1 :     }
     291              : }
     292              : 
     293              : impl CancellationHandler<()> {
     294            2 :     pub fn new(
     295            2 :         compute_config: &'static ComputeConfig,
     296            2 :         map: CancelMap,
     297            2 :         from: CancellationSource,
     298            2 :     ) -> Self {
     299            2 :         Self {
     300            2 :             compute_config,
     301            2 :             map,
     302            2 :             client: (),
     303            2 :             from,
     304            2 :             limiter: Arc::new(std::sync::Mutex::new(
     305            2 :                 LeakyBucketRateLimiter::<IpSubnetKey>::new_with_shards(
     306            2 :                     LeakyBucketRateLimiter::<IpSubnetKey>::DEFAULT,
     307            2 :                     64,
     308            2 :                 ),
     309            2 :             )),
     310            2 :         }
     311            2 :     }
     312              : }
     313              : 
     314              : impl<P: CancellationPublisherMut> CancellationHandler<Option<Arc<Mutex<P>>>> {
     315            0 :     pub fn new(
     316            0 :         compute_config: &'static ComputeConfig,
     317            0 :         map: CancelMap,
     318            0 :         client: Option<Arc<Mutex<P>>>,
     319            0 :         from: CancellationSource,
     320            0 :     ) -> Self {
     321            0 :         Self {
     322            0 :             compute_config,
     323            0 :             map,
     324            0 :             client,
     325            0 :             from,
     326            0 :             limiter: Arc::new(std::sync::Mutex::new(
     327            0 :                 LeakyBucketRateLimiter::<IpSubnetKey>::new_with_shards(
     328            0 :                     LeakyBucketRateLimiter::<IpSubnetKey>::DEFAULT,
     329            0 :                     64,
     330            0 :                 ),
     331            0 :             )),
     332            0 :         }
     333            0 :     }
     334              : }
     335              : 
     336              : /// This should've been a [`std::future::Future`], but
     337              : /// it's impossible to name a type of an unboxed future
     338              : /// (we'd need something like `#![feature(type_alias_impl_trait)]`).
     339              : #[derive(Clone)]
     340              : pub struct CancelClosure {
     341              :     socket_addr: SocketAddr,
     342              :     cancel_token: CancelToken,
     343              :     ip_allowlist: Vec<IpPattern>,
     344              :     hostname: String, // for pg_sni router
     345              :     user_info: ComputeUserInfo,
     346              : }
     347              : 
     348              : impl CancelClosure {
     349            0 :     pub(crate) fn new(
     350            0 :         socket_addr: SocketAddr,
     351            0 :         cancel_token: CancelToken,
     352            0 :         ip_allowlist: Vec<IpPattern>,
     353            0 :         hostname: String,
     354            0 :         user_info: ComputeUserInfo,
     355            0 :     ) -> Self {
     356            0 :         Self {
     357            0 :             socket_addr,
     358            0 :             cancel_token,
     359            0 :             ip_allowlist,
     360            0 :             hostname,
     361            0 :             user_info,
     362            0 :         }
     363            0 :     }
     364              :     /// Cancels the query running on user's compute node.
     365            0 :     pub(crate) async fn try_cancel_query(
     366            0 :         self,
     367            0 :         compute_config: &ComputeConfig,
     368            0 :     ) -> Result<(), CancelError> {
     369            0 :         let socket = TcpStream::connect(self.socket_addr).await?;
     370              : 
     371            0 :         let mut mk_tls =
     372            0 :             crate::tls::postgres_rustls::MakeRustlsConnect::new(compute_config.tls.clone());
     373            0 :         let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
     374            0 :             &mut mk_tls,
     375            0 :             &self.hostname,
     376            0 :         )
     377            0 :         .map_err(|e| {
     378            0 :             CancelError::IO(std::io::Error::new(
     379            0 :                 std::io::ErrorKind::Other,
     380            0 :                 e.to_string(),
     381            0 :             ))
     382            0 :         })?;
     383              : 
     384            0 :         self.cancel_token.cancel_query_raw(socket, tls).await?;
     385            0 :         debug!("query was cancelled");
     386            0 :         Ok(())
     387            0 :     }
     388              : 
     389              :     /// Obsolete (will be removed after moving CancelMap to Redis), only for notifications
     390            0 :     pub(crate) fn set_ip_allowlist(&mut self, ip_allowlist: Vec<IpPattern>) {
     391            0 :         self.ip_allowlist = ip_allowlist;
     392            0 :     }
     393              : }
     394              : 
     395              : /// Helper for registering query cancellation tokens.
     396              : pub(crate) struct Session<P> {
     397              :     /// The user-facing key identifying this session.
     398              :     key: CancelKeyData,
     399              :     /// The [`CancelMap`] this session belongs to.
     400              :     cancellation_handler: Arc<CancellationHandler<P>>,
     401              : }
     402              : 
     403              : impl<P> Session<P> {
     404              :     /// Store the cancel token for the given session.
     405              :     /// This enables query cancellation in `crate::proxy::prepare_client_connection`.
     406            0 :     pub(crate) fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData {
     407            0 :         debug!("enabling query cancellation for this session");
     408            0 :         self.cancellation_handler
     409            0 :             .map
     410            0 :             .insert(self.key, Some(cancel_closure));
     411            0 : 
     412            0 :         self.key
     413            0 :     }
     414              : }
     415              : 
     416              : impl<P> Drop for Session<P> {
     417            1 :     fn drop(&mut self) {
     418            1 :         self.cancellation_handler.map.remove(&self.key);
     419            1 :         debug!("dropped query cancellation key {}", &self.key);
     420            1 :     }
     421              : }
     422              : 
     423              : #[cfg(test)]
     424              : #[expect(clippy::unwrap_used)]
     425              : mod tests {
     426              :     use std::time::Duration;
     427              : 
     428              :     use super::*;
     429              :     use crate::config::RetryConfig;
     430              :     use crate::tls::client_config::compute_client_config_with_certs;
     431              : 
     432            2 :     fn config() -> ComputeConfig {
     433            2 :         let retry = RetryConfig {
     434            2 :             base_delay: Duration::from_secs(1),
     435            2 :             max_retries: 5,
     436            2 :             backoff_factor: 2.0,
     437            2 :         };
     438            2 : 
     439            2 :         ComputeConfig {
     440            2 :             retry,
     441            2 :             tls: Arc::new(compute_client_config_with_certs(std::iter::empty())),
     442            2 :             timeout: Duration::from_secs(2),
     443            2 :         }
     444            2 :     }
     445              : 
     446              :     #[tokio::test]
     447            1 :     async fn check_session_drop() -> anyhow::Result<()> {
     448            1 :         let cancellation_handler = Arc::new(CancellationHandler::<()>::new(
     449            1 :             Box::leak(Box::new(config())),
     450            1 :             CancelMap::default(),
     451            1 :             CancellationSource::FromRedis,
     452            1 :         ));
     453            1 : 
     454            1 :         let session = cancellation_handler.clone().get_session();
     455            1 :         assert!(cancellation_handler.contains(&session));
     456            1 :         drop(session);
     457            1 :         // Check that the session has been dropped.
     458            1 :         assert!(cancellation_handler.is_empty());
     459            1 : 
     460            1 :         Ok(())
     461            1 :     }
     462              : 
     463              :     #[tokio::test]
     464            1 :     async fn cancel_session_noop_regression() {
     465            1 :         let handler = CancellationHandler::<()>::new(
     466            1 :             Box::leak(Box::new(config())),
     467            1 :             CancelMap::default(),
     468            1 :             CancellationSource::Local,
     469            1 :         );
     470            1 :         handler
     471            1 :             .cancel_session(
     472            1 :                 CancelKeyData {
     473            1 :                     backend_pid: 0,
     474            1 :                     cancel_key: 0,
     475            1 :                 },
     476            1 :                 Uuid::new_v4(),
     477            1 :                 "127.0.0.1".parse().unwrap(),
     478            1 :                 true,
     479            1 :             )
     480            1 :             .await
     481            1 :             .unwrap();
     482            1 :     }
     483              : }
        

Generated by: LCOV version 2.1-beta