LCOV - differential code coverage report
Current view: top level - proxy/src/console/provider - mock.rs (source / functions) Coverage Total Hit UBC CBC
Current: cd44433dd675caa99df17a61b18949c8387e2242.info Lines: 84.1 % 107 90 17 90
Current Date: 2024-01-09 02:06:09 Functions: 64.0 % 25 16 9 16
Baseline: 66c52a629a0f4a503e193045e0df4c77139e344b.info
Baseline Date: 2024-01-08 15:34:46

           TLA  Line data    Source code
       1                 : //! Mock console backend which relies on a user-provided postgres instance.
       2                 : 
       3                 : use std::sync::Arc;
       4                 : 
       5                 : use super::{
       6                 :     errors::{ApiError, GetAuthInfoError, WakeComputeError},
       7                 :     AuthInfo, AuthSecret, CachedNodeInfo, ConsoleReqExtra, NodeInfo,
       8                 : };
       9                 : use crate::{auth::backend::ComputeUserInfo, compute, error::io_error, scram, url::ApiUrl};
      10                 : use crate::{console::provider::CachedRoleSecret, context::RequestMonitoring};
      11                 : use async_trait::async_trait;
      12                 : use futures::TryFutureExt;
      13                 : use thiserror::Error;
      14                 : use tokio_postgres::{config::SslMode, Client};
      15                 : use tracing::{error, info, info_span, warn, Instrument};
      16                 : 
      17 UBC           0 : #[derive(Debug, Error)]
      18                 : enum MockApiError {
      19                 :     #[error("Failed to read password: {0}")]
      20                 :     PasswordNotSet(tokio_postgres::Error),
      21                 : }
      22                 : 
      23                 : impl From<MockApiError> for ApiError {
      24               0 :     fn from(e: MockApiError) -> Self {
      25               0 :         io_error(e).into()
      26               0 :     }
      27                 : }
      28                 : 
      29                 : impl From<tokio_postgres::Error> for ApiError {
      30               0 :     fn from(e: tokio_postgres::Error) -> Self {
      31               0 :         io_error(e).into()
      32               0 :     }
      33                 : }
      34                 : 
      35               0 : #[derive(Clone)]
      36                 : pub struct Api {
      37                 :     endpoint: ApiUrl,
      38                 : }
      39                 : 
      40                 : impl Api {
      41 CBC          18 :     pub fn new(endpoint: ApiUrl) -> Self {
      42              18 :         Self { endpoint }
      43              18 :     }
      44                 : 
      45              18 :     pub fn url(&self) -> &str {
      46              18 :         self.endpoint.as_str()
      47              18 :     }
      48                 : 
      49             120 :     async fn do_get_auth_info(
      50             120 :         &self,
      51             120 :         creds: &ComputeUserInfo,
      52             120 :     ) -> Result<AuthInfo, GetAuthInfoError> {
      53             120 :         let (secret, allowed_ips) = async {
      54                 :             // Perhaps we could persist this connection, but then we'd have to
      55                 :             // write more code for reopening it if it got closed, which doesn't
      56                 :             // seem worth it.
      57             120 :             let (client, connection) =
      58             364 :                 tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?;
      59                 : 
      60             120 :             tokio::spawn(connection);
      61             120 :             let secret = match get_execute_postgres_query(
      62             120 :                 &client,
      63             120 :                 "select rolpassword from pg_catalog.pg_authid where rolname = $1",
      64             120 :                 &[&&*creds.inner.user],
      65             120 :                 "rolpassword",
      66             120 :             )
      67             240 :             .await?
      68                 :             {
      69             118 :                 Some(entry) => {
      70             118 :                     info!("got a secret: {entry}"); // safe since it's not a prod scenario
      71             118 :                     let secret = scram::ServerSecret::parse(&entry).map(AuthSecret::Scram);
      72             118 :                     secret.or_else(|| parse_md5(&entry).map(AuthSecret::Md5))
      73                 :                 }
      74                 :                 None => {
      75               2 :                     warn!("user '{}' does not exist", creds.inner.user);
      76               2 :                     None
      77                 :                 }
      78                 :             };
      79             120 :             let allowed_ips = match get_execute_postgres_query(
      80             120 :                 &client,
      81             120 :                 "select allowed_ips from neon_control_plane.endpoints where endpoint_id = $1",
      82             120 :                 &[&creds.endpoint.as_str()],
      83             120 :                 "allowed_ips",
      84             120 :             )
      85             240 :             .await?
      86                 :             {
      87              11 :                 Some(s) => {
      88              11 :                     info!("got allowed_ips: {s}");
      89              11 :                     s.split(',').map(String::from).collect()
      90                 :                 }
      91             109 :                 None => vec![],
      92                 :             };
      93                 : 
      94             120 :             Ok((secret, allowed_ips))
      95             120 :         }
      96             120 :         .map_err(crate::error::log_error::<GetAuthInfoError>)
      97             120 :         .instrument(info_span!("postgres", url = self.endpoint.as_str()))
      98             844 :         .await?;
      99             120 :         Ok(AuthInfo {
     100             120 :             secret,
     101             120 :             allowed_ips,
     102             120 :         })
     103             120 :     }
     104                 : 
     105              77 :     async fn do_wake_compute(&self) -> Result<NodeInfo, WakeComputeError> {
     106              77 :         let mut config = compute::ConnCfg::new();
     107              77 :         config
     108              77 :             .host(self.endpoint.host_str().unwrap_or("localhost"))
     109              77 :             .port(self.endpoint.port().unwrap_or(5432))
     110              77 :             .ssl_mode(SslMode::Disable);
     111              77 : 
     112              77 :         let node = NodeInfo {
     113              77 :             config,
     114              77 :             aux: Default::default(),
     115              77 :             allow_self_signed_compute: false,
     116              77 :         };
     117              77 : 
     118              77 :         Ok(node)
     119              77 :     }
     120                 : }
     121                 : 
     122             240 : async fn get_execute_postgres_query(
     123             240 :     client: &Client,
     124             240 :     query: &str,
     125             240 :     params: &[&(dyn tokio_postgres::types::ToSql + Sync)],
     126             240 :     idx: &str,
     127             240 : ) -> Result<Option<String>, GetAuthInfoError> {
     128             480 :     let rows = client.query(query, params).await?;
     129                 : 
     130                 :     // We can get at most one row, because `rolname` is unique.
     131             240 :     let row = match rows.first() {
     132             129 :         Some(row) => row,
     133                 :         // This means that the user doesn't exist, so there can be no secret.
     134                 :         // However, this is still a *valid* outcome which is very similar
     135                 :         // to getting `404 Not found` from the Neon console.
     136             111 :         None => return Ok(None),
     137                 :     };
     138                 : 
     139             129 :     let entry = row.try_get(idx).map_err(MockApiError::PasswordNotSet)?;
     140             129 :     Ok(Some(entry))
     141             240 : }
     142                 : 
     143                 : #[async_trait]
     144                 : impl super::Api for Api {
     145 UBC           0 :     #[tracing::instrument(skip_all)]
     146                 :     async fn get_role_secret(
     147                 :         &self,
     148                 :         _ctx: &mut RequestMonitoring,
     149                 :         creds: &ComputeUserInfo,
     150 CBC          38 :     ) -> Result<CachedRoleSecret, GetAuthInfoError> {
     151                 :         Ok(CachedRoleSecret::new_uncached(
     152             264 :             self.do_get_auth_info(creds).await?.secret,
     153                 :         ))
     154              76 :     }
     155                 : 
     156              82 :     async fn get_allowed_ips(
     157              82 :         &self,
     158              82 :         _ctx: &mut RequestMonitoring,
     159              82 :         creds: &ComputeUserInfo,
     160              82 :     ) -> Result<Arc<Vec<String>>, GetAuthInfoError> {
     161             580 :         Ok(Arc::new(self.do_get_auth_info(creds).await?.allowed_ips))
     162             164 :     }
     163                 : 
     164 UBC           0 :     #[tracing::instrument(skip_all)]
     165                 :     async fn wake_compute(
     166                 :         &self,
     167                 :         _ctx: &mut RequestMonitoring,
     168                 :         _extra: &ConsoleReqExtra,
     169                 :         _creds: &ComputeUserInfo,
     170 CBC          77 :     ) -> Result<CachedNodeInfo, WakeComputeError> {
     171              77 :         self.do_wake_compute()
     172              77 :             .map_ok(CachedNodeInfo::new_uncached)
     173 UBC           0 :             .await
     174 CBC         154 :     }
     175                 : }
     176                 : 
     177 UBC           0 : fn parse_md5(input: &str) -> Option<[u8; 16]> {
     178               0 :     let text = input.strip_prefix("md5")?;
     179                 : 
     180               0 :     let mut bytes = [0u8; 16];
     181               0 :     hex::decode_to_slice(text, &mut bytes).ok()?;
     182                 : 
     183               0 :     Some(bytes)
     184               0 : }
        

Generated by: LCOV version 2.1-beta