LCOV - code coverage report
Current view: top level - proxy/src/cache - project_info.rs (source / functions) Coverage Total Hit
Test: 9ea6db2ee0fa9d49a3b75230199d5aaf7b855d49.info Lines: 63.1 % 255 161
Test Date: 2025-07-18 08:04:09 Functions: 50.0 % 28 14

            Line data    Source code
       1              : use std::collections::{HashMap, HashSet, hash_map};
       2              : use std::convert::Infallible;
       3              : use std::time::Duration;
       4              : 
       5              : use async_trait::async_trait;
       6              : use clashmap::ClashMap;
       7              : use clashmap::mapref::one::Ref;
       8              : use rand::{Rng, thread_rng};
       9              : use tokio::time::Instant;
      10              : use tracing::{debug, info};
      11              : 
      12              : use crate::config::ProjectInfoCacheOptions;
      13              : use crate::control_plane::{EndpointAccessControl, RoleAccessControl};
      14              : use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt};
      15              : use crate::types::{EndpointId, RoleName};
      16              : 
      17              : #[async_trait]
      18              : pub(crate) trait ProjectInfoCache {
      19              :     fn invalidate_endpoint_access(&self, endpoint_id: EndpointIdInt);
      20              :     fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt);
      21              :     fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt);
      22              :     fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt);
      23              : }
      24              : 
      25              : struct Entry<T> {
      26              :     expires_at: Instant,
      27              :     value: T,
      28              : }
      29              : 
      30              : impl<T> Entry<T> {
      31            6 :     pub(crate) fn new(value: T, ttl: Duration) -> Self {
      32            6 :         Self {
      33            6 :             expires_at: Instant::now() + ttl,
      34            6 :             value,
      35            6 :         }
      36            6 :     }
      37              : 
      38            6 :     pub(crate) fn get(&self) -> Option<&T> {
      39            6 :         (self.expires_at > Instant::now()).then_some(&self.value)
      40            6 :     }
      41              : }
      42              : 
      43              : struct EndpointInfo {
      44              :     role_controls: HashMap<RoleNameInt, Entry<RoleAccessControl>>,
      45              :     controls: Option<Entry<EndpointAccessControl>>,
      46              : }
      47              : 
      48              : impl EndpointInfo {
      49            5 :     pub(crate) fn get_role_secret(&self, role_name: RoleNameInt) -> Option<RoleAccessControl> {
      50            5 :         self.role_controls.get(&role_name)?.get().cloned()
      51            5 :     }
      52              : 
      53            2 :     pub(crate) fn get_controls(&self) -> Option<EndpointAccessControl> {
      54            2 :         self.controls.as_ref()?.get().cloned()
      55            2 :     }
      56              : 
      57            0 :     pub(crate) fn invalidate_endpoint(&mut self) {
      58            0 :         self.controls = None;
      59            0 :     }
      60              : 
      61            0 :     pub(crate) fn invalidate_role_secret(&mut self, role_name: RoleNameInt) {
      62            0 :         self.role_controls.remove(&role_name);
      63            0 :     }
      64              : }
      65              : 
      66              : /// Cache for project info.
      67              : /// This is used to cache auth data for endpoints.
      68              : /// Invalidation is done by console notifications or by TTL (if console notifications are disabled).
      69              : ///
      70              : /// We also store endpoint-to-project mapping in the cache, to be able to access per-endpoint data.
      71              : /// One may ask, why the data is stored per project, when on the user request there is only data about the endpoint available?
      72              : /// On the cplane side updates are done per project (or per branch), so it's easier to invalidate the whole project cache.
      73              : pub struct ProjectInfoCacheImpl {
      74              :     cache: ClashMap<EndpointIdInt, EndpointInfo>,
      75              : 
      76              :     project2ep: ClashMap<ProjectIdInt, HashSet<EndpointIdInt>>,
      77              :     // FIXME(stefan): we need a way to GC the account2ep map.
      78              :     account2ep: ClashMap<AccountIdInt, HashSet<EndpointIdInt>>,
      79              : 
      80              :     config: ProjectInfoCacheOptions,
      81              : }
      82              : 
      83              : #[async_trait]
      84              : impl ProjectInfoCache for ProjectInfoCacheImpl {
      85            0 :     fn invalidate_endpoint_access(&self, endpoint_id: EndpointIdInt) {
      86            0 :         info!("invalidating endpoint access for `{endpoint_id}`");
      87            0 :         if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
      88            0 :             endpoint_info.invalidate_endpoint();
      89            0 :         }
      90            0 :     }
      91              : 
      92            0 :     fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt) {
      93            0 :         info!("invalidating endpoint access for project `{project_id}`");
      94            0 :         let endpoints = self
      95            0 :             .project2ep
      96            0 :             .get(&project_id)
      97            0 :             .map(|kv| kv.value().clone())
      98            0 :             .unwrap_or_default();
      99            0 :         for endpoint_id in endpoints {
     100            0 :             if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
     101            0 :                 endpoint_info.invalidate_endpoint();
     102            0 :             }
     103              :         }
     104            0 :     }
     105              : 
     106            0 :     fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt) {
     107            0 :         info!("invalidating endpoint access for org `{account_id}`");
     108            0 :         let endpoints = self
     109            0 :             .account2ep
     110            0 :             .get(&account_id)
     111            0 :             .map(|kv| kv.value().clone())
     112            0 :             .unwrap_or_default();
     113            0 :         for endpoint_id in endpoints {
     114            0 :             if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
     115            0 :                 endpoint_info.invalidate_endpoint();
     116            0 :             }
     117              :         }
     118            0 :     }
     119              : 
     120            0 :     fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt) {
     121            0 :         info!(
     122            0 :             "invalidating role secret for project_id `{}` and role_name `{}`",
     123              :             project_id, role_name,
     124              :         );
     125            0 :         let endpoints = self
     126            0 :             .project2ep
     127            0 :             .get(&project_id)
     128            0 :             .map(|kv| kv.value().clone())
     129            0 :             .unwrap_or_default();
     130            0 :         for endpoint_id in endpoints {
     131            0 :             if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
     132            0 :                 endpoint_info.invalidate_role_secret(role_name);
     133            0 :             }
     134              :         }
     135            0 :     }
     136              : }
     137              : 
     138              : impl ProjectInfoCacheImpl {
     139            1 :     pub(crate) fn new(config: ProjectInfoCacheOptions) -> Self {
     140            1 :         Self {
     141            1 :             cache: ClashMap::new(),
     142            1 :             project2ep: ClashMap::new(),
     143            1 :             account2ep: ClashMap::new(),
     144            1 :             config,
     145            1 :         }
     146            1 :     }
     147              : 
     148            7 :     fn get_endpoint_cache(
     149            7 :         &self,
     150            7 :         endpoint_id: &EndpointId,
     151            7 :     ) -> Option<Ref<'_, EndpointIdInt, EndpointInfo>> {
     152            7 :         let endpoint_id = EndpointIdInt::get(endpoint_id)?;
     153            7 :         self.cache.get(&endpoint_id)
     154            7 :     }
     155              : 
     156            5 :     pub(crate) fn get_role_secret(
     157            5 :         &self,
     158            5 :         endpoint_id: &EndpointId,
     159            5 :         role_name: &RoleName,
     160            5 :     ) -> Option<RoleAccessControl> {
     161            5 :         let role_name = RoleNameInt::get(role_name)?;
     162            5 :         let endpoint_info = self.get_endpoint_cache(endpoint_id)?;
     163            5 :         endpoint_info.get_role_secret(role_name)
     164            5 :     }
     165              : 
     166            2 :     pub(crate) fn get_endpoint_access(
     167            2 :         &self,
     168            2 :         endpoint_id: &EndpointId,
     169            2 :     ) -> Option<EndpointAccessControl> {
     170            2 :         let endpoint_info = self.get_endpoint_cache(endpoint_id)?;
     171            2 :         endpoint_info.get_controls()
     172            2 :     }
     173              : 
     174            3 :     pub(crate) fn insert_endpoint_access(
     175            3 :         &self,
     176            3 :         account_id: Option<AccountIdInt>,
     177            3 :         project_id: ProjectIdInt,
     178            3 :         endpoint_id: EndpointIdInt,
     179            3 :         role_name: RoleNameInt,
     180            3 :         controls: EndpointAccessControl,
     181            3 :         role_controls: RoleAccessControl,
     182            3 :     ) {
     183            3 :         if let Some(account_id) = account_id {
     184            0 :             self.insert_account2endpoint(account_id, endpoint_id);
     185            3 :         }
     186            3 :         self.insert_project2endpoint(project_id, endpoint_id);
     187              : 
     188            3 :         if self.cache.len() >= self.config.size {
     189              :             // If there are too many entries, wait until the next gc cycle.
     190            0 :             return;
     191            3 :         }
     192              : 
     193            3 :         let controls = Entry::new(controls, self.config.ttl);
     194            3 :         let role_controls = Entry::new(role_controls, self.config.ttl);
     195              : 
     196            3 :         match self.cache.entry(endpoint_id) {
     197            1 :             clashmap::Entry::Vacant(e) => {
     198            1 :                 e.insert(EndpointInfo {
     199            1 :                     role_controls: HashMap::from_iter([(role_name, role_controls)]),
     200            1 :                     controls: Some(controls),
     201            1 :                 });
     202            1 :             }
     203            2 :             clashmap::Entry::Occupied(mut e) => {
     204            2 :                 let ep = e.get_mut();
     205            2 :                 ep.controls = Some(controls);
     206            2 :                 if ep.role_controls.len() < self.config.max_roles {
     207            1 :                     ep.role_controls.insert(role_name, role_controls);
     208            1 :                 }
     209              :             }
     210              :         }
     211            3 :     }
     212              : 
     213            3 :     fn insert_project2endpoint(&self, project_id: ProjectIdInt, endpoint_id: EndpointIdInt) {
     214            3 :         if let Some(mut endpoints) = self.project2ep.get_mut(&project_id) {
     215            2 :             endpoints.insert(endpoint_id);
     216            2 :         } else {
     217            1 :             self.project2ep
     218            1 :                 .insert(project_id, HashSet::from([endpoint_id]));
     219            1 :         }
     220            3 :     }
     221              : 
     222            0 :     fn insert_account2endpoint(&self, account_id: AccountIdInt, endpoint_id: EndpointIdInt) {
     223            0 :         if let Some(mut endpoints) = self.account2ep.get_mut(&account_id) {
     224            0 :             endpoints.insert(endpoint_id);
     225            0 :         } else {
     226            0 :             self.account2ep
     227            0 :                 .insert(account_id, HashSet::from([endpoint_id]));
     228            0 :         }
     229            0 :     }
     230              : 
     231            0 :     pub fn maybe_invalidate_role_secret(&self, endpoint_id: &EndpointId, role_name: &RoleName) {
     232            0 :         let Some(endpoint_id) = EndpointIdInt::get(endpoint_id) else {
     233            0 :             return;
     234              :         };
     235            0 :         let Some(role_name) = RoleNameInt::get(role_name) else {
     236            0 :             return;
     237              :         };
     238              : 
     239            0 :         let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) else {
     240            0 :             return;
     241              :         };
     242              : 
     243            0 :         let entry = endpoint_info.role_controls.entry(role_name);
     244            0 :         let hash_map::Entry::Occupied(role_controls) = entry else {
     245            0 :             return;
     246              :         };
     247              : 
     248            0 :         if role_controls.get().expires_at <= Instant::now() {
     249            0 :             role_controls.remove();
     250            0 :         }
     251            0 :     }
     252              : 
     253            0 :     pub async fn gc_worker(&self) -> anyhow::Result<Infallible> {
     254            0 :         let mut interval =
     255            0 :             tokio::time::interval(self.config.gc_interval / (self.cache.shards().len()) as u32);
     256              :         loop {
     257            0 :             interval.tick().await;
     258            0 :             if self.cache.len() < self.config.size {
     259              :                 // If there are not too many entries, wait until the next gc cycle.
     260            0 :                 continue;
     261            0 :             }
     262            0 :             self.gc();
     263              :         }
     264              :     }
     265              : 
     266            0 :     fn gc(&self) {
     267            0 :         let shard = thread_rng().gen_range(0..self.project2ep.shards().len());
     268            0 :         debug!(shard, "project_info_cache: performing epoch reclamation");
     269              : 
     270              :         // acquire a random shard lock
     271            0 :         let mut removed = 0;
     272            0 :         let shard = self.project2ep.shards()[shard].write();
     273            0 :         for (_, endpoints) in shard.iter() {
     274            0 :             for endpoint in endpoints {
     275            0 :                 self.cache.remove(endpoint);
     276            0 :                 removed += 1;
     277            0 :             }
     278              :         }
     279              :         // We can drop this shard only after making sure that all endpoints are removed.
     280            0 :         drop(shard);
     281            0 :         info!("project_info_cache: removed {removed} endpoints");
     282            0 :     }
     283              : }
     284              : 
     285              : #[cfg(test)]
     286              : mod tests {
     287              :     use std::sync::Arc;
     288              : 
     289              :     use super::*;
     290              :     use crate::control_plane::messages::EndpointRateLimitConfig;
     291              :     use crate::control_plane::{AccessBlockerFlags, AuthSecret};
     292              :     use crate::scram::ServerSecret;
     293              :     use crate::types::ProjectId;
     294              : 
     295              :     #[tokio::test]
     296            1 :     async fn test_project_info_cache_settings() {
     297            1 :         tokio::time::pause();
     298            1 :         let cache = ProjectInfoCacheImpl::new(ProjectInfoCacheOptions {
     299            1 :             size: 2,
     300            1 :             max_roles: 2,
     301            1 :             ttl: Duration::from_secs(1),
     302            1 :             gc_interval: Duration::from_secs(600),
     303            1 :         });
     304            1 :         let project_id: ProjectId = "project".into();
     305            1 :         let endpoint_id: EndpointId = "endpoint".into();
     306            1 :         let account_id: Option<AccountIdInt> = None;
     307              : 
     308            1 :         let user1: RoleName = "user1".into();
     309            1 :         let user2: RoleName = "user2".into();
     310            1 :         let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32])));
     311            1 :         let secret2 = None;
     312            1 :         let allowed_ips = Arc::new(vec![
     313            1 :             "127.0.0.1".parse().unwrap(),
     314            1 :             "127.0.0.2".parse().unwrap(),
     315              :         ]);
     316              : 
     317            1 :         cache.insert_endpoint_access(
     318            1 :             account_id,
     319            1 :             (&project_id).into(),
     320            1 :             (&endpoint_id).into(),
     321            1 :             (&user1).into(),
     322            1 :             EndpointAccessControl {
     323            1 :                 allowed_ips: allowed_ips.clone(),
     324            1 :                 allowed_vpce: Arc::new(vec![]),
     325            1 :                 flags: AccessBlockerFlags::default(),
     326            1 :                 rate_limits: EndpointRateLimitConfig::default(),
     327            1 :             },
     328            1 :             RoleAccessControl {
     329            1 :                 secret: secret1.clone(),
     330            1 :             },
     331              :         );
     332              : 
     333            1 :         cache.insert_endpoint_access(
     334            1 :             account_id,
     335            1 :             (&project_id).into(),
     336            1 :             (&endpoint_id).into(),
     337            1 :             (&user2).into(),
     338            1 :             EndpointAccessControl {
     339            1 :                 allowed_ips: allowed_ips.clone(),
     340            1 :                 allowed_vpce: Arc::new(vec![]),
     341            1 :                 flags: AccessBlockerFlags::default(),
     342            1 :                 rate_limits: EndpointRateLimitConfig::default(),
     343            1 :             },
     344            1 :             RoleAccessControl {
     345            1 :                 secret: secret2.clone(),
     346            1 :             },
     347              :         );
     348              : 
     349            1 :         let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
     350            1 :         assert_eq!(cached.secret, secret1);
     351              : 
     352            1 :         let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap();
     353            1 :         assert_eq!(cached.secret, secret2);
     354              : 
     355              :         // Shouldn't add more than 2 roles.
     356            1 :         let user3: RoleName = "user3".into();
     357            1 :         let secret3 = Some(AuthSecret::Scram(ServerSecret::mock([3; 32])));
     358              : 
     359            1 :         cache.insert_endpoint_access(
     360            1 :             account_id,
     361            1 :             (&project_id).into(),
     362            1 :             (&endpoint_id).into(),
     363            1 :             (&user3).into(),
     364            1 :             EndpointAccessControl {
     365            1 :                 allowed_ips: allowed_ips.clone(),
     366            1 :                 allowed_vpce: Arc::new(vec![]),
     367            1 :                 flags: AccessBlockerFlags::default(),
     368            1 :                 rate_limits: EndpointRateLimitConfig::default(),
     369            1 :             },
     370            1 :             RoleAccessControl {
     371            1 :                 secret: secret3.clone(),
     372            1 :             },
     373              :         );
     374              : 
     375            1 :         assert!(cache.get_role_secret(&endpoint_id, &user3).is_none());
     376              : 
     377            1 :         let cached = cache.get_endpoint_access(&endpoint_id).unwrap();
     378            1 :         assert_eq!(cached.allowed_ips, allowed_ips);
     379              : 
     380            1 :         tokio::time::advance(Duration::from_secs(2)).await;
     381            1 :         let cached = cache.get_role_secret(&endpoint_id, &user1);
     382            1 :         assert!(cached.is_none());
     383            1 :         let cached = cache.get_role_secret(&endpoint_id, &user2);
     384            1 :         assert!(cached.is_none());
     385            1 :         let cached = cache.get_endpoint_access(&endpoint_id);
     386            1 :         assert!(cached.is_none());
     387            1 :     }
     388              : }
        

Generated by: LCOV version 2.1-beta