Line data Source code
1 : #[cfg(any(test, feature = "testing"))]
2 : pub mod mock;
3 : pub mod neon;
4 :
5 : use std::hash::Hash;
6 : use std::sync::Arc;
7 : use std::time::Duration;
8 :
9 : use dashmap::DashMap;
10 : use tokio::time::Instant;
11 : use tracing::info;
12 :
13 : use super::messages::{ControlPlaneError, MetricsAuxInfo};
14 : use crate::auth::backend::jwt::{AuthRule, FetchAuthRules, FetchAuthRulesError};
15 : use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
16 : use crate::auth::IpPattern;
17 : use crate::cache::endpoints::EndpointsCache;
18 : use crate::cache::project_info::ProjectInfoCacheImpl;
19 : use crate::cache::{Cached, TimedLru};
20 : use crate::config::{CacheOptions, EndpointCacheConfig, ProjectInfoCacheOptions};
21 : use crate::context::RequestMonitoring;
22 : use crate::error::ReportableError;
23 : use crate::intern::ProjectIdInt;
24 : use crate::metrics::ApiLockMetrics;
25 : use crate::rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token};
26 : use crate::types::{EndpointCacheKey, EndpointId};
27 : use crate::{compute, scram};
28 :
29 : pub(crate) mod errors {
30 : use thiserror::Error;
31 :
32 : use super::ApiLockError;
33 : use crate::control_plane::messages::{self, ControlPlaneError, Reason};
34 : use crate::error::{io_error, ErrorKind, ReportableError, UserFacingError};
35 : use crate::proxy::retry::CouldRetry;
36 :
37 : /// A go-to error message which doesn't leak any detail.
38 : pub(crate) const REQUEST_FAILED: &str = "Console request failed";
39 :
40 : /// Common console API error.
41 0 : #[derive(Debug, Error)]
42 : pub(crate) enum ApiError {
43 : /// Error returned by the console itself.
44 : #[error("{REQUEST_FAILED} with {0}")]
45 : ControlPlane(Box<ControlPlaneError>),
46 :
47 : /// Various IO errors like broken pipe or malformed payload.
48 : #[error("{REQUEST_FAILED}: {0}")]
49 : Transport(#[from] std::io::Error),
50 : }
51 :
52 : impl ApiError {
53 : /// Returns HTTP status code if it's the reason for failure.
54 0 : pub(crate) fn get_reason(&self) -> messages::Reason {
55 0 : match self {
56 0 : ApiError::ControlPlane(e) => e.get_reason(),
57 0 : ApiError::Transport(_) => messages::Reason::Unknown,
58 : }
59 0 : }
60 : }
61 :
62 : impl UserFacingError for ApiError {
63 0 : fn to_string_client(&self) -> String {
64 0 : match self {
65 : // To minimize risks, only select errors are forwarded to users.
66 0 : ApiError::ControlPlane(c) => c.get_user_facing_message(),
67 0 : ApiError::Transport(_) => REQUEST_FAILED.to_owned(),
68 : }
69 0 : }
70 : }
71 :
72 : impl ReportableError for ApiError {
73 3 : fn get_error_kind(&self) -> crate::error::ErrorKind {
74 3 : match self {
75 3 : ApiError::ControlPlane(e) => match e.get_reason() {
76 0 : Reason::RoleProtected => ErrorKind::User,
77 0 : Reason::ResourceNotFound => ErrorKind::User,
78 0 : Reason::ProjectNotFound => ErrorKind::User,
79 0 : Reason::EndpointNotFound => ErrorKind::User,
80 0 : Reason::BranchNotFound => ErrorKind::User,
81 0 : Reason::RateLimitExceeded => ErrorKind::ServiceRateLimit,
82 0 : Reason::NonDefaultBranchComputeTimeExceeded => ErrorKind::Quota,
83 0 : Reason::ActiveTimeQuotaExceeded => ErrorKind::Quota,
84 0 : Reason::ComputeTimeQuotaExceeded => ErrorKind::Quota,
85 0 : Reason::WrittenDataQuotaExceeded => ErrorKind::Quota,
86 0 : Reason::DataTransferQuotaExceeded => ErrorKind::Quota,
87 0 : Reason::LogicalSizeQuotaExceeded => ErrorKind::Quota,
88 0 : Reason::ConcurrencyLimitReached => ErrorKind::ControlPlane,
89 0 : Reason::LockAlreadyTaken => ErrorKind::ControlPlane,
90 0 : Reason::RunningOperations => ErrorKind::ControlPlane,
91 0 : Reason::ActiveEndpointsLimitExceeded => ErrorKind::ControlPlane,
92 3 : Reason::Unknown => ErrorKind::ControlPlane,
93 : },
94 0 : ApiError::Transport(_) => crate::error::ErrorKind::ControlPlane,
95 : }
96 3 : }
97 : }
98 :
99 : impl CouldRetry for ApiError {
100 6 : fn could_retry(&self) -> bool {
101 6 : match self {
102 : // retry some transport errors
103 0 : Self::Transport(io) => io.could_retry(),
104 6 : Self::ControlPlane(e) => e.could_retry(),
105 : }
106 6 : }
107 : }
108 :
109 : impl From<reqwest::Error> for ApiError {
110 0 : fn from(e: reqwest::Error) -> Self {
111 0 : io_error(e).into()
112 0 : }
113 : }
114 :
115 : impl From<reqwest_middleware::Error> for ApiError {
116 0 : fn from(e: reqwest_middleware::Error) -> Self {
117 0 : io_error(e).into()
118 0 : }
119 : }
120 :
121 0 : #[derive(Debug, Error)]
122 : pub(crate) enum GetAuthInfoError {
123 : // We shouldn't include the actual secret here.
124 : #[error("Console responded with a malformed auth secret")]
125 : BadSecret,
126 :
127 : #[error(transparent)]
128 : ApiError(ApiError),
129 : }
130 :
131 : // This allows more useful interactions than `#[from]`.
132 : impl<E: Into<ApiError>> From<E> for GetAuthInfoError {
133 0 : fn from(e: E) -> Self {
134 0 : Self::ApiError(e.into())
135 0 : }
136 : }
137 :
138 : impl UserFacingError for GetAuthInfoError {
139 0 : fn to_string_client(&self) -> String {
140 0 : match self {
141 : // We absolutely should not leak any secrets!
142 0 : Self::BadSecret => REQUEST_FAILED.to_owned(),
143 : // However, API might return a meaningful error.
144 0 : Self::ApiError(e) => e.to_string_client(),
145 : }
146 0 : }
147 : }
148 :
149 : impl ReportableError for GetAuthInfoError {
150 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
151 0 : match self {
152 0 : Self::BadSecret => crate::error::ErrorKind::ControlPlane,
153 0 : Self::ApiError(_) => crate::error::ErrorKind::ControlPlane,
154 : }
155 0 : }
156 : }
157 :
158 0 : #[derive(Debug, Error)]
159 : pub(crate) enum WakeComputeError {
160 : #[error("Console responded with a malformed compute address: {0}")]
161 : BadComputeAddress(Box<str>),
162 :
163 : #[error(transparent)]
164 : ApiError(ApiError),
165 :
166 : #[error("Too many connections attempts")]
167 : TooManyConnections,
168 :
169 : #[error("error acquiring resource permit: {0}")]
170 : TooManyConnectionAttempts(#[from] ApiLockError),
171 : }
172 :
173 : // This allows more useful interactions than `#[from]`.
174 : impl<E: Into<ApiError>> From<E> for WakeComputeError {
175 0 : fn from(e: E) -> Self {
176 0 : Self::ApiError(e.into())
177 0 : }
178 : }
179 :
180 : impl UserFacingError for WakeComputeError {
181 0 : fn to_string_client(&self) -> String {
182 0 : match self {
183 : // We shouldn't show user the address even if it's broken.
184 : // Besides, user is unlikely to care about this detail.
185 0 : Self::BadComputeAddress(_) => REQUEST_FAILED.to_owned(),
186 : // However, API might return a meaningful error.
187 0 : Self::ApiError(e) => e.to_string_client(),
188 :
189 0 : Self::TooManyConnections => self.to_string(),
190 :
191 : Self::TooManyConnectionAttempts(_) => {
192 0 : "Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
193 : }
194 : }
195 0 : }
196 : }
197 :
198 : impl ReportableError for WakeComputeError {
199 3 : fn get_error_kind(&self) -> crate::error::ErrorKind {
200 3 : match self {
201 0 : Self::BadComputeAddress(_) => crate::error::ErrorKind::ControlPlane,
202 3 : Self::ApiError(e) => e.get_error_kind(),
203 0 : Self::TooManyConnections => crate::error::ErrorKind::RateLimit,
204 0 : Self::TooManyConnectionAttempts(e) => e.get_error_kind(),
205 : }
206 3 : }
207 : }
208 :
209 : impl CouldRetry for WakeComputeError {
210 3 : fn could_retry(&self) -> bool {
211 3 : match self {
212 0 : Self::BadComputeAddress(_) => false,
213 3 : Self::ApiError(e) => e.could_retry(),
214 0 : Self::TooManyConnections => false,
215 0 : Self::TooManyConnectionAttempts(_) => false,
216 : }
217 3 : }
218 : }
219 :
220 0 : #[derive(Debug, Error)]
221 : pub enum GetEndpointJwksError {
222 : #[error("endpoint not found")]
223 : EndpointNotFound,
224 :
225 : #[error("failed to build control plane request: {0}")]
226 : RequestBuild(#[source] reqwest::Error),
227 :
228 : #[error("failed to send control plane request: {0}")]
229 : RequestExecute(#[source] reqwest_middleware::Error),
230 :
231 : #[error(transparent)]
232 : ControlPlane(#[from] ApiError),
233 :
234 : #[cfg(any(test, feature = "testing"))]
235 : #[error(transparent)]
236 : TokioPostgres(#[from] tokio_postgres::Error),
237 :
238 : #[cfg(any(test, feature = "testing"))]
239 : #[error(transparent)]
240 : ParseUrl(#[from] url::ParseError),
241 :
242 : #[cfg(any(test, feature = "testing"))]
243 : #[error(transparent)]
244 : TaskJoin(#[from] tokio::task::JoinError),
245 : }
246 : }
247 :
248 : /// Auth secret which is managed by the cloud.
249 : #[derive(Clone, Eq, PartialEq, Debug)]
250 : pub(crate) enum AuthSecret {
251 : #[cfg(any(test, feature = "testing"))]
252 : /// Md5 hash of user's password.
253 : Md5([u8; 16]),
254 :
255 : /// [SCRAM](crate::scram) authentication info.
256 : Scram(scram::ServerSecret),
257 : }
258 :
259 : #[derive(Default)]
260 : pub(crate) struct AuthInfo {
261 : pub(crate) secret: Option<AuthSecret>,
262 : /// List of IP addresses allowed for the autorization.
263 : pub(crate) allowed_ips: Vec<IpPattern>,
264 : /// Project ID. This is used for cache invalidation.
265 : pub(crate) project_id: Option<ProjectIdInt>,
266 : }
267 :
268 : /// Info for establishing a connection to a compute node.
269 : /// This is what we get after auth succeeded, but not before!
270 : #[derive(Clone)]
271 : pub(crate) struct NodeInfo {
272 : /// Compute node connection params.
273 : /// It's sad that we have to clone this, but this will improve
274 : /// once we migrate to a bespoke connection logic.
275 : pub(crate) config: compute::ConnCfg,
276 :
277 : /// Labels for proxy's metrics.
278 : pub(crate) aux: MetricsAuxInfo,
279 :
280 : /// Whether we should accept self-signed certificates (for testing)
281 : pub(crate) allow_self_signed_compute: bool,
282 : }
283 :
284 : impl NodeInfo {
285 0 : pub(crate) async fn connect(
286 0 : &self,
287 0 : ctx: &RequestMonitoring,
288 0 : timeout: Duration,
289 0 : ) -> Result<compute::PostgresConnection, compute::ConnectionError> {
290 0 : self.config
291 0 : .connect(
292 0 : ctx,
293 0 : self.allow_self_signed_compute,
294 0 : self.aux.clone(),
295 0 : timeout,
296 0 : )
297 0 : .await
298 0 : }
299 4 : pub(crate) fn reuse_settings(&mut self, other: Self) {
300 4 : self.allow_self_signed_compute = other.allow_self_signed_compute;
301 4 : self.config.reuse_password(other.config);
302 4 : }
303 :
304 6 : pub(crate) fn set_keys(&mut self, keys: &ComputeCredentialKeys) {
305 6 : match keys {
306 : #[cfg(any(test, feature = "testing"))]
307 6 : ComputeCredentialKeys::Password(password) => self.config.password(password),
308 0 : ComputeCredentialKeys::AuthKeys(auth_keys) => self.config.auth_keys(*auth_keys),
309 0 : ComputeCredentialKeys::JwtPayload(_) | ComputeCredentialKeys::None => &mut self.config,
310 : };
311 6 : }
312 : }
313 :
314 : pub(crate) type NodeInfoCache =
315 : TimedLru<EndpointCacheKey, Result<NodeInfo, Box<ControlPlaneError>>>;
316 : pub(crate) type CachedNodeInfo = Cached<&'static NodeInfoCache, NodeInfo>;
317 : pub(crate) type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option<AuthSecret>>;
318 : pub(crate) type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc<Vec<IpPattern>>>;
319 :
320 : /// This will allocate per each call, but the http requests alone
321 : /// already require a few allocations, so it should be fine.
322 : pub(crate) trait Api {
323 : /// Get the client's auth secret for authentication.
324 : /// Returns option because user not found situation is special.
325 : /// We still have to mock the scram to avoid leaking information that user doesn't exist.
326 : async fn get_role_secret(
327 : &self,
328 : ctx: &RequestMonitoring,
329 : user_info: &ComputeUserInfo,
330 : ) -> Result<CachedRoleSecret, errors::GetAuthInfoError>;
331 :
332 : async fn get_allowed_ips_and_secret(
333 : &self,
334 : ctx: &RequestMonitoring,
335 : user_info: &ComputeUserInfo,
336 : ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError>;
337 :
338 : async fn get_endpoint_jwks(
339 : &self,
340 : ctx: &RequestMonitoring,
341 : endpoint: EndpointId,
342 : ) -> Result<Vec<AuthRule>, errors::GetEndpointJwksError>;
343 :
344 : /// Wake up the compute node and return the corresponding connection info.
345 : async fn wake_compute(
346 : &self,
347 : ctx: &RequestMonitoring,
348 : user_info: &ComputeUserInfo,
349 : ) -> Result<CachedNodeInfo, errors::WakeComputeError>;
350 : }
351 :
352 : #[non_exhaustive]
353 : #[derive(Clone)]
354 : pub enum ControlPlaneBackend {
355 : /// Current Management API (V2).
356 : Management(neon::Api),
357 : /// Local mock control plane.
358 : #[cfg(any(test, feature = "testing"))]
359 : PostgresMock(mock::Api),
360 : /// Internal testing
361 : #[cfg(test)]
362 : #[allow(private_interfaces)]
363 : Test(Box<dyn crate::auth::backend::TestBackend>),
364 : }
365 :
366 : impl Api for ControlPlaneBackend {
367 0 : async fn get_role_secret(
368 0 : &self,
369 0 : ctx: &RequestMonitoring,
370 0 : user_info: &ComputeUserInfo,
371 0 : ) -> Result<CachedRoleSecret, errors::GetAuthInfoError> {
372 0 : match self {
373 0 : Self::Management(api) => api.get_role_secret(ctx, user_info).await,
374 : #[cfg(any(test, feature = "testing"))]
375 0 : Self::PostgresMock(api) => api.get_role_secret(ctx, user_info).await,
376 : #[cfg(test)]
377 : Self::Test(_) => {
378 0 : unreachable!("this function should never be called in the test backend")
379 : }
380 : }
381 0 : }
382 :
383 0 : async fn get_allowed_ips_and_secret(
384 0 : &self,
385 0 : ctx: &RequestMonitoring,
386 0 : user_info: &ComputeUserInfo,
387 0 : ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError> {
388 0 : match self {
389 0 : Self::Management(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
390 : #[cfg(any(test, feature = "testing"))]
391 0 : Self::PostgresMock(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
392 : #[cfg(test)]
393 0 : Self::Test(api) => api.get_allowed_ips_and_secret(),
394 : }
395 0 : }
396 :
397 0 : async fn get_endpoint_jwks(
398 0 : &self,
399 0 : ctx: &RequestMonitoring,
400 0 : endpoint: EndpointId,
401 0 : ) -> Result<Vec<AuthRule>, errors::GetEndpointJwksError> {
402 0 : match self {
403 0 : Self::Management(api) => api.get_endpoint_jwks(ctx, endpoint).await,
404 : #[cfg(any(test, feature = "testing"))]
405 0 : Self::PostgresMock(api) => api.get_endpoint_jwks(ctx, endpoint).await,
406 : #[cfg(test)]
407 0 : Self::Test(_api) => Ok(vec![]),
408 : }
409 0 : }
410 :
411 13 : async fn wake_compute(
412 13 : &self,
413 13 : ctx: &RequestMonitoring,
414 13 : user_info: &ComputeUserInfo,
415 13 : ) -> Result<CachedNodeInfo, errors::WakeComputeError> {
416 13 : match self {
417 0 : Self::Management(api) => api.wake_compute(ctx, user_info).await,
418 : #[cfg(any(test, feature = "testing"))]
419 0 : Self::PostgresMock(api) => api.wake_compute(ctx, user_info).await,
420 : #[cfg(test)]
421 13 : Self::Test(api) => api.wake_compute(),
422 : }
423 13 : }
424 : }
425 :
426 : /// Various caches for [`control_plane`](super).
427 : pub struct ApiCaches {
428 : /// Cache for the `wake_compute` API method.
429 : pub(crate) node_info: NodeInfoCache,
430 : /// Cache which stores project_id -> endpoint_ids mapping.
431 : pub project_info: Arc<ProjectInfoCacheImpl>,
432 : /// List of all valid endpoints.
433 : pub endpoints_cache: Arc<EndpointsCache>,
434 : }
435 :
436 : impl ApiCaches {
437 0 : pub fn new(
438 0 : wake_compute_cache_config: CacheOptions,
439 0 : project_info_cache_config: ProjectInfoCacheOptions,
440 0 : endpoint_cache_config: EndpointCacheConfig,
441 0 : ) -> Self {
442 0 : Self {
443 0 : node_info: NodeInfoCache::new(
444 0 : "node_info_cache",
445 0 : wake_compute_cache_config.size,
446 0 : wake_compute_cache_config.ttl,
447 0 : true,
448 0 : ),
449 0 : project_info: Arc::new(ProjectInfoCacheImpl::new(project_info_cache_config)),
450 0 : endpoints_cache: Arc::new(EndpointsCache::new(endpoint_cache_config)),
451 0 : }
452 0 : }
453 : }
454 :
455 : /// Various caches for [`control_plane`](super).
456 : pub struct ApiLocks<K> {
457 : name: &'static str,
458 : node_locks: DashMap<K, Arc<DynamicLimiter>>,
459 : config: RateLimiterConfig,
460 : timeout: Duration,
461 : epoch: std::time::Duration,
462 : metrics: &'static ApiLockMetrics,
463 : }
464 :
465 0 : #[derive(Debug, thiserror::Error)]
466 : pub(crate) enum ApiLockError {
467 : #[error("timeout acquiring resource permit")]
468 : TimeoutError(#[from] tokio::time::error::Elapsed),
469 : }
470 :
471 : impl ReportableError for ApiLockError {
472 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
473 0 : match self {
474 0 : ApiLockError::TimeoutError(_) => crate::error::ErrorKind::RateLimit,
475 0 : }
476 0 : }
477 : }
478 :
479 : impl<K: Hash + Eq + Clone> ApiLocks<K> {
480 0 : pub fn new(
481 0 : name: &'static str,
482 0 : config: RateLimiterConfig,
483 0 : shards: usize,
484 0 : timeout: Duration,
485 0 : epoch: std::time::Duration,
486 0 : metrics: &'static ApiLockMetrics,
487 0 : ) -> prometheus::Result<Self> {
488 0 : Ok(Self {
489 0 : name,
490 0 : node_locks: DashMap::with_shard_amount(shards),
491 0 : config,
492 0 : timeout,
493 0 : epoch,
494 0 : metrics,
495 0 : })
496 0 : }
497 :
498 0 : pub(crate) async fn get_permit(&self, key: &K) -> Result<WakeComputePermit, ApiLockError> {
499 0 : if self.config.initial_limit == 0 {
500 0 : return Ok(WakeComputePermit {
501 0 : permit: Token::disabled(),
502 0 : });
503 0 : }
504 0 : let now = Instant::now();
505 0 : let semaphore = {
506 : // get fast path
507 0 : if let Some(semaphore) = self.node_locks.get(key) {
508 0 : semaphore.clone()
509 : } else {
510 0 : self.node_locks
511 0 : .entry(key.clone())
512 0 : .or_insert_with(|| {
513 0 : self.metrics.semaphores_registered.inc();
514 0 : DynamicLimiter::new(self.config)
515 0 : })
516 0 : .clone()
517 : }
518 : };
519 0 : let permit = semaphore.acquire_timeout(self.timeout).await;
520 :
521 0 : self.metrics
522 0 : .semaphore_acquire_seconds
523 0 : .observe(now.elapsed().as_secs_f64());
524 0 : info!("acquired permit {:?}", now.elapsed().as_secs_f64());
525 0 : Ok(WakeComputePermit { permit: permit? })
526 0 : }
527 :
528 0 : pub async fn garbage_collect_worker(&self) {
529 0 : if self.config.initial_limit == 0 {
530 0 : return;
531 0 : }
532 0 : let mut interval =
533 0 : tokio::time::interval(self.epoch / (self.node_locks.shards().len()) as u32);
534 : loop {
535 0 : for (i, shard) in self.node_locks.shards().iter().enumerate() {
536 0 : interval.tick().await;
537 : // temporary lock a single shard and then clear any semaphores that aren't currently checked out
538 : // race conditions: if strong_count == 1, there's no way that it can increase while the shard is locked
539 : // therefore releasing it is safe from race conditions
540 0 : info!(
541 : name = self.name,
542 : shard = i,
543 0 : "performing epoch reclamation on api lock"
544 : );
545 0 : let mut lock = shard.write();
546 0 : let timer = self.metrics.reclamation_lag_seconds.start_timer();
547 0 : let count = lock
548 0 : .extract_if(|_, semaphore| Arc::strong_count(semaphore.get_mut()) == 1)
549 0 : .count();
550 0 : drop(lock);
551 0 : self.metrics.semaphores_unregistered.inc_by(count as u64);
552 0 : timer.observe();
553 0 : }
554 : }
555 0 : }
556 : }
557 :
558 : pub(crate) struct WakeComputePermit {
559 : permit: Token,
560 : }
561 :
562 : impl WakeComputePermit {
563 0 : pub(crate) fn should_check_cache(&self) -> bool {
564 0 : !self.permit.is_disabled()
565 0 : }
566 0 : pub(crate) fn release(self, outcome: Outcome) {
567 0 : self.permit.release(outcome);
568 0 : }
569 0 : pub(crate) fn release_result<T, E>(self, res: Result<T, E>) -> Result<T, E> {
570 0 : match res {
571 0 : Ok(_) => self.release(Outcome::Success),
572 0 : Err(_) => self.release(Outcome::Overload),
573 : }
574 0 : res
575 0 : }
576 : }
577 :
578 : impl FetchAuthRules for ControlPlaneBackend {
579 0 : async fn fetch_auth_rules(
580 0 : &self,
581 0 : ctx: &RequestMonitoring,
582 0 : endpoint: EndpointId,
583 0 : ) -> Result<Vec<AuthRule>, FetchAuthRulesError> {
584 0 : self.get_endpoint_jwks(ctx, endpoint)
585 0 : .await
586 0 : .map_err(FetchAuthRulesError::GetEndpointJwks)
587 0 : }
588 : }
|