LCOV - code coverage report
Current view: top level - proxy/src/control_plane/client - neon.rs (source / functions) Coverage Total Hit
Test: 8ff8efadb0253cf618c612650348666c0c564111.info Lines: 7.5 % 281 21
Test Date: 2024-11-20 17:53:50 Functions: 10.8 % 37 4

            Line data    Source code
       1              : //! Production console backend.
       2              : 
       3              : use std::sync::Arc;
       4              : use std::time::Duration;
       5              : 
       6              : use ::http::header::AUTHORIZATION;
       7              : use ::http::HeaderName;
       8              : use futures::TryFutureExt;
       9              : use tokio::time::Instant;
      10              : use tokio_postgres::config::SslMode;
      11              : use tracing::{debug, info, info_span, warn, Instrument};
      12              : 
      13              : use super::super::messages::{ControlPlaneErrorMessage, GetRoleSecret, WakeCompute};
      14              : use crate::auth::backend::jwt::AuthRule;
      15              : use crate::auth::backend::ComputeUserInfo;
      16              : use crate::cache::Cached;
      17              : use crate::context::RequestContext;
      18              : use crate::control_plane::caches::ApiCaches;
      19              : use crate::control_plane::errors::{
      20              :     ControlPlaneError, GetAuthInfoError, GetEndpointJwksError, WakeComputeError,
      21              : };
      22              : use crate::control_plane::locks::ApiLocks;
      23              : use crate::control_plane::messages::{ColdStartInfo, EndpointJwksResponse, Reason};
      24              : use crate::control_plane::{
      25              :     AuthInfo, AuthSecret, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret, NodeInfo,
      26              : };
      27              : use crate::metrics::{CacheOutcome, Metrics};
      28              : use crate::rate_limiter::WakeComputeRateLimiter;
      29              : use crate::types::{EndpointCacheKey, EndpointId};
      30              : use crate::{compute, http, scram};
      31              : 
      32              : const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");
      33              : 
      34              : #[derive(Clone)]
      35              : pub struct NeonControlPlaneClient {
      36              :     endpoint: http::Endpoint,
      37              :     pub caches: &'static ApiCaches,
      38              :     pub(crate) locks: &'static ApiLocks<EndpointCacheKey>,
      39              :     pub(crate) wake_compute_endpoint_rate_limiter: Arc<WakeComputeRateLimiter>,
      40              :     // put in a shared ref so we don't copy secrets all over in memory
      41              :     jwt: Arc<str>,
      42              : }
      43              : 
      44              : impl NeonControlPlaneClient {
      45              :     /// Construct an API object containing the auth parameters.
      46            0 :     pub fn new(
      47            0 :         endpoint: http::Endpoint,
      48            0 :         jwt: Arc<str>,
      49            0 :         caches: &'static ApiCaches,
      50            0 :         locks: &'static ApiLocks<EndpointCacheKey>,
      51            0 :         wake_compute_endpoint_rate_limiter: Arc<WakeComputeRateLimiter>,
      52            0 :     ) -> Self {
      53            0 :         Self {
      54            0 :             endpoint,
      55            0 :             caches,
      56            0 :             locks,
      57            0 :             wake_compute_endpoint_rate_limiter,
      58            0 :             jwt,
      59            0 :         }
      60            0 :     }
      61              : 
      62            0 :     pub(crate) fn url(&self) -> &str {
      63            0 :         self.endpoint.url().as_str()
      64            0 :     }
      65              : 
      66            0 :     async fn do_get_auth_info(
      67            0 :         &self,
      68            0 :         ctx: &RequestContext,
      69            0 :         user_info: &ComputeUserInfo,
      70            0 :     ) -> Result<AuthInfo, GetAuthInfoError> {
      71            0 :         if !self
      72            0 :             .caches
      73            0 :             .endpoints_cache
      74            0 :             .is_valid(ctx, &user_info.endpoint.normalize())
      75              :         {
      76              :             // TODO: refactor this because it's weird
      77              :             // this is a failure to authenticate but we return Ok.
      78            0 :             info!("endpoint is not valid, skipping the request");
      79            0 :             return Ok(AuthInfo::default());
      80            0 :         }
      81            0 :         let request_id = ctx.session_id().to_string();
      82            0 :         let application_name = ctx.console_application_name();
      83            0 :         async {
      84            0 :             let request = self
      85            0 :                 .endpoint
      86            0 :                 .get_path("proxy_get_role_secret")
      87            0 :                 .header(X_REQUEST_ID, &request_id)
      88            0 :                 .header(AUTHORIZATION, format!("Bearer {}", &self.jwt))
      89            0 :                 .query(&[("session_id", ctx.session_id())])
      90            0 :                 .query(&[
      91            0 :                     ("application_name", application_name.as_str()),
      92            0 :                     ("project", user_info.endpoint.as_str()),
      93            0 :                     ("role", user_info.user.as_str()),
      94            0 :                 ])
      95            0 :                 .build()?;
      96              : 
      97            0 :             debug!(url = request.url().as_str(), "sending http request");
      98            0 :             let start = Instant::now();
      99            0 :             let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane);
     100            0 :             let response = self.endpoint.execute(request).await?;
     101            0 :             drop(pause);
     102            0 :             info!(duration = ?start.elapsed(), "received http response");
     103            0 :             let body = match parse_body::<GetRoleSecret>(response).await {
     104            0 :                 Ok(body) => body,
     105              :                 // Error 404 is special: it's ok not to have a secret.
     106              :                 // TODO(anna): retry
     107            0 :                 Err(e) => {
     108            0 :                     return if e.get_reason().is_not_found() {
     109              :                         // TODO: refactor this because it's weird
     110              :                         // this is a failure to authenticate but we return Ok.
     111            0 :                         Ok(AuthInfo::default())
     112              :                     } else {
     113            0 :                         Err(e.into())
     114              :                     };
     115              :                 }
     116              :             };
     117              : 
     118            0 :             let secret = if body.role_secret.is_empty() {
     119            0 :                 None
     120              :             } else {
     121            0 :                 let secret = scram::ServerSecret::parse(&body.role_secret)
     122            0 :                     .map(AuthSecret::Scram)
     123            0 :                     .ok_or(GetAuthInfoError::BadSecret)?;
     124            0 :                 Some(secret)
     125              :             };
     126            0 :             let allowed_ips = body.allowed_ips.unwrap_or_default();
     127            0 :             Metrics::get()
     128            0 :                 .proxy
     129            0 :                 .allowed_ips_number
     130            0 :                 .observe(allowed_ips.len() as f64);
     131            0 :             Ok(AuthInfo {
     132            0 :                 secret,
     133            0 :                 allowed_ips,
     134            0 :                 project_id: body.project_id,
     135            0 :             })
     136            0 :         }
     137            0 :         .map_err(crate::error::log_error)
     138            0 :         .instrument(info_span!("http", id = request_id))
     139            0 :         .await
     140            0 :     }
     141              : 
     142            0 :     async fn do_get_endpoint_jwks(
     143            0 :         &self,
     144            0 :         ctx: &RequestContext,
     145            0 :         endpoint: EndpointId,
     146            0 :     ) -> Result<Vec<AuthRule>, GetEndpointJwksError> {
     147            0 :         if !self
     148            0 :             .caches
     149            0 :             .endpoints_cache
     150            0 :             .is_valid(ctx, &endpoint.normalize())
     151              :         {
     152            0 :             return Err(GetEndpointJwksError::EndpointNotFound);
     153            0 :         }
     154            0 :         let request_id = ctx.session_id().to_string();
     155            0 :         async {
     156            0 :             let request = self
     157            0 :                 .endpoint
     158            0 :                 .get_with_url(|url| {
     159            0 :                     url.path_segments_mut()
     160            0 :                         .push("endpoints")
     161            0 :                         .push(endpoint.as_str())
     162            0 :                         .push("jwks");
     163            0 :                 })
     164            0 :                 .header(X_REQUEST_ID, &request_id)
     165            0 :                 .header(AUTHORIZATION, format!("Bearer {}", &self.jwt))
     166            0 :                 .query(&[("session_id", ctx.session_id())])
     167            0 :                 .build()
     168            0 :                 .map_err(GetEndpointJwksError::RequestBuild)?;
     169              : 
     170            0 :             debug!(url = request.url().as_str(), "sending http request");
     171            0 :             let start = Instant::now();
     172            0 :             let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane);
     173            0 :             let response = self
     174            0 :                 .endpoint
     175            0 :                 .execute(request)
     176            0 :                 .await
     177            0 :                 .map_err(GetEndpointJwksError::RequestExecute)?;
     178            0 :             drop(pause);
     179            0 :             info!(duration = ?start.elapsed(), "received http response");
     180              : 
     181            0 :             let body = parse_body::<EndpointJwksResponse>(response).await?;
     182              : 
     183            0 :             let rules = body
     184            0 :                 .jwks
     185            0 :                 .into_iter()
     186            0 :                 .map(|jwks| AuthRule {
     187            0 :                     id: jwks.id,
     188            0 :                     jwks_url: jwks.jwks_url,
     189            0 :                     audience: jwks.jwt_audience,
     190            0 :                     role_names: jwks.role_names,
     191            0 :                 })
     192            0 :                 .collect();
     193            0 : 
     194            0 :             Ok(rules)
     195            0 :         }
     196            0 :         .map_err(crate::error::log_error)
     197            0 :         .instrument(info_span!("http", id = request_id))
     198            0 :         .await
     199            0 :     }
     200              : 
     201            0 :     async fn do_wake_compute(
     202            0 :         &self,
     203            0 :         ctx: &RequestContext,
     204            0 :         user_info: &ComputeUserInfo,
     205            0 :     ) -> Result<NodeInfo, WakeComputeError> {
     206            0 :         let request_id = ctx.session_id().to_string();
     207            0 :         let application_name = ctx.console_application_name();
     208            0 :         async {
     209            0 :             let mut request_builder = self
     210            0 :                 .endpoint
     211            0 :                 .get_path("proxy_wake_compute")
     212            0 :                 .header("X-Request-ID", &request_id)
     213            0 :                 .header("Authorization", format!("Bearer {}", &self.jwt))
     214            0 :                 .query(&[("session_id", ctx.session_id())])
     215            0 :                 .query(&[
     216            0 :                     ("application_name", application_name.as_str()),
     217            0 :                     ("project", user_info.endpoint.as_str()),
     218            0 :                 ]);
     219            0 : 
     220            0 :             let options = user_info.options.to_deep_object();
     221            0 :             if !options.is_empty() {
     222            0 :                 request_builder = request_builder.query(&options);
     223            0 :             }
     224              : 
     225            0 :             let request = request_builder.build()?;
     226              : 
     227            0 :             debug!(url = request.url().as_str(), "sending http request");
     228            0 :             let start = Instant::now();
     229            0 :             let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane);
     230            0 :             let response = self.endpoint.execute(request).await?;
     231            0 :             drop(pause);
     232            0 :             info!(duration = ?start.elapsed(), "received http response");
     233            0 :             let body = parse_body::<WakeCompute>(response).await?;
     234              : 
     235              :             // Unfortunately, ownership won't let us use `Option::ok_or` here.
     236            0 :             let (host, port) = match parse_host_port(&body.address) {
     237            0 :                 None => return Err(WakeComputeError::BadComputeAddress(body.address)),
     238            0 :                 Some(x) => x,
     239            0 :             };
     240            0 : 
     241            0 :             // Don't set anything but host and port! This config will be cached.
     242            0 :             // We'll set username and such later using the startup message.
     243            0 :             // TODO: add more type safety (in progress).
     244            0 :             let mut config = compute::ConnCfg::new();
     245            0 :             config.host(host).port(port).ssl_mode(SslMode::Disable); // TLS is not configured on compute nodes.
     246            0 : 
     247            0 :             let node = NodeInfo {
     248            0 :                 config,
     249            0 :                 aux: body.aux,
     250            0 :                 allow_self_signed_compute: false,
     251            0 :             };
     252            0 : 
     253            0 :             Ok(node)
     254            0 :         }
     255            0 :         .map_err(crate::error::log_error)
     256            0 :         // TODO: redo this span stuff
     257            0 :         .instrument(info_span!("http", id = request_id))
     258            0 :         .await
     259            0 :     }
     260              : }
     261              : 
     262              : impl super::ControlPlaneApi for NeonControlPlaneClient {
     263            0 :     #[tracing::instrument(skip_all)]
     264              :     async fn get_role_secret(
     265              :         &self,
     266              :         ctx: &RequestContext,
     267              :         user_info: &ComputeUserInfo,
     268              :     ) -> Result<CachedRoleSecret, GetAuthInfoError> {
     269              :         let normalized_ep = &user_info.endpoint.normalize();
     270              :         let user = &user_info.user;
     271              :         if let Some(role_secret) = self
     272              :             .caches
     273              :             .project_info
     274              :             .get_role_secret(normalized_ep, user)
     275              :         {
     276              :             return Ok(role_secret);
     277              :         }
     278              :         let auth_info = self.do_get_auth_info(ctx, user_info).await?;
     279              :         if let Some(project_id) = auth_info.project_id {
     280              :             let normalized_ep_int = normalized_ep.into();
     281              :             self.caches.project_info.insert_role_secret(
     282              :                 project_id,
     283              :                 normalized_ep_int,
     284              :                 user.into(),
     285              :                 auth_info.secret.clone(),
     286              :             );
     287              :             self.caches.project_info.insert_allowed_ips(
     288              :                 project_id,
     289              :                 normalized_ep_int,
     290              :                 Arc::new(auth_info.allowed_ips),
     291              :             );
     292              :             ctx.set_project_id(project_id);
     293              :         }
     294              :         // When we just got a secret, we don't need to invalidate it.
     295              :         Ok(Cached::new_uncached(auth_info.secret))
     296              :     }
     297              : 
     298            0 :     async fn get_allowed_ips_and_secret(
     299            0 :         &self,
     300            0 :         ctx: &RequestContext,
     301            0 :         user_info: &ComputeUserInfo,
     302            0 :     ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
     303            0 :         let normalized_ep = &user_info.endpoint.normalize();
     304            0 :         if let Some(allowed_ips) = self.caches.project_info.get_allowed_ips(normalized_ep) {
     305            0 :             Metrics::get()
     306            0 :                 .proxy
     307            0 :                 .allowed_ips_cache_misses
     308            0 :                 .inc(CacheOutcome::Hit);
     309            0 :             return Ok((allowed_ips, None));
     310            0 :         }
     311            0 :         Metrics::get()
     312            0 :             .proxy
     313            0 :             .allowed_ips_cache_misses
     314            0 :             .inc(CacheOutcome::Miss);
     315            0 :         let auth_info = self.do_get_auth_info(ctx, user_info).await?;
     316            0 :         let allowed_ips = Arc::new(auth_info.allowed_ips);
     317            0 :         let user = &user_info.user;
     318            0 :         if let Some(project_id) = auth_info.project_id {
     319            0 :             let normalized_ep_int = normalized_ep.into();
     320            0 :             self.caches.project_info.insert_role_secret(
     321            0 :                 project_id,
     322            0 :                 normalized_ep_int,
     323            0 :                 user.into(),
     324            0 :                 auth_info.secret.clone(),
     325            0 :             );
     326            0 :             self.caches.project_info.insert_allowed_ips(
     327            0 :                 project_id,
     328            0 :                 normalized_ep_int,
     329            0 :                 allowed_ips.clone(),
     330            0 :             );
     331            0 :             ctx.set_project_id(project_id);
     332            0 :         }
     333            0 :         Ok((
     334            0 :             Cached::new_uncached(allowed_ips),
     335            0 :             Some(Cached::new_uncached(auth_info.secret)),
     336            0 :         ))
     337            0 :     }
     338              : 
     339            0 :     #[tracing::instrument(skip_all)]
     340              :     async fn get_endpoint_jwks(
     341              :         &self,
     342              :         ctx: &RequestContext,
     343              :         endpoint: EndpointId,
     344              :     ) -> Result<Vec<AuthRule>, GetEndpointJwksError> {
     345              :         self.do_get_endpoint_jwks(ctx, endpoint).await
     346              :     }
     347              : 
     348            0 :     #[tracing::instrument(skip_all)]
     349              :     async fn wake_compute(
     350              :         &self,
     351              :         ctx: &RequestContext,
     352              :         user_info: &ComputeUserInfo,
     353              :     ) -> Result<CachedNodeInfo, WakeComputeError> {
     354              :         let key = user_info.endpoint_cache_key();
     355              : 
     356              :         macro_rules! check_cache {
     357              :             () => {
     358              :                 if let Some(cached) = self.caches.node_info.get(&key) {
     359              :                     let (cached, info) = cached.take_value();
     360            0 :                     let info = info.map_err(|c| {
     361            0 :                         info!(key = &*key, "found cached wake_compute error");
     362            0 :                         WakeComputeError::ControlPlane(ControlPlaneError::Message(Box::new(*c)))
     363            0 :                     })?;
     364              : 
     365              :                     debug!(key = &*key, "found cached compute node info");
     366              :                     ctx.set_project(info.aux.clone());
     367            0 :                     return Ok(cached.map(|()| info));
     368              :                 }
     369              :             };
     370              :         }
     371              : 
     372              :         // Every time we do a wakeup http request, the compute node will stay up
     373              :         // for some time (highly depends on the console's scale-to-zero policy);
     374              :         // The connection info remains the same during that period of time,
     375              :         // which means that we might cache it to reduce the load and latency.
     376              :         check_cache!();
     377              : 
     378              :         let permit = self.locks.get_permit(&key).await?;
     379              : 
     380              :         // after getting back a permit - it's possible the cache was filled
     381              :         // double check
     382              :         if permit.should_check_cache() {
     383              :             // TODO: if there is something in the cache, mark the permit as success.
     384              :             check_cache!();
     385              :         }
     386              : 
     387              :         // check rate limit
     388              :         if !self
     389              :             .wake_compute_endpoint_rate_limiter
     390              :             .check(user_info.endpoint.normalize_intern(), 1)
     391              :         {
     392              :             return Err(WakeComputeError::TooManyConnections);
     393              :         }
     394              : 
     395              :         let node = permit.release_result(self.do_wake_compute(ctx, user_info).await);
     396              :         match node {
     397              :             Ok(node) => {
     398              :                 ctx.set_project(node.aux.clone());
     399              :                 debug!(key = &*key, "created a cache entry for woken compute node");
     400              : 
     401              :                 let mut stored_node = node.clone();
     402              :                 // store the cached node as 'warm_cached'
     403              :                 stored_node.aux.cold_start_info = ColdStartInfo::WarmCached;
     404              : 
     405              :                 let (_, cached) = self.caches.node_info.insert_unit(key, Ok(stored_node));
     406              : 
     407            0 :                 Ok(cached.map(|()| node))
     408              :             }
     409              :             Err(err) => match err {
     410              :                 WakeComputeError::ControlPlane(ControlPlaneError::Message(err)) => {
     411              :                     let Some(status) = &err.status else {
     412              :                         return Err(WakeComputeError::ControlPlane(ControlPlaneError::Message(
     413              :                             err,
     414              :                         )));
     415              :                     };
     416              : 
     417              :                     let reason = status
     418              :                         .details
     419              :                         .error_info
     420            0 :                         .map_or(Reason::Unknown, |x| x.reason);
     421              : 
     422              :                     // if we can retry this error, do not cache it.
     423              :                     if reason.can_retry() {
     424              :                         return Err(WakeComputeError::ControlPlane(ControlPlaneError::Message(
     425              :                             err,
     426              :                         )));
     427              :                     }
     428              : 
     429              :                     // at this point, we should only have quota errors.
     430              :                     debug!(
     431              :                         key = &*key,
     432              :                         "created a cache entry for the wake compute error"
     433              :                     );
     434              : 
     435              :                     self.caches.node_info.insert_ttl(
     436              :                         key,
     437              :                         Err(err.clone()),
     438              :                         Duration::from_secs(30),
     439              :                     );
     440              : 
     441              :                     Err(WakeComputeError::ControlPlane(ControlPlaneError::Message(
     442              :                         err,
     443              :                     )))
     444              :                 }
     445              :                 err => return Err(err),
     446              :             },
     447              :         }
     448              :     }
     449              : }
     450              : 
     451              : /// Parse http response body, taking status code into account.
     452            0 : async fn parse_body<T: for<'a> serde::Deserialize<'a>>(
     453            0 :     response: http::Response,
     454            0 : ) -> Result<T, ControlPlaneError> {
     455            0 :     let status = response.status();
     456            0 :     if status.is_success() {
     457              :         // We shouldn't log raw body because it may contain secrets.
     458            0 :         info!("request succeeded, processing the body");
     459            0 :         return Ok(response.json().await?);
     460            0 :     }
     461            0 :     let s = response.bytes().await?;
     462              :     // Log plaintext to be able to detect, whether there are some cases not covered by the error struct.
     463            0 :     info!("response_error plaintext: {:?}", s);
     464              : 
     465              :     // Don't throw an error here because it's not as important
     466              :     // as the fact that the request itself has failed.
     467            0 :     let mut body = serde_json::from_slice(&s).unwrap_or_else(|e| {
     468            0 :         warn!("failed to parse error body: {e}");
     469            0 :         ControlPlaneErrorMessage {
     470            0 :             error: "reason unclear (malformed error message)".into(),
     471            0 :             http_status_code: status,
     472            0 :             status: None,
     473            0 :         }
     474            0 :     });
     475            0 :     body.http_status_code = status;
     476            0 : 
     477            0 :     warn!("console responded with an error ({status}): {body:?}");
     478            0 :     Err(ControlPlaneError::Message(Box::new(body)))
     479            0 : }
     480              : 
     481            3 : fn parse_host_port(input: &str) -> Option<(&str, u16)> {
     482            3 :     let (host, port) = input.rsplit_once(':')?;
     483            3 :     let ipv6_brackets: &[_] = &['[', ']'];
     484            3 :     Some((host.trim_matches(ipv6_brackets), port.parse().ok()?))
     485            3 : }
     486              : 
     487              : #[cfg(test)]
     488              : mod tests {
     489              :     use super::*;
     490              : 
     491              :     #[test]
     492            1 :     fn test_parse_host_port_v4() {
     493            1 :         let (host, port) = parse_host_port("127.0.0.1:5432").expect("failed to parse");
     494            1 :         assert_eq!(host, "127.0.0.1");
     495            1 :         assert_eq!(port, 5432);
     496            1 :     }
     497              : 
     498              :     #[test]
     499            1 :     fn test_parse_host_port_v6() {
     500            1 :         let (host, port) = parse_host_port("[2001:db8::1]:5432").expect("failed to parse");
     501            1 :         assert_eq!(host, "2001:db8::1");
     502            1 :         assert_eq!(port, 5432);
     503            1 :     }
     504              : 
     505              :     #[test]
     506            1 :     fn test_parse_host_port_url() {
     507            1 :         let (host, port) = parse_host_port("compute-foo-bar-1234.default.svc.cluster.local:5432")
     508            1 :             .expect("failed to parse");
     509            1 :         assert_eq!(host, "compute-foo-bar-1234.default.svc.cluster.local");
     510            1 :         assert_eq!(port, 5432);
     511            1 :     }
     512              : }
        

Generated by: LCOV version 2.1-beta