Line data Source code
1 : use std::{sync::Arc, time::Duration};
2 :
3 : use async_trait::async_trait;
4 : use tracing::{field::display, info};
5 :
6 : use crate::{
7 : auth::{backend::ComputeCredentialKeys, check_peer_addr_is_in_list, AuthError},
8 : compute,
9 : config::ProxyConfig,
10 : console::{
11 : errors::{GetAuthInfoError, WakeComputeError},
12 : CachedNodeInfo,
13 : },
14 : context::RequestMonitoring,
15 : proxy::connect_compute::ConnectMechanism,
16 : };
17 :
18 : use super::conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool};
19 :
20 : pub struct PoolingBackend {
21 : pub pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
22 : pub config: &'static ProxyConfig,
23 : }
24 :
25 : impl PoolingBackend {
26 47 : pub async fn authenticate(
27 47 : &self,
28 47 : ctx: &mut RequestMonitoring,
29 47 : conn_info: &ConnInfo,
30 47 : ) -> Result<ComputeCredentialKeys, AuthError> {
31 47 : let user_info = conn_info.user_info.clone();
32 47 : let backend = self.config.auth_backend.as_ref().map(|_| user_info.clone());
33 366 : let (allowed_ips, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?;
34 47 : if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) {
35 1 : return Err(AuthError::ip_address_not_allowed());
36 46 : }
37 46 : let cached_secret = match maybe_secret {
38 0 : Some(secret) => secret,
39 317 : None => backend.get_role_secret(ctx).await?,
40 : };
41 :
42 46 : let secret = match cached_secret.value.clone() {
43 45 : Some(secret) => secret,
44 : None => {
45 : // If we don't have an authentication secret, for the http flow we can just return an error.
46 1 : info!("authentication info not found");
47 1 : return Err(AuthError::auth_failed(&*user_info.user));
48 : }
49 : };
50 45 : let auth_outcome =
51 45 : crate::auth::validate_password_and_exchange(&conn_info.password, secret)?;
52 45 : match auth_outcome {
53 44 : crate::sasl::Outcome::Success(key) => Ok(key),
54 1 : crate::sasl::Outcome::Failure(reason) => {
55 1 : info!("auth backend failed with an error: {reason}");
56 1 : Err(AuthError::auth_failed(&*conn_info.user_info.user))
57 : }
58 : }
59 47 : }
60 :
61 : // Wake up the destination if needed. Code here is a bit involved because
62 : // we reuse the code from the usual proxy and we need to prepare few structures
63 : // that this code expects.
64 88 : #[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)]
65 : pub async fn connect_to_compute(
66 : &self,
67 : ctx: &mut RequestMonitoring,
68 : conn_info: ConnInfo,
69 : keys: ComputeCredentialKeys,
70 : force_new: bool,
71 : ) -> Result<Client<tokio_postgres::Client>, HttpConnError> {
72 : let maybe_client = if !force_new {
73 24 : info!("pool: looking for an existing connection");
74 : self.pool.get(ctx, &conn_info).await?
75 : } else {
76 20 : info!("pool: pool is disabled");
77 : None
78 : };
79 :
80 : if let Some(client) = maybe_client {
81 : return Ok(client);
82 : }
83 : let conn_id = uuid::Uuid::new_v4();
84 : tracing::Span::current().record("conn_id", display(conn_id));
85 40 : info!("pool: opening a new connection '{conn_info}'");
86 : let backend = self
87 : .config
88 : .auth_backend
89 : .as_ref()
90 40 : .map(|_| conn_info.user_info.clone());
91 :
92 : let mut node_info = backend
93 : .wake_compute(ctx)
94 : .await?
95 : .ok_or(HttpConnError::NoComputeInfo)?;
96 :
97 : match keys {
98 : #[cfg(any(test, feature = "testing"))]
99 : ComputeCredentialKeys::Password(password) => node_info.config.password(password),
100 : ComputeCredentialKeys::AuthKeys(auth_keys) => node_info.config.auth_keys(auth_keys),
101 : };
102 :
103 : ctx.set_project(node_info.aux.clone());
104 :
105 : crate::proxy::connect_compute::connect_to_compute(
106 : ctx,
107 : &TokioMechanism {
108 : conn_id,
109 : conn_info,
110 : pool: self.pool.clone(),
111 : },
112 : node_info,
113 : &backend,
114 : )
115 : .await
116 : }
117 : }
118 :
119 18 : #[derive(Debug, thiserror::Error)]
120 : pub enum HttpConnError {
121 : #[error("pooled connection closed at inconsistent state")]
122 : ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError<uuid::Uuid>),
123 : #[error("could not connection to compute")]
124 : ConnectionError(#[from] tokio_postgres::Error),
125 :
126 : #[error("could not get auth info")]
127 : GetAuthInfo(#[from] GetAuthInfoError),
128 : #[error("user not authenticated")]
129 : AuthError(#[from] AuthError),
130 : #[error("wake_compute returned error")]
131 : WakeCompute(#[from] WakeComputeError),
132 : #[error("wake_compute returned nothing")]
133 : NoComputeInfo,
134 : }
135 :
136 : struct TokioMechanism {
137 : pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
138 : conn_info: ConnInfo,
139 : conn_id: uuid::Uuid,
140 : }
141 :
142 : #[async_trait]
143 : impl ConnectMechanism for TokioMechanism {
144 : type Connection = Client<tokio_postgres::Client>;
145 : type ConnectError = tokio_postgres::Error;
146 : type Error = HttpConnError;
147 :
148 40 : async fn connect_once(
149 40 : &self,
150 40 : ctx: &mut RequestMonitoring,
151 40 : node_info: &CachedNodeInfo,
152 40 : timeout: Duration,
153 40 : ) -> Result<Self::Connection, Self::ConnectError> {
154 40 : let mut config = (*node_info.config).clone();
155 40 : let config = config
156 40 : .user(&self.conn_info.user_info.user)
157 40 : .password(&*self.conn_info.password)
158 40 : .dbname(&self.conn_info.dbname)
159 40 : .connect_timeout(timeout);
160 :
161 128 : let (client, connection) = config.connect(tokio_postgres::NoTls).await?;
162 :
163 40 : tracing::Span::current().record("pid", &tracing::field::display(client.get_process_id()));
164 40 : Ok(poll_client(
165 40 : self.pool.clone(),
166 40 : ctx,
167 40 : self.conn_info.clone(),
168 40 : client,
169 40 : connection,
170 40 : self.conn_id,
171 40 : node_info.aux.clone(),
172 40 : ))
173 120 : }
174 :
175 40 : fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
176 : }
|