Line data Source code
1 : #[cfg(any(test, feature = "testing"))]
2 : pub mod mock;
3 : pub mod neon;
4 :
5 : use super::messages::MetricsAuxInfo;
6 : use crate::{
7 : auth::{
8 : backend::{ComputeCredentialKeys, ComputeUserInfo},
9 : IpPattern,
10 : },
11 : cache::{endpoints::EndpointsCache, project_info::ProjectInfoCacheImpl, Cached, TimedLru},
12 : compute,
13 : config::{CacheOptions, EndpointCacheConfig, ProjectInfoCacheOptions},
14 : context::RequestMonitoring,
15 : intern::ProjectIdInt,
16 : metrics::ApiLockMetrics,
17 : scram, EndpointCacheKey,
18 : };
19 : use dashmap::DashMap;
20 : use std::{sync::Arc, time::Duration};
21 : use tokio::sync::{OwnedSemaphorePermit, Semaphore};
22 : use tokio::time::Instant;
23 : use tracing::info;
24 :
25 : pub mod errors {
26 : use crate::{
27 : error::{io_error, ReportableError, UserFacingError},
28 : http,
29 : proxy::retry::ShouldRetry,
30 : };
31 : use thiserror::Error;
32 :
33 : /// A go-to error message which doesn't leak any detail.
34 : const REQUEST_FAILED: &str = "Console request failed";
35 :
36 : /// Common console API error.
37 0 : #[derive(Debug, Error)]
38 : pub enum ApiError {
39 : /// Error returned by the console itself.
40 : #[error("{REQUEST_FAILED} with {}: {}", .status, .text)]
41 : Console {
42 : status: http::StatusCode,
43 : text: Box<str>,
44 : },
45 :
46 : /// Various IO errors like broken pipe or malformed payload.
47 : #[error("{REQUEST_FAILED}: {0}")]
48 : Transport(#[from] std::io::Error),
49 : }
50 :
51 : impl ApiError {
52 : /// Returns HTTP status code if it's the reason for failure.
53 0 : pub fn http_status_code(&self) -> Option<http::StatusCode> {
54 0 : use ApiError::*;
55 0 : match self {
56 0 : Console { status, .. } => Some(*status),
57 0 : _ => None,
58 : }
59 0 : }
60 : }
61 :
62 : impl UserFacingError for ApiError {
63 0 : fn to_string_client(&self) -> String {
64 0 : use ApiError::*;
65 0 : match self {
66 : // To minimize risks, only select errors are forwarded to users.
67 : // Ask @neondatabase/control-plane for review before adding more.
68 0 : Console { status, .. } => match *status {
69 : http::StatusCode::NOT_FOUND => {
70 : // Status 404: failed to get a project-related resource.
71 0 : format!("{REQUEST_FAILED}: endpoint cannot be found")
72 : }
73 : http::StatusCode::NOT_ACCEPTABLE => {
74 : // Status 406: endpoint is disabled (we don't allow connections).
75 0 : format!("{REQUEST_FAILED}: endpoint is disabled")
76 : }
77 : http::StatusCode::LOCKED | http::StatusCode::UNPROCESSABLE_ENTITY => {
78 : // Status 423: project might be in maintenance mode (or bad state), or quotas exceeded.
79 0 : format!("{REQUEST_FAILED}: endpoint is temporary unavailable. check your quotas and/or contact our support")
80 : }
81 0 : _ => REQUEST_FAILED.to_owned(),
82 : },
83 0 : _ => REQUEST_FAILED.to_owned(),
84 : }
85 0 : }
86 : }
87 :
88 : impl ReportableError for ApiError {
89 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
90 0 : match self {
91 : ApiError::Console {
92 : status: http::StatusCode::NOT_FOUND | http::StatusCode::NOT_ACCEPTABLE,
93 : ..
94 0 : } => crate::error::ErrorKind::User,
95 : ApiError::Console {
96 : status: http::StatusCode::UNPROCESSABLE_ENTITY,
97 0 : text,
98 0 : } if text.contains("compute time quota of non-primary branches is exceeded") => {
99 0 : crate::error::ErrorKind::User
100 : }
101 : ApiError::Console {
102 : status: http::StatusCode::LOCKED,
103 0 : text,
104 0 : } if text.contains("quota exceeded")
105 0 : || text.contains("the limit for current plan reached") =>
106 0 : {
107 0 : crate::error::ErrorKind::User
108 : }
109 : ApiError::Console {
110 : status: http::StatusCode::TOO_MANY_REQUESTS,
111 : ..
112 0 : } => crate::error::ErrorKind::ServiceRateLimit,
113 0 : ApiError::Console { .. } => crate::error::ErrorKind::ControlPlane,
114 0 : ApiError::Transport(_) => crate::error::ErrorKind::ControlPlane,
115 : }
116 0 : }
117 : }
118 :
119 : impl ShouldRetry for ApiError {
120 12 : fn could_retry(&self) -> bool {
121 12 : match self {
122 : // retry some transport errors
123 0 : Self::Transport(io) => io.could_retry(),
124 : // retry some temporary failures because the compute was in a bad state
125 : // (bad request can be returned when the endpoint was in transition)
126 : Self::Console {
127 : status: http::StatusCode::BAD_REQUEST,
128 : ..
129 8 : } => true,
130 : // don't retry when quotas are exceeded
131 : Self::Console {
132 : status: http::StatusCode::UNPROCESSABLE_ENTITY,
133 0 : ref text,
134 0 : } => !text.contains("compute time quota of non-primary branches is exceeded"),
135 : // locked can be returned when the endpoint was in transition
136 : // or when quotas are exceeded. don't retry when quotas are exceeded
137 : Self::Console {
138 : status: http::StatusCode::LOCKED,
139 0 : ref text,
140 0 : } => {
141 0 : // written data quota exceeded
142 0 : // data transfer quota exceeded
143 0 : // compute time quota exceeded
144 0 : // logical size quota exceeded
145 0 : !text.contains("quota exceeded")
146 0 : && !text.contains("the limit for current plan reached")
147 : }
148 4 : _ => false,
149 : }
150 12 : }
151 : }
152 :
153 : impl From<reqwest::Error> for ApiError {
154 0 : fn from(e: reqwest::Error) -> Self {
155 0 : io_error(e).into()
156 0 : }
157 : }
158 :
159 : impl From<reqwest_middleware::Error> for ApiError {
160 0 : fn from(e: reqwest_middleware::Error) -> Self {
161 0 : io_error(e).into()
162 0 : }
163 : }
164 :
165 0 : #[derive(Debug, Error)]
166 : pub enum GetAuthInfoError {
167 : // We shouldn't include the actual secret here.
168 : #[error("Console responded with a malformed auth secret")]
169 : BadSecret,
170 :
171 : #[error(transparent)]
172 : ApiError(ApiError),
173 : }
174 :
175 : // This allows more useful interactions than `#[from]`.
176 : impl<E: Into<ApiError>> From<E> for GetAuthInfoError {
177 0 : fn from(e: E) -> Self {
178 0 : Self::ApiError(e.into())
179 0 : }
180 : }
181 :
182 : impl UserFacingError for GetAuthInfoError {
183 0 : fn to_string_client(&self) -> String {
184 0 : use GetAuthInfoError::*;
185 0 : match self {
186 : // We absolutely should not leak any secrets!
187 0 : BadSecret => REQUEST_FAILED.to_owned(),
188 : // However, API might return a meaningful error.
189 0 : ApiError(e) => e.to_string_client(),
190 : }
191 0 : }
192 : }
193 :
194 : impl ReportableError for GetAuthInfoError {
195 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
196 0 : match self {
197 0 : GetAuthInfoError::BadSecret => crate::error::ErrorKind::ControlPlane,
198 0 : GetAuthInfoError::ApiError(_) => crate::error::ErrorKind::ControlPlane,
199 : }
200 0 : }
201 : }
202 :
203 0 : #[derive(Debug, Error)]
204 : pub enum WakeComputeError {
205 : #[error("Console responded with a malformed compute address: {0}")]
206 : BadComputeAddress(Box<str>),
207 :
208 : #[error(transparent)]
209 : ApiError(ApiError),
210 :
211 : #[error("Too many connections attempts")]
212 : TooManyConnections,
213 :
214 : #[error("Timeout waiting to acquire wake compute lock")]
215 : TimeoutError,
216 : }
217 :
218 : // This allows more useful interactions than `#[from]`.
219 : impl<E: Into<ApiError>> From<E> for WakeComputeError {
220 0 : fn from(e: E) -> Self {
221 0 : Self::ApiError(e.into())
222 0 : }
223 : }
224 :
225 : impl From<tokio::sync::AcquireError> for WakeComputeError {
226 0 : fn from(_: tokio::sync::AcquireError) -> Self {
227 0 : WakeComputeError::TimeoutError
228 0 : }
229 : }
230 : impl From<tokio::time::error::Elapsed> for WakeComputeError {
231 0 : fn from(_: tokio::time::error::Elapsed) -> Self {
232 0 : WakeComputeError::TimeoutError
233 0 : }
234 : }
235 :
236 : impl UserFacingError for WakeComputeError {
237 0 : fn to_string_client(&self) -> String {
238 0 : use WakeComputeError::*;
239 0 : match self {
240 : // We shouldn't show user the address even if it's broken.
241 : // Besides, user is unlikely to care about this detail.
242 0 : BadComputeAddress(_) => REQUEST_FAILED.to_owned(),
243 : // However, API might return a meaningful error.
244 0 : ApiError(e) => e.to_string_client(),
245 :
246 0 : TooManyConnections => self.to_string(),
247 :
248 0 : TimeoutError => "timeout while acquiring the compute resource lock".to_owned(),
249 : }
250 0 : }
251 : }
252 :
253 : impl ReportableError for WakeComputeError {
254 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
255 0 : match self {
256 0 : WakeComputeError::BadComputeAddress(_) => crate::error::ErrorKind::ControlPlane,
257 0 : WakeComputeError::ApiError(e) => e.get_error_kind(),
258 0 : WakeComputeError::TooManyConnections => crate::error::ErrorKind::RateLimit,
259 0 : WakeComputeError::TimeoutError => crate::error::ErrorKind::ServiceRateLimit,
260 : }
261 0 : }
262 : }
263 : }
264 :
265 : /// Auth secret which is managed by the cloud.
266 : #[derive(Clone, Eq, PartialEq, Debug)]
267 : pub enum AuthSecret {
268 : #[cfg(any(test, feature = "testing"))]
269 : /// Md5 hash of user's password.
270 : Md5([u8; 16]),
271 :
272 : /// [SCRAM](crate::scram) authentication info.
273 : Scram(scram::ServerSecret),
274 : }
275 :
276 : #[derive(Default)]
277 : pub struct AuthInfo {
278 : pub secret: Option<AuthSecret>,
279 : /// List of IP addresses allowed for the autorization.
280 : pub allowed_ips: Vec<IpPattern>,
281 : /// Project ID. This is used for cache invalidation.
282 : pub project_id: Option<ProjectIdInt>,
283 : }
284 :
285 : /// Info for establishing a connection to a compute node.
286 : /// This is what we get after auth succeeded, but not before!
287 : #[derive(Clone)]
288 : pub struct NodeInfo {
289 : /// Compute node connection params.
290 : /// It's sad that we have to clone this, but this will improve
291 : /// once we migrate to a bespoke connection logic.
292 : pub config: compute::ConnCfg,
293 :
294 : /// Labels for proxy's metrics.
295 : pub aux: MetricsAuxInfo,
296 :
297 : /// Whether we should accept self-signed certificates (for testing)
298 : pub allow_self_signed_compute: bool,
299 : }
300 :
301 : impl NodeInfo {
302 0 : pub async fn connect(
303 0 : &self,
304 0 : ctx: &mut RequestMonitoring,
305 0 : timeout: Duration,
306 0 : ) -> Result<compute::PostgresConnection, compute::ConnectionError> {
307 0 : self.config
308 0 : .connect(
309 0 : ctx,
310 0 : self.allow_self_signed_compute,
311 0 : self.aux.clone(),
312 0 : timeout,
313 0 : )
314 0 : .await
315 0 : }
316 8 : pub fn reuse_settings(&mut self, other: Self) {
317 8 : self.allow_self_signed_compute = other.allow_self_signed_compute;
318 8 : self.config.reuse_password(other.config);
319 8 : }
320 :
321 12 : pub fn set_keys(&mut self, keys: &ComputeCredentialKeys) {
322 12 : match keys {
323 12 : ComputeCredentialKeys::Password(password) => self.config.password(password),
324 0 : ComputeCredentialKeys::AuthKeys(auth_keys) => self.config.auth_keys(*auth_keys),
325 : };
326 12 : }
327 : }
328 :
329 : pub type NodeInfoCache = TimedLru<EndpointCacheKey, NodeInfo>;
330 : pub type CachedNodeInfo = Cached<&'static NodeInfoCache>;
331 : pub type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option<AuthSecret>>;
332 : pub type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc<Vec<IpPattern>>>;
333 :
334 : /// This will allocate per each call, but the http requests alone
335 : /// already require a few allocations, so it should be fine.
336 : pub(crate) trait Api {
337 : /// Get the client's auth secret for authentication.
338 : /// Returns option because user not found situation is special.
339 : /// We still have to mock the scram to avoid leaking information that user doesn't exist.
340 : async fn get_role_secret(
341 : &self,
342 : ctx: &mut RequestMonitoring,
343 : user_info: &ComputeUserInfo,
344 : ) -> Result<CachedRoleSecret, errors::GetAuthInfoError>;
345 :
346 : async fn get_allowed_ips_and_secret(
347 : &self,
348 : ctx: &mut RequestMonitoring,
349 : user_info: &ComputeUserInfo,
350 : ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError>;
351 :
352 : /// Wake up the compute node and return the corresponding connection info.
353 : async fn wake_compute(
354 : &self,
355 : ctx: &mut RequestMonitoring,
356 : user_info: &ComputeUserInfo,
357 : ) -> Result<CachedNodeInfo, errors::WakeComputeError>;
358 : }
359 :
360 : #[non_exhaustive]
361 : pub enum ConsoleBackend {
362 : /// Current Cloud API (V2).
363 : Console(neon::Api),
364 : /// Local mock of Cloud API (V2).
365 : #[cfg(any(test, feature = "testing"))]
366 : Postgres(mock::Api),
367 : /// Internal testing
368 : #[cfg(test)]
369 : Test(Box<dyn crate::auth::backend::TestBackend>),
370 : }
371 :
372 : impl Api for ConsoleBackend {
373 0 : async fn get_role_secret(
374 0 : &self,
375 0 : ctx: &mut RequestMonitoring,
376 0 : user_info: &ComputeUserInfo,
377 0 : ) -> Result<CachedRoleSecret, errors::GetAuthInfoError> {
378 0 : use ConsoleBackend::*;
379 0 : match self {
380 0 : Console(api) => api.get_role_secret(ctx, user_info).await,
381 : #[cfg(any(test, feature = "testing"))]
382 0 : Postgres(api) => api.get_role_secret(ctx, user_info).await,
383 : #[cfg(test)]
384 0 : Test(_) => unreachable!("this function should never be called in the test backend"),
385 : }
386 0 : }
387 :
388 0 : async fn get_allowed_ips_and_secret(
389 0 : &self,
390 0 : ctx: &mut RequestMonitoring,
391 0 : user_info: &ComputeUserInfo,
392 0 : ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError> {
393 0 : use ConsoleBackend::*;
394 0 : match self {
395 0 : Console(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
396 : #[cfg(any(test, feature = "testing"))]
397 0 : Postgres(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
398 : #[cfg(test)]
399 0 : Test(api) => api.get_allowed_ips_and_secret(),
400 : }
401 0 : }
402 :
403 26 : async fn wake_compute(
404 26 : &self,
405 26 : ctx: &mut RequestMonitoring,
406 26 : user_info: &ComputeUserInfo,
407 26 : ) -> Result<CachedNodeInfo, errors::WakeComputeError> {
408 26 : use ConsoleBackend::*;
409 26 :
410 26 : match self {
411 0 : Console(api) => api.wake_compute(ctx, user_info).await,
412 : #[cfg(any(test, feature = "testing"))]
413 0 : Postgres(api) => api.wake_compute(ctx, user_info).await,
414 : #[cfg(test)]
415 26 : Test(api) => api.wake_compute(),
416 : }
417 26 : }
418 : }
419 :
420 : /// Various caches for [`console`](super).
421 : pub struct ApiCaches {
422 : /// Cache for the `wake_compute` API method.
423 : pub node_info: NodeInfoCache,
424 : /// Cache which stores project_id -> endpoint_ids mapping.
425 : pub project_info: Arc<ProjectInfoCacheImpl>,
426 : /// List of all valid endpoints.
427 : pub endpoints_cache: Arc<EndpointsCache>,
428 : }
429 :
430 : impl ApiCaches {
431 0 : pub fn new(
432 0 : wake_compute_cache_config: CacheOptions,
433 0 : project_info_cache_config: ProjectInfoCacheOptions,
434 0 : endpoint_cache_config: EndpointCacheConfig,
435 0 : ) -> Self {
436 0 : Self {
437 0 : node_info: NodeInfoCache::new(
438 0 : "node_info_cache",
439 0 : wake_compute_cache_config.size,
440 0 : wake_compute_cache_config.ttl,
441 0 : true,
442 0 : ),
443 0 : project_info: Arc::new(ProjectInfoCacheImpl::new(project_info_cache_config)),
444 0 : endpoints_cache: Arc::new(EndpointsCache::new(endpoint_cache_config)),
445 0 : }
446 0 : }
447 : }
448 :
449 : /// Various caches for [`console`](super).
450 : pub struct ApiLocks {
451 : name: &'static str,
452 : node_locks: DashMap<EndpointCacheKey, Arc<Semaphore>>,
453 : permits: usize,
454 : timeout: Duration,
455 : epoch: std::time::Duration,
456 : metrics: &'static ApiLockMetrics,
457 : }
458 :
459 : impl ApiLocks {
460 0 : pub fn new(
461 0 : name: &'static str,
462 0 : permits: usize,
463 0 : shards: usize,
464 0 : timeout: Duration,
465 0 : epoch: std::time::Duration,
466 0 : metrics: &'static ApiLockMetrics,
467 0 : ) -> prometheus::Result<Self> {
468 0 : Ok(Self {
469 0 : name,
470 0 : node_locks: DashMap::with_shard_amount(shards),
471 0 : permits,
472 0 : timeout,
473 0 : epoch,
474 0 : metrics,
475 0 : })
476 0 : }
477 :
478 0 : pub async fn get_wake_compute_permit(
479 0 : &self,
480 0 : key: &EndpointCacheKey,
481 0 : ) -> Result<WakeComputePermit, errors::WakeComputeError> {
482 0 : if self.permits == 0 {
483 0 : return Ok(WakeComputePermit { permit: None });
484 0 : }
485 0 : let now = Instant::now();
486 0 : let semaphore = {
487 : // get fast path
488 0 : if let Some(semaphore) = self.node_locks.get(key) {
489 0 : semaphore.clone()
490 : } else {
491 0 : self.node_locks
492 0 : .entry(key.clone())
493 0 : .or_insert_with(|| {
494 0 : self.metrics.semaphores_registered.inc();
495 0 : Arc::new(Semaphore::new(self.permits))
496 0 : })
497 0 : .clone()
498 : }
499 : };
500 0 : let permit = tokio::time::timeout_at(now + self.timeout, semaphore.acquire_owned()).await;
501 :
502 0 : self.metrics
503 0 : .semaphore_acquire_seconds
504 0 : .observe(now.elapsed().as_secs_f64());
505 0 :
506 0 : Ok(WakeComputePermit {
507 0 : permit: Some(permit??),
508 : })
509 0 : }
510 :
511 0 : pub async fn garbage_collect_worker(&self) {
512 0 : if self.permits == 0 {
513 0 : return;
514 0 : }
515 0 : let mut interval =
516 0 : tokio::time::interval(self.epoch / (self.node_locks.shards().len()) as u32);
517 : loop {
518 0 : for (i, shard) in self.node_locks.shards().iter().enumerate() {
519 0 : interval.tick().await;
520 : // temporary lock a single shard and then clear any semaphores that aren't currently checked out
521 : // race conditions: if strong_count == 1, there's no way that it can increase while the shard is locked
522 : // therefore releasing it is safe from race conditions
523 0 : info!(
524 0 : name = self.name,
525 0 : shard = i,
526 0 : "performing epoch reclamation on api lock"
527 0 : );
528 0 : let mut lock = shard.write();
529 0 : let timer = self.metrics.reclamation_lag_seconds.start_timer();
530 0 : let count = lock
531 0 : .extract_if(|_, semaphore| Arc::strong_count(semaphore.get_mut()) == 1)
532 0 : .count();
533 0 : drop(lock);
534 0 : self.metrics.semaphores_unregistered.inc_by(count as u64);
535 0 : timer.observe();
536 : }
537 : }
538 0 : }
539 : }
540 :
541 : pub struct WakeComputePermit {
542 : // None if the lock is disabled
543 : permit: Option<OwnedSemaphorePermit>,
544 : }
545 :
546 : impl WakeComputePermit {
547 0 : pub fn should_check_cache(&self) -> bool {
548 0 : self.permit.is_some()
549 0 : }
550 : }
|