LCOV - code coverage report
Current view: top level - proxy/src/serverless - rest.rs (source / functions) Coverage Total Hit
Test: 1d5975439f3c9882b18414799141ebf9a3922c58.info Lines: 0.0 % 760 0
Test Date: 2025-07-31 15:59:03 Functions: 0.0 % 100 0

            Line data    Source code
       1              : use std::borrow::Cow;
       2              : use std::collections::HashMap;
       3              : use std::convert::Infallible;
       4              : use std::sync::Arc;
       5              : 
       6              : use bytes::Bytes;
       7              : use http::Method;
       8              : use http::header::{
       9              :     ACCESS_CONTROL_ALLOW_HEADERS, ACCESS_CONTROL_ALLOW_METHODS, ACCESS_CONTROL_ALLOW_ORIGIN,
      10              :     ACCESS_CONTROL_EXPOSE_HEADERS, ACCESS_CONTROL_MAX_AGE, ACCESS_CONTROL_REQUEST_HEADERS, ALLOW,
      11              :     AUTHORIZATION, CONTENT_TYPE, HOST, ORIGIN,
      12              : };
      13              : use http_body_util::combinators::BoxBody;
      14              : use http_body_util::{BodyExt, Empty, Full};
      15              : use http_utils::error::ApiError;
      16              : use hyper::body::Incoming;
      17              : use hyper::http::response::Builder;
      18              : use hyper::http::{HeaderMap, HeaderName, HeaderValue};
      19              : use hyper::{Request, Response, StatusCode};
      20              : use indexmap::IndexMap;
      21              : use moka::sync::Cache;
      22              : use ouroboros::self_referencing;
      23              : use serde::de::DeserializeOwned;
      24              : use serde::{Deserialize, Deserializer};
      25              : use serde_json::Value as JsonValue;
      26              : use serde_json::value::RawValue;
      27              : use subzero_core::api::ContentType::{ApplicationJSON, Other, SingularJSON, TextCSV};
      28              : use subzero_core::api::QueryNode::{Delete, FunctionCall, Insert, Update};
      29              : use subzero_core::api::Resolution::{IgnoreDuplicates, MergeDuplicates};
      30              : use subzero_core::api::{ApiResponse, ListVal, Payload, Preferences, Representation, SingleVal};
      31              : use subzero_core::config::{db_allowed_select_functions, db_schemas, role_claim_key};
      32              : use subzero_core::dynamic_statement::{JoinIterator, param, sql};
      33              : use subzero_core::error::Error::{
      34              :     self as SubzeroCoreError, ContentTypeError, GucHeadersError, GucStatusError, InternalError,
      35              :     JsonDeserialize, JwtTokenInvalid, NotFound,
      36              : };
      37              : use subzero_core::error::pg_error_to_status_code;
      38              : use subzero_core::formatter::Param::{LV, PL, SV, Str, StrOwned};
      39              : use subzero_core::formatter::postgresql::{fmt_main_query, generate};
      40              : use subzero_core::formatter::{Param, Snippet, SqlParam};
      41              : use subzero_core::parser::postgrest::parse;
      42              : use subzero_core::permissions::{check_safe_functions, replace_select_star};
      43              : use subzero_core::schema::{
      44              :     DbSchema, POSTGRESQL_INTROSPECTION_SQL, get_postgresql_configuration_query,
      45              : };
      46              : use subzero_core::{content_range_header, content_range_status};
      47              : use tokio_util::sync::CancellationToken;
      48              : use tracing::{error, info};
      49              : use typed_json::json;
      50              : use url::form_urlencoded;
      51              : 
      52              : use super::backend::{HttpConnError, LocalProxyConnError, PoolingBackend};
      53              : use super::conn_pool::AuthData;
      54              : use super::conn_pool_lib::ConnInfo;
      55              : use super::error::{ConnInfoError, Credentials, HttpCodeError, ReadPayloadError};
      56              : use super::http_conn_pool::{self, LocalProxyClient};
      57              : use super::http_util::{
      58              :     ALLOW_POOL, CONN_STRING, NEON_REQUEST_ID, RAW_TEXT_OUTPUT, TXN_ISOLATION_LEVEL, TXN_READ_ONLY,
      59              :     get_conn_info, json_response, uuid_to_header_value,
      60              : };
      61              : use super::json::JsonConversionError;
      62              : use crate::auth::backend::ComputeCredentialKeys;
      63              : use crate::cache::common::{count_cache_insert, count_cache_outcome, eviction_listener};
      64              : use crate::config::ProxyConfig;
      65              : use crate::context::RequestContext;
      66              : use crate::error::{ErrorKind, ReportableError, UserFacingError};
      67              : use crate::http::read_body_with_limit;
      68              : use crate::metrics::{CacheKind, Metrics};
      69              : use crate::serverless::sql_over_http::HEADER_VALUE_TRUE;
      70              : use crate::types::EndpointCacheKey;
      71              : use crate::util::deserialize_json_string;
      72              : 
      73              : static EMPTY_JSON_SCHEMA: &str = r#"{"schemas":[]}"#;
      74              : const INTROSPECTION_SQL: &str = POSTGRESQL_INTROSPECTION_SQL;
      75              : const HEADER_VALUE_ALLOW_ALL_ORIGINS: HeaderValue = HeaderValue::from_static("*");
      76              : // CORS headers values
      77              : const ACCESS_CONTROL_ALLOW_METHODS_VALUE: HeaderValue =
      78              :     HeaderValue::from_static("GET, POST, PATCH, PUT, DELETE, OPTIONS");
      79              : const ACCESS_CONTROL_MAX_AGE_VALUE: HeaderValue = HeaderValue::from_static("86400");
      80              : const ACCESS_CONTROL_EXPOSE_HEADERS_VALUE: HeaderValue = HeaderValue::from_static(
      81              :     "Content-Encoding, Content-Location, Content-Range, Content-Type, Date, Location, Server, Transfer-Encoding, Range-Unit",
      82              : );
      83              : const ACCESS_CONTROL_ALLOW_HEADERS_VALUE: HeaderValue = HeaderValue::from_static("Authorization");
      84              : 
      85              : // A wrapper around the DbSchema that allows for self-referencing
      86              : #[self_referencing]
      87              : pub struct DbSchemaOwned {
      88              :     schema_string: String,
      89              :     #[covariant]
      90              :     #[borrows(schema_string)]
      91              :     schema: DbSchema<'this>,
      92              : }
      93              : 
      94              : impl<'de> Deserialize<'de> for DbSchemaOwned {
      95            0 :     fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
      96            0 :     where
      97            0 :         D: Deserializer<'de>,
      98              :     {
      99            0 :         let s = String::deserialize(deserializer)?;
     100            0 :         DbSchemaOwned::try_new(s, |s| serde_json::from_str(s))
     101            0 :             .map_err(<D::Error as serde::de::Error>::custom)
     102            0 :     }
     103              : }
     104              : 
     105            0 : fn split_comma_separated(s: &str) -> Vec<String> {
     106            0 :     s.split(',').map(|s| s.trim().to_string()).collect()
     107            0 : }
     108              : 
     109            0 : fn deserialize_comma_separated<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
     110            0 : where
     111            0 :     D: Deserializer<'de>,
     112              : {
     113            0 :     let s = String::deserialize(deserializer)?;
     114            0 :     Ok(split_comma_separated(&s))
     115            0 : }
     116              : 
     117            0 : fn deserialize_comma_separated_option<'de, D>(
     118            0 :     deserializer: D,
     119            0 : ) -> Result<Option<Vec<String>>, D::Error>
     120            0 : where
     121            0 :     D: Deserializer<'de>,
     122              : {
     123            0 :     let opt = Option::<String>::deserialize(deserializer)?;
     124            0 :     if let Some(s) = &opt {
     125            0 :         let trimmed = s.trim();
     126            0 :         if trimmed.is_empty() {
     127            0 :             return Ok(None);
     128            0 :         }
     129            0 :         return Ok(Some(split_comma_separated(trimmed)));
     130            0 :     }
     131            0 :     Ok(None)
     132            0 : }
     133              : 
     134              : // The ApiConfig is the configuration for the API per endpoint
     135              : // The configuration is read from the database and cached in the DbSchemaCache
     136              : #[derive(Deserialize, Debug)]
     137              : pub struct ApiConfig {
     138              :     #[serde(
     139              :         default = "db_schemas",
     140              :         deserialize_with = "deserialize_comma_separated"
     141              :     )]
     142              :     pub db_schemas: Vec<String>,
     143              :     pub db_anon_role: Option<String>,
     144              :     pub db_max_rows: Option<String>,
     145              :     #[serde(default = "db_allowed_select_functions")]
     146              :     pub db_allowed_select_functions: Vec<String>,
     147              :     // #[serde(deserialize_with = "to_tuple", default)]
     148              :     // pub db_pre_request: Option<(String, String)>,
     149              :     #[allow(dead_code)]
     150              :     #[serde(default = "role_claim_key")]
     151              :     pub role_claim_key: String,
     152              :     #[serde(default, deserialize_with = "deserialize_comma_separated_option")]
     153              :     pub db_extra_search_path: Option<Vec<String>>,
     154              :     #[serde(default, deserialize_with = "deserialize_comma_separated_option")]
     155              :     pub server_cors_allowed_origins: Option<Vec<String>>,
     156              : }
     157              : 
     158              : // The DbSchemaCache is a cache of the ApiConfig and DbSchemaOwned for each endpoint
     159              : pub(crate) struct DbSchemaCache(Cache<EndpointCacheKey, Arc<(ApiConfig, DbSchemaOwned)>>);
     160              : impl DbSchemaCache {
     161            0 :     pub fn new(config: crate::config::CacheOptions) -> Self {
     162            0 :         let builder = Cache::builder().name("schema");
     163            0 :         let builder = config.moka(builder);
     164              : 
     165            0 :         let metrics = &Metrics::get().cache;
     166            0 :         if let Some(size) = config.size {
     167            0 :             metrics.capacity.set(CacheKind::Schema, size as i64);
     168            0 :         }
     169              : 
     170            0 :         let builder =
     171            0 :             builder.eviction_listener(|_k, _v, cause| eviction_listener(CacheKind::Schema, cause));
     172              : 
     173            0 :         Self(builder.build())
     174            0 :     }
     175              : 
     176            0 :     pub async fn maintain(&self) -> Result<Infallible, anyhow::Error> {
     177            0 :         let mut ticker = tokio::time::interval(std::time::Duration::from_secs(60));
     178              :         loop {
     179            0 :             ticker.tick().await;
     180            0 :             self.0.run_pending_tasks();
     181              :         }
     182              :     }
     183              : 
     184            0 :     pub fn get_cached(
     185            0 :         &self,
     186            0 :         endpoint_id: &EndpointCacheKey,
     187            0 :     ) -> Option<Arc<(ApiConfig, DbSchemaOwned)>> {
     188            0 :         count_cache_outcome(CacheKind::Schema, self.0.get(endpoint_id))
     189            0 :     }
     190            0 :     pub async fn get_remote(
     191            0 :         &self,
     192            0 :         endpoint_id: &EndpointCacheKey,
     193            0 :         auth_header: &HeaderValue,
     194            0 :         connection_string: &str,
     195            0 :         client: &mut http_conn_pool::Client<LocalProxyClient>,
     196            0 :         ctx: &RequestContext,
     197            0 :         config: &'static ProxyConfig,
     198            0 :     ) -> Result<Arc<(ApiConfig, DbSchemaOwned)>, RestError> {
     199            0 :         info!("db_schema cache miss for endpoint: {:?}", endpoint_id);
     200            0 :         let remote_value = self
     201            0 :             .internal_get_remote(auth_header, connection_string, client, ctx, config)
     202            0 :             .await;
     203            0 :         let (api_config, schema_owned) = match remote_value {
     204            0 :             Ok((api_config, schema_owned)) => (api_config, schema_owned),
     205            0 :             Err(e @ RestError::SchemaTooLarge) => {
     206              :                 // for the case where the schema is too large, we cache an empty dummy value
     207              :                 // all the other requests will fail without triggering the introspection query
     208            0 :                 let schema_owned = serde_json::from_str::<DbSchemaOwned>(EMPTY_JSON_SCHEMA)
     209            0 :                     .map_err(|e| JsonDeserialize { source: e })?;
     210              : 
     211            0 :                 let api_config = ApiConfig {
     212            0 :                     db_schemas: vec![],
     213            0 :                     db_anon_role: None,
     214            0 :                     db_max_rows: None,
     215            0 :                     db_allowed_select_functions: vec![],
     216            0 :                     role_claim_key: String::new(),
     217            0 :                     db_extra_search_path: None,
     218            0 :                     server_cors_allowed_origins: None,
     219            0 :                 };
     220            0 :                 let value = Arc::new((api_config, schema_owned));
     221            0 :                 count_cache_insert(CacheKind::Schema);
     222            0 :                 self.0.insert(endpoint_id.clone(), value);
     223            0 :                 return Err(e);
     224              :             }
     225            0 :             Err(e) => {
     226            0 :                 return Err(e);
     227              :             }
     228              :         };
     229            0 :         let value = Arc::new((api_config, schema_owned));
     230            0 :         count_cache_insert(CacheKind::Schema);
     231            0 :         self.0.insert(endpoint_id.clone(), value.clone());
     232            0 :         Ok(value)
     233            0 :     }
     234            0 :     async fn internal_get_remote(
     235            0 :         &self,
     236            0 :         auth_header: &HeaderValue,
     237            0 :         connection_string: &str,
     238            0 :         client: &mut http_conn_pool::Client<LocalProxyClient>,
     239            0 :         ctx: &RequestContext,
     240            0 :         config: &'static ProxyConfig,
     241            0 :     ) -> Result<(ApiConfig, DbSchemaOwned), RestError> {
     242            0 :         #[derive(Deserialize)]
     243              :         struct SingleRow<Row> {
     244              :             rows: [Row; 1],
     245              :         }
     246              : 
     247              :         #[derive(Deserialize)]
     248              :         struct ConfigRow {
     249              :             #[serde(deserialize_with = "deserialize_json_string")]
     250              :             config: ApiConfig,
     251              :         }
     252              : 
     253            0 :         #[derive(Deserialize)]
     254              :         struct SchemaRow {
     255              :             json_schema: DbSchemaOwned,
     256              :         }
     257              : 
     258            0 :         let headers = vec![
     259            0 :             (&NEON_REQUEST_ID, uuid_to_header_value(ctx.session_id())),
     260            0 :             (
     261            0 :                 &CONN_STRING,
     262            0 :                 HeaderValue::from_str(connection_string).expect(
     263            0 :                     "connection string came from a header, so it must be a valid headervalue",
     264            0 :                 ),
     265            0 :             ),
     266            0 :             (&AUTHORIZATION, auth_header.clone()),
     267            0 :             (&RAW_TEXT_OUTPUT, HEADER_VALUE_TRUE),
     268              :         ];
     269              : 
     270            0 :         let query = get_postgresql_configuration_query(Some("pgrst.pre_config"));
     271              :         let SingleRow {
     272            0 :             rows: [ConfigRow { config: api_config }],
     273            0 :         } = make_local_proxy_request(
     274            0 :             client,
     275            0 :             headers.iter().cloned(),
     276            0 :             QueryData {
     277            0 :                 query: Cow::Owned(query),
     278            0 :                 params: vec![],
     279            0 :             },
     280            0 :             config.rest_config.max_schema_size,
     281            0 :         )
     282            0 :         .await
     283            0 :         .map_err(|e| match e {
     284              :             RestError::ReadPayload(ReadPayloadError::BodyTooLarge { .. }) => {
     285            0 :                 RestError::SchemaTooLarge
     286              :             }
     287            0 :             e => e,
     288            0 :         })?;
     289              : 
     290              :         // now that we have the api_config let's run the second INTROSPECTION_SQL query
     291              :         let SingleRow {
     292            0 :             rows: [SchemaRow { json_schema }],
     293            0 :         } = make_local_proxy_request(
     294            0 :             client,
     295            0 :             headers,
     296            0 :             QueryData {
     297            0 :                 query: INTROSPECTION_SQL.into(),
     298            0 :                 params: vec![
     299            0 :                     serde_json::to_value(&api_config.db_schemas)
     300            0 :                         .expect("Vec<String> is always valid to encode as JSON"),
     301            0 :                     JsonValue::Bool(false), // include_roles_with_login
     302            0 :                     JsonValue::Bool(false), // use_internal_permissions
     303            0 :                 ],
     304            0 :             },
     305            0 :             config.rest_config.max_schema_size,
     306            0 :         )
     307            0 :         .await
     308            0 :         .map_err(|e| match e {
     309              :             RestError::ReadPayload(ReadPayloadError::BodyTooLarge { .. }) => {
     310            0 :                 RestError::SchemaTooLarge
     311              :             }
     312            0 :             e => e,
     313            0 :         })?;
     314              : 
     315            0 :         Ok((api_config, json_schema))
     316            0 :     }
     317              : }
     318              : 
     319              : // A type to represent a postgresql errors
     320              : // we use our own type (instead of postgres_client::Error) because we get the error from the json response
     321            0 : #[derive(Debug, thiserror::Error, Deserialize)]
     322              : pub(crate) struct PostgresError {
     323              :     pub code: String,
     324              :     pub message: String,
     325              :     pub detail: Option<String>,
     326              :     pub hint: Option<String>,
     327              : }
     328              : impl HttpCodeError for PostgresError {
     329            0 :     fn get_http_status_code(&self) -> StatusCode {
     330            0 :         let status = pg_error_to_status_code(&self.code, true);
     331            0 :         StatusCode::from_u16(status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)
     332            0 :     }
     333              : }
     334              : impl ReportableError for PostgresError {
     335            0 :     fn get_error_kind(&self) -> ErrorKind {
     336            0 :         ErrorKind::User
     337            0 :     }
     338              : }
     339              : impl UserFacingError for PostgresError {
     340            0 :     fn to_string_client(&self) -> String {
     341            0 :         if self.code.starts_with("PT") {
     342            0 :             "Postgres error".to_string()
     343              :         } else {
     344            0 :             self.message.clone()
     345              :         }
     346            0 :     }
     347              : }
     348              : impl std::fmt::Display for PostgresError {
     349            0 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
     350            0 :         write!(f, "{}", self.message)
     351            0 :     }
     352              : }
     353              : 
     354              : // A type to represent errors that can occur in the rest broker
     355              : #[derive(Debug, thiserror::Error)]
     356              : pub(crate) enum RestError {
     357              :     #[error(transparent)]
     358              :     ReadPayload(#[from] ReadPayloadError),
     359              :     #[error(transparent)]
     360              :     ConnectCompute(#[from] HttpConnError),
     361              :     #[error(transparent)]
     362              :     ConnInfo(#[from] ConnInfoError),
     363              :     #[error(transparent)]
     364              :     Postgres(#[from] PostgresError),
     365              :     #[error(transparent)]
     366              :     JsonConversion(#[from] JsonConversionError),
     367              :     #[error(transparent)]
     368              :     SubzeroCore(#[from] SubzeroCoreError),
     369              :     #[error("schema is too large")]
     370              :     SchemaTooLarge,
     371              : }
     372              : impl ReportableError for RestError {
     373            0 :     fn get_error_kind(&self) -> ErrorKind {
     374            0 :         match self {
     375            0 :             RestError::ReadPayload(e) => e.get_error_kind(),
     376            0 :             RestError::ConnectCompute(e) => e.get_error_kind(),
     377            0 :             RestError::ConnInfo(e) => e.get_error_kind(),
     378            0 :             RestError::Postgres(_) => ErrorKind::Postgres,
     379            0 :             RestError::JsonConversion(_) => ErrorKind::Postgres,
     380            0 :             RestError::SubzeroCore(_) => ErrorKind::User,
     381            0 :             RestError::SchemaTooLarge => ErrorKind::User,
     382              :         }
     383            0 :     }
     384              : }
     385              : impl UserFacingError for RestError {
     386            0 :     fn to_string_client(&self) -> String {
     387            0 :         match self {
     388            0 :             RestError::ReadPayload(p) => p.to_string(),
     389            0 :             RestError::ConnectCompute(c) => c.to_string_client(),
     390            0 :             RestError::ConnInfo(c) => c.to_string_client(),
     391            0 :             RestError::SchemaTooLarge => self.to_string(),
     392            0 :             RestError::Postgres(p) => p.to_string_client(),
     393            0 :             RestError::JsonConversion(_) => "could not parse postgres response".to_string(),
     394            0 :             RestError::SubzeroCore(s) => {
     395              :                 // TODO: this is a hack to get the message from the json body
     396            0 :                 let json = s.json_body();
     397            0 :                 let default_message = "Unknown error".to_string();
     398              : 
     399            0 :                 json.get("message")
     400            0 :                     .map_or(default_message.clone(), |m| match m {
     401            0 :                         JsonValue::String(s) => s.clone(),
     402            0 :                         _ => default_message,
     403            0 :                     })
     404              :             }
     405              :         }
     406            0 :     }
     407              : }
     408              : impl HttpCodeError for RestError {
     409            0 :     fn get_http_status_code(&self) -> StatusCode {
     410            0 :         match self {
     411            0 :             RestError::ReadPayload(e) => e.get_http_status_code(),
     412            0 :             RestError::ConnectCompute(h) => match h.get_error_kind() {
     413            0 :                 ErrorKind::User => StatusCode::BAD_REQUEST,
     414            0 :                 _ => StatusCode::INTERNAL_SERVER_ERROR,
     415              :             },
     416            0 :             RestError::ConnInfo(_) => StatusCode::BAD_REQUEST,
     417            0 :             RestError::Postgres(e) => e.get_http_status_code(),
     418            0 :             RestError::JsonConversion(_) => StatusCode::INTERNAL_SERVER_ERROR,
     419            0 :             RestError::SchemaTooLarge => StatusCode::INTERNAL_SERVER_ERROR,
     420            0 :             RestError::SubzeroCore(e) => {
     421            0 :                 let status = e.status_code();
     422            0 :                 StatusCode::from_u16(status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)
     423              :             }
     424              :         }
     425            0 :     }
     426              : }
     427              : 
     428              : // Helper functions for the rest broker
     429              : 
     430            0 : fn fmt_env_query<'a>(env: &'a HashMap<&'a str, &'a str>) -> Snippet<'a> {
     431              :     "select "
     432            0 :         + if env.is_empty() {
     433            0 :             sql("null")
     434              :         } else {
     435            0 :             env.iter()
     436            0 :                 .map(|(k, v)| {
     437            0 :                     "set_config(" + param(k as &SqlParam) + ", " + param(v as &SqlParam) + ", true)"
     438            0 :                 })
     439            0 :                 .join(",")
     440              :         }
     441            0 : }
     442              : 
     443              : // TODO: see about removing the need for cloning the values (inner things are &Cow<str> already)
     444            0 : fn to_sql_param(p: &Param) -> JsonValue {
     445            0 :     match p {
     446            0 :         SV(SingleVal(v, ..)) => JsonValue::String(v.to_string()),
     447            0 :         Str(v) => JsonValue::String((*v).to_string()),
     448            0 :         StrOwned(v) => JsonValue::String((*v).clone()),
     449            0 :         PL(Payload(v, ..)) => JsonValue::String(v.clone().into_owned()),
     450            0 :         LV(ListVal(v, ..)) => {
     451            0 :             if v.is_empty() {
     452            0 :                 JsonValue::String(r"{}".to_string())
     453              :             } else {
     454            0 :                 JsonValue::String(format!(
     455            0 :                     "{{\"{}\"}}",
     456            0 :                     v.iter()
     457            0 :                         .map(|e| e.replace('\\', "\\\\").replace('\"', "\\\""))
     458            0 :                         .collect::<Vec<_>>()
     459            0 :                         .join("\",\"")
     460              :                 ))
     461              :             }
     462              :         }
     463              :     }
     464            0 : }
     465              : 
     466              : #[derive(serde::Serialize)]
     467              : struct QueryData<'a> {
     468              :     query: Cow<'a, str>,
     469              :     params: Vec<JsonValue>,
     470              : }
     471              : 
     472              : #[derive(serde::Serialize)]
     473              : struct BatchQueryData<'a> {
     474              :     queries: Vec<QueryData<'a>>,
     475              : }
     476              : 
     477            0 : async fn make_local_proxy_request<S: DeserializeOwned>(
     478            0 :     client: &mut http_conn_pool::Client<LocalProxyClient>,
     479            0 :     headers: impl IntoIterator<Item = (&HeaderName, HeaderValue)>,
     480            0 :     body: QueryData<'_>,
     481            0 :     max_len: usize,
     482            0 : ) -> Result<S, RestError> {
     483            0 :     let body_string = serde_json::to_string(&body)
     484            0 :         .map_err(|e| RestError::JsonConversion(JsonConversionError::ParseJsonError(e)))?;
     485              : 
     486            0 :     let response = make_raw_local_proxy_request(client, headers, body_string).await?;
     487              : 
     488            0 :     let response_status = response.status();
     489              : 
     490            0 :     if response_status != StatusCode::OK {
     491            0 :         return Err(RestError::SubzeroCore(InternalError {
     492            0 :             message: "Failed to get endpoint schema".to_string(),
     493            0 :         }));
     494            0 :     }
     495              : 
     496              :     // Capture the response body
     497            0 :     let response_body = crate::http::read_body_with_limit(response.into_body(), max_len)
     498            0 :         .await
     499            0 :         .map_err(ReadPayloadError::from)?;
     500              : 
     501              :     // Parse the JSON response
     502            0 :     let response_json: S = serde_json::from_slice(&response_body)
     503            0 :         .map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?;
     504              : 
     505            0 :     Ok(response_json)
     506            0 : }
     507              : 
     508            0 : async fn make_raw_local_proxy_request(
     509            0 :     client: &mut http_conn_pool::Client<LocalProxyClient>,
     510            0 :     headers: impl IntoIterator<Item = (&HeaderName, HeaderValue)>,
     511            0 :     body: String,
     512            0 : ) -> Result<Response<Incoming>, RestError> {
     513            0 :     let local_proxy_uri = ::http::Uri::from_static("http://proxy.local/sql");
     514            0 :     let mut req = Request::builder().method(Method::POST).uri(local_proxy_uri);
     515            0 :     let req_headers = req.headers_mut().expect("failed to get headers");
     516              :     // Add all provided headers to the request
     517            0 :     for (header_name, header_value) in headers {
     518            0 :         req_headers.insert(header_name, header_value.clone());
     519            0 :     }
     520              : 
     521            0 :     let body_boxed = Full::new(Bytes::from(body))
     522            0 :         .map_err(|never| match never {}) // Convert Infallible to hyper::Error
     523            0 :         .boxed();
     524              : 
     525            0 :     let req = req.body(body_boxed).map_err(|_| {
     526            0 :         RestError::SubzeroCore(InternalError {
     527            0 :             message: "Failed to build request".to_string(),
     528            0 :         })
     529            0 :     })?;
     530              : 
     531              :     // Send the request to the local proxy
     532            0 :     client
     533            0 :         .inner
     534            0 :         .inner
     535            0 :         .send_request(req)
     536            0 :         .await
     537            0 :         .map_err(LocalProxyConnError::from)
     538            0 :         .map_err(HttpConnError::from)
     539            0 :         .map_err(RestError::from)
     540            0 : }
     541              : 
     542            0 : pub(crate) async fn handle(
     543            0 :     config: &'static ProxyConfig,
     544            0 :     ctx: RequestContext,
     545            0 :     request: Request<Incoming>,
     546            0 :     backend: Arc<PoolingBackend>,
     547            0 :     cancel: CancellationToken,
     548            0 : ) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
     549            0 :     let result = handle_inner(cancel, config, &ctx, request, backend).await;
     550              : 
     551            0 :     let response = match result {
     552            0 :         Ok(r) => {
     553            0 :             ctx.set_success();
     554              : 
     555              :             // Handling the error response from local proxy here
     556            0 :             if r.status().is_server_error() {
     557            0 :                 let status = r.status();
     558              : 
     559            0 :                 let body_bytes = r
     560            0 :                     .collect()
     561            0 :                     .await
     562            0 :                     .map_err(|e| {
     563            0 :                         ApiError::InternalServerError(anyhow::Error::msg(format!(
     564            0 :                             "could not collect http body: {e}"
     565            0 :                         )))
     566            0 :                     })?
     567            0 :                     .to_bytes();
     568              : 
     569            0 :                 if let Ok(mut json_map) =
     570            0 :                     serde_json::from_slice::<IndexMap<&str, &RawValue>>(&body_bytes)
     571              :                 {
     572            0 :                     let message = json_map.get("message");
     573            0 :                     if let Some(message) = message {
     574            0 :                         let msg: String = match serde_json::from_str(message.get()) {
     575            0 :                             Ok(msg) => msg,
     576              :                             Err(_) => {
     577            0 :                                 "Unable to parse the response message from server".to_string()
     578              :                             }
     579              :                         };
     580              : 
     581            0 :                         error!("Error response from local_proxy: {status} {msg}");
     582              : 
     583            0 :                         json_map.retain(|key, _| !key.starts_with("neon:")); // remove all the neon-related keys
     584              : 
     585            0 :                         let resp_json = serde_json::to_string(&json_map)
     586            0 :                             .unwrap_or("failed to serialize the response message".to_string());
     587              : 
     588            0 :                         return json_response(status, resp_json);
     589            0 :                     }
     590            0 :                 }
     591              : 
     592            0 :                 error!("Unable to parse the response message from local_proxy");
     593            0 :                 return json_response(
     594            0 :                     status,
     595            0 :                     json!({ "message": "Unable to parse the response message from server".to_string() }),
     596              :                 );
     597            0 :             }
     598            0 :             r
     599              :         }
     600            0 :         Err(e @ RestError::SubzeroCore(_)) => {
     601            0 :             let error_kind = e.get_error_kind();
     602            0 :             ctx.set_error_kind(error_kind);
     603              : 
     604            0 :             tracing::info!(
     605            0 :                 kind=error_kind.to_metric_label(),
     606              :                 error=%e,
     607              :                 msg="subzero core error",
     608            0 :                 "forwarding error to user"
     609              :             );
     610              : 
     611            0 :             let RestError::SubzeroCore(subzero_err) = e else {
     612            0 :                 panic!("expected subzero core error")
     613              :             };
     614              : 
     615            0 :             let json_body = subzero_err.json_body();
     616            0 :             let status_code = StatusCode::from_u16(subzero_err.status_code())
     617            0 :                 .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
     618              : 
     619            0 :             json_response(status_code, json_body)?
     620              :         }
     621            0 :         Err(e) => {
     622            0 :             let error_kind = e.get_error_kind();
     623            0 :             ctx.set_error_kind(error_kind);
     624              : 
     625            0 :             let message = e.to_string_client();
     626            0 :             let status_code = e.get_http_status_code();
     627              : 
     628            0 :             tracing::info!(
     629            0 :                 kind=error_kind.to_metric_label(),
     630              :                 error=%e,
     631              :                 msg=message,
     632            0 :                 "forwarding error to user"
     633              :             );
     634              : 
     635            0 :             let (code, detail, hint) = match e {
     636            0 :                 RestError::Postgres(e) => (
     637            0 :                     if e.code.starts_with("PT") {
     638            0 :                         None
     639              :                     } else {
     640            0 :                         Some(e.code)
     641              :                     },
     642            0 :                     e.detail,
     643            0 :                     e.hint,
     644              :                 ),
     645            0 :                 _ => (None, None, None),
     646              :             };
     647              : 
     648            0 :             json_response(
     649            0 :                 status_code,
     650            0 :                 json!({
     651            0 :                     "message": message,
     652            0 :                     "code": code,
     653            0 :                     "detail": detail,
     654            0 :                     "hint": hint,
     655              :                 }),
     656            0 :             )?
     657              :         }
     658              :     };
     659              : 
     660            0 :     Ok(response)
     661            0 : }
     662              : 
     663            0 : async fn handle_inner(
     664            0 :     _cancel: CancellationToken,
     665            0 :     config: &'static ProxyConfig,
     666            0 :     ctx: &RequestContext,
     667            0 :     request: Request<Incoming>,
     668            0 :     backend: Arc<PoolingBackend>,
     669            0 : ) -> Result<Response<BoxBody<Bytes, hyper::Error>>, RestError> {
     670            0 :     let _requeset_gauge = Metrics::get()
     671            0 :         .proxy
     672            0 :         .connection_requests
     673            0 :         .guard(ctx.protocol());
     674            0 :     info!(
     675            0 :         protocol = %ctx.protocol(),
     676            0 :         "handling interactive connection from client"
     677              :     );
     678              : 
     679              :     // Read host from Host, then URI host as fallback
     680              :     // TODO: will this be a problem if behind a load balancer?
     681              :     // TODO: can we use the x-forwarded-host header?
     682            0 :     let host = request
     683            0 :         .headers()
     684            0 :         .get(HOST)
     685            0 :         .and_then(|v| v.to_str().ok())
     686            0 :         .unwrap_or_else(|| request.uri().host().unwrap_or(""));
     687              : 
     688              :     // a valid path is /database/rest/v1/... so splitting should be ["", "database", "rest", "v1", ...]
     689            0 :     let database_name = request
     690            0 :         .uri()
     691            0 :         .path()
     692            0 :         .split('/')
     693            0 :         .nth(1)
     694            0 :         .ok_or(RestError::SubzeroCore(NotFound {
     695            0 :             target: request.uri().path().to_string(),
     696            0 :         }))?;
     697              : 
     698              :     // we always use the authenticator role to connect to the database
     699            0 :     let authenticator_role = "authenticator";
     700              : 
     701              :     // Strip the hostname prefix from the host to get the database hostname
     702            0 :     let database_host = host.replace(&config.rest_config.hostname_prefix, "");
     703              : 
     704            0 :     let connection_string =
     705            0 :         format!("postgresql://{authenticator_role}@{database_host}/{database_name}");
     706              : 
     707            0 :     let conn_info = get_conn_info(
     708            0 :         &config.authentication_config,
     709            0 :         ctx,
     710            0 :         Some(&connection_string),
     711            0 :         request.headers(),
     712            0 :     )?;
     713            0 :     info!(
     714            0 :         user = conn_info.conn_info.user_info.user.as_str(),
     715            0 :         "credentials"
     716              :     );
     717              : 
     718            0 :     match conn_info.auth {
     719            0 :         AuthData::Jwt(jwt) => {
     720            0 :             let api_prefix = format!("/{database_name}/rest/v1/");
     721            0 :             handle_rest_inner(
     722            0 :                 config,
     723            0 :                 ctx,
     724            0 :                 &api_prefix,
     725            0 :                 request,
     726            0 :                 &connection_string,
     727            0 :                 conn_info.conn_info,
     728            0 :                 jwt,
     729            0 :                 backend,
     730            0 :             )
     731            0 :             .await
     732              :         }
     733            0 :         AuthData::Password(_) => Err(RestError::ConnInfo(ConnInfoError::MissingCredentials(
     734            0 :             Credentials::BearerJwt,
     735            0 :         ))),
     736              :     }
     737            0 : }
     738              : 
     739            0 : fn apply_common_cors_headers(
     740            0 :     response: &mut Builder,
     741            0 :     request_headers: &HeaderMap,
     742            0 :     allowed_origins: Option<&Vec<String>>,
     743            0 : ) {
     744            0 :     let request_origin = request_headers
     745            0 :         .get(ORIGIN)
     746            0 :         .map(|v| v.to_str().unwrap_or(""));
     747              : 
     748            0 :     let response_allow_origin = match (request_origin, allowed_origins) {
     749            0 :         (Some(or), Some(allowed_origins)) => {
     750            0 :             if allowed_origins.iter().any(|o| o == or) {
     751            0 :                 Some(HeaderValue::from_str(or).unwrap_or(HEADER_VALUE_ALLOW_ALL_ORIGINS))
     752              :             } else {
     753            0 :                 None
     754              :             }
     755              :         }
     756            0 :         (Some(_), None) => Some(HEADER_VALUE_ALLOW_ALL_ORIGINS),
     757            0 :         _ => None,
     758              :     };
     759            0 :     if let Some(h) = response.headers_mut() {
     760            0 :         h.insert(
     761            0 :             ACCESS_CONTROL_EXPOSE_HEADERS,
     762            0 :             ACCESS_CONTROL_EXPOSE_HEADERS_VALUE,
     763              :         );
     764            0 :         if let Some(origin) = response_allow_origin {
     765            0 :             h.insert(ACCESS_CONTROL_ALLOW_ORIGIN, origin);
     766            0 :         }
     767            0 :     }
     768            0 : }
     769              : 
     770              : #[allow(clippy::too_many_arguments)]
     771            0 : async fn handle_rest_inner(
     772            0 :     config: &'static ProxyConfig,
     773            0 :     ctx: &RequestContext,
     774            0 :     api_prefix: &str,
     775            0 :     request: Request<Incoming>,
     776            0 :     connection_string: &str,
     777            0 :     conn_info: ConnInfo,
     778            0 :     jwt: String,
     779            0 :     backend: Arc<PoolingBackend>,
     780            0 : ) -> Result<Response<BoxBody<Bytes, hyper::Error>>, RestError> {
     781            0 :     let db_schema_cache =
     782            0 :         config
     783            0 :             .rest_config
     784            0 :             .db_schema_cache
     785            0 :             .as_ref()
     786            0 :             .ok_or(RestError::SubzeroCore(InternalError {
     787            0 :                 message: "DB schema cache is not configured".to_string(),
     788            0 :             }))?;
     789              : 
     790            0 :     let endpoint_cache_key = conn_info
     791            0 :         .endpoint_cache_key()
     792            0 :         .ok_or(RestError::SubzeroCore(InternalError {
     793            0 :             message: "Failed to get endpoint cache key".to_string(),
     794            0 :         }))?;
     795              : 
     796            0 :     let (parts, originial_body) = request.into_parts();
     797              : 
     798              :     // try and get the cached entry for this endpoint
     799              :     // it contains the api config and the introspected db schema
     800            0 :     let cached_entry = db_schema_cache.get_cached(&endpoint_cache_key);
     801              : 
     802            0 :     let allowed_origins = cached_entry
     803            0 :         .as_ref()
     804            0 :         .and_then(|arc| arc.0.server_cors_allowed_origins.as_ref());
     805              : 
     806            0 :     let mut response = Response::builder();
     807            0 :     apply_common_cors_headers(&mut response, &parts.headers, allowed_origins);
     808              : 
     809              :     // handle the OPTIONS request
     810            0 :     if parts.method == Method::OPTIONS {
     811            0 :         let allowed_headers = parts
     812            0 :             .headers
     813            0 :             .get(ACCESS_CONTROL_REQUEST_HEADERS)
     814            0 :             .and_then(|a| a.to_str().ok())
     815            0 :             .filter(|v| !v.is_empty())
     816            0 :             .map_or_else(
     817            0 :                 || "Authorization".to_string(),
     818            0 :                 |v| format!("{v}, Authorization"),
     819              :             );
     820            0 :         return response
     821            0 :             .status(StatusCode::OK)
     822            0 :             .header(
     823            0 :                 ACCESS_CONTROL_ALLOW_METHODS,
     824            0 :                 ACCESS_CONTROL_ALLOW_METHODS_VALUE,
     825              :             )
     826            0 :             .header(ACCESS_CONTROL_MAX_AGE, ACCESS_CONTROL_MAX_AGE_VALUE)
     827            0 :             .header(
     828            0 :                 ACCESS_CONTROL_ALLOW_HEADERS,
     829            0 :                 HeaderValue::from_str(&allowed_headers)
     830            0 :                     .unwrap_or(ACCESS_CONTROL_ALLOW_HEADERS_VALUE),
     831              :             )
     832            0 :             .header(ALLOW, ACCESS_CONTROL_ALLOW_METHODS_VALUE)
     833            0 :             .body(Empty::new().map_err(|x| match x {}).boxed())
     834            0 :             .map_err(|e| {
     835            0 :                 RestError::SubzeroCore(InternalError {
     836            0 :                     message: e.to_string(),
     837            0 :                 })
     838            0 :             });
     839            0 :     }
     840              : 
     841              :     // validate the jwt token
     842            0 :     let jwt_parsed = backend
     843            0 :         .authenticate_with_jwt(ctx, &conn_info.user_info, jwt)
     844            0 :         .await
     845            0 :         .map_err(HttpConnError::from)?;
     846              : 
     847            0 :     let auth_header = parts
     848            0 :         .headers
     849            0 :         .get(AUTHORIZATION)
     850            0 :         .ok_or(RestError::SubzeroCore(InternalError {
     851            0 :             message: "Authorization header is required".to_string(),
     852            0 :         }))?;
     853            0 :     let mut client = backend.connect_to_local_proxy(ctx, conn_info).await?;
     854              : 
     855            0 :     let entry = match cached_entry {
     856            0 :         Some(e) => e,
     857              :         None => {
     858              :             // if not cached, get the remote entry (will run the introspection query)
     859            0 :             db_schema_cache
     860            0 :                 .get_remote(
     861            0 :                     &endpoint_cache_key,
     862            0 :                     auth_header,
     863            0 :                     connection_string,
     864            0 :                     &mut client,
     865            0 :                     ctx,
     866            0 :                     config,
     867            0 :                 )
     868            0 :                 .await?
     869              :         }
     870              :     };
     871            0 :     let (api_config, db_schema_owned) = entry.as_ref();
     872              : 
     873            0 :     let db_schema = db_schema_owned.borrow_schema();
     874              : 
     875            0 :     let db_schemas = &api_config.db_schemas; // list of schemas available for the api
     876            0 :     let db_extra_search_path = &api_config.db_extra_search_path;
     877              :     // TODO: use this when we get a replacement for jsonpath_lib
     878              :     // let role_claim_key = &api_config.role_claim_key;
     879              :     // let role_claim_path = format!("${role_claim_key}");
     880            0 :     let db_anon_role = &api_config.db_anon_role;
     881            0 :     let max_rows = api_config.db_max_rows.as_deref();
     882            0 :     let db_allowed_select_functions = api_config
     883            0 :         .db_allowed_select_functions
     884            0 :         .iter()
     885            0 :         .map(|s| s.as_str())
     886            0 :         .collect::<Vec<_>>();
     887              : 
     888              :     // extract the jwt claims (we'll need them later to set the role and env)
     889            0 :     let jwt_claims = match jwt_parsed.keys {
     890            0 :         ComputeCredentialKeys::JwtPayload(payload_bytes) => {
     891              :             // `payload_bytes` contains the raw JWT payload as Vec<u8>
     892              :             // You can deserialize it back to JSON or parse specific claims
     893            0 :             let payload: serde_json::Value = serde_json::from_slice(&payload_bytes)
     894            0 :                 .map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?;
     895            0 :             Some(payload)
     896              :         }
     897            0 :         ComputeCredentialKeys::AuthKeys(_) => None,
     898              :     };
     899              : 
     900              :     // read the role from the jwt claims (and set it to the "anon" role if not present)
     901            0 :     let (role, authenticated) = match &jwt_claims {
     902            0 :         Some(claims) => match claims.get("role") {
     903            0 :             Some(JsonValue::String(r)) => (Some(r), true),
     904            0 :             _ => (db_anon_role.as_ref(), true),
     905              :         },
     906            0 :         None => (db_anon_role.as_ref(), false),
     907              :     };
     908              : 
     909              :     // do not allow unauthenticated requests when there is no anonymous role setup
     910            0 :     if let (None, false) = (role, authenticated) {
     911            0 :         return Err(RestError::SubzeroCore(JwtTokenInvalid {
     912            0 :             message: "unauthenticated requests not allowed".to_string(),
     913            0 :         }));
     914            0 :     }
     915              : 
     916              :     // start deconstructing the request because subzero core mostly works with &str
     917            0 :     let method = parts.method;
     918            0 :     let method_str = method.as_str();
     919            0 :     let path = parts.uri.path_and_query().map_or("/", |pq| pq.as_str());
     920              : 
     921              :     // this is actually the table name (or rpc/function_name)
     922              :     // TODO: rename this to something more descriptive
     923            0 :     let root = match parts.uri.path().strip_prefix(api_prefix) {
     924            0 :         Some(p) => Ok(p),
     925            0 :         None => Err(RestError::SubzeroCore(NotFound {
     926            0 :             target: parts.uri.path().to_string(),
     927            0 :         })),
     928            0 :     }?;
     929              : 
     930              :     // pick the current schema from the headers (or the first one from config)
     931            0 :     let schema_name = &DbSchema::pick_current_schema(db_schemas, method_str, &parts.headers)?;
     932              : 
     933              :     // add the content-profile header to the response
     934            0 :     let mut response_headers = vec![];
     935            0 :     if db_schemas.len() > 1 {
     936            0 :         response_headers.push(("Content-Profile".to_string(), schema_name.clone()));
     937            0 :     }
     938              : 
     939              :     // parse the query string into a Vec<(&str, &str)>
     940            0 :     let query = match parts.uri.query() {
     941            0 :         Some(q) => form_urlencoded::parse(q.as_bytes()).collect(),
     942            0 :         None => vec![],
     943              :     };
     944            0 :     let get: Vec<(&str, &str)> = query.iter().map(|(k, v)| (&**k, &**v)).collect();
     945              : 
     946              :     // convert the headers map to a HashMap<&str, &str>
     947            0 :     let headers: HashMap<&str, &str> = parts
     948            0 :         .headers
     949            0 :         .iter()
     950            0 :         .map(|(k, v)| (k.as_str(), v.to_str().unwrap_or("__BAD_HEADER__")))
     951            0 :         .collect();
     952              : 
     953            0 :     let cookies = HashMap::new(); // TODO: add cookies
     954              : 
     955              :     // Read the request body (skip for GET requests)
     956            0 :     let body_as_string: Option<String> = if method == Method::GET {
     957            0 :         None
     958              :     } else {
     959            0 :         let body_bytes =
     960            0 :             read_body_with_limit(originial_body, config.http_config.max_request_size_bytes)
     961            0 :                 .await
     962            0 :                 .map_err(ReadPayloadError::from)?;
     963            0 :         if body_bytes.is_empty() {
     964            0 :             None
     965              :         } else {
     966            0 :             Some(String::from_utf8_lossy(&body_bytes).into_owned())
     967              :         }
     968              :     };
     969              : 
     970              :     // parse the request into an ApiRequest struct
     971            0 :     let mut api_request = parse(
     972            0 :         schema_name,
     973            0 :         root,
     974            0 :         db_schema,
     975            0 :         method_str,
     976            0 :         path,
     977            0 :         get,
     978            0 :         body_as_string.as_deref(),
     979            0 :         headers,
     980            0 :         cookies,
     981            0 :         max_rows,
     982              :     )
     983            0 :     .map_err(RestError::SubzeroCore)?;
     984              : 
     985            0 :     let role_str = match role {
     986            0 :         Some(r) => r,
     987            0 :         None => "",
     988              :     };
     989              : 
     990            0 :     replace_select_star(db_schema, schema_name, role_str, &mut api_request.query)?;
     991              : 
     992              :     // TODO: this is not relevant when acting as PostgREST but will be useful
     993              :     // in the context of DBX where they need internal permissions
     994              :     // if !disable_internal_permissions {
     995              :     //     check_privileges(db_schema, schema_name, role_str, &api_request)?;
     996              :     // }
     997              : 
     998            0 :     check_safe_functions(&api_request, &db_allowed_select_functions)?;
     999              : 
    1000              :     // TODO: this is not relevant when acting as PostgREST but will be useful
    1001              :     // in the context of DBX where they need internal permissions
    1002              :     // if !disable_internal_permissions {
    1003              :     //     insert_policy_conditions(db_schema, schema_name, role_str, &mut api_request.query)?;
    1004              :     // }
    1005              : 
    1006            0 :     let env_role = Some(role_str);
    1007              : 
    1008              :     // construct the env (passed in to the sql context as GUCs)
    1009            0 :     let empty_json = "{}".to_string();
    1010            0 :     let headers_env = serde_json::to_string(&api_request.headers).unwrap_or(empty_json.clone());
    1011            0 :     let cookies_env = serde_json::to_string(&api_request.cookies).unwrap_or(empty_json.clone());
    1012            0 :     let get_env = serde_json::to_string(&api_request.get).unwrap_or(empty_json.clone());
    1013            0 :     let jwt_claims_env = jwt_claims
    1014            0 :         .as_ref()
    1015            0 :         .map(|v| serde_json::to_string(v).unwrap_or(empty_json.clone()))
    1016            0 :         .unwrap_or(if let Some(r) = env_role {
    1017            0 :             let claims: HashMap<&str, &str> = HashMap::from([("role", r)]);
    1018            0 :             serde_json::to_string(&claims).unwrap_or(empty_json.clone())
    1019              :         } else {
    1020            0 :             empty_json.clone()
    1021              :         });
    1022            0 :     let mut search_path = vec![api_request.schema_name];
    1023            0 :     if let Some(extra) = &db_extra_search_path {
    1024            0 :         search_path.extend(extra.iter().map(|s| s.as_str()));
    1025            0 :     }
    1026            0 :     let search_path_str = search_path
    1027            0 :         .into_iter()
    1028            0 :         .filter(|s| !s.is_empty())
    1029            0 :         .collect::<Vec<_>>()
    1030            0 :         .join(",");
    1031            0 :     let mut env: HashMap<&str, &str> = HashMap::from([
    1032            0 :         ("request.method", api_request.method),
    1033            0 :         ("request.path", api_request.path),
    1034            0 :         ("request.headers", &headers_env),
    1035            0 :         ("request.cookies", &cookies_env),
    1036            0 :         ("request.get", &get_env),
    1037            0 :         ("request.jwt.claims", &jwt_claims_env),
    1038            0 :         ("search_path", &search_path_str),
    1039            0 :     ]);
    1040            0 :     if let Some(r) = env_role {
    1041            0 :         env.insert("role", r);
    1042            0 :     }
    1043              : 
    1044              :     // generate the sql statements
    1045            0 :     let (env_statement, env_parameters, _) = generate(fmt_env_query(&env));
    1046            0 :     let (main_statement, main_parameters, _) = generate(fmt_main_query(
    1047            0 :         db_schema,
    1048            0 :         api_request.schema_name,
    1049            0 :         &api_request,
    1050            0 :         &env,
    1051            0 :     )?);
    1052              : 
    1053            0 :     let mut headers = vec![
    1054            0 :         (&NEON_REQUEST_ID, uuid_to_header_value(ctx.session_id())),
    1055            0 :         (
    1056            0 :             &CONN_STRING,
    1057            0 :             HeaderValue::from_str(connection_string).expect("invalid connection string"),
    1058            0 :         ),
    1059            0 :         (&AUTHORIZATION, auth_header.clone()),
    1060            0 :         (
    1061            0 :             &TXN_ISOLATION_LEVEL,
    1062            0 :             HeaderValue::from_static("ReadCommitted"),
    1063            0 :         ),
    1064            0 :         (&ALLOW_POOL, HEADER_VALUE_TRUE),
    1065              :     ];
    1066              : 
    1067            0 :     if api_request.read_only {
    1068            0 :         headers.push((&TXN_READ_ONLY, HEADER_VALUE_TRUE));
    1069            0 :     }
    1070              : 
    1071              :     // convert the parameters from subzero core representation to the local proxy repr.
    1072            0 :     let req_body = serde_json::to_string(&BatchQueryData {
    1073            0 :         queries: vec![
    1074              :             QueryData {
    1075            0 :                 query: env_statement.into(),
    1076            0 :                 params: env_parameters
    1077            0 :                     .iter()
    1078            0 :                     .map(|p| to_sql_param(&p.to_param()))
    1079            0 :                     .collect(),
    1080              :             },
    1081              :             QueryData {
    1082            0 :                 query: main_statement.into(),
    1083            0 :                 params: main_parameters
    1084            0 :                     .iter()
    1085            0 :                     .map(|p| to_sql_param(&p.to_param()))
    1086            0 :                     .collect(),
    1087              :             },
    1088              :         ],
    1089              :     })
    1090            0 :     .map_err(|e| RestError::JsonConversion(JsonConversionError::ParseJsonError(e)))?;
    1091              : 
    1092              :     // todo: map body to count egress
    1093            0 :     let _metrics = client.metrics(ctx); // FIXME: is everything in the context set correctly?
    1094              : 
    1095              :     // send the request to the local proxy
    1096            0 :     let proxy_response = make_raw_local_proxy_request(&mut client, headers, req_body).await?;
    1097            0 :     let (response_parts, body) = proxy_response.into_parts();
    1098              : 
    1099            0 :     let max_response = config.http_config.max_response_size_bytes;
    1100            0 :     let bytes = read_body_with_limit(body, max_response)
    1101            0 :         .await
    1102            0 :         .map_err(ReadPayloadError::from)?;
    1103              : 
    1104              :     // if the response status is greater than 399, then it is an error
    1105              :     // FIXME: check if there are other error codes or shapes of the response
    1106            0 :     if response_parts.status.as_u16() > 399 {
    1107              :         // turn this postgres error from the json into PostgresError
    1108            0 :         let postgres_error = serde_json::from_slice(&bytes)
    1109            0 :             .map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?;
    1110              : 
    1111            0 :         return Err(RestError::Postgres(postgres_error));
    1112            0 :     }
    1113              : 
    1114            0 :     #[derive(Deserialize)]
    1115              :     struct QueryResults {
    1116              :         /// we run two queries, so we want only two results.
    1117              :         results: (EnvRows, MainRows),
    1118              :     }
    1119              : 
    1120              :     /// `env_statement` returns nothing of interest to us
    1121            0 :     #[derive(Deserialize)]
    1122              :     struct EnvRows {}
    1123              : 
    1124            0 :     #[derive(Deserialize)]
    1125              :     struct MainRows {
    1126              :         /// `main_statement` only returns a single row.
    1127              :         rows: [MainRow; 1],
    1128              :     }
    1129              : 
    1130            0 :     #[derive(Deserialize)]
    1131              :     struct MainRow {
    1132              :         body: String,
    1133              :         page_total: Option<String>,
    1134              :         total_result_set: Option<String>,
    1135              :         response_headers: Option<String>,
    1136              :         response_status: Option<String>,
    1137              :     }
    1138              : 
    1139            0 :     let results: QueryResults = serde_json::from_slice(&bytes)
    1140            0 :         .map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?;
    1141              : 
    1142              :     let QueryResults {
    1143            0 :         results: (_, MainRows { rows: [row] }),
    1144            0 :     } = results;
    1145              : 
    1146              :     // build the intermediate response object
    1147            0 :     let api_response = ApiResponse {
    1148            0 :         page_total: row.page_total.map_or(0, |v| v.parse::<u64>().unwrap_or(0)),
    1149            0 :         total_result_set: row.total_result_set.map(|v| v.parse::<u64>().unwrap_or(0)),
    1150              :         top_level_offset: 0, // FIXME: check why this is 0
    1151            0 :         response_headers: row.response_headers,
    1152            0 :         response_status: row.response_status,
    1153            0 :         body: row.body,
    1154              :     };
    1155              : 
    1156              :     // TODO: rollback the transaction if the page_total is not 1 and the accept_content_type is SingularJSON
    1157              :     // we can not do this in the context of proxy for now
    1158              :     // if api_request.accept_content_type == SingularJSON && api_response.page_total != 1 {
    1159              :     //     // rollback the transaction here
    1160              :     //     return Err(RestError::SubzeroCore(SingularityError {
    1161              :     //         count: api_response.page_total,
    1162              :     //         content_type: "application/vnd.pgrst.object+json".to_string(),
    1163              :     //     }));
    1164              :     // }
    1165              : 
    1166              :     // TODO: rollback the transaction if the page_total is not 1 and the method is PUT
    1167              :     // we can not do this in the context of proxy for now
    1168              :     // if api_request.method == Method::PUT && api_response.page_total != 1 {
    1169              :     //     // Makes sure the querystring pk matches the payload pk
    1170              :     //     // e.g. PUT /items?id=eq.1 { "id" : 1, .. } is accepted,
    1171              :     //     // PUT /items?id=eq.14 { "id" : 2, .. } is rejected.
    1172              :     //     // If this condition is not satisfied then nothing is inserted,
    1173              :     //     // rollback the transaction here
    1174              :     //     return Err(RestError::SubzeroCore(PutMatchingPkError));
    1175              :     // }
    1176              : 
    1177              :     // create and return the response to the client
    1178              :     // this section mostly deals with setting the right headers according to PostgREST specs
    1179            0 :     let page_total = api_response.page_total;
    1180            0 :     let total_result_set = api_response.total_result_set;
    1181            0 :     let top_level_offset = api_response.top_level_offset;
    1182            0 :     let response_content_type = match (&api_request.accept_content_type, &api_request.query.node) {
    1183              :         (SingularJSON, _)
    1184              :         | (
    1185              :             _,
    1186              :             FunctionCall {
    1187              :                 returns_single: true,
    1188              :                 is_scalar: false,
    1189              :                 ..
    1190              :             },
    1191            0 :         ) => SingularJSON,
    1192            0 :         (TextCSV, _) => TextCSV,
    1193            0 :         _ => ApplicationJSON,
    1194              :     };
    1195              : 
    1196              :     // check if the SQL env set some response headers (happens when we called a rpc function)
    1197            0 :     if let Some(response_headers_str) = api_response.response_headers {
    1198            0 :         let Ok(headers_json) =
    1199            0 :             serde_json::from_str::<Vec<Vec<(String, String)>>>(response_headers_str.as_str())
    1200              :         else {
    1201            0 :             return Err(RestError::SubzeroCore(GucHeadersError));
    1202              :         };
    1203              : 
    1204            0 :         response_headers.extend(headers_json.into_iter().flatten());
    1205            0 :     }
    1206              : 
    1207              :     // calculate and set the content range header
    1208            0 :     let lower = top_level_offset as i64;
    1209            0 :     let upper = top_level_offset as i64 + page_total as i64 - 1;
    1210            0 :     let total = total_result_set.map(|t| t as i64);
    1211            0 :     let content_range = match (&method, &api_request.query.node) {
    1212            0 :         (&Method::POST, Insert { .. }) => content_range_header(1, 0, total),
    1213            0 :         (&Method::DELETE, Delete { .. }) => content_range_header(1, upper, total),
    1214            0 :         _ => content_range_header(lower, upper, total),
    1215              :     };
    1216            0 :     response_headers.push(("Content-Range".to_string(), content_range));
    1217              : 
    1218              :     // calculate the status code
    1219              :     #[rustfmt::skip]
    1220            0 :     let mut status = match (&method, &api_request.query.node, page_total, &api_request.preferences) {
    1221            0 :         (&Method::POST,   Insert { .. }, ..) => 201,
    1222            0 :         (&Method::DELETE, Delete { .. }, _, Some(Preferences {representation: Some(Representation::Full),..}),) => 200,
    1223            0 :         (&Method::DELETE, Delete { .. }, ..) => 204,
    1224            0 :         (&Method::PATCH,  Update { columns, .. }, 0, _) if !columns.is_empty() => 404,
    1225            0 :         (&Method::PATCH,  Update { .. }, _,Some(Preferences {representation: Some(Representation::Full),..}),) => 200,
    1226            0 :         (&Method::PATCH,  Update { .. }, ..) => 204,
    1227            0 :         (&Method::PUT,    Insert { .. },_,Some(Preferences {representation: Some(Representation::Full),..}),) => 200,
    1228            0 :         (&Method::PUT,    Insert { .. }, ..) => 204,
    1229            0 :         _ => content_range_status(lower, upper, total),
    1230              :     };
    1231              : 
    1232              :     // add the preference-applied header
    1233              :     if let Some(Preferences {
    1234            0 :         resolution: Some(r),
    1235              :         ..
    1236            0 :     }) = api_request.preferences
    1237              :     {
    1238            0 :         response_headers.push((
    1239            0 :             "Preference-Applied".to_string(),
    1240            0 :             match r {
    1241            0 :                 MergeDuplicates => "resolution=merge-duplicates".to_string(),
    1242            0 :                 IgnoreDuplicates => "resolution=ignore-duplicates".to_string(),
    1243              :             },
    1244              :         ));
    1245            0 :     }
    1246              : 
    1247              :     // check if the SQL env set some response status (happens when we called a rpc function)
    1248            0 :     if let Some(response_status_str) = api_response.response_status {
    1249            0 :         status = response_status_str
    1250            0 :             .parse::<u16>()
    1251            0 :             .map_err(|_| RestError::SubzeroCore(GucStatusError))?;
    1252            0 :     }
    1253              : 
    1254              :     // set the content type header
    1255              :     // TODO: move this to a subzero function
    1256              :     // as_header_value(&self) -> Option<&str>
    1257            0 :     let http_content_type = match response_content_type {
    1258            0 :         SingularJSON => Ok("application/vnd.pgrst.object+json"),
    1259            0 :         TextCSV => Ok("text/csv"),
    1260            0 :         ApplicationJSON => Ok("application/json"),
    1261            0 :         Other(t) => Err(RestError::SubzeroCore(ContentTypeError {
    1262            0 :             message: format!("None of these Content-Types are available: {t}"),
    1263            0 :         })),
    1264            0 :     }?;
    1265              : 
    1266              :     // build the response body
    1267            0 :     let response_body = Full::new(Bytes::from(api_response.body))
    1268            0 :         .map_err(|never| match never {})
    1269            0 :         .boxed();
    1270              : 
    1271              :     // build the response
    1272            0 :     response = response
    1273            0 :         .status(StatusCode::from_u16(status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR))
    1274            0 :         .header(CONTENT_TYPE, http_content_type);
    1275              : 
    1276              :     // Add all headers from response_headers vector
    1277            0 :     for (header_name, header_value) in response_headers {
    1278            0 :         response = response.header(header_name, header_value);
    1279            0 :     }
    1280              : 
    1281              :     // add the body and return the response
    1282            0 :     response.body(response_body).map_err(|_| {
    1283            0 :         RestError::SubzeroCore(InternalError {
    1284            0 :             message: "Failed to build response".to_string(),
    1285            0 :         })
    1286            0 :     })
    1287            0 : }
        

Generated by: LCOV version 2.1-beta