LCOV - code coverage report
Current view: top level - proxy/src/auth/backend - mod.rs (source / functions) Coverage Total Hit
Test: a14d6a1f0ccf210374e9eaed9918e97cd6f5d5ba.info Lines: 70.6 % 354 250
Test Date: 2025-08-04 14:37:31 Functions: 44.6 % 56 25

            Line data    Source code
       1              : mod classic;
       2              : mod console_redirect;
       3              : mod hacks;
       4              : pub mod jwt;
       5              : pub mod local;
       6              : 
       7              : use std::sync::Arc;
       8              : 
       9              : pub use console_redirect::ConsoleRedirectBackend;
      10              : pub(crate) use console_redirect::ConsoleRedirectError;
      11              : use local::LocalBackend;
      12              : use postgres_client::config::AuthKeys;
      13              : use serde::{Deserialize, Serialize};
      14              : use tokio::io::{AsyncRead, AsyncWrite};
      15              : use tracing::{debug, info};
      16              : 
      17              : use crate::auth::{self, ComputeUserInfoMaybeEndpoint, validate_password_and_exchange};
      18              : use crate::cache::Cached;
      19              : use crate::cache::node_info::CachedNodeInfo;
      20              : use crate::config::AuthenticationConfig;
      21              : use crate::context::RequestContext;
      22              : use crate::control_plane::client::ControlPlaneClient;
      23              : use crate::control_plane::errors::GetAuthInfoError;
      24              : use crate::control_plane::messages::EndpointRateLimitConfig;
      25              : use crate::control_plane::{
      26              :     self, AccessBlockerFlags, AuthSecret, ControlPlaneApi, EndpointAccessControl, RoleAccessControl,
      27              : };
      28              : use crate::intern::{EndpointIdInt, RoleNameInt};
      29              : use crate::pqproto::BeMessage;
      30              : use crate::proxy::NeonOptions;
      31              : use crate::proxy::wake_compute::WakeComputeBackend;
      32              : use crate::rate_limiter::EndpointRateLimiter;
      33              : use crate::stream::Stream;
      34              : use crate::types::{EndpointCacheKey, EndpointId, RoleName};
      35              : use crate::{scram, stream};
      36              : 
      37              : /// Alternative to [`std::borrow::Cow`] but doesn't need `T: ToOwned` as we don't need that functionality
      38              : pub enum MaybeOwned<'a, T> {
      39              :     Owned(T),
      40              :     Borrowed(&'a T),
      41              : }
      42              : 
      43              : impl<T> std::ops::Deref for MaybeOwned<'_, T> {
      44              :     type Target = T;
      45              : 
      46           21 :     fn deref(&self) -> &Self::Target {
      47           21 :         match self {
      48           21 :             MaybeOwned::Owned(t) => t,
      49            0 :             MaybeOwned::Borrowed(t) => t,
      50              :         }
      51           21 :     }
      52              : }
      53              : 
      54              : /// This type serves two purposes:
      55              : ///
      56              : /// * When `T` is `()`, it's just a regular auth backend selector
      57              : ///   which we use in [`crate::config::ProxyConfig`].
      58              : ///
      59              : /// * However, when we substitute `T` with [`ComputeUserInfoMaybeEndpoint`],
      60              : ///   this helps us provide the credentials only to those auth
      61              : ///   backends which require them for the authentication process.
      62              : pub enum Backend<'a, T> {
      63              :     /// Cloud API (V2).
      64              :     ControlPlane(MaybeOwned<'a, ControlPlaneClient>, T),
      65              :     /// Local proxy uses configured auth credentials and does not wake compute
      66              :     Local(MaybeOwned<'a, LocalBackend>),
      67              : }
      68              : 
      69              : impl std::fmt::Display for Backend<'_, ()> {
      70            0 :     fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
      71            0 :         match self {
      72            0 :             Self::ControlPlane(api, ()) => match &**api {
      73            0 :                 ControlPlaneClient::ProxyV1(endpoint) => fmt
      74            0 :                     .debug_tuple("ControlPlane::ProxyV1")
      75            0 :                     .field(&endpoint.url())
      76            0 :                     .finish(),
      77              :                 #[cfg(any(test, feature = "testing"))]
      78            0 :                 ControlPlaneClient::PostgresMock(endpoint) => {
      79            0 :                     let url = endpoint.url();
      80            0 :                     match url::Url::parse(url) {
      81            0 :                         Ok(mut url) => {
      82            0 :                             let _ = url.set_password(Some("_redacted_"));
      83            0 :                             let url = url.as_str();
      84            0 :                             fmt.debug_tuple("ControlPlane::PostgresMock")
      85            0 :                                 .field(&url)
      86            0 :                                 .finish()
      87              :                         }
      88            0 :                         Err(_) => fmt
      89            0 :                             .debug_tuple("ControlPlane::PostgresMock")
      90            0 :                             .field(&url)
      91            0 :                             .finish(),
      92              :                     }
      93              :                 }
      94              :                 #[cfg(test)]
      95            0 :                 ControlPlaneClient::Test(_) => fmt.debug_tuple("ControlPlane::Test").finish(),
      96              :             },
      97            0 :             Self::Local(_) => fmt.debug_tuple("Local").finish(),
      98              :         }
      99            0 :     }
     100              : }
     101              : 
     102              : impl<T> Backend<'_, T> {
     103              :     /// Very similar to [`std::option::Option::as_ref`].
     104              :     /// This helps us pass structured config to async tasks.
     105            0 :     pub(crate) fn as_ref(&self) -> Backend<'_, &T> {
     106            0 :         match self {
     107            0 :             Self::ControlPlane(c, x) => Backend::ControlPlane(MaybeOwned::Borrowed(c), x),
     108            0 :             Self::Local(l) => Backend::Local(MaybeOwned::Borrowed(l)),
     109              :         }
     110            0 :     }
     111              : 
     112            0 :     pub(crate) fn get_api(&self) -> &ControlPlaneClient {
     113            0 :         match self {
     114            0 :             Self::ControlPlane(api, _) => api,
     115            0 :             Self::Local(_) => panic!("Local backend has no API"),
     116              :         }
     117            0 :     }
     118              : 
     119            0 :     pub(crate) fn is_local_proxy(&self) -> bool {
     120            0 :         matches!(self, Self::Local(_))
     121            0 :     }
     122              : }
     123              : 
     124              : impl<'a, T> Backend<'a, T> {
     125              :     /// Very similar to [`std::option::Option::map`].
     126              :     /// Maps [`Backend<T>`] to [`Backend<R>`] by applying
     127              :     /// a function to a contained value.
     128            0 :     pub(crate) fn map<R>(self, f: impl FnOnce(T) -> R) -> Backend<'a, R> {
     129            0 :         match self {
     130            0 :             Self::ControlPlane(c, x) => Backend::ControlPlane(c, f(x)),
     131            0 :             Self::Local(l) => Backend::Local(l),
     132              :         }
     133            0 :     }
     134              : }
     135              : impl<'a, T, E> Backend<'a, Result<T, E>> {
     136              :     /// Very similar to [`std::option::Option::transpose`].
     137              :     /// This is most useful for error handling.
     138            0 :     pub(crate) fn transpose(self) -> Result<Backend<'a, T>, E> {
     139            0 :         match self {
     140            0 :             Self::ControlPlane(c, x) => x.map(|x| Backend::ControlPlane(c, x)),
     141            0 :             Self::Local(l) => Ok(Backend::Local(l)),
     142              :         }
     143            0 :     }
     144              : }
     145              : 
     146              : pub(crate) struct ComputeCredentials {
     147              :     pub(crate) info: ComputeUserInfo,
     148              :     pub(crate) keys: ComputeCredentialKeys,
     149              : }
     150              : 
     151              : #[derive(Debug, Clone)]
     152              : pub(crate) struct ComputeUserInfoNoEndpoint {
     153              :     pub(crate) user: RoleName,
     154              :     pub(crate) options: NeonOptions,
     155              : }
     156              : 
     157            0 : #[derive(Debug, Clone, Default, Serialize, Deserialize)]
     158              : pub(crate) struct ComputeUserInfo {
     159              :     pub(crate) endpoint: EndpointId,
     160              :     pub(crate) user: RoleName,
     161              :     pub(crate) options: NeonOptions,
     162              : }
     163              : 
     164              : impl ComputeUserInfo {
     165            2 :     pub(crate) fn endpoint_cache_key(&self) -> EndpointCacheKey {
     166            2 :         self.options.get_cache_key(&self.endpoint)
     167            2 :     }
     168              : }
     169              : 
     170              : #[cfg_attr(test, derive(Debug))]
     171              : pub(crate) enum ComputeCredentialKeys {
     172              :     AuthKeys(AuthKeys),
     173              :     JwtPayload(Vec<u8>),
     174              : }
     175              : 
     176              : impl TryFrom<ComputeUserInfoMaybeEndpoint> for ComputeUserInfo {
     177              :     // user name
     178              :     type Error = ComputeUserInfoNoEndpoint;
     179              : 
     180            3 :     fn try_from(user_info: ComputeUserInfoMaybeEndpoint) -> Result<Self, Self::Error> {
     181            3 :         match user_info.endpoint_id {
     182            1 :             None => Err(ComputeUserInfoNoEndpoint {
     183            1 :                 user: user_info.user,
     184            1 :                 options: user_info.options,
     185            1 :             }),
     186            2 :             Some(endpoint) => Ok(ComputeUserInfo {
     187            2 :                 endpoint,
     188            2 :                 user: user_info.user,
     189            2 :                 options: user_info.options,
     190            2 :             }),
     191              :         }
     192            3 :     }
     193              : }
     194              : 
     195              : /// True to its name, this function encapsulates our current auth trade-offs.
     196              : /// Here, we choose the appropriate auth flow based on circumstances.
     197              : ///
     198              : /// All authentication flows will emit an AuthenticationOk message if successful.
     199            3 : async fn auth_quirks(
     200            3 :     ctx: &RequestContext,
     201            3 :     api: &impl control_plane::ControlPlaneApi,
     202            3 :     user_info: ComputeUserInfoMaybeEndpoint,
     203            3 :     client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
     204            3 :     allow_cleartext: bool,
     205            3 :     config: &'static AuthenticationConfig,
     206            3 :     endpoint_rate_limiter: Arc<EndpointRateLimiter>,
     207            3 : ) -> auth::Result<ComputeCredentials> {
     208              :     // If there's no project so far, that entails that client doesn't
     209              :     // support SNI or other means of passing the endpoint (project) name.
     210              :     // We now expect to see a very specific payload in the place of password.
     211            3 :     let (info, unauthenticated_password) = match user_info.try_into() {
     212            1 :         Err(info) => {
     213            1 :             let (info, password) =
     214            1 :                 hacks::password_hack_no_authentication(ctx, info, client).await?;
     215            1 :             ctx.set_endpoint_id(info.endpoint.clone());
     216            1 :             (info, Some(password))
     217              :         }
     218            2 :         Ok(info) => (info, None),
     219              :     };
     220              : 
     221            3 :     debug!("fetching authentication info and allowlists");
     222              : 
     223            3 :     let access_controls = api
     224            3 :         .get_endpoint_access_control(ctx, &info.endpoint, &info.user)
     225            3 :         .await?;
     226              : 
     227            3 :     access_controls.check(
     228            3 :         ctx,
     229            3 :         config.ip_allowlist_check_enabled,
     230            3 :         config.is_vpc_acccess_proxy,
     231            0 :     )?;
     232              : 
     233            3 :     access_controls.connection_attempt_rate_limit(ctx, &info.endpoint, &endpoint_rate_limiter)?;
     234              : 
     235            3 :     let role_access = api
     236            3 :         .get_role_access_control(ctx, &info.endpoint, &info.user)
     237            3 :         .await?;
     238              : 
     239            3 :     let secret = if let Some(secret) = role_access.secret {
     240            3 :         secret
     241              :     } else {
     242              :         // If we don't have an authentication secret, we mock one to
     243              :         // prevent malicious probing (possible due to missing protocol steps).
     244              :         // This mocked secret will never lead to successful authentication.
     245            0 :         info!("authentication info not found, mocking it");
     246            0 :         AuthSecret::Scram(scram::ServerSecret::mock(rand::random()))
     247              :     };
     248              : 
     249            3 :     match authenticate_with_secret(
     250            3 :         ctx,
     251            3 :         secret,
     252            3 :         info,
     253            3 :         client,
     254            3 :         unauthenticated_password,
     255            3 :         allow_cleartext,
     256            3 :         config,
     257              :     )
     258            3 :     .await
     259              :     {
     260            3 :         Ok(keys) => Ok(keys),
     261            0 :         Err(e) => Err(e),
     262              :     }
     263            3 : }
     264              : 
     265            3 : async fn authenticate_with_secret(
     266            3 :     ctx: &RequestContext,
     267            3 :     secret: AuthSecret,
     268            3 :     info: ComputeUserInfo,
     269            3 :     client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
     270            3 :     unauthenticated_password: Option<Vec<u8>>,
     271            3 :     allow_cleartext: bool,
     272            3 :     config: &'static AuthenticationConfig,
     273            3 : ) -> auth::Result<ComputeCredentials> {
     274            3 :     if let Some(password) = unauthenticated_password {
     275            1 :         let ep = EndpointIdInt::from(&info.endpoint);
     276            1 :         let role = RoleNameInt::from(&info.user);
     277              : 
     278            1 :         let auth_outcome =
     279            1 :             validate_password_and_exchange(&config.scram_thread_pool, ep, role, &password, secret)
     280            1 :                 .await?;
     281            1 :         let keys = match auth_outcome {
     282            1 :             crate::sasl::Outcome::Success(key) => key,
     283            0 :             crate::sasl::Outcome::Failure(reason) => {
     284            0 :                 info!("auth backend failed with an error: {reason}");
     285            0 :                 return Err(auth::AuthError::password_failed(&*info.user));
     286              :             }
     287              :         };
     288              : 
     289              :         // we have authenticated the password
     290            1 :         client.write_message(BeMessage::AuthenticationOk);
     291              : 
     292            1 :         return Ok(ComputeCredentials { info, keys });
     293            2 :     }
     294              : 
     295              :     // -- the remaining flows are self-authenticating --
     296              : 
     297              :     // Perform cleartext auth if we're allowed to do that.
     298              :     // Currently, we use it for websocket connections (latency).
     299            2 :     if allow_cleartext {
     300            1 :         ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
     301            1 :         return hacks::authenticate_cleartext(ctx, info, client, secret, config).await;
     302            1 :     }
     303              : 
     304              :     // Finally, proceed with the main auth flow (SCRAM-based).
     305            1 :     classic::authenticate(ctx, info, client, config, secret).await
     306            3 : }
     307              : 
     308              : impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> {
     309              :     /// Get username from the credentials.
     310            0 :     pub(crate) fn get_user(&self) -> &str {
     311            0 :         match self {
     312            0 :             Self::ControlPlane(_, user_info) => &user_info.user,
     313            0 :             Self::Local(_) => "local",
     314              :         }
     315            0 :     }
     316              : 
     317              :     /// Authenticate the client via the requested backend, possibly using credentials.
     318              :     #[tracing::instrument(fields(allow_cleartext = allow_cleartext), skip_all)]
     319              :     pub(crate) async fn authenticate(
     320              :         self,
     321              :         ctx: &RequestContext,
     322              :         client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
     323              :         allow_cleartext: bool,
     324              :         config: &'static AuthenticationConfig,
     325              :         endpoint_rate_limiter: Arc<EndpointRateLimiter>,
     326              :     ) -> auth::Result<Backend<'a, ComputeCredentials>> {
     327              :         let res = match self {
     328              :             Self::ControlPlane(api, user_info) => {
     329              :                 debug!(
     330              :                     user = &*user_info.user,
     331              :                     project = user_info.endpoint(),
     332              :                     "performing authentication using the console"
     333              :                 );
     334              : 
     335              :                 let auth_res = auth_quirks(
     336              :                     ctx,
     337              :                     &*api,
     338              :                     user_info.clone(),
     339              :                     client,
     340              :                     allow_cleartext,
     341              :                     config,
     342              :                     endpoint_rate_limiter,
     343              :                 )
     344              :                 .await;
     345              :                 match auth_res {
     346              :                     Ok(credentials) => Ok(Backend::ControlPlane(api, credentials)),
     347              :                     Err(e) => {
     348              :                         // The password could have been changed, so we invalidate the cache.
     349              :                         // We should only invalidate the cache if the TTL might have expired.
     350              :                         if e.is_password_failed()
     351              :                             && let ControlPlaneClient::ProxyV1(api) = &*api
     352              :                             && let Some(ep) = &user_info.endpoint_id
     353              :                         {
     354              :                             api.caches
     355              :                                 .project_info
     356              :                                 .maybe_invalidate_role_secret(ep, &user_info.user);
     357              :                         }
     358              : 
     359              :                         Err(e)
     360              :                     }
     361              :                 }
     362              :             }
     363              :             Self::Local(_) => {
     364              :                 return Err(auth::AuthError::bad_auth_method("invalid for local proxy"));
     365              :             }
     366              :         };
     367              : 
     368              :         // TODO: replace with some metric
     369              :         info!("user successfully authenticated");
     370              :         res
     371              :     }
     372              : }
     373              : 
     374              : impl Backend<'_, ComputeUserInfo> {
     375            0 :     pub(crate) async fn get_role_secret(
     376            0 :         &self,
     377            0 :         ctx: &RequestContext,
     378            0 :     ) -> Result<RoleAccessControl, GetAuthInfoError> {
     379            0 :         match self {
     380            0 :             Self::ControlPlane(api, user_info) => {
     381            0 :                 api.get_role_access_control(ctx, &user_info.endpoint, &user_info.user)
     382            0 :                     .await
     383              :             }
     384            0 :             Self::Local(_) => Ok(RoleAccessControl { secret: None }),
     385              :         }
     386            0 :     }
     387              : 
     388            0 :     pub(crate) async fn get_endpoint_access_control(
     389            0 :         &self,
     390            0 :         ctx: &RequestContext,
     391            0 :     ) -> Result<EndpointAccessControl, GetAuthInfoError> {
     392            0 :         match self {
     393            0 :             Self::ControlPlane(api, user_info) => {
     394            0 :                 api.get_endpoint_access_control(ctx, &user_info.endpoint, &user_info.user)
     395            0 :                     .await
     396              :             }
     397            0 :             Self::Local(_) => Ok(EndpointAccessControl {
     398            0 :                 allowed_ips: Arc::new(vec![]),
     399            0 :                 allowed_vpce: Arc::new(vec![]),
     400            0 :                 flags: AccessBlockerFlags::default(),
     401            0 :                 rate_limits: EndpointRateLimitConfig::default(),
     402            0 :             }),
     403              :         }
     404            0 :     }
     405              : }
     406              : 
     407              : #[async_trait::async_trait]
     408              : impl WakeComputeBackend for Backend<'_, ComputeUserInfo> {
     409           21 :     async fn wake_compute(
     410              :         &self,
     411              :         ctx: &RequestContext,
     412           21 :     ) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
     413           21 :         match self {
     414           21 :             Self::ControlPlane(api, info) => api.wake_compute(ctx, info).await,
     415            0 :             Self::Local(local) => Ok(Cached::new_uncached(local.node_info.clone())),
     416              :         }
     417           42 :     }
     418              : }
     419              : 
     420              : #[cfg(test)]
     421              : mod tests {
     422              :     #![allow(clippy::unimplemented, clippy::unwrap_used)]
     423              : 
     424              :     use std::sync::Arc;
     425              : 
     426              :     use bytes::BytesMut;
     427              :     use control_plane::AuthSecret;
     428              :     use fallible_iterator::FallibleIterator;
     429              :     use once_cell::sync::Lazy;
     430              :     use postgres_protocol::authentication::sasl::{ChannelBinding, ScramSha256};
     431              :     use postgres_protocol::message::backend::Message as PgMessage;
     432              :     use postgres_protocol::message::frontend;
     433              :     use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
     434              : 
     435              :     use super::auth_quirks;
     436              :     use super::jwt::JwkCache;
     437              :     use crate::auth::{ComputeUserInfoMaybeEndpoint, IpPattern};
     438              :     use crate::cache::node_info::CachedNodeInfo;
     439              :     use crate::config::AuthenticationConfig;
     440              :     use crate::context::RequestContext;
     441              :     use crate::control_plane::messages::EndpointRateLimitConfig;
     442              :     use crate::control_plane::{
     443              :         self, AccessBlockerFlags, EndpointAccessControl, RoleAccessControl,
     444              :     };
     445              :     use crate::proxy::NeonOptions;
     446              :     use crate::rate_limiter::EndpointRateLimiter;
     447              :     use crate::scram::ServerSecret;
     448              :     use crate::scram::threadpool::ThreadPool;
     449              :     use crate::stream::{PqStream, Stream};
     450              : 
     451              :     struct Auth {
     452              :         ips: Vec<IpPattern>,
     453              :         vpc_endpoint_ids: Vec<String>,
     454              :         access_blocker_flags: AccessBlockerFlags,
     455              :         secret: AuthSecret,
     456              :     }
     457              : 
     458              :     impl control_plane::ControlPlaneApi for Auth {
     459            3 :         async fn get_role_access_control(
     460            3 :             &self,
     461            3 :             _ctx: &RequestContext,
     462            3 :             _endpoint: &crate::types::EndpointId,
     463            3 :             _role: &crate::types::RoleName,
     464            3 :         ) -> Result<RoleAccessControl, control_plane::errors::GetAuthInfoError> {
     465            3 :             Ok(RoleAccessControl {
     466            3 :                 secret: Some(self.secret.clone()),
     467            3 :             })
     468            3 :         }
     469              : 
     470            3 :         async fn get_endpoint_access_control(
     471            3 :             &self,
     472            3 :             _ctx: &RequestContext,
     473            3 :             _endpoint: &crate::types::EndpointId,
     474            3 :             _role: &crate::types::RoleName,
     475            3 :         ) -> Result<EndpointAccessControl, control_plane::errors::GetAuthInfoError> {
     476            3 :             Ok(EndpointAccessControl {
     477            3 :                 allowed_ips: Arc::new(self.ips.clone()),
     478            3 :                 allowed_vpce: Arc::new(self.vpc_endpoint_ids.clone()),
     479            3 :                 flags: self.access_blocker_flags,
     480            3 :                 rate_limits: EndpointRateLimitConfig::default(),
     481            3 :             })
     482            3 :         }
     483              : 
     484            0 :         async fn get_endpoint_jwks(
     485            0 :             &self,
     486            0 :             _ctx: &RequestContext,
     487            0 :             _endpoint: &crate::types::EndpointId,
     488            0 :         ) -> Result<Vec<super::jwt::AuthRule>, control_plane::errors::GetEndpointJwksError>
     489            0 :         {
     490            0 :             unimplemented!()
     491              :         }
     492              : 
     493            0 :         async fn wake_compute(
     494            0 :             &self,
     495            0 :             _ctx: &RequestContext,
     496            0 :             _user_info: &super::ComputeUserInfo,
     497            0 :         ) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
     498            0 :             unimplemented!()
     499              :         }
     500              :     }
     501              : 
     502              :     static CONFIG: Lazy<AuthenticationConfig> = Lazy::new(|| AuthenticationConfig {
     503            3 :         jwks_cache: JwkCache::default(),
     504            3 :         scram_thread_pool: ThreadPool::new(1),
     505            3 :         scram_protocol_timeout: std::time::Duration::from_secs(5),
     506              :         ip_allowlist_check_enabled: true,
     507              :         is_vpc_acccess_proxy: false,
     508              :         is_auth_broker: false,
     509              :         accept_jwts: false,
     510            3 :         console_redirect_confirmation_timeout: std::time::Duration::from_secs(5),
     511            3 :     });
     512              : 
     513            5 :     async fn read_message(r: &mut (impl AsyncRead + Unpin), b: &mut BytesMut) -> PgMessage {
     514              :         loop {
     515            7 :             r.read_buf(&mut *b).await.unwrap();
     516            7 :             if let Some(m) = PgMessage::parse(&mut *b).unwrap() {
     517            5 :                 break m;
     518            2 :             }
     519              :         }
     520            5 :     }
     521              : 
     522              :     #[tokio::test]
     523            1 :     async fn auth_quirks_scram() {
     524            1 :         let (mut client, server) = tokio::io::duplex(1024);
     525            1 :         let mut stream = PqStream::new_skip_handshake(Stream::from_raw(server));
     526              : 
     527            1 :         let ctx = RequestContext::test();
     528            1 :         let api = Auth {
     529            1 :             ips: vec![],
     530            1 :             vpc_endpoint_ids: vec![],
     531            1 :             access_blocker_flags: AccessBlockerFlags::default(),
     532            1 :             secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
     533              :         };
     534              : 
     535            1 :         let user_info = ComputeUserInfoMaybeEndpoint {
     536            1 :             user: "conrad".into(),
     537            1 :             endpoint_id: Some("endpoint".into()),
     538            1 :             options: NeonOptions::default(),
     539            1 :         };
     540              : 
     541            1 :         let handle = tokio::spawn(async move {
     542            1 :             let mut scram = ScramSha256::new(b"my-secret-password", ChannelBinding::unsupported());
     543              : 
     544            1 :             let mut read = BytesMut::new();
     545              : 
     546              :             // server should offer scram
     547            1 :             match read_message(&mut client, &mut read).await {
     548            1 :                 PgMessage::AuthenticationSasl(a) => {
     549            1 :                     let options: Vec<&str> = a.mechanisms().collect().unwrap();
     550            1 :                     assert_eq!(options, ["SCRAM-SHA-256"]);
     551              :                 }
     552            0 :                 _ => panic!("wrong message"),
     553              :             }
     554              : 
     555              :             // client sends client-first-message
     556            1 :             let mut write = BytesMut::new();
     557            1 :             frontend::sasl_initial_response("SCRAM-SHA-256", scram.message(), &mut write).unwrap();
     558            1 :             client.write_all(&write).await.unwrap();
     559              : 
     560              :             // server response with server-first-message
     561            1 :             match read_message(&mut client, &mut read).await {
     562            1 :                 PgMessage::AuthenticationSaslContinue(a) => {
     563            1 :                     scram.update(a.data()).await.unwrap();
     564              :                 }
     565            0 :                 _ => panic!("wrong message"),
     566              :             }
     567              : 
     568              :             // client response with client-final-message
     569            1 :             write.clear();
     570            1 :             frontend::sasl_response(scram.message(), &mut write).unwrap();
     571            1 :             client.write_all(&write).await.unwrap();
     572              : 
     573              :             // server response with server-final-message
     574            1 :             match read_message(&mut client, &mut read).await {
     575            1 :                 PgMessage::AuthenticationSaslFinal(a) => {
     576            1 :                     scram.finish(a.data()).unwrap();
     577            1 :                 }
     578            0 :                 _ => panic!("wrong message"),
     579              :             }
     580            1 :         });
     581            1 :         let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
     582              :             EndpointRateLimiter::DEFAULT,
     583              :             64,
     584              :         ));
     585              : 
     586            1 :         let _creds = auth_quirks(
     587            1 :             &ctx,
     588            1 :             &api,
     589            1 :             user_info,
     590            1 :             &mut stream,
     591            1 :             false,
     592            1 :             &CONFIG,
     593            1 :             endpoint_rate_limiter,
     594            1 :         )
     595            1 :         .await
     596            1 :         .unwrap();
     597              : 
     598              :         // flush the final server message
     599            1 :         stream.flush().await.unwrap();
     600              : 
     601            1 :         handle.await.unwrap();
     602            1 :     }
     603              : 
     604              :     #[tokio::test]
     605            1 :     async fn auth_quirks_cleartext() {
     606            1 :         let (mut client, server) = tokio::io::duplex(1024);
     607            1 :         let mut stream = PqStream::new_skip_handshake(Stream::from_raw(server));
     608              : 
     609            1 :         let ctx = RequestContext::test();
     610            1 :         let api = Auth {
     611            1 :             ips: vec![],
     612            1 :             vpc_endpoint_ids: vec![],
     613            1 :             access_blocker_flags: AccessBlockerFlags::default(),
     614            1 :             secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
     615              :         };
     616              : 
     617            1 :         let user_info = ComputeUserInfoMaybeEndpoint {
     618            1 :             user: "conrad".into(),
     619            1 :             endpoint_id: Some("endpoint".into()),
     620            1 :             options: NeonOptions::default(),
     621            1 :         };
     622              : 
     623            1 :         let handle = tokio::spawn(async move {
     624            1 :             let mut read = BytesMut::new();
     625            1 :             let mut write = BytesMut::new();
     626              : 
     627              :             // server should offer cleartext
     628            1 :             match read_message(&mut client, &mut read).await {
     629            1 :                 PgMessage::AuthenticationCleartextPassword => {}
     630            0 :                 _ => panic!("wrong message"),
     631              :             }
     632              : 
     633              :             // client responds with password
     634            1 :             write.clear();
     635            1 :             frontend::password_message(b"my-secret-password", &mut write).unwrap();
     636            1 :             client.write_all(&write).await.unwrap();
     637            1 :         });
     638            1 :         let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
     639              :             EndpointRateLimiter::DEFAULT,
     640              :             64,
     641              :         ));
     642              : 
     643            1 :         let _creds = auth_quirks(
     644            1 :             &ctx,
     645            1 :             &api,
     646            1 :             user_info,
     647            1 :             &mut stream,
     648            1 :             true,
     649            1 :             &CONFIG,
     650            1 :             endpoint_rate_limiter,
     651            1 :         )
     652            1 :         .await
     653            1 :         .unwrap();
     654              : 
     655            1 :         handle.await.unwrap();
     656            1 :     }
     657              : 
     658              :     #[tokio::test]
     659            1 :     async fn auth_quirks_password_hack() {
     660            1 :         let (mut client, server) = tokio::io::duplex(1024);
     661            1 :         let mut stream = PqStream::new_skip_handshake(Stream::from_raw(server));
     662              : 
     663            1 :         let ctx = RequestContext::test();
     664            1 :         let api = Auth {
     665            1 :             ips: vec![],
     666            1 :             vpc_endpoint_ids: vec![],
     667            1 :             access_blocker_flags: AccessBlockerFlags::default(),
     668            1 :             secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
     669              :         };
     670              : 
     671            1 :         let user_info = ComputeUserInfoMaybeEndpoint {
     672            1 :             user: "conrad".into(),
     673            1 :             endpoint_id: None,
     674            1 :             options: NeonOptions::default(),
     675            1 :         };
     676              : 
     677            1 :         let handle = tokio::spawn(async move {
     678            1 :             let mut read = BytesMut::new();
     679              : 
     680              :             // server should offer cleartext
     681            1 :             match read_message(&mut client, &mut read).await {
     682            1 :                 PgMessage::AuthenticationCleartextPassword => {}
     683            0 :                 _ => panic!("wrong message"),
     684              :             }
     685              : 
     686              :             // client responds with password
     687            1 :             let mut write = BytesMut::new();
     688            1 :             frontend::password_message(b"endpoint=my-endpoint;my-secret-password", &mut write)
     689            1 :                 .unwrap();
     690            1 :             client.write_all(&write).await.unwrap();
     691            1 :         });
     692              : 
     693            1 :         let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
     694              :             EndpointRateLimiter::DEFAULT,
     695              :             64,
     696              :         ));
     697              : 
     698            1 :         let creds = auth_quirks(
     699            1 :             &ctx,
     700            1 :             &api,
     701            1 :             user_info,
     702            1 :             &mut stream,
     703            1 :             true,
     704            1 :             &CONFIG,
     705            1 :             endpoint_rate_limiter,
     706            1 :         )
     707            1 :         .await
     708            1 :         .unwrap();
     709              : 
     710            1 :         assert_eq!(creds.info.endpoint, "my-endpoint");
     711              : 
     712            1 :         handle.await.unwrap();
     713            1 :     }
     714              : }
        

Generated by: LCOV version 2.1-beta