LCOV - code coverage report
Current view: top level - proxy/src - cancellation.rs (source / functions) Coverage Total Hit
Test: aca8877be6ceba750c1be359ed71bc1799d52b30.info Lines: 52.1 % 121 63
Test Date: 2024-02-14 18:05:35 Functions: 39.4 % 33 13

            Line data    Source code
       1              : use async_trait::async_trait;
       2              : use dashmap::DashMap;
       3              : use pq_proto::CancelKeyData;
       4              : use std::{net::SocketAddr, sync::Arc};
       5              : use thiserror::Error;
       6              : use tokio::net::TcpStream;
       7              : use tokio::sync::Mutex;
       8              : use tokio_postgres::{CancelToken, NoTls};
       9              : use tracing::info;
      10              : use uuid::Uuid;
      11              : 
      12              : use crate::{
      13              :     error::ReportableError, metrics::NUM_CANCELLATION_REQUESTS,
      14              :     redis::publisher::RedisPublisherClient,
      15              : };
      16              : 
      17              : pub type CancelMap = Arc<DashMap<CancelKeyData, Option<CancelClosure>>>;
      18              : 
      19              : /// Enables serving `CancelRequest`s.
      20              : ///
      21              : /// If there is a `RedisPublisherClient` available, it will be used to publish the cancellation key to other proxy instances.
      22              : pub struct CancellationHandler {
      23              :     map: CancelMap,
      24              :     redis_client: Option<Arc<Mutex<RedisPublisherClient>>>,
      25              : }
      26              : 
      27            0 : #[derive(Debug, Error)]
      28              : pub enum CancelError {
      29              :     #[error("{0}")]
      30              :     IO(#[from] std::io::Error),
      31              :     #[error("{0}")]
      32              :     Postgres(#[from] tokio_postgres::Error),
      33              : }
      34              : 
      35              : impl ReportableError for CancelError {
      36            0 :     fn get_error_kind(&self) -> crate::error::ErrorKind {
      37            0 :         match self {
      38            0 :             CancelError::IO(_) => crate::error::ErrorKind::Compute,
      39            0 :             CancelError::Postgres(e) if e.as_db_error().is_some() => {
      40            0 :                 crate::error::ErrorKind::Postgres
      41              :             }
      42            0 :             CancelError::Postgres(_) => crate::error::ErrorKind::Compute,
      43              :         }
      44            0 :     }
      45              : }
      46              : 
      47              : impl CancellationHandler {
      48           25 :     pub fn new(map: CancelMap, redis_client: Option<Arc<Mutex<RedisPublisherClient>>>) -> Self {
      49           25 :         Self { map, redis_client }
      50           25 :     }
      51              :     /// Cancel a running query for the corresponding connection.
      52            0 :     pub async fn cancel_session(
      53            0 :         &self,
      54            0 :         key: CancelKeyData,
      55            0 :         session_id: Uuid,
      56            0 :     ) -> Result<(), CancelError> {
      57            0 :         let from = "from_client";
      58              :         // NB: we should immediately release the lock after cloning the token.
      59            0 :         let Some(cancel_closure) = self.map.get(&key).and_then(|x| x.clone()) else {
      60            0 :             tracing::warn!("query cancellation key not found: {key}");
      61            0 :             if let Some(redis_client) = &self.redis_client {
      62            0 :                 NUM_CANCELLATION_REQUESTS
      63            0 :                     .with_label_values(&[from, "not_found"])
      64            0 :                     .inc();
      65            0 :                 info!("publishing cancellation key to Redis");
      66            0 :                 match redis_client.lock().await.try_publish(key, session_id).await {
      67              :                     Ok(()) => {
      68            0 :                         info!("cancellation key successfuly published to Redis");
      69              :                     }
      70            0 :                     Err(e) => {
      71            0 :                         tracing::error!("failed to publish a message: {e}");
      72            0 :                         return Err(CancelError::IO(std::io::Error::new(
      73            0 :                             std::io::ErrorKind::Other,
      74            0 :                             e.to_string(),
      75            0 :                         )));
      76              :                     }
      77              :                 }
      78            0 :             }
      79            0 :             return Ok(());
      80              :         };
      81            0 :         NUM_CANCELLATION_REQUESTS
      82            0 :             .with_label_values(&[from, "found"])
      83            0 :             .inc();
      84            0 :         info!("cancelling query per user's request using key {key}");
      85            0 :         cancel_closure.try_cancel_query().await
      86            0 :     }
      87              : 
      88              :     /// Run async action within an ephemeral session identified by [`CancelKeyData`].
      89           43 :     pub fn get_session(self: Arc<Self>) -> Session {
      90              :         // HACK: We'd rather get the real backend_pid but tokio_postgres doesn't
      91              :         // expose it and we don't want to do another roundtrip to query
      92              :         // for it. The client will be able to notice that this is not the
      93              :         // actual backend_pid, but backend_pid is not used for anything
      94              :         // so it doesn't matter.
      95           43 :         let key = loop {
      96           43 :             let key = rand::random();
      97           43 : 
      98           43 :             // Random key collisions are unlikely to happen here, but they're still possible,
      99           43 :             // which is why we have to take care not to rewrite an existing key.
     100           43 :             match self.map.entry(key) {
     101            0 :                 dashmap::mapref::entry::Entry::Occupied(_) => continue,
     102           43 :                 dashmap::mapref::entry::Entry::Vacant(e) => {
     103           43 :                     e.insert(None);
     104           43 :                 }
     105           43 :             }
     106           43 :             break key;
     107           43 :         };
     108           43 : 
     109           43 :         info!("registered new query cancellation key {key}");
     110           43 :         Session {
     111           43 :             key,
     112           43 :             cancellation_handler: self,
     113           43 :         }
     114           43 :     }
     115              : 
     116              :     #[cfg(test)]
     117            2 :     fn contains(&self, session: &Session) -> bool {
     118            2 :         self.map.contains_key(&session.key)
     119            2 :     }
     120              : 
     121              :     #[cfg(test)]
     122            2 :     fn is_empty(&self) -> bool {
     123            2 :         self.map.is_empty()
     124            2 :     }
     125              : }
     126              : 
     127              : #[async_trait]
     128              : pub trait NotificationsCancellationHandler {
     129              :     async fn cancel_session_no_publish(&self, key: CancelKeyData) -> Result<(), CancelError>;
     130              : }
     131              : 
     132              : #[async_trait]
     133              : impl NotificationsCancellationHandler for CancellationHandler {
     134            0 :     async fn cancel_session_no_publish(&self, key: CancelKeyData) -> Result<(), CancelError> {
     135            0 :         let from = "from_redis";
     136            0 :         let cancel_closure = self.map.get(&key).and_then(|x| x.clone());
     137            0 :         match cancel_closure {
     138            0 :             Some(cancel_closure) => {
     139            0 :                 NUM_CANCELLATION_REQUESTS
     140            0 :                     .with_label_values(&[from, "found"])
     141            0 :                     .inc();
     142            0 :                 cancel_closure.try_cancel_query().await
     143              :             }
     144              :             None => {
     145            0 :                 NUM_CANCELLATION_REQUESTS
     146            0 :                     .with_label_values(&[from, "not_found"])
     147            0 :                     .inc();
     148            0 :                 tracing::warn!("query cancellation key not found: {key}");
     149            0 :                 Ok(())
     150              :             }
     151              :         }
     152            0 :     }
     153              : }
     154              : 
     155              : /// This should've been a [`std::future::Future`], but
     156              : /// it's impossible to name a type of an unboxed future
     157              : /// (we'd need something like `#![feature(type_alias_impl_trait)]`).
     158           41 : #[derive(Clone)]
     159              : pub struct CancelClosure {
     160              :     socket_addr: SocketAddr,
     161              :     cancel_token: CancelToken,
     162              : }
     163              : 
     164              : impl CancelClosure {
     165           41 :     pub fn new(socket_addr: SocketAddr, cancel_token: CancelToken) -> Self {
     166           41 :         Self {
     167           41 :             socket_addr,
     168           41 :             cancel_token,
     169           41 :         }
     170           41 :     }
     171              : 
     172              :     /// Cancels the query running on user's compute node.
     173            0 :     async fn try_cancel_query(self) -> Result<(), CancelError> {
     174            0 :         let socket = TcpStream::connect(self.socket_addr).await?;
     175            0 :         self.cancel_token.cancel_query_raw(socket, NoTls).await?;
     176              : 
     177            0 :         Ok(())
     178            0 :     }
     179              : }
     180              : 
     181              : /// Helper for registering query cancellation tokens.
     182              : pub struct Session {
     183              :     /// The user-facing key identifying this session.
     184              :     key: CancelKeyData,
     185              :     /// The [`CancelMap`] this session belongs to.
     186              :     cancellation_handler: Arc<CancellationHandler>,
     187              : }
     188              : 
     189              : impl Session {
     190              :     /// Store the cancel token for the given session.
     191              :     /// This enables query cancellation in `crate::proxy::prepare_client_connection`.
     192           41 :     pub fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData {
     193           41 :         info!("enabling query cancellation for this session");
     194           41 :         self.cancellation_handler
     195           41 :             .map
     196           41 :             .insert(self.key, Some(cancel_closure));
     197           41 : 
     198           41 :         self.key
     199           41 :     }
     200              : }
     201              : 
     202              : impl Drop for Session {
     203           43 :     fn drop(&mut self) {
     204           43 :         self.cancellation_handler.map.remove(&self.key);
     205           43 :         info!("dropped query cancellation key {}", &self.key);
     206           43 :     }
     207              : }
     208              : 
     209              : #[cfg(test)]
     210              : mod tests {
     211              :     use super::*;
     212              : 
     213            2 :     #[tokio::test]
     214            2 :     async fn check_session_drop() -> anyhow::Result<()> {
     215            2 :         let cancellation_handler = Arc::new(CancellationHandler {
     216            2 :             map: CancelMap::default(),
     217            2 :             redis_client: None,
     218            2 :         });
     219            2 : 
     220            2 :         let session = cancellation_handler.clone().get_session();
     221            2 :         assert!(cancellation_handler.contains(&session));
     222            2 :         drop(session);
     223            2 :         // Check that the session has been dropped.
     224            2 :         assert!(cancellation_handler.is_empty());
     225            2 : 
     226            2 :         Ok(())
     227            2 :     }
     228              : }
        

Generated by: LCOV version 2.1-beta