Line data Source code
1 : use std::fmt::{self, Display};
2 :
3 : use measured::FixedCardinalityLabel;
4 : use serde::{Deserialize, Serialize};
5 :
6 : use crate::auth::IpPattern;
7 : use crate::intern::{BranchIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt};
8 : use crate::proxy::retry::CouldRetry;
9 :
10 : /// Generic error response with human-readable description.
11 : /// Note that we can't always present it to user as is.
12 0 : #[derive(Debug, Deserialize, Clone)]
13 : pub(crate) struct ControlPlaneErrorMessage {
14 : pub(crate) error: Box<str>,
15 : #[serde(skip)]
16 : pub(crate) http_status_code: http::StatusCode,
17 : pub(crate) status: Option<Status>,
18 : }
19 :
20 : impl ControlPlaneErrorMessage {
21 3 : pub(crate) fn get_reason(&self) -> Reason {
22 3 : self.status
23 3 : .as_ref()
24 3 : .and_then(|s| s.details.error_info.as_ref())
25 3 : .map_or(Reason::Unknown, |e| e.reason)
26 3 : }
27 :
28 0 : pub(crate) fn get_user_facing_message(&self) -> String {
29 : use super::errors::REQUEST_FAILED;
30 0 : self.status
31 0 : .as_ref()
32 0 : .and_then(|s| s.details.user_facing_message.as_ref())
33 0 : .map_or_else(|| {
34 0 : // Ask @neondatabase/control-plane for review before adding more.
35 0 : match self.http_status_code {
36 : http::StatusCode::NOT_FOUND => {
37 : // Status 404: failed to get a project-related resource.
38 0 : format!("{REQUEST_FAILED}: endpoint cannot be found")
39 : }
40 : http::StatusCode::NOT_ACCEPTABLE => {
41 : // Status 406: endpoint is disabled (we don't allow connections).
42 0 : format!("{REQUEST_FAILED}: endpoint is disabled")
43 : }
44 : http::StatusCode::LOCKED | http::StatusCode::UNPROCESSABLE_ENTITY => {
45 : // Status 423: project might be in maintenance mode (or bad state), or quotas exceeded.
46 0 : format!("{REQUEST_FAILED}: endpoint is temporarily unavailable. Check your quotas and/or contact our support.")
47 : }
48 0 : _ => REQUEST_FAILED.to_owned(),
49 : }
50 0 : }, |m| m.message.clone().into())
51 0 : }
52 : }
53 :
54 : impl Display for ControlPlaneErrorMessage {
55 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56 0 : let msg: &str = self
57 0 : .status
58 0 : .as_ref()
59 0 : .and_then(|s| s.details.user_facing_message.as_ref())
60 0 : .map_or_else(|| self.error.as_ref(), |m| m.message.as_ref());
61 0 : write!(f, "{msg}")
62 0 : }
63 : }
64 :
65 : impl CouldRetry for ControlPlaneErrorMessage {
66 6 : fn could_retry(&self) -> bool {
67 : // If the error message does not have a status,
68 : // the error is unknown and probably should not retry automatically
69 6 : let Some(status) = &self.status else {
70 2 : return false;
71 : };
72 :
73 : // retry if the retry info is set.
74 4 : if status.details.retry_info.is_some() {
75 4 : return true;
76 0 : }
77 0 :
78 0 : // if no retry info set, attempt to use the error code to guess the retry state.
79 0 : let reason = status
80 0 : .details
81 0 : .error_info
82 0 : .map_or(Reason::Unknown, |e| e.reason);
83 0 :
84 0 : reason.can_retry()
85 6 : }
86 : }
87 :
88 0 : #[derive(Debug, Deserialize, Clone)]
89 : #[allow(dead_code)]
90 : pub(crate) struct Status {
91 : pub(crate) code: Box<str>,
92 : pub(crate) message: Box<str>,
93 : pub(crate) details: Details,
94 : }
95 :
96 0 : #[derive(Debug, Deserialize, Clone)]
97 : pub(crate) struct Details {
98 : pub(crate) error_info: Option<ErrorInfo>,
99 : pub(crate) retry_info: Option<RetryInfo>,
100 : pub(crate) user_facing_message: Option<UserFacingMessage>,
101 : }
102 :
103 0 : #[derive(Copy, Clone, Debug, Deserialize)]
104 : pub(crate) struct ErrorInfo {
105 : pub(crate) reason: Reason,
106 : // Schema could also have `metadata` field, but it's not structured. Skip it for now.
107 : }
108 :
109 0 : #[derive(Clone, Copy, Debug, Deserialize, Default)]
110 : pub(crate) enum Reason {
111 : /// RoleProtected indicates that the role is protected and the attempted operation is not permitted on protected roles.
112 : #[serde(rename = "ROLE_PROTECTED")]
113 : RoleProtected,
114 : /// ResourceNotFound indicates that a resource (project, endpoint, branch, etc.) wasn't found,
115 : /// usually due to the provided ID not being correct or because the subject doesn't have enough permissions to
116 : /// access the requested resource.
117 : /// Prefer a more specific reason if possible, e.g., ProjectNotFound, EndpointNotFound, etc.
118 : #[serde(rename = "RESOURCE_NOT_FOUND")]
119 : ResourceNotFound,
120 : /// ProjectNotFound indicates that the project wasn't found, usually due to the provided ID not being correct,
121 : /// or that the subject doesn't have enough permissions to access the requested project.
122 : #[serde(rename = "PROJECT_NOT_FOUND")]
123 : ProjectNotFound,
124 : /// EndpointNotFound indicates that the endpoint wasn't found, usually due to the provided ID not being correct,
125 : /// or that the subject doesn't have enough permissions to access the requested endpoint.
126 : #[serde(rename = "ENDPOINT_NOT_FOUND")]
127 : EndpointNotFound,
128 : /// BranchNotFound indicates that the branch wasn't found, usually due to the provided ID not being correct,
129 : /// or that the subject doesn't have enough permissions to access the requested branch.
130 : #[serde(rename = "BRANCH_NOT_FOUND")]
131 : BranchNotFound,
132 : /// RateLimitExceeded indicates that the rate limit for the operation has been exceeded.
133 : #[serde(rename = "RATE_LIMIT_EXCEEDED")]
134 : RateLimitExceeded,
135 : /// NonDefaultBranchComputeTimeExceeded indicates that the compute time quota of non-default branches has been
136 : /// exceeded.
137 : #[serde(rename = "NON_PRIMARY_BRANCH_COMPUTE_TIME_EXCEEDED")]
138 : NonDefaultBranchComputeTimeExceeded,
139 : /// ActiveTimeQuotaExceeded indicates that the active time quota was exceeded.
140 : #[serde(rename = "ACTIVE_TIME_QUOTA_EXCEEDED")]
141 : ActiveTimeQuotaExceeded,
142 : /// ComputeTimeQuotaExceeded indicates that the compute time quota was exceeded.
143 : #[serde(rename = "COMPUTE_TIME_QUOTA_EXCEEDED")]
144 : ComputeTimeQuotaExceeded,
145 : /// WrittenDataQuotaExceeded indicates that the written data quota was exceeded.
146 : #[serde(rename = "WRITTEN_DATA_QUOTA_EXCEEDED")]
147 : WrittenDataQuotaExceeded,
148 : /// DataTransferQuotaExceeded indicates that the data transfer quota was exceeded.
149 : #[serde(rename = "DATA_TRANSFER_QUOTA_EXCEEDED")]
150 : DataTransferQuotaExceeded,
151 : /// LogicalSizeQuotaExceeded indicates that the logical size quota was exceeded.
152 : #[serde(rename = "LOGICAL_SIZE_QUOTA_EXCEEDED")]
153 : LogicalSizeQuotaExceeded,
154 : /// RunningOperations indicates that the project already has some running operations
155 : /// and scheduling of new ones is prohibited.
156 : #[serde(rename = "RUNNING_OPERATIONS")]
157 : RunningOperations,
158 : /// ConcurrencyLimitReached indicates that the concurrency limit for an action was reached.
159 : #[serde(rename = "CONCURRENCY_LIMIT_REACHED")]
160 : ConcurrencyLimitReached,
161 : /// LockAlreadyTaken indicates that the we attempted to take a lock that was already taken.
162 : #[serde(rename = "LOCK_ALREADY_TAKEN")]
163 : LockAlreadyTaken,
164 : /// ActiveEndpointsLimitExceeded indicates that the limit of concurrently active endpoints was exceeded.
165 : #[serde(rename = "ACTIVE_ENDPOINTS_LIMIT_EXCEEDED")]
166 : ActiveEndpointsLimitExceeded,
167 : #[default]
168 : #[serde(other)]
169 : Unknown,
170 : }
171 :
172 : impl Reason {
173 0 : pub(crate) fn is_not_found(self) -> bool {
174 0 : matches!(
175 0 : self,
176 : Reason::ResourceNotFound
177 : | Reason::ProjectNotFound
178 : | Reason::EndpointNotFound
179 : | Reason::BranchNotFound
180 : )
181 0 : }
182 :
183 0 : pub(crate) fn can_retry(self) -> bool {
184 0 : match self {
185 : // do not retry role protected errors
186 : // not a transitive error
187 0 : Reason::RoleProtected => false,
188 : // on retry, it will still not be found
189 : Reason::ResourceNotFound
190 : | Reason::ProjectNotFound
191 : | Reason::EndpointNotFound
192 0 : | Reason::BranchNotFound => false,
193 : // we were asked to go away
194 : Reason::RateLimitExceeded
195 : | Reason::NonDefaultBranchComputeTimeExceeded
196 : | Reason::ActiveTimeQuotaExceeded
197 : | Reason::ComputeTimeQuotaExceeded
198 : | Reason::WrittenDataQuotaExceeded
199 : | Reason::DataTransferQuotaExceeded
200 : | Reason::LogicalSizeQuotaExceeded
201 0 : | Reason::ActiveEndpointsLimitExceeded => false,
202 : // transitive error. control plane is currently busy
203 : // but might be ready soon
204 : Reason::RunningOperations
205 : | Reason::ConcurrencyLimitReached
206 0 : | Reason::LockAlreadyTaken => true,
207 : // unknown error. better not retry it.
208 0 : Reason::Unknown => false,
209 : }
210 0 : }
211 : }
212 :
213 0 : #[derive(Copy, Clone, Debug, Deserialize)]
214 : #[allow(dead_code)]
215 : pub(crate) struct RetryInfo {
216 : pub(crate) retry_delay_ms: u64,
217 : }
218 :
219 0 : #[derive(Debug, Deserialize, Clone)]
220 : pub(crate) struct UserFacingMessage {
221 : pub(crate) message: Box<str>,
222 : }
223 :
224 : /// Response which holds client's auth secret, e.g. [`crate::scram::ServerSecret`].
225 : /// Returned by the `/proxy_get_role_secret` API method.
226 9 : #[derive(Deserialize)]
227 : pub(crate) struct GetRoleSecret {
228 : pub(crate) role_secret: Box<str>,
229 : pub(crate) allowed_ips: Option<Vec<IpPattern>>,
230 : pub(crate) project_id: Option<ProjectIdInt>,
231 : }
232 :
233 : /// Response which holds client's auth secret, e.g. [`crate::scram::ServerSecret`].
234 : /// Returned by the `/get_endpoint_access_control` API method.
235 0 : #[derive(Deserialize)]
236 : pub(crate) struct GetEndpointAccessControl {
237 : pub(crate) role_secret: Box<str>,
238 : pub(crate) allowed_ips: Option<Vec<IpPattern>>,
239 : pub(crate) project_id: Option<ProjectIdInt>,
240 : pub(crate) allowed_vpc_endpoint_ids: Option<Vec<EndpointIdInt>>,
241 : }
242 :
243 : // Manually implement debug to omit sensitive info.
244 : impl fmt::Debug for GetRoleSecret {
245 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
246 0 : f.debug_struct("GetRoleSecret").finish_non_exhaustive()
247 0 : }
248 : }
249 :
250 : /// Response which holds compute node's `host:port` pair.
251 : /// Returned by the `/proxy_wake_compute` API method.
252 3 : #[derive(Debug, Deserialize)]
253 : pub(crate) struct WakeCompute {
254 : pub(crate) address: Box<str>,
255 : pub(crate) aux: MetricsAuxInfo,
256 : }
257 :
258 : /// Async response which concludes the console redirect auth flow.
259 : /// Also known as `kickResponse` in the console.
260 4 : #[derive(Debug, Deserialize)]
261 : pub(crate) struct KickSession<'a> {
262 : /// Session ID is assigned by the proxy.
263 : pub(crate) session_id: &'a str,
264 :
265 : /// Compute node connection params.
266 : #[serde(deserialize_with = "KickSession::parse_db_info")]
267 : pub(crate) result: DatabaseInfo,
268 : }
269 :
270 : impl KickSession<'_> {
271 1 : fn parse_db_info<'de, D>(des: D) -> Result<DatabaseInfo, D::Error>
272 1 : where
273 1 : D: serde::Deserializer<'de>,
274 1 : {
275 2 : #[derive(Deserialize)]
276 : enum Wrapper {
277 : // Currently, console only reports `Success`.
278 : // `Failure(String)` used to be here... RIP.
279 : Success(DatabaseInfo),
280 : }
281 :
282 1 : Wrapper::deserialize(des).map(|x| match x {
283 1 : Wrapper::Success(info) => info,
284 1 : })
285 1 : }
286 : }
287 :
288 : /// Compute node connection params.
289 36 : #[derive(Deserialize)]
290 : pub(crate) struct DatabaseInfo {
291 : pub(crate) host: Box<str>,
292 : pub(crate) port: u16,
293 : pub(crate) dbname: Box<str>,
294 : pub(crate) user: Box<str>,
295 : /// Console always provides a password, but it might
296 : /// be inconvenient for debug with local PG instance.
297 : pub(crate) password: Option<Box<str>>,
298 : pub(crate) aux: MetricsAuxInfo,
299 : #[serde(default)]
300 : pub(crate) allowed_ips: Option<Vec<IpPattern>>,
301 : }
302 :
303 : // Manually implement debug to omit sensitive info.
304 : impl fmt::Debug for DatabaseInfo {
305 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
306 0 : f.debug_struct("DatabaseInfo")
307 0 : .field("host", &self.host)
308 0 : .field("port", &self.port)
309 0 : .field("dbname", &self.dbname)
310 0 : .field("user", &self.user)
311 0 : .field("allowed_ips", &self.allowed_ips)
312 0 : .finish_non_exhaustive()
313 0 : }
314 : }
315 :
316 : /// Various labels for prometheus metrics.
317 : /// Also known as `ProxyMetricsAuxInfo` in the console.
318 30 : #[derive(Debug, Deserialize, Clone)]
319 : pub(crate) struct MetricsAuxInfo {
320 : pub(crate) endpoint_id: EndpointIdInt,
321 : pub(crate) project_id: ProjectIdInt,
322 : pub(crate) branch_id: BranchIdInt,
323 : #[serde(default)]
324 : pub(crate) cold_start_info: ColdStartInfo,
325 : }
326 :
327 12 : #[derive(Debug, Default, Serialize, Deserialize, Clone, Copy, FixedCardinalityLabel)]
328 : #[serde(rename_all = "snake_case")]
329 : pub enum ColdStartInfo {
330 : #[default]
331 : Unknown,
332 : /// Compute was already running
333 : Warm,
334 : #[serde(rename = "pool_hit")]
335 : #[label(rename = "pool_hit")]
336 : /// Compute was not running but there was an available VM
337 : VmPoolHit,
338 : #[serde(rename = "pool_miss")]
339 : #[label(rename = "pool_miss")]
340 : /// Compute was not running and there were no VMs available
341 : VmPoolMiss,
342 :
343 : // not provided by control plane
344 : /// Connection available from HTTP pool
345 : HttpPoolHit,
346 : /// Cached connection info
347 : WarmCached,
348 : }
349 :
350 : impl ColdStartInfo {
351 0 : pub(crate) fn as_str(self) -> &'static str {
352 0 : match self {
353 0 : ColdStartInfo::Unknown => "unknown",
354 0 : ColdStartInfo::Warm => "warm",
355 0 : ColdStartInfo::VmPoolHit => "pool_hit",
356 0 : ColdStartInfo::VmPoolMiss => "pool_miss",
357 0 : ColdStartInfo::HttpPoolHit => "http_pool_hit",
358 0 : ColdStartInfo::WarmCached => "warm_cached",
359 : }
360 0 : }
361 : }
362 :
363 0 : #[derive(Debug, Deserialize, Clone)]
364 : pub struct EndpointJwksResponse {
365 : pub jwks: Vec<JwksSettings>,
366 : }
367 :
368 0 : #[derive(Debug, Deserialize, Clone)]
369 : pub struct JwksSettings {
370 : pub id: String,
371 : pub jwks_url: url::Url,
372 : pub provider_name: String,
373 : pub jwt_audience: Option<String>,
374 : pub role_names: Vec<RoleNameInt>,
375 : }
376 :
377 : #[cfg(test)]
378 : mod tests {
379 : use serde_json::json;
380 :
381 : use super::*;
382 :
383 6 : fn dummy_aux() -> serde_json::Value {
384 6 : json!({
385 6 : "endpoint_id": "endpoint",
386 6 : "project_id": "project",
387 6 : "branch_id": "branch",
388 6 : "cold_start_info": "unknown",
389 6 : })
390 6 : }
391 :
392 : #[test]
393 1 : fn parse_kick_session() -> anyhow::Result<()> {
394 1 : // This is what the console's kickResponse looks like.
395 1 : let json = json!({
396 1 : "session_id": "deadbeef",
397 1 : "result": {
398 1 : "Success": {
399 1 : "host": "localhost",
400 1 : "port": 5432,
401 1 : "dbname": "postgres",
402 1 : "user": "john_doe",
403 1 : "password": "password",
404 1 : "aux": dummy_aux(),
405 1 : }
406 1 : }
407 1 : });
408 1 : serde_json::from_str::<KickSession<'_>>(&json.to_string())?;
409 :
410 1 : Ok(())
411 1 : }
412 :
413 : #[test]
414 1 : fn parse_db_info() -> anyhow::Result<()> {
415 1 : // with password
416 1 : serde_json::from_value::<DatabaseInfo>(json!({
417 1 : "host": "localhost",
418 1 : "port": 5432,
419 1 : "dbname": "postgres",
420 1 : "user": "john_doe",
421 1 : "password": "password",
422 1 : "aux": dummy_aux(),
423 1 : }))?;
424 :
425 : // without password
426 1 : serde_json::from_value::<DatabaseInfo>(json!({
427 1 : "host": "localhost",
428 1 : "port": 5432,
429 1 : "dbname": "postgres",
430 1 : "user": "john_doe",
431 1 : "aux": dummy_aux(),
432 1 : }))?;
433 :
434 : // new field (forward compatibility)
435 1 : serde_json::from_value::<DatabaseInfo>(json!({
436 1 : "host": "localhost",
437 1 : "port": 5432,
438 1 : "dbname": "postgres",
439 1 : "user": "john_doe",
440 1 : "project": "hello_world",
441 1 : "N.E.W": "forward compatibility check",
442 1 : "aux": dummy_aux(),
443 1 : }))?;
444 :
445 : // with allowed_ips
446 1 : let dbinfo = serde_json::from_value::<DatabaseInfo>(json!({
447 1 : "host": "localhost",
448 1 : "port": 5432,
449 1 : "dbname": "postgres",
450 1 : "user": "john_doe",
451 1 : "password": "password",
452 1 : "aux": dummy_aux(),
453 1 : "allowed_ips": ["127.0.0.1"],
454 1 : }))?;
455 :
456 1 : assert_eq!(
457 1 : dbinfo.allowed_ips,
458 1 : Some(vec![IpPattern::Single("127.0.0.1".parse()?)])
459 : );
460 :
461 1 : Ok(())
462 1 : }
463 :
464 : #[test]
465 1 : fn parse_wake_compute() -> anyhow::Result<()> {
466 1 : let json = json!({
467 1 : "address": "0.0.0.0",
468 1 : "aux": dummy_aux(),
469 1 : });
470 1 : serde_json::from_str::<WakeCompute>(&json.to_string())?;
471 1 : Ok(())
472 1 : }
473 :
474 : #[test]
475 1 : fn parse_get_role_secret() -> anyhow::Result<()> {
476 1 : // Empty `allowed_ips` field.
477 1 : let json = json!({
478 1 : "role_secret": "secret",
479 1 : });
480 1 : serde_json::from_str::<GetRoleSecret>(&json.to_string())?;
481 1 : let json = json!({
482 1 : "role_secret": "secret",
483 1 : "allowed_ips": ["8.8.8.8"],
484 1 : });
485 1 : serde_json::from_str::<GetRoleSecret>(&json.to_string())?;
486 1 : let json = json!({
487 1 : "role_secret": "secret",
488 1 : "allowed_ips": ["8.8.8.8"],
489 1 : "project_id": "project",
490 1 : });
491 1 : serde_json::from_str::<GetRoleSecret>(&json.to_string())?;
492 :
493 1 : Ok(())
494 1 : }
495 : }
|