Line data Source code
1 : use dashmap::DashMap;
2 : use pq_proto::CancelKeyData;
3 : use std::{net::SocketAddr, sync::Arc};
4 : use thiserror::Error;
5 : use tokio::net::TcpStream;
6 : use tokio_postgres::{CancelToken, NoTls};
7 : use tracing::info;
8 :
9 : use crate::error::ReportableError;
10 :
11 : /// Enables serving `CancelRequest`s.
12 74 : #[derive(Default)]
13 : pub struct CancelMap(DashMap<CancelKeyData, Option<CancelClosure>>);
14 :
15 0 : #[derive(Debug, Error)]
16 : pub enum CancelError {
17 : #[error("{0}")]
18 : IO(#[from] std::io::Error),
19 : #[error("{0}")]
20 : Postgres(#[from] tokio_postgres::Error),
21 : }
22 :
23 : impl ReportableError for CancelError {
24 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
25 0 : match self {
26 0 : CancelError::IO(_) => crate::error::ErrorKind::Compute,
27 0 : CancelError::Postgres(e) if e.as_db_error().is_some() => {
28 0 : crate::error::ErrorKind::Postgres
29 : }
30 0 : CancelError::Postgres(_) => crate::error::ErrorKind::Compute,
31 : }
32 0 : }
33 : }
34 :
35 : impl CancelMap {
36 : /// Cancel a running query for the corresponding connection.
37 0 : pub async fn cancel_session(&self, key: CancelKeyData) -> Result<(), CancelError> {
38 : // NB: we should immediately release the lock after cloning the token.
39 0 : let Some(cancel_closure) = self.0.get(&key).and_then(|x| x.clone()) else {
40 0 : tracing::warn!("query cancellation key not found: {key}");
41 0 : return Ok(());
42 : };
43 :
44 0 : info!("cancelling query per user's request using key {key}");
45 0 : cancel_closure.try_cancel_query().await
46 0 : }
47 :
48 : /// Run async action within an ephemeral session identified by [`CancelKeyData`].
49 43 : pub fn get_session(self: Arc<Self>) -> Session {
50 : // HACK: We'd rather get the real backend_pid but tokio_postgres doesn't
51 : // expose it and we don't want to do another roundtrip to query
52 : // for it. The client will be able to notice that this is not the
53 : // actual backend_pid, but backend_pid is not used for anything
54 : // so it doesn't matter.
55 43 : let key = loop {
56 43 : let key = rand::random();
57 43 :
58 43 : // Random key collisions are unlikely to happen here, but they're still possible,
59 43 : // which is why we have to take care not to rewrite an existing key.
60 43 : match self.0.entry(key) {
61 0 : dashmap::mapref::entry::Entry::Occupied(_) => continue,
62 43 : dashmap::mapref::entry::Entry::Vacant(e) => {
63 43 : e.insert(None);
64 43 : }
65 43 : }
66 43 : break key;
67 43 : };
68 43 :
69 43 : info!("registered new query cancellation key {key}");
70 43 : Session {
71 43 : key,
72 43 : cancel_map: self,
73 43 : }
74 43 : }
75 :
76 : #[cfg(test)]
77 2 : fn contains(&self, session: &Session) -> bool {
78 2 : self.0.contains_key(&session.key)
79 2 : }
80 :
81 : #[cfg(test)]
82 2 : fn is_empty(&self) -> bool {
83 2 : self.0.is_empty()
84 2 : }
85 : }
86 :
87 : /// This should've been a [`std::future::Future`], but
88 : /// it's impossible to name a type of an unboxed future
89 : /// (we'd need something like `#![feature(type_alias_impl_trait)]`).
90 41 : #[derive(Clone)]
91 : pub struct CancelClosure {
92 : socket_addr: SocketAddr,
93 : cancel_token: CancelToken,
94 : }
95 :
96 : impl CancelClosure {
97 41 : pub fn new(socket_addr: SocketAddr, cancel_token: CancelToken) -> Self {
98 41 : Self {
99 41 : socket_addr,
100 41 : cancel_token,
101 41 : }
102 41 : }
103 :
104 : /// Cancels the query running on user's compute node.
105 0 : async fn try_cancel_query(self) -> Result<(), CancelError> {
106 0 : let socket = TcpStream::connect(self.socket_addr).await?;
107 0 : self.cancel_token.cancel_query_raw(socket, NoTls).await?;
108 :
109 0 : Ok(())
110 0 : }
111 : }
112 :
113 : /// Helper for registering query cancellation tokens.
114 : pub struct Session {
115 : /// The user-facing key identifying this session.
116 : key: CancelKeyData,
117 : /// The [`CancelMap`] this session belongs to.
118 : cancel_map: Arc<CancelMap>,
119 : }
120 :
121 : impl Session {
122 : /// Store the cancel token for the given session.
123 : /// This enables query cancellation in `crate::proxy::prepare_client_connection`.
124 41 : pub fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData {
125 41 : info!("enabling query cancellation for this session");
126 41 : self.cancel_map.0.insert(self.key, Some(cancel_closure));
127 41 :
128 41 : self.key
129 41 : }
130 : }
131 :
132 : impl Drop for Session {
133 43 : fn drop(&mut self) {
134 43 : self.cancel_map.0.remove(&self.key);
135 43 : info!("dropped query cancellation key {}", &self.key);
136 43 : }
137 : }
138 :
139 : #[cfg(test)]
140 : mod tests {
141 : use super::*;
142 :
143 2 : #[tokio::test]
144 2 : async fn check_session_drop() -> anyhow::Result<()> {
145 2 : let cancel_map: Arc<CancelMap> = Default::default();
146 2 :
147 2 : let session = cancel_map.clone().get_session();
148 2 : assert!(cancel_map.contains(&session));
149 2 : drop(session);
150 2 : // Check that the session has been dropped.
151 2 : assert!(cancel_map.is_empty());
152 2 :
153 2 : Ok(())
154 2 : }
155 : }
|