Line data Source code
1 : #[cfg(any(test, feature = "testing"))]
2 : pub mod mock;
3 : pub mod neon;
4 :
5 : use super::messages::MetricsAuxInfo;
6 : use crate::{
7 : auth::{
8 : backend::{ComputeCredentialKeys, ComputeUserInfo},
9 : IpPattern,
10 : },
11 : cache::{project_info::ProjectInfoCacheImpl, Cached, TimedLru},
12 : compute,
13 : config::{CacheOptions, ProjectInfoCacheOptions},
14 : context::RequestMonitoring,
15 : intern::ProjectIdInt,
16 : scram, EndpointCacheKey,
17 : };
18 : use dashmap::DashMap;
19 : use std::{sync::Arc, time::Duration};
20 : use tokio::sync::{OwnedSemaphorePermit, Semaphore};
21 : use tokio::time::Instant;
22 : use tracing::info;
23 :
24 : pub mod errors {
25 : use crate::{
26 : error::{io_error, ReportableError, UserFacingError},
27 : http,
28 : proxy::retry::ShouldRetry,
29 : };
30 : use thiserror::Error;
31 :
32 : /// A go-to error message which doesn't leak any detail.
33 : const REQUEST_FAILED: &str = "Console request failed";
34 :
35 : /// Common console API error.
36 0 : #[derive(Debug, Error)]
37 : pub enum ApiError {
38 : /// Error returned by the console itself.
39 : #[error("{REQUEST_FAILED} with {}: {}", .status, .text)]
40 : Console {
41 : status: http::StatusCode,
42 : text: Box<str>,
43 : },
44 :
45 : /// Various IO errors like broken pipe or malformed payload.
46 : #[error("{REQUEST_FAILED}: {0}")]
47 : Transport(#[from] std::io::Error),
48 : }
49 :
50 : impl ApiError {
51 : /// Returns HTTP status code if it's the reason for failure.
52 0 : pub fn http_status_code(&self) -> Option<http::StatusCode> {
53 0 : use ApiError::*;
54 0 : match self {
55 0 : Console { status, .. } => Some(*status),
56 0 : _ => None,
57 : }
58 0 : }
59 : }
60 :
61 : impl UserFacingError for ApiError {
62 0 : fn to_string_client(&self) -> String {
63 0 : use ApiError::*;
64 0 : match self {
65 : // To minimize risks, only select errors are forwarded to users.
66 : // Ask @neondatabase/control-plane for review before adding more.
67 0 : Console { status, .. } => match *status {
68 : http::StatusCode::NOT_FOUND => {
69 : // Status 404: failed to get a project-related resource.
70 0 : format!("{REQUEST_FAILED}: endpoint cannot be found")
71 : }
72 : http::StatusCode::NOT_ACCEPTABLE => {
73 : // Status 406: endpoint is disabled (we don't allow connections).
74 0 : format!("{REQUEST_FAILED}: endpoint is disabled")
75 : }
76 : http::StatusCode::LOCKED | http::StatusCode::UNPROCESSABLE_ENTITY => {
77 : // Status 423: project might be in maintenance mode (or bad state), or quotas exceeded.
78 0 : format!("{REQUEST_FAILED}: endpoint is temporary unavailable. check your quotas and/or contact our support")
79 : }
80 0 : _ => REQUEST_FAILED.to_owned(),
81 : },
82 0 : _ => REQUEST_FAILED.to_owned(),
83 : }
84 0 : }
85 : }
86 :
87 : impl ReportableError for ApiError {
88 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
89 0 : match self {
90 : ApiError::Console {
91 : status: http::StatusCode::NOT_FOUND | http::StatusCode::NOT_ACCEPTABLE,
92 : ..
93 0 : } => crate::error::ErrorKind::User,
94 : ApiError::Console {
95 : status: http::StatusCode::UNPROCESSABLE_ENTITY,
96 0 : text,
97 0 : } if text.contains("compute time quota of non-primary branches is exceeded") => {
98 0 : crate::error::ErrorKind::User
99 : }
100 : ApiError::Console {
101 : status: http::StatusCode::LOCKED,
102 0 : text,
103 0 : } if text.contains("quota exceeded")
104 0 : || text.contains("the limit for current plan reached") =>
105 0 : {
106 0 : crate::error::ErrorKind::User
107 : }
108 : ApiError::Console {
109 : status: http::StatusCode::TOO_MANY_REQUESTS,
110 : ..
111 0 : } => crate::error::ErrorKind::ServiceRateLimit,
112 0 : ApiError::Console { .. } => crate::error::ErrorKind::ControlPlane,
113 0 : ApiError::Transport(_) => crate::error::ErrorKind::ControlPlane,
114 : }
115 0 : }
116 : }
117 :
118 : impl ShouldRetry for ApiError {
119 12 : fn could_retry(&self) -> bool {
120 12 : match self {
121 : // retry some transport errors
122 0 : Self::Transport(io) => io.could_retry(),
123 : // retry some temporary failures because the compute was in a bad state
124 : // (bad request can be returned when the endpoint was in transition)
125 : Self::Console {
126 : status: http::StatusCode::BAD_REQUEST,
127 : ..
128 8 : } => true,
129 : // don't retry when quotas are exceeded
130 : Self::Console {
131 : status: http::StatusCode::UNPROCESSABLE_ENTITY,
132 0 : ref text,
133 0 : } => !text.contains("compute time quota of non-primary branches is exceeded"),
134 : // locked can be returned when the endpoint was in transition
135 : // or when quotas are exceeded. don't retry when quotas are exceeded
136 : Self::Console {
137 : status: http::StatusCode::LOCKED,
138 0 : ref text,
139 0 : } => {
140 0 : // written data quota exceeded
141 0 : // data transfer quota exceeded
142 0 : // compute time quota exceeded
143 0 : // logical size quota exceeded
144 0 : !text.contains("quota exceeded")
145 0 : && !text.contains("the limit for current plan reached")
146 : }
147 4 : _ => false,
148 : }
149 12 : }
150 : }
151 :
152 : impl From<reqwest::Error> for ApiError {
153 0 : fn from(e: reqwest::Error) -> Self {
154 0 : io_error(e).into()
155 0 : }
156 : }
157 :
158 : impl From<reqwest_middleware::Error> for ApiError {
159 0 : fn from(e: reqwest_middleware::Error) -> Self {
160 0 : io_error(e).into()
161 0 : }
162 : }
163 :
164 0 : #[derive(Debug, Error)]
165 : pub enum GetAuthInfoError {
166 : // We shouldn't include the actual secret here.
167 : #[error("Console responded with a malformed auth secret")]
168 : BadSecret,
169 :
170 : #[error(transparent)]
171 : ApiError(ApiError),
172 : }
173 :
174 : // This allows more useful interactions than `#[from]`.
175 : impl<E: Into<ApiError>> From<E> for GetAuthInfoError {
176 0 : fn from(e: E) -> Self {
177 0 : Self::ApiError(e.into())
178 0 : }
179 : }
180 :
181 : impl UserFacingError for GetAuthInfoError {
182 0 : fn to_string_client(&self) -> String {
183 0 : use GetAuthInfoError::*;
184 0 : match self {
185 : // We absolutely should not leak any secrets!
186 0 : BadSecret => REQUEST_FAILED.to_owned(),
187 : // However, API might return a meaningful error.
188 0 : ApiError(e) => e.to_string_client(),
189 : }
190 0 : }
191 : }
192 :
193 : impl ReportableError for GetAuthInfoError {
194 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
195 0 : match self {
196 0 : GetAuthInfoError::BadSecret => crate::error::ErrorKind::ControlPlane,
197 0 : GetAuthInfoError::ApiError(_) => crate::error::ErrorKind::ControlPlane,
198 : }
199 0 : }
200 : }
201 :
202 0 : #[derive(Debug, Error)]
203 : pub enum WakeComputeError {
204 : #[error("Console responded with a malformed compute address: {0}")]
205 : BadComputeAddress(Box<str>),
206 :
207 : #[error(transparent)]
208 : ApiError(ApiError),
209 :
210 : #[error("Timeout waiting to acquire wake compute lock")]
211 : TimeoutError,
212 : }
213 :
214 : // This allows more useful interactions than `#[from]`.
215 : impl<E: Into<ApiError>> From<E> for WakeComputeError {
216 0 : fn from(e: E) -> Self {
217 0 : Self::ApiError(e.into())
218 0 : }
219 : }
220 :
221 : impl From<tokio::sync::AcquireError> for WakeComputeError {
222 0 : fn from(_: tokio::sync::AcquireError) -> Self {
223 0 : WakeComputeError::TimeoutError
224 0 : }
225 : }
226 : impl From<tokio::time::error::Elapsed> for WakeComputeError {
227 0 : fn from(_: tokio::time::error::Elapsed) -> Self {
228 0 : WakeComputeError::TimeoutError
229 0 : }
230 : }
231 :
232 : impl UserFacingError for WakeComputeError {
233 0 : fn to_string_client(&self) -> String {
234 0 : use WakeComputeError::*;
235 0 : match self {
236 : // We shouldn't show user the address even if it's broken.
237 : // Besides, user is unlikely to care about this detail.
238 0 : BadComputeAddress(_) => REQUEST_FAILED.to_owned(),
239 : // However, API might return a meaningful error.
240 0 : ApiError(e) => e.to_string_client(),
241 :
242 0 : TimeoutError => "timeout while acquiring the compute resource lock".to_owned(),
243 : }
244 0 : }
245 : }
246 :
247 : impl ReportableError for WakeComputeError {
248 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
249 0 : match self {
250 0 : WakeComputeError::BadComputeAddress(_) => crate::error::ErrorKind::ControlPlane,
251 0 : WakeComputeError::ApiError(e) => e.get_error_kind(),
252 0 : WakeComputeError::TimeoutError => crate::error::ErrorKind::ServiceRateLimit,
253 : }
254 0 : }
255 : }
256 : }
257 :
258 : /// Auth secret which is managed by the cloud.
259 : #[derive(Clone, Eq, PartialEq, Debug)]
260 : pub enum AuthSecret {
261 : #[cfg(any(test, feature = "testing"))]
262 : /// Md5 hash of user's password.
263 : Md5([u8; 16]),
264 :
265 : /// [SCRAM](crate::scram) authentication info.
266 : Scram(scram::ServerSecret),
267 : }
268 :
269 : #[derive(Default)]
270 : pub struct AuthInfo {
271 : pub secret: Option<AuthSecret>,
272 : /// List of IP addresses allowed for the autorization.
273 : pub allowed_ips: Vec<IpPattern>,
274 : /// Project ID. This is used for cache invalidation.
275 : pub project_id: Option<ProjectIdInt>,
276 : }
277 :
278 : /// Info for establishing a connection to a compute node.
279 : /// This is what we get after auth succeeded, but not before!
280 : #[derive(Clone)]
281 : pub struct NodeInfo {
282 : /// Compute node connection params.
283 : /// It's sad that we have to clone this, but this will improve
284 : /// once we migrate to a bespoke connection logic.
285 : pub config: compute::ConnCfg,
286 :
287 : /// Labels for proxy's metrics.
288 : pub aux: MetricsAuxInfo,
289 :
290 : /// Whether we should accept self-signed certificates (for testing)
291 : pub allow_self_signed_compute: bool,
292 : }
293 :
294 : impl NodeInfo {
295 0 : pub async fn connect(
296 0 : &self,
297 0 : ctx: &mut RequestMonitoring,
298 0 : timeout: Duration,
299 0 : ) -> Result<compute::PostgresConnection, compute::ConnectionError> {
300 0 : self.config
301 0 : .connect(
302 0 : ctx,
303 0 : self.allow_self_signed_compute,
304 0 : self.aux.clone(),
305 0 : timeout,
306 0 : )
307 0 : .await
308 0 : }
309 8 : pub fn reuse_settings(&mut self, other: Self) {
310 8 : self.allow_self_signed_compute = other.allow_self_signed_compute;
311 8 : self.config.reuse_password(other.config);
312 8 : }
313 :
314 12 : pub fn set_keys(&mut self, keys: &ComputeCredentialKeys) {
315 12 : match keys {
316 12 : ComputeCredentialKeys::Password(password) => self.config.password(password),
317 0 : ComputeCredentialKeys::AuthKeys(auth_keys) => self.config.auth_keys(*auth_keys),
318 : };
319 12 : }
320 : }
321 :
322 : pub type NodeInfoCache = TimedLru<EndpointCacheKey, NodeInfo>;
323 : pub type CachedNodeInfo = Cached<&'static NodeInfoCache>;
324 : pub type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option<AuthSecret>>;
325 : pub type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc<Vec<IpPattern>>>;
326 :
327 : /// This will allocate per each call, but the http requests alone
328 : /// already require a few allocations, so it should be fine.
329 : pub(crate) trait Api {
330 : /// Get the client's auth secret for authentication.
331 : /// Returns option because user not found situation is special.
332 : /// We still have to mock the scram to avoid leaking information that user doesn't exist.
333 : async fn get_role_secret(
334 : &self,
335 : ctx: &mut RequestMonitoring,
336 : user_info: &ComputeUserInfo,
337 : ) -> Result<CachedRoleSecret, errors::GetAuthInfoError>;
338 :
339 : async fn get_allowed_ips_and_secret(
340 : &self,
341 : ctx: &mut RequestMonitoring,
342 : user_info: &ComputeUserInfo,
343 : ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError>;
344 :
345 : /// Wake up the compute node and return the corresponding connection info.
346 : async fn wake_compute(
347 : &self,
348 : ctx: &mut RequestMonitoring,
349 : user_info: &ComputeUserInfo,
350 : ) -> Result<CachedNodeInfo, errors::WakeComputeError>;
351 : }
352 :
353 : #[non_exhaustive]
354 : pub enum ConsoleBackend {
355 : /// Current Cloud API (V2).
356 : Console(neon::Api),
357 : /// Local mock of Cloud API (V2).
358 : #[cfg(any(test, feature = "testing"))]
359 : Postgres(mock::Api),
360 : /// Internal testing
361 : #[cfg(test)]
362 : Test(Box<dyn crate::auth::backend::TestBackend>),
363 : }
364 :
365 : impl Api for ConsoleBackend {
366 0 : async fn get_role_secret(
367 0 : &self,
368 0 : ctx: &mut RequestMonitoring,
369 0 : user_info: &ComputeUserInfo,
370 0 : ) -> Result<CachedRoleSecret, errors::GetAuthInfoError> {
371 0 : use ConsoleBackend::*;
372 0 : match self {
373 0 : Console(api) => api.get_role_secret(ctx, user_info).await,
374 : #[cfg(any(test, feature = "testing"))]
375 0 : Postgres(api) => api.get_role_secret(ctx, user_info).await,
376 : #[cfg(test)]
377 0 : Test(_) => unreachable!("this function should never be called in the test backend"),
378 : }
379 0 : }
380 :
381 0 : async fn get_allowed_ips_and_secret(
382 0 : &self,
383 0 : ctx: &mut RequestMonitoring,
384 0 : user_info: &ComputeUserInfo,
385 0 : ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError> {
386 0 : use ConsoleBackend::*;
387 0 : match self {
388 0 : Console(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
389 : #[cfg(any(test, feature = "testing"))]
390 0 : Postgres(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
391 : #[cfg(test)]
392 0 : Test(api) => api.get_allowed_ips_and_secret(),
393 : }
394 0 : }
395 :
396 26 : async fn wake_compute(
397 26 : &self,
398 26 : ctx: &mut RequestMonitoring,
399 26 : user_info: &ComputeUserInfo,
400 26 : ) -> Result<CachedNodeInfo, errors::WakeComputeError> {
401 26 : use ConsoleBackend::*;
402 26 :
403 26 : match self {
404 0 : Console(api) => api.wake_compute(ctx, user_info).await,
405 : #[cfg(any(test, feature = "testing"))]
406 0 : Postgres(api) => api.wake_compute(ctx, user_info).await,
407 : #[cfg(test)]
408 26 : Test(api) => api.wake_compute(),
409 : }
410 26 : }
411 : }
412 :
413 : /// Various caches for [`console`](super).
414 : pub struct ApiCaches {
415 : /// Cache for the `wake_compute` API method.
416 : pub node_info: NodeInfoCache,
417 : /// Cache which stores project_id -> endpoint_ids mapping.
418 : pub project_info: Arc<ProjectInfoCacheImpl>,
419 : }
420 :
421 : impl ApiCaches {
422 0 : pub fn new(
423 0 : wake_compute_cache_config: CacheOptions,
424 0 : project_info_cache_config: ProjectInfoCacheOptions,
425 0 : ) -> Self {
426 0 : Self {
427 0 : node_info: NodeInfoCache::new(
428 0 : "node_info_cache",
429 0 : wake_compute_cache_config.size,
430 0 : wake_compute_cache_config.ttl,
431 0 : true,
432 0 : ),
433 0 : project_info: Arc::new(ProjectInfoCacheImpl::new(project_info_cache_config)),
434 0 : }
435 0 : }
436 : }
437 :
438 : /// Various caches for [`console`](super).
439 : pub struct ApiLocks {
440 : name: &'static str,
441 : node_locks: DashMap<EndpointCacheKey, Arc<Semaphore>>,
442 : permits: usize,
443 : timeout: Duration,
444 : registered: prometheus::IntCounter,
445 : unregistered: prometheus::IntCounter,
446 : reclamation_lag: prometheus::Histogram,
447 : lock_acquire_lag: prometheus::Histogram,
448 : }
449 :
450 : impl ApiLocks {
451 0 : pub fn new(
452 0 : name: &'static str,
453 0 : permits: usize,
454 0 : shards: usize,
455 0 : timeout: Duration,
456 0 : ) -> prometheus::Result<Self> {
457 0 : let registered = prometheus::IntCounter::with_opts(
458 0 : prometheus::Opts::new(
459 0 : "semaphores_registered",
460 0 : "Number of semaphores registered in this api lock",
461 0 : )
462 0 : .namespace(name),
463 0 : )?;
464 0 : prometheus::register(Box::new(registered.clone()))?;
465 0 : let unregistered = prometheus::IntCounter::with_opts(
466 0 : prometheus::Opts::new(
467 0 : "semaphores_unregistered",
468 0 : "Number of semaphores unregistered in this api lock",
469 0 : )
470 0 : .namespace(name),
471 0 : )?;
472 0 : prometheus::register(Box::new(unregistered.clone()))?;
473 0 : let reclamation_lag = prometheus::Histogram::with_opts(
474 0 : prometheus::HistogramOpts::new(
475 0 : "reclamation_lag_seconds",
476 0 : "Time it takes to reclaim unused semaphores in the api lock",
477 0 : )
478 0 : .namespace(name)
479 0 : // 1us -> 65ms
480 0 : // benchmarks on my mac indicate it's usually in the range of 256us and 512us
481 0 : .buckets(prometheus::exponential_buckets(1e-6, 2.0, 16)?),
482 0 : )?;
483 0 : prometheus::register(Box::new(reclamation_lag.clone()))?;
484 0 : let lock_acquire_lag = prometheus::Histogram::with_opts(
485 0 : prometheus::HistogramOpts::new(
486 0 : "semaphore_acquire_seconds",
487 0 : "Time it takes to reclaim unused semaphores in the api lock",
488 0 : )
489 0 : .namespace(name)
490 0 : // 0.1ms -> 6s
491 0 : .buckets(prometheus::exponential_buckets(1e-4, 2.0, 16)?),
492 0 : )?;
493 0 : prometheus::register(Box::new(lock_acquire_lag.clone()))?;
494 :
495 0 : Ok(Self {
496 0 : name,
497 0 : node_locks: DashMap::with_shard_amount(shards),
498 0 : permits,
499 0 : timeout,
500 0 : lock_acquire_lag,
501 0 : registered,
502 0 : unregistered,
503 0 : reclamation_lag,
504 0 : })
505 0 : }
506 :
507 0 : pub async fn get_wake_compute_permit(
508 0 : &self,
509 0 : key: &EndpointCacheKey,
510 0 : ) -> Result<WakeComputePermit, errors::WakeComputeError> {
511 0 : if self.permits == 0 {
512 0 : return Ok(WakeComputePermit { permit: None });
513 0 : }
514 0 : let now = Instant::now();
515 0 : let semaphore = {
516 : // get fast path
517 0 : if let Some(semaphore) = self.node_locks.get(key) {
518 0 : semaphore.clone()
519 : } else {
520 0 : self.node_locks
521 0 : .entry(key.clone())
522 0 : .or_insert_with(|| {
523 0 : self.registered.inc();
524 0 : Arc::new(Semaphore::new(self.permits))
525 0 : })
526 0 : .clone()
527 : }
528 : };
529 0 : let permit = tokio::time::timeout_at(now + self.timeout, semaphore.acquire_owned()).await;
530 :
531 0 : self.lock_acquire_lag
532 0 : .observe((Instant::now() - now).as_secs_f64());
533 0 :
534 0 : Ok(WakeComputePermit {
535 0 : permit: Some(permit??),
536 : })
537 0 : }
538 :
539 0 : pub async fn garbage_collect_worker(&self, epoch: std::time::Duration) {
540 0 : if self.permits == 0 {
541 0 : return;
542 0 : }
543 0 :
544 0 : let mut interval = tokio::time::interval(epoch / (self.node_locks.shards().len()) as u32);
545 : loop {
546 0 : for (i, shard) in self.node_locks.shards().iter().enumerate() {
547 0 : interval.tick().await;
548 : // temporary lock a single shard and then clear any semaphores that aren't currently checked out
549 : // race conditions: if strong_count == 1, there's no way that it can increase while the shard is locked
550 : // therefore releasing it is safe from race conditions
551 0 : info!(
552 0 : name = self.name,
553 0 : shard = i,
554 0 : "performing epoch reclamation on api lock"
555 0 : );
556 0 : let mut lock = shard.write();
557 0 : let timer = self.reclamation_lag.start_timer();
558 0 : let count = lock
559 0 : .extract_if(|_, semaphore| Arc::strong_count(semaphore.get_mut()) == 1)
560 0 : .count();
561 0 : drop(lock);
562 0 : self.unregistered.inc_by(count as u64);
563 0 : timer.observe_duration()
564 : }
565 : }
566 0 : }
567 : }
568 :
569 : pub struct WakeComputePermit {
570 : // None if the lock is disabled
571 : permit: Option<OwnedSemaphorePermit>,
572 : }
573 :
574 : impl WakeComputePermit {
575 0 : pub fn should_check_cache(&self) -> bool {
576 0 : self.permit.is_some()
577 0 : }
578 : }
|