Line data Source code
1 : //! Production console backend.
2 :
3 : use super::{
4 : super::messages::{ConsoleError, GetRoleSecret, WakeCompute},
5 : errors::{ApiError, GetAuthInfoError, WakeComputeError},
6 : ApiCaches, ApiLocks, AuthInfo, AuthSecret, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret,
7 : NodeInfo,
8 : };
9 : use crate::{auth::backend::ComputeUserInfo, compute, http, scram};
10 : use crate::{
11 : cache::Cached,
12 : context::RequestMonitoring,
13 : metrics::{ALLOWED_IPS_BY_CACHE_OUTCOME, ALLOWED_IPS_NUMBER},
14 : };
15 : use async_trait::async_trait;
16 : use futures::TryFutureExt;
17 : use std::sync::Arc;
18 : use tokio::time::Instant;
19 : use tokio_postgres::config::SslMode;
20 : use tracing::{error, info, info_span, warn, Instrument};
21 :
22 : pub struct Api {
23 : endpoint: http::Endpoint,
24 : pub caches: &'static ApiCaches,
25 : locks: &'static ApiLocks,
26 : jwt: String,
27 : }
28 :
29 : impl Api {
30 : /// Construct an API object containing the auth parameters.
31 1 : pub fn new(
32 1 : endpoint: http::Endpoint,
33 1 : caches: &'static ApiCaches,
34 1 : locks: &'static ApiLocks,
35 1 : ) -> Self {
36 1 : let jwt: String = match std::env::var("NEON_PROXY_TO_CONTROLPLANE_TOKEN") {
37 0 : Ok(v) => v,
38 1 : Err(_) => "".to_string(),
39 : };
40 1 : Self {
41 1 : endpoint,
42 1 : caches,
43 1 : locks,
44 1 : jwt,
45 1 : }
46 1 : }
47 :
48 1 : pub fn url(&self) -> &str {
49 1 : self.endpoint.url().as_str()
50 1 : }
51 :
52 4 : async fn do_get_auth_info(
53 4 : &self,
54 4 : ctx: &mut RequestMonitoring,
55 4 : user_info: &ComputeUserInfo,
56 4 : ) -> Result<AuthInfo, GetAuthInfoError> {
57 4 : let request_id = uuid::Uuid::new_v4().to_string();
58 4 : let application_name = ctx.console_application_name();
59 4 : async {
60 4 : let request = self
61 4 : .endpoint
62 4 : .get("proxy_get_role_secret")
63 4 : .header("X-Request-ID", &request_id)
64 4 : .header("Authorization", format!("Bearer {}", &self.jwt))
65 4 : .query(&[("session_id", ctx.session_id)])
66 4 : .query(&[
67 4 : ("application_name", application_name.as_str()),
68 4 : ("project", user_info.endpoint.as_str()),
69 4 : ("role", user_info.user.as_str()),
70 4 : ])
71 4 : .build()?;
72 :
73 4 : info!(url = request.url().as_str(), "sending http request");
74 4 : let start = Instant::now();
75 12 : let response = self.endpoint.execute(request).await?;
76 3 : info!(duration = ?start.elapsed(), "received http response");
77 3 : let body = match parse_body::<GetRoleSecret>(response).await {
78 0 : Ok(body) => body,
79 : // Error 404 is special: it's ok not to have a secret.
80 3 : Err(e) => match e.http_status_code() {
81 0 : Some(http::StatusCode::NOT_FOUND) => return Ok(AuthInfo::default()),
82 3 : _otherwise => return Err(e.into()),
83 : },
84 : };
85 :
86 0 : let secret = if body.role_secret.is_empty() {
87 0 : None
88 : } else {
89 0 : let secret = scram::ServerSecret::parse(&body.role_secret)
90 0 : .map(AuthSecret::Scram)
91 0 : .ok_or(GetAuthInfoError::BadSecret)?;
92 0 : Some(secret)
93 : };
94 0 : let allowed_ips = body.allowed_ips.unwrap_or_default();
95 0 : ALLOWED_IPS_NUMBER.observe(allowed_ips.len() as f64);
96 0 : Ok(AuthInfo {
97 0 : secret,
98 0 : allowed_ips,
99 0 : project_id: body.project_id,
100 0 : })
101 4 : }
102 4 : .map_err(crate::error::log_error)
103 4 : .instrument(info_span!("http", id = request_id))
104 12 : .await
105 4 : }
106 :
107 0 : async fn do_wake_compute(
108 0 : &self,
109 0 : ctx: &mut RequestMonitoring,
110 0 : user_info: &ComputeUserInfo,
111 0 : ) -> Result<NodeInfo, WakeComputeError> {
112 0 : let request_id = uuid::Uuid::new_v4().to_string();
113 0 : let application_name = ctx.console_application_name();
114 0 : async {
115 0 : let mut request_builder = self
116 0 : .endpoint
117 0 : .get("proxy_wake_compute")
118 0 : .header("X-Request-ID", &request_id)
119 0 : .header("Authorization", format!("Bearer {}", &self.jwt))
120 0 : .query(&[("session_id", ctx.session_id)])
121 0 : .query(&[
122 0 : ("application_name", application_name.as_str()),
123 0 : ("project", user_info.endpoint.as_str()),
124 0 : ]);
125 0 :
126 0 : let options = user_info.options.to_deep_object();
127 0 : if !options.is_empty() {
128 0 : request_builder = request_builder.query(&options);
129 0 : }
130 :
131 0 : let request = request_builder.build()?;
132 :
133 0 : info!(url = request.url().as_str(), "sending http request");
134 0 : let start = Instant::now();
135 0 : let response = self.endpoint.execute(request).await?;
136 0 : info!(duration = ?start.elapsed(), "received http response");
137 0 : let body = parse_body::<WakeCompute>(response).await?;
138 :
139 : // Unfortunately, ownership won't let us use `Option::ok_or` here.
140 0 : let (host, port) = match parse_host_port(&body.address) {
141 0 : None => return Err(WakeComputeError::BadComputeAddress(body.address)),
142 0 : Some(x) => x,
143 0 : };
144 0 :
145 0 : // Don't set anything but host and port! This config will be cached.
146 0 : // We'll set username and such later using the startup message.
147 0 : // TODO: add more type safety (in progress).
148 0 : let mut config = compute::ConnCfg::new();
149 0 : config.host(host).port(port).ssl_mode(SslMode::Disable); // TLS is not configured on compute nodes.
150 0 :
151 0 : let node = NodeInfo {
152 0 : config,
153 0 : aux: body.aux,
154 0 : allow_self_signed_compute: false,
155 0 : };
156 0 :
157 0 : Ok(node)
158 0 : }
159 0 : .map_err(crate::error::log_error)
160 0 : .instrument(info_span!("http", id = request_id))
161 0 : .await
162 0 : }
163 : }
164 :
165 : #[async_trait]
166 : impl super::Api for Api {
167 0 : #[tracing::instrument(skip_all)]
168 : async fn get_role_secret(
169 : &self,
170 : ctx: &mut RequestMonitoring,
171 : user_info: &ComputeUserInfo,
172 0 : ) -> Result<CachedRoleSecret, GetAuthInfoError> {
173 0 : let ep = &user_info.endpoint;
174 0 : let user = &user_info.user;
175 0 : if let Some(role_secret) = self.caches.project_info.get_role_secret(ep, user) {
176 0 : return Ok(role_secret);
177 0 : }
178 0 : let auth_info = self.do_get_auth_info(ctx, user_info).await?;
179 0 : if let Some(project_id) = auth_info.project_id {
180 0 : self.caches.project_info.insert_role_secret(
181 0 : &project_id,
182 0 : ep,
183 0 : user,
184 0 : auth_info.secret.clone(),
185 0 : );
186 0 : self.caches.project_info.insert_allowed_ips(
187 0 : &project_id,
188 0 : ep,
189 0 : Arc::new(auth_info.allowed_ips),
190 0 : );
191 0 : ctx.set_project_id(project_id);
192 0 : }
193 : // When we just got a secret, we don't need to invalidate it.
194 0 : Ok(Cached::new_uncached(auth_info.secret))
195 0 : }
196 :
197 4 : async fn get_allowed_ips_and_secret(
198 4 : &self,
199 4 : ctx: &mut RequestMonitoring,
200 4 : user_info: &ComputeUserInfo,
201 4 : ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
202 4 : let ep = &user_info.endpoint;
203 4 : if let Some(allowed_ips) = self.caches.project_info.get_allowed_ips(ep) {
204 0 : ALLOWED_IPS_BY_CACHE_OUTCOME
205 0 : .with_label_values(&["hit"])
206 0 : .inc();
207 0 : return Ok((allowed_ips, None));
208 4 : }
209 4 : ALLOWED_IPS_BY_CACHE_OUTCOME
210 4 : .with_label_values(&["miss"])
211 4 : .inc();
212 12 : let auth_info = self.do_get_auth_info(ctx, user_info).await?;
213 0 : let allowed_ips = Arc::new(auth_info.allowed_ips);
214 0 : let user = &user_info.user;
215 0 : if let Some(project_id) = auth_info.project_id {
216 0 : self.caches.project_info.insert_role_secret(
217 0 : &project_id,
218 0 : ep,
219 0 : user,
220 0 : auth_info.secret.clone(),
221 0 : );
222 0 : self.caches
223 0 : .project_info
224 0 : .insert_allowed_ips(&project_id, ep, allowed_ips.clone());
225 0 : ctx.set_project_id(project_id);
226 0 : }
227 0 : Ok((
228 0 : Cached::new_uncached(allowed_ips),
229 0 : Some(Cached::new_uncached(auth_info.secret)),
230 0 : ))
231 12 : }
232 :
233 0 : #[tracing::instrument(skip_all)]
234 : async fn wake_compute(
235 : &self,
236 : ctx: &mut RequestMonitoring,
237 : user_info: &ComputeUserInfo,
238 0 : ) -> Result<CachedNodeInfo, WakeComputeError> {
239 0 : let key = user_info.endpoint_cache_key();
240 :
241 : // Every time we do a wakeup http request, the compute node will stay up
242 : // for some time (highly depends on the console's scale-to-zero policy);
243 : // The connection info remains the same during that period of time,
244 : // which means that we might cache it to reduce the load and latency.
245 0 : if let Some(cached) = self.caches.node_info.get(&key) {
246 0 : info!(key = &*key, "found cached compute node info");
247 0 : return Ok(cached);
248 0 : }
249 :
250 0 : let permit = self.locks.get_wake_compute_permit(&key).await?;
251 :
252 : // after getting back a permit - it's possible the cache was filled
253 : // double check
254 0 : if permit.should_check_cache() {
255 0 : if let Some(cached) = self.caches.node_info.get(&key) {
256 0 : info!(key = &*key, "found cached compute node info");
257 0 : return Ok(cached);
258 0 : }
259 0 : }
260 :
261 0 : let node = self.do_wake_compute(ctx, user_info).await?;
262 0 : let (_, cached) = self.caches.node_info.insert(key.clone(), node);
263 0 : info!(key = &*key, "created a cache entry for compute node info");
264 :
265 0 : Ok(cached)
266 0 : }
267 : }
268 :
269 : /// Parse http response body, taking status code into account.
270 3 : async fn parse_body<T: for<'a> serde::Deserialize<'a>>(
271 3 : response: http::Response,
272 3 : ) -> Result<T, ApiError> {
273 3 : let status = response.status();
274 3 : if status.is_success() {
275 : // We shouldn't log raw body because it may contain secrets.
276 1 : info!("request succeeded, processing the body");
277 1 : return Ok(response.json().await?);
278 2 : }
279 :
280 : // Don't throw an error here because it's not as important
281 : // as the fact that the request itself has failed.
282 2 : let body = response.json().await.unwrap_or_else(|e| {
283 2 : warn!("failed to parse error body: {e}");
284 2 : ConsoleError {
285 2 : error: "reason unclear (malformed error message)".into(),
286 2 : }
287 2 : });
288 2 :
289 2 : let text = body.error;
290 2 : error!("console responded with an error ({status}): {text}");
291 2 : Err(ApiError::Console { status, text })
292 3 : }
293 :
294 6 : fn parse_host_port(input: &str) -> Option<(&str, u16)> {
295 6 : let (host, port) = input.rsplit_once(':')?;
296 6 : let ipv6_brackets: &[_] = &['[', ']'];
297 6 : Some((host.trim_matches(ipv6_brackets), port.parse().ok()?))
298 6 : }
299 :
300 : #[cfg(test)]
301 : mod tests {
302 : use super::*;
303 :
304 2 : #[test]
305 2 : fn test_parse_host_port_v4() {
306 2 : let (host, port) = parse_host_port("127.0.0.1:5432").expect("failed to parse");
307 2 : assert_eq!(host, "127.0.0.1");
308 2 : assert_eq!(port, 5432);
309 2 : }
310 :
311 2 : #[test]
312 2 : fn test_parse_host_port_v6() {
313 2 : let (host, port) = parse_host_port("[2001:db8::1]:5432").expect("failed to parse");
314 2 : assert_eq!(host, "2001:db8::1");
315 2 : assert_eq!(port, 5432);
316 2 : }
317 :
318 2 : #[test]
319 2 : fn test_parse_host_port_url() {
320 2 : let (host, port) = parse_host_port("compute-foo-bar-1234.default.svc.cluster.local:5432")
321 2 : .expect("failed to parse");
322 2 : assert_eq!(host, "compute-foo-bar-1234.default.svc.cluster.local");
323 2 : assert_eq!(port, 5432);
324 2 : }
325 : }
|