LCOV - code coverage report
Current view: top level - proxy/src/serverless - http_util.rs (source / functions) Coverage Total Hit
Test: 1e20c4f2b28aa592527961bb32170ebbd2c9172f.info Lines: 0.0 % 154 0
Test Date: 2025-07-16 12:29:03 Functions: 0.0 % 20 0

            Line data    Source code
       1              : //! Things stolen from `libs/utils/src/http` to add hyper 1.0 compatibility
       2              : //! Will merge back in at some point in the future.
       3              : 
       4              : use anyhow::Context;
       5              : use bytes::Bytes;
       6              : use http::header::AUTHORIZATION;
       7              : use http::{HeaderMap, HeaderName, HeaderValue, Response, StatusCode};
       8              : use http_body_util::combinators::BoxBody;
       9              : use http_body_util::{BodyExt, Full};
      10              : use http_utils::error::ApiError;
      11              : use serde::Serialize;
      12              : use url::Url;
      13              : use uuid::Uuid;
      14              : 
      15              : use super::conn_pool::{AuthData, ConnInfoWithAuth};
      16              : use super::conn_pool_lib::ConnInfo;
      17              : use super::error::{ConnInfoError, Credentials};
      18              : use crate::auth::backend::ComputeUserInfo;
      19              : use crate::config::AuthenticationConfig;
      20              : use crate::context::RequestContext;
      21              : use crate::metrics::{Metrics, SniGroup, SniKind};
      22              : use crate::pqproto::StartupMessageParams;
      23              : use crate::proxy::NeonOptions;
      24              : use crate::types::{DbName, EndpointId, RoleName};
      25              : 
      26              : // Common header names used across serverless modules
      27              : pub(super) static NEON_REQUEST_ID: HeaderName = HeaderName::from_static("neon-request-id");
      28              : pub(super) static CONN_STRING: HeaderName = HeaderName::from_static("neon-connection-string");
      29              : pub(super) static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output");
      30              : pub(super) static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode");
      31              : pub(super) static ALLOW_POOL: HeaderName = HeaderName::from_static("neon-pool-opt-in");
      32              : pub(super) static TXN_ISOLATION_LEVEL: HeaderName =
      33              :     HeaderName::from_static("neon-batch-isolation-level");
      34              : pub(super) static TXN_READ_ONLY: HeaderName = HeaderName::from_static("neon-batch-read-only");
      35              : pub(super) static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrable");
      36              : 
      37            0 : pub(crate) fn uuid_to_header_value(id: Uuid) -> HeaderValue {
      38            0 :     let mut uuid = [0; uuid::fmt::Hyphenated::LENGTH];
      39            0 :     HeaderValue::from_str(id.as_hyphenated().encode_lower(&mut uuid[..]))
      40            0 :         .expect("uuid hyphenated format should be all valid header characters")
      41            0 : }
      42              : 
      43              : /// Like [`ApiError::into_response`]
      44            0 : pub(crate) fn api_error_into_response(this: ApiError) -> Response<BoxBody<Bytes, hyper::Error>> {
      45            0 :     match this {
      46            0 :         ApiError::BadRequest(err) => HttpErrorBody::response_from_msg_and_status(
      47            0 :             format!("{err:#?}"), // use debug printing so that we give the cause
      48              :             StatusCode::BAD_REQUEST,
      49              :         ),
      50              :         ApiError::Forbidden(_) => {
      51            0 :             HttpErrorBody::response_from_msg_and_status(this.to_string(), StatusCode::FORBIDDEN)
      52              :         }
      53              :         ApiError::Unauthorized(_) => {
      54            0 :             HttpErrorBody::response_from_msg_and_status(this.to_string(), StatusCode::UNAUTHORIZED)
      55              :         }
      56              :         ApiError::NotFound(_) => {
      57            0 :             HttpErrorBody::response_from_msg_and_status(this.to_string(), StatusCode::NOT_FOUND)
      58              :         }
      59              :         ApiError::Conflict(_) => {
      60            0 :             HttpErrorBody::response_from_msg_and_status(this.to_string(), StatusCode::CONFLICT)
      61              :         }
      62            0 :         ApiError::PreconditionFailed(_) => HttpErrorBody::response_from_msg_and_status(
      63            0 :             this.to_string(),
      64              :             StatusCode::PRECONDITION_FAILED,
      65              :         ),
      66            0 :         ApiError::ShuttingDown => HttpErrorBody::response_from_msg_and_status(
      67            0 :             "Shutting down".to_string(),
      68              :             StatusCode::SERVICE_UNAVAILABLE,
      69              :         ),
      70            0 :         ApiError::ResourceUnavailable(err) => HttpErrorBody::response_from_msg_and_status(
      71            0 :             err.to_string(),
      72              :             StatusCode::SERVICE_UNAVAILABLE,
      73              :         ),
      74            0 :         ApiError::TooManyRequests(err) => HttpErrorBody::response_from_msg_and_status(
      75            0 :             err.to_string(),
      76              :             StatusCode::TOO_MANY_REQUESTS,
      77              :         ),
      78            0 :         ApiError::Timeout(err) => HttpErrorBody::response_from_msg_and_status(
      79            0 :             err.to_string(),
      80              :             StatusCode::REQUEST_TIMEOUT,
      81              :         ),
      82            0 :         ApiError::Cancelled => HttpErrorBody::response_from_msg_and_status(
      83            0 :             this.to_string(),
      84              :             StatusCode::INTERNAL_SERVER_ERROR,
      85              :         ),
      86            0 :         ApiError::InternalServerError(err) => HttpErrorBody::response_from_msg_and_status(
      87            0 :             err.to_string(),
      88              :             StatusCode::INTERNAL_SERVER_ERROR,
      89              :         ),
      90              :     }
      91            0 : }
      92              : 
      93              : /// Same as [`http_utils::error::HttpErrorBody`]
      94              : #[derive(Serialize)]
      95              : struct HttpErrorBody {
      96              :     pub(crate) msg: String,
      97              : }
      98              : 
      99              : impl HttpErrorBody {
     100              :     /// Same as [`http_utils::error::HttpErrorBody::response_from_msg_and_status`]
     101            0 :     fn response_from_msg_and_status(
     102            0 :         msg: String,
     103            0 :         status: StatusCode,
     104            0 :     ) -> Response<BoxBody<Bytes, hyper::Error>> {
     105            0 :         HttpErrorBody { msg }.to_response(status)
     106            0 :     }
     107              : 
     108              :     /// Same as [`http_utils::error::HttpErrorBody::to_response`]
     109            0 :     fn to_response(&self, status: StatusCode) -> Response<BoxBody<Bytes, hyper::Error>> {
     110            0 :         Response::builder()
     111            0 :             .status(status)
     112            0 :             .header(http::header::CONTENT_TYPE, "application/json")
     113              :             // we do not have nested maps with non string keys so serialization shouldn't fail
     114            0 :             .body(
     115            0 :                 Full::new(Bytes::from(
     116            0 :                     serde_json::to_string(self)
     117            0 :                         .expect("serialising HttpErrorBody should never fail"),
     118              :                 ))
     119            0 :                 .map_err(|x| match x {})
     120            0 :                 .boxed(),
     121              :             )
     122            0 :             .expect("content-type header should be valid")
     123            0 :     }
     124              : }
     125              : 
     126              : /// Same as [`http_utils::json::json_response`]
     127            0 : pub(crate) fn json_response<T: Serialize>(
     128            0 :     status: StatusCode,
     129            0 :     data: T,
     130            0 : ) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
     131            0 :     let json = serde_json::to_string(&data)
     132            0 :         .context("Failed to serialize JSON response")
     133            0 :         .map_err(ApiError::InternalServerError)?;
     134            0 :     let response = Response::builder()
     135            0 :         .status(status)
     136            0 :         .header(http::header::CONTENT_TYPE, "application/json")
     137            0 :         .body(Full::new(Bytes::from(json)).map_err(|x| match x {}).boxed())
     138            0 :         .map_err(|e| ApiError::InternalServerError(e.into()))?;
     139            0 :     Ok(response)
     140            0 : }
     141              : 
     142            0 : pub(crate) fn get_conn_info(
     143            0 :     config: &'static AuthenticationConfig,
     144            0 :     ctx: &RequestContext,
     145            0 :     connection_string: Option<&str>,
     146            0 :     headers: &HeaderMap,
     147            0 : ) -> Result<ConnInfoWithAuth, ConnInfoError> {
     148            0 :     let connection_url = match connection_string {
     149            0 :         Some(connection_string) => Url::parse(connection_string)?,
     150              :         None => {
     151            0 :             let connection_string = headers
     152            0 :                 .get(&CONN_STRING)
     153            0 :                 .ok_or(ConnInfoError::InvalidHeader(&CONN_STRING))?
     154            0 :                 .to_str()
     155            0 :                 .map_err(|_| ConnInfoError::InvalidHeader(&CONN_STRING))?;
     156            0 :             Url::parse(connection_string)?
     157              :         }
     158              :     };
     159              : 
     160            0 :     let protocol = connection_url.scheme();
     161            0 :     if protocol != "postgres" && protocol != "postgresql" {
     162            0 :         return Err(ConnInfoError::IncorrectScheme);
     163            0 :     }
     164              : 
     165            0 :     let mut url_path = connection_url
     166            0 :         .path_segments()
     167            0 :         .ok_or(ConnInfoError::MissingDbName)?;
     168              : 
     169            0 :     let dbname: DbName =
     170            0 :         urlencoding::decode(url_path.next().ok_or(ConnInfoError::InvalidDbName)?)?.into();
     171            0 :     ctx.set_dbname(dbname.clone());
     172              : 
     173            0 :     let username = RoleName::from(urlencoding::decode(connection_url.username())?);
     174            0 :     if username.is_empty() {
     175            0 :         return Err(ConnInfoError::MissingUsername);
     176            0 :     }
     177            0 :     ctx.set_user(username.clone());
     178              :     // TODO: make sure this is right in the context of rest broker
     179            0 :     let auth = if let Some(auth) = headers.get(&AUTHORIZATION) {
     180            0 :         if !config.accept_jwts {
     181            0 :             return Err(ConnInfoError::MissingCredentials(Credentials::Password));
     182            0 :         }
     183              : 
     184            0 :         let auth = auth
     185            0 :             .to_str()
     186            0 :             .map_err(|_| ConnInfoError::InvalidHeader(&AUTHORIZATION))?;
     187              :         AuthData::Jwt(
     188            0 :             auth.strip_prefix("Bearer ")
     189            0 :                 .ok_or(ConnInfoError::MissingCredentials(Credentials::BearerJwt))?
     190            0 :                 .into(),
     191              :         )
     192            0 :     } else if let Some(pass) = connection_url.password() {
     193              :         // wrong credentials provided
     194            0 :         if config.accept_jwts {
     195            0 :             return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
     196            0 :         }
     197              : 
     198            0 :         AuthData::Password(match urlencoding::decode_binary(pass.as_bytes()) {
     199            0 :             std::borrow::Cow::Borrowed(b) => b.into(),
     200            0 :             std::borrow::Cow::Owned(b) => b.into(),
     201              :         })
     202            0 :     } else if config.accept_jwts {
     203            0 :         return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
     204              :     } else {
     205            0 :         return Err(ConnInfoError::MissingCredentials(Credentials::Password));
     206              :     };
     207            0 :     let endpoint: EndpointId = match connection_url.host() {
     208            0 :         Some(url::Host::Domain(hostname)) => hostname
     209            0 :             .split_once('.')
     210            0 :             .map_or(hostname, |(prefix, _)| prefix)
     211            0 :             .into(),
     212              :         Some(url::Host::Ipv4(_) | url::Host::Ipv6(_)) | None => {
     213            0 :             return Err(ConnInfoError::MissingHostname);
     214              :         }
     215              :     };
     216            0 :     ctx.set_endpoint_id(endpoint.clone());
     217              : 
     218            0 :     let pairs = connection_url.query_pairs();
     219              : 
     220            0 :     let mut options = Option::None;
     221              : 
     222            0 :     let mut params = StartupMessageParams::default();
     223            0 :     params.insert("user", &username);
     224            0 :     params.insert("database", &dbname);
     225            0 :     for (key, value) in pairs {
     226            0 :         params.insert(&key, &value);
     227            0 :         if key == "options" {
     228            0 :             options = Some(NeonOptions::parse_options_raw(&value));
     229            0 :         }
     230              :     }
     231              : 
     232              :     // check the URL that was used, for metrics
     233              :     {
     234            0 :         let host_endpoint = headers
     235              :             // get the host header
     236            0 :             .get("host")
     237              :             // extract the domain
     238            0 :             .and_then(|h| {
     239            0 :                 let (host, _port) = h.to_str().ok()?.split_once(':')?;
     240            0 :                 Some(host)
     241            0 :             })
     242              :             // get the endpoint prefix
     243            0 :             .map(|h| h.split_once('.').map_or(h, |(prefix, _)| prefix));
     244              : 
     245            0 :         let kind = if host_endpoint == Some(&*endpoint) {
     246            0 :             SniKind::Sni
     247              :         } else {
     248            0 :             SniKind::NoSni
     249              :         };
     250              : 
     251            0 :         let protocol = ctx.protocol();
     252            0 :         Metrics::get()
     253            0 :             .proxy
     254            0 :             .accepted_connections_by_sni
     255            0 :             .inc(SniGroup { protocol, kind });
     256              :     }
     257              : 
     258            0 :     ctx.set_user_agent(
     259            0 :         headers
     260            0 :             .get(hyper::header::USER_AGENT)
     261            0 :             .and_then(|h| h.to_str().ok())
     262            0 :             .map(Into::into),
     263              :     );
     264              : 
     265            0 :     let user_info = ComputeUserInfo {
     266            0 :         endpoint,
     267            0 :         user: username,
     268            0 :         options: options.unwrap_or_default(),
     269            0 :     };
     270              : 
     271            0 :     let conn_info = ConnInfo { user_info, dbname };
     272            0 :     Ok(ConnInfoWithAuth { conn_info, auth })
     273            0 : }
        

Generated by: LCOV version 2.1-beta