LCOV - code coverage report
Current view: top level - proxy/src/auth - backend.rs (source / functions) Coverage Total Hit
Test: a43a77853355b937a79c57b07a8f05607cf29e6c.info Lines: 76.6 % 501 384
Test Date: 2024-09-19 12:04:32 Functions: 42.3 % 71 30

            Line data    Source code
       1              : mod classic;
       2              : mod hacks;
       3              : pub mod jwt;
       4              : pub mod local;
       5              : mod web;
       6              : 
       7              : use std::net::IpAddr;
       8              : use std::sync::Arc;
       9              : use std::time::Duration;
      10              : 
      11              : use ipnet::{Ipv4Net, Ipv6Net};
      12              : use local::LocalBackend;
      13              : use tokio::io::{AsyncRead, AsyncWrite};
      14              : use tokio_postgres::config::AuthKeys;
      15              : use tracing::{info, warn};
      16              : pub(crate) use web::WebAuthError;
      17              : 
      18              : use crate::auth::credentials::check_peer_addr_is_in_list;
      19              : use crate::auth::{validate_password_and_exchange, AuthError};
      20              : use crate::cache::Cached;
      21              : use crate::console::errors::GetAuthInfoError;
      22              : use crate::console::provider::{CachedRoleSecret, ConsoleBackend};
      23              : use crate::console::{AuthSecret, NodeInfo};
      24              : use crate::context::RequestMonitoring;
      25              : use crate::intern::EndpointIdInt;
      26              : use crate::metrics::Metrics;
      27              : use crate::proxy::connect_compute::ComputeConnectBackend;
      28              : use crate::proxy::NeonOptions;
      29              : use crate::rate_limiter::{BucketRateLimiter, EndpointRateLimiter, RateBucketInfo};
      30              : use crate::stream::Stream;
      31              : use crate::{
      32              :     auth::{self, ComputeUserInfoMaybeEndpoint},
      33              :     config::AuthenticationConfig,
      34              :     console::{
      35              :         self,
      36              :         provider::{CachedAllowedIps, CachedNodeInfo},
      37              :         Api,
      38              :     },
      39              :     stream, url,
      40              : };
      41              : use crate::{scram, EndpointCacheKey, EndpointId, RoleName};
      42              : 
      43              : /// Alternative to [`std::borrow::Cow`] but doesn't need `T: ToOwned` as we don't need that functionality
      44              : pub enum MaybeOwned<'a, T> {
      45              :     Owned(T),
      46              :     Borrowed(&'a T),
      47              : }
      48              : 
      49              : impl<T> std::ops::Deref for MaybeOwned<'_, T> {
      50              :     type Target = T;
      51              : 
      52           13 :     fn deref(&self) -> &Self::Target {
      53           13 :         match self {
      54           13 :             MaybeOwned::Owned(t) => t,
      55            0 :             MaybeOwned::Borrowed(t) => t,
      56              :         }
      57           13 :     }
      58              : }
      59              : 
      60              : /// This type serves two purposes:
      61              : ///
      62              : /// * When `T` is `()`, it's just a regular auth backend selector
      63              : ///   which we use in [`crate::config::ProxyConfig`].
      64              : ///
      65              : /// * However, when we substitute `T` with [`ComputeUserInfoMaybeEndpoint`],
      66              : ///   this helps us provide the credentials only to those auth
      67              : ///   backends which require them for the authentication process.
      68              : pub enum Backend<'a, T, D> {
      69              :     /// Cloud API (V2).
      70              :     Console(MaybeOwned<'a, ConsoleBackend>, T),
      71              :     /// Authentication via a web browser.
      72              :     Web(MaybeOwned<'a, url::ApiUrl>, D),
      73              :     /// Local proxy uses configured auth credentials and does not wake compute
      74              :     Local(MaybeOwned<'a, LocalBackend>),
      75              : }
      76              : 
      77              : #[cfg(test)]
      78              : pub(crate) trait TestBackend: Send + Sync + 'static {
      79              :     fn wake_compute(&self) -> Result<CachedNodeInfo, console::errors::WakeComputeError>;
      80              :     fn get_allowed_ips_and_secret(
      81              :         &self,
      82              :     ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), console::errors::GetAuthInfoError>;
      83              : }
      84              : 
      85              : impl std::fmt::Display for Backend<'_, (), ()> {
      86            0 :     fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
      87            0 :         match self {
      88            0 :             Self::Console(api, ()) => match &**api {
      89            0 :                 ConsoleBackend::Console(endpoint) => {
      90            0 :                     fmt.debug_tuple("Console").field(&endpoint.url()).finish()
      91              :                 }
      92              :                 #[cfg(any(test, feature = "testing"))]
      93            0 :                 ConsoleBackend::Postgres(endpoint) => {
      94            0 :                     fmt.debug_tuple("Postgres").field(&endpoint.url()).finish()
      95              :                 }
      96              :                 #[cfg(test)]
      97            0 :                 ConsoleBackend::Test(_) => fmt.debug_tuple("Test").finish(),
      98              :             },
      99            0 :             Self::Web(url, ()) => fmt.debug_tuple("Web").field(&url.as_str()).finish(),
     100            0 :             Self::Local(_) => fmt.debug_tuple("Local").finish(),
     101              :         }
     102            0 :     }
     103              : }
     104              : 
     105              : impl<T, D> Backend<'_, T, D> {
     106              :     /// Very similar to [`std::option::Option::as_ref`].
     107              :     /// This helps us pass structured config to async tasks.
     108            0 :     pub(crate) fn as_ref(&self) -> Backend<'_, &T, &D> {
     109            0 :         match self {
     110            0 :             Self::Console(c, x) => Backend::Console(MaybeOwned::Borrowed(c), x),
     111            0 :             Self::Web(c, x) => Backend::Web(MaybeOwned::Borrowed(c), x),
     112            0 :             Self::Local(l) => Backend::Local(MaybeOwned::Borrowed(l)),
     113              :         }
     114            0 :     }
     115              : }
     116              : 
     117              : impl<'a, T, D> Backend<'a, T, D> {
     118              :     /// Very similar to [`std::option::Option::map`].
     119              :     /// Maps [`Backend<T>`] to [`Backend<R>`] by applying
     120              :     /// a function to a contained value.
     121            0 :     pub(crate) fn map<R>(self, f: impl FnOnce(T) -> R) -> Backend<'a, R, D> {
     122            0 :         match self {
     123            0 :             Self::Console(c, x) => Backend::Console(c, f(x)),
     124            0 :             Self::Web(c, x) => Backend::Web(c, x),
     125            0 :             Self::Local(l) => Backend::Local(l),
     126              :         }
     127            0 :     }
     128              : }
     129              : impl<'a, T, D, E> Backend<'a, Result<T, E>, D> {
     130              :     /// Very similar to [`std::option::Option::transpose`].
     131              :     /// This is most useful for error handling.
     132            0 :     pub(crate) fn transpose(self) -> Result<Backend<'a, T, D>, E> {
     133            0 :         match self {
     134            0 :             Self::Console(c, x) => x.map(|x| Backend::Console(c, x)),
     135            0 :             Self::Web(c, x) => Ok(Backend::Web(c, x)),
     136            0 :             Self::Local(l) => Ok(Backend::Local(l)),
     137              :         }
     138            0 :     }
     139              : }
     140              : 
     141              : pub(crate) struct ComputeCredentials {
     142              :     pub(crate) info: ComputeUserInfo,
     143              :     pub(crate) keys: ComputeCredentialKeys,
     144              : }
     145              : 
     146              : #[derive(Debug, Clone)]
     147              : pub(crate) struct ComputeUserInfoNoEndpoint {
     148              :     pub(crate) user: RoleName,
     149              :     pub(crate) options: NeonOptions,
     150              : }
     151              : 
     152              : #[derive(Debug, Clone)]
     153              : pub(crate) struct ComputeUserInfo {
     154              :     pub(crate) endpoint: EndpointId,
     155              :     pub(crate) user: RoleName,
     156              :     pub(crate) options: NeonOptions,
     157              : }
     158              : 
     159              : impl ComputeUserInfo {
     160            2 :     pub(crate) fn endpoint_cache_key(&self) -> EndpointCacheKey {
     161            2 :         self.options.get_cache_key(&self.endpoint)
     162            2 :     }
     163              : }
     164              : 
     165              : pub(crate) enum ComputeCredentialKeys {
     166              :     Password(Vec<u8>),
     167              :     AuthKeys(AuthKeys),
     168              :     None,
     169              : }
     170              : 
     171              : impl TryFrom<ComputeUserInfoMaybeEndpoint> for ComputeUserInfo {
     172              :     // user name
     173              :     type Error = ComputeUserInfoNoEndpoint;
     174              : 
     175            3 :     fn try_from(user_info: ComputeUserInfoMaybeEndpoint) -> Result<Self, Self::Error> {
     176            3 :         match user_info.endpoint_id {
     177            1 :             None => Err(ComputeUserInfoNoEndpoint {
     178            1 :                 user: user_info.user,
     179            1 :                 options: user_info.options,
     180            1 :             }),
     181            2 :             Some(endpoint) => Ok(ComputeUserInfo {
     182            2 :                 endpoint,
     183            2 :                 user: user_info.user,
     184            2 :                 options: user_info.options,
     185            2 :             }),
     186              :         }
     187            3 :     }
     188              : }
     189              : 
     190              : #[derive(PartialEq, PartialOrd, Hash, Eq, Ord, Debug, Copy, Clone)]
     191              : pub struct MaskedIp(IpAddr);
     192              : 
     193              : impl MaskedIp {
     194           15 :     fn new(value: IpAddr, prefix: u8) -> Self {
     195           15 :         match value {
     196           11 :             IpAddr::V4(v4) => Self(IpAddr::V4(
     197           11 :                 Ipv4Net::new(v4, prefix).map_or(v4, |x| x.trunc().addr()),
     198           11 :             )),
     199            4 :             IpAddr::V6(v6) => Self(IpAddr::V6(
     200            4 :                 Ipv6Net::new(v6, prefix).map_or(v6, |x| x.trunc().addr()),
     201            4 :             )),
     202              :         }
     203           15 :     }
     204              : }
     205              : 
     206              : // This can't be just per IP because that would limit some PaaS that share IP addresses
     207              : pub type AuthRateLimiter = BucketRateLimiter<(EndpointIdInt, MaskedIp)>;
     208              : 
     209              : impl RateBucketInfo {
     210              :     /// All of these are per endpoint-maskedip pair.
     211              :     /// Context: 4096 rounds of pbkdf2 take about 1ms of cpu time to execute (1 milli-cpu-second or 1mcpus).
     212              :     ///
     213              :     /// First bucket: 1000mcpus total per endpoint-ip pair
     214              :     /// * 4096000 requests per second with 1 hash rounds.
     215              :     /// * 1000 requests per second with 4096 hash rounds.
     216              :     /// * 6.8 requests per second with 600000 hash rounds.
     217              :     pub const DEFAULT_AUTH_SET: [Self; 3] = [
     218              :         Self::new(1000 * 4096, Duration::from_secs(1)),
     219              :         Self::new(600 * 4096, Duration::from_secs(60)),
     220              :         Self::new(300 * 4096, Duration::from_secs(600)),
     221              :     ];
     222              : }
     223              : 
     224              : impl AuthenticationConfig {
     225            3 :     pub(crate) fn check_rate_limit(
     226            3 :         &self,
     227            3 :         ctx: &RequestMonitoring,
     228            3 :         config: &AuthenticationConfig,
     229            3 :         secret: AuthSecret,
     230            3 :         endpoint: &EndpointId,
     231            3 :         is_cleartext: bool,
     232            3 :     ) -> auth::Result<AuthSecret> {
     233            3 :         // we have validated the endpoint exists, so let's intern it.
     234            3 :         let endpoint_int = EndpointIdInt::from(endpoint.normalize());
     235              : 
     236              :         // only count the full hash count if password hack or websocket flow.
     237              :         // in other words, if proxy needs to run the hashing
     238            3 :         let password_weight = if is_cleartext {
     239            2 :             match &secret {
     240              :                 #[cfg(any(test, feature = "testing"))]
     241            0 :                 AuthSecret::Md5(_) => 1,
     242            2 :                 AuthSecret::Scram(s) => s.iterations + 1,
     243              :             }
     244              :         } else {
     245              :             // validating scram takes just 1 hmac_sha_256 operation.
     246            1 :             1
     247              :         };
     248              : 
     249            3 :         let limit_not_exceeded = self.rate_limiter.check(
     250            3 :             (
     251            3 :                 endpoint_int,
     252            3 :                 MaskedIp::new(ctx.peer_addr(), config.rate_limit_ip_subnet),
     253            3 :             ),
     254            3 :             password_weight,
     255            3 :         );
     256            3 : 
     257            3 :         if !limit_not_exceeded {
     258            0 :             warn!(
     259              :                 enabled = self.rate_limiter_enabled,
     260            0 :                 "rate limiting authentication"
     261              :             );
     262            0 :             Metrics::get().proxy.requests_auth_rate_limits_total.inc();
     263            0 :             Metrics::get()
     264            0 :                 .proxy
     265            0 :                 .endpoints_auth_rate_limits
     266            0 :                 .get_metric()
     267            0 :                 .measure(endpoint);
     268            0 : 
     269            0 :             if self.rate_limiter_enabled {
     270            0 :                 return Err(auth::AuthError::too_many_connections());
     271            0 :             }
     272            3 :         }
     273              : 
     274            3 :         Ok(secret)
     275            3 :     }
     276              : }
     277              : 
     278              : /// True to its name, this function encapsulates our current auth trade-offs.
     279              : /// Here, we choose the appropriate auth flow based on circumstances.
     280              : ///
     281              : /// All authentication flows will emit an AuthenticationOk message if successful.
     282            3 : async fn auth_quirks(
     283            3 :     ctx: &RequestMonitoring,
     284            3 :     api: &impl console::Api,
     285            3 :     user_info: ComputeUserInfoMaybeEndpoint,
     286            3 :     client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
     287            3 :     allow_cleartext: bool,
     288            3 :     config: &'static AuthenticationConfig,
     289            3 :     endpoint_rate_limiter: Arc<EndpointRateLimiter>,
     290            3 : ) -> auth::Result<ComputeCredentials> {
     291              :     // If there's no project so far, that entails that client doesn't
     292              :     // support SNI or other means of passing the endpoint (project) name.
     293              :     // We now expect to see a very specific payload in the place of password.
     294            3 :     let (info, unauthenticated_password) = match user_info.try_into() {
     295            1 :         Err(info) => {
     296            1 :             let res = hacks::password_hack_no_authentication(ctx, info, client).await?;
     297              : 
     298            1 :             ctx.set_endpoint_id(res.info.endpoint.clone());
     299            1 :             let password = match res.keys {
     300            1 :                 ComputeCredentialKeys::Password(p) => p,
     301              :                 ComputeCredentialKeys::AuthKeys(_) | ComputeCredentialKeys::None => {
     302            0 :                     unreachable!("password hack should return a password")
     303              :                 }
     304              :             };
     305            1 :             (res.info, Some(password))
     306              :         }
     307            2 :         Ok(info) => (info, None),
     308              :     };
     309              : 
     310            3 :     info!("fetching user's authentication info");
     311            3 :     let (allowed_ips, maybe_secret) = api.get_allowed_ips_and_secret(ctx, &info).await?;
     312              : 
     313              :     // check allowed list
     314            3 :     if config.ip_allowlist_check_enabled
     315            3 :         && !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips)
     316              :     {
     317            0 :         return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr()));
     318            3 :     }
     319            3 : 
     320            3 :     if !endpoint_rate_limiter.check(info.endpoint.clone().into(), 1) {
     321            0 :         return Err(AuthError::too_many_connections());
     322            3 :     }
     323            3 :     let cached_secret = match maybe_secret {
     324            3 :         Some(secret) => secret,
     325            0 :         None => api.get_role_secret(ctx, &info).await?,
     326              :     };
     327            3 :     let (cached_entry, secret) = cached_secret.take_value();
     328              : 
     329            3 :     let secret = if let Some(secret) = secret {
     330            3 :         config.check_rate_limit(
     331            3 :             ctx,
     332            3 :             config,
     333            3 :             secret,
     334            3 :             &info.endpoint,
     335            3 :             unauthenticated_password.is_some() || allow_cleartext,
     336            0 :         )?
     337              :     } else {
     338              :         // If we don't have an authentication secret, we mock one to
     339              :         // prevent malicious probing (possible due to missing protocol steps).
     340              :         // This mocked secret will never lead to successful authentication.
     341            0 :         info!("authentication info not found, mocking it");
     342            0 :         AuthSecret::Scram(scram::ServerSecret::mock(rand::random()))
     343              :     };
     344              : 
     345            3 :     match authenticate_with_secret(
     346            3 :         ctx,
     347            3 :         secret,
     348            3 :         info,
     349            3 :         client,
     350            3 :         unauthenticated_password,
     351            3 :         allow_cleartext,
     352            3 :         config,
     353            3 :     )
     354            5 :     .await
     355              :     {
     356            3 :         Ok(keys) => Ok(keys),
     357            0 :         Err(e) => {
     358            0 :             if e.is_auth_failed() {
     359            0 :                 // The password could have been changed, so we invalidate the cache.
     360            0 :                 cached_entry.invalidate();
     361            0 :             }
     362            0 :             Err(e)
     363              :         }
     364              :     }
     365            3 : }
     366              : 
     367            3 : async fn authenticate_with_secret(
     368            3 :     ctx: &RequestMonitoring,
     369            3 :     secret: AuthSecret,
     370            3 :     info: ComputeUserInfo,
     371            3 :     client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
     372            3 :     unauthenticated_password: Option<Vec<u8>>,
     373            3 :     allow_cleartext: bool,
     374            3 :     config: &'static AuthenticationConfig,
     375            3 : ) -> auth::Result<ComputeCredentials> {
     376            3 :     if let Some(password) = unauthenticated_password {
     377            1 :         let ep = EndpointIdInt::from(&info.endpoint);
     378              : 
     379            1 :         let auth_outcome =
     380            1 :             validate_password_and_exchange(&config.thread_pool, ep, &password, secret).await?;
     381            1 :         let keys = match auth_outcome {
     382            1 :             crate::sasl::Outcome::Success(key) => key,
     383            0 :             crate::sasl::Outcome::Failure(reason) => {
     384            0 :                 info!("auth backend failed with an error: {reason}");
     385            0 :                 return Err(auth::AuthError::auth_failed(&*info.user));
     386              :             }
     387              :         };
     388              : 
     389              :         // we have authenticated the password
     390            1 :         client.write_message_noflush(&pq_proto::BeMessage::AuthenticationOk)?;
     391              : 
     392            1 :         return Ok(ComputeCredentials { info, keys });
     393            2 :     }
     394            2 : 
     395            2 :     // -- the remaining flows are self-authenticating --
     396            2 : 
     397            2 :     // Perform cleartext auth if we're allowed to do that.
     398            2 :     // Currently, we use it for websocket connections (latency).
     399            2 :     if allow_cleartext {
     400            1 :         ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
     401            2 :         return hacks::authenticate_cleartext(ctx, info, client, secret, config).await;
     402            1 :     }
     403            1 : 
     404            1 :     // Finally, proceed with the main auth flow (SCRAM-based).
     405            2 :     classic::authenticate(ctx, info, client, config, secret).await
     406            3 : }
     407              : 
     408              : impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint, &()> {
     409              :     /// Get username from the credentials.
     410            0 :     pub(crate) fn get_user(&self) -> &str {
     411            0 :         match self {
     412            0 :             Self::Console(_, user_info) => &user_info.user,
     413            0 :             Self::Web(_, ()) => "web",
     414            0 :             Self::Local(_) => "local",
     415              :         }
     416            0 :     }
     417              : 
     418              :     /// Authenticate the client via the requested backend, possibly using credentials.
     419            0 :     #[tracing::instrument(fields(allow_cleartext = allow_cleartext), skip_all)]
     420              :     pub(crate) async fn authenticate(
     421              :         self,
     422              :         ctx: &RequestMonitoring,
     423              :         client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
     424              :         allow_cleartext: bool,
     425              :         config: &'static AuthenticationConfig,
     426              :         endpoint_rate_limiter: Arc<EndpointRateLimiter>,
     427              :     ) -> auth::Result<Backend<'a, ComputeCredentials, NodeInfo>> {
     428              :         let res = match self {
     429              :             Self::Console(api, user_info) => {
     430              :                 info!(
     431              :                     user = &*user_info.user,
     432              :                     project = user_info.endpoint(),
     433              :                     "performing authentication using the console"
     434              :                 );
     435              : 
     436              :                 let credentials = auth_quirks(
     437              :                     ctx,
     438              :                     &*api,
     439              :                     user_info,
     440              :                     client,
     441              :                     allow_cleartext,
     442              :                     config,
     443              :                     endpoint_rate_limiter,
     444              :                 )
     445              :                 .await?;
     446              :                 Backend::Console(api, credentials)
     447              :             }
     448              :             // NOTE: this auth backend doesn't use client credentials.
     449              :             Self::Web(url, ()) => {
     450              :                 info!("performing web authentication");
     451              : 
     452              :                 let info = web::authenticate(ctx, &url, client).await?;
     453              : 
     454              :                 Backend::Web(url, info)
     455              :             }
     456              :             Self::Local(_) => {
     457              :                 return Err(auth::AuthError::bad_auth_method("invalid for local proxy"))
     458              :             }
     459              :         };
     460              : 
     461              :         info!("user successfully authenticated");
     462              :         Ok(res)
     463              :     }
     464              : }
     465              : 
     466              : impl Backend<'_, ComputeUserInfo, &()> {
     467            0 :     pub(crate) async fn get_role_secret(
     468            0 :         &self,
     469            0 :         ctx: &RequestMonitoring,
     470            0 :     ) -> Result<CachedRoleSecret, GetAuthInfoError> {
     471            0 :         match self {
     472            0 :             Self::Console(api, user_info) => api.get_role_secret(ctx, user_info).await,
     473            0 :             Self::Web(_, ()) => Ok(Cached::new_uncached(None)),
     474            0 :             Self::Local(_) => Ok(Cached::new_uncached(None)),
     475              :         }
     476            0 :     }
     477              : 
     478            0 :     pub(crate) async fn get_allowed_ips_and_secret(
     479            0 :         &self,
     480            0 :         ctx: &RequestMonitoring,
     481            0 :     ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
     482            0 :         match self {
     483            0 :             Self::Console(api, user_info) => api.get_allowed_ips_and_secret(ctx, user_info).await,
     484            0 :             Self::Web(_, ()) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
     485            0 :             Self::Local(_) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
     486              :         }
     487            0 :     }
     488              : }
     489              : 
     490              : #[async_trait::async_trait]
     491              : impl ComputeConnectBackend for Backend<'_, ComputeCredentials, NodeInfo> {
     492            0 :     async fn wake_compute(
     493            0 :         &self,
     494            0 :         ctx: &RequestMonitoring,
     495            0 :     ) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
     496            0 :         match self {
     497            0 :             Self::Console(api, creds) => api.wake_compute(ctx, &creds.info).await,
     498            0 :             Self::Web(_, info) => Ok(Cached::new_uncached(info.clone())),
     499            0 :             Self::Local(local) => Ok(Cached::new_uncached(local.node_info.clone())),
     500              :         }
     501            0 :     }
     502              : 
     503            0 :     fn get_keys(&self) -> &ComputeCredentialKeys {
     504            0 :         match self {
     505            0 :             Self::Console(_, creds) => &creds.keys,
     506            0 :             Self::Web(_, _) => &ComputeCredentialKeys::None,
     507            0 :             Self::Local(_) => &ComputeCredentialKeys::None,
     508              :         }
     509            0 :     }
     510              : }
     511              : 
     512              : #[async_trait::async_trait]
     513              : impl ComputeConnectBackend for Backend<'_, ComputeCredentials, &()> {
     514           13 :     async fn wake_compute(
     515           13 :         &self,
     516           13 :         ctx: &RequestMonitoring,
     517           13 :     ) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
     518           13 :         match self {
     519           13 :             Self::Console(api, creds) => api.wake_compute(ctx, &creds.info).await,
     520            0 :             Self::Web(_, ()) => {
     521            0 :                 unreachable!("web auth flow doesn't support waking the compute")
     522              :             }
     523            0 :             Self::Local(local) => Ok(Cached::new_uncached(local.node_info.clone())),
     524              :         }
     525           26 :     }
     526              : 
     527            6 :     fn get_keys(&self) -> &ComputeCredentialKeys {
     528            6 :         match self {
     529            6 :             Self::Console(_, creds) => &creds.keys,
     530            0 :             Self::Web(_, ()) => &ComputeCredentialKeys::None,
     531            0 :             Self::Local(_) => &ComputeCredentialKeys::None,
     532              :         }
     533            6 :     }
     534              : }
     535              : 
     536              : #[cfg(test)]
     537              : mod tests {
     538              :     use std::{net::IpAddr, sync::Arc, time::Duration};
     539              : 
     540              :     use bytes::BytesMut;
     541              :     use fallible_iterator::FallibleIterator;
     542              :     use once_cell::sync::Lazy;
     543              :     use postgres_protocol::{
     544              :         authentication::sasl::{ChannelBinding, ScramSha256},
     545              :         message::{backend::Message as PgMessage, frontend},
     546              :     };
     547              :     use provider::AuthSecret;
     548              :     use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
     549              : 
     550              :     use crate::{
     551              :         auth::{backend::MaskedIp, ComputeUserInfoMaybeEndpoint, IpPattern},
     552              :         config::AuthenticationConfig,
     553              :         console::{
     554              :             self,
     555              :             provider::{self, CachedAllowedIps, CachedRoleSecret},
     556              :             CachedNodeInfo,
     557              :         },
     558              :         context::RequestMonitoring,
     559              :         proxy::NeonOptions,
     560              :         rate_limiter::{EndpointRateLimiter, RateBucketInfo},
     561              :         scram::{threadpool::ThreadPool, ServerSecret},
     562              :         stream::{PqStream, Stream},
     563              :     };
     564              : 
     565              :     use super::{auth_quirks, AuthRateLimiter};
     566              : 
     567              :     struct Auth {
     568              :         ips: Vec<IpPattern>,
     569              :         secret: AuthSecret,
     570              :     }
     571              : 
     572              :     impl console::Api for Auth {
     573            0 :         async fn get_role_secret(
     574            0 :             &self,
     575            0 :             _ctx: &RequestMonitoring,
     576            0 :             _user_info: &super::ComputeUserInfo,
     577            0 :         ) -> Result<CachedRoleSecret, console::errors::GetAuthInfoError> {
     578            0 :             Ok(CachedRoleSecret::new_uncached(Some(self.secret.clone())))
     579            0 :         }
     580              : 
     581            3 :         async fn get_allowed_ips_and_secret(
     582            3 :             &self,
     583            3 :             _ctx: &RequestMonitoring,
     584            3 :             _user_info: &super::ComputeUserInfo,
     585            3 :         ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), console::errors::GetAuthInfoError>
     586            3 :         {
     587            3 :             Ok((
     588            3 :                 CachedAllowedIps::new_uncached(Arc::new(self.ips.clone())),
     589            3 :                 Some(CachedRoleSecret::new_uncached(Some(self.secret.clone()))),
     590            3 :             ))
     591            3 :         }
     592              : 
     593            0 :         async fn wake_compute(
     594            0 :             &self,
     595            0 :             _ctx: &RequestMonitoring,
     596            0 :             _user_info: &super::ComputeUserInfo,
     597            0 :         ) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
     598            0 :             unimplemented!()
     599              :         }
     600              :     }
     601              : 
     602            3 :     static CONFIG: Lazy<AuthenticationConfig> = Lazy::new(|| AuthenticationConfig {
     603            3 :         thread_pool: ThreadPool::new(1),
     604            3 :         scram_protocol_timeout: std::time::Duration::from_secs(5),
     605            3 :         rate_limiter_enabled: true,
     606            3 :         rate_limiter: AuthRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET),
     607            3 :         rate_limit_ip_subnet: 64,
     608            3 :         ip_allowlist_check_enabled: true,
     609            3 :     });
     610              : 
     611            5 :     async fn read_message(r: &mut (impl AsyncRead + Unpin), b: &mut BytesMut) -> PgMessage {
     612              :         loop {
     613            7 :             r.read_buf(&mut *b).await.unwrap();
     614            7 :             if let Some(m) = PgMessage::parse(&mut *b).unwrap() {
     615            5 :                 break m;
     616            2 :             }
     617              :         }
     618            5 :     }
     619              : 
     620              :     #[test]
     621            1 :     fn masked_ip() {
     622            1 :         let ip_a = IpAddr::V4([127, 0, 0, 1].into());
     623            1 :         let ip_b = IpAddr::V4([127, 0, 0, 2].into());
     624            1 :         let ip_c = IpAddr::V4([192, 168, 1, 101].into());
     625            1 :         let ip_d = IpAddr::V4([192, 168, 1, 102].into());
     626            1 :         let ip_e = IpAddr::V6("abcd:abcd:abcd:abcd:abcd:abcd:abcd:abcd".parse().unwrap());
     627            1 :         let ip_f = IpAddr::V6("abcd:abcd:abcd:abcd:1234:abcd:abcd:abcd".parse().unwrap());
     628            1 : 
     629            1 :         assert_ne!(MaskedIp::new(ip_a, 64), MaskedIp::new(ip_b, 64));
     630            1 :         assert_ne!(MaskedIp::new(ip_a, 32), MaskedIp::new(ip_b, 32));
     631            1 :         assert_eq!(MaskedIp::new(ip_a, 30), MaskedIp::new(ip_b, 30));
     632            1 :         assert_eq!(MaskedIp::new(ip_c, 30), MaskedIp::new(ip_d, 30));
     633              : 
     634            1 :         assert_ne!(MaskedIp::new(ip_e, 128), MaskedIp::new(ip_f, 128));
     635            1 :         assert_eq!(MaskedIp::new(ip_e, 64), MaskedIp::new(ip_f, 64));
     636            1 :     }
     637              : 
     638              :     #[test]
     639            1 :     fn test_default_auth_rate_limit_set() {
     640            1 :         // these values used to exceed u32::MAX
     641            1 :         assert_eq!(
     642            1 :             RateBucketInfo::DEFAULT_AUTH_SET,
     643            1 :             [
     644            1 :                 RateBucketInfo {
     645            1 :                     interval: Duration::from_secs(1),
     646            1 :                     max_rpi: 1000 * 4096,
     647            1 :                 },
     648            1 :                 RateBucketInfo {
     649            1 :                     interval: Duration::from_secs(60),
     650            1 :                     max_rpi: 600 * 4096 * 60,
     651            1 :                 },
     652            1 :                 RateBucketInfo {
     653            1 :                     interval: Duration::from_secs(600),
     654            1 :                     max_rpi: 300 * 4096 * 600,
     655            1 :                 }
     656            1 :             ]
     657            1 :         );
     658              : 
     659            4 :         for x in RateBucketInfo::DEFAULT_AUTH_SET {
     660            3 :             let y = x.to_string().parse().unwrap();
     661            3 :             assert_eq!(x, y);
     662              :         }
     663            1 :     }
     664              : 
     665              :     #[tokio::test]
     666            1 :     async fn auth_quirks_scram() {
     667            1 :         let (mut client, server) = tokio::io::duplex(1024);
     668            1 :         let mut stream = PqStream::new(Stream::from_raw(server));
     669            1 : 
     670            1 :         let ctx = RequestMonitoring::test();
     671            1 :         let api = Auth {
     672            1 :             ips: vec![],
     673            3 :             secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
     674            1 :         };
     675            1 : 
     676            1 :         let user_info = ComputeUserInfoMaybeEndpoint {
     677            1 :             user: "conrad".into(),
     678            1 :             endpoint_id: Some("endpoint".into()),
     679            1 :             options: NeonOptions::default(),
     680            1 :         };
     681            1 : 
     682            1 :         let handle = tokio::spawn(async move {
     683            1 :             let mut scram = ScramSha256::new(b"my-secret-password", ChannelBinding::unsupported());
     684            1 : 
     685            1 :             let mut read = BytesMut::new();
     686            1 : 
     687            1 :             // server should offer scram
     688            1 :             match read_message(&mut client, &mut read).await {
     689            1 :                 PgMessage::AuthenticationSasl(a) => {
     690            1 :                     let options: Vec<&str> = a.mechanisms().collect().unwrap();
     691            1 :                     assert_eq!(options, ["SCRAM-SHA-256"]);
     692            1 :                 }
     693            1 :                 _ => panic!("wrong message"),
     694            1 :             }
     695            1 : 
     696            1 :             // client sends client-first-message
     697            1 :             let mut write = BytesMut::new();
     698            1 :             frontend::sasl_initial_response("SCRAM-SHA-256", scram.message(), &mut write).unwrap();
     699            1 :             client.write_all(&write).await.unwrap();
     700            1 : 
     701            1 :             // server response with server-first-message
     702            1 :             match read_message(&mut client, &mut read).await {
     703            1 :                 PgMessage::AuthenticationSaslContinue(a) => {
     704            3 :                     scram.update(a.data()).await.unwrap();
     705            1 :                 }
     706            1 :                 _ => panic!("wrong message"),
     707            1 :             }
     708            1 : 
     709            1 :             // client response with client-final-message
     710            1 :             write.clear();
     711            1 :             frontend::sasl_response(scram.message(), &mut write).unwrap();
     712            1 :             client.write_all(&write).await.unwrap();
     713            1 : 
     714            1 :             // server response with server-final-message
     715            1 :             match read_message(&mut client, &mut read).await {
     716            1 :                 PgMessage::AuthenticationSaslFinal(a) => {
     717            1 :                     scram.finish(a.data()).unwrap();
     718            1 :                 }
     719            1 :                 _ => panic!("wrong message"),
     720            1 :             }
     721            1 :         });
     722            1 :         let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
     723            1 :             EndpointRateLimiter::DEFAULT,
     724            1 :             64,
     725            1 :         ));
     726            1 : 
     727            1 :         let _creds = auth_quirks(
     728            1 :             &ctx,
     729            1 :             &api,
     730            1 :             user_info,
     731            1 :             &mut stream,
     732            1 :             false,
     733            1 :             &CONFIG,
     734            1 :             endpoint_rate_limiter,
     735            1 :         )
     736            2 :         .await
     737            1 :         .unwrap();
     738            1 : 
     739            1 :         handle.await.unwrap();
     740            1 :     }
     741              : 
     742              :     #[tokio::test]
     743            1 :     async fn auth_quirks_cleartext() {
     744            1 :         let (mut client, server) = tokio::io::duplex(1024);
     745            1 :         let mut stream = PqStream::new(Stream::from_raw(server));
     746            1 : 
     747            1 :         let ctx = RequestMonitoring::test();
     748            1 :         let api = Auth {
     749            1 :             ips: vec![],
     750            3 :             secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
     751            1 :         };
     752            1 : 
     753            1 :         let user_info = ComputeUserInfoMaybeEndpoint {
     754            1 :             user: "conrad".into(),
     755            1 :             endpoint_id: Some("endpoint".into()),
     756            1 :             options: NeonOptions::default(),
     757            1 :         };
     758            1 : 
     759            1 :         let handle = tokio::spawn(async move {
     760            1 :             let mut read = BytesMut::new();
     761            1 :             let mut write = BytesMut::new();
     762            1 : 
     763            1 :             // server should offer cleartext
     764            1 :             match read_message(&mut client, &mut read).await {
     765            1 :                 PgMessage::AuthenticationCleartextPassword => {}
     766            1 :                 _ => panic!("wrong message"),
     767            1 :             }
     768            1 : 
     769            1 :             // client responds with password
     770            1 :             write.clear();
     771            1 :             frontend::password_message(b"my-secret-password", &mut write).unwrap();
     772            1 :             client.write_all(&write).await.unwrap();
     773            1 :         });
     774            1 :         let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
     775            1 :             EndpointRateLimiter::DEFAULT,
     776            1 :             64,
     777            1 :         ));
     778            1 : 
     779            1 :         let _creds = auth_quirks(
     780            1 :             &ctx,
     781            1 :             &api,
     782            1 :             user_info,
     783            1 :             &mut stream,
     784            1 :             true,
     785            1 :             &CONFIG,
     786            1 :             endpoint_rate_limiter,
     787            1 :         )
     788            2 :         .await
     789            1 :         .unwrap();
     790            1 : 
     791            1 :         handle.await.unwrap();
     792            1 :     }
     793              : 
     794              :     #[tokio::test]
     795            1 :     async fn auth_quirks_password_hack() {
     796            1 :         let (mut client, server) = tokio::io::duplex(1024);
     797            1 :         let mut stream = PqStream::new(Stream::from_raw(server));
     798            1 : 
     799            1 :         let ctx = RequestMonitoring::test();
     800            1 :         let api = Auth {
     801            1 :             ips: vec![],
     802            3 :             secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
     803            1 :         };
     804            1 : 
     805            1 :         let user_info = ComputeUserInfoMaybeEndpoint {
     806            1 :             user: "conrad".into(),
     807            1 :             endpoint_id: None,
     808            1 :             options: NeonOptions::default(),
     809            1 :         };
     810            1 : 
     811            1 :         let handle = tokio::spawn(async move {
     812            1 :             let mut read = BytesMut::new();
     813            1 : 
     814            1 :             // server should offer cleartext
     815            1 :             match read_message(&mut client, &mut read).await {
     816            1 :                 PgMessage::AuthenticationCleartextPassword => {}
     817            1 :                 _ => panic!("wrong message"),
     818            1 :             }
     819            1 : 
     820            1 :             // client responds with password
     821            1 :             let mut write = BytesMut::new();
     822            1 :             frontend::password_message(b"endpoint=my-endpoint;my-secret-password", &mut write)
     823            1 :                 .unwrap();
     824            1 :             client.write_all(&write).await.unwrap();
     825            1 :         });
     826            1 : 
     827            1 :         let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
     828            1 :             EndpointRateLimiter::DEFAULT,
     829            1 :             64,
     830            1 :         ));
     831            1 : 
     832            1 :         let creds = auth_quirks(
     833            1 :             &ctx,
     834            1 :             &api,
     835            1 :             user_info,
     836            1 :             &mut stream,
     837            1 :             true,
     838            1 :             &CONFIG,
     839            1 :             endpoint_rate_limiter,
     840            1 :         )
     841            2 :         .await
     842            1 :         .unwrap();
     843            1 : 
     844            1 :         assert_eq!(creds.info.endpoint, "my-endpoint");
     845            1 : 
     846            1 :         handle.await.unwrap();
     847            1 :     }
     848              : }
        

Generated by: LCOV version 2.1-beta