LCOV - differential code coverage report
Current view: top level - proxy/src/console/provider - mock.rs (source / functions) Coverage Total Hit UBC CBC
Current: f6946e90941b557c917ac98cd5a7e9506d180f3e.info Lines: 78.6 % 70 55 15 55
Current Date: 2023-10-19 02:04:12 Functions: 68.2 % 22 15 7 15
Baseline: c8637f37369098875162f194f92736355783b050.info
Baseline Date: 2023-10-18 20:25:20

           TLA  Line data    Source code
       1                 : //! Mock console backend which relies on a user-provided postgres instance.
       2                 : 
       3                 : use super::{
       4                 :     errors::{ApiError, GetAuthInfoError, WakeComputeError},
       5                 :     AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo,
       6                 : };
       7                 : use crate::{auth::ClientCredentials, compute, error::io_error, scram, url::ApiUrl};
       8                 : use async_trait::async_trait;
       9                 : use futures::TryFutureExt;
      10                 : use thiserror::Error;
      11                 : use tokio_postgres::config::SslMode;
      12                 : use tracing::{error, info, info_span, warn, Instrument};
      13                 : 
      14 UBC           0 : #[derive(Debug, Error)]
      15                 : enum MockApiError {
      16                 :     #[error("Failed to read password: {0}")]
      17                 :     PasswordNotSet(tokio_postgres::Error),
      18                 : }
      19                 : 
      20                 : impl From<MockApiError> for ApiError {
      21               0 :     fn from(e: MockApiError) -> Self {
      22               0 :         io_error(e).into()
      23               0 :     }
      24                 : }
      25                 : 
      26                 : impl From<tokio_postgres::Error> for ApiError {
      27               0 :     fn from(e: tokio_postgres::Error) -> Self {
      28               0 :         io_error(e).into()
      29               0 :     }
      30                 : }
      31                 : 
      32               0 : #[derive(Clone)]
      33                 : pub struct Api {
      34                 :     endpoint: ApiUrl,
      35                 : }
      36                 : 
      37                 : impl Api {
      38 CBC          13 :     pub fn new(endpoint: ApiUrl) -> Self {
      39              13 :         Self { endpoint }
      40              13 :     }
      41                 : 
      42              13 :     pub fn url(&self) -> &str {
      43              13 :         self.endpoint.as_str()
      44              13 :     }
      45                 : 
      46              27 :     async fn do_get_auth_info(
      47              27 :         &self,
      48              27 :         creds: &ClientCredentials<'_>,
      49              27 :     ) -> Result<Option<AuthInfo>, GetAuthInfoError> {
      50              27 :         async {
      51                 :             // Perhaps we could persist this connection, but then we'd have to
      52                 :             // write more code for reopening it if it got closed, which doesn't
      53                 :             // seem worth it.
      54              27 :             let (client, connection) =
      55              81 :                 tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?;
      56                 : 
      57              27 :             tokio::spawn(connection);
      58              27 :             let query = "select rolpassword from pg_catalog.pg_authid where rolname = $1";
      59              54 :             let rows = client.query(query, &[&creds.user]).await?;
      60                 : 
      61                 :             // We can get at most one row, because `rolname` is unique.
      62              27 :             let row = match rows.get(0) {
      63              26 :                 Some(row) => row,
      64                 :                 // This means that the user doesn't exist, so there can be no secret.
      65                 :                 // However, this is still a *valid* outcome which is very similar
      66                 :                 // to getting `404 Not found` from the Neon console.
      67                 :                 None => {
      68               1 :                     warn!("user '{}' does not exist", creds.user);
      69               1 :                     return Ok(None);
      70                 :                 }
      71                 :             };
      72                 : 
      73              26 :             let entry = row
      74              26 :                 .try_get("rolpassword")
      75              26 :                 .map_err(MockApiError::PasswordNotSet)?;
      76                 : 
      77              26 :             info!("got a secret: {entry}"); // safe since it's not a prod scenario
      78              26 :             let secret = scram::ServerSecret::parse(entry).map(AuthInfo::Scram);
      79              26 :             Ok(secret.or_else(|| parse_md5(entry).map(AuthInfo::Md5)))
      80              27 :         }
      81              27 :         .map_err(crate::error::log_error)
      82              27 :         .instrument(info_span!("postgres", url = self.endpoint.as_str()))
      83             135 :         .await
      84              27 :     }
      85                 : 
      86              55 :     async fn do_wake_compute(&self) -> Result<NodeInfo, WakeComputeError> {
      87              55 :         let mut config = compute::ConnCfg::new();
      88              55 :         config
      89              55 :             .host(self.endpoint.host_str().unwrap_or("localhost"))
      90              55 :             .port(self.endpoint.port().unwrap_or(5432))
      91              55 :             .ssl_mode(SslMode::Disable);
      92              55 : 
      93              55 :         let node = NodeInfo {
      94              55 :             config,
      95              55 :             aux: Default::default(),
      96              55 :             allow_self_signed_compute: false,
      97              55 :         };
      98              55 : 
      99              55 :         Ok(node)
     100              55 :     }
     101                 : }
     102                 : 
     103                 : #[async_trait]
     104                 : impl super::Api for Api {
     105              81 :     #[tracing::instrument(skip_all)]
     106                 :     async fn get_auth_info(
     107                 :         &self,
     108                 :         _extra: &ConsoleReqExtra<'_>,
     109                 :         creds: &ClientCredentials,
     110              27 :     ) -> Result<Option<AuthInfo>, GetAuthInfoError> {
     111             135 :         self.do_get_auth_info(creds).await
     112              54 :     }
     113                 : 
     114             165 :     #[tracing::instrument(skip_all)]
     115                 :     async fn wake_compute(
     116                 :         &self,
     117                 :         _extra: &ConsoleReqExtra<'_>,
     118                 :         _creds: &ClientCredentials,
     119              55 :     ) -> Result<CachedNodeInfo, WakeComputeError> {
     120              55 :         self.do_wake_compute()
     121              55 :             .map_ok(CachedNodeInfo::new_uncached)
     122 UBC           0 :             .await
     123 CBC         110 :     }
     124                 : }
     125                 : 
     126 UBC           0 : fn parse_md5(input: &str) -> Option<[u8; 16]> {
     127               0 :     let text = input.strip_prefix("md5")?;
     128                 : 
     129               0 :     let mut bytes = [0u8; 16];
     130               0 :     hex::decode_to_slice(text, &mut bytes).ok()?;
     131                 : 
     132               0 :     Some(bytes)
     133               0 : }
        

Generated by: LCOV version 2.1-beta