TLA Line data Source code
1 : //! Production console backend.
2 :
3 : use super::{
4 : super::messages::{ConsoleError, GetRoleSecret, WakeCompute},
5 : errors::{ApiError, GetAuthInfoError, WakeComputeError},
6 : ApiCaches, AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo,
7 : };
8 : use crate::{auth::ClientCredentials, compute, http, scram};
9 : use async_trait::async_trait;
10 : use futures::TryFutureExt;
11 : use std::net::SocketAddr;
12 : use tokio::time::Instant;
13 : use tokio_postgres::config::SslMode;
14 : use tracing::{error, info, info_span, warn, Instrument};
15 :
16 UBC 0 : #[derive(Clone)]
17 : pub struct Api {
18 : endpoint: http::Endpoint,
19 : caches: &'static ApiCaches,
20 : jwt: String,
21 : }
22 :
23 : impl Api {
24 : /// Construct an API object containing the auth parameters.
25 0 : pub fn new(endpoint: http::Endpoint, caches: &'static ApiCaches) -> Self {
26 0 : let jwt: String = match std::env::var("NEON_PROXY_TO_CONTROLPLANE_TOKEN") {
27 0 : Ok(v) => v,
28 0 : Err(_) => "".to_string(),
29 : };
30 0 : Self {
31 0 : endpoint,
32 0 : caches,
33 0 : jwt,
34 0 : }
35 0 : }
36 :
37 0 : pub fn url(&self) -> &str {
38 0 : self.endpoint.url().as_str()
39 0 : }
40 :
41 0 : async fn do_get_auth_info(
42 0 : &self,
43 0 : extra: &ConsoleReqExtra<'_>,
44 0 : creds: &ClientCredentials<'_>,
45 0 : ) -> Result<Option<AuthInfo>, GetAuthInfoError> {
46 0 : let request_id = uuid::Uuid::new_v4().to_string();
47 0 : async {
48 0 : let request = self
49 0 : .endpoint
50 0 : .get("proxy_get_role_secret")
51 0 : .header("X-Request-ID", &request_id)
52 0 : .header("Authorization", &self.jwt)
53 0 : .query(&[("session_id", extra.session_id)])
54 0 : .query(&[
55 0 : ("application_name", extra.application_name),
56 0 : ("project", Some(creds.project().expect("impossible"))),
57 0 : ("role", Some(creds.user)),
58 0 : ])
59 0 : .build()?;
60 :
61 0 : info!(url = request.url().as_str(), "sending http request");
62 0 : let start = Instant::now();
63 0 : let response = self.endpoint.execute(request).await?;
64 0 : info!(duration = ?start.elapsed(), "received http response");
65 0 : let body = match parse_body::<GetRoleSecret>(response).await {
66 0 : Ok(body) => body,
67 : // Error 404 is special: it's ok not to have a secret.
68 0 : Err(e) => match e.http_status_code() {
69 0 : Some(http::StatusCode::NOT_FOUND) => return Ok(None),
70 0 : _otherwise => return Err(e.into()),
71 : },
72 : };
73 :
74 0 : let secret = scram::ServerSecret::parse(&body.role_secret)
75 0 : .map(AuthInfo::Scram)
76 0 : .ok_or(GetAuthInfoError::BadSecret)?;
77 :
78 0 : Ok(Some(secret))
79 0 : }
80 0 : .map_err(crate::error::log_error)
81 0 : .instrument(info_span!("http", id = request_id))
82 0 : .await
83 0 : }
84 :
85 0 : async fn do_wake_compute(
86 0 : &self,
87 0 : extra: &ConsoleReqExtra<'_>,
88 0 : creds: &ClientCredentials<'_>,
89 0 : ) -> Result<NodeInfo, WakeComputeError> {
90 0 : let project = creds.project().expect("impossible");
91 0 : let request_id = uuid::Uuid::new_v4().to_string();
92 0 : async {
93 0 : let request = self
94 0 : .endpoint
95 0 : .get("proxy_wake_compute")
96 0 : .header("X-Request-ID", &request_id)
97 0 : .header("Authorization", &self.jwt)
98 0 : .query(&[("session_id", extra.session_id)])
99 0 : .query(&[
100 0 : ("application_name", extra.application_name),
101 0 : ("project", Some(project)),
102 0 : ])
103 0 : .build()?;
104 :
105 0 : info!(url = request.url().as_str(), "sending http request");
106 0 : let start = Instant::now();
107 0 : let response = self.endpoint.execute(request).await?;
108 0 : info!(duration = ?start.elapsed(), "received http response");
109 0 : let body = parse_body::<WakeCompute>(response).await?;
110 :
111 : // Unfortunately, ownership won't let us use `Option::ok_or` here.
112 0 : let (host, port) = match parse_host_port(&body.address) {
113 0 : None => return Err(WakeComputeError::BadComputeAddress(body.address)),
114 0 : Some(x) => x,
115 0 : };
116 0 :
117 0 : // Don't set anything but host and port! This config will be cached.
118 0 : // We'll set username and such later using the startup message.
119 0 : // TODO: add more type safety (in progress).
120 0 : let mut config = compute::ConnCfg::new();
121 0 : config.host(&host).port(port).ssl_mode(SslMode::Disable); // TLS is not configured on compute nodes.
122 0 :
123 0 : let node = NodeInfo {
124 0 : config,
125 0 : aux: body.aux.into(),
126 0 : allow_self_signed_compute: false,
127 0 : };
128 0 :
129 0 : Ok(node)
130 0 : }
131 0 : .map_err(crate::error::log_error)
132 0 : .instrument(info_span!("http", id = request_id))
133 0 : .await
134 0 : }
135 : }
136 :
137 : #[async_trait]
138 : impl super::Api for Api {
139 0 : #[tracing::instrument(skip_all)]
140 : async fn get_auth_info(
141 : &self,
142 : extra: &ConsoleReqExtra<'_>,
143 : creds: &ClientCredentials,
144 0 : ) -> Result<Option<AuthInfo>, GetAuthInfoError> {
145 0 : self.do_get_auth_info(extra, creds).await
146 0 : }
147 :
148 0 : #[tracing::instrument(skip_all)]
149 : async fn wake_compute(
150 : &self,
151 : extra: &ConsoleReqExtra<'_>,
152 : creds: &ClientCredentials,
153 0 : ) -> Result<CachedNodeInfo, WakeComputeError> {
154 0 : let key = creds.project().expect("impossible");
155 :
156 : // Every time we do a wakeup http request, the compute node will stay up
157 : // for some time (highly depends on the console's scale-to-zero policy);
158 : // The connection info remains the same during that period of time,
159 : // which means that we might cache it to reduce the load and latency.
160 0 : if let Some(cached) = self.caches.node_info.get(key) {
161 0 : info!(key = key, "found cached compute node info");
162 0 : return Ok(cached);
163 0 : }
164 :
165 0 : let node = self.do_wake_compute(extra, creds).await?;
166 0 : let (_, cached) = self.caches.node_info.insert(key.into(), node);
167 0 : info!(key = key, "created a cache entry for compute node info");
168 :
169 0 : Ok(cached)
170 0 : }
171 : }
172 :
173 : /// Parse http response body, taking status code into account.
174 0 : async fn parse_body<T: for<'a> serde::Deserialize<'a>>(
175 0 : response: http::Response,
176 0 : ) -> Result<T, ApiError> {
177 0 : let status = response.status();
178 0 : if status.is_success() {
179 : // We shouldn't log raw body because it may contain secrets.
180 0 : info!("request succeeded, processing the body");
181 0 : return Ok(response.json().await?);
182 0 : }
183 :
184 : // Don't throw an error here because it's not as important
185 : // as the fact that the request itself has failed.
186 0 : let body = response.json().await.unwrap_or_else(|e| {
187 0 : warn!("failed to parse error body: {e}");
188 0 : ConsoleError {
189 0 : error: "reason unclear (malformed error message)".into(),
190 0 : }
191 0 : });
192 0 :
193 0 : let text = body.error;
194 0 : error!("console responded with an error ({status}): {text}");
195 0 : Err(ApiError::Console { status, text })
196 0 : }
197 :
198 CBC 1 : fn parse_host_port(input: &str) -> Option<(String, u16)> {
199 1 : let parsed: SocketAddr = input.parse().ok()?;
200 1 : Some((parsed.ip().to_string(), parsed.port()))
201 1 : }
202 :
203 : #[cfg(test)]
204 : mod tests {
205 : use super::*;
206 :
207 1 : #[test]
208 1 : fn test_parse_host_port() {
209 1 : let (host, port) = parse_host_port("127.0.0.1:5432").expect("failed to parse");
210 1 : assert_eq!(host, "127.0.0.1");
211 1 : assert_eq!(port, 5432);
212 1 : }
213 : }
|