LCOV - differential code coverage report
Current view: top level - proxy/src/console/provider - neon.rs (source / functions) Coverage Total Hit UBC CBC
Current: cd44433dd675caa99df17a61b18949c8387e2242.info Lines: 46.6 % 223 104 119 104
Current Date: 2024-01-09 02:06:09 Functions: 53.7 % 41 22 19 22
Baseline: 66c52a629a0f4a503e193045e0df4c77139e344b.info
Baseline Date: 2024-01-08 15:34:46

           TLA  Line data    Source code
       1                 : //! Production console backend.
       2                 : 
       3                 : use super::{
       4                 :     super::messages::{ConsoleError, GetRoleSecret, WakeCompute},
       5                 :     errors::{ApiError, GetAuthInfoError, WakeComputeError},
       6                 :     ApiCaches, ApiLocks, AuthInfo, AuthSecret, CachedNodeInfo, CachedRoleSecret, ConsoleReqExtra,
       7                 :     NodeInfo,
       8                 : };
       9                 : use crate::{auth::backend::ComputeUserInfo, compute, http, scram};
      10                 : use crate::{
      11                 :     context::RequestMonitoring,
      12                 :     metrics::{ALLOWED_IPS_BY_CACHE_OUTCOME, ALLOWED_IPS_NUMBER},
      13                 : };
      14                 : use async_trait::async_trait;
      15                 : use futures::TryFutureExt;
      16                 : use itertools::Itertools;
      17                 : use std::sync::Arc;
      18                 : use tokio::time::Instant;
      19                 : use tokio_postgres::config::SslMode;
      20                 : use tracing::{error, info, info_span, warn, Instrument};
      21                 : 
      22 UBC           0 : #[derive(Clone)]
      23                 : pub struct Api {
      24                 :     endpoint: http::Endpoint,
      25                 :     caches: &'static ApiCaches,
      26                 :     locks: &'static ApiLocks,
      27                 :     jwt: String,
      28                 : }
      29                 : 
      30                 : impl Api {
      31                 :     /// Construct an API object containing the auth parameters.
      32 CBC           1 :     pub fn new(
      33               1 :         endpoint: http::Endpoint,
      34               1 :         caches: &'static ApiCaches,
      35               1 :         locks: &'static ApiLocks,
      36               1 :     ) -> Self {
      37               1 :         let jwt: String = match std::env::var("NEON_PROXY_TO_CONTROLPLANE_TOKEN") {
      38 UBC           0 :             Ok(v) => v,
      39 CBC           1 :             Err(_) => "".to_string(),
      40                 :         };
      41               1 :         Self {
      42               1 :             endpoint,
      43               1 :             caches,
      44               1 :             locks,
      45               1 :             jwt,
      46               1 :         }
      47               1 :     }
      48                 : 
      49               1 :     pub fn url(&self) -> &str {
      50               1 :         self.endpoint.url().as_str()
      51               1 :     }
      52                 : 
      53               4 :     async fn do_get_auth_info(
      54               4 :         &self,
      55               4 :         ctx: &mut RequestMonitoring,
      56               4 :         creds: &ComputeUserInfo,
      57               4 :     ) -> Result<AuthInfo, GetAuthInfoError> {
      58               4 :         let request_id = uuid::Uuid::new_v4().to_string();
      59               4 :         let application_name = ctx.console_application_name();
      60               4 :         async {
      61               4 :             let request = self
      62               4 :                 .endpoint
      63               4 :                 .get("proxy_get_role_secret")
      64               4 :                 .header("X-Request-ID", &request_id)
      65               4 :                 .header("Authorization", format!("Bearer {}", &self.jwt))
      66               4 :                 .query(&[("session_id", ctx.session_id)])
      67               4 :                 .query(&[
      68               4 :                     ("application_name", application_name.as_str()),
      69               4 :                     ("project", creds.endpoint.as_str()),
      70               4 :                     ("role", creds.inner.user.as_str()),
      71               4 :                 ])
      72               4 :                 .build()?;
      73                 : 
      74               4 :             info!(url = request.url().as_str(), "sending http request");
      75               4 :             let start = Instant::now();
      76              12 :             let response = self.endpoint.execute(request).await?;
      77               3 :             info!(duration = ?start.elapsed(), "received http response");
      78               3 :             let body = match parse_body::<GetRoleSecret>(response).await {
      79 UBC           0 :                 Ok(body) => body,
      80                 :                 // Error 404 is special: it's ok not to have a secret.
      81 CBC           3 :                 Err(e) => match e.http_status_code() {
      82 UBC           0 :                     Some(http::StatusCode::NOT_FOUND) => return Ok(AuthInfo::default()),
      83 CBC           3 :                     _otherwise => return Err(e.into()),
      84                 :                 },
      85                 :             };
      86                 : 
      87 UBC           0 :             let secret = scram::ServerSecret::parse(&body.role_secret)
      88               0 :                 .map(AuthSecret::Scram)
      89               0 :                 .ok_or(GetAuthInfoError::BadSecret)?;
      90               0 :             let allowed_ips = body
      91               0 :                 .allowed_ips
      92               0 :                 .into_iter()
      93               0 :                 .flatten()
      94               0 :                 .map(String::from)
      95               0 :                 .collect_vec();
      96               0 :             ALLOWED_IPS_NUMBER.observe(allowed_ips.len() as f64);
      97               0 :             Ok(AuthInfo {
      98               0 :                 secret: Some(secret),
      99               0 :                 allowed_ips,
     100               0 :             })
     101 CBC           4 :         }
     102               4 :         .map_err(crate::error::log_error)
     103               4 :         .instrument(info_span!("http", id = request_id))
     104              12 :         .await
     105               4 :     }
     106                 : 
     107 UBC           0 :     async fn do_wake_compute(
     108               0 :         &self,
     109               0 :         ctx: &mut RequestMonitoring,
     110               0 :         extra: &ConsoleReqExtra,
     111               0 :         creds: &ComputeUserInfo,
     112               0 :     ) -> Result<NodeInfo, WakeComputeError> {
     113               0 :         let request_id = uuid::Uuid::new_v4().to_string();
     114               0 :         let application_name = ctx.console_application_name();
     115               0 :         async {
     116               0 :             let mut request_builder = self
     117               0 :                 .endpoint
     118               0 :                 .get("proxy_wake_compute")
     119               0 :                 .header("X-Request-ID", &request_id)
     120               0 :                 .header("Authorization", format!("Bearer {}", &self.jwt))
     121               0 :                 .query(&[("session_id", ctx.session_id)])
     122               0 :                 .query(&[
     123               0 :                     ("application_name", application_name.as_str()),
     124               0 :                     ("project", creds.endpoint.as_str()),
     125               0 :                 ]);
     126                 : 
     127               0 :             request_builder = if extra.options.is_empty() {
     128               0 :                 request_builder
     129                 :             } else {
     130               0 :                 request_builder.query(&extra.options_as_deep_object())
     131                 :             };
     132               0 :             let request = request_builder.build()?;
     133                 : 
     134               0 :             info!(url = request.url().as_str(), "sending http request");
     135               0 :             let start = Instant::now();
     136               0 :             let response = self.endpoint.execute(request).await?;
     137               0 :             info!(duration = ?start.elapsed(), "received http response");
     138               0 :             let body = parse_body::<WakeCompute>(response).await?;
     139                 : 
     140                 :             // Unfortunately, ownership won't let us use `Option::ok_or` here.
     141               0 :             let (host, port) = match parse_host_port(&body.address) {
     142               0 :                 None => return Err(WakeComputeError::BadComputeAddress(body.address)),
     143               0 :                 Some(x) => x,
     144               0 :             };
     145               0 : 
     146               0 :             // Don't set anything but host and port! This config will be cached.
     147               0 :             // We'll set username and such later using the startup message.
     148               0 :             // TODO: add more type safety (in progress).
     149               0 :             let mut config = compute::ConnCfg::new();
     150               0 :             config.host(host).port(port).ssl_mode(SslMode::Disable); // TLS is not configured on compute nodes.
     151               0 : 
     152               0 :             let node = NodeInfo {
     153               0 :                 config,
     154               0 :                 aux: body.aux,
     155               0 :                 allow_self_signed_compute: false,
     156               0 :             };
     157               0 : 
     158               0 :             Ok(node)
     159               0 :         }
     160               0 :         .map_err(crate::error::log_error)
     161               0 :         .instrument(info_span!("http", id = request_id))
     162               0 :         .await
     163               0 :     }
     164                 : }
     165                 : 
     166                 : #[async_trait]
     167                 : impl super::Api for Api {
     168               0 :     #[tracing::instrument(skip_all)]
     169                 :     async fn get_role_secret(
     170                 :         &self,
     171                 :         ctx: &mut RequestMonitoring,
     172                 :         creds: &ComputeUserInfo,
     173               0 :     ) -> Result<CachedRoleSecret, GetAuthInfoError> {
     174               0 :         let ep = creds.endpoint.clone();
     175               0 :         let user = creds.inner.user.clone();
     176               0 :         if let Some(role_secret) = self.caches.role_secret.get(&(ep.clone(), user.clone())) {
     177               0 :             return Ok(role_secret);
     178               0 :         }
     179               0 :         let auth_info = self.do_get_auth_info(ctx, creds).await?;
     180               0 :         let (_, secret) = self
     181               0 :             .caches
     182               0 :             .role_secret
     183               0 :             .insert((ep.clone(), user), auth_info.secret.clone());
     184               0 :         self.caches
     185               0 :             .allowed_ips
     186               0 :             .insert(ep, Arc::new(auth_info.allowed_ips));
     187               0 :         Ok(secret)
     188               0 :     }
     189                 : 
     190 CBC           4 :     async fn get_allowed_ips(
     191               4 :         &self,
     192               4 :         ctx: &mut RequestMonitoring,
     193               4 :         creds: &ComputeUserInfo,
     194               4 :     ) -> Result<Arc<Vec<String>>, GetAuthInfoError> {
     195               4 :         if let Some(allowed_ips) = self.caches.allowed_ips.get(&creds.endpoint) {
     196 UBC           0 :             ALLOWED_IPS_BY_CACHE_OUTCOME
     197               0 :                 .with_label_values(&["hit"])
     198               0 :                 .inc();
     199               0 :             return Ok(Arc::new(allowed_ips.to_vec()));
     200 CBC           4 :         }
     201               4 :         ALLOWED_IPS_BY_CACHE_OUTCOME
     202               4 :             .with_label_values(&["miss"])
     203               4 :             .inc();
     204              12 :         let auth_info = self.do_get_auth_info(ctx, creds).await?;
     205 UBC           0 :         let allowed_ips = Arc::new(auth_info.allowed_ips);
     206               0 :         let ep = creds.endpoint.clone();
     207               0 :         let user = creds.inner.user.clone();
     208               0 :         self.caches
     209               0 :             .role_secret
     210               0 :             .insert((ep.clone(), user), auth_info.secret);
     211               0 :         self.caches.allowed_ips.insert(ep, allowed_ips.clone());
     212               0 :         Ok(allowed_ips)
     213 CBC           8 :     }
     214                 : 
     215 UBC           0 :     #[tracing::instrument(skip_all)]
     216                 :     async fn wake_compute(
     217                 :         &self,
     218                 :         ctx: &mut RequestMonitoring,
     219                 :         extra: &ConsoleReqExtra,
     220                 :         creds: &ComputeUserInfo,
     221               0 :     ) -> Result<CachedNodeInfo, WakeComputeError> {
     222               0 :         let key: &str = &creds.inner.cache_key;
     223                 : 
     224                 :         // Every time we do a wakeup http request, the compute node will stay up
     225                 :         // for some time (highly depends on the console's scale-to-zero policy);
     226                 :         // The connection info remains the same during that period of time,
     227                 :         // which means that we might cache it to reduce the load and latency.
     228               0 :         if let Some(cached) = self.caches.node_info.get(key) {
     229               0 :             info!(key = key, "found cached compute node info");
     230               0 :             return Ok(cached);
     231               0 :         }
     232               0 : 
     233               0 :         let key: Arc<str> = key.into();
     234                 : 
     235               0 :         let permit = self.locks.get_wake_compute_permit(&key).await?;
     236                 : 
     237                 :         // after getting back a permit - it's possible the cache was filled
     238                 :         // double check
     239               0 :         if permit.should_check_cache() {
     240               0 :             if let Some(cached) = self.caches.node_info.get(&key) {
     241               0 :                 info!(key = &*key, "found cached compute node info");
     242               0 :                 return Ok(cached);
     243               0 :             }
     244               0 :         }
     245                 : 
     246               0 :         let node = self.do_wake_compute(ctx, extra, creds).await?;
     247               0 :         let (_, cached) = self.caches.node_info.insert(key.clone(), node);
     248               0 :         info!(key = &*key, "created a cache entry for compute node info");
     249                 : 
     250               0 :         Ok(cached)
     251               0 :     }
     252                 : }
     253                 : 
     254                 : /// Parse http response body, taking status code into account.
     255 CBC           3 : async fn parse_body<T: for<'a> serde::Deserialize<'a>>(
     256               3 :     response: http::Response,
     257               3 : ) -> Result<T, ApiError> {
     258               3 :     let status = response.status();
     259               3 :     if status.is_success() {
     260                 :         // We shouldn't log raw body because it may contain secrets.
     261               1 :         info!("request succeeded, processing the body");
     262               1 :         return Ok(response.json().await?);
     263               2 :     }
     264                 : 
     265                 :     // Don't throw an error here because it's not as important
     266                 :     // as the fact that the request itself has failed.
     267               2 :     let body = response.json().await.unwrap_or_else(|e| {
     268               2 :         warn!("failed to parse error body: {e}");
     269               2 :         ConsoleError {
     270               2 :             error: "reason unclear (malformed error message)".into(),
     271               2 :         }
     272               2 :     });
     273               2 : 
     274               2 :     let text = body.error;
     275               2 :     error!("console responded with an error ({status}): {text}");
     276               2 :     Err(ApiError::Console { status, text })
     277               3 : }
     278                 : 
     279               3 : fn parse_host_port(input: &str) -> Option<(&str, u16)> {
     280               3 :     let (host, port) = input.rsplit_once(':')?;
     281               3 :     let ipv6_brackets: &[_] = &['[', ']'];
     282               3 :     Some((host.trim_matches(ipv6_brackets), port.parse().ok()?))
     283               3 : }
     284                 : 
     285                 : #[cfg(test)]
     286                 : mod tests {
     287                 :     use super::*;
     288                 : 
     289               1 :     #[test]
     290               1 :     fn test_parse_host_port_v4() {
     291               1 :         let (host, port) = parse_host_port("127.0.0.1:5432").expect("failed to parse");
     292               1 :         assert_eq!(host, "127.0.0.1");
     293               1 :         assert_eq!(port, 5432);
     294               1 :     }
     295                 : 
     296               1 :     #[test]
     297               1 :     fn test_parse_host_port_v6() {
     298               1 :         let (host, port) = parse_host_port("[2001:db8::1]:5432").expect("failed to parse");
     299               1 :         assert_eq!(host, "2001:db8::1");
     300               1 :         assert_eq!(port, 5432);
     301               1 :     }
     302                 : 
     303               1 :     #[test]
     304               1 :     fn test_parse_host_port_url() {
     305               1 :         let (host, port) = parse_host_port("compute-foo-bar-1234.default.svc.cluster.local:5432")
     306               1 :             .expect("failed to parse");
     307               1 :         assert_eq!(host, "compute-foo-bar-1234.default.svc.cluster.local");
     308               1 :         assert_eq!(port, 5432);
     309               1 :     }
     310                 : }
        

Generated by: LCOV version 2.1-beta