LCOV - code coverage report
Current view: top level - proxy/src - cancellation.rs (source / functions) Coverage Total Hit
Test: 5445d246133daeceb0507e6cc0797ab7c1c70cb8.info Lines: 0.0 % 312 0
Test Date: 2025-03-12 18:05:02 Functions: 0.0 % 52 0

            Line data    Source code
       1              : use std::convert::Infallible;
       2              : use std::net::{IpAddr, SocketAddr};
       3              : use std::sync::Arc;
       4              : 
       5              : use ipnet::{IpNet, Ipv4Net, Ipv6Net};
       6              : use postgres_client::CancelToken;
       7              : use postgres_client::tls::MakeTlsConnect;
       8              : use pq_proto::CancelKeyData;
       9              : use serde::{Deserialize, Serialize};
      10              : use thiserror::Error;
      11              : use tokio::net::TcpStream;
      12              : use tokio::sync::{mpsc, oneshot};
      13              : use tracing::{debug, info};
      14              : 
      15              : use crate::auth::backend::ComputeUserInfo;
      16              : use crate::auth::{AuthError, check_peer_addr_is_in_list};
      17              : use crate::config::ComputeConfig;
      18              : use crate::context::RequestContext;
      19              : use crate::control_plane::ControlPlaneApi;
      20              : use crate::error::ReportableError;
      21              : use crate::ext::LockExt;
      22              : use crate::metrics::{CancelChannelSizeGuard, CancellationRequest, Metrics, RedisMsgKind};
      23              : use crate::protocol2::ConnectionInfoExtra;
      24              : use crate::rate_limiter::LeakyBucketRateLimiter;
      25              : use crate::redis::keys::KeyPrefix;
      26              : use crate::redis::kv_ops::RedisKVClient;
      27              : use crate::tls::postgres_rustls::MakeRustlsConnect;
      28              : 
      29              : type IpSubnetKey = IpNet;
      30              : 
      31              : const CANCEL_KEY_TTL: i64 = 1_209_600; // 2 weeks cancellation key expire time
      32              : const REDIS_SEND_TIMEOUT: std::time::Duration = std::time::Duration::from_millis(10);
      33              : 
      34              : // Message types for sending through mpsc channel
      35              : pub enum CancelKeyOp {
      36              :     StoreCancelKey {
      37              :         key: String,
      38              :         field: String,
      39              :         value: String,
      40              :         resp_tx: Option<oneshot::Sender<anyhow::Result<()>>>,
      41              :         _guard: CancelChannelSizeGuard<'static>,
      42              :         expire: i64, // TTL for key
      43              :     },
      44              :     GetCancelData {
      45              :         key: String,
      46              :         resp_tx: oneshot::Sender<anyhow::Result<Vec<(String, String)>>>,
      47              :         _guard: CancelChannelSizeGuard<'static>,
      48              :     },
      49              :     RemoveCancelKey {
      50              :         key: String,
      51              :         field: String,
      52              :         resp_tx: Option<oneshot::Sender<anyhow::Result<()>>>,
      53              :         _guard: CancelChannelSizeGuard<'static>,
      54              :     },
      55              : }
      56              : 
      57              : // Running as a separate task to accept messages through the rx channel
      58              : // In case of problems with RTT: switch to recv_many() + redis pipeline
      59            0 : pub async fn handle_cancel_messages(
      60            0 :     client: &mut RedisKVClient,
      61            0 :     mut rx: mpsc::Receiver<CancelKeyOp>,
      62            0 : ) -> anyhow::Result<Infallible> {
      63              :     loop {
      64            0 :         if let Some(msg) = rx.recv().await {
      65            0 :             match msg {
      66              :                 CancelKeyOp::StoreCancelKey {
      67            0 :                     key,
      68            0 :                     field,
      69            0 :                     value,
      70            0 :                     resp_tx,
      71            0 :                     _guard,
      72            0 :                     expire,
      73              :                 } => {
      74            0 :                     let res = client.hset(&key, field, value).await;
      75            0 :                     if let Some(resp_tx) = resp_tx {
      76            0 :                         if res.is_ok() {
      77            0 :                             resp_tx
      78            0 :                                 .send(client.expire(key, expire).await)
      79            0 :                                 .inspect_err(|e| {
      80            0 :                                     tracing::debug!(
      81            0 :                                         "failed to send StoreCancelKey response: {:?}",
      82              :                                         e
      83              :                                     );
      84            0 :                                 })
      85            0 :                                 .ok();
      86            0 :                         } else {
      87            0 :                             resp_tx
      88            0 :                                 .send(res)
      89            0 :                                 .inspect_err(|e| {
      90            0 :                                     tracing::debug!(
      91            0 :                                         "failed to send StoreCancelKey response: {:?}",
      92              :                                         e
      93              :                                     );
      94            0 :                                 })
      95            0 :                                 .ok();
      96            0 :                         }
      97            0 :                     } else if res.is_ok() {
      98            0 :                         drop(client.expire(key, expire).await);
      99              :                     } else {
     100            0 :                         tracing::warn!("failed to store cancel key: {:?}", res);
     101              :                     }
     102              :                 }
     103              :                 CancelKeyOp::GetCancelData {
     104            0 :                     key,
     105            0 :                     resp_tx,
     106            0 :                     _guard,
     107            0 :                 } => {
     108            0 :                     drop(resp_tx.send(client.hget_all(key).await));
     109              :                 }
     110              :                 CancelKeyOp::RemoveCancelKey {
     111            0 :                     key,
     112            0 :                     field,
     113            0 :                     resp_tx,
     114            0 :                     _guard,
     115              :                 } => {
     116            0 :                     if let Some(resp_tx) = resp_tx {
     117            0 :                         resp_tx
     118            0 :                             .send(client.hdel(key, field).await)
     119            0 :                             .inspect_err(|e| {
     120            0 :                                 tracing::debug!("failed to send StoreCancelKey response: {:?}", e);
     121            0 :                             })
     122            0 :                             .ok();
     123            0 :                     } else {
     124            0 :                         drop(client.hdel(key, field).await);
     125              :                     }
     126              :                 }
     127              :             }
     128            0 :         }
     129              :     }
     130              : }
     131              : 
     132              : /// Enables serving `CancelRequest`s.
     133              : ///
     134              : /// If `CancellationPublisher` is available, cancel request will be used to publish the cancellation key to other proxy instances.
     135              : pub struct CancellationHandler {
     136              :     compute_config: &'static ComputeConfig,
     137              :     // rate limiter of cancellation requests
     138              :     limiter: Arc<std::sync::Mutex<LeakyBucketRateLimiter<IpSubnetKey>>>,
     139              :     tx: Option<mpsc::Sender<CancelKeyOp>>, // send messages to the redis KV client task
     140              : }
     141              : 
     142              : #[derive(Debug, Error)]
     143              : pub(crate) enum CancelError {
     144              :     #[error("{0}")]
     145              :     IO(#[from] std::io::Error),
     146              : 
     147              :     #[error("{0}")]
     148              :     Postgres(#[from] postgres_client::Error),
     149              : 
     150              :     #[error("rate limit exceeded")]
     151              :     RateLimit,
     152              : 
     153              :     #[error("IP is not allowed")]
     154              :     IpNotAllowed,
     155              : 
     156              :     #[error("VPC endpoint id is not allowed to connect")]
     157              :     VpcEndpointIdNotAllowed,
     158              : 
     159              :     #[error("Authentication backend error")]
     160              :     AuthError(#[from] AuthError),
     161              : 
     162              :     #[error("key not found")]
     163              :     NotFound,
     164              : 
     165              :     #[error("proxy service error")]
     166              :     InternalError,
     167              : }
     168              : 
     169              : impl ReportableError for CancelError {
     170            0 :     fn get_error_kind(&self) -> crate::error::ErrorKind {
     171            0 :         match self {
     172            0 :             CancelError::IO(_) => crate::error::ErrorKind::Compute,
     173            0 :             CancelError::Postgres(e) if e.as_db_error().is_some() => {
     174            0 :                 crate::error::ErrorKind::Postgres
     175              :             }
     176            0 :             CancelError::Postgres(_) => crate::error::ErrorKind::Compute,
     177            0 :             CancelError::RateLimit => crate::error::ErrorKind::RateLimit,
     178              :             CancelError::IpNotAllowed
     179              :             | CancelError::VpcEndpointIdNotAllowed
     180            0 :             | CancelError::NotFound => crate::error::ErrorKind::User,
     181            0 :             CancelError::AuthError(_) => crate::error::ErrorKind::ControlPlane,
     182            0 :             CancelError::InternalError => crate::error::ErrorKind::Service,
     183              :         }
     184            0 :     }
     185              : }
     186              : 
     187              : impl CancellationHandler {
     188            0 :     pub fn new(
     189            0 :         compute_config: &'static ComputeConfig,
     190            0 :         tx: Option<mpsc::Sender<CancelKeyOp>>,
     191            0 :     ) -> Self {
     192            0 :         Self {
     193            0 :             compute_config,
     194            0 :             tx,
     195            0 :             limiter: Arc::new(std::sync::Mutex::new(
     196            0 :                 LeakyBucketRateLimiter::<IpSubnetKey>::new_with_shards(
     197            0 :                     LeakyBucketRateLimiter::<IpSubnetKey>::DEFAULT,
     198            0 :                     64,
     199            0 :                 ),
     200            0 :             )),
     201            0 :         }
     202            0 :     }
     203              : 
     204            0 :     pub(crate) fn get_key(self: &Arc<Self>) -> Session {
     205            0 :         // we intentionally generate a random "backend pid" and "secret key" here.
     206            0 :         // we use the corresponding u64 as an identifier for the
     207            0 :         // actual endpoint+pid+secret for postgres/pgbouncer.
     208            0 :         //
     209            0 :         // if we forwarded the backend_pid from postgres to the client, there would be a lot
     210            0 :         // of overlap between our computes as most pids are small (~100).
     211            0 : 
     212            0 :         let key: CancelKeyData = rand::random();
     213            0 : 
     214            0 :         let prefix_key: KeyPrefix = KeyPrefix::Cancel(key);
     215            0 :         let redis_key = prefix_key.build_redis_key();
     216            0 : 
     217            0 :         debug!("registered new query cancellation key {key}");
     218            0 :         Session {
     219            0 :             key,
     220            0 :             redis_key,
     221            0 :             cancellation_handler: Arc::clone(self),
     222            0 :         }
     223            0 :     }
     224              : 
     225            0 :     async fn get_cancel_key(
     226            0 :         &self,
     227            0 :         key: CancelKeyData,
     228            0 :     ) -> Result<Option<CancelClosure>, CancelError> {
     229            0 :         let prefix_key: KeyPrefix = KeyPrefix::Cancel(key);
     230            0 :         let redis_key = prefix_key.build_redis_key();
     231            0 : 
     232            0 :         let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
     233            0 :         let op = CancelKeyOp::GetCancelData {
     234            0 :             key: redis_key,
     235            0 :             resp_tx,
     236            0 :             _guard: Metrics::get()
     237            0 :                 .proxy
     238            0 :                 .cancel_channel_size
     239            0 :                 .guard(RedisMsgKind::HGetAll),
     240            0 :         };
     241              : 
     242            0 :         let Some(tx) = &self.tx else {
     243            0 :             tracing::warn!("cancellation handler is not available");
     244            0 :             return Err(CancelError::InternalError);
     245              :         };
     246              : 
     247            0 :         tx.send_timeout(op, REDIS_SEND_TIMEOUT)
     248            0 :             .await
     249            0 :             .map_err(|e| {
     250            0 :                 tracing::warn!("failed to send GetCancelData for {key}: {e}");
     251            0 :             })
     252            0 :             .map_err(|()| CancelError::InternalError)?;
     253              : 
     254            0 :         let result = resp_rx.await.map_err(|e| {
     255            0 :             tracing::warn!("failed to receive GetCancelData response: {e}");
     256            0 :             CancelError::InternalError
     257            0 :         })?;
     258              : 
     259            0 :         let cancel_state_str: Option<String> = match result {
     260            0 :             Ok(mut state) => {
     261            0 :                 if state.len() == 1 {
     262            0 :                     Some(state.remove(0).1)
     263              :                 } else {
     264            0 :                     tracing::warn!("unexpected number of entries in cancel state: {state:?}");
     265            0 :                     return Err(CancelError::InternalError);
     266              :                 }
     267              :             }
     268            0 :             Err(e) => {
     269            0 :                 tracing::warn!("failed to receive cancel state from redis: {e}");
     270            0 :                 return Err(CancelError::InternalError);
     271              :             }
     272              :         };
     273              : 
     274            0 :         let cancel_state: Option<CancelClosure> = match cancel_state_str {
     275            0 :             Some(state) => {
     276            0 :                 let cancel_closure: CancelClosure = serde_json::from_str(&state).map_err(|e| {
     277            0 :                     tracing::warn!("failed to deserialize cancel state: {e}");
     278            0 :                     CancelError::InternalError
     279            0 :                 })?;
     280            0 :                 Some(cancel_closure)
     281              :             }
     282            0 :             None => None,
     283              :         };
     284            0 :         Ok(cancel_state)
     285            0 :     }
     286              :     /// Try to cancel a running query for the corresponding connection.
     287              :     /// If the cancellation key is not found, it will be published to Redis.
     288              :     /// check_allowed - if true, check if the IP is allowed to cancel the query.
     289              :     /// Will fetch IP allowlist internally.
     290              :     ///
     291              :     /// return Result primarily for tests
     292            0 :     pub(crate) async fn cancel_session<T: ControlPlaneApi>(
     293            0 :         &self,
     294            0 :         key: CancelKeyData,
     295            0 :         ctx: RequestContext,
     296            0 :         check_ip_allowed: bool,
     297            0 :         check_vpc_allowed: bool,
     298            0 :         auth_backend: &T,
     299            0 :     ) -> Result<(), CancelError> {
     300            0 :         let subnet_key = match ctx.peer_addr() {
     301            0 :             IpAddr::V4(ip) => IpNet::V4(Ipv4Net::new_assert(ip, 24).trunc()), // use defaut mask here
     302            0 :             IpAddr::V6(ip) => IpNet::V6(Ipv6Net::new_assert(ip, 64).trunc()),
     303              :         };
     304            0 :         if !self.limiter.lock_propagate_poison().check(subnet_key, 1) {
     305              :             // log only the subnet part of the IP address to know which subnet is rate limited
     306            0 :             tracing::warn!("Rate limit exceeded. Skipping cancellation message, {subnet_key}");
     307            0 :             Metrics::get()
     308            0 :                 .proxy
     309            0 :                 .cancellation_requests_total
     310            0 :                 .inc(CancellationRequest {
     311            0 :                     kind: crate::metrics::CancellationOutcome::RateLimitExceeded,
     312            0 :                 });
     313            0 :             return Err(CancelError::RateLimit);
     314            0 :         }
     315              : 
     316            0 :         let cancel_state = self.get_cancel_key(key).await.map_err(|e| {
     317            0 :             tracing::warn!("failed to receive RedisOp response: {e}");
     318            0 :             CancelError::InternalError
     319            0 :         })?;
     320              : 
     321            0 :         let Some(cancel_closure) = cancel_state else {
     322            0 :             tracing::warn!("query cancellation key not found: {key}");
     323            0 :             Metrics::get()
     324            0 :                 .proxy
     325            0 :                 .cancellation_requests_total
     326            0 :                 .inc(CancellationRequest {
     327            0 :                     kind: crate::metrics::CancellationOutcome::NotFound,
     328            0 :                 });
     329            0 :             return Err(CancelError::NotFound);
     330              :         };
     331              : 
     332            0 :         if check_ip_allowed {
     333            0 :             let ip_allowlist = auth_backend
     334            0 :                 .get_allowed_ips(&ctx, &cancel_closure.user_info)
     335            0 :                 .await
     336            0 :                 .map_err(|e| CancelError::AuthError(e.into()))?;
     337              : 
     338            0 :             if !check_peer_addr_is_in_list(&ctx.peer_addr(), &ip_allowlist) {
     339              :                 // log it here since cancel_session could be spawned in a task
     340            0 :                 tracing::warn!(
     341            0 :                     "IP is not allowed to cancel the query: {key}, address: {}",
     342            0 :                     ctx.peer_addr()
     343              :                 );
     344            0 :                 return Err(CancelError::IpNotAllowed);
     345            0 :             }
     346            0 :         }
     347              : 
     348              :         // check if a VPC endpoint ID is coming in and if yes, if it's allowed
     349            0 :         let access_blocks = auth_backend
     350            0 :             .get_block_public_or_vpc_access(&ctx, &cancel_closure.user_info)
     351            0 :             .await
     352            0 :             .map_err(|e| CancelError::AuthError(e.into()))?;
     353              : 
     354            0 :         if check_vpc_allowed {
     355            0 :             if access_blocks.vpc_access_blocked {
     356            0 :                 return Err(CancelError::AuthError(AuthError::NetworkNotAllowed));
     357            0 :             }
     358              : 
     359            0 :             let incoming_vpc_endpoint_id = match ctx.extra() {
     360            0 :                 None => return Err(CancelError::AuthError(AuthError::MissingVPCEndpointId)),
     361            0 :                 Some(ConnectionInfoExtra::Aws { vpce_id }) => vpce_id.to_string(),
     362            0 :                 Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(),
     363              :             };
     364              : 
     365            0 :             let allowed_vpc_endpoint_ids = auth_backend
     366            0 :                 .get_allowed_vpc_endpoint_ids(&ctx, &cancel_closure.user_info)
     367            0 :                 .await
     368            0 :                 .map_err(|e| CancelError::AuthError(e.into()))?;
     369              :             // TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that.
     370            0 :             if !allowed_vpc_endpoint_ids.is_empty()
     371            0 :                 && !allowed_vpc_endpoint_ids.contains(&incoming_vpc_endpoint_id)
     372              :             {
     373            0 :                 return Err(CancelError::VpcEndpointIdNotAllowed);
     374            0 :             }
     375            0 :         } else if access_blocks.public_access_blocked {
     376            0 :             return Err(CancelError::VpcEndpointIdNotAllowed);
     377            0 :         }
     378              : 
     379            0 :         Metrics::get()
     380            0 :             .proxy
     381            0 :             .cancellation_requests_total
     382            0 :             .inc(CancellationRequest {
     383            0 :                 kind: crate::metrics::CancellationOutcome::Found,
     384            0 :             });
     385            0 :         info!("cancelling query per user's request using key {key}");
     386            0 :         cancel_closure.try_cancel_query(self.compute_config).await
     387            0 :     }
     388              : }
     389              : 
     390              : /// This should've been a [`std::future::Future`], but
     391              : /// it's impossible to name a type of an unboxed future
     392              : /// (we'd need something like `#![feature(type_alias_impl_trait)]`).
     393            0 : #[derive(Clone, Serialize, Deserialize)]
     394              : pub struct CancelClosure {
     395              :     socket_addr: SocketAddr,
     396              :     cancel_token: CancelToken,
     397              :     hostname: String, // for pg_sni router
     398              :     user_info: ComputeUserInfo,
     399              : }
     400              : 
     401              : impl CancelClosure {
     402            0 :     pub(crate) fn new(
     403            0 :         socket_addr: SocketAddr,
     404            0 :         cancel_token: CancelToken,
     405            0 :         hostname: String,
     406            0 :         user_info: ComputeUserInfo,
     407            0 :     ) -> Self {
     408            0 :         Self {
     409            0 :             socket_addr,
     410            0 :             cancel_token,
     411            0 :             hostname,
     412            0 :             user_info,
     413            0 :         }
     414            0 :     }
     415              :     /// Cancels the query running on user's compute node.
     416            0 :     pub(crate) async fn try_cancel_query(
     417            0 :         self,
     418            0 :         compute_config: &ComputeConfig,
     419            0 :     ) -> Result<(), CancelError> {
     420            0 :         let socket = TcpStream::connect(self.socket_addr).await?;
     421              : 
     422            0 :         let mut mk_tls =
     423            0 :             crate::tls::postgres_rustls::MakeRustlsConnect::new(compute_config.tls.clone());
     424            0 :         let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
     425            0 :             &mut mk_tls,
     426            0 :             &self.hostname,
     427            0 :         )
     428            0 :         .map_err(|e| {
     429            0 :             CancelError::IO(std::io::Error::new(
     430            0 :                 std::io::ErrorKind::Other,
     431            0 :                 e.to_string(),
     432            0 :             ))
     433            0 :         })?;
     434              : 
     435            0 :         self.cancel_token.cancel_query_raw(socket, tls).await?;
     436            0 :         debug!("query was cancelled");
     437            0 :         Ok(())
     438            0 :     }
     439              : }
     440              : 
     441              : /// Helper for registering query cancellation tokens.
     442              : pub(crate) struct Session {
     443              :     /// The user-facing key identifying this session.
     444              :     key: CancelKeyData,
     445              :     redis_key: String,
     446              :     cancellation_handler: Arc<CancellationHandler>,
     447              : }
     448              : 
     449              : impl Session {
     450            0 :     pub(crate) fn key(&self) -> &CancelKeyData {
     451            0 :         &self.key
     452            0 :     }
     453              : 
     454              :     // Send the store key op to the cancellation handler and set TTL for the key
     455            0 :     pub(crate) async fn write_cancel_key(
     456            0 :         &self,
     457            0 :         cancel_closure: CancelClosure,
     458            0 :     ) -> Result<(), CancelError> {
     459            0 :         let Some(tx) = &self.cancellation_handler.tx else {
     460            0 :             tracing::warn!("cancellation handler is not available");
     461            0 :             return Err(CancelError::InternalError);
     462              :         };
     463              : 
     464            0 :         let closure_json = serde_json::to_string(&cancel_closure).map_err(|e| {
     465            0 :             tracing::warn!("failed to serialize cancel closure: {e}");
     466            0 :             CancelError::InternalError
     467            0 :         })?;
     468              : 
     469            0 :         let op = CancelKeyOp::StoreCancelKey {
     470            0 :             key: self.redis_key.clone(),
     471            0 :             field: "data".to_string(),
     472            0 :             value: closure_json,
     473            0 :             resp_tx: None,
     474            0 :             _guard: Metrics::get()
     475            0 :                 .proxy
     476            0 :                 .cancel_channel_size
     477            0 :                 .guard(RedisMsgKind::HSet),
     478            0 :             expire: CANCEL_KEY_TTL,
     479            0 :         };
     480            0 : 
     481            0 :         let _ = tx.send_timeout(op, REDIS_SEND_TIMEOUT).await.map_err(|e| {
     482            0 :             let key = self.key;
     483            0 :             tracing::warn!("failed to send StoreCancelKey for {key}: {e}");
     484            0 :         });
     485            0 :         Ok(())
     486            0 :     }
     487              : 
     488            0 :     pub(crate) async fn remove_cancel_key(&self) -> Result<(), CancelError> {
     489            0 :         let Some(tx) = &self.cancellation_handler.tx else {
     490            0 :             tracing::warn!("cancellation handler is not available");
     491            0 :             return Err(CancelError::InternalError);
     492              :         };
     493              : 
     494            0 :         let op = CancelKeyOp::RemoveCancelKey {
     495            0 :             key: self.redis_key.clone(),
     496            0 :             field: "data".to_string(),
     497            0 :             resp_tx: None,
     498            0 :             _guard: Metrics::get()
     499            0 :                 .proxy
     500            0 :                 .cancel_channel_size
     501            0 :                 .guard(RedisMsgKind::HDel),
     502            0 :         };
     503            0 : 
     504            0 :         let _ = tx.send_timeout(op, REDIS_SEND_TIMEOUT).await.map_err(|e| {
     505            0 :             let key = self.key;
     506            0 :             tracing::warn!("failed to send RemoveCancelKey for {key}: {e}");
     507            0 :         });
     508            0 :         Ok(())
     509            0 :     }
     510              : }
        

Generated by: LCOV version 2.1-beta