Line data Source code
1 : use std::{sync::Arc, time::Duration};
2 :
3 : use async_trait::async_trait;
4 : use tracing::{field::display, info};
5 :
6 : use crate::{
7 : auth::{backend::ComputeCredentials, check_peer_addr_is_in_list, AuthError},
8 : compute,
9 : config::{AuthenticationConfig, ProxyConfig},
10 : console::{
11 : errors::{GetAuthInfoError, WakeComputeError},
12 : CachedNodeInfo,
13 : },
14 : context::RequestMonitoring,
15 : error::{ErrorKind, ReportableError, UserFacingError},
16 : proxy::connect_compute::ConnectMechanism,
17 : };
18 :
19 : use super::conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool};
20 :
21 : pub struct PoolingBackend {
22 : pub pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
23 : pub config: &'static ProxyConfig,
24 : }
25 :
26 : impl PoolingBackend {
27 0 : pub async fn authenticate(
28 0 : &self,
29 0 : ctx: &mut RequestMonitoring,
30 0 : config: &AuthenticationConfig,
31 0 : conn_info: &ConnInfo,
32 0 : ) -> Result<ComputeCredentials, AuthError> {
33 0 : let user_info = conn_info.user_info.clone();
34 0 : let backend = self.config.auth_backend.as_ref().map(|_| user_info.clone());
35 0 : let (allowed_ips, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?;
36 0 : if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) {
37 0 : return Err(AuthError::ip_address_not_allowed(ctx.peer_addr));
38 0 : }
39 0 : let cached_secret = match maybe_secret {
40 0 : Some(secret) => secret,
41 0 : None => backend.get_role_secret(ctx).await?,
42 : };
43 :
44 0 : let secret = match cached_secret.value.clone() {
45 0 : Some(secret) => self.config.authentication_config.check_rate_limit(
46 0 : ctx,
47 0 : config,
48 0 : secret,
49 0 : &user_info.endpoint,
50 0 : true,
51 0 : )?,
52 : None => {
53 : // If we don't have an authentication secret, for the http flow we can just return an error.
54 0 : info!("authentication info not found");
55 0 : return Err(AuthError::auth_failed(&*user_info.user));
56 : }
57 : };
58 0 : let auth_outcome =
59 0 : crate::auth::validate_password_and_exchange(&conn_info.password, secret).await?;
60 0 : let res = match auth_outcome {
61 0 : crate::sasl::Outcome::Success(key) => {
62 0 : info!("user successfully authenticated");
63 0 : Ok(key)
64 : }
65 0 : crate::sasl::Outcome::Failure(reason) => {
66 0 : info!("auth backend failed with an error: {reason}");
67 0 : Err(AuthError::auth_failed(&*conn_info.user_info.user))
68 : }
69 : };
70 0 : res.map(|key| ComputeCredentials {
71 0 : info: user_info,
72 0 : keys: key,
73 0 : })
74 0 : }
75 :
76 : // Wake up the destination if needed. Code here is a bit involved because
77 : // we reuse the code from the usual proxy and we need to prepare few structures
78 : // that this code expects.
79 0 : #[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)]
80 : pub async fn connect_to_compute(
81 : &self,
82 : ctx: &mut RequestMonitoring,
83 : conn_info: ConnInfo,
84 : keys: ComputeCredentials,
85 : force_new: bool,
86 : ) -> Result<Client<tokio_postgres::Client>, HttpConnError> {
87 : let maybe_client = if !force_new {
88 0 : info!("pool: looking for an existing connection");
89 : self.pool.get(ctx, &conn_info).await?
90 : } else {
91 0 : info!("pool: pool is disabled");
92 : None
93 : };
94 :
95 : if let Some(client) = maybe_client {
96 : return Ok(client);
97 : }
98 : let conn_id = uuid::Uuid::new_v4();
99 : tracing::Span::current().record("conn_id", display(conn_id));
100 0 : info!(%conn_id, "pool: opening a new connection '{conn_info}'");
101 0 : let backend = self.config.auth_backend.as_ref().map(|_| keys);
102 : crate::proxy::connect_compute::connect_to_compute(
103 : ctx,
104 : &TokioMechanism {
105 : conn_id,
106 : conn_info,
107 : pool: self.pool.clone(),
108 : },
109 : &backend,
110 : false, // do not allow self signed compute for http flow
111 : )
112 : .await
113 : }
114 : }
115 :
116 0 : #[derive(Debug, thiserror::Error)]
117 : pub enum HttpConnError {
118 : #[error("pooled connection closed at inconsistent state")]
119 : ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError<uuid::Uuid>),
120 : #[error("could not connection to compute")]
121 : ConnectionError(#[from] tokio_postgres::Error),
122 :
123 : #[error("could not get auth info")]
124 : GetAuthInfo(#[from] GetAuthInfoError),
125 : #[error("user not authenticated")]
126 : AuthError(#[from] AuthError),
127 : #[error("wake_compute returned error")]
128 : WakeCompute(#[from] WakeComputeError),
129 : }
130 :
131 : impl ReportableError for HttpConnError {
132 0 : fn get_error_kind(&self) -> ErrorKind {
133 0 : match self {
134 0 : HttpConnError::ConnectionClosedAbruptly(_) => ErrorKind::Compute,
135 0 : HttpConnError::ConnectionError(p) => p.get_error_kind(),
136 0 : HttpConnError::GetAuthInfo(a) => a.get_error_kind(),
137 0 : HttpConnError::AuthError(a) => a.get_error_kind(),
138 0 : HttpConnError::WakeCompute(w) => w.get_error_kind(),
139 : }
140 0 : }
141 : }
142 :
143 : impl UserFacingError for HttpConnError {
144 0 : fn to_string_client(&self) -> String {
145 0 : match self {
146 0 : HttpConnError::ConnectionClosedAbruptly(_) => self.to_string(),
147 0 : HttpConnError::ConnectionError(p) => p.to_string(),
148 0 : HttpConnError::GetAuthInfo(c) => c.to_string_client(),
149 0 : HttpConnError::AuthError(c) => c.to_string_client(),
150 0 : HttpConnError::WakeCompute(c) => c.to_string_client(),
151 : }
152 0 : }
153 : }
154 :
155 : struct TokioMechanism {
156 : pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
157 : conn_info: ConnInfo,
158 : conn_id: uuid::Uuid,
159 : }
160 :
161 : #[async_trait]
162 : impl ConnectMechanism for TokioMechanism {
163 : type Connection = Client<tokio_postgres::Client>;
164 : type ConnectError = tokio_postgres::Error;
165 : type Error = HttpConnError;
166 :
167 0 : async fn connect_once(
168 0 : &self,
169 0 : ctx: &mut RequestMonitoring,
170 0 : node_info: &CachedNodeInfo,
171 0 : timeout: Duration,
172 0 : ) -> Result<Self::Connection, Self::ConnectError> {
173 0 : let mut config = (*node_info.config).clone();
174 0 : let config = config
175 0 : .user(&self.conn_info.user_info.user)
176 0 : .password(&*self.conn_info.password)
177 0 : .dbname(&self.conn_info.dbname)
178 0 : .connect_timeout(timeout);
179 :
180 0 : let (client, connection) = config.connect(tokio_postgres::NoTls).await?;
181 :
182 0 : tracing::Span::current().record("pid", &tracing::field::display(client.get_process_id()));
183 0 : Ok(poll_client(
184 0 : self.pool.clone(),
185 0 : ctx,
186 0 : self.conn_info.clone(),
187 0 : client,
188 0 : connection,
189 0 : self.conn_id,
190 0 : node_info.aux.clone(),
191 0 : ))
192 0 : }
193 :
194 0 : fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
195 : }
|