Line data Source code
1 : use std::sync::Arc;
2 : use std::time::Duration;
3 :
4 : use ed25519_dalek::SigningKey;
5 : use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
6 : use jose_jwk::jose_b64;
7 : use postgres_client::maybe_tls_stream::MaybeTlsStream;
8 : use rand_core::OsRng;
9 : use tracing::field::display;
10 : use tracing::{debug, info};
11 :
12 : use super::AsyncRW;
13 : use super::conn_pool::poll_client;
14 : use super::conn_pool_lib::{Client, ConnInfo, EndpointConnPool, GlobalConnPool};
15 : use super::http_conn_pool::{self, HttpConnPool, LocalProxyClient, poll_http2_client};
16 : use super::local_conn_pool::{self, EXT_NAME, EXT_SCHEMA, EXT_VERSION, LocalConnPool};
17 : use crate::auth::backend::local::StaticAuthRules;
18 : use crate::auth::backend::{ComputeCredentials, ComputeUserInfo};
19 : use crate::auth::{self, AuthError};
20 : use crate::compute;
21 : use crate::compute_ctl::{
22 : ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest,
23 : };
24 : use crate::config::ProxyConfig;
25 : use crate::context::RequestContext;
26 : use crate::control_plane::client::ApiLockError;
27 : use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError};
28 : use crate::error::{ErrorKind, ReportableError, UserFacingError};
29 : use crate::intern::{EndpointIdInt, RoleNameInt};
30 : use crate::pqproto::StartupMessageParams;
31 : use crate::proxy::{connect_auth, connect_compute};
32 : use crate::rate_limiter::EndpointRateLimiter;
33 : use crate::types::{EndpointId, LOCAL_PROXY_SUFFIX};
34 :
35 : pub(crate) struct PoolingBackend {
36 : pub(crate) http_conn_pool:
37 : Arc<GlobalConnPool<LocalProxyClient, HttpConnPool<LocalProxyClient>>>,
38 : pub(crate) local_pool: Arc<LocalConnPool<postgres_client::Client>>,
39 : pub(crate) pool:
40 : Arc<GlobalConnPool<postgres_client::Client, EndpointConnPool<postgres_client::Client>>>,
41 :
42 : pub(crate) config: &'static ProxyConfig,
43 : pub(crate) auth_backend: &'static crate::auth::Backend<'static, ()>,
44 : pub(crate) endpoint_rate_limiter: Arc<EndpointRateLimiter>,
45 : }
46 :
47 : impl PoolingBackend {
48 0 : pub(crate) async fn authenticate_with_password(
49 0 : &self,
50 0 : ctx: &RequestContext,
51 0 : user_info: &ComputeUserInfo,
52 0 : password: &[u8],
53 0 : ) -> Result<ComputeCredentials, AuthError> {
54 0 : ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
55 :
56 0 : let user_info = user_info.clone();
57 0 : let backend = self.auth_backend.as_ref().map(|()| user_info.clone());
58 0 : let access_control = backend.get_endpoint_access_control(ctx).await?;
59 0 : access_control.check(
60 0 : ctx,
61 0 : self.config.authentication_config.ip_allowlist_check_enabled,
62 0 : self.config.authentication_config.is_vpc_acccess_proxy,
63 0 : )?;
64 :
65 0 : access_control.connection_attempt_rate_limit(
66 0 : ctx,
67 0 : &user_info.endpoint,
68 0 : &self.endpoint_rate_limiter,
69 0 : )?;
70 :
71 0 : let role_access = backend.get_role_secret(ctx).await?;
72 0 : let Some(secret) = role_access.secret else {
73 : // If we don't have an authentication secret, for the http flow we can just return an error.
74 0 : info!("authentication info not found");
75 0 : return Err(AuthError::password_failed(&*user_info.user));
76 : };
77 :
78 0 : let ep = EndpointIdInt::from(&user_info.endpoint);
79 0 : let role = RoleNameInt::from(&user_info.user);
80 0 : let auth_outcome = crate::auth::validate_password_and_exchange(
81 0 : &self.config.authentication_config.scram_thread_pool,
82 0 : ep,
83 0 : role,
84 0 : password,
85 0 : secret,
86 0 : )
87 0 : .await?;
88 0 : let res = match auth_outcome {
89 0 : crate::sasl::Outcome::Success(key) => {
90 0 : info!("user successfully authenticated");
91 0 : Ok(key)
92 : }
93 0 : crate::sasl::Outcome::Failure(reason) => {
94 0 : info!("auth backend failed with an error: {reason}");
95 0 : Err(AuthError::password_failed(&*user_info.user))
96 : }
97 : };
98 0 : res.map(|key| ComputeCredentials {
99 0 : info: user_info,
100 0 : keys: key,
101 0 : })
102 0 : }
103 :
104 0 : pub(crate) async fn authenticate_with_jwt(
105 0 : &self,
106 0 : ctx: &RequestContext,
107 0 : user_info: &ComputeUserInfo,
108 0 : jwt: String,
109 0 : ) -> Result<ComputeCredentials, AuthError> {
110 0 : ctx.set_auth_method(crate::context::AuthMethod::Jwt);
111 :
112 0 : match &self.auth_backend {
113 0 : crate::auth::Backend::ControlPlane(console, ()) => {
114 0 : let keys = self
115 0 : .config
116 0 : .authentication_config
117 0 : .jwks_cache
118 0 : .check_jwt(
119 0 : ctx,
120 0 : user_info.endpoint.clone(),
121 0 : &user_info.user,
122 0 : &**console,
123 0 : &jwt,
124 0 : )
125 0 : .await?;
126 :
127 0 : Ok(ComputeCredentials {
128 0 : info: user_info.clone(),
129 0 : keys,
130 0 : })
131 : }
132 : crate::auth::Backend::Local(_) => {
133 0 : let keys = self
134 0 : .config
135 0 : .authentication_config
136 0 : .jwks_cache
137 0 : .check_jwt(
138 0 : ctx,
139 0 : user_info.endpoint.clone(),
140 0 : &user_info.user,
141 0 : &StaticAuthRules,
142 0 : &jwt,
143 0 : )
144 0 : .await?;
145 :
146 0 : Ok(ComputeCredentials {
147 0 : info: user_info.clone(),
148 0 : keys,
149 0 : })
150 : }
151 : }
152 0 : }
153 :
154 : // Wake up the destination if needed. Code here is a bit involved because
155 : // we reuse the code from the usual proxy and we need to prepare few structures
156 : // that this code expects.
157 : #[tracing::instrument(skip_all, fields(
158 : pid = tracing::field::Empty,
159 : compute_id = tracing::field::Empty,
160 : conn_id = tracing::field::Empty,
161 : ))]
162 : pub(crate) async fn connect_to_compute(
163 : &self,
164 : ctx: &RequestContext,
165 : conn_info: ConnInfo,
166 : keys: ComputeCredentials,
167 : force_new: bool,
168 : ) -> Result<Client<postgres_client::Client>, HttpConnError> {
169 : let maybe_client = if force_new {
170 : debug!("pool: pool is disabled");
171 : None
172 : } else {
173 : debug!("pool: looking for an existing connection");
174 : self.pool.get(ctx, &conn_info)?
175 : };
176 :
177 : if let Some(client) = maybe_client {
178 : return Ok(client);
179 : }
180 : let conn_id = uuid::Uuid::new_v4();
181 : tracing::Span::current().record("conn_id", display(conn_id));
182 : info!(%conn_id, "pool: opening a new connection '{conn_info}'");
183 : let backend = self.auth_backend.as_ref().map(|()| keys.info);
184 :
185 : let mut params = StartupMessageParams::default();
186 : params.insert("database", &conn_info.dbname);
187 : params.insert("user", &conn_info.user_info.user);
188 :
189 : let mut auth_info = compute::AuthInfo::with_auth_keys(keys.keys);
190 : auth_info.set_startup_params(¶ms, true);
191 :
192 : let node = connect_auth::connect_to_compute_and_auth(
193 : ctx,
194 : self.config,
195 : &backend,
196 : auth_info,
197 : connect_compute::TlsNegotiation::Postgres,
198 : )
199 : .await?;
200 :
201 : let (client, connection) = postgres_client::connect::managed(
202 : node.stream,
203 : Some(node.socket_addr.ip()),
204 : postgres_client::config::Host::Tcp(node.hostname.to_string()),
205 : node.socket_addr.port(),
206 : node.ssl_mode,
207 : Some(self.config.connect_to_compute.timeout),
208 : )
209 : .await?;
210 :
211 : Ok(poll_client(
212 : self.pool.clone(),
213 : ctx,
214 : conn_info,
215 : client,
216 : connection,
217 : conn_id,
218 : node.aux,
219 : ))
220 : }
221 :
222 : // Wake up the destination if needed
223 : #[tracing::instrument(skip_all, fields(
224 : compute_id = tracing::field::Empty,
225 : conn_id = tracing::field::Empty,
226 : ))]
227 : pub(crate) async fn connect_to_local_proxy(
228 : &self,
229 : ctx: &RequestContext,
230 : conn_info: ConnInfo,
231 : ) -> Result<http_conn_pool::Client<LocalProxyClient>, HttpConnError> {
232 : debug!("pool: looking for an existing connection");
233 : if let Ok(Some(client)) = self.http_conn_pool.get(ctx, &conn_info) {
234 : return Ok(client);
235 : }
236 :
237 : let conn_id = uuid::Uuid::new_v4();
238 : tracing::Span::current().record("conn_id", display(conn_id));
239 : debug!(%conn_id, "pool: opening a new connection '{conn_info}'");
240 : let backend = self.auth_backend.as_ref().map(|()| ComputeUserInfo {
241 0 : user: conn_info.user_info.user.clone(),
242 0 : endpoint: EndpointId::from(format!(
243 0 : "{}{LOCAL_PROXY_SUFFIX}",
244 0 : conn_info.user_info.endpoint.normalize()
245 : )),
246 0 : options: conn_info.user_info.options.clone(),
247 0 : });
248 :
249 : let node = connect_compute::connect_to_compute(
250 : ctx,
251 : self.config,
252 : &backend,
253 : connect_compute::TlsNegotiation::Direct,
254 : )
255 : .await?;
256 :
257 : let stream = match node.stream.into_framed().into_inner() {
258 : MaybeTlsStream::Raw(s) => Box::pin(s) as AsyncRW,
259 : MaybeTlsStream::Tls(s) => Box::pin(s) as AsyncRW,
260 : };
261 :
262 : let (client, connection) = hyper::client::conn::http2::Builder::new(TokioExecutor::new())
263 : .timer(TokioTimer::new())
264 : .keep_alive_interval(Duration::from_secs(20))
265 : .keep_alive_while_idle(true)
266 : .keep_alive_timeout(Duration::from_secs(5))
267 : .handshake(TokioIo::new(stream))
268 : .await
269 : .map_err(LocalProxyConnError::H2)?;
270 :
271 : Ok(poll_http2_client(
272 : self.http_conn_pool.clone(),
273 : ctx,
274 : &conn_info,
275 : client,
276 : connection,
277 : conn_id,
278 : node.aux.clone(),
279 : ))
280 : }
281 :
282 : /// Connect to postgres over localhost.
283 : ///
284 : /// We expect postgres to be started here, so we won't do any retries.
285 : ///
286 : /// # Panics
287 : ///
288 : /// Panics if called with a non-local_proxy backend.
289 : #[tracing::instrument(skip_all, fields(
290 : pid = tracing::field::Empty,
291 : conn_id = tracing::field::Empty,
292 : ))]
293 : pub(crate) async fn connect_to_local_postgres(
294 : &self,
295 : ctx: &RequestContext,
296 : conn_info: ConnInfo,
297 : disable_pg_session_jwt: bool,
298 : ) -> Result<Client<postgres_client::Client>, HttpConnError> {
299 : if let Some(client) = self.local_pool.get(ctx, &conn_info)? {
300 : return Ok(client);
301 : }
302 :
303 : let local_backend = match &self.auth_backend {
304 : auth::Backend::ControlPlane(_, ()) => {
305 : unreachable!("only local_proxy can connect to local postgres")
306 : }
307 : auth::Backend::Local(local) => local,
308 : };
309 :
310 : if !self.local_pool.initialized(&conn_info) {
311 : // only install and grant usage one at a time.
312 : let _permit = local_backend
313 : .initialize
314 : .acquire()
315 : .await
316 : .expect("semaphore should never be closed");
317 :
318 : // check again for race
319 : if !self.local_pool.initialized(&conn_info) && !disable_pg_session_jwt {
320 : local_backend
321 : .compute_ctl
322 : .install_extension(&ExtensionInstallRequest {
323 : extension: EXT_NAME,
324 : database: conn_info.dbname.clone(),
325 : version: EXT_VERSION,
326 : })
327 : .await?;
328 :
329 : local_backend
330 : .compute_ctl
331 : .grant_role(&SetRoleGrantsRequest {
332 : schema: EXT_SCHEMA,
333 : privileges: vec![Privilege::Usage],
334 : database: conn_info.dbname.clone(),
335 : role: conn_info.user_info.user.clone(),
336 : })
337 : .await?;
338 :
339 : self.local_pool.set_initialized(&conn_info);
340 : }
341 : }
342 :
343 : let conn_id = uuid::Uuid::new_v4();
344 : tracing::Span::current().record("conn_id", display(conn_id));
345 : info!(%conn_id, "local_pool: opening a new connection '{conn_info}'");
346 :
347 : let (key, jwk) = create_random_jwk();
348 :
349 : let mut config = local_backend
350 : .node_info
351 : .conn_info
352 : .to_postgres_client_config();
353 : config
354 : .user(&conn_info.user_info.user)
355 : .dbname(&conn_info.dbname);
356 : if !disable_pg_session_jwt {
357 : config.set_param(
358 : "options",
359 : &format!(
360 : "-c pg_session_jwt.jwk={}",
361 : serde_json::to_string(&jwk).expect("serializing jwk to json should not fail")
362 : ),
363 : );
364 : }
365 :
366 : let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
367 : let (client, connection) = config.connect(&postgres_client::NoTls).await?;
368 : drop(pause);
369 :
370 : let pid = client.get_process_id();
371 : tracing::Span::current().record("pid", pid);
372 :
373 : let mut handle = local_conn_pool::poll_client(
374 : self.local_pool.clone(),
375 : ctx,
376 : conn_info,
377 : client,
378 : connection,
379 : key,
380 : conn_id,
381 : local_backend.node_info.aux.clone(),
382 : );
383 :
384 : {
385 : let (client, mut discard) = handle.inner();
386 : debug!("setting up backend session state");
387 :
388 : // initiates the auth session
389 : if !disable_pg_session_jwt
390 : && let Err(e) = client.batch_execute("select auth.init();").await
391 : {
392 : discard.discard();
393 : return Err(e.into());
394 : }
395 :
396 : info!("backend session state initialized");
397 : }
398 :
399 : Ok(handle)
400 : }
401 : }
402 :
403 0 : fn create_random_jwk() -> (SigningKey, jose_jwk::Key) {
404 0 : let key = SigningKey::generate(&mut OsRng);
405 :
406 0 : let jwk = jose_jwk::Key::Okp(jose_jwk::Okp {
407 0 : crv: jose_jwk::OkpCurves::Ed25519,
408 0 : x: jose_b64::serde::Bytes::from(key.verifying_key().to_bytes().to_vec()),
409 0 : d: None,
410 0 : });
411 :
412 0 : (key, jwk)
413 0 : }
414 :
415 : #[derive(Debug, thiserror::Error)]
416 : pub(crate) enum HttpConnError {
417 : #[error("pooled connection closed at inconsistent state")]
418 : ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError<uuid::Uuid>),
419 : #[error("could not connect to compute")]
420 : ConnectError(#[from] compute::ConnectionError),
421 : #[error("could not connect to postgres in compute")]
422 : PostgresConnectionError(#[from] postgres_client::Error),
423 : #[error("could not connect to local-proxy in compute")]
424 : LocalProxyConnectionError(#[from] LocalProxyConnError),
425 : #[error("could not parse JWT payload")]
426 : JwtPayloadError(serde_json::Error),
427 :
428 : #[error("could not install extension: {0}")]
429 : ComputeCtl(#[from] ComputeCtlError),
430 : #[error("could not get auth info")]
431 : GetAuthInfo(#[from] GetAuthInfoError),
432 : #[error("user not authenticated")]
433 : AuthError(#[from] AuthError),
434 : #[error("wake_compute returned error")]
435 : WakeCompute(#[from] WakeComputeError),
436 : #[error("error acquiring resource permit: {0}")]
437 : TooManyConnectionAttempts(#[from] ApiLockError),
438 : }
439 :
440 : impl From<connect_auth::AuthError> for HttpConnError {
441 0 : fn from(value: connect_auth::AuthError) -> Self {
442 0 : match value {
443 0 : connect_auth::AuthError::Auth(compute::PostgresError::Postgres(error)) => {
444 0 : Self::PostgresConnectionError(error)
445 : }
446 0 : connect_auth::AuthError::Connect(error) => Self::ConnectError(error),
447 : }
448 0 : }
449 : }
450 :
451 : #[derive(Debug, thiserror::Error)]
452 : pub(crate) enum LocalProxyConnError {
453 : #[error("could not establish h2 connection")]
454 : H2(#[from] hyper::Error),
455 : }
456 :
457 : impl ReportableError for HttpConnError {
458 0 : fn get_error_kind(&self) -> ErrorKind {
459 0 : match self {
460 0 : HttpConnError::ConnectError(_) => ErrorKind::Compute,
461 0 : HttpConnError::ConnectionClosedAbruptly(_) => ErrorKind::Compute,
462 0 : HttpConnError::PostgresConnectionError(p) => {
463 0 : if p.as_db_error().is_some() {
464 : // postgres rejected the connection
465 0 : ErrorKind::Postgres
466 : } else {
467 : // couldn't even reach postgres
468 0 : ErrorKind::Compute
469 : }
470 : }
471 0 : HttpConnError::LocalProxyConnectionError(_) => ErrorKind::Compute,
472 0 : HttpConnError::ComputeCtl(_) => ErrorKind::Service,
473 0 : HttpConnError::JwtPayloadError(_) => ErrorKind::User,
474 0 : HttpConnError::GetAuthInfo(a) => a.get_error_kind(),
475 0 : HttpConnError::AuthError(a) => a.get_error_kind(),
476 0 : HttpConnError::WakeCompute(w) => w.get_error_kind(),
477 0 : HttpConnError::TooManyConnectionAttempts(w) => w.get_error_kind(),
478 : }
479 0 : }
480 : }
481 :
482 : impl UserFacingError for HttpConnError {
483 0 : fn to_string_client(&self) -> String {
484 0 : match self {
485 0 : HttpConnError::ConnectError(p) => p.to_string_client(),
486 0 : HttpConnError::ConnectionClosedAbruptly(_) => self.to_string(),
487 0 : HttpConnError::PostgresConnectionError(p) => p.to_string(),
488 0 : HttpConnError::LocalProxyConnectionError(p) => p.to_string(),
489 0 : HttpConnError::ComputeCtl(_) => "could not set up the JWT authorization database extension".to_string(),
490 0 : HttpConnError::JwtPayloadError(p) => p.to_string(),
491 0 : HttpConnError::GetAuthInfo(c) => c.to_string_client(),
492 0 : HttpConnError::AuthError(c) => c.to_string_client(),
493 0 : HttpConnError::WakeCompute(c) => c.to_string_client(),
494 : HttpConnError::TooManyConnectionAttempts(_) => {
495 0 : "Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
496 : }
497 : }
498 0 : }
499 : }
500 :
501 : impl ReportableError for LocalProxyConnError {
502 0 : fn get_error_kind(&self) -> ErrorKind {
503 0 : match self {
504 0 : LocalProxyConnError::H2(_) => ErrorKind::Compute,
505 : }
506 0 : }
507 : }
508 :
509 : impl UserFacingError for LocalProxyConnError {
510 0 : fn to_string_client(&self) -> String {
511 0 : "Could not establish HTTP connection to the database".to_string()
512 0 : }
513 : }
|