Line data Source code
1 : //! Production console backend.
2 :
3 : use super::{
4 : super::messages::{ConsoleError, GetRoleSecret, WakeCompute},
5 : errors::{ApiError, GetAuthInfoError, WakeComputeError},
6 : ApiCaches, ApiLocks, AuthInfo, AuthSecret, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret,
7 : NodeInfo,
8 : };
9 : use crate::{
10 : auth::backend::ComputeUserInfo,
11 : compute,
12 : console::messages::ColdStartInfo,
13 : http,
14 : metrics::{CacheOutcome, Metrics},
15 : rate_limiter::EndpointRateLimiter,
16 : scram, EndpointCacheKey, Normalize,
17 : };
18 : use crate::{cache::Cached, context::RequestMonitoring};
19 : use futures::TryFutureExt;
20 : use std::sync::Arc;
21 : use tokio::time::Instant;
22 : use tokio_postgres::config::SslMode;
23 : use tracing::{error, info, info_span, warn, Instrument};
24 :
25 : pub struct Api {
26 : endpoint: http::Endpoint,
27 : pub caches: &'static ApiCaches,
28 : pub locks: &'static ApiLocks<EndpointCacheKey>,
29 : pub wake_compute_endpoint_rate_limiter: Arc<EndpointRateLimiter>,
30 : jwt: String,
31 : }
32 :
33 : impl Api {
34 : /// Construct an API object containing the auth parameters.
35 0 : pub fn new(
36 0 : endpoint: http::Endpoint,
37 0 : caches: &'static ApiCaches,
38 0 : locks: &'static ApiLocks<EndpointCacheKey>,
39 0 : wake_compute_endpoint_rate_limiter: Arc<EndpointRateLimiter>,
40 0 : ) -> Self {
41 0 : let jwt: String = match std::env::var("NEON_PROXY_TO_CONTROLPLANE_TOKEN") {
42 0 : Ok(v) => v,
43 0 : Err(_) => "".to_string(),
44 : };
45 0 : Self {
46 0 : endpoint,
47 0 : caches,
48 0 : locks,
49 0 : wake_compute_endpoint_rate_limiter,
50 0 : jwt,
51 0 : }
52 0 : }
53 :
54 0 : pub fn url(&self) -> &str {
55 0 : self.endpoint.url().as_str()
56 0 : }
57 :
58 0 : async fn do_get_auth_info(
59 0 : &self,
60 0 : ctx: &mut RequestMonitoring,
61 0 : user_info: &ComputeUserInfo,
62 0 : ) -> Result<AuthInfo, GetAuthInfoError> {
63 0 : if !self
64 0 : .caches
65 0 : .endpoints_cache
66 0 : .is_valid(ctx, &user_info.endpoint.normalize())
67 0 : .await
68 : {
69 0 : info!("endpoint is not valid, skipping the request");
70 0 : return Ok(AuthInfo::default());
71 0 : }
72 0 : let request_id = ctx.session_id.to_string();
73 0 : let application_name = ctx.console_application_name();
74 0 : async {
75 0 : let request = self
76 0 : .endpoint
77 0 : .get("proxy_get_role_secret")
78 0 : .header("X-Request-ID", &request_id)
79 0 : .header("Authorization", format!("Bearer {}", &self.jwt))
80 0 : .query(&[("session_id", ctx.session_id)])
81 0 : .query(&[
82 0 : ("application_name", application_name.as_str()),
83 0 : ("project", user_info.endpoint.as_str()),
84 0 : ("role", user_info.user.as_str()),
85 0 : ])
86 0 : .build()?;
87 :
88 0 : info!(url = request.url().as_str(), "sending http request");
89 0 : let start = Instant::now();
90 0 : let pause = ctx.latency_timer.pause(crate::metrics::Waiting::Cplane);
91 0 : let response = self.endpoint.execute(request).await?;
92 0 : drop(pause);
93 0 : info!(duration = ?start.elapsed(), "received http response");
94 0 : let body = match parse_body::<GetRoleSecret>(response).await {
95 0 : Ok(body) => body,
96 : // Error 404 is special: it's ok not to have a secret.
97 0 : Err(e) => match e.http_status_code() {
98 : Some(http::StatusCode::NOT_FOUND) => {
99 0 : return Ok(AuthInfo::default());
100 : }
101 0 : _otherwise => return Err(e.into()),
102 : },
103 : };
104 :
105 0 : let secret = if body.role_secret.is_empty() {
106 0 : None
107 : } else {
108 0 : let secret = scram::ServerSecret::parse(&body.role_secret)
109 0 : .map(AuthSecret::Scram)
110 0 : .ok_or(GetAuthInfoError::BadSecret)?;
111 0 : Some(secret)
112 : };
113 0 : let allowed_ips = body.allowed_ips.unwrap_or_default();
114 0 : Metrics::get()
115 0 : .proxy
116 0 : .allowed_ips_number
117 0 : .observe(allowed_ips.len() as f64);
118 0 : Ok(AuthInfo {
119 0 : secret,
120 0 : allowed_ips,
121 0 : project_id: body.project_id,
122 0 : })
123 0 : }
124 0 : .map_err(crate::error::log_error)
125 0 : .instrument(info_span!("http", id = request_id))
126 0 : .await
127 0 : }
128 :
129 0 : async fn do_wake_compute(
130 0 : &self,
131 0 : ctx: &mut RequestMonitoring,
132 0 : user_info: &ComputeUserInfo,
133 0 : ) -> Result<NodeInfo, WakeComputeError> {
134 0 : let request_id = ctx.session_id.to_string();
135 0 : let application_name = ctx.console_application_name();
136 0 : async {
137 0 : let mut request_builder = self
138 0 : .endpoint
139 0 : .get("proxy_wake_compute")
140 0 : .header("X-Request-ID", &request_id)
141 0 : .header("Authorization", format!("Bearer {}", &self.jwt))
142 0 : .query(&[("session_id", ctx.session_id)])
143 0 : .query(&[
144 0 : ("application_name", application_name.as_str()),
145 0 : ("project", user_info.endpoint.as_str()),
146 0 : ]);
147 0 :
148 0 : let options = user_info.options.to_deep_object();
149 0 : if !options.is_empty() {
150 0 : request_builder = request_builder.query(&options);
151 0 : }
152 :
153 0 : let request = request_builder.build()?;
154 :
155 0 : info!(url = request.url().as_str(), "sending http request");
156 0 : let start = Instant::now();
157 0 : let pause = ctx.latency_timer.pause(crate::metrics::Waiting::Cplane);
158 0 : let response = self.endpoint.execute(request).await?;
159 0 : drop(pause);
160 0 : info!(duration = ?start.elapsed(), "received http response");
161 0 : let body = parse_body::<WakeCompute>(response).await?;
162 :
163 : // Unfortunately, ownership won't let us use `Option::ok_or` here.
164 0 : let (host, port) = match parse_host_port(&body.address) {
165 0 : None => return Err(WakeComputeError::BadComputeAddress(body.address)),
166 0 : Some(x) => x,
167 0 : };
168 0 :
169 0 : // Don't set anything but host and port! This config will be cached.
170 0 : // We'll set username and such later using the startup message.
171 0 : // TODO: add more type safety (in progress).
172 0 : let mut config = compute::ConnCfg::new();
173 0 : config.host(host).port(port).ssl_mode(SslMode::Disable); // TLS is not configured on compute nodes.
174 0 :
175 0 : let node = NodeInfo {
176 0 : config,
177 0 : aux: body.aux,
178 0 : allow_self_signed_compute: false,
179 0 : };
180 0 :
181 0 : Ok(node)
182 0 : }
183 0 : .map_err(crate::error::log_error)
184 0 : .instrument(info_span!("http", id = request_id))
185 0 : .await
186 0 : }
187 : }
188 :
189 : impl super::Api for Api {
190 0 : #[tracing::instrument(skip_all)]
191 : async fn get_role_secret(
192 : &self,
193 : ctx: &mut RequestMonitoring,
194 : user_info: &ComputeUserInfo,
195 : ) -> Result<CachedRoleSecret, GetAuthInfoError> {
196 : let normalized_ep = &user_info.endpoint.normalize();
197 : let user = &user_info.user;
198 : if let Some(role_secret) = self
199 : .caches
200 : .project_info
201 : .get_role_secret(normalized_ep, user)
202 : {
203 : return Ok(role_secret);
204 : }
205 : let auth_info = self.do_get_auth_info(ctx, user_info).await?;
206 : if let Some(project_id) = auth_info.project_id {
207 : let normalized_ep_int = normalized_ep.into();
208 : self.caches.project_info.insert_role_secret(
209 : project_id,
210 : normalized_ep_int,
211 : user.into(),
212 : auth_info.secret.clone(),
213 : );
214 : self.caches.project_info.insert_allowed_ips(
215 : project_id,
216 : normalized_ep_int,
217 : Arc::new(auth_info.allowed_ips),
218 : );
219 : ctx.set_project_id(project_id);
220 : }
221 : // When we just got a secret, we don't need to invalidate it.
222 : Ok(Cached::new_uncached(auth_info.secret))
223 : }
224 :
225 0 : async fn get_allowed_ips_and_secret(
226 0 : &self,
227 0 : ctx: &mut RequestMonitoring,
228 0 : user_info: &ComputeUserInfo,
229 0 : ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
230 0 : let normalized_ep = &user_info.endpoint.normalize();
231 0 : if let Some(allowed_ips) = self.caches.project_info.get_allowed_ips(normalized_ep) {
232 0 : Metrics::get()
233 0 : .proxy
234 0 : .allowed_ips_cache_misses
235 0 : .inc(CacheOutcome::Hit);
236 0 : return Ok((allowed_ips, None));
237 0 : }
238 0 : Metrics::get()
239 0 : .proxy
240 0 : .allowed_ips_cache_misses
241 0 : .inc(CacheOutcome::Miss);
242 0 : let auth_info = self.do_get_auth_info(ctx, user_info).await?;
243 0 : let allowed_ips = Arc::new(auth_info.allowed_ips);
244 0 : let user = &user_info.user;
245 0 : if let Some(project_id) = auth_info.project_id {
246 0 : let normalized_ep_int = normalized_ep.into();
247 0 : self.caches.project_info.insert_role_secret(
248 0 : project_id,
249 0 : normalized_ep_int,
250 0 : user.into(),
251 0 : auth_info.secret.clone(),
252 0 : );
253 0 : self.caches.project_info.insert_allowed_ips(
254 0 : project_id,
255 0 : normalized_ep_int,
256 0 : allowed_ips.clone(),
257 0 : );
258 0 : ctx.set_project_id(project_id);
259 0 : }
260 0 : Ok((
261 0 : Cached::new_uncached(allowed_ips),
262 0 : Some(Cached::new_uncached(auth_info.secret)),
263 0 : ))
264 0 : }
265 :
266 0 : #[tracing::instrument(skip_all)]
267 : async fn wake_compute(
268 : &self,
269 : ctx: &mut RequestMonitoring,
270 : user_info: &ComputeUserInfo,
271 : ) -> Result<CachedNodeInfo, WakeComputeError> {
272 : let key = user_info.endpoint_cache_key();
273 :
274 : // Every time we do a wakeup http request, the compute node will stay up
275 : // for some time (highly depends on the console's scale-to-zero policy);
276 : // The connection info remains the same during that period of time,
277 : // which means that we might cache it to reduce the load and latency.
278 : if let Some(cached) = self.caches.node_info.get(&key) {
279 : info!(key = &*key, "found cached compute node info");
280 : ctx.set_project(cached.aux.clone());
281 : return Ok(cached);
282 : }
283 :
284 : // check rate limit
285 : if !self
286 : .wake_compute_endpoint_rate_limiter
287 : .check(user_info.endpoint.normalize().into(), 1)
288 : {
289 : return Err(WakeComputeError::TooManyConnections);
290 : }
291 :
292 : let permit = self.locks.get_permit(&key).await?;
293 :
294 : // after getting back a permit - it's possible the cache was filled
295 : // double check
296 : if permit.should_check_cache() {
297 : if let Some(cached) = self.caches.node_info.get(&key) {
298 : info!(key = &*key, "found cached compute node info");
299 : ctx.set_project(cached.aux.clone());
300 : return Ok(cached);
301 : }
302 : }
303 :
304 : let mut node = self.do_wake_compute(ctx, user_info).await?;
305 : ctx.set_project(node.aux.clone());
306 : let cold_start_info = node.aux.cold_start_info;
307 : info!("woken up a compute node");
308 :
309 : // store the cached node as 'warm'
310 : node.aux.cold_start_info = ColdStartInfo::WarmCached;
311 : let (_, mut cached) = self.caches.node_info.insert(key.clone(), node);
312 : cached.aux.cold_start_info = cold_start_info;
313 :
314 : info!(key = &*key, "created a cache entry for compute node info");
315 :
316 : Ok(cached)
317 : }
318 : }
319 :
320 : /// Parse http response body, taking status code into account.
321 0 : async fn parse_body<T: for<'a> serde::Deserialize<'a>>(
322 0 : response: http::Response,
323 0 : ) -> Result<T, ApiError> {
324 0 : let status = response.status();
325 0 : if status.is_success() {
326 : // We shouldn't log raw body because it may contain secrets.
327 0 : info!("request succeeded, processing the body");
328 0 : return Ok(response.json().await?);
329 0 : }
330 :
331 : // Don't throw an error here because it's not as important
332 : // as the fact that the request itself has failed.
333 0 : let body = response.json().await.unwrap_or_else(|e| {
334 0 : warn!("failed to parse error body: {e}");
335 0 : ConsoleError {
336 0 : error: "reason unclear (malformed error message)".into(),
337 0 : }
338 0 : });
339 0 :
340 0 : let text = body.error;
341 0 : error!("console responded with an error ({status}): {text}");
342 0 : Err(ApiError::Console { status, text })
343 0 : }
344 :
345 6 : fn parse_host_port(input: &str) -> Option<(&str, u16)> {
346 6 : let (host, port) = input.rsplit_once(':')?;
347 6 : let ipv6_brackets: &[_] = &['[', ']'];
348 6 : Some((host.trim_matches(ipv6_brackets), port.parse().ok()?))
349 6 : }
350 :
351 : #[cfg(test)]
352 : mod tests {
353 : use super::*;
354 :
355 : #[test]
356 2 : fn test_parse_host_port_v4() {
357 2 : let (host, port) = parse_host_port("127.0.0.1:5432").expect("failed to parse");
358 2 : assert_eq!(host, "127.0.0.1");
359 2 : assert_eq!(port, 5432);
360 2 : }
361 :
362 : #[test]
363 2 : fn test_parse_host_port_v6() {
364 2 : let (host, port) = parse_host_port("[2001:db8::1]:5432").expect("failed to parse");
365 2 : assert_eq!(host, "2001:db8::1");
366 2 : assert_eq!(port, 5432);
367 2 : }
368 :
369 : #[test]
370 2 : fn test_parse_host_port_url() {
371 2 : let (host, port) = parse_host_port("compute-foo-bar-1234.default.svc.cluster.local:5432")
372 2 : .expect("failed to parse");
373 2 : assert_eq!(host, "compute-foo-bar-1234.default.svc.cluster.local");
374 2 : assert_eq!(port, 5432);
375 2 : }
376 : }
|