Line data Source code
1 : //! Mock console backend which relies on a user-provided postgres instance.
2 :
3 : use std::io;
4 : use std::net::{IpAddr, Ipv4Addr};
5 : use std::str::FromStr;
6 : use std::sync::Arc;
7 :
8 : use futures::TryFutureExt;
9 : use postgres_client::config::SslMode;
10 : use thiserror::Error;
11 : use tokio_postgres::Client;
12 : use tracing::{Instrument, error, info, info_span, warn};
13 :
14 : use crate::auth::IpPattern;
15 : use crate::auth::backend::ComputeUserInfo;
16 : use crate::auth::backend::jwt::AuthRule;
17 : use crate::cache::Cached;
18 : use crate::compute::ConnectInfo;
19 : use crate::context::RequestContext;
20 : use crate::control_plane::errors::{
21 : ControlPlaneError, GetAuthInfoError, GetEndpointJwksError, WakeComputeError,
22 : };
23 : use crate::control_plane::messages::{EndpointRateLimitConfig, MetricsAuxInfo};
24 : use crate::control_plane::{
25 : AccessBlockerFlags, AuthInfo, AuthSecret, CachedNodeInfo, EndpointAccessControl, NodeInfo,
26 : RoleAccessControl,
27 : };
28 : use crate::intern::RoleNameInt;
29 : use crate::scram;
30 : use crate::types::{BranchId, EndpointId, ProjectId, RoleName};
31 : use crate::url::ApiUrl;
32 :
33 : #[derive(Debug, Error)]
34 : enum MockApiError {
35 : #[error("Failed to read password: {0}")]
36 : PasswordNotSet(tokio_postgres::Error),
37 : }
38 :
39 : impl From<MockApiError> for ControlPlaneError {
40 0 : fn from(e: MockApiError) -> Self {
41 0 : io::Error::other(e).into()
42 0 : }
43 : }
44 :
45 : impl From<tokio_postgres::Error> for ControlPlaneError {
46 0 : fn from(e: tokio_postgres::Error) -> Self {
47 0 : io::Error::other(e).into()
48 0 : }
49 : }
50 :
51 : #[derive(Clone)]
52 : pub struct MockControlPlane {
53 : endpoint: ApiUrl,
54 : ip_allowlist_check_enabled: bool,
55 : }
56 :
57 : impl MockControlPlane {
58 0 : pub fn new(endpoint: ApiUrl, ip_allowlist_check_enabled: bool) -> Self {
59 0 : Self {
60 0 : endpoint,
61 0 : ip_allowlist_check_enabled,
62 0 : }
63 0 : }
64 :
65 0 : pub(crate) fn url(&self) -> &str {
66 0 : self.endpoint.as_str()
67 0 : }
68 :
69 0 : async fn do_get_auth_info(
70 0 : &self,
71 0 : endpoint: &EndpointId,
72 0 : role: &RoleName,
73 0 : ) -> Result<AuthInfo, GetAuthInfoError> {
74 0 : let (secret, allowed_ips) = async {
75 : // Perhaps we could persist this connection, but then we'd have to
76 : // write more code for reopening it if it got closed, which doesn't
77 : // seem worth it.
78 0 : let (client, connection) =
79 0 : tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?;
80 :
81 0 : tokio::spawn(connection);
82 :
83 0 : let secret = if let Some(entry) = get_execute_postgres_query(
84 0 : &client,
85 0 : "select rolpassword from pg_catalog.pg_authid where rolname = $1",
86 0 : &[&role.as_str()],
87 0 : "rolpassword",
88 0 : )
89 0 : .await?
90 : {
91 0 : info!("got a secret: {entry}"); // safe since it's not a prod scenario
92 0 : scram::ServerSecret::parse(&entry).map(AuthSecret::Scram)
93 : } else {
94 0 : warn!("user '{role}' does not exist");
95 0 : None
96 : };
97 :
98 0 : let allowed_ips = if self.ip_allowlist_check_enabled {
99 0 : match get_execute_postgres_query(
100 0 : &client,
101 0 : "select allowed_ips from neon_control_plane.endpoints where endpoint_id = $1",
102 0 : &[&endpoint.as_str()],
103 0 : "allowed_ips",
104 0 : )
105 0 : .await?
106 : {
107 0 : Some(s) => {
108 0 : info!("got allowed_ips: {s}");
109 0 : s.split(',')
110 0 : .map(|s| {
111 0 : IpPattern::from_str(s).expect("mocked ip pattern should be correct")
112 0 : })
113 0 : .collect()
114 : }
115 0 : None => vec![],
116 : }
117 : } else {
118 0 : vec![]
119 : };
120 :
121 0 : Ok((secret, allowed_ips))
122 0 : }
123 0 : .inspect_err(|e: &GetAuthInfoError| tracing::error!("{e}"))
124 0 : .instrument(info_span!("postgres", url = self.endpoint.as_str()))
125 0 : .await?;
126 0 : Ok(AuthInfo {
127 0 : secret,
128 0 : allowed_ips,
129 0 : allowed_vpc_endpoint_ids: vec![],
130 0 : project_id: None,
131 0 : account_id: None,
132 0 : access_blocker_flags: AccessBlockerFlags::default(),
133 0 : rate_limits: EndpointRateLimitConfig::default(),
134 0 : })
135 0 : }
136 :
137 0 : async fn do_get_endpoint_jwks(
138 0 : &self,
139 0 : endpoint: &EndpointId,
140 0 : ) -> Result<Vec<AuthRule>, GetEndpointJwksError> {
141 0 : let (client, connection) =
142 0 : tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?;
143 :
144 0 : let connection = tokio::spawn(connection);
145 :
146 0 : let res = client.query(
147 0 : "select id, jwks_url, audience, role_names from neon_control_plane.endpoint_jwks where endpoint_id = $1",
148 0 : &[&endpoint.as_str()],
149 0 : )
150 0 : .await?;
151 :
152 0 : let mut rows = vec![];
153 0 : for row in res {
154 0 : rows.push(AuthRule {
155 0 : id: row.get("id"),
156 0 : jwks_url: url::Url::parse(row.get("jwks_url"))?,
157 0 : audience: row.get("audience"),
158 0 : role_names: row
159 0 : .get::<_, Vec<String>>("role_names")
160 0 : .into_iter()
161 0 : .map(RoleName::from)
162 0 : .map(|s| RoleNameInt::from(&s))
163 0 : .collect(),
164 : });
165 : }
166 :
167 0 : drop(client);
168 0 : connection.await??;
169 :
170 0 : Ok(rows)
171 0 : }
172 :
173 0 : async fn do_wake_compute(&self) -> Result<NodeInfo, WakeComputeError> {
174 0 : let port = self.endpoint.port().unwrap_or(5432);
175 0 : let conn_info = match self.endpoint.host_str() {
176 0 : None => ConnectInfo {
177 0 : host_addr: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
178 0 : host: "localhost".into(),
179 0 : port,
180 0 : ssl_mode: SslMode::Disable,
181 0 : },
182 0 : Some(host) => ConnectInfo {
183 0 : host_addr: IpAddr::from_str(host).ok(),
184 0 : host: host.into(),
185 0 : port,
186 0 : ssl_mode: SslMode::Disable,
187 0 : },
188 : };
189 :
190 0 : let node = NodeInfo {
191 0 : conn_info,
192 0 : aux: MetricsAuxInfo {
193 0 : endpoint_id: (&EndpointId::from("endpoint")).into(),
194 0 : project_id: (&ProjectId::from("project")).into(),
195 0 : branch_id: (&BranchId::from("branch")).into(),
196 0 : compute_id: "compute".into(),
197 0 : cold_start_info: crate::control_plane::messages::ColdStartInfo::Warm,
198 0 : },
199 0 : };
200 :
201 0 : Ok(node)
202 0 : }
203 : }
204 :
205 0 : async fn get_execute_postgres_query(
206 0 : client: &Client,
207 0 : query: &str,
208 0 : params: &[&(dyn tokio_postgres::types::ToSql + Sync)],
209 0 : idx: &str,
210 0 : ) -> Result<Option<String>, GetAuthInfoError> {
211 0 : let rows = client.query(query, params).await?;
212 :
213 : // We can get at most one row, because `rolname` is unique.
214 0 : let Some(row) = rows.first() else {
215 : // This means that the user doesn't exist, so there can be no secret.
216 : // However, this is still a *valid* outcome which is very similar
217 : // to getting `404 Not found` from the Neon console.
218 0 : return Ok(None);
219 : };
220 :
221 0 : let entry = row.try_get(idx).map_err(MockApiError::PasswordNotSet)?;
222 0 : Ok(Some(entry))
223 0 : }
224 :
225 : impl super::ControlPlaneApi for MockControlPlane {
226 0 : async fn get_endpoint_access_control(
227 0 : &self,
228 0 : _ctx: &RequestContext,
229 0 : endpoint: &EndpointId,
230 0 : role: &RoleName,
231 0 : ) -> Result<EndpointAccessControl, GetAuthInfoError> {
232 0 : let info = self.do_get_auth_info(endpoint, role).await?;
233 0 : Ok(EndpointAccessControl {
234 0 : allowed_ips: Arc::new(info.allowed_ips),
235 0 : allowed_vpce: Arc::new(info.allowed_vpc_endpoint_ids),
236 0 : flags: info.access_blocker_flags,
237 0 : rate_limits: info.rate_limits,
238 0 : })
239 0 : }
240 :
241 0 : async fn get_role_access_control(
242 0 : &self,
243 0 : _ctx: &RequestContext,
244 0 : endpoint: &EndpointId,
245 0 : role: &RoleName,
246 0 : ) -> Result<RoleAccessControl, GetAuthInfoError> {
247 0 : let info = self.do_get_auth_info(endpoint, role).await?;
248 0 : Ok(RoleAccessControl {
249 0 : secret: info.secret,
250 0 : })
251 0 : }
252 :
253 0 : async fn get_endpoint_jwks(
254 0 : &self,
255 0 : _ctx: &RequestContext,
256 0 : endpoint: &EndpointId,
257 0 : ) -> Result<Vec<AuthRule>, GetEndpointJwksError> {
258 0 : self.do_get_endpoint_jwks(endpoint).await
259 0 : }
260 :
261 : #[tracing::instrument(skip_all)]
262 : async fn wake_compute(
263 : &self,
264 : _ctx: &RequestContext,
265 : _user_info: &ComputeUserInfo,
266 : ) -> Result<CachedNodeInfo, WakeComputeError> {
267 : self.do_wake_compute().map_ok(Cached::new_uncached).await
268 : }
269 : }
|