Line data Source code
1 : use std::net::{IpAddr, SocketAddr};
2 : use std::sync::Arc;
3 :
4 : use dashmap::DashMap;
5 : use ipnet::{IpNet, Ipv4Net, Ipv6Net};
6 : use postgres_client::{CancelToken, NoTls};
7 : use pq_proto::CancelKeyData;
8 : use thiserror::Error;
9 : use tokio::net::TcpStream;
10 : use tokio::sync::Mutex;
11 : use tracing::{debug, info};
12 : use uuid::Uuid;
13 :
14 : use crate::auth::{check_peer_addr_is_in_list, IpPattern};
15 : use crate::error::ReportableError;
16 : use crate::metrics::{CancellationRequest, CancellationSource, Metrics};
17 : use crate::rate_limiter::LeakyBucketRateLimiter;
18 : use crate::redis::cancellation_publisher::{
19 : CancellationPublisher, CancellationPublisherMut, RedisPublisherClient,
20 : };
21 :
22 : pub type CancelMap = Arc<DashMap<CancelKeyData, Option<CancelClosure>>>;
23 : pub type CancellationHandlerMain = CancellationHandler<Option<Arc<Mutex<RedisPublisherClient>>>>;
24 : pub(crate) type CancellationHandlerMainInternal = Option<Arc<Mutex<RedisPublisherClient>>>;
25 :
26 : type IpSubnetKey = IpNet;
27 :
28 : /// Enables serving `CancelRequest`s.
29 : ///
30 : /// If `CancellationPublisher` is available, cancel request will be used to publish the cancellation key to other proxy instances.
31 : pub struct CancellationHandler<P> {
32 : map: CancelMap,
33 : client: P,
34 : /// This field used for the monitoring purposes.
35 : /// Represents the source of the cancellation request.
36 : from: CancellationSource,
37 : // rate limiter of cancellation requests
38 : limiter: Arc<std::sync::Mutex<LeakyBucketRateLimiter<IpSubnetKey>>>,
39 : }
40 :
41 : #[derive(Debug, Error)]
42 : pub(crate) enum CancelError {
43 : #[error("{0}")]
44 : IO(#[from] std::io::Error),
45 :
46 : #[error("{0}")]
47 : Postgres(#[from] postgres_client::Error),
48 :
49 : #[error("rate limit exceeded")]
50 : RateLimit,
51 :
52 : #[error("IP is not allowed")]
53 : IpNotAllowed,
54 : }
55 :
56 : impl ReportableError for CancelError {
57 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
58 0 : match self {
59 0 : CancelError::IO(_) => crate::error::ErrorKind::Compute,
60 0 : CancelError::Postgres(e) if e.as_db_error().is_some() => {
61 0 : crate::error::ErrorKind::Postgres
62 : }
63 0 : CancelError::Postgres(_) => crate::error::ErrorKind::Compute,
64 0 : CancelError::RateLimit => crate::error::ErrorKind::RateLimit,
65 0 : CancelError::IpNotAllowed => crate::error::ErrorKind::User,
66 : }
67 0 : }
68 : }
69 :
70 : impl<P: CancellationPublisher> CancellationHandler<P> {
71 : /// Run async action within an ephemeral session identified by [`CancelKeyData`].
72 1 : pub(crate) fn get_session(self: Arc<Self>) -> Session<P> {
73 : // we intentionally generate a random "backend pid" and "secret key" here.
74 : // we use the corresponding u64 as an identifier for the
75 : // actual endpoint+pid+secret for postgres/pgbouncer.
76 : //
77 : // if we forwarded the backend_pid from postgres to the client, there would be a lot
78 : // of overlap between our computes as most pids are small (~100).
79 1 : let key = loop {
80 1 : let key = rand::random();
81 1 :
82 1 : // Random key collisions are unlikely to happen here, but they're still possible,
83 1 : // which is why we have to take care not to rewrite an existing key.
84 1 : match self.map.entry(key) {
85 0 : dashmap::mapref::entry::Entry::Occupied(_) => continue,
86 1 : dashmap::mapref::entry::Entry::Vacant(e) => {
87 1 : e.insert(None);
88 1 : }
89 1 : }
90 1 : break key;
91 1 : };
92 1 :
93 1 : debug!("registered new query cancellation key {key}");
94 1 : Session {
95 1 : key,
96 1 : cancellation_handler: self,
97 1 : }
98 1 : }
99 :
100 : /// Try to cancel a running query for the corresponding connection.
101 : /// If the cancellation key is not found, it will be published to Redis.
102 : /// check_allowed - if true, check if the IP is allowed to cancel the query
103 : /// return Result primarily for tests
104 1 : pub(crate) async fn cancel_session(
105 1 : &self,
106 1 : key: CancelKeyData,
107 1 : session_id: Uuid,
108 1 : peer_addr: IpAddr,
109 1 : check_allowed: bool,
110 1 : ) -> Result<(), CancelError> {
111 1 : // TODO: check for unspecified address is only for backward compatibility, should be removed
112 1 : if !peer_addr.is_unspecified() {
113 1 : let subnet_key = match peer_addr {
114 1 : IpAddr::V4(ip) => IpNet::V4(Ipv4Net::new_assert(ip, 24).trunc()), // use defaut mask here
115 0 : IpAddr::V6(ip) => IpNet::V6(Ipv6Net::new_assert(ip, 64).trunc()),
116 : };
117 1 : if !self.limiter.lock().unwrap().check(subnet_key, 1) {
118 : // log only the subnet part of the IP address to know which subnet is rate limited
119 0 : tracing::warn!("Rate limit exceeded. Skipping cancellation message, {subnet_key}");
120 0 : Metrics::get()
121 0 : .proxy
122 0 : .cancellation_requests_total
123 0 : .inc(CancellationRequest {
124 0 : source: self.from,
125 0 : kind: crate::metrics::CancellationOutcome::RateLimitExceeded,
126 0 : });
127 0 : return Err(CancelError::RateLimit);
128 1 : }
129 0 : }
130 :
131 : // NB: we should immediately release the lock after cloning the token.
132 1 : let Some(cancel_closure) = self.map.get(&key).and_then(|x| x.clone()) else {
133 1 : tracing::warn!("query cancellation key not found: {key}");
134 1 : Metrics::get()
135 1 : .proxy
136 1 : .cancellation_requests_total
137 1 : .inc(CancellationRequest {
138 1 : source: self.from,
139 1 : kind: crate::metrics::CancellationOutcome::NotFound,
140 1 : });
141 1 :
142 1 : if session_id == Uuid::nil() {
143 : // was already published, do not publish it again
144 0 : return Ok(());
145 1 : }
146 1 :
147 1 : match self.client.try_publish(key, session_id, peer_addr).await {
148 1 : Ok(()) => {} // do nothing
149 0 : Err(e) => {
150 0 : // log it here since cancel_session could be spawned in a task
151 0 : tracing::error!("failed to publish cancellation key: {key}, error: {e}");
152 0 : return Err(CancelError::IO(std::io::Error::new(
153 0 : std::io::ErrorKind::Other,
154 0 : e.to_string(),
155 0 : )));
156 : }
157 : }
158 1 : return Ok(());
159 : };
160 :
161 0 : if check_allowed
162 0 : && !check_peer_addr_is_in_list(&peer_addr, cancel_closure.ip_allowlist.as_slice())
163 : {
164 : // log it here since cancel_session could be spawned in a task
165 0 : tracing::warn!("IP is not allowed to cancel the query: {key}");
166 0 : return Err(CancelError::IpNotAllowed);
167 0 : }
168 0 :
169 0 : Metrics::get()
170 0 : .proxy
171 0 : .cancellation_requests_total
172 0 : .inc(CancellationRequest {
173 0 : source: self.from,
174 0 : kind: crate::metrics::CancellationOutcome::Found,
175 0 : });
176 0 : info!("cancelling query per user's request using key {key}");
177 0 : cancel_closure.try_cancel_query().await
178 1 : }
179 :
180 : #[cfg(test)]
181 1 : fn contains(&self, session: &Session<P>) -> bool {
182 1 : self.map.contains_key(&session.key)
183 1 : }
184 :
185 : #[cfg(test)]
186 1 : fn is_empty(&self) -> bool {
187 1 : self.map.is_empty()
188 1 : }
189 : }
190 :
191 : impl CancellationHandler<()> {
192 2 : pub fn new(map: CancelMap, from: CancellationSource) -> Self {
193 2 : Self {
194 2 : map,
195 2 : client: (),
196 2 : from,
197 2 : limiter: Arc::new(std::sync::Mutex::new(
198 2 : LeakyBucketRateLimiter::<IpSubnetKey>::new_with_shards(
199 2 : LeakyBucketRateLimiter::<IpSubnetKey>::DEFAULT,
200 2 : 64,
201 2 : ),
202 2 : )),
203 2 : }
204 2 : }
205 : }
206 :
207 : impl<P: CancellationPublisherMut> CancellationHandler<Option<Arc<Mutex<P>>>> {
208 0 : pub fn new(map: CancelMap, client: Option<Arc<Mutex<P>>>, from: CancellationSource) -> Self {
209 0 : Self {
210 0 : map,
211 0 : client,
212 0 : from,
213 0 : limiter: Arc::new(std::sync::Mutex::new(
214 0 : LeakyBucketRateLimiter::<IpSubnetKey>::new_with_shards(
215 0 : LeakyBucketRateLimiter::<IpSubnetKey>::DEFAULT,
216 0 : 64,
217 0 : ),
218 0 : )),
219 0 : }
220 0 : }
221 : }
222 :
223 : /// This should've been a [`std::future::Future`], but
224 : /// it's impossible to name a type of an unboxed future
225 : /// (we'd need something like `#![feature(type_alias_impl_trait)]`).
226 : #[derive(Clone)]
227 : pub struct CancelClosure {
228 : socket_addr: SocketAddr,
229 : cancel_token: CancelToken,
230 : ip_allowlist: Vec<IpPattern>,
231 : }
232 :
233 : impl CancelClosure {
234 0 : pub(crate) fn new(
235 0 : socket_addr: SocketAddr,
236 0 : cancel_token: CancelToken,
237 0 : ip_allowlist: Vec<IpPattern>,
238 0 : ) -> Self {
239 0 : Self {
240 0 : socket_addr,
241 0 : cancel_token,
242 0 : ip_allowlist,
243 0 : }
244 0 : }
245 : /// Cancels the query running on user's compute node.
246 0 : pub(crate) async fn try_cancel_query(self) -> Result<(), CancelError> {
247 0 : let socket = TcpStream::connect(self.socket_addr).await?;
248 0 : self.cancel_token.cancel_query_raw(socket, NoTls).await?;
249 0 : debug!("query was cancelled");
250 0 : Ok(())
251 0 : }
252 0 : pub(crate) fn set_ip_allowlist(&mut self, ip_allowlist: Vec<IpPattern>) {
253 0 : self.ip_allowlist = ip_allowlist;
254 0 : }
255 : }
256 :
257 : /// Helper for registering query cancellation tokens.
258 : pub(crate) struct Session<P> {
259 : /// The user-facing key identifying this session.
260 : key: CancelKeyData,
261 : /// The [`CancelMap`] this session belongs to.
262 : cancellation_handler: Arc<CancellationHandler<P>>,
263 : }
264 :
265 : impl<P> Session<P> {
266 : /// Store the cancel token for the given session.
267 : /// This enables query cancellation in `crate::proxy::prepare_client_connection`.
268 0 : pub(crate) fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData {
269 0 : debug!("enabling query cancellation for this session");
270 0 : self.cancellation_handler
271 0 : .map
272 0 : .insert(self.key, Some(cancel_closure));
273 0 :
274 0 : self.key
275 0 : }
276 : }
277 :
278 : impl<P> Drop for Session<P> {
279 1 : fn drop(&mut self) {
280 1 : self.cancellation_handler.map.remove(&self.key);
281 1 : debug!("dropped query cancellation key {}", &self.key);
282 1 : }
283 : }
284 :
285 : #[cfg(test)]
286 : mod tests {
287 : use super::*;
288 :
289 : #[tokio::test]
290 1 : async fn check_session_drop() -> anyhow::Result<()> {
291 1 : let cancellation_handler = Arc::new(CancellationHandler::<()>::new(
292 1 : CancelMap::default(),
293 1 : CancellationSource::FromRedis,
294 1 : ));
295 1 :
296 1 : let session = cancellation_handler.clone().get_session();
297 1 : assert!(cancellation_handler.contains(&session));
298 1 : drop(session);
299 1 : // Check that the session has been dropped.
300 1 : assert!(cancellation_handler.is_empty());
301 1 :
302 1 : Ok(())
303 1 : }
304 :
305 : #[tokio::test]
306 1 : async fn cancel_session_noop_regression() {
307 1 : let handler =
308 1 : CancellationHandler::<()>::new(CancelMap::default(), CancellationSource::Local);
309 1 : handler
310 1 : .cancel_session(
311 1 : CancelKeyData {
312 1 : backend_pid: 0,
313 1 : cancel_key: 0,
314 1 : },
315 1 : Uuid::new_v4(),
316 1 : "127.0.0.1".parse().unwrap(),
317 1 : true,
318 1 : )
319 1 : .await
320 1 : .unwrap();
321 1 : }
322 : }
|