Line data Source code
1 : use std::pin::pin;
2 : use std::sync::Arc;
3 :
4 : use bytes::Bytes;
5 : use futures::future::{select, try_join, Either};
6 : use futures::{StreamExt, TryFutureExt};
7 : use http::header::AUTHORIZATION;
8 : use http::Method;
9 : use http_body_util::combinators::BoxBody;
10 : use http_body_util::{BodyExt, Full};
11 : use hyper::body::{Body, Incoming};
12 : use hyper::http::{HeaderName, HeaderValue};
13 : use hyper::{header, HeaderMap, Request, Response, StatusCode};
14 : use pq_proto::StartupMessageParamsBuilder;
15 : use serde::Serialize;
16 : use serde_json::Value;
17 : use tokio::time;
18 : use tokio_postgres::error::{DbError, ErrorPosition, SqlState};
19 : use tokio_postgres::{GenericClient, IsolationLevel, NoTls, ReadyForQueryStatus, Transaction};
20 : use tokio_util::sync::CancellationToken;
21 : use tracing::{error, info};
22 : use typed_json::json;
23 : use url::Url;
24 : use urlencoding;
25 : use utils::http::error::ApiError;
26 :
27 : use super::backend::{LocalProxyConnError, PoolingBackend};
28 : use super::conn_pool::{AuthData, ConnInfoWithAuth};
29 : use super::conn_pool_lib::{self, ConnInfo};
30 : use super::http_util::json_response;
31 : use super::json::{json_to_pg_text, pg_text_row_to_json, JsonConversionError};
32 : use super::local_conn_pool;
33 : use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
34 : use crate::auth::{endpoint_sni, ComputeUserInfoParseError};
35 : use crate::config::{AuthenticationConfig, HttpConfig, ProxyConfig, TlsConfig};
36 : use crate::context::RequestMonitoring;
37 : use crate::error::{ErrorKind, ReportableError, UserFacingError};
38 : use crate::metrics::{HttpDirection, Metrics};
39 : use crate::proxy::{run_until_cancelled, NeonOptions};
40 : use crate::serverless::backend::HttpConnError;
41 : use crate::usage_metrics::{MetricCounter, MetricCounterRecorder};
42 : use crate::{DbName, RoleName};
43 :
44 0 : #[derive(serde::Deserialize)]
45 : #[serde(rename_all = "camelCase")]
46 : struct QueryData {
47 : query: String,
48 : #[serde(deserialize_with = "bytes_to_pg_text")]
49 : params: Vec<Option<String>>,
50 : #[serde(default)]
51 : array_mode: Option<bool>,
52 : }
53 :
54 0 : #[derive(serde::Deserialize)]
55 : struct BatchQueryData {
56 : queries: Vec<QueryData>,
57 : }
58 :
59 : #[derive(serde::Deserialize)]
60 : #[serde(untagged)]
61 : enum Payload {
62 : Single(QueryData),
63 : Batch(BatchQueryData),
64 : }
65 :
66 : static CONN_STRING: HeaderName = HeaderName::from_static("neon-connection-string");
67 : static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output");
68 : static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode");
69 : static ALLOW_POOL: HeaderName = HeaderName::from_static("neon-pool-opt-in");
70 : static TXN_ISOLATION_LEVEL: HeaderName = HeaderName::from_static("neon-batch-isolation-level");
71 : static TXN_READ_ONLY: HeaderName = HeaderName::from_static("neon-batch-read-only");
72 : static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrable");
73 :
74 : static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true");
75 :
76 0 : fn bytes_to_pg_text<'de, D>(deserializer: D) -> Result<Vec<Option<String>>, D::Error>
77 0 : where
78 0 : D: serde::de::Deserializer<'de>,
79 0 : {
80 : // TODO: consider avoiding the allocation here.
81 0 : let json: Vec<Value> = serde::de::Deserialize::deserialize(deserializer)?;
82 0 : Ok(json_to_pg_text(json))
83 0 : }
84 :
85 0 : #[derive(Debug, thiserror::Error)]
86 : pub(crate) enum ConnInfoError {
87 : #[error("invalid header: {0}")]
88 : InvalidHeader(&'static HeaderName),
89 : #[error("invalid connection string: {0}")]
90 : UrlParseError(#[from] url::ParseError),
91 : #[error("incorrect scheme")]
92 : IncorrectScheme,
93 : #[error("missing database name")]
94 : MissingDbName,
95 : #[error("invalid database name")]
96 : InvalidDbName,
97 : #[error("missing username")]
98 : MissingUsername,
99 : #[error("invalid username: {0}")]
100 : InvalidUsername(#[from] std::string::FromUtf8Error),
101 : #[error("missing authentication credentials: {0}")]
102 : MissingCredentials(Credentials),
103 : #[error("missing hostname")]
104 : MissingHostname,
105 : #[error("invalid hostname: {0}")]
106 : InvalidEndpoint(#[from] ComputeUserInfoParseError),
107 : #[error("malformed endpoint")]
108 : MalformedEndpoint,
109 : }
110 :
111 0 : #[derive(Debug, thiserror::Error)]
112 : pub(crate) enum Credentials {
113 : #[error("required password")]
114 : Password,
115 : #[error("required authorization bearer token in JWT format")]
116 : BearerJwt,
117 : }
118 :
119 : impl ReportableError for ConnInfoError {
120 0 : fn get_error_kind(&self) -> ErrorKind {
121 0 : ErrorKind::User
122 0 : }
123 : }
124 :
125 : impl UserFacingError for ConnInfoError {
126 0 : fn to_string_client(&self) -> String {
127 0 : self.to_string()
128 0 : }
129 : }
130 :
131 0 : fn get_conn_info(
132 0 : config: &'static AuthenticationConfig,
133 0 : ctx: &RequestMonitoring,
134 0 : headers: &HeaderMap,
135 0 : tls: Option<&TlsConfig>,
136 0 : ) -> Result<ConnInfoWithAuth, ConnInfoError> {
137 0 : // HTTP only uses cleartext (for now and likely always)
138 0 : ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
139 :
140 0 : let connection_string = headers
141 0 : .get(&CONN_STRING)
142 0 : .ok_or(ConnInfoError::InvalidHeader(&CONN_STRING))?
143 0 : .to_str()
144 0 : .map_err(|_| ConnInfoError::InvalidHeader(&CONN_STRING))?;
145 :
146 0 : let connection_url = Url::parse(connection_string)?;
147 :
148 0 : let protocol = connection_url.scheme();
149 0 : if protocol != "postgres" && protocol != "postgresql" {
150 0 : return Err(ConnInfoError::IncorrectScheme);
151 0 : }
152 :
153 0 : let mut url_path = connection_url
154 0 : .path_segments()
155 0 : .ok_or(ConnInfoError::MissingDbName)?;
156 :
157 0 : let dbname: DbName =
158 0 : urlencoding::decode(url_path.next().ok_or(ConnInfoError::InvalidDbName)?)?.into();
159 0 : ctx.set_dbname(dbname.clone());
160 :
161 0 : let username = RoleName::from(urlencoding::decode(connection_url.username())?);
162 0 : if username.is_empty() {
163 0 : return Err(ConnInfoError::MissingUsername);
164 0 : }
165 0 : ctx.set_user(username.clone());
166 :
167 0 : let auth = if let Some(auth) = headers.get(&AUTHORIZATION) {
168 0 : if !config.accept_jwts {
169 0 : return Err(ConnInfoError::MissingCredentials(Credentials::Password));
170 0 : }
171 :
172 0 : let auth = auth
173 0 : .to_str()
174 0 : .map_err(|_| ConnInfoError::InvalidHeader(&AUTHORIZATION))?;
175 : AuthData::Jwt(
176 0 : auth.strip_prefix("Bearer ")
177 0 : .ok_or(ConnInfoError::MissingCredentials(Credentials::BearerJwt))?
178 0 : .into(),
179 : )
180 0 : } else if let Some(pass) = connection_url.password() {
181 : // wrong credentials provided
182 0 : if config.accept_jwts {
183 0 : return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
184 0 : }
185 0 :
186 0 : AuthData::Password(match urlencoding::decode_binary(pass.as_bytes()) {
187 0 : std::borrow::Cow::Borrowed(b) => b.into(),
188 0 : std::borrow::Cow::Owned(b) => b.into(),
189 : })
190 0 : } else if config.accept_jwts {
191 0 : return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
192 : } else {
193 0 : return Err(ConnInfoError::MissingCredentials(Credentials::Password));
194 : };
195 :
196 0 : let endpoint = match connection_url.host() {
197 0 : Some(url::Host::Domain(hostname)) => {
198 0 : if let Some(tls) = tls {
199 0 : endpoint_sni(hostname, &tls.common_names)?
200 0 : .ok_or(ConnInfoError::MalformedEndpoint)?
201 : } else {
202 0 : hostname
203 0 : .split_once('.')
204 0 : .map_or(hostname, |(prefix, _)| prefix)
205 0 : .into()
206 : }
207 : }
208 : Some(url::Host::Ipv4(_) | url::Host::Ipv6(_)) | None => {
209 0 : return Err(ConnInfoError::MissingHostname)
210 : }
211 : };
212 0 : ctx.set_endpoint_id(endpoint.clone());
213 0 :
214 0 : let pairs = connection_url.query_pairs();
215 0 :
216 0 : let mut options = Option::None;
217 0 :
218 0 : let mut params = StartupMessageParamsBuilder::default();
219 0 : params.insert("user", &username);
220 0 : params.insert("database", &dbname);
221 0 : for (key, value) in pairs {
222 0 : params.insert(&key, &value);
223 0 : if key == "options" {
224 0 : options = Some(NeonOptions::parse_options_raw(&value));
225 0 : }
226 : }
227 :
228 0 : let user_info = ComputeUserInfo {
229 0 : endpoint,
230 0 : user: username,
231 0 : options: options.unwrap_or_default(),
232 0 : };
233 0 :
234 0 : let conn_info = ConnInfo { user_info, dbname };
235 0 : Ok(ConnInfoWithAuth { conn_info, auth })
236 0 : }
237 :
238 : // TODO: return different http error codes
239 0 : pub(crate) async fn handle(
240 0 : config: &'static ProxyConfig,
241 0 : ctx: RequestMonitoring,
242 0 : request: Request<Incoming>,
243 0 : backend: Arc<PoolingBackend>,
244 0 : cancel: CancellationToken,
245 0 : ) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
246 0 : let result = handle_inner(cancel, config, &ctx, request, backend).await;
247 :
248 0 : let mut response = match result {
249 0 : Ok(r) => {
250 0 : ctx.set_success();
251 0 : r
252 : }
253 0 : Err(e @ SqlOverHttpError::Cancelled(_)) => {
254 0 : let error_kind = e.get_error_kind();
255 0 : ctx.set_error_kind(error_kind);
256 0 :
257 0 : let message = "Query cancelled, connection was terminated";
258 0 :
259 0 : tracing::info!(
260 0 : kind=error_kind.to_metric_label(),
261 0 : error=%e,
262 0 : msg=message,
263 0 : "forwarding error to user"
264 : );
265 :
266 0 : json_response(
267 0 : StatusCode::BAD_REQUEST,
268 0 : json!({ "message": message, "code": SqlState::PROTOCOL_VIOLATION.code() }),
269 0 : )?
270 : }
271 0 : Err(e) => {
272 0 : let error_kind = e.get_error_kind();
273 0 : ctx.set_error_kind(error_kind);
274 0 :
275 0 : let mut message = e.to_string_client();
276 0 : let db_error = match &e {
277 0 : SqlOverHttpError::ConnectCompute(HttpConnError::PostgresConnectionError(e))
278 0 : | SqlOverHttpError::Postgres(e) => e.as_db_error(),
279 0 : _ => None,
280 : };
281 0 : fn get<'a, T: Default>(db: Option<&'a DbError>, x: impl FnOnce(&'a DbError) -> T) -> T {
282 0 : db.map(x).unwrap_or_default()
283 0 : }
284 :
285 0 : if let Some(db_error) = db_error {
286 0 : db_error.message().clone_into(&mut message);
287 0 : }
288 :
289 0 : let position = db_error.and_then(|db| db.position());
290 0 : let (position, internal_position, internal_query) = match position {
291 0 : Some(ErrorPosition::Original(position)) => (Some(position.to_string()), None, None),
292 0 : Some(ErrorPosition::Internal { position, query }) => {
293 0 : (None, Some(position.to_string()), Some(query.clone()))
294 : }
295 0 : None => (None, None, None),
296 : };
297 :
298 0 : let code = get(db_error, |db| db.code().code());
299 0 : let severity = get(db_error, |db| db.severity());
300 0 : let detail = get(db_error, |db| db.detail());
301 0 : let hint = get(db_error, |db| db.hint());
302 0 : let where_ = get(db_error, |db| db.where_());
303 0 : let table = get(db_error, |db| db.table());
304 0 : let column = get(db_error, |db| db.column());
305 0 : let schema = get(db_error, |db| db.schema());
306 0 : let datatype = get(db_error, |db| db.datatype());
307 0 : let constraint = get(db_error, |db| db.constraint());
308 0 : let file = get(db_error, |db| db.file());
309 0 : let line = get(db_error, |db| db.line().map(|l| l.to_string()));
310 0 : let routine = get(db_error, |db| db.routine());
311 0 :
312 0 : tracing::info!(
313 0 : kind=error_kind.to_metric_label(),
314 0 : error=%e,
315 0 : msg=message,
316 0 : "forwarding error to user"
317 : );
318 :
319 : // TODO: this shouldn't always be bad request.
320 0 : json_response(
321 0 : StatusCode::BAD_REQUEST,
322 0 : json!({
323 0 : "message": message,
324 0 : "code": code,
325 0 : "detail": detail,
326 0 : "hint": hint,
327 0 : "position": position,
328 0 : "internalPosition": internal_position,
329 0 : "internalQuery": internal_query,
330 0 : "severity": severity,
331 0 : "where": where_,
332 0 : "table": table,
333 0 : "column": column,
334 0 : "schema": schema,
335 0 : "dataType": datatype,
336 0 : "constraint": constraint,
337 0 : "file": file,
338 0 : "line": line,
339 0 : "routine": routine,
340 0 : }),
341 0 : )?
342 : }
343 : };
344 :
345 0 : response
346 0 : .headers_mut()
347 0 : .insert("Access-Control-Allow-Origin", HeaderValue::from_static("*"));
348 0 : Ok(response)
349 0 : }
350 :
351 0 : #[derive(Debug, thiserror::Error)]
352 : pub(crate) enum SqlOverHttpError {
353 : #[error("{0}")]
354 : ReadPayload(#[from] ReadPayloadError),
355 : #[error("{0}")]
356 : ConnectCompute(#[from] HttpConnError),
357 : #[error("{0}")]
358 : ConnInfo(#[from] ConnInfoError),
359 : #[error("request is too large (max is {0} bytes)")]
360 : RequestTooLarge(u64),
361 : #[error("response is too large (max is {0} bytes)")]
362 : ResponseTooLarge(usize),
363 : #[error("invalid isolation level")]
364 : InvalidIsolationLevel,
365 : #[error("{0}")]
366 : Postgres(#[from] tokio_postgres::Error),
367 : #[error("{0}")]
368 : JsonConversion(#[from] JsonConversionError),
369 : #[error("{0}")]
370 : Cancelled(SqlOverHttpCancel),
371 : }
372 :
373 : impl ReportableError for SqlOverHttpError {
374 0 : fn get_error_kind(&self) -> ErrorKind {
375 0 : match self {
376 0 : SqlOverHttpError::ReadPayload(e) => e.get_error_kind(),
377 0 : SqlOverHttpError::ConnectCompute(e) => e.get_error_kind(),
378 0 : SqlOverHttpError::ConnInfo(e) => e.get_error_kind(),
379 0 : SqlOverHttpError::RequestTooLarge(_) => ErrorKind::User,
380 0 : SqlOverHttpError::ResponseTooLarge(_) => ErrorKind::User,
381 0 : SqlOverHttpError::InvalidIsolationLevel => ErrorKind::User,
382 0 : SqlOverHttpError::Postgres(p) => p.get_error_kind(),
383 0 : SqlOverHttpError::JsonConversion(_) => ErrorKind::Postgres,
384 0 : SqlOverHttpError::Cancelled(c) => c.get_error_kind(),
385 : }
386 0 : }
387 : }
388 :
389 : impl UserFacingError for SqlOverHttpError {
390 0 : fn to_string_client(&self) -> String {
391 0 : match self {
392 0 : SqlOverHttpError::ReadPayload(p) => p.to_string(),
393 0 : SqlOverHttpError::ConnectCompute(c) => c.to_string_client(),
394 0 : SqlOverHttpError::ConnInfo(c) => c.to_string_client(),
395 0 : SqlOverHttpError::RequestTooLarge(_) => self.to_string(),
396 0 : SqlOverHttpError::ResponseTooLarge(_) => self.to_string(),
397 0 : SqlOverHttpError::InvalidIsolationLevel => self.to_string(),
398 0 : SqlOverHttpError::Postgres(p) => p.to_string(),
399 0 : SqlOverHttpError::JsonConversion(_) => "could not parse postgres response".to_string(),
400 0 : SqlOverHttpError::Cancelled(_) => self.to_string(),
401 : }
402 0 : }
403 : }
404 :
405 0 : #[derive(Debug, thiserror::Error)]
406 : pub(crate) enum ReadPayloadError {
407 : #[error("could not read the HTTP request body: {0}")]
408 : Read(#[from] hyper::Error),
409 : #[error("could not parse the HTTP request body: {0}")]
410 : Parse(#[from] serde_json::Error),
411 : }
412 :
413 : impl ReportableError for ReadPayloadError {
414 0 : fn get_error_kind(&self) -> ErrorKind {
415 0 : match self {
416 0 : ReadPayloadError::Read(_) => ErrorKind::ClientDisconnect,
417 0 : ReadPayloadError::Parse(_) => ErrorKind::User,
418 : }
419 0 : }
420 : }
421 :
422 0 : #[derive(Debug, thiserror::Error)]
423 : pub(crate) enum SqlOverHttpCancel {
424 : #[error("query was cancelled")]
425 : Postgres,
426 : #[error("query was cancelled while stuck trying to connect to the database")]
427 : Connect,
428 : }
429 :
430 : impl ReportableError for SqlOverHttpCancel {
431 0 : fn get_error_kind(&self) -> ErrorKind {
432 0 : match self {
433 0 : SqlOverHttpCancel::Postgres => ErrorKind::ClientDisconnect,
434 0 : SqlOverHttpCancel::Connect => ErrorKind::ClientDisconnect,
435 : }
436 0 : }
437 : }
438 :
439 : #[derive(Clone, Copy, Debug)]
440 : struct HttpHeaders {
441 : raw_output: bool,
442 : default_array_mode: bool,
443 : txn_isolation_level: Option<IsolationLevel>,
444 : txn_read_only: bool,
445 : txn_deferrable: bool,
446 : }
447 :
448 : impl HttpHeaders {
449 0 : fn try_parse(headers: &hyper::http::HeaderMap) -> Result<Self, SqlOverHttpError> {
450 0 : // Determine the output options. Default behaviour is 'false'. Anything that is not
451 0 : // strictly 'true' assumed to be false.
452 0 : let raw_output = headers.get(&RAW_TEXT_OUTPUT) == Some(&HEADER_VALUE_TRUE);
453 0 : let default_array_mode = headers.get(&ARRAY_MODE) == Some(&HEADER_VALUE_TRUE);
454 :
455 : // isolation level, read only and deferrable
456 0 : let txn_isolation_level = match headers.get(&TXN_ISOLATION_LEVEL) {
457 0 : Some(x) => Some(
458 0 : map_header_to_isolation_level(x).ok_or(SqlOverHttpError::InvalidIsolationLevel)?,
459 : ),
460 0 : None => None,
461 : };
462 :
463 0 : let txn_read_only = headers.get(&TXN_READ_ONLY) == Some(&HEADER_VALUE_TRUE);
464 0 : let txn_deferrable = headers.get(&TXN_DEFERRABLE) == Some(&HEADER_VALUE_TRUE);
465 0 :
466 0 : Ok(Self {
467 0 : raw_output,
468 0 : default_array_mode,
469 0 : txn_isolation_level,
470 0 : txn_read_only,
471 0 : txn_deferrable,
472 0 : })
473 0 : }
474 : }
475 :
476 0 : fn map_header_to_isolation_level(level: &HeaderValue) -> Option<IsolationLevel> {
477 0 : match level.as_bytes() {
478 0 : b"Serializable" => Some(IsolationLevel::Serializable),
479 0 : b"ReadUncommitted" => Some(IsolationLevel::ReadUncommitted),
480 0 : b"ReadCommitted" => Some(IsolationLevel::ReadCommitted),
481 0 : b"RepeatableRead" => Some(IsolationLevel::RepeatableRead),
482 0 : _ => None,
483 : }
484 0 : }
485 :
486 0 : fn map_isolation_level_to_headers(level: IsolationLevel) -> Option<HeaderValue> {
487 0 : match level {
488 0 : IsolationLevel::ReadUncommitted => Some(HeaderValue::from_static("ReadUncommitted")),
489 0 : IsolationLevel::ReadCommitted => Some(HeaderValue::from_static("ReadCommitted")),
490 0 : IsolationLevel::RepeatableRead => Some(HeaderValue::from_static("RepeatableRead")),
491 0 : IsolationLevel::Serializable => Some(HeaderValue::from_static("Serializable")),
492 0 : _ => None,
493 : }
494 0 : }
495 :
496 0 : async fn handle_inner(
497 0 : cancel: CancellationToken,
498 0 : config: &'static ProxyConfig,
499 0 : ctx: &RequestMonitoring,
500 0 : request: Request<Incoming>,
501 0 : backend: Arc<PoolingBackend>,
502 0 : ) -> Result<Response<BoxBody<Bytes, hyper::Error>>, SqlOverHttpError> {
503 0 : let _requeset_gauge = Metrics::get()
504 0 : .proxy
505 0 : .connection_requests
506 0 : .guard(ctx.protocol());
507 0 : info!(
508 0 : protocol = %ctx.protocol(),
509 0 : "handling interactive connection from client"
510 : );
511 :
512 0 : let conn_info = get_conn_info(
513 0 : &config.authentication_config,
514 0 : ctx,
515 0 : request.headers(),
516 0 : config.tls_config.as_ref(),
517 0 : )?;
518 0 : info!(
519 0 : user = conn_info.conn_info.user_info.user.as_str(),
520 0 : "credentials"
521 : );
522 :
523 0 : match conn_info.auth {
524 0 : AuthData::Jwt(jwt) if config.authentication_config.is_auth_broker => {
525 0 : handle_auth_broker_inner(ctx, request, conn_info.conn_info, jwt, backend).await
526 : }
527 0 : auth => {
528 0 : handle_db_inner(
529 0 : cancel,
530 0 : config,
531 0 : ctx,
532 0 : request,
533 0 : conn_info.conn_info,
534 0 : auth,
535 0 : backend,
536 0 : )
537 0 : .await
538 : }
539 : }
540 0 : }
541 :
542 0 : async fn handle_db_inner(
543 0 : cancel: CancellationToken,
544 0 : config: &'static ProxyConfig,
545 0 : ctx: &RequestMonitoring,
546 0 : request: Request<Incoming>,
547 0 : conn_info: ConnInfo,
548 0 : auth: AuthData,
549 0 : backend: Arc<PoolingBackend>,
550 0 : ) -> Result<Response<BoxBody<Bytes, hyper::Error>>, SqlOverHttpError> {
551 0 : //
552 0 : // Determine the destination and connection params
553 0 : //
554 0 : let headers = request.headers();
555 :
556 : // Allow connection pooling only if explicitly requested
557 : // or if we have decided that http pool is no longer opt-in
558 0 : let allow_pool = !config.http_config.pool_options.opt_in
559 0 : || headers.get(&ALLOW_POOL) == Some(&HEADER_VALUE_TRUE);
560 :
561 0 : let parsed_headers = HttpHeaders::try_parse(headers)?;
562 :
563 0 : let request_content_length = match request.body().size_hint().upper() {
564 0 : Some(v) => v,
565 0 : None => config.http_config.max_request_size_bytes + 1,
566 : };
567 0 : info!(request_content_length, "request size in bytes");
568 0 : Metrics::get()
569 0 : .proxy
570 0 : .http_conn_content_length_bytes
571 0 : .observe(HttpDirection::Request, request_content_length as f64);
572 0 :
573 0 : // we don't have a streaming request support yet so this is to prevent OOM
574 0 : // from a malicious user sending an extremely large request body
575 0 : if request_content_length > config.http_config.max_request_size_bytes {
576 0 : return Err(SqlOverHttpError::RequestTooLarge(
577 0 : config.http_config.max_request_size_bytes,
578 0 : ));
579 0 : }
580 0 :
581 0 : let fetch_and_process_request = Box::pin(
582 0 : async {
583 0 : let body = request.into_body().collect().await?.to_bytes();
584 0 : info!(length = body.len(), "request payload read");
585 0 : let payload: Payload = serde_json::from_slice(&body)?;
586 0 : Ok::<Payload, ReadPayloadError>(payload) // Adjust error type accordingly
587 0 : }
588 0 : .map_err(SqlOverHttpError::from),
589 0 : );
590 0 :
591 0 : let authenticate_and_connect = Box::pin(
592 0 : async {
593 0 : let is_local_proxy = matches!(backend.auth_backend, crate::auth::Backend::Local(_));
594 :
595 0 : let keys = match auth {
596 0 : AuthData::Password(pw) => {
597 0 : backend
598 0 : .authenticate_with_password(ctx, &conn_info.user_info, &pw)
599 0 : .await?
600 : }
601 0 : AuthData::Jwt(jwt) => {
602 0 : backend
603 0 : .authenticate_with_jwt(ctx, &conn_info.user_info, jwt)
604 0 : .await?
605 : }
606 : };
607 :
608 0 : let client = match keys.keys {
609 0 : ComputeCredentialKeys::JwtPayload(payload) if is_local_proxy => {
610 0 : let mut client = backend.connect_to_local_postgres(ctx, conn_info).await?;
611 0 : let (cli_inner, _dsc) = client.client_inner();
612 0 : cli_inner.set_jwt_session(&payload).await?;
613 0 : Client::Local(client)
614 : }
615 : _ => {
616 0 : let client = backend
617 0 : .connect_to_compute(ctx, conn_info, keys, !allow_pool)
618 0 : .await?;
619 0 : Client::Remote(client)
620 : }
621 : };
622 :
623 : // not strictly necessary to mark success here,
624 : // but it's just insurance for if we forget it somewhere else
625 0 : ctx.success();
626 0 : Ok::<_, HttpConnError>(client)
627 0 : }
628 0 : .map_err(SqlOverHttpError::from),
629 0 : );
630 :
631 0 : let (payload, mut client) = match run_until_cancelled(
632 0 : // Run both operations in parallel
633 0 : try_join(
634 0 : pin!(fetch_and_process_request),
635 0 : pin!(authenticate_and_connect),
636 0 : ),
637 0 : &cancel,
638 0 : )
639 0 : .await
640 : {
641 0 : Some(result) => result?,
642 0 : None => return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Connect)),
643 : };
644 :
645 0 : let mut response = Response::builder()
646 0 : .status(StatusCode::OK)
647 0 : .header(header::CONTENT_TYPE, "application/json");
648 :
649 : // Now execute the query and return the result.
650 0 : let json_output = match payload {
651 0 : Payload::Single(stmt) => {
652 0 : stmt.process(&config.http_config, cancel, &mut client, parsed_headers)
653 0 : .await?
654 : }
655 0 : Payload::Batch(statements) => {
656 0 : if parsed_headers.txn_read_only {
657 0 : response = response.header(TXN_READ_ONLY.clone(), &HEADER_VALUE_TRUE);
658 0 : }
659 0 : if parsed_headers.txn_deferrable {
660 0 : response = response.header(TXN_DEFERRABLE.clone(), &HEADER_VALUE_TRUE);
661 0 : }
662 0 : if let Some(txn_isolation_level) = parsed_headers
663 0 : .txn_isolation_level
664 0 : .and_then(map_isolation_level_to_headers)
665 0 : {
666 0 : response = response.header(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level);
667 0 : }
668 :
669 0 : statements
670 0 : .process(&config.http_config, cancel, &mut client, parsed_headers)
671 0 : .await?
672 : }
673 : };
674 :
675 0 : let metrics = client.metrics();
676 0 :
677 0 : let len = json_output.len();
678 0 : let response = response
679 0 : .body(
680 0 : Full::new(Bytes::from(json_output))
681 0 : .map_err(|x| match x {})
682 0 : .boxed(),
683 0 : )
684 0 : // only fails if invalid status code or invalid header/values are given.
685 0 : // these are not user configurable so it cannot fail dynamically
686 0 : .expect("building response payload should not fail");
687 0 :
688 0 : // count the egress bytes - we miss the TLS and header overhead but oh well...
689 0 : // moving this later in the stack is going to be a lot of effort and ehhhh
690 0 : metrics.record_egress(len as u64);
691 0 : Metrics::get()
692 0 : .proxy
693 0 : .http_conn_content_length_bytes
694 0 : .observe(HttpDirection::Response, len as f64);
695 0 :
696 0 : Ok(response)
697 0 : }
698 :
699 : static HEADERS_TO_FORWARD: &[&HeaderName] = &[
700 : &AUTHORIZATION,
701 : &CONN_STRING,
702 : &RAW_TEXT_OUTPUT,
703 : &ARRAY_MODE,
704 : &TXN_ISOLATION_LEVEL,
705 : &TXN_READ_ONLY,
706 : &TXN_DEFERRABLE,
707 : ];
708 :
709 0 : async fn handle_auth_broker_inner(
710 0 : ctx: &RequestMonitoring,
711 0 : request: Request<Incoming>,
712 0 : conn_info: ConnInfo,
713 0 : jwt: String,
714 0 : backend: Arc<PoolingBackend>,
715 0 : ) -> Result<Response<BoxBody<Bytes, hyper::Error>>, SqlOverHttpError> {
716 0 : backend
717 0 : .authenticate_with_jwt(ctx, &conn_info.user_info, jwt)
718 0 : .await
719 0 : .map_err(HttpConnError::from)?;
720 :
721 0 : let mut client = backend.connect_to_local_proxy(ctx, conn_info).await?;
722 :
723 0 : let local_proxy_uri = ::http::Uri::from_static("http://proxy.local/sql");
724 0 :
725 0 : let (mut parts, body) = request.into_parts();
726 0 : let mut req = Request::builder().method(Method::POST).uri(local_proxy_uri);
727 :
728 : // todo(conradludgate): maybe auth-broker should parse these and re-serialize
729 : // these instead just to ensure they remain normalised.
730 0 : for &h in HEADERS_TO_FORWARD {
731 0 : if let Some(hv) = parts.headers.remove(h) {
732 0 : req = req.header(h, hv);
733 0 : }
734 : }
735 :
736 0 : let req = req
737 0 : .body(body)
738 0 : .expect("all headers and params received via hyper should be valid for request");
739 0 :
740 0 : // todo: map body to count egress
741 0 : let _metrics = client.metrics();
742 0 :
743 0 : Ok(client
744 0 : .inner
745 0 : .send_request(req)
746 0 : .await
747 0 : .map_err(LocalProxyConnError::from)
748 0 : .map_err(HttpConnError::from)?
749 0 : .map(|b| b.boxed()))
750 0 : }
751 :
752 : impl QueryData {
753 0 : async fn process(
754 0 : self,
755 0 : config: &'static HttpConfig,
756 0 : cancel: CancellationToken,
757 0 : client: &mut Client,
758 0 : parsed_headers: HttpHeaders,
759 0 : ) -> Result<String, SqlOverHttpError> {
760 0 : let (inner, mut discard) = client.inner();
761 0 : let cancel_token = inner.cancel_token();
762 :
763 0 : let res = match select(
764 0 : pin!(query_to_json(config, &*inner, self, &mut 0, parsed_headers)),
765 0 : pin!(cancel.cancelled()),
766 0 : )
767 0 : .await
768 : {
769 : // The query successfully completed.
770 0 : Either::Left((Ok((status, results)), __not_yet_cancelled)) => {
771 0 : discard.check_idle(status);
772 0 :
773 0 : let json_output =
774 0 : serde_json::to_string(&results).expect("json serialization should not fail");
775 0 : Ok(json_output)
776 : }
777 : // The query failed with an error
778 0 : Either::Left((Err(e), __not_yet_cancelled)) => {
779 0 : discard.discard();
780 0 : return Err(e);
781 : }
782 : // The query was cancelled.
783 0 : Either::Right((_cancelled, query)) => {
784 0 : tracing::info!("cancelling query");
785 0 : if let Err(err) = cancel_token.cancel_query(NoTls).await {
786 0 : tracing::warn!(?err, "could not cancel query");
787 0 : }
788 : // wait for the query cancellation
789 0 : match time::timeout(time::Duration::from_millis(100), query).await {
790 : // query successed before it was cancelled.
791 0 : Ok(Ok((status, results))) => {
792 0 : discard.check_idle(status);
793 0 :
794 0 : let json_output = serde_json::to_string(&results)
795 0 : .expect("json serialization should not fail");
796 0 : Ok(json_output)
797 : }
798 : // query failed or was cancelled.
799 0 : Ok(Err(error)) => {
800 0 : let db_error = match &error {
801 : SqlOverHttpError::ConnectCompute(
802 0 : HttpConnError::PostgresConnectionError(e),
803 : )
804 0 : | SqlOverHttpError::Postgres(e) => e.as_db_error(),
805 0 : _ => None,
806 : };
807 :
808 : // if errored for some other reason, it might not be safe to return
809 0 : if !db_error.is_some_and(|e| *e.code() == SqlState::QUERY_CANCELED) {
810 0 : discard.discard();
811 0 : }
812 :
813 0 : Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres))
814 : }
815 0 : Err(_timeout) => {
816 0 : discard.discard();
817 0 : Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres))
818 : }
819 : }
820 : }
821 : };
822 0 : res
823 0 : }
824 : }
825 :
826 : impl BatchQueryData {
827 0 : async fn process(
828 0 : self,
829 0 : config: &'static HttpConfig,
830 0 : cancel: CancellationToken,
831 0 : client: &mut Client,
832 0 : parsed_headers: HttpHeaders,
833 0 : ) -> Result<String, SqlOverHttpError> {
834 0 : info!("starting transaction");
835 0 : let (inner, mut discard) = client.inner();
836 0 : let cancel_token = inner.cancel_token();
837 0 : let mut builder = inner.build_transaction();
838 0 : if let Some(isolation_level) = parsed_headers.txn_isolation_level {
839 0 : builder = builder.isolation_level(isolation_level);
840 0 : }
841 0 : if parsed_headers.txn_read_only {
842 0 : builder = builder.read_only(true);
843 0 : }
844 0 : if parsed_headers.txn_deferrable {
845 0 : builder = builder.deferrable(true);
846 0 : }
847 :
848 0 : let transaction = builder.start().await.inspect_err(|_| {
849 0 : // if we cannot start a transaction, we should return immediately
850 0 : // and not return to the pool. connection is clearly broken
851 0 : discard.discard();
852 0 : })?;
853 :
854 0 : let json_output = match query_batch(
855 0 : config,
856 0 : cancel.child_token(),
857 0 : &transaction,
858 0 : self,
859 0 : parsed_headers,
860 0 : )
861 0 : .await
862 : {
863 0 : Ok(json_output) => {
864 0 : info!("commit");
865 0 : let status = transaction.commit().await.inspect_err(|_| {
866 0 : // if we cannot commit - for now don't return connection to pool
867 0 : // TODO: get a query status from the error
868 0 : discard.discard();
869 0 : })?;
870 0 : discard.check_idle(status);
871 0 : json_output
872 : }
873 : Err(SqlOverHttpError::Cancelled(_)) => {
874 0 : if let Err(err) = cancel_token.cancel_query(NoTls).await {
875 0 : tracing::warn!(?err, "could not cancel query");
876 0 : }
877 : // TODO: after cancelling, wait to see if we can get a status. maybe the connection is still safe.
878 0 : discard.discard();
879 0 :
880 0 : return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres));
881 : }
882 0 : Err(err) => {
883 0 : info!("rollback");
884 0 : let status = transaction.rollback().await.inspect_err(|_| {
885 0 : // if we cannot rollback - for now don't return connection to pool
886 0 : // TODO: get a query status from the error
887 0 : discard.discard();
888 0 : })?;
889 0 : discard.check_idle(status);
890 0 : return Err(err);
891 : }
892 : };
893 :
894 0 : Ok(json_output)
895 0 : }
896 : }
897 :
898 0 : async fn query_batch(
899 0 : config: &'static HttpConfig,
900 0 : cancel: CancellationToken,
901 0 : transaction: &Transaction<'_>,
902 0 : queries: BatchQueryData,
903 0 : parsed_headers: HttpHeaders,
904 0 : ) -> Result<String, SqlOverHttpError> {
905 0 : let mut results = Vec::with_capacity(queries.queries.len());
906 0 : let mut current_size = 0;
907 0 : for stmt in queries.queries {
908 0 : let query = pin!(query_to_json(
909 0 : config,
910 0 : transaction,
911 0 : stmt,
912 0 : &mut current_size,
913 0 : parsed_headers,
914 0 : ));
915 0 : let cancelled = pin!(cancel.cancelled());
916 0 : let res = select(query, cancelled).await;
917 0 : match res {
918 : // TODO: maybe we should check that the transaction bit is set here
919 0 : Either::Left((Ok((_, values)), _cancelled)) => {
920 0 : results.push(values);
921 0 : }
922 0 : Either::Left((Err(e), _cancelled)) => {
923 0 : return Err(e);
924 : }
925 0 : Either::Right((_cancelled, _)) => {
926 0 : return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres));
927 : }
928 : }
929 : }
930 :
931 0 : let results = json!({ "results": results });
932 0 : let json_output = serde_json::to_string(&results).expect("json serialization should not fail");
933 0 :
934 0 : Ok(json_output)
935 0 : }
936 :
937 0 : async fn query_to_json<T: GenericClient>(
938 0 : config: &'static HttpConfig,
939 0 : client: &T,
940 0 : data: QueryData,
941 0 : current_size: &mut usize,
942 0 : parsed_headers: HttpHeaders,
943 0 : ) -> Result<(ReadyForQueryStatus, impl Serialize), SqlOverHttpError> {
944 0 : info!("executing query");
945 0 : let query_params = data.params;
946 0 : let mut row_stream = std::pin::pin!(client.query_raw_txt(&data.query, query_params).await?);
947 0 : info!("finished executing query");
948 :
949 : // Manually drain the stream into a vector to leave row_stream hanging
950 : // around to get a command tag. Also check that the response is not too
951 : // big.
952 0 : let mut rows: Vec<tokio_postgres::Row> = Vec::new();
953 0 : while let Some(row) = row_stream.next().await {
954 0 : let row = row?;
955 0 : *current_size += row.body_len();
956 0 : rows.push(row);
957 0 : // we don't have a streaming response support yet so this is to prevent OOM
958 0 : // from a malicious query (eg a cross join)
959 0 : if *current_size > config.max_response_size_bytes {
960 0 : return Err(SqlOverHttpError::ResponseTooLarge(
961 0 : config.max_response_size_bytes,
962 0 : ));
963 0 : }
964 : }
965 :
966 0 : let ready = row_stream.ready_status();
967 0 :
968 0 : // grab the command tag and number of rows affected
969 0 : let command_tag = row_stream.command_tag().unwrap_or_default();
970 0 : let mut command_tag_split = command_tag.split(' ');
971 0 : let command_tag_name = command_tag_split.next().unwrap_or_default();
972 0 : let command_tag_count = if command_tag_name == "INSERT" {
973 : // INSERT returns OID first and then number of rows
974 0 : command_tag_split.nth(1)
975 : } else {
976 : // other commands return number of rows (if any)
977 0 : command_tag_split.next()
978 : }
979 0 : .and_then(|s| s.parse::<i64>().ok());
980 0 :
981 0 : info!(
982 0 : rows = rows.len(),
983 0 : ?ready,
984 0 : command_tag,
985 0 : "finished reading rows"
986 : );
987 :
988 0 : let columns_len = row_stream.columns().len();
989 0 : let mut fields = Vec::with_capacity(columns_len);
990 0 : let mut columns = Vec::with_capacity(columns_len);
991 :
992 0 : for c in row_stream.columns() {
993 0 : fields.push(json!({
994 0 : "name": c.name().to_owned(),
995 0 : "dataTypeID": c.type_().oid(),
996 0 : "tableID": c.table_oid(),
997 0 : "columnID": c.column_id(),
998 0 : "dataTypeSize": c.type_size(),
999 0 : "dataTypeModifier": c.type_modifier(),
1000 0 : "format": "text",
1001 0 : }));
1002 0 : columns.push(client.get_type(c.type_oid()).await?);
1003 : }
1004 :
1005 0 : let array_mode = data.array_mode.unwrap_or(parsed_headers.default_array_mode);
1006 :
1007 : // convert rows to JSON
1008 0 : let rows = rows
1009 0 : .iter()
1010 0 : .map(|row| pg_text_row_to_json(row, &columns, parsed_headers.raw_output, array_mode))
1011 0 : .collect::<Result<Vec<_>, _>>()?;
1012 :
1013 : // Resulting JSON format is based on the format of node-postgres result.
1014 0 : let results = json!({
1015 0 : "command": command_tag_name.to_string(),
1016 0 : "rowCount": command_tag_count,
1017 0 : "rows": rows,
1018 0 : "fields": fields,
1019 0 : "rowAsArray": array_mode,
1020 0 : });
1021 0 :
1022 0 : Ok((ready, results))
1023 0 : }
1024 :
1025 : enum Client {
1026 : Remote(conn_pool_lib::Client<tokio_postgres::Client>),
1027 : Local(local_conn_pool::LocalClient<tokio_postgres::Client>),
1028 : }
1029 :
1030 : enum Discard<'a> {
1031 : Remote(conn_pool_lib::Discard<'a, tokio_postgres::Client>),
1032 : Local(local_conn_pool::Discard<'a, tokio_postgres::Client>),
1033 : }
1034 :
1035 : impl Client {
1036 0 : fn metrics(&self) -> Arc<MetricCounter> {
1037 0 : match self {
1038 0 : Client::Remote(client) => client.metrics(),
1039 0 : Client::Local(local_client) => local_client.metrics(),
1040 : }
1041 0 : }
1042 :
1043 0 : fn inner(&mut self) -> (&mut tokio_postgres::Client, Discard<'_>) {
1044 0 : match self {
1045 0 : Client::Remote(client) => {
1046 0 : let (c, d) = client.inner_mut();
1047 0 : (c, Discard::Remote(d))
1048 : }
1049 0 : Client::Local(local_client) => {
1050 0 : let (c, d) = local_client.inner();
1051 0 : (c, Discard::Local(d))
1052 : }
1053 : }
1054 0 : }
1055 : }
1056 :
1057 : impl Discard<'_> {
1058 0 : fn check_idle(&mut self, status: ReadyForQueryStatus) {
1059 0 : match self {
1060 0 : Discard::Remote(discard) => discard.check_idle(status),
1061 0 : Discard::Local(discard) => discard.check_idle(status),
1062 : }
1063 0 : }
1064 0 : fn discard(&mut self) {
1065 0 : match self {
1066 0 : Discard::Remote(discard) => discard.discard(),
1067 0 : Discard::Local(discard) => discard.discard(),
1068 : }
1069 0 : }
1070 : }
|