Line data Source code
1 : #[cfg(any(test, feature = "testing"))]
2 : pub mod mock;
3 : pub mod neon;
4 :
5 : use super::messages::MetricsAuxInfo;
6 : use crate::{
7 : auth::{
8 : backend::{ComputeCredentialKeys, ComputeUserInfo},
9 : IpPattern,
10 : },
11 : cache::{project_info::ProjectInfoCacheImpl, Cached, TimedLru},
12 : compute,
13 : config::{CacheOptions, ProjectInfoCacheOptions},
14 : context::RequestMonitoring,
15 : scram, EndpointCacheKey, ProjectId,
16 : };
17 : use async_trait::async_trait;
18 : use dashmap::DashMap;
19 : use std::{sync::Arc, time::Duration};
20 : use tokio::sync::{OwnedSemaphorePermit, Semaphore};
21 : use tokio::time::Instant;
22 : use tracing::info;
23 :
24 : pub mod errors {
25 : use crate::{
26 : error::{io_error, ReportableError, UserFacingError},
27 : http,
28 : proxy::retry::ShouldRetry,
29 : };
30 : use thiserror::Error;
31 :
32 : /// A go-to error message which doesn't leak any detail.
33 : const REQUEST_FAILED: &str = "Console request failed";
34 :
35 : /// Common console API error.
36 2 : #[derive(Debug, Error)]
37 : pub enum ApiError {
38 : /// Error returned by the console itself.
39 : #[error("{REQUEST_FAILED} with {}: {}", .status, .text)]
40 : Console {
41 : status: http::StatusCode,
42 : text: Box<str>,
43 : },
44 :
45 : /// Various IO errors like broken pipe or malformed payload.
46 : #[error("{REQUEST_FAILED}: {0}")]
47 : Transport(#[from] std::io::Error),
48 : }
49 :
50 : impl ApiError {
51 : /// Returns HTTP status code if it's the reason for failure.
52 0 : pub fn http_status_code(&self) -> Option<http::StatusCode> {
53 0 : use ApiError::*;
54 0 : match self {
55 0 : Console { status, .. } => Some(*status),
56 0 : _ => None,
57 : }
58 0 : }
59 : }
60 :
61 : impl UserFacingError for ApiError {
62 0 : fn to_string_client(&self) -> String {
63 0 : use ApiError::*;
64 0 : match self {
65 : // To minimize risks, only select errors are forwarded to users.
66 : // Ask @neondatabase/control-plane for review before adding more.
67 0 : Console { status, .. } => match *status {
68 : http::StatusCode::NOT_FOUND => {
69 : // Status 404: failed to get a project-related resource.
70 0 : format!("{REQUEST_FAILED}: endpoint cannot be found")
71 : }
72 : http::StatusCode::NOT_ACCEPTABLE => {
73 : // Status 406: endpoint is disabled (we don't allow connections).
74 0 : format!("{REQUEST_FAILED}: endpoint is disabled")
75 : }
76 : http::StatusCode::LOCKED => {
77 : // Status 423: project might be in maintenance mode (or bad state), or quotas exceeded.
78 0 : format!("{REQUEST_FAILED}: endpoint is temporary unavailable. check your quotas and/or contact our support")
79 : }
80 0 : _ => REQUEST_FAILED.to_owned(),
81 : },
82 0 : _ => REQUEST_FAILED.to_owned(),
83 : }
84 0 : }
85 : }
86 :
87 : impl ReportableError for ApiError {
88 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
89 0 : match self {
90 : ApiError::Console {
91 : status: http::StatusCode::NOT_FOUND | http::StatusCode::NOT_ACCEPTABLE,
92 : ..
93 0 : } => crate::error::ErrorKind::User,
94 : ApiError::Console {
95 : status: http::StatusCode::LOCKED,
96 0 : text,
97 0 : } if text.contains("quota exceeded")
98 0 : || text.contains("the limit for current plan reached") =>
99 0 : {
100 0 : crate::error::ErrorKind::User
101 : }
102 : ApiError::Console {
103 : status: http::StatusCode::TOO_MANY_REQUESTS,
104 : ..
105 0 : } => crate::error::ErrorKind::ServiceRateLimit,
106 0 : ApiError::Console { .. } => crate::error::ErrorKind::ControlPlane,
107 0 : ApiError::Transport(_) => crate::error::ErrorKind::ControlPlane,
108 : }
109 0 : }
110 : }
111 :
112 : impl ShouldRetry for ApiError {
113 12 : fn could_retry(&self) -> bool {
114 12 : match self {
115 : // retry some transport errors
116 0 : Self::Transport(io) => io.could_retry(),
117 : // retry some temporary failures because the compute was in a bad state
118 : // (bad request can be returned when the endpoint was in transition)
119 : Self::Console {
120 : status: http::StatusCode::BAD_REQUEST,
121 : ..
122 8 : } => true,
123 : // locked can be returned when the endpoint was in transition
124 : // or when quotas are exceeded. don't retry when quotas are exceeded
125 : Self::Console {
126 : status: http::StatusCode::LOCKED,
127 0 : ref text,
128 0 : } => {
129 0 : // written data quota exceeded
130 0 : // data transfer quota exceeded
131 0 : // compute time quota exceeded
132 0 : // logical size quota exceeded
133 0 : !text.contains("quota exceeded")
134 0 : && !text.contains("the limit for current plan reached")
135 : }
136 4 : _ => false,
137 : }
138 12 : }
139 : }
140 :
141 : impl From<reqwest::Error> for ApiError {
142 0 : fn from(e: reqwest::Error) -> Self {
143 0 : io_error(e).into()
144 0 : }
145 : }
146 :
147 : impl From<reqwest_middleware::Error> for ApiError {
148 0 : fn from(e: reqwest_middleware::Error) -> Self {
149 0 : io_error(e).into()
150 0 : }
151 : }
152 :
153 0 : #[derive(Debug, Error)]
154 : pub enum GetAuthInfoError {
155 : // We shouldn't include the actual secret here.
156 : #[error("Console responded with a malformed auth secret")]
157 : BadSecret,
158 :
159 : #[error(transparent)]
160 : ApiError(ApiError),
161 : }
162 :
163 : // This allows more useful interactions than `#[from]`.
164 : impl<E: Into<ApiError>> From<E> for GetAuthInfoError {
165 0 : fn from(e: E) -> Self {
166 0 : Self::ApiError(e.into())
167 0 : }
168 : }
169 :
170 : impl UserFacingError for GetAuthInfoError {
171 0 : fn to_string_client(&self) -> String {
172 0 : use GetAuthInfoError::*;
173 0 : match self {
174 : // We absolutely should not leak any secrets!
175 0 : BadSecret => REQUEST_FAILED.to_owned(),
176 : // However, API might return a meaningful error.
177 0 : ApiError(e) => e.to_string_client(),
178 : }
179 0 : }
180 : }
181 :
182 : impl ReportableError for GetAuthInfoError {
183 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
184 0 : match self {
185 0 : GetAuthInfoError::BadSecret => crate::error::ErrorKind::ControlPlane,
186 0 : GetAuthInfoError::ApiError(_) => crate::error::ErrorKind::ControlPlane,
187 : }
188 0 : }
189 : }
190 :
191 2 : #[derive(Debug, Error)]
192 : pub enum WakeComputeError {
193 : #[error("Console responded with a malformed compute address: {0}")]
194 : BadComputeAddress(Box<str>),
195 :
196 : #[error(transparent)]
197 : ApiError(ApiError),
198 :
199 : #[error("Timeout waiting to acquire wake compute lock")]
200 : TimeoutError,
201 : }
202 :
203 : // This allows more useful interactions than `#[from]`.
204 : impl<E: Into<ApiError>> From<E> for WakeComputeError {
205 0 : fn from(e: E) -> Self {
206 0 : Self::ApiError(e.into())
207 0 : }
208 : }
209 :
210 : impl From<tokio::sync::AcquireError> for WakeComputeError {
211 0 : fn from(_: tokio::sync::AcquireError) -> Self {
212 0 : WakeComputeError::TimeoutError
213 0 : }
214 : }
215 : impl From<tokio::time::error::Elapsed> for WakeComputeError {
216 0 : fn from(_: tokio::time::error::Elapsed) -> Self {
217 0 : WakeComputeError::TimeoutError
218 0 : }
219 : }
220 :
221 : impl UserFacingError for WakeComputeError {
222 0 : fn to_string_client(&self) -> String {
223 0 : use WakeComputeError::*;
224 0 : match self {
225 : // We shouldn't show user the address even if it's broken.
226 : // Besides, user is unlikely to care about this detail.
227 0 : BadComputeAddress(_) => REQUEST_FAILED.to_owned(),
228 : // However, API might return a meaningful error.
229 0 : ApiError(e) => e.to_string_client(),
230 :
231 0 : TimeoutError => "timeout while acquiring the compute resource lock".to_owned(),
232 : }
233 0 : }
234 : }
235 :
236 : impl ReportableError for WakeComputeError {
237 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
238 0 : match self {
239 0 : WakeComputeError::BadComputeAddress(_) => crate::error::ErrorKind::ControlPlane,
240 0 : WakeComputeError::ApiError(e) => e.get_error_kind(),
241 0 : WakeComputeError::TimeoutError => crate::error::ErrorKind::ServiceRateLimit,
242 : }
243 0 : }
244 : }
245 : }
246 :
247 : /// Auth secret which is managed by the cloud.
248 24 : #[derive(Clone, Eq, PartialEq, Debug)]
249 : pub enum AuthSecret {
250 : #[cfg(any(test, feature = "testing"))]
251 : /// Md5 hash of user's password.
252 : Md5([u8; 16]),
253 :
254 : /// [SCRAM](crate::scram) authentication info.
255 : Scram(scram::ServerSecret),
256 : }
257 :
258 0 : #[derive(Default)]
259 : pub struct AuthInfo {
260 : pub secret: Option<AuthSecret>,
261 : /// List of IP addresses allowed for the autorization.
262 : pub allowed_ips: Vec<IpPattern>,
263 : /// Project ID. This is used for cache invalidation.
264 : pub project_id: Option<ProjectId>,
265 : }
266 :
267 : /// Info for establishing a connection to a compute node.
268 : /// This is what we get after auth succeeded, but not before!
269 20 : #[derive(Clone)]
270 : pub struct NodeInfo {
271 : /// Compute node connection params.
272 : /// It's sad that we have to clone this, but this will improve
273 : /// once we migrate to a bespoke connection logic.
274 : pub config: compute::ConnCfg,
275 :
276 : /// Labels for proxy's metrics.
277 : pub aux: MetricsAuxInfo,
278 :
279 : /// Whether we should accept self-signed certificates (for testing)
280 : pub allow_self_signed_compute: bool,
281 : }
282 :
283 : impl NodeInfo {
284 0 : pub async fn connect(
285 0 : &self,
286 0 : ctx: &mut RequestMonitoring,
287 0 : timeout: Duration,
288 0 : ) -> Result<compute::PostgresConnection, compute::ConnectionError> {
289 0 : self.config
290 0 : .connect(
291 0 : ctx,
292 0 : self.allow_self_signed_compute,
293 0 : self.aux.clone(),
294 0 : timeout,
295 0 : )
296 0 : .await
297 0 : }
298 8 : pub fn reuse_settings(&mut self, other: Self) {
299 8 : self.allow_self_signed_compute = other.allow_self_signed_compute;
300 8 : self.config.reuse_password(other.config);
301 8 : }
302 :
303 12 : pub fn set_keys(&mut self, keys: &ComputeCredentialKeys) {
304 12 : match keys {
305 12 : ComputeCredentialKeys::Password(password) => self.config.password(password),
306 0 : ComputeCredentialKeys::AuthKeys(auth_keys) => self.config.auth_keys(*auth_keys),
307 : };
308 12 : }
309 : }
310 :
311 : pub type NodeInfoCache = TimedLru<EndpointCacheKey, NodeInfo>;
312 : pub type CachedNodeInfo = Cached<&'static NodeInfoCache>;
313 : pub type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option<AuthSecret>>;
314 : pub type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc<Vec<IpPattern>>>;
315 :
316 : /// This will allocate per each call, but the http requests alone
317 : /// already require a few allocations, so it should be fine.
318 : #[async_trait]
319 : pub trait Api {
320 : /// Get the client's auth secret for authentication.
321 : /// Returns option because user not found situation is special.
322 : /// We still have to mock the scram to avoid leaking information that user doesn't exist.
323 : async fn get_role_secret(
324 : &self,
325 : ctx: &mut RequestMonitoring,
326 : user_info: &ComputeUserInfo,
327 : ) -> Result<CachedRoleSecret, errors::GetAuthInfoError>;
328 :
329 : async fn get_allowed_ips_and_secret(
330 : &self,
331 : ctx: &mut RequestMonitoring,
332 : user_info: &ComputeUserInfo,
333 : ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError>;
334 :
335 : /// Wake up the compute node and return the corresponding connection info.
336 : async fn wake_compute(
337 : &self,
338 : ctx: &mut RequestMonitoring,
339 : user_info: &ComputeUserInfo,
340 : ) -> Result<CachedNodeInfo, errors::WakeComputeError>;
341 : }
342 :
343 : #[non_exhaustive]
344 : pub enum ConsoleBackend {
345 : /// Current Cloud API (V2).
346 : Console(neon::Api),
347 : /// Local mock of Cloud API (V2).
348 : #[cfg(any(test, feature = "testing"))]
349 : Postgres(mock::Api),
350 : /// Internal testing
351 : #[cfg(test)]
352 : Test(Box<dyn crate::auth::backend::TestBackend>),
353 : }
354 :
355 : #[async_trait]
356 : impl Api for ConsoleBackend {
357 0 : async fn get_role_secret(
358 0 : &self,
359 0 : ctx: &mut RequestMonitoring,
360 0 : user_info: &ComputeUserInfo,
361 0 : ) -> Result<CachedRoleSecret, errors::GetAuthInfoError> {
362 : use ConsoleBackend::*;
363 0 : match self {
364 0 : Console(api) => api.get_role_secret(ctx, user_info).await,
365 : #[cfg(any(test, feature = "testing"))]
366 0 : Postgres(api) => api.get_role_secret(ctx, user_info).await,
367 : #[cfg(test)]
368 0 : Test(_) => unreachable!("this function should never be called in the test backend"),
369 : }
370 0 : }
371 :
372 0 : async fn get_allowed_ips_and_secret(
373 0 : &self,
374 0 : ctx: &mut RequestMonitoring,
375 0 : user_info: &ComputeUserInfo,
376 0 : ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError> {
377 : use ConsoleBackend::*;
378 0 : match self {
379 0 : Console(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
380 : #[cfg(any(test, feature = "testing"))]
381 0 : Postgres(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
382 : #[cfg(test)]
383 0 : Test(api) => api.get_allowed_ips_and_secret(),
384 : }
385 0 : }
386 :
387 26 : async fn wake_compute(
388 26 : &self,
389 26 : ctx: &mut RequestMonitoring,
390 26 : user_info: &ComputeUserInfo,
391 26 : ) -> Result<CachedNodeInfo, errors::WakeComputeError> {
392 : use ConsoleBackend::*;
393 :
394 26 : match self {
395 0 : Console(api) => api.wake_compute(ctx, user_info).await,
396 : #[cfg(any(test, feature = "testing"))]
397 0 : Postgres(api) => api.wake_compute(ctx, user_info).await,
398 : #[cfg(test)]
399 26 : Test(api) => api.wake_compute(),
400 : }
401 78 : }
402 : }
403 :
404 : /// Various caches for [`console`](super).
405 : pub struct ApiCaches {
406 : /// Cache for the `wake_compute` API method.
407 : pub node_info: NodeInfoCache,
408 : /// Cache which stores project_id -> endpoint_ids mapping.
409 : pub project_info: Arc<ProjectInfoCacheImpl>,
410 : }
411 :
412 : impl ApiCaches {
413 0 : pub fn new(
414 0 : wake_compute_cache_config: CacheOptions,
415 0 : project_info_cache_config: ProjectInfoCacheOptions,
416 0 : ) -> Self {
417 0 : Self {
418 0 : node_info: NodeInfoCache::new(
419 0 : "node_info_cache",
420 0 : wake_compute_cache_config.size,
421 0 : wake_compute_cache_config.ttl,
422 0 : true,
423 0 : ),
424 0 : project_info: Arc::new(ProjectInfoCacheImpl::new(project_info_cache_config)),
425 0 : }
426 0 : }
427 : }
428 :
429 : /// Various caches for [`console`](super).
430 : pub struct ApiLocks {
431 : name: &'static str,
432 : node_locks: DashMap<EndpointCacheKey, Arc<Semaphore>>,
433 : permits: usize,
434 : timeout: Duration,
435 : registered: prometheus::IntCounter,
436 : unregistered: prometheus::IntCounter,
437 : reclamation_lag: prometheus::Histogram,
438 : lock_acquire_lag: prometheus::Histogram,
439 : }
440 :
441 : impl ApiLocks {
442 0 : pub fn new(
443 0 : name: &'static str,
444 0 : permits: usize,
445 0 : shards: usize,
446 0 : timeout: Duration,
447 0 : ) -> prometheus::Result<Self> {
448 0 : let registered = prometheus::IntCounter::with_opts(
449 0 : prometheus::Opts::new(
450 0 : "semaphores_registered",
451 0 : "Number of semaphores registered in this api lock",
452 0 : )
453 0 : .namespace(name),
454 0 : )?;
455 0 : prometheus::register(Box::new(registered.clone()))?;
456 0 : let unregistered = prometheus::IntCounter::with_opts(
457 0 : prometheus::Opts::new(
458 0 : "semaphores_unregistered",
459 0 : "Number of semaphores unregistered in this api lock",
460 0 : )
461 0 : .namespace(name),
462 0 : )?;
463 0 : prometheus::register(Box::new(unregistered.clone()))?;
464 0 : let reclamation_lag = prometheus::Histogram::with_opts(
465 0 : prometheus::HistogramOpts::new(
466 0 : "reclamation_lag_seconds",
467 0 : "Time it takes to reclaim unused semaphores in the api lock",
468 0 : )
469 0 : .namespace(name)
470 0 : // 1us -> 65ms
471 0 : // benchmarks on my mac indicate it's usually in the range of 256us and 512us
472 0 : .buckets(prometheus::exponential_buckets(1e-6, 2.0, 16)?),
473 0 : )?;
474 0 : prometheus::register(Box::new(reclamation_lag.clone()))?;
475 0 : let lock_acquire_lag = prometheus::Histogram::with_opts(
476 0 : prometheus::HistogramOpts::new(
477 0 : "semaphore_acquire_seconds",
478 0 : "Time it takes to reclaim unused semaphores in the api lock",
479 0 : )
480 0 : .namespace(name)
481 0 : // 0.1ms -> 6s
482 0 : .buckets(prometheus::exponential_buckets(1e-4, 2.0, 16)?),
483 0 : )?;
484 0 : prometheus::register(Box::new(lock_acquire_lag.clone()))?;
485 :
486 0 : Ok(Self {
487 0 : name,
488 0 : node_locks: DashMap::with_shard_amount(shards),
489 0 : permits,
490 0 : timeout,
491 0 : lock_acquire_lag,
492 0 : registered,
493 0 : unregistered,
494 0 : reclamation_lag,
495 0 : })
496 0 : }
497 :
498 0 : pub async fn get_wake_compute_permit(
499 0 : &self,
500 0 : key: &EndpointCacheKey,
501 0 : ) -> Result<WakeComputePermit, errors::WakeComputeError> {
502 0 : if self.permits == 0 {
503 0 : return Ok(WakeComputePermit { permit: None });
504 0 : }
505 0 : let now = Instant::now();
506 0 : let semaphore = {
507 : // get fast path
508 0 : if let Some(semaphore) = self.node_locks.get(key) {
509 0 : semaphore.clone()
510 : } else {
511 0 : self.node_locks
512 0 : .entry(key.clone())
513 0 : .or_insert_with(|| {
514 0 : self.registered.inc();
515 0 : Arc::new(Semaphore::new(self.permits))
516 0 : })
517 0 : .clone()
518 : }
519 : };
520 0 : let permit = tokio::time::timeout_at(now + self.timeout, semaphore.acquire_owned()).await;
521 :
522 0 : self.lock_acquire_lag
523 0 : .observe((Instant::now() - now).as_secs_f64());
524 0 :
525 0 : Ok(WakeComputePermit {
526 0 : permit: Some(permit??),
527 : })
528 0 : }
529 :
530 0 : pub async fn garbage_collect_worker(&self, epoch: std::time::Duration) {
531 0 : if self.permits == 0 {
532 0 : return;
533 0 : }
534 0 :
535 0 : let mut interval = tokio::time::interval(epoch / (self.node_locks.shards().len()) as u32);
536 : loop {
537 0 : for (i, shard) in self.node_locks.shards().iter().enumerate() {
538 0 : interval.tick().await;
539 : // temporary lock a single shard and then clear any semaphores that aren't currently checked out
540 : // race conditions: if strong_count == 1, there's no way that it can increase while the shard is locked
541 : // therefore releasing it is safe from race conditions
542 0 : info!(
543 0 : name = self.name,
544 0 : shard = i,
545 0 : "performing epoch reclamation on api lock"
546 0 : );
547 0 : let mut lock = shard.write();
548 0 : let timer = self.reclamation_lag.start_timer();
549 0 : let count = lock
550 0 : .extract_if(|_, semaphore| Arc::strong_count(semaphore.get_mut()) == 1)
551 0 : .count();
552 0 : drop(lock);
553 0 : self.unregistered.inc_by(count as u64);
554 0 : timer.observe_duration()
555 : }
556 : }
557 0 : }
558 : }
559 :
560 : pub struct WakeComputePermit {
561 : // None if the lock is disabled
562 : permit: Option<OwnedSemaphorePermit>,
563 : }
564 :
565 : impl WakeComputePermit {
566 0 : pub fn should_check_cache(&self) -> bool {
567 0 : self.permit.is_some()
568 0 : }
569 : }
|