TLA Line data Source code
1 : use anyhow::{bail, Context};
2 : use dashmap::DashMap;
3 : use pq_proto::CancelKeyData;
4 : use std::net::SocketAddr;
5 : use tokio::net::TcpStream;
6 : use tokio_postgres::{CancelToken, NoTls};
7 : use tracing::info;
8 :
9 : /// Enables serving `CancelRequest`s.
10 CBC 90 : #[derive(Default)]
11 : pub struct CancelMap(DashMap<CancelKeyData, Option<CancelClosure>>);
12 :
13 : impl CancelMap {
14 : /// Cancel a running query for the corresponding connection.
15 UBC 0 : pub async fn cancel_session(&self, key: CancelKeyData) -> anyhow::Result<()> {
16 : // NB: we should immediately release the lock after cloning the token.
17 0 : let cancel_closure = self
18 0 : .0
19 0 : .get(&key)
20 0 : .and_then(|x| x.clone())
21 0 : .with_context(|| format!("query cancellation key not found: {key}"))?;
22 :
23 0 : info!("cancelling query per user's request using key {key}");
24 0 : cancel_closure.try_cancel_query().await
25 0 : }
26 :
27 : /// Run async action within an ephemeral session identified by [`CancelKeyData`].
28 CBC 50 : pub async fn with_session<'a, F, R, V>(&'a self, f: F) -> anyhow::Result<V>
29 50 : where
30 50 : F: FnOnce(Session<'a>) -> R,
31 50 : R: std::future::Future<Output = anyhow::Result<V>>,
32 50 : {
33 50 : // HACK: We'd rather get the real backend_pid but tokio_postgres doesn't
34 50 : // expose it and we don't want to do another roundtrip to query
35 50 : // for it. The client will be able to notice that this is not the
36 50 : // actual backend_pid, but backend_pid is not used for anything
37 50 : // so it doesn't matter.
38 50 : let key = rand::random();
39 50 :
40 50 : // Random key collisions are unlikely to happen here, but they're still possible,
41 50 : // which is why we have to take care not to rewrite an existing key.
42 50 : match self.0.entry(key) {
43 : dashmap::mapref::entry::Entry::Occupied(_) => {
44 UBC 0 : bail!("query cancellation key already exists: {key}")
45 : }
46 CBC 50 : dashmap::mapref::entry::Entry::Vacant(e) => {
47 50 : e.insert(None);
48 50 : }
49 : }
50 :
51 : // This will guarantee that the session gets dropped
52 : // as soon as the future is finished.
53 50 : scopeguard::defer! {
54 50 : self.0.remove(&key);
55 50 : info!("dropped query cancellation key {key}");
56 : }
57 :
58 49 : info!("registered new query cancellation key {key}");
59 50 : let session = Session::new(key, self);
60 877 : f(session).await
61 49 : }
62 :
63 : #[cfg(test)]
64 1 : fn contains(&self, session: &Session) -> bool {
65 1 : self.0.contains_key(&session.key)
66 1 : }
67 :
68 : #[cfg(test)]
69 1 : fn is_empty(&self) -> bool {
70 1 : self.0.is_empty()
71 1 : }
72 : }
73 :
74 : /// This should've been a [`std::future::Future`], but
75 : /// it's impossible to name a type of an unboxed future
76 : /// (we'd need something like `#![feature(type_alias_impl_trait)]`).
77 38 : #[derive(Clone)]
78 : pub struct CancelClosure {
79 : socket_addr: SocketAddr,
80 : cancel_token: CancelToken,
81 : }
82 :
83 : impl CancelClosure {
84 38 : pub fn new(socket_addr: SocketAddr, cancel_token: CancelToken) -> Self {
85 38 : Self {
86 38 : socket_addr,
87 38 : cancel_token,
88 38 : }
89 38 : }
90 :
91 : /// Cancels the query running on user's compute node.
92 UBC 0 : pub async fn try_cancel_query(self) -> anyhow::Result<()> {
93 0 : let socket = TcpStream::connect(self.socket_addr).await?;
94 0 : self.cancel_token.cancel_query_raw(socket, NoTls).await?;
95 :
96 0 : Ok(())
97 0 : }
98 : }
99 :
100 : /// Helper for registering query cancellation tokens.
101 : pub struct Session<'a> {
102 : /// The user-facing key identifying this session.
103 : key: CancelKeyData,
104 : /// The [`CancelMap`] this session belongs to.
105 : cancel_map: &'a CancelMap,
106 : }
107 :
108 : impl<'a> Session<'a> {
109 CBC 50 : fn new(key: CancelKeyData, cancel_map: &'a CancelMap) -> Self {
110 50 : Self { key, cancel_map }
111 50 : }
112 : }
113 :
114 : impl Session<'_> {
115 : /// Store the cancel token for the given session.
116 : /// This enables query cancellation in `crate::proxy::prepare_client_connection`.
117 38 : pub fn enable_query_cancellation(self, cancel_closure: CancelClosure) -> CancelKeyData {
118 38 : info!("enabling query cancellation for this session");
119 38 : self.cancel_map.0.insert(self.key, Some(cancel_closure));
120 38 :
121 38 : self.key
122 38 : }
123 : }
124 :
125 : #[cfg(test)]
126 : mod tests {
127 : use super::*;
128 : use once_cell::sync::Lazy;
129 :
130 1 : #[tokio::test]
131 1 : async fn check_session_drop() -> anyhow::Result<()> {
132 1 : static CANCEL_MAP: Lazy<CancelMap> = Lazy::new(Default::default);
133 1 :
134 1 : let (tx, rx) = tokio::sync::oneshot::channel();
135 1 : let task = tokio::spawn(CANCEL_MAP.with_session(|session| async move {
136 1 : assert!(CANCEL_MAP.contains(&session));
137 :
138 1 : tx.send(()).expect("failed to send");
139 1 : futures::future::pending::<()>().await; // sleep forever
140 :
141 UBC 0 : Ok(())
142 CBC 1 : }));
143 1 :
144 1 : // Wait until the task has been spawned.
145 1 : rx.await.context("failed to hear from the task")?;
146 :
147 : // Drop the session's entry by cancelling the task.
148 1 : task.abort();
149 1 : let error = task.await.expect_err("task should have failed");
150 1 : if !error.is_cancelled() {
151 UBC 0 : anyhow::bail!(error);
152 CBC 1 : }
153 :
154 : // Check that the session has been dropped.
155 1 : assert!(CANCEL_MAP.is_empty());
156 :
157 1 : Ok(())
158 : }
159 : }
|