LCOV - code coverage report
Current view: top level - proxy/src/serverless - rest.rs (source / functions) Coverage Total Hit
Test: a14d6a1f0ccf210374e9eaed9918e97cd6f5d5ba.info Lines: 0.0 % 700 0
Test Date: 2025-08-04 14:37:31 Functions: 0.0 % 90 0

            Line data    Source code
       1              : use std::borrow::Cow;
       2              : use std::collections::HashMap;
       3              : use std::convert::Infallible;
       4              : use std::sync::Arc;
       5              : 
       6              : use bytes::Bytes;
       7              : use http::Method;
       8              : use http::header::{AUTHORIZATION, CONTENT_TYPE, HOST};
       9              : use http_body_util::combinators::BoxBody;
      10              : use http_body_util::{BodyExt, Full};
      11              : use http_utils::error::ApiError;
      12              : use hyper::body::Incoming;
      13              : use hyper::http::{HeaderName, HeaderValue};
      14              : use hyper::{Request, Response, StatusCode};
      15              : use indexmap::IndexMap;
      16              : use moka::sync::Cache;
      17              : use ouroboros::self_referencing;
      18              : use serde::de::DeserializeOwned;
      19              : use serde::{Deserialize, Deserializer};
      20              : use serde_json::Value as JsonValue;
      21              : use serde_json::value::RawValue;
      22              : use subzero_core::api::ContentType::{ApplicationJSON, Other, SingularJSON, TextCSV};
      23              : use subzero_core::api::QueryNode::{Delete, FunctionCall, Insert, Update};
      24              : use subzero_core::api::Resolution::{IgnoreDuplicates, MergeDuplicates};
      25              : use subzero_core::api::{ApiResponse, ListVal, Payload, Preferences, Representation, SingleVal};
      26              : use subzero_core::config::{db_allowed_select_functions, db_schemas, role_claim_key};
      27              : use subzero_core::dynamic_statement::{JoinIterator, param, sql};
      28              : use subzero_core::error::Error::{
      29              :     self as SubzeroCoreError, ContentTypeError, GucHeadersError, GucStatusError, InternalError,
      30              :     JsonDeserialize, JwtTokenInvalid, NotFound,
      31              : };
      32              : use subzero_core::error::pg_error_to_status_code;
      33              : use subzero_core::formatter::Param::{LV, PL, SV, Str, StrOwned};
      34              : use subzero_core::formatter::postgresql::{fmt_main_query, generate};
      35              : use subzero_core::formatter::{Param, Snippet, SqlParam};
      36              : use subzero_core::parser::postgrest::parse;
      37              : use subzero_core::permissions::{check_safe_functions, replace_select_star};
      38              : use subzero_core::schema::{
      39              :     DbSchema, POSTGRESQL_INTROSPECTION_SQL, get_postgresql_configuration_query,
      40              : };
      41              : use subzero_core::{content_range_header, content_range_status};
      42              : use tokio_util::sync::CancellationToken;
      43              : use tracing::{error, info};
      44              : use typed_json::json;
      45              : use url::form_urlencoded;
      46              : 
      47              : use super::backend::{HttpConnError, LocalProxyConnError, PoolingBackend};
      48              : use super::conn_pool::AuthData;
      49              : use super::conn_pool_lib::ConnInfo;
      50              : use super::error::{ConnInfoError, Credentials, HttpCodeError, ReadPayloadError};
      51              : use super::http_conn_pool::{self, LocalProxyClient};
      52              : use super::http_util::{
      53              :     ALLOW_POOL, CONN_STRING, NEON_REQUEST_ID, RAW_TEXT_OUTPUT, TXN_ISOLATION_LEVEL, TXN_READ_ONLY,
      54              :     get_conn_info, json_response, uuid_to_header_value,
      55              : };
      56              : use super::json::JsonConversionError;
      57              : use crate::auth::backend::ComputeCredentialKeys;
      58              : use crate::cache::common::{count_cache_insert, count_cache_outcome, eviction_listener};
      59              : use crate::config::ProxyConfig;
      60              : use crate::context::RequestContext;
      61              : use crate::error::{ErrorKind, ReportableError, UserFacingError};
      62              : use crate::http::read_body_with_limit;
      63              : use crate::metrics::{CacheKind, Metrics};
      64              : use crate::serverless::sql_over_http::HEADER_VALUE_TRUE;
      65              : use crate::types::EndpointCacheKey;
      66              : use crate::util::deserialize_json_string;
      67              : 
      68              : static EMPTY_JSON_SCHEMA: &str = r#"{"schemas":[]}"#;
      69              : const INTROSPECTION_SQL: &str = POSTGRESQL_INTROSPECTION_SQL;
      70              : 
      71              : // A wrapper around the DbSchema that allows for self-referencing
      72              : #[self_referencing]
      73              : pub struct DbSchemaOwned {
      74              :     schema_string: String,
      75              :     #[covariant]
      76              :     #[borrows(schema_string)]
      77              :     schema: DbSchema<'this>,
      78              : }
      79              : 
      80              : impl<'de> Deserialize<'de> for DbSchemaOwned {
      81            0 :     fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
      82            0 :     where
      83            0 :         D: Deserializer<'de>,
      84              :     {
      85            0 :         let s = String::deserialize(deserializer)?;
      86            0 :         DbSchemaOwned::try_new(s, |s| serde_json::from_str(s))
      87            0 :             .map_err(<D::Error as serde::de::Error>::custom)
      88            0 :     }
      89              : }
      90              : 
      91            0 : fn split_comma_separated(s: &str) -> Vec<String> {
      92            0 :     s.split(',').map(|s| s.trim().to_string()).collect()
      93            0 : }
      94              : 
      95            0 : fn deserialize_comma_separated<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
      96            0 : where
      97            0 :     D: Deserializer<'de>,
      98              : {
      99            0 :     let s = String::deserialize(deserializer)?;
     100            0 :     Ok(split_comma_separated(&s))
     101            0 : }
     102              : 
     103            0 : fn deserialize_comma_separated_option<'de, D>(
     104            0 :     deserializer: D,
     105            0 : ) -> Result<Option<Vec<String>>, D::Error>
     106            0 : where
     107            0 :     D: Deserializer<'de>,
     108              : {
     109            0 :     let opt = Option::<String>::deserialize(deserializer)?;
     110            0 :     if let Some(s) = &opt {
     111            0 :         let trimmed = s.trim();
     112            0 :         if trimmed.is_empty() {
     113            0 :             return Ok(None);
     114            0 :         }
     115            0 :         return Ok(Some(split_comma_separated(trimmed)));
     116            0 :     }
     117            0 :     Ok(None)
     118            0 : }
     119              : 
     120              : // The ApiConfig is the configuration for the API per endpoint
     121              : // The configuration is read from the database and cached in the DbSchemaCache
     122              : #[derive(Deserialize, Debug)]
     123              : pub struct ApiConfig {
     124              :     #[serde(
     125              :         default = "db_schemas",
     126              :         deserialize_with = "deserialize_comma_separated"
     127              :     )]
     128              :     pub db_schemas: Vec<String>,
     129              :     pub db_anon_role: Option<String>,
     130              :     pub db_max_rows: Option<String>,
     131              :     #[serde(default = "db_allowed_select_functions")]
     132              :     pub db_allowed_select_functions: Vec<String>,
     133              :     // #[serde(deserialize_with = "to_tuple", default)]
     134              :     // pub db_pre_request: Option<(String, String)>,
     135              :     #[allow(dead_code)]
     136              :     #[serde(default = "role_claim_key")]
     137              :     pub role_claim_key: String,
     138              :     #[serde(default, deserialize_with = "deserialize_comma_separated_option")]
     139              :     pub db_extra_search_path: Option<Vec<String>>,
     140              : }
     141              : 
     142              : // The DbSchemaCache is a cache of the ApiConfig and DbSchemaOwned for each endpoint
     143              : pub(crate) struct DbSchemaCache(Cache<EndpointCacheKey, Arc<(ApiConfig, DbSchemaOwned)>>);
     144              : impl DbSchemaCache {
     145            0 :     pub fn new(config: crate::config::CacheOptions) -> Self {
     146            0 :         let builder = Cache::builder().name("schema");
     147            0 :         let builder = config.moka(builder);
     148              : 
     149            0 :         let metrics = &Metrics::get().cache;
     150            0 :         if let Some(size) = config.size {
     151            0 :             metrics.capacity.set(CacheKind::Schema, size as i64);
     152            0 :         }
     153              : 
     154            0 :         let builder =
     155            0 :             builder.eviction_listener(|_k, _v, cause| eviction_listener(CacheKind::Schema, cause));
     156              : 
     157            0 :         Self(builder.build())
     158            0 :     }
     159              : 
     160            0 :     pub async fn maintain(&self) -> Result<Infallible, anyhow::Error> {
     161            0 :         let mut ticker = tokio::time::interval(std::time::Duration::from_secs(60));
     162              :         loop {
     163            0 :             ticker.tick().await;
     164            0 :             self.0.run_pending_tasks();
     165              :         }
     166              :     }
     167              : 
     168            0 :     pub async fn get_cached_or_remote(
     169            0 :         &self,
     170            0 :         endpoint_id: &EndpointCacheKey,
     171            0 :         auth_header: &HeaderValue,
     172            0 :         connection_string: &str,
     173            0 :         client: &mut http_conn_pool::Client<LocalProxyClient>,
     174            0 :         ctx: &RequestContext,
     175            0 :         config: &'static ProxyConfig,
     176            0 :     ) -> Result<Arc<(ApiConfig, DbSchemaOwned)>, RestError> {
     177            0 :         let cache_result = count_cache_outcome(CacheKind::Schema, self.0.get(endpoint_id));
     178            0 :         match cache_result {
     179            0 :             Some(v) => Ok(v),
     180              :             None => {
     181            0 :                 info!("db_schema cache miss for endpoint: {:?}", endpoint_id);
     182            0 :                 let remote_value = self
     183            0 :                     .get_remote(auth_header, connection_string, client, ctx, config)
     184            0 :                     .await;
     185            0 :                 let (api_config, schema_owned) = match remote_value {
     186            0 :                     Ok((api_config, schema_owned)) => (api_config, schema_owned),
     187            0 :                     Err(e @ RestError::SchemaTooLarge) => {
     188              :                         // for the case where the schema is too large, we cache an empty dummy value
     189              :                         // all the other requests will fail without triggering the introspection query
     190            0 :                         let schema_owned = serde_json::from_str::<DbSchemaOwned>(EMPTY_JSON_SCHEMA)
     191            0 :                             .map_err(|e| JsonDeserialize { source: e })?;
     192              : 
     193            0 :                         let api_config = ApiConfig {
     194            0 :                             db_schemas: vec![],
     195            0 :                             db_anon_role: None,
     196            0 :                             db_max_rows: None,
     197            0 :                             db_allowed_select_functions: vec![],
     198            0 :                             role_claim_key: String::new(),
     199            0 :                             db_extra_search_path: None,
     200            0 :                         };
     201            0 :                         let value = Arc::new((api_config, schema_owned));
     202            0 :                         count_cache_insert(CacheKind::Schema);
     203            0 :                         self.0.insert(endpoint_id.clone(), value);
     204            0 :                         return Err(e);
     205              :                     }
     206            0 :                     Err(e) => {
     207            0 :                         return Err(e);
     208              :                     }
     209              :                 };
     210            0 :                 let value = Arc::new((api_config, schema_owned));
     211            0 :                 count_cache_insert(CacheKind::Schema);
     212            0 :                 self.0.insert(endpoint_id.clone(), value.clone());
     213            0 :                 Ok(value)
     214              :             }
     215              :         }
     216            0 :     }
     217            0 :     pub async fn get_remote(
     218            0 :         &self,
     219            0 :         auth_header: &HeaderValue,
     220            0 :         connection_string: &str,
     221            0 :         client: &mut http_conn_pool::Client<LocalProxyClient>,
     222            0 :         ctx: &RequestContext,
     223            0 :         config: &'static ProxyConfig,
     224            0 :     ) -> Result<(ApiConfig, DbSchemaOwned), RestError> {
     225            0 :         #[derive(Deserialize)]
     226              :         struct SingleRow<Row> {
     227              :             rows: [Row; 1],
     228              :         }
     229              : 
     230              :         #[derive(Deserialize)]
     231              :         struct ConfigRow {
     232              :             #[serde(deserialize_with = "deserialize_json_string")]
     233              :             config: ApiConfig,
     234              :         }
     235              : 
     236            0 :         #[derive(Deserialize)]
     237              :         struct SchemaRow {
     238              :             json_schema: DbSchemaOwned,
     239              :         }
     240              : 
     241            0 :         let headers = vec![
     242            0 :             (&NEON_REQUEST_ID, uuid_to_header_value(ctx.session_id())),
     243            0 :             (
     244            0 :                 &CONN_STRING,
     245            0 :                 HeaderValue::from_str(connection_string).expect(
     246            0 :                     "connection string came from a header, so it must be a valid headervalue",
     247            0 :                 ),
     248            0 :             ),
     249            0 :             (&AUTHORIZATION, auth_header.clone()),
     250            0 :             (&RAW_TEXT_OUTPUT, HEADER_VALUE_TRUE),
     251              :         ];
     252              : 
     253            0 :         let query = get_postgresql_configuration_query(Some("pgrst.pre_config"));
     254              :         let SingleRow {
     255            0 :             rows: [ConfigRow { config: api_config }],
     256            0 :         } = make_local_proxy_request(
     257            0 :             client,
     258            0 :             headers.iter().cloned(),
     259            0 :             QueryData {
     260            0 :                 query: Cow::Owned(query),
     261            0 :                 params: vec![],
     262            0 :             },
     263            0 :             config.rest_config.max_schema_size,
     264            0 :         )
     265            0 :         .await
     266            0 :         .map_err(|e| match e {
     267              :             RestError::ReadPayload(ReadPayloadError::BodyTooLarge { .. }) => {
     268            0 :                 RestError::SchemaTooLarge
     269              :             }
     270            0 :             e => e,
     271            0 :         })?;
     272              : 
     273              :         // now that we have the api_config let's run the second INTROSPECTION_SQL query
     274              :         let SingleRow {
     275            0 :             rows: [SchemaRow { json_schema }],
     276            0 :         } = make_local_proxy_request(
     277            0 :             client,
     278            0 :             headers,
     279            0 :             QueryData {
     280            0 :                 query: INTROSPECTION_SQL.into(),
     281            0 :                 params: vec![
     282            0 :                     serde_json::to_value(&api_config.db_schemas)
     283            0 :                         .expect("Vec<String> is always valid to encode as JSON"),
     284            0 :                     JsonValue::Bool(false), // include_roles_with_login
     285            0 :                     JsonValue::Bool(false), // use_internal_permissions
     286            0 :                 ],
     287            0 :             },
     288            0 :             config.rest_config.max_schema_size,
     289            0 :         )
     290            0 :         .await
     291            0 :         .map_err(|e| match e {
     292              :             RestError::ReadPayload(ReadPayloadError::BodyTooLarge { .. }) => {
     293            0 :                 RestError::SchemaTooLarge
     294              :             }
     295            0 :             e => e,
     296            0 :         })?;
     297              : 
     298            0 :         Ok((api_config, json_schema))
     299            0 :     }
     300              : }
     301              : 
     302              : // A type to represent a postgresql errors
     303              : // we use our own type (instead of postgres_client::Error) because we get the error from the json response
     304            0 : #[derive(Debug, thiserror::Error, Deserialize)]
     305              : pub(crate) struct PostgresError {
     306              :     pub code: String,
     307              :     pub message: String,
     308              :     pub detail: Option<String>,
     309              :     pub hint: Option<String>,
     310              : }
     311              : impl HttpCodeError for PostgresError {
     312            0 :     fn get_http_status_code(&self) -> StatusCode {
     313            0 :         let status = pg_error_to_status_code(&self.code, true);
     314            0 :         StatusCode::from_u16(status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)
     315            0 :     }
     316              : }
     317              : impl ReportableError for PostgresError {
     318            0 :     fn get_error_kind(&self) -> ErrorKind {
     319            0 :         ErrorKind::User
     320            0 :     }
     321              : }
     322              : impl UserFacingError for PostgresError {
     323            0 :     fn to_string_client(&self) -> String {
     324            0 :         if self.code.starts_with("PT") {
     325            0 :             "Postgres error".to_string()
     326              :         } else {
     327            0 :             self.message.clone()
     328              :         }
     329            0 :     }
     330              : }
     331              : impl std::fmt::Display for PostgresError {
     332            0 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
     333            0 :         write!(f, "{}", self.message)
     334            0 :     }
     335              : }
     336              : 
     337              : // A type to represent errors that can occur in the rest broker
     338              : #[derive(Debug, thiserror::Error)]
     339              : pub(crate) enum RestError {
     340              :     #[error(transparent)]
     341              :     ReadPayload(#[from] ReadPayloadError),
     342              :     #[error(transparent)]
     343              :     ConnectCompute(#[from] HttpConnError),
     344              :     #[error(transparent)]
     345              :     ConnInfo(#[from] ConnInfoError),
     346              :     #[error(transparent)]
     347              :     Postgres(#[from] PostgresError),
     348              :     #[error(transparent)]
     349              :     JsonConversion(#[from] JsonConversionError),
     350              :     #[error(transparent)]
     351              :     SubzeroCore(#[from] SubzeroCoreError),
     352              :     #[error("schema is too large")]
     353              :     SchemaTooLarge,
     354              : }
     355              : impl ReportableError for RestError {
     356            0 :     fn get_error_kind(&self) -> ErrorKind {
     357            0 :         match self {
     358            0 :             RestError::ReadPayload(e) => e.get_error_kind(),
     359            0 :             RestError::ConnectCompute(e) => e.get_error_kind(),
     360            0 :             RestError::ConnInfo(e) => e.get_error_kind(),
     361            0 :             RestError::Postgres(_) => ErrorKind::Postgres,
     362            0 :             RestError::JsonConversion(_) => ErrorKind::Postgres,
     363            0 :             RestError::SubzeroCore(_) => ErrorKind::User,
     364            0 :             RestError::SchemaTooLarge => ErrorKind::User,
     365              :         }
     366            0 :     }
     367              : }
     368              : impl UserFacingError for RestError {
     369            0 :     fn to_string_client(&self) -> String {
     370            0 :         match self {
     371            0 :             RestError::ReadPayload(p) => p.to_string(),
     372            0 :             RestError::ConnectCompute(c) => c.to_string_client(),
     373            0 :             RestError::ConnInfo(c) => c.to_string_client(),
     374            0 :             RestError::SchemaTooLarge => self.to_string(),
     375            0 :             RestError::Postgres(p) => p.to_string_client(),
     376            0 :             RestError::JsonConversion(_) => "could not parse postgres response".to_string(),
     377            0 :             RestError::SubzeroCore(s) => {
     378              :                 // TODO: this is a hack to get the message from the json body
     379            0 :                 let json = s.json_body();
     380            0 :                 let default_message = "Unknown error".to_string();
     381              : 
     382            0 :                 json.get("message")
     383            0 :                     .map_or(default_message.clone(), |m| match m {
     384            0 :                         JsonValue::String(s) => s.clone(),
     385            0 :                         _ => default_message,
     386            0 :                     })
     387              :             }
     388              :         }
     389            0 :     }
     390              : }
     391              : impl HttpCodeError for RestError {
     392            0 :     fn get_http_status_code(&self) -> StatusCode {
     393            0 :         match self {
     394            0 :             RestError::ReadPayload(e) => e.get_http_status_code(),
     395            0 :             RestError::ConnectCompute(h) => match h.get_error_kind() {
     396            0 :                 ErrorKind::User => StatusCode::BAD_REQUEST,
     397            0 :                 _ => StatusCode::INTERNAL_SERVER_ERROR,
     398              :             },
     399            0 :             RestError::ConnInfo(_) => StatusCode::BAD_REQUEST,
     400            0 :             RestError::Postgres(e) => e.get_http_status_code(),
     401            0 :             RestError::JsonConversion(_) => StatusCode::INTERNAL_SERVER_ERROR,
     402            0 :             RestError::SchemaTooLarge => StatusCode::INTERNAL_SERVER_ERROR,
     403            0 :             RestError::SubzeroCore(e) => {
     404            0 :                 let status = e.status_code();
     405            0 :                 StatusCode::from_u16(status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)
     406              :             }
     407              :         }
     408            0 :     }
     409              : }
     410              : 
     411              : // Helper functions for the rest broker
     412              : 
     413            0 : fn fmt_env_query<'a>(env: &'a HashMap<&'a str, &'a str>) -> Snippet<'a> {
     414              :     "select "
     415            0 :         + if env.is_empty() {
     416            0 :             sql("null")
     417              :         } else {
     418            0 :             env.iter()
     419            0 :                 .map(|(k, v)| {
     420            0 :                     "set_config(" + param(k as &SqlParam) + ", " + param(v as &SqlParam) + ", true)"
     421            0 :                 })
     422            0 :                 .join(",")
     423              :         }
     424            0 : }
     425              : 
     426              : // TODO: see about removing the need for cloning the values (inner things are &Cow<str> already)
     427            0 : fn to_sql_param(p: &Param) -> JsonValue {
     428            0 :     match p {
     429            0 :         SV(SingleVal(v, ..)) => JsonValue::String(v.to_string()),
     430            0 :         Str(v) => JsonValue::String((*v).to_string()),
     431            0 :         StrOwned(v) => JsonValue::String((*v).clone()),
     432            0 :         PL(Payload(v, ..)) => JsonValue::String(v.clone().into_owned()),
     433            0 :         LV(ListVal(v, ..)) => {
     434            0 :             if v.is_empty() {
     435            0 :                 JsonValue::String(r"{}".to_string())
     436              :             } else {
     437            0 :                 JsonValue::String(format!(
     438            0 :                     "{{\"{}\"}}",
     439            0 :                     v.iter()
     440            0 :                         .map(|e| e.replace('\\', "\\\\").replace('\"', "\\\""))
     441            0 :                         .collect::<Vec<_>>()
     442            0 :                         .join("\",\"")
     443              :                 ))
     444              :             }
     445              :         }
     446              :     }
     447            0 : }
     448              : 
     449              : #[derive(serde::Serialize)]
     450              : struct QueryData<'a> {
     451              :     query: Cow<'a, str>,
     452              :     params: Vec<JsonValue>,
     453              : }
     454              : 
     455              : #[derive(serde::Serialize)]
     456              : struct BatchQueryData<'a> {
     457              :     queries: Vec<QueryData<'a>>,
     458              : }
     459              : 
     460            0 : async fn make_local_proxy_request<S: DeserializeOwned>(
     461            0 :     client: &mut http_conn_pool::Client<LocalProxyClient>,
     462            0 :     headers: impl IntoIterator<Item = (&HeaderName, HeaderValue)>,
     463            0 :     body: QueryData<'_>,
     464            0 :     max_len: usize,
     465            0 : ) -> Result<S, RestError> {
     466            0 :     let body_string = serde_json::to_string(&body)
     467            0 :         .map_err(|e| RestError::JsonConversion(JsonConversionError::ParseJsonError(e)))?;
     468              : 
     469            0 :     let response = make_raw_local_proxy_request(client, headers, body_string).await?;
     470              : 
     471            0 :     let response_status = response.status();
     472              : 
     473            0 :     if response_status != StatusCode::OK {
     474            0 :         return Err(RestError::SubzeroCore(InternalError {
     475            0 :             message: "Failed to get endpoint schema".to_string(),
     476            0 :         }));
     477            0 :     }
     478              : 
     479              :     // Capture the response body
     480            0 :     let response_body = crate::http::read_body_with_limit(response.into_body(), max_len)
     481            0 :         .await
     482            0 :         .map_err(ReadPayloadError::from)?;
     483              : 
     484              :     // Parse the JSON response
     485            0 :     let response_json: S = serde_json::from_slice(&response_body)
     486            0 :         .map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?;
     487              : 
     488            0 :     Ok(response_json)
     489            0 : }
     490              : 
     491            0 : async fn make_raw_local_proxy_request(
     492            0 :     client: &mut http_conn_pool::Client<LocalProxyClient>,
     493            0 :     headers: impl IntoIterator<Item = (&HeaderName, HeaderValue)>,
     494            0 :     body: String,
     495            0 : ) -> Result<Response<Incoming>, RestError> {
     496            0 :     let local_proxy_uri = ::http::Uri::from_static("http://proxy.local/sql");
     497            0 :     let mut req = Request::builder().method(Method::POST).uri(local_proxy_uri);
     498            0 :     let req_headers = req.headers_mut().expect("failed to get headers");
     499              :     // Add all provided headers to the request
     500            0 :     for (header_name, header_value) in headers {
     501            0 :         req_headers.insert(header_name, header_value.clone());
     502            0 :     }
     503              : 
     504            0 :     let body_boxed = Full::new(Bytes::from(body))
     505            0 :         .map_err(|never| match never {}) // Convert Infallible to hyper::Error
     506            0 :         .boxed();
     507              : 
     508            0 :     let req = req.body(body_boxed).map_err(|_| {
     509            0 :         RestError::SubzeroCore(InternalError {
     510            0 :             message: "Failed to build request".to_string(),
     511            0 :         })
     512            0 :     })?;
     513              : 
     514              :     // Send the request to the local proxy
     515            0 :     client
     516            0 :         .inner
     517            0 :         .inner
     518            0 :         .send_request(req)
     519            0 :         .await
     520            0 :         .map_err(LocalProxyConnError::from)
     521            0 :         .map_err(HttpConnError::from)
     522            0 :         .map_err(RestError::from)
     523            0 : }
     524              : 
     525            0 : pub(crate) async fn handle(
     526            0 :     config: &'static ProxyConfig,
     527            0 :     ctx: RequestContext,
     528            0 :     request: Request<Incoming>,
     529            0 :     backend: Arc<PoolingBackend>,
     530            0 :     cancel: CancellationToken,
     531            0 : ) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
     532            0 :     let result = handle_inner(cancel, config, &ctx, request, backend).await;
     533              : 
     534            0 :     let mut response = match result {
     535            0 :         Ok(r) => {
     536            0 :             ctx.set_success();
     537              : 
     538              :             // Handling the error response from local proxy here
     539            0 :             if r.status().is_server_error() {
     540            0 :                 let status = r.status();
     541              : 
     542            0 :                 let body_bytes = r
     543            0 :                     .collect()
     544            0 :                     .await
     545            0 :                     .map_err(|e| {
     546            0 :                         ApiError::InternalServerError(anyhow::Error::msg(format!(
     547            0 :                             "could not collect http body: {e}"
     548            0 :                         )))
     549            0 :                     })?
     550            0 :                     .to_bytes();
     551              : 
     552            0 :                 if let Ok(mut json_map) =
     553            0 :                     serde_json::from_slice::<IndexMap<&str, &RawValue>>(&body_bytes)
     554              :                 {
     555            0 :                     let message = json_map.get("message");
     556            0 :                     if let Some(message) = message {
     557            0 :                         let msg: String = match serde_json::from_str(message.get()) {
     558            0 :                             Ok(msg) => msg,
     559              :                             Err(_) => {
     560            0 :                                 "Unable to parse the response message from server".to_string()
     561              :                             }
     562              :                         };
     563              : 
     564            0 :                         error!("Error response from local_proxy: {status} {msg}");
     565              : 
     566            0 :                         json_map.retain(|key, _| !key.starts_with("neon:")); // remove all the neon-related keys
     567              : 
     568            0 :                         let resp_json = serde_json::to_string(&json_map)
     569            0 :                             .unwrap_or("failed to serialize the response message".to_string());
     570              : 
     571            0 :                         return json_response(status, resp_json);
     572            0 :                     }
     573            0 :                 }
     574              : 
     575            0 :                 error!("Unable to parse the response message from local_proxy");
     576            0 :                 return json_response(
     577            0 :                     status,
     578            0 :                     json!({ "message": "Unable to parse the response message from server".to_string() }),
     579              :                 );
     580            0 :             }
     581            0 :             r
     582              :         }
     583            0 :         Err(e @ RestError::SubzeroCore(_)) => {
     584            0 :             let error_kind = e.get_error_kind();
     585            0 :             ctx.set_error_kind(error_kind);
     586              : 
     587            0 :             tracing::info!(
     588            0 :                 kind=error_kind.to_metric_label(),
     589              :                 error=%e,
     590              :                 msg="subzero core error",
     591            0 :                 "forwarding error to user"
     592              :             );
     593              : 
     594            0 :             let RestError::SubzeroCore(subzero_err) = e else {
     595            0 :                 panic!("expected subzero core error")
     596              :             };
     597              : 
     598            0 :             let json_body = subzero_err.json_body();
     599            0 :             let status_code = StatusCode::from_u16(subzero_err.status_code())
     600            0 :                 .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
     601              : 
     602            0 :             json_response(status_code, json_body)?
     603              :         }
     604            0 :         Err(e) => {
     605            0 :             let error_kind = e.get_error_kind();
     606            0 :             ctx.set_error_kind(error_kind);
     607              : 
     608            0 :             let message = e.to_string_client();
     609            0 :             let status_code = e.get_http_status_code();
     610              : 
     611            0 :             tracing::info!(
     612            0 :                 kind=error_kind.to_metric_label(),
     613              :                 error=%e,
     614              :                 msg=message,
     615            0 :                 "forwarding error to user"
     616              :             );
     617              : 
     618            0 :             let (code, detail, hint) = match e {
     619            0 :                 RestError::Postgres(e) => (
     620            0 :                     if e.code.starts_with("PT") {
     621            0 :                         None
     622              :                     } else {
     623            0 :                         Some(e.code)
     624              :                     },
     625            0 :                     e.detail,
     626            0 :                     e.hint,
     627              :                 ),
     628            0 :                 _ => (None, None, None),
     629              :             };
     630              : 
     631            0 :             json_response(
     632            0 :                 status_code,
     633            0 :                 json!({
     634            0 :                     "message": message,
     635            0 :                     "code": code,
     636            0 :                     "detail": detail,
     637            0 :                     "hint": hint,
     638              :                 }),
     639            0 :             )?
     640              :         }
     641              :     };
     642              : 
     643            0 :     response
     644            0 :         .headers_mut()
     645            0 :         .insert("Access-Control-Allow-Origin", HeaderValue::from_static("*"));
     646            0 :     Ok(response)
     647            0 : }
     648              : 
     649            0 : async fn handle_inner(
     650            0 :     _cancel: CancellationToken,
     651            0 :     config: &'static ProxyConfig,
     652            0 :     ctx: &RequestContext,
     653            0 :     request: Request<Incoming>,
     654            0 :     backend: Arc<PoolingBackend>,
     655            0 : ) -> Result<Response<BoxBody<Bytes, hyper::Error>>, RestError> {
     656            0 :     let _requeset_gauge = Metrics::get()
     657            0 :         .proxy
     658            0 :         .connection_requests
     659            0 :         .guard(ctx.protocol());
     660            0 :     info!(
     661            0 :         protocol = %ctx.protocol(),
     662            0 :         "handling interactive connection from client"
     663              :     );
     664              : 
     665              :     // Read host from Host, then URI host as fallback
     666              :     // TODO: will this be a problem if behind a load balancer?
     667              :     // TODO: can we use the x-forwarded-host header?
     668            0 :     let host = request
     669            0 :         .headers()
     670            0 :         .get(HOST)
     671            0 :         .and_then(|v| v.to_str().ok())
     672            0 :         .unwrap_or_else(|| request.uri().host().unwrap_or(""));
     673              : 
     674              :     // a valid path is /database/rest/v1/... so splitting should be ["", "database", "rest", "v1", ...]
     675            0 :     let database_name = request
     676            0 :         .uri()
     677            0 :         .path()
     678            0 :         .split('/')
     679            0 :         .nth(1)
     680            0 :         .ok_or(RestError::SubzeroCore(NotFound {
     681            0 :             target: request.uri().path().to_string(),
     682            0 :         }))?;
     683              : 
     684              :     // we always use the authenticator role to connect to the database
     685            0 :     let authenticator_role = "authenticator";
     686              : 
     687              :     // Strip the hostname prefix from the host to get the database hostname
     688            0 :     let database_host = host.replace(&config.rest_config.hostname_prefix, "");
     689              : 
     690            0 :     let connection_string =
     691            0 :         format!("postgresql://{authenticator_role}@{database_host}/{database_name}");
     692              : 
     693            0 :     let conn_info = get_conn_info(
     694            0 :         &config.authentication_config,
     695            0 :         ctx,
     696            0 :         Some(&connection_string),
     697            0 :         request.headers(),
     698            0 :     )?;
     699            0 :     info!(
     700            0 :         user = conn_info.conn_info.user_info.user.as_str(),
     701            0 :         "credentials"
     702              :     );
     703              : 
     704            0 :     match conn_info.auth {
     705            0 :         AuthData::Jwt(jwt) => {
     706            0 :             let api_prefix = format!("/{database_name}/rest/v1/");
     707            0 :             handle_rest_inner(
     708            0 :                 config,
     709            0 :                 ctx,
     710            0 :                 &api_prefix,
     711            0 :                 request,
     712            0 :                 &connection_string,
     713            0 :                 conn_info.conn_info,
     714            0 :                 jwt,
     715            0 :                 backend,
     716            0 :             )
     717            0 :             .await
     718              :         }
     719            0 :         AuthData::Password(_) => Err(RestError::ConnInfo(ConnInfoError::MissingCredentials(
     720            0 :             Credentials::BearerJwt,
     721            0 :         ))),
     722              :     }
     723            0 : }
     724              : 
     725              : #[allow(clippy::too_many_arguments)]
     726            0 : async fn handle_rest_inner(
     727            0 :     config: &'static ProxyConfig,
     728            0 :     ctx: &RequestContext,
     729            0 :     api_prefix: &str,
     730            0 :     request: Request<Incoming>,
     731            0 :     connection_string: &str,
     732            0 :     conn_info: ConnInfo,
     733            0 :     jwt: String,
     734            0 :     backend: Arc<PoolingBackend>,
     735            0 : ) -> Result<Response<BoxBody<Bytes, hyper::Error>>, RestError> {
     736              :     // validate the jwt token
     737            0 :     let jwt_parsed = backend
     738            0 :         .authenticate_with_jwt(ctx, &conn_info.user_info, jwt)
     739            0 :         .await
     740            0 :         .map_err(HttpConnError::from)?;
     741              : 
     742            0 :     let db_schema_cache =
     743            0 :         config
     744            0 :             .rest_config
     745            0 :             .db_schema_cache
     746            0 :             .as_ref()
     747            0 :             .ok_or(RestError::SubzeroCore(InternalError {
     748            0 :                 message: "DB schema cache is not configured".to_string(),
     749            0 :             }))?;
     750              : 
     751            0 :     let endpoint_cache_key = conn_info
     752            0 :         .endpoint_cache_key()
     753            0 :         .ok_or(RestError::SubzeroCore(InternalError {
     754            0 :             message: "Failed to get endpoint cache key".to_string(),
     755            0 :         }))?;
     756              : 
     757            0 :     let mut client = backend.connect_to_local_proxy(ctx, conn_info).await?;
     758              : 
     759            0 :     let (parts, originial_body) = request.into_parts();
     760              : 
     761            0 :     let auth_header = parts
     762            0 :         .headers
     763            0 :         .get(AUTHORIZATION)
     764            0 :         .ok_or(RestError::SubzeroCore(InternalError {
     765            0 :             message: "Authorization header is required".to_string(),
     766            0 :         }))?;
     767              : 
     768            0 :     let entry = db_schema_cache
     769            0 :         .get_cached_or_remote(
     770            0 :             &endpoint_cache_key,
     771            0 :             auth_header,
     772            0 :             connection_string,
     773            0 :             &mut client,
     774            0 :             ctx,
     775            0 :             config,
     776            0 :         )
     777            0 :         .await?;
     778            0 :     let (api_config, db_schema_owned) = entry.as_ref();
     779            0 :     let db_schema = db_schema_owned.borrow_schema();
     780              : 
     781            0 :     let db_schemas = &api_config.db_schemas; // list of schemas available for the api
     782            0 :     let db_extra_search_path = &api_config.db_extra_search_path;
     783              :     // TODO: use this when we get a replacement for jsonpath_lib
     784              :     // let role_claim_key = &api_config.role_claim_key;
     785              :     // let role_claim_path = format!("${role_claim_key}");
     786            0 :     let db_anon_role = &api_config.db_anon_role;
     787            0 :     let max_rows = api_config.db_max_rows.as_deref();
     788            0 :     let db_allowed_select_functions = api_config
     789            0 :         .db_allowed_select_functions
     790            0 :         .iter()
     791            0 :         .map(|s| s.as_str())
     792            0 :         .collect::<Vec<_>>();
     793              : 
     794              :     // extract the jwt claims (we'll need them later to set the role and env)
     795            0 :     let jwt_claims = match jwt_parsed.keys {
     796            0 :         ComputeCredentialKeys::JwtPayload(payload_bytes) => {
     797              :             // `payload_bytes` contains the raw JWT payload as Vec<u8>
     798              :             // You can deserialize it back to JSON or parse specific claims
     799            0 :             let payload: serde_json::Value = serde_json::from_slice(&payload_bytes)
     800            0 :                 .map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?;
     801            0 :             Some(payload)
     802              :         }
     803            0 :         ComputeCredentialKeys::AuthKeys(_) => None,
     804              :     };
     805              : 
     806              :     // read the role from the jwt claims (and set it to the "anon" role if not present)
     807            0 :     let (role, authenticated) = match &jwt_claims {
     808            0 :         Some(claims) => match claims.get("role") {
     809            0 :             Some(JsonValue::String(r)) => (Some(r), true),
     810            0 :             _ => (db_anon_role.as_ref(), true),
     811              :         },
     812            0 :         None => (db_anon_role.as_ref(), false),
     813              :     };
     814              : 
     815              :     // do not allow unauthenticated requests when there is no anonymous role setup
     816            0 :     if let (None, false) = (role, authenticated) {
     817            0 :         return Err(RestError::SubzeroCore(JwtTokenInvalid {
     818            0 :             message: "unauthenticated requests not allowed".to_string(),
     819            0 :         }));
     820            0 :     }
     821              : 
     822              :     // start deconstructing the request because subzero core mostly works with &str
     823            0 :     let method = parts.method;
     824            0 :     let method_str = method.as_str();
     825            0 :     let path = parts.uri.path_and_query().map_or("/", |pq| pq.as_str());
     826              : 
     827              :     // this is actually the table name (or rpc/function_name)
     828              :     // TODO: rename this to something more descriptive
     829            0 :     let root = match parts.uri.path().strip_prefix(api_prefix) {
     830            0 :         Some(p) => Ok(p),
     831            0 :         None => Err(RestError::SubzeroCore(NotFound {
     832            0 :             target: parts.uri.path().to_string(),
     833            0 :         })),
     834            0 :     }?;
     835              : 
     836              :     // pick the current schema from the headers (or the first one from config)
     837            0 :     let schema_name = &DbSchema::pick_current_schema(db_schemas, method_str, &parts.headers)?;
     838              : 
     839              :     // add the content-profile header to the response
     840            0 :     let mut response_headers = vec![];
     841            0 :     if db_schemas.len() > 1 {
     842            0 :         response_headers.push(("Content-Profile".to_string(), schema_name.clone()));
     843            0 :     }
     844              : 
     845              :     // parse the query string into a Vec<(&str, &str)>
     846            0 :     let query = match parts.uri.query() {
     847            0 :         Some(q) => form_urlencoded::parse(q.as_bytes()).collect(),
     848            0 :         None => vec![],
     849              :     };
     850            0 :     let get: Vec<(&str, &str)> = query.iter().map(|(k, v)| (&**k, &**v)).collect();
     851              : 
     852              :     // convert the headers map to a HashMap<&str, &str>
     853            0 :     let headers: HashMap<&str, &str> = parts
     854            0 :         .headers
     855            0 :         .iter()
     856            0 :         .map(|(k, v)| (k.as_str(), v.to_str().unwrap_or("__BAD_HEADER__")))
     857            0 :         .collect();
     858              : 
     859            0 :     let cookies = HashMap::new(); // TODO: add cookies
     860              : 
     861              :     // Read the request body (skip for GET requests)
     862            0 :     let body_as_string: Option<String> = if method == Method::GET {
     863            0 :         None
     864              :     } else {
     865            0 :         let body_bytes =
     866            0 :             read_body_with_limit(originial_body, config.http_config.max_request_size_bytes)
     867            0 :                 .await
     868            0 :                 .map_err(ReadPayloadError::from)?;
     869            0 :         if body_bytes.is_empty() {
     870            0 :             None
     871              :         } else {
     872            0 :             Some(String::from_utf8_lossy(&body_bytes).into_owned())
     873              :         }
     874              :     };
     875              : 
     876              :     // parse the request into an ApiRequest struct
     877            0 :     let mut api_request = parse(
     878            0 :         schema_name,
     879            0 :         root,
     880            0 :         db_schema,
     881            0 :         method_str,
     882            0 :         path,
     883            0 :         get,
     884            0 :         body_as_string.as_deref(),
     885            0 :         headers,
     886            0 :         cookies,
     887            0 :         max_rows,
     888              :     )
     889            0 :     .map_err(RestError::SubzeroCore)?;
     890              : 
     891            0 :     let role_str = match role {
     892            0 :         Some(r) => r,
     893            0 :         None => "",
     894              :     };
     895              : 
     896            0 :     replace_select_star(db_schema, schema_name, role_str, &mut api_request.query)?;
     897              : 
     898              :     // TODO: this is not relevant when acting as PostgREST but will be useful
     899              :     // in the context of DBX where they need internal permissions
     900              :     // if !disable_internal_permissions {
     901              :     //     check_privileges(db_schema, schema_name, role_str, &api_request)?;
     902              :     // }
     903              : 
     904            0 :     check_safe_functions(&api_request, &db_allowed_select_functions)?;
     905              : 
     906              :     // TODO: this is not relevant when acting as PostgREST but will be useful
     907              :     // in the context of DBX where they need internal permissions
     908              :     // if !disable_internal_permissions {
     909              :     //     insert_policy_conditions(db_schema, schema_name, role_str, &mut api_request.query)?;
     910              :     // }
     911              : 
     912            0 :     let env_role = Some(role_str);
     913              : 
     914              :     // construct the env (passed in to the sql context as GUCs)
     915            0 :     let empty_json = "{}".to_string();
     916            0 :     let headers_env = serde_json::to_string(&api_request.headers).unwrap_or(empty_json.clone());
     917            0 :     let cookies_env = serde_json::to_string(&api_request.cookies).unwrap_or(empty_json.clone());
     918            0 :     let get_env = serde_json::to_string(&api_request.get).unwrap_or(empty_json.clone());
     919            0 :     let jwt_claims_env = jwt_claims
     920            0 :         .as_ref()
     921            0 :         .map(|v| serde_json::to_string(v).unwrap_or(empty_json.clone()))
     922            0 :         .unwrap_or(if let Some(r) = env_role {
     923            0 :             let claims: HashMap<&str, &str> = HashMap::from([("role", r)]);
     924            0 :             serde_json::to_string(&claims).unwrap_or(empty_json.clone())
     925              :         } else {
     926            0 :             empty_json.clone()
     927              :         });
     928            0 :     let mut search_path = vec![api_request.schema_name];
     929            0 :     if let Some(extra) = &db_extra_search_path {
     930            0 :         search_path.extend(extra.iter().map(|s| s.as_str()));
     931            0 :     }
     932            0 :     let search_path_str = search_path
     933            0 :         .into_iter()
     934            0 :         .filter(|s| !s.is_empty())
     935            0 :         .collect::<Vec<_>>()
     936            0 :         .join(",");
     937            0 :     let mut env: HashMap<&str, &str> = HashMap::from([
     938            0 :         ("request.method", api_request.method),
     939            0 :         ("request.path", api_request.path),
     940            0 :         ("request.headers", &headers_env),
     941            0 :         ("request.cookies", &cookies_env),
     942            0 :         ("request.get", &get_env),
     943            0 :         ("request.jwt.claims", &jwt_claims_env),
     944            0 :         ("search_path", &search_path_str),
     945            0 :     ]);
     946            0 :     if let Some(r) = env_role {
     947            0 :         env.insert("role", r);
     948            0 :     }
     949              : 
     950              :     // generate the sql statements
     951            0 :     let (env_statement, env_parameters, _) = generate(fmt_env_query(&env));
     952            0 :     let (main_statement, main_parameters, _) = generate(fmt_main_query(
     953            0 :         db_schema,
     954            0 :         api_request.schema_name,
     955            0 :         &api_request,
     956            0 :         &env,
     957            0 :     )?);
     958              : 
     959            0 :     let mut headers = vec![
     960            0 :         (&NEON_REQUEST_ID, uuid_to_header_value(ctx.session_id())),
     961            0 :         (
     962            0 :             &CONN_STRING,
     963            0 :             HeaderValue::from_str(connection_string).expect("invalid connection string"),
     964            0 :         ),
     965            0 :         (&AUTHORIZATION, auth_header.clone()),
     966            0 :         (
     967            0 :             &TXN_ISOLATION_LEVEL,
     968            0 :             HeaderValue::from_static("ReadCommitted"),
     969            0 :         ),
     970            0 :         (&ALLOW_POOL, HEADER_VALUE_TRUE),
     971              :     ];
     972              : 
     973            0 :     if api_request.read_only {
     974            0 :         headers.push((&TXN_READ_ONLY, HEADER_VALUE_TRUE));
     975            0 :     }
     976              : 
     977              :     // convert the parameters from subzero core representation to the local proxy repr.
     978            0 :     let req_body = serde_json::to_string(&BatchQueryData {
     979            0 :         queries: vec![
     980              :             QueryData {
     981            0 :                 query: env_statement.into(),
     982            0 :                 params: env_parameters
     983            0 :                     .iter()
     984            0 :                     .map(|p| to_sql_param(&p.to_param()))
     985            0 :                     .collect(),
     986              :             },
     987              :             QueryData {
     988            0 :                 query: main_statement.into(),
     989            0 :                 params: main_parameters
     990            0 :                     .iter()
     991            0 :                     .map(|p| to_sql_param(&p.to_param()))
     992            0 :                     .collect(),
     993              :             },
     994              :         ],
     995              :     })
     996            0 :     .map_err(|e| RestError::JsonConversion(JsonConversionError::ParseJsonError(e)))?;
     997              : 
     998              :     // todo: map body to count egress
     999            0 :     let _metrics = client.metrics(ctx); // FIXME: is everything in the context set correctly?
    1000              : 
    1001              :     // send the request to the local proxy
    1002            0 :     let response = make_raw_local_proxy_request(&mut client, headers, req_body).await?;
    1003            0 :     let (parts, body) = response.into_parts();
    1004              : 
    1005            0 :     let max_response = config.http_config.max_response_size_bytes;
    1006            0 :     let bytes = read_body_with_limit(body, max_response)
    1007            0 :         .await
    1008            0 :         .map_err(ReadPayloadError::from)?;
    1009              : 
    1010              :     // if the response status is greater than 399, then it is an error
    1011              :     // FIXME: check if there are other error codes or shapes of the response
    1012            0 :     if parts.status.as_u16() > 399 {
    1013              :         // turn this postgres error from the json into PostgresError
    1014            0 :         let postgres_error = serde_json::from_slice(&bytes)
    1015            0 :             .map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?;
    1016              : 
    1017            0 :         return Err(RestError::Postgres(postgres_error));
    1018            0 :     }
    1019              : 
    1020            0 :     #[derive(Deserialize)]
    1021              :     struct QueryResults {
    1022              :         /// we run two queries, so we want only two results.
    1023              :         results: (EnvRows, MainRows),
    1024              :     }
    1025              : 
    1026              :     /// `env_statement` returns nothing of interest to us
    1027            0 :     #[derive(Deserialize)]
    1028              :     struct EnvRows {}
    1029              : 
    1030            0 :     #[derive(Deserialize)]
    1031              :     struct MainRows {
    1032              :         /// `main_statement` only returns a single row.
    1033              :         rows: [MainRow; 1],
    1034              :     }
    1035              : 
    1036            0 :     #[derive(Deserialize)]
    1037              :     struct MainRow {
    1038              :         body: String,
    1039              :         page_total: Option<String>,
    1040              :         total_result_set: Option<String>,
    1041              :         response_headers: Option<String>,
    1042              :         response_status: Option<String>,
    1043              :     }
    1044              : 
    1045            0 :     let results: QueryResults = serde_json::from_slice(&bytes)
    1046            0 :         .map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?;
    1047              : 
    1048              :     let QueryResults {
    1049            0 :         results: (_, MainRows { rows: [row] }),
    1050            0 :     } = results;
    1051              : 
    1052              :     // build the intermediate response object
    1053            0 :     let api_response = ApiResponse {
    1054            0 :         page_total: row.page_total.map_or(0, |v| v.parse::<u64>().unwrap_or(0)),
    1055            0 :         total_result_set: row.total_result_set.map(|v| v.parse::<u64>().unwrap_or(0)),
    1056              :         top_level_offset: 0, // FIXME: check why this is 0
    1057            0 :         response_headers: row.response_headers,
    1058            0 :         response_status: row.response_status,
    1059            0 :         body: row.body,
    1060              :     };
    1061              : 
    1062              :     // TODO: rollback the transaction if the page_total is not 1 and the accept_content_type is SingularJSON
    1063              :     // we can not do this in the context of proxy for now
    1064              :     // if api_request.accept_content_type == SingularJSON && api_response.page_total != 1 {
    1065              :     //     // rollback the transaction here
    1066              :     //     return Err(RestError::SubzeroCore(SingularityError {
    1067              :     //         count: api_response.page_total,
    1068              :     //         content_type: "application/vnd.pgrst.object+json".to_string(),
    1069              :     //     }));
    1070              :     // }
    1071              : 
    1072              :     // TODO: rollback the transaction if the page_total is not 1 and the method is PUT
    1073              :     // we can not do this in the context of proxy for now
    1074              :     // if api_request.method == Method::PUT && api_response.page_total != 1 {
    1075              :     //     // Makes sure the querystring pk matches the payload pk
    1076              :     //     // e.g. PUT /items?id=eq.1 { "id" : 1, .. } is accepted,
    1077              :     //     // PUT /items?id=eq.14 { "id" : 2, .. } is rejected.
    1078              :     //     // If this condition is not satisfied then nothing is inserted,
    1079              :     //     // rollback the transaction here
    1080              :     //     return Err(RestError::SubzeroCore(PutMatchingPkError));
    1081              :     // }
    1082              : 
    1083              :     // create and return the response to the client
    1084              :     // this section mostly deals with setting the right headers according to PostgREST specs
    1085            0 :     let page_total = api_response.page_total;
    1086            0 :     let total_result_set = api_response.total_result_set;
    1087            0 :     let top_level_offset = api_response.top_level_offset;
    1088            0 :     let response_content_type = match (&api_request.accept_content_type, &api_request.query.node) {
    1089              :         (SingularJSON, _)
    1090              :         | (
    1091              :             _,
    1092              :             FunctionCall {
    1093              :                 returns_single: true,
    1094              :                 is_scalar: false,
    1095              :                 ..
    1096              :             },
    1097            0 :         ) => SingularJSON,
    1098            0 :         (TextCSV, _) => TextCSV,
    1099            0 :         _ => ApplicationJSON,
    1100              :     };
    1101              : 
    1102              :     // check if the SQL env set some response headers (happens when we called a rpc function)
    1103            0 :     if let Some(response_headers_str) = api_response.response_headers {
    1104            0 :         let Ok(headers_json) =
    1105            0 :             serde_json::from_str::<Vec<Vec<(String, String)>>>(response_headers_str.as_str())
    1106              :         else {
    1107            0 :             return Err(RestError::SubzeroCore(GucHeadersError));
    1108              :         };
    1109              : 
    1110            0 :         response_headers.extend(headers_json.into_iter().flatten());
    1111            0 :     }
    1112              : 
    1113              :     // calculate and set the content range header
    1114            0 :     let lower = top_level_offset as i64;
    1115            0 :     let upper = top_level_offset as i64 + page_total as i64 - 1;
    1116            0 :     let total = total_result_set.map(|t| t as i64);
    1117            0 :     let content_range = match (&method, &api_request.query.node) {
    1118            0 :         (&Method::POST, Insert { .. }) => content_range_header(1, 0, total),
    1119            0 :         (&Method::DELETE, Delete { .. }) => content_range_header(1, upper, total),
    1120            0 :         _ => content_range_header(lower, upper, total),
    1121              :     };
    1122            0 :     response_headers.push(("Content-Range".to_string(), content_range));
    1123              : 
    1124              :     // calculate the status code
    1125              :     #[rustfmt::skip]
    1126            0 :     let mut status = match (&method, &api_request.query.node, page_total, &api_request.preferences) {
    1127            0 :         (&Method::POST,   Insert { .. }, ..) => 201,
    1128            0 :         (&Method::DELETE, Delete { .. }, _, Some(Preferences {representation: Some(Representation::Full),..}),) => 200,
    1129            0 :         (&Method::DELETE, Delete { .. }, ..) => 204,
    1130            0 :         (&Method::PATCH,  Update { columns, .. }, 0, _) if !columns.is_empty() => 404,
    1131            0 :         (&Method::PATCH,  Update { .. }, _,Some(Preferences {representation: Some(Representation::Full),..}),) => 200,
    1132            0 :         (&Method::PATCH,  Update { .. }, ..) => 204,
    1133            0 :         (&Method::PUT,    Insert { .. },_,Some(Preferences {representation: Some(Representation::Full),..}),) => 200,
    1134            0 :         (&Method::PUT,    Insert { .. }, ..) => 204,
    1135            0 :         _ => content_range_status(lower, upper, total),
    1136              :     };
    1137              : 
    1138              :     // add the preference-applied header
    1139              :     if let Some(Preferences {
    1140            0 :         resolution: Some(r),
    1141              :         ..
    1142            0 :     }) = api_request.preferences
    1143              :     {
    1144            0 :         response_headers.push((
    1145            0 :             "Preference-Applied".to_string(),
    1146            0 :             match r {
    1147            0 :                 MergeDuplicates => "resolution=merge-duplicates".to_string(),
    1148            0 :                 IgnoreDuplicates => "resolution=ignore-duplicates".to_string(),
    1149              :             },
    1150              :         ));
    1151            0 :     }
    1152              : 
    1153              :     // check if the SQL env set some response status (happens when we called a rpc function)
    1154            0 :     if let Some(response_status_str) = api_response.response_status {
    1155            0 :         status = response_status_str
    1156            0 :             .parse::<u16>()
    1157            0 :             .map_err(|_| RestError::SubzeroCore(GucStatusError))?;
    1158            0 :     }
    1159              : 
    1160              :     // set the content type header
    1161              :     // TODO: move this to a subzero function
    1162              :     // as_header_value(&self) -> Option<&str>
    1163            0 :     let http_content_type = match response_content_type {
    1164            0 :         SingularJSON => Ok("application/vnd.pgrst.object+json"),
    1165            0 :         TextCSV => Ok("text/csv"),
    1166            0 :         ApplicationJSON => Ok("application/json"),
    1167            0 :         Other(t) => Err(RestError::SubzeroCore(ContentTypeError {
    1168            0 :             message: format!("None of these Content-Types are available: {t}"),
    1169            0 :         })),
    1170            0 :     }?;
    1171              : 
    1172              :     // build the response body
    1173            0 :     let response_body = Full::new(Bytes::from(api_response.body))
    1174            0 :         .map_err(|never| match never {})
    1175            0 :         .boxed();
    1176              : 
    1177              :     // build the response
    1178            0 :     let mut response = Response::builder()
    1179            0 :         .status(StatusCode::from_u16(status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR))
    1180            0 :         .header(CONTENT_TYPE, http_content_type);
    1181              : 
    1182              :     // Add all headers from response_headers vector
    1183            0 :     for (header_name, header_value) in response_headers {
    1184            0 :         response = response.header(header_name, header_value);
    1185            0 :     }
    1186              : 
    1187              :     // add the body and return the response
    1188            0 :     response.body(response_body).map_err(|_| {
    1189            0 :         RestError::SubzeroCore(InternalError {
    1190            0 :             message: "Failed to build response".to_string(),
    1191            0 :         })
    1192            0 :     })
    1193            0 : }
        

Generated by: LCOV version 2.1-beta