LCOV - code coverage report
Current view: top level - proxy/src/console - provider.rs (source / functions) Coverage Total Hit
Test: 32f4a56327bc9da697706839ed4836b2a00a408f.info Lines: 60.9 % 215 131
Test Date: 2024-02-07 07:37:29 Functions: 52.1 % 48 25

            Line data    Source code
       1              : #[cfg(any(test, feature = "testing"))]
       2              : pub mod mock;
       3              : pub mod neon;
       4              : 
       5              : use super::messages::MetricsAuxInfo;
       6              : use crate::{
       7              :     auth::{backend::ComputeUserInfo, IpPattern},
       8              :     cache::{project_info::ProjectInfoCacheImpl, Cached, TimedLru},
       9              :     compute,
      10              :     config::{CacheOptions, ProjectInfoCacheOptions},
      11              :     context::RequestMonitoring,
      12              :     scram, EndpointCacheKey, ProjectId,
      13              : };
      14              : use async_trait::async_trait;
      15              : use dashmap::DashMap;
      16              : use std::{sync::Arc, time::Duration};
      17              : use tokio::sync::{OwnedSemaphorePermit, Semaphore};
      18              : use tokio::time::Instant;
      19              : use tracing::info;
      20              : 
      21              : pub mod errors {
      22              :     use crate::{
      23              :         error::{io_error, UserFacingError},
      24              :         http,
      25              :         proxy::retry::ShouldRetry,
      26              :     };
      27              :     use thiserror::Error;
      28              : 
      29              :     /// A go-to error message which doesn't leak any detail.
      30              :     const REQUEST_FAILED: &str = "Console request failed";
      31              : 
      32              :     /// Common console API error.
      33           37 :     #[derive(Debug, Error)]
      34              :     pub enum ApiError {
      35              :         /// Error returned by the console itself.
      36              :         #[error("{REQUEST_FAILED} with {}: {}", .status, .text)]
      37              :         Console {
      38              :             status: http::StatusCode,
      39              :             text: Box<str>,
      40              :         },
      41              : 
      42              :         /// Various IO errors like broken pipe or malformed payload.
      43              :         #[error("{REQUEST_FAILED}: {0}")]
      44              :         Transport(#[from] std::io::Error),
      45              :     }
      46              : 
      47              :     impl ApiError {
      48              :         /// Returns HTTP status code if it's the reason for failure.
      49            3 :         pub fn http_status_code(&self) -> Option<http::StatusCode> {
      50            3 :             use ApiError::*;
      51            3 :             match self {
      52            2 :                 Console { status, .. } => Some(*status),
      53            1 :                 _ => None,
      54              :             }
      55            3 :         }
      56              :     }
      57              : 
      58              :     impl UserFacingError for ApiError {
      59            4 :         fn to_string_client(&self) -> String {
      60            4 :             use ApiError::*;
      61            4 :             match self {
      62              :                 // To minimize risks, only select errors are forwarded to users.
      63              :                 // Ask @neondatabase/control-plane for review before adding more.
      64            2 :                 Console { status, .. } => match *status {
      65              :                     http::StatusCode::NOT_FOUND => {
      66              :                         // Status 404: failed to get a project-related resource.
      67            0 :                         format!("{REQUEST_FAILED}: endpoint cannot be found")
      68              :                     }
      69              :                     http::StatusCode::NOT_ACCEPTABLE => {
      70              :                         // Status 406: endpoint is disabled (we don't allow connections).
      71            0 :                         format!("{REQUEST_FAILED}: endpoint is disabled")
      72              :                     }
      73              :                     http::StatusCode::LOCKED => {
      74              :                         // Status 423: project might be in maintenance mode (or bad state), or quotas exceeded.
      75            0 :                         format!("{REQUEST_FAILED}: endpoint is temporary unavailable. check your quotas and/or contact our support")
      76              :                     }
      77            2 :                     _ => REQUEST_FAILED.to_owned(),
      78              :                 },
      79            2 :                 _ => REQUEST_FAILED.to_owned(),
      80              :             }
      81            4 :         }
      82              :     }
      83              : 
      84              :     impl ShouldRetry for ApiError {
      85            8 :         fn could_retry(&self) -> bool {
      86            8 :             match self {
      87              :                 // retry some transport errors
      88            0 :                 Self::Transport(io) => io.could_retry(),
      89              :                 // retry some temporary failures because the compute was in a bad state
      90              :                 // (bad request can be returned when the endpoint was in transition)
      91              :                 Self::Console {
      92              :                     status: http::StatusCode::BAD_REQUEST,
      93              :                     ..
      94            4 :                 } => true,
      95              :                 // locked can be returned when the endpoint was in transition
      96              :                 // or when quotas are exceeded. don't retry when quotas are exceeded
      97              :                 Self::Console {
      98              :                     status: http::StatusCode::LOCKED,
      99            0 :                     ref text,
     100            0 :                 } => {
     101            0 :                     // written data quota exceeded
     102            0 :                     // data transfer quota exceeded
     103            0 :                     // compute time quota exceeded
     104            0 :                     // logical size quota exceeded
     105            0 :                     !text.contains("quota exceeded")
     106            0 :                         && !text.contains("the limit for current plan reached")
     107              :                 }
     108            4 :                 _ => false,
     109              :             }
     110            8 :         }
     111              :     }
     112              : 
     113              :     impl From<reqwest::Error> for ApiError {
     114            1 :         fn from(e: reqwest::Error) -> Self {
     115            1 :             io_error(e).into()
     116            1 :         }
     117              :     }
     118              : 
     119              :     impl From<reqwest_middleware::Error> for ApiError {
     120            1 :         fn from(e: reqwest_middleware::Error) -> Self {
     121            1 :             io_error(e).into()
     122            1 :         }
     123              :     }
     124              : 
     125           24 :     #[derive(Debug, Error)]
     126              :     pub enum GetAuthInfoError {
     127              :         // We shouldn't include the actual secret here.
     128              :         #[error("Console responded with a malformed auth secret")]
     129              :         BadSecret,
     130              : 
     131              :         #[error(transparent)]
     132              :         ApiError(ApiError),
     133              :     }
     134              : 
     135              :     // This allows more useful interactions than `#[from]`.
     136              :     impl<E: Into<ApiError>> From<E> for GetAuthInfoError {
     137            4 :         fn from(e: E) -> Self {
     138            4 :             Self::ApiError(e.into())
     139            4 :         }
     140              :     }
     141              : 
     142              :     impl UserFacingError for GetAuthInfoError {
     143            4 :         fn to_string_client(&self) -> String {
     144            4 :             use GetAuthInfoError::*;
     145            4 :             match self {
     146              :                 // We absolutely should not leak any secrets!
     147            0 :                 BadSecret => REQUEST_FAILED.to_owned(),
     148              :                 // However, API might return a meaningful error.
     149            4 :                 ApiError(e) => e.to_string_client(),
     150              :             }
     151            4 :         }
     152              :     }
     153            0 :     #[derive(Debug, Error)]
     154              :     pub enum WakeComputeError {
     155              :         #[error("Console responded with a malformed compute address: {0}")]
     156              :         BadComputeAddress(Box<str>),
     157              : 
     158              :         #[error(transparent)]
     159              :         ApiError(ApiError),
     160              : 
     161              :         #[error("Timeout waiting to acquire wake compute lock")]
     162              :         TimeoutError,
     163              :     }
     164              : 
     165              :     // This allows more useful interactions than `#[from]`.
     166              :     impl<E: Into<ApiError>> From<E> for WakeComputeError {
     167            0 :         fn from(e: E) -> Self {
     168            0 :             Self::ApiError(e.into())
     169            0 :         }
     170              :     }
     171              : 
     172              :     impl From<tokio::sync::AcquireError> for WakeComputeError {
     173            0 :         fn from(_: tokio::sync::AcquireError) -> Self {
     174            0 :             WakeComputeError::TimeoutError
     175            0 :         }
     176              :     }
     177              :     impl From<tokio::time::error::Elapsed> for WakeComputeError {
     178            0 :         fn from(_: tokio::time::error::Elapsed) -> Self {
     179            0 :             WakeComputeError::TimeoutError
     180            0 :         }
     181              :     }
     182              : 
     183              :     impl UserFacingError for WakeComputeError {
     184            0 :         fn to_string_client(&self) -> String {
     185            0 :             use WakeComputeError::*;
     186            0 :             match self {
     187              :                 // We shouldn't show user the address even if it's broken.
     188              :                 // Besides, user is unlikely to care about this detail.
     189            0 :                 BadComputeAddress(_) => REQUEST_FAILED.to_owned(),
     190              :                 // However, API might return a meaningful error.
     191            0 :                 ApiError(e) => e.to_string_client(),
     192              : 
     193            0 :                 TimeoutError => "timeout while acquiring the compute resource lock".to_owned(),
     194              :             }
     195            0 :         }
     196              :     }
     197              : }
     198              : 
     199              : /// Auth secret which is managed by the cloud.
     200           62 : #[derive(Clone, Eq, PartialEq, Debug)]
     201              : pub enum AuthSecret {
     202              :     #[cfg(any(test, feature = "testing"))]
     203              :     /// Md5 hash of user's password.
     204              :     Md5([u8; 16]),
     205              : 
     206              :     /// [SCRAM](crate::scram) authentication info.
     207              :     Scram(scram::ServerSecret),
     208              : }
     209              : 
     210            0 : #[derive(Default)]
     211              : pub struct AuthInfo {
     212              :     pub secret: Option<AuthSecret>,
     213              :     /// List of IP addresses allowed for the autorization.
     214              :     pub allowed_ips: Vec<IpPattern>,
     215              :     /// Project ID. This is used for cache invalidation.
     216              :     pub project_id: Option<ProjectId>,
     217              : }
     218              : 
     219              : /// Info for establishing a connection to a compute node.
     220              : /// This is what we get after auth succeeded, but not before!
     221            0 : #[derive(Clone)]
     222              : pub struct NodeInfo {
     223              :     /// Compute node connection params.
     224              :     /// It's sad that we have to clone this, but this will improve
     225              :     /// once we migrate to a bespoke connection logic.
     226              :     pub config: compute::ConnCfg,
     227              : 
     228              :     /// Labels for proxy's metrics.
     229              :     pub aux: MetricsAuxInfo,
     230              : 
     231              :     /// Whether we should accept self-signed certificates (for testing)
     232              :     pub allow_self_signed_compute: bool,
     233              : }
     234              : 
     235              : pub type NodeInfoCache = TimedLru<EndpointCacheKey, NodeInfo>;
     236              : pub type CachedNodeInfo = Cached<&'static NodeInfoCache>;
     237              : pub type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option<AuthSecret>>;
     238              : pub type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc<Vec<IpPattern>>>;
     239              : 
     240              : /// This will allocate per each call, but the http requests alone
     241              : /// already require a few allocations, so it should be fine.
     242              : #[async_trait]
     243              : pub trait Api {
     244              :     /// Get the client's auth secret for authentication.
     245              :     /// Returns option because user not found situation is special.
     246              :     /// We still have to mock the scram to avoid leaking information that user doesn't exist.
     247              :     async fn get_role_secret(
     248              :         &self,
     249              :         ctx: &mut RequestMonitoring,
     250              :         user_info: &ComputeUserInfo,
     251              :     ) -> Result<CachedRoleSecret, errors::GetAuthInfoError>;
     252              : 
     253              :     async fn get_allowed_ips_and_secret(
     254              :         &self,
     255              :         ctx: &mut RequestMonitoring,
     256              :         user_info: &ComputeUserInfo,
     257              :     ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError>;
     258              : 
     259              :     /// Wake up the compute node and return the corresponding connection info.
     260              :     async fn wake_compute(
     261              :         &self,
     262              :         ctx: &mut RequestMonitoring,
     263              :         user_info: &ComputeUserInfo,
     264              :     ) -> Result<CachedNodeInfo, errors::WakeComputeError>;
     265              : }
     266              : 
     267              : #[non_exhaustive]
     268              : pub enum ConsoleBackend {
     269              :     /// Current Cloud API (V2).
     270              :     Console(neon::Api),
     271              :     /// Local mock of Cloud API (V2).
     272              :     #[cfg(any(test, feature = "testing"))]
     273              :     Postgres(mock::Api),
     274              :     /// Internal testing
     275              :     #[cfg(test)]
     276              :     Test(Box<dyn crate::auth::backend::TestBackend>),
     277              : }
     278              : 
     279              : #[async_trait]
     280              : impl Api for ConsoleBackend {
     281           39 :     async fn get_role_secret(
     282           39 :         &self,
     283           39 :         ctx: &mut RequestMonitoring,
     284           39 :         user_info: &ComputeUserInfo,
     285           39 :     ) -> Result<CachedRoleSecret, errors::GetAuthInfoError> {
     286              :         use ConsoleBackend::*;
     287           39 :         match self {
     288            0 :             Console(api) => api.get_role_secret(ctx, user_info).await,
     289              :             #[cfg(any(test, feature = "testing"))]
     290          267 :             Postgres(api) => api.get_role_secret(ctx, user_info).await,
     291              :             #[cfg(test)]
     292            0 :             Test(_) => unreachable!("this function should never be called in the test backend"),
     293              :         }
     294           78 :     }
     295              : 
     296           88 :     async fn get_allowed_ips_and_secret(
     297           88 :         &self,
     298           88 :         ctx: &mut RequestMonitoring,
     299           88 :         user_info: &ComputeUserInfo,
     300           88 :     ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError> {
     301              :         use ConsoleBackend::*;
     302           88 :         match self {
     303           12 :             Console(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
     304              :             #[cfg(any(test, feature = "testing"))]
     305          598 :             Postgres(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
     306              :             #[cfg(test)]
     307            0 :             Test(api) => api.get_allowed_ips_and_secret(),
     308              :         }
     309          176 :     }
     310              : 
     311           93 :     async fn wake_compute(
     312           93 :         &self,
     313           93 :         ctx: &mut RequestMonitoring,
     314           93 :         user_info: &ComputeUserInfo,
     315           93 :     ) -> Result<CachedNodeInfo, errors::WakeComputeError> {
     316              :         use ConsoleBackend::*;
     317              : 
     318           93 :         match self {
     319            0 :             Console(api) => api.wake_compute(ctx, user_info).await,
     320              :             #[cfg(any(test, feature = "testing"))]
     321           79 :             Postgres(api) => api.wake_compute(ctx, user_info).await,
     322              :             #[cfg(test)]
     323           14 :             Test(api) => api.wake_compute(),
     324              :         }
     325          186 :     }
     326              : }
     327              : 
     328              : /// Various caches for [`console`](super).
     329              : pub struct ApiCaches {
     330              :     /// Cache for the `wake_compute` API method.
     331              :     pub node_info: NodeInfoCache,
     332              :     /// Cache which stores project_id -> endpoint_ids mapping.
     333              :     pub project_info: Arc<ProjectInfoCacheImpl>,
     334              : }
     335              : 
     336              : impl ApiCaches {
     337            1 :     pub fn new(
     338            1 :         wake_compute_cache_config: CacheOptions,
     339            1 :         project_info_cache_config: ProjectInfoCacheOptions,
     340            1 :     ) -> Self {
     341            1 :         Self {
     342            1 :             node_info: NodeInfoCache::new(
     343            1 :                 "node_info_cache",
     344            1 :                 wake_compute_cache_config.size,
     345            1 :                 wake_compute_cache_config.ttl,
     346            1 :                 true,
     347            1 :             ),
     348            1 :             project_info: Arc::new(ProjectInfoCacheImpl::new(project_info_cache_config)),
     349            1 :         }
     350            1 :     }
     351              : }
     352              : 
     353              : /// Various caches for [`console`](super).
     354              : pub struct ApiLocks {
     355              :     name: &'static str,
     356              :     node_locks: DashMap<EndpointCacheKey, Arc<Semaphore>>,
     357              :     permits: usize,
     358              :     timeout: Duration,
     359              :     registered: prometheus::IntCounter,
     360              :     unregistered: prometheus::IntCounter,
     361              :     reclamation_lag: prometheus::Histogram,
     362              :     lock_acquire_lag: prometheus::Histogram,
     363              : }
     364              : 
     365              : impl ApiLocks {
     366            1 :     pub fn new(
     367            1 :         name: &'static str,
     368            1 :         permits: usize,
     369            1 :         shards: usize,
     370            1 :         timeout: Duration,
     371            1 :     ) -> prometheus::Result<Self> {
     372            1 :         let registered = prometheus::IntCounter::with_opts(
     373            1 :             prometheus::Opts::new(
     374            1 :                 "semaphores_registered",
     375            1 :                 "Number of semaphores registered in this api lock",
     376            1 :             )
     377            1 :             .namespace(name),
     378            1 :         )?;
     379            1 :         prometheus::register(Box::new(registered.clone()))?;
     380            1 :         let unregistered = prometheus::IntCounter::with_opts(
     381            1 :             prometheus::Opts::new(
     382            1 :                 "semaphores_unregistered",
     383            1 :                 "Number of semaphores unregistered in this api lock",
     384            1 :             )
     385            1 :             .namespace(name),
     386            1 :         )?;
     387            1 :         prometheus::register(Box::new(unregistered.clone()))?;
     388            1 :         let reclamation_lag = prometheus::Histogram::with_opts(
     389            1 :             prometheus::HistogramOpts::new(
     390            1 :                 "reclamation_lag_seconds",
     391            1 :                 "Time it takes to reclaim unused semaphores in the api lock",
     392            1 :             )
     393            1 :             .namespace(name)
     394            1 :             // 1us -> 65ms
     395            1 :             // benchmarks on my mac indicate it's usually in the range of 256us and 512us
     396            1 :             .buckets(prometheus::exponential_buckets(1e-6, 2.0, 16)?),
     397            0 :         )?;
     398            1 :         prometheus::register(Box::new(reclamation_lag.clone()))?;
     399            1 :         let lock_acquire_lag = prometheus::Histogram::with_opts(
     400            1 :             prometheus::HistogramOpts::new(
     401            1 :                 "semaphore_acquire_seconds",
     402            1 :                 "Time it takes to reclaim unused semaphores in the api lock",
     403            1 :             )
     404            1 :             .namespace(name)
     405            1 :             // 0.1ms -> 6s
     406            1 :             .buckets(prometheus::exponential_buckets(1e-4, 2.0, 16)?),
     407            0 :         )?;
     408            1 :         prometheus::register(Box::new(lock_acquire_lag.clone()))?;
     409              : 
     410            1 :         Ok(Self {
     411            1 :             name,
     412            1 :             node_locks: DashMap::with_shard_amount(shards),
     413            1 :             permits,
     414            1 :             timeout,
     415            1 :             lock_acquire_lag,
     416            1 :             registered,
     417            1 :             unregistered,
     418            1 :             reclamation_lag,
     419            1 :         })
     420            1 :     }
     421              : 
     422            0 :     pub async fn get_wake_compute_permit(
     423            0 :         &self,
     424            0 :         key: &EndpointCacheKey,
     425            0 :     ) -> Result<WakeComputePermit, errors::WakeComputeError> {
     426            0 :         if self.permits == 0 {
     427            0 :             return Ok(WakeComputePermit { permit: None });
     428            0 :         }
     429            0 :         let now = Instant::now();
     430            0 :         let semaphore = {
     431              :             // get fast path
     432            0 :             if let Some(semaphore) = self.node_locks.get(key) {
     433            0 :                 semaphore.clone()
     434              :             } else {
     435            0 :                 self.node_locks
     436            0 :                     .entry(key.clone())
     437            0 :                     .or_insert_with(|| {
     438            0 :                         self.registered.inc();
     439            0 :                         Arc::new(Semaphore::new(self.permits))
     440            0 :                     })
     441            0 :                     .clone()
     442              :             }
     443              :         };
     444            0 :         let permit = tokio::time::timeout_at(now + self.timeout, semaphore.acquire_owned()).await;
     445              : 
     446            0 :         self.lock_acquire_lag
     447            0 :             .observe((Instant::now() - now).as_secs_f64());
     448            0 : 
     449            0 :         Ok(WakeComputePermit {
     450            0 :             permit: Some(permit??),
     451              :         })
     452            0 :     }
     453              : 
     454            1 :     pub async fn garbage_collect_worker(&self, epoch: std::time::Duration) {
     455            1 :         if self.permits == 0 {
     456            1 :             return;
     457            0 :         }
     458            0 : 
     459            0 :         let mut interval = tokio::time::interval(epoch / (self.node_locks.shards().len()) as u32);
     460              :         loop {
     461            0 :             for (i, shard) in self.node_locks.shards().iter().enumerate() {
     462            0 :                 interval.tick().await;
     463              :                 // temporary lock a single shard and then clear any semaphores that aren't currently checked out
     464              :                 // race conditions: if strong_count == 1, there's no way that it can increase while the shard is locked
     465              :                 // therefore releasing it is safe from race conditions
     466            0 :                 info!(
     467            0 :                     name = self.name,
     468            0 :                     shard = i,
     469            0 :                     "performing epoch reclamation on api lock"
     470            0 :                 );
     471            0 :                 let mut lock = shard.write();
     472            0 :                 let timer = self.reclamation_lag.start_timer();
     473            0 :                 let count = lock
     474            0 :                     .extract_if(|_, semaphore| Arc::strong_count(semaphore.get_mut()) == 1)
     475            0 :                     .count();
     476            0 :                 drop(lock);
     477            0 :                 self.unregistered.inc_by(count as u64);
     478            0 :                 timer.observe_duration()
     479              :             }
     480              :         }
     481            1 :     }
     482              : }
     483              : 
     484              : pub struct WakeComputePermit {
     485              :     // None if the lock is disabled
     486              :     permit: Option<OwnedSemaphorePermit>,
     487              : }
     488              : 
     489              : impl WakeComputePermit {
     490            0 :     pub fn should_check_cache(&self) -> bool {
     491            0 :         self.permit.is_some()
     492            0 :     }
     493              : }
        

Generated by: LCOV version 2.1-beta