Line data Source code
1 : use std::pin::pin;
2 : use std::sync::Arc;
3 :
4 : use bytes::Bytes;
5 : use futures::future::{Either, select, try_join};
6 : use futures::{StreamExt, TryFutureExt};
7 : use http::Method;
8 : use http::header::AUTHORIZATION;
9 : use http_body_util::combinators::BoxBody;
10 : use http_body_util::{BodyExt, Full};
11 : use http_utils::error::ApiError;
12 : use hyper::body::Incoming;
13 : use hyper::http::{HeaderName, HeaderValue};
14 : use hyper::{Request, Response, StatusCode, header};
15 : use indexmap::IndexMap;
16 : use postgres_client::error::{DbError, ErrorPosition, SqlState};
17 : use postgres_client::{GenericClient, IsolationLevel, NoTls, ReadyForQueryStatus, Transaction};
18 : use serde_json::Value;
19 : use serde_json::value::RawValue;
20 : use tokio::time::{self, Instant};
21 : use tokio_util::sync::CancellationToken;
22 : use tracing::{Level, debug, error, info};
23 : use typed_json::json;
24 :
25 : use super::backend::{LocalProxyConnError, PoolingBackend};
26 : use super::conn_pool::AuthData;
27 : use super::conn_pool_lib::{self, ConnInfo};
28 : use super::error::{ConnInfoError, HttpCodeError, ReadPayloadError};
29 : use super::http_util::{
30 : ALLOW_POOL, ARRAY_MODE, CONN_STRING, NEON_REQUEST_ID, RAW_TEXT_OUTPUT, TXN_DEFERRABLE,
31 : TXN_ISOLATION_LEVEL, TXN_READ_ONLY, get_conn_info, json_response, uuid_to_header_value,
32 : };
33 : use super::json::{JsonConversionError, json_to_pg_text, pg_text_row_to_json};
34 : use crate::auth::backend::ComputeCredentialKeys;
35 : use crate::config::{HttpConfig, ProxyConfig};
36 : use crate::context::RequestContext;
37 : use crate::error::{ErrorKind, ReportableError, UserFacingError};
38 : use crate::http::read_body_with_limit;
39 : use crate::metrics::{HttpDirection, Metrics};
40 : use crate::serverless::backend::HttpConnError;
41 : use crate::usage_metrics::{MetricCounter, MetricCounterRecorder};
42 : use crate::util::run_until_cancelled;
43 :
44 : #[derive(serde::Deserialize)]
45 : #[serde(rename_all = "camelCase")]
46 : struct QueryData {
47 : query: String,
48 : #[serde(deserialize_with = "bytes_to_pg_text")]
49 : #[serde(default)]
50 : params: Vec<Option<String>>,
51 : #[serde(default)]
52 : array_mode: Option<bool>,
53 : }
54 :
55 0 : #[derive(serde::Deserialize)]
56 : struct BatchQueryData {
57 : queries: Vec<QueryData>,
58 : }
59 :
60 : #[derive(serde::Deserialize)]
61 : #[serde(untagged)]
62 : enum Payload {
63 : Single(QueryData),
64 : Batch(BatchQueryData),
65 : }
66 :
67 : pub(super) const HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true");
68 :
69 3 : fn bytes_to_pg_text<'de, D>(deserializer: D) -> Result<Vec<Option<String>>, D::Error>
70 3 : where
71 3 : D: serde::de::Deserializer<'de>,
72 : {
73 : // TODO: consider avoiding the allocation here.
74 3 : let json: Vec<Value> = serde::de::Deserialize::deserialize(deserializer)?;
75 3 : Ok(json_to_pg_text(json))
76 3 : }
77 :
78 0 : pub(crate) async fn handle(
79 0 : config: &'static ProxyConfig,
80 0 : ctx: RequestContext,
81 0 : request: Request<Incoming>,
82 0 : backend: Arc<PoolingBackend>,
83 0 : cancel: CancellationToken,
84 0 : ) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
85 0 : let result = handle_inner(cancel, config, &ctx, request, backend).await;
86 :
87 0 : let mut response = match result {
88 0 : Ok(r) => {
89 0 : ctx.set_success();
90 :
91 : // Handling the error response from local proxy here
92 0 : if config.authentication_config.is_auth_broker && r.status().is_server_error() {
93 0 : let status = r.status();
94 :
95 0 : let body_bytes = r
96 0 : .collect()
97 0 : .await
98 0 : .map_err(|e| {
99 0 : ApiError::InternalServerError(anyhow::Error::msg(format!(
100 0 : "could not collect http body: {e}"
101 0 : )))
102 0 : })?
103 0 : .to_bytes();
104 :
105 0 : if let Ok(mut json_map) =
106 0 : serde_json::from_slice::<IndexMap<&str, &RawValue>>(&body_bytes)
107 : {
108 0 : let message = json_map.get("message");
109 0 : if let Some(message) = message {
110 0 : let msg: String = match serde_json::from_str(message.get()) {
111 0 : Ok(msg) => msg,
112 : Err(_) => {
113 0 : "Unable to parse the response message from server".to_string()
114 : }
115 : };
116 :
117 0 : error!("Error response from local_proxy: {status} {msg}");
118 :
119 0 : json_map.retain(|key, _| !key.starts_with("neon:")); // remove all the neon-related keys
120 :
121 0 : let resp_json = serde_json::to_string(&json_map)
122 0 : .unwrap_or("failed to serialize the response message".to_string());
123 :
124 0 : return json_response(status, resp_json);
125 0 : }
126 0 : }
127 :
128 0 : error!("Unable to parse the response message from local_proxy");
129 0 : return json_response(
130 0 : status,
131 0 : json!({ "message": "Unable to parse the response message from server".to_string() }),
132 : );
133 0 : }
134 0 : r
135 : }
136 0 : Err(e @ SqlOverHttpError::Cancelled(_)) => {
137 0 : let error_kind = e.get_error_kind();
138 0 : ctx.set_error_kind(error_kind);
139 :
140 0 : let message = "Query cancelled, connection was terminated";
141 :
142 0 : tracing::info!(
143 0 : kind=error_kind.to_metric_label(),
144 : error=%e,
145 : msg=message,
146 0 : "forwarding error to user"
147 : );
148 :
149 0 : json_response(
150 : StatusCode::BAD_REQUEST,
151 0 : json!({ "message": message, "code": SqlState::PROTOCOL_VIOLATION.code() }),
152 0 : )?
153 : }
154 0 : Err(e) => {
155 0 : let error_kind = e.get_error_kind();
156 0 : ctx.set_error_kind(error_kind);
157 :
158 0 : let mut message = e.to_string_client();
159 0 : let db_error = match &e {
160 0 : SqlOverHttpError::ConnectCompute(HttpConnError::PostgresConnectionError(e))
161 0 : | SqlOverHttpError::Postgres(e) => e.as_db_error(),
162 0 : _ => None,
163 : };
164 0 : fn get<'a, T: Default>(db: Option<&'a DbError>, x: impl FnOnce(&'a DbError) -> T) -> T {
165 0 : db.map(x).unwrap_or_default()
166 0 : }
167 :
168 0 : if let Some(db_error) = db_error {
169 0 : db_error.message().clone_into(&mut message);
170 0 : }
171 :
172 0 : let position = db_error.and_then(|db| db.position());
173 0 : let (position, internal_position, internal_query) = match position {
174 0 : Some(ErrorPosition::Original(position)) => (Some(position.to_string()), None, None),
175 0 : Some(ErrorPosition::Internal { position, query }) => {
176 0 : (None, Some(position.to_string()), Some(query.clone()))
177 : }
178 0 : None => (None, None, None),
179 : };
180 :
181 0 : let code = get(db_error, |db| db.code().code());
182 0 : let severity = get(db_error, |db| db.severity());
183 0 : let detail = get(db_error, |db| db.detail());
184 0 : let hint = get(db_error, |db| db.hint());
185 0 : let where_ = get(db_error, |db| db.where_());
186 0 : let table = get(db_error, |db| db.table());
187 0 : let column = get(db_error, |db| db.column());
188 0 : let schema = get(db_error, |db| db.schema());
189 0 : let datatype = get(db_error, |db| db.datatype());
190 0 : let constraint = get(db_error, |db| db.constraint());
191 0 : let file = get(db_error, |db| db.file());
192 0 : let line = get(db_error, |db| db.line().map(|l| l.to_string()));
193 0 : let routine = get(db_error, |db| db.routine());
194 :
195 0 : match &e {
196 0 : SqlOverHttpError::Postgres(e)
197 0 : if e.as_db_error().is_some() && error_kind == ErrorKind::User =>
198 : {
199 : // this error contains too much info, and it's not an error we care about.
200 0 : if tracing::enabled!(Level::DEBUG) {
201 0 : tracing::debug!(
202 0 : kind=error_kind.to_metric_label(),
203 : error=%e,
204 : msg=message,
205 0 : "forwarding error to user"
206 : );
207 : } else {
208 0 : tracing::info!(
209 0 : kind = error_kind.to_metric_label(),
210 : error = "bad query",
211 0 : "forwarding error to user"
212 : );
213 : }
214 : }
215 : _ => {
216 0 : tracing::info!(
217 0 : kind=error_kind.to_metric_label(),
218 : error=%e,
219 : msg=message,
220 0 : "forwarding error to user"
221 : );
222 : }
223 : }
224 :
225 0 : json_response(
226 0 : e.get_http_status_code(),
227 0 : json!({
228 0 : "message": message,
229 0 : "code": code,
230 0 : "detail": detail,
231 0 : "hint": hint,
232 0 : "position": position,
233 0 : "internalPosition": internal_position,
234 0 : "internalQuery": internal_query,
235 0 : "severity": severity,
236 0 : "where": where_,
237 0 : "table": table,
238 0 : "column": column,
239 0 : "schema": schema,
240 0 : "dataType": datatype,
241 0 : "constraint": constraint,
242 0 : "file": file,
243 0 : "line": line,
244 0 : "routine": routine,
245 : }),
246 0 : )?
247 : }
248 : };
249 :
250 0 : response
251 0 : .headers_mut()
252 0 : .insert("Access-Control-Allow-Origin", HeaderValue::from_static("*"));
253 0 : Ok(response)
254 0 : }
255 :
256 : #[derive(Debug, thiserror::Error)]
257 : pub(crate) enum SqlOverHttpError {
258 : #[error("{0}")]
259 : ReadPayload(#[from] ReadPayloadError),
260 : #[error("{0}")]
261 : ConnectCompute(#[from] HttpConnError),
262 : #[error("{0}")]
263 : ConnInfo(#[from] ConnInfoError),
264 : #[error("response is too large (max is {0} bytes)")]
265 : ResponseTooLarge(usize),
266 : #[error("invalid isolation level")]
267 : InvalidIsolationLevel,
268 : /// for queries our customers choose to run
269 : #[error("{0}")]
270 : Postgres(#[source] postgres_client::Error),
271 : /// for queries we choose to run
272 : #[error("{0}")]
273 : InternalPostgres(#[source] postgres_client::Error),
274 : #[error("{0}")]
275 : JsonConversion(#[from] JsonConversionError),
276 : #[error("{0}")]
277 : Cancelled(SqlOverHttpCancel),
278 : }
279 :
280 : impl ReportableError for SqlOverHttpError {
281 0 : fn get_error_kind(&self) -> ErrorKind {
282 0 : match self {
283 0 : SqlOverHttpError::ReadPayload(e) => e.get_error_kind(),
284 0 : SqlOverHttpError::ConnectCompute(e) => e.get_error_kind(),
285 0 : SqlOverHttpError::ConnInfo(e) => e.get_error_kind(),
286 0 : SqlOverHttpError::ResponseTooLarge(_) => ErrorKind::User,
287 0 : SqlOverHttpError::InvalidIsolationLevel => ErrorKind::User,
288 : // customer initiated SQL errors.
289 0 : SqlOverHttpError::Postgres(p) => {
290 0 : if p.as_db_error().is_some() {
291 0 : ErrorKind::User
292 : } else {
293 0 : ErrorKind::Compute
294 : }
295 : }
296 : // proxy initiated SQL errors.
297 0 : SqlOverHttpError::InternalPostgres(p) => {
298 0 : if p.as_db_error().is_some() {
299 0 : ErrorKind::Service
300 : } else {
301 0 : ErrorKind::Compute
302 : }
303 : }
304 : // postgres returned a bad row format that we couldn't parse.
305 0 : SqlOverHttpError::JsonConversion(_) => ErrorKind::Postgres,
306 0 : SqlOverHttpError::Cancelled(c) => c.get_error_kind(),
307 : }
308 0 : }
309 : }
310 :
311 : impl UserFacingError for SqlOverHttpError {
312 0 : fn to_string_client(&self) -> String {
313 0 : match self {
314 0 : SqlOverHttpError::ReadPayload(p) => p.to_string(),
315 0 : SqlOverHttpError::ConnectCompute(c) => c.to_string_client(),
316 0 : SqlOverHttpError::ConnInfo(c) => c.to_string_client(),
317 0 : SqlOverHttpError::ResponseTooLarge(_) => self.to_string(),
318 0 : SqlOverHttpError::InvalidIsolationLevel => self.to_string(),
319 0 : SqlOverHttpError::Postgres(p) => p.to_string(),
320 0 : SqlOverHttpError::InternalPostgres(p) => p.to_string(),
321 0 : SqlOverHttpError::JsonConversion(_) => "could not parse postgres response".to_string(),
322 0 : SqlOverHttpError::Cancelled(_) => self.to_string(),
323 : }
324 0 : }
325 : }
326 :
327 : impl HttpCodeError for SqlOverHttpError {
328 0 : fn get_http_status_code(&self) -> StatusCode {
329 0 : match self {
330 0 : SqlOverHttpError::ReadPayload(e) => e.get_http_status_code(),
331 0 : SqlOverHttpError::ConnectCompute(h) => match h.get_error_kind() {
332 0 : ErrorKind::User => StatusCode::BAD_REQUEST,
333 0 : _ => StatusCode::INTERNAL_SERVER_ERROR,
334 : },
335 0 : SqlOverHttpError::ConnInfo(_) => StatusCode::BAD_REQUEST,
336 0 : SqlOverHttpError::ResponseTooLarge(_) => StatusCode::INSUFFICIENT_STORAGE,
337 0 : SqlOverHttpError::InvalidIsolationLevel => StatusCode::BAD_REQUEST,
338 0 : SqlOverHttpError::Postgres(_) => StatusCode::BAD_REQUEST,
339 0 : SqlOverHttpError::InternalPostgres(_) => StatusCode::INTERNAL_SERVER_ERROR,
340 0 : SqlOverHttpError::JsonConversion(_) => StatusCode::INTERNAL_SERVER_ERROR,
341 0 : SqlOverHttpError::Cancelled(_) => StatusCode::INTERNAL_SERVER_ERROR,
342 : }
343 0 : }
344 : }
345 :
346 : #[derive(Debug, thiserror::Error)]
347 : pub(crate) enum SqlOverHttpCancel {
348 : #[error("query was cancelled")]
349 : Postgres,
350 : #[error("query was cancelled while stuck trying to connect to the database")]
351 : Connect,
352 : }
353 :
354 : impl ReportableError for SqlOverHttpCancel {
355 0 : fn get_error_kind(&self) -> ErrorKind {
356 0 : match self {
357 0 : SqlOverHttpCancel::Postgres => ErrorKind::ClientDisconnect,
358 0 : SqlOverHttpCancel::Connect => ErrorKind::ClientDisconnect,
359 : }
360 0 : }
361 : }
362 :
363 : #[derive(Clone, Copy, Debug)]
364 : struct HttpHeaders {
365 : raw_output: bool,
366 : default_array_mode: bool,
367 : txn_isolation_level: Option<IsolationLevel>,
368 : txn_read_only: bool,
369 : txn_deferrable: bool,
370 : }
371 :
372 : impl HttpHeaders {
373 0 : fn try_parse(headers: &hyper::http::HeaderMap) -> Result<Self, SqlOverHttpError> {
374 : // Determine the output options. Default behaviour is 'false'. Anything that is not
375 : // strictly 'true' assumed to be false.
376 0 : let raw_output = headers.get(&RAW_TEXT_OUTPUT) == Some(&HEADER_VALUE_TRUE);
377 0 : let default_array_mode = headers.get(&ARRAY_MODE) == Some(&HEADER_VALUE_TRUE);
378 :
379 : // isolation level, read only and deferrable
380 0 : let txn_isolation_level = match headers.get(&TXN_ISOLATION_LEVEL) {
381 0 : Some(x) => Some(
382 0 : map_header_to_isolation_level(x).ok_or(SqlOverHttpError::InvalidIsolationLevel)?,
383 : ),
384 0 : None => None,
385 : };
386 :
387 0 : let txn_read_only = headers.get(&TXN_READ_ONLY) == Some(&HEADER_VALUE_TRUE);
388 0 : let txn_deferrable = headers.get(&TXN_DEFERRABLE) == Some(&HEADER_VALUE_TRUE);
389 :
390 0 : Ok(Self {
391 0 : raw_output,
392 0 : default_array_mode,
393 0 : txn_isolation_level,
394 0 : txn_read_only,
395 0 : txn_deferrable,
396 0 : })
397 0 : }
398 : }
399 :
400 0 : fn map_header_to_isolation_level(level: &HeaderValue) -> Option<IsolationLevel> {
401 0 : match level.as_bytes() {
402 0 : b"Serializable" => Some(IsolationLevel::Serializable),
403 0 : b"ReadUncommitted" => Some(IsolationLevel::ReadUncommitted),
404 0 : b"ReadCommitted" => Some(IsolationLevel::ReadCommitted),
405 0 : b"RepeatableRead" => Some(IsolationLevel::RepeatableRead),
406 0 : _ => None,
407 : }
408 0 : }
409 :
410 0 : fn map_isolation_level_to_headers(level: IsolationLevel) -> Option<HeaderValue> {
411 0 : match level {
412 0 : IsolationLevel::ReadUncommitted => Some(HeaderValue::from_static("ReadUncommitted")),
413 0 : IsolationLevel::ReadCommitted => Some(HeaderValue::from_static("ReadCommitted")),
414 0 : IsolationLevel::RepeatableRead => Some(HeaderValue::from_static("RepeatableRead")),
415 0 : IsolationLevel::Serializable => Some(HeaderValue::from_static("Serializable")),
416 0 : _ => None,
417 : }
418 0 : }
419 :
420 0 : async fn handle_inner(
421 0 : cancel: CancellationToken,
422 0 : config: &'static ProxyConfig,
423 0 : ctx: &RequestContext,
424 0 : request: Request<Incoming>,
425 0 : backend: Arc<PoolingBackend>,
426 0 : ) -> Result<Response<BoxBody<Bytes, hyper::Error>>, SqlOverHttpError> {
427 0 : let _requeset_gauge = Metrics::get()
428 0 : .proxy
429 0 : .connection_requests
430 0 : .guard(ctx.protocol());
431 0 : info!(
432 0 : protocol = %ctx.protocol(),
433 0 : "handling interactive connection from client"
434 : );
435 :
436 0 : let conn_info = get_conn_info(&config.authentication_config, ctx, None, request.headers())?;
437 0 : info!(
438 0 : user = conn_info.conn_info.user_info.user.as_str(),
439 0 : "credentials"
440 : );
441 :
442 0 : match conn_info.auth {
443 0 : AuthData::Jwt(jwt) if config.authentication_config.is_auth_broker => {
444 0 : handle_auth_broker_inner(ctx, request, conn_info.conn_info, jwt, backend).await
445 : }
446 0 : auth => {
447 0 : handle_db_inner(
448 0 : cancel,
449 0 : config,
450 0 : ctx,
451 0 : request,
452 0 : conn_info.conn_info,
453 0 : auth,
454 0 : backend,
455 0 : )
456 0 : .await
457 : }
458 : }
459 0 : }
460 :
461 0 : async fn handle_db_inner(
462 0 : cancel: CancellationToken,
463 0 : config: &'static ProxyConfig,
464 0 : ctx: &RequestContext,
465 0 : request: Request<Incoming>,
466 0 : conn_info: ConnInfo,
467 0 : auth: AuthData,
468 0 : backend: Arc<PoolingBackend>,
469 0 : ) -> Result<Response<BoxBody<Bytes, hyper::Error>>, SqlOverHttpError> {
470 : //
471 : // Determine the destination and connection params
472 : //
473 0 : let headers = request.headers();
474 :
475 : // Allow connection pooling only if explicitly requested
476 : // or if we have decided that http pool is no longer opt-in
477 0 : let allow_pool = !config.http_config.pool_options.opt_in
478 0 : || headers.get(&ALLOW_POOL) == Some(&HEADER_VALUE_TRUE);
479 :
480 0 : let parsed_headers = HttpHeaders::try_parse(headers)?;
481 :
482 0 : let mut request_len = 0;
483 0 : let fetch_and_process_request = Box::pin(
484 0 : async {
485 0 : let body = read_body_with_limit(
486 0 : request.into_body(),
487 0 : config.http_config.max_request_size_bytes,
488 0 : )
489 0 : .await?;
490 :
491 0 : request_len = body.len();
492 :
493 0 : Metrics::get()
494 0 : .proxy
495 0 : .http_conn_content_length_bytes
496 0 : .observe(HttpDirection::Request, body.len() as f64);
497 :
498 0 : debug!(length = body.len(), "request payload read");
499 0 : let payload: Payload = serde_json::from_slice(&body)?;
500 0 : Ok::<Payload, ReadPayloadError>(payload) // Adjust error type accordingly
501 0 : }
502 0 : .map_err(SqlOverHttpError::from),
503 : );
504 :
505 0 : let authenticate_and_connect = Box::pin(
506 0 : async {
507 0 : let keys = match auth {
508 0 : AuthData::Password(pw) => backend
509 0 : .authenticate_with_password(ctx, &conn_info.user_info, &pw)
510 0 : .await
511 0 : .map_err(HttpConnError::AuthError)?,
512 0 : AuthData::Jwt(jwt) => backend
513 0 : .authenticate_with_jwt(ctx, &conn_info.user_info, jwt)
514 0 : .await
515 0 : .map_err(HttpConnError::AuthError)?,
516 : };
517 :
518 0 : let client = match keys.keys {
519 0 : ComputeCredentialKeys::JwtPayload(payload)
520 0 : if backend.auth_backend.is_local_proxy() =>
521 : {
522 : #[cfg(feature = "testing")]
523 0 : let disable_pg_session_jwt = config.disable_pg_session_jwt;
524 : #[cfg(not(feature = "testing"))]
525 : let disable_pg_session_jwt = false;
526 0 : let mut client = backend
527 0 : .connect_to_local_postgres(ctx, conn_info, disable_pg_session_jwt)
528 0 : .await?;
529 0 : if !disable_pg_session_jwt {
530 0 : let (cli_inner, _dsc) = client.client_inner();
531 0 : cli_inner.set_jwt_session(&payload).await?;
532 0 : }
533 0 : Client::Local(client)
534 : }
535 : _ => {
536 0 : let client = backend
537 0 : .connect_to_compute(ctx, conn_info, keys, !allow_pool)
538 0 : .await?;
539 0 : Client::Remote(client)
540 : }
541 : };
542 :
543 : // not strictly necessary to mark success here,
544 : // but it's just insurance for if we forget it somewhere else
545 0 : ctx.success();
546 0 : Ok::<_, SqlOverHttpError>(client)
547 0 : }
548 0 : .map_err(SqlOverHttpError::from),
549 : );
550 :
551 0 : let (payload, mut client) = match run_until_cancelled(
552 : // Run both operations in parallel
553 0 : try_join(
554 0 : pin!(fetch_and_process_request),
555 0 : pin!(authenticate_and_connect),
556 : ),
557 0 : &cancel,
558 : )
559 0 : .await
560 : {
561 0 : Some(result) => result?,
562 0 : None => return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Connect)),
563 : };
564 :
565 0 : let mut response = Response::builder()
566 0 : .status(StatusCode::OK)
567 0 : .header(header::CONTENT_TYPE, "application/json");
568 :
569 : // Now execute the query and return the result.
570 0 : let json_output = match payload {
571 0 : Payload::Single(stmt) => {
572 0 : stmt.process(&config.http_config, cancel, &mut client, parsed_headers)
573 0 : .await?
574 : }
575 0 : Payload::Batch(statements) => {
576 0 : if parsed_headers.txn_read_only {
577 0 : response = response.header(TXN_READ_ONLY.clone(), &HEADER_VALUE_TRUE);
578 0 : }
579 0 : if parsed_headers.txn_deferrable {
580 0 : response = response.header(TXN_DEFERRABLE.clone(), &HEADER_VALUE_TRUE);
581 0 : }
582 0 : if let Some(txn_isolation_level) = parsed_headers
583 0 : .txn_isolation_level
584 0 : .and_then(map_isolation_level_to_headers)
585 0 : {
586 0 : response = response.header(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level);
587 0 : }
588 :
589 0 : statements
590 0 : .process(&config.http_config, cancel, &mut client, parsed_headers)
591 0 : .await?
592 : }
593 : };
594 :
595 0 : let metrics = client.metrics(ctx);
596 :
597 0 : let len = json_output.len();
598 0 : let response = response
599 0 : .body(
600 0 : Full::new(Bytes::from(json_output))
601 0 : .map_err(|x| match x {})
602 0 : .boxed(),
603 : )
604 : // only fails if invalid status code or invalid header/values are given.
605 : // these are not user configurable so it cannot fail dynamically
606 0 : .expect("building response payload should not fail");
607 :
608 : // count the egress bytes - we miss the TLS and header overhead but oh well...
609 : // moving this later in the stack is going to be a lot of effort and ehhhh
610 0 : metrics.record_egress(len as u64);
611 0 : metrics.record_ingress(request_len as u64);
612 :
613 0 : Metrics::get()
614 0 : .proxy
615 0 : .http_conn_content_length_bytes
616 0 : .observe(HttpDirection::Response, len as f64);
617 :
618 0 : Ok(response)
619 0 : }
620 :
621 : static HEADERS_TO_FORWARD: &[&HeaderName] = &[
622 : &AUTHORIZATION,
623 : &CONN_STRING,
624 : &RAW_TEXT_OUTPUT,
625 : &ARRAY_MODE,
626 : &TXN_ISOLATION_LEVEL,
627 : &TXN_READ_ONLY,
628 : &TXN_DEFERRABLE,
629 : ];
630 :
631 0 : async fn handle_auth_broker_inner(
632 0 : ctx: &RequestContext,
633 0 : request: Request<Incoming>,
634 0 : conn_info: ConnInfo,
635 0 : jwt: String,
636 0 : backend: Arc<PoolingBackend>,
637 0 : ) -> Result<Response<BoxBody<Bytes, hyper::Error>>, SqlOverHttpError> {
638 0 : backend
639 0 : .authenticate_with_jwt(ctx, &conn_info.user_info, jwt)
640 0 : .await
641 0 : .map_err(HttpConnError::from)?;
642 :
643 0 : let mut client = backend.connect_to_local_proxy(ctx, conn_info).await?;
644 :
645 0 : let local_proxy_uri = ::http::Uri::from_static("http://proxy.local/sql");
646 :
647 0 : let (mut parts, body) = request.into_parts();
648 0 : let mut req = Request::builder().method(Method::POST).uri(local_proxy_uri);
649 :
650 : // todo(conradludgate): maybe auth-broker should parse these and re-serialize
651 : // these instead just to ensure they remain normalised.
652 0 : for &h in HEADERS_TO_FORWARD {
653 0 : if let Some(hv) = parts.headers.remove(h) {
654 0 : req = req.header(h, hv);
655 0 : }
656 : }
657 0 : req = req.header(&NEON_REQUEST_ID, uuid_to_header_value(ctx.session_id()));
658 :
659 0 : let req = req
660 0 : .body(body.map_err(|e| e).boxed()) //TODO: is there a potential for a regression here?
661 0 : .expect("all headers and params received via hyper should be valid for request");
662 :
663 : // todo: map body to count egress
664 0 : let _metrics = client.metrics(ctx);
665 :
666 0 : Ok(client
667 0 : .inner
668 0 : .inner
669 0 : .send_request(req)
670 0 : .await
671 0 : .map_err(LocalProxyConnError::from)
672 0 : .map_err(HttpConnError::from)?
673 0 : .map(|b| b.boxed()))
674 0 : }
675 :
676 : impl QueryData {
677 0 : async fn process(
678 0 : self,
679 0 : config: &'static HttpConfig,
680 0 : cancel: CancellationToken,
681 0 : client: &mut Client,
682 0 : parsed_headers: HttpHeaders,
683 0 : ) -> Result<String, SqlOverHttpError> {
684 0 : let (inner, mut discard) = client.inner();
685 0 : let cancel_token = inner.cancel_token();
686 :
687 0 : let mut json_buf = vec![];
688 :
689 0 : let batch_result = match select(
690 0 : pin!(query_to_json(
691 0 : config,
692 0 : &mut *inner,
693 0 : self,
694 0 : json::ValueSer::new(&mut json_buf),
695 0 : parsed_headers
696 : )),
697 0 : pin!(cancel.cancelled()),
698 : )
699 0 : .await
700 : {
701 0 : Either::Left((res, __not_yet_cancelled)) => res,
702 0 : Either::Right((_cancelled, query)) => {
703 0 : tracing::info!("cancelling query");
704 0 : if let Err(err) = cancel_token.cancel_query(NoTls).await {
705 0 : tracing::warn!(?err, "could not cancel query");
706 0 : }
707 : // wait for the query cancellation
708 0 : match time::timeout(time::Duration::from_millis(100), query).await {
709 : // query successed before it was cancelled.
710 0 : Ok(Ok(status)) => Ok(status),
711 : // query failed or was cancelled.
712 0 : Ok(Err(error)) => {
713 0 : let db_error = match &error {
714 : SqlOverHttpError::ConnectCompute(
715 0 : HttpConnError::PostgresConnectionError(e),
716 : )
717 0 : | SqlOverHttpError::Postgres(e) => e.as_db_error(),
718 0 : _ => None,
719 : };
720 :
721 : // if errored for some other reason, it might not be safe to return
722 0 : if !db_error.is_some_and(|e| *e.code() == SqlState::QUERY_CANCELED) {
723 0 : discard.discard();
724 0 : }
725 :
726 0 : return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres));
727 : }
728 0 : Err(_timeout) => {
729 0 : discard.discard();
730 0 : return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres));
731 : }
732 : }
733 : }
734 : };
735 :
736 0 : match batch_result {
737 : // The query successfully completed.
738 : Ok(_) => {
739 0 : let json_output = String::from_utf8(json_buf).expect("json should be valid utf8");
740 0 : Ok(json_output)
741 : }
742 : // The query failed with an error
743 0 : Err(e) => {
744 0 : discard.discard();
745 0 : Err(e)
746 : }
747 : }
748 0 : }
749 : }
750 :
751 : impl BatchQueryData {
752 0 : async fn process(
753 0 : self,
754 0 : config: &'static HttpConfig,
755 0 : cancel: CancellationToken,
756 0 : client: &mut Client,
757 0 : parsed_headers: HttpHeaders,
758 0 : ) -> Result<String, SqlOverHttpError> {
759 0 : info!("starting transaction");
760 0 : let (inner, mut discard) = client.inner();
761 0 : let cancel_token = inner.cancel_token();
762 0 : let mut builder = inner.build_transaction();
763 0 : if let Some(isolation_level) = parsed_headers.txn_isolation_level {
764 0 : builder = builder.isolation_level(isolation_level);
765 0 : }
766 0 : if parsed_headers.txn_read_only {
767 0 : builder = builder.read_only(true);
768 0 : }
769 0 : if parsed_headers.txn_deferrable {
770 0 : builder = builder.deferrable(true);
771 0 : }
772 :
773 0 : let mut transaction = builder
774 0 : .start()
775 0 : .await
776 0 : .inspect_err(|_| {
777 : // if we cannot start a transaction, we should return immediately
778 : // and not return to the pool. connection is clearly broken
779 0 : discard.discard();
780 0 : })
781 0 : .map_err(SqlOverHttpError::Postgres)?;
782 :
783 0 : let json_output = match query_batch_to_json(
784 0 : config,
785 0 : cancel.child_token(),
786 0 : &mut transaction,
787 0 : self,
788 0 : parsed_headers,
789 : )
790 0 : .await
791 : {
792 0 : Ok(json_output) => {
793 0 : info!("commit");
794 0 : transaction
795 0 : .commit()
796 0 : .await
797 0 : .inspect_err(|_| {
798 : // if we cannot commit - for now don't return connection to pool
799 : // TODO: get a query status from the error
800 0 : discard.discard();
801 0 : })
802 0 : .map_err(SqlOverHttpError::Postgres)?;
803 0 : json_output
804 : }
805 : Err(SqlOverHttpError::Cancelled(_)) => {
806 0 : if let Err(err) = cancel_token.cancel_query(NoTls).await {
807 0 : tracing::warn!(?err, "could not cancel query");
808 0 : }
809 : // TODO: after cancelling, wait to see if we can get a status. maybe the connection is still safe.
810 0 : discard.discard();
811 :
812 0 : return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres));
813 : }
814 0 : Err(err) => {
815 0 : return Err(err);
816 : }
817 : };
818 :
819 0 : Ok(json_output)
820 0 : }
821 : }
822 :
823 0 : async fn query_batch(
824 0 : config: &'static HttpConfig,
825 0 : cancel: CancellationToken,
826 0 : transaction: &mut Transaction<'_>,
827 0 : queries: BatchQueryData,
828 0 : parsed_headers: HttpHeaders,
829 0 : results: &mut json::ListSer<'_>,
830 0 : ) -> Result<(), SqlOverHttpError> {
831 0 : for stmt in queries.queries {
832 0 : let query = pin!(query_to_json(
833 0 : config,
834 0 : transaction,
835 0 : stmt,
836 0 : results.entry(),
837 0 : parsed_headers,
838 : ));
839 0 : let cancelled = pin!(cancel.cancelled());
840 0 : let res = select(query, cancelled).await;
841 0 : match res {
842 : // TODO: maybe we should check that the transaction bit is set here
843 0 : Either::Left((Ok(_), _cancelled)) => {}
844 0 : Either::Left((Err(e), _cancelled)) => {
845 0 : return Err(e);
846 : }
847 0 : Either::Right((_cancelled, _)) => {
848 0 : return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres));
849 : }
850 : }
851 : }
852 :
853 0 : Ok(())
854 0 : }
855 :
856 0 : async fn query_batch_to_json(
857 0 : config: &'static HttpConfig,
858 0 : cancel: CancellationToken,
859 0 : tx: &mut Transaction<'_>,
860 0 : queries: BatchQueryData,
861 0 : headers: HttpHeaders,
862 0 : ) -> Result<String, SqlOverHttpError> {
863 0 : let json_output = json::value_to_string!(|obj| json::value_as_object!(|obj| {
864 0 : let results = obj.key("results");
865 0 : json::value_as_list!(|results| {
866 0 : query_batch(config, cancel, tx, queries, headers, results).await?;
867 : });
868 : }));
869 :
870 0 : Ok(json_output)
871 0 : }
872 :
873 0 : async fn query_to_json<T: GenericClient>(
874 0 : config: &'static HttpConfig,
875 0 : client: &mut T,
876 0 : data: QueryData,
877 0 : output: json::ValueSer<'_>,
878 0 : parsed_headers: HttpHeaders,
879 0 : ) -> Result<ReadyForQueryStatus, SqlOverHttpError> {
880 0 : let query_start = Instant::now();
881 :
882 0 : let mut output = json::ObjectSer::new(output);
883 0 : let mut row_stream = client
884 0 : .query_raw_txt(&data.query, data.params)
885 0 : .await
886 0 : .map_err(SqlOverHttpError::Postgres)?;
887 0 : let query_acknowledged = Instant::now();
888 :
889 0 : let mut json_fields = output.key("fields").list();
890 0 : for c in row_stream.statement.columns() {
891 0 : let json_field = json_fields.entry();
892 0 : json::value_as_object!(|json_field| {
893 0 : json_field.entry("name", c.name());
894 0 : json_field.entry("dataTypeID", c.type_().oid());
895 0 : json_field.entry("tableID", c.table_oid());
896 0 : json_field.entry("columnID", c.column_id());
897 0 : json_field.entry("dataTypeSize", c.type_size());
898 0 : json_field.entry("dataTypeModifier", c.type_modifier());
899 0 : json_field.entry("format", "text");
900 0 : });
901 0 : }
902 0 : json_fields.finish();
903 :
904 0 : let array_mode = data.array_mode.unwrap_or(parsed_headers.default_array_mode);
905 0 : let raw_output = parsed_headers.raw_output;
906 :
907 : // Manually drain the stream into a vector to leave row_stream hanging
908 : // around to get a command tag. Also check that the response is not too
909 : // big.
910 0 : let mut rows = 0;
911 0 : let mut json_rows = output.key("rows").list();
912 0 : while let Some(row) = row_stream.next().await {
913 0 : let row = row.map_err(SqlOverHttpError::Postgres)?;
914 :
915 : // we don't have a streaming response support yet so this is to prevent OOM
916 : // from a malicious query (eg a cross join)
917 0 : if json_rows.as_buffer().len() > config.max_response_size_bytes {
918 0 : return Err(SqlOverHttpError::ResponseTooLarge(
919 0 : config.max_response_size_bytes,
920 0 : ));
921 0 : }
922 :
923 0 : pg_text_row_to_json(json_rows.entry(), &row, raw_output, array_mode)?;
924 0 : rows += 1;
925 :
926 : // assumption: parsing pg text and converting to json takes CPU time.
927 : // let's assume it is slightly expensive, so we should consume some cooperative budget.
928 : // Especially considering that `RowStream::next` might be pulling from a batch
929 : // of rows and never hit the tokio mpsc for a long time (although unlikely).
930 0 : tokio::task::consume_budget().await;
931 : }
932 0 : json_rows.finish();
933 :
934 0 : let query_resp_end = Instant::now();
935 :
936 0 : let ready = row_stream.status;
937 :
938 : // grab the command tag and number of rows affected
939 0 : let command_tag = row_stream.command_tag.unwrap_or_default();
940 0 : let mut command_tag_split = command_tag.split(' ');
941 0 : let command_tag_name = command_tag_split.next().unwrap_or_default();
942 0 : let command_tag_count = if command_tag_name == "INSERT" {
943 : // INSERT returns OID first and then number of rows
944 0 : command_tag_split.nth(1)
945 : } else {
946 : // other commands return number of rows (if any)
947 0 : command_tag_split.next()
948 : }
949 0 : .and_then(|s| s.parse::<i64>().ok());
950 :
951 0 : info!(
952 : rows,
953 : ?ready,
954 : command_tag,
955 0 : acknowledgement = ?(query_acknowledged - query_start),
956 0 : response = ?(query_resp_end - query_start),
957 0 : "finished executing query"
958 : );
959 :
960 0 : output.entry("command", command_tag_name);
961 0 : output.entry("rowCount", command_tag_count);
962 0 : output.entry("rowAsArray", array_mode);
963 :
964 0 : output.finish();
965 0 : Ok(ready)
966 0 : }
967 :
968 : enum Client {
969 : Remote(conn_pool_lib::Client<postgres_client::Client>),
970 : Local(conn_pool_lib::Client<postgres_client::Client>),
971 : }
972 :
973 : enum Discard<'a> {
974 : Remote(conn_pool_lib::Discard<'a, postgres_client::Client>),
975 : Local(conn_pool_lib::Discard<'a, postgres_client::Client>),
976 : }
977 :
978 : impl Client {
979 0 : fn metrics(&self, ctx: &RequestContext) -> Arc<MetricCounter> {
980 0 : match self {
981 0 : Client::Remote(client) => client.metrics(ctx),
982 0 : Client::Local(local_client) => local_client.metrics(ctx),
983 : }
984 0 : }
985 :
986 0 : fn inner(&mut self) -> (&mut postgres_client::Client, Discard<'_>) {
987 0 : match self {
988 0 : Client::Remote(client) => {
989 0 : let (c, d) = client.inner();
990 0 : (c, Discard::Remote(d))
991 : }
992 0 : Client::Local(local_client) => {
993 0 : let (c, d) = local_client.inner();
994 0 : (c, Discard::Local(d))
995 : }
996 : }
997 0 : }
998 : }
999 :
1000 : impl Discard<'_> {
1001 0 : fn discard(&mut self) {
1002 0 : match self {
1003 0 : Discard::Remote(discard) => discard.discard(),
1004 0 : Discard::Local(discard) => discard.discard(),
1005 : }
1006 0 : }
1007 : }
1008 :
1009 : #[cfg(test)]
1010 : mod tests {
1011 : use super::*;
1012 :
1013 : #[test]
1014 1 : fn test_payload() {
1015 1 : let payload = "{\"query\":\"SELECT * FROM users WHERE name = ?\",\"params\":[\"test\"],\"arrayMode\":true}";
1016 1 : let deserialized_payload: Payload = serde_json::from_str(payload).unwrap();
1017 :
1018 1 : match deserialized_payload {
1019 : Payload::Single(QueryData {
1020 1 : query,
1021 1 : params,
1022 1 : array_mode,
1023 : }) => {
1024 1 : assert_eq!(query, "SELECT * FROM users WHERE name = ?");
1025 1 : assert_eq!(params, vec![Some(String::from("test"))]);
1026 1 : assert!(array_mode.unwrap());
1027 : }
1028 : Payload::Batch(_) => {
1029 0 : panic!("deserialization failed: case with single query, one param, and array mode")
1030 : }
1031 : }
1032 :
1033 1 : let payload = "{\"queries\":[{\"query\":\"SELECT * FROM users0 WHERE name = ?\",\"params\":[\"test0\"], \"arrayMode\":false},{\"query\":\"SELECT * FROM users1 WHERE name = ?\",\"params\":[\"test1\"],\"arrayMode\":true}]}";
1034 1 : let deserialized_payload: Payload = serde_json::from_str(payload).unwrap();
1035 :
1036 1 : match deserialized_payload {
1037 1 : Payload::Batch(BatchQueryData { queries }) => {
1038 1 : assert_eq!(queries.len(), 2);
1039 2 : for (i, query) in queries.into_iter().enumerate() {
1040 2 : assert_eq!(
1041 : query.query,
1042 2 : format!("SELECT * FROM users{i} WHERE name = ?")
1043 : );
1044 2 : assert_eq!(query.params, vec![Some(format!("test{i}"))]);
1045 2 : assert_eq!(query.array_mode.unwrap(), i > 0);
1046 : }
1047 : }
1048 0 : Payload::Single(_) => panic!("deserialization failed: case with multiple queries"),
1049 : }
1050 :
1051 1 : let payload = "{\"query\":\"SELECT 1\"}";
1052 1 : let deserialized_payload: Payload = serde_json::from_str(payload).unwrap();
1053 :
1054 1 : match deserialized_payload {
1055 : Payload::Single(QueryData {
1056 1 : query,
1057 1 : params,
1058 1 : array_mode,
1059 : }) => {
1060 1 : assert_eq!(query, "SELECT 1");
1061 1 : assert_eq!(params, vec![]);
1062 1 : assert!(array_mode.is_none());
1063 : }
1064 0 : Payload::Batch(_) => panic!("deserialization failed: case with only one query"),
1065 : }
1066 1 : }
1067 : }
|