LCOV - code coverage report
Current view: top level - proxy/src - cancellation.rs (source / functions) Coverage Total Hit
Test: aca806cab4756d7eb6a304846130f4a73a5d5393.info Lines: 0.0 % 371 0
Test Date: 2025-04-24 20:31:15 Functions: 0.0 % 49 0

            Line data    Source code
       1              : use std::net::{IpAddr, SocketAddr};
       2              : use std::sync::Arc;
       3              : 
       4              : use anyhow::{Context, anyhow};
       5              : use ipnet::{IpNet, Ipv4Net, Ipv6Net};
       6              : use postgres_client::CancelToken;
       7              : use postgres_client::tls::MakeTlsConnect;
       8              : use pq_proto::CancelKeyData;
       9              : use redis::{FromRedisValue, Pipeline, Value, pipe};
      10              : use serde::{Deserialize, Serialize};
      11              : use thiserror::Error;
      12              : use tokio::net::TcpStream;
      13              : use tokio::sync::{mpsc, oneshot};
      14              : use tracing::{debug, info, warn};
      15              : 
      16              : use crate::auth::backend::ComputeUserInfo;
      17              : use crate::auth::{AuthError, check_peer_addr_is_in_list};
      18              : use crate::config::ComputeConfig;
      19              : use crate::context::RequestContext;
      20              : use crate::control_plane::ControlPlaneApi;
      21              : use crate::error::ReportableError;
      22              : use crate::ext::LockExt;
      23              : use crate::metrics::{CancelChannelSizeGuard, CancellationRequest, Metrics, RedisMsgKind};
      24              : use crate::protocol2::ConnectionInfoExtra;
      25              : use crate::rate_limiter::LeakyBucketRateLimiter;
      26              : use crate::redis::keys::KeyPrefix;
      27              : use crate::redis::kv_ops::RedisKVClient;
      28              : use crate::tls::postgres_rustls::MakeRustlsConnect;
      29              : 
      30              : type IpSubnetKey = IpNet;
      31              : 
      32              : const CANCEL_KEY_TTL: i64 = 1_209_600; // 2 weeks cancellation key expire time
      33              : const REDIS_SEND_TIMEOUT: std::time::Duration = std::time::Duration::from_millis(10);
      34              : const BATCH_SIZE: usize = 8;
      35              : 
      36              : // Message types for sending through mpsc channel
      37              : pub enum CancelKeyOp {
      38              :     StoreCancelKey {
      39              :         key: String,
      40              :         field: String,
      41              :         value: String,
      42              :         resp_tx: Option<oneshot::Sender<anyhow::Result<()>>>,
      43              :         _guard: CancelChannelSizeGuard<'static>,
      44              :         expire: i64, // TTL for key
      45              :     },
      46              :     GetCancelData {
      47              :         key: String,
      48              :         resp_tx: oneshot::Sender<anyhow::Result<Vec<(String, String)>>>,
      49              :         _guard: CancelChannelSizeGuard<'static>,
      50              :     },
      51              :     RemoveCancelKey {
      52              :         key: String,
      53              :         field: String,
      54              :         resp_tx: Option<oneshot::Sender<anyhow::Result<()>>>,
      55              :         _guard: CancelChannelSizeGuard<'static>,
      56              :     },
      57              : }
      58              : 
      59              : impl CancelKeyOp {
      60            0 :     fn register(self, pipe: &mut Pipeline) -> Option<CancelReplyOp> {
      61            0 :         #[allow(clippy::used_underscore_binding)]
      62            0 :         match self {
      63              :             CancelKeyOp::StoreCancelKey {
      64            0 :                 key,
      65            0 :                 field,
      66            0 :                 value,
      67            0 :                 resp_tx,
      68            0 :                 _guard,
      69            0 :                 expire,
      70            0 :             } => {
      71            0 :                 pipe.hset(&key, field, value);
      72            0 :                 pipe.expire(key, expire);
      73            0 :                 let resp_tx = resp_tx?;
      74            0 :                 Some(CancelReplyOp::StoreCancelKey { resp_tx, _guard })
      75              :             }
      76              :             CancelKeyOp::GetCancelData {
      77            0 :                 key,
      78            0 :                 resp_tx,
      79            0 :                 _guard,
      80            0 :             } => {
      81            0 :                 pipe.hgetall(key);
      82            0 :                 Some(CancelReplyOp::GetCancelData { resp_tx, _guard })
      83              :             }
      84              :             CancelKeyOp::RemoveCancelKey {
      85            0 :                 key,
      86            0 :                 field,
      87            0 :                 resp_tx,
      88            0 :                 _guard,
      89            0 :             } => {
      90            0 :                 pipe.hdel(key, field);
      91            0 :                 let resp_tx = resp_tx?;
      92            0 :                 Some(CancelReplyOp::RemoveCancelKey { resp_tx, _guard })
      93              :             }
      94              :         }
      95            0 :     }
      96              : }
      97              : 
      98              : // Message types for sending through mpsc channel
      99              : pub enum CancelReplyOp {
     100              :     StoreCancelKey {
     101              :         resp_tx: oneshot::Sender<anyhow::Result<()>>,
     102              :         _guard: CancelChannelSizeGuard<'static>,
     103              :     },
     104              :     GetCancelData {
     105              :         resp_tx: oneshot::Sender<anyhow::Result<Vec<(String, String)>>>,
     106              :         _guard: CancelChannelSizeGuard<'static>,
     107              :     },
     108              :     RemoveCancelKey {
     109              :         resp_tx: oneshot::Sender<anyhow::Result<()>>,
     110              :         _guard: CancelChannelSizeGuard<'static>,
     111              :     },
     112              : }
     113              : 
     114              : impl CancelReplyOp {
     115            0 :     fn send_err(self, e: anyhow::Error) {
     116            0 :         match self {
     117            0 :             CancelReplyOp::StoreCancelKey { resp_tx, _guard } => {
     118            0 :                 resp_tx
     119            0 :                     .send(Err(e))
     120            0 :                     .inspect_err(|_| tracing::debug!("could not send reply"))
     121            0 :                     .ok();
     122            0 :             }
     123            0 :             CancelReplyOp::GetCancelData { resp_tx, _guard } => {
     124            0 :                 resp_tx
     125            0 :                     .send(Err(e))
     126            0 :                     .inspect_err(|_| tracing::debug!("could not send reply"))
     127            0 :                     .ok();
     128            0 :             }
     129            0 :             CancelReplyOp::RemoveCancelKey { resp_tx, _guard } => {
     130            0 :                 resp_tx
     131            0 :                     .send(Err(e))
     132            0 :                     .inspect_err(|_| tracing::debug!("could not send reply"))
     133            0 :                     .ok();
     134            0 :             }
     135              :         }
     136            0 :     }
     137              : 
     138            0 :     fn send_value(self, v: redis::Value) {
     139            0 :         match self {
     140            0 :             CancelReplyOp::StoreCancelKey { resp_tx, _guard } => {
     141            0 :                 let send =
     142            0 :                     FromRedisValue::from_owned_redis_value(v).context("could not parse value");
     143            0 :                 resp_tx
     144            0 :                     .send(send)
     145            0 :                     .inspect_err(|_| tracing::debug!("could not send reply"))
     146            0 :                     .ok();
     147            0 :             }
     148            0 :             CancelReplyOp::GetCancelData { resp_tx, _guard } => {
     149            0 :                 let send =
     150            0 :                     FromRedisValue::from_owned_redis_value(v).context("could not parse value");
     151            0 :                 resp_tx
     152            0 :                     .send(send)
     153            0 :                     .inspect_err(|_| tracing::debug!("could not send reply"))
     154            0 :                     .ok();
     155            0 :             }
     156            0 :             CancelReplyOp::RemoveCancelKey { resp_tx, _guard } => {
     157            0 :                 let send =
     158            0 :                     FromRedisValue::from_owned_redis_value(v).context("could not parse value");
     159            0 :                 resp_tx
     160            0 :                     .send(send)
     161            0 :                     .inspect_err(|_| tracing::debug!("could not send reply"))
     162            0 :                     .ok();
     163            0 :             }
     164              :         }
     165            0 :     }
     166              : }
     167              : 
     168              : // Running as a separate task to accept messages through the rx channel
     169            0 : pub async fn handle_cancel_messages(
     170            0 :     client: &mut RedisKVClient,
     171            0 :     mut rx: mpsc::Receiver<CancelKeyOp>,
     172            0 : ) -> anyhow::Result<()> {
     173            0 :     let mut batch = Vec::new();
     174            0 :     let mut replies = vec![];
     175              : 
     176              :     loop {
     177            0 :         if rx.recv_many(&mut batch, BATCH_SIZE).await == 0 {
     178            0 :             warn!("shutting down cancellation queue");
     179            0 :             break Ok(());
     180            0 :         }
     181            0 : 
     182            0 :         let batch_size = batch.len();
     183            0 :         debug!(batch_size, "running cancellation jobs");
     184              : 
     185            0 :         let mut pipe = pipe();
     186            0 :         for msg in batch.drain(..) {
     187            0 :             if let Some(reply) = msg.register(&mut pipe) {
     188            0 :                 replies.push(reply);
     189            0 :             } else {
     190            0 :                 pipe.ignore();
     191            0 :             }
     192              :         }
     193              : 
     194            0 :         let responses = replies.len();
     195            0 : 
     196            0 :         match client.query(pipe).await {
     197              :             // for each reply, we expect that many values.
     198            0 :             Ok(Value::Array(values)) if values.len() == responses => {
     199            0 :                 debug!(
     200              :                     batch_size,
     201            0 :                     responses, "successfully completed cancellation jobs",
     202              :                 );
     203            0 :                 for (value, reply) in std::iter::zip(values, replies.drain(..)) {
     204            0 :                     reply.send_value(value);
     205            0 :                 }
     206              :             }
     207            0 :             Ok(value) => {
     208            0 :                 debug!(?value, "unexpected redis return value");
     209            0 :                 for reply in replies.drain(..) {
     210            0 :                     reply.send_err(anyhow!("incorrect response type from redis"));
     211            0 :                 }
     212              :             }
     213            0 :             Err(err) => {
     214            0 :                 for reply in replies.drain(..) {
     215            0 :                     reply.send_err(anyhow!("could not send cmd to redis: {err}"));
     216            0 :                 }
     217              :             }
     218              :         }
     219              : 
     220            0 :         replies.clear();
     221              :     }
     222            0 : }
     223              : 
     224              : /// Enables serving `CancelRequest`s.
     225              : ///
     226              : /// If `CancellationPublisher` is available, cancel request will be used to publish the cancellation key to other proxy instances.
     227              : pub struct CancellationHandler {
     228              :     compute_config: &'static ComputeConfig,
     229              :     // rate limiter of cancellation requests
     230              :     limiter: Arc<std::sync::Mutex<LeakyBucketRateLimiter<IpSubnetKey>>>,
     231              :     tx: Option<mpsc::Sender<CancelKeyOp>>, // send messages to the redis KV client task
     232              : }
     233              : 
     234              : #[derive(Debug, Error)]
     235              : pub(crate) enum CancelError {
     236              :     #[error("{0}")]
     237              :     IO(#[from] std::io::Error),
     238              : 
     239              :     #[error("{0}")]
     240              :     Postgres(#[from] postgres_client::Error),
     241              : 
     242              :     #[error("rate limit exceeded")]
     243              :     RateLimit,
     244              : 
     245              :     #[error("IP is not allowed")]
     246              :     IpNotAllowed,
     247              : 
     248              :     #[error("VPC endpoint id is not allowed to connect")]
     249              :     VpcEndpointIdNotAllowed,
     250              : 
     251              :     #[error("Authentication backend error")]
     252              :     AuthError(#[from] AuthError),
     253              : 
     254              :     #[error("key not found")]
     255              :     NotFound,
     256              : 
     257              :     #[error("proxy service error")]
     258              :     InternalError,
     259              : }
     260              : 
     261              : impl ReportableError for CancelError {
     262            0 :     fn get_error_kind(&self) -> crate::error::ErrorKind {
     263            0 :         match self {
     264            0 :             CancelError::IO(_) => crate::error::ErrorKind::Compute,
     265            0 :             CancelError::Postgres(e) if e.as_db_error().is_some() => {
     266            0 :                 crate::error::ErrorKind::Postgres
     267              :             }
     268            0 :             CancelError::Postgres(_) => crate::error::ErrorKind::Compute,
     269            0 :             CancelError::RateLimit => crate::error::ErrorKind::RateLimit,
     270              :             CancelError::IpNotAllowed
     271              :             | CancelError::VpcEndpointIdNotAllowed
     272            0 :             | CancelError::NotFound => crate::error::ErrorKind::User,
     273            0 :             CancelError::AuthError(_) => crate::error::ErrorKind::ControlPlane,
     274            0 :             CancelError::InternalError => crate::error::ErrorKind::Service,
     275              :         }
     276            0 :     }
     277              : }
     278              : 
     279              : impl CancellationHandler {
     280            0 :     pub fn new(
     281            0 :         compute_config: &'static ComputeConfig,
     282            0 :         tx: Option<mpsc::Sender<CancelKeyOp>>,
     283            0 :     ) -> Self {
     284            0 :         Self {
     285            0 :             compute_config,
     286            0 :             tx,
     287            0 :             limiter: Arc::new(std::sync::Mutex::new(
     288            0 :                 LeakyBucketRateLimiter::<IpSubnetKey>::new_with_shards(
     289            0 :                     LeakyBucketRateLimiter::<IpSubnetKey>::DEFAULT,
     290            0 :                     64,
     291            0 :                 ),
     292            0 :             )),
     293            0 :         }
     294            0 :     }
     295              : 
     296            0 :     pub(crate) fn get_key(self: &Arc<Self>) -> Session {
     297            0 :         // we intentionally generate a random "backend pid" and "secret key" here.
     298            0 :         // we use the corresponding u64 as an identifier for the
     299            0 :         // actual endpoint+pid+secret for postgres/pgbouncer.
     300            0 :         //
     301            0 :         // if we forwarded the backend_pid from postgres to the client, there would be a lot
     302            0 :         // of overlap between our computes as most pids are small (~100).
     303            0 : 
     304            0 :         let key: CancelKeyData = rand::random();
     305            0 : 
     306            0 :         let prefix_key: KeyPrefix = KeyPrefix::Cancel(key);
     307            0 :         let redis_key = prefix_key.build_redis_key();
     308            0 : 
     309            0 :         debug!("registered new query cancellation key {key}");
     310            0 :         Session {
     311            0 :             key,
     312            0 :             redis_key,
     313            0 :             cancellation_handler: Arc::clone(self),
     314            0 :         }
     315            0 :     }
     316              : 
     317            0 :     async fn get_cancel_key(
     318            0 :         &self,
     319            0 :         key: CancelKeyData,
     320            0 :     ) -> Result<Option<CancelClosure>, CancelError> {
     321            0 :         let prefix_key: KeyPrefix = KeyPrefix::Cancel(key);
     322            0 :         let redis_key = prefix_key.build_redis_key();
     323            0 : 
     324            0 :         let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
     325            0 :         let op = CancelKeyOp::GetCancelData {
     326            0 :             key: redis_key,
     327            0 :             resp_tx,
     328            0 :             _guard: Metrics::get()
     329            0 :                 .proxy
     330            0 :                 .cancel_channel_size
     331            0 :                 .guard(RedisMsgKind::HGetAll),
     332            0 :         };
     333              : 
     334            0 :         let Some(tx) = &self.tx else {
     335            0 :             tracing::warn!("cancellation handler is not available");
     336            0 :             return Err(CancelError::InternalError);
     337              :         };
     338              : 
     339            0 :         tx.send_timeout(op, REDIS_SEND_TIMEOUT)
     340            0 :             .await
     341            0 :             .map_err(|e| {
     342            0 :                 tracing::warn!("failed to send GetCancelData for {key}: {e}");
     343            0 :             })
     344            0 :             .map_err(|()| CancelError::InternalError)?;
     345              : 
     346            0 :         let result = resp_rx.await.map_err(|e| {
     347            0 :             tracing::warn!("failed to receive GetCancelData response: {e}");
     348            0 :             CancelError::InternalError
     349            0 :         })?;
     350              : 
     351            0 :         let cancel_state_str: Option<String> = match result {
     352            0 :             Ok(mut state) => {
     353            0 :                 if state.len() == 1 {
     354            0 :                     Some(state.remove(0).1)
     355              :                 } else {
     356            0 :                     tracing::warn!("unexpected number of entries in cancel state: {state:?}");
     357            0 :                     return Err(CancelError::InternalError);
     358              :                 }
     359              :             }
     360            0 :             Err(e) => {
     361            0 :                 tracing::warn!("failed to receive cancel state from redis: {e}");
     362            0 :                 return Err(CancelError::InternalError);
     363              :             }
     364              :         };
     365              : 
     366            0 :         let cancel_state: Option<CancelClosure> = match cancel_state_str {
     367            0 :             Some(state) => {
     368            0 :                 let cancel_closure: CancelClosure = serde_json::from_str(&state).map_err(|e| {
     369            0 :                     tracing::warn!("failed to deserialize cancel state: {e}");
     370            0 :                     CancelError::InternalError
     371            0 :                 })?;
     372            0 :                 Some(cancel_closure)
     373              :             }
     374            0 :             None => None,
     375              :         };
     376            0 :         Ok(cancel_state)
     377            0 :     }
     378              :     /// Try to cancel a running query for the corresponding connection.
     379              :     /// If the cancellation key is not found, it will be published to Redis.
     380              :     /// check_allowed - if true, check if the IP is allowed to cancel the query.
     381              :     /// Will fetch IP allowlist internally.
     382              :     ///
     383              :     /// return Result primarily for tests
     384            0 :     pub(crate) async fn cancel_session<T: ControlPlaneApi>(
     385            0 :         &self,
     386            0 :         key: CancelKeyData,
     387            0 :         ctx: RequestContext,
     388            0 :         check_ip_allowed: bool,
     389            0 :         check_vpc_allowed: bool,
     390            0 :         auth_backend: &T,
     391            0 :     ) -> Result<(), CancelError> {
     392            0 :         let subnet_key = match ctx.peer_addr() {
     393            0 :             IpAddr::V4(ip) => IpNet::V4(Ipv4Net::new_assert(ip, 24).trunc()), // use defaut mask here
     394            0 :             IpAddr::V6(ip) => IpNet::V6(Ipv6Net::new_assert(ip, 64).trunc()),
     395              :         };
     396            0 :         if !self.limiter.lock_propagate_poison().check(subnet_key, 1) {
     397              :             // log only the subnet part of the IP address to know which subnet is rate limited
     398            0 :             tracing::warn!("Rate limit exceeded. Skipping cancellation message, {subnet_key}");
     399            0 :             Metrics::get()
     400            0 :                 .proxy
     401            0 :                 .cancellation_requests_total
     402            0 :                 .inc(CancellationRequest {
     403            0 :                     kind: crate::metrics::CancellationOutcome::RateLimitExceeded,
     404            0 :                 });
     405            0 :             return Err(CancelError::RateLimit);
     406            0 :         }
     407              : 
     408            0 :         let cancel_state = self.get_cancel_key(key).await.map_err(|e| {
     409            0 :             tracing::warn!("failed to receive RedisOp response: {e}");
     410            0 :             CancelError::InternalError
     411            0 :         })?;
     412              : 
     413            0 :         let Some(cancel_closure) = cancel_state else {
     414            0 :             tracing::warn!("query cancellation key not found: {key}");
     415            0 :             Metrics::get()
     416            0 :                 .proxy
     417            0 :                 .cancellation_requests_total
     418            0 :                 .inc(CancellationRequest {
     419            0 :                     kind: crate::metrics::CancellationOutcome::NotFound,
     420            0 :                 });
     421            0 :             return Err(CancelError::NotFound);
     422              :         };
     423              : 
     424            0 :         if check_ip_allowed {
     425            0 :             let ip_allowlist = auth_backend
     426            0 :                 .get_allowed_ips(&ctx, &cancel_closure.user_info)
     427            0 :                 .await
     428            0 :                 .map_err(|e| CancelError::AuthError(e.into()))?;
     429              : 
     430            0 :             if !check_peer_addr_is_in_list(&ctx.peer_addr(), &ip_allowlist) {
     431              :                 // log it here since cancel_session could be spawned in a task
     432            0 :                 tracing::warn!(
     433            0 :                     "IP is not allowed to cancel the query: {key}, address: {}",
     434            0 :                     ctx.peer_addr()
     435              :                 );
     436            0 :                 return Err(CancelError::IpNotAllowed);
     437            0 :             }
     438            0 :         }
     439              : 
     440              :         // check if a VPC endpoint ID is coming in and if yes, if it's allowed
     441            0 :         let access_blocks = auth_backend
     442            0 :             .get_block_public_or_vpc_access(&ctx, &cancel_closure.user_info)
     443            0 :             .await
     444            0 :             .map_err(|e| CancelError::AuthError(e.into()))?;
     445              : 
     446            0 :         if check_vpc_allowed {
     447            0 :             if access_blocks.vpc_access_blocked {
     448            0 :                 return Err(CancelError::AuthError(AuthError::NetworkNotAllowed));
     449            0 :             }
     450              : 
     451            0 :             let incoming_vpc_endpoint_id = match ctx.extra() {
     452            0 :                 None => return Err(CancelError::AuthError(AuthError::MissingVPCEndpointId)),
     453            0 :                 Some(ConnectionInfoExtra::Aws { vpce_id }) => vpce_id.to_string(),
     454            0 :                 Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(),
     455              :             };
     456              : 
     457            0 :             let allowed_vpc_endpoint_ids = auth_backend
     458            0 :                 .get_allowed_vpc_endpoint_ids(&ctx, &cancel_closure.user_info)
     459            0 :                 .await
     460            0 :                 .map_err(|e| CancelError::AuthError(e.into()))?;
     461              :             // TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that.
     462            0 :             if !allowed_vpc_endpoint_ids.is_empty()
     463            0 :                 && !allowed_vpc_endpoint_ids.contains(&incoming_vpc_endpoint_id)
     464              :             {
     465            0 :                 return Err(CancelError::VpcEndpointIdNotAllowed);
     466            0 :             }
     467            0 :         } else if access_blocks.public_access_blocked {
     468            0 :             return Err(CancelError::VpcEndpointIdNotAllowed);
     469            0 :         }
     470              : 
     471            0 :         Metrics::get()
     472            0 :             .proxy
     473            0 :             .cancellation_requests_total
     474            0 :             .inc(CancellationRequest {
     475            0 :                 kind: crate::metrics::CancellationOutcome::Found,
     476            0 :             });
     477            0 :         info!("cancelling query per user's request using key {key}");
     478            0 :         cancel_closure.try_cancel_query(self.compute_config).await
     479            0 :     }
     480              : }
     481              : 
     482              : /// This should've been a [`std::future::Future`], but
     483              : /// it's impossible to name a type of an unboxed future
     484              : /// (we'd need something like `#![feature(type_alias_impl_trait)]`).
     485            0 : #[derive(Clone, Serialize, Deserialize)]
     486              : pub struct CancelClosure {
     487              :     socket_addr: SocketAddr,
     488              :     cancel_token: CancelToken,
     489              :     hostname: String, // for pg_sni router
     490              :     user_info: ComputeUserInfo,
     491              : }
     492              : 
     493              : impl CancelClosure {
     494            0 :     pub(crate) fn new(
     495            0 :         socket_addr: SocketAddr,
     496            0 :         cancel_token: CancelToken,
     497            0 :         hostname: String,
     498            0 :         user_info: ComputeUserInfo,
     499            0 :     ) -> Self {
     500            0 :         Self {
     501            0 :             socket_addr,
     502            0 :             cancel_token,
     503            0 :             hostname,
     504            0 :             user_info,
     505            0 :         }
     506            0 :     }
     507              :     /// Cancels the query running on user's compute node.
     508            0 :     pub(crate) async fn try_cancel_query(
     509            0 :         self,
     510            0 :         compute_config: &ComputeConfig,
     511            0 :     ) -> Result<(), CancelError> {
     512            0 :         let socket = TcpStream::connect(self.socket_addr).await?;
     513              : 
     514            0 :         let mut mk_tls =
     515            0 :             crate::tls::postgres_rustls::MakeRustlsConnect::new(compute_config.tls.clone());
     516            0 :         let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
     517            0 :             &mut mk_tls,
     518            0 :             &self.hostname,
     519            0 :         )
     520            0 :         .map_err(|e| CancelError::IO(std::io::Error::other(e.to_string())))?;
     521              : 
     522            0 :         self.cancel_token.cancel_query_raw(socket, tls).await?;
     523            0 :         debug!("query was cancelled");
     524            0 :         Ok(())
     525            0 :     }
     526              : }
     527              : 
     528              : /// Helper for registering query cancellation tokens.
     529              : pub(crate) struct Session {
     530              :     /// The user-facing key identifying this session.
     531              :     key: CancelKeyData,
     532              :     redis_key: String,
     533              :     cancellation_handler: Arc<CancellationHandler>,
     534              : }
     535              : 
     536              : impl Session {
     537            0 :     pub(crate) fn key(&self) -> &CancelKeyData {
     538            0 :         &self.key
     539            0 :     }
     540              : 
     541              :     // Send the store key op to the cancellation handler and set TTL for the key
     542            0 :     pub(crate) async fn write_cancel_key(
     543            0 :         &self,
     544            0 :         cancel_closure: CancelClosure,
     545            0 :     ) -> Result<(), CancelError> {
     546            0 :         let Some(tx) = &self.cancellation_handler.tx else {
     547            0 :             tracing::warn!("cancellation handler is not available");
     548            0 :             return Err(CancelError::InternalError);
     549              :         };
     550              : 
     551            0 :         let closure_json = serde_json::to_string(&cancel_closure).map_err(|e| {
     552            0 :             tracing::warn!("failed to serialize cancel closure: {e}");
     553            0 :             CancelError::InternalError
     554            0 :         })?;
     555              : 
     556            0 :         let op = CancelKeyOp::StoreCancelKey {
     557            0 :             key: self.redis_key.clone(),
     558            0 :             field: "data".to_string(),
     559            0 :             value: closure_json,
     560            0 :             resp_tx: None,
     561            0 :             _guard: Metrics::get()
     562            0 :                 .proxy
     563            0 :                 .cancel_channel_size
     564            0 :                 .guard(RedisMsgKind::HSet),
     565            0 :             expire: CANCEL_KEY_TTL,
     566            0 :         };
     567            0 : 
     568            0 :         let _ = tx.send_timeout(op, REDIS_SEND_TIMEOUT).await.map_err(|e| {
     569            0 :             let key = self.key;
     570            0 :             tracing::warn!("failed to send StoreCancelKey for {key}: {e}");
     571            0 :         });
     572            0 :         Ok(())
     573            0 :     }
     574              : 
     575            0 :     pub(crate) async fn remove_cancel_key(&self) -> Result<(), CancelError> {
     576            0 :         let Some(tx) = &self.cancellation_handler.tx else {
     577            0 :             tracing::warn!("cancellation handler is not available");
     578            0 :             return Err(CancelError::InternalError);
     579              :         };
     580              : 
     581            0 :         let op = CancelKeyOp::RemoveCancelKey {
     582            0 :             key: self.redis_key.clone(),
     583            0 :             field: "data".to_string(),
     584            0 :             resp_tx: None,
     585            0 :             _guard: Metrics::get()
     586            0 :                 .proxy
     587            0 :                 .cancel_channel_size
     588            0 :                 .guard(RedisMsgKind::HDel),
     589            0 :         };
     590            0 : 
     591            0 :         let _ = tx.send_timeout(op, REDIS_SEND_TIMEOUT).await.map_err(|e| {
     592            0 :             let key = self.key;
     593            0 :             tracing::warn!("failed to send RemoveCancelKey for {key}: {e}");
     594            0 :         });
     595            0 :         Ok(())
     596            0 :     }
     597              : }
        

Generated by: LCOV version 2.1-beta