Line data Source code
1 : use anyhow::Context;
2 : use dashmap::DashMap;
3 : use pq_proto::CancelKeyData;
4 : use std::{net::SocketAddr, sync::Arc};
5 : use tokio::net::TcpStream;
6 : use tokio_postgres::{CancelToken, NoTls};
7 : use tracing::info;
8 :
9 : /// Enables serving `CancelRequest`s.
10 115 : #[derive(Default)]
11 : pub struct CancelMap(DashMap<CancelKeyData, Option<CancelClosure>>);
12 :
13 : impl CancelMap {
14 : /// Cancel a running query for the corresponding connection.
15 0 : pub async fn cancel_session(&self, key: CancelKeyData) -> anyhow::Result<()> {
16 : // NB: we should immediately release the lock after cloning the token.
17 0 : let cancel_closure = self
18 0 : .0
19 0 : .get(&key)
20 0 : .and_then(|x| x.clone())
21 0 : .with_context(|| format!("query cancellation key not found: {key}"))?;
22 :
23 0 : info!("cancelling query per user's request using key {key}");
24 0 : cancel_closure.try_cancel_query().await
25 0 : }
26 :
27 : /// Run async action within an ephemeral session identified by [`CancelKeyData`].
28 41 : pub fn get_session(self: Arc<Self>) -> Session {
29 : // HACK: We'd rather get the real backend_pid but tokio_postgres doesn't
30 : // expose it and we don't want to do another roundtrip to query
31 : // for it. The client will be able to notice that this is not the
32 : // actual backend_pid, but backend_pid is not used for anything
33 : // so it doesn't matter.
34 41 : let key = loop {
35 41 : let key = rand::random();
36 41 :
37 41 : // Random key collisions are unlikely to happen here, but they're still possible,
38 41 : // which is why we have to take care not to rewrite an existing key.
39 41 : match self.0.entry(key) {
40 0 : dashmap::mapref::entry::Entry::Occupied(_) => continue,
41 41 : dashmap::mapref::entry::Entry::Vacant(e) => {
42 41 : e.insert(None);
43 41 : }
44 41 : }
45 41 : break key;
46 41 : };
47 41 :
48 41 : info!("registered new query cancellation key {key}");
49 41 : Session {
50 41 : key,
51 41 : cancel_map: self,
52 41 : }
53 41 : }
54 :
55 : #[cfg(test)]
56 2 : fn contains(&self, session: &Session) -> bool {
57 2 : self.0.contains_key(&session.key)
58 2 : }
59 :
60 : #[cfg(test)]
61 2 : fn is_empty(&self) -> bool {
62 2 : self.0.is_empty()
63 2 : }
64 : }
65 :
66 : /// This should've been a [`std::future::Future`], but
67 : /// it's impossible to name a type of an unboxed future
68 : /// (we'd need something like `#![feature(type_alias_impl_trait)]`).
69 39 : #[derive(Clone)]
70 : pub struct CancelClosure {
71 : socket_addr: SocketAddr,
72 : cancel_token: CancelToken,
73 : }
74 :
75 : impl CancelClosure {
76 39 : pub fn new(socket_addr: SocketAddr, cancel_token: CancelToken) -> Self {
77 39 : Self {
78 39 : socket_addr,
79 39 : cancel_token,
80 39 : }
81 39 : }
82 :
83 : /// Cancels the query running on user's compute node.
84 0 : pub async fn try_cancel_query(self) -> anyhow::Result<()> {
85 0 : let socket = TcpStream::connect(self.socket_addr).await?;
86 0 : self.cancel_token.cancel_query_raw(socket, NoTls).await?;
87 :
88 0 : Ok(())
89 0 : }
90 : }
91 :
92 : /// Helper for registering query cancellation tokens.
93 : pub struct Session {
94 : /// The user-facing key identifying this session.
95 : key: CancelKeyData,
96 : /// The [`CancelMap`] this session belongs to.
97 : cancel_map: Arc<CancelMap>,
98 : }
99 :
100 : impl Session {
101 : /// Store the cancel token for the given session.
102 : /// This enables query cancellation in `crate::proxy::prepare_client_connection`.
103 39 : pub fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData {
104 39 : info!("enabling query cancellation for this session");
105 39 : self.cancel_map.0.insert(self.key, Some(cancel_closure));
106 39 :
107 39 : self.key
108 39 : }
109 : }
110 :
111 : impl Drop for Session {
112 41 : fn drop(&mut self) {
113 41 : self.cancel_map.0.remove(&self.key);
114 41 : info!("dropped query cancellation key {}", &self.key);
115 41 : }
116 : }
117 :
118 : #[cfg(test)]
119 : mod tests {
120 : use super::*;
121 :
122 2 : #[tokio::test]
123 2 : async fn check_session_drop() -> anyhow::Result<()> {
124 2 : let cancel_map: Arc<CancelMap> = Default::default();
125 2 :
126 2 : let session = cancel_map.clone().get_session();
127 2 : assert!(cancel_map.contains(&session));
128 2 : drop(session);
129 : // Check that the session has been dropped.
130 2 : assert!(cancel_map.is_empty());
131 :
132 2 : Ok(())
133 : }
134 : }
|