LCOV - code coverage report
Current view: top level - proxy/src/auth - backend.rs (source / functions) Coverage Total Hit
Test: 36bb8dd7c7efcb53483d1a7d9f7cb33e8406dcf0.info Lines: 69.7 % 409 285
Test Date: 2024-04-08 10:22:05 Functions: 26.6 % 94 25

            Line data    Source code
       1              : mod classic;
       2              : mod hacks;
       3              : mod link;
       4              : 
       5              : pub use link::LinkAuthError;
       6              : use tokio_postgres::config::AuthKeys;
       7              : 
       8              : use crate::auth::credentials::check_peer_addr_is_in_list;
       9              : use crate::auth::validate_password_and_exchange;
      10              : use crate::cache::Cached;
      11              : use crate::console::errors::GetAuthInfoError;
      12              : use crate::console::provider::{CachedRoleSecret, ConsoleBackend};
      13              : use crate::console::{AuthSecret, NodeInfo};
      14              : use crate::context::RequestMonitoring;
      15              : use crate::intern::EndpointIdInt;
      16              : use crate::metrics::{AUTH_RATE_LIMIT_HITS, ENDPOINTS_AUTH_RATE_LIMITED};
      17              : use crate::proxy::connect_compute::ComputeConnectBackend;
      18              : use crate::proxy::NeonOptions;
      19              : use crate::stream::Stream;
      20              : use crate::{
      21              :     auth::{self, ComputeUserInfoMaybeEndpoint},
      22              :     config::AuthenticationConfig,
      23              :     console::{
      24              :         self,
      25              :         provider::{CachedAllowedIps, CachedNodeInfo},
      26              :         Api,
      27              :     },
      28              :     stream, url,
      29              : };
      30              : use crate::{scram, EndpointCacheKey, EndpointId, RoleName};
      31              : use std::sync::Arc;
      32              : use tokio::io::{AsyncRead, AsyncWrite};
      33              : use tracing::{info, warn};
      34              : 
      35              : /// Alternative to [`std::borrow::Cow`] but doesn't need `T: ToOwned` as we don't need that functionality
      36              : pub enum MaybeOwned<'a, T> {
      37              :     Owned(T),
      38              :     Borrowed(&'a T),
      39              : }
      40              : 
      41              : impl<T> std::ops::Deref for MaybeOwned<'_, T> {
      42              :     type Target = T;
      43              : 
      44           26 :     fn deref(&self) -> &Self::Target {
      45           26 :         match self {
      46           26 :             MaybeOwned::Owned(t) => t,
      47            0 :             MaybeOwned::Borrowed(t) => t,
      48              :         }
      49           26 :     }
      50              : }
      51              : 
      52              : /// This type serves two purposes:
      53              : ///
      54              : /// * When `T` is `()`, it's just a regular auth backend selector
      55              : ///   which we use in [`crate::config::ProxyConfig`].
      56              : ///
      57              : /// * However, when we substitute `T` with [`ComputeUserInfoMaybeEndpoint`],
      58              : ///   this helps us provide the credentials only to those auth
      59              : ///   backends which require them for the authentication process.
      60              : pub enum BackendType<'a, T, D> {
      61              :     /// Cloud API (V2).
      62              :     Console(MaybeOwned<'a, ConsoleBackend>, T),
      63              :     /// Authentication via a web browser.
      64              :     Link(MaybeOwned<'a, url::ApiUrl>, D),
      65              : }
      66              : 
      67              : pub trait TestBackend: Send + Sync + 'static {
      68              :     fn wake_compute(&self) -> Result<CachedNodeInfo, console::errors::WakeComputeError>;
      69              :     fn get_allowed_ips_and_secret(
      70              :         &self,
      71              :     ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), console::errors::GetAuthInfoError>;
      72              :     fn get_role_secret(&self) -> Result<CachedRoleSecret, console::errors::GetAuthInfoError>;
      73              : }
      74              : 
      75              : impl std::fmt::Display for BackendType<'_, (), ()> {
      76            0 :     fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
      77            0 :         use BackendType::*;
      78            0 :         match self {
      79            0 :             Console(api, _) => match &**api {
      80            0 :                 ConsoleBackend::Console(endpoint) => {
      81            0 :                     fmt.debug_tuple("Console").field(&endpoint.url()).finish()
      82              :                 }
      83              :                 #[cfg(any(test, feature = "testing"))]
      84            0 :                 ConsoleBackend::Postgres(endpoint) => {
      85            0 :                     fmt.debug_tuple("Postgres").field(&endpoint.url()).finish()
      86              :                 }
      87              :                 #[cfg(test)]
      88            0 :                 ConsoleBackend::Test(_) => fmt.debug_tuple("Test").finish(),
      89              :             },
      90            0 :             Link(url, _) => fmt.debug_tuple("Link").field(&url.as_str()).finish(),
      91              :         }
      92            0 :     }
      93              : }
      94              : 
      95              : impl<T, D> BackendType<'_, T, D> {
      96              :     /// Very similar to [`std::option::Option::as_ref`].
      97              :     /// This helps us pass structured config to async tasks.
      98            0 :     pub fn as_ref(&self) -> BackendType<'_, &T, &D> {
      99            0 :         use BackendType::*;
     100            0 :         match self {
     101            0 :             Console(c, x) => Console(MaybeOwned::Borrowed(c), x),
     102            0 :             Link(c, x) => Link(MaybeOwned::Borrowed(c), x),
     103              :         }
     104            0 :     }
     105              : }
     106              : 
     107              : impl<'a, T, D> BackendType<'a, T, D> {
     108              :     /// Very similar to [`std::option::Option::map`].
     109              :     /// Maps [`BackendType<T>`] to [`BackendType<R>`] by applying
     110              :     /// a function to a contained value.
     111            0 :     pub fn map<R>(self, f: impl FnOnce(T) -> R) -> BackendType<'a, R, D> {
     112            0 :         use BackendType::*;
     113            0 :         match self {
     114            0 :             Console(c, x) => Console(c, f(x)),
     115            0 :             Link(c, x) => Link(c, x),
     116              :         }
     117            0 :     }
     118              : }
     119              : impl<'a, T, D, E> BackendType<'a, Result<T, E>, D> {
     120              :     /// Very similar to [`std::option::Option::transpose`].
     121              :     /// This is most useful for error handling.
     122            0 :     pub fn transpose(self) -> Result<BackendType<'a, T, D>, E> {
     123            0 :         use BackendType::*;
     124            0 :         match self {
     125            0 :             Console(c, x) => x.map(|x| Console(c, x)),
     126            0 :             Link(c, x) => Ok(Link(c, x)),
     127              :         }
     128            0 :     }
     129              : }
     130              : 
     131              : pub struct ComputeCredentials {
     132              :     pub info: ComputeUserInfo,
     133              :     pub keys: ComputeCredentialKeys,
     134              : }
     135              : 
     136              : #[derive(Debug, Clone)]
     137              : pub struct ComputeUserInfoNoEndpoint {
     138              :     pub user: RoleName,
     139              :     pub options: NeonOptions,
     140              : }
     141              : 
     142              : #[derive(Debug, Clone)]
     143              : pub struct ComputeUserInfo {
     144              :     pub endpoint: EndpointId,
     145              :     pub user: RoleName,
     146              :     pub options: NeonOptions,
     147              : }
     148              : 
     149              : impl ComputeUserInfo {
     150            4 :     pub fn endpoint_cache_key(&self) -> EndpointCacheKey {
     151            4 :         self.options.get_cache_key(&self.endpoint)
     152            4 :     }
     153              : }
     154              : 
     155              : pub enum ComputeCredentialKeys {
     156              :     Password(Vec<u8>),
     157              :     AuthKeys(AuthKeys),
     158              : }
     159              : 
     160              : impl TryFrom<ComputeUserInfoMaybeEndpoint> for ComputeUserInfo {
     161              :     // user name
     162              :     type Error = ComputeUserInfoNoEndpoint;
     163              : 
     164            6 :     fn try_from(user_info: ComputeUserInfoMaybeEndpoint) -> Result<Self, Self::Error> {
     165            6 :         match user_info.endpoint_id {
     166            2 :             None => Err(ComputeUserInfoNoEndpoint {
     167            2 :                 user: user_info.user,
     168            2 :                 options: user_info.options,
     169            2 :             }),
     170            4 :             Some(endpoint) => Ok(ComputeUserInfo {
     171            4 :                 endpoint,
     172            4 :                 user: user_info.user,
     173            4 :                 options: user_info.options,
     174            4 :             }),
     175              :         }
     176            6 :     }
     177              : }
     178              : 
     179              : impl AuthenticationConfig {
     180            6 :     pub fn check_rate_limit(
     181            6 :         &self,
     182            6 : 
     183            6 :         ctx: &mut RequestMonitoring,
     184            6 :         secret: AuthSecret,
     185            6 :         endpoint: &EndpointId,
     186            6 :         is_cleartext: bool,
     187            6 :     ) -> auth::Result<AuthSecret> {
     188            6 :         // we have validated the endpoint exists, so let's intern it.
     189            6 :         let endpoint_int = EndpointIdInt::from(endpoint);
     190              : 
     191              :         // only count the full hash count if password hack or websocket flow.
     192              :         // in other words, if proxy needs to run the hashing
     193            6 :         let password_weight = if is_cleartext {
     194            4 :             match &secret {
     195              :                 #[cfg(any(test, feature = "testing"))]
     196            0 :                 AuthSecret::Md5(_) => 1,
     197            4 :                 AuthSecret::Scram(s) => s.iterations + 1,
     198              :             }
     199              :         } else {
     200              :             // validating scram takes just 1 hmac_sha_256 operation.
     201            2 :             1
     202              :         };
     203              : 
     204            6 :         let limit_not_exceeded = self
     205            6 :             .rate_limiter
     206            6 :             .check((endpoint_int, ctx.peer_addr), password_weight);
     207            6 : 
     208            6 :         if !limit_not_exceeded {
     209            0 :             warn!(
     210            0 :                 enabled = self.rate_limiter_enabled,
     211            0 :                 "rate limiting authentication"
     212            0 :             );
     213            0 :             AUTH_RATE_LIMIT_HITS.inc();
     214            0 :             ENDPOINTS_AUTH_RATE_LIMITED.measure(endpoint);
     215            0 : 
     216            0 :             if self.rate_limiter_enabled {
     217            0 :                 return Err(auth::AuthError::too_many_connections());
     218            0 :             }
     219            6 :         }
     220              : 
     221            6 :         Ok(secret)
     222            6 :     }
     223              : }
     224              : 
     225              : /// True to its name, this function encapsulates our current auth trade-offs.
     226              : /// Here, we choose the appropriate auth flow based on circumstances.
     227              : ///
     228              : /// All authentication flows will emit an AuthenticationOk message if successful.
     229            6 : async fn auth_quirks(
     230            6 :     ctx: &mut RequestMonitoring,
     231            6 :     api: &impl console::Api,
     232            6 :     user_info: ComputeUserInfoMaybeEndpoint,
     233            6 :     client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
     234            6 :     allow_cleartext: bool,
     235            6 :     config: &'static AuthenticationConfig,
     236            6 : ) -> auth::Result<ComputeCredentials> {
     237              :     // If there's no project so far, that entails that client doesn't
     238              :     // support SNI or other means of passing the endpoint (project) name.
     239              :     // We now expect to see a very specific payload in the place of password.
     240            6 :     let (info, unauthenticated_password) = match user_info.try_into() {
     241            2 :         Err(info) => {
     242            2 :             let res = hacks::password_hack_no_authentication(ctx, info, client).await?;
     243              : 
     244            2 :             ctx.set_endpoint_id(res.info.endpoint.clone());
     245            2 :             let password = match res.keys {
     246            2 :                 ComputeCredentialKeys::Password(p) => p,
     247            0 :                 _ => unreachable!("password hack should return a password"),
     248              :             };
     249            2 :             (res.info, Some(password))
     250              :         }
     251            4 :         Ok(info) => (info, None),
     252              :     };
     253              : 
     254            6 :     info!("fetching user's authentication info");
     255            6 :     let (allowed_ips, maybe_secret) = api.get_allowed_ips_and_secret(ctx, &info).await?;
     256              : 
     257              :     // check allowed list
     258            6 :     if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) {
     259            0 :         return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr));
     260            6 :     }
     261            6 :     let cached_secret = match maybe_secret {
     262            6 :         Some(secret) => secret,
     263            0 :         None => api.get_role_secret(ctx, &info).await?,
     264              :     };
     265            6 :     let (cached_entry, secret) = cached_secret.take_value();
     266              : 
     267            6 :     let secret = match secret {
     268            6 :         Some(secret) => config.check_rate_limit(
     269            6 :             ctx,
     270            6 :             secret,
     271            6 :             &info.endpoint,
     272            6 :             unauthenticated_password.is_some() || allow_cleartext,
     273            0 :         )?,
     274              :         None => {
     275              :             // If we don't have an authentication secret, we mock one to
     276              :             // prevent malicious probing (possible due to missing protocol steps).
     277              :             // This mocked secret will never lead to successful authentication.
     278            0 :             info!("authentication info not found, mocking it");
     279            0 :             AuthSecret::Scram(scram::ServerSecret::mock(rand::random()))
     280              :         }
     281              :     };
     282              : 
     283            6 :     match authenticate_with_secret(
     284            6 :         ctx,
     285            6 :         secret,
     286            6 :         info,
     287            6 :         client,
     288            6 :         unauthenticated_password,
     289            6 :         allow_cleartext,
     290            6 :         config,
     291            6 :     )
     292           18 :     .await
     293              :     {
     294            6 :         Ok(keys) => Ok(keys),
     295            0 :         Err(e) => {
     296            0 :             if e.is_auth_failed() {
     297            0 :                 // The password could have been changed, so we invalidate the cache.
     298            0 :                 cached_entry.invalidate();
     299            0 :             }
     300            0 :             Err(e)
     301              :         }
     302              :     }
     303            6 : }
     304              : 
     305            6 : async fn authenticate_with_secret(
     306            6 :     ctx: &mut RequestMonitoring,
     307            6 :     secret: AuthSecret,
     308            6 :     info: ComputeUserInfo,
     309            6 :     client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
     310            6 :     unauthenticated_password: Option<Vec<u8>>,
     311            6 :     allow_cleartext: bool,
     312            6 :     config: &'static AuthenticationConfig,
     313            6 : ) -> auth::Result<ComputeCredentials> {
     314            6 :     if let Some(password) = unauthenticated_password {
     315            6 :         let auth_outcome = validate_password_and_exchange(&password, secret).await?;
     316            2 :         let keys = match auth_outcome {
     317            2 :             crate::sasl::Outcome::Success(key) => key,
     318            0 :             crate::sasl::Outcome::Failure(reason) => {
     319            0 :                 info!("auth backend failed with an error: {reason}");
     320            0 :                 return Err(auth::AuthError::auth_failed(&*info.user));
     321              :             }
     322              :         };
     323              : 
     324              :         // we have authenticated the password
     325            2 :         client.write_message_noflush(&pq_proto::BeMessage::AuthenticationOk)?;
     326              : 
     327            2 :         return Ok(ComputeCredentials { info, keys });
     328            4 :     }
     329            4 : 
     330            4 :     // -- the remaining flows are self-authenticating --
     331            4 : 
     332            4 :     // Perform cleartext auth if we're allowed to do that.
     333            4 :     // Currently, we use it for websocket connections (latency).
     334            4 :     if allow_cleartext {
     335            2 :         ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
     336            8 :         return hacks::authenticate_cleartext(ctx, info, client, secret).await;
     337            2 :     }
     338            2 : 
     339            2 :     // Finally, proceed with the main auth flow (SCRAM-based).
     340            4 :     classic::authenticate(ctx, info, client, config, secret).await
     341            6 : }
     342              : 
     343              : impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
     344              :     /// Get compute endpoint name from the credentials.
     345            0 :     pub fn get_endpoint(&self) -> Option<EndpointId> {
     346            0 :         use BackendType::*;
     347            0 : 
     348            0 :         match self {
     349            0 :             Console(_, user_info) => user_info.endpoint_id.clone(),
     350            0 :             Link(_, _) => Some("link".into()),
     351              :         }
     352            0 :     }
     353              : 
     354              :     /// Get username from the credentials.
     355            0 :     pub fn get_user(&self) -> &str {
     356            0 :         use BackendType::*;
     357            0 : 
     358            0 :         match self {
     359            0 :             Console(_, user_info) => &user_info.user,
     360            0 :             Link(_, _) => "link",
     361              :         }
     362            0 :     }
     363              : 
     364              :     /// Authenticate the client via the requested backend, possibly using credentials.
     365            0 :     #[tracing::instrument(fields(allow_cleartext = allow_cleartext), skip_all)]
     366              :     pub async fn authenticate(
     367              :         self,
     368              :         ctx: &mut RequestMonitoring,
     369              :         client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
     370              :         allow_cleartext: bool,
     371              :         config: &'static AuthenticationConfig,
     372              :     ) -> auth::Result<BackendType<'a, ComputeCredentials, NodeInfo>> {
     373              :         use BackendType::*;
     374              : 
     375              :         let res = match self {
     376              :             Console(api, user_info) => {
     377            0 :                 info!(
     378            0 :                     user = &*user_info.user,
     379            0 :                     project = user_info.endpoint(),
     380            0 :                     "performing authentication using the console"
     381            0 :                 );
     382              : 
     383              :                 let credentials =
     384              :                     auth_quirks(ctx, &*api, user_info, client, allow_cleartext, config).await?;
     385              :                 BackendType::Console(api, credentials)
     386              :             }
     387              :             // NOTE: this auth backend doesn't use client credentials.
     388              :             Link(url, _) => {
     389            0 :                 info!("performing link authentication");
     390              : 
     391              :                 let info = link::authenticate(ctx, &url, client).await?;
     392              : 
     393              :                 BackendType::Link(url, info)
     394              :             }
     395              :         };
     396              : 
     397            0 :         info!("user successfully authenticated");
     398              :         Ok(res)
     399              :     }
     400              : }
     401              : 
     402              : impl BackendType<'_, ComputeUserInfo, &()> {
     403            0 :     pub async fn get_role_secret(
     404            0 :         &self,
     405            0 :         ctx: &mut RequestMonitoring,
     406            0 :     ) -> Result<CachedRoleSecret, GetAuthInfoError> {
     407            0 :         use BackendType::*;
     408            0 :         match self {
     409            0 :             Console(api, user_info) => api.get_role_secret(ctx, user_info).await,
     410            0 :             Link(_, _) => Ok(Cached::new_uncached(None)),
     411              :         }
     412            0 :     }
     413              : 
     414            0 :     pub async fn get_allowed_ips_and_secret(
     415            0 :         &self,
     416            0 :         ctx: &mut RequestMonitoring,
     417            0 :     ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
     418            0 :         use BackendType::*;
     419            0 :         match self {
     420            0 :             Console(api, user_info) => api.get_allowed_ips_and_secret(ctx, user_info).await,
     421            0 :             Link(_, _) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
     422              :         }
     423            0 :     }
     424              : }
     425              : 
     426              : #[async_trait::async_trait]
     427              : impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, NodeInfo> {
     428            0 :     async fn wake_compute(
     429            0 :         &self,
     430            0 :         ctx: &mut RequestMonitoring,
     431            0 :     ) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
     432              :         use BackendType::*;
     433              : 
     434            0 :         match self {
     435            0 :             Console(api, creds) => api.wake_compute(ctx, &creds.info).await,
     436            0 :             Link(_, info) => Ok(Cached::new_uncached(info.clone())),
     437              :         }
     438            0 :     }
     439              : 
     440            0 :     fn get_keys(&self) -> Option<&ComputeCredentialKeys> {
     441            0 :         match self {
     442            0 :             BackendType::Console(_, creds) => Some(&creds.keys),
     443            0 :             BackendType::Link(_, _) => None,
     444              :         }
     445            0 :     }
     446              : }
     447              : 
     448              : #[async_trait::async_trait]
     449              : impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, &()> {
     450           26 :     async fn wake_compute(
     451           26 :         &self,
     452           26 :         ctx: &mut RequestMonitoring,
     453           26 :     ) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
     454              :         use BackendType::*;
     455              : 
     456           26 :         match self {
     457           26 :             Console(api, creds) => api.wake_compute(ctx, &creds.info).await,
     458            0 :             Link(_, _) => unreachable!("link auth flow doesn't support waking the compute"),
     459              :         }
     460           78 :     }
     461              : 
     462           12 :     fn get_keys(&self) -> Option<&ComputeCredentialKeys> {
     463           12 :         match self {
     464           12 :             BackendType::Console(_, creds) => Some(&creds.keys),
     465            0 :             BackendType::Link(_, _) => None,
     466              :         }
     467           12 :     }
     468              : }
     469              : 
     470              : #[cfg(test)]
     471              : mod tests {
     472              :     use std::sync::Arc;
     473              : 
     474              :     use bytes::BytesMut;
     475              :     use fallible_iterator::FallibleIterator;
     476              :     use once_cell::sync::Lazy;
     477              :     use postgres_protocol::{
     478              :         authentication::sasl::{ChannelBinding, ScramSha256},
     479              :         message::{backend::Message as PgMessage, frontend},
     480              :     };
     481              :     use provider::AuthSecret;
     482              :     use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
     483              : 
     484              :     use crate::{
     485              :         auth::{ComputeUserInfoMaybeEndpoint, IpPattern},
     486              :         config::AuthenticationConfig,
     487              :         console::{
     488              :             self,
     489              :             provider::{self, CachedAllowedIps, CachedRoleSecret},
     490              :             CachedNodeInfo,
     491              :         },
     492              :         context::RequestMonitoring,
     493              :         proxy::NeonOptions,
     494              :         rate_limiter::{AuthRateLimiter, RateBucketInfo},
     495              :         scram::ServerSecret,
     496              :         stream::{PqStream, Stream},
     497              :     };
     498              : 
     499              :     use super::auth_quirks;
     500              : 
     501              :     struct Auth {
     502              :         ips: Vec<IpPattern>,
     503              :         secret: AuthSecret,
     504              :     }
     505              : 
     506              :     impl console::Api for Auth {
     507            0 :         async fn get_role_secret(
     508            0 :             &self,
     509            0 :             _ctx: &mut RequestMonitoring,
     510            0 :             _user_info: &super::ComputeUserInfo,
     511            0 :         ) -> Result<CachedRoleSecret, console::errors::GetAuthInfoError> {
     512            0 :             Ok(CachedRoleSecret::new_uncached(Some(self.secret.clone())))
     513            0 :         }
     514              : 
     515            6 :         async fn get_allowed_ips_and_secret(
     516            6 :             &self,
     517            6 :             _ctx: &mut RequestMonitoring,
     518            6 :             _user_info: &super::ComputeUserInfo,
     519            6 :         ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), console::errors::GetAuthInfoError>
     520            6 :         {
     521            6 :             Ok((
     522            6 :                 CachedAllowedIps::new_uncached(Arc::new(self.ips.clone())),
     523            6 :                 Some(CachedRoleSecret::new_uncached(Some(self.secret.clone()))),
     524            6 :             ))
     525            6 :         }
     526              : 
     527            0 :         async fn wake_compute(
     528            0 :             &self,
     529            0 :             _ctx: &mut RequestMonitoring,
     530            0 :             _user_info: &super::ComputeUserInfo,
     531            0 :         ) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
     532            0 :             unimplemented!()
     533              :         }
     534              :     }
     535              : 
     536            6 :     static CONFIG: Lazy<AuthenticationConfig> = Lazy::new(|| AuthenticationConfig {
     537            6 :         scram_protocol_timeout: std::time::Duration::from_secs(5),
     538            6 :         rate_limiter_enabled: true,
     539            6 :         rate_limiter: AuthRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET),
     540            6 :     });
     541              : 
     542           10 :     async fn read_message(r: &mut (impl AsyncRead + Unpin), b: &mut BytesMut) -> PgMessage {
     543              :         loop {
     544           14 :             r.read_buf(&mut *b).await.unwrap();
     545           14 :             if let Some(m) = PgMessage::parse(&mut *b).unwrap() {
     546           10 :                 break m;
     547            4 :             }
     548              :         }
     549           10 :     }
     550              : 
     551              :     #[tokio::test]
     552            2 :     async fn auth_quirks_scram() {
     553            2 :         let (mut client, server) = tokio::io::duplex(1024);
     554            2 :         let mut stream = PqStream::new(Stream::from_raw(server));
     555            2 : 
     556            2 :         let mut ctx = RequestMonitoring::test();
     557            2 :         let api = Auth {
     558            2 :             ips: vec![],
     559            6 :             secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
     560            2 :         };
     561            2 : 
     562            2 :         let user_info = ComputeUserInfoMaybeEndpoint {
     563            2 :             user: "conrad".into(),
     564            2 :             endpoint_id: Some("endpoint".into()),
     565            2 :             options: NeonOptions::default(),
     566            2 :         };
     567            2 : 
     568            2 :         let handle = tokio::spawn(async move {
     569            2 :             let mut scram = ScramSha256::new(b"my-secret-password", ChannelBinding::unsupported());
     570            2 : 
     571            2 :             let mut read = BytesMut::new();
     572            2 : 
     573            2 :             // server should offer scram
     574            2 :             match read_message(&mut client, &mut read).await {
     575            2 :                 PgMessage::AuthenticationSasl(a) => {
     576            2 :                     let options: Vec<&str> = a.mechanisms().collect().unwrap();
     577            2 :                     assert_eq!(options, ["SCRAM-SHA-256"]);
     578            2 :                 }
     579            2 :                 _ => panic!("wrong message"),
     580            2 :             }
     581            2 : 
     582            2 :             // client sends client-first-message
     583            2 :             let mut write = BytesMut::new();
     584            2 :             frontend::sasl_initial_response("SCRAM-SHA-256", scram.message(), &mut write).unwrap();
     585            2 :             client.write_all(&write).await.unwrap();
     586            2 : 
     587            2 :             // server response with server-first-message
     588            2 :             match read_message(&mut client, &mut read).await {
     589            2 :                 PgMessage::AuthenticationSaslContinue(a) => {
     590            6 :                     scram.update(a.data()).await.unwrap();
     591            2 :                 }
     592            2 :                 _ => panic!("wrong message"),
     593            2 :             }
     594            2 : 
     595            2 :             // client response with client-final-message
     596            2 :             write.clear();
     597            2 :             frontend::sasl_response(scram.message(), &mut write).unwrap();
     598            2 :             client.write_all(&write).await.unwrap();
     599            2 : 
     600            2 :             // server response with server-final-message
     601            2 :             match read_message(&mut client, &mut read).await {
     602            2 :                 PgMessage::AuthenticationSaslFinal(a) => {
     603            2 :                     scram.finish(a.data()).unwrap();
     604            2 :                 }
     605            2 :                 _ => panic!("wrong message"),
     606            2 :             }
     607            2 :         });
     608            2 : 
     609            2 :         let _creds = auth_quirks(&mut ctx, &api, user_info, &mut stream, false, &CONFIG)
     610            4 :             .await
     611            2 :             .unwrap();
     612            2 : 
     613            2 :         handle.await.unwrap();
     614            2 :     }
     615              : 
     616              :     #[tokio::test]
     617            2 :     async fn auth_quirks_cleartext() {
     618            2 :         let (mut client, server) = tokio::io::duplex(1024);
     619            2 :         let mut stream = PqStream::new(Stream::from_raw(server));
     620            2 : 
     621            2 :         let mut ctx = RequestMonitoring::test();
     622            2 :         let api = Auth {
     623            2 :             ips: vec![],
     624            6 :             secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
     625            2 :         };
     626            2 : 
     627            2 :         let user_info = ComputeUserInfoMaybeEndpoint {
     628            2 :             user: "conrad".into(),
     629            2 :             endpoint_id: Some("endpoint".into()),
     630            2 :             options: NeonOptions::default(),
     631            2 :         };
     632            2 : 
     633            2 :         let handle = tokio::spawn(async move {
     634            2 :             let mut read = BytesMut::new();
     635            2 :             let mut write = BytesMut::new();
     636            2 : 
     637            2 :             // server should offer cleartext
     638            2 :             match read_message(&mut client, &mut read).await {
     639            2 :                 PgMessage::AuthenticationCleartextPassword => {}
     640            2 :                 _ => panic!("wrong message"),
     641            2 :             }
     642            2 : 
     643            2 :             // client responds with password
     644            2 :             write.clear();
     645            2 :             frontend::password_message(b"my-secret-password", &mut write).unwrap();
     646            2 :             client.write_all(&write).await.unwrap();
     647            2 :         });
     648            2 : 
     649            2 :         let _creds = auth_quirks(&mut ctx, &api, user_info, &mut stream, true, &CONFIG)
     650            8 :             .await
     651            2 :             .unwrap();
     652            2 : 
     653            2 :         handle.await.unwrap();
     654            2 :     }
     655              : 
     656              :     #[tokio::test]
     657            2 :     async fn auth_quirks_password_hack() {
     658            2 :         let (mut client, server) = tokio::io::duplex(1024);
     659            2 :         let mut stream = PqStream::new(Stream::from_raw(server));
     660            2 : 
     661            2 :         let mut ctx = RequestMonitoring::test();
     662            2 :         let api = Auth {
     663            2 :             ips: vec![],
     664            6 :             secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
     665            2 :         };
     666            2 : 
     667            2 :         let user_info = ComputeUserInfoMaybeEndpoint {
     668            2 :             user: "conrad".into(),
     669            2 :             endpoint_id: None,
     670            2 :             options: NeonOptions::default(),
     671            2 :         };
     672            2 : 
     673            2 :         let handle = tokio::spawn(async move {
     674            2 :             let mut read = BytesMut::new();
     675            2 : 
     676            2 :             // server should offer cleartext
     677            2 :             match read_message(&mut client, &mut read).await {
     678            2 :                 PgMessage::AuthenticationCleartextPassword => {}
     679            2 :                 _ => panic!("wrong message"),
     680            2 :             }
     681            2 : 
     682            2 :             // client responds with password
     683            2 :             let mut write = BytesMut::new();
     684            2 :             frontend::password_message(b"endpoint=my-endpoint;my-secret-password", &mut write)
     685            2 :                 .unwrap();
     686            2 :             client.write_all(&write).await.unwrap();
     687            2 :         });
     688            2 : 
     689            2 :         let creds = auth_quirks(&mut ctx, &api, user_info, &mut stream, true, &CONFIG)
     690            8 :             .await
     691            2 :             .unwrap();
     692            2 : 
     693            2 :         assert_eq!(creds.info.endpoint, "my-endpoint");
     694            2 : 
     695            2 :         handle.await.unwrap();
     696            2 :     }
     697              : }
        

Generated by: LCOV version 2.1-beta