Line data Source code
1 : use std::net::{IpAddr, SocketAddr};
2 : use std::sync::Arc;
3 :
4 : use dashmap::DashMap;
5 : use ipnet::{IpNet, Ipv4Net, Ipv6Net};
6 : use postgres_client::tls::MakeTlsConnect;
7 : use postgres_client::CancelToken;
8 : use pq_proto::CancelKeyData;
9 : use thiserror::Error;
10 : use tokio::net::TcpStream;
11 : use tokio::sync::Mutex;
12 : use tracing::{debug, info};
13 : use uuid::Uuid;
14 :
15 : use crate::auth::{check_peer_addr_is_in_list, IpPattern};
16 : use crate::config::ComputeConfig;
17 : use crate::error::ReportableError;
18 : use crate::ext::LockExt;
19 : use crate::metrics::{CancellationRequest, CancellationSource, Metrics};
20 : use crate::rate_limiter::LeakyBucketRateLimiter;
21 : use crate::redis::cancellation_publisher::{
22 : CancellationPublisher, CancellationPublisherMut, RedisPublisherClient,
23 : };
24 : use crate::tls::postgres_rustls::MakeRustlsConnect;
25 :
26 : pub type CancelMap = Arc<DashMap<CancelKeyData, Option<CancelClosure>>>;
27 : pub type CancellationHandlerMain = CancellationHandler<Option<Arc<Mutex<RedisPublisherClient>>>>;
28 : pub(crate) type CancellationHandlerMainInternal = Option<Arc<Mutex<RedisPublisherClient>>>;
29 :
30 : type IpSubnetKey = IpNet;
31 :
32 : /// Enables serving `CancelRequest`s.
33 : ///
34 : /// If `CancellationPublisher` is available, cancel request will be used to publish the cancellation key to other proxy instances.
35 : pub struct CancellationHandler<P> {
36 : compute_config: &'static ComputeConfig,
37 : map: CancelMap,
38 : client: P,
39 : /// This field used for the monitoring purposes.
40 : /// Represents the source of the cancellation request.
41 : from: CancellationSource,
42 : // rate limiter of cancellation requests
43 : limiter: Arc<std::sync::Mutex<LeakyBucketRateLimiter<IpSubnetKey>>>,
44 : }
45 :
46 : #[derive(Debug, Error)]
47 : pub(crate) enum CancelError {
48 : #[error("{0}")]
49 : IO(#[from] std::io::Error),
50 :
51 : #[error("{0}")]
52 : Postgres(#[from] postgres_client::Error),
53 :
54 : #[error("rate limit exceeded")]
55 : RateLimit,
56 :
57 : #[error("IP is not allowed")]
58 : IpNotAllowed,
59 : }
60 :
61 : impl ReportableError for CancelError {
62 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
63 0 : match self {
64 0 : CancelError::IO(_) => crate::error::ErrorKind::Compute,
65 0 : CancelError::Postgres(e) if e.as_db_error().is_some() => {
66 0 : crate::error::ErrorKind::Postgres
67 : }
68 0 : CancelError::Postgres(_) => crate::error::ErrorKind::Compute,
69 0 : CancelError::RateLimit => crate::error::ErrorKind::RateLimit,
70 0 : CancelError::IpNotAllowed => crate::error::ErrorKind::User,
71 : }
72 0 : }
73 : }
74 :
75 : impl<P: CancellationPublisher> CancellationHandler<P> {
76 : /// Run async action within an ephemeral session identified by [`CancelKeyData`].
77 1 : pub(crate) fn get_session(self: Arc<Self>) -> Session<P> {
78 : // we intentionally generate a random "backend pid" and "secret key" here.
79 : // we use the corresponding u64 as an identifier for the
80 : // actual endpoint+pid+secret for postgres/pgbouncer.
81 : //
82 : // if we forwarded the backend_pid from postgres to the client, there would be a lot
83 : // of overlap between our computes as most pids are small (~100).
84 1 : let key = loop {
85 1 : let key = rand::random();
86 1 :
87 1 : // Random key collisions are unlikely to happen here, but they're still possible,
88 1 : // which is why we have to take care not to rewrite an existing key.
89 1 : match self.map.entry(key) {
90 0 : dashmap::mapref::entry::Entry::Occupied(_) => continue,
91 1 : dashmap::mapref::entry::Entry::Vacant(e) => {
92 1 : e.insert(None);
93 1 : }
94 1 : }
95 1 : break key;
96 1 : };
97 1 :
98 1 : debug!("registered new query cancellation key {key}");
99 1 : Session {
100 1 : key,
101 1 : cancellation_handler: self,
102 1 : }
103 1 : }
104 :
105 : /// Try to cancel a running query for the corresponding connection.
106 : /// If the cancellation key is not found, it will be published to Redis.
107 : /// check_allowed - if true, check if the IP is allowed to cancel the query
108 : /// return Result primarily for tests
109 1 : pub(crate) async fn cancel_session(
110 1 : &self,
111 1 : key: CancelKeyData,
112 1 : session_id: Uuid,
113 1 : peer_addr: IpAddr,
114 1 : check_allowed: bool,
115 1 : ) -> Result<(), CancelError> {
116 1 : // TODO: check for unspecified address is only for backward compatibility, should be removed
117 1 : if !peer_addr.is_unspecified() {
118 1 : let subnet_key = match peer_addr {
119 1 : IpAddr::V4(ip) => IpNet::V4(Ipv4Net::new_assert(ip, 24).trunc()), // use defaut mask here
120 0 : IpAddr::V6(ip) => IpNet::V6(Ipv6Net::new_assert(ip, 64).trunc()),
121 : };
122 1 : if !self.limiter.lock_propagate_poison().check(subnet_key, 1) {
123 : // log only the subnet part of the IP address to know which subnet is rate limited
124 0 : tracing::warn!("Rate limit exceeded. Skipping cancellation message, {subnet_key}");
125 0 : Metrics::get()
126 0 : .proxy
127 0 : .cancellation_requests_total
128 0 : .inc(CancellationRequest {
129 0 : source: self.from,
130 0 : kind: crate::metrics::CancellationOutcome::RateLimitExceeded,
131 0 : });
132 0 : return Err(CancelError::RateLimit);
133 1 : }
134 0 : }
135 :
136 : // NB: we should immediately release the lock after cloning the token.
137 1 : let Some(cancel_closure) = self.map.get(&key).and_then(|x| x.clone()) else {
138 1 : tracing::warn!("query cancellation key not found: {key}");
139 1 : Metrics::get()
140 1 : .proxy
141 1 : .cancellation_requests_total
142 1 : .inc(CancellationRequest {
143 1 : source: self.from,
144 1 : kind: crate::metrics::CancellationOutcome::NotFound,
145 1 : });
146 1 :
147 1 : if session_id == Uuid::nil() {
148 : // was already published, do not publish it again
149 0 : return Ok(());
150 1 : }
151 1 :
152 1 : match self.client.try_publish(key, session_id, peer_addr).await {
153 1 : Ok(()) => {} // do nothing
154 0 : Err(e) => {
155 0 : // log it here since cancel_session could be spawned in a task
156 0 : tracing::error!("failed to publish cancellation key: {key}, error: {e}");
157 0 : return Err(CancelError::IO(std::io::Error::new(
158 0 : std::io::ErrorKind::Other,
159 0 : e.to_string(),
160 0 : )));
161 : }
162 : }
163 1 : return Ok(());
164 : };
165 :
166 0 : if check_allowed
167 0 : && !check_peer_addr_is_in_list(&peer_addr, cancel_closure.ip_allowlist.as_slice())
168 : {
169 : // log it here since cancel_session could be spawned in a task
170 0 : tracing::warn!("IP is not allowed to cancel the query: {key}");
171 0 : return Err(CancelError::IpNotAllowed);
172 0 : }
173 0 :
174 0 : Metrics::get()
175 0 : .proxy
176 0 : .cancellation_requests_total
177 0 : .inc(CancellationRequest {
178 0 : source: self.from,
179 0 : kind: crate::metrics::CancellationOutcome::Found,
180 0 : });
181 0 : info!(
182 0 : "cancelling query per user's request using key {key}, hostname {}, address: {}",
183 : cancel_closure.hostname, cancel_closure.socket_addr
184 : );
185 0 : cancel_closure.try_cancel_query(self.compute_config).await
186 1 : }
187 :
188 : #[cfg(test)]
189 1 : fn contains(&self, session: &Session<P>) -> bool {
190 1 : self.map.contains_key(&session.key)
191 1 : }
192 :
193 : #[cfg(test)]
194 1 : fn is_empty(&self) -> bool {
195 1 : self.map.is_empty()
196 1 : }
197 : }
198 :
199 : impl CancellationHandler<()> {
200 2 : pub fn new(
201 2 : compute_config: &'static ComputeConfig,
202 2 : map: CancelMap,
203 2 : from: CancellationSource,
204 2 : ) -> Self {
205 2 : Self {
206 2 : compute_config,
207 2 : map,
208 2 : client: (),
209 2 : from,
210 2 : limiter: Arc::new(std::sync::Mutex::new(
211 2 : LeakyBucketRateLimiter::<IpSubnetKey>::new_with_shards(
212 2 : LeakyBucketRateLimiter::<IpSubnetKey>::DEFAULT,
213 2 : 64,
214 2 : ),
215 2 : )),
216 2 : }
217 2 : }
218 : }
219 :
220 : impl<P: CancellationPublisherMut> CancellationHandler<Option<Arc<Mutex<P>>>> {
221 0 : pub fn new(
222 0 : compute_config: &'static ComputeConfig,
223 0 : map: CancelMap,
224 0 : client: Option<Arc<Mutex<P>>>,
225 0 : from: CancellationSource,
226 0 : ) -> Self {
227 0 : Self {
228 0 : compute_config,
229 0 : map,
230 0 : client,
231 0 : from,
232 0 : limiter: Arc::new(std::sync::Mutex::new(
233 0 : LeakyBucketRateLimiter::<IpSubnetKey>::new_with_shards(
234 0 : LeakyBucketRateLimiter::<IpSubnetKey>::DEFAULT,
235 0 : 64,
236 0 : ),
237 0 : )),
238 0 : }
239 0 : }
240 : }
241 :
242 : /// This should've been a [`std::future::Future`], but
243 : /// it's impossible to name a type of an unboxed future
244 : /// (we'd need something like `#![feature(type_alias_impl_trait)]`).
245 : #[derive(Clone)]
246 : pub struct CancelClosure {
247 : socket_addr: SocketAddr,
248 : cancel_token: CancelToken,
249 : ip_allowlist: Vec<IpPattern>,
250 : hostname: String, // for pg_sni router
251 : }
252 :
253 : impl CancelClosure {
254 0 : pub(crate) fn new(
255 0 : socket_addr: SocketAddr,
256 0 : cancel_token: CancelToken,
257 0 : ip_allowlist: Vec<IpPattern>,
258 0 : hostname: String,
259 0 : ) -> Self {
260 0 : Self {
261 0 : socket_addr,
262 0 : cancel_token,
263 0 : ip_allowlist,
264 0 : hostname,
265 0 : }
266 0 : }
267 : /// Cancels the query running on user's compute node.
268 0 : pub(crate) async fn try_cancel_query(
269 0 : self,
270 0 : compute_config: &ComputeConfig,
271 0 : ) -> Result<(), CancelError> {
272 0 : let socket = TcpStream::connect(self.socket_addr).await?;
273 :
274 0 : let mut mk_tls =
275 0 : crate::tls::postgres_rustls::MakeRustlsConnect::new(compute_config.tls.clone());
276 0 : let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
277 0 : &mut mk_tls,
278 0 : &self.hostname,
279 0 : )
280 0 : .map_err(|e| {
281 0 : CancelError::IO(std::io::Error::new(
282 0 : std::io::ErrorKind::Other,
283 0 : e.to_string(),
284 0 : ))
285 0 : })?;
286 :
287 0 : self.cancel_token.cancel_query_raw(socket, tls).await?;
288 0 : debug!("query was cancelled");
289 0 : Ok(())
290 0 : }
291 0 : pub(crate) fn set_ip_allowlist(&mut self, ip_allowlist: Vec<IpPattern>) {
292 0 : self.ip_allowlist = ip_allowlist;
293 0 : }
294 : }
295 :
296 : /// Helper for registering query cancellation tokens.
297 : pub(crate) struct Session<P> {
298 : /// The user-facing key identifying this session.
299 : key: CancelKeyData,
300 : /// The [`CancelMap`] this session belongs to.
301 : cancellation_handler: Arc<CancellationHandler<P>>,
302 : }
303 :
304 : impl<P> Session<P> {
305 : /// Store the cancel token for the given session.
306 : /// This enables query cancellation in `crate::proxy::prepare_client_connection`.
307 0 : pub(crate) fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData {
308 0 : debug!("enabling query cancellation for this session");
309 0 : self.cancellation_handler
310 0 : .map
311 0 : .insert(self.key, Some(cancel_closure));
312 0 :
313 0 : self.key
314 0 : }
315 : }
316 :
317 : impl<P> Drop for Session<P> {
318 1 : fn drop(&mut self) {
319 1 : self.cancellation_handler.map.remove(&self.key);
320 1 : debug!("dropped query cancellation key {}", &self.key);
321 1 : }
322 : }
323 :
324 : #[cfg(test)]
325 : #[expect(clippy::unwrap_used)]
326 : mod tests {
327 : use std::time::Duration;
328 :
329 : use super::*;
330 : use crate::config::RetryConfig;
331 : use crate::tls::client_config::compute_client_config_with_certs;
332 :
333 2 : fn config() -> ComputeConfig {
334 2 : let retry = RetryConfig {
335 2 : base_delay: Duration::from_secs(1),
336 2 : max_retries: 5,
337 2 : backoff_factor: 2.0,
338 2 : };
339 2 :
340 2 : ComputeConfig {
341 2 : retry,
342 2 : tls: Arc::new(compute_client_config_with_certs(std::iter::empty())),
343 2 : timeout: Duration::from_secs(2),
344 2 : }
345 2 : }
346 :
347 : #[tokio::test]
348 1 : async fn check_session_drop() -> anyhow::Result<()> {
349 1 : let cancellation_handler = Arc::new(CancellationHandler::<()>::new(
350 1 : Box::leak(Box::new(config())),
351 1 : CancelMap::default(),
352 1 : CancellationSource::FromRedis,
353 1 : ));
354 1 :
355 1 : let session = cancellation_handler.clone().get_session();
356 1 : assert!(cancellation_handler.contains(&session));
357 1 : drop(session);
358 1 : // Check that the session has been dropped.
359 1 : assert!(cancellation_handler.is_empty());
360 1 :
361 1 : Ok(())
362 1 : }
363 :
364 : #[tokio::test]
365 1 : async fn cancel_session_noop_regression() {
366 1 : let handler = CancellationHandler::<()>::new(
367 1 : Box::leak(Box::new(config())),
368 1 : CancelMap::default(),
369 1 : CancellationSource::Local,
370 1 : );
371 1 : handler
372 1 : .cancel_session(
373 1 : CancelKeyData {
374 1 : backend_pid: 0,
375 1 : cancel_key: 0,
376 1 : },
377 1 : Uuid::new_v4(),
378 1 : "127.0.0.1".parse().unwrap(),
379 1 : true,
380 1 : )
381 1 : .await
382 1 : .unwrap();
383 1 : }
384 : }
|