LCOV - code coverage report
Current view: top level - proxy/src/serverless - rest.rs (source / functions) Coverage Total Hit
Test: 4be46b1c0003aa3bbac9ade362c676b419df4c20.info Lines: 0.0 % 682 0
Test Date: 2025-07-22 17:50:06 Functions: 0.0 % 86 0

            Line data    Source code
       1              : use std::borrow::Cow;
       2              : use std::collections::HashMap;
       3              : use std::sync::Arc;
       4              : 
       5              : use bytes::Bytes;
       6              : use http::Method;
       7              : use http::header::{AUTHORIZATION, CONTENT_TYPE, HOST};
       8              : use http_body_util::combinators::BoxBody;
       9              : use http_body_util::{BodyExt, Full};
      10              : use http_utils::error::ApiError;
      11              : use hyper::body::Incoming;
      12              : use hyper::http::{HeaderName, HeaderValue};
      13              : use hyper::{Request, Response, StatusCode};
      14              : use indexmap::IndexMap;
      15              : use ouroboros::self_referencing;
      16              : use serde::de::DeserializeOwned;
      17              : use serde::{Deserialize, Deserializer};
      18              : use serde_json::Value as JsonValue;
      19              : use serde_json::value::RawValue;
      20              : use subzero_core::api::ContentType::{ApplicationJSON, Other, SingularJSON, TextCSV};
      21              : use subzero_core::api::QueryNode::{Delete, FunctionCall, Insert, Update};
      22              : use subzero_core::api::Resolution::{IgnoreDuplicates, MergeDuplicates};
      23              : use subzero_core::api::{ApiResponse, ListVal, Payload, Preferences, Representation, SingleVal};
      24              : use subzero_core::config::{db_allowed_select_functions, db_schemas, role_claim_key};
      25              : use subzero_core::dynamic_statement::{JoinIterator, param, sql};
      26              : use subzero_core::error::Error::{
      27              :     self as SubzeroCoreError, ContentTypeError, GucHeadersError, GucStatusError, InternalError,
      28              :     JsonDeserialize, JwtTokenInvalid, NotFound,
      29              : };
      30              : use subzero_core::error::pg_error_to_status_code;
      31              : use subzero_core::formatter::Param::{LV, PL, SV, Str, StrOwned};
      32              : use subzero_core::formatter::postgresql::{fmt_main_query, generate};
      33              : use subzero_core::formatter::{Param, Snippet, SqlParam};
      34              : use subzero_core::parser::postgrest::parse;
      35              : use subzero_core::permissions::{check_safe_functions, replace_select_star};
      36              : use subzero_core::schema::{
      37              :     DbSchema, POSTGRESQL_INTROSPECTION_SQL, get_postgresql_configuration_query,
      38              : };
      39              : use subzero_core::{content_range_header, content_range_status};
      40              : use tokio_util::sync::CancellationToken;
      41              : use tracing::{error, info};
      42              : use typed_json::json;
      43              : use url::form_urlencoded;
      44              : 
      45              : use super::backend::{HttpConnError, LocalProxyConnError, PoolingBackend};
      46              : use super::conn_pool::AuthData;
      47              : use super::conn_pool_lib::ConnInfo;
      48              : use super::error::{ConnInfoError, Credentials, HttpCodeError, ReadPayloadError};
      49              : use super::http_conn_pool::{self, Send};
      50              : use super::http_util::{
      51              :     ALLOW_POOL, CONN_STRING, NEON_REQUEST_ID, RAW_TEXT_OUTPUT, TXN_ISOLATION_LEVEL, TXN_READ_ONLY,
      52              :     get_conn_info, json_response, uuid_to_header_value,
      53              : };
      54              : use super::json::JsonConversionError;
      55              : use crate::auth::backend::ComputeCredentialKeys;
      56              : use crate::cache::{Cached, TimedLru};
      57              : use crate::config::ProxyConfig;
      58              : use crate::context::RequestContext;
      59              : use crate::error::{ErrorKind, ReportableError, UserFacingError};
      60              : use crate::http::read_body_with_limit;
      61              : use crate::metrics::Metrics;
      62              : use crate::serverless::sql_over_http::HEADER_VALUE_TRUE;
      63              : use crate::types::EndpointCacheKey;
      64              : use crate::util::deserialize_json_string;
      65              : 
      66              : static EMPTY_JSON_SCHEMA: &str = r#"{"schemas":[]}"#;
      67              : const INTROSPECTION_SQL: &str = POSTGRESQL_INTROSPECTION_SQL;
      68              : 
      69              : // A wrapper around the DbSchema that allows for self-referencing
      70              : #[self_referencing]
      71              : pub struct DbSchemaOwned {
      72              :     schema_string: String,
      73              :     #[covariant]
      74              :     #[borrows(schema_string)]
      75              :     schema: DbSchema<'this>,
      76              : }
      77              : 
      78              : impl<'de> Deserialize<'de> for DbSchemaOwned {
      79            0 :     fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
      80            0 :     where
      81            0 :         D: Deserializer<'de>,
      82              :     {
      83            0 :         let s = String::deserialize(deserializer)?;
      84            0 :         DbSchemaOwned::try_new(s, |s| serde_json::from_str(s))
      85            0 :             .map_err(<D::Error as serde::de::Error>::custom)
      86            0 :     }
      87              : }
      88              : 
      89            0 : fn split_comma_separated(s: &str) -> Vec<String> {
      90            0 :     s.split(',').map(|s| s.trim().to_string()).collect()
      91            0 : }
      92              : 
      93            0 : fn deserialize_comma_separated<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
      94            0 : where
      95            0 :     D: Deserializer<'de>,
      96              : {
      97            0 :     let s = String::deserialize(deserializer)?;
      98            0 :     Ok(split_comma_separated(&s))
      99            0 : }
     100              : 
     101            0 : fn deserialize_comma_separated_option<'de, D>(
     102            0 :     deserializer: D,
     103            0 : ) -> Result<Option<Vec<String>>, D::Error>
     104            0 : where
     105            0 :     D: Deserializer<'de>,
     106              : {
     107            0 :     let opt = Option::<String>::deserialize(deserializer)?;
     108            0 :     if let Some(s) = &opt {
     109            0 :         let trimmed = s.trim();
     110            0 :         if trimmed.is_empty() {
     111            0 :             return Ok(None);
     112            0 :         }
     113            0 :         return Ok(Some(split_comma_separated(trimmed)));
     114            0 :     }
     115            0 :     Ok(None)
     116            0 : }
     117              : 
     118              : // The ApiConfig is the configuration for the API per endpoint
     119              : // The configuration is read from the database and cached in the DbSchemaCache
     120              : #[derive(Deserialize, Debug)]
     121              : pub struct ApiConfig {
     122              :     #[serde(
     123              :         default = "db_schemas",
     124              :         deserialize_with = "deserialize_comma_separated"
     125              :     )]
     126              :     pub db_schemas: Vec<String>,
     127              :     pub db_anon_role: Option<String>,
     128              :     pub db_max_rows: Option<String>,
     129              :     #[serde(default = "db_allowed_select_functions")]
     130              :     pub db_allowed_select_functions: Vec<String>,
     131              :     // #[serde(deserialize_with = "to_tuple", default)]
     132              :     // pub db_pre_request: Option<(String, String)>,
     133              :     #[allow(dead_code)]
     134              :     #[serde(default = "role_claim_key")]
     135              :     pub role_claim_key: String,
     136              :     #[serde(default, deserialize_with = "deserialize_comma_separated_option")]
     137              :     pub db_extra_search_path: Option<Vec<String>>,
     138              : }
     139              : 
     140              : // The DbSchemaCache is a cache of the ApiConfig and DbSchemaOwned for each endpoint
     141              : pub(crate) type DbSchemaCache = TimedLru<EndpointCacheKey, Arc<(ApiConfig, DbSchemaOwned)>>;
     142              : impl DbSchemaCache {
     143            0 :     pub async fn get_cached_or_remote(
     144            0 :         &self,
     145            0 :         endpoint_id: &EndpointCacheKey,
     146            0 :         auth_header: &HeaderValue,
     147            0 :         connection_string: &str,
     148            0 :         client: &mut http_conn_pool::Client<Send>,
     149            0 :         ctx: &RequestContext,
     150            0 :         config: &'static ProxyConfig,
     151            0 :     ) -> Result<Arc<(ApiConfig, DbSchemaOwned)>, RestError> {
     152            0 :         match self.get_with_created_at(endpoint_id) {
     153            0 :             Some(Cached { value: (v, _), .. }) => Ok(v),
     154              :             None => {
     155            0 :                 info!("db_schema cache miss for endpoint: {:?}", endpoint_id);
     156            0 :                 let remote_value = self
     157            0 :                     .get_remote(auth_header, connection_string, client, ctx, config)
     158            0 :                     .await;
     159            0 :                 let (api_config, schema_owned) = match remote_value {
     160            0 :                     Ok((api_config, schema_owned)) => (api_config, schema_owned),
     161            0 :                     Err(e @ RestError::SchemaTooLarge) => {
     162              :                         // for the case where the schema is too large, we cache an empty dummy value
     163              :                         // all the other requests will fail without triggering the introspection query
     164            0 :                         let schema_owned = serde_json::from_str::<DbSchemaOwned>(EMPTY_JSON_SCHEMA)
     165            0 :                             .map_err(|e| JsonDeserialize { source: e })?;
     166              : 
     167            0 :                         let api_config = ApiConfig {
     168            0 :                             db_schemas: vec![],
     169            0 :                             db_anon_role: None,
     170            0 :                             db_max_rows: None,
     171            0 :                             db_allowed_select_functions: vec![],
     172            0 :                             role_claim_key: String::new(),
     173            0 :                             db_extra_search_path: None,
     174            0 :                         };
     175            0 :                         let value = Arc::new((api_config, schema_owned));
     176            0 :                         self.insert(endpoint_id.clone(), value);
     177            0 :                         return Err(e);
     178              :                     }
     179            0 :                     Err(e) => {
     180            0 :                         return Err(e);
     181              :                     }
     182              :                 };
     183            0 :                 let value = Arc::new((api_config, schema_owned));
     184            0 :                 self.insert(endpoint_id.clone(), value.clone());
     185            0 :                 Ok(value)
     186              :             }
     187              :         }
     188            0 :     }
     189            0 :     pub async fn get_remote(
     190            0 :         &self,
     191            0 :         auth_header: &HeaderValue,
     192            0 :         connection_string: &str,
     193            0 :         client: &mut http_conn_pool::Client<Send>,
     194            0 :         ctx: &RequestContext,
     195            0 :         config: &'static ProxyConfig,
     196            0 :     ) -> Result<(ApiConfig, DbSchemaOwned), RestError> {
     197            0 :         #[derive(Deserialize)]
     198              :         struct SingleRow<Row> {
     199              :             rows: [Row; 1],
     200              :         }
     201              : 
     202              :         #[derive(Deserialize)]
     203              :         struct ConfigRow {
     204              :             #[serde(deserialize_with = "deserialize_json_string")]
     205              :             config: ApiConfig,
     206              :         }
     207              : 
     208            0 :         #[derive(Deserialize)]
     209              :         struct SchemaRow {
     210              :             json_schema: DbSchemaOwned,
     211              :         }
     212              : 
     213            0 :         let headers = vec![
     214            0 :             (&NEON_REQUEST_ID, uuid_to_header_value(ctx.session_id())),
     215            0 :             (
     216            0 :                 &CONN_STRING,
     217            0 :                 HeaderValue::from_str(connection_string).expect(
     218            0 :                     "connection string came from a header, so it must be a valid headervalue",
     219            0 :                 ),
     220            0 :             ),
     221            0 :             (&AUTHORIZATION, auth_header.clone()),
     222            0 :             (&RAW_TEXT_OUTPUT, HEADER_VALUE_TRUE),
     223              :         ];
     224              : 
     225            0 :         let query = get_postgresql_configuration_query(Some("pgrst.pre_config"));
     226              :         let SingleRow {
     227            0 :             rows: [ConfigRow { config: api_config }],
     228            0 :         } = make_local_proxy_request(
     229            0 :             client,
     230            0 :             headers.iter().cloned(),
     231            0 :             QueryData {
     232            0 :                 query: Cow::Owned(query),
     233            0 :                 params: vec![],
     234            0 :             },
     235            0 :             config.rest_config.max_schema_size,
     236            0 :         )
     237            0 :         .await
     238            0 :         .map_err(|e| match e {
     239              :             RestError::ReadPayload(ReadPayloadError::BodyTooLarge { .. }) => {
     240            0 :                 RestError::SchemaTooLarge
     241              :             }
     242            0 :             e => e,
     243            0 :         })?;
     244              : 
     245              :         // now that we have the api_config let's run the second INTROSPECTION_SQL query
     246              :         let SingleRow {
     247            0 :             rows: [SchemaRow { json_schema }],
     248            0 :         } = make_local_proxy_request(
     249            0 :             client,
     250            0 :             headers,
     251            0 :             QueryData {
     252            0 :                 query: INTROSPECTION_SQL.into(),
     253            0 :                 params: vec![
     254            0 :                     serde_json::to_value(&api_config.db_schemas)
     255            0 :                         .expect("Vec<String> is always valid to encode as JSON"),
     256            0 :                     JsonValue::Bool(false), // include_roles_with_login
     257            0 :                     JsonValue::Bool(false), // use_internal_permissions
     258            0 :                 ],
     259            0 :             },
     260            0 :             config.rest_config.max_schema_size,
     261            0 :         )
     262            0 :         .await
     263            0 :         .map_err(|e| match e {
     264              :             RestError::ReadPayload(ReadPayloadError::BodyTooLarge { .. }) => {
     265            0 :                 RestError::SchemaTooLarge
     266              :             }
     267            0 :             e => e,
     268            0 :         })?;
     269              : 
     270            0 :         Ok((api_config, json_schema))
     271            0 :     }
     272              : }
     273              : 
     274              : // A type to represent a postgresql errors
     275              : // we use our own type (instead of postgres_client::Error) because we get the error from the json response
     276            0 : #[derive(Debug, thiserror::Error, Deserialize)]
     277              : pub(crate) struct PostgresError {
     278              :     pub code: String,
     279              :     pub message: String,
     280              :     pub detail: Option<String>,
     281              :     pub hint: Option<String>,
     282              : }
     283              : impl HttpCodeError for PostgresError {
     284            0 :     fn get_http_status_code(&self) -> StatusCode {
     285            0 :         let status = pg_error_to_status_code(&self.code, true);
     286            0 :         StatusCode::from_u16(status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)
     287            0 :     }
     288              : }
     289              : impl ReportableError for PostgresError {
     290            0 :     fn get_error_kind(&self) -> ErrorKind {
     291            0 :         ErrorKind::User
     292            0 :     }
     293              : }
     294              : impl UserFacingError for PostgresError {
     295            0 :     fn to_string_client(&self) -> String {
     296            0 :         if self.code.starts_with("PT") {
     297            0 :             "Postgres error".to_string()
     298              :         } else {
     299            0 :             self.message.clone()
     300              :         }
     301            0 :     }
     302              : }
     303              : impl std::fmt::Display for PostgresError {
     304            0 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
     305            0 :         write!(f, "{}", self.message)
     306            0 :     }
     307              : }
     308              : 
     309              : // A type to represent errors that can occur in the rest broker
     310              : #[derive(Debug, thiserror::Error)]
     311              : pub(crate) enum RestError {
     312              :     #[error(transparent)]
     313              :     ReadPayload(#[from] ReadPayloadError),
     314              :     #[error(transparent)]
     315              :     ConnectCompute(#[from] HttpConnError),
     316              :     #[error(transparent)]
     317              :     ConnInfo(#[from] ConnInfoError),
     318              :     #[error(transparent)]
     319              :     Postgres(#[from] PostgresError),
     320              :     #[error(transparent)]
     321              :     JsonConversion(#[from] JsonConversionError),
     322              :     #[error(transparent)]
     323              :     SubzeroCore(#[from] SubzeroCoreError),
     324              :     #[error("schema is too large")]
     325              :     SchemaTooLarge,
     326              : }
     327              : impl ReportableError for RestError {
     328            0 :     fn get_error_kind(&self) -> ErrorKind {
     329            0 :         match self {
     330            0 :             RestError::ReadPayload(e) => e.get_error_kind(),
     331            0 :             RestError::ConnectCompute(e) => e.get_error_kind(),
     332            0 :             RestError::ConnInfo(e) => e.get_error_kind(),
     333            0 :             RestError::Postgres(_) => ErrorKind::Postgres,
     334            0 :             RestError::JsonConversion(_) => ErrorKind::Postgres,
     335            0 :             RestError::SubzeroCore(_) => ErrorKind::User,
     336            0 :             RestError::SchemaTooLarge => ErrorKind::User,
     337              :         }
     338            0 :     }
     339              : }
     340              : impl UserFacingError for RestError {
     341            0 :     fn to_string_client(&self) -> String {
     342            0 :         match self {
     343            0 :             RestError::ReadPayload(p) => p.to_string(),
     344            0 :             RestError::ConnectCompute(c) => c.to_string_client(),
     345            0 :             RestError::ConnInfo(c) => c.to_string_client(),
     346            0 :             RestError::SchemaTooLarge => self.to_string(),
     347            0 :             RestError::Postgres(p) => p.to_string_client(),
     348            0 :             RestError::JsonConversion(_) => "could not parse postgres response".to_string(),
     349            0 :             RestError::SubzeroCore(s) => {
     350              :                 // TODO: this is a hack to get the message from the json body
     351            0 :                 let json = s.json_body();
     352            0 :                 let default_message = "Unknown error".to_string();
     353              : 
     354            0 :                 json.get("message")
     355            0 :                     .map_or(default_message.clone(), |m| match m {
     356            0 :                         JsonValue::String(s) => s.clone(),
     357            0 :                         _ => default_message,
     358            0 :                     })
     359              :             }
     360              :         }
     361            0 :     }
     362              : }
     363              : impl HttpCodeError for RestError {
     364            0 :     fn get_http_status_code(&self) -> StatusCode {
     365            0 :         match self {
     366            0 :             RestError::ReadPayload(e) => e.get_http_status_code(),
     367            0 :             RestError::ConnectCompute(h) => match h.get_error_kind() {
     368            0 :                 ErrorKind::User => StatusCode::BAD_REQUEST,
     369            0 :                 _ => StatusCode::INTERNAL_SERVER_ERROR,
     370              :             },
     371            0 :             RestError::ConnInfo(_) => StatusCode::BAD_REQUEST,
     372            0 :             RestError::Postgres(e) => e.get_http_status_code(),
     373            0 :             RestError::JsonConversion(_) => StatusCode::INTERNAL_SERVER_ERROR,
     374            0 :             RestError::SchemaTooLarge => StatusCode::INTERNAL_SERVER_ERROR,
     375            0 :             RestError::SubzeroCore(e) => {
     376            0 :                 let status = e.status_code();
     377            0 :                 StatusCode::from_u16(status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)
     378              :             }
     379              :         }
     380            0 :     }
     381              : }
     382              : 
     383              : // Helper functions for the rest broker
     384              : 
     385            0 : fn fmt_env_query<'a>(env: &'a HashMap<&'a str, &'a str>) -> Snippet<'a> {
     386              :     "select "
     387            0 :         + if env.is_empty() {
     388            0 :             sql("null")
     389              :         } else {
     390            0 :             env.iter()
     391            0 :                 .map(|(k, v)| {
     392            0 :                     "set_config(" + param(k as &SqlParam) + ", " + param(v as &SqlParam) + ", true)"
     393            0 :                 })
     394            0 :                 .join(",")
     395              :         }
     396            0 : }
     397              : 
     398              : // TODO: see about removing the need for cloning the values (inner things are &Cow<str> already)
     399            0 : fn to_sql_param(p: &Param) -> JsonValue {
     400            0 :     match p {
     401            0 :         SV(SingleVal(v, ..)) => JsonValue::String(v.to_string()),
     402            0 :         Str(v) => JsonValue::String((*v).to_string()),
     403            0 :         StrOwned(v) => JsonValue::String((*v).clone()),
     404            0 :         PL(Payload(v, ..)) => JsonValue::String(v.clone().into_owned()),
     405            0 :         LV(ListVal(v, ..)) => {
     406            0 :             if v.is_empty() {
     407            0 :                 JsonValue::String(r"{}".to_string())
     408              :             } else {
     409            0 :                 JsonValue::String(format!(
     410            0 :                     "{{\"{}\"}}",
     411            0 :                     v.iter()
     412            0 :                         .map(|e| e.replace('\\', "\\\\").replace('\"', "\\\""))
     413            0 :                         .collect::<Vec<_>>()
     414            0 :                         .join("\",\"")
     415              :                 ))
     416              :             }
     417              :         }
     418              :     }
     419            0 : }
     420              : 
     421              : #[derive(serde::Serialize)]
     422              : struct QueryData<'a> {
     423              :     query: Cow<'a, str>,
     424              :     params: Vec<JsonValue>,
     425              : }
     426              : 
     427              : #[derive(serde::Serialize)]
     428              : struct BatchQueryData<'a> {
     429              :     queries: Vec<QueryData<'a>>,
     430              : }
     431              : 
     432            0 : async fn make_local_proxy_request<S: DeserializeOwned>(
     433            0 :     client: &mut http_conn_pool::Client<Send>,
     434            0 :     headers: impl IntoIterator<Item = (&HeaderName, HeaderValue)>,
     435            0 :     body: QueryData<'_>,
     436            0 :     max_len: usize,
     437            0 : ) -> Result<S, RestError> {
     438            0 :     let body_string = serde_json::to_string(&body)
     439            0 :         .map_err(|e| RestError::JsonConversion(JsonConversionError::ParseJsonError(e)))?;
     440              : 
     441            0 :     let response = make_raw_local_proxy_request(client, headers, body_string).await?;
     442              : 
     443            0 :     let response_status = response.status();
     444              : 
     445            0 :     if response_status != StatusCode::OK {
     446            0 :         return Err(RestError::SubzeroCore(InternalError {
     447            0 :             message: "Failed to get endpoint schema".to_string(),
     448            0 :         }));
     449            0 :     }
     450              : 
     451              :     // Capture the response body
     452            0 :     let response_body = crate::http::read_body_with_limit(response.into_body(), max_len)
     453            0 :         .await
     454            0 :         .map_err(ReadPayloadError::from)?;
     455              : 
     456              :     // Parse the JSON response
     457            0 :     let response_json: S = serde_json::from_slice(&response_body)
     458            0 :         .map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?;
     459              : 
     460            0 :     Ok(response_json)
     461            0 : }
     462              : 
     463            0 : async fn make_raw_local_proxy_request(
     464            0 :     client: &mut http_conn_pool::Client<Send>,
     465            0 :     headers: impl IntoIterator<Item = (&HeaderName, HeaderValue)>,
     466            0 :     body: String,
     467            0 : ) -> Result<Response<Incoming>, RestError> {
     468            0 :     let local_proxy_uri = ::http::Uri::from_static("http://proxy.local/sql");
     469            0 :     let mut req = Request::builder().method(Method::POST).uri(local_proxy_uri);
     470            0 :     let req_headers = req.headers_mut().expect("failed to get headers");
     471              :     // Add all provided headers to the request
     472            0 :     for (header_name, header_value) in headers {
     473            0 :         req_headers.insert(header_name, header_value.clone());
     474            0 :     }
     475              : 
     476            0 :     let body_boxed = Full::new(Bytes::from(body))
     477            0 :         .map_err(|never| match never {}) // Convert Infallible to hyper::Error
     478            0 :         .boxed();
     479              : 
     480            0 :     let req = req.body(body_boxed).map_err(|_| {
     481            0 :         RestError::SubzeroCore(InternalError {
     482            0 :             message: "Failed to build request".to_string(),
     483            0 :         })
     484            0 :     })?;
     485              : 
     486              :     // Send the request to the local proxy
     487            0 :     client
     488            0 :         .inner
     489            0 :         .inner
     490            0 :         .send_request(req)
     491            0 :         .await
     492            0 :         .map_err(LocalProxyConnError::from)
     493            0 :         .map_err(HttpConnError::from)
     494            0 :         .map_err(RestError::from)
     495            0 : }
     496              : 
     497            0 : pub(crate) async fn handle(
     498            0 :     config: &'static ProxyConfig,
     499            0 :     ctx: RequestContext,
     500            0 :     request: Request<Incoming>,
     501            0 :     backend: Arc<PoolingBackend>,
     502            0 :     cancel: CancellationToken,
     503            0 : ) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
     504            0 :     let result = handle_inner(cancel, config, &ctx, request, backend).await;
     505              : 
     506            0 :     let mut response = match result {
     507            0 :         Ok(r) => {
     508            0 :             ctx.set_success();
     509              : 
     510              :             // Handling the error response from local proxy here
     511            0 :             if r.status().is_server_error() {
     512            0 :                 let status = r.status();
     513              : 
     514            0 :                 let body_bytes = r
     515            0 :                     .collect()
     516            0 :                     .await
     517            0 :                     .map_err(|e| {
     518            0 :                         ApiError::InternalServerError(anyhow::Error::msg(format!(
     519            0 :                             "could not collect http body: {e}"
     520            0 :                         )))
     521            0 :                     })?
     522            0 :                     .to_bytes();
     523              : 
     524            0 :                 if let Ok(mut json_map) =
     525            0 :                     serde_json::from_slice::<IndexMap<&str, &RawValue>>(&body_bytes)
     526              :                 {
     527            0 :                     let message = json_map.get("message");
     528            0 :                     if let Some(message) = message {
     529            0 :                         let msg: String = match serde_json::from_str(message.get()) {
     530            0 :                             Ok(msg) => msg,
     531              :                             Err(_) => {
     532            0 :                                 "Unable to parse the response message from server".to_string()
     533              :                             }
     534              :                         };
     535              : 
     536            0 :                         error!("Error response from local_proxy: {status} {msg}");
     537              : 
     538            0 :                         json_map.retain(|key, _| !key.starts_with("neon:")); // remove all the neon-related keys
     539              : 
     540            0 :                         let resp_json = serde_json::to_string(&json_map)
     541            0 :                             .unwrap_or("failed to serialize the response message".to_string());
     542              : 
     543            0 :                         return json_response(status, resp_json);
     544            0 :                     }
     545            0 :                 }
     546              : 
     547            0 :                 error!("Unable to parse the response message from local_proxy");
     548            0 :                 return json_response(
     549            0 :                     status,
     550            0 :                     json!({ "message": "Unable to parse the response message from server".to_string() }),
     551              :                 );
     552            0 :             }
     553            0 :             r
     554              :         }
     555            0 :         Err(e @ RestError::SubzeroCore(_)) => {
     556            0 :             let error_kind = e.get_error_kind();
     557            0 :             ctx.set_error_kind(error_kind);
     558              : 
     559            0 :             tracing::info!(
     560            0 :                 kind=error_kind.to_metric_label(),
     561              :                 error=%e,
     562              :                 msg="subzero core error",
     563            0 :                 "forwarding error to user"
     564              :             );
     565              : 
     566            0 :             let RestError::SubzeroCore(subzero_err) = e else {
     567            0 :                 panic!("expected subzero core error")
     568              :             };
     569              : 
     570            0 :             let json_body = subzero_err.json_body();
     571            0 :             let status_code = StatusCode::from_u16(subzero_err.status_code())
     572            0 :                 .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
     573              : 
     574            0 :             json_response(status_code, json_body)?
     575              :         }
     576            0 :         Err(e) => {
     577            0 :             let error_kind = e.get_error_kind();
     578            0 :             ctx.set_error_kind(error_kind);
     579              : 
     580            0 :             let message = e.to_string_client();
     581            0 :             let status_code = e.get_http_status_code();
     582              : 
     583            0 :             tracing::info!(
     584            0 :                 kind=error_kind.to_metric_label(),
     585              :                 error=%e,
     586              :                 msg=message,
     587            0 :                 "forwarding error to user"
     588              :             );
     589              : 
     590            0 :             let (code, detail, hint) = match e {
     591            0 :                 RestError::Postgres(e) => (
     592            0 :                     if e.code.starts_with("PT") {
     593            0 :                         None
     594              :                     } else {
     595            0 :                         Some(e.code)
     596              :                     },
     597            0 :                     e.detail,
     598            0 :                     e.hint,
     599              :                 ),
     600            0 :                 _ => (None, None, None),
     601              :             };
     602              : 
     603            0 :             json_response(
     604            0 :                 status_code,
     605            0 :                 json!({
     606            0 :                     "message": message,
     607            0 :                     "code": code,
     608            0 :                     "detail": detail,
     609            0 :                     "hint": hint,
     610              :                 }),
     611            0 :             )?
     612              :         }
     613              :     };
     614              : 
     615            0 :     response
     616            0 :         .headers_mut()
     617            0 :         .insert("Access-Control-Allow-Origin", HeaderValue::from_static("*"));
     618            0 :     Ok(response)
     619            0 : }
     620              : 
     621            0 : async fn handle_inner(
     622            0 :     _cancel: CancellationToken,
     623            0 :     config: &'static ProxyConfig,
     624            0 :     ctx: &RequestContext,
     625            0 :     request: Request<Incoming>,
     626            0 :     backend: Arc<PoolingBackend>,
     627            0 : ) -> Result<Response<BoxBody<Bytes, hyper::Error>>, RestError> {
     628            0 :     let _requeset_gauge = Metrics::get()
     629            0 :         .proxy
     630            0 :         .connection_requests
     631            0 :         .guard(ctx.protocol());
     632            0 :     info!(
     633            0 :         protocol = %ctx.protocol(),
     634            0 :         "handling interactive connection from client"
     635              :     );
     636              : 
     637              :     // Read host from Host, then URI host as fallback
     638              :     // TODO: will this be a problem if behind a load balancer?
     639              :     // TODO: can we use the x-forwarded-host header?
     640            0 :     let host = request
     641            0 :         .headers()
     642            0 :         .get(HOST)
     643            0 :         .and_then(|v| v.to_str().ok())
     644            0 :         .unwrap_or_else(|| request.uri().host().unwrap_or(""));
     645              : 
     646              :     // a valid path is /database/rest/v1/... so splitting should be ["", "database", "rest", "v1", ...]
     647            0 :     let database_name = request
     648            0 :         .uri()
     649            0 :         .path()
     650            0 :         .split('/')
     651            0 :         .nth(1)
     652            0 :         .ok_or(RestError::SubzeroCore(NotFound {
     653            0 :             target: request.uri().path().to_string(),
     654            0 :         }))?;
     655              : 
     656              :     // we always use the authenticator role to connect to the database
     657            0 :     let authenticator_role = "authenticator";
     658              : 
     659              :     // Strip the hostname prefix from the host to get the database hostname
     660            0 :     let database_host = host.replace(&config.rest_config.hostname_prefix, "");
     661              : 
     662            0 :     let connection_string =
     663            0 :         format!("postgresql://{authenticator_role}@{database_host}/{database_name}");
     664              : 
     665            0 :     let conn_info = get_conn_info(
     666            0 :         &config.authentication_config,
     667            0 :         ctx,
     668            0 :         Some(&connection_string),
     669            0 :         request.headers(),
     670            0 :     )?;
     671            0 :     info!(
     672            0 :         user = conn_info.conn_info.user_info.user.as_str(),
     673            0 :         "credentials"
     674              :     );
     675              : 
     676            0 :     match conn_info.auth {
     677            0 :         AuthData::Jwt(jwt) => {
     678            0 :             let api_prefix = format!("/{database_name}/rest/v1/");
     679            0 :             handle_rest_inner(
     680            0 :                 config,
     681            0 :                 ctx,
     682            0 :                 &api_prefix,
     683            0 :                 request,
     684            0 :                 &connection_string,
     685            0 :                 conn_info.conn_info,
     686            0 :                 jwt,
     687            0 :                 backend,
     688            0 :             )
     689            0 :             .await
     690              :         }
     691            0 :         AuthData::Password(_) => Err(RestError::ConnInfo(ConnInfoError::MissingCredentials(
     692            0 :             Credentials::BearerJwt,
     693            0 :         ))),
     694              :     }
     695            0 : }
     696              : 
     697              : #[allow(clippy::too_many_arguments)]
     698            0 : async fn handle_rest_inner(
     699            0 :     config: &'static ProxyConfig,
     700            0 :     ctx: &RequestContext,
     701            0 :     api_prefix: &str,
     702            0 :     request: Request<Incoming>,
     703            0 :     connection_string: &str,
     704            0 :     conn_info: ConnInfo,
     705            0 :     jwt: String,
     706            0 :     backend: Arc<PoolingBackend>,
     707            0 : ) -> Result<Response<BoxBody<Bytes, hyper::Error>>, RestError> {
     708              :     // validate the jwt token
     709            0 :     let jwt_parsed = backend
     710            0 :         .authenticate_with_jwt(ctx, &conn_info.user_info, jwt)
     711            0 :         .await
     712            0 :         .map_err(HttpConnError::from)?;
     713              : 
     714            0 :     let db_schema_cache =
     715            0 :         config
     716            0 :             .rest_config
     717            0 :             .db_schema_cache
     718            0 :             .as_ref()
     719            0 :             .ok_or(RestError::SubzeroCore(InternalError {
     720            0 :                 message: "DB schema cache is not configured".to_string(),
     721            0 :             }))?;
     722              : 
     723            0 :     let endpoint_cache_key = conn_info
     724            0 :         .endpoint_cache_key()
     725            0 :         .ok_or(RestError::SubzeroCore(InternalError {
     726            0 :             message: "Failed to get endpoint cache key".to_string(),
     727            0 :         }))?;
     728              : 
     729            0 :     let mut client = backend.connect_to_local_proxy(ctx, conn_info).await?;
     730              : 
     731            0 :     let (parts, originial_body) = request.into_parts();
     732              : 
     733            0 :     let auth_header = parts
     734            0 :         .headers
     735            0 :         .get(AUTHORIZATION)
     736            0 :         .ok_or(RestError::SubzeroCore(InternalError {
     737            0 :             message: "Authorization header is required".to_string(),
     738            0 :         }))?;
     739              : 
     740            0 :     let entry = db_schema_cache
     741            0 :         .get_cached_or_remote(
     742            0 :             &endpoint_cache_key,
     743            0 :             auth_header,
     744            0 :             connection_string,
     745            0 :             &mut client,
     746            0 :             ctx,
     747            0 :             config,
     748            0 :         )
     749            0 :         .await?;
     750            0 :     let (api_config, db_schema_owned) = entry.as_ref();
     751            0 :     let db_schema = db_schema_owned.borrow_schema();
     752              : 
     753            0 :     let db_schemas = &api_config.db_schemas; // list of schemas available for the api
     754            0 :     let db_extra_search_path = &api_config.db_extra_search_path;
     755              :     // TODO: use this when we get a replacement for jsonpath_lib
     756              :     // let role_claim_key = &api_config.role_claim_key;
     757              :     // let role_claim_path = format!("${role_claim_key}");
     758            0 :     let db_anon_role = &api_config.db_anon_role;
     759            0 :     let max_rows = api_config.db_max_rows.as_deref();
     760            0 :     let db_allowed_select_functions = api_config
     761            0 :         .db_allowed_select_functions
     762            0 :         .iter()
     763            0 :         .map(|s| s.as_str())
     764            0 :         .collect::<Vec<_>>();
     765              : 
     766              :     // extract the jwt claims (we'll need them later to set the role and env)
     767            0 :     let jwt_claims = match jwt_parsed.keys {
     768            0 :         ComputeCredentialKeys::JwtPayload(payload_bytes) => {
     769              :             // `payload_bytes` contains the raw JWT payload as Vec<u8>
     770              :             // You can deserialize it back to JSON or parse specific claims
     771            0 :             let payload: serde_json::Value = serde_json::from_slice(&payload_bytes)
     772            0 :                 .map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?;
     773            0 :             Some(payload)
     774              :         }
     775            0 :         ComputeCredentialKeys::AuthKeys(_) => None,
     776              :     };
     777              : 
     778              :     // read the role from the jwt claims (and set it to the "anon" role if not present)
     779            0 :     let (role, authenticated) = match &jwt_claims {
     780            0 :         Some(claims) => match claims.get("role") {
     781            0 :             Some(JsonValue::String(r)) => (Some(r), true),
     782            0 :             _ => (db_anon_role.as_ref(), true),
     783              :         },
     784            0 :         None => (db_anon_role.as_ref(), false),
     785              :     };
     786              : 
     787              :     // do not allow unauthenticated requests when there is no anonymous role setup
     788            0 :     if let (None, false) = (role, authenticated) {
     789            0 :         return Err(RestError::SubzeroCore(JwtTokenInvalid {
     790            0 :             message: "unauthenticated requests not allowed".to_string(),
     791            0 :         }));
     792            0 :     }
     793              : 
     794              :     // start deconstructing the request because subzero core mostly works with &str
     795            0 :     let method = parts.method;
     796            0 :     let method_str = method.as_str();
     797            0 :     let path = parts.uri.path_and_query().map_or("/", |pq| pq.as_str());
     798              : 
     799              :     // this is actually the table name (or rpc/function_name)
     800              :     // TODO: rename this to something more descriptive
     801            0 :     let root = match parts.uri.path().strip_prefix(api_prefix) {
     802            0 :         Some(p) => Ok(p),
     803            0 :         None => Err(RestError::SubzeroCore(NotFound {
     804            0 :             target: parts.uri.path().to_string(),
     805            0 :         })),
     806            0 :     }?;
     807              : 
     808              :     // pick the current schema from the headers (or the first one from config)
     809            0 :     let schema_name = &DbSchema::pick_current_schema(db_schemas, method_str, &parts.headers)?;
     810              : 
     811              :     // add the content-profile header to the response
     812            0 :     let mut response_headers = vec![];
     813            0 :     if db_schemas.len() > 1 {
     814            0 :         response_headers.push(("Content-Profile".to_string(), schema_name.clone()));
     815            0 :     }
     816              : 
     817              :     // parse the query string into a Vec<(&str, &str)>
     818            0 :     let query = match parts.uri.query() {
     819            0 :         Some(q) => form_urlencoded::parse(q.as_bytes()).collect(),
     820            0 :         None => vec![],
     821              :     };
     822            0 :     let get: Vec<(&str, &str)> = query.iter().map(|(k, v)| (&**k, &**v)).collect();
     823              : 
     824              :     // convert the headers map to a HashMap<&str, &str>
     825            0 :     let headers: HashMap<&str, &str> = parts
     826            0 :         .headers
     827            0 :         .iter()
     828            0 :         .map(|(k, v)| (k.as_str(), v.to_str().unwrap_or("__BAD_HEADER__")))
     829            0 :         .collect();
     830              : 
     831            0 :     let cookies = HashMap::new(); // TODO: add cookies
     832              : 
     833              :     // Read the request body (skip for GET requests)
     834            0 :     let body_as_string: Option<String> = if method == Method::GET {
     835            0 :         None
     836              :     } else {
     837            0 :         let body_bytes =
     838            0 :             read_body_with_limit(originial_body, config.http_config.max_request_size_bytes)
     839            0 :                 .await
     840            0 :                 .map_err(ReadPayloadError::from)?;
     841            0 :         if body_bytes.is_empty() {
     842            0 :             None
     843              :         } else {
     844            0 :             Some(String::from_utf8_lossy(&body_bytes).into_owned())
     845              :         }
     846              :     };
     847              : 
     848              :     // parse the request into an ApiRequest struct
     849            0 :     let mut api_request = parse(
     850            0 :         schema_name,
     851            0 :         root,
     852            0 :         db_schema,
     853            0 :         method_str,
     854            0 :         path,
     855            0 :         get,
     856            0 :         body_as_string.as_deref(),
     857            0 :         headers,
     858            0 :         cookies,
     859            0 :         max_rows,
     860              :     )
     861            0 :     .map_err(RestError::SubzeroCore)?;
     862              : 
     863            0 :     let role_str = match role {
     864            0 :         Some(r) => r,
     865            0 :         None => "",
     866              :     };
     867              : 
     868            0 :     replace_select_star(db_schema, schema_name, role_str, &mut api_request.query)?;
     869              : 
     870              :     // TODO: this is not relevant when acting as PostgREST but will be useful
     871              :     // in the context of DBX where they need internal permissions
     872              :     // if !disable_internal_permissions {
     873              :     //     check_privileges(db_schema, schema_name, role_str, &api_request)?;
     874              :     // }
     875              : 
     876            0 :     check_safe_functions(&api_request, &db_allowed_select_functions)?;
     877              : 
     878              :     // TODO: this is not relevant when acting as PostgREST but will be useful
     879              :     // in the context of DBX where they need internal permissions
     880              :     // if !disable_internal_permissions {
     881              :     //     insert_policy_conditions(db_schema, schema_name, role_str, &mut api_request.query)?;
     882              :     // }
     883              : 
     884            0 :     let env_role = Some(role_str);
     885              : 
     886              :     // construct the env (passed in to the sql context as GUCs)
     887            0 :     let empty_json = "{}".to_string();
     888            0 :     let headers_env = serde_json::to_string(&api_request.headers).unwrap_or(empty_json.clone());
     889            0 :     let cookies_env = serde_json::to_string(&api_request.cookies).unwrap_or(empty_json.clone());
     890            0 :     let get_env = serde_json::to_string(&api_request.get).unwrap_or(empty_json.clone());
     891            0 :     let jwt_claims_env = jwt_claims
     892            0 :         .as_ref()
     893            0 :         .map(|v| serde_json::to_string(v).unwrap_or(empty_json.clone()))
     894            0 :         .unwrap_or(if let Some(r) = env_role {
     895            0 :             let claims: HashMap<&str, &str> = HashMap::from([("role", r)]);
     896            0 :             serde_json::to_string(&claims).unwrap_or(empty_json.clone())
     897              :         } else {
     898            0 :             empty_json.clone()
     899              :         });
     900            0 :     let mut search_path = vec![api_request.schema_name];
     901            0 :     if let Some(extra) = &db_extra_search_path {
     902            0 :         search_path.extend(extra.iter().map(|s| s.as_str()));
     903            0 :     }
     904            0 :     let search_path_str = search_path
     905            0 :         .into_iter()
     906            0 :         .filter(|s| !s.is_empty())
     907            0 :         .collect::<Vec<_>>()
     908            0 :         .join(",");
     909            0 :     let mut env: HashMap<&str, &str> = HashMap::from([
     910            0 :         ("request.method", api_request.method),
     911            0 :         ("request.path", api_request.path),
     912            0 :         ("request.headers", &headers_env),
     913            0 :         ("request.cookies", &cookies_env),
     914            0 :         ("request.get", &get_env),
     915            0 :         ("request.jwt.claims", &jwt_claims_env),
     916            0 :         ("search_path", &search_path_str),
     917            0 :     ]);
     918            0 :     if let Some(r) = env_role {
     919            0 :         env.insert("role", r);
     920            0 :     }
     921              : 
     922              :     // generate the sql statements
     923            0 :     let (env_statement, env_parameters, _) = generate(fmt_env_query(&env));
     924            0 :     let (main_statement, main_parameters, _) = generate(fmt_main_query(
     925            0 :         db_schema,
     926            0 :         api_request.schema_name,
     927            0 :         &api_request,
     928            0 :         &env,
     929            0 :     )?);
     930              : 
     931            0 :     let mut headers = vec![
     932            0 :         (&NEON_REQUEST_ID, uuid_to_header_value(ctx.session_id())),
     933            0 :         (
     934            0 :             &CONN_STRING,
     935            0 :             HeaderValue::from_str(connection_string).expect("invalid connection string"),
     936            0 :         ),
     937            0 :         (&AUTHORIZATION, auth_header.clone()),
     938            0 :         (
     939            0 :             &TXN_ISOLATION_LEVEL,
     940            0 :             HeaderValue::from_static("ReadCommitted"),
     941            0 :         ),
     942            0 :         (&ALLOW_POOL, HEADER_VALUE_TRUE),
     943              :     ];
     944              : 
     945            0 :     if api_request.read_only {
     946            0 :         headers.push((&TXN_READ_ONLY, HEADER_VALUE_TRUE));
     947            0 :     }
     948              : 
     949              :     // convert the parameters from subzero core representation to the local proxy repr.
     950            0 :     let req_body = serde_json::to_string(&BatchQueryData {
     951            0 :         queries: vec![
     952              :             QueryData {
     953            0 :                 query: env_statement.into(),
     954            0 :                 params: env_parameters
     955            0 :                     .iter()
     956            0 :                     .map(|p| to_sql_param(&p.to_param()))
     957            0 :                     .collect(),
     958              :             },
     959              :             QueryData {
     960            0 :                 query: main_statement.into(),
     961            0 :                 params: main_parameters
     962            0 :                     .iter()
     963            0 :                     .map(|p| to_sql_param(&p.to_param()))
     964            0 :                     .collect(),
     965              :             },
     966              :         ],
     967              :     })
     968            0 :     .map_err(|e| RestError::JsonConversion(JsonConversionError::ParseJsonError(e)))?;
     969              : 
     970              :     // todo: map body to count egress
     971            0 :     let _metrics = client.metrics(ctx); // FIXME: is everything in the context set correctly?
     972              : 
     973              :     // send the request to the local proxy
     974            0 :     let response = make_raw_local_proxy_request(&mut client, headers, req_body).await?;
     975            0 :     let (parts, body) = response.into_parts();
     976              : 
     977            0 :     let max_response = config.http_config.max_response_size_bytes;
     978            0 :     let bytes = read_body_with_limit(body, max_response)
     979            0 :         .await
     980            0 :         .map_err(ReadPayloadError::from)?;
     981              : 
     982              :     // if the response status is greater than 399, then it is an error
     983              :     // FIXME: check if there are other error codes or shapes of the response
     984            0 :     if parts.status.as_u16() > 399 {
     985              :         // turn this postgres error from the json into PostgresError
     986            0 :         let postgres_error = serde_json::from_slice(&bytes)
     987            0 :             .map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?;
     988              : 
     989            0 :         return Err(RestError::Postgres(postgres_error));
     990            0 :     }
     991              : 
     992            0 :     #[derive(Deserialize)]
     993              :     struct QueryResults {
     994              :         /// we run two queries, so we want only two results.
     995              :         results: (EnvRows, MainRows),
     996              :     }
     997              : 
     998              :     /// `env_statement` returns nothing of interest to us
     999            0 :     #[derive(Deserialize)]
    1000              :     struct EnvRows {}
    1001              : 
    1002            0 :     #[derive(Deserialize)]
    1003              :     struct MainRows {
    1004              :         /// `main_statement` only returns a single row.
    1005              :         rows: [MainRow; 1],
    1006              :     }
    1007              : 
    1008            0 :     #[derive(Deserialize)]
    1009              :     struct MainRow {
    1010              :         body: String,
    1011              :         page_total: Option<String>,
    1012              :         total_result_set: Option<String>,
    1013              :         response_headers: Option<String>,
    1014              :         response_status: Option<String>,
    1015              :     }
    1016              : 
    1017            0 :     let results: QueryResults = serde_json::from_slice(&bytes)
    1018            0 :         .map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?;
    1019              : 
    1020              :     let QueryResults {
    1021            0 :         results: (_, MainRows { rows: [row] }),
    1022            0 :     } = results;
    1023              : 
    1024              :     // build the intermediate response object
    1025            0 :     let api_response = ApiResponse {
    1026            0 :         page_total: row.page_total.map_or(0, |v| v.parse::<u64>().unwrap_or(0)),
    1027            0 :         total_result_set: row.total_result_set.map(|v| v.parse::<u64>().unwrap_or(0)),
    1028              :         top_level_offset: 0, // FIXME: check why this is 0
    1029            0 :         response_headers: row.response_headers,
    1030            0 :         response_status: row.response_status,
    1031            0 :         body: row.body,
    1032              :     };
    1033              : 
    1034              :     // TODO: rollback the transaction if the page_total is not 1 and the accept_content_type is SingularJSON
    1035              :     // we can not do this in the context of proxy for now
    1036              :     // if api_request.accept_content_type == SingularJSON && api_response.page_total != 1 {
    1037              :     //     // rollback the transaction here
    1038              :     //     return Err(RestError::SubzeroCore(SingularityError {
    1039              :     //         count: api_response.page_total,
    1040              :     //         content_type: "application/vnd.pgrst.object+json".to_string(),
    1041              :     //     }));
    1042              :     // }
    1043              : 
    1044              :     // TODO: rollback the transaction if the page_total is not 1 and the method is PUT
    1045              :     // we can not do this in the context of proxy for now
    1046              :     // if api_request.method == Method::PUT && api_response.page_total != 1 {
    1047              :     //     // Makes sure the querystring pk matches the payload pk
    1048              :     //     // e.g. PUT /items?id=eq.1 { "id" : 1, .. } is accepted,
    1049              :     //     // PUT /items?id=eq.14 { "id" : 2, .. } is rejected.
    1050              :     //     // If this condition is not satisfied then nothing is inserted,
    1051              :     //     // rollback the transaction here
    1052              :     //     return Err(RestError::SubzeroCore(PutMatchingPkError));
    1053              :     // }
    1054              : 
    1055              :     // create and return the response to the client
    1056              :     // this section mostly deals with setting the right headers according to PostgREST specs
    1057            0 :     let page_total = api_response.page_total;
    1058            0 :     let total_result_set = api_response.total_result_set;
    1059            0 :     let top_level_offset = api_response.top_level_offset;
    1060            0 :     let response_content_type = match (&api_request.accept_content_type, &api_request.query.node) {
    1061              :         (SingularJSON, _)
    1062              :         | (
    1063              :             _,
    1064              :             FunctionCall {
    1065              :                 returns_single: true,
    1066              :                 is_scalar: false,
    1067              :                 ..
    1068              :             },
    1069            0 :         ) => SingularJSON,
    1070            0 :         (TextCSV, _) => TextCSV,
    1071            0 :         _ => ApplicationJSON,
    1072              :     };
    1073              : 
    1074              :     // check if the SQL env set some response headers (happens when we called a rpc function)
    1075            0 :     if let Some(response_headers_str) = api_response.response_headers {
    1076            0 :         let Ok(headers_json) =
    1077            0 :             serde_json::from_str::<Vec<Vec<(String, String)>>>(response_headers_str.as_str())
    1078              :         else {
    1079            0 :             return Err(RestError::SubzeroCore(GucHeadersError));
    1080              :         };
    1081              : 
    1082            0 :         response_headers.extend(headers_json.into_iter().flatten());
    1083            0 :     }
    1084              : 
    1085              :     // calculate and set the content range header
    1086            0 :     let lower = top_level_offset as i64;
    1087            0 :     let upper = top_level_offset as i64 + page_total as i64 - 1;
    1088            0 :     let total = total_result_set.map(|t| t as i64);
    1089            0 :     let content_range = match (&method, &api_request.query.node) {
    1090            0 :         (&Method::POST, Insert { .. }) => content_range_header(1, 0, total),
    1091            0 :         (&Method::DELETE, Delete { .. }) => content_range_header(1, upper, total),
    1092            0 :         _ => content_range_header(lower, upper, total),
    1093              :     };
    1094            0 :     response_headers.push(("Content-Range".to_string(), content_range));
    1095              : 
    1096              :     // calculate the status code
    1097              :     #[rustfmt::skip]
    1098            0 :     let mut status = match (&method, &api_request.query.node, page_total, &api_request.preferences) {
    1099            0 :         (&Method::POST,   Insert { .. }, ..) => 201,
    1100            0 :         (&Method::DELETE, Delete { .. }, _, Some(Preferences {representation: Some(Representation::Full),..}),) => 200,
    1101            0 :         (&Method::DELETE, Delete { .. }, ..) => 204,
    1102            0 :         (&Method::PATCH,  Update { columns, .. }, 0, _) if !columns.is_empty() => 404,
    1103            0 :         (&Method::PATCH,  Update { .. }, _,Some(Preferences {representation: Some(Representation::Full),..}),) => 200,
    1104            0 :         (&Method::PATCH,  Update { .. }, ..) => 204,
    1105            0 :         (&Method::PUT,    Insert { .. },_,Some(Preferences {representation: Some(Representation::Full),..}),) => 200,
    1106            0 :         (&Method::PUT,    Insert { .. }, ..) => 204,
    1107            0 :         _ => content_range_status(lower, upper, total),
    1108              :     };
    1109              : 
    1110              :     // add the preference-applied header
    1111              :     if let Some(Preferences {
    1112            0 :         resolution: Some(r),
    1113              :         ..
    1114            0 :     }) = api_request.preferences
    1115              :     {
    1116            0 :         response_headers.push((
    1117            0 :             "Preference-Applied".to_string(),
    1118            0 :             match r {
    1119            0 :                 MergeDuplicates => "resolution=merge-duplicates".to_string(),
    1120            0 :                 IgnoreDuplicates => "resolution=ignore-duplicates".to_string(),
    1121              :             },
    1122              :         ));
    1123            0 :     }
    1124              : 
    1125              :     // check if the SQL env set some response status (happens when we called a rpc function)
    1126            0 :     if let Some(response_status_str) = api_response.response_status {
    1127            0 :         status = response_status_str
    1128            0 :             .parse::<u16>()
    1129            0 :             .map_err(|_| RestError::SubzeroCore(GucStatusError))?;
    1130            0 :     }
    1131              : 
    1132              :     // set the content type header
    1133              :     // TODO: move this to a subzero function
    1134              :     // as_header_value(&self) -> Option<&str>
    1135            0 :     let http_content_type = match response_content_type {
    1136            0 :         SingularJSON => Ok("application/vnd.pgrst.object+json"),
    1137            0 :         TextCSV => Ok("text/csv"),
    1138            0 :         ApplicationJSON => Ok("application/json"),
    1139            0 :         Other(t) => Err(RestError::SubzeroCore(ContentTypeError {
    1140            0 :             message: format!("None of these Content-Types are available: {t}"),
    1141            0 :         })),
    1142            0 :     }?;
    1143              : 
    1144              :     // build the response body
    1145            0 :     let response_body = Full::new(Bytes::from(api_response.body))
    1146            0 :         .map_err(|never| match never {})
    1147            0 :         .boxed();
    1148              : 
    1149              :     // build the response
    1150            0 :     let mut response = Response::builder()
    1151            0 :         .status(StatusCode::from_u16(status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR))
    1152            0 :         .header(CONTENT_TYPE, http_content_type);
    1153              : 
    1154              :     // Add all headers from response_headers vector
    1155            0 :     for (header_name, header_value) in response_headers {
    1156            0 :         response = response.header(header_name, header_value);
    1157            0 :     }
    1158              : 
    1159              :     // add the body and return the response
    1160            0 :     response.body(response_body).map_err(|_| {
    1161            0 :         RestError::SubzeroCore(InternalError {
    1162            0 :             message: "Failed to build response".to_string(),
    1163            0 :         })
    1164            0 :     })
    1165            0 : }
        

Generated by: LCOV version 2.1-beta