Line data Source code
1 : use std::pin::pin;
2 : use std::sync::Arc;
3 :
4 : use bytes::Bytes;
5 : use futures::future::{Either, select, try_join};
6 : use futures::{StreamExt, TryFutureExt};
7 : use http::Method;
8 : use http::header::AUTHORIZATION;
9 : use http_body_util::combinators::BoxBody;
10 : use http_body_util::{BodyExt, Full};
11 : use http_utils::error::ApiError;
12 : use hyper::body::Incoming;
13 : use hyper::http::{HeaderName, HeaderValue};
14 : use hyper::{Request, Response, StatusCode, header};
15 : use indexmap::IndexMap;
16 : use postgres_client::error::{DbError, ErrorPosition, SqlState};
17 : use postgres_client::{
18 : GenericClient, IsolationLevel, NoTls, ReadyForQueryStatus, RowStream, Transaction,
19 : };
20 : use serde::Serialize;
21 : use serde_json::Value;
22 : use serde_json::value::RawValue;
23 : use tokio::time::{self, Instant};
24 : use tokio_util::sync::CancellationToken;
25 : use tracing::{Level, debug, error, info};
26 : use typed_json::json;
27 :
28 : use super::backend::{LocalProxyConnError, PoolingBackend};
29 : use super::conn_pool::AuthData;
30 : use super::conn_pool_lib::{self, ConnInfo};
31 : use super::error::{ConnInfoError, HttpCodeError, ReadPayloadError};
32 : use super::http_util::{
33 : ALLOW_POOL, ARRAY_MODE, CONN_STRING, NEON_REQUEST_ID, RAW_TEXT_OUTPUT, TXN_DEFERRABLE,
34 : TXN_ISOLATION_LEVEL, TXN_READ_ONLY, get_conn_info, json_response, uuid_to_header_value,
35 : };
36 : use super::json::{JsonConversionError, json_to_pg_text, pg_text_row_to_json};
37 : use crate::auth::backend::ComputeCredentialKeys;
38 : use crate::config::{HttpConfig, ProxyConfig};
39 : use crate::context::RequestContext;
40 : use crate::error::{ErrorKind, ReportableError, UserFacingError};
41 : use crate::http::read_body_with_limit;
42 : use crate::metrics::{HttpDirection, Metrics};
43 : use crate::serverless::backend::HttpConnError;
44 : use crate::usage_metrics::{MetricCounter, MetricCounterRecorder};
45 : use crate::util::run_until_cancelled;
46 :
47 : #[derive(serde::Deserialize)]
48 : #[serde(rename_all = "camelCase")]
49 : struct QueryData {
50 : query: String,
51 : #[serde(deserialize_with = "bytes_to_pg_text")]
52 : #[serde(default)]
53 : params: Vec<Option<String>>,
54 : #[serde(default)]
55 : array_mode: Option<bool>,
56 : }
57 :
58 0 : #[derive(serde::Deserialize)]
59 : struct BatchQueryData {
60 : queries: Vec<QueryData>,
61 : }
62 :
63 : #[derive(serde::Deserialize)]
64 : #[serde(untagged)]
65 : enum Payload {
66 : Single(QueryData),
67 : Batch(BatchQueryData),
68 : }
69 :
70 : static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true");
71 :
72 3 : fn bytes_to_pg_text<'de, D>(deserializer: D) -> Result<Vec<Option<String>>, D::Error>
73 3 : where
74 3 : D: serde::de::Deserializer<'de>,
75 : {
76 : // TODO: consider avoiding the allocation here.
77 3 : let json: Vec<Value> = serde::de::Deserialize::deserialize(deserializer)?;
78 3 : Ok(json_to_pg_text(json))
79 3 : }
80 :
81 0 : pub(crate) async fn handle(
82 0 : config: &'static ProxyConfig,
83 0 : ctx: RequestContext,
84 0 : request: Request<Incoming>,
85 0 : backend: Arc<PoolingBackend>,
86 0 : cancel: CancellationToken,
87 0 : ) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
88 0 : let result = handle_inner(cancel, config, &ctx, request, backend).await;
89 :
90 0 : let mut response = match result {
91 0 : Ok(r) => {
92 0 : ctx.set_success();
93 :
94 : // Handling the error response from local proxy here
95 0 : if config.authentication_config.is_auth_broker && r.status().is_server_error() {
96 0 : let status = r.status();
97 :
98 0 : let body_bytes = r
99 0 : .collect()
100 0 : .await
101 0 : .map_err(|e| {
102 0 : ApiError::InternalServerError(anyhow::Error::msg(format!(
103 0 : "could not collect http body: {e}"
104 0 : )))
105 0 : })?
106 0 : .to_bytes();
107 :
108 0 : if let Ok(mut json_map) =
109 0 : serde_json::from_slice::<IndexMap<&str, &RawValue>>(&body_bytes)
110 : {
111 0 : let message = json_map.get("message");
112 0 : if let Some(message) = message {
113 0 : let msg: String = match serde_json::from_str(message.get()) {
114 0 : Ok(msg) => msg,
115 : Err(_) => {
116 0 : "Unable to parse the response message from server".to_string()
117 : }
118 : };
119 :
120 0 : error!("Error response from local_proxy: {status} {msg}");
121 :
122 0 : json_map.retain(|key, _| !key.starts_with("neon:")); // remove all the neon-related keys
123 :
124 0 : let resp_json = serde_json::to_string(&json_map)
125 0 : .unwrap_or("failed to serialize the response message".to_string());
126 :
127 0 : return json_response(status, resp_json);
128 0 : }
129 0 : }
130 :
131 0 : error!("Unable to parse the response message from local_proxy");
132 0 : return json_response(
133 0 : status,
134 0 : json!({ "message": "Unable to parse the response message from server".to_string() }),
135 : );
136 0 : }
137 0 : r
138 : }
139 0 : Err(e @ SqlOverHttpError::Cancelled(_)) => {
140 0 : let error_kind = e.get_error_kind();
141 0 : ctx.set_error_kind(error_kind);
142 :
143 0 : let message = "Query cancelled, connection was terminated";
144 :
145 0 : tracing::info!(
146 0 : kind=error_kind.to_metric_label(),
147 : error=%e,
148 : msg=message,
149 0 : "forwarding error to user"
150 : );
151 :
152 0 : json_response(
153 : StatusCode::BAD_REQUEST,
154 0 : json!({ "message": message, "code": SqlState::PROTOCOL_VIOLATION.code() }),
155 0 : )?
156 : }
157 0 : Err(e) => {
158 0 : let error_kind = e.get_error_kind();
159 0 : ctx.set_error_kind(error_kind);
160 :
161 0 : let mut message = e.to_string_client();
162 0 : let db_error = match &e {
163 0 : SqlOverHttpError::ConnectCompute(HttpConnError::PostgresConnectionError(e))
164 0 : | SqlOverHttpError::Postgres(e) => e.as_db_error(),
165 0 : _ => None,
166 : };
167 0 : fn get<'a, T: Default>(db: Option<&'a DbError>, x: impl FnOnce(&'a DbError) -> T) -> T {
168 0 : db.map(x).unwrap_or_default()
169 0 : }
170 :
171 0 : if let Some(db_error) = db_error {
172 0 : db_error.message().clone_into(&mut message);
173 0 : }
174 :
175 0 : let position = db_error.and_then(|db| db.position());
176 0 : let (position, internal_position, internal_query) = match position {
177 0 : Some(ErrorPosition::Original(position)) => (Some(position.to_string()), None, None),
178 0 : Some(ErrorPosition::Internal { position, query }) => {
179 0 : (None, Some(position.to_string()), Some(query.clone()))
180 : }
181 0 : None => (None, None, None),
182 : };
183 :
184 0 : let code = get(db_error, |db| db.code().code());
185 0 : let severity = get(db_error, |db| db.severity());
186 0 : let detail = get(db_error, |db| db.detail());
187 0 : let hint = get(db_error, |db| db.hint());
188 0 : let where_ = get(db_error, |db| db.where_());
189 0 : let table = get(db_error, |db| db.table());
190 0 : let column = get(db_error, |db| db.column());
191 0 : let schema = get(db_error, |db| db.schema());
192 0 : let datatype = get(db_error, |db| db.datatype());
193 0 : let constraint = get(db_error, |db| db.constraint());
194 0 : let file = get(db_error, |db| db.file());
195 0 : let line = get(db_error, |db| db.line().map(|l| l.to_string()));
196 0 : let routine = get(db_error, |db| db.routine());
197 :
198 0 : match &e {
199 0 : SqlOverHttpError::Postgres(e)
200 0 : if e.as_db_error().is_some() && error_kind == ErrorKind::User =>
201 : {
202 : // this error contains too much info, and it's not an error we care about.
203 0 : if tracing::enabled!(Level::DEBUG) {
204 0 : tracing::debug!(
205 0 : kind=error_kind.to_metric_label(),
206 : error=%e,
207 : msg=message,
208 0 : "forwarding error to user"
209 : );
210 : } else {
211 0 : tracing::info!(
212 0 : kind = error_kind.to_metric_label(),
213 : error = "bad query",
214 0 : "forwarding error to user"
215 : );
216 : }
217 : }
218 : _ => {
219 0 : tracing::info!(
220 0 : kind=error_kind.to_metric_label(),
221 : error=%e,
222 : msg=message,
223 0 : "forwarding error to user"
224 : );
225 : }
226 : }
227 :
228 0 : json_response(
229 0 : e.get_http_status_code(),
230 0 : json!({
231 0 : "message": message,
232 0 : "code": code,
233 0 : "detail": detail,
234 0 : "hint": hint,
235 0 : "position": position,
236 0 : "internalPosition": internal_position,
237 0 : "internalQuery": internal_query,
238 0 : "severity": severity,
239 0 : "where": where_,
240 0 : "table": table,
241 0 : "column": column,
242 0 : "schema": schema,
243 0 : "dataType": datatype,
244 0 : "constraint": constraint,
245 0 : "file": file,
246 0 : "line": line,
247 0 : "routine": routine,
248 : }),
249 0 : )?
250 : }
251 : };
252 :
253 0 : response
254 0 : .headers_mut()
255 0 : .insert("Access-Control-Allow-Origin", HeaderValue::from_static("*"));
256 0 : Ok(response)
257 0 : }
258 :
259 : #[derive(Debug, thiserror::Error)]
260 : pub(crate) enum SqlOverHttpError {
261 : #[error("{0}")]
262 : ReadPayload(#[from] ReadPayloadError),
263 : #[error("{0}")]
264 : ConnectCompute(#[from] HttpConnError),
265 : #[error("{0}")]
266 : ConnInfo(#[from] ConnInfoError),
267 : #[error("response is too large (max is {0} bytes)")]
268 : ResponseTooLarge(usize),
269 : #[error("invalid isolation level")]
270 : InvalidIsolationLevel,
271 : /// for queries our customers choose to run
272 : #[error("{0}")]
273 : Postgres(#[source] postgres_client::Error),
274 : /// for queries we choose to run
275 : #[error("{0}")]
276 : InternalPostgres(#[source] postgres_client::Error),
277 : #[error("{0}")]
278 : JsonConversion(#[from] JsonConversionError),
279 : #[error("{0}")]
280 : Cancelled(SqlOverHttpCancel),
281 : }
282 :
283 : impl ReportableError for SqlOverHttpError {
284 0 : fn get_error_kind(&self) -> ErrorKind {
285 0 : match self {
286 0 : SqlOverHttpError::ReadPayload(e) => e.get_error_kind(),
287 0 : SqlOverHttpError::ConnectCompute(e) => e.get_error_kind(),
288 0 : SqlOverHttpError::ConnInfo(e) => e.get_error_kind(),
289 0 : SqlOverHttpError::ResponseTooLarge(_) => ErrorKind::User,
290 0 : SqlOverHttpError::InvalidIsolationLevel => ErrorKind::User,
291 : // customer initiated SQL errors.
292 0 : SqlOverHttpError::Postgres(p) => {
293 0 : if p.as_db_error().is_some() {
294 0 : ErrorKind::User
295 : } else {
296 0 : ErrorKind::Compute
297 : }
298 : }
299 : // proxy initiated SQL errors.
300 0 : SqlOverHttpError::InternalPostgres(p) => {
301 0 : if p.as_db_error().is_some() {
302 0 : ErrorKind::Service
303 : } else {
304 0 : ErrorKind::Compute
305 : }
306 : }
307 : // postgres returned a bad row format that we couldn't parse.
308 0 : SqlOverHttpError::JsonConversion(_) => ErrorKind::Postgres,
309 0 : SqlOverHttpError::Cancelled(c) => c.get_error_kind(),
310 : }
311 0 : }
312 : }
313 :
314 : impl UserFacingError for SqlOverHttpError {
315 0 : fn to_string_client(&self) -> String {
316 0 : match self {
317 0 : SqlOverHttpError::ReadPayload(p) => p.to_string(),
318 0 : SqlOverHttpError::ConnectCompute(c) => c.to_string_client(),
319 0 : SqlOverHttpError::ConnInfo(c) => c.to_string_client(),
320 0 : SqlOverHttpError::ResponseTooLarge(_) => self.to_string(),
321 0 : SqlOverHttpError::InvalidIsolationLevel => self.to_string(),
322 0 : SqlOverHttpError::Postgres(p) => p.to_string(),
323 0 : SqlOverHttpError::InternalPostgres(p) => p.to_string(),
324 0 : SqlOverHttpError::JsonConversion(_) => "could not parse postgres response".to_string(),
325 0 : SqlOverHttpError::Cancelled(_) => self.to_string(),
326 : }
327 0 : }
328 : }
329 :
330 : impl HttpCodeError for SqlOverHttpError {
331 0 : fn get_http_status_code(&self) -> StatusCode {
332 0 : match self {
333 0 : SqlOverHttpError::ReadPayload(e) => e.get_http_status_code(),
334 0 : SqlOverHttpError::ConnectCompute(h) => match h.get_error_kind() {
335 0 : ErrorKind::User => StatusCode::BAD_REQUEST,
336 0 : _ => StatusCode::INTERNAL_SERVER_ERROR,
337 : },
338 0 : SqlOverHttpError::ConnInfo(_) => StatusCode::BAD_REQUEST,
339 0 : SqlOverHttpError::ResponseTooLarge(_) => StatusCode::INSUFFICIENT_STORAGE,
340 0 : SqlOverHttpError::InvalidIsolationLevel => StatusCode::BAD_REQUEST,
341 0 : SqlOverHttpError::Postgres(_) => StatusCode::BAD_REQUEST,
342 0 : SqlOverHttpError::InternalPostgres(_) => StatusCode::INTERNAL_SERVER_ERROR,
343 0 : SqlOverHttpError::JsonConversion(_) => StatusCode::INTERNAL_SERVER_ERROR,
344 0 : SqlOverHttpError::Cancelled(_) => StatusCode::INTERNAL_SERVER_ERROR,
345 : }
346 0 : }
347 : }
348 :
349 : #[derive(Debug, thiserror::Error)]
350 : pub(crate) enum SqlOverHttpCancel {
351 : #[error("query was cancelled")]
352 : Postgres,
353 : #[error("query was cancelled while stuck trying to connect to the database")]
354 : Connect,
355 : }
356 :
357 : impl ReportableError for SqlOverHttpCancel {
358 0 : fn get_error_kind(&self) -> ErrorKind {
359 0 : match self {
360 0 : SqlOverHttpCancel::Postgres => ErrorKind::ClientDisconnect,
361 0 : SqlOverHttpCancel::Connect => ErrorKind::ClientDisconnect,
362 : }
363 0 : }
364 : }
365 :
366 : #[derive(Clone, Copy, Debug)]
367 : struct HttpHeaders {
368 : raw_output: bool,
369 : default_array_mode: bool,
370 : txn_isolation_level: Option<IsolationLevel>,
371 : txn_read_only: bool,
372 : txn_deferrable: bool,
373 : }
374 :
375 : impl HttpHeaders {
376 0 : fn try_parse(headers: &hyper::http::HeaderMap) -> Result<Self, SqlOverHttpError> {
377 : // Determine the output options. Default behaviour is 'false'. Anything that is not
378 : // strictly 'true' assumed to be false.
379 0 : let raw_output = headers.get(&RAW_TEXT_OUTPUT) == Some(&HEADER_VALUE_TRUE);
380 0 : let default_array_mode = headers.get(&ARRAY_MODE) == Some(&HEADER_VALUE_TRUE);
381 :
382 : // isolation level, read only and deferrable
383 0 : let txn_isolation_level = match headers.get(&TXN_ISOLATION_LEVEL) {
384 0 : Some(x) => Some(
385 0 : map_header_to_isolation_level(x).ok_or(SqlOverHttpError::InvalidIsolationLevel)?,
386 : ),
387 0 : None => None,
388 : };
389 :
390 0 : let txn_read_only = headers.get(&TXN_READ_ONLY) == Some(&HEADER_VALUE_TRUE);
391 0 : let txn_deferrable = headers.get(&TXN_DEFERRABLE) == Some(&HEADER_VALUE_TRUE);
392 :
393 0 : Ok(Self {
394 0 : raw_output,
395 0 : default_array_mode,
396 0 : txn_isolation_level,
397 0 : txn_read_only,
398 0 : txn_deferrable,
399 0 : })
400 0 : }
401 : }
402 :
403 0 : fn map_header_to_isolation_level(level: &HeaderValue) -> Option<IsolationLevel> {
404 0 : match level.as_bytes() {
405 0 : b"Serializable" => Some(IsolationLevel::Serializable),
406 0 : b"ReadUncommitted" => Some(IsolationLevel::ReadUncommitted),
407 0 : b"ReadCommitted" => Some(IsolationLevel::ReadCommitted),
408 0 : b"RepeatableRead" => Some(IsolationLevel::RepeatableRead),
409 0 : _ => None,
410 : }
411 0 : }
412 :
413 0 : fn map_isolation_level_to_headers(level: IsolationLevel) -> Option<HeaderValue> {
414 0 : match level {
415 0 : IsolationLevel::ReadUncommitted => Some(HeaderValue::from_static("ReadUncommitted")),
416 0 : IsolationLevel::ReadCommitted => Some(HeaderValue::from_static("ReadCommitted")),
417 0 : IsolationLevel::RepeatableRead => Some(HeaderValue::from_static("RepeatableRead")),
418 0 : IsolationLevel::Serializable => Some(HeaderValue::from_static("Serializable")),
419 0 : _ => None,
420 : }
421 0 : }
422 :
423 0 : async fn handle_inner(
424 0 : cancel: CancellationToken,
425 0 : config: &'static ProxyConfig,
426 0 : ctx: &RequestContext,
427 0 : request: Request<Incoming>,
428 0 : backend: Arc<PoolingBackend>,
429 0 : ) -> Result<Response<BoxBody<Bytes, hyper::Error>>, SqlOverHttpError> {
430 0 : let _requeset_gauge = Metrics::get()
431 0 : .proxy
432 0 : .connection_requests
433 0 : .guard(ctx.protocol());
434 0 : info!(
435 0 : protocol = %ctx.protocol(),
436 0 : "handling interactive connection from client"
437 : );
438 :
439 0 : let conn_info = get_conn_info(&config.authentication_config, ctx, None, request.headers())?;
440 0 : info!(
441 0 : user = conn_info.conn_info.user_info.user.as_str(),
442 0 : "credentials"
443 : );
444 :
445 0 : match conn_info.auth {
446 0 : AuthData::Jwt(jwt) if config.authentication_config.is_auth_broker => {
447 0 : handle_auth_broker_inner(ctx, request, conn_info.conn_info, jwt, backend).await
448 : }
449 0 : auth => {
450 0 : handle_db_inner(
451 0 : cancel,
452 0 : config,
453 0 : ctx,
454 0 : request,
455 0 : conn_info.conn_info,
456 0 : auth,
457 0 : backend,
458 0 : )
459 0 : .await
460 : }
461 : }
462 0 : }
463 :
464 0 : async fn handle_db_inner(
465 0 : cancel: CancellationToken,
466 0 : config: &'static ProxyConfig,
467 0 : ctx: &RequestContext,
468 0 : request: Request<Incoming>,
469 0 : conn_info: ConnInfo,
470 0 : auth: AuthData,
471 0 : backend: Arc<PoolingBackend>,
472 0 : ) -> Result<Response<BoxBody<Bytes, hyper::Error>>, SqlOverHttpError> {
473 : //
474 : // Determine the destination and connection params
475 : //
476 0 : let headers = request.headers();
477 :
478 : // Allow connection pooling only if explicitly requested
479 : // or if we have decided that http pool is no longer opt-in
480 0 : let allow_pool = !config.http_config.pool_options.opt_in
481 0 : || headers.get(&ALLOW_POOL) == Some(&HEADER_VALUE_TRUE);
482 :
483 0 : let parsed_headers = HttpHeaders::try_parse(headers)?;
484 :
485 0 : let mut request_len = 0;
486 0 : let fetch_and_process_request = Box::pin(
487 0 : async {
488 0 : let body = read_body_with_limit(
489 0 : request.into_body(),
490 0 : config.http_config.max_request_size_bytes,
491 0 : )
492 0 : .await?;
493 :
494 0 : request_len = body.len();
495 :
496 0 : Metrics::get()
497 0 : .proxy
498 0 : .http_conn_content_length_bytes
499 0 : .observe(HttpDirection::Request, body.len() as f64);
500 :
501 0 : debug!(length = body.len(), "request payload read");
502 0 : let payload: Payload = serde_json::from_slice(&body)?;
503 0 : Ok::<Payload, ReadPayloadError>(payload) // Adjust error type accordingly
504 0 : }
505 0 : .map_err(SqlOverHttpError::from),
506 : );
507 :
508 0 : let authenticate_and_connect = Box::pin(
509 0 : async {
510 0 : let keys = match auth {
511 0 : AuthData::Password(pw) => backend
512 0 : .authenticate_with_password(ctx, &conn_info.user_info, &pw)
513 0 : .await
514 0 : .map_err(HttpConnError::AuthError)?,
515 0 : AuthData::Jwt(jwt) => backend
516 0 : .authenticate_with_jwt(ctx, &conn_info.user_info, jwt)
517 0 : .await
518 0 : .map_err(HttpConnError::AuthError)?,
519 : };
520 :
521 0 : let client = match keys.keys {
522 0 : ComputeCredentialKeys::JwtPayload(payload)
523 0 : if backend.auth_backend.is_local_proxy() =>
524 : {
525 : #[cfg(feature = "testing")]
526 0 : let disable_pg_session_jwt = config.disable_pg_session_jwt;
527 : #[cfg(not(feature = "testing"))]
528 : let disable_pg_session_jwt = false;
529 0 : let mut client = backend
530 0 : .connect_to_local_postgres(ctx, conn_info, disable_pg_session_jwt)
531 0 : .await?;
532 0 : if !disable_pg_session_jwt {
533 0 : let (cli_inner, _dsc) = client.client_inner();
534 0 : cli_inner.set_jwt_session(&payload).await?;
535 0 : }
536 0 : Client::Local(client)
537 : }
538 : _ => {
539 0 : let client = backend
540 0 : .connect_to_compute(ctx, conn_info, keys, !allow_pool)
541 0 : .await?;
542 0 : Client::Remote(client)
543 : }
544 : };
545 :
546 : // not strictly necessary to mark success here,
547 : // but it's just insurance for if we forget it somewhere else
548 0 : ctx.success();
549 0 : Ok::<_, SqlOverHttpError>(client)
550 0 : }
551 0 : .map_err(SqlOverHttpError::from),
552 : );
553 :
554 0 : let (payload, mut client) = match run_until_cancelled(
555 : // Run both operations in parallel
556 0 : try_join(
557 0 : pin!(fetch_and_process_request),
558 0 : pin!(authenticate_and_connect),
559 : ),
560 0 : &cancel,
561 : )
562 0 : .await
563 : {
564 0 : Some(result) => result?,
565 0 : None => return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Connect)),
566 : };
567 :
568 0 : let mut response = Response::builder()
569 0 : .status(StatusCode::OK)
570 0 : .header(header::CONTENT_TYPE, "application/json");
571 :
572 : // Now execute the query and return the result.
573 0 : let json_output = match payload {
574 0 : Payload::Single(stmt) => {
575 0 : stmt.process(&config.http_config, cancel, &mut client, parsed_headers)
576 0 : .await?
577 : }
578 0 : Payload::Batch(statements) => {
579 0 : if parsed_headers.txn_read_only {
580 0 : response = response.header(TXN_READ_ONLY.clone(), &HEADER_VALUE_TRUE);
581 0 : }
582 0 : if parsed_headers.txn_deferrable {
583 0 : response = response.header(TXN_DEFERRABLE.clone(), &HEADER_VALUE_TRUE);
584 0 : }
585 0 : if let Some(txn_isolation_level) = parsed_headers
586 0 : .txn_isolation_level
587 0 : .and_then(map_isolation_level_to_headers)
588 0 : {
589 0 : response = response.header(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level);
590 0 : }
591 :
592 0 : statements
593 0 : .process(&config.http_config, cancel, &mut client, parsed_headers)
594 0 : .await?
595 : }
596 : };
597 :
598 0 : let metrics = client.metrics(ctx);
599 :
600 0 : let len = json_output.len();
601 0 : let response = response
602 0 : .body(
603 0 : Full::new(Bytes::from(json_output))
604 0 : .map_err(|x| match x {})
605 0 : .boxed(),
606 : )
607 : // only fails if invalid status code or invalid header/values are given.
608 : // these are not user configurable so it cannot fail dynamically
609 0 : .expect("building response payload should not fail");
610 :
611 : // count the egress bytes - we miss the TLS and header overhead but oh well...
612 : // moving this later in the stack is going to be a lot of effort and ehhhh
613 0 : metrics.record_egress(len as u64);
614 0 : metrics.record_ingress(request_len as u64);
615 :
616 0 : Metrics::get()
617 0 : .proxy
618 0 : .http_conn_content_length_bytes
619 0 : .observe(HttpDirection::Response, len as f64);
620 :
621 0 : Ok(response)
622 0 : }
623 :
624 : static HEADERS_TO_FORWARD: &[&HeaderName] = &[
625 : &AUTHORIZATION,
626 : &CONN_STRING,
627 : &RAW_TEXT_OUTPUT,
628 : &ARRAY_MODE,
629 : &TXN_ISOLATION_LEVEL,
630 : &TXN_READ_ONLY,
631 : &TXN_DEFERRABLE,
632 : ];
633 :
634 0 : async fn handle_auth_broker_inner(
635 0 : ctx: &RequestContext,
636 0 : request: Request<Incoming>,
637 0 : conn_info: ConnInfo,
638 0 : jwt: String,
639 0 : backend: Arc<PoolingBackend>,
640 0 : ) -> Result<Response<BoxBody<Bytes, hyper::Error>>, SqlOverHttpError> {
641 0 : backend
642 0 : .authenticate_with_jwt(ctx, &conn_info.user_info, jwt)
643 0 : .await
644 0 : .map_err(HttpConnError::from)?;
645 :
646 0 : let mut client = backend.connect_to_local_proxy(ctx, conn_info).await?;
647 :
648 0 : let local_proxy_uri = ::http::Uri::from_static("http://proxy.local/sql");
649 :
650 0 : let (mut parts, body) = request.into_parts();
651 0 : let mut req = Request::builder().method(Method::POST).uri(local_proxy_uri);
652 :
653 : // todo(conradludgate): maybe auth-broker should parse these and re-serialize
654 : // these instead just to ensure they remain normalised.
655 0 : for &h in HEADERS_TO_FORWARD {
656 0 : if let Some(hv) = parts.headers.remove(h) {
657 0 : req = req.header(h, hv);
658 0 : }
659 : }
660 0 : req = req.header(&NEON_REQUEST_ID, uuid_to_header_value(ctx.session_id()));
661 :
662 0 : let req = req
663 0 : .body(body.map_err(|e| e).boxed()) //TODO: is there a potential for a regression here?
664 0 : .expect("all headers and params received via hyper should be valid for request");
665 :
666 : // todo: map body to count egress
667 0 : let _metrics = client.metrics(ctx);
668 :
669 0 : Ok(client
670 0 : .inner
671 0 : .inner
672 0 : .send_request(req)
673 0 : .await
674 0 : .map_err(LocalProxyConnError::from)
675 0 : .map_err(HttpConnError::from)?
676 0 : .map(|b| b.boxed()))
677 0 : }
678 :
679 : impl QueryData {
680 0 : async fn process(
681 0 : self,
682 0 : config: &'static HttpConfig,
683 0 : cancel: CancellationToken,
684 0 : client: &mut Client,
685 0 : parsed_headers: HttpHeaders,
686 0 : ) -> Result<String, SqlOverHttpError> {
687 0 : let (inner, mut discard) = client.inner();
688 0 : let cancel_token = inner.cancel_token();
689 :
690 0 : match select(
691 0 : pin!(query_to_json(
692 0 : config,
693 0 : &mut *inner,
694 0 : self,
695 0 : &mut 0,
696 0 : parsed_headers
697 : )),
698 0 : pin!(cancel.cancelled()),
699 : )
700 0 : .await
701 : {
702 : // The query successfully completed.
703 0 : Either::Left((Ok((status, results)), __not_yet_cancelled)) => {
704 0 : discard.check_idle(status);
705 :
706 0 : let json_output =
707 0 : serde_json::to_string(&results).expect("json serialization should not fail");
708 0 : Ok(json_output)
709 : }
710 : // The query failed with an error
711 0 : Either::Left((Err(e), __not_yet_cancelled)) => {
712 0 : discard.discard();
713 0 : Err(e)
714 : }
715 : // The query was cancelled.
716 0 : Either::Right((_cancelled, query)) => {
717 0 : tracing::info!("cancelling query");
718 0 : if let Err(err) = cancel_token.cancel_query(NoTls).await {
719 0 : tracing::warn!(?err, "could not cancel query");
720 0 : }
721 : // wait for the query cancellation
722 0 : match time::timeout(time::Duration::from_millis(100), query).await {
723 : // query successed before it was cancelled.
724 0 : Ok(Ok((status, results))) => {
725 0 : discard.check_idle(status);
726 :
727 0 : let json_output = serde_json::to_string(&results)
728 0 : .expect("json serialization should not fail");
729 0 : Ok(json_output)
730 : }
731 : // query failed or was cancelled.
732 0 : Ok(Err(error)) => {
733 0 : let db_error = match &error {
734 : SqlOverHttpError::ConnectCompute(
735 0 : HttpConnError::PostgresConnectionError(e),
736 : )
737 0 : | SqlOverHttpError::Postgres(e) => e.as_db_error(),
738 0 : _ => None,
739 : };
740 :
741 : // if errored for some other reason, it might not be safe to return
742 0 : if !db_error.is_some_and(|e| *e.code() == SqlState::QUERY_CANCELED) {
743 0 : discard.discard();
744 0 : }
745 :
746 0 : Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres))
747 : }
748 0 : Err(_timeout) => {
749 0 : discard.discard();
750 0 : Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres))
751 : }
752 : }
753 : }
754 : }
755 0 : }
756 : }
757 :
758 : impl BatchQueryData {
759 0 : async fn process(
760 0 : self,
761 0 : config: &'static HttpConfig,
762 0 : cancel: CancellationToken,
763 0 : client: &mut Client,
764 0 : parsed_headers: HttpHeaders,
765 0 : ) -> Result<String, SqlOverHttpError> {
766 0 : info!("starting transaction");
767 0 : let (inner, mut discard) = client.inner();
768 0 : let cancel_token = inner.cancel_token();
769 0 : let mut builder = inner.build_transaction();
770 0 : if let Some(isolation_level) = parsed_headers.txn_isolation_level {
771 0 : builder = builder.isolation_level(isolation_level);
772 0 : }
773 0 : if parsed_headers.txn_read_only {
774 0 : builder = builder.read_only(true);
775 0 : }
776 0 : if parsed_headers.txn_deferrable {
777 0 : builder = builder.deferrable(true);
778 0 : }
779 :
780 0 : let mut transaction = builder
781 0 : .start()
782 0 : .await
783 0 : .inspect_err(|_| {
784 : // if we cannot start a transaction, we should return immediately
785 : // and not return to the pool. connection is clearly broken
786 0 : discard.discard();
787 0 : })
788 0 : .map_err(SqlOverHttpError::Postgres)?;
789 :
790 0 : let json_output = match query_batch(
791 0 : config,
792 0 : cancel.child_token(),
793 0 : &mut transaction,
794 0 : self,
795 0 : parsed_headers,
796 : )
797 0 : .await
798 : {
799 0 : Ok(json_output) => {
800 0 : info!("commit");
801 0 : let status = transaction
802 0 : .commit()
803 0 : .await
804 0 : .inspect_err(|_| {
805 : // if we cannot commit - for now don't return connection to pool
806 : // TODO: get a query status from the error
807 0 : discard.discard();
808 0 : })
809 0 : .map_err(SqlOverHttpError::Postgres)?;
810 0 : discard.check_idle(status);
811 0 : json_output
812 : }
813 : Err(SqlOverHttpError::Cancelled(_)) => {
814 0 : if let Err(err) = cancel_token.cancel_query(NoTls).await {
815 0 : tracing::warn!(?err, "could not cancel query");
816 0 : }
817 : // TODO: after cancelling, wait to see if we can get a status. maybe the connection is still safe.
818 0 : discard.discard();
819 :
820 0 : return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres));
821 : }
822 0 : Err(err) => {
823 0 : info!("rollback");
824 0 : let status = transaction
825 0 : .rollback()
826 0 : .await
827 0 : .inspect_err(|_| {
828 : // if we cannot rollback - for now don't return connection to pool
829 : // TODO: get a query status from the error
830 0 : discard.discard();
831 0 : })
832 0 : .map_err(SqlOverHttpError::Postgres)?;
833 0 : discard.check_idle(status);
834 0 : return Err(err);
835 : }
836 : };
837 :
838 0 : Ok(json_output)
839 0 : }
840 : }
841 :
842 0 : async fn query_batch(
843 0 : config: &'static HttpConfig,
844 0 : cancel: CancellationToken,
845 0 : transaction: &mut Transaction<'_>,
846 0 : queries: BatchQueryData,
847 0 : parsed_headers: HttpHeaders,
848 0 : ) -> Result<String, SqlOverHttpError> {
849 0 : let mut results = Vec::with_capacity(queries.queries.len());
850 0 : let mut current_size = 0;
851 0 : for stmt in queries.queries {
852 0 : let query = pin!(query_to_json(
853 0 : config,
854 0 : transaction,
855 0 : stmt,
856 0 : &mut current_size,
857 0 : parsed_headers,
858 : ));
859 0 : let cancelled = pin!(cancel.cancelled());
860 0 : let res = select(query, cancelled).await;
861 0 : match res {
862 : // TODO: maybe we should check that the transaction bit is set here
863 0 : Either::Left((Ok((_, values)), _cancelled)) => {
864 0 : results.push(values);
865 0 : }
866 0 : Either::Left((Err(e), _cancelled)) => {
867 0 : return Err(e);
868 : }
869 0 : Either::Right((_cancelled, _)) => {
870 0 : return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres));
871 : }
872 : }
873 : }
874 :
875 0 : let results = json!({ "results": results });
876 0 : let json_output = serde_json::to_string(&results).expect("json serialization should not fail");
877 :
878 0 : Ok(json_output)
879 0 : }
880 :
881 0 : async fn query_to_json<T: GenericClient>(
882 0 : config: &'static HttpConfig,
883 0 : client: &mut T,
884 0 : data: QueryData,
885 0 : current_size: &mut usize,
886 0 : parsed_headers: HttpHeaders,
887 0 : ) -> Result<(ReadyForQueryStatus, impl Serialize + use<T>), SqlOverHttpError> {
888 0 : let query_start = Instant::now();
889 :
890 0 : let query_params = data.params;
891 0 : let mut row_stream = client
892 0 : .query_raw_txt(&data.query, query_params)
893 0 : .await
894 0 : .map_err(SqlOverHttpError::Postgres)?;
895 0 : let query_acknowledged = Instant::now();
896 :
897 0 : let columns_len = row_stream.statement.columns().len();
898 0 : let mut fields = Vec::with_capacity(columns_len);
899 :
900 0 : for c in row_stream.statement.columns() {
901 0 : fields.push(json!({
902 0 : "name": c.name().to_owned(),
903 0 : "dataTypeID": c.type_().oid(),
904 0 : "tableID": c.table_oid(),
905 0 : "columnID": c.column_id(),
906 0 : "dataTypeSize": c.type_size(),
907 0 : "dataTypeModifier": c.type_modifier(),
908 0 : "format": "text",
909 0 : }));
910 0 : }
911 :
912 0 : let raw_output = parsed_headers.raw_output;
913 0 : let array_mode = data.array_mode.unwrap_or(parsed_headers.default_array_mode);
914 :
915 : // Manually drain the stream into a vector to leave row_stream hanging
916 : // around to get a command tag. Also check that the response is not too
917 : // big.
918 0 : let mut rows = Vec::new();
919 0 : while let Some(row) = row_stream.next().await {
920 0 : let row = row.map_err(SqlOverHttpError::Postgres)?;
921 0 : *current_size += row.body_len();
922 :
923 : // we don't have a streaming response support yet so this is to prevent OOM
924 : // from a malicious query (eg a cross join)
925 0 : if *current_size > config.max_response_size_bytes {
926 0 : return Err(SqlOverHttpError::ResponseTooLarge(
927 0 : config.max_response_size_bytes,
928 0 : ));
929 0 : }
930 :
931 0 : let row = pg_text_row_to_json(&row, raw_output, array_mode)?;
932 0 : rows.push(row);
933 :
934 : // assumption: parsing pg text and converting to json takes CPU time.
935 : // let's assume it is slightly expensive, so we should consume some cooperative budget.
936 : // Especially considering that `RowStream::next` might be pulling from a batch
937 : // of rows and never hit the tokio mpsc for a long time (although unlikely).
938 0 : tokio::task::consume_budget().await;
939 : }
940 :
941 0 : let query_resp_end = Instant::now();
942 : let RowStream {
943 0 : command_tag,
944 0 : status: ready,
945 : ..
946 0 : } = row_stream;
947 :
948 : // grab the command tag and number of rows affected
949 0 : let command_tag = command_tag.unwrap_or_default();
950 0 : let mut command_tag_split = command_tag.split(' ');
951 0 : let command_tag_name = command_tag_split.next().unwrap_or_default();
952 0 : let command_tag_count = if command_tag_name == "INSERT" {
953 : // INSERT returns OID first and then number of rows
954 0 : command_tag_split.nth(1)
955 : } else {
956 : // other commands return number of rows (if any)
957 0 : command_tag_split.next()
958 : }
959 0 : .and_then(|s| s.parse::<i64>().ok());
960 :
961 0 : info!(
962 0 : rows = rows.len(),
963 : ?ready,
964 : command_tag,
965 0 : acknowledgement = ?(query_acknowledged - query_start),
966 0 : response = ?(query_resp_end - query_start),
967 0 : "finished executing query"
968 : );
969 :
970 : // Resulting JSON format is based on the format of node-postgres result.
971 0 : let results = json!({
972 0 : "command": command_tag_name.to_string(),
973 0 : "rowCount": command_tag_count,
974 0 : "rows": rows,
975 0 : "fields": fields,
976 0 : "rowAsArray": array_mode,
977 : });
978 :
979 0 : Ok((ready, results))
980 0 : }
981 :
982 : enum Client {
983 : Remote(conn_pool_lib::Client<postgres_client::Client>),
984 : Local(conn_pool_lib::Client<postgres_client::Client>),
985 : }
986 :
987 : enum Discard<'a> {
988 : Remote(conn_pool_lib::Discard<'a, postgres_client::Client>),
989 : Local(conn_pool_lib::Discard<'a, postgres_client::Client>),
990 : }
991 :
992 : impl Client {
993 0 : fn metrics(&self, ctx: &RequestContext) -> Arc<MetricCounter> {
994 0 : match self {
995 0 : Client::Remote(client) => client.metrics(ctx),
996 0 : Client::Local(local_client) => local_client.metrics(ctx),
997 : }
998 0 : }
999 :
1000 0 : fn inner(&mut self) -> (&mut postgres_client::Client, Discard<'_>) {
1001 0 : match self {
1002 0 : Client::Remote(client) => {
1003 0 : let (c, d) = client.inner();
1004 0 : (c, Discard::Remote(d))
1005 : }
1006 0 : Client::Local(local_client) => {
1007 0 : let (c, d) = local_client.inner();
1008 0 : (c, Discard::Local(d))
1009 : }
1010 : }
1011 0 : }
1012 : }
1013 :
1014 : impl Discard<'_> {
1015 0 : fn check_idle(&mut self, status: ReadyForQueryStatus) {
1016 0 : match self {
1017 0 : Discard::Remote(discard) => discard.check_idle(status),
1018 0 : Discard::Local(discard) => discard.check_idle(status),
1019 : }
1020 0 : }
1021 0 : fn discard(&mut self) {
1022 0 : match self {
1023 0 : Discard::Remote(discard) => discard.discard(),
1024 0 : Discard::Local(discard) => discard.discard(),
1025 : }
1026 0 : }
1027 : }
1028 :
1029 : #[cfg(test)]
1030 : mod tests {
1031 : use super::*;
1032 :
1033 : #[test]
1034 1 : fn test_payload() {
1035 1 : let payload = "{\"query\":\"SELECT * FROM users WHERE name = ?\",\"params\":[\"test\"],\"arrayMode\":true}";
1036 1 : let deserialized_payload: Payload = serde_json::from_str(payload).unwrap();
1037 :
1038 1 : match deserialized_payload {
1039 : Payload::Single(QueryData {
1040 1 : query,
1041 1 : params,
1042 1 : array_mode,
1043 : }) => {
1044 1 : assert_eq!(query, "SELECT * FROM users WHERE name = ?");
1045 1 : assert_eq!(params, vec![Some(String::from("test"))]);
1046 1 : assert!(array_mode.unwrap());
1047 : }
1048 : Payload::Batch(_) => {
1049 0 : panic!("deserialization failed: case with single query, one param, and array mode")
1050 : }
1051 : }
1052 :
1053 1 : let payload = "{\"queries\":[{\"query\":\"SELECT * FROM users0 WHERE name = ?\",\"params\":[\"test0\"], \"arrayMode\":false},{\"query\":\"SELECT * FROM users1 WHERE name = ?\",\"params\":[\"test1\"],\"arrayMode\":true}]}";
1054 1 : let deserialized_payload: Payload = serde_json::from_str(payload).unwrap();
1055 :
1056 1 : match deserialized_payload {
1057 1 : Payload::Batch(BatchQueryData { queries }) => {
1058 1 : assert_eq!(queries.len(), 2);
1059 2 : for (i, query) in queries.into_iter().enumerate() {
1060 2 : assert_eq!(
1061 : query.query,
1062 2 : format!("SELECT * FROM users{i} WHERE name = ?")
1063 : );
1064 2 : assert_eq!(query.params, vec![Some(format!("test{i}"))]);
1065 2 : assert_eq!(query.array_mode.unwrap(), i > 0);
1066 : }
1067 : }
1068 0 : Payload::Single(_) => panic!("deserialization failed: case with multiple queries"),
1069 : }
1070 :
1071 1 : let payload = "{\"query\":\"SELECT 1\"}";
1072 1 : let deserialized_payload: Payload = serde_json::from_str(payload).unwrap();
1073 :
1074 1 : match deserialized_payload {
1075 : Payload::Single(QueryData {
1076 1 : query,
1077 1 : params,
1078 1 : array_mode,
1079 : }) => {
1080 1 : assert_eq!(query, "SELECT 1");
1081 1 : assert_eq!(params, vec![]);
1082 1 : assert!(array_mode.is_none());
1083 : }
1084 0 : Payload::Batch(_) => panic!("deserialization failed: case with only one query"),
1085 : }
1086 1 : }
1087 : }
|