Line data Source code
1 : //! Mock console backend which relies on a user-provided postgres instance.
2 :
3 : use super::{
4 : errors::{ApiError, GetAuthInfoError, WakeComputeError},
5 : AuthInfo, AuthSecret, CachedNodeInfo, NodeInfo,
6 : };
7 : use crate::console::provider::{CachedAllowedIps, CachedRoleSecret};
8 : use crate::context::RequestMonitoring;
9 : use crate::{auth::backend::ComputeUserInfo, compute, error::io_error, scram, url::ApiUrl};
10 : use crate::{auth::IpPattern, cache::Cached};
11 : use async_trait::async_trait;
12 : use futures::TryFutureExt;
13 : use std::{str::FromStr, sync::Arc};
14 : use thiserror::Error;
15 : use tokio_postgres::{config::SslMode, Client};
16 : use tracing::{error, info, info_span, warn, Instrument};
17 :
18 0 : #[derive(Debug, Error)]
19 : enum MockApiError {
20 : #[error("Failed to read password: {0}")]
21 : PasswordNotSet(tokio_postgres::Error),
22 : }
23 :
24 : impl From<MockApiError> for ApiError {
25 0 : fn from(e: MockApiError) -> Self {
26 0 : io_error(e).into()
27 0 : }
28 : }
29 :
30 : impl From<tokio_postgres::Error> for ApiError {
31 0 : fn from(e: tokio_postgres::Error) -> Self {
32 0 : io_error(e).into()
33 0 : }
34 : }
35 :
36 0 : #[derive(Clone)]
37 : pub struct Api {
38 : endpoint: ApiUrl,
39 : }
40 :
41 : impl Api {
42 19 : pub fn new(endpoint: ApiUrl) -> Self {
43 19 : Self { endpoint }
44 19 : }
45 :
46 19 : pub fn url(&self) -> &str {
47 19 : self.endpoint.as_str()
48 19 : }
49 :
50 123 : async fn do_get_auth_info(
51 123 : &self,
52 123 : user_info: &ComputeUserInfo,
53 123 : ) -> Result<AuthInfo, GetAuthInfoError> {
54 123 : let (secret, allowed_ips) = async {
55 : // Perhaps we could persist this connection, but then we'd have to
56 : // write more code for reopening it if it got closed, which doesn't
57 : // seem worth it.
58 123 : let (client, connection) =
59 373 : tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?;
60 :
61 123 : tokio::spawn(connection);
62 123 : let secret = match get_execute_postgres_query(
63 123 : &client,
64 123 : "select rolpassword from pg_catalog.pg_authid where rolname = $1",
65 123 : &[&&*user_info.user],
66 123 : "rolpassword",
67 123 : )
68 246 : .await?
69 : {
70 121 : Some(entry) => {
71 121 : info!("got a secret: {entry}"); // safe since it's not a prod scenario
72 121 : let secret = scram::ServerSecret::parse(&entry).map(AuthSecret::Scram);
73 121 : secret.or_else(|| parse_md5(&entry).map(AuthSecret::Md5))
74 : }
75 : None => {
76 2 : warn!("user '{}' does not exist", user_info.user);
77 2 : None
78 : }
79 : };
80 123 : let allowed_ips = match get_execute_postgres_query(
81 123 : &client,
82 123 : "select allowed_ips from neon_control_plane.endpoints where endpoint_id = $1",
83 123 : &[&user_info.endpoint.as_str()],
84 123 : "allowed_ips",
85 123 : )
86 246 : .await?
87 : {
88 11 : Some(s) => {
89 11 : info!("got allowed_ips: {s}");
90 11 : s.split(',')
91 18 : .map(|s| IpPattern::from_str(s).unwrap())
92 11 : .collect()
93 : }
94 112 : None => vec![],
95 : };
96 :
97 123 : Ok((secret, allowed_ips))
98 123 : }
99 123 : .map_err(crate::error::log_error::<GetAuthInfoError>)
100 123 : .instrument(info_span!("postgres", url = self.endpoint.as_str()))
101 865 : .await?;
102 123 : Ok(AuthInfo {
103 123 : secret,
104 123 : allowed_ips,
105 123 : project_id: None,
106 123 : })
107 123 : }
108 :
109 79 : async fn do_wake_compute(&self) -> Result<NodeInfo, WakeComputeError> {
110 79 : let mut config = compute::ConnCfg::new();
111 79 : config
112 79 : .host(self.endpoint.host_str().unwrap_or("localhost"))
113 79 : .port(self.endpoint.port().unwrap_or(5432))
114 79 : .ssl_mode(SslMode::Disable);
115 79 :
116 79 : let node = NodeInfo {
117 79 : config,
118 79 : aux: Default::default(),
119 79 : allow_self_signed_compute: false,
120 79 : };
121 79 :
122 79 : Ok(node)
123 79 : }
124 : }
125 :
126 246 : async fn get_execute_postgres_query(
127 246 : client: &Client,
128 246 : query: &str,
129 246 : params: &[&(dyn tokio_postgres::types::ToSql + Sync)],
130 246 : idx: &str,
131 246 : ) -> Result<Option<String>, GetAuthInfoError> {
132 492 : let rows = client.query(query, params).await?;
133 :
134 : // We can get at most one row, because `rolname` is unique.
135 246 : let row = match rows.first() {
136 132 : Some(row) => row,
137 : // This means that the user doesn't exist, so there can be no secret.
138 : // However, this is still a *valid* outcome which is very similar
139 : // to getting `404 Not found` from the Neon console.
140 114 : None => return Ok(None),
141 : };
142 :
143 132 : let entry = row.try_get(idx).map_err(MockApiError::PasswordNotSet)?;
144 132 : Ok(Some(entry))
145 246 : }
146 :
147 : #[async_trait]
148 : impl super::Api for Api {
149 0 : #[tracing::instrument(skip_all)]
150 : async fn get_role_secret(
151 : &self,
152 : _ctx: &mut RequestMonitoring,
153 : user_info: &ComputeUserInfo,
154 39 : ) -> Result<CachedRoleSecret, GetAuthInfoError> {
155 : Ok(CachedRoleSecret::new_uncached(
156 267 : self.do_get_auth_info(user_info).await?.secret,
157 : ))
158 78 : }
159 :
160 84 : async fn get_allowed_ips_and_secret(
161 84 : &self,
162 84 : _ctx: &mut RequestMonitoring,
163 84 : user_info: &ComputeUserInfo,
164 84 : ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
165 : Ok((
166 : Cached::new_uncached(Arc::new(
167 598 : self.do_get_auth_info(user_info).await?.allowed_ips,
168 : )),
169 84 : None,
170 : ))
171 168 : }
172 :
173 0 : #[tracing::instrument(skip_all)]
174 : async fn wake_compute(
175 : &self,
176 : _ctx: &mut RequestMonitoring,
177 : _user_info: &ComputeUserInfo,
178 79 : ) -> Result<CachedNodeInfo, WakeComputeError> {
179 79 : self.do_wake_compute()
180 79 : .map_ok(CachedNodeInfo::new_uncached)
181 0 : .await
182 158 : }
183 : }
184 :
185 0 : fn parse_md5(input: &str) -> Option<[u8; 16]> {
186 0 : let text = input.strip_prefix("md5")?;
187 :
188 0 : let mut bytes = [0u8; 16];
189 0 : hex::decode_to_slice(text, &mut bytes).ok()?;
190 :
191 0 : Some(bytes)
192 0 : }
|