Line data Source code
1 : use std::convert::Infallible;
2 : use std::net::{IpAddr, SocketAddr};
3 : use std::pin::pin;
4 : use std::sync::{Arc, OnceLock};
5 : use std::time::Duration;
6 :
7 : use futures::FutureExt;
8 : use ipnet::{IpNet, Ipv4Net, Ipv6Net};
9 : use postgres_client::RawCancelToken;
10 : use postgres_client::tls::MakeTlsConnect;
11 : use redis::{Cmd, FromRedisValue, SetExpiry, SetOptions, Value};
12 : use serde::{Deserialize, Serialize};
13 : use thiserror::Error;
14 : use tokio::net::TcpStream;
15 : use tokio::time::timeout;
16 : use tracing::{debug, error, info};
17 :
18 : use crate::auth::AuthError;
19 : use crate::auth::backend::ComputeUserInfo;
20 : use crate::batch::{BatchQueue, BatchQueueError, QueueProcessing};
21 : use crate::config::ComputeConfig;
22 : use crate::context::RequestContext;
23 : use crate::control_plane::ControlPlaneApi;
24 : use crate::error::ReportableError;
25 : use crate::ext::LockExt;
26 : use crate::metrics::{CancelChannelSizeGuard, CancellationRequest, Metrics, RedisMsgKind};
27 : use crate::pqproto::CancelKeyData;
28 : use crate::rate_limiter::LeakyBucketRateLimiter;
29 : use crate::redis::keys::KeyPrefix;
30 : use crate::redis::kv_ops::{RedisKVClient, RedisKVClientError};
31 : use crate::util::run_until;
32 :
33 : type IpSubnetKey = IpNet;
34 :
35 : /// Initial period and TTL is shorter to clear keys of short-lived connections faster.
36 : const CANCEL_KEY_INITIAL_PERIOD: Duration = Duration::from_secs(60);
37 : const CANCEL_KEY_REFRESH_PERIOD: Duration = Duration::from_secs(10 * 60);
38 : /// `CANCEL_KEY_TTL_SLACK` is added to the periods to determine the actual TTL.
39 : const CANCEL_KEY_TTL_SLACK: Duration = Duration::from_secs(30);
40 :
41 : // Message types for sending through mpsc channel
42 : pub enum CancelKeyOp {
43 : Store {
44 : key: CancelKeyData,
45 : value: Box<str>,
46 : expire: Duration,
47 : },
48 : Refresh {
49 : key: CancelKeyData,
50 : expire: Duration,
51 : },
52 : Get {
53 : key: CancelKeyData,
54 : },
55 : GetOld {
56 : key: CancelKeyData,
57 : },
58 : }
59 :
60 : impl CancelKeyOp {
61 0 : const fn redis_msg_kind(&self) -> RedisMsgKind {
62 0 : match self {
63 0 : CancelKeyOp::Store { .. } => RedisMsgKind::Set,
64 0 : CancelKeyOp::Refresh { .. } => RedisMsgKind::Expire,
65 0 : CancelKeyOp::Get { .. } => RedisMsgKind::Get,
66 0 : CancelKeyOp::GetOld { .. } => RedisMsgKind::HGet,
67 : }
68 0 : }
69 :
70 0 : fn cancel_channel_metric_guard(&self) -> CancelChannelSizeGuard<'static> {
71 0 : Metrics::get()
72 0 : .proxy
73 0 : .cancel_channel_size
74 0 : .guard(self.redis_msg_kind())
75 0 : }
76 : }
77 :
78 : #[derive(thiserror::Error, Debug, Clone)]
79 : pub enum PipelineError {
80 : #[error("could not send cmd to redis: {0}")]
81 : RedisKVClient(Arc<RedisKVClientError>),
82 : #[error("incorrect number of responses from redis")]
83 : IncorrectNumberOfResponses,
84 : }
85 :
86 : pub struct Pipeline {
87 : inner: redis::Pipeline,
88 : replies: usize,
89 : }
90 :
91 : impl Pipeline {
92 0 : fn with_capacity(n: usize) -> Self {
93 0 : Self {
94 0 : inner: redis::Pipeline::with_capacity(n),
95 0 : replies: 0,
96 0 : }
97 0 : }
98 :
99 0 : async fn execute(self, client: &mut RedisKVClient) -> Result<Vec<Value>, PipelineError> {
100 0 : let responses = self.replies;
101 0 : let batch_size = self.inner.len();
102 :
103 0 : if !client.credentials_refreshed() {
104 0 : tracing::debug!(
105 0 : "Redis credentials are not refreshed. Sleeping for 5 seconds before retrying..."
106 : );
107 0 : tokio::time::sleep(Duration::from_secs(5)).await;
108 0 : }
109 :
110 0 : match client.query(&self.inner).await {
111 : // for each reply, we expect that many values.
112 0 : Ok(Value::Array(values)) if values.len() == responses => {
113 0 : debug!(
114 : batch_size,
115 0 : responses, "successfully completed cancellation jobs",
116 : );
117 0 : Ok(values.into_iter().collect())
118 : }
119 0 : Ok(value) => {
120 0 : error!(batch_size, ?value, "unexpected redis return value");
121 0 : Err(PipelineError::IncorrectNumberOfResponses)
122 : }
123 0 : Err(err) => Err(PipelineError::RedisKVClient(Arc::new(err))),
124 : }
125 0 : }
126 :
127 0 : fn add_command(&mut self, cmd: Cmd) {
128 0 : self.inner.add_command(cmd);
129 0 : self.replies += 1;
130 0 : }
131 : }
132 :
133 : impl CancelKeyOp {
134 0 : fn register(&self, pipe: &mut Pipeline) {
135 0 : match self {
136 0 : CancelKeyOp::Store { key, value, expire } => {
137 0 : let key = KeyPrefix::Cancel(*key).build_redis_key();
138 0 : pipe.add_command(Cmd::set_options(
139 0 : &key,
140 0 : &**value,
141 0 : SetOptions::default().with_expiration(SetExpiry::EX(expire.as_secs())),
142 0 : ));
143 0 : }
144 0 : CancelKeyOp::Refresh { key, expire } => {
145 0 : let key = KeyPrefix::Cancel(*key).build_redis_key();
146 0 : pipe.add_command(Cmd::expire(&key, expire.as_secs() as i64));
147 0 : }
148 0 : CancelKeyOp::GetOld { key } => {
149 0 : let key = KeyPrefix::Cancel(*key).build_redis_key();
150 0 : pipe.add_command(Cmd::hget(key, "data"));
151 0 : }
152 0 : CancelKeyOp::Get { key } => {
153 0 : let key = KeyPrefix::Cancel(*key).build_redis_key();
154 0 : pipe.add_command(Cmd::get(key));
155 0 : }
156 : }
157 0 : }
158 : }
159 :
160 : pub struct CancellationProcessor {
161 : pub client: RedisKVClient,
162 : pub batch_size: usize,
163 : }
164 :
165 : impl QueueProcessing for CancellationProcessor {
166 : type Req = (CancelChannelSizeGuard<'static>, CancelKeyOp);
167 : type Res = redis::Value;
168 : type Err = PipelineError;
169 :
170 0 : fn batch_size(&self, _queue_size: usize) -> usize {
171 0 : self.batch_size
172 0 : }
173 :
174 0 : async fn apply(&mut self, batch: Vec<Self::Req>) -> Result<Vec<Self::Res>, Self::Err> {
175 0 : if !self.client.credentials_refreshed() {
176 : // this will cause a timeout for cancellation operations
177 0 : tracing::debug!(
178 0 : "Redis credentials are not refreshed. Sleeping for 5 seconds before retrying..."
179 : );
180 0 : tokio::time::sleep(Duration::from_secs(5)).await;
181 0 : }
182 :
183 0 : let mut pipeline = Pipeline::with_capacity(batch.len());
184 :
185 0 : let batch_size = batch.len();
186 0 : debug!(batch_size, "running cancellation jobs");
187 :
188 0 : for (_, op) in &batch {
189 0 : op.register(&mut pipeline);
190 0 : }
191 :
192 0 : pipeline.execute(&mut self.client).await
193 0 : }
194 : }
195 :
196 : /// Enables serving `CancelRequest`s.
197 : ///
198 : /// If `CancellationPublisher` is available, cancel request will be used to publish the cancellation key to other proxy instances.
199 : pub struct CancellationHandler {
200 : compute_config: &'static ComputeConfig,
201 : // rate limiter of cancellation requests
202 : limiter: Arc<std::sync::Mutex<LeakyBucketRateLimiter<IpSubnetKey>>>,
203 : tx: OnceLock<BatchQueue<CancellationProcessor>>, // send messages to the redis KV client task
204 : }
205 :
206 : #[derive(Debug, Error)]
207 : pub(crate) enum CancelError {
208 : #[error("{0}")]
209 : IO(#[from] std::io::Error),
210 :
211 : #[error("{0}")]
212 : Postgres(#[from] postgres_client::Error),
213 :
214 : #[error("rate limit exceeded")]
215 : RateLimit,
216 :
217 : #[error("Authentication error")]
218 : AuthError(#[from] AuthError),
219 :
220 : #[error("key not found")]
221 : NotFound,
222 :
223 : #[error("proxy service error")]
224 : InternalError,
225 : }
226 :
227 : impl ReportableError for CancelError {
228 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
229 0 : match self {
230 0 : CancelError::IO(_) => crate::error::ErrorKind::Compute,
231 0 : CancelError::Postgres(e) if e.as_db_error().is_some() => {
232 0 : crate::error::ErrorKind::Postgres
233 : }
234 0 : CancelError::Postgres(_) => crate::error::ErrorKind::Compute,
235 0 : CancelError::RateLimit => crate::error::ErrorKind::RateLimit,
236 0 : CancelError::NotFound | CancelError::AuthError(_) => crate::error::ErrorKind::User,
237 0 : CancelError::InternalError => crate::error::ErrorKind::Service,
238 : }
239 0 : }
240 : }
241 :
242 : impl CancellationHandler {
243 0 : pub fn new(compute_config: &'static ComputeConfig) -> Self {
244 0 : Self {
245 0 : compute_config,
246 0 : tx: OnceLock::new(),
247 0 : limiter: Arc::new(std::sync::Mutex::new(
248 0 : LeakyBucketRateLimiter::<IpSubnetKey>::new_with_shards(
249 0 : LeakyBucketRateLimiter::<IpSubnetKey>::DEFAULT,
250 0 : 64,
251 0 : ),
252 0 : )),
253 0 : }
254 0 : }
255 :
256 0 : pub fn init_tx(&self, queue: BatchQueue<CancellationProcessor>) {
257 0 : self.tx
258 0 : .set(queue)
259 0 : .map_err(|_| {})
260 0 : .expect("cancellation queue should be registered once");
261 0 : }
262 :
263 0 : pub(crate) fn get_key(self: Arc<Self>) -> Session {
264 : // we intentionally generate a random "backend pid" and "secret key" here.
265 : // we use the corresponding u64 as an identifier for the
266 : // actual endpoint+pid+secret for postgres/pgbouncer.
267 : //
268 : // if we forwarded the backend_pid from postgres to the client, there would be a lot
269 : // of overlap between our computes as most pids are small (~100).
270 :
271 0 : let key: CancelKeyData = rand::random();
272 :
273 0 : debug!("registered new query cancellation key {key}");
274 0 : Session {
275 0 : key,
276 0 : cancellation_handler: self,
277 0 : }
278 0 : }
279 :
280 : /// This is not cancel safe
281 0 : async fn get_cancel_key(
282 0 : &self,
283 0 : key: CancelKeyData,
284 0 : ) -> Result<Option<CancelClosure>, CancelError> {
285 : const TIMEOUT: Duration = Duration::from_secs(5);
286 :
287 0 : let Some(tx) = self.tx.get() else {
288 0 : tracing::warn!("cancellation handler is not available");
289 0 : return Err(CancelError::InternalError);
290 : };
291 :
292 0 : let guard = Metrics::get()
293 0 : .proxy
294 0 : .cancel_channel_size
295 0 : .guard(RedisMsgKind::Get);
296 0 : let op = CancelKeyOp::Get { key };
297 0 : let result = timeout(
298 0 : TIMEOUT,
299 0 : tx.call((guard, op), std::future::pending::<Infallible>()),
300 0 : )
301 0 : .await
302 0 : .map_err(|_| {
303 0 : tracing::warn!("timed out waiting to receive GetCancelData response");
304 0 : CancelError::RateLimit
305 0 : })?;
306 :
307 : // We may still have cancel keys set with HSET <key> "data".
308 : // Check error type and retry with HGET.
309 : // TODO: remove code after HSET is not used anymore.
310 0 : let result = if let Err(err) = result.as_ref()
311 0 : && let BatchQueueError::Result(err) = err
312 0 : && let PipelineError::RedisKVClient(err) = err
313 0 : && let RedisKVClientError::Redis(err) = &**err
314 0 : && let Some(errcode) = err.code()
315 0 : && errcode == "WRONGTYPE"
316 : {
317 0 : let guard = Metrics::get()
318 0 : .proxy
319 0 : .cancel_channel_size
320 0 : .guard(RedisMsgKind::HGet);
321 0 : let op = CancelKeyOp::GetOld { key };
322 0 : timeout(
323 0 : TIMEOUT,
324 0 : tx.call((guard, op), std::future::pending::<Infallible>()),
325 0 : )
326 0 : .await
327 0 : .map_err(|_| {
328 0 : tracing::warn!("timed out waiting to receive GetCancelData response");
329 0 : CancelError::RateLimit
330 0 : })?
331 : } else {
332 0 : result
333 : };
334 :
335 0 : let result = result.map_err(|e| {
336 0 : tracing::warn!("failed to receive GetCancelData response: {e}");
337 0 : CancelError::InternalError
338 0 : })?;
339 :
340 0 : let cancel_state_str = String::from_owned_redis_value(result).map_err(|e| {
341 0 : tracing::warn!("failed to receive GetCancelData response: {e}");
342 0 : CancelError::InternalError
343 0 : })?;
344 :
345 0 : let cancel_closure: CancelClosure =
346 0 : serde_json::from_str(&cancel_state_str).map_err(|e| {
347 0 : tracing::warn!("failed to deserialize cancel state: {e}");
348 0 : CancelError::InternalError
349 0 : })?;
350 :
351 0 : Ok(Some(cancel_closure))
352 0 : }
353 :
354 : /// Try to cancel a running query for the corresponding connection.
355 : /// If the cancellation key is not found, it will be published to Redis.
356 : /// check_allowed - if true, check if the IP is allowed to cancel the query.
357 : /// Will fetch IP allowlist internally.
358 : ///
359 : /// return Result primarily for tests
360 : ///
361 : /// This is not cancel safe
362 0 : pub(crate) async fn cancel_session<T: ControlPlaneApi>(
363 0 : &self,
364 0 : key: CancelKeyData,
365 0 : ctx: RequestContext,
366 0 : check_ip_allowed: bool,
367 0 : check_vpc_allowed: bool,
368 0 : auth_backend: &T,
369 0 : ) -> Result<(), CancelError> {
370 0 : let subnet_key = match ctx.peer_addr() {
371 0 : IpAddr::V4(ip) => IpNet::V4(Ipv4Net::new_assert(ip, 24).trunc()), // use defaut mask here
372 0 : IpAddr::V6(ip) => IpNet::V6(Ipv6Net::new_assert(ip, 64).trunc()),
373 : };
374 :
375 0 : let allowed = {
376 0 : let rate_limit_config = None;
377 0 : let limiter = self.limiter.lock_propagate_poison();
378 0 : limiter.check(subnet_key, rate_limit_config, 1)
379 : };
380 0 : if !allowed {
381 : // log only the subnet part of the IP address to know which subnet is rate limited
382 0 : tracing::warn!("Rate limit exceeded. Skipping cancellation message, {subnet_key}");
383 0 : Metrics::get()
384 0 : .proxy
385 0 : .cancellation_requests_total
386 0 : .inc(CancellationRequest {
387 0 : kind: crate::metrics::CancellationOutcome::RateLimitExceeded,
388 0 : });
389 0 : return Err(CancelError::RateLimit);
390 0 : }
391 :
392 0 : let cancel_state = self.get_cancel_key(key).await.map_err(|e| {
393 0 : tracing::warn!("failed to receive RedisOp response: {e}");
394 0 : CancelError::InternalError
395 0 : })?;
396 :
397 0 : let Some(cancel_closure) = cancel_state else {
398 0 : tracing::warn!("query cancellation key not found: {key}");
399 0 : Metrics::get()
400 0 : .proxy
401 0 : .cancellation_requests_total
402 0 : .inc(CancellationRequest {
403 0 : kind: crate::metrics::CancellationOutcome::NotFound,
404 0 : });
405 0 : return Err(CancelError::NotFound);
406 : };
407 :
408 0 : let info = &cancel_closure.user_info;
409 0 : let access_controls = auth_backend
410 0 : .get_endpoint_access_control(&ctx, &info.endpoint, &info.user)
411 0 : .await
412 0 : .map_err(|e| CancelError::AuthError(e.into()))?;
413 :
414 0 : access_controls.check(&ctx, check_ip_allowed, check_vpc_allowed)?;
415 :
416 0 : Metrics::get()
417 0 : .proxy
418 0 : .cancellation_requests_total
419 0 : .inc(CancellationRequest {
420 0 : kind: crate::metrics::CancellationOutcome::Found,
421 0 : });
422 0 : info!("cancelling query per user's request using key {key}");
423 0 : cancel_closure.try_cancel_query(self.compute_config).await
424 0 : }
425 : }
426 :
427 : /// This should've been a [`std::future::Future`], but
428 : /// it's impossible to name a type of an unboxed future
429 : /// (we'd need something like `#![feature(type_alias_impl_trait)]`).
430 0 : #[derive(Debug, Clone, Serialize, Deserialize)]
431 : pub struct CancelClosure {
432 : socket_addr: SocketAddr,
433 : cancel_token: RawCancelToken,
434 : hostname: String, // for pg_sni router
435 : user_info: ComputeUserInfo,
436 : }
437 :
438 : impl CancelClosure {
439 0 : pub(crate) fn new(
440 0 : socket_addr: SocketAddr,
441 0 : cancel_token: RawCancelToken,
442 0 : hostname: String,
443 0 : user_info: ComputeUserInfo,
444 0 : ) -> Self {
445 0 : Self {
446 0 : socket_addr,
447 0 : cancel_token,
448 0 : hostname,
449 0 : user_info,
450 0 : }
451 0 : }
452 : /// Cancels the query running on user's compute node.
453 0 : pub(crate) async fn try_cancel_query(
454 0 : &self,
455 0 : compute_config: &ComputeConfig,
456 0 : ) -> Result<(), CancelError> {
457 0 : let socket = TcpStream::connect(self.socket_addr).await?;
458 :
459 0 : let tls = <_ as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
460 0 : compute_config,
461 0 : &self.hostname,
462 : )
463 0 : .map_err(|e| CancelError::IO(std::io::Error::other(e.to_string())))?;
464 :
465 0 : self.cancel_token.cancel_query_raw(socket, tls).await?;
466 0 : debug!("query was cancelled");
467 0 : Ok(())
468 0 : }
469 : }
470 :
471 : /// Helper for registering query cancellation tokens.
472 : pub(crate) struct Session {
473 : /// The user-facing key identifying this session.
474 : key: CancelKeyData,
475 : cancellation_handler: Arc<CancellationHandler>,
476 : }
477 :
478 : impl Session {
479 0 : pub(crate) fn key(&self) -> &CancelKeyData {
480 0 : &self.key
481 0 : }
482 :
483 : /// Ensure the cancel key is continously refreshed,
484 : /// but stop when the channel is dropped.
485 : ///
486 : /// This is not cancel safe
487 0 : pub(crate) async fn maintain_cancel_key(
488 0 : &self,
489 0 : session_id: uuid::Uuid,
490 0 : cancel: tokio::sync::oneshot::Receiver<Infallible>,
491 0 : cancel_closure: &CancelClosure,
492 0 : compute_config: &ComputeConfig,
493 0 : ) {
494 0 : let Some(tx) = self.cancellation_handler.tx.get() else {
495 0 : tracing::warn!("cancellation handler is not available");
496 : // don't exit, as we only want to exit if cancelled externally.
497 0 : std::future::pending().await
498 : };
499 :
500 0 : let closure_json = serde_json::to_string(&cancel_closure)
501 0 : .expect("serialising to json string should not fail")
502 0 : .into_boxed_str();
503 :
504 0 : let mut cancel = pin!(cancel);
505 :
506 : enum State {
507 : Init,
508 : Refresh,
509 : }
510 :
511 0 : let mut state = State::Init;
512 : loop {
513 0 : let (op, mut wait_interval) = match state {
514 : State::Init => {
515 0 : tracing::debug!(
516 : src=%self.key,
517 : dest=?cancel_closure.cancel_token,
518 0 : "registering cancellation key"
519 : );
520 0 : (
521 0 : CancelKeyOp::Store {
522 0 : key: self.key,
523 0 : value: closure_json.clone(),
524 0 : expire: CANCEL_KEY_INITIAL_PERIOD + CANCEL_KEY_TTL_SLACK,
525 0 : },
526 0 : CANCEL_KEY_INITIAL_PERIOD,
527 0 : )
528 : }
529 :
530 : State::Refresh => {
531 0 : tracing::debug!(
532 : src=%self.key,
533 : dest=?cancel_closure.cancel_token,
534 0 : "refreshing cancellation key"
535 : );
536 0 : (
537 0 : CancelKeyOp::Refresh {
538 0 : key: self.key,
539 0 : expire: CANCEL_KEY_REFRESH_PERIOD + CANCEL_KEY_TTL_SLACK,
540 0 : },
541 0 : CANCEL_KEY_REFRESH_PERIOD,
542 0 : )
543 : }
544 : };
545 :
546 0 : match tx
547 0 : .call((op.cancel_channel_metric_guard(), op), cancel.as_mut())
548 0 : .await
549 : {
550 : // SET returns OK
551 : Ok(Value::Okay) => {
552 0 : tracing::debug!(
553 : src=%self.key,
554 : dest=?cancel_closure.cancel_token,
555 0 : "registered cancellation key"
556 : );
557 0 : state = State::Refresh;
558 : }
559 :
560 : // EXPIRE returns 1
561 : Ok(Value::Int(1)) => {
562 0 : tracing::debug!(
563 : src=%self.key,
564 : dest=?cancel_closure.cancel_token,
565 0 : "refreshed cancellation key"
566 : );
567 : }
568 :
569 : Ok(_) => {
570 : // Any other response likely means the key expired.
571 0 : tracing::warn!(src=%self.key, "refreshing cancellation key failed");
572 : // Re-enter the SET loop quickly to repush full data.
573 0 : state = State::Init;
574 0 : wait_interval = Duration::ZERO;
575 : }
576 :
577 : // retry immediately.
578 0 : Err(BatchQueueError::Result(error)) => {
579 0 : tracing::warn!(?error, "error refreshing cancellation key");
580 : // Small delay to prevent busy loop with high cpu and logging.
581 0 : wait_interval = Duration::from_millis(10);
582 : }
583 :
584 0 : Err(BatchQueueError::Cancelled(Err(_cancelled))) => break,
585 : }
586 :
587 : // wait before continuing. break immediately if cancelled.
588 0 : if run_until(tokio::time::sleep(wait_interval), cancel.as_mut())
589 0 : .await
590 0 : .is_err()
591 : {
592 0 : break;
593 0 : }
594 : }
595 :
596 0 : if let Err(err) = cancel_closure
597 0 : .try_cancel_query(compute_config)
598 0 : .boxed()
599 0 : .await
600 : {
601 0 : tracing::warn!(
602 : ?session_id,
603 : ?err,
604 0 : "could not cancel the query in the database"
605 : );
606 0 : }
607 0 : }
608 : }
|