LCOV - code coverage report
Current view: top level - proxy/src/control_plane/client - cplane_proxy_v1.rs (source / functions) Coverage Total Hit
Test: 6df3fc19ec669bcfbbf9aba41d1338898d24eaa0.info Lines: 4.9 % 432 21
Test Date: 2025-03-12 18:28:53 Functions: 10.0 % 40 4

            Line data    Source code
       1              : //! Production console backend.
       2              : 
       3              : use std::net::IpAddr;
       4              : use std::str::FromStr;
       5              : use std::sync::Arc;
       6              : use std::time::Duration;
       7              : 
       8              : use ::http::HeaderName;
       9              : use ::http::header::AUTHORIZATION;
      10              : use futures::TryFutureExt;
      11              : use postgres_client::config::SslMode;
      12              : use tokio::time::Instant;
      13              : use tracing::{Instrument, debug, info, info_span, warn};
      14              : 
      15              : use super::super::messages::{ControlPlaneErrorMessage, GetEndpointAccessControl, WakeCompute};
      16              : use crate::auth::backend::ComputeUserInfo;
      17              : use crate::auth::backend::jwt::AuthRule;
      18              : use crate::cache::Cached;
      19              : use crate::context::RequestContext;
      20              : use crate::control_plane::caches::ApiCaches;
      21              : use crate::control_plane::errors::{
      22              :     ControlPlaneError, GetAuthInfoError, GetEndpointJwksError, WakeComputeError,
      23              : };
      24              : use crate::control_plane::locks::ApiLocks;
      25              : use crate::control_plane::messages::{ColdStartInfo, EndpointJwksResponse, Reason};
      26              : use crate::control_plane::{
      27              :     AccessBlockerFlags, AuthInfo, AuthSecret, CachedAccessBlockerFlags, CachedAllowedIps,
      28              :     CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret, NodeInfo,
      29              : };
      30              : use crate::metrics::{CacheOutcome, Metrics};
      31              : use crate::rate_limiter::WakeComputeRateLimiter;
      32              : use crate::types::{EndpointCacheKey, EndpointId};
      33              : use crate::{compute, http, scram};
      34              : 
      35              : pub(crate) const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");
      36              : 
      37              : #[derive(Clone)]
      38              : pub struct NeonControlPlaneClient {
      39              :     endpoint: http::Endpoint,
      40              :     pub caches: &'static ApiCaches,
      41              :     pub(crate) locks: &'static ApiLocks<EndpointCacheKey>,
      42              :     pub(crate) wake_compute_endpoint_rate_limiter: Arc<WakeComputeRateLimiter>,
      43              :     // put in a shared ref so we don't copy secrets all over in memory
      44              :     jwt: Arc<str>,
      45              : }
      46              : 
      47              : impl NeonControlPlaneClient {
      48              :     /// Construct an API object containing the auth parameters.
      49            0 :     pub fn new(
      50            0 :         endpoint: http::Endpoint,
      51            0 :         jwt: Arc<str>,
      52            0 :         caches: &'static ApiCaches,
      53            0 :         locks: &'static ApiLocks<EndpointCacheKey>,
      54            0 :         wake_compute_endpoint_rate_limiter: Arc<WakeComputeRateLimiter>,
      55            0 :     ) -> Self {
      56            0 :         Self {
      57            0 :             endpoint,
      58            0 :             caches,
      59            0 :             locks,
      60            0 :             wake_compute_endpoint_rate_limiter,
      61            0 :             jwt,
      62            0 :         }
      63            0 :     }
      64              : 
      65            0 :     pub(crate) fn url(&self) -> &str {
      66            0 :         self.endpoint.url().as_str()
      67            0 :     }
      68              : 
      69            0 :     async fn do_get_auth_info(
      70            0 :         &self,
      71            0 :         ctx: &RequestContext,
      72            0 :         user_info: &ComputeUserInfo,
      73            0 :     ) -> Result<AuthInfo, GetAuthInfoError> {
      74            0 :         if !self
      75            0 :             .caches
      76            0 :             .endpoints_cache
      77            0 :             .is_valid(ctx, &user_info.endpoint.normalize())
      78              :         {
      79              :             // TODO: refactor this because it's weird
      80              :             // this is a failure to authenticate but we return Ok.
      81            0 :             info!("endpoint is not valid, skipping the request");
      82            0 :             return Ok(AuthInfo::default());
      83            0 :         }
      84            0 :         self.do_get_auth_req(user_info, &ctx.session_id(), Some(ctx))
      85            0 :             .await
      86            0 :     }
      87              : 
      88            0 :     async fn do_get_auth_req(
      89            0 :         &self,
      90            0 :         user_info: &ComputeUserInfo,
      91            0 :         session_id: &uuid::Uuid,
      92            0 :         ctx: Option<&RequestContext>,
      93            0 :     ) -> Result<AuthInfo, GetAuthInfoError> {
      94            0 :         let request_id: String = session_id.to_string();
      95            0 :         let application_name = if let Some(ctx) = ctx {
      96            0 :             ctx.console_application_name()
      97              :         } else {
      98            0 :             "auth_cancellation".to_string()
      99              :         };
     100              : 
     101            0 :         async {
     102            0 :             let request = self
     103            0 :                 .endpoint
     104            0 :                 .get_path("get_endpoint_access_control")
     105            0 :                 .header(X_REQUEST_ID, &request_id)
     106            0 :                 .header(AUTHORIZATION, format!("Bearer {}", &self.jwt))
     107            0 :                 .query(&[("session_id", session_id)])
     108            0 :                 .query(&[
     109            0 :                     ("application_name", application_name.as_str()),
     110            0 :                     ("endpointish", user_info.endpoint.as_str()),
     111            0 :                     ("role", user_info.user.as_str()),
     112            0 :                 ])
     113            0 :                 .build()?;
     114              : 
     115            0 :             debug!(url = request.url().as_str(), "sending http request");
     116            0 :             let start = Instant::now();
     117            0 :             let response = match ctx {
     118            0 :                 Some(ctx) => {
     119            0 :                     let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane);
     120            0 :                     let rsp = self.endpoint.execute(request).await;
     121            0 :                     drop(pause);
     122            0 :                     rsp?
     123              :                 }
     124            0 :                 None => self.endpoint.execute(request).await?,
     125              :             };
     126              : 
     127            0 :             info!(duration = ?start.elapsed(), "received http response");
     128            0 :             let body = match parse_body::<GetEndpointAccessControl>(response).await {
     129            0 :                 Ok(body) => body,
     130              :                 // Error 404 is special: it's ok not to have a secret.
     131              :                 // TODO(anna): retry
     132            0 :                 Err(e) => {
     133            0 :                     return if e.get_reason().is_not_found() {
     134              :                         // TODO: refactor this because it's weird
     135              :                         // this is a failure to authenticate but we return Ok.
     136            0 :                         Ok(AuthInfo::default())
     137              :                     } else {
     138            0 :                         Err(e.into())
     139              :                     };
     140              :                 }
     141              :             };
     142              : 
     143            0 :             let secret = if body.role_secret.is_empty() {
     144            0 :                 None
     145              :             } else {
     146            0 :                 let secret = scram::ServerSecret::parse(&body.role_secret)
     147            0 :                     .map(AuthSecret::Scram)
     148            0 :                     .ok_or(GetAuthInfoError::BadSecret)?;
     149            0 :                 Some(secret)
     150              :             };
     151            0 :             let allowed_ips = body.allowed_ips.unwrap_or_default();
     152            0 :             Metrics::get()
     153            0 :                 .proxy
     154            0 :                 .allowed_ips_number
     155            0 :                 .observe(allowed_ips.len() as f64);
     156            0 :             let allowed_vpc_endpoint_ids = body.allowed_vpc_endpoint_ids.unwrap_or_default();
     157            0 :             Metrics::get()
     158            0 :                 .proxy
     159            0 :                 .allowed_vpc_endpoint_ids
     160            0 :                 .observe(allowed_vpc_endpoint_ids.len() as f64);
     161            0 :             let block_public_connections = body.block_public_connections.unwrap_or_default();
     162            0 :             let block_vpc_connections = body.block_vpc_connections.unwrap_or_default();
     163            0 :             Ok(AuthInfo {
     164            0 :                 secret,
     165            0 :                 allowed_ips,
     166            0 :                 allowed_vpc_endpoint_ids,
     167            0 :                 project_id: body.project_id,
     168            0 :                 account_id: body.account_id,
     169            0 :                 access_blocker_flags: AccessBlockerFlags {
     170            0 :                     public_access_blocked: block_public_connections,
     171            0 :                     vpc_access_blocked: block_vpc_connections,
     172            0 :                 },
     173            0 :             })
     174            0 :         }
     175            0 :         .inspect_err(|e| tracing::debug!(error = ?e))
     176            0 :         .instrument(info_span!("do_get_auth_info"))
     177            0 :         .await
     178            0 :     }
     179              : 
     180            0 :     async fn do_get_endpoint_jwks(
     181            0 :         &self,
     182            0 :         ctx: &RequestContext,
     183            0 :         endpoint: EndpointId,
     184            0 :     ) -> Result<Vec<AuthRule>, GetEndpointJwksError> {
     185            0 :         if !self
     186            0 :             .caches
     187            0 :             .endpoints_cache
     188            0 :             .is_valid(ctx, &endpoint.normalize())
     189              :         {
     190            0 :             return Err(GetEndpointJwksError::EndpointNotFound);
     191            0 :         }
     192            0 :         let request_id = ctx.session_id().to_string();
     193            0 :         async {
     194            0 :             let request = self
     195            0 :                 .endpoint
     196            0 :                 .get_with_url(|url| {
     197            0 :                     url.path_segments_mut()
     198            0 :                         .push("endpoints")
     199            0 :                         .push(endpoint.as_str())
     200            0 :                         .push("jwks");
     201            0 :                 })
     202            0 :                 .header(X_REQUEST_ID, &request_id)
     203            0 :                 .header(AUTHORIZATION, format!("Bearer {}", &self.jwt))
     204            0 :                 .query(&[("session_id", ctx.session_id())])
     205            0 :                 .build()
     206            0 :                 .map_err(GetEndpointJwksError::RequestBuild)?;
     207              : 
     208            0 :             debug!(url = request.url().as_str(), "sending http request");
     209            0 :             let start = Instant::now();
     210            0 :             let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane);
     211            0 :             let response = self
     212            0 :                 .endpoint
     213            0 :                 .execute(request)
     214            0 :                 .await
     215            0 :                 .map_err(GetEndpointJwksError::RequestExecute)?;
     216            0 :             drop(pause);
     217            0 :             info!(duration = ?start.elapsed(), "received http response");
     218              : 
     219            0 :             let body = parse_body::<EndpointJwksResponse>(response).await?;
     220              : 
     221            0 :             let rules = body
     222            0 :                 .jwks
     223            0 :                 .into_iter()
     224            0 :                 .map(|jwks| AuthRule {
     225            0 :                     id: jwks.id,
     226            0 :                     jwks_url: jwks.jwks_url,
     227            0 :                     audience: jwks.jwt_audience,
     228            0 :                     role_names: jwks.role_names,
     229            0 :                 })
     230            0 :                 .collect();
     231            0 : 
     232            0 :             Ok(rules)
     233            0 :         }
     234            0 :         .inspect_err(|e| tracing::debug!(error = ?e))
     235            0 :         .instrument(info_span!("do_get_endpoint_jwks"))
     236            0 :         .await
     237            0 :     }
     238              : 
     239            0 :     async fn do_wake_compute(
     240            0 :         &self,
     241            0 :         ctx: &RequestContext,
     242            0 :         user_info: &ComputeUserInfo,
     243            0 :     ) -> Result<NodeInfo, WakeComputeError> {
     244            0 :         let request_id = ctx.session_id().to_string();
     245            0 :         let application_name = ctx.console_application_name();
     246            0 :         async {
     247            0 :             let mut request_builder = self
     248            0 :                 .endpoint
     249            0 :                 .get_path("wake_compute")
     250            0 :                 .header("X-Request-ID", &request_id)
     251            0 :                 .header("Authorization", format!("Bearer {}", &self.jwt))
     252            0 :                 .query(&[("session_id", ctx.session_id())])
     253            0 :                 .query(&[
     254            0 :                     ("application_name", application_name.as_str()),
     255            0 :                     ("endpointish", user_info.endpoint.as_str()),
     256            0 :                 ]);
     257            0 : 
     258            0 :             let options = user_info.options.to_deep_object();
     259            0 :             if !options.is_empty() {
     260            0 :                 request_builder = request_builder.query(&options);
     261            0 :             }
     262              : 
     263            0 :             let request = request_builder.build()?;
     264              : 
     265            0 :             debug!(url = request.url().as_str(), "sending http request");
     266            0 :             let start = Instant::now();
     267            0 :             let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane);
     268            0 :             let response = self.endpoint.execute(request).await?;
     269            0 :             drop(pause);
     270            0 :             info!(duration = ?start.elapsed(), "received http response");
     271            0 :             let body = parse_body::<WakeCompute>(response).await?;
     272              : 
     273              :             // Unfortunately, ownership won't let us use `Option::ok_or` here.
     274            0 :             let (host, port) = match parse_host_port(&body.address) {
     275            0 :                 None => return Err(WakeComputeError::BadComputeAddress(body.address)),
     276            0 :                 Some(x) => x,
     277            0 :             };
     278            0 : 
     279            0 :             let host_addr = IpAddr::from_str(host).ok();
     280              : 
     281            0 :             let ssl_mode = match &body.server_name {
     282            0 :                 Some(_) => SslMode::Require,
     283            0 :                 None => SslMode::Disable,
     284              :             };
     285            0 :             let host_name = match body.server_name {
     286            0 :                 Some(host) => host,
     287            0 :                 None => host.to_owned(),
     288              :             };
     289              : 
     290              :             // Don't set anything but host and port! This config will be cached.
     291              :             // We'll set username and such later using the startup message.
     292              :             // TODO: add more type safety (in progress).
     293            0 :             let mut config = compute::ConnCfg::new(host_name, port);
     294              : 
     295            0 :             if let Some(addr) = host_addr {
     296            0 :                 config.set_host_addr(addr);
     297            0 :             }
     298              : 
     299            0 :             config.ssl_mode(ssl_mode);
     300            0 : 
     301            0 :             let node = NodeInfo {
     302            0 :                 config,
     303            0 :                 aux: body.aux,
     304            0 :             };
     305            0 : 
     306            0 :             Ok(node)
     307            0 :         }
     308            0 :         .inspect_err(|e| tracing::debug!(error = ?e))
     309            0 :         .instrument(info_span!("do_wake_compute"))
     310            0 :         .await
     311            0 :     }
     312              : }
     313              : 
     314              : impl super::ControlPlaneApi for NeonControlPlaneClient {
     315              :     #[tracing::instrument(skip_all)]
     316              :     async fn get_role_secret(
     317              :         &self,
     318              :         ctx: &RequestContext,
     319              :         user_info: &ComputeUserInfo,
     320              :     ) -> Result<CachedRoleSecret, GetAuthInfoError> {
     321              :         let normalized_ep = &user_info.endpoint.normalize();
     322              :         let user = &user_info.user;
     323              :         if let Some(role_secret) = self
     324              :             .caches
     325              :             .project_info
     326              :             .get_role_secret(normalized_ep, user)
     327              :         {
     328              :             return Ok(role_secret);
     329              :         }
     330              :         let auth_info = self.do_get_auth_info(ctx, user_info).await?;
     331              :         let account_id = auth_info.account_id;
     332              :         if let Some(project_id) = auth_info.project_id {
     333              :             let normalized_ep_int = normalized_ep.into();
     334              :             self.caches.project_info.insert_role_secret(
     335              :                 project_id,
     336              :                 normalized_ep_int,
     337              :                 user.into(),
     338              :                 auth_info.secret.clone(),
     339              :             );
     340              :             self.caches.project_info.insert_allowed_ips(
     341              :                 project_id,
     342              :                 normalized_ep_int,
     343              :                 Arc::new(auth_info.allowed_ips),
     344              :             );
     345              :             self.caches.project_info.insert_allowed_vpc_endpoint_ids(
     346              :                 account_id,
     347              :                 project_id,
     348              :                 normalized_ep_int,
     349              :                 Arc::new(auth_info.allowed_vpc_endpoint_ids),
     350              :             );
     351              :             self.caches.project_info.insert_block_public_or_vpc_access(
     352              :                 project_id,
     353              :                 normalized_ep_int,
     354              :                 auth_info.access_blocker_flags,
     355              :             );
     356              :             ctx.set_project_id(project_id);
     357              :         }
     358              :         // When we just got a secret, we don't need to invalidate it.
     359              :         Ok(Cached::new_uncached(auth_info.secret))
     360              :     }
     361              : 
     362            0 :     async fn get_allowed_ips(
     363            0 :         &self,
     364            0 :         ctx: &RequestContext,
     365            0 :         user_info: &ComputeUserInfo,
     366            0 :     ) -> Result<CachedAllowedIps, GetAuthInfoError> {
     367            0 :         let normalized_ep = &user_info.endpoint.normalize();
     368            0 :         if let Some(allowed_ips) = self.caches.project_info.get_allowed_ips(normalized_ep) {
     369            0 :             Metrics::get()
     370            0 :                 .proxy
     371            0 :                 .allowed_ips_cache_misses // TODO SR: Should we rename this variable to something like allowed_ip_cache_stats?
     372            0 :                 .inc(CacheOutcome::Hit);
     373            0 :             return Ok(allowed_ips);
     374            0 :         }
     375            0 :         Metrics::get()
     376            0 :             .proxy
     377            0 :             .allowed_ips_cache_misses
     378            0 :             .inc(CacheOutcome::Miss);
     379            0 :         let auth_info = self.do_get_auth_info(ctx, user_info).await?;
     380            0 :         let allowed_ips = Arc::new(auth_info.allowed_ips);
     381            0 :         let allowed_vpc_endpoint_ids = Arc::new(auth_info.allowed_vpc_endpoint_ids);
     382            0 :         let access_blocker_flags = auth_info.access_blocker_flags;
     383            0 :         let user = &user_info.user;
     384            0 :         let account_id = auth_info.account_id;
     385            0 :         if let Some(project_id) = auth_info.project_id {
     386            0 :             let normalized_ep_int = normalized_ep.into();
     387            0 :             self.caches.project_info.insert_role_secret(
     388            0 :                 project_id,
     389            0 :                 normalized_ep_int,
     390            0 :                 user.into(),
     391            0 :                 auth_info.secret.clone(),
     392            0 :             );
     393            0 :             self.caches.project_info.insert_allowed_ips(
     394            0 :                 project_id,
     395            0 :                 normalized_ep_int,
     396            0 :                 allowed_ips.clone(),
     397            0 :             );
     398            0 :             self.caches.project_info.insert_allowed_vpc_endpoint_ids(
     399            0 :                 account_id,
     400            0 :                 project_id,
     401            0 :                 normalized_ep_int,
     402            0 :                 allowed_vpc_endpoint_ids.clone(),
     403            0 :             );
     404            0 :             self.caches.project_info.insert_block_public_or_vpc_access(
     405            0 :                 project_id,
     406            0 :                 normalized_ep_int,
     407            0 :                 access_blocker_flags,
     408            0 :             );
     409            0 :             ctx.set_project_id(project_id);
     410            0 :         }
     411            0 :         Ok(Cached::new_uncached(allowed_ips))
     412            0 :     }
     413              : 
     414            0 :     async fn get_allowed_vpc_endpoint_ids(
     415            0 :         &self,
     416            0 :         ctx: &RequestContext,
     417            0 :         user_info: &ComputeUserInfo,
     418            0 :     ) -> Result<CachedAllowedVpcEndpointIds, GetAuthInfoError> {
     419            0 :         let normalized_ep = &user_info.endpoint.normalize();
     420            0 :         if let Some(allowed_vpc_endpoint_ids) = self
     421            0 :             .caches
     422            0 :             .project_info
     423            0 :             .get_allowed_vpc_endpoint_ids(normalized_ep)
     424              :         {
     425            0 :             Metrics::get()
     426            0 :                 .proxy
     427            0 :                 .vpc_endpoint_id_cache_stats
     428            0 :                 .inc(CacheOutcome::Hit);
     429            0 :             return Ok(allowed_vpc_endpoint_ids);
     430            0 :         }
     431            0 : 
     432            0 :         Metrics::get()
     433            0 :             .proxy
     434            0 :             .vpc_endpoint_id_cache_stats
     435            0 :             .inc(CacheOutcome::Miss);
     436              : 
     437            0 :         let auth_info = self.do_get_auth_info(ctx, user_info).await?;
     438            0 :         let allowed_ips = Arc::new(auth_info.allowed_ips);
     439            0 :         let allowed_vpc_endpoint_ids = Arc::new(auth_info.allowed_vpc_endpoint_ids);
     440            0 :         let access_blocker_flags = auth_info.access_blocker_flags;
     441            0 :         let user = &user_info.user;
     442            0 :         let account_id = auth_info.account_id;
     443            0 :         if let Some(project_id) = auth_info.project_id {
     444            0 :             let normalized_ep_int = normalized_ep.into();
     445            0 :             self.caches.project_info.insert_role_secret(
     446            0 :                 project_id,
     447            0 :                 normalized_ep_int,
     448            0 :                 user.into(),
     449            0 :                 auth_info.secret.clone(),
     450            0 :             );
     451            0 :             self.caches.project_info.insert_allowed_ips(
     452            0 :                 project_id,
     453            0 :                 normalized_ep_int,
     454            0 :                 allowed_ips.clone(),
     455            0 :             );
     456            0 :             self.caches.project_info.insert_allowed_vpc_endpoint_ids(
     457            0 :                 account_id,
     458            0 :                 project_id,
     459            0 :                 normalized_ep_int,
     460            0 :                 allowed_vpc_endpoint_ids.clone(),
     461            0 :             );
     462            0 :             self.caches.project_info.insert_block_public_or_vpc_access(
     463            0 :                 project_id,
     464            0 :                 normalized_ep_int,
     465            0 :                 access_blocker_flags,
     466            0 :             );
     467            0 :             ctx.set_project_id(project_id);
     468            0 :         }
     469            0 :         Ok(Cached::new_uncached(allowed_vpc_endpoint_ids))
     470            0 :     }
     471              : 
     472            0 :     async fn get_block_public_or_vpc_access(
     473            0 :         &self,
     474            0 :         ctx: &RequestContext,
     475            0 :         user_info: &ComputeUserInfo,
     476            0 :     ) -> Result<CachedAccessBlockerFlags, GetAuthInfoError> {
     477            0 :         let normalized_ep = &user_info.endpoint.normalize();
     478            0 :         if let Some(access_blocker_flags) = self
     479            0 :             .caches
     480            0 :             .project_info
     481            0 :             .get_block_public_or_vpc_access(normalized_ep)
     482              :         {
     483            0 :             Metrics::get()
     484            0 :                 .proxy
     485            0 :                 .access_blocker_flags_cache_stats
     486            0 :                 .inc(CacheOutcome::Hit);
     487            0 :             return Ok(access_blocker_flags);
     488            0 :         }
     489            0 : 
     490            0 :         Metrics::get()
     491            0 :             .proxy
     492            0 :             .access_blocker_flags_cache_stats
     493            0 :             .inc(CacheOutcome::Miss);
     494              : 
     495            0 :         let auth_info = self.do_get_auth_info(ctx, user_info).await?;
     496            0 :         let allowed_ips = Arc::new(auth_info.allowed_ips);
     497            0 :         let allowed_vpc_endpoint_ids = Arc::new(auth_info.allowed_vpc_endpoint_ids);
     498            0 :         let access_blocker_flags = auth_info.access_blocker_flags;
     499            0 :         let user = &user_info.user;
     500            0 :         let account_id = auth_info.account_id;
     501            0 :         if let Some(project_id) = auth_info.project_id {
     502            0 :             let normalized_ep_int = normalized_ep.into();
     503            0 :             self.caches.project_info.insert_role_secret(
     504            0 :                 project_id,
     505            0 :                 normalized_ep_int,
     506            0 :                 user.into(),
     507            0 :                 auth_info.secret.clone(),
     508            0 :             );
     509            0 :             self.caches.project_info.insert_allowed_ips(
     510            0 :                 project_id,
     511            0 :                 normalized_ep_int,
     512            0 :                 allowed_ips.clone(),
     513            0 :             );
     514            0 :             self.caches.project_info.insert_allowed_vpc_endpoint_ids(
     515            0 :                 account_id,
     516            0 :                 project_id,
     517            0 :                 normalized_ep_int,
     518            0 :                 allowed_vpc_endpoint_ids.clone(),
     519            0 :             );
     520            0 :             self.caches.project_info.insert_block_public_or_vpc_access(
     521            0 :                 project_id,
     522            0 :                 normalized_ep_int,
     523            0 :                 access_blocker_flags.clone(),
     524            0 :             );
     525            0 :             ctx.set_project_id(project_id);
     526            0 :         }
     527            0 :         Ok(Cached::new_uncached(access_blocker_flags))
     528            0 :     }
     529              : 
     530              :     #[tracing::instrument(skip_all)]
     531              :     async fn get_endpoint_jwks(
     532              :         &self,
     533              :         ctx: &RequestContext,
     534              :         endpoint: EndpointId,
     535              :     ) -> Result<Vec<AuthRule>, GetEndpointJwksError> {
     536              :         self.do_get_endpoint_jwks(ctx, endpoint).await
     537              :     }
     538              : 
     539              :     #[tracing::instrument(skip_all)]
     540              :     async fn wake_compute(
     541              :         &self,
     542              :         ctx: &RequestContext,
     543              :         user_info: &ComputeUserInfo,
     544              :     ) -> Result<CachedNodeInfo, WakeComputeError> {
     545              :         let key = user_info.endpoint_cache_key();
     546              : 
     547              :         macro_rules! check_cache {
     548              :             () => {
     549              :                 if let Some(cached) = self.caches.node_info.get(&key) {
     550              :                     let (cached, info) = cached.take_value();
     551            0 :                     let info = info.map_err(|c| {
     552            0 :                         info!(key = &*key, "found cached wake_compute error");
     553            0 :                         WakeComputeError::ControlPlane(ControlPlaneError::Message(Box::new(*c)))
     554            0 :                     })?;
     555              : 
     556              :                     debug!(key = &*key, "found cached compute node info");
     557              :                     ctx.set_project(info.aux.clone());
     558            0 :                     return Ok(cached.map(|()| info));
     559              :                 }
     560              :             };
     561              :         }
     562              : 
     563              :         // Every time we do a wakeup http request, the compute node will stay up
     564              :         // for some time (highly depends on the console's scale-to-zero policy);
     565              :         // The connection info remains the same during that period of time,
     566              :         // which means that we might cache it to reduce the load and latency.
     567              :         check_cache!();
     568              : 
     569              :         let permit = self.locks.get_permit(&key).await?;
     570              : 
     571              :         // after getting back a permit - it's possible the cache was filled
     572              :         // double check
     573              :         if permit.should_check_cache() {
     574              :             // TODO: if there is something in the cache, mark the permit as success.
     575              :             check_cache!();
     576              :         }
     577              : 
     578              :         // check rate limit
     579              :         if !self
     580              :             .wake_compute_endpoint_rate_limiter
     581              :             .check(user_info.endpoint.normalize_intern(), 1)
     582              :         {
     583              :             return Err(WakeComputeError::TooManyConnections);
     584              :         }
     585              : 
     586              :         let node = permit.release_result(self.do_wake_compute(ctx, user_info).await);
     587              :         match node {
     588              :             Ok(node) => {
     589              :                 ctx.set_project(node.aux.clone());
     590              :                 debug!(key = &*key, "created a cache entry for woken compute node");
     591              : 
     592              :                 let mut stored_node = node.clone();
     593              :                 // store the cached node as 'warm_cached'
     594              :                 stored_node.aux.cold_start_info = ColdStartInfo::WarmCached;
     595              : 
     596              :                 let (_, cached) = self.caches.node_info.insert_unit(key, Ok(stored_node));
     597              : 
     598            0 :                 Ok(cached.map(|()| node))
     599              :             }
     600              :             Err(err) => match err {
     601              :                 WakeComputeError::ControlPlane(ControlPlaneError::Message(err)) => {
     602              :                     let Some(status) = &err.status else {
     603              :                         return Err(WakeComputeError::ControlPlane(ControlPlaneError::Message(
     604              :                             err,
     605              :                         )));
     606              :                     };
     607              : 
     608              :                     let reason = status
     609              :                         .details
     610              :                         .error_info
     611            0 :                         .map_or(Reason::Unknown, |x| x.reason);
     612              : 
     613              :                     // if we can retry this error, do not cache it.
     614              :                     if reason.can_retry() {
     615              :                         return Err(WakeComputeError::ControlPlane(ControlPlaneError::Message(
     616              :                             err,
     617              :                         )));
     618              :                     }
     619              : 
     620              :                     // at this point, we should only have quota errors.
     621              :                     debug!(
     622              :                         key = &*key,
     623              :                         "created a cache entry for the wake compute error"
     624              :                     );
     625              : 
     626              :                     self.caches.node_info.insert_ttl(
     627              :                         key,
     628              :                         Err(err.clone()),
     629              :                         Duration::from_secs(30),
     630              :                     );
     631              : 
     632              :                     Err(WakeComputeError::ControlPlane(ControlPlaneError::Message(
     633              :                         err,
     634              :                     )))
     635              :                 }
     636              :                 err => return Err(err),
     637              :             },
     638              :         }
     639              :     }
     640              : }
     641              : 
     642              : /// Parse http response body, taking status code into account.
     643            0 : async fn parse_body<T: for<'a> serde::Deserialize<'a>>(
     644            0 :     response: http::Response,
     645            0 : ) -> Result<T, ControlPlaneError> {
     646            0 :     let status = response.status();
     647            0 :     if status.is_success() {
     648              :         // We shouldn't log raw body because it may contain secrets.
     649            0 :         info!("request succeeded, processing the body");
     650            0 :         return Ok(response.json().await?);
     651            0 :     }
     652            0 :     let s = response.bytes().await?;
     653              :     // Log plaintext to be able to detect, whether there are some cases not covered by the error struct.
     654            0 :     info!("response_error plaintext: {:?}", s);
     655              : 
     656              :     // Don't throw an error here because it's not as important
     657              :     // as the fact that the request itself has failed.
     658            0 :     let mut body = serde_json::from_slice(&s).unwrap_or_else(|e| {
     659            0 :         warn!("failed to parse error body: {e}");
     660            0 :         ControlPlaneErrorMessage {
     661            0 :             error: "reason unclear (malformed error message)".into(),
     662            0 :             http_status_code: status,
     663            0 :             status: None,
     664            0 :         }
     665            0 :     });
     666            0 :     body.http_status_code = status;
     667            0 : 
     668            0 :     warn!("console responded with an error ({status}): {body:?}");
     669            0 :     Err(ControlPlaneError::Message(Box::new(body)))
     670            0 : }
     671              : 
     672            3 : fn parse_host_port(input: &str) -> Option<(&str, u16)> {
     673            3 :     let (host, port) = input.rsplit_once(':')?;
     674            3 :     let ipv6_brackets: &[_] = &['[', ']'];
     675            3 :     Some((host.trim_matches(ipv6_brackets), port.parse().ok()?))
     676            3 : }
     677              : 
     678              : #[cfg(test)]
     679              : mod tests {
     680              :     use super::*;
     681              : 
     682              :     #[test]
     683            1 :     fn test_parse_host_port_v4() {
     684            1 :         let (host, port) = parse_host_port("127.0.0.1:5432").expect("failed to parse");
     685            1 :         assert_eq!(host, "127.0.0.1");
     686            1 :         assert_eq!(port, 5432);
     687            1 :     }
     688              : 
     689              :     #[test]
     690            1 :     fn test_parse_host_port_v6() {
     691            1 :         let (host, port) = parse_host_port("[2001:db8::1]:5432").expect("failed to parse");
     692            1 :         assert_eq!(host, "2001:db8::1");
     693            1 :         assert_eq!(port, 5432);
     694            1 :     }
     695              : 
     696              :     #[test]
     697            1 :     fn test_parse_host_port_url() {
     698            1 :         let (host, port) = parse_host_port("compute-foo-bar-1234.default.svc.cluster.local:5432")
     699            1 :             .expect("failed to parse");
     700            1 :         assert_eq!(host, "compute-foo-bar-1234.default.svc.cluster.local");
     701            1 :         assert_eq!(port, 5432);
     702            1 :     }
     703              : }
        

Generated by: LCOV version 2.1-beta