Test: Lines: 72.7 % 77 56
Test Date: 2024-02-12 20:26:03 Functions: 50.0 % 26 13

            Line data    Source code
       1              : use dashmap::DashMap;
       2              : use pq_proto::CancelKeyData;
       3              : use std::{net::SocketAddr, sync::Arc};
       4              : use thiserror::Error;
       5              : use tokio::net::TcpStream;
       6              : use tokio_postgres::{CancelToken, NoTls};
       7              : use tracing::info;
       8              : 
       9              : use crate::error::ReportableError;
      10              : 
      11              : /// Enables serving `CancelRequest`s.
      12           74 : #[derive(Default)]
      13              : pub struct CancelMap(DashMap<CancelKeyData, Option<CancelClosure>>);
      14              : 
      15            0 : #[derive(Debug, Error)]
      16              : pub enum CancelError {
      17              :     #[error("{0}")]
      18              :     IO(#[from] std::io::Error),
      19              :     #[error("{0}")]
      20              :     Postgres(#[from] tokio_postgres::Error),
      21              : }
      22              : 
      23              : impl ReportableError for CancelError {
      24            0 :     fn get_error_kind(&self) -> crate::error::ErrorKind {
      25            0 :         match self {
      26            0 :             CancelError::IO(_) => crate::error::ErrorKind::Compute,
      27            0 :             CancelError::Postgres(e) if e.as_db_error().is_some() => {
      28            0 :                 crate::error::ErrorKind::Postgres
      29              :             }
      30            0 :             CancelError::Postgres(_) => crate::error::ErrorKind::Compute,
      31              :         }
      32            0 :     }
      33              : }
      34              : 
      35              : impl CancelMap {
      36              :     /// Cancel a running query for the corresponding connection.
      37            0 :     pub async fn cancel_session(&self, key: CancelKeyData) -> Result<(), CancelError> {
      38              :         // NB: we should immediately release the lock after cloning the token.
      39            0 :         let Some(cancel_closure) = self.0.get(&key).and_then(|x| x.clone()) else {
      40            0 :             tracing::warn!("query cancellation key not found: {key}");
      41            0 :             return Ok(());
      42              :         };
      43              : 
      44            0 :         info!("cancelling query per user's request using key {key}");
      45            0 :         cancel_closure.try_cancel_query().await
      46            0 :     }
      47              : 
      48              :     /// Run async action within an ephemeral session identified by [`CancelKeyData`].
      49           43 :     pub fn get_session(self: Arc<Self>) -> Session {
      50              :         // HACK: We'd rather get the real backend_pid but tokio_postgres doesn't
      51              :         // expose it and we don't want to do another roundtrip to query
      52              :         // for it. The client will be able to notice that this is not the
      53              :         // actual backend_pid, but backend_pid is not used for anything
      54              :         // so it doesn't matter.
      55           43 :         let key = loop {
      56           43 :             let key = rand::random();
      57           43 : 
      58           43 :             // Random key collisions are unlikely to happen here, but they're still possible,
      59           43 :             // which is why we have to take care not to rewrite an existing key.
      60           43 :             match self.0.entry(key) {
      61            0 :                 dashmap::mapref::entry::Entry::Occupied(_) => continue,
      62           43 :                 dashmap::mapref::entry::Entry::Vacant(e) => {
      63           43 :                     e.insert(None);
      64           43 :                 }
      65           43 :             }
      66           43 :             break key;
      67           43 :         };
      68           43 : 
      69           43 :         info!("registered new query cancellation key {key}");
      70           43 :         Session {
      71           43 :             key,
      72           43 :             cancel_map: self,
      73           43 :         }
      74           43 :     }
      75              : 
      76              :     #[cfg(test)]
      77            2 :     fn contains(&self, session: &Session) -> bool {
      78            2 :         self.0.contains_key(&session.key)
      79            2 :     }
      80              : 
      81              :     #[cfg(test)]
      82            2 :     fn is_empty(&self) -> bool {
      83            2 :         self.0.is_empty()
      84            2 :     }
      85              : }
      86              : 
      87              : /// This should've been a [`std::future::Future`], but
      88              : /// it's impossible to name a type of an unboxed future
      89              : /// (we'd need something like `#![feature(type_alias_impl_trait)]`).
      90           41 : #[derive(Clone)]
      91              : pub struct CancelClosure {
      92              :     socket_addr: SocketAddr,
      93              :     cancel_token: CancelToken,
      94              : }
      95              : 
      96              : impl CancelClosure {
      97           41 :     pub fn new(socket_addr: SocketAddr, cancel_token: CancelToken) -> Self {
      98           41 :         Self {
      99           41 :             socket_addr,
     100           41 :             cancel_token,
     101           41 :         }
     102           41 :     }
     103              : 
     104              :     /// Cancels the query running on user's compute node.
     105            0 :     async fn try_cancel_query(self) -> Result<(), CancelError> {
     106            0 :         let socket = TcpStream::connect(self.socket_addr).await?;
     107            0 :         self.cancel_token.cancel_query_raw(socket, NoTls).await?;
     108              : 
     109            0 :         Ok(())
     110            0 :     }
     111              : }
     112              : 
     113              : /// Helper for registering query cancellation tokens.
     114              : pub struct Session {
     115              :     /// The user-facing key identifying this session.
     116              :     key: CancelKeyData,
     117              :     /// The [`CancelMap`] this session belongs to.
     118              :     cancel_map: Arc<CancelMap>,
     119              : }
     120              : 
     121              : impl Session {
     122              :     /// Store the cancel token for the given session.
     123              :     /// This enables query cancellation in `crate::proxy::prepare_client_connection`.
     124           41 :     pub fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData {
     125           41 :         info!("enabling query cancellation for this session");
     126           41 :         self.cancel_map.0.insert(self.key, Some(cancel_closure));
     127           41 : 
     128           41 :         self.key
     129           41 :     }
     130              : }
     131              : 
     132              : impl Drop for Session {
     133           43 :     fn drop(&mut self) {
     134           43 :         self.cancel_map.0.remove(&self.key);
     135           43 :         info!("dropped query cancellation key {}", &self.key);
     136           43 :     }
     137              : }
     138              : 
     139              : #[cfg(test)]
     140              : mod tests {
     141              :     use super::*;
     142              : 
     143            2 :     #[tokio::test]
     144            2 :     async fn check_session_drop() -> anyhow::Result<()> {
     145            2 :         let cancel_map: Arc<CancelMap> = Default::default();
     146            2 : 
     147            2 :         let session = cancel_map.clone().get_session();
     148            2 :         assert!(cancel_map.contains(&session));
     149            2 :         drop(session);
     150            2 :         // Check that the session has been dropped.
     151            2 :         assert!(cancel_map.is_empty());
     152            2 : 
     153            2 :         Ok(())
     154            2 :     }
     155              : }

