Line data Source code
1 : //! Mock console backend which relies on a user-provided postgres instance.
2 :
3 : use super::{
4 : errors::{ApiError, GetAuthInfoError, WakeComputeError},
5 : AuthInfo, AuthSecret, CachedNodeInfo, NodeInfo,
6 : };
7 : use crate::context::RequestMonitoring;
8 : use crate::{auth::backend::ComputeUserInfo, compute, error::io_error, scram, url::ApiUrl};
9 : use crate::{auth::IpPattern, cache::Cached};
10 : use crate::{
11 : console::{
12 : messages::MetricsAuxInfo,
13 : provider::{CachedAllowedIps, CachedRoleSecret},
14 : },
15 : BranchId, EndpointId, ProjectId,
16 : };
17 : use futures::TryFutureExt;
18 : use std::{str::FromStr, sync::Arc};
19 : use thiserror::Error;
20 : use tokio_postgres::{config::SslMode, Client};
21 : use tracing::{error, info, info_span, warn, Instrument};
22 :
23 0 : #[derive(Debug, Error)]
24 : enum MockApiError {
25 : #[error("Failed to read password: {0}")]
26 : PasswordNotSet(tokio_postgres::Error),
27 : }
28 :
29 : impl From<MockApiError> for ApiError {
30 0 : fn from(e: MockApiError) -> Self {
31 0 : io_error(e).into()
32 0 : }
33 : }
34 :
35 : impl From<tokio_postgres::Error> for ApiError {
36 0 : fn from(e: tokio_postgres::Error) -> Self {
37 0 : io_error(e).into()
38 0 : }
39 : }
40 :
41 : #[derive(Clone)]
42 : pub struct Api {
43 : endpoint: ApiUrl,
44 : ip_allowlist_check_enabled: bool,
45 : }
46 :
47 : impl Api {
48 0 : pub fn new(endpoint: ApiUrl, ip_allowlist_check_enabled: bool) -> Self {
49 0 : Self {
50 0 : endpoint,
51 0 : ip_allowlist_check_enabled,
52 0 : }
53 0 : }
54 :
55 0 : pub(crate) fn url(&self) -> &str {
56 0 : self.endpoint.as_str()
57 0 : }
58 :
59 0 : async fn do_get_auth_info(
60 0 : &self,
61 0 : user_info: &ComputeUserInfo,
62 0 : ) -> Result<AuthInfo, GetAuthInfoError> {
63 0 : let (secret, allowed_ips) = async {
64 : // Perhaps we could persist this connection, but then we'd have to
65 : // write more code for reopening it if it got closed, which doesn't
66 : // seem worth it.
67 0 : let (client, connection) =
68 0 : tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?;
69 :
70 0 : tokio::spawn(connection);
71 :
72 0 : let secret = if let Some(entry) = get_execute_postgres_query(
73 0 : &client,
74 0 : "select rolpassword from pg_catalog.pg_authid where rolname = $1",
75 0 : &[&&*user_info.user],
76 0 : "rolpassword",
77 0 : )
78 0 : .await?
79 : {
80 0 : info!("got a secret: {entry}"); // safe since it's not a prod scenario
81 0 : let secret = scram::ServerSecret::parse(&entry).map(AuthSecret::Scram);
82 0 : secret.or_else(|| parse_md5(&entry).map(AuthSecret::Md5))
83 : } else {
84 0 : warn!("user '{}' does not exist", user_info.user);
85 0 : None
86 : };
87 :
88 0 : let allowed_ips = if self.ip_allowlist_check_enabled {
89 0 : match get_execute_postgres_query(
90 0 : &client,
91 0 : "select allowed_ips from neon_control_plane.endpoints where endpoint_id = $1",
92 0 : &[&user_info.endpoint.as_str()],
93 0 : "allowed_ips",
94 0 : )
95 0 : .await?
96 : {
97 0 : Some(s) => {
98 0 : info!("got allowed_ips: {s}");
99 0 : s.split(',')
100 0 : .map(|s| IpPattern::from_str(s).unwrap())
101 0 : .collect()
102 : }
103 0 : None => vec![],
104 : }
105 : } else {
106 0 : vec![]
107 : };
108 :
109 0 : Ok((secret, allowed_ips))
110 0 : }
111 0 : .map_err(crate::error::log_error::<GetAuthInfoError>)
112 0 : .instrument(info_span!("postgres", url = self.endpoint.as_str()))
113 0 : .await?;
114 0 : Ok(AuthInfo {
115 0 : secret,
116 0 : allowed_ips,
117 0 : project_id: None,
118 0 : })
119 0 : }
120 :
121 0 : async fn do_wake_compute(&self) -> Result<NodeInfo, WakeComputeError> {
122 0 : let mut config = compute::ConnCfg::new();
123 0 : config
124 0 : .host(self.endpoint.host_str().unwrap_or("localhost"))
125 0 : .port(self.endpoint.port().unwrap_or(5432))
126 0 : .ssl_mode(SslMode::Disable);
127 0 :
128 0 : let node = NodeInfo {
129 0 : config,
130 0 : aux: MetricsAuxInfo {
131 0 : endpoint_id: (&EndpointId::from("endpoint")).into(),
132 0 : project_id: (&ProjectId::from("project")).into(),
133 0 : branch_id: (&BranchId::from("branch")).into(),
134 0 : cold_start_info: crate::console::messages::ColdStartInfo::Warm,
135 0 : },
136 0 : allow_self_signed_compute: false,
137 0 : };
138 0 :
139 0 : Ok(node)
140 0 : }
141 : }
142 :
143 0 : async fn get_execute_postgres_query(
144 0 : client: &Client,
145 0 : query: &str,
146 0 : params: &[&(dyn tokio_postgres::types::ToSql + Sync)],
147 0 : idx: &str,
148 0 : ) -> Result<Option<String>, GetAuthInfoError> {
149 0 : let rows = client.query(query, params).await?;
150 :
151 : // We can get at most one row, because `rolname` is unique.
152 0 : let Some(row) = rows.first() else {
153 : // This means that the user doesn't exist, so there can be no secret.
154 : // However, this is still a *valid* outcome which is very similar
155 : // to getting `404 Not found` from the Neon console.
156 0 : return Ok(None);
157 : };
158 :
159 0 : let entry = row.try_get(idx).map_err(MockApiError::PasswordNotSet)?;
160 0 : Ok(Some(entry))
161 0 : }
162 :
163 : impl super::Api for Api {
164 0 : #[tracing::instrument(skip_all)]
165 : async fn get_role_secret(
166 : &self,
167 : _ctx: &RequestMonitoring,
168 : user_info: &ComputeUserInfo,
169 : ) -> Result<CachedRoleSecret, GetAuthInfoError> {
170 : Ok(CachedRoleSecret::new_uncached(
171 : self.do_get_auth_info(user_info).await?.secret,
172 : ))
173 : }
174 :
175 0 : async fn get_allowed_ips_and_secret(
176 0 : &self,
177 0 : _ctx: &RequestMonitoring,
178 0 : user_info: &ComputeUserInfo,
179 0 : ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
180 0 : Ok((
181 0 : Cached::new_uncached(Arc::new(
182 0 : self.do_get_auth_info(user_info).await?.allowed_ips,
183 : )),
184 0 : None,
185 : ))
186 0 : }
187 :
188 0 : #[tracing::instrument(skip_all)]
189 : async fn wake_compute(
190 : &self,
191 : _ctx: &RequestMonitoring,
192 : _user_info: &ComputeUserInfo,
193 : ) -> Result<CachedNodeInfo, WakeComputeError> {
194 : self.do_wake_compute().map_ok(Cached::new_uncached).await
195 : }
196 : }
197 :
198 0 : fn parse_md5(input: &str) -> Option<[u8; 16]> {
199 0 : let text = input.strip_prefix("md5")?;
200 :
201 0 : let mut bytes = [0u8; 16];
202 0 : hex::decode_to_slice(text, &mut bytes).ok()?;
203 :
204 0 : Some(bytes)
205 0 : }
|