Line data Source code
1 : use std::{sync::Arc, time::Duration};
2 :
3 : use async_trait::async_trait;
4 : use tracing::{field::display, info};
5 :
6 : use crate::{
7 : auth::{
8 : backend::{local::StaticAuthRules, ComputeCredentials, ComputeUserInfo},
9 : check_peer_addr_is_in_list, AuthError,
10 : },
11 : compute,
12 : config::{AuthenticationConfig, ProxyConfig},
13 : console::{
14 : errors::{GetAuthInfoError, WakeComputeError},
15 : locks::ApiLocks,
16 : provider::ApiLockError,
17 : CachedNodeInfo,
18 : },
19 : context::RequestMonitoring,
20 : error::{ErrorKind, ReportableError, UserFacingError},
21 : intern::EndpointIdInt,
22 : proxy::{
23 : connect_compute::ConnectMechanism,
24 : retry::{CouldRetry, ShouldRetryWakeCompute},
25 : },
26 : rate_limiter::EndpointRateLimiter,
27 : Host,
28 : };
29 :
30 : use super::conn_pool::{poll_client, AuthData, Client, ConnInfo, GlobalConnPool};
31 :
32 : pub(crate) struct PoolingBackend {
33 : pub(crate) pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
34 : pub(crate) config: &'static ProxyConfig,
35 : pub(crate) endpoint_rate_limiter: Arc<EndpointRateLimiter>,
36 : }
37 :
38 : impl PoolingBackend {
39 0 : pub(crate) async fn authenticate_with_password(
40 0 : &self,
41 0 : ctx: &RequestMonitoring,
42 0 : config: &AuthenticationConfig,
43 0 : user_info: &ComputeUserInfo,
44 0 : password: &[u8],
45 0 : ) -> Result<ComputeCredentials, AuthError> {
46 0 : let user_info = user_info.clone();
47 0 : let backend = self
48 0 : .config
49 0 : .auth_backend
50 0 : .as_ref()
51 0 : .map(|()| user_info.clone());
52 0 : let (allowed_ips, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?;
53 0 : if config.ip_allowlist_check_enabled
54 0 : && !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips)
55 : {
56 0 : return Err(AuthError::ip_address_not_allowed(ctx.peer_addr()));
57 0 : }
58 0 : if !self
59 0 : .endpoint_rate_limiter
60 0 : .check(user_info.endpoint.clone().into(), 1)
61 : {
62 0 : return Err(AuthError::too_many_connections());
63 0 : }
64 0 : let cached_secret = match maybe_secret {
65 0 : Some(secret) => secret,
66 0 : None => backend.get_role_secret(ctx).await?,
67 : };
68 :
69 0 : let secret = match cached_secret.value.clone() {
70 0 : Some(secret) => self.config.authentication_config.check_rate_limit(
71 0 : ctx,
72 0 : config,
73 0 : secret,
74 0 : &user_info.endpoint,
75 0 : true,
76 0 : )?,
77 : None => {
78 : // If we don't have an authentication secret, for the http flow we can just return an error.
79 0 : info!("authentication info not found");
80 0 : return Err(AuthError::auth_failed(&*user_info.user));
81 : }
82 : };
83 0 : let ep = EndpointIdInt::from(&user_info.endpoint);
84 0 : let auth_outcome =
85 0 : crate::auth::validate_password_and_exchange(&config.thread_pool, ep, password, secret)
86 0 : .await?;
87 0 : let res = match auth_outcome {
88 0 : crate::sasl::Outcome::Success(key) => {
89 0 : info!("user successfully authenticated");
90 0 : Ok(key)
91 : }
92 0 : crate::sasl::Outcome::Failure(reason) => {
93 0 : info!("auth backend failed with an error: {reason}");
94 0 : Err(AuthError::auth_failed(&*user_info.user))
95 : }
96 : };
97 0 : res.map(|key| ComputeCredentials {
98 0 : info: user_info,
99 0 : keys: key,
100 0 : })
101 0 : }
102 :
103 0 : pub(crate) async fn authenticate_with_jwt(
104 0 : &self,
105 0 : ctx: &RequestMonitoring,
106 0 : user_info: &ComputeUserInfo,
107 0 : jwt: &str,
108 0 : ) -> Result<ComputeCredentials, AuthError> {
109 0 : match &self.config.auth_backend {
110 : crate::auth::Backend::Console(_, ()) => {
111 0 : Err(AuthError::auth_failed("JWT login is not yet supported"))
112 : }
113 0 : crate::auth::Backend::Web(_, ()) => Err(AuthError::auth_failed(
114 0 : "JWT login over web auth proxy is not supported",
115 0 : )),
116 0 : crate::auth::Backend::Local(cache) => {
117 0 : cache
118 0 : .jwks_cache
119 0 : .check_jwt(
120 0 : ctx,
121 0 : user_info.endpoint.clone(),
122 0 : user_info.user.clone(),
123 0 : &StaticAuthRules,
124 0 : jwt,
125 0 : )
126 0 : .await
127 0 : .map_err(|e| AuthError::auth_failed(e.to_string()))?;
128 0 : Ok(ComputeCredentials {
129 0 : info: user_info.clone(),
130 0 : keys: crate::auth::backend::ComputeCredentialKeys::None,
131 0 : })
132 : }
133 : }
134 0 : }
135 :
136 : // Wake up the destination if needed. Code here is a bit involved because
137 : // we reuse the code from the usual proxy and we need to prepare few structures
138 : // that this code expects.
139 0 : #[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)]
140 : pub(crate) async fn connect_to_compute(
141 : &self,
142 : ctx: &RequestMonitoring,
143 : conn_info: ConnInfo,
144 : keys: ComputeCredentials,
145 : force_new: bool,
146 : ) -> Result<Client<tokio_postgres::Client>, HttpConnError> {
147 : let maybe_client = if force_new {
148 : info!("pool: pool is disabled");
149 : None
150 : } else {
151 : info!("pool: looking for an existing connection");
152 : self.pool.get(ctx, &conn_info)?
153 : };
154 :
155 : if let Some(client) = maybe_client {
156 : return Ok(client);
157 : }
158 : let conn_id = uuid::Uuid::new_v4();
159 : tracing::Span::current().record("conn_id", display(conn_id));
160 : info!(%conn_id, "pool: opening a new connection '{conn_info}'");
161 0 : let backend = self.config.auth_backend.as_ref().map(|()| keys);
162 : crate::proxy::connect_compute::connect_to_compute(
163 : ctx,
164 : &TokioMechanism {
165 : conn_id,
166 : conn_info,
167 : pool: self.pool.clone(),
168 : locks: &self.config.connect_compute_locks,
169 : },
170 : &backend,
171 : false, // do not allow self signed compute for http flow
172 : self.config.wake_compute_retry_config,
173 : self.config.connect_to_compute_retry_config,
174 : )
175 : .await
176 : }
177 : }
178 :
179 0 : #[derive(Debug, thiserror::Error)]
180 : pub(crate) enum HttpConnError {
181 : #[error("pooled connection closed at inconsistent state")]
182 : ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError<uuid::Uuid>),
183 : #[error("could not connection to compute")]
184 : ConnectionError(#[from] tokio_postgres::Error),
185 :
186 : #[error("could not get auth info")]
187 : GetAuthInfo(#[from] GetAuthInfoError),
188 : #[error("user not authenticated")]
189 : AuthError(#[from] AuthError),
190 : #[error("wake_compute returned error")]
191 : WakeCompute(#[from] WakeComputeError),
192 : #[error("error acquiring resource permit: {0}")]
193 : TooManyConnectionAttempts(#[from] ApiLockError),
194 : }
195 :
196 : impl ReportableError for HttpConnError {
197 0 : fn get_error_kind(&self) -> ErrorKind {
198 0 : match self {
199 0 : HttpConnError::ConnectionClosedAbruptly(_) => ErrorKind::Compute,
200 0 : HttpConnError::ConnectionError(p) => p.get_error_kind(),
201 0 : HttpConnError::GetAuthInfo(a) => a.get_error_kind(),
202 0 : HttpConnError::AuthError(a) => a.get_error_kind(),
203 0 : HttpConnError::WakeCompute(w) => w.get_error_kind(),
204 0 : HttpConnError::TooManyConnectionAttempts(w) => w.get_error_kind(),
205 : }
206 0 : }
207 : }
208 :
209 : impl UserFacingError for HttpConnError {
210 0 : fn to_string_client(&self) -> String {
211 0 : match self {
212 0 : HttpConnError::ConnectionClosedAbruptly(_) => self.to_string(),
213 0 : HttpConnError::ConnectionError(p) => p.to_string(),
214 0 : HttpConnError::GetAuthInfo(c) => c.to_string_client(),
215 0 : HttpConnError::AuthError(c) => c.to_string_client(),
216 0 : HttpConnError::WakeCompute(c) => c.to_string_client(),
217 : HttpConnError::TooManyConnectionAttempts(_) => {
218 0 : "Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
219 : }
220 : }
221 0 : }
222 : }
223 :
224 : impl CouldRetry for HttpConnError {
225 0 : fn could_retry(&self) -> bool {
226 0 : match self {
227 0 : HttpConnError::ConnectionError(e) => e.could_retry(),
228 0 : HttpConnError::ConnectionClosedAbruptly(_) => false,
229 0 : HttpConnError::GetAuthInfo(_) => false,
230 0 : HttpConnError::AuthError(_) => false,
231 0 : HttpConnError::WakeCompute(_) => false,
232 0 : HttpConnError::TooManyConnectionAttempts(_) => false,
233 : }
234 0 : }
235 : }
236 : impl ShouldRetryWakeCompute for HttpConnError {
237 0 : fn should_retry_wake_compute(&self) -> bool {
238 0 : match self {
239 0 : HttpConnError::ConnectionError(e) => e.should_retry_wake_compute(),
240 : // we never checked cache validity
241 0 : HttpConnError::TooManyConnectionAttempts(_) => false,
242 0 : _ => true,
243 : }
244 0 : }
245 : }
246 :
247 : struct TokioMechanism {
248 : pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
249 : conn_info: ConnInfo,
250 : conn_id: uuid::Uuid,
251 :
252 : /// connect_to_compute concurrency lock
253 : locks: &'static ApiLocks<Host>,
254 : }
255 :
256 : #[async_trait]
257 : impl ConnectMechanism for TokioMechanism {
258 : type Connection = Client<tokio_postgres::Client>;
259 : type ConnectError = HttpConnError;
260 : type Error = HttpConnError;
261 :
262 0 : async fn connect_once(
263 0 : &self,
264 0 : ctx: &RequestMonitoring,
265 0 : node_info: &CachedNodeInfo,
266 0 : timeout: Duration,
267 0 : ) -> Result<Self::Connection, Self::ConnectError> {
268 0 : let host = node_info.config.get_host()?;
269 0 : let permit = self.locks.get_permit(&host).await?;
270 :
271 0 : let mut config = (*node_info.config).clone();
272 0 : let config = config
273 0 : .user(&self.conn_info.user_info.user)
274 0 : .dbname(&self.conn_info.dbname)
275 0 : .connect_timeout(timeout);
276 0 :
277 0 : match &self.conn_info.auth {
278 0 : AuthData::Jwt(_) => {}
279 0 : AuthData::Password(pw) => {
280 0 : config.password(pw);
281 0 : }
282 : }
283 :
284 0 : let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
285 0 : let res = config.connect(tokio_postgres::NoTls).await;
286 0 : drop(pause);
287 0 : let (client, connection) = permit.release_result(res)?;
288 :
289 0 : tracing::Span::current().record("pid", tracing::field::display(client.get_process_id()));
290 0 : Ok(poll_client(
291 0 : self.pool.clone(),
292 0 : ctx,
293 0 : self.conn_info.clone(),
294 0 : client,
295 0 : connection,
296 0 : self.conn_id,
297 0 : node_info.aux.clone(),
298 0 : ))
299 0 : }
300 :
301 0 : fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
302 : }
|