LCOV - differential code coverage report
Current view: top level - proxy/src - cancellation.rs (source / functions) Coverage Total Hit UBC CBC
Current: cd44433dd675caa99df17a61b18949c8387e2242.info Lines: 79.8 % 84 67 17 67
Current Date: 2024-01-09 02:06:09 Functions: 50.0 % 38 19 19 19
Baseline: 66c52a629a0f4a503e193045e0df4c77139e344b.info
Baseline Date: 2024-01-08 15:34:46

           TLA  Line data    Source code
       1                 : use anyhow::{bail, Context};
       2                 : use dashmap::DashMap;
       3                 : use pq_proto::CancelKeyData;
       4                 : use std::net::SocketAddr;
       5                 : use tokio::net::TcpStream;
       6                 : use tokio_postgres::{CancelToken, NoTls};
       7                 : use tracing::info;
       8                 : 
       9                 : /// Enables serving `CancelRequest`s.
      10 CBC          90 : #[derive(Default)]
      11                 : pub struct CancelMap(DashMap<CancelKeyData, Option<CancelClosure>>);
      12                 : 
      13                 : impl CancelMap {
      14                 :     /// Cancel a running query for the corresponding connection.
      15 UBC           0 :     pub async fn cancel_session(&self, key: CancelKeyData) -> anyhow::Result<()> {
      16                 :         // NB: we should immediately release the lock after cloning the token.
      17               0 :         let cancel_closure = self
      18               0 :             .0
      19               0 :             .get(&key)
      20               0 :             .and_then(|x| x.clone())
      21               0 :             .with_context(|| format!("query cancellation key not found: {key}"))?;
      22                 : 
      23               0 :         info!("cancelling query per user's request using key {key}");
      24               0 :         cancel_closure.try_cancel_query().await
      25               0 :     }
      26                 : 
      27                 :     /// Run async action within an ephemeral session identified by [`CancelKeyData`].
      28 CBC          50 :     pub async fn with_session<'a, F, R, V>(&'a self, f: F) -> anyhow::Result<V>
      29              50 :     where
      30              50 :         F: FnOnce(Session<'a>) -> R,
      31              50 :         R: std::future::Future<Output = anyhow::Result<V>>,
      32              50 :     {
      33              50 :         // HACK: We'd rather get the real backend_pid but tokio_postgres doesn't
      34              50 :         // expose it and we don't want to do another roundtrip to query
      35              50 :         // for it. The client will be able to notice that this is not the
      36              50 :         // actual backend_pid, but backend_pid is not used for anything
      37              50 :         // so it doesn't matter.
      38              50 :         let key = rand::random();
      39              50 : 
      40              50 :         // Random key collisions are unlikely to happen here, but they're still possible,
      41              50 :         // which is why we have to take care not to rewrite an existing key.
      42              50 :         match self.0.entry(key) {
      43                 :             dashmap::mapref::entry::Entry::Occupied(_) => {
      44 UBC           0 :                 bail!("query cancellation key already exists: {key}")
      45                 :             }
      46 CBC          50 :             dashmap::mapref::entry::Entry::Vacant(e) => {
      47              50 :                 e.insert(None);
      48              50 :             }
      49                 :         }
      50                 : 
      51                 :         // This will guarantee that the session gets dropped
      52                 :         // as soon as the future is finished.
      53              50 :         scopeguard::defer! {
      54              50 :             self.0.remove(&key);
      55              50 :             info!("dropped query cancellation key {key}");
      56                 :         }
      57                 : 
      58              49 :         info!("registered new query cancellation key {key}");
      59              50 :         let session = Session::new(key, self);
      60             877 :         f(session).await
      61              49 :     }
      62                 : 
      63                 :     #[cfg(test)]
      64               1 :     fn contains(&self, session: &Session) -> bool {
      65               1 :         self.0.contains_key(&session.key)
      66               1 :     }
      67                 : 
      68                 :     #[cfg(test)]
      69               1 :     fn is_empty(&self) -> bool {
      70               1 :         self.0.is_empty()
      71               1 :     }
      72                 : }
      73                 : 
      74                 : /// This should've been a [`std::future::Future`], but
      75                 : /// it's impossible to name a type of an unboxed future
      76                 : /// (we'd need something like `#![feature(type_alias_impl_trait)]`).
      77              38 : #[derive(Clone)]
      78                 : pub struct CancelClosure {
      79                 :     socket_addr: SocketAddr,
      80                 :     cancel_token: CancelToken,
      81                 : }
      82                 : 
      83                 : impl CancelClosure {
      84              38 :     pub fn new(socket_addr: SocketAddr, cancel_token: CancelToken) -> Self {
      85              38 :         Self {
      86              38 :             socket_addr,
      87              38 :             cancel_token,
      88              38 :         }
      89              38 :     }
      90                 : 
      91                 :     /// Cancels the query running on user's compute node.
      92 UBC           0 :     pub async fn try_cancel_query(self) -> anyhow::Result<()> {
      93               0 :         let socket = TcpStream::connect(self.socket_addr).await?;
      94               0 :         self.cancel_token.cancel_query_raw(socket, NoTls).await?;
      95                 : 
      96               0 :         Ok(())
      97               0 :     }
      98                 : }
      99                 : 
     100                 : /// Helper for registering query cancellation tokens.
     101                 : pub struct Session<'a> {
     102                 :     /// The user-facing key identifying this session.
     103                 :     key: CancelKeyData,
     104                 :     /// The [`CancelMap`] this session belongs to.
     105                 :     cancel_map: &'a CancelMap,
     106                 : }
     107                 : 
     108                 : impl<'a> Session<'a> {
     109 CBC          50 :     fn new(key: CancelKeyData, cancel_map: &'a CancelMap) -> Self {
     110              50 :         Self { key, cancel_map }
     111              50 :     }
     112                 : }
     113                 : 
     114                 : impl Session<'_> {
     115                 :     /// Store the cancel token for the given session.
     116                 :     /// This enables query cancellation in `crate::proxy::prepare_client_connection`.
     117              38 :     pub fn enable_query_cancellation(self, cancel_closure: CancelClosure) -> CancelKeyData {
     118              38 :         info!("enabling query cancellation for this session");
     119              38 :         self.cancel_map.0.insert(self.key, Some(cancel_closure));
     120              38 : 
     121              38 :         self.key
     122              38 :     }
     123                 : }
     124                 : 
     125                 : #[cfg(test)]
     126                 : mod tests {
     127                 :     use super::*;
     128                 :     use once_cell::sync::Lazy;
     129                 : 
     130               1 :     #[tokio::test]
     131               1 :     async fn check_session_drop() -> anyhow::Result<()> {
     132               1 :         static CANCEL_MAP: Lazy<CancelMap> = Lazy::new(Default::default);
     133               1 : 
     134               1 :         let (tx, rx) = tokio::sync::oneshot::channel();
     135               1 :         let task = tokio::spawn(CANCEL_MAP.with_session(|session| async move {
     136               1 :             assert!(CANCEL_MAP.contains(&session));
     137                 : 
     138               1 :             tx.send(()).expect("failed to send");
     139               1 :             futures::future::pending::<()>().await; // sleep forever
     140                 : 
     141 UBC           0 :             Ok(())
     142 CBC           1 :         }));
     143               1 : 
     144               1 :         // Wait until the task has been spawned.
     145               1 :         rx.await.context("failed to hear from the task")?;
     146                 : 
     147                 :         // Drop the session's entry by cancelling the task.
     148               1 :         task.abort();
     149               1 :         let error = task.await.expect_err("task should have failed");
     150               1 :         if !error.is_cancelled() {
     151 UBC           0 :             anyhow::bail!(error);
     152 CBC           1 :         }
     153                 : 
     154                 :         // Check that the session has been dropped.
     155               1 :         assert!(CANCEL_MAP.is_empty());
     156                 : 
     157               1 :         Ok(())
     158                 :     }
     159                 : }
        

Generated by: LCOV version 2.1-beta