Line data Source code
1 : use std::sync::Arc;
2 : use std::time::Duration;
3 :
4 : use ed25519_dalek::SigningKey;
5 : use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
6 : use jose_jwk::jose_b64;
7 : use postgres_client::error::SqlState;
8 : use postgres_client::maybe_tls_stream::MaybeTlsStream;
9 : use rand_core::OsRng;
10 : use tracing::field::display;
11 : use tracing::{debug, info};
12 :
13 : use super::AsyncRW;
14 : use super::conn_pool::poll_client;
15 : use super::conn_pool_lib::{Client, ConnInfo, EndpointConnPool, GlobalConnPool};
16 : use super::http_conn_pool::{self, HttpConnPool, LocalProxyClient, poll_http2_client};
17 : use super::local_conn_pool::{self, EXT_NAME, EXT_SCHEMA, EXT_VERSION, LocalConnPool};
18 : use crate::auth::backend::local::StaticAuthRules;
19 : use crate::auth::backend::{ComputeCredentials, ComputeUserInfo};
20 : use crate::auth::{self, AuthError};
21 : use crate::compute;
22 : use crate::compute_ctl::{
23 : ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest,
24 : };
25 : use crate::config::ProxyConfig;
26 : use crate::context::RequestContext;
27 : use crate::control_plane::client::ApiLockError;
28 : use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError};
29 : use crate::error::{ErrorKind, ReportableError, UserFacingError};
30 : use crate::intern::{EndpointIdInt, RoleNameInt};
31 : use crate::pqproto::StartupMessageParams;
32 : use crate::proxy::{connect_auth, connect_compute};
33 : use crate::rate_limiter::EndpointRateLimiter;
34 : use crate::types::{EndpointId, LOCAL_PROXY_SUFFIX};
35 :
36 : pub(crate) struct PoolingBackend {
37 : pub(crate) http_conn_pool:
38 : Arc<GlobalConnPool<LocalProxyClient, HttpConnPool<LocalProxyClient>>>,
39 : pub(crate) local_pool: Arc<LocalConnPool<postgres_client::Client>>,
40 : pub(crate) pool:
41 : Arc<GlobalConnPool<postgres_client::Client, EndpointConnPool<postgres_client::Client>>>,
42 :
43 : pub(crate) config: &'static ProxyConfig,
44 : pub(crate) auth_backend: &'static crate::auth::Backend<'static, ()>,
45 : pub(crate) endpoint_rate_limiter: Arc<EndpointRateLimiter>,
46 : }
47 :
48 : impl PoolingBackend {
49 0 : pub(crate) async fn authenticate_with_password(
50 0 : &self,
51 0 : ctx: &RequestContext,
52 0 : user_info: &ComputeUserInfo,
53 0 : password: &[u8],
54 0 : ) -> Result<ComputeCredentials, AuthError> {
55 0 : ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
56 :
57 0 : let user_info = user_info.clone();
58 0 : let backend = self.auth_backend.as_ref().map(|()| user_info.clone());
59 0 : let access_control = backend.get_endpoint_access_control(ctx).await?;
60 0 : access_control.check(
61 0 : ctx,
62 0 : self.config.authentication_config.ip_allowlist_check_enabled,
63 0 : self.config.authentication_config.is_vpc_acccess_proxy,
64 0 : )?;
65 :
66 0 : access_control.connection_attempt_rate_limit(
67 0 : ctx,
68 0 : &user_info.endpoint,
69 0 : &self.endpoint_rate_limiter,
70 0 : )?;
71 :
72 0 : let role_access = backend.get_role_secret(ctx).await?;
73 0 : let Some(secret) = role_access.secret else {
74 : // If we don't have an authentication secret, for the http flow we can just return an error.
75 0 : info!("authentication info not found");
76 0 : return Err(AuthError::password_failed(&*user_info.user));
77 : };
78 :
79 0 : let ep = EndpointIdInt::from(&user_info.endpoint);
80 0 : let role = RoleNameInt::from(&user_info.user);
81 0 : let auth_outcome = crate::auth::validate_password_and_exchange(
82 0 : &self.config.authentication_config.scram_thread_pool,
83 0 : ep,
84 0 : role,
85 0 : password,
86 0 : secret,
87 0 : )
88 0 : .await?;
89 0 : let res = match auth_outcome {
90 0 : crate::sasl::Outcome::Success(key) => {
91 0 : info!("user successfully authenticated");
92 0 : Ok(key)
93 : }
94 0 : crate::sasl::Outcome::Failure(reason) => {
95 0 : info!("auth backend failed with an error: {reason}");
96 0 : Err(AuthError::password_failed(&*user_info.user))
97 : }
98 : };
99 0 : res.map(|key| ComputeCredentials {
100 0 : info: user_info,
101 0 : keys: key,
102 0 : })
103 0 : }
104 :
105 0 : pub(crate) async fn authenticate_with_jwt(
106 0 : &self,
107 0 : ctx: &RequestContext,
108 0 : user_info: &ComputeUserInfo,
109 0 : jwt: String,
110 0 : ) -> Result<ComputeCredentials, AuthError> {
111 0 : ctx.set_auth_method(crate::context::AuthMethod::Jwt);
112 :
113 0 : match &self.auth_backend {
114 0 : crate::auth::Backend::ControlPlane(console, ()) => {
115 0 : let keys = self
116 0 : .config
117 0 : .authentication_config
118 0 : .jwks_cache
119 0 : .check_jwt(
120 0 : ctx,
121 0 : user_info.endpoint.clone(),
122 0 : &user_info.user,
123 0 : &**console,
124 0 : &jwt,
125 0 : )
126 0 : .await?;
127 :
128 0 : Ok(ComputeCredentials {
129 0 : info: user_info.clone(),
130 0 : keys,
131 0 : })
132 : }
133 : crate::auth::Backend::Local(_) => {
134 0 : let keys = self
135 0 : .config
136 0 : .authentication_config
137 0 : .jwks_cache
138 0 : .check_jwt(
139 0 : ctx,
140 0 : user_info.endpoint.clone(),
141 0 : &user_info.user,
142 0 : &StaticAuthRules,
143 0 : &jwt,
144 0 : )
145 0 : .await?;
146 :
147 0 : Ok(ComputeCredentials {
148 0 : info: user_info.clone(),
149 0 : keys,
150 0 : })
151 : }
152 : }
153 0 : }
154 :
155 : // Wake up the destination if needed. Code here is a bit involved because
156 : // we reuse the code from the usual proxy and we need to prepare few structures
157 : // that this code expects.
158 : #[tracing::instrument(skip_all, fields(
159 : pid = tracing::field::Empty,
160 : compute_id = tracing::field::Empty,
161 : conn_id = tracing::field::Empty,
162 : ))]
163 : pub(crate) async fn connect_to_compute(
164 : &self,
165 : ctx: &RequestContext,
166 : conn_info: ConnInfo,
167 : keys: ComputeCredentials,
168 : force_new: bool,
169 : ) -> Result<Client<postgres_client::Client>, HttpConnError> {
170 : let maybe_client = if force_new {
171 : debug!("pool: pool is disabled");
172 : None
173 : } else {
174 : debug!("pool: looking for an existing connection");
175 : self.pool.get(ctx, &conn_info)?
176 : };
177 :
178 : if let Some(client) = maybe_client {
179 : return Ok(client);
180 : }
181 : let conn_id = uuid::Uuid::new_v4();
182 : tracing::Span::current().record("conn_id", display(conn_id));
183 : info!(%conn_id, "pool: opening a new connection '{conn_info}'");
184 : let backend = self.auth_backend.as_ref().map(|()| keys.info);
185 :
186 : let mut params = StartupMessageParams::default();
187 : params.insert("database", &conn_info.dbname);
188 : params.insert("user", &conn_info.user_info.user);
189 :
190 : let mut auth_info = compute::AuthInfo::with_auth_keys(keys.keys);
191 : auth_info.set_startup_params(¶ms, true);
192 :
193 : let node = connect_auth::connect_to_compute_and_auth(
194 : ctx,
195 : self.config,
196 : &backend,
197 : auth_info,
198 : connect_compute::TlsNegotiation::Postgres,
199 : )
200 : .await?;
201 :
202 : let (client, connection) = postgres_client::connect::managed(
203 : node.stream,
204 : Some(node.socket_addr.ip()),
205 : postgres_client::config::Host::Tcp(node.hostname.to_string()),
206 : node.socket_addr.port(),
207 : node.ssl_mode,
208 : Some(self.config.connect_to_compute.timeout),
209 : )
210 : .await?;
211 :
212 : Ok(poll_client(
213 : self.pool.clone(),
214 : ctx,
215 : conn_info,
216 : client,
217 : connection,
218 : conn_id,
219 : node.aux,
220 : ))
221 : }
222 :
223 : // Wake up the destination if needed
224 : #[tracing::instrument(skip_all, fields(
225 : compute_id = tracing::field::Empty,
226 : conn_id = tracing::field::Empty,
227 : ))]
228 : pub(crate) async fn connect_to_local_proxy(
229 : &self,
230 : ctx: &RequestContext,
231 : conn_info: ConnInfo,
232 : ) -> Result<http_conn_pool::Client<LocalProxyClient>, HttpConnError> {
233 : debug!("pool: looking for an existing connection");
234 : if let Ok(Some(client)) = self.http_conn_pool.get(ctx, &conn_info) {
235 : return Ok(client);
236 : }
237 :
238 : let conn_id = uuid::Uuid::new_v4();
239 : tracing::Span::current().record("conn_id", display(conn_id));
240 : debug!(%conn_id, "pool: opening a new connection '{conn_info}'");
241 : let backend = self.auth_backend.as_ref().map(|()| ComputeUserInfo {
242 0 : user: conn_info.user_info.user.clone(),
243 0 : endpoint: EndpointId::from(format!(
244 0 : "{}{LOCAL_PROXY_SUFFIX}",
245 0 : conn_info.user_info.endpoint.normalize()
246 : )),
247 0 : options: conn_info.user_info.options.clone(),
248 0 : });
249 :
250 : let node = connect_compute::connect_to_compute(
251 : ctx,
252 : self.config,
253 : &backend,
254 : connect_compute::TlsNegotiation::Direct,
255 : )
256 : .await?;
257 :
258 : let stream = match node.stream.into_framed().into_inner() {
259 : MaybeTlsStream::Raw(s) => Box::pin(s) as AsyncRW,
260 : MaybeTlsStream::Tls(s) => Box::pin(s) as AsyncRW,
261 : };
262 :
263 : let (client, connection) = hyper::client::conn::http2::Builder::new(TokioExecutor::new())
264 : .timer(TokioTimer::new())
265 : .keep_alive_interval(Duration::from_secs(20))
266 : .keep_alive_while_idle(true)
267 : .keep_alive_timeout(Duration::from_secs(5))
268 : .handshake(TokioIo::new(stream))
269 : .await
270 : .map_err(LocalProxyConnError::H2)?;
271 :
272 : Ok(poll_http2_client(
273 : self.http_conn_pool.clone(),
274 : ctx,
275 : &conn_info,
276 : client,
277 : connection,
278 : conn_id,
279 : node.aux.clone(),
280 : ))
281 : }
282 :
283 : /// Connect to postgres over localhost.
284 : ///
285 : /// We expect postgres to be started here, so we won't do any retries.
286 : ///
287 : /// # Panics
288 : ///
289 : /// Panics if called with a non-local_proxy backend.
290 : #[tracing::instrument(skip_all, fields(
291 : pid = tracing::field::Empty,
292 : conn_id = tracing::field::Empty,
293 : ))]
294 : pub(crate) async fn connect_to_local_postgres(
295 : &self,
296 : ctx: &RequestContext,
297 : conn_info: ConnInfo,
298 : disable_pg_session_jwt: bool,
299 : ) -> Result<Client<postgres_client::Client>, HttpConnError> {
300 : if let Some(client) = self.local_pool.get(ctx, &conn_info)? {
301 : return Ok(client);
302 : }
303 :
304 : let local_backend = match &self.auth_backend {
305 : auth::Backend::ControlPlane(_, ()) => {
306 : unreachable!("only local_proxy can connect to local postgres")
307 : }
308 : auth::Backend::Local(local) => local,
309 : };
310 :
311 : if !self.local_pool.initialized(&conn_info) {
312 : // only install and grant usage one at a time.
313 : let _permit = local_backend
314 : .initialize
315 : .acquire()
316 : .await
317 : .expect("semaphore should never be closed");
318 :
319 : // check again for race
320 : if !self.local_pool.initialized(&conn_info) && !disable_pg_session_jwt {
321 : local_backend
322 : .compute_ctl
323 : .install_extension(&ExtensionInstallRequest {
324 : extension: EXT_NAME,
325 : database: conn_info.dbname.clone(),
326 : version: EXT_VERSION,
327 : })
328 : .await?;
329 :
330 : local_backend
331 : .compute_ctl
332 : .grant_role(&SetRoleGrantsRequest {
333 : schema: EXT_SCHEMA,
334 : privileges: vec![Privilege::Usage],
335 : database: conn_info.dbname.clone(),
336 : role: conn_info.user_info.user.clone(),
337 : })
338 : .await?;
339 :
340 : self.local_pool.set_initialized(&conn_info);
341 : }
342 : }
343 :
344 : let conn_id = uuid::Uuid::new_v4();
345 : tracing::Span::current().record("conn_id", display(conn_id));
346 : info!(%conn_id, "local_pool: opening a new connection '{conn_info}'");
347 :
348 : let (key, jwk) = create_random_jwk();
349 :
350 : let mut config = local_backend
351 : .node_info
352 : .conn_info
353 : .to_postgres_client_config();
354 : config
355 : .user(&conn_info.user_info.user)
356 : .dbname(&conn_info.dbname);
357 : if !disable_pg_session_jwt {
358 : config.set_param(
359 : "options",
360 : &format!(
361 : "-c pg_session_jwt.jwk={}",
362 : serde_json::to_string(&jwk).expect("serializing jwk to json should not fail")
363 : ),
364 : );
365 : }
366 :
367 : let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
368 : let (client, connection) = config.connect(&postgres_client::NoTls).await?;
369 : drop(pause);
370 :
371 : let pid = client.get_process_id();
372 : tracing::Span::current().record("pid", pid);
373 :
374 : let mut handle = local_conn_pool::poll_client(
375 : self.local_pool.clone(),
376 : ctx,
377 : conn_info,
378 : client,
379 : connection,
380 : key,
381 : conn_id,
382 : local_backend.node_info.aux.clone(),
383 : );
384 :
385 : {
386 : let (client, mut discard) = handle.inner();
387 : debug!("setting up backend session state");
388 :
389 : // initiates the auth session
390 : if !disable_pg_session_jwt
391 : && let Err(e) = client.batch_execute("select auth.init();").await
392 : {
393 : discard.discard();
394 : return Err(e.into());
395 : }
396 :
397 : info!("backend session state initialized");
398 : }
399 :
400 : Ok(handle)
401 : }
402 : }
403 :
404 0 : fn create_random_jwk() -> (SigningKey, jose_jwk::Key) {
405 0 : let key = SigningKey::generate(&mut OsRng);
406 :
407 0 : let jwk = jose_jwk::Key::Okp(jose_jwk::Okp {
408 0 : crv: jose_jwk::OkpCurves::Ed25519,
409 0 : x: jose_b64::serde::Bytes::from(key.verifying_key().to_bytes().to_vec()),
410 0 : d: None,
411 0 : });
412 :
413 0 : (key, jwk)
414 0 : }
415 :
416 : #[derive(Debug, thiserror::Error)]
417 : pub(crate) enum HttpConnError {
418 : #[error("pooled connection closed at inconsistent state")]
419 : ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError<uuid::Uuid>),
420 : #[error("could not connect to compute")]
421 : ConnectError(#[from] compute::ConnectionError),
422 : #[error("could not connect to postgres in compute")]
423 : PostgresConnectionError(#[from] postgres_client::Error),
424 : #[error("could not connect to local-proxy in compute")]
425 : LocalProxyConnectionError(#[from] LocalProxyConnError),
426 : #[error("could not parse JWT payload")]
427 : JwtPayloadError(serde_json::Error),
428 :
429 : #[error("could not install extension: {0}")]
430 : ComputeCtl(#[from] ComputeCtlError),
431 : #[error("could not get auth info")]
432 : GetAuthInfo(#[from] GetAuthInfoError),
433 : #[error("user not authenticated")]
434 : AuthError(#[from] AuthError),
435 : #[error("wake_compute returned error")]
436 : WakeCompute(#[from] WakeComputeError),
437 : #[error("error acquiring resource permit: {0}")]
438 : TooManyConnectionAttempts(#[from] ApiLockError),
439 : }
440 :
441 : impl From<connect_auth::AuthError> for HttpConnError {
442 0 : fn from(value: connect_auth::AuthError) -> Self {
443 0 : match value {
444 0 : connect_auth::AuthError::Auth(compute::PostgresError::Postgres(error)) => {
445 0 : Self::PostgresConnectionError(error)
446 : }
447 0 : connect_auth::AuthError::Connect(error) => Self::ConnectError(error),
448 : }
449 0 : }
450 : }
451 :
452 : #[derive(Debug, thiserror::Error)]
453 : pub(crate) enum LocalProxyConnError {
454 : #[error("could not establish h2 connection")]
455 : H2(#[from] hyper::Error),
456 : }
457 :
458 : impl ReportableError for HttpConnError {
459 0 : fn get_error_kind(&self) -> ErrorKind {
460 0 : match self {
461 0 : HttpConnError::ConnectError(e) => e.get_error_kind(),
462 0 : HttpConnError::ConnectionClosedAbruptly(_) => ErrorKind::Compute,
463 0 : HttpConnError::PostgresConnectionError(p) => match p.as_db_error() {
464 : // user provided a wrong database name
465 0 : Some(err) if err.code() == &SqlState::INVALID_CATALOG_NAME => ErrorKind::User,
466 : // postgres rejected the connection
467 0 : Some(_) => ErrorKind::Postgres,
468 : // couldn't even reach postgres
469 0 : None => ErrorKind::Compute,
470 : },
471 0 : HttpConnError::LocalProxyConnectionError(_) => ErrorKind::Compute,
472 0 : HttpConnError::ComputeCtl(_) => ErrorKind::Service,
473 0 : HttpConnError::JwtPayloadError(_) => ErrorKind::User,
474 0 : HttpConnError::GetAuthInfo(a) => a.get_error_kind(),
475 0 : HttpConnError::AuthError(a) => a.get_error_kind(),
476 0 : HttpConnError::WakeCompute(w) => w.get_error_kind(),
477 0 : HttpConnError::TooManyConnectionAttempts(w) => w.get_error_kind(),
478 : }
479 0 : }
480 : }
481 :
482 : impl UserFacingError for HttpConnError {
483 0 : fn to_string_client(&self) -> String {
484 0 : match self {
485 0 : HttpConnError::ConnectError(p) => p.to_string_client(),
486 0 : HttpConnError::ConnectionClosedAbruptly(_) => self.to_string(),
487 0 : HttpConnError::PostgresConnectionError(p) => p.to_string(),
488 0 : HttpConnError::LocalProxyConnectionError(p) => p.to_string(),
489 0 : HttpConnError::ComputeCtl(_) => "could not set up the JWT authorization database extension".to_string(),
490 0 : HttpConnError::JwtPayloadError(p) => p.to_string(),
491 0 : HttpConnError::GetAuthInfo(c) => c.to_string_client(),
492 0 : HttpConnError::AuthError(c) => c.to_string_client(),
493 0 : HttpConnError::WakeCompute(c) => c.to_string_client(),
494 : HttpConnError::TooManyConnectionAttempts(_) => {
495 0 : "Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
496 : }
497 : }
498 0 : }
499 : }
500 :
501 : impl ReportableError for LocalProxyConnError {
502 0 : fn get_error_kind(&self) -> ErrorKind {
503 0 : match self {
504 0 : LocalProxyConnError::H2(_) => ErrorKind::Compute,
505 : }
506 0 : }
507 : }
508 :
509 : impl UserFacingError for LocalProxyConnError {
510 0 : fn to_string_client(&self) -> String {
511 0 : "Could not establish HTTP connection to the database".to_string()
512 0 : }
513 : }
|