LCOV - code coverage report
Current view: top level - proxy/src - cancellation.rs (source / functions) Coverage Total Hit
Test: 20b6afc7b7f34578dcaab2b3acdaecfe91cd8bf1.info Lines: 55.4 % 186 103
Test Date: 2024-11-25 17:48:16 Functions: 31.4 % 35 11

            Line data    Source code
       1              : use std::net::SocketAddr;
       2              : use std::sync::Arc;
       3              : 
       4              : use dashmap::DashMap;
       5              : use pq_proto::CancelKeyData;
       6              : use thiserror::Error;
       7              : use tokio::net::TcpStream;
       8              : use tokio::sync::Mutex;
       9              : use tokio_postgres::{CancelToken, NoTls};
      10              : use tracing::{debug, info};
      11              : use uuid::Uuid;
      12              : 
      13              : use crate::auth::{check_peer_addr_is_in_list, IpPattern};
      14              : use crate::error::ReportableError;
      15              : use crate::metrics::{CancellationRequest, CancellationSource, Metrics};
      16              : use crate::rate_limiter::LeakyBucketRateLimiter;
      17              : use crate::redis::cancellation_publisher::{
      18              :     CancellationPublisher, CancellationPublisherMut, RedisPublisherClient,
      19              : };
      20              : use std::net::IpAddr;
      21              : 
      22              : use ipnet::{IpNet, Ipv4Net, Ipv6Net};
      23              : 
      24              : pub type CancelMap = Arc<DashMap<CancelKeyData, Option<CancelClosure>>>;
      25              : pub type CancellationHandlerMain = CancellationHandler<Option<Arc<Mutex<RedisPublisherClient>>>>;
      26              : pub(crate) type CancellationHandlerMainInternal = Option<Arc<Mutex<RedisPublisherClient>>>;
      27              : 
      28              : type IpSubnetKey = IpNet;
      29              : 
      30              : /// Enables serving `CancelRequest`s.
      31              : ///
      32              : /// If `CancellationPublisher` is available, cancel request will be used to publish the cancellation key to other proxy instances.
      33              : pub struct CancellationHandler<P> {
      34              :     map: CancelMap,
      35              :     client: P,
      36              :     /// This field used for the monitoring purposes.
      37              :     /// Represents the source of the cancellation request.
      38              :     from: CancellationSource,
      39              :     // rate limiter of cancellation requests
      40              :     limiter: Arc<std::sync::Mutex<LeakyBucketRateLimiter<IpSubnetKey>>>,
      41              : }
      42              : 
      43            0 : #[derive(Debug, Error)]
      44              : pub(crate) enum CancelError {
      45              :     #[error("{0}")]
      46              :     IO(#[from] std::io::Error),
      47              : 
      48              :     #[error("{0}")]
      49              :     Postgres(#[from] tokio_postgres::Error),
      50              : 
      51              :     #[error("rate limit exceeded")]
      52              :     RateLimit,
      53              : 
      54              :     #[error("IP is not allowed")]
      55              :     IpNotAllowed,
      56              : }
      57              : 
      58              : impl ReportableError for CancelError {
      59            0 :     fn get_error_kind(&self) -> crate::error::ErrorKind {
      60            0 :         match self {
      61            0 :             CancelError::IO(_) => crate::error::ErrorKind::Compute,
      62            0 :             CancelError::Postgres(e) if e.as_db_error().is_some() => {
      63            0 :                 crate::error::ErrorKind::Postgres
      64              :             }
      65            0 :             CancelError::Postgres(_) => crate::error::ErrorKind::Compute,
      66            0 :             CancelError::RateLimit => crate::error::ErrorKind::RateLimit,
      67            0 :             CancelError::IpNotAllowed => crate::error::ErrorKind::User,
      68              :         }
      69            0 :     }
      70              : }
      71              : 
      72              : impl<P: CancellationPublisher> CancellationHandler<P> {
      73              :     /// Run async action within an ephemeral session identified by [`CancelKeyData`].
      74            1 :     pub(crate) fn get_session(self: Arc<Self>) -> Session<P> {
      75              :         // HACK: We'd rather get the real backend_pid but tokio_postgres doesn't
      76              :         // expose it and we don't want to do another roundtrip to query
      77              :         // for it. The client will be able to notice that this is not the
      78              :         // actual backend_pid, but backend_pid is not used for anything
      79              :         // so it doesn't matter.
      80            1 :         let key = loop {
      81            1 :             let key = rand::random();
      82            1 : 
      83            1 :             // Random key collisions are unlikely to happen here, but they're still possible,
      84            1 :             // which is why we have to take care not to rewrite an existing key.
      85            1 :             match self.map.entry(key) {
      86            0 :                 dashmap::mapref::entry::Entry::Occupied(_) => continue,
      87            1 :                 dashmap::mapref::entry::Entry::Vacant(e) => {
      88            1 :                     e.insert(None);
      89            1 :                 }
      90            1 :             }
      91            1 :             break key;
      92            1 :         };
      93            1 : 
      94            1 :         debug!("registered new query cancellation key {key}");
      95            1 :         Session {
      96            1 :             key,
      97            1 :             cancellation_handler: self,
      98            1 :         }
      99            1 :     }
     100              : 
     101              :     /// Try to cancel a running query for the corresponding connection.
     102              :     /// If the cancellation key is not found, it will be published to Redis.
     103              :     /// check_allowed - if true, check if the IP is allowed to cancel the query
     104            1 :     pub(crate) async fn cancel_session(
     105            1 :         &self,
     106            1 :         key: CancelKeyData,
     107            1 :         session_id: Uuid,
     108            1 :         peer_addr: &IpAddr,
     109            1 :         check_allowed: bool,
     110            1 :     ) -> Result<(), CancelError> {
     111            1 :         // TODO: check for unspecified address is only for backward compatibility, should be removed
     112            1 :         if !peer_addr.is_unspecified() {
     113            1 :             let subnet_key = match *peer_addr {
     114            1 :                 IpAddr::V4(ip) => IpNet::V4(Ipv4Net::new_assert(ip, 24).trunc()), // use defaut mask here
     115            0 :                 IpAddr::V6(ip) => IpNet::V6(Ipv6Net::new_assert(ip, 64).trunc()),
     116              :             };
     117            1 :             if !self.limiter.lock().unwrap().check(subnet_key, 1) {
     118            0 :                 tracing::debug!("Rate limit exceeded. Skipping cancellation message");
     119            0 :                 Metrics::get()
     120            0 :                     .proxy
     121            0 :                     .cancellation_requests_total
     122            0 :                     .inc(CancellationRequest {
     123            0 :                         source: self.from,
     124            0 :                         kind: crate::metrics::CancellationOutcome::RateLimitExceeded,
     125            0 :                     });
     126            0 :                 return Err(CancelError::RateLimit);
     127            1 :             }
     128            0 :         }
     129              : 
     130              :         // NB: we should immediately release the lock after cloning the token.
     131            1 :         let Some(cancel_closure) = self.map.get(&key).and_then(|x| x.clone()) else {
     132            1 :             tracing::warn!("query cancellation key not found: {key}");
     133            1 :             Metrics::get()
     134            1 :                 .proxy
     135            1 :                 .cancellation_requests_total
     136            1 :                 .inc(CancellationRequest {
     137            1 :                     source: self.from,
     138            1 :                     kind: crate::metrics::CancellationOutcome::NotFound,
     139            1 :                 });
     140            1 : 
     141            1 :             if session_id == Uuid::nil() {
     142              :                 // was already published, do not publish it again
     143            0 :                 return Ok(());
     144            1 :             }
     145            1 : 
     146            1 :             match self.client.try_publish(key, session_id, *peer_addr).await {
     147            1 :                 Ok(()) => {} // do nothing
     148            0 :                 Err(e) => {
     149            0 :                     return Err(CancelError::IO(std::io::Error::new(
     150            0 :                         std::io::ErrorKind::Other,
     151            0 :                         e.to_string(),
     152            0 :                     )));
     153              :                 }
     154              :             }
     155            1 :             return Ok(());
     156              :         };
     157              : 
     158            0 :         if check_allowed
     159            0 :             && !check_peer_addr_is_in_list(peer_addr, cancel_closure.ip_allowlist.as_slice())
     160              :         {
     161            0 :             return Err(CancelError::IpNotAllowed);
     162            0 :         }
     163            0 : 
     164            0 :         Metrics::get()
     165            0 :             .proxy
     166            0 :             .cancellation_requests_total
     167            0 :             .inc(CancellationRequest {
     168            0 :                 source: self.from,
     169            0 :                 kind: crate::metrics::CancellationOutcome::Found,
     170            0 :             });
     171            0 :         info!("cancelling query per user's request using key {key}");
     172            0 :         cancel_closure.try_cancel_query().await
     173            1 :     }
     174              : 
     175              :     #[cfg(test)]
     176            1 :     fn contains(&self, session: &Session<P>) -> bool {
     177            1 :         self.map.contains_key(&session.key)
     178            1 :     }
     179              : 
     180              :     #[cfg(test)]
     181            1 :     fn is_empty(&self) -> bool {
     182            1 :         self.map.is_empty()
     183            1 :     }
     184              : }
     185              : 
     186              : impl CancellationHandler<()> {
     187            2 :     pub fn new(map: CancelMap, from: CancellationSource) -> Self {
     188            2 :         Self {
     189            2 :             map,
     190            2 :             client: (),
     191            2 :             from,
     192            2 :             limiter: Arc::new(std::sync::Mutex::new(
     193            2 :                 LeakyBucketRateLimiter::<IpSubnetKey>::new_with_shards(
     194            2 :                     LeakyBucketRateLimiter::<IpSubnetKey>::DEFAULT,
     195            2 :                     64,
     196            2 :                 ),
     197            2 :             )),
     198            2 :         }
     199            2 :     }
     200              : }
     201              : 
     202              : impl<P: CancellationPublisherMut> CancellationHandler<Option<Arc<Mutex<P>>>> {
     203            0 :     pub fn new(map: CancelMap, client: Option<Arc<Mutex<P>>>, from: CancellationSource) -> Self {
     204            0 :         Self {
     205            0 :             map,
     206            0 :             client,
     207            0 :             from,
     208            0 :             limiter: Arc::new(std::sync::Mutex::new(
     209            0 :                 LeakyBucketRateLimiter::<IpSubnetKey>::new_with_shards(
     210            0 :                     LeakyBucketRateLimiter::<IpSubnetKey>::DEFAULT,
     211            0 :                     64,
     212            0 :                 ),
     213            0 :             )),
     214            0 :         }
     215            0 :     }
     216              : }
     217              : 
     218              : /// This should've been a [`std::future::Future`], but
     219              : /// it's impossible to name a type of an unboxed future
     220              : /// (we'd need something like `#![feature(type_alias_impl_trait)]`).
     221              : #[derive(Clone)]
     222              : pub struct CancelClosure {
     223              :     socket_addr: SocketAddr,
     224              :     cancel_token: CancelToken,
     225              :     ip_allowlist: Vec<IpPattern>,
     226              : }
     227              : 
     228              : impl CancelClosure {
     229            0 :     pub(crate) fn new(
     230            0 :         socket_addr: SocketAddr,
     231            0 :         cancel_token: CancelToken,
     232            0 :         ip_allowlist: Vec<IpPattern>,
     233            0 :     ) -> Self {
     234            0 :         Self {
     235            0 :             socket_addr,
     236            0 :             cancel_token,
     237            0 :             ip_allowlist,
     238            0 :         }
     239            0 :     }
     240              :     /// Cancels the query running on user's compute node.
     241            0 :     pub(crate) async fn try_cancel_query(self) -> Result<(), CancelError> {
     242            0 :         let socket = TcpStream::connect(self.socket_addr).await?;
     243            0 :         self.cancel_token.cancel_query_raw(socket, NoTls).await?;
     244            0 :         debug!("query was cancelled");
     245            0 :         Ok(())
     246            0 :     }
     247            0 :     pub(crate) fn set_ip_allowlist(&mut self, ip_allowlist: Vec<IpPattern>) {
     248            0 :         self.ip_allowlist = ip_allowlist;
     249            0 :     }
     250              : }
     251              : 
     252              : /// Helper for registering query cancellation tokens.
     253              : pub(crate) struct Session<P> {
     254              :     /// The user-facing key identifying this session.
     255              :     key: CancelKeyData,
     256              :     /// The [`CancelMap`] this session belongs to.
     257              :     cancellation_handler: Arc<CancellationHandler<P>>,
     258              : }
     259              : 
     260              : impl<P> Session<P> {
     261              :     /// Store the cancel token for the given session.
     262              :     /// This enables query cancellation in `crate::proxy::prepare_client_connection`.
     263            0 :     pub(crate) fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData {
     264            0 :         debug!("enabling query cancellation for this session");
     265            0 :         self.cancellation_handler
     266            0 :             .map
     267            0 :             .insert(self.key, Some(cancel_closure));
     268            0 : 
     269            0 :         self.key
     270            0 :     }
     271              : }
     272              : 
     273              : impl<P> Drop for Session<P> {
     274            1 :     fn drop(&mut self) {
     275            1 :         self.cancellation_handler.map.remove(&self.key);
     276            1 :         debug!("dropped query cancellation key {}", &self.key);
     277            1 :     }
     278              : }
     279              : 
     280              : #[cfg(test)]
     281              : mod tests {
     282              :     use super::*;
     283              : 
     284              :     #[tokio::test]
     285            1 :     async fn check_session_drop() -> anyhow::Result<()> {
     286            1 :         let cancellation_handler = Arc::new(CancellationHandler::<()>::new(
     287            1 :             CancelMap::default(),
     288            1 :             CancellationSource::FromRedis,
     289            1 :         ));
     290            1 : 
     291            1 :         let session = cancellation_handler.clone().get_session();
     292            1 :         assert!(cancellation_handler.contains(&session));
     293            1 :         drop(session);
     294            1 :         // Check that the session has been dropped.
     295            1 :         assert!(cancellation_handler.is_empty());
     296            1 : 
     297            1 :         Ok(())
     298            1 :     }
     299              : 
     300              :     #[tokio::test]
     301            1 :     async fn cancel_session_noop_regression() {
     302            1 :         let handler =
     303            1 :             CancellationHandler::<()>::new(CancelMap::default(), CancellationSource::Local);
     304            1 :         handler
     305            1 :             .cancel_session(
     306            1 :                 CancelKeyData {
     307            1 :                     backend_pid: 0,
     308            1 :                     cancel_key: 0,
     309            1 :                 },
     310            1 :                 Uuid::new_v4(),
     311            1 :                 &("127.0.0.1".parse().unwrap()),
     312            1 :                 true,
     313            1 :             )
     314            1 :             .await
     315            1 :             .unwrap();
     316            1 :     }
     317              : }
        

Generated by: LCOV version 2.1-beta