Line data Source code
1 : //! Production console backend.
2 :
3 : use std::sync::Arc;
4 : use std::time::Duration;
5 :
6 : use ::http::header::AUTHORIZATION;
7 : use ::http::HeaderName;
8 : use futures::TryFutureExt;
9 : use postgres_client::config::SslMode;
10 : use tokio::time::Instant;
11 : use tracing::{debug, info, info_span, warn, Instrument};
12 :
13 : use super::super::messages::{ControlPlaneErrorMessage, GetEndpointAccessControl, WakeCompute};
14 : use crate::auth::backend::jwt::AuthRule;
15 : use crate::auth::backend::ComputeUserInfo;
16 : use crate::cache::Cached;
17 : use crate::context::RequestContext;
18 : use crate::control_plane::caches::ApiCaches;
19 : use crate::control_plane::errors::{
20 : ControlPlaneError, GetAuthInfoError, GetEndpointJwksError, WakeComputeError,
21 : };
22 : use crate::control_plane::locks::ApiLocks;
23 : use crate::control_plane::messages::{ColdStartInfo, EndpointJwksResponse, Reason};
24 : use crate::control_plane::{
25 : AuthInfo, AuthSecret, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret, NodeInfo,
26 : };
27 : use crate::metrics::{CacheOutcome, Metrics};
28 : use crate::rate_limiter::WakeComputeRateLimiter;
29 : use crate::types::{EndpointCacheKey, EndpointId};
30 : use crate::{compute, http, scram};
31 :
32 : pub(crate) const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");
33 :
34 : #[derive(Clone)]
35 : pub struct NeonControlPlaneClient {
36 : endpoint: http::Endpoint,
37 : pub caches: &'static ApiCaches,
38 : pub(crate) locks: &'static ApiLocks<EndpointCacheKey>,
39 : pub(crate) wake_compute_endpoint_rate_limiter: Arc<WakeComputeRateLimiter>,
40 : // put in a shared ref so we don't copy secrets all over in memory
41 : jwt: Arc<str>,
42 : }
43 :
44 : impl NeonControlPlaneClient {
45 : /// Construct an API object containing the auth parameters.
46 0 : pub fn new(
47 0 : endpoint: http::Endpoint,
48 0 : jwt: Arc<str>,
49 0 : caches: &'static ApiCaches,
50 0 : locks: &'static ApiLocks<EndpointCacheKey>,
51 0 : wake_compute_endpoint_rate_limiter: Arc<WakeComputeRateLimiter>,
52 0 : ) -> Self {
53 0 : Self {
54 0 : endpoint,
55 0 : caches,
56 0 : locks,
57 0 : wake_compute_endpoint_rate_limiter,
58 0 : jwt,
59 0 : }
60 0 : }
61 :
62 0 : pub(crate) fn url(&self) -> &str {
63 0 : self.endpoint.url().as_str()
64 0 : }
65 :
66 0 : async fn do_get_auth_info(
67 0 : &self,
68 0 : ctx: &RequestContext,
69 0 : user_info: &ComputeUserInfo,
70 0 : ) -> Result<AuthInfo, GetAuthInfoError> {
71 0 : if !self
72 0 : .caches
73 0 : .endpoints_cache
74 0 : .is_valid(ctx, &user_info.endpoint.normalize())
75 : {
76 : // TODO: refactor this because it's weird
77 : // this is a failure to authenticate but we return Ok.
78 0 : info!("endpoint is not valid, skipping the request");
79 0 : return Ok(AuthInfo::default());
80 0 : }
81 0 : self.do_get_auth_req(user_info, &ctx.session_id(), Some(ctx))
82 0 : .await
83 0 : }
84 :
85 0 : async fn do_get_auth_req(
86 0 : &self,
87 0 : user_info: &ComputeUserInfo,
88 0 : session_id: &uuid::Uuid,
89 0 : ctx: Option<&RequestContext>,
90 0 : ) -> Result<AuthInfo, GetAuthInfoError> {
91 0 : let request_id: String = session_id.to_string();
92 0 : let application_name = if let Some(ctx) = ctx {
93 0 : ctx.console_application_name()
94 : } else {
95 0 : "auth_cancellation".to_string()
96 : };
97 :
98 0 : async {
99 0 : let request = self
100 0 : .endpoint
101 0 : .get_path("get_endpoint_access_control")
102 0 : .header(X_REQUEST_ID, &request_id)
103 0 : .header(AUTHORIZATION, format!("Bearer {}", &self.jwt))
104 0 : .query(&[("session_id", session_id)])
105 0 : .query(&[
106 0 : ("application_name", application_name.as_str()),
107 0 : ("endpointish", user_info.endpoint.as_str()),
108 0 : ("role", user_info.user.as_str()),
109 0 : ])
110 0 : .build()?;
111 :
112 0 : debug!(url = request.url().as_str(), "sending http request");
113 0 : let start = Instant::now();
114 0 : let response = match ctx {
115 0 : Some(ctx) => {
116 0 : let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane);
117 0 : let rsp = self.endpoint.execute(request).await;
118 0 : drop(pause);
119 0 : rsp?
120 : }
121 0 : None => self.endpoint.execute(request).await?,
122 : };
123 :
124 0 : info!(duration = ?start.elapsed(), "received http response");
125 0 : let body = match parse_body::<GetEndpointAccessControl>(response).await {
126 0 : Ok(body) => body,
127 : // Error 404 is special: it's ok not to have a secret.
128 : // TODO(anna): retry
129 0 : Err(e) => {
130 0 : return if e.get_reason().is_not_found() {
131 : // TODO: refactor this because it's weird
132 : // this is a failure to authenticate but we return Ok.
133 0 : Ok(AuthInfo::default())
134 : } else {
135 0 : Err(e.into())
136 : };
137 : }
138 : };
139 :
140 : // Ivan: don't know where it will be used, so I leave it here
141 0 : let _endpoint_vpc_ids = body.allowed_vpc_endpoint_ids.unwrap_or_default();
142 :
143 0 : let secret = if body.role_secret.is_empty() {
144 0 : None
145 : } else {
146 0 : let secret = scram::ServerSecret::parse(&body.role_secret)
147 0 : .map(AuthSecret::Scram)
148 0 : .ok_or(GetAuthInfoError::BadSecret)?;
149 0 : Some(secret)
150 : };
151 0 : let allowed_ips = body.allowed_ips.unwrap_or_default();
152 0 : Metrics::get()
153 0 : .proxy
154 0 : .allowed_ips_number
155 0 : .observe(allowed_ips.len() as f64);
156 0 : Ok(AuthInfo {
157 0 : secret,
158 0 : allowed_ips,
159 0 : project_id: body.project_id,
160 0 : })
161 0 : }
162 0 : .inspect_err(|e| tracing::debug!(error = ?e))
163 0 : .instrument(info_span!("do_get_auth_info"))
164 0 : .await
165 0 : }
166 :
167 0 : async fn do_get_endpoint_jwks(
168 0 : &self,
169 0 : ctx: &RequestContext,
170 0 : endpoint: EndpointId,
171 0 : ) -> Result<Vec<AuthRule>, GetEndpointJwksError> {
172 0 : if !self
173 0 : .caches
174 0 : .endpoints_cache
175 0 : .is_valid(ctx, &endpoint.normalize())
176 : {
177 0 : return Err(GetEndpointJwksError::EndpointNotFound);
178 0 : }
179 0 : let request_id = ctx.session_id().to_string();
180 0 : async {
181 0 : let request = self
182 0 : .endpoint
183 0 : .get_with_url(|url| {
184 0 : url.path_segments_mut()
185 0 : .push("endpoints")
186 0 : .push(endpoint.as_str())
187 0 : .push("jwks");
188 0 : })
189 0 : .header(X_REQUEST_ID, &request_id)
190 0 : .header(AUTHORIZATION, format!("Bearer {}", &self.jwt))
191 0 : .query(&[("session_id", ctx.session_id())])
192 0 : .build()
193 0 : .map_err(GetEndpointJwksError::RequestBuild)?;
194 :
195 0 : debug!(url = request.url().as_str(), "sending http request");
196 0 : let start = Instant::now();
197 0 : let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane);
198 0 : let response = self
199 0 : .endpoint
200 0 : .execute(request)
201 0 : .await
202 0 : .map_err(GetEndpointJwksError::RequestExecute)?;
203 0 : drop(pause);
204 0 : info!(duration = ?start.elapsed(), "received http response");
205 :
206 0 : let body = parse_body::<EndpointJwksResponse>(response).await?;
207 :
208 0 : let rules = body
209 0 : .jwks
210 0 : .into_iter()
211 0 : .map(|jwks| AuthRule {
212 0 : id: jwks.id,
213 0 : jwks_url: jwks.jwks_url,
214 0 : audience: jwks.jwt_audience,
215 0 : role_names: jwks.role_names,
216 0 : })
217 0 : .collect();
218 0 :
219 0 : Ok(rules)
220 0 : }
221 0 : .inspect_err(|e| tracing::debug!(error = ?e))
222 0 : .instrument(info_span!("do_get_endpoint_jwks"))
223 0 : .await
224 0 : }
225 :
226 0 : async fn do_wake_compute(
227 0 : &self,
228 0 : ctx: &RequestContext,
229 0 : user_info: &ComputeUserInfo,
230 0 : ) -> Result<NodeInfo, WakeComputeError> {
231 0 : let request_id = ctx.session_id().to_string();
232 0 : let application_name = ctx.console_application_name();
233 0 : async {
234 0 : let mut request_builder = self
235 0 : .endpoint
236 0 : .get_path("wake_compute")
237 0 : .header("X-Request-ID", &request_id)
238 0 : .header("Authorization", format!("Bearer {}", &self.jwt))
239 0 : .query(&[("session_id", ctx.session_id())])
240 0 : .query(&[
241 0 : ("application_name", application_name.as_str()),
242 0 : ("endpointish", user_info.endpoint.as_str()),
243 0 : ]);
244 0 :
245 0 : let options = user_info.options.to_deep_object();
246 0 : if !options.is_empty() {
247 0 : request_builder = request_builder.query(&options);
248 0 : }
249 :
250 0 : let request = request_builder.build()?;
251 :
252 0 : debug!(url = request.url().as_str(), "sending http request");
253 0 : let start = Instant::now();
254 0 : let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane);
255 0 : let response = self.endpoint.execute(request).await?;
256 0 : drop(pause);
257 0 : info!(duration = ?start.elapsed(), "received http response");
258 0 : let body = parse_body::<WakeCompute>(response).await?;
259 :
260 : // Unfortunately, ownership won't let us use `Option::ok_or` here.
261 0 : let (host, port) = match parse_host_port(&body.address) {
262 0 : None => return Err(WakeComputeError::BadComputeAddress(body.address)),
263 0 : Some(x) => x,
264 0 : };
265 0 :
266 0 : // Don't set anything but host and port! This config will be cached.
267 0 : // We'll set username and such later using the startup message.
268 0 : // TODO: add more type safety (in progress).
269 0 : let mut config = compute::ConnCfg::new(host.to_owned(), port);
270 0 : config.ssl_mode(SslMode::Disable); // TLS is not configured on compute nodes.
271 0 :
272 0 : let node = NodeInfo {
273 0 : config,
274 0 : aux: body.aux,
275 0 : };
276 0 :
277 0 : Ok(node)
278 0 : }
279 0 : .inspect_err(|e| tracing::debug!(error = ?e))
280 0 : .instrument(info_span!("do_wake_compute"))
281 0 : .await
282 0 : }
283 : }
284 :
285 : impl super::ControlPlaneApi for NeonControlPlaneClient {
286 : #[tracing::instrument(skip_all)]
287 : async fn get_role_secret(
288 : &self,
289 : ctx: &RequestContext,
290 : user_info: &ComputeUserInfo,
291 : ) -> Result<CachedRoleSecret, GetAuthInfoError> {
292 : let normalized_ep = &user_info.endpoint.normalize();
293 : let user = &user_info.user;
294 : if let Some(role_secret) = self
295 : .caches
296 : .project_info
297 : .get_role_secret(normalized_ep, user)
298 : {
299 : return Ok(role_secret);
300 : }
301 : let auth_info = self.do_get_auth_info(ctx, user_info).await?;
302 : if let Some(project_id) = auth_info.project_id {
303 : let normalized_ep_int = normalized_ep.into();
304 : self.caches.project_info.insert_role_secret(
305 : project_id,
306 : normalized_ep_int,
307 : user.into(),
308 : auth_info.secret.clone(),
309 : );
310 : self.caches.project_info.insert_allowed_ips(
311 : project_id,
312 : normalized_ep_int,
313 : Arc::new(auth_info.allowed_ips),
314 : );
315 : ctx.set_project_id(project_id);
316 : }
317 : // When we just got a secret, we don't need to invalidate it.
318 : Ok(Cached::new_uncached(auth_info.secret))
319 : }
320 :
321 0 : async fn get_allowed_ips_and_secret(
322 0 : &self,
323 0 : ctx: &RequestContext,
324 0 : user_info: &ComputeUserInfo,
325 0 : ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
326 0 : let normalized_ep = &user_info.endpoint.normalize();
327 0 : if let Some(allowed_ips) = self.caches.project_info.get_allowed_ips(normalized_ep) {
328 0 : Metrics::get()
329 0 : .proxy
330 0 : .allowed_ips_cache_misses
331 0 : .inc(CacheOutcome::Hit);
332 0 : return Ok((allowed_ips, None));
333 0 : }
334 0 : Metrics::get()
335 0 : .proxy
336 0 : .allowed_ips_cache_misses
337 0 : .inc(CacheOutcome::Miss);
338 0 : let auth_info = self.do_get_auth_info(ctx, user_info).await?;
339 0 : let allowed_ips = Arc::new(auth_info.allowed_ips);
340 0 : let user = &user_info.user;
341 0 : if let Some(project_id) = auth_info.project_id {
342 0 : let normalized_ep_int = normalized_ep.into();
343 0 : self.caches.project_info.insert_role_secret(
344 0 : project_id,
345 0 : normalized_ep_int,
346 0 : user.into(),
347 0 : auth_info.secret.clone(),
348 0 : );
349 0 : self.caches.project_info.insert_allowed_ips(
350 0 : project_id,
351 0 : normalized_ep_int,
352 0 : allowed_ips.clone(),
353 0 : );
354 0 : ctx.set_project_id(project_id);
355 0 : }
356 0 : Ok((
357 0 : Cached::new_uncached(allowed_ips),
358 0 : Some(Cached::new_uncached(auth_info.secret)),
359 0 : ))
360 0 : }
361 :
362 : #[tracing::instrument(skip_all)]
363 : async fn get_endpoint_jwks(
364 : &self,
365 : ctx: &RequestContext,
366 : endpoint: EndpointId,
367 : ) -> Result<Vec<AuthRule>, GetEndpointJwksError> {
368 : self.do_get_endpoint_jwks(ctx, endpoint).await
369 : }
370 :
371 : #[tracing::instrument(skip_all)]
372 : async fn wake_compute(
373 : &self,
374 : ctx: &RequestContext,
375 : user_info: &ComputeUserInfo,
376 : ) -> Result<CachedNodeInfo, WakeComputeError> {
377 : let key = user_info.endpoint_cache_key();
378 :
379 : macro_rules! check_cache {
380 : () => {
381 : if let Some(cached) = self.caches.node_info.get(&key) {
382 : let (cached, info) = cached.take_value();
383 0 : let info = info.map_err(|c| {
384 0 : info!(key = &*key, "found cached wake_compute error");
385 0 : WakeComputeError::ControlPlane(ControlPlaneError::Message(Box::new(*c)))
386 0 : })?;
387 :
388 : debug!(key = &*key, "found cached compute node info");
389 : ctx.set_project(info.aux.clone());
390 0 : return Ok(cached.map(|()| info));
391 : }
392 : };
393 : }
394 :
395 : // Every time we do a wakeup http request, the compute node will stay up
396 : // for some time (highly depends on the console's scale-to-zero policy);
397 : // The connection info remains the same during that period of time,
398 : // which means that we might cache it to reduce the load and latency.
399 : check_cache!();
400 :
401 : let permit = self.locks.get_permit(&key).await?;
402 :
403 : // after getting back a permit - it's possible the cache was filled
404 : // double check
405 : if permit.should_check_cache() {
406 : // TODO: if there is something in the cache, mark the permit as success.
407 : check_cache!();
408 : }
409 :
410 : // check rate limit
411 : if !self
412 : .wake_compute_endpoint_rate_limiter
413 : .check(user_info.endpoint.normalize_intern(), 1)
414 : {
415 : return Err(WakeComputeError::TooManyConnections);
416 : }
417 :
418 : let node = permit.release_result(self.do_wake_compute(ctx, user_info).await);
419 : match node {
420 : Ok(node) => {
421 : ctx.set_project(node.aux.clone());
422 : debug!(key = &*key, "created a cache entry for woken compute node");
423 :
424 : let mut stored_node = node.clone();
425 : // store the cached node as 'warm_cached'
426 : stored_node.aux.cold_start_info = ColdStartInfo::WarmCached;
427 :
428 : let (_, cached) = self.caches.node_info.insert_unit(key, Ok(stored_node));
429 :
430 0 : Ok(cached.map(|()| node))
431 : }
432 : Err(err) => match err {
433 : WakeComputeError::ControlPlane(ControlPlaneError::Message(err)) => {
434 : let Some(status) = &err.status else {
435 : return Err(WakeComputeError::ControlPlane(ControlPlaneError::Message(
436 : err,
437 : )));
438 : };
439 :
440 : let reason = status
441 : .details
442 : .error_info
443 0 : .map_or(Reason::Unknown, |x| x.reason);
444 :
445 : // if we can retry this error, do not cache it.
446 : if reason.can_retry() {
447 : return Err(WakeComputeError::ControlPlane(ControlPlaneError::Message(
448 : err,
449 : )));
450 : }
451 :
452 : // at this point, we should only have quota errors.
453 : debug!(
454 : key = &*key,
455 : "created a cache entry for the wake compute error"
456 : );
457 :
458 : self.caches.node_info.insert_ttl(
459 : key,
460 : Err(err.clone()),
461 : Duration::from_secs(30),
462 : );
463 :
464 : Err(WakeComputeError::ControlPlane(ControlPlaneError::Message(
465 : err,
466 : )))
467 : }
468 : err => return Err(err),
469 : },
470 : }
471 : }
472 : }
473 :
474 : /// Parse http response body, taking status code into account.
475 0 : async fn parse_body<T: for<'a> serde::Deserialize<'a>>(
476 0 : response: http::Response,
477 0 : ) -> Result<T, ControlPlaneError> {
478 0 : let status = response.status();
479 0 : if status.is_success() {
480 : // We shouldn't log raw body because it may contain secrets.
481 0 : info!("request succeeded, processing the body");
482 0 : return Ok(response.json().await?);
483 0 : }
484 0 : let s = response.bytes().await?;
485 : // Log plaintext to be able to detect, whether there are some cases not covered by the error struct.
486 0 : info!("response_error plaintext: {:?}", s);
487 :
488 : // Don't throw an error here because it's not as important
489 : // as the fact that the request itself has failed.
490 0 : let mut body = serde_json::from_slice(&s).unwrap_or_else(|e| {
491 0 : warn!("failed to parse error body: {e}");
492 0 : ControlPlaneErrorMessage {
493 0 : error: "reason unclear (malformed error message)".into(),
494 0 : http_status_code: status,
495 0 : status: None,
496 0 : }
497 0 : });
498 0 : body.http_status_code = status;
499 0 :
500 0 : warn!("console responded with an error ({status}): {body:?}");
501 0 : Err(ControlPlaneError::Message(Box::new(body)))
502 0 : }
503 :
504 3 : fn parse_host_port(input: &str) -> Option<(&str, u16)> {
505 3 : let (host, port) = input.rsplit_once(':')?;
506 3 : let ipv6_brackets: &[_] = &['[', ']'];
507 3 : Some((host.trim_matches(ipv6_brackets), port.parse().ok()?))
508 3 : }
509 :
510 : #[cfg(test)]
511 : mod tests {
512 : use super::*;
513 :
514 : #[test]
515 1 : fn test_parse_host_port_v4() {
516 1 : let (host, port) = parse_host_port("127.0.0.1:5432").expect("failed to parse");
517 1 : assert_eq!(host, "127.0.0.1");
518 1 : assert_eq!(port, 5432);
519 1 : }
520 :
521 : #[test]
522 1 : fn test_parse_host_port_v6() {
523 1 : let (host, port) = parse_host_port("[2001:db8::1]:5432").expect("failed to parse");
524 1 : assert_eq!(host, "2001:db8::1");
525 1 : assert_eq!(port, 5432);
526 1 : }
527 :
528 : #[test]
529 1 : fn test_parse_host_port_url() {
530 1 : let (host, port) = parse_host_port("compute-foo-bar-1234.default.svc.cluster.local:5432")
531 1 : .expect("failed to parse");
532 1 : assert_eq!(host, "compute-foo-bar-1234.default.svc.cluster.local");
533 1 : assert_eq!(port, 5432);
534 1 : }
535 : }
|