Line data Source code
1 : use std::fmt;
2 :
3 : use async_trait::async_trait;
4 : use postgres_client::config::SslMode;
5 : use thiserror::Error;
6 : use tokio::io::{AsyncRead, AsyncWrite};
7 : use tracing::{info, info_span};
8 :
9 : use crate::auth::backend::ComputeUserInfo;
10 : use crate::cache::Cached;
11 : use crate::compute::AuthInfo;
12 : use crate::config::AuthenticationConfig;
13 : use crate::context::RequestContext;
14 : use crate::control_plane::client::cplane_proxy_v1;
15 : use crate::control_plane::{self, CachedNodeInfo, NodeInfo};
16 : use crate::error::{ReportableError, UserFacingError};
17 : use crate::pqproto::BeMessage;
18 : use crate::proxy::NeonOptions;
19 : use crate::proxy::wake_compute::WakeComputeBackend;
20 : use crate::stream::PqStream;
21 : use crate::types::RoleName;
22 : use crate::{auth, compute, waiters};
23 :
24 : #[derive(Debug, Error)]
25 : pub(crate) enum ConsoleRedirectError {
26 : #[error(transparent)]
27 : WaiterRegister(#[from] waiters::RegisterError),
28 :
29 : #[error(transparent)]
30 : WaiterWait(#[from] waiters::WaitError),
31 :
32 : #[error(transparent)]
33 : Io(#[from] std::io::Error),
34 : }
35 :
36 : #[derive(Debug)]
37 : pub struct ConsoleRedirectBackend {
38 : console_uri: reqwest::Url,
39 : api: cplane_proxy_v1::NeonControlPlaneClient,
40 : }
41 :
42 : impl fmt::Debug for cplane_proxy_v1::NeonControlPlaneClient {
43 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44 0 : write!(f, "NeonControlPlaneClient")
45 0 : }
46 : }
47 :
48 : impl UserFacingError for ConsoleRedirectError {
49 0 : fn to_string_client(&self) -> String {
50 0 : "Internal error".to_string()
51 0 : }
52 : }
53 :
54 : impl ReportableError for ConsoleRedirectError {
55 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
56 0 : match self {
57 0 : Self::WaiterRegister(_) => crate::error::ErrorKind::Service,
58 0 : Self::WaiterWait(_) => crate::error::ErrorKind::Service,
59 0 : Self::Io(_) => crate::error::ErrorKind::ClientDisconnect,
60 : }
61 0 : }
62 : }
63 :
64 0 : fn hello_message(
65 0 : redirect_uri: &reqwest::Url,
66 0 : session_id: &str,
67 0 : duration: std::time::Duration,
68 0 : ) -> String {
69 0 : let formatted_duration = humantime::format_duration(duration).to_string();
70 0 : format!(
71 0 : concat![
72 : "Welcome to Neon!\n",
73 : "Authenticate by visiting (will expire in {duration}):\n",
74 : " {redirect_uri}{session_id}\n\n",
75 : ],
76 : duration = formatted_duration,
77 : redirect_uri = redirect_uri,
78 : session_id = session_id,
79 : )
80 0 : }
81 :
82 0 : pub(crate) fn new_psql_session_id() -> String {
83 0 : hex::encode(rand::random::<[u8; 8]>())
84 0 : }
85 :
86 : impl ConsoleRedirectBackend {
87 0 : pub fn new(console_uri: reqwest::Url, api: cplane_proxy_v1::NeonControlPlaneClient) -> Self {
88 0 : Self { console_uri, api }
89 0 : }
90 :
91 0 : pub(crate) fn get_api(&self) -> &cplane_proxy_v1::NeonControlPlaneClient {
92 0 : &self.api
93 0 : }
94 :
95 0 : pub(crate) async fn authenticate(
96 0 : &self,
97 0 : ctx: &RequestContext,
98 0 : auth_config: &'static AuthenticationConfig,
99 0 : client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
100 0 : ) -> auth::Result<(ConsoleRedirectNodeInfo, AuthInfo, ComputeUserInfo)> {
101 0 : authenticate(ctx, auth_config, &self.console_uri, client)
102 0 : .await
103 0 : .map(|(node_info, auth_info, user_info)| {
104 0 : (ConsoleRedirectNodeInfo(node_info), auth_info, user_info)
105 0 : })
106 0 : }
107 : }
108 :
109 : pub struct ConsoleRedirectNodeInfo(pub(super) NodeInfo);
110 :
111 : #[async_trait]
112 : impl WakeComputeBackend for ConsoleRedirectNodeInfo {
113 0 : async fn wake_compute(
114 : &self,
115 : _ctx: &RequestContext,
116 0 : ) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
117 0 : Ok(Cached::new_uncached(self.0.clone()))
118 0 : }
119 : }
120 :
121 0 : async fn authenticate(
122 0 : ctx: &RequestContext,
123 0 : auth_config: &'static AuthenticationConfig,
124 0 : link_uri: &reqwest::Url,
125 0 : client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
126 0 : ) -> auth::Result<(NodeInfo, AuthInfo, ComputeUserInfo)> {
127 0 : ctx.set_auth_method(crate::context::AuthMethod::ConsoleRedirect);
128 :
129 : // registering waiter can fail if we get unlucky with rng.
130 : // just try again.
131 0 : let (psql_session_id, waiter) = loop {
132 0 : let psql_session_id = new_psql_session_id();
133 :
134 0 : if let Ok(waiter) = control_plane::mgmt::get_waiter(&psql_session_id) {
135 0 : break (psql_session_id, waiter);
136 0 : }
137 : };
138 :
139 0 : let span = info_span!("console_redirect", psql_session_id = &psql_session_id);
140 0 : let greeting = hello_message(
141 0 : link_uri,
142 0 : &psql_session_id,
143 0 : auth_config.console_redirect_confirmation_timeout,
144 : );
145 :
146 : // Give user a URL to spawn a new database.
147 0 : info!(parent: &span, "sending the auth URL to the user");
148 0 : client.write_message(BeMessage::AuthenticationOk);
149 0 : client.write_message(BeMessage::ParameterStatus {
150 0 : name: b"client_encoding",
151 0 : value: b"UTF8",
152 0 : });
153 0 : client.write_message(BeMessage::NoticeResponse(&greeting));
154 0 : client.flush().await?;
155 :
156 : // Wait for console response via control plane (see `mgmt`).
157 0 : info!(parent: &span, "waiting for console's reply...");
158 0 : let db_info = tokio::time::timeout(auth_config.console_redirect_confirmation_timeout, waiter)
159 0 : .await
160 0 : .map_err(|_elapsed| {
161 0 : auth::AuthError::confirmation_timeout(
162 0 : auth_config.console_redirect_confirmation_timeout.into(),
163 : )
164 0 : })?
165 0 : .map_err(ConsoleRedirectError::from)?;
166 :
167 0 : if auth_config.ip_allowlist_check_enabled
168 0 : && let Some(allowed_ips) = &db_info.allowed_ips
169 0 : && !auth::check_peer_addr_is_in_list(&ctx.peer_addr(), allowed_ips)
170 : {
171 0 : return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr()));
172 0 : }
173 :
174 : // Check if the access over the public internet is allowed, otherwise block. Note that
175 : // the console redirect is not behind the VPC service endpoint, so we don't need to check
176 : // the VPC endpoint ID.
177 0 : if let Some(public_access_allowed) = db_info.public_access_allowed
178 0 : && !public_access_allowed
179 : {
180 0 : return Err(auth::AuthError::NetworkNotAllowed);
181 0 : }
182 :
183 0 : client.write_message(BeMessage::NoticeResponse("Connecting to database."));
184 :
185 : // Backwards compatibility. pg_sni_proxy uses "--" in domain names
186 : // while direct connections do not. Once we migrate to pg_sni_proxy
187 : // everywhere, we can remove this.
188 0 : let ssl_mode = if db_info.host.contains("--") {
189 : // we need TLS connection with SNI info to properly route it
190 0 : SslMode::Require
191 : } else {
192 0 : SslMode::Disable
193 : };
194 :
195 0 : let conn_info = compute::ConnectInfo {
196 0 : host: db_info.host.into(),
197 0 : port: db_info.port,
198 0 : ssl_mode,
199 0 : host_addr: None,
200 0 : };
201 0 : let auth_info =
202 0 : AuthInfo::for_console_redirect(&db_info.dbname, &db_info.user, db_info.password.as_deref());
203 :
204 0 : let user: RoleName = db_info.user.into();
205 0 : let user_info = ComputeUserInfo {
206 0 : endpoint: db_info.aux.endpoint_id.as_str().into(),
207 0 : user: user.clone(),
208 0 : options: NeonOptions::default(),
209 0 : };
210 :
211 0 : ctx.set_dbname(db_info.dbname.into());
212 0 : ctx.set_user(user);
213 0 : ctx.set_project(db_info.aux.clone());
214 0 : info!("woken up a compute node");
215 :
216 0 : Ok((
217 0 : NodeInfo {
218 0 : conn_info,
219 0 : aux: db_info.aux,
220 0 : },
221 0 : auth_info,
222 0 : user_info,
223 0 : ))
224 0 : }
|