Line data Source code
1 : #[cfg(any(test, feature = "testing"))]
2 : pub mod mock;
3 : pub mod neon;
4 :
5 : use std::hash::Hash;
6 : use std::sync::Arc;
7 : use std::time::Duration;
8 :
9 : use dashmap::DashMap;
10 : use tokio::time::Instant;
11 : use tracing::{debug, info};
12 :
13 : use crate::auth::backend::jwt::{AuthRule, FetchAuthRules, FetchAuthRulesError};
14 : use crate::auth::backend::ComputeUserInfo;
15 : use crate::cache::endpoints::EndpointsCache;
16 : use crate::cache::project_info::ProjectInfoCacheImpl;
17 : use crate::config::{CacheOptions, EndpointCacheConfig, ProjectInfoCacheOptions};
18 : use crate::context::RequestMonitoring;
19 : use crate::control_plane::{
20 : errors, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret, ControlPlaneApi, NodeInfoCache,
21 : };
22 : use crate::error::ReportableError;
23 : use crate::metrics::ApiLockMetrics;
24 : use crate::rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token};
25 : use crate::types::EndpointId;
26 :
27 : #[non_exhaustive]
28 : #[derive(Clone)]
29 : pub enum ControlPlaneClient {
30 : /// Current Management API (V2).
31 : Neon(neon::NeonControlPlaneClient),
32 : /// Local mock control plane.
33 : #[cfg(any(test, feature = "testing"))]
34 : PostgresMock(mock::MockControlPlane),
35 : /// Internal testing
36 : #[cfg(test)]
37 : #[allow(private_interfaces)]
38 : Test(Box<dyn TestControlPlaneClient>),
39 : }
40 :
41 : impl ControlPlaneApi for ControlPlaneClient {
42 0 : async fn get_role_secret(
43 0 : &self,
44 0 : ctx: &RequestMonitoring,
45 0 : user_info: &ComputeUserInfo,
46 0 : ) -> Result<CachedRoleSecret, errors::GetAuthInfoError> {
47 0 : match self {
48 0 : Self::Neon(api) => api.get_role_secret(ctx, user_info).await,
49 : #[cfg(any(test, feature = "testing"))]
50 0 : Self::PostgresMock(api) => api.get_role_secret(ctx, user_info).await,
51 : #[cfg(test)]
52 : Self::Test(_) => {
53 0 : unreachable!("this function should never be called in the test backend")
54 : }
55 : }
56 0 : }
57 :
58 0 : async fn get_allowed_ips_and_secret(
59 0 : &self,
60 0 : ctx: &RequestMonitoring,
61 0 : user_info: &ComputeUserInfo,
62 0 : ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError> {
63 0 : match self {
64 0 : Self::Neon(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
65 : #[cfg(any(test, feature = "testing"))]
66 0 : Self::PostgresMock(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
67 : #[cfg(test)]
68 0 : Self::Test(api) => api.get_allowed_ips_and_secret(),
69 : }
70 0 : }
71 :
72 0 : async fn get_endpoint_jwks(
73 0 : &self,
74 0 : ctx: &RequestMonitoring,
75 0 : endpoint: EndpointId,
76 0 : ) -> Result<Vec<AuthRule>, errors::GetEndpointJwksError> {
77 0 : match self {
78 0 : Self::Neon(api) => api.get_endpoint_jwks(ctx, endpoint).await,
79 : #[cfg(any(test, feature = "testing"))]
80 0 : Self::PostgresMock(api) => api.get_endpoint_jwks(ctx, endpoint).await,
81 : #[cfg(test)]
82 0 : Self::Test(_api) => Ok(vec![]),
83 : }
84 0 : }
85 :
86 13 : async fn wake_compute(
87 13 : &self,
88 13 : ctx: &RequestMonitoring,
89 13 : user_info: &ComputeUserInfo,
90 13 : ) -> Result<CachedNodeInfo, errors::WakeComputeError> {
91 13 : match self {
92 0 : Self::Neon(api) => api.wake_compute(ctx, user_info).await,
93 : #[cfg(any(test, feature = "testing"))]
94 0 : Self::PostgresMock(api) => api.wake_compute(ctx, user_info).await,
95 : #[cfg(test)]
96 13 : Self::Test(api) => api.wake_compute(),
97 : }
98 13 : }
99 : }
100 :
101 : #[cfg(test)]
102 : pub(crate) trait TestControlPlaneClient: Send + Sync + 'static {
103 : fn wake_compute(&self) -> Result<CachedNodeInfo, errors::WakeComputeError>;
104 :
105 : fn get_allowed_ips_and_secret(
106 : &self,
107 : ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError>;
108 :
109 : fn dyn_clone(&self) -> Box<dyn TestControlPlaneClient>;
110 : }
111 :
112 : #[cfg(test)]
113 : impl Clone for Box<dyn TestControlPlaneClient> {
114 0 : fn clone(&self) -> Self {
115 0 : TestControlPlaneClient::dyn_clone(&**self)
116 0 : }
117 : }
118 :
119 : /// Various caches for [`control_plane`](super).
120 : pub struct ApiCaches {
121 : /// Cache for the `wake_compute` API method.
122 : pub(crate) node_info: NodeInfoCache,
123 : /// Cache which stores project_id -> endpoint_ids mapping.
124 : pub project_info: Arc<ProjectInfoCacheImpl>,
125 : /// List of all valid endpoints.
126 : pub endpoints_cache: Arc<EndpointsCache>,
127 : }
128 :
129 : impl ApiCaches {
130 0 : pub fn new(
131 0 : wake_compute_cache_config: CacheOptions,
132 0 : project_info_cache_config: ProjectInfoCacheOptions,
133 0 : endpoint_cache_config: EndpointCacheConfig,
134 0 : ) -> Self {
135 0 : Self {
136 0 : node_info: NodeInfoCache::new(
137 0 : "node_info_cache",
138 0 : wake_compute_cache_config.size,
139 0 : wake_compute_cache_config.ttl,
140 0 : true,
141 0 : ),
142 0 : project_info: Arc::new(ProjectInfoCacheImpl::new(project_info_cache_config)),
143 0 : endpoints_cache: Arc::new(EndpointsCache::new(endpoint_cache_config)),
144 0 : }
145 0 : }
146 : }
147 :
148 : /// Various caches for [`control_plane`](super).
149 : pub struct ApiLocks<K> {
150 : name: &'static str,
151 : node_locks: DashMap<K, Arc<DynamicLimiter>>,
152 : config: RateLimiterConfig,
153 : timeout: Duration,
154 : epoch: std::time::Duration,
155 : metrics: &'static ApiLockMetrics,
156 : }
157 :
158 0 : #[derive(Debug, thiserror::Error)]
159 : pub(crate) enum ApiLockError {
160 : #[error("timeout acquiring resource permit")]
161 : TimeoutError(#[from] tokio::time::error::Elapsed),
162 : }
163 :
164 : impl ReportableError for ApiLockError {
165 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
166 0 : match self {
167 0 : ApiLockError::TimeoutError(_) => crate::error::ErrorKind::RateLimit,
168 0 : }
169 0 : }
170 : }
171 :
172 : impl<K: Hash + Eq + Clone> ApiLocks<K> {
173 0 : pub fn new(
174 0 : name: &'static str,
175 0 : config: RateLimiterConfig,
176 0 : shards: usize,
177 0 : timeout: Duration,
178 0 : epoch: std::time::Duration,
179 0 : metrics: &'static ApiLockMetrics,
180 0 : ) -> prometheus::Result<Self> {
181 0 : Ok(Self {
182 0 : name,
183 0 : node_locks: DashMap::with_shard_amount(shards),
184 0 : config,
185 0 : timeout,
186 0 : epoch,
187 0 : metrics,
188 0 : })
189 0 : }
190 :
191 0 : pub(crate) async fn get_permit(&self, key: &K) -> Result<WakeComputePermit, ApiLockError> {
192 0 : if self.config.initial_limit == 0 {
193 0 : return Ok(WakeComputePermit {
194 0 : permit: Token::disabled(),
195 0 : });
196 0 : }
197 0 : let now = Instant::now();
198 0 : let semaphore = {
199 : // get fast path
200 0 : if let Some(semaphore) = self.node_locks.get(key) {
201 0 : semaphore.clone()
202 : } else {
203 0 : self.node_locks
204 0 : .entry(key.clone())
205 0 : .or_insert_with(|| {
206 0 : self.metrics.semaphores_registered.inc();
207 0 : DynamicLimiter::new(self.config)
208 0 : })
209 0 : .clone()
210 : }
211 : };
212 0 : let permit = semaphore.acquire_timeout(self.timeout).await;
213 :
214 0 : self.metrics
215 0 : .semaphore_acquire_seconds
216 0 : .observe(now.elapsed().as_secs_f64());
217 0 : debug!("acquired permit {:?}", now.elapsed().as_secs_f64());
218 0 : Ok(WakeComputePermit { permit: permit? })
219 0 : }
220 :
221 0 : pub async fn garbage_collect_worker(&self) {
222 0 : if self.config.initial_limit == 0 {
223 0 : return;
224 0 : }
225 0 : let mut interval =
226 0 : tokio::time::interval(self.epoch / (self.node_locks.shards().len()) as u32);
227 : loop {
228 0 : for (i, shard) in self.node_locks.shards().iter().enumerate() {
229 0 : interval.tick().await;
230 : // temporary lock a single shard and then clear any semaphores that aren't currently checked out
231 : // race conditions: if strong_count == 1, there's no way that it can increase while the shard is locked
232 : // therefore releasing it is safe from race conditions
233 0 : info!(
234 : name = self.name,
235 : shard = i,
236 0 : "performing epoch reclamation on api lock"
237 : );
238 0 : let mut lock = shard.write();
239 0 : let timer = self.metrics.reclamation_lag_seconds.start_timer();
240 0 : let count = lock
241 0 : .extract_if(|_, semaphore| Arc::strong_count(semaphore.get_mut()) == 1)
242 0 : .count();
243 0 : drop(lock);
244 0 : self.metrics.semaphores_unregistered.inc_by(count as u64);
245 0 : timer.observe();
246 0 : }
247 : }
248 0 : }
249 : }
250 :
251 : pub(crate) struct WakeComputePermit {
252 : permit: Token,
253 : }
254 :
255 : impl WakeComputePermit {
256 0 : pub(crate) fn should_check_cache(&self) -> bool {
257 0 : !self.permit.is_disabled()
258 0 : }
259 0 : pub(crate) fn release(self, outcome: Outcome) {
260 0 : self.permit.release(outcome);
261 0 : }
262 0 : pub(crate) fn release_result<T, E>(self, res: Result<T, E>) -> Result<T, E> {
263 0 : match res {
264 0 : Ok(_) => self.release(Outcome::Success),
265 0 : Err(_) => self.release(Outcome::Overload),
266 : }
267 0 : res
268 0 : }
269 : }
270 :
271 : impl FetchAuthRules for ControlPlaneClient {
272 0 : async fn fetch_auth_rules(
273 0 : &self,
274 0 : ctx: &RequestMonitoring,
275 0 : endpoint: EndpointId,
276 0 : ) -> Result<Vec<AuthRule>, FetchAuthRulesError> {
277 0 : self.get_endpoint_jwks(ctx, endpoint)
278 0 : .await
279 0 : .map_err(FetchAuthRulesError::GetEndpointJwks)
280 0 : }
281 : }
|