LCOV - code coverage report
Current view: top level - proxy/src/auth/backend - mod.rs (source / functions) Coverage Total Hit
Test: bb45db3982713bfd5bec075773079136e362195e.info Lines: 78.3 % 494 387
Test Date: 2024-12-11 15:53:32 Functions: 42.9 % 70 30

            Line data    Source code
       1              : mod classic;
       2              : mod console_redirect;
       3              : mod hacks;
       4              : pub mod jwt;
       5              : pub mod local;
       6              : 
       7              : use std::net::IpAddr;
       8              : use std::sync::Arc;
       9              : 
      10              : pub use console_redirect::ConsoleRedirectBackend;
      11              : pub(crate) use console_redirect::ConsoleRedirectError;
      12              : use ipnet::{Ipv4Net, Ipv6Net};
      13              : use local::LocalBackend;
      14              : use postgres_client::config::AuthKeys;
      15              : use tokio::io::{AsyncRead, AsyncWrite};
      16              : use tracing::{debug, info, warn};
      17              : 
      18              : use crate::auth::credentials::check_peer_addr_is_in_list;
      19              : use crate::auth::{self, validate_password_and_exchange, AuthError, ComputeUserInfoMaybeEndpoint};
      20              : use crate::cache::Cached;
      21              : use crate::config::AuthenticationConfig;
      22              : use crate::context::RequestContext;
      23              : use crate::control_plane::client::ControlPlaneClient;
      24              : use crate::control_plane::errors::GetAuthInfoError;
      25              : use crate::control_plane::{
      26              :     self, AuthSecret, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret, ControlPlaneApi,
      27              : };
      28              : use crate::intern::EndpointIdInt;
      29              : use crate::metrics::Metrics;
      30              : use crate::proxy::connect_compute::ComputeConnectBackend;
      31              : use crate::proxy::NeonOptions;
      32              : use crate::rate_limiter::{BucketRateLimiter, EndpointRateLimiter};
      33              : use crate::stream::Stream;
      34              : use crate::types::{EndpointCacheKey, EndpointId, RoleName};
      35              : use crate::{scram, stream};
      36              : 
      37              : /// Alternative to [`std::borrow::Cow`] but doesn't need `T: ToOwned` as we don't need that functionality
      38              : pub enum MaybeOwned<'a, T> {
      39              :     Owned(T),
      40              :     Borrowed(&'a T),
      41              : }
      42              : 
      43              : impl<T> std::ops::Deref for MaybeOwned<'_, T> {
      44              :     type Target = T;
      45              : 
      46           13 :     fn deref(&self) -> &Self::Target {
      47           13 :         match self {
      48           13 :             MaybeOwned::Owned(t) => t,
      49            0 :             MaybeOwned::Borrowed(t) => t,
      50              :         }
      51           13 :     }
      52              : }
      53              : 
      54              : /// This type serves two purposes:
      55              : ///
      56              : /// * When `T` is `()`, it's just a regular auth backend selector
      57              : ///   which we use in [`crate::config::ProxyConfig`].
      58              : ///
      59              : /// * However, when we substitute `T` with [`ComputeUserInfoMaybeEndpoint`],
      60              : ///   this helps us provide the credentials only to those auth
      61              : ///   backends which require them for the authentication process.
      62              : pub enum Backend<'a, T> {
      63              :     /// Cloud API (V2).
      64              :     ControlPlane(MaybeOwned<'a, ControlPlaneClient>, T),
      65              :     /// Local proxy uses configured auth credentials and does not wake compute
      66              :     Local(MaybeOwned<'a, LocalBackend>),
      67              : }
      68              : 
      69              : impl std::fmt::Display for Backend<'_, ()> {
      70            0 :     fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
      71            0 :         match self {
      72            0 :             Self::ControlPlane(api, ()) => match &**api {
      73            0 :                 ControlPlaneClient::ProxyV1(endpoint) => fmt
      74            0 :                     .debug_tuple("ControlPlane::ProxyV1")
      75            0 :                     .field(&endpoint.url())
      76            0 :                     .finish(),
      77            0 :                 ControlPlaneClient::Neon(endpoint) => fmt
      78            0 :                     .debug_tuple("ControlPlane::Neon")
      79            0 :                     .field(&endpoint.url())
      80            0 :                     .finish(),
      81              :                 #[cfg(any(test, feature = "testing"))]
      82            0 :                 ControlPlaneClient::PostgresMock(endpoint) => fmt
      83            0 :                     .debug_tuple("ControlPlane::PostgresMock")
      84            0 :                     .field(&endpoint.url())
      85            0 :                     .finish(),
      86              :                 #[cfg(test)]
      87            0 :                 ControlPlaneClient::Test(_) => fmt.debug_tuple("ControlPlane::Test").finish(),
      88              :             },
      89            0 :             Self::Local(_) => fmt.debug_tuple("Local").finish(),
      90              :         }
      91            0 :     }
      92              : }
      93              : 
      94              : impl<T> Backend<'_, T> {
      95              :     /// Very similar to [`std::option::Option::as_ref`].
      96              :     /// This helps us pass structured config to async tasks.
      97            0 :     pub(crate) fn as_ref(&self) -> Backend<'_, &T> {
      98            0 :         match self {
      99            0 :             Self::ControlPlane(c, x) => Backend::ControlPlane(MaybeOwned::Borrowed(c), x),
     100            0 :             Self::Local(l) => Backend::Local(MaybeOwned::Borrowed(l)),
     101              :         }
     102            0 :     }
     103              : }
     104              : 
     105              : impl<'a, T> Backend<'a, T> {
     106              :     /// Very similar to [`std::option::Option::map`].
     107              :     /// Maps [`Backend<T>`] to [`Backend<R>`] by applying
     108              :     /// a function to a contained value.
     109            0 :     pub(crate) fn map<R>(self, f: impl FnOnce(T) -> R) -> Backend<'a, R> {
     110            0 :         match self {
     111            0 :             Self::ControlPlane(c, x) => Backend::ControlPlane(c, f(x)),
     112            0 :             Self::Local(l) => Backend::Local(l),
     113              :         }
     114            0 :     }
     115              : }
     116              : impl<'a, T, E> Backend<'a, Result<T, E>> {
     117              :     /// Very similar to [`std::option::Option::transpose`].
     118              :     /// This is most useful for error handling.
     119            0 :     pub(crate) fn transpose(self) -> Result<Backend<'a, T>, E> {
     120            0 :         match self {
     121            0 :             Self::ControlPlane(c, x) => x.map(|x| Backend::ControlPlane(c, x)),
     122            0 :             Self::Local(l) => Ok(Backend::Local(l)),
     123              :         }
     124            0 :     }
     125              : }
     126              : 
     127              : pub(crate) struct ComputeCredentials {
     128              :     pub(crate) info: ComputeUserInfo,
     129              :     pub(crate) keys: ComputeCredentialKeys,
     130              : }
     131              : 
     132              : #[derive(Debug, Clone)]
     133              : pub(crate) struct ComputeUserInfoNoEndpoint {
     134              :     pub(crate) user: RoleName,
     135              :     pub(crate) options: NeonOptions,
     136              : }
     137              : 
     138              : #[derive(Debug, Clone)]
     139              : pub(crate) struct ComputeUserInfo {
     140              :     pub(crate) endpoint: EndpointId,
     141              :     pub(crate) user: RoleName,
     142              :     pub(crate) options: NeonOptions,
     143              : }
     144              : 
     145              : impl ComputeUserInfo {
     146            2 :     pub(crate) fn endpoint_cache_key(&self) -> EndpointCacheKey {
     147            2 :         self.options.get_cache_key(&self.endpoint)
     148            2 :     }
     149              : }
     150              : 
     151              : #[cfg_attr(test, derive(Debug))]
     152              : pub(crate) enum ComputeCredentialKeys {
     153              :     #[cfg(any(test, feature = "testing"))]
     154              :     Password(Vec<u8>),
     155              :     AuthKeys(AuthKeys),
     156              :     JwtPayload(Vec<u8>),
     157              :     None,
     158              : }
     159              : 
     160              : impl TryFrom<ComputeUserInfoMaybeEndpoint> for ComputeUserInfo {
     161              :     // user name
     162              :     type Error = ComputeUserInfoNoEndpoint;
     163              : 
     164            3 :     fn try_from(user_info: ComputeUserInfoMaybeEndpoint) -> Result<Self, Self::Error> {
     165            3 :         match user_info.endpoint_id {
     166            1 :             None => Err(ComputeUserInfoNoEndpoint {
     167            1 :                 user: user_info.user,
     168            1 :                 options: user_info.options,
     169            1 :             }),
     170            2 :             Some(endpoint) => Ok(ComputeUserInfo {
     171            2 :                 endpoint,
     172            2 :                 user: user_info.user,
     173            2 :                 options: user_info.options,
     174            2 :             }),
     175              :         }
     176            3 :     }
     177              : }
     178              : 
     179              : #[derive(PartialEq, PartialOrd, Hash, Eq, Ord, Debug, Copy, Clone)]
     180              : pub struct MaskedIp(IpAddr);
     181              : 
     182              : impl MaskedIp {
     183           15 :     fn new(value: IpAddr, prefix: u8) -> Self {
     184           15 :         match value {
     185           11 :             IpAddr::V4(v4) => Self(IpAddr::V4(
     186           11 :                 Ipv4Net::new(v4, prefix).map_or(v4, |x| x.trunc().addr()),
     187           11 :             )),
     188            4 :             IpAddr::V6(v6) => Self(IpAddr::V6(
     189            4 :                 Ipv6Net::new(v6, prefix).map_or(v6, |x| x.trunc().addr()),
     190            4 :             )),
     191              :         }
     192           15 :     }
     193              : }
     194              : 
     195              : // This can't be just per IP because that would limit some PaaS that share IP addresses
     196              : pub type AuthRateLimiter = BucketRateLimiter<(EndpointIdInt, MaskedIp)>;
     197              : 
     198              : impl AuthenticationConfig {
     199            3 :     pub(crate) fn check_rate_limit(
     200            3 :         &self,
     201            3 :         ctx: &RequestContext,
     202            3 :         secret: AuthSecret,
     203            3 :         endpoint: &EndpointId,
     204            3 :         is_cleartext: bool,
     205            3 :     ) -> auth::Result<AuthSecret> {
     206            3 :         // we have validated the endpoint exists, so let's intern it.
     207            3 :         let endpoint_int = EndpointIdInt::from(endpoint.normalize());
     208              : 
     209              :         // only count the full hash count if password hack or websocket flow.
     210              :         // in other words, if proxy needs to run the hashing
     211            3 :         let password_weight = if is_cleartext {
     212            2 :             match &secret {
     213              :                 #[cfg(any(test, feature = "testing"))]
     214            0 :                 AuthSecret::Md5(_) => 1,
     215            2 :                 AuthSecret::Scram(s) => s.iterations + 1,
     216              :             }
     217              :         } else {
     218              :             // validating scram takes just 1 hmac_sha_256 operation.
     219            1 :             1
     220              :         };
     221              : 
     222            3 :         let limit_not_exceeded = self.rate_limiter.check(
     223            3 :             (
     224            3 :                 endpoint_int,
     225            3 :                 MaskedIp::new(ctx.peer_addr(), self.rate_limit_ip_subnet),
     226            3 :             ),
     227            3 :             password_weight,
     228            3 :         );
     229            3 : 
     230            3 :         if !limit_not_exceeded {
     231            0 :             warn!(
     232              :                 enabled = self.rate_limiter_enabled,
     233            0 :                 "rate limiting authentication"
     234              :             );
     235            0 :             Metrics::get().proxy.requests_auth_rate_limits_total.inc();
     236            0 :             Metrics::get()
     237            0 :                 .proxy
     238            0 :                 .endpoints_auth_rate_limits
     239            0 :                 .get_metric()
     240            0 :                 .measure(endpoint);
     241            0 : 
     242            0 :             if self.rate_limiter_enabled {
     243            0 :                 return Err(auth::AuthError::too_many_connections());
     244            0 :             }
     245            3 :         }
     246              : 
     247            3 :         Ok(secret)
     248            3 :     }
     249              : }
     250              : 
     251              : /// True to its name, this function encapsulates our current auth trade-offs.
     252              : /// Here, we choose the appropriate auth flow based on circumstances.
     253              : ///
     254              : /// All authentication flows will emit an AuthenticationOk message if successful.
     255            3 : async fn auth_quirks(
     256            3 :     ctx: &RequestContext,
     257            3 :     api: &impl control_plane::ControlPlaneApi,
     258            3 :     user_info: ComputeUserInfoMaybeEndpoint,
     259            3 :     client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
     260            3 :     allow_cleartext: bool,
     261            3 :     config: &'static AuthenticationConfig,
     262            3 :     endpoint_rate_limiter: Arc<EndpointRateLimiter>,
     263            3 : ) -> auth::Result<ComputeCredentials> {
     264              :     // If there's no project so far, that entails that client doesn't
     265              :     // support SNI or other means of passing the endpoint (project) name.
     266              :     // We now expect to see a very specific payload in the place of password.
     267            3 :     let (info, unauthenticated_password) = match user_info.try_into() {
     268            1 :         Err(info) => {
     269            1 :             let (info, password) =
     270            1 :                 hacks::password_hack_no_authentication(ctx, info, client).await?;
     271            1 :             ctx.set_endpoint_id(info.endpoint.clone());
     272            1 :             (info, Some(password))
     273              :         }
     274            2 :         Ok(info) => (info, None),
     275              :     };
     276              : 
     277            3 :     debug!("fetching user's authentication info");
     278            3 :     let (allowed_ips, maybe_secret) = api.get_allowed_ips_and_secret(ctx, &info).await?;
     279              : 
     280              :     // check allowed list
     281            3 :     if config.ip_allowlist_check_enabled
     282            3 :         && !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips)
     283              :     {
     284            0 :         return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr()));
     285            3 :     }
     286            3 : 
     287            3 :     if !endpoint_rate_limiter.check(info.endpoint.clone().into(), 1) {
     288            0 :         return Err(AuthError::too_many_connections());
     289            3 :     }
     290            3 :     let cached_secret = match maybe_secret {
     291            3 :         Some(secret) => secret,
     292            0 :         None => api.get_role_secret(ctx, &info).await?,
     293              :     };
     294            3 :     let (cached_entry, secret) = cached_secret.take_value();
     295              : 
     296            3 :     let secret = if let Some(secret) = secret {
     297            3 :         config.check_rate_limit(
     298            3 :             ctx,
     299            3 :             secret,
     300            3 :             &info.endpoint,
     301            3 :             unauthenticated_password.is_some() || allow_cleartext,
     302            0 :         )?
     303              :     } else {
     304              :         // If we don't have an authentication secret, we mock one to
     305              :         // prevent malicious probing (possible due to missing protocol steps).
     306              :         // This mocked secret will never lead to successful authentication.
     307            0 :         info!("authentication info not found, mocking it");
     308            0 :         AuthSecret::Scram(scram::ServerSecret::mock(rand::random()))
     309              :     };
     310              : 
     311            3 :     match authenticate_with_secret(
     312            3 :         ctx,
     313            3 :         secret,
     314            3 :         info,
     315            3 :         client,
     316            3 :         unauthenticated_password,
     317            3 :         allow_cleartext,
     318            3 :         config,
     319            3 :     )
     320            3 :     .await
     321              :     {
     322            3 :         Ok(keys) => Ok(keys),
     323            0 :         Err(e) => {
     324            0 :             if e.is_password_failed() {
     325            0 :                 // The password could have been changed, so we invalidate the cache.
     326            0 :                 cached_entry.invalidate();
     327            0 :             }
     328            0 :             Err(e)
     329              :         }
     330              :     }
     331            3 : }
     332              : 
     333            3 : async fn authenticate_with_secret(
     334            3 :     ctx: &RequestContext,
     335            3 :     secret: AuthSecret,
     336            3 :     info: ComputeUserInfo,
     337            3 :     client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
     338            3 :     unauthenticated_password: Option<Vec<u8>>,
     339            3 :     allow_cleartext: bool,
     340            3 :     config: &'static AuthenticationConfig,
     341            3 : ) -> auth::Result<ComputeCredentials> {
     342            3 :     if let Some(password) = unauthenticated_password {
     343            1 :         let ep = EndpointIdInt::from(&info.endpoint);
     344              : 
     345            1 :         let auth_outcome =
     346            1 :             validate_password_and_exchange(&config.thread_pool, ep, &password, secret).await?;
     347            1 :         let keys = match auth_outcome {
     348            1 :             crate::sasl::Outcome::Success(key) => key,
     349            0 :             crate::sasl::Outcome::Failure(reason) => {
     350            0 :                 info!("auth backend failed with an error: {reason}");
     351            0 :                 return Err(auth::AuthError::password_failed(&*info.user));
     352              :             }
     353              :         };
     354              : 
     355              :         // we have authenticated the password
     356            1 :         client.write_message_noflush(&pq_proto::BeMessage::AuthenticationOk)?;
     357              : 
     358            1 :         return Ok(ComputeCredentials { info, keys });
     359            2 :     }
     360            2 : 
     361            2 :     // -- the remaining flows are self-authenticating --
     362            2 : 
     363            2 :     // Perform cleartext auth if we're allowed to do that.
     364            2 :     // Currently, we use it for websocket connections (latency).
     365            2 :     if allow_cleartext {
     366            1 :         ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
     367            1 :         return hacks::authenticate_cleartext(ctx, info, client, secret, config).await;
     368            1 :     }
     369            1 : 
     370            1 :     // Finally, proceed with the main auth flow (SCRAM-based).
     371            1 :     classic::authenticate(ctx, info, client, config, secret).await
     372            3 : }
     373              : 
     374              : impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> {
     375              :     /// Get username from the credentials.
     376            0 :     pub(crate) fn get_user(&self) -> &str {
     377            0 :         match self {
     378            0 :             Self::ControlPlane(_, user_info) => &user_info.user,
     379            0 :             Self::Local(_) => "local",
     380              :         }
     381            0 :     }
     382              : 
     383              :     /// Authenticate the client via the requested backend, possibly using credentials.
     384            0 :     #[tracing::instrument(fields(allow_cleartext = allow_cleartext), skip_all)]
     385              :     pub(crate) async fn authenticate(
     386              :         self,
     387              :         ctx: &RequestContext,
     388              :         client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
     389              :         allow_cleartext: bool,
     390              :         config: &'static AuthenticationConfig,
     391              :         endpoint_rate_limiter: Arc<EndpointRateLimiter>,
     392              :     ) -> auth::Result<Backend<'a, ComputeCredentials>> {
     393              :         let res = match self {
     394              :             Self::ControlPlane(api, user_info) => {
     395              :                 debug!(
     396              :                     user = &*user_info.user,
     397              :                     project = user_info.endpoint(),
     398              :                     "performing authentication using the console"
     399              :                 );
     400              : 
     401              :                 let credentials = auth_quirks(
     402              :                     ctx,
     403              :                     &*api,
     404              :                     user_info,
     405              :                     client,
     406              :                     allow_cleartext,
     407              :                     config,
     408              :                     endpoint_rate_limiter,
     409              :                 )
     410              :                 .await?;
     411              :                 Backend::ControlPlane(api, credentials)
     412              :             }
     413              :             Self::Local(_) => {
     414              :                 return Err(auth::AuthError::bad_auth_method("invalid for local proxy"))
     415              :             }
     416              :         };
     417              : 
     418              :         // TODO: replace with some metric
     419              :         info!("user successfully authenticated");
     420              :         Ok(res)
     421              :     }
     422              : }
     423              : 
     424              : impl Backend<'_, ComputeUserInfo> {
     425            0 :     pub(crate) async fn get_role_secret(
     426            0 :         &self,
     427            0 :         ctx: &RequestContext,
     428            0 :     ) -> Result<CachedRoleSecret, GetAuthInfoError> {
     429            0 :         match self {
     430            0 :             Self::ControlPlane(api, user_info) => api.get_role_secret(ctx, user_info).await,
     431            0 :             Self::Local(_) => Ok(Cached::new_uncached(None)),
     432              :         }
     433            0 :     }
     434              : 
     435            0 :     pub(crate) async fn get_allowed_ips_and_secret(
     436            0 :         &self,
     437            0 :         ctx: &RequestContext,
     438            0 :     ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
     439            0 :         match self {
     440            0 :             Self::ControlPlane(api, user_info) => {
     441            0 :                 api.get_allowed_ips_and_secret(ctx, user_info).await
     442              :             }
     443            0 :             Self::Local(_) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
     444              :         }
     445            0 :     }
     446              : }
     447              : 
     448              : #[async_trait::async_trait]
     449              : impl ComputeConnectBackend for Backend<'_, ComputeCredentials> {
     450           13 :     async fn wake_compute(
     451           13 :         &self,
     452           13 :         ctx: &RequestContext,
     453           13 :     ) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
     454           13 :         match self {
     455           13 :             Self::ControlPlane(api, creds) => api.wake_compute(ctx, &creds.info).await,
     456            0 :             Self::Local(local) => Ok(Cached::new_uncached(local.node_info.clone())),
     457              :         }
     458           26 :     }
     459              : 
     460            6 :     fn get_keys(&self) -> &ComputeCredentialKeys {
     461            6 :         match self {
     462            6 :             Self::ControlPlane(_, creds) => &creds.keys,
     463            0 :             Self::Local(_) => &ComputeCredentialKeys::None,
     464              :         }
     465            6 :     }
     466              : }
     467              : 
     468              : #[cfg(test)]
     469              : mod tests {
     470              :     use std::net::IpAddr;
     471              :     use std::sync::Arc;
     472              :     use std::time::Duration;
     473              : 
     474              :     use bytes::BytesMut;
     475              :     use control_plane::AuthSecret;
     476              :     use fallible_iterator::FallibleIterator;
     477              :     use once_cell::sync::Lazy;
     478              :     use postgres_protocol::authentication::sasl::{ChannelBinding, ScramSha256};
     479              :     use postgres_protocol::message::backend::Message as PgMessage;
     480              :     use postgres_protocol::message::frontend;
     481              :     use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
     482              : 
     483              :     use super::jwt::JwkCache;
     484              :     use super::{auth_quirks, AuthRateLimiter};
     485              :     use crate::auth::backend::MaskedIp;
     486              :     use crate::auth::{ComputeUserInfoMaybeEndpoint, IpPattern};
     487              :     use crate::config::AuthenticationConfig;
     488              :     use crate::context::RequestContext;
     489              :     use crate::control_plane::{self, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret};
     490              :     use crate::proxy::NeonOptions;
     491              :     use crate::rate_limiter::{EndpointRateLimiter, RateBucketInfo};
     492              :     use crate::scram::threadpool::ThreadPool;
     493              :     use crate::scram::ServerSecret;
     494              :     use crate::stream::{PqStream, Stream};
     495              : 
     496              :     struct Auth {
     497              :         ips: Vec<IpPattern>,
     498              :         secret: AuthSecret,
     499              :     }
     500              : 
     501              :     impl control_plane::ControlPlaneApi for Auth {
     502            0 :         async fn get_role_secret(
     503            0 :             &self,
     504            0 :             _ctx: &RequestContext,
     505            0 :             _user_info: &super::ComputeUserInfo,
     506            0 :         ) -> Result<CachedRoleSecret, control_plane::errors::GetAuthInfoError> {
     507            0 :             Ok(CachedRoleSecret::new_uncached(Some(self.secret.clone())))
     508            0 :         }
     509              : 
     510            3 :         async fn get_allowed_ips_and_secret(
     511            3 :             &self,
     512            3 :             _ctx: &RequestContext,
     513            3 :             _user_info: &super::ComputeUserInfo,
     514            3 :         ) -> Result<
     515            3 :             (CachedAllowedIps, Option<CachedRoleSecret>),
     516            3 :             control_plane::errors::GetAuthInfoError,
     517            3 :         > {
     518            3 :             Ok((
     519            3 :                 CachedAllowedIps::new_uncached(Arc::new(self.ips.clone())),
     520            3 :                 Some(CachedRoleSecret::new_uncached(Some(self.secret.clone()))),
     521            3 :             ))
     522            3 :         }
     523              : 
     524            0 :         async fn get_endpoint_jwks(
     525            0 :             &self,
     526            0 :             _ctx: &RequestContext,
     527            0 :             _endpoint: crate::types::EndpointId,
     528            0 :         ) -> Result<Vec<super::jwt::AuthRule>, control_plane::errors::GetEndpointJwksError>
     529            0 :         {
     530            0 :             unimplemented!()
     531              :         }
     532              : 
     533            0 :         async fn wake_compute(
     534            0 :             &self,
     535            0 :             _ctx: &RequestContext,
     536            0 :             _user_info: &super::ComputeUserInfo,
     537            0 :         ) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
     538            0 :             unimplemented!()
     539              :         }
     540              :     }
     541              : 
     542            3 :     static CONFIG: Lazy<AuthenticationConfig> = Lazy::new(|| AuthenticationConfig {
     543            3 :         jwks_cache: JwkCache::default(),
     544            3 :         thread_pool: ThreadPool::new(1),
     545            3 :         scram_protocol_timeout: std::time::Duration::from_secs(5),
     546            3 :         rate_limiter_enabled: true,
     547            3 :         rate_limiter: AuthRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET),
     548            3 :         rate_limit_ip_subnet: 64,
     549            3 :         ip_allowlist_check_enabled: true,
     550            3 :         is_auth_broker: false,
     551            3 :         accept_jwts: false,
     552            3 :         console_redirect_confirmation_timeout: std::time::Duration::from_secs(5),
     553            3 :     });
     554              : 
     555            5 :     async fn read_message(r: &mut (impl AsyncRead + Unpin), b: &mut BytesMut) -> PgMessage {
     556              :         loop {
     557            7 :             r.read_buf(&mut *b).await.unwrap();
     558            7 :             if let Some(m) = PgMessage::parse(&mut *b).unwrap() {
     559            5 :                 break m;
     560            2 :             }
     561              :         }
     562            5 :     }
     563              : 
     564              :     #[test]
     565            1 :     fn masked_ip() {
     566            1 :         let ip_a = IpAddr::V4([127, 0, 0, 1].into());
     567            1 :         let ip_b = IpAddr::V4([127, 0, 0, 2].into());
     568            1 :         let ip_c = IpAddr::V4([192, 168, 1, 101].into());
     569            1 :         let ip_d = IpAddr::V4([192, 168, 1, 102].into());
     570            1 :         let ip_e = IpAddr::V6("abcd:abcd:abcd:abcd:abcd:abcd:abcd:abcd".parse().unwrap());
     571            1 :         let ip_f = IpAddr::V6("abcd:abcd:abcd:abcd:1234:abcd:abcd:abcd".parse().unwrap());
     572            1 : 
     573            1 :         assert_ne!(MaskedIp::new(ip_a, 64), MaskedIp::new(ip_b, 64));
     574            1 :         assert_ne!(MaskedIp::new(ip_a, 32), MaskedIp::new(ip_b, 32));
     575            1 :         assert_eq!(MaskedIp::new(ip_a, 30), MaskedIp::new(ip_b, 30));
     576            1 :         assert_eq!(MaskedIp::new(ip_c, 30), MaskedIp::new(ip_d, 30));
     577              : 
     578            1 :         assert_ne!(MaskedIp::new(ip_e, 128), MaskedIp::new(ip_f, 128));
     579            1 :         assert_eq!(MaskedIp::new(ip_e, 64), MaskedIp::new(ip_f, 64));
     580            1 :     }
     581              : 
     582              :     #[test]
     583            1 :     fn test_default_auth_rate_limit_set() {
     584            1 :         // these values used to exceed u32::MAX
     585            1 :         assert_eq!(
     586            1 :             RateBucketInfo::DEFAULT_AUTH_SET,
     587            1 :             [
     588            1 :                 RateBucketInfo {
     589            1 :                     interval: Duration::from_secs(1),
     590            1 :                     max_rpi: 1000 * 4096,
     591            1 :                 },
     592            1 :                 RateBucketInfo {
     593            1 :                     interval: Duration::from_secs(60),
     594            1 :                     max_rpi: 600 * 4096 * 60,
     595            1 :                 },
     596            1 :                 RateBucketInfo {
     597            1 :                     interval: Duration::from_secs(600),
     598            1 :                     max_rpi: 300 * 4096 * 600,
     599            1 :                 }
     600            1 :             ]
     601            1 :         );
     602              : 
     603            4 :         for x in RateBucketInfo::DEFAULT_AUTH_SET {
     604            3 :             let y = x.to_string().parse().unwrap();
     605            3 :             assert_eq!(x, y);
     606              :         }
     607            1 :     }
     608              : 
     609              :     #[tokio::test]
     610            1 :     async fn auth_quirks_scram() {
     611            1 :         let (mut client, server) = tokio::io::duplex(1024);
     612            1 :         let mut stream = PqStream::new(Stream::from_raw(server));
     613            1 : 
     614            1 :         let ctx = RequestContext::test();
     615            1 :         let api = Auth {
     616            1 :             ips: vec![],
     617            1 :             secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
     618            1 :         };
     619            1 : 
     620            1 :         let user_info = ComputeUserInfoMaybeEndpoint {
     621            1 :             user: "conrad".into(),
     622            1 :             endpoint_id: Some("endpoint".into()),
     623            1 :             options: NeonOptions::default(),
     624            1 :         };
     625            1 : 
     626            1 :         let handle = tokio::spawn(async move {
     627            1 :             let mut scram = ScramSha256::new(b"my-secret-password", ChannelBinding::unsupported());
     628            1 : 
     629            1 :             let mut read = BytesMut::new();
     630            1 : 
     631            1 :             // server should offer scram
     632            1 :             match read_message(&mut client, &mut read).await {
     633            1 :                 PgMessage::AuthenticationSasl(a) => {
     634            1 :                     let options: Vec<&str> = a.mechanisms().collect().unwrap();
     635            1 :                     assert_eq!(options, ["SCRAM-SHA-256"]);
     636            1 :                 }
     637            1 :                 _ => panic!("wrong message"),
     638            1 :             }
     639            1 : 
     640            1 :             // client sends client-first-message
     641            1 :             let mut write = BytesMut::new();
     642            1 :             frontend::sasl_initial_response("SCRAM-SHA-256", scram.message(), &mut write).unwrap();
     643            1 :             client.write_all(&write).await.unwrap();
     644            1 : 
     645            1 :             // server response with server-first-message
     646            1 :             match read_message(&mut client, &mut read).await {
     647            1 :                 PgMessage::AuthenticationSaslContinue(a) => {
     648            1 :                     scram.update(a.data()).await.unwrap();
     649            1 :                 }
     650            1 :                 _ => panic!("wrong message"),
     651            1 :             }
     652            1 : 
     653            1 :             // client response with client-final-message
     654            1 :             write.clear();
     655            1 :             frontend::sasl_response(scram.message(), &mut write).unwrap();
     656            1 :             client.write_all(&write).await.unwrap();
     657            1 : 
     658            1 :             // server response with server-final-message
     659            1 :             match read_message(&mut client, &mut read).await {
     660            1 :                 PgMessage::AuthenticationSaslFinal(a) => {
     661            1 :                     scram.finish(a.data()).unwrap();
     662            1 :                 }
     663            1 :                 _ => panic!("wrong message"),
     664            1 :             }
     665            1 :         });
     666            1 :         let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
     667            1 :             EndpointRateLimiter::DEFAULT,
     668            1 :             64,
     669            1 :         ));
     670            1 : 
     671            1 :         let _creds = auth_quirks(
     672            1 :             &ctx,
     673            1 :             &api,
     674            1 :             user_info,
     675            1 :             &mut stream,
     676            1 :             false,
     677            1 :             &CONFIG,
     678            1 :             endpoint_rate_limiter,
     679            1 :         )
     680            1 :         .await
     681            1 :         .unwrap();
     682            1 : 
     683            1 :         handle.await.unwrap();
     684            1 :     }
     685              : 
     686              :     #[tokio::test]
     687            1 :     async fn auth_quirks_cleartext() {
     688            1 :         let (mut client, server) = tokio::io::duplex(1024);
     689            1 :         let mut stream = PqStream::new(Stream::from_raw(server));
     690            1 : 
     691            1 :         let ctx = RequestContext::test();
     692            1 :         let api = Auth {
     693            1 :             ips: vec![],
     694            1 :             secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
     695            1 :         };
     696            1 : 
     697            1 :         let user_info = ComputeUserInfoMaybeEndpoint {
     698            1 :             user: "conrad".into(),
     699            1 :             endpoint_id: Some("endpoint".into()),
     700            1 :             options: NeonOptions::default(),
     701            1 :         };
     702            1 : 
     703            1 :         let handle = tokio::spawn(async move {
     704            1 :             let mut read = BytesMut::new();
     705            1 :             let mut write = BytesMut::new();
     706            1 : 
     707            1 :             // server should offer cleartext
     708            1 :             match read_message(&mut client, &mut read).await {
     709            1 :                 PgMessage::AuthenticationCleartextPassword => {}
     710            1 :                 _ => panic!("wrong message"),
     711            1 :             }
     712            1 : 
     713            1 :             // client responds with password
     714            1 :             write.clear();
     715            1 :             frontend::password_message(b"my-secret-password", &mut write).unwrap();
     716            1 :             client.write_all(&write).await.unwrap();
     717            1 :         });
     718            1 :         let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
     719            1 :             EndpointRateLimiter::DEFAULT,
     720            1 :             64,
     721            1 :         ));
     722            1 : 
     723            1 :         let _creds = auth_quirks(
     724            1 :             &ctx,
     725            1 :             &api,
     726            1 :             user_info,
     727            1 :             &mut stream,
     728            1 :             true,
     729            1 :             &CONFIG,
     730            1 :             endpoint_rate_limiter,
     731            1 :         )
     732            1 :         .await
     733            1 :         .unwrap();
     734            1 : 
     735            1 :         handle.await.unwrap();
     736            1 :     }
     737              : 
     738              :     #[tokio::test]
     739            1 :     async fn auth_quirks_password_hack() {
     740            1 :         let (mut client, server) = tokio::io::duplex(1024);
     741            1 :         let mut stream = PqStream::new(Stream::from_raw(server));
     742            1 : 
     743            1 :         let ctx = RequestContext::test();
     744            1 :         let api = Auth {
     745            1 :             ips: vec![],
     746            1 :             secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
     747            1 :         };
     748            1 : 
     749            1 :         let user_info = ComputeUserInfoMaybeEndpoint {
     750            1 :             user: "conrad".into(),
     751            1 :             endpoint_id: None,
     752            1 :             options: NeonOptions::default(),
     753            1 :         };
     754            1 : 
     755            1 :         let handle = tokio::spawn(async move {
     756            1 :             let mut read = BytesMut::new();
     757            1 : 
     758            1 :             // server should offer cleartext
     759            1 :             match read_message(&mut client, &mut read).await {
     760            1 :                 PgMessage::AuthenticationCleartextPassword => {}
     761            1 :                 _ => panic!("wrong message"),
     762            1 :             }
     763            1 : 
     764            1 :             // client responds with password
     765            1 :             let mut write = BytesMut::new();
     766            1 :             frontend::password_message(b"endpoint=my-endpoint;my-secret-password", &mut write)
     767            1 :                 .unwrap();
     768            1 :             client.write_all(&write).await.unwrap();
     769            1 :         });
     770            1 : 
     771            1 :         let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
     772            1 :             EndpointRateLimiter::DEFAULT,
     773            1 :             64,
     774            1 :         ));
     775            1 : 
     776            1 :         let creds = auth_quirks(
     777            1 :             &ctx,
     778            1 :             &api,
     779            1 :             user_info,
     780            1 :             &mut stream,
     781            1 :             true,
     782            1 :             &CONFIG,
     783            1 :             endpoint_rate_limiter,
     784            1 :         )
     785            1 :         .await
     786            1 :         .unwrap();
     787            1 : 
     788            1 :         assert_eq!(creds.info.endpoint, "my-endpoint");
     789            1 : 
     790            1 :         handle.await.unwrap();
     791            1 :     }
     792              : }
        

Generated by: LCOV version 2.1-beta