Line data Source code
1 : use std::convert::Infallible;
2 : use std::net::{IpAddr, SocketAddr};
3 : use std::pin::pin;
4 : use std::sync::{Arc, OnceLock};
5 : use std::time::Duration;
6 :
7 : use futures::FutureExt;
8 : use ipnet::{IpNet, Ipv4Net, Ipv6Net};
9 : use postgres_client::RawCancelToken;
10 : use postgres_client::tls::MakeTlsConnect;
11 : use redis::{Cmd, FromRedisValue, SetExpiry, SetOptions, Value};
12 : use serde::{Deserialize, Serialize};
13 : use thiserror::Error;
14 : use tokio::net::TcpStream;
15 : use tokio::time::timeout;
16 : use tracing::{debug, error, info};
17 :
18 : use crate::auth::AuthError;
19 : use crate::auth::backend::ComputeUserInfo;
20 : use crate::batch::{BatchQueue, BatchQueueError, QueueProcessing};
21 : use crate::config::ComputeConfig;
22 : use crate::context::RequestContext;
23 : use crate::control_plane::ControlPlaneApi;
24 : use crate::error::ReportableError;
25 : use crate::ext::LockExt;
26 : use crate::metrics::{CancelChannelSizeGuard, CancellationRequest, Metrics, RedisMsgKind};
27 : use crate::pqproto::CancelKeyData;
28 : use crate::rate_limiter::LeakyBucketRateLimiter;
29 : use crate::redis::keys::KeyPrefix;
30 : use crate::redis::kv_ops::{RedisKVClient, RedisKVClientError};
31 : use crate::util::run_until;
32 :
33 : type IpSubnetKey = IpNet;
34 :
35 : const CANCEL_KEY_TTL: Duration = Duration::from_secs(600);
36 : const CANCEL_KEY_REFRESH: Duration = Duration::from_secs(570);
37 :
38 : // Message types for sending through mpsc channel
39 : pub enum CancelKeyOp {
40 : Store {
41 : key: CancelKeyData,
42 : value: Box<str>,
43 : expire: Duration,
44 : },
45 : Refresh {
46 : key: CancelKeyData,
47 : expire: Duration,
48 : },
49 : Get {
50 : key: CancelKeyData,
51 : },
52 : GetOld {
53 : key: CancelKeyData,
54 : },
55 : }
56 :
57 : #[derive(thiserror::Error, Debug, Clone)]
58 : pub enum PipelineError {
59 : #[error("could not send cmd to redis: {0}")]
60 : RedisKVClient(Arc<RedisKVClientError>),
61 : #[error("incorrect number of responses from redis")]
62 : IncorrectNumberOfResponses,
63 : }
64 :
65 : pub struct Pipeline {
66 : inner: redis::Pipeline,
67 : replies: usize,
68 : }
69 :
70 : impl Pipeline {
71 0 : fn with_capacity(n: usize) -> Self {
72 0 : Self {
73 0 : inner: redis::Pipeline::with_capacity(n),
74 0 : replies: 0,
75 0 : }
76 0 : }
77 :
78 0 : async fn execute(self, client: &mut RedisKVClient) -> Result<Vec<Value>, PipelineError> {
79 0 : let responses = self.replies;
80 0 : let batch_size = self.inner.len();
81 :
82 0 : if !client.credentials_refreshed() {
83 0 : tracing::debug!(
84 0 : "Redis credentials are not refreshed. Sleeping for 5 seconds before retrying..."
85 : );
86 0 : tokio::time::sleep(Duration::from_secs(5)).await;
87 0 : }
88 :
89 0 : match client.query(&self.inner).await {
90 : // for each reply, we expect that many values.
91 0 : Ok(Value::Array(values)) if values.len() == responses => {
92 0 : debug!(
93 : batch_size,
94 0 : responses, "successfully completed cancellation jobs",
95 : );
96 0 : Ok(values.into_iter().collect())
97 : }
98 0 : Ok(value) => {
99 0 : error!(batch_size, ?value, "unexpected redis return value");
100 0 : Err(PipelineError::IncorrectNumberOfResponses)
101 : }
102 0 : Err(err) => Err(PipelineError::RedisKVClient(Arc::new(err))),
103 : }
104 0 : }
105 :
106 0 : fn add_command(&mut self, cmd: Cmd) {
107 0 : self.inner.add_command(cmd);
108 0 : self.replies += 1;
109 0 : }
110 : }
111 :
112 : impl CancelKeyOp {
113 0 : fn register(&self, pipe: &mut Pipeline) {
114 0 : match self {
115 0 : CancelKeyOp::Store { key, value, expire } => {
116 0 : let key = KeyPrefix::Cancel(*key).build_redis_key();
117 0 : pipe.add_command(Cmd::set_options(
118 0 : &key,
119 0 : &**value,
120 0 : SetOptions::default().with_expiration(SetExpiry::EX(expire.as_secs())),
121 0 : ));
122 0 : }
123 0 : CancelKeyOp::Refresh { key, expire } => {
124 0 : let key = KeyPrefix::Cancel(*key).build_redis_key();
125 0 : pipe.add_command(Cmd::expire(&key, expire.as_secs() as i64));
126 0 : }
127 0 : CancelKeyOp::GetOld { key } => {
128 0 : let key = KeyPrefix::Cancel(*key).build_redis_key();
129 0 : pipe.add_command(Cmd::hget(key, "data"));
130 0 : }
131 0 : CancelKeyOp::Get { key } => {
132 0 : let key = KeyPrefix::Cancel(*key).build_redis_key();
133 0 : pipe.add_command(Cmd::get(key));
134 0 : }
135 : }
136 0 : }
137 : }
138 :
139 : pub struct CancellationProcessor {
140 : pub client: RedisKVClient,
141 : pub batch_size: usize,
142 : }
143 :
144 : impl QueueProcessing for CancellationProcessor {
145 : type Req = (CancelChannelSizeGuard<'static>, CancelKeyOp);
146 : type Res = redis::Value;
147 : type Err = PipelineError;
148 :
149 0 : fn batch_size(&self, _queue_size: usize) -> usize {
150 0 : self.batch_size
151 0 : }
152 :
153 0 : async fn apply(&mut self, batch: Vec<Self::Req>) -> Result<Vec<Self::Res>, Self::Err> {
154 0 : if !self.client.credentials_refreshed() {
155 : // this will cause a timeout for cancellation operations
156 0 : tracing::debug!(
157 0 : "Redis credentials are not refreshed. Sleeping for 5 seconds before retrying..."
158 : );
159 0 : tokio::time::sleep(Duration::from_secs(5)).await;
160 0 : }
161 :
162 0 : let mut pipeline = Pipeline::with_capacity(batch.len());
163 :
164 0 : let batch_size = batch.len();
165 0 : debug!(batch_size, "running cancellation jobs");
166 :
167 0 : for (_, op) in &batch {
168 0 : op.register(&mut pipeline);
169 0 : }
170 :
171 0 : pipeline.execute(&mut self.client).await
172 0 : }
173 : }
174 :
175 : /// Enables serving `CancelRequest`s.
176 : ///
177 : /// If `CancellationPublisher` is available, cancel request will be used to publish the cancellation key to other proxy instances.
178 : pub struct CancellationHandler {
179 : compute_config: &'static ComputeConfig,
180 : // rate limiter of cancellation requests
181 : limiter: Arc<std::sync::Mutex<LeakyBucketRateLimiter<IpSubnetKey>>>,
182 : tx: OnceLock<BatchQueue<CancellationProcessor>>, // send messages to the redis KV client task
183 : }
184 :
185 : #[derive(Debug, Error)]
186 : pub(crate) enum CancelError {
187 : #[error("{0}")]
188 : IO(#[from] std::io::Error),
189 :
190 : #[error("{0}")]
191 : Postgres(#[from] postgres_client::Error),
192 :
193 : #[error("rate limit exceeded")]
194 : RateLimit,
195 :
196 : #[error("Authentication error")]
197 : AuthError(#[from] AuthError),
198 :
199 : #[error("key not found")]
200 : NotFound,
201 :
202 : #[error("proxy service error")]
203 : InternalError,
204 : }
205 :
206 : impl ReportableError for CancelError {
207 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
208 0 : match self {
209 0 : CancelError::IO(_) => crate::error::ErrorKind::Compute,
210 0 : CancelError::Postgres(e) if e.as_db_error().is_some() => {
211 0 : crate::error::ErrorKind::Postgres
212 : }
213 0 : CancelError::Postgres(_) => crate::error::ErrorKind::Compute,
214 0 : CancelError::RateLimit => crate::error::ErrorKind::RateLimit,
215 0 : CancelError::NotFound | CancelError::AuthError(_) => crate::error::ErrorKind::User,
216 0 : CancelError::InternalError => crate::error::ErrorKind::Service,
217 : }
218 0 : }
219 : }
220 :
221 : impl CancellationHandler {
222 0 : pub fn new(compute_config: &'static ComputeConfig) -> Self {
223 0 : Self {
224 0 : compute_config,
225 0 : tx: OnceLock::new(),
226 0 : limiter: Arc::new(std::sync::Mutex::new(
227 0 : LeakyBucketRateLimiter::<IpSubnetKey>::new_with_shards(
228 0 : LeakyBucketRateLimiter::<IpSubnetKey>::DEFAULT,
229 0 : 64,
230 0 : ),
231 0 : )),
232 0 : }
233 0 : }
234 :
235 0 : pub fn init_tx(&self, queue: BatchQueue<CancellationProcessor>) {
236 0 : self.tx
237 0 : .set(queue)
238 0 : .map_err(|_| {})
239 0 : .expect("cancellation queue should be registered once");
240 0 : }
241 :
242 0 : pub(crate) fn get_key(self: Arc<Self>) -> Session {
243 : // we intentionally generate a random "backend pid" and "secret key" here.
244 : // we use the corresponding u64 as an identifier for the
245 : // actual endpoint+pid+secret for postgres/pgbouncer.
246 : //
247 : // if we forwarded the backend_pid from postgres to the client, there would be a lot
248 : // of overlap between our computes as most pids are small (~100).
249 :
250 0 : let key: CancelKeyData = rand::random();
251 :
252 0 : debug!("registered new query cancellation key {key}");
253 0 : Session {
254 0 : key,
255 0 : cancellation_handler: self,
256 0 : }
257 0 : }
258 :
259 : /// This is not cancel safe
260 0 : async fn get_cancel_key(
261 0 : &self,
262 0 : key: CancelKeyData,
263 0 : ) -> Result<Option<CancelClosure>, CancelError> {
264 : const TIMEOUT: Duration = Duration::from_secs(5);
265 :
266 0 : let Some(tx) = self.tx.get() else {
267 0 : tracing::warn!("cancellation handler is not available");
268 0 : return Err(CancelError::InternalError);
269 : };
270 :
271 0 : let guard = Metrics::get()
272 0 : .proxy
273 0 : .cancel_channel_size
274 0 : .guard(RedisMsgKind::Get);
275 0 : let op = CancelKeyOp::Get { key };
276 0 : let result = timeout(
277 0 : TIMEOUT,
278 0 : tx.call((guard, op), std::future::pending::<Infallible>()),
279 0 : )
280 0 : .await
281 0 : .map_err(|_| {
282 0 : tracing::warn!("timed out waiting to receive GetCancelData response");
283 0 : CancelError::RateLimit
284 0 : })?;
285 :
286 : // We may still have cancel keys set with HSET <key> "data".
287 : // Check error type and retry with HGET.
288 : // TODO: remove code after HSET is not used anymore.
289 0 : let result = if let Err(err) = result.as_ref()
290 0 : && let BatchQueueError::Result(err) = err
291 0 : && let PipelineError::RedisKVClient(err) = err
292 0 : && let RedisKVClientError::Redis(err) = &**err
293 0 : && let Some(errcode) = err.code()
294 0 : && errcode == "WRONGTYPE"
295 : {
296 0 : let guard = Metrics::get()
297 0 : .proxy
298 0 : .cancel_channel_size
299 0 : .guard(RedisMsgKind::HGet);
300 0 : let op = CancelKeyOp::GetOld { key };
301 0 : timeout(
302 0 : TIMEOUT,
303 0 : tx.call((guard, op), std::future::pending::<Infallible>()),
304 0 : )
305 0 : .await
306 0 : .map_err(|_| {
307 0 : tracing::warn!("timed out waiting to receive GetCancelData response");
308 0 : CancelError::RateLimit
309 0 : })?
310 : } else {
311 0 : result
312 : };
313 :
314 0 : let result = result.map_err(|e| {
315 0 : tracing::warn!("failed to receive GetCancelData response: {e}");
316 0 : CancelError::InternalError
317 0 : })?;
318 :
319 0 : let cancel_state_str = String::from_owned_redis_value(result).map_err(|e| {
320 0 : tracing::warn!("failed to receive GetCancelData response: {e}");
321 0 : CancelError::InternalError
322 0 : })?;
323 :
324 0 : let cancel_closure: CancelClosure =
325 0 : serde_json::from_str(&cancel_state_str).map_err(|e| {
326 0 : tracing::warn!("failed to deserialize cancel state: {e}");
327 0 : CancelError::InternalError
328 0 : })?;
329 :
330 0 : Ok(Some(cancel_closure))
331 0 : }
332 :
333 : /// Try to cancel a running query for the corresponding connection.
334 : /// If the cancellation key is not found, it will be published to Redis.
335 : /// check_allowed - if true, check if the IP is allowed to cancel the query.
336 : /// Will fetch IP allowlist internally.
337 : ///
338 : /// return Result primarily for tests
339 : ///
340 : /// This is not cancel safe
341 0 : pub(crate) async fn cancel_session<T: ControlPlaneApi>(
342 0 : &self,
343 0 : key: CancelKeyData,
344 0 : ctx: RequestContext,
345 0 : check_ip_allowed: bool,
346 0 : check_vpc_allowed: bool,
347 0 : auth_backend: &T,
348 0 : ) -> Result<(), CancelError> {
349 0 : let subnet_key = match ctx.peer_addr() {
350 0 : IpAddr::V4(ip) => IpNet::V4(Ipv4Net::new_assert(ip, 24).trunc()), // use defaut mask here
351 0 : IpAddr::V6(ip) => IpNet::V6(Ipv6Net::new_assert(ip, 64).trunc()),
352 : };
353 :
354 0 : let allowed = {
355 0 : let rate_limit_config = None;
356 0 : let limiter = self.limiter.lock_propagate_poison();
357 0 : limiter.check(subnet_key, rate_limit_config, 1)
358 : };
359 0 : if !allowed {
360 : // log only the subnet part of the IP address to know which subnet is rate limited
361 0 : tracing::warn!("Rate limit exceeded. Skipping cancellation message, {subnet_key}");
362 0 : Metrics::get()
363 0 : .proxy
364 0 : .cancellation_requests_total
365 0 : .inc(CancellationRequest {
366 0 : kind: crate::metrics::CancellationOutcome::RateLimitExceeded,
367 0 : });
368 0 : return Err(CancelError::RateLimit);
369 0 : }
370 :
371 0 : let cancel_state = self.get_cancel_key(key).await.map_err(|e| {
372 0 : tracing::warn!("failed to receive RedisOp response: {e}");
373 0 : CancelError::InternalError
374 0 : })?;
375 :
376 0 : let Some(cancel_closure) = cancel_state else {
377 0 : tracing::warn!("query cancellation key not found: {key}");
378 0 : Metrics::get()
379 0 : .proxy
380 0 : .cancellation_requests_total
381 0 : .inc(CancellationRequest {
382 0 : kind: crate::metrics::CancellationOutcome::NotFound,
383 0 : });
384 0 : return Err(CancelError::NotFound);
385 : };
386 :
387 0 : let info = &cancel_closure.user_info;
388 0 : let access_controls = auth_backend
389 0 : .get_endpoint_access_control(&ctx, &info.endpoint, &info.user)
390 0 : .await
391 0 : .map_err(|e| CancelError::AuthError(e.into()))?;
392 :
393 0 : access_controls.check(&ctx, check_ip_allowed, check_vpc_allowed)?;
394 :
395 0 : Metrics::get()
396 0 : .proxy
397 0 : .cancellation_requests_total
398 0 : .inc(CancellationRequest {
399 0 : kind: crate::metrics::CancellationOutcome::Found,
400 0 : });
401 0 : info!("cancelling query per user's request using key {key}");
402 0 : cancel_closure.try_cancel_query(self.compute_config).await
403 0 : }
404 : }
405 :
406 : /// This should've been a [`std::future::Future`], but
407 : /// it's impossible to name a type of an unboxed future
408 : /// (we'd need something like `#![feature(type_alias_impl_trait)]`).
409 0 : #[derive(Debug, Clone, Serialize, Deserialize)]
410 : pub struct CancelClosure {
411 : socket_addr: SocketAddr,
412 : cancel_token: RawCancelToken,
413 : hostname: String, // for pg_sni router
414 : user_info: ComputeUserInfo,
415 : }
416 :
417 : impl CancelClosure {
418 0 : pub(crate) fn new(
419 0 : socket_addr: SocketAddr,
420 0 : cancel_token: RawCancelToken,
421 0 : hostname: String,
422 0 : user_info: ComputeUserInfo,
423 0 : ) -> Self {
424 0 : Self {
425 0 : socket_addr,
426 0 : cancel_token,
427 0 : hostname,
428 0 : user_info,
429 0 : }
430 0 : }
431 : /// Cancels the query running on user's compute node.
432 0 : pub(crate) async fn try_cancel_query(
433 0 : &self,
434 0 : compute_config: &ComputeConfig,
435 0 : ) -> Result<(), CancelError> {
436 0 : let socket = TcpStream::connect(self.socket_addr).await?;
437 :
438 0 : let tls = <_ as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
439 0 : compute_config,
440 0 : &self.hostname,
441 : )
442 0 : .map_err(|e| CancelError::IO(std::io::Error::other(e.to_string())))?;
443 :
444 0 : self.cancel_token.cancel_query_raw(socket, tls).await?;
445 0 : debug!("query was cancelled");
446 0 : Ok(())
447 0 : }
448 : }
449 :
450 : /// Helper for registering query cancellation tokens.
451 : pub(crate) struct Session {
452 : /// The user-facing key identifying this session.
453 : key: CancelKeyData,
454 : cancellation_handler: Arc<CancellationHandler>,
455 : }
456 :
457 : impl Session {
458 0 : pub(crate) fn key(&self) -> &CancelKeyData {
459 0 : &self.key
460 0 : }
461 :
462 : /// Ensure the cancel key is continously refreshed,
463 : /// but stop when the channel is dropped.
464 : ///
465 : /// This is not cancel safe
466 0 : pub(crate) async fn maintain_cancel_key(
467 0 : &self,
468 0 : session_id: uuid::Uuid,
469 0 : cancel: tokio::sync::oneshot::Receiver<Infallible>,
470 0 : cancel_closure: &CancelClosure,
471 0 : compute_config: &ComputeConfig,
472 0 : ) {
473 0 : let Some(tx) = self.cancellation_handler.tx.get() else {
474 0 : tracing::warn!("cancellation handler is not available");
475 : // don't exit, as we only want to exit if cancelled externally.
476 0 : std::future::pending().await
477 : };
478 :
479 0 : let closure_json = serde_json::to_string(&cancel_closure)
480 0 : .expect("serialising to json string should not fail")
481 0 : .into_boxed_str();
482 :
483 0 : let mut cancel = pin!(cancel);
484 :
485 : enum State {
486 : Set,
487 : Refresh,
488 : }
489 0 : let mut state = State::Set;
490 :
491 : loop {
492 0 : let guard_op = match state {
493 : State::Set => {
494 0 : let guard = Metrics::get()
495 0 : .proxy
496 0 : .cancel_channel_size
497 0 : .guard(RedisMsgKind::Set);
498 0 : let op = CancelKeyOp::Store {
499 0 : key: self.key,
500 0 : value: closure_json.clone(),
501 0 : expire: CANCEL_KEY_TTL,
502 0 : };
503 0 : tracing::debug!(
504 : src=%self.key,
505 : dest=?cancel_closure.cancel_token,
506 0 : "registering cancellation key"
507 : );
508 0 : (guard, op)
509 : }
510 :
511 : State::Refresh => {
512 0 : let guard = Metrics::get()
513 0 : .proxy
514 0 : .cancel_channel_size
515 0 : .guard(RedisMsgKind::Expire);
516 0 : let op = CancelKeyOp::Refresh {
517 0 : key: self.key,
518 0 : expire: CANCEL_KEY_TTL,
519 0 : };
520 0 : tracing::debug!(
521 : src=%self.key,
522 : dest=?cancel_closure.cancel_token,
523 0 : "refreshing cancellation key"
524 : );
525 0 : (guard, op)
526 : }
527 : };
528 :
529 0 : match tx.call(guard_op, cancel.as_mut()).await {
530 : // SET returns OK
531 : Ok(Value::Okay) => {
532 0 : tracing::debug!(
533 : src=%self.key,
534 : dest=?cancel_closure.cancel_token,
535 0 : "registered cancellation key"
536 : );
537 0 : state = State::Refresh;
538 : }
539 :
540 : // EXPIRE returns 1
541 : Ok(Value::Int(1)) => {
542 0 : tracing::debug!(
543 : src=%self.key,
544 : dest=?cancel_closure.cancel_token,
545 0 : "refreshed cancellation key"
546 : );
547 : }
548 :
549 : Ok(_) => {
550 : // Any other response likely means the key expired.
551 0 : tracing::warn!(src=%self.key, "refreshing cancellation key failed");
552 : // Re-enter the SET loop to repush full data.
553 0 : state = State::Set;
554 : }
555 :
556 : // retry immediately.
557 0 : Err(BatchQueueError::Result(error)) => {
558 0 : tracing::warn!(?error, "error refreshing cancellation key");
559 : // Small delay to prevent busy loop with high cpu and logging.
560 0 : tokio::time::sleep(Duration::from_millis(10)).await;
561 0 : continue;
562 : }
563 :
564 0 : Err(BatchQueueError::Cancelled(Err(_cancelled))) => break,
565 : }
566 :
567 : // wait before continuing. break immediately if cancelled.
568 0 : if run_until(tokio::time::sleep(CANCEL_KEY_REFRESH), cancel.as_mut())
569 0 : .await
570 0 : .is_err()
571 : {
572 0 : break;
573 0 : }
574 : }
575 :
576 0 : if let Err(err) = cancel_closure
577 0 : .try_cancel_query(compute_config)
578 0 : .boxed()
579 0 : .await
580 : {
581 0 : tracing::warn!(
582 : ?session_id,
583 : ?err,
584 0 : "could not cancel the query in the database"
585 : );
586 0 : }
587 0 : }
588 : }
|