Line data Source code
1 : pub mod cplane_proxy_v1;
2 : #[cfg(any(test, feature = "testing"))]
3 : pub mod mock;
4 :
5 : use std::hash::Hash;
6 : use std::sync::Arc;
7 : use std::time::Duration;
8 :
9 : use clashmap::ClashMap;
10 : use tokio::time::Instant;
11 : use tracing::{debug, info};
12 :
13 : use crate::auth::backend::jwt::{AuthRule, FetchAuthRules, FetchAuthRulesError};
14 : use crate::auth::backend::ComputeUserInfo;
15 : use crate::cache::endpoints::EndpointsCache;
16 : use crate::cache::project_info::ProjectInfoCacheImpl;
17 : use crate::config::{CacheOptions, EndpointCacheConfig, ProjectInfoCacheOptions};
18 : use crate::context::RequestContext;
19 : use crate::control_plane::{
20 : errors, CachedAccessBlockerFlags, CachedAllowedIps, CachedAllowedVpcEndpointIds,
21 : CachedNodeInfo, CachedRoleSecret, ControlPlaneApi, NodeInfoCache,
22 : };
23 : use crate::error::ReportableError;
24 : use crate::metrics::ApiLockMetrics;
25 : use crate::rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token};
26 : use crate::types::EndpointId;
27 :
28 : #[non_exhaustive]
29 : #[derive(Clone)]
30 : pub enum ControlPlaneClient {
31 : /// Proxy V1 control plane API
32 : ProxyV1(cplane_proxy_v1::NeonControlPlaneClient),
33 : /// Local mock control plane.
34 : #[cfg(any(test, feature = "testing"))]
35 : PostgresMock(mock::MockControlPlane),
36 : /// Internal testing
37 : #[cfg(test)]
38 : #[allow(private_interfaces)]
39 : Test(Box<dyn TestControlPlaneClient>),
40 : }
41 :
42 : impl ControlPlaneApi for ControlPlaneClient {
43 0 : async fn get_role_secret(
44 0 : &self,
45 0 : ctx: &RequestContext,
46 0 : user_info: &ComputeUserInfo,
47 0 : ) -> Result<CachedRoleSecret, errors::GetAuthInfoError> {
48 0 : match self {
49 0 : Self::ProxyV1(api) => api.get_role_secret(ctx, user_info).await,
50 : #[cfg(any(test, feature = "testing"))]
51 0 : Self::PostgresMock(api) => api.get_role_secret(ctx, user_info).await,
52 : #[cfg(test)]
53 : Self::Test(_) => {
54 0 : unreachable!("this function should never be called in the test backend")
55 : }
56 : }
57 0 : }
58 :
59 0 : async fn get_allowed_ips(
60 0 : &self,
61 0 : ctx: &RequestContext,
62 0 : user_info: &ComputeUserInfo,
63 0 : ) -> Result<CachedAllowedIps, errors::GetAuthInfoError> {
64 0 : match self {
65 0 : Self::ProxyV1(api) => api.get_allowed_ips(ctx, user_info).await,
66 : #[cfg(any(test, feature = "testing"))]
67 0 : Self::PostgresMock(api) => api.get_allowed_ips(ctx, user_info).await,
68 : #[cfg(test)]
69 0 : Self::Test(api) => api.get_allowed_ips(),
70 : }
71 0 : }
72 :
73 0 : async fn get_allowed_vpc_endpoint_ids(
74 0 : &self,
75 0 : ctx: &RequestContext,
76 0 : user_info: &ComputeUserInfo,
77 0 : ) -> Result<CachedAllowedVpcEndpointIds, errors::GetAuthInfoError> {
78 0 : match self {
79 0 : Self::ProxyV1(api) => api.get_allowed_vpc_endpoint_ids(ctx, user_info).await,
80 : #[cfg(any(test, feature = "testing"))]
81 0 : Self::PostgresMock(api) => api.get_allowed_vpc_endpoint_ids(ctx, user_info).await,
82 : #[cfg(test)]
83 0 : Self::Test(api) => api.get_allowed_vpc_endpoint_ids(),
84 : }
85 0 : }
86 :
87 0 : async fn get_block_public_or_vpc_access(
88 0 : &self,
89 0 : ctx: &RequestContext,
90 0 : user_info: &ComputeUserInfo,
91 0 : ) -> Result<CachedAccessBlockerFlags, errors::GetAuthInfoError> {
92 0 : match self {
93 0 : Self::ProxyV1(api) => api.get_block_public_or_vpc_access(ctx, user_info).await,
94 : #[cfg(any(test, feature = "testing"))]
95 0 : Self::PostgresMock(api) => api.get_block_public_or_vpc_access(ctx, user_info).await,
96 : #[cfg(test)]
97 0 : Self::Test(api) => api.get_block_public_or_vpc_access(),
98 : }
99 0 : }
100 :
101 0 : async fn get_endpoint_jwks(
102 0 : &self,
103 0 : ctx: &RequestContext,
104 0 : endpoint: EndpointId,
105 0 : ) -> Result<Vec<AuthRule>, errors::GetEndpointJwksError> {
106 0 : match self {
107 0 : Self::ProxyV1(api) => api.get_endpoint_jwks(ctx, endpoint).await,
108 : #[cfg(any(test, feature = "testing"))]
109 0 : Self::PostgresMock(api) => api.get_endpoint_jwks(ctx, endpoint).await,
110 : #[cfg(test)]
111 0 : Self::Test(_api) => Ok(vec![]),
112 : }
113 0 : }
114 :
115 13 : async fn wake_compute(
116 13 : &self,
117 13 : ctx: &RequestContext,
118 13 : user_info: &ComputeUserInfo,
119 13 : ) -> Result<CachedNodeInfo, errors::WakeComputeError> {
120 13 : match self {
121 0 : Self::ProxyV1(api) => api.wake_compute(ctx, user_info).await,
122 : #[cfg(any(test, feature = "testing"))]
123 0 : Self::PostgresMock(api) => api.wake_compute(ctx, user_info).await,
124 : #[cfg(test)]
125 13 : Self::Test(api) => api.wake_compute(),
126 : }
127 13 : }
128 : }
129 :
130 : #[cfg(test)]
131 : pub(crate) trait TestControlPlaneClient: Send + Sync + 'static {
132 : fn wake_compute(&self) -> Result<CachedNodeInfo, errors::WakeComputeError>;
133 :
134 : fn get_allowed_ips(&self) -> Result<CachedAllowedIps, errors::GetAuthInfoError>;
135 :
136 : fn get_allowed_vpc_endpoint_ids(
137 : &self,
138 : ) -> Result<CachedAllowedVpcEndpointIds, errors::GetAuthInfoError>;
139 :
140 : fn get_block_public_or_vpc_access(
141 : &self,
142 : ) -> Result<CachedAccessBlockerFlags, errors::GetAuthInfoError>;
143 :
144 : fn dyn_clone(&self) -> Box<dyn TestControlPlaneClient>;
145 : }
146 :
147 : #[cfg(test)]
148 : impl Clone for Box<dyn TestControlPlaneClient> {
149 0 : fn clone(&self) -> Self {
150 0 : TestControlPlaneClient::dyn_clone(&**self)
151 0 : }
152 : }
153 :
154 : /// Various caches for [`control_plane`](super).
155 : pub struct ApiCaches {
156 : /// Cache for the `wake_compute` API method.
157 : pub(crate) node_info: NodeInfoCache,
158 : /// Cache which stores project_id -> endpoint_ids mapping.
159 : pub project_info: Arc<ProjectInfoCacheImpl>,
160 : /// List of all valid endpoints.
161 : pub endpoints_cache: Arc<EndpointsCache>,
162 : }
163 :
164 : impl ApiCaches {
165 0 : pub fn new(
166 0 : wake_compute_cache_config: CacheOptions,
167 0 : project_info_cache_config: ProjectInfoCacheOptions,
168 0 : endpoint_cache_config: EndpointCacheConfig,
169 0 : ) -> Self {
170 0 : Self {
171 0 : node_info: NodeInfoCache::new(
172 0 : "node_info_cache",
173 0 : wake_compute_cache_config.size,
174 0 : wake_compute_cache_config.ttl,
175 0 : true,
176 0 : ),
177 0 : project_info: Arc::new(ProjectInfoCacheImpl::new(project_info_cache_config)),
178 0 : endpoints_cache: Arc::new(EndpointsCache::new(endpoint_cache_config)),
179 0 : }
180 0 : }
181 : }
182 :
183 : /// Various caches for [`control_plane`](super).
184 : pub struct ApiLocks<K> {
185 : name: &'static str,
186 : node_locks: ClashMap<K, Arc<DynamicLimiter>>,
187 : config: RateLimiterConfig,
188 : timeout: Duration,
189 : epoch: std::time::Duration,
190 : metrics: &'static ApiLockMetrics,
191 : }
192 :
193 : #[derive(Debug, thiserror::Error)]
194 : pub(crate) enum ApiLockError {
195 : #[error("timeout acquiring resource permit")]
196 : TimeoutError(#[from] tokio::time::error::Elapsed),
197 : }
198 :
199 : impl ReportableError for ApiLockError {
200 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
201 0 : match self {
202 0 : ApiLockError::TimeoutError(_) => crate::error::ErrorKind::RateLimit,
203 0 : }
204 0 : }
205 : }
206 :
207 : impl<K: Hash + Eq + Clone> ApiLocks<K> {
208 0 : pub fn new(
209 0 : name: &'static str,
210 0 : config: RateLimiterConfig,
211 0 : shards: usize,
212 0 : timeout: Duration,
213 0 : epoch: std::time::Duration,
214 0 : metrics: &'static ApiLockMetrics,
215 0 : ) -> Self {
216 0 : Self {
217 0 : name,
218 0 : node_locks: ClashMap::with_shard_amount(shards),
219 0 : config,
220 0 : timeout,
221 0 : epoch,
222 0 : metrics,
223 0 : }
224 0 : }
225 :
226 0 : pub(crate) async fn get_permit(&self, key: &K) -> Result<WakeComputePermit, ApiLockError> {
227 0 : if self.config.initial_limit == 0 {
228 0 : return Ok(WakeComputePermit {
229 0 : permit: Token::disabled(),
230 0 : });
231 0 : }
232 0 : let now = Instant::now();
233 0 : let semaphore = {
234 : // get fast path
235 0 : if let Some(semaphore) = self.node_locks.get(key) {
236 0 : semaphore.clone()
237 : } else {
238 0 : self.node_locks
239 0 : .entry(key.clone())
240 0 : .or_insert_with(|| {
241 0 : self.metrics.semaphores_registered.inc();
242 0 : DynamicLimiter::new(self.config)
243 0 : })
244 0 : .clone()
245 : }
246 : };
247 0 : let permit = semaphore.acquire_timeout(self.timeout).await;
248 :
249 0 : self.metrics
250 0 : .semaphore_acquire_seconds
251 0 : .observe(now.elapsed().as_secs_f64());
252 0 : debug!("acquired permit {:?}", now.elapsed().as_secs_f64());
253 0 : Ok(WakeComputePermit { permit: permit? })
254 0 : }
255 :
256 0 : pub async fn garbage_collect_worker(&self) {
257 0 : if self.config.initial_limit == 0 {
258 0 : return;
259 0 : }
260 0 : let mut interval =
261 0 : tokio::time::interval(self.epoch / (self.node_locks.shards().len()) as u32);
262 : loop {
263 0 : for (i, shard) in self.node_locks.shards().iter().enumerate() {
264 0 : interval.tick().await;
265 : // temporary lock a single shard and then clear any semaphores that aren't currently checked out
266 : // race conditions: if strong_count == 1, there's no way that it can increase while the shard is locked
267 : // therefore releasing it is safe from race conditions
268 0 : info!(
269 : name = self.name,
270 : shard = i,
271 0 : "performing epoch reclamation on api lock"
272 : );
273 0 : let mut lock = shard.write();
274 0 : let timer = self.metrics.reclamation_lag_seconds.start_timer();
275 0 : let count = lock
276 0 : .extract_if(|(_, semaphore)| Arc::strong_count(semaphore) == 1)
277 0 : .count();
278 0 : drop(lock);
279 0 : self.metrics.semaphores_unregistered.inc_by(count as u64);
280 0 : timer.observe();
281 0 : }
282 : }
283 0 : }
284 : }
285 :
286 : pub(crate) struct WakeComputePermit {
287 : permit: Token,
288 : }
289 :
290 : impl WakeComputePermit {
291 0 : pub(crate) fn should_check_cache(&self) -> bool {
292 0 : !self.permit.is_disabled()
293 0 : }
294 0 : pub(crate) fn release(self, outcome: Outcome) {
295 0 : self.permit.release(outcome);
296 0 : }
297 0 : pub(crate) fn release_result<T, E>(self, res: Result<T, E>) -> Result<T, E> {
298 0 : match res {
299 0 : Ok(_) => self.release(Outcome::Success),
300 0 : Err(_) => self.release(Outcome::Overload),
301 : }
302 0 : res
303 0 : }
304 : }
305 :
306 : impl FetchAuthRules for ControlPlaneClient {
307 0 : async fn fetch_auth_rules(
308 0 : &self,
309 0 : ctx: &RequestContext,
310 0 : endpoint: EndpointId,
311 0 : ) -> Result<Vec<AuthRule>, FetchAuthRulesError> {
312 0 : self.get_endpoint_jwks(ctx, endpoint)
313 0 : .await
314 0 : .map_err(FetchAuthRulesError::GetEndpointJwks)
315 0 : }
316 : }
|