Line data Source code
1 : use std::net::{IpAddr, SocketAddr};
2 : use std::sync::Arc;
3 :
4 : use dashmap::DashMap;
5 : use ipnet::{IpNet, Ipv4Net, Ipv6Net};
6 : use postgres_client::tls::MakeTlsConnect;
7 : use postgres_client::CancelToken;
8 : use pq_proto::CancelKeyData;
9 : use thiserror::Error;
10 : use tokio::net::TcpStream;
11 : use tokio::sync::Mutex;
12 : use tracing::{debug, info};
13 : use uuid::Uuid;
14 :
15 : use crate::auth::backend::{BackendIpAllowlist, ComputeUserInfo};
16 : use crate::auth::{check_peer_addr_is_in_list, AuthError, IpPattern};
17 : use crate::config::ComputeConfig;
18 : use crate::context::RequestContext;
19 : use crate::error::ReportableError;
20 : use crate::ext::LockExt;
21 : use crate::metrics::{CancellationRequest, CancellationSource, Metrics};
22 : use crate::rate_limiter::LeakyBucketRateLimiter;
23 : use crate::redis::cancellation_publisher::{
24 : CancellationPublisher, CancellationPublisherMut, RedisPublisherClient,
25 : };
26 : use crate::tls::postgres_rustls::MakeRustlsConnect;
27 :
28 : pub type CancelMap = Arc<DashMap<CancelKeyData, Option<CancelClosure>>>;
29 : pub type CancellationHandlerMain = CancellationHandler<Option<Arc<Mutex<RedisPublisherClient>>>>;
30 : pub(crate) type CancellationHandlerMainInternal = Option<Arc<Mutex<RedisPublisherClient>>>;
31 :
32 : type IpSubnetKey = IpNet;
33 :
34 : /// Enables serving `CancelRequest`s.
35 : ///
36 : /// If `CancellationPublisher` is available, cancel request will be used to publish the cancellation key to other proxy instances.
37 : pub struct CancellationHandler<P> {
38 : compute_config: &'static ComputeConfig,
39 : map: CancelMap,
40 : client: P,
41 : /// This field used for the monitoring purposes.
42 : /// Represents the source of the cancellation request.
43 : from: CancellationSource,
44 : // rate limiter of cancellation requests
45 : limiter: Arc<std::sync::Mutex<LeakyBucketRateLimiter<IpSubnetKey>>>,
46 : }
47 :
48 : #[derive(Debug, Error)]
49 : pub(crate) enum CancelError {
50 : #[error("{0}")]
51 : IO(#[from] std::io::Error),
52 :
53 : #[error("{0}")]
54 : Postgres(#[from] postgres_client::Error),
55 :
56 : #[error("rate limit exceeded")]
57 : RateLimit,
58 :
59 : #[error("IP is not allowed")]
60 : IpNotAllowed,
61 :
62 : #[error("Authentication backend error")]
63 : AuthError(#[from] AuthError),
64 : }
65 :
66 : impl ReportableError for CancelError {
67 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
68 0 : match self {
69 0 : CancelError::IO(_) => crate::error::ErrorKind::Compute,
70 0 : CancelError::Postgres(e) if e.as_db_error().is_some() => {
71 0 : crate::error::ErrorKind::Postgres
72 : }
73 0 : CancelError::Postgres(_) => crate::error::ErrorKind::Compute,
74 0 : CancelError::RateLimit => crate::error::ErrorKind::RateLimit,
75 0 : CancelError::IpNotAllowed => crate::error::ErrorKind::User,
76 0 : CancelError::AuthError(_) => crate::error::ErrorKind::ControlPlane,
77 : }
78 0 : }
79 : }
80 :
81 : impl<P: CancellationPublisher> CancellationHandler<P> {
82 : /// Run async action within an ephemeral session identified by [`CancelKeyData`].
83 1 : pub(crate) fn get_session(self: Arc<Self>) -> Session<P> {
84 : // we intentionally generate a random "backend pid" and "secret key" here.
85 : // we use the corresponding u64 as an identifier for the
86 : // actual endpoint+pid+secret for postgres/pgbouncer.
87 : //
88 : // if we forwarded the backend_pid from postgres to the client, there would be a lot
89 : // of overlap between our computes as most pids are small (~100).
90 1 : let key = loop {
91 1 : let key = rand::random();
92 1 :
93 1 : // Random key collisions are unlikely to happen here, but they're still possible,
94 1 : // which is why we have to take care not to rewrite an existing key.
95 1 : match self.map.entry(key) {
96 0 : dashmap::mapref::entry::Entry::Occupied(_) => continue,
97 1 : dashmap::mapref::entry::Entry::Vacant(e) => {
98 1 : e.insert(None);
99 1 : }
100 1 : }
101 1 : break key;
102 1 : };
103 1 :
104 1 : debug!("registered new query cancellation key {key}");
105 1 : Session {
106 1 : key,
107 1 : cancellation_handler: self,
108 1 : }
109 1 : }
110 :
111 : /// Cancelling only in notification, will be removed
112 1 : pub(crate) async fn cancel_session(
113 1 : &self,
114 1 : key: CancelKeyData,
115 1 : session_id: Uuid,
116 1 : peer_addr: IpAddr,
117 1 : check_allowed: bool,
118 1 : ) -> Result<(), CancelError> {
119 1 : // TODO: check for unspecified address is only for backward compatibility, should be removed
120 1 : if !peer_addr.is_unspecified() {
121 1 : let subnet_key = match peer_addr {
122 1 : IpAddr::V4(ip) => IpNet::V4(Ipv4Net::new_assert(ip, 24).trunc()), // use defaut mask here
123 0 : IpAddr::V6(ip) => IpNet::V6(Ipv6Net::new_assert(ip, 64).trunc()),
124 : };
125 1 : if !self.limiter.lock_propagate_poison().check(subnet_key, 1) {
126 : // log only the subnet part of the IP address to know which subnet is rate limited
127 0 : tracing::warn!("Rate limit exceeded. Skipping cancellation message, {subnet_key}");
128 0 : Metrics::get()
129 0 : .proxy
130 0 : .cancellation_requests_total
131 0 : .inc(CancellationRequest {
132 0 : source: self.from,
133 0 : kind: crate::metrics::CancellationOutcome::RateLimitExceeded,
134 0 : });
135 0 : return Err(CancelError::RateLimit);
136 1 : }
137 0 : }
138 :
139 : // NB: we should immediately release the lock after cloning the token.
140 1 : let cancel_state = self.map.get(&key).and_then(|x| x.clone());
141 1 : let Some(cancel_closure) = cancel_state else {
142 1 : tracing::warn!("query cancellation key not found: {key}");
143 1 : Metrics::get()
144 1 : .proxy
145 1 : .cancellation_requests_total
146 1 : .inc(CancellationRequest {
147 1 : source: self.from,
148 1 : kind: crate::metrics::CancellationOutcome::NotFound,
149 1 : });
150 1 :
151 1 : if session_id == Uuid::nil() {
152 : // was already published, do not publish it again
153 0 : return Ok(());
154 1 : }
155 1 :
156 1 : match self.client.try_publish(key, session_id, peer_addr).await {
157 1 : Ok(()) => {} // do nothing
158 0 : Err(e) => {
159 0 : // log it here since cancel_session could be spawned in a task
160 0 : tracing::error!("failed to publish cancellation key: {key}, error: {e}");
161 0 : return Err(CancelError::IO(std::io::Error::new(
162 0 : std::io::ErrorKind::Other,
163 0 : e.to_string(),
164 0 : )));
165 : }
166 : }
167 1 : return Ok(());
168 : };
169 :
170 0 : if check_allowed
171 0 : && !check_peer_addr_is_in_list(&peer_addr, cancel_closure.ip_allowlist.as_slice())
172 : {
173 : // log it here since cancel_session could be spawned in a task
174 0 : tracing::warn!("IP is not allowed to cancel the query: {key}");
175 0 : return Err(CancelError::IpNotAllowed);
176 0 : }
177 0 :
178 0 : Metrics::get()
179 0 : .proxy
180 0 : .cancellation_requests_total
181 0 : .inc(CancellationRequest {
182 0 : source: self.from,
183 0 : kind: crate::metrics::CancellationOutcome::Found,
184 0 : });
185 0 : info!(
186 0 : "cancelling query per user's request using key {key}, hostname {}, address: {}",
187 : cancel_closure.hostname, cancel_closure.socket_addr
188 : );
189 0 : cancel_closure.try_cancel_query(self.compute_config).await
190 1 : }
191 :
192 : /// Try to cancel a running query for the corresponding connection.
193 : /// If the cancellation key is not found, it will be published to Redis.
194 : /// check_allowed - if true, check if the IP is allowed to cancel the query.
195 : /// Will fetch IP allowlist internally.
196 : ///
197 : /// return Result primarily for tests
198 0 : pub(crate) async fn cancel_session_auth<T: BackendIpAllowlist>(
199 0 : &self,
200 0 : key: CancelKeyData,
201 0 : ctx: RequestContext,
202 0 : check_allowed: bool,
203 0 : auth_backend: &T,
204 0 : ) -> Result<(), CancelError> {
205 0 : // TODO: check for unspecified address is only for backward compatibility, should be removed
206 0 : if !ctx.peer_addr().is_unspecified() {
207 0 : let subnet_key = match ctx.peer_addr() {
208 0 : IpAddr::V4(ip) => IpNet::V4(Ipv4Net::new_assert(ip, 24).trunc()), // use defaut mask here
209 0 : IpAddr::V6(ip) => IpNet::V6(Ipv6Net::new_assert(ip, 64).trunc()),
210 : };
211 0 : if !self.limiter.lock_propagate_poison().check(subnet_key, 1) {
212 : // log only the subnet part of the IP address to know which subnet is rate limited
213 0 : tracing::warn!("Rate limit exceeded. Skipping cancellation message, {subnet_key}");
214 0 : Metrics::get()
215 0 : .proxy
216 0 : .cancellation_requests_total
217 0 : .inc(CancellationRequest {
218 0 : source: self.from,
219 0 : kind: crate::metrics::CancellationOutcome::RateLimitExceeded,
220 0 : });
221 0 : return Err(CancelError::RateLimit);
222 0 : }
223 0 : }
224 :
225 : // NB: we should immediately release the lock after cloning the token.
226 0 : let cancel_state = self.map.get(&key).and_then(|x| x.clone());
227 0 : let Some(cancel_closure) = cancel_state else {
228 0 : tracing::warn!("query cancellation key not found: {key}");
229 0 : Metrics::get()
230 0 : .proxy
231 0 : .cancellation_requests_total
232 0 : .inc(CancellationRequest {
233 0 : source: self.from,
234 0 : kind: crate::metrics::CancellationOutcome::NotFound,
235 0 : });
236 0 :
237 0 : if ctx.session_id() == Uuid::nil() {
238 : // was already published, do not publish it again
239 0 : return Ok(());
240 0 : }
241 0 :
242 0 : match self
243 0 : .client
244 0 : .try_publish(key, ctx.session_id(), ctx.peer_addr())
245 0 : .await
246 : {
247 0 : Ok(()) => {} // do nothing
248 0 : Err(e) => {
249 0 : // log it here since cancel_session could be spawned in a task
250 0 : tracing::error!("failed to publish cancellation key: {key}, error: {e}");
251 0 : return Err(CancelError::IO(std::io::Error::new(
252 0 : std::io::ErrorKind::Other,
253 0 : e.to_string(),
254 0 : )));
255 : }
256 : }
257 0 : return Ok(());
258 : };
259 :
260 0 : let ip_allowlist = auth_backend
261 0 : .get_allowed_ips(&ctx, &cancel_closure.user_info)
262 0 : .await
263 0 : .map_err(CancelError::AuthError)?;
264 :
265 0 : if check_allowed && !check_peer_addr_is_in_list(&ctx.peer_addr(), &ip_allowlist) {
266 : // log it here since cancel_session could be spawned in a task
267 0 : tracing::warn!("IP is not allowed to cancel the query: {key}");
268 0 : return Err(CancelError::IpNotAllowed);
269 0 : }
270 0 :
271 0 : Metrics::get()
272 0 : .proxy
273 0 : .cancellation_requests_total
274 0 : .inc(CancellationRequest {
275 0 : source: self.from,
276 0 : kind: crate::metrics::CancellationOutcome::Found,
277 0 : });
278 0 : info!("cancelling query per user's request using key {key}");
279 0 : cancel_closure.try_cancel_query(self.compute_config).await
280 0 : }
281 :
282 : #[cfg(test)]
283 1 : fn contains(&self, session: &Session<P>) -> bool {
284 1 : self.map.contains_key(&session.key)
285 1 : }
286 :
287 : #[cfg(test)]
288 1 : fn is_empty(&self) -> bool {
289 1 : self.map.is_empty()
290 1 : }
291 : }
292 :
293 : impl CancellationHandler<()> {
294 2 : pub fn new(
295 2 : compute_config: &'static ComputeConfig,
296 2 : map: CancelMap,
297 2 : from: CancellationSource,
298 2 : ) -> Self {
299 2 : Self {
300 2 : compute_config,
301 2 : map,
302 2 : client: (),
303 2 : from,
304 2 : limiter: Arc::new(std::sync::Mutex::new(
305 2 : LeakyBucketRateLimiter::<IpSubnetKey>::new_with_shards(
306 2 : LeakyBucketRateLimiter::<IpSubnetKey>::DEFAULT,
307 2 : 64,
308 2 : ),
309 2 : )),
310 2 : }
311 2 : }
312 : }
313 :
314 : impl<P: CancellationPublisherMut> CancellationHandler<Option<Arc<Mutex<P>>>> {
315 0 : pub fn new(
316 0 : compute_config: &'static ComputeConfig,
317 0 : map: CancelMap,
318 0 : client: Option<Arc<Mutex<P>>>,
319 0 : from: CancellationSource,
320 0 : ) -> Self {
321 0 : Self {
322 0 : compute_config,
323 0 : map,
324 0 : client,
325 0 : from,
326 0 : limiter: Arc::new(std::sync::Mutex::new(
327 0 : LeakyBucketRateLimiter::<IpSubnetKey>::new_with_shards(
328 0 : LeakyBucketRateLimiter::<IpSubnetKey>::DEFAULT,
329 0 : 64,
330 0 : ),
331 0 : )),
332 0 : }
333 0 : }
334 : }
335 :
336 : /// This should've been a [`std::future::Future`], but
337 : /// it's impossible to name a type of an unboxed future
338 : /// (we'd need something like `#![feature(type_alias_impl_trait)]`).
339 : #[derive(Clone)]
340 : pub struct CancelClosure {
341 : socket_addr: SocketAddr,
342 : cancel_token: CancelToken,
343 : ip_allowlist: Vec<IpPattern>,
344 : hostname: String, // for pg_sni router
345 : user_info: ComputeUserInfo,
346 : }
347 :
348 : impl CancelClosure {
349 0 : pub(crate) fn new(
350 0 : socket_addr: SocketAddr,
351 0 : cancel_token: CancelToken,
352 0 : ip_allowlist: Vec<IpPattern>,
353 0 : hostname: String,
354 0 : user_info: ComputeUserInfo,
355 0 : ) -> Self {
356 0 : Self {
357 0 : socket_addr,
358 0 : cancel_token,
359 0 : ip_allowlist,
360 0 : hostname,
361 0 : user_info,
362 0 : }
363 0 : }
364 : /// Cancels the query running on user's compute node.
365 0 : pub(crate) async fn try_cancel_query(
366 0 : self,
367 0 : compute_config: &ComputeConfig,
368 0 : ) -> Result<(), CancelError> {
369 0 : let socket = TcpStream::connect(self.socket_addr).await?;
370 :
371 0 : let mut mk_tls =
372 0 : crate::tls::postgres_rustls::MakeRustlsConnect::new(compute_config.tls.clone());
373 0 : let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
374 0 : &mut mk_tls,
375 0 : &self.hostname,
376 0 : )
377 0 : .map_err(|e| {
378 0 : CancelError::IO(std::io::Error::new(
379 0 : std::io::ErrorKind::Other,
380 0 : e.to_string(),
381 0 : ))
382 0 : })?;
383 :
384 0 : self.cancel_token.cancel_query_raw(socket, tls).await?;
385 0 : debug!("query was cancelled");
386 0 : Ok(())
387 0 : }
388 :
389 : /// Obsolete (will be removed after moving CancelMap to Redis), only for notifications
390 0 : pub(crate) fn set_ip_allowlist(&mut self, ip_allowlist: Vec<IpPattern>) {
391 0 : self.ip_allowlist = ip_allowlist;
392 0 : }
393 : }
394 :
395 : /// Helper for registering query cancellation tokens.
396 : pub(crate) struct Session<P> {
397 : /// The user-facing key identifying this session.
398 : key: CancelKeyData,
399 : /// The [`CancelMap`] this session belongs to.
400 : cancellation_handler: Arc<CancellationHandler<P>>,
401 : }
402 :
403 : impl<P> Session<P> {
404 : /// Store the cancel token for the given session.
405 : /// This enables query cancellation in `crate::proxy::prepare_client_connection`.
406 0 : pub(crate) fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData {
407 0 : debug!("enabling query cancellation for this session");
408 0 : self.cancellation_handler
409 0 : .map
410 0 : .insert(self.key, Some(cancel_closure));
411 0 :
412 0 : self.key
413 0 : }
414 : }
415 :
416 : impl<P> Drop for Session<P> {
417 1 : fn drop(&mut self) {
418 1 : self.cancellation_handler.map.remove(&self.key);
419 1 : debug!("dropped query cancellation key {}", &self.key);
420 1 : }
421 : }
422 :
423 : #[cfg(test)]
424 : #[expect(clippy::unwrap_used)]
425 : mod tests {
426 : use std::time::Duration;
427 :
428 : use super::*;
429 : use crate::config::RetryConfig;
430 : use crate::tls::client_config::compute_client_config_with_certs;
431 :
432 2 : fn config() -> ComputeConfig {
433 2 : let retry = RetryConfig {
434 2 : base_delay: Duration::from_secs(1),
435 2 : max_retries: 5,
436 2 : backoff_factor: 2.0,
437 2 : };
438 2 :
439 2 : ComputeConfig {
440 2 : retry,
441 2 : tls: Arc::new(compute_client_config_with_certs(std::iter::empty())),
442 2 : timeout: Duration::from_secs(2),
443 2 : }
444 2 : }
445 :
446 : #[tokio::test]
447 1 : async fn check_session_drop() -> anyhow::Result<()> {
448 1 : let cancellation_handler = Arc::new(CancellationHandler::<()>::new(
449 1 : Box::leak(Box::new(config())),
450 1 : CancelMap::default(),
451 1 : CancellationSource::FromRedis,
452 1 : ));
453 1 :
454 1 : let session = cancellation_handler.clone().get_session();
455 1 : assert!(cancellation_handler.contains(&session));
456 1 : drop(session);
457 1 : // Check that the session has been dropped.
458 1 : assert!(cancellation_handler.is_empty());
459 1 :
460 1 : Ok(())
461 1 : }
462 :
463 : #[tokio::test]
464 1 : async fn cancel_session_noop_regression() {
465 1 : let handler = CancellationHandler::<()>::new(
466 1 : Box::leak(Box::new(config())),
467 1 : CancelMap::default(),
468 1 : CancellationSource::Local,
469 1 : );
470 1 : handler
471 1 : .cancel_session(
472 1 : CancelKeyData {
473 1 : backend_pid: 0,
474 1 : cancel_key: 0,
475 1 : },
476 1 : Uuid::new_v4(),
477 1 : "127.0.0.1".parse().unwrap(),
478 1 : true,
479 1 : )
480 1 : .await
481 1 : .unwrap();
482 1 : }
483 : }
|