LCOV - differential code coverage report
Current view: top level - proxy/src - cancellation.rs (source / functions) Coverage Total Hit UBC CBC
Current: f6946e90941b557c917ac98cd5a7e9506d180f3e.info Lines: 80.9 % 89 72 17 72
Current Date: 2023-10-19 02:04:12 Functions: 45.2 % 42 19 23 19
Baseline: c8637f37369098875162f194f92736355783b050.info
Baseline Date: 2023-10-18 20:25:20

           TLA  Line data    Source code
       1                 : use anyhow::{anyhow, Context};
       2                 : use hashbrown::HashMap;
       3                 : use pq_proto::CancelKeyData;
       4                 : use std::net::SocketAddr;
       5                 : use tokio::net::TcpStream;
       6                 : use tokio_postgres::{CancelToken, NoTls};
       7                 : use tracing::info;
       8                 : 
       9                 : /// Enables serving `CancelRequest`s.
      10 CBC          54 : #[derive(Default)]
      11                 : pub struct CancelMap(parking_lot::RwLock<HashMap<CancelKeyData, Option<CancelClosure>>>);
      12                 : 
      13                 : impl CancelMap {
      14                 :     /// Cancel a running query for the corresponding connection.
      15 UBC           0 :     pub async fn cancel_session(&self, key: CancelKeyData) -> anyhow::Result<()> {
      16                 :         // NB: we should immediately release the lock after cloning the token.
      17               0 :         let cancel_closure = self
      18               0 :             .0
      19               0 :             .read()
      20               0 :             .get(&key)
      21               0 :             .and_then(|x| x.clone())
      22               0 :             .with_context(|| format!("query cancellation key not found: {key}"))?;
      23                 : 
      24               0 :         info!("cancelling query per user's request using key {key}");
      25               0 :         cancel_closure.try_cancel_query().await
      26               0 :     }
      27                 : 
      28                 :     /// Run async action within an ephemeral session identified by [`CancelKeyData`].
      29 CBC          34 :     pub async fn with_session<'a, F, R, V>(&'a self, f: F) -> anyhow::Result<V>
      30              34 :     where
      31              34 :         F: FnOnce(Session<'a>) -> R,
      32              34 :         R: std::future::Future<Output = anyhow::Result<V>>,
      33              34 :     {
      34              34 :         // HACK: We'd rather get the real backend_pid but tokio_postgres doesn't
      35              34 :         // expose it and we don't want to do another roundtrip to query
      36              34 :         // for it. The client will be able to notice that this is not the
      37              34 :         // actual backend_pid, but backend_pid is not used for anything
      38              34 :         // so it doesn't matter.
      39              34 :         let key = rand::random();
      40              34 : 
      41              34 :         // Random key collisions are unlikely to happen here, but they're still possible,
      42              34 :         // which is why we have to take care not to rewrite an existing key.
      43              34 :         self.0
      44              34 :             .write()
      45              34 :             .try_insert(key, None)
      46              34 :             .map_err(|_| anyhow!("query cancellation key already exists: {key}"))?;
      47                 : 
      48                 :         // This will guarantee that the session gets dropped
      49                 :         // as soon as the future is finished.
      50              34 :         scopeguard::defer! {
      51              34 :             self.0.write().remove(&key);
      52              34 :             info!("dropped query cancellation key {key}");
      53                 :         }
      54                 : 
      55              33 :         info!("registered new query cancellation key {key}");
      56              34 :         let session = Session::new(key, self);
      57             397 :         f(session).await
      58              33 :     }
      59                 : 
      60                 :     #[cfg(test)]
      61               1 :     fn contains(&self, session: &Session) -> bool {
      62               1 :         self.0.read().contains_key(&session.key)
      63               1 :     }
      64                 : 
      65                 :     #[cfg(test)]
      66               1 :     fn is_empty(&self) -> bool {
      67               1 :         self.0.read().is_empty()
      68               1 :     }
      69                 : }
      70                 : 
      71                 : /// This should've been a [`std::future::Future`], but
      72                 : /// it's impossible to name a type of an unboxed future
      73                 : /// (we'd need something like `#![feature(type_alias_impl_trait)]`).
      74              29 : #[derive(Clone)]
      75                 : pub struct CancelClosure {
      76                 :     socket_addr: SocketAddr,
      77                 :     cancel_token: CancelToken,
      78                 : }
      79                 : 
      80                 : impl CancelClosure {
      81              29 :     pub fn new(socket_addr: SocketAddr, cancel_token: CancelToken) -> Self {
      82              29 :         Self {
      83              29 :             socket_addr,
      84              29 :             cancel_token,
      85              29 :         }
      86              29 :     }
      87                 : 
      88                 :     /// Cancels the query running on user's compute node.
      89 UBC           0 :     pub async fn try_cancel_query(self) -> anyhow::Result<()> {
      90               0 :         let socket = TcpStream::connect(self.socket_addr).await?;
      91               0 :         self.cancel_token.cancel_query_raw(socket, NoTls).await?;
      92                 : 
      93               0 :         Ok(())
      94               0 :     }
      95                 : }
      96                 : 
      97                 : /// Helper for registering query cancellation tokens.
      98                 : pub struct Session<'a> {
      99                 :     /// The user-facing key identifying this session.
     100                 :     key: CancelKeyData,
     101                 :     /// The [`CancelMap`] this session belongs to.
     102                 :     cancel_map: &'a CancelMap,
     103                 : }
     104                 : 
     105                 : impl<'a> Session<'a> {
     106 CBC          34 :     fn new(key: CancelKeyData, cancel_map: &'a CancelMap) -> Self {
     107              34 :         Self { key, cancel_map }
     108              34 :     }
     109                 : }
     110                 : 
     111                 : impl Session<'_> {
     112                 :     /// Store the cancel token for the given session.
     113                 :     /// This enables query cancellation in `crate::proxy::prepare_client_connection`.
     114              29 :     pub fn enable_query_cancellation(self, cancel_closure: CancelClosure) -> CancelKeyData {
     115              29 :         info!("enabling query cancellation for this session");
     116              29 :         self.cancel_map
     117              29 :             .0
     118              29 :             .write()
     119              29 :             .insert(self.key, Some(cancel_closure));
     120              29 : 
     121              29 :         self.key
     122              29 :     }
     123                 : }
     124                 : 
     125                 : #[cfg(test)]
     126                 : mod tests {
     127                 :     use super::*;
     128                 :     use once_cell::sync::Lazy;
     129                 : 
     130               1 :     #[tokio::test]
     131               1 :     async fn check_session_drop() -> anyhow::Result<()> {
     132               1 :         static CANCEL_MAP: Lazy<CancelMap> = Lazy::new(Default::default);
     133               1 : 
     134               1 :         let (tx, rx) = tokio::sync::oneshot::channel();
     135               1 :         let task = tokio::spawn(CANCEL_MAP.with_session(|session| async move {
     136               1 :             assert!(CANCEL_MAP.contains(&session));
     137                 : 
     138               1 :             tx.send(()).expect("failed to send");
     139               1 :             futures::future::pending::<()>().await; // sleep forever
     140                 : 
     141 UBC           0 :             Ok(())
     142 CBC           1 :         }));
     143               1 : 
     144               1 :         // Wait until the task has been spawned.
     145               1 :         rx.await.context("failed to hear from the task")?;
     146                 : 
     147                 :         // Drop the session's entry by cancelling the task.
     148               1 :         task.abort();
     149               1 :         let error = task.await.expect_err("task should have failed");
     150               1 :         if !error.is_cancelled() {
     151 UBC           0 :             anyhow::bail!(error);
     152 CBC           1 :         }
     153               1 : 
     154               1 :         // Check that the session has been dropped.
     155               1 :         assert!(CANCEL_MAP.is_empty());
     156                 : 
     157               1 :         Ok(())
     158                 :     }
     159                 : }
        

Generated by: LCOV version 2.1-beta