LCOV - code coverage report
Current view: top level - proxy/src/console/provider - neon.rs (source / functions) Coverage Total Hit
Test: 32f4a56327bc9da697706839ed4836b2a00a408f.info Lines: 44.5 % 236 105
Test Date: 2024-02-07 07:37:29 Functions: 55.0 % 40 22

            Line data    Source code
       1              : //! Production console backend.
       2              : 
       3              : use super::{
       4              :     super::messages::{ConsoleError, GetRoleSecret, WakeCompute},
       5              :     errors::{ApiError, GetAuthInfoError, WakeComputeError},
       6              :     ApiCaches, ApiLocks, AuthInfo, AuthSecret, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret,
       7              :     NodeInfo,
       8              : };
       9              : use crate::{auth::backend::ComputeUserInfo, compute, http, scram};
      10              : use crate::{
      11              :     cache::Cached,
      12              :     context::RequestMonitoring,
      13              :     metrics::{ALLOWED_IPS_BY_CACHE_OUTCOME, ALLOWED_IPS_NUMBER},
      14              : };
      15              : use async_trait::async_trait;
      16              : use futures::TryFutureExt;
      17              : use std::sync::Arc;
      18              : use tokio::time::Instant;
      19              : use tokio_postgres::config::SslMode;
      20              : use tracing::{error, info, info_span, warn, Instrument};
      21              : 
      22              : pub struct Api {
      23              :     endpoint: http::Endpoint,
      24              :     pub caches: &'static ApiCaches,
      25              :     locks: &'static ApiLocks,
      26              :     jwt: String,
      27              : }
      28              : 
      29              : impl Api {
      30              :     /// Construct an API object containing the auth parameters.
      31            1 :     pub fn new(
      32            1 :         endpoint: http::Endpoint,
      33            1 :         caches: &'static ApiCaches,
      34            1 :         locks: &'static ApiLocks,
      35            1 :     ) -> Self {
      36            1 :         let jwt: String = match std::env::var("NEON_PROXY_TO_CONTROLPLANE_TOKEN") {
      37            0 :             Ok(v) => v,
      38            1 :             Err(_) => "".to_string(),
      39              :         };
      40            1 :         Self {
      41            1 :             endpoint,
      42            1 :             caches,
      43            1 :             locks,
      44            1 :             jwt,
      45            1 :         }
      46            1 :     }
      47              : 
      48            1 :     pub fn url(&self) -> &str {
      49            1 :         self.endpoint.url().as_str()
      50            1 :     }
      51              : 
      52            4 :     async fn do_get_auth_info(
      53            4 :         &self,
      54            4 :         ctx: &mut RequestMonitoring,
      55            4 :         user_info: &ComputeUserInfo,
      56            4 :     ) -> Result<AuthInfo, GetAuthInfoError> {
      57            4 :         let request_id = uuid::Uuid::new_v4().to_string();
      58            4 :         let application_name = ctx.console_application_name();
      59            4 :         async {
      60            4 :             let request = self
      61            4 :                 .endpoint
      62            4 :                 .get("proxy_get_role_secret")
      63            4 :                 .header("X-Request-ID", &request_id)
      64            4 :                 .header("Authorization", format!("Bearer {}", &self.jwt))
      65            4 :                 .query(&[("session_id", ctx.session_id)])
      66            4 :                 .query(&[
      67            4 :                     ("application_name", application_name.as_str()),
      68            4 :                     ("project", user_info.endpoint.as_str()),
      69            4 :                     ("role", user_info.user.as_str()),
      70            4 :                 ])
      71            4 :                 .build()?;
      72              : 
      73            4 :             info!(url = request.url().as_str(), "sending http request");
      74            4 :             let start = Instant::now();
      75           12 :             let response = self.endpoint.execute(request).await?;
      76            3 :             info!(duration = ?start.elapsed(), "received http response");
      77            3 :             let body = match parse_body::<GetRoleSecret>(response).await {
      78            0 :                 Ok(body) => body,
      79              :                 // Error 404 is special: it's ok not to have a secret.
      80            3 :                 Err(e) => match e.http_status_code() {
      81            0 :                     Some(http::StatusCode::NOT_FOUND) => return Ok(AuthInfo::default()),
      82            3 :                     _otherwise => return Err(e.into()),
      83              :                 },
      84              :             };
      85              : 
      86            0 :             let secret = if body.role_secret.is_empty() {
      87            0 :                 None
      88              :             } else {
      89            0 :                 let secret = scram::ServerSecret::parse(&body.role_secret)
      90            0 :                     .map(AuthSecret::Scram)
      91            0 :                     .ok_or(GetAuthInfoError::BadSecret)?;
      92            0 :                 Some(secret)
      93              :             };
      94            0 :             let allowed_ips = body.allowed_ips.unwrap_or_default();
      95            0 :             ALLOWED_IPS_NUMBER.observe(allowed_ips.len() as f64);
      96            0 :             Ok(AuthInfo {
      97            0 :                 secret,
      98            0 :                 allowed_ips,
      99            0 :                 project_id: body.project_id,
     100            0 :             })
     101            4 :         }
     102            4 :         .map_err(crate::error::log_error)
     103            4 :         .instrument(info_span!("http", id = request_id))
     104           12 :         .await
     105            4 :     }
     106              : 
     107            0 :     async fn do_wake_compute(
     108            0 :         &self,
     109            0 :         ctx: &mut RequestMonitoring,
     110            0 :         user_info: &ComputeUserInfo,
     111            0 :     ) -> Result<NodeInfo, WakeComputeError> {
     112            0 :         let request_id = uuid::Uuid::new_v4().to_string();
     113            0 :         let application_name = ctx.console_application_name();
     114            0 :         async {
     115            0 :             let mut request_builder = self
     116            0 :                 .endpoint
     117            0 :                 .get("proxy_wake_compute")
     118            0 :                 .header("X-Request-ID", &request_id)
     119            0 :                 .header("Authorization", format!("Bearer {}", &self.jwt))
     120            0 :                 .query(&[("session_id", ctx.session_id)])
     121            0 :                 .query(&[
     122            0 :                     ("application_name", application_name.as_str()),
     123            0 :                     ("project", user_info.endpoint.as_str()),
     124            0 :                 ]);
     125            0 : 
     126            0 :             let options = user_info.options.to_deep_object();
     127            0 :             if !options.is_empty() {
     128            0 :                 request_builder = request_builder.query(&options);
     129            0 :             }
     130              : 
     131            0 :             let request = request_builder.build()?;
     132              : 
     133            0 :             info!(url = request.url().as_str(), "sending http request");
     134            0 :             let start = Instant::now();
     135            0 :             let response = self.endpoint.execute(request).await?;
     136            0 :             info!(duration = ?start.elapsed(), "received http response");
     137            0 :             let body = parse_body::<WakeCompute>(response).await?;
     138              : 
     139              :             // Unfortunately, ownership won't let us use `Option::ok_or` here.
     140            0 :             let (host, port) = match parse_host_port(&body.address) {
     141            0 :                 None => return Err(WakeComputeError::BadComputeAddress(body.address)),
     142            0 :                 Some(x) => x,
     143            0 :             };
     144            0 : 
     145            0 :             // Don't set anything but host and port! This config will be cached.
     146            0 :             // We'll set username and such later using the startup message.
     147            0 :             // TODO: add more type safety (in progress).
     148            0 :             let mut config = compute::ConnCfg::new();
     149            0 :             config.host(host).port(port).ssl_mode(SslMode::Disable); // TLS is not configured on compute nodes.
     150            0 : 
     151            0 :             let node = NodeInfo {
     152            0 :                 config,
     153            0 :                 aux: body.aux,
     154            0 :                 allow_self_signed_compute: false,
     155            0 :             };
     156            0 : 
     157            0 :             Ok(node)
     158            0 :         }
     159            0 :         .map_err(crate::error::log_error)
     160            0 :         .instrument(info_span!("http", id = request_id))
     161            0 :         .await
     162            0 :     }
     163              : }
     164              : 
     165              : #[async_trait]
     166              : impl super::Api for Api {
     167            0 :     #[tracing::instrument(skip_all)]
     168              :     async fn get_role_secret(
     169              :         &self,
     170              :         ctx: &mut RequestMonitoring,
     171              :         user_info: &ComputeUserInfo,
     172            0 :     ) -> Result<CachedRoleSecret, GetAuthInfoError> {
     173            0 :         let ep = &user_info.endpoint;
     174            0 :         let user = &user_info.user;
     175            0 :         if let Some(role_secret) = self.caches.project_info.get_role_secret(ep, user) {
     176            0 :             return Ok(role_secret);
     177            0 :         }
     178            0 :         let auth_info = self.do_get_auth_info(ctx, user_info).await?;
     179            0 :         if let Some(project_id) = auth_info.project_id {
     180            0 :             self.caches.project_info.insert_role_secret(
     181            0 :                 &project_id,
     182            0 :                 ep,
     183            0 :                 user,
     184            0 :                 auth_info.secret.clone(),
     185            0 :             );
     186            0 :             self.caches.project_info.insert_allowed_ips(
     187            0 :                 &project_id,
     188            0 :                 ep,
     189            0 :                 Arc::new(auth_info.allowed_ips),
     190            0 :             );
     191            0 :         }
     192              :         // When we just got a secret, we don't need to invalidate it.
     193            0 :         Ok(Cached::new_uncached(auth_info.secret))
     194            0 :     }
     195              : 
     196            4 :     async fn get_allowed_ips_and_secret(
     197            4 :         &self,
     198            4 :         ctx: &mut RequestMonitoring,
     199            4 :         user_info: &ComputeUserInfo,
     200            4 :     ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
     201            4 :         let ep = &user_info.endpoint;
     202            4 :         if let Some(allowed_ips) = self.caches.project_info.get_allowed_ips(ep) {
     203            0 :             ALLOWED_IPS_BY_CACHE_OUTCOME
     204            0 :                 .with_label_values(&["hit"])
     205            0 :                 .inc();
     206            0 :             return Ok((allowed_ips, None));
     207            4 :         }
     208            4 :         ALLOWED_IPS_BY_CACHE_OUTCOME
     209            4 :             .with_label_values(&["miss"])
     210            4 :             .inc();
     211           12 :         let auth_info = self.do_get_auth_info(ctx, user_info).await?;
     212            0 :         let allowed_ips = Arc::new(auth_info.allowed_ips);
     213            0 :         let user = &user_info.user;
     214            0 :         if let Some(project_id) = auth_info.project_id {
     215            0 :             self.caches.project_info.insert_role_secret(
     216            0 :                 &project_id,
     217            0 :                 ep,
     218            0 :                 user,
     219            0 :                 auth_info.secret.clone(),
     220            0 :             );
     221            0 :             self.caches
     222            0 :                 .project_info
     223            0 :                 .insert_allowed_ips(&project_id, ep, allowed_ips.clone());
     224            0 :         }
     225            0 :         Ok((
     226            0 :             Cached::new_uncached(allowed_ips),
     227            0 :             Some(Cached::new_uncached(auth_info.secret)),
     228            0 :         ))
     229            8 :     }
     230              : 
     231            0 :     #[tracing::instrument(skip_all)]
     232              :     async fn wake_compute(
     233              :         &self,
     234              :         ctx: &mut RequestMonitoring,
     235              :         user_info: &ComputeUserInfo,
     236            0 :     ) -> Result<CachedNodeInfo, WakeComputeError> {
     237            0 :         let key = user_info.endpoint_cache_key();
     238              : 
     239              :         // Every time we do a wakeup http request, the compute node will stay up
     240              :         // for some time (highly depends on the console's scale-to-zero policy);
     241              :         // The connection info remains the same during that period of time,
     242              :         // which means that we might cache it to reduce the load and latency.
     243            0 :         if let Some(cached) = self.caches.node_info.get(&key) {
     244            0 :             info!(key = &*key, "found cached compute node info");
     245            0 :             return Ok(cached);
     246            0 :         }
     247              : 
     248            0 :         let permit = self.locks.get_wake_compute_permit(&key).await?;
     249              : 
     250              :         // after getting back a permit - it's possible the cache was filled
     251              :         // double check
     252            0 :         if permit.should_check_cache() {
     253            0 :             if let Some(cached) = self.caches.node_info.get(&key) {
     254            0 :                 info!(key = &*key, "found cached compute node info");
     255            0 :                 return Ok(cached);
     256            0 :             }
     257            0 :         }
     258              : 
     259            0 :         let node = self.do_wake_compute(ctx, user_info).await?;
     260            0 :         let (_, cached) = self.caches.node_info.insert(key.clone(), node);
     261            0 :         info!(key = &*key, "created a cache entry for compute node info");
     262              : 
     263            0 :         Ok(cached)
     264            0 :     }
     265              : }
     266              : 
     267              : /// Parse http response body, taking status code into account.
     268            3 : async fn parse_body<T: for<'a> serde::Deserialize<'a>>(
     269            3 :     response: http::Response,
     270            3 : ) -> Result<T, ApiError> {
     271            3 :     let status = response.status();
     272            3 :     if status.is_success() {
     273              :         // We shouldn't log raw body because it may contain secrets.
     274            1 :         info!("request succeeded, processing the body");
     275            1 :         return Ok(response.json().await?);
     276            2 :     }
     277              : 
     278              :     // Don't throw an error here because it's not as important
     279              :     // as the fact that the request itself has failed.
     280            2 :     let body = response.json().await.unwrap_or_else(|e| {
     281            2 :         warn!("failed to parse error body: {e}");
     282            2 :         ConsoleError {
     283            2 :             error: "reason unclear (malformed error message)".into(),
     284            2 :         }
     285            2 :     });
     286            2 : 
     287            2 :     let text = body.error;
     288            2 :     error!("console responded with an error ({status}): {text}");
     289            2 :     Err(ApiError::Console { status, text })
     290            3 : }
     291              : 
     292            6 : fn parse_host_port(input: &str) -> Option<(&str, u16)> {
     293            6 :     let (host, port) = input.rsplit_once(':')?;
     294            6 :     let ipv6_brackets: &[_] = &['[', ']'];
     295            6 :     Some((host.trim_matches(ipv6_brackets), port.parse().ok()?))
     296            6 : }
     297              : 
     298              : #[cfg(test)]
     299              : mod tests {
     300              :     use super::*;
     301              : 
     302            2 :     #[test]
     303            2 :     fn test_parse_host_port_v4() {
     304            2 :         let (host, port) = parse_host_port("127.0.0.1:5432").expect("failed to parse");
     305            2 :         assert_eq!(host, "127.0.0.1");
     306            2 :         assert_eq!(port, 5432);
     307            2 :     }
     308              : 
     309            2 :     #[test]
     310            2 :     fn test_parse_host_port_v6() {
     311            2 :         let (host, port) = parse_host_port("[2001:db8::1]:5432").expect("failed to parse");
     312            2 :         assert_eq!(host, "2001:db8::1");
     313            2 :         assert_eq!(port, 5432);
     314            2 :     }
     315              : 
     316            2 :     #[test]
     317            2 :     fn test_parse_host_port_url() {
     318            2 :         let (host, port) = parse_host_port("compute-foo-bar-1234.default.svc.cluster.local:5432")
     319            2 :             .expect("failed to parse");
     320            2 :         assert_eq!(host, "compute-foo-bar-1234.default.svc.cluster.local");
     321            2 :         assert_eq!(port, 5432);
     322            2 :     }
     323              : }
        

Generated by: LCOV version 2.1-beta