Line data Source code
1 : #[cfg(any(test, feature = "testing"))]
2 : pub mod mock;
3 : pub mod neon;
4 :
5 : use super::messages::{ConsoleError, MetricsAuxInfo};
6 : use crate::{
7 : auth::{
8 : backend::{ComputeCredentialKeys, ComputeUserInfo},
9 : IpPattern,
10 : },
11 : cache::{endpoints::EndpointsCache, project_info::ProjectInfoCacheImpl, Cached, TimedLru},
12 : compute,
13 : config::{CacheOptions, EndpointCacheConfig, ProjectInfoCacheOptions},
14 : context::RequestMonitoring,
15 : error::ReportableError,
16 : intern::ProjectIdInt,
17 : metrics::ApiLockMetrics,
18 : rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token},
19 : scram, EndpointCacheKey,
20 : };
21 : use dashmap::DashMap;
22 : use std::{hash::Hash, sync::Arc, time::Duration};
23 : use tokio::time::Instant;
24 : use tracing::info;
25 :
26 : pub(crate) mod errors {
27 : use crate::{
28 : console::messages::{self, ConsoleError, Reason},
29 : error::{io_error, ErrorKind, ReportableError, UserFacingError},
30 : proxy::retry::CouldRetry,
31 : };
32 : use thiserror::Error;
33 :
34 : use super::ApiLockError;
35 :
36 : /// A go-to error message which doesn't leak any detail.
37 : pub(crate) const REQUEST_FAILED: &str = "Console request failed";
38 :
39 : /// Common console API error.
40 0 : #[derive(Debug, Error)]
41 : pub(crate) enum ApiError {
42 : /// Error returned by the console itself.
43 : #[error("{REQUEST_FAILED} with {0}")]
44 : Console(ConsoleError),
45 :
46 : /// Various IO errors like broken pipe or malformed payload.
47 : #[error("{REQUEST_FAILED}: {0}")]
48 : Transport(#[from] std::io::Error),
49 : }
50 :
51 : impl ApiError {
52 : /// Returns HTTP status code if it's the reason for failure.
53 0 : pub(crate) fn get_reason(&self) -> messages::Reason {
54 0 : match self {
55 0 : ApiError::Console(e) => e.get_reason(),
56 0 : ApiError::Transport(_) => messages::Reason::Unknown,
57 : }
58 0 : }
59 : }
60 :
61 : impl UserFacingError for ApiError {
62 0 : fn to_string_client(&self) -> String {
63 0 : match self {
64 : // To minimize risks, only select errors are forwarded to users.
65 0 : ApiError::Console(c) => c.get_user_facing_message(),
66 0 : ApiError::Transport(_) => REQUEST_FAILED.to_owned(),
67 : }
68 0 : }
69 : }
70 :
71 : impl ReportableError for ApiError {
72 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
73 0 : match self {
74 0 : ApiError::Console(e) => match e.get_reason() {
75 0 : Reason::RoleProtected => ErrorKind::User,
76 0 : Reason::ResourceNotFound => ErrorKind::User,
77 0 : Reason::ProjectNotFound => ErrorKind::User,
78 0 : Reason::EndpointNotFound => ErrorKind::User,
79 0 : Reason::BranchNotFound => ErrorKind::User,
80 0 : Reason::RateLimitExceeded => ErrorKind::ServiceRateLimit,
81 0 : Reason::NonDefaultBranchComputeTimeExceeded => ErrorKind::User,
82 0 : Reason::ActiveTimeQuotaExceeded => ErrorKind::User,
83 0 : Reason::ComputeTimeQuotaExceeded => ErrorKind::User,
84 0 : Reason::WrittenDataQuotaExceeded => ErrorKind::User,
85 0 : Reason::DataTransferQuotaExceeded => ErrorKind::User,
86 0 : Reason::LogicalSizeQuotaExceeded => ErrorKind::User,
87 0 : Reason::ConcurrencyLimitReached => ErrorKind::ControlPlane,
88 0 : Reason::LockAlreadyTaken => ErrorKind::ControlPlane,
89 0 : Reason::RunningOperations => ErrorKind::ControlPlane,
90 0 : Reason::Unknown => match &e {
91 : ConsoleError {
92 : http_status_code:
93 : http::StatusCode::NOT_FOUND | http::StatusCode::NOT_ACCEPTABLE,
94 : ..
95 0 : } => crate::error::ErrorKind::User,
96 : ConsoleError {
97 : http_status_code: http::StatusCode::UNPROCESSABLE_ENTITY,
98 0 : error,
99 0 : ..
100 0 : } if error
101 0 : .contains("compute time quota of non-primary branches is exceeded") =>
102 0 : {
103 0 : crate::error::ErrorKind::User
104 : }
105 : ConsoleError {
106 : http_status_code: http::StatusCode::LOCKED,
107 0 : error,
108 0 : ..
109 0 : } if error.contains("quota exceeded")
110 0 : || error.contains("the limit for current plan reached") =>
111 0 : {
112 0 : crate::error::ErrorKind::User
113 : }
114 : ConsoleError {
115 : http_status_code: http::StatusCode::TOO_MANY_REQUESTS,
116 : ..
117 0 : } => crate::error::ErrorKind::ServiceRateLimit,
118 0 : ConsoleError { .. } => crate::error::ErrorKind::ControlPlane,
119 : },
120 : },
121 0 : ApiError::Transport(_) => crate::error::ErrorKind::ControlPlane,
122 : }
123 0 : }
124 : }
125 :
126 : impl CouldRetry for ApiError {
127 6 : fn could_retry(&self) -> bool {
128 6 : match self {
129 : // retry some transport errors
130 0 : Self::Transport(io) => io.could_retry(),
131 6 : Self::Console(e) => e.could_retry(),
132 : }
133 6 : }
134 : }
135 :
136 : impl From<reqwest::Error> for ApiError {
137 0 : fn from(e: reqwest::Error) -> Self {
138 0 : io_error(e).into()
139 0 : }
140 : }
141 :
142 : impl From<reqwest_middleware::Error> for ApiError {
143 0 : fn from(e: reqwest_middleware::Error) -> Self {
144 0 : io_error(e).into()
145 0 : }
146 : }
147 :
148 0 : #[derive(Debug, Error)]
149 : pub(crate) enum GetAuthInfoError {
150 : // We shouldn't include the actual secret here.
151 : #[error("Console responded with a malformed auth secret")]
152 : BadSecret,
153 :
154 : #[error(transparent)]
155 : ApiError(ApiError),
156 : }
157 :
158 : // This allows more useful interactions than `#[from]`.
159 : impl<E: Into<ApiError>> From<E> for GetAuthInfoError {
160 0 : fn from(e: E) -> Self {
161 0 : Self::ApiError(e.into())
162 0 : }
163 : }
164 :
165 : impl UserFacingError for GetAuthInfoError {
166 0 : fn to_string_client(&self) -> String {
167 0 : match self {
168 : // We absolutely should not leak any secrets!
169 0 : Self::BadSecret => REQUEST_FAILED.to_owned(),
170 : // However, API might return a meaningful error.
171 0 : Self::ApiError(e) => e.to_string_client(),
172 : }
173 0 : }
174 : }
175 :
176 : impl ReportableError for GetAuthInfoError {
177 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
178 0 : match self {
179 0 : Self::BadSecret => crate::error::ErrorKind::ControlPlane,
180 0 : Self::ApiError(_) => crate::error::ErrorKind::ControlPlane,
181 : }
182 0 : }
183 : }
184 :
185 0 : #[derive(Debug, Error)]
186 : pub(crate) enum WakeComputeError {
187 : #[error("Console responded with a malformed compute address: {0}")]
188 : BadComputeAddress(Box<str>),
189 :
190 : #[error(transparent)]
191 : ApiError(ApiError),
192 :
193 : #[error("Too many connections attempts")]
194 : TooManyConnections,
195 :
196 : #[error("error acquiring resource permit: {0}")]
197 : TooManyConnectionAttempts(#[from] ApiLockError),
198 : }
199 :
200 : // This allows more useful interactions than `#[from]`.
201 : impl<E: Into<ApiError>> From<E> for WakeComputeError {
202 0 : fn from(e: E) -> Self {
203 0 : Self::ApiError(e.into())
204 0 : }
205 : }
206 :
207 : impl UserFacingError for WakeComputeError {
208 0 : fn to_string_client(&self) -> String {
209 0 : match self {
210 : // We shouldn't show user the address even if it's broken.
211 : // Besides, user is unlikely to care about this detail.
212 0 : Self::BadComputeAddress(_) => REQUEST_FAILED.to_owned(),
213 : // However, API might return a meaningful error.
214 0 : Self::ApiError(e) => e.to_string_client(),
215 :
216 0 : Self::TooManyConnections => self.to_string(),
217 :
218 : Self::TooManyConnectionAttempts(_) => {
219 0 : "Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
220 : }
221 : }
222 0 : }
223 : }
224 :
225 : impl ReportableError for WakeComputeError {
226 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
227 0 : match self {
228 0 : Self::BadComputeAddress(_) => crate::error::ErrorKind::ControlPlane,
229 0 : Self::ApiError(e) => e.get_error_kind(),
230 0 : Self::TooManyConnections => crate::error::ErrorKind::RateLimit,
231 0 : Self::TooManyConnectionAttempts(e) => e.get_error_kind(),
232 : }
233 0 : }
234 : }
235 :
236 : impl CouldRetry for WakeComputeError {
237 3 : fn could_retry(&self) -> bool {
238 3 : match self {
239 0 : Self::BadComputeAddress(_) => false,
240 3 : Self::ApiError(e) => e.could_retry(),
241 0 : Self::TooManyConnections => false,
242 0 : Self::TooManyConnectionAttempts(_) => false,
243 : }
244 3 : }
245 : }
246 : }
247 :
248 : /// Auth secret which is managed by the cloud.
249 : #[derive(Clone, Eq, PartialEq, Debug)]
250 : pub(crate) enum AuthSecret {
251 : #[cfg(any(test, feature = "testing"))]
252 : /// Md5 hash of user's password.
253 : Md5([u8; 16]),
254 :
255 : /// [SCRAM](crate::scram) authentication info.
256 : Scram(scram::ServerSecret),
257 : }
258 :
259 : #[derive(Default)]
260 : pub(crate) struct AuthInfo {
261 : pub(crate) secret: Option<AuthSecret>,
262 : /// List of IP addresses allowed for the autorization.
263 : pub(crate) allowed_ips: Vec<IpPattern>,
264 : /// Project ID. This is used for cache invalidation.
265 : pub(crate) project_id: Option<ProjectIdInt>,
266 : }
267 :
268 : /// Info for establishing a connection to a compute node.
269 : /// This is what we get after auth succeeded, but not before!
270 : #[derive(Clone)]
271 : pub(crate) struct NodeInfo {
272 : /// Compute node connection params.
273 : /// It's sad that we have to clone this, but this will improve
274 : /// once we migrate to a bespoke connection logic.
275 : pub(crate) config: compute::ConnCfg,
276 :
277 : /// Labels for proxy's metrics.
278 : pub(crate) aux: MetricsAuxInfo,
279 :
280 : /// Whether we should accept self-signed certificates (for testing)
281 : pub(crate) allow_self_signed_compute: bool,
282 : }
283 :
284 : impl NodeInfo {
285 0 : pub(crate) async fn connect(
286 0 : &self,
287 0 : ctx: &RequestMonitoring,
288 0 : timeout: Duration,
289 0 : ) -> Result<compute::PostgresConnection, compute::ConnectionError> {
290 0 : self.config
291 0 : .connect(
292 0 : ctx,
293 0 : self.allow_self_signed_compute,
294 0 : self.aux.clone(),
295 0 : timeout,
296 0 : )
297 0 : .await
298 0 : }
299 4 : pub(crate) fn reuse_settings(&mut self, other: Self) {
300 4 : self.allow_self_signed_compute = other.allow_self_signed_compute;
301 4 : self.config.reuse_password(other.config);
302 4 : }
303 :
304 6 : pub(crate) fn set_keys(&mut self, keys: &ComputeCredentialKeys) {
305 6 : match keys {
306 : #[cfg(any(test, feature = "testing"))]
307 6 : ComputeCredentialKeys::Password(password) => self.config.password(password),
308 0 : ComputeCredentialKeys::AuthKeys(auth_keys) => self.config.auth_keys(*auth_keys),
309 0 : ComputeCredentialKeys::None => &mut self.config,
310 : };
311 6 : }
312 : }
313 :
314 : pub(crate) type NodeInfoCache = TimedLru<EndpointCacheKey, Result<NodeInfo, Box<ConsoleError>>>;
315 : pub(crate) type CachedNodeInfo = Cached<&'static NodeInfoCache, NodeInfo>;
316 : pub(crate) type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option<AuthSecret>>;
317 : pub(crate) type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc<Vec<IpPattern>>>;
318 :
319 : /// This will allocate per each call, but the http requests alone
320 : /// already require a few allocations, so it should be fine.
321 : pub(crate) trait Api {
322 : /// Get the client's auth secret for authentication.
323 : /// Returns option because user not found situation is special.
324 : /// We still have to mock the scram to avoid leaking information that user doesn't exist.
325 : async fn get_role_secret(
326 : &self,
327 : ctx: &RequestMonitoring,
328 : user_info: &ComputeUserInfo,
329 : ) -> Result<CachedRoleSecret, errors::GetAuthInfoError>;
330 :
331 : async fn get_allowed_ips_and_secret(
332 : &self,
333 : ctx: &RequestMonitoring,
334 : user_info: &ComputeUserInfo,
335 : ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError>;
336 :
337 : /// Wake up the compute node and return the corresponding connection info.
338 : async fn wake_compute(
339 : &self,
340 : ctx: &RequestMonitoring,
341 : user_info: &ComputeUserInfo,
342 : ) -> Result<CachedNodeInfo, errors::WakeComputeError>;
343 : }
344 :
345 : #[non_exhaustive]
346 : pub enum ConsoleBackend {
347 : /// Current Cloud API (V2).
348 : Console(neon::Api),
349 : /// Local mock of Cloud API (V2).
350 : #[cfg(any(test, feature = "testing"))]
351 : Postgres(mock::Api),
352 : /// Internal testing
353 : #[cfg(test)]
354 : #[allow(private_interfaces)]
355 : Test(Box<dyn crate::auth::backend::TestBackend>),
356 : }
357 :
358 : impl Api for ConsoleBackend {
359 0 : async fn get_role_secret(
360 0 : &self,
361 0 : ctx: &RequestMonitoring,
362 0 : user_info: &ComputeUserInfo,
363 0 : ) -> Result<CachedRoleSecret, errors::GetAuthInfoError> {
364 0 : match self {
365 0 : Self::Console(api) => api.get_role_secret(ctx, user_info).await,
366 : #[cfg(any(test, feature = "testing"))]
367 0 : Self::Postgres(api) => api.get_role_secret(ctx, user_info).await,
368 : #[cfg(test)]
369 : Self::Test(_) => {
370 0 : unreachable!("this function should never be called in the test backend")
371 : }
372 : }
373 0 : }
374 :
375 0 : async fn get_allowed_ips_and_secret(
376 0 : &self,
377 0 : ctx: &RequestMonitoring,
378 0 : user_info: &ComputeUserInfo,
379 0 : ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError> {
380 0 : match self {
381 0 : Self::Console(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
382 : #[cfg(any(test, feature = "testing"))]
383 0 : Self::Postgres(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
384 : #[cfg(test)]
385 0 : Self::Test(api) => api.get_allowed_ips_and_secret(),
386 : }
387 0 : }
388 :
389 13 : async fn wake_compute(
390 13 : &self,
391 13 : ctx: &RequestMonitoring,
392 13 : user_info: &ComputeUserInfo,
393 13 : ) -> Result<CachedNodeInfo, errors::WakeComputeError> {
394 13 : match self {
395 0 : Self::Console(api) => api.wake_compute(ctx, user_info).await,
396 : #[cfg(any(test, feature = "testing"))]
397 0 : Self::Postgres(api) => api.wake_compute(ctx, user_info).await,
398 : #[cfg(test)]
399 13 : Self::Test(api) => api.wake_compute(),
400 : }
401 13 : }
402 : }
403 :
404 : /// Various caches for [`console`](super).
405 : pub struct ApiCaches {
406 : /// Cache for the `wake_compute` API method.
407 : pub(crate) node_info: NodeInfoCache,
408 : /// Cache which stores project_id -> endpoint_ids mapping.
409 : pub project_info: Arc<ProjectInfoCacheImpl>,
410 : /// List of all valid endpoints.
411 : pub endpoints_cache: Arc<EndpointsCache>,
412 : }
413 :
414 : impl ApiCaches {
415 0 : pub fn new(
416 0 : wake_compute_cache_config: CacheOptions,
417 0 : project_info_cache_config: ProjectInfoCacheOptions,
418 0 : endpoint_cache_config: EndpointCacheConfig,
419 0 : ) -> Self {
420 0 : Self {
421 0 : node_info: NodeInfoCache::new(
422 0 : "node_info_cache",
423 0 : wake_compute_cache_config.size,
424 0 : wake_compute_cache_config.ttl,
425 0 : true,
426 0 : ),
427 0 : project_info: Arc::new(ProjectInfoCacheImpl::new(project_info_cache_config)),
428 0 : endpoints_cache: Arc::new(EndpointsCache::new(endpoint_cache_config)),
429 0 : }
430 0 : }
431 : }
432 :
433 : /// Various caches for [`console`](super).
434 : pub struct ApiLocks<K> {
435 : name: &'static str,
436 : node_locks: DashMap<K, Arc<DynamicLimiter>>,
437 : config: RateLimiterConfig,
438 : timeout: Duration,
439 : epoch: std::time::Duration,
440 : metrics: &'static ApiLockMetrics,
441 : }
442 :
443 0 : #[derive(Debug, thiserror::Error)]
444 : pub(crate) enum ApiLockError {
445 : #[error("timeout acquiring resource permit")]
446 : TimeoutError(#[from] tokio::time::error::Elapsed),
447 : }
448 :
449 : impl ReportableError for ApiLockError {
450 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
451 0 : match self {
452 0 : ApiLockError::TimeoutError(_) => crate::error::ErrorKind::RateLimit,
453 0 : }
454 0 : }
455 : }
456 :
457 : impl<K: Hash + Eq + Clone> ApiLocks<K> {
458 0 : pub fn new(
459 0 : name: &'static str,
460 0 : config: RateLimiterConfig,
461 0 : shards: usize,
462 0 : timeout: Duration,
463 0 : epoch: std::time::Duration,
464 0 : metrics: &'static ApiLockMetrics,
465 0 : ) -> prometheus::Result<Self> {
466 0 : Ok(Self {
467 0 : name,
468 0 : node_locks: DashMap::with_shard_amount(shards),
469 0 : config,
470 0 : timeout,
471 0 : epoch,
472 0 : metrics,
473 0 : })
474 0 : }
475 :
476 0 : pub(crate) async fn get_permit(&self, key: &K) -> Result<WakeComputePermit, ApiLockError> {
477 0 : if self.config.initial_limit == 0 {
478 0 : return Ok(WakeComputePermit {
479 0 : permit: Token::disabled(),
480 0 : });
481 0 : }
482 0 : let now = Instant::now();
483 0 : let semaphore = {
484 : // get fast path
485 0 : if let Some(semaphore) = self.node_locks.get(key) {
486 0 : semaphore.clone()
487 : } else {
488 0 : self.node_locks
489 0 : .entry(key.clone())
490 0 : .or_insert_with(|| {
491 0 : self.metrics.semaphores_registered.inc();
492 0 : DynamicLimiter::new(self.config)
493 0 : })
494 0 : .clone()
495 : }
496 : };
497 0 : let permit = semaphore.acquire_timeout(self.timeout).await;
498 :
499 0 : self.metrics
500 0 : .semaphore_acquire_seconds
501 0 : .observe(now.elapsed().as_secs_f64());
502 0 : info!("acquired permit {:?}", now.elapsed().as_secs_f64());
503 0 : Ok(WakeComputePermit { permit: permit? })
504 0 : }
505 :
506 0 : pub async fn garbage_collect_worker(&self) {
507 0 : if self.config.initial_limit == 0 {
508 0 : return;
509 0 : }
510 0 : let mut interval =
511 0 : tokio::time::interval(self.epoch / (self.node_locks.shards().len()) as u32);
512 : loop {
513 0 : for (i, shard) in self.node_locks.shards().iter().enumerate() {
514 0 : interval.tick().await;
515 : // temporary lock a single shard and then clear any semaphores that aren't currently checked out
516 : // race conditions: if strong_count == 1, there's no way that it can increase while the shard is locked
517 : // therefore releasing it is safe from race conditions
518 0 : info!(
519 : name = self.name,
520 : shard = i,
521 0 : "performing epoch reclamation on api lock"
522 : );
523 0 : let mut lock = shard.write();
524 0 : let timer = self.metrics.reclamation_lag_seconds.start_timer();
525 0 : let count = lock
526 0 : .extract_if(|_, semaphore| Arc::strong_count(semaphore.get_mut()) == 1)
527 0 : .count();
528 0 : drop(lock);
529 0 : self.metrics.semaphores_unregistered.inc_by(count as u64);
530 0 : timer.observe();
531 0 : }
532 : }
533 0 : }
534 : }
535 :
536 : pub(crate) struct WakeComputePermit {
537 : permit: Token,
538 : }
539 :
540 : impl WakeComputePermit {
541 0 : pub(crate) fn should_check_cache(&self) -> bool {
542 0 : !self.permit.is_disabled()
543 0 : }
544 0 : pub(crate) fn release(self, outcome: Outcome) {
545 0 : self.permit.release(outcome);
546 0 : }
547 0 : pub(crate) fn release_result<T, E>(self, res: Result<T, E>) -> Result<T, E> {
548 0 : match res {
549 0 : Ok(_) => self.release(Outcome::Success),
550 0 : Err(_) => self.release(Outcome::Overload),
551 : }
552 0 : res
553 0 : }
554 : }
|