Line data Source code
1 : use std::collections::HashSet;
2 : use std::convert::Infallible;
3 :
4 : use clashmap::ClashMap;
5 : use moka::sync::Cache;
6 : use tracing::{debug, info};
7 :
8 : use crate::cache::common::{ControlPlaneResult, CplaneExpiry};
9 : use crate::config::ProjectInfoCacheOptions;
10 : use crate::control_plane::messages::{ControlPlaneErrorMessage, Reason};
11 : use crate::control_plane::{EndpointAccessControl, RoleAccessControl};
12 : use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt};
13 : use crate::types::{EndpointId, RoleName};
14 :
15 : /// Cache for project info.
16 : /// This is used to cache auth data for endpoints.
17 : /// Invalidation is done by console notifications or by TTL (if console notifications are disabled).
18 : ///
19 : /// We also store endpoint-to-project mapping in the cache, to be able to access per-endpoint data.
20 : /// One may ask, why the data is stored per project, when on the user request there is only data about the endpoint available?
21 : /// On the cplane side updates are done per project (or per branch), so it's easier to invalidate the whole project cache.
22 : pub struct ProjectInfoCache {
23 : role_controls: Cache<(EndpointIdInt, RoleNameInt), ControlPlaneResult<RoleAccessControl>>,
24 : ep_controls: Cache<EndpointIdInt, ControlPlaneResult<EndpointAccessControl>>,
25 :
26 : project2ep: ClashMap<ProjectIdInt, HashSet<EndpointIdInt>>,
27 : // FIXME(stefan): we need a way to GC the account2ep map.
28 : account2ep: ClashMap<AccountIdInt, HashSet<EndpointIdInt>>,
29 :
30 : config: ProjectInfoCacheOptions,
31 : }
32 :
33 : impl ProjectInfoCache {
34 0 : pub fn invalidate_endpoint_access(&self, endpoint_id: EndpointIdInt) {
35 0 : info!("invalidating endpoint access for `{endpoint_id}`");
36 0 : self.ep_controls.invalidate(&endpoint_id);
37 0 : }
38 :
39 0 : pub fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt) {
40 0 : info!("invalidating endpoint access for project `{project_id}`");
41 0 : let endpoints = self
42 0 : .project2ep
43 0 : .get(&project_id)
44 0 : .map(|kv| kv.value().clone())
45 0 : .unwrap_or_default();
46 0 : for endpoint_id in endpoints {
47 0 : self.ep_controls.invalidate(&endpoint_id);
48 0 : }
49 0 : }
50 :
51 0 : pub fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt) {
52 0 : info!("invalidating endpoint access for org `{account_id}`");
53 0 : let endpoints = self
54 0 : .account2ep
55 0 : .get(&account_id)
56 0 : .map(|kv| kv.value().clone())
57 0 : .unwrap_or_default();
58 0 : for endpoint_id in endpoints {
59 0 : self.ep_controls.invalidate(&endpoint_id);
60 0 : }
61 0 : }
62 :
63 0 : pub fn invalidate_role_secret_for_project(
64 0 : &self,
65 0 : project_id: ProjectIdInt,
66 0 : role_name: RoleNameInt,
67 0 : ) {
68 0 : info!(
69 0 : "invalidating role secret for project_id `{}` and role_name `{}`",
70 : project_id, role_name,
71 : );
72 0 : let endpoints = self
73 0 : .project2ep
74 0 : .get(&project_id)
75 0 : .map(|kv| kv.value().clone())
76 0 : .unwrap_or_default();
77 0 : for endpoint_id in endpoints {
78 0 : self.role_controls.invalidate(&(endpoint_id, role_name));
79 0 : }
80 0 : }
81 : }
82 :
83 : impl ProjectInfoCache {
84 2 : pub(crate) fn new(config: ProjectInfoCacheOptions) -> Self {
85 : // we cache errors for 30 seconds, unless retry_at is set.
86 2 : let expiry = CplaneExpiry::default();
87 2 : Self {
88 2 : role_controls: Cache::builder()
89 2 : .name("role_access_controls")
90 2 : .max_capacity(config.size * config.max_roles)
91 2 : .time_to_live(config.ttl)
92 2 : .expire_after(expiry)
93 2 : .build(),
94 2 : ep_controls: Cache::builder()
95 2 : .name("endpoint_access_controls")
96 2 : .max_capacity(config.size)
97 2 : .time_to_live(config.ttl)
98 2 : .expire_after(expiry)
99 2 : .build(),
100 2 : project2ep: ClashMap::new(),
101 2 : account2ep: ClashMap::new(),
102 2 : config,
103 2 : }
104 2 : }
105 :
106 8 : pub(crate) fn get_role_secret(
107 8 : &self,
108 8 : endpoint_id: &EndpointId,
109 8 : role_name: &RoleName,
110 8 : ) -> Option<ControlPlaneResult<RoleAccessControl>> {
111 8 : let endpoint_id = EndpointIdInt::get(endpoint_id)?;
112 8 : let role_name = RoleNameInt::get(role_name)?;
113 :
114 7 : self.role_controls.get(&(endpoint_id, role_name))
115 8 : }
116 :
117 4 : pub(crate) fn get_endpoint_access(
118 4 : &self,
119 4 : endpoint_id: &EndpointId,
120 4 : ) -> Option<ControlPlaneResult<EndpointAccessControl>> {
121 4 : let endpoint_id = EndpointIdInt::get(endpoint_id)?;
122 :
123 4 : self.ep_controls.get(&endpoint_id)
124 4 : }
125 :
126 4 : pub(crate) fn insert_endpoint_access(
127 4 : &self,
128 4 : account_id: Option<AccountIdInt>,
129 4 : project_id: Option<ProjectIdInt>,
130 4 : endpoint_id: EndpointIdInt,
131 4 : role_name: RoleNameInt,
132 4 : controls: EndpointAccessControl,
133 4 : role_controls: RoleAccessControl,
134 4 : ) {
135 4 : if let Some(account_id) = account_id {
136 0 : self.insert_account2endpoint(account_id, endpoint_id);
137 4 : }
138 4 : if let Some(project_id) = project_id {
139 4 : self.insert_project2endpoint(project_id, endpoint_id);
140 4 : }
141 :
142 4 : debug!(
143 0 : key = &*endpoint_id,
144 0 : "created a cache entry for endpoint access"
145 : );
146 :
147 4 : self.ep_controls.insert(endpoint_id, Ok(controls));
148 4 : self.role_controls
149 4 : .insert((endpoint_id, role_name), Ok(role_controls));
150 4 : }
151 :
152 3 : pub(crate) fn insert_endpoint_access_err(
153 3 : &self,
154 3 : endpoint_id: EndpointIdInt,
155 3 : role_name: RoleNameInt,
156 3 : msg: Box<ControlPlaneErrorMessage>,
157 3 : ) {
158 3 : debug!(
159 0 : key = &*endpoint_id,
160 0 : "created a cache entry for an endpoint access error"
161 : );
162 :
163 : // RoleProtected is the only role-specific error that control plane can give us.
164 : // If a given role name does not exist, it still returns a successful response,
165 : // just with an empty secret.
166 3 : if msg.get_reason() != Reason::RoleProtected {
167 : // We can cache all the other errors in ep_controls because they don't
168 : // depend on what role name we pass to control plane.
169 2 : self.ep_controls
170 2 : .entry(endpoint_id)
171 2 : .and_compute_with(|entry| match entry {
172 : // leave the entry alone if it's already Ok
173 1 : Some(entry) if entry.value().is_ok() => moka::ops::compute::Op::Nop,
174 : // replace the entry
175 1 : _ => moka::ops::compute::Op::Put(Err(msg.clone())),
176 2 : });
177 1 : }
178 :
179 3 : self.role_controls
180 3 : .insert((endpoint_id, role_name), Err(msg));
181 3 : }
182 :
183 4 : fn insert_project2endpoint(&self, project_id: ProjectIdInt, endpoint_id: EndpointIdInt) {
184 4 : if let Some(mut endpoints) = self.project2ep.get_mut(&project_id) {
185 2 : endpoints.insert(endpoint_id);
186 2 : } else {
187 2 : self.project2ep
188 2 : .insert(project_id, HashSet::from([endpoint_id]));
189 2 : }
190 4 : }
191 :
192 0 : fn insert_account2endpoint(&self, account_id: AccountIdInt, endpoint_id: EndpointIdInt) {
193 0 : if let Some(mut endpoints) = self.account2ep.get_mut(&account_id) {
194 0 : endpoints.insert(endpoint_id);
195 0 : } else {
196 0 : self.account2ep
197 0 : .insert(account_id, HashSet::from([endpoint_id]));
198 0 : }
199 0 : }
200 :
201 0 : pub fn maybe_invalidate_role_secret(&self, _endpoint_id: &EndpointId, _role_name: &RoleName) {
202 : // TODO: Expire the value early if the key is idle.
203 : // Currently not an issue as we would just use the TTL to decide, which is what already happens.
204 0 : }
205 :
206 0 : pub async fn gc_worker(&self) -> anyhow::Result<Infallible> {
207 0 : let mut interval = tokio::time::interval(self.config.gc_interval);
208 : loop {
209 0 : interval.tick().await;
210 0 : self.ep_controls.run_pending_tasks();
211 0 : self.role_controls.run_pending_tasks();
212 : }
213 : }
214 : }
215 :
216 : #[cfg(test)]
217 : mod tests {
218 : use std::sync::Arc;
219 : use std::time::Duration;
220 :
221 : use super::*;
222 : use crate::control_plane::messages::{Details, EndpointRateLimitConfig, ErrorInfo, Status};
223 : use crate::control_plane::{AccessBlockerFlags, AuthSecret};
224 : use crate::scram::ServerSecret;
225 :
226 : #[tokio::test]
227 1 : async fn test_project_info_cache_settings() {
228 1 : let cache = ProjectInfoCache::new(ProjectInfoCacheOptions {
229 1 : size: 1,
230 1 : max_roles: 2,
231 1 : ttl: Duration::from_secs(1),
232 1 : gc_interval: Duration::from_secs(600),
233 1 : });
234 1 : let project_id: Option<ProjectIdInt> = Some(ProjectIdInt::from(&"project".into()));
235 1 : let endpoint_id: EndpointId = "endpoint".into();
236 1 : let account_id = None;
237 :
238 1 : let user1: RoleName = "user1".into();
239 1 : let user2: RoleName = "user2".into();
240 1 : let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32])));
241 1 : let secret2 = None;
242 1 : let allowed_ips = Arc::new(vec![
243 1 : "127.0.0.1".parse().unwrap(),
244 1 : "127.0.0.2".parse().unwrap(),
245 : ]);
246 :
247 1 : cache.insert_endpoint_access(
248 1 : account_id,
249 1 : project_id,
250 1 : (&endpoint_id).into(),
251 1 : (&user1).into(),
252 1 : EndpointAccessControl {
253 1 : allowed_ips: allowed_ips.clone(),
254 1 : allowed_vpce: Arc::new(vec![]),
255 1 : flags: AccessBlockerFlags::default(),
256 1 : rate_limits: EndpointRateLimitConfig::default(),
257 1 : },
258 1 : RoleAccessControl {
259 1 : secret: secret1.clone(),
260 1 : },
261 : );
262 :
263 1 : cache.insert_endpoint_access(
264 1 : account_id,
265 1 : project_id,
266 1 : (&endpoint_id).into(),
267 1 : (&user2).into(),
268 1 : EndpointAccessControl {
269 1 : allowed_ips: allowed_ips.clone(),
270 1 : allowed_vpce: Arc::new(vec![]),
271 1 : flags: AccessBlockerFlags::default(),
272 1 : rate_limits: EndpointRateLimitConfig::default(),
273 1 : },
274 1 : RoleAccessControl {
275 1 : secret: secret2.clone(),
276 1 : },
277 : );
278 :
279 1 : let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
280 1 : assert_eq!(cached.unwrap().secret, secret1);
281 :
282 1 : let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap();
283 1 : assert_eq!(cached.unwrap().secret, secret2);
284 :
285 : // Shouldn't add more than 2 roles.
286 1 : let user3: RoleName = "user3".into();
287 1 : let secret3 = Some(AuthSecret::Scram(ServerSecret::mock([3; 32])));
288 :
289 1 : cache.role_controls.run_pending_tasks();
290 1 : cache.insert_endpoint_access(
291 1 : account_id,
292 1 : project_id,
293 1 : (&endpoint_id).into(),
294 1 : (&user3).into(),
295 1 : EndpointAccessControl {
296 1 : allowed_ips: allowed_ips.clone(),
297 1 : allowed_vpce: Arc::new(vec![]),
298 1 : flags: AccessBlockerFlags::default(),
299 1 : rate_limits: EndpointRateLimitConfig::default(),
300 1 : },
301 1 : RoleAccessControl {
302 1 : secret: secret3.clone(),
303 1 : },
304 : );
305 :
306 1 : cache.role_controls.run_pending_tasks();
307 1 : assert_eq!(cache.role_controls.entry_count(), 2);
308 :
309 1 : tokio::time::sleep(Duration::from_secs(2)).await;
310 :
311 1 : cache.role_controls.run_pending_tasks();
312 1 : assert_eq!(cache.role_controls.entry_count(), 0);
313 1 : }
314 :
315 : #[tokio::test]
316 1 : async fn test_caching_project_info_errors() {
317 1 : let cache = ProjectInfoCache::new(ProjectInfoCacheOptions {
318 1 : size: 10,
319 1 : max_roles: 10,
320 1 : ttl: Duration::from_secs(1),
321 1 : gc_interval: Duration::from_secs(600),
322 1 : });
323 1 : let project_id = Some(ProjectIdInt::from(&"project".into()));
324 1 : let endpoint_id: EndpointId = "endpoint".into();
325 1 : let account_id = None;
326 :
327 1 : let user1: RoleName = "user1".into();
328 1 : let user2: RoleName = "user2".into();
329 1 : let secret = Some(AuthSecret::Scram(ServerSecret::mock([1; 32])));
330 :
331 1 : let role_msg = Box::new(ControlPlaneErrorMessage {
332 1 : error: "role is protected and cannot be used for password-based authentication"
333 1 : .to_owned()
334 1 : .into_boxed_str(),
335 1 : http_status_code: http::StatusCode::NOT_FOUND,
336 1 : status: Some(Status {
337 1 : code: "PERMISSION_DENIED".to_owned().into_boxed_str(),
338 1 : message: "role is protected and cannot be used for password-based authentication"
339 1 : .to_owned()
340 1 : .into_boxed_str(),
341 1 : details: Details {
342 1 : error_info: Some(ErrorInfo {
343 1 : reason: Reason::RoleProtected,
344 1 : }),
345 1 : retry_info: None,
346 1 : user_facing_message: None,
347 1 : },
348 1 : }),
349 1 : });
350 :
351 1 : let generic_msg = Box::new(ControlPlaneErrorMessage {
352 1 : error: "oh noes".to_owned().into_boxed_str(),
353 1 : http_status_code: http::StatusCode::NOT_FOUND,
354 1 : status: None,
355 1 : });
356 :
357 1 : let get_role_secret =
358 5 : |endpoint_id, role_name| cache.get_role_secret(endpoint_id, role_name).unwrap();
359 3 : let get_endpoint_access = |endpoint_id| cache.get_endpoint_access(endpoint_id).unwrap();
360 :
361 : // stores role-specific errors only for get_role_secret
362 1 : cache.insert_endpoint_access_err((&endpoint_id).into(), (&user1).into(), role_msg.clone());
363 1 : assert_eq!(
364 1 : get_role_secret(&endpoint_id, &user1).unwrap_err().error,
365 : role_msg.error
366 : );
367 1 : assert!(cache.get_endpoint_access(&endpoint_id).is_none());
368 :
369 : // stores non-role specific errors for both get_role_secret and get_endpoint_access
370 1 : cache.insert_endpoint_access_err(
371 1 : (&endpoint_id).into(),
372 1 : (&user1).into(),
373 1 : generic_msg.clone(),
374 : );
375 1 : assert_eq!(
376 1 : get_role_secret(&endpoint_id, &user1).unwrap_err().error,
377 : generic_msg.error
378 : );
379 1 : assert_eq!(
380 1 : get_endpoint_access(&endpoint_id).unwrap_err().error,
381 : generic_msg.error
382 : );
383 :
384 : // error isn't returned for other roles in the same endpoint
385 1 : assert!(cache.get_role_secret(&endpoint_id, &user2).is_none());
386 :
387 : // success for a role does not overwrite errors for other roles
388 1 : cache.insert_endpoint_access(
389 1 : account_id,
390 1 : project_id,
391 1 : (&endpoint_id).into(),
392 1 : (&user2).into(),
393 1 : EndpointAccessControl {
394 1 : allowed_ips: Arc::new(vec![]),
395 1 : allowed_vpce: Arc::new(vec![]),
396 1 : flags: AccessBlockerFlags::default(),
397 1 : rate_limits: EndpointRateLimitConfig::default(),
398 1 : },
399 1 : RoleAccessControl {
400 1 : secret: secret.clone(),
401 1 : },
402 : );
403 1 : assert!(get_role_secret(&endpoint_id, &user1).is_err());
404 1 : assert!(get_role_secret(&endpoint_id, &user2).is_ok());
405 : // ...but does clear the access control error
406 1 : assert!(get_endpoint_access(&endpoint_id).is_ok());
407 :
408 : // storing an error does not overwrite successful access control response
409 1 : cache.insert_endpoint_access_err(
410 1 : (&endpoint_id).into(),
411 1 : (&user2).into(),
412 1 : generic_msg.clone(),
413 : );
414 1 : assert!(get_role_secret(&endpoint_id, &user2).is_err());
415 1 : assert!(get_endpoint_access(&endpoint_id).is_ok());
416 1 : }
417 : }
|