TLA Line data Source code
1 : use anyhow::{anyhow, Context};
2 : use hashbrown::HashMap;
3 : use pq_proto::CancelKeyData;
4 : use std::net::SocketAddr;
5 : use tokio::net::TcpStream;
6 : use tokio_postgres::{CancelToken, NoTls};
7 : use tracing::info;
8 :
9 : /// Enables serving `CancelRequest`s.
10 CBC 54 : #[derive(Default)]
11 : pub struct CancelMap(parking_lot::RwLock<HashMap<CancelKeyData, Option<CancelClosure>>>);
12 :
13 : impl CancelMap {
14 : /// Cancel a running query for the corresponding connection.
15 UBC 0 : pub async fn cancel_session(&self, key: CancelKeyData) -> anyhow::Result<()> {
16 : // NB: we should immediately release the lock after cloning the token.
17 0 : let cancel_closure = self
18 0 : .0
19 0 : .read()
20 0 : .get(&key)
21 0 : .and_then(|x| x.clone())
22 0 : .with_context(|| format!("query cancellation key not found: {key}"))?;
23 :
24 0 : info!("cancelling query per user's request using key {key}");
25 0 : cancel_closure.try_cancel_query().await
26 0 : }
27 :
28 : /// Run async action within an ephemeral session identified by [`CancelKeyData`].
29 CBC 34 : pub async fn with_session<'a, F, R, V>(&'a self, f: F) -> anyhow::Result<V>
30 34 : where
31 34 : F: FnOnce(Session<'a>) -> R,
32 34 : R: std::future::Future<Output = anyhow::Result<V>>,
33 34 : {
34 34 : // HACK: We'd rather get the real backend_pid but tokio_postgres doesn't
35 34 : // expose it and we don't want to do another roundtrip to query
36 34 : // for it. The client will be able to notice that this is not the
37 34 : // actual backend_pid, but backend_pid is not used for anything
38 34 : // so it doesn't matter.
39 34 : let key = rand::random();
40 34 :
41 34 : // Random key collisions are unlikely to happen here, but they're still possible,
42 34 : // which is why we have to take care not to rewrite an existing key.
43 34 : self.0
44 34 : .write()
45 34 : .try_insert(key, None)
46 34 : .map_err(|_| anyhow!("query cancellation key already exists: {key}"))?;
47 :
48 : // This will guarantee that the session gets dropped
49 : // as soon as the future is finished.
50 34 : scopeguard::defer! {
51 34 : self.0.write().remove(&key);
52 34 : info!("dropped query cancellation key {key}");
53 : }
54 :
55 33 : info!("registered new query cancellation key {key}");
56 34 : let session = Session::new(key, self);
57 397 : f(session).await
58 33 : }
59 :
60 : #[cfg(test)]
61 1 : fn contains(&self, session: &Session) -> bool {
62 1 : self.0.read().contains_key(&session.key)
63 1 : }
64 :
65 : #[cfg(test)]
66 1 : fn is_empty(&self) -> bool {
67 1 : self.0.read().is_empty()
68 1 : }
69 : }
70 :
71 : /// This should've been a [`std::future::Future`], but
72 : /// it's impossible to name a type of an unboxed future
73 : /// (we'd need something like `#![feature(type_alias_impl_trait)]`).
74 29 : #[derive(Clone)]
75 : pub struct CancelClosure {
76 : socket_addr: SocketAddr,
77 : cancel_token: CancelToken,
78 : }
79 :
80 : impl CancelClosure {
81 29 : pub fn new(socket_addr: SocketAddr, cancel_token: CancelToken) -> Self {
82 29 : Self {
83 29 : socket_addr,
84 29 : cancel_token,
85 29 : }
86 29 : }
87 :
88 : /// Cancels the query running on user's compute node.
89 UBC 0 : pub async fn try_cancel_query(self) -> anyhow::Result<()> {
90 0 : let socket = TcpStream::connect(self.socket_addr).await?;
91 0 : self.cancel_token.cancel_query_raw(socket, NoTls).await?;
92 :
93 0 : Ok(())
94 0 : }
95 : }
96 :
97 : /// Helper for registering query cancellation tokens.
98 : pub struct Session<'a> {
99 : /// The user-facing key identifying this session.
100 : key: CancelKeyData,
101 : /// The [`CancelMap`] this session belongs to.
102 : cancel_map: &'a CancelMap,
103 : }
104 :
105 : impl<'a> Session<'a> {
106 CBC 34 : fn new(key: CancelKeyData, cancel_map: &'a CancelMap) -> Self {
107 34 : Self { key, cancel_map }
108 34 : }
109 : }
110 :
111 : impl Session<'_> {
112 : /// Store the cancel token for the given session.
113 : /// This enables query cancellation in `crate::proxy::prepare_client_connection`.
114 29 : pub fn enable_query_cancellation(self, cancel_closure: CancelClosure) -> CancelKeyData {
115 29 : info!("enabling query cancellation for this session");
116 29 : self.cancel_map
117 29 : .0
118 29 : .write()
119 29 : .insert(self.key, Some(cancel_closure));
120 29 :
121 29 : self.key
122 29 : }
123 : }
124 :
125 : #[cfg(test)]
126 : mod tests {
127 : use super::*;
128 : use once_cell::sync::Lazy;
129 :
130 1 : #[tokio::test]
131 1 : async fn check_session_drop() -> anyhow::Result<()> {
132 1 : static CANCEL_MAP: Lazy<CancelMap> = Lazy::new(Default::default);
133 1 :
134 1 : let (tx, rx) = tokio::sync::oneshot::channel();
135 1 : let task = tokio::spawn(CANCEL_MAP.with_session(|session| async move {
136 1 : assert!(CANCEL_MAP.contains(&session));
137 :
138 1 : tx.send(()).expect("failed to send");
139 1 : futures::future::pending::<()>().await; // sleep forever
140 :
141 UBC 0 : Ok(())
142 CBC 1 : }));
143 1 :
144 1 : // Wait until the task has been spawned.
145 1 : rx.await.context("failed to hear from the task")?;
146 :
147 : // Drop the session's entry by cancelling the task.
148 1 : task.abort();
149 1 : let error = task.await.expect_err("task should have failed");
150 1 : if !error.is_cancelled() {
151 UBC 0 : anyhow::bail!(error);
152 CBC 1 : }
153 1 :
154 1 : // Check that the session has been dropped.
155 1 : assert!(CANCEL_MAP.is_empty());
156 :
157 1 : Ok(())
158 : }
159 : }
|