LCOV - code coverage report
Current view: top level - proxy/src/auth/backend - console_redirect.rs (source / functions) Coverage Total Hit
Test: 4f58e98c51285c7fa348e0b410c88a10caf68ad2.info Lines: 0.0 % 119 0
Test Date: 2025-01-07 20:58:07 Functions: 0.0 % 20 0

            Line data    Source code
       1              : use async_trait::async_trait;
       2              : use postgres_client::config::SslMode;
       3              : use pq_proto::BeMessage as Be;
       4              : use thiserror::Error;
       5              : use tokio::io::{AsyncRead, AsyncWrite};
       6              : use tracing::{info, info_span};
       7              : 
       8              : use super::ComputeCredentialKeys;
       9              : use crate::auth::IpPattern;
      10              : use crate::cache::Cached;
      11              : use crate::config::AuthenticationConfig;
      12              : use crate::context::RequestContext;
      13              : use crate::control_plane::{self, CachedNodeInfo, NodeInfo};
      14              : use crate::error::{ReportableError, UserFacingError};
      15              : use crate::proxy::connect_compute::ComputeConnectBackend;
      16              : use crate::stream::PqStream;
      17              : use crate::{auth, compute, waiters};
      18              : 
      19              : #[derive(Debug, Error)]
      20              : pub(crate) enum ConsoleRedirectError {
      21              :     #[error(transparent)]
      22              :     WaiterRegister(#[from] waiters::RegisterError),
      23              : 
      24              :     #[error(transparent)]
      25              :     WaiterWait(#[from] waiters::WaitError),
      26              : 
      27              :     #[error(transparent)]
      28              :     Io(#[from] std::io::Error),
      29              : }
      30              : 
      31              : #[derive(Debug)]
      32              : pub struct ConsoleRedirectBackend {
      33              :     console_uri: reqwest::Url,
      34              : }
      35              : 
      36              : impl UserFacingError for ConsoleRedirectError {
      37            0 :     fn to_string_client(&self) -> String {
      38            0 :         "Internal error".to_string()
      39            0 :     }
      40              : }
      41              : 
      42              : impl ReportableError for ConsoleRedirectError {
      43            0 :     fn get_error_kind(&self) -> crate::error::ErrorKind {
      44            0 :         match self {
      45            0 :             Self::WaiterRegister(_) => crate::error::ErrorKind::Service,
      46            0 :             Self::WaiterWait(_) => crate::error::ErrorKind::Service,
      47            0 :             Self::Io(_) => crate::error::ErrorKind::ClientDisconnect,
      48              :         }
      49            0 :     }
      50              : }
      51              : 
      52            0 : fn hello_message(
      53            0 :     redirect_uri: &reqwest::Url,
      54            0 :     session_id: &str,
      55            0 :     duration: std::time::Duration,
      56            0 : ) -> String {
      57            0 :     let formatted_duration = humantime::format_duration(duration).to_string();
      58            0 :     format!(
      59            0 :         concat![
      60            0 :             "Welcome to Neon!\n",
      61            0 :             "Authenticate by visiting (will expire in {duration}):\n",
      62            0 :             "    {redirect_uri}{session_id}\n\n",
      63            0 :         ],
      64            0 :         duration = formatted_duration,
      65            0 :         redirect_uri = redirect_uri,
      66            0 :         session_id = session_id,
      67            0 :     )
      68            0 : }
      69              : 
      70            0 : pub(crate) fn new_psql_session_id() -> String {
      71            0 :     hex::encode(rand::random::<[u8; 8]>())
      72            0 : }
      73              : 
      74              : impl ConsoleRedirectBackend {
      75            0 :     pub fn new(console_uri: reqwest::Url) -> Self {
      76            0 :         Self { console_uri }
      77            0 :     }
      78              : 
      79            0 :     pub(crate) async fn authenticate(
      80            0 :         &self,
      81            0 :         ctx: &RequestContext,
      82            0 :         auth_config: &'static AuthenticationConfig,
      83            0 :         client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
      84            0 :     ) -> auth::Result<(ConsoleRedirectNodeInfo, Option<Vec<IpPattern>>)> {
      85            0 :         authenticate(ctx, auth_config, &self.console_uri, client)
      86            0 :             .await
      87            0 :             .map(|(node_info, ip_allowlist)| (ConsoleRedirectNodeInfo(node_info), ip_allowlist))
      88            0 :     }
      89              : }
      90              : 
      91              : pub struct ConsoleRedirectNodeInfo(pub(super) NodeInfo);
      92              : 
      93              : #[async_trait]
      94              : impl ComputeConnectBackend for ConsoleRedirectNodeInfo {
      95            0 :     async fn wake_compute(
      96            0 :         &self,
      97            0 :         _ctx: &RequestContext,
      98            0 :     ) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
      99            0 :         Ok(Cached::new_uncached(self.0.clone()))
     100            0 :     }
     101              : 
     102            0 :     fn get_keys(&self) -> &ComputeCredentialKeys {
     103            0 :         &ComputeCredentialKeys::None
     104            0 :     }
     105              : }
     106              : 
     107            0 : async fn authenticate(
     108            0 :     ctx: &RequestContext,
     109            0 :     auth_config: &'static AuthenticationConfig,
     110            0 :     link_uri: &reqwest::Url,
     111            0 :     client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
     112            0 : ) -> auth::Result<(NodeInfo, Option<Vec<IpPattern>>)> {
     113            0 :     ctx.set_auth_method(crate::context::AuthMethod::ConsoleRedirect);
     114              : 
     115              :     // registering waiter can fail if we get unlucky with rng.
     116              :     // just try again.
     117            0 :     let (psql_session_id, waiter) = loop {
     118            0 :         let psql_session_id = new_psql_session_id();
     119            0 : 
     120            0 :         match control_plane::mgmt::get_waiter(&psql_session_id) {
     121            0 :             Ok(waiter) => break (psql_session_id, waiter),
     122            0 :             Err(_e) => continue,
     123              :         }
     124              :     };
     125              : 
     126            0 :     let span = info_span!("console_redirect", psql_session_id = &psql_session_id);
     127            0 :     let greeting = hello_message(
     128            0 :         link_uri,
     129            0 :         &psql_session_id,
     130            0 :         auth_config.console_redirect_confirmation_timeout,
     131            0 :     );
     132            0 : 
     133            0 :     // Give user a URL to spawn a new database.
     134            0 :     info!(parent: &span, "sending the auth URL to the user");
     135            0 :     client
     136            0 :         .write_message_noflush(&Be::AuthenticationOk)?
     137            0 :         .write_message_noflush(&Be::CLIENT_ENCODING)?
     138            0 :         .write_message(&Be::NoticeResponse(&greeting))
     139            0 :         .await?;
     140              : 
     141              :     // Wait for console response via control plane (see `mgmt`).
     142            0 :     info!(parent: &span, "waiting for console's reply...");
     143            0 :     let db_info = tokio::time::timeout(auth_config.console_redirect_confirmation_timeout, waiter)
     144            0 :         .await
     145            0 :         .map_err(|_elapsed| {
     146            0 :             auth::AuthError::confirmation_timeout(
     147            0 :                 auth_config.console_redirect_confirmation_timeout.into(),
     148            0 :             )
     149            0 :         })?
     150            0 :         .map_err(ConsoleRedirectError::from)?;
     151              : 
     152            0 :     if auth_config.ip_allowlist_check_enabled {
     153            0 :         if let Some(allowed_ips) = &db_info.allowed_ips {
     154            0 :             if !auth::check_peer_addr_is_in_list(&ctx.peer_addr(), allowed_ips) {
     155            0 :                 return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr()));
     156            0 :             }
     157            0 :         }
     158            0 :     }
     159              : 
     160            0 :     client.write_message_noflush(&Be::NoticeResponse("Connecting to database."))?;
     161              : 
     162              :     // This config should be self-contained, because we won't
     163              :     // take username or dbname from client's startup message.
     164            0 :     let mut config = compute::ConnCfg::new(db_info.host.to_string(), db_info.port);
     165            0 :     config.dbname(&db_info.dbname).user(&db_info.user);
     166            0 : 
     167            0 :     ctx.set_dbname(db_info.dbname.into());
     168            0 :     ctx.set_user(db_info.user.into());
     169            0 :     ctx.set_project(db_info.aux.clone());
     170            0 :     info!("woken up a compute node");
     171              : 
     172              :     // Backwards compatibility. pg_sni_proxy uses "--" in domain names
     173              :     // while direct connections do not. Once we migrate to pg_sni_proxy
     174              :     // everywhere, we can remove this.
     175            0 :     if db_info.host.contains("--") {
     176            0 :         // we need TLS connection with SNI info to properly route it
     177            0 :         config.ssl_mode(SslMode::Require);
     178            0 :     } else {
     179            0 :         config.ssl_mode(SslMode::Disable);
     180            0 :     }
     181              : 
     182            0 :     if let Some(password) = db_info.password {
     183            0 :         config.password(password.as_ref());
     184            0 :     }
     185              : 
     186            0 :     Ok((
     187            0 :         NodeInfo {
     188            0 :             config,
     189            0 :             aux: db_info.aux,
     190            0 :         },
     191            0 :         db_info.allowed_ips,
     192            0 :     ))
     193            0 : }
        

Generated by: LCOV version 2.1-beta