LCOV - code coverage report
Current view: top level - proxy/src/console/provider - neon.rs (source / functions) Coverage Total Hit
Test: 8ac049b474321fdc72ddcb56d7165153a1a900e8.info Lines: 7.0 % 143 10
Test Date: 2023-09-06 10:18:01 Functions: 8.3 % 36 3

            Line data    Source code
       1              : //! Production console backend.
       2              : 
       3              : use super::{
       4              :     super::messages::{ConsoleError, GetRoleSecret, WakeCompute},
       5              :     errors::{ApiError, GetAuthInfoError, WakeComputeError},
       6              :     ApiCaches, AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo,
       7              : };
       8              : use crate::{auth::ClientCredentials, compute, http, scram};
       9              : use async_trait::async_trait;
      10              : use futures::TryFutureExt;
      11              : use std::net::SocketAddr;
      12              : use tokio::time::Instant;
      13              : use tokio_postgres::config::SslMode;
      14              : use tracing::{error, info, info_span, warn, Instrument};
      15              : 
      16            0 : #[derive(Clone)]
      17              : pub struct Api {
      18              :     endpoint: http::Endpoint,
      19              :     caches: &'static ApiCaches,
      20              :     jwt: String,
      21              : }
      22              : 
      23              : impl Api {
      24              :     /// Construct an API object containing the auth parameters.
      25            0 :     pub fn new(endpoint: http::Endpoint, caches: &'static ApiCaches) -> Self {
      26            0 :         let jwt: String = match std::env::var("NEON_PROXY_TO_CONTROLPLANE_TOKEN") {
      27            0 :             Ok(v) => v,
      28            0 :             Err(_) => "".to_string(),
      29              :         };
      30            0 :         Self {
      31            0 :             endpoint,
      32            0 :             caches,
      33            0 :             jwt,
      34            0 :         }
      35            0 :     }
      36              : 
      37            0 :     pub fn url(&self) -> &str {
      38            0 :         self.endpoint.url().as_str()
      39            0 :     }
      40              : 
      41            0 :     async fn do_get_auth_info(
      42            0 :         &self,
      43            0 :         extra: &ConsoleReqExtra<'_>,
      44            0 :         creds: &ClientCredentials<'_>,
      45            0 :     ) -> Result<Option<AuthInfo>, GetAuthInfoError> {
      46            0 :         let request_id = uuid::Uuid::new_v4().to_string();
      47            0 :         async {
      48            0 :             let request = self
      49            0 :                 .endpoint
      50            0 :                 .get("proxy_get_role_secret")
      51            0 :                 .header("X-Request-ID", &request_id)
      52            0 :                 .header("Authorization", &self.jwt)
      53            0 :                 .query(&[("session_id", extra.session_id)])
      54            0 :                 .query(&[
      55            0 :                     ("application_name", extra.application_name),
      56            0 :                     ("project", Some(creds.project().expect("impossible"))),
      57            0 :                     ("role", Some(creds.user)),
      58            0 :                 ])
      59            0 :                 .build()?;
      60              : 
      61            0 :             info!(url = request.url().as_str(), "sending http request");
      62            0 :             let start = Instant::now();
      63            0 :             let response = self.endpoint.execute(request).await?;
      64            0 :             info!(duration = ?start.elapsed(), "received http response");
      65            0 :             let body = match parse_body::<GetRoleSecret>(response).await {
      66            0 :                 Ok(body) => body,
      67              :                 // Error 404 is special: it's ok not to have a secret.
      68            0 :                 Err(e) => match e.http_status_code() {
      69            0 :                     Some(http::StatusCode::NOT_FOUND) => return Ok(None),
      70            0 :                     _otherwise => return Err(e.into()),
      71              :                 },
      72              :             };
      73              : 
      74            0 :             let secret = scram::ServerSecret::parse(&body.role_secret)
      75            0 :                 .map(AuthInfo::Scram)
      76            0 :                 .ok_or(GetAuthInfoError::BadSecret)?;
      77              : 
      78            0 :             Ok(Some(secret))
      79            0 :         }
      80            0 :         .map_err(crate::error::log_error)
      81            0 :         .instrument(info_span!("http", id = request_id))
      82            0 :         .await
      83            0 :     }
      84              : 
      85            0 :     async fn do_wake_compute(
      86            0 :         &self,
      87            0 :         extra: &ConsoleReqExtra<'_>,
      88            0 :         creds: &ClientCredentials<'_>,
      89            0 :     ) -> Result<NodeInfo, WakeComputeError> {
      90            0 :         let project = creds.project().expect("impossible");
      91            0 :         let request_id = uuid::Uuid::new_v4().to_string();
      92            0 :         async {
      93            0 :             let request = self
      94            0 :                 .endpoint
      95            0 :                 .get("proxy_wake_compute")
      96            0 :                 .header("X-Request-ID", &request_id)
      97            0 :                 .header("Authorization", &self.jwt)
      98            0 :                 .query(&[("session_id", extra.session_id)])
      99            0 :                 .query(&[
     100            0 :                     ("application_name", extra.application_name),
     101            0 :                     ("project", Some(project)),
     102            0 :                 ])
     103            0 :                 .build()?;
     104              : 
     105            0 :             info!(url = request.url().as_str(), "sending http request");
     106            0 :             let start = Instant::now();
     107            0 :             let response = self.endpoint.execute(request).await?;
     108            0 :             info!(duration = ?start.elapsed(), "received http response");
     109            0 :             let body = parse_body::<WakeCompute>(response).await?;
     110              : 
     111              :             // Unfortunately, ownership won't let us use `Option::ok_or` here.
     112            0 :             let (host, port) = match parse_host_port(&body.address) {
     113            0 :                 None => return Err(WakeComputeError::BadComputeAddress(body.address)),
     114            0 :                 Some(x) => x,
     115            0 :             };
     116            0 : 
     117            0 :             // Don't set anything but host and port! This config will be cached.
     118            0 :             // We'll set username and such later using the startup message.
     119            0 :             // TODO: add more type safety (in progress).
     120            0 :             let mut config = compute::ConnCfg::new();
     121            0 :             config.host(&host).port(port).ssl_mode(SslMode::Disable); // TLS is not configured on compute nodes.
     122            0 : 
     123            0 :             let node = NodeInfo {
     124            0 :                 config,
     125            0 :                 aux: body.aux.into(),
     126            0 :                 allow_self_signed_compute: false,
     127            0 :             };
     128            0 : 
     129            0 :             Ok(node)
     130            0 :         }
     131            0 :         .map_err(crate::error::log_error)
     132            0 :         .instrument(info_span!("http", id = request_id))
     133            0 :         .await
     134            0 :     }
     135              : }
     136              : 
     137              : #[async_trait]
     138              : impl super::Api for Api {
     139            0 :     #[tracing::instrument(skip_all)]
     140              :     async fn get_auth_info(
     141              :         &self,
     142              :         extra: &ConsoleReqExtra<'_>,
     143              :         creds: &ClientCredentials,
     144            0 :     ) -> Result<Option<AuthInfo>, GetAuthInfoError> {
     145            0 :         self.do_get_auth_info(extra, creds).await
     146            0 :     }
     147              : 
     148            0 :     #[tracing::instrument(skip_all)]
     149              :     async fn wake_compute(
     150              :         &self,
     151              :         extra: &ConsoleReqExtra<'_>,
     152              :         creds: &ClientCredentials,
     153            0 :     ) -> Result<CachedNodeInfo, WakeComputeError> {
     154            0 :         let key = creds.project().expect("impossible");
     155              : 
     156              :         // Every time we do a wakeup http request, the compute node will stay up
     157              :         // for some time (highly depends on the console's scale-to-zero policy);
     158              :         // The connection info remains the same during that period of time,
     159              :         // which means that we might cache it to reduce the load and latency.
     160            0 :         if let Some(cached) = self.caches.node_info.get(key) {
     161            0 :             info!(key = key, "found cached compute node info");
     162            0 :             return Ok(cached);
     163            0 :         }
     164              : 
     165            0 :         let node = self.do_wake_compute(extra, creds).await?;
     166            0 :         let (_, cached) = self.caches.node_info.insert(key.into(), node);
     167            0 :         info!(key = key, "created a cache entry for compute node info");
     168              : 
     169            0 :         Ok(cached)
     170            0 :     }
     171              : }
     172              : 
     173              : /// Parse http response body, taking status code into account.
     174            0 : async fn parse_body<T: for<'a> serde::Deserialize<'a>>(
     175            0 :     response: http::Response,
     176            0 : ) -> Result<T, ApiError> {
     177            0 :     let status = response.status();
     178            0 :     if status.is_success() {
     179              :         // We shouldn't log raw body because it may contain secrets.
     180            0 :         info!("request succeeded, processing the body");
     181            0 :         return Ok(response.json().await?);
     182            0 :     }
     183              : 
     184              :     // Don't throw an error here because it's not as important
     185              :     // as the fact that the request itself has failed.
     186            0 :     let body = response.json().await.unwrap_or_else(|e| {
     187            0 :         warn!("failed to parse error body: {e}");
     188            0 :         ConsoleError {
     189            0 :             error: "reason unclear (malformed error message)".into(),
     190            0 :         }
     191            0 :     });
     192            0 : 
     193            0 :     let text = body.error;
     194            0 :     error!("console responded with an error ({status}): {text}");
     195            0 :     Err(ApiError::Console { status, text })
     196            0 : }
     197              : 
     198            1 : fn parse_host_port(input: &str) -> Option<(String, u16)> {
     199            1 :     let parsed: SocketAddr = input.parse().ok()?;
     200            1 :     Some((parsed.ip().to_string(), parsed.port()))
     201            1 : }
     202              : 
     203              : #[cfg(test)]
     204              : mod tests {
     205              :     use super::*;
     206              : 
     207            1 :     #[test]
     208            1 :     fn test_parse_host_port() {
     209            1 :         let (host, port) = parse_host_port("127.0.0.1:5432").expect("failed to parse");
     210            1 :         assert_eq!(host, "127.0.0.1");
     211            1 :         assert_eq!(port, 5432);
     212            1 :     }
     213              : }
        

Generated by: LCOV version 2.1-beta