Line data Source code
1 : //! Mock console backend which relies on a user-provided postgres instance.
2 :
3 : use std::str::FromStr;
4 : use std::sync::Arc;
5 :
6 : use futures::TryFutureExt;
7 : use thiserror::Error;
8 : use tokio_postgres::config::SslMode;
9 : use tokio_postgres::Client;
10 : use tracing::{error, info, info_span, warn, Instrument};
11 :
12 : use super::errors::{ApiError, GetAuthInfoError, WakeComputeError};
13 : use super::{AuthInfo, AuthSecret, CachedNodeInfo, NodeInfo};
14 : use crate::auth::backend::jwt::AuthRule;
15 : use crate::auth::backend::ComputeUserInfo;
16 : use crate::auth::IpPattern;
17 : use crate::cache::Cached;
18 : use crate::context::RequestMonitoring;
19 : use crate::control_plane::errors::GetEndpointJwksError;
20 : use crate::control_plane::messages::MetricsAuxInfo;
21 : use crate::control_plane::provider::{CachedAllowedIps, CachedRoleSecret};
22 : use crate::error::io_error;
23 : use crate::intern::RoleNameInt;
24 : use crate::url::ApiUrl;
25 : use crate::{compute, scram, BranchId, EndpointId, ProjectId, RoleName};
26 :
27 0 : #[derive(Debug, Error)]
28 : enum MockApiError {
29 : #[error("Failed to read password: {0}")]
30 : PasswordNotSet(tokio_postgres::Error),
31 : }
32 :
33 : impl From<MockApiError> for ApiError {
34 0 : fn from(e: MockApiError) -> Self {
35 0 : io_error(e).into()
36 0 : }
37 : }
38 :
39 : impl From<tokio_postgres::Error> for ApiError {
40 0 : fn from(e: tokio_postgres::Error) -> Self {
41 0 : io_error(e).into()
42 0 : }
43 : }
44 :
45 : #[derive(Clone)]
46 : pub struct Api {
47 : endpoint: ApiUrl,
48 : ip_allowlist_check_enabled: bool,
49 : }
50 :
51 : impl Api {
52 0 : pub fn new(endpoint: ApiUrl, ip_allowlist_check_enabled: bool) -> Self {
53 0 : Self {
54 0 : endpoint,
55 0 : ip_allowlist_check_enabled,
56 0 : }
57 0 : }
58 :
59 0 : pub(crate) fn url(&self) -> &str {
60 0 : self.endpoint.as_str()
61 0 : }
62 :
63 0 : async fn do_get_auth_info(
64 0 : &self,
65 0 : user_info: &ComputeUserInfo,
66 0 : ) -> Result<AuthInfo, GetAuthInfoError> {
67 0 : let (secret, allowed_ips) = async {
68 : // Perhaps we could persist this connection, but then we'd have to
69 : // write more code for reopening it if it got closed, which doesn't
70 : // seem worth it.
71 0 : let (client, connection) =
72 0 : tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?;
73 :
74 0 : tokio::spawn(connection);
75 :
76 0 : let secret = if let Some(entry) = get_execute_postgres_query(
77 0 : &client,
78 0 : "select rolpassword from pg_catalog.pg_authid where rolname = $1",
79 0 : &[&&*user_info.user],
80 0 : "rolpassword",
81 0 : )
82 0 : .await?
83 : {
84 0 : info!("got a secret: {entry}"); // safe since it's not a prod scenario
85 0 : let secret = scram::ServerSecret::parse(&entry).map(AuthSecret::Scram);
86 0 : secret.or_else(|| parse_md5(&entry).map(AuthSecret::Md5))
87 : } else {
88 0 : warn!("user '{}' does not exist", user_info.user);
89 0 : None
90 : };
91 :
92 0 : let allowed_ips = if self.ip_allowlist_check_enabled {
93 0 : match get_execute_postgres_query(
94 0 : &client,
95 0 : "select allowed_ips from neon_control_plane.endpoints where endpoint_id = $1",
96 0 : &[&user_info.endpoint.as_str()],
97 0 : "allowed_ips",
98 0 : )
99 0 : .await?
100 : {
101 0 : Some(s) => {
102 0 : info!("got allowed_ips: {s}");
103 0 : s.split(',')
104 0 : .map(|s| IpPattern::from_str(s).unwrap())
105 0 : .collect()
106 : }
107 0 : None => vec![],
108 : }
109 : } else {
110 0 : vec![]
111 : };
112 :
113 0 : Ok((secret, allowed_ips))
114 0 : }
115 0 : .map_err(crate::error::log_error::<GetAuthInfoError>)
116 0 : .instrument(info_span!("postgres", url = self.endpoint.as_str()))
117 0 : .await?;
118 0 : Ok(AuthInfo {
119 0 : secret,
120 0 : allowed_ips,
121 0 : project_id: None,
122 0 : })
123 0 : }
124 :
125 0 : async fn do_get_endpoint_jwks(
126 0 : &self,
127 0 : endpoint: EndpointId,
128 0 : ) -> Result<Vec<AuthRule>, GetEndpointJwksError> {
129 0 : let (client, connection) =
130 0 : tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?;
131 :
132 0 : let connection = tokio::spawn(connection);
133 :
134 0 : let res = client.query(
135 0 : "select id, jwks_url, audience, role_names from neon_control_plane.endpoint_jwks where endpoint_id = $1",
136 0 : &[&endpoint.as_str()],
137 0 : )
138 0 : .await?;
139 :
140 0 : let mut rows = vec![];
141 0 : for row in res {
142 0 : rows.push(AuthRule {
143 0 : id: row.get("id"),
144 0 : jwks_url: url::Url::parse(row.get("jwks_url"))?,
145 0 : audience: row.get("audience"),
146 0 : role_names: row
147 0 : .get::<_, Vec<String>>("role_names")
148 0 : .into_iter()
149 0 : .map(RoleName::from)
150 0 : .map(|s| RoleNameInt::from(&s))
151 0 : .collect(),
152 0 : });
153 0 : }
154 :
155 0 : drop(client);
156 0 : connection.await??;
157 :
158 0 : Ok(rows)
159 0 : }
160 :
161 0 : async fn do_wake_compute(&self) -> Result<NodeInfo, WakeComputeError> {
162 0 : let mut config = compute::ConnCfg::new();
163 0 : config
164 0 : .host(self.endpoint.host_str().unwrap_or("localhost"))
165 0 : .port(self.endpoint.port().unwrap_or(5432))
166 0 : .ssl_mode(SslMode::Disable);
167 0 :
168 0 : let node = NodeInfo {
169 0 : config,
170 0 : aux: MetricsAuxInfo {
171 0 : endpoint_id: (&EndpointId::from("endpoint")).into(),
172 0 : project_id: (&ProjectId::from("project")).into(),
173 0 : branch_id: (&BranchId::from("branch")).into(),
174 0 : cold_start_info: crate::control_plane::messages::ColdStartInfo::Warm,
175 0 : },
176 0 : allow_self_signed_compute: false,
177 0 : };
178 0 :
179 0 : Ok(node)
180 0 : }
181 : }
182 :
183 0 : async fn get_execute_postgres_query(
184 0 : client: &Client,
185 0 : query: &str,
186 0 : params: &[&(dyn tokio_postgres::types::ToSql + Sync)],
187 0 : idx: &str,
188 0 : ) -> Result<Option<String>, GetAuthInfoError> {
189 0 : let rows = client.query(query, params).await?;
190 :
191 : // We can get at most one row, because `rolname` is unique.
192 0 : let Some(row) = rows.first() else {
193 : // This means that the user doesn't exist, so there can be no secret.
194 : // However, this is still a *valid* outcome which is very similar
195 : // to getting `404 Not found` from the Neon console.
196 0 : return Ok(None);
197 : };
198 :
199 0 : let entry = row.try_get(idx).map_err(MockApiError::PasswordNotSet)?;
200 0 : Ok(Some(entry))
201 0 : }
202 :
203 : impl super::Api for Api {
204 0 : #[tracing::instrument(skip_all)]
205 : async fn get_role_secret(
206 : &self,
207 : _ctx: &RequestMonitoring,
208 : user_info: &ComputeUserInfo,
209 : ) -> Result<CachedRoleSecret, GetAuthInfoError> {
210 : Ok(CachedRoleSecret::new_uncached(
211 : self.do_get_auth_info(user_info).await?.secret,
212 : ))
213 : }
214 :
215 0 : async fn get_allowed_ips_and_secret(
216 0 : &self,
217 0 : _ctx: &RequestMonitoring,
218 0 : user_info: &ComputeUserInfo,
219 0 : ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
220 0 : Ok((
221 0 : Cached::new_uncached(Arc::new(
222 0 : self.do_get_auth_info(user_info).await?.allowed_ips,
223 : )),
224 0 : None,
225 : ))
226 0 : }
227 :
228 0 : async fn get_endpoint_jwks(
229 0 : &self,
230 0 : _ctx: &RequestMonitoring,
231 0 : endpoint: EndpointId,
232 0 : ) -> Result<Vec<AuthRule>, GetEndpointJwksError> {
233 0 : self.do_get_endpoint_jwks(endpoint).await
234 0 : }
235 :
236 0 : #[tracing::instrument(skip_all)]
237 : async fn wake_compute(
238 : &self,
239 : _ctx: &RequestMonitoring,
240 : _user_info: &ComputeUserInfo,
241 : ) -> Result<CachedNodeInfo, WakeComputeError> {
242 : self.do_wake_compute().map_ok(Cached::new_uncached).await
243 : }
244 : }
245 :
246 0 : fn parse_md5(input: &str) -> Option<[u8; 16]> {
247 0 : let text = input.strip_prefix("md5")?;
248 :
249 0 : let mut bytes = [0u8; 16];
250 0 : hex::decode_to_slice(text, &mut bytes).ok()?;
251 :
252 0 : Some(bytes)
253 0 : }
|