Line data Source code
1 : //! Production console backend.
2 :
3 : use std::sync::Arc;
4 : use std::time::Duration;
5 :
6 : use ::http::header::AUTHORIZATION;
7 : use ::http::HeaderName;
8 : use futures::TryFutureExt;
9 : use tokio::time::Instant;
10 : use tokio_postgres::config::SslMode;
11 : use tracing::{debug, info, info_span, warn, Instrument};
12 :
13 : use super::super::messages::{ControlPlaneError, GetRoleSecret, WakeCompute};
14 : use super::errors::{ApiError, GetAuthInfoError, WakeComputeError};
15 : use super::{
16 : ApiCaches, ApiLocks, AuthInfo, AuthSecret, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret,
17 : NodeInfo,
18 : };
19 : use crate::auth::backend::jwt::AuthRule;
20 : use crate::auth::backend::ComputeUserInfo;
21 : use crate::cache::Cached;
22 : use crate::context::RequestMonitoring;
23 : use crate::control_plane::errors::GetEndpointJwksError;
24 : use crate::control_plane::messages::{ColdStartInfo, EndpointJwksResponse, Reason};
25 : use crate::metrics::{CacheOutcome, Metrics};
26 : use crate::rate_limiter::WakeComputeRateLimiter;
27 : use crate::{compute, http, scram, EndpointCacheKey, EndpointId};
28 :
29 : const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");
30 :
31 : #[derive(Clone)]
32 : pub struct Api {
33 : endpoint: http::Endpoint,
34 : pub caches: &'static ApiCaches,
35 : pub(crate) locks: &'static ApiLocks<EndpointCacheKey>,
36 : pub(crate) wake_compute_endpoint_rate_limiter: Arc<WakeComputeRateLimiter>,
37 : // put in a shared ref so we don't copy secrets all over in memory
38 : jwt: Arc<str>,
39 : }
40 :
41 : impl Api {
42 : /// Construct an API object containing the auth parameters.
43 0 : pub fn new(
44 0 : endpoint: http::Endpoint,
45 0 : caches: &'static ApiCaches,
46 0 : locks: &'static ApiLocks<EndpointCacheKey>,
47 0 : wake_compute_endpoint_rate_limiter: Arc<WakeComputeRateLimiter>,
48 0 : ) -> Self {
49 0 : let jwt = std::env::var("NEON_PROXY_TO_CONTROLPLANE_TOKEN")
50 0 : .unwrap_or_default()
51 0 : .into();
52 0 : Self {
53 0 : endpoint,
54 0 : caches,
55 0 : locks,
56 0 : wake_compute_endpoint_rate_limiter,
57 0 : jwt,
58 0 : }
59 0 : }
60 :
61 0 : pub(crate) fn url(&self) -> &str {
62 0 : self.endpoint.url().as_str()
63 0 : }
64 :
65 0 : async fn do_get_auth_info(
66 0 : &self,
67 0 : ctx: &RequestMonitoring,
68 0 : user_info: &ComputeUserInfo,
69 0 : ) -> Result<AuthInfo, GetAuthInfoError> {
70 0 : if !self
71 0 : .caches
72 0 : .endpoints_cache
73 0 : .is_valid(ctx, &user_info.endpoint.normalize())
74 0 : .await
75 : {
76 0 : info!("endpoint is not valid, skipping the request");
77 0 : return Ok(AuthInfo::default());
78 0 : }
79 0 : let request_id = ctx.session_id().to_string();
80 0 : let application_name = ctx.console_application_name();
81 0 : async {
82 0 : let request = self
83 0 : .endpoint
84 0 : .get_path("proxy_get_role_secret")
85 0 : .header(X_REQUEST_ID, &request_id)
86 0 : .header(AUTHORIZATION, format!("Bearer {}", &self.jwt))
87 0 : .query(&[("session_id", ctx.session_id())])
88 0 : .query(&[
89 0 : ("application_name", application_name.as_str()),
90 0 : ("project", user_info.endpoint.as_str()),
91 0 : ("role", user_info.user.as_str()),
92 0 : ])
93 0 : .build()?;
94 :
95 0 : info!(url = request.url().as_str(), "sending http request");
96 0 : let start = Instant::now();
97 0 : let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane);
98 0 : let response = self.endpoint.execute(request).await?;
99 0 : drop(pause);
100 0 : info!(duration = ?start.elapsed(), "received http response");
101 0 : let body = match parse_body::<GetRoleSecret>(response).await {
102 0 : Ok(body) => body,
103 : // Error 404 is special: it's ok not to have a secret.
104 : // TODO(anna): retry
105 0 : Err(e) => {
106 0 : return if e.get_reason().is_not_found() {
107 0 : Ok(AuthInfo::default())
108 : } else {
109 0 : Err(e.into())
110 : }
111 : }
112 : };
113 :
114 0 : let secret = if body.role_secret.is_empty() {
115 0 : None
116 : } else {
117 0 : let secret = scram::ServerSecret::parse(&body.role_secret)
118 0 : .map(AuthSecret::Scram)
119 0 : .ok_or(GetAuthInfoError::BadSecret)?;
120 0 : Some(secret)
121 : };
122 0 : let allowed_ips = body.allowed_ips.unwrap_or_default();
123 0 : Metrics::get()
124 0 : .proxy
125 0 : .allowed_ips_number
126 0 : .observe(allowed_ips.len() as f64);
127 0 : Ok(AuthInfo {
128 0 : secret,
129 0 : allowed_ips,
130 0 : project_id: body.project_id,
131 0 : })
132 0 : }
133 0 : .map_err(crate::error::log_error)
134 0 : .instrument(info_span!("http", id = request_id))
135 0 : .await
136 0 : }
137 :
138 0 : async fn do_get_endpoint_jwks(
139 0 : &self,
140 0 : ctx: &RequestMonitoring,
141 0 : endpoint: EndpointId,
142 0 : ) -> Result<Vec<AuthRule>, GetEndpointJwksError> {
143 0 : if !self
144 0 : .caches
145 0 : .endpoints_cache
146 0 : .is_valid(ctx, &endpoint.normalize())
147 0 : .await
148 : {
149 0 : return Err(GetEndpointJwksError::EndpointNotFound);
150 0 : }
151 0 : let request_id = ctx.session_id().to_string();
152 0 : async {
153 0 : let request = self
154 0 : .endpoint
155 0 : .get_with_url(|url| {
156 0 : url.path_segments_mut()
157 0 : .push("endpoints")
158 0 : .push(endpoint.as_str())
159 0 : .push("jwks");
160 0 : })
161 0 : .header(X_REQUEST_ID, &request_id)
162 0 : .header(AUTHORIZATION, format!("Bearer {}", &self.jwt))
163 0 : .query(&[("session_id", ctx.session_id())])
164 0 : .build()
165 0 : .map_err(GetEndpointJwksError::RequestBuild)?;
166 :
167 0 : info!(url = request.url().as_str(), "sending http request");
168 0 : let start = Instant::now();
169 0 : let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane);
170 0 : let response = self
171 0 : .endpoint
172 0 : .execute(request)
173 0 : .await
174 0 : .map_err(GetEndpointJwksError::RequestExecute)?;
175 0 : drop(pause);
176 0 : info!(duration = ?start.elapsed(), "received http response");
177 :
178 0 : let body = parse_body::<EndpointJwksResponse>(response).await?;
179 :
180 0 : let rules = body
181 0 : .jwks
182 0 : .into_iter()
183 0 : .map(|jwks| AuthRule {
184 0 : id: jwks.id,
185 0 : jwks_url: jwks.jwks_url,
186 0 : audience: jwks.jwt_audience,
187 0 : role_names: jwks.role_names,
188 0 : })
189 0 : .collect();
190 0 :
191 0 : Ok(rules)
192 0 : }
193 0 : .map_err(crate::error::log_error)
194 0 : .instrument(info_span!("http", id = request_id))
195 0 : .await
196 0 : }
197 :
198 0 : async fn do_wake_compute(
199 0 : &self,
200 0 : ctx: &RequestMonitoring,
201 0 : user_info: &ComputeUserInfo,
202 0 : ) -> Result<NodeInfo, WakeComputeError> {
203 0 : let request_id = ctx.session_id().to_string();
204 0 : let application_name = ctx.console_application_name();
205 0 : async {
206 0 : let mut request_builder = self
207 0 : .endpoint
208 0 : .get_path("proxy_wake_compute")
209 0 : .header("X-Request-ID", &request_id)
210 0 : .header("Authorization", format!("Bearer {}", &self.jwt))
211 0 : .query(&[("session_id", ctx.session_id())])
212 0 : .query(&[
213 0 : ("application_name", application_name.as_str()),
214 0 : ("project", user_info.endpoint.as_str()),
215 0 : ]);
216 0 :
217 0 : let options = user_info.options.to_deep_object();
218 0 : if !options.is_empty() {
219 0 : request_builder = request_builder.query(&options);
220 0 : }
221 :
222 0 : let request = request_builder.build()?;
223 :
224 0 : info!(url = request.url().as_str(), "sending http request");
225 0 : let start = Instant::now();
226 0 : let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane);
227 0 : let response = self.endpoint.execute(request).await?;
228 0 : drop(pause);
229 0 : info!(duration = ?start.elapsed(), "received http response");
230 0 : let body = parse_body::<WakeCompute>(response).await?;
231 :
232 : // Unfortunately, ownership won't let us use `Option::ok_or` here.
233 0 : let (host, port) = match parse_host_port(&body.address) {
234 0 : None => return Err(WakeComputeError::BadComputeAddress(body.address)),
235 0 : Some(x) => x,
236 0 : };
237 0 :
238 0 : // Don't set anything but host and port! This config will be cached.
239 0 : // We'll set username and such later using the startup message.
240 0 : // TODO: add more type safety (in progress).
241 0 : let mut config = compute::ConnCfg::new();
242 0 : config.host(host).port(port).ssl_mode(SslMode::Disable); // TLS is not configured on compute nodes.
243 0 :
244 0 : let node = NodeInfo {
245 0 : config,
246 0 : aux: body.aux,
247 0 : allow_self_signed_compute: false,
248 0 : };
249 0 :
250 0 : Ok(node)
251 0 : }
252 0 : .map_err(crate::error::log_error)
253 0 : .instrument(info_span!("http", id = request_id))
254 0 : .await
255 0 : }
256 : }
257 :
258 : impl super::Api for Api {
259 0 : #[tracing::instrument(skip_all)]
260 : async fn get_role_secret(
261 : &self,
262 : ctx: &RequestMonitoring,
263 : user_info: &ComputeUserInfo,
264 : ) -> Result<CachedRoleSecret, GetAuthInfoError> {
265 : let normalized_ep = &user_info.endpoint.normalize();
266 : let user = &user_info.user;
267 : if let Some(role_secret) = self
268 : .caches
269 : .project_info
270 : .get_role_secret(normalized_ep, user)
271 : {
272 : return Ok(role_secret);
273 : }
274 : let auth_info = self.do_get_auth_info(ctx, user_info).await?;
275 : if let Some(project_id) = auth_info.project_id {
276 : let normalized_ep_int = normalized_ep.into();
277 : self.caches.project_info.insert_role_secret(
278 : project_id,
279 : normalized_ep_int,
280 : user.into(),
281 : auth_info.secret.clone(),
282 : );
283 : self.caches.project_info.insert_allowed_ips(
284 : project_id,
285 : normalized_ep_int,
286 : Arc::new(auth_info.allowed_ips),
287 : );
288 : ctx.set_project_id(project_id);
289 : }
290 : // When we just got a secret, we don't need to invalidate it.
291 : Ok(Cached::new_uncached(auth_info.secret))
292 : }
293 :
294 0 : async fn get_allowed_ips_and_secret(
295 0 : &self,
296 0 : ctx: &RequestMonitoring,
297 0 : user_info: &ComputeUserInfo,
298 0 : ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
299 0 : let normalized_ep = &user_info.endpoint.normalize();
300 0 : if let Some(allowed_ips) = self.caches.project_info.get_allowed_ips(normalized_ep) {
301 0 : Metrics::get()
302 0 : .proxy
303 0 : .allowed_ips_cache_misses
304 0 : .inc(CacheOutcome::Hit);
305 0 : return Ok((allowed_ips, None));
306 0 : }
307 0 : Metrics::get()
308 0 : .proxy
309 0 : .allowed_ips_cache_misses
310 0 : .inc(CacheOutcome::Miss);
311 0 : let auth_info = self.do_get_auth_info(ctx, user_info).await?;
312 0 : let allowed_ips = Arc::new(auth_info.allowed_ips);
313 0 : let user = &user_info.user;
314 0 : if let Some(project_id) = auth_info.project_id {
315 0 : let normalized_ep_int = normalized_ep.into();
316 0 : self.caches.project_info.insert_role_secret(
317 0 : project_id,
318 0 : normalized_ep_int,
319 0 : user.into(),
320 0 : auth_info.secret.clone(),
321 0 : );
322 0 : self.caches.project_info.insert_allowed_ips(
323 0 : project_id,
324 0 : normalized_ep_int,
325 0 : allowed_ips.clone(),
326 0 : );
327 0 : ctx.set_project_id(project_id);
328 0 : }
329 0 : Ok((
330 0 : Cached::new_uncached(allowed_ips),
331 0 : Some(Cached::new_uncached(auth_info.secret)),
332 0 : ))
333 0 : }
334 :
335 0 : #[tracing::instrument(skip_all)]
336 : async fn get_endpoint_jwks(
337 : &self,
338 : ctx: &RequestMonitoring,
339 : endpoint: EndpointId,
340 : ) -> Result<Vec<AuthRule>, GetEndpointJwksError> {
341 : self.do_get_endpoint_jwks(ctx, endpoint).await
342 : }
343 :
344 0 : #[tracing::instrument(skip_all)]
345 : async fn wake_compute(
346 : &self,
347 : ctx: &RequestMonitoring,
348 : user_info: &ComputeUserInfo,
349 : ) -> Result<CachedNodeInfo, WakeComputeError> {
350 : let key = user_info.endpoint_cache_key();
351 :
352 : macro_rules! check_cache {
353 : () => {
354 : if let Some(cached) = self.caches.node_info.get(&key) {
355 : let (cached, info) = cached.take_value();
356 0 : let info = info.map_err(|c| {
357 0 : info!(key = &*key, "found cached wake_compute error");
358 0 : WakeComputeError::ApiError(ApiError::ControlPlane(Box::new(*c)))
359 0 : })?;
360 :
361 : debug!(key = &*key, "found cached compute node info");
362 : ctx.set_project(info.aux.clone());
363 0 : return Ok(cached.map(|()| info));
364 : }
365 : };
366 : }
367 :
368 : // Every time we do a wakeup http request, the compute node will stay up
369 : // for some time (highly depends on the console's scale-to-zero policy);
370 : // The connection info remains the same during that period of time,
371 : // which means that we might cache it to reduce the load and latency.
372 : check_cache!();
373 :
374 : let permit = self.locks.get_permit(&key).await?;
375 :
376 : // after getting back a permit - it's possible the cache was filled
377 : // double check
378 : if permit.should_check_cache() {
379 : check_cache!();
380 : }
381 :
382 : // check rate limit
383 : if !self
384 : .wake_compute_endpoint_rate_limiter
385 : .check(user_info.endpoint.normalize_intern(), 1)
386 : {
387 : return Err(WakeComputeError::TooManyConnections);
388 : }
389 :
390 : let node = permit.release_result(self.do_wake_compute(ctx, user_info).await);
391 : match node {
392 : Ok(node) => {
393 : ctx.set_project(node.aux.clone());
394 : debug!(key = &*key, "created a cache entry for woken compute node");
395 :
396 : let mut stored_node = node.clone();
397 : // store the cached node as 'warm_cached'
398 : stored_node.aux.cold_start_info = ColdStartInfo::WarmCached;
399 :
400 : let (_, cached) = self.caches.node_info.insert_unit(key, Ok(stored_node));
401 :
402 0 : Ok(cached.map(|()| node))
403 : }
404 : Err(err) => match err {
405 : WakeComputeError::ApiError(ApiError::ControlPlane(err)) => {
406 : let Some(status) = &err.status else {
407 : return Err(WakeComputeError::ApiError(ApiError::ControlPlane(err)));
408 : };
409 :
410 : let reason = status
411 : .details
412 : .error_info
413 0 : .map_or(Reason::Unknown, |x| x.reason);
414 :
415 : // if we can retry this error, do not cache it.
416 : if reason.can_retry() {
417 : return Err(WakeComputeError::ApiError(ApiError::ControlPlane(err)));
418 : }
419 :
420 : // at this point, we should only have quota errors.
421 : debug!(
422 : key = &*key,
423 : "created a cache entry for the wake compute error"
424 : );
425 :
426 : self.caches.node_info.insert_ttl(
427 : key,
428 : Err(err.clone()),
429 : Duration::from_secs(30),
430 : );
431 :
432 : Err(WakeComputeError::ApiError(ApiError::ControlPlane(err)))
433 : }
434 : err => return Err(err),
435 : },
436 : }
437 : }
438 : }
439 :
440 : /// Parse http response body, taking status code into account.
441 0 : async fn parse_body<T: for<'a> serde::Deserialize<'a>>(
442 0 : response: http::Response,
443 0 : ) -> Result<T, ApiError> {
444 0 : let status = response.status();
445 0 : if status.is_success() {
446 : // We shouldn't log raw body because it may contain secrets.
447 0 : info!("request succeeded, processing the body");
448 0 : return Ok(response.json().await?);
449 0 : }
450 0 : let s = response.bytes().await?;
451 : // Log plaintext to be able to detect, whether there are some cases not covered by the error struct.
452 0 : info!("response_error plaintext: {:?}", s);
453 :
454 : // Don't throw an error here because it's not as important
455 : // as the fact that the request itself has failed.
456 0 : let mut body = serde_json::from_slice(&s).unwrap_or_else(|e| {
457 0 : warn!("failed to parse error body: {e}");
458 0 : ControlPlaneError {
459 0 : error: "reason unclear (malformed error message)".into(),
460 0 : http_status_code: status,
461 0 : status: None,
462 0 : }
463 0 : });
464 0 : body.http_status_code = status;
465 0 :
466 0 : warn!("console responded with an error ({status}): {body:?}");
467 0 : Err(ApiError::ControlPlane(Box::new(body)))
468 0 : }
469 :
470 3 : fn parse_host_port(input: &str) -> Option<(&str, u16)> {
471 3 : let (host, port) = input.rsplit_once(':')?;
472 3 : let ipv6_brackets: &[_] = &['[', ']'];
473 3 : Some((host.trim_matches(ipv6_brackets), port.parse().ok()?))
474 3 : }
475 :
476 : #[cfg(test)]
477 : mod tests {
478 : use super::*;
479 :
480 : #[test]
481 1 : fn test_parse_host_port_v4() {
482 1 : let (host, port) = parse_host_port("127.0.0.1:5432").expect("failed to parse");
483 1 : assert_eq!(host, "127.0.0.1");
484 1 : assert_eq!(port, 5432);
485 1 : }
486 :
487 : #[test]
488 1 : fn test_parse_host_port_v6() {
489 1 : let (host, port) = parse_host_port("[2001:db8::1]:5432").expect("failed to parse");
490 1 : assert_eq!(host, "2001:db8::1");
491 1 : assert_eq!(port, 5432);
492 1 : }
493 :
494 : #[test]
495 1 : fn test_parse_host_port_url() {
496 1 : let (host, port) = parse_host_port("compute-foo-bar-1234.default.svc.cluster.local:5432")
497 1 : .expect("failed to parse");
498 1 : assert_eq!(host, "compute-foo-bar-1234.default.svc.cluster.local");
499 1 : assert_eq!(port, 5432);
500 1 : }
501 : }
|