LCOV - code coverage report
Current view: top level - proxy/src/cache - project_info.rs (source / functions) Coverage Total Hit
Test: 4be46b1c0003aa3bbac9ade362c676b419df4c20.info Lines: 76.1 % 415 316
Test Date: 2025-07-22 17:50:06 Functions: 60.0 % 35 21

            Line data    Source code
       1              : use std::collections::{HashMap, HashSet, hash_map};
       2              : use std::convert::Infallible;
       3              : use std::time::Duration;
       4              : 
       5              : use async_trait::async_trait;
       6              : use clashmap::ClashMap;
       7              : use clashmap::mapref::one::Ref;
       8              : use rand::Rng;
       9              : use tokio::time::Instant;
      10              : use tracing::{debug, info};
      11              : 
      12              : use crate::config::ProjectInfoCacheOptions;
      13              : use crate::control_plane::messages::{ControlPlaneErrorMessage, Reason};
      14              : use crate::control_plane::{EndpointAccessControl, RoleAccessControl};
      15              : use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt};
      16              : use crate::types::{EndpointId, RoleName};
      17              : 
      18              : #[async_trait]
      19              : pub(crate) trait ProjectInfoCache {
      20              :     fn invalidate_endpoint_access(&self, endpoint_id: EndpointIdInt);
      21              :     fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt);
      22              :     fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt);
      23              :     fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt);
      24              : }
      25              : 
      26              : struct Entry<T> {
      27              :     expires_at: Instant,
      28              :     value: T,
      29              : }
      30              : 
      31              : impl<T> Entry<T> {
      32           13 :     pub(crate) fn new(value: T, ttl: Duration) -> Self {
      33           13 :         Self {
      34           13 :             expires_at: Instant::now() + ttl,
      35           13 :             value,
      36           13 :         }
      37           13 :     }
      38              : 
      39           14 :     pub(crate) fn get(&self) -> Option<&T> {
      40           14 :         (!self.is_expired()).then_some(&self.value)
      41           14 :     }
      42              : 
      43           15 :     fn is_expired(&self) -> bool {
      44           15 :         self.expires_at <= Instant::now()
      45           15 :     }
      46              : }
      47              : 
      48              : struct EndpointInfo {
      49              :     role_controls: HashMap<RoleNameInt, Entry<ControlPlaneResult<RoleAccessControl>>>,
      50              :     controls: Option<Entry<ControlPlaneResult<EndpointAccessControl>>>,
      51              : }
      52              : 
      53              : type ControlPlaneResult<T> = Result<T, Box<ControlPlaneErrorMessage>>;
      54              : 
      55              : impl EndpointInfo {
      56           10 :     pub(crate) fn get_role_secret_with_ttl(
      57           10 :         &self,
      58           10 :         role_name: RoleNameInt,
      59           10 :     ) -> Option<(ControlPlaneResult<RoleAccessControl>, Duration)> {
      60           10 :         let entry = self.role_controls.get(&role_name)?;
      61            9 :         let ttl = entry.expires_at - Instant::now();
      62            9 :         Some((entry.get()?.clone(), ttl))
      63           10 :     }
      64              : 
      65            6 :     pub(crate) fn get_controls_with_ttl(
      66            6 :         &self,
      67            6 :     ) -> Option<(ControlPlaneResult<EndpointAccessControl>, Duration)> {
      68            6 :         let entry = self.controls.as_ref()?;
      69            5 :         let ttl = entry.expires_at - Instant::now();
      70            5 :         Some((entry.get()?.clone(), ttl))
      71            6 :     }
      72              : 
      73            0 :     pub(crate) fn invalidate_endpoint(&mut self) {
      74            0 :         self.controls = None;
      75            0 :     }
      76              : 
      77            0 :     pub(crate) fn invalidate_role_secret(&mut self, role_name: RoleNameInt) {
      78            0 :         self.role_controls.remove(&role_name);
      79            0 :     }
      80              : }
      81              : 
      82              : /// Cache for project info.
      83              : /// This is used to cache auth data for endpoints.
      84              : /// Invalidation is done by console notifications or by TTL (if console notifications are disabled).
      85              : ///
      86              : /// We also store endpoint-to-project mapping in the cache, to be able to access per-endpoint data.
      87              : /// One may ask, why the data is stored per project, when on the user request there is only data about the endpoint available?
      88              : /// On the cplane side updates are done per project (or per branch), so it's easier to invalidate the whole project cache.
      89              : pub struct ProjectInfoCacheImpl {
      90              :     cache: ClashMap<EndpointIdInt, EndpointInfo>,
      91              : 
      92              :     project2ep: ClashMap<ProjectIdInt, HashSet<EndpointIdInt>>,
      93              :     // FIXME(stefan): we need a way to GC the account2ep map.
      94              :     account2ep: ClashMap<AccountIdInt, HashSet<EndpointIdInt>>,
      95              : 
      96              :     config: ProjectInfoCacheOptions,
      97              : }
      98              : 
      99              : #[async_trait]
     100              : impl ProjectInfoCache for ProjectInfoCacheImpl {
     101            0 :     fn invalidate_endpoint_access(&self, endpoint_id: EndpointIdInt) {
     102            0 :         info!("invalidating endpoint access for `{endpoint_id}`");
     103            0 :         if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
     104            0 :             endpoint_info.invalidate_endpoint();
     105            0 :         }
     106            0 :     }
     107              : 
     108            0 :     fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt) {
     109            0 :         info!("invalidating endpoint access for project `{project_id}`");
     110            0 :         let endpoints = self
     111            0 :             .project2ep
     112            0 :             .get(&project_id)
     113            0 :             .map(|kv| kv.value().clone())
     114            0 :             .unwrap_or_default();
     115            0 :         for endpoint_id in endpoints {
     116            0 :             if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
     117            0 :                 endpoint_info.invalidate_endpoint();
     118            0 :             }
     119              :         }
     120            0 :     }
     121              : 
     122            0 :     fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt) {
     123            0 :         info!("invalidating endpoint access for org `{account_id}`");
     124            0 :         let endpoints = self
     125            0 :             .account2ep
     126            0 :             .get(&account_id)
     127            0 :             .map(|kv| kv.value().clone())
     128            0 :             .unwrap_or_default();
     129            0 :         for endpoint_id in endpoints {
     130            0 :             if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
     131            0 :                 endpoint_info.invalidate_endpoint();
     132            0 :             }
     133              :         }
     134            0 :     }
     135              : 
     136            0 :     fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt) {
     137            0 :         info!(
     138            0 :             "invalidating role secret for project_id `{}` and role_name `{}`",
     139              :             project_id, role_name,
     140              :         );
     141            0 :         let endpoints = self
     142            0 :             .project2ep
     143            0 :             .get(&project_id)
     144            0 :             .map(|kv| kv.value().clone())
     145            0 :             .unwrap_or_default();
     146            0 :         for endpoint_id in endpoints {
     147            0 :             if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
     148            0 :                 endpoint_info.invalidate_role_secret(role_name);
     149            0 :             }
     150              :         }
     151            0 :     }
     152              : }
     153              : 
     154              : impl ProjectInfoCacheImpl {
     155            2 :     pub(crate) fn new(config: ProjectInfoCacheOptions) -> Self {
     156            2 :         Self {
     157            2 :             cache: ClashMap::new(),
     158            2 :             project2ep: ClashMap::new(),
     159            2 :             account2ep: ClashMap::new(),
     160            2 :             config,
     161            2 :         }
     162            2 :     }
     163              : 
     164           16 :     fn get_endpoint_cache(
     165           16 :         &self,
     166           16 :         endpoint_id: &EndpointId,
     167           16 :     ) -> Option<Ref<'_, EndpointIdInt, EndpointInfo>> {
     168           16 :         let endpoint_id = EndpointIdInt::get(endpoint_id)?;
     169           16 :         self.cache.get(&endpoint_id)
     170           16 :     }
     171              : 
     172           11 :     pub(crate) fn get_role_secret_with_ttl(
     173           11 :         &self,
     174           11 :         endpoint_id: &EndpointId,
     175           11 :         role_name: &RoleName,
     176           11 :     ) -> Option<(ControlPlaneResult<RoleAccessControl>, Duration)> {
     177           11 :         let role_name = RoleNameInt::get(role_name)?;
     178           10 :         let endpoint_info = self.get_endpoint_cache(endpoint_id)?;
     179           10 :         endpoint_info.get_role_secret_with_ttl(role_name)
     180           11 :     }
     181              : 
     182            6 :     pub(crate) fn get_endpoint_access_with_ttl(
     183            6 :         &self,
     184            6 :         endpoint_id: &EndpointId,
     185            6 :     ) -> Option<(ControlPlaneResult<EndpointAccessControl>, Duration)> {
     186            6 :         let endpoint_info = self.get_endpoint_cache(endpoint_id)?;
     187            6 :         endpoint_info.get_controls_with_ttl()
     188            6 :     }
     189              : 
     190            4 :     pub(crate) fn insert_endpoint_access(
     191            4 :         &self,
     192            4 :         account_id: Option<AccountIdInt>,
     193            4 :         project_id: Option<ProjectIdInt>,
     194            4 :         endpoint_id: EndpointIdInt,
     195            4 :         role_name: RoleNameInt,
     196            4 :         controls: EndpointAccessControl,
     197            4 :         role_controls: RoleAccessControl,
     198            4 :     ) {
     199            4 :         if let Some(account_id) = account_id {
     200            0 :             self.insert_account2endpoint(account_id, endpoint_id);
     201            4 :         }
     202            4 :         if let Some(project_id) = project_id {
     203            4 :             self.insert_project2endpoint(project_id, endpoint_id);
     204            4 :         }
     205              : 
     206            4 :         if self.cache.len() >= self.config.size {
     207              :             // If there are too many entries, wait until the next gc cycle.
     208            0 :             return;
     209            4 :         }
     210              : 
     211            4 :         debug!(
     212            0 :             key = &*endpoint_id,
     213            0 :             "created a cache entry for endpoint access"
     214              :         );
     215              : 
     216            4 :         let controls = Some(Entry::new(Ok(controls), self.config.ttl));
     217            4 :         let role_controls = Entry::new(Ok(role_controls), self.config.ttl);
     218              : 
     219            4 :         match self.cache.entry(endpoint_id) {
     220            1 :             clashmap::Entry::Vacant(e) => {
     221            1 :                 e.insert(EndpointInfo {
     222            1 :                     role_controls: HashMap::from_iter([(role_name, role_controls)]),
     223            1 :                     controls,
     224            1 :                 });
     225            1 :             }
     226            3 :             clashmap::Entry::Occupied(mut e) => {
     227            3 :                 let ep = e.get_mut();
     228            3 :                 ep.controls = controls;
     229            3 :                 if ep.role_controls.len() < self.config.max_roles {
     230            2 :                     ep.role_controls.insert(role_name, role_controls);
     231            2 :                 }
     232              :             }
     233              :         }
     234            4 :     }
     235              : 
     236            3 :     pub(crate) fn insert_endpoint_access_err(
     237            3 :         &self,
     238            3 :         endpoint_id: EndpointIdInt,
     239            3 :         role_name: RoleNameInt,
     240            3 :         msg: Box<ControlPlaneErrorMessage>,
     241            3 :         ttl: Option<Duration>,
     242            3 :     ) {
     243            3 :         if self.cache.len() >= self.config.size {
     244              :             // If there are too many entries, wait until the next gc cycle.
     245            0 :             return;
     246            3 :         }
     247              : 
     248            3 :         debug!(
     249            0 :             key = &*endpoint_id,
     250            0 :             "created a cache entry for an endpoint access error"
     251              :         );
     252              : 
     253            3 :         let ttl = ttl.unwrap_or(self.config.ttl);
     254              : 
     255            3 :         let controls = if msg.get_reason() == Reason::RoleProtected {
     256              :             // RoleProtected is the only role-specific error that control plane can give us.
     257              :             // If a given role name does not exist, it still returns a successful response,
     258              :             // just with an empty secret.
     259            1 :             None
     260              :         } else {
     261              :             // We can cache all the other errors in EndpointInfo.controls,
     262              :             // because they don't depend on what role name we pass to control plane.
     263            2 :             Some(Entry::new(Err(msg.clone()), ttl))
     264              :         };
     265              : 
     266            3 :         let role_controls = Entry::new(Err(msg), ttl);
     267              : 
     268            3 :         match self.cache.entry(endpoint_id) {
     269            1 :             clashmap::Entry::Vacant(e) => {
     270            1 :                 e.insert(EndpointInfo {
     271            1 :                     role_controls: HashMap::from_iter([(role_name, role_controls)]),
     272            1 :                     controls,
     273            1 :                 });
     274            1 :             }
     275            2 :             clashmap::Entry::Occupied(mut e) => {
     276            2 :                 let ep = e.get_mut();
     277            2 :                 if let Some(entry) = &ep.controls
     278            1 :                     && !entry.is_expired()
     279            1 :                     && entry.value.is_ok()
     280            1 :                 {
     281            1 :                     // If we have cached non-expired, non-error controls, keep them.
     282            1 :                 } else {
     283            1 :                     ep.controls = controls;
     284            1 :                 }
     285            2 :                 if ep.role_controls.len() < self.config.max_roles {
     286            2 :                     ep.role_controls.insert(role_name, role_controls);
     287            2 :                 }
     288              :             }
     289              :         }
     290            3 :     }
     291              : 
     292            4 :     fn insert_project2endpoint(&self, project_id: ProjectIdInt, endpoint_id: EndpointIdInt) {
     293            4 :         if let Some(mut endpoints) = self.project2ep.get_mut(&project_id) {
     294            2 :             endpoints.insert(endpoint_id);
     295            2 :         } else {
     296            2 :             self.project2ep
     297            2 :                 .insert(project_id, HashSet::from([endpoint_id]));
     298            2 :         }
     299            4 :     }
     300              : 
     301            0 :     fn insert_account2endpoint(&self, account_id: AccountIdInt, endpoint_id: EndpointIdInt) {
     302            0 :         if let Some(mut endpoints) = self.account2ep.get_mut(&account_id) {
     303            0 :             endpoints.insert(endpoint_id);
     304            0 :         } else {
     305            0 :             self.account2ep
     306            0 :                 .insert(account_id, HashSet::from([endpoint_id]));
     307            0 :         }
     308            0 :     }
     309              : 
     310            0 :     pub fn maybe_invalidate_role_secret(&self, endpoint_id: &EndpointId, role_name: &RoleName) {
     311            0 :         let Some(endpoint_id) = EndpointIdInt::get(endpoint_id) else {
     312            0 :             return;
     313              :         };
     314            0 :         let Some(role_name) = RoleNameInt::get(role_name) else {
     315            0 :             return;
     316              :         };
     317              : 
     318            0 :         let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) else {
     319            0 :             return;
     320              :         };
     321              : 
     322            0 :         let entry = endpoint_info.role_controls.entry(role_name);
     323            0 :         let hash_map::Entry::Occupied(role_controls) = entry else {
     324            0 :             return;
     325              :         };
     326              : 
     327            0 :         if role_controls.get().is_expired() {
     328            0 :             role_controls.remove();
     329            0 :         }
     330            0 :     }
     331              : 
     332            0 :     pub async fn gc_worker(&self) -> anyhow::Result<Infallible> {
     333            0 :         let mut interval =
     334            0 :             tokio::time::interval(self.config.gc_interval / (self.cache.shards().len()) as u32);
     335              :         loop {
     336            0 :             interval.tick().await;
     337            0 :             if self.cache.len() < self.config.size {
     338              :                 // If there are not too many entries, wait until the next gc cycle.
     339            0 :                 continue;
     340            0 :             }
     341            0 :             self.gc();
     342              :         }
     343              :     }
     344              : 
     345            0 :     fn gc(&self) {
     346            0 :         let shard = rand::rng().random_range(0..self.project2ep.shards().len());
     347            0 :         debug!(shard, "project_info_cache: performing epoch reclamation");
     348              : 
     349              :         // acquire a random shard lock
     350            0 :         let mut removed = 0;
     351            0 :         let shard = self.project2ep.shards()[shard].write();
     352            0 :         for (_, endpoints) in shard.iter() {
     353            0 :             for endpoint in endpoints {
     354            0 :                 self.cache.remove(endpoint);
     355            0 :                 removed += 1;
     356            0 :             }
     357              :         }
     358              :         // We can drop this shard only after making sure that all endpoints are removed.
     359            0 :         drop(shard);
     360            0 :         info!("project_info_cache: removed {removed} endpoints");
     361            0 :     }
     362              : }
     363              : 
     364              : #[cfg(test)]
     365              : mod tests {
     366              :     use super::*;
     367              :     use crate::control_plane::messages::{Details, EndpointRateLimitConfig, ErrorInfo, Status};
     368              :     use crate::control_plane::{AccessBlockerFlags, AuthSecret};
     369              :     use crate::scram::ServerSecret;
     370              :     use std::sync::Arc;
     371              : 
     372              :     #[tokio::test]
     373            1 :     async fn test_project_info_cache_settings() {
     374            1 :         tokio::time::pause();
     375            1 :         let cache = ProjectInfoCacheImpl::new(ProjectInfoCacheOptions {
     376            1 :             size: 2,
     377            1 :             max_roles: 2,
     378            1 :             ttl: Duration::from_secs(1),
     379            1 :             gc_interval: Duration::from_secs(600),
     380            1 :         });
     381            1 :         let project_id: Option<ProjectIdInt> = Some(ProjectIdInt::from(&"project".into()));
     382            1 :         let endpoint_id: EndpointId = "endpoint".into();
     383            1 :         let account_id = None;
     384              : 
     385            1 :         let user1: RoleName = "user1".into();
     386            1 :         let user2: RoleName = "user2".into();
     387            1 :         let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32])));
     388            1 :         let secret2 = None;
     389            1 :         let allowed_ips = Arc::new(vec![
     390            1 :             "127.0.0.1".parse().unwrap(),
     391            1 :             "127.0.0.2".parse().unwrap(),
     392              :         ]);
     393              : 
     394            1 :         cache.insert_endpoint_access(
     395            1 :             account_id,
     396            1 :             project_id,
     397            1 :             (&endpoint_id).into(),
     398            1 :             (&user1).into(),
     399            1 :             EndpointAccessControl {
     400            1 :                 allowed_ips: allowed_ips.clone(),
     401            1 :                 allowed_vpce: Arc::new(vec![]),
     402            1 :                 flags: AccessBlockerFlags::default(),
     403            1 :                 rate_limits: EndpointRateLimitConfig::default(),
     404            1 :             },
     405            1 :             RoleAccessControl {
     406            1 :                 secret: secret1.clone(),
     407            1 :             },
     408              :         );
     409              : 
     410            1 :         cache.insert_endpoint_access(
     411            1 :             account_id,
     412            1 :             project_id,
     413            1 :             (&endpoint_id).into(),
     414            1 :             (&user2).into(),
     415            1 :             EndpointAccessControl {
     416            1 :                 allowed_ips: allowed_ips.clone(),
     417            1 :                 allowed_vpce: Arc::new(vec![]),
     418            1 :                 flags: AccessBlockerFlags::default(),
     419            1 :                 rate_limits: EndpointRateLimitConfig::default(),
     420            1 :             },
     421            1 :             RoleAccessControl {
     422            1 :                 secret: secret2.clone(),
     423            1 :             },
     424              :         );
     425              : 
     426            1 :         let (cached, ttl) = cache
     427            1 :             .get_role_secret_with_ttl(&endpoint_id, &user1)
     428            1 :             .unwrap();
     429            1 :         assert_eq!(cached.unwrap().secret, secret1);
     430            1 :         assert_eq!(ttl, cache.config.ttl);
     431              : 
     432            1 :         let (cached, ttl) = cache
     433            1 :             .get_role_secret_with_ttl(&endpoint_id, &user2)
     434            1 :             .unwrap();
     435            1 :         assert_eq!(cached.unwrap().secret, secret2);
     436            1 :         assert_eq!(ttl, cache.config.ttl);
     437              : 
     438              :         // Shouldn't add more than 2 roles.
     439            1 :         let user3: RoleName = "user3".into();
     440            1 :         let secret3 = Some(AuthSecret::Scram(ServerSecret::mock([3; 32])));
     441              : 
     442            1 :         cache.insert_endpoint_access(
     443            1 :             account_id,
     444            1 :             project_id,
     445            1 :             (&endpoint_id).into(),
     446            1 :             (&user3).into(),
     447            1 :             EndpointAccessControl {
     448            1 :                 allowed_ips: allowed_ips.clone(),
     449            1 :                 allowed_vpce: Arc::new(vec![]),
     450            1 :                 flags: AccessBlockerFlags::default(),
     451            1 :                 rate_limits: EndpointRateLimitConfig::default(),
     452            1 :             },
     453            1 :             RoleAccessControl {
     454            1 :                 secret: secret3.clone(),
     455            1 :             },
     456              :         );
     457              : 
     458            1 :         assert!(
     459            1 :             cache
     460            1 :                 .get_role_secret_with_ttl(&endpoint_id, &user3)
     461            1 :                 .is_none()
     462              :         );
     463              : 
     464            1 :         let cached = cache
     465            1 :             .get_endpoint_access_with_ttl(&endpoint_id)
     466            1 :             .unwrap()
     467            1 :             .0
     468            1 :             .unwrap();
     469            1 :         assert_eq!(cached.allowed_ips, allowed_ips);
     470              : 
     471            1 :         tokio::time::advance(Duration::from_secs(2)).await;
     472            1 :         let cached = cache.get_role_secret_with_ttl(&endpoint_id, &user1);
     473            1 :         assert!(cached.is_none());
     474            1 :         let cached = cache.get_role_secret_with_ttl(&endpoint_id, &user2);
     475            1 :         assert!(cached.is_none());
     476            1 :         let cached = cache.get_endpoint_access_with_ttl(&endpoint_id);
     477            1 :         assert!(cached.is_none());
     478            1 :     }
     479              : 
     480              :     #[tokio::test]
     481            1 :     async fn test_caching_project_info_errors() {
     482            1 :         let cache = ProjectInfoCacheImpl::new(ProjectInfoCacheOptions {
     483            1 :             size: 10,
     484            1 :             max_roles: 10,
     485            1 :             ttl: Duration::from_secs(1),
     486            1 :             gc_interval: Duration::from_secs(600),
     487            1 :         });
     488            1 :         let project_id = Some(ProjectIdInt::from(&"project".into()));
     489            1 :         let endpoint_id: EndpointId = "endpoint".into();
     490            1 :         let account_id = None;
     491              : 
     492            1 :         let user1: RoleName = "user1".into();
     493            1 :         let user2: RoleName = "user2".into();
     494            1 :         let secret = Some(AuthSecret::Scram(ServerSecret::mock([1; 32])));
     495              : 
     496            1 :         let role_msg = Box::new(ControlPlaneErrorMessage {
     497            1 :             error: "role is protected and cannot be used for password-based authentication"
     498            1 :                 .to_owned()
     499            1 :                 .into_boxed_str(),
     500            1 :             http_status_code: http::StatusCode::NOT_FOUND,
     501            1 :             status: Some(Status {
     502            1 :                 code: "PERMISSION_DENIED".to_owned().into_boxed_str(),
     503            1 :                 message: "role is protected and cannot be used for password-based authentication"
     504            1 :                     .to_owned()
     505            1 :                     .into_boxed_str(),
     506            1 :                 details: Details {
     507            1 :                     error_info: Some(ErrorInfo {
     508            1 :                         reason: Reason::RoleProtected,
     509            1 :                     }),
     510            1 :                     retry_info: None,
     511            1 :                     user_facing_message: None,
     512            1 :                 },
     513            1 :             }),
     514            1 :         });
     515              : 
     516            1 :         let generic_msg = Box::new(ControlPlaneErrorMessage {
     517            1 :             error: "oh noes".to_owned().into_boxed_str(),
     518            1 :             http_status_code: http::StatusCode::NOT_FOUND,
     519            1 :             status: None,
     520            1 :         });
     521              : 
     522            5 :         let get_role_secret = |endpoint_id, role_name| {
     523            5 :             cache
     524            5 :                 .get_role_secret_with_ttl(endpoint_id, role_name)
     525            5 :                 .unwrap()
     526            5 :                 .0
     527            5 :         };
     528            1 :         let get_endpoint_access =
     529            3 :             |endpoint_id| cache.get_endpoint_access_with_ttl(endpoint_id).unwrap().0;
     530              : 
     531              :         // stores role-specific errors only for get_role_secret
     532            1 :         cache.insert_endpoint_access_err(
     533            1 :             (&endpoint_id).into(),
     534            1 :             (&user1).into(),
     535            1 :             role_msg.clone(),
     536            1 :             None,
     537              :         );
     538            1 :         assert_eq!(
     539            1 :             get_role_secret(&endpoint_id, &user1).unwrap_err().error,
     540              :             role_msg.error
     541              :         );
     542            1 :         assert!(cache.get_endpoint_access_with_ttl(&endpoint_id).is_none());
     543              : 
     544              :         // stores non-role specific errors for both get_role_secret and get_endpoint_access
     545            1 :         cache.insert_endpoint_access_err(
     546            1 :             (&endpoint_id).into(),
     547            1 :             (&user1).into(),
     548            1 :             generic_msg.clone(),
     549            1 :             None,
     550              :         );
     551            1 :         assert_eq!(
     552            1 :             get_role_secret(&endpoint_id, &user1).unwrap_err().error,
     553              :             generic_msg.error
     554              :         );
     555            1 :         assert_eq!(
     556            1 :             get_endpoint_access(&endpoint_id).unwrap_err().error,
     557              :             generic_msg.error
     558              :         );
     559              : 
     560              :         // error isn't returned for other roles in the same endpoint
     561            1 :         assert!(
     562            1 :             cache
     563            1 :                 .get_role_secret_with_ttl(&endpoint_id, &user2)
     564            1 :                 .is_none()
     565              :         );
     566              : 
     567              :         // success for a role does not overwrite errors for other roles
     568            1 :         cache.insert_endpoint_access(
     569            1 :             account_id,
     570            1 :             project_id,
     571            1 :             (&endpoint_id).into(),
     572            1 :             (&user2).into(),
     573            1 :             EndpointAccessControl {
     574            1 :                 allowed_ips: Arc::new(vec![]),
     575            1 :                 allowed_vpce: Arc::new(vec![]),
     576            1 :                 flags: AccessBlockerFlags::default(),
     577            1 :                 rate_limits: EndpointRateLimitConfig::default(),
     578            1 :             },
     579            1 :             RoleAccessControl {
     580            1 :                 secret: secret.clone(),
     581            1 :             },
     582              :         );
     583            1 :         assert!(get_role_secret(&endpoint_id, &user1).is_err());
     584            1 :         assert!(get_role_secret(&endpoint_id, &user2).is_ok());
     585              :         // ...but does clear the access control error
     586            1 :         assert!(get_endpoint_access(&endpoint_id).is_ok());
     587              : 
     588              :         // storing an error does not overwrite successful access control response
     589            1 :         cache.insert_endpoint_access_err(
     590            1 :             (&endpoint_id).into(),
     591            1 :             (&user2).into(),
     592            1 :             generic_msg.clone(),
     593            1 :             None,
     594              :         );
     595            1 :         assert!(get_role_secret(&endpoint_id, &user2).is_err());
     596            1 :         assert!(get_endpoint_access(&endpoint_id).is_ok());
     597            1 :     }
     598              : }
        

Generated by: LCOV version 2.1-beta