Line data Source code
1 : //! Production console backend.
2 :
3 : use std::net::IpAddr;
4 : use std::str::FromStr;
5 : use std::sync::Arc;
6 : use std::time::Duration;
7 :
8 : use ::http::HeaderName;
9 : use ::http::header::AUTHORIZATION;
10 : use bytes::Bytes;
11 : use futures::TryFutureExt;
12 : use hyper::StatusCode;
13 : use postgres_client::config::SslMode;
14 : use tokio::time::Instant;
15 : use tracing::{Instrument, debug, info, info_span, warn};
16 :
17 : use super::super::messages::{ControlPlaneErrorMessage, GetEndpointAccessControl, WakeCompute};
18 : use crate::auth::backend::ComputeUserInfo;
19 : use crate::auth::backend::jwt::AuthRule;
20 : use crate::context::RequestContext;
21 : use crate::control_plane::caches::ApiCaches;
22 : use crate::control_plane::errors::{
23 : ControlPlaneError, GetAuthInfoError, GetEndpointJwksError, WakeComputeError,
24 : };
25 : use crate::control_plane::locks::ApiLocks;
26 : use crate::control_plane::messages::{ColdStartInfo, EndpointJwksResponse, Reason};
27 : use crate::control_plane::{
28 : AccessBlockerFlags, AuthInfo, AuthSecret, CachedNodeInfo, EndpointAccessControl, NodeInfo,
29 : RoleAccessControl,
30 : };
31 : use crate::metrics::Metrics;
32 : use crate::rate_limiter::WakeComputeRateLimiter;
33 : use crate::types::{EndpointCacheKey, EndpointId, RoleName};
34 : use crate::{compute, http, scram};
35 :
36 : pub(crate) const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");
37 :
38 : #[derive(Clone)]
39 : pub struct NeonControlPlaneClient {
40 : endpoint: http::Endpoint,
41 : pub caches: &'static ApiCaches,
42 : pub(crate) locks: &'static ApiLocks<EndpointCacheKey>,
43 : pub(crate) wake_compute_endpoint_rate_limiter: Arc<WakeComputeRateLimiter>,
44 : // put in a shared ref so we don't copy secrets all over in memory
45 : jwt: Arc<str>,
46 : }
47 :
48 : impl NeonControlPlaneClient {
49 : /// Construct an API object containing the auth parameters.
50 0 : pub fn new(
51 0 : endpoint: http::Endpoint,
52 0 : jwt: Arc<str>,
53 0 : caches: &'static ApiCaches,
54 0 : locks: &'static ApiLocks<EndpointCacheKey>,
55 0 : wake_compute_endpoint_rate_limiter: Arc<WakeComputeRateLimiter>,
56 0 : ) -> Self {
57 0 : Self {
58 0 : endpoint,
59 0 : caches,
60 0 : locks,
61 0 : wake_compute_endpoint_rate_limiter,
62 0 : jwt,
63 0 : }
64 0 : }
65 :
66 0 : pub(crate) fn url(&self) -> &str {
67 0 : self.endpoint.url().as_str()
68 0 : }
69 :
70 0 : async fn do_get_auth_req(
71 0 : &self,
72 0 : ctx: &RequestContext,
73 0 : endpoint: &EndpointId,
74 0 : role: &RoleName,
75 0 : ) -> Result<AuthInfo, GetAuthInfoError> {
76 0 : async {
77 0 : let response = {
78 0 : let request = self
79 0 : .endpoint
80 0 : .get_path("get_endpoint_access_control")
81 0 : .header(X_REQUEST_ID, ctx.session_id().to_string())
82 0 : .header(AUTHORIZATION, format!("Bearer {}", &self.jwt))
83 0 : .query(&[("session_id", ctx.session_id())])
84 0 : .query(&[
85 0 : ("application_name", ctx.console_application_name().as_str()),
86 0 : ("endpointish", endpoint.as_str()),
87 0 : ("role", role.as_str()),
88 0 : ])
89 0 : .build()?;
90 :
91 0 : debug!(url = request.url().as_str(), "sending http request");
92 0 : let start = Instant::now();
93 0 : let _pause = ctx.latency_timer_pause_at(start, crate::metrics::Waiting::Cplane);
94 0 : let response = self.endpoint.execute(request).await?;
95 :
96 0 : info!(duration = ?start.elapsed(), "received http response");
97 :
98 0 : response
99 : };
100 :
101 0 : let body = match parse_body::<GetEndpointAccessControl>(
102 0 : response.status(),
103 0 : response.bytes().await?,
104 : ) {
105 0 : Ok(body) => body,
106 : // Error 404 is special: it's ok not to have a secret.
107 : // TODO(anna): retry
108 0 : Err(e) => {
109 0 : return if e.get_reason().is_not_found() {
110 : // TODO: refactor this because it's weird
111 : // this is a failure to authenticate but we return Ok.
112 0 : Ok(AuthInfo::default())
113 : } else {
114 0 : Err(e.into())
115 : };
116 : }
117 : };
118 :
119 0 : let secret = if body.role_secret.is_empty() {
120 0 : None
121 : } else {
122 0 : let secret = scram::ServerSecret::parse(&body.role_secret)
123 0 : .map(AuthSecret::Scram)
124 0 : .ok_or(GetAuthInfoError::BadSecret)?;
125 0 : Some(secret)
126 : };
127 0 : let allowed_ips = body.allowed_ips.unwrap_or_default();
128 0 : Metrics::get()
129 0 : .proxy
130 0 : .allowed_ips_number
131 0 : .observe(allowed_ips.len() as f64);
132 0 : let allowed_vpc_endpoint_ids = body.allowed_vpc_endpoint_ids.unwrap_or_default();
133 0 : Metrics::get()
134 0 : .proxy
135 0 : .allowed_vpc_endpoint_ids
136 0 : .observe(allowed_vpc_endpoint_ids.len() as f64);
137 0 : let block_public_connections = body.block_public_connections.unwrap_or_default();
138 0 : let block_vpc_connections = body.block_vpc_connections.unwrap_or_default();
139 0 : Ok(AuthInfo {
140 0 : secret,
141 0 : allowed_ips,
142 0 : allowed_vpc_endpoint_ids,
143 0 : project_id: body.project_id,
144 0 : account_id: body.account_id,
145 0 : access_blocker_flags: AccessBlockerFlags {
146 0 : public_access_blocked: block_public_connections,
147 0 : vpc_access_blocked: block_vpc_connections,
148 0 : },
149 0 : rate_limits: body.rate_limits,
150 0 : })
151 0 : }
152 0 : .inspect_err(|e| tracing::debug!(error = ?e))
153 0 : .instrument(info_span!("do_get_auth_info"))
154 0 : .await
155 0 : }
156 :
157 0 : async fn do_get_endpoint_jwks(
158 0 : &self,
159 0 : ctx: &RequestContext,
160 0 : endpoint: &EndpointId,
161 0 : ) -> Result<Vec<AuthRule>, GetEndpointJwksError> {
162 0 : if !self
163 0 : .caches
164 0 : .endpoints_cache
165 0 : .is_valid(ctx, &endpoint.normalize())
166 : {
167 0 : return Err(GetEndpointJwksError::EndpointNotFound);
168 0 : }
169 0 : let request_id = ctx.session_id().to_string();
170 0 : async {
171 0 : let request = self
172 0 : .endpoint
173 0 : .get_with_url(|url| {
174 0 : url.path_segments_mut()
175 0 : .push("endpoints")
176 0 : .push(endpoint.as_str())
177 0 : .push("jwks");
178 0 : })
179 0 : .header(X_REQUEST_ID, &request_id)
180 0 : .header(AUTHORIZATION, format!("Bearer {}", &self.jwt))
181 0 : .query(&[("session_id", ctx.session_id())])
182 0 : .build()
183 0 : .map_err(GetEndpointJwksError::RequestBuild)?;
184 :
185 0 : debug!(url = request.url().as_str(), "sending http request");
186 0 : let start = Instant::now();
187 0 : let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane);
188 0 : let response = self
189 0 : .endpoint
190 0 : .execute(request)
191 0 : .await
192 0 : .map_err(GetEndpointJwksError::RequestExecute)?;
193 0 : drop(pause);
194 0 : info!(duration = ?start.elapsed(), "received http response");
195 :
196 0 : let body = parse_body::<EndpointJwksResponse>(
197 0 : response.status(),
198 0 : response.bytes().await.map_err(ControlPlaneError::from)?,
199 0 : )?;
200 :
201 0 : let rules = body
202 0 : .jwks
203 0 : .into_iter()
204 0 : .map(|jwks| AuthRule {
205 0 : id: jwks.id,
206 0 : jwks_url: jwks.jwks_url,
207 0 : audience: jwks.jwt_audience,
208 0 : role_names: jwks.role_names,
209 0 : })
210 0 : .collect();
211 0 :
212 0 : Ok(rules)
213 0 : }
214 0 : .inspect_err(|e| tracing::debug!(error = ?e))
215 0 : .instrument(info_span!("do_get_endpoint_jwks"))
216 0 : .await
217 0 : }
218 :
219 0 : async fn do_wake_compute(
220 0 : &self,
221 0 : ctx: &RequestContext,
222 0 : user_info: &ComputeUserInfo,
223 0 : ) -> Result<NodeInfo, WakeComputeError> {
224 0 : let request_id = ctx.session_id().to_string();
225 0 : let application_name = ctx.console_application_name();
226 0 : async {
227 0 : let mut request_builder = self
228 0 : .endpoint
229 0 : .get_path("wake_compute")
230 0 : .header("X-Request-ID", &request_id)
231 0 : .header("Authorization", format!("Bearer {}", &self.jwt))
232 0 : .query(&[("session_id", ctx.session_id())])
233 0 : .query(&[
234 0 : ("application_name", application_name.as_str()),
235 0 : ("endpointish", user_info.endpoint.as_str()),
236 0 : ]);
237 0 :
238 0 : let options = user_info.options.to_deep_object();
239 0 : if !options.is_empty() {
240 0 : request_builder = request_builder.query(&options);
241 0 : }
242 :
243 0 : let request = request_builder.build()?;
244 :
245 0 : debug!(url = request.url().as_str(), "sending http request");
246 0 : let start = Instant::now();
247 0 : let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane);
248 0 : let response = self.endpoint.execute(request).await?;
249 0 : drop(pause);
250 0 : info!(duration = ?start.elapsed(), "received http response");
251 0 : let body = parse_body::<WakeCompute>(response.status(), response.bytes().await?)?;
252 :
253 : // Unfortunately, ownership won't let us use `Option::ok_or` here.
254 0 : let (host, port) = match parse_host_port(&body.address) {
255 0 : None => return Err(WakeComputeError::BadComputeAddress(body.address)),
256 0 : Some(x) => x,
257 0 : };
258 0 :
259 0 : let host_addr = IpAddr::from_str(host).ok();
260 :
261 0 : let ssl_mode = match &body.server_name {
262 0 : Some(_) => SslMode::Require,
263 0 : None => SslMode::Disable,
264 : };
265 0 : let host = match body.server_name {
266 0 : Some(host) => host.into(),
267 0 : None => host.into(),
268 : };
269 :
270 0 : let node = NodeInfo {
271 0 : conn_info: compute::ConnectInfo {
272 0 : host_addr,
273 0 : host,
274 0 : port,
275 0 : ssl_mode,
276 0 : },
277 0 : aux: body.aux,
278 0 : };
279 0 :
280 0 : Ok(node)
281 0 : }
282 0 : .inspect_err(|e| tracing::debug!(error = ?e))
283 0 : .instrument(info_span!("do_wake_compute"))
284 0 : .await
285 0 : }
286 : }
287 :
288 : impl super::ControlPlaneApi for NeonControlPlaneClient {
289 : #[tracing::instrument(skip_all)]
290 : async fn get_role_access_control(
291 : &self,
292 : ctx: &RequestContext,
293 : endpoint: &EndpointId,
294 : role: &RoleName,
295 : ) -> Result<RoleAccessControl, crate::control_plane::errors::GetAuthInfoError> {
296 : let normalized_ep = &endpoint.normalize();
297 : if let Some(secret) = self
298 : .caches
299 : .project_info
300 : .get_role_secret(normalized_ep, role)
301 : {
302 : return Ok(secret);
303 : }
304 :
305 : if !self.caches.endpoints_cache.is_valid(ctx, normalized_ep) {
306 : info!("endpoint is not valid, skipping the request");
307 : return Err(GetAuthInfoError::UnknownEndpoint);
308 : }
309 :
310 : let auth_info = self.do_get_auth_req(ctx, endpoint, role).await?;
311 :
312 : let control = EndpointAccessControl {
313 : allowed_ips: Arc::new(auth_info.allowed_ips),
314 : allowed_vpce: Arc::new(auth_info.allowed_vpc_endpoint_ids),
315 : flags: auth_info.access_blocker_flags,
316 : rate_limits: auth_info.rate_limits,
317 : };
318 : let role_control = RoleAccessControl {
319 : secret: auth_info.secret,
320 : };
321 :
322 : if let Some(project_id) = auth_info.project_id {
323 : let normalized_ep_int = normalized_ep.into();
324 :
325 : self.caches.project_info.insert_endpoint_access(
326 : auth_info.account_id,
327 : project_id,
328 : normalized_ep_int,
329 : role.into(),
330 : control,
331 : role_control.clone(),
332 : );
333 : ctx.set_project_id(project_id);
334 : }
335 :
336 : Ok(role_control)
337 : }
338 :
339 : #[tracing::instrument(skip_all)]
340 : async fn get_endpoint_access_control(
341 : &self,
342 : ctx: &RequestContext,
343 : endpoint: &EndpointId,
344 : role: &RoleName,
345 : ) -> Result<EndpointAccessControl, GetAuthInfoError> {
346 : let normalized_ep = &endpoint.normalize();
347 : if let Some(control) = self.caches.project_info.get_endpoint_access(normalized_ep) {
348 : return Ok(control);
349 : }
350 :
351 : if !self.caches.endpoints_cache.is_valid(ctx, normalized_ep) {
352 : info!("endpoint is not valid, skipping the request");
353 : return Err(GetAuthInfoError::UnknownEndpoint);
354 : }
355 :
356 : let auth_info = self.do_get_auth_req(ctx, endpoint, role).await?;
357 :
358 : let control = EndpointAccessControl {
359 : allowed_ips: Arc::new(auth_info.allowed_ips),
360 : allowed_vpce: Arc::new(auth_info.allowed_vpc_endpoint_ids),
361 : flags: auth_info.access_blocker_flags,
362 : rate_limits: auth_info.rate_limits,
363 : };
364 : let role_control = RoleAccessControl {
365 : secret: auth_info.secret,
366 : };
367 :
368 : if let Some(project_id) = auth_info.project_id {
369 : let normalized_ep_int = normalized_ep.into();
370 :
371 : self.caches.project_info.insert_endpoint_access(
372 : auth_info.account_id,
373 : project_id,
374 : normalized_ep_int,
375 : role.into(),
376 : control.clone(),
377 : role_control,
378 : );
379 : ctx.set_project_id(project_id);
380 : }
381 :
382 : Ok(control)
383 : }
384 :
385 : #[tracing::instrument(skip_all)]
386 : async fn get_endpoint_jwks(
387 : &self,
388 : ctx: &RequestContext,
389 : endpoint: &EndpointId,
390 : ) -> Result<Vec<AuthRule>, GetEndpointJwksError> {
391 : self.do_get_endpoint_jwks(ctx, endpoint).await
392 : }
393 :
394 : #[tracing::instrument(skip_all)]
395 : async fn wake_compute(
396 : &self,
397 : ctx: &RequestContext,
398 : user_info: &ComputeUserInfo,
399 : ) -> Result<CachedNodeInfo, WakeComputeError> {
400 : let key = user_info.endpoint_cache_key();
401 :
402 : macro_rules! check_cache {
403 : () => {
404 : if let Some(cached) = self.caches.node_info.get(&key) {
405 : let (cached, info) = cached.take_value();
406 0 : let info = info.map_err(|c| {
407 0 : info!(key = &*key, "found cached wake_compute error");
408 0 : WakeComputeError::ControlPlane(ControlPlaneError::Message(Box::new(*c)))
409 0 : })?;
410 :
411 : debug!(key = &*key, "found cached compute node info");
412 : ctx.set_project(info.aux.clone());
413 0 : return Ok(cached.map(|()| info));
414 : }
415 : };
416 : }
417 :
418 : // Every time we do a wakeup http request, the compute node will stay up
419 : // for some time (highly depends on the console's scale-to-zero policy);
420 : // The connection info remains the same during that period of time,
421 : // which means that we might cache it to reduce the load and latency.
422 : check_cache!();
423 :
424 : let permit = self.locks.get_permit(&key).await?;
425 :
426 : // after getting back a permit - it's possible the cache was filled
427 : // double check
428 : if permit.should_check_cache() {
429 : // TODO: if there is something in the cache, mark the permit as success.
430 : check_cache!();
431 : }
432 :
433 : // check rate limit
434 : if !self
435 : .wake_compute_endpoint_rate_limiter
436 : .check(user_info.endpoint.normalize_intern(), 1)
437 : {
438 : return Err(WakeComputeError::TooManyConnections);
439 : }
440 :
441 : let node = permit.release_result(self.do_wake_compute(ctx, user_info).await);
442 : match node {
443 : Ok(node) => {
444 : ctx.set_project(node.aux.clone());
445 : debug!(key = &*key, "created a cache entry for woken compute node");
446 :
447 : let mut stored_node = node.clone();
448 : // store the cached node as 'warm_cached'
449 : stored_node.aux.cold_start_info = ColdStartInfo::WarmCached;
450 :
451 : let (_, cached) = self.caches.node_info.insert_unit(key, Ok(stored_node));
452 :
453 0 : Ok(cached.map(|()| node))
454 : }
455 : Err(err) => match err {
456 : WakeComputeError::ControlPlane(ControlPlaneError::Message(err)) => {
457 : let Some(status) = &err.status else {
458 : return Err(WakeComputeError::ControlPlane(ControlPlaneError::Message(
459 : err,
460 : )));
461 : };
462 :
463 : let reason = status
464 : .details
465 : .error_info
466 0 : .map_or(Reason::Unknown, |x| x.reason);
467 :
468 : // if we can retry this error, do not cache it.
469 : if reason.can_retry() {
470 : return Err(WakeComputeError::ControlPlane(ControlPlaneError::Message(
471 : err,
472 : )));
473 : }
474 :
475 : // at this point, we should only have quota errors.
476 : debug!(
477 : key = &*key,
478 : "created a cache entry for the wake compute error"
479 : );
480 :
481 : self.caches.node_info.insert_ttl(
482 : key,
483 : Err(err.clone()),
484 : Duration::from_secs(30),
485 : );
486 :
487 : Err(WakeComputeError::ControlPlane(ControlPlaneError::Message(
488 : err,
489 : )))
490 : }
491 : err => return Err(err),
492 : },
493 : }
494 : }
495 : }
496 :
497 : /// Parse http response body, taking status code into account.
498 0 : fn parse_body<T: for<'a> serde::Deserialize<'a>>(
499 0 : status: StatusCode,
500 0 : body: Bytes,
501 0 : ) -> Result<T, ControlPlaneError> {
502 0 : if status.is_success() {
503 : // We shouldn't log raw body because it may contain secrets.
504 0 : info!("request succeeded, processing the body");
505 0 : return Ok(serde_json::from_slice(&body).map_err(std::io::Error::other)?);
506 0 : }
507 0 :
508 0 : // Log plaintext to be able to detect, whether there are some cases not covered by the error struct.
509 0 : info!("response_error plaintext: {:?}", body);
510 :
511 : // Don't throw an error here because it's not as important
512 : // as the fact that the request itself has failed.
513 0 : let mut body = serde_json::from_slice(&body).unwrap_or_else(|e| {
514 0 : warn!("failed to parse error body: {e}");
515 0 : Box::new(ControlPlaneErrorMessage {
516 0 : error: "reason unclear (malformed error message)".into(),
517 0 : http_status_code: status,
518 0 : status: None,
519 0 : })
520 0 : });
521 0 : body.http_status_code = status;
522 0 :
523 0 : warn!("console responded with an error ({status}): {body:?}");
524 0 : Err(ControlPlaneError::Message(body))
525 0 : }
526 :
527 3 : fn parse_host_port(input: &str) -> Option<(&str, u16)> {
528 3 : let (host, port) = input.rsplit_once(':')?;
529 3 : let ipv6_brackets: &[_] = &['[', ']'];
530 3 : Some((host.trim_matches(ipv6_brackets), port.parse().ok()?))
531 3 : }
532 :
533 : #[cfg(test)]
534 : mod tests {
535 : use super::*;
536 :
537 : #[test]
538 1 : fn test_parse_host_port_v4() {
539 1 : let (host, port) = parse_host_port("127.0.0.1:5432").expect("failed to parse");
540 1 : assert_eq!(host, "127.0.0.1");
541 1 : assert_eq!(port, 5432);
542 1 : }
543 :
544 : #[test]
545 1 : fn test_parse_host_port_v6() {
546 1 : let (host, port) = parse_host_port("[2001:db8::1]:5432").expect("failed to parse");
547 1 : assert_eq!(host, "2001:db8::1");
548 1 : assert_eq!(port, 5432);
549 1 : }
550 :
551 : #[test]
552 1 : fn test_parse_host_port_url() {
553 1 : let (host, port) = parse_host_port("compute-foo-bar-1234.default.svc.cluster.local:5432")
554 1 : .expect("failed to parse");
555 1 : assert_eq!(host, "compute-foo-bar-1234.default.svc.cluster.local");
556 1 : assert_eq!(port, 5432);
557 1 : }
558 : }
|