LCOV - code coverage report
Current view: top level - proxy/src/console/provider - neon.rs (source / functions) Coverage Total Hit
Test: 050dd70dd490b28fffe527eae9fb8a1222b5c59c.info Lines: 9.5 % 221 21
Test Date: 2024-06-25 21:28:46 Functions: 16.7 % 24 4

            Line data    Source code
       1              : //! Production console backend.
       2              : 
       3              : use super::{
       4              :     super::messages::{ConsoleError, GetRoleSecret, WakeCompute},
       5              :     errors::{ApiError, GetAuthInfoError, WakeComputeError},
       6              :     ApiCaches, ApiLocks, AuthInfo, AuthSecret, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret,
       7              :     NodeInfo,
       8              : };
       9              : use crate::{
      10              :     auth::backend::ComputeUserInfo,
      11              :     compute,
      12              :     console::messages::ColdStartInfo,
      13              :     http,
      14              :     metrics::{CacheOutcome, Metrics},
      15              :     rate_limiter::EndpointRateLimiter,
      16              :     scram, EndpointCacheKey,
      17              : };
      18              : use crate::{cache::Cached, context::RequestMonitoring};
      19              : use futures::TryFutureExt;
      20              : use std::sync::Arc;
      21              : use tokio::time::Instant;
      22              : use tokio_postgres::config::SslMode;
      23              : use tracing::{error, info, info_span, warn, Instrument};
      24              : 
      25              : pub struct Api {
      26              :     endpoint: http::Endpoint,
      27              :     pub caches: &'static ApiCaches,
      28              :     pub locks: &'static ApiLocks<EndpointCacheKey>,
      29              :     pub wake_compute_endpoint_rate_limiter: Arc<EndpointRateLimiter>,
      30              :     jwt: String,
      31              : }
      32              : 
      33              : impl Api {
      34              :     /// Construct an API object containing the auth parameters.
      35            0 :     pub fn new(
      36            0 :         endpoint: http::Endpoint,
      37            0 :         caches: &'static ApiCaches,
      38            0 :         locks: &'static ApiLocks<EndpointCacheKey>,
      39            0 :         wake_compute_endpoint_rate_limiter: Arc<EndpointRateLimiter>,
      40            0 :     ) -> Self {
      41            0 :         let jwt: String = match std::env::var("NEON_PROXY_TO_CONTROLPLANE_TOKEN") {
      42            0 :             Ok(v) => v,
      43            0 :             Err(_) => "".to_string(),
      44              :         };
      45            0 :         Self {
      46            0 :             endpoint,
      47            0 :             caches,
      48            0 :             locks,
      49            0 :             wake_compute_endpoint_rate_limiter,
      50            0 :             jwt,
      51            0 :         }
      52            0 :     }
      53              : 
      54            0 :     pub fn url(&self) -> &str {
      55            0 :         self.endpoint.url().as_str()
      56            0 :     }
      57              : 
      58            0 :     async fn do_get_auth_info(
      59            0 :         &self,
      60            0 :         ctx: &mut RequestMonitoring,
      61            0 :         user_info: &ComputeUserInfo,
      62            0 :     ) -> Result<AuthInfo, GetAuthInfoError> {
      63            0 :         if !self
      64            0 :             .caches
      65            0 :             .endpoints_cache
      66            0 :             .is_valid(ctx, &user_info.endpoint.normalize())
      67            0 :             .await
      68              :         {
      69            0 :             info!("endpoint is not valid, skipping the request");
      70            0 :             return Ok(AuthInfo::default());
      71            0 :         }
      72            0 :         let request_id = ctx.session_id.to_string();
      73            0 :         let application_name = ctx.console_application_name();
      74            0 :         async {
      75            0 :             let request = self
      76            0 :                 .endpoint
      77            0 :                 .get("proxy_get_role_secret")
      78            0 :                 .header("X-Request-ID", &request_id)
      79            0 :                 .header("Authorization", format!("Bearer {}", &self.jwt))
      80            0 :                 .query(&[("session_id", ctx.session_id)])
      81            0 :                 .query(&[
      82            0 :                     ("application_name", application_name.as_str()),
      83            0 :                     ("project", user_info.endpoint.as_str()),
      84            0 :                     ("role", user_info.user.as_str()),
      85            0 :                 ])
      86            0 :                 .build()?;
      87              : 
      88            0 :             info!(url = request.url().as_str(), "sending http request");
      89            0 :             let start = Instant::now();
      90            0 :             let pause = ctx.latency_timer.pause(crate::metrics::Waiting::Cplane);
      91            0 :             let response = self.endpoint.execute(request).await?;
      92            0 :             drop(pause);
      93            0 :             info!(duration = ?start.elapsed(), "received http response");
      94            0 :             let body = match parse_body::<GetRoleSecret>(response).await {
      95            0 :                 Ok(body) => body,
      96              :                 // Error 404 is special: it's ok not to have a secret.
      97              :                 // TODO(anna): retry
      98            0 :                 Err(e) => {
      99            0 :                     if e.get_reason().is_not_found() {
     100            0 :                         return Ok(AuthInfo::default());
     101              :                     } else {
     102            0 :                         return Err(e.into());
     103              :                     }
     104              :                 }
     105              :             };
     106              : 
     107            0 :             let secret = if body.role_secret.is_empty() {
     108            0 :                 None
     109              :             } else {
     110            0 :                 let secret = scram::ServerSecret::parse(&body.role_secret)
     111            0 :                     .map(AuthSecret::Scram)
     112            0 :                     .ok_or(GetAuthInfoError::BadSecret)?;
     113            0 :                 Some(secret)
     114              :             };
     115            0 :             let allowed_ips = body.allowed_ips.unwrap_or_default();
     116            0 :             Metrics::get()
     117            0 :                 .proxy
     118            0 :                 .allowed_ips_number
     119            0 :                 .observe(allowed_ips.len() as f64);
     120            0 :             Ok(AuthInfo {
     121            0 :                 secret,
     122            0 :                 allowed_ips,
     123            0 :                 project_id: body.project_id,
     124            0 :             })
     125            0 :         }
     126            0 :         .map_err(crate::error::log_error)
     127            0 :         .instrument(info_span!("http", id = request_id))
     128            0 :         .await
     129            0 :     }
     130              : 
     131            0 :     async fn do_wake_compute(
     132            0 :         &self,
     133            0 :         ctx: &mut RequestMonitoring,
     134            0 :         user_info: &ComputeUserInfo,
     135            0 :     ) -> Result<NodeInfo, WakeComputeError> {
     136            0 :         let request_id = ctx.session_id.to_string();
     137            0 :         let application_name = ctx.console_application_name();
     138            0 :         async {
     139            0 :             let mut request_builder = self
     140            0 :                 .endpoint
     141            0 :                 .get("proxy_wake_compute")
     142            0 :                 .header("X-Request-ID", &request_id)
     143            0 :                 .header("Authorization", format!("Bearer {}", &self.jwt))
     144            0 :                 .query(&[("session_id", ctx.session_id)])
     145            0 :                 .query(&[
     146            0 :                     ("application_name", application_name.as_str()),
     147            0 :                     ("project", user_info.endpoint.as_str()),
     148            0 :                 ]);
     149            0 : 
     150            0 :             let options = user_info.options.to_deep_object();
     151            0 :             if !options.is_empty() {
     152            0 :                 request_builder = request_builder.query(&options);
     153            0 :             }
     154              : 
     155            0 :             let request = request_builder.build()?;
     156              : 
     157            0 :             info!(url = request.url().as_str(), "sending http request");
     158            0 :             let start = Instant::now();
     159            0 :             let pause = ctx.latency_timer.pause(crate::metrics::Waiting::Cplane);
     160            0 :             let response = self.endpoint.execute(request).await?;
     161            0 :             drop(pause);
     162            0 :             info!(duration = ?start.elapsed(), "received http response");
     163            0 :             let body = parse_body::<WakeCompute>(response).await?;
     164              : 
     165              :             // Unfortunately, ownership won't let us use `Option::ok_or` here.
     166            0 :             let (host, port) = match parse_host_port(&body.address) {
     167            0 :                 None => return Err(WakeComputeError::BadComputeAddress(body.address)),
     168            0 :                 Some(x) => x,
     169            0 :             };
     170            0 : 
     171            0 :             // Don't set anything but host and port! This config will be cached.
     172            0 :             // We'll set username and such later using the startup message.
     173            0 :             // TODO: add more type safety (in progress).
     174            0 :             let mut config = compute::ConnCfg::new();
     175            0 :             config.host(host).port(port).ssl_mode(SslMode::Disable); // TLS is not configured on compute nodes.
     176            0 : 
     177            0 :             let node = NodeInfo {
     178            0 :                 config,
     179            0 :                 aux: body.aux,
     180            0 :                 allow_self_signed_compute: false,
     181            0 :             };
     182            0 : 
     183            0 :             Ok(node)
     184            0 :         }
     185            0 :         .map_err(crate::error::log_error)
     186            0 :         .instrument(info_span!("http", id = request_id))
     187            0 :         .await
     188            0 :     }
     189              : }
     190              : 
     191              : impl super::Api for Api {
     192            0 :     #[tracing::instrument(skip_all)]
     193              :     async fn get_role_secret(
     194              :         &self,
     195              :         ctx: &mut RequestMonitoring,
     196              :         user_info: &ComputeUserInfo,
     197              :     ) -> Result<CachedRoleSecret, GetAuthInfoError> {
     198              :         let normalized_ep = &user_info.endpoint.normalize();
     199              :         let user = &user_info.user;
     200              :         if let Some(role_secret) = self
     201              :             .caches
     202              :             .project_info
     203              :             .get_role_secret(normalized_ep, user)
     204              :         {
     205              :             return Ok(role_secret);
     206              :         }
     207              :         let auth_info = self.do_get_auth_info(ctx, user_info).await?;
     208              :         if let Some(project_id) = auth_info.project_id {
     209              :             let normalized_ep_int = normalized_ep.into();
     210              :             self.caches.project_info.insert_role_secret(
     211              :                 project_id,
     212              :                 normalized_ep_int,
     213              :                 user.into(),
     214              :                 auth_info.secret.clone(),
     215              :             );
     216              :             self.caches.project_info.insert_allowed_ips(
     217              :                 project_id,
     218              :                 normalized_ep_int,
     219              :                 Arc::new(auth_info.allowed_ips),
     220              :             );
     221              :             ctx.set_project_id(project_id);
     222              :         }
     223              :         // When we just got a secret, we don't need to invalidate it.
     224              :         Ok(Cached::new_uncached(auth_info.secret))
     225              :     }
     226              : 
     227            0 :     async fn get_allowed_ips_and_secret(
     228            0 :         &self,
     229            0 :         ctx: &mut RequestMonitoring,
     230            0 :         user_info: &ComputeUserInfo,
     231            0 :     ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
     232            0 :         let normalized_ep = &user_info.endpoint.normalize();
     233            0 :         if let Some(allowed_ips) = self.caches.project_info.get_allowed_ips(normalized_ep) {
     234            0 :             Metrics::get()
     235            0 :                 .proxy
     236            0 :                 .allowed_ips_cache_misses
     237            0 :                 .inc(CacheOutcome::Hit);
     238            0 :             return Ok((allowed_ips, None));
     239            0 :         }
     240            0 :         Metrics::get()
     241            0 :             .proxy
     242            0 :             .allowed_ips_cache_misses
     243            0 :             .inc(CacheOutcome::Miss);
     244            0 :         let auth_info = self.do_get_auth_info(ctx, user_info).await?;
     245            0 :         let allowed_ips = Arc::new(auth_info.allowed_ips);
     246            0 :         let user = &user_info.user;
     247            0 :         if let Some(project_id) = auth_info.project_id {
     248            0 :             let normalized_ep_int = normalized_ep.into();
     249            0 :             self.caches.project_info.insert_role_secret(
     250            0 :                 project_id,
     251            0 :                 normalized_ep_int,
     252            0 :                 user.into(),
     253            0 :                 auth_info.secret.clone(),
     254            0 :             );
     255            0 :             self.caches.project_info.insert_allowed_ips(
     256            0 :                 project_id,
     257            0 :                 normalized_ep_int,
     258            0 :                 allowed_ips.clone(),
     259            0 :             );
     260            0 :             ctx.set_project_id(project_id);
     261            0 :         }
     262            0 :         Ok((
     263            0 :             Cached::new_uncached(allowed_ips),
     264            0 :             Some(Cached::new_uncached(auth_info.secret)),
     265            0 :         ))
     266            0 :     }
     267              : 
     268            0 :     #[tracing::instrument(skip_all)]
     269              :     async fn wake_compute(
     270              :         &self,
     271              :         ctx: &mut RequestMonitoring,
     272              :         user_info: &ComputeUserInfo,
     273              :     ) -> Result<CachedNodeInfo, WakeComputeError> {
     274              :         let key = user_info.endpoint_cache_key();
     275              : 
     276              :         // Every time we do a wakeup http request, the compute node will stay up
     277              :         // for some time (highly depends on the console's scale-to-zero policy);
     278              :         // The connection info remains the same during that period of time,
     279              :         // which means that we might cache it to reduce the load and latency.
     280              :         if let Some(cached) = self.caches.node_info.get(&key) {
     281              :             info!(key = &*key, "found cached compute node info");
     282              :             ctx.set_project(cached.aux.clone());
     283              :             return Ok(cached);
     284              :         }
     285              : 
     286              :         let permit = self.locks.get_permit(&key).await?;
     287              : 
     288              :         // after getting back a permit - it's possible the cache was filled
     289              :         // double check
     290              :         if permit.should_check_cache() {
     291              :             if let Some(cached) = self.caches.node_info.get(&key) {
     292              :                 info!(key = &*key, "found cached compute node info");
     293              :                 ctx.set_project(cached.aux.clone());
     294              :                 return Ok(cached);
     295              :             }
     296              :         }
     297              : 
     298              :         // check rate limit
     299              :         if !self
     300              :             .wake_compute_endpoint_rate_limiter
     301              :             .check(user_info.endpoint.normalize_intern(), 1)
     302              :         {
     303              :             info!(key = &*key, "found cached compute node info");
     304              :             return Err(WakeComputeError::TooManyConnections);
     305              :         }
     306              : 
     307              :         let mut node = permit.release_result(self.do_wake_compute(ctx, user_info).await)?;
     308              :         ctx.set_project(node.aux.clone());
     309              :         let cold_start_info = node.aux.cold_start_info;
     310              :         info!("woken up a compute node");
     311              : 
     312              :         // store the cached node as 'warm'
     313              :         node.aux.cold_start_info = ColdStartInfo::WarmCached;
     314              :         let (_, mut cached) = self.caches.node_info.insert(key.clone(), node);
     315              :         cached.aux.cold_start_info = cold_start_info;
     316              : 
     317              :         info!(key = &*key, "created a cache entry for compute node info");
     318              : 
     319              :         Ok(cached)
     320              :     }
     321              : }
     322              : 
     323              : /// Parse http response body, taking status code into account.
     324            0 : async fn parse_body<T: for<'a> serde::Deserialize<'a>>(
     325            0 :     response: http::Response,
     326            0 : ) -> Result<T, ApiError> {
     327            0 :     let status = response.status();
     328            0 :     if status.is_success() {
     329              :         // We shouldn't log raw body because it may contain secrets.
     330            0 :         info!("request succeeded, processing the body");
     331            0 :         return Ok(response.json().await?);
     332            0 :     }
     333            0 :     let s = response.bytes().await?;
     334              :     // Log plaintext to be able to detect, whether there are some cases not covered by the error struct.
     335            0 :     info!("response_error plaintext: {:?}", s);
     336              : 
     337              :     // Don't throw an error here because it's not as important
     338              :     // as the fact that the request itself has failed.
     339            0 :     let mut body = serde_json::from_slice(&s).unwrap_or_else(|e| {
     340            0 :         warn!("failed to parse error body: {e}");
     341            0 :         ConsoleError {
     342            0 :             error: "reason unclear (malformed error message)".into(),
     343            0 :             http_status_code: status,
     344            0 :             status: None,
     345            0 :         }
     346            0 :     });
     347            0 :     body.http_status_code = status;
     348            0 : 
     349            0 :     error!("console responded with an error ({status}): {body:?}");
     350            0 :     Err(ApiError::Console(body))
     351            0 : }
     352              : 
     353            6 : fn parse_host_port(input: &str) -> Option<(&str, u16)> {
     354            6 :     let (host, port) = input.rsplit_once(':')?;
     355            6 :     let ipv6_brackets: &[_] = &['[', ']'];
     356            6 :     Some((host.trim_matches(ipv6_brackets), port.parse().ok()?))
     357            6 : }
     358              : 
     359              : #[cfg(test)]
     360              : mod tests {
     361              :     use super::*;
     362              : 
     363              :     #[test]
     364            2 :     fn test_parse_host_port_v4() {
     365            2 :         let (host, port) = parse_host_port("127.0.0.1:5432").expect("failed to parse");
     366            2 :         assert_eq!(host, "127.0.0.1");
     367            2 :         assert_eq!(port, 5432);
     368            2 :     }
     369              : 
     370              :     #[test]
     371            2 :     fn test_parse_host_port_v6() {
     372            2 :         let (host, port) = parse_host_port("[2001:db8::1]:5432").expect("failed to parse");
     373            2 :         assert_eq!(host, "2001:db8::1");
     374            2 :         assert_eq!(port, 5432);
     375            2 :     }
     376              : 
     377              :     #[test]
     378            2 :     fn test_parse_host_port_url() {
     379            2 :         let (host, port) = parse_host_port("compute-foo-bar-1234.default.svc.cluster.local:5432")
     380            2 :             .expect("failed to parse");
     381            2 :         assert_eq!(host, "compute-foo-bar-1234.default.svc.cluster.local");
     382            2 :         assert_eq!(port, 5432);
     383            2 :     }
     384              : }
        

Generated by: LCOV version 2.1-beta