LCOV - code coverage report
Current view: top level - proxy/src/console/provider - mock.rs (source / functions) Coverage Total Hit
Test: fabb29a6339542ee130cd1d32b534fafdc0be240.info Lines: 0.0 % 109 0
Test Date: 2024-06-25 13:20:00 Functions: 0.0 % 21 0

            Line data    Source code
       1              : //! Mock console backend which relies on a user-provided postgres instance.
       2              : 
       3              : use super::{
       4              :     errors::{ApiError, GetAuthInfoError, WakeComputeError},
       5              :     AuthInfo, AuthSecret, CachedNodeInfo, NodeInfo,
       6              : };
       7              : use crate::context::RequestMonitoring;
       8              : use crate::{auth::backend::ComputeUserInfo, compute, error::io_error, scram, url::ApiUrl};
       9              : use crate::{auth::IpPattern, cache::Cached};
      10              : use crate::{
      11              :     console::{
      12              :         messages::MetricsAuxInfo,
      13              :         provider::{CachedAllowedIps, CachedRoleSecret},
      14              :     },
      15              :     BranchId, EndpointId, ProjectId,
      16              : };
      17              : use futures::TryFutureExt;
      18              : use std::{str::FromStr, sync::Arc};
      19              : use thiserror::Error;
      20              : use tokio_postgres::{config::SslMode, Client};
      21              : use tracing::{error, info, info_span, warn, Instrument};
      22              : 
      23            0 : #[derive(Debug, Error)]
      24              : enum MockApiError {
      25              :     #[error("Failed to read password: {0}")]
      26              :     PasswordNotSet(tokio_postgres::Error),
      27              : }
      28              : 
      29              : impl From<MockApiError> for ApiError {
      30            0 :     fn from(e: MockApiError) -> Self {
      31            0 :         io_error(e).into()
      32            0 :     }
      33              : }
      34              : 
      35              : impl From<tokio_postgres::Error> for ApiError {
      36            0 :     fn from(e: tokio_postgres::Error) -> Self {
      37            0 :         io_error(e).into()
      38            0 :     }
      39              : }
      40              : 
      41              : #[derive(Clone)]
      42              : pub struct Api {
      43              :     endpoint: ApiUrl,
      44              : }
      45              : 
      46              : impl Api {
      47            0 :     pub fn new(endpoint: ApiUrl) -> Self {
      48            0 :         Self { endpoint }
      49            0 :     }
      50              : 
      51            0 :     pub fn url(&self) -> &str {
      52            0 :         self.endpoint.as_str()
      53            0 :     }
      54              : 
      55            0 :     async fn do_get_auth_info(
      56            0 :         &self,
      57            0 :         user_info: &ComputeUserInfo,
      58            0 :     ) -> Result<AuthInfo, GetAuthInfoError> {
      59            0 :         let (secret, allowed_ips) = async {
      60              :             // Perhaps we could persist this connection, but then we'd have to
      61              :             // write more code for reopening it if it got closed, which doesn't
      62              :             // seem worth it.
      63            0 :             let (client, connection) =
      64            0 :                 tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?;
      65              : 
      66            0 :             tokio::spawn(connection);
      67            0 :             let secret = match get_execute_postgres_query(
      68            0 :                 &client,
      69            0 :                 "select rolpassword from pg_catalog.pg_authid where rolname = $1",
      70            0 :                 &[&&*user_info.user],
      71            0 :                 "rolpassword",
      72            0 :             )
      73            0 :             .await?
      74              :             {
      75            0 :                 Some(entry) => {
      76            0 :                     info!("got a secret: {entry}"); // safe since it's not a prod scenario
      77            0 :                     let secret = scram::ServerSecret::parse(&entry).map(AuthSecret::Scram);
      78            0 :                     secret.or_else(|| parse_md5(&entry).map(AuthSecret::Md5))
      79              :                 }
      80              :                 None => {
      81            0 :                     warn!("user '{}' does not exist", user_info.user);
      82            0 :                     None
      83              :                 }
      84              :             };
      85            0 :             let allowed_ips = match get_execute_postgres_query(
      86            0 :                 &client,
      87            0 :                 "select allowed_ips from neon_control_plane.endpoints where endpoint_id = $1",
      88            0 :                 &[&user_info.endpoint.as_str()],
      89            0 :                 "allowed_ips",
      90            0 :             )
      91            0 :             .await?
      92              :             {
      93            0 :                 Some(s) => {
      94            0 :                     info!("got allowed_ips: {s}");
      95            0 :                     s.split(',')
      96            0 :                         .map(|s| IpPattern::from_str(s).unwrap())
      97            0 :                         .collect()
      98              :                 }
      99            0 :                 None => vec![],
     100              :             };
     101              : 
     102            0 :             Ok((secret, allowed_ips))
     103            0 :         }
     104            0 :         .map_err(crate::error::log_error::<GetAuthInfoError>)
     105            0 :         .instrument(info_span!("postgres", url = self.endpoint.as_str()))
     106            0 :         .await?;
     107            0 :         Ok(AuthInfo {
     108            0 :             secret,
     109            0 :             allowed_ips,
     110            0 :             project_id: None,
     111            0 :         })
     112            0 :     }
     113              : 
     114            0 :     async fn do_wake_compute(&self) -> Result<NodeInfo, WakeComputeError> {
     115            0 :         let mut config = compute::ConnCfg::new();
     116            0 :         config
     117            0 :             .host(self.endpoint.host_str().unwrap_or("localhost"))
     118            0 :             .port(self.endpoint.port().unwrap_or(5432))
     119            0 :             .ssl_mode(SslMode::Disable);
     120            0 : 
     121            0 :         let node = NodeInfo {
     122            0 :             config,
     123            0 :             aux: MetricsAuxInfo {
     124            0 :                 endpoint_id: (&EndpointId::from("endpoint")).into(),
     125            0 :                 project_id: (&ProjectId::from("project")).into(),
     126            0 :                 branch_id: (&BranchId::from("branch")).into(),
     127            0 :                 cold_start_info: crate::console::messages::ColdStartInfo::Warm,
     128            0 :             },
     129            0 :             allow_self_signed_compute: false,
     130            0 :         };
     131            0 : 
     132            0 :         Ok(node)
     133            0 :     }
     134              : }
     135              : 
     136            0 : async fn get_execute_postgres_query(
     137            0 :     client: &Client,
     138            0 :     query: &str,
     139            0 :     params: &[&(dyn tokio_postgres::types::ToSql + Sync)],
     140            0 :     idx: &str,
     141            0 : ) -> Result<Option<String>, GetAuthInfoError> {
     142            0 :     let rows = client.query(query, params).await?;
     143              : 
     144              :     // We can get at most one row, because `rolname` is unique.
     145            0 :     let row = match rows.first() {
     146            0 :         Some(row) => row,
     147              :         // This means that the user doesn't exist, so there can be no secret.
     148              :         // However, this is still a *valid* outcome which is very similar
     149              :         // to getting `404 Not found` from the Neon console.
     150            0 :         None => return Ok(None),
     151              :     };
     152              : 
     153            0 :     let entry = row.try_get(idx).map_err(MockApiError::PasswordNotSet)?;
     154            0 :     Ok(Some(entry))
     155            0 : }
     156              : 
     157              : impl super::Api for Api {
     158            0 :     #[tracing::instrument(skip_all)]
     159              :     async fn get_role_secret(
     160              :         &self,
     161              :         _ctx: &mut RequestMonitoring,
     162              :         user_info: &ComputeUserInfo,
     163              :     ) -> Result<CachedRoleSecret, GetAuthInfoError> {
     164              :         Ok(CachedRoleSecret::new_uncached(
     165              :             self.do_get_auth_info(user_info).await?.secret,
     166              :         ))
     167              :     }
     168              : 
     169            0 :     async fn get_allowed_ips_and_secret(
     170            0 :         &self,
     171            0 :         _ctx: &mut RequestMonitoring,
     172            0 :         user_info: &ComputeUserInfo,
     173            0 :     ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
     174            0 :         Ok((
     175            0 :             Cached::new_uncached(Arc::new(
     176            0 :                 self.do_get_auth_info(user_info).await?.allowed_ips,
     177              :             )),
     178            0 :             None,
     179              :         ))
     180            0 :     }
     181              : 
     182            0 :     #[tracing::instrument(skip_all)]
     183              :     async fn wake_compute(
     184              :         &self,
     185              :         _ctx: &mut RequestMonitoring,
     186              :         _user_info: &ComputeUserInfo,
     187              :     ) -> Result<CachedNodeInfo, WakeComputeError> {
     188              :         self.do_wake_compute().map_ok(Cached::new_uncached).await
     189              :     }
     190              : }
     191              : 
     192            0 : fn parse_md5(input: &str) -> Option<[u8; 16]> {
     193            0 :     let text = input.strip_prefix("md5")?;
     194              : 
     195            0 :     let mut bytes = [0u8; 16];
     196            0 :     hex::decode_to_slice(text, &mut bytes).ok()?;
     197              : 
     198            0 :     Some(bytes)
     199            0 : }
        

Generated by: LCOV version 2.1-beta