LCOV - differential code coverage report
Current view: top level - proxy/src/auth - backend.rs (source / functions) Coverage Total Hit UBC CBC
Current: cd44433dd675caa99df17a61b18949c8387e2242.info Lines: 86.8 % 205 178 27 178
Current Date: 2024-01-09 02:06:09 Functions: 35.9 % 92 33 59 33
Baseline: 66c52a629a0f4a503e193045e0df4c77139e344b.info
Baseline Date: 2024-01-08 15:34:46

           TLA  Line data    Source code
       1                 : mod classic;
       2                 : mod hacks;
       3                 : mod link;
       4                 : 
       5                 : pub use link::LinkAuthError;
       6                 : use smol_str::SmolStr;
       7                 : use tokio_postgres::config::AuthKeys;
       8                 : 
       9                 : use crate::auth::credentials::check_peer_addr_is_in_list;
      10                 : use crate::auth::validate_password_and_exchange;
      11                 : use crate::console::errors::GetAuthInfoError;
      12                 : use crate::console::AuthSecret;
      13                 : use crate::context::RequestMonitoring;
      14                 : use crate::proxy::connect_compute::handle_try_wake;
      15                 : use crate::proxy::retry::retry_after;
      16                 : use crate::scram;
      17                 : use crate::stream::Stream;
      18                 : use crate::{
      19                 :     auth::{self, ClientCredentials},
      20                 :     config::AuthenticationConfig,
      21                 :     console::{
      22                 :         self,
      23                 :         provider::{CachedNodeInfo, ConsoleReqExtra},
      24                 :         Api,
      25                 :     },
      26                 :     stream, url,
      27                 : };
      28                 : use futures::TryFutureExt;
      29                 : use std::borrow::Cow;
      30                 : use std::ops::ControlFlow;
      31                 : use std::sync::Arc;
      32                 : use tokio::io::{AsyncRead, AsyncWrite};
      33                 : use tracing::{error, info, warn};
      34                 : 
      35                 : /// This type serves two purposes:
      36                 : ///
      37                 : /// * When `T` is `()`, it's just a regular auth backend selector
      38                 : ///   which we use in [`crate::config::ProxyConfig`].
      39                 : ///
      40                 : /// * However, when we substitute `T` with [`ClientCredentials`],
      41                 : ///   this helps us provide the credentials only to those auth
      42                 : ///   backends which require them for the authentication process.
      43                 : pub enum BackendType<'a, T> {
      44                 :     /// Current Cloud API (V2).
      45                 :     Console(Cow<'a, console::provider::neon::Api>, T),
      46                 :     /// Local mock of Cloud API (V2).
      47                 :     #[cfg(feature = "testing")]
      48                 :     Postgres(Cow<'a, console::provider::mock::Api>, T),
      49                 :     /// Authentication via a web browser.
      50                 :     Link(Cow<'a, url::ApiUrl>),
      51                 :     #[cfg(test)]
      52                 :     /// Test backend.
      53                 :     Test(&'a dyn TestBackend),
      54                 : }
      55                 : 
      56                 : pub trait TestBackend: Send + Sync + 'static {
      57                 :     fn wake_compute(&self) -> Result<CachedNodeInfo, console::errors::WakeComputeError>;
      58                 :     fn get_allowed_ips(&self) -> Result<Arc<Vec<String>>, console::errors::GetAuthInfoError>;
      59                 : }
      60                 : 
      61                 : impl std::fmt::Display for BackendType<'_, ()> {
      62 CBC          22 :     fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
      63              22 :         use BackendType::*;
      64              22 :         match self {
      65               1 :             Console(endpoint, _) => fmt.debug_tuple("Console").field(&endpoint.url()).finish(),
      66                 :             #[cfg(feature = "testing")]
      67              18 :             Postgres(endpoint, _) => fmt.debug_tuple("Postgres").field(&endpoint.url()).finish(),
      68               3 :             Link(url) => fmt.debug_tuple("Link").field(&url.as_str()).finish(),
      69                 :             #[cfg(test)]
      70 UBC           0 :             Test(_) => fmt.debug_tuple("Test").finish(),
      71                 :         }
      72 CBC          22 :     }
      73                 : }
      74                 : 
      75                 : impl<T> BackendType<'_, T> {
      76                 :     /// Very similar to [`std::option::Option::as_ref`].
      77                 :     /// This helps us pass structured config to async tasks.
      78              90 :     pub fn as_ref(&self) -> BackendType<'_, &T> {
      79              90 :         use BackendType::*;
      80              90 :         match self {
      81               4 :             Console(c, x) => Console(Cow::Borrowed(c), x),
      82                 :             #[cfg(feature = "testing")]
      83              83 :             Postgres(c, x) => Postgres(Cow::Borrowed(c), x),
      84               3 :             Link(c) => Link(Cow::Borrowed(c)),
      85                 :             #[cfg(test)]
      86 UBC           0 :             Test(x) => Test(*x),
      87                 :         }
      88 CBC          90 :     }
      89                 : }
      90                 : 
      91                 : impl<'a, T> BackendType<'a, T> {
      92                 :     /// Very similar to [`std::option::Option::map`].
      93                 :     /// Maps [`BackendType<T>`] to [`BackendType<R>`] by applying
      94                 :     /// a function to a contained value.
      95              90 :     pub fn map<R>(self, f: impl FnOnce(T) -> R) -> BackendType<'a, R> {
      96              90 :         use BackendType::*;
      97              90 :         match self {
      98               4 :             Console(c, x) => Console(c, f(x)),
      99                 :             #[cfg(feature = "testing")]
     100              83 :             Postgres(c, x) => Postgres(c, f(x)),
     101               3 :             Link(c) => Link(c),
     102                 :             #[cfg(test)]
     103 UBC           0 :             Test(x) => Test(x),
     104                 :         }
     105 CBC          90 :     }
     106                 : }
     107                 : 
     108                 : impl<'a, T, E> BackendType<'a, Result<T, E>> {
     109                 :     /// Very similar to [`std::option::Option::transpose`].
     110                 :     /// This is most useful for error handling.
     111              49 :     pub fn transpose(self) -> Result<BackendType<'a, T>, E> {
     112              49 :         use BackendType::*;
     113              49 :         match self {
     114               4 :             Console(c, x) => x.map(|x| Console(c, x)),
     115                 :             #[cfg(feature = "testing")]
     116              42 :             Postgres(c, x) => x.map(|x| Postgres(c, x)),
     117               3 :             Link(c) => Ok(Link(c)),
     118                 :             #[cfg(test)]
     119 UBC           0 :             Test(x) => Ok(Test(x)),
     120                 :         }
     121 CBC          49 :     }
     122                 : }
     123                 : 
     124                 : pub struct ComputeCredentials<T> {
     125                 :     pub info: ComputeUserInfo,
     126                 :     pub keys: T,
     127                 : }
     128                 : 
     129                 : pub struct ComputeUserInfoNoEndpoint {
     130                 :     pub user: SmolStr,
     131                 :     pub cache_key: SmolStr,
     132                 : }
     133                 : 
     134                 : pub struct ComputeUserInfo {
     135                 :     pub endpoint: SmolStr,
     136                 :     pub inner: ComputeUserInfoNoEndpoint,
     137                 : }
     138                 : 
     139                 : pub enum ComputeCredentialKeys {
     140                 :     #[cfg(feature = "testing")]
     141                 :     Password(Vec<u8>),
     142                 :     AuthKeys(AuthKeys),
     143                 : }
     144                 : 
     145                 : impl TryFrom<ClientCredentials> for ComputeUserInfo {
     146                 :     // user name
     147                 :     type Error = ComputeUserInfoNoEndpoint;
     148                 : 
     149              87 :     fn try_from(creds: ClientCredentials) -> Result<Self, Self::Error> {
     150              87 :         let inner = ComputeUserInfoNoEndpoint {
     151              87 :             user: creds.user,
     152              87 :             cache_key: creds.cache_key,
     153              87 :         };
     154              87 :         match creds.project {
     155               3 :             None => Err(inner),
     156              84 :             Some(endpoint) => Ok(ComputeUserInfo { endpoint, inner }),
     157                 :         }
     158              87 :     }
     159                 : }
     160                 : 
     161                 : /// True to its name, this function encapsulates our current auth trade-offs.
     162                 : /// Here, we choose the appropriate auth flow based on circumstances.
     163                 : ///
     164                 : /// All authentication flows will emit an AuthenticationOk message if successful.
     165              46 : async fn auth_quirks(
     166              46 :     ctx: &mut RequestMonitoring,
     167              46 :     api: &impl console::Api,
     168              46 :     creds: ClientCredentials,
     169              46 :     client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
     170              46 :     allow_cleartext: bool,
     171              46 :     config: &'static AuthenticationConfig,
     172              46 : ) -> auth::Result<ComputeCredentials<ComputeCredentialKeys>> {
     173                 :     // If there's no project so far, that entails that client doesn't
     174                 :     // support SNI or other means of passing the endpoint (project) name.
     175                 :     // We now expect to see a very specific payload in the place of password.
     176              46 :     let (info, unauthenticated_password) = match creds.try_into() {
     177               3 :         Err(info) => {
     178               3 :             let res = hacks::password_hack_no_authentication(info, client, &mut ctx.latency_timer)
     179               3 :                 .await?;
     180               2 :             ctx.set_endpoint_id(Some(res.info.endpoint.clone()));
     181               2 :             (res.info, Some(res.keys))
     182                 :         }
     183              43 :         Ok(info) => (info, None),
     184                 :     };
     185                 : 
     186              45 :     info!("fetching user's authentication info");
     187             279 :     let allowed_ips = api.get_allowed_ips(ctx, &info).await?;
     188                 : 
     189                 :     // check allowed list
     190              41 :     if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) {
     191               3 :         return Err(auth::AuthError::ip_address_not_allowed());
     192              38 :     }
     193             264 :     let cached_secret = api.get_role_secret(ctx, &info).await?;
     194                 : 
     195              38 :     let secret = cached_secret.clone().unwrap_or_else(|| {
     196               1 :         // If we don't have an authentication secret, we mock one to
     197               1 :         // prevent malicious probing (possible due to missing protocol steps).
     198               1 :         // This mocked secret will never lead to successful authentication.
     199               1 :         info!("authentication info not found, mocking it");
     200               1 :         AuthSecret::Scram(scram::ServerSecret::mock(&info.inner.user, rand::random()))
     201              38 :     });
     202              38 :     match authenticate_with_secret(
     203              38 :         ctx,
     204              38 :         secret,
     205              38 :         info,
     206              38 :         client,
     207              38 :         unauthenticated_password,
     208              38 :         allow_cleartext,
     209              38 :         config,
     210              38 :     )
     211              68 :     .await
     212                 :     {
     213              35 :         Ok(keys) => Ok(keys),
     214               3 :         Err(e) => {
     215               3 :             if e.is_auth_failed() {
     216               3 :                 // The password could have been changed, so we invalidate the cache.
     217               3 :                 cached_secret.invalidate();
     218               3 :             }
     219               3 :             Err(e)
     220                 :         }
     221                 :     }
     222              46 : }
     223                 : 
     224              38 : async fn authenticate_with_secret(
     225              38 :     ctx: &mut RequestMonitoring,
     226              38 :     secret: AuthSecret,
     227              38 :     info: ComputeUserInfo,
     228              38 :     client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
     229              38 :     unauthenticated_password: Option<Vec<u8>>,
     230              38 :     allow_cleartext: bool,
     231              38 :     config: &'static AuthenticationConfig,
     232              38 : ) -> auth::Result<ComputeCredentials<ComputeCredentialKeys>> {
     233              38 :     if let Some(password) = unauthenticated_password {
     234               2 :         let auth_outcome = validate_password_and_exchange(&password, secret)?;
     235               2 :         let keys = match auth_outcome {
     236               2 :             crate::sasl::Outcome::Success(key) => key,
     237 UBC           0 :             crate::sasl::Outcome::Failure(reason) => {
     238               0 :                 info!("auth backend failed with an error: {reason}");
     239               0 :                 return Err(auth::AuthError::auth_failed(&*info.inner.user));
     240                 :             }
     241                 :         };
     242                 : 
     243                 :         // we have authenticated the password
     244 CBC           2 :         client.write_message_noflush(&pq_proto::BeMessage::AuthenticationOk)?;
     245                 : 
     246               2 :         return Ok(ComputeCredentials { info, keys });
     247              36 :     }
     248              36 : 
     249              36 :     // -- the remaining flows are self-authenticating --
     250              36 : 
     251              36 :     // Perform cleartext auth if we're allowed to do that.
     252              36 :     // Currently, we use it for websocket connections (latency).
     253              36 :     if allow_cleartext {
     254 UBC           0 :         return hacks::authenticate_cleartext(info, client, &mut ctx.latency_timer, secret).await;
     255 CBC          36 :     }
     256              36 : 
     257              36 :     // Finally, proceed with the main auth flow (SCRAM-based).
     258              68 :     classic::authenticate(info, client, config, &mut ctx.latency_timer, secret).await
     259              38 : }
     260                 : 
     261                 : /// Authenticate the user and then wake a compute (or retrieve an existing compute session from cache)
     262                 : /// only if authentication was successfuly.
     263              46 : async fn auth_and_wake_compute(
     264              46 :     ctx: &mut RequestMonitoring,
     265              46 :     api: &impl console::Api,
     266              46 :     extra: &ConsoleReqExtra,
     267              46 :     creds: ClientCredentials,
     268              46 :     client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
     269              46 :     allow_cleartext: bool,
     270              46 :     config: &'static AuthenticationConfig,
     271              46 : ) -> auth::Result<(CachedNodeInfo, ComputeUserInfo)> {
     272             614 :     let compute_credentials = auth_quirks(ctx, api, creds, client, allow_cleartext, config).await?;
     273                 : 
     274              35 :     let mut num_retries = 0;
     275              35 :     let mut node = loop {
     276              35 :         let wake_res = api
     277              35 :             .wake_compute(ctx, extra, &compute_credentials.info)
     278 UBC           0 :             .await;
     279 CBC          35 :         match handle_try_wake(wake_res, num_retries) {
     280 UBC           0 :             Err(e) => {
     281               0 :                 error!(error = ?e, num_retries, retriable = false, "couldn't wake compute node");
     282               0 :                 return Err(e.into());
     283                 :             }
     284               0 :             Ok(ControlFlow::Continue(e)) => {
     285               0 :                 warn!(error = ?e, num_retries, retriable = true, "couldn't wake compute node");
     286                 :             }
     287 CBC          35 :             Ok(ControlFlow::Break(n)) => break n,
     288                 :         }
     289                 : 
     290 UBC           0 :         let wait_duration = retry_after(num_retries);
     291               0 :         num_retries += 1;
     292               0 :         tokio::time::sleep(wait_duration).await;
     293                 :     };
     294                 : 
     295 CBC          35 :     ctx.set_project(node.aux.clone());
     296              35 : 
     297              35 :     match compute_credentials.keys {
     298                 :         #[cfg(feature = "testing")]
     299 UBC           0 :         ComputeCredentialKeys::Password(password) => node.config.password(password),
     300 CBC          35 :         ComputeCredentialKeys::AuthKeys(auth_keys) => node.config.auth_keys(auth_keys),
     301                 :     };
     302                 : 
     303              35 :     Ok((node, compute_credentials.info))
     304              46 : }
     305                 : 
     306                 : impl<'a> BackendType<'a, ClientCredentials> {
     307                 :     /// Get compute endpoint name from the credentials.
     308             147 :     pub fn get_endpoint(&self) -> Option<SmolStr> {
     309             147 :         use BackendType::*;
     310             147 : 
     311             147 :         match self {
     312              12 :             Console(_, creds) => creds.project.clone(),
     313                 :             #[cfg(feature = "testing")]
     314             126 :             Postgres(_, creds) => creds.project.clone(),
     315               9 :             Link(_) => Some("link".into()),
     316                 :             #[cfg(test)]
     317 UBC           0 :             Test(_) => Some("test".into()),
     318                 :         }
     319 CBC         147 :     }
     320                 : 
     321                 :     /// Get username from the credentials.
     322              49 :     pub fn get_user(&self) -> &str {
     323              49 :         use BackendType::*;
     324              49 : 
     325              49 :         match self {
     326               4 :             Console(_, creds) => &creds.user,
     327                 :             #[cfg(feature = "testing")]
     328              42 :             Postgres(_, creds) => &creds.user,
     329               3 :             Link(_) => "link",
     330                 :             #[cfg(test)]
     331 UBC           0 :             Test(_) => "test",
     332                 :         }
     333 CBC          49 :     }
     334                 : 
     335                 :     /// Authenticate the client via the requested backend, possibly using credentials.
     336 UBC           0 :     #[tracing::instrument(fields(allow_cleartext = allow_cleartext), skip_all)]
     337                 :     pub async fn authenticate(
     338                 :         self,
     339                 :         ctx: &mut RequestMonitoring,
     340                 :         extra: &ConsoleReqExtra,
     341                 :         client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
     342                 :         allow_cleartext: bool,
     343                 :         config: &'static AuthenticationConfig,
     344                 :     ) -> auth::Result<(CachedNodeInfo, BackendType<'a, ComputeUserInfo>)> {
     345                 :         use BackendType::*;
     346                 : 
     347                 :         let res = match self {
     348                 :             Console(api, creds) => {
     349 CBC           4 :                 info!(
     350               4 :                     user = &*creds.user,
     351               4 :                     project = creds.project(),
     352               4 :                     "performing authentication using the console"
     353               4 :                 );
     354                 : 
     355                 :                 let (cache_info, user_info) = auth_and_wake_compute(
     356                 :                     ctx,
     357                 :                     &*api,
     358                 :                     extra,
     359                 :                     creds,
     360                 :                     client,
     361                 :                     allow_cleartext,
     362                 :                     config,
     363                 :                 )
     364                 :                 .await?;
     365                 :                 (cache_info, BackendType::Console(api, user_info))
     366                 :             }
     367                 :             #[cfg(feature = "testing")]
     368                 :             Postgres(api, creds) => {
     369              42 :                 info!(
     370              42 :                     user = &*creds.user,
     371              42 :                     project = creds.project(),
     372              42 :                     "performing authentication using a local postgres instance"
     373              42 :                 );
     374                 : 
     375                 :                 let (cache_info, user_info) = auth_and_wake_compute(
     376                 :                     ctx,
     377                 :                     &*api,
     378                 :                     extra,
     379                 :                     creds,
     380                 :                     client,
     381                 :                     allow_cleartext,
     382                 :                     config,
     383                 :                 )
     384                 :                 .await?;
     385                 :                 (cache_info, BackendType::Postgres(api, user_info))
     386                 :             }
     387                 :             // NOTE: this auth backend doesn't use client credentials.
     388                 :             Link(url) => {
     389               3 :                 info!("performing link authentication");
     390                 : 
     391                 :                 let node_info = link::authenticate(&url, client).await?;
     392                 : 
     393                 :                 (
     394                 :                     CachedNodeInfo::new_uncached(node_info),
     395                 :                     BackendType::Link(url),
     396                 :                 )
     397                 :             }
     398                 :             #[cfg(test)]
     399                 :             Test(_) => {
     400                 :                 unreachable!("this function should never be called in the test backend")
     401                 :             }
     402                 :         };
     403                 : 
     404              38 :         info!("user successfully authenticated");
     405                 :         Ok(res)
     406                 :     }
     407                 : }
     408                 : 
     409                 : impl BackendType<'_, ComputeUserInfo> {
     410              41 :     pub async fn get_allowed_ips(
     411              41 :         &self,
     412              41 :         ctx: &mut RequestMonitoring,
     413              41 :     ) -> Result<Arc<Vec<String>>, GetAuthInfoError> {
     414              41 :         use BackendType::*;
     415              41 :         match self {
     416 UBC           0 :             Console(api, creds) => api.get_allowed_ips(ctx, creds).await,
     417                 :             #[cfg(feature = "testing")]
     418 CBC         313 :             Postgres(api, creds) => api.get_allowed_ips(ctx, creds).await,
     419 UBC           0 :             Link(_) => Ok(Arc::new(vec![])),
     420                 :             #[cfg(test)]
     421               0 :             Test(x) => x.get_allowed_ips(),
     422                 :         }
     423 CBC          41 :     }
     424                 : 
     425                 :     /// When applicable, wake the compute node, gaining its connection info in the process.
     426                 :     /// The link auth flow doesn't support this, so we return [`None`] in that case.
     427              40 :     pub async fn wake_compute(
     428              40 :         &self,
     429              40 :         ctx: &mut RequestMonitoring,
     430              40 :         extra: &ConsoleReqExtra,
     431              40 :     ) -> Result<Option<CachedNodeInfo>, console::errors::WakeComputeError> {
     432              40 :         use BackendType::*;
     433              40 : 
     434              40 :         match self {
     435 UBC           0 :             Console(api, creds) => api.wake_compute(ctx, extra, creds).map_ok(Some).await,
     436                 :             #[cfg(feature = "testing")]
     437 CBC          40 :             Postgres(api, creds) => api.wake_compute(ctx, extra, creds).map_ok(Some).await,
     438 UBC           0 :             Link(_) => Ok(None),
     439                 :             #[cfg(test)]
     440               0 :             Test(x) => x.wake_compute().map(Some),
     441                 :         }
     442 CBC          40 :     }
     443                 : }
        

Generated by: LCOV version 2.1-beta