Line data Source code
1 : use std::fmt::{self, Display};
2 :
3 : use measured::FixedCardinalityLabel;
4 : use serde::{Deserialize, Serialize};
5 : use smol_str::SmolStr;
6 :
7 : use crate::auth::IpPattern;
8 : use crate::intern::{AccountIdInt, BranchIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt};
9 : use crate::proxy::retry::CouldRetry;
10 :
11 : /// Generic error response with human-readable description.
12 : /// Note that we can't always present it to user as is.
13 0 : #[derive(Debug, Deserialize, Clone)]
14 : pub(crate) struct ControlPlaneErrorMessage {
15 : pub(crate) error: Box<str>,
16 : #[serde(skip)]
17 : pub(crate) http_status_code: http::StatusCode,
18 : pub(crate) status: Option<Status>,
19 : }
20 :
21 : impl ControlPlaneErrorMessage {
22 3 : pub(crate) fn get_reason(&self) -> Reason {
23 3 : self.status
24 3 : .as_ref()
25 3 : .and_then(|s| s.details.error_info.as_ref())
26 3 : .map_or(Reason::Unknown, |e| e.reason)
27 3 : }
28 :
29 0 : pub(crate) fn get_user_facing_message(&self) -> String {
30 : use super::errors::REQUEST_FAILED;
31 0 : self.status
32 0 : .as_ref()
33 0 : .and_then(|s| s.details.user_facing_message.as_ref())
34 0 : .map_or_else(|| {
35 : // Ask @neondatabase/control-plane for review before adding more.
36 0 : match self.http_status_code {
37 : http::StatusCode::NOT_FOUND => {
38 : // Status 404: failed to get a project-related resource.
39 0 : format!("{REQUEST_FAILED}: endpoint cannot be found")
40 : }
41 : http::StatusCode::NOT_ACCEPTABLE => {
42 : // Status 406: endpoint is disabled (we don't allow connections).
43 0 : format!("{REQUEST_FAILED}: endpoint is disabled")
44 : }
45 : http::StatusCode::LOCKED | http::StatusCode::UNPROCESSABLE_ENTITY => {
46 : // Status 423: project might be in maintenance mode (or bad state), or quotas exceeded.
47 0 : format!("{REQUEST_FAILED}: endpoint is temporarily unavailable. Check your quotas and/or contact our support.")
48 : }
49 0 : _ => REQUEST_FAILED.to_owned(),
50 : }
51 0 : }, |m| m.message.clone().into())
52 0 : }
53 : }
54 :
55 : impl Display for ControlPlaneErrorMessage {
56 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57 0 : let msg: &str = self
58 0 : .status
59 0 : .as_ref()
60 0 : .and_then(|s| s.details.user_facing_message.as_ref())
61 0 : .map_or_else(|| self.error.as_ref(), |m| m.message.as_ref());
62 0 : write!(f, "{msg}")
63 0 : }
64 : }
65 :
66 : impl CouldRetry for ControlPlaneErrorMessage {
67 6 : fn could_retry(&self) -> bool {
68 : // If the error message does not have a status,
69 : // the error is unknown and probably should not retry automatically
70 6 : let Some(status) = &self.status else {
71 2 : return false;
72 : };
73 :
74 : // retry if the retry info is set.
75 4 : if status.details.retry_info.is_some() {
76 4 : return true;
77 0 : }
78 :
79 : // if no retry info set, attempt to use the error code to guess the retry state.
80 0 : let reason = status
81 0 : .details
82 0 : .error_info
83 0 : .map_or(Reason::Unknown, |e| e.reason);
84 :
85 0 : reason.can_retry()
86 6 : }
87 : }
88 :
89 0 : #[derive(Debug, Deserialize, Clone)]
90 : #[allow(dead_code)]
91 : pub(crate) struct Status {
92 : pub(crate) code: Box<str>,
93 : pub(crate) message: Box<str>,
94 : pub(crate) details: Details,
95 : }
96 :
97 0 : #[derive(Debug, Deserialize, Clone)]
98 : pub(crate) struct Details {
99 : pub(crate) error_info: Option<ErrorInfo>,
100 : pub(crate) retry_info: Option<RetryInfo>,
101 : pub(crate) user_facing_message: Option<UserFacingMessage>,
102 : }
103 :
104 0 : #[derive(Copy, Clone, Debug, Deserialize)]
105 : pub(crate) struct ErrorInfo {
106 : pub(crate) reason: Reason,
107 : // Schema could also have `metadata` field, but it's not structured. Skip it for now.
108 : }
109 :
110 0 : #[derive(Clone, Copy, Debug, Deserialize, Default)]
111 : pub(crate) enum Reason {
112 : /// RoleProtected indicates that the role is protected and the attempted operation is not permitted on protected roles.
113 : #[serde(rename = "ROLE_PROTECTED")]
114 : RoleProtected,
115 : /// ResourceNotFound indicates that a resource (project, endpoint, branch, etc.) wasn't found,
116 : /// usually due to the provided ID not being correct or because the subject doesn't have enough permissions to
117 : /// access the requested resource.
118 : /// Prefer a more specific reason if possible, e.g., ProjectNotFound, EndpointNotFound, etc.
119 : #[serde(rename = "RESOURCE_NOT_FOUND")]
120 : ResourceNotFound,
121 : /// ProjectNotFound indicates that the project wasn't found, usually due to the provided ID not being correct,
122 : /// or that the subject doesn't have enough permissions to access the requested project.
123 : #[serde(rename = "PROJECT_NOT_FOUND")]
124 : ProjectNotFound,
125 : /// EndpointNotFound indicates that the endpoint wasn't found, usually due to the provided ID not being correct,
126 : /// or that the subject doesn't have enough permissions to access the requested endpoint.
127 : #[serde(rename = "ENDPOINT_NOT_FOUND")]
128 : EndpointNotFound,
129 : /// EndpointDisabled indicates that the endpoint has been disabled and does not accept connections.
130 : #[serde(rename = "ENDPOINT_DISABLED")]
131 : EndpointDisabled,
132 : /// BranchNotFound indicates that the branch wasn't found, usually due to the provided ID not being correct,
133 : /// or that the subject doesn't have enough permissions to access the requested branch.
134 : #[serde(rename = "BRANCH_NOT_FOUND")]
135 : BranchNotFound,
136 : /// InvalidEphemeralEndpointOptions indicates that the specified LSN or timestamp are wrong.
137 : #[serde(rename = "INVALID_EPHEMERAL_OPTIONS")]
138 : InvalidEphemeralEndpointOptions,
139 : /// RateLimitExceeded indicates that the rate limit for the operation has been exceeded.
140 : #[serde(rename = "RATE_LIMIT_EXCEEDED")]
141 : RateLimitExceeded,
142 : /// NonDefaultBranchComputeTimeExceeded indicates that the compute time quota of non-default branches has been
143 : /// exceeded.
144 : #[serde(rename = "NON_PRIMARY_BRANCH_COMPUTE_TIME_EXCEEDED")]
145 : NonDefaultBranchComputeTimeExceeded,
146 : /// ActiveTimeQuotaExceeded indicates that the active time quota was exceeded.
147 : #[serde(rename = "ACTIVE_TIME_QUOTA_EXCEEDED")]
148 : ActiveTimeQuotaExceeded,
149 : /// ComputeTimeQuotaExceeded indicates that the compute time quota was exceeded.
150 : #[serde(rename = "COMPUTE_TIME_QUOTA_EXCEEDED")]
151 : ComputeTimeQuotaExceeded,
152 : /// WrittenDataQuotaExceeded indicates that the written data quota was exceeded.
153 : #[serde(rename = "WRITTEN_DATA_QUOTA_EXCEEDED")]
154 : WrittenDataQuotaExceeded,
155 : /// DataTransferQuotaExceeded indicates that the data transfer quota was exceeded.
156 : #[serde(rename = "DATA_TRANSFER_QUOTA_EXCEEDED")]
157 : DataTransferQuotaExceeded,
158 : /// LogicalSizeQuotaExceeded indicates that the logical size quota was exceeded.
159 : #[serde(rename = "LOGICAL_SIZE_QUOTA_EXCEEDED")]
160 : LogicalSizeQuotaExceeded,
161 : /// ActiveEndpointsLimitExceeded indicates that the limit of concurrently active endpoints was exceeded.
162 : #[serde(rename = "ACTIVE_ENDPOINTS_LIMIT_EXCEEDED")]
163 : ActiveEndpointsLimitExceeded,
164 : /// RunningOperations indicates that the project already has some running operations
165 : /// and scheduling of new ones is prohibited.
166 : #[serde(rename = "RUNNING_OPERATIONS")]
167 : RunningOperations,
168 : /// ConcurrencyLimitReached indicates that the concurrency limit for an action was reached.
169 : #[serde(rename = "CONCURRENCY_LIMIT_REACHED")]
170 : ConcurrencyLimitReached,
171 : /// LockAlreadyTaken indicates that the we attempted to take a lock that was already taken.
172 : #[serde(rename = "LOCK_ALREADY_TAKEN")]
173 : LockAlreadyTaken,
174 : /// EndpointIdle indicates that the endpoint cannot become active, because it's idle.
175 : #[serde(rename = "ENDPOINT_IDLE")]
176 : EndpointIdle,
177 : /// ProjectUnderMaintenance indicates that the project is currently ongoing maintenance,
178 : /// and thus cannot accept connections.
179 : #[serde(rename = "PROJECT_UNDER_MAINTENANCE")]
180 : ProjectUnderMaintenance,
181 : #[default]
182 : #[serde(other)]
183 : Unknown,
184 : }
185 :
186 : impl Reason {
187 0 : pub(crate) fn is_not_found(self) -> bool {
188 0 : matches!(
189 0 : self,
190 : Reason::ResourceNotFound
191 : | Reason::ProjectNotFound
192 : | Reason::EndpointNotFound
193 : | Reason::BranchNotFound
194 : )
195 0 : }
196 :
197 0 : pub(crate) fn can_retry(self) -> bool {
198 0 : match self {
199 : // do not retry role protected errors
200 : // not a transient error
201 0 : Reason::RoleProtected => false,
202 : // on retry, it will still not be found or valid
203 : Reason::ResourceNotFound
204 : | Reason::ProjectNotFound
205 : | Reason::EndpointNotFound
206 : | Reason::EndpointDisabled
207 : | Reason::BranchNotFound
208 0 : | Reason::InvalidEphemeralEndpointOptions => false,
209 : // we were asked to go away
210 : Reason::RateLimitExceeded
211 : | Reason::NonDefaultBranchComputeTimeExceeded
212 : | Reason::ActiveTimeQuotaExceeded
213 : | Reason::ComputeTimeQuotaExceeded
214 : | Reason::WrittenDataQuotaExceeded
215 : | Reason::DataTransferQuotaExceeded
216 : | Reason::LogicalSizeQuotaExceeded
217 0 : | Reason::ActiveEndpointsLimitExceeded => false,
218 : // transient error. control plane is currently busy
219 : // but might be ready soon
220 : Reason::RunningOperations
221 : | Reason::ConcurrencyLimitReached
222 : | Reason::LockAlreadyTaken
223 : | Reason::EndpointIdle
224 0 : | Reason::ProjectUnderMaintenance => true,
225 : // unknown error. better not retry it.
226 0 : Reason::Unknown => false,
227 : }
228 0 : }
229 : }
230 :
231 0 : #[derive(Copy, Clone, Debug, Deserialize)]
232 : #[allow(dead_code)]
233 : pub(crate) struct RetryInfo {
234 : pub(crate) retry_delay_ms: u64,
235 : }
236 :
237 0 : #[derive(Debug, Deserialize, Clone)]
238 : pub(crate) struct UserFacingMessage {
239 : pub(crate) message: Box<str>,
240 : }
241 :
242 : /// Response which holds client's auth secret, e.g. [`crate::scram::ServerSecret`].
243 : /// Returned by the `/get_endpoint_access_control` API method.
244 0 : #[derive(Deserialize)]
245 : pub(crate) struct GetEndpointAccessControl {
246 : pub(crate) role_secret: Box<str>,
247 :
248 : pub(crate) project_id: Option<ProjectIdInt>,
249 : pub(crate) account_id: Option<AccountIdInt>,
250 :
251 : pub(crate) allowed_ips: Option<Vec<IpPattern>>,
252 : pub(crate) allowed_vpc_endpoint_ids: Option<Vec<String>>,
253 : pub(crate) block_public_connections: Option<bool>,
254 : pub(crate) block_vpc_connections: Option<bool>,
255 :
256 : #[serde(default)]
257 : pub(crate) rate_limits: EndpointRateLimitConfig,
258 : }
259 :
260 0 : #[derive(Copy, Clone, Deserialize, Default)]
261 : pub struct EndpointRateLimitConfig {
262 : pub connection_attempts: ConnectionAttemptsLimit,
263 : }
264 :
265 0 : #[derive(Copy, Clone, Deserialize, Default)]
266 : pub struct ConnectionAttemptsLimit {
267 : pub tcp: Option<LeakyBucketSetting>,
268 : pub ws: Option<LeakyBucketSetting>,
269 : pub http: Option<LeakyBucketSetting>,
270 : }
271 :
272 0 : #[derive(Copy, Clone, Deserialize)]
273 : pub struct LeakyBucketSetting {
274 : pub rps: f64,
275 : pub burst: f64,
276 : }
277 :
278 : /// Response which holds compute node's `host:port` pair.
279 : /// Returned by the `/proxy_wake_compute` API method.
280 0 : #[derive(Debug, Deserialize)]
281 : pub(crate) struct WakeCompute {
282 : pub(crate) address: Box<str>,
283 : pub(crate) server_name: Option<String>,
284 : pub(crate) aux: MetricsAuxInfo,
285 : }
286 :
287 : /// Async response which concludes the console redirect auth flow.
288 : /// Also known as `kickResponse` in the console.
289 : #[derive(Debug, Deserialize)]
290 : pub(crate) struct KickSession<'a> {
291 : /// Session ID is assigned by the proxy.
292 : pub(crate) session_id: &'a str,
293 :
294 : /// Compute node connection params.
295 : #[serde(deserialize_with = "KickSession::parse_db_info")]
296 : pub(crate) result: DatabaseInfo,
297 : }
298 :
299 : impl KickSession<'_> {
300 1 : fn parse_db_info<'de, D>(des: D) -> Result<DatabaseInfo, D::Error>
301 1 : where
302 1 : D: serde::Deserializer<'de>,
303 : {
304 0 : #[derive(Deserialize)]
305 : enum Wrapper {
306 : // Currently, console only reports `Success`.
307 : // `Failure(String)` used to be here... RIP.
308 : Success(DatabaseInfo),
309 : }
310 :
311 1 : Wrapper::deserialize(des).map(|x| match x {
312 1 : Wrapper::Success(info) => info,
313 1 : })
314 1 : }
315 : }
316 :
317 : /// Compute node connection params.
318 0 : #[derive(Deserialize)]
319 : pub(crate) struct DatabaseInfo {
320 : pub(crate) host: Box<str>,
321 : pub(crate) port: u16,
322 : pub(crate) dbname: Box<str>,
323 : pub(crate) user: Box<str>,
324 : /// Console always provides a password, but it might
325 : /// be inconvenient for debug with local PG instance.
326 : pub(crate) password: Option<Box<str>>,
327 : pub(crate) aux: MetricsAuxInfo,
328 : #[serde(default)]
329 : pub(crate) allowed_ips: Option<Vec<IpPattern>>,
330 : #[serde(default)]
331 : pub(crate) allowed_vpc_endpoint_ids: Option<Vec<String>>,
332 : #[serde(default)]
333 : pub(crate) public_access_allowed: Option<bool>,
334 : }
335 :
336 : // Manually implement debug to omit sensitive info.
337 : impl fmt::Debug for DatabaseInfo {
338 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
339 0 : f.debug_struct("DatabaseInfo")
340 0 : .field("host", &self.host)
341 0 : .field("port", &self.port)
342 0 : .field("dbname", &self.dbname)
343 0 : .field("user", &self.user)
344 0 : .field("allowed_ips", &self.allowed_ips)
345 0 : .field("allowed_vpc_endpoint_ids", &self.allowed_vpc_endpoint_ids)
346 0 : .finish_non_exhaustive()
347 0 : }
348 : }
349 :
350 : /// Various labels for prometheus metrics.
351 : /// Also known as `ProxyMetricsAuxInfo` in the console.
352 0 : #[derive(Debug, Deserialize, Clone)]
353 : pub(crate) struct MetricsAuxInfo {
354 : pub(crate) endpoint_id: EndpointIdInt,
355 : pub(crate) project_id: ProjectIdInt,
356 : pub(crate) branch_id: BranchIdInt,
357 : // note: we don't use interned strings for compute IDs.
358 : // they churn too quickly and we have no way to clean up interned strings.
359 : pub(crate) compute_id: SmolStr,
360 : #[serde(default)]
361 : pub(crate) cold_start_info: ColdStartInfo,
362 : }
363 :
364 0 : #[derive(Debug, Default, Serialize, Deserialize, Clone, Copy, FixedCardinalityLabel)]
365 : #[serde(rename_all = "snake_case")]
366 : pub enum ColdStartInfo {
367 : #[default]
368 : Unknown,
369 : /// Compute was already running
370 : Warm,
371 : #[serde(rename = "pool_hit")]
372 : #[label(rename = "pool_hit")]
373 : /// Compute was not running but there was an available VM
374 : VmPoolHit,
375 : #[serde(rename = "pool_miss")]
376 : #[label(rename = "pool_miss")]
377 : /// Compute was not running and there were no VMs available
378 : VmPoolMiss,
379 :
380 : // not provided by control plane
381 : /// Connection available from HTTP pool
382 : HttpPoolHit,
383 : /// Cached connection info
384 : WarmCached,
385 : }
386 :
387 : impl ColdStartInfo {
388 0 : pub(crate) fn as_str(self) -> &'static str {
389 0 : match self {
390 0 : ColdStartInfo::Unknown => "unknown",
391 0 : ColdStartInfo::Warm => "warm",
392 0 : ColdStartInfo::VmPoolHit => "pool_hit",
393 0 : ColdStartInfo::VmPoolMiss => "pool_miss",
394 0 : ColdStartInfo::HttpPoolHit => "http_pool_hit",
395 0 : ColdStartInfo::WarmCached => "warm_cached",
396 : }
397 0 : }
398 : }
399 :
400 0 : #[derive(Debug, Deserialize, Clone)]
401 : pub struct EndpointJwksResponse {
402 : pub jwks: Vec<JwksSettings>,
403 : }
404 :
405 0 : #[derive(Debug, Deserialize, Clone)]
406 : pub struct JwksSettings {
407 : pub id: String,
408 : pub jwks_url: url::Url,
409 : #[serde(rename = "provider_name")]
410 : pub _provider_name: String,
411 : pub jwt_audience: Option<String>,
412 : pub role_names: Vec<RoleNameInt>,
413 : }
414 :
415 : #[cfg(test)]
416 : mod tests {
417 : use serde_json::json;
418 :
419 : use super::*;
420 :
421 6 : fn dummy_aux() -> serde_json::Value {
422 6 : json!({
423 6 : "endpoint_id": "endpoint",
424 6 : "project_id": "project",
425 6 : "branch_id": "branch",
426 6 : "compute_id": "compute",
427 6 : "cold_start_info": "unknown",
428 : })
429 6 : }
430 :
431 : #[test]
432 1 : fn parse_kick_session() -> anyhow::Result<()> {
433 : // This is what the console's kickResponse looks like.
434 1 : let json = json!({
435 1 : "session_id": "deadbeef",
436 1 : "result": {
437 1 : "Success": {
438 1 : "host": "localhost",
439 1 : "port": 5432,
440 1 : "dbname": "postgres",
441 1 : "user": "john_doe",
442 1 : "password": "password",
443 1 : "aux": dummy_aux(),
444 : }
445 : }
446 : });
447 1 : serde_json::from_str::<KickSession<'_>>(&json.to_string())?;
448 :
449 1 : Ok(())
450 1 : }
451 :
452 : #[test]
453 1 : fn parse_db_info() -> anyhow::Result<()> {
454 : // with password
455 1 : serde_json::from_value::<DatabaseInfo>(json!({
456 1 : "host": "localhost",
457 1 : "port": 5432,
458 1 : "dbname": "postgres",
459 1 : "user": "john_doe",
460 1 : "password": "password",
461 1 : "aux": dummy_aux(),
462 0 : }))?;
463 :
464 : // without password
465 1 : serde_json::from_value::<DatabaseInfo>(json!({
466 1 : "host": "localhost",
467 1 : "port": 5432,
468 1 : "dbname": "postgres",
469 1 : "user": "john_doe",
470 1 : "aux": dummy_aux(),
471 0 : }))?;
472 :
473 : // new field (forward compatibility)
474 1 : serde_json::from_value::<DatabaseInfo>(json!({
475 1 : "host": "localhost",
476 1 : "port": 5432,
477 1 : "dbname": "postgres",
478 1 : "user": "john_doe",
479 1 : "project": "hello_world",
480 1 : "N.E.W": "forward compatibility check",
481 1 : "aux": dummy_aux(),
482 0 : }))?;
483 :
484 : // with allowed_ips
485 1 : let dbinfo = serde_json::from_value::<DatabaseInfo>(json!({
486 1 : "host": "localhost",
487 1 : "port": 5432,
488 1 : "dbname": "postgres",
489 1 : "user": "john_doe",
490 1 : "password": "password",
491 1 : "aux": dummy_aux(),
492 1 : "allowed_ips": ["127.0.0.1"],
493 0 : }))?;
494 :
495 1 : assert_eq!(
496 : dbinfo.allowed_ips,
497 1 : Some(vec![IpPattern::Single("127.0.0.1".parse()?)])
498 : );
499 :
500 1 : Ok(())
501 1 : }
502 :
503 : #[test]
504 1 : fn parse_wake_compute() -> anyhow::Result<()> {
505 1 : let json = json!({
506 1 : "address": "0.0.0.0",
507 1 : "aux": dummy_aux(),
508 : });
509 1 : serde_json::from_str::<WakeCompute>(&json.to_string())?;
510 1 : Ok(())
511 1 : }
512 :
513 : #[test]
514 1 : fn parse_get_role_secret() -> anyhow::Result<()> {
515 : // Empty `allowed_ips` and `allowed_vpc_endpoint_ids` field.
516 1 : let json = json!({
517 1 : "role_secret": "secret",
518 : });
519 1 : serde_json::from_str::<GetEndpointAccessControl>(&json.to_string())?;
520 1 : let json = json!({
521 1 : "role_secret": "secret",
522 1 : "allowed_ips": ["8.8.8.8"],
523 : });
524 1 : serde_json::from_str::<GetEndpointAccessControl>(&json.to_string())?;
525 1 : let json = json!({
526 1 : "role_secret": "secret",
527 1 : "allowed_vpc_endpoint_ids": ["vpce-0abcd1234567890ef"],
528 : });
529 1 : serde_json::from_str::<GetEndpointAccessControl>(&json.to_string())?;
530 1 : let json = json!({
531 1 : "role_secret": "secret",
532 1 : "allowed_ips": ["8.8.8.8"],
533 1 : "allowed_vpc_endpoint_ids": ["vpce-0abcd1234567890ef"],
534 : });
535 1 : serde_json::from_str::<GetEndpointAccessControl>(&json.to_string())?;
536 1 : let json = json!({
537 1 : "role_secret": "secret",
538 1 : "allowed_ips": ["8.8.8.8"],
539 1 : "allowed_vpc_endpoint_ids": ["vpce-0abcd1234567890ef"],
540 1 : "project_id": "project",
541 : });
542 1 : serde_json::from_str::<GetEndpointAccessControl>(&json.to_string())?;
543 :
544 1 : Ok(())
545 1 : }
546 : }
|