LCOV - code coverage report
Current view: top level - proxy/src/auth - backend.rs (source / functions) Coverage Total Hit
Test: fabb29a6339542ee130cd1d32b534fafdc0be240.info Lines: 75.7 % 502 380
Test Date: 2024-06-25 13:20:00 Functions: 40.5 % 74 30

            Line data    Source code
       1              : mod classic;
       2              : mod hacks;
       3              : mod link;
       4              : 
       5              : use std::net::IpAddr;
       6              : use std::sync::Arc;
       7              : use std::time::Duration;
       8              : 
       9              : use ipnet::{Ipv4Net, Ipv6Net};
      10              : pub use link::LinkAuthError;
      11              : use tokio::io::{AsyncRead, AsyncWrite};
      12              : use tokio_postgres::config::AuthKeys;
      13              : use tracing::{info, warn};
      14              : 
      15              : use crate::auth::credentials::check_peer_addr_is_in_list;
      16              : use crate::auth::{validate_password_and_exchange, AuthError};
      17              : use crate::cache::Cached;
      18              : use crate::console::errors::GetAuthInfoError;
      19              : use crate::console::provider::{CachedRoleSecret, ConsoleBackend};
      20              : use crate::console::{AuthSecret, NodeInfo};
      21              : use crate::context::RequestMonitoring;
      22              : use crate::intern::EndpointIdInt;
      23              : use crate::metrics::Metrics;
      24              : use crate::proxy::connect_compute::ComputeConnectBackend;
      25              : use crate::proxy::NeonOptions;
      26              : use crate::rate_limiter::{BucketRateLimiter, EndpointRateLimiter, RateBucketInfo};
      27              : use crate::stream::Stream;
      28              : use crate::{
      29              :     auth::{self, ComputeUserInfoMaybeEndpoint},
      30              :     config::AuthenticationConfig,
      31              :     console::{
      32              :         self,
      33              :         provider::{CachedAllowedIps, CachedNodeInfo},
      34              :         Api,
      35              :     },
      36              :     stream, url,
      37              : };
      38              : use crate::{scram, EndpointCacheKey, EndpointId, RoleName};
      39              : 
      40              : /// Alternative to [`std::borrow::Cow`] but doesn't need `T: ToOwned` as we don't need that functionality
      41              : pub enum MaybeOwned<'a, T> {
      42              :     Owned(T),
      43              :     Borrowed(&'a T),
      44              : }
      45              : 
      46              : impl<T> std::ops::Deref for MaybeOwned<'_, T> {
      47              :     type Target = T;
      48              : 
      49           26 :     fn deref(&self) -> &Self::Target {
      50           26 :         match self {
      51           26 :             MaybeOwned::Owned(t) => t,
      52            0 :             MaybeOwned::Borrowed(t) => t,
      53              :         }
      54           26 :     }
      55              : }
      56              : 
      57              : /// This type serves two purposes:
      58              : ///
      59              : /// * When `T` is `()`, it's just a regular auth backend selector
      60              : ///   which we use in [`crate::config::ProxyConfig`].
      61              : ///
      62              : /// * However, when we substitute `T` with [`ComputeUserInfoMaybeEndpoint`],
      63              : ///   this helps us provide the credentials only to those auth
      64              : ///   backends which require them for the authentication process.
      65              : pub enum BackendType<'a, T, D> {
      66              :     /// Cloud API (V2).
      67              :     Console(MaybeOwned<'a, ConsoleBackend>, T),
      68              :     /// Authentication via a web browser.
      69              :     Link(MaybeOwned<'a, url::ApiUrl>, D),
      70              : }
      71              : 
      72              : pub trait TestBackend: Send + Sync + 'static {
      73              :     fn wake_compute(&self) -> Result<CachedNodeInfo, console::errors::WakeComputeError>;
      74              :     fn get_allowed_ips_and_secret(
      75              :         &self,
      76              :     ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), console::errors::GetAuthInfoError>;
      77              :     fn get_role_secret(&self) -> Result<CachedRoleSecret, console::errors::GetAuthInfoError>;
      78              : }
      79              : 
      80              : impl std::fmt::Display for BackendType<'_, (), ()> {
      81            0 :     fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
      82            0 :         use BackendType::*;
      83            0 :         match self {
      84            0 :             Console(api, _) => match &**api {
      85            0 :                 ConsoleBackend::Console(endpoint) => {
      86            0 :                     fmt.debug_tuple("Console").field(&endpoint.url()).finish()
      87              :                 }
      88              :                 #[cfg(any(test, feature = "testing"))]
      89            0 :                 ConsoleBackend::Postgres(endpoint) => {
      90            0 :                     fmt.debug_tuple("Postgres").field(&endpoint.url()).finish()
      91              :                 }
      92              :                 #[cfg(test)]
      93            0 :                 ConsoleBackend::Test(_) => fmt.debug_tuple("Test").finish(),
      94              :             },
      95            0 :             Link(url, _) => fmt.debug_tuple("Link").field(&url.as_str()).finish(),
      96              :         }
      97            0 :     }
      98              : }
      99              : 
     100              : impl<T, D> BackendType<'_, T, D> {
     101              :     /// Very similar to [`std::option::Option::as_ref`].
     102              :     /// This helps us pass structured config to async tasks.
     103            0 :     pub fn as_ref(&self) -> BackendType<'_, &T, &D> {
     104            0 :         use BackendType::*;
     105            0 :         match self {
     106            0 :             Console(c, x) => Console(MaybeOwned::Borrowed(c), x),
     107            0 :             Link(c, x) => Link(MaybeOwned::Borrowed(c), x),
     108              :         }
     109            0 :     }
     110              : }
     111              : 
     112              : impl<'a, T, D> BackendType<'a, T, D> {
     113              :     /// Very similar to [`std::option::Option::map`].
     114              :     /// Maps [`BackendType<T>`] to [`BackendType<R>`] by applying
     115              :     /// a function to a contained value.
     116            0 :     pub fn map<R>(self, f: impl FnOnce(T) -> R) -> BackendType<'a, R, D> {
     117            0 :         use BackendType::*;
     118            0 :         match self {
     119            0 :             Console(c, x) => Console(c, f(x)),
     120            0 :             Link(c, x) => Link(c, x),
     121              :         }
     122            0 :     }
     123              : }
     124              : impl<'a, T, D, E> BackendType<'a, Result<T, E>, D> {
     125              :     /// Very similar to [`std::option::Option::transpose`].
     126              :     /// This is most useful for error handling.
     127            0 :     pub fn transpose(self) -> Result<BackendType<'a, T, D>, E> {
     128            0 :         use BackendType::*;
     129            0 :         match self {
     130            0 :             Console(c, x) => x.map(|x| Console(c, x)),
     131            0 :             Link(c, x) => Ok(Link(c, x)),
     132              :         }
     133            0 :     }
     134              : }
     135              : 
     136              : pub struct ComputeCredentials {
     137              :     pub info: ComputeUserInfo,
     138              :     pub keys: ComputeCredentialKeys,
     139              : }
     140              : 
     141              : #[derive(Debug, Clone)]
     142              : pub struct ComputeUserInfoNoEndpoint {
     143              :     pub user: RoleName,
     144              :     pub options: NeonOptions,
     145              : }
     146              : 
     147              : #[derive(Debug, Clone)]
     148              : pub struct ComputeUserInfo {
     149              :     pub endpoint: EndpointId,
     150              :     pub user: RoleName,
     151              :     pub options: NeonOptions,
     152              : }
     153              : 
     154              : impl ComputeUserInfo {
     155            4 :     pub fn endpoint_cache_key(&self) -> EndpointCacheKey {
     156            4 :         self.options.get_cache_key(&self.endpoint)
     157            4 :     }
     158              : }
     159              : 
     160              : pub enum ComputeCredentialKeys {
     161              :     Password(Vec<u8>),
     162              :     AuthKeys(AuthKeys),
     163              : }
     164              : 
     165              : impl TryFrom<ComputeUserInfoMaybeEndpoint> for ComputeUserInfo {
     166              :     // user name
     167              :     type Error = ComputeUserInfoNoEndpoint;
     168              : 
     169            6 :     fn try_from(user_info: ComputeUserInfoMaybeEndpoint) -> Result<Self, Self::Error> {
     170            6 :         match user_info.endpoint_id {
     171            2 :             None => Err(ComputeUserInfoNoEndpoint {
     172            2 :                 user: user_info.user,
     173            2 :                 options: user_info.options,
     174            2 :             }),
     175            4 :             Some(endpoint) => Ok(ComputeUserInfo {
     176            4 :                 endpoint,
     177            4 :                 user: user_info.user,
     178            4 :                 options: user_info.options,
     179            4 :             }),
     180              :         }
     181            6 :     }
     182              : }
     183              : 
     184              : #[derive(PartialEq, PartialOrd, Hash, Eq, Ord, Debug, Copy, Clone)]
     185              : pub struct MaskedIp(IpAddr);
     186              : 
     187              : impl MaskedIp {
     188           30 :     fn new(value: IpAddr, prefix: u8) -> Self {
     189           30 :         match value {
     190           22 :             IpAddr::V4(v4) => Self(IpAddr::V4(
     191           22 :                 Ipv4Net::new(v4, prefix).map_or(v4, |x| x.trunc().addr()),
     192           22 :             )),
     193            8 :             IpAddr::V6(v6) => Self(IpAddr::V6(
     194            8 :                 Ipv6Net::new(v6, prefix).map_or(v6, |x| x.trunc().addr()),
     195            8 :             )),
     196              :         }
     197           30 :     }
     198              : }
     199              : 
     200              : // This can't be just per IP because that would limit some PaaS that share IP addresses
     201              : pub type AuthRateLimiter = BucketRateLimiter<(EndpointIdInt, MaskedIp)>;
     202              : 
     203              : impl RateBucketInfo {
     204              :     /// All of these are per endpoint-maskedip pair.
     205              :     /// Context: 4096 rounds of pbkdf2 take about 1ms of cpu time to execute (1 milli-cpu-second or 1mcpus).
     206              :     ///
     207              :     /// First bucket: 1000mcpus total per endpoint-ip pair
     208              :     /// * 4096000 requests per second with 1 hash rounds.
     209              :     /// * 1000 requests per second with 4096 hash rounds.
     210              :     /// * 6.8 requests per second with 600000 hash rounds.
     211              :     pub const DEFAULT_AUTH_SET: [Self; 3] = [
     212              :         Self::new(1000 * 4096, Duration::from_secs(1)),
     213              :         Self::new(600 * 4096, Duration::from_secs(60)),
     214              :         Self::new(300 * 4096, Duration::from_secs(600)),
     215              :     ];
     216              : }
     217              : 
     218              : impl AuthenticationConfig {
     219            6 :     pub fn check_rate_limit(
     220            6 :         &self,
     221            6 :         ctx: &mut RequestMonitoring,
     222            6 :         config: &AuthenticationConfig,
     223            6 :         secret: AuthSecret,
     224            6 :         endpoint: &EndpointId,
     225            6 :         is_cleartext: bool,
     226            6 :     ) -> auth::Result<AuthSecret> {
     227            6 :         // we have validated the endpoint exists, so let's intern it.
     228            6 :         let endpoint_int = EndpointIdInt::from(endpoint.normalize());
     229              : 
     230              :         // only count the full hash count if password hack or websocket flow.
     231              :         // in other words, if proxy needs to run the hashing
     232            6 :         let password_weight = if is_cleartext {
     233            4 :             match &secret {
     234              :                 #[cfg(any(test, feature = "testing"))]
     235            0 :                 AuthSecret::Md5(_) => 1,
     236            4 :                 AuthSecret::Scram(s) => s.iterations + 1,
     237              :             }
     238              :         } else {
     239              :             // validating scram takes just 1 hmac_sha_256 operation.
     240            2 :             1
     241              :         };
     242              : 
     243            6 :         let limit_not_exceeded = self.rate_limiter.check(
     244            6 :             (
     245            6 :                 endpoint_int,
     246            6 :                 MaskedIp::new(ctx.peer_addr, config.rate_limit_ip_subnet),
     247            6 :             ),
     248            6 :             password_weight,
     249            6 :         );
     250            6 : 
     251            6 :         if !limit_not_exceeded {
     252            0 :             warn!(
     253              :                 enabled = self.rate_limiter_enabled,
     254            0 :                 "rate limiting authentication"
     255              :             );
     256            0 :             Metrics::get().proxy.requests_auth_rate_limits_total.inc();
     257            0 :             Metrics::get()
     258            0 :                 .proxy
     259            0 :                 .endpoints_auth_rate_limits
     260            0 :                 .get_metric()
     261            0 :                 .measure(endpoint);
     262            0 : 
     263            0 :             if self.rate_limiter_enabled {
     264            0 :                 return Err(auth::AuthError::too_many_connections());
     265            0 :             }
     266            6 :         }
     267              : 
     268            6 :         Ok(secret)
     269            6 :     }
     270              : }
     271              : 
     272              : /// True to its name, this function encapsulates our current auth trade-offs.
     273              : /// Here, we choose the appropriate auth flow based on circumstances.
     274              : ///
     275              : /// All authentication flows will emit an AuthenticationOk message if successful.
     276            6 : async fn auth_quirks(
     277            6 :     ctx: &mut RequestMonitoring,
     278            6 :     api: &impl console::Api,
     279            6 :     user_info: ComputeUserInfoMaybeEndpoint,
     280            6 :     client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
     281            6 :     allow_cleartext: bool,
     282            6 :     config: &'static AuthenticationConfig,
     283            6 :     endpoint_rate_limiter: Arc<EndpointRateLimiter>,
     284            6 : ) -> auth::Result<ComputeCredentials> {
     285              :     // If there's no project so far, that entails that client doesn't
     286              :     // support SNI or other means of passing the endpoint (project) name.
     287              :     // We now expect to see a very specific payload in the place of password.
     288            6 :     let (info, unauthenticated_password) = match user_info.try_into() {
     289            2 :         Err(info) => {
     290            2 :             let res = hacks::password_hack_no_authentication(ctx, info, client).await?;
     291              : 
     292            2 :             ctx.set_endpoint_id(res.info.endpoint.clone());
     293            2 :             let password = match res.keys {
     294            2 :                 ComputeCredentialKeys::Password(p) => p,
     295            0 :                 _ => unreachable!("password hack should return a password"),
     296              :             };
     297            2 :             (res.info, Some(password))
     298              :         }
     299            4 :         Ok(info) => (info, None),
     300              :     };
     301              : 
     302            6 :     info!("fetching user's authentication info");
     303            6 :     let (allowed_ips, maybe_secret) = api.get_allowed_ips_and_secret(ctx, &info).await?;
     304              : 
     305              :     // check allowed list
     306            6 :     if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) {
     307            0 :         return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr));
     308            6 :     }
     309            6 : 
     310            6 :     if !endpoint_rate_limiter.check(info.endpoint.clone().into(), 1) {
     311            0 :         return Err(AuthError::too_many_connections());
     312            6 :     }
     313            6 :     let cached_secret = match maybe_secret {
     314            6 :         Some(secret) => secret,
     315            0 :         None => api.get_role_secret(ctx, &info).await?,
     316              :     };
     317            6 :     let (cached_entry, secret) = cached_secret.take_value();
     318              : 
     319            6 :     let secret = match secret {
     320            6 :         Some(secret) => config.check_rate_limit(
     321            6 :             ctx,
     322            6 :             config,
     323            6 :             secret,
     324            6 :             &info.endpoint,
     325            6 :             unauthenticated_password.is_some() || allow_cleartext,
     326            0 :         )?,
     327              :         None => {
     328              :             // If we don't have an authentication secret, we mock one to
     329              :             // prevent malicious probing (possible due to missing protocol steps).
     330              :             // This mocked secret will never lead to successful authentication.
     331            0 :             info!("authentication info not found, mocking it");
     332            0 :             AuthSecret::Scram(scram::ServerSecret::mock(rand::random()))
     333              :         }
     334              :     };
     335              : 
     336            6 :     match authenticate_with_secret(
     337            6 :         ctx,
     338            6 :         secret,
     339            6 :         info,
     340            6 :         client,
     341            6 :         unauthenticated_password,
     342            6 :         allow_cleartext,
     343            6 :         config,
     344            6 :     )
     345           10 :     .await
     346              :     {
     347            6 :         Ok(keys) => Ok(keys),
     348            0 :         Err(e) => {
     349            0 :             if e.is_auth_failed() {
     350            0 :                 // The password could have been changed, so we invalidate the cache.
     351            0 :                 cached_entry.invalidate();
     352            0 :             }
     353            0 :             Err(e)
     354              :         }
     355              :     }
     356            6 : }
     357              : 
     358            6 : async fn authenticate_with_secret(
     359            6 :     ctx: &mut RequestMonitoring,
     360            6 :     secret: AuthSecret,
     361            6 :     info: ComputeUserInfo,
     362            6 :     client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
     363            6 :     unauthenticated_password: Option<Vec<u8>>,
     364            6 :     allow_cleartext: bool,
     365            6 :     config: &'static AuthenticationConfig,
     366            6 : ) -> auth::Result<ComputeCredentials> {
     367            6 :     if let Some(password) = unauthenticated_password {
     368            2 :         let ep = EndpointIdInt::from(&info.endpoint);
     369              : 
     370            2 :         let auth_outcome =
     371            2 :             validate_password_and_exchange(&config.thread_pool, ep, &password, secret).await?;
     372            2 :         let keys = match auth_outcome {
     373            2 :             crate::sasl::Outcome::Success(key) => key,
     374            0 :             crate::sasl::Outcome::Failure(reason) => {
     375            0 :                 info!("auth backend failed with an error: {reason}");
     376            0 :                 return Err(auth::AuthError::auth_failed(&*info.user));
     377              :             }
     378              :         };
     379              : 
     380              :         // we have authenticated the password
     381            2 :         client.write_message_noflush(&pq_proto::BeMessage::AuthenticationOk)?;
     382              : 
     383            2 :         return Ok(ComputeCredentials { info, keys });
     384            4 :     }
     385            4 : 
     386            4 :     // -- the remaining flows are self-authenticating --
     387            4 : 
     388            4 :     // Perform cleartext auth if we're allowed to do that.
     389            4 :     // Currently, we use it for websocket connections (latency).
     390            4 :     if allow_cleartext {
     391            2 :         ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
     392            4 :         return hacks::authenticate_cleartext(ctx, info, client, secret, config).await;
     393            2 :     }
     394            2 : 
     395            2 :     // Finally, proceed with the main auth flow (SCRAM-based).
     396            4 :     classic::authenticate(ctx, info, client, config, secret).await
     397            6 : }
     398              : 
     399              : impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
     400              :     /// Get compute endpoint name from the credentials.
     401            0 :     pub fn get_endpoint(&self) -> Option<EndpointId> {
     402            0 :         use BackendType::*;
     403            0 : 
     404            0 :         match self {
     405            0 :             Console(_, user_info) => user_info.endpoint_id.clone(),
     406            0 :             Link(_, _) => Some("link".into()),
     407              :         }
     408            0 :     }
     409              : 
     410              :     /// Get username from the credentials.
     411            0 :     pub fn get_user(&self) -> &str {
     412            0 :         use BackendType::*;
     413            0 : 
     414            0 :         match self {
     415            0 :             Console(_, user_info) => &user_info.user,
     416            0 :             Link(_, _) => "link",
     417              :         }
     418            0 :     }
     419              : 
     420              :     /// Authenticate the client via the requested backend, possibly using credentials.
     421            0 :     #[tracing::instrument(fields(allow_cleartext = allow_cleartext), skip_all)]
     422              :     pub async fn authenticate(
     423              :         self,
     424              :         ctx: &mut RequestMonitoring,
     425              :         client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
     426              :         allow_cleartext: bool,
     427              :         config: &'static AuthenticationConfig,
     428              :         endpoint_rate_limiter: Arc<EndpointRateLimiter>,
     429              :     ) -> auth::Result<BackendType<'a, ComputeCredentials, NodeInfo>> {
     430              :         use BackendType::*;
     431              : 
     432              :         let res = match self {
     433              :             Console(api, user_info) => {
     434              :                 info!(
     435              :                     user = &*user_info.user,
     436              :                     project = user_info.endpoint(),
     437              :                     "performing authentication using the console"
     438              :                 );
     439              : 
     440              :                 let credentials = auth_quirks(
     441              :                     ctx,
     442              :                     &*api,
     443              :                     user_info,
     444              :                     client,
     445              :                     allow_cleartext,
     446              :                     config,
     447              :                     endpoint_rate_limiter,
     448              :                 )
     449              :                 .await?;
     450              :                 BackendType::Console(api, credentials)
     451              :             }
     452              :             // NOTE: this auth backend doesn't use client credentials.
     453              :             Link(url, _) => {
     454              :                 info!("performing link authentication");
     455              : 
     456              :                 let info = link::authenticate(ctx, &url, client).await?;
     457              : 
     458              :                 BackendType::Link(url, info)
     459              :             }
     460              :         };
     461              : 
     462              :         info!("user successfully authenticated");
     463              :         Ok(res)
     464              :     }
     465              : }
     466              : 
     467              : impl BackendType<'_, ComputeUserInfo, &()> {
     468            0 :     pub async fn get_role_secret(
     469            0 :         &self,
     470            0 :         ctx: &mut RequestMonitoring,
     471            0 :     ) -> Result<CachedRoleSecret, GetAuthInfoError> {
     472            0 :         use BackendType::*;
     473            0 :         match self {
     474            0 :             Console(api, user_info) => api.get_role_secret(ctx, user_info).await,
     475            0 :             Link(_, _) => Ok(Cached::new_uncached(None)),
     476              :         }
     477            0 :     }
     478              : 
     479            0 :     pub async fn get_allowed_ips_and_secret(
     480            0 :         &self,
     481            0 :         ctx: &mut RequestMonitoring,
     482            0 :     ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
     483            0 :         use BackendType::*;
     484            0 :         match self {
     485            0 :             Console(api, user_info) => api.get_allowed_ips_and_secret(ctx, user_info).await,
     486            0 :             Link(_, _) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
     487              :         }
     488            0 :     }
     489              : }
     490              : 
     491              : #[async_trait::async_trait]
     492              : impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, NodeInfo> {
     493            0 :     async fn wake_compute(
     494            0 :         &self,
     495            0 :         ctx: &mut RequestMonitoring,
     496            0 :     ) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
     497            0 :         use BackendType::*;
     498            0 : 
     499            0 :         match self {
     500            0 :             Console(api, creds) => api.wake_compute(ctx, &creds.info).await,
     501            0 :             Link(_, info) => Ok(Cached::new_uncached(info.clone())),
     502            0 :         }
     503            0 :     }
     504              : 
     505            0 :     fn get_keys(&self) -> Option<&ComputeCredentialKeys> {
     506            0 :         match self {
     507            0 :             BackendType::Console(_, creds) => Some(&creds.keys),
     508            0 :             BackendType::Link(_, _) => None,
     509              :         }
     510            0 :     }
     511              : }
     512              : 
     513              : #[async_trait::async_trait]
     514              : impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, &()> {
     515           26 :     async fn wake_compute(
     516           26 :         &self,
     517           26 :         ctx: &mut RequestMonitoring,
     518           26 :     ) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
     519           26 :         use BackendType::*;
     520           26 : 
     521           26 :         match self {
     522           26 :             Console(api, creds) => api.wake_compute(ctx, &creds.info).await,
     523           26 :             Link(_, _) => unreachable!("link auth flow doesn't support waking the compute"),
     524           26 :         }
     525           26 :     }
     526              : 
     527           12 :     fn get_keys(&self) -> Option<&ComputeCredentialKeys> {
     528           12 :         match self {
     529           12 :             BackendType::Console(_, creds) => Some(&creds.keys),
     530            0 :             BackendType::Link(_, _) => None,
     531              :         }
     532           12 :     }
     533              : }
     534              : 
     535              : #[cfg(test)]
     536              : mod tests {
     537              :     use std::{net::IpAddr, sync::Arc, time::Duration};
     538              : 
     539              :     use bytes::BytesMut;
     540              :     use fallible_iterator::FallibleIterator;
     541              :     use once_cell::sync::Lazy;
     542              :     use postgres_protocol::{
     543              :         authentication::sasl::{ChannelBinding, ScramSha256},
     544              :         message::{backend::Message as PgMessage, frontend},
     545              :     };
     546              :     use provider::AuthSecret;
     547              :     use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
     548              : 
     549              :     use crate::{
     550              :         auth::{backend::MaskedIp, ComputeUserInfoMaybeEndpoint, IpPattern},
     551              :         config::AuthenticationConfig,
     552              :         console::{
     553              :             self,
     554              :             provider::{self, CachedAllowedIps, CachedRoleSecret},
     555              :             CachedNodeInfo,
     556              :         },
     557              :         context::RequestMonitoring,
     558              :         proxy::NeonOptions,
     559              :         rate_limiter::{EndpointRateLimiter, RateBucketInfo},
     560              :         scram::{threadpool::ThreadPool, ServerSecret},
     561              :         stream::{PqStream, Stream},
     562              :     };
     563              : 
     564              :     use super::{auth_quirks, AuthRateLimiter};
     565              : 
     566              :     struct Auth {
     567              :         ips: Vec<IpPattern>,
     568              :         secret: AuthSecret,
     569              :     }
     570              : 
     571              :     impl console::Api for Auth {
     572            0 :         async fn get_role_secret(
     573            0 :             &self,
     574            0 :             _ctx: &mut RequestMonitoring,
     575            0 :             _user_info: &super::ComputeUserInfo,
     576            0 :         ) -> Result<CachedRoleSecret, console::errors::GetAuthInfoError> {
     577            0 :             Ok(CachedRoleSecret::new_uncached(Some(self.secret.clone())))
     578            0 :         }
     579              : 
     580            6 :         async fn get_allowed_ips_and_secret(
     581            6 :             &self,
     582            6 :             _ctx: &mut RequestMonitoring,
     583            6 :             _user_info: &super::ComputeUserInfo,
     584            6 :         ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), console::errors::GetAuthInfoError>
     585            6 :         {
     586            6 :             Ok((
     587            6 :                 CachedAllowedIps::new_uncached(Arc::new(self.ips.clone())),
     588            6 :                 Some(CachedRoleSecret::new_uncached(Some(self.secret.clone()))),
     589            6 :             ))
     590            6 :         }
     591              : 
     592            0 :         async fn wake_compute(
     593            0 :             &self,
     594            0 :             _ctx: &mut RequestMonitoring,
     595            0 :             _user_info: &super::ComputeUserInfo,
     596            0 :         ) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
     597            0 :             unimplemented!()
     598              :         }
     599              :     }
     600              : 
     601            6 :     static CONFIG: Lazy<AuthenticationConfig> = Lazy::new(|| AuthenticationConfig {
     602            6 :         thread_pool: ThreadPool::new(1),
     603            6 :         scram_protocol_timeout: std::time::Duration::from_secs(5),
     604            6 :         rate_limiter_enabled: true,
     605            6 :         rate_limiter: AuthRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET),
     606            6 :         rate_limit_ip_subnet: 64,
     607            6 :     });
     608              : 
     609           10 :     async fn read_message(r: &mut (impl AsyncRead + Unpin), b: &mut BytesMut) -> PgMessage {
     610              :         loop {
     611           14 :             r.read_buf(&mut *b).await.unwrap();
     612           14 :             if let Some(m) = PgMessage::parse(&mut *b).unwrap() {
     613           10 :                 break m;
     614            4 :             }
     615              :         }
     616           10 :     }
     617              : 
     618              :     #[test]
     619            2 :     fn masked_ip() {
     620            2 :         let ip_a = IpAddr::V4([127, 0, 0, 1].into());
     621            2 :         let ip_b = IpAddr::V4([127, 0, 0, 2].into());
     622            2 :         let ip_c = IpAddr::V4([192, 168, 1, 101].into());
     623            2 :         let ip_d = IpAddr::V4([192, 168, 1, 102].into());
     624            2 :         let ip_e = IpAddr::V6("abcd:abcd:abcd:abcd:abcd:abcd:abcd:abcd".parse().unwrap());
     625            2 :         let ip_f = IpAddr::V6("abcd:abcd:abcd:abcd:1234:abcd:abcd:abcd".parse().unwrap());
     626            2 : 
     627            2 :         assert_ne!(MaskedIp::new(ip_a, 64), MaskedIp::new(ip_b, 64));
     628            2 :         assert_ne!(MaskedIp::new(ip_a, 32), MaskedIp::new(ip_b, 32));
     629            2 :         assert_eq!(MaskedIp::new(ip_a, 30), MaskedIp::new(ip_b, 30));
     630            2 :         assert_eq!(MaskedIp::new(ip_c, 30), MaskedIp::new(ip_d, 30));
     631              : 
     632            2 :         assert_ne!(MaskedIp::new(ip_e, 128), MaskedIp::new(ip_f, 128));
     633            2 :         assert_eq!(MaskedIp::new(ip_e, 64), MaskedIp::new(ip_f, 64));
     634            2 :     }
     635              : 
     636              :     #[test]
     637            2 :     fn test_default_auth_rate_limit_set() {
     638            2 :         // these values used to exceed u32::MAX
     639            2 :         assert_eq!(
     640            2 :             RateBucketInfo::DEFAULT_AUTH_SET,
     641            2 :             [
     642            2 :                 RateBucketInfo {
     643            2 :                     interval: Duration::from_secs(1),
     644            2 :                     max_rpi: 1000 * 4096,
     645            2 :                 },
     646            2 :                 RateBucketInfo {
     647            2 :                     interval: Duration::from_secs(60),
     648            2 :                     max_rpi: 600 * 4096 * 60,
     649            2 :                 },
     650            2 :                 RateBucketInfo {
     651            2 :                     interval: Duration::from_secs(600),
     652            2 :                     max_rpi: 300 * 4096 * 600,
     653            2 :                 }
     654            2 :             ]
     655            2 :         );
     656              : 
     657            8 :         for x in RateBucketInfo::DEFAULT_AUTH_SET {
     658            6 :             let y = x.to_string().parse().unwrap();
     659            6 :             assert_eq!(x, y);
     660              :         }
     661            2 :     }
     662              : 
     663              :     #[tokio::test]
     664            2 :     async fn auth_quirks_scram() {
     665            2 :         let (mut client, server) = tokio::io::duplex(1024);
     666            2 :         let mut stream = PqStream::new(Stream::from_raw(server));
     667            2 : 
     668            2 :         let mut ctx = RequestMonitoring::test();
     669            2 :         let api = Auth {
     670            2 :             ips: vec![],
     671            6 :             secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
     672            2 :         };
     673            2 : 
     674            2 :         let user_info = ComputeUserInfoMaybeEndpoint {
     675            2 :             user: "conrad".into(),
     676            2 :             endpoint_id: Some("endpoint".into()),
     677            2 :             options: NeonOptions::default(),
     678            2 :         };
     679            2 : 
     680            2 :         let handle = tokio::spawn(async move {
     681            2 :             let mut scram = ScramSha256::new(b"my-secret-password", ChannelBinding::unsupported());
     682            2 : 
     683            2 :             let mut read = BytesMut::new();
     684            2 : 
     685            2 :             // server should offer scram
     686            2 :             match read_message(&mut client, &mut read).await {
     687            2 :                 PgMessage::AuthenticationSasl(a) => {
     688            2 :                     let options: Vec<&str> = a.mechanisms().collect().unwrap();
     689            2 :                     assert_eq!(options, ["SCRAM-SHA-256"]);
     690            2 :                 }
     691            2 :                 _ => panic!("wrong message"),
     692            2 :             }
     693            2 : 
     694            2 :             // client sends client-first-message
     695            2 :             let mut write = BytesMut::new();
     696            2 :             frontend::sasl_initial_response("SCRAM-SHA-256", scram.message(), &mut write).unwrap();
     697            2 :             client.write_all(&write).await.unwrap();
     698            2 : 
     699            2 :             // server response with server-first-message
     700            2 :             match read_message(&mut client, &mut read).await {
     701            2 :                 PgMessage::AuthenticationSaslContinue(a) => {
     702            6 :                     scram.update(a.data()).await.unwrap();
     703            2 :                 }
     704            2 :                 _ => panic!("wrong message"),
     705            2 :             }
     706            2 : 
     707            2 :             // client response with client-final-message
     708            2 :             write.clear();
     709            2 :             frontend::sasl_response(scram.message(), &mut write).unwrap();
     710            2 :             client.write_all(&write).await.unwrap();
     711            2 : 
     712            2 :             // server response with server-final-message
     713            2 :             match read_message(&mut client, &mut read).await {
     714            2 :                 PgMessage::AuthenticationSaslFinal(a) => {
     715            2 :                     scram.finish(a.data()).unwrap();
     716            2 :                 }
     717            2 :                 _ => panic!("wrong message"),
     718            2 :             }
     719            2 :         });
     720            2 :         let endpoint_rate_limiter =
     721            2 :             Arc::new(EndpointRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET));
     722            2 : 
     723            2 :         let _creds = auth_quirks(
     724            2 :             &mut ctx,
     725            2 :             &api,
     726            2 :             user_info,
     727            2 :             &mut stream,
     728            2 :             false,
     729            2 :             &CONFIG,
     730            2 :             endpoint_rate_limiter,
     731            2 :         )
     732            4 :         .await
     733            2 :         .unwrap();
     734            2 : 
     735            2 :         handle.await.unwrap();
     736            2 :     }
     737              : 
     738              :     #[tokio::test]
     739            2 :     async fn auth_quirks_cleartext() {
     740            2 :         let (mut client, server) = tokio::io::duplex(1024);
     741            2 :         let mut stream = PqStream::new(Stream::from_raw(server));
     742            2 : 
     743            2 :         let mut ctx = RequestMonitoring::test();
     744            2 :         let api = Auth {
     745            2 :             ips: vec![],
     746            6 :             secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
     747            2 :         };
     748            2 : 
     749            2 :         let user_info = ComputeUserInfoMaybeEndpoint {
     750            2 :             user: "conrad".into(),
     751            2 :             endpoint_id: Some("endpoint".into()),
     752            2 :             options: NeonOptions::default(),
     753            2 :         };
     754            2 : 
     755            2 :         let handle = tokio::spawn(async move {
     756            2 :             let mut read = BytesMut::new();
     757            2 :             let mut write = BytesMut::new();
     758            2 : 
     759            2 :             // server should offer cleartext
     760            2 :             match read_message(&mut client, &mut read).await {
     761            2 :                 PgMessage::AuthenticationCleartextPassword => {}
     762            2 :                 _ => panic!("wrong message"),
     763            2 :             }
     764            2 : 
     765            2 :             // client responds with password
     766            2 :             write.clear();
     767            2 :             frontend::password_message(b"my-secret-password", &mut write).unwrap();
     768            2 :             client.write_all(&write).await.unwrap();
     769            2 :         });
     770            2 :         let endpoint_rate_limiter =
     771            2 :             Arc::new(EndpointRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET));
     772            2 : 
     773            2 :         let _creds = auth_quirks(
     774            2 :             &mut ctx,
     775            2 :             &api,
     776            2 :             user_info,
     777            2 :             &mut stream,
     778            2 :             true,
     779            2 :             &CONFIG,
     780            2 :             endpoint_rate_limiter,
     781            2 :         )
     782            4 :         .await
     783            2 :         .unwrap();
     784            2 : 
     785            2 :         handle.await.unwrap();
     786            2 :     }
     787              : 
     788              :     #[tokio::test]
     789            2 :     async fn auth_quirks_password_hack() {
     790            2 :         let (mut client, server) = tokio::io::duplex(1024);
     791            2 :         let mut stream = PqStream::new(Stream::from_raw(server));
     792            2 : 
     793            2 :         let mut ctx = RequestMonitoring::test();
     794            2 :         let api = Auth {
     795            2 :             ips: vec![],
     796            6 :             secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
     797            2 :         };
     798            2 : 
     799            2 :         let user_info = ComputeUserInfoMaybeEndpoint {
     800            2 :             user: "conrad".into(),
     801            2 :             endpoint_id: None,
     802            2 :             options: NeonOptions::default(),
     803            2 :         };
     804            2 : 
     805            2 :         let handle = tokio::spawn(async move {
     806            2 :             let mut read = BytesMut::new();
     807            2 : 
     808            2 :             // server should offer cleartext
     809            2 :             match read_message(&mut client, &mut read).await {
     810            2 :                 PgMessage::AuthenticationCleartextPassword => {}
     811            2 :                 _ => panic!("wrong message"),
     812            2 :             }
     813            2 : 
     814            2 :             // client responds with password
     815            2 :             let mut write = BytesMut::new();
     816            2 :             frontend::password_message(b"endpoint=my-endpoint;my-secret-password", &mut write)
     817            2 :                 .unwrap();
     818            2 :             client.write_all(&write).await.unwrap();
     819            2 :         });
     820            2 : 
     821            2 :         let endpoint_rate_limiter =
     822            2 :             Arc::new(EndpointRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET));
     823            2 : 
     824            2 :         let creds = auth_quirks(
     825            2 :             &mut ctx,
     826            2 :             &api,
     827            2 :             user_info,
     828            2 :             &mut stream,
     829            2 :             true,
     830            2 :             &CONFIG,
     831            2 :             endpoint_rate_limiter,
     832            2 :         )
     833            4 :         .await
     834            2 :         .unwrap();
     835            2 : 
     836            2 :         assert_eq!(creds.info.endpoint, "my-endpoint");
     837            2 : 
     838            2 :         handle.await.unwrap();
     839            2 :     }
     840              : }
        

Generated by: LCOV version 2.1-beta