Line data Source code
1 : use async_trait::async_trait;
2 : use pq_proto::BeMessage as Be;
3 : use thiserror::Error;
4 : use tokio::io::{AsyncRead, AsyncWrite};
5 : use tokio_postgres::config::SslMode;
6 : use tracing::{info, info_span};
7 :
8 : use super::ComputeCredentialKeys;
9 : use crate::cache::Cached;
10 : use crate::config::AuthenticationConfig;
11 : use crate::context::RequestMonitoring;
12 : use crate::control_plane::provider::NodeInfo;
13 : use crate::control_plane::{self, CachedNodeInfo};
14 : use crate::error::{ReportableError, UserFacingError};
15 : use crate::proxy::connect_compute::ComputeConnectBackend;
16 : use crate::stream::PqStream;
17 : use crate::{auth, compute, waiters};
18 :
19 0 : #[derive(Debug, Error)]
20 : pub(crate) enum WebAuthError {
21 : #[error(transparent)]
22 : WaiterRegister(#[from] waiters::RegisterError),
23 :
24 : #[error(transparent)]
25 : WaiterWait(#[from] waiters::WaitError),
26 :
27 : #[error(transparent)]
28 : Io(#[from] std::io::Error),
29 : }
30 :
31 : #[derive(Debug)]
32 : pub struct ConsoleRedirectBackend {
33 : console_uri: reqwest::Url,
34 : }
35 :
36 : impl UserFacingError for WebAuthError {
37 0 : fn to_string_client(&self) -> String {
38 0 : "Internal error".to_string()
39 0 : }
40 : }
41 :
42 : impl ReportableError for WebAuthError {
43 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
44 0 : match self {
45 0 : Self::WaiterRegister(_) => crate::error::ErrorKind::Service,
46 0 : Self::WaiterWait(_) => crate::error::ErrorKind::Service,
47 0 : Self::Io(_) => crate::error::ErrorKind::ClientDisconnect,
48 : }
49 0 : }
50 : }
51 :
52 0 : fn hello_message(redirect_uri: &reqwest::Url, session_id: &str) -> String {
53 0 : format!(
54 0 : concat![
55 0 : "Welcome to Neon!\n",
56 0 : "Authenticate by visiting:\n",
57 0 : " {redirect_uri}{session_id}\n\n",
58 0 : ],
59 0 : redirect_uri = redirect_uri,
60 0 : session_id = session_id,
61 0 : )
62 0 : }
63 :
64 0 : pub(crate) fn new_psql_session_id() -> String {
65 0 : hex::encode(rand::random::<[u8; 8]>())
66 0 : }
67 :
68 : impl ConsoleRedirectBackend {
69 0 : pub fn new(console_uri: reqwest::Url) -> Self {
70 0 : Self { console_uri }
71 0 : }
72 :
73 0 : pub(crate) async fn authenticate(
74 0 : &self,
75 0 : ctx: &RequestMonitoring,
76 0 : auth_config: &'static AuthenticationConfig,
77 0 : client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
78 0 : ) -> auth::Result<ConsoleRedirectNodeInfo> {
79 0 : authenticate(ctx, auth_config, &self.console_uri, client)
80 0 : .await
81 0 : .map(ConsoleRedirectNodeInfo)
82 0 : }
83 : }
84 :
85 : pub struct ConsoleRedirectNodeInfo(pub(super) NodeInfo);
86 :
87 : #[async_trait]
88 : impl ComputeConnectBackend for ConsoleRedirectNodeInfo {
89 0 : async fn wake_compute(
90 0 : &self,
91 0 : _ctx: &RequestMonitoring,
92 0 : ) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
93 0 : Ok(Cached::new_uncached(self.0.clone()))
94 0 : }
95 :
96 0 : fn get_keys(&self) -> &ComputeCredentialKeys {
97 0 : &ComputeCredentialKeys::None
98 0 : }
99 : }
100 :
101 0 : async fn authenticate(
102 0 : ctx: &RequestMonitoring,
103 0 : auth_config: &'static AuthenticationConfig,
104 0 : link_uri: &reqwest::Url,
105 0 : client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
106 0 : ) -> auth::Result<NodeInfo> {
107 0 : ctx.set_auth_method(crate::context::AuthMethod::Web);
108 :
109 : // registering waiter can fail if we get unlucky with rng.
110 : // just try again.
111 0 : let (psql_session_id, waiter) = loop {
112 0 : let psql_session_id = new_psql_session_id();
113 0 :
114 0 : match control_plane::mgmt::get_waiter(&psql_session_id) {
115 0 : Ok(waiter) => break (psql_session_id, waiter),
116 0 : Err(_e) => continue,
117 : }
118 : };
119 :
120 0 : let span = info_span!("web", psql_session_id = &psql_session_id);
121 0 : let greeting = hello_message(link_uri, &psql_session_id);
122 0 :
123 0 : // Give user a URL to spawn a new database.
124 0 : info!(parent: &span, "sending the auth URL to the user");
125 0 : client
126 0 : .write_message_noflush(&Be::AuthenticationOk)?
127 0 : .write_message_noflush(&Be::CLIENT_ENCODING)?
128 0 : .write_message(&Be::NoticeResponse(&greeting))
129 0 : .await?;
130 :
131 : // Wait for web console response (see `mgmt`).
132 0 : info!(parent: &span, "waiting for console's reply...");
133 0 : let db_info = tokio::time::timeout(auth_config.webauth_confirmation_timeout, waiter)
134 0 : .await
135 0 : .map_err(|_elapsed| {
136 0 : auth::AuthError::confirmation_timeout(auth_config.webauth_confirmation_timeout.into())
137 0 : })?
138 0 : .map_err(WebAuthError::from)?;
139 :
140 0 : if auth_config.ip_allowlist_check_enabled {
141 0 : if let Some(allowed_ips) = &db_info.allowed_ips {
142 0 : if !auth::check_peer_addr_is_in_list(&ctx.peer_addr(), allowed_ips) {
143 0 : return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr()));
144 0 : }
145 0 : }
146 0 : }
147 :
148 0 : client.write_message_noflush(&Be::NoticeResponse("Connecting to database."))?;
149 :
150 : // This config should be self-contained, because we won't
151 : // take username or dbname from client's startup message.
152 0 : let mut config = compute::ConnCfg::new();
153 0 : config
154 0 : .host(&db_info.host)
155 0 : .port(db_info.port)
156 0 : .dbname(&db_info.dbname)
157 0 : .user(&db_info.user);
158 0 :
159 0 : ctx.set_dbname(db_info.dbname.into());
160 0 : ctx.set_user(db_info.user.into());
161 0 : ctx.set_project(db_info.aux.clone());
162 0 : info!("woken up a compute node");
163 :
164 : // Backwards compatibility. pg_sni_proxy uses "--" in domain names
165 : // while direct connections do not. Once we migrate to pg_sni_proxy
166 : // everywhere, we can remove this.
167 0 : if db_info.host.contains("--") {
168 0 : // we need TLS connection with SNI info to properly route it
169 0 : config.ssl_mode(SslMode::Require);
170 0 : } else {
171 0 : config.ssl_mode(SslMode::Disable);
172 0 : }
173 :
174 0 : if let Some(password) = db_info.password {
175 0 : config.password(password.as_ref());
176 0 : }
177 :
178 0 : Ok(NodeInfo {
179 0 : config,
180 0 : aux: db_info.aux,
181 0 : allow_self_signed_compute: false, // caller may override
182 0 : })
183 0 : }
|