Line data Source code
1 : use std::net::{IpAddr, SocketAddr};
2 : use std::sync::Arc;
3 :
4 : use ipnet::{IpNet, Ipv4Net, Ipv6Net};
5 : use postgres_client::tls::MakeTlsConnect;
6 : use postgres_client::CancelToken;
7 : use pq_proto::CancelKeyData;
8 : use serde::{Deserialize, Serialize};
9 : use thiserror::Error;
10 : use tokio::net::TcpStream;
11 : use tokio::sync::mpsc;
12 : use tracing::{debug, info};
13 :
14 : use crate::auth::backend::{BackendIpAllowlist, ComputeUserInfo};
15 : use crate::auth::{check_peer_addr_is_in_list, AuthError};
16 : use crate::config::ComputeConfig;
17 : use crate::context::RequestContext;
18 : use crate::error::ReportableError;
19 : use crate::ext::LockExt;
20 : use crate::metrics::CancelChannelSizeGuard;
21 : use crate::metrics::{CancellationRequest, Metrics, RedisMsgKind};
22 : use crate::rate_limiter::LeakyBucketRateLimiter;
23 : use crate::redis::keys::KeyPrefix;
24 : use crate::redis::kv_ops::RedisKVClient;
25 : use crate::tls::postgres_rustls::MakeRustlsConnect;
26 : use std::convert::Infallible;
27 : use tokio::sync::oneshot;
28 :
29 : type IpSubnetKey = IpNet;
30 :
31 : const CANCEL_KEY_TTL: i64 = 1_209_600; // 2 weeks cancellation key expire time
32 : const REDIS_SEND_TIMEOUT: std::time::Duration = std::time::Duration::from_millis(10);
33 :
34 : // Message types for sending through mpsc channel
35 : pub enum CancelKeyOp {
36 : StoreCancelKey {
37 : key: String,
38 : field: String,
39 : value: String,
40 : resp_tx: Option<oneshot::Sender<anyhow::Result<()>>>,
41 : _guard: CancelChannelSizeGuard<'static>,
42 : expire: i64, // TTL for key
43 : },
44 : GetCancelData {
45 : key: String,
46 : resp_tx: oneshot::Sender<anyhow::Result<Vec<(String, String)>>>,
47 : _guard: CancelChannelSizeGuard<'static>,
48 : },
49 : RemoveCancelKey {
50 : key: String,
51 : field: String,
52 : resp_tx: Option<oneshot::Sender<anyhow::Result<()>>>,
53 : _guard: CancelChannelSizeGuard<'static>,
54 : },
55 : }
56 :
57 : // Running as a separate task to accept messages through the rx channel
58 : // In case of problems with RTT: switch to recv_many() + redis pipeline
59 0 : pub async fn handle_cancel_messages(
60 0 : client: &mut RedisKVClient,
61 0 : mut rx: mpsc::Receiver<CancelKeyOp>,
62 0 : ) -> anyhow::Result<Infallible> {
63 : loop {
64 0 : if let Some(msg) = rx.recv().await {
65 0 : match msg {
66 : CancelKeyOp::StoreCancelKey {
67 0 : key,
68 0 : field,
69 0 : value,
70 0 : resp_tx,
71 0 : _guard,
72 : expire: _,
73 : } => {
74 0 : if let Some(resp_tx) = resp_tx {
75 0 : resp_tx
76 0 : .send(client.hset(key, field, value).await)
77 0 : .inspect_err(|e| {
78 0 : tracing::debug!("failed to send StoreCancelKey response: {:?}", e);
79 0 : })
80 0 : .ok();
81 0 : } else {
82 0 : drop(client.hset(key, field, value).await);
83 : }
84 : }
85 : CancelKeyOp::GetCancelData {
86 0 : key,
87 0 : resp_tx,
88 0 : _guard,
89 0 : } => {
90 0 : drop(resp_tx.send(client.hget_all(key).await));
91 : }
92 : CancelKeyOp::RemoveCancelKey {
93 0 : key,
94 0 : field,
95 0 : resp_tx,
96 0 : _guard,
97 : } => {
98 0 : if let Some(resp_tx) = resp_tx {
99 0 : resp_tx
100 0 : .send(client.hdel(key, field).await)
101 0 : .inspect_err(|e| {
102 0 : tracing::debug!("failed to send StoreCancelKey response: {:?}", e);
103 0 : })
104 0 : .ok();
105 0 : } else {
106 0 : drop(client.hdel(key, field).await);
107 : }
108 : }
109 : }
110 0 : }
111 : }
112 : }
113 :
114 : /// Enables serving `CancelRequest`s.
115 : ///
116 : /// If `CancellationPublisher` is available, cancel request will be used to publish the cancellation key to other proxy instances.
117 : pub struct CancellationHandler {
118 : compute_config: &'static ComputeConfig,
119 : // rate limiter of cancellation requests
120 : limiter: Arc<std::sync::Mutex<LeakyBucketRateLimiter<IpSubnetKey>>>,
121 : tx: Option<mpsc::Sender<CancelKeyOp>>, // send messages to the redis KV client task
122 : }
123 :
124 : #[derive(Debug, Error)]
125 : pub(crate) enum CancelError {
126 : #[error("{0}")]
127 : IO(#[from] std::io::Error),
128 :
129 : #[error("{0}")]
130 : Postgres(#[from] postgres_client::Error),
131 :
132 : #[error("rate limit exceeded")]
133 : RateLimit,
134 :
135 : #[error("IP is not allowed")]
136 : IpNotAllowed,
137 :
138 : #[error("Authentication backend error")]
139 : AuthError(#[from] AuthError),
140 :
141 : #[error("key not found")]
142 : NotFound,
143 :
144 : #[error("proxy service error")]
145 : InternalError,
146 : }
147 :
148 : impl ReportableError for CancelError {
149 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
150 0 : match self {
151 0 : CancelError::IO(_) => crate::error::ErrorKind::Compute,
152 0 : CancelError::Postgres(e) if e.as_db_error().is_some() => {
153 0 : crate::error::ErrorKind::Postgres
154 : }
155 0 : CancelError::Postgres(_) => crate::error::ErrorKind::Compute,
156 0 : CancelError::RateLimit => crate::error::ErrorKind::RateLimit,
157 0 : CancelError::IpNotAllowed => crate::error::ErrorKind::User,
158 0 : CancelError::NotFound => crate::error::ErrorKind::User,
159 0 : CancelError::AuthError(_) => crate::error::ErrorKind::ControlPlane,
160 0 : CancelError::InternalError => crate::error::ErrorKind::Service,
161 : }
162 0 : }
163 : }
164 :
165 : impl CancellationHandler {
166 0 : pub fn new(
167 0 : compute_config: &'static ComputeConfig,
168 0 : tx: Option<mpsc::Sender<CancelKeyOp>>,
169 0 : ) -> Self {
170 0 : Self {
171 0 : compute_config,
172 0 : tx,
173 0 : limiter: Arc::new(std::sync::Mutex::new(
174 0 : LeakyBucketRateLimiter::<IpSubnetKey>::new_with_shards(
175 0 : LeakyBucketRateLimiter::<IpSubnetKey>::DEFAULT,
176 0 : 64,
177 0 : ),
178 0 : )),
179 0 : }
180 0 : }
181 :
182 0 : pub(crate) fn get_key(self: &Arc<Self>) -> Session {
183 0 : // we intentionally generate a random "backend pid" and "secret key" here.
184 0 : // we use the corresponding u64 as an identifier for the
185 0 : // actual endpoint+pid+secret for postgres/pgbouncer.
186 0 : //
187 0 : // if we forwarded the backend_pid from postgres to the client, there would be a lot
188 0 : // of overlap between our computes as most pids are small (~100).
189 0 :
190 0 : let key: CancelKeyData = rand::random();
191 0 :
192 0 : let prefix_key: KeyPrefix = KeyPrefix::Cancel(key);
193 0 : let redis_key = prefix_key.build_redis_key();
194 0 :
195 0 : debug!("registered new query cancellation key {key}");
196 0 : Session {
197 0 : key,
198 0 : redis_key,
199 0 : cancellation_handler: Arc::clone(self),
200 0 : }
201 0 : }
202 :
203 0 : async fn get_cancel_key(
204 0 : &self,
205 0 : key: CancelKeyData,
206 0 : ) -> Result<Option<CancelClosure>, CancelError> {
207 0 : let prefix_key: KeyPrefix = KeyPrefix::Cancel(key);
208 0 : let redis_key = prefix_key.build_redis_key();
209 0 :
210 0 : let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
211 0 : let op = CancelKeyOp::GetCancelData {
212 0 : key: redis_key,
213 0 : resp_tx,
214 0 : _guard: Metrics::get()
215 0 : .proxy
216 0 : .cancel_channel_size
217 0 : .guard(RedisMsgKind::HGetAll),
218 0 : };
219 :
220 0 : let Some(tx) = &self.tx else {
221 0 : tracing::warn!("cancellation handler is not available");
222 0 : return Err(CancelError::InternalError);
223 : };
224 :
225 0 : tx.send_timeout(op, REDIS_SEND_TIMEOUT)
226 0 : .await
227 0 : .map_err(|e| {
228 0 : tracing::warn!("failed to send GetCancelData for {key}: {e}");
229 0 : })
230 0 : .map_err(|()| CancelError::InternalError)?;
231 :
232 0 : let result = resp_rx.await.map_err(|e| {
233 0 : tracing::warn!("failed to receive GetCancelData response: {e}");
234 0 : CancelError::InternalError
235 0 : })?;
236 :
237 0 : let cancel_state_str: Option<String> = match result {
238 0 : Ok(mut state) => {
239 0 : if state.len() == 1 {
240 0 : Some(state.remove(0).1)
241 : } else {
242 0 : tracing::warn!("unexpected number of entries in cancel state: {state:?}");
243 0 : return Err(CancelError::InternalError);
244 : }
245 : }
246 0 : Err(e) => {
247 0 : tracing::warn!("failed to receive cancel state from redis: {e}");
248 0 : return Err(CancelError::InternalError);
249 : }
250 : };
251 :
252 0 : let cancel_state: Option<CancelClosure> = match cancel_state_str {
253 0 : Some(state) => {
254 0 : let cancel_closure: CancelClosure = serde_json::from_str(&state).map_err(|e| {
255 0 : tracing::warn!("failed to deserialize cancel state: {e}");
256 0 : CancelError::InternalError
257 0 : })?;
258 0 : Some(cancel_closure)
259 : }
260 0 : None => None,
261 : };
262 0 : Ok(cancel_state)
263 0 : }
264 : /// Try to cancel a running query for the corresponding connection.
265 : /// If the cancellation key is not found, it will be published to Redis.
266 : /// check_allowed - if true, check if the IP is allowed to cancel the query.
267 : /// Will fetch IP allowlist internally.
268 : ///
269 : /// return Result primarily for tests
270 0 : pub(crate) async fn cancel_session<T: BackendIpAllowlist>(
271 0 : &self,
272 0 : key: CancelKeyData,
273 0 : ctx: RequestContext,
274 0 : check_allowed: bool,
275 0 : auth_backend: &T,
276 0 : ) -> Result<(), CancelError> {
277 0 : let subnet_key = match ctx.peer_addr() {
278 0 : IpAddr::V4(ip) => IpNet::V4(Ipv4Net::new_assert(ip, 24).trunc()), // use defaut mask here
279 0 : IpAddr::V6(ip) => IpNet::V6(Ipv6Net::new_assert(ip, 64).trunc()),
280 : };
281 0 : if !self.limiter.lock_propagate_poison().check(subnet_key, 1) {
282 : // log only the subnet part of the IP address to know which subnet is rate limited
283 0 : tracing::warn!("Rate limit exceeded. Skipping cancellation message, {subnet_key}");
284 0 : Metrics::get()
285 0 : .proxy
286 0 : .cancellation_requests_total
287 0 : .inc(CancellationRequest {
288 0 : kind: crate::metrics::CancellationOutcome::RateLimitExceeded,
289 0 : });
290 0 : return Err(CancelError::RateLimit);
291 0 : }
292 :
293 0 : let cancel_state = self.get_cancel_key(key).await.map_err(|e| {
294 0 : tracing::warn!("failed to receive RedisOp response: {e}");
295 0 : CancelError::InternalError
296 0 : })?;
297 :
298 0 : let Some(cancel_closure) = cancel_state else {
299 0 : tracing::warn!("query cancellation key not found: {key}");
300 0 : Metrics::get()
301 0 : .proxy
302 0 : .cancellation_requests_total
303 0 : .inc(CancellationRequest {
304 0 : kind: crate::metrics::CancellationOutcome::NotFound,
305 0 : });
306 0 : return Err(CancelError::NotFound);
307 : };
308 :
309 0 : if check_allowed {
310 0 : let ip_allowlist = auth_backend
311 0 : .get_allowed_ips(&ctx, &cancel_closure.user_info)
312 0 : .await
313 0 : .map_err(CancelError::AuthError)?;
314 :
315 0 : if !check_peer_addr_is_in_list(&ctx.peer_addr(), &ip_allowlist) {
316 : // log it here since cancel_session could be spawned in a task
317 0 : tracing::warn!(
318 0 : "IP is not allowed to cancel the query: {key}, address: {}",
319 0 : ctx.peer_addr()
320 : );
321 0 : return Err(CancelError::IpNotAllowed);
322 0 : }
323 0 : }
324 :
325 0 : Metrics::get()
326 0 : .proxy
327 0 : .cancellation_requests_total
328 0 : .inc(CancellationRequest {
329 0 : kind: crate::metrics::CancellationOutcome::Found,
330 0 : });
331 0 : info!("cancelling query per user's request using key {key}");
332 0 : cancel_closure.try_cancel_query(self.compute_config).await
333 0 : }
334 : }
335 :
336 : /// This should've been a [`std::future::Future`], but
337 : /// it's impossible to name a type of an unboxed future
338 : /// (we'd need something like `#![feature(type_alias_impl_trait)]`).
339 0 : #[derive(Clone, Serialize, Deserialize)]
340 : pub struct CancelClosure {
341 : socket_addr: SocketAddr,
342 : cancel_token: CancelToken,
343 : hostname: String, // for pg_sni router
344 : user_info: ComputeUserInfo,
345 : }
346 :
347 : impl CancelClosure {
348 0 : pub(crate) fn new(
349 0 : socket_addr: SocketAddr,
350 0 : cancel_token: CancelToken,
351 0 : hostname: String,
352 0 : user_info: ComputeUserInfo,
353 0 : ) -> Self {
354 0 : Self {
355 0 : socket_addr,
356 0 : cancel_token,
357 0 : hostname,
358 0 : user_info,
359 0 : }
360 0 : }
361 : /// Cancels the query running on user's compute node.
362 0 : pub(crate) async fn try_cancel_query(
363 0 : self,
364 0 : compute_config: &ComputeConfig,
365 0 : ) -> Result<(), CancelError> {
366 0 : let socket = TcpStream::connect(self.socket_addr).await?;
367 :
368 0 : let mut mk_tls =
369 0 : crate::tls::postgres_rustls::MakeRustlsConnect::new(compute_config.tls.clone());
370 0 : let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
371 0 : &mut mk_tls,
372 0 : &self.hostname,
373 0 : )
374 0 : .map_err(|e| {
375 0 : CancelError::IO(std::io::Error::new(
376 0 : std::io::ErrorKind::Other,
377 0 : e.to_string(),
378 0 : ))
379 0 : })?;
380 :
381 0 : self.cancel_token.cancel_query_raw(socket, tls).await?;
382 0 : debug!("query was cancelled");
383 0 : Ok(())
384 0 : }
385 : }
386 :
387 : /// Helper for registering query cancellation tokens.
388 : pub(crate) struct Session {
389 : /// The user-facing key identifying this session.
390 : key: CancelKeyData,
391 : redis_key: String,
392 : cancellation_handler: Arc<CancellationHandler>,
393 : }
394 :
395 : impl Session {
396 0 : pub(crate) fn key(&self) -> &CancelKeyData {
397 0 : &self.key
398 0 : }
399 :
400 : // Send the store key op to the cancellation handler
401 0 : pub(crate) async fn write_cancel_key(
402 0 : &self,
403 0 : cancel_closure: CancelClosure,
404 0 : ) -> Result<(), CancelError> {
405 0 : let Some(tx) = &self.cancellation_handler.tx else {
406 0 : tracing::warn!("cancellation handler is not available");
407 0 : return Err(CancelError::InternalError);
408 : };
409 :
410 0 : let closure_json = serde_json::to_string(&cancel_closure).map_err(|e| {
411 0 : tracing::warn!("failed to serialize cancel closure: {e}");
412 0 : CancelError::InternalError
413 0 : })?;
414 :
415 0 : let op = CancelKeyOp::StoreCancelKey {
416 0 : key: self.redis_key.clone(),
417 0 : field: "data".to_string(),
418 0 : value: closure_json,
419 0 : resp_tx: None,
420 0 : _guard: Metrics::get()
421 0 : .proxy
422 0 : .cancel_channel_size
423 0 : .guard(RedisMsgKind::HSet),
424 0 : expire: CANCEL_KEY_TTL,
425 0 : };
426 0 :
427 0 : let _ = tx.send_timeout(op, REDIS_SEND_TIMEOUT).await.map_err(|e| {
428 0 : let key = self.key;
429 0 : tracing::warn!("failed to send StoreCancelKey for {key}: {e}");
430 0 : });
431 0 : Ok(())
432 0 : }
433 :
434 0 : pub(crate) async fn remove_cancel_key(&self) -> Result<(), CancelError> {
435 0 : let Some(tx) = &self.cancellation_handler.tx else {
436 0 : tracing::warn!("cancellation handler is not available");
437 0 : return Err(CancelError::InternalError);
438 : };
439 :
440 0 : let op = CancelKeyOp::RemoveCancelKey {
441 0 : key: self.redis_key.clone(),
442 0 : field: "data".to_string(),
443 0 : resp_tx: None,
444 0 : _guard: Metrics::get()
445 0 : .proxy
446 0 : .cancel_channel_size
447 0 : .guard(RedisMsgKind::HSet),
448 0 : };
449 0 :
450 0 : let _ = tx.send_timeout(op, REDIS_SEND_TIMEOUT).await.map_err(|e| {
451 0 : let key = self.key;
452 0 : tracing::warn!("failed to send RemoveCancelKey for {key}: {e}");
453 0 : });
454 0 : Ok(())
455 0 : }
456 : }
|