Line data Source code
1 : #[cfg(any(test, feature = "testing"))]
2 : pub mod mock;
3 : pub mod neon;
4 :
5 : use super::messages::MetricsAuxInfo;
6 : use crate::{
7 : auth::{backend::ComputeUserInfo, IpPattern},
8 : cache::{project_info::ProjectInfoCacheImpl, Cached, TimedLru},
9 : compute,
10 : config::{CacheOptions, ProjectInfoCacheOptions},
11 : context::RequestMonitoring,
12 : scram, EndpointCacheKey, ProjectId,
13 : };
14 : use async_trait::async_trait;
15 : use dashmap::DashMap;
16 : use std::{sync::Arc, time::Duration};
17 : use tokio::sync::{OwnedSemaphorePermit, Semaphore};
18 : use tokio::time::Instant;
19 : use tracing::info;
20 :
21 : pub mod errors {
22 : use crate::{
23 : error::{io_error, ReportableError, UserFacingError},
24 : http,
25 : proxy::retry::ShouldRetry,
26 : };
27 : use thiserror::Error;
28 :
29 : /// A go-to error message which doesn't leak any detail.
30 : const REQUEST_FAILED: &str = "Console request failed";
31 :
32 : /// Common console API error.
33 57 : #[derive(Debug, Error)]
34 : pub enum ApiError {
35 : /// Error returned by the console itself.
36 : #[error("{REQUEST_FAILED} with {}: {}", .status, .text)]
37 : Console {
38 : status: http::StatusCode,
39 : text: Box<str>,
40 : },
41 :
42 : /// Various IO errors like broken pipe or malformed payload.
43 : #[error("{REQUEST_FAILED}: {0}")]
44 : Transport(#[from] std::io::Error),
45 : }
46 :
47 : impl ApiError {
48 : /// Returns HTTP status code if it's the reason for failure.
49 3 : pub fn http_status_code(&self) -> Option<http::StatusCode> {
50 3 : use ApiError::*;
51 3 : match self {
52 2 : Console { status, .. } => Some(*status),
53 1 : _ => None,
54 : }
55 3 : }
56 : }
57 :
58 : impl UserFacingError for ApiError {
59 4 : fn to_string_client(&self) -> String {
60 4 : use ApiError::*;
61 4 : match self {
62 : // To minimize risks, only select errors are forwarded to users.
63 : // Ask @neondatabase/control-plane for review before adding more.
64 2 : Console { status, .. } => match *status {
65 : http::StatusCode::NOT_FOUND => {
66 : // Status 404: failed to get a project-related resource.
67 0 : format!("{REQUEST_FAILED}: endpoint cannot be found")
68 : }
69 : http::StatusCode::NOT_ACCEPTABLE => {
70 : // Status 406: endpoint is disabled (we don't allow connections).
71 0 : format!("{REQUEST_FAILED}: endpoint is disabled")
72 : }
73 : http::StatusCode::LOCKED => {
74 : // Status 423: project might be in maintenance mode (or bad state), or quotas exceeded.
75 0 : format!("{REQUEST_FAILED}: endpoint is temporary unavailable. check your quotas and/or contact our support")
76 : }
77 2 : _ => REQUEST_FAILED.to_owned(),
78 : },
79 2 : _ => REQUEST_FAILED.to_owned(),
80 : }
81 4 : }
82 : }
83 :
84 : impl ReportableError for ApiError {
85 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
86 0 : match self {
87 0 : ApiError::Console { .. } => crate::error::ErrorKind::ControlPlane,
88 0 : ApiError::Transport(_) => crate::error::ErrorKind::ControlPlane,
89 : }
90 0 : }
91 : }
92 :
93 : impl ShouldRetry for ApiError {
94 8 : fn could_retry(&self) -> bool {
95 8 : match self {
96 : // retry some transport errors
97 0 : Self::Transport(io) => io.could_retry(),
98 : // retry some temporary failures because the compute was in a bad state
99 : // (bad request can be returned when the endpoint was in transition)
100 : Self::Console {
101 : status: http::StatusCode::BAD_REQUEST,
102 : ..
103 4 : } => true,
104 : // locked can be returned when the endpoint was in transition
105 : // or when quotas are exceeded. don't retry when quotas are exceeded
106 : Self::Console {
107 : status: http::StatusCode::LOCKED,
108 0 : ref text,
109 0 : } => {
110 0 : // written data quota exceeded
111 0 : // data transfer quota exceeded
112 0 : // compute time quota exceeded
113 0 : // logical size quota exceeded
114 0 : !text.contains("quota exceeded")
115 0 : && !text.contains("the limit for current plan reached")
116 : }
117 4 : _ => false,
118 : }
119 8 : }
120 : }
121 :
122 : impl From<reqwest::Error> for ApiError {
123 1 : fn from(e: reqwest::Error) -> Self {
124 1 : io_error(e).into()
125 1 : }
126 : }
127 :
128 : impl From<reqwest_middleware::Error> for ApiError {
129 1 : fn from(e: reqwest_middleware::Error) -> Self {
130 1 : io_error(e).into()
131 1 : }
132 : }
133 :
134 40 : #[derive(Debug, Error)]
135 : pub enum GetAuthInfoError {
136 : // We shouldn't include the actual secret here.
137 : #[error("Console responded with a malformed auth secret")]
138 : BadSecret,
139 :
140 : #[error(transparent)]
141 : ApiError(ApiError),
142 : }
143 :
144 : // This allows more useful interactions than `#[from]`.
145 : impl<E: Into<ApiError>> From<E> for GetAuthInfoError {
146 4 : fn from(e: E) -> Self {
147 4 : Self::ApiError(e.into())
148 4 : }
149 : }
150 :
151 : impl UserFacingError for GetAuthInfoError {
152 4 : fn to_string_client(&self) -> String {
153 4 : use GetAuthInfoError::*;
154 4 : match self {
155 : // We absolutely should not leak any secrets!
156 0 : BadSecret => REQUEST_FAILED.to_owned(),
157 : // However, API might return a meaningful error.
158 4 : ApiError(e) => e.to_string_client(),
159 : }
160 4 : }
161 : }
162 :
163 : impl ReportableError for GetAuthInfoError {
164 4 : fn get_error_kind(&self) -> crate::error::ErrorKind {
165 4 : match self {
166 0 : GetAuthInfoError::BadSecret => crate::error::ErrorKind::ControlPlane,
167 4 : GetAuthInfoError::ApiError(_) => crate::error::ErrorKind::ControlPlane,
168 : }
169 4 : }
170 : }
171 :
172 0 : #[derive(Debug, Error)]
173 : pub enum WakeComputeError {
174 : #[error("Console responded with a malformed compute address: {0}")]
175 : BadComputeAddress(Box<str>),
176 :
177 : #[error(transparent)]
178 : ApiError(ApiError),
179 :
180 : #[error("Timeout waiting to acquire wake compute lock")]
181 : TimeoutError,
182 : }
183 :
184 : // This allows more useful interactions than `#[from]`.
185 : impl<E: Into<ApiError>> From<E> for WakeComputeError {
186 0 : fn from(e: E) -> Self {
187 0 : Self::ApiError(e.into())
188 0 : }
189 : }
190 :
191 : impl From<tokio::sync::AcquireError> for WakeComputeError {
192 0 : fn from(_: tokio::sync::AcquireError) -> Self {
193 0 : WakeComputeError::TimeoutError
194 0 : }
195 : }
196 : impl From<tokio::time::error::Elapsed> for WakeComputeError {
197 0 : fn from(_: tokio::time::error::Elapsed) -> Self {
198 0 : WakeComputeError::TimeoutError
199 0 : }
200 : }
201 :
202 : impl UserFacingError for WakeComputeError {
203 0 : fn to_string_client(&self) -> String {
204 0 : use WakeComputeError::*;
205 0 : match self {
206 : // We shouldn't show user the address even if it's broken.
207 : // Besides, user is unlikely to care about this detail.
208 0 : BadComputeAddress(_) => REQUEST_FAILED.to_owned(),
209 : // However, API might return a meaningful error.
210 0 : ApiError(e) => e.to_string_client(),
211 :
212 0 : TimeoutError => "timeout while acquiring the compute resource lock".to_owned(),
213 : }
214 0 : }
215 : }
216 :
217 : impl ReportableError for WakeComputeError {
218 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
219 0 : match self {
220 0 : WakeComputeError::BadComputeAddress(_) => crate::error::ErrorKind::ControlPlane,
221 0 : WakeComputeError::ApiError(e) => e.get_error_kind(),
222 0 : WakeComputeError::TimeoutError => crate::error::ErrorKind::RateLimit,
223 : }
224 0 : }
225 : }
226 : }
227 :
228 : /// Auth secret which is managed by the cloud.
229 109 : #[derive(Clone, Eq, PartialEq, Debug)]
230 : pub enum AuthSecret {
231 : #[cfg(any(test, feature = "testing"))]
232 : /// Md5 hash of user's password.
233 : Md5([u8; 16]),
234 :
235 : /// [SCRAM](crate::scram) authentication info.
236 : Scram(scram::ServerSecret),
237 : }
238 :
239 0 : #[derive(Default)]
240 : pub struct AuthInfo {
241 : pub secret: Option<AuthSecret>,
242 : /// List of IP addresses allowed for the autorization.
243 : pub allowed_ips: Vec<IpPattern>,
244 : /// Project ID. This is used for cache invalidation.
245 : pub project_id: Option<ProjectId>,
246 : }
247 :
248 : /// Info for establishing a connection to a compute node.
249 : /// This is what we get after auth succeeded, but not before!
250 0 : #[derive(Clone)]
251 : pub struct NodeInfo {
252 : /// Compute node connection params.
253 : /// It's sad that we have to clone this, but this will improve
254 : /// once we migrate to a bespoke connection logic.
255 : pub config: compute::ConnCfg,
256 :
257 : /// Labels for proxy's metrics.
258 : pub aux: MetricsAuxInfo,
259 :
260 : /// Whether we should accept self-signed certificates (for testing)
261 : pub allow_self_signed_compute: bool,
262 : }
263 :
264 : pub type NodeInfoCache = TimedLru<EndpointCacheKey, NodeInfo>;
265 : pub type CachedNodeInfo = Cached<&'static NodeInfoCache>;
266 : pub type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option<AuthSecret>>;
267 : pub type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc<Vec<IpPattern>>>;
268 :
269 : /// This will allocate per each call, but the http requests alone
270 : /// already require a few allocations, so it should be fine.
271 : #[async_trait]
272 : pub trait Api {
273 : /// Get the client's auth secret for authentication.
274 : /// Returns option because user not found situation is special.
275 : /// We still have to mock the scram to avoid leaking information that user doesn't exist.
276 : async fn get_role_secret(
277 : &self,
278 : ctx: &mut RequestMonitoring,
279 : user_info: &ComputeUserInfo,
280 : ) -> Result<CachedRoleSecret, errors::GetAuthInfoError>;
281 :
282 : async fn get_allowed_ips_and_secret(
283 : &self,
284 : ctx: &mut RequestMonitoring,
285 : user_info: &ComputeUserInfo,
286 : ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError>;
287 :
288 : /// Wake up the compute node and return the corresponding connection info.
289 : async fn wake_compute(
290 : &self,
291 : ctx: &mut RequestMonitoring,
292 : user_info: &ComputeUserInfo,
293 : ) -> Result<CachedNodeInfo, errors::WakeComputeError>;
294 : }
295 :
296 : #[non_exhaustive]
297 : pub enum ConsoleBackend {
298 : /// Current Cloud API (V2).
299 : Console(neon::Api),
300 : /// Local mock of Cloud API (V2).
301 : #[cfg(any(test, feature = "testing"))]
302 : Postgres(mock::Api),
303 : /// Internal testing
304 : #[cfg(test)]
305 : Test(Box<dyn crate::auth::backend::TestBackend>),
306 : }
307 :
308 : #[async_trait]
309 : impl Api for ConsoleBackend {
310 87 : async fn get_role_secret(
311 87 : &self,
312 87 : ctx: &mut RequestMonitoring,
313 87 : user_info: &ComputeUserInfo,
314 87 : ) -> Result<CachedRoleSecret, errors::GetAuthInfoError> {
315 : use ConsoleBackend::*;
316 87 : match self {
317 0 : Console(api) => api.get_role_secret(ctx, user_info).await,
318 : #[cfg(any(test, feature = "testing"))]
319 602 : Postgres(api) => api.get_role_secret(ctx, user_info).await,
320 : #[cfg(test)]
321 0 : Test(_) => unreachable!("this function should never be called in the test backend"),
322 : }
323 261 : }
324 :
325 95 : async fn get_allowed_ips_and_secret(
326 95 : &self,
327 95 : ctx: &mut RequestMonitoring,
328 95 : user_info: &ComputeUserInfo,
329 95 : ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError> {
330 : use ConsoleBackend::*;
331 95 : match self {
332 13 : Console(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
333 : #[cfg(any(test, feature = "testing"))]
334 650 : Postgres(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
335 : #[cfg(test)]
336 0 : Test(api) => api.get_allowed_ips_and_secret(),
337 : }
338 285 : }
339 :
340 92 : async fn wake_compute(
341 92 : &self,
342 92 : ctx: &mut RequestMonitoring,
343 92 : user_info: &ComputeUserInfo,
344 92 : ) -> Result<CachedNodeInfo, errors::WakeComputeError> {
345 : use ConsoleBackend::*;
346 :
347 92 : match self {
348 0 : Console(api) => api.wake_compute(ctx, user_info).await,
349 : #[cfg(any(test, feature = "testing"))]
350 78 : Postgres(api) => api.wake_compute(ctx, user_info).await,
351 : #[cfg(test)]
352 14 : Test(api) => api.wake_compute(),
353 : }
354 276 : }
355 : }
356 :
357 : /// Various caches for [`console`](super).
358 : pub struct ApiCaches {
359 : /// Cache for the `wake_compute` API method.
360 : pub node_info: NodeInfoCache,
361 : /// Cache which stores project_id -> endpoint_ids mapping.
362 : pub project_info: Arc<ProjectInfoCacheImpl>,
363 : }
364 :
365 : impl ApiCaches {
366 1 : pub fn new(
367 1 : wake_compute_cache_config: CacheOptions,
368 1 : project_info_cache_config: ProjectInfoCacheOptions,
369 1 : ) -> Self {
370 1 : Self {
371 1 : node_info: NodeInfoCache::new(
372 1 : "node_info_cache",
373 1 : wake_compute_cache_config.size,
374 1 : wake_compute_cache_config.ttl,
375 1 : true,
376 1 : ),
377 1 : project_info: Arc::new(ProjectInfoCacheImpl::new(project_info_cache_config)),
378 1 : }
379 1 : }
380 : }
381 :
382 : /// Various caches for [`console`](super).
383 : pub struct ApiLocks {
384 : name: &'static str,
385 : node_locks: DashMap<EndpointCacheKey, Arc<Semaphore>>,
386 : permits: usize,
387 : timeout: Duration,
388 : registered: prometheus::IntCounter,
389 : unregistered: prometheus::IntCounter,
390 : reclamation_lag: prometheus::Histogram,
391 : lock_acquire_lag: prometheus::Histogram,
392 : }
393 :
394 : impl ApiLocks {
395 1 : pub fn new(
396 1 : name: &'static str,
397 1 : permits: usize,
398 1 : shards: usize,
399 1 : timeout: Duration,
400 1 : ) -> prometheus::Result<Self> {
401 1 : let registered = prometheus::IntCounter::with_opts(
402 1 : prometheus::Opts::new(
403 1 : "semaphores_registered",
404 1 : "Number of semaphores registered in this api lock",
405 1 : )
406 1 : .namespace(name),
407 1 : )?;
408 1 : prometheus::register(Box::new(registered.clone()))?;
409 1 : let unregistered = prometheus::IntCounter::with_opts(
410 1 : prometheus::Opts::new(
411 1 : "semaphores_unregistered",
412 1 : "Number of semaphores unregistered in this api lock",
413 1 : )
414 1 : .namespace(name),
415 1 : )?;
416 1 : prometheus::register(Box::new(unregistered.clone()))?;
417 1 : let reclamation_lag = prometheus::Histogram::with_opts(
418 1 : prometheus::HistogramOpts::new(
419 1 : "reclamation_lag_seconds",
420 1 : "Time it takes to reclaim unused semaphores in the api lock",
421 1 : )
422 1 : .namespace(name)
423 1 : // 1us -> 65ms
424 1 : // benchmarks on my mac indicate it's usually in the range of 256us and 512us
425 1 : .buckets(prometheus::exponential_buckets(1e-6, 2.0, 16)?),
426 0 : )?;
427 1 : prometheus::register(Box::new(reclamation_lag.clone()))?;
428 1 : let lock_acquire_lag = prometheus::Histogram::with_opts(
429 1 : prometheus::HistogramOpts::new(
430 1 : "semaphore_acquire_seconds",
431 1 : "Time it takes to reclaim unused semaphores in the api lock",
432 1 : )
433 1 : .namespace(name)
434 1 : // 0.1ms -> 6s
435 1 : .buckets(prometheus::exponential_buckets(1e-4, 2.0, 16)?),
436 0 : )?;
437 1 : prometheus::register(Box::new(lock_acquire_lag.clone()))?;
438 :
439 1 : Ok(Self {
440 1 : name,
441 1 : node_locks: DashMap::with_shard_amount(shards),
442 1 : permits,
443 1 : timeout,
444 1 : lock_acquire_lag,
445 1 : registered,
446 1 : unregistered,
447 1 : reclamation_lag,
448 1 : })
449 1 : }
450 :
451 0 : pub async fn get_wake_compute_permit(
452 0 : &self,
453 0 : key: &EndpointCacheKey,
454 0 : ) -> Result<WakeComputePermit, errors::WakeComputeError> {
455 0 : if self.permits == 0 {
456 0 : return Ok(WakeComputePermit { permit: None });
457 0 : }
458 0 : let now = Instant::now();
459 0 : let semaphore = {
460 : // get fast path
461 0 : if let Some(semaphore) = self.node_locks.get(key) {
462 0 : semaphore.clone()
463 : } else {
464 0 : self.node_locks
465 0 : .entry(key.clone())
466 0 : .or_insert_with(|| {
467 0 : self.registered.inc();
468 0 : Arc::new(Semaphore::new(self.permits))
469 0 : })
470 0 : .clone()
471 : }
472 : };
473 0 : let permit = tokio::time::timeout_at(now + self.timeout, semaphore.acquire_owned()).await;
474 :
475 0 : self.lock_acquire_lag
476 0 : .observe((Instant::now() - now).as_secs_f64());
477 0 :
478 0 : Ok(WakeComputePermit {
479 0 : permit: Some(permit??),
480 : })
481 0 : }
482 :
483 1 : pub async fn garbage_collect_worker(&self, epoch: std::time::Duration) {
484 1 : if self.permits == 0 {
485 1 : return;
486 0 : }
487 0 :
488 0 : let mut interval = tokio::time::interval(epoch / (self.node_locks.shards().len()) as u32);
489 : loop {
490 0 : for (i, shard) in self.node_locks.shards().iter().enumerate() {
491 0 : interval.tick().await;
492 : // temporary lock a single shard and then clear any semaphores that aren't currently checked out
493 : // race conditions: if strong_count == 1, there's no way that it can increase while the shard is locked
494 : // therefore releasing it is safe from race conditions
495 0 : info!(
496 0 : name = self.name,
497 0 : shard = i,
498 0 : "performing epoch reclamation on api lock"
499 0 : );
500 0 : let mut lock = shard.write();
501 0 : let timer = self.reclamation_lag.start_timer();
502 0 : let count = lock
503 0 : .extract_if(|_, semaphore| Arc::strong_count(semaphore.get_mut()) == 1)
504 0 : .count();
505 0 : drop(lock);
506 0 : self.unregistered.inc_by(count as u64);
507 0 : timer.observe_duration()
508 : }
509 : }
510 1 : }
511 : }
512 :
513 : pub struct WakeComputePermit {
514 : // None if the lock is disabled
515 : permit: Option<OwnedSemaphorePermit>,
516 : }
517 :
518 : impl WakeComputePermit {
519 0 : pub fn should_check_cache(&self) -> bool {
520 0 : self.permit.is_some()
521 0 : }
522 : }
|