Line data Source code
1 : use std::net::{IpAddr, SocketAddr};
2 : use std::sync::Arc;
3 :
4 : use anyhow::{Context, anyhow};
5 : use ipnet::{IpNet, Ipv4Net, Ipv6Net};
6 : use postgres_client::CancelToken;
7 : use postgres_client::tls::MakeTlsConnect;
8 : use redis::{Cmd, FromRedisValue, Value};
9 : use serde::{Deserialize, Serialize};
10 : use thiserror::Error;
11 : use tokio::net::TcpStream;
12 : use tokio::sync::{mpsc, oneshot};
13 : use tracing::{debug, error, info, warn};
14 :
15 : use crate::auth::AuthError;
16 : use crate::auth::backend::ComputeUserInfo;
17 : use crate::config::ComputeConfig;
18 : use crate::context::RequestContext;
19 : use crate::control_plane::ControlPlaneApi;
20 : use crate::error::ReportableError;
21 : use crate::ext::LockExt;
22 : use crate::metrics::{CancelChannelSizeGuard, CancellationRequest, Metrics, RedisMsgKind};
23 : use crate::pqproto::CancelKeyData;
24 : use crate::rate_limiter::LeakyBucketRateLimiter;
25 : use crate::redis::keys::KeyPrefix;
26 : use crate::redis::kv_ops::RedisKVClient;
27 :
28 : type IpSubnetKey = IpNet;
29 :
30 : const CANCEL_KEY_TTL: i64 = 1_209_600; // 2 weeks cancellation key expire time
31 :
32 : // Message types for sending through mpsc channel
33 : pub enum CancelKeyOp {
34 : StoreCancelKey {
35 : key: String,
36 : field: String,
37 : value: String,
38 : resp_tx: Option<oneshot::Sender<anyhow::Result<()>>>,
39 : _guard: CancelChannelSizeGuard<'static>,
40 : expire: i64, // TTL for key
41 : },
42 : GetCancelData {
43 : key: String,
44 : resp_tx: oneshot::Sender<anyhow::Result<Vec<(String, String)>>>,
45 : _guard: CancelChannelSizeGuard<'static>,
46 : },
47 : RemoveCancelKey {
48 : key: String,
49 : field: String,
50 : resp_tx: Option<oneshot::Sender<anyhow::Result<()>>>,
51 : _guard: CancelChannelSizeGuard<'static>,
52 : },
53 : }
54 :
55 : pub struct Pipeline {
56 : inner: redis::Pipeline,
57 : replies: Vec<CancelReplyOp>,
58 : }
59 :
60 : impl Pipeline {
61 0 : fn with_capacity(n: usize) -> Self {
62 0 : Self {
63 0 : inner: redis::Pipeline::with_capacity(n),
64 0 : replies: Vec::with_capacity(n),
65 0 : }
66 0 : }
67 :
68 0 : async fn execute(&mut self, client: &mut RedisKVClient) {
69 0 : let responses = self.replies.len();
70 0 : let batch_size = self.inner.len();
71 0 :
72 0 : match client.query(&self.inner).await {
73 : // for each reply, we expect that many values.
74 0 : Ok(Value::Array(values)) if values.len() == responses => {
75 0 : debug!(
76 : batch_size,
77 0 : responses, "successfully completed cancellation jobs",
78 : );
79 0 : for (value, reply) in std::iter::zip(values, self.replies.drain(..)) {
80 0 : reply.send_value(value);
81 0 : }
82 : }
83 0 : Ok(value) => {
84 0 : error!(batch_size, ?value, "unexpected redis return value");
85 0 : for reply in self.replies.drain(..) {
86 0 : reply.send_err(anyhow!("incorrect response type from redis"));
87 0 : }
88 : }
89 0 : Err(err) => {
90 0 : for reply in self.replies.drain(..) {
91 0 : reply.send_err(anyhow!("could not send cmd to redis: {err}"));
92 0 : }
93 : }
94 : }
95 :
96 0 : self.inner.clear();
97 0 : self.replies.clear();
98 0 : }
99 :
100 0 : fn add_command_with_reply(&mut self, cmd: Cmd, reply: CancelReplyOp) {
101 0 : self.inner.add_command(cmd);
102 0 : self.replies.push(reply);
103 0 : }
104 :
105 0 : fn add_command_no_reply(&mut self, cmd: Cmd) {
106 0 : self.inner.add_command(cmd).ignore();
107 0 : }
108 :
109 0 : fn add_command(&mut self, cmd: Cmd, reply: Option<CancelReplyOp>) {
110 0 : match reply {
111 0 : Some(reply) => self.add_command_with_reply(cmd, reply),
112 0 : None => self.add_command_no_reply(cmd),
113 : }
114 0 : }
115 : }
116 :
117 : impl CancelKeyOp {
118 0 : fn register(self, pipe: &mut Pipeline) {
119 0 : #[allow(clippy::used_underscore_binding)]
120 0 : match self {
121 : CancelKeyOp::StoreCancelKey {
122 0 : key,
123 0 : field,
124 0 : value,
125 0 : resp_tx,
126 0 : _guard,
127 0 : expire,
128 0 : } => {
129 0 : let reply =
130 0 : resp_tx.map(|resp_tx| CancelReplyOp::StoreCancelKey { resp_tx, _guard });
131 0 : pipe.add_command(Cmd::hset(&key, field, value), reply);
132 0 : pipe.add_command_no_reply(Cmd::expire(key, expire));
133 0 : }
134 : CancelKeyOp::GetCancelData {
135 0 : key,
136 0 : resp_tx,
137 0 : _guard,
138 0 : } => {
139 0 : let reply = CancelReplyOp::GetCancelData { resp_tx, _guard };
140 0 : pipe.add_command_with_reply(Cmd::hgetall(key), reply);
141 0 : }
142 : CancelKeyOp::RemoveCancelKey {
143 0 : key,
144 0 : field,
145 0 : resp_tx,
146 0 : _guard,
147 0 : } => {
148 0 : let reply =
149 0 : resp_tx.map(|resp_tx| CancelReplyOp::RemoveCancelKey { resp_tx, _guard });
150 0 : pipe.add_command(Cmd::hdel(key, field), reply);
151 0 : }
152 : }
153 0 : }
154 : }
155 :
156 : // Message types for sending through mpsc channel
157 : pub enum CancelReplyOp {
158 : StoreCancelKey {
159 : resp_tx: oneshot::Sender<anyhow::Result<()>>,
160 : _guard: CancelChannelSizeGuard<'static>,
161 : },
162 : GetCancelData {
163 : resp_tx: oneshot::Sender<anyhow::Result<Vec<(String, String)>>>,
164 : _guard: CancelChannelSizeGuard<'static>,
165 : },
166 : RemoveCancelKey {
167 : resp_tx: oneshot::Sender<anyhow::Result<()>>,
168 : _guard: CancelChannelSizeGuard<'static>,
169 : },
170 : }
171 :
172 : impl CancelReplyOp {
173 0 : fn send_err(self, e: anyhow::Error) {
174 0 : match self {
175 0 : CancelReplyOp::StoreCancelKey { resp_tx, _guard } => {
176 0 : resp_tx
177 0 : .send(Err(e))
178 0 : .inspect_err(|_| tracing::debug!("could not send reply"))
179 0 : .ok();
180 0 : }
181 0 : CancelReplyOp::GetCancelData { resp_tx, _guard } => {
182 0 : resp_tx
183 0 : .send(Err(e))
184 0 : .inspect_err(|_| tracing::debug!("could not send reply"))
185 0 : .ok();
186 0 : }
187 0 : CancelReplyOp::RemoveCancelKey { resp_tx, _guard } => {
188 0 : resp_tx
189 0 : .send(Err(e))
190 0 : .inspect_err(|_| tracing::debug!("could not send reply"))
191 0 : .ok();
192 0 : }
193 : }
194 0 : }
195 :
196 0 : fn send_value(self, v: redis::Value) {
197 0 : match self {
198 0 : CancelReplyOp::StoreCancelKey { resp_tx, _guard } => {
199 0 : let send =
200 0 : FromRedisValue::from_owned_redis_value(v).context("could not parse value");
201 0 : resp_tx
202 0 : .send(send)
203 0 : .inspect_err(|_| tracing::debug!("could not send reply"))
204 0 : .ok();
205 0 : }
206 0 : CancelReplyOp::GetCancelData { resp_tx, _guard } => {
207 0 : let send =
208 0 : FromRedisValue::from_owned_redis_value(v).context("could not parse value");
209 0 : resp_tx
210 0 : .send(send)
211 0 : .inspect_err(|_| tracing::debug!("could not send reply"))
212 0 : .ok();
213 0 : }
214 0 : CancelReplyOp::RemoveCancelKey { resp_tx, _guard } => {
215 0 : let send =
216 0 : FromRedisValue::from_owned_redis_value(v).context("could not parse value");
217 0 : resp_tx
218 0 : .send(send)
219 0 : .inspect_err(|_| tracing::debug!("could not send reply"))
220 0 : .ok();
221 0 : }
222 : }
223 0 : }
224 : }
225 :
226 : // Running as a separate task to accept messages through the rx channel
227 0 : pub async fn handle_cancel_messages(
228 0 : client: &mut RedisKVClient,
229 0 : mut rx: mpsc::Receiver<CancelKeyOp>,
230 0 : batch_size: usize,
231 0 : ) -> anyhow::Result<()> {
232 0 : let mut batch = Vec::with_capacity(batch_size);
233 0 : let mut pipeline = Pipeline::with_capacity(batch_size);
234 :
235 : loop {
236 0 : if rx.recv_many(&mut batch, batch_size).await == 0 {
237 0 : warn!("shutting down cancellation queue");
238 0 : break Ok(());
239 0 : }
240 0 :
241 0 : let batch_size = batch.len();
242 0 : debug!(batch_size, "running cancellation jobs");
243 :
244 0 : for msg in batch.drain(..) {
245 0 : msg.register(&mut pipeline);
246 0 : }
247 :
248 0 : pipeline.execute(client).await;
249 : }
250 0 : }
251 :
252 : /// Enables serving `CancelRequest`s.
253 : ///
254 : /// If `CancellationPublisher` is available, cancel request will be used to publish the cancellation key to other proxy instances.
255 : pub struct CancellationHandler {
256 : compute_config: &'static ComputeConfig,
257 : // rate limiter of cancellation requests
258 : limiter: Arc<std::sync::Mutex<LeakyBucketRateLimiter<IpSubnetKey>>>,
259 : tx: Option<mpsc::Sender<CancelKeyOp>>, // send messages to the redis KV client task
260 : }
261 :
262 : #[derive(Debug, Error)]
263 : pub(crate) enum CancelError {
264 : #[error("{0}")]
265 : IO(#[from] std::io::Error),
266 :
267 : #[error("{0}")]
268 : Postgres(#[from] postgres_client::Error),
269 :
270 : #[error("rate limit exceeded")]
271 : RateLimit,
272 :
273 : #[error("Authentication error")]
274 : AuthError(#[from] AuthError),
275 :
276 : #[error("key not found")]
277 : NotFound,
278 :
279 : #[error("proxy service error")]
280 : InternalError,
281 : }
282 :
283 : impl ReportableError for CancelError {
284 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
285 0 : match self {
286 0 : CancelError::IO(_) => crate::error::ErrorKind::Compute,
287 0 : CancelError::Postgres(e) if e.as_db_error().is_some() => {
288 0 : crate::error::ErrorKind::Postgres
289 : }
290 0 : CancelError::Postgres(_) => crate::error::ErrorKind::Compute,
291 0 : CancelError::RateLimit => crate::error::ErrorKind::RateLimit,
292 0 : CancelError::NotFound | CancelError::AuthError(_) => crate::error::ErrorKind::User,
293 0 : CancelError::InternalError => crate::error::ErrorKind::Service,
294 : }
295 0 : }
296 : }
297 :
298 : impl CancellationHandler {
299 0 : pub fn new(
300 0 : compute_config: &'static ComputeConfig,
301 0 : tx: Option<mpsc::Sender<CancelKeyOp>>,
302 0 : ) -> Self {
303 0 : Self {
304 0 : compute_config,
305 0 : tx,
306 0 : limiter: Arc::new(std::sync::Mutex::new(
307 0 : LeakyBucketRateLimiter::<IpSubnetKey>::new_with_shards(
308 0 : LeakyBucketRateLimiter::<IpSubnetKey>::DEFAULT,
309 0 : 64,
310 0 : ),
311 0 : )),
312 0 : }
313 0 : }
314 :
315 0 : pub(crate) fn get_key(self: &Arc<Self>) -> Session {
316 0 : // we intentionally generate a random "backend pid" and "secret key" here.
317 0 : // we use the corresponding u64 as an identifier for the
318 0 : // actual endpoint+pid+secret for postgres/pgbouncer.
319 0 : //
320 0 : // if we forwarded the backend_pid from postgres to the client, there would be a lot
321 0 : // of overlap between our computes as most pids are small (~100).
322 0 :
323 0 : let key: CancelKeyData = rand::random();
324 0 :
325 0 : let prefix_key: KeyPrefix = KeyPrefix::Cancel(key);
326 0 : let redis_key = prefix_key.build_redis_key();
327 0 :
328 0 : debug!("registered new query cancellation key {key}");
329 0 : Session {
330 0 : key,
331 0 : redis_key,
332 0 : cancellation_handler: Arc::clone(self),
333 0 : }
334 0 : }
335 :
336 0 : async fn get_cancel_key(
337 0 : &self,
338 0 : key: CancelKeyData,
339 0 : ) -> Result<Option<CancelClosure>, CancelError> {
340 0 : let prefix_key: KeyPrefix = KeyPrefix::Cancel(key);
341 0 : let redis_key = prefix_key.build_redis_key();
342 0 :
343 0 : let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
344 0 : let op = CancelKeyOp::GetCancelData {
345 0 : key: redis_key,
346 0 : resp_tx,
347 0 : _guard: Metrics::get()
348 0 : .proxy
349 0 : .cancel_channel_size
350 0 : .guard(RedisMsgKind::HGetAll),
351 0 : };
352 :
353 0 : let Some(tx) = &self.tx else {
354 0 : tracing::warn!("cancellation handler is not available");
355 0 : return Err(CancelError::InternalError);
356 : };
357 :
358 0 : tx.try_send(op)
359 0 : .map_err(|e| {
360 0 : tracing::warn!("failed to send GetCancelData for {key}: {e}");
361 0 : })
362 0 : .map_err(|()| CancelError::InternalError)?;
363 :
364 0 : let result = resp_rx.await.map_err(|e| {
365 0 : tracing::warn!("failed to receive GetCancelData response: {e}");
366 0 : CancelError::InternalError
367 0 : })?;
368 :
369 0 : let cancel_state_str: Option<String> = match result {
370 0 : Ok(mut state) => {
371 0 : if state.len() == 1 {
372 0 : Some(state.remove(0).1)
373 : } else {
374 0 : tracing::warn!("unexpected number of entries in cancel state: {state:?}");
375 0 : return Err(CancelError::InternalError);
376 : }
377 : }
378 0 : Err(e) => {
379 0 : tracing::warn!("failed to receive cancel state from redis: {e}");
380 0 : return Err(CancelError::InternalError);
381 : }
382 : };
383 :
384 0 : let cancel_state: Option<CancelClosure> = match cancel_state_str {
385 0 : Some(state) => {
386 0 : let cancel_closure: CancelClosure = serde_json::from_str(&state).map_err(|e| {
387 0 : tracing::warn!("failed to deserialize cancel state: {e}");
388 0 : CancelError::InternalError
389 0 : })?;
390 0 : Some(cancel_closure)
391 : }
392 0 : None => None,
393 : };
394 0 : Ok(cancel_state)
395 0 : }
396 : /// Try to cancel a running query for the corresponding connection.
397 : /// If the cancellation key is not found, it will be published to Redis.
398 : /// check_allowed - if true, check if the IP is allowed to cancel the query.
399 : /// Will fetch IP allowlist internally.
400 : ///
401 : /// return Result primarily for tests
402 0 : pub(crate) async fn cancel_session<T: ControlPlaneApi>(
403 0 : &self,
404 0 : key: CancelKeyData,
405 0 : ctx: RequestContext,
406 0 : check_ip_allowed: bool,
407 0 : check_vpc_allowed: bool,
408 0 : auth_backend: &T,
409 0 : ) -> Result<(), CancelError> {
410 0 : let subnet_key = match ctx.peer_addr() {
411 0 : IpAddr::V4(ip) => IpNet::V4(Ipv4Net::new_assert(ip, 24).trunc()), // use defaut mask here
412 0 : IpAddr::V6(ip) => IpNet::V6(Ipv6Net::new_assert(ip, 64).trunc()),
413 : };
414 :
415 0 : let allowed = {
416 0 : let rate_limit_config = None;
417 0 : let limiter = self.limiter.lock_propagate_poison();
418 0 : limiter.check(subnet_key, rate_limit_config, 1)
419 0 : };
420 0 : if !allowed {
421 : // log only the subnet part of the IP address to know which subnet is rate limited
422 0 : tracing::warn!("Rate limit exceeded. Skipping cancellation message, {subnet_key}");
423 0 : Metrics::get()
424 0 : .proxy
425 0 : .cancellation_requests_total
426 0 : .inc(CancellationRequest {
427 0 : kind: crate::metrics::CancellationOutcome::RateLimitExceeded,
428 0 : });
429 0 : return Err(CancelError::RateLimit);
430 0 : }
431 :
432 0 : let cancel_state = self.get_cancel_key(key).await.map_err(|e| {
433 0 : tracing::warn!("failed to receive RedisOp response: {e}");
434 0 : CancelError::InternalError
435 0 : })?;
436 :
437 0 : let Some(cancel_closure) = cancel_state else {
438 0 : tracing::warn!("query cancellation key not found: {key}");
439 0 : Metrics::get()
440 0 : .proxy
441 0 : .cancellation_requests_total
442 0 : .inc(CancellationRequest {
443 0 : kind: crate::metrics::CancellationOutcome::NotFound,
444 0 : });
445 0 : return Err(CancelError::NotFound);
446 : };
447 :
448 0 : let info = &cancel_closure.user_info;
449 0 : let access_controls = auth_backend
450 0 : .get_endpoint_access_control(&ctx, &info.endpoint, &info.user)
451 0 : .await
452 0 : .map_err(|e| CancelError::AuthError(e.into()))?;
453 :
454 0 : access_controls.check(&ctx, check_ip_allowed, check_vpc_allowed)?;
455 :
456 0 : Metrics::get()
457 0 : .proxy
458 0 : .cancellation_requests_total
459 0 : .inc(CancellationRequest {
460 0 : kind: crate::metrics::CancellationOutcome::Found,
461 0 : });
462 0 : info!("cancelling query per user's request using key {key}");
463 0 : cancel_closure.try_cancel_query(self.compute_config).await
464 0 : }
465 : }
466 :
467 : /// This should've been a [`std::future::Future`], but
468 : /// it's impossible to name a type of an unboxed future
469 : /// (we'd need something like `#![feature(type_alias_impl_trait)]`).
470 0 : #[derive(Clone, Serialize, Deserialize)]
471 : pub struct CancelClosure {
472 : socket_addr: SocketAddr,
473 : cancel_token: CancelToken,
474 : hostname: String, // for pg_sni router
475 : user_info: ComputeUserInfo,
476 : }
477 :
478 : impl CancelClosure {
479 0 : pub(crate) fn new(
480 0 : socket_addr: SocketAddr,
481 0 : cancel_token: CancelToken,
482 0 : hostname: String,
483 0 : user_info: ComputeUserInfo,
484 0 : ) -> Self {
485 0 : Self {
486 0 : socket_addr,
487 0 : cancel_token,
488 0 : hostname,
489 0 : user_info,
490 0 : }
491 0 : }
492 : /// Cancels the query running on user's compute node.
493 0 : pub(crate) async fn try_cancel_query(
494 0 : self,
495 0 : compute_config: &ComputeConfig,
496 0 : ) -> Result<(), CancelError> {
497 0 : let socket = TcpStream::connect(self.socket_addr).await?;
498 :
499 0 : let tls = <_ as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
500 0 : compute_config,
501 0 : &self.hostname,
502 0 : )
503 0 : .map_err(|e| CancelError::IO(std::io::Error::other(e.to_string())))?;
504 :
505 0 : self.cancel_token.cancel_query_raw(socket, tls).await?;
506 0 : debug!("query was cancelled");
507 0 : Ok(())
508 0 : }
509 : }
510 :
511 : /// Helper for registering query cancellation tokens.
512 : pub(crate) struct Session {
513 : /// The user-facing key identifying this session.
514 : key: CancelKeyData,
515 : redis_key: String,
516 : cancellation_handler: Arc<CancellationHandler>,
517 : }
518 :
519 : impl Session {
520 0 : pub(crate) fn key(&self) -> &CancelKeyData {
521 0 : &self.key
522 0 : }
523 :
524 : // Send the store key op to the cancellation handler and set TTL for the key
525 0 : pub(crate) fn write_cancel_key(
526 0 : &self,
527 0 : cancel_closure: CancelClosure,
528 0 : ) -> Result<(), CancelError> {
529 0 : let Some(tx) = &self.cancellation_handler.tx else {
530 0 : tracing::warn!("cancellation handler is not available");
531 0 : return Err(CancelError::InternalError);
532 : };
533 :
534 0 : let closure_json = serde_json::to_string(&cancel_closure).map_err(|e| {
535 0 : tracing::warn!("failed to serialize cancel closure: {e}");
536 0 : CancelError::InternalError
537 0 : })?;
538 :
539 0 : let op = CancelKeyOp::StoreCancelKey {
540 0 : key: self.redis_key.clone(),
541 0 : field: "data".to_string(),
542 0 : value: closure_json,
543 0 : resp_tx: None,
544 0 : _guard: Metrics::get()
545 0 : .proxy
546 0 : .cancel_channel_size
547 0 : .guard(RedisMsgKind::HSet),
548 0 : expire: CANCEL_KEY_TTL,
549 0 : };
550 0 :
551 0 : let _ = tx.try_send(op).map_err(|e| {
552 0 : let key = self.key;
553 0 : tracing::warn!("failed to send StoreCancelKey for {key}: {e}");
554 0 : });
555 0 : Ok(())
556 0 : }
557 :
558 0 : pub(crate) fn remove_cancel_key(&self) -> Result<(), CancelError> {
559 0 : let Some(tx) = &self.cancellation_handler.tx else {
560 0 : tracing::warn!("cancellation handler is not available");
561 0 : return Err(CancelError::InternalError);
562 : };
563 :
564 0 : let op = CancelKeyOp::RemoveCancelKey {
565 0 : key: self.redis_key.clone(),
566 0 : field: "data".to_string(),
567 0 : resp_tx: None,
568 0 : _guard: Metrics::get()
569 0 : .proxy
570 0 : .cancel_channel_size
571 0 : .guard(RedisMsgKind::HDel),
572 0 : };
573 0 :
574 0 : let _ = tx.try_send(op).map_err(|e| {
575 0 : let key = self.key;
576 0 : tracing::warn!("failed to send RemoveCancelKey for {key}: {e}");
577 0 : });
578 0 : Ok(())
579 0 : }
580 : }
|