LCOV - code coverage report
Current view: top level - proxy/src/console/provider - mock.rs (source / functions) Coverage Total Hit
Test: 8ac049b474321fdc72ddcb56d7165153a1a900e8.info Lines: 78.6 % 70 55
Test Date: 2023-09-06 10:18:01 Functions: 68.2 % 22 15

            Line data    Source code
       1              : //! Mock console backend which relies on a user-provided postgres instance.
       2              : 
       3              : use super::{
       4              :     errors::{ApiError, GetAuthInfoError, WakeComputeError},
       5              :     AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo,
       6              : };
       7              : use crate::{auth::ClientCredentials, compute, error::io_error, scram, url::ApiUrl};
       8              : use async_trait::async_trait;
       9              : use futures::TryFutureExt;
      10              : use thiserror::Error;
      11              : use tokio_postgres::config::SslMode;
      12              : use tracing::{error, info, info_span, warn, Instrument};
      13              : 
      14            0 : #[derive(Debug, Error)]
      15              : enum MockApiError {
      16              :     #[error("Failed to read password: {0}")]
      17              :     PasswordNotSet(tokio_postgres::Error),
      18              : }
      19              : 
      20              : impl From<MockApiError> for ApiError {
      21            0 :     fn from(e: MockApiError) -> Self {
      22            0 :         io_error(e).into()
      23            0 :     }
      24              : }
      25              : 
      26              : impl From<tokio_postgres::Error> for ApiError {
      27            0 :     fn from(e: tokio_postgres::Error) -> Self {
      28            0 :         io_error(e).into()
      29            0 :     }
      30              : }
      31              : 
      32            0 : #[derive(Clone)]
      33              : pub struct Api {
      34              :     endpoint: ApiUrl,
      35              : }
      36              : 
      37              : impl Api {
      38           11 :     pub fn new(endpoint: ApiUrl) -> Self {
      39           11 :         Self { endpoint }
      40           11 :     }
      41              : 
      42           11 :     pub fn url(&self) -> &str {
      43           11 :         self.endpoint.as_str()
      44           11 :     }
      45              : 
      46           25 :     async fn do_get_auth_info(
      47           25 :         &self,
      48           25 :         creds: &ClientCredentials<'_>,
      49           25 :     ) -> Result<Option<AuthInfo>, GetAuthInfoError> {
      50           25 :         async {
      51              :             // Perhaps we could persist this connection, but then we'd have to
      52              :             // write more code for reopening it if it got closed, which doesn't
      53              :             // seem worth it.
      54           25 :             let (client, connection) =
      55           66 :                 tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?;
      56              : 
      57           25 :             tokio::spawn(connection);
      58           25 :             let query = "select rolpassword from pg_catalog.pg_authid where rolname = $1";
      59           50 :             let rows = client.query(query, &[&creds.user]).await?;
      60              : 
      61              :             // We can get at most one row, because `rolname` is unique.
      62           25 :             let row = match rows.get(0) {
      63           24 :                 Some(row) => row,
      64              :                 // This means that the user doesn't exist, so there can be no secret.
      65              :                 // However, this is still a *valid* outcome which is very similar
      66              :                 // to getting `404 Not found` from the Neon console.
      67              :                 None => {
      68            1 :                     warn!("user '{}' does not exist", creds.user);
      69            1 :                     return Ok(None);
      70              :                 }
      71              :             };
      72              : 
      73           24 :             let entry = row
      74           24 :                 .try_get("rolpassword")
      75           24 :                 .map_err(MockApiError::PasswordNotSet)?;
      76              : 
      77           24 :             info!("got a secret: {entry}"); // safe since it's not a prod scenario
      78           24 :             let secret = scram::ServerSecret::parse(entry).map(AuthInfo::Scram);
      79           24 :             Ok(secret.or_else(|| parse_md5(entry).map(AuthInfo::Md5)))
      80           25 :         }
      81           25 :         .map_err(crate::error::log_error)
      82           25 :         .instrument(info_span!("postgres", url = self.endpoint.as_str()))
      83          116 :         .await
      84           25 :     }
      85              : 
      86           46 :     async fn do_wake_compute(&self) -> Result<NodeInfo, WakeComputeError> {
      87           46 :         let mut config = compute::ConnCfg::new();
      88           46 :         config
      89           46 :             .host(self.endpoint.host_str().unwrap_or("localhost"))
      90           46 :             .port(self.endpoint.port().unwrap_or(5432))
      91           46 :             .ssl_mode(SslMode::Disable);
      92           46 : 
      93           46 :         let node = NodeInfo {
      94           46 :             config,
      95           46 :             aux: Default::default(),
      96           46 :             allow_self_signed_compute: false,
      97           46 :         };
      98           46 : 
      99           46 :         Ok(node)
     100           46 :     }
     101              : }
     102              : 
     103              : #[async_trait]
     104              : impl super::Api for Api {
     105           75 :     #[tracing::instrument(skip_all)]
     106              :     async fn get_auth_info(
     107              :         &self,
     108              :         _extra: &ConsoleReqExtra<'_>,
     109              :         creds: &ClientCredentials,
     110           25 :     ) -> Result<Option<AuthInfo>, GetAuthInfoError> {
     111          116 :         self.do_get_auth_info(creds).await
     112           50 :     }
     113              : 
     114          138 :     #[tracing::instrument(skip_all)]
     115              :     async fn wake_compute(
     116              :         &self,
     117              :         _extra: &ConsoleReqExtra<'_>,
     118              :         _creds: &ClientCredentials,
     119           46 :     ) -> Result<CachedNodeInfo, WakeComputeError> {
     120           46 :         self.do_wake_compute()
     121           46 :             .map_ok(CachedNodeInfo::new_uncached)
     122            0 :             .await
     123           92 :     }
     124              : }
     125              : 
     126            0 : fn parse_md5(input: &str) -> Option<[u8; 16]> {
     127            0 :     let text = input.strip_prefix("md5")?;
     128              : 
     129            0 :     let mut bytes = [0u8; 16];
     130            0 :     hex::decode_to_slice(text, &mut bytes).ok()?;
     131              : 
     132            0 :     Some(bytes)
     133            0 : }
        

Generated by: LCOV version 2.1-beta