Line data Source code
1 : pub mod cplane_proxy_v1;
2 : #[cfg(any(test, feature = "testing"))]
3 : pub mod mock;
4 : pub mod neon;
5 :
6 : use std::hash::Hash;
7 : use std::sync::Arc;
8 : use std::time::Duration;
9 :
10 : use dashmap::DashMap;
11 : use tokio::time::Instant;
12 : use tracing::{debug, info};
13 :
14 : use crate::auth::backend::jwt::{AuthRule, FetchAuthRules, FetchAuthRulesError};
15 : use crate::auth::backend::ComputeUserInfo;
16 : use crate::cache::endpoints::EndpointsCache;
17 : use crate::cache::project_info::ProjectInfoCacheImpl;
18 : use crate::config::{CacheOptions, EndpointCacheConfig, ProjectInfoCacheOptions};
19 : use crate::context::RequestContext;
20 : use crate::control_plane::{
21 : errors, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret, ControlPlaneApi, NodeInfoCache,
22 : };
23 : use crate::error::ReportableError;
24 : use crate::metrics::ApiLockMetrics;
25 : use crate::rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token};
26 : use crate::types::EndpointId;
27 :
28 : #[non_exhaustive]
29 : #[derive(Clone)]
30 : pub enum ControlPlaneClient {
31 : /// New Proxy V1 control plane API
32 : ProxyV1(cplane_proxy_v1::NeonControlPlaneClient),
33 : /// Current Management API (V2).
34 : Neon(neon::NeonControlPlaneClient),
35 : /// Local mock control plane.
36 : #[cfg(any(test, feature = "testing"))]
37 : PostgresMock(mock::MockControlPlane),
38 : /// Internal testing
39 : #[cfg(test)]
40 : #[allow(private_interfaces)]
41 : Test(Box<dyn TestControlPlaneClient>),
42 : }
43 :
44 : impl ControlPlaneApi for ControlPlaneClient {
45 0 : async fn get_role_secret(
46 0 : &self,
47 0 : ctx: &RequestContext,
48 0 : user_info: &ComputeUserInfo,
49 0 : ) -> Result<CachedRoleSecret, errors::GetAuthInfoError> {
50 0 : match self {
51 0 : Self::ProxyV1(api) => api.get_role_secret(ctx, user_info).await,
52 0 : Self::Neon(api) => api.get_role_secret(ctx, user_info).await,
53 : #[cfg(any(test, feature = "testing"))]
54 0 : Self::PostgresMock(api) => api.get_role_secret(ctx, user_info).await,
55 : #[cfg(test)]
56 : Self::Test(_) => {
57 0 : unreachable!("this function should never be called in the test backend")
58 : }
59 : }
60 0 : }
61 :
62 0 : async fn get_allowed_ips_and_secret(
63 0 : &self,
64 0 : ctx: &RequestContext,
65 0 : user_info: &ComputeUserInfo,
66 0 : ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError> {
67 0 : match self {
68 0 : Self::ProxyV1(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
69 0 : Self::Neon(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
70 : #[cfg(any(test, feature = "testing"))]
71 0 : Self::PostgresMock(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
72 : #[cfg(test)]
73 0 : Self::Test(api) => api.get_allowed_ips_and_secret(),
74 : }
75 0 : }
76 :
77 0 : async fn get_endpoint_jwks(
78 0 : &self,
79 0 : ctx: &RequestContext,
80 0 : endpoint: EndpointId,
81 0 : ) -> Result<Vec<AuthRule>, errors::GetEndpointJwksError> {
82 0 : match self {
83 0 : Self::ProxyV1(api) => api.get_endpoint_jwks(ctx, endpoint).await,
84 0 : Self::Neon(api) => api.get_endpoint_jwks(ctx, endpoint).await,
85 : #[cfg(any(test, feature = "testing"))]
86 0 : Self::PostgresMock(api) => api.get_endpoint_jwks(ctx, endpoint).await,
87 : #[cfg(test)]
88 0 : Self::Test(_api) => Ok(vec![]),
89 : }
90 0 : }
91 :
92 13 : async fn wake_compute(
93 13 : &self,
94 13 : ctx: &RequestContext,
95 13 : user_info: &ComputeUserInfo,
96 13 : ) -> Result<CachedNodeInfo, errors::WakeComputeError> {
97 13 : match self {
98 0 : Self::ProxyV1(api) => api.wake_compute(ctx, user_info).await,
99 0 : Self::Neon(api) => api.wake_compute(ctx, user_info).await,
100 : #[cfg(any(test, feature = "testing"))]
101 0 : Self::PostgresMock(api) => api.wake_compute(ctx, user_info).await,
102 : #[cfg(test)]
103 13 : Self::Test(api) => api.wake_compute(),
104 : }
105 13 : }
106 : }
107 :
108 : #[cfg(test)]
109 : pub(crate) trait TestControlPlaneClient: Send + Sync + 'static {
110 : fn wake_compute(&self) -> Result<CachedNodeInfo, errors::WakeComputeError>;
111 :
112 : fn get_allowed_ips_and_secret(
113 : &self,
114 : ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError>;
115 :
116 : fn dyn_clone(&self) -> Box<dyn TestControlPlaneClient>;
117 : }
118 :
119 : #[cfg(test)]
120 : impl Clone for Box<dyn TestControlPlaneClient> {
121 0 : fn clone(&self) -> Self {
122 0 : TestControlPlaneClient::dyn_clone(&**self)
123 0 : }
124 : }
125 :
126 : /// Various caches for [`control_plane`](super).
127 : pub struct ApiCaches {
128 : /// Cache for the `wake_compute` API method.
129 : pub(crate) node_info: NodeInfoCache,
130 : /// Cache which stores project_id -> endpoint_ids mapping.
131 : pub project_info: Arc<ProjectInfoCacheImpl>,
132 : /// List of all valid endpoints.
133 : pub endpoints_cache: Arc<EndpointsCache>,
134 : }
135 :
136 : impl ApiCaches {
137 0 : pub fn new(
138 0 : wake_compute_cache_config: CacheOptions,
139 0 : project_info_cache_config: ProjectInfoCacheOptions,
140 0 : endpoint_cache_config: EndpointCacheConfig,
141 0 : ) -> Self {
142 0 : Self {
143 0 : node_info: NodeInfoCache::new(
144 0 : "node_info_cache",
145 0 : wake_compute_cache_config.size,
146 0 : wake_compute_cache_config.ttl,
147 0 : true,
148 0 : ),
149 0 : project_info: Arc::new(ProjectInfoCacheImpl::new(project_info_cache_config)),
150 0 : endpoints_cache: Arc::new(EndpointsCache::new(endpoint_cache_config)),
151 0 : }
152 0 : }
153 : }
154 :
155 : /// Various caches for [`control_plane`](super).
156 : pub struct ApiLocks<K> {
157 : name: &'static str,
158 : node_locks: DashMap<K, Arc<DynamicLimiter>>,
159 : config: RateLimiterConfig,
160 : timeout: Duration,
161 : epoch: std::time::Duration,
162 : metrics: &'static ApiLockMetrics,
163 : }
164 :
165 : #[derive(Debug, thiserror::Error)]
166 : pub(crate) enum ApiLockError {
167 : #[error("timeout acquiring resource permit")]
168 : TimeoutError(#[from] tokio::time::error::Elapsed),
169 : }
170 :
171 : impl ReportableError for ApiLockError {
172 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
173 0 : match self {
174 0 : ApiLockError::TimeoutError(_) => crate::error::ErrorKind::RateLimit,
175 0 : }
176 0 : }
177 : }
178 :
179 : impl<K: Hash + Eq + Clone> ApiLocks<K> {
180 0 : pub fn new(
181 0 : name: &'static str,
182 0 : config: RateLimiterConfig,
183 0 : shards: usize,
184 0 : timeout: Duration,
185 0 : epoch: std::time::Duration,
186 0 : metrics: &'static ApiLockMetrics,
187 0 : ) -> prometheus::Result<Self> {
188 0 : Ok(Self {
189 0 : name,
190 0 : node_locks: DashMap::with_shard_amount(shards),
191 0 : config,
192 0 : timeout,
193 0 : epoch,
194 0 : metrics,
195 0 : })
196 0 : }
197 :
198 0 : pub(crate) async fn get_permit(&self, key: &K) -> Result<WakeComputePermit, ApiLockError> {
199 0 : if self.config.initial_limit == 0 {
200 0 : return Ok(WakeComputePermit {
201 0 : permit: Token::disabled(),
202 0 : });
203 0 : }
204 0 : let now = Instant::now();
205 0 : let semaphore = {
206 : // get fast path
207 0 : if let Some(semaphore) = self.node_locks.get(key) {
208 0 : semaphore.clone()
209 : } else {
210 0 : self.node_locks
211 0 : .entry(key.clone())
212 0 : .or_insert_with(|| {
213 0 : self.metrics.semaphores_registered.inc();
214 0 : DynamicLimiter::new(self.config)
215 0 : })
216 0 : .clone()
217 : }
218 : };
219 0 : let permit = semaphore.acquire_timeout(self.timeout).await;
220 :
221 0 : self.metrics
222 0 : .semaphore_acquire_seconds
223 0 : .observe(now.elapsed().as_secs_f64());
224 0 : debug!("acquired permit {:?}", now.elapsed().as_secs_f64());
225 0 : Ok(WakeComputePermit { permit: permit? })
226 0 : }
227 :
228 0 : pub async fn garbage_collect_worker(&self) {
229 0 : if self.config.initial_limit == 0 {
230 0 : return;
231 0 : }
232 0 : let mut interval =
233 0 : tokio::time::interval(self.epoch / (self.node_locks.shards().len()) as u32);
234 : loop {
235 0 : for (i, shard) in self.node_locks.shards().iter().enumerate() {
236 0 : interval.tick().await;
237 : // temporary lock a single shard and then clear any semaphores that aren't currently checked out
238 : // race conditions: if strong_count == 1, there's no way that it can increase while the shard is locked
239 : // therefore releasing it is safe from race conditions
240 0 : info!(
241 : name = self.name,
242 : shard = i,
243 0 : "performing epoch reclamation on api lock"
244 : );
245 0 : let mut lock = shard.write();
246 0 : let timer = self.metrics.reclamation_lag_seconds.start_timer();
247 0 : let count = lock
248 0 : .extract_if(|_, semaphore| Arc::strong_count(semaphore.get_mut()) == 1)
249 0 : .count();
250 0 : drop(lock);
251 0 : self.metrics.semaphores_unregistered.inc_by(count as u64);
252 0 : timer.observe();
253 0 : }
254 : }
255 0 : }
256 : }
257 :
258 : pub(crate) struct WakeComputePermit {
259 : permit: Token,
260 : }
261 :
262 : impl WakeComputePermit {
263 0 : pub(crate) fn should_check_cache(&self) -> bool {
264 0 : !self.permit.is_disabled()
265 0 : }
266 0 : pub(crate) fn release(self, outcome: Outcome) {
267 0 : self.permit.release(outcome);
268 0 : }
269 0 : pub(crate) fn release_result<T, E>(self, res: Result<T, E>) -> Result<T, E> {
270 0 : match res {
271 0 : Ok(_) => self.release(Outcome::Success),
272 0 : Err(_) => self.release(Outcome::Overload),
273 : }
274 0 : res
275 0 : }
276 : }
277 :
278 : impl FetchAuthRules for ControlPlaneClient {
279 0 : async fn fetch_auth_rules(
280 0 : &self,
281 0 : ctx: &RequestContext,
282 0 : endpoint: EndpointId,
283 0 : ) -> Result<Vec<AuthRule>, FetchAuthRulesError> {
284 0 : self.get_endpoint_jwks(ctx, endpoint)
285 0 : .await
286 0 : .map_err(FetchAuthRulesError::GetEndpointJwks)
287 0 : }
288 : }
|