Line data Source code
1 : use std::fmt::{self, Display};
2 : use std::time::Duration;
3 :
4 : use measured::FixedCardinalityLabel;
5 : use serde::{Deserialize, Serialize};
6 : use smol_str::SmolStr;
7 : use tokio::time::Instant;
8 :
9 : use crate::auth::IpPattern;
10 : use crate::intern::{AccountIdInt, BranchIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt};
11 : use crate::proxy::retry::CouldRetry;
12 :
13 : /// Generic error response with human-readable description.
14 : /// Note that we can't always present it to user as is.
15 0 : #[derive(Debug, Deserialize, Clone)]
16 : pub(crate) struct ControlPlaneErrorMessage {
17 : pub(crate) error: Box<str>,
18 : #[serde(skip)]
19 : pub(crate) http_status_code: http::StatusCode,
20 : pub(crate) status: Option<Status>,
21 : }
22 :
23 : impl ControlPlaneErrorMessage {
24 6 : pub(crate) fn get_reason(&self) -> Reason {
25 6 : self.status
26 6 : .as_ref()
27 6 : .and_then(|s| s.details.error_info.as_ref())
28 6 : .map_or(Reason::Unknown, |e| e.reason)
29 6 : }
30 :
31 0 : pub(crate) fn get_user_facing_message(&self) -> String {
32 : use super::errors::REQUEST_FAILED;
33 0 : self.status
34 0 : .as_ref()
35 0 : .and_then(|s| s.details.user_facing_message.as_ref())
36 0 : .map_or_else(|| {
37 : // Ask @neondatabase/control-plane for review before adding more.
38 0 : match self.http_status_code {
39 : http::StatusCode::NOT_FOUND => {
40 : // Status 404: failed to get a project-related resource.
41 0 : format!("{REQUEST_FAILED}: endpoint cannot be found")
42 : }
43 : http::StatusCode::NOT_ACCEPTABLE => {
44 : // Status 406: endpoint is disabled (we don't allow connections).
45 0 : format!("{REQUEST_FAILED}: endpoint is disabled")
46 : }
47 : http::StatusCode::LOCKED | http::StatusCode::UNPROCESSABLE_ENTITY => {
48 : // Status 423: project might be in maintenance mode (or bad state), or quotas exceeded.
49 0 : format!("{REQUEST_FAILED}: endpoint is temporarily unavailable. Check your quotas and/or contact our support.")
50 : }
51 0 : _ => REQUEST_FAILED.to_owned(),
52 : }
53 0 : }, |m| m.message.clone().into())
54 0 : }
55 : }
56 :
57 : impl Display for ControlPlaneErrorMessage {
58 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
59 0 : let msg: &str = self
60 0 : .status
61 0 : .as_ref()
62 0 : .and_then(|s| s.details.user_facing_message.as_ref())
63 0 : .map_or_else(|| self.error.as_ref(), |m| m.message.as_ref());
64 0 : write!(f, "{msg}")
65 0 : }
66 : }
67 :
68 : impl CouldRetry for ControlPlaneErrorMessage {
69 6 : fn could_retry(&self) -> bool {
70 : // If the error message does not have a status,
71 : // the error is unknown and probably should not retry automatically
72 6 : let Some(status) = &self.status else {
73 2 : return false;
74 : };
75 :
76 : // retry if the retry info is set.
77 4 : if status.details.retry_info.is_some() {
78 4 : return true;
79 0 : }
80 :
81 : // if no retry info set, attempt to use the error code to guess the retry state.
82 0 : let reason = status
83 0 : .details
84 0 : .error_info
85 0 : .map_or(Reason::Unknown, |e| e.reason);
86 :
87 0 : reason.can_retry()
88 6 : }
89 : }
90 :
91 0 : #[derive(Debug, Deserialize, Clone)]
92 : #[allow(dead_code)]
93 : pub(crate) struct Status {
94 : pub(crate) code: Box<str>,
95 : pub(crate) message: Box<str>,
96 : pub(crate) details: Details,
97 : }
98 :
99 0 : #[derive(Debug, Deserialize, Clone)]
100 : pub(crate) struct Details {
101 : pub(crate) error_info: Option<ErrorInfo>,
102 : pub(crate) retry_info: Option<RetryInfo>,
103 : pub(crate) user_facing_message: Option<UserFacingMessage>,
104 : }
105 :
106 0 : #[derive(Copy, Clone, Debug, Deserialize)]
107 : pub(crate) struct ErrorInfo {
108 : pub(crate) reason: Reason,
109 : // Schema could also have `metadata` field, but it's not structured. Skip it for now.
110 : }
111 :
112 0 : #[derive(Clone, Copy, Debug, Deserialize, Default, PartialEq, Eq)]
113 : pub(crate) enum Reason {
114 : /// RoleProtected indicates that the role is protected and the attempted operation is not permitted on protected roles.
115 : #[serde(rename = "ROLE_PROTECTED")]
116 : RoleProtected,
117 : /// ResourceNotFound indicates that a resource (project, endpoint, branch, etc.) wasn't found,
118 : /// usually due to the provided ID not being correct or because the subject doesn't have enough permissions to
119 : /// access the requested resource.
120 : /// Prefer a more specific reason if possible, e.g., ProjectNotFound, EndpointNotFound, etc.
121 : #[serde(rename = "RESOURCE_NOT_FOUND")]
122 : ResourceNotFound,
123 : /// ProjectNotFound indicates that the project wasn't found, usually due to the provided ID not being correct,
124 : /// or that the subject doesn't have enough permissions to access the requested project.
125 : #[serde(rename = "PROJECT_NOT_FOUND")]
126 : ProjectNotFound,
127 : /// EndpointNotFound indicates that the endpoint wasn't found, usually due to the provided ID not being correct,
128 : /// or that the subject doesn't have enough permissions to access the requested endpoint.
129 : #[serde(rename = "ENDPOINT_NOT_FOUND")]
130 : EndpointNotFound,
131 : /// EndpointDisabled indicates that the endpoint has been disabled and does not accept connections.
132 : #[serde(rename = "ENDPOINT_DISABLED")]
133 : EndpointDisabled,
134 : /// BranchNotFound indicates that the branch wasn't found, usually due to the provided ID not being correct,
135 : /// or that the subject doesn't have enough permissions to access the requested branch.
136 : #[serde(rename = "BRANCH_NOT_FOUND")]
137 : BranchNotFound,
138 : /// WrongLsnOrTimestamp indicates that the specified LSN or timestamp are wrong.
139 : #[serde(rename = "WRONG_LSN_OR_TIMESTAMP")]
140 : WrongLsnOrTimestamp,
141 : /// RateLimitExceeded indicates that the rate limit for the operation has been exceeded.
142 : #[serde(rename = "RATE_LIMIT_EXCEEDED")]
143 : RateLimitExceeded,
144 : /// NonDefaultBranchComputeTimeExceeded indicates that the compute time quota of non-default branches has been
145 : /// exceeded.
146 : #[serde(rename = "NON_PRIMARY_BRANCH_COMPUTE_TIME_EXCEEDED")]
147 : NonDefaultBranchComputeTimeExceeded,
148 : /// ActiveTimeQuotaExceeded indicates that the active time quota was exceeded.
149 : #[serde(rename = "ACTIVE_TIME_QUOTA_EXCEEDED")]
150 : ActiveTimeQuotaExceeded,
151 : /// ComputeTimeQuotaExceeded indicates that the compute time quota was exceeded.
152 : #[serde(rename = "COMPUTE_TIME_QUOTA_EXCEEDED")]
153 : ComputeTimeQuotaExceeded,
154 : /// WrittenDataQuotaExceeded indicates that the written data quota was exceeded.
155 : #[serde(rename = "WRITTEN_DATA_QUOTA_EXCEEDED")]
156 : WrittenDataQuotaExceeded,
157 : /// DataTransferQuotaExceeded indicates that the data transfer quota was exceeded.
158 : #[serde(rename = "DATA_TRANSFER_QUOTA_EXCEEDED")]
159 : DataTransferQuotaExceeded,
160 : /// LogicalSizeQuotaExceeded indicates that the logical size quota was exceeded.
161 : #[serde(rename = "LOGICAL_SIZE_QUOTA_EXCEEDED")]
162 : LogicalSizeQuotaExceeded,
163 : /// ActiveEndpointsLimitExceeded indicates that the limit of concurrently active endpoints was exceeded.
164 : #[serde(rename = "ACTIVE_ENDPOINTS_LIMIT_EXCEEDED")]
165 : ActiveEndpointsLimitExceeded,
166 : /// RunningOperations indicates that the project already has some running operations
167 : /// and scheduling of new ones is prohibited.
168 : #[serde(rename = "RUNNING_OPERATIONS")]
169 : RunningOperations,
170 : /// ConcurrencyLimitReached indicates that the concurrency limit for an action was reached.
171 : #[serde(rename = "CONCURRENCY_LIMIT_REACHED")]
172 : ConcurrencyLimitReached,
173 : /// LockAlreadyTaken indicates that the we attempted to take a lock that was already taken.
174 : #[serde(rename = "LOCK_ALREADY_TAKEN")]
175 : LockAlreadyTaken,
176 : /// EndpointIdle indicates that the endpoint cannot become active, because it's idle.
177 : #[serde(rename = "ENDPOINT_IDLE")]
178 : EndpointIdle,
179 : /// ProjectUnderMaintenance indicates that the project is currently ongoing maintenance,
180 : /// and thus cannot accept connections.
181 : #[serde(rename = "PROJECT_UNDER_MAINTENANCE")]
182 : ProjectUnderMaintenance,
183 : #[default]
184 : #[serde(other)]
185 : Unknown,
186 : }
187 :
188 : impl Reason {
189 0 : pub(crate) fn is_not_found(self) -> bool {
190 0 : matches!(
191 0 : self,
192 : Reason::ResourceNotFound
193 : | Reason::ProjectNotFound
194 : | Reason::EndpointNotFound
195 : | Reason::BranchNotFound
196 : )
197 0 : }
198 :
199 0 : pub(crate) fn can_retry(self) -> bool {
200 0 : match self {
201 : // do not retry role protected errors
202 : // not a transient error
203 0 : Reason::RoleProtected => false,
204 : // on retry, it will still not be found or valid
205 : Reason::ResourceNotFound
206 : | Reason::ProjectNotFound
207 : | Reason::EndpointNotFound
208 : | Reason::EndpointDisabled
209 : | Reason::BranchNotFound
210 0 : | Reason::WrongLsnOrTimestamp => false,
211 : // we were asked to go away
212 : Reason::RateLimitExceeded
213 : | Reason::NonDefaultBranchComputeTimeExceeded
214 : | Reason::ActiveTimeQuotaExceeded
215 : | Reason::ComputeTimeQuotaExceeded
216 : | Reason::WrittenDataQuotaExceeded
217 : | Reason::DataTransferQuotaExceeded
218 : | Reason::LogicalSizeQuotaExceeded
219 0 : | Reason::ActiveEndpointsLimitExceeded => false,
220 : // transient error. control plane is currently busy
221 : // but might be ready soon
222 : Reason::RunningOperations
223 : | Reason::ConcurrencyLimitReached
224 : | Reason::LockAlreadyTaken
225 : | Reason::EndpointIdle
226 0 : | Reason::ProjectUnderMaintenance => true,
227 : // unknown error. better not retry it.
228 0 : Reason::Unknown => false,
229 : }
230 0 : }
231 : }
232 :
233 : #[derive(Copy, Clone, Debug, Deserialize)]
234 : #[allow(dead_code)]
235 : pub(crate) struct RetryInfo {
236 : #[serde(rename = "retry_delay_ms", deserialize_with = "milliseconds_from_now")]
237 : pub(crate) retry_at: Instant,
238 : }
239 :
240 0 : fn milliseconds_from_now<'de, D: serde::Deserializer<'de>>(d: D) -> Result<Instant, D::Error> {
241 0 : let millis = u64::deserialize(d)?;
242 0 : Ok(Instant::now() + Duration::from_millis(millis))
243 0 : }
244 :
245 0 : #[derive(Debug, Deserialize, Clone)]
246 : pub(crate) struct UserFacingMessage {
247 : pub(crate) message: Box<str>,
248 : }
249 :
250 : /// Response which holds client's auth secret, e.g. [`crate::scram::ServerSecret`].
251 : /// Returned by the `/get_endpoint_access_control` API method.
252 0 : #[derive(Deserialize)]
253 : pub(crate) struct GetEndpointAccessControl {
254 : pub(crate) role_secret: Box<str>,
255 :
256 : pub(crate) project_id: Option<ProjectIdInt>,
257 : pub(crate) account_id: Option<AccountIdInt>,
258 :
259 : pub(crate) allowed_ips: Option<Vec<IpPattern>>,
260 : pub(crate) allowed_vpc_endpoint_ids: Option<Vec<String>>,
261 : pub(crate) block_public_connections: Option<bool>,
262 : pub(crate) block_vpc_connections: Option<bool>,
263 :
264 : #[serde(default)]
265 : pub(crate) rate_limits: EndpointRateLimitConfig,
266 : }
267 :
268 0 : #[derive(Copy, Clone, Deserialize, Default, Debug)]
269 : pub struct EndpointRateLimitConfig {
270 : pub connection_attempts: ConnectionAttemptsLimit,
271 : }
272 :
273 0 : #[derive(Copy, Clone, Deserialize, Default, Debug)]
274 : pub struct ConnectionAttemptsLimit {
275 : pub tcp: Option<LeakyBucketSetting>,
276 : pub ws: Option<LeakyBucketSetting>,
277 : pub http: Option<LeakyBucketSetting>,
278 : }
279 :
280 0 : #[derive(Copy, Clone, Deserialize, Debug)]
281 : pub struct LeakyBucketSetting {
282 : pub rps: f64,
283 : pub burst: f64,
284 : }
285 :
286 : /// Response which holds compute node's `host:port` pair.
287 : /// Returned by the `/proxy_wake_compute` API method.
288 0 : #[derive(Debug, Deserialize)]
289 : pub(crate) struct WakeCompute {
290 : pub(crate) address: Box<str>,
291 : pub(crate) server_name: Option<String>,
292 : pub(crate) aux: MetricsAuxInfo,
293 : }
294 :
295 : /// Async response which concludes the console redirect auth flow.
296 : /// Also known as `kickResponse` in the console.
297 : #[derive(Debug, Deserialize)]
298 : pub(crate) struct KickSession<'a> {
299 : /// Session ID is assigned by the proxy.
300 : pub(crate) session_id: &'a str,
301 :
302 : /// Compute node connection params.
303 : #[serde(deserialize_with = "KickSession::parse_db_info")]
304 : pub(crate) result: DatabaseInfo,
305 : }
306 :
307 : impl KickSession<'_> {
308 1 : fn parse_db_info<'de, D>(des: D) -> Result<DatabaseInfo, D::Error>
309 1 : where
310 1 : D: serde::Deserializer<'de>,
311 : {
312 0 : #[derive(Deserialize)]
313 : enum Wrapper {
314 : // Currently, console only reports `Success`.
315 : // `Failure(String)` used to be here... RIP.
316 : Success(DatabaseInfo),
317 : }
318 :
319 1 : Wrapper::deserialize(des).map(|x| match x {
320 1 : Wrapper::Success(info) => info,
321 1 : })
322 1 : }
323 : }
324 :
325 : /// Compute node connection params.
326 0 : #[derive(Deserialize)]
327 : pub(crate) struct DatabaseInfo {
328 : pub(crate) host: Box<str>,
329 : pub(crate) port: u16,
330 : pub(crate) dbname: Box<str>,
331 : pub(crate) user: Box<str>,
332 : /// Console always provides a password, but it might
333 : /// be inconvenient for debug with local PG instance.
334 : pub(crate) password: Option<Box<str>>,
335 : pub(crate) aux: MetricsAuxInfo,
336 : #[serde(default)]
337 : pub(crate) allowed_ips: Option<Vec<IpPattern>>,
338 : #[serde(default)]
339 : pub(crate) allowed_vpc_endpoint_ids: Option<Vec<String>>,
340 : #[serde(default)]
341 : pub(crate) public_access_allowed: Option<bool>,
342 : }
343 :
344 : // Manually implement debug to omit sensitive info.
345 : impl fmt::Debug for DatabaseInfo {
346 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
347 0 : f.debug_struct("DatabaseInfo")
348 0 : .field("host", &self.host)
349 0 : .field("port", &self.port)
350 0 : .field("dbname", &self.dbname)
351 0 : .field("user", &self.user)
352 0 : .field("allowed_ips", &self.allowed_ips)
353 0 : .field("allowed_vpc_endpoint_ids", &self.allowed_vpc_endpoint_ids)
354 0 : .finish_non_exhaustive()
355 0 : }
356 : }
357 :
358 : /// Various labels for prometheus metrics.
359 : /// Also known as `ProxyMetricsAuxInfo` in the console.
360 0 : #[derive(Debug, Deserialize, Clone)]
361 : pub(crate) struct MetricsAuxInfo {
362 : pub(crate) endpoint_id: EndpointIdInt,
363 : pub(crate) project_id: ProjectIdInt,
364 : pub(crate) branch_id: BranchIdInt,
365 : // note: we don't use interned strings for compute IDs.
366 : // they churn too quickly and we have no way to clean up interned strings.
367 : pub(crate) compute_id: SmolStr,
368 : #[serde(default)]
369 : pub(crate) cold_start_info: ColdStartInfo,
370 : }
371 :
372 0 : #[derive(Debug, Default, Serialize, Deserialize, Clone, Copy, FixedCardinalityLabel)]
373 : #[serde(rename_all = "snake_case")]
374 : pub enum ColdStartInfo {
375 : #[default]
376 : Unknown,
377 : /// Compute was already running
378 : Warm,
379 : #[serde(rename = "pool_hit")]
380 : #[label(rename = "pool_hit")]
381 : /// Compute was not running but there was an available VM
382 : VmPoolHit,
383 : #[serde(rename = "pool_miss")]
384 : #[label(rename = "pool_miss")]
385 : /// Compute was not running and there were no VMs available
386 : VmPoolMiss,
387 :
388 : // not provided by control plane
389 : /// Connection available from HTTP pool
390 : HttpPoolHit,
391 : /// Cached connection info
392 : WarmCached,
393 : }
394 :
395 : impl ColdStartInfo {
396 0 : pub(crate) fn as_str(self) -> &'static str {
397 0 : match self {
398 0 : ColdStartInfo::Unknown => "unknown",
399 0 : ColdStartInfo::Warm => "warm",
400 0 : ColdStartInfo::VmPoolHit => "pool_hit",
401 0 : ColdStartInfo::VmPoolMiss => "pool_miss",
402 0 : ColdStartInfo::HttpPoolHit => "http_pool_hit",
403 0 : ColdStartInfo::WarmCached => "warm_cached",
404 : }
405 0 : }
406 : }
407 :
408 0 : #[derive(Debug, Deserialize, Clone)]
409 : pub struct EndpointJwksResponse {
410 : pub jwks: Vec<JwksSettings>,
411 : }
412 :
413 0 : #[derive(Debug, Deserialize, Clone)]
414 : pub struct JwksSettings {
415 : pub id: String,
416 : pub jwks_url: url::Url,
417 : #[serde(rename = "provider_name")]
418 : pub _provider_name: String,
419 : pub jwt_audience: Option<String>,
420 : pub role_names: Vec<RoleNameInt>,
421 : }
422 :
423 : #[cfg(test)]
424 : mod tests {
425 : use serde_json::json;
426 :
427 : use super::*;
428 :
429 6 : fn dummy_aux() -> serde_json::Value {
430 6 : json!({
431 6 : "endpoint_id": "endpoint",
432 6 : "project_id": "project",
433 6 : "branch_id": "branch",
434 6 : "compute_id": "compute",
435 6 : "cold_start_info": "unknown",
436 : })
437 6 : }
438 :
439 : #[test]
440 1 : fn parse_kick_session() -> anyhow::Result<()> {
441 : // This is what the console's kickResponse looks like.
442 1 : let json = json!({
443 1 : "session_id": "deadbeef",
444 1 : "result": {
445 1 : "Success": {
446 1 : "host": "localhost",
447 1 : "port": 5432,
448 1 : "dbname": "postgres",
449 1 : "user": "john_doe",
450 1 : "password": "password",
451 1 : "aux": dummy_aux(),
452 : }
453 : }
454 : });
455 1 : serde_json::from_str::<KickSession<'_>>(&json.to_string())?;
456 :
457 1 : Ok(())
458 1 : }
459 :
460 : #[test]
461 1 : fn parse_db_info() -> anyhow::Result<()> {
462 : // with password
463 1 : serde_json::from_value::<DatabaseInfo>(json!({
464 1 : "host": "localhost",
465 1 : "port": 5432,
466 1 : "dbname": "postgres",
467 1 : "user": "john_doe",
468 1 : "password": "password",
469 1 : "aux": dummy_aux(),
470 0 : }))?;
471 :
472 : // without password
473 1 : serde_json::from_value::<DatabaseInfo>(json!({
474 1 : "host": "localhost",
475 1 : "port": 5432,
476 1 : "dbname": "postgres",
477 1 : "user": "john_doe",
478 1 : "aux": dummy_aux(),
479 0 : }))?;
480 :
481 : // new field (forward compatibility)
482 1 : serde_json::from_value::<DatabaseInfo>(json!({
483 1 : "host": "localhost",
484 1 : "port": 5432,
485 1 : "dbname": "postgres",
486 1 : "user": "john_doe",
487 1 : "project": "hello_world",
488 1 : "N.E.W": "forward compatibility check",
489 1 : "aux": dummy_aux(),
490 0 : }))?;
491 :
492 : // with allowed_ips
493 1 : let dbinfo = serde_json::from_value::<DatabaseInfo>(json!({
494 1 : "host": "localhost",
495 1 : "port": 5432,
496 1 : "dbname": "postgres",
497 1 : "user": "john_doe",
498 1 : "password": "password",
499 1 : "aux": dummy_aux(),
500 1 : "allowed_ips": ["127.0.0.1"],
501 0 : }))?;
502 :
503 1 : assert_eq!(
504 : dbinfo.allowed_ips,
505 1 : Some(vec![IpPattern::Single("127.0.0.1".parse()?)])
506 : );
507 :
508 1 : Ok(())
509 1 : }
510 :
511 : #[test]
512 1 : fn parse_wake_compute() -> anyhow::Result<()> {
513 1 : let json = json!({
514 1 : "address": "0.0.0.0",
515 1 : "aux": dummy_aux(),
516 : });
517 1 : serde_json::from_str::<WakeCompute>(&json.to_string())?;
518 1 : Ok(())
519 1 : }
520 :
521 : #[test]
522 1 : fn parse_get_role_secret() -> anyhow::Result<()> {
523 : // Empty `allowed_ips` and `allowed_vpc_endpoint_ids` field.
524 1 : let json = json!({
525 1 : "role_secret": "secret",
526 : });
527 1 : serde_json::from_str::<GetEndpointAccessControl>(&json.to_string())?;
528 1 : let json = json!({
529 1 : "role_secret": "secret",
530 1 : "allowed_ips": ["8.8.8.8"],
531 : });
532 1 : serde_json::from_str::<GetEndpointAccessControl>(&json.to_string())?;
533 1 : let json = json!({
534 1 : "role_secret": "secret",
535 1 : "allowed_vpc_endpoint_ids": ["vpce-0abcd1234567890ef"],
536 : });
537 1 : serde_json::from_str::<GetEndpointAccessControl>(&json.to_string())?;
538 1 : let json = json!({
539 1 : "role_secret": "secret",
540 1 : "allowed_ips": ["8.8.8.8"],
541 1 : "allowed_vpc_endpoint_ids": ["vpce-0abcd1234567890ef"],
542 : });
543 1 : serde_json::from_str::<GetEndpointAccessControl>(&json.to_string())?;
544 1 : let json = json!({
545 1 : "role_secret": "secret",
546 1 : "allowed_ips": ["8.8.8.8"],
547 1 : "allowed_vpc_endpoint_ids": ["vpce-0abcd1234567890ef"],
548 1 : "project_id": "project",
549 : });
550 1 : serde_json::from_str::<GetEndpointAccessControl>(&json.to_string())?;
551 :
552 1 : Ok(())
553 1 : }
554 : }
|