LCOV - code coverage report
Current view: top level - proxy/src/console/provider - mock.rs (source / functions) Coverage Total Hit
Test: 32f4a56327bc9da697706839ed4836b2a00a408f.info Lines: 84.7 % 111 94
Test Date: 2024-02-07 07:37:29 Functions: 65.4 % 26 17

            Line data    Source code
       1              : //! Mock console backend which relies on a user-provided postgres instance.
       2              : 
       3              : use super::{
       4              :     errors::{ApiError, GetAuthInfoError, WakeComputeError},
       5              :     AuthInfo, AuthSecret, CachedNodeInfo, NodeInfo,
       6              : };
       7              : use crate::console::provider::{CachedAllowedIps, CachedRoleSecret};
       8              : use crate::context::RequestMonitoring;
       9              : use crate::{auth::backend::ComputeUserInfo, compute, error::io_error, scram, url::ApiUrl};
      10              : use crate::{auth::IpPattern, cache::Cached};
      11              : use async_trait::async_trait;
      12              : use futures::TryFutureExt;
      13              : use std::{str::FromStr, sync::Arc};
      14              : use thiserror::Error;
      15              : use tokio_postgres::{config::SslMode, Client};
      16              : use tracing::{error, info, info_span, warn, Instrument};
      17              : 
      18            0 : #[derive(Debug, Error)]
      19              : enum MockApiError {
      20              :     #[error("Failed to read password: {0}")]
      21              :     PasswordNotSet(tokio_postgres::Error),
      22              : }
      23              : 
      24              : impl From<MockApiError> for ApiError {
      25            0 :     fn from(e: MockApiError) -> Self {
      26            0 :         io_error(e).into()
      27            0 :     }
      28              : }
      29              : 
      30              : impl From<tokio_postgres::Error> for ApiError {
      31            0 :     fn from(e: tokio_postgres::Error) -> Self {
      32            0 :         io_error(e).into()
      33            0 :     }
      34              : }
      35              : 
      36            0 : #[derive(Clone)]
      37              : pub struct Api {
      38              :     endpoint: ApiUrl,
      39              : }
      40              : 
      41              : impl Api {
      42           19 :     pub fn new(endpoint: ApiUrl) -> Self {
      43           19 :         Self { endpoint }
      44           19 :     }
      45              : 
      46           19 :     pub fn url(&self) -> &str {
      47           19 :         self.endpoint.as_str()
      48           19 :     }
      49              : 
      50          123 :     async fn do_get_auth_info(
      51          123 :         &self,
      52          123 :         user_info: &ComputeUserInfo,
      53          123 :     ) -> Result<AuthInfo, GetAuthInfoError> {
      54          123 :         let (secret, allowed_ips) = async {
      55              :             // Perhaps we could persist this connection, but then we'd have to
      56              :             // write more code for reopening it if it got closed, which doesn't
      57              :             // seem worth it.
      58          123 :             let (client, connection) =
      59          373 :                 tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?;
      60              : 
      61          123 :             tokio::spawn(connection);
      62          123 :             let secret = match get_execute_postgres_query(
      63          123 :                 &client,
      64          123 :                 "select rolpassword from pg_catalog.pg_authid where rolname = $1",
      65          123 :                 &[&&*user_info.user],
      66          123 :                 "rolpassword",
      67          123 :             )
      68          246 :             .await?
      69              :             {
      70          121 :                 Some(entry) => {
      71          121 :                     info!("got a secret: {entry}"); // safe since it's not a prod scenario
      72          121 :                     let secret = scram::ServerSecret::parse(&entry).map(AuthSecret::Scram);
      73          121 :                     secret.or_else(|| parse_md5(&entry).map(AuthSecret::Md5))
      74              :                 }
      75              :                 None => {
      76            2 :                     warn!("user '{}' does not exist", user_info.user);
      77            2 :                     None
      78              :                 }
      79              :             };
      80          123 :             let allowed_ips = match get_execute_postgres_query(
      81          123 :                 &client,
      82          123 :                 "select allowed_ips from neon_control_plane.endpoints where endpoint_id = $1",
      83          123 :                 &[&user_info.endpoint.as_str()],
      84          123 :                 "allowed_ips",
      85          123 :             )
      86          246 :             .await?
      87              :             {
      88           11 :                 Some(s) => {
      89           11 :                     info!("got allowed_ips: {s}");
      90           11 :                     s.split(',')
      91           18 :                         .map(|s| IpPattern::from_str(s).unwrap())
      92           11 :                         .collect()
      93              :                 }
      94          112 :                 None => vec![],
      95              :             };
      96              : 
      97          123 :             Ok((secret, allowed_ips))
      98          123 :         }
      99          123 :         .map_err(crate::error::log_error::<GetAuthInfoError>)
     100          123 :         .instrument(info_span!("postgres", url = self.endpoint.as_str()))
     101          865 :         .await?;
     102          123 :         Ok(AuthInfo {
     103          123 :             secret,
     104          123 :             allowed_ips,
     105          123 :             project_id: None,
     106          123 :         })
     107          123 :     }
     108              : 
     109           79 :     async fn do_wake_compute(&self) -> Result<NodeInfo, WakeComputeError> {
     110           79 :         let mut config = compute::ConnCfg::new();
     111           79 :         config
     112           79 :             .host(self.endpoint.host_str().unwrap_or("localhost"))
     113           79 :             .port(self.endpoint.port().unwrap_or(5432))
     114           79 :             .ssl_mode(SslMode::Disable);
     115           79 : 
     116           79 :         let node = NodeInfo {
     117           79 :             config,
     118           79 :             aux: Default::default(),
     119           79 :             allow_self_signed_compute: false,
     120           79 :         };
     121           79 : 
     122           79 :         Ok(node)
     123           79 :     }
     124              : }
     125              : 
     126          246 : async fn get_execute_postgres_query(
     127          246 :     client: &Client,
     128          246 :     query: &str,
     129          246 :     params: &[&(dyn tokio_postgres::types::ToSql + Sync)],
     130          246 :     idx: &str,
     131          246 : ) -> Result<Option<String>, GetAuthInfoError> {
     132          492 :     let rows = client.query(query, params).await?;
     133              : 
     134              :     // We can get at most one row, because `rolname` is unique.
     135          246 :     let row = match rows.first() {
     136          132 :         Some(row) => row,
     137              :         // This means that the user doesn't exist, so there can be no secret.
     138              :         // However, this is still a *valid* outcome which is very similar
     139              :         // to getting `404 Not found` from the Neon console.
     140          114 :         None => return Ok(None),
     141              :     };
     142              : 
     143          132 :     let entry = row.try_get(idx).map_err(MockApiError::PasswordNotSet)?;
     144          132 :     Ok(Some(entry))
     145          246 : }
     146              : 
     147              : #[async_trait]
     148              : impl super::Api for Api {
     149            0 :     #[tracing::instrument(skip_all)]
     150              :     async fn get_role_secret(
     151              :         &self,
     152              :         _ctx: &mut RequestMonitoring,
     153              :         user_info: &ComputeUserInfo,
     154           39 :     ) -> Result<CachedRoleSecret, GetAuthInfoError> {
     155              :         Ok(CachedRoleSecret::new_uncached(
     156          267 :             self.do_get_auth_info(user_info).await?.secret,
     157              :         ))
     158           78 :     }
     159              : 
     160           84 :     async fn get_allowed_ips_and_secret(
     161           84 :         &self,
     162           84 :         _ctx: &mut RequestMonitoring,
     163           84 :         user_info: &ComputeUserInfo,
     164           84 :     ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
     165              :         Ok((
     166              :             Cached::new_uncached(Arc::new(
     167          598 :                 self.do_get_auth_info(user_info).await?.allowed_ips,
     168              :             )),
     169           84 :             None,
     170              :         ))
     171          168 :     }
     172              : 
     173            0 :     #[tracing::instrument(skip_all)]
     174              :     async fn wake_compute(
     175              :         &self,
     176              :         _ctx: &mut RequestMonitoring,
     177              :         _user_info: &ComputeUserInfo,
     178           79 :     ) -> Result<CachedNodeInfo, WakeComputeError> {
     179           79 :         self.do_wake_compute()
     180           79 :             .map_ok(CachedNodeInfo::new_uncached)
     181            0 :             .await
     182          158 :     }
     183              : }
     184              : 
     185            0 : fn parse_md5(input: &str) -> Option<[u8; 16]> {
     186            0 :     let text = input.strip_prefix("md5")?;
     187              : 
     188            0 :     let mut bytes = [0u8; 16];
     189            0 :     hex::decode_to_slice(text, &mut bytes).ok()?;
     190              : 
     191            0 :     Some(bytes)
     192            0 : }
        

Generated by: LCOV version 2.1-beta