Line data Source code
1 : //! Production console backend.
2 :
3 : use std::net::IpAddr;
4 : use std::str::FromStr;
5 : use std::sync::Arc;
6 : use std::time::Duration;
7 :
8 : use ::http::HeaderName;
9 : use ::http::header::AUTHORIZATION;
10 : use bytes::Bytes;
11 : use futures::TryFutureExt;
12 : use hyper::StatusCode;
13 : use postgres_client::config::SslMode;
14 : use tokio::time::Instant;
15 : use tracing::{Instrument, debug, info, info_span, warn};
16 :
17 : use super::super::messages::{ControlPlaneErrorMessage, GetEndpointAccessControl, WakeCompute};
18 : use crate::auth::backend::ComputeUserInfo;
19 : use crate::auth::backend::jwt::AuthRule;
20 : use crate::context::RequestContext;
21 : use crate::control_plane::caches::ApiCaches;
22 : use crate::control_plane::errors::{
23 : ControlPlaneError, GetAuthInfoError, GetEndpointJwksError, WakeComputeError,
24 : };
25 : use crate::control_plane::locks::ApiLocks;
26 : use crate::control_plane::messages::{ColdStartInfo, EndpointJwksResponse};
27 : use crate::control_plane::{
28 : AccessBlockerFlags, AuthInfo, AuthSecret, CachedNodeInfo, EndpointAccessControl, NodeInfo,
29 : RoleAccessControl,
30 : };
31 : use crate::metrics::Metrics;
32 : use crate::proxy::retry::CouldRetry;
33 : use crate::rate_limiter::WakeComputeRateLimiter;
34 : use crate::types::{EndpointCacheKey, EndpointId, RoleName};
35 : use crate::{compute, http, scram};
36 :
37 : pub(crate) const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");
38 :
39 : #[derive(Clone)]
40 : pub struct NeonControlPlaneClient {
41 : endpoint: http::Endpoint,
42 : pub caches: &'static ApiCaches,
43 : pub(crate) locks: &'static ApiLocks<EndpointCacheKey>,
44 : pub(crate) wake_compute_endpoint_rate_limiter: Arc<WakeComputeRateLimiter>,
45 : // put in a shared ref so we don't copy secrets all over in memory
46 : jwt: Arc<str>,
47 : }
48 :
49 : impl NeonControlPlaneClient {
50 : /// Construct an API object containing the auth parameters.
51 0 : pub fn new(
52 0 : endpoint: http::Endpoint,
53 0 : jwt: Arc<str>,
54 0 : caches: &'static ApiCaches,
55 0 : locks: &'static ApiLocks<EndpointCacheKey>,
56 0 : wake_compute_endpoint_rate_limiter: Arc<WakeComputeRateLimiter>,
57 0 : ) -> Self {
58 0 : Self {
59 0 : endpoint,
60 0 : caches,
61 0 : locks,
62 0 : wake_compute_endpoint_rate_limiter,
63 0 : jwt,
64 0 : }
65 0 : }
66 :
67 0 : pub(crate) fn url(&self) -> &str {
68 0 : self.endpoint.url().as_str()
69 0 : }
70 :
71 0 : async fn do_get_auth_req(
72 0 : &self,
73 0 : ctx: &RequestContext,
74 0 : endpoint: &EndpointId,
75 0 : role: &RoleName,
76 0 : ) -> Result<AuthInfo, GetAuthInfoError> {
77 0 : async {
78 0 : let response = {
79 0 : let request = self
80 0 : .endpoint
81 0 : .get_path("get_endpoint_access_control")
82 0 : .header(X_REQUEST_ID, ctx.session_id().to_string())
83 0 : .header(AUTHORIZATION, format!("Bearer {}", &self.jwt))
84 0 : .query(&[("session_id", ctx.session_id())])
85 0 : .query(&[
86 0 : ("application_name", ctx.console_application_name().as_str()),
87 0 : ("endpointish", endpoint.as_str()),
88 0 : ("role", role.as_str()),
89 0 : ])
90 0 : .build()?;
91 :
92 0 : debug!(url = request.url().as_str(), "sending http request");
93 0 : let start = Instant::now();
94 0 : let _pause = ctx.latency_timer_pause_at(start, crate::metrics::Waiting::Cplane);
95 0 : let response = self.endpoint.execute(request).await?;
96 :
97 0 : info!(duration = ?start.elapsed(), "received http response");
98 :
99 0 : response
100 : };
101 :
102 0 : let body = match parse_body::<GetEndpointAccessControl>(
103 0 : response.status(),
104 0 : response.bytes().await?,
105 : ) {
106 0 : Ok(body) => body,
107 : // Error 404 is special: it's ok not to have a secret.
108 : // TODO(anna): retry
109 0 : Err(e) => {
110 0 : return if e.get_reason().is_not_found() {
111 : // TODO: refactor this because it's weird
112 : // this is a failure to authenticate but we return Ok.
113 0 : Ok(AuthInfo::default())
114 : } else {
115 0 : Err(e.into())
116 : };
117 : }
118 : };
119 :
120 0 : let secret = if body.role_secret.is_empty() {
121 0 : None
122 : } else {
123 0 : let secret = scram::ServerSecret::parse(&body.role_secret)
124 0 : .map(AuthSecret::Scram)
125 0 : .ok_or(GetAuthInfoError::BadSecret)?;
126 0 : Some(secret)
127 : };
128 0 : let allowed_ips = body.allowed_ips.unwrap_or_default();
129 0 : Metrics::get()
130 0 : .proxy
131 0 : .allowed_ips_number
132 0 : .observe(allowed_ips.len() as f64);
133 0 : let allowed_vpc_endpoint_ids = body.allowed_vpc_endpoint_ids.unwrap_or_default();
134 0 : Metrics::get()
135 0 : .proxy
136 0 : .allowed_vpc_endpoint_ids
137 0 : .observe(allowed_vpc_endpoint_ids.len() as f64);
138 0 : let block_public_connections = body.block_public_connections.unwrap_or_default();
139 0 : let block_vpc_connections = body.block_vpc_connections.unwrap_or_default();
140 0 : Ok(AuthInfo {
141 0 : secret,
142 0 : allowed_ips,
143 0 : allowed_vpc_endpoint_ids,
144 0 : project_id: body.project_id,
145 0 : account_id: body.account_id,
146 0 : access_blocker_flags: AccessBlockerFlags {
147 0 : public_access_blocked: block_public_connections,
148 0 : vpc_access_blocked: block_vpc_connections,
149 0 : },
150 0 : rate_limits: body.rate_limits,
151 0 : })
152 0 : }
153 0 : .inspect_err(|e| tracing::debug!(error = ?e))
154 0 : .instrument(info_span!("do_get_auth_info"))
155 0 : .await
156 0 : }
157 :
158 0 : async fn do_get_endpoint_jwks(
159 0 : &self,
160 0 : ctx: &RequestContext,
161 0 : endpoint: &EndpointId,
162 0 : ) -> Result<Vec<AuthRule>, GetEndpointJwksError> {
163 0 : let request_id = ctx.session_id().to_string();
164 0 : async {
165 0 : let request = self
166 0 : .endpoint
167 0 : .get_with_url(|url| {
168 0 : url.path_segments_mut()
169 0 : .push("endpoints")
170 0 : .push(endpoint.as_str())
171 0 : .push("jwks");
172 0 : })
173 0 : .header(X_REQUEST_ID, &request_id)
174 0 : .header(AUTHORIZATION, format!("Bearer {}", &self.jwt))
175 0 : .query(&[("session_id", ctx.session_id())])
176 0 : .build()
177 0 : .map_err(GetEndpointJwksError::RequestBuild)?;
178 :
179 0 : debug!(url = request.url().as_str(), "sending http request");
180 0 : let start = Instant::now();
181 0 : let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane);
182 0 : let response = self
183 0 : .endpoint
184 0 : .execute(request)
185 0 : .await
186 0 : .map_err(GetEndpointJwksError::RequestExecute)?;
187 0 : drop(pause);
188 0 : info!(duration = ?start.elapsed(), "received http response");
189 :
190 0 : let body = parse_body::<EndpointJwksResponse>(
191 0 : response.status(),
192 0 : response.bytes().await.map_err(ControlPlaneError::from)?,
193 0 : )?;
194 :
195 0 : let rules = body
196 0 : .jwks
197 0 : .into_iter()
198 0 : .map(|jwks| AuthRule {
199 0 : id: jwks.id,
200 0 : jwks_url: jwks.jwks_url,
201 0 : audience: jwks.jwt_audience,
202 0 : role_names: jwks.role_names,
203 0 : })
204 0 : .collect();
205 :
206 0 : Ok(rules)
207 0 : }
208 0 : .inspect_err(|e| tracing::debug!(error = ?e))
209 0 : .instrument(info_span!("do_get_endpoint_jwks"))
210 0 : .await
211 0 : }
212 :
213 0 : async fn do_wake_compute(
214 0 : &self,
215 0 : ctx: &RequestContext,
216 0 : user_info: &ComputeUserInfo,
217 0 : ) -> Result<NodeInfo, WakeComputeError> {
218 0 : let request_id = ctx.session_id().to_string();
219 0 : let application_name = ctx.console_application_name();
220 0 : async {
221 0 : let mut request_builder = self
222 0 : .endpoint
223 0 : .get_path("wake_compute")
224 0 : .header("X-Request-ID", &request_id)
225 0 : .header("Authorization", format!("Bearer {}", &self.jwt))
226 0 : .query(&[("session_id", ctx.session_id())])
227 0 : .query(&[
228 0 : ("application_name", application_name.as_str()),
229 0 : ("endpointish", user_info.endpoint.as_str()),
230 0 : ]);
231 :
232 0 : let options = user_info.options.to_deep_object();
233 0 : if !options.is_empty() {
234 0 : request_builder = request_builder.query(&options);
235 0 : }
236 :
237 0 : let request = request_builder.build()?;
238 :
239 0 : debug!(url = request.url().as_str(), "sending http request");
240 0 : let start = Instant::now();
241 0 : let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane);
242 0 : let response = self.endpoint.execute(request).await?;
243 0 : drop(pause);
244 0 : info!(duration = ?start.elapsed(), "received http response");
245 0 : let body = parse_body::<WakeCompute>(response.status(), response.bytes().await?)?;
246 :
247 0 : let Some((host, port)) = parse_host_port(&body.address) else {
248 0 : return Err(WakeComputeError::BadComputeAddress(body.address));
249 : };
250 :
251 0 : let host_addr = IpAddr::from_str(host).ok();
252 :
253 0 : let ssl_mode = match &body.server_name {
254 0 : Some(_) => SslMode::Require,
255 0 : None => SslMode::Disable,
256 : };
257 0 : let host = match body.server_name {
258 0 : Some(host) => host.into(),
259 0 : None => host.into(),
260 : };
261 :
262 0 : let node = NodeInfo {
263 0 : conn_info: compute::ConnectInfo {
264 0 : host_addr,
265 0 : host,
266 0 : port,
267 0 : ssl_mode,
268 0 : },
269 0 : aux: body.aux,
270 0 : };
271 :
272 0 : Ok(node)
273 0 : }
274 0 : .inspect_err(|e| tracing::debug!(error = ?e))
275 0 : .instrument(info_span!("do_wake_compute"))
276 0 : .await
277 0 : }
278 : }
279 :
280 : impl super::ControlPlaneApi for NeonControlPlaneClient {
281 : #[tracing::instrument(skip_all)]
282 : async fn get_role_access_control(
283 : &self,
284 : ctx: &RequestContext,
285 : endpoint: &EndpointId,
286 : role: &RoleName,
287 : ) -> Result<RoleAccessControl, crate::control_plane::errors::GetAuthInfoError> {
288 : let normalized_ep = &endpoint.normalize();
289 : if let Some(secret) = self
290 : .caches
291 : .project_info
292 : .get_role_secret(normalized_ep, role)
293 : {
294 : return Ok(secret);
295 : }
296 :
297 : let auth_info = self.do_get_auth_req(ctx, endpoint, role).await?;
298 :
299 : let control = EndpointAccessControl {
300 : allowed_ips: Arc::new(auth_info.allowed_ips),
301 : allowed_vpce: Arc::new(auth_info.allowed_vpc_endpoint_ids),
302 : flags: auth_info.access_blocker_flags,
303 : rate_limits: auth_info.rate_limits,
304 : };
305 : let role_control = RoleAccessControl {
306 : secret: auth_info.secret,
307 : };
308 :
309 : if let Some(project_id) = auth_info.project_id {
310 : let normalized_ep_int = normalized_ep.into();
311 :
312 : self.caches.project_info.insert_endpoint_access(
313 : auth_info.account_id,
314 : project_id,
315 : normalized_ep_int,
316 : role.into(),
317 : control,
318 : role_control.clone(),
319 : );
320 : ctx.set_project_id(project_id);
321 : }
322 :
323 : Ok(role_control)
324 : }
325 :
326 : #[tracing::instrument(skip_all)]
327 : async fn get_endpoint_access_control(
328 : &self,
329 : ctx: &RequestContext,
330 : endpoint: &EndpointId,
331 : role: &RoleName,
332 : ) -> Result<EndpointAccessControl, GetAuthInfoError> {
333 : let normalized_ep = &endpoint.normalize();
334 : if let Some(control) = self.caches.project_info.get_endpoint_access(normalized_ep) {
335 : return Ok(control);
336 : }
337 :
338 : let auth_info = self.do_get_auth_req(ctx, endpoint, role).await?;
339 :
340 : let control = EndpointAccessControl {
341 : allowed_ips: Arc::new(auth_info.allowed_ips),
342 : allowed_vpce: Arc::new(auth_info.allowed_vpc_endpoint_ids),
343 : flags: auth_info.access_blocker_flags,
344 : rate_limits: auth_info.rate_limits,
345 : };
346 : let role_control = RoleAccessControl {
347 : secret: auth_info.secret,
348 : };
349 :
350 : if let Some(project_id) = auth_info.project_id {
351 : let normalized_ep_int = normalized_ep.into();
352 :
353 : self.caches.project_info.insert_endpoint_access(
354 : auth_info.account_id,
355 : project_id,
356 : normalized_ep_int,
357 : role.into(),
358 : control.clone(),
359 : role_control,
360 : );
361 : ctx.set_project_id(project_id);
362 : }
363 :
364 : Ok(control)
365 : }
366 :
367 : #[tracing::instrument(skip_all)]
368 : async fn get_endpoint_jwks(
369 : &self,
370 : ctx: &RequestContext,
371 : endpoint: &EndpointId,
372 : ) -> Result<Vec<AuthRule>, GetEndpointJwksError> {
373 : self.do_get_endpoint_jwks(ctx, endpoint).await
374 : }
375 :
376 : #[tracing::instrument(skip_all)]
377 : async fn wake_compute(
378 : &self,
379 : ctx: &RequestContext,
380 : user_info: &ComputeUserInfo,
381 : ) -> Result<CachedNodeInfo, WakeComputeError> {
382 : let key = user_info.endpoint_cache_key();
383 :
384 : macro_rules! check_cache {
385 : () => {
386 : if let Some(cached) = self.caches.node_info.get_with_created_at(&key) {
387 : let (cached, (info, created_at)) = cached.take_value();
388 : return match info {
389 : Err(mut msg) => {
390 : info!(key = &*key, "found cached wake_compute error");
391 :
392 : // if retry_delay_ms is set, reduce it by the amount of time it spent in cache
393 : if let Some(status) = &mut msg.status {
394 : if let Some(retry_info) = &mut status.details.retry_info {
395 : retry_info.retry_delay_ms = retry_info
396 : .retry_delay_ms
397 : .saturating_sub(created_at.elapsed().as_millis() as u64)
398 : }
399 : }
400 :
401 : Err(WakeComputeError::ControlPlane(ControlPlaneError::Message(
402 : msg,
403 : )))
404 : }
405 : Ok(info) => {
406 : debug!(key = &*key, "found cached compute node info");
407 : ctx.set_project(info.aux.clone());
408 : Ok(cached.map(|()| info))
409 : }
410 : };
411 : }
412 : };
413 : }
414 :
415 : // Every time we do a wakeup http request, the compute node will stay up
416 : // for some time (highly depends on the console's scale-to-zero policy);
417 : // The connection info remains the same during that period of time,
418 : // which means that we might cache it to reduce the load and latency.
419 : check_cache!();
420 :
421 : let permit = self.locks.get_permit(&key).await?;
422 :
423 : // after getting back a permit - it's possible the cache was filled
424 : // double check
425 : if permit.should_check_cache() {
426 : // TODO: if there is something in the cache, mark the permit as success.
427 : check_cache!();
428 : }
429 :
430 : // check rate limit
431 : if !self
432 : .wake_compute_endpoint_rate_limiter
433 : .check(user_info.endpoint.normalize_intern(), 1)
434 : {
435 : return Err(WakeComputeError::TooManyConnections);
436 : }
437 :
438 : let node = permit.release_result(self.do_wake_compute(ctx, user_info).await);
439 : match node {
440 : Ok(node) => {
441 : ctx.set_project(node.aux.clone());
442 : debug!(key = &*key, "created a cache entry for woken compute node");
443 :
444 : let mut stored_node = node.clone();
445 : // store the cached node as 'warm_cached'
446 : stored_node.aux.cold_start_info = ColdStartInfo::WarmCached;
447 :
448 : let (_, cached) = self.caches.node_info.insert_unit(key, Ok(stored_node));
449 :
450 : Ok(cached.map(|()| node))
451 : }
452 : Err(err) => match err {
453 : WakeComputeError::ControlPlane(ControlPlaneError::Message(ref msg)) => {
454 : let retry_info = msg.status.as_ref().and_then(|s| s.details.retry_info);
455 :
456 : // If we can retry this error, do not cache it,
457 : // unless we were given a retry delay.
458 : if msg.could_retry() && retry_info.is_none() {
459 : return Err(err);
460 : }
461 :
462 : debug!(
463 : key = &*key,
464 : "created a cache entry for the wake compute error"
465 : );
466 :
467 0 : let ttl = retry_info.map_or(Duration::from_secs(30), |r| {
468 0 : Duration::from_millis(r.retry_delay_ms)
469 0 : });
470 :
471 : self.caches.node_info.insert_ttl(key, Err(msg.clone()), ttl);
472 :
473 : Err(err)
474 : }
475 : err => Err(err),
476 : },
477 : }
478 : }
479 : }
480 :
481 : /// Parse http response body, taking status code into account.
482 0 : fn parse_body<T: for<'a> serde::Deserialize<'a>>(
483 0 : status: StatusCode,
484 0 : body: Bytes,
485 0 : ) -> Result<T, ControlPlaneError> {
486 0 : if status.is_success() {
487 : // We shouldn't log raw body because it may contain secrets.
488 0 : info!("request succeeded, processing the body");
489 0 : return Ok(serde_json::from_slice(&body).map_err(std::io::Error::other)?);
490 0 : }
491 :
492 : // Log plaintext to be able to detect, whether there are some cases not covered by the error struct.
493 0 : info!("response_error plaintext: {:?}", body);
494 :
495 : // Don't throw an error here because it's not as important
496 : // as the fact that the request itself has failed.
497 0 : let mut body = serde_json::from_slice(&body).unwrap_or_else(|e| {
498 0 : warn!("failed to parse error body: {e}");
499 0 : Box::new(ControlPlaneErrorMessage {
500 0 : error: "reason unclear (malformed error message)".into(),
501 0 : http_status_code: status,
502 0 : status: None,
503 0 : })
504 0 : });
505 0 : body.http_status_code = status;
506 :
507 0 : warn!("console responded with an error ({status}): {body:?}");
508 0 : Err(ControlPlaneError::Message(body))
509 0 : }
510 :
511 3 : fn parse_host_port(input: &str) -> Option<(&str, u16)> {
512 3 : let (host, port) = input.rsplit_once(':')?;
513 3 : let ipv6_brackets: &[_] = &['[', ']'];
514 3 : Some((host.trim_matches(ipv6_brackets), port.parse().ok()?))
515 3 : }
516 :
517 : #[cfg(test)]
518 : mod tests {
519 : use super::*;
520 :
521 : #[test]
522 1 : fn test_parse_host_port_v4() {
523 1 : let (host, port) = parse_host_port("127.0.0.1:5432").expect("failed to parse");
524 1 : assert_eq!(host, "127.0.0.1");
525 1 : assert_eq!(port, 5432);
526 1 : }
527 :
528 : #[test]
529 1 : fn test_parse_host_port_v6() {
530 1 : let (host, port) = parse_host_port("[2001:db8::1]:5432").expect("failed to parse");
531 1 : assert_eq!(host, "2001:db8::1");
532 1 : assert_eq!(port, 5432);
533 1 : }
534 :
535 : #[test]
536 1 : fn test_parse_host_port_url() {
537 1 : let (host, port) = parse_host_port("compute-foo-bar-1234.default.svc.cluster.local:5432")
538 1 : .expect("failed to parse");
539 1 : assert_eq!(host, "compute-foo-bar-1234.default.svc.cluster.local");
540 1 : assert_eq!(port, 5432);
541 1 : }
542 : }
|