TLA Line data Source code
1 : //! Mock console backend which relies on a user-provided postgres instance.
2 :
3 : use super::{
4 : errors::{ApiError, GetAuthInfoError, WakeComputeError},
5 : AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo,
6 : };
7 : use crate::{auth::ClientCredentials, compute, error::io_error, scram, url::ApiUrl};
8 : use async_trait::async_trait;
9 : use futures::TryFutureExt;
10 : use thiserror::Error;
11 : use tokio_postgres::config::SslMode;
12 : use tracing::{error, info, info_span, warn, Instrument};
13 :
14 UBC 0 : #[derive(Debug, Error)]
15 : enum MockApiError {
16 : #[error("Failed to read password: {0}")]
17 : PasswordNotSet(tokio_postgres::Error),
18 : }
19 :
20 : impl From<MockApiError> for ApiError {
21 0 : fn from(e: MockApiError) -> Self {
22 0 : io_error(e).into()
23 0 : }
24 : }
25 :
26 : impl From<tokio_postgres::Error> for ApiError {
27 0 : fn from(e: tokio_postgres::Error) -> Self {
28 0 : io_error(e).into()
29 0 : }
30 : }
31 :
32 0 : #[derive(Clone)]
33 : pub struct Api {
34 : endpoint: ApiUrl,
35 : }
36 :
37 : impl Api {
38 CBC 13 : pub fn new(endpoint: ApiUrl) -> Self {
39 13 : Self { endpoint }
40 13 : }
41 :
42 13 : pub fn url(&self) -> &str {
43 13 : self.endpoint.as_str()
44 13 : }
45 :
46 27 : async fn do_get_auth_info(
47 27 : &self,
48 27 : creds: &ClientCredentials<'_>,
49 27 : ) -> Result<Option<AuthInfo>, GetAuthInfoError> {
50 27 : async {
51 : // Perhaps we could persist this connection, but then we'd have to
52 : // write more code for reopening it if it got closed, which doesn't
53 : // seem worth it.
54 27 : let (client, connection) =
55 81 : tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?;
56 :
57 27 : tokio::spawn(connection);
58 27 : let query = "select rolpassword from pg_catalog.pg_authid where rolname = $1";
59 54 : let rows = client.query(query, &[&creds.user]).await?;
60 :
61 : // We can get at most one row, because `rolname` is unique.
62 27 : let row = match rows.get(0) {
63 26 : Some(row) => row,
64 : // This means that the user doesn't exist, so there can be no secret.
65 : // However, this is still a *valid* outcome which is very similar
66 : // to getting `404 Not found` from the Neon console.
67 : None => {
68 1 : warn!("user '{}' does not exist", creds.user);
69 1 : return Ok(None);
70 : }
71 : };
72 :
73 26 : let entry = row
74 26 : .try_get("rolpassword")
75 26 : .map_err(MockApiError::PasswordNotSet)?;
76 :
77 26 : info!("got a secret: {entry}"); // safe since it's not a prod scenario
78 26 : let secret = scram::ServerSecret::parse(entry).map(AuthInfo::Scram);
79 26 : Ok(secret.or_else(|| parse_md5(entry).map(AuthInfo::Md5)))
80 27 : }
81 27 : .map_err(crate::error::log_error)
82 27 : .instrument(info_span!("postgres", url = self.endpoint.as_str()))
83 135 : .await
84 27 : }
85 :
86 55 : async fn do_wake_compute(&self) -> Result<NodeInfo, WakeComputeError> {
87 55 : let mut config = compute::ConnCfg::new();
88 55 : config
89 55 : .host(self.endpoint.host_str().unwrap_or("localhost"))
90 55 : .port(self.endpoint.port().unwrap_or(5432))
91 55 : .ssl_mode(SslMode::Disable);
92 55 :
93 55 : let node = NodeInfo {
94 55 : config,
95 55 : aux: Default::default(),
96 55 : allow_self_signed_compute: false,
97 55 : };
98 55 :
99 55 : Ok(node)
100 55 : }
101 : }
102 :
103 : #[async_trait]
104 : impl super::Api for Api {
105 81 : #[tracing::instrument(skip_all)]
106 : async fn get_auth_info(
107 : &self,
108 : _extra: &ConsoleReqExtra<'_>,
109 : creds: &ClientCredentials,
110 27 : ) -> Result<Option<AuthInfo>, GetAuthInfoError> {
111 135 : self.do_get_auth_info(creds).await
112 54 : }
113 :
114 165 : #[tracing::instrument(skip_all)]
115 : async fn wake_compute(
116 : &self,
117 : _extra: &ConsoleReqExtra<'_>,
118 : _creds: &ClientCredentials,
119 55 : ) -> Result<CachedNodeInfo, WakeComputeError> {
120 55 : self.do_wake_compute()
121 55 : .map_ok(CachedNodeInfo::new_uncached)
122 UBC 0 : .await
123 CBC 110 : }
124 : }
125 :
126 UBC 0 : fn parse_md5(input: &str) -> Option<[u8; 16]> {
127 0 : let text = input.strip_prefix("md5")?;
128 :
129 0 : let mut bytes = [0u8; 16];
130 0 : hex::decode_to_slice(text, &mut bytes).ok()?;
131 :
132 0 : Some(bytes)
133 0 : }
|