Line data Source code
1 : //! Things stolen from `libs/utils/src/http` to add hyper 1.0 compatibility
2 : //! Will merge back in at some point in the future.
3 :
4 : use anyhow::Context;
5 : use bytes::Bytes;
6 : use http::header::AUTHORIZATION;
7 : use http::{HeaderMap, HeaderName, HeaderValue, Response, StatusCode};
8 : use http_body_util::combinators::BoxBody;
9 : use http_body_util::{BodyExt, Full};
10 : use http_utils::error::ApiError;
11 : use serde::Serialize;
12 : use url::Url;
13 : use uuid::Uuid;
14 :
15 : use super::conn_pool::{AuthData, ConnInfoWithAuth};
16 : use super::conn_pool_lib::ConnInfo;
17 : use super::error::{ConnInfoError, Credentials};
18 : use crate::auth::backend::ComputeUserInfo;
19 : use crate::config::AuthenticationConfig;
20 : use crate::context::RequestContext;
21 : use crate::metrics::{Metrics, SniGroup, SniKind};
22 : use crate::pqproto::StartupMessageParams;
23 : use crate::proxy::NeonOptions;
24 : use crate::types::{DbName, EndpointId, RoleName};
25 :
26 : // Common header names used across serverless modules
27 : pub(super) static NEON_REQUEST_ID: HeaderName = HeaderName::from_static("neon-request-id");
28 : pub(super) static CONN_STRING: HeaderName = HeaderName::from_static("neon-connection-string");
29 : pub(super) static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output");
30 : pub(super) static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode");
31 : pub(super) static ALLOW_POOL: HeaderName = HeaderName::from_static("neon-pool-opt-in");
32 : pub(super) static TXN_ISOLATION_LEVEL: HeaderName =
33 : HeaderName::from_static("neon-batch-isolation-level");
34 : pub(super) static TXN_READ_ONLY: HeaderName = HeaderName::from_static("neon-batch-read-only");
35 : pub(super) static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrable");
36 :
37 0 : pub(crate) fn uuid_to_header_value(id: Uuid) -> HeaderValue {
38 0 : let mut uuid = [0; uuid::fmt::Hyphenated::LENGTH];
39 0 : HeaderValue::from_str(id.as_hyphenated().encode_lower(&mut uuid[..]))
40 0 : .expect("uuid hyphenated format should be all valid header characters")
41 0 : }
42 :
43 : /// Like [`ApiError::into_response`]
44 0 : pub(crate) fn api_error_into_response(this: ApiError) -> Response<BoxBody<Bytes, hyper::Error>> {
45 0 : match this {
46 0 : ApiError::BadRequest(err) => HttpErrorBody::response_from_msg_and_status(
47 0 : format!("{err:#?}"), // use debug printing so that we give the cause
48 : StatusCode::BAD_REQUEST,
49 : ),
50 : ApiError::Forbidden(_) => {
51 0 : HttpErrorBody::response_from_msg_and_status(this.to_string(), StatusCode::FORBIDDEN)
52 : }
53 : ApiError::Unauthorized(_) => {
54 0 : HttpErrorBody::response_from_msg_and_status(this.to_string(), StatusCode::UNAUTHORIZED)
55 : }
56 : ApiError::NotFound(_) => {
57 0 : HttpErrorBody::response_from_msg_and_status(this.to_string(), StatusCode::NOT_FOUND)
58 : }
59 : ApiError::Conflict(_) => {
60 0 : HttpErrorBody::response_from_msg_and_status(this.to_string(), StatusCode::CONFLICT)
61 : }
62 0 : ApiError::PreconditionFailed(_) => HttpErrorBody::response_from_msg_and_status(
63 0 : this.to_string(),
64 : StatusCode::PRECONDITION_FAILED,
65 : ),
66 0 : ApiError::ShuttingDown => HttpErrorBody::response_from_msg_and_status(
67 0 : "Shutting down".to_string(),
68 : StatusCode::SERVICE_UNAVAILABLE,
69 : ),
70 0 : ApiError::ResourceUnavailable(err) => HttpErrorBody::response_from_msg_and_status(
71 0 : err.to_string(),
72 : StatusCode::SERVICE_UNAVAILABLE,
73 : ),
74 0 : ApiError::TooManyRequests(err) => HttpErrorBody::response_from_msg_and_status(
75 0 : err.to_string(),
76 : StatusCode::TOO_MANY_REQUESTS,
77 : ),
78 0 : ApiError::Timeout(err) => HttpErrorBody::response_from_msg_and_status(
79 0 : err.to_string(),
80 : StatusCode::REQUEST_TIMEOUT,
81 : ),
82 0 : ApiError::Cancelled => HttpErrorBody::response_from_msg_and_status(
83 0 : this.to_string(),
84 : StatusCode::INTERNAL_SERVER_ERROR,
85 : ),
86 0 : ApiError::InternalServerError(err) => HttpErrorBody::response_from_msg_and_status(
87 0 : err.to_string(),
88 : StatusCode::INTERNAL_SERVER_ERROR,
89 : ),
90 : }
91 0 : }
92 :
93 : /// Same as [`http_utils::error::HttpErrorBody`]
94 : #[derive(Serialize)]
95 : struct HttpErrorBody {
96 : pub(crate) msg: String,
97 : }
98 :
99 : impl HttpErrorBody {
100 : /// Same as [`http_utils::error::HttpErrorBody::response_from_msg_and_status`]
101 0 : fn response_from_msg_and_status(
102 0 : msg: String,
103 0 : status: StatusCode,
104 0 : ) -> Response<BoxBody<Bytes, hyper::Error>> {
105 0 : HttpErrorBody { msg }.to_response(status)
106 0 : }
107 :
108 : /// Same as [`http_utils::error::HttpErrorBody::to_response`]
109 0 : fn to_response(&self, status: StatusCode) -> Response<BoxBody<Bytes, hyper::Error>> {
110 0 : Response::builder()
111 0 : .status(status)
112 0 : .header(http::header::CONTENT_TYPE, "application/json")
113 : // we do not have nested maps with non string keys so serialization shouldn't fail
114 0 : .body(
115 0 : Full::new(Bytes::from(
116 0 : serde_json::to_string(self)
117 0 : .expect("serialising HttpErrorBody should never fail"),
118 : ))
119 0 : .map_err(|x| match x {})
120 0 : .boxed(),
121 : )
122 0 : .expect("content-type header should be valid")
123 0 : }
124 : }
125 :
126 : /// Same as [`http_utils::json::json_response`]
127 0 : pub(crate) fn json_response<T: Serialize>(
128 0 : status: StatusCode,
129 0 : data: T,
130 0 : ) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
131 0 : let json = serde_json::to_string(&data)
132 0 : .context("Failed to serialize JSON response")
133 0 : .map_err(ApiError::InternalServerError)?;
134 0 : let response = Response::builder()
135 0 : .status(status)
136 0 : .header(http::header::CONTENT_TYPE, "application/json")
137 0 : .body(Full::new(Bytes::from(json)).map_err(|x| match x {}).boxed())
138 0 : .map_err(|e| ApiError::InternalServerError(e.into()))?;
139 0 : Ok(response)
140 0 : }
141 :
142 0 : pub(crate) fn get_conn_info(
143 0 : config: &'static AuthenticationConfig,
144 0 : ctx: &RequestContext,
145 0 : connection_string: Option<&str>,
146 0 : headers: &HeaderMap,
147 0 : ) -> Result<ConnInfoWithAuth, ConnInfoError> {
148 0 : let connection_url = match connection_string {
149 0 : Some(connection_string) => Url::parse(connection_string)?,
150 : None => {
151 0 : let connection_string = headers
152 0 : .get(&CONN_STRING)
153 0 : .ok_or(ConnInfoError::InvalidHeader(&CONN_STRING))?
154 0 : .to_str()
155 0 : .map_err(|_| ConnInfoError::InvalidHeader(&CONN_STRING))?;
156 0 : Url::parse(connection_string)?
157 : }
158 : };
159 :
160 0 : let protocol = connection_url.scheme();
161 0 : if protocol != "postgres" && protocol != "postgresql" {
162 0 : return Err(ConnInfoError::IncorrectScheme);
163 0 : }
164 :
165 0 : let mut url_path = connection_url
166 0 : .path_segments()
167 0 : .ok_or(ConnInfoError::MissingDbName)?;
168 :
169 0 : let dbname: DbName =
170 0 : urlencoding::decode(url_path.next().ok_or(ConnInfoError::InvalidDbName)?)?.into();
171 0 : ctx.set_dbname(dbname.clone());
172 :
173 0 : let username = RoleName::from(urlencoding::decode(connection_url.username())?);
174 0 : if username.is_empty() {
175 0 : return Err(ConnInfoError::MissingUsername);
176 0 : }
177 0 : ctx.set_user(username.clone());
178 : // TODO: make sure this is right in the context of rest broker
179 0 : let auth = if let Some(auth) = headers.get(&AUTHORIZATION) {
180 0 : if !config.accept_jwts {
181 0 : return Err(ConnInfoError::MissingCredentials(Credentials::Password));
182 0 : }
183 :
184 0 : let auth = auth
185 0 : .to_str()
186 0 : .map_err(|_| ConnInfoError::InvalidHeader(&AUTHORIZATION))?;
187 : AuthData::Jwt(
188 0 : auth.strip_prefix("Bearer ")
189 0 : .ok_or(ConnInfoError::MissingCredentials(Credentials::BearerJwt))?
190 0 : .into(),
191 : )
192 0 : } else if let Some(pass) = connection_url.password() {
193 : // wrong credentials provided
194 0 : if config.accept_jwts {
195 0 : return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
196 0 : }
197 :
198 0 : AuthData::Password(match urlencoding::decode_binary(pass.as_bytes()) {
199 0 : std::borrow::Cow::Borrowed(b) => b.into(),
200 0 : std::borrow::Cow::Owned(b) => b.into(),
201 : })
202 0 : } else if config.accept_jwts {
203 0 : return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
204 : } else {
205 0 : return Err(ConnInfoError::MissingCredentials(Credentials::Password));
206 : };
207 0 : let endpoint: EndpointId = match connection_url.host() {
208 0 : Some(url::Host::Domain(hostname)) => hostname
209 0 : .split_once('.')
210 0 : .map_or(hostname, |(prefix, _)| prefix)
211 0 : .into(),
212 : Some(url::Host::Ipv4(_) | url::Host::Ipv6(_)) | None => {
213 0 : return Err(ConnInfoError::MissingHostname);
214 : }
215 : };
216 0 : ctx.set_endpoint_id(endpoint.clone());
217 :
218 0 : let pairs = connection_url.query_pairs();
219 :
220 0 : let mut options = Option::None;
221 :
222 0 : let mut params = StartupMessageParams::default();
223 0 : params.insert("user", &username);
224 0 : params.insert("database", &dbname);
225 0 : for (key, value) in pairs {
226 0 : params.insert(&key, &value);
227 0 : if key == "options" {
228 0 : options = Some(NeonOptions::parse_options_raw(&value));
229 0 : }
230 : }
231 :
232 : // check the URL that was used, for metrics
233 : {
234 0 : let host_endpoint = headers
235 : // get the host header
236 0 : .get("host")
237 : // extract the domain
238 0 : .and_then(|h| {
239 0 : let (host, _port) = h.to_str().ok()?.split_once(':')?;
240 0 : Some(host)
241 0 : })
242 : // get the endpoint prefix
243 0 : .map(|h| h.split_once('.').map_or(h, |(prefix, _)| prefix));
244 :
245 0 : let kind = if host_endpoint == Some(&*endpoint) {
246 0 : SniKind::Sni
247 : } else {
248 0 : SniKind::NoSni
249 : };
250 :
251 0 : let protocol = ctx.protocol();
252 0 : Metrics::get()
253 0 : .proxy
254 0 : .accepted_connections_by_sni
255 0 : .inc(SniGroup { protocol, kind });
256 : }
257 :
258 0 : ctx.set_user_agent(
259 0 : headers
260 0 : .get(hyper::header::USER_AGENT)
261 0 : .and_then(|h| h.to_str().ok())
262 0 : .map(Into::into),
263 : );
264 :
265 0 : let user_info = ComputeUserInfo {
266 0 : endpoint,
267 0 : user: username,
268 0 : options: options.unwrap_or_default(),
269 0 : };
270 :
271 0 : let conn_info = ConnInfo { user_info, dbname };
272 0 : Ok(ConnInfoWithAuth { conn_info, auth })
273 0 : }
|