Line data Source code
1 : use std::{
2 : collections::HashSet,
3 : convert::Infallible,
4 : sync::{atomic::AtomicU64, Arc},
5 : time::Duration,
6 : };
7 :
8 : use dashmap::DashMap;
9 : use rand::{thread_rng, Rng};
10 : use smol_str::SmolStr;
11 : use tokio::time::Instant;
12 : use tracing::{debug, info};
13 :
14 : use crate::{
15 : auth::IpPattern,
16 : config::ProjectInfoCacheOptions,
17 : console::AuthSecret,
18 : intern::{EndpointIdInt, ProjectIdInt, RoleNameInt},
19 : EndpointId, ProjectId, RoleName,
20 : };
21 :
22 : use super::{Cache, Cached};
23 :
24 : pub trait ProjectInfoCache {
25 : fn invalidate_allowed_ips_for_project(&self, project_id: ProjectIdInt);
26 : fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt);
27 : fn enable_ttl(&self);
28 : fn disable_ttl(&self);
29 : }
30 :
31 : struct Entry<T> {
32 : created_at: Instant,
33 : value: T,
34 : }
35 :
36 : impl<T> Entry<T> {
37 18 : pub fn new(value: T) -> Self {
38 18 : Self {
39 18 : created_at: Instant::now(),
40 18 : value,
41 18 : }
42 18 : }
43 : }
44 :
45 : impl<T> From<T> for Entry<T> {
46 18 : fn from(value: T) -> Self {
47 18 : Self::new(value)
48 18 : }
49 : }
50 :
51 6 : #[derive(Default)]
52 : struct EndpointInfo {
53 : secret: std::collections::HashMap<RoleNameInt, Entry<Option<AuthSecret>>>,
54 : allowed_ips: Option<Entry<Arc<Vec<IpPattern>>>>,
55 : }
56 :
57 : impl EndpointInfo {
58 22 : fn check_ignore_cache(ignore_cache_since: Option<Instant>, created_at: Instant) -> bool {
59 22 : match ignore_cache_since {
60 6 : None => false,
61 16 : Some(t) => t < created_at,
62 : }
63 22 : }
64 30 : pub fn get_role_secret(
65 30 : &self,
66 30 : role_name: RoleNameInt,
67 30 : valid_since: Instant,
68 30 : ignore_cache_since: Option<Instant>,
69 30 : ) -> Option<(Option<AuthSecret>, bool)> {
70 30 : if let Some(secret) = self.secret.get(&role_name) {
71 26 : if valid_since < secret.created_at {
72 14 : return Some((
73 14 : secret.value.clone(),
74 14 : Self::check_ignore_cache(ignore_cache_since, secret.created_at),
75 14 : ));
76 12 : }
77 4 : }
78 16 : None
79 30 : }
80 :
81 10 : pub fn get_allowed_ips(
82 10 : &self,
83 10 : valid_since: Instant,
84 10 : ignore_cache_since: Option<Instant>,
85 10 : ) -> Option<(Arc<Vec<IpPattern>>, bool)> {
86 10 : if let Some(allowed_ips) = &self.allowed_ips {
87 10 : if valid_since < allowed_ips.created_at {
88 8 : return Some((
89 8 : allowed_ips.value.clone(),
90 8 : Self::check_ignore_cache(ignore_cache_since, allowed_ips.created_at),
91 8 : ));
92 2 : }
93 0 : }
94 2 : None
95 10 : }
96 0 : pub fn invalidate_allowed_ips(&mut self) {
97 0 : self.allowed_ips = None;
98 0 : }
99 2 : pub fn invalidate_role_secret(&mut self, role_name: RoleNameInt) {
100 2 : self.secret.remove(&role_name);
101 2 : }
102 : }
103 :
104 : /// Cache for project info.
105 : /// This is used to cache auth data for endpoints.
106 : /// Invalidation is done by console notifications or by TTL (if console notifications are disabled).
107 : ///
108 : /// We also store endpoint-to-project mapping in the cache, to be able to access per-endpoint data.
109 : /// One may ask, why the data is stored per project, when on the user request there is only data about the endpoint available?
110 : /// On the cplane side updates are done per project (or per branch), so it's easier to invalidate the whole project cache.
111 : pub struct ProjectInfoCacheImpl {
112 : cache: DashMap<EndpointIdInt, EndpointInfo>,
113 :
114 : project2ep: DashMap<ProjectIdInt, HashSet<EndpointIdInt>>,
115 : config: ProjectInfoCacheOptions,
116 :
117 : start_time: Instant,
118 : ttl_disabled_since_us: AtomicU64,
119 : }
120 :
121 : impl ProjectInfoCache for ProjectInfoCacheImpl {
122 0 : fn invalidate_allowed_ips_for_project(&self, project_id: ProjectIdInt) {
123 0 : info!("invalidating allowed ips for project `{}`", project_id);
124 0 : let endpoints = self
125 0 : .project2ep
126 0 : .get(&project_id)
127 0 : .map(|kv| kv.value().clone())
128 0 : .unwrap_or_default();
129 0 : for endpoint_id in endpoints {
130 0 : if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
131 0 : endpoint_info.invalidate_allowed_ips();
132 0 : }
133 : }
134 0 : }
135 2 : fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt) {
136 2 : info!(
137 0 : "invalidating role secret for project_id `{}` and role_name `{}`",
138 0 : project_id, role_name,
139 0 : );
140 2 : let endpoints = self
141 2 : .project2ep
142 2 : .get(&project_id)
143 2 : .map(|kv| kv.value().clone())
144 2 : .unwrap_or_default();
145 4 : for endpoint_id in endpoints {
146 2 : if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
147 2 : endpoint_info.invalidate_role_secret(role_name);
148 2 : }
149 : }
150 2 : }
151 0 : fn enable_ttl(&self) {
152 0 : self.ttl_disabled_since_us
153 0 : .store(u64::MAX, std::sync::atomic::Ordering::Relaxed);
154 0 : }
155 :
156 4 : fn disable_ttl(&self) {
157 4 : let new_ttl = (self.start_time.elapsed() + self.config.ttl).as_micros() as u64;
158 4 : self.ttl_disabled_since_us
159 4 : .store(new_ttl, std::sync::atomic::Ordering::Relaxed);
160 4 : }
161 : }
162 :
163 : impl ProjectInfoCacheImpl {
164 7 : pub fn new(config: ProjectInfoCacheOptions) -> Self {
165 7 : Self {
166 7 : cache: DashMap::new(),
167 7 : project2ep: DashMap::new(),
168 7 : config,
169 7 : ttl_disabled_since_us: AtomicU64::new(u64::MAX),
170 7 : start_time: Instant::now(),
171 7 : }
172 7 : }
173 :
174 30 : pub fn get_role_secret(
175 30 : &self,
176 30 : endpoint_id: &EndpointId,
177 30 : role_name: &RoleName,
178 30 : ) -> Option<Cached<&Self, Option<AuthSecret>>> {
179 30 : let endpoint_id = EndpointIdInt::get(endpoint_id)?;
180 30 : let role_name = RoleNameInt::get(role_name)?;
181 30 : let (valid_since, ignore_cache_since) = self.get_cache_times();
182 30 : let endpoint_info = self.cache.get(&endpoint_id)?;
183 14 : let (value, ignore_cache) =
184 30 : endpoint_info.get_role_secret(role_name, valid_since, ignore_cache_since)?;
185 14 : if !ignore_cache {
186 8 : let cached = Cached {
187 8 : token: Some((
188 8 : self,
189 8 : CachedLookupInfo::new_role_secret(endpoint_id, role_name),
190 8 : )),
191 8 : value,
192 8 : };
193 8 : return Some(cached);
194 6 : }
195 6 : Some(Cached::new_uncached(value))
196 30 : }
197 14 : pub fn get_allowed_ips(
198 14 : &self,
199 14 : endpoint_id: &EndpointId,
200 14 : ) -> Option<Cached<&Self, Arc<Vec<IpPattern>>>> {
201 14 : let endpoint_id = EndpointIdInt::get(endpoint_id)?;
202 10 : let (valid_since, ignore_cache_since) = self.get_cache_times();
203 10 : let endpoint_info = self.cache.get(&endpoint_id)?;
204 10 : let value = endpoint_info.get_allowed_ips(valid_since, ignore_cache_since);
205 10 : let (value, ignore_cache) = value?;
206 8 : if !ignore_cache {
207 2 : let cached = Cached {
208 2 : token: Some((self, CachedLookupInfo::new_allowed_ips(endpoint_id))),
209 2 : value,
210 2 : };
211 2 : return Some(cached);
212 6 : }
213 6 : Some(Cached::new_uncached(value))
214 14 : }
215 14 : pub fn insert_role_secret(
216 14 : &self,
217 14 : project_id: &ProjectId,
218 14 : endpoint_id: &EndpointId,
219 14 : role_name: &RoleName,
220 14 : secret: Option<AuthSecret>,
221 14 : ) {
222 14 : let project_id = ProjectIdInt::from(project_id);
223 14 : let endpoint_id = EndpointIdInt::from(endpoint_id);
224 14 : let role_name = RoleNameInt::from(role_name);
225 14 : if self.cache.len() >= self.config.size {
226 : // If there are too many entries, wait until the next gc cycle.
227 0 : return;
228 14 : }
229 14 : self.insert_project2endpoint(project_id, endpoint_id);
230 14 : let mut entry = self.cache.entry(endpoint_id).or_default();
231 14 : if entry.secret.len() < self.config.max_roles {
232 12 : entry.secret.insert(role_name, secret.into());
233 12 : }
234 14 : }
235 6 : pub fn insert_allowed_ips(
236 6 : &self,
237 6 : project_id: &ProjectId,
238 6 : endpoint_id: &EndpointId,
239 6 : allowed_ips: Arc<Vec<IpPattern>>,
240 6 : ) {
241 6 : let project_id = ProjectIdInt::from(project_id);
242 6 : let endpoint_id = EndpointIdInt::from(endpoint_id);
243 6 : if self.cache.len() >= self.config.size {
244 : // If there are too many entries, wait until the next gc cycle.
245 0 : return;
246 6 : }
247 6 : self.insert_project2endpoint(project_id, endpoint_id);
248 6 : self.cache.entry(endpoint_id).or_default().allowed_ips = Some(allowed_ips.into());
249 6 : }
250 20 : fn insert_project2endpoint(&self, project_id: ProjectIdInt, endpoint_id: EndpointIdInt) {
251 20 : if let Some(mut endpoints) = self.project2ep.get_mut(&project_id) {
252 14 : endpoints.insert(endpoint_id);
253 14 : } else {
254 6 : self.project2ep
255 6 : .insert(project_id, HashSet::from([endpoint_id]));
256 6 : }
257 20 : }
258 40 : fn get_cache_times(&self) -> (Instant, Option<Instant>) {
259 40 : let mut valid_since = Instant::now() - self.config.ttl;
260 40 : // Only ignore cache if ttl is disabled.
261 40 : let ttl_disabled_since_us = self
262 40 : .ttl_disabled_since_us
263 40 : .load(std::sync::atomic::Ordering::Relaxed);
264 40 : let ignore_cache_since = if ttl_disabled_since_us != u64::MAX {
265 26 : let ignore_cache_since = self.start_time + Duration::from_micros(ttl_disabled_since_us);
266 26 : // We are fine if entry is not older than ttl or was added before we are getting notifications.
267 26 : valid_since = valid_since.min(ignore_cache_since);
268 26 : Some(ignore_cache_since)
269 : } else {
270 14 : None
271 : };
272 40 : (valid_since, ignore_cache_since)
273 40 : }
274 :
275 1 : pub async fn gc_worker(&self) -> anyhow::Result<Infallible> {
276 1 : let mut interval =
277 1 : tokio::time::interval(self.config.gc_interval / (self.cache.shards().len()) as u32);
278 : loop {
279 2 : interval.tick().await;
280 1 : if self.cache.len() < self.config.size {
281 : // If there are not too many entries, wait until the next gc cycle.
282 1 : continue;
283 0 : }
284 0 : self.gc();
285 : }
286 : }
287 :
288 0 : fn gc(&self) {
289 0 : let shard = thread_rng().gen_range(0..self.project2ep.shards().len());
290 0 : debug!(shard, "project_info_cache: performing epoch reclamation");
291 :
292 : // acquire a random shard lock
293 0 : let mut removed = 0;
294 0 : let shard = self.project2ep.shards()[shard].write();
295 0 : for (_, endpoints) in shard.iter() {
296 0 : for endpoint in endpoints.get().iter() {
297 0 : self.cache.remove(endpoint);
298 0 : removed += 1;
299 0 : }
300 : }
301 : // We can drop this shard only after making sure that all endpoints are removed.
302 0 : drop(shard);
303 0 : info!("project_info_cache: removed {removed} endpoints");
304 0 : }
305 : }
306 :
307 : /// Lookup info for project info cache.
308 : /// This is used to invalidate cache entries.
309 : pub struct CachedLookupInfo {
310 : /// Search by this key.
311 : endpoint_id: EndpointIdInt,
312 : lookup_type: LookupType,
313 : }
314 :
315 : impl CachedLookupInfo {
316 8 : pub(self) fn new_role_secret(endpoint_id: EndpointIdInt, role_name: RoleNameInt) -> Self {
317 8 : Self {
318 8 : endpoint_id,
319 8 : lookup_type: LookupType::RoleSecret(role_name),
320 8 : }
321 8 : }
322 2 : pub(self) fn new_allowed_ips(endpoint_id: EndpointIdInt) -> Self {
323 2 : Self {
324 2 : endpoint_id,
325 2 : lookup_type: LookupType::AllowedIps,
326 2 : }
327 2 : }
328 : }
329 :
330 : enum LookupType {
331 : RoleSecret(RoleNameInt),
332 : AllowedIps,
333 : }
334 :
335 : impl Cache for ProjectInfoCacheImpl {
336 : type Key = SmolStr;
337 : // Value is not really used here, but we need to specify it.
338 : type Value = SmolStr;
339 :
340 : type LookupInfo<Key> = CachedLookupInfo;
341 :
342 0 : fn invalidate(&self, key: &Self::LookupInfo<SmolStr>) {
343 0 : match &key.lookup_type {
344 0 : LookupType::RoleSecret(role_name) => {
345 0 : if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) {
346 0 : endpoint_info.invalidate_role_secret(*role_name);
347 0 : }
348 : }
349 : LookupType::AllowedIps => {
350 0 : if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) {
351 0 : endpoint_info.invalidate_allowed_ips();
352 0 : }
353 : }
354 : }
355 0 : }
356 : }
357 :
358 : #[cfg(test)]
359 : mod tests {
360 : use super::*;
361 : use crate::{console::AuthSecret, scram::ServerSecret};
362 : use std::{sync::Arc, time::Duration};
363 :
364 2 : #[tokio::test]
365 2 : async fn test_project_info_cache_settings() {
366 2 : tokio::time::pause();
367 2 : let cache = ProjectInfoCacheImpl::new(ProjectInfoCacheOptions {
368 2 : size: 2,
369 2 : max_roles: 2,
370 2 : ttl: Duration::from_secs(1),
371 2 : gc_interval: Duration::from_secs(600),
372 2 : });
373 2 : let project_id = "project".into();
374 2 : let endpoint_id = "endpoint".into();
375 2 : let user1: RoleName = "user1".into();
376 2 : let user2: RoleName = "user2".into();
377 2 : let secret1 = Some(AuthSecret::Scram(ServerSecret::mock(
378 2 : user1.as_str(),
379 2 : [1; 32],
380 2 : )));
381 2 : let secret2 = None;
382 2 : let allowed_ips = Arc::new(vec![
383 2 : "127.0.0.1".parse().unwrap(),
384 2 : "127.0.0.2".parse().unwrap(),
385 2 : ]);
386 2 : cache.insert_role_secret(&project_id, &endpoint_id, &user1, secret1.clone());
387 2 : cache.insert_role_secret(&project_id, &endpoint_id, &user2, secret2.clone());
388 2 : cache.insert_allowed_ips(&project_id, &endpoint_id, allowed_ips.clone());
389 2 :
390 2 : let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
391 2 : assert!(cached.cached());
392 2 : assert_eq!(cached.value, secret1);
393 2 : let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap();
394 2 : assert!(cached.cached());
395 2 : assert_eq!(cached.value, secret2);
396 :
397 : // Shouldn't add more than 2 roles.
398 2 : let user3: RoleName = "user3".into();
399 2 : let secret3 = Some(AuthSecret::Scram(ServerSecret::mock(
400 2 : user3.as_str(),
401 2 : [3; 32],
402 2 : )));
403 2 : cache.insert_role_secret(&project_id, &endpoint_id, &user3, secret3.clone());
404 2 : assert!(cache.get_role_secret(&endpoint_id, &user3).is_none());
405 :
406 2 : let cached = cache.get_allowed_ips(&endpoint_id).unwrap();
407 2 : assert!(cached.cached());
408 2 : assert_eq!(cached.value, allowed_ips);
409 :
410 2 : tokio::time::advance(Duration::from_secs(2)).await;
411 2 : let cached = cache.get_role_secret(&endpoint_id, &user1);
412 2 : assert!(cached.is_none());
413 2 : let cached = cache.get_role_secret(&endpoint_id, &user2);
414 2 : assert!(cached.is_none());
415 2 : let cached = cache.get_allowed_ips(&endpoint_id);
416 2 : assert!(cached.is_none());
417 : }
418 :
419 2 : #[tokio::test]
420 2 : async fn test_project_info_cache_invalidations() {
421 2 : tokio::time::pause();
422 2 : let cache = Arc::new(ProjectInfoCacheImpl::new(ProjectInfoCacheOptions {
423 2 : size: 2,
424 2 : max_roles: 2,
425 2 : ttl: Duration::from_secs(1),
426 2 : gc_interval: Duration::from_secs(600),
427 2 : }));
428 2 : cache.clone().disable_ttl();
429 2 : tokio::time::advance(Duration::from_secs(2)).await;
430 :
431 2 : let project_id = "project".into();
432 2 : let endpoint_id = "endpoint".into();
433 2 : let user1: RoleName = "user1".into();
434 2 : let user2: RoleName = "user2".into();
435 2 : let secret1 = Some(AuthSecret::Scram(ServerSecret::mock(
436 2 : user1.as_str(),
437 2 : [1; 32],
438 2 : )));
439 2 : let secret2 = Some(AuthSecret::Scram(ServerSecret::mock(
440 2 : user2.as_str(),
441 2 : [2; 32],
442 2 : )));
443 2 : let allowed_ips = Arc::new(vec![
444 2 : "127.0.0.1".parse().unwrap(),
445 2 : "127.0.0.2".parse().unwrap(),
446 2 : ]);
447 2 : cache.insert_role_secret(&project_id, &endpoint_id, &user1, secret1.clone());
448 2 : cache.insert_role_secret(&project_id, &endpoint_id, &user2, secret2.clone());
449 2 : cache.insert_allowed_ips(&project_id, &endpoint_id, allowed_ips.clone());
450 2 :
451 2 : tokio::time::advance(Duration::from_secs(2)).await;
452 : // Nothing should be invalidated.
453 :
454 2 : let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
455 : // TTL is disabled, so it should be impossible to invalidate this value.
456 2 : assert!(!cached.cached());
457 2 : assert_eq!(cached.value, secret1);
458 :
459 2 : cached.invalidate(); // Shouldn't do anything.
460 2 : let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
461 2 : assert_eq!(cached.value, secret1);
462 :
463 2 : let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap();
464 2 : assert!(!cached.cached());
465 2 : assert_eq!(cached.value, secret2);
466 :
467 : // The only way to invalidate this value is to invalidate via the api.
468 2 : cache.invalidate_role_secret_for_project((&project_id).into(), (&user2).into());
469 2 : assert!(cache.get_role_secret(&endpoint_id, &user2).is_none());
470 :
471 2 : let cached = cache.get_allowed_ips(&endpoint_id).unwrap();
472 2 : assert!(!cached.cached());
473 2 : assert_eq!(cached.value, allowed_ips);
474 : }
475 :
476 2 : #[tokio::test]
477 2 : async fn test_disable_ttl_invalidate_added_before() {
478 2 : tokio::time::pause();
479 2 : let cache = Arc::new(ProjectInfoCacheImpl::new(ProjectInfoCacheOptions {
480 2 : size: 2,
481 2 : max_roles: 2,
482 2 : ttl: Duration::from_secs(1),
483 2 : gc_interval: Duration::from_secs(600),
484 2 : }));
485 2 :
486 2 : let project_id = "project".into();
487 2 : let endpoint_id = "endpoint".into();
488 2 : let user1: RoleName = "user1".into();
489 2 : let user2: RoleName = "user2".into();
490 2 : let secret1 = Some(AuthSecret::Scram(ServerSecret::mock(
491 2 : user1.as_str(),
492 2 : [1; 32],
493 2 : )));
494 2 : let secret2 = Some(AuthSecret::Scram(ServerSecret::mock(
495 2 : user2.as_str(),
496 2 : [2; 32],
497 2 : )));
498 2 : let allowed_ips = Arc::new(vec![
499 2 : "127.0.0.1".parse().unwrap(),
500 2 : "127.0.0.2".parse().unwrap(),
501 2 : ]);
502 2 : cache.insert_role_secret(&project_id, &endpoint_id, &user1, secret1.clone());
503 2 : cache.clone().disable_ttl();
504 2 : tokio::time::advance(Duration::from_millis(100)).await;
505 2 : cache.insert_role_secret(&project_id, &endpoint_id, &user2, secret2.clone());
506 2 :
507 2 : // Added before ttl was disabled + ttl should be still cached.
508 2 : let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
509 2 : assert!(cached.cached());
510 2 : let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap();
511 2 : assert!(cached.cached());
512 :
513 2 : tokio::time::advance(Duration::from_secs(1)).await;
514 : // Added before ttl was disabled + ttl should expire.
515 2 : assert!(cache.get_role_secret(&endpoint_id, &user1).is_none());
516 2 : assert!(cache.get_role_secret(&endpoint_id, &user2).is_none());
517 :
518 : // Added after ttl was disabled + ttl should not be cached.
519 2 : cache.insert_allowed_ips(&project_id, &endpoint_id, allowed_ips.clone());
520 2 : let cached = cache.get_allowed_ips(&endpoint_id).unwrap();
521 2 : assert!(!cached.cached());
522 :
523 2 : tokio::time::advance(Duration::from_secs(1)).await;
524 : // Added before ttl was disabled + ttl still should expire.
525 2 : assert!(cache.get_role_secret(&endpoint_id, &user1).is_none());
526 2 : assert!(cache.get_role_secret(&endpoint_id, &user2).is_none());
527 : // Shouldn't be invalidated.
528 :
529 2 : let cached = cache.get_allowed_ips(&endpoint_id).unwrap();
530 2 : assert!(!cached.cached());
531 2 : assert_eq!(cached.value, allowed_ips);
532 : }
533 : }
|