Line data Source code
1 : use std::{
2 : collections::HashSet,
3 : convert::Infallible,
4 : sync::{atomic::AtomicU64, Arc},
5 : time::Duration,
6 : };
7 :
8 : use dashmap::DashMap;
9 : use rand::{thread_rng, Rng};
10 : use smol_str::SmolStr;
11 : use tokio::time::Instant;
12 : use tracing::{debug, info};
13 :
14 : use crate::{
15 : auth::IpPattern,
16 : config::ProjectInfoCacheOptions,
17 : console::AuthSecret,
18 : intern::{EndpointIdInt, ProjectIdInt, RoleNameInt},
19 : EndpointId, RoleName,
20 : };
21 :
22 : use super::{Cache, Cached};
23 :
24 : pub trait ProjectInfoCache {
25 : fn invalidate_allowed_ips_for_project(&self, project_id: ProjectIdInt);
26 : fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt);
27 : fn enable_ttl(&self);
28 : fn disable_ttl(&self);
29 : }
30 :
31 : struct Entry<T> {
32 : created_at: Instant,
33 : value: T,
34 : }
35 :
36 : impl<T> Entry<T> {
37 18 : pub fn new(value: T) -> Self {
38 18 : Self {
39 18 : created_at: Instant::now(),
40 18 : value,
41 18 : }
42 18 : }
43 : }
44 :
45 : impl<T> From<T> for Entry<T> {
46 18 : fn from(value: T) -> Self {
47 18 : Self::new(value)
48 18 : }
49 : }
50 :
51 : #[derive(Default)]
52 : struct EndpointInfo {
53 : secret: std::collections::HashMap<RoleNameInt, Entry<Option<AuthSecret>>>,
54 : allowed_ips: Option<Entry<Arc<Vec<IpPattern>>>>,
55 : }
56 :
57 : impl EndpointInfo {
58 22 : fn check_ignore_cache(ignore_cache_since: Option<Instant>, created_at: Instant) -> bool {
59 22 : match ignore_cache_since {
60 6 : None => false,
61 16 : Some(t) => t < created_at,
62 : }
63 22 : }
64 30 : pub fn get_role_secret(
65 30 : &self,
66 30 : role_name: RoleNameInt,
67 30 : valid_since: Instant,
68 30 : ignore_cache_since: Option<Instant>,
69 30 : ) -> Option<(Option<AuthSecret>, bool)> {
70 30 : if let Some(secret) = self.secret.get(&role_name) {
71 26 : if valid_since < secret.created_at {
72 14 : return Some((
73 14 : secret.value.clone(),
74 14 : Self::check_ignore_cache(ignore_cache_since, secret.created_at),
75 14 : ));
76 12 : }
77 4 : }
78 16 : None
79 30 : }
80 :
81 10 : pub fn get_allowed_ips(
82 10 : &self,
83 10 : valid_since: Instant,
84 10 : ignore_cache_since: Option<Instant>,
85 10 : ) -> Option<(Arc<Vec<IpPattern>>, bool)> {
86 10 : if let Some(allowed_ips) = &self.allowed_ips {
87 10 : if valid_since < allowed_ips.created_at {
88 8 : return Some((
89 8 : allowed_ips.value.clone(),
90 8 : Self::check_ignore_cache(ignore_cache_since, allowed_ips.created_at),
91 8 : ));
92 2 : }
93 0 : }
94 2 : None
95 10 : }
96 0 : pub fn invalidate_allowed_ips(&mut self) {
97 0 : self.allowed_ips = None;
98 0 : }
99 2 : pub fn invalidate_role_secret(&mut self, role_name: RoleNameInt) {
100 2 : self.secret.remove(&role_name);
101 2 : }
102 : }
103 :
104 : /// Cache for project info.
105 : /// This is used to cache auth data for endpoints.
106 : /// Invalidation is done by console notifications or by TTL (if console notifications are disabled).
107 : ///
108 : /// We also store endpoint-to-project mapping in the cache, to be able to access per-endpoint data.
109 : /// One may ask, why the data is stored per project, when on the user request there is only data about the endpoint available?
110 : /// On the cplane side updates are done per project (or per branch), so it's easier to invalidate the whole project cache.
111 : pub struct ProjectInfoCacheImpl {
112 : cache: DashMap<EndpointIdInt, EndpointInfo>,
113 :
114 : project2ep: DashMap<ProjectIdInt, HashSet<EndpointIdInt>>,
115 : config: ProjectInfoCacheOptions,
116 :
117 : start_time: Instant,
118 : ttl_disabled_since_us: AtomicU64,
119 : }
120 :
121 : impl ProjectInfoCache for ProjectInfoCacheImpl {
122 0 : fn invalidate_allowed_ips_for_project(&self, project_id: ProjectIdInt) {
123 0 : info!("invalidating allowed ips for project `{}`", project_id);
124 0 : let endpoints = self
125 0 : .project2ep
126 0 : .get(&project_id)
127 0 : .map(|kv| kv.value().clone())
128 0 : .unwrap_or_default();
129 0 : for endpoint_id in endpoints {
130 0 : if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
131 0 : endpoint_info.invalidate_allowed_ips();
132 0 : }
133 : }
134 0 : }
135 2 : fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt) {
136 2 : info!(
137 0 : "invalidating role secret for project_id `{}` and role_name `{}`",
138 0 : project_id, role_name,
139 0 : );
140 2 : let endpoints = self
141 2 : .project2ep
142 2 : .get(&project_id)
143 2 : .map(|kv| kv.value().clone())
144 2 : .unwrap_or_default();
145 4 : for endpoint_id in endpoints {
146 2 : if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
147 2 : endpoint_info.invalidate_role_secret(role_name);
148 2 : }
149 : }
150 2 : }
151 0 : fn enable_ttl(&self) {
152 0 : self.ttl_disabled_since_us
153 0 : .store(u64::MAX, std::sync::atomic::Ordering::Relaxed);
154 0 : }
155 :
156 4 : fn disable_ttl(&self) {
157 4 : let new_ttl = (self.start_time.elapsed() + self.config.ttl).as_micros() as u64;
158 4 : self.ttl_disabled_since_us
159 4 : .store(new_ttl, std::sync::atomic::Ordering::Relaxed);
160 4 : }
161 : }
162 :
163 : impl ProjectInfoCacheImpl {
164 6 : pub fn new(config: ProjectInfoCacheOptions) -> Self {
165 6 : Self {
166 6 : cache: DashMap::new(),
167 6 : project2ep: DashMap::new(),
168 6 : config,
169 6 : ttl_disabled_since_us: AtomicU64::new(u64::MAX),
170 6 : start_time: Instant::now(),
171 6 : }
172 6 : }
173 :
174 30 : pub fn get_role_secret(
175 30 : &self,
176 30 : endpoint_id: &EndpointId,
177 30 : role_name: &RoleName,
178 30 : ) -> Option<Cached<&Self, Option<AuthSecret>>> {
179 30 : let endpoint_id = EndpointIdInt::get(endpoint_id)?;
180 30 : let role_name = RoleNameInt::get(role_name)?;
181 30 : let (valid_since, ignore_cache_since) = self.get_cache_times();
182 30 : let endpoint_info = self.cache.get(&endpoint_id)?;
183 14 : let (value, ignore_cache) =
184 30 : endpoint_info.get_role_secret(role_name, valid_since, ignore_cache_since)?;
185 14 : if !ignore_cache {
186 8 : let cached = Cached {
187 8 : token: Some((
188 8 : self,
189 8 : CachedLookupInfo::new_role_secret(endpoint_id, role_name),
190 8 : )),
191 8 : value,
192 8 : };
193 8 : return Some(cached);
194 6 : }
195 6 : Some(Cached::new_uncached(value))
196 30 : }
197 10 : pub fn get_allowed_ips(
198 10 : &self,
199 10 : endpoint_id: &EndpointId,
200 10 : ) -> Option<Cached<&Self, Arc<Vec<IpPattern>>>> {
201 10 : let endpoint_id = EndpointIdInt::get(endpoint_id)?;
202 10 : let (valid_since, ignore_cache_since) = self.get_cache_times();
203 10 : let endpoint_info = self.cache.get(&endpoint_id)?;
204 10 : let value = endpoint_info.get_allowed_ips(valid_since, ignore_cache_since);
205 10 : let (value, ignore_cache) = value?;
206 8 : if !ignore_cache {
207 2 : let cached = Cached {
208 2 : token: Some((self, CachedLookupInfo::new_allowed_ips(endpoint_id))),
209 2 : value,
210 2 : };
211 2 : return Some(cached);
212 6 : }
213 6 : Some(Cached::new_uncached(value))
214 10 : }
215 14 : pub fn insert_role_secret(
216 14 : &self,
217 14 : project_id: ProjectIdInt,
218 14 : endpoint_id: EndpointIdInt,
219 14 : role_name: RoleNameInt,
220 14 : secret: Option<AuthSecret>,
221 14 : ) {
222 14 : if self.cache.len() >= self.config.size {
223 : // If there are too many entries, wait until the next gc cycle.
224 0 : return;
225 14 : }
226 14 : self.insert_project2endpoint(project_id, endpoint_id);
227 14 : let mut entry = self.cache.entry(endpoint_id).or_default();
228 14 : if entry.secret.len() < self.config.max_roles {
229 12 : entry.secret.insert(role_name, secret.into());
230 12 : }
231 14 : }
232 6 : pub fn insert_allowed_ips(
233 6 : &self,
234 6 : project_id: ProjectIdInt,
235 6 : endpoint_id: EndpointIdInt,
236 6 : allowed_ips: Arc<Vec<IpPattern>>,
237 6 : ) {
238 6 : if self.cache.len() >= self.config.size {
239 : // If there are too many entries, wait until the next gc cycle.
240 0 : return;
241 6 : }
242 6 : self.insert_project2endpoint(project_id, endpoint_id);
243 6 : self.cache.entry(endpoint_id).or_default().allowed_ips = Some(allowed_ips.into());
244 6 : }
245 20 : fn insert_project2endpoint(&self, project_id: ProjectIdInt, endpoint_id: EndpointIdInt) {
246 20 : if let Some(mut endpoints) = self.project2ep.get_mut(&project_id) {
247 14 : endpoints.insert(endpoint_id);
248 14 : } else {
249 6 : self.project2ep
250 6 : .insert(project_id, HashSet::from([endpoint_id]));
251 6 : }
252 20 : }
253 40 : fn get_cache_times(&self) -> (Instant, Option<Instant>) {
254 40 : let mut valid_since = Instant::now() - self.config.ttl;
255 40 : // Only ignore cache if ttl is disabled.
256 40 : let ttl_disabled_since_us = self
257 40 : .ttl_disabled_since_us
258 40 : .load(std::sync::atomic::Ordering::Relaxed);
259 40 : let ignore_cache_since = if ttl_disabled_since_us != u64::MAX {
260 26 : let ignore_cache_since = self.start_time + Duration::from_micros(ttl_disabled_since_us);
261 26 : // We are fine if entry is not older than ttl or was added before we are getting notifications.
262 26 : valid_since = valid_since.min(ignore_cache_since);
263 26 : Some(ignore_cache_since)
264 : } else {
265 14 : None
266 : };
267 40 : (valid_since, ignore_cache_since)
268 40 : }
269 :
270 0 : pub async fn gc_worker(&self) -> anyhow::Result<Infallible> {
271 0 : let mut interval =
272 0 : tokio::time::interval(self.config.gc_interval / (self.cache.shards().len()) as u32);
273 : loop {
274 0 : interval.tick().await;
275 0 : if self.cache.len() < self.config.size {
276 : // If there are not too many entries, wait until the next gc cycle.
277 0 : continue;
278 0 : }
279 0 : self.gc();
280 : }
281 : }
282 :
283 0 : fn gc(&self) {
284 0 : let shard = thread_rng().gen_range(0..self.project2ep.shards().len());
285 0 : debug!(shard, "project_info_cache: performing epoch reclamation");
286 :
287 : // acquire a random shard lock
288 0 : let mut removed = 0;
289 0 : let shard = self.project2ep.shards()[shard].write();
290 0 : for (_, endpoints) in shard.iter() {
291 0 : for endpoint in endpoints.get().iter() {
292 0 : self.cache.remove(endpoint);
293 0 : removed += 1;
294 0 : }
295 : }
296 : // We can drop this shard only after making sure that all endpoints are removed.
297 0 : drop(shard);
298 0 : info!("project_info_cache: removed {removed} endpoints");
299 0 : }
300 : }
301 :
302 : /// Lookup info for project info cache.
303 : /// This is used to invalidate cache entries.
304 : pub struct CachedLookupInfo {
305 : /// Search by this key.
306 : endpoint_id: EndpointIdInt,
307 : lookup_type: LookupType,
308 : }
309 :
310 : impl CachedLookupInfo {
311 8 : pub(self) fn new_role_secret(endpoint_id: EndpointIdInt, role_name: RoleNameInt) -> Self {
312 8 : Self {
313 8 : endpoint_id,
314 8 : lookup_type: LookupType::RoleSecret(role_name),
315 8 : }
316 8 : }
317 2 : pub(self) fn new_allowed_ips(endpoint_id: EndpointIdInt) -> Self {
318 2 : Self {
319 2 : endpoint_id,
320 2 : lookup_type: LookupType::AllowedIps,
321 2 : }
322 2 : }
323 : }
324 :
325 : enum LookupType {
326 : RoleSecret(RoleNameInt),
327 : AllowedIps,
328 : }
329 :
330 : impl Cache for ProjectInfoCacheImpl {
331 : type Key = SmolStr;
332 : // Value is not really used here, but we need to specify it.
333 : type Value = SmolStr;
334 :
335 : type LookupInfo<Key> = CachedLookupInfo;
336 :
337 0 : fn invalidate(&self, key: &Self::LookupInfo<SmolStr>) {
338 0 : match &key.lookup_type {
339 0 : LookupType::RoleSecret(role_name) => {
340 0 : if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) {
341 0 : endpoint_info.invalidate_role_secret(*role_name);
342 0 : }
343 : }
344 : LookupType::AllowedIps => {
345 0 : if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) {
346 0 : endpoint_info.invalidate_allowed_ips();
347 0 : }
348 : }
349 : }
350 0 : }
351 : }
352 :
353 : #[cfg(test)]
354 : mod tests {
355 : use super::*;
356 : use crate::{scram::ServerSecret, ProjectId};
357 :
358 : #[tokio::test]
359 2 : async fn test_project_info_cache_settings() {
360 2 : tokio::time::pause();
361 2 : let cache = ProjectInfoCacheImpl::new(ProjectInfoCacheOptions {
362 2 : size: 2,
363 2 : max_roles: 2,
364 2 : ttl: Duration::from_secs(1),
365 2 : gc_interval: Duration::from_secs(600),
366 2 : });
367 2 : let project_id: ProjectId = "project".into();
368 2 : let endpoint_id: EndpointId = "endpoint".into();
369 2 : let user1: RoleName = "user1".into();
370 2 : let user2: RoleName = "user2".into();
371 2 : let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32])));
372 2 : let secret2 = None;
373 2 : let allowed_ips = Arc::new(vec![
374 2 : "127.0.0.1".parse().unwrap(),
375 2 : "127.0.0.2".parse().unwrap(),
376 2 : ]);
377 2 : cache.insert_role_secret(
378 2 : (&project_id).into(),
379 2 : (&endpoint_id).into(),
380 2 : (&user1).into(),
381 2 : secret1.clone(),
382 2 : );
383 2 : cache.insert_role_secret(
384 2 : (&project_id).into(),
385 2 : (&endpoint_id).into(),
386 2 : (&user2).into(),
387 2 : secret2.clone(),
388 2 : );
389 2 : cache.insert_allowed_ips(
390 2 : (&project_id).into(),
391 2 : (&endpoint_id).into(),
392 2 : allowed_ips.clone(),
393 2 : );
394 2 :
395 2 : let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
396 2 : assert!(cached.cached());
397 2 : assert_eq!(cached.value, secret1);
398 2 : let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap();
399 2 : assert!(cached.cached());
400 2 : assert_eq!(cached.value, secret2);
401 2 :
402 2 : // Shouldn't add more than 2 roles.
403 2 : let user3: RoleName = "user3".into();
404 2 : let secret3 = Some(AuthSecret::Scram(ServerSecret::mock([3; 32])));
405 2 : cache.insert_role_secret(
406 2 : (&project_id).into(),
407 2 : (&endpoint_id).into(),
408 2 : (&user3).into(),
409 2 : secret3.clone(),
410 2 : );
411 2 : assert!(cache.get_role_secret(&endpoint_id, &user3).is_none());
412 2 :
413 2 : let cached = cache.get_allowed_ips(&endpoint_id).unwrap();
414 2 : assert!(cached.cached());
415 2 : assert_eq!(cached.value, allowed_ips);
416 2 :
417 2 : tokio::time::advance(Duration::from_secs(2)).await;
418 2 : let cached = cache.get_role_secret(&endpoint_id, &user1);
419 2 : assert!(cached.is_none());
420 2 : let cached = cache.get_role_secret(&endpoint_id, &user2);
421 2 : assert!(cached.is_none());
422 2 : let cached = cache.get_allowed_ips(&endpoint_id);
423 2 : assert!(cached.is_none());
424 2 : }
425 :
426 : #[tokio::test]
427 2 : async fn test_project_info_cache_invalidations() {
428 2 : tokio::time::pause();
429 2 : let cache = Arc::new(ProjectInfoCacheImpl::new(ProjectInfoCacheOptions {
430 2 : size: 2,
431 2 : max_roles: 2,
432 2 : ttl: Duration::from_secs(1),
433 2 : gc_interval: Duration::from_secs(600),
434 2 : }));
435 2 : cache.clone().disable_ttl();
436 2 : tokio::time::advance(Duration::from_secs(2)).await;
437 2 :
438 2 : let project_id: ProjectId = "project".into();
439 2 : let endpoint_id: EndpointId = "endpoint".into();
440 2 : let user1: RoleName = "user1".into();
441 2 : let user2: RoleName = "user2".into();
442 2 : let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32])));
443 2 : let secret2 = Some(AuthSecret::Scram(ServerSecret::mock([2; 32])));
444 2 : let allowed_ips = Arc::new(vec![
445 2 : "127.0.0.1".parse().unwrap(),
446 2 : "127.0.0.2".parse().unwrap(),
447 2 : ]);
448 2 : cache.insert_role_secret(
449 2 : (&project_id).into(),
450 2 : (&endpoint_id).into(),
451 2 : (&user1).into(),
452 2 : secret1.clone(),
453 2 : );
454 2 : cache.insert_role_secret(
455 2 : (&project_id).into(),
456 2 : (&endpoint_id).into(),
457 2 : (&user2).into(),
458 2 : secret2.clone(),
459 2 : );
460 2 : cache.insert_allowed_ips(
461 2 : (&project_id).into(),
462 2 : (&endpoint_id).into(),
463 2 : allowed_ips.clone(),
464 2 : );
465 2 :
466 2 : tokio::time::advance(Duration::from_secs(2)).await;
467 2 : // Nothing should be invalidated.
468 2 :
469 2 : let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
470 2 : // TTL is disabled, so it should be impossible to invalidate this value.
471 2 : assert!(!cached.cached());
472 2 : assert_eq!(cached.value, secret1);
473 2 :
474 2 : cached.invalidate(); // Shouldn't do anything.
475 2 : let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
476 2 : assert_eq!(cached.value, secret1);
477 2 :
478 2 : let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap();
479 2 : assert!(!cached.cached());
480 2 : assert_eq!(cached.value, secret2);
481 2 :
482 2 : // The only way to invalidate this value is to invalidate via the api.
483 2 : cache.invalidate_role_secret_for_project((&project_id).into(), (&user2).into());
484 2 : assert!(cache.get_role_secret(&endpoint_id, &user2).is_none());
485 2 :
486 2 : let cached = cache.get_allowed_ips(&endpoint_id).unwrap();
487 2 : assert!(!cached.cached());
488 2 : assert_eq!(cached.value, allowed_ips);
489 2 : }
490 :
491 : #[tokio::test]
492 2 : async fn test_disable_ttl_invalidate_added_before() {
493 2 : tokio::time::pause();
494 2 : let cache = Arc::new(ProjectInfoCacheImpl::new(ProjectInfoCacheOptions {
495 2 : size: 2,
496 2 : max_roles: 2,
497 2 : ttl: Duration::from_secs(1),
498 2 : gc_interval: Duration::from_secs(600),
499 2 : }));
500 2 :
501 2 : let project_id: ProjectId = "project".into();
502 2 : let endpoint_id: EndpointId = "endpoint".into();
503 2 : let user1: RoleName = "user1".into();
504 2 : let user2: RoleName = "user2".into();
505 2 : let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32])));
506 2 : let secret2 = Some(AuthSecret::Scram(ServerSecret::mock([2; 32])));
507 2 : let allowed_ips = Arc::new(vec![
508 2 : "127.0.0.1".parse().unwrap(),
509 2 : "127.0.0.2".parse().unwrap(),
510 2 : ]);
511 2 : cache.insert_role_secret(
512 2 : (&project_id).into(),
513 2 : (&endpoint_id).into(),
514 2 : (&user1).into(),
515 2 : secret1.clone(),
516 2 : );
517 2 : cache.clone().disable_ttl();
518 2 : tokio::time::advance(Duration::from_millis(100)).await;
519 2 : cache.insert_role_secret(
520 2 : (&project_id).into(),
521 2 : (&endpoint_id).into(),
522 2 : (&user2).into(),
523 2 : secret2.clone(),
524 2 : );
525 2 :
526 2 : // Added before ttl was disabled + ttl should be still cached.
527 2 : let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
528 2 : assert!(cached.cached());
529 2 : let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap();
530 2 : assert!(cached.cached());
531 2 :
532 2 : tokio::time::advance(Duration::from_secs(1)).await;
533 2 : // Added before ttl was disabled + ttl should expire.
534 2 : assert!(cache.get_role_secret(&endpoint_id, &user1).is_none());
535 2 : assert!(cache.get_role_secret(&endpoint_id, &user2).is_none());
536 2 :
537 2 : // Added after ttl was disabled + ttl should not be cached.
538 2 : cache.insert_allowed_ips(
539 2 : (&project_id).into(),
540 2 : (&endpoint_id).into(),
541 2 : allowed_ips.clone(),
542 2 : );
543 2 : let cached = cache.get_allowed_ips(&endpoint_id).unwrap();
544 2 : assert!(!cached.cached());
545 2 :
546 2 : tokio::time::advance(Duration::from_secs(1)).await;
547 2 : // Added before ttl was disabled + ttl still should expire.
548 2 : assert!(cache.get_role_secret(&endpoint_id, &user1).is_none());
549 2 : assert!(cache.get_role_secret(&endpoint_id, &user2).is_none());
550 2 : // Shouldn't be invalidated.
551 2 :
552 2 : let cached = cache.get_allowed_ips(&endpoint_id).unwrap();
553 2 : assert!(!cached.cached());
554 2 : assert_eq!(cached.value, allowed_ips);
555 2 : }
556 : }
|