Line data Source code
1 : use std::net::SocketAddr;
2 : use std::sync::Arc;
3 :
4 : use dashmap::DashMap;
5 : use pq_proto::CancelKeyData;
6 : use thiserror::Error;
7 : use tokio::net::TcpStream;
8 : use tokio::sync::Mutex;
9 : use tokio_postgres::{CancelToken, NoTls};
10 : use tracing::{debug, info};
11 : use uuid::Uuid;
12 :
13 : use crate::error::ReportableError;
14 : use crate::metrics::{CancellationRequest, CancellationSource, Metrics};
15 : use crate::redis::cancellation_publisher::{
16 : CancellationPublisher, CancellationPublisherMut, RedisPublisherClient,
17 : };
18 :
19 : pub type CancelMap = Arc<DashMap<CancelKeyData, Option<CancelClosure>>>;
20 : pub type CancellationHandlerMain = CancellationHandler<Option<Arc<Mutex<RedisPublisherClient>>>>;
21 : pub(crate) type CancellationHandlerMainInternal = Option<Arc<Mutex<RedisPublisherClient>>>;
22 :
23 : /// Enables serving `CancelRequest`s.
24 : ///
25 : /// If `CancellationPublisher` is available, cancel request will be used to publish the cancellation key to other proxy instances.
26 : pub struct CancellationHandler<P> {
27 : map: CancelMap,
28 : client: P,
29 : /// This field used for the monitoring purposes.
30 : /// Represents the source of the cancellation request.
31 : from: CancellationSource,
32 : }
33 :
34 0 : #[derive(Debug, Error)]
35 : pub(crate) enum CancelError {
36 : #[error("{0}")]
37 : IO(#[from] std::io::Error),
38 : #[error("{0}")]
39 : Postgres(#[from] tokio_postgres::Error),
40 : }
41 :
42 : impl ReportableError for CancelError {
43 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
44 0 : match self {
45 0 : CancelError::IO(_) => crate::error::ErrorKind::Compute,
46 0 : CancelError::Postgres(e) if e.as_db_error().is_some() => {
47 0 : crate::error::ErrorKind::Postgres
48 : }
49 0 : CancelError::Postgres(_) => crate::error::ErrorKind::Compute,
50 : }
51 0 : }
52 : }
53 :
54 : impl<P: CancellationPublisher> CancellationHandler<P> {
55 : /// Run async action within an ephemeral session identified by [`CancelKeyData`].
56 1 : pub(crate) fn get_session(self: Arc<Self>) -> Session<P> {
57 : // HACK: We'd rather get the real backend_pid but tokio_postgres doesn't
58 : // expose it and we don't want to do another roundtrip to query
59 : // for it. The client will be able to notice that this is not the
60 : // actual backend_pid, but backend_pid is not used for anything
61 : // so it doesn't matter.
62 1 : let key = loop {
63 1 : let key = rand::random();
64 1 :
65 1 : // Random key collisions are unlikely to happen here, but they're still possible,
66 1 : // which is why we have to take care not to rewrite an existing key.
67 1 : match self.map.entry(key) {
68 0 : dashmap::mapref::entry::Entry::Occupied(_) => continue,
69 1 : dashmap::mapref::entry::Entry::Vacant(e) => {
70 1 : e.insert(None);
71 1 : }
72 1 : }
73 1 : break key;
74 1 : };
75 1 :
76 1 : debug!("registered new query cancellation key {key}");
77 1 : Session {
78 1 : key,
79 1 : cancellation_handler: self,
80 1 : }
81 1 : }
82 : /// Try to cancel a running query for the corresponding connection.
83 : /// If the cancellation key is not found, it will be published to Redis.
84 1 : pub(crate) async fn cancel_session(
85 1 : &self,
86 1 : key: CancelKeyData,
87 1 : session_id: Uuid,
88 1 : ) -> Result<(), CancelError> {
89 : // NB: we should immediately release the lock after cloning the token.
90 1 : let Some(cancel_closure) = self.map.get(&key).and_then(|x| x.clone()) else {
91 1 : tracing::warn!("query cancellation key not found: {key}");
92 1 : Metrics::get()
93 1 : .proxy
94 1 : .cancellation_requests_total
95 1 : .inc(CancellationRequest {
96 1 : source: self.from,
97 1 : kind: crate::metrics::CancellationOutcome::NotFound,
98 1 : });
99 1 : match self.client.try_publish(key, session_id).await {
100 1 : Ok(()) => {} // do nothing
101 0 : Err(e) => {
102 0 : return Err(CancelError::IO(std::io::Error::new(
103 0 : std::io::ErrorKind::Other,
104 0 : e.to_string(),
105 0 : )));
106 : }
107 : }
108 1 : return Ok(());
109 : };
110 0 : Metrics::get()
111 0 : .proxy
112 0 : .cancellation_requests_total
113 0 : .inc(CancellationRequest {
114 0 : source: self.from,
115 0 : kind: crate::metrics::CancellationOutcome::Found,
116 0 : });
117 0 : info!("cancelling query per user's request using key {key}");
118 0 : cancel_closure.try_cancel_query().await
119 1 : }
120 :
121 : #[cfg(test)]
122 1 : fn contains(&self, session: &Session<P>) -> bool {
123 1 : self.map.contains_key(&session.key)
124 1 : }
125 :
126 : #[cfg(test)]
127 1 : fn is_empty(&self) -> bool {
128 1 : self.map.is_empty()
129 1 : }
130 : }
131 :
132 : impl CancellationHandler<()> {
133 2 : pub fn new(map: CancelMap, from: CancellationSource) -> Self {
134 2 : Self {
135 2 : map,
136 2 : client: (),
137 2 : from,
138 2 : }
139 2 : }
140 : }
141 :
142 : impl<P: CancellationPublisherMut> CancellationHandler<Option<Arc<Mutex<P>>>> {
143 0 : pub fn new(map: CancelMap, client: Option<Arc<Mutex<P>>>, from: CancellationSource) -> Self {
144 0 : Self { map, client, from }
145 0 : }
146 : }
147 :
148 : /// This should've been a [`std::future::Future`], but
149 : /// it's impossible to name a type of an unboxed future
150 : /// (we'd need something like `#![feature(type_alias_impl_trait)]`).
151 : #[derive(Clone)]
152 : pub struct CancelClosure {
153 : socket_addr: SocketAddr,
154 : cancel_token: CancelToken,
155 : }
156 :
157 : impl CancelClosure {
158 0 : pub(crate) fn new(socket_addr: SocketAddr, cancel_token: CancelToken) -> Self {
159 0 : Self {
160 0 : socket_addr,
161 0 : cancel_token,
162 0 : }
163 0 : }
164 : /// Cancels the query running on user's compute node.
165 0 : pub(crate) async fn try_cancel_query(self) -> Result<(), CancelError> {
166 0 : let socket = TcpStream::connect(self.socket_addr).await?;
167 0 : self.cancel_token.cancel_query_raw(socket, NoTls).await?;
168 0 : debug!("query was cancelled");
169 0 : Ok(())
170 0 : }
171 : }
172 :
173 : /// Helper for registering query cancellation tokens.
174 : pub(crate) struct Session<P> {
175 : /// The user-facing key identifying this session.
176 : key: CancelKeyData,
177 : /// The [`CancelMap`] this session belongs to.
178 : cancellation_handler: Arc<CancellationHandler<P>>,
179 : }
180 :
181 : impl<P> Session<P> {
182 : /// Store the cancel token for the given session.
183 : /// This enables query cancellation in `crate::proxy::prepare_client_connection`.
184 0 : pub(crate) fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData {
185 0 : debug!("enabling query cancellation for this session");
186 0 : self.cancellation_handler
187 0 : .map
188 0 : .insert(self.key, Some(cancel_closure));
189 0 :
190 0 : self.key
191 0 : }
192 : }
193 :
194 : impl<P> Drop for Session<P> {
195 1 : fn drop(&mut self) {
196 1 : self.cancellation_handler.map.remove(&self.key);
197 1 : debug!("dropped query cancellation key {}", &self.key);
198 1 : }
199 : }
200 :
201 : #[cfg(test)]
202 : mod tests {
203 : use super::*;
204 :
205 : #[tokio::test]
206 1 : async fn check_session_drop() -> anyhow::Result<()> {
207 1 : let cancellation_handler = Arc::new(CancellationHandler::<()>::new(
208 1 : CancelMap::default(),
209 1 : CancellationSource::FromRedis,
210 1 : ));
211 1 :
212 1 : let session = cancellation_handler.clone().get_session();
213 1 : assert!(cancellation_handler.contains(&session));
214 1 : drop(session);
215 1 : // Check that the session has been dropped.
216 1 : assert!(cancellation_handler.is_empty());
217 1 :
218 1 : Ok(())
219 1 : }
220 :
221 : #[tokio::test]
222 1 : async fn cancel_session_noop_regression() {
223 1 : let handler =
224 1 : CancellationHandler::<()>::new(CancelMap::default(), CancellationSource::Local);
225 1 : handler
226 1 : .cancel_session(
227 1 : CancelKeyData {
228 1 : backend_pid: 0,
229 1 : cancel_key: 0,
230 1 : },
231 1 : Uuid::new_v4(),
232 1 : )
233 1 : .await
234 1 : .unwrap();
235 1 : }
236 : }
|