Line data Source code
1 : use crate::{
2 : auth, compute,
3 : config::AuthenticationConfig,
4 : console::{self, provider::NodeInfo},
5 : context::RequestMonitoring,
6 : error::{ReportableError, UserFacingError},
7 : stream::PqStream,
8 : waiters,
9 : };
10 : use pq_proto::BeMessage as Be;
11 : use thiserror::Error;
12 : use tokio::io::{AsyncRead, AsyncWrite};
13 : use tokio_postgres::config::SslMode;
14 : use tracing::{info, info_span};
15 :
16 0 : #[derive(Debug, Error)]
17 : pub(crate) enum WebAuthError {
18 : #[error(transparent)]
19 : WaiterRegister(#[from] waiters::RegisterError),
20 :
21 : #[error(transparent)]
22 : WaiterWait(#[from] waiters::WaitError),
23 :
24 : #[error(transparent)]
25 : Io(#[from] std::io::Error),
26 : }
27 :
28 : impl UserFacingError for WebAuthError {
29 0 : fn to_string_client(&self) -> String {
30 0 : "Internal error".to_string()
31 0 : }
32 : }
33 :
34 : impl ReportableError for WebAuthError {
35 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
36 0 : match self {
37 0 : Self::WaiterRegister(_) => crate::error::ErrorKind::Service,
38 0 : Self::WaiterWait(_) => crate::error::ErrorKind::Service,
39 0 : Self::Io(_) => crate::error::ErrorKind::ClientDisconnect,
40 : }
41 0 : }
42 : }
43 :
44 0 : fn hello_message(redirect_uri: &reqwest::Url, session_id: &str) -> String {
45 0 : format!(
46 0 : concat![
47 0 : "Welcome to Neon!\n",
48 0 : "Authenticate by visiting:\n",
49 0 : " {redirect_uri}{session_id}\n\n",
50 0 : ],
51 0 : redirect_uri = redirect_uri,
52 0 : session_id = session_id,
53 0 : )
54 0 : }
55 :
56 0 : pub(crate) fn new_psql_session_id() -> String {
57 0 : hex::encode(rand::random::<[u8; 8]>())
58 0 : }
59 :
60 0 : pub(super) async fn authenticate(
61 0 : ctx: &RequestMonitoring,
62 0 : auth_config: &'static AuthenticationConfig,
63 0 : link_uri: &reqwest::Url,
64 0 : client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
65 0 : ) -> auth::Result<NodeInfo> {
66 0 : ctx.set_auth_method(crate::context::AuthMethod::Web);
67 :
68 : // registering waiter can fail if we get unlucky with rng.
69 : // just try again.
70 0 : let (psql_session_id, waiter) = loop {
71 0 : let psql_session_id = new_psql_session_id();
72 0 :
73 0 : match console::mgmt::get_waiter(&psql_session_id) {
74 0 : Ok(waiter) => break (psql_session_id, waiter),
75 0 : Err(_e) => continue,
76 : }
77 : };
78 :
79 0 : let span = info_span!("web", psql_session_id = &psql_session_id);
80 0 : let greeting = hello_message(link_uri, &psql_session_id);
81 0 :
82 0 : // Give user a URL to spawn a new database.
83 0 : info!(parent: &span, "sending the auth URL to the user");
84 0 : client
85 0 : .write_message_noflush(&Be::AuthenticationOk)?
86 0 : .write_message_noflush(&Be::CLIENT_ENCODING)?
87 0 : .write_message(&Be::NoticeResponse(&greeting))
88 0 : .await?;
89 :
90 : // Wait for web console response (see `mgmt`).
91 0 : info!(parent: &span, "waiting for console's reply...");
92 0 : let db_info = waiter.await.map_err(WebAuthError::from)?;
93 :
94 0 : if auth_config.ip_allowlist_check_enabled {
95 0 : if let Some(allowed_ips) = &db_info.allowed_ips {
96 0 : if !auth::check_peer_addr_is_in_list(&ctx.peer_addr(), allowed_ips) {
97 0 : return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr()));
98 0 : }
99 0 : }
100 0 : }
101 :
102 0 : client.write_message_noflush(&Be::NoticeResponse("Connecting to database."))?;
103 :
104 : // This config should be self-contained, because we won't
105 : // take username or dbname from client's startup message.
106 0 : let mut config = compute::ConnCfg::new();
107 0 : config
108 0 : .host(&db_info.host)
109 0 : .port(db_info.port)
110 0 : .dbname(&db_info.dbname)
111 0 : .user(&db_info.user);
112 0 :
113 0 : ctx.set_dbname(db_info.dbname.into());
114 0 : ctx.set_user(db_info.user.into());
115 0 : ctx.set_project(db_info.aux.clone());
116 0 : info!("woken up a compute node");
117 :
118 : // Backwards compatibility. pg_sni_proxy uses "--" in domain names
119 : // while direct connections do not. Once we migrate to pg_sni_proxy
120 : // everywhere, we can remove this.
121 0 : if db_info.host.contains("--") {
122 0 : // we need TLS connection with SNI info to properly route it
123 0 : config.ssl_mode(SslMode::Require);
124 0 : } else {
125 0 : config.ssl_mode(SslMode::Disable);
126 0 : }
127 :
128 0 : if let Some(password) = db_info.password {
129 0 : config.password(password.as_ref());
130 0 : }
131 :
132 0 : Ok(NodeInfo {
133 0 : config,
134 0 : aux: db_info.aux,
135 0 : allow_self_signed_compute: false, // caller may override
136 0 : })
137 0 : }
|