Line data Source code
1 : use std::convert::Infallible;
2 : use std::net::{IpAddr, SocketAddr};
3 : use std::pin::pin;
4 : use std::sync::{Arc, OnceLock};
5 : use std::time::Duration;
6 :
7 : use anyhow::anyhow;
8 : use futures::FutureExt;
9 : use ipnet::{IpNet, Ipv4Net, Ipv6Net};
10 : use postgres_client::RawCancelToken;
11 : use postgres_client::tls::MakeTlsConnect;
12 : use redis::{Cmd, FromRedisValue, Value};
13 : use serde::{Deserialize, Serialize};
14 : use thiserror::Error;
15 : use tokio::net::TcpStream;
16 : use tokio::time::timeout;
17 : use tracing::{debug, error, info};
18 :
19 : use crate::auth::AuthError;
20 : use crate::auth::backend::ComputeUserInfo;
21 : use crate::batch::{BatchQueue, QueueProcessing};
22 : use crate::config::ComputeConfig;
23 : use crate::context::RequestContext;
24 : use crate::control_plane::ControlPlaneApi;
25 : use crate::error::ReportableError;
26 : use crate::ext::LockExt;
27 : use crate::metrics::{CancelChannelSizeGuard, CancellationRequest, Metrics, RedisMsgKind};
28 : use crate::pqproto::CancelKeyData;
29 : use crate::rate_limiter::LeakyBucketRateLimiter;
30 : use crate::redis::keys::KeyPrefix;
31 : use crate::redis::kv_ops::RedisKVClient;
32 :
33 : type IpSubnetKey = IpNet;
34 :
35 : const CANCEL_KEY_TTL: std::time::Duration = std::time::Duration::from_secs(600);
36 : const CANCEL_KEY_REFRESH: std::time::Duration = std::time::Duration::from_secs(570);
37 :
38 : // Message types for sending through mpsc channel
39 : pub enum CancelKeyOp {
40 : StoreCancelKey {
41 : key: CancelKeyData,
42 : value: Box<str>,
43 : expire: std::time::Duration,
44 : },
45 : GetCancelData {
46 : key: CancelKeyData,
47 : },
48 : }
49 :
50 : pub struct Pipeline {
51 : inner: redis::Pipeline,
52 : replies: usize,
53 : }
54 :
55 : impl Pipeline {
56 0 : fn with_capacity(n: usize) -> Self {
57 0 : Self {
58 0 : inner: redis::Pipeline::with_capacity(n),
59 0 : replies: 0,
60 0 : }
61 0 : }
62 :
63 0 : async fn execute(self, client: &mut RedisKVClient) -> Vec<anyhow::Result<Value>> {
64 0 : let responses = self.replies;
65 0 : let batch_size = self.inner.len();
66 :
67 0 : if !client.credentials_refreshed() {
68 0 : tracing::debug!(
69 0 : "Redis credentials are not refreshed. Sleeping for 5 seconds before retrying..."
70 : );
71 0 : tokio::time::sleep(Duration::from_secs(5)).await;
72 0 : }
73 :
74 0 : match client.query(&self.inner).await {
75 : // for each reply, we expect that many values.
76 0 : Ok(Value::Array(values)) if values.len() == responses => {
77 0 : debug!(
78 : batch_size,
79 0 : responses, "successfully completed cancellation jobs",
80 : );
81 0 : values.into_iter().map(Ok).collect()
82 : }
83 0 : Ok(value) => {
84 0 : error!(batch_size, ?value, "unexpected redis return value");
85 0 : std::iter::repeat_with(|| Err(anyhow!("incorrect response type from redis")))
86 0 : .take(responses)
87 0 : .collect()
88 : }
89 0 : Err(err) => {
90 0 : std::iter::repeat_with(|| Err(anyhow!("could not send cmd to redis: {err}")))
91 0 : .take(responses)
92 0 : .collect()
93 : }
94 : }
95 0 : }
96 :
97 0 : fn add_command_with_reply(&mut self, cmd: Cmd) {
98 0 : self.inner.add_command(cmd);
99 0 : self.replies += 1;
100 0 : }
101 :
102 0 : fn add_command_no_reply(&mut self, cmd: Cmd) {
103 0 : self.inner.add_command(cmd).ignore();
104 0 : }
105 : }
106 :
107 : impl CancelKeyOp {
108 0 : fn register(&self, pipe: &mut Pipeline) {
109 0 : match self {
110 0 : CancelKeyOp::StoreCancelKey { key, value, expire } => {
111 0 : let key = KeyPrefix::Cancel(*key).build_redis_key();
112 0 : pipe.add_command_with_reply(Cmd::hset(&key, "data", &**value));
113 0 : pipe.add_command_no_reply(Cmd::expire(&key, expire.as_secs() as i64));
114 0 : }
115 0 : CancelKeyOp::GetCancelData { key } => {
116 0 : let key = KeyPrefix::Cancel(*key).build_redis_key();
117 0 : pipe.add_command_with_reply(Cmd::hget(key, "data"));
118 0 : }
119 : }
120 0 : }
121 : }
122 :
123 : pub struct CancellationProcessor {
124 : pub client: RedisKVClient,
125 : pub batch_size: usize,
126 : }
127 :
128 : impl QueueProcessing for CancellationProcessor {
129 : type Req = (CancelChannelSizeGuard<'static>, CancelKeyOp);
130 : type Res = anyhow::Result<redis::Value>;
131 :
132 0 : fn batch_size(&self, _queue_size: usize) -> usize {
133 0 : self.batch_size
134 0 : }
135 :
136 0 : async fn apply(&mut self, batch: Vec<Self::Req>) -> Vec<Self::Res> {
137 0 : if !self.client.credentials_refreshed() {
138 : // this will cause a timeout for cancellation operations
139 0 : tracing::debug!(
140 0 : "Redis credentials are not refreshed. Sleeping for 5 seconds before retrying..."
141 : );
142 0 : tokio::time::sleep(Duration::from_secs(5)).await;
143 0 : }
144 :
145 0 : let mut pipeline = Pipeline::with_capacity(batch.len());
146 :
147 0 : let batch_size = batch.len();
148 0 : debug!(batch_size, "running cancellation jobs");
149 :
150 0 : for (_, op) in &batch {
151 0 : op.register(&mut pipeline);
152 0 : }
153 :
154 0 : pipeline.execute(&mut self.client).await
155 0 : }
156 : }
157 :
158 : /// Enables serving `CancelRequest`s.
159 : ///
160 : /// If `CancellationPublisher` is available, cancel request will be used to publish the cancellation key to other proxy instances.
161 : pub struct CancellationHandler {
162 : compute_config: &'static ComputeConfig,
163 : // rate limiter of cancellation requests
164 : limiter: Arc<std::sync::Mutex<LeakyBucketRateLimiter<IpSubnetKey>>>,
165 : tx: OnceLock<BatchQueue<CancellationProcessor>>, // send messages to the redis KV client task
166 : }
167 :
168 : #[derive(Debug, Error)]
169 : pub(crate) enum CancelError {
170 : #[error("{0}")]
171 : IO(#[from] std::io::Error),
172 :
173 : #[error("{0}")]
174 : Postgres(#[from] postgres_client::Error),
175 :
176 : #[error("rate limit exceeded")]
177 : RateLimit,
178 :
179 : #[error("Authentication error")]
180 : AuthError(#[from] AuthError),
181 :
182 : #[error("key not found")]
183 : NotFound,
184 :
185 : #[error("proxy service error")]
186 : InternalError,
187 : }
188 :
189 : impl ReportableError for CancelError {
190 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
191 0 : match self {
192 0 : CancelError::IO(_) => crate::error::ErrorKind::Compute,
193 0 : CancelError::Postgres(e) if e.as_db_error().is_some() => {
194 0 : crate::error::ErrorKind::Postgres
195 : }
196 0 : CancelError::Postgres(_) => crate::error::ErrorKind::Compute,
197 0 : CancelError::RateLimit => crate::error::ErrorKind::RateLimit,
198 0 : CancelError::NotFound | CancelError::AuthError(_) => crate::error::ErrorKind::User,
199 0 : CancelError::InternalError => crate::error::ErrorKind::Service,
200 : }
201 0 : }
202 : }
203 :
204 : impl CancellationHandler {
205 0 : pub fn new(compute_config: &'static ComputeConfig) -> Self {
206 0 : Self {
207 0 : compute_config,
208 0 : tx: OnceLock::new(),
209 0 : limiter: Arc::new(std::sync::Mutex::new(
210 0 : LeakyBucketRateLimiter::<IpSubnetKey>::new_with_shards(
211 0 : LeakyBucketRateLimiter::<IpSubnetKey>::DEFAULT,
212 0 : 64,
213 0 : ),
214 0 : )),
215 0 : }
216 0 : }
217 :
218 0 : pub fn init_tx(&self, queue: BatchQueue<CancellationProcessor>) {
219 0 : self.tx
220 0 : .set(queue)
221 0 : .map_err(|_| {})
222 0 : .expect("cancellation queue should be registered once");
223 0 : }
224 :
225 0 : pub(crate) fn get_key(self: Arc<Self>) -> Session {
226 : // we intentionally generate a random "backend pid" and "secret key" here.
227 : // we use the corresponding u64 as an identifier for the
228 : // actual endpoint+pid+secret for postgres/pgbouncer.
229 : //
230 : // if we forwarded the backend_pid from postgres to the client, there would be a lot
231 : // of overlap between our computes as most pids are small (~100).
232 :
233 0 : let key: CancelKeyData = rand::random();
234 :
235 0 : debug!("registered new query cancellation key {key}");
236 0 : Session {
237 0 : key,
238 0 : cancellation_handler: self,
239 0 : }
240 0 : }
241 :
242 : /// This is not cancel safe
243 0 : async fn get_cancel_key(
244 0 : &self,
245 0 : key: CancelKeyData,
246 0 : ) -> Result<Option<CancelClosure>, CancelError> {
247 0 : let guard = Metrics::get()
248 0 : .proxy
249 0 : .cancel_channel_size
250 0 : .guard(RedisMsgKind::HGet);
251 0 : let op = CancelKeyOp::GetCancelData { key };
252 :
253 0 : let Some(tx) = self.tx.get() else {
254 0 : tracing::warn!("cancellation handler is not available");
255 0 : return Err(CancelError::InternalError);
256 : };
257 :
258 : const TIMEOUT: Duration = Duration::from_secs(5);
259 0 : let result = timeout(
260 0 : TIMEOUT,
261 0 : tx.call((guard, op), std::future::pending::<Infallible>()),
262 0 : )
263 0 : .await
264 0 : .map_err(|_| {
265 0 : tracing::warn!("timed out waiting to receive GetCancelData response");
266 0 : CancelError::RateLimit
267 0 : })?
268 : // cannot be cancelled
269 0 : .unwrap_or_else(|x| match x {})
270 0 : .map_err(|e| {
271 0 : tracing::warn!("failed to receive GetCancelData response: {e}");
272 0 : CancelError::InternalError
273 0 : })?;
274 :
275 0 : let cancel_state_str = String::from_owned_redis_value(result).map_err(|e| {
276 0 : tracing::warn!("failed to receive GetCancelData response: {e}");
277 0 : CancelError::InternalError
278 0 : })?;
279 :
280 0 : let cancel_closure: CancelClosure =
281 0 : serde_json::from_str(&cancel_state_str).map_err(|e| {
282 0 : tracing::warn!("failed to deserialize cancel state: {e}");
283 0 : CancelError::InternalError
284 0 : })?;
285 :
286 0 : Ok(Some(cancel_closure))
287 0 : }
288 :
289 : /// Try to cancel a running query for the corresponding connection.
290 : /// If the cancellation key is not found, it will be published to Redis.
291 : /// check_allowed - if true, check if the IP is allowed to cancel the query.
292 : /// Will fetch IP allowlist internally.
293 : ///
294 : /// return Result primarily for tests
295 : ///
296 : /// This is not cancel safe
297 0 : pub(crate) async fn cancel_session<T: ControlPlaneApi>(
298 0 : &self,
299 0 : key: CancelKeyData,
300 0 : ctx: RequestContext,
301 0 : check_ip_allowed: bool,
302 0 : check_vpc_allowed: bool,
303 0 : auth_backend: &T,
304 0 : ) -> Result<(), CancelError> {
305 0 : let subnet_key = match ctx.peer_addr() {
306 0 : IpAddr::V4(ip) => IpNet::V4(Ipv4Net::new_assert(ip, 24).trunc()), // use defaut mask here
307 0 : IpAddr::V6(ip) => IpNet::V6(Ipv6Net::new_assert(ip, 64).trunc()),
308 : };
309 :
310 0 : let allowed = {
311 0 : let rate_limit_config = None;
312 0 : let limiter = self.limiter.lock_propagate_poison();
313 0 : limiter.check(subnet_key, rate_limit_config, 1)
314 : };
315 0 : if !allowed {
316 : // log only the subnet part of the IP address to know which subnet is rate limited
317 0 : tracing::warn!("Rate limit exceeded. Skipping cancellation message, {subnet_key}");
318 0 : Metrics::get()
319 0 : .proxy
320 0 : .cancellation_requests_total
321 0 : .inc(CancellationRequest {
322 0 : kind: crate::metrics::CancellationOutcome::RateLimitExceeded,
323 0 : });
324 0 : return Err(CancelError::RateLimit);
325 0 : }
326 :
327 0 : let cancel_state = self.get_cancel_key(key).await.map_err(|e| {
328 0 : tracing::warn!("failed to receive RedisOp response: {e}");
329 0 : CancelError::InternalError
330 0 : })?;
331 :
332 0 : let Some(cancel_closure) = cancel_state else {
333 0 : tracing::warn!("query cancellation key not found: {key}");
334 0 : Metrics::get()
335 0 : .proxy
336 0 : .cancellation_requests_total
337 0 : .inc(CancellationRequest {
338 0 : kind: crate::metrics::CancellationOutcome::NotFound,
339 0 : });
340 0 : return Err(CancelError::NotFound);
341 : };
342 :
343 0 : let info = &cancel_closure.user_info;
344 0 : let access_controls = auth_backend
345 0 : .get_endpoint_access_control(&ctx, &info.endpoint, &info.user)
346 0 : .await
347 0 : .map_err(|e| CancelError::AuthError(e.into()))?;
348 :
349 0 : access_controls.check(&ctx, check_ip_allowed, check_vpc_allowed)?;
350 :
351 0 : Metrics::get()
352 0 : .proxy
353 0 : .cancellation_requests_total
354 0 : .inc(CancellationRequest {
355 0 : kind: crate::metrics::CancellationOutcome::Found,
356 0 : });
357 0 : info!("cancelling query per user's request using key {key}");
358 0 : cancel_closure.try_cancel_query(self.compute_config).await
359 0 : }
360 : }
361 :
362 : /// This should've been a [`std::future::Future`], but
363 : /// it's impossible to name a type of an unboxed future
364 : /// (we'd need something like `#![feature(type_alias_impl_trait)]`).
365 0 : #[derive(Debug, Clone, Serialize, Deserialize)]
366 : pub struct CancelClosure {
367 : socket_addr: SocketAddr,
368 : cancel_token: RawCancelToken,
369 : hostname: String, // for pg_sni router
370 : user_info: ComputeUserInfo,
371 : }
372 :
373 : impl CancelClosure {
374 0 : pub(crate) fn new(
375 0 : socket_addr: SocketAddr,
376 0 : cancel_token: RawCancelToken,
377 0 : hostname: String,
378 0 : user_info: ComputeUserInfo,
379 0 : ) -> Self {
380 0 : Self {
381 0 : socket_addr,
382 0 : cancel_token,
383 0 : hostname,
384 0 : user_info,
385 0 : }
386 0 : }
387 : /// Cancels the query running on user's compute node.
388 0 : pub(crate) async fn try_cancel_query(
389 0 : &self,
390 0 : compute_config: &ComputeConfig,
391 0 : ) -> Result<(), CancelError> {
392 0 : let socket = TcpStream::connect(self.socket_addr).await?;
393 :
394 0 : let tls = <_ as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
395 0 : compute_config,
396 0 : &self.hostname,
397 : )
398 0 : .map_err(|e| CancelError::IO(std::io::Error::other(e.to_string())))?;
399 :
400 0 : self.cancel_token.cancel_query_raw(socket, tls).await?;
401 0 : debug!("query was cancelled");
402 0 : Ok(())
403 0 : }
404 : }
405 :
406 : /// Helper for registering query cancellation tokens.
407 : pub(crate) struct Session {
408 : /// The user-facing key identifying this session.
409 : key: CancelKeyData,
410 : cancellation_handler: Arc<CancellationHandler>,
411 : }
412 :
413 : impl Session {
414 0 : pub(crate) fn key(&self) -> &CancelKeyData {
415 0 : &self.key
416 0 : }
417 :
418 : /// Ensure the cancel key is continously refreshed,
419 : /// but stop when the channel is dropped.
420 : ///
421 : /// This is not cancel safe
422 0 : pub(crate) async fn maintain_cancel_key(
423 0 : &self,
424 0 : session_id: uuid::Uuid,
425 0 : cancel: tokio::sync::oneshot::Receiver<Infallible>,
426 0 : cancel_closure: &CancelClosure,
427 0 : compute_config: &ComputeConfig,
428 0 : ) {
429 0 : let Some(tx) = self.cancellation_handler.tx.get() else {
430 0 : tracing::warn!("cancellation handler is not available");
431 : // don't exit, as we only want to exit if cancelled externally.
432 0 : std::future::pending().await
433 : };
434 :
435 0 : let closure_json = serde_json::to_string(&cancel_closure)
436 0 : .expect("serialising to json string should not fail")
437 0 : .into_boxed_str();
438 :
439 0 : let mut cancel = pin!(cancel);
440 :
441 : loop {
442 0 : let guard = Metrics::get()
443 0 : .proxy
444 0 : .cancel_channel_size
445 0 : .guard(RedisMsgKind::HSet);
446 0 : let op = CancelKeyOp::StoreCancelKey {
447 0 : key: self.key,
448 0 : value: closure_json.clone(),
449 0 : expire: CANCEL_KEY_TTL,
450 0 : };
451 :
452 0 : tracing::debug!(
453 : src=%self.key,
454 : dest=?cancel_closure.cancel_token,
455 0 : "registering cancellation key"
456 : );
457 :
458 0 : match tx.call((guard, op), cancel.as_mut()).await {
459 : Ok(Ok(_)) => {
460 0 : tracing::debug!(
461 : src=%self.key,
462 : dest=?cancel_closure.cancel_token,
463 0 : "registered cancellation key"
464 : );
465 :
466 : // wait before continuing.
467 0 : tokio::time::sleep(CANCEL_KEY_REFRESH).await;
468 : }
469 : // retry immediately.
470 0 : Ok(Err(error)) => {
471 0 : tracing::warn!(?error, "error registering cancellation key");
472 : }
473 0 : Err(Err(_cancelled)) => break,
474 : }
475 : }
476 :
477 0 : if let Err(err) = cancel_closure
478 0 : .try_cancel_query(compute_config)
479 0 : .boxed()
480 0 : .await
481 : {
482 0 : tracing::warn!(
483 : ?session_id,
484 : ?err,
485 0 : "could not cancel the query in the database"
486 : );
487 0 : }
488 0 : }
489 : }
|