TLA Line data Source code
1 : //! Production console backend.
2 :
3 : use super::{
4 : super::messages::{ConsoleError, GetRoleSecret, WakeCompute},
5 : errors::{ApiError, GetAuthInfoError, WakeComputeError},
6 : ApiCaches, ApiLocks, AuthInfo, AuthSecret, CachedNodeInfo, CachedRoleSecret, ConsoleReqExtra,
7 : NodeInfo,
8 : };
9 : use crate::{auth::backend::ComputeUserInfo, compute, http, scram};
10 : use crate::{
11 : context::RequestMonitoring,
12 : metrics::{ALLOWED_IPS_BY_CACHE_OUTCOME, ALLOWED_IPS_NUMBER},
13 : };
14 : use async_trait::async_trait;
15 : use futures::TryFutureExt;
16 : use itertools::Itertools;
17 : use std::sync::Arc;
18 : use tokio::time::Instant;
19 : use tokio_postgres::config::SslMode;
20 : use tracing::{error, info, info_span, warn, Instrument};
21 :
22 UBC 0 : #[derive(Clone)]
23 : pub struct Api {
24 : endpoint: http::Endpoint,
25 : caches: &'static ApiCaches,
26 : locks: &'static ApiLocks,
27 : jwt: String,
28 : }
29 :
30 : impl Api {
31 : /// Construct an API object containing the auth parameters.
32 CBC 1 : pub fn new(
33 1 : endpoint: http::Endpoint,
34 1 : caches: &'static ApiCaches,
35 1 : locks: &'static ApiLocks,
36 1 : ) -> Self {
37 1 : let jwt: String = match std::env::var("NEON_PROXY_TO_CONTROLPLANE_TOKEN") {
38 UBC 0 : Ok(v) => v,
39 CBC 1 : Err(_) => "".to_string(),
40 : };
41 1 : Self {
42 1 : endpoint,
43 1 : caches,
44 1 : locks,
45 1 : jwt,
46 1 : }
47 1 : }
48 :
49 1 : pub fn url(&self) -> &str {
50 1 : self.endpoint.url().as_str()
51 1 : }
52 :
53 4 : async fn do_get_auth_info(
54 4 : &self,
55 4 : ctx: &mut RequestMonitoring,
56 4 : creds: &ComputeUserInfo,
57 4 : ) -> Result<AuthInfo, GetAuthInfoError> {
58 4 : let request_id = uuid::Uuid::new_v4().to_string();
59 4 : let application_name = ctx.console_application_name();
60 4 : async {
61 4 : let request = self
62 4 : .endpoint
63 4 : .get("proxy_get_role_secret")
64 4 : .header("X-Request-ID", &request_id)
65 4 : .header("Authorization", format!("Bearer {}", &self.jwt))
66 4 : .query(&[("session_id", ctx.session_id)])
67 4 : .query(&[
68 4 : ("application_name", application_name.as_str()),
69 4 : ("project", creds.endpoint.as_str()),
70 4 : ("role", creds.inner.user.as_str()),
71 4 : ])
72 4 : .build()?;
73 :
74 4 : info!(url = request.url().as_str(), "sending http request");
75 4 : let start = Instant::now();
76 12 : let response = self.endpoint.execute(request).await?;
77 3 : info!(duration = ?start.elapsed(), "received http response");
78 3 : let body = match parse_body::<GetRoleSecret>(response).await {
79 UBC 0 : Ok(body) => body,
80 : // Error 404 is special: it's ok not to have a secret.
81 CBC 3 : Err(e) => match e.http_status_code() {
82 UBC 0 : Some(http::StatusCode::NOT_FOUND) => return Ok(AuthInfo::default()),
83 CBC 3 : _otherwise => return Err(e.into()),
84 : },
85 : };
86 :
87 UBC 0 : let secret = scram::ServerSecret::parse(&body.role_secret)
88 0 : .map(AuthSecret::Scram)
89 0 : .ok_or(GetAuthInfoError::BadSecret)?;
90 0 : let allowed_ips = body
91 0 : .allowed_ips
92 0 : .into_iter()
93 0 : .flatten()
94 0 : .map(String::from)
95 0 : .collect_vec();
96 0 : ALLOWED_IPS_NUMBER.observe(allowed_ips.len() as f64);
97 0 : Ok(AuthInfo {
98 0 : secret: Some(secret),
99 0 : allowed_ips,
100 0 : })
101 CBC 4 : }
102 4 : .map_err(crate::error::log_error)
103 4 : .instrument(info_span!("http", id = request_id))
104 12 : .await
105 4 : }
106 :
107 UBC 0 : async fn do_wake_compute(
108 0 : &self,
109 0 : ctx: &mut RequestMonitoring,
110 0 : extra: &ConsoleReqExtra,
111 0 : creds: &ComputeUserInfo,
112 0 : ) -> Result<NodeInfo, WakeComputeError> {
113 0 : let request_id = uuid::Uuid::new_v4().to_string();
114 0 : let application_name = ctx.console_application_name();
115 0 : async {
116 0 : let mut request_builder = self
117 0 : .endpoint
118 0 : .get("proxy_wake_compute")
119 0 : .header("X-Request-ID", &request_id)
120 0 : .header("Authorization", format!("Bearer {}", &self.jwt))
121 0 : .query(&[("session_id", ctx.session_id)])
122 0 : .query(&[
123 0 : ("application_name", application_name.as_str()),
124 0 : ("project", creds.endpoint.as_str()),
125 0 : ]);
126 :
127 0 : request_builder = if extra.options.is_empty() {
128 0 : request_builder
129 : } else {
130 0 : request_builder.query(&extra.options_as_deep_object())
131 : };
132 0 : let request = request_builder.build()?;
133 :
134 0 : info!(url = request.url().as_str(), "sending http request");
135 0 : let start = Instant::now();
136 0 : let response = self.endpoint.execute(request).await?;
137 0 : info!(duration = ?start.elapsed(), "received http response");
138 0 : let body = parse_body::<WakeCompute>(response).await?;
139 :
140 : // Unfortunately, ownership won't let us use `Option::ok_or` here.
141 0 : let (host, port) = match parse_host_port(&body.address) {
142 0 : None => return Err(WakeComputeError::BadComputeAddress(body.address)),
143 0 : Some(x) => x,
144 0 : };
145 0 :
146 0 : // Don't set anything but host and port! This config will be cached.
147 0 : // We'll set username and such later using the startup message.
148 0 : // TODO: add more type safety (in progress).
149 0 : let mut config = compute::ConnCfg::new();
150 0 : config.host(host).port(port).ssl_mode(SslMode::Disable); // TLS is not configured on compute nodes.
151 0 :
152 0 : let node = NodeInfo {
153 0 : config,
154 0 : aux: body.aux,
155 0 : allow_self_signed_compute: false,
156 0 : };
157 0 :
158 0 : Ok(node)
159 0 : }
160 0 : .map_err(crate::error::log_error)
161 0 : .instrument(info_span!("http", id = request_id))
162 0 : .await
163 0 : }
164 : }
165 :
166 : #[async_trait]
167 : impl super::Api for Api {
168 0 : #[tracing::instrument(skip_all)]
169 : async fn get_role_secret(
170 : &self,
171 : ctx: &mut RequestMonitoring,
172 : creds: &ComputeUserInfo,
173 0 : ) -> Result<CachedRoleSecret, GetAuthInfoError> {
174 0 : let ep = creds.endpoint.clone();
175 0 : let user = creds.inner.user.clone();
176 0 : if let Some(role_secret) = self.caches.role_secret.get(&(ep.clone(), user.clone())) {
177 0 : return Ok(role_secret);
178 0 : }
179 0 : let auth_info = self.do_get_auth_info(ctx, creds).await?;
180 0 : let (_, secret) = self
181 0 : .caches
182 0 : .role_secret
183 0 : .insert((ep.clone(), user), auth_info.secret.clone());
184 0 : self.caches
185 0 : .allowed_ips
186 0 : .insert(ep, Arc::new(auth_info.allowed_ips));
187 0 : Ok(secret)
188 0 : }
189 :
190 CBC 4 : async fn get_allowed_ips(
191 4 : &self,
192 4 : ctx: &mut RequestMonitoring,
193 4 : creds: &ComputeUserInfo,
194 4 : ) -> Result<Arc<Vec<String>>, GetAuthInfoError> {
195 4 : if let Some(allowed_ips) = self.caches.allowed_ips.get(&creds.endpoint) {
196 UBC 0 : ALLOWED_IPS_BY_CACHE_OUTCOME
197 0 : .with_label_values(&["hit"])
198 0 : .inc();
199 0 : return Ok(Arc::new(allowed_ips.to_vec()));
200 CBC 4 : }
201 4 : ALLOWED_IPS_BY_CACHE_OUTCOME
202 4 : .with_label_values(&["miss"])
203 4 : .inc();
204 12 : let auth_info = self.do_get_auth_info(ctx, creds).await?;
205 UBC 0 : let allowed_ips = Arc::new(auth_info.allowed_ips);
206 0 : let ep = creds.endpoint.clone();
207 0 : let user = creds.inner.user.clone();
208 0 : self.caches
209 0 : .role_secret
210 0 : .insert((ep.clone(), user), auth_info.secret);
211 0 : self.caches.allowed_ips.insert(ep, allowed_ips.clone());
212 0 : Ok(allowed_ips)
213 CBC 8 : }
214 :
215 UBC 0 : #[tracing::instrument(skip_all)]
216 : async fn wake_compute(
217 : &self,
218 : ctx: &mut RequestMonitoring,
219 : extra: &ConsoleReqExtra,
220 : creds: &ComputeUserInfo,
221 0 : ) -> Result<CachedNodeInfo, WakeComputeError> {
222 0 : let key: &str = &creds.inner.cache_key;
223 :
224 : // Every time we do a wakeup http request, the compute node will stay up
225 : // for some time (highly depends on the console's scale-to-zero policy);
226 : // The connection info remains the same during that period of time,
227 : // which means that we might cache it to reduce the load and latency.
228 0 : if let Some(cached) = self.caches.node_info.get(key) {
229 0 : info!(key = key, "found cached compute node info");
230 0 : return Ok(cached);
231 0 : }
232 0 :
233 0 : let key: Arc<str> = key.into();
234 :
235 0 : let permit = self.locks.get_wake_compute_permit(&key).await?;
236 :
237 : // after getting back a permit - it's possible the cache was filled
238 : // double check
239 0 : if permit.should_check_cache() {
240 0 : if let Some(cached) = self.caches.node_info.get(&key) {
241 0 : info!(key = &*key, "found cached compute node info");
242 0 : return Ok(cached);
243 0 : }
244 0 : }
245 :
246 0 : let node = self.do_wake_compute(ctx, extra, creds).await?;
247 0 : let (_, cached) = self.caches.node_info.insert(key.clone(), node);
248 0 : info!(key = &*key, "created a cache entry for compute node info");
249 :
250 0 : Ok(cached)
251 0 : }
252 : }
253 :
254 : /// Parse http response body, taking status code into account.
255 CBC 3 : async fn parse_body<T: for<'a> serde::Deserialize<'a>>(
256 3 : response: http::Response,
257 3 : ) -> Result<T, ApiError> {
258 3 : let status = response.status();
259 3 : if status.is_success() {
260 : // We shouldn't log raw body because it may contain secrets.
261 1 : info!("request succeeded, processing the body");
262 1 : return Ok(response.json().await?);
263 2 : }
264 :
265 : // Don't throw an error here because it's not as important
266 : // as the fact that the request itself has failed.
267 2 : let body = response.json().await.unwrap_or_else(|e| {
268 2 : warn!("failed to parse error body: {e}");
269 2 : ConsoleError {
270 2 : error: "reason unclear (malformed error message)".into(),
271 2 : }
272 2 : });
273 2 :
274 2 : let text = body.error;
275 2 : error!("console responded with an error ({status}): {text}");
276 2 : Err(ApiError::Console { status, text })
277 3 : }
278 :
279 3 : fn parse_host_port(input: &str) -> Option<(&str, u16)> {
280 3 : let (host, port) = input.rsplit_once(':')?;
281 3 : let ipv6_brackets: &[_] = &['[', ']'];
282 3 : Some((host.trim_matches(ipv6_brackets), port.parse().ok()?))
283 3 : }
284 :
285 : #[cfg(test)]
286 : mod tests {
287 : use super::*;
288 :
289 1 : #[test]
290 1 : fn test_parse_host_port_v4() {
291 1 : let (host, port) = parse_host_port("127.0.0.1:5432").expect("failed to parse");
292 1 : assert_eq!(host, "127.0.0.1");
293 1 : assert_eq!(port, 5432);
294 1 : }
295 :
296 1 : #[test]
297 1 : fn test_parse_host_port_v6() {
298 1 : let (host, port) = parse_host_port("[2001:db8::1]:5432").expect("failed to parse");
299 1 : assert_eq!(host, "2001:db8::1");
300 1 : assert_eq!(port, 5432);
301 1 : }
302 :
303 1 : #[test]
304 1 : fn test_parse_host_port_url() {
305 1 : let (host, port) = parse_host_port("compute-foo-bar-1234.default.svc.cluster.local:5432")
306 1 : .expect("failed to parse");
307 1 : assert_eq!(host, "compute-foo-bar-1234.default.svc.cluster.local");
308 1 : assert_eq!(port, 5432);
309 1 : }
310 : }
|