LCOV - code coverage report
Current view: top level - proxy/src - cancellation.rs (source / functions) Coverage Total Hit
Test: 157166bf1e7b60cf936c3c96f6e44d24268705a4.info Lines: 0.0 % 261 0
Test Date: 2025-07-08 19:05:57 Functions: 0.0 % 39 0

            Line data    Source code
       1              : use std::convert::Infallible;
       2              : use std::net::{IpAddr, SocketAddr};
       3              : use std::pin::pin;
       4              : use std::sync::{Arc, OnceLock};
       5              : use std::time::Duration;
       6              : 
       7              : use anyhow::anyhow;
       8              : use futures::FutureExt;
       9              : use ipnet::{IpNet, Ipv4Net, Ipv6Net};
      10              : use postgres_client::RawCancelToken;
      11              : use postgres_client::tls::MakeTlsConnect;
      12              : use redis::{Cmd, FromRedisValue, Value};
      13              : use serde::{Deserialize, Serialize};
      14              : use thiserror::Error;
      15              : use tokio::net::TcpStream;
      16              : use tokio::time::timeout;
      17              : use tracing::{debug, error, info};
      18              : 
      19              : use crate::auth::AuthError;
      20              : use crate::auth::backend::ComputeUserInfo;
      21              : use crate::batch::{BatchQueue, QueueProcessing};
      22              : use crate::config::ComputeConfig;
      23              : use crate::context::RequestContext;
      24              : use crate::control_plane::ControlPlaneApi;
      25              : use crate::error::ReportableError;
      26              : use crate::ext::LockExt;
      27              : use crate::metrics::{CancelChannelSizeGuard, CancellationRequest, Metrics, RedisMsgKind};
      28              : use crate::pqproto::CancelKeyData;
      29              : use crate::rate_limiter::LeakyBucketRateLimiter;
      30              : use crate::redis::keys::KeyPrefix;
      31              : use crate::redis::kv_ops::RedisKVClient;
      32              : 
      33              : type IpSubnetKey = IpNet;
      34              : 
      35              : const CANCEL_KEY_TTL: std::time::Duration = std::time::Duration::from_secs(600);
      36              : const CANCEL_KEY_REFRESH: std::time::Duration = std::time::Duration::from_secs(570);
      37              : 
      38              : // Message types for sending through mpsc channel
      39              : pub enum CancelKeyOp {
      40              :     StoreCancelKey {
      41              :         key: CancelKeyData,
      42              :         value: Box<str>,
      43              :         expire: std::time::Duration,
      44              :     },
      45              :     GetCancelData {
      46              :         key: CancelKeyData,
      47              :     },
      48              : }
      49              : 
      50              : pub struct Pipeline {
      51              :     inner: redis::Pipeline,
      52              :     replies: usize,
      53              : }
      54              : 
      55              : impl Pipeline {
      56            0 :     fn with_capacity(n: usize) -> Self {
      57            0 :         Self {
      58            0 :             inner: redis::Pipeline::with_capacity(n),
      59            0 :             replies: 0,
      60            0 :         }
      61            0 :     }
      62              : 
      63            0 :     async fn execute(self, client: &mut RedisKVClient) -> Vec<anyhow::Result<Value>> {
      64            0 :         let responses = self.replies;
      65            0 :         let batch_size = self.inner.len();
      66              : 
      67            0 :         if !client.credentials_refreshed() {
      68            0 :             tracing::debug!(
      69            0 :                 "Redis credentials are not refreshed. Sleeping for 5 seconds before retrying..."
      70              :             );
      71            0 :             tokio::time::sleep(Duration::from_secs(5)).await;
      72            0 :         }
      73              : 
      74            0 :         match client.query(&self.inner).await {
      75              :             // for each reply, we expect that many values.
      76            0 :             Ok(Value::Array(values)) if values.len() == responses => {
      77            0 :                 debug!(
      78              :                     batch_size,
      79            0 :                     responses, "successfully completed cancellation jobs",
      80              :                 );
      81            0 :                 values.into_iter().map(Ok).collect()
      82              :             }
      83            0 :             Ok(value) => {
      84            0 :                 error!(batch_size, ?value, "unexpected redis return value");
      85            0 :                 std::iter::repeat_with(|| Err(anyhow!("incorrect response type from redis")))
      86            0 :                     .take(responses)
      87            0 :                     .collect()
      88              :             }
      89            0 :             Err(err) => {
      90            0 :                 std::iter::repeat_with(|| Err(anyhow!("could not send cmd to redis: {err}")))
      91            0 :                     .take(responses)
      92            0 :                     .collect()
      93              :             }
      94              :         }
      95            0 :     }
      96              : 
      97            0 :     fn add_command_with_reply(&mut self, cmd: Cmd) {
      98            0 :         self.inner.add_command(cmd);
      99            0 :         self.replies += 1;
     100            0 :     }
     101              : 
     102            0 :     fn add_command_no_reply(&mut self, cmd: Cmd) {
     103            0 :         self.inner.add_command(cmd).ignore();
     104            0 :     }
     105              : }
     106              : 
     107              : impl CancelKeyOp {
     108            0 :     fn register(&self, pipe: &mut Pipeline) {
     109            0 :         match self {
     110            0 :             CancelKeyOp::StoreCancelKey { key, value, expire } => {
     111            0 :                 let key = KeyPrefix::Cancel(*key).build_redis_key();
     112            0 :                 pipe.add_command_with_reply(Cmd::hset(&key, "data", &**value));
     113            0 :                 pipe.add_command_no_reply(Cmd::expire(&key, expire.as_secs() as i64));
     114            0 :             }
     115            0 :             CancelKeyOp::GetCancelData { key } => {
     116            0 :                 let key = KeyPrefix::Cancel(*key).build_redis_key();
     117            0 :                 pipe.add_command_with_reply(Cmd::hget(key, "data"));
     118            0 :             }
     119              :         }
     120            0 :     }
     121              : }
     122              : 
     123              : pub struct CancellationProcessor {
     124              :     pub client: RedisKVClient,
     125              :     pub batch_size: usize,
     126              : }
     127              : 
     128              : impl QueueProcessing for CancellationProcessor {
     129              :     type Req = (CancelChannelSizeGuard<'static>, CancelKeyOp);
     130              :     type Res = anyhow::Result<redis::Value>;
     131              : 
     132            0 :     fn batch_size(&self, _queue_size: usize) -> usize {
     133            0 :         self.batch_size
     134            0 :     }
     135              : 
     136            0 :     async fn apply(&mut self, batch: Vec<Self::Req>) -> Vec<Self::Res> {
     137            0 :         if !self.client.credentials_refreshed() {
     138              :             // this will cause a timeout for cancellation operations
     139            0 :             tracing::debug!(
     140            0 :                 "Redis credentials are not refreshed. Sleeping for 5 seconds before retrying..."
     141              :             );
     142            0 :             tokio::time::sleep(Duration::from_secs(5)).await;
     143            0 :         }
     144              : 
     145            0 :         let mut pipeline = Pipeline::with_capacity(batch.len());
     146              : 
     147            0 :         let batch_size = batch.len();
     148            0 :         debug!(batch_size, "running cancellation jobs");
     149              : 
     150            0 :         for (_, op) in &batch {
     151            0 :             op.register(&mut pipeline);
     152            0 :         }
     153              : 
     154            0 :         pipeline.execute(&mut self.client).await
     155            0 :     }
     156              : }
     157              : 
     158              : /// Enables serving `CancelRequest`s.
     159              : ///
     160              : /// If `CancellationPublisher` is available, cancel request will be used to publish the cancellation key to other proxy instances.
     161              : pub struct CancellationHandler {
     162              :     compute_config: &'static ComputeConfig,
     163              :     // rate limiter of cancellation requests
     164              :     limiter: Arc<std::sync::Mutex<LeakyBucketRateLimiter<IpSubnetKey>>>,
     165              :     tx: OnceLock<BatchQueue<CancellationProcessor>>, // send messages to the redis KV client task
     166              : }
     167              : 
     168              : #[derive(Debug, Error)]
     169              : pub(crate) enum CancelError {
     170              :     #[error("{0}")]
     171              :     IO(#[from] std::io::Error),
     172              : 
     173              :     #[error("{0}")]
     174              :     Postgres(#[from] postgres_client::Error),
     175              : 
     176              :     #[error("rate limit exceeded")]
     177              :     RateLimit,
     178              : 
     179              :     #[error("Authentication error")]
     180              :     AuthError(#[from] AuthError),
     181              : 
     182              :     #[error("key not found")]
     183              :     NotFound,
     184              : 
     185              :     #[error("proxy service error")]
     186              :     InternalError,
     187              : }
     188              : 
     189              : impl ReportableError for CancelError {
     190            0 :     fn get_error_kind(&self) -> crate::error::ErrorKind {
     191            0 :         match self {
     192            0 :             CancelError::IO(_) => crate::error::ErrorKind::Compute,
     193            0 :             CancelError::Postgres(e) if e.as_db_error().is_some() => {
     194            0 :                 crate::error::ErrorKind::Postgres
     195              :             }
     196            0 :             CancelError::Postgres(_) => crate::error::ErrorKind::Compute,
     197            0 :             CancelError::RateLimit => crate::error::ErrorKind::RateLimit,
     198            0 :             CancelError::NotFound | CancelError::AuthError(_) => crate::error::ErrorKind::User,
     199            0 :             CancelError::InternalError => crate::error::ErrorKind::Service,
     200              :         }
     201            0 :     }
     202              : }
     203              : 
     204              : impl CancellationHandler {
     205            0 :     pub fn new(compute_config: &'static ComputeConfig) -> Self {
     206            0 :         Self {
     207            0 :             compute_config,
     208            0 :             tx: OnceLock::new(),
     209            0 :             limiter: Arc::new(std::sync::Mutex::new(
     210            0 :                 LeakyBucketRateLimiter::<IpSubnetKey>::new_with_shards(
     211            0 :                     LeakyBucketRateLimiter::<IpSubnetKey>::DEFAULT,
     212            0 :                     64,
     213            0 :                 ),
     214            0 :             )),
     215            0 :         }
     216            0 :     }
     217              : 
     218            0 :     pub fn init_tx(&self, queue: BatchQueue<CancellationProcessor>) {
     219            0 :         self.tx
     220            0 :             .set(queue)
     221            0 :             .map_err(|_| {})
     222            0 :             .expect("cancellation queue should be registered once");
     223            0 :     }
     224              : 
     225            0 :     pub(crate) fn get_key(self: Arc<Self>) -> Session {
     226              :         // we intentionally generate a random "backend pid" and "secret key" here.
     227              :         // we use the corresponding u64 as an identifier for the
     228              :         // actual endpoint+pid+secret for postgres/pgbouncer.
     229              :         //
     230              :         // if we forwarded the backend_pid from postgres to the client, there would be a lot
     231              :         // of overlap between our computes as most pids are small (~100).
     232              : 
     233            0 :         let key: CancelKeyData = rand::random();
     234              : 
     235            0 :         debug!("registered new query cancellation key {key}");
     236            0 :         Session {
     237            0 :             key,
     238            0 :             cancellation_handler: self,
     239            0 :         }
     240            0 :     }
     241              : 
     242              :     /// This is not cancel safe
     243            0 :     async fn get_cancel_key(
     244            0 :         &self,
     245            0 :         key: CancelKeyData,
     246            0 :     ) -> Result<Option<CancelClosure>, CancelError> {
     247            0 :         let guard = Metrics::get()
     248            0 :             .proxy
     249            0 :             .cancel_channel_size
     250            0 :             .guard(RedisMsgKind::HGet);
     251            0 :         let op = CancelKeyOp::GetCancelData { key };
     252              : 
     253            0 :         let Some(tx) = self.tx.get() else {
     254            0 :             tracing::warn!("cancellation handler is not available");
     255            0 :             return Err(CancelError::InternalError);
     256              :         };
     257              : 
     258              :         const TIMEOUT: Duration = Duration::from_secs(5);
     259            0 :         let result = timeout(
     260            0 :             TIMEOUT,
     261            0 :             tx.call((guard, op), std::future::pending::<Infallible>()),
     262            0 :         )
     263            0 :         .await
     264            0 :         .map_err(|_| {
     265            0 :             tracing::warn!("timed out waiting to receive GetCancelData response");
     266            0 :             CancelError::RateLimit
     267            0 :         })?
     268              :         // cannot be cancelled
     269            0 :         .unwrap_or_else(|x| match x {})
     270            0 :         .map_err(|e| {
     271            0 :             tracing::warn!("failed to receive GetCancelData response: {e}");
     272            0 :             CancelError::InternalError
     273            0 :         })?;
     274              : 
     275            0 :         let cancel_state_str = String::from_owned_redis_value(result).map_err(|e| {
     276            0 :             tracing::warn!("failed to receive GetCancelData response: {e}");
     277            0 :             CancelError::InternalError
     278            0 :         })?;
     279              : 
     280            0 :         let cancel_closure: CancelClosure =
     281            0 :             serde_json::from_str(&cancel_state_str).map_err(|e| {
     282            0 :                 tracing::warn!("failed to deserialize cancel state: {e}");
     283            0 :                 CancelError::InternalError
     284            0 :             })?;
     285              : 
     286            0 :         Ok(Some(cancel_closure))
     287            0 :     }
     288              : 
     289              :     /// Try to cancel a running query for the corresponding connection.
     290              :     /// If the cancellation key is not found, it will be published to Redis.
     291              :     /// check_allowed - if true, check if the IP is allowed to cancel the query.
     292              :     /// Will fetch IP allowlist internally.
     293              :     ///
     294              :     /// return Result primarily for tests
     295              :     ///
     296              :     /// This is not cancel safe
     297            0 :     pub(crate) async fn cancel_session<T: ControlPlaneApi>(
     298            0 :         &self,
     299            0 :         key: CancelKeyData,
     300            0 :         ctx: RequestContext,
     301            0 :         check_ip_allowed: bool,
     302            0 :         check_vpc_allowed: bool,
     303            0 :         auth_backend: &T,
     304            0 :     ) -> Result<(), CancelError> {
     305            0 :         let subnet_key = match ctx.peer_addr() {
     306            0 :             IpAddr::V4(ip) => IpNet::V4(Ipv4Net::new_assert(ip, 24).trunc()), // use defaut mask here
     307            0 :             IpAddr::V6(ip) => IpNet::V6(Ipv6Net::new_assert(ip, 64).trunc()),
     308              :         };
     309              : 
     310            0 :         let allowed = {
     311            0 :             let rate_limit_config = None;
     312            0 :             let limiter = self.limiter.lock_propagate_poison();
     313            0 :             limiter.check(subnet_key, rate_limit_config, 1)
     314              :         };
     315            0 :         if !allowed {
     316              :             // log only the subnet part of the IP address to know which subnet is rate limited
     317            0 :             tracing::warn!("Rate limit exceeded. Skipping cancellation message, {subnet_key}");
     318            0 :             Metrics::get()
     319            0 :                 .proxy
     320            0 :                 .cancellation_requests_total
     321            0 :                 .inc(CancellationRequest {
     322            0 :                     kind: crate::metrics::CancellationOutcome::RateLimitExceeded,
     323            0 :                 });
     324            0 :             return Err(CancelError::RateLimit);
     325            0 :         }
     326              : 
     327            0 :         let cancel_state = self.get_cancel_key(key).await.map_err(|e| {
     328            0 :             tracing::warn!("failed to receive RedisOp response: {e}");
     329            0 :             CancelError::InternalError
     330            0 :         })?;
     331              : 
     332            0 :         let Some(cancel_closure) = cancel_state else {
     333            0 :             tracing::warn!("query cancellation key not found: {key}");
     334            0 :             Metrics::get()
     335            0 :                 .proxy
     336            0 :                 .cancellation_requests_total
     337            0 :                 .inc(CancellationRequest {
     338            0 :                     kind: crate::metrics::CancellationOutcome::NotFound,
     339            0 :                 });
     340            0 :             return Err(CancelError::NotFound);
     341              :         };
     342              : 
     343            0 :         let info = &cancel_closure.user_info;
     344            0 :         let access_controls = auth_backend
     345            0 :             .get_endpoint_access_control(&ctx, &info.endpoint, &info.user)
     346            0 :             .await
     347            0 :             .map_err(|e| CancelError::AuthError(e.into()))?;
     348              : 
     349            0 :         access_controls.check(&ctx, check_ip_allowed, check_vpc_allowed)?;
     350              : 
     351            0 :         Metrics::get()
     352            0 :             .proxy
     353            0 :             .cancellation_requests_total
     354            0 :             .inc(CancellationRequest {
     355            0 :                 kind: crate::metrics::CancellationOutcome::Found,
     356            0 :             });
     357            0 :         info!("cancelling query per user's request using key {key}");
     358            0 :         cancel_closure.try_cancel_query(self.compute_config).await
     359            0 :     }
     360              : }
     361              : 
     362              : /// This should've been a [`std::future::Future`], but
     363              : /// it's impossible to name a type of an unboxed future
     364              : /// (we'd need something like `#![feature(type_alias_impl_trait)]`).
     365            0 : #[derive(Debug, Clone, Serialize, Deserialize)]
     366              : pub struct CancelClosure {
     367              :     socket_addr: SocketAddr,
     368              :     cancel_token: RawCancelToken,
     369              :     hostname: String, // for pg_sni router
     370              :     user_info: ComputeUserInfo,
     371              : }
     372              : 
     373              : impl CancelClosure {
     374            0 :     pub(crate) fn new(
     375            0 :         socket_addr: SocketAddr,
     376            0 :         cancel_token: RawCancelToken,
     377            0 :         hostname: String,
     378            0 :         user_info: ComputeUserInfo,
     379            0 :     ) -> Self {
     380            0 :         Self {
     381            0 :             socket_addr,
     382            0 :             cancel_token,
     383            0 :             hostname,
     384            0 :             user_info,
     385            0 :         }
     386            0 :     }
     387              :     /// Cancels the query running on user's compute node.
     388            0 :     pub(crate) async fn try_cancel_query(
     389            0 :         &self,
     390            0 :         compute_config: &ComputeConfig,
     391            0 :     ) -> Result<(), CancelError> {
     392            0 :         let socket = TcpStream::connect(self.socket_addr).await?;
     393              : 
     394            0 :         let tls = <_ as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
     395            0 :             compute_config,
     396            0 :             &self.hostname,
     397              :         )
     398            0 :         .map_err(|e| CancelError::IO(std::io::Error::other(e.to_string())))?;
     399              : 
     400            0 :         self.cancel_token.cancel_query_raw(socket, tls).await?;
     401            0 :         debug!("query was cancelled");
     402            0 :         Ok(())
     403            0 :     }
     404              : }
     405              : 
     406              : /// Helper for registering query cancellation tokens.
     407              : pub(crate) struct Session {
     408              :     /// The user-facing key identifying this session.
     409              :     key: CancelKeyData,
     410              :     cancellation_handler: Arc<CancellationHandler>,
     411              : }
     412              : 
     413              : impl Session {
     414            0 :     pub(crate) fn key(&self) -> &CancelKeyData {
     415            0 :         &self.key
     416            0 :     }
     417              : 
     418              :     /// Ensure the cancel key is continously refreshed,
     419              :     /// but stop when the channel is dropped.
     420              :     ///
     421              :     /// This is not cancel safe
     422            0 :     pub(crate) async fn maintain_cancel_key(
     423            0 :         &self,
     424            0 :         session_id: uuid::Uuid,
     425            0 :         cancel: tokio::sync::oneshot::Receiver<Infallible>,
     426            0 :         cancel_closure: &CancelClosure,
     427            0 :         compute_config: &ComputeConfig,
     428            0 :     ) {
     429            0 :         let Some(tx) = self.cancellation_handler.tx.get() else {
     430            0 :             tracing::warn!("cancellation handler is not available");
     431              :             // don't exit, as we only want to exit if cancelled externally.
     432            0 :             std::future::pending().await
     433              :         };
     434              : 
     435            0 :         let closure_json = serde_json::to_string(&cancel_closure)
     436            0 :             .expect("serialising to json string should not fail")
     437            0 :             .into_boxed_str();
     438              : 
     439            0 :         let mut cancel = pin!(cancel);
     440              : 
     441              :         loop {
     442            0 :             let guard = Metrics::get()
     443            0 :                 .proxy
     444            0 :                 .cancel_channel_size
     445            0 :                 .guard(RedisMsgKind::HSet);
     446            0 :             let op = CancelKeyOp::StoreCancelKey {
     447            0 :                 key: self.key,
     448            0 :                 value: closure_json.clone(),
     449            0 :                 expire: CANCEL_KEY_TTL,
     450            0 :             };
     451              : 
     452            0 :             tracing::debug!(
     453              :                 src=%self.key,
     454              :                 dest=?cancel_closure.cancel_token,
     455            0 :                 "registering cancellation key"
     456              :             );
     457              : 
     458            0 :             match tx.call((guard, op), cancel.as_mut()).await {
     459              :                 Ok(Ok(_)) => {
     460            0 :                     tracing::debug!(
     461              :                         src=%self.key,
     462              :                         dest=?cancel_closure.cancel_token,
     463            0 :                         "registered cancellation key"
     464              :                     );
     465              : 
     466              :                     // wait before continuing.
     467            0 :                     tokio::time::sleep(CANCEL_KEY_REFRESH).await;
     468              :                 }
     469              :                 // retry immediately.
     470            0 :                 Ok(Err(error)) => {
     471            0 :                     tracing::warn!(?error, "error registering cancellation key");
     472              :                 }
     473            0 :                 Err(Err(_cancelled)) => break,
     474              :             }
     475              :         }
     476              : 
     477            0 :         if let Err(err) = cancel_closure
     478            0 :             .try_cancel_query(compute_config)
     479            0 :             .boxed()
     480            0 :             .await
     481              :         {
     482            0 :             tracing::warn!(
     483              :                 ?session_id,
     484              :                 ?err,
     485            0 :                 "could not cancel the query in the database"
     486              :             );
     487            0 :         }
     488            0 :     }
     489              : }
        

Generated by: LCOV version 2.1-beta