LCOV - code coverage report
Current view: top level - proxy/src - cancellation.rs (source / functions) Coverage Total Hit
Test: 8ac049b474321fdc72ddcb56d7165153a1a900e8.info Lines: 80.9 % 89 72
Test Date: 2023-09-06 10:18:01 Functions: 45.2 % 42 19

            Line data    Source code
       1              : use anyhow::{anyhow, Context};
       2              : use hashbrown::HashMap;
       3              : use pq_proto::CancelKeyData;
       4              : use std::net::SocketAddr;
       5              : use tokio::net::TcpStream;
       6              : use tokio_postgres::{CancelToken, NoTls};
       7              : use tracing::info;
       8              : 
       9              : /// Enables serving `CancelRequest`s.
      10           44 : #[derive(Default)]
      11              : pub struct CancelMap(parking_lot::RwLock<HashMap<CancelKeyData, Option<CancelClosure>>>);
      12              : 
      13              : impl CancelMap {
      14              :     /// Cancel a running query for the corresponding connection.
      15            0 :     pub async fn cancel_session(&self, key: CancelKeyData) -> anyhow::Result<()> {
      16              :         // NB: we should immediately release the lock after cloning the token.
      17            0 :         let cancel_closure = self
      18            0 :             .0
      19            0 :             .read()
      20            0 :             .get(&key)
      21            0 :             .and_then(|x| x.clone())
      22            0 :             .with_context(|| format!("query cancellation key not found: {key}"))?;
      23              : 
      24            0 :         info!("cancelling query per user's request using key {key}");
      25            0 :         cancel_closure.try_cancel_query().await
      26            0 :     }
      27              : 
      28              :     /// Run async action within an ephemeral session identified by [`CancelKeyData`].
      29           32 :     pub async fn with_session<'a, F, R, V>(&'a self, f: F) -> anyhow::Result<V>
      30           32 :     where
      31           32 :         F: FnOnce(Session<'a>) -> R,
      32           32 :         R: std::future::Future<Output = anyhow::Result<V>>,
      33           32 :     {
      34           32 :         // HACK: We'd rather get the real backend_pid but tokio_postgres doesn't
      35           32 :         // expose it and we don't want to do another roundtrip to query
      36           32 :         // for it. The client will be able to notice that this is not the
      37           32 :         // actual backend_pid, but backend_pid is not used for anything
      38           32 :         // so it doesn't matter.
      39           32 :         let key = rand::random();
      40           32 : 
      41           32 :         // Random key collisions are unlikely to happen here, but they're still possible,
      42           32 :         // which is why we have to take care not to rewrite an existing key.
      43           32 :         self.0
      44           32 :             .write()
      45           32 :             .try_insert(key, None)
      46           32 :             .map_err(|_| anyhow!("query cancellation key already exists: {key}"))?;
      47              : 
      48              :         // This will guarantee that the session gets dropped
      49              :         // as soon as the future is finished.
      50           32 :         scopeguard::defer! {
      51           32 :             self.0.write().remove(&key);
      52           32 :             info!("dropped query cancellation key {key}");
      53              :         }
      54              : 
      55           31 :         info!("registered new query cancellation key {key}");
      56           32 :         let session = Session::new(key, self);
      57          358 :         f(session).await
      58           31 :     }
      59              : 
      60              :     #[cfg(test)]
      61            1 :     fn contains(&self, session: &Session) -> bool {
      62            1 :         self.0.read().contains_key(&session.key)
      63            1 :     }
      64              : 
      65              :     #[cfg(test)]
      66            1 :     fn is_empty(&self) -> bool {
      67            1 :         self.0.read().is_empty()
      68            1 :     }
      69              : }
      70              : 
      71              : /// This should've been a [`std::future::Future`], but
      72              : /// it's impossible to name a type of an unboxed future
      73              : /// (we'd need something like `#![feature(type_alias_impl_trait)]`).
      74           27 : #[derive(Clone)]
      75              : pub struct CancelClosure {
      76              :     socket_addr: SocketAddr,
      77              :     cancel_token: CancelToken,
      78              : }
      79              : 
      80              : impl CancelClosure {
      81           27 :     pub fn new(socket_addr: SocketAddr, cancel_token: CancelToken) -> Self {
      82           27 :         Self {
      83           27 :             socket_addr,
      84           27 :             cancel_token,
      85           27 :         }
      86           27 :     }
      87              : 
      88              :     /// Cancels the query running on user's compute node.
      89            0 :     pub async fn try_cancel_query(self) -> anyhow::Result<()> {
      90            0 :         let socket = TcpStream::connect(self.socket_addr).await?;
      91            0 :         self.cancel_token.cancel_query_raw(socket, NoTls).await?;
      92              : 
      93            0 :         Ok(())
      94            0 :     }
      95              : }
      96              : 
      97              : /// Helper for registering query cancellation tokens.
      98              : pub struct Session<'a> {
      99              :     /// The user-facing key identifying this session.
     100              :     key: CancelKeyData,
     101              :     /// The [`CancelMap`] this session belongs to.
     102              :     cancel_map: &'a CancelMap,
     103              : }
     104              : 
     105              : impl<'a> Session<'a> {
     106           32 :     fn new(key: CancelKeyData, cancel_map: &'a CancelMap) -> Self {
     107           32 :         Self { key, cancel_map }
     108           32 :     }
     109              : }
     110              : 
     111              : impl Session<'_> {
     112              :     /// Store the cancel token for the given session.
     113              :     /// This enables query cancellation in `crate::proxy::prepare_client_connection`.
     114           27 :     pub fn enable_query_cancellation(self, cancel_closure: CancelClosure) -> CancelKeyData {
     115           27 :         info!("enabling query cancellation for this session");
     116           27 :         self.cancel_map
     117           27 :             .0
     118           27 :             .write()
     119           27 :             .insert(self.key, Some(cancel_closure));
     120           27 : 
     121           27 :         self.key
     122           27 :     }
     123              : }
     124              : 
     125              : #[cfg(test)]
     126              : mod tests {
     127              :     use super::*;
     128              :     use once_cell::sync::Lazy;
     129              : 
     130            1 :     #[tokio::test]
     131            1 :     async fn check_session_drop() -> anyhow::Result<()> {
     132            1 :         static CANCEL_MAP: Lazy<CancelMap> = Lazy::new(Default::default);
     133            1 : 
     134            1 :         let (tx, rx) = tokio::sync::oneshot::channel();
     135            1 :         let task = tokio::spawn(CANCEL_MAP.with_session(|session| async move {
     136            1 :             assert!(CANCEL_MAP.contains(&session));
     137              : 
     138            1 :             tx.send(()).expect("failed to send");
     139            1 :             futures::future::pending::<()>().await; // sleep forever
     140              : 
     141            0 :             Ok(())
     142            1 :         }));
     143            1 : 
     144            1 :         // Wait until the task has been spawned.
     145            1 :         rx.await.context("failed to hear from the task")?;
     146              : 
     147              :         // Drop the session's entry by cancelling the task.
     148            1 :         task.abort();
     149            1 :         let error = task.await.expect_err("task should have failed");
     150            1 :         if !error.is_cancelled() {
     151            0 :             anyhow::bail!(error);
     152            1 :         }
     153            1 : 
     154            1 :         // Check that the session has been dropped.
     155            1 :         assert!(CANCEL_MAP.is_empty());
     156              : 
     157            1 :         Ok(())
     158              :     }
     159              : }
        

Generated by: LCOV version 2.1-beta