Line data Source code
1 : use std::collections::{HashMap, HashSet, hash_map};
2 : use std::convert::Infallible;
3 : use std::sync::atomic::AtomicU64;
4 : use std::time::Duration;
5 :
6 : use async_trait::async_trait;
7 : use clashmap::ClashMap;
8 : use clashmap::mapref::one::Ref;
9 : use rand::{Rng, thread_rng};
10 : use tokio::sync::Mutex;
11 : use tokio::time::Instant;
12 : use tracing::{debug, info};
13 :
14 : use crate::config::ProjectInfoCacheOptions;
15 : use crate::control_plane::{EndpointAccessControl, RoleAccessControl};
16 : use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt};
17 : use crate::types::{EndpointId, RoleName};
18 :
19 : #[async_trait]
20 : pub(crate) trait ProjectInfoCache {
21 : fn invalidate_endpoint_access(&self, endpoint_id: EndpointIdInt);
22 : fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt);
23 : fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt);
24 : fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt);
25 : async fn decrement_active_listeners(&self);
26 : async fn increment_active_listeners(&self);
27 : }
28 :
29 : struct Entry<T> {
30 : created_at: Instant,
31 : value: T,
32 : }
33 :
34 : impl<T> Entry<T> {
35 6 : pub(crate) fn new(value: T) -> Self {
36 6 : Self {
37 6 : created_at: Instant::now(),
38 6 : value,
39 6 : }
40 6 : }
41 :
42 6 : pub(crate) fn get(&self, valid_since: Instant) -> Option<&T> {
43 6 : (valid_since < self.created_at).then_some(&self.value)
44 6 : }
45 : }
46 :
47 : impl<T> From<T> for Entry<T> {
48 6 : fn from(value: T) -> Self {
49 6 : Self::new(value)
50 6 : }
51 : }
52 :
53 : struct EndpointInfo {
54 : role_controls: HashMap<RoleNameInt, Entry<RoleAccessControl>>,
55 : controls: Option<Entry<EndpointAccessControl>>,
56 : }
57 :
58 : impl EndpointInfo {
59 5 : pub(crate) fn get_role_secret(
60 5 : &self,
61 5 : role_name: RoleNameInt,
62 5 : valid_since: Instant,
63 5 : ) -> Option<RoleAccessControl> {
64 5 : let controls = self.role_controls.get(&role_name)?;
65 4 : controls.get(valid_since).cloned()
66 5 : }
67 :
68 2 : pub(crate) fn get_controls(&self, valid_since: Instant) -> Option<EndpointAccessControl> {
69 2 : let controls = self.controls.as_ref()?;
70 2 : controls.get(valid_since).cloned()
71 2 : }
72 :
73 0 : pub(crate) fn invalidate_endpoint(&mut self) {
74 0 : self.controls = None;
75 0 : }
76 :
77 0 : pub(crate) fn invalidate_role_secret(&mut self, role_name: RoleNameInt) {
78 0 : self.role_controls.remove(&role_name);
79 0 : }
80 : }
81 :
82 : /// Cache for project info.
83 : /// This is used to cache auth data for endpoints.
84 : /// Invalidation is done by console notifications or by TTL (if console notifications are disabled).
85 : ///
86 : /// We also store endpoint-to-project mapping in the cache, to be able to access per-endpoint data.
87 : /// One may ask, why the data is stored per project, when on the user request there is only data about the endpoint available?
88 : /// On the cplane side updates are done per project (or per branch), so it's easier to invalidate the whole project cache.
89 : pub struct ProjectInfoCacheImpl {
90 : cache: ClashMap<EndpointIdInt, EndpointInfo>,
91 :
92 : project2ep: ClashMap<ProjectIdInt, HashSet<EndpointIdInt>>,
93 : // FIXME(stefan): we need a way to GC the account2ep map.
94 : account2ep: ClashMap<AccountIdInt, HashSet<EndpointIdInt>>,
95 : config: ProjectInfoCacheOptions,
96 :
97 : start_time: Instant,
98 : ttl_disabled_since_us: AtomicU64,
99 : active_listeners_lock: Mutex<usize>,
100 : }
101 :
102 : #[async_trait]
103 : impl ProjectInfoCache for ProjectInfoCacheImpl {
104 0 : fn invalidate_endpoint_access(&self, endpoint_id: EndpointIdInt) {
105 0 : info!("invalidating endpoint access for `{endpoint_id}`");
106 0 : if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
107 0 : endpoint_info.invalidate_endpoint();
108 0 : }
109 0 : }
110 :
111 0 : fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt) {
112 0 : info!("invalidating endpoint access for project `{project_id}`");
113 0 : let endpoints = self
114 0 : .project2ep
115 0 : .get(&project_id)
116 0 : .map(|kv| kv.value().clone())
117 0 : .unwrap_or_default();
118 0 : for endpoint_id in endpoints {
119 0 : if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
120 0 : endpoint_info.invalidate_endpoint();
121 0 : }
122 : }
123 0 : }
124 :
125 0 : fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt) {
126 0 : info!("invalidating endpoint access for org `{account_id}`");
127 0 : let endpoints = self
128 0 : .account2ep
129 0 : .get(&account_id)
130 0 : .map(|kv| kv.value().clone())
131 0 : .unwrap_or_default();
132 0 : for endpoint_id in endpoints {
133 0 : if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
134 0 : endpoint_info.invalidate_endpoint();
135 0 : }
136 : }
137 0 : }
138 :
139 0 : fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt) {
140 0 : info!(
141 0 : "invalidating role secret for project_id `{}` and role_name `{}`",
142 : project_id, role_name,
143 : );
144 0 : let endpoints = self
145 0 : .project2ep
146 0 : .get(&project_id)
147 0 : .map(|kv| kv.value().clone())
148 0 : .unwrap_or_default();
149 0 : for endpoint_id in endpoints {
150 0 : if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
151 0 : endpoint_info.invalidate_role_secret(role_name);
152 0 : }
153 : }
154 0 : }
155 :
156 0 : async fn decrement_active_listeners(&self) {
157 0 : let mut listeners_guard = self.active_listeners_lock.lock().await;
158 0 : if *listeners_guard == 0 {
159 0 : tracing::error!("active_listeners count is already 0, something is broken");
160 0 : return;
161 0 : }
162 0 : *listeners_guard -= 1;
163 0 : if *listeners_guard == 0 {
164 0 : self.ttl_disabled_since_us
165 0 : .store(u64::MAX, std::sync::atomic::Ordering::SeqCst);
166 0 : }
167 0 : }
168 :
169 0 : async fn increment_active_listeners(&self) {
170 0 : let mut listeners_guard = self.active_listeners_lock.lock().await;
171 0 : *listeners_guard += 1;
172 0 : if *listeners_guard == 1 {
173 0 : let new_ttl = (self.start_time.elapsed() + self.config.ttl).as_micros() as u64;
174 0 : self.ttl_disabled_since_us
175 0 : .store(new_ttl, std::sync::atomic::Ordering::SeqCst);
176 0 : }
177 0 : }
178 : }
179 :
180 : impl ProjectInfoCacheImpl {
181 1 : pub(crate) fn new(config: ProjectInfoCacheOptions) -> Self {
182 1 : Self {
183 1 : cache: ClashMap::new(),
184 1 : project2ep: ClashMap::new(),
185 1 : account2ep: ClashMap::new(),
186 1 : config,
187 1 : ttl_disabled_since_us: AtomicU64::new(u64::MAX),
188 1 : start_time: Instant::now(),
189 1 : active_listeners_lock: Mutex::new(0),
190 1 : }
191 1 : }
192 :
193 7 : fn get_endpoint_cache(
194 7 : &self,
195 7 : endpoint_id: &EndpointId,
196 7 : ) -> Option<Ref<'_, EndpointIdInt, EndpointInfo>> {
197 7 : let endpoint_id = EndpointIdInt::get(endpoint_id)?;
198 7 : self.cache.get(&endpoint_id)
199 7 : }
200 :
201 5 : pub(crate) fn get_role_secret(
202 5 : &self,
203 5 : endpoint_id: &EndpointId,
204 5 : role_name: &RoleName,
205 5 : ) -> Option<RoleAccessControl> {
206 5 : let valid_since = self.get_cache_times();
207 5 : let role_name = RoleNameInt::get(role_name)?;
208 5 : let endpoint_info = self.get_endpoint_cache(endpoint_id)?;
209 5 : endpoint_info.get_role_secret(role_name, valid_since)
210 5 : }
211 :
212 2 : pub(crate) fn get_endpoint_access(
213 2 : &self,
214 2 : endpoint_id: &EndpointId,
215 2 : ) -> Option<EndpointAccessControl> {
216 2 : let valid_since = self.get_cache_times();
217 2 : let endpoint_info = self.get_endpoint_cache(endpoint_id)?;
218 2 : endpoint_info.get_controls(valid_since)
219 2 : }
220 :
221 3 : pub(crate) fn insert_endpoint_access(
222 3 : &self,
223 3 : account_id: Option<AccountIdInt>,
224 3 : project_id: ProjectIdInt,
225 3 : endpoint_id: EndpointIdInt,
226 3 : role_name: RoleNameInt,
227 3 : controls: EndpointAccessControl,
228 3 : role_controls: RoleAccessControl,
229 3 : ) {
230 3 : if let Some(account_id) = account_id {
231 0 : self.insert_account2endpoint(account_id, endpoint_id);
232 3 : }
233 3 : self.insert_project2endpoint(project_id, endpoint_id);
234 :
235 3 : if self.cache.len() >= self.config.size {
236 : // If there are too many entries, wait until the next gc cycle.
237 0 : return;
238 3 : }
239 :
240 3 : let controls = Entry::from(controls);
241 3 : let role_controls = Entry::from(role_controls);
242 :
243 3 : match self.cache.entry(endpoint_id) {
244 1 : clashmap::Entry::Vacant(e) => {
245 1 : e.insert(EndpointInfo {
246 1 : role_controls: HashMap::from_iter([(role_name, role_controls)]),
247 1 : controls: Some(controls),
248 1 : });
249 1 : }
250 2 : clashmap::Entry::Occupied(mut e) => {
251 2 : let ep = e.get_mut();
252 2 : ep.controls = Some(controls);
253 2 : if ep.role_controls.len() < self.config.max_roles {
254 1 : ep.role_controls.insert(role_name, role_controls);
255 1 : }
256 : }
257 : }
258 3 : }
259 :
260 3 : fn insert_project2endpoint(&self, project_id: ProjectIdInt, endpoint_id: EndpointIdInt) {
261 3 : if let Some(mut endpoints) = self.project2ep.get_mut(&project_id) {
262 2 : endpoints.insert(endpoint_id);
263 2 : } else {
264 1 : self.project2ep
265 1 : .insert(project_id, HashSet::from([endpoint_id]));
266 1 : }
267 3 : }
268 :
269 0 : fn insert_account2endpoint(&self, account_id: AccountIdInt, endpoint_id: EndpointIdInt) {
270 0 : if let Some(mut endpoints) = self.account2ep.get_mut(&account_id) {
271 0 : endpoints.insert(endpoint_id);
272 0 : } else {
273 0 : self.account2ep
274 0 : .insert(account_id, HashSet::from([endpoint_id]));
275 0 : }
276 0 : }
277 :
278 7 : fn ignore_ttl_since(&self) -> Option<Instant> {
279 7 : let ttl_disabled_since_us = self
280 7 : .ttl_disabled_since_us
281 7 : .load(std::sync::atomic::Ordering::Relaxed);
282 :
283 7 : if ttl_disabled_since_us == u64::MAX {
284 7 : return None;
285 0 : }
286 :
287 0 : Some(self.start_time + Duration::from_micros(ttl_disabled_since_us))
288 7 : }
289 :
290 7 : fn get_cache_times(&self) -> Instant {
291 7 : let mut valid_since = Instant::now() - self.config.ttl;
292 7 : if let Some(ignore_ttl_since) = self.ignore_ttl_since() {
293 0 : // We are fine if entry is not older than ttl or was added before we are getting notifications.
294 0 : valid_since = valid_since.min(ignore_ttl_since);
295 7 : }
296 7 : valid_since
297 7 : }
298 :
299 0 : pub fn maybe_invalidate_role_secret(&self, endpoint_id: &EndpointId, role_name: &RoleName) {
300 0 : let Some(endpoint_id) = EndpointIdInt::get(endpoint_id) else {
301 0 : return;
302 : };
303 0 : let Some(role_name) = RoleNameInt::get(role_name) else {
304 0 : return;
305 : };
306 :
307 0 : let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) else {
308 0 : return;
309 : };
310 :
311 0 : let entry = endpoint_info.role_controls.entry(role_name);
312 0 : let hash_map::Entry::Occupied(role_controls) = entry else {
313 0 : return;
314 : };
315 :
316 0 : let created_at = role_controls.get().created_at;
317 0 : let expire = match self.ignore_ttl_since() {
318 : // if ignoring TTL, we should still try and roll the password if it's old
319 : // and we the client gave an incorrect password. There could be some lag on the redis channel.
320 0 : Some(_) => created_at + self.config.ttl < Instant::now(),
321 : // edge case: redis is down, let's be generous and invalidate the cache immediately.
322 0 : None => true,
323 : };
324 :
325 0 : if expire {
326 0 : role_controls.remove();
327 0 : }
328 0 : }
329 :
330 0 : pub async fn gc_worker(&self) -> anyhow::Result<Infallible> {
331 0 : let mut interval =
332 0 : tokio::time::interval(self.config.gc_interval / (self.cache.shards().len()) as u32);
333 : loop {
334 0 : interval.tick().await;
335 0 : if self.cache.len() < self.config.size {
336 : // If there are not too many entries, wait until the next gc cycle.
337 0 : continue;
338 0 : }
339 0 : self.gc();
340 : }
341 : }
342 :
343 0 : fn gc(&self) {
344 0 : let shard = thread_rng().gen_range(0..self.project2ep.shards().len());
345 0 : debug!(shard, "project_info_cache: performing epoch reclamation");
346 :
347 : // acquire a random shard lock
348 0 : let mut removed = 0;
349 0 : let shard = self.project2ep.shards()[shard].write();
350 0 : for (_, endpoints) in shard.iter() {
351 0 : for endpoint in endpoints {
352 0 : self.cache.remove(endpoint);
353 0 : removed += 1;
354 0 : }
355 : }
356 : // We can drop this shard only after making sure that all endpoints are removed.
357 0 : drop(shard);
358 0 : info!("project_info_cache: removed {removed} endpoints");
359 0 : }
360 : }
361 :
362 : #[cfg(test)]
363 : mod tests {
364 : use std::sync::Arc;
365 :
366 : use super::*;
367 : use crate::control_plane::messages::EndpointRateLimitConfig;
368 : use crate::control_plane::{AccessBlockerFlags, AuthSecret};
369 : use crate::scram::ServerSecret;
370 : use crate::types::ProjectId;
371 :
372 : #[tokio::test]
373 1 : async fn test_project_info_cache_settings() {
374 1 : tokio::time::pause();
375 1 : let cache = ProjectInfoCacheImpl::new(ProjectInfoCacheOptions {
376 1 : size: 2,
377 1 : max_roles: 2,
378 1 : ttl: Duration::from_secs(1),
379 1 : gc_interval: Duration::from_secs(600),
380 1 : });
381 1 : let project_id: ProjectId = "project".into();
382 1 : let endpoint_id: EndpointId = "endpoint".into();
383 1 : let account_id: Option<AccountIdInt> = None;
384 :
385 1 : let user1: RoleName = "user1".into();
386 1 : let user2: RoleName = "user2".into();
387 1 : let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32])));
388 1 : let secret2 = None;
389 1 : let allowed_ips = Arc::new(vec![
390 1 : "127.0.0.1".parse().unwrap(),
391 1 : "127.0.0.2".parse().unwrap(),
392 : ]);
393 :
394 1 : cache.insert_endpoint_access(
395 1 : account_id,
396 1 : (&project_id).into(),
397 1 : (&endpoint_id).into(),
398 1 : (&user1).into(),
399 1 : EndpointAccessControl {
400 1 : allowed_ips: allowed_ips.clone(),
401 1 : allowed_vpce: Arc::new(vec![]),
402 1 : flags: AccessBlockerFlags::default(),
403 1 : rate_limits: EndpointRateLimitConfig::default(),
404 1 : },
405 1 : RoleAccessControl {
406 1 : secret: secret1.clone(),
407 1 : },
408 : );
409 :
410 1 : cache.insert_endpoint_access(
411 1 : account_id,
412 1 : (&project_id).into(),
413 1 : (&endpoint_id).into(),
414 1 : (&user2).into(),
415 1 : EndpointAccessControl {
416 1 : allowed_ips: allowed_ips.clone(),
417 1 : allowed_vpce: Arc::new(vec![]),
418 1 : flags: AccessBlockerFlags::default(),
419 1 : rate_limits: EndpointRateLimitConfig::default(),
420 1 : },
421 1 : RoleAccessControl {
422 1 : secret: secret2.clone(),
423 1 : },
424 : );
425 :
426 1 : let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
427 1 : assert_eq!(cached.secret, secret1);
428 :
429 1 : let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap();
430 1 : assert_eq!(cached.secret, secret2);
431 :
432 : // Shouldn't add more than 2 roles.
433 1 : let user3: RoleName = "user3".into();
434 1 : let secret3 = Some(AuthSecret::Scram(ServerSecret::mock([3; 32])));
435 :
436 1 : cache.insert_endpoint_access(
437 1 : account_id,
438 1 : (&project_id).into(),
439 1 : (&endpoint_id).into(),
440 1 : (&user3).into(),
441 1 : EndpointAccessControl {
442 1 : allowed_ips: allowed_ips.clone(),
443 1 : allowed_vpce: Arc::new(vec![]),
444 1 : flags: AccessBlockerFlags::default(),
445 1 : rate_limits: EndpointRateLimitConfig::default(),
446 1 : },
447 1 : RoleAccessControl {
448 1 : secret: secret3.clone(),
449 1 : },
450 : );
451 :
452 1 : assert!(cache.get_role_secret(&endpoint_id, &user3).is_none());
453 :
454 1 : let cached = cache.get_endpoint_access(&endpoint_id).unwrap();
455 1 : assert_eq!(cached.allowed_ips, allowed_ips);
456 :
457 1 : tokio::time::advance(Duration::from_secs(2)).await;
458 1 : let cached = cache.get_role_secret(&endpoint_id, &user1);
459 1 : assert!(cached.is_none());
460 1 : let cached = cache.get_role_secret(&endpoint_id, &user2);
461 1 : assert!(cached.is_none());
462 1 : let cached = cache.get_endpoint_access(&endpoint_id);
463 1 : assert!(cached.is_none());
464 1 : }
465 : }
|