LCOV - code coverage report
Current view: top level - proxy/src/serverless - rest.rs (source / functions) Coverage Total Hit
Test: c8f8d331b83562868d9054d9e0e68f866772aeaa.info Lines: 0.0 % 687 0
Test Date: 2025-07-26 17:20:05 Functions: 0.0 % 87 0

            Line data    Source code
       1              : use std::borrow::Cow;
       2              : use std::collections::HashMap;
       3              : use std::sync::Arc;
       4              : 
       5              : use bytes::Bytes;
       6              : use http::Method;
       7              : use http::header::{AUTHORIZATION, CONTENT_TYPE, HOST};
       8              : use http_body_util::combinators::BoxBody;
       9              : use http_body_util::{BodyExt, Full};
      10              : use http_utils::error::ApiError;
      11              : use hyper::body::Incoming;
      12              : use hyper::http::{HeaderName, HeaderValue};
      13              : use hyper::{Request, Response, StatusCode};
      14              : use indexmap::IndexMap;
      15              : use moka::sync::Cache;
      16              : use ouroboros::self_referencing;
      17              : use serde::de::DeserializeOwned;
      18              : use serde::{Deserialize, Deserializer};
      19              : use serde_json::Value as JsonValue;
      20              : use serde_json::value::RawValue;
      21              : use subzero_core::api::ContentType::{ApplicationJSON, Other, SingularJSON, TextCSV};
      22              : use subzero_core::api::QueryNode::{Delete, FunctionCall, Insert, Update};
      23              : use subzero_core::api::Resolution::{IgnoreDuplicates, MergeDuplicates};
      24              : use subzero_core::api::{ApiResponse, ListVal, Payload, Preferences, Representation, SingleVal};
      25              : use subzero_core::config::{db_allowed_select_functions, db_schemas, role_claim_key};
      26              : use subzero_core::dynamic_statement::{JoinIterator, param, sql};
      27              : use subzero_core::error::Error::{
      28              :     self as SubzeroCoreError, ContentTypeError, GucHeadersError, GucStatusError, InternalError,
      29              :     JsonDeserialize, JwtTokenInvalid, NotFound,
      30              : };
      31              : use subzero_core::error::pg_error_to_status_code;
      32              : use subzero_core::formatter::Param::{LV, PL, SV, Str, StrOwned};
      33              : use subzero_core::formatter::postgresql::{fmt_main_query, generate};
      34              : use subzero_core::formatter::{Param, Snippet, SqlParam};
      35              : use subzero_core::parser::postgrest::parse;
      36              : use subzero_core::permissions::{check_safe_functions, replace_select_star};
      37              : use subzero_core::schema::{
      38              :     DbSchema, POSTGRESQL_INTROSPECTION_SQL, get_postgresql_configuration_query,
      39              : };
      40              : use subzero_core::{content_range_header, content_range_status};
      41              : use tokio_util::sync::CancellationToken;
      42              : use tracing::{error, info};
      43              : use typed_json::json;
      44              : use url::form_urlencoded;
      45              : 
      46              : use super::backend::{HttpConnError, LocalProxyConnError, PoolingBackend};
      47              : use super::conn_pool::AuthData;
      48              : use super::conn_pool_lib::ConnInfo;
      49              : use super::error::{ConnInfoError, Credentials, HttpCodeError, ReadPayloadError};
      50              : use super::http_conn_pool::{self, LocalProxyClient};
      51              : use super::http_util::{
      52              :     ALLOW_POOL, CONN_STRING, NEON_REQUEST_ID, RAW_TEXT_OUTPUT, TXN_ISOLATION_LEVEL, TXN_READ_ONLY,
      53              :     get_conn_info, json_response, uuid_to_header_value,
      54              : };
      55              : use super::json::JsonConversionError;
      56              : use crate::auth::backend::ComputeCredentialKeys;
      57              : use crate::config::ProxyConfig;
      58              : use crate::context::RequestContext;
      59              : use crate::error::{ErrorKind, ReportableError, UserFacingError};
      60              : use crate::http::read_body_with_limit;
      61              : use crate::metrics::Metrics;
      62              : use crate::serverless::sql_over_http::HEADER_VALUE_TRUE;
      63              : use crate::types::EndpointCacheKey;
      64              : use crate::util::deserialize_json_string;
      65              : 
      66              : static EMPTY_JSON_SCHEMA: &str = r#"{"schemas":[]}"#;
      67              : const INTROSPECTION_SQL: &str = POSTGRESQL_INTROSPECTION_SQL;
      68              : 
      69              : // A wrapper around the DbSchema that allows for self-referencing
      70              : #[self_referencing]
      71              : pub struct DbSchemaOwned {
      72              :     schema_string: String,
      73              :     #[covariant]
      74              :     #[borrows(schema_string)]
      75              :     schema: DbSchema<'this>,
      76              : }
      77              : 
      78              : impl<'de> Deserialize<'de> for DbSchemaOwned {
      79            0 :     fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
      80            0 :     where
      81            0 :         D: Deserializer<'de>,
      82              :     {
      83            0 :         let s = String::deserialize(deserializer)?;
      84            0 :         DbSchemaOwned::try_new(s, |s| serde_json::from_str(s))
      85            0 :             .map_err(<D::Error as serde::de::Error>::custom)
      86            0 :     }
      87              : }
      88              : 
      89            0 : fn split_comma_separated(s: &str) -> Vec<String> {
      90            0 :     s.split(',').map(|s| s.trim().to_string()).collect()
      91            0 : }
      92              : 
      93            0 : fn deserialize_comma_separated<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
      94            0 : where
      95            0 :     D: Deserializer<'de>,
      96              : {
      97            0 :     let s = String::deserialize(deserializer)?;
      98            0 :     Ok(split_comma_separated(&s))
      99            0 : }
     100              : 
     101            0 : fn deserialize_comma_separated_option<'de, D>(
     102            0 :     deserializer: D,
     103            0 : ) -> Result<Option<Vec<String>>, D::Error>
     104            0 : where
     105            0 :     D: Deserializer<'de>,
     106              : {
     107            0 :     let opt = Option::<String>::deserialize(deserializer)?;
     108            0 :     if let Some(s) = &opt {
     109            0 :         let trimmed = s.trim();
     110            0 :         if trimmed.is_empty() {
     111            0 :             return Ok(None);
     112            0 :         }
     113            0 :         return Ok(Some(split_comma_separated(trimmed)));
     114            0 :     }
     115            0 :     Ok(None)
     116            0 : }
     117              : 
     118              : // The ApiConfig is the configuration for the API per endpoint
     119              : // The configuration is read from the database and cached in the DbSchemaCache
     120              : #[derive(Deserialize, Debug)]
     121              : pub struct ApiConfig {
     122              :     #[serde(
     123              :         default = "db_schemas",
     124              :         deserialize_with = "deserialize_comma_separated"
     125              :     )]
     126              :     pub db_schemas: Vec<String>,
     127              :     pub db_anon_role: Option<String>,
     128              :     pub db_max_rows: Option<String>,
     129              :     #[serde(default = "db_allowed_select_functions")]
     130              :     pub db_allowed_select_functions: Vec<String>,
     131              :     // #[serde(deserialize_with = "to_tuple", default)]
     132              :     // pub db_pre_request: Option<(String, String)>,
     133              :     #[allow(dead_code)]
     134              :     #[serde(default = "role_claim_key")]
     135              :     pub role_claim_key: String,
     136              :     #[serde(default, deserialize_with = "deserialize_comma_separated_option")]
     137              :     pub db_extra_search_path: Option<Vec<String>>,
     138              : }
     139              : 
     140              : // The DbSchemaCache is a cache of the ApiConfig and DbSchemaOwned for each endpoint
     141              : pub(crate) struct DbSchemaCache(pub Cache<EndpointCacheKey, Arc<(ApiConfig, DbSchemaOwned)>>);
     142              : impl DbSchemaCache {
     143            0 :     pub fn new(config: crate::config::CacheOptions) -> Self {
     144            0 :         let builder = Cache::builder().name("db_schema_cache");
     145            0 :         let builder = config.moka(builder);
     146              : 
     147            0 :         Self(builder.build())
     148            0 :     }
     149              : 
     150            0 :     pub async fn get_cached_or_remote(
     151            0 :         &self,
     152            0 :         endpoint_id: &EndpointCacheKey,
     153            0 :         auth_header: &HeaderValue,
     154            0 :         connection_string: &str,
     155            0 :         client: &mut http_conn_pool::Client<LocalProxyClient>,
     156            0 :         ctx: &RequestContext,
     157            0 :         config: &'static ProxyConfig,
     158            0 :     ) -> Result<Arc<(ApiConfig, DbSchemaOwned)>, RestError> {
     159            0 :         match self.0.get(endpoint_id) {
     160            0 :             Some(v) => Ok(v),
     161              :             None => {
     162            0 :                 info!("db_schema cache miss for endpoint: {:?}", endpoint_id);
     163            0 :                 let remote_value = self
     164            0 :                     .get_remote(auth_header, connection_string, client, ctx, config)
     165            0 :                     .await;
     166            0 :                 let (api_config, schema_owned) = match remote_value {
     167            0 :                     Ok((api_config, schema_owned)) => (api_config, schema_owned),
     168            0 :                     Err(e @ RestError::SchemaTooLarge) => {
     169              :                         // for the case where the schema is too large, we cache an empty dummy value
     170              :                         // all the other requests will fail without triggering the introspection query
     171            0 :                         let schema_owned = serde_json::from_str::<DbSchemaOwned>(EMPTY_JSON_SCHEMA)
     172            0 :                             .map_err(|e| JsonDeserialize { source: e })?;
     173              : 
     174            0 :                         let api_config = ApiConfig {
     175            0 :                             db_schemas: vec![],
     176            0 :                             db_anon_role: None,
     177            0 :                             db_max_rows: None,
     178            0 :                             db_allowed_select_functions: vec![],
     179            0 :                             role_claim_key: String::new(),
     180            0 :                             db_extra_search_path: None,
     181            0 :                         };
     182            0 :                         let value = Arc::new((api_config, schema_owned));
     183            0 :                         self.0.insert(endpoint_id.clone(), value);
     184            0 :                         return Err(e);
     185              :                     }
     186            0 :                     Err(e) => {
     187            0 :                         return Err(e);
     188              :                     }
     189              :                 };
     190            0 :                 let value = Arc::new((api_config, schema_owned));
     191            0 :                 self.0.insert(endpoint_id.clone(), value.clone());
     192            0 :                 Ok(value)
     193              :             }
     194              :         }
     195            0 :     }
     196            0 :     pub async fn get_remote(
     197            0 :         &self,
     198            0 :         auth_header: &HeaderValue,
     199            0 :         connection_string: &str,
     200            0 :         client: &mut http_conn_pool::Client<LocalProxyClient>,
     201            0 :         ctx: &RequestContext,
     202            0 :         config: &'static ProxyConfig,
     203            0 :     ) -> Result<(ApiConfig, DbSchemaOwned), RestError> {
     204            0 :         #[derive(Deserialize)]
     205              :         struct SingleRow<Row> {
     206              :             rows: [Row; 1],
     207              :         }
     208              : 
     209              :         #[derive(Deserialize)]
     210              :         struct ConfigRow {
     211              :             #[serde(deserialize_with = "deserialize_json_string")]
     212              :             config: ApiConfig,
     213              :         }
     214              : 
     215            0 :         #[derive(Deserialize)]
     216              :         struct SchemaRow {
     217              :             json_schema: DbSchemaOwned,
     218              :         }
     219              : 
     220            0 :         let headers = vec![
     221            0 :             (&NEON_REQUEST_ID, uuid_to_header_value(ctx.session_id())),
     222            0 :             (
     223            0 :                 &CONN_STRING,
     224            0 :                 HeaderValue::from_str(connection_string).expect(
     225            0 :                     "connection string came from a header, so it must be a valid headervalue",
     226            0 :                 ),
     227            0 :             ),
     228            0 :             (&AUTHORIZATION, auth_header.clone()),
     229            0 :             (&RAW_TEXT_OUTPUT, HEADER_VALUE_TRUE),
     230              :         ];
     231              : 
     232            0 :         let query = get_postgresql_configuration_query(Some("pgrst.pre_config"));
     233              :         let SingleRow {
     234            0 :             rows: [ConfigRow { config: api_config }],
     235            0 :         } = make_local_proxy_request(
     236            0 :             client,
     237            0 :             headers.iter().cloned(),
     238            0 :             QueryData {
     239            0 :                 query: Cow::Owned(query),
     240            0 :                 params: vec![],
     241            0 :             },
     242            0 :             config.rest_config.max_schema_size,
     243            0 :         )
     244            0 :         .await
     245            0 :         .map_err(|e| match e {
     246              :             RestError::ReadPayload(ReadPayloadError::BodyTooLarge { .. }) => {
     247            0 :                 RestError::SchemaTooLarge
     248              :             }
     249            0 :             e => e,
     250            0 :         })?;
     251              : 
     252              :         // now that we have the api_config let's run the second INTROSPECTION_SQL query
     253              :         let SingleRow {
     254            0 :             rows: [SchemaRow { json_schema }],
     255            0 :         } = make_local_proxy_request(
     256            0 :             client,
     257            0 :             headers,
     258            0 :             QueryData {
     259            0 :                 query: INTROSPECTION_SQL.into(),
     260            0 :                 params: vec![
     261            0 :                     serde_json::to_value(&api_config.db_schemas)
     262            0 :                         .expect("Vec<String> is always valid to encode as JSON"),
     263            0 :                     JsonValue::Bool(false), // include_roles_with_login
     264            0 :                     JsonValue::Bool(false), // use_internal_permissions
     265            0 :                 ],
     266            0 :             },
     267            0 :             config.rest_config.max_schema_size,
     268            0 :         )
     269            0 :         .await
     270            0 :         .map_err(|e| match e {
     271              :             RestError::ReadPayload(ReadPayloadError::BodyTooLarge { .. }) => {
     272            0 :                 RestError::SchemaTooLarge
     273              :             }
     274            0 :             e => e,
     275            0 :         })?;
     276              : 
     277            0 :         Ok((api_config, json_schema))
     278            0 :     }
     279              : }
     280              : 
     281              : // A type to represent a postgresql errors
     282              : // we use our own type (instead of postgres_client::Error) because we get the error from the json response
     283            0 : #[derive(Debug, thiserror::Error, Deserialize)]
     284              : pub(crate) struct PostgresError {
     285              :     pub code: String,
     286              :     pub message: String,
     287              :     pub detail: Option<String>,
     288              :     pub hint: Option<String>,
     289              : }
     290              : impl HttpCodeError for PostgresError {
     291            0 :     fn get_http_status_code(&self) -> StatusCode {
     292            0 :         let status = pg_error_to_status_code(&self.code, true);
     293            0 :         StatusCode::from_u16(status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)
     294            0 :     }
     295              : }
     296              : impl ReportableError for PostgresError {
     297            0 :     fn get_error_kind(&self) -> ErrorKind {
     298            0 :         ErrorKind::User
     299            0 :     }
     300              : }
     301              : impl UserFacingError for PostgresError {
     302            0 :     fn to_string_client(&self) -> String {
     303            0 :         if self.code.starts_with("PT") {
     304            0 :             "Postgres error".to_string()
     305              :         } else {
     306            0 :             self.message.clone()
     307              :         }
     308            0 :     }
     309              : }
     310              : impl std::fmt::Display for PostgresError {
     311            0 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
     312            0 :         write!(f, "{}", self.message)
     313            0 :     }
     314              : }
     315              : 
     316              : // A type to represent errors that can occur in the rest broker
     317              : #[derive(Debug, thiserror::Error)]
     318              : pub(crate) enum RestError {
     319              :     #[error(transparent)]
     320              :     ReadPayload(#[from] ReadPayloadError),
     321              :     #[error(transparent)]
     322              :     ConnectCompute(#[from] HttpConnError),
     323              :     #[error(transparent)]
     324              :     ConnInfo(#[from] ConnInfoError),
     325              :     #[error(transparent)]
     326              :     Postgres(#[from] PostgresError),
     327              :     #[error(transparent)]
     328              :     JsonConversion(#[from] JsonConversionError),
     329              :     #[error(transparent)]
     330              :     SubzeroCore(#[from] SubzeroCoreError),
     331              :     #[error("schema is too large")]
     332              :     SchemaTooLarge,
     333              : }
     334              : impl ReportableError for RestError {
     335            0 :     fn get_error_kind(&self) -> ErrorKind {
     336            0 :         match self {
     337            0 :             RestError::ReadPayload(e) => e.get_error_kind(),
     338            0 :             RestError::ConnectCompute(e) => e.get_error_kind(),
     339            0 :             RestError::ConnInfo(e) => e.get_error_kind(),
     340            0 :             RestError::Postgres(_) => ErrorKind::Postgres,
     341            0 :             RestError::JsonConversion(_) => ErrorKind::Postgres,
     342            0 :             RestError::SubzeroCore(_) => ErrorKind::User,
     343            0 :             RestError::SchemaTooLarge => ErrorKind::User,
     344              :         }
     345            0 :     }
     346              : }
     347              : impl UserFacingError for RestError {
     348            0 :     fn to_string_client(&self) -> String {
     349            0 :         match self {
     350            0 :             RestError::ReadPayload(p) => p.to_string(),
     351            0 :             RestError::ConnectCompute(c) => c.to_string_client(),
     352            0 :             RestError::ConnInfo(c) => c.to_string_client(),
     353            0 :             RestError::SchemaTooLarge => self.to_string(),
     354            0 :             RestError::Postgres(p) => p.to_string_client(),
     355            0 :             RestError::JsonConversion(_) => "could not parse postgres response".to_string(),
     356            0 :             RestError::SubzeroCore(s) => {
     357              :                 // TODO: this is a hack to get the message from the json body
     358            0 :                 let json = s.json_body();
     359            0 :                 let default_message = "Unknown error".to_string();
     360              : 
     361            0 :                 json.get("message")
     362            0 :                     .map_or(default_message.clone(), |m| match m {
     363            0 :                         JsonValue::String(s) => s.clone(),
     364            0 :                         _ => default_message,
     365            0 :                     })
     366              :             }
     367              :         }
     368            0 :     }
     369              : }
     370              : impl HttpCodeError for RestError {
     371            0 :     fn get_http_status_code(&self) -> StatusCode {
     372            0 :         match self {
     373            0 :             RestError::ReadPayload(e) => e.get_http_status_code(),
     374            0 :             RestError::ConnectCompute(h) => match h.get_error_kind() {
     375            0 :                 ErrorKind::User => StatusCode::BAD_REQUEST,
     376            0 :                 _ => StatusCode::INTERNAL_SERVER_ERROR,
     377              :             },
     378            0 :             RestError::ConnInfo(_) => StatusCode::BAD_REQUEST,
     379            0 :             RestError::Postgres(e) => e.get_http_status_code(),
     380            0 :             RestError::JsonConversion(_) => StatusCode::INTERNAL_SERVER_ERROR,
     381            0 :             RestError::SchemaTooLarge => StatusCode::INTERNAL_SERVER_ERROR,
     382            0 :             RestError::SubzeroCore(e) => {
     383            0 :                 let status = e.status_code();
     384            0 :                 StatusCode::from_u16(status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)
     385              :             }
     386              :         }
     387            0 :     }
     388              : }
     389              : 
     390              : // Helper functions for the rest broker
     391              : 
     392            0 : fn fmt_env_query<'a>(env: &'a HashMap<&'a str, &'a str>) -> Snippet<'a> {
     393              :     "select "
     394            0 :         + if env.is_empty() {
     395            0 :             sql("null")
     396              :         } else {
     397            0 :             env.iter()
     398            0 :                 .map(|(k, v)| {
     399            0 :                     "set_config(" + param(k as &SqlParam) + ", " + param(v as &SqlParam) + ", true)"
     400            0 :                 })
     401            0 :                 .join(",")
     402              :         }
     403            0 : }
     404              : 
     405              : // TODO: see about removing the need for cloning the values (inner things are &Cow<str> already)
     406            0 : fn to_sql_param(p: &Param) -> JsonValue {
     407            0 :     match p {
     408            0 :         SV(SingleVal(v, ..)) => JsonValue::String(v.to_string()),
     409            0 :         Str(v) => JsonValue::String((*v).to_string()),
     410            0 :         StrOwned(v) => JsonValue::String((*v).clone()),
     411            0 :         PL(Payload(v, ..)) => JsonValue::String(v.clone().into_owned()),
     412            0 :         LV(ListVal(v, ..)) => {
     413            0 :             if v.is_empty() {
     414            0 :                 JsonValue::String(r"{}".to_string())
     415              :             } else {
     416            0 :                 JsonValue::String(format!(
     417            0 :                     "{{\"{}\"}}",
     418            0 :                     v.iter()
     419            0 :                         .map(|e| e.replace('\\', "\\\\").replace('\"', "\\\""))
     420            0 :                         .collect::<Vec<_>>()
     421            0 :                         .join("\",\"")
     422              :                 ))
     423              :             }
     424              :         }
     425              :     }
     426            0 : }
     427              : 
     428              : #[derive(serde::Serialize)]
     429              : struct QueryData<'a> {
     430              :     query: Cow<'a, str>,
     431              :     params: Vec<JsonValue>,
     432              : }
     433              : 
     434              : #[derive(serde::Serialize)]
     435              : struct BatchQueryData<'a> {
     436              :     queries: Vec<QueryData<'a>>,
     437              : }
     438              : 
     439            0 : async fn make_local_proxy_request<S: DeserializeOwned>(
     440            0 :     client: &mut http_conn_pool::Client<LocalProxyClient>,
     441            0 :     headers: impl IntoIterator<Item = (&HeaderName, HeaderValue)>,
     442            0 :     body: QueryData<'_>,
     443            0 :     max_len: usize,
     444            0 : ) -> Result<S, RestError> {
     445            0 :     let body_string = serde_json::to_string(&body)
     446            0 :         .map_err(|e| RestError::JsonConversion(JsonConversionError::ParseJsonError(e)))?;
     447              : 
     448            0 :     let response = make_raw_local_proxy_request(client, headers, body_string).await?;
     449              : 
     450            0 :     let response_status = response.status();
     451              : 
     452            0 :     if response_status != StatusCode::OK {
     453            0 :         return Err(RestError::SubzeroCore(InternalError {
     454            0 :             message: "Failed to get endpoint schema".to_string(),
     455            0 :         }));
     456            0 :     }
     457              : 
     458              :     // Capture the response body
     459            0 :     let response_body = crate::http::read_body_with_limit(response.into_body(), max_len)
     460            0 :         .await
     461            0 :         .map_err(ReadPayloadError::from)?;
     462              : 
     463              :     // Parse the JSON response
     464            0 :     let response_json: S = serde_json::from_slice(&response_body)
     465            0 :         .map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?;
     466              : 
     467            0 :     Ok(response_json)
     468            0 : }
     469              : 
     470            0 : async fn make_raw_local_proxy_request(
     471            0 :     client: &mut http_conn_pool::Client<LocalProxyClient>,
     472            0 :     headers: impl IntoIterator<Item = (&HeaderName, HeaderValue)>,
     473            0 :     body: String,
     474            0 : ) -> Result<Response<Incoming>, RestError> {
     475            0 :     let local_proxy_uri = ::http::Uri::from_static("http://proxy.local/sql");
     476            0 :     let mut req = Request::builder().method(Method::POST).uri(local_proxy_uri);
     477            0 :     let req_headers = req.headers_mut().expect("failed to get headers");
     478              :     // Add all provided headers to the request
     479            0 :     for (header_name, header_value) in headers {
     480            0 :         req_headers.insert(header_name, header_value.clone());
     481            0 :     }
     482              : 
     483            0 :     let body_boxed = Full::new(Bytes::from(body))
     484            0 :         .map_err(|never| match never {}) // Convert Infallible to hyper::Error
     485            0 :         .boxed();
     486              : 
     487            0 :     let req = req.body(body_boxed).map_err(|_| {
     488            0 :         RestError::SubzeroCore(InternalError {
     489            0 :             message: "Failed to build request".to_string(),
     490            0 :         })
     491            0 :     })?;
     492              : 
     493              :     // Send the request to the local proxy
     494            0 :     client
     495            0 :         .inner
     496            0 :         .inner
     497            0 :         .send_request(req)
     498            0 :         .await
     499            0 :         .map_err(LocalProxyConnError::from)
     500            0 :         .map_err(HttpConnError::from)
     501            0 :         .map_err(RestError::from)
     502            0 : }
     503              : 
     504            0 : pub(crate) async fn handle(
     505            0 :     config: &'static ProxyConfig,
     506            0 :     ctx: RequestContext,
     507            0 :     request: Request<Incoming>,
     508            0 :     backend: Arc<PoolingBackend>,
     509            0 :     cancel: CancellationToken,
     510            0 : ) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
     511            0 :     let result = handle_inner(cancel, config, &ctx, request, backend).await;
     512              : 
     513            0 :     let mut response = match result {
     514            0 :         Ok(r) => {
     515            0 :             ctx.set_success();
     516              : 
     517              :             // Handling the error response from local proxy here
     518            0 :             if r.status().is_server_error() {
     519            0 :                 let status = r.status();
     520              : 
     521            0 :                 let body_bytes = r
     522            0 :                     .collect()
     523            0 :                     .await
     524            0 :                     .map_err(|e| {
     525            0 :                         ApiError::InternalServerError(anyhow::Error::msg(format!(
     526            0 :                             "could not collect http body: {e}"
     527            0 :                         )))
     528            0 :                     })?
     529            0 :                     .to_bytes();
     530              : 
     531            0 :                 if let Ok(mut json_map) =
     532            0 :                     serde_json::from_slice::<IndexMap<&str, &RawValue>>(&body_bytes)
     533              :                 {
     534            0 :                     let message = json_map.get("message");
     535            0 :                     if let Some(message) = message {
     536            0 :                         let msg: String = match serde_json::from_str(message.get()) {
     537            0 :                             Ok(msg) => msg,
     538              :                             Err(_) => {
     539            0 :                                 "Unable to parse the response message from server".to_string()
     540              :                             }
     541              :                         };
     542              : 
     543            0 :                         error!("Error response from local_proxy: {status} {msg}");
     544              : 
     545            0 :                         json_map.retain(|key, _| !key.starts_with("neon:")); // remove all the neon-related keys
     546              : 
     547            0 :                         let resp_json = serde_json::to_string(&json_map)
     548            0 :                             .unwrap_or("failed to serialize the response message".to_string());
     549              : 
     550            0 :                         return json_response(status, resp_json);
     551            0 :                     }
     552            0 :                 }
     553              : 
     554            0 :                 error!("Unable to parse the response message from local_proxy");
     555            0 :                 return json_response(
     556            0 :                     status,
     557            0 :                     json!({ "message": "Unable to parse the response message from server".to_string() }),
     558              :                 );
     559            0 :             }
     560            0 :             r
     561              :         }
     562            0 :         Err(e @ RestError::SubzeroCore(_)) => {
     563            0 :             let error_kind = e.get_error_kind();
     564            0 :             ctx.set_error_kind(error_kind);
     565              : 
     566            0 :             tracing::info!(
     567            0 :                 kind=error_kind.to_metric_label(),
     568              :                 error=%e,
     569              :                 msg="subzero core error",
     570            0 :                 "forwarding error to user"
     571              :             );
     572              : 
     573            0 :             let RestError::SubzeroCore(subzero_err) = e else {
     574            0 :                 panic!("expected subzero core error")
     575              :             };
     576              : 
     577            0 :             let json_body = subzero_err.json_body();
     578            0 :             let status_code = StatusCode::from_u16(subzero_err.status_code())
     579            0 :                 .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
     580              : 
     581            0 :             json_response(status_code, json_body)?
     582              :         }
     583            0 :         Err(e) => {
     584            0 :             let error_kind = e.get_error_kind();
     585            0 :             ctx.set_error_kind(error_kind);
     586              : 
     587            0 :             let message = e.to_string_client();
     588            0 :             let status_code = e.get_http_status_code();
     589              : 
     590            0 :             tracing::info!(
     591            0 :                 kind=error_kind.to_metric_label(),
     592              :                 error=%e,
     593              :                 msg=message,
     594            0 :                 "forwarding error to user"
     595              :             );
     596              : 
     597            0 :             let (code, detail, hint) = match e {
     598            0 :                 RestError::Postgres(e) => (
     599            0 :                     if e.code.starts_with("PT") {
     600            0 :                         None
     601              :                     } else {
     602            0 :                         Some(e.code)
     603              :                     },
     604            0 :                     e.detail,
     605            0 :                     e.hint,
     606              :                 ),
     607            0 :                 _ => (None, None, None),
     608              :             };
     609              : 
     610            0 :             json_response(
     611            0 :                 status_code,
     612            0 :                 json!({
     613            0 :                     "message": message,
     614            0 :                     "code": code,
     615            0 :                     "detail": detail,
     616            0 :                     "hint": hint,
     617              :                 }),
     618            0 :             )?
     619              :         }
     620              :     };
     621              : 
     622            0 :     response
     623            0 :         .headers_mut()
     624            0 :         .insert("Access-Control-Allow-Origin", HeaderValue::from_static("*"));
     625            0 :     Ok(response)
     626            0 : }
     627              : 
     628            0 : async fn handle_inner(
     629            0 :     _cancel: CancellationToken,
     630            0 :     config: &'static ProxyConfig,
     631            0 :     ctx: &RequestContext,
     632            0 :     request: Request<Incoming>,
     633            0 :     backend: Arc<PoolingBackend>,
     634            0 : ) -> Result<Response<BoxBody<Bytes, hyper::Error>>, RestError> {
     635            0 :     let _requeset_gauge = Metrics::get()
     636            0 :         .proxy
     637            0 :         .connection_requests
     638            0 :         .guard(ctx.protocol());
     639            0 :     info!(
     640            0 :         protocol = %ctx.protocol(),
     641            0 :         "handling interactive connection from client"
     642              :     );
     643              : 
     644              :     // Read host from Host, then URI host as fallback
     645              :     // TODO: will this be a problem if behind a load balancer?
     646              :     // TODO: can we use the x-forwarded-host header?
     647            0 :     let host = request
     648            0 :         .headers()
     649            0 :         .get(HOST)
     650            0 :         .and_then(|v| v.to_str().ok())
     651            0 :         .unwrap_or_else(|| request.uri().host().unwrap_or(""));
     652              : 
     653              :     // a valid path is /database/rest/v1/... so splitting should be ["", "database", "rest", "v1", ...]
     654            0 :     let database_name = request
     655            0 :         .uri()
     656            0 :         .path()
     657            0 :         .split('/')
     658            0 :         .nth(1)
     659            0 :         .ok_or(RestError::SubzeroCore(NotFound {
     660            0 :             target: request.uri().path().to_string(),
     661            0 :         }))?;
     662              : 
     663              :     // we always use the authenticator role to connect to the database
     664            0 :     let authenticator_role = "authenticator";
     665              : 
     666              :     // Strip the hostname prefix from the host to get the database hostname
     667            0 :     let database_host = host.replace(&config.rest_config.hostname_prefix, "");
     668              : 
     669            0 :     let connection_string =
     670            0 :         format!("postgresql://{authenticator_role}@{database_host}/{database_name}");
     671              : 
     672            0 :     let conn_info = get_conn_info(
     673            0 :         &config.authentication_config,
     674            0 :         ctx,
     675            0 :         Some(&connection_string),
     676            0 :         request.headers(),
     677            0 :     )?;
     678            0 :     info!(
     679            0 :         user = conn_info.conn_info.user_info.user.as_str(),
     680            0 :         "credentials"
     681              :     );
     682              : 
     683            0 :     match conn_info.auth {
     684            0 :         AuthData::Jwt(jwt) => {
     685            0 :             let api_prefix = format!("/{database_name}/rest/v1/");
     686            0 :             handle_rest_inner(
     687            0 :                 config,
     688            0 :                 ctx,
     689            0 :                 &api_prefix,
     690            0 :                 request,
     691            0 :                 &connection_string,
     692            0 :                 conn_info.conn_info,
     693            0 :                 jwt,
     694            0 :                 backend,
     695            0 :             )
     696            0 :             .await
     697              :         }
     698            0 :         AuthData::Password(_) => Err(RestError::ConnInfo(ConnInfoError::MissingCredentials(
     699            0 :             Credentials::BearerJwt,
     700            0 :         ))),
     701              :     }
     702            0 : }
     703              : 
     704              : #[allow(clippy::too_many_arguments)]
     705            0 : async fn handle_rest_inner(
     706            0 :     config: &'static ProxyConfig,
     707            0 :     ctx: &RequestContext,
     708            0 :     api_prefix: &str,
     709            0 :     request: Request<Incoming>,
     710            0 :     connection_string: &str,
     711            0 :     conn_info: ConnInfo,
     712            0 :     jwt: String,
     713            0 :     backend: Arc<PoolingBackend>,
     714            0 : ) -> Result<Response<BoxBody<Bytes, hyper::Error>>, RestError> {
     715              :     // validate the jwt token
     716            0 :     let jwt_parsed = backend
     717            0 :         .authenticate_with_jwt(ctx, &conn_info.user_info, jwt)
     718            0 :         .await
     719            0 :         .map_err(HttpConnError::from)?;
     720              : 
     721            0 :     let db_schema_cache =
     722            0 :         config
     723            0 :             .rest_config
     724            0 :             .db_schema_cache
     725            0 :             .as_ref()
     726            0 :             .ok_or(RestError::SubzeroCore(InternalError {
     727            0 :                 message: "DB schema cache is not configured".to_string(),
     728            0 :             }))?;
     729              : 
     730            0 :     let endpoint_cache_key = conn_info
     731            0 :         .endpoint_cache_key()
     732            0 :         .ok_or(RestError::SubzeroCore(InternalError {
     733            0 :             message: "Failed to get endpoint cache key".to_string(),
     734            0 :         }))?;
     735              : 
     736            0 :     let mut client = backend.connect_to_local_proxy(ctx, conn_info).await?;
     737              : 
     738            0 :     let (parts, originial_body) = request.into_parts();
     739              : 
     740            0 :     let auth_header = parts
     741            0 :         .headers
     742            0 :         .get(AUTHORIZATION)
     743            0 :         .ok_or(RestError::SubzeroCore(InternalError {
     744            0 :             message: "Authorization header is required".to_string(),
     745            0 :         }))?;
     746              : 
     747            0 :     let entry = db_schema_cache
     748            0 :         .get_cached_or_remote(
     749            0 :             &endpoint_cache_key,
     750            0 :             auth_header,
     751            0 :             connection_string,
     752            0 :             &mut client,
     753            0 :             ctx,
     754            0 :             config,
     755            0 :         )
     756            0 :         .await?;
     757            0 :     let (api_config, db_schema_owned) = entry.as_ref();
     758            0 :     let db_schema = db_schema_owned.borrow_schema();
     759              : 
     760            0 :     let db_schemas = &api_config.db_schemas; // list of schemas available for the api
     761            0 :     let db_extra_search_path = &api_config.db_extra_search_path;
     762              :     // TODO: use this when we get a replacement for jsonpath_lib
     763              :     // let role_claim_key = &api_config.role_claim_key;
     764              :     // let role_claim_path = format!("${role_claim_key}");
     765            0 :     let db_anon_role = &api_config.db_anon_role;
     766            0 :     let max_rows = api_config.db_max_rows.as_deref();
     767            0 :     let db_allowed_select_functions = api_config
     768            0 :         .db_allowed_select_functions
     769            0 :         .iter()
     770            0 :         .map(|s| s.as_str())
     771            0 :         .collect::<Vec<_>>();
     772              : 
     773              :     // extract the jwt claims (we'll need them later to set the role and env)
     774            0 :     let jwt_claims = match jwt_parsed.keys {
     775            0 :         ComputeCredentialKeys::JwtPayload(payload_bytes) => {
     776              :             // `payload_bytes` contains the raw JWT payload as Vec<u8>
     777              :             // You can deserialize it back to JSON or parse specific claims
     778            0 :             let payload: serde_json::Value = serde_json::from_slice(&payload_bytes)
     779            0 :                 .map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?;
     780            0 :             Some(payload)
     781              :         }
     782            0 :         ComputeCredentialKeys::AuthKeys(_) => None,
     783              :     };
     784              : 
     785              :     // read the role from the jwt claims (and set it to the "anon" role if not present)
     786            0 :     let (role, authenticated) = match &jwt_claims {
     787            0 :         Some(claims) => match claims.get("role") {
     788            0 :             Some(JsonValue::String(r)) => (Some(r), true),
     789            0 :             _ => (db_anon_role.as_ref(), true),
     790              :         },
     791            0 :         None => (db_anon_role.as_ref(), false),
     792              :     };
     793              : 
     794              :     // do not allow unauthenticated requests when there is no anonymous role setup
     795            0 :     if let (None, false) = (role, authenticated) {
     796            0 :         return Err(RestError::SubzeroCore(JwtTokenInvalid {
     797            0 :             message: "unauthenticated requests not allowed".to_string(),
     798            0 :         }));
     799            0 :     }
     800              : 
     801              :     // start deconstructing the request because subzero core mostly works with &str
     802            0 :     let method = parts.method;
     803            0 :     let method_str = method.as_str();
     804            0 :     let path = parts.uri.path_and_query().map_or("/", |pq| pq.as_str());
     805              : 
     806              :     // this is actually the table name (or rpc/function_name)
     807              :     // TODO: rename this to something more descriptive
     808            0 :     let root = match parts.uri.path().strip_prefix(api_prefix) {
     809            0 :         Some(p) => Ok(p),
     810            0 :         None => Err(RestError::SubzeroCore(NotFound {
     811            0 :             target: parts.uri.path().to_string(),
     812            0 :         })),
     813            0 :     }?;
     814              : 
     815              :     // pick the current schema from the headers (or the first one from config)
     816            0 :     let schema_name = &DbSchema::pick_current_schema(db_schemas, method_str, &parts.headers)?;
     817              : 
     818              :     // add the content-profile header to the response
     819            0 :     let mut response_headers = vec![];
     820            0 :     if db_schemas.len() > 1 {
     821            0 :         response_headers.push(("Content-Profile".to_string(), schema_name.clone()));
     822            0 :     }
     823              : 
     824              :     // parse the query string into a Vec<(&str, &str)>
     825            0 :     let query = match parts.uri.query() {
     826            0 :         Some(q) => form_urlencoded::parse(q.as_bytes()).collect(),
     827            0 :         None => vec![],
     828              :     };
     829            0 :     let get: Vec<(&str, &str)> = query.iter().map(|(k, v)| (&**k, &**v)).collect();
     830              : 
     831              :     // convert the headers map to a HashMap<&str, &str>
     832            0 :     let headers: HashMap<&str, &str> = parts
     833            0 :         .headers
     834            0 :         .iter()
     835            0 :         .map(|(k, v)| (k.as_str(), v.to_str().unwrap_or("__BAD_HEADER__")))
     836            0 :         .collect();
     837              : 
     838            0 :     let cookies = HashMap::new(); // TODO: add cookies
     839              : 
     840              :     // Read the request body (skip for GET requests)
     841            0 :     let body_as_string: Option<String> = if method == Method::GET {
     842            0 :         None
     843              :     } else {
     844            0 :         let body_bytes =
     845            0 :             read_body_with_limit(originial_body, config.http_config.max_request_size_bytes)
     846            0 :                 .await
     847            0 :                 .map_err(ReadPayloadError::from)?;
     848            0 :         if body_bytes.is_empty() {
     849            0 :             None
     850              :         } else {
     851            0 :             Some(String::from_utf8_lossy(&body_bytes).into_owned())
     852              :         }
     853              :     };
     854              : 
     855              :     // parse the request into an ApiRequest struct
     856            0 :     let mut api_request = parse(
     857            0 :         schema_name,
     858            0 :         root,
     859            0 :         db_schema,
     860            0 :         method_str,
     861            0 :         path,
     862            0 :         get,
     863            0 :         body_as_string.as_deref(),
     864            0 :         headers,
     865            0 :         cookies,
     866            0 :         max_rows,
     867              :     )
     868            0 :     .map_err(RestError::SubzeroCore)?;
     869              : 
     870            0 :     let role_str = match role {
     871            0 :         Some(r) => r,
     872            0 :         None => "",
     873              :     };
     874              : 
     875            0 :     replace_select_star(db_schema, schema_name, role_str, &mut api_request.query)?;
     876              : 
     877              :     // TODO: this is not relevant when acting as PostgREST but will be useful
     878              :     // in the context of DBX where they need internal permissions
     879              :     // if !disable_internal_permissions {
     880              :     //     check_privileges(db_schema, schema_name, role_str, &api_request)?;
     881              :     // }
     882              : 
     883            0 :     check_safe_functions(&api_request, &db_allowed_select_functions)?;
     884              : 
     885              :     // TODO: this is not relevant when acting as PostgREST but will be useful
     886              :     // in the context of DBX where they need internal permissions
     887              :     // if !disable_internal_permissions {
     888              :     //     insert_policy_conditions(db_schema, schema_name, role_str, &mut api_request.query)?;
     889              :     // }
     890              : 
     891            0 :     let env_role = Some(role_str);
     892              : 
     893              :     // construct the env (passed in to the sql context as GUCs)
     894            0 :     let empty_json = "{}".to_string();
     895            0 :     let headers_env = serde_json::to_string(&api_request.headers).unwrap_or(empty_json.clone());
     896            0 :     let cookies_env = serde_json::to_string(&api_request.cookies).unwrap_or(empty_json.clone());
     897            0 :     let get_env = serde_json::to_string(&api_request.get).unwrap_or(empty_json.clone());
     898            0 :     let jwt_claims_env = jwt_claims
     899            0 :         .as_ref()
     900            0 :         .map(|v| serde_json::to_string(v).unwrap_or(empty_json.clone()))
     901            0 :         .unwrap_or(if let Some(r) = env_role {
     902            0 :             let claims: HashMap<&str, &str> = HashMap::from([("role", r)]);
     903            0 :             serde_json::to_string(&claims).unwrap_or(empty_json.clone())
     904              :         } else {
     905            0 :             empty_json.clone()
     906              :         });
     907            0 :     let mut search_path = vec![api_request.schema_name];
     908            0 :     if let Some(extra) = &db_extra_search_path {
     909            0 :         search_path.extend(extra.iter().map(|s| s.as_str()));
     910            0 :     }
     911            0 :     let search_path_str = search_path
     912            0 :         .into_iter()
     913            0 :         .filter(|s| !s.is_empty())
     914            0 :         .collect::<Vec<_>>()
     915            0 :         .join(",");
     916            0 :     let mut env: HashMap<&str, &str> = HashMap::from([
     917            0 :         ("request.method", api_request.method),
     918            0 :         ("request.path", api_request.path),
     919            0 :         ("request.headers", &headers_env),
     920            0 :         ("request.cookies", &cookies_env),
     921            0 :         ("request.get", &get_env),
     922            0 :         ("request.jwt.claims", &jwt_claims_env),
     923            0 :         ("search_path", &search_path_str),
     924            0 :     ]);
     925            0 :     if let Some(r) = env_role {
     926            0 :         env.insert("role", r);
     927            0 :     }
     928              : 
     929              :     // generate the sql statements
     930            0 :     let (env_statement, env_parameters, _) = generate(fmt_env_query(&env));
     931            0 :     let (main_statement, main_parameters, _) = generate(fmt_main_query(
     932            0 :         db_schema,
     933            0 :         api_request.schema_name,
     934            0 :         &api_request,
     935            0 :         &env,
     936            0 :     )?);
     937              : 
     938            0 :     let mut headers = vec![
     939            0 :         (&NEON_REQUEST_ID, uuid_to_header_value(ctx.session_id())),
     940            0 :         (
     941            0 :             &CONN_STRING,
     942            0 :             HeaderValue::from_str(connection_string).expect("invalid connection string"),
     943            0 :         ),
     944            0 :         (&AUTHORIZATION, auth_header.clone()),
     945            0 :         (
     946            0 :             &TXN_ISOLATION_LEVEL,
     947            0 :             HeaderValue::from_static("ReadCommitted"),
     948            0 :         ),
     949            0 :         (&ALLOW_POOL, HEADER_VALUE_TRUE),
     950              :     ];
     951              : 
     952            0 :     if api_request.read_only {
     953            0 :         headers.push((&TXN_READ_ONLY, HEADER_VALUE_TRUE));
     954            0 :     }
     955              : 
     956              :     // convert the parameters from subzero core representation to the local proxy repr.
     957            0 :     let req_body = serde_json::to_string(&BatchQueryData {
     958            0 :         queries: vec![
     959              :             QueryData {
     960            0 :                 query: env_statement.into(),
     961            0 :                 params: env_parameters
     962            0 :                     .iter()
     963            0 :                     .map(|p| to_sql_param(&p.to_param()))
     964            0 :                     .collect(),
     965              :             },
     966              :             QueryData {
     967            0 :                 query: main_statement.into(),
     968            0 :                 params: main_parameters
     969            0 :                     .iter()
     970            0 :                     .map(|p| to_sql_param(&p.to_param()))
     971            0 :                     .collect(),
     972              :             },
     973              :         ],
     974              :     })
     975            0 :     .map_err(|e| RestError::JsonConversion(JsonConversionError::ParseJsonError(e)))?;
     976              : 
     977              :     // todo: map body to count egress
     978            0 :     let _metrics = client.metrics(ctx); // FIXME: is everything in the context set correctly?
     979              : 
     980              :     // send the request to the local proxy
     981            0 :     let response = make_raw_local_proxy_request(&mut client, headers, req_body).await?;
     982            0 :     let (parts, body) = response.into_parts();
     983              : 
     984            0 :     let max_response = config.http_config.max_response_size_bytes;
     985            0 :     let bytes = read_body_with_limit(body, max_response)
     986            0 :         .await
     987            0 :         .map_err(ReadPayloadError::from)?;
     988              : 
     989              :     // if the response status is greater than 399, then it is an error
     990              :     // FIXME: check if there are other error codes or shapes of the response
     991            0 :     if parts.status.as_u16() > 399 {
     992              :         // turn this postgres error from the json into PostgresError
     993            0 :         let postgres_error = serde_json::from_slice(&bytes)
     994            0 :             .map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?;
     995              : 
     996            0 :         return Err(RestError::Postgres(postgres_error));
     997            0 :     }
     998              : 
     999            0 :     #[derive(Deserialize)]
    1000              :     struct QueryResults {
    1001              :         /// we run two queries, so we want only two results.
    1002              :         results: (EnvRows, MainRows),
    1003              :     }
    1004              : 
    1005              :     /// `env_statement` returns nothing of interest to us
    1006            0 :     #[derive(Deserialize)]
    1007              :     struct EnvRows {}
    1008              : 
    1009            0 :     #[derive(Deserialize)]
    1010              :     struct MainRows {
    1011              :         /// `main_statement` only returns a single row.
    1012              :         rows: [MainRow; 1],
    1013              :     }
    1014              : 
    1015            0 :     #[derive(Deserialize)]
    1016              :     struct MainRow {
    1017              :         body: String,
    1018              :         page_total: Option<String>,
    1019              :         total_result_set: Option<String>,
    1020              :         response_headers: Option<String>,
    1021              :         response_status: Option<String>,
    1022              :     }
    1023              : 
    1024            0 :     let results: QueryResults = serde_json::from_slice(&bytes)
    1025            0 :         .map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?;
    1026              : 
    1027              :     let QueryResults {
    1028            0 :         results: (_, MainRows { rows: [row] }),
    1029            0 :     } = results;
    1030              : 
    1031              :     // build the intermediate response object
    1032            0 :     let api_response = ApiResponse {
    1033            0 :         page_total: row.page_total.map_or(0, |v| v.parse::<u64>().unwrap_or(0)),
    1034            0 :         total_result_set: row.total_result_set.map(|v| v.parse::<u64>().unwrap_or(0)),
    1035              :         top_level_offset: 0, // FIXME: check why this is 0
    1036            0 :         response_headers: row.response_headers,
    1037            0 :         response_status: row.response_status,
    1038            0 :         body: row.body,
    1039              :     };
    1040              : 
    1041              :     // TODO: rollback the transaction if the page_total is not 1 and the accept_content_type is SingularJSON
    1042              :     // we can not do this in the context of proxy for now
    1043              :     // if api_request.accept_content_type == SingularJSON && api_response.page_total != 1 {
    1044              :     //     // rollback the transaction here
    1045              :     //     return Err(RestError::SubzeroCore(SingularityError {
    1046              :     //         count: api_response.page_total,
    1047              :     //         content_type: "application/vnd.pgrst.object+json".to_string(),
    1048              :     //     }));
    1049              :     // }
    1050              : 
    1051              :     // TODO: rollback the transaction if the page_total is not 1 and the method is PUT
    1052              :     // we can not do this in the context of proxy for now
    1053              :     // if api_request.method == Method::PUT && api_response.page_total != 1 {
    1054              :     //     // Makes sure the querystring pk matches the payload pk
    1055              :     //     // e.g. PUT /items?id=eq.1 { "id" : 1, .. } is accepted,
    1056              :     //     // PUT /items?id=eq.14 { "id" : 2, .. } is rejected.
    1057              :     //     // If this condition is not satisfied then nothing is inserted,
    1058              :     //     // rollback the transaction here
    1059              :     //     return Err(RestError::SubzeroCore(PutMatchingPkError));
    1060              :     // }
    1061              : 
    1062              :     // create and return the response to the client
    1063              :     // this section mostly deals with setting the right headers according to PostgREST specs
    1064            0 :     let page_total = api_response.page_total;
    1065            0 :     let total_result_set = api_response.total_result_set;
    1066            0 :     let top_level_offset = api_response.top_level_offset;
    1067            0 :     let response_content_type = match (&api_request.accept_content_type, &api_request.query.node) {
    1068              :         (SingularJSON, _)
    1069              :         | (
    1070              :             _,
    1071              :             FunctionCall {
    1072              :                 returns_single: true,
    1073              :                 is_scalar: false,
    1074              :                 ..
    1075              :             },
    1076            0 :         ) => SingularJSON,
    1077            0 :         (TextCSV, _) => TextCSV,
    1078            0 :         _ => ApplicationJSON,
    1079              :     };
    1080              : 
    1081              :     // check if the SQL env set some response headers (happens when we called a rpc function)
    1082            0 :     if let Some(response_headers_str) = api_response.response_headers {
    1083            0 :         let Ok(headers_json) =
    1084            0 :             serde_json::from_str::<Vec<Vec<(String, String)>>>(response_headers_str.as_str())
    1085              :         else {
    1086            0 :             return Err(RestError::SubzeroCore(GucHeadersError));
    1087              :         };
    1088              : 
    1089            0 :         response_headers.extend(headers_json.into_iter().flatten());
    1090            0 :     }
    1091              : 
    1092              :     // calculate and set the content range header
    1093            0 :     let lower = top_level_offset as i64;
    1094            0 :     let upper = top_level_offset as i64 + page_total as i64 - 1;
    1095            0 :     let total = total_result_set.map(|t| t as i64);
    1096            0 :     let content_range = match (&method, &api_request.query.node) {
    1097            0 :         (&Method::POST, Insert { .. }) => content_range_header(1, 0, total),
    1098            0 :         (&Method::DELETE, Delete { .. }) => content_range_header(1, upper, total),
    1099            0 :         _ => content_range_header(lower, upper, total),
    1100              :     };
    1101            0 :     response_headers.push(("Content-Range".to_string(), content_range));
    1102              : 
    1103              :     // calculate the status code
    1104              :     #[rustfmt::skip]
    1105            0 :     let mut status = match (&method, &api_request.query.node, page_total, &api_request.preferences) {
    1106            0 :         (&Method::POST,   Insert { .. }, ..) => 201,
    1107            0 :         (&Method::DELETE, Delete { .. }, _, Some(Preferences {representation: Some(Representation::Full),..}),) => 200,
    1108            0 :         (&Method::DELETE, Delete { .. }, ..) => 204,
    1109            0 :         (&Method::PATCH,  Update { columns, .. }, 0, _) if !columns.is_empty() => 404,
    1110            0 :         (&Method::PATCH,  Update { .. }, _,Some(Preferences {representation: Some(Representation::Full),..}),) => 200,
    1111            0 :         (&Method::PATCH,  Update { .. }, ..) => 204,
    1112            0 :         (&Method::PUT,    Insert { .. },_,Some(Preferences {representation: Some(Representation::Full),..}),) => 200,
    1113            0 :         (&Method::PUT,    Insert { .. }, ..) => 204,
    1114            0 :         _ => content_range_status(lower, upper, total),
    1115              :     };
    1116              : 
    1117              :     // add the preference-applied header
    1118              :     if let Some(Preferences {
    1119            0 :         resolution: Some(r),
    1120              :         ..
    1121            0 :     }) = api_request.preferences
    1122              :     {
    1123            0 :         response_headers.push((
    1124            0 :             "Preference-Applied".to_string(),
    1125            0 :             match r {
    1126            0 :                 MergeDuplicates => "resolution=merge-duplicates".to_string(),
    1127            0 :                 IgnoreDuplicates => "resolution=ignore-duplicates".to_string(),
    1128              :             },
    1129              :         ));
    1130            0 :     }
    1131              : 
    1132              :     // check if the SQL env set some response status (happens when we called a rpc function)
    1133            0 :     if let Some(response_status_str) = api_response.response_status {
    1134            0 :         status = response_status_str
    1135            0 :             .parse::<u16>()
    1136            0 :             .map_err(|_| RestError::SubzeroCore(GucStatusError))?;
    1137            0 :     }
    1138              : 
    1139              :     // set the content type header
    1140              :     // TODO: move this to a subzero function
    1141              :     // as_header_value(&self) -> Option<&str>
    1142            0 :     let http_content_type = match response_content_type {
    1143            0 :         SingularJSON => Ok("application/vnd.pgrst.object+json"),
    1144            0 :         TextCSV => Ok("text/csv"),
    1145            0 :         ApplicationJSON => Ok("application/json"),
    1146            0 :         Other(t) => Err(RestError::SubzeroCore(ContentTypeError {
    1147            0 :             message: format!("None of these Content-Types are available: {t}"),
    1148            0 :         })),
    1149            0 :     }?;
    1150              : 
    1151              :     // build the response body
    1152            0 :     let response_body = Full::new(Bytes::from(api_response.body))
    1153            0 :         .map_err(|never| match never {})
    1154            0 :         .boxed();
    1155              : 
    1156              :     // build the response
    1157            0 :     let mut response = Response::builder()
    1158            0 :         .status(StatusCode::from_u16(status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR))
    1159            0 :         .header(CONTENT_TYPE, http_content_type);
    1160              : 
    1161              :     // Add all headers from response_headers vector
    1162            0 :     for (header_name, header_value) in response_headers {
    1163            0 :         response = response.header(header_name, header_value);
    1164            0 :     }
    1165              : 
    1166              :     // add the body and return the response
    1167            0 :     response.body(response_body).map_err(|_| {
    1168            0 :         RestError::SubzeroCore(InternalError {
    1169            0 :             message: "Failed to build response".to_string(),
    1170            0 :         })
    1171            0 :     })
    1172            0 : }
        

Generated by: LCOV version 2.1-beta