Line data Source code
1 : #[cfg(any(test, feature = "testing"))]
2 : pub mod mock;
3 : pub mod neon;
4 :
5 : use super::messages::MetricsAuxInfo;
6 : use crate::{
7 : auth::{
8 : backend::{ComputeCredentialKeys, ComputeUserInfo},
9 : IpPattern,
10 : },
11 : cache::{endpoints::EndpointsCache, project_info::ProjectInfoCacheImpl, Cached, TimedLru},
12 : compute,
13 : config::{CacheOptions, EndpointCacheConfig, ProjectInfoCacheOptions},
14 : context::RequestMonitoring,
15 : error::ReportableError,
16 : intern::ProjectIdInt,
17 : metrics::ApiLockMetrics,
18 : rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token},
19 : scram, EndpointCacheKey,
20 : };
21 : use dashmap::DashMap;
22 : use std::{hash::Hash, sync::Arc, time::Duration};
23 : use tokio::time::Instant;
24 : use tracing::info;
25 :
26 : pub mod errors {
27 : use crate::{
28 : console::messages::{self, ConsoleError, Reason},
29 : error::{io_error, ReportableError, UserFacingError},
30 : proxy::retry::CouldRetry,
31 : };
32 : use thiserror::Error;
33 :
34 : use super::ApiLockError;
35 :
36 : /// A go-to error message which doesn't leak any detail.
37 : pub const REQUEST_FAILED: &str = "Console request failed";
38 :
39 : /// Common console API error.
40 0 : #[derive(Debug, Error)]
41 : pub enum ApiError {
42 : /// Error returned by the console itself.
43 : #[error("{REQUEST_FAILED} with {0}")]
44 : Console(ConsoleError),
45 :
46 : /// Various IO errors like broken pipe or malformed payload.
47 : #[error("{REQUEST_FAILED}: {0}")]
48 : Transport(#[from] std::io::Error),
49 : }
50 :
51 : impl ApiError {
52 : /// Returns HTTP status code if it's the reason for failure.
53 0 : pub fn get_reason(&self) -> messages::Reason {
54 0 : use ApiError::*;
55 0 : match self {
56 0 : Console(e) => e.get_reason(),
57 0 : _ => messages::Reason::Unknown,
58 : }
59 0 : }
60 : }
61 :
62 : impl UserFacingError for ApiError {
63 0 : fn to_string_client(&self) -> String {
64 0 : use ApiError::*;
65 0 : match self {
66 : // To minimize risks, only select errors are forwarded to users.
67 0 : Console(c) => c.get_user_facing_message(),
68 0 : _ => REQUEST_FAILED.to_owned(),
69 : }
70 0 : }
71 : }
72 :
73 : impl ReportableError for ApiError {
74 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
75 0 : match self {
76 0 : ApiError::Console(e) => {
77 0 : use crate::error::ErrorKind::*;
78 0 : match e.get_reason() {
79 0 : Reason::RoleProtected => User,
80 0 : Reason::ResourceNotFound => User,
81 0 : Reason::ProjectNotFound => User,
82 0 : Reason::EndpointNotFound => User,
83 0 : Reason::BranchNotFound => User,
84 0 : Reason::RateLimitExceeded => ServiceRateLimit,
85 0 : Reason::NonDefaultBranchComputeTimeExceeded => User,
86 0 : Reason::ActiveTimeQuotaExceeded => User,
87 0 : Reason::ComputeTimeQuotaExceeded => User,
88 0 : Reason::WrittenDataQuotaExceeded => User,
89 0 : Reason::DataTransferQuotaExceeded => User,
90 0 : Reason::LogicalSizeQuotaExceeded => User,
91 0 : Reason::ConcurrencyLimitReached => ControlPlane,
92 0 : Reason::LockAlreadyTaken => ControlPlane,
93 0 : Reason::RunningOperations => ControlPlane,
94 0 : Reason::Unknown => match &e {
95 : ConsoleError {
96 : http_status_code:
97 : http::StatusCode::NOT_FOUND | http::StatusCode::NOT_ACCEPTABLE,
98 : ..
99 0 : } => crate::error::ErrorKind::User,
100 : ConsoleError {
101 : http_status_code: http::StatusCode::UNPROCESSABLE_ENTITY,
102 0 : error,
103 0 : ..
104 0 : } if error.contains(
105 : "compute time quota of non-primary branches is exceeded",
106 0 : ) =>
107 0 : {
108 0 : crate::error::ErrorKind::User
109 : }
110 : ConsoleError {
111 : http_status_code: http::StatusCode::LOCKED,
112 0 : error,
113 0 : ..
114 0 : } if error.contains("quota exceeded")
115 0 : || error.contains("the limit for current plan reached") =>
116 0 : {
117 0 : crate::error::ErrorKind::User
118 : }
119 : ConsoleError {
120 : http_status_code: http::StatusCode::TOO_MANY_REQUESTS,
121 : ..
122 0 : } => crate::error::ErrorKind::ServiceRateLimit,
123 0 : ConsoleError { .. } => crate::error::ErrorKind::ControlPlane,
124 : },
125 : }
126 : }
127 0 : ApiError::Transport(_) => crate::error::ErrorKind::ControlPlane,
128 : }
129 0 : }
130 : }
131 :
132 : impl CouldRetry for ApiError {
133 12 : fn could_retry(&self) -> bool {
134 12 : match self {
135 : // retry some transport errors
136 0 : Self::Transport(io) => io.could_retry(),
137 12 : Self::Console(e) => e.could_retry(),
138 : }
139 12 : }
140 : }
141 :
142 : impl From<reqwest::Error> for ApiError {
143 0 : fn from(e: reqwest::Error) -> Self {
144 0 : io_error(e).into()
145 0 : }
146 : }
147 :
148 : impl From<reqwest_middleware::Error> for ApiError {
149 0 : fn from(e: reqwest_middleware::Error) -> Self {
150 0 : io_error(e).into()
151 0 : }
152 : }
153 :
154 0 : #[derive(Debug, Error)]
155 : pub enum GetAuthInfoError {
156 : // We shouldn't include the actual secret here.
157 : #[error("Console responded with a malformed auth secret")]
158 : BadSecret,
159 :
160 : #[error(transparent)]
161 : ApiError(ApiError),
162 : }
163 :
164 : // This allows more useful interactions than `#[from]`.
165 : impl<E: Into<ApiError>> From<E> for GetAuthInfoError {
166 0 : fn from(e: E) -> Self {
167 0 : Self::ApiError(e.into())
168 0 : }
169 : }
170 :
171 : impl UserFacingError for GetAuthInfoError {
172 0 : fn to_string_client(&self) -> String {
173 0 : use GetAuthInfoError::*;
174 0 : match self {
175 : // We absolutely should not leak any secrets!
176 0 : BadSecret => REQUEST_FAILED.to_owned(),
177 : // However, API might return a meaningful error.
178 0 : ApiError(e) => e.to_string_client(),
179 : }
180 0 : }
181 : }
182 :
183 : impl ReportableError for GetAuthInfoError {
184 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
185 0 : match self {
186 0 : GetAuthInfoError::BadSecret => crate::error::ErrorKind::ControlPlane,
187 0 : GetAuthInfoError::ApiError(_) => crate::error::ErrorKind::ControlPlane,
188 : }
189 0 : }
190 : }
191 :
192 0 : #[derive(Debug, Error)]
193 : pub enum WakeComputeError {
194 : #[error("Console responded with a malformed compute address: {0}")]
195 : BadComputeAddress(Box<str>),
196 :
197 : #[error(transparent)]
198 : ApiError(ApiError),
199 :
200 : #[error("Too many connections attempts")]
201 : TooManyConnections,
202 :
203 : #[error("error acquiring resource permit: {0}")]
204 : TooManyConnectionAttempts(#[from] ApiLockError),
205 : }
206 :
207 : // This allows more useful interactions than `#[from]`.
208 : impl<E: Into<ApiError>> From<E> for WakeComputeError {
209 0 : fn from(e: E) -> Self {
210 0 : Self::ApiError(e.into())
211 0 : }
212 : }
213 :
214 : impl UserFacingError for WakeComputeError {
215 0 : fn to_string_client(&self) -> String {
216 0 : use WakeComputeError::*;
217 0 : match self {
218 : // We shouldn't show user the address even if it's broken.
219 : // Besides, user is unlikely to care about this detail.
220 0 : BadComputeAddress(_) => REQUEST_FAILED.to_owned(),
221 : // However, API might return a meaningful error.
222 0 : ApiError(e) => e.to_string_client(),
223 :
224 0 : TooManyConnections => self.to_string(),
225 :
226 : TooManyConnectionAttempts(_) => {
227 0 : "Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
228 : }
229 : }
230 0 : }
231 : }
232 :
233 : impl ReportableError for WakeComputeError {
234 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
235 0 : match self {
236 0 : WakeComputeError::BadComputeAddress(_) => crate::error::ErrorKind::ControlPlane,
237 0 : WakeComputeError::ApiError(e) => e.get_error_kind(),
238 0 : WakeComputeError::TooManyConnections => crate::error::ErrorKind::RateLimit,
239 0 : WakeComputeError::TooManyConnectionAttempts(e) => e.get_error_kind(),
240 : }
241 0 : }
242 : }
243 :
244 : impl CouldRetry for WakeComputeError {
245 6 : fn could_retry(&self) -> bool {
246 6 : match self {
247 0 : WakeComputeError::BadComputeAddress(_) => false,
248 6 : WakeComputeError::ApiError(e) => e.could_retry(),
249 0 : WakeComputeError::TooManyConnections => false,
250 0 : WakeComputeError::TooManyConnectionAttempts(_) => false,
251 : }
252 6 : }
253 : }
254 : }
255 :
256 : /// Auth secret which is managed by the cloud.
257 : #[derive(Clone, Eq, PartialEq, Debug)]
258 : pub enum AuthSecret {
259 : #[cfg(any(test, feature = "testing"))]
260 : /// Md5 hash of user's password.
261 : Md5([u8; 16]),
262 :
263 : /// [SCRAM](crate::scram) authentication info.
264 : Scram(scram::ServerSecret),
265 : }
266 :
267 : #[derive(Default)]
268 : pub struct AuthInfo {
269 : pub secret: Option<AuthSecret>,
270 : /// List of IP addresses allowed for the autorization.
271 : pub allowed_ips: Vec<IpPattern>,
272 : /// Project ID. This is used for cache invalidation.
273 : pub project_id: Option<ProjectIdInt>,
274 : }
275 :
276 : /// Info for establishing a connection to a compute node.
277 : /// This is what we get after auth succeeded, but not before!
278 : #[derive(Clone)]
279 : pub struct NodeInfo {
280 : /// Compute node connection params.
281 : /// It's sad that we have to clone this, but this will improve
282 : /// once we migrate to a bespoke connection logic.
283 : pub config: compute::ConnCfg,
284 :
285 : /// Labels for proxy's metrics.
286 : pub aux: MetricsAuxInfo,
287 :
288 : /// Whether we should accept self-signed certificates (for testing)
289 : pub allow_self_signed_compute: bool,
290 : }
291 :
292 : impl NodeInfo {
293 0 : pub async fn connect(
294 0 : &self,
295 0 : ctx: &mut RequestMonitoring,
296 0 : timeout: Duration,
297 0 : ) -> Result<compute::PostgresConnection, compute::ConnectionError> {
298 0 : self.config
299 0 : .connect(
300 0 : ctx,
301 0 : self.allow_self_signed_compute,
302 0 : self.aux.clone(),
303 0 : timeout,
304 0 : )
305 0 : .await
306 0 : }
307 8 : pub fn reuse_settings(&mut self, other: Self) {
308 8 : self.allow_self_signed_compute = other.allow_self_signed_compute;
309 8 : self.config.reuse_password(other.config);
310 8 : }
311 :
312 12 : pub fn set_keys(&mut self, keys: &ComputeCredentialKeys) {
313 12 : match keys {
314 12 : ComputeCredentialKeys::Password(password) => self.config.password(password),
315 0 : ComputeCredentialKeys::AuthKeys(auth_keys) => self.config.auth_keys(*auth_keys),
316 : };
317 12 : }
318 : }
319 :
320 : pub type NodeInfoCache = TimedLru<EndpointCacheKey, NodeInfo>;
321 : pub type CachedNodeInfo = Cached<&'static NodeInfoCache>;
322 : pub type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option<AuthSecret>>;
323 : pub type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc<Vec<IpPattern>>>;
324 :
325 : /// This will allocate per each call, but the http requests alone
326 : /// already require a few allocations, so it should be fine.
327 : pub(crate) trait Api {
328 : /// Get the client's auth secret for authentication.
329 : /// Returns option because user not found situation is special.
330 : /// We still have to mock the scram to avoid leaking information that user doesn't exist.
331 : async fn get_role_secret(
332 : &self,
333 : ctx: &mut RequestMonitoring,
334 : user_info: &ComputeUserInfo,
335 : ) -> Result<CachedRoleSecret, errors::GetAuthInfoError>;
336 :
337 : async fn get_allowed_ips_and_secret(
338 : &self,
339 : ctx: &mut RequestMonitoring,
340 : user_info: &ComputeUserInfo,
341 : ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError>;
342 :
343 : /// Wake up the compute node and return the corresponding connection info.
344 : async fn wake_compute(
345 : &self,
346 : ctx: &mut RequestMonitoring,
347 : user_info: &ComputeUserInfo,
348 : ) -> Result<CachedNodeInfo, errors::WakeComputeError>;
349 : }
350 :
351 : #[non_exhaustive]
352 : pub enum ConsoleBackend {
353 : /// Current Cloud API (V2).
354 : Console(neon::Api),
355 : /// Local mock of Cloud API (V2).
356 : #[cfg(any(test, feature = "testing"))]
357 : Postgres(mock::Api),
358 : /// Internal testing
359 : #[cfg(test)]
360 : Test(Box<dyn crate::auth::backend::TestBackend>),
361 : }
362 :
363 : impl Api for ConsoleBackend {
364 0 : async fn get_role_secret(
365 0 : &self,
366 0 : ctx: &mut RequestMonitoring,
367 0 : user_info: &ComputeUserInfo,
368 0 : ) -> Result<CachedRoleSecret, errors::GetAuthInfoError> {
369 0 : use ConsoleBackend::*;
370 0 : match self {
371 0 : Console(api) => api.get_role_secret(ctx, user_info).await,
372 : #[cfg(any(test, feature = "testing"))]
373 0 : Postgres(api) => api.get_role_secret(ctx, user_info).await,
374 : #[cfg(test)]
375 0 : Test(_) => unreachable!("this function should never be called in the test backend"),
376 : }
377 0 : }
378 :
379 0 : async fn get_allowed_ips_and_secret(
380 0 : &self,
381 0 : ctx: &mut RequestMonitoring,
382 0 : user_info: &ComputeUserInfo,
383 0 : ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError> {
384 0 : use ConsoleBackend::*;
385 0 : match self {
386 0 : Console(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
387 : #[cfg(any(test, feature = "testing"))]
388 0 : Postgres(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
389 : #[cfg(test)]
390 0 : Test(api) => api.get_allowed_ips_and_secret(),
391 : }
392 0 : }
393 :
394 26 : async fn wake_compute(
395 26 : &self,
396 26 : ctx: &mut RequestMonitoring,
397 26 : user_info: &ComputeUserInfo,
398 26 : ) -> Result<CachedNodeInfo, errors::WakeComputeError> {
399 26 : use ConsoleBackend::*;
400 26 :
401 26 : match self {
402 0 : Console(api) => api.wake_compute(ctx, user_info).await,
403 : #[cfg(any(test, feature = "testing"))]
404 0 : Postgres(api) => api.wake_compute(ctx, user_info).await,
405 : #[cfg(test)]
406 26 : Test(api) => api.wake_compute(),
407 : }
408 26 : }
409 : }
410 :
411 : /// Various caches for [`console`](super).
412 : pub struct ApiCaches {
413 : /// Cache for the `wake_compute` API method.
414 : pub node_info: NodeInfoCache,
415 : /// Cache which stores project_id -> endpoint_ids mapping.
416 : pub project_info: Arc<ProjectInfoCacheImpl>,
417 : /// List of all valid endpoints.
418 : pub endpoints_cache: Arc<EndpointsCache>,
419 : }
420 :
421 : impl ApiCaches {
422 0 : pub fn new(
423 0 : wake_compute_cache_config: CacheOptions,
424 0 : project_info_cache_config: ProjectInfoCacheOptions,
425 0 : endpoint_cache_config: EndpointCacheConfig,
426 0 : ) -> Self {
427 0 : Self {
428 0 : node_info: NodeInfoCache::new(
429 0 : "node_info_cache",
430 0 : wake_compute_cache_config.size,
431 0 : wake_compute_cache_config.ttl,
432 0 : true,
433 0 : ),
434 0 : project_info: Arc::new(ProjectInfoCacheImpl::new(project_info_cache_config)),
435 0 : endpoints_cache: Arc::new(EndpointsCache::new(endpoint_cache_config)),
436 0 : }
437 0 : }
438 : }
439 :
440 : /// Various caches for [`console`](super).
441 : pub struct ApiLocks<K> {
442 : name: &'static str,
443 : node_locks: DashMap<K, Arc<DynamicLimiter>>,
444 : config: RateLimiterConfig,
445 : timeout: Duration,
446 : epoch: std::time::Duration,
447 : metrics: &'static ApiLockMetrics,
448 : }
449 :
450 0 : #[derive(Debug, thiserror::Error)]
451 : pub enum ApiLockError {
452 : #[error("timeout acquiring resource permit")]
453 : TimeoutError(#[from] tokio::time::error::Elapsed),
454 : }
455 :
456 : impl ReportableError for ApiLockError {
457 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
458 0 : match self {
459 0 : ApiLockError::TimeoutError(_) => crate::error::ErrorKind::RateLimit,
460 0 : }
461 0 : }
462 : }
463 :
464 : impl<K: Hash + Eq + Clone> ApiLocks<K> {
465 0 : pub fn new(
466 0 : name: &'static str,
467 0 : config: RateLimiterConfig,
468 0 : shards: usize,
469 0 : timeout: Duration,
470 0 : epoch: std::time::Duration,
471 0 : metrics: &'static ApiLockMetrics,
472 0 : ) -> prometheus::Result<Self> {
473 0 : Ok(Self {
474 0 : name,
475 0 : node_locks: DashMap::with_shard_amount(shards),
476 0 : config,
477 0 : timeout,
478 0 : epoch,
479 0 : metrics,
480 0 : })
481 0 : }
482 :
483 0 : pub async fn get_permit(&self, key: &K) -> Result<WakeComputePermit, ApiLockError> {
484 0 : if self.config.initial_limit == 0 {
485 0 : return Ok(WakeComputePermit {
486 0 : permit: Token::disabled(),
487 0 : });
488 0 : }
489 0 : let now = Instant::now();
490 0 : let semaphore = {
491 : // get fast path
492 0 : if let Some(semaphore) = self.node_locks.get(key) {
493 0 : semaphore.clone()
494 : } else {
495 0 : self.node_locks
496 0 : .entry(key.clone())
497 0 : .or_insert_with(|| {
498 0 : self.metrics.semaphores_registered.inc();
499 0 : DynamicLimiter::new(self.config)
500 0 : })
501 0 : .clone()
502 : }
503 : };
504 0 : let permit = semaphore.acquire_timeout(self.timeout).await;
505 :
506 0 : self.metrics
507 0 : .semaphore_acquire_seconds
508 0 : .observe(now.elapsed().as_secs_f64());
509 0 : info!("acquired permit {:?}", now.elapsed().as_secs_f64());
510 0 : Ok(WakeComputePermit { permit: permit? })
511 0 : }
512 :
513 0 : pub async fn garbage_collect_worker(&self) {
514 0 : if self.config.initial_limit == 0 {
515 0 : return;
516 0 : }
517 0 : let mut interval =
518 0 : tokio::time::interval(self.epoch / (self.node_locks.shards().len()) as u32);
519 : loop {
520 0 : for (i, shard) in self.node_locks.shards().iter().enumerate() {
521 0 : interval.tick().await;
522 : // temporary lock a single shard and then clear any semaphores that aren't currently checked out
523 : // race conditions: if strong_count == 1, there's no way that it can increase while the shard is locked
524 : // therefore releasing it is safe from race conditions
525 0 : info!(
526 : name = self.name,
527 : shard = i,
528 0 : "performing epoch reclamation on api lock"
529 : );
530 0 : let mut lock = shard.write();
531 0 : let timer = self.metrics.reclamation_lag_seconds.start_timer();
532 0 : let count = lock
533 0 : .extract_if(|_, semaphore| Arc::strong_count(semaphore.get_mut()) == 1)
534 0 : .count();
535 0 : drop(lock);
536 0 : self.metrics.semaphores_unregistered.inc_by(count as u64);
537 0 : timer.observe();
538 : }
539 : }
540 0 : }
541 : }
542 :
543 : pub struct WakeComputePermit {
544 : permit: Token,
545 : }
546 :
547 : impl WakeComputePermit {
548 0 : pub fn should_check_cache(&self) -> bool {
549 0 : !self.permit.is_disabled()
550 0 : }
551 0 : pub fn release(self, outcome: Outcome) {
552 0 : self.permit.release(outcome)
553 0 : }
554 0 : pub fn release_result<T, E>(self, res: Result<T, E>) -> Result<T, E> {
555 0 : match res {
556 0 : Ok(_) => self.release(Outcome::Success),
557 0 : Err(_) => self.release(Outcome::Overload),
558 : }
559 0 : res
560 0 : }
561 : }
|