TLA Line data Source code
1 : use std::sync::Arc;
2 :
3 : use anyhow::bail;
4 : use futures::pin_mut;
5 : use futures::StreamExt;
6 : use hyper::body::HttpBody;
7 : use hyper::header;
8 : use hyper::http::HeaderName;
9 : use hyper::http::HeaderValue;
10 : use hyper::Response;
11 : use hyper::StatusCode;
12 : use hyper::{Body, HeaderMap, Request};
13 : use serde_json::json;
14 : use serde_json::Map;
15 : use serde_json::Value;
16 : use smol_str::SmolStr;
17 : use tokio_postgres::error::DbError;
18 : use tokio_postgres::types::Kind;
19 : use tokio_postgres::types::Type;
20 : use tokio_postgres::GenericClient;
21 : use tokio_postgres::IsolationLevel;
22 : use tokio_postgres::ReadyForQueryStatus;
23 : use tokio_postgres::Row;
24 : use tokio_postgres::Transaction;
25 : use tracing::error;
26 : use tracing::instrument;
27 : use url::Url;
28 : use utils::http::error::ApiError;
29 : use utils::http::json::json_response;
30 :
31 : use crate::config::HttpConfig;
32 : use crate::context::RequestMonitoring;
33 : use crate::metrics::NUM_CONNECTION_REQUESTS_GAUGE;
34 :
35 : use super::conn_pool::ConnInfo;
36 : use super::conn_pool::GlobalConnPool;
37 :
38 CBC 276 : #[derive(serde::Deserialize)]
39 : struct QueryData {
40 : query: String,
41 : params: Vec<serde_json::Value>,
42 : }
43 :
44 6 : #[derive(serde::Deserialize)]
45 : struct BatchQueryData {
46 : queries: Vec<QueryData>,
47 : }
48 :
49 47 : #[derive(serde::Deserialize)]
50 : #[serde(untagged)]
51 : enum Payload {
52 : Single(QueryData),
53 : Batch(BatchQueryData),
54 : }
55 :
56 : const MAX_RESPONSE_SIZE: usize = 10 * 1024 * 1024; // 10 MiB
57 : const MAX_REQUEST_SIZE: u64 = 10 * 1024 * 1024; // 10 MiB
58 :
59 : static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output");
60 : static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode");
61 : static ALLOW_POOL: HeaderName = HeaderName::from_static("neon-pool-opt-in");
62 : static TXN_ISOLATION_LEVEL: HeaderName = HeaderName::from_static("neon-batch-isolation-level");
63 : static TXN_READ_ONLY: HeaderName = HeaderName::from_static("neon-batch-read-only");
64 : static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrable");
65 :
66 : static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true");
67 :
68 : //
69 : // Convert json non-string types to strings, so that they can be passed to Postgres
70 : // as parameters.
71 : //
72 58 : fn json_to_pg_text(json: Vec<Value>) -> Vec<Option<String>> {
73 58 : json.iter()
74 58 : .map(|value| {
75 18 : match value {
76 : // special care for nulls
77 1 : Value::Null => None,
78 :
79 : // convert to text with escaping
80 9 : v @ (Value::Bool(_) | Value::Number(_) | Value::Object(_)) => Some(v.to_string()),
81 :
82 : // avoid escaping here, as we pass this as a parameter
83 1 : Value::String(s) => Some(s.to_string()),
84 :
85 : // special care for arrays
86 7 : Value::Array(_) => json_array_to_pg_array(value),
87 : }
88 58 : })
89 58 : .collect()
90 58 : }
91 :
92 : //
93 : // Serialize a JSON array to a Postgres array. Contrary to the strings in the params
94 : // in the array we need to escape the strings. Postgres is okay with arrays of form
95 : // '{1,"2",3}'::int[], so we don't check that array holds values of the same type, leaving
96 : // it for Postgres to check.
97 : //
98 : // Example of the same escaping in node-postgres: packages/pg/lib/utils.js
99 : //
100 39 : fn json_array_to_pg_array(value: &Value) -> Option<String> {
101 39 : match value {
102 : // special care for nulls
103 2 : Value::Null => None,
104 :
105 : // convert to text with escaping
106 : // here string needs to be escaped, as it is part of the array
107 22 : v @ (Value::Bool(_) | Value::Number(_) | Value::String(_)) => Some(v.to_string()),
108 5 : v @ Value::Object(_) => json_array_to_pg_array(&Value::String(v.to_string())),
109 :
110 : // recurse into array
111 10 : Value::Array(arr) => {
112 10 : let vals = arr
113 10 : .iter()
114 10 : .map(json_array_to_pg_array)
115 27 : .map(|v| v.unwrap_or_else(|| "NULL".to_string()))
116 10 : .collect::<Vec<_>>()
117 10 : .join(",");
118 10 :
119 10 : Some(format!("{{{}}}", vals))
120 : }
121 : }
122 39 : }
123 :
124 45 : fn get_conn_info(
125 45 : ctx: &mut RequestMonitoring,
126 45 : headers: &HeaderMap,
127 45 : sni_hostname: Option<String>,
128 45 : ) -> Result<ConnInfo, anyhow::Error> {
129 45 : let connection_string = headers
130 45 : .get("Neon-Connection-String")
131 45 : .ok_or(anyhow::anyhow!("missing connection string"))?
132 45 : .to_str()?;
133 :
134 45 : let connection_url = Url::parse(connection_string)?;
135 :
136 45 : let protocol = connection_url.scheme();
137 45 : if protocol != "postgres" && protocol != "postgresql" {
138 UBC 0 : return Err(anyhow::anyhow!(
139 0 : "connection string must start with postgres: or postgresql:"
140 0 : ));
141 CBC 45 : }
142 :
143 45 : let mut url_path = connection_url
144 45 : .path_segments()
145 45 : .ok_or(anyhow::anyhow!("missing database name"))?;
146 :
147 45 : let dbname = url_path
148 45 : .next()
149 45 : .ok_or(anyhow::anyhow!("invalid database name"))?;
150 :
151 45 : let username = SmolStr::from(connection_url.username());
152 45 : if username.is_empty() {
153 UBC 0 : return Err(anyhow::anyhow!("missing username"));
154 CBC 45 : }
155 45 : ctx.set_user(username.clone());
156 :
157 45 : let password = connection_url
158 45 : .password()
159 45 : .ok_or(anyhow::anyhow!("no password"))?;
160 :
161 : // TLS certificate selector now based on SNI hostname, so if we are running here
162 : // we are sure that SNI hostname is set to one of the configured domain names.
163 45 : let sni_hostname = sni_hostname.ok_or(anyhow::anyhow!("no SNI hostname set"))?;
164 :
165 45 : let hostname = connection_url
166 45 : .host_str()
167 45 : .ok_or(anyhow::anyhow!("no host"))?;
168 :
169 45 : let host_header = headers
170 45 : .get("host")
171 45 : .and_then(|h| h.to_str().ok())
172 45 : .and_then(|h| h.split(':').next());
173 45 :
174 45 : if hostname != sni_hostname {
175 UBC 0 : return Err(anyhow::anyhow!("mismatched SNI hostname and hostname"));
176 CBC 45 : } else if let Some(h) = host_header {
177 45 : if h != hostname {
178 UBC 0 : return Err(anyhow::anyhow!("mismatched host header and hostname"));
179 CBC 45 : }
180 UBC 0 : }
181 :
182 CBC 45 : let hostname: SmolStr = hostname.into();
183 45 : ctx.set_endpoint_id(Some(hostname.clone()));
184 45 :
185 45 : let pairs = connection_url.query_pairs();
186 45 :
187 45 : let mut options = Option::None;
188 :
189 45 : for (key, value) in pairs {
190 UBC 0 : if key == "options" {
191 0 : options = Some(value.into());
192 0 : break;
193 0 : }
194 : }
195 :
196 CBC 45 : Ok(ConnInfo {
197 45 : username,
198 45 : dbname: dbname.into(),
199 45 : hostname,
200 45 : password: password.into(),
201 45 : options,
202 45 : })
203 45 : }
204 :
205 : // TODO: return different http error codes
206 45 : pub async fn handle(
207 45 : config: &'static HttpConfig,
208 45 : ctx: &mut RequestMonitoring,
209 45 : request: Request<Body>,
210 45 : sni_hostname: Option<String>,
211 45 : conn_pool: Arc<GlobalConnPool>,
212 45 : ) -> Result<Response<Body>, ApiError> {
213 45 : let result = tokio::time::timeout(
214 45 : config.request_timeout,
215 45 : handle_inner(config, ctx, request, sni_hostname, conn_pool),
216 45 : )
217 727 : .await;
218 45 : let mut response = match result {
219 45 : Ok(r) => match r {
220 40 : Ok(r) => r,
221 5 : Err(e) => {
222 5 : let message = format!("{:?}", e);
223 5 : let db_error = e
224 5 : .downcast_ref::<tokio_postgres::Error>()
225 5 : .and_then(|e| e.as_db_error());
226 65 : fn get<'a, T: serde::Serialize>(
227 65 : db: Option<&'a DbError>,
228 65 : x: impl FnOnce(&'a DbError) -> T,
229 65 : ) -> Value {
230 65 : db.map(x)
231 65 : .and_then(|t| serde_json::to_value(t).ok())
232 65 : .unwrap_or_default()
233 65 : }
234 5 :
235 5 : // TODO(conrad): db_error.position()
236 5 : let code = get(db_error, |db| db.code().code());
237 5 : let severity = get(db_error, |db| db.severity());
238 5 : let detail = get(db_error, |db| db.detail());
239 5 : let hint = get(db_error, |db| db.hint());
240 5 : let where_ = get(db_error, |db| db.where_());
241 5 : let table = get(db_error, |db| db.table());
242 5 : let column = get(db_error, |db| db.column());
243 5 : let schema = get(db_error, |db| db.schema());
244 5 : let datatype = get(db_error, |db| db.datatype());
245 5 : let constraint = get(db_error, |db| db.constraint());
246 5 : let file = get(db_error, |db| db.file());
247 5 : let line = get(db_error, |db| db.line());
248 5 : let routine = get(db_error, |db| db.routine());
249 :
250 5 : error!(
251 5 : ?code,
252 5 : "sql-over-http per-client task finished with an error: {e:#}"
253 5 : );
254 : // TODO: this shouldn't always be bad request.
255 5 : json_response(
256 5 : StatusCode::BAD_REQUEST,
257 5 : json!({
258 5 : "message": message,
259 5 : "code": code,
260 5 : "detail": detail,
261 5 : "hint": hint,
262 5 : "severity": severity,
263 5 : "where": where_,
264 5 : "table": table,
265 5 : "column": column,
266 5 : "schema": schema,
267 5 : "datatype": datatype,
268 5 : "constraint": constraint,
269 5 : "file": file,
270 5 : "line": line,
271 5 : "routine": routine,
272 5 : }),
273 5 : )?
274 : }
275 : },
276 : Err(_) => {
277 UBC 0 : let message = format!(
278 0 : "HTTP-Connection timed out, execution time exeeded {} seconds",
279 0 : config.request_timeout.as_secs()
280 0 : );
281 0 : error!(message);
282 0 : json_response(
283 0 : StatusCode::GATEWAY_TIMEOUT,
284 0 : json!({ "message": message, "code": StatusCode::GATEWAY_TIMEOUT.as_u16() }),
285 0 : )?
286 : }
287 : };
288 CBC 45 : response.headers_mut().insert(
289 45 : "Access-Control-Allow-Origin",
290 45 : hyper::http::HeaderValue::from_static("*"),
291 45 : );
292 45 : Ok(response)
293 45 : }
294 :
295 UBC 0 : #[instrument(name = "sql-over-http", fields(pid = tracing::field::Empty), skip_all)]
296 : async fn handle_inner(
297 : config: &'static HttpConfig,
298 : ctx: &mut RequestMonitoring,
299 : request: Request<Body>,
300 : sni_hostname: Option<String>,
301 : conn_pool: Arc<GlobalConnPool>,
302 : ) -> anyhow::Result<Response<Body>> {
303 : let _request_gauge = NUM_CONNECTION_REQUESTS_GAUGE
304 : .with_label_values(&["http"])
305 : .guard();
306 :
307 : //
308 : // Determine the destination and connection params
309 : //
310 : let headers = request.headers();
311 : let conn_info = get_conn_info(ctx, headers, sni_hostname)?;
312 :
313 : // Determine the output options. Default behaviour is 'false'. Anything that is not
314 : // strictly 'true' assumed to be false.
315 : let raw_output = headers.get(&RAW_TEXT_OUTPUT) == Some(&HEADER_VALUE_TRUE);
316 : let array_mode = headers.get(&ARRAY_MODE) == Some(&HEADER_VALUE_TRUE);
317 :
318 : // Allow connection pooling only if explicitly requested
319 : // or if we have decided that http pool is no longer opt-in
320 : let allow_pool =
321 : !config.pool_options.opt_in || headers.get(&ALLOW_POOL) == Some(&HEADER_VALUE_TRUE);
322 :
323 : // isolation level, read only and deferrable
324 :
325 : let txn_isolation_level_raw = headers.get(&TXN_ISOLATION_LEVEL).cloned();
326 : let txn_isolation_level = match txn_isolation_level_raw {
327 : Some(ref x) => Some(match x.as_bytes() {
328 : b"Serializable" => IsolationLevel::Serializable,
329 : b"ReadUncommitted" => IsolationLevel::ReadUncommitted,
330 : b"ReadCommitted" => IsolationLevel::ReadCommitted,
331 : b"RepeatableRead" => IsolationLevel::RepeatableRead,
332 : _ => bail!("invalid isolation level"),
333 : }),
334 : None => None,
335 : };
336 :
337 : let txn_read_only = headers.get(&TXN_READ_ONLY) == Some(&HEADER_VALUE_TRUE);
338 : let txn_deferrable = headers.get(&TXN_DEFERRABLE) == Some(&HEADER_VALUE_TRUE);
339 :
340 : let paused = ctx.latency_timer.pause();
341 : let request_content_length = match request.body().size_hint().upper() {
342 : Some(v) => v,
343 : None => MAX_REQUEST_SIZE + 1,
344 : };
345 : drop(paused);
346 :
347 : // we don't have a streaming request support yet so this is to prevent OOM
348 : // from a malicious user sending an extremely large request body
349 : if request_content_length > MAX_REQUEST_SIZE {
350 : return Err(anyhow::anyhow!(
351 : "request is too large (max is {MAX_REQUEST_SIZE} bytes)"
352 : ));
353 : }
354 :
355 : //
356 : // Read the query and query params from the request body
357 : //
358 : let body = hyper::body::to_bytes(request.into_body()).await?;
359 : let payload: Payload = serde_json::from_slice(&body)?;
360 :
361 : let mut client = conn_pool.get(ctx, conn_info, !allow_pool).await?;
362 :
363 : let mut response = Response::builder()
364 : .status(StatusCode::OK)
365 : .header(header::CONTENT_TYPE, "application/json");
366 :
367 : //
368 : // Now execute the query and return the result
369 : //
370 : let mut size = 0;
371 : let result =
372 : match payload {
373 : Payload::Single(stmt) => {
374 : let (status, results) =
375 : query_to_json(&*client, stmt, &mut 0, raw_output, array_mode)
376 : .await
377 CBC 2 : .map_err(|e| {
378 2 : client.discard();
379 2 : e
380 2 : })?;
381 : client.check_idle(status);
382 : results
383 : }
384 : Payload::Batch(statements) => {
385 : let (inner, mut discard) = client.inner();
386 : let mut builder = inner.build_transaction();
387 : if let Some(isolation_level) = txn_isolation_level {
388 : builder = builder.isolation_level(isolation_level);
389 : }
390 : if txn_read_only {
391 : builder = builder.read_only(true);
392 : }
393 : if txn_deferrable {
394 : builder = builder.deferrable(true);
395 : }
396 :
397 UBC 0 : let transaction = builder.start().await.map_err(|e| {
398 0 : // if we cannot start a transaction, we should return immediately
399 0 : // and not return to the pool. connection is clearly broken
400 0 : discard.discard();
401 0 : e
402 0 : })?;
403 :
404 : let results =
405 : match query_batch(&transaction, statements, &mut size, raw_output, array_mode)
406 : .await
407 : {
408 : Ok(results) => {
409 0 : let status = transaction.commit().await.map_err(|e| {
410 0 : // if we cannot commit - for now don't return connection to pool
411 0 : // TODO: get a query status from the error
412 0 : discard.discard();
413 0 : e
414 0 : })?;
415 : discard.check_idle(status);
416 : results
417 : }
418 : Err(err) => {
419 0 : let status = transaction.rollback().await.map_err(|e| {
420 0 : // if we cannot rollback - for now don't return connection to pool
421 0 : // TODO: get a query status from the error
422 0 : discard.discard();
423 0 : e
424 0 : })?;
425 : discard.check_idle(status);
426 : return Err(err);
427 : }
428 : };
429 :
430 : if txn_read_only {
431 : response = response.header(
432 : TXN_READ_ONLY.clone(),
433 : HeaderValue::try_from(txn_read_only.to_string())?,
434 : );
435 : }
436 : if txn_deferrable {
437 : response = response.header(
438 : TXN_DEFERRABLE.clone(),
439 : HeaderValue::try_from(txn_deferrable.to_string())?,
440 : );
441 : }
442 : if let Some(txn_isolation_level) = txn_isolation_level_raw {
443 : response = response.header(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level);
444 : }
445 : json!({ "results": results })
446 : }
447 : };
448 :
449 : ctx.log();
450 : let metrics = client.metrics();
451 :
452 : // how could this possibly fail
453 : let body = serde_json::to_string(&result).expect("json serialization should not fail");
454 : let len = body.len();
455 : let response = response
456 : .body(Body::from(body))
457 : // only fails if invalid status code or invalid header/values are given.
458 : // these are not user configurable so it cannot fail dynamically
459 : .expect("building response payload should not fail");
460 :
461 : // count the egress bytes - we miss the TLS and header overhead but oh well...
462 : // moving this later in the stack is going to be a lot of effort and ehhhh
463 : metrics.record_egress(len as u64);
464 :
465 : Ok(response)
466 : }
467 :
468 CBC 2 : async fn query_batch(
469 2 : transaction: &Transaction<'_>,
470 2 : queries: BatchQueryData,
471 2 : total_size: &mut usize,
472 2 : raw_output: bool,
473 2 : array_mode: bool,
474 2 : ) -> anyhow::Result<Vec<Value>> {
475 2 : let mut results = Vec::with_capacity(queries.queries.len());
476 2 : let mut current_size = 0;
477 13 : for stmt in queries.queries {
478 11 : // TODO: maybe we should check that the transaction bit is set here
479 11 : let (_, values) =
480 11 : query_to_json(transaction, stmt, &mut current_size, raw_output, array_mode).await?;
481 11 : results.push(values);
482 : }
483 2 : *total_size += current_size;
484 2 : Ok(results)
485 2 : }
486 :
487 51 : async fn query_to_json<T: GenericClient>(
488 51 : client: &T,
489 51 : data: QueryData,
490 51 : current_size: &mut usize,
491 51 : raw_output: bool,
492 51 : array_mode: bool,
493 51 : ) -> anyhow::Result<(ReadyForQueryStatus, Value)> {
494 51 : let query_params = json_to_pg_text(data.params);
495 51 : let row_stream = client.query_raw_txt(&data.query, query_params).await?;
496 :
497 : // Manually drain the stream into a vector to leave row_stream hanging
498 : // around to get a command tag. Also check that the response is not too
499 : // big.
500 50 : pin_mut!(row_stream);
501 50 : let mut rows: Vec<tokio_postgres::Row> = Vec::new();
502 506404 : while let Some(row) = row_stream.next().await {
503 506355 : let row = row?;
504 506355 : *current_size += row.body_len();
505 506355 : rows.push(row);
506 506355 : // we don't have a streaming response support yet so this is to prevent OOM
507 506355 : // from a malicious query (eg a cross join)
508 506355 : if *current_size > MAX_RESPONSE_SIZE {
509 1 : return Err(anyhow::anyhow!(
510 1 : "response is too large (max is {MAX_RESPONSE_SIZE} bytes)"
511 1 : ));
512 506354 : }
513 : }
514 :
515 49 : let ready = row_stream.ready_status();
516 49 :
517 49 : // grab the command tag and number of rows affected
518 49 : let command_tag = row_stream.command_tag().unwrap_or_default();
519 49 : let mut command_tag_split = command_tag.split(' ');
520 49 : let command_tag_name = command_tag_split.next().unwrap_or_default();
521 49 : let command_tag_count = if command_tag_name == "INSERT" {
522 : // INSERT returns OID first and then number of rows
523 2 : command_tag_split.nth(1)
524 : } else {
525 : // other commands return number of rows (if any)
526 47 : command_tag_split.next()
527 : }
528 49 : .and_then(|s| s.parse::<i64>().ok());
529 49 :
530 49 : let mut fields = vec![];
531 49 : let mut columns = vec![];
532 :
533 117 : for c in row_stream.columns() {
534 117 : fields.push(json!({
535 117 : "name": Value::String(c.name().to_owned()),
536 117 : "dataTypeID": Value::Number(c.type_().oid().into()),
537 117 : "tableID": c.table_oid(),
538 117 : "columnID": c.column_id(),
539 117 : "dataTypeSize": c.type_size(),
540 117 : "dataTypeModifier": c.type_modifier(),
541 117 : "format": "text",
542 117 : }));
543 117 : columns.push(client.get_type(c.type_oid()).await?);
544 : }
545 :
546 : // convert rows to JSON
547 49 : let rows = rows
548 49 : .iter()
549 51 : .map(|row| pg_text_row_to_json(row, &columns, raw_output, array_mode))
550 49 : .collect::<Result<Vec<_>, _>>()?;
551 :
552 : // resulting JSON format is based on the format of node-postgres result
553 49 : Ok((
554 49 : ready,
555 49 : json!({
556 49 : "command": command_tag_name,
557 49 : "rowCount": command_tag_count,
558 49 : "rows": rows,
559 49 : "fields": fields,
560 49 : "rowAsArray": array_mode,
561 49 : }),
562 49 : ))
563 51 : }
564 :
565 : //
566 : // Convert postgres row with text-encoded values to JSON object
567 : //
568 51 : pub fn pg_text_row_to_json(
569 51 : row: &Row,
570 51 : columns: &[Type],
571 51 : raw_output: bool,
572 51 : array_mode: bool,
573 51 : ) -> Result<Value, anyhow::Error> {
574 51 : let iter = row
575 51 : .columns()
576 51 : .iter()
577 51 : .zip(columns)
578 51 : .enumerate()
579 129 : .map(|(i, (column, typ))| {
580 129 : let name = column.name();
581 129 : let pg_value = row.as_text(i)?;
582 129 : let json_value = if raw_output {
583 6 : match pg_value {
584 6 : Some(v) => Value::String(v.to_string()),
585 UBC 0 : None => Value::Null,
586 : }
587 : } else {
588 CBC 123 : pg_text_to_json(pg_value, typ)?
589 : };
590 129 : Ok((name.to_string(), json_value))
591 129 : });
592 51 :
593 51 : if array_mode {
594 : // drop keys and aggregate into array
595 2 : let arr = iter
596 6 : .map(|r| r.map(|(_key, val)| val))
597 2 : .collect::<Result<Vec<Value>, anyhow::Error>>()?;
598 2 : Ok(Value::Array(arr))
599 : } else {
600 49 : let obj = iter.collect::<Result<Map<String, Value>, anyhow::Error>>()?;
601 49 : Ok(Value::Object(obj))
602 : }
603 51 : }
604 :
605 : //
606 : // Convert postgres text-encoded value to JSON value
607 : //
608 225 : pub fn pg_text_to_json(pg_value: Option<&str>, pg_type: &Type) -> Result<Value, anyhow::Error> {
609 225 : if let Some(val) = pg_value {
610 211 : if let Kind::Array(elem_type) = pg_type.kind() {
611 7 : return pg_array_parse(val, elem_type);
612 204 : }
613 204 :
614 204 : match *pg_type {
615 30 : Type::BOOL => Ok(Value::Bool(val == "t")),
616 : Type::INT2 | Type::INT4 => {
617 67 : let val = val.parse::<i32>()?;
618 67 : Ok(Value::Number(serde_json::Number::from(val)))
619 : }
620 : Type::FLOAT4 | Type::FLOAT8 => {
621 25 : let fval = val.parse::<f64>()?;
622 25 : let num = serde_json::Number::from_f64(fval);
623 25 : if let Some(num) = num {
624 16 : Ok(Value::Number(num))
625 : } else {
626 : // Pass Nan, Inf, -Inf as strings
627 : // JS JSON.stringify() does converts them to null, but we
628 : // want to preserve them, so we pass them as strings
629 9 : Ok(Value::String(val.to_string()))
630 : }
631 : }
632 12 : Type::JSON | Type::JSONB => Ok(serde_json::from_str(val)?),
633 70 : _ => Ok(Value::String(val.to_string())),
634 : }
635 : } else {
636 14 : Ok(Value::Null)
637 : }
638 225 : }
639 :
640 : //
641 : // Parse postgres array into JSON array.
642 : //
643 : // This is a bit involved because we need to handle nested arrays and quoted
644 : // values. Unlike postgres we don't check that all nested arrays have the same
645 : // dimensions, we just return them as is.
646 : //
647 29 : fn pg_array_parse(pg_array: &str, elem_type: &Type) -> Result<Value, anyhow::Error> {
648 29 : _pg_array_parse(pg_array, elem_type, false).map(|(v, _)| v)
649 29 : }
650 :
651 46 : fn _pg_array_parse(
652 46 : pg_array: &str,
653 46 : elem_type: &Type,
654 46 : nested: bool,
655 46 : ) -> Result<(Value, usize), anyhow::Error> {
656 46 : let mut pg_array_chr = pg_array.char_indices();
657 46 : let mut level = 0;
658 46 : let mut quote = false;
659 46 : let mut entries: Vec<Value> = Vec::new();
660 46 : let mut entry = String::new();
661 46 :
662 46 : // skip bounds decoration
663 46 : if let Some('[') = pg_array.chars().next() {
664 18 : for (_, c) in pg_array_chr.by_ref() {
665 18 : if c == '=' {
666 1 : break;
667 17 : }
668 : }
669 45 : }
670 :
671 108 : fn push_checked(
672 108 : entry: &mut String,
673 108 : entries: &mut Vec<Value>,
674 108 : elem_type: &Type,
675 108 : ) -> Result<(), anyhow::Error> {
676 108 : if !entry.is_empty() {
677 : // While in usual postgres response we get nulls as None and everything else
678 : // as Some(&str), in arrays we get NULL as unquoted 'NULL' string (while
679 : // string with value 'NULL' will be represented by '"NULL"'). So catch NULLs
680 : // here while we have quotation info and convert them to None.
681 73 : if entry == "NULL" {
682 7 : entries.push(pg_text_to_json(None, elem_type)?);
683 : } else {
684 66 : entries.push(pg_text_to_json(Some(entry), elem_type)?);
685 : }
686 73 : entry.clear();
687 35 : }
688 :
689 108 : Ok(())
690 108 : }
691 :
692 518 : while let Some((mut i, mut c)) = pg_array_chr.next() {
693 489 : let mut escaped = false;
694 489 :
695 489 : if c == '\\' {
696 19 : escaped = true;
697 19 : (i, c) = pg_array_chr.next().unwrap();
698 470 : }
699 :
700 261 : match c {
701 63 : '{' if !quote => {
702 63 : level += 1;
703 63 : if level > 1 {
704 17 : let (res, off) = _pg_array_parse(&pg_array[i..], elem_type, true)?;
705 17 : entries.push(res);
706 195 : for _ in 0..off - 1 {
707 195 : pg_array_chr.next();
708 195 : }
709 46 : }
710 : }
711 63 : '}' if !quote => {
712 63 : level -= 1;
713 63 : if level == 0 {
714 46 : push_checked(&mut entry, &mut entries, elem_type)?;
715 46 : if nested {
716 17 : return Ok((Value::Array(entries), i));
717 29 : }
718 17 : }
719 : }
720 36 : '"' if !escaped => {
721 36 : if quote {
722 : // end of quoted string, so push it manually without any checks
723 : // for emptiness or nulls
724 18 : entries.push(pg_text_to_json(Some(&entry), elem_type)?);
725 18 : entry.clear();
726 18 : }
727 36 : quote = !quote;
728 : }
729 62 : ',' if !quote => {
730 62 : push_checked(&mut entry, &mut entries, elem_type)?;
731 : }
732 265 : _ => {
733 265 : entry.push(c);
734 265 : }
735 : }
736 : }
737 :
738 29 : if level != 0 {
739 UBC 0 : return Err(anyhow::anyhow!("unbalanced array"));
740 CBC 29 : }
741 29 :
742 29 : Ok((Value::Array(entries), 0))
743 46 : }
744 :
745 : #[cfg(test)]
746 : mod tests {
747 : use super::*;
748 : use serde_json::json;
749 :
750 1 : #[test]
751 1 : fn test_atomic_types_to_pg_params() {
752 1 : let json = vec![Value::Bool(true), Value::Bool(false)];
753 1 : let pg_params = json_to_pg_text(json);
754 1 : assert_eq!(
755 1 : pg_params,
756 1 : vec![Some("true".to_owned()), Some("false".to_owned())]
757 1 : );
758 :
759 1 : let json = vec![Value::Number(serde_json::Number::from(42))];
760 1 : let pg_params = json_to_pg_text(json);
761 1 : assert_eq!(pg_params, vec![Some("42".to_owned())]);
762 :
763 1 : let json = vec![Value::String("foo\"".to_string())];
764 1 : let pg_params = json_to_pg_text(json);
765 1 : assert_eq!(pg_params, vec![Some("foo\"".to_owned())]);
766 :
767 1 : let json = vec![Value::Null];
768 1 : let pg_params = json_to_pg_text(json);
769 1 : assert_eq!(pg_params, vec![None]);
770 1 : }
771 :
772 1 : #[test]
773 1 : fn test_json_array_to_pg_array() {
774 1 : // atoms and escaping
775 1 : let json = "[true, false, null, \"NULL\", 42, \"foo\", \"bar\\\"-\\\\\"]";
776 1 : let json: Value = serde_json::from_str(json).unwrap();
777 1 : let pg_params = json_to_pg_text(vec![json]);
778 1 : assert_eq!(
779 1 : pg_params,
780 1 : vec![Some(
781 1 : "{true,false,NULL,\"NULL\",42,\"foo\",\"bar\\\"-\\\\\"}".to_owned()
782 1 : )]
783 1 : );
784 :
785 : // nested arrays
786 1 : let json = "[[true, false], [null, 42], [\"foo\", \"bar\\\"-\\\\\"]]";
787 1 : let json: Value = serde_json::from_str(json).unwrap();
788 1 : let pg_params = json_to_pg_text(vec![json]);
789 1 : assert_eq!(
790 1 : pg_params,
791 1 : vec![Some(
792 1 : "{{true,false},{NULL,42},{\"foo\",\"bar\\\"-\\\\\"}}".to_owned()
793 1 : )]
794 1 : );
795 : // array of objects
796 1 : let json = r#"[{"foo": 1},{"bar": 2}]"#;
797 1 : let json: Value = serde_json::from_str(json).unwrap();
798 1 : let pg_params = json_to_pg_text(vec![json]);
799 1 : assert_eq!(
800 1 : pg_params,
801 1 : vec![Some(r#"{"{\"foo\":1}","{\"bar\":2}"}"#.to_owned())]
802 1 : );
803 1 : }
804 :
805 1 : #[test]
806 1 : fn test_atomic_types_parse() {
807 1 : assert_eq!(
808 1 : pg_text_to_json(Some("foo"), &Type::TEXT).unwrap(),
809 1 : json!("foo")
810 1 : );
811 1 : assert_eq!(pg_text_to_json(None, &Type::TEXT).unwrap(), json!(null));
812 1 : assert_eq!(pg_text_to_json(Some("42"), &Type::INT4).unwrap(), json!(42));
813 1 : assert_eq!(pg_text_to_json(Some("42"), &Type::INT2).unwrap(), json!(42));
814 1 : assert_eq!(
815 1 : pg_text_to_json(Some("42"), &Type::INT8).unwrap(),
816 1 : json!("42")
817 1 : );
818 1 : assert_eq!(
819 1 : pg_text_to_json(Some("42.42"), &Type::FLOAT8).unwrap(),
820 1 : json!(42.42)
821 1 : );
822 1 : assert_eq!(
823 1 : pg_text_to_json(Some("42.42"), &Type::FLOAT4).unwrap(),
824 1 : json!(42.42)
825 1 : );
826 1 : assert_eq!(
827 1 : pg_text_to_json(Some("NaN"), &Type::FLOAT4).unwrap(),
828 1 : json!("NaN")
829 1 : );
830 1 : assert_eq!(
831 1 : pg_text_to_json(Some("Infinity"), &Type::FLOAT4).unwrap(),
832 1 : json!("Infinity")
833 1 : );
834 1 : assert_eq!(
835 1 : pg_text_to_json(Some("-Infinity"), &Type::FLOAT4).unwrap(),
836 1 : json!("-Infinity")
837 1 : );
838 :
839 1 : let json: Value =
840 1 : serde_json::from_str("{\"s\":\"str\",\"n\":42,\"f\":4.2,\"a\":[null,3,\"a\"]}")
841 1 : .unwrap();
842 1 : assert_eq!(
843 1 : pg_text_to_json(
844 1 : Some(r#"{"s":"str","n":42,"f":4.2,"a":[null,3,"a"]}"#),
845 1 : &Type::JSONB
846 1 : )
847 1 : .unwrap(),
848 1 : json
849 1 : );
850 1 : }
851 :
852 1 : #[test]
853 1 : fn test_pg_array_parse_text() {
854 4 : fn pt(pg_arr: &str) -> Value {
855 4 : pg_array_parse(pg_arr, &Type::TEXT).unwrap()
856 4 : }
857 1 : assert_eq!(
858 1 : pt(r#"{"aa\"\\\,a",cha,"bbbb"}"#),
859 1 : json!(["aa\"\\,a", "cha", "bbbb"])
860 1 : );
861 1 : assert_eq!(
862 1 : pt(r#"{{"foo","bar"},{"bee","bop"}}"#),
863 1 : json!([["foo", "bar"], ["bee", "bop"]])
864 1 : );
865 1 : assert_eq!(
866 1 : pt(r#"{{{{"foo",NULL,"bop",bup}}}}"#),
867 1 : json!([[[["foo", null, "bop", "bup"]]]])
868 1 : );
869 1 : assert_eq!(
870 1 : pt(r#"{{"1",2,3},{4,NULL,6},{NULL,NULL,NULL}}"#),
871 1 : json!([["1", "2", "3"], ["4", null, "6"], [null, null, null]])
872 1 : );
873 1 : }
874 :
875 1 : #[test]
876 1 : fn test_pg_array_parse_bool() {
877 4 : fn pb(pg_arr: &str) -> Value {
878 4 : pg_array_parse(pg_arr, &Type::BOOL).unwrap()
879 4 : }
880 1 : assert_eq!(pb(r#"{t,f,t}"#), json!([true, false, true]));
881 1 : assert_eq!(pb(r#"{{t,f,t}}"#), json!([[true, false, true]]));
882 1 : assert_eq!(
883 1 : pb(r#"{{t,f},{f,t}}"#),
884 1 : json!([[true, false], [false, true]])
885 1 : );
886 1 : assert_eq!(
887 1 : pb(r#"{{t,NULL},{NULL,f}}"#),
888 1 : json!([[true, null], [null, false]])
889 1 : );
890 1 : }
891 :
892 1 : #[test]
893 1 : fn test_pg_array_parse_numbers() {
894 9 : fn pn(pg_arr: &str, ty: &Type) -> Value {
895 9 : pg_array_parse(pg_arr, ty).unwrap()
896 9 : }
897 1 : assert_eq!(pn(r#"{1,2,3}"#, &Type::INT4), json!([1, 2, 3]));
898 1 : assert_eq!(pn(r#"{1,2,3}"#, &Type::INT2), json!([1, 2, 3]));
899 1 : assert_eq!(pn(r#"{1,2,3}"#, &Type::INT8), json!(["1", "2", "3"]));
900 1 : assert_eq!(pn(r#"{1,2,3}"#, &Type::FLOAT4), json!([1.0, 2.0, 3.0]));
901 1 : assert_eq!(pn(r#"{1,2,3}"#, &Type::FLOAT8), json!([1.0, 2.0, 3.0]));
902 1 : assert_eq!(
903 1 : pn(r#"{1.1,2.2,3.3}"#, &Type::FLOAT4),
904 1 : json!([1.1, 2.2, 3.3])
905 1 : );
906 1 : assert_eq!(
907 1 : pn(r#"{1.1,2.2,3.3}"#, &Type::FLOAT8),
908 1 : json!([1.1, 2.2, 3.3])
909 1 : );
910 1 : assert_eq!(
911 1 : pn(r#"{NaN,Infinity,-Infinity}"#, &Type::FLOAT4),
912 1 : json!(["NaN", "Infinity", "-Infinity"])
913 1 : );
914 1 : assert_eq!(
915 1 : pn(r#"{NaN,Infinity,-Infinity}"#, &Type::FLOAT8),
916 1 : json!(["NaN", "Infinity", "-Infinity"])
917 1 : );
918 1 : }
919 :
920 1 : #[test]
921 1 : fn test_pg_array_with_decoration() {
922 1 : fn p(pg_arr: &str) -> Value {
923 1 : pg_array_parse(pg_arr, &Type::INT2).unwrap()
924 1 : }
925 1 : assert_eq!(
926 1 : p(r#"[1:1][-2:-1][3:5]={{{1,2,3},{4,5,6}}}"#),
927 1 : json!([[[1, 2, 3], [4, 5, 6]]])
928 1 : );
929 1 : }
930 1 : #[test]
931 1 : fn test_pg_array_parse_json() {
932 4 : fn pt(pg_arr: &str) -> Value {
933 4 : pg_array_parse(pg_arr, &Type::JSONB).unwrap()
934 4 : }
935 1 : assert_eq!(pt(r#"{"{}"}"#), json!([{}]));
936 1 : assert_eq!(
937 1 : pt(r#"{"{\"foo\": 1, \"bar\": 2}"}"#),
938 1 : json!([{"foo": 1, "bar": 2}])
939 1 : );
940 1 : assert_eq!(
941 1 : pt(r#"{"{\"foo\": 1}", "{\"bar\": 2}"}"#),
942 1 : json!([{"foo": 1}, {"bar": 2}])
943 1 : );
944 1 : assert_eq!(
945 1 : pt(r#"{{"{\"foo\": 1}", "{\"bar\": 2}"}}"#),
946 1 : json!([[{"foo": 1}, {"bar": 2}]])
947 1 : );
948 1 : }
949 : }
|