Line data Source code
1 : use async_trait::async_trait;
2 : use dashmap::DashMap;
3 : use pq_proto::CancelKeyData;
4 : use std::{net::SocketAddr, sync::Arc};
5 : use thiserror::Error;
6 : use tokio::net::TcpStream;
7 : use tokio::sync::Mutex;
8 : use tokio_postgres::{CancelToken, NoTls};
9 : use tracing::info;
10 : use uuid::Uuid;
11 :
12 : use crate::{
13 : error::ReportableError, metrics::NUM_CANCELLATION_REQUESTS,
14 : redis::publisher::RedisPublisherClient,
15 : };
16 :
17 : pub type CancelMap = Arc<DashMap<CancelKeyData, Option<CancelClosure>>>;
18 :
19 : /// Enables serving `CancelRequest`s.
20 : ///
21 : /// If there is a `RedisPublisherClient` available, it will be used to publish the cancellation key to other proxy instances.
22 : pub struct CancellationHandler {
23 : map: CancelMap,
24 : redis_client: Option<Arc<Mutex<RedisPublisherClient>>>,
25 : }
26 :
27 0 : #[derive(Debug, Error)]
28 : pub enum CancelError {
29 : #[error("{0}")]
30 : IO(#[from] std::io::Error),
31 : #[error("{0}")]
32 : Postgres(#[from] tokio_postgres::Error),
33 : }
34 :
35 : impl ReportableError for CancelError {
36 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
37 0 : match self {
38 0 : CancelError::IO(_) => crate::error::ErrorKind::Compute,
39 0 : CancelError::Postgres(e) if e.as_db_error().is_some() => {
40 0 : crate::error::ErrorKind::Postgres
41 : }
42 0 : CancelError::Postgres(_) => crate::error::ErrorKind::Compute,
43 : }
44 0 : }
45 : }
46 :
47 : impl CancellationHandler {
48 25 : pub fn new(map: CancelMap, redis_client: Option<Arc<Mutex<RedisPublisherClient>>>) -> Self {
49 25 : Self { map, redis_client }
50 25 : }
51 : /// Cancel a running query for the corresponding connection.
52 0 : pub async fn cancel_session(
53 0 : &self,
54 0 : key: CancelKeyData,
55 0 : session_id: Uuid,
56 0 : ) -> Result<(), CancelError> {
57 0 : let from = "from_client";
58 : // NB: we should immediately release the lock after cloning the token.
59 0 : let Some(cancel_closure) = self.map.get(&key).and_then(|x| x.clone()) else {
60 0 : tracing::warn!("query cancellation key not found: {key}");
61 0 : if let Some(redis_client) = &self.redis_client {
62 0 : NUM_CANCELLATION_REQUESTS
63 0 : .with_label_values(&[from, "not_found"])
64 0 : .inc();
65 0 : info!("publishing cancellation key to Redis");
66 0 : match redis_client.lock().await.try_publish(key, session_id).await {
67 : Ok(()) => {
68 0 : info!("cancellation key successfuly published to Redis");
69 : }
70 0 : Err(e) => {
71 0 : tracing::error!("failed to publish a message: {e}");
72 0 : return Err(CancelError::IO(std::io::Error::new(
73 0 : std::io::ErrorKind::Other,
74 0 : e.to_string(),
75 0 : )));
76 : }
77 : }
78 0 : }
79 0 : return Ok(());
80 : };
81 0 : NUM_CANCELLATION_REQUESTS
82 0 : .with_label_values(&[from, "found"])
83 0 : .inc();
84 0 : info!("cancelling query per user's request using key {key}");
85 0 : cancel_closure.try_cancel_query().await
86 0 : }
87 :
88 : /// Run async action within an ephemeral session identified by [`CancelKeyData`].
89 43 : pub fn get_session(self: Arc<Self>) -> Session {
90 : // HACK: We'd rather get the real backend_pid but tokio_postgres doesn't
91 : // expose it and we don't want to do another roundtrip to query
92 : // for it. The client will be able to notice that this is not the
93 : // actual backend_pid, but backend_pid is not used for anything
94 : // so it doesn't matter.
95 43 : let key = loop {
96 43 : let key = rand::random();
97 43 :
98 43 : // Random key collisions are unlikely to happen here, but they're still possible,
99 43 : // which is why we have to take care not to rewrite an existing key.
100 43 : match self.map.entry(key) {
101 0 : dashmap::mapref::entry::Entry::Occupied(_) => continue,
102 43 : dashmap::mapref::entry::Entry::Vacant(e) => {
103 43 : e.insert(None);
104 43 : }
105 43 : }
106 43 : break key;
107 43 : };
108 43 :
109 43 : info!("registered new query cancellation key {key}");
110 43 : Session {
111 43 : key,
112 43 : cancellation_handler: self,
113 43 : }
114 43 : }
115 :
116 : #[cfg(test)]
117 2 : fn contains(&self, session: &Session) -> bool {
118 2 : self.map.contains_key(&session.key)
119 2 : }
120 :
121 : #[cfg(test)]
122 2 : fn is_empty(&self) -> bool {
123 2 : self.map.is_empty()
124 2 : }
125 : }
126 :
127 : #[async_trait]
128 : pub trait NotificationsCancellationHandler {
129 : async fn cancel_session_no_publish(&self, key: CancelKeyData) -> Result<(), CancelError>;
130 : }
131 :
132 : #[async_trait]
133 : impl NotificationsCancellationHandler for CancellationHandler {
134 0 : async fn cancel_session_no_publish(&self, key: CancelKeyData) -> Result<(), CancelError> {
135 0 : let from = "from_redis";
136 0 : let cancel_closure = self.map.get(&key).and_then(|x| x.clone());
137 0 : match cancel_closure {
138 0 : Some(cancel_closure) => {
139 0 : NUM_CANCELLATION_REQUESTS
140 0 : .with_label_values(&[from, "found"])
141 0 : .inc();
142 0 : cancel_closure.try_cancel_query().await
143 : }
144 : None => {
145 0 : NUM_CANCELLATION_REQUESTS
146 0 : .with_label_values(&[from, "not_found"])
147 0 : .inc();
148 0 : tracing::warn!("query cancellation key not found: {key}");
149 0 : Ok(())
150 : }
151 : }
152 0 : }
153 : }
154 :
155 : /// This should've been a [`std::future::Future`], but
156 : /// it's impossible to name a type of an unboxed future
157 : /// (we'd need something like `#![feature(type_alias_impl_trait)]`).
158 41 : #[derive(Clone)]
159 : pub struct CancelClosure {
160 : socket_addr: SocketAddr,
161 : cancel_token: CancelToken,
162 : }
163 :
164 : impl CancelClosure {
165 41 : pub fn new(socket_addr: SocketAddr, cancel_token: CancelToken) -> Self {
166 41 : Self {
167 41 : socket_addr,
168 41 : cancel_token,
169 41 : }
170 41 : }
171 :
172 : /// Cancels the query running on user's compute node.
173 0 : async fn try_cancel_query(self) -> Result<(), CancelError> {
174 0 : let socket = TcpStream::connect(self.socket_addr).await?;
175 0 : self.cancel_token.cancel_query_raw(socket, NoTls).await?;
176 :
177 0 : Ok(())
178 0 : }
179 : }
180 :
181 : /// Helper for registering query cancellation tokens.
182 : pub struct Session {
183 : /// The user-facing key identifying this session.
184 : key: CancelKeyData,
185 : /// The [`CancelMap`] this session belongs to.
186 : cancellation_handler: Arc<CancellationHandler>,
187 : }
188 :
189 : impl Session {
190 : /// Store the cancel token for the given session.
191 : /// This enables query cancellation in `crate::proxy::prepare_client_connection`.
192 41 : pub fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData {
193 41 : info!("enabling query cancellation for this session");
194 41 : self.cancellation_handler
195 41 : .map
196 41 : .insert(self.key, Some(cancel_closure));
197 41 :
198 41 : self.key
199 41 : }
200 : }
201 :
202 : impl Drop for Session {
203 43 : fn drop(&mut self) {
204 43 : self.cancellation_handler.map.remove(&self.key);
205 43 : info!("dropped query cancellation key {}", &self.key);
206 43 : }
207 : }
208 :
209 : #[cfg(test)]
210 : mod tests {
211 : use super::*;
212 :
213 2 : #[tokio::test]
214 2 : async fn check_session_drop() -> anyhow::Result<()> {
215 2 : let cancellation_handler = Arc::new(CancellationHandler {
216 2 : map: CancelMap::default(),
217 2 : redis_client: None,
218 2 : });
219 2 :
220 2 : let session = cancellation_handler.clone().get_session();
221 2 : assert!(cancellation_handler.contains(&session));
222 2 : drop(session);
223 2 : // Check that the session has been dropped.
224 2 : assert!(cancellation_handler.is_empty());
225 2 :
226 2 : Ok(())
227 2 : }
228 : }
|