Line data Source code
1 : use std::convert::Infallible;
2 : use std::net::{IpAddr, SocketAddr};
3 : use std::sync::Arc;
4 :
5 : use ipnet::{IpNet, Ipv4Net, Ipv6Net};
6 : use postgres_client::CancelToken;
7 : use postgres_client::tls::MakeTlsConnect;
8 : use pq_proto::CancelKeyData;
9 : use serde::{Deserialize, Serialize};
10 : use thiserror::Error;
11 : use tokio::net::TcpStream;
12 : use tokio::sync::{mpsc, oneshot};
13 : use tracing::{debug, info};
14 :
15 : use crate::auth::backend::ComputeUserInfo;
16 : use crate::auth::{AuthError, check_peer_addr_is_in_list};
17 : use crate::config::ComputeConfig;
18 : use crate::context::RequestContext;
19 : use crate::control_plane::ControlPlaneApi;
20 : use crate::error::ReportableError;
21 : use crate::ext::LockExt;
22 : use crate::metrics::{CancelChannelSizeGuard, CancellationRequest, Metrics, RedisMsgKind};
23 : use crate::protocol2::ConnectionInfoExtra;
24 : use crate::rate_limiter::LeakyBucketRateLimiter;
25 : use crate::redis::keys::KeyPrefix;
26 : use crate::redis::kv_ops::RedisKVClient;
27 : use crate::tls::postgres_rustls::MakeRustlsConnect;
28 :
29 : type IpSubnetKey = IpNet;
30 :
31 : const CANCEL_KEY_TTL: i64 = 1_209_600; // 2 weeks cancellation key expire time
32 : const REDIS_SEND_TIMEOUT: std::time::Duration = std::time::Duration::from_millis(10);
33 :
34 : // Message types for sending through mpsc channel
35 : pub enum CancelKeyOp {
36 : StoreCancelKey {
37 : key: String,
38 : field: String,
39 : value: String,
40 : resp_tx: Option<oneshot::Sender<anyhow::Result<()>>>,
41 : _guard: CancelChannelSizeGuard<'static>,
42 : expire: i64, // TTL for key
43 : },
44 : GetCancelData {
45 : key: String,
46 : resp_tx: oneshot::Sender<anyhow::Result<Vec<(String, String)>>>,
47 : _guard: CancelChannelSizeGuard<'static>,
48 : },
49 : RemoveCancelKey {
50 : key: String,
51 : field: String,
52 : resp_tx: Option<oneshot::Sender<anyhow::Result<()>>>,
53 : _guard: CancelChannelSizeGuard<'static>,
54 : },
55 : }
56 :
57 : // Running as a separate task to accept messages through the rx channel
58 : // In case of problems with RTT: switch to recv_many() + redis pipeline
59 0 : pub async fn handle_cancel_messages(
60 0 : client: &mut RedisKVClient,
61 0 : mut rx: mpsc::Receiver<CancelKeyOp>,
62 0 : ) -> anyhow::Result<Infallible> {
63 : loop {
64 0 : if let Some(msg) = rx.recv().await {
65 0 : match msg {
66 : CancelKeyOp::StoreCancelKey {
67 0 : key,
68 0 : field,
69 0 : value,
70 0 : resp_tx,
71 0 : _guard,
72 0 : expire,
73 : } => {
74 0 : let res = client.hset(&key, field, value).await;
75 0 : if let Some(resp_tx) = resp_tx {
76 0 : if res.is_ok() {
77 0 : resp_tx
78 0 : .send(client.expire(key, expire).await)
79 0 : .inspect_err(|e| {
80 0 : tracing::debug!(
81 0 : "failed to send StoreCancelKey response: {:?}",
82 : e
83 : );
84 0 : })
85 0 : .ok();
86 0 : } else {
87 0 : resp_tx
88 0 : .send(res)
89 0 : .inspect_err(|e| {
90 0 : tracing::debug!(
91 0 : "failed to send StoreCancelKey response: {:?}",
92 : e
93 : );
94 0 : })
95 0 : .ok();
96 0 : }
97 0 : } else if res.is_ok() {
98 0 : drop(client.expire(key, expire).await);
99 : } else {
100 0 : tracing::warn!("failed to store cancel key: {:?}", res);
101 : }
102 : }
103 : CancelKeyOp::GetCancelData {
104 0 : key,
105 0 : resp_tx,
106 0 : _guard,
107 0 : } => {
108 0 : drop(resp_tx.send(client.hget_all(key).await));
109 : }
110 : CancelKeyOp::RemoveCancelKey {
111 0 : key,
112 0 : field,
113 0 : resp_tx,
114 0 : _guard,
115 : } => {
116 0 : if let Some(resp_tx) = resp_tx {
117 0 : resp_tx
118 0 : .send(client.hdel(key, field).await)
119 0 : .inspect_err(|e| {
120 0 : tracing::debug!("failed to send StoreCancelKey response: {:?}", e);
121 0 : })
122 0 : .ok();
123 0 : } else {
124 0 : drop(client.hdel(key, field).await);
125 : }
126 : }
127 : }
128 0 : }
129 : }
130 : }
131 :
132 : /// Enables serving `CancelRequest`s.
133 : ///
134 : /// If `CancellationPublisher` is available, cancel request will be used to publish the cancellation key to other proxy instances.
135 : pub struct CancellationHandler {
136 : compute_config: &'static ComputeConfig,
137 : // rate limiter of cancellation requests
138 : limiter: Arc<std::sync::Mutex<LeakyBucketRateLimiter<IpSubnetKey>>>,
139 : tx: Option<mpsc::Sender<CancelKeyOp>>, // send messages to the redis KV client task
140 : }
141 :
142 : #[derive(Debug, Error)]
143 : pub(crate) enum CancelError {
144 : #[error("{0}")]
145 : IO(#[from] std::io::Error),
146 :
147 : #[error("{0}")]
148 : Postgres(#[from] postgres_client::Error),
149 :
150 : #[error("rate limit exceeded")]
151 : RateLimit,
152 :
153 : #[error("IP is not allowed")]
154 : IpNotAllowed,
155 :
156 : #[error("VPC endpoint id is not allowed to connect")]
157 : VpcEndpointIdNotAllowed,
158 :
159 : #[error("Authentication backend error")]
160 : AuthError(#[from] AuthError),
161 :
162 : #[error("key not found")]
163 : NotFound,
164 :
165 : #[error("proxy service error")]
166 : InternalError,
167 : }
168 :
169 : impl ReportableError for CancelError {
170 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
171 0 : match self {
172 0 : CancelError::IO(_) => crate::error::ErrorKind::Compute,
173 0 : CancelError::Postgres(e) if e.as_db_error().is_some() => {
174 0 : crate::error::ErrorKind::Postgres
175 : }
176 0 : CancelError::Postgres(_) => crate::error::ErrorKind::Compute,
177 0 : CancelError::RateLimit => crate::error::ErrorKind::RateLimit,
178 : CancelError::IpNotAllowed
179 : | CancelError::VpcEndpointIdNotAllowed
180 0 : | CancelError::NotFound => crate::error::ErrorKind::User,
181 0 : CancelError::AuthError(_) => crate::error::ErrorKind::ControlPlane,
182 0 : CancelError::InternalError => crate::error::ErrorKind::Service,
183 : }
184 0 : }
185 : }
186 :
187 : impl CancellationHandler {
188 0 : pub fn new(
189 0 : compute_config: &'static ComputeConfig,
190 0 : tx: Option<mpsc::Sender<CancelKeyOp>>,
191 0 : ) -> Self {
192 0 : Self {
193 0 : compute_config,
194 0 : tx,
195 0 : limiter: Arc::new(std::sync::Mutex::new(
196 0 : LeakyBucketRateLimiter::<IpSubnetKey>::new_with_shards(
197 0 : LeakyBucketRateLimiter::<IpSubnetKey>::DEFAULT,
198 0 : 64,
199 0 : ),
200 0 : )),
201 0 : }
202 0 : }
203 :
204 0 : pub(crate) fn get_key(self: &Arc<Self>) -> Session {
205 0 : // we intentionally generate a random "backend pid" and "secret key" here.
206 0 : // we use the corresponding u64 as an identifier for the
207 0 : // actual endpoint+pid+secret for postgres/pgbouncer.
208 0 : //
209 0 : // if we forwarded the backend_pid from postgres to the client, there would be a lot
210 0 : // of overlap between our computes as most pids are small (~100).
211 0 :
212 0 : let key: CancelKeyData = rand::random();
213 0 :
214 0 : let prefix_key: KeyPrefix = KeyPrefix::Cancel(key);
215 0 : let redis_key = prefix_key.build_redis_key();
216 0 :
217 0 : debug!("registered new query cancellation key {key}");
218 0 : Session {
219 0 : key,
220 0 : redis_key,
221 0 : cancellation_handler: Arc::clone(self),
222 0 : }
223 0 : }
224 :
225 0 : async fn get_cancel_key(
226 0 : &self,
227 0 : key: CancelKeyData,
228 0 : ) -> Result<Option<CancelClosure>, CancelError> {
229 0 : let prefix_key: KeyPrefix = KeyPrefix::Cancel(key);
230 0 : let redis_key = prefix_key.build_redis_key();
231 0 :
232 0 : let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
233 0 : let op = CancelKeyOp::GetCancelData {
234 0 : key: redis_key,
235 0 : resp_tx,
236 0 : _guard: Metrics::get()
237 0 : .proxy
238 0 : .cancel_channel_size
239 0 : .guard(RedisMsgKind::HGetAll),
240 0 : };
241 :
242 0 : let Some(tx) = &self.tx else {
243 0 : tracing::warn!("cancellation handler is not available");
244 0 : return Err(CancelError::InternalError);
245 : };
246 :
247 0 : tx.send_timeout(op, REDIS_SEND_TIMEOUT)
248 0 : .await
249 0 : .map_err(|e| {
250 0 : tracing::warn!("failed to send GetCancelData for {key}: {e}");
251 0 : })
252 0 : .map_err(|()| CancelError::InternalError)?;
253 :
254 0 : let result = resp_rx.await.map_err(|e| {
255 0 : tracing::warn!("failed to receive GetCancelData response: {e}");
256 0 : CancelError::InternalError
257 0 : })?;
258 :
259 0 : let cancel_state_str: Option<String> = match result {
260 0 : Ok(mut state) => {
261 0 : if state.len() == 1 {
262 0 : Some(state.remove(0).1)
263 : } else {
264 0 : tracing::warn!("unexpected number of entries in cancel state: {state:?}");
265 0 : return Err(CancelError::InternalError);
266 : }
267 : }
268 0 : Err(e) => {
269 0 : tracing::warn!("failed to receive cancel state from redis: {e}");
270 0 : return Err(CancelError::InternalError);
271 : }
272 : };
273 :
274 0 : let cancel_state: Option<CancelClosure> = match cancel_state_str {
275 0 : Some(state) => {
276 0 : let cancel_closure: CancelClosure = serde_json::from_str(&state).map_err(|e| {
277 0 : tracing::warn!("failed to deserialize cancel state: {e}");
278 0 : CancelError::InternalError
279 0 : })?;
280 0 : Some(cancel_closure)
281 : }
282 0 : None => None,
283 : };
284 0 : Ok(cancel_state)
285 0 : }
286 : /// Try to cancel a running query for the corresponding connection.
287 : /// If the cancellation key is not found, it will be published to Redis.
288 : /// check_allowed - if true, check if the IP is allowed to cancel the query.
289 : /// Will fetch IP allowlist internally.
290 : ///
291 : /// return Result primarily for tests
292 0 : pub(crate) async fn cancel_session<T: ControlPlaneApi>(
293 0 : &self,
294 0 : key: CancelKeyData,
295 0 : ctx: RequestContext,
296 0 : check_ip_allowed: bool,
297 0 : check_vpc_allowed: bool,
298 0 : auth_backend: &T,
299 0 : ) -> Result<(), CancelError> {
300 0 : let subnet_key = match ctx.peer_addr() {
301 0 : IpAddr::V4(ip) => IpNet::V4(Ipv4Net::new_assert(ip, 24).trunc()), // use defaut mask here
302 0 : IpAddr::V6(ip) => IpNet::V6(Ipv6Net::new_assert(ip, 64).trunc()),
303 : };
304 0 : if !self.limiter.lock_propagate_poison().check(subnet_key, 1) {
305 : // log only the subnet part of the IP address to know which subnet is rate limited
306 0 : tracing::warn!("Rate limit exceeded. Skipping cancellation message, {subnet_key}");
307 0 : Metrics::get()
308 0 : .proxy
309 0 : .cancellation_requests_total
310 0 : .inc(CancellationRequest {
311 0 : kind: crate::metrics::CancellationOutcome::RateLimitExceeded,
312 0 : });
313 0 : return Err(CancelError::RateLimit);
314 0 : }
315 :
316 0 : let cancel_state = self.get_cancel_key(key).await.map_err(|e| {
317 0 : tracing::warn!("failed to receive RedisOp response: {e}");
318 0 : CancelError::InternalError
319 0 : })?;
320 :
321 0 : let Some(cancel_closure) = cancel_state else {
322 0 : tracing::warn!("query cancellation key not found: {key}");
323 0 : Metrics::get()
324 0 : .proxy
325 0 : .cancellation_requests_total
326 0 : .inc(CancellationRequest {
327 0 : kind: crate::metrics::CancellationOutcome::NotFound,
328 0 : });
329 0 : return Err(CancelError::NotFound);
330 : };
331 :
332 0 : if check_ip_allowed {
333 0 : let ip_allowlist = auth_backend
334 0 : .get_allowed_ips(&ctx, &cancel_closure.user_info)
335 0 : .await
336 0 : .map_err(|e| CancelError::AuthError(e.into()))?;
337 :
338 0 : if !check_peer_addr_is_in_list(&ctx.peer_addr(), &ip_allowlist) {
339 : // log it here since cancel_session could be spawned in a task
340 0 : tracing::warn!(
341 0 : "IP is not allowed to cancel the query: {key}, address: {}",
342 0 : ctx.peer_addr()
343 : );
344 0 : return Err(CancelError::IpNotAllowed);
345 0 : }
346 0 : }
347 :
348 : // check if a VPC endpoint ID is coming in and if yes, if it's allowed
349 0 : let access_blocks = auth_backend
350 0 : .get_block_public_or_vpc_access(&ctx, &cancel_closure.user_info)
351 0 : .await
352 0 : .map_err(|e| CancelError::AuthError(e.into()))?;
353 :
354 0 : if check_vpc_allowed {
355 0 : if access_blocks.vpc_access_blocked {
356 0 : return Err(CancelError::AuthError(AuthError::NetworkNotAllowed));
357 0 : }
358 :
359 0 : let incoming_vpc_endpoint_id = match ctx.extra() {
360 0 : None => return Err(CancelError::AuthError(AuthError::MissingVPCEndpointId)),
361 0 : Some(ConnectionInfoExtra::Aws { vpce_id }) => vpce_id.to_string(),
362 0 : Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(),
363 : };
364 :
365 0 : let allowed_vpc_endpoint_ids = auth_backend
366 0 : .get_allowed_vpc_endpoint_ids(&ctx, &cancel_closure.user_info)
367 0 : .await
368 0 : .map_err(|e| CancelError::AuthError(e.into()))?;
369 : // TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that.
370 0 : if !allowed_vpc_endpoint_ids.is_empty()
371 0 : && !allowed_vpc_endpoint_ids.contains(&incoming_vpc_endpoint_id)
372 : {
373 0 : return Err(CancelError::VpcEndpointIdNotAllowed);
374 0 : }
375 0 : } else if access_blocks.public_access_blocked {
376 0 : return Err(CancelError::VpcEndpointIdNotAllowed);
377 0 : }
378 :
379 0 : Metrics::get()
380 0 : .proxy
381 0 : .cancellation_requests_total
382 0 : .inc(CancellationRequest {
383 0 : kind: crate::metrics::CancellationOutcome::Found,
384 0 : });
385 0 : info!("cancelling query per user's request using key {key}");
386 0 : cancel_closure.try_cancel_query(self.compute_config).await
387 0 : }
388 : }
389 :
390 : /// This should've been a [`std::future::Future`], but
391 : /// it's impossible to name a type of an unboxed future
392 : /// (we'd need something like `#![feature(type_alias_impl_trait)]`).
393 0 : #[derive(Clone, Serialize, Deserialize)]
394 : pub struct CancelClosure {
395 : socket_addr: SocketAddr,
396 : cancel_token: CancelToken,
397 : hostname: String, // for pg_sni router
398 : user_info: ComputeUserInfo,
399 : }
400 :
401 : impl CancelClosure {
402 0 : pub(crate) fn new(
403 0 : socket_addr: SocketAddr,
404 0 : cancel_token: CancelToken,
405 0 : hostname: String,
406 0 : user_info: ComputeUserInfo,
407 0 : ) -> Self {
408 0 : Self {
409 0 : socket_addr,
410 0 : cancel_token,
411 0 : hostname,
412 0 : user_info,
413 0 : }
414 0 : }
415 : /// Cancels the query running on user's compute node.
416 0 : pub(crate) async fn try_cancel_query(
417 0 : self,
418 0 : compute_config: &ComputeConfig,
419 0 : ) -> Result<(), CancelError> {
420 0 : let socket = TcpStream::connect(self.socket_addr).await?;
421 :
422 0 : let mut mk_tls =
423 0 : crate::tls::postgres_rustls::MakeRustlsConnect::new(compute_config.tls.clone());
424 0 : let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
425 0 : &mut mk_tls,
426 0 : &self.hostname,
427 0 : )
428 0 : .map_err(|e| {
429 0 : CancelError::IO(std::io::Error::new(
430 0 : std::io::ErrorKind::Other,
431 0 : e.to_string(),
432 0 : ))
433 0 : })?;
434 :
435 0 : self.cancel_token.cancel_query_raw(socket, tls).await?;
436 0 : debug!("query was cancelled");
437 0 : Ok(())
438 0 : }
439 : }
440 :
441 : /// Helper for registering query cancellation tokens.
442 : pub(crate) struct Session {
443 : /// The user-facing key identifying this session.
444 : key: CancelKeyData,
445 : redis_key: String,
446 : cancellation_handler: Arc<CancellationHandler>,
447 : }
448 :
449 : impl Session {
450 0 : pub(crate) fn key(&self) -> &CancelKeyData {
451 0 : &self.key
452 0 : }
453 :
454 : // Send the store key op to the cancellation handler and set TTL for the key
455 0 : pub(crate) async fn write_cancel_key(
456 0 : &self,
457 0 : cancel_closure: CancelClosure,
458 0 : ) -> Result<(), CancelError> {
459 0 : let Some(tx) = &self.cancellation_handler.tx else {
460 0 : tracing::warn!("cancellation handler is not available");
461 0 : return Err(CancelError::InternalError);
462 : };
463 :
464 0 : let closure_json = serde_json::to_string(&cancel_closure).map_err(|e| {
465 0 : tracing::warn!("failed to serialize cancel closure: {e}");
466 0 : CancelError::InternalError
467 0 : })?;
468 :
469 0 : let op = CancelKeyOp::StoreCancelKey {
470 0 : key: self.redis_key.clone(),
471 0 : field: "data".to_string(),
472 0 : value: closure_json,
473 0 : resp_tx: None,
474 0 : _guard: Metrics::get()
475 0 : .proxy
476 0 : .cancel_channel_size
477 0 : .guard(RedisMsgKind::HSet),
478 0 : expire: CANCEL_KEY_TTL,
479 0 : };
480 0 :
481 0 : let _ = tx.send_timeout(op, REDIS_SEND_TIMEOUT).await.map_err(|e| {
482 0 : let key = self.key;
483 0 : tracing::warn!("failed to send StoreCancelKey for {key}: {e}");
484 0 : });
485 0 : Ok(())
486 0 : }
487 :
488 0 : pub(crate) async fn remove_cancel_key(&self) -> Result<(), CancelError> {
489 0 : let Some(tx) = &self.cancellation_handler.tx else {
490 0 : tracing::warn!("cancellation handler is not available");
491 0 : return Err(CancelError::InternalError);
492 : };
493 :
494 0 : let op = CancelKeyOp::RemoveCancelKey {
495 0 : key: self.redis_key.clone(),
496 0 : field: "data".to_string(),
497 0 : resp_tx: None,
498 0 : _guard: Metrics::get()
499 0 : .proxy
500 0 : .cancel_channel_size
501 0 : .guard(RedisMsgKind::HDel),
502 0 : };
503 0 :
504 0 : let _ = tx.send_timeout(op, REDIS_SEND_TIMEOUT).await.map_err(|e| {
505 0 : let key = self.key;
506 0 : tracing::warn!("failed to send RemoveCancelKey for {key}: {e}");
507 0 : });
508 0 : Ok(())
509 0 : }
510 : }
|