LCOV - code coverage report
Current view: top level - proxy/src/cache - project_info.rs (source / functions) Coverage Total Hit
Test: 5fe7fa8d483b39476409aee736d6d5e32728bfac.info Lines: 59.8 % 624 373
Test Date: 2025-03-12 16:10:49 Functions: 45.8 % 59 27

            Line data    Source code
       1              : use std::collections::HashSet;
       2              : use std::convert::Infallible;
       3              : use std::sync::Arc;
       4              : use std::sync::atomic::AtomicU64;
       5              : use std::time::Duration;
       6              : 
       7              : use async_trait::async_trait;
       8              : use clashmap::ClashMap;
       9              : use rand::{Rng, thread_rng};
      10              : use smol_str::SmolStr;
      11              : use tokio::sync::Mutex;
      12              : use tokio::time::Instant;
      13              : use tracing::{debug, info};
      14              : 
      15              : use super::{Cache, Cached};
      16              : use crate::auth::IpPattern;
      17              : use crate::config::ProjectInfoCacheOptions;
      18              : use crate::control_plane::{AccessBlockerFlags, AuthSecret};
      19              : use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt};
      20              : use crate::types::{EndpointId, RoleName};
      21              : 
      22              : #[async_trait]
      23              : pub(crate) trait ProjectInfoCache {
      24              :     fn invalidate_allowed_ips_for_project(&self, project_id: ProjectIdInt);
      25              :     fn invalidate_allowed_vpc_endpoint_ids_for_projects(&self, project_ids: Vec<ProjectIdInt>);
      26              :     fn invalidate_allowed_vpc_endpoint_ids_for_org(&self, account_id: AccountIdInt);
      27              :     fn invalidate_block_public_or_vpc_access_for_project(&self, project_id: ProjectIdInt);
      28              :     fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt);
      29              :     async fn decrement_active_listeners(&self);
      30              :     async fn increment_active_listeners(&self);
      31              : }
      32              : 
      33              : struct Entry<T> {
      34              :     created_at: Instant,
      35              :     value: T,
      36              : }
      37              : 
      38              : impl<T> Entry<T> {
      39            9 :     pub(crate) fn new(value: T) -> Self {
      40            9 :         Self {
      41            9 :             created_at: Instant::now(),
      42            9 :             value,
      43            9 :         }
      44            9 :     }
      45              : }
      46              : 
      47              : impl<T> From<T> for Entry<T> {
      48            9 :     fn from(value: T) -> Self {
      49            9 :         Self::new(value)
      50            9 :     }
      51              : }
      52              : 
      53              : #[derive(Default)]
      54              : struct EndpointInfo {
      55              :     secret: std::collections::HashMap<RoleNameInt, Entry<Option<AuthSecret>>>,
      56              :     allowed_ips: Option<Entry<Arc<Vec<IpPattern>>>>,
      57              :     block_public_or_vpc_access: Option<Entry<AccessBlockerFlags>>,
      58              :     allowed_vpc_endpoint_ids: Option<Entry<Arc<Vec<String>>>>,
      59              : }
      60              : 
      61              : impl EndpointInfo {
      62           11 :     fn check_ignore_cache(ignore_cache_since: Option<Instant>, created_at: Instant) -> bool {
      63           11 :         match ignore_cache_since {
      64            3 :             None => false,
      65            8 :             Some(t) => t < created_at,
      66              :         }
      67           11 :     }
      68           15 :     pub(crate) fn get_role_secret(
      69           15 :         &self,
      70           15 :         role_name: RoleNameInt,
      71           15 :         valid_since: Instant,
      72           15 :         ignore_cache_since: Option<Instant>,
      73           15 :     ) -> Option<(Option<AuthSecret>, bool)> {
      74           15 :         if let Some(secret) = self.secret.get(&role_name) {
      75           13 :             if valid_since < secret.created_at {
      76            7 :                 return Some((
      77            7 :                     secret.value.clone(),
      78            7 :                     Self::check_ignore_cache(ignore_cache_since, secret.created_at),
      79            7 :                 ));
      80            6 :             }
      81            2 :         }
      82            8 :         None
      83           15 :     }
      84              : 
      85            5 :     pub(crate) fn get_allowed_ips(
      86            5 :         &self,
      87            5 :         valid_since: Instant,
      88            5 :         ignore_cache_since: Option<Instant>,
      89            5 :     ) -> Option<(Arc<Vec<IpPattern>>, bool)> {
      90            5 :         if let Some(allowed_ips) = &self.allowed_ips {
      91            5 :             if valid_since < allowed_ips.created_at {
      92            4 :                 return Some((
      93            4 :                     allowed_ips.value.clone(),
      94            4 :                     Self::check_ignore_cache(ignore_cache_since, allowed_ips.created_at),
      95            4 :                 ));
      96            1 :             }
      97            0 :         }
      98            1 :         None
      99            5 :     }
     100            0 :     pub(crate) fn get_allowed_vpc_endpoint_ids(
     101            0 :         &self,
     102            0 :         valid_since: Instant,
     103            0 :         ignore_cache_since: Option<Instant>,
     104            0 :     ) -> Option<(Arc<Vec<String>>, bool)> {
     105            0 :         if let Some(allowed_vpc_endpoint_ids) = &self.allowed_vpc_endpoint_ids {
     106            0 :             if valid_since < allowed_vpc_endpoint_ids.created_at {
     107            0 :                 return Some((
     108            0 :                     allowed_vpc_endpoint_ids.value.clone(),
     109            0 :                     Self::check_ignore_cache(
     110            0 :                         ignore_cache_since,
     111            0 :                         allowed_vpc_endpoint_ids.created_at,
     112            0 :                     ),
     113            0 :                 ));
     114            0 :             }
     115            0 :         }
     116            0 :         None
     117            0 :     }
     118            0 :     pub(crate) fn get_block_public_or_vpc_access(
     119            0 :         &self,
     120            0 :         valid_since: Instant,
     121            0 :         ignore_cache_since: Option<Instant>,
     122            0 :     ) -> Option<(AccessBlockerFlags, bool)> {
     123            0 :         if let Some(block_public_or_vpc_access) = &self.block_public_or_vpc_access {
     124            0 :             if valid_since < block_public_or_vpc_access.created_at {
     125            0 :                 return Some((
     126            0 :                     block_public_or_vpc_access.value.clone(),
     127            0 :                     Self::check_ignore_cache(
     128            0 :                         ignore_cache_since,
     129            0 :                         block_public_or_vpc_access.created_at,
     130            0 :                     ),
     131            0 :                 ));
     132            0 :             }
     133            0 :         }
     134            0 :         None
     135            0 :     }
     136              : 
     137            0 :     pub(crate) fn invalidate_allowed_ips(&mut self) {
     138            0 :         self.allowed_ips = None;
     139            0 :     }
     140            0 :     pub(crate) fn invalidate_allowed_vpc_endpoint_ids(&mut self) {
     141            0 :         self.allowed_vpc_endpoint_ids = None;
     142            0 :     }
     143            0 :     pub(crate) fn invalidate_block_public_or_vpc_access(&mut self) {
     144            0 :         self.block_public_or_vpc_access = None;
     145            0 :     }
     146            1 :     pub(crate) fn invalidate_role_secret(&mut self, role_name: RoleNameInt) {
     147            1 :         self.secret.remove(&role_name);
     148            1 :     }
     149              : }
     150              : 
     151              : /// Cache for project info.
     152              : /// This is used to cache auth data for endpoints.
     153              : /// Invalidation is done by console notifications or by TTL (if console notifications are disabled).
     154              : ///
     155              : /// We also store endpoint-to-project mapping in the cache, to be able to access per-endpoint data.
     156              : /// One may ask, why the data is stored per project, when on the user request there is only data about the endpoint available?
     157              : /// On the cplane side updates are done per project (or per branch), so it's easier to invalidate the whole project cache.
     158              : pub struct ProjectInfoCacheImpl {
     159              :     cache: ClashMap<EndpointIdInt, EndpointInfo>,
     160              : 
     161              :     project2ep: ClashMap<ProjectIdInt, HashSet<EndpointIdInt>>,
     162              :     // FIXME(stefan): we need a way to GC the account2ep map.
     163              :     account2ep: ClashMap<AccountIdInt, HashSet<EndpointIdInt>>,
     164              :     config: ProjectInfoCacheOptions,
     165              : 
     166              :     start_time: Instant,
     167              :     ttl_disabled_since_us: AtomicU64,
     168              :     active_listeners_lock: Mutex<usize>,
     169              : }
     170              : 
     171              : #[async_trait]
     172              : impl ProjectInfoCache for ProjectInfoCacheImpl {
     173            0 :     fn invalidate_allowed_vpc_endpoint_ids_for_projects(&self, project_ids: Vec<ProjectIdInt>) {
     174            0 :         info!(
     175            0 :             "invalidating allowed vpc endpoint ids for projects `{}`",
     176            0 :             project_ids
     177            0 :                 .iter()
     178            0 :                 .map(|id| id.to_string())
     179            0 :                 .collect::<Vec<_>>()
     180            0 :                 .join(", ")
     181              :         );
     182            0 :         for project_id in project_ids {
     183            0 :             let endpoints = self
     184            0 :                 .project2ep
     185            0 :                 .get(&project_id)
     186            0 :                 .map(|kv| kv.value().clone())
     187            0 :                 .unwrap_or_default();
     188            0 :             for endpoint_id in endpoints {
     189            0 :                 if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
     190            0 :                     endpoint_info.invalidate_allowed_vpc_endpoint_ids();
     191            0 :                 }
     192              :             }
     193              :         }
     194            0 :     }
     195              : 
     196            0 :     fn invalidate_allowed_vpc_endpoint_ids_for_org(&self, account_id: AccountIdInt) {
     197            0 :         info!(
     198            0 :             "invalidating allowed vpc endpoint ids for org `{}`",
     199              :             account_id
     200              :         );
     201            0 :         let endpoints = self
     202            0 :             .account2ep
     203            0 :             .get(&account_id)
     204            0 :             .map(|kv| kv.value().clone())
     205            0 :             .unwrap_or_default();
     206            0 :         for endpoint_id in endpoints {
     207            0 :             if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
     208            0 :                 endpoint_info.invalidate_allowed_vpc_endpoint_ids();
     209            0 :             }
     210              :         }
     211            0 :     }
     212              : 
     213            0 :     fn invalidate_block_public_or_vpc_access_for_project(&self, project_id: ProjectIdInt) {
     214            0 :         info!(
     215            0 :             "invalidating block public or vpc access for project `{}`",
     216              :             project_id
     217              :         );
     218            0 :         let endpoints = self
     219            0 :             .project2ep
     220            0 :             .get(&project_id)
     221            0 :             .map(|kv| kv.value().clone())
     222            0 :             .unwrap_or_default();
     223            0 :         for endpoint_id in endpoints {
     224            0 :             if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
     225            0 :                 endpoint_info.invalidate_block_public_or_vpc_access();
     226            0 :             }
     227              :         }
     228            0 :     }
     229              : 
     230            0 :     fn invalidate_allowed_ips_for_project(&self, project_id: ProjectIdInt) {
     231            0 :         info!("invalidating allowed ips for project `{}`", project_id);
     232            0 :         let endpoints = self
     233            0 :             .project2ep
     234            0 :             .get(&project_id)
     235            0 :             .map(|kv| kv.value().clone())
     236            0 :             .unwrap_or_default();
     237            0 :         for endpoint_id in endpoints {
     238            0 :             if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
     239            0 :                 endpoint_info.invalidate_allowed_ips();
     240            0 :             }
     241              :         }
     242            0 :     }
     243            1 :     fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt) {
     244            1 :         info!(
     245            0 :             "invalidating role secret for project_id `{}` and role_name `{}`",
     246              :             project_id, role_name,
     247              :         );
     248            1 :         let endpoints = self
     249            1 :             .project2ep
     250            1 :             .get(&project_id)
     251            1 :             .map(|kv| kv.value().clone())
     252            1 :             .unwrap_or_default();
     253            2 :         for endpoint_id in endpoints {
     254            1 :             if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
     255            1 :                 endpoint_info.invalidate_role_secret(role_name);
     256            1 :             }
     257              :         }
     258            1 :     }
     259            0 :     async fn decrement_active_listeners(&self) {
     260            0 :         let mut listeners_guard = self.active_listeners_lock.lock().await;
     261            0 :         if *listeners_guard == 0 {
     262            0 :             tracing::error!("active_listeners count is already 0, something is broken");
     263            0 :             return;
     264            0 :         }
     265            0 :         *listeners_guard -= 1;
     266            0 :         if *listeners_guard == 0 {
     267            0 :             self.ttl_disabled_since_us
     268            0 :                 .store(u64::MAX, std::sync::atomic::Ordering::SeqCst);
     269            0 :         }
     270            0 :     }
     271              : 
     272            2 :     async fn increment_active_listeners(&self) {
     273            2 :         let mut listeners_guard = self.active_listeners_lock.lock().await;
     274            2 :         *listeners_guard += 1;
     275            2 :         if *listeners_guard == 1 {
     276            2 :             let new_ttl = (self.start_time.elapsed() + self.config.ttl).as_micros() as u64;
     277            2 :             self.ttl_disabled_since_us
     278            2 :                 .store(new_ttl, std::sync::atomic::Ordering::SeqCst);
     279            2 :         }
     280            4 :     }
     281              : }
     282              : 
     283              : impl ProjectInfoCacheImpl {
     284            3 :     pub(crate) fn new(config: ProjectInfoCacheOptions) -> Self {
     285            3 :         Self {
     286            3 :             cache: ClashMap::new(),
     287            3 :             project2ep: ClashMap::new(),
     288            3 :             account2ep: ClashMap::new(),
     289            3 :             config,
     290            3 :             ttl_disabled_since_us: AtomicU64::new(u64::MAX),
     291            3 :             start_time: Instant::now(),
     292            3 :             active_listeners_lock: Mutex::new(0),
     293            3 :         }
     294            3 :     }
     295              : 
     296           15 :     pub(crate) fn get_role_secret(
     297           15 :         &self,
     298           15 :         endpoint_id: &EndpointId,
     299           15 :         role_name: &RoleName,
     300           15 :     ) -> Option<Cached<&Self, Option<AuthSecret>>> {
     301           15 :         let endpoint_id = EndpointIdInt::get(endpoint_id)?;
     302           15 :         let role_name = RoleNameInt::get(role_name)?;
     303           15 :         let (valid_since, ignore_cache_since) = self.get_cache_times();
     304           15 :         let endpoint_info = self.cache.get(&endpoint_id)?;
     305            7 :         let (value, ignore_cache) =
     306           15 :             endpoint_info.get_role_secret(role_name, valid_since, ignore_cache_since)?;
     307            7 :         if !ignore_cache {
     308            4 :             let cached = Cached {
     309            4 :                 token: Some((
     310            4 :                     self,
     311            4 :                     CachedLookupInfo::new_role_secret(endpoint_id, role_name),
     312            4 :                 )),
     313            4 :                 value,
     314            4 :             };
     315            4 :             return Some(cached);
     316            3 :         }
     317            3 :         Some(Cached::new_uncached(value))
     318           15 :     }
     319            5 :     pub(crate) fn get_allowed_ips(
     320            5 :         &self,
     321            5 :         endpoint_id: &EndpointId,
     322            5 :     ) -> Option<Cached<&Self, Arc<Vec<IpPattern>>>> {
     323            5 :         let endpoint_id = EndpointIdInt::get(endpoint_id)?;
     324            5 :         let (valid_since, ignore_cache_since) = self.get_cache_times();
     325            5 :         let endpoint_info = self.cache.get(&endpoint_id)?;
     326            5 :         let value = endpoint_info.get_allowed_ips(valid_since, ignore_cache_since);
     327            5 :         let (value, ignore_cache) = value?;
     328            4 :         if !ignore_cache {
     329            1 :             let cached = Cached {
     330            1 :                 token: Some((self, CachedLookupInfo::new_allowed_ips(endpoint_id))),
     331            1 :                 value,
     332            1 :             };
     333            1 :             return Some(cached);
     334            3 :         }
     335            3 :         Some(Cached::new_uncached(value))
     336            5 :     }
     337            0 :     pub(crate) fn get_allowed_vpc_endpoint_ids(
     338            0 :         &self,
     339            0 :         endpoint_id: &EndpointId,
     340            0 :     ) -> Option<Cached<&Self, Arc<Vec<String>>>> {
     341            0 :         let endpoint_id = EndpointIdInt::get(endpoint_id)?;
     342            0 :         let (valid_since, ignore_cache_since) = self.get_cache_times();
     343            0 :         let endpoint_info = self.cache.get(&endpoint_id)?;
     344            0 :         let value = endpoint_info.get_allowed_vpc_endpoint_ids(valid_since, ignore_cache_since);
     345            0 :         let (value, ignore_cache) = value?;
     346            0 :         if !ignore_cache {
     347            0 :             let cached = Cached {
     348            0 :                 token: Some((
     349            0 :                     self,
     350            0 :                     CachedLookupInfo::new_allowed_vpc_endpoint_ids(endpoint_id),
     351            0 :                 )),
     352            0 :                 value,
     353            0 :             };
     354            0 :             return Some(cached);
     355            0 :         }
     356            0 :         Some(Cached::new_uncached(value))
     357            0 :     }
     358            0 :     pub(crate) fn get_block_public_or_vpc_access(
     359            0 :         &self,
     360            0 :         endpoint_id: &EndpointId,
     361            0 :     ) -> Option<Cached<&Self, AccessBlockerFlags>> {
     362            0 :         let endpoint_id = EndpointIdInt::get(endpoint_id)?;
     363            0 :         let (valid_since, ignore_cache_since) = self.get_cache_times();
     364            0 :         let endpoint_info = self.cache.get(&endpoint_id)?;
     365            0 :         let value = endpoint_info.get_block_public_or_vpc_access(valid_since, ignore_cache_since);
     366            0 :         let (value, ignore_cache) = value?;
     367            0 :         if !ignore_cache {
     368            0 :             let cached = Cached {
     369            0 :                 token: Some((
     370            0 :                     self,
     371            0 :                     CachedLookupInfo::new_block_public_or_vpc_access(endpoint_id),
     372            0 :                 )),
     373            0 :                 value,
     374            0 :             };
     375            0 :             return Some(cached);
     376            0 :         }
     377            0 :         Some(Cached::new_uncached(value))
     378            0 :     }
     379              : 
     380            7 :     pub(crate) fn insert_role_secret(
     381            7 :         &self,
     382            7 :         project_id: ProjectIdInt,
     383            7 :         endpoint_id: EndpointIdInt,
     384            7 :         role_name: RoleNameInt,
     385            7 :         secret: Option<AuthSecret>,
     386            7 :     ) {
     387            7 :         if self.cache.len() >= self.config.size {
     388              :             // If there are too many entries, wait until the next gc cycle.
     389            0 :             return;
     390            7 :         }
     391            7 :         self.insert_project2endpoint(project_id, endpoint_id);
     392            7 :         let mut entry = self.cache.entry(endpoint_id).or_default();
     393            7 :         if entry.secret.len() < self.config.max_roles {
     394            6 :             entry.secret.insert(role_name, secret.into());
     395            6 :         }
     396            7 :     }
     397            3 :     pub(crate) fn insert_allowed_ips(
     398            3 :         &self,
     399            3 :         project_id: ProjectIdInt,
     400            3 :         endpoint_id: EndpointIdInt,
     401            3 :         allowed_ips: Arc<Vec<IpPattern>>,
     402            3 :     ) {
     403            3 :         if self.cache.len() >= self.config.size {
     404              :             // If there are too many entries, wait until the next gc cycle.
     405            0 :             return;
     406            3 :         }
     407            3 :         self.insert_project2endpoint(project_id, endpoint_id);
     408            3 :         self.cache.entry(endpoint_id).or_default().allowed_ips = Some(allowed_ips.into());
     409            3 :     }
     410            0 :     pub(crate) fn insert_allowed_vpc_endpoint_ids(
     411            0 :         &self,
     412            0 :         account_id: Option<AccountIdInt>,
     413            0 :         project_id: ProjectIdInt,
     414            0 :         endpoint_id: EndpointIdInt,
     415            0 :         allowed_vpc_endpoint_ids: Arc<Vec<String>>,
     416            0 :     ) {
     417            0 :         if self.cache.len() >= self.config.size {
     418              :             // If there are too many entries, wait until the next gc cycle.
     419            0 :             return;
     420            0 :         }
     421            0 :         if let Some(account_id) = account_id {
     422            0 :             self.insert_account2endpoint(account_id, endpoint_id);
     423            0 :         }
     424            0 :         self.insert_project2endpoint(project_id, endpoint_id);
     425            0 :         self.cache
     426            0 :             .entry(endpoint_id)
     427            0 :             .or_default()
     428            0 :             .allowed_vpc_endpoint_ids = Some(allowed_vpc_endpoint_ids.into());
     429            0 :     }
     430            0 :     pub(crate) fn insert_block_public_or_vpc_access(
     431            0 :         &self,
     432            0 :         project_id: ProjectIdInt,
     433            0 :         endpoint_id: EndpointIdInt,
     434            0 :         access_blockers: AccessBlockerFlags,
     435            0 :     ) {
     436            0 :         if self.cache.len() >= self.config.size {
     437              :             // If there are too many entries, wait until the next gc cycle.
     438            0 :             return;
     439            0 :         }
     440            0 :         self.insert_project2endpoint(project_id, endpoint_id);
     441            0 :         self.cache
     442            0 :             .entry(endpoint_id)
     443            0 :             .or_default()
     444            0 :             .block_public_or_vpc_access = Some(access_blockers.into());
     445            0 :     }
     446              : 
     447           10 :     fn insert_project2endpoint(&self, project_id: ProjectIdInt, endpoint_id: EndpointIdInt) {
     448           10 :         if let Some(mut endpoints) = self.project2ep.get_mut(&project_id) {
     449            7 :             endpoints.insert(endpoint_id);
     450            7 :         } else {
     451            3 :             self.project2ep
     452            3 :                 .insert(project_id, HashSet::from([endpoint_id]));
     453            3 :         }
     454           10 :     }
     455            0 :     fn insert_account2endpoint(&self, account_id: AccountIdInt, endpoint_id: EndpointIdInt) {
     456            0 :         if let Some(mut endpoints) = self.account2ep.get_mut(&account_id) {
     457            0 :             endpoints.insert(endpoint_id);
     458            0 :         } else {
     459            0 :             self.account2ep
     460            0 :                 .insert(account_id, HashSet::from([endpoint_id]));
     461            0 :         }
     462            0 :     }
     463           20 :     fn get_cache_times(&self) -> (Instant, Option<Instant>) {
     464           20 :         let mut valid_since = Instant::now() - self.config.ttl;
     465           20 :         // Only ignore cache if ttl is disabled.
     466           20 :         let ttl_disabled_since_us = self
     467           20 :             .ttl_disabled_since_us
     468           20 :             .load(std::sync::atomic::Ordering::Relaxed);
     469           20 :         let ignore_cache_since = if ttl_disabled_since_us == u64::MAX {
     470            7 :             None
     471              :         } else {
     472           13 :             let ignore_cache_since = self.start_time + Duration::from_micros(ttl_disabled_since_us);
     473           13 :             // We are fine if entry is not older than ttl or was added before we are getting notifications.
     474           13 :             valid_since = valid_since.min(ignore_cache_since);
     475           13 :             Some(ignore_cache_since)
     476              :         };
     477           20 :         (valid_since, ignore_cache_since)
     478           20 :     }
     479              : 
     480            0 :     pub async fn gc_worker(&self) -> anyhow::Result<Infallible> {
     481            0 :         let mut interval =
     482            0 :             tokio::time::interval(self.config.gc_interval / (self.cache.shards().len()) as u32);
     483              :         loop {
     484            0 :             interval.tick().await;
     485            0 :             if self.cache.len() < self.config.size {
     486              :                 // If there are not too many entries, wait until the next gc cycle.
     487            0 :                 continue;
     488            0 :             }
     489            0 :             self.gc();
     490              :         }
     491              :     }
     492              : 
     493            0 :     fn gc(&self) {
     494            0 :         let shard = thread_rng().gen_range(0..self.project2ep.shards().len());
     495            0 :         debug!(shard, "project_info_cache: performing epoch reclamation");
     496              : 
     497              :         // acquire a random shard lock
     498            0 :         let mut removed = 0;
     499            0 :         let shard = self.project2ep.shards()[shard].write();
     500            0 :         for (_, endpoints) in shard.iter() {
     501            0 :             for endpoint in endpoints {
     502            0 :                 self.cache.remove(endpoint);
     503            0 :                 removed += 1;
     504            0 :             }
     505              :         }
     506              :         // We can drop this shard only after making sure that all endpoints are removed.
     507            0 :         drop(shard);
     508            0 :         info!("project_info_cache: removed {removed} endpoints");
     509            0 :     }
     510              : }
     511              : 
     512              : /// Lookup info for project info cache.
     513              : /// This is used to invalidate cache entries.
     514              : pub(crate) struct CachedLookupInfo {
     515              :     /// Search by this key.
     516              :     endpoint_id: EndpointIdInt,
     517              :     lookup_type: LookupType,
     518              : }
     519              : 
     520              : impl CachedLookupInfo {
     521            4 :     pub(self) fn new_role_secret(endpoint_id: EndpointIdInt, role_name: RoleNameInt) -> Self {
     522            4 :         Self {
     523            4 :             endpoint_id,
     524            4 :             lookup_type: LookupType::RoleSecret(role_name),
     525            4 :         }
     526            4 :     }
     527            1 :     pub(self) fn new_allowed_ips(endpoint_id: EndpointIdInt) -> Self {
     528            1 :         Self {
     529            1 :             endpoint_id,
     530            1 :             lookup_type: LookupType::AllowedIps,
     531            1 :         }
     532            1 :     }
     533            0 :     pub(self) fn new_allowed_vpc_endpoint_ids(endpoint_id: EndpointIdInt) -> Self {
     534            0 :         Self {
     535            0 :             endpoint_id,
     536            0 :             lookup_type: LookupType::AllowedVpcEndpointIds,
     537            0 :         }
     538            0 :     }
     539            0 :     pub(self) fn new_block_public_or_vpc_access(endpoint_id: EndpointIdInt) -> Self {
     540            0 :         Self {
     541            0 :             endpoint_id,
     542            0 :             lookup_type: LookupType::BlockPublicOrVpcAccess,
     543            0 :         }
     544            0 :     }
     545              : }
     546              : 
     547              : enum LookupType {
     548              :     RoleSecret(RoleNameInt),
     549              :     AllowedIps,
     550              :     AllowedVpcEndpointIds,
     551              :     BlockPublicOrVpcAccess,
     552              : }
     553              : 
     554              : impl Cache for ProjectInfoCacheImpl {
     555              :     type Key = SmolStr;
     556              :     // Value is not really used here, but we need to specify it.
     557              :     type Value = SmolStr;
     558              : 
     559              :     type LookupInfo<Key> = CachedLookupInfo;
     560              : 
     561            0 :     fn invalidate(&self, key: &Self::LookupInfo<SmolStr>) {
     562            0 :         match &key.lookup_type {
     563            0 :             LookupType::RoleSecret(role_name) => {
     564            0 :                 if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) {
     565            0 :                     endpoint_info.invalidate_role_secret(*role_name);
     566            0 :                 }
     567              :             }
     568              :             LookupType::AllowedIps => {
     569            0 :                 if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) {
     570            0 :                     endpoint_info.invalidate_allowed_ips();
     571            0 :                 }
     572              :             }
     573              :             LookupType::AllowedVpcEndpointIds => {
     574            0 :                 if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) {
     575            0 :                     endpoint_info.invalidate_allowed_vpc_endpoint_ids();
     576            0 :                 }
     577              :             }
     578              :             LookupType::BlockPublicOrVpcAccess => {
     579            0 :                 if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) {
     580            0 :                     endpoint_info.invalidate_block_public_or_vpc_access();
     581            0 :                 }
     582              :             }
     583              :         }
     584            0 :     }
     585              : }
     586              : 
     587              : #[cfg(test)]
     588              : #[expect(clippy::unwrap_used)]
     589              : mod tests {
     590              :     use super::*;
     591              :     use crate::scram::ServerSecret;
     592              :     use crate::types::ProjectId;
     593              : 
     594              :     #[tokio::test]
     595            1 :     async fn test_project_info_cache_settings() {
     596            1 :         tokio::time::pause();
     597            1 :         let cache = ProjectInfoCacheImpl::new(ProjectInfoCacheOptions {
     598            1 :             size: 2,
     599            1 :             max_roles: 2,
     600            1 :             ttl: Duration::from_secs(1),
     601            1 :             gc_interval: Duration::from_secs(600),
     602            1 :         });
     603            1 :         let project_id: ProjectId = "project".into();
     604            1 :         let endpoint_id: EndpointId = "endpoint".into();
     605            1 :         let user1: RoleName = "user1".into();
     606            1 :         let user2: RoleName = "user2".into();
     607            1 :         let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32])));
     608            1 :         let secret2 = None;
     609            1 :         let allowed_ips = Arc::new(vec![
     610            1 :             "127.0.0.1".parse().unwrap(),
     611            1 :             "127.0.0.2".parse().unwrap(),
     612            1 :         ]);
     613            1 :         cache.insert_role_secret(
     614            1 :             (&project_id).into(),
     615            1 :             (&endpoint_id).into(),
     616            1 :             (&user1).into(),
     617            1 :             secret1.clone(),
     618            1 :         );
     619            1 :         cache.insert_role_secret(
     620            1 :             (&project_id).into(),
     621            1 :             (&endpoint_id).into(),
     622            1 :             (&user2).into(),
     623            1 :             secret2.clone(),
     624            1 :         );
     625            1 :         cache.insert_allowed_ips(
     626            1 :             (&project_id).into(),
     627            1 :             (&endpoint_id).into(),
     628            1 :             allowed_ips.clone(),
     629            1 :         );
     630            1 : 
     631            1 :         let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
     632            1 :         assert!(cached.cached());
     633            1 :         assert_eq!(cached.value, secret1);
     634            1 :         let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap();
     635            1 :         assert!(cached.cached());
     636            1 :         assert_eq!(cached.value, secret2);
     637            1 : 
     638            1 :         // Shouldn't add more than 2 roles.
     639            1 :         let user3: RoleName = "user3".into();
     640            1 :         let secret3 = Some(AuthSecret::Scram(ServerSecret::mock([3; 32])));
     641            1 :         cache.insert_role_secret(
     642            1 :             (&project_id).into(),
     643            1 :             (&endpoint_id).into(),
     644            1 :             (&user3).into(),
     645            1 :             secret3.clone(),
     646            1 :         );
     647            1 :         assert!(cache.get_role_secret(&endpoint_id, &user3).is_none());
     648            1 : 
     649            1 :         let cached = cache.get_allowed_ips(&endpoint_id).unwrap();
     650            1 :         assert!(cached.cached());
     651            1 :         assert_eq!(cached.value, allowed_ips);
     652            1 : 
     653            1 :         tokio::time::advance(Duration::from_secs(2)).await;
     654            1 :         let cached = cache.get_role_secret(&endpoint_id, &user1);
     655            1 :         assert!(cached.is_none());
     656            1 :         let cached = cache.get_role_secret(&endpoint_id, &user2);
     657            1 :         assert!(cached.is_none());
     658            1 :         let cached = cache.get_allowed_ips(&endpoint_id);
     659            1 :         assert!(cached.is_none());
     660            1 :     }
     661              : 
     662              :     #[tokio::test]
     663            1 :     async fn test_project_info_cache_invalidations() {
     664            1 :         tokio::time::pause();
     665            1 :         let cache = Arc::new(ProjectInfoCacheImpl::new(ProjectInfoCacheOptions {
     666            1 :             size: 2,
     667            1 :             max_roles: 2,
     668            1 :             ttl: Duration::from_secs(1),
     669            1 :             gc_interval: Duration::from_secs(600),
     670            1 :         }));
     671            1 :         cache.clone().increment_active_listeners().await;
     672            1 :         tokio::time::advance(Duration::from_secs(2)).await;
     673            1 : 
     674            1 :         let project_id: ProjectId = "project".into();
     675            1 :         let endpoint_id: EndpointId = "endpoint".into();
     676            1 :         let user1: RoleName = "user1".into();
     677            1 :         let user2: RoleName = "user2".into();
     678            1 :         let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32])));
     679            1 :         let secret2 = Some(AuthSecret::Scram(ServerSecret::mock([2; 32])));
     680            1 :         let allowed_ips = Arc::new(vec![
     681            1 :             "127.0.0.1".parse().unwrap(),
     682            1 :             "127.0.0.2".parse().unwrap(),
     683            1 :         ]);
     684            1 :         cache.insert_role_secret(
     685            1 :             (&project_id).into(),
     686            1 :             (&endpoint_id).into(),
     687            1 :             (&user1).into(),
     688            1 :             secret1.clone(),
     689            1 :         );
     690            1 :         cache.insert_role_secret(
     691            1 :             (&project_id).into(),
     692            1 :             (&endpoint_id).into(),
     693            1 :             (&user2).into(),
     694            1 :             secret2.clone(),
     695            1 :         );
     696            1 :         cache.insert_allowed_ips(
     697            1 :             (&project_id).into(),
     698            1 :             (&endpoint_id).into(),
     699            1 :             allowed_ips.clone(),
     700            1 :         );
     701            1 : 
     702            1 :         tokio::time::advance(Duration::from_secs(2)).await;
     703            1 :         // Nothing should be invalidated.
     704            1 : 
     705            1 :         let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
     706            1 :         // TTL is disabled, so it should be impossible to invalidate this value.
     707            1 :         assert!(!cached.cached());
     708            1 :         assert_eq!(cached.value, secret1);
     709            1 : 
     710            1 :         cached.invalidate(); // Shouldn't do anything.
     711            1 :         let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
     712            1 :         assert_eq!(cached.value, secret1);
     713            1 : 
     714            1 :         let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap();
     715            1 :         assert!(!cached.cached());
     716            1 :         assert_eq!(cached.value, secret2);
     717            1 : 
     718            1 :         // The only way to invalidate this value is to invalidate via the api.
     719            1 :         cache.invalidate_role_secret_for_project((&project_id).into(), (&user2).into());
     720            1 :         assert!(cache.get_role_secret(&endpoint_id, &user2).is_none());
     721            1 : 
     722            1 :         let cached = cache.get_allowed_ips(&endpoint_id).unwrap();
     723            1 :         assert!(!cached.cached());
     724            1 :         assert_eq!(cached.value, allowed_ips);
     725            1 :     }
     726              : 
     727              :     #[tokio::test]
     728            1 :     async fn test_increment_active_listeners_invalidate_added_before() {
     729            1 :         tokio::time::pause();
     730            1 :         let cache = Arc::new(ProjectInfoCacheImpl::new(ProjectInfoCacheOptions {
     731            1 :             size: 2,
     732            1 :             max_roles: 2,
     733            1 :             ttl: Duration::from_secs(1),
     734            1 :             gc_interval: Duration::from_secs(600),
     735            1 :         }));
     736            1 : 
     737            1 :         let project_id: ProjectId = "project".into();
     738            1 :         let endpoint_id: EndpointId = "endpoint".into();
     739            1 :         let user1: RoleName = "user1".into();
     740            1 :         let user2: RoleName = "user2".into();
     741            1 :         let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32])));
     742            1 :         let secret2 = Some(AuthSecret::Scram(ServerSecret::mock([2; 32])));
     743            1 :         let allowed_ips = Arc::new(vec![
     744            1 :             "127.0.0.1".parse().unwrap(),
     745            1 :             "127.0.0.2".parse().unwrap(),
     746            1 :         ]);
     747            1 :         cache.insert_role_secret(
     748            1 :             (&project_id).into(),
     749            1 :             (&endpoint_id).into(),
     750            1 :             (&user1).into(),
     751            1 :             secret1.clone(),
     752            1 :         );
     753            1 :         cache.clone().increment_active_listeners().await;
     754            1 :         tokio::time::advance(Duration::from_millis(100)).await;
     755            1 :         cache.insert_role_secret(
     756            1 :             (&project_id).into(),
     757            1 :             (&endpoint_id).into(),
     758            1 :             (&user2).into(),
     759            1 :             secret2.clone(),
     760            1 :         );
     761            1 : 
     762            1 :         // Added before ttl was disabled + ttl should be still cached.
     763            1 :         let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
     764            1 :         assert!(cached.cached());
     765            1 :         let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap();
     766            1 :         assert!(cached.cached());
     767            1 : 
     768            1 :         tokio::time::advance(Duration::from_secs(1)).await;
     769            1 :         // Added before ttl was disabled + ttl should expire.
     770            1 :         assert!(cache.get_role_secret(&endpoint_id, &user1).is_none());
     771            1 :         assert!(cache.get_role_secret(&endpoint_id, &user2).is_none());
     772            1 : 
     773            1 :         // Added after ttl was disabled + ttl should not be cached.
     774            1 :         cache.insert_allowed_ips(
     775            1 :             (&project_id).into(),
     776            1 :             (&endpoint_id).into(),
     777            1 :             allowed_ips.clone(),
     778            1 :         );
     779            1 :         let cached = cache.get_allowed_ips(&endpoint_id).unwrap();
     780            1 :         assert!(!cached.cached());
     781            1 : 
     782            1 :         tokio::time::advance(Duration::from_secs(1)).await;
     783            1 :         // Added before ttl was disabled + ttl still should expire.
     784            1 :         assert!(cache.get_role_secret(&endpoint_id, &user1).is_none());
     785            1 :         assert!(cache.get_role_secret(&endpoint_id, &user2).is_none());
     786            1 :         // Shouldn't be invalidated.
     787            1 : 
     788            1 :         let cached = cache.get_allowed_ips(&endpoint_id).unwrap();
     789            1 :         assert!(!cached.cached());
     790            1 :         assert_eq!(cached.value, allowed_ips);
     791            1 :     }
     792              : }
        

Generated by: LCOV version 2.1-beta