LCOV - code coverage report
Current view: top level - proxy/src/auth/backend - console_redirect.rs (source / functions) Coverage Total Hit
Test: c8f8d331b83562868d9054d9e0e68f866772aeaa.info Lines: 0.0 % 118 0
Test Date: 2025-07-26 17:20:05 Functions: 0.0 % 15 0

            Line data    Source code
       1              : use std::fmt;
       2              : 
       3              : use async_trait::async_trait;
       4              : use postgres_client::config::SslMode;
       5              : use thiserror::Error;
       6              : use tokio::io::{AsyncRead, AsyncWrite};
       7              : use tracing::{info, info_span};
       8              : 
       9              : use crate::auth::backend::ComputeUserInfo;
      10              : use crate::cache::Cached;
      11              : use crate::cache::node_info::CachedNodeInfo;
      12              : use crate::compute::AuthInfo;
      13              : use crate::config::AuthenticationConfig;
      14              : use crate::context::RequestContext;
      15              : use crate::control_plane::client::cplane_proxy_v1;
      16              : use crate::control_plane::{self, NodeInfo};
      17              : use crate::error::{ReportableError, UserFacingError};
      18              : use crate::pqproto::BeMessage;
      19              : use crate::proxy::NeonOptions;
      20              : use crate::proxy::wake_compute::WakeComputeBackend;
      21              : use crate::stream::PqStream;
      22              : use crate::types::RoleName;
      23              : use crate::{auth, compute, waiters};
      24              : 
      25              : #[derive(Debug, Error)]
      26              : pub(crate) enum ConsoleRedirectError {
      27              :     #[error(transparent)]
      28              :     WaiterRegister(#[from] waiters::RegisterError),
      29              : 
      30              :     #[error(transparent)]
      31              :     WaiterWait(#[from] waiters::WaitError),
      32              : 
      33              :     #[error(transparent)]
      34              :     Io(#[from] std::io::Error),
      35              : }
      36              : 
      37              : #[derive(Debug)]
      38              : pub struct ConsoleRedirectBackend {
      39              :     console_uri: reqwest::Url,
      40              :     api: cplane_proxy_v1::NeonControlPlaneClient,
      41              : }
      42              : 
      43              : impl fmt::Debug for cplane_proxy_v1::NeonControlPlaneClient {
      44            0 :     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
      45            0 :         write!(f, "NeonControlPlaneClient")
      46            0 :     }
      47              : }
      48              : 
      49              : impl UserFacingError for ConsoleRedirectError {
      50            0 :     fn to_string_client(&self) -> String {
      51            0 :         "Internal error".to_string()
      52            0 :     }
      53              : }
      54              : 
      55              : impl ReportableError for ConsoleRedirectError {
      56            0 :     fn get_error_kind(&self) -> crate::error::ErrorKind {
      57            0 :         match self {
      58            0 :             Self::WaiterRegister(_) => crate::error::ErrorKind::Service,
      59            0 :             Self::WaiterWait(_) => crate::error::ErrorKind::Service,
      60            0 :             Self::Io(_) => crate::error::ErrorKind::ClientDisconnect,
      61              :         }
      62            0 :     }
      63              : }
      64              : 
      65            0 : fn hello_message(
      66            0 :     redirect_uri: &reqwest::Url,
      67            0 :     session_id: &str,
      68            0 :     duration: std::time::Duration,
      69            0 : ) -> String {
      70            0 :     let formatted_duration = humantime::format_duration(duration).to_string();
      71            0 :     format!(
      72            0 :         concat![
      73              :             "Welcome to Neon!\n",
      74              :             "Authenticate by visiting (will expire in {duration}):\n",
      75              :             "    {redirect_uri}{session_id}\n\n",
      76              :         ],
      77              :         duration = formatted_duration,
      78              :         redirect_uri = redirect_uri,
      79              :         session_id = session_id,
      80              :     )
      81            0 : }
      82              : 
      83            0 : pub(crate) fn new_psql_session_id() -> String {
      84            0 :     hex::encode(rand::random::<[u8; 8]>())
      85            0 : }
      86              : 
      87              : impl ConsoleRedirectBackend {
      88            0 :     pub fn new(console_uri: reqwest::Url, api: cplane_proxy_v1::NeonControlPlaneClient) -> Self {
      89            0 :         Self { console_uri, api }
      90            0 :     }
      91              : 
      92            0 :     pub(crate) fn get_api(&self) -> &cplane_proxy_v1::NeonControlPlaneClient {
      93            0 :         &self.api
      94            0 :     }
      95              : 
      96            0 :     pub(crate) async fn authenticate(
      97            0 :         &self,
      98            0 :         ctx: &RequestContext,
      99            0 :         auth_config: &'static AuthenticationConfig,
     100            0 :         client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
     101            0 :     ) -> auth::Result<(ConsoleRedirectNodeInfo, AuthInfo, ComputeUserInfo)> {
     102            0 :         authenticate(ctx, auth_config, &self.console_uri, client)
     103            0 :             .await
     104            0 :             .map(|(node_info, auth_info, user_info)| {
     105            0 :                 (ConsoleRedirectNodeInfo(node_info), auth_info, user_info)
     106            0 :             })
     107            0 :     }
     108              : }
     109              : 
     110              : pub struct ConsoleRedirectNodeInfo(pub(super) NodeInfo);
     111              : 
     112              : #[async_trait]
     113              : impl WakeComputeBackend for ConsoleRedirectNodeInfo {
     114            0 :     async fn wake_compute(
     115              :         &self,
     116              :         _ctx: &RequestContext,
     117            0 :     ) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
     118            0 :         Ok(Cached::new_uncached(self.0.clone()))
     119            0 :     }
     120              : }
     121              : 
     122            0 : async fn authenticate(
     123            0 :     ctx: &RequestContext,
     124            0 :     auth_config: &'static AuthenticationConfig,
     125            0 :     link_uri: &reqwest::Url,
     126            0 :     client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
     127            0 : ) -> auth::Result<(NodeInfo, AuthInfo, ComputeUserInfo)> {
     128            0 :     ctx.set_auth_method(crate::context::AuthMethod::ConsoleRedirect);
     129              : 
     130              :     // registering waiter can fail if we get unlucky with rng.
     131              :     // just try again.
     132            0 :     let (psql_session_id, waiter) = loop {
     133            0 :         let psql_session_id = new_psql_session_id();
     134              : 
     135            0 :         if let Ok(waiter) = control_plane::mgmt::get_waiter(&psql_session_id) {
     136            0 :             break (psql_session_id, waiter);
     137            0 :         }
     138              :     };
     139              : 
     140            0 :     let span = info_span!("console_redirect", psql_session_id = &psql_session_id);
     141            0 :     let greeting = hello_message(
     142            0 :         link_uri,
     143            0 :         &psql_session_id,
     144            0 :         auth_config.console_redirect_confirmation_timeout,
     145              :     );
     146              : 
     147              :     // Give user a URL to spawn a new database.
     148            0 :     info!(parent: &span, "sending the auth URL to the user");
     149            0 :     client.write_message(BeMessage::AuthenticationOk);
     150            0 :     client.write_message(BeMessage::ParameterStatus {
     151            0 :         name: b"client_encoding",
     152            0 :         value: b"UTF8",
     153            0 :     });
     154            0 :     client.write_message(BeMessage::NoticeResponse(&greeting));
     155            0 :     client.flush().await?;
     156              : 
     157              :     // Wait for console response via control plane (see `mgmt`).
     158            0 :     info!(parent: &span, "waiting for console's reply...");
     159            0 :     let db_info = tokio::time::timeout(auth_config.console_redirect_confirmation_timeout, waiter)
     160            0 :         .await
     161            0 :         .map_err(|_elapsed| {
     162            0 :             auth::AuthError::confirmation_timeout(
     163            0 :                 auth_config.console_redirect_confirmation_timeout.into(),
     164              :             )
     165            0 :         })?
     166            0 :         .map_err(ConsoleRedirectError::from)?;
     167              : 
     168            0 :     if auth_config.ip_allowlist_check_enabled
     169            0 :         && let Some(allowed_ips) = &db_info.allowed_ips
     170            0 :         && !auth::check_peer_addr_is_in_list(&ctx.peer_addr(), allowed_ips)
     171              :     {
     172            0 :         return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr()));
     173            0 :     }
     174              : 
     175              :     // Check if the access over the public internet is allowed, otherwise block. Note that
     176              :     // the console redirect is not behind the VPC service endpoint, so we don't need to check
     177              :     // the VPC endpoint ID.
     178            0 :     if let Some(public_access_allowed) = db_info.public_access_allowed
     179            0 :         && !public_access_allowed
     180              :     {
     181            0 :         return Err(auth::AuthError::NetworkNotAllowed);
     182            0 :     }
     183              : 
     184              :     // Backwards compatibility. pg_sni_proxy uses "--" in domain names
     185              :     // while direct connections do not. Once we migrate to pg_sni_proxy
     186              :     // everywhere, we can remove this.
     187            0 :     let ssl_mode = if db_info.host.contains("--") {
     188              :         // we need TLS connection with SNI info to properly route it
     189            0 :         SslMode::Require
     190              :     } else {
     191            0 :         SslMode::Disable
     192              :     };
     193              : 
     194            0 :     let conn_info = compute::ConnectInfo {
     195            0 :         host: db_info.host.into(),
     196            0 :         port: db_info.port,
     197            0 :         ssl_mode,
     198            0 :         host_addr: None,
     199            0 :     };
     200            0 :     let auth_info =
     201            0 :         AuthInfo::for_console_redirect(&db_info.dbname, &db_info.user, db_info.password.as_deref());
     202              : 
     203            0 :     let user: RoleName = db_info.user.into();
     204            0 :     let user_info = ComputeUserInfo {
     205            0 :         endpoint: db_info.aux.endpoint_id.as_str().into(),
     206            0 :         user: user.clone(),
     207            0 :         options: NeonOptions::default(),
     208            0 :     };
     209              : 
     210            0 :     ctx.set_dbname(db_info.dbname.into());
     211            0 :     ctx.set_user(user);
     212            0 :     ctx.set_project(db_info.aux.clone());
     213            0 :     info!("woken up a compute node");
     214              : 
     215            0 :     Ok((
     216            0 :         NodeInfo {
     217            0 :             conn_info,
     218            0 :             aux: db_info.aux,
     219            0 :         },
     220            0 :         auth_info,
     221            0 :         user_info,
     222            0 :     ))
     223            0 : }
        

Generated by: LCOV version 2.1-beta