Line data Source code
1 : #[cfg(any(test, feature = "testing"))]
2 : pub mod mock;
3 : pub mod neon;
4 :
5 : use super::messages::MetricsAuxInfo;
6 : use crate::{
7 : auth::{
8 : backend::{ComputeCredentialKeys, ComputeUserInfo},
9 : IpPattern,
10 : },
11 : cache::{endpoints::EndpointsCache, project_info::ProjectInfoCacheImpl, Cached, TimedLru},
12 : compute,
13 : config::{CacheOptions, EndpointCacheConfig, ProjectInfoCacheOptions},
14 : context::RequestMonitoring,
15 : error::ReportableError,
16 : intern::ProjectIdInt,
17 : metrics::ApiLockMetrics,
18 : scram, EndpointCacheKey,
19 : };
20 : use dashmap::DashMap;
21 : use std::{hash::Hash, sync::Arc, time::Duration};
22 : use tokio::sync::{OwnedSemaphorePermit, Semaphore};
23 : use tokio::time::Instant;
24 : use tracing::info;
25 :
26 : pub mod errors {
27 : use crate::{
28 : error::{io_error, ReportableError, UserFacingError},
29 : http,
30 : proxy::retry::ShouldRetry,
31 : };
32 : use thiserror::Error;
33 :
34 : use super::ApiLockError;
35 :
36 : /// A go-to error message which doesn't leak any detail.
37 : const REQUEST_FAILED: &str = "Console request failed";
38 :
39 : /// Common console API error.
40 0 : #[derive(Debug, Error)]
41 : pub enum ApiError {
42 : /// Error returned by the console itself.
43 : #[error("{REQUEST_FAILED} with {}: {}", .status, .text)]
44 : Console {
45 : status: http::StatusCode,
46 : text: Box<str>,
47 : },
48 :
49 : /// Various IO errors like broken pipe or malformed payload.
50 : #[error("{REQUEST_FAILED}: {0}")]
51 : Transport(#[from] std::io::Error),
52 : }
53 :
54 : impl ApiError {
55 : /// Returns HTTP status code if it's the reason for failure.
56 0 : pub fn http_status_code(&self) -> Option<http::StatusCode> {
57 0 : use ApiError::*;
58 0 : match self {
59 0 : Console { status, .. } => Some(*status),
60 0 : _ => None,
61 : }
62 0 : }
63 : }
64 :
65 : impl UserFacingError for ApiError {
66 0 : fn to_string_client(&self) -> String {
67 0 : use ApiError::*;
68 0 : match self {
69 : // To minimize risks, only select errors are forwarded to users.
70 : // Ask @neondatabase/control-plane for review before adding more.
71 0 : Console { status, .. } => match *status {
72 : http::StatusCode::NOT_FOUND => {
73 : // Status 404: failed to get a project-related resource.
74 0 : format!("{REQUEST_FAILED}: endpoint cannot be found")
75 : }
76 : http::StatusCode::NOT_ACCEPTABLE => {
77 : // Status 406: endpoint is disabled (we don't allow connections).
78 0 : format!("{REQUEST_FAILED}: endpoint is disabled")
79 : }
80 : http::StatusCode::LOCKED | http::StatusCode::UNPROCESSABLE_ENTITY => {
81 : // Status 423: project might be in maintenance mode (or bad state), or quotas exceeded.
82 0 : format!("{REQUEST_FAILED}: endpoint is temporarily unavailable. Check your quotas and/or contact our support.")
83 : }
84 0 : _ => REQUEST_FAILED.to_owned(),
85 : },
86 0 : _ => REQUEST_FAILED.to_owned(),
87 : }
88 0 : }
89 : }
90 :
91 : impl ReportableError for ApiError {
92 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
93 0 : match self {
94 : ApiError::Console {
95 : status: http::StatusCode::NOT_FOUND | http::StatusCode::NOT_ACCEPTABLE,
96 : ..
97 0 : } => crate::error::ErrorKind::User,
98 : ApiError::Console {
99 : status: http::StatusCode::UNPROCESSABLE_ENTITY,
100 0 : text,
101 0 : } if text.contains("compute time quota of non-primary branches is exceeded") => {
102 0 : crate::error::ErrorKind::User
103 : }
104 : ApiError::Console {
105 : status: http::StatusCode::LOCKED,
106 0 : text,
107 0 : } if text.contains("quota exceeded")
108 0 : || text.contains("the limit for current plan reached") =>
109 0 : {
110 0 : crate::error::ErrorKind::User
111 : }
112 : ApiError::Console {
113 : status: http::StatusCode::TOO_MANY_REQUESTS,
114 : ..
115 0 : } => crate::error::ErrorKind::ServiceRateLimit,
116 0 : ApiError::Console { .. } => crate::error::ErrorKind::ControlPlane,
117 0 : ApiError::Transport(_) => crate::error::ErrorKind::ControlPlane,
118 : }
119 0 : }
120 : }
121 :
122 : impl ShouldRetry for ApiError {
123 12 : fn could_retry(&self) -> bool {
124 12 : match self {
125 : // retry some transport errors
126 0 : Self::Transport(io) => io.could_retry(),
127 : // retry some temporary failures because the compute was in a bad state
128 : // (bad request can be returned when the endpoint was in transition)
129 : Self::Console {
130 : status: http::StatusCode::BAD_REQUEST,
131 : ..
132 8 : } => true,
133 : // don't retry when quotas are exceeded
134 : Self::Console {
135 : status: http::StatusCode::UNPROCESSABLE_ENTITY,
136 0 : ref text,
137 0 : } => !text.contains("compute time quota of non-primary branches is exceeded"),
138 : // locked can be returned when the endpoint was in transition
139 : // or when quotas are exceeded. don't retry when quotas are exceeded
140 : Self::Console {
141 : status: http::StatusCode::LOCKED,
142 0 : ref text,
143 0 : } => {
144 0 : // written data quota exceeded
145 0 : // data transfer quota exceeded
146 0 : // compute time quota exceeded
147 0 : // logical size quota exceeded
148 0 : !text.contains("quota exceeded")
149 0 : && !text.contains("the limit for current plan reached")
150 : }
151 4 : _ => false,
152 : }
153 12 : }
154 : }
155 :
156 : impl From<reqwest::Error> for ApiError {
157 0 : fn from(e: reqwest::Error) -> Self {
158 0 : io_error(e).into()
159 0 : }
160 : }
161 :
162 : impl From<reqwest_middleware::Error> for ApiError {
163 0 : fn from(e: reqwest_middleware::Error) -> Self {
164 0 : io_error(e).into()
165 0 : }
166 : }
167 :
168 0 : #[derive(Debug, Error)]
169 : pub enum GetAuthInfoError {
170 : // We shouldn't include the actual secret here.
171 : #[error("Console responded with a malformed auth secret")]
172 : BadSecret,
173 :
174 : #[error(transparent)]
175 : ApiError(ApiError),
176 : }
177 :
178 : // This allows more useful interactions than `#[from]`.
179 : impl<E: Into<ApiError>> From<E> for GetAuthInfoError {
180 0 : fn from(e: E) -> Self {
181 0 : Self::ApiError(e.into())
182 0 : }
183 : }
184 :
185 : impl UserFacingError for GetAuthInfoError {
186 0 : fn to_string_client(&self) -> String {
187 0 : use GetAuthInfoError::*;
188 0 : match self {
189 : // We absolutely should not leak any secrets!
190 0 : BadSecret => REQUEST_FAILED.to_owned(),
191 : // However, API might return a meaningful error.
192 0 : ApiError(e) => e.to_string_client(),
193 : }
194 0 : }
195 : }
196 :
197 : impl ReportableError for GetAuthInfoError {
198 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
199 0 : match self {
200 0 : GetAuthInfoError::BadSecret => crate::error::ErrorKind::ControlPlane,
201 0 : GetAuthInfoError::ApiError(_) => crate::error::ErrorKind::ControlPlane,
202 : }
203 0 : }
204 : }
205 :
206 0 : #[derive(Debug, Error)]
207 : pub enum WakeComputeError {
208 : #[error("Console responded with a malformed compute address: {0}")]
209 : BadComputeAddress(Box<str>),
210 :
211 : #[error(transparent)]
212 : ApiError(ApiError),
213 :
214 : #[error("Too many connections attempts")]
215 : TooManyConnections,
216 :
217 : #[error("error acquiring resource permit: {0}")]
218 : TooManyConnectionAttempts(#[from] ApiLockError),
219 : }
220 :
221 : // This allows more useful interactions than `#[from]`.
222 : impl<E: Into<ApiError>> From<E> for WakeComputeError {
223 0 : fn from(e: E) -> Self {
224 0 : Self::ApiError(e.into())
225 0 : }
226 : }
227 :
228 : impl UserFacingError for WakeComputeError {
229 0 : fn to_string_client(&self) -> String {
230 0 : use WakeComputeError::*;
231 0 : match self {
232 : // We shouldn't show user the address even if it's broken.
233 : // Besides, user is unlikely to care about this detail.
234 0 : BadComputeAddress(_) => REQUEST_FAILED.to_owned(),
235 : // However, API might return a meaningful error.
236 0 : ApiError(e) => e.to_string_client(),
237 :
238 0 : TooManyConnections => self.to_string(),
239 :
240 : TooManyConnectionAttempts(_) => {
241 0 : "Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
242 : }
243 : }
244 0 : }
245 : }
246 :
247 : impl ReportableError for WakeComputeError {
248 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
249 0 : match self {
250 0 : WakeComputeError::BadComputeAddress(_) => crate::error::ErrorKind::ControlPlane,
251 0 : WakeComputeError::ApiError(e) => e.get_error_kind(),
252 0 : WakeComputeError::TooManyConnections => crate::error::ErrorKind::RateLimit,
253 0 : WakeComputeError::TooManyConnectionAttempts(e) => e.get_error_kind(),
254 : }
255 0 : }
256 : }
257 : }
258 :
259 : /// Auth secret which is managed by the cloud.
260 : #[derive(Clone, Eq, PartialEq, Debug)]
261 : pub enum AuthSecret {
262 : #[cfg(any(test, feature = "testing"))]
263 : /// Md5 hash of user's password.
264 : Md5([u8; 16]),
265 :
266 : /// [SCRAM](crate::scram) authentication info.
267 : Scram(scram::ServerSecret),
268 : }
269 :
270 : #[derive(Default)]
271 : pub struct AuthInfo {
272 : pub secret: Option<AuthSecret>,
273 : /// List of IP addresses allowed for the autorization.
274 : pub allowed_ips: Vec<IpPattern>,
275 : /// Project ID. This is used for cache invalidation.
276 : pub project_id: Option<ProjectIdInt>,
277 : }
278 :
279 : /// Info for establishing a connection to a compute node.
280 : /// This is what we get after auth succeeded, but not before!
281 : #[derive(Clone)]
282 : pub struct NodeInfo {
283 : /// Compute node connection params.
284 : /// It's sad that we have to clone this, but this will improve
285 : /// once we migrate to a bespoke connection logic.
286 : pub config: compute::ConnCfg,
287 :
288 : /// Labels for proxy's metrics.
289 : pub aux: MetricsAuxInfo,
290 :
291 : /// Whether we should accept self-signed certificates (for testing)
292 : pub allow_self_signed_compute: bool,
293 : }
294 :
295 : impl NodeInfo {
296 0 : pub async fn connect(
297 0 : &self,
298 0 : ctx: &mut RequestMonitoring,
299 0 : timeout: Duration,
300 0 : ) -> Result<compute::PostgresConnection, compute::ConnectionError> {
301 0 : self.config
302 0 : .connect(
303 0 : ctx,
304 0 : self.allow_self_signed_compute,
305 0 : self.aux.clone(),
306 0 : timeout,
307 0 : )
308 0 : .await
309 0 : }
310 8 : pub fn reuse_settings(&mut self, other: Self) {
311 8 : self.allow_self_signed_compute = other.allow_self_signed_compute;
312 8 : self.config.reuse_password(other.config);
313 8 : }
314 :
315 12 : pub fn set_keys(&mut self, keys: &ComputeCredentialKeys) {
316 12 : match keys {
317 12 : ComputeCredentialKeys::Password(password) => self.config.password(password),
318 0 : ComputeCredentialKeys::AuthKeys(auth_keys) => self.config.auth_keys(*auth_keys),
319 : };
320 12 : }
321 : }
322 :
323 : pub type NodeInfoCache = TimedLru<EndpointCacheKey, NodeInfo>;
324 : pub type CachedNodeInfo = Cached<&'static NodeInfoCache>;
325 : pub type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option<AuthSecret>>;
326 : pub type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc<Vec<IpPattern>>>;
327 :
328 : /// This will allocate per each call, but the http requests alone
329 : /// already require a few allocations, so it should be fine.
330 : pub(crate) trait Api {
331 : /// Get the client's auth secret for authentication.
332 : /// Returns option because user not found situation is special.
333 : /// We still have to mock the scram to avoid leaking information that user doesn't exist.
334 : async fn get_role_secret(
335 : &self,
336 : ctx: &mut RequestMonitoring,
337 : user_info: &ComputeUserInfo,
338 : ) -> Result<CachedRoleSecret, errors::GetAuthInfoError>;
339 :
340 : async fn get_allowed_ips_and_secret(
341 : &self,
342 : ctx: &mut RequestMonitoring,
343 : user_info: &ComputeUserInfo,
344 : ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError>;
345 :
346 : /// Wake up the compute node and return the corresponding connection info.
347 : async fn wake_compute(
348 : &self,
349 : ctx: &mut RequestMonitoring,
350 : user_info: &ComputeUserInfo,
351 : ) -> Result<CachedNodeInfo, errors::WakeComputeError>;
352 : }
353 :
354 : #[non_exhaustive]
355 : pub enum ConsoleBackend {
356 : /// Current Cloud API (V2).
357 : Console(neon::Api),
358 : /// Local mock of Cloud API (V2).
359 : #[cfg(any(test, feature = "testing"))]
360 : Postgres(mock::Api),
361 : /// Internal testing
362 : #[cfg(test)]
363 : Test(Box<dyn crate::auth::backend::TestBackend>),
364 : }
365 :
366 : impl Api for ConsoleBackend {
367 0 : async fn get_role_secret(
368 0 : &self,
369 0 : ctx: &mut RequestMonitoring,
370 0 : user_info: &ComputeUserInfo,
371 0 : ) -> Result<CachedRoleSecret, errors::GetAuthInfoError> {
372 0 : use ConsoleBackend::*;
373 0 : match self {
374 0 : Console(api) => api.get_role_secret(ctx, user_info).await,
375 : #[cfg(any(test, feature = "testing"))]
376 0 : Postgres(api) => api.get_role_secret(ctx, user_info).await,
377 : #[cfg(test)]
378 0 : Test(_) => unreachable!("this function should never be called in the test backend"),
379 : }
380 0 : }
381 :
382 0 : async fn get_allowed_ips_and_secret(
383 0 : &self,
384 0 : ctx: &mut RequestMonitoring,
385 0 : user_info: &ComputeUserInfo,
386 0 : ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError> {
387 0 : use ConsoleBackend::*;
388 0 : match self {
389 0 : Console(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
390 : #[cfg(any(test, feature = "testing"))]
391 0 : Postgres(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
392 : #[cfg(test)]
393 0 : Test(api) => api.get_allowed_ips_and_secret(),
394 : }
395 0 : }
396 :
397 26 : async fn wake_compute(
398 26 : &self,
399 26 : ctx: &mut RequestMonitoring,
400 26 : user_info: &ComputeUserInfo,
401 26 : ) -> Result<CachedNodeInfo, errors::WakeComputeError> {
402 26 : use ConsoleBackend::*;
403 26 :
404 26 : match self {
405 0 : Console(api) => api.wake_compute(ctx, user_info).await,
406 : #[cfg(any(test, feature = "testing"))]
407 0 : Postgres(api) => api.wake_compute(ctx, user_info).await,
408 : #[cfg(test)]
409 26 : Test(api) => api.wake_compute(),
410 : }
411 26 : }
412 : }
413 :
414 : /// Various caches for [`console`](super).
415 : pub struct ApiCaches {
416 : /// Cache for the `wake_compute` API method.
417 : pub node_info: NodeInfoCache,
418 : /// Cache which stores project_id -> endpoint_ids mapping.
419 : pub project_info: Arc<ProjectInfoCacheImpl>,
420 : /// List of all valid endpoints.
421 : pub endpoints_cache: Arc<EndpointsCache>,
422 : }
423 :
424 : impl ApiCaches {
425 0 : pub fn new(
426 0 : wake_compute_cache_config: CacheOptions,
427 0 : project_info_cache_config: ProjectInfoCacheOptions,
428 0 : endpoint_cache_config: EndpointCacheConfig,
429 0 : ) -> Self {
430 0 : Self {
431 0 : node_info: NodeInfoCache::new(
432 0 : "node_info_cache",
433 0 : wake_compute_cache_config.size,
434 0 : wake_compute_cache_config.ttl,
435 0 : true,
436 0 : ),
437 0 : project_info: Arc::new(ProjectInfoCacheImpl::new(project_info_cache_config)),
438 0 : endpoints_cache: Arc::new(EndpointsCache::new(endpoint_cache_config)),
439 0 : }
440 0 : }
441 : }
442 :
443 : /// Various caches for [`console`](super).
444 : pub struct ApiLocks<K> {
445 : name: &'static str,
446 : node_locks: DashMap<K, Arc<Semaphore>>,
447 : permits: usize,
448 : timeout: Duration,
449 : epoch: std::time::Duration,
450 : metrics: &'static ApiLockMetrics,
451 : }
452 :
453 0 : #[derive(Debug, thiserror::Error)]
454 : pub enum ApiLockError {
455 : #[error("lock was closed")]
456 : AcquireError(#[from] tokio::sync::AcquireError),
457 : #[error("permit could not be acquired")]
458 : TimeoutError(#[from] tokio::time::error::Elapsed),
459 : }
460 :
461 : impl ReportableError for ApiLockError {
462 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
463 0 : match self {
464 0 : ApiLockError::AcquireError(_) => crate::error::ErrorKind::Service,
465 0 : ApiLockError::TimeoutError(_) => crate::error::ErrorKind::RateLimit,
466 : }
467 0 : }
468 : }
469 :
470 : impl<K: Hash + Eq + Clone> ApiLocks<K> {
471 0 : pub fn new(
472 0 : name: &'static str,
473 0 : permits: usize,
474 0 : shards: usize,
475 0 : timeout: Duration,
476 0 : epoch: std::time::Duration,
477 0 : metrics: &'static ApiLockMetrics,
478 0 : ) -> prometheus::Result<Self> {
479 0 : Ok(Self {
480 0 : name,
481 0 : node_locks: DashMap::with_shard_amount(shards),
482 0 : permits,
483 0 : timeout,
484 0 : epoch,
485 0 : metrics,
486 0 : })
487 0 : }
488 :
489 0 : pub async fn get_permit(&self, key: &K) -> Result<WakeComputePermit, ApiLockError> {
490 0 : if self.permits == 0 {
491 0 : return Ok(WakeComputePermit { permit: None });
492 0 : }
493 0 : let now = Instant::now();
494 0 : let semaphore = {
495 : // get fast path
496 0 : if let Some(semaphore) = self.node_locks.get(key) {
497 0 : semaphore.clone()
498 : } else {
499 0 : self.node_locks
500 0 : .entry(key.clone())
501 0 : .or_insert_with(|| {
502 0 : self.metrics.semaphores_registered.inc();
503 0 : Arc::new(Semaphore::new(self.permits))
504 0 : })
505 0 : .clone()
506 : }
507 : };
508 0 : let permit = tokio::time::timeout_at(now + self.timeout, semaphore.acquire_owned()).await;
509 :
510 0 : self.metrics
511 0 : .semaphore_acquire_seconds
512 0 : .observe(now.elapsed().as_secs_f64());
513 0 :
514 0 : Ok(WakeComputePermit {
515 0 : permit: Some(permit??),
516 : })
517 0 : }
518 :
519 0 : pub async fn garbage_collect_worker(&self) {
520 0 : if self.permits == 0 {
521 0 : return;
522 0 : }
523 0 : let mut interval =
524 0 : tokio::time::interval(self.epoch / (self.node_locks.shards().len()) as u32);
525 : loop {
526 0 : for (i, shard) in self.node_locks.shards().iter().enumerate() {
527 0 : interval.tick().await;
528 : // temporary lock a single shard and then clear any semaphores that aren't currently checked out
529 : // race conditions: if strong_count == 1, there's no way that it can increase while the shard is locked
530 : // therefore releasing it is safe from race conditions
531 0 : info!(
532 : name = self.name,
533 : shard = i,
534 0 : "performing epoch reclamation on api lock"
535 : );
536 0 : let mut lock = shard.write();
537 0 : let timer = self.metrics.reclamation_lag_seconds.start_timer();
538 0 : let count = lock
539 0 : .extract_if(|_, semaphore| Arc::strong_count(semaphore.get_mut()) == 1)
540 0 : .count();
541 0 : drop(lock);
542 0 : self.metrics.semaphores_unregistered.inc_by(count as u64);
543 0 : timer.observe();
544 : }
545 : }
546 0 : }
547 : }
548 :
549 : pub struct WakeComputePermit {
550 : // None if the lock is disabled
551 : permit: Option<OwnedSemaphorePermit>,
552 : }
553 :
554 : impl WakeComputePermit {
555 0 : pub fn should_check_cache(&self) -> bool {
556 0 : self.permit.is_some()
557 0 : }
558 : }
|