Line data Source code
1 : //! Production console backend.
2 :
3 : use std::net::IpAddr;
4 : use std::str::FromStr;
5 : use std::sync::Arc;
6 : use std::time::Duration;
7 :
8 : use ::http::HeaderName;
9 : use ::http::header::AUTHORIZATION;
10 : use futures::TryFutureExt;
11 : use postgres_client::config::SslMode;
12 : use tokio::time::Instant;
13 : use tracing::{Instrument, debug, info, info_span, warn};
14 :
15 : use super::super::messages::{ControlPlaneErrorMessage, GetEndpointAccessControl, WakeCompute};
16 : use crate::auth::backend::ComputeUserInfo;
17 : use crate::auth::backend::jwt::AuthRule;
18 : use crate::cache::Cached;
19 : use crate::context::RequestContext;
20 : use crate::control_plane::caches::ApiCaches;
21 : use crate::control_plane::errors::{
22 : ControlPlaneError, GetAuthInfoError, GetEndpointJwksError, WakeComputeError,
23 : };
24 : use crate::control_plane::locks::ApiLocks;
25 : use crate::control_plane::messages::{ColdStartInfo, EndpointJwksResponse, Reason};
26 : use crate::control_plane::{
27 : AccessBlockerFlags, AuthInfo, AuthSecret, CachedAccessBlockerFlags, CachedAllowedIps,
28 : CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret, NodeInfo,
29 : };
30 : use crate::metrics::{CacheOutcome, Metrics};
31 : use crate::rate_limiter::WakeComputeRateLimiter;
32 : use crate::types::{EndpointCacheKey, EndpointId};
33 : use crate::{compute, http, scram};
34 :
35 : pub(crate) const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");
36 :
37 : #[derive(Clone)]
38 : pub struct NeonControlPlaneClient {
39 : endpoint: http::Endpoint,
40 : pub caches: &'static ApiCaches,
41 : pub(crate) locks: &'static ApiLocks<EndpointCacheKey>,
42 : pub(crate) wake_compute_endpoint_rate_limiter: Arc<WakeComputeRateLimiter>,
43 : // put in a shared ref so we don't copy secrets all over in memory
44 : jwt: Arc<str>,
45 : }
46 :
47 : impl NeonControlPlaneClient {
48 : /// Construct an API object containing the auth parameters.
49 0 : pub fn new(
50 0 : endpoint: http::Endpoint,
51 0 : jwt: Arc<str>,
52 0 : caches: &'static ApiCaches,
53 0 : locks: &'static ApiLocks<EndpointCacheKey>,
54 0 : wake_compute_endpoint_rate_limiter: Arc<WakeComputeRateLimiter>,
55 0 : ) -> Self {
56 0 : Self {
57 0 : endpoint,
58 0 : caches,
59 0 : locks,
60 0 : wake_compute_endpoint_rate_limiter,
61 0 : jwt,
62 0 : }
63 0 : }
64 :
65 0 : pub(crate) fn url(&self) -> &str {
66 0 : self.endpoint.url().as_str()
67 0 : }
68 :
69 0 : async fn do_get_auth_info(
70 0 : &self,
71 0 : ctx: &RequestContext,
72 0 : user_info: &ComputeUserInfo,
73 0 : ) -> Result<AuthInfo, GetAuthInfoError> {
74 0 : if !self
75 0 : .caches
76 0 : .endpoints_cache
77 0 : .is_valid(ctx, &user_info.endpoint.normalize())
78 : {
79 : // TODO: refactor this because it's weird
80 : // this is a failure to authenticate but we return Ok.
81 0 : info!("endpoint is not valid, skipping the request");
82 0 : return Ok(AuthInfo::default());
83 0 : }
84 0 : self.do_get_auth_req(user_info, &ctx.session_id(), Some(ctx))
85 0 : .await
86 0 : }
87 :
88 0 : async fn do_get_auth_req(
89 0 : &self,
90 0 : user_info: &ComputeUserInfo,
91 0 : session_id: &uuid::Uuid,
92 0 : ctx: Option<&RequestContext>,
93 0 : ) -> Result<AuthInfo, GetAuthInfoError> {
94 0 : let request_id: String = session_id.to_string();
95 0 : let application_name = if let Some(ctx) = ctx {
96 0 : ctx.console_application_name()
97 : } else {
98 0 : "auth_cancellation".to_string()
99 : };
100 :
101 0 : async {
102 0 : let request = self
103 0 : .endpoint
104 0 : .get_path("get_endpoint_access_control")
105 0 : .header(X_REQUEST_ID, &request_id)
106 0 : .header(AUTHORIZATION, format!("Bearer {}", &self.jwt))
107 0 : .query(&[("session_id", session_id)])
108 0 : .query(&[
109 0 : ("application_name", application_name.as_str()),
110 0 : ("endpointish", user_info.endpoint.as_str()),
111 0 : ("role", user_info.user.as_str()),
112 0 : ])
113 0 : .build()?;
114 :
115 0 : debug!(url = request.url().as_str(), "sending http request");
116 0 : let start = Instant::now();
117 0 : let response = match ctx {
118 0 : Some(ctx) => {
119 0 : let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane);
120 0 : let rsp = self.endpoint.execute(request).await;
121 0 : drop(pause);
122 0 : rsp?
123 : }
124 0 : None => self.endpoint.execute(request).await?,
125 : };
126 :
127 0 : info!(duration = ?start.elapsed(), "received http response");
128 0 : let body = match parse_body::<GetEndpointAccessControl>(response).await {
129 0 : Ok(body) => body,
130 : // Error 404 is special: it's ok not to have a secret.
131 : // TODO(anna): retry
132 0 : Err(e) => {
133 0 : return if e.get_reason().is_not_found() {
134 : // TODO: refactor this because it's weird
135 : // this is a failure to authenticate but we return Ok.
136 0 : Ok(AuthInfo::default())
137 : } else {
138 0 : Err(e.into())
139 : };
140 : }
141 : };
142 :
143 0 : let secret = if body.role_secret.is_empty() {
144 0 : None
145 : } else {
146 0 : let secret = scram::ServerSecret::parse(&body.role_secret)
147 0 : .map(AuthSecret::Scram)
148 0 : .ok_or(GetAuthInfoError::BadSecret)?;
149 0 : Some(secret)
150 : };
151 0 : let allowed_ips = body.allowed_ips.unwrap_or_default();
152 0 : Metrics::get()
153 0 : .proxy
154 0 : .allowed_ips_number
155 0 : .observe(allowed_ips.len() as f64);
156 0 : let allowed_vpc_endpoint_ids = body.allowed_vpc_endpoint_ids.unwrap_or_default();
157 0 : Metrics::get()
158 0 : .proxy
159 0 : .allowed_vpc_endpoint_ids
160 0 : .observe(allowed_vpc_endpoint_ids.len() as f64);
161 0 : let block_public_connections = body.block_public_connections.unwrap_or_default();
162 0 : let block_vpc_connections = body.block_vpc_connections.unwrap_or_default();
163 0 : Ok(AuthInfo {
164 0 : secret,
165 0 : allowed_ips,
166 0 : allowed_vpc_endpoint_ids,
167 0 : project_id: body.project_id,
168 0 : account_id: body.account_id,
169 0 : access_blocker_flags: AccessBlockerFlags {
170 0 : public_access_blocked: block_public_connections,
171 0 : vpc_access_blocked: block_vpc_connections,
172 0 : },
173 0 : })
174 0 : }
175 0 : .inspect_err(|e| tracing::debug!(error = ?e))
176 0 : .instrument(info_span!("do_get_auth_info"))
177 0 : .await
178 0 : }
179 :
180 0 : async fn do_get_endpoint_jwks(
181 0 : &self,
182 0 : ctx: &RequestContext,
183 0 : endpoint: EndpointId,
184 0 : ) -> Result<Vec<AuthRule>, GetEndpointJwksError> {
185 0 : if !self
186 0 : .caches
187 0 : .endpoints_cache
188 0 : .is_valid(ctx, &endpoint.normalize())
189 : {
190 0 : return Err(GetEndpointJwksError::EndpointNotFound);
191 0 : }
192 0 : let request_id = ctx.session_id().to_string();
193 0 : async {
194 0 : let request = self
195 0 : .endpoint
196 0 : .get_with_url(|url| {
197 0 : url.path_segments_mut()
198 0 : .push("endpoints")
199 0 : .push(endpoint.as_str())
200 0 : .push("jwks");
201 0 : })
202 0 : .header(X_REQUEST_ID, &request_id)
203 0 : .header(AUTHORIZATION, format!("Bearer {}", &self.jwt))
204 0 : .query(&[("session_id", ctx.session_id())])
205 0 : .build()
206 0 : .map_err(GetEndpointJwksError::RequestBuild)?;
207 :
208 0 : debug!(url = request.url().as_str(), "sending http request");
209 0 : let start = Instant::now();
210 0 : let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane);
211 0 : let response = self
212 0 : .endpoint
213 0 : .execute(request)
214 0 : .await
215 0 : .map_err(GetEndpointJwksError::RequestExecute)?;
216 0 : drop(pause);
217 0 : info!(duration = ?start.elapsed(), "received http response");
218 :
219 0 : let body = parse_body::<EndpointJwksResponse>(response).await?;
220 :
221 0 : let rules = body
222 0 : .jwks
223 0 : .into_iter()
224 0 : .map(|jwks| AuthRule {
225 0 : id: jwks.id,
226 0 : jwks_url: jwks.jwks_url,
227 0 : audience: jwks.jwt_audience,
228 0 : role_names: jwks.role_names,
229 0 : })
230 0 : .collect();
231 0 :
232 0 : Ok(rules)
233 0 : }
234 0 : .inspect_err(|e| tracing::debug!(error = ?e))
235 0 : .instrument(info_span!("do_get_endpoint_jwks"))
236 0 : .await
237 0 : }
238 :
239 0 : async fn do_wake_compute(
240 0 : &self,
241 0 : ctx: &RequestContext,
242 0 : user_info: &ComputeUserInfo,
243 0 : ) -> Result<NodeInfo, WakeComputeError> {
244 0 : let request_id = ctx.session_id().to_string();
245 0 : let application_name = ctx.console_application_name();
246 0 : async {
247 0 : let mut request_builder = self
248 0 : .endpoint
249 0 : .get_path("wake_compute")
250 0 : .header("X-Request-ID", &request_id)
251 0 : .header("Authorization", format!("Bearer {}", &self.jwt))
252 0 : .query(&[("session_id", ctx.session_id())])
253 0 : .query(&[
254 0 : ("application_name", application_name.as_str()),
255 0 : ("endpointish", user_info.endpoint.as_str()),
256 0 : ]);
257 0 :
258 0 : let options = user_info.options.to_deep_object();
259 0 : if !options.is_empty() {
260 0 : request_builder = request_builder.query(&options);
261 0 : }
262 :
263 0 : let request = request_builder.build()?;
264 :
265 0 : debug!(url = request.url().as_str(), "sending http request");
266 0 : let start = Instant::now();
267 0 : let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane);
268 0 : let response = self.endpoint.execute(request).await?;
269 0 : drop(pause);
270 0 : info!(duration = ?start.elapsed(), "received http response");
271 0 : let body = parse_body::<WakeCompute>(response).await?;
272 :
273 : // Unfortunately, ownership won't let us use `Option::ok_or` here.
274 0 : let (host, port) = match parse_host_port(&body.address) {
275 0 : None => return Err(WakeComputeError::BadComputeAddress(body.address)),
276 0 : Some(x) => x,
277 0 : };
278 0 :
279 0 : let host_addr = IpAddr::from_str(host).ok();
280 :
281 0 : let ssl_mode = match &body.server_name {
282 0 : Some(_) => SslMode::Require,
283 0 : None => SslMode::Disable,
284 : };
285 0 : let host_name = match body.server_name {
286 0 : Some(host) => host,
287 0 : None => host.to_owned(),
288 : };
289 :
290 : // Don't set anything but host and port! This config will be cached.
291 : // We'll set username and such later using the startup message.
292 : // TODO: add more type safety (in progress).
293 0 : let mut config = compute::ConnCfg::new(host_name, port);
294 :
295 0 : if let Some(addr) = host_addr {
296 0 : config.set_host_addr(addr);
297 0 : }
298 :
299 0 : config.ssl_mode(ssl_mode);
300 0 :
301 0 : let node = NodeInfo {
302 0 : config,
303 0 : aux: body.aux,
304 0 : };
305 0 :
306 0 : Ok(node)
307 0 : }
308 0 : .inspect_err(|e| tracing::debug!(error = ?e))
309 0 : .instrument(info_span!("do_wake_compute"))
310 0 : .await
311 0 : }
312 : }
313 :
314 : impl super::ControlPlaneApi for NeonControlPlaneClient {
315 : #[tracing::instrument(skip_all)]
316 : async fn get_role_secret(
317 : &self,
318 : ctx: &RequestContext,
319 : user_info: &ComputeUserInfo,
320 : ) -> Result<CachedRoleSecret, GetAuthInfoError> {
321 : let normalized_ep = &user_info.endpoint.normalize();
322 : let user = &user_info.user;
323 : if let Some(role_secret) = self
324 : .caches
325 : .project_info
326 : .get_role_secret(normalized_ep, user)
327 : {
328 : return Ok(role_secret);
329 : }
330 : let auth_info = self.do_get_auth_info(ctx, user_info).await?;
331 : let account_id = auth_info.account_id;
332 : if let Some(project_id) = auth_info.project_id {
333 : let normalized_ep_int = normalized_ep.into();
334 : self.caches.project_info.insert_role_secret(
335 : project_id,
336 : normalized_ep_int,
337 : user.into(),
338 : auth_info.secret.clone(),
339 : );
340 : self.caches.project_info.insert_allowed_ips(
341 : project_id,
342 : normalized_ep_int,
343 : Arc::new(auth_info.allowed_ips),
344 : );
345 : self.caches.project_info.insert_allowed_vpc_endpoint_ids(
346 : account_id,
347 : project_id,
348 : normalized_ep_int,
349 : Arc::new(auth_info.allowed_vpc_endpoint_ids),
350 : );
351 : self.caches.project_info.insert_block_public_or_vpc_access(
352 : project_id,
353 : normalized_ep_int,
354 : auth_info.access_blocker_flags,
355 : );
356 : ctx.set_project_id(project_id);
357 : }
358 : // When we just got a secret, we don't need to invalidate it.
359 : Ok(Cached::new_uncached(auth_info.secret))
360 : }
361 :
362 0 : async fn get_allowed_ips(
363 0 : &self,
364 0 : ctx: &RequestContext,
365 0 : user_info: &ComputeUserInfo,
366 0 : ) -> Result<CachedAllowedIps, GetAuthInfoError> {
367 0 : let normalized_ep = &user_info.endpoint.normalize();
368 0 : if let Some(allowed_ips) = self.caches.project_info.get_allowed_ips(normalized_ep) {
369 0 : Metrics::get()
370 0 : .proxy
371 0 : .allowed_ips_cache_misses // TODO SR: Should we rename this variable to something like allowed_ip_cache_stats?
372 0 : .inc(CacheOutcome::Hit);
373 0 : return Ok(allowed_ips);
374 0 : }
375 0 : Metrics::get()
376 0 : .proxy
377 0 : .allowed_ips_cache_misses
378 0 : .inc(CacheOutcome::Miss);
379 0 : let auth_info = self.do_get_auth_info(ctx, user_info).await?;
380 0 : let allowed_ips = Arc::new(auth_info.allowed_ips);
381 0 : let allowed_vpc_endpoint_ids = Arc::new(auth_info.allowed_vpc_endpoint_ids);
382 0 : let access_blocker_flags = auth_info.access_blocker_flags;
383 0 : let user = &user_info.user;
384 0 : let account_id = auth_info.account_id;
385 0 : if let Some(project_id) = auth_info.project_id {
386 0 : let normalized_ep_int = normalized_ep.into();
387 0 : self.caches.project_info.insert_role_secret(
388 0 : project_id,
389 0 : normalized_ep_int,
390 0 : user.into(),
391 0 : auth_info.secret.clone(),
392 0 : );
393 0 : self.caches.project_info.insert_allowed_ips(
394 0 : project_id,
395 0 : normalized_ep_int,
396 0 : allowed_ips.clone(),
397 0 : );
398 0 : self.caches.project_info.insert_allowed_vpc_endpoint_ids(
399 0 : account_id,
400 0 : project_id,
401 0 : normalized_ep_int,
402 0 : allowed_vpc_endpoint_ids.clone(),
403 0 : );
404 0 : self.caches.project_info.insert_block_public_or_vpc_access(
405 0 : project_id,
406 0 : normalized_ep_int,
407 0 : access_blocker_flags,
408 0 : );
409 0 : ctx.set_project_id(project_id);
410 0 : }
411 0 : Ok(Cached::new_uncached(allowed_ips))
412 0 : }
413 :
414 0 : async fn get_allowed_vpc_endpoint_ids(
415 0 : &self,
416 0 : ctx: &RequestContext,
417 0 : user_info: &ComputeUserInfo,
418 0 : ) -> Result<CachedAllowedVpcEndpointIds, GetAuthInfoError> {
419 0 : let normalized_ep = &user_info.endpoint.normalize();
420 0 : if let Some(allowed_vpc_endpoint_ids) = self
421 0 : .caches
422 0 : .project_info
423 0 : .get_allowed_vpc_endpoint_ids(normalized_ep)
424 : {
425 0 : Metrics::get()
426 0 : .proxy
427 0 : .vpc_endpoint_id_cache_stats
428 0 : .inc(CacheOutcome::Hit);
429 0 : return Ok(allowed_vpc_endpoint_ids);
430 0 : }
431 0 :
432 0 : Metrics::get()
433 0 : .proxy
434 0 : .vpc_endpoint_id_cache_stats
435 0 : .inc(CacheOutcome::Miss);
436 :
437 0 : let auth_info = self.do_get_auth_info(ctx, user_info).await?;
438 0 : let allowed_ips = Arc::new(auth_info.allowed_ips);
439 0 : let allowed_vpc_endpoint_ids = Arc::new(auth_info.allowed_vpc_endpoint_ids);
440 0 : let access_blocker_flags = auth_info.access_blocker_flags;
441 0 : let user = &user_info.user;
442 0 : let account_id = auth_info.account_id;
443 0 : if let Some(project_id) = auth_info.project_id {
444 0 : let normalized_ep_int = normalized_ep.into();
445 0 : self.caches.project_info.insert_role_secret(
446 0 : project_id,
447 0 : normalized_ep_int,
448 0 : user.into(),
449 0 : auth_info.secret.clone(),
450 0 : );
451 0 : self.caches.project_info.insert_allowed_ips(
452 0 : project_id,
453 0 : normalized_ep_int,
454 0 : allowed_ips.clone(),
455 0 : );
456 0 : self.caches.project_info.insert_allowed_vpc_endpoint_ids(
457 0 : account_id,
458 0 : project_id,
459 0 : normalized_ep_int,
460 0 : allowed_vpc_endpoint_ids.clone(),
461 0 : );
462 0 : self.caches.project_info.insert_block_public_or_vpc_access(
463 0 : project_id,
464 0 : normalized_ep_int,
465 0 : access_blocker_flags,
466 0 : );
467 0 : ctx.set_project_id(project_id);
468 0 : }
469 0 : Ok(Cached::new_uncached(allowed_vpc_endpoint_ids))
470 0 : }
471 :
472 0 : async fn get_block_public_or_vpc_access(
473 0 : &self,
474 0 : ctx: &RequestContext,
475 0 : user_info: &ComputeUserInfo,
476 0 : ) -> Result<CachedAccessBlockerFlags, GetAuthInfoError> {
477 0 : let normalized_ep = &user_info.endpoint.normalize();
478 0 : if let Some(access_blocker_flags) = self
479 0 : .caches
480 0 : .project_info
481 0 : .get_block_public_or_vpc_access(normalized_ep)
482 : {
483 0 : Metrics::get()
484 0 : .proxy
485 0 : .access_blocker_flags_cache_stats
486 0 : .inc(CacheOutcome::Hit);
487 0 : return Ok(access_blocker_flags);
488 0 : }
489 0 :
490 0 : Metrics::get()
491 0 : .proxy
492 0 : .access_blocker_flags_cache_stats
493 0 : .inc(CacheOutcome::Miss);
494 :
495 0 : let auth_info = self.do_get_auth_info(ctx, user_info).await?;
496 0 : let allowed_ips = Arc::new(auth_info.allowed_ips);
497 0 : let allowed_vpc_endpoint_ids = Arc::new(auth_info.allowed_vpc_endpoint_ids);
498 0 : let access_blocker_flags = auth_info.access_blocker_flags;
499 0 : let user = &user_info.user;
500 0 : let account_id = auth_info.account_id;
501 0 : if let Some(project_id) = auth_info.project_id {
502 0 : let normalized_ep_int = normalized_ep.into();
503 0 : self.caches.project_info.insert_role_secret(
504 0 : project_id,
505 0 : normalized_ep_int,
506 0 : user.into(),
507 0 : auth_info.secret.clone(),
508 0 : );
509 0 : self.caches.project_info.insert_allowed_ips(
510 0 : project_id,
511 0 : normalized_ep_int,
512 0 : allowed_ips.clone(),
513 0 : );
514 0 : self.caches.project_info.insert_allowed_vpc_endpoint_ids(
515 0 : account_id,
516 0 : project_id,
517 0 : normalized_ep_int,
518 0 : allowed_vpc_endpoint_ids.clone(),
519 0 : );
520 0 : self.caches.project_info.insert_block_public_or_vpc_access(
521 0 : project_id,
522 0 : normalized_ep_int,
523 0 : access_blocker_flags.clone(),
524 0 : );
525 0 : ctx.set_project_id(project_id);
526 0 : }
527 0 : Ok(Cached::new_uncached(access_blocker_flags))
528 0 : }
529 :
530 : #[tracing::instrument(skip_all)]
531 : async fn get_endpoint_jwks(
532 : &self,
533 : ctx: &RequestContext,
534 : endpoint: EndpointId,
535 : ) -> Result<Vec<AuthRule>, GetEndpointJwksError> {
536 : self.do_get_endpoint_jwks(ctx, endpoint).await
537 : }
538 :
539 : #[tracing::instrument(skip_all)]
540 : async fn wake_compute(
541 : &self,
542 : ctx: &RequestContext,
543 : user_info: &ComputeUserInfo,
544 : ) -> Result<CachedNodeInfo, WakeComputeError> {
545 : let key = user_info.endpoint_cache_key();
546 :
547 : macro_rules! check_cache {
548 : () => {
549 : if let Some(cached) = self.caches.node_info.get(&key) {
550 : let (cached, info) = cached.take_value();
551 0 : let info = info.map_err(|c| {
552 0 : info!(key = &*key, "found cached wake_compute error");
553 0 : WakeComputeError::ControlPlane(ControlPlaneError::Message(Box::new(*c)))
554 0 : })?;
555 :
556 : debug!(key = &*key, "found cached compute node info");
557 : ctx.set_project(info.aux.clone());
558 0 : return Ok(cached.map(|()| info));
559 : }
560 : };
561 : }
562 :
563 : // Every time we do a wakeup http request, the compute node will stay up
564 : // for some time (highly depends on the console's scale-to-zero policy);
565 : // The connection info remains the same during that period of time,
566 : // which means that we might cache it to reduce the load and latency.
567 : check_cache!();
568 :
569 : let permit = self.locks.get_permit(&key).await?;
570 :
571 : // after getting back a permit - it's possible the cache was filled
572 : // double check
573 : if permit.should_check_cache() {
574 : // TODO: if there is something in the cache, mark the permit as success.
575 : check_cache!();
576 : }
577 :
578 : // check rate limit
579 : if !self
580 : .wake_compute_endpoint_rate_limiter
581 : .check(user_info.endpoint.normalize_intern(), 1)
582 : {
583 : return Err(WakeComputeError::TooManyConnections);
584 : }
585 :
586 : let node = permit.release_result(self.do_wake_compute(ctx, user_info).await);
587 : match node {
588 : Ok(node) => {
589 : ctx.set_project(node.aux.clone());
590 : debug!(key = &*key, "created a cache entry for woken compute node");
591 :
592 : let mut stored_node = node.clone();
593 : // store the cached node as 'warm_cached'
594 : stored_node.aux.cold_start_info = ColdStartInfo::WarmCached;
595 :
596 : let (_, cached) = self.caches.node_info.insert_unit(key, Ok(stored_node));
597 :
598 0 : Ok(cached.map(|()| node))
599 : }
600 : Err(err) => match err {
601 : WakeComputeError::ControlPlane(ControlPlaneError::Message(err)) => {
602 : let Some(status) = &err.status else {
603 : return Err(WakeComputeError::ControlPlane(ControlPlaneError::Message(
604 : err,
605 : )));
606 : };
607 :
608 : let reason = status
609 : .details
610 : .error_info
611 0 : .map_or(Reason::Unknown, |x| x.reason);
612 :
613 : // if we can retry this error, do not cache it.
614 : if reason.can_retry() {
615 : return Err(WakeComputeError::ControlPlane(ControlPlaneError::Message(
616 : err,
617 : )));
618 : }
619 :
620 : // at this point, we should only have quota errors.
621 : debug!(
622 : key = &*key,
623 : "created a cache entry for the wake compute error"
624 : );
625 :
626 : self.caches.node_info.insert_ttl(
627 : key,
628 : Err(err.clone()),
629 : Duration::from_secs(30),
630 : );
631 :
632 : Err(WakeComputeError::ControlPlane(ControlPlaneError::Message(
633 : err,
634 : )))
635 : }
636 : err => return Err(err),
637 : },
638 : }
639 : }
640 : }
641 :
642 : /// Parse http response body, taking status code into account.
643 0 : async fn parse_body<T: for<'a> serde::Deserialize<'a>>(
644 0 : response: http::Response,
645 0 : ) -> Result<T, ControlPlaneError> {
646 0 : let status = response.status();
647 0 : if status.is_success() {
648 : // We shouldn't log raw body because it may contain secrets.
649 0 : info!("request succeeded, processing the body");
650 0 : return Ok(response.json().await?);
651 0 : }
652 0 : let s = response.bytes().await?;
653 : // Log plaintext to be able to detect, whether there are some cases not covered by the error struct.
654 0 : info!("response_error plaintext: {:?}", s);
655 :
656 : // Don't throw an error here because it's not as important
657 : // as the fact that the request itself has failed.
658 0 : let mut body = serde_json::from_slice(&s).unwrap_or_else(|e| {
659 0 : warn!("failed to parse error body: {e}");
660 0 : ControlPlaneErrorMessage {
661 0 : error: "reason unclear (malformed error message)".into(),
662 0 : http_status_code: status,
663 0 : status: None,
664 0 : }
665 0 : });
666 0 : body.http_status_code = status;
667 0 :
668 0 : warn!("console responded with an error ({status}): {body:?}");
669 0 : Err(ControlPlaneError::Message(Box::new(body)))
670 0 : }
671 :
672 3 : fn parse_host_port(input: &str) -> Option<(&str, u16)> {
673 3 : let (host, port) = input.rsplit_once(':')?;
674 3 : let ipv6_brackets: &[_] = &['[', ']'];
675 3 : Some((host.trim_matches(ipv6_brackets), port.parse().ok()?))
676 3 : }
677 :
678 : #[cfg(test)]
679 : mod tests {
680 : use super::*;
681 :
682 : #[test]
683 1 : fn test_parse_host_port_v4() {
684 1 : let (host, port) = parse_host_port("127.0.0.1:5432").expect("failed to parse");
685 1 : assert_eq!(host, "127.0.0.1");
686 1 : assert_eq!(port, 5432);
687 1 : }
688 :
689 : #[test]
690 1 : fn test_parse_host_port_v6() {
691 1 : let (host, port) = parse_host_port("[2001:db8::1]:5432").expect("failed to parse");
692 1 : assert_eq!(host, "2001:db8::1");
693 1 : assert_eq!(port, 5432);
694 1 : }
695 :
696 : #[test]
697 1 : fn test_parse_host_port_url() {
698 1 : let (host, port) = parse_host_port("compute-foo-bar-1234.default.svc.cluster.local:5432")
699 1 : .expect("failed to parse");
700 1 : assert_eq!(host, "compute-foo-bar-1234.default.svc.cluster.local");
701 1 : assert_eq!(port, 5432);
702 1 : }
703 : }
|