Line data Source code
1 : use std::net::{IpAddr, SocketAddr};
2 : use std::sync::Arc;
3 :
4 : use anyhow::{Context, anyhow};
5 : use ipnet::{IpNet, Ipv4Net, Ipv6Net};
6 : use postgres_client::CancelToken;
7 : use postgres_client::tls::MakeTlsConnect;
8 : use pq_proto::CancelKeyData;
9 : use redis::{FromRedisValue, Pipeline, Value, pipe};
10 : use serde::{Deserialize, Serialize};
11 : use thiserror::Error;
12 : use tokio::net::TcpStream;
13 : use tokio::sync::{mpsc, oneshot};
14 : use tracing::{debug, info, warn};
15 :
16 : use crate::auth::backend::ComputeUserInfo;
17 : use crate::auth::{AuthError, check_peer_addr_is_in_list};
18 : use crate::config::ComputeConfig;
19 : use crate::context::RequestContext;
20 : use crate::control_plane::ControlPlaneApi;
21 : use crate::error::ReportableError;
22 : use crate::ext::LockExt;
23 : use crate::metrics::{CancelChannelSizeGuard, CancellationRequest, Metrics, RedisMsgKind};
24 : use crate::protocol2::ConnectionInfoExtra;
25 : use crate::rate_limiter::LeakyBucketRateLimiter;
26 : use crate::redis::keys::KeyPrefix;
27 : use crate::redis::kv_ops::RedisKVClient;
28 : use crate::tls::postgres_rustls::MakeRustlsConnect;
29 :
30 : type IpSubnetKey = IpNet;
31 :
32 : const CANCEL_KEY_TTL: i64 = 1_209_600; // 2 weeks cancellation key expire time
33 : const REDIS_SEND_TIMEOUT: std::time::Duration = std::time::Duration::from_millis(10);
34 : const BATCH_SIZE: usize = 8;
35 :
36 : // Message types for sending through mpsc channel
37 : pub enum CancelKeyOp {
38 : StoreCancelKey {
39 : key: String,
40 : field: String,
41 : value: String,
42 : resp_tx: Option<oneshot::Sender<anyhow::Result<()>>>,
43 : _guard: CancelChannelSizeGuard<'static>,
44 : expire: i64, // TTL for key
45 : },
46 : GetCancelData {
47 : key: String,
48 : resp_tx: oneshot::Sender<anyhow::Result<Vec<(String, String)>>>,
49 : _guard: CancelChannelSizeGuard<'static>,
50 : },
51 : RemoveCancelKey {
52 : key: String,
53 : field: String,
54 : resp_tx: Option<oneshot::Sender<anyhow::Result<()>>>,
55 : _guard: CancelChannelSizeGuard<'static>,
56 : },
57 : }
58 :
59 : impl CancelKeyOp {
60 0 : fn register(self, pipe: &mut Pipeline) -> Option<CancelReplyOp> {
61 0 : #[allow(clippy::used_underscore_binding)]
62 0 : match self {
63 : CancelKeyOp::StoreCancelKey {
64 0 : key,
65 0 : field,
66 0 : value,
67 0 : resp_tx,
68 0 : _guard,
69 0 : expire,
70 0 : } => {
71 0 : pipe.hset(&key, field, value);
72 0 : pipe.expire(key, expire);
73 0 : let resp_tx = resp_tx?;
74 0 : Some(CancelReplyOp::StoreCancelKey { resp_tx, _guard })
75 : }
76 : CancelKeyOp::GetCancelData {
77 0 : key,
78 0 : resp_tx,
79 0 : _guard,
80 0 : } => {
81 0 : pipe.hgetall(key);
82 0 : Some(CancelReplyOp::GetCancelData { resp_tx, _guard })
83 : }
84 : CancelKeyOp::RemoveCancelKey {
85 0 : key,
86 0 : field,
87 0 : resp_tx,
88 0 : _guard,
89 0 : } => {
90 0 : pipe.hdel(key, field);
91 0 : let resp_tx = resp_tx?;
92 0 : Some(CancelReplyOp::RemoveCancelKey { resp_tx, _guard })
93 : }
94 : }
95 0 : }
96 : }
97 :
98 : // Message types for sending through mpsc channel
99 : pub enum CancelReplyOp {
100 : StoreCancelKey {
101 : resp_tx: oneshot::Sender<anyhow::Result<()>>,
102 : _guard: CancelChannelSizeGuard<'static>,
103 : },
104 : GetCancelData {
105 : resp_tx: oneshot::Sender<anyhow::Result<Vec<(String, String)>>>,
106 : _guard: CancelChannelSizeGuard<'static>,
107 : },
108 : RemoveCancelKey {
109 : resp_tx: oneshot::Sender<anyhow::Result<()>>,
110 : _guard: CancelChannelSizeGuard<'static>,
111 : },
112 : }
113 :
114 : impl CancelReplyOp {
115 0 : fn send_err(self, e: anyhow::Error) {
116 0 : match self {
117 0 : CancelReplyOp::StoreCancelKey { resp_tx, _guard } => {
118 0 : resp_tx
119 0 : .send(Err(e))
120 0 : .inspect_err(|_| tracing::debug!("could not send reply"))
121 0 : .ok();
122 0 : }
123 0 : CancelReplyOp::GetCancelData { resp_tx, _guard } => {
124 0 : resp_tx
125 0 : .send(Err(e))
126 0 : .inspect_err(|_| tracing::debug!("could not send reply"))
127 0 : .ok();
128 0 : }
129 0 : CancelReplyOp::RemoveCancelKey { resp_tx, _guard } => {
130 0 : resp_tx
131 0 : .send(Err(e))
132 0 : .inspect_err(|_| tracing::debug!("could not send reply"))
133 0 : .ok();
134 0 : }
135 : }
136 0 : }
137 :
138 0 : fn send_value(self, v: redis::Value) {
139 0 : match self {
140 0 : CancelReplyOp::StoreCancelKey { resp_tx, _guard } => {
141 0 : let send =
142 0 : FromRedisValue::from_owned_redis_value(v).context("could not parse value");
143 0 : resp_tx
144 0 : .send(send)
145 0 : .inspect_err(|_| tracing::debug!("could not send reply"))
146 0 : .ok();
147 0 : }
148 0 : CancelReplyOp::GetCancelData { resp_tx, _guard } => {
149 0 : let send =
150 0 : FromRedisValue::from_owned_redis_value(v).context("could not parse value");
151 0 : resp_tx
152 0 : .send(send)
153 0 : .inspect_err(|_| tracing::debug!("could not send reply"))
154 0 : .ok();
155 0 : }
156 0 : CancelReplyOp::RemoveCancelKey { resp_tx, _guard } => {
157 0 : let send =
158 0 : FromRedisValue::from_owned_redis_value(v).context("could not parse value");
159 0 : resp_tx
160 0 : .send(send)
161 0 : .inspect_err(|_| tracing::debug!("could not send reply"))
162 0 : .ok();
163 0 : }
164 : }
165 0 : }
166 : }
167 :
168 : // Running as a separate task to accept messages through the rx channel
169 0 : pub async fn handle_cancel_messages(
170 0 : client: &mut RedisKVClient,
171 0 : mut rx: mpsc::Receiver<CancelKeyOp>,
172 0 : ) -> anyhow::Result<()> {
173 0 : let mut batch = Vec::new();
174 0 : let mut replies = vec![];
175 :
176 : loop {
177 0 : if rx.recv_many(&mut batch, BATCH_SIZE).await == 0 {
178 0 : warn!("shutting down cancellation queue");
179 0 : break Ok(());
180 0 : }
181 0 :
182 0 : let batch_size = batch.len();
183 0 : debug!(batch_size, "running cancellation jobs");
184 :
185 0 : let mut pipe = pipe();
186 0 : for msg in batch.drain(..) {
187 0 : if let Some(reply) = msg.register(&mut pipe) {
188 0 : replies.push(reply);
189 0 : } else {
190 0 : pipe.ignore();
191 0 : }
192 : }
193 :
194 0 : let responses = replies.len();
195 0 :
196 0 : match client.query(pipe).await {
197 : // for each reply, we expect that many values.
198 0 : Ok(Value::Array(values)) if values.len() == responses => {
199 0 : debug!(
200 : batch_size,
201 0 : responses, "successfully completed cancellation jobs",
202 : );
203 0 : for (value, reply) in std::iter::zip(values, replies.drain(..)) {
204 0 : reply.send_value(value);
205 0 : }
206 : }
207 0 : Ok(value) => {
208 0 : debug!(?value, "unexpected redis return value");
209 0 : for reply in replies.drain(..) {
210 0 : reply.send_err(anyhow!("incorrect response type from redis"));
211 0 : }
212 : }
213 0 : Err(err) => {
214 0 : for reply in replies.drain(..) {
215 0 : reply.send_err(anyhow!("could not send cmd to redis: {err}"));
216 0 : }
217 : }
218 : }
219 :
220 0 : replies.clear();
221 : }
222 0 : }
223 :
224 : /// Enables serving `CancelRequest`s.
225 : ///
226 : /// If `CancellationPublisher` is available, cancel request will be used to publish the cancellation key to other proxy instances.
227 : pub struct CancellationHandler {
228 : compute_config: &'static ComputeConfig,
229 : // rate limiter of cancellation requests
230 : limiter: Arc<std::sync::Mutex<LeakyBucketRateLimiter<IpSubnetKey>>>,
231 : tx: Option<mpsc::Sender<CancelKeyOp>>, // send messages to the redis KV client task
232 : }
233 :
234 : #[derive(Debug, Error)]
235 : pub(crate) enum CancelError {
236 : #[error("{0}")]
237 : IO(#[from] std::io::Error),
238 :
239 : #[error("{0}")]
240 : Postgres(#[from] postgres_client::Error),
241 :
242 : #[error("rate limit exceeded")]
243 : RateLimit,
244 :
245 : #[error("IP is not allowed")]
246 : IpNotAllowed,
247 :
248 : #[error("VPC endpoint id is not allowed to connect")]
249 : VpcEndpointIdNotAllowed,
250 :
251 : #[error("Authentication backend error")]
252 : AuthError(#[from] AuthError),
253 :
254 : #[error("key not found")]
255 : NotFound,
256 :
257 : #[error("proxy service error")]
258 : InternalError,
259 : }
260 :
261 : impl ReportableError for CancelError {
262 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
263 0 : match self {
264 0 : CancelError::IO(_) => crate::error::ErrorKind::Compute,
265 0 : CancelError::Postgres(e) if e.as_db_error().is_some() => {
266 0 : crate::error::ErrorKind::Postgres
267 : }
268 0 : CancelError::Postgres(_) => crate::error::ErrorKind::Compute,
269 0 : CancelError::RateLimit => crate::error::ErrorKind::RateLimit,
270 : CancelError::IpNotAllowed
271 : | CancelError::VpcEndpointIdNotAllowed
272 0 : | CancelError::NotFound => crate::error::ErrorKind::User,
273 0 : CancelError::AuthError(_) => crate::error::ErrorKind::ControlPlane,
274 0 : CancelError::InternalError => crate::error::ErrorKind::Service,
275 : }
276 0 : }
277 : }
278 :
279 : impl CancellationHandler {
280 0 : pub fn new(
281 0 : compute_config: &'static ComputeConfig,
282 0 : tx: Option<mpsc::Sender<CancelKeyOp>>,
283 0 : ) -> Self {
284 0 : Self {
285 0 : compute_config,
286 0 : tx,
287 0 : limiter: Arc::new(std::sync::Mutex::new(
288 0 : LeakyBucketRateLimiter::<IpSubnetKey>::new_with_shards(
289 0 : LeakyBucketRateLimiter::<IpSubnetKey>::DEFAULT,
290 0 : 64,
291 0 : ),
292 0 : )),
293 0 : }
294 0 : }
295 :
296 0 : pub(crate) fn get_key(self: &Arc<Self>) -> Session {
297 0 : // we intentionally generate a random "backend pid" and "secret key" here.
298 0 : // we use the corresponding u64 as an identifier for the
299 0 : // actual endpoint+pid+secret for postgres/pgbouncer.
300 0 : //
301 0 : // if we forwarded the backend_pid from postgres to the client, there would be a lot
302 0 : // of overlap between our computes as most pids are small (~100).
303 0 :
304 0 : let key: CancelKeyData = rand::random();
305 0 :
306 0 : let prefix_key: KeyPrefix = KeyPrefix::Cancel(key);
307 0 : let redis_key = prefix_key.build_redis_key();
308 0 :
309 0 : debug!("registered new query cancellation key {key}");
310 0 : Session {
311 0 : key,
312 0 : redis_key,
313 0 : cancellation_handler: Arc::clone(self),
314 0 : }
315 0 : }
316 :
317 0 : async fn get_cancel_key(
318 0 : &self,
319 0 : key: CancelKeyData,
320 0 : ) -> Result<Option<CancelClosure>, CancelError> {
321 0 : let prefix_key: KeyPrefix = KeyPrefix::Cancel(key);
322 0 : let redis_key = prefix_key.build_redis_key();
323 0 :
324 0 : let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
325 0 : let op = CancelKeyOp::GetCancelData {
326 0 : key: redis_key,
327 0 : resp_tx,
328 0 : _guard: Metrics::get()
329 0 : .proxy
330 0 : .cancel_channel_size
331 0 : .guard(RedisMsgKind::HGetAll),
332 0 : };
333 :
334 0 : let Some(tx) = &self.tx else {
335 0 : tracing::warn!("cancellation handler is not available");
336 0 : return Err(CancelError::InternalError);
337 : };
338 :
339 0 : tx.send_timeout(op, REDIS_SEND_TIMEOUT)
340 0 : .await
341 0 : .map_err(|e| {
342 0 : tracing::warn!("failed to send GetCancelData for {key}: {e}");
343 0 : })
344 0 : .map_err(|()| CancelError::InternalError)?;
345 :
346 0 : let result = resp_rx.await.map_err(|e| {
347 0 : tracing::warn!("failed to receive GetCancelData response: {e}");
348 0 : CancelError::InternalError
349 0 : })?;
350 :
351 0 : let cancel_state_str: Option<String> = match result {
352 0 : Ok(mut state) => {
353 0 : if state.len() == 1 {
354 0 : Some(state.remove(0).1)
355 : } else {
356 0 : tracing::warn!("unexpected number of entries in cancel state: {state:?}");
357 0 : return Err(CancelError::InternalError);
358 : }
359 : }
360 0 : Err(e) => {
361 0 : tracing::warn!("failed to receive cancel state from redis: {e}");
362 0 : return Err(CancelError::InternalError);
363 : }
364 : };
365 :
366 0 : let cancel_state: Option<CancelClosure> = match cancel_state_str {
367 0 : Some(state) => {
368 0 : let cancel_closure: CancelClosure = serde_json::from_str(&state).map_err(|e| {
369 0 : tracing::warn!("failed to deserialize cancel state: {e}");
370 0 : CancelError::InternalError
371 0 : })?;
372 0 : Some(cancel_closure)
373 : }
374 0 : None => None,
375 : };
376 0 : Ok(cancel_state)
377 0 : }
378 : /// Try to cancel a running query for the corresponding connection.
379 : /// If the cancellation key is not found, it will be published to Redis.
380 : /// check_allowed - if true, check if the IP is allowed to cancel the query.
381 : /// Will fetch IP allowlist internally.
382 : ///
383 : /// return Result primarily for tests
384 0 : pub(crate) async fn cancel_session<T: ControlPlaneApi>(
385 0 : &self,
386 0 : key: CancelKeyData,
387 0 : ctx: RequestContext,
388 0 : check_ip_allowed: bool,
389 0 : check_vpc_allowed: bool,
390 0 : auth_backend: &T,
391 0 : ) -> Result<(), CancelError> {
392 0 : let subnet_key = match ctx.peer_addr() {
393 0 : IpAddr::V4(ip) => IpNet::V4(Ipv4Net::new_assert(ip, 24).trunc()), // use defaut mask here
394 0 : IpAddr::V6(ip) => IpNet::V6(Ipv6Net::new_assert(ip, 64).trunc()),
395 : };
396 0 : if !self.limiter.lock_propagate_poison().check(subnet_key, 1) {
397 : // log only the subnet part of the IP address to know which subnet is rate limited
398 0 : tracing::warn!("Rate limit exceeded. Skipping cancellation message, {subnet_key}");
399 0 : Metrics::get()
400 0 : .proxy
401 0 : .cancellation_requests_total
402 0 : .inc(CancellationRequest {
403 0 : kind: crate::metrics::CancellationOutcome::RateLimitExceeded,
404 0 : });
405 0 : return Err(CancelError::RateLimit);
406 0 : }
407 :
408 0 : let cancel_state = self.get_cancel_key(key).await.map_err(|e| {
409 0 : tracing::warn!("failed to receive RedisOp response: {e}");
410 0 : CancelError::InternalError
411 0 : })?;
412 :
413 0 : let Some(cancel_closure) = cancel_state else {
414 0 : tracing::warn!("query cancellation key not found: {key}");
415 0 : Metrics::get()
416 0 : .proxy
417 0 : .cancellation_requests_total
418 0 : .inc(CancellationRequest {
419 0 : kind: crate::metrics::CancellationOutcome::NotFound,
420 0 : });
421 0 : return Err(CancelError::NotFound);
422 : };
423 :
424 0 : if check_ip_allowed {
425 0 : let ip_allowlist = auth_backend
426 0 : .get_allowed_ips(&ctx, &cancel_closure.user_info)
427 0 : .await
428 0 : .map_err(|e| CancelError::AuthError(e.into()))?;
429 :
430 0 : if !check_peer_addr_is_in_list(&ctx.peer_addr(), &ip_allowlist) {
431 : // log it here since cancel_session could be spawned in a task
432 0 : tracing::warn!(
433 0 : "IP is not allowed to cancel the query: {key}, address: {}",
434 0 : ctx.peer_addr()
435 : );
436 0 : return Err(CancelError::IpNotAllowed);
437 0 : }
438 0 : }
439 :
440 : // check if a VPC endpoint ID is coming in and if yes, if it's allowed
441 0 : let access_blocks = auth_backend
442 0 : .get_block_public_or_vpc_access(&ctx, &cancel_closure.user_info)
443 0 : .await
444 0 : .map_err(|e| CancelError::AuthError(e.into()))?;
445 :
446 0 : if check_vpc_allowed {
447 0 : if access_blocks.vpc_access_blocked {
448 0 : return Err(CancelError::AuthError(AuthError::NetworkNotAllowed));
449 0 : }
450 :
451 0 : let incoming_vpc_endpoint_id = match ctx.extra() {
452 0 : None => return Err(CancelError::AuthError(AuthError::MissingVPCEndpointId)),
453 0 : Some(ConnectionInfoExtra::Aws { vpce_id }) => vpce_id.to_string(),
454 0 : Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(),
455 : };
456 :
457 0 : let allowed_vpc_endpoint_ids = auth_backend
458 0 : .get_allowed_vpc_endpoint_ids(&ctx, &cancel_closure.user_info)
459 0 : .await
460 0 : .map_err(|e| CancelError::AuthError(e.into()))?;
461 : // TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that.
462 0 : if !allowed_vpc_endpoint_ids.is_empty()
463 0 : && !allowed_vpc_endpoint_ids.contains(&incoming_vpc_endpoint_id)
464 : {
465 0 : return Err(CancelError::VpcEndpointIdNotAllowed);
466 0 : }
467 0 : } else if access_blocks.public_access_blocked {
468 0 : return Err(CancelError::VpcEndpointIdNotAllowed);
469 0 : }
470 :
471 0 : Metrics::get()
472 0 : .proxy
473 0 : .cancellation_requests_total
474 0 : .inc(CancellationRequest {
475 0 : kind: crate::metrics::CancellationOutcome::Found,
476 0 : });
477 0 : info!("cancelling query per user's request using key {key}");
478 0 : cancel_closure.try_cancel_query(self.compute_config).await
479 0 : }
480 : }
481 :
482 : /// This should've been a [`std::future::Future`], but
483 : /// it's impossible to name a type of an unboxed future
484 : /// (we'd need something like `#![feature(type_alias_impl_trait)]`).
485 0 : #[derive(Clone, Serialize, Deserialize)]
486 : pub struct CancelClosure {
487 : socket_addr: SocketAddr,
488 : cancel_token: CancelToken,
489 : hostname: String, // for pg_sni router
490 : user_info: ComputeUserInfo,
491 : }
492 :
493 : impl CancelClosure {
494 0 : pub(crate) fn new(
495 0 : socket_addr: SocketAddr,
496 0 : cancel_token: CancelToken,
497 0 : hostname: String,
498 0 : user_info: ComputeUserInfo,
499 0 : ) -> Self {
500 0 : Self {
501 0 : socket_addr,
502 0 : cancel_token,
503 0 : hostname,
504 0 : user_info,
505 0 : }
506 0 : }
507 : /// Cancels the query running on user's compute node.
508 0 : pub(crate) async fn try_cancel_query(
509 0 : self,
510 0 : compute_config: &ComputeConfig,
511 0 : ) -> Result<(), CancelError> {
512 0 : let socket = TcpStream::connect(self.socket_addr).await?;
513 :
514 0 : let mut mk_tls =
515 0 : crate::tls::postgres_rustls::MakeRustlsConnect::new(compute_config.tls.clone());
516 0 : let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
517 0 : &mut mk_tls,
518 0 : &self.hostname,
519 0 : )
520 0 : .map_err(|e| CancelError::IO(std::io::Error::other(e.to_string())))?;
521 :
522 0 : self.cancel_token.cancel_query_raw(socket, tls).await?;
523 0 : debug!("query was cancelled");
524 0 : Ok(())
525 0 : }
526 : }
527 :
528 : /// Helper for registering query cancellation tokens.
529 : pub(crate) struct Session {
530 : /// The user-facing key identifying this session.
531 : key: CancelKeyData,
532 : redis_key: String,
533 : cancellation_handler: Arc<CancellationHandler>,
534 : }
535 :
536 : impl Session {
537 0 : pub(crate) fn key(&self) -> &CancelKeyData {
538 0 : &self.key
539 0 : }
540 :
541 : // Send the store key op to the cancellation handler and set TTL for the key
542 0 : pub(crate) async fn write_cancel_key(
543 0 : &self,
544 0 : cancel_closure: CancelClosure,
545 0 : ) -> Result<(), CancelError> {
546 0 : let Some(tx) = &self.cancellation_handler.tx else {
547 0 : tracing::warn!("cancellation handler is not available");
548 0 : return Err(CancelError::InternalError);
549 : };
550 :
551 0 : let closure_json = serde_json::to_string(&cancel_closure).map_err(|e| {
552 0 : tracing::warn!("failed to serialize cancel closure: {e}");
553 0 : CancelError::InternalError
554 0 : })?;
555 :
556 0 : let op = CancelKeyOp::StoreCancelKey {
557 0 : key: self.redis_key.clone(),
558 0 : field: "data".to_string(),
559 0 : value: closure_json,
560 0 : resp_tx: None,
561 0 : _guard: Metrics::get()
562 0 : .proxy
563 0 : .cancel_channel_size
564 0 : .guard(RedisMsgKind::HSet),
565 0 : expire: CANCEL_KEY_TTL,
566 0 : };
567 0 :
568 0 : let _ = tx.send_timeout(op, REDIS_SEND_TIMEOUT).await.map_err(|e| {
569 0 : let key = self.key;
570 0 : tracing::warn!("failed to send StoreCancelKey for {key}: {e}");
571 0 : });
572 0 : Ok(())
573 0 : }
574 :
575 0 : pub(crate) async fn remove_cancel_key(&self) -> Result<(), CancelError> {
576 0 : let Some(tx) = &self.cancellation_handler.tx else {
577 0 : tracing::warn!("cancellation handler is not available");
578 0 : return Err(CancelError::InternalError);
579 : };
580 :
581 0 : let op = CancelKeyOp::RemoveCancelKey {
582 0 : key: self.redis_key.clone(),
583 0 : field: "data".to_string(),
584 0 : resp_tx: None,
585 0 : _guard: Metrics::get()
586 0 : .proxy
587 0 : .cancel_channel_size
588 0 : .guard(RedisMsgKind::HDel),
589 0 : };
590 0 :
591 0 : let _ = tx.send_timeout(op, REDIS_SEND_TIMEOUT).await.map_err(|e| {
592 0 : let key = self.key;
593 0 : tracing::warn!("failed to send RemoveCancelKey for {key}: {e}");
594 0 : });
595 0 : Ok(())
596 0 : }
597 : }
|