Line data Source code
1 : use std::pin::pin;
2 : use std::sync::Arc;
3 :
4 : use bytes::Bytes;
5 : use futures::future::{Either, select, try_join};
6 : use futures::{StreamExt, TryFutureExt};
7 : use http::Method;
8 : use http::header::AUTHORIZATION;
9 : use http_body_util::combinators::BoxBody;
10 : use http_body_util::{BodyExt, Full};
11 : use http_utils::error::ApiError;
12 : use hyper::body::Incoming;
13 : use hyper::http::{HeaderName, HeaderValue};
14 : use hyper::{HeaderMap, Request, Response, StatusCode, header};
15 : use indexmap::IndexMap;
16 : use postgres_client::error::{DbError, ErrorPosition, SqlState};
17 : use postgres_client::{
18 : GenericClient, IsolationLevel, NoTls, ReadyForQueryStatus, RowStream, Transaction,
19 : };
20 : use serde::Serialize;
21 : use serde_json::Value;
22 : use serde_json::value::RawValue;
23 : use tokio::time::{self, Instant};
24 : use tokio_util::sync::CancellationToken;
25 : use tracing::{debug, error, info};
26 : use typed_json::json;
27 : use url::Url;
28 : use uuid::Uuid;
29 :
30 : use super::backend::{LocalProxyConnError, PoolingBackend};
31 : use super::conn_pool::{AuthData, ConnInfoWithAuth};
32 : use super::conn_pool_lib::{self, ConnInfo};
33 : use super::error::HttpCodeError;
34 : use super::http_util::json_response;
35 : use super::json::{JsonConversionError, json_to_pg_text, pg_text_row_to_json};
36 : use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
37 : use crate::auth::{ComputeUserInfoParseError, endpoint_sni};
38 : use crate::config::{AuthenticationConfig, HttpConfig, ProxyConfig, TlsConfig};
39 : use crate::context::RequestContext;
40 : use crate::error::{ErrorKind, ReportableError, UserFacingError};
41 : use crate::http::{ReadBodyError, read_body_with_limit};
42 : use crate::metrics::{HttpDirection, Metrics, SniGroup, SniKind};
43 : use crate::pqproto::StartupMessageParams;
44 : use crate::proxy::NeonOptions;
45 : use crate::serverless::backend::HttpConnError;
46 : use crate::types::{DbName, RoleName};
47 : use crate::usage_metrics::{MetricCounter, MetricCounterRecorder};
48 : use crate::util::run_until_cancelled;
49 :
50 11 : #[derive(serde::Deserialize)]
51 : #[serde(rename_all = "camelCase")]
52 : struct QueryData {
53 : query: String,
54 : #[serde(deserialize_with = "bytes_to_pg_text")]
55 : #[serde(default)]
56 : params: Vec<Option<String>>,
57 : #[serde(default)]
58 : array_mode: Option<bool>,
59 : }
60 :
61 1 : #[derive(serde::Deserialize)]
62 : struct BatchQueryData {
63 : queries: Vec<QueryData>,
64 : }
65 :
66 : #[derive(serde::Deserialize)]
67 : #[serde(untagged)]
68 : enum Payload {
69 : Single(QueryData),
70 : Batch(BatchQueryData),
71 : }
72 :
73 : pub(super) static NEON_REQUEST_ID: HeaderName = HeaderName::from_static("neon-request-id");
74 :
75 : static CONN_STRING: HeaderName = HeaderName::from_static("neon-connection-string");
76 : static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output");
77 : static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode");
78 : static ALLOW_POOL: HeaderName = HeaderName::from_static("neon-pool-opt-in");
79 : static TXN_ISOLATION_LEVEL: HeaderName = HeaderName::from_static("neon-batch-isolation-level");
80 : static TXN_READ_ONLY: HeaderName = HeaderName::from_static("neon-batch-read-only");
81 : static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrable");
82 :
83 : static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true");
84 :
85 3 : fn bytes_to_pg_text<'de, D>(deserializer: D) -> Result<Vec<Option<String>>, D::Error>
86 3 : where
87 3 : D: serde::de::Deserializer<'de>,
88 3 : {
89 : // TODO: consider avoiding the allocation here.
90 3 : let json: Vec<Value> = serde::de::Deserialize::deserialize(deserializer)?;
91 3 : Ok(json_to_pg_text(json))
92 3 : }
93 :
94 : #[derive(Debug, thiserror::Error)]
95 : pub(crate) enum ConnInfoError {
96 : #[error("invalid header: {0}")]
97 : InvalidHeader(&'static HeaderName),
98 : #[error("invalid connection string: {0}")]
99 : UrlParseError(#[from] url::ParseError),
100 : #[error("incorrect scheme")]
101 : IncorrectScheme,
102 : #[error("missing database name")]
103 : MissingDbName,
104 : #[error("invalid database name")]
105 : InvalidDbName,
106 : #[error("missing username")]
107 : MissingUsername,
108 : #[error("invalid username: {0}")]
109 : InvalidUsername(#[from] std::string::FromUtf8Error),
110 : #[error("missing authentication credentials: {0}")]
111 : MissingCredentials(Credentials),
112 : #[error("missing hostname")]
113 : MissingHostname,
114 : #[error("invalid hostname: {0}")]
115 : InvalidEndpoint(#[from] ComputeUserInfoParseError),
116 : #[error("malformed endpoint")]
117 : MalformedEndpoint,
118 : }
119 :
120 : #[derive(Debug, thiserror::Error)]
121 : pub(crate) enum Credentials {
122 : #[error("required password")]
123 : Password,
124 : #[error("required authorization bearer token in JWT format")]
125 : BearerJwt,
126 : }
127 :
128 : impl ReportableError for ConnInfoError {
129 0 : fn get_error_kind(&self) -> ErrorKind {
130 0 : ErrorKind::User
131 0 : }
132 : }
133 :
134 : impl UserFacingError for ConnInfoError {
135 0 : fn to_string_client(&self) -> String {
136 0 : self.to_string()
137 0 : }
138 : }
139 :
140 0 : fn get_conn_info(
141 0 : config: &'static AuthenticationConfig,
142 0 : ctx: &RequestContext,
143 0 : headers: &HeaderMap,
144 0 : tls: Option<&TlsConfig>,
145 0 : ) -> Result<ConnInfoWithAuth, ConnInfoError> {
146 0 : let connection_string = headers
147 0 : .get(&CONN_STRING)
148 0 : .ok_or(ConnInfoError::InvalidHeader(&CONN_STRING))?
149 0 : .to_str()
150 0 : .map_err(|_| ConnInfoError::InvalidHeader(&CONN_STRING))?;
151 :
152 0 : let connection_url = Url::parse(connection_string)?;
153 :
154 0 : let protocol = connection_url.scheme();
155 0 : if protocol != "postgres" && protocol != "postgresql" {
156 0 : return Err(ConnInfoError::IncorrectScheme);
157 0 : }
158 :
159 0 : let mut url_path = connection_url
160 0 : .path_segments()
161 0 : .ok_or(ConnInfoError::MissingDbName)?;
162 :
163 0 : let dbname: DbName =
164 0 : urlencoding::decode(url_path.next().ok_or(ConnInfoError::InvalidDbName)?)?.into();
165 0 : ctx.set_dbname(dbname.clone());
166 :
167 0 : let username = RoleName::from(urlencoding::decode(connection_url.username())?);
168 0 : if username.is_empty() {
169 0 : return Err(ConnInfoError::MissingUsername);
170 0 : }
171 0 : ctx.set_user(username.clone());
172 :
173 0 : let auth = if let Some(auth) = headers.get(&AUTHORIZATION) {
174 0 : if !config.accept_jwts {
175 0 : return Err(ConnInfoError::MissingCredentials(Credentials::Password));
176 0 : }
177 :
178 0 : let auth = auth
179 0 : .to_str()
180 0 : .map_err(|_| ConnInfoError::InvalidHeader(&AUTHORIZATION))?;
181 : AuthData::Jwt(
182 0 : auth.strip_prefix("Bearer ")
183 0 : .ok_or(ConnInfoError::MissingCredentials(Credentials::BearerJwt))?
184 0 : .into(),
185 : )
186 0 : } else if let Some(pass) = connection_url.password() {
187 : // wrong credentials provided
188 0 : if config.accept_jwts {
189 0 : return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
190 0 : }
191 0 :
192 0 : AuthData::Password(match urlencoding::decode_binary(pass.as_bytes()) {
193 0 : std::borrow::Cow::Borrowed(b) => b.into(),
194 0 : std::borrow::Cow::Owned(b) => b.into(),
195 : })
196 0 : } else if config.accept_jwts {
197 0 : return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
198 : } else {
199 0 : return Err(ConnInfoError::MissingCredentials(Credentials::Password));
200 : };
201 :
202 0 : let endpoint = match connection_url.host() {
203 0 : Some(url::Host::Domain(hostname)) => {
204 0 : if let Some(tls) = tls {
205 0 : endpoint_sni(hostname, &tls.common_names).ok_or(ConnInfoError::MalformedEndpoint)?
206 : } else {
207 0 : hostname
208 0 : .split_once('.')
209 0 : .map_or(hostname, |(prefix, _)| prefix)
210 0 : .into()
211 : }
212 : }
213 : Some(url::Host::Ipv4(_) | url::Host::Ipv6(_)) | None => {
214 0 : return Err(ConnInfoError::MissingHostname);
215 : }
216 : };
217 0 : ctx.set_endpoint_id(endpoint.clone());
218 0 :
219 0 : let pairs = connection_url.query_pairs();
220 0 :
221 0 : let mut options = Option::None;
222 0 :
223 0 : let mut params = StartupMessageParams::default();
224 0 : params.insert("user", &username);
225 0 : params.insert("database", &dbname);
226 0 : for (key, value) in pairs {
227 0 : params.insert(&key, &value);
228 0 : if key == "options" {
229 0 : options = Some(NeonOptions::parse_options_raw(&value));
230 0 : }
231 : }
232 :
233 : // check the URL that was used, for metrics
234 : {
235 0 : let host_endpoint = headers
236 0 : // get the host header
237 0 : .get("host")
238 0 : // extract the domain
239 0 : .and_then(|h| {
240 0 : let (host, _port) = h.to_str().ok()?.split_once(':')?;
241 0 : Some(host)
242 0 : })
243 0 : // get the endpoint prefix
244 0 : .map(|h| h.split_once('.').map_or(h, |(prefix, _)| prefix));
245 :
246 0 : let kind = if host_endpoint == Some(&*endpoint) {
247 0 : SniKind::Sni
248 : } else {
249 0 : SniKind::NoSni
250 : };
251 :
252 0 : let protocol = ctx.protocol();
253 0 : Metrics::get()
254 0 : .proxy
255 0 : .accepted_connections_by_sni
256 0 : .inc(SniGroup { protocol, kind });
257 0 : }
258 0 :
259 0 : ctx.set_user_agent(
260 0 : headers
261 0 : .get(hyper::header::USER_AGENT)
262 0 : .and_then(|h| h.to_str().ok())
263 0 : .map(Into::into),
264 0 : );
265 0 :
266 0 : let user_info = ComputeUserInfo {
267 0 : endpoint,
268 0 : user: username,
269 0 : options: options.unwrap_or_default(),
270 0 : };
271 0 :
272 0 : let conn_info = ConnInfo { user_info, dbname };
273 0 : Ok(ConnInfoWithAuth { conn_info, auth })
274 0 : }
275 :
276 0 : pub(crate) async fn handle(
277 0 : config: &'static ProxyConfig,
278 0 : ctx: RequestContext,
279 0 : request: Request<Incoming>,
280 0 : backend: Arc<PoolingBackend>,
281 0 : cancel: CancellationToken,
282 0 : ) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
283 0 : let result = handle_inner(cancel, config, &ctx, request, backend).await;
284 :
285 0 : let mut response = match result {
286 0 : Ok(r) => {
287 0 : ctx.set_success();
288 0 :
289 0 : // Handling the error response from local proxy here
290 0 : if config.authentication_config.is_auth_broker && r.status().is_server_error() {
291 0 : let status = r.status();
292 :
293 0 : let body_bytes = r
294 0 : .collect()
295 0 : .await
296 0 : .map_err(|e| {
297 0 : ApiError::InternalServerError(anyhow::Error::msg(format!(
298 0 : "could not collect http body: {e}"
299 0 : )))
300 0 : })?
301 0 : .to_bytes();
302 :
303 0 : if let Ok(mut json_map) =
304 0 : serde_json::from_slice::<IndexMap<&str, &RawValue>>(&body_bytes)
305 : {
306 0 : let message = json_map.get("message");
307 0 : if let Some(message) = message {
308 0 : let msg: String = match serde_json::from_str(message.get()) {
309 0 : Ok(msg) => msg,
310 : Err(_) => {
311 0 : "Unable to parse the response message from server".to_string()
312 : }
313 : };
314 :
315 0 : error!("Error response from local_proxy: {status} {msg}");
316 :
317 0 : json_map.retain(|key, _| !key.starts_with("neon:")); // remove all the neon-related keys
318 0 :
319 0 : let resp_json = serde_json::to_string(&json_map)
320 0 : .unwrap_or("failed to serialize the response message".to_string());
321 0 :
322 0 : return json_response(status, resp_json);
323 0 : }
324 0 : }
325 :
326 0 : error!("Unable to parse the response message from local_proxy");
327 0 : return json_response(
328 0 : status,
329 0 : json!({ "message": "Unable to parse the response message from server".to_string() }),
330 0 : );
331 0 : }
332 0 : r
333 : }
334 0 : Err(e @ SqlOverHttpError::Cancelled(_)) => {
335 0 : let error_kind = e.get_error_kind();
336 0 : ctx.set_error_kind(error_kind);
337 0 :
338 0 : let message = "Query cancelled, connection was terminated";
339 0 :
340 0 : tracing::info!(
341 0 : kind=error_kind.to_metric_label(),
342 0 : error=%e,
343 0 : msg=message,
344 0 : "forwarding error to user"
345 : );
346 :
347 0 : json_response(
348 0 : StatusCode::BAD_REQUEST,
349 0 : json!({ "message": message, "code": SqlState::PROTOCOL_VIOLATION.code() }),
350 0 : )?
351 : }
352 0 : Err(e) => {
353 0 : let error_kind = e.get_error_kind();
354 0 : ctx.set_error_kind(error_kind);
355 0 :
356 0 : let mut message = e.to_string_client();
357 0 : let db_error = match &e {
358 0 : SqlOverHttpError::ConnectCompute(HttpConnError::PostgresConnectionError(e))
359 0 : | SqlOverHttpError::Postgres(e) => e.as_db_error(),
360 0 : _ => None,
361 : };
362 0 : fn get<'a, T: Default>(db: Option<&'a DbError>, x: impl FnOnce(&'a DbError) -> T) -> T {
363 0 : db.map(x).unwrap_or_default()
364 0 : }
365 :
366 0 : if let Some(db_error) = db_error {
367 0 : db_error.message().clone_into(&mut message);
368 0 : }
369 :
370 0 : let position = db_error.and_then(|db| db.position());
371 0 : let (position, internal_position, internal_query) = match position {
372 0 : Some(ErrorPosition::Original(position)) => (Some(position.to_string()), None, None),
373 0 : Some(ErrorPosition::Internal { position, query }) => {
374 0 : (None, Some(position.to_string()), Some(query.clone()))
375 : }
376 0 : None => (None, None, None),
377 : };
378 :
379 0 : let code = get(db_error, |db| db.code().code());
380 0 : let severity = get(db_error, |db| db.severity());
381 0 : let detail = get(db_error, |db| db.detail());
382 0 : let hint = get(db_error, |db| db.hint());
383 0 : let where_ = get(db_error, |db| db.where_());
384 0 : let table = get(db_error, |db| db.table());
385 0 : let column = get(db_error, |db| db.column());
386 0 : let schema = get(db_error, |db| db.schema());
387 0 : let datatype = get(db_error, |db| db.datatype());
388 0 : let constraint = get(db_error, |db| db.constraint());
389 0 : let file = get(db_error, |db| db.file());
390 0 : let line = get(db_error, |db| db.line().map(|l| l.to_string()));
391 0 : let routine = get(db_error, |db| db.routine());
392 0 :
393 0 : tracing::info!(
394 0 : kind=error_kind.to_metric_label(),
395 0 : error=%e,
396 0 : msg=message,
397 0 : "forwarding error to user"
398 : );
399 :
400 0 : json_response(
401 0 : e.get_http_status_code(),
402 0 : json!({
403 0 : "message": message,
404 0 : "code": code,
405 0 : "detail": detail,
406 0 : "hint": hint,
407 0 : "position": position,
408 0 : "internalPosition": internal_position,
409 0 : "internalQuery": internal_query,
410 0 : "severity": severity,
411 0 : "where": where_,
412 0 : "table": table,
413 0 : "column": column,
414 0 : "schema": schema,
415 0 : "dataType": datatype,
416 0 : "constraint": constraint,
417 0 : "file": file,
418 0 : "line": line,
419 0 : "routine": routine,
420 0 : }),
421 0 : )?
422 : }
423 : };
424 :
425 0 : response
426 0 : .headers_mut()
427 0 : .insert("Access-Control-Allow-Origin", HeaderValue::from_static("*"));
428 0 : Ok(response)
429 0 : }
430 :
431 : #[derive(Debug, thiserror::Error)]
432 : pub(crate) enum SqlOverHttpError {
433 : #[error("{0}")]
434 : ReadPayload(#[from] ReadPayloadError),
435 : #[error("{0}")]
436 : ConnectCompute(#[from] HttpConnError),
437 : #[error("{0}")]
438 : ConnInfo(#[from] ConnInfoError),
439 : #[error("response is too large (max is {0} bytes)")]
440 : ResponseTooLarge(usize),
441 : #[error("invalid isolation level")]
442 : InvalidIsolationLevel,
443 : /// for queries our customers choose to run
444 : #[error("{0}")]
445 : Postgres(#[source] postgres_client::Error),
446 : /// for queries we choose to run
447 : #[error("{0}")]
448 : InternalPostgres(#[source] postgres_client::Error),
449 : #[error("{0}")]
450 : JsonConversion(#[from] JsonConversionError),
451 : #[error("{0}")]
452 : Cancelled(SqlOverHttpCancel),
453 : }
454 :
455 : impl ReportableError for SqlOverHttpError {
456 0 : fn get_error_kind(&self) -> ErrorKind {
457 0 : match self {
458 0 : SqlOverHttpError::ReadPayload(e) => e.get_error_kind(),
459 0 : SqlOverHttpError::ConnectCompute(e) => e.get_error_kind(),
460 0 : SqlOverHttpError::ConnInfo(e) => e.get_error_kind(),
461 0 : SqlOverHttpError::ResponseTooLarge(_) => ErrorKind::User,
462 0 : SqlOverHttpError::InvalidIsolationLevel => ErrorKind::User,
463 0 : SqlOverHttpError::Postgres(p) => p.get_error_kind(),
464 0 : SqlOverHttpError::InternalPostgres(p) => {
465 0 : if p.as_db_error().is_some() {
466 0 : ErrorKind::Service
467 : } else {
468 0 : ErrorKind::Compute
469 : }
470 : }
471 0 : SqlOverHttpError::JsonConversion(_) => ErrorKind::Postgres,
472 0 : SqlOverHttpError::Cancelled(c) => c.get_error_kind(),
473 : }
474 0 : }
475 : }
476 :
477 : impl UserFacingError for SqlOverHttpError {
478 0 : fn to_string_client(&self) -> String {
479 0 : match self {
480 0 : SqlOverHttpError::ReadPayload(p) => p.to_string(),
481 0 : SqlOverHttpError::ConnectCompute(c) => c.to_string_client(),
482 0 : SqlOverHttpError::ConnInfo(c) => c.to_string_client(),
483 0 : SqlOverHttpError::ResponseTooLarge(_) => self.to_string(),
484 0 : SqlOverHttpError::InvalidIsolationLevel => self.to_string(),
485 0 : SqlOverHttpError::Postgres(p) => p.to_string(),
486 0 : SqlOverHttpError::InternalPostgres(p) => p.to_string(),
487 0 : SqlOverHttpError::JsonConversion(_) => "could not parse postgres response".to_string(),
488 0 : SqlOverHttpError::Cancelled(_) => self.to_string(),
489 : }
490 0 : }
491 : }
492 :
493 : impl HttpCodeError for SqlOverHttpError {
494 0 : fn get_http_status_code(&self) -> StatusCode {
495 0 : match self {
496 0 : SqlOverHttpError::ReadPayload(e) => e.get_http_status_code(),
497 0 : SqlOverHttpError::ConnectCompute(h) => match h.get_error_kind() {
498 0 : ErrorKind::User => StatusCode::BAD_REQUEST,
499 0 : _ => StatusCode::INTERNAL_SERVER_ERROR,
500 : },
501 0 : SqlOverHttpError::ConnInfo(_) => StatusCode::BAD_REQUEST,
502 0 : SqlOverHttpError::ResponseTooLarge(_) => StatusCode::INSUFFICIENT_STORAGE,
503 0 : SqlOverHttpError::InvalidIsolationLevel => StatusCode::BAD_REQUEST,
504 0 : SqlOverHttpError::Postgres(_) => StatusCode::BAD_REQUEST,
505 0 : SqlOverHttpError::InternalPostgres(_) => StatusCode::INTERNAL_SERVER_ERROR,
506 0 : SqlOverHttpError::JsonConversion(_) => StatusCode::INTERNAL_SERVER_ERROR,
507 0 : SqlOverHttpError::Cancelled(_) => StatusCode::INTERNAL_SERVER_ERROR,
508 : }
509 0 : }
510 : }
511 :
512 : #[derive(Debug, thiserror::Error)]
513 : pub(crate) enum ReadPayloadError {
514 : #[error("could not read the HTTP request body: {0}")]
515 : Read(#[from] hyper::Error),
516 : #[error("request is too large (max is {limit} bytes)")]
517 : BodyTooLarge { limit: usize },
518 : #[error("could not parse the HTTP request body: {0}")]
519 : Parse(#[from] serde_json::Error),
520 : }
521 :
522 : impl From<ReadBodyError<hyper::Error>> for ReadPayloadError {
523 0 : fn from(value: ReadBodyError<hyper::Error>) -> Self {
524 0 : match value {
525 0 : ReadBodyError::BodyTooLarge { limit } => Self::BodyTooLarge { limit },
526 0 : ReadBodyError::Read(e) => Self::Read(e),
527 : }
528 0 : }
529 : }
530 :
531 : impl ReportableError for ReadPayloadError {
532 0 : fn get_error_kind(&self) -> ErrorKind {
533 0 : match self {
534 0 : ReadPayloadError::Read(_) => ErrorKind::ClientDisconnect,
535 0 : ReadPayloadError::BodyTooLarge { .. } => ErrorKind::User,
536 0 : ReadPayloadError::Parse(_) => ErrorKind::User,
537 : }
538 0 : }
539 : }
540 :
541 : impl HttpCodeError for ReadPayloadError {
542 0 : fn get_http_status_code(&self) -> StatusCode {
543 0 : match self {
544 0 : ReadPayloadError::Read(_) => StatusCode::BAD_REQUEST,
545 0 : ReadPayloadError::BodyTooLarge { .. } => StatusCode::PAYLOAD_TOO_LARGE,
546 0 : ReadPayloadError::Parse(_) => StatusCode::BAD_REQUEST,
547 : }
548 0 : }
549 : }
550 :
551 : #[derive(Debug, thiserror::Error)]
552 : pub(crate) enum SqlOverHttpCancel {
553 : #[error("query was cancelled")]
554 : Postgres,
555 : #[error("query was cancelled while stuck trying to connect to the database")]
556 : Connect,
557 : }
558 :
559 : impl ReportableError for SqlOverHttpCancel {
560 0 : fn get_error_kind(&self) -> ErrorKind {
561 0 : match self {
562 0 : SqlOverHttpCancel::Postgres => ErrorKind::ClientDisconnect,
563 0 : SqlOverHttpCancel::Connect => ErrorKind::ClientDisconnect,
564 : }
565 0 : }
566 : }
567 :
568 : #[derive(Clone, Copy, Debug)]
569 : struct HttpHeaders {
570 : raw_output: bool,
571 : default_array_mode: bool,
572 : txn_isolation_level: Option<IsolationLevel>,
573 : txn_read_only: bool,
574 : txn_deferrable: bool,
575 : }
576 :
577 : impl HttpHeaders {
578 0 : fn try_parse(headers: &hyper::http::HeaderMap) -> Result<Self, SqlOverHttpError> {
579 0 : // Determine the output options. Default behaviour is 'false'. Anything that is not
580 0 : // strictly 'true' assumed to be false.
581 0 : let raw_output = headers.get(&RAW_TEXT_OUTPUT) == Some(&HEADER_VALUE_TRUE);
582 0 : let default_array_mode = headers.get(&ARRAY_MODE) == Some(&HEADER_VALUE_TRUE);
583 :
584 : // isolation level, read only and deferrable
585 0 : let txn_isolation_level = match headers.get(&TXN_ISOLATION_LEVEL) {
586 0 : Some(x) => Some(
587 0 : map_header_to_isolation_level(x).ok_or(SqlOverHttpError::InvalidIsolationLevel)?,
588 : ),
589 0 : None => None,
590 : };
591 :
592 0 : let txn_read_only = headers.get(&TXN_READ_ONLY) == Some(&HEADER_VALUE_TRUE);
593 0 : let txn_deferrable = headers.get(&TXN_DEFERRABLE) == Some(&HEADER_VALUE_TRUE);
594 0 :
595 0 : Ok(Self {
596 0 : raw_output,
597 0 : default_array_mode,
598 0 : txn_isolation_level,
599 0 : txn_read_only,
600 0 : txn_deferrable,
601 0 : })
602 0 : }
603 : }
604 :
605 0 : fn map_header_to_isolation_level(level: &HeaderValue) -> Option<IsolationLevel> {
606 0 : match level.as_bytes() {
607 0 : b"Serializable" => Some(IsolationLevel::Serializable),
608 0 : b"ReadUncommitted" => Some(IsolationLevel::ReadUncommitted),
609 0 : b"ReadCommitted" => Some(IsolationLevel::ReadCommitted),
610 0 : b"RepeatableRead" => Some(IsolationLevel::RepeatableRead),
611 0 : _ => None,
612 : }
613 0 : }
614 :
615 0 : fn map_isolation_level_to_headers(level: IsolationLevel) -> Option<HeaderValue> {
616 0 : match level {
617 0 : IsolationLevel::ReadUncommitted => Some(HeaderValue::from_static("ReadUncommitted")),
618 0 : IsolationLevel::ReadCommitted => Some(HeaderValue::from_static("ReadCommitted")),
619 0 : IsolationLevel::RepeatableRead => Some(HeaderValue::from_static("RepeatableRead")),
620 0 : IsolationLevel::Serializable => Some(HeaderValue::from_static("Serializable")),
621 0 : _ => None,
622 : }
623 0 : }
624 :
625 0 : async fn handle_inner(
626 0 : cancel: CancellationToken,
627 0 : config: &'static ProxyConfig,
628 0 : ctx: &RequestContext,
629 0 : request: Request<Incoming>,
630 0 : backend: Arc<PoolingBackend>,
631 0 : ) -> Result<Response<BoxBody<Bytes, hyper::Error>>, SqlOverHttpError> {
632 0 : let _requeset_gauge = Metrics::get()
633 0 : .proxy
634 0 : .connection_requests
635 0 : .guard(ctx.protocol());
636 0 : info!(
637 0 : protocol = %ctx.protocol(),
638 0 : "handling interactive connection from client"
639 : );
640 :
641 0 : let conn_info = get_conn_info(
642 0 : &config.authentication_config,
643 0 : ctx,
644 0 : request.headers(),
645 0 : // todo: race condition?
646 0 : // we're unlikely to change the common names.
647 0 : config.tls_config.load().as_deref(),
648 0 : )?;
649 0 : info!(
650 0 : user = conn_info.conn_info.user_info.user.as_str(),
651 0 : "credentials"
652 : );
653 :
654 0 : match conn_info.auth {
655 0 : AuthData::Jwt(jwt) if config.authentication_config.is_auth_broker => {
656 0 : handle_auth_broker_inner(ctx, request, conn_info.conn_info, jwt, backend).await
657 : }
658 0 : auth => {
659 0 : handle_db_inner(
660 0 : cancel,
661 0 : config,
662 0 : ctx,
663 0 : request,
664 0 : conn_info.conn_info,
665 0 : auth,
666 0 : backend,
667 0 : )
668 0 : .await
669 : }
670 : }
671 0 : }
672 :
673 0 : async fn handle_db_inner(
674 0 : cancel: CancellationToken,
675 0 : config: &'static ProxyConfig,
676 0 : ctx: &RequestContext,
677 0 : request: Request<Incoming>,
678 0 : conn_info: ConnInfo,
679 0 : auth: AuthData,
680 0 : backend: Arc<PoolingBackend>,
681 0 : ) -> Result<Response<BoxBody<Bytes, hyper::Error>>, SqlOverHttpError> {
682 0 : //
683 0 : // Determine the destination and connection params
684 0 : //
685 0 : let headers = request.headers();
686 :
687 : // Allow connection pooling only if explicitly requested
688 : // or if we have decided that http pool is no longer opt-in
689 0 : let allow_pool = !config.http_config.pool_options.opt_in
690 0 : || headers.get(&ALLOW_POOL) == Some(&HEADER_VALUE_TRUE);
691 :
692 0 : let parsed_headers = HttpHeaders::try_parse(headers)?;
693 :
694 0 : let mut request_len = 0;
695 0 : let fetch_and_process_request = Box::pin(
696 0 : async {
697 0 : let body = read_body_with_limit(
698 0 : request.into_body(),
699 0 : config.http_config.max_request_size_bytes,
700 0 : )
701 0 : .await?;
702 :
703 0 : request_len = body.len();
704 0 :
705 0 : Metrics::get()
706 0 : .proxy
707 0 : .http_conn_content_length_bytes
708 0 : .observe(HttpDirection::Request, body.len() as f64);
709 0 :
710 0 : debug!(length = body.len(), "request payload read");
711 0 : let payload: Payload = serde_json::from_slice(&body)?;
712 0 : Ok::<Payload, ReadPayloadError>(payload) // Adjust error type accordingly
713 0 : }
714 0 : .map_err(SqlOverHttpError::from),
715 0 : );
716 0 :
717 0 : let authenticate_and_connect = Box::pin(
718 0 : async {
719 0 : let keys = match auth {
720 0 : AuthData::Password(pw) => backend
721 0 : .authenticate_with_password(ctx, &conn_info.user_info, &pw)
722 0 : .await
723 0 : .map_err(HttpConnError::AuthError)?,
724 0 : AuthData::Jwt(jwt) => backend
725 0 : .authenticate_with_jwt(ctx, &conn_info.user_info, jwt)
726 0 : .await
727 0 : .map_err(HttpConnError::AuthError)?,
728 : };
729 :
730 0 : let client = match keys.keys {
731 0 : ComputeCredentialKeys::JwtPayload(payload)
732 0 : if backend.auth_backend.is_local_proxy() =>
733 : {
734 0 : let mut client = backend.connect_to_local_postgres(ctx, conn_info).await?;
735 0 : let (cli_inner, _dsc) = client.client_inner();
736 0 : cli_inner.set_jwt_session(&payload).await?;
737 0 : Client::Local(client)
738 : }
739 : _ => {
740 0 : let client = backend
741 0 : .connect_to_compute(ctx, conn_info, keys, !allow_pool)
742 0 : .await?;
743 0 : Client::Remote(client)
744 : }
745 : };
746 :
747 : // not strictly necessary to mark success here,
748 : // but it's just insurance for if we forget it somewhere else
749 0 : ctx.success();
750 0 : Ok::<_, SqlOverHttpError>(client)
751 0 : }
752 0 : .map_err(SqlOverHttpError::from),
753 0 : );
754 :
755 0 : let (payload, mut client) = match run_until_cancelled(
756 0 : // Run both operations in parallel
757 0 : try_join(
758 0 : pin!(fetch_and_process_request),
759 0 : pin!(authenticate_and_connect),
760 0 : ),
761 0 : &cancel,
762 0 : )
763 0 : .await
764 : {
765 0 : Some(result) => result?,
766 0 : None => return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Connect)),
767 : };
768 :
769 0 : let mut response = Response::builder()
770 0 : .status(StatusCode::OK)
771 0 : .header(header::CONTENT_TYPE, "application/json");
772 :
773 : // Now execute the query and return the result.
774 0 : let json_output = match payload {
775 0 : Payload::Single(stmt) => {
776 0 : stmt.process(&config.http_config, cancel, &mut client, parsed_headers)
777 0 : .await?
778 : }
779 0 : Payload::Batch(statements) => {
780 0 : if parsed_headers.txn_read_only {
781 0 : response = response.header(TXN_READ_ONLY.clone(), &HEADER_VALUE_TRUE);
782 0 : }
783 0 : if parsed_headers.txn_deferrable {
784 0 : response = response.header(TXN_DEFERRABLE.clone(), &HEADER_VALUE_TRUE);
785 0 : }
786 0 : if let Some(txn_isolation_level) = parsed_headers
787 0 : .txn_isolation_level
788 0 : .and_then(map_isolation_level_to_headers)
789 0 : {
790 0 : response = response.header(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level);
791 0 : }
792 :
793 0 : statements
794 0 : .process(&config.http_config, cancel, &mut client, parsed_headers)
795 0 : .await?
796 : }
797 : };
798 :
799 0 : let metrics = client.metrics(ctx);
800 0 :
801 0 : let len = json_output.len();
802 0 : let response = response
803 0 : .body(
804 0 : Full::new(Bytes::from(json_output))
805 0 : .map_err(|x| match x {})
806 0 : .boxed(),
807 0 : )
808 0 : // only fails if invalid status code or invalid header/values are given.
809 0 : // these are not user configurable so it cannot fail dynamically
810 0 : .expect("building response payload should not fail");
811 0 :
812 0 : // count the egress bytes - we miss the TLS and header overhead but oh well...
813 0 : // moving this later in the stack is going to be a lot of effort and ehhhh
814 0 : metrics.record_egress(len as u64);
815 0 : metrics.record_ingress(request_len as u64);
816 0 :
817 0 : Metrics::get()
818 0 : .proxy
819 0 : .http_conn_content_length_bytes
820 0 : .observe(HttpDirection::Response, len as f64);
821 0 :
822 0 : Ok(response)
823 0 : }
824 :
825 : static HEADERS_TO_FORWARD: &[&HeaderName] = &[
826 : &AUTHORIZATION,
827 : &CONN_STRING,
828 : &RAW_TEXT_OUTPUT,
829 : &ARRAY_MODE,
830 : &TXN_ISOLATION_LEVEL,
831 : &TXN_READ_ONLY,
832 : &TXN_DEFERRABLE,
833 : ];
834 :
835 0 : pub(crate) fn uuid_to_header_value(id: Uuid) -> HeaderValue {
836 0 : let mut uuid = [0; uuid::fmt::Hyphenated::LENGTH];
837 0 : HeaderValue::from_str(id.as_hyphenated().encode_lower(&mut uuid[..]))
838 0 : .expect("uuid hyphenated format should be all valid header characters")
839 0 : }
840 :
841 0 : async fn handle_auth_broker_inner(
842 0 : ctx: &RequestContext,
843 0 : request: Request<Incoming>,
844 0 : conn_info: ConnInfo,
845 0 : jwt: String,
846 0 : backend: Arc<PoolingBackend>,
847 0 : ) -> Result<Response<BoxBody<Bytes, hyper::Error>>, SqlOverHttpError> {
848 0 : backend
849 0 : .authenticate_with_jwt(ctx, &conn_info.user_info, jwt)
850 0 : .await
851 0 : .map_err(HttpConnError::from)?;
852 :
853 0 : let mut client = backend.connect_to_local_proxy(ctx, conn_info).await?;
854 :
855 0 : let local_proxy_uri = ::http::Uri::from_static("http://proxy.local/sql");
856 0 :
857 0 : let (mut parts, body) = request.into_parts();
858 0 : let mut req = Request::builder().method(Method::POST).uri(local_proxy_uri);
859 :
860 : // todo(conradludgate): maybe auth-broker should parse these and re-serialize
861 : // these instead just to ensure they remain normalised.
862 0 : for &h in HEADERS_TO_FORWARD {
863 0 : if let Some(hv) = parts.headers.remove(h) {
864 0 : req = req.header(h, hv);
865 0 : }
866 : }
867 0 : req = req.header(&NEON_REQUEST_ID, uuid_to_header_value(ctx.session_id()));
868 0 :
869 0 : let req = req
870 0 : .body(body)
871 0 : .expect("all headers and params received via hyper should be valid for request");
872 0 :
873 0 : // todo: map body to count egress
874 0 : let _metrics = client.metrics(ctx);
875 0 :
876 0 : Ok(client
877 0 : .inner
878 0 : .inner
879 0 : .send_request(req)
880 0 : .await
881 0 : .map_err(LocalProxyConnError::from)
882 0 : .map_err(HttpConnError::from)?
883 0 : .map(|b| b.boxed()))
884 0 : }
885 :
886 : impl QueryData {
887 0 : async fn process(
888 0 : self,
889 0 : config: &'static HttpConfig,
890 0 : cancel: CancellationToken,
891 0 : client: &mut Client,
892 0 : parsed_headers: HttpHeaders,
893 0 : ) -> Result<String, SqlOverHttpError> {
894 0 : let (inner, mut discard) = client.inner();
895 0 : let cancel_token = inner.cancel_token();
896 0 :
897 0 : match select(
898 0 : pin!(query_to_json(
899 0 : config,
900 0 : &mut *inner,
901 0 : self,
902 0 : &mut 0,
903 0 : parsed_headers
904 0 : )),
905 0 : pin!(cancel.cancelled()),
906 0 : )
907 0 : .await
908 : {
909 : // The query successfully completed.
910 0 : Either::Left((Ok((status, results)), __not_yet_cancelled)) => {
911 0 : discard.check_idle(status);
912 0 :
913 0 : let json_output =
914 0 : serde_json::to_string(&results).expect("json serialization should not fail");
915 0 : Ok(json_output)
916 : }
917 : // The query failed with an error
918 0 : Either::Left((Err(e), __not_yet_cancelled)) => {
919 0 : discard.discard();
920 0 : Err(e)
921 : }
922 : // The query was cancelled.
923 0 : Either::Right((_cancelled, query)) => {
924 0 : tracing::info!("cancelling query");
925 0 : if let Err(err) = cancel_token.cancel_query(NoTls).await {
926 0 : tracing::warn!(?err, "could not cancel query");
927 0 : }
928 : // wait for the query cancellation
929 0 : match time::timeout(time::Duration::from_millis(100), query).await {
930 : // query successed before it was cancelled.
931 0 : Ok(Ok((status, results))) => {
932 0 : discard.check_idle(status);
933 0 :
934 0 : let json_output = serde_json::to_string(&results)
935 0 : .expect("json serialization should not fail");
936 0 : Ok(json_output)
937 : }
938 : // query failed or was cancelled.
939 0 : Ok(Err(error)) => {
940 0 : let db_error = match &error {
941 : SqlOverHttpError::ConnectCompute(
942 0 : HttpConnError::PostgresConnectionError(e),
943 : )
944 0 : | SqlOverHttpError::Postgres(e) => e.as_db_error(),
945 0 : _ => None,
946 : };
947 :
948 : // if errored for some other reason, it might not be safe to return
949 0 : if !db_error.is_some_and(|e| *e.code() == SqlState::QUERY_CANCELED) {
950 0 : discard.discard();
951 0 : }
952 :
953 0 : Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres))
954 : }
955 0 : Err(_timeout) => {
956 0 : discard.discard();
957 0 : Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres))
958 : }
959 : }
960 : }
961 : }
962 0 : }
963 : }
964 :
965 : impl BatchQueryData {
966 0 : async fn process(
967 0 : self,
968 0 : config: &'static HttpConfig,
969 0 : cancel: CancellationToken,
970 0 : client: &mut Client,
971 0 : parsed_headers: HttpHeaders,
972 0 : ) -> Result<String, SqlOverHttpError> {
973 0 : info!("starting transaction");
974 0 : let (inner, mut discard) = client.inner();
975 0 : let cancel_token = inner.cancel_token();
976 0 : let mut builder = inner.build_transaction();
977 0 : if let Some(isolation_level) = parsed_headers.txn_isolation_level {
978 0 : builder = builder.isolation_level(isolation_level);
979 0 : }
980 0 : if parsed_headers.txn_read_only {
981 0 : builder = builder.read_only(true);
982 0 : }
983 0 : if parsed_headers.txn_deferrable {
984 0 : builder = builder.deferrable(true);
985 0 : }
986 :
987 0 : let mut transaction = builder
988 0 : .start()
989 0 : .await
990 0 : .inspect_err(|_| {
991 0 : // if we cannot start a transaction, we should return immediately
992 0 : // and not return to the pool. connection is clearly broken
993 0 : discard.discard();
994 0 : })
995 0 : .map_err(SqlOverHttpError::Postgres)?;
996 :
997 0 : let json_output = match query_batch(
998 0 : config,
999 0 : cancel.child_token(),
1000 0 : &mut transaction,
1001 0 : self,
1002 0 : parsed_headers,
1003 0 : )
1004 0 : .await
1005 : {
1006 0 : Ok(json_output) => {
1007 0 : info!("commit");
1008 0 : let status = transaction
1009 0 : .commit()
1010 0 : .await
1011 0 : .inspect_err(|_| {
1012 0 : // if we cannot commit - for now don't return connection to pool
1013 0 : // TODO: get a query status from the error
1014 0 : discard.discard();
1015 0 : })
1016 0 : .map_err(SqlOverHttpError::Postgres)?;
1017 0 : discard.check_idle(status);
1018 0 : json_output
1019 : }
1020 : Err(SqlOverHttpError::Cancelled(_)) => {
1021 0 : if let Err(err) = cancel_token.cancel_query(NoTls).await {
1022 0 : tracing::warn!(?err, "could not cancel query");
1023 0 : }
1024 : // TODO: after cancelling, wait to see if we can get a status. maybe the connection is still safe.
1025 0 : discard.discard();
1026 0 :
1027 0 : return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres));
1028 : }
1029 0 : Err(err) => {
1030 0 : info!("rollback");
1031 0 : let status = transaction
1032 0 : .rollback()
1033 0 : .await
1034 0 : .inspect_err(|_| {
1035 0 : // if we cannot rollback - for now don't return connection to pool
1036 0 : // TODO: get a query status from the error
1037 0 : discard.discard();
1038 0 : })
1039 0 : .map_err(SqlOverHttpError::Postgres)?;
1040 0 : discard.check_idle(status);
1041 0 : return Err(err);
1042 : }
1043 : };
1044 :
1045 0 : Ok(json_output)
1046 0 : }
1047 : }
1048 :
1049 0 : async fn query_batch(
1050 0 : config: &'static HttpConfig,
1051 0 : cancel: CancellationToken,
1052 0 : transaction: &mut Transaction<'_>,
1053 0 : queries: BatchQueryData,
1054 0 : parsed_headers: HttpHeaders,
1055 0 : ) -> Result<String, SqlOverHttpError> {
1056 0 : let mut results = Vec::with_capacity(queries.queries.len());
1057 0 : let mut current_size = 0;
1058 0 : for stmt in queries.queries {
1059 0 : let query = pin!(query_to_json(
1060 0 : config,
1061 0 : transaction,
1062 0 : stmt,
1063 0 : &mut current_size,
1064 0 : parsed_headers,
1065 0 : ));
1066 0 : let cancelled = pin!(cancel.cancelled());
1067 0 : let res = select(query, cancelled).await;
1068 0 : match res {
1069 : // TODO: maybe we should check that the transaction bit is set here
1070 0 : Either::Left((Ok((_, values)), _cancelled)) => {
1071 0 : results.push(values);
1072 0 : }
1073 0 : Either::Left((Err(e), _cancelled)) => {
1074 0 : return Err(e);
1075 : }
1076 0 : Either::Right((_cancelled, _)) => {
1077 0 : return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres));
1078 : }
1079 : }
1080 : }
1081 :
1082 0 : let results = json!({ "results": results });
1083 0 : let json_output = serde_json::to_string(&results).expect("json serialization should not fail");
1084 0 :
1085 0 : Ok(json_output)
1086 0 : }
1087 :
1088 0 : async fn query_to_json<T: GenericClient>(
1089 0 : config: &'static HttpConfig,
1090 0 : client: &mut T,
1091 0 : data: QueryData,
1092 0 : current_size: &mut usize,
1093 0 : parsed_headers: HttpHeaders,
1094 0 : ) -> Result<(ReadyForQueryStatus, impl Serialize + use<T>), SqlOverHttpError> {
1095 0 : let query_start = Instant::now();
1096 0 :
1097 0 : let query_params = data.params;
1098 0 : let mut row_stream = client
1099 0 : .query_raw_txt(&data.query, query_params)
1100 0 : .await
1101 0 : .map_err(SqlOverHttpError::Postgres)?;
1102 0 : let query_acknowledged = Instant::now();
1103 0 :
1104 0 : let columns_len = row_stream.statement.columns().len();
1105 0 : let mut fields = Vec::with_capacity(columns_len);
1106 0 : let mut types = Vec::with_capacity(columns_len);
1107 :
1108 0 : for c in row_stream.statement.columns() {
1109 0 : fields.push(json!({
1110 0 : "name": c.name().to_owned(),
1111 0 : "dataTypeID": c.type_().oid(),
1112 0 : "tableID": c.table_oid(),
1113 0 : "columnID": c.column_id(),
1114 0 : "dataTypeSize": c.type_size(),
1115 0 : "dataTypeModifier": c.type_modifier(),
1116 0 : "format": "text",
1117 0 : }));
1118 0 :
1119 0 : types.push(c.type_().clone());
1120 0 : }
1121 :
1122 0 : let raw_output = parsed_headers.raw_output;
1123 0 : let array_mode = data.array_mode.unwrap_or(parsed_headers.default_array_mode);
1124 0 :
1125 0 : // Manually drain the stream into a vector to leave row_stream hanging
1126 0 : // around to get a command tag. Also check that the response is not too
1127 0 : // big.
1128 0 : let mut rows = Vec::new();
1129 0 : while let Some(row) = row_stream.next().await {
1130 0 : let row = row.map_err(SqlOverHttpError::Postgres)?;
1131 0 : *current_size += row.body_len();
1132 0 :
1133 0 : // we don't have a streaming response support yet so this is to prevent OOM
1134 0 : // from a malicious query (eg a cross join)
1135 0 : if *current_size > config.max_response_size_bytes {
1136 0 : return Err(SqlOverHttpError::ResponseTooLarge(
1137 0 : config.max_response_size_bytes,
1138 0 : ));
1139 0 : }
1140 :
1141 0 : let row = pg_text_row_to_json(&row, &types, raw_output, array_mode)?;
1142 0 : rows.push(row);
1143 0 :
1144 0 : // assumption: parsing pg text and converting to json takes CPU time.
1145 0 : // let's assume it is slightly expensive, so we should consume some cooperative budget.
1146 0 : // Especially considering that `RowStream::next` might be pulling from a batch
1147 0 : // of rows and never hit the tokio mpsc for a long time (although unlikely).
1148 0 : tokio::task::consume_budget().await;
1149 : }
1150 :
1151 0 : let query_resp_end = Instant::now();
1152 0 : let RowStream {
1153 0 : command_tag,
1154 0 : status: ready,
1155 0 : ..
1156 0 : } = row_stream;
1157 0 :
1158 0 : // grab the command tag and number of rows affected
1159 0 : let command_tag = command_tag.unwrap_or_default();
1160 0 : let mut command_tag_split = command_tag.split(' ');
1161 0 : let command_tag_name = command_tag_split.next().unwrap_or_default();
1162 0 : let command_tag_count = if command_tag_name == "INSERT" {
1163 : // INSERT returns OID first and then number of rows
1164 0 : command_tag_split.nth(1)
1165 : } else {
1166 : // other commands return number of rows (if any)
1167 0 : command_tag_split.next()
1168 : }
1169 0 : .and_then(|s| s.parse::<i64>().ok());
1170 0 :
1171 0 : info!(
1172 0 : rows = rows.len(),
1173 0 : ?ready,
1174 0 : command_tag,
1175 0 : acknowledgement = ?(query_acknowledged - query_start),
1176 0 : response = ?(query_resp_end - query_start),
1177 0 : "finished executing query"
1178 : );
1179 :
1180 : // Resulting JSON format is based on the format of node-postgres result.
1181 0 : let results = json!({
1182 0 : "command": command_tag_name.to_string(),
1183 0 : "rowCount": command_tag_count,
1184 0 : "rows": rows,
1185 0 : "fields": fields,
1186 0 : "rowAsArray": array_mode,
1187 0 : });
1188 0 :
1189 0 : Ok((ready, results))
1190 0 : }
1191 :
1192 : enum Client {
1193 : Remote(conn_pool_lib::Client<postgres_client::Client>),
1194 : Local(conn_pool_lib::Client<postgres_client::Client>),
1195 : }
1196 :
1197 : enum Discard<'a> {
1198 : Remote(conn_pool_lib::Discard<'a, postgres_client::Client>),
1199 : Local(conn_pool_lib::Discard<'a, postgres_client::Client>),
1200 : }
1201 :
1202 : impl Client {
1203 0 : fn metrics(&self, ctx: &RequestContext) -> Arc<MetricCounter> {
1204 0 : match self {
1205 0 : Client::Remote(client) => client.metrics(ctx),
1206 0 : Client::Local(local_client) => local_client.metrics(ctx),
1207 : }
1208 0 : }
1209 :
1210 0 : fn inner(&mut self) -> (&mut postgres_client::Client, Discard<'_>) {
1211 0 : match self {
1212 0 : Client::Remote(client) => {
1213 0 : let (c, d) = client.inner();
1214 0 : (c, Discard::Remote(d))
1215 : }
1216 0 : Client::Local(local_client) => {
1217 0 : let (c, d) = local_client.inner();
1218 0 : (c, Discard::Local(d))
1219 : }
1220 : }
1221 0 : }
1222 : }
1223 :
1224 : impl Discard<'_> {
1225 0 : fn check_idle(&mut self, status: ReadyForQueryStatus) {
1226 0 : match self {
1227 0 : Discard::Remote(discard) => discard.check_idle(status),
1228 0 : Discard::Local(discard) => discard.check_idle(status),
1229 : }
1230 0 : }
1231 0 : fn discard(&mut self) {
1232 0 : match self {
1233 0 : Discard::Remote(discard) => discard.discard(),
1234 0 : Discard::Local(discard) => discard.discard(),
1235 : }
1236 0 : }
1237 : }
1238 :
1239 : #[cfg(test)]
1240 : mod tests {
1241 : use super::*;
1242 :
1243 : #[test]
1244 1 : fn test_payload() {
1245 1 : let payload = "{\"query\":\"SELECT * FROM users WHERE name = ?\",\"params\":[\"test\"],\"arrayMode\":true}";
1246 1 : let deserialized_payload: Payload = serde_json::from_str(payload).unwrap();
1247 1 :
1248 1 : match deserialized_payload {
1249 : Payload::Single(QueryData {
1250 1 : query,
1251 1 : params,
1252 1 : array_mode,
1253 1 : }) => {
1254 1 : assert_eq!(query, "SELECT * FROM users WHERE name = ?");
1255 1 : assert_eq!(params, vec![Some(String::from("test"))]);
1256 1 : assert!(array_mode.unwrap());
1257 : }
1258 : Payload::Batch(_) => {
1259 0 : panic!("deserialization failed: case with single query, one param, and array mode")
1260 : }
1261 : }
1262 :
1263 1 : let payload = "{\"queries\":[{\"query\":\"SELECT * FROM users0 WHERE name = ?\",\"params\":[\"test0\"], \"arrayMode\":false},{\"query\":\"SELECT * FROM users1 WHERE name = ?\",\"params\":[\"test1\"],\"arrayMode\":true}]}";
1264 1 : let deserialized_payload: Payload = serde_json::from_str(payload).unwrap();
1265 1 :
1266 1 : match deserialized_payload {
1267 1 : Payload::Batch(BatchQueryData { queries }) => {
1268 1 : assert_eq!(queries.len(), 2);
1269 2 : for (i, query) in queries.into_iter().enumerate() {
1270 2 : assert_eq!(
1271 2 : query.query,
1272 2 : format!("SELECT * FROM users{i} WHERE name = ?")
1273 2 : );
1274 2 : assert_eq!(query.params, vec![Some(format!("test{i}"))]);
1275 2 : assert_eq!(query.array_mode.unwrap(), i > 0);
1276 : }
1277 : }
1278 0 : Payload::Single(_) => panic!("deserialization failed: case with multiple queries"),
1279 : }
1280 :
1281 1 : let payload = "{\"query\":\"SELECT 1\"}";
1282 1 : let deserialized_payload: Payload = serde_json::from_str(payload).unwrap();
1283 1 :
1284 1 : match deserialized_payload {
1285 : Payload::Single(QueryData {
1286 1 : query,
1287 1 : params,
1288 1 : array_mode,
1289 1 : }) => {
1290 1 : assert_eq!(query, "SELECT 1");
1291 1 : assert_eq!(params, vec![]);
1292 1 : assert!(array_mode.is_none());
1293 : }
1294 0 : Payload::Batch(_) => panic!("deserialization failed: case with only one query"),
1295 : }
1296 1 : }
1297 : }
|