Line data Source code
1 : use std::{sync::Arc, time::Duration};
2 :
3 : use async_trait::async_trait;
4 : use tracing::{field::display, info};
5 :
6 : use crate::{
7 : auth::{backend::ComputeCredentials, check_peer_addr_is_in_list, AuthError},
8 : compute,
9 : config::{AuthenticationConfig, ProxyConfig},
10 : console::{
11 : errors::{GetAuthInfoError, WakeComputeError},
12 : locks::ApiLocks,
13 : provider::ApiLockError,
14 : CachedNodeInfo,
15 : },
16 : context::RequestMonitoring,
17 : error::{ErrorKind, ReportableError, UserFacingError},
18 : intern::EndpointIdInt,
19 : proxy::{connect_compute::ConnectMechanism, retry::ShouldRetry},
20 : rate_limiter::EndpointRateLimiter,
21 : Host,
22 : };
23 :
24 : use super::conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool};
25 :
26 : pub struct PoolingBackend {
27 : pub pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
28 : pub config: &'static ProxyConfig,
29 : pub endpoint_rate_limiter: Arc<EndpointRateLimiter>,
30 : }
31 :
32 : impl PoolingBackend {
33 0 : pub async fn authenticate(
34 0 : &self,
35 0 : ctx: &mut RequestMonitoring,
36 0 : config: &AuthenticationConfig,
37 0 : conn_info: &ConnInfo,
38 0 : ) -> Result<ComputeCredentials, AuthError> {
39 0 : let user_info = conn_info.user_info.clone();
40 0 : let backend = self.config.auth_backend.as_ref().map(|_| user_info.clone());
41 0 : let (allowed_ips, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?;
42 0 : if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) {
43 0 : return Err(AuthError::ip_address_not_allowed(ctx.peer_addr));
44 0 : }
45 0 : if !self
46 0 : .endpoint_rate_limiter
47 0 : .check(conn_info.user_info.endpoint.clone().into(), 1)
48 : {
49 0 : return Err(AuthError::too_many_connections());
50 0 : }
51 0 : let cached_secret = match maybe_secret {
52 0 : Some(secret) => secret,
53 0 : None => backend.get_role_secret(ctx).await?,
54 : };
55 :
56 0 : let secret = match cached_secret.value.clone() {
57 0 : Some(secret) => self.config.authentication_config.check_rate_limit(
58 0 : ctx,
59 0 : config,
60 0 : secret,
61 0 : &user_info.endpoint,
62 0 : true,
63 0 : )?,
64 : None => {
65 : // If we don't have an authentication secret, for the http flow we can just return an error.
66 0 : info!("authentication info not found");
67 0 : return Err(AuthError::auth_failed(&*user_info.user));
68 : }
69 : };
70 0 : let ep = EndpointIdInt::from(&conn_info.user_info.endpoint);
71 0 : let auth_outcome = crate::auth::validate_password_and_exchange(
72 0 : &config.thread_pool,
73 0 : ep,
74 0 : &conn_info.password,
75 0 : secret,
76 0 : )
77 0 : .await?;
78 0 : let res = match auth_outcome {
79 0 : crate::sasl::Outcome::Success(key) => {
80 0 : info!("user successfully authenticated");
81 0 : Ok(key)
82 : }
83 0 : crate::sasl::Outcome::Failure(reason) => {
84 0 : info!("auth backend failed with an error: {reason}");
85 0 : Err(AuthError::auth_failed(&*conn_info.user_info.user))
86 : }
87 : };
88 0 : res.map(|key| ComputeCredentials {
89 0 : info: user_info,
90 0 : keys: key,
91 0 : })
92 0 : }
93 :
94 : // Wake up the destination if needed. Code here is a bit involved because
95 : // we reuse the code from the usual proxy and we need to prepare few structures
96 : // that this code expects.
97 0 : #[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)]
98 : pub async fn connect_to_compute(
99 : &self,
100 : ctx: &mut RequestMonitoring,
101 : conn_info: ConnInfo,
102 : keys: ComputeCredentials,
103 : force_new: bool,
104 : ) -> Result<Client<tokio_postgres::Client>, HttpConnError> {
105 : let maybe_client = if !force_new {
106 : info!("pool: looking for an existing connection");
107 : self.pool.get(ctx, &conn_info)?
108 : } else {
109 : info!("pool: pool is disabled");
110 : None
111 : };
112 :
113 : if let Some(client) = maybe_client {
114 : return Ok(client);
115 : }
116 : let conn_id = uuid::Uuid::new_v4();
117 : tracing::Span::current().record("conn_id", display(conn_id));
118 : info!(%conn_id, "pool: opening a new connection '{conn_info}'");
119 0 : let backend = self.config.auth_backend.as_ref().map(|_| keys);
120 : crate::proxy::connect_compute::connect_to_compute(
121 : ctx,
122 : &TokioMechanism {
123 : conn_id,
124 : conn_info,
125 : pool: self.pool.clone(),
126 : locks: &self.config.connect_compute_locks,
127 : },
128 : &backend,
129 : false, // do not allow self signed compute for http flow
130 : self.config.wake_compute_retry_config,
131 : self.config.connect_to_compute_retry_config,
132 : )
133 : .await
134 : }
135 : }
136 :
137 0 : #[derive(Debug, thiserror::Error)]
138 : pub enum HttpConnError {
139 : #[error("pooled connection closed at inconsistent state")]
140 : ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError<uuid::Uuid>),
141 : #[error("could not connection to compute")]
142 : ConnectionError(#[from] tokio_postgres::Error),
143 :
144 : #[error("could not get auth info")]
145 : GetAuthInfo(#[from] GetAuthInfoError),
146 : #[error("user not authenticated")]
147 : AuthError(#[from] AuthError),
148 : #[error("wake_compute returned error")]
149 : WakeCompute(#[from] WakeComputeError),
150 : #[error("error acquiring resource permit: {0}")]
151 : TooManyConnectionAttempts(#[from] ApiLockError),
152 : }
153 :
154 : impl ReportableError for HttpConnError {
155 0 : fn get_error_kind(&self) -> ErrorKind {
156 0 : match self {
157 0 : HttpConnError::ConnectionClosedAbruptly(_) => ErrorKind::Compute,
158 0 : HttpConnError::ConnectionError(p) => p.get_error_kind(),
159 0 : HttpConnError::GetAuthInfo(a) => a.get_error_kind(),
160 0 : HttpConnError::AuthError(a) => a.get_error_kind(),
161 0 : HttpConnError::WakeCompute(w) => w.get_error_kind(),
162 0 : HttpConnError::TooManyConnectionAttempts(w) => w.get_error_kind(),
163 : }
164 0 : }
165 : }
166 :
167 : impl UserFacingError for HttpConnError {
168 0 : fn to_string_client(&self) -> String {
169 0 : match self {
170 0 : HttpConnError::ConnectionClosedAbruptly(_) => self.to_string(),
171 0 : HttpConnError::ConnectionError(p) => p.to_string(),
172 0 : HttpConnError::GetAuthInfo(c) => c.to_string_client(),
173 0 : HttpConnError::AuthError(c) => c.to_string_client(),
174 0 : HttpConnError::WakeCompute(c) => c.to_string_client(),
175 : HttpConnError::TooManyConnectionAttempts(_) => {
176 0 : "Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
177 : }
178 : }
179 0 : }
180 : }
181 :
182 : impl ShouldRetry for HttpConnError {
183 0 : fn could_retry(&self) -> bool {
184 0 : match self {
185 0 : HttpConnError::ConnectionError(e) => e.could_retry(),
186 0 : HttpConnError::ConnectionClosedAbruptly(_) => false,
187 0 : HttpConnError::GetAuthInfo(_) => false,
188 0 : HttpConnError::AuthError(_) => false,
189 0 : HttpConnError::WakeCompute(_) => false,
190 0 : HttpConnError::TooManyConnectionAttempts(_) => false,
191 : }
192 0 : }
193 0 : fn should_retry_database_address(&self) -> bool {
194 0 : match self {
195 0 : HttpConnError::ConnectionError(e) => e.should_retry_database_address(),
196 : // we never checked cache validity
197 0 : HttpConnError::TooManyConnectionAttempts(_) => false,
198 0 : _ => true,
199 : }
200 0 : }
201 : }
202 :
203 : struct TokioMechanism {
204 : pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
205 : conn_info: ConnInfo,
206 : conn_id: uuid::Uuid,
207 :
208 : /// connect_to_compute concurrency lock
209 : locks: &'static ApiLocks<Host>,
210 : }
211 :
212 : #[async_trait]
213 : impl ConnectMechanism for TokioMechanism {
214 : type Connection = Client<tokio_postgres::Client>;
215 : type ConnectError = HttpConnError;
216 : type Error = HttpConnError;
217 :
218 0 : async fn connect_once(
219 0 : &self,
220 0 : ctx: &mut RequestMonitoring,
221 0 : node_info: &CachedNodeInfo,
222 0 : timeout: Duration,
223 0 : ) -> Result<Self::Connection, Self::ConnectError> {
224 0 : let host = node_info.config.get_host()?;
225 0 : let permit = self.locks.get_permit(&host).await?;
226 0 :
227 0 : let mut config = (*node_info.config).clone();
228 0 : let config = config
229 0 : .user(&self.conn_info.user_info.user)
230 0 : .password(&*self.conn_info.password)
231 0 : .dbname(&self.conn_info.dbname)
232 0 : .connect_timeout(timeout);
233 0 :
234 0 : let pause = ctx.latency_timer.pause(crate::metrics::Waiting::Compute);
235 0 : let res = config.connect(tokio_postgres::NoTls).await;
236 0 : drop(pause);
237 0 : let (client, connection) = permit.release_result(res)?;
238 0 :
239 0 : tracing::Span::current().record("pid", &tracing::field::display(client.get_process_id()));
240 0 : Ok(poll_client(
241 0 : self.pool.clone(),
242 0 : ctx,
243 0 : self.conn_info.clone(),
244 0 : client,
245 0 : connection,
246 0 : self.conn_id,
247 0 : node_info.aux.clone(),
248 0 : ))
249 0 : }
250 :
251 0 : fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
252 : }
|