Line data Source code
1 : use std::collections::{HashMap, HashSet, hash_map};
2 : use std::convert::Infallible;
3 : use std::time::Duration;
4 :
5 : use async_trait::async_trait;
6 : use clashmap::ClashMap;
7 : use clashmap::mapref::one::Ref;
8 : use rand::{Rng, thread_rng};
9 : use tokio::time::Instant;
10 : use tracing::{debug, info};
11 :
12 : use crate::config::ProjectInfoCacheOptions;
13 : use crate::control_plane::{EndpointAccessControl, RoleAccessControl};
14 : use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt};
15 : use crate::types::{EndpointId, RoleName};
16 :
17 : #[async_trait]
18 : pub(crate) trait ProjectInfoCache {
19 : fn invalidate_endpoint_access(&self, endpoint_id: EndpointIdInt);
20 : fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt);
21 : fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt);
22 : fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt);
23 : }
24 :
25 : struct Entry<T> {
26 : expires_at: Instant,
27 : value: T,
28 : }
29 :
30 : impl<T> Entry<T> {
31 6 : pub(crate) fn new(value: T, ttl: Duration) -> Self {
32 6 : Self {
33 6 : expires_at: Instant::now() + ttl,
34 6 : value,
35 6 : }
36 6 : }
37 :
38 6 : pub(crate) fn get(&self) -> Option<&T> {
39 6 : (self.expires_at > Instant::now()).then_some(&self.value)
40 6 : }
41 : }
42 :
43 : struct EndpointInfo {
44 : role_controls: HashMap<RoleNameInt, Entry<RoleAccessControl>>,
45 : controls: Option<Entry<EndpointAccessControl>>,
46 : }
47 :
48 : impl EndpointInfo {
49 5 : pub(crate) fn get_role_secret(&self, role_name: RoleNameInt) -> Option<RoleAccessControl> {
50 5 : self.role_controls.get(&role_name)?.get().cloned()
51 5 : }
52 :
53 2 : pub(crate) fn get_controls(&self) -> Option<EndpointAccessControl> {
54 2 : self.controls.as_ref()?.get().cloned()
55 2 : }
56 :
57 0 : pub(crate) fn invalidate_endpoint(&mut self) {
58 0 : self.controls = None;
59 0 : }
60 :
61 0 : pub(crate) fn invalidate_role_secret(&mut self, role_name: RoleNameInt) {
62 0 : self.role_controls.remove(&role_name);
63 0 : }
64 : }
65 :
66 : /// Cache for project info.
67 : /// This is used to cache auth data for endpoints.
68 : /// Invalidation is done by console notifications or by TTL (if console notifications are disabled).
69 : ///
70 : /// We also store endpoint-to-project mapping in the cache, to be able to access per-endpoint data.
71 : /// One may ask, why the data is stored per project, when on the user request there is only data about the endpoint available?
72 : /// On the cplane side updates are done per project (or per branch), so it's easier to invalidate the whole project cache.
73 : pub struct ProjectInfoCacheImpl {
74 : cache: ClashMap<EndpointIdInt, EndpointInfo>,
75 :
76 : project2ep: ClashMap<ProjectIdInt, HashSet<EndpointIdInt>>,
77 : // FIXME(stefan): we need a way to GC the account2ep map.
78 : account2ep: ClashMap<AccountIdInt, HashSet<EndpointIdInt>>,
79 :
80 : config: ProjectInfoCacheOptions,
81 : }
82 :
83 : #[async_trait]
84 : impl ProjectInfoCache for ProjectInfoCacheImpl {
85 0 : fn invalidate_endpoint_access(&self, endpoint_id: EndpointIdInt) {
86 0 : info!("invalidating endpoint access for `{endpoint_id}`");
87 0 : if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
88 0 : endpoint_info.invalidate_endpoint();
89 0 : }
90 0 : }
91 :
92 0 : fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt) {
93 0 : info!("invalidating endpoint access for project `{project_id}`");
94 0 : let endpoints = self
95 0 : .project2ep
96 0 : .get(&project_id)
97 0 : .map(|kv| kv.value().clone())
98 0 : .unwrap_or_default();
99 0 : for endpoint_id in endpoints {
100 0 : if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
101 0 : endpoint_info.invalidate_endpoint();
102 0 : }
103 : }
104 0 : }
105 :
106 0 : fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt) {
107 0 : info!("invalidating endpoint access for org `{account_id}`");
108 0 : let endpoints = self
109 0 : .account2ep
110 0 : .get(&account_id)
111 0 : .map(|kv| kv.value().clone())
112 0 : .unwrap_or_default();
113 0 : for endpoint_id in endpoints {
114 0 : if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
115 0 : endpoint_info.invalidate_endpoint();
116 0 : }
117 : }
118 0 : }
119 :
120 0 : fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt) {
121 0 : info!(
122 0 : "invalidating role secret for project_id `{}` and role_name `{}`",
123 : project_id, role_name,
124 : );
125 0 : let endpoints = self
126 0 : .project2ep
127 0 : .get(&project_id)
128 0 : .map(|kv| kv.value().clone())
129 0 : .unwrap_or_default();
130 0 : for endpoint_id in endpoints {
131 0 : if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
132 0 : endpoint_info.invalidate_role_secret(role_name);
133 0 : }
134 : }
135 0 : }
136 : }
137 :
138 : impl ProjectInfoCacheImpl {
139 1 : pub(crate) fn new(config: ProjectInfoCacheOptions) -> Self {
140 1 : Self {
141 1 : cache: ClashMap::new(),
142 1 : project2ep: ClashMap::new(),
143 1 : account2ep: ClashMap::new(),
144 1 : config,
145 1 : }
146 1 : }
147 :
148 7 : fn get_endpoint_cache(
149 7 : &self,
150 7 : endpoint_id: &EndpointId,
151 7 : ) -> Option<Ref<'_, EndpointIdInt, EndpointInfo>> {
152 7 : let endpoint_id = EndpointIdInt::get(endpoint_id)?;
153 7 : self.cache.get(&endpoint_id)
154 7 : }
155 :
156 5 : pub(crate) fn get_role_secret(
157 5 : &self,
158 5 : endpoint_id: &EndpointId,
159 5 : role_name: &RoleName,
160 5 : ) -> Option<RoleAccessControl> {
161 5 : let role_name = RoleNameInt::get(role_name)?;
162 5 : let endpoint_info = self.get_endpoint_cache(endpoint_id)?;
163 5 : endpoint_info.get_role_secret(role_name)
164 5 : }
165 :
166 2 : pub(crate) fn get_endpoint_access(
167 2 : &self,
168 2 : endpoint_id: &EndpointId,
169 2 : ) -> Option<EndpointAccessControl> {
170 2 : let endpoint_info = self.get_endpoint_cache(endpoint_id)?;
171 2 : endpoint_info.get_controls()
172 2 : }
173 :
174 3 : pub(crate) fn insert_endpoint_access(
175 3 : &self,
176 3 : account_id: Option<AccountIdInt>,
177 3 : project_id: ProjectIdInt,
178 3 : endpoint_id: EndpointIdInt,
179 3 : role_name: RoleNameInt,
180 3 : controls: EndpointAccessControl,
181 3 : role_controls: RoleAccessControl,
182 3 : ) {
183 3 : if let Some(account_id) = account_id {
184 0 : self.insert_account2endpoint(account_id, endpoint_id);
185 3 : }
186 3 : self.insert_project2endpoint(project_id, endpoint_id);
187 :
188 3 : if self.cache.len() >= self.config.size {
189 : // If there are too many entries, wait until the next gc cycle.
190 0 : return;
191 3 : }
192 :
193 3 : let controls = Entry::new(controls, self.config.ttl);
194 3 : let role_controls = Entry::new(role_controls, self.config.ttl);
195 :
196 3 : match self.cache.entry(endpoint_id) {
197 1 : clashmap::Entry::Vacant(e) => {
198 1 : e.insert(EndpointInfo {
199 1 : role_controls: HashMap::from_iter([(role_name, role_controls)]),
200 1 : controls: Some(controls),
201 1 : });
202 1 : }
203 2 : clashmap::Entry::Occupied(mut e) => {
204 2 : let ep = e.get_mut();
205 2 : ep.controls = Some(controls);
206 2 : if ep.role_controls.len() < self.config.max_roles {
207 1 : ep.role_controls.insert(role_name, role_controls);
208 1 : }
209 : }
210 : }
211 3 : }
212 :
213 3 : fn insert_project2endpoint(&self, project_id: ProjectIdInt, endpoint_id: EndpointIdInt) {
214 3 : if let Some(mut endpoints) = self.project2ep.get_mut(&project_id) {
215 2 : endpoints.insert(endpoint_id);
216 2 : } else {
217 1 : self.project2ep
218 1 : .insert(project_id, HashSet::from([endpoint_id]));
219 1 : }
220 3 : }
221 :
222 0 : fn insert_account2endpoint(&self, account_id: AccountIdInt, endpoint_id: EndpointIdInt) {
223 0 : if let Some(mut endpoints) = self.account2ep.get_mut(&account_id) {
224 0 : endpoints.insert(endpoint_id);
225 0 : } else {
226 0 : self.account2ep
227 0 : .insert(account_id, HashSet::from([endpoint_id]));
228 0 : }
229 0 : }
230 :
231 0 : pub fn maybe_invalidate_role_secret(&self, endpoint_id: &EndpointId, role_name: &RoleName) {
232 0 : let Some(endpoint_id) = EndpointIdInt::get(endpoint_id) else {
233 0 : return;
234 : };
235 0 : let Some(role_name) = RoleNameInt::get(role_name) else {
236 0 : return;
237 : };
238 :
239 0 : let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) else {
240 0 : return;
241 : };
242 :
243 0 : let entry = endpoint_info.role_controls.entry(role_name);
244 0 : let hash_map::Entry::Occupied(role_controls) = entry else {
245 0 : return;
246 : };
247 :
248 0 : if role_controls.get().expires_at <= Instant::now() {
249 0 : role_controls.remove();
250 0 : }
251 0 : }
252 :
253 0 : pub async fn gc_worker(&self) -> anyhow::Result<Infallible> {
254 0 : let mut interval =
255 0 : tokio::time::interval(self.config.gc_interval / (self.cache.shards().len()) as u32);
256 : loop {
257 0 : interval.tick().await;
258 0 : if self.cache.len() < self.config.size {
259 : // If there are not too many entries, wait until the next gc cycle.
260 0 : continue;
261 0 : }
262 0 : self.gc();
263 : }
264 : }
265 :
266 0 : fn gc(&self) {
267 0 : let shard = thread_rng().gen_range(0..self.project2ep.shards().len());
268 0 : debug!(shard, "project_info_cache: performing epoch reclamation");
269 :
270 : // acquire a random shard lock
271 0 : let mut removed = 0;
272 0 : let shard = self.project2ep.shards()[shard].write();
273 0 : for (_, endpoints) in shard.iter() {
274 0 : for endpoint in endpoints {
275 0 : self.cache.remove(endpoint);
276 0 : removed += 1;
277 0 : }
278 : }
279 : // We can drop this shard only after making sure that all endpoints are removed.
280 0 : drop(shard);
281 0 : info!("project_info_cache: removed {removed} endpoints");
282 0 : }
283 : }
284 :
285 : #[cfg(test)]
286 : mod tests {
287 : use std::sync::Arc;
288 :
289 : use super::*;
290 : use crate::control_plane::messages::EndpointRateLimitConfig;
291 : use crate::control_plane::{AccessBlockerFlags, AuthSecret};
292 : use crate::scram::ServerSecret;
293 : use crate::types::ProjectId;
294 :
295 : #[tokio::test]
296 1 : async fn test_project_info_cache_settings() {
297 1 : tokio::time::pause();
298 1 : let cache = ProjectInfoCacheImpl::new(ProjectInfoCacheOptions {
299 1 : size: 2,
300 1 : max_roles: 2,
301 1 : ttl: Duration::from_secs(1),
302 1 : gc_interval: Duration::from_secs(600),
303 1 : });
304 1 : let project_id: ProjectId = "project".into();
305 1 : let endpoint_id: EndpointId = "endpoint".into();
306 1 : let account_id: Option<AccountIdInt> = None;
307 :
308 1 : let user1: RoleName = "user1".into();
309 1 : let user2: RoleName = "user2".into();
310 1 : let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32])));
311 1 : let secret2 = None;
312 1 : let allowed_ips = Arc::new(vec![
313 1 : "127.0.0.1".parse().unwrap(),
314 1 : "127.0.0.2".parse().unwrap(),
315 : ]);
316 :
317 1 : cache.insert_endpoint_access(
318 1 : account_id,
319 1 : (&project_id).into(),
320 1 : (&endpoint_id).into(),
321 1 : (&user1).into(),
322 1 : EndpointAccessControl {
323 1 : allowed_ips: allowed_ips.clone(),
324 1 : allowed_vpce: Arc::new(vec![]),
325 1 : flags: AccessBlockerFlags::default(),
326 1 : rate_limits: EndpointRateLimitConfig::default(),
327 1 : },
328 1 : RoleAccessControl {
329 1 : secret: secret1.clone(),
330 1 : },
331 : );
332 :
333 1 : cache.insert_endpoint_access(
334 1 : account_id,
335 1 : (&project_id).into(),
336 1 : (&endpoint_id).into(),
337 1 : (&user2).into(),
338 1 : EndpointAccessControl {
339 1 : allowed_ips: allowed_ips.clone(),
340 1 : allowed_vpce: Arc::new(vec![]),
341 1 : flags: AccessBlockerFlags::default(),
342 1 : rate_limits: EndpointRateLimitConfig::default(),
343 1 : },
344 1 : RoleAccessControl {
345 1 : secret: secret2.clone(),
346 1 : },
347 : );
348 :
349 1 : let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
350 1 : assert_eq!(cached.secret, secret1);
351 :
352 1 : let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap();
353 1 : assert_eq!(cached.secret, secret2);
354 :
355 : // Shouldn't add more than 2 roles.
356 1 : let user3: RoleName = "user3".into();
357 1 : let secret3 = Some(AuthSecret::Scram(ServerSecret::mock([3; 32])));
358 :
359 1 : cache.insert_endpoint_access(
360 1 : account_id,
361 1 : (&project_id).into(),
362 1 : (&endpoint_id).into(),
363 1 : (&user3).into(),
364 1 : EndpointAccessControl {
365 1 : allowed_ips: allowed_ips.clone(),
366 1 : allowed_vpce: Arc::new(vec![]),
367 1 : flags: AccessBlockerFlags::default(),
368 1 : rate_limits: EndpointRateLimitConfig::default(),
369 1 : },
370 1 : RoleAccessControl {
371 1 : secret: secret3.clone(),
372 1 : },
373 : );
374 :
375 1 : assert!(cache.get_role_secret(&endpoint_id, &user3).is_none());
376 :
377 1 : let cached = cache.get_endpoint_access(&endpoint_id).unwrap();
378 1 : assert_eq!(cached.allowed_ips, allowed_ips);
379 :
380 1 : tokio::time::advance(Duration::from_secs(2)).await;
381 1 : let cached = cache.get_role_secret(&endpoint_id, &user1);
382 1 : assert!(cached.is_none());
383 1 : let cached = cache.get_role_secret(&endpoint_id, &user2);
384 1 : assert!(cached.is_none());
385 1 : let cached = cache.get_endpoint_access(&endpoint_id);
386 1 : assert!(cached.is_none());
387 1 : }
388 : }
|