Line data Source code
1 : #[cfg(any(test, feature = "testing"))]
2 : pub mod mock;
3 : pub mod neon;
4 :
5 : use super::messages::MetricsAuxInfo;
6 : use crate::{
7 : auth::{
8 : backend::{ComputeCredentialKeys, ComputeUserInfo},
9 : IpPattern,
10 : },
11 : cache::{project_info::ProjectInfoCacheImpl, Cached, TimedLru},
12 : compute,
13 : config::{CacheOptions, ProjectInfoCacheOptions},
14 : context::RequestMonitoring,
15 : scram, EndpointCacheKey, ProjectId,
16 : };
17 : use async_trait::async_trait;
18 : use dashmap::DashMap;
19 : use std::{sync::Arc, time::Duration};
20 : use tokio::sync::{OwnedSemaphorePermit, Semaphore};
21 : use tokio::time::Instant;
22 : use tracing::info;
23 :
24 : pub mod errors {
25 : use crate::{
26 : error::{io_error, ReportableError, UserFacingError},
27 : http,
28 : proxy::retry::ShouldRetry,
29 : };
30 : use thiserror::Error;
31 :
32 : /// A go-to error message which doesn't leak any detail.
33 : const REQUEST_FAILED: &str = "Console request failed";
34 :
35 : /// Common console API error.
36 57 : #[derive(Debug, Error)]
37 : pub enum ApiError {
38 : /// Error returned by the console itself.
39 : #[error("{REQUEST_FAILED} with {}: {}", .status, .text)]
40 : Console {
41 : status: http::StatusCode,
42 : text: Box<str>,
43 : },
44 :
45 : /// Various IO errors like broken pipe or malformed payload.
46 : #[error("{REQUEST_FAILED}: {0}")]
47 : Transport(#[from] std::io::Error),
48 : }
49 :
50 : impl ApiError {
51 : /// Returns HTTP status code if it's the reason for failure.
52 3 : pub fn http_status_code(&self) -> Option<http::StatusCode> {
53 3 : use ApiError::*;
54 3 : match self {
55 2 : Console { status, .. } => Some(*status),
56 1 : _ => None,
57 : }
58 3 : }
59 : }
60 :
61 : impl UserFacingError for ApiError {
62 4 : fn to_string_client(&self) -> String {
63 4 : use ApiError::*;
64 4 : match self {
65 : // To minimize risks, only select errors are forwarded to users.
66 : // Ask @neondatabase/control-plane for review before adding more.
67 2 : Console { status, .. } => match *status {
68 : http::StatusCode::NOT_FOUND => {
69 : // Status 404: failed to get a project-related resource.
70 0 : format!("{REQUEST_FAILED}: endpoint cannot be found")
71 : }
72 : http::StatusCode::NOT_ACCEPTABLE => {
73 : // Status 406: endpoint is disabled (we don't allow connections).
74 0 : format!("{REQUEST_FAILED}: endpoint is disabled")
75 : }
76 : http::StatusCode::LOCKED => {
77 : // Status 423: project might be in maintenance mode (or bad state), or quotas exceeded.
78 0 : format!("{REQUEST_FAILED}: endpoint is temporary unavailable. check your quotas and/or contact our support")
79 : }
80 2 : _ => REQUEST_FAILED.to_owned(),
81 : },
82 2 : _ => REQUEST_FAILED.to_owned(),
83 : }
84 4 : }
85 : }
86 :
87 : impl ReportableError for ApiError {
88 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
89 0 : match self {
90 0 : ApiError::Console { .. } => crate::error::ErrorKind::ControlPlane,
91 0 : ApiError::Transport(_) => crate::error::ErrorKind::ControlPlane,
92 : }
93 0 : }
94 : }
95 :
96 : impl ShouldRetry for ApiError {
97 12 : fn could_retry(&self) -> bool {
98 12 : match self {
99 : // retry some transport errors
100 0 : Self::Transport(io) => io.could_retry(),
101 : // retry some temporary failures because the compute was in a bad state
102 : // (bad request can be returned when the endpoint was in transition)
103 : Self::Console {
104 : status: http::StatusCode::BAD_REQUEST,
105 : ..
106 8 : } => true,
107 : // locked can be returned when the endpoint was in transition
108 : // or when quotas are exceeded. don't retry when quotas are exceeded
109 : Self::Console {
110 : status: http::StatusCode::LOCKED,
111 0 : ref text,
112 0 : } => {
113 0 : // written data quota exceeded
114 0 : // data transfer quota exceeded
115 0 : // compute time quota exceeded
116 0 : // logical size quota exceeded
117 0 : !text.contains("quota exceeded")
118 0 : && !text.contains("the limit for current plan reached")
119 : }
120 4 : _ => false,
121 : }
122 12 : }
123 : }
124 :
125 : impl From<reqwest::Error> for ApiError {
126 1 : fn from(e: reqwest::Error) -> Self {
127 1 : io_error(e).into()
128 1 : }
129 : }
130 :
131 : impl From<reqwest_middleware::Error> for ApiError {
132 1 : fn from(e: reqwest_middleware::Error) -> Self {
133 1 : io_error(e).into()
134 1 : }
135 : }
136 :
137 40 : #[derive(Debug, Error)]
138 : pub enum GetAuthInfoError {
139 : // We shouldn't include the actual secret here.
140 : #[error("Console responded with a malformed auth secret")]
141 : BadSecret,
142 :
143 : #[error(transparent)]
144 : ApiError(ApiError),
145 : }
146 :
147 : // This allows more useful interactions than `#[from]`.
148 : impl<E: Into<ApiError>> From<E> for GetAuthInfoError {
149 4 : fn from(e: E) -> Self {
150 4 : Self::ApiError(e.into())
151 4 : }
152 : }
153 :
154 : impl UserFacingError for GetAuthInfoError {
155 4 : fn to_string_client(&self) -> String {
156 4 : use GetAuthInfoError::*;
157 4 : match self {
158 : // We absolutely should not leak any secrets!
159 0 : BadSecret => REQUEST_FAILED.to_owned(),
160 : // However, API might return a meaningful error.
161 4 : ApiError(e) => e.to_string_client(),
162 : }
163 4 : }
164 : }
165 :
166 : impl ReportableError for GetAuthInfoError {
167 4 : fn get_error_kind(&self) -> crate::error::ErrorKind {
168 4 : match self {
169 0 : GetAuthInfoError::BadSecret => crate::error::ErrorKind::ControlPlane,
170 4 : GetAuthInfoError::ApiError(_) => crate::error::ErrorKind::ControlPlane,
171 : }
172 4 : }
173 : }
174 :
175 2 : #[derive(Debug, Error)]
176 : pub enum WakeComputeError {
177 : #[error("Console responded with a malformed compute address: {0}")]
178 : BadComputeAddress(Box<str>),
179 :
180 : #[error(transparent)]
181 : ApiError(ApiError),
182 :
183 : #[error("Timeout waiting to acquire wake compute lock")]
184 : TimeoutError,
185 : }
186 :
187 : // This allows more useful interactions than `#[from]`.
188 : impl<E: Into<ApiError>> From<E> for WakeComputeError {
189 0 : fn from(e: E) -> Self {
190 0 : Self::ApiError(e.into())
191 0 : }
192 : }
193 :
194 : impl From<tokio::sync::AcquireError> for WakeComputeError {
195 0 : fn from(_: tokio::sync::AcquireError) -> Self {
196 0 : WakeComputeError::TimeoutError
197 0 : }
198 : }
199 : impl From<tokio::time::error::Elapsed> for WakeComputeError {
200 0 : fn from(_: tokio::time::error::Elapsed) -> Self {
201 0 : WakeComputeError::TimeoutError
202 0 : }
203 : }
204 :
205 : impl UserFacingError for WakeComputeError {
206 0 : fn to_string_client(&self) -> String {
207 0 : use WakeComputeError::*;
208 0 : match self {
209 : // We shouldn't show user the address even if it's broken.
210 : // Besides, user is unlikely to care about this detail.
211 0 : BadComputeAddress(_) => REQUEST_FAILED.to_owned(),
212 : // However, API might return a meaningful error.
213 0 : ApiError(e) => e.to_string_client(),
214 :
215 0 : TimeoutError => "timeout while acquiring the compute resource lock".to_owned(),
216 : }
217 0 : }
218 : }
219 :
220 : impl ReportableError for WakeComputeError {
221 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
222 0 : match self {
223 0 : WakeComputeError::BadComputeAddress(_) => crate::error::ErrorKind::ControlPlane,
224 0 : WakeComputeError::ApiError(e) => e.get_error_kind(),
225 0 : WakeComputeError::TimeoutError => crate::error::ErrorKind::RateLimit,
226 : }
227 0 : }
228 : }
229 : }
230 :
231 : /// Auth secret which is managed by the cloud.
232 109 : #[derive(Clone, Eq, PartialEq, Debug)]
233 : pub enum AuthSecret {
234 : #[cfg(any(test, feature = "testing"))]
235 : /// Md5 hash of user's password.
236 : Md5([u8; 16]),
237 :
238 : /// [SCRAM](crate::scram) authentication info.
239 : Scram(scram::ServerSecret),
240 : }
241 :
242 0 : #[derive(Default)]
243 : pub struct AuthInfo {
244 : pub secret: Option<AuthSecret>,
245 : /// List of IP addresses allowed for the autorization.
246 : pub allowed_ips: Vec<IpPattern>,
247 : /// Project ID. This is used for cache invalidation.
248 : pub project_id: Option<ProjectId>,
249 : }
250 :
251 : /// Info for establishing a connection to a compute node.
252 : /// This is what we get after auth succeeded, but not before!
253 27 : #[derive(Clone)]
254 : pub struct NodeInfo {
255 : /// Compute node connection params.
256 : /// It's sad that we have to clone this, but this will improve
257 : /// once we migrate to a bespoke connection logic.
258 : pub config: compute::ConnCfg,
259 :
260 : /// Labels for proxy's metrics.
261 : pub aux: MetricsAuxInfo,
262 :
263 : /// Whether we should accept self-signed certificates (for testing)
264 : pub allow_self_signed_compute: bool,
265 : }
266 :
267 : impl NodeInfo {
268 41 : pub async fn connect(
269 41 : &self,
270 41 : ctx: &mut RequestMonitoring,
271 41 : timeout: Duration,
272 41 : ) -> Result<compute::PostgresConnection, compute::ConnectionError> {
273 41 : self.config
274 41 : .connect(
275 41 : ctx,
276 41 : self.allow_self_signed_compute,
277 41 : self.aux.clone(),
278 41 : timeout,
279 41 : )
280 117 : .await
281 41 : }
282 8 : pub fn reuse_settings(&mut self, other: Self) {
283 8 : self.allow_self_signed_compute = other.allow_self_signed_compute;
284 8 : self.config.reuse_password(other.config);
285 8 : }
286 :
287 94 : pub fn set_keys(&mut self, keys: &ComputeCredentialKeys) {
288 94 : match keys {
289 16 : ComputeCredentialKeys::Password(password) => self.config.password(password),
290 78 : ComputeCredentialKeys::AuthKeys(auth_keys) => self.config.auth_keys(*auth_keys),
291 : };
292 94 : }
293 : }
294 :
295 : pub type NodeInfoCache = TimedLru<EndpointCacheKey, NodeInfo>;
296 : pub type CachedNodeInfo = Cached<&'static NodeInfoCache>;
297 : pub type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option<AuthSecret>>;
298 : pub type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc<Vec<IpPattern>>>;
299 :
300 : /// This will allocate per each call, but the http requests alone
301 : /// already require a few allocations, so it should be fine.
302 : #[async_trait]
303 : pub trait Api {
304 : /// Get the client's auth secret for authentication.
305 : /// Returns option because user not found situation is special.
306 : /// We still have to mock the scram to avoid leaking information that user doesn't exist.
307 : async fn get_role_secret(
308 : &self,
309 : ctx: &mut RequestMonitoring,
310 : user_info: &ComputeUserInfo,
311 : ) -> Result<CachedRoleSecret, errors::GetAuthInfoError>;
312 :
313 : async fn get_allowed_ips_and_secret(
314 : &self,
315 : ctx: &mut RequestMonitoring,
316 : user_info: &ComputeUserInfo,
317 : ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError>;
318 :
319 : /// Wake up the compute node and return the corresponding connection info.
320 : async fn wake_compute(
321 : &self,
322 : ctx: &mut RequestMonitoring,
323 : user_info: &ComputeUserInfo,
324 : ) -> Result<CachedNodeInfo, errors::WakeComputeError>;
325 : }
326 :
327 : #[non_exhaustive]
328 : pub enum ConsoleBackend {
329 : /// Current Cloud API (V2).
330 : Console(neon::Api),
331 : /// Local mock of Cloud API (V2).
332 : #[cfg(any(test, feature = "testing"))]
333 : Postgres(mock::Api),
334 : /// Internal testing
335 : #[cfg(test)]
336 : Test(Box<dyn crate::auth::backend::TestBackend>),
337 : }
338 :
339 : #[async_trait]
340 : impl Api for ConsoleBackend {
341 87 : async fn get_role_secret(
342 87 : &self,
343 87 : ctx: &mut RequestMonitoring,
344 87 : user_info: &ComputeUserInfo,
345 87 : ) -> Result<CachedRoleSecret, errors::GetAuthInfoError> {
346 : use ConsoleBackend::*;
347 87 : match self {
348 0 : Console(api) => api.get_role_secret(ctx, user_info).await,
349 : #[cfg(any(test, feature = "testing"))]
350 603 : Postgres(api) => api.get_role_secret(ctx, user_info).await,
351 : #[cfg(test)]
352 0 : Test(_) => unreachable!("this function should never be called in the test backend"),
353 : }
354 261 : }
355 :
356 95 : async fn get_allowed_ips_and_secret(
357 95 : &self,
358 95 : ctx: &mut RequestMonitoring,
359 95 : user_info: &ComputeUserInfo,
360 95 : ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError> {
361 : use ConsoleBackend::*;
362 95 : match self {
363 12 : Console(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
364 : #[cfg(any(test, feature = "testing"))]
365 663 : Postgres(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
366 : #[cfg(test)]
367 0 : Test(api) => api.get_allowed_ips_and_secret(),
368 : }
369 285 : }
370 :
371 108 : async fn wake_compute(
372 108 : &self,
373 108 : ctx: &mut RequestMonitoring,
374 108 : user_info: &ComputeUserInfo,
375 108 : ) -> Result<CachedNodeInfo, errors::WakeComputeError> {
376 : use ConsoleBackend::*;
377 :
378 108 : match self {
379 0 : Console(api) => api.wake_compute(ctx, user_info).await,
380 : #[cfg(any(test, feature = "testing"))]
381 78 : Postgres(api) => api.wake_compute(ctx, user_info).await,
382 : #[cfg(test)]
383 30 : Test(api) => api.wake_compute(),
384 : }
385 324 : }
386 : }
387 :
388 : /// Various caches for [`console`](super).
389 : pub struct ApiCaches {
390 : /// Cache for the `wake_compute` API method.
391 : pub node_info: NodeInfoCache,
392 : /// Cache which stores project_id -> endpoint_ids mapping.
393 : pub project_info: Arc<ProjectInfoCacheImpl>,
394 : }
395 :
396 : impl ApiCaches {
397 1 : pub fn new(
398 1 : wake_compute_cache_config: CacheOptions,
399 1 : project_info_cache_config: ProjectInfoCacheOptions,
400 1 : ) -> Self {
401 1 : Self {
402 1 : node_info: NodeInfoCache::new(
403 1 : "node_info_cache",
404 1 : wake_compute_cache_config.size,
405 1 : wake_compute_cache_config.ttl,
406 1 : true,
407 1 : ),
408 1 : project_info: Arc::new(ProjectInfoCacheImpl::new(project_info_cache_config)),
409 1 : }
410 1 : }
411 : }
412 :
413 : /// Various caches for [`console`](super).
414 : pub struct ApiLocks {
415 : name: &'static str,
416 : node_locks: DashMap<EndpointCacheKey, Arc<Semaphore>>,
417 : permits: usize,
418 : timeout: Duration,
419 : registered: prometheus::IntCounter,
420 : unregistered: prometheus::IntCounter,
421 : reclamation_lag: prometheus::Histogram,
422 : lock_acquire_lag: prometheus::Histogram,
423 : }
424 :
425 : impl ApiLocks {
426 1 : pub fn new(
427 1 : name: &'static str,
428 1 : permits: usize,
429 1 : shards: usize,
430 1 : timeout: Duration,
431 1 : ) -> prometheus::Result<Self> {
432 1 : let registered = prometheus::IntCounter::with_opts(
433 1 : prometheus::Opts::new(
434 1 : "semaphores_registered",
435 1 : "Number of semaphores registered in this api lock",
436 1 : )
437 1 : .namespace(name),
438 1 : )?;
439 1 : prometheus::register(Box::new(registered.clone()))?;
440 1 : let unregistered = prometheus::IntCounter::with_opts(
441 1 : prometheus::Opts::new(
442 1 : "semaphores_unregistered",
443 1 : "Number of semaphores unregistered in this api lock",
444 1 : )
445 1 : .namespace(name),
446 1 : )?;
447 1 : prometheus::register(Box::new(unregistered.clone()))?;
448 1 : let reclamation_lag = prometheus::Histogram::with_opts(
449 1 : prometheus::HistogramOpts::new(
450 1 : "reclamation_lag_seconds",
451 1 : "Time it takes to reclaim unused semaphores in the api lock",
452 1 : )
453 1 : .namespace(name)
454 1 : // 1us -> 65ms
455 1 : // benchmarks on my mac indicate it's usually in the range of 256us and 512us
456 1 : .buckets(prometheus::exponential_buckets(1e-6, 2.0, 16)?),
457 0 : )?;
458 1 : prometheus::register(Box::new(reclamation_lag.clone()))?;
459 1 : let lock_acquire_lag = prometheus::Histogram::with_opts(
460 1 : prometheus::HistogramOpts::new(
461 1 : "semaphore_acquire_seconds",
462 1 : "Time it takes to reclaim unused semaphores in the api lock",
463 1 : )
464 1 : .namespace(name)
465 1 : // 0.1ms -> 6s
466 1 : .buckets(prometheus::exponential_buckets(1e-4, 2.0, 16)?),
467 0 : )?;
468 1 : prometheus::register(Box::new(lock_acquire_lag.clone()))?;
469 :
470 1 : Ok(Self {
471 1 : name,
472 1 : node_locks: DashMap::with_shard_amount(shards),
473 1 : permits,
474 1 : timeout,
475 1 : lock_acquire_lag,
476 1 : registered,
477 1 : unregistered,
478 1 : reclamation_lag,
479 1 : })
480 1 : }
481 :
482 0 : pub async fn get_wake_compute_permit(
483 0 : &self,
484 0 : key: &EndpointCacheKey,
485 0 : ) -> Result<WakeComputePermit, errors::WakeComputeError> {
486 0 : if self.permits == 0 {
487 0 : return Ok(WakeComputePermit { permit: None });
488 0 : }
489 0 : let now = Instant::now();
490 0 : let semaphore = {
491 : // get fast path
492 0 : if let Some(semaphore) = self.node_locks.get(key) {
493 0 : semaphore.clone()
494 : } else {
495 0 : self.node_locks
496 0 : .entry(key.clone())
497 0 : .or_insert_with(|| {
498 0 : self.registered.inc();
499 0 : Arc::new(Semaphore::new(self.permits))
500 0 : })
501 0 : .clone()
502 : }
503 : };
504 0 : let permit = tokio::time::timeout_at(now + self.timeout, semaphore.acquire_owned()).await;
505 :
506 0 : self.lock_acquire_lag
507 0 : .observe((Instant::now() - now).as_secs_f64());
508 0 :
509 0 : Ok(WakeComputePermit {
510 0 : permit: Some(permit??),
511 : })
512 0 : }
513 :
514 1 : pub async fn garbage_collect_worker(&self, epoch: std::time::Duration) {
515 1 : if self.permits == 0 {
516 1 : return;
517 0 : }
518 0 :
519 0 : let mut interval = tokio::time::interval(epoch / (self.node_locks.shards().len()) as u32);
520 : loop {
521 0 : for (i, shard) in self.node_locks.shards().iter().enumerate() {
522 0 : interval.tick().await;
523 : // temporary lock a single shard and then clear any semaphores that aren't currently checked out
524 : // race conditions: if strong_count == 1, there's no way that it can increase while the shard is locked
525 : // therefore releasing it is safe from race conditions
526 0 : info!(
527 0 : name = self.name,
528 0 : shard = i,
529 0 : "performing epoch reclamation on api lock"
530 0 : );
531 0 : let mut lock = shard.write();
532 0 : let timer = self.reclamation_lag.start_timer();
533 0 : let count = lock
534 0 : .extract_if(|_, semaphore| Arc::strong_count(semaphore.get_mut()) == 1)
535 0 : .count();
536 0 : drop(lock);
537 0 : self.unregistered.inc_by(count as u64);
538 0 : timer.observe_duration()
539 : }
540 : }
541 1 : }
542 : }
543 :
544 : pub struct WakeComputePermit {
545 : // None if the lock is disabled
546 : permit: Option<OwnedSemaphorePermit>,
547 : }
548 :
549 : impl WakeComputePermit {
550 0 : pub fn should_check_cache(&self) -> bool {
551 0 : self.permit.is_some()
552 0 : }
553 : }
|