LCOV - code coverage report
Current view: top level - proxy/src/auth/backend - mod.rs (source / functions) Coverage Total Hit
Test: 472031e0b71f3195f7f21b1f2b20de09fd07bb56.info Lines: 72.7 % 564 410
Test Date: 2025-05-26 10:37:33 Functions: 45.9 % 74 34

            Line data    Source code
       1              : mod classic;
       2              : mod console_redirect;
       3              : mod hacks;
       4              : pub mod jwt;
       5              : pub mod local;
       6              : 
       7              : use std::net::IpAddr;
       8              : use std::sync::Arc;
       9              : 
      10              : pub use console_redirect::ConsoleRedirectBackend;
      11              : pub(crate) use console_redirect::ConsoleRedirectError;
      12              : use ipnet::{Ipv4Net, Ipv6Net};
      13              : use local::LocalBackend;
      14              : use postgres_client::config::AuthKeys;
      15              : use serde::{Deserialize, Serialize};
      16              : use tokio::io::{AsyncRead, AsyncWrite};
      17              : use tracing::{debug, info, warn};
      18              : 
      19              : use crate::auth::credentials::check_peer_addr_is_in_list;
      20              : use crate::auth::{
      21              :     self, AuthError, ComputeUserInfoMaybeEndpoint, IpPattern, validate_password_and_exchange,
      22              : };
      23              : use crate::cache::Cached;
      24              : use crate::config::AuthenticationConfig;
      25              : use crate::context::RequestContext;
      26              : use crate::control_plane::client::ControlPlaneClient;
      27              : use crate::control_plane::errors::GetAuthInfoError;
      28              : use crate::control_plane::{
      29              :     self, AccessBlockerFlags, AuthSecret, CachedAccessBlockerFlags, CachedAllowedIps,
      30              :     CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret, ControlPlaneApi,
      31              : };
      32              : use crate::intern::EndpointIdInt;
      33              : use crate::metrics::Metrics;
      34              : use crate::protocol2::ConnectionInfoExtra;
      35              : use crate::proxy::NeonOptions;
      36              : use crate::proxy::connect_compute::ComputeConnectBackend;
      37              : use crate::rate_limiter::{BucketRateLimiter, EndpointRateLimiter};
      38              : use crate::stream::Stream;
      39              : use crate::types::{EndpointCacheKey, EndpointId, RoleName};
      40              : use crate::{scram, stream};
      41              : 
      42              : /// Alternative to [`std::borrow::Cow`] but doesn't need `T: ToOwned` as we don't need that functionality
      43              : pub enum MaybeOwned<'a, T> {
      44              :     Owned(T),
      45              :     Borrowed(&'a T),
      46              : }
      47              : 
      48              : impl<T> std::ops::Deref for MaybeOwned<'_, T> {
      49              :     type Target = T;
      50              : 
      51           19 :     fn deref(&self) -> &Self::Target {
      52           19 :         match self {
      53           19 :             MaybeOwned::Owned(t) => t,
      54            0 :             MaybeOwned::Borrowed(t) => t,
      55              :         }
      56           19 :     }
      57              : }
      58              : 
      59              : /// This type serves two purposes:
      60              : ///
      61              : /// * When `T` is `()`, it's just a regular auth backend selector
      62              : ///   which we use in [`crate::config::ProxyConfig`].
      63              : ///
      64              : /// * However, when we substitute `T` with [`ComputeUserInfoMaybeEndpoint`],
      65              : ///   this helps us provide the credentials only to those auth
      66              : ///   backends which require them for the authentication process.
      67              : pub enum Backend<'a, T> {
      68              :     /// Cloud API (V2).
      69              :     ControlPlane(MaybeOwned<'a, ControlPlaneClient>, T),
      70              :     /// Local proxy uses configured auth credentials and does not wake compute
      71              :     Local(MaybeOwned<'a, LocalBackend>),
      72              : }
      73              : 
      74              : impl std::fmt::Display for Backend<'_, ()> {
      75            0 :     fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
      76            0 :         match self {
      77            0 :             Self::ControlPlane(api, ()) => match &**api {
      78            0 :                 ControlPlaneClient::ProxyV1(endpoint) => fmt
      79            0 :                     .debug_tuple("ControlPlane::ProxyV1")
      80            0 :                     .field(&endpoint.url())
      81            0 :                     .finish(),
      82              :                 #[cfg(any(test, feature = "testing"))]
      83            0 :                 ControlPlaneClient::PostgresMock(endpoint) => {
      84            0 :                     let url = endpoint.url();
      85            0 :                     match url::Url::parse(url) {
      86            0 :                         Ok(mut url) => {
      87            0 :                             let _ = url.set_password(Some("_redacted_"));
      88            0 :                             let url = url.as_str();
      89            0 :                             fmt.debug_tuple("ControlPlane::PostgresMock")
      90            0 :                                 .field(&url)
      91            0 :                                 .finish()
      92              :                         }
      93            0 :                         Err(_) => fmt
      94            0 :                             .debug_tuple("ControlPlane::PostgresMock")
      95            0 :                             .field(&url)
      96            0 :                             .finish(),
      97              :                     }
      98              :                 }
      99              :                 #[cfg(test)]
     100            0 :                 ControlPlaneClient::Test(_) => fmt.debug_tuple("ControlPlane::Test").finish(),
     101              :             },
     102            0 :             Self::Local(_) => fmt.debug_tuple("Local").finish(),
     103              :         }
     104            0 :     }
     105              : }
     106              : 
     107              : impl<T> Backend<'_, T> {
     108              :     /// Very similar to [`std::option::Option::as_ref`].
     109              :     /// This helps us pass structured config to async tasks.
     110            0 :     pub(crate) fn as_ref(&self) -> Backend<'_, &T> {
     111            0 :         match self {
     112            0 :             Self::ControlPlane(c, x) => Backend::ControlPlane(MaybeOwned::Borrowed(c), x),
     113            0 :             Self::Local(l) => Backend::Local(MaybeOwned::Borrowed(l)),
     114              :         }
     115            0 :     }
     116              : 
     117            0 :     pub(crate) fn get_api(&self) -> &ControlPlaneClient {
     118            0 :         match self {
     119            0 :             Self::ControlPlane(api, _) => api,
     120            0 :             Self::Local(_) => panic!("Local backend has no API"),
     121              :         }
     122            0 :     }
     123              : 
     124            0 :     pub(crate) fn is_local_proxy(&self) -> bool {
     125            0 :         matches!(self, Self::Local(_))
     126            0 :     }
     127              : }
     128              : 
     129              : impl<'a, T> Backend<'a, T> {
     130              :     /// Very similar to [`std::option::Option::map`].
     131              :     /// Maps [`Backend<T>`] to [`Backend<R>`] by applying
     132              :     /// a function to a contained value.
     133            0 :     pub(crate) fn map<R>(self, f: impl FnOnce(T) -> R) -> Backend<'a, R> {
     134            0 :         match self {
     135            0 :             Self::ControlPlane(c, x) => Backend::ControlPlane(c, f(x)),
     136            0 :             Self::Local(l) => Backend::Local(l),
     137              :         }
     138            0 :     }
     139              : }
     140              : impl<'a, T, E> Backend<'a, Result<T, E>> {
     141              :     /// Very similar to [`std::option::Option::transpose`].
     142              :     /// This is most useful for error handling.
     143            0 :     pub(crate) fn transpose(self) -> Result<Backend<'a, T>, E> {
     144            0 :         match self {
     145            0 :             Self::ControlPlane(c, x) => x.map(|x| Backend::ControlPlane(c, x)),
     146            0 :             Self::Local(l) => Ok(Backend::Local(l)),
     147              :         }
     148            0 :     }
     149              : }
     150              : 
     151              : pub(crate) struct ComputeCredentials {
     152              :     pub(crate) info: ComputeUserInfo,
     153              :     pub(crate) keys: ComputeCredentialKeys,
     154              : }
     155              : 
     156              : #[derive(Debug, Clone)]
     157              : pub(crate) struct ComputeUserInfoNoEndpoint {
     158              :     pub(crate) user: RoleName,
     159              :     pub(crate) options: NeonOptions,
     160              : }
     161              : 
     162            0 : #[derive(Debug, Clone, Default, Serialize, Deserialize)]
     163              : pub(crate) struct ComputeUserInfo {
     164              :     pub(crate) endpoint: EndpointId,
     165              :     pub(crate) user: RoleName,
     166              :     pub(crate) options: NeonOptions,
     167              : }
     168              : 
     169              : impl ComputeUserInfo {
     170            2 :     pub(crate) fn endpoint_cache_key(&self) -> EndpointCacheKey {
     171            2 :         self.options.get_cache_key(&self.endpoint)
     172            2 :     }
     173              : }
     174              : 
     175              : #[cfg_attr(test, derive(Debug))]
     176              : pub(crate) enum ComputeCredentialKeys {
     177              :     #[cfg(any(test, feature = "testing"))]
     178              :     Password(Vec<u8>),
     179              :     AuthKeys(AuthKeys),
     180              :     JwtPayload(Vec<u8>),
     181              :     None,
     182              : }
     183              : 
     184              : impl TryFrom<ComputeUserInfoMaybeEndpoint> for ComputeUserInfo {
     185              :     // user name
     186              :     type Error = ComputeUserInfoNoEndpoint;
     187              : 
     188            3 :     fn try_from(user_info: ComputeUserInfoMaybeEndpoint) -> Result<Self, Self::Error> {
     189            3 :         match user_info.endpoint_id {
     190            1 :             None => Err(ComputeUserInfoNoEndpoint {
     191            1 :                 user: user_info.user,
     192            1 :                 options: user_info.options,
     193            1 :             }),
     194            2 :             Some(endpoint) => Ok(ComputeUserInfo {
     195            2 :                 endpoint,
     196            2 :                 user: user_info.user,
     197            2 :                 options: user_info.options,
     198            2 :             }),
     199              :         }
     200            3 :     }
     201              : }
     202              : 
     203              : #[derive(PartialEq, PartialOrd, Hash, Eq, Ord, Debug, Copy, Clone)]
     204              : pub struct MaskedIp(IpAddr);
     205              : 
     206              : impl MaskedIp {
     207           15 :     fn new(value: IpAddr, prefix: u8) -> Self {
     208           15 :         match value {
     209           11 :             IpAddr::V4(v4) => Self(IpAddr::V4(
     210           11 :                 Ipv4Net::new(v4, prefix).map_or(v4, |x| x.trunc().addr()),
     211           11 :             )),
     212            4 :             IpAddr::V6(v6) => Self(IpAddr::V6(
     213            4 :                 Ipv6Net::new(v6, prefix).map_or(v6, |x| x.trunc().addr()),
     214            4 :             )),
     215              :         }
     216           15 :     }
     217              : }
     218              : 
     219              : // This can't be just per IP because that would limit some PaaS that share IP addresses
     220              : pub type AuthRateLimiter = BucketRateLimiter<(EndpointIdInt, MaskedIp)>;
     221              : 
     222              : impl AuthenticationConfig {
     223            3 :     pub(crate) fn check_rate_limit(
     224            3 :         &self,
     225            3 :         ctx: &RequestContext,
     226            3 :         secret: AuthSecret,
     227            3 :         endpoint: &EndpointId,
     228            3 :         is_cleartext: bool,
     229            3 :     ) -> auth::Result<AuthSecret> {
     230            3 :         // we have validated the endpoint exists, so let's intern it.
     231            3 :         let endpoint_int = EndpointIdInt::from(endpoint.normalize());
     232              : 
     233              :         // only count the full hash count if password hack or websocket flow.
     234              :         // in other words, if proxy needs to run the hashing
     235            3 :         let password_weight = if is_cleartext {
     236            2 :             match &secret {
     237              :                 #[cfg(any(test, feature = "testing"))]
     238            0 :                 AuthSecret::Md5(_) => 1,
     239            2 :                 AuthSecret::Scram(s) => s.iterations + 1,
     240              :             }
     241              :         } else {
     242              :             // validating scram takes just 1 hmac_sha_256 operation.
     243            1 :             1
     244              :         };
     245              : 
     246            3 :         let limit_not_exceeded = self.rate_limiter.check(
     247            3 :             (
     248            3 :                 endpoint_int,
     249            3 :                 MaskedIp::new(ctx.peer_addr(), self.rate_limit_ip_subnet),
     250            3 :             ),
     251            3 :             password_weight,
     252            3 :         );
     253            3 : 
     254            3 :         if !limit_not_exceeded {
     255            0 :             warn!(
     256              :                 enabled = self.rate_limiter_enabled,
     257            0 :                 "rate limiting authentication"
     258              :             );
     259            0 :             Metrics::get().proxy.requests_auth_rate_limits_total.inc();
     260            0 :             Metrics::get()
     261            0 :                 .proxy
     262            0 :                 .endpoints_auth_rate_limits
     263            0 :                 .get_metric()
     264            0 :                 .measure(endpoint);
     265            0 : 
     266            0 :             if self.rate_limiter_enabled {
     267            0 :                 return Err(auth::AuthError::too_many_connections());
     268            0 :             }
     269            3 :         }
     270              : 
     271            3 :         Ok(secret)
     272            3 :     }
     273              : }
     274              : 
     275              : /// True to its name, this function encapsulates our current auth trade-offs.
     276              : /// Here, we choose the appropriate auth flow based on circumstances.
     277              : ///
     278              : /// All authentication flows will emit an AuthenticationOk message if successful.
     279            3 : async fn auth_quirks(
     280            3 :     ctx: &RequestContext,
     281            3 :     api: &impl control_plane::ControlPlaneApi,
     282            3 :     user_info: ComputeUserInfoMaybeEndpoint,
     283            3 :     client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
     284            3 :     allow_cleartext: bool,
     285            3 :     config: &'static AuthenticationConfig,
     286            3 :     endpoint_rate_limiter: Arc<EndpointRateLimiter>,
     287            3 : ) -> auth::Result<(ComputeCredentials, Option<Vec<IpPattern>>)> {
     288              :     // If there's no project so far, that entails that client doesn't
     289              :     // support SNI or other means of passing the endpoint (project) name.
     290              :     // We now expect to see a very specific payload in the place of password.
     291            3 :     let (info, unauthenticated_password) = match user_info.try_into() {
     292            1 :         Err(info) => {
     293            1 :             let (info, password) =
     294            1 :                 hacks::password_hack_no_authentication(ctx, info, client).await?;
     295            1 :             ctx.set_endpoint_id(info.endpoint.clone());
     296            1 :             (info, Some(password))
     297              :         }
     298            2 :         Ok(info) => (info, None),
     299              :     };
     300              : 
     301            3 :     debug!("fetching authentication info and allowlists");
     302              : 
     303              :     // check allowed list
     304            3 :     let allowed_ips = if config.ip_allowlist_check_enabled {
     305            3 :         let allowed_ips = api.get_allowed_ips(ctx, &info).await?;
     306            3 :         if !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips) {
     307            0 :             return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr()));
     308            3 :         }
     309            3 :         allowed_ips
     310              :     } else {
     311            0 :         Cached::new_uncached(Arc::new(vec![]))
     312              :     };
     313              : 
     314              :     // check if a VPC endpoint ID is coming in and if yes, if it's allowed
     315            3 :     let access_blocks = api.get_block_public_or_vpc_access(ctx, &info).await?;
     316            3 :     if config.is_vpc_acccess_proxy {
     317            0 :         if access_blocks.vpc_access_blocked {
     318            0 :             return Err(AuthError::NetworkNotAllowed);
     319            0 :         }
     320              : 
     321            0 :         let incoming_vpc_endpoint_id = match ctx.extra() {
     322            0 :             None => return Err(AuthError::MissingEndpointName),
     323            0 :             Some(ConnectionInfoExtra::Aws { vpce_id }) => vpce_id.to_string(),
     324            0 :             Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(),
     325              :         };
     326            0 :         let allowed_vpc_endpoint_ids = api.get_allowed_vpc_endpoint_ids(ctx, &info).await?;
     327              :         // TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that.
     328            0 :         if !allowed_vpc_endpoint_ids.is_empty()
     329            0 :             && !allowed_vpc_endpoint_ids.contains(&incoming_vpc_endpoint_id)
     330              :         {
     331            0 :             return Err(AuthError::vpc_endpoint_id_not_allowed(
     332            0 :                 incoming_vpc_endpoint_id,
     333            0 :             ));
     334            0 :         }
     335            3 :     } else if access_blocks.public_access_blocked {
     336            0 :         return Err(AuthError::NetworkNotAllowed);
     337            3 :     }
     338              : 
     339            3 :     if !endpoint_rate_limiter.check(info.endpoint.clone().into(), 1) {
     340            0 :         return Err(AuthError::too_many_connections());
     341            3 :     }
     342            3 :     let cached_secret = api.get_role_secret(ctx, &info).await?;
     343            3 :     let (cached_entry, secret) = cached_secret.take_value();
     344              : 
     345            3 :     let secret = if let Some(secret) = secret {
     346            3 :         config.check_rate_limit(
     347            3 :             ctx,
     348            3 :             secret,
     349            3 :             &info.endpoint,
     350            3 :             unauthenticated_password.is_some() || allow_cleartext,
     351            0 :         )?
     352              :     } else {
     353              :         // If we don't have an authentication secret, we mock one to
     354              :         // prevent malicious probing (possible due to missing protocol steps).
     355              :         // This mocked secret will never lead to successful authentication.
     356            0 :         info!("authentication info not found, mocking it");
     357            0 :         AuthSecret::Scram(scram::ServerSecret::mock(rand::random()))
     358              :     };
     359              : 
     360            3 :     match authenticate_with_secret(
     361            3 :         ctx,
     362            3 :         secret,
     363            3 :         info,
     364            3 :         client,
     365            3 :         unauthenticated_password,
     366            3 :         allow_cleartext,
     367            3 :         config,
     368            3 :     )
     369            3 :     .await
     370              :     {
     371            3 :         Ok(keys) => Ok((keys, Some(allowed_ips.as_ref().clone()))),
     372            0 :         Err(e) => {
     373            0 :             if e.is_password_failed() {
     374            0 :                 // The password could have been changed, so we invalidate the cache.
     375            0 :                 cached_entry.invalidate();
     376            0 :             }
     377            0 :             Err(e)
     378              :         }
     379              :     }
     380            3 : }
     381              : 
     382            3 : async fn authenticate_with_secret(
     383            3 :     ctx: &RequestContext,
     384            3 :     secret: AuthSecret,
     385            3 :     info: ComputeUserInfo,
     386            3 :     client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
     387            3 :     unauthenticated_password: Option<Vec<u8>>,
     388            3 :     allow_cleartext: bool,
     389            3 :     config: &'static AuthenticationConfig,
     390            3 : ) -> auth::Result<ComputeCredentials> {
     391            3 :     if let Some(password) = unauthenticated_password {
     392            1 :         let ep = EndpointIdInt::from(&info.endpoint);
     393              : 
     394            1 :         let auth_outcome =
     395            1 :             validate_password_and_exchange(&config.thread_pool, ep, &password, secret).await?;
     396            1 :         let keys = match auth_outcome {
     397            1 :             crate::sasl::Outcome::Success(key) => key,
     398            0 :             crate::sasl::Outcome::Failure(reason) => {
     399            0 :                 info!("auth backend failed with an error: {reason}");
     400            0 :                 return Err(auth::AuthError::password_failed(&*info.user));
     401              :             }
     402              :         };
     403              : 
     404              :         // we have authenticated the password
     405            1 :         client.write_message_noflush(&pq_proto::BeMessage::AuthenticationOk)?;
     406              : 
     407            1 :         return Ok(ComputeCredentials { info, keys });
     408            2 :     }
     409            2 : 
     410            2 :     // -- the remaining flows are self-authenticating --
     411            2 : 
     412            2 :     // Perform cleartext auth if we're allowed to do that.
     413            2 :     // Currently, we use it for websocket connections (latency).
     414            2 :     if allow_cleartext {
     415            1 :         ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
     416            1 :         return hacks::authenticate_cleartext(ctx, info, client, secret, config).await;
     417            1 :     }
     418            1 : 
     419            1 :     // Finally, proceed with the main auth flow (SCRAM-based).
     420            1 :     classic::authenticate(ctx, info, client, config, secret).await
     421            3 : }
     422              : 
     423              : impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> {
     424              :     /// Get username from the credentials.
     425            0 :     pub(crate) fn get_user(&self) -> &str {
     426            0 :         match self {
     427            0 :             Self::ControlPlane(_, user_info) => &user_info.user,
     428            0 :             Self::Local(_) => "local",
     429              :         }
     430            0 :     }
     431              : 
     432              :     /// Authenticate the client via the requested backend, possibly using credentials.
     433              :     #[tracing::instrument(fields(allow_cleartext = allow_cleartext), skip_all)]
     434              :     pub(crate) async fn authenticate(
     435              :         self,
     436              :         ctx: &RequestContext,
     437              :         client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
     438              :         allow_cleartext: bool,
     439              :         config: &'static AuthenticationConfig,
     440              :         endpoint_rate_limiter: Arc<EndpointRateLimiter>,
     441              :     ) -> auth::Result<(Backend<'a, ComputeCredentials>, Option<Vec<IpPattern>>)> {
     442              :         let res = match self {
     443              :             Self::ControlPlane(api, user_info) => {
     444              :                 debug!(
     445              :                     user = &*user_info.user,
     446              :                     project = user_info.endpoint(),
     447              :                     "performing authentication using the console"
     448              :                 );
     449              : 
     450              :                 let (credentials, ip_allowlist) = auth_quirks(
     451              :                     ctx,
     452              :                     &*api,
     453              :                     user_info,
     454              :                     client,
     455              :                     allow_cleartext,
     456              :                     config,
     457              :                     endpoint_rate_limiter,
     458              :                 )
     459              :                 .await?;
     460              :                 Ok((Backend::ControlPlane(api, credentials), ip_allowlist))
     461              :             }
     462              :             Self::Local(_) => {
     463              :                 return Err(auth::AuthError::bad_auth_method("invalid for local proxy"));
     464              :             }
     465              :         };
     466              : 
     467              :         // TODO: replace with some metric
     468              :         info!("user successfully authenticated");
     469              :         res
     470              :     }
     471              : }
     472              : 
     473              : impl Backend<'_, ComputeUserInfo> {
     474            0 :     pub(crate) async fn get_role_secret(
     475            0 :         &self,
     476            0 :         ctx: &RequestContext,
     477            0 :     ) -> Result<CachedRoleSecret, GetAuthInfoError> {
     478            0 :         match self {
     479            0 :             Self::ControlPlane(api, user_info) => api.get_role_secret(ctx, user_info).await,
     480            0 :             Self::Local(_) => Ok(Cached::new_uncached(None)),
     481              :         }
     482            0 :     }
     483              : 
     484            0 :     pub(crate) async fn get_allowed_ips(
     485            0 :         &self,
     486            0 :         ctx: &RequestContext,
     487            0 :     ) -> Result<CachedAllowedIps, GetAuthInfoError> {
     488            0 :         match self {
     489            0 :             Self::ControlPlane(api, user_info) => api.get_allowed_ips(ctx, user_info).await,
     490            0 :             Self::Local(_) => Ok(Cached::new_uncached(Arc::new(vec![]))),
     491              :         }
     492            0 :     }
     493              : 
     494            0 :     pub(crate) async fn get_allowed_vpc_endpoint_ids(
     495            0 :         &self,
     496            0 :         ctx: &RequestContext,
     497            0 :     ) -> Result<CachedAllowedVpcEndpointIds, GetAuthInfoError> {
     498            0 :         match self {
     499            0 :             Self::ControlPlane(api, user_info) => {
     500            0 :                 api.get_allowed_vpc_endpoint_ids(ctx, user_info).await
     501              :             }
     502            0 :             Self::Local(_) => Ok(Cached::new_uncached(Arc::new(vec![]))),
     503              :         }
     504            0 :     }
     505              : 
     506            0 :     pub(crate) async fn get_block_public_or_vpc_access(
     507            0 :         &self,
     508            0 :         ctx: &RequestContext,
     509            0 :     ) -> Result<CachedAccessBlockerFlags, GetAuthInfoError> {
     510            0 :         match self {
     511            0 :             Self::ControlPlane(api, user_info) => {
     512            0 :                 api.get_block_public_or_vpc_access(ctx, user_info).await
     513              :             }
     514            0 :             Self::Local(_) => Ok(Cached::new_uncached(AccessBlockerFlags::default())),
     515              :         }
     516            0 :     }
     517              : }
     518              : 
     519              : #[async_trait::async_trait]
     520              : impl ComputeConnectBackend for Backend<'_, ComputeCredentials> {
     521           19 :     async fn wake_compute(
     522           19 :         &self,
     523           19 :         ctx: &RequestContext,
     524           19 :     ) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
     525           19 :         match self {
     526           19 :             Self::ControlPlane(api, creds) => api.wake_compute(ctx, &creds.info).await,
     527            0 :             Self::Local(local) => Ok(Cached::new_uncached(local.node_info.clone())),
     528              :         }
     529           38 :     }
     530              : 
     531           10 :     fn get_keys(&self) -> &ComputeCredentialKeys {
     532           10 :         match self {
     533           10 :             Self::ControlPlane(_, creds) => &creds.keys,
     534            0 :             Self::Local(_) => &ComputeCredentialKeys::None,
     535              :         }
     536           10 :     }
     537              : }
     538              : 
     539              : #[cfg(test)]
     540              : mod tests {
     541              :     #![allow(clippy::unimplemented, clippy::unwrap_used)]
     542              : 
     543              :     use std::net::IpAddr;
     544              :     use std::sync::Arc;
     545              :     use std::time::Duration;
     546              : 
     547              :     use bytes::BytesMut;
     548              :     use control_plane::AuthSecret;
     549              :     use fallible_iterator::FallibleIterator;
     550              :     use once_cell::sync::Lazy;
     551              :     use postgres_protocol::authentication::sasl::{ChannelBinding, ScramSha256};
     552              :     use postgres_protocol::message::backend::Message as PgMessage;
     553              :     use postgres_protocol::message::frontend;
     554              :     use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
     555              : 
     556              :     use super::jwt::JwkCache;
     557              :     use super::{AuthRateLimiter, auth_quirks};
     558              :     use crate::auth::backend::MaskedIp;
     559              :     use crate::auth::{ComputeUserInfoMaybeEndpoint, IpPattern};
     560              :     use crate::config::AuthenticationConfig;
     561              :     use crate::context::RequestContext;
     562              :     use crate::control_plane::{
     563              :         self, AccessBlockerFlags, CachedAccessBlockerFlags, CachedAllowedIps,
     564              :         CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret,
     565              :     };
     566              :     use crate::proxy::NeonOptions;
     567              :     use crate::rate_limiter::{EndpointRateLimiter, RateBucketInfo};
     568              :     use crate::scram::ServerSecret;
     569              :     use crate::scram::threadpool::ThreadPool;
     570              :     use crate::stream::{PqStream, Stream};
     571              : 
     572              :     struct Auth {
     573              :         ips: Vec<IpPattern>,
     574              :         vpc_endpoint_ids: Vec<String>,
     575              :         access_blocker_flags: AccessBlockerFlags,
     576              :         secret: AuthSecret,
     577              :     }
     578              : 
     579              :     impl control_plane::ControlPlaneApi for Auth {
     580            3 :         async fn get_role_secret(
     581            3 :             &self,
     582            3 :             _ctx: &RequestContext,
     583            3 :             _user_info: &super::ComputeUserInfo,
     584            3 :         ) -> Result<CachedRoleSecret, control_plane::errors::GetAuthInfoError> {
     585            3 :             Ok(CachedRoleSecret::new_uncached(Some(self.secret.clone())))
     586            3 :         }
     587              : 
     588            3 :         async fn get_allowed_ips(
     589            3 :             &self,
     590            3 :             _ctx: &RequestContext,
     591            3 :             _user_info: &super::ComputeUserInfo,
     592            3 :         ) -> Result<CachedAllowedIps, control_plane::errors::GetAuthInfoError> {
     593            3 :             Ok(CachedAllowedIps::new_uncached(Arc::new(self.ips.clone())))
     594            3 :         }
     595              : 
     596            0 :         async fn get_allowed_vpc_endpoint_ids(
     597            0 :             &self,
     598            0 :             _ctx: &RequestContext,
     599            0 :             _user_info: &super::ComputeUserInfo,
     600            0 :         ) -> Result<CachedAllowedVpcEndpointIds, control_plane::errors::GetAuthInfoError> {
     601            0 :             Ok(CachedAllowedVpcEndpointIds::new_uncached(Arc::new(
     602            0 :                 self.vpc_endpoint_ids.clone(),
     603            0 :             )))
     604            0 :         }
     605              : 
     606            3 :         async fn get_block_public_or_vpc_access(
     607            3 :             &self,
     608            3 :             _ctx: &RequestContext,
     609            3 :             _user_info: &super::ComputeUserInfo,
     610            3 :         ) -> Result<CachedAccessBlockerFlags, control_plane::errors::GetAuthInfoError> {
     611            3 :             Ok(CachedAccessBlockerFlags::new_uncached(
     612            3 :                 self.access_blocker_flags.clone(),
     613            3 :             ))
     614            3 :         }
     615              : 
     616            0 :         async fn get_endpoint_jwks(
     617            0 :             &self,
     618            0 :             _ctx: &RequestContext,
     619            0 :             _endpoint: crate::types::EndpointId,
     620            0 :         ) -> Result<Vec<super::jwt::AuthRule>, control_plane::errors::GetEndpointJwksError>
     621            0 :         {
     622            0 :             unimplemented!()
     623              :         }
     624              : 
     625            0 :         async fn wake_compute(
     626            0 :             &self,
     627            0 :             _ctx: &RequestContext,
     628            0 :             _user_info: &super::ComputeUserInfo,
     629            0 :         ) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
     630            0 :             unimplemented!()
     631              :         }
     632              :     }
     633              : 
     634            3 :     static CONFIG: Lazy<AuthenticationConfig> = Lazy::new(|| AuthenticationConfig {
     635            3 :         jwks_cache: JwkCache::default(),
     636            3 :         thread_pool: ThreadPool::new(1),
     637            3 :         scram_protocol_timeout: std::time::Duration::from_secs(5),
     638            3 :         rate_limiter_enabled: true,
     639            3 :         rate_limiter: AuthRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET),
     640            3 :         rate_limit_ip_subnet: 64,
     641            3 :         ip_allowlist_check_enabled: true,
     642            3 :         is_vpc_acccess_proxy: false,
     643            3 :         is_auth_broker: false,
     644            3 :         accept_jwts: false,
     645            3 :         console_redirect_confirmation_timeout: std::time::Duration::from_secs(5),
     646            3 :     });
     647              : 
     648            5 :     async fn read_message(r: &mut (impl AsyncRead + Unpin), b: &mut BytesMut) -> PgMessage {
     649              :         loop {
     650            7 :             r.read_buf(&mut *b).await.unwrap();
     651            7 :             if let Some(m) = PgMessage::parse(&mut *b).unwrap() {
     652            5 :                 break m;
     653            2 :             }
     654              :         }
     655            5 :     }
     656              : 
     657              :     #[test]
     658            1 :     fn masked_ip() {
     659            1 :         let ip_a = IpAddr::V4([127, 0, 0, 1].into());
     660            1 :         let ip_b = IpAddr::V4([127, 0, 0, 2].into());
     661            1 :         let ip_c = IpAddr::V4([192, 168, 1, 101].into());
     662            1 :         let ip_d = IpAddr::V4([192, 168, 1, 102].into());
     663            1 :         let ip_e = IpAddr::V6("abcd:abcd:abcd:abcd:abcd:abcd:abcd:abcd".parse().unwrap());
     664            1 :         let ip_f = IpAddr::V6("abcd:abcd:abcd:abcd:1234:abcd:abcd:abcd".parse().unwrap());
     665            1 : 
     666            1 :         assert_ne!(MaskedIp::new(ip_a, 64), MaskedIp::new(ip_b, 64));
     667            1 :         assert_ne!(MaskedIp::new(ip_a, 32), MaskedIp::new(ip_b, 32));
     668            1 :         assert_eq!(MaskedIp::new(ip_a, 30), MaskedIp::new(ip_b, 30));
     669            1 :         assert_eq!(MaskedIp::new(ip_c, 30), MaskedIp::new(ip_d, 30));
     670              : 
     671            1 :         assert_ne!(MaskedIp::new(ip_e, 128), MaskedIp::new(ip_f, 128));
     672            1 :         assert_eq!(MaskedIp::new(ip_e, 64), MaskedIp::new(ip_f, 64));
     673            1 :     }
     674              : 
     675              :     #[test]
     676            1 :     fn test_default_auth_rate_limit_set() {
     677            1 :         // these values used to exceed u32::MAX
     678            1 :         assert_eq!(
     679            1 :             RateBucketInfo::DEFAULT_AUTH_SET,
     680            1 :             [
     681            1 :                 RateBucketInfo {
     682            1 :                     interval: Duration::from_secs(1),
     683            1 :                     max_rpi: 1000 * 4096,
     684            1 :                 },
     685            1 :                 RateBucketInfo {
     686            1 :                     interval: Duration::from_secs(60),
     687            1 :                     max_rpi: 600 * 4096 * 60,
     688            1 :                 },
     689            1 :                 RateBucketInfo {
     690            1 :                     interval: Duration::from_secs(600),
     691            1 :                     max_rpi: 300 * 4096 * 600,
     692            1 :                 }
     693            1 :             ]
     694            1 :         );
     695              : 
     696            4 :         for x in RateBucketInfo::DEFAULT_AUTH_SET {
     697            3 :             let y = x.to_string().parse().unwrap();
     698            3 :             assert_eq!(x, y);
     699              :         }
     700            1 :     }
     701              : 
     702              :     #[tokio::test]
     703            1 :     async fn auth_quirks_scram() {
     704            1 :         let (mut client, server) = tokio::io::duplex(1024);
     705            1 :         let mut stream = PqStream::new(Stream::from_raw(server));
     706            1 : 
     707            1 :         let ctx = RequestContext::test();
     708            1 :         let api = Auth {
     709            1 :             ips: vec![],
     710            1 :             vpc_endpoint_ids: vec![],
     711            1 :             access_blocker_flags: AccessBlockerFlags::default(),
     712            1 :             secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
     713            1 :         };
     714            1 : 
     715            1 :         let user_info = ComputeUserInfoMaybeEndpoint {
     716            1 :             user: "conrad".into(),
     717            1 :             endpoint_id: Some("endpoint".into()),
     718            1 :             options: NeonOptions::default(),
     719            1 :         };
     720            1 : 
     721            1 :         let handle = tokio::spawn(async move {
     722            1 :             let mut scram = ScramSha256::new(b"my-secret-password", ChannelBinding::unsupported());
     723            1 : 
     724            1 :             let mut read = BytesMut::new();
     725            1 : 
     726            1 :             // server should offer scram
     727            1 :             match read_message(&mut client, &mut read).await {
     728            1 :                 PgMessage::AuthenticationSasl(a) => {
     729            1 :                     let options: Vec<&str> = a.mechanisms().collect().unwrap();
     730            1 :                     assert_eq!(options, ["SCRAM-SHA-256"]);
     731            1 :                 }
     732            1 :                 _ => panic!("wrong message"),
     733            1 :             }
     734            1 : 
     735            1 :             // client sends client-first-message
     736            1 :             let mut write = BytesMut::new();
     737            1 :             frontend::sasl_initial_response("SCRAM-SHA-256", scram.message(), &mut write).unwrap();
     738            1 :             client.write_all(&write).await.unwrap();
     739            1 : 
     740            1 :             // server response with server-first-message
     741            1 :             match read_message(&mut client, &mut read).await {
     742            1 :                 PgMessage::AuthenticationSaslContinue(a) => {
     743            1 :                     scram.update(a.data()).await.unwrap();
     744            1 :                 }
     745            1 :                 _ => panic!("wrong message"),
     746            1 :             }
     747            1 : 
     748            1 :             // client response with client-final-message
     749            1 :             write.clear();
     750            1 :             frontend::sasl_response(scram.message(), &mut write).unwrap();
     751            1 :             client.write_all(&write).await.unwrap();
     752            1 : 
     753            1 :             // server response with server-final-message
     754            1 :             match read_message(&mut client, &mut read).await {
     755            1 :                 PgMessage::AuthenticationSaslFinal(a) => {
     756            1 :                     scram.finish(a.data()).unwrap();
     757            1 :                 }
     758            1 :                 _ => panic!("wrong message"),
     759            1 :             }
     760            1 :         });
     761            1 :         let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
     762            1 :             EndpointRateLimiter::DEFAULT,
     763            1 :             64,
     764            1 :         ));
     765            1 : 
     766            1 :         let _creds = auth_quirks(
     767            1 :             &ctx,
     768            1 :             &api,
     769            1 :             user_info,
     770            1 :             &mut stream,
     771            1 :             false,
     772            1 :             &CONFIG,
     773            1 :             endpoint_rate_limiter,
     774            1 :         )
     775            1 :         .await
     776            1 :         .unwrap();
     777            1 : 
     778            1 :         // flush the final server message
     779            1 :         stream.flush().await.unwrap();
     780            1 : 
     781            1 :         handle.await.unwrap();
     782            1 :     }
     783              : 
     784              :     #[tokio::test]
     785            1 :     async fn auth_quirks_cleartext() {
     786            1 :         let (mut client, server) = tokio::io::duplex(1024);
     787            1 :         let mut stream = PqStream::new(Stream::from_raw(server));
     788            1 : 
     789            1 :         let ctx = RequestContext::test();
     790            1 :         let api = Auth {
     791            1 :             ips: vec![],
     792            1 :             vpc_endpoint_ids: vec![],
     793            1 :             access_blocker_flags: AccessBlockerFlags::default(),
     794            1 :             secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
     795            1 :         };
     796            1 : 
     797            1 :         let user_info = ComputeUserInfoMaybeEndpoint {
     798            1 :             user: "conrad".into(),
     799            1 :             endpoint_id: Some("endpoint".into()),
     800            1 :             options: NeonOptions::default(),
     801            1 :         };
     802            1 : 
     803            1 :         let handle = tokio::spawn(async move {
     804            1 :             let mut read = BytesMut::new();
     805            1 :             let mut write = BytesMut::new();
     806            1 : 
     807            1 :             // server should offer cleartext
     808            1 :             match read_message(&mut client, &mut read).await {
     809            1 :                 PgMessage::AuthenticationCleartextPassword => {}
     810            1 :                 _ => panic!("wrong message"),
     811            1 :             }
     812            1 : 
     813            1 :             // client responds with password
     814            1 :             write.clear();
     815            1 :             frontend::password_message(b"my-secret-password", &mut write).unwrap();
     816            1 :             client.write_all(&write).await.unwrap();
     817            1 :         });
     818            1 :         let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
     819            1 :             EndpointRateLimiter::DEFAULT,
     820            1 :             64,
     821            1 :         ));
     822            1 : 
     823            1 :         let _creds = auth_quirks(
     824            1 :             &ctx,
     825            1 :             &api,
     826            1 :             user_info,
     827            1 :             &mut stream,
     828            1 :             true,
     829            1 :             &CONFIG,
     830            1 :             endpoint_rate_limiter,
     831            1 :         )
     832            1 :         .await
     833            1 :         .unwrap();
     834            1 : 
     835            1 :         handle.await.unwrap();
     836            1 :     }
     837              : 
     838              :     #[tokio::test]
     839            1 :     async fn auth_quirks_password_hack() {
     840            1 :         let (mut client, server) = tokio::io::duplex(1024);
     841            1 :         let mut stream = PqStream::new(Stream::from_raw(server));
     842            1 : 
     843            1 :         let ctx = RequestContext::test();
     844            1 :         let api = Auth {
     845            1 :             ips: vec![],
     846            1 :             vpc_endpoint_ids: vec![],
     847            1 :             access_blocker_flags: AccessBlockerFlags::default(),
     848            1 :             secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
     849            1 :         };
     850            1 : 
     851            1 :         let user_info = ComputeUserInfoMaybeEndpoint {
     852            1 :             user: "conrad".into(),
     853            1 :             endpoint_id: None,
     854            1 :             options: NeonOptions::default(),
     855            1 :         };
     856            1 : 
     857            1 :         let handle = tokio::spawn(async move {
     858            1 :             let mut read = BytesMut::new();
     859            1 : 
     860            1 :             // server should offer cleartext
     861            1 :             match read_message(&mut client, &mut read).await {
     862            1 :                 PgMessage::AuthenticationCleartextPassword => {}
     863            1 :                 _ => panic!("wrong message"),
     864            1 :             }
     865            1 : 
     866            1 :             // client responds with password
     867            1 :             let mut write = BytesMut::new();
     868            1 :             frontend::password_message(b"endpoint=my-endpoint;my-secret-password", &mut write)
     869            1 :                 .unwrap();
     870            1 :             client.write_all(&write).await.unwrap();
     871            1 :         });
     872            1 : 
     873            1 :         let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
     874            1 :             EndpointRateLimiter::DEFAULT,
     875            1 :             64,
     876            1 :         ));
     877            1 : 
     878            1 :         let creds = auth_quirks(
     879            1 :             &ctx,
     880            1 :             &api,
     881            1 :             user_info,
     882            1 :             &mut stream,
     883            1 :             true,
     884            1 :             &CONFIG,
     885            1 :             endpoint_rate_limiter,
     886            1 :         )
     887            1 :         .await
     888            1 :         .unwrap();
     889            1 : 
     890            1 :         assert_eq!(creds.0.info.endpoint, "my-endpoint");
     891            1 : 
     892            1 :         handle.await.unwrap();
     893            1 :     }
     894              : }
        

Generated by: LCOV version 2.1-beta