Line data Source code
1 : use async_trait::async_trait;
2 : use dashmap::DashMap;
3 : use pq_proto::CancelKeyData;
4 : use std::{net::SocketAddr, sync::Arc};
5 : use thiserror::Error;
6 : use tokio::net::TcpStream;
7 : use tokio::sync::Mutex;
8 : use tokio_postgres::{CancelToken, NoTls};
9 : use tracing::info;
10 : use uuid::Uuid;
11 :
12 : use crate::{
13 : error::ReportableError, metrics::NUM_CANCELLATION_REQUESTS,
14 : redis::publisher::RedisPublisherClient,
15 : };
16 :
17 : pub type CancelMap = Arc<DashMap<CancelKeyData, Option<CancelClosure>>>;
18 :
19 : /// Enables serving `CancelRequest`s.
20 : ///
21 : /// If there is a `RedisPublisherClient` available, it will be used to publish the cancellation key to other proxy instances.
22 : pub struct CancellationHandler {
23 : map: CancelMap,
24 : redis_client: Option<Arc<Mutex<RedisPublisherClient>>>,
25 : }
26 :
27 0 : #[derive(Debug, Error)]
28 : pub enum CancelError {
29 : #[error("{0}")]
30 : IO(#[from] std::io::Error),
31 : #[error("{0}")]
32 : Postgres(#[from] tokio_postgres::Error),
33 : }
34 :
35 : impl ReportableError for CancelError {
36 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
37 0 : match self {
38 0 : CancelError::IO(_) => crate::error::ErrorKind::Compute,
39 0 : CancelError::Postgres(e) if e.as_db_error().is_some() => {
40 0 : crate::error::ErrorKind::Postgres
41 : }
42 0 : CancelError::Postgres(_) => crate::error::ErrorKind::Compute,
43 : }
44 0 : }
45 : }
46 :
47 : impl CancellationHandler {
48 0 : pub fn new(map: CancelMap, redis_client: Option<Arc<Mutex<RedisPublisherClient>>>) -> Self {
49 0 : Self { map, redis_client }
50 0 : }
51 : /// Cancel a running query for the corresponding connection.
52 0 : pub async fn cancel_session(
53 0 : &self,
54 0 : key: CancelKeyData,
55 0 : session_id: Uuid,
56 0 : ) -> Result<(), CancelError> {
57 0 : let from = "from_client";
58 : // NB: we should immediately release the lock after cloning the token.
59 0 : let Some(cancel_closure) = self.map.get(&key).and_then(|x| x.clone()) else {
60 0 : tracing::warn!("query cancellation key not found: {key}");
61 0 : if let Some(redis_client) = &self.redis_client {
62 0 : NUM_CANCELLATION_REQUESTS
63 0 : .with_label_values(&[from, "not_found"])
64 0 : .inc();
65 0 : info!("publishing cancellation key to Redis");
66 0 : match redis_client.lock().await.try_publish(key, session_id).await {
67 : Ok(()) => {
68 0 : info!("cancellation key successfuly published to Redis");
69 : }
70 0 : Err(e) => {
71 0 : tracing::error!("failed to publish a message: {e}");
72 0 : return Err(CancelError::IO(std::io::Error::new(
73 0 : std::io::ErrorKind::Other,
74 0 : e.to_string(),
75 0 : )));
76 : }
77 : }
78 0 : }
79 0 : return Ok(());
80 : };
81 0 : NUM_CANCELLATION_REQUESTS
82 0 : .with_label_values(&[from, "found"])
83 0 : .inc();
84 0 : info!("cancelling query per user's request using key {key}");
85 0 : cancel_closure.try_cancel_query().await
86 0 : }
87 :
88 : /// Run async action within an ephemeral session identified by [`CancelKeyData`].
89 2 : pub fn get_session(self: Arc<Self>) -> Session {
90 : // HACK: We'd rather get the real backend_pid but tokio_postgres doesn't
91 : // expose it and we don't want to do another roundtrip to query
92 : // for it. The client will be able to notice that this is not the
93 : // actual backend_pid, but backend_pid is not used for anything
94 : // so it doesn't matter.
95 2 : let key = loop {
96 2 : let key = rand::random();
97 2 :
98 2 : // Random key collisions are unlikely to happen here, but they're still possible,
99 2 : // which is why we have to take care not to rewrite an existing key.
100 2 : match self.map.entry(key) {
101 0 : dashmap::mapref::entry::Entry::Occupied(_) => continue,
102 2 : dashmap::mapref::entry::Entry::Vacant(e) => {
103 2 : e.insert(None);
104 2 : }
105 2 : }
106 2 : break key;
107 2 : };
108 2 :
109 2 : info!("registered new query cancellation key {key}");
110 2 : Session {
111 2 : key,
112 2 : cancellation_handler: self,
113 2 : }
114 2 : }
115 :
116 : #[cfg(test)]
117 2 : fn contains(&self, session: &Session) -> bool {
118 2 : self.map.contains_key(&session.key)
119 2 : }
120 :
121 : #[cfg(test)]
122 2 : fn is_empty(&self) -> bool {
123 2 : self.map.is_empty()
124 2 : }
125 : }
126 :
127 : #[async_trait]
128 : pub trait NotificationsCancellationHandler {
129 : async fn cancel_session_no_publish(&self, key: CancelKeyData) -> Result<(), CancelError>;
130 : }
131 :
132 : #[async_trait]
133 : impl NotificationsCancellationHandler for CancellationHandler {
134 0 : async fn cancel_session_no_publish(&self, key: CancelKeyData) -> Result<(), CancelError> {
135 0 : let from = "from_redis";
136 0 : let cancel_closure = self.map.get(&key).and_then(|x| x.clone());
137 0 : match cancel_closure {
138 0 : Some(cancel_closure) => {
139 0 : NUM_CANCELLATION_REQUESTS
140 0 : .with_label_values(&[from, "found"])
141 0 : .inc();
142 0 : cancel_closure.try_cancel_query().await
143 : }
144 : None => {
145 0 : NUM_CANCELLATION_REQUESTS
146 0 : .with_label_values(&[from, "not_found"])
147 0 : .inc();
148 0 : tracing::warn!("query cancellation key not found: {key}");
149 0 : Ok(())
150 : }
151 : }
152 0 : }
153 : }
154 :
155 : /// This should've been a [`std::future::Future`], but
156 : /// it's impossible to name a type of an unboxed future
157 : /// (we'd need something like `#![feature(type_alias_impl_trait)]`).
158 0 : #[derive(Clone)]
159 : pub struct CancelClosure {
160 : socket_addr: SocketAddr,
161 : cancel_token: CancelToken,
162 : }
163 :
164 : impl CancelClosure {
165 0 : pub fn new(socket_addr: SocketAddr, cancel_token: CancelToken) -> Self {
166 0 : Self {
167 0 : socket_addr,
168 0 : cancel_token,
169 0 : }
170 0 : }
171 : /// Cancels the query running on user's compute node.
172 0 : pub async fn try_cancel_query(self) -> Result<(), CancelError> {
173 0 : let socket = TcpStream::connect(self.socket_addr).await?;
174 0 : self.cancel_token.cancel_query_raw(socket, NoTls).await?;
175 0 : info!("query was cancelled");
176 0 : Ok(())
177 0 : }
178 : }
179 :
180 : /// Helper for registering query cancellation tokens.
181 : pub struct Session {
182 : /// The user-facing key identifying this session.
183 : key: CancelKeyData,
184 : /// The [`CancelMap`] this session belongs to.
185 : cancellation_handler: Arc<CancellationHandler>,
186 : }
187 :
188 : impl Session {
189 : /// Store the cancel token for the given session.
190 : /// This enables query cancellation in `crate::proxy::prepare_client_connection`.
191 0 : pub fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData {
192 0 : info!("enabling query cancellation for this session");
193 0 : self.cancellation_handler
194 0 : .map
195 0 : .insert(self.key, Some(cancel_closure));
196 0 :
197 0 : self.key
198 0 : }
199 : }
200 :
201 : impl Drop for Session {
202 2 : fn drop(&mut self) {
203 2 : self.cancellation_handler.map.remove(&self.key);
204 2 : info!("dropped query cancellation key {}", &self.key);
205 2 : }
206 : }
207 :
208 : #[cfg(test)]
209 : mod tests {
210 : use super::*;
211 :
212 2 : #[tokio::test]
213 2 : async fn check_session_drop() -> anyhow::Result<()> {
214 2 : let cancellation_handler = Arc::new(CancellationHandler {
215 2 : map: CancelMap::default(),
216 2 : redis_client: None,
217 2 : });
218 2 :
219 2 : let session = cancellation_handler.clone().get_session();
220 2 : assert!(cancellation_handler.contains(&session));
221 2 : drop(session);
222 2 : // Check that the session has been dropped.
223 2 : assert!(cancellation_handler.is_empty());
224 2 :
225 2 : Ok(())
226 2 : }
227 : }
|