LCOV - differential code coverage report
Current view: top level - proxy/src/console - provider.rs (source / functions) Coverage Total Hit UBC CBC
Current: cd44433dd675caa99df17a61b18949c8387e2242.info Lines: 51.4 % 177 91 86 91
Current Date: 2024-01-09 02:06:09 Functions: 41.5 % 41 17 24 17
Baseline: 66c52a629a0f4a503e193045e0df4c77139e344b.info
Baseline Date: 2024-01-08 15:34:46

           TLA  Line data    Source code
       1                 : #[cfg(feature = "testing")]
       2                 : pub mod mock;
       3                 : pub mod neon;
       4                 : 
       5                 : use super::messages::MetricsAuxInfo;
       6                 : use crate::{
       7                 :     auth::backend::ComputeUserInfo,
       8                 :     cache::{timed_lru, TimedLru},
       9                 :     compute,
      10                 :     context::RequestMonitoring,
      11                 :     scram,
      12                 : };
      13                 : use async_trait::async_trait;
      14                 : use dashmap::DashMap;
      15                 : use smol_str::SmolStr;
      16                 : use std::{sync::Arc, time::Duration};
      17                 : use tokio::{
      18                 :     sync::{OwnedSemaphorePermit, Semaphore},
      19                 :     time::Instant,
      20                 : };
      21                 : use tracing::info;
      22                 : 
      23                 : pub mod errors {
      24                 :     use crate::{
      25                 :         error::{io_error, UserFacingError},
      26                 :         http,
      27                 :         proxy::retry::ShouldRetry,
      28                 :     };
      29                 :     use thiserror::Error;
      30                 : 
      31                 :     /// A go-to error message which doesn't leak any detail.
      32                 :     const REQUEST_FAILED: &str = "Console request failed";
      33                 : 
      34                 :     /// Common console API error.
      35 CBC          25 :     #[derive(Debug, Error)]
      36                 :     pub enum ApiError {
      37                 :         /// Error returned by the console itself.
      38                 :         #[error("{REQUEST_FAILED} with {}: {}", .status, .text)]
      39                 :         Console {
      40                 :             status: http::StatusCode,
      41                 :             text: Box<str>,
      42                 :         },
      43                 : 
      44                 :         /// Various IO errors like broken pipe or malformed payload.
      45                 :         #[error("{REQUEST_FAILED}: {0}")]
      46                 :         Transport(#[from] std::io::Error),
      47                 :     }
      48                 : 
      49                 :     impl ApiError {
      50                 :         /// Returns HTTP status code if it's the reason for failure.
      51               3 :         pub fn http_status_code(&self) -> Option<http::StatusCode> {
      52               3 :             use ApiError::*;
      53               3 :             match self {
      54               2 :                 Console { status, .. } => Some(*status),
      55               1 :                 _ => None,
      56                 :             }
      57               3 :         }
      58                 :     }
      59                 : 
      60                 :     impl UserFacingError for ApiError {
      61               4 :         fn to_string_client(&self) -> String {
      62               4 :             use ApiError::*;
      63               4 :             match self {
      64                 :                 // To minimize risks, only select errors are forwarded to users.
      65                 :                 // Ask @neondatabase/control-plane for review before adding more.
      66               2 :                 Console { status, .. } => match *status {
      67                 :                     http::StatusCode::NOT_FOUND => {
      68                 :                         // Status 404: failed to get a project-related resource.
      69 UBC           0 :                         format!("{REQUEST_FAILED}: endpoint cannot be found")
      70                 :                     }
      71                 :                     http::StatusCode::NOT_ACCEPTABLE => {
      72                 :                         // Status 406: endpoint is disabled (we don't allow connections).
      73               0 :                         format!("{REQUEST_FAILED}: endpoint is disabled")
      74                 :                     }
      75                 :                     http::StatusCode::LOCKED => {
      76                 :                         // Status 423: project might be in maintenance mode (or bad state), or quotas exceeded.
      77               0 :                         format!("{REQUEST_FAILED}: endpoint is temporary unavailable. check your quotas and/or contact our support")
      78                 :                     }
      79 CBC           2 :                     _ => REQUEST_FAILED.to_owned(),
      80                 :                 },
      81               2 :                 _ => REQUEST_FAILED.to_owned(),
      82                 :             }
      83               4 :         }
      84                 :     }
      85                 : 
      86                 :     impl ShouldRetry for ApiError {
      87               4 :         fn could_retry(&self) -> bool {
      88               4 :             match self {
      89                 :                 // retry some transport errors
      90 UBC           0 :                 Self::Transport(io) => io.could_retry(),
      91                 :                 // retry some temporary failures because the compute was in a bad state
      92                 :                 // (bad request can be returned when the endpoint was in transition)
      93                 :                 Self::Console {
      94                 :                     status: http::StatusCode::BAD_REQUEST,
      95                 :                     ..
      96 CBC           2 :                 } => true,
      97                 :                 // locked can be returned when the endpoint was in transition
      98                 :                 // or when quotas are exceeded. don't retry when quotas are exceeded
      99                 :                 Self::Console {
     100                 :                     status: http::StatusCode::LOCKED,
     101 UBC           0 :                     ref text,
     102               0 :                 } => {
     103               0 :                     // written data quota exceeded
     104               0 :                     // data transfer quota exceeded
     105               0 :                     // compute time quota exceeded
     106               0 :                     // logical size quota exceeded
     107               0 :                     !text.contains("quota exceeded")
     108               0 :                         && !text.contains("the limit for current plan reached")
     109                 :                 }
     110 CBC           2 :                 _ => false,
     111                 :             }
     112               4 :         }
     113                 :     }
     114                 : 
     115                 :     impl From<reqwest::Error> for ApiError {
     116               1 :         fn from(e: reqwest::Error) -> Self {
     117               1 :             io_error(e).into()
     118               1 :         }
     119                 :     }
     120                 : 
     121                 :     impl From<reqwest_middleware::Error> for ApiError {
     122               1 :         fn from(e: reqwest_middleware::Error) -> Self {
     123               1 :             io_error(e).into()
     124               1 :         }
     125                 :     }
     126                 : 
     127              16 :     #[derive(Debug, Error)]
     128                 :     pub enum GetAuthInfoError {
     129                 :         // We shouldn't include the actual secret here.
     130                 :         #[error("Console responded with a malformed auth secret")]
     131                 :         BadSecret,
     132                 : 
     133                 :         #[error(transparent)]
     134                 :         ApiError(ApiError),
     135                 :     }
     136                 : 
     137                 :     // This allows more useful interactions than `#[from]`.
     138                 :     impl<E: Into<ApiError>> From<E> for GetAuthInfoError {
     139               4 :         fn from(e: E) -> Self {
     140               4 :             Self::ApiError(e.into())
     141               4 :         }
     142                 :     }
     143                 : 
     144                 :     impl UserFacingError for GetAuthInfoError {
     145               4 :         fn to_string_client(&self) -> String {
     146               4 :             use GetAuthInfoError::*;
     147               4 :             match self {
     148                 :                 // We absolutely should not leak any secrets!
     149 UBC           0 :                 BadSecret => REQUEST_FAILED.to_owned(),
     150                 :                 // However, API might return a meaningful error.
     151 CBC           4 :                 ApiError(e) => e.to_string_client(),
     152                 :             }
     153               4 :         }
     154                 :     }
     155 UBC           0 :     #[derive(Debug, Error)]
     156                 :     pub enum WakeComputeError {
     157                 :         #[error("Console responded with a malformed compute address: {0}")]
     158                 :         BadComputeAddress(Box<str>),
     159                 : 
     160                 :         #[error(transparent)]
     161                 :         ApiError(ApiError),
     162                 : 
     163                 :         #[error("Timeout waiting to acquire wake compute lock")]
     164                 :         TimeoutError,
     165                 :     }
     166                 : 
     167                 :     // This allows more useful interactions than `#[from]`.
     168                 :     impl<E: Into<ApiError>> From<E> for WakeComputeError {
     169               0 :         fn from(e: E) -> Self {
     170               0 :             Self::ApiError(e.into())
     171               0 :         }
     172                 :     }
     173                 : 
     174                 :     impl From<tokio::sync::AcquireError> for WakeComputeError {
     175               0 :         fn from(_: tokio::sync::AcquireError) -> Self {
     176               0 :             WakeComputeError::TimeoutError
     177               0 :         }
     178                 :     }
     179                 :     impl From<tokio::time::error::Elapsed> for WakeComputeError {
     180               0 :         fn from(_: tokio::time::error::Elapsed) -> Self {
     181               0 :             WakeComputeError::TimeoutError
     182               0 :         }
     183                 :     }
     184                 : 
     185                 :     impl UserFacingError for WakeComputeError {
     186               0 :         fn to_string_client(&self) -> String {
     187               0 :             use WakeComputeError::*;
     188               0 :             match self {
     189                 :                 // We shouldn't show user the address even if it's broken.
     190                 :                 // Besides, user is unlikely to care about this detail.
     191               0 :                 BadComputeAddress(_) => REQUEST_FAILED.to_owned(),
     192                 :                 // However, API might return a meaningful error.
     193               0 :                 ApiError(e) => e.to_string_client(),
     194                 : 
     195               0 :                 TimeoutError => "timeout while acquiring the compute resource lock".to_owned(),
     196                 :             }
     197               0 :         }
     198                 :     }
     199                 : }
     200                 : 
     201                 : /// Extra query params we'd like to pass to the console.
     202                 : pub struct ConsoleReqExtra {
     203                 :     pub options: Vec<(String, String)>,
     204                 : }
     205                 : 
     206                 : impl ConsoleReqExtra {
     207                 :     // https://swagger.io/docs/specification/serialization/ DeepObject format
     208                 :     // paramName[prop1]=value1&paramName[prop2]=value2&....
     209               0 :     pub fn options_as_deep_object(&self) -> Vec<(String, String)> {
     210               0 :         self.options
     211               0 :             .iter()
     212               0 :             .map(|(k, v)| (format!("options[{}]", k), v.to_string()))
     213               0 :             .collect()
     214               0 :     }
     215                 : }
     216                 : 
     217                 : /// Auth secret which is managed by the cloud.
     218 CBC          37 : #[derive(Clone)]
     219                 : pub enum AuthSecret {
     220                 :     #[cfg(feature = "testing")]
     221                 :     /// Md5 hash of user's password.
     222                 :     Md5([u8; 16]),
     223                 : 
     224                 :     /// [SCRAM](crate::scram) authentication info.
     225                 :     Scram(scram::ServerSecret),
     226                 : }
     227                 : 
     228 UBC           0 : #[derive(Default)]
     229                 : pub struct AuthInfo {
     230                 :     pub secret: Option<AuthSecret>,
     231                 :     /// List of IP addresses allowed for the autorization.
     232                 :     pub allowed_ips: Vec<String>,
     233                 : }
     234                 : 
     235                 : /// Info for establishing a connection to a compute node.
     236                 : /// This is what we get after auth succeeded, but not before!
     237               0 : #[derive(Clone)]
     238                 : pub struct NodeInfo {
     239                 :     /// Compute node connection params.
     240                 :     /// It's sad that we have to clone this, but this will improve
     241                 :     /// once we migrate to a bespoke connection logic.
     242                 :     pub config: compute::ConnCfg,
     243                 : 
     244                 :     /// Labels for proxy's metrics.
     245                 :     pub aux: MetricsAuxInfo,
     246                 : 
     247                 :     /// Whether we should accept self-signed certificates (for testing)
     248                 :     pub allow_self_signed_compute: bool,
     249                 : }
     250                 : 
     251                 : pub type NodeInfoCache = TimedLru<Arc<str>, NodeInfo>;
     252                 : pub type CachedNodeInfo = timed_lru::Cached<&'static NodeInfoCache>;
     253                 : pub type AllowedIpsCache = TimedLru<SmolStr, Arc<Vec<String>>>;
     254                 : pub type RoleSecretCache = TimedLru<(SmolStr, SmolStr), Option<AuthSecret>>;
     255                 : pub type CachedRoleSecret = timed_lru::Cached<&'static RoleSecretCache>;
     256                 : 
     257                 : /// This will allocate per each call, but the http requests alone
     258                 : /// already require a few allocations, so it should be fine.
     259                 : #[async_trait]
     260                 : pub trait Api {
     261                 :     /// Get the client's auth secret for authentication.
     262                 :     async fn get_role_secret(
     263                 :         &self,
     264                 :         ctx: &mut RequestMonitoring,
     265                 :         creds: &ComputeUserInfo,
     266                 :     ) -> Result<CachedRoleSecret, errors::GetAuthInfoError>;
     267                 : 
     268                 :     async fn get_allowed_ips(
     269                 :         &self,
     270                 :         ctx: &mut RequestMonitoring,
     271                 :         creds: &ComputeUserInfo,
     272                 :     ) -> Result<Arc<Vec<String>>, errors::GetAuthInfoError>;
     273                 : 
     274                 :     /// Wake up the compute node and return the corresponding connection info.
     275                 :     async fn wake_compute(
     276                 :         &self,
     277                 :         ctx: &mut RequestMonitoring,
     278                 :         extra: &ConsoleReqExtra,
     279                 :         creds: &ComputeUserInfo,
     280                 :     ) -> Result<CachedNodeInfo, errors::WakeComputeError>;
     281                 : }
     282                 : 
     283                 : /// Various caches for [`console`](super).
     284                 : pub struct ApiCaches {
     285                 :     /// Cache for the `wake_compute` API method.
     286                 :     pub node_info: NodeInfoCache,
     287                 :     /// Cache for the `get_allowed_ips`. TODO(anna): use notifications listener instead.
     288                 :     pub allowed_ips: AllowedIpsCache,
     289                 :     /// Cache for the `get_role_secret`. TODO(anna): use notifications listener instead.
     290                 :     pub role_secret: RoleSecretCache,
     291                 : }
     292                 : 
     293                 : /// Various caches for [`console`](super).
     294                 : pub struct ApiLocks {
     295                 :     name: &'static str,
     296                 :     node_locks: DashMap<Arc<str>, Arc<Semaphore>>,
     297                 :     permits: usize,
     298                 :     timeout: Duration,
     299                 :     registered: prometheus::IntCounter,
     300                 :     unregistered: prometheus::IntCounter,
     301                 :     reclamation_lag: prometheus::Histogram,
     302                 :     lock_acquire_lag: prometheus::Histogram,
     303                 : }
     304                 : 
     305                 : impl ApiLocks {
     306 CBC           1 :     pub fn new(
     307               1 :         name: &'static str,
     308               1 :         permits: usize,
     309               1 :         shards: usize,
     310               1 :         timeout: Duration,
     311               1 :     ) -> prometheus::Result<Self> {
     312               1 :         let registered = prometheus::IntCounter::with_opts(
     313               1 :             prometheus::Opts::new(
     314               1 :                 "semaphores_registered",
     315               1 :                 "Number of semaphores registered in this api lock",
     316               1 :             )
     317               1 :             .namespace(name),
     318               1 :         )?;
     319               1 :         prometheus::register(Box::new(registered.clone()))?;
     320               1 :         let unregistered = prometheus::IntCounter::with_opts(
     321               1 :             prometheus::Opts::new(
     322               1 :                 "semaphores_unregistered",
     323               1 :                 "Number of semaphores unregistered in this api lock",
     324               1 :             )
     325               1 :             .namespace(name),
     326               1 :         )?;
     327               1 :         prometheus::register(Box::new(unregistered.clone()))?;
     328               1 :         let reclamation_lag = prometheus::Histogram::with_opts(
     329               1 :             prometheus::HistogramOpts::new(
     330               1 :                 "reclamation_lag_seconds",
     331               1 :                 "Time it takes to reclaim unused semaphores in the api lock",
     332               1 :             )
     333               1 :             .namespace(name)
     334               1 :             // 1us -> 65ms
     335               1 :             // benchmarks on my mac indicate it's usually in the range of 256us and 512us
     336               1 :             .buckets(prometheus::exponential_buckets(1e-6, 2.0, 16)?),
     337 UBC           0 :         )?;
     338 CBC           1 :         prometheus::register(Box::new(reclamation_lag.clone()))?;
     339               1 :         let lock_acquire_lag = prometheus::Histogram::with_opts(
     340               1 :             prometheus::HistogramOpts::new(
     341               1 :                 "semaphore_acquire_seconds",
     342               1 :                 "Time it takes to reclaim unused semaphores in the api lock",
     343               1 :             )
     344               1 :             .namespace(name)
     345               1 :             // 0.1ms -> 6s
     346               1 :             .buckets(prometheus::exponential_buckets(1e-4, 2.0, 16)?),
     347 UBC           0 :         )?;
     348 CBC           1 :         prometheus::register(Box::new(lock_acquire_lag.clone()))?;
     349                 : 
     350               1 :         Ok(Self {
     351               1 :             name,
     352               1 :             node_locks: DashMap::with_shard_amount(shards),
     353               1 :             permits,
     354               1 :             timeout,
     355               1 :             lock_acquire_lag,
     356               1 :             registered,
     357               1 :             unregistered,
     358               1 :             reclamation_lag,
     359               1 :         })
     360               1 :     }
     361                 : 
     362 UBC           0 :     pub async fn get_wake_compute_permit(
     363               0 :         &self,
     364               0 :         key: &Arc<str>,
     365               0 :     ) -> Result<WakeComputePermit, errors::WakeComputeError> {
     366               0 :         if self.permits == 0 {
     367               0 :             return Ok(WakeComputePermit { permit: None });
     368               0 :         }
     369               0 :         let now = Instant::now();
     370               0 :         let semaphore = {
     371                 :             // get fast path
     372               0 :             if let Some(semaphore) = self.node_locks.get(key) {
     373               0 :                 semaphore.clone()
     374                 :             } else {
     375               0 :                 self.node_locks
     376               0 :                     .entry(key.clone())
     377               0 :                     .or_insert_with(|| {
     378               0 :                         self.registered.inc();
     379               0 :                         Arc::new(Semaphore::new(self.permits))
     380               0 :                     })
     381               0 :                     .clone()
     382                 :             }
     383                 :         };
     384               0 :         let permit = tokio::time::timeout_at(now + self.timeout, semaphore.acquire_owned()).await;
     385                 : 
     386               0 :         self.lock_acquire_lag
     387               0 :             .observe((Instant::now() - now).as_secs_f64());
     388               0 : 
     389               0 :         Ok(WakeComputePermit {
     390               0 :             permit: Some(permit??),
     391                 :         })
     392               0 :     }
     393                 : 
     394 CBC           1 :     pub async fn garbage_collect_worker(&self, epoch: std::time::Duration) {
     395               1 :         if self.permits == 0 {
     396               1 :             return;
     397 UBC           0 :         }
     398               0 : 
     399               0 :         let mut interval = tokio::time::interval(epoch / (self.node_locks.shards().len()) as u32);
     400                 :         loop {
     401               0 :             for (i, shard) in self.node_locks.shards().iter().enumerate() {
     402               0 :                 interval.tick().await;
     403                 :                 // temporary lock a single shard and then clear any semaphores that aren't currently checked out
     404                 :                 // race conditions: if strong_count == 1, there's no way that it can increase while the shard is locked
     405                 :                 // therefore releasing it is safe from race conditions
     406               0 :                 info!(
     407               0 :                     name = self.name,
     408               0 :                     shard = i,
     409               0 :                     "performing epoch reclamation on api lock"
     410               0 :                 );
     411               0 :                 let mut lock = shard.write();
     412               0 :                 let timer = self.reclamation_lag.start_timer();
     413               0 :                 let count = lock
     414               0 :                     .extract_if(|_, semaphore| Arc::strong_count(semaphore.get_mut()) == 1)
     415               0 :                     .count();
     416               0 :                 drop(lock);
     417               0 :                 self.unregistered.inc_by(count as u64);
     418               0 :                 timer.observe_duration()
     419                 :             }
     420                 :         }
     421 CBC           1 :     }
     422                 : }
     423                 : 
     424                 : pub struct WakeComputePermit {
     425                 :     // None if the lock is disabled
     426                 :     permit: Option<OwnedSemaphorePermit>,
     427                 : }
     428                 : 
     429                 : impl WakeComputePermit {
     430 UBC           0 :     pub fn should_check_cache(&self) -> bool {
     431               0 :         self.permit.is_some()
     432               0 :     }
     433                 : }
        

Generated by: LCOV version 2.1-beta