TLA Line data Source code
1 : use std::sync::Arc;
2 :
3 : use anyhow::bail;
4 : use futures::pin_mut;
5 : use futures::StreamExt;
6 : use hyper::body::HttpBody;
7 : use hyper::header;
8 : use hyper::http::HeaderName;
9 : use hyper::http::HeaderValue;
10 : use hyper::Response;
11 : use hyper::StatusCode;
12 : use hyper::{Body, HeaderMap, Request};
13 : use serde_json::json;
14 : use serde_json::Map;
15 : use serde_json::Value;
16 : use tokio_postgres::types::Kind;
17 : use tokio_postgres::types::Type;
18 : use tokio_postgres::GenericClient;
19 : use tokio_postgres::IsolationLevel;
20 : use tokio_postgres::ReadyForQueryStatus;
21 : use tokio_postgres::Row;
22 : use tokio_postgres::Transaction;
23 : use tracing::error;
24 : use tracing::instrument;
25 : use url::Url;
26 : use utils::http::error::ApiError;
27 : use utils::http::json::json_response;
28 :
29 : use crate::config::HttpConfig;
30 : use crate::proxy::{NUM_CONNECTIONS_ACCEPTED_COUNTER, NUM_CONNECTIONS_CLOSED_COUNTER};
31 :
32 : use super::conn_pool::ConnInfo;
33 : use super::conn_pool::GlobalConnPool;
34 :
35 CBC 201 : #[derive(serde::Deserialize)]
36 : struct QueryData {
37 : query: String,
38 : params: Vec<serde_json::Value>,
39 : }
40 :
41 6 : #[derive(serde::Deserialize)]
42 : struct BatchQueryData {
43 : queries: Vec<QueryData>,
44 : }
45 :
46 32 : #[derive(serde::Deserialize)]
47 : #[serde(untagged)]
48 : enum Payload {
49 : Single(QueryData),
50 : Batch(BatchQueryData),
51 : }
52 :
53 : const MAX_RESPONSE_SIZE: usize = 10 * 1024 * 1024; // 10 MiB
54 : const MAX_REQUEST_SIZE: u64 = 10 * 1024 * 1024; // 10 MiB
55 :
56 : static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output");
57 : static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode");
58 : static ALLOW_POOL: HeaderName = HeaderName::from_static("neon-pool-opt-in");
59 : static TXN_ISOLATION_LEVEL: HeaderName = HeaderName::from_static("neon-batch-isolation-level");
60 : static TXN_READ_ONLY: HeaderName = HeaderName::from_static("neon-batch-read-only");
61 : static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrable");
62 :
63 : static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true");
64 :
65 : //
66 : // Convert json non-string types to strings, so that they can be passed to Postgres
67 : // as parameters.
68 : //
69 44 : fn json_to_pg_text(json: Vec<Value>) -> Vec<Option<String>> {
70 44 : json.iter()
71 44 : .map(|value| {
72 18 : match value {
73 : // special care for nulls
74 1 : Value::Null => None,
75 :
76 : // convert to text with escaping
77 9 : v @ (Value::Bool(_) | Value::Number(_) | Value::Object(_)) => Some(v.to_string()),
78 :
79 : // avoid escaping here, as we pass this as a parameter
80 1 : Value::String(s) => Some(s.to_string()),
81 :
82 : // special care for arrays
83 7 : Value::Array(_) => json_array_to_pg_array(value),
84 : }
85 44 : })
86 44 : .collect()
87 44 : }
88 :
89 : //
90 : // Serialize a JSON array to a Postgres array. Contrary to the strings in the params
91 : // in the array we need to escape the strings. Postgres is okay with arrays of form
92 : // '{1,"2",3}'::int[], so we don't check that array holds values of the same type, leaving
93 : // it for Postgres to check.
94 : //
95 : // Example of the same escaping in node-postgres: packages/pg/lib/utils.js
96 : //
97 39 : fn json_array_to_pg_array(value: &Value) -> Option<String> {
98 39 : match value {
99 : // special care for nulls
100 2 : Value::Null => None,
101 :
102 : // convert to text with escaping
103 : // here string needs to be escaped, as it is part of the array
104 22 : v @ (Value::Bool(_) | Value::Number(_) | Value::String(_)) => Some(v.to_string()),
105 5 : v @ Value::Object(_) => json_array_to_pg_array(&Value::String(v.to_string())),
106 :
107 : // recurse into array
108 10 : Value::Array(arr) => {
109 10 : let vals = arr
110 10 : .iter()
111 10 : .map(json_array_to_pg_array)
112 27 : .map(|v| v.unwrap_or_else(|| "NULL".to_string()))
113 10 : .collect::<Vec<_>>()
114 10 : .join(",");
115 10 :
116 10 : Some(format!("{{{}}}", vals))
117 : }
118 : }
119 39 : }
120 :
121 30 : fn get_conn_info(
122 30 : headers: &HeaderMap,
123 30 : sni_hostname: Option<String>,
124 30 : ) -> Result<ConnInfo, anyhow::Error> {
125 30 : let connection_string = headers
126 30 : .get("Neon-Connection-String")
127 30 : .ok_or(anyhow::anyhow!("missing connection string"))?
128 30 : .to_str()?;
129 :
130 30 : let connection_url = Url::parse(connection_string)?;
131 :
132 30 : let protocol = connection_url.scheme();
133 30 : if protocol != "postgres" && protocol != "postgresql" {
134 UBC 0 : return Err(anyhow::anyhow!(
135 0 : "connection string must start with postgres: or postgresql:"
136 0 : ));
137 CBC 30 : }
138 :
139 30 : let mut url_path = connection_url
140 30 : .path_segments()
141 30 : .ok_or(anyhow::anyhow!("missing database name"))?;
142 :
143 30 : let dbname = url_path
144 30 : .next()
145 30 : .ok_or(anyhow::anyhow!("invalid database name"))?;
146 :
147 30 : let username = connection_url.username();
148 30 : if username.is_empty() {
149 UBC 0 : return Err(anyhow::anyhow!("missing username"));
150 CBC 30 : }
151 :
152 30 : let password = connection_url
153 30 : .password()
154 30 : .ok_or(anyhow::anyhow!("no password"))?;
155 :
156 : // TLS certificate selector now based on SNI hostname, so if we are running here
157 : // we are sure that SNI hostname is set to one of the configured domain names.
158 30 : let sni_hostname = sni_hostname.ok_or(anyhow::anyhow!("no SNI hostname set"))?;
159 :
160 30 : let hostname = connection_url
161 30 : .host_str()
162 30 : .ok_or(anyhow::anyhow!("no host"))?;
163 :
164 30 : let host_header = headers
165 30 : .get("host")
166 30 : .and_then(|h| h.to_str().ok())
167 30 : .and_then(|h| h.split(':').next());
168 30 :
169 30 : if hostname != sni_hostname {
170 UBC 0 : return Err(anyhow::anyhow!("mismatched SNI hostname and hostname"));
171 CBC 30 : } else if let Some(h) = host_header {
172 30 : if h != hostname {
173 UBC 0 : return Err(anyhow::anyhow!("mismatched host header and hostname"));
174 CBC 30 : }
175 UBC 0 : }
176 :
177 CBC 30 : Ok(ConnInfo {
178 30 : username: username.to_owned(),
179 30 : dbname: dbname.to_owned(),
180 30 : hostname: hostname.to_owned(),
181 30 : password: password.to_owned(),
182 30 : })
183 30 : }
184 :
185 : // TODO: return different http error codes
186 30 : pub async fn handle(
187 30 : request: Request<Body>,
188 30 : sni_hostname: Option<String>,
189 30 : conn_pool: Arc<GlobalConnPool>,
190 30 : session_id: uuid::Uuid,
191 30 : config: &'static HttpConfig,
192 30 : ) -> Result<Response<Body>, ApiError> {
193 30 : let result = tokio::time::timeout(
194 30 : config.sql_over_http_timeout,
195 30 : handle_inner(request, sni_hostname, conn_pool, session_id),
196 30 : )
197 233 : .await;
198 30 : let mut response = match result {
199 30 : Ok(r) => match r {
200 27 : Ok(r) => r,
201 3 : Err(e) => {
202 3 : let message = format!("{:?}", e);
203 3 : let code = e.downcast_ref::<tokio_postgres::Error>().and_then(|e| {
204 3 : e.code()
205 3 : .map(|s| serde_json::to_value(s.code()).unwrap_or_default())
206 3 : });
207 3 : let code = match code {
208 3 : Some(c) => c,
209 UBC 0 : None => Value::Null,
210 : };
211 CBC 3 : error!(
212 3 : ?code,
213 3 : "sql-over-http per-client task finished with an error: {e:#}"
214 3 : );
215 : // TODO: this shouldn't always be bad request.
216 3 : json_response(
217 3 : StatusCode::BAD_REQUEST,
218 3 : json!({ "message": message, "code": code }),
219 3 : )?
220 : }
221 : },
222 : Err(_) => {
223 UBC 0 : let message = format!(
224 0 : "HTTP-Connection timed out, execution time exeeded {} seconds",
225 0 : config.sql_over_http_timeout.as_secs()
226 0 : );
227 0 : error!(message);
228 0 : json_response(
229 0 : StatusCode::GATEWAY_TIMEOUT,
230 0 : json!({ "message": message, "code": StatusCode::GATEWAY_TIMEOUT.as_u16() }),
231 0 : )?
232 : }
233 : };
234 CBC 30 : response.headers_mut().insert(
235 30 : "Access-Control-Allow-Origin",
236 30 : hyper::http::HeaderValue::from_static("*"),
237 30 : );
238 30 : Ok(response)
239 30 : }
240 :
241 90 : #[instrument(name = "sql-over-http", skip_all)]
242 : async fn handle_inner(
243 : request: Request<Body>,
244 : sni_hostname: Option<String>,
245 : conn_pool: Arc<GlobalConnPool>,
246 : session_id: uuid::Uuid,
247 : ) -> anyhow::Result<Response<Body>> {
248 : NUM_CONNECTIONS_ACCEPTED_COUNTER
249 : .with_label_values(&["http"])
250 : .inc();
251 30 : scopeguard::defer! {
252 30 : NUM_CONNECTIONS_CLOSED_COUNTER.with_label_values(&["http"]).inc();
253 30 : }
254 :
255 : //
256 : // Determine the destination and connection params
257 : //
258 : let headers = request.headers();
259 : let conn_info = get_conn_info(headers, sni_hostname)?;
260 :
261 : // Determine the output options. Default behaviour is 'false'. Anything that is not
262 : // strictly 'true' assumed to be false.
263 : let raw_output = headers.get(&RAW_TEXT_OUTPUT) == Some(&HEADER_VALUE_TRUE);
264 : let array_mode = headers.get(&ARRAY_MODE) == Some(&HEADER_VALUE_TRUE);
265 :
266 : // Allow connection pooling only if explicitly requested
267 : let allow_pool = headers.get(&ALLOW_POOL) == Some(&HEADER_VALUE_TRUE);
268 :
269 : // isolation level, read only and deferrable
270 :
271 : let txn_isolation_level_raw = headers.get(&TXN_ISOLATION_LEVEL).cloned();
272 : let txn_isolation_level = match txn_isolation_level_raw {
273 : Some(ref x) => Some(match x.as_bytes() {
274 : b"Serializable" => IsolationLevel::Serializable,
275 : b"ReadUncommitted" => IsolationLevel::ReadUncommitted,
276 : b"ReadCommitted" => IsolationLevel::ReadCommitted,
277 : b"RepeatableRead" => IsolationLevel::RepeatableRead,
278 : _ => bail!("invalid isolation level"),
279 : }),
280 : None => None,
281 : };
282 :
283 : let txn_read_only = headers.get(&TXN_READ_ONLY) == Some(&HEADER_VALUE_TRUE);
284 : let txn_deferrable = headers.get(&TXN_DEFERRABLE) == Some(&HEADER_VALUE_TRUE);
285 :
286 : let request_content_length = match request.body().size_hint().upper() {
287 : Some(v) => v,
288 : None => MAX_REQUEST_SIZE + 1,
289 : };
290 :
291 : // we don't have a streaming request support yet so this is to prevent OOM
292 : // from a malicious user sending an extremely large request body
293 : if request_content_length > MAX_REQUEST_SIZE {
294 : return Err(anyhow::anyhow!(
295 : "request is too large (max is {MAX_REQUEST_SIZE} bytes)"
296 : ));
297 : }
298 :
299 : //
300 : // Read the query and query params from the request body
301 : //
302 : let body = hyper::body::to_bytes(request.into_body()).await?;
303 : let payload: Payload = serde_json::from_slice(&body)?;
304 :
305 : let mut client = conn_pool.get(&conn_info, !allow_pool, session_id).await?;
306 :
307 : let mut response = Response::builder()
308 : .status(StatusCode::OK)
309 : .header(header::CONTENT_TYPE, "application/json");
310 :
311 : //
312 : // Now execute the query and return the result
313 : //
314 : let mut size = 0;
315 : let result =
316 : match payload {
317 : Payload::Single(stmt) => {
318 : let (status, results) =
319 : query_to_json(&*client, stmt, &mut 0, raw_output, array_mode)
320 : .await
321 1 : .map_err(|e| {
322 1 : client.discard();
323 1 : e
324 1 : })?;
325 : client.check_idle(status);
326 : results
327 : }
328 : Payload::Batch(statements) => {
329 : let (inner, mut discard) = client.inner();
330 : let mut builder = inner.build_transaction();
331 : if let Some(isolation_level) = txn_isolation_level {
332 : builder = builder.isolation_level(isolation_level);
333 : }
334 : if txn_read_only {
335 : builder = builder.read_only(true);
336 : }
337 : if txn_deferrable {
338 : builder = builder.deferrable(true);
339 : }
340 :
341 UBC 0 : let transaction = builder.start().await.map_err(|e| {
342 0 : // if we cannot start a transaction, we should return immediately
343 0 : // and not return to the pool. connection is clearly broken
344 0 : discard.discard();
345 0 : e
346 0 : })?;
347 :
348 : let results =
349 : match query_batch(&transaction, statements, &mut size, raw_output, array_mode)
350 : .await
351 : {
352 : Ok(results) => {
353 0 : let status = transaction.commit().await.map_err(|e| {
354 0 : // if we cannot commit - for now don't return connection to pool
355 0 : // TODO: get a query status from the error
356 0 : discard.discard();
357 0 : e
358 0 : })?;
359 : discard.check_idle(status);
360 : results
361 : }
362 : Err(err) => {
363 0 : let status = transaction.rollback().await.map_err(|e| {
364 0 : // if we cannot rollback - for now don't return connection to pool
365 0 : // TODO: get a query status from the error
366 0 : discard.discard();
367 0 : e
368 0 : })?;
369 : discard.check_idle(status);
370 : return Err(err);
371 : }
372 : };
373 :
374 : if txn_read_only {
375 : response = response.header(
376 : TXN_READ_ONLY.clone(),
377 : HeaderValue::try_from(txn_read_only.to_string())?,
378 : );
379 : }
380 : if txn_deferrable {
381 : response = response.header(
382 : TXN_DEFERRABLE.clone(),
383 : HeaderValue::try_from(txn_deferrable.to_string())?,
384 : );
385 : }
386 : if let Some(txn_isolation_level) = txn_isolation_level_raw {
387 : response = response.header(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level);
388 : }
389 : json!({ "results": results })
390 : }
391 : };
392 :
393 : let metrics = client.metrics();
394 :
395 : // how could this possibly fail
396 : let body = serde_json::to_string(&result).expect("json serialization should not fail");
397 : let len = body.len();
398 : let response = response
399 : .body(Body::from(body))
400 : // only fails if invalid status code or invalid header/values are given.
401 : // these are not user configurable so it cannot fail dynamically
402 : .expect("building response payload should not fail");
403 :
404 : // count the egress bytes - we miss the TLS and header overhead but oh well...
405 : // moving this later in the stack is going to be a lot of effort and ehhhh
406 : metrics.record_egress(len as u64);
407 :
408 : Ok(response)
409 : }
410 :
411 CBC 2 : async fn query_batch(
412 2 : transaction: &Transaction<'_>,
413 2 : queries: BatchQueryData,
414 2 : total_size: &mut usize,
415 2 : raw_output: bool,
416 2 : array_mode: bool,
417 2 : ) -> anyhow::Result<Vec<Value>> {
418 2 : let mut results = Vec::with_capacity(queries.queries.len());
419 2 : let mut current_size = 0;
420 13 : for stmt in queries.queries {
421 11 : // TODO: maybe we should check that the transaction bit is set here
422 11 : let (_, values) =
423 22 : query_to_json(transaction, stmt, &mut current_size, raw_output, array_mode).await?;
424 11 : results.push(values);
425 : }
426 2 : *total_size += current_size;
427 2 : Ok(results)
428 2 : }
429 :
430 37 : async fn query_to_json<T: GenericClient>(
431 37 : client: &T,
432 37 : data: QueryData,
433 37 : current_size: &mut usize,
434 37 : raw_output: bool,
435 37 : array_mode: bool,
436 37 : ) -> anyhow::Result<(ReadyForQueryStatus, Value)> {
437 37 : let query_params = json_to_pg_text(data.params);
438 73 : let row_stream = client.query_raw_txt(&data.query, query_params).await?;
439 :
440 : // Manually drain the stream into a vector to leave row_stream hanging
441 : // around to get a command tag. Also check that the response is not too
442 : // big.
443 36 : pin_mut!(row_stream);
444 36 : let mut rows: Vec<tokio_postgres::Row> = Vec::new();
445 74 : while let Some(row) = row_stream.next().await {
446 38 : let row = row?;
447 38 : *current_size += row.body_len();
448 38 : rows.push(row);
449 38 : // we don't have a streaming response support yet so this is to prevent OOM
450 38 : // from a malicious query (eg a cross join)
451 38 : if *current_size > MAX_RESPONSE_SIZE {
452 UBC 0 : return Err(anyhow::anyhow!(
453 0 : "response is too large (max is {MAX_RESPONSE_SIZE} bytes)"
454 0 : ));
455 CBC 38 : }
456 : }
457 :
458 36 : let ready = row_stream.ready_status();
459 36 :
460 36 : // grab the command tag and number of rows affected
461 36 : let command_tag = row_stream.command_tag().unwrap_or_default();
462 36 : let mut command_tag_split = command_tag.split(' ');
463 36 : let command_tag_name = command_tag_split.next().unwrap_or_default();
464 36 : let command_tag_count = if command_tag_name == "INSERT" {
465 : // INSERT returns OID first and then number of rows
466 2 : command_tag_split.nth(1)
467 : } else {
468 : // other commands return number of rows (if any)
469 34 : command_tag_split.next()
470 : }
471 36 : .and_then(|s| s.parse::<i64>().ok());
472 :
473 36 : let fields = if !rows.is_empty() {
474 30 : rows[0]
475 30 : .columns()
476 30 : .iter()
477 104 : .map(|c| {
478 104 : json!({
479 104 : "name": Value::String(c.name().to_owned()),
480 104 : "dataTypeID": Value::Number(c.type_().oid().into()),
481 104 : "tableID": c.table_oid(),
482 104 : "columnID": c.column_id(),
483 104 : "dataTypeSize": c.type_size(),
484 104 : "dataTypeModifier": c.type_modifier(),
485 104 : "format": "text",
486 104 : })
487 104 : })
488 30 : .collect::<Vec<_>>()
489 : } else {
490 6 : Vec::new()
491 : };
492 :
493 : // convert rows to JSON
494 36 : let rows = rows
495 36 : .iter()
496 38 : .map(|row| pg_text_row_to_json(row, raw_output, array_mode))
497 36 : .collect::<Result<Vec<_>, _>>()?;
498 :
499 : // resulting JSON format is based on the format of node-postgres result
500 36 : Ok((
501 36 : ready,
502 36 : json!({
503 36 : "command": command_tag_name,
504 36 : "rowCount": command_tag_count,
505 36 : "rows": rows,
506 36 : "fields": fields,
507 36 : "rowAsArray": array_mode,
508 36 : }),
509 36 : ))
510 37 : }
511 :
512 : //
513 : // Convert postgres row with text-encoded values to JSON object
514 : //
515 38 : pub fn pg_text_row_to_json(
516 38 : row: &Row,
517 38 : raw_output: bool,
518 38 : array_mode: bool,
519 38 : ) -> Result<Value, anyhow::Error> {
520 116 : let iter = row.columns().iter().enumerate().map(|(i, column)| {
521 116 : let name = column.name();
522 116 : let pg_value = row.as_text(i)?;
523 116 : let json_value = if raw_output {
524 6 : match pg_value {
525 6 : Some(v) => Value::String(v.to_string()),
526 UBC 0 : None => Value::Null,
527 : }
528 : } else {
529 CBC 110 : pg_text_to_json(pg_value, column.type_())?
530 : };
531 116 : Ok((name.to_string(), json_value))
532 116 : });
533 38 :
534 38 : if array_mode {
535 : // drop keys and aggregate into array
536 2 : let arr = iter
537 6 : .map(|r| r.map(|(_key, val)| val))
538 2 : .collect::<Result<Vec<Value>, anyhow::Error>>()?;
539 2 : Ok(Value::Array(arr))
540 : } else {
541 36 : let obj = iter.collect::<Result<Map<String, Value>, anyhow::Error>>()?;
542 36 : Ok(Value::Object(obj))
543 : }
544 38 : }
545 :
546 : //
547 : // Convert postgres text-encoded value to JSON value
548 : //
549 : pub fn pg_text_to_json(pg_value: Option<&str>, pg_type: &Type) -> Result<Value, anyhow::Error> {
550 211 : if let Some(val) = pg_value {
551 199 : if let Kind::Array(elem_type) = pg_type.kind() {
552 8 : return pg_array_parse(val, elem_type);
553 191 : }
554 191 :
555 191 : match *pg_type {
556 30 : Type::BOOL => Ok(Value::Bool(val == "t")),
557 : Type::INT2 | Type::INT4 => {
558 66 : let val = val.parse::<i32>()?;
559 66 : Ok(Value::Number(serde_json::Number::from(val)))
560 : }
561 : Type::FLOAT4 | Type::FLOAT8 => {
562 25 : let fval = val.parse::<f64>()?;
563 25 : let num = serde_json::Number::from_f64(fval);
564 25 : if let Some(num) = num {
565 16 : Ok(Value::Number(num))
566 : } else {
567 : // Pass Nan, Inf, -Inf as strings
568 : // JS JSON.stringify() does converts them to null, but we
569 : // want to preserve them, so we pass them as strings
570 9 : Ok(Value::String(val.to_string()))
571 : }
572 : }
573 12 : Type::JSON | Type::JSONB => Ok(serde_json::from_str(val)?),
574 58 : _ => Ok(Value::String(val.to_string())),
575 : }
576 : } else {
577 12 : Ok(Value::Null)
578 : }
579 211 : }
580 :
581 : //
582 : // Parse postgres array into JSON array.
583 : //
584 : // This is a bit involved because we need to handle nested arrays and quoted
585 : // values. Unlike postgres we don't check that all nested arrays have the same
586 : // dimensions, we just return them as is.
587 : //
588 30 : fn pg_array_parse(pg_array: &str, elem_type: &Type) -> Result<Value, anyhow::Error> {
589 30 : _pg_array_parse(pg_array, elem_type, false).map(|(v, _)| v)
590 30 : }
591 :
592 47 : fn _pg_array_parse(
593 47 : pg_array: &str,
594 47 : elem_type: &Type,
595 47 : nested: bool,
596 47 : ) -> Result<(Value, usize), anyhow::Error> {
597 47 : let mut pg_array_chr = pg_array.char_indices();
598 47 : let mut level = 0;
599 47 : let mut quote = false;
600 47 : let mut entries: Vec<Value> = Vec::new();
601 47 : let mut entry = String::new();
602 47 :
603 47 : // skip bounds decoration
604 47 : if let Some('[') = pg_array.chars().next() {
605 18 : for (_, c) in pg_array_chr.by_ref() {
606 18 : if c == '=' {
607 1 : break;
608 17 : }
609 : }
610 46 : }
611 :
612 107 : fn push_checked(
613 107 : entry: &mut String,
614 107 : entries: &mut Vec<Value>,
615 107 : elem_type: &Type,
616 107 : ) -> Result<(), anyhow::Error> {
617 107 : if !entry.is_empty() {
618 : // While in usual postgres response we get nulls as None and everything else
619 : // as Some(&str), in arrays we get NULL as unquoted 'NULL' string (while
620 : // string with value 'NULL' will be represented by '"NULL"'). So catch NULLs
621 : // here while we have quotation info and convert them to None.
622 72 : if entry == "NULL" {
623 7 : entries.push(pg_text_to_json(None, elem_type)?);
624 : } else {
625 65 : entries.push(pg_text_to_json(Some(entry), elem_type)?);
626 : }
627 72 : entry.clear();
628 35 : }
629 :
630 107 : Ok(())
631 107 : }
632 :
633 556 : while let Some((mut i, mut c)) = pg_array_chr.next() {
634 526 : let mut escaped = false;
635 526 :
636 526 : if c == '\\' {
637 19 : escaped = true;
638 19 : (i, c) = pg_array_chr.next().unwrap();
639 507 : }
640 :
641 261 : match c {
642 64 : '{' if !quote => {
643 64 : level += 1;
644 64 : if level > 1 {
645 17 : let (res, off) = _pg_array_parse(&pg_array[i..], elem_type, true)?;
646 17 : entries.push(res);
647 195 : for _ in 0..off - 1 {
648 195 : pg_array_chr.next();
649 195 : }
650 47 : }
651 : }
652 64 : '}' if !quote => {
653 64 : level -= 1;
654 64 : if level == 0 {
655 47 : push_checked(&mut entry, &mut entries, elem_type)?;
656 47 : if nested {
657 17 : return Ok((Value::Array(entries), i));
658 30 : }
659 17 : }
660 : }
661 36 : '"' if !escaped => {
662 36 : if quote {
663 : // end of quoted string, so push it manually without any checks
664 : // for emptiness or nulls
665 18 : entries.push(pg_text_to_json(Some(&entry), elem_type)?);
666 18 : entry.clear();
667 18 : }
668 36 : quote = !quote;
669 : }
670 60 : ',' if !quote => {
671 60 : push_checked(&mut entry, &mut entries, elem_type)?;
672 : }
673 302 : _ => {
674 302 : entry.push(c);
675 302 : }
676 : }
677 : }
678 :
679 30 : if level != 0 {
680 UBC 0 : return Err(anyhow::anyhow!("unbalanced array"));
681 CBC 30 : }
682 30 :
683 30 : Ok((Value::Array(entries), 0))
684 47 : }
685 :
686 : #[cfg(test)]
687 : mod tests {
688 : use super::*;
689 : use serde_json::json;
690 :
691 1 : #[test]
692 1 : fn test_atomic_types_to_pg_params() {
693 1 : let json = vec![Value::Bool(true), Value::Bool(false)];
694 1 : let pg_params = json_to_pg_text(json);
695 1 : assert_eq!(
696 1 : pg_params,
697 1 : vec![Some("true".to_owned()), Some("false".to_owned())]
698 1 : );
699 :
700 1 : let json = vec![Value::Number(serde_json::Number::from(42))];
701 1 : let pg_params = json_to_pg_text(json);
702 1 : assert_eq!(pg_params, vec![Some("42".to_owned())]);
703 :
704 1 : let json = vec![Value::String("foo\"".to_string())];
705 1 : let pg_params = json_to_pg_text(json);
706 1 : assert_eq!(pg_params, vec![Some("foo\"".to_owned())]);
707 :
708 1 : let json = vec![Value::Null];
709 1 : let pg_params = json_to_pg_text(json);
710 1 : assert_eq!(pg_params, vec![None]);
711 1 : }
712 :
713 1 : #[test]
714 1 : fn test_json_array_to_pg_array() {
715 1 : // atoms and escaping
716 1 : let json = "[true, false, null, \"NULL\", 42, \"foo\", \"bar\\\"-\\\\\"]";
717 1 : let json: Value = serde_json::from_str(json).unwrap();
718 1 : let pg_params = json_to_pg_text(vec![json]);
719 1 : assert_eq!(
720 1 : pg_params,
721 1 : vec![Some(
722 1 : "{true,false,NULL,\"NULL\",42,\"foo\",\"bar\\\"-\\\\\"}".to_owned()
723 1 : )]
724 1 : );
725 :
726 : // nested arrays
727 1 : let json = "[[true, false], [null, 42], [\"foo\", \"bar\\\"-\\\\\"]]";
728 1 : let json: Value = serde_json::from_str(json).unwrap();
729 1 : let pg_params = json_to_pg_text(vec![json]);
730 1 : assert_eq!(
731 1 : pg_params,
732 1 : vec![Some(
733 1 : "{{true,false},{NULL,42},{\"foo\",\"bar\\\"-\\\\\"}}".to_owned()
734 1 : )]
735 1 : );
736 : // array of objects
737 1 : let json = r#"[{"foo": 1},{"bar": 2}]"#;
738 1 : let json: Value = serde_json::from_str(json).unwrap();
739 1 : let pg_params = json_to_pg_text(vec![json]);
740 1 : assert_eq!(
741 1 : pg_params,
742 1 : vec![Some(r#"{"{\"foo\":1}","{\"bar\":2}"}"#.to_owned())]
743 1 : );
744 1 : }
745 :
746 1 : #[test]
747 1 : fn test_atomic_types_parse() {
748 1 : assert_eq!(
749 1 : pg_text_to_json(Some("foo"), &Type::TEXT).unwrap(),
750 1 : json!("foo")
751 1 : );
752 1 : assert_eq!(pg_text_to_json(None, &Type::TEXT).unwrap(), json!(null));
753 1 : assert_eq!(pg_text_to_json(Some("42"), &Type::INT4).unwrap(), json!(42));
754 1 : assert_eq!(pg_text_to_json(Some("42"), &Type::INT2).unwrap(), json!(42));
755 1 : assert_eq!(
756 1 : pg_text_to_json(Some("42"), &Type::INT8).unwrap(),
757 1 : json!("42")
758 1 : );
759 1 : assert_eq!(
760 1 : pg_text_to_json(Some("42.42"), &Type::FLOAT8).unwrap(),
761 1 : json!(42.42)
762 1 : );
763 1 : assert_eq!(
764 1 : pg_text_to_json(Some("42.42"), &Type::FLOAT4).unwrap(),
765 1 : json!(42.42)
766 1 : );
767 1 : assert_eq!(
768 1 : pg_text_to_json(Some("NaN"), &Type::FLOAT4).unwrap(),
769 1 : json!("NaN")
770 1 : );
771 1 : assert_eq!(
772 1 : pg_text_to_json(Some("Infinity"), &Type::FLOAT4).unwrap(),
773 1 : json!("Infinity")
774 1 : );
775 1 : assert_eq!(
776 1 : pg_text_to_json(Some("-Infinity"), &Type::FLOAT4).unwrap(),
777 1 : json!("-Infinity")
778 1 : );
779 :
780 1 : let json: Value =
781 1 : serde_json::from_str("{\"s\":\"str\",\"n\":42,\"f\":4.2,\"a\":[null,3,\"a\"]}")
782 1 : .unwrap();
783 1 : assert_eq!(
784 1 : pg_text_to_json(
785 1 : Some(r#"{"s":"str","n":42,"f":4.2,"a":[null,3,"a"]}"#),
786 1 : &Type::JSONB
787 1 : )
788 1 : .unwrap(),
789 1 : json
790 1 : );
791 1 : }
792 :
793 1 : #[test]
794 1 : fn test_pg_array_parse_text() {
795 4 : fn pt(pg_arr: &str) -> Value {
796 4 : pg_array_parse(pg_arr, &Type::TEXT).unwrap()
797 4 : }
798 1 : assert_eq!(
799 1 : pt(r#"{"aa\"\\\,a",cha,"bbbb"}"#),
800 1 : json!(["aa\"\\,a", "cha", "bbbb"])
801 1 : );
802 1 : assert_eq!(
803 1 : pt(r#"{{"foo","bar"},{"bee","bop"}}"#),
804 1 : json!([["foo", "bar"], ["bee", "bop"]])
805 1 : );
806 1 : assert_eq!(
807 1 : pt(r#"{{{{"foo",NULL,"bop",bup}}}}"#),
808 1 : json!([[[["foo", null, "bop", "bup"]]]])
809 1 : );
810 1 : assert_eq!(
811 1 : pt(r#"{{"1",2,3},{4,NULL,6},{NULL,NULL,NULL}}"#),
812 1 : json!([["1", "2", "3"], ["4", null, "6"], [null, null, null]])
813 1 : );
814 1 : }
815 :
816 1 : #[test]
817 1 : fn test_pg_array_parse_bool() {
818 4 : fn pb(pg_arr: &str) -> Value {
819 4 : pg_array_parse(pg_arr, &Type::BOOL).unwrap()
820 4 : }
821 1 : assert_eq!(pb(r#"{t,f,t}"#), json!([true, false, true]));
822 1 : assert_eq!(pb(r#"{{t,f,t}}"#), json!([[true, false, true]]));
823 1 : assert_eq!(
824 1 : pb(r#"{{t,f},{f,t}}"#),
825 1 : json!([[true, false], [false, true]])
826 1 : );
827 1 : assert_eq!(
828 1 : pb(r#"{{t,NULL},{NULL,f}}"#),
829 1 : json!([[true, null], [null, false]])
830 1 : );
831 1 : }
832 :
833 1 : #[test]
834 1 : fn test_pg_array_parse_numbers() {
835 9 : fn pn(pg_arr: &str, ty: &Type) -> Value {
836 9 : pg_array_parse(pg_arr, ty).unwrap()
837 9 : }
838 1 : assert_eq!(pn(r#"{1,2,3}"#, &Type::INT4), json!([1, 2, 3]));
839 1 : assert_eq!(pn(r#"{1,2,3}"#, &Type::INT2), json!([1, 2, 3]));
840 1 : assert_eq!(pn(r#"{1,2,3}"#, &Type::INT8), json!(["1", "2", "3"]));
841 1 : assert_eq!(pn(r#"{1,2,3}"#, &Type::FLOAT4), json!([1.0, 2.0, 3.0]));
842 1 : assert_eq!(pn(r#"{1,2,3}"#, &Type::FLOAT8), json!([1.0, 2.0, 3.0]));
843 1 : assert_eq!(
844 1 : pn(r#"{1.1,2.2,3.3}"#, &Type::FLOAT4),
845 1 : json!([1.1, 2.2, 3.3])
846 1 : );
847 1 : assert_eq!(
848 1 : pn(r#"{1.1,2.2,3.3}"#, &Type::FLOAT8),
849 1 : json!([1.1, 2.2, 3.3])
850 1 : );
851 1 : assert_eq!(
852 1 : pn(r#"{NaN,Infinity,-Infinity}"#, &Type::FLOAT4),
853 1 : json!(["NaN", "Infinity", "-Infinity"])
854 1 : );
855 1 : assert_eq!(
856 1 : pn(r#"{NaN,Infinity,-Infinity}"#, &Type::FLOAT8),
857 1 : json!(["NaN", "Infinity", "-Infinity"])
858 1 : );
859 1 : }
860 :
861 1 : #[test]
862 1 : fn test_pg_array_with_decoration() {
863 1 : fn p(pg_arr: &str) -> Value {
864 1 : pg_array_parse(pg_arr, &Type::INT2).unwrap()
865 1 : }
866 1 : assert_eq!(
867 1 : p(r#"[1:1][-2:-1][3:5]={{{1,2,3},{4,5,6}}}"#),
868 1 : json!([[[1, 2, 3], [4, 5, 6]]])
869 1 : );
870 1 : }
871 1 : #[test]
872 1 : fn test_pg_array_parse_json() {
873 4 : fn pt(pg_arr: &str) -> Value {
874 4 : pg_array_parse(pg_arr, &Type::JSONB).unwrap()
875 4 : }
876 1 : assert_eq!(pt(r#"{"{}"}"#), json!([{}]));
877 1 : assert_eq!(
878 1 : pt(r#"{"{\"foo\": 1, \"bar\": 2}"}"#),
879 1 : json!([{"foo": 1, "bar": 2}])
880 1 : );
881 1 : assert_eq!(
882 1 : pt(r#"{"{\"foo\": 1}", "{\"bar\": 2}"}"#),
883 1 : json!([{"foo": 1}, {"bar": 2}])
884 1 : );
885 1 : assert_eq!(
886 1 : pt(r#"{{"{\"foo\": 1}", "{\"bar\": 2}"}}"#),
887 1 : json!([[{"foo": 1}, {"bar": 2}]])
888 1 : );
889 1 : }
890 : }
|