Line data Source code
1 : use anyhow::{anyhow, Context};
2 : use hashbrown::HashMap;
3 : use pq_proto::CancelKeyData;
4 : use std::net::SocketAddr;
5 : use tokio::net::TcpStream;
6 : use tokio_postgres::{CancelToken, NoTls};
7 : use tracing::info;
8 :
9 : /// Enables serving `CancelRequest`s.
10 44 : #[derive(Default)]
11 : pub struct CancelMap(parking_lot::RwLock<HashMap<CancelKeyData, Option<CancelClosure>>>);
12 :
13 : impl CancelMap {
14 : /// Cancel a running query for the corresponding connection.
15 0 : pub async fn cancel_session(&self, key: CancelKeyData) -> anyhow::Result<()> {
16 : // NB: we should immediately release the lock after cloning the token.
17 0 : let cancel_closure = self
18 0 : .0
19 0 : .read()
20 0 : .get(&key)
21 0 : .and_then(|x| x.clone())
22 0 : .with_context(|| format!("query cancellation key not found: {key}"))?;
23 :
24 0 : info!("cancelling query per user's request using key {key}");
25 0 : cancel_closure.try_cancel_query().await
26 0 : }
27 :
28 : /// Run async action within an ephemeral session identified by [`CancelKeyData`].
29 32 : pub async fn with_session<'a, F, R, V>(&'a self, f: F) -> anyhow::Result<V>
30 32 : where
31 32 : F: FnOnce(Session<'a>) -> R,
32 32 : R: std::future::Future<Output = anyhow::Result<V>>,
33 32 : {
34 32 : // HACK: We'd rather get the real backend_pid but tokio_postgres doesn't
35 32 : // expose it and we don't want to do another roundtrip to query
36 32 : // for it. The client will be able to notice that this is not the
37 32 : // actual backend_pid, but backend_pid is not used for anything
38 32 : // so it doesn't matter.
39 32 : let key = rand::random();
40 32 :
41 32 : // Random key collisions are unlikely to happen here, but they're still possible,
42 32 : // which is why we have to take care not to rewrite an existing key.
43 32 : self.0
44 32 : .write()
45 32 : .try_insert(key, None)
46 32 : .map_err(|_| anyhow!("query cancellation key already exists: {key}"))?;
47 :
48 : // This will guarantee that the session gets dropped
49 : // as soon as the future is finished.
50 32 : scopeguard::defer! {
51 32 : self.0.write().remove(&key);
52 32 : info!("dropped query cancellation key {key}");
53 : }
54 :
55 31 : info!("registered new query cancellation key {key}");
56 32 : let session = Session::new(key, self);
57 358 : f(session).await
58 31 : }
59 :
60 : #[cfg(test)]
61 1 : fn contains(&self, session: &Session) -> bool {
62 1 : self.0.read().contains_key(&session.key)
63 1 : }
64 :
65 : #[cfg(test)]
66 1 : fn is_empty(&self) -> bool {
67 1 : self.0.read().is_empty()
68 1 : }
69 : }
70 :
71 : /// This should've been a [`std::future::Future`], but
72 : /// it's impossible to name a type of an unboxed future
73 : /// (we'd need something like `#![feature(type_alias_impl_trait)]`).
74 27 : #[derive(Clone)]
75 : pub struct CancelClosure {
76 : socket_addr: SocketAddr,
77 : cancel_token: CancelToken,
78 : }
79 :
80 : impl CancelClosure {
81 27 : pub fn new(socket_addr: SocketAddr, cancel_token: CancelToken) -> Self {
82 27 : Self {
83 27 : socket_addr,
84 27 : cancel_token,
85 27 : }
86 27 : }
87 :
88 : /// Cancels the query running on user's compute node.
89 0 : pub async fn try_cancel_query(self) -> anyhow::Result<()> {
90 0 : let socket = TcpStream::connect(self.socket_addr).await?;
91 0 : self.cancel_token.cancel_query_raw(socket, NoTls).await?;
92 :
93 0 : Ok(())
94 0 : }
95 : }
96 :
97 : /// Helper for registering query cancellation tokens.
98 : pub struct Session<'a> {
99 : /// The user-facing key identifying this session.
100 : key: CancelKeyData,
101 : /// The [`CancelMap`] this session belongs to.
102 : cancel_map: &'a CancelMap,
103 : }
104 :
105 : impl<'a> Session<'a> {
106 32 : fn new(key: CancelKeyData, cancel_map: &'a CancelMap) -> Self {
107 32 : Self { key, cancel_map }
108 32 : }
109 : }
110 :
111 : impl Session<'_> {
112 : /// Store the cancel token for the given session.
113 : /// This enables query cancellation in `crate::proxy::prepare_client_connection`.
114 27 : pub fn enable_query_cancellation(self, cancel_closure: CancelClosure) -> CancelKeyData {
115 27 : info!("enabling query cancellation for this session");
116 27 : self.cancel_map
117 27 : .0
118 27 : .write()
119 27 : .insert(self.key, Some(cancel_closure));
120 27 :
121 27 : self.key
122 27 : }
123 : }
124 :
125 : #[cfg(test)]
126 : mod tests {
127 : use super::*;
128 : use once_cell::sync::Lazy;
129 :
130 1 : #[tokio::test]
131 1 : async fn check_session_drop() -> anyhow::Result<()> {
132 1 : static CANCEL_MAP: Lazy<CancelMap> = Lazy::new(Default::default);
133 1 :
134 1 : let (tx, rx) = tokio::sync::oneshot::channel();
135 1 : let task = tokio::spawn(CANCEL_MAP.with_session(|session| async move {
136 1 : assert!(CANCEL_MAP.contains(&session));
137 :
138 1 : tx.send(()).expect("failed to send");
139 1 : futures::future::pending::<()>().await; // sleep forever
140 :
141 0 : Ok(())
142 1 : }));
143 1 :
144 1 : // Wait until the task has been spawned.
145 1 : rx.await.context("failed to hear from the task")?;
146 :
147 : // Drop the session's entry by cancelling the task.
148 1 : task.abort();
149 1 : let error = task.await.expect_err("task should have failed");
150 1 : if !error.is_cancelled() {
151 0 : anyhow::bail!(error);
152 1 : }
153 1 :
154 1 : // Check that the session has been dropped.
155 1 : assert!(CANCEL_MAP.is_empty());
156 :
157 1 : Ok(())
158 : }
159 : }
|