TLA Line data Source code
1 : //! Mock console backend which relies on a user-provided postgres instance.
2 :
3 : use std::sync::Arc;
4 :
5 : use super::{
6 : errors::{ApiError, GetAuthInfoError, WakeComputeError},
7 : AuthInfo, AuthSecret, CachedNodeInfo, ConsoleReqExtra, NodeInfo,
8 : };
9 : use crate::{auth::backend::ComputeUserInfo, compute, error::io_error, scram, url::ApiUrl};
10 : use crate::{console::provider::CachedRoleSecret, context::RequestMonitoring};
11 : use async_trait::async_trait;
12 : use futures::TryFutureExt;
13 : use thiserror::Error;
14 : use tokio_postgres::{config::SslMode, Client};
15 : use tracing::{error, info, info_span, warn, Instrument};
16 :
17 UBC 0 : #[derive(Debug, Error)]
18 : enum MockApiError {
19 : #[error("Failed to read password: {0}")]
20 : PasswordNotSet(tokio_postgres::Error),
21 : }
22 :
23 : impl From<MockApiError> for ApiError {
24 0 : fn from(e: MockApiError) -> Self {
25 0 : io_error(e).into()
26 0 : }
27 : }
28 :
29 : impl From<tokio_postgres::Error> for ApiError {
30 0 : fn from(e: tokio_postgres::Error) -> Self {
31 0 : io_error(e).into()
32 0 : }
33 : }
34 :
35 0 : #[derive(Clone)]
36 : pub struct Api {
37 : endpoint: ApiUrl,
38 : }
39 :
40 : impl Api {
41 CBC 18 : pub fn new(endpoint: ApiUrl) -> Self {
42 18 : Self { endpoint }
43 18 : }
44 :
45 18 : pub fn url(&self) -> &str {
46 18 : self.endpoint.as_str()
47 18 : }
48 :
49 120 : async fn do_get_auth_info(
50 120 : &self,
51 120 : creds: &ComputeUserInfo,
52 120 : ) -> Result<AuthInfo, GetAuthInfoError> {
53 120 : let (secret, allowed_ips) = async {
54 : // Perhaps we could persist this connection, but then we'd have to
55 : // write more code for reopening it if it got closed, which doesn't
56 : // seem worth it.
57 120 : let (client, connection) =
58 364 : tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?;
59 :
60 120 : tokio::spawn(connection);
61 120 : let secret = match get_execute_postgres_query(
62 120 : &client,
63 120 : "select rolpassword from pg_catalog.pg_authid where rolname = $1",
64 120 : &[&&*creds.inner.user],
65 120 : "rolpassword",
66 120 : )
67 240 : .await?
68 : {
69 118 : Some(entry) => {
70 118 : info!("got a secret: {entry}"); // safe since it's not a prod scenario
71 118 : let secret = scram::ServerSecret::parse(&entry).map(AuthSecret::Scram);
72 118 : secret.or_else(|| parse_md5(&entry).map(AuthSecret::Md5))
73 : }
74 : None => {
75 2 : warn!("user '{}' does not exist", creds.inner.user);
76 2 : None
77 : }
78 : };
79 120 : let allowed_ips = match get_execute_postgres_query(
80 120 : &client,
81 120 : "select allowed_ips from neon_control_plane.endpoints where endpoint_id = $1",
82 120 : &[&creds.endpoint.as_str()],
83 120 : "allowed_ips",
84 120 : )
85 240 : .await?
86 : {
87 11 : Some(s) => {
88 11 : info!("got allowed_ips: {s}");
89 11 : s.split(',').map(String::from).collect()
90 : }
91 109 : None => vec![],
92 : };
93 :
94 120 : Ok((secret, allowed_ips))
95 120 : }
96 120 : .map_err(crate::error::log_error::<GetAuthInfoError>)
97 120 : .instrument(info_span!("postgres", url = self.endpoint.as_str()))
98 844 : .await?;
99 120 : Ok(AuthInfo {
100 120 : secret,
101 120 : allowed_ips,
102 120 : })
103 120 : }
104 :
105 77 : async fn do_wake_compute(&self) -> Result<NodeInfo, WakeComputeError> {
106 77 : let mut config = compute::ConnCfg::new();
107 77 : config
108 77 : .host(self.endpoint.host_str().unwrap_or("localhost"))
109 77 : .port(self.endpoint.port().unwrap_or(5432))
110 77 : .ssl_mode(SslMode::Disable);
111 77 :
112 77 : let node = NodeInfo {
113 77 : config,
114 77 : aux: Default::default(),
115 77 : allow_self_signed_compute: false,
116 77 : };
117 77 :
118 77 : Ok(node)
119 77 : }
120 : }
121 :
122 240 : async fn get_execute_postgres_query(
123 240 : client: &Client,
124 240 : query: &str,
125 240 : params: &[&(dyn tokio_postgres::types::ToSql + Sync)],
126 240 : idx: &str,
127 240 : ) -> Result<Option<String>, GetAuthInfoError> {
128 480 : let rows = client.query(query, params).await?;
129 :
130 : // We can get at most one row, because `rolname` is unique.
131 240 : let row = match rows.first() {
132 129 : Some(row) => row,
133 : // This means that the user doesn't exist, so there can be no secret.
134 : // However, this is still a *valid* outcome which is very similar
135 : // to getting `404 Not found` from the Neon console.
136 111 : None => return Ok(None),
137 : };
138 :
139 129 : let entry = row.try_get(idx).map_err(MockApiError::PasswordNotSet)?;
140 129 : Ok(Some(entry))
141 240 : }
142 :
143 : #[async_trait]
144 : impl super::Api for Api {
145 UBC 0 : #[tracing::instrument(skip_all)]
146 : async fn get_role_secret(
147 : &self,
148 : _ctx: &mut RequestMonitoring,
149 : creds: &ComputeUserInfo,
150 CBC 38 : ) -> Result<CachedRoleSecret, GetAuthInfoError> {
151 : Ok(CachedRoleSecret::new_uncached(
152 264 : self.do_get_auth_info(creds).await?.secret,
153 : ))
154 76 : }
155 :
156 82 : async fn get_allowed_ips(
157 82 : &self,
158 82 : _ctx: &mut RequestMonitoring,
159 82 : creds: &ComputeUserInfo,
160 82 : ) -> Result<Arc<Vec<String>>, GetAuthInfoError> {
161 580 : Ok(Arc::new(self.do_get_auth_info(creds).await?.allowed_ips))
162 164 : }
163 :
164 UBC 0 : #[tracing::instrument(skip_all)]
165 : async fn wake_compute(
166 : &self,
167 : _ctx: &mut RequestMonitoring,
168 : _extra: &ConsoleReqExtra,
169 : _creds: &ComputeUserInfo,
170 CBC 77 : ) -> Result<CachedNodeInfo, WakeComputeError> {
171 77 : self.do_wake_compute()
172 77 : .map_ok(CachedNodeInfo::new_uncached)
173 UBC 0 : .await
174 CBC 154 : }
175 : }
176 :
177 UBC 0 : fn parse_md5(input: &str) -> Option<[u8; 16]> {
178 0 : let text = input.strip_prefix("md5")?;
179 :
180 0 : let mut bytes = [0u8; 16];
181 0 : hex::decode_to_slice(text, &mut bytes).ok()?;
182 :
183 0 : Some(bytes)
184 0 : }
|