Line data Source code
1 : //! Mock console backend which relies on a user-provided postgres instance.
2 :
3 : use std::net::{IpAddr, Ipv4Addr};
4 : use std::str::FromStr;
5 : use std::sync::Arc;
6 :
7 : use futures::TryFutureExt;
8 : use thiserror::Error;
9 : use tokio_postgres::Client;
10 : use tracing::{Instrument, error, info, info_span, warn};
11 :
12 : use crate::auth::IpPattern;
13 : use crate::auth::backend::ComputeUserInfo;
14 : use crate::auth::backend::jwt::AuthRule;
15 : use crate::cache::Cached;
16 : use crate::context::RequestContext;
17 : use crate::control_plane::client::{
18 : CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedRoleSecret,
19 : };
20 : use crate::control_plane::errors::{
21 : ControlPlaneError, GetAuthInfoError, GetEndpointJwksError, WakeComputeError,
22 : };
23 : use crate::control_plane::messages::MetricsAuxInfo;
24 : use crate::control_plane::{AccessBlockerFlags, AuthInfo, AuthSecret, CachedNodeInfo, NodeInfo};
25 : use crate::error::io_error;
26 : use crate::intern::RoleNameInt;
27 : use crate::types::{BranchId, EndpointId, ProjectId, RoleName};
28 : use crate::url::ApiUrl;
29 : use crate::{compute, scram};
30 :
31 : #[derive(Debug, Error)]
32 : enum MockApiError {
33 : #[error("Failed to read password: {0}")]
34 : PasswordNotSet(tokio_postgres::Error),
35 : }
36 :
37 : impl From<MockApiError> for ControlPlaneError {
38 0 : fn from(e: MockApiError) -> Self {
39 0 : io_error(e).into()
40 0 : }
41 : }
42 :
43 : impl From<tokio_postgres::Error> for ControlPlaneError {
44 0 : fn from(e: tokio_postgres::Error) -> Self {
45 0 : io_error(e).into()
46 0 : }
47 : }
48 :
49 : #[derive(Clone)]
50 : pub struct MockControlPlane {
51 : endpoint: ApiUrl,
52 : ip_allowlist_check_enabled: bool,
53 : }
54 :
55 : impl MockControlPlane {
56 0 : pub fn new(endpoint: ApiUrl, ip_allowlist_check_enabled: bool) -> Self {
57 0 : Self {
58 0 : endpoint,
59 0 : ip_allowlist_check_enabled,
60 0 : }
61 0 : }
62 :
63 0 : pub(crate) fn url(&self) -> &str {
64 0 : self.endpoint.as_str()
65 0 : }
66 :
67 0 : async fn do_get_auth_info(
68 0 : &self,
69 0 : user_info: &ComputeUserInfo,
70 0 : ) -> Result<AuthInfo, GetAuthInfoError> {
71 0 : let (secret, allowed_ips) = async {
72 : // Perhaps we could persist this connection, but then we'd have to
73 : // write more code for reopening it if it got closed, which doesn't
74 : // seem worth it.
75 0 : let (client, connection) =
76 0 : tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?;
77 :
78 0 : tokio::spawn(connection);
79 :
80 0 : let secret = if let Some(entry) = get_execute_postgres_query(
81 0 : &client,
82 0 : "select rolpassword from pg_catalog.pg_authid where rolname = $1",
83 0 : &[&&*user_info.user],
84 0 : "rolpassword",
85 0 : )
86 0 : .await?
87 : {
88 0 : info!("got a secret: {entry}"); // safe since it's not a prod scenario
89 0 : let secret = scram::ServerSecret::parse(&entry).map(AuthSecret::Scram);
90 0 : secret.or_else(|| parse_md5(&entry).map(AuthSecret::Md5))
91 : } else {
92 0 : warn!("user '{}' does not exist", user_info.user);
93 0 : None
94 : };
95 :
96 0 : let allowed_ips = if self.ip_allowlist_check_enabled {
97 0 : match get_execute_postgres_query(
98 0 : &client,
99 0 : "select allowed_ips from neon_control_plane.endpoints where endpoint_id = $1",
100 0 : &[&user_info.endpoint.as_str()],
101 0 : "allowed_ips",
102 0 : )
103 0 : .await?
104 : {
105 0 : Some(s) => {
106 0 : info!("got allowed_ips: {s}");
107 0 : s.split(',')
108 0 : .map(|s| {
109 0 : IpPattern::from_str(s).expect("mocked ip pattern should be correct")
110 0 : })
111 0 : .collect()
112 : }
113 0 : None => vec![],
114 : }
115 : } else {
116 0 : vec![]
117 : };
118 :
119 0 : Ok((secret, allowed_ips))
120 0 : }
121 0 : .inspect_err(|e: &GetAuthInfoError| tracing::error!("{e}"))
122 0 : .instrument(info_span!("postgres", url = self.endpoint.as_str()))
123 0 : .await?;
124 0 : Ok(AuthInfo {
125 0 : secret,
126 0 : allowed_ips,
127 0 : allowed_vpc_endpoint_ids: vec![],
128 0 : project_id: None,
129 0 : account_id: None,
130 0 : access_blocker_flags: AccessBlockerFlags::default(),
131 0 : })
132 0 : }
133 :
134 0 : async fn do_get_endpoint_jwks(
135 0 : &self,
136 0 : endpoint: EndpointId,
137 0 : ) -> Result<Vec<AuthRule>, GetEndpointJwksError> {
138 0 : let (client, connection) =
139 0 : tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?;
140 :
141 0 : let connection = tokio::spawn(connection);
142 :
143 0 : let res = client.query(
144 0 : "select id, jwks_url, audience, role_names from neon_control_plane.endpoint_jwks where endpoint_id = $1",
145 0 : &[&endpoint.as_str()],
146 0 : )
147 0 : .await?;
148 :
149 0 : let mut rows = vec![];
150 0 : for row in res {
151 0 : rows.push(AuthRule {
152 0 : id: row.get("id"),
153 0 : jwks_url: url::Url::parse(row.get("jwks_url"))?,
154 0 : audience: row.get("audience"),
155 0 : role_names: row
156 0 : .get::<_, Vec<String>>("role_names")
157 0 : .into_iter()
158 0 : .map(RoleName::from)
159 0 : .map(|s| RoleNameInt::from(&s))
160 0 : .collect(),
161 0 : });
162 0 : }
163 :
164 0 : drop(client);
165 0 : connection.await??;
166 :
167 0 : Ok(rows)
168 0 : }
169 :
170 0 : async fn do_wake_compute(&self) -> Result<NodeInfo, WakeComputeError> {
171 0 : let port = self.endpoint.port().unwrap_or(5432);
172 0 : let mut config = match self.endpoint.host_str() {
173 : None => {
174 0 : let mut config = compute::ConnCfg::new("localhost".to_string(), port);
175 0 : config.set_host_addr(IpAddr::V4(Ipv4Addr::LOCALHOST));
176 0 : config
177 : }
178 0 : Some(host) => {
179 0 : let mut config = compute::ConnCfg::new(host.to_string(), port);
180 0 : if let Ok(addr) = IpAddr::from_str(host) {
181 0 : config.set_host_addr(addr);
182 0 : }
183 0 : config
184 : }
185 : };
186 :
187 0 : config.ssl_mode(postgres_client::config::SslMode::Disable);
188 0 :
189 0 : let node = NodeInfo {
190 0 : config,
191 0 : aux: MetricsAuxInfo {
192 0 : endpoint_id: (&EndpointId::from("endpoint")).into(),
193 0 : project_id: (&ProjectId::from("project")).into(),
194 0 : branch_id: (&BranchId::from("branch")).into(),
195 0 : compute_id: "compute".into(),
196 0 : cold_start_info: crate::control_plane::messages::ColdStartInfo::Warm,
197 0 : },
198 0 : };
199 0 :
200 0 : Ok(node)
201 0 : }
202 : }
203 :
204 0 : async fn get_execute_postgres_query(
205 0 : client: &Client,
206 0 : query: &str,
207 0 : params: &[&(dyn tokio_postgres::types::ToSql + Sync)],
208 0 : idx: &str,
209 0 : ) -> Result<Option<String>, GetAuthInfoError> {
210 0 : let rows = client.query(query, params).await?;
211 :
212 : // We can get at most one row, because `rolname` is unique.
213 0 : let Some(row) = rows.first() else {
214 : // This means that the user doesn't exist, so there can be no secret.
215 : // However, this is still a *valid* outcome which is very similar
216 : // to getting `404 Not found` from the Neon console.
217 0 : return Ok(None);
218 : };
219 :
220 0 : let entry = row.try_get(idx).map_err(MockApiError::PasswordNotSet)?;
221 0 : Ok(Some(entry))
222 0 : }
223 :
224 : impl super::ControlPlaneApi for MockControlPlane {
225 : #[tracing::instrument(skip_all)]
226 : async fn get_role_secret(
227 : &self,
228 : _ctx: &RequestContext,
229 : user_info: &ComputeUserInfo,
230 : ) -> Result<CachedRoleSecret, GetAuthInfoError> {
231 : Ok(CachedRoleSecret::new_uncached(
232 : self.do_get_auth_info(user_info).await?.secret,
233 : ))
234 : }
235 :
236 0 : async fn get_allowed_ips(
237 0 : &self,
238 0 : _ctx: &RequestContext,
239 0 : user_info: &ComputeUserInfo,
240 0 : ) -> Result<CachedAllowedIps, GetAuthInfoError> {
241 0 : Ok(Cached::new_uncached(Arc::new(
242 0 : self.do_get_auth_info(user_info).await?.allowed_ips,
243 : )))
244 0 : }
245 :
246 0 : async fn get_allowed_vpc_endpoint_ids(
247 0 : &self,
248 0 : _ctx: &RequestContext,
249 0 : user_info: &ComputeUserInfo,
250 0 : ) -> Result<CachedAllowedVpcEndpointIds, super::errors::GetAuthInfoError> {
251 0 : Ok(Cached::new_uncached(Arc::new(
252 0 : self.do_get_auth_info(user_info)
253 0 : .await?
254 : .allowed_vpc_endpoint_ids,
255 : )))
256 0 : }
257 :
258 0 : async fn get_block_public_or_vpc_access(
259 0 : &self,
260 0 : _ctx: &RequestContext,
261 0 : user_info: &ComputeUserInfo,
262 0 : ) -> Result<super::CachedAccessBlockerFlags, super::errors::GetAuthInfoError> {
263 0 : Ok(Cached::new_uncached(
264 0 : self.do_get_auth_info(user_info).await?.access_blocker_flags,
265 : ))
266 0 : }
267 :
268 0 : async fn get_endpoint_jwks(
269 0 : &self,
270 0 : _ctx: &RequestContext,
271 0 : endpoint: EndpointId,
272 0 : ) -> Result<Vec<AuthRule>, GetEndpointJwksError> {
273 0 : self.do_get_endpoint_jwks(endpoint).await
274 0 : }
275 :
276 : #[tracing::instrument(skip_all)]
277 : async fn wake_compute(
278 : &self,
279 : _ctx: &RequestContext,
280 : _user_info: &ComputeUserInfo,
281 : ) -> Result<CachedNodeInfo, WakeComputeError> {
282 : self.do_wake_compute().map_ok(Cached::new_uncached).await
283 : }
284 : }
285 :
286 0 : fn parse_md5(input: &str) -> Option<[u8; 16]> {
287 0 : let text = input.strip_prefix("md5")?;
288 :
289 0 : let mut bytes = [0u8; 16];
290 0 : hex::decode_to_slice(text, &mut bytes).ok()?;
291 :
292 0 : Some(bytes)
293 0 : }
|