Line data Source code
1 : use std::{sync::Arc, time::Duration};
2 :
3 : use async_trait::async_trait;
4 : use tracing::{field::display, info};
5 :
6 : use crate::{
7 : auth::{backend::ComputeCredentials, check_peer_addr_is_in_list, AuthError},
8 : compute,
9 : config::{AuthenticationConfig, ProxyConfig},
10 : console::{
11 : errors::{GetAuthInfoError, WakeComputeError},
12 : locks::ApiLocks,
13 : provider::ApiLockError,
14 : CachedNodeInfo,
15 : },
16 : context::RequestMonitoring,
17 : error::{ErrorKind, ReportableError, UserFacingError},
18 : proxy::{connect_compute::ConnectMechanism, retry::ShouldRetry},
19 : rate_limiter::EndpointRateLimiter,
20 : Host,
21 : };
22 :
23 : use super::conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool};
24 :
25 : pub struct PoolingBackend {
26 : pub pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
27 : pub config: &'static ProxyConfig,
28 : pub endpoint_rate_limiter: Arc<EndpointRateLimiter>,
29 : }
30 :
31 : impl PoolingBackend {
32 0 : pub async fn authenticate(
33 0 : &self,
34 0 : ctx: &mut RequestMonitoring,
35 0 : config: &AuthenticationConfig,
36 0 : conn_info: &ConnInfo,
37 0 : ) -> Result<ComputeCredentials, AuthError> {
38 0 : let user_info = conn_info.user_info.clone();
39 0 : let backend = self.config.auth_backend.as_ref().map(|_| user_info.clone());
40 0 : let (allowed_ips, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?;
41 0 : if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) {
42 0 : return Err(AuthError::ip_address_not_allowed(ctx.peer_addr));
43 0 : }
44 0 : if !self
45 0 : .endpoint_rate_limiter
46 0 : .check(conn_info.user_info.endpoint.clone().into(), 1)
47 : {
48 0 : return Err(AuthError::too_many_connections());
49 0 : }
50 0 : let cached_secret = match maybe_secret {
51 0 : Some(secret) => secret,
52 0 : None => backend.get_role_secret(ctx).await?,
53 : };
54 :
55 0 : let secret = match cached_secret.value.clone() {
56 0 : Some(secret) => self.config.authentication_config.check_rate_limit(
57 0 : ctx,
58 0 : config,
59 0 : secret,
60 0 : &user_info.endpoint,
61 0 : true,
62 0 : )?,
63 : None => {
64 : // If we don't have an authentication secret, for the http flow we can just return an error.
65 0 : info!("authentication info not found");
66 0 : return Err(AuthError::auth_failed(&*user_info.user));
67 : }
68 : };
69 0 : let auth_outcome =
70 0 : crate::auth::validate_password_and_exchange(&conn_info.password, secret).await?;
71 0 : let res = match auth_outcome {
72 0 : crate::sasl::Outcome::Success(key) => {
73 0 : info!("user successfully authenticated");
74 0 : Ok(key)
75 : }
76 0 : crate::sasl::Outcome::Failure(reason) => {
77 0 : info!("auth backend failed with an error: {reason}");
78 0 : Err(AuthError::auth_failed(&*conn_info.user_info.user))
79 : }
80 : };
81 0 : res.map(|key| ComputeCredentials {
82 0 : info: user_info,
83 0 : keys: key,
84 0 : })
85 0 : }
86 :
87 : // Wake up the destination if needed. Code here is a bit involved because
88 : // we reuse the code from the usual proxy and we need to prepare few structures
89 : // that this code expects.
90 0 : #[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)]
91 : pub async fn connect_to_compute(
92 : &self,
93 : ctx: &mut RequestMonitoring,
94 : conn_info: ConnInfo,
95 : keys: ComputeCredentials,
96 : force_new: bool,
97 : ) -> Result<Client<tokio_postgres::Client>, HttpConnError> {
98 : let maybe_client = if !force_new {
99 : info!("pool: looking for an existing connection");
100 : self.pool.get(ctx, &conn_info).await?
101 : } else {
102 : info!("pool: pool is disabled");
103 : None
104 : };
105 :
106 : if let Some(client) = maybe_client {
107 : return Ok(client);
108 : }
109 : let conn_id = uuid::Uuid::new_v4();
110 : tracing::Span::current().record("conn_id", display(conn_id));
111 : info!(%conn_id, "pool: opening a new connection '{conn_info}'");
112 0 : let backend = self.config.auth_backend.as_ref().map(|_| keys);
113 : crate::proxy::connect_compute::connect_to_compute(
114 : ctx,
115 : &TokioMechanism {
116 : conn_id,
117 : conn_info,
118 : pool: self.pool.clone(),
119 : locks: &self.config.connect_compute_locks,
120 : },
121 : &backend,
122 : false, // do not allow self signed compute for http flow
123 : self.config.wake_compute_retry_config,
124 : self.config.connect_to_compute_retry_config,
125 : )
126 : .await
127 : }
128 : }
129 :
130 0 : #[derive(Debug, thiserror::Error)]
131 : pub enum HttpConnError {
132 : #[error("pooled connection closed at inconsistent state")]
133 : ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError<uuid::Uuid>),
134 : #[error("could not connection to compute")]
135 : ConnectionError(#[from] tokio_postgres::Error),
136 :
137 : #[error("could not get auth info")]
138 : GetAuthInfo(#[from] GetAuthInfoError),
139 : #[error("user not authenticated")]
140 : AuthError(#[from] AuthError),
141 : #[error("wake_compute returned error")]
142 : WakeCompute(#[from] WakeComputeError),
143 : #[error("error acquiring resource permit: {0}")]
144 : TooManyConnectionAttempts(#[from] ApiLockError),
145 : }
146 :
147 : impl ReportableError for HttpConnError {
148 0 : fn get_error_kind(&self) -> ErrorKind {
149 0 : match self {
150 0 : HttpConnError::ConnectionClosedAbruptly(_) => ErrorKind::Compute,
151 0 : HttpConnError::ConnectionError(p) => p.get_error_kind(),
152 0 : HttpConnError::GetAuthInfo(a) => a.get_error_kind(),
153 0 : HttpConnError::AuthError(a) => a.get_error_kind(),
154 0 : HttpConnError::WakeCompute(w) => w.get_error_kind(),
155 0 : HttpConnError::TooManyConnectionAttempts(w) => w.get_error_kind(),
156 : }
157 0 : }
158 : }
159 :
160 : impl UserFacingError for HttpConnError {
161 0 : fn to_string_client(&self) -> String {
162 0 : match self {
163 0 : HttpConnError::ConnectionClosedAbruptly(_) => self.to_string(),
164 0 : HttpConnError::ConnectionError(p) => p.to_string(),
165 0 : HttpConnError::GetAuthInfo(c) => c.to_string_client(),
166 0 : HttpConnError::AuthError(c) => c.to_string_client(),
167 0 : HttpConnError::WakeCompute(c) => c.to_string_client(),
168 : HttpConnError::TooManyConnectionAttempts(_) => {
169 0 : "Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
170 : }
171 : }
172 0 : }
173 : }
174 :
175 : impl ShouldRetry for HttpConnError {
176 0 : fn could_retry(&self) -> bool {
177 0 : match self {
178 0 : HttpConnError::ConnectionError(e) => e.could_retry(),
179 0 : HttpConnError::ConnectionClosedAbruptly(_) => false,
180 0 : HttpConnError::GetAuthInfo(_) => false,
181 0 : HttpConnError::AuthError(_) => false,
182 0 : HttpConnError::WakeCompute(_) => false,
183 0 : HttpConnError::TooManyConnectionAttempts(_) => false,
184 : }
185 0 : }
186 0 : fn should_retry_database_address(&self) -> bool {
187 0 : match self {
188 0 : HttpConnError::ConnectionError(e) => e.should_retry_database_address(),
189 : // we never checked cache validity
190 0 : HttpConnError::TooManyConnectionAttempts(_) => false,
191 0 : _ => true,
192 : }
193 0 : }
194 : }
195 :
196 : struct TokioMechanism {
197 : pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
198 : conn_info: ConnInfo,
199 : conn_id: uuid::Uuid,
200 :
201 : /// connect_to_compute concurrency lock
202 : locks: &'static ApiLocks<Host>,
203 : }
204 :
205 : #[async_trait]
206 : impl ConnectMechanism for TokioMechanism {
207 : type Connection = Client<tokio_postgres::Client>;
208 : type ConnectError = HttpConnError;
209 : type Error = HttpConnError;
210 :
211 0 : async fn connect_once(
212 0 : &self,
213 0 : ctx: &mut RequestMonitoring,
214 0 : node_info: &CachedNodeInfo,
215 0 : timeout: Duration,
216 0 : ) -> Result<Self::Connection, Self::ConnectError> {
217 0 : let host = node_info.config.get_host()?;
218 0 : let permit = self.locks.get_permit(&host).await?;
219 0 :
220 0 : let mut config = (*node_info.config).clone();
221 0 : let config = config
222 0 : .user(&self.conn_info.user_info.user)
223 0 : .password(&*self.conn_info.password)
224 0 : .dbname(&self.conn_info.dbname)
225 0 : .connect_timeout(timeout);
226 0 :
227 0 : let pause = ctx.latency_timer.pause(crate::metrics::Waiting::Compute);
228 0 : let (client, connection) = config.connect(tokio_postgres::NoTls).await?;
229 0 : drop(pause);
230 0 : drop(permit);
231 0 :
232 0 : tracing::Span::current().record("pid", &tracing::field::display(client.get_process_id()));
233 0 : Ok(poll_client(
234 0 : self.pool.clone(),
235 0 : ctx,
236 0 : self.conn_info.clone(),
237 0 : client,
238 0 : connection,
239 0 : self.conn_id,
240 0 : node_info.aux.clone(),
241 0 : ))
242 0 : }
243 :
244 0 : fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
245 : }
|