Line data Source code
1 : pub mod cplane_proxy_v1;
2 : #[cfg(any(test, feature = "testing"))]
3 : pub mod mock;
4 :
5 : use std::hash::Hash;
6 : use std::sync::Arc;
7 : use std::time::Duration;
8 :
9 : use clashmap::ClashMap;
10 : use tokio::time::Instant;
11 : use tracing::{debug, info};
12 :
13 : use super::{EndpointAccessControl, RoleAccessControl};
14 : use crate::auth::backend::ComputeUserInfo;
15 : use crate::auth::backend::jwt::{AuthRule, FetchAuthRules, FetchAuthRulesError};
16 : use crate::cache::node_info::{CachedNodeInfo, NodeInfoCache};
17 : use crate::cache::project_info::ProjectInfoCache;
18 : use crate::config::{CacheOptions, ProjectInfoCacheOptions};
19 : use crate::context::RequestContext;
20 : use crate::control_plane::{ControlPlaneApi, errors};
21 : use crate::error::ReportableError;
22 : use crate::metrics::ApiLockMetrics;
23 : use crate::rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token};
24 : use crate::types::EndpointId;
25 :
26 : #[non_exhaustive]
27 : #[derive(Clone)]
28 : pub enum ControlPlaneClient {
29 : /// Proxy V1 control plane API
30 : ProxyV1(cplane_proxy_v1::NeonControlPlaneClient),
31 : /// Local mock control plane.
32 : #[cfg(any(test, feature = "testing"))]
33 : PostgresMock(mock::MockControlPlane),
34 : /// Internal testing
35 : #[cfg(test)]
36 : #[allow(private_interfaces)]
37 : Test(Box<dyn TestControlPlaneClient>),
38 : }
39 :
40 : impl ControlPlaneApi for ControlPlaneClient {
41 0 : async fn get_role_access_control(
42 0 : &self,
43 0 : ctx: &RequestContext,
44 0 : endpoint: &EndpointId,
45 0 : role: &crate::types::RoleName,
46 0 : ) -> Result<RoleAccessControl, errors::GetAuthInfoError> {
47 0 : match self {
48 0 : Self::ProxyV1(api) => api.get_role_access_control(ctx, endpoint, role).await,
49 : #[cfg(any(test, feature = "testing"))]
50 0 : Self::PostgresMock(api) => api.get_role_access_control(ctx, endpoint, role).await,
51 : #[cfg(test)]
52 0 : Self::Test(_api) => {
53 0 : unreachable!("this function should never be called in the test backend")
54 : }
55 : }
56 0 : }
57 :
58 0 : async fn get_endpoint_access_control(
59 0 : &self,
60 0 : ctx: &RequestContext,
61 0 : endpoint: &EndpointId,
62 0 : role: &crate::types::RoleName,
63 0 : ) -> Result<EndpointAccessControl, errors::GetAuthInfoError> {
64 0 : match self {
65 0 : Self::ProxyV1(api) => api.get_endpoint_access_control(ctx, endpoint, role).await,
66 : #[cfg(any(test, feature = "testing"))]
67 0 : Self::PostgresMock(api) => api.get_endpoint_access_control(ctx, endpoint, role).await,
68 : #[cfg(test)]
69 0 : Self::Test(api) => api.get_access_control(),
70 : }
71 0 : }
72 :
73 0 : async fn get_endpoint_jwks(
74 0 : &self,
75 0 : ctx: &RequestContext,
76 0 : endpoint: &EndpointId,
77 0 : ) -> Result<Vec<AuthRule>, errors::GetEndpointJwksError> {
78 0 : match self {
79 0 : Self::ProxyV1(api) => api.get_endpoint_jwks(ctx, endpoint).await,
80 : #[cfg(any(test, feature = "testing"))]
81 0 : Self::PostgresMock(api) => api.get_endpoint_jwks(ctx, endpoint).await,
82 : #[cfg(test)]
83 0 : Self::Test(_api) => Ok(vec![]),
84 : }
85 0 : }
86 :
87 21 : async fn wake_compute(
88 21 : &self,
89 21 : ctx: &RequestContext,
90 21 : user_info: &ComputeUserInfo,
91 21 : ) -> Result<CachedNodeInfo, errors::WakeComputeError> {
92 21 : match self {
93 0 : Self::ProxyV1(api) => api.wake_compute(ctx, user_info).await,
94 : #[cfg(any(test, feature = "testing"))]
95 0 : Self::PostgresMock(api) => api.wake_compute(ctx, user_info).await,
96 : #[cfg(test)]
97 21 : Self::Test(api) => api.wake_compute(),
98 : }
99 21 : }
100 : }
101 :
102 : #[cfg(test)]
103 : pub(crate) trait TestControlPlaneClient: Send + Sync + 'static {
104 : fn wake_compute(&self) -> Result<CachedNodeInfo, errors::WakeComputeError>;
105 :
106 : fn get_access_control(&self) -> Result<EndpointAccessControl, errors::GetAuthInfoError>;
107 :
108 : fn dyn_clone(&self) -> Box<dyn TestControlPlaneClient>;
109 : }
110 :
111 : #[cfg(test)]
112 : impl Clone for Box<dyn TestControlPlaneClient> {
113 0 : fn clone(&self) -> Self {
114 0 : TestControlPlaneClient::dyn_clone(&**self)
115 0 : }
116 : }
117 :
118 : /// Various caches for [`control_plane`](super).
119 : pub struct ApiCaches {
120 : /// Cache for the `wake_compute` API method.
121 : pub(crate) node_info: NodeInfoCache,
122 : /// Cache which stores project_id -> endpoint_ids mapping.
123 : pub project_info: Arc<ProjectInfoCache>,
124 : }
125 :
126 : impl ApiCaches {
127 0 : pub fn new(
128 0 : wake_compute_cache_config: CacheOptions,
129 0 : project_info_cache_config: ProjectInfoCacheOptions,
130 0 : ) -> Self {
131 0 : Self {
132 0 : node_info: NodeInfoCache::new(wake_compute_cache_config),
133 0 : project_info: Arc::new(ProjectInfoCache::new(project_info_cache_config)),
134 0 : }
135 0 : }
136 : }
137 :
138 : /// Various caches for [`control_plane`](super).
139 : pub struct ApiLocks<K> {
140 : name: &'static str,
141 : node_locks: ClashMap<K, Arc<DynamicLimiter>>,
142 : config: RateLimiterConfig,
143 : timeout: Duration,
144 : epoch: std::time::Duration,
145 : metrics: &'static ApiLockMetrics,
146 : }
147 :
148 : #[derive(Debug, thiserror::Error)]
149 : pub(crate) enum ApiLockError {
150 : #[error("timeout acquiring resource permit")]
151 : TimeoutError(#[from] tokio::time::error::Elapsed),
152 : }
153 :
154 : impl ReportableError for ApiLockError {
155 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
156 0 : match self {
157 0 : ApiLockError::TimeoutError(_) => crate::error::ErrorKind::RateLimit,
158 : }
159 0 : }
160 : }
161 :
162 : impl<K: Hash + Eq + Clone> ApiLocks<K> {
163 0 : pub fn new(
164 0 : name: &'static str,
165 0 : config: RateLimiterConfig,
166 0 : shards: usize,
167 0 : timeout: Duration,
168 0 : epoch: std::time::Duration,
169 0 : metrics: &'static ApiLockMetrics,
170 0 : ) -> Self {
171 0 : Self {
172 0 : name,
173 0 : node_locks: ClashMap::with_shard_amount(shards),
174 0 : config,
175 0 : timeout,
176 0 : epoch,
177 0 : metrics,
178 0 : }
179 0 : }
180 :
181 0 : pub(crate) async fn get_permit(&self, key: &K) -> Result<WakeComputePermit, ApiLockError> {
182 0 : if self.config.initial_limit == 0 {
183 0 : return Ok(WakeComputePermit {
184 0 : permit: Token::disabled(),
185 0 : });
186 0 : }
187 0 : let now = Instant::now();
188 0 : let semaphore = {
189 : // get fast path
190 0 : if let Some(semaphore) = self.node_locks.get(key) {
191 0 : semaphore.clone()
192 : } else {
193 0 : self.node_locks
194 0 : .entry(key.clone())
195 0 : .or_insert_with(|| {
196 0 : self.metrics.semaphores_registered.inc();
197 0 : DynamicLimiter::new(self.config)
198 0 : })
199 0 : .clone()
200 : }
201 : };
202 0 : let permit = semaphore.acquire_timeout(self.timeout).await;
203 :
204 0 : self.metrics
205 0 : .semaphore_acquire_seconds
206 0 : .observe(now.elapsed().as_secs_f64());
207 :
208 0 : if permit.is_ok() {
209 0 : debug!(elapsed = ?now.elapsed(), "acquired permit");
210 : } else {
211 0 : debug!(elapsed = ?now.elapsed(), "timed out acquiring permit");
212 : }
213 0 : Ok(WakeComputePermit { permit: permit? })
214 0 : }
215 :
216 0 : pub async fn garbage_collect_worker(&self) {
217 0 : if self.config.initial_limit == 0 {
218 0 : return;
219 0 : }
220 0 : let mut interval =
221 0 : tokio::time::interval(self.epoch / (self.node_locks.shards().len()) as u32);
222 : loop {
223 0 : for (i, shard) in self.node_locks.shards().iter().enumerate() {
224 0 : interval.tick().await;
225 : // temporary lock a single shard and then clear any semaphores that aren't currently checked out
226 : // race conditions: if strong_count == 1, there's no way that it can increase while the shard is locked
227 : // therefore releasing it is safe from race conditions
228 0 : info!(
229 : name = self.name,
230 : shard = i,
231 0 : "performing epoch reclamation on api lock"
232 : );
233 0 : let mut lock = shard.write();
234 0 : let timer = self.metrics.reclamation_lag_seconds.start_timer();
235 0 : let count = lock
236 0 : .extract_if(|(_, semaphore)| Arc::strong_count(semaphore) == 1)
237 0 : .count();
238 0 : drop(lock);
239 0 : self.metrics.semaphores_unregistered.inc_by(count as u64);
240 0 : timer.observe();
241 : }
242 : }
243 0 : }
244 : }
245 :
246 : pub(crate) struct WakeComputePermit {
247 : permit: Token,
248 : }
249 :
250 : impl WakeComputePermit {
251 0 : pub(crate) fn should_check_cache(&self) -> bool {
252 0 : !self.permit.is_disabled()
253 0 : }
254 0 : pub(crate) fn release(self, outcome: Outcome) {
255 0 : self.permit.release(outcome);
256 0 : }
257 0 : pub(crate) fn release_result<T, E>(self, res: Result<T, E>) -> Result<T, E> {
258 0 : match res {
259 0 : Ok(_) => self.release(Outcome::Success),
260 0 : Err(_) => self.release(Outcome::Overload),
261 : }
262 0 : res
263 0 : }
264 : }
265 :
266 : impl FetchAuthRules for ControlPlaneClient {
267 0 : async fn fetch_auth_rules(
268 0 : &self,
269 0 : ctx: &RequestContext,
270 0 : endpoint: EndpointId,
271 0 : ) -> Result<Vec<AuthRule>, FetchAuthRulesError> {
272 0 : self.get_endpoint_jwks(ctx, &endpoint)
273 0 : .await
274 0 : .map_err(FetchAuthRulesError::GetEndpointJwks)
275 0 : }
276 : }
|