Line data Source code
1 : use std::pin::pin;
2 : use std::sync::Arc;
3 :
4 : use bytes::Bytes;
5 : use futures::future::select;
6 : use futures::future::try_join;
7 : use futures::future::Either;
8 : use futures::StreamExt;
9 : use futures::TryFutureExt;
10 : use http::header::AUTHORIZATION;
11 : use http_body_util::BodyExt;
12 : use http_body_util::Full;
13 : use hyper1::body::Body;
14 : use hyper1::body::Incoming;
15 : use hyper1::header;
16 : use hyper1::http::HeaderName;
17 : use hyper1::http::HeaderValue;
18 : use hyper1::Response;
19 : use hyper1::StatusCode;
20 : use hyper1::{HeaderMap, Request};
21 : use pq_proto::StartupMessageParamsBuilder;
22 : use serde::Serialize;
23 : use serde_json::Value;
24 : use tokio::time;
25 : use tokio_postgres::error::DbError;
26 : use tokio_postgres::error::ErrorPosition;
27 : use tokio_postgres::error::SqlState;
28 : use tokio_postgres::GenericClient;
29 : use tokio_postgres::IsolationLevel;
30 : use tokio_postgres::NoTls;
31 : use tokio_postgres::ReadyForQueryStatus;
32 : use tokio_postgres::Transaction;
33 : use tokio_util::sync::CancellationToken;
34 : use tracing::error;
35 : use tracing::info;
36 : use typed_json::json;
37 : use url::Url;
38 : use urlencoding;
39 : use utils::http::error::ApiError;
40 :
41 : use crate::auth::backend::ComputeUserInfo;
42 : use crate::auth::endpoint_sni;
43 : use crate::auth::ComputeUserInfoParseError;
44 : use crate::config::ProxyConfig;
45 : use crate::config::TlsConfig;
46 : use crate::context::RequestMonitoring;
47 : use crate::error::ErrorKind;
48 : use crate::error::ReportableError;
49 : use crate::error::UserFacingError;
50 : use crate::metrics::HttpDirection;
51 : use crate::metrics::Metrics;
52 : use crate::proxy::run_until_cancelled;
53 : use crate::proxy::NeonOptions;
54 : use crate::serverless::backend::HttpConnError;
55 : use crate::usage_metrics::MetricCounterRecorder;
56 : use crate::DbName;
57 : use crate::RoleName;
58 :
59 : use super::backend::PoolingBackend;
60 : use super::conn_pool::AuthData;
61 : use super::conn_pool::Client;
62 : use super::conn_pool::ConnInfo;
63 : use super::conn_pool::ConnInfoWithAuth;
64 : use super::http_util::json_response;
65 : use super::json::json_to_pg_text;
66 : use super::json::pg_text_row_to_json;
67 : use super::json::JsonConversionError;
68 :
69 0 : #[derive(serde::Deserialize)]
70 : #[serde(rename_all = "camelCase")]
71 : struct QueryData {
72 : query: String,
73 : #[serde(deserialize_with = "bytes_to_pg_text")]
74 : params: Vec<Option<String>>,
75 : #[serde(default)]
76 : array_mode: Option<bool>,
77 : }
78 :
79 0 : #[derive(serde::Deserialize)]
80 : struct BatchQueryData {
81 : queries: Vec<QueryData>,
82 : }
83 :
84 : #[derive(serde::Deserialize)]
85 : #[serde(untagged)]
86 : enum Payload {
87 : Single(QueryData),
88 : Batch(BatchQueryData),
89 : }
90 :
91 : static CONN_STRING: HeaderName = HeaderName::from_static("neon-connection-string");
92 : static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output");
93 : static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode");
94 : static ALLOW_POOL: HeaderName = HeaderName::from_static("neon-pool-opt-in");
95 : static TXN_ISOLATION_LEVEL: HeaderName = HeaderName::from_static("neon-batch-isolation-level");
96 : static TXN_READ_ONLY: HeaderName = HeaderName::from_static("neon-batch-read-only");
97 : static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrable");
98 :
99 : static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true");
100 :
101 0 : fn bytes_to_pg_text<'de, D>(deserializer: D) -> Result<Vec<Option<String>>, D::Error>
102 0 : where
103 0 : D: serde::de::Deserializer<'de>,
104 0 : {
105 : // TODO: consider avoiding the allocation here.
106 0 : let json: Vec<Value> = serde::de::Deserialize::deserialize(deserializer)?;
107 0 : Ok(json_to_pg_text(json))
108 0 : }
109 :
110 0 : #[derive(Debug, thiserror::Error)]
111 : pub(crate) enum ConnInfoError {
112 : #[error("invalid header: {0}")]
113 : InvalidHeader(&'static HeaderName),
114 : #[error("invalid connection string: {0}")]
115 : UrlParseError(#[from] url::ParseError),
116 : #[error("incorrect scheme")]
117 : IncorrectScheme,
118 : #[error("missing database name")]
119 : MissingDbName,
120 : #[error("invalid database name")]
121 : InvalidDbName,
122 : #[error("missing username")]
123 : MissingUsername,
124 : #[error("invalid username: {0}")]
125 : InvalidUsername(#[from] std::string::FromUtf8Error),
126 : #[error("missing password")]
127 : MissingPassword,
128 : #[error("missing hostname")]
129 : MissingHostname,
130 : #[error("invalid hostname: {0}")]
131 : InvalidEndpoint(#[from] ComputeUserInfoParseError),
132 : #[error("malformed endpoint")]
133 : MalformedEndpoint,
134 : }
135 :
136 : impl ReportableError for ConnInfoError {
137 0 : fn get_error_kind(&self) -> ErrorKind {
138 0 : ErrorKind::User
139 0 : }
140 : }
141 :
142 : impl UserFacingError for ConnInfoError {
143 0 : fn to_string_client(&self) -> String {
144 0 : self.to_string()
145 0 : }
146 : }
147 :
148 0 : fn get_conn_info(
149 0 : ctx: &RequestMonitoring,
150 0 : headers: &HeaderMap,
151 0 : tls: Option<&TlsConfig>,
152 0 : ) -> Result<ConnInfoWithAuth, ConnInfoError> {
153 0 : // HTTP only uses cleartext (for now and likely always)
154 0 : ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
155 :
156 0 : let connection_string = headers
157 0 : .get(&CONN_STRING)
158 0 : .ok_or(ConnInfoError::InvalidHeader(&CONN_STRING))?
159 0 : .to_str()
160 0 : .map_err(|_| ConnInfoError::InvalidHeader(&CONN_STRING))?;
161 :
162 0 : let connection_url = Url::parse(connection_string)?;
163 :
164 0 : let protocol = connection_url.scheme();
165 0 : if protocol != "postgres" && protocol != "postgresql" {
166 0 : return Err(ConnInfoError::IncorrectScheme);
167 0 : }
168 :
169 0 : let mut url_path = connection_url
170 0 : .path_segments()
171 0 : .ok_or(ConnInfoError::MissingDbName)?;
172 :
173 0 : let dbname: DbName =
174 0 : urlencoding::decode(url_path.next().ok_or(ConnInfoError::InvalidDbName)?)?.into();
175 0 : ctx.set_dbname(dbname.clone());
176 :
177 0 : let username = RoleName::from(urlencoding::decode(connection_url.username())?);
178 0 : if username.is_empty() {
179 0 : return Err(ConnInfoError::MissingUsername);
180 0 : }
181 0 : ctx.set_user(username.clone());
182 :
183 0 : let auth = if let Some(auth) = headers.get(&AUTHORIZATION) {
184 0 : let auth = auth
185 0 : .to_str()
186 0 : .map_err(|_| ConnInfoError::InvalidHeader(&AUTHORIZATION))?;
187 : AuthData::Jwt(
188 0 : auth.strip_prefix("Bearer ")
189 0 : .ok_or(ConnInfoError::MissingPassword)?
190 0 : .into(),
191 : )
192 0 : } else if let Some(pass) = connection_url.password() {
193 0 : AuthData::Password(match urlencoding::decode_binary(pass.as_bytes()) {
194 0 : std::borrow::Cow::Borrowed(b) => b.into(),
195 0 : std::borrow::Cow::Owned(b) => b.into(),
196 : })
197 : } else {
198 0 : return Err(ConnInfoError::MissingPassword);
199 : };
200 :
201 0 : let endpoint = match connection_url.host() {
202 0 : Some(url::Host::Domain(hostname)) => {
203 0 : if let Some(tls) = tls {
204 0 : endpoint_sni(hostname, &tls.common_names)?
205 0 : .ok_or(ConnInfoError::MalformedEndpoint)?
206 : } else {
207 0 : hostname
208 0 : .split_once('.')
209 0 : .map_or(hostname, |(prefix, _)| prefix)
210 0 : .into()
211 : }
212 : }
213 : Some(url::Host::Ipv4(_) | url::Host::Ipv6(_)) | None => {
214 0 : return Err(ConnInfoError::MissingHostname)
215 : }
216 : };
217 0 : ctx.set_endpoint_id(endpoint.clone());
218 0 :
219 0 : let pairs = connection_url.query_pairs();
220 0 :
221 0 : let mut options = Option::None;
222 0 :
223 0 : let mut params = StartupMessageParamsBuilder::default();
224 0 : params.insert("user", &username);
225 0 : params.insert("database", &dbname);
226 0 : for (key, value) in pairs {
227 0 : params.insert(&key, &value);
228 0 : if key == "options" {
229 0 : options = Some(NeonOptions::parse_options_raw(&value));
230 0 : }
231 : }
232 :
233 0 : let user_info = ComputeUserInfo {
234 0 : endpoint,
235 0 : user: username,
236 0 : options: options.unwrap_or_default(),
237 0 : };
238 0 :
239 0 : let conn_info = ConnInfo { user_info, dbname };
240 0 : Ok(ConnInfoWithAuth { conn_info, auth })
241 0 : }
242 :
243 : // TODO: return different http error codes
244 0 : pub(crate) async fn handle(
245 0 : config: &'static ProxyConfig,
246 0 : ctx: RequestMonitoring,
247 0 : request: Request<Incoming>,
248 0 : backend: Arc<PoolingBackend>,
249 0 : cancel: CancellationToken,
250 0 : ) -> Result<Response<Full<Bytes>>, ApiError> {
251 0 : let result = handle_inner(cancel, config, &ctx, request, backend).await;
252 :
253 0 : let mut response = match result {
254 0 : Ok(r) => {
255 0 : ctx.set_success();
256 0 : r
257 : }
258 0 : Err(e @ SqlOverHttpError::Cancelled(_)) => {
259 0 : let error_kind = e.get_error_kind();
260 0 : ctx.set_error_kind(error_kind);
261 0 :
262 0 : let message = "Query cancelled, connection was terminated";
263 0 :
264 0 : tracing::info!(
265 0 : kind=error_kind.to_metric_label(),
266 0 : error=%e,
267 0 : msg=message,
268 0 : "forwarding error to user"
269 : );
270 :
271 0 : json_response(
272 0 : StatusCode::BAD_REQUEST,
273 0 : json!({ "message": message, "code": SqlState::PROTOCOL_VIOLATION.code() }),
274 0 : )?
275 : }
276 0 : Err(e) => {
277 0 : let error_kind = e.get_error_kind();
278 0 : ctx.set_error_kind(error_kind);
279 0 :
280 0 : let mut message = e.to_string_client();
281 0 : let db_error = match &e {
282 0 : SqlOverHttpError::ConnectCompute(HttpConnError::ConnectionError(e))
283 0 : | SqlOverHttpError::Postgres(e) => e.as_db_error(),
284 0 : _ => None,
285 : };
286 0 : fn get<'a, T: Default>(db: Option<&'a DbError>, x: impl FnOnce(&'a DbError) -> T) -> T {
287 0 : db.map(x).unwrap_or_default()
288 0 : }
289 :
290 0 : if let Some(db_error) = db_error {
291 0 : db_error.message().clone_into(&mut message);
292 0 : }
293 :
294 0 : let position = db_error.and_then(|db| db.position());
295 0 : let (position, internal_position, internal_query) = match position {
296 0 : Some(ErrorPosition::Original(position)) => (Some(position.to_string()), None, None),
297 0 : Some(ErrorPosition::Internal { position, query }) => {
298 0 : (None, Some(position.to_string()), Some(query.clone()))
299 : }
300 0 : None => (None, None, None),
301 : };
302 :
303 0 : let code = get(db_error, |db| db.code().code());
304 0 : let severity = get(db_error, |db| db.severity());
305 0 : let detail = get(db_error, |db| db.detail());
306 0 : let hint = get(db_error, |db| db.hint());
307 0 : let where_ = get(db_error, |db| db.where_());
308 0 : let table = get(db_error, |db| db.table());
309 0 : let column = get(db_error, |db| db.column());
310 0 : let schema = get(db_error, |db| db.schema());
311 0 : let datatype = get(db_error, |db| db.datatype());
312 0 : let constraint = get(db_error, |db| db.constraint());
313 0 : let file = get(db_error, |db| db.file());
314 0 : let line = get(db_error, |db| db.line().map(|l| l.to_string()));
315 0 : let routine = get(db_error, |db| db.routine());
316 0 :
317 0 : tracing::info!(
318 0 : kind=error_kind.to_metric_label(),
319 0 : error=%e,
320 0 : msg=message,
321 0 : "forwarding error to user"
322 : );
323 :
324 : // TODO: this shouldn't always be bad request.
325 0 : json_response(
326 0 : StatusCode::BAD_REQUEST,
327 0 : json!({
328 0 : "message": message,
329 0 : "code": code,
330 0 : "detail": detail,
331 0 : "hint": hint,
332 0 : "position": position,
333 0 : "internalPosition": internal_position,
334 0 : "internalQuery": internal_query,
335 0 : "severity": severity,
336 0 : "where": where_,
337 0 : "table": table,
338 0 : "column": column,
339 0 : "schema": schema,
340 0 : "dataType": datatype,
341 0 : "constraint": constraint,
342 0 : "file": file,
343 0 : "line": line,
344 0 : "routine": routine,
345 0 : }),
346 0 : )?
347 : }
348 : };
349 :
350 0 : response
351 0 : .headers_mut()
352 0 : .insert("Access-Control-Allow-Origin", HeaderValue::from_static("*"));
353 0 : Ok(response)
354 0 : }
355 :
356 0 : #[derive(Debug, thiserror::Error)]
357 : pub(crate) enum SqlOverHttpError {
358 : #[error("{0}")]
359 : ReadPayload(#[from] ReadPayloadError),
360 : #[error("{0}")]
361 : ConnectCompute(#[from] HttpConnError),
362 : #[error("{0}")]
363 : ConnInfo(#[from] ConnInfoError),
364 : #[error("request is too large (max is {0} bytes)")]
365 : RequestTooLarge(u64),
366 : #[error("response is too large (max is {0} bytes)")]
367 : ResponseTooLarge(usize),
368 : #[error("invalid isolation level")]
369 : InvalidIsolationLevel,
370 : #[error("{0}")]
371 : Postgres(#[from] tokio_postgres::Error),
372 : #[error("{0}")]
373 : JsonConversion(#[from] JsonConversionError),
374 : #[error("{0}")]
375 : Cancelled(SqlOverHttpCancel),
376 : }
377 :
378 : impl ReportableError for SqlOverHttpError {
379 0 : fn get_error_kind(&self) -> ErrorKind {
380 0 : match self {
381 0 : SqlOverHttpError::ReadPayload(e) => e.get_error_kind(),
382 0 : SqlOverHttpError::ConnectCompute(e) => e.get_error_kind(),
383 0 : SqlOverHttpError::ConnInfo(e) => e.get_error_kind(),
384 0 : SqlOverHttpError::RequestTooLarge(_) => ErrorKind::User,
385 0 : SqlOverHttpError::ResponseTooLarge(_) => ErrorKind::User,
386 0 : SqlOverHttpError::InvalidIsolationLevel => ErrorKind::User,
387 0 : SqlOverHttpError::Postgres(p) => p.get_error_kind(),
388 0 : SqlOverHttpError::JsonConversion(_) => ErrorKind::Postgres,
389 0 : SqlOverHttpError::Cancelled(c) => c.get_error_kind(),
390 : }
391 0 : }
392 : }
393 :
394 : impl UserFacingError for SqlOverHttpError {
395 0 : fn to_string_client(&self) -> String {
396 0 : match self {
397 0 : SqlOverHttpError::ReadPayload(p) => p.to_string(),
398 0 : SqlOverHttpError::ConnectCompute(c) => c.to_string_client(),
399 0 : SqlOverHttpError::ConnInfo(c) => c.to_string_client(),
400 0 : SqlOverHttpError::RequestTooLarge(_) => self.to_string(),
401 0 : SqlOverHttpError::ResponseTooLarge(_) => self.to_string(),
402 0 : SqlOverHttpError::InvalidIsolationLevel => self.to_string(),
403 0 : SqlOverHttpError::Postgres(p) => p.to_string(),
404 0 : SqlOverHttpError::JsonConversion(_) => "could not parse postgres response".to_string(),
405 0 : SqlOverHttpError::Cancelled(_) => self.to_string(),
406 : }
407 0 : }
408 : }
409 :
410 0 : #[derive(Debug, thiserror::Error)]
411 : pub(crate) enum ReadPayloadError {
412 : #[error("could not read the HTTP request body: {0}")]
413 : Read(#[from] hyper1::Error),
414 : #[error("could not parse the HTTP request body: {0}")]
415 : Parse(#[from] serde_json::Error),
416 : }
417 :
418 : impl ReportableError for ReadPayloadError {
419 0 : fn get_error_kind(&self) -> ErrorKind {
420 0 : match self {
421 0 : ReadPayloadError::Read(_) => ErrorKind::ClientDisconnect,
422 0 : ReadPayloadError::Parse(_) => ErrorKind::User,
423 : }
424 0 : }
425 : }
426 :
427 0 : #[derive(Debug, thiserror::Error)]
428 : pub(crate) enum SqlOverHttpCancel {
429 : #[error("query was cancelled")]
430 : Postgres,
431 : #[error("query was cancelled while stuck trying to connect to the database")]
432 : Connect,
433 : }
434 :
435 : impl ReportableError for SqlOverHttpCancel {
436 0 : fn get_error_kind(&self) -> ErrorKind {
437 0 : match self {
438 0 : SqlOverHttpCancel::Postgres => ErrorKind::ClientDisconnect,
439 0 : SqlOverHttpCancel::Connect => ErrorKind::ClientDisconnect,
440 : }
441 0 : }
442 : }
443 :
444 : #[derive(Clone, Copy, Debug)]
445 : struct HttpHeaders {
446 : raw_output: bool,
447 : default_array_mode: bool,
448 : txn_isolation_level: Option<IsolationLevel>,
449 : txn_read_only: bool,
450 : txn_deferrable: bool,
451 : }
452 :
453 : impl HttpHeaders {
454 0 : fn try_parse(headers: &hyper1::http::HeaderMap) -> Result<Self, SqlOverHttpError> {
455 0 : // Determine the output options. Default behaviour is 'false'. Anything that is not
456 0 : // strictly 'true' assumed to be false.
457 0 : let raw_output = headers.get(&RAW_TEXT_OUTPUT) == Some(&HEADER_VALUE_TRUE);
458 0 : let default_array_mode = headers.get(&ARRAY_MODE) == Some(&HEADER_VALUE_TRUE);
459 :
460 : // isolation level, read only and deferrable
461 0 : let txn_isolation_level = match headers.get(&TXN_ISOLATION_LEVEL) {
462 0 : Some(x) => Some(
463 0 : map_header_to_isolation_level(x).ok_or(SqlOverHttpError::InvalidIsolationLevel)?,
464 : ),
465 0 : None => None,
466 : };
467 :
468 0 : let txn_read_only = headers.get(&TXN_READ_ONLY) == Some(&HEADER_VALUE_TRUE);
469 0 : let txn_deferrable = headers.get(&TXN_DEFERRABLE) == Some(&HEADER_VALUE_TRUE);
470 0 :
471 0 : Ok(Self {
472 0 : raw_output,
473 0 : default_array_mode,
474 0 : txn_isolation_level,
475 0 : txn_read_only,
476 0 : txn_deferrable,
477 0 : })
478 0 : }
479 : }
480 :
481 0 : fn map_header_to_isolation_level(level: &HeaderValue) -> Option<IsolationLevel> {
482 0 : match level.as_bytes() {
483 0 : b"Serializable" => Some(IsolationLevel::Serializable),
484 0 : b"ReadUncommitted" => Some(IsolationLevel::ReadUncommitted),
485 0 : b"ReadCommitted" => Some(IsolationLevel::ReadCommitted),
486 0 : b"RepeatableRead" => Some(IsolationLevel::RepeatableRead),
487 0 : _ => None,
488 : }
489 0 : }
490 :
491 0 : fn map_isolation_level_to_headers(level: IsolationLevel) -> Option<HeaderValue> {
492 0 : match level {
493 0 : IsolationLevel::ReadUncommitted => Some(HeaderValue::from_static("ReadUncommitted")),
494 0 : IsolationLevel::ReadCommitted => Some(HeaderValue::from_static("ReadCommitted")),
495 0 : IsolationLevel::RepeatableRead => Some(HeaderValue::from_static("RepeatableRead")),
496 0 : IsolationLevel::Serializable => Some(HeaderValue::from_static("Serializable")),
497 0 : _ => None,
498 : }
499 0 : }
500 :
501 0 : async fn handle_inner(
502 0 : cancel: CancellationToken,
503 0 : config: &'static ProxyConfig,
504 0 : ctx: &RequestMonitoring,
505 0 : request: Request<Incoming>,
506 0 : backend: Arc<PoolingBackend>,
507 0 : ) -> Result<Response<Full<Bytes>>, SqlOverHttpError> {
508 0 : let _requeset_gauge = Metrics::get()
509 0 : .proxy
510 0 : .connection_requests
511 0 : .guard(ctx.protocol());
512 0 : info!(
513 0 : protocol = %ctx.protocol(),
514 0 : "handling interactive connection from client"
515 : );
516 :
517 : //
518 : // Determine the destination and connection params
519 : //
520 0 : let headers = request.headers();
521 :
522 : // TLS config should be there.
523 0 : let conn_info = get_conn_info(ctx, headers, config.tls_config.as_ref())?;
524 0 : info!(
525 0 : user = conn_info.conn_info.user_info.user.as_str(),
526 0 : "credentials"
527 : );
528 :
529 : // Allow connection pooling only if explicitly requested
530 : // or if we have decided that http pool is no longer opt-in
531 0 : let allow_pool = !config.http_config.pool_options.opt_in
532 0 : || headers.get(&ALLOW_POOL) == Some(&HEADER_VALUE_TRUE);
533 :
534 0 : let parsed_headers = HttpHeaders::try_parse(headers)?;
535 :
536 0 : let request_content_length = match request.body().size_hint().upper() {
537 0 : Some(v) => v,
538 0 : None => config.http_config.max_request_size_bytes + 1,
539 : };
540 0 : info!(request_content_length, "request size in bytes");
541 0 : Metrics::get()
542 0 : .proxy
543 0 : .http_conn_content_length_bytes
544 0 : .observe(HttpDirection::Request, request_content_length as f64);
545 0 :
546 0 : // we don't have a streaming request support yet so this is to prevent OOM
547 0 : // from a malicious user sending an extremely large request body
548 0 : if request_content_length > config.http_config.max_request_size_bytes {
549 0 : return Err(SqlOverHttpError::RequestTooLarge(
550 0 : config.http_config.max_request_size_bytes,
551 0 : ));
552 0 : }
553 0 :
554 0 : let fetch_and_process_request = Box::pin(
555 0 : async {
556 0 : let body = request.into_body().collect().await?.to_bytes();
557 0 : info!(length = body.len(), "request payload read");
558 0 : let payload: Payload = serde_json::from_slice(&body)?;
559 0 : Ok::<Payload, ReadPayloadError>(payload) // Adjust error type accordingly
560 0 : }
561 0 : .map_err(SqlOverHttpError::from),
562 0 : );
563 0 :
564 0 : let authenticate_and_connect = Box::pin(
565 0 : async {
566 0 : let keys = match &conn_info.auth {
567 0 : AuthData::Password(pw) => {
568 0 : backend
569 0 : .authenticate_with_password(
570 0 : ctx,
571 0 : &config.authentication_config,
572 0 : &conn_info.conn_info.user_info,
573 0 : pw,
574 0 : )
575 0 : .await?
576 : }
577 0 : AuthData::Jwt(jwt) => {
578 0 : backend
579 0 : .authenticate_with_jwt(ctx, &conn_info.conn_info.user_info, jwt)
580 0 : .await?
581 : }
582 : };
583 :
584 0 : let client = backend
585 0 : .connect_to_compute(ctx, conn_info.conn_info, keys, !allow_pool)
586 0 : .await?;
587 : // not strictly necessary to mark success here,
588 : // but it's just insurance for if we forget it somewhere else
589 0 : ctx.success();
590 0 : Ok::<_, HttpConnError>(client)
591 0 : }
592 0 : .map_err(SqlOverHttpError::from),
593 0 : );
594 :
595 0 : let (payload, mut client) = match run_until_cancelled(
596 0 : // Run both operations in parallel
597 0 : try_join(
598 0 : pin!(fetch_and_process_request),
599 0 : pin!(authenticate_and_connect),
600 0 : ),
601 0 : &cancel,
602 0 : )
603 0 : .await
604 : {
605 0 : Some(result) => result?,
606 0 : None => return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Connect)),
607 : };
608 :
609 0 : let mut response = Response::builder()
610 0 : .status(StatusCode::OK)
611 0 : .header(header::CONTENT_TYPE, "application/json");
612 :
613 : // Now execute the query and return the result.
614 0 : let json_output = match payload {
615 0 : Payload::Single(stmt) => {
616 0 : stmt.process(config, cancel, &mut client, parsed_headers)
617 0 : .await?
618 : }
619 0 : Payload::Batch(statements) => {
620 0 : if parsed_headers.txn_read_only {
621 0 : response = response.header(TXN_READ_ONLY.clone(), &HEADER_VALUE_TRUE);
622 0 : }
623 0 : if parsed_headers.txn_deferrable {
624 0 : response = response.header(TXN_DEFERRABLE.clone(), &HEADER_VALUE_TRUE);
625 0 : }
626 0 : if let Some(txn_isolation_level) = parsed_headers
627 0 : .txn_isolation_level
628 0 : .and_then(map_isolation_level_to_headers)
629 0 : {
630 0 : response = response.header(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level);
631 0 : }
632 :
633 0 : statements
634 0 : .process(config, cancel, &mut client, parsed_headers)
635 0 : .await?
636 : }
637 : };
638 :
639 0 : let metrics = client.metrics();
640 0 :
641 0 : let len = json_output.len();
642 0 : let response = response
643 0 : .body(Full::new(Bytes::from(json_output)))
644 0 : // only fails if invalid status code or invalid header/values are given.
645 0 : // these are not user configurable so it cannot fail dynamically
646 0 : .expect("building response payload should not fail");
647 0 :
648 0 : // count the egress bytes - we miss the TLS and header overhead but oh well...
649 0 : // moving this later in the stack is going to be a lot of effort and ehhhh
650 0 : metrics.record_egress(len as u64);
651 0 : Metrics::get()
652 0 : .proxy
653 0 : .http_conn_content_length_bytes
654 0 : .observe(HttpDirection::Response, len as f64);
655 0 :
656 0 : Ok(response)
657 0 : }
658 :
659 : impl QueryData {
660 0 : async fn process(
661 0 : self,
662 0 : config: &'static ProxyConfig,
663 0 : cancel: CancellationToken,
664 0 : client: &mut Client<tokio_postgres::Client>,
665 0 : parsed_headers: HttpHeaders,
666 0 : ) -> Result<String, SqlOverHttpError> {
667 0 : let (inner, mut discard) = client.inner();
668 0 : let cancel_token = inner.cancel_token();
669 :
670 0 : let res = match select(
671 0 : pin!(query_to_json(config, &*inner, self, &mut 0, parsed_headers)),
672 0 : pin!(cancel.cancelled()),
673 0 : )
674 0 : .await
675 : {
676 : // The query successfully completed.
677 0 : Either::Left((Ok((status, results)), __not_yet_cancelled)) => {
678 0 : discard.check_idle(status);
679 0 :
680 0 : let json_output =
681 0 : serde_json::to_string(&results).expect("json serialization should not fail");
682 0 : Ok(json_output)
683 : }
684 : // The query failed with an error
685 0 : Either::Left((Err(e), __not_yet_cancelled)) => {
686 0 : discard.discard();
687 0 : return Err(e);
688 : }
689 : // The query was cancelled.
690 0 : Either::Right((_cancelled, query)) => {
691 0 : tracing::info!("cancelling query");
692 0 : if let Err(err) = cancel_token.cancel_query(NoTls).await {
693 0 : tracing::error!(?err, "could not cancel query");
694 0 : }
695 : // wait for the query cancellation
696 0 : match time::timeout(time::Duration::from_millis(100), query).await {
697 : // query successed before it was cancelled.
698 0 : Ok(Ok((status, results))) => {
699 0 : discard.check_idle(status);
700 0 :
701 0 : let json_output = serde_json::to_string(&results)
702 0 : .expect("json serialization should not fail");
703 0 : Ok(json_output)
704 : }
705 : // query failed or was cancelled.
706 0 : Ok(Err(error)) => {
707 0 : let db_error = match &error {
708 0 : SqlOverHttpError::ConnectCompute(HttpConnError::ConnectionError(e))
709 0 : | SqlOverHttpError::Postgres(e) => e.as_db_error(),
710 0 : _ => None,
711 : };
712 :
713 : // if errored for some other reason, it might not be safe to return
714 0 : if !db_error.is_some_and(|e| *e.code() == SqlState::QUERY_CANCELED) {
715 0 : discard.discard();
716 0 : }
717 :
718 0 : Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres))
719 : }
720 0 : Err(_timeout) => {
721 0 : discard.discard();
722 0 : Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres))
723 : }
724 : }
725 : }
726 : };
727 0 : res
728 0 : }
729 : }
730 :
731 : impl BatchQueryData {
732 0 : async fn process(
733 0 : self,
734 0 : config: &'static ProxyConfig,
735 0 : cancel: CancellationToken,
736 0 : client: &mut Client<tokio_postgres::Client>,
737 0 : parsed_headers: HttpHeaders,
738 0 : ) -> Result<String, SqlOverHttpError> {
739 0 : info!("starting transaction");
740 0 : let (inner, mut discard) = client.inner();
741 0 : let cancel_token = inner.cancel_token();
742 0 : let mut builder = inner.build_transaction();
743 0 : if let Some(isolation_level) = parsed_headers.txn_isolation_level {
744 0 : builder = builder.isolation_level(isolation_level);
745 0 : }
746 0 : if parsed_headers.txn_read_only {
747 0 : builder = builder.read_only(true);
748 0 : }
749 0 : if parsed_headers.txn_deferrable {
750 0 : builder = builder.deferrable(true);
751 0 : }
752 :
753 0 : let transaction = builder.start().await.inspect_err(|_| {
754 0 : // if we cannot start a transaction, we should return immediately
755 0 : // and not return to the pool. connection is clearly broken
756 0 : discard.discard();
757 0 : })?;
758 :
759 0 : let json_output = match query_batch(
760 0 : config,
761 0 : cancel.child_token(),
762 0 : &transaction,
763 0 : self,
764 0 : parsed_headers,
765 0 : )
766 0 : .await
767 : {
768 0 : Ok(json_output) => {
769 0 : info!("commit");
770 0 : let status = transaction.commit().await.inspect_err(|_| {
771 0 : // if we cannot commit - for now don't return connection to pool
772 0 : // TODO: get a query status from the error
773 0 : discard.discard();
774 0 : })?;
775 0 : discard.check_idle(status);
776 0 : json_output
777 : }
778 : Err(SqlOverHttpError::Cancelled(_)) => {
779 0 : if let Err(err) = cancel_token.cancel_query(NoTls).await {
780 0 : tracing::error!(?err, "could not cancel query");
781 0 : }
782 : // TODO: after cancelling, wait to see if we can get a status. maybe the connection is still safe.
783 0 : discard.discard();
784 0 :
785 0 : return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres));
786 : }
787 0 : Err(err) => {
788 0 : info!("rollback");
789 0 : let status = transaction.rollback().await.inspect_err(|_| {
790 0 : // if we cannot rollback - for now don't return connection to pool
791 0 : // TODO: get a query status from the error
792 0 : discard.discard();
793 0 : })?;
794 0 : discard.check_idle(status);
795 0 : return Err(err);
796 : }
797 : };
798 :
799 0 : Ok(json_output)
800 0 : }
801 : }
802 :
803 0 : async fn query_batch(
804 0 : config: &'static ProxyConfig,
805 0 : cancel: CancellationToken,
806 0 : transaction: &Transaction<'_>,
807 0 : queries: BatchQueryData,
808 0 : parsed_headers: HttpHeaders,
809 0 : ) -> Result<String, SqlOverHttpError> {
810 0 : let mut results = Vec::with_capacity(queries.queries.len());
811 0 : let mut current_size = 0;
812 0 : for stmt in queries.queries {
813 0 : let query = pin!(query_to_json(
814 0 : config,
815 0 : transaction,
816 0 : stmt,
817 0 : &mut current_size,
818 0 : parsed_headers,
819 0 : ));
820 0 : let cancelled = pin!(cancel.cancelled());
821 0 : let res = select(query, cancelled).await;
822 0 : match res {
823 : // TODO: maybe we should check that the transaction bit is set here
824 0 : Either::Left((Ok((_, values)), _cancelled)) => {
825 0 : results.push(values);
826 0 : }
827 0 : Either::Left((Err(e), _cancelled)) => {
828 0 : return Err(e);
829 : }
830 0 : Either::Right((_cancelled, _)) => {
831 0 : return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres));
832 : }
833 : }
834 : }
835 :
836 0 : let results = json!({ "results": results });
837 0 : let json_output = serde_json::to_string(&results).expect("json serialization should not fail");
838 0 :
839 0 : Ok(json_output)
840 0 : }
841 :
842 0 : async fn query_to_json<T: GenericClient>(
843 0 : config: &'static ProxyConfig,
844 0 : client: &T,
845 0 : data: QueryData,
846 0 : current_size: &mut usize,
847 0 : parsed_headers: HttpHeaders,
848 0 : ) -> Result<(ReadyForQueryStatus, impl Serialize), SqlOverHttpError> {
849 0 : info!("executing query");
850 0 : let query_params = data.params;
851 0 : let mut row_stream = std::pin::pin!(client.query_raw_txt(&data.query, query_params).await?);
852 0 : info!("finished executing query");
853 :
854 : // Manually drain the stream into a vector to leave row_stream hanging
855 : // around to get a command tag. Also check that the response is not too
856 : // big.
857 0 : let mut rows: Vec<tokio_postgres::Row> = Vec::new();
858 0 : while let Some(row) = row_stream.next().await {
859 0 : let row = row?;
860 0 : *current_size += row.body_len();
861 0 : rows.push(row);
862 0 : // we don't have a streaming response support yet so this is to prevent OOM
863 0 : // from a malicious query (eg a cross join)
864 0 : if *current_size > config.http_config.max_response_size_bytes {
865 0 : return Err(SqlOverHttpError::ResponseTooLarge(
866 0 : config.http_config.max_response_size_bytes,
867 0 : ));
868 0 : }
869 : }
870 :
871 0 : let ready = row_stream.ready_status();
872 0 :
873 0 : // grab the command tag and number of rows affected
874 0 : let command_tag = row_stream.command_tag().unwrap_or_default();
875 0 : let mut command_tag_split = command_tag.split(' ');
876 0 : let command_tag_name = command_tag_split.next().unwrap_or_default();
877 0 : let command_tag_count = if command_tag_name == "INSERT" {
878 : // INSERT returns OID first and then number of rows
879 0 : command_tag_split.nth(1)
880 : } else {
881 : // other commands return number of rows (if any)
882 0 : command_tag_split.next()
883 : }
884 0 : .and_then(|s| s.parse::<i64>().ok());
885 0 :
886 0 : info!(
887 0 : rows = rows.len(),
888 0 : ?ready,
889 0 : command_tag,
890 0 : "finished reading rows"
891 : );
892 :
893 0 : let columns_len = row_stream.columns().len();
894 0 : let mut fields = Vec::with_capacity(columns_len);
895 0 : let mut columns = Vec::with_capacity(columns_len);
896 :
897 0 : for c in row_stream.columns() {
898 0 : fields.push(json!({
899 0 : "name": c.name().to_owned(),
900 0 : "dataTypeID": c.type_().oid(),
901 0 : "tableID": c.table_oid(),
902 0 : "columnID": c.column_id(),
903 0 : "dataTypeSize": c.type_size(),
904 0 : "dataTypeModifier": c.type_modifier(),
905 0 : "format": "text",
906 0 : }));
907 0 : columns.push(client.get_type(c.type_oid()).await?);
908 : }
909 :
910 0 : let array_mode = data.array_mode.unwrap_or(parsed_headers.default_array_mode);
911 :
912 : // convert rows to JSON
913 0 : let rows = rows
914 0 : .iter()
915 0 : .map(|row| pg_text_row_to_json(row, &columns, parsed_headers.raw_output, array_mode))
916 0 : .collect::<Result<Vec<_>, _>>()?;
917 :
918 : // Resulting JSON format is based on the format of node-postgres result.
919 0 : let results = json!({
920 0 : "command": command_tag_name.to_string(),
921 0 : "rowCount": command_tag_count,
922 0 : "rows": rows,
923 0 : "fields": fields,
924 0 : "rowAsArray": array_mode,
925 0 : });
926 0 :
927 0 : Ok((ready, results))
928 0 : }
|