LCOV - code coverage report
Current view: top level - proxy/src/serverless - sql_over_http.rs (source / functions) Coverage Total Hit
Test: aca8877be6ceba750c1be359ed71bc1799d52b30.info Lines: 86.3 % 313 270
Test Date: 2024-02-14 18:05:35 Functions: 61.6 % 138 85

            Line data    Source code
       1              : use std::sync::Arc;
       2              : 
       3              : use anyhow::bail;
       4              : use futures::pin_mut;
       5              : use futures::StreamExt;
       6              : use hyper::body::HttpBody;
       7              : use hyper::header;
       8              : use hyper::http::HeaderName;
       9              : use hyper::http::HeaderValue;
      10              : use hyper::Response;
      11              : use hyper::StatusCode;
      12              : use hyper::{Body, HeaderMap, Request};
      13              : use serde_json::json;
      14              : use serde_json::Value;
      15              : use tokio::join;
      16              : use tokio_postgres::error::DbError;
      17              : use tokio_postgres::error::ErrorPosition;
      18              : use tokio_postgres::GenericClient;
      19              : use tokio_postgres::IsolationLevel;
      20              : use tokio_postgres::ReadyForQueryStatus;
      21              : use tokio_postgres::Transaction;
      22              : use tracing::error;
      23              : use tracing::info;
      24              : use tracing::instrument;
      25              : use url::Url;
      26              : use utils::http::error::ApiError;
      27              : use utils::http::json::json_response;
      28              : 
      29              : use crate::auth::backend::ComputeUserInfo;
      30              : use crate::auth::endpoint_sni;
      31              : use crate::auth::ComputeUserInfoParseError;
      32              : use crate::config::ProxyConfig;
      33              : use crate::config::TlsConfig;
      34              : use crate::context::RequestMonitoring;
      35              : use crate::error::ReportableError;
      36              : use crate::metrics::HTTP_CONTENT_LENGTH;
      37              : use crate::metrics::NUM_CONNECTION_REQUESTS_GAUGE;
      38              : use crate::proxy::NeonOptions;
      39              : use crate::serverless::backend::HttpConnError;
      40              : use crate::DbName;
      41              : use crate::RoleName;
      42              : 
      43              : use super::backend::PoolingBackend;
      44              : use super::conn_pool::ConnInfo;
      45              : use super::json::json_to_pg_text;
      46              : use super::json::pg_text_row_to_json;
      47              : 
      48          355 : #[derive(serde::Deserialize)]
      49              : #[serde(rename_all = "camelCase")]
      50              : struct QueryData {
      51              :     query: String,
      52              :     #[serde(deserialize_with = "bytes_to_pg_text")]
      53              :     params: Vec<Option<String>>,
      54              :     #[serde(default)]
      55              :     array_mode: Option<bool>,
      56              : }
      57              : 
      58            9 : #[derive(serde::Deserialize)]
      59              : struct BatchQueryData {
      60              :     queries: Vec<QueryData>,
      61              : }
      62              : 
      63           50 : #[derive(serde::Deserialize)]
      64              : #[serde(untagged)]
      65              : enum Payload {
      66              :     Single(QueryData),
      67              :     Batch(BatchQueryData),
      68              : }
      69              : 
      70              : const MAX_RESPONSE_SIZE: usize = 10 * 1024 * 1024; // 10 MiB
      71              : const MAX_REQUEST_SIZE: u64 = 10 * 1024 * 1024; // 10 MiB
      72              : 
      73              : static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output");
      74              : static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode");
      75              : static ALLOW_POOL: HeaderName = HeaderName::from_static("neon-pool-opt-in");
      76              : static TXN_ISOLATION_LEVEL: HeaderName = HeaderName::from_static("neon-batch-isolation-level");
      77              : static TXN_READ_ONLY: HeaderName = HeaderName::from_static("neon-batch-read-only");
      78              : static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrable");
      79              : 
      80              : static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true");
      81              : 
      82           57 : fn bytes_to_pg_text<'de, D>(deserializer: D) -> Result<Vec<Option<String>>, D::Error>
      83           57 : where
      84           57 :     D: serde::de::Deserializer<'de>,
      85           57 : {
      86              :     // TODO: consider avoiding the allocation here.
      87           57 :     let json: Vec<Value> = serde::de::Deserialize::deserialize(deserializer)?;
      88           57 :     Ok(json_to_pg_text(json))
      89           57 : }
      90              : 
      91            0 : #[derive(Debug, thiserror::Error)]
      92              : pub enum ConnInfoError {
      93              :     #[error("invalid header: {0}")]
      94              :     InvalidHeader(&'static str),
      95              :     #[error("invalid connection string: {0}")]
      96              :     UrlParseError(#[from] url::ParseError),
      97              :     #[error("incorrect scheme")]
      98              :     IncorrectScheme,
      99              :     #[error("missing database name")]
     100              :     MissingDbName,
     101              :     #[error("invalid database name")]
     102              :     InvalidDbName,
     103              :     #[error("missing username")]
     104              :     MissingUsername,
     105              :     #[error("invalid username: {0}")]
     106              :     InvalidUsername(#[from] std::string::FromUtf8Error),
     107              :     #[error("missing password")]
     108              :     MissingPassword,
     109              :     #[error("missing hostname")]
     110              :     MissingHostname,
     111              :     #[error("invalid hostname: {0}")]
     112              :     InvalidEndpoint(#[from] ComputeUserInfoParseError),
     113              :     #[error("malformed endpoint")]
     114              :     MalformedEndpoint,
     115              : }
     116              : 
     117           47 : fn get_conn_info(
     118           47 :     ctx: &mut RequestMonitoring,
     119           47 :     headers: &HeaderMap,
     120           47 :     tls: &TlsConfig,
     121           47 : ) -> Result<ConnInfo, ConnInfoError> {
     122           47 :     // HTTP only uses cleartext (for now and likely always)
     123           47 :     ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
     124              : 
     125           47 :     let connection_string = headers
     126           47 :         .get("Neon-Connection-String")
     127           47 :         .ok_or(ConnInfoError::InvalidHeader("Neon-Connection-String"))?
     128           47 :         .to_str()
     129           47 :         .map_err(|_| ConnInfoError::InvalidHeader("Neon-Connection-String"))?;
     130              : 
     131           47 :     let connection_url = Url::parse(connection_string)?;
     132              : 
     133           47 :     let protocol = connection_url.scheme();
     134           47 :     if protocol != "postgres" && protocol != "postgresql" {
     135            0 :         return Err(ConnInfoError::IncorrectScheme);
     136           47 :     }
     137              : 
     138           47 :     let mut url_path = connection_url
     139           47 :         .path_segments()
     140           47 :         .ok_or(ConnInfoError::MissingDbName)?;
     141              : 
     142           47 :     let dbname: DbName = url_path.next().ok_or(ConnInfoError::InvalidDbName)?.into();
     143           47 :     ctx.set_dbname(dbname.clone());
     144              : 
     145           47 :     let username = RoleName::from(urlencoding::decode(connection_url.username())?);
     146           47 :     if username.is_empty() {
     147            0 :         return Err(ConnInfoError::MissingUsername);
     148           47 :     }
     149           47 :     ctx.set_user(username.clone());
     150              : 
     151           47 :     let password = connection_url
     152           47 :         .password()
     153           47 :         .ok_or(ConnInfoError::MissingPassword)?;
     154           47 :     let password = urlencoding::decode_binary(password.as_bytes());
     155              : 
     156           47 :     let hostname = connection_url
     157           47 :         .host_str()
     158           47 :         .ok_or(ConnInfoError::MissingHostname)?;
     159              : 
     160           47 :     let endpoint =
     161           47 :         endpoint_sni(hostname, &tls.common_names)?.ok_or(ConnInfoError::MalformedEndpoint)?;
     162           47 :     ctx.set_endpoint_id(endpoint.clone());
     163           47 : 
     164           47 :     let pairs = connection_url.query_pairs();
     165           47 : 
     166           47 :     let mut options = Option::None;
     167              : 
     168           47 :     for (key, value) in pairs {
     169            0 :         if key == "options" {
     170            0 :             options = Some(NeonOptions::parse_options_raw(&value));
     171            0 :             break;
     172            0 :         }
     173              :     }
     174              : 
     175           47 :     let user_info = ComputeUserInfo {
     176           47 :         endpoint,
     177           47 :         user: username,
     178           47 :         options: options.unwrap_or_default(),
     179           47 :     };
     180           47 : 
     181           47 :     Ok(ConnInfo {
     182           47 :         user_info,
     183           47 :         dbname,
     184           47 :         password: match password {
     185           46 :             std::borrow::Cow::Borrowed(b) => b.into(),
     186            1 :             std::borrow::Cow::Owned(b) => b.into(),
     187              :         },
     188              :     })
     189           47 : }
     190              : 
     191              : // TODO: return different http error codes
     192           47 : pub async fn handle(
     193           47 :     config: &'static ProxyConfig,
     194           47 :     mut ctx: RequestMonitoring,
     195           47 :     request: Request<Body>,
     196           47 :     backend: Arc<PoolingBackend>,
     197           47 : ) -> Result<Response<Body>, ApiError> {
     198           47 :     let result = tokio::time::timeout(
     199           47 :         config.http_config.request_timeout,
     200           47 :         handle_inner(config, &mut ctx, request, backend),
     201           47 :     )
     202          980 :     .await;
     203           47 :     let mut response = match result {
     204           47 :         Ok(r) => match r {
     205           42 :             Ok(r) => {
     206           42 :                 ctx.set_success();
     207           42 :                 r
     208              :             }
     209            5 :             Err(e) => {
     210            5 :                 // TODO: ctx.set_error_kind(e.get_error_type());
     211            5 : 
     212            5 :                 let mut message = format!("{:?}", e);
     213            5 :                 let db_error = e
     214            5 :                     .downcast_ref::<tokio_postgres::Error>()
     215            5 :                     .and_then(|e| e.as_db_error());
     216           65 :                 fn get<'a, T: serde::Serialize>(
     217           65 :                     db: Option<&'a DbError>,
     218           65 :                     x: impl FnOnce(&'a DbError) -> T,
     219           65 :                 ) -> Value {
     220           65 :                     db.map(x)
     221           65 :                         .and_then(|t| serde_json::to_value(t).ok())
     222           65 :                         .unwrap_or_default()
     223           65 :                 }
     224              : 
     225            5 :                 if let Some(db_error) = db_error {
     226            1 :                     db_error.message().clone_into(&mut message);
     227            4 :                 }
     228              : 
     229            5 :                 let position = db_error.and_then(|db| db.position());
     230            5 :                 let (position, internal_position, internal_query) = match position {
     231            1 :                     Some(ErrorPosition::Original(position)) => (
     232            1 :                         Value::String(position.to_string()),
     233            1 :                         Value::Null,
     234            1 :                         Value::Null,
     235            1 :                     ),
     236            0 :                     Some(ErrorPosition::Internal { position, query }) => (
     237            0 :                         Value::Null,
     238            0 :                         Value::String(position.to_string()),
     239            0 :                         Value::String(query.clone()),
     240            0 :                     ),
     241            4 :                     None => (Value::Null, Value::Null, Value::Null),
     242              :                 };
     243              : 
     244            5 :                 let code = get(db_error, |db| db.code().code());
     245            5 :                 let severity = get(db_error, |db| db.severity());
     246            5 :                 let detail = get(db_error, |db| db.detail());
     247            5 :                 let hint = get(db_error, |db| db.hint());
     248            5 :                 let where_ = get(db_error, |db| db.where_());
     249            5 :                 let table = get(db_error, |db| db.table());
     250            5 :                 let column = get(db_error, |db| db.column());
     251            5 :                 let schema = get(db_error, |db| db.schema());
     252            5 :                 let datatype = get(db_error, |db| db.datatype());
     253            5 :                 let constraint = get(db_error, |db| db.constraint());
     254            5 :                 let file = get(db_error, |db| db.file());
     255            5 :                 let line = get(db_error, |db| db.line().map(|l| l.to_string()));
     256            5 :                 let routine = get(db_error, |db| db.routine());
     257              : 
     258            5 :                 error!(
     259            5 :                     ?code,
     260            5 :                     "sql-over-http per-client task finished with an error: {e:#}"
     261            5 :                 );
     262              :                 // TODO: this shouldn't always be bad request.
     263            5 :                 json_response(
     264            5 :                     StatusCode::BAD_REQUEST,
     265            5 :                     json!({
     266            5 :                         "message": message,
     267            5 :                         "code": code,
     268            5 :                         "detail": detail,
     269            5 :                         "hint": hint,
     270            5 :                         "position": position,
     271            5 :                         "internalPosition": internal_position,
     272            5 :                         "internalQuery": internal_query,
     273            5 :                         "severity": severity,
     274            5 :                         "where": where_,
     275            5 :                         "table": table,
     276            5 :                         "column": column,
     277            5 :                         "schema": schema,
     278            5 :                         "dataType": datatype,
     279            5 :                         "constraint": constraint,
     280            5 :                         "file": file,
     281            5 :                         "line": line,
     282            5 :                         "routine": routine,
     283            5 :                     }),
     284            5 :                 )?
     285              :             }
     286              :         },
     287            0 :         Err(e) => {
     288            0 :             ctx.set_error_kind(e.get_error_kind());
     289            0 : 
     290            0 :             let message = format!(
     291            0 :                 "HTTP-Connection timed out, execution time exeeded {} seconds",
     292            0 :                 config.http_config.request_timeout.as_secs()
     293            0 :             );
     294            0 :             error!(message);
     295            0 :             json_response(
     296            0 :                 StatusCode::GATEWAY_TIMEOUT,
     297            0 :                 json!({ "message": message, "code": StatusCode::GATEWAY_TIMEOUT.as_u16() }),
     298            0 :             )?
     299              :         }
     300              :     };
     301              : 
     302           47 :     response.headers_mut().insert(
     303           47 :         "Access-Control-Allow-Origin",
     304           47 :         hyper::http::HeaderValue::from_static("*"),
     305           47 :     );
     306           47 :     Ok(response)
     307           47 : }
     308              : 
     309           94 : #[instrument(
     310              :     name = "sql-over-http",
     311              :     skip_all,
     312              :     fields(
     313              :         pid = tracing::field::Empty,
     314              :         conn_id = tracing::field::Empty
     315              :     )
     316              : )]
     317              : async fn handle_inner(
     318              :     config: &'static ProxyConfig,
     319              :     ctx: &mut RequestMonitoring,
     320              :     request: Request<Body>,
     321              :     backend: Arc<PoolingBackend>,
     322              : ) -> anyhow::Result<Response<Body>> {
     323              :     let _request_gauge = NUM_CONNECTION_REQUESTS_GAUGE
     324              :         .with_label_values(&[ctx.protocol])
     325              :         .guard();
     326           47 :     info!(
     327           47 :         protocol = ctx.protocol,
     328           47 :         "handling interactive connection from client"
     329           47 :     );
     330              : 
     331              :     //
     332              :     // Determine the destination and connection params
     333              :     //
     334              :     let headers = request.headers();
     335              :     // TLS config should be there.
     336              :     let conn_info = get_conn_info(ctx, headers, config.tls_config.as_ref().unwrap())?;
     337           47 :     info!(
     338           47 :         user = conn_info.user_info.user.as_str(),
     339           47 :         project = conn_info.user_info.endpoint.as_str(),
     340           47 :         "credentials"
     341           47 :     );
     342              : 
     343              :     // Determine the output options. Default behaviour is 'false'. Anything that is not
     344              :     // strictly 'true' assumed to be false.
     345              :     let raw_output = headers.get(&RAW_TEXT_OUTPUT) == Some(&HEADER_VALUE_TRUE);
     346              :     let default_array_mode = headers.get(&ARRAY_MODE) == Some(&HEADER_VALUE_TRUE);
     347              : 
     348              :     // Allow connection pooling only if explicitly requested
     349              :     // or if we have decided that http pool is no longer opt-in
     350              :     let allow_pool = !config.http_config.pool_options.opt_in
     351              :         || headers.get(&ALLOW_POOL) == Some(&HEADER_VALUE_TRUE);
     352              : 
     353              :     // isolation level, read only and deferrable
     354              : 
     355              :     let txn_isolation_level_raw = headers.get(&TXN_ISOLATION_LEVEL).cloned();
     356              :     let txn_isolation_level = match txn_isolation_level_raw {
     357              :         Some(ref x) => Some(match x.as_bytes() {
     358              :             b"Serializable" => IsolationLevel::Serializable,
     359              :             b"ReadUncommitted" => IsolationLevel::ReadUncommitted,
     360              :             b"ReadCommitted" => IsolationLevel::ReadCommitted,
     361              :             b"RepeatableRead" => IsolationLevel::RepeatableRead,
     362              :             _ => bail!("invalid isolation level"),
     363              :         }),
     364              :         None => None,
     365              :     };
     366              : 
     367              :     let txn_read_only = headers.get(&TXN_READ_ONLY) == Some(&HEADER_VALUE_TRUE);
     368              :     let txn_deferrable = headers.get(&TXN_DEFERRABLE) == Some(&HEADER_VALUE_TRUE);
     369              : 
     370              :     let request_content_length = match request.body().size_hint().upper() {
     371              :         Some(v) => v,
     372              :         None => MAX_REQUEST_SIZE + 1,
     373              :     };
     374           47 :     info!(request_content_length, "request size in bytes");
     375              :     HTTP_CONTENT_LENGTH.observe(request_content_length as f64);
     376              : 
     377              :     // we don't have a streaming request support yet so this is to prevent OOM
     378              :     // from a malicious user sending an extremely large request body
     379              :     if request_content_length > MAX_REQUEST_SIZE {
     380              :         return Err(anyhow::anyhow!(
     381              :             "request is too large (max is {MAX_REQUEST_SIZE} bytes)"
     382              :         ));
     383              :     }
     384              : 
     385           47 :     let fetch_and_process_request = async {
     386           47 :         let body = hyper::body::to_bytes(request.into_body())
     387            5 :             .await
     388           47 :             .map_err(anyhow::Error::from)?;
     389           47 :         info!(length = body.len(), "request payload read");
     390           47 :         let payload: Payload = serde_json::from_slice(&body)?;
     391           47 :         Ok::<Payload, anyhow::Error>(payload) // Adjust error type accordingly
     392           47 :     };
     393              : 
     394           47 :     let authenticate_and_connect = async {
     395          692 :         let keys = backend.authenticate(ctx, &conn_info).await?;
     396           44 :         let client = backend
     397           44 :             .connect_to_compute(ctx, conn_info, keys, !allow_pool)
     398          121 :             .await?;
     399              :         // not strictly necessary to mark success here,
     400              :         // but it's just insurance for if we forget it somewhere else
     401           44 :         ctx.latency_timer.success();
     402           44 :         Ok::<_, HttpConnError>(client)
     403           47 :     };
     404              : 
     405              :     // Run both operations in parallel
     406              :     let (payload_result, auth_and_connect_result) =
     407          860 :         join!(fetch_and_process_request, authenticate_and_connect,);
     408              : 
     409              :     // Handle the results
     410              :     let payload = payload_result?; // Handle errors appropriately
     411              :     let mut client = auth_and_connect_result?; // Handle errors appropriately
     412              : 
     413              :     let mut response = Response::builder()
     414              :         .status(StatusCode::OK)
     415              :         .header(header::CONTENT_TYPE, "application/json");
     416              : 
     417              :     //
     418              :     // Now execute the query and return the result
     419              :     //
     420              :     let mut size = 0;
     421              :     let result = match payload {
     422              :         Payload::Single(stmt) => {
     423              :             let (status, results) =
     424              :                 query_to_json(&*client, stmt, &mut 0, raw_output, default_array_mode)
     425              :                     .await
     426            2 :                     .map_err(|e| {
     427            2 :                         client.discard();
     428            2 :                         e
     429            2 :                     })?;
     430              :             client.check_idle(status);
     431              :             results
     432              :         }
     433              :         Payload::Batch(statements) => {
     434            3 :             info!("starting transaction");
     435              :             let (inner, mut discard) = client.inner();
     436              :             let mut builder = inner.build_transaction();
     437              :             if let Some(isolation_level) = txn_isolation_level {
     438              :                 builder = builder.isolation_level(isolation_level);
     439              :             }
     440              :             if txn_read_only {
     441              :                 builder = builder.read_only(true);
     442              :             }
     443              :             if txn_deferrable {
     444              :                 builder = builder.deferrable(true);
     445              :             }
     446              : 
     447            0 :             let transaction = builder.start().await.map_err(|e| {
     448            0 :                 // if we cannot start a transaction, we should return immediately
     449            0 :                 // and not return to the pool. connection is clearly broken
     450            0 :                 discard.discard();
     451            0 :                 e
     452            0 :             })?;
     453              : 
     454              :             let results = match query_batch(
     455              :                 &transaction,
     456              :                 statements,
     457              :                 &mut size,
     458              :                 raw_output,
     459              :                 default_array_mode,
     460              :             )
     461              :             .await
     462              :             {
     463              :                 Ok(results) => {
     464            3 :                     info!("commit");
     465            0 :                     let status = transaction.commit().await.map_err(|e| {
     466            0 :                         // if we cannot commit - for now don't return connection to pool
     467            0 :                         // TODO: get a query status from the error
     468            0 :                         discard.discard();
     469            0 :                         e
     470            0 :                     })?;
     471              :                     discard.check_idle(status);
     472              :                     results
     473              :                 }
     474              :                 Err(err) => {
     475            0 :                     info!("rollback");
     476            0 :                     let status = transaction.rollback().await.map_err(|e| {
     477            0 :                         // if we cannot rollback - for now don't return connection to pool
     478            0 :                         // TODO: get a query status from the error
     479            0 :                         discard.discard();
     480            0 :                         e
     481            0 :                     })?;
     482              :                     discard.check_idle(status);
     483              :                     return Err(err);
     484              :                 }
     485              :             };
     486              : 
     487              :             if txn_read_only {
     488              :                 response = response.header(
     489              :                     TXN_READ_ONLY.clone(),
     490              :                     HeaderValue::try_from(txn_read_only.to_string())?,
     491              :                 );
     492              :             }
     493              :             if txn_deferrable {
     494              :                 response = response.header(
     495              :                     TXN_DEFERRABLE.clone(),
     496              :                     HeaderValue::try_from(txn_deferrable.to_string())?,
     497              :                 );
     498              :             }
     499              :             if let Some(txn_isolation_level) = txn_isolation_level_raw {
     500              :                 response = response.header(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level);
     501              :             }
     502              :             json!({ "results": results })
     503              :         }
     504              :     };
     505              : 
     506              :     let metrics = client.metrics();
     507              : 
     508              :     // how could this possibly fail
     509              :     let body = serde_json::to_string(&result).expect("json serialization should not fail");
     510              :     let len = body.len();
     511              :     let response = response
     512              :         .body(Body::from(body))
     513              :         // only fails if invalid status code or invalid header/values are given.
     514              :         // these are not user configurable so it cannot fail dynamically
     515              :         .expect("building response payload should not fail");
     516              : 
     517              :     // count the egress bytes - we miss the TLS and header overhead but oh well...
     518              :     // moving this later in the stack is going to be a lot of effort and ehhhh
     519              :     metrics.record_egress(len as u64);
     520              : 
     521              :     Ok(response)
     522              : }
     523              : 
     524            3 : async fn query_batch(
     525            3 :     transaction: &Transaction<'_>,
     526            3 :     queries: BatchQueryData,
     527            3 :     total_size: &mut usize,
     528            3 :     raw_output: bool,
     529            3 :     array_mode: bool,
     530            3 : ) -> anyhow::Result<Vec<Value>> {
     531            3 :     let mut results = Vec::with_capacity(queries.queries.len());
     532            3 :     let mut current_size = 0;
     533           16 :     for stmt in queries.queries {
     534           13 :         // TODO: maybe we should check that the transaction bit is set here
     535           13 :         let (_, values) =
     536           13 :             query_to_json(transaction, stmt, &mut current_size, raw_output, array_mode).await?;
     537           13 :         results.push(values);
     538              :     }
     539            3 :     *total_size += current_size;
     540            3 :     Ok(results)
     541            3 : }
     542              : 
     543           54 : async fn query_to_json<T: GenericClient>(
     544           54 :     client: &T,
     545           54 :     data: QueryData,
     546           54 :     current_size: &mut usize,
     547           54 :     raw_output: bool,
     548           54 :     default_array_mode: bool,
     549           54 : ) -> anyhow::Result<(ReadyForQueryStatus, Value)> {
     550           54 :     info!("executing query");
     551           54 :     let query_params = data.params;
     552           54 :     let row_stream = client.query_raw_txt(&data.query, query_params).await?;
     553           53 :     info!("finished executing query");
     554              : 
     555              :     // Manually drain the stream into a vector to leave row_stream hanging
     556              :     // around to get a command tag. Also check that the response is not too
     557              :     // big.
     558           53 :     pin_mut!(row_stream);
     559           53 :     let mut rows: Vec<tokio_postgres::Row> = Vec::new();
     560       506410 :     while let Some(row) = row_stream.next().await {
     561       506358 :         let row = row?;
     562       506358 :         *current_size += row.body_len();
     563       506358 :         rows.push(row);
     564       506358 :         // we don't have a streaming response support yet so this is to prevent OOM
     565       506358 :         // from a malicious query (eg a cross join)
     566       506358 :         if *current_size > MAX_RESPONSE_SIZE {
     567            1 :             return Err(anyhow::anyhow!(
     568            1 :                 "response is too large (max is {MAX_RESPONSE_SIZE} bytes)"
     569            1 :             ));
     570       506357 :         }
     571              :     }
     572              : 
     573           52 :     let ready = row_stream.ready_status();
     574           52 : 
     575           52 :     // grab the command tag and number of rows affected
     576           52 :     let command_tag = row_stream.command_tag().unwrap_or_default();
     577           52 :     let mut command_tag_split = command_tag.split(' ');
     578           52 :     let command_tag_name = command_tag_split.next().unwrap_or_default();
     579           52 :     let command_tag_count = if command_tag_name == "INSERT" {
     580              :         // INSERT returns OID first and then number of rows
     581            2 :         command_tag_split.nth(1)
     582              :     } else {
     583              :         // other commands return number of rows (if any)
     584           50 :         command_tag_split.next()
     585              :     }
     586           52 :     .and_then(|s| s.parse::<i64>().ok());
     587              : 
     588           52 :     info!(
     589           52 :         rows = rows.len(),
     590           52 :         ?ready,
     591           52 :         command_tag,
     592           52 :         "finished reading rows"
     593           52 :     );
     594              : 
     595           52 :     let mut fields = vec![];
     596           52 :     let mut columns = vec![];
     597              : 
     598          120 :     for c in row_stream.columns() {
     599          120 :         fields.push(json!({
     600          120 :             "name": Value::String(c.name().to_owned()),
     601          120 :             "dataTypeID": Value::Number(c.type_().oid().into()),
     602          120 :             "tableID": c.table_oid(),
     603          120 :             "columnID": c.column_id(),
     604          120 :             "dataTypeSize": c.type_size(),
     605          120 :             "dataTypeModifier": c.type_modifier(),
     606          120 :             "format": "text",
     607          120 :         }));
     608          120 :         columns.push(client.get_type(c.type_oid()).await?);
     609              :     }
     610              : 
     611           52 :     let array_mode = data.array_mode.unwrap_or(default_array_mode);
     612              : 
     613              :     // convert rows to JSON
     614           52 :     let rows = rows
     615           52 :         .iter()
     616           54 :         .map(|row| pg_text_row_to_json(row, &columns, raw_output, array_mode))
     617           52 :         .collect::<Result<Vec<_>, _>>()?;
     618              : 
     619              :     // resulting JSON format is based on the format of node-postgres result
     620           52 :     Ok((
     621           52 :         ready,
     622           52 :         json!({
     623           52 :             "command": command_tag_name,
     624           52 :             "rowCount": command_tag_count,
     625           52 :             "rows": rows,
     626           52 :             "fields": fields,
     627           52 :             "rowAsArray": array_mode,
     628           52 :         }),
     629           52 :     ))
     630           54 : }
        

Generated by: LCOV version 2.1-beta