Line data Source code
1 : //! Production console backend.
2 :
3 : use super::{
4 : super::messages::{ConsoleError, GetRoleSecret, WakeCompute},
5 : errors::{ApiError, GetAuthInfoError, WakeComputeError},
6 : ApiCaches, ApiLocks, AuthInfo, AuthSecret, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret,
7 : NodeInfo,
8 : };
9 : use crate::{
10 : auth::backend::ComputeUserInfo,
11 : compute,
12 : console::messages::{ColdStartInfo, Reason},
13 : http,
14 : metrics::{CacheOutcome, Metrics},
15 : rate_limiter::WakeComputeRateLimiter,
16 : scram, EndpointCacheKey,
17 : };
18 : use crate::{cache::Cached, context::RequestMonitoring};
19 : use futures::TryFutureExt;
20 : use std::{sync::Arc, time::Duration};
21 : use tokio::time::Instant;
22 : use tokio_postgres::config::SslMode;
23 : use tracing::{debug, error, info, info_span, warn, Instrument};
24 :
25 : pub struct Api {
26 : endpoint: http::Endpoint,
27 : pub caches: &'static ApiCaches,
28 : pub(crate) locks: &'static ApiLocks<EndpointCacheKey>,
29 : pub(crate) wake_compute_endpoint_rate_limiter: Arc<WakeComputeRateLimiter>,
30 : jwt: String,
31 : }
32 :
33 : impl Api {
34 : /// Construct an API object containing the auth parameters.
35 0 : pub fn new(
36 0 : endpoint: http::Endpoint,
37 0 : caches: &'static ApiCaches,
38 0 : locks: &'static ApiLocks<EndpointCacheKey>,
39 0 : wake_compute_endpoint_rate_limiter: Arc<WakeComputeRateLimiter>,
40 0 : ) -> Self {
41 0 : let jwt = std::env::var("NEON_PROXY_TO_CONTROLPLANE_TOKEN").unwrap_or_default();
42 0 : Self {
43 0 : endpoint,
44 0 : caches,
45 0 : locks,
46 0 : wake_compute_endpoint_rate_limiter,
47 0 : jwt,
48 0 : }
49 0 : }
50 :
51 0 : pub(crate) fn url(&self) -> &str {
52 0 : self.endpoint.url().as_str()
53 0 : }
54 :
55 0 : async fn do_get_auth_info(
56 0 : &self,
57 0 : ctx: &RequestMonitoring,
58 0 : user_info: &ComputeUserInfo,
59 0 : ) -> Result<AuthInfo, GetAuthInfoError> {
60 0 : if !self
61 0 : .caches
62 0 : .endpoints_cache
63 0 : .is_valid(ctx, &user_info.endpoint.normalize())
64 0 : .await
65 : {
66 0 : info!("endpoint is not valid, skipping the request");
67 0 : return Ok(AuthInfo::default());
68 0 : }
69 0 : let request_id = ctx.session_id().to_string();
70 0 : let application_name = ctx.console_application_name();
71 0 : async {
72 0 : let request = self
73 0 : .endpoint
74 0 : .get("proxy_get_role_secret")
75 0 : .header("X-Request-ID", &request_id)
76 0 : .header("Authorization", format!("Bearer {}", &self.jwt))
77 0 : .query(&[("session_id", ctx.session_id())])
78 0 : .query(&[
79 0 : ("application_name", application_name.as_str()),
80 0 : ("project", user_info.endpoint.as_str()),
81 0 : ("role", user_info.user.as_str()),
82 0 : ])
83 0 : .build()?;
84 :
85 0 : info!(url = request.url().as_str(), "sending http request");
86 0 : let start = Instant::now();
87 0 : let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane);
88 0 : let response = self.endpoint.execute(request).await?;
89 0 : drop(pause);
90 0 : info!(duration = ?start.elapsed(), "received http response");
91 0 : let body = match parse_body::<GetRoleSecret>(response).await {
92 0 : Ok(body) => body,
93 : // Error 404 is special: it's ok not to have a secret.
94 : // TODO(anna): retry
95 0 : Err(e) => {
96 0 : return if e.get_reason().is_not_found() {
97 0 : Ok(AuthInfo::default())
98 : } else {
99 0 : Err(e.into())
100 : }
101 : }
102 : };
103 :
104 0 : let secret = if body.role_secret.is_empty() {
105 0 : None
106 : } else {
107 0 : let secret = scram::ServerSecret::parse(&body.role_secret)
108 0 : .map(AuthSecret::Scram)
109 0 : .ok_or(GetAuthInfoError::BadSecret)?;
110 0 : Some(secret)
111 : };
112 0 : let allowed_ips = body.allowed_ips.unwrap_or_default();
113 0 : Metrics::get()
114 0 : .proxy
115 0 : .allowed_ips_number
116 0 : .observe(allowed_ips.len() as f64);
117 0 : Ok(AuthInfo {
118 0 : secret,
119 0 : allowed_ips,
120 0 : project_id: body.project_id,
121 0 : })
122 0 : }
123 0 : .map_err(crate::error::log_error)
124 0 : .instrument(info_span!("http", id = request_id))
125 0 : .await
126 0 : }
127 :
128 0 : async fn do_wake_compute(
129 0 : &self,
130 0 : ctx: &RequestMonitoring,
131 0 : user_info: &ComputeUserInfo,
132 0 : ) -> Result<NodeInfo, WakeComputeError> {
133 0 : let request_id = ctx.session_id().to_string();
134 0 : let application_name = ctx.console_application_name();
135 0 : async {
136 0 : let mut request_builder = self
137 0 : .endpoint
138 0 : .get("proxy_wake_compute")
139 0 : .header("X-Request-ID", &request_id)
140 0 : .header("Authorization", format!("Bearer {}", &self.jwt))
141 0 : .query(&[("session_id", ctx.session_id())])
142 0 : .query(&[
143 0 : ("application_name", application_name.as_str()),
144 0 : ("project", user_info.endpoint.as_str()),
145 0 : ]);
146 0 :
147 0 : let options = user_info.options.to_deep_object();
148 0 : if !options.is_empty() {
149 0 : request_builder = request_builder.query(&options);
150 0 : }
151 :
152 0 : let request = request_builder.build()?;
153 :
154 0 : info!(url = request.url().as_str(), "sending http request");
155 0 : let start = Instant::now();
156 0 : let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane);
157 0 : let response = self.endpoint.execute(request).await?;
158 0 : drop(pause);
159 0 : info!(duration = ?start.elapsed(), "received http response");
160 0 : let body = parse_body::<WakeCompute>(response).await?;
161 :
162 : // Unfortunately, ownership won't let us use `Option::ok_or` here.
163 0 : let (host, port) = match parse_host_port(&body.address) {
164 0 : None => return Err(WakeComputeError::BadComputeAddress(body.address)),
165 0 : Some(x) => x,
166 0 : };
167 0 :
168 0 : // Don't set anything but host and port! This config will be cached.
169 0 : // We'll set username and such later using the startup message.
170 0 : // TODO: add more type safety (in progress).
171 0 : let mut config = compute::ConnCfg::new();
172 0 : config.host(host).port(port).ssl_mode(SslMode::Disable); // TLS is not configured on compute nodes.
173 0 :
174 0 : let node = NodeInfo {
175 0 : config,
176 0 : aux: body.aux,
177 0 : allow_self_signed_compute: false,
178 0 : };
179 0 :
180 0 : Ok(node)
181 0 : }
182 0 : .map_err(crate::error::log_error)
183 0 : .instrument(info_span!("http", id = request_id))
184 0 : .await
185 0 : }
186 : }
187 :
188 : impl super::Api for Api {
189 0 : #[tracing::instrument(skip_all)]
190 : async fn get_role_secret(
191 : &self,
192 : ctx: &RequestMonitoring,
193 : user_info: &ComputeUserInfo,
194 : ) -> Result<CachedRoleSecret, GetAuthInfoError> {
195 : let normalized_ep = &user_info.endpoint.normalize();
196 : let user = &user_info.user;
197 : if let Some(role_secret) = self
198 : .caches
199 : .project_info
200 : .get_role_secret(normalized_ep, user)
201 : {
202 : return Ok(role_secret);
203 : }
204 : let auth_info = self.do_get_auth_info(ctx, user_info).await?;
205 : if let Some(project_id) = auth_info.project_id {
206 : let normalized_ep_int = normalized_ep.into();
207 : self.caches.project_info.insert_role_secret(
208 : project_id,
209 : normalized_ep_int,
210 : user.into(),
211 : auth_info.secret.clone(),
212 : );
213 : self.caches.project_info.insert_allowed_ips(
214 : project_id,
215 : normalized_ep_int,
216 : Arc::new(auth_info.allowed_ips),
217 : );
218 : ctx.set_project_id(project_id);
219 : }
220 : // When we just got a secret, we don't need to invalidate it.
221 : Ok(Cached::new_uncached(auth_info.secret))
222 : }
223 :
224 0 : async fn get_allowed_ips_and_secret(
225 0 : &self,
226 0 : ctx: &RequestMonitoring,
227 0 : user_info: &ComputeUserInfo,
228 0 : ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
229 0 : let normalized_ep = &user_info.endpoint.normalize();
230 0 : if let Some(allowed_ips) = self.caches.project_info.get_allowed_ips(normalized_ep) {
231 0 : Metrics::get()
232 0 : .proxy
233 0 : .allowed_ips_cache_misses
234 0 : .inc(CacheOutcome::Hit);
235 0 : return Ok((allowed_ips, None));
236 0 : }
237 0 : Metrics::get()
238 0 : .proxy
239 0 : .allowed_ips_cache_misses
240 0 : .inc(CacheOutcome::Miss);
241 0 : let auth_info = self.do_get_auth_info(ctx, user_info).await?;
242 0 : let allowed_ips = Arc::new(auth_info.allowed_ips);
243 0 : let user = &user_info.user;
244 0 : if let Some(project_id) = auth_info.project_id {
245 0 : let normalized_ep_int = normalized_ep.into();
246 0 : self.caches.project_info.insert_role_secret(
247 0 : project_id,
248 0 : normalized_ep_int,
249 0 : user.into(),
250 0 : auth_info.secret.clone(),
251 0 : );
252 0 : self.caches.project_info.insert_allowed_ips(
253 0 : project_id,
254 0 : normalized_ep_int,
255 0 : allowed_ips.clone(),
256 0 : );
257 0 : ctx.set_project_id(project_id);
258 0 : }
259 0 : Ok((
260 0 : Cached::new_uncached(allowed_ips),
261 0 : Some(Cached::new_uncached(auth_info.secret)),
262 0 : ))
263 0 : }
264 :
265 0 : #[tracing::instrument(skip_all)]
266 : async fn wake_compute(
267 : &self,
268 : ctx: &RequestMonitoring,
269 : user_info: &ComputeUserInfo,
270 : ) -> Result<CachedNodeInfo, WakeComputeError> {
271 : let key = user_info.endpoint_cache_key();
272 :
273 : macro_rules! check_cache {
274 : () => {
275 : if let Some(cached) = self.caches.node_info.get(&key) {
276 : let (cached, info) = cached.take_value();
277 0 : let info = info.map_err(|c| {
278 0 : info!(key = &*key, "found cached wake_compute error");
279 0 : WakeComputeError::ApiError(ApiError::Console(*c))
280 0 : })?;
281 :
282 : debug!(key = &*key, "found cached compute node info");
283 : ctx.set_project(info.aux.clone());
284 0 : return Ok(cached.map(|()| info));
285 : }
286 : };
287 : }
288 :
289 : // Every time we do a wakeup http request, the compute node will stay up
290 : // for some time (highly depends on the console's scale-to-zero policy);
291 : // The connection info remains the same during that period of time,
292 : // which means that we might cache it to reduce the load and latency.
293 : check_cache!();
294 :
295 : let permit = self.locks.get_permit(&key).await?;
296 :
297 : // after getting back a permit - it's possible the cache was filled
298 : // double check
299 : if permit.should_check_cache() {
300 : check_cache!();
301 : }
302 :
303 : // check rate limit
304 : if !self
305 : .wake_compute_endpoint_rate_limiter
306 : .check(user_info.endpoint.normalize_intern(), 1)
307 : {
308 : return Err(WakeComputeError::TooManyConnections);
309 : }
310 :
311 : let node = permit.release_result(self.do_wake_compute(ctx, user_info).await);
312 : match node {
313 : Ok(node) => {
314 : ctx.set_project(node.aux.clone());
315 : debug!(key = &*key, "created a cache entry for woken compute node");
316 :
317 : let mut stored_node = node.clone();
318 : // store the cached node as 'warm_cached'
319 : stored_node.aux.cold_start_info = ColdStartInfo::WarmCached;
320 :
321 : let (_, cached) = self.caches.node_info.insert_unit(key, Ok(stored_node));
322 :
323 0 : Ok(cached.map(|()| node))
324 : }
325 : Err(err) => match err {
326 : WakeComputeError::ApiError(ApiError::Console(err)) => {
327 : let Some(status) = &err.status else {
328 : return Err(WakeComputeError::ApiError(ApiError::Console(err)));
329 : };
330 :
331 : let reason = status
332 : .details
333 : .error_info
334 0 : .map_or(Reason::Unknown, |x| x.reason);
335 :
336 : // if we can retry this error, do not cache it.
337 : if reason.can_retry() {
338 : return Err(WakeComputeError::ApiError(ApiError::Console(err)));
339 : }
340 :
341 : // at this point, we should only have quota errors.
342 : debug!(
343 : key = &*key,
344 : "created a cache entry for the wake compute error"
345 : );
346 :
347 : self.caches.node_info.insert_ttl(
348 : key,
349 : Err(Box::new(err.clone())),
350 : Duration::from_secs(30),
351 : );
352 :
353 : Err(WakeComputeError::ApiError(ApiError::Console(err)))
354 : }
355 : err => return Err(err),
356 : },
357 : }
358 : }
359 : }
360 :
361 : /// Parse http response body, taking status code into account.
362 0 : async fn parse_body<T: for<'a> serde::Deserialize<'a>>(
363 0 : response: http::Response,
364 0 : ) -> Result<T, ApiError> {
365 0 : let status = response.status();
366 0 : if status.is_success() {
367 : // We shouldn't log raw body because it may contain secrets.
368 0 : info!("request succeeded, processing the body");
369 0 : return Ok(response.json().await?);
370 0 : }
371 0 : let s = response.bytes().await?;
372 : // Log plaintext to be able to detect, whether there are some cases not covered by the error struct.
373 0 : info!("response_error plaintext: {:?}", s);
374 :
375 : // Don't throw an error here because it's not as important
376 : // as the fact that the request itself has failed.
377 0 : let mut body = serde_json::from_slice(&s).unwrap_or_else(|e| {
378 0 : warn!("failed to parse error body: {e}");
379 0 : ConsoleError {
380 0 : error: "reason unclear (malformed error message)".into(),
381 0 : http_status_code: status,
382 0 : status: None,
383 0 : }
384 0 : });
385 0 : body.http_status_code = status;
386 0 :
387 0 : error!("console responded with an error ({status}): {body:?}");
388 0 : Err(ApiError::Console(body))
389 0 : }
390 :
391 3 : fn parse_host_port(input: &str) -> Option<(&str, u16)> {
392 3 : let (host, port) = input.rsplit_once(':')?;
393 3 : let ipv6_brackets: &[_] = &['[', ']'];
394 3 : Some((host.trim_matches(ipv6_brackets), port.parse().ok()?))
395 3 : }
396 :
397 : #[cfg(test)]
398 : mod tests {
399 : use super::*;
400 :
401 : #[test]
402 1 : fn test_parse_host_port_v4() {
403 1 : let (host, port) = parse_host_port("127.0.0.1:5432").expect("failed to parse");
404 1 : assert_eq!(host, "127.0.0.1");
405 1 : assert_eq!(port, 5432);
406 1 : }
407 :
408 : #[test]
409 1 : fn test_parse_host_port_v6() {
410 1 : let (host, port) = parse_host_port("[2001:db8::1]:5432").expect("failed to parse");
411 1 : assert_eq!(host, "2001:db8::1");
412 1 : assert_eq!(port, 5432);
413 1 : }
414 :
415 : #[test]
416 1 : fn test_parse_host_port_url() {
417 1 : let (host, port) = parse_host_port("compute-foo-bar-1234.default.svc.cluster.local:5432")
418 1 : .expect("failed to parse");
419 1 : assert_eq!(host, "compute-foo-bar-1234.default.svc.cluster.local");
420 1 : assert_eq!(port, 5432);
421 1 : }
422 : }
|