LCOV - code coverage report
Current view: top level - proxy/src/console - provider.rs (source / functions) Coverage Total Hit
Test: c639aa5f7ab62b43d647b10f40d15a15686ce8a9.info Lines: 58.4 % 231 135
Test Date: 2024-02-12 20:26:03 Functions: 51.0 % 51 26

            Line data    Source code
       1              : #[cfg(any(test, feature = "testing"))]
       2              : pub mod mock;
       3              : pub mod neon;
       4              : 
       5              : use super::messages::MetricsAuxInfo;
       6              : use crate::{
       7              :     auth::{backend::ComputeUserInfo, IpPattern},
       8              :     cache::{project_info::ProjectInfoCacheImpl, Cached, TimedLru},
       9              :     compute,
      10              :     config::{CacheOptions, ProjectInfoCacheOptions},
      11              :     context::RequestMonitoring,
      12              :     scram, EndpointCacheKey, ProjectId,
      13              : };
      14              : use async_trait::async_trait;
      15              : use dashmap::DashMap;
      16              : use std::{sync::Arc, time::Duration};
      17              : use tokio::sync::{OwnedSemaphorePermit, Semaphore};
      18              : use tokio::time::Instant;
      19              : use tracing::info;
      20              : 
      21              : pub mod errors {
      22              :     use crate::{
      23              :         error::{io_error, ReportableError, UserFacingError},
      24              :         http,
      25              :         proxy::retry::ShouldRetry,
      26              :     };
      27              :     use thiserror::Error;
      28              : 
      29              :     /// A go-to error message which doesn't leak any detail.
      30              :     const REQUEST_FAILED: &str = "Console request failed";
      31              : 
      32              :     /// Common console API error.
      33           57 :     #[derive(Debug, Error)]
      34              :     pub enum ApiError {
      35              :         /// Error returned by the console itself.
      36              :         #[error("{REQUEST_FAILED} with {}: {}", .status, .text)]
      37              :         Console {
      38              :             status: http::StatusCode,
      39              :             text: Box<str>,
      40              :         },
      41              : 
      42              :         /// Various IO errors like broken pipe or malformed payload.
      43              :         #[error("{REQUEST_FAILED}: {0}")]
      44              :         Transport(#[from] std::io::Error),
      45              :     }
      46              : 
      47              :     impl ApiError {
      48              :         /// Returns HTTP status code if it's the reason for failure.
      49            3 :         pub fn http_status_code(&self) -> Option<http::StatusCode> {
      50            3 :             use ApiError::*;
      51            3 :             match self {
      52            2 :                 Console { status, .. } => Some(*status),
      53            1 :                 _ => None,
      54              :             }
      55            3 :         }
      56              :     }
      57              : 
      58              :     impl UserFacingError for ApiError {
      59            4 :         fn to_string_client(&self) -> String {
      60            4 :             use ApiError::*;
      61            4 :             match self {
      62              :                 // To minimize risks, only select errors are forwarded to users.
      63              :                 // Ask @neondatabase/control-plane for review before adding more.
      64            2 :                 Console { status, .. } => match *status {
      65              :                     http::StatusCode::NOT_FOUND => {
      66              :                         // Status 404: failed to get a project-related resource.
      67            0 :                         format!("{REQUEST_FAILED}: endpoint cannot be found")
      68              :                     }
      69              :                     http::StatusCode::NOT_ACCEPTABLE => {
      70              :                         // Status 406: endpoint is disabled (we don't allow connections).
      71            0 :                         format!("{REQUEST_FAILED}: endpoint is disabled")
      72              :                     }
      73              :                     http::StatusCode::LOCKED => {
      74              :                         // Status 423: project might be in maintenance mode (or bad state), or quotas exceeded.
      75            0 :                         format!("{REQUEST_FAILED}: endpoint is temporary unavailable. check your quotas and/or contact our support")
      76              :                     }
      77            2 :                     _ => REQUEST_FAILED.to_owned(),
      78              :                 },
      79            2 :                 _ => REQUEST_FAILED.to_owned(),
      80              :             }
      81            4 :         }
      82              :     }
      83              : 
      84              :     impl ReportableError for ApiError {
      85            0 :         fn get_error_kind(&self) -> crate::error::ErrorKind {
      86            0 :             match self {
      87            0 :                 ApiError::Console { .. } => crate::error::ErrorKind::ControlPlane,
      88            0 :                 ApiError::Transport(_) => crate::error::ErrorKind::ControlPlane,
      89              :             }
      90            0 :         }
      91              :     }
      92              : 
      93              :     impl ShouldRetry for ApiError {
      94            8 :         fn could_retry(&self) -> bool {
      95            8 :             match self {
      96              :                 // retry some transport errors
      97            0 :                 Self::Transport(io) => io.could_retry(),
      98              :                 // retry some temporary failures because the compute was in a bad state
      99              :                 // (bad request can be returned when the endpoint was in transition)
     100              :                 Self::Console {
     101              :                     status: http::StatusCode::BAD_REQUEST,
     102              :                     ..
     103            4 :                 } => true,
     104              :                 // locked can be returned when the endpoint was in transition
     105              :                 // or when quotas are exceeded. don't retry when quotas are exceeded
     106              :                 Self::Console {
     107              :                     status: http::StatusCode::LOCKED,
     108            0 :                     ref text,
     109            0 :                 } => {
     110            0 :                     // written data quota exceeded
     111            0 :                     // data transfer quota exceeded
     112            0 :                     // compute time quota exceeded
     113            0 :                     // logical size quota exceeded
     114            0 :                     !text.contains("quota exceeded")
     115            0 :                         && !text.contains("the limit for current plan reached")
     116              :                 }
     117            4 :                 _ => false,
     118              :             }
     119            8 :         }
     120              :     }
     121              : 
     122              :     impl From<reqwest::Error> for ApiError {
     123            1 :         fn from(e: reqwest::Error) -> Self {
     124            1 :             io_error(e).into()
     125            1 :         }
     126              :     }
     127              : 
     128              :     impl From<reqwest_middleware::Error> for ApiError {
     129            1 :         fn from(e: reqwest_middleware::Error) -> Self {
     130            1 :             io_error(e).into()
     131            1 :         }
     132              :     }
     133              : 
     134           40 :     #[derive(Debug, Error)]
     135              :     pub enum GetAuthInfoError {
     136              :         // We shouldn't include the actual secret here.
     137              :         #[error("Console responded with a malformed auth secret")]
     138              :         BadSecret,
     139              : 
     140              :         #[error(transparent)]
     141              :         ApiError(ApiError),
     142              :     }
     143              : 
     144              :     // This allows more useful interactions than `#[from]`.
     145              :     impl<E: Into<ApiError>> From<E> for GetAuthInfoError {
     146            4 :         fn from(e: E) -> Self {
     147            4 :             Self::ApiError(e.into())
     148            4 :         }
     149              :     }
     150              : 
     151              :     impl UserFacingError for GetAuthInfoError {
     152            4 :         fn to_string_client(&self) -> String {
     153            4 :             use GetAuthInfoError::*;
     154            4 :             match self {
     155              :                 // We absolutely should not leak any secrets!
     156            0 :                 BadSecret => REQUEST_FAILED.to_owned(),
     157              :                 // However, API might return a meaningful error.
     158            4 :                 ApiError(e) => e.to_string_client(),
     159              :             }
     160            4 :         }
     161              :     }
     162              : 
     163              :     impl ReportableError for GetAuthInfoError {
     164            4 :         fn get_error_kind(&self) -> crate::error::ErrorKind {
     165            4 :             match self {
     166            0 :                 GetAuthInfoError::BadSecret => crate::error::ErrorKind::ControlPlane,
     167            4 :                 GetAuthInfoError::ApiError(_) => crate::error::ErrorKind::ControlPlane,
     168              :             }
     169            4 :         }
     170              :     }
     171              : 
     172            0 :     #[derive(Debug, Error)]
     173              :     pub enum WakeComputeError {
     174              :         #[error("Console responded with a malformed compute address: {0}")]
     175              :         BadComputeAddress(Box<str>),
     176              : 
     177              :         #[error(transparent)]
     178              :         ApiError(ApiError),
     179              : 
     180              :         #[error("Timeout waiting to acquire wake compute lock")]
     181              :         TimeoutError,
     182              :     }
     183              : 
     184              :     // This allows more useful interactions than `#[from]`.
     185              :     impl<E: Into<ApiError>> From<E> for WakeComputeError {
     186            0 :         fn from(e: E) -> Self {
     187            0 :             Self::ApiError(e.into())
     188            0 :         }
     189              :     }
     190              : 
     191              :     impl From<tokio::sync::AcquireError> for WakeComputeError {
     192            0 :         fn from(_: tokio::sync::AcquireError) -> Self {
     193            0 :             WakeComputeError::TimeoutError
     194            0 :         }
     195              :     }
     196              :     impl From<tokio::time::error::Elapsed> for WakeComputeError {
     197            0 :         fn from(_: tokio::time::error::Elapsed) -> Self {
     198            0 :             WakeComputeError::TimeoutError
     199            0 :         }
     200              :     }
     201              : 
     202              :     impl UserFacingError for WakeComputeError {
     203            0 :         fn to_string_client(&self) -> String {
     204            0 :             use WakeComputeError::*;
     205            0 :             match self {
     206              :                 // We shouldn't show user the address even if it's broken.
     207              :                 // Besides, user is unlikely to care about this detail.
     208            0 :                 BadComputeAddress(_) => REQUEST_FAILED.to_owned(),
     209              :                 // However, API might return a meaningful error.
     210            0 :                 ApiError(e) => e.to_string_client(),
     211              : 
     212            0 :                 TimeoutError => "timeout while acquiring the compute resource lock".to_owned(),
     213              :             }
     214            0 :         }
     215              :     }
     216              : 
     217              :     impl ReportableError for WakeComputeError {
     218            0 :         fn get_error_kind(&self) -> crate::error::ErrorKind {
     219            0 :             match self {
     220            0 :                 WakeComputeError::BadComputeAddress(_) => crate::error::ErrorKind::ControlPlane,
     221            0 :                 WakeComputeError::ApiError(e) => e.get_error_kind(),
     222            0 :                 WakeComputeError::TimeoutError => crate::error::ErrorKind::RateLimit,
     223              :             }
     224            0 :         }
     225              :     }
     226              : }
     227              : 
     228              : /// Auth secret which is managed by the cloud.
     229          109 : #[derive(Clone, Eq, PartialEq, Debug)]
     230              : pub enum AuthSecret {
     231              :     #[cfg(any(test, feature = "testing"))]
     232              :     /// Md5 hash of user's password.
     233              :     Md5([u8; 16]),
     234              : 
     235              :     /// [SCRAM](crate::scram) authentication info.
     236              :     Scram(scram::ServerSecret),
     237              : }
     238              : 
     239            0 : #[derive(Default)]
     240              : pub struct AuthInfo {
     241              :     pub secret: Option<AuthSecret>,
     242              :     /// List of IP addresses allowed for the autorization.
     243              :     pub allowed_ips: Vec<IpPattern>,
     244              :     /// Project ID. This is used for cache invalidation.
     245              :     pub project_id: Option<ProjectId>,
     246              : }
     247              : 
     248              : /// Info for establishing a connection to a compute node.
     249              : /// This is what we get after auth succeeded, but not before!
     250            0 : #[derive(Clone)]
     251              : pub struct NodeInfo {
     252              :     /// Compute node connection params.
     253              :     /// It's sad that we have to clone this, but this will improve
     254              :     /// once we migrate to a bespoke connection logic.
     255              :     pub config: compute::ConnCfg,
     256              : 
     257              :     /// Labels for proxy's metrics.
     258              :     pub aux: MetricsAuxInfo,
     259              : 
     260              :     /// Whether we should accept self-signed certificates (for testing)
     261              :     pub allow_self_signed_compute: bool,
     262              : }
     263              : 
     264              : pub type NodeInfoCache = TimedLru<EndpointCacheKey, NodeInfo>;
     265              : pub type CachedNodeInfo = Cached<&'static NodeInfoCache>;
     266              : pub type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option<AuthSecret>>;
     267              : pub type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc<Vec<IpPattern>>>;
     268              : 
     269              : /// This will allocate per each call, but the http requests alone
     270              : /// already require a few allocations, so it should be fine.
     271              : #[async_trait]
     272              : pub trait Api {
     273              :     /// Get the client's auth secret for authentication.
     274              :     /// Returns option because user not found situation is special.
     275              :     /// We still have to mock the scram to avoid leaking information that user doesn't exist.
     276              :     async fn get_role_secret(
     277              :         &self,
     278              :         ctx: &mut RequestMonitoring,
     279              :         user_info: &ComputeUserInfo,
     280              :     ) -> Result<CachedRoleSecret, errors::GetAuthInfoError>;
     281              : 
     282              :     async fn get_allowed_ips_and_secret(
     283              :         &self,
     284              :         ctx: &mut RequestMonitoring,
     285              :         user_info: &ComputeUserInfo,
     286              :     ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError>;
     287              : 
     288              :     /// Wake up the compute node and return the corresponding connection info.
     289              :     async fn wake_compute(
     290              :         &self,
     291              :         ctx: &mut RequestMonitoring,
     292              :         user_info: &ComputeUserInfo,
     293              :     ) -> Result<CachedNodeInfo, errors::WakeComputeError>;
     294              : }
     295              : 
     296              : #[non_exhaustive]
     297              : pub enum ConsoleBackend {
     298              :     /// Current Cloud API (V2).
     299              :     Console(neon::Api),
     300              :     /// Local mock of Cloud API (V2).
     301              :     #[cfg(any(test, feature = "testing"))]
     302              :     Postgres(mock::Api),
     303              :     /// Internal testing
     304              :     #[cfg(test)]
     305              :     Test(Box<dyn crate::auth::backend::TestBackend>),
     306              : }
     307              : 
     308              : #[async_trait]
     309              : impl Api for ConsoleBackend {
     310           87 :     async fn get_role_secret(
     311           87 :         &self,
     312           87 :         ctx: &mut RequestMonitoring,
     313           87 :         user_info: &ComputeUserInfo,
     314           87 :     ) -> Result<CachedRoleSecret, errors::GetAuthInfoError> {
     315              :         use ConsoleBackend::*;
     316           87 :         match self {
     317            0 :             Console(api) => api.get_role_secret(ctx, user_info).await,
     318              :             #[cfg(any(test, feature = "testing"))]
     319          602 :             Postgres(api) => api.get_role_secret(ctx, user_info).await,
     320              :             #[cfg(test)]
     321            0 :             Test(_) => unreachable!("this function should never be called in the test backend"),
     322              :         }
     323          261 :     }
     324              : 
     325           95 :     async fn get_allowed_ips_and_secret(
     326           95 :         &self,
     327           95 :         ctx: &mut RequestMonitoring,
     328           95 :         user_info: &ComputeUserInfo,
     329           95 :     ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError> {
     330              :         use ConsoleBackend::*;
     331           95 :         match self {
     332           13 :             Console(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
     333              :             #[cfg(any(test, feature = "testing"))]
     334          650 :             Postgres(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
     335              :             #[cfg(test)]
     336            0 :             Test(api) => api.get_allowed_ips_and_secret(),
     337              :         }
     338          285 :     }
     339              : 
     340           92 :     async fn wake_compute(
     341           92 :         &self,
     342           92 :         ctx: &mut RequestMonitoring,
     343           92 :         user_info: &ComputeUserInfo,
     344           92 :     ) -> Result<CachedNodeInfo, errors::WakeComputeError> {
     345              :         use ConsoleBackend::*;
     346              : 
     347           92 :         match self {
     348            0 :             Console(api) => api.wake_compute(ctx, user_info).await,
     349              :             #[cfg(any(test, feature = "testing"))]
     350           78 :             Postgres(api) => api.wake_compute(ctx, user_info).await,
     351              :             #[cfg(test)]
     352           14 :             Test(api) => api.wake_compute(),
     353              :         }
     354          276 :     }
     355              : }
     356              : 
     357              : /// Various caches for [`console`](super).
     358              : pub struct ApiCaches {
     359              :     /// Cache for the `wake_compute` API method.
     360              :     pub node_info: NodeInfoCache,
     361              :     /// Cache which stores project_id -> endpoint_ids mapping.
     362              :     pub project_info: Arc<ProjectInfoCacheImpl>,
     363              : }
     364              : 
     365              : impl ApiCaches {
     366            1 :     pub fn new(
     367            1 :         wake_compute_cache_config: CacheOptions,
     368            1 :         project_info_cache_config: ProjectInfoCacheOptions,
     369            1 :     ) -> Self {
     370            1 :         Self {
     371            1 :             node_info: NodeInfoCache::new(
     372            1 :                 "node_info_cache",
     373            1 :                 wake_compute_cache_config.size,
     374            1 :                 wake_compute_cache_config.ttl,
     375            1 :                 true,
     376            1 :             ),
     377            1 :             project_info: Arc::new(ProjectInfoCacheImpl::new(project_info_cache_config)),
     378            1 :         }
     379            1 :     }
     380              : }
     381              : 
     382              : /// Various caches for [`console`](super).
     383              : pub struct ApiLocks {
     384              :     name: &'static str,
     385              :     node_locks: DashMap<EndpointCacheKey, Arc<Semaphore>>,
     386              :     permits: usize,
     387              :     timeout: Duration,
     388              :     registered: prometheus::IntCounter,
     389              :     unregistered: prometheus::IntCounter,
     390              :     reclamation_lag: prometheus::Histogram,
     391              :     lock_acquire_lag: prometheus::Histogram,
     392              : }
     393              : 
     394              : impl ApiLocks {
     395            1 :     pub fn new(
     396            1 :         name: &'static str,
     397            1 :         permits: usize,
     398            1 :         shards: usize,
     399            1 :         timeout: Duration,
     400            1 :     ) -> prometheus::Result<Self> {
     401            1 :         let registered = prometheus::IntCounter::with_opts(
     402            1 :             prometheus::Opts::new(
     403            1 :                 "semaphores_registered",
     404            1 :                 "Number of semaphores registered in this api lock",
     405            1 :             )
     406            1 :             .namespace(name),
     407            1 :         )?;
     408            1 :         prometheus::register(Box::new(registered.clone()))?;
     409            1 :         let unregistered = prometheus::IntCounter::with_opts(
     410            1 :             prometheus::Opts::new(
     411            1 :                 "semaphores_unregistered",
     412            1 :                 "Number of semaphores unregistered in this api lock",
     413            1 :             )
     414            1 :             .namespace(name),
     415            1 :         )?;
     416            1 :         prometheus::register(Box::new(unregistered.clone()))?;
     417            1 :         let reclamation_lag = prometheus::Histogram::with_opts(
     418            1 :             prometheus::HistogramOpts::new(
     419            1 :                 "reclamation_lag_seconds",
     420            1 :                 "Time it takes to reclaim unused semaphores in the api lock",
     421            1 :             )
     422            1 :             .namespace(name)
     423            1 :             // 1us -> 65ms
     424            1 :             // benchmarks on my mac indicate it's usually in the range of 256us and 512us
     425            1 :             .buckets(prometheus::exponential_buckets(1e-6, 2.0, 16)?),
     426            0 :         )?;
     427            1 :         prometheus::register(Box::new(reclamation_lag.clone()))?;
     428            1 :         let lock_acquire_lag = prometheus::Histogram::with_opts(
     429            1 :             prometheus::HistogramOpts::new(
     430            1 :                 "semaphore_acquire_seconds",
     431            1 :                 "Time it takes to reclaim unused semaphores in the api lock",
     432            1 :             )
     433            1 :             .namespace(name)
     434            1 :             // 0.1ms -> 6s
     435            1 :             .buckets(prometheus::exponential_buckets(1e-4, 2.0, 16)?),
     436            0 :         )?;
     437            1 :         prometheus::register(Box::new(lock_acquire_lag.clone()))?;
     438              : 
     439            1 :         Ok(Self {
     440            1 :             name,
     441            1 :             node_locks: DashMap::with_shard_amount(shards),
     442            1 :             permits,
     443            1 :             timeout,
     444            1 :             lock_acquire_lag,
     445            1 :             registered,
     446            1 :             unregistered,
     447            1 :             reclamation_lag,
     448            1 :         })
     449            1 :     }
     450              : 
     451            0 :     pub async fn get_wake_compute_permit(
     452            0 :         &self,
     453            0 :         key: &EndpointCacheKey,
     454            0 :     ) -> Result<WakeComputePermit, errors::WakeComputeError> {
     455            0 :         if self.permits == 0 {
     456            0 :             return Ok(WakeComputePermit { permit: None });
     457            0 :         }
     458            0 :         let now = Instant::now();
     459            0 :         let semaphore = {
     460              :             // get fast path
     461            0 :             if let Some(semaphore) = self.node_locks.get(key) {
     462            0 :                 semaphore.clone()
     463              :             } else {
     464            0 :                 self.node_locks
     465            0 :                     .entry(key.clone())
     466            0 :                     .or_insert_with(|| {
     467            0 :                         self.registered.inc();
     468            0 :                         Arc::new(Semaphore::new(self.permits))
     469            0 :                     })
     470            0 :                     .clone()
     471              :             }
     472              :         };
     473            0 :         let permit = tokio::time::timeout_at(now + self.timeout, semaphore.acquire_owned()).await;
     474              : 
     475            0 :         self.lock_acquire_lag
     476            0 :             .observe((Instant::now() - now).as_secs_f64());
     477            0 : 
     478            0 :         Ok(WakeComputePermit {
     479            0 :             permit: Some(permit??),
     480              :         })
     481            0 :     }
     482              : 
     483            1 :     pub async fn garbage_collect_worker(&self, epoch: std::time::Duration) {
     484            1 :         if self.permits == 0 {
     485            1 :             return;
     486            0 :         }
     487            0 : 
     488            0 :         let mut interval = tokio::time::interval(epoch / (self.node_locks.shards().len()) as u32);
     489              :         loop {
     490            0 :             for (i, shard) in self.node_locks.shards().iter().enumerate() {
     491            0 :                 interval.tick().await;
     492              :                 // temporary lock a single shard and then clear any semaphores that aren't currently checked out
     493              :                 // race conditions: if strong_count == 1, there's no way that it can increase while the shard is locked
     494              :                 // therefore releasing it is safe from race conditions
     495            0 :                 info!(
     496            0 :                     name = self.name,
     497            0 :                     shard = i,
     498            0 :                     "performing epoch reclamation on api lock"
     499            0 :                 );
     500            0 :                 let mut lock = shard.write();
     501            0 :                 let timer = self.reclamation_lag.start_timer();
     502            0 :                 let count = lock
     503            0 :                     .extract_if(|_, semaphore| Arc::strong_count(semaphore.get_mut()) == 1)
     504            0 :                     .count();
     505            0 :                 drop(lock);
     506            0 :                 self.unregistered.inc_by(count as u64);
     507            0 :                 timer.observe_duration()
     508              :             }
     509              :         }
     510            1 :     }
     511              : }
     512              : 
     513              : pub struct WakeComputePermit {
     514              :     // None if the lock is disabled
     515              :     permit: Option<OwnedSemaphorePermit>,
     516              : }
     517              : 
     518              : impl WakeComputePermit {
     519            0 :     pub fn should_check_cache(&self) -> bool {
     520            0 :         self.permit.is_some()
     521            0 :     }
     522              : }
        

Generated by: LCOV version 2.1-beta