Line data Source code
1 : use dashmap::DashMap;
2 : use pq_proto::CancelKeyData;
3 : use std::{net::SocketAddr, sync::Arc};
4 : use thiserror::Error;
5 : use tokio::net::TcpStream;
6 : use tokio::sync::Mutex;
7 : use tokio_postgres::{CancelToken, NoTls};
8 : use tracing::info;
9 : use uuid::Uuid;
10 :
11 : use crate::{
12 : error::ReportableError,
13 : metrics::NUM_CANCELLATION_REQUESTS,
14 : redis::cancellation_publisher::{
15 : CancellationPublisher, CancellationPublisherMut, RedisPublisherClient,
16 : },
17 : };
18 :
19 : pub type CancelMap = Arc<DashMap<CancelKeyData, Option<CancelClosure>>>;
20 : pub type CancellationHandlerMain = CancellationHandler<Option<Arc<Mutex<RedisPublisherClient>>>>;
21 : pub type CancellationHandlerMainInternal = Option<Arc<Mutex<RedisPublisherClient>>>;
22 :
23 : /// Enables serving `CancelRequest`s.
24 : ///
25 : /// If `CancellationPublisher` is available, cancel request will be used to publish the cancellation key to other proxy instances.
26 : pub struct CancellationHandler<P> {
27 : map: CancelMap,
28 : client: P,
29 : /// This field used for the monitoring purposes.
30 : /// Represents the source of the cancellation request.
31 : from: &'static str,
32 : }
33 :
34 0 : #[derive(Debug, Error)]
35 : pub enum CancelError {
36 : #[error("{0}")]
37 : IO(#[from] std::io::Error),
38 : #[error("{0}")]
39 : Postgres(#[from] tokio_postgres::Error),
40 : }
41 :
42 : impl ReportableError for CancelError {
43 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
44 0 : match self {
45 0 : CancelError::IO(_) => crate::error::ErrorKind::Compute,
46 0 : CancelError::Postgres(e) if e.as_db_error().is_some() => {
47 0 : crate::error::ErrorKind::Postgres
48 : }
49 0 : CancelError::Postgres(_) => crate::error::ErrorKind::Compute,
50 : }
51 0 : }
52 : }
53 :
54 : impl<P: CancellationPublisher> CancellationHandler<P> {
55 : /// Run async action within an ephemeral session identified by [`CancelKeyData`].
56 2 : pub fn get_session(self: Arc<Self>) -> Session<P> {
57 : // HACK: We'd rather get the real backend_pid but tokio_postgres doesn't
58 : // expose it and we don't want to do another roundtrip to query
59 : // for it. The client will be able to notice that this is not the
60 : // actual backend_pid, but backend_pid is not used for anything
61 : // so it doesn't matter.
62 2 : let key = loop {
63 2 : let key = rand::random();
64 2 :
65 2 : // Random key collisions are unlikely to happen here, but they're still possible,
66 2 : // which is why we have to take care not to rewrite an existing key.
67 2 : match self.map.entry(key) {
68 0 : dashmap::mapref::entry::Entry::Occupied(_) => continue,
69 2 : dashmap::mapref::entry::Entry::Vacant(e) => {
70 2 : e.insert(None);
71 2 : }
72 2 : }
73 2 : break key;
74 2 : };
75 2 :
76 2 : info!("registered new query cancellation key {key}");
77 2 : Session {
78 2 : key,
79 2 : cancellation_handler: self,
80 2 : }
81 2 : }
82 : /// Try to cancel a running query for the corresponding connection.
83 : /// If the cancellation key is not found, it will be published to Redis.
84 2 : pub async fn cancel_session(
85 2 : &self,
86 2 : key: CancelKeyData,
87 2 : session_id: Uuid,
88 2 : ) -> Result<(), CancelError> {
89 : // NB: we should immediately release the lock after cloning the token.
90 2 : let Some(cancel_closure) = self.map.get(&key).and_then(|x| x.clone()) else {
91 2 : tracing::warn!("query cancellation key not found: {key}");
92 2 : NUM_CANCELLATION_REQUESTS
93 2 : .with_label_values(&[self.from, "not_found"])
94 2 : .inc();
95 2 : match self.client.try_publish(key, session_id).await {
96 2 : Ok(()) => {} // do nothing
97 0 : Err(e) => {
98 0 : return Err(CancelError::IO(std::io::Error::new(
99 0 : std::io::ErrorKind::Other,
100 0 : e.to_string(),
101 0 : )));
102 : }
103 : }
104 2 : return Ok(());
105 : };
106 0 : NUM_CANCELLATION_REQUESTS
107 0 : .with_label_values(&[self.from, "found"])
108 0 : .inc();
109 0 : info!("cancelling query per user's request using key {key}");
110 0 : cancel_closure.try_cancel_query().await
111 2 : }
112 :
113 : #[cfg(test)]
114 2 : fn contains(&self, session: &Session<P>) -> bool {
115 2 : self.map.contains_key(&session.key)
116 2 : }
117 :
118 : #[cfg(test)]
119 2 : fn is_empty(&self) -> bool {
120 2 : self.map.is_empty()
121 2 : }
122 : }
123 :
124 : impl CancellationHandler<()> {
125 4 : pub fn new(map: CancelMap, from: &'static str) -> Self {
126 4 : Self {
127 4 : map,
128 4 : client: (),
129 4 : from,
130 4 : }
131 4 : }
132 : }
133 :
134 : impl<P: CancellationPublisherMut> CancellationHandler<Option<Arc<Mutex<P>>>> {
135 0 : pub fn new(map: CancelMap, client: Option<Arc<Mutex<P>>>, from: &'static str) -> Self {
136 0 : Self { map, client, from }
137 0 : }
138 : }
139 :
140 : /// This should've been a [`std::future::Future`], but
141 : /// it's impossible to name a type of an unboxed future
142 : /// (we'd need something like `#![feature(type_alias_impl_trait)]`).
143 : #[derive(Clone)]
144 : pub struct CancelClosure {
145 : socket_addr: SocketAddr,
146 : cancel_token: CancelToken,
147 : }
148 :
149 : impl CancelClosure {
150 0 : pub fn new(socket_addr: SocketAddr, cancel_token: CancelToken) -> Self {
151 0 : Self {
152 0 : socket_addr,
153 0 : cancel_token,
154 0 : }
155 0 : }
156 : /// Cancels the query running on user's compute node.
157 0 : pub async fn try_cancel_query(self) -> Result<(), CancelError> {
158 0 : let socket = TcpStream::connect(self.socket_addr).await?;
159 0 : self.cancel_token.cancel_query_raw(socket, NoTls).await?;
160 0 : info!("query was cancelled");
161 0 : Ok(())
162 0 : }
163 : }
164 :
165 : /// Helper for registering query cancellation tokens.
166 : pub struct Session<P> {
167 : /// The user-facing key identifying this session.
168 : key: CancelKeyData,
169 : /// The [`CancelMap`] this session belongs to.
170 : cancellation_handler: Arc<CancellationHandler<P>>,
171 : }
172 :
173 : impl<P> Session<P> {
174 : /// Store the cancel token for the given session.
175 : /// This enables query cancellation in `crate::proxy::prepare_client_connection`.
176 0 : pub fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData {
177 0 : info!("enabling query cancellation for this session");
178 0 : self.cancellation_handler
179 0 : .map
180 0 : .insert(self.key, Some(cancel_closure));
181 0 :
182 0 : self.key
183 0 : }
184 : }
185 :
186 : impl<P> Drop for Session<P> {
187 2 : fn drop(&mut self) {
188 2 : self.cancellation_handler.map.remove(&self.key);
189 2 : info!("dropped query cancellation key {}", &self.key);
190 2 : }
191 : }
192 :
193 : #[cfg(test)]
194 : mod tests {
195 : use crate::metrics::NUM_CANCELLATION_REQUESTS_SOURCE_FROM_REDIS;
196 :
197 : use super::*;
198 :
199 : #[tokio::test]
200 2 : async fn check_session_drop() -> anyhow::Result<()> {
201 2 : let cancellation_handler = Arc::new(CancellationHandler::<()>::new(
202 2 : CancelMap::default(),
203 2 : NUM_CANCELLATION_REQUESTS_SOURCE_FROM_REDIS,
204 2 : ));
205 2 :
206 2 : let session = cancellation_handler.clone().get_session();
207 2 : assert!(cancellation_handler.contains(&session));
208 2 : drop(session);
209 2 : // Check that the session has been dropped.
210 2 : assert!(cancellation_handler.is_empty());
211 2 :
212 2 : Ok(())
213 2 : }
214 :
215 : #[tokio::test]
216 2 : async fn cancel_session_noop_regression() {
217 2 : let handler = CancellationHandler::<()>::new(Default::default(), "local");
218 2 : handler
219 2 : .cancel_session(
220 2 : CancelKeyData {
221 2 : backend_pid: 0,
222 2 : cancel_key: 0,
223 2 : },
224 2 : Uuid::new_v4(),
225 2 : )
226 2 : .await
227 2 : .unwrap();
228 2 : }
229 : }
|