LCOV - code coverage report
Current view: top level - proxy/src - cancellation.rs (source / functions) Coverage Total Hit
Test: 472031e0b71f3195f7f21b1f2b20de09fd07bb56.info Lines: 0.0 % 393 0
Test Date: 2025-05-26 10:37:33 Functions: 0.0 % 55 0

            Line data    Source code
       1              : use std::net::{IpAddr, SocketAddr};
       2              : use std::sync::Arc;
       3              : 
       4              : use anyhow::{Context, anyhow};
       5              : use ipnet::{IpNet, Ipv4Net, Ipv6Net};
       6              : use postgres_client::CancelToken;
       7              : use postgres_client::tls::MakeTlsConnect;
       8              : use pq_proto::CancelKeyData;
       9              : use redis::{Cmd, FromRedisValue, Value};
      10              : use serde::{Deserialize, Serialize};
      11              : use thiserror::Error;
      12              : use tokio::net::TcpStream;
      13              : use tokio::sync::{mpsc, oneshot};
      14              : use tracing::{debug, error, info, warn};
      15              : 
      16              : use crate::auth::backend::ComputeUserInfo;
      17              : use crate::auth::{AuthError, check_peer_addr_is_in_list};
      18              : use crate::config::ComputeConfig;
      19              : use crate::context::RequestContext;
      20              : use crate::control_plane::ControlPlaneApi;
      21              : use crate::error::ReportableError;
      22              : use crate::ext::LockExt;
      23              : use crate::metrics::{CancelChannelSizeGuard, CancellationRequest, Metrics, RedisMsgKind};
      24              : use crate::protocol2::ConnectionInfoExtra;
      25              : use crate::rate_limiter::LeakyBucketRateLimiter;
      26              : use crate::redis::keys::KeyPrefix;
      27              : use crate::redis::kv_ops::RedisKVClient;
      28              : use crate::tls::postgres_rustls::MakeRustlsConnect;
      29              : 
      30              : type IpSubnetKey = IpNet;
      31              : 
      32              : const CANCEL_KEY_TTL: i64 = 1_209_600; // 2 weeks cancellation key expire time
      33              : 
      34              : // Message types for sending through mpsc channel
      35              : pub enum CancelKeyOp {
      36              :     StoreCancelKey {
      37              :         key: String,
      38              :         field: String,
      39              :         value: String,
      40              :         resp_tx: Option<oneshot::Sender<anyhow::Result<()>>>,
      41              :         _guard: CancelChannelSizeGuard<'static>,
      42              :         expire: i64, // TTL for key
      43              :     },
      44              :     GetCancelData {
      45              :         key: String,
      46              :         resp_tx: oneshot::Sender<anyhow::Result<Vec<(String, String)>>>,
      47              :         _guard: CancelChannelSizeGuard<'static>,
      48              :     },
      49              :     RemoveCancelKey {
      50              :         key: String,
      51              :         field: String,
      52              :         resp_tx: Option<oneshot::Sender<anyhow::Result<()>>>,
      53              :         _guard: CancelChannelSizeGuard<'static>,
      54              :     },
      55              : }
      56              : 
      57              : pub struct Pipeline {
      58              :     inner: redis::Pipeline,
      59              :     replies: Vec<CancelReplyOp>,
      60              : }
      61              : 
      62              : impl Pipeline {
      63            0 :     fn with_capacity(n: usize) -> Self {
      64            0 :         Self {
      65            0 :             inner: redis::Pipeline::with_capacity(n),
      66            0 :             replies: Vec::with_capacity(n),
      67            0 :         }
      68            0 :     }
      69              : 
      70            0 :     async fn execute(&mut self, client: &mut RedisKVClient) {
      71            0 :         let responses = self.replies.len();
      72            0 :         let batch_size = self.inner.len();
      73            0 : 
      74            0 :         match client.query(&self.inner).await {
      75              :             // for each reply, we expect that many values.
      76            0 :             Ok(Value::Array(values)) if values.len() == responses => {
      77            0 :                 debug!(
      78              :                     batch_size,
      79            0 :                     responses, "successfully completed cancellation jobs",
      80              :                 );
      81            0 :                 for (value, reply) in std::iter::zip(values, self.replies.drain(..)) {
      82            0 :                     reply.send_value(value);
      83            0 :                 }
      84              :             }
      85            0 :             Ok(value) => {
      86            0 :                 error!(batch_size, ?value, "unexpected redis return value");
      87            0 :                 for reply in self.replies.drain(..) {
      88            0 :                     reply.send_err(anyhow!("incorrect response type from redis"));
      89            0 :                 }
      90              :             }
      91            0 :             Err(err) => {
      92            0 :                 for reply in self.replies.drain(..) {
      93            0 :                     reply.send_err(anyhow!("could not send cmd to redis: {err}"));
      94            0 :                 }
      95              :             }
      96              :         }
      97              : 
      98            0 :         self.inner.clear();
      99            0 :         self.replies.clear();
     100            0 :     }
     101              : 
     102            0 :     fn add_command_with_reply(&mut self, cmd: Cmd, reply: CancelReplyOp) {
     103            0 :         self.inner.add_command(cmd);
     104            0 :         self.replies.push(reply);
     105            0 :     }
     106              : 
     107            0 :     fn add_command_no_reply(&mut self, cmd: Cmd) {
     108            0 :         self.inner.add_command(cmd).ignore();
     109            0 :     }
     110              : 
     111            0 :     fn add_command(&mut self, cmd: Cmd, reply: Option<CancelReplyOp>) {
     112            0 :         match reply {
     113            0 :             Some(reply) => self.add_command_with_reply(cmd, reply),
     114            0 :             None => self.add_command_no_reply(cmd),
     115              :         }
     116            0 :     }
     117              : }
     118              : 
     119              : impl CancelKeyOp {
     120            0 :     fn register(self, pipe: &mut Pipeline) {
     121            0 :         #[allow(clippy::used_underscore_binding)]
     122            0 :         match self {
     123              :             CancelKeyOp::StoreCancelKey {
     124            0 :                 key,
     125            0 :                 field,
     126            0 :                 value,
     127            0 :                 resp_tx,
     128            0 :                 _guard,
     129            0 :                 expire,
     130            0 :             } => {
     131            0 :                 let reply =
     132            0 :                     resp_tx.map(|resp_tx| CancelReplyOp::StoreCancelKey { resp_tx, _guard });
     133            0 :                 pipe.add_command(Cmd::hset(&key, field, value), reply);
     134            0 :                 pipe.add_command_no_reply(Cmd::expire(key, expire));
     135            0 :             }
     136              :             CancelKeyOp::GetCancelData {
     137            0 :                 key,
     138            0 :                 resp_tx,
     139            0 :                 _guard,
     140            0 :             } => {
     141            0 :                 let reply = CancelReplyOp::GetCancelData { resp_tx, _guard };
     142            0 :                 pipe.add_command_with_reply(Cmd::hgetall(key), reply);
     143            0 :             }
     144              :             CancelKeyOp::RemoveCancelKey {
     145            0 :                 key,
     146            0 :                 field,
     147            0 :                 resp_tx,
     148            0 :                 _guard,
     149            0 :             } => {
     150            0 :                 let reply =
     151            0 :                     resp_tx.map(|resp_tx| CancelReplyOp::RemoveCancelKey { resp_tx, _guard });
     152            0 :                 pipe.add_command(Cmd::hdel(key, field), reply);
     153            0 :             }
     154              :         }
     155            0 :     }
     156              : }
     157              : 
     158              : // Message types for sending through mpsc channel
     159              : pub enum CancelReplyOp {
     160              :     StoreCancelKey {
     161              :         resp_tx: oneshot::Sender<anyhow::Result<()>>,
     162              :         _guard: CancelChannelSizeGuard<'static>,
     163              :     },
     164              :     GetCancelData {
     165              :         resp_tx: oneshot::Sender<anyhow::Result<Vec<(String, String)>>>,
     166              :         _guard: CancelChannelSizeGuard<'static>,
     167              :     },
     168              :     RemoveCancelKey {
     169              :         resp_tx: oneshot::Sender<anyhow::Result<()>>,
     170              :         _guard: CancelChannelSizeGuard<'static>,
     171              :     },
     172              : }
     173              : 
     174              : impl CancelReplyOp {
     175            0 :     fn send_err(self, e: anyhow::Error) {
     176            0 :         match self {
     177            0 :             CancelReplyOp::StoreCancelKey { resp_tx, _guard } => {
     178            0 :                 resp_tx
     179            0 :                     .send(Err(e))
     180            0 :                     .inspect_err(|_| tracing::debug!("could not send reply"))
     181            0 :                     .ok();
     182            0 :             }
     183            0 :             CancelReplyOp::GetCancelData { resp_tx, _guard } => {
     184            0 :                 resp_tx
     185            0 :                     .send(Err(e))
     186            0 :                     .inspect_err(|_| tracing::debug!("could not send reply"))
     187            0 :                     .ok();
     188            0 :             }
     189            0 :             CancelReplyOp::RemoveCancelKey { resp_tx, _guard } => {
     190            0 :                 resp_tx
     191            0 :                     .send(Err(e))
     192            0 :                     .inspect_err(|_| tracing::debug!("could not send reply"))
     193            0 :                     .ok();
     194            0 :             }
     195              :         }
     196            0 :     }
     197              : 
     198            0 :     fn send_value(self, v: redis::Value) {
     199            0 :         match self {
     200            0 :             CancelReplyOp::StoreCancelKey { resp_tx, _guard } => {
     201            0 :                 let send =
     202            0 :                     FromRedisValue::from_owned_redis_value(v).context("could not parse value");
     203            0 :                 resp_tx
     204            0 :                     .send(send)
     205            0 :                     .inspect_err(|_| tracing::debug!("could not send reply"))
     206            0 :                     .ok();
     207            0 :             }
     208            0 :             CancelReplyOp::GetCancelData { resp_tx, _guard } => {
     209            0 :                 let send =
     210            0 :                     FromRedisValue::from_owned_redis_value(v).context("could not parse value");
     211            0 :                 resp_tx
     212            0 :                     .send(send)
     213            0 :                     .inspect_err(|_| tracing::debug!("could not send reply"))
     214            0 :                     .ok();
     215            0 :             }
     216            0 :             CancelReplyOp::RemoveCancelKey { resp_tx, _guard } => {
     217            0 :                 let send =
     218            0 :                     FromRedisValue::from_owned_redis_value(v).context("could not parse value");
     219            0 :                 resp_tx
     220            0 :                     .send(send)
     221            0 :                     .inspect_err(|_| tracing::debug!("could not send reply"))
     222            0 :                     .ok();
     223            0 :             }
     224              :         }
     225            0 :     }
     226              : }
     227              : 
     228              : // Running as a separate task to accept messages through the rx channel
     229            0 : pub async fn handle_cancel_messages(
     230            0 :     client: &mut RedisKVClient,
     231            0 :     mut rx: mpsc::Receiver<CancelKeyOp>,
     232            0 :     batch_size: usize,
     233            0 : ) -> anyhow::Result<()> {
     234            0 :     let mut batch = Vec::with_capacity(batch_size);
     235            0 :     let mut pipeline = Pipeline::with_capacity(batch_size);
     236              : 
     237              :     loop {
     238            0 :         if rx.recv_many(&mut batch, batch_size).await == 0 {
     239            0 :             warn!("shutting down cancellation queue");
     240            0 :             break Ok(());
     241            0 :         }
     242            0 : 
     243            0 :         let batch_size = batch.len();
     244            0 :         debug!(batch_size, "running cancellation jobs");
     245              : 
     246            0 :         for msg in batch.drain(..) {
     247            0 :             msg.register(&mut pipeline);
     248            0 :         }
     249              : 
     250            0 :         pipeline.execute(client).await;
     251              :     }
     252            0 : }
     253              : 
     254              : /// Enables serving `CancelRequest`s.
     255              : ///
     256              : /// If `CancellationPublisher` is available, cancel request will be used to publish the cancellation key to other proxy instances.
     257              : pub struct CancellationHandler {
     258              :     compute_config: &'static ComputeConfig,
     259              :     // rate limiter of cancellation requests
     260              :     limiter: Arc<std::sync::Mutex<LeakyBucketRateLimiter<IpSubnetKey>>>,
     261              :     tx: Option<mpsc::Sender<CancelKeyOp>>, // send messages to the redis KV client task
     262              : }
     263              : 
     264              : #[derive(Debug, Error)]
     265              : pub(crate) enum CancelError {
     266              :     #[error("{0}")]
     267              :     IO(#[from] std::io::Error),
     268              : 
     269              :     #[error("{0}")]
     270              :     Postgres(#[from] postgres_client::Error),
     271              : 
     272              :     #[error("rate limit exceeded")]
     273              :     RateLimit,
     274              : 
     275              :     #[error("IP is not allowed")]
     276              :     IpNotAllowed,
     277              : 
     278              :     #[error("VPC endpoint id is not allowed to connect")]
     279              :     VpcEndpointIdNotAllowed,
     280              : 
     281              :     #[error("Authentication backend error")]
     282              :     AuthError(#[from] AuthError),
     283              : 
     284              :     #[error("key not found")]
     285              :     NotFound,
     286              : 
     287              :     #[error("proxy service error")]
     288              :     InternalError,
     289              : }
     290              : 
     291              : impl ReportableError for CancelError {
     292            0 :     fn get_error_kind(&self) -> crate::error::ErrorKind {
     293            0 :         match self {
     294            0 :             CancelError::IO(_) => crate::error::ErrorKind::Compute,
     295            0 :             CancelError::Postgres(e) if e.as_db_error().is_some() => {
     296            0 :                 crate::error::ErrorKind::Postgres
     297              :             }
     298            0 :             CancelError::Postgres(_) => crate::error::ErrorKind::Compute,
     299            0 :             CancelError::RateLimit => crate::error::ErrorKind::RateLimit,
     300              :             CancelError::IpNotAllowed
     301              :             | CancelError::VpcEndpointIdNotAllowed
     302            0 :             | CancelError::NotFound => crate::error::ErrorKind::User,
     303            0 :             CancelError::AuthError(_) => crate::error::ErrorKind::ControlPlane,
     304            0 :             CancelError::InternalError => crate::error::ErrorKind::Service,
     305              :         }
     306            0 :     }
     307              : }
     308              : 
     309              : impl CancellationHandler {
     310            0 :     pub fn new(
     311            0 :         compute_config: &'static ComputeConfig,
     312            0 :         tx: Option<mpsc::Sender<CancelKeyOp>>,
     313            0 :     ) -> Self {
     314            0 :         Self {
     315            0 :             compute_config,
     316            0 :             tx,
     317            0 :             limiter: Arc::new(std::sync::Mutex::new(
     318            0 :                 LeakyBucketRateLimiter::<IpSubnetKey>::new_with_shards(
     319            0 :                     LeakyBucketRateLimiter::<IpSubnetKey>::DEFAULT,
     320            0 :                     64,
     321            0 :                 ),
     322            0 :             )),
     323            0 :         }
     324            0 :     }
     325              : 
     326            0 :     pub(crate) fn get_key(self: &Arc<Self>) -> Session {
     327            0 :         // we intentionally generate a random "backend pid" and "secret key" here.
     328            0 :         // we use the corresponding u64 as an identifier for the
     329            0 :         // actual endpoint+pid+secret for postgres/pgbouncer.
     330            0 :         //
     331            0 :         // if we forwarded the backend_pid from postgres to the client, there would be a lot
     332            0 :         // of overlap between our computes as most pids are small (~100).
     333            0 : 
     334            0 :         let key: CancelKeyData = rand::random();
     335            0 : 
     336            0 :         let prefix_key: KeyPrefix = KeyPrefix::Cancel(key);
     337            0 :         let redis_key = prefix_key.build_redis_key();
     338            0 : 
     339            0 :         debug!("registered new query cancellation key {key}");
     340            0 :         Session {
     341            0 :             key,
     342            0 :             redis_key,
     343            0 :             cancellation_handler: Arc::clone(self),
     344            0 :         }
     345            0 :     }
     346              : 
     347            0 :     async fn get_cancel_key(
     348            0 :         &self,
     349            0 :         key: CancelKeyData,
     350            0 :     ) -> Result<Option<CancelClosure>, CancelError> {
     351            0 :         let prefix_key: KeyPrefix = KeyPrefix::Cancel(key);
     352            0 :         let redis_key = prefix_key.build_redis_key();
     353            0 : 
     354            0 :         let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
     355            0 :         let op = CancelKeyOp::GetCancelData {
     356            0 :             key: redis_key,
     357            0 :             resp_tx,
     358            0 :             _guard: Metrics::get()
     359            0 :                 .proxy
     360            0 :                 .cancel_channel_size
     361            0 :                 .guard(RedisMsgKind::HGetAll),
     362            0 :         };
     363              : 
     364            0 :         let Some(tx) = &self.tx else {
     365            0 :             tracing::warn!("cancellation handler is not available");
     366            0 :             return Err(CancelError::InternalError);
     367              :         };
     368              : 
     369            0 :         tx.try_send(op)
     370            0 :             .map_err(|e| {
     371            0 :                 tracing::warn!("failed to send GetCancelData for {key}: {e}");
     372            0 :             })
     373            0 :             .map_err(|()| CancelError::InternalError)?;
     374              : 
     375            0 :         let result = resp_rx.await.map_err(|e| {
     376            0 :             tracing::warn!("failed to receive GetCancelData response: {e}");
     377            0 :             CancelError::InternalError
     378            0 :         })?;
     379              : 
     380            0 :         let cancel_state_str: Option<String> = match result {
     381            0 :             Ok(mut state) => {
     382            0 :                 if state.len() == 1 {
     383            0 :                     Some(state.remove(0).1)
     384              :                 } else {
     385            0 :                     tracing::warn!("unexpected number of entries in cancel state: {state:?}");
     386            0 :                     return Err(CancelError::InternalError);
     387              :                 }
     388              :             }
     389            0 :             Err(e) => {
     390            0 :                 tracing::warn!("failed to receive cancel state from redis: {e}");
     391            0 :                 return Err(CancelError::InternalError);
     392              :             }
     393              :         };
     394              : 
     395            0 :         let cancel_state: Option<CancelClosure> = match cancel_state_str {
     396            0 :             Some(state) => {
     397            0 :                 let cancel_closure: CancelClosure = serde_json::from_str(&state).map_err(|e| {
     398            0 :                     tracing::warn!("failed to deserialize cancel state: {e}");
     399            0 :                     CancelError::InternalError
     400            0 :                 })?;
     401            0 :                 Some(cancel_closure)
     402              :             }
     403            0 :             None => None,
     404              :         };
     405            0 :         Ok(cancel_state)
     406            0 :     }
     407              :     /// Try to cancel a running query for the corresponding connection.
     408              :     /// If the cancellation key is not found, it will be published to Redis.
     409              :     /// check_allowed - if true, check if the IP is allowed to cancel the query.
     410              :     /// Will fetch IP allowlist internally.
     411              :     ///
     412              :     /// return Result primarily for tests
     413            0 :     pub(crate) async fn cancel_session<T: ControlPlaneApi>(
     414            0 :         &self,
     415            0 :         key: CancelKeyData,
     416            0 :         ctx: RequestContext,
     417            0 :         check_ip_allowed: bool,
     418            0 :         check_vpc_allowed: bool,
     419            0 :         auth_backend: &T,
     420            0 :     ) -> Result<(), CancelError> {
     421            0 :         let subnet_key = match ctx.peer_addr() {
     422            0 :             IpAddr::V4(ip) => IpNet::V4(Ipv4Net::new_assert(ip, 24).trunc()), // use defaut mask here
     423            0 :             IpAddr::V6(ip) => IpNet::V6(Ipv6Net::new_assert(ip, 64).trunc()),
     424              :         };
     425            0 :         if !self.limiter.lock_propagate_poison().check(subnet_key, 1) {
     426              :             // log only the subnet part of the IP address to know which subnet is rate limited
     427            0 :             tracing::warn!("Rate limit exceeded. Skipping cancellation message, {subnet_key}");
     428            0 :             Metrics::get()
     429            0 :                 .proxy
     430            0 :                 .cancellation_requests_total
     431            0 :                 .inc(CancellationRequest {
     432            0 :                     kind: crate::metrics::CancellationOutcome::RateLimitExceeded,
     433            0 :                 });
     434            0 :             return Err(CancelError::RateLimit);
     435            0 :         }
     436              : 
     437            0 :         let cancel_state = self.get_cancel_key(key).await.map_err(|e| {
     438            0 :             tracing::warn!("failed to receive RedisOp response: {e}");
     439            0 :             CancelError::InternalError
     440            0 :         })?;
     441              : 
     442            0 :         let Some(cancel_closure) = cancel_state else {
     443            0 :             tracing::warn!("query cancellation key not found: {key}");
     444            0 :             Metrics::get()
     445            0 :                 .proxy
     446            0 :                 .cancellation_requests_total
     447            0 :                 .inc(CancellationRequest {
     448            0 :                     kind: crate::metrics::CancellationOutcome::NotFound,
     449            0 :                 });
     450            0 :             return Err(CancelError::NotFound);
     451              :         };
     452              : 
     453            0 :         if check_ip_allowed {
     454            0 :             let ip_allowlist = auth_backend
     455            0 :                 .get_allowed_ips(&ctx, &cancel_closure.user_info)
     456            0 :                 .await
     457            0 :                 .map_err(|e| CancelError::AuthError(e.into()))?;
     458              : 
     459            0 :             if !check_peer_addr_is_in_list(&ctx.peer_addr(), &ip_allowlist) {
     460              :                 // log it here since cancel_session could be spawned in a task
     461            0 :                 tracing::warn!(
     462            0 :                     "IP is not allowed to cancel the query: {key}, address: {}",
     463            0 :                     ctx.peer_addr()
     464              :                 );
     465            0 :                 return Err(CancelError::IpNotAllowed);
     466            0 :             }
     467            0 :         }
     468              : 
     469              :         // check if a VPC endpoint ID is coming in and if yes, if it's allowed
     470            0 :         let access_blocks = auth_backend
     471            0 :             .get_block_public_or_vpc_access(&ctx, &cancel_closure.user_info)
     472            0 :             .await
     473            0 :             .map_err(|e| CancelError::AuthError(e.into()))?;
     474              : 
     475            0 :         if check_vpc_allowed {
     476            0 :             if access_blocks.vpc_access_blocked {
     477            0 :                 return Err(CancelError::AuthError(AuthError::NetworkNotAllowed));
     478            0 :             }
     479              : 
     480            0 :             let incoming_vpc_endpoint_id = match ctx.extra() {
     481            0 :                 None => return Err(CancelError::AuthError(AuthError::MissingVPCEndpointId)),
     482            0 :                 Some(ConnectionInfoExtra::Aws { vpce_id }) => vpce_id.to_string(),
     483            0 :                 Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(),
     484              :             };
     485              : 
     486            0 :             let allowed_vpc_endpoint_ids = auth_backend
     487            0 :                 .get_allowed_vpc_endpoint_ids(&ctx, &cancel_closure.user_info)
     488            0 :                 .await
     489            0 :                 .map_err(|e| CancelError::AuthError(e.into()))?;
     490              :             // TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that.
     491            0 :             if !allowed_vpc_endpoint_ids.is_empty()
     492            0 :                 && !allowed_vpc_endpoint_ids.contains(&incoming_vpc_endpoint_id)
     493              :             {
     494            0 :                 return Err(CancelError::VpcEndpointIdNotAllowed);
     495            0 :             }
     496            0 :         } else if access_blocks.public_access_blocked {
     497            0 :             return Err(CancelError::VpcEndpointIdNotAllowed);
     498            0 :         }
     499              : 
     500            0 :         Metrics::get()
     501            0 :             .proxy
     502            0 :             .cancellation_requests_total
     503            0 :             .inc(CancellationRequest {
     504            0 :                 kind: crate::metrics::CancellationOutcome::Found,
     505            0 :             });
     506            0 :         info!("cancelling query per user's request using key {key}");
     507            0 :         cancel_closure.try_cancel_query(self.compute_config).await
     508            0 :     }
     509              : }
     510              : 
     511              : /// This should've been a [`std::future::Future`], but
     512              : /// it's impossible to name a type of an unboxed future
     513              : /// (we'd need something like `#![feature(type_alias_impl_trait)]`).
     514            0 : #[derive(Clone, Serialize, Deserialize)]
     515              : pub struct CancelClosure {
     516              :     socket_addr: SocketAddr,
     517              :     cancel_token: CancelToken,
     518              :     hostname: String, // for pg_sni router
     519              :     user_info: ComputeUserInfo,
     520              : }
     521              : 
     522              : impl CancelClosure {
     523            0 :     pub(crate) fn new(
     524            0 :         socket_addr: SocketAddr,
     525            0 :         cancel_token: CancelToken,
     526            0 :         hostname: String,
     527            0 :         user_info: ComputeUserInfo,
     528            0 :     ) -> Self {
     529            0 :         Self {
     530            0 :             socket_addr,
     531            0 :             cancel_token,
     532            0 :             hostname,
     533            0 :             user_info,
     534            0 :         }
     535            0 :     }
     536              :     /// Cancels the query running on user's compute node.
     537            0 :     pub(crate) async fn try_cancel_query(
     538            0 :         self,
     539            0 :         compute_config: &ComputeConfig,
     540            0 :     ) -> Result<(), CancelError> {
     541            0 :         let socket = TcpStream::connect(self.socket_addr).await?;
     542              : 
     543            0 :         let mut mk_tls =
     544            0 :             crate::tls::postgres_rustls::MakeRustlsConnect::new(compute_config.tls.clone());
     545            0 :         let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
     546            0 :             &mut mk_tls,
     547            0 :             &self.hostname,
     548            0 :         )
     549            0 :         .map_err(|e| CancelError::IO(std::io::Error::other(e.to_string())))?;
     550              : 
     551            0 :         self.cancel_token.cancel_query_raw(socket, tls).await?;
     552            0 :         debug!("query was cancelled");
     553            0 :         Ok(())
     554            0 :     }
     555              : }
     556              : 
     557              : /// Helper for registering query cancellation tokens.
     558              : pub(crate) struct Session {
     559              :     /// The user-facing key identifying this session.
     560              :     key: CancelKeyData,
     561              :     redis_key: String,
     562              :     cancellation_handler: Arc<CancellationHandler>,
     563              : }
     564              : 
     565              : impl Session {
     566            0 :     pub(crate) fn key(&self) -> &CancelKeyData {
     567            0 :         &self.key
     568            0 :     }
     569              : 
     570              :     // Send the store key op to the cancellation handler and set TTL for the key
     571            0 :     pub(crate) fn write_cancel_key(
     572            0 :         &self,
     573            0 :         cancel_closure: CancelClosure,
     574            0 :     ) -> Result<(), CancelError> {
     575            0 :         let Some(tx) = &self.cancellation_handler.tx else {
     576            0 :             tracing::warn!("cancellation handler is not available");
     577            0 :             return Err(CancelError::InternalError);
     578              :         };
     579              : 
     580            0 :         let closure_json = serde_json::to_string(&cancel_closure).map_err(|e| {
     581            0 :             tracing::warn!("failed to serialize cancel closure: {e}");
     582            0 :             CancelError::InternalError
     583            0 :         })?;
     584              : 
     585            0 :         let op = CancelKeyOp::StoreCancelKey {
     586            0 :             key: self.redis_key.clone(),
     587            0 :             field: "data".to_string(),
     588            0 :             value: closure_json,
     589            0 :             resp_tx: None,
     590            0 :             _guard: Metrics::get()
     591            0 :                 .proxy
     592            0 :                 .cancel_channel_size
     593            0 :                 .guard(RedisMsgKind::HSet),
     594            0 :             expire: CANCEL_KEY_TTL,
     595            0 :         };
     596            0 : 
     597            0 :         let _ = tx.try_send(op).map_err(|e| {
     598            0 :             let key = self.key;
     599            0 :             tracing::warn!("failed to send StoreCancelKey for {key}: {e}");
     600            0 :         });
     601            0 :         Ok(())
     602            0 :     }
     603              : 
     604            0 :     pub(crate) fn remove_cancel_key(&self) -> Result<(), CancelError> {
     605            0 :         let Some(tx) = &self.cancellation_handler.tx else {
     606            0 :             tracing::warn!("cancellation handler is not available");
     607            0 :             return Err(CancelError::InternalError);
     608              :         };
     609              : 
     610            0 :         let op = CancelKeyOp::RemoveCancelKey {
     611            0 :             key: self.redis_key.clone(),
     612            0 :             field: "data".to_string(),
     613            0 :             resp_tx: None,
     614            0 :             _guard: Metrics::get()
     615            0 :                 .proxy
     616            0 :                 .cancel_channel_size
     617            0 :                 .guard(RedisMsgKind::HDel),
     618            0 :         };
     619            0 : 
     620            0 :         let _ = tx.try_send(op).map_err(|e| {
     621            0 :             let key = self.key;
     622            0 :             tracing::warn!("failed to send RemoveCancelKey for {key}: {e}");
     623            0 :         });
     624            0 :         Ok(())
     625            0 :     }
     626              : }
        

Generated by: LCOV version 2.1-beta