LCOV - code coverage report
Current view: top level - proxy/src - cancellation.rs (source / functions) Coverage Total Hit
Test: 727bdccc1d7d53837da843959afb612f56da4e79.info Lines: 0.0 % 275 0
Test Date: 2025-01-30 15:18:43 Functions: 0.0 % 42 0

            Line data    Source code
       1              : use std::net::{IpAddr, SocketAddr};
       2              : use std::sync::Arc;
       3              : 
       4              : use ipnet::{IpNet, Ipv4Net, Ipv6Net};
       5              : use postgres_client::tls::MakeTlsConnect;
       6              : use postgres_client::CancelToken;
       7              : use pq_proto::CancelKeyData;
       8              : use serde::{Deserialize, Serialize};
       9              : use thiserror::Error;
      10              : use tokio::net::TcpStream;
      11              : use tokio::sync::mpsc;
      12              : use tracing::{debug, info};
      13              : 
      14              : use crate::auth::backend::{BackendIpAllowlist, ComputeUserInfo};
      15              : use crate::auth::{check_peer_addr_is_in_list, AuthError};
      16              : use crate::config::ComputeConfig;
      17              : use crate::context::RequestContext;
      18              : use crate::error::ReportableError;
      19              : use crate::ext::LockExt;
      20              : use crate::metrics::CancelChannelSizeGuard;
      21              : use crate::metrics::{CancellationRequest, Metrics, RedisMsgKind};
      22              : use crate::rate_limiter::LeakyBucketRateLimiter;
      23              : use crate::redis::keys::KeyPrefix;
      24              : use crate::redis::kv_ops::RedisKVClient;
      25              : use crate::tls::postgres_rustls::MakeRustlsConnect;
      26              : use std::convert::Infallible;
      27              : use tokio::sync::oneshot;
      28              : 
      29              : type IpSubnetKey = IpNet;
      30              : 
      31              : const CANCEL_KEY_TTL: i64 = 1_209_600; // 2 weeks cancellation key expire time
      32              : const REDIS_SEND_TIMEOUT: std::time::Duration = std::time::Duration::from_millis(10);
      33              : 
      34              : // Message types for sending through mpsc channel
      35              : pub enum CancelKeyOp {
      36              :     StoreCancelKey {
      37              :         key: String,
      38              :         field: String,
      39              :         value: String,
      40              :         resp_tx: Option<oneshot::Sender<anyhow::Result<()>>>,
      41              :         _guard: CancelChannelSizeGuard<'static>,
      42              :         expire: i64, // TTL for key
      43              :     },
      44              :     GetCancelData {
      45              :         key: String,
      46              :         resp_tx: oneshot::Sender<anyhow::Result<Vec<(String, String)>>>,
      47              :         _guard: CancelChannelSizeGuard<'static>,
      48              :     },
      49              :     RemoveCancelKey {
      50              :         key: String,
      51              :         field: String,
      52              :         resp_tx: Option<oneshot::Sender<anyhow::Result<()>>>,
      53              :         _guard: CancelChannelSizeGuard<'static>,
      54              :     },
      55              : }
      56              : 
      57              : // Running as a separate task to accept messages through the rx channel
      58              : // In case of problems with RTT: switch to recv_many() + redis pipeline
      59            0 : pub async fn handle_cancel_messages(
      60            0 :     client: &mut RedisKVClient,
      61            0 :     mut rx: mpsc::Receiver<CancelKeyOp>,
      62            0 : ) -> anyhow::Result<Infallible> {
      63              :     loop {
      64            0 :         if let Some(msg) = rx.recv().await {
      65            0 :             match msg {
      66              :                 CancelKeyOp::StoreCancelKey {
      67            0 :                     key,
      68            0 :                     field,
      69            0 :                     value,
      70            0 :                     resp_tx,
      71            0 :                     _guard,
      72              :                     expire: _,
      73              :                 } => {
      74            0 :                     if let Some(resp_tx) = resp_tx {
      75            0 :                         resp_tx
      76            0 :                             .send(client.hset(key, field, value).await)
      77            0 :                             .inspect_err(|e| {
      78            0 :                                 tracing::debug!("failed to send StoreCancelKey response: {:?}", e);
      79            0 :                             })
      80            0 :                             .ok();
      81            0 :                     } else {
      82            0 :                         drop(client.hset(key, field, value).await);
      83              :                     }
      84              :                 }
      85              :                 CancelKeyOp::GetCancelData {
      86            0 :                     key,
      87            0 :                     resp_tx,
      88            0 :                     _guard,
      89            0 :                 } => {
      90            0 :                     drop(resp_tx.send(client.hget_all(key).await));
      91              :                 }
      92              :                 CancelKeyOp::RemoveCancelKey {
      93            0 :                     key,
      94            0 :                     field,
      95            0 :                     resp_tx,
      96            0 :                     _guard,
      97              :                 } => {
      98            0 :                     if let Some(resp_tx) = resp_tx {
      99            0 :                         resp_tx
     100            0 :                             .send(client.hdel(key, field).await)
     101            0 :                             .inspect_err(|e| {
     102            0 :                                 tracing::debug!("failed to send StoreCancelKey response: {:?}", e);
     103            0 :                             })
     104            0 :                             .ok();
     105            0 :                     } else {
     106            0 :                         drop(client.hdel(key, field).await);
     107              :                     }
     108              :                 }
     109              :             }
     110            0 :         }
     111              :     }
     112              : }
     113              : 
     114              : /// Enables serving `CancelRequest`s.
     115              : ///
     116              : /// If `CancellationPublisher` is available, cancel request will be used to publish the cancellation key to other proxy instances.
     117              : pub struct CancellationHandler {
     118              :     compute_config: &'static ComputeConfig,
     119              :     // rate limiter of cancellation requests
     120              :     limiter: Arc<std::sync::Mutex<LeakyBucketRateLimiter<IpSubnetKey>>>,
     121              :     tx: Option<mpsc::Sender<CancelKeyOp>>, // send messages to the redis KV client task
     122              : }
     123              : 
     124              : #[derive(Debug, Error)]
     125              : pub(crate) enum CancelError {
     126              :     #[error("{0}")]
     127              :     IO(#[from] std::io::Error),
     128              : 
     129              :     #[error("{0}")]
     130              :     Postgres(#[from] postgres_client::Error),
     131              : 
     132              :     #[error("rate limit exceeded")]
     133              :     RateLimit,
     134              : 
     135              :     #[error("IP is not allowed")]
     136              :     IpNotAllowed,
     137              : 
     138              :     #[error("Authentication backend error")]
     139              :     AuthError(#[from] AuthError),
     140              : 
     141              :     #[error("key not found")]
     142              :     NotFound,
     143              : 
     144              :     #[error("proxy service error")]
     145              :     InternalError,
     146              : }
     147              : 
     148              : impl ReportableError for CancelError {
     149            0 :     fn get_error_kind(&self) -> crate::error::ErrorKind {
     150            0 :         match self {
     151            0 :             CancelError::IO(_) => crate::error::ErrorKind::Compute,
     152            0 :             CancelError::Postgres(e) if e.as_db_error().is_some() => {
     153            0 :                 crate::error::ErrorKind::Postgres
     154              :             }
     155            0 :             CancelError::Postgres(_) => crate::error::ErrorKind::Compute,
     156            0 :             CancelError::RateLimit => crate::error::ErrorKind::RateLimit,
     157            0 :             CancelError::IpNotAllowed => crate::error::ErrorKind::User,
     158            0 :             CancelError::NotFound => crate::error::ErrorKind::User,
     159            0 :             CancelError::AuthError(_) => crate::error::ErrorKind::ControlPlane,
     160            0 :             CancelError::InternalError => crate::error::ErrorKind::Service,
     161              :         }
     162            0 :     }
     163              : }
     164              : 
     165              : impl CancellationHandler {
     166            0 :     pub fn new(
     167            0 :         compute_config: &'static ComputeConfig,
     168            0 :         tx: Option<mpsc::Sender<CancelKeyOp>>,
     169            0 :     ) -> Self {
     170            0 :         Self {
     171            0 :             compute_config,
     172            0 :             tx,
     173            0 :             limiter: Arc::new(std::sync::Mutex::new(
     174            0 :                 LeakyBucketRateLimiter::<IpSubnetKey>::new_with_shards(
     175            0 :                     LeakyBucketRateLimiter::<IpSubnetKey>::DEFAULT,
     176            0 :                     64,
     177            0 :                 ),
     178            0 :             )),
     179            0 :         }
     180            0 :     }
     181              : 
     182            0 :     pub(crate) fn get_key(self: &Arc<Self>) -> Session {
     183            0 :         // we intentionally generate a random "backend pid" and "secret key" here.
     184            0 :         // we use the corresponding u64 as an identifier for the
     185            0 :         // actual endpoint+pid+secret for postgres/pgbouncer.
     186            0 :         //
     187            0 :         // if we forwarded the backend_pid from postgres to the client, there would be a lot
     188            0 :         // of overlap between our computes as most pids are small (~100).
     189            0 : 
     190            0 :         let key: CancelKeyData = rand::random();
     191            0 : 
     192            0 :         let prefix_key: KeyPrefix = KeyPrefix::Cancel(key);
     193            0 :         let redis_key = prefix_key.build_redis_key();
     194            0 : 
     195            0 :         debug!("registered new query cancellation key {key}");
     196            0 :         Session {
     197            0 :             key,
     198            0 :             redis_key,
     199            0 :             cancellation_handler: Arc::clone(self),
     200            0 :         }
     201            0 :     }
     202              : 
     203            0 :     async fn get_cancel_key(
     204            0 :         &self,
     205            0 :         key: CancelKeyData,
     206            0 :     ) -> Result<Option<CancelClosure>, CancelError> {
     207            0 :         let prefix_key: KeyPrefix = KeyPrefix::Cancel(key);
     208            0 :         let redis_key = prefix_key.build_redis_key();
     209            0 : 
     210            0 :         let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
     211            0 :         let op = CancelKeyOp::GetCancelData {
     212            0 :             key: redis_key,
     213            0 :             resp_tx,
     214            0 :             _guard: Metrics::get()
     215            0 :                 .proxy
     216            0 :                 .cancel_channel_size
     217            0 :                 .guard(RedisMsgKind::HGetAll),
     218            0 :         };
     219              : 
     220            0 :         let Some(tx) = &self.tx else {
     221            0 :             tracing::warn!("cancellation handler is not available");
     222            0 :             return Err(CancelError::InternalError);
     223              :         };
     224              : 
     225            0 :         tx.send_timeout(op, REDIS_SEND_TIMEOUT)
     226            0 :             .await
     227            0 :             .map_err(|e| {
     228            0 :                 tracing::warn!("failed to send GetCancelData for {key}: {e}");
     229            0 :             })
     230            0 :             .map_err(|()| CancelError::InternalError)?;
     231              : 
     232            0 :         let result = resp_rx.await.map_err(|e| {
     233            0 :             tracing::warn!("failed to receive GetCancelData response: {e}");
     234            0 :             CancelError::InternalError
     235            0 :         })?;
     236              : 
     237            0 :         let cancel_state_str: Option<String> = match result {
     238            0 :             Ok(mut state) => {
     239            0 :                 if state.len() == 1 {
     240            0 :                     Some(state.remove(0).1)
     241              :                 } else {
     242            0 :                     tracing::warn!("unexpected number of entries in cancel state: {state:?}");
     243            0 :                     return Err(CancelError::InternalError);
     244              :                 }
     245              :             }
     246            0 :             Err(e) => {
     247            0 :                 tracing::warn!("failed to receive cancel state from redis: {e}");
     248            0 :                 return Err(CancelError::InternalError);
     249              :             }
     250              :         };
     251              : 
     252            0 :         let cancel_state: Option<CancelClosure> = match cancel_state_str {
     253            0 :             Some(state) => {
     254            0 :                 let cancel_closure: CancelClosure = serde_json::from_str(&state).map_err(|e| {
     255            0 :                     tracing::warn!("failed to deserialize cancel state: {e}");
     256            0 :                     CancelError::InternalError
     257            0 :                 })?;
     258            0 :                 Some(cancel_closure)
     259              :             }
     260            0 :             None => None,
     261              :         };
     262            0 :         Ok(cancel_state)
     263            0 :     }
     264              :     /// Try to cancel a running query for the corresponding connection.
     265              :     /// If the cancellation key is not found, it will be published to Redis.
     266              :     /// check_allowed - if true, check if the IP is allowed to cancel the query.
     267              :     /// Will fetch IP allowlist internally.
     268              :     ///
     269              :     /// return Result primarily for tests
     270            0 :     pub(crate) async fn cancel_session<T: BackendIpAllowlist>(
     271            0 :         &self,
     272            0 :         key: CancelKeyData,
     273            0 :         ctx: RequestContext,
     274            0 :         check_allowed: bool,
     275            0 :         auth_backend: &T,
     276            0 :     ) -> Result<(), CancelError> {
     277            0 :         let subnet_key = match ctx.peer_addr() {
     278            0 :             IpAddr::V4(ip) => IpNet::V4(Ipv4Net::new_assert(ip, 24).trunc()), // use defaut mask here
     279            0 :             IpAddr::V6(ip) => IpNet::V6(Ipv6Net::new_assert(ip, 64).trunc()),
     280              :         };
     281            0 :         if !self.limiter.lock_propagate_poison().check(subnet_key, 1) {
     282              :             // log only the subnet part of the IP address to know which subnet is rate limited
     283            0 :             tracing::warn!("Rate limit exceeded. Skipping cancellation message, {subnet_key}");
     284            0 :             Metrics::get()
     285            0 :                 .proxy
     286            0 :                 .cancellation_requests_total
     287            0 :                 .inc(CancellationRequest {
     288            0 :                     kind: crate::metrics::CancellationOutcome::RateLimitExceeded,
     289            0 :                 });
     290            0 :             return Err(CancelError::RateLimit);
     291            0 :         }
     292              : 
     293            0 :         let cancel_state = self.get_cancel_key(key).await.map_err(|e| {
     294            0 :             tracing::warn!("failed to receive RedisOp response: {e}");
     295            0 :             CancelError::InternalError
     296            0 :         })?;
     297              : 
     298            0 :         let Some(cancel_closure) = cancel_state else {
     299            0 :             tracing::warn!("query cancellation key not found: {key}");
     300            0 :             Metrics::get()
     301            0 :                 .proxy
     302            0 :                 .cancellation_requests_total
     303            0 :                 .inc(CancellationRequest {
     304            0 :                     kind: crate::metrics::CancellationOutcome::NotFound,
     305            0 :                 });
     306            0 :             return Err(CancelError::NotFound);
     307              :         };
     308              : 
     309            0 :         if check_allowed {
     310            0 :             let ip_allowlist = auth_backend
     311            0 :                 .get_allowed_ips(&ctx, &cancel_closure.user_info)
     312            0 :                 .await
     313            0 :                 .map_err(CancelError::AuthError)?;
     314              : 
     315            0 :             if !check_peer_addr_is_in_list(&ctx.peer_addr(), &ip_allowlist) {
     316              :                 // log it here since cancel_session could be spawned in a task
     317            0 :                 tracing::warn!(
     318            0 :                     "IP is not allowed to cancel the query: {key}, address: {}",
     319            0 :                     ctx.peer_addr()
     320              :                 );
     321            0 :                 return Err(CancelError::IpNotAllowed);
     322            0 :             }
     323            0 :         }
     324              : 
     325            0 :         Metrics::get()
     326            0 :             .proxy
     327            0 :             .cancellation_requests_total
     328            0 :             .inc(CancellationRequest {
     329            0 :                 kind: crate::metrics::CancellationOutcome::Found,
     330            0 :             });
     331            0 :         info!("cancelling query per user's request using key {key}");
     332            0 :         cancel_closure.try_cancel_query(self.compute_config).await
     333            0 :     }
     334              : }
     335              : 
     336              : /// This should've been a [`std::future::Future`], but
     337              : /// it's impossible to name a type of an unboxed future
     338              : /// (we'd need something like `#![feature(type_alias_impl_trait)]`).
     339            0 : #[derive(Clone, Serialize, Deserialize)]
     340              : pub struct CancelClosure {
     341              :     socket_addr: SocketAddr,
     342              :     cancel_token: CancelToken,
     343              :     hostname: String, // for pg_sni router
     344              :     user_info: ComputeUserInfo,
     345              : }
     346              : 
     347              : impl CancelClosure {
     348            0 :     pub(crate) fn new(
     349            0 :         socket_addr: SocketAddr,
     350            0 :         cancel_token: CancelToken,
     351            0 :         hostname: String,
     352            0 :         user_info: ComputeUserInfo,
     353            0 :     ) -> Self {
     354            0 :         Self {
     355            0 :             socket_addr,
     356            0 :             cancel_token,
     357            0 :             hostname,
     358            0 :             user_info,
     359            0 :         }
     360            0 :     }
     361              :     /// Cancels the query running on user's compute node.
     362            0 :     pub(crate) async fn try_cancel_query(
     363            0 :         self,
     364            0 :         compute_config: &ComputeConfig,
     365            0 :     ) -> Result<(), CancelError> {
     366            0 :         let socket = TcpStream::connect(self.socket_addr).await?;
     367              : 
     368            0 :         let mut mk_tls =
     369            0 :             crate::tls::postgres_rustls::MakeRustlsConnect::new(compute_config.tls.clone());
     370            0 :         let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
     371            0 :             &mut mk_tls,
     372            0 :             &self.hostname,
     373            0 :         )
     374            0 :         .map_err(|e| {
     375            0 :             CancelError::IO(std::io::Error::new(
     376            0 :                 std::io::ErrorKind::Other,
     377            0 :                 e.to_string(),
     378            0 :             ))
     379            0 :         })?;
     380              : 
     381            0 :         self.cancel_token.cancel_query_raw(socket, tls).await?;
     382            0 :         debug!("query was cancelled");
     383            0 :         Ok(())
     384            0 :     }
     385              : }
     386              : 
     387              : /// Helper for registering query cancellation tokens.
     388              : pub(crate) struct Session {
     389              :     /// The user-facing key identifying this session.
     390              :     key: CancelKeyData,
     391              :     redis_key: String,
     392              :     cancellation_handler: Arc<CancellationHandler>,
     393              : }
     394              : 
     395              : impl Session {
     396            0 :     pub(crate) fn key(&self) -> &CancelKeyData {
     397            0 :         &self.key
     398            0 :     }
     399              : 
     400              :     // Send the store key op to the cancellation handler
     401            0 :     pub(crate) async fn write_cancel_key(
     402            0 :         &self,
     403            0 :         cancel_closure: CancelClosure,
     404            0 :     ) -> Result<(), CancelError> {
     405            0 :         let Some(tx) = &self.cancellation_handler.tx else {
     406            0 :             tracing::warn!("cancellation handler is not available");
     407            0 :             return Err(CancelError::InternalError);
     408              :         };
     409              : 
     410            0 :         let closure_json = serde_json::to_string(&cancel_closure).map_err(|e| {
     411            0 :             tracing::warn!("failed to serialize cancel closure: {e}");
     412            0 :             CancelError::InternalError
     413            0 :         })?;
     414              : 
     415            0 :         let op = CancelKeyOp::StoreCancelKey {
     416            0 :             key: self.redis_key.clone(),
     417            0 :             field: "data".to_string(),
     418            0 :             value: closure_json,
     419            0 :             resp_tx: None,
     420            0 :             _guard: Metrics::get()
     421            0 :                 .proxy
     422            0 :                 .cancel_channel_size
     423            0 :                 .guard(RedisMsgKind::HSet),
     424            0 :             expire: CANCEL_KEY_TTL,
     425            0 :         };
     426            0 : 
     427            0 :         let _ = tx.send_timeout(op, REDIS_SEND_TIMEOUT).await.map_err(|e| {
     428            0 :             let key = self.key;
     429            0 :             tracing::warn!("failed to send StoreCancelKey for {key}: {e}");
     430            0 :         });
     431            0 :         Ok(())
     432            0 :     }
     433              : 
     434            0 :     pub(crate) async fn remove_cancel_key(&self) -> Result<(), CancelError> {
     435            0 :         let Some(tx) = &self.cancellation_handler.tx else {
     436            0 :             tracing::warn!("cancellation handler is not available");
     437            0 :             return Err(CancelError::InternalError);
     438              :         };
     439              : 
     440            0 :         let op = CancelKeyOp::RemoveCancelKey {
     441            0 :             key: self.redis_key.clone(),
     442            0 :             field: "data".to_string(),
     443            0 :             resp_tx: None,
     444            0 :             _guard: Metrics::get()
     445            0 :                 .proxy
     446            0 :                 .cancel_channel_size
     447            0 :                 .guard(RedisMsgKind::HSet),
     448            0 :         };
     449            0 : 
     450            0 :         let _ = tx.send_timeout(op, REDIS_SEND_TIMEOUT).await.map_err(|e| {
     451            0 :             let key = self.key;
     452            0 :             tracing::warn!("failed to send RemoveCancelKey for {key}: {e}");
     453            0 :         });
     454            0 :         Ok(())
     455            0 :     }
     456              : }
        

Generated by: LCOV version 2.1-beta