LCOV - differential code coverage report
Current view: top level - proxy/src/auth/backend - link.rs (source / functions) Coverage Total Hit UBC CBC
Current: f6946e90941b557c917ac98cd5a7e9506d180f3e.info Lines: 84.1 % 69 58 11 58
Current Date: 2023-10-19 02:04:12 Functions: 29.6 % 27 8 19 8
Baseline: c8637f37369098875162f194f92736355783b050.info
Baseline Date: 2023-10-18 20:25:20

           TLA  Line data    Source code
       1                 : use super::AuthSuccess;
       2                 : use crate::{
       3                 :     auth, compute,
       4                 :     console::{self, provider::NodeInfo},
       5                 :     error::UserFacingError,
       6                 :     stream::PqStream,
       7                 :     waiters,
       8                 : };
       9                 : use pq_proto::BeMessage as Be;
      10                 : use thiserror::Error;
      11                 : use tokio::io::{AsyncRead, AsyncWrite};
      12                 : use tokio_postgres::config::SslMode;
      13                 : use tracing::{info, info_span};
      14                 : 
      15 UBC           0 : #[derive(Debug, Error)]
      16                 : pub enum LinkAuthError {
      17                 :     /// Authentication error reported by the console.
      18                 :     #[error("Authentication failed: {0}")]
      19                 :     AuthFailed(String),
      20                 : 
      21                 :     #[error(transparent)]
      22                 :     WaiterRegister(#[from] waiters::RegisterError),
      23                 : 
      24                 :     #[error(transparent)]
      25                 :     WaiterWait(#[from] waiters::WaitError),
      26                 : 
      27                 :     #[error(transparent)]
      28                 :     Io(#[from] std::io::Error),
      29                 : }
      30                 : 
      31                 : impl UserFacingError for LinkAuthError {
      32               0 :     fn to_string_client(&self) -> String {
      33               0 :         use LinkAuthError::*;
      34               0 :         match self {
      35               0 :             AuthFailed(_) => self.to_string(),
      36               0 :             _ => "Internal error".to_string(),
      37                 :         }
      38               0 :     }
      39                 : }
      40                 : 
      41 CBC           3 : fn hello_message(redirect_uri: &reqwest::Url, session_id: &str) -> String {
      42               3 :     format!(
      43               3 :         concat![
      44               3 :             "Welcome to Neon!\n",
      45               3 :             "Authenticate by visiting:\n",
      46               3 :             "    {redirect_uri}{session_id}\n\n",
      47               3 :         ],
      48               3 :         redirect_uri = redirect_uri,
      49               3 :         session_id = session_id,
      50               3 :     )
      51               3 : }
      52                 : 
      53               3 : pub fn new_psql_session_id() -> String {
      54               3 :     hex::encode(rand::random::<[u8; 8]>())
      55               3 : }
      56                 : 
      57               3 : pub(super) async fn authenticate(
      58               3 :     link_uri: &reqwest::Url,
      59               3 :     client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
      60               3 : ) -> auth::Result<AuthSuccess<NodeInfo>> {
      61               3 :     let psql_session_id = new_psql_session_id();
      62               3 :     let span = info_span!("link", psql_session_id = &psql_session_id);
      63               3 :     let greeting = hello_message(link_uri, &psql_session_id);
      64                 : 
      65               3 :     let db_info = console::mgmt::with_waiter(psql_session_id, |waiter| async {
      66               3 :         // Give user a URL to spawn a new database.
      67               3 :         info!(parent: &span, "sending the auth URL to the user");
      68               3 :         client
      69               3 :             .write_message_noflush(&Be::AuthenticationOk)?
      70               3 :             .write_message_noflush(&Be::CLIENT_ENCODING)?
      71               3 :             .write_message(&Be::NoticeResponse(&greeting))
      72 UBC           0 :             .await?;
      73                 : 
      74                 :         // Wait for web console response (see `mgmt`).
      75 CBC           6 :         info!(parent: &span, "waiting for console's reply...");
      76               3 :         waiter.await?.map_err(LinkAuthError::AuthFailed)
      77               3 :     })
      78               3 :     .await?;
      79                 : 
      80               3 :     client.write_message_noflush(&Be::NoticeResponse("Connecting to database."))?;
      81                 : 
      82                 :     // This config should be self-contained, because we won't
      83                 :     // take username or dbname from client's startup message.
      84               3 :     let mut config = compute::ConnCfg::new();
      85               3 :     config
      86               3 :         .host(&db_info.host)
      87               3 :         .port(db_info.port)
      88               3 :         .dbname(&db_info.dbname)
      89               3 :         .user(&db_info.user);
      90               3 : 
      91               3 :     // Backwards compatibility. pg_sni_proxy uses "--" in domain names
      92               3 :     // while direct connections do not. Once we migrate to pg_sni_proxy
      93               3 :     // everywhere, we can remove this.
      94               3 :     if db_info.host.contains("--") {
      95 UBC           0 :         // we need TLS connection with SNI info to properly route it
      96               0 :         config.ssl_mode(SslMode::Require);
      97 CBC           3 :     } else {
      98               3 :         config.ssl_mode(SslMode::Disable);
      99               3 :     }
     100                 : 
     101               3 :     if let Some(password) = db_info.password {
     102 UBC           0 :         config.password(password.as_ref());
     103 CBC           3 :     }
     104                 : 
     105               3 :     Ok(AuthSuccess {
     106               3 :         reported_auth_ok: true,
     107               3 :         value: NodeInfo {
     108               3 :             config,
     109               3 :             aux: db_info.aux.into(),
     110               3 :             allow_self_signed_compute: false, // caller may override
     111               3 :         },
     112               3 :     })
     113               3 : }
        

Generated by: LCOV version 2.1-beta