LCOV - code coverage report
Current view: top level - proxy/src/control_plane/client - mod.rs (source / functions) Coverage Total Hit
Test: c8f8d331b83562868d9054d9e0e68f866772aeaa.info Lines: 5.6 % 142 8
Test Date: 2025-07-26 17:20:05 Functions: 6.5 % 31 2

            Line data    Source code
       1              : pub mod cplane_proxy_v1;
       2              : #[cfg(any(test, feature = "testing"))]
       3              : pub mod mock;
       4              : 
       5              : use std::hash::Hash;
       6              : use std::sync::Arc;
       7              : use std::time::Duration;
       8              : 
       9              : use clashmap::ClashMap;
      10              : use tokio::time::Instant;
      11              : use tracing::{debug, info};
      12              : 
      13              : use super::{EndpointAccessControl, RoleAccessControl};
      14              : use crate::auth::backend::ComputeUserInfo;
      15              : use crate::auth::backend::jwt::{AuthRule, FetchAuthRules, FetchAuthRulesError};
      16              : use crate::cache::node_info::{CachedNodeInfo, NodeInfoCache};
      17              : use crate::cache::project_info::ProjectInfoCache;
      18              : use crate::config::{CacheOptions, ProjectInfoCacheOptions};
      19              : use crate::context::RequestContext;
      20              : use crate::control_plane::{ControlPlaneApi, errors};
      21              : use crate::error::ReportableError;
      22              : use crate::metrics::ApiLockMetrics;
      23              : use crate::rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token};
      24              : use crate::types::EndpointId;
      25              : 
      26              : #[non_exhaustive]
      27              : #[derive(Clone)]
      28              : pub enum ControlPlaneClient {
      29              :     /// Proxy V1 control plane API
      30              :     ProxyV1(cplane_proxy_v1::NeonControlPlaneClient),
      31              :     /// Local mock control plane.
      32              :     #[cfg(any(test, feature = "testing"))]
      33              :     PostgresMock(mock::MockControlPlane),
      34              :     /// Internal testing
      35              :     #[cfg(test)]
      36              :     #[allow(private_interfaces)]
      37              :     Test(Box<dyn TestControlPlaneClient>),
      38              : }
      39              : 
      40              : impl ControlPlaneApi for ControlPlaneClient {
      41            0 :     async fn get_role_access_control(
      42            0 :         &self,
      43            0 :         ctx: &RequestContext,
      44            0 :         endpoint: &EndpointId,
      45            0 :         role: &crate::types::RoleName,
      46            0 :     ) -> Result<RoleAccessControl, errors::GetAuthInfoError> {
      47            0 :         match self {
      48            0 :             Self::ProxyV1(api) => api.get_role_access_control(ctx, endpoint, role).await,
      49              :             #[cfg(any(test, feature = "testing"))]
      50            0 :             Self::PostgresMock(api) => api.get_role_access_control(ctx, endpoint, role).await,
      51              :             #[cfg(test)]
      52            0 :             Self::Test(_api) => {
      53            0 :                 unreachable!("this function should never be called in the test backend")
      54              :             }
      55              :         }
      56            0 :     }
      57              : 
      58            0 :     async fn get_endpoint_access_control(
      59            0 :         &self,
      60            0 :         ctx: &RequestContext,
      61            0 :         endpoint: &EndpointId,
      62            0 :         role: &crate::types::RoleName,
      63            0 :     ) -> Result<EndpointAccessControl, errors::GetAuthInfoError> {
      64            0 :         match self {
      65            0 :             Self::ProxyV1(api) => api.get_endpoint_access_control(ctx, endpoint, role).await,
      66              :             #[cfg(any(test, feature = "testing"))]
      67            0 :             Self::PostgresMock(api) => api.get_endpoint_access_control(ctx, endpoint, role).await,
      68              :             #[cfg(test)]
      69            0 :             Self::Test(api) => api.get_access_control(),
      70              :         }
      71            0 :     }
      72              : 
      73            0 :     async fn get_endpoint_jwks(
      74            0 :         &self,
      75            0 :         ctx: &RequestContext,
      76            0 :         endpoint: &EndpointId,
      77            0 :     ) -> Result<Vec<AuthRule>, errors::GetEndpointJwksError> {
      78            0 :         match self {
      79            0 :             Self::ProxyV1(api) => api.get_endpoint_jwks(ctx, endpoint).await,
      80              :             #[cfg(any(test, feature = "testing"))]
      81            0 :             Self::PostgresMock(api) => api.get_endpoint_jwks(ctx, endpoint).await,
      82              :             #[cfg(test)]
      83            0 :             Self::Test(_api) => Ok(vec![]),
      84              :         }
      85            0 :     }
      86              : 
      87           21 :     async fn wake_compute(
      88           21 :         &self,
      89           21 :         ctx: &RequestContext,
      90           21 :         user_info: &ComputeUserInfo,
      91           21 :     ) -> Result<CachedNodeInfo, errors::WakeComputeError> {
      92           21 :         match self {
      93            0 :             Self::ProxyV1(api) => api.wake_compute(ctx, user_info).await,
      94              :             #[cfg(any(test, feature = "testing"))]
      95            0 :             Self::PostgresMock(api) => api.wake_compute(ctx, user_info).await,
      96              :             #[cfg(test)]
      97           21 :             Self::Test(api) => api.wake_compute(),
      98              :         }
      99           21 :     }
     100              : }
     101              : 
     102              : #[cfg(test)]
     103              : pub(crate) trait TestControlPlaneClient: Send + Sync + 'static {
     104              :     fn wake_compute(&self) -> Result<CachedNodeInfo, errors::WakeComputeError>;
     105              : 
     106              :     fn get_access_control(&self) -> Result<EndpointAccessControl, errors::GetAuthInfoError>;
     107              : 
     108              :     fn dyn_clone(&self) -> Box<dyn TestControlPlaneClient>;
     109              : }
     110              : 
     111              : #[cfg(test)]
     112              : impl Clone for Box<dyn TestControlPlaneClient> {
     113            0 :     fn clone(&self) -> Self {
     114            0 :         TestControlPlaneClient::dyn_clone(&**self)
     115            0 :     }
     116              : }
     117              : 
     118              : /// Various caches for [`control_plane`](super).
     119              : pub struct ApiCaches {
     120              :     /// Cache for the `wake_compute` API method.
     121              :     pub(crate) node_info: NodeInfoCache,
     122              :     /// Cache which stores project_id -> endpoint_ids mapping.
     123              :     pub project_info: Arc<ProjectInfoCache>,
     124              : }
     125              : 
     126              : impl ApiCaches {
     127            0 :     pub fn new(
     128            0 :         wake_compute_cache_config: CacheOptions,
     129            0 :         project_info_cache_config: ProjectInfoCacheOptions,
     130            0 :     ) -> Self {
     131            0 :         Self {
     132            0 :             node_info: NodeInfoCache::new(wake_compute_cache_config),
     133            0 :             project_info: Arc::new(ProjectInfoCache::new(project_info_cache_config)),
     134            0 :         }
     135            0 :     }
     136              : }
     137              : 
     138              : /// Various caches for [`control_plane`](super).
     139              : pub struct ApiLocks<K> {
     140              :     name: &'static str,
     141              :     node_locks: ClashMap<K, Arc<DynamicLimiter>>,
     142              :     config: RateLimiterConfig,
     143              :     timeout: Duration,
     144              :     epoch: std::time::Duration,
     145              :     metrics: &'static ApiLockMetrics,
     146              : }
     147              : 
     148              : #[derive(Debug, thiserror::Error)]
     149              : pub(crate) enum ApiLockError {
     150              :     #[error("timeout acquiring resource permit")]
     151              :     TimeoutError(#[from] tokio::time::error::Elapsed),
     152              : }
     153              : 
     154              : impl ReportableError for ApiLockError {
     155            0 :     fn get_error_kind(&self) -> crate::error::ErrorKind {
     156            0 :         match self {
     157            0 :             ApiLockError::TimeoutError(_) => crate::error::ErrorKind::RateLimit,
     158              :         }
     159            0 :     }
     160              : }
     161              : 
     162              : impl<K: Hash + Eq + Clone> ApiLocks<K> {
     163            0 :     pub fn new(
     164            0 :         name: &'static str,
     165            0 :         config: RateLimiterConfig,
     166            0 :         shards: usize,
     167            0 :         timeout: Duration,
     168            0 :         epoch: std::time::Duration,
     169            0 :         metrics: &'static ApiLockMetrics,
     170            0 :     ) -> Self {
     171            0 :         Self {
     172            0 :             name,
     173            0 :             node_locks: ClashMap::with_shard_amount(shards),
     174            0 :             config,
     175            0 :             timeout,
     176            0 :             epoch,
     177            0 :             metrics,
     178            0 :         }
     179            0 :     }
     180              : 
     181            0 :     pub(crate) async fn get_permit(&self, key: &K) -> Result<WakeComputePermit, ApiLockError> {
     182            0 :         if self.config.initial_limit == 0 {
     183            0 :             return Ok(WakeComputePermit {
     184            0 :                 permit: Token::disabled(),
     185            0 :             });
     186            0 :         }
     187            0 :         let now = Instant::now();
     188            0 :         let semaphore = {
     189              :             // get fast path
     190            0 :             if let Some(semaphore) = self.node_locks.get(key) {
     191            0 :                 semaphore.clone()
     192              :             } else {
     193            0 :                 self.node_locks
     194            0 :                     .entry(key.clone())
     195            0 :                     .or_insert_with(|| {
     196            0 :                         self.metrics.semaphores_registered.inc();
     197            0 :                         DynamicLimiter::new(self.config)
     198            0 :                     })
     199            0 :                     .clone()
     200              :             }
     201              :         };
     202            0 :         let permit = semaphore.acquire_timeout(self.timeout).await;
     203              : 
     204            0 :         self.metrics
     205            0 :             .semaphore_acquire_seconds
     206            0 :             .observe(now.elapsed().as_secs_f64());
     207              : 
     208            0 :         if permit.is_ok() {
     209            0 :             debug!(elapsed = ?now.elapsed(), "acquired permit");
     210              :         } else {
     211            0 :             debug!(elapsed = ?now.elapsed(), "timed out acquiring permit");
     212              :         }
     213            0 :         Ok(WakeComputePermit { permit: permit? })
     214            0 :     }
     215              : 
     216            0 :     pub async fn garbage_collect_worker(&self) {
     217            0 :         if self.config.initial_limit == 0 {
     218            0 :             return;
     219            0 :         }
     220            0 :         let mut interval =
     221            0 :             tokio::time::interval(self.epoch / (self.node_locks.shards().len()) as u32);
     222              :         loop {
     223            0 :             for (i, shard) in self.node_locks.shards().iter().enumerate() {
     224            0 :                 interval.tick().await;
     225              :                 // temporary lock a single shard and then clear any semaphores that aren't currently checked out
     226              :                 // race conditions: if strong_count == 1, there's no way that it can increase while the shard is locked
     227              :                 // therefore releasing it is safe from race conditions
     228            0 :                 info!(
     229              :                     name = self.name,
     230              :                     shard = i,
     231            0 :                     "performing epoch reclamation on api lock"
     232              :                 );
     233            0 :                 let mut lock = shard.write();
     234            0 :                 let timer = self.metrics.reclamation_lag_seconds.start_timer();
     235            0 :                 let count = lock
     236            0 :                     .extract_if(|(_, semaphore)| Arc::strong_count(semaphore) == 1)
     237            0 :                     .count();
     238            0 :                 drop(lock);
     239            0 :                 self.metrics.semaphores_unregistered.inc_by(count as u64);
     240            0 :                 timer.observe();
     241              :             }
     242              :         }
     243            0 :     }
     244              : }
     245              : 
     246              : pub(crate) struct WakeComputePermit {
     247              :     permit: Token,
     248              : }
     249              : 
     250              : impl WakeComputePermit {
     251            0 :     pub(crate) fn should_check_cache(&self) -> bool {
     252            0 :         !self.permit.is_disabled()
     253            0 :     }
     254            0 :     pub(crate) fn release(self, outcome: Outcome) {
     255            0 :         self.permit.release(outcome);
     256            0 :     }
     257            0 :     pub(crate) fn release_result<T, E>(self, res: Result<T, E>) -> Result<T, E> {
     258            0 :         match res {
     259            0 :             Ok(_) => self.release(Outcome::Success),
     260            0 :             Err(_) => self.release(Outcome::Overload),
     261              :         }
     262            0 :         res
     263            0 :     }
     264              : }
     265              : 
     266              : impl FetchAuthRules for ControlPlaneClient {
     267            0 :     async fn fetch_auth_rules(
     268            0 :         &self,
     269            0 :         ctx: &RequestContext,
     270            0 :         endpoint: EndpointId,
     271            0 :     ) -> Result<Vec<AuthRule>, FetchAuthRulesError> {
     272            0 :         self.get_endpoint_jwks(ctx, &endpoint)
     273            0 :             .await
     274            0 :             .map_err(FetchAuthRulesError::GetEndpointJwks)
     275            0 :     }
     276              : }
        

Generated by: LCOV version 2.1-beta