LCOV - differential code coverage report
Current view: top level - proxy/src/http - sql_over_http.rs (source / functions) Coverage Total Hit UBC CBC
Current: f6946e90941b557c917ac98cd5a7e9506d180f3e.info Lines: 92.6 % 538 498 40 498
Current Date: 2023-10-19 02:04:12 Functions: 67.0 % 103 69 34 69
Baseline: c8637f37369098875162f194f92736355783b050.info
Baseline Date: 2023-10-18 20:25:20

           TLA  Line data    Source code
       1                 : use std::sync::Arc;
       2                 : 
       3                 : use anyhow::bail;
       4                 : use futures::pin_mut;
       5                 : use futures::StreamExt;
       6                 : use hyper::body::HttpBody;
       7                 : use hyper::header;
       8                 : use hyper::http::HeaderName;
       9                 : use hyper::http::HeaderValue;
      10                 : use hyper::Response;
      11                 : use hyper::StatusCode;
      12                 : use hyper::{Body, HeaderMap, Request};
      13                 : use serde_json::json;
      14                 : use serde_json::Map;
      15                 : use serde_json::Value;
      16                 : use tokio_postgres::types::Kind;
      17                 : use tokio_postgres::types::Type;
      18                 : use tokio_postgres::GenericClient;
      19                 : use tokio_postgres::IsolationLevel;
      20                 : use tokio_postgres::ReadyForQueryStatus;
      21                 : use tokio_postgres::Row;
      22                 : use tokio_postgres::Transaction;
      23                 : use tracing::error;
      24                 : use tracing::instrument;
      25                 : use url::Url;
      26                 : use utils::http::error::ApiError;
      27                 : use utils::http::json::json_response;
      28                 : 
      29                 : use crate::config::HttpConfig;
      30                 : use crate::proxy::{NUM_CONNECTIONS_ACCEPTED_COUNTER, NUM_CONNECTIONS_CLOSED_COUNTER};
      31                 : 
      32                 : use super::conn_pool::ConnInfo;
      33                 : use super::conn_pool::GlobalConnPool;
      34                 : 
      35 CBC         201 : #[derive(serde::Deserialize)]
      36                 : struct QueryData {
      37                 :     query: String,
      38                 :     params: Vec<serde_json::Value>,
      39                 : }
      40                 : 
      41               6 : #[derive(serde::Deserialize)]
      42                 : struct BatchQueryData {
      43                 :     queries: Vec<QueryData>,
      44                 : }
      45                 : 
      46              32 : #[derive(serde::Deserialize)]
      47                 : #[serde(untagged)]
      48                 : enum Payload {
      49                 :     Single(QueryData),
      50                 :     Batch(BatchQueryData),
      51                 : }
      52                 : 
      53                 : const MAX_RESPONSE_SIZE: usize = 10 * 1024 * 1024; // 10 MiB
      54                 : const MAX_REQUEST_SIZE: u64 = 10 * 1024 * 1024; // 10 MiB
      55                 : 
      56                 : static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output");
      57                 : static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode");
      58                 : static ALLOW_POOL: HeaderName = HeaderName::from_static("neon-pool-opt-in");
      59                 : static TXN_ISOLATION_LEVEL: HeaderName = HeaderName::from_static("neon-batch-isolation-level");
      60                 : static TXN_READ_ONLY: HeaderName = HeaderName::from_static("neon-batch-read-only");
      61                 : static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrable");
      62                 : 
      63                 : static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true");
      64                 : 
      65                 : //
      66                 : // Convert json non-string types to strings, so that they can be passed to Postgres
      67                 : // as parameters.
      68                 : //
      69              44 : fn json_to_pg_text(json: Vec<Value>) -> Vec<Option<String>> {
      70              44 :     json.iter()
      71              44 :         .map(|value| {
      72              18 :             match value {
      73                 :                 // special care for nulls
      74               1 :                 Value::Null => None,
      75                 : 
      76                 :                 // convert to text with escaping
      77               9 :                 v @ (Value::Bool(_) | Value::Number(_) | Value::Object(_)) => Some(v.to_string()),
      78                 : 
      79                 :                 // avoid escaping here, as we pass this as a parameter
      80               1 :                 Value::String(s) => Some(s.to_string()),
      81                 : 
      82                 :                 // special care for arrays
      83               7 :                 Value::Array(_) => json_array_to_pg_array(value),
      84                 :             }
      85              44 :         })
      86              44 :         .collect()
      87              44 : }
      88                 : 
      89                 : //
      90                 : // Serialize a JSON array to a Postgres array. Contrary to the strings in the params
      91                 : // in the array we need to escape the strings. Postgres is okay with arrays of form
      92                 : // '{1,"2",3}'::int[], so we don't check that array holds values of the same type, leaving
      93                 : // it for Postgres to check.
      94                 : //
      95                 : // Example of the same escaping in node-postgres: packages/pg/lib/utils.js
      96                 : //
      97              39 : fn json_array_to_pg_array(value: &Value) -> Option<String> {
      98              39 :     match value {
      99                 :         // special care for nulls
     100               2 :         Value::Null => None,
     101                 : 
     102                 :         // convert to text with escaping
     103                 :         // here string needs to be escaped, as it is part of the array
     104              22 :         v @ (Value::Bool(_) | Value::Number(_) | Value::String(_)) => Some(v.to_string()),
     105               5 :         v @ Value::Object(_) => json_array_to_pg_array(&Value::String(v.to_string())),
     106                 : 
     107                 :         // recurse into array
     108              10 :         Value::Array(arr) => {
     109              10 :             let vals = arr
     110              10 :                 .iter()
     111              10 :                 .map(json_array_to_pg_array)
     112              27 :                 .map(|v| v.unwrap_or_else(|| "NULL".to_string()))
     113              10 :                 .collect::<Vec<_>>()
     114              10 :                 .join(",");
     115              10 : 
     116              10 :             Some(format!("{{{}}}", vals))
     117                 :         }
     118                 :     }
     119              39 : }
     120                 : 
     121              30 : fn get_conn_info(
     122              30 :     headers: &HeaderMap,
     123              30 :     sni_hostname: Option<String>,
     124              30 : ) -> Result<ConnInfo, anyhow::Error> {
     125              30 :     let connection_string = headers
     126              30 :         .get("Neon-Connection-String")
     127              30 :         .ok_or(anyhow::anyhow!("missing connection string"))?
     128              30 :         .to_str()?;
     129                 : 
     130              30 :     let connection_url = Url::parse(connection_string)?;
     131                 : 
     132              30 :     let protocol = connection_url.scheme();
     133              30 :     if protocol != "postgres" && protocol != "postgresql" {
     134 UBC           0 :         return Err(anyhow::anyhow!(
     135               0 :             "connection string must start with postgres: or postgresql:"
     136               0 :         ));
     137 CBC          30 :     }
     138                 : 
     139              30 :     let mut url_path = connection_url
     140              30 :         .path_segments()
     141              30 :         .ok_or(anyhow::anyhow!("missing database name"))?;
     142                 : 
     143              30 :     let dbname = url_path
     144              30 :         .next()
     145              30 :         .ok_or(anyhow::anyhow!("invalid database name"))?;
     146                 : 
     147              30 :     let username = connection_url.username();
     148              30 :     if username.is_empty() {
     149 UBC           0 :         return Err(anyhow::anyhow!("missing username"));
     150 CBC          30 :     }
     151                 : 
     152              30 :     let password = connection_url
     153              30 :         .password()
     154              30 :         .ok_or(anyhow::anyhow!("no password"))?;
     155                 : 
     156                 :     // TLS certificate selector now based on SNI hostname, so if we are running here
     157                 :     // we are sure that SNI hostname is set to one of the configured domain names.
     158              30 :     let sni_hostname = sni_hostname.ok_or(anyhow::anyhow!("no SNI hostname set"))?;
     159                 : 
     160              30 :     let hostname = connection_url
     161              30 :         .host_str()
     162              30 :         .ok_or(anyhow::anyhow!("no host"))?;
     163                 : 
     164              30 :     let host_header = headers
     165              30 :         .get("host")
     166              30 :         .and_then(|h| h.to_str().ok())
     167              30 :         .and_then(|h| h.split(':').next());
     168              30 : 
     169              30 :     if hostname != sni_hostname {
     170 UBC           0 :         return Err(anyhow::anyhow!("mismatched SNI hostname and hostname"));
     171 CBC          30 :     } else if let Some(h) = host_header {
     172              30 :         if h != hostname {
     173 UBC           0 :             return Err(anyhow::anyhow!("mismatched host header and hostname"));
     174 CBC          30 :         }
     175 UBC           0 :     }
     176                 : 
     177 CBC          30 :     Ok(ConnInfo {
     178              30 :         username: username.to_owned(),
     179              30 :         dbname: dbname.to_owned(),
     180              30 :         hostname: hostname.to_owned(),
     181              30 :         password: password.to_owned(),
     182              30 :     })
     183              30 : }
     184                 : 
     185                 : // TODO: return different http error codes
     186              30 : pub async fn handle(
     187              30 :     request: Request<Body>,
     188              30 :     sni_hostname: Option<String>,
     189              30 :     conn_pool: Arc<GlobalConnPool>,
     190              30 :     session_id: uuid::Uuid,
     191              30 :     config: &'static HttpConfig,
     192              30 : ) -> Result<Response<Body>, ApiError> {
     193              30 :     let result = tokio::time::timeout(
     194              30 :         config.sql_over_http_timeout,
     195              30 :         handle_inner(request, sni_hostname, conn_pool, session_id),
     196              30 :     )
     197             233 :     .await;
     198              30 :     let mut response = match result {
     199              30 :         Ok(r) => match r {
     200              27 :             Ok(r) => r,
     201               3 :             Err(e) => {
     202               3 :                 let message = format!("{:?}", e);
     203               3 :                 let code = e.downcast_ref::<tokio_postgres::Error>().and_then(|e| {
     204               3 :                     e.code()
     205               3 :                         .map(|s| serde_json::to_value(s.code()).unwrap_or_default())
     206               3 :                 });
     207               3 :                 let code = match code {
     208               3 :                     Some(c) => c,
     209 UBC           0 :                     None => Value::Null,
     210                 :                 };
     211 CBC           3 :                 error!(
     212               3 :                     ?code,
     213               3 :                     "sql-over-http per-client task finished with an error: {e:#}"
     214               3 :                 );
     215                 :                 // TODO: this shouldn't always be bad request.
     216               3 :                 json_response(
     217               3 :                     StatusCode::BAD_REQUEST,
     218               3 :                     json!({ "message": message, "code": code }),
     219               3 :                 )?
     220                 :             }
     221                 :         },
     222                 :         Err(_) => {
     223 UBC           0 :             let message = format!(
     224               0 :                 "HTTP-Connection timed out, execution time exeeded {} seconds",
     225               0 :                 config.sql_over_http_timeout.as_secs()
     226               0 :             );
     227               0 :             error!(message);
     228               0 :             json_response(
     229               0 :                 StatusCode::GATEWAY_TIMEOUT,
     230               0 :                 json!({ "message": message, "code": StatusCode::GATEWAY_TIMEOUT.as_u16() }),
     231               0 :             )?
     232                 :         }
     233                 :     };
     234 CBC          30 :     response.headers_mut().insert(
     235              30 :         "Access-Control-Allow-Origin",
     236              30 :         hyper::http::HeaderValue::from_static("*"),
     237              30 :     );
     238              30 :     Ok(response)
     239              30 : }
     240                 : 
     241              90 : #[instrument(name = "sql-over-http", skip_all)]
     242                 : async fn handle_inner(
     243                 :     request: Request<Body>,
     244                 :     sni_hostname: Option<String>,
     245                 :     conn_pool: Arc<GlobalConnPool>,
     246                 :     session_id: uuid::Uuid,
     247                 : ) -> anyhow::Result<Response<Body>> {
     248                 :     NUM_CONNECTIONS_ACCEPTED_COUNTER
     249                 :         .with_label_values(&["http"])
     250                 :         .inc();
     251              30 :     scopeguard::defer! {
     252              30 :         NUM_CONNECTIONS_CLOSED_COUNTER.with_label_values(&["http"]).inc();
     253              30 :     }
     254                 : 
     255                 :     //
     256                 :     // Determine the destination and connection params
     257                 :     //
     258                 :     let headers = request.headers();
     259                 :     let conn_info = get_conn_info(headers, sni_hostname)?;
     260                 : 
     261                 :     // Determine the output options. Default behaviour is 'false'. Anything that is not
     262                 :     // strictly 'true' assumed to be false.
     263                 :     let raw_output = headers.get(&RAW_TEXT_OUTPUT) == Some(&HEADER_VALUE_TRUE);
     264                 :     let array_mode = headers.get(&ARRAY_MODE) == Some(&HEADER_VALUE_TRUE);
     265                 : 
     266                 :     // Allow connection pooling only if explicitly requested
     267                 :     let allow_pool = headers.get(&ALLOW_POOL) == Some(&HEADER_VALUE_TRUE);
     268                 : 
     269                 :     // isolation level, read only and deferrable
     270                 : 
     271                 :     let txn_isolation_level_raw = headers.get(&TXN_ISOLATION_LEVEL).cloned();
     272                 :     let txn_isolation_level = match txn_isolation_level_raw {
     273                 :         Some(ref x) => Some(match x.as_bytes() {
     274                 :             b"Serializable" => IsolationLevel::Serializable,
     275                 :             b"ReadUncommitted" => IsolationLevel::ReadUncommitted,
     276                 :             b"ReadCommitted" => IsolationLevel::ReadCommitted,
     277                 :             b"RepeatableRead" => IsolationLevel::RepeatableRead,
     278                 :             _ => bail!("invalid isolation level"),
     279                 :         }),
     280                 :         None => None,
     281                 :     };
     282                 : 
     283                 :     let txn_read_only = headers.get(&TXN_READ_ONLY) == Some(&HEADER_VALUE_TRUE);
     284                 :     let txn_deferrable = headers.get(&TXN_DEFERRABLE) == Some(&HEADER_VALUE_TRUE);
     285                 : 
     286                 :     let request_content_length = match request.body().size_hint().upper() {
     287                 :         Some(v) => v,
     288                 :         None => MAX_REQUEST_SIZE + 1,
     289                 :     };
     290                 : 
     291                 :     // we don't have a streaming request support yet so this is to prevent OOM
     292                 :     // from a malicious user sending an extremely large request body
     293                 :     if request_content_length > MAX_REQUEST_SIZE {
     294                 :         return Err(anyhow::anyhow!(
     295                 :             "request is too large (max is {MAX_REQUEST_SIZE} bytes)"
     296                 :         ));
     297                 :     }
     298                 : 
     299                 :     //
     300                 :     // Read the query and query params from the request body
     301                 :     //
     302                 :     let body = hyper::body::to_bytes(request.into_body()).await?;
     303                 :     let payload: Payload = serde_json::from_slice(&body)?;
     304                 : 
     305                 :     let mut client = conn_pool.get(&conn_info, !allow_pool, session_id).await?;
     306                 : 
     307                 :     let mut response = Response::builder()
     308                 :         .status(StatusCode::OK)
     309                 :         .header(header::CONTENT_TYPE, "application/json");
     310                 : 
     311                 :     //
     312                 :     // Now execute the query and return the result
     313                 :     //
     314                 :     let mut size = 0;
     315                 :     let result =
     316                 :         match payload {
     317                 :             Payload::Single(stmt) => {
     318                 :                 let (status, results) =
     319                 :                     query_to_json(&*client, stmt, &mut 0, raw_output, array_mode)
     320                 :                         .await
     321               1 :                         .map_err(|e| {
     322               1 :                             client.discard();
     323               1 :                             e
     324               1 :                         })?;
     325                 :                 client.check_idle(status);
     326                 :                 results
     327                 :             }
     328                 :             Payload::Batch(statements) => {
     329                 :                 let (inner, mut discard) = client.inner();
     330                 :                 let mut builder = inner.build_transaction();
     331                 :                 if let Some(isolation_level) = txn_isolation_level {
     332                 :                     builder = builder.isolation_level(isolation_level);
     333                 :                 }
     334                 :                 if txn_read_only {
     335                 :                     builder = builder.read_only(true);
     336                 :                 }
     337                 :                 if txn_deferrable {
     338                 :                     builder = builder.deferrable(true);
     339                 :                 }
     340                 : 
     341 UBC           0 :                 let transaction = builder.start().await.map_err(|e| {
     342               0 :                     // if we cannot start a transaction, we should return immediately
     343               0 :                     // and not return to the pool. connection is clearly broken
     344               0 :                     discard.discard();
     345               0 :                     e
     346               0 :                 })?;
     347                 : 
     348                 :                 let results =
     349                 :                     match query_batch(&transaction, statements, &mut size, raw_output, array_mode)
     350                 :                         .await
     351                 :                     {
     352                 :                         Ok(results) => {
     353               0 :                             let status = transaction.commit().await.map_err(|e| {
     354               0 :                                 // if we cannot commit - for now don't return connection to pool
     355               0 :                                 // TODO: get a query status from the error
     356               0 :                                 discard.discard();
     357               0 :                                 e
     358               0 :                             })?;
     359                 :                             discard.check_idle(status);
     360                 :                             results
     361                 :                         }
     362                 :                         Err(err) => {
     363               0 :                             let status = transaction.rollback().await.map_err(|e| {
     364               0 :                                 // if we cannot rollback - for now don't return connection to pool
     365               0 :                                 // TODO: get a query status from the error
     366               0 :                                 discard.discard();
     367               0 :                                 e
     368               0 :                             })?;
     369                 :                             discard.check_idle(status);
     370                 :                             return Err(err);
     371                 :                         }
     372                 :                     };
     373                 : 
     374                 :                 if txn_read_only {
     375                 :                     response = response.header(
     376                 :                         TXN_READ_ONLY.clone(),
     377                 :                         HeaderValue::try_from(txn_read_only.to_string())?,
     378                 :                     );
     379                 :                 }
     380                 :                 if txn_deferrable {
     381                 :                     response = response.header(
     382                 :                         TXN_DEFERRABLE.clone(),
     383                 :                         HeaderValue::try_from(txn_deferrable.to_string())?,
     384                 :                     );
     385                 :                 }
     386                 :                 if let Some(txn_isolation_level) = txn_isolation_level_raw {
     387                 :                     response = response.header(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level);
     388                 :                 }
     389                 :                 json!({ "results": results })
     390                 :             }
     391                 :         };
     392                 : 
     393                 :     let metrics = client.metrics();
     394                 : 
     395                 :     // how could this possibly fail
     396                 :     let body = serde_json::to_string(&result).expect("json serialization should not fail");
     397                 :     let len = body.len();
     398                 :     let response = response
     399                 :         .body(Body::from(body))
     400                 :         // only fails if invalid status code or invalid header/values are given.
     401                 :         // these are not user configurable so it cannot fail dynamically
     402                 :         .expect("building response payload should not fail");
     403                 : 
     404                 :     // count the egress bytes - we miss the TLS and header overhead but oh well...
     405                 :     // moving this later in the stack is going to be a lot of effort and ehhhh
     406                 :     metrics.record_egress(len as u64);
     407                 : 
     408                 :     Ok(response)
     409                 : }
     410                 : 
     411 CBC           2 : async fn query_batch(
     412               2 :     transaction: &Transaction<'_>,
     413               2 :     queries: BatchQueryData,
     414               2 :     total_size: &mut usize,
     415               2 :     raw_output: bool,
     416               2 :     array_mode: bool,
     417               2 : ) -> anyhow::Result<Vec<Value>> {
     418               2 :     let mut results = Vec::with_capacity(queries.queries.len());
     419               2 :     let mut current_size = 0;
     420              13 :     for stmt in queries.queries {
     421              11 :         // TODO: maybe we should check that the transaction bit is set here
     422              11 :         let (_, values) =
     423              22 :             query_to_json(transaction, stmt, &mut current_size, raw_output, array_mode).await?;
     424              11 :         results.push(values);
     425                 :     }
     426               2 :     *total_size += current_size;
     427               2 :     Ok(results)
     428               2 : }
     429                 : 
     430              37 : async fn query_to_json<T: GenericClient>(
     431              37 :     client: &T,
     432              37 :     data: QueryData,
     433              37 :     current_size: &mut usize,
     434              37 :     raw_output: bool,
     435              37 :     array_mode: bool,
     436              37 : ) -> anyhow::Result<(ReadyForQueryStatus, Value)> {
     437              37 :     let query_params = json_to_pg_text(data.params);
     438              73 :     let row_stream = client.query_raw_txt(&data.query, query_params).await?;
     439                 : 
     440                 :     // Manually drain the stream into a vector to leave row_stream hanging
     441                 :     // around to get a command tag. Also check that the response is not too
     442                 :     // big.
     443              36 :     pin_mut!(row_stream);
     444              36 :     let mut rows: Vec<tokio_postgres::Row> = Vec::new();
     445              74 :     while let Some(row) = row_stream.next().await {
     446              38 :         let row = row?;
     447              38 :         *current_size += row.body_len();
     448              38 :         rows.push(row);
     449              38 :         // we don't have a streaming response support yet so this is to prevent OOM
     450              38 :         // from a malicious query (eg a cross join)
     451              38 :         if *current_size > MAX_RESPONSE_SIZE {
     452 UBC           0 :             return Err(anyhow::anyhow!(
     453               0 :                 "response is too large (max is {MAX_RESPONSE_SIZE} bytes)"
     454               0 :             ));
     455 CBC          38 :         }
     456                 :     }
     457                 : 
     458              36 :     let ready = row_stream.ready_status();
     459              36 : 
     460              36 :     // grab the command tag and number of rows affected
     461              36 :     let command_tag = row_stream.command_tag().unwrap_or_default();
     462              36 :     let mut command_tag_split = command_tag.split(' ');
     463              36 :     let command_tag_name = command_tag_split.next().unwrap_or_default();
     464              36 :     let command_tag_count = if command_tag_name == "INSERT" {
     465                 :         // INSERT returns OID first and then number of rows
     466               2 :         command_tag_split.nth(1)
     467                 :     } else {
     468                 :         // other commands return number of rows (if any)
     469              34 :         command_tag_split.next()
     470                 :     }
     471              36 :     .and_then(|s| s.parse::<i64>().ok());
     472                 : 
     473              36 :     let fields = if !rows.is_empty() {
     474              30 :         rows[0]
     475              30 :             .columns()
     476              30 :             .iter()
     477             104 :             .map(|c| {
     478             104 :                 json!({
     479             104 :                     "name": Value::String(c.name().to_owned()),
     480             104 :                     "dataTypeID": Value::Number(c.type_().oid().into()),
     481             104 :                     "tableID": c.table_oid(),
     482             104 :                     "columnID": c.column_id(),
     483             104 :                     "dataTypeSize": c.type_size(),
     484             104 :                     "dataTypeModifier": c.type_modifier(),
     485             104 :                     "format": "text",
     486             104 :                 })
     487             104 :             })
     488              30 :             .collect::<Vec<_>>()
     489                 :     } else {
     490               6 :         Vec::new()
     491                 :     };
     492                 : 
     493                 :     // convert rows to JSON
     494              36 :     let rows = rows
     495              36 :         .iter()
     496              38 :         .map(|row| pg_text_row_to_json(row, raw_output, array_mode))
     497              36 :         .collect::<Result<Vec<_>, _>>()?;
     498                 : 
     499                 :     // resulting JSON format is based on the format of node-postgres result
     500              36 :     Ok((
     501              36 :         ready,
     502              36 :         json!({
     503              36 :             "command": command_tag_name,
     504              36 :             "rowCount": command_tag_count,
     505              36 :             "rows": rows,
     506              36 :             "fields": fields,
     507              36 :             "rowAsArray": array_mode,
     508              36 :         }),
     509              36 :     ))
     510              37 : }
     511                 : 
     512                 : //
     513                 : // Convert postgres row with text-encoded values to JSON object
     514                 : //
     515              38 : pub fn pg_text_row_to_json(
     516              38 :     row: &Row,
     517              38 :     raw_output: bool,
     518              38 :     array_mode: bool,
     519              38 : ) -> Result<Value, anyhow::Error> {
     520             116 :     let iter = row.columns().iter().enumerate().map(|(i, column)| {
     521             116 :         let name = column.name();
     522             116 :         let pg_value = row.as_text(i)?;
     523             116 :         let json_value = if raw_output {
     524               6 :             match pg_value {
     525               6 :                 Some(v) => Value::String(v.to_string()),
     526 UBC           0 :                 None => Value::Null,
     527                 :             }
     528                 :         } else {
     529 CBC         110 :             pg_text_to_json(pg_value, column.type_())?
     530                 :         };
     531             116 :         Ok((name.to_string(), json_value))
     532             116 :     });
     533              38 : 
     534              38 :     if array_mode {
     535                 :         // drop keys and aggregate into array
     536               2 :         let arr = iter
     537               6 :             .map(|r| r.map(|(_key, val)| val))
     538               2 :             .collect::<Result<Vec<Value>, anyhow::Error>>()?;
     539               2 :         Ok(Value::Array(arr))
     540                 :     } else {
     541              36 :         let obj = iter.collect::<Result<Map<String, Value>, anyhow::Error>>()?;
     542              36 :         Ok(Value::Object(obj))
     543                 :     }
     544              38 : }
     545                 : 
     546                 : //
     547                 : // Convert postgres text-encoded value to JSON value
     548                 : //
     549                 : pub fn pg_text_to_json(pg_value: Option<&str>, pg_type: &Type) -> Result<Value, anyhow::Error> {
     550             211 :     if let Some(val) = pg_value {
     551             199 :         if let Kind::Array(elem_type) = pg_type.kind() {
     552               8 :             return pg_array_parse(val, elem_type);
     553             191 :         }
     554             191 : 
     555             191 :         match *pg_type {
     556              30 :             Type::BOOL => Ok(Value::Bool(val == "t")),
     557                 :             Type::INT2 | Type::INT4 => {
     558              66 :                 let val = val.parse::<i32>()?;
     559              66 :                 Ok(Value::Number(serde_json::Number::from(val)))
     560                 :             }
     561                 :             Type::FLOAT4 | Type::FLOAT8 => {
     562              25 :                 let fval = val.parse::<f64>()?;
     563              25 :                 let num = serde_json::Number::from_f64(fval);
     564              25 :                 if let Some(num) = num {
     565              16 :                     Ok(Value::Number(num))
     566                 :                 } else {
     567                 :                     // Pass Nan, Inf, -Inf as strings
     568                 :                     // JS JSON.stringify() does converts them to null, but we
     569                 :                     // want to preserve them, so we pass them as strings
     570               9 :                     Ok(Value::String(val.to_string()))
     571                 :                 }
     572                 :             }
     573              12 :             Type::JSON | Type::JSONB => Ok(serde_json::from_str(val)?),
     574              58 :             _ => Ok(Value::String(val.to_string())),
     575                 :         }
     576                 :     } else {
     577              12 :         Ok(Value::Null)
     578                 :     }
     579             211 : }
     580                 : 
     581                 : //
     582                 : // Parse postgres array into JSON array.
     583                 : //
     584                 : // This is a bit involved because we need to handle nested arrays and quoted
     585                 : // values. Unlike postgres we don't check that all nested arrays have the same
     586                 : // dimensions, we just return them as is.
     587                 : //
     588              30 : fn pg_array_parse(pg_array: &str, elem_type: &Type) -> Result<Value, anyhow::Error> {
     589              30 :     _pg_array_parse(pg_array, elem_type, false).map(|(v, _)| v)
     590              30 : }
     591                 : 
     592              47 : fn _pg_array_parse(
     593              47 :     pg_array: &str,
     594              47 :     elem_type: &Type,
     595              47 :     nested: bool,
     596              47 : ) -> Result<(Value, usize), anyhow::Error> {
     597              47 :     let mut pg_array_chr = pg_array.char_indices();
     598              47 :     let mut level = 0;
     599              47 :     let mut quote = false;
     600              47 :     let mut entries: Vec<Value> = Vec::new();
     601              47 :     let mut entry = String::new();
     602              47 : 
     603              47 :     // skip bounds decoration
     604              47 :     if let Some('[') = pg_array.chars().next() {
     605              18 :         for (_, c) in pg_array_chr.by_ref() {
     606              18 :             if c == '=' {
     607               1 :                 break;
     608              17 :             }
     609                 :         }
     610              46 :     }
     611                 : 
     612             107 :     fn push_checked(
     613             107 :         entry: &mut String,
     614             107 :         entries: &mut Vec<Value>,
     615             107 :         elem_type: &Type,
     616             107 :     ) -> Result<(), anyhow::Error> {
     617             107 :         if !entry.is_empty() {
     618                 :             // While in usual postgres response we get nulls as None and everything else
     619                 :             // as Some(&str), in arrays we get NULL as unquoted 'NULL' string (while
     620                 :             // string with value 'NULL' will be represented by '"NULL"'). So catch NULLs
     621                 :             // here while we have quotation info and convert them to None.
     622              72 :             if entry == "NULL" {
     623               7 :                 entries.push(pg_text_to_json(None, elem_type)?);
     624                 :             } else {
     625              65 :                 entries.push(pg_text_to_json(Some(entry), elem_type)?);
     626                 :             }
     627              72 :             entry.clear();
     628              35 :         }
     629                 : 
     630             107 :         Ok(())
     631             107 :     }
     632                 : 
     633             556 :     while let Some((mut i, mut c)) = pg_array_chr.next() {
     634             526 :         let mut escaped = false;
     635             526 : 
     636             526 :         if c == '\\' {
     637              19 :             escaped = true;
     638              19 :             (i, c) = pg_array_chr.next().unwrap();
     639             507 :         }
     640                 : 
     641             261 :         match c {
     642              64 :             '{' if !quote => {
     643              64 :                 level += 1;
     644              64 :                 if level > 1 {
     645              17 :                     let (res, off) = _pg_array_parse(&pg_array[i..], elem_type, true)?;
     646              17 :                     entries.push(res);
     647             195 :                     for _ in 0..off - 1 {
     648             195 :                         pg_array_chr.next();
     649             195 :                     }
     650              47 :                 }
     651                 :             }
     652              64 :             '}' if !quote => {
     653              64 :                 level -= 1;
     654              64 :                 if level == 0 {
     655              47 :                     push_checked(&mut entry, &mut entries, elem_type)?;
     656              47 :                     if nested {
     657              17 :                         return Ok((Value::Array(entries), i));
     658              30 :                     }
     659              17 :                 }
     660                 :             }
     661              36 :             '"' if !escaped => {
     662              36 :                 if quote {
     663                 :                     // end of quoted string, so push it manually without any checks
     664                 :                     // for emptiness or nulls
     665              18 :                     entries.push(pg_text_to_json(Some(&entry), elem_type)?);
     666              18 :                     entry.clear();
     667              18 :                 }
     668              36 :                 quote = !quote;
     669                 :             }
     670              60 :             ',' if !quote => {
     671              60 :                 push_checked(&mut entry, &mut entries, elem_type)?;
     672                 :             }
     673             302 :             _ => {
     674             302 :                 entry.push(c);
     675             302 :             }
     676                 :         }
     677                 :     }
     678                 : 
     679              30 :     if level != 0 {
     680 UBC           0 :         return Err(anyhow::anyhow!("unbalanced array"));
     681 CBC          30 :     }
     682              30 : 
     683              30 :     Ok((Value::Array(entries), 0))
     684              47 : }
     685                 : 
     686                 : #[cfg(test)]
     687                 : mod tests {
     688                 :     use super::*;
     689                 :     use serde_json::json;
     690                 : 
     691               1 :     #[test]
     692               1 :     fn test_atomic_types_to_pg_params() {
     693               1 :         let json = vec![Value::Bool(true), Value::Bool(false)];
     694               1 :         let pg_params = json_to_pg_text(json);
     695               1 :         assert_eq!(
     696               1 :             pg_params,
     697               1 :             vec![Some("true".to_owned()), Some("false".to_owned())]
     698               1 :         );
     699                 : 
     700               1 :         let json = vec![Value::Number(serde_json::Number::from(42))];
     701               1 :         let pg_params = json_to_pg_text(json);
     702               1 :         assert_eq!(pg_params, vec![Some("42".to_owned())]);
     703                 : 
     704               1 :         let json = vec![Value::String("foo\"".to_string())];
     705               1 :         let pg_params = json_to_pg_text(json);
     706               1 :         assert_eq!(pg_params, vec![Some("foo\"".to_owned())]);
     707                 : 
     708               1 :         let json = vec![Value::Null];
     709               1 :         let pg_params = json_to_pg_text(json);
     710               1 :         assert_eq!(pg_params, vec![None]);
     711               1 :     }
     712                 : 
     713               1 :     #[test]
     714               1 :     fn test_json_array_to_pg_array() {
     715               1 :         // atoms and escaping
     716               1 :         let json = "[true, false, null, \"NULL\", 42, \"foo\", \"bar\\\"-\\\\\"]";
     717               1 :         let json: Value = serde_json::from_str(json).unwrap();
     718               1 :         let pg_params = json_to_pg_text(vec![json]);
     719               1 :         assert_eq!(
     720               1 :             pg_params,
     721               1 :             vec![Some(
     722               1 :                 "{true,false,NULL,\"NULL\",42,\"foo\",\"bar\\\"-\\\\\"}".to_owned()
     723               1 :             )]
     724               1 :         );
     725                 : 
     726                 :         // nested arrays
     727               1 :         let json = "[[true, false], [null, 42], [\"foo\", \"bar\\\"-\\\\\"]]";
     728               1 :         let json: Value = serde_json::from_str(json).unwrap();
     729               1 :         let pg_params = json_to_pg_text(vec![json]);
     730               1 :         assert_eq!(
     731               1 :             pg_params,
     732               1 :             vec![Some(
     733               1 :                 "{{true,false},{NULL,42},{\"foo\",\"bar\\\"-\\\\\"}}".to_owned()
     734               1 :             )]
     735               1 :         );
     736                 :         // array of objects
     737               1 :         let json = r#"[{"foo": 1},{"bar": 2}]"#;
     738               1 :         let json: Value = serde_json::from_str(json).unwrap();
     739               1 :         let pg_params = json_to_pg_text(vec![json]);
     740               1 :         assert_eq!(
     741               1 :             pg_params,
     742               1 :             vec![Some(r#"{"{\"foo\":1}","{\"bar\":2}"}"#.to_owned())]
     743               1 :         );
     744               1 :     }
     745                 : 
     746               1 :     #[test]
     747               1 :     fn test_atomic_types_parse() {
     748               1 :         assert_eq!(
     749               1 :             pg_text_to_json(Some("foo"), &Type::TEXT).unwrap(),
     750               1 :             json!("foo")
     751               1 :         );
     752               1 :         assert_eq!(pg_text_to_json(None, &Type::TEXT).unwrap(), json!(null));
     753               1 :         assert_eq!(pg_text_to_json(Some("42"), &Type::INT4).unwrap(), json!(42));
     754               1 :         assert_eq!(pg_text_to_json(Some("42"), &Type::INT2).unwrap(), json!(42));
     755               1 :         assert_eq!(
     756               1 :             pg_text_to_json(Some("42"), &Type::INT8).unwrap(),
     757               1 :             json!("42")
     758               1 :         );
     759               1 :         assert_eq!(
     760               1 :             pg_text_to_json(Some("42.42"), &Type::FLOAT8).unwrap(),
     761               1 :             json!(42.42)
     762               1 :         );
     763               1 :         assert_eq!(
     764               1 :             pg_text_to_json(Some("42.42"), &Type::FLOAT4).unwrap(),
     765               1 :             json!(42.42)
     766               1 :         );
     767               1 :         assert_eq!(
     768               1 :             pg_text_to_json(Some("NaN"), &Type::FLOAT4).unwrap(),
     769               1 :             json!("NaN")
     770               1 :         );
     771               1 :         assert_eq!(
     772               1 :             pg_text_to_json(Some("Infinity"), &Type::FLOAT4).unwrap(),
     773               1 :             json!("Infinity")
     774               1 :         );
     775               1 :         assert_eq!(
     776               1 :             pg_text_to_json(Some("-Infinity"), &Type::FLOAT4).unwrap(),
     777               1 :             json!("-Infinity")
     778               1 :         );
     779                 : 
     780               1 :         let json: Value =
     781               1 :             serde_json::from_str("{\"s\":\"str\",\"n\":42,\"f\":4.2,\"a\":[null,3,\"a\"]}")
     782               1 :                 .unwrap();
     783               1 :         assert_eq!(
     784               1 :             pg_text_to_json(
     785               1 :                 Some(r#"{"s":"str","n":42,"f":4.2,"a":[null,3,"a"]}"#),
     786               1 :                 &Type::JSONB
     787               1 :             )
     788               1 :             .unwrap(),
     789               1 :             json
     790               1 :         );
     791               1 :     }
     792                 : 
     793               1 :     #[test]
     794               1 :     fn test_pg_array_parse_text() {
     795               4 :         fn pt(pg_arr: &str) -> Value {
     796               4 :             pg_array_parse(pg_arr, &Type::TEXT).unwrap()
     797               4 :         }
     798               1 :         assert_eq!(
     799               1 :             pt(r#"{"aa\"\\\,a",cha,"bbbb"}"#),
     800               1 :             json!(["aa\"\\,a", "cha", "bbbb"])
     801               1 :         );
     802               1 :         assert_eq!(
     803               1 :             pt(r#"{{"foo","bar"},{"bee","bop"}}"#),
     804               1 :             json!([["foo", "bar"], ["bee", "bop"]])
     805               1 :         );
     806               1 :         assert_eq!(
     807               1 :             pt(r#"{{{{"foo",NULL,"bop",bup}}}}"#),
     808               1 :             json!([[[["foo", null, "bop", "bup"]]]])
     809               1 :         );
     810               1 :         assert_eq!(
     811               1 :             pt(r#"{{"1",2,3},{4,NULL,6},{NULL,NULL,NULL}}"#),
     812               1 :             json!([["1", "2", "3"], ["4", null, "6"], [null, null, null]])
     813               1 :         );
     814               1 :     }
     815                 : 
     816               1 :     #[test]
     817               1 :     fn test_pg_array_parse_bool() {
     818               4 :         fn pb(pg_arr: &str) -> Value {
     819               4 :             pg_array_parse(pg_arr, &Type::BOOL).unwrap()
     820               4 :         }
     821               1 :         assert_eq!(pb(r#"{t,f,t}"#), json!([true, false, true]));
     822               1 :         assert_eq!(pb(r#"{{t,f,t}}"#), json!([[true, false, true]]));
     823               1 :         assert_eq!(
     824               1 :             pb(r#"{{t,f},{f,t}}"#),
     825               1 :             json!([[true, false], [false, true]])
     826               1 :         );
     827               1 :         assert_eq!(
     828               1 :             pb(r#"{{t,NULL},{NULL,f}}"#),
     829               1 :             json!([[true, null], [null, false]])
     830               1 :         );
     831               1 :     }
     832                 : 
     833               1 :     #[test]
     834               1 :     fn test_pg_array_parse_numbers() {
     835               9 :         fn pn(pg_arr: &str, ty: &Type) -> Value {
     836               9 :             pg_array_parse(pg_arr, ty).unwrap()
     837               9 :         }
     838               1 :         assert_eq!(pn(r#"{1,2,3}"#, &Type::INT4), json!([1, 2, 3]));
     839               1 :         assert_eq!(pn(r#"{1,2,3}"#, &Type::INT2), json!([1, 2, 3]));
     840               1 :         assert_eq!(pn(r#"{1,2,3}"#, &Type::INT8), json!(["1", "2", "3"]));
     841               1 :         assert_eq!(pn(r#"{1,2,3}"#, &Type::FLOAT4), json!([1.0, 2.0, 3.0]));
     842               1 :         assert_eq!(pn(r#"{1,2,3}"#, &Type::FLOAT8), json!([1.0, 2.0, 3.0]));
     843               1 :         assert_eq!(
     844               1 :             pn(r#"{1.1,2.2,3.3}"#, &Type::FLOAT4),
     845               1 :             json!([1.1, 2.2, 3.3])
     846               1 :         );
     847               1 :         assert_eq!(
     848               1 :             pn(r#"{1.1,2.2,3.3}"#, &Type::FLOAT8),
     849               1 :             json!([1.1, 2.2, 3.3])
     850               1 :         );
     851               1 :         assert_eq!(
     852               1 :             pn(r#"{NaN,Infinity,-Infinity}"#, &Type::FLOAT4),
     853               1 :             json!(["NaN", "Infinity", "-Infinity"])
     854               1 :         );
     855               1 :         assert_eq!(
     856               1 :             pn(r#"{NaN,Infinity,-Infinity}"#, &Type::FLOAT8),
     857               1 :             json!(["NaN", "Infinity", "-Infinity"])
     858               1 :         );
     859               1 :     }
     860                 : 
     861               1 :     #[test]
     862               1 :     fn test_pg_array_with_decoration() {
     863               1 :         fn p(pg_arr: &str) -> Value {
     864               1 :             pg_array_parse(pg_arr, &Type::INT2).unwrap()
     865               1 :         }
     866               1 :         assert_eq!(
     867               1 :             p(r#"[1:1][-2:-1][3:5]={{{1,2,3},{4,5,6}}}"#),
     868               1 :             json!([[[1, 2, 3], [4, 5, 6]]])
     869               1 :         );
     870               1 :     }
     871               1 :     #[test]
     872               1 :     fn test_pg_array_parse_json() {
     873               4 :         fn pt(pg_arr: &str) -> Value {
     874               4 :             pg_array_parse(pg_arr, &Type::JSONB).unwrap()
     875               4 :         }
     876               1 :         assert_eq!(pt(r#"{"{}"}"#), json!([{}]));
     877               1 :         assert_eq!(
     878               1 :             pt(r#"{"{\"foo\": 1, \"bar\": 2}"}"#),
     879               1 :             json!([{"foo": 1, "bar": 2}])
     880               1 :         );
     881               1 :         assert_eq!(
     882               1 :             pt(r#"{"{\"foo\": 1}", "{\"bar\": 2}"}"#),
     883               1 :             json!([{"foo": 1}, {"bar": 2}])
     884               1 :         );
     885               1 :         assert_eq!(
     886               1 :             pt(r#"{{"{\"foo\": 1}", "{\"bar\": 2}"}}"#),
     887               1 :             json!([[{"foo": 1}, {"bar": 2}]])
     888               1 :         );
     889               1 :     }
     890                 : }
        

Generated by: LCOV version 2.1-beta