LCOV - differential code coverage report
Current view: top level - proxy/src/auth/backend - link.rs (source / functions) Coverage Total Hit UBC CBC
Current: cd44433dd675caa99df17a61b18949c8387e2242.info Lines: 83.3 % 66 55 11 55
Current Date: 2024-01-09 02:06:09 Functions: 29.2 % 24 7 17 7
Baseline: 66c52a629a0f4a503e193045e0df4c77139e344b.info
Baseline Date: 2024-01-08 15:34:46

           TLA  Line data    Source code
       1                 : use crate::{
       2                 :     auth, compute,
       3                 :     console::{self, provider::NodeInfo},
       4                 :     error::UserFacingError,
       5                 :     stream::PqStream,
       6                 :     waiters,
       7                 : };
       8                 : use pq_proto::BeMessage as Be;
       9                 : use thiserror::Error;
      10                 : use tokio::io::{AsyncRead, AsyncWrite};
      11                 : use tokio_postgres::config::SslMode;
      12                 : use tracing::{info, info_span};
      13                 : 
      14 UBC           0 : #[derive(Debug, Error)]
      15                 : pub enum LinkAuthError {
      16                 :     /// Authentication error reported by the console.
      17                 :     #[error("Authentication failed: {0}")]
      18                 :     AuthFailed(String),
      19                 : 
      20                 :     #[error(transparent)]
      21                 :     WaiterRegister(#[from] waiters::RegisterError),
      22                 : 
      23                 :     #[error(transparent)]
      24                 :     WaiterWait(#[from] waiters::WaitError),
      25                 : 
      26                 :     #[error(transparent)]
      27                 :     Io(#[from] std::io::Error),
      28                 : }
      29                 : 
      30                 : impl UserFacingError for LinkAuthError {
      31               0 :     fn to_string_client(&self) -> String {
      32               0 :         use LinkAuthError::*;
      33               0 :         match self {
      34               0 :             AuthFailed(_) => self.to_string(),
      35               0 :             _ => "Internal error".to_string(),
      36                 :         }
      37               0 :     }
      38                 : }
      39                 : 
      40 CBC           3 : fn hello_message(redirect_uri: &reqwest::Url, session_id: &str) -> String {
      41               3 :     format!(
      42               3 :         concat![
      43               3 :             "Welcome to Neon!\n",
      44               3 :             "Authenticate by visiting:\n",
      45               3 :             "    {redirect_uri}{session_id}\n\n",
      46               3 :         ],
      47               3 :         redirect_uri = redirect_uri,
      48               3 :         session_id = session_id,
      49               3 :     )
      50               3 : }
      51                 : 
      52               3 : pub fn new_psql_session_id() -> String {
      53               3 :     hex::encode(rand::random::<[u8; 8]>())
      54               3 : }
      55                 : 
      56               3 : pub(super) async fn authenticate(
      57               3 :     link_uri: &reqwest::Url,
      58               3 :     client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
      59               3 : ) -> auth::Result<NodeInfo> {
      60               3 :     let psql_session_id = new_psql_session_id();
      61               3 :     let span = info_span!("link", psql_session_id = &psql_session_id);
      62               3 :     let greeting = hello_message(link_uri, &psql_session_id);
      63                 : 
      64               3 :     let db_info = console::mgmt::with_waiter(psql_session_id, |waiter| async {
      65               3 :         // Give user a URL to spawn a new database.
      66               3 :         info!(parent: &span, "sending the auth URL to the user");
      67               3 :         client
      68               3 :             .write_message_noflush(&Be::AuthenticationOk)?
      69               3 :             .write_message_noflush(&Be::CLIENT_ENCODING)?
      70               3 :             .write_message(&Be::NoticeResponse(&greeting))
      71 UBC           0 :             .await?;
      72                 : 
      73                 :         // Wait for web console response (see `mgmt`).
      74 CBC           6 :         info!(parent: &span, "waiting for console's reply...");
      75               3 :         waiter.await?.map_err(LinkAuthError::AuthFailed)
      76               3 :     })
      77               3 :     .await?;
      78                 : 
      79               3 :     client.write_message_noflush(&Be::NoticeResponse("Connecting to database."))?;
      80                 : 
      81                 :     // This config should be self-contained, because we won't
      82                 :     // take username or dbname from client's startup message.
      83               3 :     let mut config = compute::ConnCfg::new();
      84               3 :     config
      85               3 :         .host(&db_info.host)
      86               3 :         .port(db_info.port)
      87               3 :         .dbname(&db_info.dbname)
      88               3 :         .user(&db_info.user);
      89               3 : 
      90               3 :     // Backwards compatibility. pg_sni_proxy uses "--" in domain names
      91               3 :     // while direct connections do not. Once we migrate to pg_sni_proxy
      92               3 :     // everywhere, we can remove this.
      93               3 :     if db_info.host.contains("--") {
      94 UBC           0 :         // we need TLS connection with SNI info to properly route it
      95               0 :         config.ssl_mode(SslMode::Require);
      96 CBC           3 :     } else {
      97               3 :         config.ssl_mode(SslMode::Disable);
      98               3 :     }
      99                 : 
     100               3 :     if let Some(password) = db_info.password {
     101 UBC           0 :         config.password(password.as_ref());
     102 CBC           3 :     }
     103                 : 
     104               3 :     Ok(NodeInfo {
     105               3 :         config,
     106               3 :         aux: db_info.aux,
     107               3 :         allow_self_signed_compute: false, // caller may override
     108               3 :     })
     109               3 : }
        

Generated by: LCOV version 2.1-beta