LCOV - code coverage report
Current view: top level - proxy/src/auth/backend - link.rs (source / functions) Coverage Total Hit
Test: aca8877be6ceba750c1be359ed71bc1799d52b30.info Lines: 80.8 % 78 63
Test Date: 2024-02-14 18:05:35 Functions: 27.3 % 22 6

            Line data    Source code
       1              : use crate::{
       2              :     auth, compute,
       3              :     console::{self, provider::NodeInfo},
       4              :     context::RequestMonitoring,
       5              :     error::{ReportableError, UserFacingError},
       6              :     stream::PqStream,
       7              :     waiters,
       8              : };
       9              : use pq_proto::BeMessage as Be;
      10              : use thiserror::Error;
      11              : use tokio::io::{AsyncRead, AsyncWrite};
      12              : use tokio_postgres::config::SslMode;
      13              : use tracing::{info, info_span};
      14              : 
      15            0 : #[derive(Debug, Error)]
      16              : pub enum LinkAuthError {
      17              :     #[error(transparent)]
      18              :     WaiterRegister(#[from] waiters::RegisterError),
      19              : 
      20              :     #[error(transparent)]
      21              :     WaiterWait(#[from] waiters::WaitError),
      22              : 
      23              :     #[error(transparent)]
      24              :     Io(#[from] std::io::Error),
      25              : }
      26              : 
      27              : impl UserFacingError for LinkAuthError {
      28            0 :     fn to_string_client(&self) -> String {
      29            0 :         "Internal error".to_string()
      30            0 :     }
      31              : }
      32              : 
      33              : impl ReportableError for LinkAuthError {
      34            0 :     fn get_error_kind(&self) -> crate::error::ErrorKind {
      35            0 :         match self {
      36            0 :             LinkAuthError::WaiterRegister(_) => crate::error::ErrorKind::Service,
      37            0 :             LinkAuthError::WaiterWait(_) => crate::error::ErrorKind::Service,
      38            0 :             LinkAuthError::Io(_) => crate::error::ErrorKind::ClientDisconnect,
      39              :         }
      40            0 :     }
      41              : }
      42              : 
      43            3 : fn hello_message(redirect_uri: &reqwest::Url, session_id: &str) -> String {
      44            3 :     format!(
      45            3 :         concat![
      46            3 :             "Welcome to Neon!\n",
      47            3 :             "Authenticate by visiting:\n",
      48            3 :             "    {redirect_uri}{session_id}\n\n",
      49            3 :         ],
      50            3 :         redirect_uri = redirect_uri,
      51            3 :         session_id = session_id,
      52            3 :     )
      53            3 : }
      54              : 
      55            3 : pub fn new_psql_session_id() -> String {
      56            3 :     hex::encode(rand::random::<[u8; 8]>())
      57            3 : }
      58              : 
      59            3 : pub(super) async fn authenticate(
      60            3 :     ctx: &mut RequestMonitoring,
      61            3 :     link_uri: &reqwest::Url,
      62            3 :     client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
      63            3 : ) -> auth::Result<NodeInfo> {
      64            3 :     ctx.set_auth_method(crate::context::AuthMethod::Web);
      65              : 
      66              :     // registering waiter can fail if we get unlucky with rng.
      67              :     // just try again.
      68            3 :     let (psql_session_id, waiter) = loop {
      69            3 :         let psql_session_id = new_psql_session_id();
      70            3 : 
      71            3 :         match console::mgmt::get_waiter(&psql_session_id) {
      72            3 :             Ok(waiter) => break (psql_session_id, waiter),
      73            0 :             Err(_e) => continue,
      74              :         }
      75              :     };
      76              : 
      77            3 :     let span = info_span!("link", psql_session_id = &psql_session_id);
      78            3 :     let greeting = hello_message(link_uri, &psql_session_id);
      79            3 : 
      80            3 :     // Give user a URL to spawn a new database.
      81            3 :     info!(parent: &span, "sending the auth URL to the user");
      82            3 :     client
      83            3 :         .write_message_noflush(&Be::AuthenticationOk)?
      84            3 :         .write_message_noflush(&Be::CLIENT_ENCODING)?
      85            3 :         .write_message(&Be::NoticeResponse(&greeting))
      86            0 :         .await?;
      87              : 
      88              :     // Wait for web console response (see `mgmt`).
      89            3 :     info!(parent: &span, "waiting for console's reply...");
      90            3 :     let db_info = waiter.await.map_err(LinkAuthError::from)?;
      91              : 
      92            3 :     client.write_message_noflush(&Be::NoticeResponse("Connecting to database."))?;
      93              : 
      94              :     // This config should be self-contained, because we won't
      95              :     // take username or dbname from client's startup message.
      96            3 :     let mut config = compute::ConnCfg::new();
      97            3 :     config
      98            3 :         .host(&db_info.host)
      99            3 :         .port(db_info.port)
     100            3 :         .dbname(&db_info.dbname)
     101            3 :         .user(&db_info.user);
     102            3 : 
     103            3 :     ctx.set_user(db_info.user.into());
     104            3 :     ctx.set_project(db_info.aux.clone());
     105            3 :     tracing::Span::current().record("ep", &tracing::field::display(&db_info.aux.endpoint_id));
     106            3 : 
     107            3 :     // Backwards compatibility. pg_sni_proxy uses "--" in domain names
     108            3 :     // while direct connections do not. Once we migrate to pg_sni_proxy
     109            3 :     // everywhere, we can remove this.
     110            3 :     if db_info.host.contains("--") {
     111            0 :         // we need TLS connection with SNI info to properly route it
     112            0 :         config.ssl_mode(SslMode::Require);
     113            3 :     } else {
     114            3 :         config.ssl_mode(SslMode::Disable);
     115            3 :     }
     116              : 
     117            3 :     if let Some(password) = db_info.password {
     118            0 :         config.password(password.as_ref());
     119            3 :     }
     120              : 
     121            3 :     Ok(NodeInfo {
     122            3 :         config,
     123            3 :         aux: db_info.aux,
     124            3 :         allow_self_signed_compute: false, // caller may override
     125            3 :     })
     126            3 : }
        

Generated by: LCOV version 2.1-beta