Line data Source code
1 : use std::collections::HashSet;
2 : use std::convert::Infallible;
3 : use std::sync::Arc;
4 : use std::sync::atomic::AtomicU64;
5 : use std::time::Duration;
6 :
7 : use async_trait::async_trait;
8 : use clashmap::ClashMap;
9 : use rand::{Rng, thread_rng};
10 : use smol_str::SmolStr;
11 : use tokio::sync::Mutex;
12 : use tokio::time::Instant;
13 : use tracing::{debug, info};
14 :
15 : use super::{Cache, Cached};
16 : use crate::auth::IpPattern;
17 : use crate::config::ProjectInfoCacheOptions;
18 : use crate::control_plane::{AccessBlockerFlags, AuthSecret};
19 : use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt};
20 : use crate::types::{EndpointId, RoleName};
21 :
22 : #[async_trait]
23 : pub(crate) trait ProjectInfoCache {
24 : fn invalidate_allowed_ips_for_project(&self, project_id: ProjectIdInt);
25 : fn invalidate_allowed_vpc_endpoint_ids_for_projects(&self, project_ids: Vec<ProjectIdInt>);
26 : fn invalidate_allowed_vpc_endpoint_ids_for_org(&self, account_id: AccountIdInt);
27 : fn invalidate_block_public_or_vpc_access_for_project(&self, project_id: ProjectIdInt);
28 : fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt);
29 : async fn decrement_active_listeners(&self);
30 : async fn increment_active_listeners(&self);
31 : }
32 :
33 : struct Entry<T> {
34 : created_at: Instant,
35 : value: T,
36 : }
37 :
38 : impl<T> Entry<T> {
39 9 : pub(crate) fn new(value: T) -> Self {
40 9 : Self {
41 9 : created_at: Instant::now(),
42 9 : value,
43 9 : }
44 9 : }
45 : }
46 :
47 : impl<T> From<T> for Entry<T> {
48 9 : fn from(value: T) -> Self {
49 9 : Self::new(value)
50 9 : }
51 : }
52 :
53 : #[derive(Default)]
54 : struct EndpointInfo {
55 : secret: std::collections::HashMap<RoleNameInt, Entry<Option<AuthSecret>>>,
56 : allowed_ips: Option<Entry<Arc<Vec<IpPattern>>>>,
57 : block_public_or_vpc_access: Option<Entry<AccessBlockerFlags>>,
58 : allowed_vpc_endpoint_ids: Option<Entry<Arc<Vec<String>>>>,
59 : }
60 :
61 : impl EndpointInfo {
62 11 : fn check_ignore_cache(ignore_cache_since: Option<Instant>, created_at: Instant) -> bool {
63 11 : match ignore_cache_since {
64 3 : None => false,
65 8 : Some(t) => t < created_at,
66 : }
67 11 : }
68 15 : pub(crate) fn get_role_secret(
69 15 : &self,
70 15 : role_name: RoleNameInt,
71 15 : valid_since: Instant,
72 15 : ignore_cache_since: Option<Instant>,
73 15 : ) -> Option<(Option<AuthSecret>, bool)> {
74 15 : if let Some(secret) = self.secret.get(&role_name) {
75 13 : if valid_since < secret.created_at {
76 7 : return Some((
77 7 : secret.value.clone(),
78 7 : Self::check_ignore_cache(ignore_cache_since, secret.created_at),
79 7 : ));
80 6 : }
81 2 : }
82 8 : None
83 15 : }
84 :
85 5 : pub(crate) fn get_allowed_ips(
86 5 : &self,
87 5 : valid_since: Instant,
88 5 : ignore_cache_since: Option<Instant>,
89 5 : ) -> Option<(Arc<Vec<IpPattern>>, bool)> {
90 5 : if let Some(allowed_ips) = &self.allowed_ips {
91 5 : if valid_since < allowed_ips.created_at {
92 4 : return Some((
93 4 : allowed_ips.value.clone(),
94 4 : Self::check_ignore_cache(ignore_cache_since, allowed_ips.created_at),
95 4 : ));
96 1 : }
97 0 : }
98 1 : None
99 5 : }
100 0 : pub(crate) fn get_allowed_vpc_endpoint_ids(
101 0 : &self,
102 0 : valid_since: Instant,
103 0 : ignore_cache_since: Option<Instant>,
104 0 : ) -> Option<(Arc<Vec<String>>, bool)> {
105 0 : if let Some(allowed_vpc_endpoint_ids) = &self.allowed_vpc_endpoint_ids {
106 0 : if valid_since < allowed_vpc_endpoint_ids.created_at {
107 0 : return Some((
108 0 : allowed_vpc_endpoint_ids.value.clone(),
109 0 : Self::check_ignore_cache(
110 0 : ignore_cache_since,
111 0 : allowed_vpc_endpoint_ids.created_at,
112 0 : ),
113 0 : ));
114 0 : }
115 0 : }
116 0 : None
117 0 : }
118 0 : pub(crate) fn get_block_public_or_vpc_access(
119 0 : &self,
120 0 : valid_since: Instant,
121 0 : ignore_cache_since: Option<Instant>,
122 0 : ) -> Option<(AccessBlockerFlags, bool)> {
123 0 : if let Some(block_public_or_vpc_access) = &self.block_public_or_vpc_access {
124 0 : if valid_since < block_public_or_vpc_access.created_at {
125 0 : return Some((
126 0 : block_public_or_vpc_access.value.clone(),
127 0 : Self::check_ignore_cache(
128 0 : ignore_cache_since,
129 0 : block_public_or_vpc_access.created_at,
130 0 : ),
131 0 : ));
132 0 : }
133 0 : }
134 0 : None
135 0 : }
136 :
137 0 : pub(crate) fn invalidate_allowed_ips(&mut self) {
138 0 : self.allowed_ips = None;
139 0 : }
140 0 : pub(crate) fn invalidate_allowed_vpc_endpoint_ids(&mut self) {
141 0 : self.allowed_vpc_endpoint_ids = None;
142 0 : }
143 0 : pub(crate) fn invalidate_block_public_or_vpc_access(&mut self) {
144 0 : self.block_public_or_vpc_access = None;
145 0 : }
146 1 : pub(crate) fn invalidate_role_secret(&mut self, role_name: RoleNameInt) {
147 1 : self.secret.remove(&role_name);
148 1 : }
149 : }
150 :
151 : /// Cache for project info.
152 : /// This is used to cache auth data for endpoints.
153 : /// Invalidation is done by console notifications or by TTL (if console notifications are disabled).
154 : ///
155 : /// We also store endpoint-to-project mapping in the cache, to be able to access per-endpoint data.
156 : /// One may ask, why the data is stored per project, when on the user request there is only data about the endpoint available?
157 : /// On the cplane side updates are done per project (or per branch), so it's easier to invalidate the whole project cache.
158 : pub struct ProjectInfoCacheImpl {
159 : cache: ClashMap<EndpointIdInt, EndpointInfo>,
160 :
161 : project2ep: ClashMap<ProjectIdInt, HashSet<EndpointIdInt>>,
162 : // FIXME(stefan): we need a way to GC the account2ep map.
163 : account2ep: ClashMap<AccountIdInt, HashSet<EndpointIdInt>>,
164 : config: ProjectInfoCacheOptions,
165 :
166 : start_time: Instant,
167 : ttl_disabled_since_us: AtomicU64,
168 : active_listeners_lock: Mutex<usize>,
169 : }
170 :
171 : #[async_trait]
172 : impl ProjectInfoCache for ProjectInfoCacheImpl {
173 0 : fn invalidate_allowed_vpc_endpoint_ids_for_projects(&self, project_ids: Vec<ProjectIdInt>) {
174 0 : info!(
175 0 : "invalidating allowed vpc endpoint ids for projects `{}`",
176 0 : project_ids
177 0 : .iter()
178 0 : .map(|id| id.to_string())
179 0 : .collect::<Vec<_>>()
180 0 : .join(", ")
181 : );
182 0 : for project_id in project_ids {
183 0 : let endpoints = self
184 0 : .project2ep
185 0 : .get(&project_id)
186 0 : .map(|kv| kv.value().clone())
187 0 : .unwrap_or_default();
188 0 : for endpoint_id in endpoints {
189 0 : if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
190 0 : endpoint_info.invalidate_allowed_vpc_endpoint_ids();
191 0 : }
192 : }
193 : }
194 0 : }
195 :
196 0 : fn invalidate_allowed_vpc_endpoint_ids_for_org(&self, account_id: AccountIdInt) {
197 0 : info!(
198 0 : "invalidating allowed vpc endpoint ids for org `{}`",
199 : account_id
200 : );
201 0 : let endpoints = self
202 0 : .account2ep
203 0 : .get(&account_id)
204 0 : .map(|kv| kv.value().clone())
205 0 : .unwrap_or_default();
206 0 : for endpoint_id in endpoints {
207 0 : if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
208 0 : endpoint_info.invalidate_allowed_vpc_endpoint_ids();
209 0 : }
210 : }
211 0 : }
212 :
213 0 : fn invalidate_block_public_or_vpc_access_for_project(&self, project_id: ProjectIdInt) {
214 0 : info!(
215 0 : "invalidating block public or vpc access for project `{}`",
216 : project_id
217 : );
218 0 : let endpoints = self
219 0 : .project2ep
220 0 : .get(&project_id)
221 0 : .map(|kv| kv.value().clone())
222 0 : .unwrap_or_default();
223 0 : for endpoint_id in endpoints {
224 0 : if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
225 0 : endpoint_info.invalidate_block_public_or_vpc_access();
226 0 : }
227 : }
228 0 : }
229 :
230 0 : fn invalidate_allowed_ips_for_project(&self, project_id: ProjectIdInt) {
231 0 : info!("invalidating allowed ips for project `{}`", project_id);
232 0 : let endpoints = self
233 0 : .project2ep
234 0 : .get(&project_id)
235 0 : .map(|kv| kv.value().clone())
236 0 : .unwrap_or_default();
237 0 : for endpoint_id in endpoints {
238 0 : if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
239 0 : endpoint_info.invalidate_allowed_ips();
240 0 : }
241 : }
242 0 : }
243 1 : fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt) {
244 1 : info!(
245 0 : "invalidating role secret for project_id `{}` and role_name `{}`",
246 : project_id, role_name,
247 : );
248 1 : let endpoints = self
249 1 : .project2ep
250 1 : .get(&project_id)
251 1 : .map(|kv| kv.value().clone())
252 1 : .unwrap_or_default();
253 2 : for endpoint_id in endpoints {
254 1 : if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
255 1 : endpoint_info.invalidate_role_secret(role_name);
256 1 : }
257 : }
258 1 : }
259 0 : async fn decrement_active_listeners(&self) {
260 0 : let mut listeners_guard = self.active_listeners_lock.lock().await;
261 0 : if *listeners_guard == 0 {
262 0 : tracing::error!("active_listeners count is already 0, something is broken");
263 0 : return;
264 0 : }
265 0 : *listeners_guard -= 1;
266 0 : if *listeners_guard == 0 {
267 0 : self.ttl_disabled_since_us
268 0 : .store(u64::MAX, std::sync::atomic::Ordering::SeqCst);
269 0 : }
270 0 : }
271 :
272 2 : async fn increment_active_listeners(&self) {
273 2 : let mut listeners_guard = self.active_listeners_lock.lock().await;
274 2 : *listeners_guard += 1;
275 2 : if *listeners_guard == 1 {
276 2 : let new_ttl = (self.start_time.elapsed() + self.config.ttl).as_micros() as u64;
277 2 : self.ttl_disabled_since_us
278 2 : .store(new_ttl, std::sync::atomic::Ordering::SeqCst);
279 2 : }
280 4 : }
281 : }
282 :
283 : impl ProjectInfoCacheImpl {
284 3 : pub(crate) fn new(config: ProjectInfoCacheOptions) -> Self {
285 3 : Self {
286 3 : cache: ClashMap::new(),
287 3 : project2ep: ClashMap::new(),
288 3 : account2ep: ClashMap::new(),
289 3 : config,
290 3 : ttl_disabled_since_us: AtomicU64::new(u64::MAX),
291 3 : start_time: Instant::now(),
292 3 : active_listeners_lock: Mutex::new(0),
293 3 : }
294 3 : }
295 :
296 15 : pub(crate) fn get_role_secret(
297 15 : &self,
298 15 : endpoint_id: &EndpointId,
299 15 : role_name: &RoleName,
300 15 : ) -> Option<Cached<&Self, Option<AuthSecret>>> {
301 15 : let endpoint_id = EndpointIdInt::get(endpoint_id)?;
302 15 : let role_name = RoleNameInt::get(role_name)?;
303 15 : let (valid_since, ignore_cache_since) = self.get_cache_times();
304 15 : let endpoint_info = self.cache.get(&endpoint_id)?;
305 7 : let (value, ignore_cache) =
306 15 : endpoint_info.get_role_secret(role_name, valid_since, ignore_cache_since)?;
307 7 : if !ignore_cache {
308 4 : let cached = Cached {
309 4 : token: Some((
310 4 : self,
311 4 : CachedLookupInfo::new_role_secret(endpoint_id, role_name),
312 4 : )),
313 4 : value,
314 4 : };
315 4 : return Some(cached);
316 3 : }
317 3 : Some(Cached::new_uncached(value))
318 15 : }
319 5 : pub(crate) fn get_allowed_ips(
320 5 : &self,
321 5 : endpoint_id: &EndpointId,
322 5 : ) -> Option<Cached<&Self, Arc<Vec<IpPattern>>>> {
323 5 : let endpoint_id = EndpointIdInt::get(endpoint_id)?;
324 5 : let (valid_since, ignore_cache_since) = self.get_cache_times();
325 5 : let endpoint_info = self.cache.get(&endpoint_id)?;
326 5 : let value = endpoint_info.get_allowed_ips(valid_since, ignore_cache_since);
327 5 : let (value, ignore_cache) = value?;
328 4 : if !ignore_cache {
329 1 : let cached = Cached {
330 1 : token: Some((self, CachedLookupInfo::new_allowed_ips(endpoint_id))),
331 1 : value,
332 1 : };
333 1 : return Some(cached);
334 3 : }
335 3 : Some(Cached::new_uncached(value))
336 5 : }
337 0 : pub(crate) fn get_allowed_vpc_endpoint_ids(
338 0 : &self,
339 0 : endpoint_id: &EndpointId,
340 0 : ) -> Option<Cached<&Self, Arc<Vec<String>>>> {
341 0 : let endpoint_id = EndpointIdInt::get(endpoint_id)?;
342 0 : let (valid_since, ignore_cache_since) = self.get_cache_times();
343 0 : let endpoint_info = self.cache.get(&endpoint_id)?;
344 0 : let value = endpoint_info.get_allowed_vpc_endpoint_ids(valid_since, ignore_cache_since);
345 0 : let (value, ignore_cache) = value?;
346 0 : if !ignore_cache {
347 0 : let cached = Cached {
348 0 : token: Some((
349 0 : self,
350 0 : CachedLookupInfo::new_allowed_vpc_endpoint_ids(endpoint_id),
351 0 : )),
352 0 : value,
353 0 : };
354 0 : return Some(cached);
355 0 : }
356 0 : Some(Cached::new_uncached(value))
357 0 : }
358 0 : pub(crate) fn get_block_public_or_vpc_access(
359 0 : &self,
360 0 : endpoint_id: &EndpointId,
361 0 : ) -> Option<Cached<&Self, AccessBlockerFlags>> {
362 0 : let endpoint_id = EndpointIdInt::get(endpoint_id)?;
363 0 : let (valid_since, ignore_cache_since) = self.get_cache_times();
364 0 : let endpoint_info = self.cache.get(&endpoint_id)?;
365 0 : let value = endpoint_info.get_block_public_or_vpc_access(valid_since, ignore_cache_since);
366 0 : let (value, ignore_cache) = value?;
367 0 : if !ignore_cache {
368 0 : let cached = Cached {
369 0 : token: Some((
370 0 : self,
371 0 : CachedLookupInfo::new_block_public_or_vpc_access(endpoint_id),
372 0 : )),
373 0 : value,
374 0 : };
375 0 : return Some(cached);
376 0 : }
377 0 : Some(Cached::new_uncached(value))
378 0 : }
379 :
380 7 : pub(crate) fn insert_role_secret(
381 7 : &self,
382 7 : project_id: ProjectIdInt,
383 7 : endpoint_id: EndpointIdInt,
384 7 : role_name: RoleNameInt,
385 7 : secret: Option<AuthSecret>,
386 7 : ) {
387 7 : if self.cache.len() >= self.config.size {
388 : // If there are too many entries, wait until the next gc cycle.
389 0 : return;
390 7 : }
391 7 : self.insert_project2endpoint(project_id, endpoint_id);
392 7 : let mut entry = self.cache.entry(endpoint_id).or_default();
393 7 : if entry.secret.len() < self.config.max_roles {
394 6 : entry.secret.insert(role_name, secret.into());
395 6 : }
396 7 : }
397 3 : pub(crate) fn insert_allowed_ips(
398 3 : &self,
399 3 : project_id: ProjectIdInt,
400 3 : endpoint_id: EndpointIdInt,
401 3 : allowed_ips: Arc<Vec<IpPattern>>,
402 3 : ) {
403 3 : if self.cache.len() >= self.config.size {
404 : // If there are too many entries, wait until the next gc cycle.
405 0 : return;
406 3 : }
407 3 : self.insert_project2endpoint(project_id, endpoint_id);
408 3 : self.cache.entry(endpoint_id).or_default().allowed_ips = Some(allowed_ips.into());
409 3 : }
410 0 : pub(crate) fn insert_allowed_vpc_endpoint_ids(
411 0 : &self,
412 0 : account_id: Option<AccountIdInt>,
413 0 : project_id: ProjectIdInt,
414 0 : endpoint_id: EndpointIdInt,
415 0 : allowed_vpc_endpoint_ids: Arc<Vec<String>>,
416 0 : ) {
417 0 : if self.cache.len() >= self.config.size {
418 : // If there are too many entries, wait until the next gc cycle.
419 0 : return;
420 0 : }
421 0 : if let Some(account_id) = account_id {
422 0 : self.insert_account2endpoint(account_id, endpoint_id);
423 0 : }
424 0 : self.insert_project2endpoint(project_id, endpoint_id);
425 0 : self.cache
426 0 : .entry(endpoint_id)
427 0 : .or_default()
428 0 : .allowed_vpc_endpoint_ids = Some(allowed_vpc_endpoint_ids.into());
429 0 : }
430 0 : pub(crate) fn insert_block_public_or_vpc_access(
431 0 : &self,
432 0 : project_id: ProjectIdInt,
433 0 : endpoint_id: EndpointIdInt,
434 0 : access_blockers: AccessBlockerFlags,
435 0 : ) {
436 0 : if self.cache.len() >= self.config.size {
437 : // If there are too many entries, wait until the next gc cycle.
438 0 : return;
439 0 : }
440 0 : self.insert_project2endpoint(project_id, endpoint_id);
441 0 : self.cache
442 0 : .entry(endpoint_id)
443 0 : .or_default()
444 0 : .block_public_or_vpc_access = Some(access_blockers.into());
445 0 : }
446 :
447 10 : fn insert_project2endpoint(&self, project_id: ProjectIdInt, endpoint_id: EndpointIdInt) {
448 10 : if let Some(mut endpoints) = self.project2ep.get_mut(&project_id) {
449 7 : endpoints.insert(endpoint_id);
450 7 : } else {
451 3 : self.project2ep
452 3 : .insert(project_id, HashSet::from([endpoint_id]));
453 3 : }
454 10 : }
455 0 : fn insert_account2endpoint(&self, account_id: AccountIdInt, endpoint_id: EndpointIdInt) {
456 0 : if let Some(mut endpoints) = self.account2ep.get_mut(&account_id) {
457 0 : endpoints.insert(endpoint_id);
458 0 : } else {
459 0 : self.account2ep
460 0 : .insert(account_id, HashSet::from([endpoint_id]));
461 0 : }
462 0 : }
463 20 : fn get_cache_times(&self) -> (Instant, Option<Instant>) {
464 20 : let mut valid_since = Instant::now() - self.config.ttl;
465 20 : // Only ignore cache if ttl is disabled.
466 20 : let ttl_disabled_since_us = self
467 20 : .ttl_disabled_since_us
468 20 : .load(std::sync::atomic::Ordering::Relaxed);
469 20 : let ignore_cache_since = if ttl_disabled_since_us == u64::MAX {
470 7 : None
471 : } else {
472 13 : let ignore_cache_since = self.start_time + Duration::from_micros(ttl_disabled_since_us);
473 13 : // We are fine if entry is not older than ttl or was added before we are getting notifications.
474 13 : valid_since = valid_since.min(ignore_cache_since);
475 13 : Some(ignore_cache_since)
476 : };
477 20 : (valid_since, ignore_cache_since)
478 20 : }
479 :
480 0 : pub async fn gc_worker(&self) -> anyhow::Result<Infallible> {
481 0 : let mut interval =
482 0 : tokio::time::interval(self.config.gc_interval / (self.cache.shards().len()) as u32);
483 : loop {
484 0 : interval.tick().await;
485 0 : if self.cache.len() < self.config.size {
486 : // If there are not too many entries, wait until the next gc cycle.
487 0 : continue;
488 0 : }
489 0 : self.gc();
490 : }
491 : }
492 :
493 0 : fn gc(&self) {
494 0 : let shard = thread_rng().gen_range(0..self.project2ep.shards().len());
495 0 : debug!(shard, "project_info_cache: performing epoch reclamation");
496 :
497 : // acquire a random shard lock
498 0 : let mut removed = 0;
499 0 : let shard = self.project2ep.shards()[shard].write();
500 0 : for (_, endpoints) in shard.iter() {
501 0 : for endpoint in endpoints {
502 0 : self.cache.remove(endpoint);
503 0 : removed += 1;
504 0 : }
505 : }
506 : // We can drop this shard only after making sure that all endpoints are removed.
507 0 : drop(shard);
508 0 : info!("project_info_cache: removed {removed} endpoints");
509 0 : }
510 : }
511 :
512 : /// Lookup info for project info cache.
513 : /// This is used to invalidate cache entries.
514 : pub(crate) struct CachedLookupInfo {
515 : /// Search by this key.
516 : endpoint_id: EndpointIdInt,
517 : lookup_type: LookupType,
518 : }
519 :
520 : impl CachedLookupInfo {
521 4 : pub(self) fn new_role_secret(endpoint_id: EndpointIdInt, role_name: RoleNameInt) -> Self {
522 4 : Self {
523 4 : endpoint_id,
524 4 : lookup_type: LookupType::RoleSecret(role_name),
525 4 : }
526 4 : }
527 1 : pub(self) fn new_allowed_ips(endpoint_id: EndpointIdInt) -> Self {
528 1 : Self {
529 1 : endpoint_id,
530 1 : lookup_type: LookupType::AllowedIps,
531 1 : }
532 1 : }
533 0 : pub(self) fn new_allowed_vpc_endpoint_ids(endpoint_id: EndpointIdInt) -> Self {
534 0 : Self {
535 0 : endpoint_id,
536 0 : lookup_type: LookupType::AllowedVpcEndpointIds,
537 0 : }
538 0 : }
539 0 : pub(self) fn new_block_public_or_vpc_access(endpoint_id: EndpointIdInt) -> Self {
540 0 : Self {
541 0 : endpoint_id,
542 0 : lookup_type: LookupType::BlockPublicOrVpcAccess,
543 0 : }
544 0 : }
545 : }
546 :
547 : enum LookupType {
548 : RoleSecret(RoleNameInt),
549 : AllowedIps,
550 : AllowedVpcEndpointIds,
551 : BlockPublicOrVpcAccess,
552 : }
553 :
554 : impl Cache for ProjectInfoCacheImpl {
555 : type Key = SmolStr;
556 : // Value is not really used here, but we need to specify it.
557 : type Value = SmolStr;
558 :
559 : type LookupInfo<Key> = CachedLookupInfo;
560 :
561 0 : fn invalidate(&self, key: &Self::LookupInfo<SmolStr>) {
562 0 : match &key.lookup_type {
563 0 : LookupType::RoleSecret(role_name) => {
564 0 : if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) {
565 0 : endpoint_info.invalidate_role_secret(*role_name);
566 0 : }
567 : }
568 : LookupType::AllowedIps => {
569 0 : if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) {
570 0 : endpoint_info.invalidate_allowed_ips();
571 0 : }
572 : }
573 : LookupType::AllowedVpcEndpointIds => {
574 0 : if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) {
575 0 : endpoint_info.invalidate_allowed_vpc_endpoint_ids();
576 0 : }
577 : }
578 : LookupType::BlockPublicOrVpcAccess => {
579 0 : if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) {
580 0 : endpoint_info.invalidate_block_public_or_vpc_access();
581 0 : }
582 : }
583 : }
584 0 : }
585 : }
586 :
587 : #[cfg(test)]
588 : #[expect(clippy::unwrap_used)]
589 : mod tests {
590 : use super::*;
591 : use crate::scram::ServerSecret;
592 : use crate::types::ProjectId;
593 :
594 : #[tokio::test]
595 1 : async fn test_project_info_cache_settings() {
596 1 : tokio::time::pause();
597 1 : let cache = ProjectInfoCacheImpl::new(ProjectInfoCacheOptions {
598 1 : size: 2,
599 1 : max_roles: 2,
600 1 : ttl: Duration::from_secs(1),
601 1 : gc_interval: Duration::from_secs(600),
602 1 : });
603 1 : let project_id: ProjectId = "project".into();
604 1 : let endpoint_id: EndpointId = "endpoint".into();
605 1 : let user1: RoleName = "user1".into();
606 1 : let user2: RoleName = "user2".into();
607 1 : let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32])));
608 1 : let secret2 = None;
609 1 : let allowed_ips = Arc::new(vec![
610 1 : "127.0.0.1".parse().unwrap(),
611 1 : "127.0.0.2".parse().unwrap(),
612 1 : ]);
613 1 : cache.insert_role_secret(
614 1 : (&project_id).into(),
615 1 : (&endpoint_id).into(),
616 1 : (&user1).into(),
617 1 : secret1.clone(),
618 1 : );
619 1 : cache.insert_role_secret(
620 1 : (&project_id).into(),
621 1 : (&endpoint_id).into(),
622 1 : (&user2).into(),
623 1 : secret2.clone(),
624 1 : );
625 1 : cache.insert_allowed_ips(
626 1 : (&project_id).into(),
627 1 : (&endpoint_id).into(),
628 1 : allowed_ips.clone(),
629 1 : );
630 1 :
631 1 : let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
632 1 : assert!(cached.cached());
633 1 : assert_eq!(cached.value, secret1);
634 1 : let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap();
635 1 : assert!(cached.cached());
636 1 : assert_eq!(cached.value, secret2);
637 1 :
638 1 : // Shouldn't add more than 2 roles.
639 1 : let user3: RoleName = "user3".into();
640 1 : let secret3 = Some(AuthSecret::Scram(ServerSecret::mock([3; 32])));
641 1 : cache.insert_role_secret(
642 1 : (&project_id).into(),
643 1 : (&endpoint_id).into(),
644 1 : (&user3).into(),
645 1 : secret3.clone(),
646 1 : );
647 1 : assert!(cache.get_role_secret(&endpoint_id, &user3).is_none());
648 1 :
649 1 : let cached = cache.get_allowed_ips(&endpoint_id).unwrap();
650 1 : assert!(cached.cached());
651 1 : assert_eq!(cached.value, allowed_ips);
652 1 :
653 1 : tokio::time::advance(Duration::from_secs(2)).await;
654 1 : let cached = cache.get_role_secret(&endpoint_id, &user1);
655 1 : assert!(cached.is_none());
656 1 : let cached = cache.get_role_secret(&endpoint_id, &user2);
657 1 : assert!(cached.is_none());
658 1 : let cached = cache.get_allowed_ips(&endpoint_id);
659 1 : assert!(cached.is_none());
660 1 : }
661 :
662 : #[tokio::test]
663 1 : async fn test_project_info_cache_invalidations() {
664 1 : tokio::time::pause();
665 1 : let cache = Arc::new(ProjectInfoCacheImpl::new(ProjectInfoCacheOptions {
666 1 : size: 2,
667 1 : max_roles: 2,
668 1 : ttl: Duration::from_secs(1),
669 1 : gc_interval: Duration::from_secs(600),
670 1 : }));
671 1 : cache.clone().increment_active_listeners().await;
672 1 : tokio::time::advance(Duration::from_secs(2)).await;
673 1 :
674 1 : let project_id: ProjectId = "project".into();
675 1 : let endpoint_id: EndpointId = "endpoint".into();
676 1 : let user1: RoleName = "user1".into();
677 1 : let user2: RoleName = "user2".into();
678 1 : let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32])));
679 1 : let secret2 = Some(AuthSecret::Scram(ServerSecret::mock([2; 32])));
680 1 : let allowed_ips = Arc::new(vec![
681 1 : "127.0.0.1".parse().unwrap(),
682 1 : "127.0.0.2".parse().unwrap(),
683 1 : ]);
684 1 : cache.insert_role_secret(
685 1 : (&project_id).into(),
686 1 : (&endpoint_id).into(),
687 1 : (&user1).into(),
688 1 : secret1.clone(),
689 1 : );
690 1 : cache.insert_role_secret(
691 1 : (&project_id).into(),
692 1 : (&endpoint_id).into(),
693 1 : (&user2).into(),
694 1 : secret2.clone(),
695 1 : );
696 1 : cache.insert_allowed_ips(
697 1 : (&project_id).into(),
698 1 : (&endpoint_id).into(),
699 1 : allowed_ips.clone(),
700 1 : );
701 1 :
702 1 : tokio::time::advance(Duration::from_secs(2)).await;
703 1 : // Nothing should be invalidated.
704 1 :
705 1 : let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
706 1 : // TTL is disabled, so it should be impossible to invalidate this value.
707 1 : assert!(!cached.cached());
708 1 : assert_eq!(cached.value, secret1);
709 1 :
710 1 : cached.invalidate(); // Shouldn't do anything.
711 1 : let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
712 1 : assert_eq!(cached.value, secret1);
713 1 :
714 1 : let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap();
715 1 : assert!(!cached.cached());
716 1 : assert_eq!(cached.value, secret2);
717 1 :
718 1 : // The only way to invalidate this value is to invalidate via the api.
719 1 : cache.invalidate_role_secret_for_project((&project_id).into(), (&user2).into());
720 1 : assert!(cache.get_role_secret(&endpoint_id, &user2).is_none());
721 1 :
722 1 : let cached = cache.get_allowed_ips(&endpoint_id).unwrap();
723 1 : assert!(!cached.cached());
724 1 : assert_eq!(cached.value, allowed_ips);
725 1 : }
726 :
727 : #[tokio::test]
728 1 : async fn test_increment_active_listeners_invalidate_added_before() {
729 1 : tokio::time::pause();
730 1 : let cache = Arc::new(ProjectInfoCacheImpl::new(ProjectInfoCacheOptions {
731 1 : size: 2,
732 1 : max_roles: 2,
733 1 : ttl: Duration::from_secs(1),
734 1 : gc_interval: Duration::from_secs(600),
735 1 : }));
736 1 :
737 1 : let project_id: ProjectId = "project".into();
738 1 : let endpoint_id: EndpointId = "endpoint".into();
739 1 : let user1: RoleName = "user1".into();
740 1 : let user2: RoleName = "user2".into();
741 1 : let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32])));
742 1 : let secret2 = Some(AuthSecret::Scram(ServerSecret::mock([2; 32])));
743 1 : let allowed_ips = Arc::new(vec![
744 1 : "127.0.0.1".parse().unwrap(),
745 1 : "127.0.0.2".parse().unwrap(),
746 1 : ]);
747 1 : cache.insert_role_secret(
748 1 : (&project_id).into(),
749 1 : (&endpoint_id).into(),
750 1 : (&user1).into(),
751 1 : secret1.clone(),
752 1 : );
753 1 : cache.clone().increment_active_listeners().await;
754 1 : tokio::time::advance(Duration::from_millis(100)).await;
755 1 : cache.insert_role_secret(
756 1 : (&project_id).into(),
757 1 : (&endpoint_id).into(),
758 1 : (&user2).into(),
759 1 : secret2.clone(),
760 1 : );
761 1 :
762 1 : // Added before ttl was disabled + ttl should be still cached.
763 1 : let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
764 1 : assert!(cached.cached());
765 1 : let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap();
766 1 : assert!(cached.cached());
767 1 :
768 1 : tokio::time::advance(Duration::from_secs(1)).await;
769 1 : // Added before ttl was disabled + ttl should expire.
770 1 : assert!(cache.get_role_secret(&endpoint_id, &user1).is_none());
771 1 : assert!(cache.get_role_secret(&endpoint_id, &user2).is_none());
772 1 :
773 1 : // Added after ttl was disabled + ttl should not be cached.
774 1 : cache.insert_allowed_ips(
775 1 : (&project_id).into(),
776 1 : (&endpoint_id).into(),
777 1 : allowed_ips.clone(),
778 1 : );
779 1 : let cached = cache.get_allowed_ips(&endpoint_id).unwrap();
780 1 : assert!(!cached.cached());
781 1 :
782 1 : tokio::time::advance(Duration::from_secs(1)).await;
783 1 : // Added before ttl was disabled + ttl still should expire.
784 1 : assert!(cache.get_role_secret(&endpoint_id, &user1).is_none());
785 1 : assert!(cache.get_role_secret(&endpoint_id, &user2).is_none());
786 1 : // Shouldn't be invalidated.
787 1 :
788 1 : let cached = cache.get_allowed_ips(&endpoint_id).unwrap();
789 1 : assert!(!cached.cached());
790 1 : assert_eq!(cached.value, allowed_ips);
791 1 : }
792 : }
|