LCOV - code coverage report
Current view: top level - proxy/src/auth/backend - web.rs (source / functions) Coverage Total Hit
Test: fc67f8dc6087a0b4f4f0bcd74f6e1dc25fab8cf3.info Lines: 0.0 % 83 0
Test Date: 2024-09-24 13:57:57 Functions: 0.0 % 15 0

            Line data    Source code
       1              : use crate::{
       2              :     auth, compute,
       3              :     config::AuthenticationConfig,
       4              :     console::{self, provider::NodeInfo},
       5              :     context::RequestMonitoring,
       6              :     error::{ReportableError, UserFacingError},
       7              :     stream::PqStream,
       8              :     waiters,
       9              : };
      10              : use pq_proto::BeMessage as Be;
      11              : use thiserror::Error;
      12              : use tokio::io::{AsyncRead, AsyncWrite};
      13              : use tokio_postgres::config::SslMode;
      14              : use tracing::{info, info_span};
      15              : 
      16            0 : #[derive(Debug, Error)]
      17              : pub(crate) enum WebAuthError {
      18              :     #[error(transparent)]
      19              :     WaiterRegister(#[from] waiters::RegisterError),
      20              : 
      21              :     #[error(transparent)]
      22              :     WaiterWait(#[from] waiters::WaitError),
      23              : 
      24              :     #[error(transparent)]
      25              :     Io(#[from] std::io::Error),
      26              : }
      27              : 
      28              : impl UserFacingError for WebAuthError {
      29            0 :     fn to_string_client(&self) -> String {
      30            0 :         "Internal error".to_string()
      31            0 :     }
      32              : }
      33              : 
      34              : impl ReportableError for WebAuthError {
      35            0 :     fn get_error_kind(&self) -> crate::error::ErrorKind {
      36            0 :         match self {
      37            0 :             Self::WaiterRegister(_) => crate::error::ErrorKind::Service,
      38            0 :             Self::WaiterWait(_) => crate::error::ErrorKind::Service,
      39            0 :             Self::Io(_) => crate::error::ErrorKind::ClientDisconnect,
      40              :         }
      41            0 :     }
      42              : }
      43              : 
      44            0 : fn hello_message(redirect_uri: &reqwest::Url, session_id: &str) -> String {
      45            0 :     format!(
      46            0 :         concat![
      47            0 :             "Welcome to Neon!\n",
      48            0 :             "Authenticate by visiting:\n",
      49            0 :             "    {redirect_uri}{session_id}\n\n",
      50            0 :         ],
      51            0 :         redirect_uri = redirect_uri,
      52            0 :         session_id = session_id,
      53            0 :     )
      54            0 : }
      55              : 
      56            0 : pub(crate) fn new_psql_session_id() -> String {
      57            0 :     hex::encode(rand::random::<[u8; 8]>())
      58            0 : }
      59              : 
      60            0 : pub(super) async fn authenticate(
      61            0 :     ctx: &RequestMonitoring,
      62            0 :     auth_config: &'static AuthenticationConfig,
      63            0 :     link_uri: &reqwest::Url,
      64            0 :     client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
      65            0 : ) -> auth::Result<NodeInfo> {
      66            0 :     ctx.set_auth_method(crate::context::AuthMethod::Web);
      67              : 
      68              :     // registering waiter can fail if we get unlucky with rng.
      69              :     // just try again.
      70            0 :     let (psql_session_id, waiter) = loop {
      71            0 :         let psql_session_id = new_psql_session_id();
      72            0 : 
      73            0 :         match console::mgmt::get_waiter(&psql_session_id) {
      74            0 :             Ok(waiter) => break (psql_session_id, waiter),
      75            0 :             Err(_e) => continue,
      76              :         }
      77              :     };
      78              : 
      79            0 :     let span = info_span!("web", psql_session_id = &psql_session_id);
      80            0 :     let greeting = hello_message(link_uri, &psql_session_id);
      81            0 : 
      82            0 :     // Give user a URL to spawn a new database.
      83            0 :     info!(parent: &span, "sending the auth URL to the user");
      84            0 :     client
      85            0 :         .write_message_noflush(&Be::AuthenticationOk)?
      86            0 :         .write_message_noflush(&Be::CLIENT_ENCODING)?
      87            0 :         .write_message(&Be::NoticeResponse(&greeting))
      88            0 :         .await?;
      89              : 
      90              :     // Wait for web console response (see `mgmt`).
      91            0 :     info!(parent: &span, "waiting for console's reply...");
      92            0 :     let db_info = waiter.await.map_err(WebAuthError::from)?;
      93              : 
      94            0 :     if auth_config.ip_allowlist_check_enabled {
      95            0 :         if let Some(allowed_ips) = &db_info.allowed_ips {
      96            0 :             if !auth::check_peer_addr_is_in_list(&ctx.peer_addr(), allowed_ips) {
      97            0 :                 return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr()));
      98            0 :             }
      99            0 :         }
     100            0 :     }
     101              : 
     102            0 :     client.write_message_noflush(&Be::NoticeResponse("Connecting to database."))?;
     103              : 
     104              :     // This config should be self-contained, because we won't
     105              :     // take username or dbname from client's startup message.
     106            0 :     let mut config = compute::ConnCfg::new();
     107            0 :     config
     108            0 :         .host(&db_info.host)
     109            0 :         .port(db_info.port)
     110            0 :         .dbname(&db_info.dbname)
     111            0 :         .user(&db_info.user);
     112            0 : 
     113            0 :     ctx.set_dbname(db_info.dbname.into());
     114            0 :     ctx.set_user(db_info.user.into());
     115            0 :     ctx.set_project(db_info.aux.clone());
     116            0 :     info!("woken up a compute node");
     117              : 
     118              :     // Backwards compatibility. pg_sni_proxy uses "--" in domain names
     119              :     // while direct connections do not. Once we migrate to pg_sni_proxy
     120              :     // everywhere, we can remove this.
     121            0 :     if db_info.host.contains("--") {
     122            0 :         // we need TLS connection with SNI info to properly route it
     123            0 :         config.ssl_mode(SslMode::Require);
     124            0 :     } else {
     125            0 :         config.ssl_mode(SslMode::Disable);
     126            0 :     }
     127              : 
     128            0 :     if let Some(password) = db_info.password {
     129            0 :         config.password(password.as_ref());
     130            0 :     }
     131              : 
     132            0 :     Ok(NodeInfo {
     133            0 :         config,
     134            0 :         aux: db_info.aux,
     135            0 :         allow_self_signed_compute: false, // caller may override
     136            0 :     })
     137            0 : }
        

Generated by: LCOV version 2.1-beta