LCOV - code coverage report
Current view: top level - proxy/src/cache - project_info.rs (source / functions) Coverage Total Hit
Test: 1d5975439f3c9882b18414799141ebf9a3922c58.info Lines: 80.1 % 311 249
Test Date: 2025-07-31 15:59:03 Functions: 57.7 % 26 15

            Line data    Source code
       1              : use std::collections::HashSet;
       2              : use std::convert::Infallible;
       3              : 
       4              : use clashmap::ClashMap;
       5              : use moka::sync::Cache;
       6              : use tracing::{debug, info};
       7              : 
       8              : use crate::cache::common::{
       9              :     ControlPlaneResult, CplaneExpiry, count_cache_insert, count_cache_outcome, eviction_listener,
      10              : };
      11              : use crate::config::ProjectInfoCacheOptions;
      12              : use crate::control_plane::messages::{ControlPlaneErrorMessage, Reason};
      13              : use crate::control_plane::{EndpointAccessControl, RoleAccessControl};
      14              : use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt};
      15              : use crate::metrics::{CacheKind, Metrics};
      16              : use crate::types::{EndpointId, RoleName};
      17              : 
      18              : /// Cache for project info.
      19              : /// This is used to cache auth data for endpoints.
      20              : /// Invalidation is done by console notifications or by TTL (if console notifications are disabled).
      21              : ///
      22              : /// We also store endpoint-to-project mapping in the cache, to be able to access per-endpoint data.
      23              : /// One may ask, why the data is stored per project, when on the user request there is only data about the endpoint available?
      24              : /// On the cplane side updates are done per project (or per branch), so it's easier to invalidate the whole project cache.
      25              : pub struct ProjectInfoCache {
      26              :     role_controls: Cache<(EndpointIdInt, RoleNameInt), ControlPlaneResult<RoleAccessControl>>,
      27              :     ep_controls: Cache<EndpointIdInt, ControlPlaneResult<EndpointAccessControl>>,
      28              : 
      29              :     project2ep: ClashMap<ProjectIdInt, HashSet<EndpointIdInt>>,
      30              :     // FIXME(stefan): we need a way to GC the account2ep map.
      31              :     account2ep: ClashMap<AccountIdInt, HashSet<EndpointIdInt>>,
      32              : 
      33              :     config: ProjectInfoCacheOptions,
      34              : }
      35              : 
      36              : impl ProjectInfoCache {
      37            0 :     pub fn invalidate_endpoint_access(&self, endpoint_id: EndpointIdInt) {
      38            0 :         info!("invalidating endpoint access for `{endpoint_id}`");
      39            0 :         self.ep_controls.invalidate(&endpoint_id);
      40            0 :     }
      41              : 
      42            0 :     pub fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt) {
      43            0 :         info!("invalidating endpoint access for project `{project_id}`");
      44            0 :         let endpoints = self
      45            0 :             .project2ep
      46            0 :             .get(&project_id)
      47            0 :             .map(|kv| kv.value().clone())
      48            0 :             .unwrap_or_default();
      49            0 :         for endpoint_id in endpoints {
      50            0 :             self.ep_controls.invalidate(&endpoint_id);
      51            0 :         }
      52            0 :     }
      53              : 
      54            0 :     pub fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt) {
      55            0 :         info!("invalidating endpoint access for org `{account_id}`");
      56            0 :         let endpoints = self
      57            0 :             .account2ep
      58            0 :             .get(&account_id)
      59            0 :             .map(|kv| kv.value().clone())
      60            0 :             .unwrap_or_default();
      61            0 :         for endpoint_id in endpoints {
      62            0 :             self.ep_controls.invalidate(&endpoint_id);
      63            0 :         }
      64            0 :     }
      65              : 
      66            0 :     pub fn invalidate_role_secret_for_project(
      67            0 :         &self,
      68            0 :         project_id: ProjectIdInt,
      69            0 :         role_name: RoleNameInt,
      70            0 :     ) {
      71            0 :         info!(
      72            0 :             "invalidating role secret for project_id `{}` and role_name `{}`",
      73              :             project_id, role_name,
      74              :         );
      75            0 :         let endpoints = self
      76            0 :             .project2ep
      77            0 :             .get(&project_id)
      78            0 :             .map(|kv| kv.value().clone())
      79            0 :             .unwrap_or_default();
      80            0 :         for endpoint_id in endpoints {
      81            0 :             self.role_controls.invalidate(&(endpoint_id, role_name));
      82            0 :         }
      83            0 :     }
      84              : }
      85              : 
      86              : impl ProjectInfoCache {
      87            2 :     pub(crate) fn new(config: ProjectInfoCacheOptions) -> Self {
      88            2 :         Metrics::get().cache.capacity.set(
      89            2 :             CacheKind::ProjectInfoRoles,
      90            2 :             (config.size * config.max_roles) as i64,
      91              :         );
      92            2 :         Metrics::get()
      93            2 :             .cache
      94            2 :             .capacity
      95            2 :             .set(CacheKind::ProjectInfoEndpoints, config.size as i64);
      96              : 
      97              :         // we cache errors for 30 seconds, unless retry_at is set.
      98            2 :         let expiry = CplaneExpiry::default();
      99              :         Self {
     100            2 :             role_controls: Cache::builder()
     101            2 :                 .name("project_info_roles")
     102            5 :                 .eviction_listener(|_k, _v, cause| {
     103            5 :                     eviction_listener(CacheKind::ProjectInfoRoles, cause);
     104            5 :                 })
     105            2 :                 .max_capacity(config.size * config.max_roles)
     106            2 :                 .time_to_live(config.ttl)
     107            2 :                 .expire_after(expiry)
     108            2 :                 .build(),
     109            2 :             ep_controls: Cache::builder()
     110            2 :                 .name("project_info_endpoints")
     111            3 :                 .eviction_listener(|_k, _v, cause| {
     112            3 :                     eviction_listener(CacheKind::ProjectInfoEndpoints, cause);
     113            3 :                 })
     114            2 :                 .max_capacity(config.size)
     115            2 :                 .time_to_live(config.ttl)
     116            2 :                 .expire_after(expiry)
     117            2 :                 .build(),
     118            2 :             project2ep: ClashMap::new(),
     119            2 :             account2ep: ClashMap::new(),
     120            2 :             config,
     121              :         }
     122            2 :     }
     123              : 
     124            8 :     pub(crate) fn get_role_secret(
     125            8 :         &self,
     126            8 :         endpoint_id: &EndpointId,
     127            8 :         role_name: &RoleName,
     128            8 :     ) -> Option<ControlPlaneResult<RoleAccessControl>> {
     129            8 :         let endpoint_id = EndpointIdInt::get(endpoint_id)?;
     130            8 :         let role_name = RoleNameInt::get(role_name)?;
     131              : 
     132            7 :         count_cache_outcome(
     133            7 :             CacheKind::ProjectInfoRoles,
     134            7 :             self.role_controls.get(&(endpoint_id, role_name)),
     135              :         )
     136            8 :     }
     137              : 
     138            4 :     pub(crate) fn get_endpoint_access(
     139            4 :         &self,
     140            4 :         endpoint_id: &EndpointId,
     141            4 :     ) -> Option<ControlPlaneResult<EndpointAccessControl>> {
     142            4 :         let endpoint_id = EndpointIdInt::get(endpoint_id)?;
     143              : 
     144            4 :         count_cache_outcome(
     145            4 :             CacheKind::ProjectInfoEndpoints,
     146            4 :             self.ep_controls.get(&endpoint_id),
     147              :         )
     148            4 :     }
     149              : 
     150            4 :     pub(crate) fn insert_endpoint_access(
     151            4 :         &self,
     152            4 :         account_id: Option<AccountIdInt>,
     153            4 :         project_id: Option<ProjectIdInt>,
     154            4 :         endpoint_id: EndpointIdInt,
     155            4 :         role_name: RoleNameInt,
     156            4 :         controls: EndpointAccessControl,
     157            4 :         role_controls: RoleAccessControl,
     158            4 :     ) {
     159            4 :         if let Some(account_id) = account_id {
     160            0 :             self.insert_account2endpoint(account_id, endpoint_id);
     161            4 :         }
     162            4 :         if let Some(project_id) = project_id {
     163            4 :             self.insert_project2endpoint(project_id, endpoint_id);
     164            4 :         }
     165              : 
     166            4 :         debug!(
     167            0 :             key = &*endpoint_id,
     168            0 :             "created a cache entry for endpoint access"
     169              :         );
     170              : 
     171            4 :         count_cache_insert(CacheKind::ProjectInfoEndpoints);
     172            4 :         count_cache_insert(CacheKind::ProjectInfoRoles);
     173              : 
     174            4 :         self.ep_controls.insert(endpoint_id, Ok(controls));
     175            4 :         self.role_controls
     176            4 :             .insert((endpoint_id, role_name), Ok(role_controls));
     177            4 :     }
     178              : 
     179            3 :     pub(crate) fn insert_endpoint_access_err(
     180            3 :         &self,
     181            3 :         endpoint_id: EndpointIdInt,
     182            3 :         role_name: RoleNameInt,
     183            3 :         msg: Box<ControlPlaneErrorMessage>,
     184            3 :     ) {
     185            3 :         debug!(
     186            0 :             key = &*endpoint_id,
     187            0 :             "created a cache entry for an endpoint access error"
     188              :         );
     189              : 
     190              :         // RoleProtected is the only role-specific error that control plane can give us.
     191              :         // If a given role name does not exist, it still returns a successful response,
     192              :         // just with an empty secret.
     193            3 :         if msg.get_reason() != Reason::RoleProtected {
     194              :             // We can cache all the other errors in ep_controls because they don't
     195              :             // depend on what role name we pass to control plane.
     196            2 :             self.ep_controls
     197            2 :                 .entry(endpoint_id)
     198            2 :                 .and_compute_with(|entry| match entry {
     199              :                     // leave the entry alone if it's already Ok
     200            1 :                     Some(entry) if entry.value().is_ok() => moka::ops::compute::Op::Nop,
     201              :                     // replace the entry
     202              :                     _ => {
     203            1 :                         count_cache_insert(CacheKind::ProjectInfoEndpoints);
     204            1 :                         moka::ops::compute::Op::Put(Err(msg.clone()))
     205              :                     }
     206            2 :                 });
     207            1 :         }
     208              : 
     209            3 :         count_cache_insert(CacheKind::ProjectInfoRoles);
     210            3 :         self.role_controls
     211            3 :             .insert((endpoint_id, role_name), Err(msg));
     212            3 :     }
     213              : 
     214            4 :     fn insert_project2endpoint(&self, project_id: ProjectIdInt, endpoint_id: EndpointIdInt) {
     215            4 :         if let Some(mut endpoints) = self.project2ep.get_mut(&project_id) {
     216            2 :             endpoints.insert(endpoint_id);
     217            2 :         } else {
     218            2 :             self.project2ep
     219            2 :                 .insert(project_id, HashSet::from([endpoint_id]));
     220            2 :         }
     221            4 :     }
     222              : 
     223            0 :     fn insert_account2endpoint(&self, account_id: AccountIdInt, endpoint_id: EndpointIdInt) {
     224            0 :         if let Some(mut endpoints) = self.account2ep.get_mut(&account_id) {
     225            0 :             endpoints.insert(endpoint_id);
     226            0 :         } else {
     227            0 :             self.account2ep
     228            0 :                 .insert(account_id, HashSet::from([endpoint_id]));
     229            0 :         }
     230            0 :     }
     231              : 
     232            0 :     pub fn maybe_invalidate_role_secret(&self, _endpoint_id: &EndpointId, _role_name: &RoleName) {
     233              :         // TODO: Expire the value early if the key is idle.
     234              :         // Currently not an issue as we would just use the TTL to decide, which is what already happens.
     235            0 :     }
     236              : 
     237            0 :     pub async fn gc_worker(&self) -> anyhow::Result<Infallible> {
     238            0 :         let mut interval = tokio::time::interval(self.config.gc_interval);
     239              :         loop {
     240            0 :             interval.tick().await;
     241            0 :             self.ep_controls.run_pending_tasks();
     242            0 :             self.role_controls.run_pending_tasks();
     243              :         }
     244              :     }
     245              : }
     246              : 
     247              : #[cfg(test)]
     248              : mod tests {
     249              :     use std::sync::Arc;
     250              :     use std::time::Duration;
     251              : 
     252              :     use super::*;
     253              :     use crate::control_plane::messages::{Details, EndpointRateLimitConfig, ErrorInfo, Status};
     254              :     use crate::control_plane::{AccessBlockerFlags, AuthSecret};
     255              :     use crate::scram::ServerSecret;
     256              : 
     257              :     #[tokio::test]
     258            1 :     async fn test_project_info_cache_settings() {
     259            1 :         let cache = ProjectInfoCache::new(ProjectInfoCacheOptions {
     260            1 :             size: 1,
     261            1 :             max_roles: 2,
     262            1 :             ttl: Duration::from_secs(1),
     263            1 :             gc_interval: Duration::from_secs(600),
     264            1 :         });
     265            1 :         let project_id: Option<ProjectIdInt> = Some(ProjectIdInt::from(&"project".into()));
     266            1 :         let endpoint_id: EndpointId = "endpoint".into();
     267            1 :         let account_id = None;
     268              : 
     269            1 :         let user1: RoleName = "user1".into();
     270            1 :         let user2: RoleName = "user2".into();
     271            1 :         let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32])));
     272            1 :         let secret2 = None;
     273            1 :         let allowed_ips = Arc::new(vec![
     274            1 :             "127.0.0.1".parse().unwrap(),
     275            1 :             "127.0.0.2".parse().unwrap(),
     276              :         ]);
     277              : 
     278            1 :         cache.insert_endpoint_access(
     279            1 :             account_id,
     280            1 :             project_id,
     281            1 :             (&endpoint_id).into(),
     282            1 :             (&user1).into(),
     283            1 :             EndpointAccessControl {
     284            1 :                 allowed_ips: allowed_ips.clone(),
     285            1 :                 allowed_vpce: Arc::new(vec![]),
     286            1 :                 flags: AccessBlockerFlags::default(),
     287            1 :                 rate_limits: EndpointRateLimitConfig::default(),
     288            1 :             },
     289            1 :             RoleAccessControl {
     290            1 :                 secret: secret1.clone(),
     291            1 :             },
     292              :         );
     293              : 
     294            1 :         cache.insert_endpoint_access(
     295            1 :             account_id,
     296            1 :             project_id,
     297            1 :             (&endpoint_id).into(),
     298            1 :             (&user2).into(),
     299            1 :             EndpointAccessControl {
     300            1 :                 allowed_ips: allowed_ips.clone(),
     301            1 :                 allowed_vpce: Arc::new(vec![]),
     302            1 :                 flags: AccessBlockerFlags::default(),
     303            1 :                 rate_limits: EndpointRateLimitConfig::default(),
     304            1 :             },
     305            1 :             RoleAccessControl {
     306            1 :                 secret: secret2.clone(),
     307            1 :             },
     308              :         );
     309              : 
     310            1 :         let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
     311            1 :         assert_eq!(cached.unwrap().secret, secret1);
     312              : 
     313            1 :         let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap();
     314            1 :         assert_eq!(cached.unwrap().secret, secret2);
     315              : 
     316              :         // Shouldn't add more than 2 roles.
     317            1 :         let user3: RoleName = "user3".into();
     318            1 :         let secret3 = Some(AuthSecret::Scram(ServerSecret::mock([3; 32])));
     319              : 
     320            1 :         cache.role_controls.run_pending_tasks();
     321            1 :         cache.insert_endpoint_access(
     322            1 :             account_id,
     323            1 :             project_id,
     324            1 :             (&endpoint_id).into(),
     325            1 :             (&user3).into(),
     326            1 :             EndpointAccessControl {
     327            1 :                 allowed_ips: allowed_ips.clone(),
     328            1 :                 allowed_vpce: Arc::new(vec![]),
     329            1 :                 flags: AccessBlockerFlags::default(),
     330            1 :                 rate_limits: EndpointRateLimitConfig::default(),
     331            1 :             },
     332            1 :             RoleAccessControl {
     333            1 :                 secret: secret3.clone(),
     334            1 :             },
     335              :         );
     336              : 
     337            1 :         cache.role_controls.run_pending_tasks();
     338            1 :         assert_eq!(cache.role_controls.entry_count(), 2);
     339              : 
     340            1 :         tokio::time::sleep(Duration::from_secs(2)).await;
     341              : 
     342            1 :         cache.role_controls.run_pending_tasks();
     343            1 :         assert_eq!(cache.role_controls.entry_count(), 0);
     344            1 :     }
     345              : 
     346              :     #[tokio::test]
     347            1 :     async fn test_caching_project_info_errors() {
     348            1 :         let cache = ProjectInfoCache::new(ProjectInfoCacheOptions {
     349            1 :             size: 10,
     350            1 :             max_roles: 10,
     351            1 :             ttl: Duration::from_secs(1),
     352            1 :             gc_interval: Duration::from_secs(600),
     353            1 :         });
     354            1 :         let project_id = Some(ProjectIdInt::from(&"project".into()));
     355            1 :         let endpoint_id: EndpointId = "endpoint".into();
     356            1 :         let account_id = None;
     357              : 
     358            1 :         let user1: RoleName = "user1".into();
     359            1 :         let user2: RoleName = "user2".into();
     360            1 :         let secret = Some(AuthSecret::Scram(ServerSecret::mock([1; 32])));
     361              : 
     362            1 :         let role_msg = Box::new(ControlPlaneErrorMessage {
     363            1 :             error: "role is protected and cannot be used for password-based authentication"
     364            1 :                 .to_owned()
     365            1 :                 .into_boxed_str(),
     366            1 :             http_status_code: http::StatusCode::NOT_FOUND,
     367            1 :             status: Some(Status {
     368            1 :                 code: "PERMISSION_DENIED".to_owned().into_boxed_str(),
     369            1 :                 message: "role is protected and cannot be used for password-based authentication"
     370            1 :                     .to_owned()
     371            1 :                     .into_boxed_str(),
     372            1 :                 details: Details {
     373            1 :                     error_info: Some(ErrorInfo {
     374            1 :                         reason: Reason::RoleProtected,
     375            1 :                     }),
     376            1 :                     retry_info: None,
     377            1 :                     user_facing_message: None,
     378            1 :                 },
     379            1 :             }),
     380            1 :         });
     381              : 
     382            1 :         let generic_msg = Box::new(ControlPlaneErrorMessage {
     383            1 :             error: "oh noes".to_owned().into_boxed_str(),
     384            1 :             http_status_code: http::StatusCode::NOT_FOUND,
     385            1 :             status: None,
     386            1 :         });
     387              : 
     388            1 :         let get_role_secret =
     389            5 :             |endpoint_id, role_name| cache.get_role_secret(endpoint_id, role_name).unwrap();
     390            3 :         let get_endpoint_access = |endpoint_id| cache.get_endpoint_access(endpoint_id).unwrap();
     391              : 
     392              :         // stores role-specific errors only for get_role_secret
     393            1 :         cache.insert_endpoint_access_err((&endpoint_id).into(), (&user1).into(), role_msg.clone());
     394            1 :         assert_eq!(
     395            1 :             get_role_secret(&endpoint_id, &user1).unwrap_err().error,
     396              :             role_msg.error
     397              :         );
     398            1 :         assert!(cache.get_endpoint_access(&endpoint_id).is_none());
     399              : 
     400              :         // stores non-role specific errors for both get_role_secret and get_endpoint_access
     401            1 :         cache.insert_endpoint_access_err(
     402            1 :             (&endpoint_id).into(),
     403            1 :             (&user1).into(),
     404            1 :             generic_msg.clone(),
     405              :         );
     406            1 :         assert_eq!(
     407            1 :             get_role_secret(&endpoint_id, &user1).unwrap_err().error,
     408              :             generic_msg.error
     409              :         );
     410            1 :         assert_eq!(
     411            1 :             get_endpoint_access(&endpoint_id).unwrap_err().error,
     412              :             generic_msg.error
     413              :         );
     414              : 
     415              :         // error isn't returned for other roles in the same endpoint
     416            1 :         assert!(cache.get_role_secret(&endpoint_id, &user2).is_none());
     417              : 
     418              :         // success for a role does not overwrite errors for other roles
     419            1 :         cache.insert_endpoint_access(
     420            1 :             account_id,
     421            1 :             project_id,
     422            1 :             (&endpoint_id).into(),
     423            1 :             (&user2).into(),
     424            1 :             EndpointAccessControl {
     425            1 :                 allowed_ips: Arc::new(vec![]),
     426            1 :                 allowed_vpce: Arc::new(vec![]),
     427            1 :                 flags: AccessBlockerFlags::default(),
     428            1 :                 rate_limits: EndpointRateLimitConfig::default(),
     429            1 :             },
     430            1 :             RoleAccessControl {
     431            1 :                 secret: secret.clone(),
     432            1 :             },
     433              :         );
     434            1 :         assert!(get_role_secret(&endpoint_id, &user1).is_err());
     435            1 :         assert!(get_role_secret(&endpoint_id, &user2).is_ok());
     436              :         // ...but does clear the access control error
     437            1 :         assert!(get_endpoint_access(&endpoint_id).is_ok());
     438              : 
     439              :         // storing an error does not overwrite successful access control response
     440            1 :         cache.insert_endpoint_access_err(
     441            1 :             (&endpoint_id).into(),
     442            1 :             (&user2).into(),
     443            1 :             generic_msg.clone(),
     444              :         );
     445            1 :         assert!(get_role_secret(&endpoint_id, &user2).is_err());
     446            1 :         assert!(get_endpoint_access(&endpoint_id).is_ok());
     447            1 :     }
     448              : }
        

Generated by: LCOV version 2.1-beta