Line data Source code
1 : use std::borrow::Cow;
2 : use std::collections::HashMap;
3 : use std::convert::Infallible;
4 : use std::sync::Arc;
5 :
6 : use bytes::Bytes;
7 : use http::Method;
8 : use http::header::{
9 : ACCESS_CONTROL_ALLOW_HEADERS, ACCESS_CONTROL_ALLOW_METHODS, ACCESS_CONTROL_ALLOW_ORIGIN,
10 : ACCESS_CONTROL_EXPOSE_HEADERS, ACCESS_CONTROL_MAX_AGE, ACCESS_CONTROL_REQUEST_HEADERS, ALLOW,
11 : AUTHORIZATION, CONTENT_TYPE, HOST, ORIGIN, VARY,
12 : };
13 : use http_body_util::combinators::BoxBody;
14 : use http_body_util::{BodyExt, Empty, Full};
15 : use http_utils::error::ApiError;
16 : use hyper::body::Incoming;
17 : use hyper::http::response::Builder;
18 : use hyper::http::{HeaderMap, HeaderName, HeaderValue};
19 : use hyper::{Request, Response, StatusCode};
20 : use indexmap::IndexMap;
21 : use moka::sync::Cache;
22 : use ouroboros::self_referencing;
23 : use serde::de::DeserializeOwned;
24 : use serde::{Deserialize, Deserializer};
25 : use serde_json::Value as JsonValue;
26 : use serde_json::value::RawValue;
27 : use subzero_core::api::ContentType::{ApplicationJSON, Other, SingularJSON, TextCSV};
28 : use subzero_core::api::QueryNode::{Delete, FunctionCall, Insert, Update};
29 : use subzero_core::api::Resolution::{IgnoreDuplicates, MergeDuplicates};
30 : use subzero_core::api::{ApiResponse, ListVal, Payload, Preferences, Representation, SingleVal};
31 : use subzero_core::config::{db_allowed_select_functions, db_schemas, role_claim_key};
32 : use subzero_core::dynamic_statement::{JoinIterator, param, sql};
33 : use subzero_core::error::Error::{
34 : self as SubzeroCoreError, ContentTypeError, GucHeadersError, GucStatusError, InternalError,
35 : JsonDeserialize, JwtTokenInvalid, NotFound,
36 : };
37 : use subzero_core::error::pg_error_to_status_code;
38 : use subzero_core::formatter::Param::{LV, PL, SV, Str, StrOwned};
39 : use subzero_core::formatter::postgresql::{fmt_main_query, generate};
40 : use subzero_core::formatter::{Param, Snippet, SqlParam};
41 : use subzero_core::parser::postgrest::parse;
42 : use subzero_core::permissions::{check_safe_functions, replace_select_star};
43 : use subzero_core::schema::{
44 : DbSchema, POSTGRESQL_INTROSPECTION_SQL, get_postgresql_configuration_query,
45 : };
46 : use subzero_core::{content_range_header, content_range_status};
47 : use tokio_util::sync::CancellationToken;
48 : use tracing::{error, info};
49 : use typed_json::json;
50 : use url::form_urlencoded;
51 :
52 : use super::backend::{HttpConnError, LocalProxyConnError, PoolingBackend};
53 : use super::conn_pool::AuthData;
54 : use super::conn_pool_lib::ConnInfo;
55 : use super::error::{ConnInfoError, Credentials, HttpCodeError, ReadPayloadError};
56 : use super::http_conn_pool::{self, LocalProxyClient};
57 : use super::http_util::{
58 : ALLOW_POOL, CONN_STRING, NEON_REQUEST_ID, RAW_TEXT_OUTPUT, TXN_ISOLATION_LEVEL, TXN_READ_ONLY,
59 : get_conn_info, json_response, uuid_to_header_value,
60 : };
61 : use super::json::JsonConversionError;
62 : use crate::auth::backend::ComputeCredentialKeys;
63 : use crate::cache::common::{count_cache_insert, count_cache_outcome, eviction_listener};
64 : use crate::config::ProxyConfig;
65 : use crate::context::RequestContext;
66 : use crate::error::{ErrorKind, ReportableError, UserFacingError};
67 : use crate::http::read_body_with_limit;
68 : use crate::metrics::{CacheKind, Metrics};
69 : use crate::serverless::sql_over_http::HEADER_VALUE_TRUE;
70 : use crate::types::EndpointCacheKey;
71 : use crate::util::deserialize_json_string;
72 :
73 : static EMPTY_JSON_SCHEMA: &str = r#"{"schemas":[]}"#;
74 : const INTROSPECTION_SQL: &str = POSTGRESQL_INTROSPECTION_SQL;
75 : const HEADER_VALUE_ALLOW_ALL_ORIGINS: HeaderValue = HeaderValue::from_static("*");
76 : // CORS headers values
77 : const ACCESS_CONTROL_ALLOW_METHODS_VALUE: HeaderValue =
78 : HeaderValue::from_static("GET, POST, PATCH, PUT, DELETE, OPTIONS");
79 : const ACCESS_CONTROL_MAX_AGE_VALUE: HeaderValue = HeaderValue::from_static("86400");
80 : const ACCESS_CONTROL_EXPOSE_HEADERS_VALUE: HeaderValue = HeaderValue::from_static(
81 : "Content-Encoding, Content-Location, Content-Range, Content-Type, Date, Location, Server, Transfer-Encoding, Range-Unit",
82 : );
83 : const ACCESS_CONTROL_ALLOW_HEADERS_VALUE: HeaderValue = HeaderValue::from_static("Authorization");
84 : const ACCESS_CONTROL_VARY_VALUE: HeaderValue = HeaderValue::from_static("Origin");
85 :
86 : // A wrapper around the DbSchema that allows for self-referencing
87 : #[self_referencing]
88 : pub struct DbSchemaOwned {
89 : schema_string: String,
90 : #[covariant]
91 : #[borrows(schema_string)]
92 : schema: DbSchema<'this>,
93 : }
94 :
95 : impl<'de> Deserialize<'de> for DbSchemaOwned {
96 0 : fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
97 0 : where
98 0 : D: Deserializer<'de>,
99 : {
100 0 : let s = String::deserialize(deserializer)?;
101 0 : DbSchemaOwned::try_new(s, |s| serde_json::from_str(s))
102 0 : .map_err(<D::Error as serde::de::Error>::custom)
103 0 : }
104 : }
105 :
106 0 : fn split_comma_separated(s: &str) -> Vec<String> {
107 0 : s.split(',').map(|s| s.trim().to_string()).collect()
108 0 : }
109 :
110 0 : fn deserialize_comma_separated<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
111 0 : where
112 0 : D: Deserializer<'de>,
113 : {
114 0 : let s = String::deserialize(deserializer)?;
115 0 : Ok(split_comma_separated(&s))
116 0 : }
117 :
118 0 : fn deserialize_comma_separated_option<'de, D>(
119 0 : deserializer: D,
120 0 : ) -> Result<Option<Vec<String>>, D::Error>
121 0 : where
122 0 : D: Deserializer<'de>,
123 : {
124 0 : let opt = Option::<String>::deserialize(deserializer)?;
125 0 : if let Some(s) = &opt {
126 0 : let trimmed = s.trim();
127 0 : if trimmed.is_empty() {
128 0 : return Ok(None);
129 0 : }
130 0 : return Ok(Some(split_comma_separated(trimmed)));
131 0 : }
132 0 : Ok(None)
133 0 : }
134 :
135 : // The ApiConfig is the configuration for the API per endpoint
136 : // The configuration is read from the database and cached in the DbSchemaCache
137 : #[derive(Deserialize, Debug)]
138 : pub struct ApiConfig {
139 : #[serde(
140 : default = "db_schemas",
141 : deserialize_with = "deserialize_comma_separated"
142 : )]
143 : pub db_schemas: Vec<String>,
144 : pub db_anon_role: Option<String>,
145 : pub db_max_rows: Option<String>,
146 : #[serde(default = "db_allowed_select_functions")]
147 : pub db_allowed_select_functions: Vec<String>,
148 : // #[serde(deserialize_with = "to_tuple", default)]
149 : // pub db_pre_request: Option<(String, String)>,
150 : #[allow(dead_code)]
151 : #[serde(default = "role_claim_key")]
152 : pub role_claim_key: String,
153 : #[serde(default, deserialize_with = "deserialize_comma_separated_option")]
154 : pub db_extra_search_path: Option<Vec<String>>,
155 : #[serde(default, deserialize_with = "deserialize_comma_separated_option")]
156 : pub server_cors_allowed_origins: Option<Vec<String>>,
157 : }
158 :
159 : // The DbSchemaCache is a cache of the ApiConfig and DbSchemaOwned for each endpoint
160 : pub(crate) struct DbSchemaCache(Cache<EndpointCacheKey, Arc<(ApiConfig, DbSchemaOwned)>>);
161 : impl DbSchemaCache {
162 0 : pub fn new(config: crate::config::CacheOptions) -> Self {
163 0 : let builder = Cache::builder().name("schema");
164 0 : let builder = config.moka(builder);
165 :
166 0 : let metrics = &Metrics::get().cache;
167 0 : if let Some(size) = config.size {
168 0 : metrics.capacity.set(CacheKind::Schema, size as i64);
169 0 : }
170 :
171 0 : let builder =
172 0 : builder.eviction_listener(|_k, _v, cause| eviction_listener(CacheKind::Schema, cause));
173 :
174 0 : Self(builder.build())
175 0 : }
176 :
177 0 : pub async fn maintain(&self) -> Result<Infallible, anyhow::Error> {
178 0 : let mut ticker = tokio::time::interval(std::time::Duration::from_secs(60));
179 : loop {
180 0 : ticker.tick().await;
181 0 : self.0.run_pending_tasks();
182 : }
183 : }
184 :
185 0 : pub fn get_cached(
186 0 : &self,
187 0 : endpoint_id: &EndpointCacheKey,
188 0 : ) -> Option<Arc<(ApiConfig, DbSchemaOwned)>> {
189 0 : count_cache_outcome(CacheKind::Schema, self.0.get(endpoint_id))
190 0 : }
191 0 : pub async fn get_remote(
192 0 : &self,
193 0 : endpoint_id: &EndpointCacheKey,
194 0 : auth_header: &HeaderValue,
195 0 : connection_string: &str,
196 0 : client: &mut http_conn_pool::Client<LocalProxyClient>,
197 0 : ctx: &RequestContext,
198 0 : config: &'static ProxyConfig,
199 0 : ) -> Result<Arc<(ApiConfig, DbSchemaOwned)>, RestError> {
200 0 : info!("db_schema cache miss for endpoint: {:?}", endpoint_id);
201 0 : let remote_value = self
202 0 : .internal_get_remote(auth_header, connection_string, client, ctx, config)
203 0 : .await;
204 0 : let (api_config, schema_owned) = match remote_value {
205 0 : Ok((api_config, schema_owned)) => (api_config, schema_owned),
206 0 : Err(e @ RestError::SchemaTooLarge) => {
207 : // for the case where the schema is too large, we cache an empty dummy value
208 : // all the other requests will fail without triggering the introspection query
209 0 : let schema_owned = serde_json::from_str::<DbSchemaOwned>(EMPTY_JSON_SCHEMA)
210 0 : .map_err(|e| JsonDeserialize { source: e })?;
211 :
212 0 : let api_config = ApiConfig {
213 0 : db_schemas: vec![],
214 0 : db_anon_role: None,
215 0 : db_max_rows: None,
216 0 : db_allowed_select_functions: vec![],
217 0 : role_claim_key: String::new(),
218 0 : db_extra_search_path: None,
219 0 : server_cors_allowed_origins: None,
220 0 : };
221 0 : let value = Arc::new((api_config, schema_owned));
222 0 : count_cache_insert(CacheKind::Schema);
223 0 : self.0.insert(endpoint_id.clone(), value);
224 0 : return Err(e);
225 : }
226 0 : Err(e) => {
227 0 : return Err(e);
228 : }
229 : };
230 0 : let value = Arc::new((api_config, schema_owned));
231 0 : count_cache_insert(CacheKind::Schema);
232 0 : self.0.insert(endpoint_id.clone(), value.clone());
233 0 : Ok(value)
234 0 : }
235 0 : async fn internal_get_remote(
236 0 : &self,
237 0 : auth_header: &HeaderValue,
238 0 : connection_string: &str,
239 0 : client: &mut http_conn_pool::Client<LocalProxyClient>,
240 0 : ctx: &RequestContext,
241 0 : config: &'static ProxyConfig,
242 0 : ) -> Result<(ApiConfig, DbSchemaOwned), RestError> {
243 0 : #[derive(Deserialize)]
244 : struct SingleRow<Row> {
245 : rows: [Row; 1],
246 : }
247 :
248 : #[derive(Deserialize)]
249 : struct ConfigRow {
250 : #[serde(deserialize_with = "deserialize_json_string")]
251 : config: ApiConfig,
252 : }
253 :
254 0 : #[derive(Deserialize)]
255 : struct SchemaRow {
256 : json_schema: DbSchemaOwned,
257 : }
258 :
259 0 : let headers = vec![
260 0 : (&NEON_REQUEST_ID, uuid_to_header_value(ctx.session_id())),
261 0 : (
262 0 : &CONN_STRING,
263 0 : HeaderValue::from_str(connection_string).expect(
264 0 : "connection string came from a header, so it must be a valid headervalue",
265 0 : ),
266 0 : ),
267 0 : (&AUTHORIZATION, auth_header.clone()),
268 0 : (&RAW_TEXT_OUTPUT, HEADER_VALUE_TRUE),
269 : ];
270 :
271 0 : let query = get_postgresql_configuration_query(Some("pgrst.pre_config"));
272 : let SingleRow {
273 0 : rows: [ConfigRow { config: api_config }],
274 0 : } = make_local_proxy_request(
275 0 : client,
276 0 : headers.iter().cloned(),
277 0 : QueryData {
278 0 : query: Cow::Owned(query),
279 0 : params: vec![],
280 0 : },
281 0 : config.rest_config.max_schema_size,
282 0 : )
283 0 : .await
284 0 : .map_err(|e| match e {
285 : RestError::ReadPayload(ReadPayloadError::BodyTooLarge { .. }) => {
286 0 : RestError::SchemaTooLarge
287 : }
288 0 : e => e,
289 0 : })?;
290 :
291 : // now that we have the api_config let's run the second INTROSPECTION_SQL query
292 : let SingleRow {
293 0 : rows: [SchemaRow { json_schema }],
294 0 : } = make_local_proxy_request(
295 0 : client,
296 0 : headers,
297 0 : QueryData {
298 0 : query: INTROSPECTION_SQL.into(),
299 0 : params: vec![
300 0 : serde_json::to_value(&api_config.db_schemas)
301 0 : .expect("Vec<String> is always valid to encode as JSON"),
302 0 : JsonValue::Bool(false), // include_roles_with_login
303 0 : JsonValue::Bool(false), // use_internal_permissions
304 0 : ],
305 0 : },
306 0 : config.rest_config.max_schema_size,
307 0 : )
308 0 : .await
309 0 : .map_err(|e| match e {
310 : RestError::ReadPayload(ReadPayloadError::BodyTooLarge { .. }) => {
311 0 : RestError::SchemaTooLarge
312 : }
313 0 : e => e,
314 0 : })?;
315 :
316 0 : Ok((api_config, json_schema))
317 0 : }
318 : }
319 :
320 : // A type to represent a postgresql errors
321 : // we use our own type (instead of postgres_client::Error) because we get the error from the json response
322 0 : #[derive(Debug, thiserror::Error, Deserialize)]
323 : pub(crate) struct PostgresError {
324 : pub code: String,
325 : pub message: String,
326 : pub detail: Option<String>,
327 : pub hint: Option<String>,
328 : }
329 : impl HttpCodeError for PostgresError {
330 0 : fn get_http_status_code(&self) -> StatusCode {
331 0 : let status = pg_error_to_status_code(&self.code, true);
332 0 : StatusCode::from_u16(status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)
333 0 : }
334 : }
335 : impl ReportableError for PostgresError {
336 0 : fn get_error_kind(&self) -> ErrorKind {
337 0 : ErrorKind::User
338 0 : }
339 : }
340 : impl UserFacingError for PostgresError {
341 0 : fn to_string_client(&self) -> String {
342 0 : if self.code.starts_with("PT") {
343 0 : "Postgres error".to_string()
344 : } else {
345 0 : self.message.clone()
346 : }
347 0 : }
348 : }
349 : impl std::fmt::Display for PostgresError {
350 0 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
351 0 : write!(f, "{}", self.message)
352 0 : }
353 : }
354 :
355 : // A type to represent errors that can occur in the rest broker
356 : #[derive(Debug, thiserror::Error)]
357 : pub(crate) enum RestError {
358 : #[error(transparent)]
359 : ReadPayload(#[from] ReadPayloadError),
360 : #[error(transparent)]
361 : ConnectCompute(#[from] HttpConnError),
362 : #[error(transparent)]
363 : ConnInfo(#[from] ConnInfoError),
364 : #[error(transparent)]
365 : Postgres(#[from] PostgresError),
366 : #[error(transparent)]
367 : JsonConversion(#[from] JsonConversionError),
368 : #[error(transparent)]
369 : SubzeroCore(#[from] SubzeroCoreError),
370 : #[error("schema is too large")]
371 : SchemaTooLarge,
372 : }
373 : impl ReportableError for RestError {
374 0 : fn get_error_kind(&self) -> ErrorKind {
375 0 : match self {
376 0 : RestError::ReadPayload(e) => e.get_error_kind(),
377 0 : RestError::ConnectCompute(e) => e.get_error_kind(),
378 0 : RestError::ConnInfo(e) => e.get_error_kind(),
379 0 : RestError::Postgres(_) => ErrorKind::Postgres,
380 0 : RestError::JsonConversion(_) => ErrorKind::Postgres,
381 0 : RestError::SubzeroCore(_) => ErrorKind::User,
382 0 : RestError::SchemaTooLarge => ErrorKind::User,
383 : }
384 0 : }
385 : }
386 : impl UserFacingError for RestError {
387 0 : fn to_string_client(&self) -> String {
388 0 : match self {
389 0 : RestError::ReadPayload(p) => p.to_string(),
390 0 : RestError::ConnectCompute(c) => c.to_string_client(),
391 0 : RestError::ConnInfo(c) => c.to_string_client(),
392 0 : RestError::SchemaTooLarge => self.to_string(),
393 0 : RestError::Postgres(p) => p.to_string_client(),
394 0 : RestError::JsonConversion(_) => "could not parse postgres response".to_string(),
395 0 : RestError::SubzeroCore(s) => {
396 : // TODO: this is a hack to get the message from the json body
397 0 : let json = s.json_body();
398 0 : let default_message = "Unknown error".to_string();
399 :
400 0 : json.get("message")
401 0 : .map_or(default_message.clone(), |m| match m {
402 0 : JsonValue::String(s) => s.clone(),
403 0 : _ => default_message,
404 0 : })
405 : }
406 : }
407 0 : }
408 : }
409 : impl HttpCodeError for RestError {
410 0 : fn get_http_status_code(&self) -> StatusCode {
411 0 : match self {
412 0 : RestError::ReadPayload(e) => e.get_http_status_code(),
413 0 : RestError::ConnectCompute(h) => match h.get_error_kind() {
414 0 : ErrorKind::User => StatusCode::BAD_REQUEST,
415 0 : _ => StatusCode::INTERNAL_SERVER_ERROR,
416 : },
417 0 : RestError::ConnInfo(_) => StatusCode::BAD_REQUEST,
418 0 : RestError::Postgres(e) => e.get_http_status_code(),
419 0 : RestError::JsonConversion(_) => StatusCode::INTERNAL_SERVER_ERROR,
420 0 : RestError::SchemaTooLarge => StatusCode::INTERNAL_SERVER_ERROR,
421 0 : RestError::SubzeroCore(e) => {
422 0 : let status = e.status_code();
423 0 : StatusCode::from_u16(status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)
424 : }
425 : }
426 0 : }
427 : }
428 :
429 : // Helper functions for the rest broker
430 :
431 0 : fn fmt_env_query<'a>(env: &'a HashMap<&'a str, &'a str>) -> Snippet<'a> {
432 : "select "
433 0 : + if env.is_empty() {
434 0 : sql("null")
435 : } else {
436 0 : env.iter()
437 0 : .map(|(k, v)| {
438 0 : "set_config(" + param(k as &SqlParam) + ", " + param(v as &SqlParam) + ", true)"
439 0 : })
440 0 : .join(",")
441 : }
442 0 : }
443 :
444 : // TODO: see about removing the need for cloning the values (inner things are &Cow<str> already)
445 0 : fn to_sql_param(p: &Param) -> JsonValue {
446 0 : match p {
447 0 : SV(SingleVal(v, ..)) => JsonValue::String(v.to_string()),
448 0 : Str(v) => JsonValue::String((*v).to_string()),
449 0 : StrOwned(v) => JsonValue::String((*v).clone()),
450 0 : PL(Payload(v, ..)) => JsonValue::String(v.clone().into_owned()),
451 0 : LV(ListVal(v, ..)) => {
452 0 : if v.is_empty() {
453 0 : JsonValue::String(r"{}".to_string())
454 : } else {
455 0 : JsonValue::String(format!(
456 0 : "{{\"{}\"}}",
457 0 : v.iter()
458 0 : .map(|e| e.replace('\\', "\\\\").replace('\"', "\\\""))
459 0 : .collect::<Vec<_>>()
460 0 : .join("\",\"")
461 : ))
462 : }
463 : }
464 : }
465 0 : }
466 :
467 : #[derive(serde::Serialize)]
468 : struct QueryData<'a> {
469 : query: Cow<'a, str>,
470 : params: Vec<JsonValue>,
471 : }
472 :
473 : #[derive(serde::Serialize)]
474 : struct BatchQueryData<'a> {
475 : queries: Vec<QueryData<'a>>,
476 : }
477 :
478 0 : async fn make_local_proxy_request<S: DeserializeOwned>(
479 0 : client: &mut http_conn_pool::Client<LocalProxyClient>,
480 0 : headers: impl IntoIterator<Item = (&HeaderName, HeaderValue)>,
481 0 : body: QueryData<'_>,
482 0 : max_len: usize,
483 0 : ) -> Result<S, RestError> {
484 0 : let body_string = serde_json::to_string(&body)
485 0 : .map_err(|e| RestError::JsonConversion(JsonConversionError::ParseJsonError(e)))?;
486 :
487 0 : let response = make_raw_local_proxy_request(client, headers, body_string).await?;
488 :
489 0 : let response_status = response.status();
490 :
491 0 : if response_status != StatusCode::OK {
492 0 : return Err(RestError::SubzeroCore(InternalError {
493 0 : message: "Failed to get endpoint schema".to_string(),
494 0 : }));
495 0 : }
496 :
497 : // Capture the response body
498 0 : let response_body = crate::http::read_body_with_limit(response.into_body(), max_len)
499 0 : .await
500 0 : .map_err(ReadPayloadError::from)?;
501 :
502 : // Parse the JSON response
503 0 : let response_json: S = serde_json::from_slice(&response_body)
504 0 : .map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?;
505 :
506 0 : Ok(response_json)
507 0 : }
508 :
509 0 : async fn make_raw_local_proxy_request(
510 0 : client: &mut http_conn_pool::Client<LocalProxyClient>,
511 0 : headers: impl IntoIterator<Item = (&HeaderName, HeaderValue)>,
512 0 : body: String,
513 0 : ) -> Result<Response<Incoming>, RestError> {
514 0 : let local_proxy_uri = ::http::Uri::from_static("http://proxy.local/sql");
515 0 : let mut req = Request::builder().method(Method::POST).uri(local_proxy_uri);
516 0 : let req_headers = req.headers_mut().expect("failed to get headers");
517 : // Add all provided headers to the request
518 0 : for (header_name, header_value) in headers {
519 0 : req_headers.insert(header_name, header_value.clone());
520 0 : }
521 :
522 0 : let body_boxed = Full::new(Bytes::from(body))
523 0 : .map_err(|never| match never {}) // Convert Infallible to hyper::Error
524 0 : .boxed();
525 :
526 0 : let req = req.body(body_boxed).map_err(|_| {
527 0 : RestError::SubzeroCore(InternalError {
528 0 : message: "Failed to build request".to_string(),
529 0 : })
530 0 : })?;
531 :
532 : // Send the request to the local proxy
533 0 : client
534 0 : .inner
535 0 : .inner
536 0 : .send_request(req)
537 0 : .await
538 0 : .map_err(LocalProxyConnError::from)
539 0 : .map_err(HttpConnError::from)
540 0 : .map_err(RestError::from)
541 0 : }
542 :
543 0 : pub(crate) async fn handle(
544 0 : config: &'static ProxyConfig,
545 0 : ctx: RequestContext,
546 0 : request: Request<Incoming>,
547 0 : backend: Arc<PoolingBackend>,
548 0 : cancel: CancellationToken,
549 0 : ) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
550 0 : let result = handle_inner(cancel, config, &ctx, request, backend).await;
551 :
552 0 : let response = match result {
553 0 : Ok(r) => {
554 0 : ctx.set_success();
555 :
556 : // Handling the error response from local proxy here
557 0 : if r.status().is_server_error() {
558 0 : let status = r.status();
559 :
560 0 : let body_bytes = r
561 0 : .collect()
562 0 : .await
563 0 : .map_err(|e| {
564 0 : ApiError::InternalServerError(anyhow::Error::msg(format!(
565 0 : "could not collect http body: {e}"
566 0 : )))
567 0 : })?
568 0 : .to_bytes();
569 :
570 0 : if let Ok(mut json_map) =
571 0 : serde_json::from_slice::<IndexMap<&str, &RawValue>>(&body_bytes)
572 : {
573 0 : let message = json_map.get("message");
574 0 : if let Some(message) = message {
575 0 : let msg: String = match serde_json::from_str(message.get()) {
576 0 : Ok(msg) => msg,
577 : Err(_) => {
578 0 : "Unable to parse the response message from server".to_string()
579 : }
580 : };
581 :
582 0 : error!("Error response from local_proxy: {status} {msg}");
583 :
584 0 : json_map.retain(|key, _| !key.starts_with("neon:")); // remove all the neon-related keys
585 :
586 0 : let resp_json = serde_json::to_string(&json_map)
587 0 : .unwrap_or("failed to serialize the response message".to_string());
588 :
589 0 : return json_response(status, resp_json);
590 0 : }
591 0 : }
592 :
593 0 : error!("Unable to parse the response message from local_proxy");
594 0 : return json_response(
595 0 : status,
596 0 : json!({ "message": "Unable to parse the response message from server".to_string() }),
597 : );
598 0 : }
599 0 : r
600 : }
601 0 : Err(e @ RestError::SubzeroCore(_)) => {
602 0 : let error_kind = e.get_error_kind();
603 0 : ctx.set_error_kind(error_kind);
604 :
605 0 : tracing::info!(
606 0 : kind=error_kind.to_metric_label(),
607 : error=%e,
608 : msg="subzero core error",
609 0 : "forwarding error to user"
610 : );
611 :
612 0 : let RestError::SubzeroCore(subzero_err) = e else {
613 0 : panic!("expected subzero core error")
614 : };
615 :
616 0 : let json_body = subzero_err.json_body();
617 0 : let status_code = StatusCode::from_u16(subzero_err.status_code())
618 0 : .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
619 :
620 0 : json_response(status_code, json_body)?
621 : }
622 0 : Err(e) => {
623 0 : let error_kind = e.get_error_kind();
624 0 : ctx.set_error_kind(error_kind);
625 :
626 0 : let message = e.to_string_client();
627 0 : let status_code = e.get_http_status_code();
628 :
629 0 : tracing::info!(
630 0 : kind=error_kind.to_metric_label(),
631 : error=%e,
632 : msg=message,
633 0 : "forwarding error to user"
634 : );
635 :
636 0 : let (code, detail, hint) = match e {
637 0 : RestError::Postgres(e) => (
638 0 : if e.code.starts_with("PT") {
639 0 : None
640 : } else {
641 0 : Some(e.code)
642 : },
643 0 : e.detail,
644 0 : e.hint,
645 : ),
646 0 : _ => (None, None, None),
647 : };
648 :
649 0 : json_response(
650 0 : status_code,
651 0 : json!({
652 0 : "message": message,
653 0 : "code": code,
654 0 : "detail": detail,
655 0 : "hint": hint,
656 : }),
657 0 : )?
658 : }
659 : };
660 :
661 0 : Ok(response)
662 0 : }
663 :
664 0 : async fn handle_inner(
665 0 : _cancel: CancellationToken,
666 0 : config: &'static ProxyConfig,
667 0 : ctx: &RequestContext,
668 0 : request: Request<Incoming>,
669 0 : backend: Arc<PoolingBackend>,
670 0 : ) -> Result<Response<BoxBody<Bytes, hyper::Error>>, RestError> {
671 0 : let _requeset_gauge = Metrics::get()
672 0 : .proxy
673 0 : .connection_requests
674 0 : .guard(ctx.protocol());
675 0 : info!(
676 0 : protocol = %ctx.protocol(),
677 0 : "handling interactive connection from client"
678 : );
679 :
680 : // Read host from Host, then URI host as fallback
681 : // TODO: will this be a problem if behind a load balancer?
682 : // TODO: can we use the x-forwarded-host header?
683 0 : let host = request
684 0 : .headers()
685 0 : .get(HOST)
686 0 : .and_then(|v| v.to_str().ok())
687 0 : .unwrap_or_else(|| request.uri().host().unwrap_or(""));
688 :
689 : // a valid path is /database/rest/v1/... so splitting should be ["", "database", "rest", "v1", ...]
690 0 : let database_name = request
691 0 : .uri()
692 0 : .path()
693 0 : .split('/')
694 0 : .nth(1)
695 0 : .ok_or(RestError::SubzeroCore(NotFound {
696 0 : target: request.uri().path().to_string(),
697 0 : }))?;
698 :
699 : // we always use the authenticator role to connect to the database
700 0 : let authenticator_role = "authenticator";
701 :
702 : // Strip the hostname prefix from the host to get the database hostname
703 0 : let database_host = host.replace(&config.rest_config.hostname_prefix, "");
704 :
705 0 : let connection_string =
706 0 : format!("postgresql://{authenticator_role}@{database_host}/{database_name}");
707 :
708 0 : let conn_info = get_conn_info(
709 0 : &config.authentication_config,
710 0 : ctx,
711 0 : Some(&connection_string),
712 0 : request.headers(),
713 0 : )?;
714 0 : info!(
715 0 : user = conn_info.conn_info.user_info.user.as_str(),
716 0 : "credentials"
717 : );
718 :
719 0 : match conn_info.auth {
720 0 : AuthData::Jwt(jwt) => {
721 0 : let api_prefix = format!("/{database_name}/rest/v1/");
722 0 : handle_rest_inner(
723 0 : config,
724 0 : ctx,
725 0 : &api_prefix,
726 0 : request,
727 0 : &connection_string,
728 0 : conn_info.conn_info,
729 0 : jwt,
730 0 : backend,
731 0 : )
732 0 : .await
733 : }
734 0 : AuthData::Password(_) => Err(RestError::ConnInfo(ConnInfoError::MissingCredentials(
735 0 : Credentials::BearerJwt,
736 0 : ))),
737 : }
738 0 : }
739 :
740 0 : fn apply_common_cors_headers(
741 0 : response: &mut Builder,
742 0 : request_headers: &HeaderMap,
743 0 : allowed_origins: Option<&Vec<String>>,
744 0 : ) {
745 0 : let request_origin = request_headers
746 0 : .get(ORIGIN)
747 0 : .map(|v| v.to_str().unwrap_or(""));
748 :
749 0 : let response_allow_origin = match (request_origin, allowed_origins) {
750 0 : (Some(or), Some(allowed_origins)) => {
751 0 : if allowed_origins.iter().any(|o| o == or) {
752 0 : Some(HeaderValue::from_str(or).unwrap_or(HEADER_VALUE_ALLOW_ALL_ORIGINS))
753 : } else {
754 0 : None
755 : }
756 : }
757 : // FIXME!: on the first request, when we don't have a cached entry for the config,
758 : // allowed_origins is None, so we allow all origins for now but we should fix this
759 0 : (Some(_), None) => Some(HEADER_VALUE_ALLOW_ALL_ORIGINS),
760 0 : _ => None,
761 : };
762 0 : if let Some(h) = response.headers_mut() {
763 0 : h.insert(
764 0 : ACCESS_CONTROL_EXPOSE_HEADERS,
765 0 : ACCESS_CONTROL_EXPOSE_HEADERS_VALUE,
766 : );
767 0 : if let Some(origin) = response_allow_origin {
768 0 : if origin != HEADER_VALUE_ALLOW_ALL_ORIGINS {
769 0 : h.insert(VARY, ACCESS_CONTROL_VARY_VALUE);
770 0 : }
771 0 : h.insert(ACCESS_CONTROL_ALLOW_ORIGIN, origin);
772 0 : }
773 0 : }
774 0 : }
775 :
776 : #[allow(clippy::too_many_arguments)]
777 0 : async fn handle_rest_inner(
778 0 : config: &'static ProxyConfig,
779 0 : ctx: &RequestContext,
780 0 : api_prefix: &str,
781 0 : request: Request<Incoming>,
782 0 : connection_string: &str,
783 0 : conn_info: ConnInfo,
784 0 : jwt: String,
785 0 : backend: Arc<PoolingBackend>,
786 0 : ) -> Result<Response<BoxBody<Bytes, hyper::Error>>, RestError> {
787 0 : let db_schema_cache =
788 0 : config
789 0 : .rest_config
790 0 : .db_schema_cache
791 0 : .as_ref()
792 0 : .ok_or(RestError::SubzeroCore(InternalError {
793 0 : message: "DB schema cache is not configured".to_string(),
794 0 : }))?;
795 :
796 0 : let endpoint_cache_key = conn_info
797 0 : .endpoint_cache_key()
798 0 : .ok_or(RestError::SubzeroCore(InternalError {
799 0 : message: "Failed to get endpoint cache key".to_string(),
800 0 : }))?;
801 :
802 0 : let (parts, originial_body) = request.into_parts();
803 :
804 : // try and get the cached entry for this endpoint
805 : // it contains the api config and the introspected db schema
806 0 : let cached_entry = db_schema_cache.get_cached(&endpoint_cache_key);
807 :
808 0 : let allowed_origins = cached_entry
809 0 : .as_ref()
810 0 : .and_then(|arc| arc.0.server_cors_allowed_origins.as_ref());
811 :
812 0 : let mut response = Response::builder();
813 0 : apply_common_cors_headers(&mut response, &parts.headers, allowed_origins);
814 :
815 : // handle the OPTIONS request
816 0 : if parts.method == Method::OPTIONS {
817 0 : let allowed_headers = parts
818 0 : .headers
819 0 : .get(ACCESS_CONTROL_REQUEST_HEADERS)
820 0 : .and_then(|a| a.to_str().ok())
821 0 : .filter(|v| !v.is_empty())
822 0 : .map_or_else(
823 0 : || "Authorization".to_string(),
824 0 : |v| format!("{v}, Authorization"),
825 : );
826 0 : return response
827 0 : .status(StatusCode::OK)
828 0 : .header(
829 0 : ACCESS_CONTROL_ALLOW_METHODS,
830 0 : ACCESS_CONTROL_ALLOW_METHODS_VALUE,
831 : )
832 0 : .header(ACCESS_CONTROL_MAX_AGE, ACCESS_CONTROL_MAX_AGE_VALUE)
833 0 : .header(
834 0 : ACCESS_CONTROL_ALLOW_HEADERS,
835 0 : HeaderValue::from_str(&allowed_headers)
836 0 : .unwrap_or(ACCESS_CONTROL_ALLOW_HEADERS_VALUE),
837 : )
838 0 : .header(ALLOW, ACCESS_CONTROL_ALLOW_METHODS_VALUE)
839 0 : .body(Empty::new().map_err(|x| match x {}).boxed())
840 0 : .map_err(|e| {
841 0 : RestError::SubzeroCore(InternalError {
842 0 : message: e.to_string(),
843 0 : })
844 0 : });
845 0 : }
846 :
847 : // validate the jwt token
848 0 : let jwt_parsed = backend
849 0 : .authenticate_with_jwt(ctx, &conn_info.user_info, jwt)
850 0 : .await
851 0 : .map_err(HttpConnError::from)?;
852 :
853 0 : let auth_header = parts
854 0 : .headers
855 0 : .get(AUTHORIZATION)
856 0 : .ok_or(RestError::SubzeroCore(InternalError {
857 0 : message: "Authorization header is required".to_string(),
858 0 : }))?;
859 0 : let mut client = backend.connect_to_local_proxy(ctx, conn_info).await?;
860 :
861 0 : let entry = match cached_entry {
862 0 : Some(e) => e,
863 : None => {
864 : // if not cached, get the remote entry (will run the introspection query)
865 0 : db_schema_cache
866 0 : .get_remote(
867 0 : &endpoint_cache_key,
868 0 : auth_header,
869 0 : connection_string,
870 0 : &mut client,
871 0 : ctx,
872 0 : config,
873 0 : )
874 0 : .await?
875 : }
876 : };
877 0 : let (api_config, db_schema_owned) = entry.as_ref();
878 :
879 0 : let db_schema = db_schema_owned.borrow_schema();
880 :
881 0 : let db_schemas = &api_config.db_schemas; // list of schemas available for the api
882 0 : let db_extra_search_path = &api_config.db_extra_search_path;
883 : // TODO: use this when we get a replacement for jsonpath_lib
884 : // let role_claim_key = &api_config.role_claim_key;
885 : // let role_claim_path = format!("${role_claim_key}");
886 0 : let db_anon_role = &api_config.db_anon_role;
887 0 : let max_rows = api_config.db_max_rows.as_deref();
888 0 : let db_allowed_select_functions = api_config
889 0 : .db_allowed_select_functions
890 0 : .iter()
891 0 : .map(|s| s.as_str())
892 0 : .collect::<Vec<_>>();
893 :
894 : // extract the jwt claims (we'll need them later to set the role and env)
895 0 : let jwt_claims = match jwt_parsed.keys {
896 0 : ComputeCredentialKeys::JwtPayload(payload_bytes) => {
897 : // `payload_bytes` contains the raw JWT payload as Vec<u8>
898 : // You can deserialize it back to JSON or parse specific claims
899 0 : let payload: serde_json::Value = serde_json::from_slice(&payload_bytes)
900 0 : .map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?;
901 0 : Some(payload)
902 : }
903 0 : ComputeCredentialKeys::AuthKeys(_) => None,
904 : };
905 :
906 : // read the role from the jwt claims (and set it to the "anon" role if not present)
907 0 : let (role, authenticated) = match &jwt_claims {
908 0 : Some(claims) => match claims.get("role") {
909 0 : Some(JsonValue::String(r)) => (Some(r), true),
910 0 : _ => (db_anon_role.as_ref(), true),
911 : },
912 0 : None => (db_anon_role.as_ref(), false),
913 : };
914 :
915 : // do not allow unauthenticated requests when there is no anonymous role setup
916 0 : if let (None, false) = (role, authenticated) {
917 0 : return Err(RestError::SubzeroCore(JwtTokenInvalid {
918 0 : message: "unauthenticated requests not allowed".to_string(),
919 0 : }));
920 0 : }
921 :
922 : // start deconstructing the request because subzero core mostly works with &str
923 0 : let method = parts.method;
924 0 : let method_str = method.as_str();
925 0 : let path = parts.uri.path_and_query().map_or("/", |pq| pq.as_str());
926 :
927 : // this is actually the table name (or rpc/function_name)
928 : // TODO: rename this to something more descriptive
929 0 : let root = match parts.uri.path().strip_prefix(api_prefix) {
930 0 : Some(p) => Ok(p),
931 0 : None => Err(RestError::SubzeroCore(NotFound {
932 0 : target: parts.uri.path().to_string(),
933 0 : })),
934 0 : }?;
935 :
936 : // pick the current schema from the headers (or the first one from config)
937 0 : let schema_name = &DbSchema::pick_current_schema(db_schemas, method_str, &parts.headers)?;
938 :
939 : // add the content-profile header to the response
940 0 : let mut response_headers = vec![];
941 0 : if db_schemas.len() > 1 {
942 0 : response_headers.push(("Content-Profile".to_string(), schema_name.clone()));
943 0 : }
944 :
945 : // parse the query string into a Vec<(&str, &str)>
946 0 : let query = match parts.uri.query() {
947 0 : Some(q) => form_urlencoded::parse(q.as_bytes()).collect(),
948 0 : None => vec![],
949 : };
950 0 : let get: Vec<(&str, &str)> = query.iter().map(|(k, v)| (&**k, &**v)).collect();
951 :
952 : // convert the headers map to a HashMap<&str, &str>
953 0 : let headers: HashMap<&str, &str> = parts
954 0 : .headers
955 0 : .iter()
956 0 : .map(|(k, v)| (k.as_str(), v.to_str().unwrap_or("__BAD_HEADER__")))
957 0 : .collect();
958 :
959 0 : let cookies = HashMap::new(); // TODO: add cookies
960 :
961 : // Read the request body (skip for GET requests)
962 0 : let body_as_string: Option<String> = if method == Method::GET {
963 0 : None
964 : } else {
965 0 : let body_bytes =
966 0 : read_body_with_limit(originial_body, config.http_config.max_request_size_bytes)
967 0 : .await
968 0 : .map_err(ReadPayloadError::from)?;
969 0 : if body_bytes.is_empty() {
970 0 : None
971 : } else {
972 0 : Some(String::from_utf8_lossy(&body_bytes).into_owned())
973 : }
974 : };
975 :
976 : // parse the request into an ApiRequest struct
977 0 : let mut api_request = parse(
978 0 : schema_name,
979 0 : root,
980 0 : db_schema,
981 0 : method_str,
982 0 : path,
983 0 : get,
984 0 : body_as_string.as_deref(),
985 0 : headers,
986 0 : cookies,
987 0 : max_rows,
988 : )
989 0 : .map_err(RestError::SubzeroCore)?;
990 :
991 0 : let role_str = match role {
992 0 : Some(r) => r,
993 0 : None => "",
994 : };
995 :
996 0 : replace_select_star(db_schema, schema_name, role_str, &mut api_request.query)?;
997 :
998 : // TODO: this is not relevant when acting as PostgREST but will be useful
999 : // in the context of DBX where they need internal permissions
1000 : // if !disable_internal_permissions {
1001 : // check_privileges(db_schema, schema_name, role_str, &api_request)?;
1002 : // }
1003 :
1004 0 : check_safe_functions(&api_request, &db_allowed_select_functions)?;
1005 :
1006 : // TODO: this is not relevant when acting as PostgREST but will be useful
1007 : // in the context of DBX where they need internal permissions
1008 : // if !disable_internal_permissions {
1009 : // insert_policy_conditions(db_schema, schema_name, role_str, &mut api_request.query)?;
1010 : // }
1011 :
1012 0 : let env_role = Some(role_str);
1013 :
1014 : // construct the env (passed in to the sql context as GUCs)
1015 0 : let empty_json = "{}".to_string();
1016 0 : let headers_env = serde_json::to_string(&api_request.headers).unwrap_or(empty_json.clone());
1017 0 : let cookies_env = serde_json::to_string(&api_request.cookies).unwrap_or(empty_json.clone());
1018 0 : let get_env = serde_json::to_string(&api_request.get).unwrap_or(empty_json.clone());
1019 0 : let jwt_claims_env = jwt_claims
1020 0 : .as_ref()
1021 0 : .map(|v| serde_json::to_string(v).unwrap_or(empty_json.clone()))
1022 0 : .unwrap_or(if let Some(r) = env_role {
1023 0 : let claims: HashMap<&str, &str> = HashMap::from([("role", r)]);
1024 0 : serde_json::to_string(&claims).unwrap_or(empty_json.clone())
1025 : } else {
1026 0 : empty_json.clone()
1027 : });
1028 0 : let mut search_path = vec![api_request.schema_name];
1029 0 : if let Some(extra) = &db_extra_search_path {
1030 0 : search_path.extend(extra.iter().map(|s| s.as_str()));
1031 0 : }
1032 0 : let search_path_str = search_path
1033 0 : .into_iter()
1034 0 : .filter(|s| !s.is_empty())
1035 0 : .collect::<Vec<_>>()
1036 0 : .join(",");
1037 0 : let mut env: HashMap<&str, &str> = HashMap::from([
1038 0 : ("request.method", api_request.method),
1039 0 : ("request.path", api_request.path),
1040 0 : ("request.headers", &headers_env),
1041 0 : ("request.cookies", &cookies_env),
1042 0 : ("request.get", &get_env),
1043 0 : ("request.jwt.claims", &jwt_claims_env),
1044 0 : ("search_path", &search_path_str),
1045 0 : ]);
1046 0 : if let Some(r) = env_role {
1047 0 : env.insert("role", r);
1048 0 : }
1049 :
1050 : // generate the sql statements
1051 0 : let (env_statement, env_parameters, _) = generate(fmt_env_query(&env));
1052 0 : let (main_statement, main_parameters, _) = generate(fmt_main_query(
1053 0 : db_schema,
1054 0 : api_request.schema_name,
1055 0 : &api_request,
1056 0 : &env,
1057 0 : )?);
1058 :
1059 0 : let mut headers = vec![
1060 0 : (&NEON_REQUEST_ID, uuid_to_header_value(ctx.session_id())),
1061 0 : (
1062 0 : &CONN_STRING,
1063 0 : HeaderValue::from_str(connection_string).expect("invalid connection string"),
1064 0 : ),
1065 0 : (&AUTHORIZATION, auth_header.clone()),
1066 0 : (
1067 0 : &TXN_ISOLATION_LEVEL,
1068 0 : HeaderValue::from_static("ReadCommitted"),
1069 0 : ),
1070 0 : (&ALLOW_POOL, HEADER_VALUE_TRUE),
1071 : ];
1072 :
1073 0 : if api_request.read_only {
1074 0 : headers.push((&TXN_READ_ONLY, HEADER_VALUE_TRUE));
1075 0 : }
1076 :
1077 : // convert the parameters from subzero core representation to the local proxy repr.
1078 0 : let req_body = serde_json::to_string(&BatchQueryData {
1079 0 : queries: vec![
1080 : QueryData {
1081 0 : query: env_statement.into(),
1082 0 : params: env_parameters
1083 0 : .iter()
1084 0 : .map(|p| to_sql_param(&p.to_param()))
1085 0 : .collect(),
1086 : },
1087 : QueryData {
1088 0 : query: main_statement.into(),
1089 0 : params: main_parameters
1090 0 : .iter()
1091 0 : .map(|p| to_sql_param(&p.to_param()))
1092 0 : .collect(),
1093 : },
1094 : ],
1095 : })
1096 0 : .map_err(|e| RestError::JsonConversion(JsonConversionError::ParseJsonError(e)))?;
1097 :
1098 : // todo: map body to count egress
1099 0 : let _metrics = client.metrics(ctx); // FIXME: is everything in the context set correctly?
1100 :
1101 : // send the request to the local proxy
1102 0 : let proxy_response = make_raw_local_proxy_request(&mut client, headers, req_body).await?;
1103 0 : let (response_parts, body) = proxy_response.into_parts();
1104 :
1105 0 : let max_response = config.http_config.max_response_size_bytes;
1106 0 : let bytes = read_body_with_limit(body, max_response)
1107 0 : .await
1108 0 : .map_err(ReadPayloadError::from)?;
1109 :
1110 : // if the response status is greater than 399, then it is an error
1111 : // FIXME: check if there are other error codes or shapes of the response
1112 0 : if response_parts.status.as_u16() > 399 {
1113 : // turn this postgres error from the json into PostgresError
1114 0 : let postgres_error = serde_json::from_slice(&bytes)
1115 0 : .map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?;
1116 :
1117 0 : return Err(RestError::Postgres(postgres_error));
1118 0 : }
1119 :
1120 0 : #[derive(Deserialize)]
1121 : struct QueryResults {
1122 : /// we run two queries, so we want only two results.
1123 : results: (EnvRows, MainRows),
1124 : }
1125 :
1126 : /// `env_statement` returns nothing of interest to us
1127 0 : #[derive(Deserialize)]
1128 : struct EnvRows {}
1129 :
1130 0 : #[derive(Deserialize)]
1131 : struct MainRows {
1132 : /// `main_statement` only returns a single row.
1133 : rows: [MainRow; 1],
1134 : }
1135 :
1136 0 : #[derive(Deserialize)]
1137 : struct MainRow {
1138 : body: String,
1139 : page_total: Option<String>,
1140 : total_result_set: Option<String>,
1141 : response_headers: Option<String>,
1142 : response_status: Option<String>,
1143 : }
1144 :
1145 0 : let results: QueryResults = serde_json::from_slice(&bytes)
1146 0 : .map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?;
1147 :
1148 : let QueryResults {
1149 0 : results: (_, MainRows { rows: [row] }),
1150 0 : } = results;
1151 :
1152 : // build the intermediate response object
1153 0 : let api_response = ApiResponse {
1154 0 : page_total: row.page_total.map_or(0, |v| v.parse::<u64>().unwrap_or(0)),
1155 0 : total_result_set: row.total_result_set.map(|v| v.parse::<u64>().unwrap_or(0)),
1156 : top_level_offset: 0, // FIXME: check why this is 0
1157 0 : response_headers: row.response_headers,
1158 0 : response_status: row.response_status,
1159 0 : body: row.body,
1160 : };
1161 :
1162 : // TODO: rollback the transaction if the page_total is not 1 and the accept_content_type is SingularJSON
1163 : // we can not do this in the context of proxy for now
1164 : // if api_request.accept_content_type == SingularJSON && api_response.page_total != 1 {
1165 : // // rollback the transaction here
1166 : // return Err(RestError::SubzeroCore(SingularityError {
1167 : // count: api_response.page_total,
1168 : // content_type: "application/vnd.pgrst.object+json".to_string(),
1169 : // }));
1170 : // }
1171 :
1172 : // TODO: rollback the transaction if the page_total is not 1 and the method is PUT
1173 : // we can not do this in the context of proxy for now
1174 : // if api_request.method == Method::PUT && api_response.page_total != 1 {
1175 : // // Makes sure the querystring pk matches the payload pk
1176 : // // e.g. PUT /items?id=eq.1 { "id" : 1, .. } is accepted,
1177 : // // PUT /items?id=eq.14 { "id" : 2, .. } is rejected.
1178 : // // If this condition is not satisfied then nothing is inserted,
1179 : // // rollback the transaction here
1180 : // return Err(RestError::SubzeroCore(PutMatchingPkError));
1181 : // }
1182 :
1183 : // create and return the response to the client
1184 : // this section mostly deals with setting the right headers according to PostgREST specs
1185 0 : let page_total = api_response.page_total;
1186 0 : let total_result_set = api_response.total_result_set;
1187 0 : let top_level_offset = api_response.top_level_offset;
1188 0 : let response_content_type = match (&api_request.accept_content_type, &api_request.query.node) {
1189 : (SingularJSON, _)
1190 : | (
1191 : _,
1192 : FunctionCall {
1193 : returns_single: true,
1194 : is_scalar: false,
1195 : ..
1196 : },
1197 0 : ) => SingularJSON,
1198 0 : (TextCSV, _) => TextCSV,
1199 0 : _ => ApplicationJSON,
1200 : };
1201 :
1202 : // check if the SQL env set some response headers (happens when we called a rpc function)
1203 0 : if let Some(response_headers_str) = api_response.response_headers {
1204 0 : let Ok(headers_json) =
1205 0 : serde_json::from_str::<Vec<Vec<(String, String)>>>(response_headers_str.as_str())
1206 : else {
1207 0 : return Err(RestError::SubzeroCore(GucHeadersError));
1208 : };
1209 :
1210 0 : response_headers.extend(headers_json.into_iter().flatten());
1211 0 : }
1212 :
1213 : // calculate and set the content range header
1214 0 : let lower = top_level_offset as i64;
1215 0 : let upper = top_level_offset as i64 + page_total as i64 - 1;
1216 0 : let total = total_result_set.map(|t| t as i64);
1217 0 : let content_range = match (&method, &api_request.query.node) {
1218 0 : (&Method::POST, Insert { .. }) => content_range_header(1, 0, total),
1219 0 : (&Method::DELETE, Delete { .. }) => content_range_header(1, upper, total),
1220 0 : _ => content_range_header(lower, upper, total),
1221 : };
1222 0 : response_headers.push(("Content-Range".to_string(), content_range));
1223 :
1224 : // calculate the status code
1225 : #[rustfmt::skip]
1226 0 : let mut status = match (&method, &api_request.query.node, page_total, &api_request.preferences) {
1227 0 : (&Method::POST, Insert { .. }, ..) => 201,
1228 0 : (&Method::DELETE, Delete { .. }, _, Some(Preferences {representation: Some(Representation::Full),..}),) => 200,
1229 0 : (&Method::DELETE, Delete { .. }, ..) => 204,
1230 0 : (&Method::PATCH, Update { columns, .. }, 0, _) if !columns.is_empty() => 404,
1231 0 : (&Method::PATCH, Update { .. }, _,Some(Preferences {representation: Some(Representation::Full),..}),) => 200,
1232 0 : (&Method::PATCH, Update { .. }, ..) => 204,
1233 0 : (&Method::PUT, Insert { .. },_,Some(Preferences {representation: Some(Representation::Full),..}),) => 200,
1234 0 : (&Method::PUT, Insert { .. }, ..) => 204,
1235 0 : _ => content_range_status(lower, upper, total),
1236 : };
1237 :
1238 : // add the preference-applied header
1239 : if let Some(Preferences {
1240 0 : resolution: Some(r),
1241 : ..
1242 0 : }) = api_request.preferences
1243 : {
1244 0 : response_headers.push((
1245 0 : "Preference-Applied".to_string(),
1246 0 : match r {
1247 0 : MergeDuplicates => "resolution=merge-duplicates".to_string(),
1248 0 : IgnoreDuplicates => "resolution=ignore-duplicates".to_string(),
1249 : },
1250 : ));
1251 0 : }
1252 :
1253 : // check if the SQL env set some response status (happens when we called a rpc function)
1254 0 : if let Some(response_status_str) = api_response.response_status {
1255 0 : status = response_status_str
1256 0 : .parse::<u16>()
1257 0 : .map_err(|_| RestError::SubzeroCore(GucStatusError))?;
1258 0 : }
1259 :
1260 : // set the content type header
1261 : // TODO: move this to a subzero function
1262 : // as_header_value(&self) -> Option<&str>
1263 0 : let http_content_type = match response_content_type {
1264 0 : SingularJSON => Ok("application/vnd.pgrst.object+json"),
1265 0 : TextCSV => Ok("text/csv"),
1266 0 : ApplicationJSON => Ok("application/json"),
1267 0 : Other(t) => Err(RestError::SubzeroCore(ContentTypeError {
1268 0 : message: format!("None of these Content-Types are available: {t}"),
1269 0 : })),
1270 0 : }?;
1271 :
1272 : // build the response body
1273 0 : let response_body = Full::new(Bytes::from(api_response.body))
1274 0 : .map_err(|never| match never {})
1275 0 : .boxed();
1276 :
1277 : // build the response
1278 0 : response = response
1279 0 : .status(StatusCode::from_u16(status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR))
1280 0 : .header(CONTENT_TYPE, http_content_type);
1281 :
1282 : // Add all headers from response_headers vector
1283 0 : for (header_name, header_value) in response_headers {
1284 0 : response = response.header(header_name, header_value);
1285 0 : }
1286 :
1287 : // add the body and return the response
1288 0 : response.body(response_body).map_err(|_| {
1289 0 : RestError::SubzeroCore(InternalError {
1290 0 : message: "Failed to build response".to_string(),
1291 0 : })
1292 0 : })
1293 0 : }
|