LCOV - differential code coverage report
Current view: top level - proxy/src/serverless - sql_over_http.rs (source / functions) Coverage Total Hit UBC CBC
Current: cd44433dd675caa99df17a61b18949c8387e2242.info Lines: 93.0 % 587 546 41 546
Current Date: 2024-01-09 02:06:09 Functions: 73.2 % 138 101 37 101
Baseline: 66c52a629a0f4a503e193045e0df4c77139e344b.info
Baseline Date: 2024-01-08 15:34:46

           TLA  Line data    Source code
       1                 : use std::sync::Arc;
       2                 : 
       3                 : use anyhow::bail;
       4                 : use futures::pin_mut;
       5                 : use futures::StreamExt;
       6                 : use hyper::body::HttpBody;
       7                 : use hyper::header;
       8                 : use hyper::http::HeaderName;
       9                 : use hyper::http::HeaderValue;
      10                 : use hyper::Response;
      11                 : use hyper::StatusCode;
      12                 : use hyper::{Body, HeaderMap, Request};
      13                 : use serde_json::json;
      14                 : use serde_json::Map;
      15                 : use serde_json::Value;
      16                 : use smol_str::SmolStr;
      17                 : use tokio_postgres::error::DbError;
      18                 : use tokio_postgres::types::Kind;
      19                 : use tokio_postgres::types::Type;
      20                 : use tokio_postgres::GenericClient;
      21                 : use tokio_postgres::IsolationLevel;
      22                 : use tokio_postgres::ReadyForQueryStatus;
      23                 : use tokio_postgres::Row;
      24                 : use tokio_postgres::Transaction;
      25                 : use tracing::error;
      26                 : use tracing::instrument;
      27                 : use url::Url;
      28                 : use utils::http::error::ApiError;
      29                 : use utils::http::json::json_response;
      30                 : 
      31                 : use crate::config::HttpConfig;
      32                 : use crate::context::RequestMonitoring;
      33                 : use crate::metrics::NUM_CONNECTION_REQUESTS_GAUGE;
      34                 : 
      35                 : use super::conn_pool::ConnInfo;
      36                 : use super::conn_pool::GlobalConnPool;
      37                 : 
      38 CBC         276 : #[derive(serde::Deserialize)]
      39                 : struct QueryData {
      40                 :     query: String,
      41                 :     params: Vec<serde_json::Value>,
      42                 : }
      43                 : 
      44               6 : #[derive(serde::Deserialize)]
      45                 : struct BatchQueryData {
      46                 :     queries: Vec<QueryData>,
      47                 : }
      48                 : 
      49              47 : #[derive(serde::Deserialize)]
      50                 : #[serde(untagged)]
      51                 : enum Payload {
      52                 :     Single(QueryData),
      53                 :     Batch(BatchQueryData),
      54                 : }
      55                 : 
      56                 : const MAX_RESPONSE_SIZE: usize = 10 * 1024 * 1024; // 10 MiB
      57                 : const MAX_REQUEST_SIZE: u64 = 10 * 1024 * 1024; // 10 MiB
      58                 : 
      59                 : static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output");
      60                 : static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode");
      61                 : static ALLOW_POOL: HeaderName = HeaderName::from_static("neon-pool-opt-in");
      62                 : static TXN_ISOLATION_LEVEL: HeaderName = HeaderName::from_static("neon-batch-isolation-level");
      63                 : static TXN_READ_ONLY: HeaderName = HeaderName::from_static("neon-batch-read-only");
      64                 : static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrable");
      65                 : 
      66                 : static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true");
      67                 : 
      68                 : //
      69                 : // Convert json non-string types to strings, so that they can be passed to Postgres
      70                 : // as parameters.
      71                 : //
      72              58 : fn json_to_pg_text(json: Vec<Value>) -> Vec<Option<String>> {
      73              58 :     json.iter()
      74              58 :         .map(|value| {
      75              18 :             match value {
      76                 :                 // special care for nulls
      77               1 :                 Value::Null => None,
      78                 : 
      79                 :                 // convert to text with escaping
      80               9 :                 v @ (Value::Bool(_) | Value::Number(_) | Value::Object(_)) => Some(v.to_string()),
      81                 : 
      82                 :                 // avoid escaping here, as we pass this as a parameter
      83               1 :                 Value::String(s) => Some(s.to_string()),
      84                 : 
      85                 :                 // special care for arrays
      86               7 :                 Value::Array(_) => json_array_to_pg_array(value),
      87                 :             }
      88              58 :         })
      89              58 :         .collect()
      90              58 : }
      91                 : 
      92                 : //
      93                 : // Serialize a JSON array to a Postgres array. Contrary to the strings in the params
      94                 : // in the array we need to escape the strings. Postgres is okay with arrays of form
      95                 : // '{1,"2",3}'::int[], so we don't check that array holds values of the same type, leaving
      96                 : // it for Postgres to check.
      97                 : //
      98                 : // Example of the same escaping in node-postgres: packages/pg/lib/utils.js
      99                 : //
     100              39 : fn json_array_to_pg_array(value: &Value) -> Option<String> {
     101              39 :     match value {
     102                 :         // special care for nulls
     103               2 :         Value::Null => None,
     104                 : 
     105                 :         // convert to text with escaping
     106                 :         // here string needs to be escaped, as it is part of the array
     107              22 :         v @ (Value::Bool(_) | Value::Number(_) | Value::String(_)) => Some(v.to_string()),
     108               5 :         v @ Value::Object(_) => json_array_to_pg_array(&Value::String(v.to_string())),
     109                 : 
     110                 :         // recurse into array
     111              10 :         Value::Array(arr) => {
     112              10 :             let vals = arr
     113              10 :                 .iter()
     114              10 :                 .map(json_array_to_pg_array)
     115              27 :                 .map(|v| v.unwrap_or_else(|| "NULL".to_string()))
     116              10 :                 .collect::<Vec<_>>()
     117              10 :                 .join(",");
     118              10 : 
     119              10 :             Some(format!("{{{}}}", vals))
     120                 :         }
     121                 :     }
     122              39 : }
     123                 : 
     124              45 : fn get_conn_info(
     125              45 :     ctx: &mut RequestMonitoring,
     126              45 :     headers: &HeaderMap,
     127              45 :     sni_hostname: Option<String>,
     128              45 : ) -> Result<ConnInfo, anyhow::Error> {
     129              45 :     let connection_string = headers
     130              45 :         .get("Neon-Connection-String")
     131              45 :         .ok_or(anyhow::anyhow!("missing connection string"))?
     132              45 :         .to_str()?;
     133                 : 
     134              45 :     let connection_url = Url::parse(connection_string)?;
     135                 : 
     136              45 :     let protocol = connection_url.scheme();
     137              45 :     if protocol != "postgres" && protocol != "postgresql" {
     138 UBC           0 :         return Err(anyhow::anyhow!(
     139               0 :             "connection string must start with postgres: or postgresql:"
     140               0 :         ));
     141 CBC          45 :     }
     142                 : 
     143              45 :     let mut url_path = connection_url
     144              45 :         .path_segments()
     145              45 :         .ok_or(anyhow::anyhow!("missing database name"))?;
     146                 : 
     147              45 :     let dbname = url_path
     148              45 :         .next()
     149              45 :         .ok_or(anyhow::anyhow!("invalid database name"))?;
     150                 : 
     151              45 :     let username = SmolStr::from(connection_url.username());
     152              45 :     if username.is_empty() {
     153 UBC           0 :         return Err(anyhow::anyhow!("missing username"));
     154 CBC          45 :     }
     155              45 :     ctx.set_user(username.clone());
     156                 : 
     157              45 :     let password = connection_url
     158              45 :         .password()
     159              45 :         .ok_or(anyhow::anyhow!("no password"))?;
     160                 : 
     161                 :     // TLS certificate selector now based on SNI hostname, so if we are running here
     162                 :     // we are sure that SNI hostname is set to one of the configured domain names.
     163              45 :     let sni_hostname = sni_hostname.ok_or(anyhow::anyhow!("no SNI hostname set"))?;
     164                 : 
     165              45 :     let hostname = connection_url
     166              45 :         .host_str()
     167              45 :         .ok_or(anyhow::anyhow!("no host"))?;
     168                 : 
     169              45 :     let host_header = headers
     170              45 :         .get("host")
     171              45 :         .and_then(|h| h.to_str().ok())
     172              45 :         .and_then(|h| h.split(':').next());
     173              45 : 
     174              45 :     if hostname != sni_hostname {
     175 UBC           0 :         return Err(anyhow::anyhow!("mismatched SNI hostname and hostname"));
     176 CBC          45 :     } else if let Some(h) = host_header {
     177              45 :         if h != hostname {
     178 UBC           0 :             return Err(anyhow::anyhow!("mismatched host header and hostname"));
     179 CBC          45 :         }
     180 UBC           0 :     }
     181                 : 
     182 CBC          45 :     let hostname: SmolStr = hostname.into();
     183              45 :     ctx.set_endpoint_id(Some(hostname.clone()));
     184              45 : 
     185              45 :     let pairs = connection_url.query_pairs();
     186              45 : 
     187              45 :     let mut options = Option::None;
     188                 : 
     189              45 :     for (key, value) in pairs {
     190 UBC           0 :         if key == "options" {
     191               0 :             options = Some(value.into());
     192               0 :             break;
     193               0 :         }
     194                 :     }
     195                 : 
     196 CBC          45 :     Ok(ConnInfo {
     197              45 :         username,
     198              45 :         dbname: dbname.into(),
     199              45 :         hostname,
     200              45 :         password: password.into(),
     201              45 :         options,
     202              45 :     })
     203              45 : }
     204                 : 
     205                 : // TODO: return different http error codes
     206              45 : pub async fn handle(
     207              45 :     config: &'static HttpConfig,
     208              45 :     ctx: &mut RequestMonitoring,
     209              45 :     request: Request<Body>,
     210              45 :     sni_hostname: Option<String>,
     211              45 :     conn_pool: Arc<GlobalConnPool>,
     212              45 : ) -> Result<Response<Body>, ApiError> {
     213              45 :     let result = tokio::time::timeout(
     214              45 :         config.request_timeout,
     215              45 :         handle_inner(config, ctx, request, sni_hostname, conn_pool),
     216              45 :     )
     217             727 :     .await;
     218              45 :     let mut response = match result {
     219              45 :         Ok(r) => match r {
     220              40 :             Ok(r) => r,
     221               5 :             Err(e) => {
     222               5 :                 let message = format!("{:?}", e);
     223               5 :                 let db_error = e
     224               5 :                     .downcast_ref::<tokio_postgres::Error>()
     225               5 :                     .and_then(|e| e.as_db_error());
     226              65 :                 fn get<'a, T: serde::Serialize>(
     227              65 :                     db: Option<&'a DbError>,
     228              65 :                     x: impl FnOnce(&'a DbError) -> T,
     229              65 :                 ) -> Value {
     230              65 :                     db.map(x)
     231              65 :                         .and_then(|t| serde_json::to_value(t).ok())
     232              65 :                         .unwrap_or_default()
     233              65 :                 }
     234               5 : 
     235               5 :                 // TODO(conrad): db_error.position()
     236               5 :                 let code = get(db_error, |db| db.code().code());
     237               5 :                 let severity = get(db_error, |db| db.severity());
     238               5 :                 let detail = get(db_error, |db| db.detail());
     239               5 :                 let hint = get(db_error, |db| db.hint());
     240               5 :                 let where_ = get(db_error, |db| db.where_());
     241               5 :                 let table = get(db_error, |db| db.table());
     242               5 :                 let column = get(db_error, |db| db.column());
     243               5 :                 let schema = get(db_error, |db| db.schema());
     244               5 :                 let datatype = get(db_error, |db| db.datatype());
     245               5 :                 let constraint = get(db_error, |db| db.constraint());
     246               5 :                 let file = get(db_error, |db| db.file());
     247               5 :                 let line = get(db_error, |db| db.line());
     248               5 :                 let routine = get(db_error, |db| db.routine());
     249                 : 
     250               5 :                 error!(
     251               5 :                     ?code,
     252               5 :                     "sql-over-http per-client task finished with an error: {e:#}"
     253               5 :                 );
     254                 :                 // TODO: this shouldn't always be bad request.
     255               5 :                 json_response(
     256               5 :                     StatusCode::BAD_REQUEST,
     257               5 :                     json!({
     258               5 :                         "message": message,
     259               5 :                         "code": code,
     260               5 :                         "detail": detail,
     261               5 :                         "hint": hint,
     262               5 :                         "severity": severity,
     263               5 :                         "where": where_,
     264               5 :                         "table": table,
     265               5 :                         "column": column,
     266               5 :                         "schema": schema,
     267               5 :                         "datatype": datatype,
     268               5 :                         "constraint": constraint,
     269               5 :                         "file": file,
     270               5 :                         "line": line,
     271               5 :                         "routine": routine,
     272               5 :                     }),
     273               5 :                 )?
     274                 :             }
     275                 :         },
     276                 :         Err(_) => {
     277 UBC           0 :             let message = format!(
     278               0 :                 "HTTP-Connection timed out, execution time exeeded {} seconds",
     279               0 :                 config.request_timeout.as_secs()
     280               0 :             );
     281               0 :             error!(message);
     282               0 :             json_response(
     283               0 :                 StatusCode::GATEWAY_TIMEOUT,
     284               0 :                 json!({ "message": message, "code": StatusCode::GATEWAY_TIMEOUT.as_u16() }),
     285               0 :             )?
     286                 :         }
     287                 :     };
     288 CBC          45 :     response.headers_mut().insert(
     289              45 :         "Access-Control-Allow-Origin",
     290              45 :         hyper::http::HeaderValue::from_static("*"),
     291              45 :     );
     292              45 :     Ok(response)
     293              45 : }
     294                 : 
     295 UBC           0 : #[instrument(name = "sql-over-http", fields(pid = tracing::field::Empty), skip_all)]
     296                 : async fn handle_inner(
     297                 :     config: &'static HttpConfig,
     298                 :     ctx: &mut RequestMonitoring,
     299                 :     request: Request<Body>,
     300                 :     sni_hostname: Option<String>,
     301                 :     conn_pool: Arc<GlobalConnPool>,
     302                 : ) -> anyhow::Result<Response<Body>> {
     303                 :     let _request_gauge = NUM_CONNECTION_REQUESTS_GAUGE
     304                 :         .with_label_values(&["http"])
     305                 :         .guard();
     306                 : 
     307                 :     //
     308                 :     // Determine the destination and connection params
     309                 :     //
     310                 :     let headers = request.headers();
     311                 :     let conn_info = get_conn_info(ctx, headers, sni_hostname)?;
     312                 : 
     313                 :     // Determine the output options. Default behaviour is 'false'. Anything that is not
     314                 :     // strictly 'true' assumed to be false.
     315                 :     let raw_output = headers.get(&RAW_TEXT_OUTPUT) == Some(&HEADER_VALUE_TRUE);
     316                 :     let array_mode = headers.get(&ARRAY_MODE) == Some(&HEADER_VALUE_TRUE);
     317                 : 
     318                 :     // Allow connection pooling only if explicitly requested
     319                 :     // or if we have decided that http pool is no longer opt-in
     320                 :     let allow_pool =
     321                 :         !config.pool_options.opt_in || headers.get(&ALLOW_POOL) == Some(&HEADER_VALUE_TRUE);
     322                 : 
     323                 :     // isolation level, read only and deferrable
     324                 : 
     325                 :     let txn_isolation_level_raw = headers.get(&TXN_ISOLATION_LEVEL).cloned();
     326                 :     let txn_isolation_level = match txn_isolation_level_raw {
     327                 :         Some(ref x) => Some(match x.as_bytes() {
     328                 :             b"Serializable" => IsolationLevel::Serializable,
     329                 :             b"ReadUncommitted" => IsolationLevel::ReadUncommitted,
     330                 :             b"ReadCommitted" => IsolationLevel::ReadCommitted,
     331                 :             b"RepeatableRead" => IsolationLevel::RepeatableRead,
     332                 :             _ => bail!("invalid isolation level"),
     333                 :         }),
     334                 :         None => None,
     335                 :     };
     336                 : 
     337                 :     let txn_read_only = headers.get(&TXN_READ_ONLY) == Some(&HEADER_VALUE_TRUE);
     338                 :     let txn_deferrable = headers.get(&TXN_DEFERRABLE) == Some(&HEADER_VALUE_TRUE);
     339                 : 
     340                 :     let paused = ctx.latency_timer.pause();
     341                 :     let request_content_length = match request.body().size_hint().upper() {
     342                 :         Some(v) => v,
     343                 :         None => MAX_REQUEST_SIZE + 1,
     344                 :     };
     345                 :     drop(paused);
     346                 : 
     347                 :     // we don't have a streaming request support yet so this is to prevent OOM
     348                 :     // from a malicious user sending an extremely large request body
     349                 :     if request_content_length > MAX_REQUEST_SIZE {
     350                 :         return Err(anyhow::anyhow!(
     351                 :             "request is too large (max is {MAX_REQUEST_SIZE} bytes)"
     352                 :         ));
     353                 :     }
     354                 : 
     355                 :     //
     356                 :     // Read the query and query params from the request body
     357                 :     //
     358                 :     let body = hyper::body::to_bytes(request.into_body()).await?;
     359                 :     let payload: Payload = serde_json::from_slice(&body)?;
     360                 : 
     361                 :     let mut client = conn_pool.get(ctx, conn_info, !allow_pool).await?;
     362                 : 
     363                 :     let mut response = Response::builder()
     364                 :         .status(StatusCode::OK)
     365                 :         .header(header::CONTENT_TYPE, "application/json");
     366                 : 
     367                 :     //
     368                 :     // Now execute the query and return the result
     369                 :     //
     370                 :     let mut size = 0;
     371                 :     let result =
     372                 :         match payload {
     373                 :             Payload::Single(stmt) => {
     374                 :                 let (status, results) =
     375                 :                     query_to_json(&*client, stmt, &mut 0, raw_output, array_mode)
     376                 :                         .await
     377 CBC           2 :                         .map_err(|e| {
     378               2 :                             client.discard();
     379               2 :                             e
     380               2 :                         })?;
     381                 :                 client.check_idle(status);
     382                 :                 results
     383                 :             }
     384                 :             Payload::Batch(statements) => {
     385                 :                 let (inner, mut discard) = client.inner();
     386                 :                 let mut builder = inner.build_transaction();
     387                 :                 if let Some(isolation_level) = txn_isolation_level {
     388                 :                     builder = builder.isolation_level(isolation_level);
     389                 :                 }
     390                 :                 if txn_read_only {
     391                 :                     builder = builder.read_only(true);
     392                 :                 }
     393                 :                 if txn_deferrable {
     394                 :                     builder = builder.deferrable(true);
     395                 :                 }
     396                 : 
     397 UBC           0 :                 let transaction = builder.start().await.map_err(|e| {
     398               0 :                     // if we cannot start a transaction, we should return immediately
     399               0 :                     // and not return to the pool. connection is clearly broken
     400               0 :                     discard.discard();
     401               0 :                     e
     402               0 :                 })?;
     403                 : 
     404                 :                 let results =
     405                 :                     match query_batch(&transaction, statements, &mut size, raw_output, array_mode)
     406                 :                         .await
     407                 :                     {
     408                 :                         Ok(results) => {
     409               0 :                             let status = transaction.commit().await.map_err(|e| {
     410               0 :                                 // if we cannot commit - for now don't return connection to pool
     411               0 :                                 // TODO: get a query status from the error
     412               0 :                                 discard.discard();
     413               0 :                                 e
     414               0 :                             })?;
     415                 :                             discard.check_idle(status);
     416                 :                             results
     417                 :                         }
     418                 :                         Err(err) => {
     419               0 :                             let status = transaction.rollback().await.map_err(|e| {
     420               0 :                                 // if we cannot rollback - for now don't return connection to pool
     421               0 :                                 // TODO: get a query status from the error
     422               0 :                                 discard.discard();
     423               0 :                                 e
     424               0 :                             })?;
     425                 :                             discard.check_idle(status);
     426                 :                             return Err(err);
     427                 :                         }
     428                 :                     };
     429                 : 
     430                 :                 if txn_read_only {
     431                 :                     response = response.header(
     432                 :                         TXN_READ_ONLY.clone(),
     433                 :                         HeaderValue::try_from(txn_read_only.to_string())?,
     434                 :                     );
     435                 :                 }
     436                 :                 if txn_deferrable {
     437                 :                     response = response.header(
     438                 :                         TXN_DEFERRABLE.clone(),
     439                 :                         HeaderValue::try_from(txn_deferrable.to_string())?,
     440                 :                     );
     441                 :                 }
     442                 :                 if let Some(txn_isolation_level) = txn_isolation_level_raw {
     443                 :                     response = response.header(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level);
     444                 :                 }
     445                 :                 json!({ "results": results })
     446                 :             }
     447                 :         };
     448                 : 
     449                 :     ctx.log();
     450                 :     let metrics = client.metrics();
     451                 : 
     452                 :     // how could this possibly fail
     453                 :     let body = serde_json::to_string(&result).expect("json serialization should not fail");
     454                 :     let len = body.len();
     455                 :     let response = response
     456                 :         .body(Body::from(body))
     457                 :         // only fails if invalid status code or invalid header/values are given.
     458                 :         // these are not user configurable so it cannot fail dynamically
     459                 :         .expect("building response payload should not fail");
     460                 : 
     461                 :     // count the egress bytes - we miss the TLS and header overhead but oh well...
     462                 :     // moving this later in the stack is going to be a lot of effort and ehhhh
     463                 :     metrics.record_egress(len as u64);
     464                 : 
     465                 :     Ok(response)
     466                 : }
     467                 : 
     468 CBC           2 : async fn query_batch(
     469               2 :     transaction: &Transaction<'_>,
     470               2 :     queries: BatchQueryData,
     471               2 :     total_size: &mut usize,
     472               2 :     raw_output: bool,
     473               2 :     array_mode: bool,
     474               2 : ) -> anyhow::Result<Vec<Value>> {
     475               2 :     let mut results = Vec::with_capacity(queries.queries.len());
     476               2 :     let mut current_size = 0;
     477              13 :     for stmt in queries.queries {
     478              11 :         // TODO: maybe we should check that the transaction bit is set here
     479              11 :         let (_, values) =
     480              11 :             query_to_json(transaction, stmt, &mut current_size, raw_output, array_mode).await?;
     481              11 :         results.push(values);
     482                 :     }
     483               2 :     *total_size += current_size;
     484               2 :     Ok(results)
     485               2 : }
     486                 : 
     487              51 : async fn query_to_json<T: GenericClient>(
     488              51 :     client: &T,
     489              51 :     data: QueryData,
     490              51 :     current_size: &mut usize,
     491              51 :     raw_output: bool,
     492              51 :     array_mode: bool,
     493              51 : ) -> anyhow::Result<(ReadyForQueryStatus, Value)> {
     494              51 :     let query_params = json_to_pg_text(data.params);
     495              51 :     let row_stream = client.query_raw_txt(&data.query, query_params).await?;
     496                 : 
     497                 :     // Manually drain the stream into a vector to leave row_stream hanging
     498                 :     // around to get a command tag. Also check that the response is not too
     499                 :     // big.
     500              50 :     pin_mut!(row_stream);
     501              50 :     let mut rows: Vec<tokio_postgres::Row> = Vec::new();
     502          506404 :     while let Some(row) = row_stream.next().await {
     503          506355 :         let row = row?;
     504          506355 :         *current_size += row.body_len();
     505          506355 :         rows.push(row);
     506          506355 :         // we don't have a streaming response support yet so this is to prevent OOM
     507          506355 :         // from a malicious query (eg a cross join)
     508          506355 :         if *current_size > MAX_RESPONSE_SIZE {
     509               1 :             return Err(anyhow::anyhow!(
     510               1 :                 "response is too large (max is {MAX_RESPONSE_SIZE} bytes)"
     511               1 :             ));
     512          506354 :         }
     513                 :     }
     514                 : 
     515              49 :     let ready = row_stream.ready_status();
     516              49 : 
     517              49 :     // grab the command tag and number of rows affected
     518              49 :     let command_tag = row_stream.command_tag().unwrap_or_default();
     519              49 :     let mut command_tag_split = command_tag.split(' ');
     520              49 :     let command_tag_name = command_tag_split.next().unwrap_or_default();
     521              49 :     let command_tag_count = if command_tag_name == "INSERT" {
     522                 :         // INSERT returns OID first and then number of rows
     523               2 :         command_tag_split.nth(1)
     524                 :     } else {
     525                 :         // other commands return number of rows (if any)
     526              47 :         command_tag_split.next()
     527                 :     }
     528              49 :     .and_then(|s| s.parse::<i64>().ok());
     529              49 : 
     530              49 :     let mut fields = vec![];
     531              49 :     let mut columns = vec![];
     532                 : 
     533             117 :     for c in row_stream.columns() {
     534             117 :         fields.push(json!({
     535             117 :             "name": Value::String(c.name().to_owned()),
     536             117 :             "dataTypeID": Value::Number(c.type_().oid().into()),
     537             117 :             "tableID": c.table_oid(),
     538             117 :             "columnID": c.column_id(),
     539             117 :             "dataTypeSize": c.type_size(),
     540             117 :             "dataTypeModifier": c.type_modifier(),
     541             117 :             "format": "text",
     542             117 :         }));
     543             117 :         columns.push(client.get_type(c.type_oid()).await?);
     544                 :     }
     545                 : 
     546                 :     // convert rows to JSON
     547              49 :     let rows = rows
     548              49 :         .iter()
     549              51 :         .map(|row| pg_text_row_to_json(row, &columns, raw_output, array_mode))
     550              49 :         .collect::<Result<Vec<_>, _>>()?;
     551                 : 
     552                 :     // resulting JSON format is based on the format of node-postgres result
     553              49 :     Ok((
     554              49 :         ready,
     555              49 :         json!({
     556              49 :             "command": command_tag_name,
     557              49 :             "rowCount": command_tag_count,
     558              49 :             "rows": rows,
     559              49 :             "fields": fields,
     560              49 :             "rowAsArray": array_mode,
     561              49 :         }),
     562              49 :     ))
     563              51 : }
     564                 : 
     565                 : //
     566                 : // Convert postgres row with text-encoded values to JSON object
     567                 : //
     568              51 : pub fn pg_text_row_to_json(
     569              51 :     row: &Row,
     570              51 :     columns: &[Type],
     571              51 :     raw_output: bool,
     572              51 :     array_mode: bool,
     573              51 : ) -> Result<Value, anyhow::Error> {
     574              51 :     let iter = row
     575              51 :         .columns()
     576              51 :         .iter()
     577              51 :         .zip(columns)
     578              51 :         .enumerate()
     579             129 :         .map(|(i, (column, typ))| {
     580             129 :             let name = column.name();
     581             129 :             let pg_value = row.as_text(i)?;
     582             129 :             let json_value = if raw_output {
     583               6 :                 match pg_value {
     584               6 :                     Some(v) => Value::String(v.to_string()),
     585 UBC           0 :                     None => Value::Null,
     586                 :                 }
     587                 :             } else {
     588 CBC         123 :                 pg_text_to_json(pg_value, typ)?
     589                 :             };
     590             129 :             Ok((name.to_string(), json_value))
     591             129 :         });
     592              51 : 
     593              51 :     if array_mode {
     594                 :         // drop keys and aggregate into array
     595               2 :         let arr = iter
     596               6 :             .map(|r| r.map(|(_key, val)| val))
     597               2 :             .collect::<Result<Vec<Value>, anyhow::Error>>()?;
     598               2 :         Ok(Value::Array(arr))
     599                 :     } else {
     600              49 :         let obj = iter.collect::<Result<Map<String, Value>, anyhow::Error>>()?;
     601              49 :         Ok(Value::Object(obj))
     602                 :     }
     603              51 : }
     604                 : 
     605                 : //
     606                 : // Convert postgres text-encoded value to JSON value
     607                 : //
     608             225 : pub fn pg_text_to_json(pg_value: Option<&str>, pg_type: &Type) -> Result<Value, anyhow::Error> {
     609             225 :     if let Some(val) = pg_value {
     610             211 :         if let Kind::Array(elem_type) = pg_type.kind() {
     611               7 :             return pg_array_parse(val, elem_type);
     612             204 :         }
     613             204 : 
     614             204 :         match *pg_type {
     615              30 :             Type::BOOL => Ok(Value::Bool(val == "t")),
     616                 :             Type::INT2 | Type::INT4 => {
     617              67 :                 let val = val.parse::<i32>()?;
     618              67 :                 Ok(Value::Number(serde_json::Number::from(val)))
     619                 :             }
     620                 :             Type::FLOAT4 | Type::FLOAT8 => {
     621              25 :                 let fval = val.parse::<f64>()?;
     622              25 :                 let num = serde_json::Number::from_f64(fval);
     623              25 :                 if let Some(num) = num {
     624              16 :                     Ok(Value::Number(num))
     625                 :                 } else {
     626                 :                     // Pass Nan, Inf, -Inf as strings
     627                 :                     // JS JSON.stringify() does converts them to null, but we
     628                 :                     // want to preserve them, so we pass them as strings
     629               9 :                     Ok(Value::String(val.to_string()))
     630                 :                 }
     631                 :             }
     632              12 :             Type::JSON | Type::JSONB => Ok(serde_json::from_str(val)?),
     633              70 :             _ => Ok(Value::String(val.to_string())),
     634                 :         }
     635                 :     } else {
     636              14 :         Ok(Value::Null)
     637                 :     }
     638             225 : }
     639                 : 
     640                 : //
     641                 : // Parse postgres array into JSON array.
     642                 : //
     643                 : // This is a bit involved because we need to handle nested arrays and quoted
     644                 : // values. Unlike postgres we don't check that all nested arrays have the same
     645                 : // dimensions, we just return them as is.
     646                 : //
     647              29 : fn pg_array_parse(pg_array: &str, elem_type: &Type) -> Result<Value, anyhow::Error> {
     648              29 :     _pg_array_parse(pg_array, elem_type, false).map(|(v, _)| v)
     649              29 : }
     650                 : 
     651              46 : fn _pg_array_parse(
     652              46 :     pg_array: &str,
     653              46 :     elem_type: &Type,
     654              46 :     nested: bool,
     655              46 : ) -> Result<(Value, usize), anyhow::Error> {
     656              46 :     let mut pg_array_chr = pg_array.char_indices();
     657              46 :     let mut level = 0;
     658              46 :     let mut quote = false;
     659              46 :     let mut entries: Vec<Value> = Vec::new();
     660              46 :     let mut entry = String::new();
     661              46 : 
     662              46 :     // skip bounds decoration
     663              46 :     if let Some('[') = pg_array.chars().next() {
     664              18 :         for (_, c) in pg_array_chr.by_ref() {
     665              18 :             if c == '=' {
     666               1 :                 break;
     667              17 :             }
     668                 :         }
     669              45 :     }
     670                 : 
     671             108 :     fn push_checked(
     672             108 :         entry: &mut String,
     673             108 :         entries: &mut Vec<Value>,
     674             108 :         elem_type: &Type,
     675             108 :     ) -> Result<(), anyhow::Error> {
     676             108 :         if !entry.is_empty() {
     677                 :             // While in usual postgres response we get nulls as None and everything else
     678                 :             // as Some(&str), in arrays we get NULL as unquoted 'NULL' string (while
     679                 :             // string with value 'NULL' will be represented by '"NULL"'). So catch NULLs
     680                 :             // here while we have quotation info and convert them to None.
     681              73 :             if entry == "NULL" {
     682               7 :                 entries.push(pg_text_to_json(None, elem_type)?);
     683                 :             } else {
     684              66 :                 entries.push(pg_text_to_json(Some(entry), elem_type)?);
     685                 :             }
     686              73 :             entry.clear();
     687              35 :         }
     688                 : 
     689             108 :         Ok(())
     690             108 :     }
     691                 : 
     692             518 :     while let Some((mut i, mut c)) = pg_array_chr.next() {
     693             489 :         let mut escaped = false;
     694             489 : 
     695             489 :         if c == '\\' {
     696              19 :             escaped = true;
     697              19 :             (i, c) = pg_array_chr.next().unwrap();
     698             470 :         }
     699                 : 
     700             261 :         match c {
     701              63 :             '{' if !quote => {
     702              63 :                 level += 1;
     703              63 :                 if level > 1 {
     704              17 :                     let (res, off) = _pg_array_parse(&pg_array[i..], elem_type, true)?;
     705              17 :                     entries.push(res);
     706             195 :                     for _ in 0..off - 1 {
     707             195 :                         pg_array_chr.next();
     708             195 :                     }
     709              46 :                 }
     710                 :             }
     711              63 :             '}' if !quote => {
     712              63 :                 level -= 1;
     713              63 :                 if level == 0 {
     714              46 :                     push_checked(&mut entry, &mut entries, elem_type)?;
     715              46 :                     if nested {
     716              17 :                         return Ok((Value::Array(entries), i));
     717              29 :                     }
     718              17 :                 }
     719                 :             }
     720              36 :             '"' if !escaped => {
     721              36 :                 if quote {
     722                 :                     // end of quoted string, so push it manually without any checks
     723                 :                     // for emptiness or nulls
     724              18 :                     entries.push(pg_text_to_json(Some(&entry), elem_type)?);
     725              18 :                     entry.clear();
     726              18 :                 }
     727              36 :                 quote = !quote;
     728                 :             }
     729              62 :             ',' if !quote => {
     730              62 :                 push_checked(&mut entry, &mut entries, elem_type)?;
     731                 :             }
     732             265 :             _ => {
     733             265 :                 entry.push(c);
     734             265 :             }
     735                 :         }
     736                 :     }
     737                 : 
     738              29 :     if level != 0 {
     739 UBC           0 :         return Err(anyhow::anyhow!("unbalanced array"));
     740 CBC          29 :     }
     741              29 : 
     742              29 :     Ok((Value::Array(entries), 0))
     743              46 : }
     744                 : 
     745                 : #[cfg(test)]
     746                 : mod tests {
     747                 :     use super::*;
     748                 :     use serde_json::json;
     749                 : 
     750               1 :     #[test]
     751               1 :     fn test_atomic_types_to_pg_params() {
     752               1 :         let json = vec![Value::Bool(true), Value::Bool(false)];
     753               1 :         let pg_params = json_to_pg_text(json);
     754               1 :         assert_eq!(
     755               1 :             pg_params,
     756               1 :             vec![Some("true".to_owned()), Some("false".to_owned())]
     757               1 :         );
     758                 : 
     759               1 :         let json = vec![Value::Number(serde_json::Number::from(42))];
     760               1 :         let pg_params = json_to_pg_text(json);
     761               1 :         assert_eq!(pg_params, vec![Some("42".to_owned())]);
     762                 : 
     763               1 :         let json = vec![Value::String("foo\"".to_string())];
     764               1 :         let pg_params = json_to_pg_text(json);
     765               1 :         assert_eq!(pg_params, vec![Some("foo\"".to_owned())]);
     766                 : 
     767               1 :         let json = vec![Value::Null];
     768               1 :         let pg_params = json_to_pg_text(json);
     769               1 :         assert_eq!(pg_params, vec![None]);
     770               1 :     }
     771                 : 
     772               1 :     #[test]
     773               1 :     fn test_json_array_to_pg_array() {
     774               1 :         // atoms and escaping
     775               1 :         let json = "[true, false, null, \"NULL\", 42, \"foo\", \"bar\\\"-\\\\\"]";
     776               1 :         let json: Value = serde_json::from_str(json).unwrap();
     777               1 :         let pg_params = json_to_pg_text(vec![json]);
     778               1 :         assert_eq!(
     779               1 :             pg_params,
     780               1 :             vec![Some(
     781               1 :                 "{true,false,NULL,\"NULL\",42,\"foo\",\"bar\\\"-\\\\\"}".to_owned()
     782               1 :             )]
     783               1 :         );
     784                 : 
     785                 :         // nested arrays
     786               1 :         let json = "[[true, false], [null, 42], [\"foo\", \"bar\\\"-\\\\\"]]";
     787               1 :         let json: Value = serde_json::from_str(json).unwrap();
     788               1 :         let pg_params = json_to_pg_text(vec![json]);
     789               1 :         assert_eq!(
     790               1 :             pg_params,
     791               1 :             vec![Some(
     792               1 :                 "{{true,false},{NULL,42},{\"foo\",\"bar\\\"-\\\\\"}}".to_owned()
     793               1 :             )]
     794               1 :         );
     795                 :         // array of objects
     796               1 :         let json = r#"[{"foo": 1},{"bar": 2}]"#;
     797               1 :         let json: Value = serde_json::from_str(json).unwrap();
     798               1 :         let pg_params = json_to_pg_text(vec![json]);
     799               1 :         assert_eq!(
     800               1 :             pg_params,
     801               1 :             vec![Some(r#"{"{\"foo\":1}","{\"bar\":2}"}"#.to_owned())]
     802               1 :         );
     803               1 :     }
     804                 : 
     805               1 :     #[test]
     806               1 :     fn test_atomic_types_parse() {
     807               1 :         assert_eq!(
     808               1 :             pg_text_to_json(Some("foo"), &Type::TEXT).unwrap(),
     809               1 :             json!("foo")
     810               1 :         );
     811               1 :         assert_eq!(pg_text_to_json(None, &Type::TEXT).unwrap(), json!(null));
     812               1 :         assert_eq!(pg_text_to_json(Some("42"), &Type::INT4).unwrap(), json!(42));
     813               1 :         assert_eq!(pg_text_to_json(Some("42"), &Type::INT2).unwrap(), json!(42));
     814               1 :         assert_eq!(
     815               1 :             pg_text_to_json(Some("42"), &Type::INT8).unwrap(),
     816               1 :             json!("42")
     817               1 :         );
     818               1 :         assert_eq!(
     819               1 :             pg_text_to_json(Some("42.42"), &Type::FLOAT8).unwrap(),
     820               1 :             json!(42.42)
     821               1 :         );
     822               1 :         assert_eq!(
     823               1 :             pg_text_to_json(Some("42.42"), &Type::FLOAT4).unwrap(),
     824               1 :             json!(42.42)
     825               1 :         );
     826               1 :         assert_eq!(
     827               1 :             pg_text_to_json(Some("NaN"), &Type::FLOAT4).unwrap(),
     828               1 :             json!("NaN")
     829               1 :         );
     830               1 :         assert_eq!(
     831               1 :             pg_text_to_json(Some("Infinity"), &Type::FLOAT4).unwrap(),
     832               1 :             json!("Infinity")
     833               1 :         );
     834               1 :         assert_eq!(
     835               1 :             pg_text_to_json(Some("-Infinity"), &Type::FLOAT4).unwrap(),
     836               1 :             json!("-Infinity")
     837               1 :         );
     838                 : 
     839               1 :         let json: Value =
     840               1 :             serde_json::from_str("{\"s\":\"str\",\"n\":42,\"f\":4.2,\"a\":[null,3,\"a\"]}")
     841               1 :                 .unwrap();
     842               1 :         assert_eq!(
     843               1 :             pg_text_to_json(
     844               1 :                 Some(r#"{"s":"str","n":42,"f":4.2,"a":[null,3,"a"]}"#),
     845               1 :                 &Type::JSONB
     846               1 :             )
     847               1 :             .unwrap(),
     848               1 :             json
     849               1 :         );
     850               1 :     }
     851                 : 
     852               1 :     #[test]
     853               1 :     fn test_pg_array_parse_text() {
     854               4 :         fn pt(pg_arr: &str) -> Value {
     855               4 :             pg_array_parse(pg_arr, &Type::TEXT).unwrap()
     856               4 :         }
     857               1 :         assert_eq!(
     858               1 :             pt(r#"{"aa\"\\\,a",cha,"bbbb"}"#),
     859               1 :             json!(["aa\"\\,a", "cha", "bbbb"])
     860               1 :         );
     861               1 :         assert_eq!(
     862               1 :             pt(r#"{{"foo","bar"},{"bee","bop"}}"#),
     863               1 :             json!([["foo", "bar"], ["bee", "bop"]])
     864               1 :         );
     865               1 :         assert_eq!(
     866               1 :             pt(r#"{{{{"foo",NULL,"bop",bup}}}}"#),
     867               1 :             json!([[[["foo", null, "bop", "bup"]]]])
     868               1 :         );
     869               1 :         assert_eq!(
     870               1 :             pt(r#"{{"1",2,3},{4,NULL,6},{NULL,NULL,NULL}}"#),
     871               1 :             json!([["1", "2", "3"], ["4", null, "6"], [null, null, null]])
     872               1 :         );
     873               1 :     }
     874                 : 
     875               1 :     #[test]
     876               1 :     fn test_pg_array_parse_bool() {
     877               4 :         fn pb(pg_arr: &str) -> Value {
     878               4 :             pg_array_parse(pg_arr, &Type::BOOL).unwrap()
     879               4 :         }
     880               1 :         assert_eq!(pb(r#"{t,f,t}"#), json!([true, false, true]));
     881               1 :         assert_eq!(pb(r#"{{t,f,t}}"#), json!([[true, false, true]]));
     882               1 :         assert_eq!(
     883               1 :             pb(r#"{{t,f},{f,t}}"#),
     884               1 :             json!([[true, false], [false, true]])
     885               1 :         );
     886               1 :         assert_eq!(
     887               1 :             pb(r#"{{t,NULL},{NULL,f}}"#),
     888               1 :             json!([[true, null], [null, false]])
     889               1 :         );
     890               1 :     }
     891                 : 
     892               1 :     #[test]
     893               1 :     fn test_pg_array_parse_numbers() {
     894               9 :         fn pn(pg_arr: &str, ty: &Type) -> Value {
     895               9 :             pg_array_parse(pg_arr, ty).unwrap()
     896               9 :         }
     897               1 :         assert_eq!(pn(r#"{1,2,3}"#, &Type::INT4), json!([1, 2, 3]));
     898               1 :         assert_eq!(pn(r#"{1,2,3}"#, &Type::INT2), json!([1, 2, 3]));
     899               1 :         assert_eq!(pn(r#"{1,2,3}"#, &Type::INT8), json!(["1", "2", "3"]));
     900               1 :         assert_eq!(pn(r#"{1,2,3}"#, &Type::FLOAT4), json!([1.0, 2.0, 3.0]));
     901               1 :         assert_eq!(pn(r#"{1,2,3}"#, &Type::FLOAT8), json!([1.0, 2.0, 3.0]));
     902               1 :         assert_eq!(
     903               1 :             pn(r#"{1.1,2.2,3.3}"#, &Type::FLOAT4),
     904               1 :             json!([1.1, 2.2, 3.3])
     905               1 :         );
     906               1 :         assert_eq!(
     907               1 :             pn(r#"{1.1,2.2,3.3}"#, &Type::FLOAT8),
     908               1 :             json!([1.1, 2.2, 3.3])
     909               1 :         );
     910               1 :         assert_eq!(
     911               1 :             pn(r#"{NaN,Infinity,-Infinity}"#, &Type::FLOAT4),
     912               1 :             json!(["NaN", "Infinity", "-Infinity"])
     913               1 :         );
     914               1 :         assert_eq!(
     915               1 :             pn(r#"{NaN,Infinity,-Infinity}"#, &Type::FLOAT8),
     916               1 :             json!(["NaN", "Infinity", "-Infinity"])
     917               1 :         );
     918               1 :     }
     919                 : 
     920               1 :     #[test]
     921               1 :     fn test_pg_array_with_decoration() {
     922               1 :         fn p(pg_arr: &str) -> Value {
     923               1 :             pg_array_parse(pg_arr, &Type::INT2).unwrap()
     924               1 :         }
     925               1 :         assert_eq!(
     926               1 :             p(r#"[1:1][-2:-1][3:5]={{{1,2,3},{4,5,6}}}"#),
     927               1 :             json!([[[1, 2, 3], [4, 5, 6]]])
     928               1 :         );
     929               1 :     }
     930               1 :     #[test]
     931               1 :     fn test_pg_array_parse_json() {
     932               4 :         fn pt(pg_arr: &str) -> Value {
     933               4 :             pg_array_parse(pg_arr, &Type::JSONB).unwrap()
     934               4 :         }
     935               1 :         assert_eq!(pt(r#"{"{}"}"#), json!([{}]));
     936               1 :         assert_eq!(
     937               1 :             pt(r#"{"{\"foo\": 1, \"bar\": 2}"}"#),
     938               1 :             json!([{"foo": 1, "bar": 2}])
     939               1 :         );
     940               1 :         assert_eq!(
     941               1 :             pt(r#"{"{\"foo\": 1}", "{\"bar\": 2}"}"#),
     942               1 :             json!([{"foo": 1}, {"bar": 2}])
     943               1 :         );
     944               1 :         assert_eq!(
     945               1 :             pt(r#"{{"{\"foo\": 1}", "{\"bar\": 2}"}}"#),
     946               1 :             json!([[{"foo": 1}, {"bar": 2}]])
     947               1 :         );
     948               1 :     }
     949                 : }
        

Generated by: LCOV version 2.1-beta