Line data Source code
1 : //! Production console backend.
2 :
3 : use super::{
4 : super::messages::{ConsoleError, GetRoleSecret, WakeCompute},
5 : errors::{ApiError, GetAuthInfoError, WakeComputeError},
6 : ApiCaches, ApiLocks, AuthInfo, AuthSecret, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret,
7 : NodeInfo,
8 : };
9 : use crate::{
10 : auth::backend::ComputeUserInfo, compute, console::messages::ColdStartInfo, http, scram,
11 : };
12 : use crate::{
13 : cache::Cached,
14 : context::RequestMonitoring,
15 : metrics::{ALLOWED_IPS_BY_CACHE_OUTCOME, ALLOWED_IPS_NUMBER},
16 : };
17 : use futures::TryFutureExt;
18 : use std::sync::Arc;
19 : use tokio::time::Instant;
20 : use tokio_postgres::config::SslMode;
21 : use tracing::{error, info, info_span, warn, Instrument};
22 :
23 : pub struct Api {
24 : endpoint: http::Endpoint,
25 : pub caches: &'static ApiCaches,
26 : locks: &'static ApiLocks,
27 : jwt: String,
28 : }
29 :
30 : impl Api {
31 : /// Construct an API object containing the auth parameters.
32 0 : pub fn new(
33 0 : endpoint: http::Endpoint,
34 0 : caches: &'static ApiCaches,
35 0 : locks: &'static ApiLocks,
36 0 : ) -> Self {
37 0 : let jwt: String = match std::env::var("NEON_PROXY_TO_CONTROLPLANE_TOKEN") {
38 0 : Ok(v) => v,
39 0 : Err(_) => "".to_string(),
40 : };
41 0 : Self {
42 0 : endpoint,
43 0 : caches,
44 0 : locks,
45 0 : jwt,
46 0 : }
47 0 : }
48 :
49 0 : pub fn url(&self) -> &str {
50 0 : self.endpoint.url().as_str()
51 0 : }
52 :
53 0 : async fn do_get_auth_info(
54 0 : &self,
55 0 : ctx: &mut RequestMonitoring,
56 0 : user_info: &ComputeUserInfo,
57 0 : ) -> Result<AuthInfo, GetAuthInfoError> {
58 0 : let request_id = ctx.session_id.to_string();
59 0 : let application_name = ctx.console_application_name();
60 0 : async {
61 0 : let request = self
62 0 : .endpoint
63 0 : .get("proxy_get_role_secret")
64 0 : .header("X-Request-ID", &request_id)
65 0 : .header("Authorization", format!("Bearer {}", &self.jwt))
66 0 : .query(&[("session_id", ctx.session_id)])
67 0 : .query(&[
68 0 : ("application_name", application_name.as_str()),
69 0 : ("project", user_info.endpoint.as_str()),
70 0 : ("role", user_info.user.as_str()),
71 0 : ])
72 0 : .build()?;
73 :
74 0 : info!(url = request.url().as_str(), "sending http request");
75 0 : let start = Instant::now();
76 0 : let pause = ctx.latency_timer.pause(crate::metrics::Waiting::Cplane);
77 0 : let response = self.endpoint.execute(request).await?;
78 0 : drop(pause);
79 0 : info!(duration = ?start.elapsed(), "received http response");
80 0 : let body = match parse_body::<GetRoleSecret>(response).await {
81 0 : Ok(body) => body,
82 : // Error 404 is special: it's ok not to have a secret.
83 0 : Err(e) => match e.http_status_code() {
84 0 : Some(http::StatusCode::NOT_FOUND) => return Ok(AuthInfo::default()),
85 0 : _otherwise => return Err(e.into()),
86 : },
87 : };
88 :
89 0 : let secret = if body.role_secret.is_empty() {
90 0 : None
91 : } else {
92 0 : let secret = scram::ServerSecret::parse(&body.role_secret)
93 0 : .map(AuthSecret::Scram)
94 0 : .ok_or(GetAuthInfoError::BadSecret)?;
95 0 : Some(secret)
96 : };
97 0 : let allowed_ips = body.allowed_ips.unwrap_or_default();
98 0 : ALLOWED_IPS_NUMBER.observe(allowed_ips.len() as f64);
99 0 : Ok(AuthInfo {
100 0 : secret,
101 0 : allowed_ips,
102 0 : project_id: body.project_id,
103 0 : })
104 0 : }
105 0 : .map_err(crate::error::log_error)
106 0 : .instrument(info_span!("http", id = request_id))
107 0 : .await
108 0 : }
109 :
110 0 : async fn do_wake_compute(
111 0 : &self,
112 0 : ctx: &mut RequestMonitoring,
113 0 : user_info: &ComputeUserInfo,
114 0 : ) -> Result<NodeInfo, WakeComputeError> {
115 0 : let request_id = ctx.session_id.to_string();
116 0 : let application_name = ctx.console_application_name();
117 0 : async {
118 0 : let mut request_builder = self
119 0 : .endpoint
120 0 : .get("proxy_wake_compute")
121 0 : .header("X-Request-ID", &request_id)
122 0 : .header("Authorization", format!("Bearer {}", &self.jwt))
123 0 : .query(&[("session_id", ctx.session_id)])
124 0 : .query(&[
125 0 : ("application_name", application_name.as_str()),
126 0 : ("project", user_info.endpoint.as_str()),
127 0 : ]);
128 0 :
129 0 : let options = user_info.options.to_deep_object();
130 0 : if !options.is_empty() {
131 0 : request_builder = request_builder.query(&options);
132 0 : }
133 :
134 0 : let request = request_builder.build()?;
135 :
136 0 : info!(url = request.url().as_str(), "sending http request");
137 0 : let start = Instant::now();
138 0 : let pause = ctx.latency_timer.pause(crate::metrics::Waiting::Cplane);
139 0 : let response = self.endpoint.execute(request).await?;
140 0 : drop(pause);
141 0 : info!(duration = ?start.elapsed(), "received http response");
142 0 : let body = parse_body::<WakeCompute>(response).await?;
143 :
144 : // Unfortunately, ownership won't let us use `Option::ok_or` here.
145 0 : let (host, port) = match parse_host_port(&body.address) {
146 0 : None => return Err(WakeComputeError::BadComputeAddress(body.address)),
147 0 : Some(x) => x,
148 0 : };
149 0 :
150 0 : // Don't set anything but host and port! This config will be cached.
151 0 : // We'll set username and such later using the startup message.
152 0 : // TODO: add more type safety (in progress).
153 0 : let mut config = compute::ConnCfg::new();
154 0 : config.host(host).port(port).ssl_mode(SslMode::Disable); // TLS is not configured on compute nodes.
155 0 :
156 0 : let node = NodeInfo {
157 0 : config,
158 0 : aux: body.aux,
159 0 : allow_self_signed_compute: false,
160 0 : };
161 0 :
162 0 : Ok(node)
163 0 : }
164 0 : .map_err(crate::error::log_error)
165 0 : .instrument(info_span!("http", id = request_id))
166 0 : .await
167 0 : }
168 : }
169 :
170 : impl super::Api for Api {
171 0 : #[tracing::instrument(skip_all)]
172 : async fn get_role_secret(
173 : &self,
174 : ctx: &mut RequestMonitoring,
175 : user_info: &ComputeUserInfo,
176 : ) -> Result<CachedRoleSecret, GetAuthInfoError> {
177 : let ep = &user_info.endpoint;
178 : let user = &user_info.user;
179 : if let Some(role_secret) = self.caches.project_info.get_role_secret(ep, user) {
180 : return Ok(role_secret);
181 : }
182 : let auth_info = self.do_get_auth_info(ctx, user_info).await?;
183 : if let Some(project_id) = auth_info.project_id {
184 : let ep_int = ep.into();
185 : self.caches.project_info.insert_role_secret(
186 : project_id,
187 : ep_int,
188 : user.into(),
189 : auth_info.secret.clone(),
190 : );
191 : self.caches.project_info.insert_allowed_ips(
192 : project_id,
193 : ep_int,
194 : Arc::new(auth_info.allowed_ips),
195 : );
196 : ctx.set_project_id(project_id);
197 : }
198 : // When we just got a secret, we don't need to invalidate it.
199 : Ok(Cached::new_uncached(auth_info.secret))
200 : }
201 :
202 0 : async fn get_allowed_ips_and_secret(
203 0 : &self,
204 0 : ctx: &mut RequestMonitoring,
205 0 : user_info: &ComputeUserInfo,
206 0 : ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
207 0 : let ep = &user_info.endpoint;
208 0 : if let Some(allowed_ips) = self.caches.project_info.get_allowed_ips(ep) {
209 0 : ALLOWED_IPS_BY_CACHE_OUTCOME
210 0 : .with_label_values(&["hit"])
211 0 : .inc();
212 0 : return Ok((allowed_ips, None));
213 0 : }
214 0 : ALLOWED_IPS_BY_CACHE_OUTCOME
215 0 : .with_label_values(&["miss"])
216 0 : .inc();
217 0 : let auth_info = self.do_get_auth_info(ctx, user_info).await?;
218 0 : let allowed_ips = Arc::new(auth_info.allowed_ips);
219 0 : let user = &user_info.user;
220 0 : if let Some(project_id) = auth_info.project_id {
221 0 : let ep_int = ep.into();
222 0 : self.caches.project_info.insert_role_secret(
223 0 : project_id,
224 0 : ep_int,
225 0 : user.into(),
226 0 : auth_info.secret.clone(),
227 0 : );
228 0 : self.caches
229 0 : .project_info
230 0 : .insert_allowed_ips(project_id, ep_int, allowed_ips.clone());
231 0 : ctx.set_project_id(project_id);
232 0 : }
233 0 : Ok((
234 0 : Cached::new_uncached(allowed_ips),
235 0 : Some(Cached::new_uncached(auth_info.secret)),
236 0 : ))
237 0 : }
238 :
239 0 : #[tracing::instrument(skip_all)]
240 : async fn wake_compute(
241 : &self,
242 : ctx: &mut RequestMonitoring,
243 : user_info: &ComputeUserInfo,
244 : ) -> Result<CachedNodeInfo, WakeComputeError> {
245 : let key = user_info.endpoint_cache_key();
246 :
247 : // Every time we do a wakeup http request, the compute node will stay up
248 : // for some time (highly depends on the console's scale-to-zero policy);
249 : // The connection info remains the same during that period of time,
250 : // which means that we might cache it to reduce the load and latency.
251 : if let Some(cached) = self.caches.node_info.get(&key) {
252 0 : info!(key = &*key, "found cached compute node info");
253 : ctx.set_project(cached.aux.clone());
254 : return Ok(cached);
255 : }
256 :
257 : let permit = self.locks.get_wake_compute_permit(&key).await?;
258 :
259 : // after getting back a permit - it's possible the cache was filled
260 : // double check
261 : if permit.should_check_cache() {
262 : if let Some(cached) = self.caches.node_info.get(&key) {
263 0 : info!(key = &*key, "found cached compute node info");
264 : ctx.set_project(cached.aux.clone());
265 : return Ok(cached);
266 : }
267 : }
268 :
269 : let mut node = self.do_wake_compute(ctx, user_info).await?;
270 : ctx.set_project(node.aux.clone());
271 : let cold_start_info = node.aux.cold_start_info;
272 0 : info!("woken up a compute node");
273 :
274 : // store the cached node as 'warm'
275 : node.aux.cold_start_info = ColdStartInfo::WarmCached;
276 : let (_, mut cached) = self.caches.node_info.insert(key.clone(), node);
277 : cached.aux.cold_start_info = cold_start_info;
278 :
279 0 : info!(key = &*key, "created a cache entry for compute node info");
280 :
281 : Ok(cached)
282 : }
283 : }
284 :
285 : /// Parse http response body, taking status code into account.
286 0 : async fn parse_body<T: for<'a> serde::Deserialize<'a>>(
287 0 : response: http::Response,
288 0 : ) -> Result<T, ApiError> {
289 0 : let status = response.status();
290 0 : if status.is_success() {
291 : // We shouldn't log raw body because it may contain secrets.
292 0 : info!("request succeeded, processing the body");
293 0 : return Ok(response.json().await?);
294 0 : }
295 :
296 : // Don't throw an error here because it's not as important
297 : // as the fact that the request itself has failed.
298 0 : let body = response.json().await.unwrap_or_else(|e| {
299 0 : warn!("failed to parse error body: {e}");
300 0 : ConsoleError {
301 0 : error: "reason unclear (malformed error message)".into(),
302 0 : }
303 0 : });
304 0 :
305 0 : let text = body.error;
306 0 : error!("console responded with an error ({status}): {text}");
307 0 : Err(ApiError::Console { status, text })
308 0 : }
309 :
310 6 : fn parse_host_port(input: &str) -> Option<(&str, u16)> {
311 6 : let (host, port) = input.rsplit_once(':')?;
312 6 : let ipv6_brackets: &[_] = &['[', ']'];
313 6 : Some((host.trim_matches(ipv6_brackets), port.parse().ok()?))
314 6 : }
315 :
316 : #[cfg(test)]
317 : mod tests {
318 : use super::*;
319 :
320 : #[test]
321 2 : fn test_parse_host_port_v4() {
322 2 : let (host, port) = parse_host_port("127.0.0.1:5432").expect("failed to parse");
323 2 : assert_eq!(host, "127.0.0.1");
324 2 : assert_eq!(port, 5432);
325 2 : }
326 :
327 : #[test]
328 2 : fn test_parse_host_port_v6() {
329 2 : let (host, port) = parse_host_port("[2001:db8::1]:5432").expect("failed to parse");
330 2 : assert_eq!(host, "2001:db8::1");
331 2 : assert_eq!(port, 5432);
332 2 : }
333 :
334 : #[test]
335 2 : fn test_parse_host_port_url() {
336 2 : let (host, port) = parse_host_port("compute-foo-bar-1234.default.svc.cluster.local:5432")
337 2 : .expect("failed to parse");
338 2 : assert_eq!(host, "compute-foo-bar-1234.default.svc.cluster.local");
339 2 : assert_eq!(port, 5432);
340 2 : }
341 : }
|