LCOV - code coverage report
Current view: top level - proxy/src - cancellation.rs (source / functions) Coverage Total Hit
Test: 09e7485004805bd42b53a0c369170b3228136512.info Lines: 64.3 % 129 83
Test Date: 2024-11-21 18:36:18 Functions: 32.4 % 34 11

            Line data    Source code
       1              : use std::net::SocketAddr;
       2              : use std::sync::Arc;
       3              : 
       4              : use dashmap::DashMap;
       5              : use pq_proto::CancelKeyData;
       6              : use thiserror::Error;
       7              : use tokio::net::TcpStream;
       8              : use tokio::sync::Mutex;
       9              : use tokio_postgres::{CancelToken, NoTls};
      10              : use tracing::{debug, info};
      11              : use uuid::Uuid;
      12              : 
      13              : use crate::error::ReportableError;
      14              : use crate::metrics::{CancellationRequest, CancellationSource, Metrics};
      15              : use crate::redis::cancellation_publisher::{
      16              :     CancellationPublisher, CancellationPublisherMut, RedisPublisherClient,
      17              : };
      18              : 
      19              : pub type CancelMap = Arc<DashMap<CancelKeyData, Option<CancelClosure>>>;
      20              : pub type CancellationHandlerMain = CancellationHandler<Option<Arc<Mutex<RedisPublisherClient>>>>;
      21              : pub(crate) type CancellationHandlerMainInternal = Option<Arc<Mutex<RedisPublisherClient>>>;
      22              : 
      23              : /// Enables serving `CancelRequest`s.
      24              : ///
      25              : /// If `CancellationPublisher` is available, cancel request will be used to publish the cancellation key to other proxy instances.
      26              : pub struct CancellationHandler<P> {
      27              :     map: CancelMap,
      28              :     client: P,
      29              :     /// This field used for the monitoring purposes.
      30              :     /// Represents the source of the cancellation request.
      31              :     from: CancellationSource,
      32              : }
      33              : 
      34            0 : #[derive(Debug, Error)]
      35              : pub(crate) enum CancelError {
      36              :     #[error("{0}")]
      37              :     IO(#[from] std::io::Error),
      38              :     #[error("{0}")]
      39              :     Postgres(#[from] tokio_postgres::Error),
      40              : }
      41              : 
      42              : impl ReportableError for CancelError {
      43            0 :     fn get_error_kind(&self) -> crate::error::ErrorKind {
      44            0 :         match self {
      45            0 :             CancelError::IO(_) => crate::error::ErrorKind::Compute,
      46            0 :             CancelError::Postgres(e) if e.as_db_error().is_some() => {
      47            0 :                 crate::error::ErrorKind::Postgres
      48              :             }
      49            0 :             CancelError::Postgres(_) => crate::error::ErrorKind::Compute,
      50              :         }
      51            0 :     }
      52              : }
      53              : 
      54              : impl<P: CancellationPublisher> CancellationHandler<P> {
      55              :     /// Run async action within an ephemeral session identified by [`CancelKeyData`].
      56            1 :     pub(crate) fn get_session(self: Arc<Self>) -> Session<P> {
      57              :         // HACK: We'd rather get the real backend_pid but tokio_postgres doesn't
      58              :         // expose it and we don't want to do another roundtrip to query
      59              :         // for it. The client will be able to notice that this is not the
      60              :         // actual backend_pid, but backend_pid is not used for anything
      61              :         // so it doesn't matter.
      62            1 :         let key = loop {
      63            1 :             let key = rand::random();
      64            1 : 
      65            1 :             // Random key collisions are unlikely to happen here, but they're still possible,
      66            1 :             // which is why we have to take care not to rewrite an existing key.
      67            1 :             match self.map.entry(key) {
      68            0 :                 dashmap::mapref::entry::Entry::Occupied(_) => continue,
      69            1 :                 dashmap::mapref::entry::Entry::Vacant(e) => {
      70            1 :                     e.insert(None);
      71            1 :                 }
      72            1 :             }
      73            1 :             break key;
      74            1 :         };
      75            1 : 
      76            1 :         debug!("registered new query cancellation key {key}");
      77            1 :         Session {
      78            1 :             key,
      79            1 :             cancellation_handler: self,
      80            1 :         }
      81            1 :     }
      82              :     /// Try to cancel a running query for the corresponding connection.
      83              :     /// If the cancellation key is not found, it will be published to Redis.
      84            1 :     pub(crate) async fn cancel_session(
      85            1 :         &self,
      86            1 :         key: CancelKeyData,
      87            1 :         session_id: Uuid,
      88            1 :     ) -> Result<(), CancelError> {
      89              :         // NB: we should immediately release the lock after cloning the token.
      90            1 :         let Some(cancel_closure) = self.map.get(&key).and_then(|x| x.clone()) else {
      91            1 :             tracing::warn!("query cancellation key not found: {key}");
      92            1 :             Metrics::get()
      93            1 :                 .proxy
      94            1 :                 .cancellation_requests_total
      95            1 :                 .inc(CancellationRequest {
      96            1 :                     source: self.from,
      97            1 :                     kind: crate::metrics::CancellationOutcome::NotFound,
      98            1 :                 });
      99            1 :             match self.client.try_publish(key, session_id).await {
     100            1 :                 Ok(()) => {} // do nothing
     101            0 :                 Err(e) => {
     102            0 :                     return Err(CancelError::IO(std::io::Error::new(
     103            0 :                         std::io::ErrorKind::Other,
     104            0 :                         e.to_string(),
     105            0 :                     )));
     106              :                 }
     107              :             }
     108            1 :             return Ok(());
     109              :         };
     110            0 :         Metrics::get()
     111            0 :             .proxy
     112            0 :             .cancellation_requests_total
     113            0 :             .inc(CancellationRequest {
     114            0 :                 source: self.from,
     115            0 :                 kind: crate::metrics::CancellationOutcome::Found,
     116            0 :             });
     117            0 :         info!("cancelling query per user's request using key {key}");
     118            0 :         cancel_closure.try_cancel_query().await
     119            1 :     }
     120              : 
     121              :     #[cfg(test)]
     122            1 :     fn contains(&self, session: &Session<P>) -> bool {
     123            1 :         self.map.contains_key(&session.key)
     124            1 :     }
     125              : 
     126              :     #[cfg(test)]
     127            1 :     fn is_empty(&self) -> bool {
     128            1 :         self.map.is_empty()
     129            1 :     }
     130              : }
     131              : 
     132              : impl CancellationHandler<()> {
     133            2 :     pub fn new(map: CancelMap, from: CancellationSource) -> Self {
     134            2 :         Self {
     135            2 :             map,
     136            2 :             client: (),
     137            2 :             from,
     138            2 :         }
     139            2 :     }
     140              : }
     141              : 
     142              : impl<P: CancellationPublisherMut> CancellationHandler<Option<Arc<Mutex<P>>>> {
     143            0 :     pub fn new(map: CancelMap, client: Option<Arc<Mutex<P>>>, from: CancellationSource) -> Self {
     144            0 :         Self { map, client, from }
     145            0 :     }
     146              : }
     147              : 
     148              : /// This should've been a [`std::future::Future`], but
     149              : /// it's impossible to name a type of an unboxed future
     150              : /// (we'd need something like `#![feature(type_alias_impl_trait)]`).
     151              : #[derive(Clone)]
     152              : pub struct CancelClosure {
     153              :     socket_addr: SocketAddr,
     154              :     cancel_token: CancelToken,
     155              : }
     156              : 
     157              : impl CancelClosure {
     158            0 :     pub(crate) fn new(socket_addr: SocketAddr, cancel_token: CancelToken) -> Self {
     159            0 :         Self {
     160            0 :             socket_addr,
     161            0 :             cancel_token,
     162            0 :         }
     163            0 :     }
     164              :     /// Cancels the query running on user's compute node.
     165            0 :     pub(crate) async fn try_cancel_query(self) -> Result<(), CancelError> {
     166            0 :         let socket = TcpStream::connect(self.socket_addr).await?;
     167            0 :         self.cancel_token.cancel_query_raw(socket, NoTls).await?;
     168            0 :         debug!("query was cancelled");
     169            0 :         Ok(())
     170            0 :     }
     171              : }
     172              : 
     173              : /// Helper for registering query cancellation tokens.
     174              : pub(crate) struct Session<P> {
     175              :     /// The user-facing key identifying this session.
     176              :     key: CancelKeyData,
     177              :     /// The [`CancelMap`] this session belongs to.
     178              :     cancellation_handler: Arc<CancellationHandler<P>>,
     179              : }
     180              : 
     181              : impl<P> Session<P> {
     182              :     /// Store the cancel token for the given session.
     183              :     /// This enables query cancellation in `crate::proxy::prepare_client_connection`.
     184            0 :     pub(crate) fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData {
     185            0 :         debug!("enabling query cancellation for this session");
     186            0 :         self.cancellation_handler
     187            0 :             .map
     188            0 :             .insert(self.key, Some(cancel_closure));
     189            0 : 
     190            0 :         self.key
     191            0 :     }
     192              : }
     193              : 
     194              : impl<P> Drop for Session<P> {
     195            1 :     fn drop(&mut self) {
     196            1 :         self.cancellation_handler.map.remove(&self.key);
     197            1 :         debug!("dropped query cancellation key {}", &self.key);
     198            1 :     }
     199              : }
     200              : 
     201              : #[cfg(test)]
     202              : mod tests {
     203              :     use super::*;
     204              : 
     205              :     #[tokio::test]
     206            1 :     async fn check_session_drop() -> anyhow::Result<()> {
     207            1 :         let cancellation_handler = Arc::new(CancellationHandler::<()>::new(
     208            1 :             CancelMap::default(),
     209            1 :             CancellationSource::FromRedis,
     210            1 :         ));
     211            1 : 
     212            1 :         let session = cancellation_handler.clone().get_session();
     213            1 :         assert!(cancellation_handler.contains(&session));
     214            1 :         drop(session);
     215            1 :         // Check that the session has been dropped.
     216            1 :         assert!(cancellation_handler.is_empty());
     217            1 : 
     218            1 :         Ok(())
     219            1 :     }
     220              : 
     221              :     #[tokio::test]
     222            1 :     async fn cancel_session_noop_regression() {
     223            1 :         let handler =
     224            1 :             CancellationHandler::<()>::new(CancelMap::default(), CancellationSource::Local);
     225            1 :         handler
     226            1 :             .cancel_session(
     227            1 :                 CancelKeyData {
     228            1 :                     backend_pid: 0,
     229            1 :                     cancel_key: 0,
     230            1 :                 },
     231            1 :                 Uuid::new_v4(),
     232            1 :             )
     233            1 :             .await
     234            1 :             .unwrap();
     235            1 :     }
     236              : }
        

Generated by: LCOV version 2.1-beta