LCOV - code coverage report
Current view: top level - proxy/src/control_plane/client - mod.rs (source / functions) Coverage Total Hit
Test: bb45db3982713bfd5bec075773079136e362195e.info Lines: 5.3 % 150 8
Test Date: 2024-12-11 15:53:32 Functions: 5.4 % 37 2

            Line data    Source code
       1              : pub mod cplane_proxy_v1;
       2              : #[cfg(any(test, feature = "testing"))]
       3              : pub mod mock;
       4              : pub mod neon;
       5              : 
       6              : use std::hash::Hash;
       7              : use std::sync::Arc;
       8              : use std::time::Duration;
       9              : 
      10              : use dashmap::DashMap;
      11              : use tokio::time::Instant;
      12              : use tracing::{debug, info};
      13              : 
      14              : use crate::auth::backend::jwt::{AuthRule, FetchAuthRules, FetchAuthRulesError};
      15              : use crate::auth::backend::ComputeUserInfo;
      16              : use crate::cache::endpoints::EndpointsCache;
      17              : use crate::cache::project_info::ProjectInfoCacheImpl;
      18              : use crate::config::{CacheOptions, EndpointCacheConfig, ProjectInfoCacheOptions};
      19              : use crate::context::RequestContext;
      20              : use crate::control_plane::{
      21              :     errors, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret, ControlPlaneApi, NodeInfoCache,
      22              : };
      23              : use crate::error::ReportableError;
      24              : use crate::metrics::ApiLockMetrics;
      25              : use crate::rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token};
      26              : use crate::types::EndpointId;
      27              : 
      28              : #[non_exhaustive]
      29              : #[derive(Clone)]
      30              : pub enum ControlPlaneClient {
      31              :     /// New Proxy V1 control plane API
      32              :     ProxyV1(cplane_proxy_v1::NeonControlPlaneClient),
      33              :     /// Current Management API (V2).
      34              :     Neon(neon::NeonControlPlaneClient),
      35              :     /// Local mock control plane.
      36              :     #[cfg(any(test, feature = "testing"))]
      37              :     PostgresMock(mock::MockControlPlane),
      38              :     /// Internal testing
      39              :     #[cfg(test)]
      40              :     #[allow(private_interfaces)]
      41              :     Test(Box<dyn TestControlPlaneClient>),
      42              : }
      43              : 
      44              : impl ControlPlaneApi for ControlPlaneClient {
      45            0 :     async fn get_role_secret(
      46            0 :         &self,
      47            0 :         ctx: &RequestContext,
      48            0 :         user_info: &ComputeUserInfo,
      49            0 :     ) -> Result<CachedRoleSecret, errors::GetAuthInfoError> {
      50            0 :         match self {
      51            0 :             Self::ProxyV1(api) => api.get_role_secret(ctx, user_info).await,
      52            0 :             Self::Neon(api) => api.get_role_secret(ctx, user_info).await,
      53              :             #[cfg(any(test, feature = "testing"))]
      54            0 :             Self::PostgresMock(api) => api.get_role_secret(ctx, user_info).await,
      55              :             #[cfg(test)]
      56              :             Self::Test(_) => {
      57            0 :                 unreachable!("this function should never be called in the test backend")
      58              :             }
      59              :         }
      60            0 :     }
      61              : 
      62            0 :     async fn get_allowed_ips_and_secret(
      63            0 :         &self,
      64            0 :         ctx: &RequestContext,
      65            0 :         user_info: &ComputeUserInfo,
      66            0 :     ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError> {
      67            0 :         match self {
      68            0 :             Self::ProxyV1(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
      69            0 :             Self::Neon(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
      70              :             #[cfg(any(test, feature = "testing"))]
      71            0 :             Self::PostgresMock(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
      72              :             #[cfg(test)]
      73            0 :             Self::Test(api) => api.get_allowed_ips_and_secret(),
      74              :         }
      75            0 :     }
      76              : 
      77            0 :     async fn get_endpoint_jwks(
      78            0 :         &self,
      79            0 :         ctx: &RequestContext,
      80            0 :         endpoint: EndpointId,
      81            0 :     ) -> Result<Vec<AuthRule>, errors::GetEndpointJwksError> {
      82            0 :         match self {
      83            0 :             Self::ProxyV1(api) => api.get_endpoint_jwks(ctx, endpoint).await,
      84            0 :             Self::Neon(api) => api.get_endpoint_jwks(ctx, endpoint).await,
      85              :             #[cfg(any(test, feature = "testing"))]
      86            0 :             Self::PostgresMock(api) => api.get_endpoint_jwks(ctx, endpoint).await,
      87              :             #[cfg(test)]
      88            0 :             Self::Test(_api) => Ok(vec![]),
      89              :         }
      90            0 :     }
      91              : 
      92           13 :     async fn wake_compute(
      93           13 :         &self,
      94           13 :         ctx: &RequestContext,
      95           13 :         user_info: &ComputeUserInfo,
      96           13 :     ) -> Result<CachedNodeInfo, errors::WakeComputeError> {
      97           13 :         match self {
      98            0 :             Self::ProxyV1(api) => api.wake_compute(ctx, user_info).await,
      99            0 :             Self::Neon(api) => api.wake_compute(ctx, user_info).await,
     100              :             #[cfg(any(test, feature = "testing"))]
     101            0 :             Self::PostgresMock(api) => api.wake_compute(ctx, user_info).await,
     102              :             #[cfg(test)]
     103           13 :             Self::Test(api) => api.wake_compute(),
     104              :         }
     105           13 :     }
     106              : }
     107              : 
     108              : #[cfg(test)]
     109              : pub(crate) trait TestControlPlaneClient: Send + Sync + 'static {
     110              :     fn wake_compute(&self) -> Result<CachedNodeInfo, errors::WakeComputeError>;
     111              : 
     112              :     fn get_allowed_ips_and_secret(
     113              :         &self,
     114              :     ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError>;
     115              : 
     116              :     fn dyn_clone(&self) -> Box<dyn TestControlPlaneClient>;
     117              : }
     118              : 
     119              : #[cfg(test)]
     120              : impl Clone for Box<dyn TestControlPlaneClient> {
     121            0 :     fn clone(&self) -> Self {
     122            0 :         TestControlPlaneClient::dyn_clone(&**self)
     123            0 :     }
     124              : }
     125              : 
     126              : /// Various caches for [`control_plane`](super).
     127              : pub struct ApiCaches {
     128              :     /// Cache for the `wake_compute` API method.
     129              :     pub(crate) node_info: NodeInfoCache,
     130              :     /// Cache which stores project_id -> endpoint_ids mapping.
     131              :     pub project_info: Arc<ProjectInfoCacheImpl>,
     132              :     /// List of all valid endpoints.
     133              :     pub endpoints_cache: Arc<EndpointsCache>,
     134              : }
     135              : 
     136              : impl ApiCaches {
     137            0 :     pub fn new(
     138            0 :         wake_compute_cache_config: CacheOptions,
     139            0 :         project_info_cache_config: ProjectInfoCacheOptions,
     140            0 :         endpoint_cache_config: EndpointCacheConfig,
     141            0 :     ) -> Self {
     142            0 :         Self {
     143            0 :             node_info: NodeInfoCache::new(
     144            0 :                 "node_info_cache",
     145            0 :                 wake_compute_cache_config.size,
     146            0 :                 wake_compute_cache_config.ttl,
     147            0 :                 true,
     148            0 :             ),
     149            0 :             project_info: Arc::new(ProjectInfoCacheImpl::new(project_info_cache_config)),
     150            0 :             endpoints_cache: Arc::new(EndpointsCache::new(endpoint_cache_config)),
     151            0 :         }
     152            0 :     }
     153              : }
     154              : 
     155              : /// Various caches for [`control_plane`](super).
     156              : pub struct ApiLocks<K> {
     157              :     name: &'static str,
     158              :     node_locks: DashMap<K, Arc<DynamicLimiter>>,
     159              :     config: RateLimiterConfig,
     160              :     timeout: Duration,
     161              :     epoch: std::time::Duration,
     162              :     metrics: &'static ApiLockMetrics,
     163              : }
     164              : 
     165              : #[derive(Debug, thiserror::Error)]
     166              : pub(crate) enum ApiLockError {
     167              :     #[error("timeout acquiring resource permit")]
     168              :     TimeoutError(#[from] tokio::time::error::Elapsed),
     169              : }
     170              : 
     171              : impl ReportableError for ApiLockError {
     172            0 :     fn get_error_kind(&self) -> crate::error::ErrorKind {
     173            0 :         match self {
     174            0 :             ApiLockError::TimeoutError(_) => crate::error::ErrorKind::RateLimit,
     175            0 :         }
     176            0 :     }
     177              : }
     178              : 
     179              : impl<K: Hash + Eq + Clone> ApiLocks<K> {
     180            0 :     pub fn new(
     181            0 :         name: &'static str,
     182            0 :         config: RateLimiterConfig,
     183            0 :         shards: usize,
     184            0 :         timeout: Duration,
     185            0 :         epoch: std::time::Duration,
     186            0 :         metrics: &'static ApiLockMetrics,
     187            0 :     ) -> prometheus::Result<Self> {
     188            0 :         Ok(Self {
     189            0 :             name,
     190            0 :             node_locks: DashMap::with_shard_amount(shards),
     191            0 :             config,
     192            0 :             timeout,
     193            0 :             epoch,
     194            0 :             metrics,
     195            0 :         })
     196            0 :     }
     197              : 
     198            0 :     pub(crate) async fn get_permit(&self, key: &K) -> Result<WakeComputePermit, ApiLockError> {
     199            0 :         if self.config.initial_limit == 0 {
     200            0 :             return Ok(WakeComputePermit {
     201            0 :                 permit: Token::disabled(),
     202            0 :             });
     203            0 :         }
     204            0 :         let now = Instant::now();
     205            0 :         let semaphore = {
     206              :             // get fast path
     207            0 :             if let Some(semaphore) = self.node_locks.get(key) {
     208            0 :                 semaphore.clone()
     209              :             } else {
     210            0 :                 self.node_locks
     211            0 :                     .entry(key.clone())
     212            0 :                     .or_insert_with(|| {
     213            0 :                         self.metrics.semaphores_registered.inc();
     214            0 :                         DynamicLimiter::new(self.config)
     215            0 :                     })
     216            0 :                     .clone()
     217              :             }
     218              :         };
     219            0 :         let permit = semaphore.acquire_timeout(self.timeout).await;
     220              : 
     221            0 :         self.metrics
     222            0 :             .semaphore_acquire_seconds
     223            0 :             .observe(now.elapsed().as_secs_f64());
     224            0 :         debug!("acquired permit {:?}", now.elapsed().as_secs_f64());
     225            0 :         Ok(WakeComputePermit { permit: permit? })
     226            0 :     }
     227              : 
     228            0 :     pub async fn garbage_collect_worker(&self) {
     229            0 :         if self.config.initial_limit == 0 {
     230            0 :             return;
     231            0 :         }
     232            0 :         let mut interval =
     233            0 :             tokio::time::interval(self.epoch / (self.node_locks.shards().len()) as u32);
     234              :         loop {
     235            0 :             for (i, shard) in self.node_locks.shards().iter().enumerate() {
     236            0 :                 interval.tick().await;
     237              :                 // temporary lock a single shard and then clear any semaphores that aren't currently checked out
     238              :                 // race conditions: if strong_count == 1, there's no way that it can increase while the shard is locked
     239              :                 // therefore releasing it is safe from race conditions
     240            0 :                 info!(
     241              :                     name = self.name,
     242              :                     shard = i,
     243            0 :                     "performing epoch reclamation on api lock"
     244              :                 );
     245            0 :                 let mut lock = shard.write();
     246            0 :                 let timer = self.metrics.reclamation_lag_seconds.start_timer();
     247            0 :                 let count = lock
     248            0 :                     .extract_if(|_, semaphore| Arc::strong_count(semaphore.get_mut()) == 1)
     249            0 :                     .count();
     250            0 :                 drop(lock);
     251            0 :                 self.metrics.semaphores_unregistered.inc_by(count as u64);
     252            0 :                 timer.observe();
     253            0 :             }
     254              :         }
     255            0 :     }
     256              : }
     257              : 
     258              : pub(crate) struct WakeComputePermit {
     259              :     permit: Token,
     260              : }
     261              : 
     262              : impl WakeComputePermit {
     263            0 :     pub(crate) fn should_check_cache(&self) -> bool {
     264            0 :         !self.permit.is_disabled()
     265            0 :     }
     266            0 :     pub(crate) fn release(self, outcome: Outcome) {
     267            0 :         self.permit.release(outcome);
     268            0 :     }
     269            0 :     pub(crate) fn release_result<T, E>(self, res: Result<T, E>) -> Result<T, E> {
     270            0 :         match res {
     271            0 :             Ok(_) => self.release(Outcome::Success),
     272            0 :             Err(_) => self.release(Outcome::Overload),
     273              :         }
     274            0 :         res
     275            0 :     }
     276              : }
     277              : 
     278              : impl FetchAuthRules for ControlPlaneClient {
     279            0 :     async fn fetch_auth_rules(
     280            0 :         &self,
     281            0 :         ctx: &RequestContext,
     282            0 :         endpoint: EndpointId,
     283            0 :     ) -> Result<Vec<AuthRule>, FetchAuthRulesError> {
     284            0 :         self.get_endpoint_jwks(ctx, endpoint)
     285            0 :             .await
     286            0 :             .map_err(FetchAuthRulesError::GetEndpointJwks)
     287            0 :     }
     288              : }
        

Generated by: LCOV version 2.1-beta