Line data Source code
1 : use std::net::SocketAddr;
2 : use std::sync::Arc;
3 :
4 : use dashmap::DashMap;
5 : use pq_proto::CancelKeyData;
6 : use thiserror::Error;
7 : use tokio::net::TcpStream;
8 : use tokio::sync::Mutex;
9 : use tokio_postgres::{CancelToken, NoTls};
10 : use tracing::{debug, info};
11 : use uuid::Uuid;
12 :
13 : use crate::auth::{check_peer_addr_is_in_list, IpPattern};
14 : use crate::error::ReportableError;
15 : use crate::metrics::{CancellationRequest, CancellationSource, Metrics};
16 : use crate::rate_limiter::LeakyBucketRateLimiter;
17 : use crate::redis::cancellation_publisher::{
18 : CancellationPublisher, CancellationPublisherMut, RedisPublisherClient,
19 : };
20 : use std::net::IpAddr;
21 :
22 : use ipnet::{IpNet, Ipv4Net, Ipv6Net};
23 :
24 : pub type CancelMap = Arc<DashMap<CancelKeyData, Option<CancelClosure>>>;
25 : pub type CancellationHandlerMain = CancellationHandler<Option<Arc<Mutex<RedisPublisherClient>>>>;
26 : pub(crate) type CancellationHandlerMainInternal = Option<Arc<Mutex<RedisPublisherClient>>>;
27 :
28 : type IpSubnetKey = IpNet;
29 :
30 : /// Enables serving `CancelRequest`s.
31 : ///
32 : /// If `CancellationPublisher` is available, cancel request will be used to publish the cancellation key to other proxy instances.
33 : pub struct CancellationHandler<P> {
34 : map: CancelMap,
35 : client: P,
36 : /// This field used for the monitoring purposes.
37 : /// Represents the source of the cancellation request.
38 : from: CancellationSource,
39 : // rate limiter of cancellation requests
40 : limiter: Arc<std::sync::Mutex<LeakyBucketRateLimiter<IpSubnetKey>>>,
41 : }
42 :
43 0 : #[derive(Debug, Error)]
44 : pub(crate) enum CancelError {
45 : #[error("{0}")]
46 : IO(#[from] std::io::Error),
47 :
48 : #[error("{0}")]
49 : Postgres(#[from] tokio_postgres::Error),
50 :
51 : #[error("rate limit exceeded")]
52 : RateLimit,
53 :
54 : #[error("IP is not allowed")]
55 : IpNotAllowed,
56 : }
57 :
58 : impl ReportableError for CancelError {
59 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
60 0 : match self {
61 0 : CancelError::IO(_) => crate::error::ErrorKind::Compute,
62 0 : CancelError::Postgres(e) if e.as_db_error().is_some() => {
63 0 : crate::error::ErrorKind::Postgres
64 : }
65 0 : CancelError::Postgres(_) => crate::error::ErrorKind::Compute,
66 0 : CancelError::RateLimit => crate::error::ErrorKind::RateLimit,
67 0 : CancelError::IpNotAllowed => crate::error::ErrorKind::User,
68 : }
69 0 : }
70 : }
71 :
72 : impl<P: CancellationPublisher> CancellationHandler<P> {
73 : /// Run async action within an ephemeral session identified by [`CancelKeyData`].
74 1 : pub(crate) fn get_session(self: Arc<Self>) -> Session<P> {
75 : // HACK: We'd rather get the real backend_pid but tokio_postgres doesn't
76 : // expose it and we don't want to do another roundtrip to query
77 : // for it. The client will be able to notice that this is not the
78 : // actual backend_pid, but backend_pid is not used for anything
79 : // so it doesn't matter.
80 1 : let key = loop {
81 1 : let key = rand::random();
82 1 :
83 1 : // Random key collisions are unlikely to happen here, but they're still possible,
84 1 : // which is why we have to take care not to rewrite an existing key.
85 1 : match self.map.entry(key) {
86 0 : dashmap::mapref::entry::Entry::Occupied(_) => continue,
87 1 : dashmap::mapref::entry::Entry::Vacant(e) => {
88 1 : e.insert(None);
89 1 : }
90 1 : }
91 1 : break key;
92 1 : };
93 1 :
94 1 : debug!("registered new query cancellation key {key}");
95 1 : Session {
96 1 : key,
97 1 : cancellation_handler: self,
98 1 : }
99 1 : }
100 :
101 : /// Try to cancel a running query for the corresponding connection.
102 : /// If the cancellation key is not found, it will be published to Redis.
103 : /// check_allowed - if true, check if the IP is allowed to cancel the query
104 1 : pub(crate) async fn cancel_session(
105 1 : &self,
106 1 : key: CancelKeyData,
107 1 : session_id: Uuid,
108 1 : peer_addr: &IpAddr,
109 1 : check_allowed: bool,
110 1 : ) -> Result<(), CancelError> {
111 1 : // TODO: check for unspecified address is only for backward compatibility, should be removed
112 1 : if !peer_addr.is_unspecified() {
113 1 : let subnet_key = match *peer_addr {
114 1 : IpAddr::V4(ip) => IpNet::V4(Ipv4Net::new_assert(ip, 24).trunc()), // use defaut mask here
115 0 : IpAddr::V6(ip) => IpNet::V6(Ipv6Net::new_assert(ip, 64).trunc()),
116 : };
117 1 : if !self.limiter.lock().unwrap().check(subnet_key, 1) {
118 0 : tracing::debug!("Rate limit exceeded. Skipping cancellation message");
119 0 : Metrics::get()
120 0 : .proxy
121 0 : .cancellation_requests_total
122 0 : .inc(CancellationRequest {
123 0 : source: self.from,
124 0 : kind: crate::metrics::CancellationOutcome::RateLimitExceeded,
125 0 : });
126 0 : return Err(CancelError::RateLimit);
127 1 : }
128 0 : }
129 :
130 : // NB: we should immediately release the lock after cloning the token.
131 1 : let Some(cancel_closure) = self.map.get(&key).and_then(|x| x.clone()) else {
132 1 : tracing::warn!("query cancellation key not found: {key}");
133 1 : Metrics::get()
134 1 : .proxy
135 1 : .cancellation_requests_total
136 1 : .inc(CancellationRequest {
137 1 : source: self.from,
138 1 : kind: crate::metrics::CancellationOutcome::NotFound,
139 1 : });
140 1 :
141 1 : if session_id == Uuid::nil() {
142 : // was already published, do not publish it again
143 0 : return Ok(());
144 1 : }
145 1 :
146 1 : match self.client.try_publish(key, session_id, *peer_addr).await {
147 1 : Ok(()) => {} // do nothing
148 0 : Err(e) => {
149 0 : return Err(CancelError::IO(std::io::Error::new(
150 0 : std::io::ErrorKind::Other,
151 0 : e.to_string(),
152 0 : )));
153 : }
154 : }
155 1 : return Ok(());
156 : };
157 :
158 0 : if check_allowed
159 0 : && !check_peer_addr_is_in_list(peer_addr, cancel_closure.ip_allowlist.as_slice())
160 : {
161 0 : return Err(CancelError::IpNotAllowed);
162 0 : }
163 0 :
164 0 : Metrics::get()
165 0 : .proxy
166 0 : .cancellation_requests_total
167 0 : .inc(CancellationRequest {
168 0 : source: self.from,
169 0 : kind: crate::metrics::CancellationOutcome::Found,
170 0 : });
171 0 : info!("cancelling query per user's request using key {key}");
172 0 : cancel_closure.try_cancel_query().await
173 1 : }
174 :
175 : #[cfg(test)]
176 1 : fn contains(&self, session: &Session<P>) -> bool {
177 1 : self.map.contains_key(&session.key)
178 1 : }
179 :
180 : #[cfg(test)]
181 1 : fn is_empty(&self) -> bool {
182 1 : self.map.is_empty()
183 1 : }
184 : }
185 :
186 : impl CancellationHandler<()> {
187 2 : pub fn new(map: CancelMap, from: CancellationSource) -> Self {
188 2 : Self {
189 2 : map,
190 2 : client: (),
191 2 : from,
192 2 : limiter: Arc::new(std::sync::Mutex::new(
193 2 : LeakyBucketRateLimiter::<IpSubnetKey>::new_with_shards(
194 2 : LeakyBucketRateLimiter::<IpSubnetKey>::DEFAULT,
195 2 : 64,
196 2 : ),
197 2 : )),
198 2 : }
199 2 : }
200 : }
201 :
202 : impl<P: CancellationPublisherMut> CancellationHandler<Option<Arc<Mutex<P>>>> {
203 0 : pub fn new(map: CancelMap, client: Option<Arc<Mutex<P>>>, from: CancellationSource) -> Self {
204 0 : Self {
205 0 : map,
206 0 : client,
207 0 : from,
208 0 : limiter: Arc::new(std::sync::Mutex::new(
209 0 : LeakyBucketRateLimiter::<IpSubnetKey>::new_with_shards(
210 0 : LeakyBucketRateLimiter::<IpSubnetKey>::DEFAULT,
211 0 : 64,
212 0 : ),
213 0 : )),
214 0 : }
215 0 : }
216 : }
217 :
218 : /// This should've been a [`std::future::Future`], but
219 : /// it's impossible to name a type of an unboxed future
220 : /// (we'd need something like `#![feature(type_alias_impl_trait)]`).
221 : #[derive(Clone)]
222 : pub struct CancelClosure {
223 : socket_addr: SocketAddr,
224 : cancel_token: CancelToken,
225 : ip_allowlist: Vec<IpPattern>,
226 : }
227 :
228 : impl CancelClosure {
229 0 : pub(crate) fn new(
230 0 : socket_addr: SocketAddr,
231 0 : cancel_token: CancelToken,
232 0 : ip_allowlist: Vec<IpPattern>,
233 0 : ) -> Self {
234 0 : Self {
235 0 : socket_addr,
236 0 : cancel_token,
237 0 : ip_allowlist,
238 0 : }
239 0 : }
240 : /// Cancels the query running on user's compute node.
241 0 : pub(crate) async fn try_cancel_query(self) -> Result<(), CancelError> {
242 0 : let socket = TcpStream::connect(self.socket_addr).await?;
243 0 : self.cancel_token.cancel_query_raw(socket, NoTls).await?;
244 0 : debug!("query was cancelled");
245 0 : Ok(())
246 0 : }
247 0 : pub(crate) fn set_ip_allowlist(&mut self, ip_allowlist: Vec<IpPattern>) {
248 0 : self.ip_allowlist = ip_allowlist;
249 0 : }
250 : }
251 :
252 : /// Helper for registering query cancellation tokens.
253 : pub(crate) struct Session<P> {
254 : /// The user-facing key identifying this session.
255 : key: CancelKeyData,
256 : /// The [`CancelMap`] this session belongs to.
257 : cancellation_handler: Arc<CancellationHandler<P>>,
258 : }
259 :
260 : impl<P> Session<P> {
261 : /// Store the cancel token for the given session.
262 : /// This enables query cancellation in `crate::proxy::prepare_client_connection`.
263 0 : pub(crate) fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData {
264 0 : debug!("enabling query cancellation for this session");
265 0 : self.cancellation_handler
266 0 : .map
267 0 : .insert(self.key, Some(cancel_closure));
268 0 :
269 0 : self.key
270 0 : }
271 : }
272 :
273 : impl<P> Drop for Session<P> {
274 1 : fn drop(&mut self) {
275 1 : self.cancellation_handler.map.remove(&self.key);
276 1 : debug!("dropped query cancellation key {}", &self.key);
277 1 : }
278 : }
279 :
280 : #[cfg(test)]
281 : mod tests {
282 : use super::*;
283 :
284 : #[tokio::test]
285 1 : async fn check_session_drop() -> anyhow::Result<()> {
286 1 : let cancellation_handler = Arc::new(CancellationHandler::<()>::new(
287 1 : CancelMap::default(),
288 1 : CancellationSource::FromRedis,
289 1 : ));
290 1 :
291 1 : let session = cancellation_handler.clone().get_session();
292 1 : assert!(cancellation_handler.contains(&session));
293 1 : drop(session);
294 1 : // Check that the session has been dropped.
295 1 : assert!(cancellation_handler.is_empty());
296 1 :
297 1 : Ok(())
298 1 : }
299 :
300 : #[tokio::test]
301 1 : async fn cancel_session_noop_regression() {
302 1 : let handler =
303 1 : CancellationHandler::<()>::new(CancelMap::default(), CancellationSource::Local);
304 1 : handler
305 1 : .cancel_session(
306 1 : CancelKeyData {
307 1 : backend_pid: 0,
308 1 : cancel_key: 0,
309 1 : },
310 1 : Uuid::new_v4(),
311 1 : &("127.0.0.1".parse().unwrap()),
312 1 : true,
313 1 : )
314 1 : .await
315 1 : .unwrap();
316 1 : }
317 : }
|