LCOV - code coverage report
Current view: top level - proxy/src/control_plane/client - mock.rs (source / functions) Coverage Total Hit
Test: 1e20c4f2b28aa592527961bb32170ebbd2c9172f.info Lines: 0.0 % 166 0
Test Date: 2025-07-16 12:29:03 Functions: 0.0 % 22 0

            Line data    Source code
       1              : //! Mock console backend which relies on a user-provided postgres instance.
       2              : 
       3              : use std::io;
       4              : use std::net::{IpAddr, Ipv4Addr};
       5              : use std::str::FromStr;
       6              : use std::sync::Arc;
       7              : 
       8              : use futures::TryFutureExt;
       9              : use postgres_client::config::SslMode;
      10              : use thiserror::Error;
      11              : use tokio_postgres::Client;
      12              : use tracing::{Instrument, error, info, info_span, warn};
      13              : 
      14              : use crate::auth::IpPattern;
      15              : use crate::auth::backend::ComputeUserInfo;
      16              : use crate::auth::backend::jwt::AuthRule;
      17              : use crate::cache::Cached;
      18              : use crate::compute::ConnectInfo;
      19              : use crate::context::RequestContext;
      20              : use crate::control_plane::errors::{
      21              :     ControlPlaneError, GetAuthInfoError, GetEndpointJwksError, WakeComputeError,
      22              : };
      23              : use crate::control_plane::messages::{EndpointRateLimitConfig, MetricsAuxInfo};
      24              : use crate::control_plane::{
      25              :     AccessBlockerFlags, AuthInfo, AuthSecret, CachedNodeInfo, EndpointAccessControl, NodeInfo,
      26              :     RoleAccessControl,
      27              : };
      28              : use crate::intern::RoleNameInt;
      29              : use crate::scram;
      30              : use crate::types::{BranchId, EndpointId, ProjectId, RoleName};
      31              : use crate::url::ApiUrl;
      32              : 
      33              : #[derive(Debug, Error)]
      34              : enum MockApiError {
      35              :     #[error("Failed to read password: {0}")]
      36              :     PasswordNotSet(tokio_postgres::Error),
      37              : }
      38              : 
      39              : impl From<MockApiError> for ControlPlaneError {
      40            0 :     fn from(e: MockApiError) -> Self {
      41            0 :         io::Error::other(e).into()
      42            0 :     }
      43              : }
      44              : 
      45              : impl From<tokio_postgres::Error> for ControlPlaneError {
      46            0 :     fn from(e: tokio_postgres::Error) -> Self {
      47            0 :         io::Error::other(e).into()
      48            0 :     }
      49              : }
      50              : 
      51              : #[derive(Clone)]
      52              : pub struct MockControlPlane {
      53              :     endpoint: ApiUrl,
      54              :     ip_allowlist_check_enabled: bool,
      55              : }
      56              : 
      57              : impl MockControlPlane {
      58            0 :     pub fn new(endpoint: ApiUrl, ip_allowlist_check_enabled: bool) -> Self {
      59            0 :         Self {
      60            0 :             endpoint,
      61            0 :             ip_allowlist_check_enabled,
      62            0 :         }
      63            0 :     }
      64              : 
      65            0 :     pub(crate) fn url(&self) -> &str {
      66            0 :         self.endpoint.as_str()
      67            0 :     }
      68              : 
      69            0 :     async fn do_get_auth_info(
      70            0 :         &self,
      71            0 :         endpoint: &EndpointId,
      72            0 :         role: &RoleName,
      73            0 :     ) -> Result<AuthInfo, GetAuthInfoError> {
      74            0 :         let (secret, allowed_ips) = async {
      75              :             // Perhaps we could persist this connection, but then we'd have to
      76              :             // write more code for reopening it if it got closed, which doesn't
      77              :             // seem worth it.
      78            0 :             let (client, connection) =
      79            0 :                 tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?;
      80              : 
      81            0 :             tokio::spawn(connection);
      82              : 
      83            0 :             let secret = if let Some(entry) = get_execute_postgres_query(
      84            0 :                 &client,
      85            0 :                 "select rolpassword from pg_catalog.pg_authid where rolname = $1",
      86            0 :                 &[&role.as_str()],
      87            0 :                 "rolpassword",
      88            0 :             )
      89            0 :             .await?
      90              :             {
      91            0 :                 info!("got a secret: {entry}"); // safe since it's not a prod scenario
      92            0 :                 scram::ServerSecret::parse(&entry).map(AuthSecret::Scram)
      93              :             } else {
      94            0 :                 warn!("user '{role}' does not exist");
      95            0 :                 None
      96              :             };
      97              : 
      98            0 :             let allowed_ips = if self.ip_allowlist_check_enabled {
      99            0 :                 match get_execute_postgres_query(
     100            0 :                     &client,
     101            0 :                     "select allowed_ips from neon_control_plane.endpoints where endpoint_id = $1",
     102            0 :                     &[&endpoint.as_str()],
     103            0 :                     "allowed_ips",
     104            0 :                 )
     105            0 :                 .await?
     106              :                 {
     107            0 :                     Some(s) => {
     108            0 :                         info!("got allowed_ips: {s}");
     109            0 :                         s.split(',')
     110            0 :                             .map(|s| {
     111            0 :                                 IpPattern::from_str(s).expect("mocked ip pattern should be correct")
     112            0 :                             })
     113            0 :                             .collect()
     114              :                     }
     115            0 :                     None => vec![],
     116              :                 }
     117              :             } else {
     118            0 :                 vec![]
     119              :             };
     120              : 
     121            0 :             Ok((secret, allowed_ips))
     122            0 :         }
     123            0 :         .inspect_err(|e: &GetAuthInfoError| tracing::error!("{e}"))
     124            0 :         .instrument(info_span!("postgres", url = self.endpoint.as_str()))
     125            0 :         .await?;
     126            0 :         Ok(AuthInfo {
     127            0 :             secret,
     128            0 :             allowed_ips,
     129            0 :             allowed_vpc_endpoint_ids: vec![],
     130            0 :             project_id: None,
     131            0 :             account_id: None,
     132            0 :             access_blocker_flags: AccessBlockerFlags::default(),
     133            0 :             rate_limits: EndpointRateLimitConfig::default(),
     134            0 :         })
     135            0 :     }
     136              : 
     137            0 :     async fn do_get_endpoint_jwks(
     138            0 :         &self,
     139            0 :         endpoint: &EndpointId,
     140            0 :     ) -> Result<Vec<AuthRule>, GetEndpointJwksError> {
     141            0 :         let (client, connection) =
     142            0 :             tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?;
     143              : 
     144            0 :         let connection = tokio::spawn(connection);
     145              : 
     146            0 :         let res = client.query(
     147            0 :                 "select id, jwks_url, audience, role_names from neon_control_plane.endpoint_jwks where endpoint_id = $1",
     148            0 :                 &[&endpoint.as_str()],
     149            0 :             )
     150            0 :             .await?;
     151              : 
     152            0 :         let mut rows = vec![];
     153            0 :         for row in res {
     154            0 :             rows.push(AuthRule {
     155            0 :                 id: row.get("id"),
     156            0 :                 jwks_url: url::Url::parse(row.get("jwks_url"))?,
     157            0 :                 audience: row.get("audience"),
     158            0 :                 role_names: row
     159            0 :                     .get::<_, Vec<String>>("role_names")
     160            0 :                     .into_iter()
     161            0 :                     .map(RoleName::from)
     162            0 :                     .map(|s| RoleNameInt::from(&s))
     163            0 :                     .collect(),
     164              :             });
     165              :         }
     166              : 
     167            0 :         drop(client);
     168            0 :         connection.await??;
     169              : 
     170            0 :         Ok(rows)
     171            0 :     }
     172              : 
     173            0 :     async fn do_wake_compute(&self) -> Result<NodeInfo, WakeComputeError> {
     174            0 :         let port = self.endpoint.port().unwrap_or(5432);
     175            0 :         let conn_info = match self.endpoint.host_str() {
     176            0 :             None => ConnectInfo {
     177            0 :                 host_addr: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
     178            0 :                 host: "localhost".into(),
     179            0 :                 port,
     180            0 :                 ssl_mode: SslMode::Disable,
     181            0 :             },
     182            0 :             Some(host) => ConnectInfo {
     183            0 :                 host_addr: IpAddr::from_str(host).ok(),
     184            0 :                 host: host.into(),
     185            0 :                 port,
     186            0 :                 ssl_mode: SslMode::Disable,
     187            0 :             },
     188              :         };
     189              : 
     190            0 :         let node = NodeInfo {
     191            0 :             conn_info,
     192            0 :             aux: MetricsAuxInfo {
     193            0 :                 endpoint_id: (&EndpointId::from("endpoint")).into(),
     194            0 :                 project_id: (&ProjectId::from("project")).into(),
     195            0 :                 branch_id: (&BranchId::from("branch")).into(),
     196            0 :                 compute_id: "compute".into(),
     197            0 :                 cold_start_info: crate::control_plane::messages::ColdStartInfo::Warm,
     198            0 :             },
     199            0 :         };
     200              : 
     201            0 :         Ok(node)
     202            0 :     }
     203              : }
     204              : 
     205            0 : async fn get_execute_postgres_query(
     206            0 :     client: &Client,
     207            0 :     query: &str,
     208            0 :     params: &[&(dyn tokio_postgres::types::ToSql + Sync)],
     209            0 :     idx: &str,
     210            0 : ) -> Result<Option<String>, GetAuthInfoError> {
     211            0 :     let rows = client.query(query, params).await?;
     212              : 
     213              :     // We can get at most one row, because `rolname` is unique.
     214            0 :     let Some(row) = rows.first() else {
     215              :         // This means that the user doesn't exist, so there can be no secret.
     216              :         // However, this is still a *valid* outcome which is very similar
     217              :         // to getting `404 Not found` from the Neon console.
     218            0 :         return Ok(None);
     219              :     };
     220              : 
     221            0 :     let entry = row.try_get(idx).map_err(MockApiError::PasswordNotSet)?;
     222            0 :     Ok(Some(entry))
     223            0 : }
     224              : 
     225              : impl super::ControlPlaneApi for MockControlPlane {
     226            0 :     async fn get_endpoint_access_control(
     227            0 :         &self,
     228            0 :         _ctx: &RequestContext,
     229            0 :         endpoint: &EndpointId,
     230            0 :         role: &RoleName,
     231            0 :     ) -> Result<EndpointAccessControl, GetAuthInfoError> {
     232            0 :         let info = self.do_get_auth_info(endpoint, role).await?;
     233            0 :         Ok(EndpointAccessControl {
     234            0 :             allowed_ips: Arc::new(info.allowed_ips),
     235            0 :             allowed_vpce: Arc::new(info.allowed_vpc_endpoint_ids),
     236            0 :             flags: info.access_blocker_flags,
     237            0 :             rate_limits: info.rate_limits,
     238            0 :         })
     239            0 :     }
     240              : 
     241            0 :     async fn get_role_access_control(
     242            0 :         &self,
     243            0 :         _ctx: &RequestContext,
     244            0 :         endpoint: &EndpointId,
     245            0 :         role: &RoleName,
     246            0 :     ) -> Result<RoleAccessControl, GetAuthInfoError> {
     247            0 :         let info = self.do_get_auth_info(endpoint, role).await?;
     248            0 :         Ok(RoleAccessControl {
     249            0 :             secret: info.secret,
     250            0 :         })
     251            0 :     }
     252              : 
     253            0 :     async fn get_endpoint_jwks(
     254            0 :         &self,
     255            0 :         _ctx: &RequestContext,
     256            0 :         endpoint: &EndpointId,
     257            0 :     ) -> Result<Vec<AuthRule>, GetEndpointJwksError> {
     258            0 :         self.do_get_endpoint_jwks(endpoint).await
     259            0 :     }
     260              : 
     261              :     #[tracing::instrument(skip_all)]
     262              :     async fn wake_compute(
     263              :         &self,
     264              :         _ctx: &RequestContext,
     265              :         _user_info: &ComputeUserInfo,
     266              :     ) -> Result<CachedNodeInfo, WakeComputeError> {
     267              :         self.do_wake_compute().map_ok(Cached::new_uncached).await
     268              :     }
     269              : }
        

Generated by: LCOV version 2.1-beta