Line data Source code
1 : use std::net::{IpAddr, SocketAddr};
2 : use std::sync::Arc;
3 :
4 : use dashmap::DashMap;
5 : use ipnet::{IpNet, Ipv4Net, Ipv6Net};
6 : use once_cell::sync::OnceCell;
7 : use postgres_client::tls::MakeTlsConnect;
8 : use postgres_client::CancelToken;
9 : use pq_proto::CancelKeyData;
10 : use rustls::crypto::ring;
11 : use thiserror::Error;
12 : use tokio::net::TcpStream;
13 : use tokio::sync::Mutex;
14 : use tracing::{debug, info};
15 : use uuid::Uuid;
16 :
17 : use crate::auth::{check_peer_addr_is_in_list, IpPattern};
18 : use crate::compute::load_certs;
19 : use crate::error::ReportableError;
20 : use crate::ext::LockExt;
21 : use crate::metrics::{CancellationRequest, CancellationSource, Metrics};
22 : use crate::postgres_rustls::MakeRustlsConnect;
23 : use crate::rate_limiter::LeakyBucketRateLimiter;
24 : use crate::redis::cancellation_publisher::{
25 : CancellationPublisher, CancellationPublisherMut, RedisPublisherClient,
26 : };
27 :
28 : pub type CancelMap = Arc<DashMap<CancelKeyData, Option<CancelClosure>>>;
29 : pub type CancellationHandlerMain = CancellationHandler<Option<Arc<Mutex<RedisPublisherClient>>>>;
30 : pub(crate) type CancellationHandlerMainInternal = Option<Arc<Mutex<RedisPublisherClient>>>;
31 :
32 : type IpSubnetKey = IpNet;
33 :
34 : /// Enables serving `CancelRequest`s.
35 : ///
36 : /// If `CancellationPublisher` is available, cancel request will be used to publish the cancellation key to other proxy instances.
37 : pub struct CancellationHandler<P> {
38 : map: CancelMap,
39 : client: P,
40 : /// This field used for the monitoring purposes.
41 : /// Represents the source of the cancellation request.
42 : from: CancellationSource,
43 : // rate limiter of cancellation requests
44 : limiter: Arc<std::sync::Mutex<LeakyBucketRateLimiter<IpSubnetKey>>>,
45 : }
46 :
47 : #[derive(Debug, Error)]
48 : pub(crate) enum CancelError {
49 : #[error("{0}")]
50 : IO(#[from] std::io::Error),
51 :
52 : #[error("{0}")]
53 : Postgres(#[from] postgres_client::Error),
54 :
55 : #[error("rate limit exceeded")]
56 : RateLimit,
57 :
58 : #[error("IP is not allowed")]
59 : IpNotAllowed,
60 : }
61 :
62 : impl ReportableError for CancelError {
63 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
64 0 : match self {
65 0 : CancelError::IO(_) => crate::error::ErrorKind::Compute,
66 0 : CancelError::Postgres(e) if e.as_db_error().is_some() => {
67 0 : crate::error::ErrorKind::Postgres
68 : }
69 0 : CancelError::Postgres(_) => crate::error::ErrorKind::Compute,
70 0 : CancelError::RateLimit => crate::error::ErrorKind::RateLimit,
71 0 : CancelError::IpNotAllowed => crate::error::ErrorKind::User,
72 : }
73 0 : }
74 : }
75 :
76 : impl<P: CancellationPublisher> CancellationHandler<P> {
77 : /// Run async action within an ephemeral session identified by [`CancelKeyData`].
78 1 : pub(crate) fn get_session(self: Arc<Self>) -> Session<P> {
79 : // we intentionally generate a random "backend pid" and "secret key" here.
80 : // we use the corresponding u64 as an identifier for the
81 : // actual endpoint+pid+secret for postgres/pgbouncer.
82 : //
83 : // if we forwarded the backend_pid from postgres to the client, there would be a lot
84 : // of overlap between our computes as most pids are small (~100).
85 1 : let key = loop {
86 1 : let key = rand::random();
87 1 :
88 1 : // Random key collisions are unlikely to happen here, but they're still possible,
89 1 : // which is why we have to take care not to rewrite an existing key.
90 1 : match self.map.entry(key) {
91 0 : dashmap::mapref::entry::Entry::Occupied(_) => continue,
92 1 : dashmap::mapref::entry::Entry::Vacant(e) => {
93 1 : e.insert(None);
94 1 : }
95 1 : }
96 1 : break key;
97 1 : };
98 1 :
99 1 : debug!("registered new query cancellation key {key}");
100 1 : Session {
101 1 : key,
102 1 : cancellation_handler: self,
103 1 : }
104 1 : }
105 :
106 : /// Try to cancel a running query for the corresponding connection.
107 : /// If the cancellation key is not found, it will be published to Redis.
108 : /// check_allowed - if true, check if the IP is allowed to cancel the query
109 : /// return Result primarily for tests
110 1 : pub(crate) async fn cancel_session(
111 1 : &self,
112 1 : key: CancelKeyData,
113 1 : session_id: Uuid,
114 1 : peer_addr: IpAddr,
115 1 : check_allowed: bool,
116 1 : ) -> Result<(), CancelError> {
117 1 : // TODO: check for unspecified address is only for backward compatibility, should be removed
118 1 : if !peer_addr.is_unspecified() {
119 1 : let subnet_key = match peer_addr {
120 1 : IpAddr::V4(ip) => IpNet::V4(Ipv4Net::new_assert(ip, 24).trunc()), // use defaut mask here
121 0 : IpAddr::V6(ip) => IpNet::V6(Ipv6Net::new_assert(ip, 64).trunc()),
122 : };
123 1 : if !self.limiter.lock_propagate_poison().check(subnet_key, 1) {
124 : // log only the subnet part of the IP address to know which subnet is rate limited
125 0 : tracing::warn!("Rate limit exceeded. Skipping cancellation message, {subnet_key}");
126 0 : Metrics::get()
127 0 : .proxy
128 0 : .cancellation_requests_total
129 0 : .inc(CancellationRequest {
130 0 : source: self.from,
131 0 : kind: crate::metrics::CancellationOutcome::RateLimitExceeded,
132 0 : });
133 0 : return Err(CancelError::RateLimit);
134 1 : }
135 0 : }
136 :
137 : // NB: we should immediately release the lock after cloning the token.
138 1 : let Some(cancel_closure) = self.map.get(&key).and_then(|x| x.clone()) else {
139 1 : tracing::warn!("query cancellation key not found: {key}");
140 1 : Metrics::get()
141 1 : .proxy
142 1 : .cancellation_requests_total
143 1 : .inc(CancellationRequest {
144 1 : source: self.from,
145 1 : kind: crate::metrics::CancellationOutcome::NotFound,
146 1 : });
147 1 :
148 1 : if session_id == Uuid::nil() {
149 : // was already published, do not publish it again
150 0 : return Ok(());
151 1 : }
152 1 :
153 1 : match self.client.try_publish(key, session_id, peer_addr).await {
154 1 : Ok(()) => {} // do nothing
155 0 : Err(e) => {
156 0 : // log it here since cancel_session could be spawned in a task
157 0 : tracing::error!("failed to publish cancellation key: {key}, error: {e}");
158 0 : return Err(CancelError::IO(std::io::Error::new(
159 0 : std::io::ErrorKind::Other,
160 0 : e.to_string(),
161 0 : )));
162 : }
163 : }
164 1 : return Ok(());
165 : };
166 :
167 0 : if check_allowed
168 0 : && !check_peer_addr_is_in_list(&peer_addr, cancel_closure.ip_allowlist.as_slice())
169 : {
170 : // log it here since cancel_session could be spawned in a task
171 0 : tracing::warn!("IP is not allowed to cancel the query: {key}");
172 0 : return Err(CancelError::IpNotAllowed);
173 0 : }
174 0 :
175 0 : Metrics::get()
176 0 : .proxy
177 0 : .cancellation_requests_total
178 0 : .inc(CancellationRequest {
179 0 : source: self.from,
180 0 : kind: crate::metrics::CancellationOutcome::Found,
181 0 : });
182 0 : info!(
183 0 : "cancelling query per user's request using key {key}, hostname {}, address: {}",
184 : cancel_closure.hostname, cancel_closure.socket_addr
185 : );
186 0 : cancel_closure.try_cancel_query().await
187 1 : }
188 :
189 : #[cfg(test)]
190 1 : fn contains(&self, session: &Session<P>) -> bool {
191 1 : self.map.contains_key(&session.key)
192 1 : }
193 :
194 : #[cfg(test)]
195 1 : fn is_empty(&self) -> bool {
196 1 : self.map.is_empty()
197 1 : }
198 : }
199 :
200 : impl CancellationHandler<()> {
201 2 : pub fn new(map: CancelMap, from: CancellationSource) -> Self {
202 2 : Self {
203 2 : map,
204 2 : client: (),
205 2 : from,
206 2 : limiter: Arc::new(std::sync::Mutex::new(
207 2 : LeakyBucketRateLimiter::<IpSubnetKey>::new_with_shards(
208 2 : LeakyBucketRateLimiter::<IpSubnetKey>::DEFAULT,
209 2 : 64,
210 2 : ),
211 2 : )),
212 2 : }
213 2 : }
214 : }
215 :
216 : impl<P: CancellationPublisherMut> CancellationHandler<Option<Arc<Mutex<P>>>> {
217 0 : pub fn new(map: CancelMap, client: Option<Arc<Mutex<P>>>, from: CancellationSource) -> Self {
218 0 : Self {
219 0 : map,
220 0 : client,
221 0 : from,
222 0 : limiter: Arc::new(std::sync::Mutex::new(
223 0 : LeakyBucketRateLimiter::<IpSubnetKey>::new_with_shards(
224 0 : LeakyBucketRateLimiter::<IpSubnetKey>::DEFAULT,
225 0 : 64,
226 0 : ),
227 0 : )),
228 0 : }
229 0 : }
230 : }
231 :
232 : static TLS_ROOTS: OnceCell<Arc<rustls::RootCertStore>> = OnceCell::new();
233 :
234 : /// This should've been a [`std::future::Future`], but
235 : /// it's impossible to name a type of an unboxed future
236 : /// (we'd need something like `#![feature(type_alias_impl_trait)]`).
237 : #[derive(Clone)]
238 : pub struct CancelClosure {
239 : socket_addr: SocketAddr,
240 : cancel_token: CancelToken,
241 : ip_allowlist: Vec<IpPattern>,
242 : hostname: String, // for pg_sni router
243 : }
244 :
245 : impl CancelClosure {
246 0 : pub(crate) fn new(
247 0 : socket_addr: SocketAddr,
248 0 : cancel_token: CancelToken,
249 0 : ip_allowlist: Vec<IpPattern>,
250 0 : hostname: String,
251 0 : ) -> Self {
252 0 : Self {
253 0 : socket_addr,
254 0 : cancel_token,
255 0 : ip_allowlist,
256 0 : hostname,
257 0 : }
258 0 : }
259 : /// Cancels the query running on user's compute node.
260 0 : pub(crate) async fn try_cancel_query(self) -> Result<(), CancelError> {
261 0 : let socket = TcpStream::connect(self.socket_addr).await?;
262 :
263 0 : let root_store = TLS_ROOTS
264 0 : .get_or_try_init(load_certs)
265 0 : .map_err(|_e| {
266 0 : CancelError::IO(std::io::Error::new(
267 0 : std::io::ErrorKind::Other,
268 0 : "TLS root store initialization failed".to_string(),
269 0 : ))
270 0 : })?
271 0 : .clone();
272 0 :
273 0 : let client_config =
274 0 : rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
275 0 : .with_safe_default_protocol_versions()
276 0 : .expect("ring should support the default protocol versions")
277 0 : .with_root_certificates(root_store)
278 0 : .with_no_client_auth();
279 0 :
280 0 : let mut mk_tls = crate::postgres_rustls::MakeRustlsConnect::new(client_config);
281 0 : let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
282 0 : &mut mk_tls,
283 0 : &self.hostname,
284 0 : )
285 0 : .map_err(|e| {
286 0 : CancelError::IO(std::io::Error::new(
287 0 : std::io::ErrorKind::Other,
288 0 : e.to_string(),
289 0 : ))
290 0 : })?;
291 :
292 0 : self.cancel_token.cancel_query_raw(socket, tls).await?;
293 0 : debug!("query was cancelled");
294 0 : Ok(())
295 0 : }
296 0 : pub(crate) fn set_ip_allowlist(&mut self, ip_allowlist: Vec<IpPattern>) {
297 0 : self.ip_allowlist = ip_allowlist;
298 0 : }
299 : }
300 :
301 : /// Helper for registering query cancellation tokens.
302 : pub(crate) struct Session<P> {
303 : /// The user-facing key identifying this session.
304 : key: CancelKeyData,
305 : /// The [`CancelMap`] this session belongs to.
306 : cancellation_handler: Arc<CancellationHandler<P>>,
307 : }
308 :
309 : impl<P> Session<P> {
310 : /// Store the cancel token for the given session.
311 : /// This enables query cancellation in `crate::proxy::prepare_client_connection`.
312 0 : pub(crate) fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData {
313 0 : debug!("enabling query cancellation for this session");
314 0 : self.cancellation_handler
315 0 : .map
316 0 : .insert(self.key, Some(cancel_closure));
317 0 :
318 0 : self.key
319 0 : }
320 : }
321 :
322 : impl<P> Drop for Session<P> {
323 1 : fn drop(&mut self) {
324 1 : self.cancellation_handler.map.remove(&self.key);
325 1 : debug!("dropped query cancellation key {}", &self.key);
326 1 : }
327 : }
328 :
329 : #[cfg(test)]
330 : #[expect(clippy::unwrap_used)]
331 : mod tests {
332 : use super::*;
333 :
334 : #[tokio::test]
335 1 : async fn check_session_drop() -> anyhow::Result<()> {
336 1 : let cancellation_handler = Arc::new(CancellationHandler::<()>::new(
337 1 : CancelMap::default(),
338 1 : CancellationSource::FromRedis,
339 1 : ));
340 1 :
341 1 : let session = cancellation_handler.clone().get_session();
342 1 : assert!(cancellation_handler.contains(&session));
343 1 : drop(session);
344 1 : // Check that the session has been dropped.
345 1 : assert!(cancellation_handler.is_empty());
346 1 :
347 1 : Ok(())
348 1 : }
349 :
350 : #[tokio::test]
351 1 : async fn cancel_session_noop_regression() {
352 1 : let handler =
353 1 : CancellationHandler::<()>::new(CancelMap::default(), CancellationSource::Local);
354 1 : handler
355 1 : .cancel_session(
356 1 : CancelKeyData {
357 1 : backend_pid: 0,
358 1 : cancel_key: 0,
359 1 : },
360 1 : Uuid::new_v4(),
361 1 : "127.0.0.1".parse().unwrap(),
362 1 : true,
363 1 : )
364 1 : .await
365 1 : .unwrap();
366 1 : }
367 : }
|