LCOV - code coverage report
Current view: top level - proxy/src/control_plane/client - mock.rs (source / functions) Coverage Total Hit
Test: 1b0a6a0c05cee5a7de360813c8034804e105ce1c.info Lines: 0.0 % 173 0
Test Date: 2025-03-12 00:01:28 Functions: 0.0 % 26 0

            Line data    Source code
       1              : //! Mock console backend which relies on a user-provided postgres instance.
       2              : 
       3              : use std::net::{IpAddr, Ipv4Addr};
       4              : use std::str::FromStr;
       5              : use std::sync::Arc;
       6              : 
       7              : use futures::TryFutureExt;
       8              : use thiserror::Error;
       9              : use tokio_postgres::Client;
      10              : use tracing::{Instrument, error, info, info_span, warn};
      11              : 
      12              : use crate::auth::IpPattern;
      13              : use crate::auth::backend::ComputeUserInfo;
      14              : use crate::auth::backend::jwt::AuthRule;
      15              : use crate::cache::Cached;
      16              : use crate::context::RequestContext;
      17              : use crate::control_plane::client::{
      18              :     CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedRoleSecret,
      19              : };
      20              : use crate::control_plane::errors::{
      21              :     ControlPlaneError, GetAuthInfoError, GetEndpointJwksError, WakeComputeError,
      22              : };
      23              : use crate::control_plane::messages::MetricsAuxInfo;
      24              : use crate::control_plane::{AccessBlockerFlags, AuthInfo, AuthSecret, CachedNodeInfo, NodeInfo};
      25              : use crate::error::io_error;
      26              : use crate::intern::RoleNameInt;
      27              : use crate::types::{BranchId, EndpointId, ProjectId, RoleName};
      28              : use crate::url::ApiUrl;
      29              : use crate::{compute, scram};
      30              : 
      31              : #[derive(Debug, Error)]
      32              : enum MockApiError {
      33              :     #[error("Failed to read password: {0}")]
      34              :     PasswordNotSet(tokio_postgres::Error),
      35              : }
      36              : 
      37              : impl From<MockApiError> for ControlPlaneError {
      38            0 :     fn from(e: MockApiError) -> Self {
      39            0 :         io_error(e).into()
      40            0 :     }
      41              : }
      42              : 
      43              : impl From<tokio_postgres::Error> for ControlPlaneError {
      44            0 :     fn from(e: tokio_postgres::Error) -> Self {
      45            0 :         io_error(e).into()
      46            0 :     }
      47              : }
      48              : 
      49              : #[derive(Clone)]
      50              : pub struct MockControlPlane {
      51              :     endpoint: ApiUrl,
      52              :     ip_allowlist_check_enabled: bool,
      53              : }
      54              : 
      55              : impl MockControlPlane {
      56            0 :     pub fn new(endpoint: ApiUrl, ip_allowlist_check_enabled: bool) -> Self {
      57            0 :         Self {
      58            0 :             endpoint,
      59            0 :             ip_allowlist_check_enabled,
      60            0 :         }
      61            0 :     }
      62              : 
      63            0 :     pub(crate) fn url(&self) -> &str {
      64            0 :         self.endpoint.as_str()
      65            0 :     }
      66              : 
      67            0 :     async fn do_get_auth_info(
      68            0 :         &self,
      69            0 :         user_info: &ComputeUserInfo,
      70            0 :     ) -> Result<AuthInfo, GetAuthInfoError> {
      71            0 :         let (secret, allowed_ips) = async {
      72              :             // Perhaps we could persist this connection, but then we'd have to
      73              :             // write more code for reopening it if it got closed, which doesn't
      74              :             // seem worth it.
      75            0 :             let (client, connection) =
      76            0 :                 tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?;
      77              : 
      78            0 :             tokio::spawn(connection);
      79              : 
      80            0 :             let secret = if let Some(entry) = get_execute_postgres_query(
      81            0 :                 &client,
      82            0 :                 "select rolpassword from pg_catalog.pg_authid where rolname = $1",
      83            0 :                 &[&&*user_info.user],
      84            0 :                 "rolpassword",
      85            0 :             )
      86            0 :             .await?
      87              :             {
      88            0 :                 info!("got a secret: {entry}"); // safe since it's not a prod scenario
      89            0 :                 let secret = scram::ServerSecret::parse(&entry).map(AuthSecret::Scram);
      90            0 :                 secret.or_else(|| parse_md5(&entry).map(AuthSecret::Md5))
      91              :             } else {
      92            0 :                 warn!("user '{}' does not exist", user_info.user);
      93            0 :                 None
      94              :             };
      95              : 
      96            0 :             let allowed_ips = if self.ip_allowlist_check_enabled {
      97            0 :                 match get_execute_postgres_query(
      98            0 :                     &client,
      99            0 :                     "select allowed_ips from neon_control_plane.endpoints where endpoint_id = $1",
     100            0 :                     &[&user_info.endpoint.as_str()],
     101            0 :                     "allowed_ips",
     102            0 :                 )
     103            0 :                 .await?
     104              :                 {
     105            0 :                     Some(s) => {
     106            0 :                         info!("got allowed_ips: {s}");
     107            0 :                         s.split(',')
     108            0 :                             .map(|s| {
     109            0 :                                 IpPattern::from_str(s).expect("mocked ip pattern should be correct")
     110            0 :                             })
     111            0 :                             .collect()
     112              :                     }
     113            0 :                     None => vec![],
     114              :                 }
     115              :             } else {
     116            0 :                 vec![]
     117              :             };
     118              : 
     119            0 :             Ok((secret, allowed_ips))
     120            0 :         }
     121            0 :         .inspect_err(|e: &GetAuthInfoError| tracing::error!("{e}"))
     122            0 :         .instrument(info_span!("postgres", url = self.endpoint.as_str()))
     123            0 :         .await?;
     124            0 :         Ok(AuthInfo {
     125            0 :             secret,
     126            0 :             allowed_ips,
     127            0 :             allowed_vpc_endpoint_ids: vec![],
     128            0 :             project_id: None,
     129            0 :             account_id: None,
     130            0 :             access_blocker_flags: AccessBlockerFlags::default(),
     131            0 :         })
     132            0 :     }
     133              : 
     134            0 :     async fn do_get_endpoint_jwks(
     135            0 :         &self,
     136            0 :         endpoint: EndpointId,
     137            0 :     ) -> Result<Vec<AuthRule>, GetEndpointJwksError> {
     138            0 :         let (client, connection) =
     139            0 :             tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?;
     140              : 
     141            0 :         let connection = tokio::spawn(connection);
     142              : 
     143            0 :         let res = client.query(
     144            0 :                 "select id, jwks_url, audience, role_names from neon_control_plane.endpoint_jwks where endpoint_id = $1",
     145            0 :                 &[&endpoint.as_str()],
     146            0 :             )
     147            0 :             .await?;
     148              : 
     149            0 :         let mut rows = vec![];
     150            0 :         for row in res {
     151            0 :             rows.push(AuthRule {
     152            0 :                 id: row.get("id"),
     153            0 :                 jwks_url: url::Url::parse(row.get("jwks_url"))?,
     154            0 :                 audience: row.get("audience"),
     155            0 :                 role_names: row
     156            0 :                     .get::<_, Vec<String>>("role_names")
     157            0 :                     .into_iter()
     158            0 :                     .map(RoleName::from)
     159            0 :                     .map(|s| RoleNameInt::from(&s))
     160            0 :                     .collect(),
     161            0 :             });
     162            0 :         }
     163              : 
     164            0 :         drop(client);
     165            0 :         connection.await??;
     166              : 
     167            0 :         Ok(rows)
     168            0 :     }
     169              : 
     170            0 :     async fn do_wake_compute(&self) -> Result<NodeInfo, WakeComputeError> {
     171            0 :         let port = self.endpoint.port().unwrap_or(5432);
     172            0 :         let mut config = match self.endpoint.host_str() {
     173              :             None => {
     174            0 :                 let mut config = compute::ConnCfg::new("localhost".to_string(), port);
     175            0 :                 config.set_host_addr(IpAddr::V4(Ipv4Addr::LOCALHOST));
     176            0 :                 config
     177              :             }
     178            0 :             Some(host) => {
     179            0 :                 let mut config = compute::ConnCfg::new(host.to_string(), port);
     180            0 :                 if let Ok(addr) = IpAddr::from_str(host) {
     181            0 :                     config.set_host_addr(addr);
     182            0 :                 }
     183            0 :                 config
     184              :             }
     185              :         };
     186              : 
     187            0 :         config.ssl_mode(postgres_client::config::SslMode::Disable);
     188            0 : 
     189            0 :         let node = NodeInfo {
     190            0 :             config,
     191            0 :             aux: MetricsAuxInfo {
     192            0 :                 endpoint_id: (&EndpointId::from("endpoint")).into(),
     193            0 :                 project_id: (&ProjectId::from("project")).into(),
     194            0 :                 branch_id: (&BranchId::from("branch")).into(),
     195            0 :                 compute_id: "compute".into(),
     196            0 :                 cold_start_info: crate::control_plane::messages::ColdStartInfo::Warm,
     197            0 :             },
     198            0 :         };
     199            0 : 
     200            0 :         Ok(node)
     201            0 :     }
     202              : }
     203              : 
     204            0 : async fn get_execute_postgres_query(
     205            0 :     client: &Client,
     206            0 :     query: &str,
     207            0 :     params: &[&(dyn tokio_postgres::types::ToSql + Sync)],
     208            0 :     idx: &str,
     209            0 : ) -> Result<Option<String>, GetAuthInfoError> {
     210            0 :     let rows = client.query(query, params).await?;
     211              : 
     212              :     // We can get at most one row, because `rolname` is unique.
     213            0 :     let Some(row) = rows.first() else {
     214              :         // This means that the user doesn't exist, so there can be no secret.
     215              :         // However, this is still a *valid* outcome which is very similar
     216              :         // to getting `404 Not found` from the Neon console.
     217            0 :         return Ok(None);
     218              :     };
     219              : 
     220            0 :     let entry = row.try_get(idx).map_err(MockApiError::PasswordNotSet)?;
     221            0 :     Ok(Some(entry))
     222            0 : }
     223              : 
     224              : impl super::ControlPlaneApi for MockControlPlane {
     225              :     #[tracing::instrument(skip_all)]
     226              :     async fn get_role_secret(
     227              :         &self,
     228              :         _ctx: &RequestContext,
     229              :         user_info: &ComputeUserInfo,
     230              :     ) -> Result<CachedRoleSecret, GetAuthInfoError> {
     231              :         Ok(CachedRoleSecret::new_uncached(
     232              :             self.do_get_auth_info(user_info).await?.secret,
     233              :         ))
     234              :     }
     235              : 
     236            0 :     async fn get_allowed_ips(
     237            0 :         &self,
     238            0 :         _ctx: &RequestContext,
     239            0 :         user_info: &ComputeUserInfo,
     240            0 :     ) -> Result<CachedAllowedIps, GetAuthInfoError> {
     241            0 :         Ok(Cached::new_uncached(Arc::new(
     242            0 :             self.do_get_auth_info(user_info).await?.allowed_ips,
     243              :         )))
     244            0 :     }
     245              : 
     246            0 :     async fn get_allowed_vpc_endpoint_ids(
     247            0 :         &self,
     248            0 :         _ctx: &RequestContext,
     249            0 :         user_info: &ComputeUserInfo,
     250            0 :     ) -> Result<CachedAllowedVpcEndpointIds, super::errors::GetAuthInfoError> {
     251            0 :         Ok(Cached::new_uncached(Arc::new(
     252            0 :             self.do_get_auth_info(user_info)
     253            0 :                 .await?
     254              :                 .allowed_vpc_endpoint_ids,
     255              :         )))
     256            0 :     }
     257              : 
     258            0 :     async fn get_block_public_or_vpc_access(
     259            0 :         &self,
     260            0 :         _ctx: &RequestContext,
     261            0 :         user_info: &ComputeUserInfo,
     262            0 :     ) -> Result<super::CachedAccessBlockerFlags, super::errors::GetAuthInfoError> {
     263            0 :         Ok(Cached::new_uncached(
     264            0 :             self.do_get_auth_info(user_info).await?.access_blocker_flags,
     265              :         ))
     266            0 :     }
     267              : 
     268            0 :     async fn get_endpoint_jwks(
     269            0 :         &self,
     270            0 :         _ctx: &RequestContext,
     271            0 :         endpoint: EndpointId,
     272            0 :     ) -> Result<Vec<AuthRule>, GetEndpointJwksError> {
     273            0 :         self.do_get_endpoint_jwks(endpoint).await
     274            0 :     }
     275              : 
     276              :     #[tracing::instrument(skip_all)]
     277              :     async fn wake_compute(
     278              :         &self,
     279              :         _ctx: &RequestContext,
     280              :         _user_info: &ComputeUserInfo,
     281              :     ) -> Result<CachedNodeInfo, WakeComputeError> {
     282              :         self.do_wake_compute().map_ok(Cached::new_uncached).await
     283              :     }
     284              : }
     285              : 
     286            0 : fn parse_md5(input: &str) -> Option<[u8; 16]> {
     287            0 :     let text = input.strip_prefix("md5")?;
     288              : 
     289            0 :     let mut bytes = [0u8; 16];
     290            0 :     hex::decode_to_slice(text, &mut bytes).ok()?;
     291              : 
     292            0 :     Some(bytes)
     293            0 : }
        

Generated by: LCOV version 2.1-beta