LCOV - code coverage report
Current view: top level - proxy/src/auth/backend - link.rs (source / functions) Coverage Total Hit
Test: f081ec316c96fa98335efd15ef501745aa4f015d.info Lines: 0.0 % 71 0
Test Date: 2024-06-25 15:11:17 Functions: 0.0 % 15 0

            Line data    Source code
       1              : use crate::{
       2              :     auth, compute,
       3              :     console::{self, provider::NodeInfo},
       4              :     context::RequestMonitoring,
       5              :     error::{ReportableError, UserFacingError},
       6              :     stream::PqStream,
       7              :     waiters,
       8              : };
       9              : use pq_proto::BeMessage as Be;
      10              : use thiserror::Error;
      11              : use tokio::io::{AsyncRead, AsyncWrite};
      12              : use tokio_postgres::config::SslMode;
      13              : use tracing::{info, info_span};
      14              : 
      15            0 : #[derive(Debug, Error)]
      16              : pub enum LinkAuthError {
      17              :     #[error(transparent)]
      18              :     WaiterRegister(#[from] waiters::RegisterError),
      19              : 
      20              :     #[error(transparent)]
      21              :     WaiterWait(#[from] waiters::WaitError),
      22              : 
      23              :     #[error(transparent)]
      24              :     Io(#[from] std::io::Error),
      25              : }
      26              : 
      27              : impl UserFacingError for LinkAuthError {
      28            0 :     fn to_string_client(&self) -> String {
      29            0 :         "Internal error".to_string()
      30            0 :     }
      31              : }
      32              : 
      33              : impl ReportableError for LinkAuthError {
      34            0 :     fn get_error_kind(&self) -> crate::error::ErrorKind {
      35            0 :         match self {
      36            0 :             LinkAuthError::WaiterRegister(_) => crate::error::ErrorKind::Service,
      37            0 :             LinkAuthError::WaiterWait(_) => crate::error::ErrorKind::Service,
      38            0 :             LinkAuthError::Io(_) => crate::error::ErrorKind::ClientDisconnect,
      39              :         }
      40            0 :     }
      41              : }
      42              : 
      43            0 : fn hello_message(redirect_uri: &reqwest::Url, session_id: &str) -> String {
      44            0 :     format!(
      45            0 :         concat![
      46            0 :             "Welcome to Neon!\n",
      47            0 :             "Authenticate by visiting:\n",
      48            0 :             "    {redirect_uri}{session_id}\n\n",
      49            0 :         ],
      50            0 :         redirect_uri = redirect_uri,
      51            0 :         session_id = session_id,
      52            0 :     )
      53            0 : }
      54              : 
      55            0 : pub fn new_psql_session_id() -> String {
      56            0 :     hex::encode(rand::random::<[u8; 8]>())
      57            0 : }
      58              : 
      59            0 : pub(super) async fn authenticate(
      60            0 :     ctx: &mut RequestMonitoring,
      61            0 :     link_uri: &reqwest::Url,
      62            0 :     client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
      63            0 : ) -> auth::Result<NodeInfo> {
      64            0 :     ctx.set_auth_method(crate::context::AuthMethod::Web);
      65              : 
      66              :     // registering waiter can fail if we get unlucky with rng.
      67              :     // just try again.
      68            0 :     let (psql_session_id, waiter) = loop {
      69            0 :         let psql_session_id = new_psql_session_id();
      70            0 : 
      71            0 :         match console::mgmt::get_waiter(&psql_session_id) {
      72            0 :             Ok(waiter) => break (psql_session_id, waiter),
      73            0 :             Err(_e) => continue,
      74              :         }
      75              :     };
      76              : 
      77            0 :     let span = info_span!("link", psql_session_id = &psql_session_id);
      78            0 :     let greeting = hello_message(link_uri, &psql_session_id);
      79              : 
      80              :     // Give user a URL to spawn a new database.
      81              :     info!(parent: &span, "sending the auth URL to the user");
      82            0 :     client
      83            0 :         .write_message_noflush(&Be::AuthenticationOk)?
      84            0 :         .write_message_noflush(&Be::CLIENT_ENCODING)?
      85            0 :         .write_message(&Be::NoticeResponse(&greeting))
      86            0 :         .await?;
      87              : 
      88              :     // Wait for web console response (see `mgmt`).
      89              :     info!(parent: &span, "waiting for console's reply...");
      90            0 :     let db_info = waiter.await.map_err(LinkAuthError::from)?;
      91              : 
      92            0 :     client.write_message_noflush(&Be::NoticeResponse("Connecting to database."))?;
      93              : 
      94              :     // This config should be self-contained, because we won't
      95              :     // take username or dbname from client's startup message.
      96            0 :     let mut config = compute::ConnCfg::new();
      97            0 :     config
      98            0 :         .host(&db_info.host)
      99            0 :         .port(db_info.port)
     100            0 :         .dbname(&db_info.dbname)
     101            0 :         .user(&db_info.user);
     102            0 : 
     103            0 :     ctx.set_dbname(db_info.dbname.into());
     104            0 :     ctx.set_user(db_info.user.into());
     105            0 :     ctx.set_project(db_info.aux.clone());
     106            0 :     info!("woken up a compute node");
     107              : 
     108              :     // Backwards compatibility. pg_sni_proxy uses "--" in domain names
     109              :     // while direct connections do not. Once we migrate to pg_sni_proxy
     110              :     // everywhere, we can remove this.
     111            0 :     if db_info.host.contains("--") {
     112            0 :         // we need TLS connection with SNI info to properly route it
     113            0 :         config.ssl_mode(SslMode::Require);
     114            0 :     } else {
     115            0 :         config.ssl_mode(SslMode::Disable);
     116            0 :     }
     117              : 
     118            0 :     if let Some(password) = db_info.password {
     119            0 :         config.password(password.as_ref());
     120            0 :     }
     121              : 
     122            0 :     Ok(NodeInfo {
     123            0 :         config,
     124            0 :         aux: db_info.aux,
     125            0 :         allow_self_signed_compute: false, // caller may override
     126            0 :     })
     127            0 : }
        

Generated by: LCOV version 2.1-beta