LCOV - code coverage report
Current view: top level - proxy/src/control_plane/client - mock.rs (source / functions) Coverage Total Hit
Test: b9d67f908f91f00e353a27440ba89f642a869959.info Lines: 0.0 % 149 0
Test Date: 2024-11-19 21:44:13 Functions: 0.0 % 24 0

            Line data    Source code
       1              : //! Mock console backend which relies on a user-provided postgres instance.
       2              : 
       3              : use std::str::FromStr;
       4              : use std::sync::Arc;
       5              : 
       6              : use futures::TryFutureExt;
       7              : use thiserror::Error;
       8              : use tokio_postgres::config::SslMode;
       9              : use tokio_postgres::Client;
      10              : use tracing::{error, info, info_span, warn, Instrument};
      11              : 
      12              : use crate::auth::backend::jwt::AuthRule;
      13              : use crate::auth::backend::ComputeUserInfo;
      14              : use crate::auth::IpPattern;
      15              : use crate::cache::Cached;
      16              : use crate::context::RequestMonitoring;
      17              : use crate::control_plane::client::{CachedAllowedIps, CachedRoleSecret};
      18              : use crate::control_plane::errors::{
      19              :     ControlPlaneError, GetAuthInfoError, GetEndpointJwksError, WakeComputeError,
      20              : };
      21              : use crate::control_plane::messages::MetricsAuxInfo;
      22              : use crate::control_plane::{AuthInfo, AuthSecret, CachedNodeInfo, NodeInfo};
      23              : use crate::error::io_error;
      24              : use crate::intern::RoleNameInt;
      25              : use crate::types::{BranchId, EndpointId, ProjectId, RoleName};
      26              : use crate::url::ApiUrl;
      27              : use crate::{compute, scram};
      28              : 
      29            0 : #[derive(Debug, Error)]
      30              : enum MockApiError {
      31              :     #[error("Failed to read password: {0}")]
      32              :     PasswordNotSet(tokio_postgres::Error),
      33              : }
      34              : 
      35              : impl From<MockApiError> for ControlPlaneError {
      36            0 :     fn from(e: MockApiError) -> Self {
      37            0 :         io_error(e).into()
      38            0 :     }
      39              : }
      40              : 
      41              : impl From<tokio_postgres::Error> for ControlPlaneError {
      42            0 :     fn from(e: tokio_postgres::Error) -> Self {
      43            0 :         io_error(e).into()
      44            0 :     }
      45              : }
      46              : 
      47              : #[derive(Clone)]
      48              : pub struct MockControlPlane {
      49              :     endpoint: ApiUrl,
      50              :     ip_allowlist_check_enabled: bool,
      51              : }
      52              : 
      53              : impl MockControlPlane {
      54            0 :     pub fn new(endpoint: ApiUrl, ip_allowlist_check_enabled: bool) -> Self {
      55            0 :         Self {
      56            0 :             endpoint,
      57            0 :             ip_allowlist_check_enabled,
      58            0 :         }
      59            0 :     }
      60              : 
      61            0 :     pub(crate) fn url(&self) -> &str {
      62            0 :         self.endpoint.as_str()
      63            0 :     }
      64              : 
      65            0 :     async fn do_get_auth_info(
      66            0 :         &self,
      67            0 :         user_info: &ComputeUserInfo,
      68            0 :     ) -> Result<AuthInfo, GetAuthInfoError> {
      69            0 :         let (secret, allowed_ips) = async {
      70              :             // Perhaps we could persist this connection, but then we'd have to
      71              :             // write more code for reopening it if it got closed, which doesn't
      72              :             // seem worth it.
      73            0 :             let (client, connection) =
      74            0 :                 tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?;
      75              : 
      76            0 :             tokio::spawn(connection);
      77              : 
      78            0 :             let secret = if let Some(entry) = get_execute_postgres_query(
      79            0 :                 &client,
      80            0 :                 "select rolpassword from pg_catalog.pg_authid where rolname = $1",
      81            0 :                 &[&&*user_info.user],
      82            0 :                 "rolpassword",
      83            0 :             )
      84            0 :             .await?
      85              :             {
      86            0 :                 info!("got a secret: {entry}"); // safe since it's not a prod scenario
      87            0 :                 let secret = scram::ServerSecret::parse(&entry).map(AuthSecret::Scram);
      88            0 :                 secret.or_else(|| parse_md5(&entry).map(AuthSecret::Md5))
      89              :             } else {
      90            0 :                 warn!("user '{}' does not exist", user_info.user);
      91            0 :                 None
      92              :             };
      93              : 
      94            0 :             let allowed_ips = if self.ip_allowlist_check_enabled {
      95            0 :                 match get_execute_postgres_query(
      96            0 :                     &client,
      97            0 :                     "select allowed_ips from neon_control_plane.endpoints where endpoint_id = $1",
      98            0 :                     &[&user_info.endpoint.as_str()],
      99            0 :                     "allowed_ips",
     100            0 :                 )
     101            0 :                 .await?
     102              :                 {
     103            0 :                     Some(s) => {
     104            0 :                         info!("got allowed_ips: {s}");
     105            0 :                         s.split(',')
     106            0 :                             .map(|s| IpPattern::from_str(s).unwrap())
     107            0 :                             .collect()
     108              :                     }
     109            0 :                     None => vec![],
     110              :                 }
     111              :             } else {
     112            0 :                 vec![]
     113              :             };
     114              : 
     115            0 :             Ok((secret, allowed_ips))
     116            0 :         }
     117            0 :         .map_err(crate::error::log_error::<GetAuthInfoError>)
     118            0 :         .instrument(info_span!("postgres", url = self.endpoint.as_str()))
     119            0 :         .await?;
     120            0 :         Ok(AuthInfo {
     121            0 :             secret,
     122            0 :             allowed_ips,
     123            0 :             project_id: None,
     124            0 :         })
     125            0 :     }
     126              : 
     127            0 :     async fn do_get_endpoint_jwks(
     128            0 :         &self,
     129            0 :         endpoint: EndpointId,
     130            0 :     ) -> Result<Vec<AuthRule>, GetEndpointJwksError> {
     131            0 :         let (client, connection) =
     132            0 :             tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?;
     133              : 
     134            0 :         let connection = tokio::spawn(connection);
     135              : 
     136            0 :         let res = client.query(
     137            0 :                 "select id, jwks_url, audience, role_names from neon_control_plane.endpoint_jwks where endpoint_id = $1",
     138            0 :                 &[&endpoint.as_str()],
     139            0 :             )
     140            0 :             .await?;
     141              : 
     142            0 :         let mut rows = vec![];
     143            0 :         for row in res {
     144            0 :             rows.push(AuthRule {
     145            0 :                 id: row.get("id"),
     146            0 :                 jwks_url: url::Url::parse(row.get("jwks_url"))?,
     147            0 :                 audience: row.get("audience"),
     148            0 :                 role_names: row
     149            0 :                     .get::<_, Vec<String>>("role_names")
     150            0 :                     .into_iter()
     151            0 :                     .map(RoleName::from)
     152            0 :                     .map(|s| RoleNameInt::from(&s))
     153            0 :                     .collect(),
     154            0 :             });
     155            0 :         }
     156              : 
     157            0 :         drop(client);
     158            0 :         connection.await??;
     159              : 
     160            0 :         Ok(rows)
     161            0 :     }
     162              : 
     163            0 :     async fn do_wake_compute(&self) -> Result<NodeInfo, WakeComputeError> {
     164            0 :         let mut config = compute::ConnCfg::new();
     165            0 :         config
     166            0 :             .host(self.endpoint.host_str().unwrap_or("localhost"))
     167            0 :             .port(self.endpoint.port().unwrap_or(5432))
     168            0 :             .ssl_mode(SslMode::Disable);
     169            0 : 
     170            0 :         let node = NodeInfo {
     171            0 :             config,
     172            0 :             aux: MetricsAuxInfo {
     173            0 :                 endpoint_id: (&EndpointId::from("endpoint")).into(),
     174            0 :                 project_id: (&ProjectId::from("project")).into(),
     175            0 :                 branch_id: (&BranchId::from("branch")).into(),
     176            0 :                 cold_start_info: crate::control_plane::messages::ColdStartInfo::Warm,
     177            0 :             },
     178            0 :             allow_self_signed_compute: false,
     179            0 :         };
     180            0 : 
     181            0 :         Ok(node)
     182            0 :     }
     183              : }
     184              : 
     185            0 : async fn get_execute_postgres_query(
     186            0 :     client: &Client,
     187            0 :     query: &str,
     188            0 :     params: &[&(dyn tokio_postgres::types::ToSql + Sync)],
     189            0 :     idx: &str,
     190            0 : ) -> Result<Option<String>, GetAuthInfoError> {
     191            0 :     let rows = client.query(query, params).await?;
     192              : 
     193              :     // We can get at most one row, because `rolname` is unique.
     194            0 :     let Some(row) = rows.first() else {
     195              :         // This means that the user doesn't exist, so there can be no secret.
     196              :         // However, this is still a *valid* outcome which is very similar
     197              :         // to getting `404 Not found` from the Neon console.
     198            0 :         return Ok(None);
     199              :     };
     200              : 
     201            0 :     let entry = row.try_get(idx).map_err(MockApiError::PasswordNotSet)?;
     202            0 :     Ok(Some(entry))
     203            0 : }
     204              : 
     205              : impl super::ControlPlaneApi for MockControlPlane {
     206            0 :     #[tracing::instrument(skip_all)]
     207              :     async fn get_role_secret(
     208              :         &self,
     209              :         _ctx: &RequestMonitoring,
     210              :         user_info: &ComputeUserInfo,
     211              :     ) -> Result<CachedRoleSecret, GetAuthInfoError> {
     212              :         Ok(CachedRoleSecret::new_uncached(
     213              :             self.do_get_auth_info(user_info).await?.secret,
     214              :         ))
     215              :     }
     216              : 
     217            0 :     async fn get_allowed_ips_and_secret(
     218            0 :         &self,
     219            0 :         _ctx: &RequestMonitoring,
     220            0 :         user_info: &ComputeUserInfo,
     221            0 :     ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
     222            0 :         Ok((
     223            0 :             Cached::new_uncached(Arc::new(
     224            0 :                 self.do_get_auth_info(user_info).await?.allowed_ips,
     225              :             )),
     226            0 :             None,
     227              :         ))
     228            0 :     }
     229              : 
     230            0 :     async fn get_endpoint_jwks(
     231            0 :         &self,
     232            0 :         _ctx: &RequestMonitoring,
     233            0 :         endpoint: EndpointId,
     234            0 :     ) -> Result<Vec<AuthRule>, GetEndpointJwksError> {
     235            0 :         self.do_get_endpoint_jwks(endpoint).await
     236            0 :     }
     237              : 
     238            0 :     #[tracing::instrument(skip_all)]
     239              :     async fn wake_compute(
     240              :         &self,
     241              :         _ctx: &RequestMonitoring,
     242              :         _user_info: &ComputeUserInfo,
     243              :     ) -> Result<CachedNodeInfo, WakeComputeError> {
     244              :         self.do_wake_compute().map_ok(Cached::new_uncached).await
     245              :     }
     246              : }
     247              : 
     248            0 : fn parse_md5(input: &str) -> Option<[u8; 16]> {
     249            0 :     let text = input.strip_prefix("md5")?;
     250              : 
     251            0 :     let mut bytes = [0u8; 16];
     252            0 :     hex::decode_to_slice(text, &mut bytes).ok()?;
     253              : 
     254            0 :     Some(bytes)
     255            0 : }
        

Generated by: LCOV version 2.1-beta