Line data Source code
1 : use std::sync::Arc;
2 : use std::time::Duration;
3 :
4 : use ed25519_dalek::SigningKey;
5 : use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
6 : use jose_jwk::jose_b64;
7 : use postgres_client::maybe_tls_stream::MaybeTlsStream;
8 : use rand_core::OsRng;
9 : use tracing::field::display;
10 : use tracing::{debug, info};
11 :
12 : use super::AsyncRW;
13 : use super::conn_pool::poll_client;
14 : use super::conn_pool_lib::{Client, ConnInfo, EndpointConnPool, GlobalConnPool};
15 : use super::http_conn_pool::{self, HttpConnPool, LocalProxyClient, poll_http2_client};
16 : use super::local_conn_pool::{self, EXT_NAME, EXT_SCHEMA, EXT_VERSION, LocalConnPool};
17 : use crate::auth::backend::local::StaticAuthRules;
18 : use crate::auth::backend::{ComputeCredentials, ComputeUserInfo};
19 : use crate::auth::{self, AuthError};
20 : use crate::compute;
21 : use crate::compute_ctl::{
22 : ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest,
23 : };
24 : use crate::config::ProxyConfig;
25 : use crate::context::RequestContext;
26 : use crate::control_plane::client::ApiLockError;
27 : use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError};
28 : use crate::error::{ErrorKind, ReportableError, UserFacingError};
29 : use crate::intern::EndpointIdInt;
30 : use crate::pqproto::StartupMessageParams;
31 : use crate::proxy::{connect_auth, connect_compute};
32 : use crate::rate_limiter::EndpointRateLimiter;
33 : use crate::types::{EndpointId, LOCAL_PROXY_SUFFIX};
34 :
35 : pub(crate) struct PoolingBackend {
36 : pub(crate) http_conn_pool:
37 : Arc<GlobalConnPool<LocalProxyClient, HttpConnPool<LocalProxyClient>>>,
38 : pub(crate) local_pool: Arc<LocalConnPool<postgres_client::Client>>,
39 : pub(crate) pool:
40 : Arc<GlobalConnPool<postgres_client::Client, EndpointConnPool<postgres_client::Client>>>,
41 :
42 : pub(crate) config: &'static ProxyConfig,
43 : pub(crate) auth_backend: &'static crate::auth::Backend<'static, ()>,
44 : pub(crate) endpoint_rate_limiter: Arc<EndpointRateLimiter>,
45 : }
46 :
47 : impl PoolingBackend {
48 0 : pub(crate) async fn authenticate_with_password(
49 0 : &self,
50 0 : ctx: &RequestContext,
51 0 : user_info: &ComputeUserInfo,
52 0 : password: &[u8],
53 0 : ) -> Result<ComputeCredentials, AuthError> {
54 0 : ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
55 :
56 0 : let user_info = user_info.clone();
57 0 : let backend = self.auth_backend.as_ref().map(|()| user_info.clone());
58 0 : let access_control = backend.get_endpoint_access_control(ctx).await?;
59 0 : access_control.check(
60 0 : ctx,
61 0 : self.config.authentication_config.ip_allowlist_check_enabled,
62 0 : self.config.authentication_config.is_vpc_acccess_proxy,
63 0 : )?;
64 :
65 0 : access_control.connection_attempt_rate_limit(
66 0 : ctx,
67 0 : &user_info.endpoint,
68 0 : &self.endpoint_rate_limiter,
69 0 : )?;
70 :
71 0 : let role_access = backend.get_role_secret(ctx).await?;
72 0 : let Some(secret) = role_access.secret else {
73 : // If we don't have an authentication secret, for the http flow we can just return an error.
74 0 : info!("authentication info not found");
75 0 : return Err(AuthError::password_failed(&*user_info.user));
76 : };
77 :
78 0 : let ep = EndpointIdInt::from(&user_info.endpoint);
79 0 : let auth_outcome = crate::auth::validate_password_and_exchange(
80 0 : &self.config.authentication_config.thread_pool,
81 0 : ep,
82 0 : password,
83 0 : secret,
84 0 : )
85 0 : .await?;
86 0 : let res = match auth_outcome {
87 0 : crate::sasl::Outcome::Success(key) => {
88 0 : info!("user successfully authenticated");
89 0 : Ok(key)
90 : }
91 0 : crate::sasl::Outcome::Failure(reason) => {
92 0 : info!("auth backend failed with an error: {reason}");
93 0 : Err(AuthError::password_failed(&*user_info.user))
94 : }
95 : };
96 0 : res.map(|key| ComputeCredentials {
97 0 : info: user_info,
98 0 : keys: key,
99 0 : })
100 0 : }
101 :
102 0 : pub(crate) async fn authenticate_with_jwt(
103 0 : &self,
104 0 : ctx: &RequestContext,
105 0 : user_info: &ComputeUserInfo,
106 0 : jwt: String,
107 0 : ) -> Result<ComputeCredentials, AuthError> {
108 0 : ctx.set_auth_method(crate::context::AuthMethod::Jwt);
109 :
110 0 : match &self.auth_backend {
111 0 : crate::auth::Backend::ControlPlane(console, ()) => {
112 0 : let keys = self
113 0 : .config
114 0 : .authentication_config
115 0 : .jwks_cache
116 0 : .check_jwt(
117 0 : ctx,
118 0 : user_info.endpoint.clone(),
119 0 : &user_info.user,
120 0 : &**console,
121 0 : &jwt,
122 0 : )
123 0 : .await?;
124 :
125 0 : Ok(ComputeCredentials {
126 0 : info: user_info.clone(),
127 0 : keys,
128 0 : })
129 : }
130 : crate::auth::Backend::Local(_) => {
131 0 : let keys = self
132 0 : .config
133 0 : .authentication_config
134 0 : .jwks_cache
135 0 : .check_jwt(
136 0 : ctx,
137 0 : user_info.endpoint.clone(),
138 0 : &user_info.user,
139 0 : &StaticAuthRules,
140 0 : &jwt,
141 0 : )
142 0 : .await?;
143 :
144 0 : Ok(ComputeCredentials {
145 0 : info: user_info.clone(),
146 0 : keys,
147 0 : })
148 : }
149 : }
150 0 : }
151 :
152 : // Wake up the destination if needed. Code here is a bit involved because
153 : // we reuse the code from the usual proxy and we need to prepare few structures
154 : // that this code expects.
155 : #[tracing::instrument(skip_all, fields(
156 : pid = tracing::field::Empty,
157 : compute_id = tracing::field::Empty,
158 : conn_id = tracing::field::Empty,
159 : ))]
160 : pub(crate) async fn connect_to_compute(
161 : &self,
162 : ctx: &RequestContext,
163 : conn_info: ConnInfo,
164 : keys: ComputeCredentials,
165 : force_new: bool,
166 : ) -> Result<Client<postgres_client::Client>, HttpConnError> {
167 : let maybe_client = if force_new {
168 : debug!("pool: pool is disabled");
169 : None
170 : } else {
171 : debug!("pool: looking for an existing connection");
172 : self.pool.get(ctx, &conn_info)?
173 : };
174 :
175 : if let Some(client) = maybe_client {
176 : return Ok(client);
177 : }
178 : let conn_id = uuid::Uuid::new_v4();
179 : tracing::Span::current().record("conn_id", display(conn_id));
180 : info!(%conn_id, "pool: opening a new connection '{conn_info}'");
181 : let backend = self.auth_backend.as_ref().map(|()| keys.info);
182 :
183 : let mut params = StartupMessageParams::default();
184 : params.insert("database", &conn_info.dbname);
185 : params.insert("user", &conn_info.user_info.user);
186 :
187 : let mut auth_info = compute::AuthInfo::with_auth_keys(keys.keys);
188 : auth_info.set_startup_params(¶ms, true);
189 :
190 : let node = connect_auth::connect_to_compute_and_auth(
191 : ctx,
192 : self.config,
193 : &backend,
194 : auth_info,
195 : connect_compute::TlsNegotiation::Postgres,
196 : )
197 : .await?;
198 :
199 : let (client, connection) = postgres_client::connect::managed(
200 : node.stream,
201 : Some(node.socket_addr.ip()),
202 : postgres_client::config::Host::Tcp(node.hostname.to_string()),
203 : node.socket_addr.port(),
204 : node.ssl_mode,
205 : Some(self.config.connect_to_compute.timeout),
206 : )
207 : .await?;
208 :
209 : Ok(poll_client(
210 : self.pool.clone(),
211 : ctx,
212 : conn_info,
213 : client,
214 : connection,
215 : conn_id,
216 : node.aux,
217 : ))
218 : }
219 :
220 : // Wake up the destination if needed
221 : #[tracing::instrument(skip_all, fields(
222 : compute_id = tracing::field::Empty,
223 : conn_id = tracing::field::Empty,
224 : ))]
225 : pub(crate) async fn connect_to_local_proxy(
226 : &self,
227 : ctx: &RequestContext,
228 : conn_info: ConnInfo,
229 : ) -> Result<http_conn_pool::Client<LocalProxyClient>, HttpConnError> {
230 : debug!("pool: looking for an existing connection");
231 : if let Ok(Some(client)) = self.http_conn_pool.get(ctx, &conn_info) {
232 : return Ok(client);
233 : }
234 :
235 : let conn_id = uuid::Uuid::new_v4();
236 : tracing::Span::current().record("conn_id", display(conn_id));
237 : debug!(%conn_id, "pool: opening a new connection '{conn_info}'");
238 : let backend = self.auth_backend.as_ref().map(|()| ComputeUserInfo {
239 0 : user: conn_info.user_info.user.clone(),
240 0 : endpoint: EndpointId::from(format!(
241 0 : "{}{LOCAL_PROXY_SUFFIX}",
242 0 : conn_info.user_info.endpoint.normalize()
243 : )),
244 0 : options: conn_info.user_info.options.clone(),
245 0 : });
246 :
247 : let node = connect_compute::connect_to_compute(
248 : ctx,
249 : self.config,
250 : &backend,
251 : connect_compute::TlsNegotiation::Direct,
252 : )
253 : .await?;
254 :
255 : let stream = match node.stream.into_framed().into_inner() {
256 : MaybeTlsStream::Raw(s) => Box::pin(s) as AsyncRW,
257 : MaybeTlsStream::Tls(s) => Box::pin(s) as AsyncRW,
258 : };
259 :
260 : let (client, connection) = hyper::client::conn::http2::Builder::new(TokioExecutor::new())
261 : .timer(TokioTimer::new())
262 : .keep_alive_interval(Duration::from_secs(20))
263 : .keep_alive_while_idle(true)
264 : .keep_alive_timeout(Duration::from_secs(5))
265 : .handshake(TokioIo::new(stream))
266 : .await
267 : .map_err(LocalProxyConnError::H2)?;
268 :
269 : Ok(poll_http2_client(
270 : self.http_conn_pool.clone(),
271 : ctx,
272 : &conn_info,
273 : client,
274 : connection,
275 : conn_id,
276 : node.aux.clone(),
277 : ))
278 : }
279 :
280 : /// Connect to postgres over localhost.
281 : ///
282 : /// We expect postgres to be started here, so we won't do any retries.
283 : ///
284 : /// # Panics
285 : ///
286 : /// Panics if called with a non-local_proxy backend.
287 : #[tracing::instrument(skip_all, fields(
288 : pid = tracing::field::Empty,
289 : conn_id = tracing::field::Empty,
290 : ))]
291 : pub(crate) async fn connect_to_local_postgres(
292 : &self,
293 : ctx: &RequestContext,
294 : conn_info: ConnInfo,
295 : disable_pg_session_jwt: bool,
296 : ) -> Result<Client<postgres_client::Client>, HttpConnError> {
297 : if let Some(client) = self.local_pool.get(ctx, &conn_info)? {
298 : return Ok(client);
299 : }
300 :
301 : let local_backend = match &self.auth_backend {
302 : auth::Backend::ControlPlane(_, ()) => {
303 : unreachable!("only local_proxy can connect to local postgres")
304 : }
305 : auth::Backend::Local(local) => local,
306 : };
307 :
308 : if !self.local_pool.initialized(&conn_info) {
309 : // only install and grant usage one at a time.
310 : let _permit = local_backend
311 : .initialize
312 : .acquire()
313 : .await
314 : .expect("semaphore should never be closed");
315 :
316 : // check again for race
317 : if !self.local_pool.initialized(&conn_info) && !disable_pg_session_jwt {
318 : local_backend
319 : .compute_ctl
320 : .install_extension(&ExtensionInstallRequest {
321 : extension: EXT_NAME,
322 : database: conn_info.dbname.clone(),
323 : version: EXT_VERSION,
324 : })
325 : .await?;
326 :
327 : local_backend
328 : .compute_ctl
329 : .grant_role(&SetRoleGrantsRequest {
330 : schema: EXT_SCHEMA,
331 : privileges: vec![Privilege::Usage],
332 : database: conn_info.dbname.clone(),
333 : role: conn_info.user_info.user.clone(),
334 : })
335 : .await?;
336 :
337 : self.local_pool.set_initialized(&conn_info);
338 : }
339 : }
340 :
341 : let conn_id = uuid::Uuid::new_v4();
342 : tracing::Span::current().record("conn_id", display(conn_id));
343 : info!(%conn_id, "local_pool: opening a new connection '{conn_info}'");
344 :
345 : let (key, jwk) = create_random_jwk();
346 :
347 : let mut config = local_backend
348 : .node_info
349 : .conn_info
350 : .to_postgres_client_config();
351 : config
352 : .user(&conn_info.user_info.user)
353 : .dbname(&conn_info.dbname);
354 : if !disable_pg_session_jwt {
355 : config.set_param(
356 : "options",
357 : &format!(
358 : "-c pg_session_jwt.jwk={}",
359 : serde_json::to_string(&jwk).expect("serializing jwk to json should not fail")
360 : ),
361 : );
362 : }
363 :
364 : let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
365 : let (client, connection) = config.connect(&postgres_client::NoTls).await?;
366 : drop(pause);
367 :
368 : let pid = client.get_process_id();
369 : tracing::Span::current().record("pid", pid);
370 :
371 : let mut handle = local_conn_pool::poll_client(
372 : self.local_pool.clone(),
373 : ctx,
374 : conn_info,
375 : client,
376 : connection,
377 : key,
378 : conn_id,
379 : local_backend.node_info.aux.clone(),
380 : );
381 :
382 : {
383 : let (client, mut discard) = handle.inner();
384 : debug!("setting up backend session state");
385 :
386 : // initiates the auth session
387 : if !disable_pg_session_jwt
388 : && let Err(e) = client.batch_execute("select auth.init();").await
389 : {
390 : discard.discard();
391 : return Err(e.into());
392 : }
393 :
394 : info!("backend session state initialized");
395 : }
396 :
397 : Ok(handle)
398 : }
399 : }
400 :
401 0 : fn create_random_jwk() -> (SigningKey, jose_jwk::Key) {
402 0 : let key = SigningKey::generate(&mut OsRng);
403 :
404 0 : let jwk = jose_jwk::Key::Okp(jose_jwk::Okp {
405 0 : crv: jose_jwk::OkpCurves::Ed25519,
406 0 : x: jose_b64::serde::Bytes::from(key.verifying_key().to_bytes().to_vec()),
407 0 : d: None,
408 0 : });
409 :
410 0 : (key, jwk)
411 0 : }
412 :
413 : #[derive(Debug, thiserror::Error)]
414 : pub(crate) enum HttpConnError {
415 : #[error("pooled connection closed at inconsistent state")]
416 : ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError<uuid::Uuid>),
417 : #[error("could not connect to compute")]
418 : ConnectError(#[from] compute::ConnectionError),
419 : #[error("could not connect to postgres in compute")]
420 : PostgresConnectionError(#[from] postgres_client::Error),
421 : #[error("could not connect to local-proxy in compute")]
422 : LocalProxyConnectionError(#[from] LocalProxyConnError),
423 : #[error("could not parse JWT payload")]
424 : JwtPayloadError(serde_json::Error),
425 :
426 : #[error("could not install extension: {0}")]
427 : ComputeCtl(#[from] ComputeCtlError),
428 : #[error("could not get auth info")]
429 : GetAuthInfo(#[from] GetAuthInfoError),
430 : #[error("user not authenticated")]
431 : AuthError(#[from] AuthError),
432 : #[error("wake_compute returned error")]
433 : WakeCompute(#[from] WakeComputeError),
434 : #[error("error acquiring resource permit: {0}")]
435 : TooManyConnectionAttempts(#[from] ApiLockError),
436 : }
437 :
438 : impl From<connect_auth::AuthError> for HttpConnError {
439 0 : fn from(value: connect_auth::AuthError) -> Self {
440 0 : match value {
441 0 : connect_auth::AuthError::Auth(compute::PostgresError::Postgres(error)) => {
442 0 : Self::PostgresConnectionError(error)
443 : }
444 0 : connect_auth::AuthError::Connect(error) => Self::ConnectError(error),
445 : }
446 0 : }
447 : }
448 :
449 : #[derive(Debug, thiserror::Error)]
450 : pub(crate) enum LocalProxyConnError {
451 : #[error("could not establish h2 connection")]
452 : H2(#[from] hyper::Error),
453 : }
454 :
455 : impl ReportableError for HttpConnError {
456 0 : fn get_error_kind(&self) -> ErrorKind {
457 0 : match self {
458 0 : HttpConnError::ConnectError(_) => ErrorKind::Compute,
459 0 : HttpConnError::ConnectionClosedAbruptly(_) => ErrorKind::Compute,
460 0 : HttpConnError::PostgresConnectionError(p) => {
461 0 : if p.as_db_error().is_some() {
462 : // postgres rejected the connection
463 0 : ErrorKind::Postgres
464 : } else {
465 : // couldn't even reach postgres
466 0 : ErrorKind::Compute
467 : }
468 : }
469 0 : HttpConnError::LocalProxyConnectionError(_) => ErrorKind::Compute,
470 0 : HttpConnError::ComputeCtl(_) => ErrorKind::Service,
471 0 : HttpConnError::JwtPayloadError(_) => ErrorKind::User,
472 0 : HttpConnError::GetAuthInfo(a) => a.get_error_kind(),
473 0 : HttpConnError::AuthError(a) => a.get_error_kind(),
474 0 : HttpConnError::WakeCompute(w) => w.get_error_kind(),
475 0 : HttpConnError::TooManyConnectionAttempts(w) => w.get_error_kind(),
476 : }
477 0 : }
478 : }
479 :
480 : impl UserFacingError for HttpConnError {
481 0 : fn to_string_client(&self) -> String {
482 0 : match self {
483 0 : HttpConnError::ConnectError(p) => p.to_string_client(),
484 0 : HttpConnError::ConnectionClosedAbruptly(_) => self.to_string(),
485 0 : HttpConnError::PostgresConnectionError(p) => p.to_string(),
486 0 : HttpConnError::LocalProxyConnectionError(p) => p.to_string(),
487 0 : HttpConnError::ComputeCtl(_) => "could not set up the JWT authorization database extension".to_string(),
488 0 : HttpConnError::JwtPayloadError(p) => p.to_string(),
489 0 : HttpConnError::GetAuthInfo(c) => c.to_string_client(),
490 0 : HttpConnError::AuthError(c) => c.to_string_client(),
491 0 : HttpConnError::WakeCompute(c) => c.to_string_client(),
492 : HttpConnError::TooManyConnectionAttempts(_) => {
493 0 : "Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
494 : }
495 : }
496 0 : }
497 : }
498 :
499 : impl ReportableError for LocalProxyConnError {
500 0 : fn get_error_kind(&self) -> ErrorKind {
501 0 : match self {
502 0 : LocalProxyConnError::H2(_) => ErrorKind::Compute,
503 : }
504 0 : }
505 : }
506 :
507 : impl UserFacingError for LocalProxyConnError {
508 0 : fn to_string_client(&self) -> String {
509 0 : "Could not establish HTTP connection to the database".to_string()
510 0 : }
511 : }
|