LCOV - code coverage report
Current view: top level - proxy/src/cache - project_info.rs (source / functions) Coverage Total Hit
Test: 1e20c4f2b28aa592527961bb32170ebbd2c9172f.info Lines: 60.5 % 311 188
Test Date: 2025-07-16 12:29:03 Functions: 50.0 % 36 18

            Line data    Source code
       1              : use std::collections::{HashMap, HashSet, hash_map};
       2              : use std::convert::Infallible;
       3              : use std::sync::atomic::AtomicU64;
       4              : use std::time::Duration;
       5              : 
       6              : use async_trait::async_trait;
       7              : use clashmap::ClashMap;
       8              : use clashmap::mapref::one::Ref;
       9              : use rand::{Rng, thread_rng};
      10              : use tokio::sync::Mutex;
      11              : use tokio::time::Instant;
      12              : use tracing::{debug, info};
      13              : 
      14              : use crate::config::ProjectInfoCacheOptions;
      15              : use crate::control_plane::{EndpointAccessControl, RoleAccessControl};
      16              : use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt};
      17              : use crate::types::{EndpointId, RoleName};
      18              : 
      19              : #[async_trait]
      20              : pub(crate) trait ProjectInfoCache {
      21              :     fn invalidate_endpoint_access(&self, endpoint_id: EndpointIdInt);
      22              :     fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt);
      23              :     fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt);
      24              :     fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt);
      25              :     async fn decrement_active_listeners(&self);
      26              :     async fn increment_active_listeners(&self);
      27              : }
      28              : 
      29              : struct Entry<T> {
      30              :     created_at: Instant,
      31              :     value: T,
      32              : }
      33              : 
      34              : impl<T> Entry<T> {
      35            6 :     pub(crate) fn new(value: T) -> Self {
      36            6 :         Self {
      37            6 :             created_at: Instant::now(),
      38            6 :             value,
      39            6 :         }
      40            6 :     }
      41              : 
      42            6 :     pub(crate) fn get(&self, valid_since: Instant) -> Option<&T> {
      43            6 :         (valid_since < self.created_at).then_some(&self.value)
      44            6 :     }
      45              : }
      46              : 
      47              : impl<T> From<T> for Entry<T> {
      48            6 :     fn from(value: T) -> Self {
      49            6 :         Self::new(value)
      50            6 :     }
      51              : }
      52              : 
      53              : struct EndpointInfo {
      54              :     role_controls: HashMap<RoleNameInt, Entry<RoleAccessControl>>,
      55              :     controls: Option<Entry<EndpointAccessControl>>,
      56              : }
      57              : 
      58              : impl EndpointInfo {
      59            5 :     pub(crate) fn get_role_secret(
      60            5 :         &self,
      61            5 :         role_name: RoleNameInt,
      62            5 :         valid_since: Instant,
      63            5 :     ) -> Option<RoleAccessControl> {
      64            5 :         let controls = self.role_controls.get(&role_name)?;
      65            4 :         controls.get(valid_since).cloned()
      66            5 :     }
      67              : 
      68            2 :     pub(crate) fn get_controls(&self, valid_since: Instant) -> Option<EndpointAccessControl> {
      69            2 :         let controls = self.controls.as_ref()?;
      70            2 :         controls.get(valid_since).cloned()
      71            2 :     }
      72              : 
      73            0 :     pub(crate) fn invalidate_endpoint(&mut self) {
      74            0 :         self.controls = None;
      75            0 :     }
      76              : 
      77            0 :     pub(crate) fn invalidate_role_secret(&mut self, role_name: RoleNameInt) {
      78            0 :         self.role_controls.remove(&role_name);
      79            0 :     }
      80              : }
      81              : 
      82              : /// Cache for project info.
      83              : /// This is used to cache auth data for endpoints.
      84              : /// Invalidation is done by console notifications or by TTL (if console notifications are disabled).
      85              : ///
      86              : /// We also store endpoint-to-project mapping in the cache, to be able to access per-endpoint data.
      87              : /// One may ask, why the data is stored per project, when on the user request there is only data about the endpoint available?
      88              : /// On the cplane side updates are done per project (or per branch), so it's easier to invalidate the whole project cache.
      89              : pub struct ProjectInfoCacheImpl {
      90              :     cache: ClashMap<EndpointIdInt, EndpointInfo>,
      91              : 
      92              :     project2ep: ClashMap<ProjectIdInt, HashSet<EndpointIdInt>>,
      93              :     // FIXME(stefan): we need a way to GC the account2ep map.
      94              :     account2ep: ClashMap<AccountIdInt, HashSet<EndpointIdInt>>,
      95              :     config: ProjectInfoCacheOptions,
      96              : 
      97              :     start_time: Instant,
      98              :     ttl_disabled_since_us: AtomicU64,
      99              :     active_listeners_lock: Mutex<usize>,
     100              : }
     101              : 
     102              : #[async_trait]
     103              : impl ProjectInfoCache for ProjectInfoCacheImpl {
     104            0 :     fn invalidate_endpoint_access(&self, endpoint_id: EndpointIdInt) {
     105            0 :         info!("invalidating endpoint access for `{endpoint_id}`");
     106            0 :         if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
     107            0 :             endpoint_info.invalidate_endpoint();
     108            0 :         }
     109            0 :     }
     110              : 
     111            0 :     fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt) {
     112            0 :         info!("invalidating endpoint access for project `{project_id}`");
     113            0 :         let endpoints = self
     114            0 :             .project2ep
     115            0 :             .get(&project_id)
     116            0 :             .map(|kv| kv.value().clone())
     117            0 :             .unwrap_or_default();
     118            0 :         for endpoint_id in endpoints {
     119            0 :             if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
     120            0 :                 endpoint_info.invalidate_endpoint();
     121            0 :             }
     122              :         }
     123            0 :     }
     124              : 
     125            0 :     fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt) {
     126            0 :         info!("invalidating endpoint access for org `{account_id}`");
     127            0 :         let endpoints = self
     128            0 :             .account2ep
     129            0 :             .get(&account_id)
     130            0 :             .map(|kv| kv.value().clone())
     131            0 :             .unwrap_or_default();
     132            0 :         for endpoint_id in endpoints {
     133            0 :             if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
     134            0 :                 endpoint_info.invalidate_endpoint();
     135            0 :             }
     136              :         }
     137            0 :     }
     138              : 
     139            0 :     fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt) {
     140            0 :         info!(
     141            0 :             "invalidating role secret for project_id `{}` and role_name `{}`",
     142              :             project_id, role_name,
     143              :         );
     144            0 :         let endpoints = self
     145            0 :             .project2ep
     146            0 :             .get(&project_id)
     147            0 :             .map(|kv| kv.value().clone())
     148            0 :             .unwrap_or_default();
     149            0 :         for endpoint_id in endpoints {
     150            0 :             if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
     151            0 :                 endpoint_info.invalidate_role_secret(role_name);
     152            0 :             }
     153              :         }
     154            0 :     }
     155              : 
     156            0 :     async fn decrement_active_listeners(&self) {
     157            0 :         let mut listeners_guard = self.active_listeners_lock.lock().await;
     158            0 :         if *listeners_guard == 0 {
     159            0 :             tracing::error!("active_listeners count is already 0, something is broken");
     160            0 :             return;
     161            0 :         }
     162            0 :         *listeners_guard -= 1;
     163            0 :         if *listeners_guard == 0 {
     164            0 :             self.ttl_disabled_since_us
     165            0 :                 .store(u64::MAX, std::sync::atomic::Ordering::SeqCst);
     166            0 :         }
     167            0 :     }
     168              : 
     169            0 :     async fn increment_active_listeners(&self) {
     170            0 :         let mut listeners_guard = self.active_listeners_lock.lock().await;
     171            0 :         *listeners_guard += 1;
     172            0 :         if *listeners_guard == 1 {
     173            0 :             let new_ttl = (self.start_time.elapsed() + self.config.ttl).as_micros() as u64;
     174            0 :             self.ttl_disabled_since_us
     175            0 :                 .store(new_ttl, std::sync::atomic::Ordering::SeqCst);
     176            0 :         }
     177            0 :     }
     178              : }
     179              : 
     180              : impl ProjectInfoCacheImpl {
     181            1 :     pub(crate) fn new(config: ProjectInfoCacheOptions) -> Self {
     182            1 :         Self {
     183            1 :             cache: ClashMap::new(),
     184            1 :             project2ep: ClashMap::new(),
     185            1 :             account2ep: ClashMap::new(),
     186            1 :             config,
     187            1 :             ttl_disabled_since_us: AtomicU64::new(u64::MAX),
     188            1 :             start_time: Instant::now(),
     189            1 :             active_listeners_lock: Mutex::new(0),
     190            1 :         }
     191            1 :     }
     192              : 
     193            7 :     fn get_endpoint_cache(
     194            7 :         &self,
     195            7 :         endpoint_id: &EndpointId,
     196            7 :     ) -> Option<Ref<'_, EndpointIdInt, EndpointInfo>> {
     197            7 :         let endpoint_id = EndpointIdInt::get(endpoint_id)?;
     198            7 :         self.cache.get(&endpoint_id)
     199            7 :     }
     200              : 
     201            5 :     pub(crate) fn get_role_secret(
     202            5 :         &self,
     203            5 :         endpoint_id: &EndpointId,
     204            5 :         role_name: &RoleName,
     205            5 :     ) -> Option<RoleAccessControl> {
     206            5 :         let valid_since = self.get_cache_times();
     207            5 :         let role_name = RoleNameInt::get(role_name)?;
     208            5 :         let endpoint_info = self.get_endpoint_cache(endpoint_id)?;
     209            5 :         endpoint_info.get_role_secret(role_name, valid_since)
     210            5 :     }
     211              : 
     212            2 :     pub(crate) fn get_endpoint_access(
     213            2 :         &self,
     214            2 :         endpoint_id: &EndpointId,
     215            2 :     ) -> Option<EndpointAccessControl> {
     216            2 :         let valid_since = self.get_cache_times();
     217            2 :         let endpoint_info = self.get_endpoint_cache(endpoint_id)?;
     218            2 :         endpoint_info.get_controls(valid_since)
     219            2 :     }
     220              : 
     221            3 :     pub(crate) fn insert_endpoint_access(
     222            3 :         &self,
     223            3 :         account_id: Option<AccountIdInt>,
     224            3 :         project_id: ProjectIdInt,
     225            3 :         endpoint_id: EndpointIdInt,
     226            3 :         role_name: RoleNameInt,
     227            3 :         controls: EndpointAccessControl,
     228            3 :         role_controls: RoleAccessControl,
     229            3 :     ) {
     230            3 :         if let Some(account_id) = account_id {
     231            0 :             self.insert_account2endpoint(account_id, endpoint_id);
     232            3 :         }
     233            3 :         self.insert_project2endpoint(project_id, endpoint_id);
     234              : 
     235            3 :         if self.cache.len() >= self.config.size {
     236              :             // If there are too many entries, wait until the next gc cycle.
     237            0 :             return;
     238            3 :         }
     239              : 
     240            3 :         let controls = Entry::from(controls);
     241            3 :         let role_controls = Entry::from(role_controls);
     242              : 
     243            3 :         match self.cache.entry(endpoint_id) {
     244            1 :             clashmap::Entry::Vacant(e) => {
     245            1 :                 e.insert(EndpointInfo {
     246            1 :                     role_controls: HashMap::from_iter([(role_name, role_controls)]),
     247            1 :                     controls: Some(controls),
     248            1 :                 });
     249            1 :             }
     250            2 :             clashmap::Entry::Occupied(mut e) => {
     251            2 :                 let ep = e.get_mut();
     252            2 :                 ep.controls = Some(controls);
     253            2 :                 if ep.role_controls.len() < self.config.max_roles {
     254            1 :                     ep.role_controls.insert(role_name, role_controls);
     255            1 :                 }
     256              :             }
     257              :         }
     258            3 :     }
     259              : 
     260            3 :     fn insert_project2endpoint(&self, project_id: ProjectIdInt, endpoint_id: EndpointIdInt) {
     261            3 :         if let Some(mut endpoints) = self.project2ep.get_mut(&project_id) {
     262            2 :             endpoints.insert(endpoint_id);
     263            2 :         } else {
     264            1 :             self.project2ep
     265            1 :                 .insert(project_id, HashSet::from([endpoint_id]));
     266            1 :         }
     267            3 :     }
     268              : 
     269            0 :     fn insert_account2endpoint(&self, account_id: AccountIdInt, endpoint_id: EndpointIdInt) {
     270            0 :         if let Some(mut endpoints) = self.account2ep.get_mut(&account_id) {
     271            0 :             endpoints.insert(endpoint_id);
     272            0 :         } else {
     273            0 :             self.account2ep
     274            0 :                 .insert(account_id, HashSet::from([endpoint_id]));
     275            0 :         }
     276            0 :     }
     277              : 
     278            7 :     fn ignore_ttl_since(&self) -> Option<Instant> {
     279            7 :         let ttl_disabled_since_us = self
     280            7 :             .ttl_disabled_since_us
     281            7 :             .load(std::sync::atomic::Ordering::Relaxed);
     282              : 
     283            7 :         if ttl_disabled_since_us == u64::MAX {
     284            7 :             return None;
     285            0 :         }
     286              : 
     287            0 :         Some(self.start_time + Duration::from_micros(ttl_disabled_since_us))
     288            7 :     }
     289              : 
     290            7 :     fn get_cache_times(&self) -> Instant {
     291            7 :         let mut valid_since = Instant::now() - self.config.ttl;
     292            7 :         if let Some(ignore_ttl_since) = self.ignore_ttl_since() {
     293            0 :             // We are fine if entry is not older than ttl or was added before we are getting notifications.
     294            0 :             valid_since = valid_since.min(ignore_ttl_since);
     295            7 :         }
     296            7 :         valid_since
     297            7 :     }
     298              : 
     299            0 :     pub fn maybe_invalidate_role_secret(&self, endpoint_id: &EndpointId, role_name: &RoleName) {
     300            0 :         let Some(endpoint_id) = EndpointIdInt::get(endpoint_id) else {
     301            0 :             return;
     302              :         };
     303            0 :         let Some(role_name) = RoleNameInt::get(role_name) else {
     304            0 :             return;
     305              :         };
     306              : 
     307            0 :         let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) else {
     308            0 :             return;
     309              :         };
     310              : 
     311            0 :         let entry = endpoint_info.role_controls.entry(role_name);
     312            0 :         let hash_map::Entry::Occupied(role_controls) = entry else {
     313            0 :             return;
     314              :         };
     315              : 
     316            0 :         let created_at = role_controls.get().created_at;
     317            0 :         let expire = match self.ignore_ttl_since() {
     318              :             // if ignoring TTL, we should still try and roll the password if it's old
     319              :             // and we the client gave an incorrect password. There could be some lag on the redis channel.
     320            0 :             Some(_) => created_at + self.config.ttl < Instant::now(),
     321              :             // edge case: redis is down, let's be generous and invalidate the cache immediately.
     322            0 :             None => true,
     323              :         };
     324              : 
     325            0 :         if expire {
     326            0 :             role_controls.remove();
     327            0 :         }
     328            0 :     }
     329              : 
     330            0 :     pub async fn gc_worker(&self) -> anyhow::Result<Infallible> {
     331            0 :         let mut interval =
     332            0 :             tokio::time::interval(self.config.gc_interval / (self.cache.shards().len()) as u32);
     333              :         loop {
     334            0 :             interval.tick().await;
     335            0 :             if self.cache.len() < self.config.size {
     336              :                 // If there are not too many entries, wait until the next gc cycle.
     337            0 :                 continue;
     338            0 :             }
     339            0 :             self.gc();
     340              :         }
     341              :     }
     342              : 
     343            0 :     fn gc(&self) {
     344            0 :         let shard = thread_rng().gen_range(0..self.project2ep.shards().len());
     345            0 :         debug!(shard, "project_info_cache: performing epoch reclamation");
     346              : 
     347              :         // acquire a random shard lock
     348            0 :         let mut removed = 0;
     349            0 :         let shard = self.project2ep.shards()[shard].write();
     350            0 :         for (_, endpoints) in shard.iter() {
     351            0 :             for endpoint in endpoints {
     352            0 :                 self.cache.remove(endpoint);
     353            0 :                 removed += 1;
     354            0 :             }
     355              :         }
     356              :         // We can drop this shard only after making sure that all endpoints are removed.
     357            0 :         drop(shard);
     358            0 :         info!("project_info_cache: removed {removed} endpoints");
     359            0 :     }
     360              : }
     361              : 
     362              : #[cfg(test)]
     363              : mod tests {
     364              :     use std::sync::Arc;
     365              : 
     366              :     use super::*;
     367              :     use crate::control_plane::messages::EndpointRateLimitConfig;
     368              :     use crate::control_plane::{AccessBlockerFlags, AuthSecret};
     369              :     use crate::scram::ServerSecret;
     370              :     use crate::types::ProjectId;
     371              : 
     372              :     #[tokio::test]
     373            1 :     async fn test_project_info_cache_settings() {
     374            1 :         tokio::time::pause();
     375            1 :         let cache = ProjectInfoCacheImpl::new(ProjectInfoCacheOptions {
     376            1 :             size: 2,
     377            1 :             max_roles: 2,
     378            1 :             ttl: Duration::from_secs(1),
     379            1 :             gc_interval: Duration::from_secs(600),
     380            1 :         });
     381            1 :         let project_id: ProjectId = "project".into();
     382            1 :         let endpoint_id: EndpointId = "endpoint".into();
     383            1 :         let account_id: Option<AccountIdInt> = None;
     384              : 
     385            1 :         let user1: RoleName = "user1".into();
     386            1 :         let user2: RoleName = "user2".into();
     387            1 :         let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32])));
     388            1 :         let secret2 = None;
     389            1 :         let allowed_ips = Arc::new(vec![
     390            1 :             "127.0.0.1".parse().unwrap(),
     391            1 :             "127.0.0.2".parse().unwrap(),
     392              :         ]);
     393              : 
     394            1 :         cache.insert_endpoint_access(
     395            1 :             account_id,
     396            1 :             (&project_id).into(),
     397            1 :             (&endpoint_id).into(),
     398            1 :             (&user1).into(),
     399            1 :             EndpointAccessControl {
     400            1 :                 allowed_ips: allowed_ips.clone(),
     401            1 :                 allowed_vpce: Arc::new(vec![]),
     402            1 :                 flags: AccessBlockerFlags::default(),
     403            1 :                 rate_limits: EndpointRateLimitConfig::default(),
     404            1 :             },
     405            1 :             RoleAccessControl {
     406            1 :                 secret: secret1.clone(),
     407            1 :             },
     408              :         );
     409              : 
     410            1 :         cache.insert_endpoint_access(
     411            1 :             account_id,
     412            1 :             (&project_id).into(),
     413            1 :             (&endpoint_id).into(),
     414            1 :             (&user2).into(),
     415            1 :             EndpointAccessControl {
     416            1 :                 allowed_ips: allowed_ips.clone(),
     417            1 :                 allowed_vpce: Arc::new(vec![]),
     418            1 :                 flags: AccessBlockerFlags::default(),
     419            1 :                 rate_limits: EndpointRateLimitConfig::default(),
     420            1 :             },
     421            1 :             RoleAccessControl {
     422            1 :                 secret: secret2.clone(),
     423            1 :             },
     424              :         );
     425              : 
     426            1 :         let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
     427            1 :         assert_eq!(cached.secret, secret1);
     428              : 
     429            1 :         let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap();
     430            1 :         assert_eq!(cached.secret, secret2);
     431              : 
     432              :         // Shouldn't add more than 2 roles.
     433            1 :         let user3: RoleName = "user3".into();
     434            1 :         let secret3 = Some(AuthSecret::Scram(ServerSecret::mock([3; 32])));
     435              : 
     436            1 :         cache.insert_endpoint_access(
     437            1 :             account_id,
     438            1 :             (&project_id).into(),
     439            1 :             (&endpoint_id).into(),
     440            1 :             (&user3).into(),
     441            1 :             EndpointAccessControl {
     442            1 :                 allowed_ips: allowed_ips.clone(),
     443            1 :                 allowed_vpce: Arc::new(vec![]),
     444            1 :                 flags: AccessBlockerFlags::default(),
     445            1 :                 rate_limits: EndpointRateLimitConfig::default(),
     446            1 :             },
     447            1 :             RoleAccessControl {
     448            1 :                 secret: secret3.clone(),
     449            1 :             },
     450              :         );
     451              : 
     452            1 :         assert!(cache.get_role_secret(&endpoint_id, &user3).is_none());
     453              : 
     454            1 :         let cached = cache.get_endpoint_access(&endpoint_id).unwrap();
     455            1 :         assert_eq!(cached.allowed_ips, allowed_ips);
     456              : 
     457            1 :         tokio::time::advance(Duration::from_secs(2)).await;
     458            1 :         let cached = cache.get_role_secret(&endpoint_id, &user1);
     459            1 :         assert!(cached.is_none());
     460            1 :         let cached = cache.get_role_secret(&endpoint_id, &user2);
     461            1 :         assert!(cached.is_none());
     462            1 :         let cached = cache.get_endpoint_access(&endpoint_id);
     463            1 :         assert!(cached.is_none());
     464            1 :     }
     465              : }
        

Generated by: LCOV version 2.1-beta