|             Line data    Source code 
       1              : use std::collections::HashSet;
       2              : use std::convert::Infallible;
       3              : 
       4              : use clashmap::ClashMap;
       5              : use moka::sync::Cache;
       6              : use tracing::{debug, info};
       7              : 
       8              : use crate::cache::common::{ControlPlaneResult, CplaneExpiry};
       9              : use crate::config::ProjectInfoCacheOptions;
      10              : use crate::control_plane::messages::{ControlPlaneErrorMessage, Reason};
      11              : use crate::control_plane::{EndpointAccessControl, RoleAccessControl};
      12              : use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt};
      13              : use crate::types::{EndpointId, RoleName};
      14              : 
      15              : /// Cache for project info.
      16              : /// This is used to cache auth data for endpoints.
      17              : /// Invalidation is done by console notifications or by TTL (if console notifications are disabled).
      18              : ///
      19              : /// We also store endpoint-to-project mapping in the cache, to be able to access per-endpoint data.
      20              : /// One may ask, why the data is stored per project, when on the user request there is only data about the endpoint available?
      21              : /// On the cplane side updates are done per project (or per branch), so it's easier to invalidate the whole project cache.
      22              : pub struct ProjectInfoCache {
      23              :     role_controls: Cache<(EndpointIdInt, RoleNameInt), ControlPlaneResult<RoleAccessControl>>,
      24              :     ep_controls: Cache<EndpointIdInt, ControlPlaneResult<EndpointAccessControl>>,
      25              : 
      26              :     project2ep: ClashMap<ProjectIdInt, HashSet<EndpointIdInt>>,
      27              :     // FIXME(stefan): we need a way to GC the account2ep map.
      28              :     account2ep: ClashMap<AccountIdInt, HashSet<EndpointIdInt>>,
      29              : 
      30              :     config: ProjectInfoCacheOptions,
      31              : }
      32              : 
      33              : impl ProjectInfoCache {
      34            0 :     pub fn invalidate_endpoint_access(&self, endpoint_id: EndpointIdInt) {
      35            0 :         info!("invalidating endpoint access for `{endpoint_id}`");
      36            0 :         self.ep_controls.invalidate(&endpoint_id);
      37            0 :     }
      38              : 
      39            0 :     pub fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt) {
      40            0 :         info!("invalidating endpoint access for project `{project_id}`");
      41            0 :         let endpoints = self
      42            0 :             .project2ep
      43            0 :             .get(&project_id)
      44            0 :             .map(|kv| kv.value().clone())
      45            0 :             .unwrap_or_default();
      46            0 :         for endpoint_id in endpoints {
      47            0 :             self.ep_controls.invalidate(&endpoint_id);
      48            0 :         }
      49            0 :     }
      50              : 
      51            0 :     pub fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt) {
      52            0 :         info!("invalidating endpoint access for org `{account_id}`");
      53            0 :         let endpoints = self
      54            0 :             .account2ep
      55            0 :             .get(&account_id)
      56            0 :             .map(|kv| kv.value().clone())
      57            0 :             .unwrap_or_default();
      58            0 :         for endpoint_id in endpoints {
      59            0 :             self.ep_controls.invalidate(&endpoint_id);
      60            0 :         }
      61            0 :     }
      62              : 
      63            0 :     pub fn invalidate_role_secret_for_project(
      64            0 :         &self,
      65            0 :         project_id: ProjectIdInt,
      66            0 :         role_name: RoleNameInt,
      67            0 :     ) {
      68            0 :         info!(
      69            0 :             "invalidating role secret for project_id `{}` and role_name `{}`",
      70              :             project_id, role_name,
      71              :         );
      72            0 :         let endpoints = self
      73            0 :             .project2ep
      74            0 :             .get(&project_id)
      75            0 :             .map(|kv| kv.value().clone())
      76            0 :             .unwrap_or_default();
      77            0 :         for endpoint_id in endpoints {
      78            0 :             self.role_controls.invalidate(&(endpoint_id, role_name));
      79            0 :         }
      80            0 :     }
      81              : }
      82              : 
      83              : impl ProjectInfoCache {
      84            2 :     pub(crate) fn new(config: ProjectInfoCacheOptions) -> Self {
      85              :         // we cache errors for 30 seconds, unless retry_at is set.
      86            2 :         let expiry = CplaneExpiry::default();
      87            2 :         Self {
      88            2 :             role_controls: Cache::builder()
      89            2 :                 .name("role_access_controls")
      90            2 :                 .max_capacity(config.size * config.max_roles)
      91            2 :                 .time_to_live(config.ttl)
      92            2 :                 .expire_after(expiry)
      93            2 :                 .build(),
      94            2 :             ep_controls: Cache::builder()
      95            2 :                 .name("endpoint_access_controls")
      96            2 :                 .max_capacity(config.size)
      97            2 :                 .time_to_live(config.ttl)
      98            2 :                 .expire_after(expiry)
      99            2 :                 .build(),
     100            2 :             project2ep: ClashMap::new(),
     101            2 :             account2ep: ClashMap::new(),
     102            2 :             config,
     103            2 :         }
     104            2 :     }
     105              : 
     106            8 :     pub(crate) fn get_role_secret(
     107            8 :         &self,
     108            8 :         endpoint_id: &EndpointId,
     109            8 :         role_name: &RoleName,
     110            8 :     ) -> Option<ControlPlaneResult<RoleAccessControl>> {
     111            8 :         let endpoint_id = EndpointIdInt::get(endpoint_id)?;
     112            8 :         let role_name = RoleNameInt::get(role_name)?;
     113              : 
     114            7 :         self.role_controls.get(&(endpoint_id, role_name))
     115            8 :     }
     116              : 
     117            4 :     pub(crate) fn get_endpoint_access(
     118            4 :         &self,
     119            4 :         endpoint_id: &EndpointId,
     120            4 :     ) -> Option<ControlPlaneResult<EndpointAccessControl>> {
     121            4 :         let endpoint_id = EndpointIdInt::get(endpoint_id)?;
     122              : 
     123            4 :         self.ep_controls.get(&endpoint_id)
     124            4 :     }
     125              : 
     126            4 :     pub(crate) fn insert_endpoint_access(
     127            4 :         &self,
     128            4 :         account_id: Option<AccountIdInt>,
     129            4 :         project_id: Option<ProjectIdInt>,
     130            4 :         endpoint_id: EndpointIdInt,
     131            4 :         role_name: RoleNameInt,
     132            4 :         controls: EndpointAccessControl,
     133            4 :         role_controls: RoleAccessControl,
     134            4 :     ) {
     135            4 :         if let Some(account_id) = account_id {
     136            0 :             self.insert_account2endpoint(account_id, endpoint_id);
     137            4 :         }
     138            4 :         if let Some(project_id) = project_id {
     139            4 :             self.insert_project2endpoint(project_id, endpoint_id);
     140            4 :         }
     141              : 
     142            4 :         debug!(
     143            0 :             key = &*endpoint_id,
     144            0 :             "created a cache entry for endpoint access"
     145              :         );
     146              : 
     147            4 :         self.ep_controls.insert(endpoint_id, Ok(controls));
     148            4 :         self.role_controls
     149            4 :             .insert((endpoint_id, role_name), Ok(role_controls));
     150            4 :     }
     151              : 
     152            3 :     pub(crate) fn insert_endpoint_access_err(
     153            3 :         &self,
     154            3 :         endpoint_id: EndpointIdInt,
     155            3 :         role_name: RoleNameInt,
     156            3 :         msg: Box<ControlPlaneErrorMessage>,
     157            3 :     ) {
     158            3 :         debug!(
     159            0 :             key = &*endpoint_id,
     160            0 :             "created a cache entry for an endpoint access error"
     161              :         );
     162              : 
     163              :         // RoleProtected is the only role-specific error that control plane can give us.
     164              :         // If a given role name does not exist, it still returns a successful response,
     165              :         // just with an empty secret.
     166            3 :         if msg.get_reason() != Reason::RoleProtected {
     167              :             // We can cache all the other errors in ep_controls because they don't
     168              :             // depend on what role name we pass to control plane.
     169            2 :             self.ep_controls
     170            2 :                 .entry(endpoint_id)
     171            2 :                 .and_compute_with(|entry| match entry {
     172              :                     // leave the entry alone if it's already Ok
     173            1 :                     Some(entry) if entry.value().is_ok() => moka::ops::compute::Op::Nop,
     174              :                     // replace the entry
     175            1 :                     _ => moka::ops::compute::Op::Put(Err(msg.clone())),
     176            2 :                 });
     177            1 :         }
     178              : 
     179            3 :         self.role_controls
     180            3 :             .insert((endpoint_id, role_name), Err(msg));
     181            3 :     }
     182              : 
     183            4 :     fn insert_project2endpoint(&self, project_id: ProjectIdInt, endpoint_id: EndpointIdInt) {
     184            4 :         if let Some(mut endpoints) = self.project2ep.get_mut(&project_id) {
     185            2 :             endpoints.insert(endpoint_id);
     186            2 :         } else {
     187            2 :             self.project2ep
     188            2 :                 .insert(project_id, HashSet::from([endpoint_id]));
     189            2 :         }
     190            4 :     }
     191              : 
     192            0 :     fn insert_account2endpoint(&self, account_id: AccountIdInt, endpoint_id: EndpointIdInt) {
     193            0 :         if let Some(mut endpoints) = self.account2ep.get_mut(&account_id) {
     194            0 :             endpoints.insert(endpoint_id);
     195            0 :         } else {
     196            0 :             self.account2ep
     197            0 :                 .insert(account_id, HashSet::from([endpoint_id]));
     198            0 :         }
     199            0 :     }
     200              : 
     201            0 :     pub fn maybe_invalidate_role_secret(&self, _endpoint_id: &EndpointId, _role_name: &RoleName) {
     202              :         // TODO: Expire the value early if the key is idle.
     203              :         // Currently not an issue as we would just use the TTL to decide, which is what already happens.
     204            0 :     }
     205              : 
     206            0 :     pub async fn gc_worker(&self) -> anyhow::Result<Infallible> {
     207            0 :         let mut interval = tokio::time::interval(self.config.gc_interval);
     208              :         loop {
     209            0 :             interval.tick().await;
     210            0 :             self.ep_controls.run_pending_tasks();
     211            0 :             self.role_controls.run_pending_tasks();
     212              :         }
     213              :     }
     214              : }
     215              : 
     216              : #[cfg(test)]
     217              : mod tests {
     218              :     use std::sync::Arc;
     219              :     use std::time::Duration;
     220              : 
     221              :     use super::*;
     222              :     use crate::control_plane::messages::{Details, EndpointRateLimitConfig, ErrorInfo, Status};
     223              :     use crate::control_plane::{AccessBlockerFlags, AuthSecret};
     224              :     use crate::scram::ServerSecret;
     225              : 
     226              :     #[tokio::test]
     227            1 :     async fn test_project_info_cache_settings() {
     228            1 :         let cache = ProjectInfoCache::new(ProjectInfoCacheOptions {
     229            1 :             size: 1,
     230            1 :             max_roles: 2,
     231            1 :             ttl: Duration::from_secs(1),
     232            1 :             gc_interval: Duration::from_secs(600),
     233            1 :         });
     234            1 :         let project_id: Option<ProjectIdInt> = Some(ProjectIdInt::from(&"project".into()));
     235            1 :         let endpoint_id: EndpointId = "endpoint".into();
     236            1 :         let account_id = None;
     237              : 
     238            1 :         let user1: RoleName = "user1".into();
     239            1 :         let user2: RoleName = "user2".into();
     240            1 :         let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32])));
     241            1 :         let secret2 = None;
     242            1 :         let allowed_ips = Arc::new(vec![
     243            1 :             "127.0.0.1".parse().unwrap(),
     244            1 :             "127.0.0.2".parse().unwrap(),
     245              :         ]);
     246              : 
     247            1 :         cache.insert_endpoint_access(
     248            1 :             account_id,
     249            1 :             project_id,
     250            1 :             (&endpoint_id).into(),
     251            1 :             (&user1).into(),
     252            1 :             EndpointAccessControl {
     253            1 :                 allowed_ips: allowed_ips.clone(),
     254            1 :                 allowed_vpce: Arc::new(vec![]),
     255            1 :                 flags: AccessBlockerFlags::default(),
     256            1 :                 rate_limits: EndpointRateLimitConfig::default(),
     257            1 :             },
     258            1 :             RoleAccessControl {
     259            1 :                 secret: secret1.clone(),
     260            1 :             },
     261              :         );
     262              : 
     263            1 :         cache.insert_endpoint_access(
     264            1 :             account_id,
     265            1 :             project_id,
     266            1 :             (&endpoint_id).into(),
     267            1 :             (&user2).into(),
     268            1 :             EndpointAccessControl {
     269            1 :                 allowed_ips: allowed_ips.clone(),
     270            1 :                 allowed_vpce: Arc::new(vec![]),
     271            1 :                 flags: AccessBlockerFlags::default(),
     272            1 :                 rate_limits: EndpointRateLimitConfig::default(),
     273            1 :             },
     274            1 :             RoleAccessControl {
     275            1 :                 secret: secret2.clone(),
     276            1 :             },
     277              :         );
     278              : 
     279            1 :         let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
     280            1 :         assert_eq!(cached.unwrap().secret, secret1);
     281              : 
     282            1 :         let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap();
     283            1 :         assert_eq!(cached.unwrap().secret, secret2);
     284              : 
     285              :         // Shouldn't add more than 2 roles.
     286            1 :         let user3: RoleName = "user3".into();
     287            1 :         let secret3 = Some(AuthSecret::Scram(ServerSecret::mock([3; 32])));
     288              : 
     289            1 :         cache.role_controls.run_pending_tasks();
     290            1 :         cache.insert_endpoint_access(
     291            1 :             account_id,
     292            1 :             project_id,
     293            1 :             (&endpoint_id).into(),
     294            1 :             (&user3).into(),
     295            1 :             EndpointAccessControl {
     296            1 :                 allowed_ips: allowed_ips.clone(),
     297            1 :                 allowed_vpce: Arc::new(vec![]),
     298            1 :                 flags: AccessBlockerFlags::default(),
     299            1 :                 rate_limits: EndpointRateLimitConfig::default(),
     300            1 :             },
     301            1 :             RoleAccessControl {
     302            1 :                 secret: secret3.clone(),
     303            1 :             },
     304              :         );
     305              : 
     306            1 :         cache.role_controls.run_pending_tasks();
     307            1 :         assert_eq!(cache.role_controls.entry_count(), 2);
     308              : 
     309            1 :         tokio::time::sleep(Duration::from_secs(2)).await;
     310              : 
     311            1 :         cache.role_controls.run_pending_tasks();
     312            1 :         assert_eq!(cache.role_controls.entry_count(), 0);
     313            1 :     }
     314              : 
     315              :     #[tokio::test]
     316            1 :     async fn test_caching_project_info_errors() {
     317            1 :         let cache = ProjectInfoCache::new(ProjectInfoCacheOptions {
     318            1 :             size: 10,
     319            1 :             max_roles: 10,
     320            1 :             ttl: Duration::from_secs(1),
     321            1 :             gc_interval: Duration::from_secs(600),
     322            1 :         });
     323            1 :         let project_id = Some(ProjectIdInt::from(&"project".into()));
     324            1 :         let endpoint_id: EndpointId = "endpoint".into();
     325            1 :         let account_id = None;
     326              : 
     327            1 :         let user1: RoleName = "user1".into();
     328            1 :         let user2: RoleName = "user2".into();
     329            1 :         let secret = Some(AuthSecret::Scram(ServerSecret::mock([1; 32])));
     330              : 
     331            1 :         let role_msg = Box::new(ControlPlaneErrorMessage {
     332            1 :             error: "role is protected and cannot be used for password-based authentication"
     333            1 :                 .to_owned()
     334            1 :                 .into_boxed_str(),
     335            1 :             http_status_code: http::StatusCode::NOT_FOUND,
     336            1 :             status: Some(Status {
     337            1 :                 code: "PERMISSION_DENIED".to_owned().into_boxed_str(),
     338            1 :                 message: "role is protected and cannot be used for password-based authentication"
     339            1 :                     .to_owned()
     340            1 :                     .into_boxed_str(),
     341            1 :                 details: Details {
     342            1 :                     error_info: Some(ErrorInfo {
     343            1 :                         reason: Reason::RoleProtected,
     344            1 :                     }),
     345            1 :                     retry_info: None,
     346            1 :                     user_facing_message: None,
     347            1 :                 },
     348            1 :             }),
     349            1 :         });
     350              : 
     351            1 :         let generic_msg = Box::new(ControlPlaneErrorMessage {
     352            1 :             error: "oh noes".to_owned().into_boxed_str(),
     353            1 :             http_status_code: http::StatusCode::NOT_FOUND,
     354            1 :             status: None,
     355            1 :         });
     356              : 
     357            1 :         let get_role_secret =
     358            5 :             |endpoint_id, role_name| cache.get_role_secret(endpoint_id, role_name).unwrap();
     359            3 :         let get_endpoint_access = |endpoint_id| cache.get_endpoint_access(endpoint_id).unwrap();
     360              : 
     361              :         // stores role-specific errors only for get_role_secret
     362            1 :         cache.insert_endpoint_access_err((&endpoint_id).into(), (&user1).into(), role_msg.clone());
     363            1 :         assert_eq!(
     364            1 :             get_role_secret(&endpoint_id, &user1).unwrap_err().error,
     365              :             role_msg.error
     366              :         );
     367            1 :         assert!(cache.get_endpoint_access(&endpoint_id).is_none());
     368              : 
     369              :         // stores non-role specific errors for both get_role_secret and get_endpoint_access
     370            1 :         cache.insert_endpoint_access_err(
     371            1 :             (&endpoint_id).into(),
     372            1 :             (&user1).into(),
     373            1 :             generic_msg.clone(),
     374              :         );
     375            1 :         assert_eq!(
     376            1 :             get_role_secret(&endpoint_id, &user1).unwrap_err().error,
     377              :             generic_msg.error
     378              :         );
     379            1 :         assert_eq!(
     380            1 :             get_endpoint_access(&endpoint_id).unwrap_err().error,
     381              :             generic_msg.error
     382              :         );
     383              : 
     384              :         // error isn't returned for other roles in the same endpoint
     385            1 :         assert!(cache.get_role_secret(&endpoint_id, &user2).is_none());
     386              : 
     387              :         // success for a role does not overwrite errors for other roles
     388            1 :         cache.insert_endpoint_access(
     389            1 :             account_id,
     390            1 :             project_id,
     391            1 :             (&endpoint_id).into(),
     392            1 :             (&user2).into(),
     393            1 :             EndpointAccessControl {
     394            1 :                 allowed_ips: Arc::new(vec![]),
     395            1 :                 allowed_vpce: Arc::new(vec![]),
     396            1 :                 flags: AccessBlockerFlags::default(),
     397            1 :                 rate_limits: EndpointRateLimitConfig::default(),
     398            1 :             },
     399            1 :             RoleAccessControl {
     400            1 :                 secret: secret.clone(),
     401            1 :             },
     402              :         );
     403            1 :         assert!(get_role_secret(&endpoint_id, &user1).is_err());
     404            1 :         assert!(get_role_secret(&endpoint_id, &user2).is_ok());
     405              :         // ...but does clear the access control error
     406            1 :         assert!(get_endpoint_access(&endpoint_id).is_ok());
     407              : 
     408              :         // storing an error does not overwrite successful access control response
     409            1 :         cache.insert_endpoint_access_err(
     410            1 :             (&endpoint_id).into(),
     411            1 :             (&user2).into(),
     412            1 :             generic_msg.clone(),
     413              :         );
     414            1 :         assert!(get_role_secret(&endpoint_id, &user2).is_err());
     415            1 :         assert!(get_endpoint_access(&endpoint_id).is_ok());
     416            1 :     }
     417              : }
         |