Line data Source code
1 : //! Production console backend.
2 :
3 : use super::{
4 : super::messages::{ConsoleError, GetRoleSecret, WakeCompute},
5 : errors::{ApiError, GetAuthInfoError, WakeComputeError},
6 : ApiCaches, ApiLocks, AuthInfo, AuthSecret, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret,
7 : NodeInfo,
8 : };
9 : use crate::{auth::backend::ComputeUserInfo, compute, http, scram};
10 : use crate::{
11 : cache::Cached,
12 : context::RequestMonitoring,
13 : metrics::{ALLOWED_IPS_BY_CACHE_OUTCOME, ALLOWED_IPS_NUMBER},
14 : };
15 : use async_trait::async_trait;
16 : use futures::TryFutureExt;
17 : use std::sync::Arc;
18 : use tokio::time::Instant;
19 : use tokio_postgres::config::SslMode;
20 : use tracing::{error, info, info_span, warn, Instrument};
21 :
22 : pub struct Api {
23 : endpoint: http::Endpoint,
24 : pub caches: &'static ApiCaches,
25 : locks: &'static ApiLocks,
26 : jwt: String,
27 : }
28 :
29 : impl Api {
30 : /// Construct an API object containing the auth parameters.
31 1 : pub fn new(
32 1 : endpoint: http::Endpoint,
33 1 : caches: &'static ApiCaches,
34 1 : locks: &'static ApiLocks,
35 1 : ) -> Self {
36 1 : let jwt: String = match std::env::var("NEON_PROXY_TO_CONTROLPLANE_TOKEN") {
37 0 : Ok(v) => v,
38 1 : Err(_) => "".to_string(),
39 : };
40 1 : Self {
41 1 : endpoint,
42 1 : caches,
43 1 : locks,
44 1 : jwt,
45 1 : }
46 1 : }
47 :
48 1 : pub fn url(&self) -> &str {
49 1 : self.endpoint.url().as_str()
50 1 : }
51 :
52 4 : async fn do_get_auth_info(
53 4 : &self,
54 4 : ctx: &mut RequestMonitoring,
55 4 : user_info: &ComputeUserInfo,
56 4 : ) -> Result<AuthInfo, GetAuthInfoError> {
57 4 : let request_id = uuid::Uuid::new_v4().to_string();
58 4 : let application_name = ctx.console_application_name();
59 4 : async {
60 4 : let request = self
61 4 : .endpoint
62 4 : .get("proxy_get_role_secret")
63 4 : .header("X-Request-ID", &request_id)
64 4 : .header("Authorization", format!("Bearer {}", &self.jwt))
65 4 : .query(&[("session_id", ctx.session_id)])
66 4 : .query(&[
67 4 : ("application_name", application_name.as_str()),
68 4 : ("project", user_info.endpoint.as_str()),
69 4 : ("role", user_info.user.as_str()),
70 4 : ])
71 4 : .build()?;
72 :
73 4 : info!(url = request.url().as_str(), "sending http request");
74 4 : let start = Instant::now();
75 12 : let response = self.endpoint.execute(request).await?;
76 3 : info!(duration = ?start.elapsed(), "received http response");
77 3 : let body = match parse_body::<GetRoleSecret>(response).await {
78 0 : Ok(body) => body,
79 : // Error 404 is special: it's ok not to have a secret.
80 3 : Err(e) => match e.http_status_code() {
81 0 : Some(http::StatusCode::NOT_FOUND) => return Ok(AuthInfo::default()),
82 3 : _otherwise => return Err(e.into()),
83 : },
84 : };
85 :
86 0 : let secret = if body.role_secret.is_empty() {
87 0 : None
88 : } else {
89 0 : let secret = scram::ServerSecret::parse(&body.role_secret)
90 0 : .map(AuthSecret::Scram)
91 0 : .ok_or(GetAuthInfoError::BadSecret)?;
92 0 : Some(secret)
93 : };
94 0 : let allowed_ips = body.allowed_ips.unwrap_or_default();
95 0 : ALLOWED_IPS_NUMBER.observe(allowed_ips.len() as f64);
96 0 : Ok(AuthInfo {
97 0 : secret,
98 0 : allowed_ips,
99 0 : project_id: body.project_id,
100 0 : })
101 4 : }
102 4 : .map_err(crate::error::log_error)
103 4 : .instrument(info_span!("http", id = request_id))
104 12 : .await
105 4 : }
106 :
107 0 : async fn do_wake_compute(
108 0 : &self,
109 0 : ctx: &mut RequestMonitoring,
110 0 : user_info: &ComputeUserInfo,
111 0 : ) -> Result<NodeInfo, WakeComputeError> {
112 0 : let request_id = uuid::Uuid::new_v4().to_string();
113 0 : let application_name = ctx.console_application_name();
114 0 : async {
115 0 : let mut request_builder = self
116 0 : .endpoint
117 0 : .get("proxy_wake_compute")
118 0 : .header("X-Request-ID", &request_id)
119 0 : .header("Authorization", format!("Bearer {}", &self.jwt))
120 0 : .query(&[("session_id", ctx.session_id)])
121 0 : .query(&[
122 0 : ("application_name", application_name.as_str()),
123 0 : ("project", user_info.endpoint.as_str()),
124 0 : ]);
125 0 :
126 0 : let options = user_info.options.to_deep_object();
127 0 : if !options.is_empty() {
128 0 : request_builder = request_builder.query(&options);
129 0 : }
130 :
131 0 : let request = request_builder.build()?;
132 :
133 0 : info!(url = request.url().as_str(), "sending http request");
134 0 : let start = Instant::now();
135 0 : let response = self.endpoint.execute(request).await?;
136 0 : info!(duration = ?start.elapsed(), "received http response");
137 0 : let body = parse_body::<WakeCompute>(response).await?;
138 :
139 : // Unfortunately, ownership won't let us use `Option::ok_or` here.
140 0 : let (host, port) = match parse_host_port(&body.address) {
141 0 : None => return Err(WakeComputeError::BadComputeAddress(body.address)),
142 0 : Some(x) => x,
143 0 : };
144 0 :
145 0 : // Don't set anything but host and port! This config will be cached.
146 0 : // We'll set username and such later using the startup message.
147 0 : // TODO: add more type safety (in progress).
148 0 : let mut config = compute::ConnCfg::new();
149 0 : config.host(host).port(port).ssl_mode(SslMode::Disable); // TLS is not configured on compute nodes.
150 0 :
151 0 : let node = NodeInfo {
152 0 : config,
153 0 : aux: body.aux,
154 0 : allow_self_signed_compute: false,
155 0 : };
156 0 :
157 0 : Ok(node)
158 0 : }
159 0 : .map_err(crate::error::log_error)
160 0 : .instrument(info_span!("http", id = request_id))
161 0 : .await
162 0 : }
163 : }
164 :
165 : #[async_trait]
166 : impl super::Api for Api {
167 0 : #[tracing::instrument(skip_all)]
168 : async fn get_role_secret(
169 : &self,
170 : ctx: &mut RequestMonitoring,
171 : user_info: &ComputeUserInfo,
172 0 : ) -> Result<CachedRoleSecret, GetAuthInfoError> {
173 0 : let ep = &user_info.endpoint;
174 0 : let user = &user_info.user;
175 0 : if let Some(role_secret) = self.caches.project_info.get_role_secret(ep, user) {
176 0 : return Ok(role_secret);
177 0 : }
178 0 : let auth_info = self.do_get_auth_info(ctx, user_info).await?;
179 0 : if let Some(project_id) = auth_info.project_id {
180 0 : self.caches.project_info.insert_role_secret(
181 0 : &project_id,
182 0 : ep,
183 0 : user,
184 0 : auth_info.secret.clone(),
185 0 : );
186 0 : self.caches.project_info.insert_allowed_ips(
187 0 : &project_id,
188 0 : ep,
189 0 : Arc::new(auth_info.allowed_ips),
190 0 : );
191 0 : }
192 : // When we just got a secret, we don't need to invalidate it.
193 0 : Ok(Cached::new_uncached(auth_info.secret))
194 0 : }
195 :
196 4 : async fn get_allowed_ips_and_secret(
197 4 : &self,
198 4 : ctx: &mut RequestMonitoring,
199 4 : user_info: &ComputeUserInfo,
200 4 : ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
201 4 : let ep = &user_info.endpoint;
202 4 : if let Some(allowed_ips) = self.caches.project_info.get_allowed_ips(ep) {
203 0 : ALLOWED_IPS_BY_CACHE_OUTCOME
204 0 : .with_label_values(&["hit"])
205 0 : .inc();
206 0 : return Ok((allowed_ips, None));
207 4 : }
208 4 : ALLOWED_IPS_BY_CACHE_OUTCOME
209 4 : .with_label_values(&["miss"])
210 4 : .inc();
211 12 : let auth_info = self.do_get_auth_info(ctx, user_info).await?;
212 0 : let allowed_ips = Arc::new(auth_info.allowed_ips);
213 0 : let user = &user_info.user;
214 0 : if let Some(project_id) = auth_info.project_id {
215 0 : self.caches.project_info.insert_role_secret(
216 0 : &project_id,
217 0 : ep,
218 0 : user,
219 0 : auth_info.secret.clone(),
220 0 : );
221 0 : self.caches
222 0 : .project_info
223 0 : .insert_allowed_ips(&project_id, ep, allowed_ips.clone());
224 0 : }
225 0 : Ok((
226 0 : Cached::new_uncached(allowed_ips),
227 0 : Some(Cached::new_uncached(auth_info.secret)),
228 0 : ))
229 8 : }
230 :
231 0 : #[tracing::instrument(skip_all)]
232 : async fn wake_compute(
233 : &self,
234 : ctx: &mut RequestMonitoring,
235 : user_info: &ComputeUserInfo,
236 0 : ) -> Result<CachedNodeInfo, WakeComputeError> {
237 0 : let key = user_info.endpoint_cache_key();
238 :
239 : // Every time we do a wakeup http request, the compute node will stay up
240 : // for some time (highly depends on the console's scale-to-zero policy);
241 : // The connection info remains the same during that period of time,
242 : // which means that we might cache it to reduce the load and latency.
243 0 : if let Some(cached) = self.caches.node_info.get(&key) {
244 0 : info!(key = &*key, "found cached compute node info");
245 0 : return Ok(cached);
246 0 : }
247 :
248 0 : let permit = self.locks.get_wake_compute_permit(&key).await?;
249 :
250 : // after getting back a permit - it's possible the cache was filled
251 : // double check
252 0 : if permit.should_check_cache() {
253 0 : if let Some(cached) = self.caches.node_info.get(&key) {
254 0 : info!(key = &*key, "found cached compute node info");
255 0 : return Ok(cached);
256 0 : }
257 0 : }
258 :
259 0 : let node = self.do_wake_compute(ctx, user_info).await?;
260 0 : let (_, cached) = self.caches.node_info.insert(key.clone(), node);
261 0 : info!(key = &*key, "created a cache entry for compute node info");
262 :
263 0 : Ok(cached)
264 0 : }
265 : }
266 :
267 : /// Parse http response body, taking status code into account.
268 3 : async fn parse_body<T: for<'a> serde::Deserialize<'a>>(
269 3 : response: http::Response,
270 3 : ) -> Result<T, ApiError> {
271 3 : let status = response.status();
272 3 : if status.is_success() {
273 : // We shouldn't log raw body because it may contain secrets.
274 1 : info!("request succeeded, processing the body");
275 1 : return Ok(response.json().await?);
276 2 : }
277 :
278 : // Don't throw an error here because it's not as important
279 : // as the fact that the request itself has failed.
280 2 : let body = response.json().await.unwrap_or_else(|e| {
281 2 : warn!("failed to parse error body: {e}");
282 2 : ConsoleError {
283 2 : error: "reason unclear (malformed error message)".into(),
284 2 : }
285 2 : });
286 2 :
287 2 : let text = body.error;
288 2 : error!("console responded with an error ({status}): {text}");
289 2 : Err(ApiError::Console { status, text })
290 3 : }
291 :
292 6 : fn parse_host_port(input: &str) -> Option<(&str, u16)> {
293 6 : let (host, port) = input.rsplit_once(':')?;
294 6 : let ipv6_brackets: &[_] = &['[', ']'];
295 6 : Some((host.trim_matches(ipv6_brackets), port.parse().ok()?))
296 6 : }
297 :
298 : #[cfg(test)]
299 : mod tests {
300 : use super::*;
301 :
302 2 : #[test]
303 2 : fn test_parse_host_port_v4() {
304 2 : let (host, port) = parse_host_port("127.0.0.1:5432").expect("failed to parse");
305 2 : assert_eq!(host, "127.0.0.1");
306 2 : assert_eq!(port, 5432);
307 2 : }
308 :
309 2 : #[test]
310 2 : fn test_parse_host_port_v6() {
311 2 : let (host, port) = parse_host_port("[2001:db8::1]:5432").expect("failed to parse");
312 2 : assert_eq!(host, "2001:db8::1");
313 2 : assert_eq!(port, 5432);
314 2 : }
315 :
316 2 : #[test]
317 2 : fn test_parse_host_port_url() {
318 2 : let (host, port) = parse_host_port("compute-foo-bar-1234.default.svc.cluster.local:5432")
319 2 : .expect("failed to parse");
320 2 : assert_eq!(host, "compute-foo-bar-1234.default.svc.cluster.local");
321 2 : assert_eq!(port, 5432);
322 2 : }
323 : }
|