Line data Source code
1 : //! Production console backend.
2 :
3 : use std::net::IpAddr;
4 : use std::str::FromStr;
5 : use std::sync::Arc;
6 :
7 : use ::http::HeaderName;
8 : use ::http::header::AUTHORIZATION;
9 : use bytes::Bytes;
10 : use futures::TryFutureExt;
11 : use hyper::StatusCode;
12 : use postgres_client::config::SslMode;
13 : use tokio::time::Instant;
14 : use tracing::{Instrument, debug, info, info_span, warn};
15 :
16 : use super::super::messages::{ControlPlaneErrorMessage, GetEndpointAccessControl, WakeCompute};
17 : use crate::auth::backend::ComputeUserInfo;
18 : use crate::auth::backend::jwt::AuthRule;
19 : use crate::cache::Cached;
20 : use crate::cache::node_info::CachedNodeInfo;
21 : use crate::context::RequestContext;
22 : use crate::control_plane::caches::ApiCaches;
23 : use crate::control_plane::errors::{
24 : ControlPlaneError, GetAuthInfoError, GetEndpointJwksError, WakeComputeError,
25 : };
26 : use crate::control_plane::locks::ApiLocks;
27 : use crate::control_plane::messages::{ColdStartInfo, EndpointJwksResponse};
28 : use crate::control_plane::{
29 : AccessBlockerFlags, AuthInfo, AuthSecret, EndpointAccessControl, NodeInfo, RoleAccessControl,
30 : };
31 : use crate::metrics::Metrics;
32 : use crate::proxy::retry::CouldRetry;
33 : use crate::rate_limiter::WakeComputeRateLimiter;
34 : use crate::types::{EndpointCacheKey, EndpointId, RoleName};
35 : use crate::{compute, http, scram};
36 :
37 : pub(crate) const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");
38 :
39 : #[derive(Clone)]
40 : pub struct NeonControlPlaneClient {
41 : endpoint: http::Endpoint,
42 : pub caches: &'static ApiCaches,
43 : pub(crate) locks: &'static ApiLocks<EndpointCacheKey>,
44 : pub(crate) wake_compute_endpoint_rate_limiter: Arc<WakeComputeRateLimiter>,
45 : // put in a shared ref so we don't copy secrets all over in memory
46 : jwt: Arc<str>,
47 : }
48 :
49 : impl NeonControlPlaneClient {
50 : /// Construct an API object containing the auth parameters.
51 0 : pub fn new(
52 0 : endpoint: http::Endpoint,
53 0 : jwt: Arc<str>,
54 0 : caches: &'static ApiCaches,
55 0 : locks: &'static ApiLocks<EndpointCacheKey>,
56 0 : wake_compute_endpoint_rate_limiter: Arc<WakeComputeRateLimiter>,
57 0 : ) -> Self {
58 0 : Self {
59 0 : endpoint,
60 0 : caches,
61 0 : locks,
62 0 : wake_compute_endpoint_rate_limiter,
63 0 : jwt,
64 0 : }
65 0 : }
66 :
67 0 : pub(crate) fn url(&self) -> &str {
68 0 : self.endpoint.url().as_str()
69 0 : }
70 :
71 0 : async fn get_and_cache_auth_info<T>(
72 0 : &self,
73 0 : ctx: &RequestContext,
74 0 : endpoint: &EndpointId,
75 0 : role: &RoleName,
76 0 : cache_key: &EndpointId,
77 0 : extract: impl FnOnce(&EndpointAccessControl, &RoleAccessControl) -> T,
78 0 : ) -> Result<T, GetAuthInfoError> {
79 0 : match self.do_get_auth_req(ctx, endpoint, role).await {
80 0 : Ok(auth_info) => {
81 0 : let control = EndpointAccessControl {
82 0 : allowed_ips: Arc::new(auth_info.allowed_ips),
83 0 : allowed_vpce: Arc::new(auth_info.allowed_vpc_endpoint_ids),
84 0 : flags: auth_info.access_blocker_flags,
85 0 : rate_limits: auth_info.rate_limits,
86 0 : };
87 0 : let role_control = RoleAccessControl {
88 0 : secret: auth_info.secret,
89 0 : };
90 0 : let res = extract(&control, &role_control);
91 :
92 0 : self.caches.project_info.insert_endpoint_access(
93 0 : auth_info.account_id,
94 0 : auth_info.project_id,
95 0 : cache_key.into(),
96 0 : role.into(),
97 0 : control,
98 0 : role_control,
99 0 : );
100 :
101 0 : if let Some(project_id) = auth_info.project_id {
102 0 : ctx.set_project_id(project_id);
103 0 : }
104 :
105 0 : Ok(res)
106 : }
107 0 : Err(err) => match err {
108 0 : GetAuthInfoError::ApiError(ControlPlaneError::Message(ref msg)) => {
109 0 : let retry_info = msg.status.as_ref().and_then(|s| s.details.retry_info);
110 :
111 : // If we can retry this error, do not cache it,
112 : // unless we were given a retry delay.
113 0 : if msg.could_retry() && retry_info.is_none() {
114 0 : return Err(err);
115 0 : }
116 :
117 0 : self.caches.project_info.insert_endpoint_access_err(
118 0 : cache_key.into(),
119 0 : role.into(),
120 0 : msg.clone(),
121 0 : );
122 :
123 0 : Err(err)
124 : }
125 0 : err => Err(err),
126 : },
127 : }
128 0 : }
129 :
130 0 : async fn do_get_auth_req(
131 0 : &self,
132 0 : ctx: &RequestContext,
133 0 : endpoint: &EndpointId,
134 0 : role: &RoleName,
135 0 : ) -> Result<AuthInfo, GetAuthInfoError> {
136 0 : async {
137 0 : let response = {
138 0 : let request = self
139 0 : .endpoint
140 0 : .get_path("get_endpoint_access_control")
141 0 : .header(X_REQUEST_ID, ctx.session_id().to_string())
142 0 : .header(AUTHORIZATION, format!("Bearer {}", &self.jwt))
143 0 : .query(&[("session_id", ctx.session_id())])
144 0 : .query(&[
145 0 : ("application_name", ctx.console_application_name().as_str()),
146 0 : ("endpointish", endpoint.as_str()),
147 0 : ("role", role.as_str()),
148 0 : ])
149 0 : .build()?;
150 :
151 0 : debug!(url = request.url().as_str(), "sending http request");
152 0 : let start = Instant::now();
153 0 : let _pause = ctx.latency_timer_pause_at(start, crate::metrics::Waiting::Cplane);
154 0 : let response = self.endpoint.execute(request).await?;
155 :
156 0 : info!(duration = ?start.elapsed(), "received http response");
157 :
158 0 : response
159 : };
160 :
161 0 : let body = match parse_body::<GetEndpointAccessControl>(
162 0 : response.status(),
163 0 : response.bytes().await?,
164 : ) {
165 0 : Ok(body) => body,
166 : // Error 404 is special: it's ok not to have a secret.
167 : // TODO(anna): retry
168 0 : Err(e) => {
169 0 : return if e.get_reason().is_not_found() {
170 : // TODO: refactor this because it's weird
171 : // this is a failure to authenticate but we return Ok.
172 0 : Ok(AuthInfo::default())
173 : } else {
174 0 : Err(e.into())
175 : };
176 : }
177 : };
178 :
179 0 : let secret = if body.role_secret.is_empty() {
180 0 : None
181 : } else {
182 0 : let secret = scram::ServerSecret::parse(&body.role_secret)
183 0 : .map(AuthSecret::Scram)
184 0 : .ok_or(GetAuthInfoError::BadSecret)?;
185 0 : Some(secret)
186 : };
187 0 : let allowed_ips = body.allowed_ips.unwrap_or_default();
188 0 : Metrics::get()
189 0 : .proxy
190 0 : .allowed_ips_number
191 0 : .observe(allowed_ips.len() as f64);
192 0 : let allowed_vpc_endpoint_ids = body.allowed_vpc_endpoint_ids.unwrap_or_default();
193 0 : Metrics::get()
194 0 : .proxy
195 0 : .allowed_vpc_endpoint_ids
196 0 : .observe(allowed_vpc_endpoint_ids.len() as f64);
197 0 : let block_public_connections = body.block_public_connections.unwrap_or_default();
198 0 : let block_vpc_connections = body.block_vpc_connections.unwrap_or_default();
199 0 : Ok(AuthInfo {
200 0 : secret,
201 0 : allowed_ips,
202 0 : allowed_vpc_endpoint_ids,
203 0 : project_id: body.project_id,
204 0 : account_id: body.account_id,
205 0 : access_blocker_flags: AccessBlockerFlags {
206 0 : public_access_blocked: block_public_connections,
207 0 : vpc_access_blocked: block_vpc_connections,
208 0 : },
209 0 : rate_limits: body.rate_limits,
210 0 : })
211 0 : }
212 0 : .inspect_err(|e| tracing::debug!(error = ?e))
213 0 : .instrument(info_span!("do_get_auth_info"))
214 0 : .await
215 0 : }
216 :
217 0 : async fn do_get_endpoint_jwks(
218 0 : &self,
219 0 : ctx: &RequestContext,
220 0 : endpoint: &EndpointId,
221 0 : ) -> Result<Vec<AuthRule>, GetEndpointJwksError> {
222 0 : let request_id = ctx.session_id().to_string();
223 0 : async {
224 0 : let request = self
225 0 : .endpoint
226 0 : .get_with_url(|url| {
227 0 : url.path_segments_mut()
228 0 : .push("endpoints")
229 0 : .push(endpoint.as_str())
230 0 : .push("jwks");
231 0 : })
232 0 : .header(X_REQUEST_ID, &request_id)
233 0 : .header(AUTHORIZATION, format!("Bearer {}", &self.jwt))
234 0 : .query(&[("session_id", ctx.session_id())])
235 0 : .build()
236 0 : .map_err(GetEndpointJwksError::RequestBuild)?;
237 :
238 0 : debug!(url = request.url().as_str(), "sending http request");
239 0 : let start = Instant::now();
240 0 : let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane);
241 0 : let response = self
242 0 : .endpoint
243 0 : .execute(request)
244 0 : .await
245 0 : .map_err(GetEndpointJwksError::RequestExecute)?;
246 0 : drop(pause);
247 0 : info!(duration = ?start.elapsed(), "received http response");
248 :
249 0 : let body = parse_body::<EndpointJwksResponse>(
250 0 : response.status(),
251 0 : response.bytes().await.map_err(ControlPlaneError::from)?,
252 0 : )?;
253 :
254 0 : let rules = body
255 0 : .jwks
256 0 : .into_iter()
257 0 : .map(|jwks| AuthRule {
258 0 : id: jwks.id,
259 0 : jwks_url: jwks.jwks_url,
260 0 : audience: jwks.jwt_audience,
261 0 : role_names: jwks.role_names,
262 0 : })
263 0 : .collect();
264 :
265 0 : Ok(rules)
266 0 : }
267 0 : .inspect_err(|e| tracing::debug!(error = ?e))
268 0 : .instrument(info_span!("do_get_endpoint_jwks"))
269 0 : .await
270 0 : }
271 :
272 0 : async fn do_wake_compute(
273 0 : &self,
274 0 : ctx: &RequestContext,
275 0 : user_info: &ComputeUserInfo,
276 0 : ) -> Result<NodeInfo, WakeComputeError> {
277 0 : let request_id = ctx.session_id().to_string();
278 0 : let application_name = ctx.console_application_name();
279 0 : async {
280 0 : let mut request_builder = self
281 0 : .endpoint
282 0 : .get_path("wake_compute")
283 0 : .header("X-Request-ID", &request_id)
284 0 : .header("Authorization", format!("Bearer {}", &self.jwt))
285 0 : .query(&[("session_id", ctx.session_id())])
286 0 : .query(&[
287 0 : ("application_name", application_name.as_str()),
288 0 : ("endpointish", user_info.endpoint.as_str()),
289 0 : ]);
290 :
291 0 : let options = user_info.options.to_deep_object();
292 0 : if !options.is_empty() {
293 0 : request_builder = request_builder.query(&options);
294 0 : }
295 :
296 0 : let request = request_builder.build()?;
297 :
298 0 : debug!(url = request.url().as_str(), "sending http request");
299 0 : let start = Instant::now();
300 0 : let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane);
301 0 : let response = self.endpoint.execute(request).await?;
302 0 : drop(pause);
303 0 : info!(duration = ?start.elapsed(), "received http response");
304 0 : let body = parse_body::<WakeCompute>(response.status(), response.bytes().await?)?;
305 :
306 0 : let Some((host, port)) = parse_host_port(&body.address) else {
307 0 : return Err(WakeComputeError::BadComputeAddress(body.address));
308 : };
309 :
310 0 : let host_addr = IpAddr::from_str(host).ok();
311 :
312 0 : let ssl_mode = match &body.server_name {
313 0 : Some(_) => SslMode::Require,
314 0 : None => SslMode::Disable,
315 : };
316 0 : let host = match body.server_name {
317 0 : Some(host) => host.into(),
318 0 : None => host.into(),
319 : };
320 :
321 0 : let node = NodeInfo {
322 0 : conn_info: compute::ConnectInfo {
323 0 : host_addr,
324 0 : host,
325 0 : port,
326 0 : ssl_mode,
327 0 : },
328 0 : aux: body.aux,
329 0 : };
330 :
331 0 : Ok(node)
332 0 : }
333 0 : .inspect_err(|e| tracing::debug!(error = ?e))
334 0 : .instrument(info_span!("do_wake_compute"))
335 0 : .await
336 0 : }
337 : }
338 :
339 : impl super::ControlPlaneApi for NeonControlPlaneClient {
340 : #[tracing::instrument(skip_all)]
341 : async fn get_role_access_control(
342 : &self,
343 : ctx: &RequestContext,
344 : endpoint: &EndpointId,
345 : role: &RoleName,
346 : ) -> Result<RoleAccessControl, GetAuthInfoError> {
347 : let key = endpoint.normalize();
348 :
349 : if let Some(role_control) = self.caches.project_info.get_role_secret(&key, role) {
350 : return match role_control {
351 : Err(msg) => {
352 : info!(key = &*key, "found cached get_role_access_control error");
353 :
354 : Err(GetAuthInfoError::ApiError(ControlPlaneError::Message(msg)))
355 : }
356 : Ok(role_control) => {
357 : debug!(key = &*key, "found cached role access control");
358 : Ok(role_control)
359 : }
360 : };
361 : }
362 :
363 0 : self.get_and_cache_auth_info(ctx, endpoint, role, &key, |_, role_control| {
364 0 : role_control.clone()
365 0 : })
366 : .await
367 : }
368 :
369 : #[tracing::instrument(skip_all)]
370 : async fn get_endpoint_access_control(
371 : &self,
372 : ctx: &RequestContext,
373 : endpoint: &EndpointId,
374 : role: &RoleName,
375 : ) -> Result<EndpointAccessControl, GetAuthInfoError> {
376 : let key = endpoint.normalize();
377 :
378 : if let Some(control) = self.caches.project_info.get_endpoint_access(&key) {
379 : return match control {
380 : Err(msg) => {
381 : info!(
382 : key = &*key,
383 : "found cached get_endpoint_access_control error"
384 : );
385 :
386 : Err(GetAuthInfoError::ApiError(ControlPlaneError::Message(msg)))
387 : }
388 : Ok(control) => {
389 : debug!(key = &*key, "found cached endpoint access control");
390 : Ok(control)
391 : }
392 : };
393 : }
394 :
395 0 : self.get_and_cache_auth_info(ctx, endpoint, role, &key, |control, _| control.clone())
396 : .await
397 : }
398 :
399 : #[tracing::instrument(skip_all)]
400 : async fn get_endpoint_jwks(
401 : &self,
402 : ctx: &RequestContext,
403 : endpoint: &EndpointId,
404 : ) -> Result<Vec<AuthRule>, GetEndpointJwksError> {
405 : self.do_get_endpoint_jwks(ctx, endpoint).await
406 : }
407 :
408 : #[tracing::instrument(skip_all)]
409 : async fn wake_compute(
410 : &self,
411 : ctx: &RequestContext,
412 : user_info: &ComputeUserInfo,
413 : ) -> Result<CachedNodeInfo, WakeComputeError> {
414 : let key = user_info.endpoint_cache_key();
415 :
416 : macro_rules! check_cache {
417 : () => {
418 : if let Some(info) = self.caches.node_info.get_entry(&key) {
419 : return match info {
420 : Err(msg) => {
421 : info!(key = &*key, "found cached wake_compute error");
422 :
423 : Err(WakeComputeError::ControlPlane(ControlPlaneError::Message(
424 : msg,
425 : )))
426 : }
427 : Ok(info) => {
428 : debug!(key = &*key, "found cached compute node info");
429 : ctx.set_project(info.aux.clone());
430 : Ok(info)
431 : }
432 : };
433 : }
434 : };
435 : }
436 :
437 : // Every time we do a wakeup http request, the compute node will stay up
438 : // for some time (highly depends on the console's scale-to-zero policy);
439 : // The connection info remains the same during that period of time,
440 : // which means that we might cache it to reduce the load and latency.
441 : check_cache!();
442 :
443 : let permit = self.locks.get_permit(&key).await?;
444 :
445 : // after getting back a permit - it's possible the cache was filled
446 : // double check
447 : if permit.should_check_cache() {
448 : // TODO: if there is something in the cache, mark the permit as success.
449 : check_cache!();
450 : }
451 :
452 : // check rate limit
453 : if !self
454 : .wake_compute_endpoint_rate_limiter
455 : .check(user_info.endpoint.normalize_intern(), 1)
456 : {
457 : return Err(WakeComputeError::TooManyConnections);
458 : }
459 :
460 : let node = permit.release_result(self.do_wake_compute(ctx, user_info).await);
461 : match node {
462 : Ok(node) => {
463 : ctx.set_project(node.aux.clone());
464 : debug!(key = &*key, "created a cache entry for woken compute node");
465 :
466 : let mut stored_node = node.clone();
467 : // store the cached node as 'warm_cached'
468 : stored_node.aux.cold_start_info = ColdStartInfo::WarmCached;
469 : self.caches.node_info.insert(key.clone(), Ok(stored_node));
470 :
471 : Ok(Cached {
472 : token: Some((&self.caches.node_info, key)),
473 : value: node,
474 : })
475 : }
476 : Err(err) => match err {
477 : WakeComputeError::ControlPlane(ControlPlaneError::Message(ref msg)) => {
478 : let retry_info = msg.status.as_ref().and_then(|s| s.details.retry_info);
479 :
480 : // If we can retry this error, do not cache it,
481 : // unless we were given a retry delay.
482 : if msg.could_retry() && retry_info.is_none() {
483 : return Err(err);
484 : }
485 :
486 : debug!(
487 : key = &*key,
488 : "created a cache entry for the wake compute error"
489 : );
490 :
491 : self.caches.node_info.insert(key, Err(msg.clone()));
492 :
493 : Err(err)
494 : }
495 : err => Err(err),
496 : },
497 : }
498 : }
499 : }
500 :
501 : /// Parse http response body, taking status code into account.
502 0 : fn parse_body<T: for<'a> serde::Deserialize<'a>>(
503 0 : status: StatusCode,
504 0 : body: Bytes,
505 0 : ) -> Result<T, ControlPlaneError> {
506 0 : if status.is_success() {
507 : // We shouldn't log raw body because it may contain secrets.
508 0 : info!("request succeeded, processing the body");
509 0 : return Ok(serde_json::from_slice(&body).map_err(std::io::Error::other)?);
510 0 : }
511 :
512 : // Log plaintext to be able to detect, whether there are some cases not covered by the error struct.
513 0 : info!("response_error plaintext: {:?}", body);
514 :
515 : // Don't throw an error here because it's not as important
516 : // as the fact that the request itself has failed.
517 0 : let mut body = serde_json::from_slice(&body).unwrap_or_else(|e| {
518 0 : warn!("failed to parse error body: {e}");
519 0 : Box::new(ControlPlaneErrorMessage {
520 0 : error: "reason unclear (malformed error message)".into(),
521 0 : http_status_code: status,
522 0 : status: None,
523 0 : })
524 0 : });
525 0 : body.http_status_code = status;
526 :
527 0 : warn!("console responded with an error ({status}): {body:?}");
528 0 : Err(ControlPlaneError::Message(body))
529 0 : }
530 :
531 3 : fn parse_host_port(input: &str) -> Option<(&str, u16)> {
532 3 : let (host, port) = input.rsplit_once(':')?;
533 3 : let ipv6_brackets: &[_] = &['[', ']'];
534 3 : Some((host.trim_matches(ipv6_brackets), port.parse().ok()?))
535 3 : }
536 :
537 : #[cfg(test)]
538 : mod tests {
539 : use super::*;
540 :
541 : #[test]
542 1 : fn test_parse_host_port_v4() {
543 1 : let (host, port) = parse_host_port("127.0.0.1:5432").expect("failed to parse");
544 1 : assert_eq!(host, "127.0.0.1");
545 1 : assert_eq!(port, 5432);
546 1 : }
547 :
548 : #[test]
549 1 : fn test_parse_host_port_v6() {
550 1 : let (host, port) = parse_host_port("[2001:db8::1]:5432").expect("failed to parse");
551 1 : assert_eq!(host, "2001:db8::1");
552 1 : assert_eq!(port, 5432);
553 1 : }
554 :
555 : #[test]
556 1 : fn test_parse_host_port_url() {
557 1 : let (host, port) = parse_host_port("compute-foo-bar-1234.default.svc.cluster.local:5432")
558 1 : .expect("failed to parse");
559 1 : assert_eq!(host, "compute-foo-bar-1234.default.svc.cluster.local");
560 1 : assert_eq!(port, 5432);
561 1 : }
562 : }
|