Line data Source code
1 : #[cfg(any(test, feature = "testing"))]
2 : pub mod mock;
3 : pub mod neon;
4 :
5 : use super::messages::MetricsAuxInfo;
6 : use crate::{
7 : auth::{backend::ComputeUserInfo, IpPattern},
8 : cache::{project_info::ProjectInfoCacheImpl, Cached, TimedLru},
9 : compute,
10 : config::{CacheOptions, ProjectInfoCacheOptions},
11 : context::RequestMonitoring,
12 : scram, EndpointCacheKey, ProjectId,
13 : };
14 : use async_trait::async_trait;
15 : use dashmap::DashMap;
16 : use std::{sync::Arc, time::Duration};
17 : use tokio::sync::{OwnedSemaphorePermit, Semaphore};
18 : use tokio::time::Instant;
19 : use tracing::info;
20 :
21 : pub mod errors {
22 : use crate::{
23 : error::{io_error, UserFacingError},
24 : http,
25 : proxy::retry::ShouldRetry,
26 : };
27 : use thiserror::Error;
28 :
29 : /// A go-to error message which doesn't leak any detail.
30 : const REQUEST_FAILED: &str = "Console request failed";
31 :
32 : /// Common console API error.
33 37 : #[derive(Debug, Error)]
34 : pub enum ApiError {
35 : /// Error returned by the console itself.
36 : #[error("{REQUEST_FAILED} with {}: {}", .status, .text)]
37 : Console {
38 : status: http::StatusCode,
39 : text: Box<str>,
40 : },
41 :
42 : /// Various IO errors like broken pipe or malformed payload.
43 : #[error("{REQUEST_FAILED}: {0}")]
44 : Transport(#[from] std::io::Error),
45 : }
46 :
47 : impl ApiError {
48 : /// Returns HTTP status code if it's the reason for failure.
49 3 : pub fn http_status_code(&self) -> Option<http::StatusCode> {
50 3 : use ApiError::*;
51 3 : match self {
52 2 : Console { status, .. } => Some(*status),
53 1 : _ => None,
54 : }
55 3 : }
56 : }
57 :
58 : impl UserFacingError for ApiError {
59 4 : fn to_string_client(&self) -> String {
60 4 : use ApiError::*;
61 4 : match self {
62 : // To minimize risks, only select errors are forwarded to users.
63 : // Ask @neondatabase/control-plane for review before adding more.
64 2 : Console { status, .. } => match *status {
65 : http::StatusCode::NOT_FOUND => {
66 : // Status 404: failed to get a project-related resource.
67 0 : format!("{REQUEST_FAILED}: endpoint cannot be found")
68 : }
69 : http::StatusCode::NOT_ACCEPTABLE => {
70 : // Status 406: endpoint is disabled (we don't allow connections).
71 0 : format!("{REQUEST_FAILED}: endpoint is disabled")
72 : }
73 : http::StatusCode::LOCKED => {
74 : // Status 423: project might be in maintenance mode (or bad state), or quotas exceeded.
75 0 : format!("{REQUEST_FAILED}: endpoint is temporary unavailable. check your quotas and/or contact our support")
76 : }
77 2 : _ => REQUEST_FAILED.to_owned(),
78 : },
79 2 : _ => REQUEST_FAILED.to_owned(),
80 : }
81 4 : }
82 : }
83 :
84 : impl ShouldRetry for ApiError {
85 8 : fn could_retry(&self) -> bool {
86 8 : match self {
87 : // retry some transport errors
88 0 : Self::Transport(io) => io.could_retry(),
89 : // retry some temporary failures because the compute was in a bad state
90 : // (bad request can be returned when the endpoint was in transition)
91 : Self::Console {
92 : status: http::StatusCode::BAD_REQUEST,
93 : ..
94 4 : } => true,
95 : // locked can be returned when the endpoint was in transition
96 : // or when quotas are exceeded. don't retry when quotas are exceeded
97 : Self::Console {
98 : status: http::StatusCode::LOCKED,
99 0 : ref text,
100 0 : } => {
101 0 : // written data quota exceeded
102 0 : // data transfer quota exceeded
103 0 : // compute time quota exceeded
104 0 : // logical size quota exceeded
105 0 : !text.contains("quota exceeded")
106 0 : && !text.contains("the limit for current plan reached")
107 : }
108 4 : _ => false,
109 : }
110 8 : }
111 : }
112 :
113 : impl From<reqwest::Error> for ApiError {
114 1 : fn from(e: reqwest::Error) -> Self {
115 1 : io_error(e).into()
116 1 : }
117 : }
118 :
119 : impl From<reqwest_middleware::Error> for ApiError {
120 1 : fn from(e: reqwest_middleware::Error) -> Self {
121 1 : io_error(e).into()
122 1 : }
123 : }
124 :
125 24 : #[derive(Debug, Error)]
126 : pub enum GetAuthInfoError {
127 : // We shouldn't include the actual secret here.
128 : #[error("Console responded with a malformed auth secret")]
129 : BadSecret,
130 :
131 : #[error(transparent)]
132 : ApiError(ApiError),
133 : }
134 :
135 : // This allows more useful interactions than `#[from]`.
136 : impl<E: Into<ApiError>> From<E> for GetAuthInfoError {
137 4 : fn from(e: E) -> Self {
138 4 : Self::ApiError(e.into())
139 4 : }
140 : }
141 :
142 : impl UserFacingError for GetAuthInfoError {
143 4 : fn to_string_client(&self) -> String {
144 4 : use GetAuthInfoError::*;
145 4 : match self {
146 : // We absolutely should not leak any secrets!
147 0 : BadSecret => REQUEST_FAILED.to_owned(),
148 : // However, API might return a meaningful error.
149 4 : ApiError(e) => e.to_string_client(),
150 : }
151 4 : }
152 : }
153 0 : #[derive(Debug, Error)]
154 : pub enum WakeComputeError {
155 : #[error("Console responded with a malformed compute address: {0}")]
156 : BadComputeAddress(Box<str>),
157 :
158 : #[error(transparent)]
159 : ApiError(ApiError),
160 :
161 : #[error("Timeout waiting to acquire wake compute lock")]
162 : TimeoutError,
163 : }
164 :
165 : // This allows more useful interactions than `#[from]`.
166 : impl<E: Into<ApiError>> From<E> for WakeComputeError {
167 0 : fn from(e: E) -> Self {
168 0 : Self::ApiError(e.into())
169 0 : }
170 : }
171 :
172 : impl From<tokio::sync::AcquireError> for WakeComputeError {
173 0 : fn from(_: tokio::sync::AcquireError) -> Self {
174 0 : WakeComputeError::TimeoutError
175 0 : }
176 : }
177 : impl From<tokio::time::error::Elapsed> for WakeComputeError {
178 0 : fn from(_: tokio::time::error::Elapsed) -> Self {
179 0 : WakeComputeError::TimeoutError
180 0 : }
181 : }
182 :
183 : impl UserFacingError for WakeComputeError {
184 0 : fn to_string_client(&self) -> String {
185 0 : use WakeComputeError::*;
186 0 : match self {
187 : // We shouldn't show user the address even if it's broken.
188 : // Besides, user is unlikely to care about this detail.
189 0 : BadComputeAddress(_) => REQUEST_FAILED.to_owned(),
190 : // However, API might return a meaningful error.
191 0 : ApiError(e) => e.to_string_client(),
192 :
193 0 : TimeoutError => "timeout while acquiring the compute resource lock".to_owned(),
194 : }
195 0 : }
196 : }
197 : }
198 :
199 : /// Auth secret which is managed by the cloud.
200 62 : #[derive(Clone, Eq, PartialEq, Debug)]
201 : pub enum AuthSecret {
202 : #[cfg(any(test, feature = "testing"))]
203 : /// Md5 hash of user's password.
204 : Md5([u8; 16]),
205 :
206 : /// [SCRAM](crate::scram) authentication info.
207 : Scram(scram::ServerSecret),
208 : }
209 :
210 0 : #[derive(Default)]
211 : pub struct AuthInfo {
212 : pub secret: Option<AuthSecret>,
213 : /// List of IP addresses allowed for the autorization.
214 : pub allowed_ips: Vec<IpPattern>,
215 : /// Project ID. This is used for cache invalidation.
216 : pub project_id: Option<ProjectId>,
217 : }
218 :
219 : /// Info for establishing a connection to a compute node.
220 : /// This is what we get after auth succeeded, but not before!
221 0 : #[derive(Clone)]
222 : pub struct NodeInfo {
223 : /// Compute node connection params.
224 : /// It's sad that we have to clone this, but this will improve
225 : /// once we migrate to a bespoke connection logic.
226 : pub config: compute::ConnCfg,
227 :
228 : /// Labels for proxy's metrics.
229 : pub aux: MetricsAuxInfo,
230 :
231 : /// Whether we should accept self-signed certificates (for testing)
232 : pub allow_self_signed_compute: bool,
233 : }
234 :
235 : pub type NodeInfoCache = TimedLru<EndpointCacheKey, NodeInfo>;
236 : pub type CachedNodeInfo = Cached<&'static NodeInfoCache>;
237 : pub type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option<AuthSecret>>;
238 : pub type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc<Vec<IpPattern>>>;
239 :
240 : /// This will allocate per each call, but the http requests alone
241 : /// already require a few allocations, so it should be fine.
242 : #[async_trait]
243 : pub trait Api {
244 : /// Get the client's auth secret for authentication.
245 : /// Returns option because user not found situation is special.
246 : /// We still have to mock the scram to avoid leaking information that user doesn't exist.
247 : async fn get_role_secret(
248 : &self,
249 : ctx: &mut RequestMonitoring,
250 : user_info: &ComputeUserInfo,
251 : ) -> Result<CachedRoleSecret, errors::GetAuthInfoError>;
252 :
253 : async fn get_allowed_ips_and_secret(
254 : &self,
255 : ctx: &mut RequestMonitoring,
256 : user_info: &ComputeUserInfo,
257 : ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError>;
258 :
259 : /// Wake up the compute node and return the corresponding connection info.
260 : async fn wake_compute(
261 : &self,
262 : ctx: &mut RequestMonitoring,
263 : user_info: &ComputeUserInfo,
264 : ) -> Result<CachedNodeInfo, errors::WakeComputeError>;
265 : }
266 :
267 : #[non_exhaustive]
268 : pub enum ConsoleBackend {
269 : /// Current Cloud API (V2).
270 : Console(neon::Api),
271 : /// Local mock of Cloud API (V2).
272 : #[cfg(any(test, feature = "testing"))]
273 : Postgres(mock::Api),
274 : /// Internal testing
275 : #[cfg(test)]
276 : Test(Box<dyn crate::auth::backend::TestBackend>),
277 : }
278 :
279 : #[async_trait]
280 : impl Api for ConsoleBackend {
281 39 : async fn get_role_secret(
282 39 : &self,
283 39 : ctx: &mut RequestMonitoring,
284 39 : user_info: &ComputeUserInfo,
285 39 : ) -> Result<CachedRoleSecret, errors::GetAuthInfoError> {
286 : use ConsoleBackend::*;
287 39 : match self {
288 0 : Console(api) => api.get_role_secret(ctx, user_info).await,
289 : #[cfg(any(test, feature = "testing"))]
290 267 : Postgres(api) => api.get_role_secret(ctx, user_info).await,
291 : #[cfg(test)]
292 0 : Test(_) => unreachable!("this function should never be called in the test backend"),
293 : }
294 78 : }
295 :
296 88 : async fn get_allowed_ips_and_secret(
297 88 : &self,
298 88 : ctx: &mut RequestMonitoring,
299 88 : user_info: &ComputeUserInfo,
300 88 : ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError> {
301 : use ConsoleBackend::*;
302 88 : match self {
303 12 : Console(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
304 : #[cfg(any(test, feature = "testing"))]
305 598 : Postgres(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
306 : #[cfg(test)]
307 0 : Test(api) => api.get_allowed_ips_and_secret(),
308 : }
309 176 : }
310 :
311 93 : async fn wake_compute(
312 93 : &self,
313 93 : ctx: &mut RequestMonitoring,
314 93 : user_info: &ComputeUserInfo,
315 93 : ) -> Result<CachedNodeInfo, errors::WakeComputeError> {
316 : use ConsoleBackend::*;
317 :
318 93 : match self {
319 0 : Console(api) => api.wake_compute(ctx, user_info).await,
320 : #[cfg(any(test, feature = "testing"))]
321 79 : Postgres(api) => api.wake_compute(ctx, user_info).await,
322 : #[cfg(test)]
323 14 : Test(api) => api.wake_compute(),
324 : }
325 186 : }
326 : }
327 :
328 : /// Various caches for [`console`](super).
329 : pub struct ApiCaches {
330 : /// Cache for the `wake_compute` API method.
331 : pub node_info: NodeInfoCache,
332 : /// Cache which stores project_id -> endpoint_ids mapping.
333 : pub project_info: Arc<ProjectInfoCacheImpl>,
334 : }
335 :
336 : impl ApiCaches {
337 1 : pub fn new(
338 1 : wake_compute_cache_config: CacheOptions,
339 1 : project_info_cache_config: ProjectInfoCacheOptions,
340 1 : ) -> Self {
341 1 : Self {
342 1 : node_info: NodeInfoCache::new(
343 1 : "node_info_cache",
344 1 : wake_compute_cache_config.size,
345 1 : wake_compute_cache_config.ttl,
346 1 : true,
347 1 : ),
348 1 : project_info: Arc::new(ProjectInfoCacheImpl::new(project_info_cache_config)),
349 1 : }
350 1 : }
351 : }
352 :
353 : /// Various caches for [`console`](super).
354 : pub struct ApiLocks {
355 : name: &'static str,
356 : node_locks: DashMap<EndpointCacheKey, Arc<Semaphore>>,
357 : permits: usize,
358 : timeout: Duration,
359 : registered: prometheus::IntCounter,
360 : unregistered: prometheus::IntCounter,
361 : reclamation_lag: prometheus::Histogram,
362 : lock_acquire_lag: prometheus::Histogram,
363 : }
364 :
365 : impl ApiLocks {
366 1 : pub fn new(
367 1 : name: &'static str,
368 1 : permits: usize,
369 1 : shards: usize,
370 1 : timeout: Duration,
371 1 : ) -> prometheus::Result<Self> {
372 1 : let registered = prometheus::IntCounter::with_opts(
373 1 : prometheus::Opts::new(
374 1 : "semaphores_registered",
375 1 : "Number of semaphores registered in this api lock",
376 1 : )
377 1 : .namespace(name),
378 1 : )?;
379 1 : prometheus::register(Box::new(registered.clone()))?;
380 1 : let unregistered = prometheus::IntCounter::with_opts(
381 1 : prometheus::Opts::new(
382 1 : "semaphores_unregistered",
383 1 : "Number of semaphores unregistered in this api lock",
384 1 : )
385 1 : .namespace(name),
386 1 : )?;
387 1 : prometheus::register(Box::new(unregistered.clone()))?;
388 1 : let reclamation_lag = prometheus::Histogram::with_opts(
389 1 : prometheus::HistogramOpts::new(
390 1 : "reclamation_lag_seconds",
391 1 : "Time it takes to reclaim unused semaphores in the api lock",
392 1 : )
393 1 : .namespace(name)
394 1 : // 1us -> 65ms
395 1 : // benchmarks on my mac indicate it's usually in the range of 256us and 512us
396 1 : .buckets(prometheus::exponential_buckets(1e-6, 2.0, 16)?),
397 0 : )?;
398 1 : prometheus::register(Box::new(reclamation_lag.clone()))?;
399 1 : let lock_acquire_lag = prometheus::Histogram::with_opts(
400 1 : prometheus::HistogramOpts::new(
401 1 : "semaphore_acquire_seconds",
402 1 : "Time it takes to reclaim unused semaphores in the api lock",
403 1 : )
404 1 : .namespace(name)
405 1 : // 0.1ms -> 6s
406 1 : .buckets(prometheus::exponential_buckets(1e-4, 2.0, 16)?),
407 0 : )?;
408 1 : prometheus::register(Box::new(lock_acquire_lag.clone()))?;
409 :
410 1 : Ok(Self {
411 1 : name,
412 1 : node_locks: DashMap::with_shard_amount(shards),
413 1 : permits,
414 1 : timeout,
415 1 : lock_acquire_lag,
416 1 : registered,
417 1 : unregistered,
418 1 : reclamation_lag,
419 1 : })
420 1 : }
421 :
422 0 : pub async fn get_wake_compute_permit(
423 0 : &self,
424 0 : key: &EndpointCacheKey,
425 0 : ) -> Result<WakeComputePermit, errors::WakeComputeError> {
426 0 : if self.permits == 0 {
427 0 : return Ok(WakeComputePermit { permit: None });
428 0 : }
429 0 : let now = Instant::now();
430 0 : let semaphore = {
431 : // get fast path
432 0 : if let Some(semaphore) = self.node_locks.get(key) {
433 0 : semaphore.clone()
434 : } else {
435 0 : self.node_locks
436 0 : .entry(key.clone())
437 0 : .or_insert_with(|| {
438 0 : self.registered.inc();
439 0 : Arc::new(Semaphore::new(self.permits))
440 0 : })
441 0 : .clone()
442 : }
443 : };
444 0 : let permit = tokio::time::timeout_at(now + self.timeout, semaphore.acquire_owned()).await;
445 :
446 0 : self.lock_acquire_lag
447 0 : .observe((Instant::now() - now).as_secs_f64());
448 0 :
449 0 : Ok(WakeComputePermit {
450 0 : permit: Some(permit??),
451 : })
452 0 : }
453 :
454 1 : pub async fn garbage_collect_worker(&self, epoch: std::time::Duration) {
455 1 : if self.permits == 0 {
456 1 : return;
457 0 : }
458 0 :
459 0 : let mut interval = tokio::time::interval(epoch / (self.node_locks.shards().len()) as u32);
460 : loop {
461 0 : for (i, shard) in self.node_locks.shards().iter().enumerate() {
462 0 : interval.tick().await;
463 : // temporary lock a single shard and then clear any semaphores that aren't currently checked out
464 : // race conditions: if strong_count == 1, there's no way that it can increase while the shard is locked
465 : // therefore releasing it is safe from race conditions
466 0 : info!(
467 0 : name = self.name,
468 0 : shard = i,
469 0 : "performing epoch reclamation on api lock"
470 0 : );
471 0 : let mut lock = shard.write();
472 0 : let timer = self.reclamation_lag.start_timer();
473 0 : let count = lock
474 0 : .extract_if(|_, semaphore| Arc::strong_count(semaphore.get_mut()) == 1)
475 0 : .count();
476 0 : drop(lock);
477 0 : self.unregistered.inc_by(count as u64);
478 0 : timer.observe_duration()
479 : }
480 : }
481 1 : }
482 : }
483 :
484 : pub struct WakeComputePermit {
485 : // None if the lock is disabled
486 : permit: Option<OwnedSemaphorePermit>,
487 : }
488 :
489 : impl WakeComputePermit {
490 0 : pub fn should_check_cache(&self) -> bool {
491 0 : self.permit.is_some()
492 0 : }
493 : }
|