LCOV - code coverage report
Current view: top level - proxy/src/serverless - rest.rs (source / functions) Coverage Total Hit
Test: 5994c4f679fd099a9a8c180070d049e614bb817d.info Lines: 0.0 % 763 0
Test Date: 2025-07-31 15:05:35 Functions: 0.0 % 100 0

            Line data    Source code
       1              : use std::borrow::Cow;
       2              : use std::collections::HashMap;
       3              : use std::convert::Infallible;
       4              : use std::sync::Arc;
       5              : 
       6              : use bytes::Bytes;
       7              : use http::Method;
       8              : use http::header::{
       9              :     ACCESS_CONTROL_ALLOW_HEADERS, ACCESS_CONTROL_ALLOW_METHODS, ACCESS_CONTROL_ALLOW_ORIGIN,
      10              :     ACCESS_CONTROL_EXPOSE_HEADERS, ACCESS_CONTROL_MAX_AGE, ACCESS_CONTROL_REQUEST_HEADERS, ALLOW,
      11              :     AUTHORIZATION, CONTENT_TYPE, HOST, ORIGIN, VARY,
      12              : };
      13              : use http_body_util::combinators::BoxBody;
      14              : use http_body_util::{BodyExt, Empty, Full};
      15              : use http_utils::error::ApiError;
      16              : use hyper::body::Incoming;
      17              : use hyper::http::response::Builder;
      18              : use hyper::http::{HeaderMap, HeaderName, HeaderValue};
      19              : use hyper::{Request, Response, StatusCode};
      20              : use indexmap::IndexMap;
      21              : use moka::sync::Cache;
      22              : use ouroboros::self_referencing;
      23              : use serde::de::DeserializeOwned;
      24              : use serde::{Deserialize, Deserializer};
      25              : use serde_json::Value as JsonValue;
      26              : use serde_json::value::RawValue;
      27              : use subzero_core::api::ContentType::{ApplicationJSON, Other, SingularJSON, TextCSV};
      28              : use subzero_core::api::QueryNode::{Delete, FunctionCall, Insert, Update};
      29              : use subzero_core::api::Resolution::{IgnoreDuplicates, MergeDuplicates};
      30              : use subzero_core::api::{ApiResponse, ListVal, Payload, Preferences, Representation, SingleVal};
      31              : use subzero_core::config::{db_allowed_select_functions, db_schemas, role_claim_key};
      32              : use subzero_core::dynamic_statement::{JoinIterator, param, sql};
      33              : use subzero_core::error::Error::{
      34              :     self as SubzeroCoreError, ContentTypeError, GucHeadersError, GucStatusError, InternalError,
      35              :     JsonDeserialize, JwtTokenInvalid, NotFound,
      36              : };
      37              : use subzero_core::error::pg_error_to_status_code;
      38              : use subzero_core::formatter::Param::{LV, PL, SV, Str, StrOwned};
      39              : use subzero_core::formatter::postgresql::{fmt_main_query, generate};
      40              : use subzero_core::formatter::{Param, Snippet, SqlParam};
      41              : use subzero_core::parser::postgrest::parse;
      42              : use subzero_core::permissions::{check_safe_functions, replace_select_star};
      43              : use subzero_core::schema::{
      44              :     DbSchema, POSTGRESQL_INTROSPECTION_SQL, get_postgresql_configuration_query,
      45              : };
      46              : use subzero_core::{content_range_header, content_range_status};
      47              : use tokio_util::sync::CancellationToken;
      48              : use tracing::{error, info};
      49              : use typed_json::json;
      50              : use url::form_urlencoded;
      51              : 
      52              : use super::backend::{HttpConnError, LocalProxyConnError, PoolingBackend};
      53              : use super::conn_pool::AuthData;
      54              : use super::conn_pool_lib::ConnInfo;
      55              : use super::error::{ConnInfoError, Credentials, HttpCodeError, ReadPayloadError};
      56              : use super::http_conn_pool::{self, LocalProxyClient};
      57              : use super::http_util::{
      58              :     ALLOW_POOL, CONN_STRING, NEON_REQUEST_ID, RAW_TEXT_OUTPUT, TXN_ISOLATION_LEVEL, TXN_READ_ONLY,
      59              :     get_conn_info, json_response, uuid_to_header_value,
      60              : };
      61              : use super::json::JsonConversionError;
      62              : use crate::auth::backend::ComputeCredentialKeys;
      63              : use crate::cache::common::{count_cache_insert, count_cache_outcome, eviction_listener};
      64              : use crate::config::ProxyConfig;
      65              : use crate::context::RequestContext;
      66              : use crate::error::{ErrorKind, ReportableError, UserFacingError};
      67              : use crate::http::read_body_with_limit;
      68              : use crate::metrics::{CacheKind, Metrics};
      69              : use crate::serverless::sql_over_http::HEADER_VALUE_TRUE;
      70              : use crate::types::EndpointCacheKey;
      71              : use crate::util::deserialize_json_string;
      72              : 
      73              : static EMPTY_JSON_SCHEMA: &str = r#"{"schemas":[]}"#;
      74              : const INTROSPECTION_SQL: &str = POSTGRESQL_INTROSPECTION_SQL;
      75              : const HEADER_VALUE_ALLOW_ALL_ORIGINS: HeaderValue = HeaderValue::from_static("*");
      76              : // CORS headers values
      77              : const ACCESS_CONTROL_ALLOW_METHODS_VALUE: HeaderValue =
      78              :     HeaderValue::from_static("GET, POST, PATCH, PUT, DELETE, OPTIONS");
      79              : const ACCESS_CONTROL_MAX_AGE_VALUE: HeaderValue = HeaderValue::from_static("86400");
      80              : const ACCESS_CONTROL_EXPOSE_HEADERS_VALUE: HeaderValue = HeaderValue::from_static(
      81              :     "Content-Encoding, Content-Location, Content-Range, Content-Type, Date, Location, Server, Transfer-Encoding, Range-Unit",
      82              : );
      83              : const ACCESS_CONTROL_ALLOW_HEADERS_VALUE: HeaderValue = HeaderValue::from_static("Authorization");
      84              : const ACCESS_CONTROL_VARY_VALUE: HeaderValue = HeaderValue::from_static("Origin");
      85              : 
      86              : // A wrapper around the DbSchema that allows for self-referencing
      87              : #[self_referencing]
      88              : pub struct DbSchemaOwned {
      89              :     schema_string: String,
      90              :     #[covariant]
      91              :     #[borrows(schema_string)]
      92              :     schema: DbSchema<'this>,
      93              : }
      94              : 
      95              : impl<'de> Deserialize<'de> for DbSchemaOwned {
      96            0 :     fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
      97            0 :     where
      98            0 :         D: Deserializer<'de>,
      99              :     {
     100            0 :         let s = String::deserialize(deserializer)?;
     101            0 :         DbSchemaOwned::try_new(s, |s| serde_json::from_str(s))
     102            0 :             .map_err(<D::Error as serde::de::Error>::custom)
     103            0 :     }
     104              : }
     105              : 
     106            0 : fn split_comma_separated(s: &str) -> Vec<String> {
     107            0 :     s.split(',').map(|s| s.trim().to_string()).collect()
     108            0 : }
     109              : 
     110            0 : fn deserialize_comma_separated<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
     111            0 : where
     112            0 :     D: Deserializer<'de>,
     113              : {
     114            0 :     let s = String::deserialize(deserializer)?;
     115            0 :     Ok(split_comma_separated(&s))
     116            0 : }
     117              : 
     118            0 : fn deserialize_comma_separated_option<'de, D>(
     119            0 :     deserializer: D,
     120            0 : ) -> Result<Option<Vec<String>>, D::Error>
     121            0 : where
     122            0 :     D: Deserializer<'de>,
     123              : {
     124            0 :     let opt = Option::<String>::deserialize(deserializer)?;
     125            0 :     if let Some(s) = &opt {
     126            0 :         let trimmed = s.trim();
     127            0 :         if trimmed.is_empty() {
     128            0 :             return Ok(None);
     129            0 :         }
     130            0 :         return Ok(Some(split_comma_separated(trimmed)));
     131            0 :     }
     132            0 :     Ok(None)
     133            0 : }
     134              : 
     135              : // The ApiConfig is the configuration for the API per endpoint
     136              : // The configuration is read from the database and cached in the DbSchemaCache
     137              : #[derive(Deserialize, Debug)]
     138              : pub struct ApiConfig {
     139              :     #[serde(
     140              :         default = "db_schemas",
     141              :         deserialize_with = "deserialize_comma_separated"
     142              :     )]
     143              :     pub db_schemas: Vec<String>,
     144              :     pub db_anon_role: Option<String>,
     145              :     pub db_max_rows: Option<String>,
     146              :     #[serde(default = "db_allowed_select_functions")]
     147              :     pub db_allowed_select_functions: Vec<String>,
     148              :     // #[serde(deserialize_with = "to_tuple", default)]
     149              :     // pub db_pre_request: Option<(String, String)>,
     150              :     #[allow(dead_code)]
     151              :     #[serde(default = "role_claim_key")]
     152              :     pub role_claim_key: String,
     153              :     #[serde(default, deserialize_with = "deserialize_comma_separated_option")]
     154              :     pub db_extra_search_path: Option<Vec<String>>,
     155              :     #[serde(default, deserialize_with = "deserialize_comma_separated_option")]
     156              :     pub server_cors_allowed_origins: Option<Vec<String>>,
     157              : }
     158              : 
     159              : // The DbSchemaCache is a cache of the ApiConfig and DbSchemaOwned for each endpoint
     160              : pub(crate) struct DbSchemaCache(Cache<EndpointCacheKey, Arc<(ApiConfig, DbSchemaOwned)>>);
     161              : impl DbSchemaCache {
     162            0 :     pub fn new(config: crate::config::CacheOptions) -> Self {
     163            0 :         let builder = Cache::builder().name("schema");
     164            0 :         let builder = config.moka(builder);
     165              : 
     166            0 :         let metrics = &Metrics::get().cache;
     167            0 :         if let Some(size) = config.size {
     168            0 :             metrics.capacity.set(CacheKind::Schema, size as i64);
     169            0 :         }
     170              : 
     171            0 :         let builder =
     172            0 :             builder.eviction_listener(|_k, _v, cause| eviction_listener(CacheKind::Schema, cause));
     173              : 
     174            0 :         Self(builder.build())
     175            0 :     }
     176              : 
     177            0 :     pub async fn maintain(&self) -> Result<Infallible, anyhow::Error> {
     178            0 :         let mut ticker = tokio::time::interval(std::time::Duration::from_secs(60));
     179              :         loop {
     180            0 :             ticker.tick().await;
     181            0 :             self.0.run_pending_tasks();
     182              :         }
     183              :     }
     184              : 
     185            0 :     pub fn get_cached(
     186            0 :         &self,
     187            0 :         endpoint_id: &EndpointCacheKey,
     188            0 :     ) -> Option<Arc<(ApiConfig, DbSchemaOwned)>> {
     189            0 :         count_cache_outcome(CacheKind::Schema, self.0.get(endpoint_id))
     190            0 :     }
     191            0 :     pub async fn get_remote(
     192            0 :         &self,
     193            0 :         endpoint_id: &EndpointCacheKey,
     194            0 :         auth_header: &HeaderValue,
     195            0 :         connection_string: &str,
     196            0 :         client: &mut http_conn_pool::Client<LocalProxyClient>,
     197            0 :         ctx: &RequestContext,
     198            0 :         config: &'static ProxyConfig,
     199            0 :     ) -> Result<Arc<(ApiConfig, DbSchemaOwned)>, RestError> {
     200            0 :         info!("db_schema cache miss for endpoint: {:?}", endpoint_id);
     201            0 :         let remote_value = self
     202            0 :             .internal_get_remote(auth_header, connection_string, client, ctx, config)
     203            0 :             .await;
     204            0 :         let (api_config, schema_owned) = match remote_value {
     205            0 :             Ok((api_config, schema_owned)) => (api_config, schema_owned),
     206            0 :             Err(e @ RestError::SchemaTooLarge) => {
     207              :                 // for the case where the schema is too large, we cache an empty dummy value
     208              :                 // all the other requests will fail without triggering the introspection query
     209            0 :                 let schema_owned = serde_json::from_str::<DbSchemaOwned>(EMPTY_JSON_SCHEMA)
     210            0 :                     .map_err(|e| JsonDeserialize { source: e })?;
     211              : 
     212            0 :                 let api_config = ApiConfig {
     213            0 :                     db_schemas: vec![],
     214            0 :                     db_anon_role: None,
     215            0 :                     db_max_rows: None,
     216            0 :                     db_allowed_select_functions: vec![],
     217            0 :                     role_claim_key: String::new(),
     218            0 :                     db_extra_search_path: None,
     219            0 :                     server_cors_allowed_origins: None,
     220            0 :                 };
     221            0 :                 let value = Arc::new((api_config, schema_owned));
     222            0 :                 count_cache_insert(CacheKind::Schema);
     223            0 :                 self.0.insert(endpoint_id.clone(), value);
     224            0 :                 return Err(e);
     225              :             }
     226            0 :             Err(e) => {
     227            0 :                 return Err(e);
     228              :             }
     229              :         };
     230            0 :         let value = Arc::new((api_config, schema_owned));
     231            0 :         count_cache_insert(CacheKind::Schema);
     232            0 :         self.0.insert(endpoint_id.clone(), value.clone());
     233            0 :         Ok(value)
     234            0 :     }
     235            0 :     async fn internal_get_remote(
     236            0 :         &self,
     237            0 :         auth_header: &HeaderValue,
     238            0 :         connection_string: &str,
     239            0 :         client: &mut http_conn_pool::Client<LocalProxyClient>,
     240            0 :         ctx: &RequestContext,
     241            0 :         config: &'static ProxyConfig,
     242            0 :     ) -> Result<(ApiConfig, DbSchemaOwned), RestError> {
     243            0 :         #[derive(Deserialize)]
     244              :         struct SingleRow<Row> {
     245              :             rows: [Row; 1],
     246              :         }
     247              : 
     248              :         #[derive(Deserialize)]
     249              :         struct ConfigRow {
     250              :             #[serde(deserialize_with = "deserialize_json_string")]
     251              :             config: ApiConfig,
     252              :         }
     253              : 
     254            0 :         #[derive(Deserialize)]
     255              :         struct SchemaRow {
     256              :             json_schema: DbSchemaOwned,
     257              :         }
     258              : 
     259            0 :         let headers = vec![
     260            0 :             (&NEON_REQUEST_ID, uuid_to_header_value(ctx.session_id())),
     261            0 :             (
     262            0 :                 &CONN_STRING,
     263            0 :                 HeaderValue::from_str(connection_string).expect(
     264            0 :                     "connection string came from a header, so it must be a valid headervalue",
     265            0 :                 ),
     266            0 :             ),
     267            0 :             (&AUTHORIZATION, auth_header.clone()),
     268            0 :             (&RAW_TEXT_OUTPUT, HEADER_VALUE_TRUE),
     269              :         ];
     270              : 
     271            0 :         let query = get_postgresql_configuration_query(Some("pgrst.pre_config"));
     272              :         let SingleRow {
     273            0 :             rows: [ConfigRow { config: api_config }],
     274            0 :         } = make_local_proxy_request(
     275            0 :             client,
     276            0 :             headers.iter().cloned(),
     277            0 :             QueryData {
     278            0 :                 query: Cow::Owned(query),
     279            0 :                 params: vec![],
     280            0 :             },
     281            0 :             config.rest_config.max_schema_size,
     282            0 :         )
     283            0 :         .await
     284            0 :         .map_err(|e| match e {
     285              :             RestError::ReadPayload(ReadPayloadError::BodyTooLarge { .. }) => {
     286            0 :                 RestError::SchemaTooLarge
     287              :             }
     288            0 :             e => e,
     289            0 :         })?;
     290              : 
     291              :         // now that we have the api_config let's run the second INTROSPECTION_SQL query
     292              :         let SingleRow {
     293            0 :             rows: [SchemaRow { json_schema }],
     294            0 :         } = make_local_proxy_request(
     295            0 :             client,
     296            0 :             headers,
     297            0 :             QueryData {
     298            0 :                 query: INTROSPECTION_SQL.into(),
     299            0 :                 params: vec![
     300            0 :                     serde_json::to_value(&api_config.db_schemas)
     301            0 :                         .expect("Vec<String> is always valid to encode as JSON"),
     302            0 :                     JsonValue::Bool(false), // include_roles_with_login
     303            0 :                     JsonValue::Bool(false), // use_internal_permissions
     304            0 :                 ],
     305            0 :             },
     306            0 :             config.rest_config.max_schema_size,
     307            0 :         )
     308            0 :         .await
     309            0 :         .map_err(|e| match e {
     310              :             RestError::ReadPayload(ReadPayloadError::BodyTooLarge { .. }) => {
     311            0 :                 RestError::SchemaTooLarge
     312              :             }
     313            0 :             e => e,
     314            0 :         })?;
     315              : 
     316            0 :         Ok((api_config, json_schema))
     317            0 :     }
     318              : }
     319              : 
     320              : // A type to represent a postgresql errors
     321              : // we use our own type (instead of postgres_client::Error) because we get the error from the json response
     322            0 : #[derive(Debug, thiserror::Error, Deserialize)]
     323              : pub(crate) struct PostgresError {
     324              :     pub code: String,
     325              :     pub message: String,
     326              :     pub detail: Option<String>,
     327              :     pub hint: Option<String>,
     328              : }
     329              : impl HttpCodeError for PostgresError {
     330            0 :     fn get_http_status_code(&self) -> StatusCode {
     331            0 :         let status = pg_error_to_status_code(&self.code, true);
     332            0 :         StatusCode::from_u16(status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)
     333            0 :     }
     334              : }
     335              : impl ReportableError for PostgresError {
     336            0 :     fn get_error_kind(&self) -> ErrorKind {
     337            0 :         ErrorKind::User
     338            0 :     }
     339              : }
     340              : impl UserFacingError for PostgresError {
     341            0 :     fn to_string_client(&self) -> String {
     342            0 :         if self.code.starts_with("PT") {
     343            0 :             "Postgres error".to_string()
     344              :         } else {
     345            0 :             self.message.clone()
     346              :         }
     347            0 :     }
     348              : }
     349              : impl std::fmt::Display for PostgresError {
     350            0 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
     351            0 :         write!(f, "{}", self.message)
     352            0 :     }
     353              : }
     354              : 
     355              : // A type to represent errors that can occur in the rest broker
     356              : #[derive(Debug, thiserror::Error)]
     357              : pub(crate) enum RestError {
     358              :     #[error(transparent)]
     359              :     ReadPayload(#[from] ReadPayloadError),
     360              :     #[error(transparent)]
     361              :     ConnectCompute(#[from] HttpConnError),
     362              :     #[error(transparent)]
     363              :     ConnInfo(#[from] ConnInfoError),
     364              :     #[error(transparent)]
     365              :     Postgres(#[from] PostgresError),
     366              :     #[error(transparent)]
     367              :     JsonConversion(#[from] JsonConversionError),
     368              :     #[error(transparent)]
     369              :     SubzeroCore(#[from] SubzeroCoreError),
     370              :     #[error("schema is too large")]
     371              :     SchemaTooLarge,
     372              : }
     373              : impl ReportableError for RestError {
     374            0 :     fn get_error_kind(&self) -> ErrorKind {
     375            0 :         match self {
     376            0 :             RestError::ReadPayload(e) => e.get_error_kind(),
     377            0 :             RestError::ConnectCompute(e) => e.get_error_kind(),
     378            0 :             RestError::ConnInfo(e) => e.get_error_kind(),
     379            0 :             RestError::Postgres(_) => ErrorKind::Postgres,
     380            0 :             RestError::JsonConversion(_) => ErrorKind::Postgres,
     381            0 :             RestError::SubzeroCore(_) => ErrorKind::User,
     382            0 :             RestError::SchemaTooLarge => ErrorKind::User,
     383              :         }
     384            0 :     }
     385              : }
     386              : impl UserFacingError for RestError {
     387            0 :     fn to_string_client(&self) -> String {
     388            0 :         match self {
     389            0 :             RestError::ReadPayload(p) => p.to_string(),
     390            0 :             RestError::ConnectCompute(c) => c.to_string_client(),
     391            0 :             RestError::ConnInfo(c) => c.to_string_client(),
     392            0 :             RestError::SchemaTooLarge => self.to_string(),
     393            0 :             RestError::Postgres(p) => p.to_string_client(),
     394            0 :             RestError::JsonConversion(_) => "could not parse postgres response".to_string(),
     395            0 :             RestError::SubzeroCore(s) => {
     396              :                 // TODO: this is a hack to get the message from the json body
     397            0 :                 let json = s.json_body();
     398            0 :                 let default_message = "Unknown error".to_string();
     399              : 
     400            0 :                 json.get("message")
     401            0 :                     .map_or(default_message.clone(), |m| match m {
     402            0 :                         JsonValue::String(s) => s.clone(),
     403            0 :                         _ => default_message,
     404            0 :                     })
     405              :             }
     406              :         }
     407            0 :     }
     408              : }
     409              : impl HttpCodeError for RestError {
     410            0 :     fn get_http_status_code(&self) -> StatusCode {
     411            0 :         match self {
     412            0 :             RestError::ReadPayload(e) => e.get_http_status_code(),
     413            0 :             RestError::ConnectCompute(h) => match h.get_error_kind() {
     414            0 :                 ErrorKind::User => StatusCode::BAD_REQUEST,
     415            0 :                 _ => StatusCode::INTERNAL_SERVER_ERROR,
     416              :             },
     417            0 :             RestError::ConnInfo(_) => StatusCode::BAD_REQUEST,
     418            0 :             RestError::Postgres(e) => e.get_http_status_code(),
     419            0 :             RestError::JsonConversion(_) => StatusCode::INTERNAL_SERVER_ERROR,
     420            0 :             RestError::SchemaTooLarge => StatusCode::INTERNAL_SERVER_ERROR,
     421            0 :             RestError::SubzeroCore(e) => {
     422            0 :                 let status = e.status_code();
     423            0 :                 StatusCode::from_u16(status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)
     424              :             }
     425              :         }
     426            0 :     }
     427              : }
     428              : 
     429              : // Helper functions for the rest broker
     430              : 
     431            0 : fn fmt_env_query<'a>(env: &'a HashMap<&'a str, &'a str>) -> Snippet<'a> {
     432              :     "select "
     433            0 :         + if env.is_empty() {
     434            0 :             sql("null")
     435              :         } else {
     436            0 :             env.iter()
     437            0 :                 .map(|(k, v)| {
     438            0 :                     "set_config(" + param(k as &SqlParam) + ", " + param(v as &SqlParam) + ", true)"
     439            0 :                 })
     440            0 :                 .join(",")
     441              :         }
     442            0 : }
     443              : 
     444              : // TODO: see about removing the need for cloning the values (inner things are &Cow<str> already)
     445            0 : fn to_sql_param(p: &Param) -> JsonValue {
     446            0 :     match p {
     447            0 :         SV(SingleVal(v, ..)) => JsonValue::String(v.to_string()),
     448            0 :         Str(v) => JsonValue::String((*v).to_string()),
     449            0 :         StrOwned(v) => JsonValue::String((*v).clone()),
     450            0 :         PL(Payload(v, ..)) => JsonValue::String(v.clone().into_owned()),
     451            0 :         LV(ListVal(v, ..)) => {
     452            0 :             if v.is_empty() {
     453            0 :                 JsonValue::String(r"{}".to_string())
     454              :             } else {
     455            0 :                 JsonValue::String(format!(
     456            0 :                     "{{\"{}\"}}",
     457            0 :                     v.iter()
     458            0 :                         .map(|e| e.replace('\\', "\\\\").replace('\"', "\\\""))
     459            0 :                         .collect::<Vec<_>>()
     460            0 :                         .join("\",\"")
     461              :                 ))
     462              :             }
     463              :         }
     464              :     }
     465            0 : }
     466              : 
     467              : #[derive(serde::Serialize)]
     468              : struct QueryData<'a> {
     469              :     query: Cow<'a, str>,
     470              :     params: Vec<JsonValue>,
     471              : }
     472              : 
     473              : #[derive(serde::Serialize)]
     474              : struct BatchQueryData<'a> {
     475              :     queries: Vec<QueryData<'a>>,
     476              : }
     477              : 
     478            0 : async fn make_local_proxy_request<S: DeserializeOwned>(
     479            0 :     client: &mut http_conn_pool::Client<LocalProxyClient>,
     480            0 :     headers: impl IntoIterator<Item = (&HeaderName, HeaderValue)>,
     481            0 :     body: QueryData<'_>,
     482            0 :     max_len: usize,
     483            0 : ) -> Result<S, RestError> {
     484            0 :     let body_string = serde_json::to_string(&body)
     485            0 :         .map_err(|e| RestError::JsonConversion(JsonConversionError::ParseJsonError(e)))?;
     486              : 
     487            0 :     let response = make_raw_local_proxy_request(client, headers, body_string).await?;
     488              : 
     489            0 :     let response_status = response.status();
     490              : 
     491            0 :     if response_status != StatusCode::OK {
     492            0 :         return Err(RestError::SubzeroCore(InternalError {
     493            0 :             message: "Failed to get endpoint schema".to_string(),
     494            0 :         }));
     495            0 :     }
     496              : 
     497              :     // Capture the response body
     498            0 :     let response_body = crate::http::read_body_with_limit(response.into_body(), max_len)
     499            0 :         .await
     500            0 :         .map_err(ReadPayloadError::from)?;
     501              : 
     502              :     // Parse the JSON response
     503            0 :     let response_json: S = serde_json::from_slice(&response_body)
     504            0 :         .map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?;
     505              : 
     506            0 :     Ok(response_json)
     507            0 : }
     508              : 
     509            0 : async fn make_raw_local_proxy_request(
     510            0 :     client: &mut http_conn_pool::Client<LocalProxyClient>,
     511            0 :     headers: impl IntoIterator<Item = (&HeaderName, HeaderValue)>,
     512            0 :     body: String,
     513            0 : ) -> Result<Response<Incoming>, RestError> {
     514            0 :     let local_proxy_uri = ::http::Uri::from_static("http://proxy.local/sql");
     515            0 :     let mut req = Request::builder().method(Method::POST).uri(local_proxy_uri);
     516            0 :     let req_headers = req.headers_mut().expect("failed to get headers");
     517              :     // Add all provided headers to the request
     518            0 :     for (header_name, header_value) in headers {
     519            0 :         req_headers.insert(header_name, header_value.clone());
     520            0 :     }
     521              : 
     522            0 :     let body_boxed = Full::new(Bytes::from(body))
     523            0 :         .map_err(|never| match never {}) // Convert Infallible to hyper::Error
     524            0 :         .boxed();
     525              : 
     526            0 :     let req = req.body(body_boxed).map_err(|_| {
     527            0 :         RestError::SubzeroCore(InternalError {
     528            0 :             message: "Failed to build request".to_string(),
     529            0 :         })
     530            0 :     })?;
     531              : 
     532              :     // Send the request to the local proxy
     533            0 :     client
     534            0 :         .inner
     535            0 :         .inner
     536            0 :         .send_request(req)
     537            0 :         .await
     538            0 :         .map_err(LocalProxyConnError::from)
     539            0 :         .map_err(HttpConnError::from)
     540            0 :         .map_err(RestError::from)
     541            0 : }
     542              : 
     543            0 : pub(crate) async fn handle(
     544            0 :     config: &'static ProxyConfig,
     545            0 :     ctx: RequestContext,
     546            0 :     request: Request<Incoming>,
     547            0 :     backend: Arc<PoolingBackend>,
     548            0 :     cancel: CancellationToken,
     549            0 : ) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
     550            0 :     let result = handle_inner(cancel, config, &ctx, request, backend).await;
     551              : 
     552            0 :     let response = match result {
     553            0 :         Ok(r) => {
     554            0 :             ctx.set_success();
     555              : 
     556              :             // Handling the error response from local proxy here
     557            0 :             if r.status().is_server_error() {
     558            0 :                 let status = r.status();
     559              : 
     560            0 :                 let body_bytes = r
     561            0 :                     .collect()
     562            0 :                     .await
     563            0 :                     .map_err(|e| {
     564            0 :                         ApiError::InternalServerError(anyhow::Error::msg(format!(
     565            0 :                             "could not collect http body: {e}"
     566            0 :                         )))
     567            0 :                     })?
     568            0 :                     .to_bytes();
     569              : 
     570            0 :                 if let Ok(mut json_map) =
     571            0 :                     serde_json::from_slice::<IndexMap<&str, &RawValue>>(&body_bytes)
     572              :                 {
     573            0 :                     let message = json_map.get("message");
     574            0 :                     if let Some(message) = message {
     575            0 :                         let msg: String = match serde_json::from_str(message.get()) {
     576            0 :                             Ok(msg) => msg,
     577              :                             Err(_) => {
     578            0 :                                 "Unable to parse the response message from server".to_string()
     579              :                             }
     580              :                         };
     581              : 
     582            0 :                         error!("Error response from local_proxy: {status} {msg}");
     583              : 
     584            0 :                         json_map.retain(|key, _| !key.starts_with("neon:")); // remove all the neon-related keys
     585              : 
     586            0 :                         let resp_json = serde_json::to_string(&json_map)
     587            0 :                             .unwrap_or("failed to serialize the response message".to_string());
     588              : 
     589            0 :                         return json_response(status, resp_json);
     590            0 :                     }
     591            0 :                 }
     592              : 
     593            0 :                 error!("Unable to parse the response message from local_proxy");
     594            0 :                 return json_response(
     595            0 :                     status,
     596            0 :                     json!({ "message": "Unable to parse the response message from server".to_string() }),
     597              :                 );
     598            0 :             }
     599            0 :             r
     600              :         }
     601            0 :         Err(e @ RestError::SubzeroCore(_)) => {
     602            0 :             let error_kind = e.get_error_kind();
     603            0 :             ctx.set_error_kind(error_kind);
     604              : 
     605            0 :             tracing::info!(
     606            0 :                 kind=error_kind.to_metric_label(),
     607              :                 error=%e,
     608              :                 msg="subzero core error",
     609            0 :                 "forwarding error to user"
     610              :             );
     611              : 
     612            0 :             let RestError::SubzeroCore(subzero_err) = e else {
     613            0 :                 panic!("expected subzero core error")
     614              :             };
     615              : 
     616            0 :             let json_body = subzero_err.json_body();
     617            0 :             let status_code = StatusCode::from_u16(subzero_err.status_code())
     618            0 :                 .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
     619              : 
     620            0 :             json_response(status_code, json_body)?
     621              :         }
     622            0 :         Err(e) => {
     623            0 :             let error_kind = e.get_error_kind();
     624            0 :             ctx.set_error_kind(error_kind);
     625              : 
     626            0 :             let message = e.to_string_client();
     627            0 :             let status_code = e.get_http_status_code();
     628              : 
     629            0 :             tracing::info!(
     630            0 :                 kind=error_kind.to_metric_label(),
     631              :                 error=%e,
     632              :                 msg=message,
     633            0 :                 "forwarding error to user"
     634              :             );
     635              : 
     636            0 :             let (code, detail, hint) = match e {
     637            0 :                 RestError::Postgres(e) => (
     638            0 :                     if e.code.starts_with("PT") {
     639            0 :                         None
     640              :                     } else {
     641            0 :                         Some(e.code)
     642              :                     },
     643            0 :                     e.detail,
     644            0 :                     e.hint,
     645              :                 ),
     646            0 :                 _ => (None, None, None),
     647              :             };
     648              : 
     649            0 :             json_response(
     650            0 :                 status_code,
     651            0 :                 json!({
     652            0 :                     "message": message,
     653            0 :                     "code": code,
     654            0 :                     "detail": detail,
     655            0 :                     "hint": hint,
     656              :                 }),
     657            0 :             )?
     658              :         }
     659              :     };
     660              : 
     661            0 :     Ok(response)
     662            0 : }
     663              : 
     664            0 : async fn handle_inner(
     665            0 :     _cancel: CancellationToken,
     666            0 :     config: &'static ProxyConfig,
     667            0 :     ctx: &RequestContext,
     668            0 :     request: Request<Incoming>,
     669            0 :     backend: Arc<PoolingBackend>,
     670            0 : ) -> Result<Response<BoxBody<Bytes, hyper::Error>>, RestError> {
     671            0 :     let _requeset_gauge = Metrics::get()
     672            0 :         .proxy
     673            0 :         .connection_requests
     674            0 :         .guard(ctx.protocol());
     675            0 :     info!(
     676            0 :         protocol = %ctx.protocol(),
     677            0 :         "handling interactive connection from client"
     678              :     );
     679              : 
     680              :     // Read host from Host, then URI host as fallback
     681              :     // TODO: will this be a problem if behind a load balancer?
     682              :     // TODO: can we use the x-forwarded-host header?
     683            0 :     let host = request
     684            0 :         .headers()
     685            0 :         .get(HOST)
     686            0 :         .and_then(|v| v.to_str().ok())
     687            0 :         .unwrap_or_else(|| request.uri().host().unwrap_or(""));
     688              : 
     689              :     // a valid path is /database/rest/v1/... so splitting should be ["", "database", "rest", "v1", ...]
     690            0 :     let database_name = request
     691            0 :         .uri()
     692            0 :         .path()
     693            0 :         .split('/')
     694            0 :         .nth(1)
     695            0 :         .ok_or(RestError::SubzeroCore(NotFound {
     696            0 :             target: request.uri().path().to_string(),
     697            0 :         }))?;
     698              : 
     699              :     // we always use the authenticator role to connect to the database
     700            0 :     let authenticator_role = "authenticator";
     701              : 
     702              :     // Strip the hostname prefix from the host to get the database hostname
     703            0 :     let database_host = host.replace(&config.rest_config.hostname_prefix, "");
     704              : 
     705            0 :     let connection_string =
     706            0 :         format!("postgresql://{authenticator_role}@{database_host}/{database_name}");
     707              : 
     708            0 :     let conn_info = get_conn_info(
     709            0 :         &config.authentication_config,
     710            0 :         ctx,
     711            0 :         Some(&connection_string),
     712            0 :         request.headers(),
     713            0 :     )?;
     714            0 :     info!(
     715            0 :         user = conn_info.conn_info.user_info.user.as_str(),
     716            0 :         "credentials"
     717              :     );
     718              : 
     719            0 :     match conn_info.auth {
     720            0 :         AuthData::Jwt(jwt) => {
     721            0 :             let api_prefix = format!("/{database_name}/rest/v1/");
     722            0 :             handle_rest_inner(
     723            0 :                 config,
     724            0 :                 ctx,
     725            0 :                 &api_prefix,
     726            0 :                 request,
     727            0 :                 &connection_string,
     728            0 :                 conn_info.conn_info,
     729            0 :                 jwt,
     730            0 :                 backend,
     731            0 :             )
     732            0 :             .await
     733              :         }
     734            0 :         AuthData::Password(_) => Err(RestError::ConnInfo(ConnInfoError::MissingCredentials(
     735            0 :             Credentials::BearerJwt,
     736            0 :         ))),
     737              :     }
     738            0 : }
     739              : 
     740            0 : fn apply_common_cors_headers(
     741            0 :     response: &mut Builder,
     742            0 :     request_headers: &HeaderMap,
     743            0 :     allowed_origins: Option<&Vec<String>>,
     744            0 : ) {
     745            0 :     let request_origin = request_headers
     746            0 :         .get(ORIGIN)
     747            0 :         .map(|v| v.to_str().unwrap_or(""));
     748              : 
     749            0 :     let response_allow_origin = match (request_origin, allowed_origins) {
     750            0 :         (Some(or), Some(allowed_origins)) => {
     751            0 :             if allowed_origins.iter().any(|o| o == or) {
     752            0 :                 Some(HeaderValue::from_str(or).unwrap_or(HEADER_VALUE_ALLOW_ALL_ORIGINS))
     753              :             } else {
     754            0 :                 None
     755              :             }
     756              :         }
     757              :         // FIXME!: on the first request, when we don't have a cached entry for the config,
     758              :         // allowed_origins is None, so we allow all origins for now but we should fix this
     759            0 :         (Some(_), None) => Some(HEADER_VALUE_ALLOW_ALL_ORIGINS),
     760            0 :         _ => None,
     761              :     };
     762            0 :     if let Some(h) = response.headers_mut() {
     763            0 :         h.insert(
     764            0 :             ACCESS_CONTROL_EXPOSE_HEADERS,
     765            0 :             ACCESS_CONTROL_EXPOSE_HEADERS_VALUE,
     766              :         );
     767            0 :         if let Some(origin) = response_allow_origin {
     768            0 :             if origin != HEADER_VALUE_ALLOW_ALL_ORIGINS {
     769            0 :                 h.insert(VARY, ACCESS_CONTROL_VARY_VALUE);
     770            0 :             }
     771            0 :             h.insert(ACCESS_CONTROL_ALLOW_ORIGIN, origin);
     772            0 :         }
     773            0 :     }
     774            0 : }
     775              : 
     776              : #[allow(clippy::too_many_arguments)]
     777            0 : async fn handle_rest_inner(
     778            0 :     config: &'static ProxyConfig,
     779            0 :     ctx: &RequestContext,
     780            0 :     api_prefix: &str,
     781            0 :     request: Request<Incoming>,
     782            0 :     connection_string: &str,
     783            0 :     conn_info: ConnInfo,
     784            0 :     jwt: String,
     785            0 :     backend: Arc<PoolingBackend>,
     786            0 : ) -> Result<Response<BoxBody<Bytes, hyper::Error>>, RestError> {
     787            0 :     let db_schema_cache =
     788            0 :         config
     789            0 :             .rest_config
     790            0 :             .db_schema_cache
     791            0 :             .as_ref()
     792            0 :             .ok_or(RestError::SubzeroCore(InternalError {
     793            0 :                 message: "DB schema cache is not configured".to_string(),
     794            0 :             }))?;
     795              : 
     796            0 :     let endpoint_cache_key = conn_info
     797            0 :         .endpoint_cache_key()
     798            0 :         .ok_or(RestError::SubzeroCore(InternalError {
     799            0 :             message: "Failed to get endpoint cache key".to_string(),
     800            0 :         }))?;
     801              : 
     802            0 :     let (parts, originial_body) = request.into_parts();
     803              : 
     804              :     // try and get the cached entry for this endpoint
     805              :     // it contains the api config and the introspected db schema
     806            0 :     let cached_entry = db_schema_cache.get_cached(&endpoint_cache_key);
     807              : 
     808            0 :     let allowed_origins = cached_entry
     809            0 :         .as_ref()
     810            0 :         .and_then(|arc| arc.0.server_cors_allowed_origins.as_ref());
     811              : 
     812            0 :     let mut response = Response::builder();
     813            0 :     apply_common_cors_headers(&mut response, &parts.headers, allowed_origins);
     814              : 
     815              :     // handle the OPTIONS request
     816            0 :     if parts.method == Method::OPTIONS {
     817            0 :         let allowed_headers = parts
     818            0 :             .headers
     819            0 :             .get(ACCESS_CONTROL_REQUEST_HEADERS)
     820            0 :             .and_then(|a| a.to_str().ok())
     821            0 :             .filter(|v| !v.is_empty())
     822            0 :             .map_or_else(
     823            0 :                 || "Authorization".to_string(),
     824            0 :                 |v| format!("{v}, Authorization"),
     825              :             );
     826            0 :         return response
     827            0 :             .status(StatusCode::OK)
     828            0 :             .header(
     829            0 :                 ACCESS_CONTROL_ALLOW_METHODS,
     830            0 :                 ACCESS_CONTROL_ALLOW_METHODS_VALUE,
     831              :             )
     832            0 :             .header(ACCESS_CONTROL_MAX_AGE, ACCESS_CONTROL_MAX_AGE_VALUE)
     833            0 :             .header(
     834            0 :                 ACCESS_CONTROL_ALLOW_HEADERS,
     835            0 :                 HeaderValue::from_str(&allowed_headers)
     836            0 :                     .unwrap_or(ACCESS_CONTROL_ALLOW_HEADERS_VALUE),
     837              :             )
     838            0 :             .header(ALLOW, ACCESS_CONTROL_ALLOW_METHODS_VALUE)
     839            0 :             .body(Empty::new().map_err(|x| match x {}).boxed())
     840            0 :             .map_err(|e| {
     841            0 :                 RestError::SubzeroCore(InternalError {
     842            0 :                     message: e.to_string(),
     843            0 :                 })
     844            0 :             });
     845            0 :     }
     846              : 
     847              :     // validate the jwt token
     848            0 :     let jwt_parsed = backend
     849            0 :         .authenticate_with_jwt(ctx, &conn_info.user_info, jwt)
     850            0 :         .await
     851            0 :         .map_err(HttpConnError::from)?;
     852              : 
     853            0 :     let auth_header = parts
     854            0 :         .headers
     855            0 :         .get(AUTHORIZATION)
     856            0 :         .ok_or(RestError::SubzeroCore(InternalError {
     857            0 :             message: "Authorization header is required".to_string(),
     858            0 :         }))?;
     859            0 :     let mut client = backend.connect_to_local_proxy(ctx, conn_info).await?;
     860              : 
     861            0 :     let entry = match cached_entry {
     862            0 :         Some(e) => e,
     863              :         None => {
     864              :             // if not cached, get the remote entry (will run the introspection query)
     865            0 :             db_schema_cache
     866            0 :                 .get_remote(
     867            0 :                     &endpoint_cache_key,
     868            0 :                     auth_header,
     869            0 :                     connection_string,
     870            0 :                     &mut client,
     871            0 :                     ctx,
     872            0 :                     config,
     873            0 :                 )
     874            0 :                 .await?
     875              :         }
     876              :     };
     877            0 :     let (api_config, db_schema_owned) = entry.as_ref();
     878              : 
     879            0 :     let db_schema = db_schema_owned.borrow_schema();
     880              : 
     881            0 :     let db_schemas = &api_config.db_schemas; // list of schemas available for the api
     882            0 :     let db_extra_search_path = &api_config.db_extra_search_path;
     883              :     // TODO: use this when we get a replacement for jsonpath_lib
     884              :     // let role_claim_key = &api_config.role_claim_key;
     885              :     // let role_claim_path = format!("${role_claim_key}");
     886            0 :     let db_anon_role = &api_config.db_anon_role;
     887            0 :     let max_rows = api_config.db_max_rows.as_deref();
     888            0 :     let db_allowed_select_functions = api_config
     889            0 :         .db_allowed_select_functions
     890            0 :         .iter()
     891            0 :         .map(|s| s.as_str())
     892            0 :         .collect::<Vec<_>>();
     893              : 
     894              :     // extract the jwt claims (we'll need them later to set the role and env)
     895            0 :     let jwt_claims = match jwt_parsed.keys {
     896            0 :         ComputeCredentialKeys::JwtPayload(payload_bytes) => {
     897              :             // `payload_bytes` contains the raw JWT payload as Vec<u8>
     898              :             // You can deserialize it back to JSON or parse specific claims
     899            0 :             let payload: serde_json::Value = serde_json::from_slice(&payload_bytes)
     900            0 :                 .map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?;
     901            0 :             Some(payload)
     902              :         }
     903            0 :         ComputeCredentialKeys::AuthKeys(_) => None,
     904              :     };
     905              : 
     906              :     // read the role from the jwt claims (and set it to the "anon" role if not present)
     907            0 :     let (role, authenticated) = match &jwt_claims {
     908            0 :         Some(claims) => match claims.get("role") {
     909            0 :             Some(JsonValue::String(r)) => (Some(r), true),
     910            0 :             _ => (db_anon_role.as_ref(), true),
     911              :         },
     912            0 :         None => (db_anon_role.as_ref(), false),
     913              :     };
     914              : 
     915              :     // do not allow unauthenticated requests when there is no anonymous role setup
     916            0 :     if let (None, false) = (role, authenticated) {
     917            0 :         return Err(RestError::SubzeroCore(JwtTokenInvalid {
     918            0 :             message: "unauthenticated requests not allowed".to_string(),
     919            0 :         }));
     920            0 :     }
     921              : 
     922              :     // start deconstructing the request because subzero core mostly works with &str
     923            0 :     let method = parts.method;
     924            0 :     let method_str = method.as_str();
     925            0 :     let path = parts.uri.path_and_query().map_or("/", |pq| pq.as_str());
     926              : 
     927              :     // this is actually the table name (or rpc/function_name)
     928              :     // TODO: rename this to something more descriptive
     929            0 :     let root = match parts.uri.path().strip_prefix(api_prefix) {
     930            0 :         Some(p) => Ok(p),
     931            0 :         None => Err(RestError::SubzeroCore(NotFound {
     932            0 :             target: parts.uri.path().to_string(),
     933            0 :         })),
     934            0 :     }?;
     935              : 
     936              :     // pick the current schema from the headers (or the first one from config)
     937            0 :     let schema_name = &DbSchema::pick_current_schema(db_schemas, method_str, &parts.headers)?;
     938              : 
     939              :     // add the content-profile header to the response
     940            0 :     let mut response_headers = vec![];
     941            0 :     if db_schemas.len() > 1 {
     942            0 :         response_headers.push(("Content-Profile".to_string(), schema_name.clone()));
     943            0 :     }
     944              : 
     945              :     // parse the query string into a Vec<(&str, &str)>
     946            0 :     let query = match parts.uri.query() {
     947            0 :         Some(q) => form_urlencoded::parse(q.as_bytes()).collect(),
     948            0 :         None => vec![],
     949              :     };
     950            0 :     let get: Vec<(&str, &str)> = query.iter().map(|(k, v)| (&**k, &**v)).collect();
     951              : 
     952              :     // convert the headers map to a HashMap<&str, &str>
     953            0 :     let headers: HashMap<&str, &str> = parts
     954            0 :         .headers
     955            0 :         .iter()
     956            0 :         .map(|(k, v)| (k.as_str(), v.to_str().unwrap_or("__BAD_HEADER__")))
     957            0 :         .collect();
     958              : 
     959            0 :     let cookies = HashMap::new(); // TODO: add cookies
     960              : 
     961              :     // Read the request body (skip for GET requests)
     962            0 :     let body_as_string: Option<String> = if method == Method::GET {
     963            0 :         None
     964              :     } else {
     965            0 :         let body_bytes =
     966            0 :             read_body_with_limit(originial_body, config.http_config.max_request_size_bytes)
     967            0 :                 .await
     968            0 :                 .map_err(ReadPayloadError::from)?;
     969            0 :         if body_bytes.is_empty() {
     970            0 :             None
     971              :         } else {
     972            0 :             Some(String::from_utf8_lossy(&body_bytes).into_owned())
     973              :         }
     974              :     };
     975              : 
     976              :     // parse the request into an ApiRequest struct
     977            0 :     let mut api_request = parse(
     978            0 :         schema_name,
     979            0 :         root,
     980            0 :         db_schema,
     981            0 :         method_str,
     982            0 :         path,
     983            0 :         get,
     984            0 :         body_as_string.as_deref(),
     985            0 :         headers,
     986            0 :         cookies,
     987            0 :         max_rows,
     988              :     )
     989            0 :     .map_err(RestError::SubzeroCore)?;
     990              : 
     991            0 :     let role_str = match role {
     992            0 :         Some(r) => r,
     993            0 :         None => "",
     994              :     };
     995              : 
     996            0 :     replace_select_star(db_schema, schema_name, role_str, &mut api_request.query)?;
     997              : 
     998              :     // TODO: this is not relevant when acting as PostgREST but will be useful
     999              :     // in the context of DBX where they need internal permissions
    1000              :     // if !disable_internal_permissions {
    1001              :     //     check_privileges(db_schema, schema_name, role_str, &api_request)?;
    1002              :     // }
    1003              : 
    1004            0 :     check_safe_functions(&api_request, &db_allowed_select_functions)?;
    1005              : 
    1006              :     // TODO: this is not relevant when acting as PostgREST but will be useful
    1007              :     // in the context of DBX where they need internal permissions
    1008              :     // if !disable_internal_permissions {
    1009              :     //     insert_policy_conditions(db_schema, schema_name, role_str, &mut api_request.query)?;
    1010              :     // }
    1011              : 
    1012            0 :     let env_role = Some(role_str);
    1013              : 
    1014              :     // construct the env (passed in to the sql context as GUCs)
    1015            0 :     let empty_json = "{}".to_string();
    1016            0 :     let headers_env = serde_json::to_string(&api_request.headers).unwrap_or(empty_json.clone());
    1017            0 :     let cookies_env = serde_json::to_string(&api_request.cookies).unwrap_or(empty_json.clone());
    1018            0 :     let get_env = serde_json::to_string(&api_request.get).unwrap_or(empty_json.clone());
    1019            0 :     let jwt_claims_env = jwt_claims
    1020            0 :         .as_ref()
    1021            0 :         .map(|v| serde_json::to_string(v).unwrap_or(empty_json.clone()))
    1022            0 :         .unwrap_or(if let Some(r) = env_role {
    1023            0 :             let claims: HashMap<&str, &str> = HashMap::from([("role", r)]);
    1024            0 :             serde_json::to_string(&claims).unwrap_or(empty_json.clone())
    1025              :         } else {
    1026            0 :             empty_json.clone()
    1027              :         });
    1028            0 :     let mut search_path = vec![api_request.schema_name];
    1029            0 :     if let Some(extra) = &db_extra_search_path {
    1030            0 :         search_path.extend(extra.iter().map(|s| s.as_str()));
    1031            0 :     }
    1032            0 :     let search_path_str = search_path
    1033            0 :         .into_iter()
    1034            0 :         .filter(|s| !s.is_empty())
    1035            0 :         .collect::<Vec<_>>()
    1036            0 :         .join(",");
    1037            0 :     let mut env: HashMap<&str, &str> = HashMap::from([
    1038            0 :         ("request.method", api_request.method),
    1039            0 :         ("request.path", api_request.path),
    1040            0 :         ("request.headers", &headers_env),
    1041            0 :         ("request.cookies", &cookies_env),
    1042            0 :         ("request.get", &get_env),
    1043            0 :         ("request.jwt.claims", &jwt_claims_env),
    1044            0 :         ("search_path", &search_path_str),
    1045            0 :     ]);
    1046            0 :     if let Some(r) = env_role {
    1047            0 :         env.insert("role", r);
    1048            0 :     }
    1049              : 
    1050              :     // generate the sql statements
    1051            0 :     let (env_statement, env_parameters, _) = generate(fmt_env_query(&env));
    1052            0 :     let (main_statement, main_parameters, _) = generate(fmt_main_query(
    1053            0 :         db_schema,
    1054            0 :         api_request.schema_name,
    1055            0 :         &api_request,
    1056            0 :         &env,
    1057            0 :     )?);
    1058              : 
    1059            0 :     let mut headers = vec![
    1060            0 :         (&NEON_REQUEST_ID, uuid_to_header_value(ctx.session_id())),
    1061            0 :         (
    1062            0 :             &CONN_STRING,
    1063            0 :             HeaderValue::from_str(connection_string).expect("invalid connection string"),
    1064            0 :         ),
    1065            0 :         (&AUTHORIZATION, auth_header.clone()),
    1066            0 :         (
    1067            0 :             &TXN_ISOLATION_LEVEL,
    1068            0 :             HeaderValue::from_static("ReadCommitted"),
    1069            0 :         ),
    1070            0 :         (&ALLOW_POOL, HEADER_VALUE_TRUE),
    1071              :     ];
    1072              : 
    1073            0 :     if api_request.read_only {
    1074            0 :         headers.push((&TXN_READ_ONLY, HEADER_VALUE_TRUE));
    1075            0 :     }
    1076              : 
    1077              :     // convert the parameters from subzero core representation to the local proxy repr.
    1078            0 :     let req_body = serde_json::to_string(&BatchQueryData {
    1079            0 :         queries: vec![
    1080              :             QueryData {
    1081            0 :                 query: env_statement.into(),
    1082            0 :                 params: env_parameters
    1083            0 :                     .iter()
    1084            0 :                     .map(|p| to_sql_param(&p.to_param()))
    1085            0 :                     .collect(),
    1086              :             },
    1087              :             QueryData {
    1088            0 :                 query: main_statement.into(),
    1089            0 :                 params: main_parameters
    1090            0 :                     .iter()
    1091            0 :                     .map(|p| to_sql_param(&p.to_param()))
    1092            0 :                     .collect(),
    1093              :             },
    1094              :         ],
    1095              :     })
    1096            0 :     .map_err(|e| RestError::JsonConversion(JsonConversionError::ParseJsonError(e)))?;
    1097              : 
    1098              :     // todo: map body to count egress
    1099            0 :     let _metrics = client.metrics(ctx); // FIXME: is everything in the context set correctly?
    1100              : 
    1101              :     // send the request to the local proxy
    1102            0 :     let proxy_response = make_raw_local_proxy_request(&mut client, headers, req_body).await?;
    1103            0 :     let (response_parts, body) = proxy_response.into_parts();
    1104              : 
    1105            0 :     let max_response = config.http_config.max_response_size_bytes;
    1106            0 :     let bytes = read_body_with_limit(body, max_response)
    1107            0 :         .await
    1108            0 :         .map_err(ReadPayloadError::from)?;
    1109              : 
    1110              :     // if the response status is greater than 399, then it is an error
    1111              :     // FIXME: check if there are other error codes or shapes of the response
    1112            0 :     if response_parts.status.as_u16() > 399 {
    1113              :         // turn this postgres error from the json into PostgresError
    1114            0 :         let postgres_error = serde_json::from_slice(&bytes)
    1115            0 :             .map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?;
    1116              : 
    1117            0 :         return Err(RestError::Postgres(postgres_error));
    1118            0 :     }
    1119              : 
    1120            0 :     #[derive(Deserialize)]
    1121              :     struct QueryResults {
    1122              :         /// we run two queries, so we want only two results.
    1123              :         results: (EnvRows, MainRows),
    1124              :     }
    1125              : 
    1126              :     /// `env_statement` returns nothing of interest to us
    1127            0 :     #[derive(Deserialize)]
    1128              :     struct EnvRows {}
    1129              : 
    1130            0 :     #[derive(Deserialize)]
    1131              :     struct MainRows {
    1132              :         /// `main_statement` only returns a single row.
    1133              :         rows: [MainRow; 1],
    1134              :     }
    1135              : 
    1136            0 :     #[derive(Deserialize)]
    1137              :     struct MainRow {
    1138              :         body: String,
    1139              :         page_total: Option<String>,
    1140              :         total_result_set: Option<String>,
    1141              :         response_headers: Option<String>,
    1142              :         response_status: Option<String>,
    1143              :     }
    1144              : 
    1145            0 :     let results: QueryResults = serde_json::from_slice(&bytes)
    1146            0 :         .map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?;
    1147              : 
    1148              :     let QueryResults {
    1149            0 :         results: (_, MainRows { rows: [row] }),
    1150            0 :     } = results;
    1151              : 
    1152              :     // build the intermediate response object
    1153            0 :     let api_response = ApiResponse {
    1154            0 :         page_total: row.page_total.map_or(0, |v| v.parse::<u64>().unwrap_or(0)),
    1155            0 :         total_result_set: row.total_result_set.map(|v| v.parse::<u64>().unwrap_or(0)),
    1156              :         top_level_offset: 0, // FIXME: check why this is 0
    1157            0 :         response_headers: row.response_headers,
    1158            0 :         response_status: row.response_status,
    1159            0 :         body: row.body,
    1160              :     };
    1161              : 
    1162              :     // TODO: rollback the transaction if the page_total is not 1 and the accept_content_type is SingularJSON
    1163              :     // we can not do this in the context of proxy for now
    1164              :     // if api_request.accept_content_type == SingularJSON && api_response.page_total != 1 {
    1165              :     //     // rollback the transaction here
    1166              :     //     return Err(RestError::SubzeroCore(SingularityError {
    1167              :     //         count: api_response.page_total,
    1168              :     //         content_type: "application/vnd.pgrst.object+json".to_string(),
    1169              :     //     }));
    1170              :     // }
    1171              : 
    1172              :     // TODO: rollback the transaction if the page_total is not 1 and the method is PUT
    1173              :     // we can not do this in the context of proxy for now
    1174              :     // if api_request.method == Method::PUT && api_response.page_total != 1 {
    1175              :     //     // Makes sure the querystring pk matches the payload pk
    1176              :     //     // e.g. PUT /items?id=eq.1 { "id" : 1, .. } is accepted,
    1177              :     //     // PUT /items?id=eq.14 { "id" : 2, .. } is rejected.
    1178              :     //     // If this condition is not satisfied then nothing is inserted,
    1179              :     //     // rollback the transaction here
    1180              :     //     return Err(RestError::SubzeroCore(PutMatchingPkError));
    1181              :     // }
    1182              : 
    1183              :     // create and return the response to the client
    1184              :     // this section mostly deals with setting the right headers according to PostgREST specs
    1185            0 :     let page_total = api_response.page_total;
    1186            0 :     let total_result_set = api_response.total_result_set;
    1187            0 :     let top_level_offset = api_response.top_level_offset;
    1188            0 :     let response_content_type = match (&api_request.accept_content_type, &api_request.query.node) {
    1189              :         (SingularJSON, _)
    1190              :         | (
    1191              :             _,
    1192              :             FunctionCall {
    1193              :                 returns_single: true,
    1194              :                 is_scalar: false,
    1195              :                 ..
    1196              :             },
    1197            0 :         ) => SingularJSON,
    1198            0 :         (TextCSV, _) => TextCSV,
    1199            0 :         _ => ApplicationJSON,
    1200              :     };
    1201              : 
    1202              :     // check if the SQL env set some response headers (happens when we called a rpc function)
    1203            0 :     if let Some(response_headers_str) = api_response.response_headers {
    1204            0 :         let Ok(headers_json) =
    1205            0 :             serde_json::from_str::<Vec<Vec<(String, String)>>>(response_headers_str.as_str())
    1206              :         else {
    1207            0 :             return Err(RestError::SubzeroCore(GucHeadersError));
    1208              :         };
    1209              : 
    1210            0 :         response_headers.extend(headers_json.into_iter().flatten());
    1211            0 :     }
    1212              : 
    1213              :     // calculate and set the content range header
    1214            0 :     let lower = top_level_offset as i64;
    1215            0 :     let upper = top_level_offset as i64 + page_total as i64 - 1;
    1216            0 :     let total = total_result_set.map(|t| t as i64);
    1217            0 :     let content_range = match (&method, &api_request.query.node) {
    1218            0 :         (&Method::POST, Insert { .. }) => content_range_header(1, 0, total),
    1219            0 :         (&Method::DELETE, Delete { .. }) => content_range_header(1, upper, total),
    1220            0 :         _ => content_range_header(lower, upper, total),
    1221              :     };
    1222            0 :     response_headers.push(("Content-Range".to_string(), content_range));
    1223              : 
    1224              :     // calculate the status code
    1225              :     #[rustfmt::skip]
    1226            0 :     let mut status = match (&method, &api_request.query.node, page_total, &api_request.preferences) {
    1227            0 :         (&Method::POST,   Insert { .. }, ..) => 201,
    1228            0 :         (&Method::DELETE, Delete { .. }, _, Some(Preferences {representation: Some(Representation::Full),..}),) => 200,
    1229            0 :         (&Method::DELETE, Delete { .. }, ..) => 204,
    1230            0 :         (&Method::PATCH,  Update { columns, .. }, 0, _) if !columns.is_empty() => 404,
    1231            0 :         (&Method::PATCH,  Update { .. }, _,Some(Preferences {representation: Some(Representation::Full),..}),) => 200,
    1232            0 :         (&Method::PATCH,  Update { .. }, ..) => 204,
    1233            0 :         (&Method::PUT,    Insert { .. },_,Some(Preferences {representation: Some(Representation::Full),..}),) => 200,
    1234            0 :         (&Method::PUT,    Insert { .. }, ..) => 204,
    1235            0 :         _ => content_range_status(lower, upper, total),
    1236              :     };
    1237              : 
    1238              :     // add the preference-applied header
    1239              :     if let Some(Preferences {
    1240            0 :         resolution: Some(r),
    1241              :         ..
    1242            0 :     }) = api_request.preferences
    1243              :     {
    1244            0 :         response_headers.push((
    1245            0 :             "Preference-Applied".to_string(),
    1246            0 :             match r {
    1247            0 :                 MergeDuplicates => "resolution=merge-duplicates".to_string(),
    1248            0 :                 IgnoreDuplicates => "resolution=ignore-duplicates".to_string(),
    1249              :             },
    1250              :         ));
    1251            0 :     }
    1252              : 
    1253              :     // check if the SQL env set some response status (happens when we called a rpc function)
    1254            0 :     if let Some(response_status_str) = api_response.response_status {
    1255            0 :         status = response_status_str
    1256            0 :             .parse::<u16>()
    1257            0 :             .map_err(|_| RestError::SubzeroCore(GucStatusError))?;
    1258            0 :     }
    1259              : 
    1260              :     // set the content type header
    1261              :     // TODO: move this to a subzero function
    1262              :     // as_header_value(&self) -> Option<&str>
    1263            0 :     let http_content_type = match response_content_type {
    1264            0 :         SingularJSON => Ok("application/vnd.pgrst.object+json"),
    1265            0 :         TextCSV => Ok("text/csv"),
    1266            0 :         ApplicationJSON => Ok("application/json"),
    1267            0 :         Other(t) => Err(RestError::SubzeroCore(ContentTypeError {
    1268            0 :             message: format!("None of these Content-Types are available: {t}"),
    1269            0 :         })),
    1270            0 :     }?;
    1271              : 
    1272              :     // build the response body
    1273            0 :     let response_body = Full::new(Bytes::from(api_response.body))
    1274            0 :         .map_err(|never| match never {})
    1275            0 :         .boxed();
    1276              : 
    1277              :     // build the response
    1278            0 :     response = response
    1279            0 :         .status(StatusCode::from_u16(status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR))
    1280            0 :         .header(CONTENT_TYPE, http_content_type);
    1281              : 
    1282              :     // Add all headers from response_headers vector
    1283            0 :     for (header_name, header_value) in response_headers {
    1284            0 :         response = response.header(header_name, header_value);
    1285            0 :     }
    1286              : 
    1287              :     // add the body and return the response
    1288            0 :     response.body(response_body).map_err(|_| {
    1289            0 :         RestError::SubzeroCore(InternalError {
    1290            0 :             message: "Failed to build response".to_string(),
    1291            0 :         })
    1292            0 :     })
    1293            0 : }
        

Generated by: LCOV version 2.1-beta