LCOV - code coverage report
Current view: top level - proxy/src - cancellation.rs (source / functions) Coverage Total Hit
Test: 32f4a56327bc9da697706839ed4836b2a00a408f.info Lines: 77.9 % 68 53
Test Date: 2024-02-07 07:37:29 Functions: 65.0 % 20 13

            Line data    Source code
       1              : use anyhow::Context;
       2              : use dashmap::DashMap;
       3              : use pq_proto::CancelKeyData;
       4              : use std::{net::SocketAddr, sync::Arc};
       5              : use tokio::net::TcpStream;
       6              : use tokio_postgres::{CancelToken, NoTls};
       7              : use tracing::info;
       8              : 
       9              : /// Enables serving `CancelRequest`s.
      10          115 : #[derive(Default)]
      11              : pub struct CancelMap(DashMap<CancelKeyData, Option<CancelClosure>>);
      12              : 
      13              : impl CancelMap {
      14              :     /// Cancel a running query for the corresponding connection.
      15            0 :     pub async fn cancel_session(&self, key: CancelKeyData) -> anyhow::Result<()> {
      16              :         // NB: we should immediately release the lock after cloning the token.
      17            0 :         let cancel_closure = self
      18            0 :             .0
      19            0 :             .get(&key)
      20            0 :             .and_then(|x| x.clone())
      21            0 :             .with_context(|| format!("query cancellation key not found: {key}"))?;
      22              : 
      23            0 :         info!("cancelling query per user's request using key {key}");
      24            0 :         cancel_closure.try_cancel_query().await
      25            0 :     }
      26              : 
      27              :     /// Run async action within an ephemeral session identified by [`CancelKeyData`].
      28           41 :     pub fn get_session(self: Arc<Self>) -> Session {
      29              :         // HACK: We'd rather get the real backend_pid but tokio_postgres doesn't
      30              :         // expose it and we don't want to do another roundtrip to query
      31              :         // for it. The client will be able to notice that this is not the
      32              :         // actual backend_pid, but backend_pid is not used for anything
      33              :         // so it doesn't matter.
      34           41 :         let key = loop {
      35           41 :             let key = rand::random();
      36           41 : 
      37           41 :             // Random key collisions are unlikely to happen here, but they're still possible,
      38           41 :             // which is why we have to take care not to rewrite an existing key.
      39           41 :             match self.0.entry(key) {
      40            0 :                 dashmap::mapref::entry::Entry::Occupied(_) => continue,
      41           41 :                 dashmap::mapref::entry::Entry::Vacant(e) => {
      42           41 :                     e.insert(None);
      43           41 :                 }
      44           41 :             }
      45           41 :             break key;
      46           41 :         };
      47           41 : 
      48           41 :         info!("registered new query cancellation key {key}");
      49           41 :         Session {
      50           41 :             key,
      51           41 :             cancel_map: self,
      52           41 :         }
      53           41 :     }
      54              : 
      55              :     #[cfg(test)]
      56            2 :     fn contains(&self, session: &Session) -> bool {
      57            2 :         self.0.contains_key(&session.key)
      58            2 :     }
      59              : 
      60              :     #[cfg(test)]
      61            2 :     fn is_empty(&self) -> bool {
      62            2 :         self.0.is_empty()
      63            2 :     }
      64              : }
      65              : 
      66              : /// This should've been a [`std::future::Future`], but
      67              : /// it's impossible to name a type of an unboxed future
      68              : /// (we'd need something like `#![feature(type_alias_impl_trait)]`).
      69           39 : #[derive(Clone)]
      70              : pub struct CancelClosure {
      71              :     socket_addr: SocketAddr,
      72              :     cancel_token: CancelToken,
      73              : }
      74              : 
      75              : impl CancelClosure {
      76           39 :     pub fn new(socket_addr: SocketAddr, cancel_token: CancelToken) -> Self {
      77           39 :         Self {
      78           39 :             socket_addr,
      79           39 :             cancel_token,
      80           39 :         }
      81           39 :     }
      82              : 
      83              :     /// Cancels the query running on user's compute node.
      84            0 :     pub async fn try_cancel_query(self) -> anyhow::Result<()> {
      85            0 :         let socket = TcpStream::connect(self.socket_addr).await?;
      86            0 :         self.cancel_token.cancel_query_raw(socket, NoTls).await?;
      87              : 
      88            0 :         Ok(())
      89            0 :     }
      90              : }
      91              : 
      92              : /// Helper for registering query cancellation tokens.
      93              : pub struct Session {
      94              :     /// The user-facing key identifying this session.
      95              :     key: CancelKeyData,
      96              :     /// The [`CancelMap`] this session belongs to.
      97              :     cancel_map: Arc<CancelMap>,
      98              : }
      99              : 
     100              : impl Session {
     101              :     /// Store the cancel token for the given session.
     102              :     /// This enables query cancellation in `crate::proxy::prepare_client_connection`.
     103           39 :     pub fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData {
     104           39 :         info!("enabling query cancellation for this session");
     105           39 :         self.cancel_map.0.insert(self.key, Some(cancel_closure));
     106           39 : 
     107           39 :         self.key
     108           39 :     }
     109              : }
     110              : 
     111              : impl Drop for Session {
     112           41 :     fn drop(&mut self) {
     113           41 :         self.cancel_map.0.remove(&self.key);
     114           41 :         info!("dropped query cancellation key {}", &self.key);
     115           41 :     }
     116              : }
     117              : 
     118              : #[cfg(test)]
     119              : mod tests {
     120              :     use super::*;
     121              : 
     122            2 :     #[tokio::test]
     123            2 :     async fn check_session_drop() -> anyhow::Result<()> {
     124            2 :         let cancel_map: Arc<CancelMap> = Default::default();
     125            2 : 
     126            2 :         let session = cancel_map.clone().get_session();
     127            2 :         assert!(cancel_map.contains(&session));
     128            2 :         drop(session);
     129              :         // Check that the session has been dropped.
     130            2 :         assert!(cancel_map.is_empty());
     131              : 
     132            2 :         Ok(())
     133              :     }
     134              : }
        

Generated by: LCOV version 2.1-beta