Line data Source code
1 : use std::pin::pin;
2 : use std::sync::Arc;
3 :
4 : use futures::future::select;
5 : use futures::future::try_join;
6 : use futures::future::Either;
7 : use futures::StreamExt;
8 : use futures::TryFutureExt;
9 : use hyper::body::HttpBody;
10 : use hyper::header;
11 : use hyper::http::HeaderName;
12 : use hyper::http::HeaderValue;
13 : use hyper::Response;
14 : use hyper::StatusCode;
15 : use hyper::{Body, HeaderMap, Request};
16 : use serde_json::json;
17 : use serde_json::Value;
18 : use tokio::time;
19 : use tokio_postgres::error::DbError;
20 : use tokio_postgres::error::ErrorPosition;
21 : use tokio_postgres::error::SqlState;
22 : use tokio_postgres::GenericClient;
23 : use tokio_postgres::IsolationLevel;
24 : use tokio_postgres::NoTls;
25 : use tokio_postgres::ReadyForQueryStatus;
26 : use tokio_postgres::Transaction;
27 : use tokio_util::sync::CancellationToken;
28 : use tracing::error;
29 : use tracing::info;
30 : use url::Url;
31 : use utils::http::error::ApiError;
32 : use utils::http::json::json_response;
33 :
34 : use crate::auth::backend::ComputeUserInfo;
35 : use crate::auth::endpoint_sni;
36 : use crate::auth::ComputeUserInfoParseError;
37 : use crate::config::ProxyConfig;
38 : use crate::config::TlsConfig;
39 : use crate::context::RequestMonitoring;
40 : use crate::error::ErrorKind;
41 : use crate::error::ReportableError;
42 : use crate::error::UserFacingError;
43 : use crate::metrics::HTTP_CONTENT_LENGTH;
44 : use crate::metrics::NUM_CONNECTION_REQUESTS_GAUGE;
45 : use crate::proxy::run_until_cancelled;
46 : use crate::proxy::NeonOptions;
47 : use crate::serverless::backend::HttpConnError;
48 : use crate::usage_metrics::MetricCounterRecorder;
49 : use crate::DbName;
50 : use crate::RoleName;
51 :
52 : use super::backend::PoolingBackend;
53 : use super::conn_pool::Client;
54 : use super::conn_pool::ConnInfo;
55 : use super::json::json_to_pg_text;
56 : use super::json::pg_text_row_to_json;
57 : use super::json::JsonConversionError;
58 :
59 0 : #[derive(serde::Deserialize)]
60 : #[serde(rename_all = "camelCase")]
61 : struct QueryData {
62 : query: String,
63 : #[serde(deserialize_with = "bytes_to_pg_text")]
64 : params: Vec<Option<String>>,
65 : #[serde(default)]
66 : array_mode: Option<bool>,
67 : }
68 :
69 0 : #[derive(serde::Deserialize)]
70 : struct BatchQueryData {
71 : queries: Vec<QueryData>,
72 : }
73 :
74 : #[derive(serde::Deserialize)]
75 : #[serde(untagged)]
76 : enum Payload {
77 : Single(QueryData),
78 : Batch(BatchQueryData),
79 : }
80 :
81 : const MAX_RESPONSE_SIZE: usize = 10 * 1024 * 1024; // 10 MiB
82 : const MAX_REQUEST_SIZE: u64 = 10 * 1024 * 1024; // 10 MiB
83 :
84 : static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output");
85 : static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode");
86 : static ALLOW_POOL: HeaderName = HeaderName::from_static("neon-pool-opt-in");
87 : static TXN_ISOLATION_LEVEL: HeaderName = HeaderName::from_static("neon-batch-isolation-level");
88 : static TXN_READ_ONLY: HeaderName = HeaderName::from_static("neon-batch-read-only");
89 : static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrable");
90 :
91 : static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true");
92 :
93 0 : fn bytes_to_pg_text<'de, D>(deserializer: D) -> Result<Vec<Option<String>>, D::Error>
94 0 : where
95 0 : D: serde::de::Deserializer<'de>,
96 0 : {
97 : // TODO: consider avoiding the allocation here.
98 0 : let json: Vec<Value> = serde::de::Deserialize::deserialize(deserializer)?;
99 0 : Ok(json_to_pg_text(json))
100 0 : }
101 :
102 0 : #[derive(Debug, thiserror::Error)]
103 : pub enum ConnInfoError {
104 : #[error("invalid header: {0}")]
105 : InvalidHeader(&'static str),
106 : #[error("invalid connection string: {0}")]
107 : UrlParseError(#[from] url::ParseError),
108 : #[error("incorrect scheme")]
109 : IncorrectScheme,
110 : #[error("missing database name")]
111 : MissingDbName,
112 : #[error("invalid database name")]
113 : InvalidDbName,
114 : #[error("missing username")]
115 : MissingUsername,
116 : #[error("invalid username: {0}")]
117 : InvalidUsername(#[from] std::string::FromUtf8Error),
118 : #[error("missing password")]
119 : MissingPassword,
120 : #[error("missing hostname")]
121 : MissingHostname,
122 : #[error("invalid hostname: {0}")]
123 : InvalidEndpoint(#[from] ComputeUserInfoParseError),
124 : #[error("malformed endpoint")]
125 : MalformedEndpoint,
126 : }
127 :
128 : impl ReportableError for ConnInfoError {
129 0 : fn get_error_kind(&self) -> ErrorKind {
130 0 : ErrorKind::User
131 0 : }
132 : }
133 :
134 : impl UserFacingError for ConnInfoError {
135 0 : fn to_string_client(&self) -> String {
136 0 : self.to_string()
137 0 : }
138 : }
139 :
140 0 : fn get_conn_info(
141 0 : ctx: &mut RequestMonitoring,
142 0 : headers: &HeaderMap,
143 0 : tls: &TlsConfig,
144 0 : ) -> Result<ConnInfo, ConnInfoError> {
145 0 : // HTTP only uses cleartext (for now and likely always)
146 0 : ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
147 :
148 0 : let connection_string = headers
149 0 : .get("Neon-Connection-String")
150 0 : .ok_or(ConnInfoError::InvalidHeader("Neon-Connection-String"))?
151 0 : .to_str()
152 0 : .map_err(|_| ConnInfoError::InvalidHeader("Neon-Connection-String"))?;
153 :
154 0 : let connection_url = Url::parse(connection_string)?;
155 :
156 0 : let protocol = connection_url.scheme();
157 0 : if protocol != "postgres" && protocol != "postgresql" {
158 0 : return Err(ConnInfoError::IncorrectScheme);
159 0 : }
160 :
161 0 : let mut url_path = connection_url
162 0 : .path_segments()
163 0 : .ok_or(ConnInfoError::MissingDbName)?;
164 :
165 0 : let dbname: DbName = url_path.next().ok_or(ConnInfoError::InvalidDbName)?.into();
166 0 : ctx.set_dbname(dbname.clone());
167 :
168 0 : let username = RoleName::from(urlencoding::decode(connection_url.username())?);
169 0 : if username.is_empty() {
170 0 : return Err(ConnInfoError::MissingUsername);
171 0 : }
172 0 : ctx.set_user(username.clone());
173 :
174 0 : let password = connection_url
175 0 : .password()
176 0 : .ok_or(ConnInfoError::MissingPassword)?;
177 0 : let password = urlencoding::decode_binary(password.as_bytes());
178 :
179 0 : let hostname = connection_url
180 0 : .host_str()
181 0 : .ok_or(ConnInfoError::MissingHostname)?;
182 :
183 0 : let endpoint =
184 0 : endpoint_sni(hostname, &tls.common_names)?.ok_or(ConnInfoError::MalformedEndpoint)?;
185 0 : ctx.set_endpoint_id(endpoint.clone());
186 0 :
187 0 : let pairs = connection_url.query_pairs();
188 0 :
189 0 : let mut options = Option::None;
190 :
191 0 : for (key, value) in pairs {
192 0 : match &*key {
193 0 : "options" => {
194 0 : options = Some(NeonOptions::parse_options_raw(&value));
195 0 : }
196 0 : "application_name" => ctx.set_application(Some(value.into())),
197 0 : _ => {}
198 : }
199 : }
200 :
201 0 : let user_info = ComputeUserInfo {
202 0 : endpoint,
203 0 : user: username,
204 0 : options: options.unwrap_or_default(),
205 0 : };
206 0 :
207 0 : Ok(ConnInfo {
208 0 : user_info,
209 0 : dbname,
210 0 : password: match password {
211 0 : std::borrow::Cow::Borrowed(b) => b.into(),
212 0 : std::borrow::Cow::Owned(b) => b.into(),
213 : },
214 : })
215 0 : }
216 :
217 : // TODO: return different http error codes
218 0 : pub async fn handle(
219 0 : config: &'static ProxyConfig,
220 0 : mut ctx: RequestMonitoring,
221 0 : request: Request<Body>,
222 0 : backend: Arc<PoolingBackend>,
223 0 : cancel: CancellationToken,
224 0 : ) -> Result<Response<Body>, ApiError> {
225 0 : let result = handle_inner(cancel, config, &mut ctx, request, backend).await;
226 :
227 0 : let mut response = match result {
228 0 : Ok(r) => {
229 0 : ctx.set_success();
230 0 : r
231 : }
232 0 : Err(e @ SqlOverHttpError::Cancelled(_)) => {
233 0 : let error_kind = e.get_error_kind();
234 0 : ctx.set_error_kind(error_kind);
235 0 :
236 0 : let message = "Query cancelled, connection was terminated";
237 0 :
238 0 : tracing::info!(
239 0 : kind=error_kind.to_metric_label(),
240 0 : error=%e,
241 0 : msg=message,
242 0 : "forwarding error to user"
243 0 : );
244 :
245 0 : json_response(
246 0 : StatusCode::BAD_REQUEST,
247 0 : json!({ "message": message, "code": SqlState::PROTOCOL_VIOLATION.code() }),
248 0 : )?
249 : }
250 0 : Err(e) => {
251 0 : let error_kind = e.get_error_kind();
252 0 : ctx.set_error_kind(error_kind);
253 0 :
254 0 : let mut message = e.to_string_client();
255 0 : let db_error = match &e {
256 0 : SqlOverHttpError::ConnectCompute(HttpConnError::ConnectionError(e))
257 0 : | SqlOverHttpError::Postgres(e) => e.as_db_error(),
258 0 : _ => None,
259 : };
260 0 : fn get<'a, T: serde::Serialize>(
261 0 : db: Option<&'a DbError>,
262 0 : x: impl FnOnce(&'a DbError) -> T,
263 0 : ) -> Value {
264 0 : db.map(x)
265 0 : .and_then(|t| serde_json::to_value(t).ok())
266 0 : .unwrap_or_default()
267 0 : }
268 :
269 0 : if let Some(db_error) = db_error {
270 0 : db_error.message().clone_into(&mut message);
271 0 : }
272 :
273 0 : let position = db_error.and_then(|db| db.position());
274 0 : let (position, internal_position, internal_query) = match position {
275 0 : Some(ErrorPosition::Original(position)) => (
276 0 : Value::String(position.to_string()),
277 0 : Value::Null,
278 0 : Value::Null,
279 0 : ),
280 0 : Some(ErrorPosition::Internal { position, query }) => (
281 0 : Value::Null,
282 0 : Value::String(position.to_string()),
283 0 : Value::String(query.clone()),
284 0 : ),
285 0 : None => (Value::Null, Value::Null, Value::Null),
286 : };
287 :
288 0 : let code = get(db_error, |db| db.code().code());
289 0 : let severity = get(db_error, |db| db.severity());
290 0 : let detail = get(db_error, |db| db.detail());
291 0 : let hint = get(db_error, |db| db.hint());
292 0 : let where_ = get(db_error, |db| db.where_());
293 0 : let table = get(db_error, |db| db.table());
294 0 : let column = get(db_error, |db| db.column());
295 0 : let schema = get(db_error, |db| db.schema());
296 0 : let datatype = get(db_error, |db| db.datatype());
297 0 : let constraint = get(db_error, |db| db.constraint());
298 0 : let file = get(db_error, |db| db.file());
299 0 : let line = get(db_error, |db| db.line().map(|l| l.to_string()));
300 0 : let routine = get(db_error, |db| db.routine());
301 0 :
302 0 : tracing::info!(
303 0 : kind=error_kind.to_metric_label(),
304 0 : error=%e,
305 0 : msg=message,
306 0 : "forwarding error to user"
307 0 : );
308 :
309 : // TODO: this shouldn't always be bad request.
310 0 : json_response(
311 0 : StatusCode::BAD_REQUEST,
312 0 : json!({
313 0 : "message": message,
314 0 : "code": code,
315 0 : "detail": detail,
316 0 : "hint": hint,
317 0 : "position": position,
318 0 : "internalPosition": internal_position,
319 0 : "internalQuery": internal_query,
320 0 : "severity": severity,
321 0 : "where": where_,
322 0 : "table": table,
323 0 : "column": column,
324 0 : "schema": schema,
325 0 : "dataType": datatype,
326 0 : "constraint": constraint,
327 0 : "file": file,
328 0 : "line": line,
329 0 : "routine": routine,
330 0 : }),
331 0 : )?
332 : }
333 : };
334 :
335 0 : response.headers_mut().insert(
336 0 : "Access-Control-Allow-Origin",
337 0 : hyper::http::HeaderValue::from_static("*"),
338 0 : );
339 0 : Ok(response)
340 0 : }
341 :
342 0 : #[derive(Debug, thiserror::Error)]
343 : pub enum SqlOverHttpError {
344 : #[error("{0}")]
345 : ReadPayload(#[from] ReadPayloadError),
346 : #[error("{0}")]
347 : ConnectCompute(#[from] HttpConnError),
348 : #[error("{0}")]
349 : ConnInfo(#[from] ConnInfoError),
350 : #[error("request is too large (max is {MAX_REQUEST_SIZE} bytes)")]
351 : RequestTooLarge,
352 : #[error("response is too large (max is {MAX_RESPONSE_SIZE} bytes)")]
353 : ResponseTooLarge,
354 : #[error("invalid isolation level")]
355 : InvalidIsolationLevel,
356 : #[error("{0}")]
357 : Postgres(#[from] tokio_postgres::Error),
358 : #[error("{0}")]
359 : JsonConversion(#[from] JsonConversionError),
360 : #[error("{0}")]
361 : Cancelled(SqlOverHttpCancel),
362 : }
363 :
364 : impl ReportableError for SqlOverHttpError {
365 0 : fn get_error_kind(&self) -> ErrorKind {
366 0 : match self {
367 0 : SqlOverHttpError::ReadPayload(e) => e.get_error_kind(),
368 0 : SqlOverHttpError::ConnectCompute(e) => e.get_error_kind(),
369 0 : SqlOverHttpError::ConnInfo(e) => e.get_error_kind(),
370 0 : SqlOverHttpError::RequestTooLarge => ErrorKind::User,
371 0 : SqlOverHttpError::ResponseTooLarge => ErrorKind::User,
372 0 : SqlOverHttpError::InvalidIsolationLevel => ErrorKind::User,
373 0 : SqlOverHttpError::Postgres(p) => p.get_error_kind(),
374 0 : SqlOverHttpError::JsonConversion(_) => ErrorKind::Postgres,
375 0 : SqlOverHttpError::Cancelled(c) => c.get_error_kind(),
376 : }
377 0 : }
378 : }
379 :
380 : impl UserFacingError for SqlOverHttpError {
381 0 : fn to_string_client(&self) -> String {
382 0 : match self {
383 0 : SqlOverHttpError::ReadPayload(p) => p.to_string(),
384 0 : SqlOverHttpError::ConnectCompute(c) => c.to_string_client(),
385 0 : SqlOverHttpError::ConnInfo(c) => c.to_string_client(),
386 0 : SqlOverHttpError::RequestTooLarge => self.to_string(),
387 0 : SqlOverHttpError::ResponseTooLarge => self.to_string(),
388 0 : SqlOverHttpError::InvalidIsolationLevel => self.to_string(),
389 0 : SqlOverHttpError::Postgres(p) => p.to_string(),
390 0 : SqlOverHttpError::JsonConversion(_) => "could not parse postgres response".to_string(),
391 0 : SqlOverHttpError::Cancelled(_) => self.to_string(),
392 : }
393 0 : }
394 : }
395 :
396 0 : #[derive(Debug, thiserror::Error)]
397 : pub enum ReadPayloadError {
398 : #[error("could not read the HTTP request body: {0}")]
399 : Read(#[from] hyper::Error),
400 : #[error("could not parse the HTTP request body: {0}")]
401 : Parse(#[from] serde_json::Error),
402 : }
403 :
404 : impl ReportableError for ReadPayloadError {
405 0 : fn get_error_kind(&self) -> ErrorKind {
406 0 : match self {
407 0 : ReadPayloadError::Read(_) => ErrorKind::ClientDisconnect,
408 0 : ReadPayloadError::Parse(_) => ErrorKind::User,
409 : }
410 0 : }
411 : }
412 :
413 0 : #[derive(Debug, thiserror::Error)]
414 : pub enum SqlOverHttpCancel {
415 : #[error("query was cancelled")]
416 : Postgres,
417 : #[error("query was cancelled while stuck trying to connect to the database")]
418 : Connect,
419 : }
420 :
421 : impl ReportableError for SqlOverHttpCancel {
422 0 : fn get_error_kind(&self) -> ErrorKind {
423 0 : match self {
424 0 : SqlOverHttpCancel::Postgres => ErrorKind::RateLimit,
425 0 : SqlOverHttpCancel::Connect => ErrorKind::ServiceRateLimit,
426 : }
427 0 : }
428 : }
429 :
430 : #[derive(Clone, Copy, Debug)]
431 : struct HttpHeaders {
432 : raw_output: bool,
433 : default_array_mode: bool,
434 : txn_isolation_level: Option<IsolationLevel>,
435 : txn_read_only: bool,
436 : txn_deferrable: bool,
437 : }
438 :
439 : impl HttpHeaders {
440 0 : fn try_parse(headers: &hyper::http::HeaderMap) -> Result<Self, SqlOverHttpError> {
441 0 : // Determine the output options. Default behaviour is 'false'. Anything that is not
442 0 : // strictly 'true' assumed to be false.
443 0 : let raw_output = headers.get(&RAW_TEXT_OUTPUT) == Some(&HEADER_VALUE_TRUE);
444 0 : let default_array_mode = headers.get(&ARRAY_MODE) == Some(&HEADER_VALUE_TRUE);
445 :
446 : // isolation level, read only and deferrable
447 0 : let txn_isolation_level = match headers.get(&TXN_ISOLATION_LEVEL) {
448 0 : Some(x) => Some(
449 0 : map_header_to_isolation_level(x).ok_or(SqlOverHttpError::InvalidIsolationLevel)?,
450 : ),
451 0 : None => None,
452 : };
453 :
454 0 : let txn_read_only = headers.get(&TXN_READ_ONLY) == Some(&HEADER_VALUE_TRUE);
455 0 : let txn_deferrable = headers.get(&TXN_DEFERRABLE) == Some(&HEADER_VALUE_TRUE);
456 0 :
457 0 : Ok(Self {
458 0 : raw_output,
459 0 : default_array_mode,
460 0 : txn_isolation_level,
461 0 : txn_read_only,
462 0 : txn_deferrable,
463 0 : })
464 0 : }
465 : }
466 :
467 0 : fn map_header_to_isolation_level(level: &HeaderValue) -> Option<IsolationLevel> {
468 0 : match level.as_bytes() {
469 0 : b"Serializable" => Some(IsolationLevel::Serializable),
470 0 : b"ReadUncommitted" => Some(IsolationLevel::ReadUncommitted),
471 0 : b"ReadCommitted" => Some(IsolationLevel::ReadCommitted),
472 0 : b"RepeatableRead" => Some(IsolationLevel::RepeatableRead),
473 0 : _ => None,
474 : }
475 0 : }
476 :
477 0 : fn map_isolation_level_to_headers(level: IsolationLevel) -> Option<HeaderValue> {
478 0 : match level {
479 0 : IsolationLevel::ReadUncommitted => Some(HeaderValue::from_static("ReadUncommitted")),
480 0 : IsolationLevel::ReadCommitted => Some(HeaderValue::from_static("ReadCommitted")),
481 0 : IsolationLevel::RepeatableRead => Some(HeaderValue::from_static("RepeatableRead")),
482 0 : IsolationLevel::Serializable => Some(HeaderValue::from_static("Serializable")),
483 0 : _ => None,
484 : }
485 0 : }
486 :
487 0 : async fn handle_inner(
488 0 : cancel: CancellationToken,
489 0 : config: &'static ProxyConfig,
490 0 : ctx: &mut RequestMonitoring,
491 0 : request: Request<Body>,
492 0 : backend: Arc<PoolingBackend>,
493 0 : ) -> Result<Response<Body>, SqlOverHttpError> {
494 0 : let _request_gauge = NUM_CONNECTION_REQUESTS_GAUGE
495 0 : .with_label_values(&[ctx.protocol])
496 0 : .guard();
497 0 : info!("handling interactive connection from client");
498 :
499 : //
500 : // Determine the destination and connection params
501 : //
502 0 : let headers = request.headers();
503 :
504 : // TLS config should be there.
505 0 : let conn_info = get_conn_info(ctx, headers, config.tls_config.as_ref().unwrap())?;
506 0 : info!(user = conn_info.user_info.user.as_str(), "credentials");
507 :
508 : // Allow connection pooling only if explicitly requested
509 : // or if we have decided that http pool is no longer opt-in
510 0 : let allow_pool = !config.http_config.pool_options.opt_in
511 0 : || headers.get(&ALLOW_POOL) == Some(&HEADER_VALUE_TRUE);
512 :
513 0 : let parsed_headers = HttpHeaders::try_parse(headers)?;
514 :
515 0 : let request_content_length = match request.body().size_hint().upper() {
516 0 : Some(v) => v,
517 0 : None => MAX_REQUEST_SIZE + 1,
518 : };
519 0 : info!(request_content_length, "request size in bytes");
520 0 : HTTP_CONTENT_LENGTH
521 0 : .with_label_values(&["request"])
522 0 : .observe(request_content_length as f64);
523 0 :
524 0 : // we don't have a streaming request support yet so this is to prevent OOM
525 0 : // from a malicious user sending an extremely large request body
526 0 : if request_content_length > MAX_REQUEST_SIZE {
527 0 : return Err(SqlOverHttpError::RequestTooLarge);
528 0 : }
529 0 :
530 0 : let fetch_and_process_request = async {
531 0 : let body = hyper::body::to_bytes(request.into_body()).await?;
532 0 : info!(length = body.len(), "request payload read");
533 0 : let payload: Payload = serde_json::from_slice(&body)?;
534 0 : Ok::<Payload, ReadPayloadError>(payload) // Adjust error type accordingly
535 0 : }
536 0 : .map_err(SqlOverHttpError::from);
537 0 :
538 0 : let authenticate_and_connect = async {
539 0 : let keys = backend.authenticate(ctx, &conn_info).await?;
540 0 : let client = backend
541 0 : .connect_to_compute(ctx, conn_info, keys, !allow_pool)
542 0 : .await?;
543 : // not strictly necessary to mark success here,
544 : // but it's just insurance for if we forget it somewhere else
545 0 : ctx.latency_timer.success();
546 0 : Ok::<_, HttpConnError>(client)
547 0 : }
548 0 : .map_err(SqlOverHttpError::from);
549 :
550 0 : let (payload, mut client) = match run_until_cancelled(
551 0 : // Run both operations in parallel
552 0 : try_join(
553 0 : pin!(fetch_and_process_request),
554 0 : pin!(authenticate_and_connect),
555 0 : ),
556 0 : &cancel,
557 0 : )
558 0 : .await
559 : {
560 0 : Some(result) => result?,
561 0 : None => return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Connect)),
562 : };
563 :
564 0 : let mut response = Response::builder()
565 0 : .status(StatusCode::OK)
566 0 : .header(header::CONTENT_TYPE, "application/json");
567 :
568 : //
569 : // Now execute the query and return the result
570 : //
571 0 : let result = match payload {
572 0 : Payload::Single(stmt) => stmt.process(cancel, &mut client, parsed_headers).await?,
573 0 : Payload::Batch(statements) => {
574 0 : if parsed_headers.txn_read_only {
575 0 : response = response.header(TXN_READ_ONLY.clone(), &HEADER_VALUE_TRUE);
576 0 : }
577 0 : if parsed_headers.txn_deferrable {
578 0 : response = response.header(TXN_DEFERRABLE.clone(), &HEADER_VALUE_TRUE);
579 0 : }
580 0 : if let Some(txn_isolation_level) = parsed_headers
581 0 : .txn_isolation_level
582 0 : .and_then(map_isolation_level_to_headers)
583 0 : {
584 0 : response = response.header(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level);
585 0 : }
586 :
587 0 : statements
588 0 : .process(cancel, &mut client, parsed_headers)
589 0 : .await?
590 : }
591 : };
592 :
593 0 : let metrics = client.metrics();
594 0 :
595 0 : // how could this possibly fail
596 0 : let body = serde_json::to_string(&result).expect("json serialization should not fail");
597 0 : let len = body.len();
598 0 : let response = response
599 0 : .body(Body::from(body))
600 0 : // only fails if invalid status code or invalid header/values are given.
601 0 : // these are not user configurable so it cannot fail dynamically
602 0 : .expect("building response payload should not fail");
603 0 :
604 0 : // count the egress bytes - we miss the TLS and header overhead but oh well...
605 0 : // moving this later in the stack is going to be a lot of effort and ehhhh
606 0 : metrics.record_egress(len as u64);
607 0 : HTTP_CONTENT_LENGTH
608 0 : .with_label_values(&["response"])
609 0 : .observe(len as f64);
610 0 :
611 0 : Ok(response)
612 0 : }
613 :
614 : impl QueryData {
615 0 : async fn process(
616 0 : self,
617 0 : cancel: CancellationToken,
618 0 : client: &mut Client<tokio_postgres::Client>,
619 0 : parsed_headers: HttpHeaders,
620 0 : ) -> Result<Value, SqlOverHttpError> {
621 0 : let (inner, mut discard) = client.inner();
622 0 : let cancel_token = inner.cancel_token();
623 :
624 0 : let res = match select(
625 0 : pin!(query_to_json(&*inner, self, &mut 0, parsed_headers)),
626 0 : pin!(cancel.cancelled()),
627 0 : )
628 0 : .await
629 : {
630 : // The query successfully completed.
631 0 : Either::Left((Ok((status, results)), __not_yet_cancelled)) => {
632 0 : discard.check_idle(status);
633 0 : Ok(results)
634 : }
635 : // The query failed with an error
636 0 : Either::Left((Err(e), __not_yet_cancelled)) => {
637 0 : discard.discard();
638 0 : return Err(e);
639 : }
640 : // The query was cancelled.
641 0 : Either::Right((_cancelled, query)) => {
642 0 : if let Err(err) = cancel_token.cancel_query(NoTls).await {
643 0 : tracing::error!(?err, "could not cancel query");
644 0 : }
645 : // wait for the query cancellation
646 0 : match time::timeout(time::Duration::from_millis(100), query).await {
647 : // query successed before it was cancelled.
648 0 : Ok(Ok((status, results))) => {
649 0 : discard.check_idle(status);
650 0 : Ok(results)
651 : }
652 : // query failed or was cancelled.
653 0 : Ok(Err(error)) => {
654 0 : let db_error = match &error {
655 0 : SqlOverHttpError::ConnectCompute(HttpConnError::ConnectionError(e))
656 0 : | SqlOverHttpError::Postgres(e) => e.as_db_error(),
657 0 : _ => None,
658 : };
659 :
660 : // if errored for some other reason, it might not be safe to return
661 0 : if !db_error.is_some_and(|e| *e.code() == SqlState::QUERY_CANCELED) {
662 0 : discard.discard();
663 0 : }
664 :
665 0 : Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres))
666 : }
667 0 : Err(_timeout) => {
668 0 : discard.discard();
669 0 : Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres))
670 : }
671 : }
672 : }
673 : };
674 0 : res
675 0 : }
676 : }
677 :
678 : impl BatchQueryData {
679 0 : async fn process(
680 0 : self,
681 0 : cancel: CancellationToken,
682 0 : client: &mut Client<tokio_postgres::Client>,
683 0 : parsed_headers: HttpHeaders,
684 0 : ) -> Result<Value, SqlOverHttpError> {
685 0 : info!("starting transaction");
686 0 : let (inner, mut discard) = client.inner();
687 0 : let cancel_token = inner.cancel_token();
688 0 : let mut builder = inner.build_transaction();
689 0 : if let Some(isolation_level) = parsed_headers.txn_isolation_level {
690 0 : builder = builder.isolation_level(isolation_level);
691 0 : }
692 0 : if parsed_headers.txn_read_only {
693 0 : builder = builder.read_only(true);
694 0 : }
695 0 : if parsed_headers.txn_deferrable {
696 0 : builder = builder.deferrable(true);
697 0 : }
698 :
699 0 : let transaction = builder.start().await.map_err(|e| {
700 0 : // if we cannot start a transaction, we should return immediately
701 0 : // and not return to the pool. connection is clearly broken
702 0 : discard.discard();
703 0 : e
704 0 : })?;
705 :
706 0 : let results =
707 0 : match query_batch(cancel.child_token(), &transaction, self, parsed_headers).await {
708 0 : Ok(results) => {
709 0 : info!("commit");
710 0 : let status = transaction.commit().await.map_err(|e| {
711 0 : // if we cannot commit - for now don't return connection to pool
712 0 : // TODO: get a query status from the error
713 0 : discard.discard();
714 0 : e
715 0 : })?;
716 0 : discard.check_idle(status);
717 0 : results
718 : }
719 : Err(SqlOverHttpError::Cancelled(_)) => {
720 0 : if let Err(err) = cancel_token.cancel_query(NoTls).await {
721 0 : tracing::error!(?err, "could not cancel query");
722 0 : }
723 : // TODO: after cancelling, wait to see if we can get a status. maybe the connection is still safe.
724 0 : discard.discard();
725 0 :
726 0 : return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres));
727 : }
728 0 : Err(err) => {
729 0 : info!("rollback");
730 0 : let status = transaction.rollback().await.map_err(|e| {
731 0 : // if we cannot rollback - for now don't return connection to pool
732 0 : // TODO: get a query status from the error
733 0 : discard.discard();
734 0 : e
735 0 : })?;
736 0 : discard.check_idle(status);
737 0 : return Err(err);
738 : }
739 : };
740 :
741 0 : Ok(json!({ "results": results }))
742 0 : }
743 : }
744 :
745 0 : async fn query_batch(
746 0 : cancel: CancellationToken,
747 0 : transaction: &Transaction<'_>,
748 0 : queries: BatchQueryData,
749 0 : parsed_headers: HttpHeaders,
750 0 : ) -> Result<Vec<Value>, SqlOverHttpError> {
751 0 : let mut results = Vec::with_capacity(queries.queries.len());
752 0 : let mut current_size = 0;
753 0 : for stmt in queries.queries {
754 0 : let query = pin!(query_to_json(
755 0 : transaction,
756 0 : stmt,
757 0 : &mut current_size,
758 0 : parsed_headers,
759 0 : ));
760 0 : let cancelled = pin!(cancel.cancelled());
761 0 : let res = select(query, cancelled).await;
762 0 : match res {
763 : // TODO: maybe we should check that the transaction bit is set here
764 0 : Either::Left((Ok((_, values)), _cancelled)) => {
765 0 : results.push(values);
766 0 : }
767 0 : Either::Left((Err(e), _cancelled)) => {
768 0 : return Err(e);
769 : }
770 0 : Either::Right((_cancelled, _)) => {
771 0 : return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres));
772 : }
773 : }
774 : }
775 0 : Ok(results)
776 0 : }
777 :
778 0 : async fn query_to_json<T: GenericClient>(
779 0 : client: &T,
780 0 : data: QueryData,
781 0 : current_size: &mut usize,
782 0 : parsed_headers: HttpHeaders,
783 0 : ) -> Result<(ReadyForQueryStatus, Value), SqlOverHttpError> {
784 0 : info!("executing query");
785 0 : let query_params = data.params;
786 0 : let mut row_stream = std::pin::pin!(client.query_raw_txt(&data.query, query_params).await?);
787 0 : info!("finished executing query");
788 :
789 : // Manually drain the stream into a vector to leave row_stream hanging
790 : // around to get a command tag. Also check that the response is not too
791 : // big.
792 0 : let mut rows: Vec<tokio_postgres::Row> = Vec::new();
793 0 : while let Some(row) = row_stream.next().await {
794 0 : let row = row?;
795 0 : *current_size += row.body_len();
796 0 : rows.push(row);
797 0 : // we don't have a streaming response support yet so this is to prevent OOM
798 0 : // from a malicious query (eg a cross join)
799 0 : if *current_size > MAX_RESPONSE_SIZE {
800 0 : return Err(SqlOverHttpError::ResponseTooLarge);
801 0 : }
802 : }
803 :
804 0 : let ready = row_stream.ready_status();
805 0 :
806 0 : // grab the command tag and number of rows affected
807 0 : let command_tag = row_stream.command_tag().unwrap_or_default();
808 0 : let mut command_tag_split = command_tag.split(' ');
809 0 : let command_tag_name = command_tag_split.next().unwrap_or_default();
810 0 : let command_tag_count = if command_tag_name == "INSERT" {
811 : // INSERT returns OID first and then number of rows
812 0 : command_tag_split.nth(1)
813 : } else {
814 : // other commands return number of rows (if any)
815 0 : command_tag_split.next()
816 : }
817 0 : .and_then(|s| s.parse::<i64>().ok());
818 0 :
819 0 : info!(
820 0 : rows = rows.len(),
821 0 : ?ready,
822 0 : command_tag,
823 0 : "finished reading rows"
824 0 : );
825 :
826 0 : let mut fields = vec![];
827 0 : let mut columns = vec![];
828 :
829 0 : for c in row_stream.columns() {
830 0 : fields.push(json!({
831 0 : "name": Value::String(c.name().to_owned()),
832 0 : "dataTypeID": Value::Number(c.type_().oid().into()),
833 0 : "tableID": c.table_oid(),
834 0 : "columnID": c.column_id(),
835 0 : "dataTypeSize": c.type_size(),
836 0 : "dataTypeModifier": c.type_modifier(),
837 0 : "format": "text",
838 0 : }));
839 0 : columns.push(client.get_type(c.type_oid()).await?);
840 : }
841 :
842 0 : let array_mode = data.array_mode.unwrap_or(parsed_headers.default_array_mode);
843 :
844 : // convert rows to JSON
845 0 : let rows = rows
846 0 : .iter()
847 0 : .map(|row| pg_text_row_to_json(row, &columns, parsed_headers.raw_output, array_mode))
848 0 : .collect::<Result<Vec<_>, _>>()?;
849 :
850 : // resulting JSON format is based on the format of node-postgres result
851 0 : Ok((
852 0 : ready,
853 0 : json!({
854 0 : "command": command_tag_name,
855 0 : "rowCount": command_tag_count,
856 0 : "rows": rows,
857 0 : "fields": fields,
858 0 : "rowAsArray": array_mode,
859 0 : }),
860 0 : ))
861 0 : }
|