Line data Source code
1 : use std::sync::Arc;
2 :
3 : use anyhow::bail;
4 : use anyhow::Context;
5 : use futures::pin_mut;
6 : use futures::StreamExt;
7 : use hyper::body::HttpBody;
8 : use hyper::header;
9 : use hyper::http::HeaderName;
10 : use hyper::http::HeaderValue;
11 : use hyper::Response;
12 : use hyper::StatusCode;
13 : use hyper::{Body, HeaderMap, Request};
14 : use serde_json::json;
15 : use serde_json::Value;
16 : use tokio_postgres::error::DbError;
17 : use tokio_postgres::error::ErrorPosition;
18 : use tokio_postgres::GenericClient;
19 : use tokio_postgres::IsolationLevel;
20 : use tokio_postgres::ReadyForQueryStatus;
21 : use tokio_postgres::Transaction;
22 : use tracing::error;
23 : use tracing::instrument;
24 : use url::Url;
25 : use utils::http::error::ApiError;
26 : use utils::http::json::json_response;
27 :
28 : use crate::auth::backend::ComputeUserInfo;
29 : use crate::auth::endpoint_sni;
30 : use crate::config::HttpConfig;
31 : use crate::config::TlsConfig;
32 : use crate::context::RequestMonitoring;
33 : use crate::metrics::NUM_CONNECTION_REQUESTS_GAUGE;
34 : use crate::proxy::NeonOptions;
35 : use crate::RoleName;
36 :
37 : use super::conn_pool::ConnInfo;
38 : use super::conn_pool::GlobalConnPool;
39 : use super::json::{json_to_pg_text, pg_text_row_to_json};
40 : use super::SERVERLESS_DRIVER_SNI;
41 :
42 281 : #[derive(serde::Deserialize)]
43 : struct QueryData {
44 : query: String,
45 : params: Vec<serde_json::Value>,
46 : }
47 :
48 6 : #[derive(serde::Deserialize)]
49 : struct BatchQueryData {
50 : queries: Vec<QueryData>,
51 : }
52 :
53 48 : #[derive(serde::Deserialize)]
54 : #[serde(untagged)]
55 : enum Payload {
56 : Single(QueryData),
57 : Batch(BatchQueryData),
58 : }
59 :
60 : const MAX_RESPONSE_SIZE: usize = 10 * 1024 * 1024; // 10 MiB
61 : const MAX_REQUEST_SIZE: u64 = 10 * 1024 * 1024; // 10 MiB
62 :
63 : static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output");
64 : static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode");
65 : static ALLOW_POOL: HeaderName = HeaderName::from_static("neon-pool-opt-in");
66 : static TXN_ISOLATION_LEVEL: HeaderName = HeaderName::from_static("neon-batch-isolation-level");
67 : static TXN_READ_ONLY: HeaderName = HeaderName::from_static("neon-batch-read-only");
68 : static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrable");
69 :
70 : static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true");
71 :
72 46 : fn get_conn_info(
73 46 : ctx: &mut RequestMonitoring,
74 46 : headers: &HeaderMap,
75 46 : sni_hostname: Option<String>,
76 46 : tls: &TlsConfig,
77 46 : ) -> Result<ConnInfo, anyhow::Error> {
78 46 : let connection_string = headers
79 46 : .get("Neon-Connection-String")
80 46 : .ok_or(anyhow::anyhow!("missing connection string"))?
81 46 : .to_str()?;
82 :
83 46 : let connection_url = Url::parse(connection_string)?;
84 :
85 46 : let protocol = connection_url.scheme();
86 46 : if protocol != "postgres" && protocol != "postgresql" {
87 0 : return Err(anyhow::anyhow!(
88 0 : "connection string must start with postgres: or postgresql:"
89 0 : ));
90 46 : }
91 :
92 46 : let mut url_path = connection_url
93 46 : .path_segments()
94 46 : .ok_or(anyhow::anyhow!("missing database name"))?;
95 :
96 46 : let dbname = url_path
97 46 : .next()
98 46 : .ok_or(anyhow::anyhow!("invalid database name"))?;
99 :
100 46 : let username = RoleName::from(connection_url.username());
101 46 : if username.is_empty() {
102 0 : return Err(anyhow::anyhow!("missing username"));
103 46 : }
104 46 : ctx.set_user(username.clone());
105 :
106 46 : let password = connection_url
107 46 : .password()
108 46 : .ok_or(anyhow::anyhow!("no password"))?;
109 :
110 : // TLS certificate selector now based on SNI hostname, so if we are running here
111 : // we are sure that SNI hostname is set to one of the configured domain names.
112 46 : let sni_hostname = sni_hostname.ok_or(anyhow::anyhow!("no SNI hostname set"))?;
113 :
114 46 : let hostname = connection_url
115 46 : .host_str()
116 46 : .ok_or(anyhow::anyhow!("no host"))?;
117 :
118 46 : let host_header = headers
119 46 : .get("host")
120 46 : .and_then(|h| h.to_str().ok())
121 46 : .and_then(|h| h.split(':').next());
122 46 :
123 46 : // sni_hostname has to be either the same as hostname or the one used in serverless driver.
124 46 : if !check_matches(&sni_hostname, hostname)? {
125 0 : return Err(anyhow::anyhow!("mismatched SNI hostname and hostname"));
126 46 : } else if let Some(h) = host_header {
127 46 : if h != sni_hostname {
128 0 : return Err(anyhow::anyhow!("mismatched host header and hostname"));
129 46 : }
130 0 : }
131 :
132 46 : let endpoint = endpoint_sni(hostname, &tls.common_names)?.context("malformed endpoint")?;
133 46 : ctx.set_endpoint_id(endpoint.clone());
134 46 :
135 46 : let pairs = connection_url.query_pairs();
136 46 :
137 46 : let mut options = Option::None;
138 :
139 46 : for (key, value) in pairs {
140 0 : if key == "options" {
141 0 : options = Some(NeonOptions::parse_options_raw(&value));
142 0 : break;
143 0 : }
144 : }
145 :
146 46 : let user_info = ComputeUserInfo {
147 46 : endpoint,
148 46 : user: username,
149 46 : options: options.unwrap_or_default(),
150 46 : };
151 46 :
152 46 : Ok(ConnInfo {
153 46 : user_info,
154 46 : dbname: dbname.into(),
155 46 : password: password.into(),
156 46 : })
157 46 : }
158 :
159 46 : fn check_matches(sni_hostname: &str, hostname: &str) -> Result<bool, anyhow::Error> {
160 46 : if sni_hostname == hostname {
161 45 : return Ok(true);
162 1 : }
163 1 : let (sni_hostname_first, sni_hostname_rest) = sni_hostname
164 1 : .split_once('.')
165 1 : .ok_or_else(|| anyhow::anyhow!("Unexpected sni format."))?;
166 1 : let (_, hostname_rest) = hostname
167 1 : .split_once('.')
168 1 : .ok_or_else(|| anyhow::anyhow!("Unexpected hostname format."))?;
169 1 : Ok(sni_hostname_rest == hostname_rest && sni_hostname_first == SERVERLESS_DRIVER_SNI)
170 46 : }
171 :
172 : // TODO: return different http error codes
173 46 : pub async fn handle(
174 46 : tls: &'static TlsConfig,
175 46 : config: &'static HttpConfig,
176 46 : ctx: &mut RequestMonitoring,
177 46 : request: Request<Body>,
178 46 : sni_hostname: Option<String>,
179 46 : conn_pool: Arc<GlobalConnPool>,
180 46 : ) -> Result<Response<Body>, ApiError> {
181 46 : let result = tokio::time::timeout(
182 46 : config.request_timeout,
183 46 : handle_inner(tls, config, ctx, request, sni_hostname, conn_pool),
184 46 : )
185 832 : .await;
186 46 : let mut response = match result {
187 46 : Ok(r) => match r {
188 41 : Ok(r) => r,
189 5 : Err(e) => {
190 5 : let mut message = format!("{:?}", e);
191 5 : let db_error = e
192 5 : .downcast_ref::<tokio_postgres::Error>()
193 5 : .and_then(|e| e.as_db_error());
194 65 : fn get<'a, T: serde::Serialize>(
195 65 : db: Option<&'a DbError>,
196 65 : x: impl FnOnce(&'a DbError) -> T,
197 65 : ) -> Value {
198 65 : db.map(x)
199 65 : .and_then(|t| serde_json::to_value(t).ok())
200 65 : .unwrap_or_default()
201 65 : }
202 :
203 5 : if let Some(db_error) = db_error {
204 3 : db_error.message().clone_into(&mut message);
205 3 : }
206 :
207 5 : let position = db_error.and_then(|db| db.position());
208 5 : let (position, internal_position, internal_query) = match position {
209 1 : Some(ErrorPosition::Original(position)) => (
210 1 : Value::String(position.to_string()),
211 1 : Value::Null,
212 1 : Value::Null,
213 1 : ),
214 0 : Some(ErrorPosition::Internal { position, query }) => (
215 0 : Value::Null,
216 0 : Value::String(position.to_string()),
217 0 : Value::String(query.clone()),
218 0 : ),
219 4 : None => (Value::Null, Value::Null, Value::Null),
220 : };
221 :
222 5 : let code = get(db_error, |db| db.code().code());
223 5 : let severity = get(db_error, |db| db.severity());
224 5 : let detail = get(db_error, |db| db.detail());
225 5 : let hint = get(db_error, |db| db.hint());
226 5 : let where_ = get(db_error, |db| db.where_());
227 5 : let table = get(db_error, |db| db.table());
228 5 : let column = get(db_error, |db| db.column());
229 5 : let schema = get(db_error, |db| db.schema());
230 5 : let datatype = get(db_error, |db| db.datatype());
231 5 : let constraint = get(db_error, |db| db.constraint());
232 5 : let file = get(db_error, |db| db.file());
233 5 : let line = get(db_error, |db| db.line().map(|l| l.to_string()));
234 5 : let routine = get(db_error, |db| db.routine());
235 :
236 5 : error!(
237 5 : ?code,
238 5 : "sql-over-http per-client task finished with an error: {e:#}"
239 5 : );
240 : // TODO: this shouldn't always be bad request.
241 5 : json_response(
242 5 : StatusCode::BAD_REQUEST,
243 5 : json!({
244 5 : "message": message,
245 5 : "code": code,
246 5 : "detail": detail,
247 5 : "hint": hint,
248 5 : "position": position,
249 5 : "internalPosition": internal_position,
250 5 : "internalQuery": internal_query,
251 5 : "severity": severity,
252 5 : "where": where_,
253 5 : "table": table,
254 5 : "column": column,
255 5 : "schema": schema,
256 5 : "dataType": datatype,
257 5 : "constraint": constraint,
258 5 : "file": file,
259 5 : "line": line,
260 5 : "routine": routine,
261 5 : }),
262 5 : )?
263 : }
264 : },
265 : Err(_) => {
266 0 : let message = format!(
267 0 : "HTTP-Connection timed out, execution time exeeded {} seconds",
268 0 : config.request_timeout.as_secs()
269 0 : );
270 0 : error!(message);
271 0 : json_response(
272 0 : StatusCode::GATEWAY_TIMEOUT,
273 0 : json!({ "message": message, "code": StatusCode::GATEWAY_TIMEOUT.as_u16() }),
274 0 : )?
275 : }
276 : };
277 46 : response.headers_mut().insert(
278 46 : "Access-Control-Allow-Origin",
279 46 : hyper::http::HeaderValue::from_static("*"),
280 46 : );
281 46 : Ok(response)
282 46 : }
283 :
284 0 : #[instrument(name = "sql-over-http", fields(pid = tracing::field::Empty), skip_all)]
285 : async fn handle_inner(
286 : tls: &'static TlsConfig,
287 : config: &'static HttpConfig,
288 : ctx: &mut RequestMonitoring,
289 : request: Request<Body>,
290 : sni_hostname: Option<String>,
291 : conn_pool: Arc<GlobalConnPool>,
292 : ) -> anyhow::Result<Response<Body>> {
293 : let _request_gauge = NUM_CONNECTION_REQUESTS_GAUGE
294 : .with_label_values(&["http"])
295 : .guard();
296 :
297 : //
298 : // Determine the destination and connection params
299 : //
300 : let headers = request.headers();
301 : let conn_info = get_conn_info(ctx, headers, sni_hostname, tls)?;
302 :
303 : // Determine the output options. Default behaviour is 'false'. Anything that is not
304 : // strictly 'true' assumed to be false.
305 : let raw_output = headers.get(&RAW_TEXT_OUTPUT) == Some(&HEADER_VALUE_TRUE);
306 : let array_mode = headers.get(&ARRAY_MODE) == Some(&HEADER_VALUE_TRUE);
307 :
308 : // Allow connection pooling only if explicitly requested
309 : // or if we have decided that http pool is no longer opt-in
310 : let allow_pool =
311 : !config.pool_options.opt_in || headers.get(&ALLOW_POOL) == Some(&HEADER_VALUE_TRUE);
312 :
313 : // isolation level, read only and deferrable
314 :
315 : let txn_isolation_level_raw = headers.get(&TXN_ISOLATION_LEVEL).cloned();
316 : let txn_isolation_level = match txn_isolation_level_raw {
317 : Some(ref x) => Some(match x.as_bytes() {
318 : b"Serializable" => IsolationLevel::Serializable,
319 : b"ReadUncommitted" => IsolationLevel::ReadUncommitted,
320 : b"ReadCommitted" => IsolationLevel::ReadCommitted,
321 : b"RepeatableRead" => IsolationLevel::RepeatableRead,
322 : _ => bail!("invalid isolation level"),
323 : }),
324 : None => None,
325 : };
326 :
327 : let txn_read_only = headers.get(&TXN_READ_ONLY) == Some(&HEADER_VALUE_TRUE);
328 : let txn_deferrable = headers.get(&TXN_DEFERRABLE) == Some(&HEADER_VALUE_TRUE);
329 :
330 : let paused = ctx.latency_timer.pause();
331 : let request_content_length = match request.body().size_hint().upper() {
332 : Some(v) => v,
333 : None => MAX_REQUEST_SIZE + 1,
334 : };
335 : drop(paused);
336 :
337 : // we don't have a streaming request support yet so this is to prevent OOM
338 : // from a malicious user sending an extremely large request body
339 : if request_content_length > MAX_REQUEST_SIZE {
340 : return Err(anyhow::anyhow!(
341 : "request is too large (max is {MAX_REQUEST_SIZE} bytes)"
342 : ));
343 : }
344 :
345 : //
346 : // Read the query and query params from the request body
347 : //
348 : let body = hyper::body::to_bytes(request.into_body()).await?;
349 : let payload: Payload = serde_json::from_slice(&body)?;
350 :
351 : let mut client = conn_pool.get(ctx, conn_info, !allow_pool).await?;
352 :
353 : let mut response = Response::builder()
354 : .status(StatusCode::OK)
355 : .header(header::CONTENT_TYPE, "application/json");
356 :
357 : //
358 : // Now execute the query and return the result
359 : //
360 : let mut size = 0;
361 : let result =
362 : match payload {
363 : Payload::Single(stmt) => {
364 : let (status, results) =
365 : query_to_json(&*client, stmt, &mut 0, raw_output, array_mode)
366 : .await
367 2 : .map_err(|e| {
368 2 : client.discard();
369 2 : e
370 2 : })?;
371 : client.check_idle(status);
372 : results
373 : }
374 : Payload::Batch(statements) => {
375 : let (inner, mut discard) = client.inner();
376 : let mut builder = inner.build_transaction();
377 : if let Some(isolation_level) = txn_isolation_level {
378 : builder = builder.isolation_level(isolation_level);
379 : }
380 : if txn_read_only {
381 : builder = builder.read_only(true);
382 : }
383 : if txn_deferrable {
384 : builder = builder.deferrable(true);
385 : }
386 :
387 0 : let transaction = builder.start().await.map_err(|e| {
388 0 : // if we cannot start a transaction, we should return immediately
389 0 : // and not return to the pool. connection is clearly broken
390 0 : discard.discard();
391 0 : e
392 0 : })?;
393 :
394 : let results =
395 : match query_batch(&transaction, statements, &mut size, raw_output, array_mode)
396 : .await
397 : {
398 : Ok(results) => {
399 0 : let status = transaction.commit().await.map_err(|e| {
400 0 : // if we cannot commit - for now don't return connection to pool
401 0 : // TODO: get a query status from the error
402 0 : discard.discard();
403 0 : e
404 0 : })?;
405 : discard.check_idle(status);
406 : results
407 : }
408 : Err(err) => {
409 0 : let status = transaction.rollback().await.map_err(|e| {
410 0 : // if we cannot rollback - for now don't return connection to pool
411 0 : // TODO: get a query status from the error
412 0 : discard.discard();
413 0 : e
414 0 : })?;
415 : discard.check_idle(status);
416 : return Err(err);
417 : }
418 : };
419 :
420 : if txn_read_only {
421 : response = response.header(
422 : TXN_READ_ONLY.clone(),
423 : HeaderValue::try_from(txn_read_only.to_string())?,
424 : );
425 : }
426 : if txn_deferrable {
427 : response = response.header(
428 : TXN_DEFERRABLE.clone(),
429 : HeaderValue::try_from(txn_deferrable.to_string())?,
430 : );
431 : }
432 : if let Some(txn_isolation_level) = txn_isolation_level_raw {
433 : response = response.header(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level);
434 : }
435 : json!({ "results": results })
436 : }
437 : };
438 :
439 : ctx.set_success();
440 : ctx.log();
441 : let metrics = client.metrics();
442 :
443 : // how could this possibly fail
444 : let body = serde_json::to_string(&result).expect("json serialization should not fail");
445 : let len = body.len();
446 : let response = response
447 : .body(Body::from(body))
448 : // only fails if invalid status code or invalid header/values are given.
449 : // these are not user configurable so it cannot fail dynamically
450 : .expect("building response payload should not fail");
451 :
452 : // count the egress bytes - we miss the TLS and header overhead but oh well...
453 : // moving this later in the stack is going to be a lot of effort and ehhhh
454 : metrics.record_egress(len as u64);
455 :
456 : Ok(response)
457 : }
458 :
459 2 : async fn query_batch(
460 2 : transaction: &Transaction<'_>,
461 2 : queries: BatchQueryData,
462 2 : total_size: &mut usize,
463 2 : raw_output: bool,
464 2 : array_mode: bool,
465 2 : ) -> anyhow::Result<Vec<Value>> {
466 2 : let mut results = Vec::with_capacity(queries.queries.len());
467 2 : let mut current_size = 0;
468 13 : for stmt in queries.queries {
469 11 : // TODO: maybe we should check that the transaction bit is set here
470 11 : let (_, values) =
471 11 : query_to_json(transaction, stmt, &mut current_size, raw_output, array_mode).await?;
472 11 : results.push(values);
473 : }
474 2 : *total_size += current_size;
475 2 : Ok(results)
476 2 : }
477 :
478 52 : async fn query_to_json<T: GenericClient>(
479 52 : client: &T,
480 52 : data: QueryData,
481 52 : current_size: &mut usize,
482 52 : raw_output: bool,
483 52 : array_mode: bool,
484 52 : ) -> anyhow::Result<(ReadyForQueryStatus, Value)> {
485 52 : let query_params = json_to_pg_text(data.params);
486 52 : let row_stream = client.query_raw_txt(&data.query, query_params).await?;
487 :
488 : // Manually drain the stream into a vector to leave row_stream hanging
489 : // around to get a command tag. Also check that the response is not too
490 : // big.
491 51 : pin_mut!(row_stream);
492 51 : let mut rows: Vec<tokio_postgres::Row> = Vec::new();
493 506406 : while let Some(row) = row_stream.next().await {
494 506356 : let row = row?;
495 506356 : *current_size += row.body_len();
496 506356 : rows.push(row);
497 506356 : // we don't have a streaming response support yet so this is to prevent OOM
498 506356 : // from a malicious query (eg a cross join)
499 506356 : if *current_size > MAX_RESPONSE_SIZE {
500 1 : return Err(anyhow::anyhow!(
501 1 : "response is too large (max is {MAX_RESPONSE_SIZE} bytes)"
502 1 : ));
503 506355 : }
504 : }
505 :
506 50 : let ready = row_stream.ready_status();
507 50 :
508 50 : // grab the command tag and number of rows affected
509 50 : let command_tag = row_stream.command_tag().unwrap_or_default();
510 50 : let mut command_tag_split = command_tag.split(' ');
511 50 : let command_tag_name = command_tag_split.next().unwrap_or_default();
512 50 : let command_tag_count = if command_tag_name == "INSERT" {
513 : // INSERT returns OID first and then number of rows
514 2 : command_tag_split.nth(1)
515 : } else {
516 : // other commands return number of rows (if any)
517 48 : command_tag_split.next()
518 : }
519 50 : .and_then(|s| s.parse::<i64>().ok());
520 50 :
521 50 : let mut fields = vec![];
522 50 : let mut columns = vec![];
523 :
524 118 : for c in row_stream.columns() {
525 118 : fields.push(json!({
526 118 : "name": Value::String(c.name().to_owned()),
527 118 : "dataTypeID": Value::Number(c.type_().oid().into()),
528 118 : "tableID": c.table_oid(),
529 118 : "columnID": c.column_id(),
530 118 : "dataTypeSize": c.type_size(),
531 118 : "dataTypeModifier": c.type_modifier(),
532 118 : "format": "text",
533 118 : }));
534 118 : columns.push(client.get_type(c.type_oid()).await?);
535 : }
536 :
537 : // convert rows to JSON
538 50 : let rows = rows
539 50 : .iter()
540 52 : .map(|row| pg_text_row_to_json(row, &columns, raw_output, array_mode))
541 50 : .collect::<Result<Vec<_>, _>>()?;
542 :
543 : // resulting JSON format is based on the format of node-postgres result
544 50 : Ok((
545 50 : ready,
546 50 : json!({
547 50 : "command": command_tag_name,
548 50 : "rowCount": command_tag_count,
549 50 : "rows": rows,
550 50 : "fields": fields,
551 50 : "rowAsArray": array_mode,
552 50 : }),
553 50 : ))
554 52 : }
|