LCOV - code coverage report
Current view: top level - proxy/src/console - messages.rs (source / functions) Coverage Total Hit
Test: fc67f8dc6087a0b4f4f0bcd74f6e1dc25fab8cf3.info Lines: 62.2 % 193 120
Test Date: 2024-09-24 13:57:57 Functions: 19.0 % 200 38

            Line data    Source code
       1              : use measured::FixedCardinalityLabel;
       2              : use serde::{Deserialize, Serialize};
       3              : use std::collections::HashMap;
       4              : use std::fmt::{self, Display};
       5              : 
       6              : use crate::auth::IpPattern;
       7              : 
       8              : use crate::intern::{BranchIdInt, EndpointIdInt, ProjectIdInt};
       9              : use crate::proxy::retry::CouldRetry;
      10              : use crate::RoleName;
      11              : 
      12              : /// Generic error response with human-readable description.
      13              : /// Note that we can't always present it to user as is.
      14            0 : #[derive(Debug, Deserialize, Clone)]
      15              : pub(crate) struct ConsoleError {
      16              :     pub(crate) error: Box<str>,
      17              :     #[serde(skip)]
      18              :     pub(crate) http_status_code: http::StatusCode,
      19              :     pub(crate) status: Option<Status>,
      20              : }
      21              : 
      22              : impl ConsoleError {
      23            3 :     pub(crate) fn get_reason(&self) -> Reason {
      24            3 :         self.status
      25            3 :             .as_ref()
      26            3 :             .and_then(|s| s.details.error_info.as_ref())
      27            3 :             .map_or(Reason::Unknown, |e| e.reason)
      28            3 :     }
      29              : 
      30            0 :     pub(crate) fn get_user_facing_message(&self) -> String {
      31              :         use super::provider::errors::REQUEST_FAILED;
      32            0 :         self.status
      33            0 :             .as_ref()
      34            0 :             .and_then(|s| s.details.user_facing_message.as_ref())
      35            0 :             .map_or_else(|| {
      36            0 :                 // Ask @neondatabase/control-plane for review before adding more.
      37            0 :                 match self.http_status_code {
      38              :                     http::StatusCode::NOT_FOUND => {
      39              :                         // Status 404: failed to get a project-related resource.
      40            0 :                         format!("{REQUEST_FAILED}: endpoint cannot be found")
      41              :                     }
      42              :                     http::StatusCode::NOT_ACCEPTABLE => {
      43              :                         // Status 406: endpoint is disabled (we don't allow connections).
      44            0 :                         format!("{REQUEST_FAILED}: endpoint is disabled")
      45              :                     }
      46              :                     http::StatusCode::LOCKED | http::StatusCode::UNPROCESSABLE_ENTITY => {
      47              :                         // Status 423: project might be in maintenance mode (or bad state), or quotas exceeded.
      48            0 :                         format!("{REQUEST_FAILED}: endpoint is temporarily unavailable. Check your quotas and/or contact our support.")
      49              :                     }
      50            0 :                     _ => REQUEST_FAILED.to_owned(),
      51              :                 }
      52            0 :             }, |m| m.message.clone().into())
      53            0 :     }
      54              : }
      55              : 
      56              : impl Display for ConsoleError {
      57            0 :     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
      58            0 :         let msg: &str = self
      59            0 :             .status
      60            0 :             .as_ref()
      61            0 :             .and_then(|s| s.details.user_facing_message.as_ref())
      62            0 :             .map_or_else(|| self.error.as_ref(), |m| m.message.as_ref());
      63            0 :         write!(f, "{msg}")
      64            0 :     }
      65              : }
      66              : 
      67              : impl CouldRetry for ConsoleError {
      68            6 :     fn could_retry(&self) -> bool {
      69              :         // If the error message does not have a status,
      70              :         // the error is unknown and probably should not retry automatically
      71            6 :         let Some(status) = &self.status else {
      72            2 :             return false;
      73              :         };
      74              : 
      75              :         // retry if the retry info is set.
      76            4 :         if status.details.retry_info.is_some() {
      77            4 :             return true;
      78            0 :         }
      79            0 : 
      80            0 :         // if no retry info set, attempt to use the error code to guess the retry state.
      81            0 :         let reason = status
      82            0 :             .details
      83            0 :             .error_info
      84            0 :             .map_or(Reason::Unknown, |e| e.reason);
      85            0 : 
      86            0 :         reason.can_retry()
      87            6 :     }
      88              : }
      89              : 
      90            0 : #[derive(Debug, Deserialize, Clone)]
      91              : #[allow(dead_code)]
      92              : pub(crate) struct Status {
      93              :     pub(crate) code: Box<str>,
      94              :     pub(crate) message: Box<str>,
      95              :     pub(crate) details: Details,
      96              : }
      97              : 
      98            0 : #[derive(Debug, Deserialize, Clone)]
      99              : pub(crate) struct Details {
     100              :     pub(crate) error_info: Option<ErrorInfo>,
     101              :     pub(crate) retry_info: Option<RetryInfo>,
     102              :     pub(crate) user_facing_message: Option<UserFacingMessage>,
     103              : }
     104              : 
     105            0 : #[derive(Copy, Clone, Debug, Deserialize)]
     106              : pub(crate) struct ErrorInfo {
     107              :     pub(crate) reason: Reason,
     108              :     // Schema could also have `metadata` field, but it's not structured. Skip it for now.
     109              : }
     110              : 
     111            0 : #[derive(Clone, Copy, Debug, Deserialize, Default)]
     112              : pub(crate) enum Reason {
     113              :     /// RoleProtected indicates that the role is protected and the attempted operation is not permitted on protected roles.
     114              :     #[serde(rename = "ROLE_PROTECTED")]
     115              :     RoleProtected,
     116              :     /// ResourceNotFound indicates that a resource (project, endpoint, branch, etc.) wasn't found,
     117              :     /// usually due to the provided ID not being correct or because the subject doesn't have enough permissions to
     118              :     /// access the requested resource.
     119              :     /// Prefer a more specific reason if possible, e.g., ProjectNotFound, EndpointNotFound, etc.
     120              :     #[serde(rename = "RESOURCE_NOT_FOUND")]
     121              :     ResourceNotFound,
     122              :     /// ProjectNotFound indicates that the project wasn't found, usually due to the provided ID not being correct,
     123              :     /// or that the subject doesn't have enough permissions to access the requested project.
     124              :     #[serde(rename = "PROJECT_NOT_FOUND")]
     125              :     ProjectNotFound,
     126              :     /// EndpointNotFound indicates that the endpoint wasn't found, usually due to the provided ID not being correct,
     127              :     /// or that the subject doesn't have enough permissions to access the requested endpoint.
     128              :     #[serde(rename = "ENDPOINT_NOT_FOUND")]
     129              :     EndpointNotFound,
     130              :     /// BranchNotFound indicates that the branch wasn't found, usually due to the provided ID not being correct,
     131              :     /// or that the subject doesn't have enough permissions to access the requested branch.
     132              :     #[serde(rename = "BRANCH_NOT_FOUND")]
     133              :     BranchNotFound,
     134              :     /// RateLimitExceeded indicates that the rate limit for the operation has been exceeded.
     135              :     #[serde(rename = "RATE_LIMIT_EXCEEDED")]
     136              :     RateLimitExceeded,
     137              :     /// NonDefaultBranchComputeTimeExceeded indicates that the compute time quota of non-default branches has been
     138              :     /// exceeded.
     139              :     #[serde(rename = "NON_PRIMARY_BRANCH_COMPUTE_TIME_EXCEEDED")]
     140              :     NonDefaultBranchComputeTimeExceeded,
     141              :     /// ActiveTimeQuotaExceeded indicates that the active time quota was exceeded.
     142              :     #[serde(rename = "ACTIVE_TIME_QUOTA_EXCEEDED")]
     143              :     ActiveTimeQuotaExceeded,
     144              :     /// ComputeTimeQuotaExceeded indicates that the compute time quota was exceeded.
     145              :     #[serde(rename = "COMPUTE_TIME_QUOTA_EXCEEDED")]
     146              :     ComputeTimeQuotaExceeded,
     147              :     /// WrittenDataQuotaExceeded indicates that the written data quota was exceeded.
     148              :     #[serde(rename = "WRITTEN_DATA_QUOTA_EXCEEDED")]
     149              :     WrittenDataQuotaExceeded,
     150              :     /// DataTransferQuotaExceeded indicates that the data transfer quota was exceeded.
     151              :     #[serde(rename = "DATA_TRANSFER_QUOTA_EXCEEDED")]
     152              :     DataTransferQuotaExceeded,
     153              :     /// LogicalSizeQuotaExceeded indicates that the logical size quota was exceeded.
     154              :     #[serde(rename = "LOGICAL_SIZE_QUOTA_EXCEEDED")]
     155              :     LogicalSizeQuotaExceeded,
     156              :     /// RunningOperations indicates that the project already has some running operations
     157              :     /// and scheduling of new ones is prohibited.
     158              :     #[serde(rename = "RUNNING_OPERATIONS")]
     159              :     RunningOperations,
     160              :     /// ConcurrencyLimitReached indicates that the concurrency limit for an action was reached.
     161              :     #[serde(rename = "CONCURRENCY_LIMIT_REACHED")]
     162              :     ConcurrencyLimitReached,
     163              :     /// LockAlreadyTaken indicates that the we attempted to take a lock that was already taken.
     164              :     #[serde(rename = "LOCK_ALREADY_TAKEN")]
     165              :     LockAlreadyTaken,
     166              :     #[default]
     167              :     #[serde(other)]
     168              :     Unknown,
     169              : }
     170              : 
     171              : impl Reason {
     172            0 :     pub(crate) fn is_not_found(self) -> bool {
     173            0 :         matches!(
     174            0 :             self,
     175              :             Reason::ResourceNotFound
     176              :                 | Reason::ProjectNotFound
     177              :                 | Reason::EndpointNotFound
     178              :                 | Reason::BranchNotFound
     179              :         )
     180            0 :     }
     181              : 
     182            0 :     pub(crate) fn can_retry(self) -> bool {
     183            0 :         match self {
     184              :             // do not retry role protected errors
     185              :             // not a transitive error
     186            0 :             Reason::RoleProtected => false,
     187              :             // on retry, it will still not be found
     188              :             Reason::ResourceNotFound
     189              :             | Reason::ProjectNotFound
     190              :             | Reason::EndpointNotFound
     191            0 :             | Reason::BranchNotFound => false,
     192              :             // we were asked to go away
     193              :             Reason::RateLimitExceeded
     194              :             | Reason::NonDefaultBranchComputeTimeExceeded
     195              :             | Reason::ActiveTimeQuotaExceeded
     196              :             | Reason::ComputeTimeQuotaExceeded
     197              :             | Reason::WrittenDataQuotaExceeded
     198              :             | Reason::DataTransferQuotaExceeded
     199            0 :             | Reason::LogicalSizeQuotaExceeded => false,
     200              :             // transitive error. control plane is currently busy
     201              :             // but might be ready soon
     202              :             Reason::RunningOperations
     203              :             | Reason::ConcurrencyLimitReached
     204            0 :             | Reason::LockAlreadyTaken => true,
     205              :             // unknown error. better not retry it.
     206            0 :             Reason::Unknown => false,
     207              :         }
     208            0 :     }
     209              : }
     210              : 
     211            0 : #[derive(Copy, Clone, Debug, Deserialize)]
     212              : #[allow(dead_code)]
     213              : pub(crate) struct RetryInfo {
     214              :     pub(crate) retry_delay_ms: u64,
     215              : }
     216              : 
     217            0 : #[derive(Debug, Deserialize, Clone)]
     218              : pub(crate) struct UserFacingMessage {
     219              :     pub(crate) message: Box<str>,
     220              : }
     221              : 
     222              : /// Response which holds client's auth secret, e.g. [`crate::scram::ServerSecret`].
     223              : /// Returned by the `/proxy_get_role_secret` API method.
     224            9 : #[derive(Deserialize)]
     225              : pub(crate) struct GetRoleSecret {
     226              :     pub(crate) role_secret: Box<str>,
     227              :     pub(crate) allowed_ips: Option<Vec<IpPattern>>,
     228              :     pub(crate) project_id: Option<ProjectIdInt>,
     229              : }
     230              : 
     231              : // Manually implement debug to omit sensitive info.
     232              : impl fmt::Debug for GetRoleSecret {
     233            0 :     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
     234            0 :         f.debug_struct("GetRoleSecret").finish_non_exhaustive()
     235            0 :     }
     236              : }
     237              : 
     238              : /// Response which holds compute node's `host:port` pair.
     239              : /// Returned by the `/proxy_wake_compute` API method.
     240            3 : #[derive(Debug, Deserialize)]
     241              : pub(crate) struct WakeCompute {
     242              :     pub(crate) address: Box<str>,
     243              :     pub(crate) aux: MetricsAuxInfo,
     244              : }
     245              : 
     246              : /// Async response which concludes the web auth flow.
     247              : /// Also known as `kickResponse` in the console.
     248            4 : #[derive(Debug, Deserialize)]
     249              : pub(crate) struct KickSession<'a> {
     250              :     /// Session ID is assigned by the proxy.
     251              :     pub(crate) session_id: &'a str,
     252              : 
     253              :     /// Compute node connection params.
     254              :     #[serde(deserialize_with = "KickSession::parse_db_info")]
     255              :     pub(crate) result: DatabaseInfo,
     256              : }
     257              : 
     258              : impl KickSession<'_> {
     259            1 :     fn parse_db_info<'de, D>(des: D) -> Result<DatabaseInfo, D::Error>
     260            1 :     where
     261            1 :         D: serde::Deserializer<'de>,
     262            1 :     {
     263            2 :         #[derive(Deserialize)]
     264              :         enum Wrapper {
     265              :             // Currently, console only reports `Success`.
     266              :             // `Failure(String)` used to be here... RIP.
     267              :             Success(DatabaseInfo),
     268              :         }
     269              : 
     270            1 :         Wrapper::deserialize(des).map(|x| match x {
     271            1 :             Wrapper::Success(info) => info,
     272            1 :         })
     273            1 :     }
     274              : }
     275              : 
     276              : /// Compute node connection params.
     277           36 : #[derive(Deserialize)]
     278              : pub(crate) struct DatabaseInfo {
     279              :     pub(crate) host: Box<str>,
     280              :     pub(crate) port: u16,
     281              :     pub(crate) dbname: Box<str>,
     282              :     pub(crate) user: Box<str>,
     283              :     /// Console always provides a password, but it might
     284              :     /// be inconvenient for debug with local PG instance.
     285              :     pub(crate) password: Option<Box<str>>,
     286              :     pub(crate) aux: MetricsAuxInfo,
     287              :     #[serde(default)]
     288              :     pub(crate) allowed_ips: Option<Vec<IpPattern>>,
     289              : }
     290              : 
     291              : // Manually implement debug to omit sensitive info.
     292              : impl fmt::Debug for DatabaseInfo {
     293            0 :     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
     294            0 :         f.debug_struct("DatabaseInfo")
     295            0 :             .field("host", &self.host)
     296            0 :             .field("port", &self.port)
     297            0 :             .field("dbname", &self.dbname)
     298            0 :             .field("user", &self.user)
     299            0 :             .field("allowed_ips", &self.allowed_ips)
     300            0 :             .finish_non_exhaustive()
     301            0 :     }
     302              : }
     303              : 
     304              : /// Various labels for prometheus metrics.
     305              : /// Also known as `ProxyMetricsAuxInfo` in the console.
     306           30 : #[derive(Debug, Deserialize, Clone)]
     307              : pub(crate) struct MetricsAuxInfo {
     308              :     pub(crate) endpoint_id: EndpointIdInt,
     309              :     pub(crate) project_id: ProjectIdInt,
     310              :     pub(crate) branch_id: BranchIdInt,
     311              :     #[serde(default)]
     312              :     pub(crate) cold_start_info: ColdStartInfo,
     313              : }
     314              : 
     315           12 : #[derive(Debug, Default, Serialize, Deserialize, Clone, Copy, FixedCardinalityLabel)]
     316              : #[serde(rename_all = "snake_case")]
     317              : pub enum ColdStartInfo {
     318              :     #[default]
     319              :     Unknown,
     320              :     /// Compute was already running
     321              :     Warm,
     322              :     #[serde(rename = "pool_hit")]
     323              :     #[label(rename = "pool_hit")]
     324              :     /// Compute was not running but there was an available VM
     325              :     VmPoolHit,
     326              :     #[serde(rename = "pool_miss")]
     327              :     #[label(rename = "pool_miss")]
     328              :     /// Compute was not running and there were no VMs available
     329              :     VmPoolMiss,
     330              : 
     331              :     // not provided by control plane
     332              :     /// Connection available from HTTP pool
     333              :     HttpPoolHit,
     334              :     /// Cached connection info
     335              :     WarmCached,
     336              : }
     337              : 
     338              : impl ColdStartInfo {
     339            0 :     pub(crate) fn as_str(self) -> &'static str {
     340            0 :         match self {
     341            0 :             ColdStartInfo::Unknown => "unknown",
     342            0 :             ColdStartInfo::Warm => "warm",
     343            0 :             ColdStartInfo::VmPoolHit => "pool_hit",
     344            0 :             ColdStartInfo::VmPoolMiss => "pool_miss",
     345            0 :             ColdStartInfo::HttpPoolHit => "http_pool_hit",
     346            0 :             ColdStartInfo::WarmCached => "warm_cached",
     347              :         }
     348            0 :     }
     349              : }
     350              : 
     351            0 : #[derive(Debug, Deserialize, Clone)]
     352              : pub struct JwksRoleMapping {
     353              :     pub roles: HashMap<RoleName, EndpointJwksResponse>,
     354              : }
     355              : 
     356            0 : #[derive(Debug, Deserialize, Clone)]
     357              : pub struct EndpointJwksResponse {
     358              :     pub jwks: Vec<JwksSettings>,
     359              : }
     360              : 
     361            0 : #[derive(Debug, Deserialize, Clone)]
     362              : pub struct JwksSettings {
     363              :     pub id: String,
     364              :     pub project_id: ProjectIdInt,
     365              :     pub branch_id: BranchIdInt,
     366              :     pub jwks_url: url::Url,
     367              :     pub provider_name: String,
     368              :     pub jwt_audience: Option<String>,
     369              : }
     370              : 
     371              : #[cfg(test)]
     372              : mod tests {
     373              :     use super::*;
     374              :     use serde_json::json;
     375              : 
     376            6 :     fn dummy_aux() -> serde_json::Value {
     377            6 :         json!({
     378            6 :             "endpoint_id": "endpoint",
     379            6 :             "project_id": "project",
     380            6 :             "branch_id": "branch",
     381            6 :             "cold_start_info": "unknown",
     382            6 :         })
     383            6 :     }
     384              : 
     385              :     #[test]
     386            1 :     fn parse_kick_session() -> anyhow::Result<()> {
     387            1 :         // This is what the console's kickResponse looks like.
     388            1 :         let json = json!({
     389            1 :             "session_id": "deadbeef",
     390            1 :             "result": {
     391            1 :                 "Success": {
     392            1 :                     "host": "localhost",
     393            1 :                     "port": 5432,
     394            1 :                     "dbname": "postgres",
     395            1 :                     "user": "john_doe",
     396            1 :                     "password": "password",
     397            1 :                     "aux": dummy_aux(),
     398            1 :                 }
     399            1 :             }
     400            1 :         });
     401            1 :         serde_json::from_str::<KickSession<'_>>(&json.to_string())?;
     402              : 
     403            1 :         Ok(())
     404            1 :     }
     405              : 
     406              :     #[test]
     407            1 :     fn parse_db_info() -> anyhow::Result<()> {
     408            1 :         // with password
     409            1 :         serde_json::from_value::<DatabaseInfo>(json!({
     410            1 :             "host": "localhost",
     411            1 :             "port": 5432,
     412            1 :             "dbname": "postgres",
     413            1 :             "user": "john_doe",
     414            1 :             "password": "password",
     415            1 :             "aux": dummy_aux(),
     416            1 :         }))?;
     417              : 
     418              :         // without password
     419            1 :         serde_json::from_value::<DatabaseInfo>(json!({
     420            1 :             "host": "localhost",
     421            1 :             "port": 5432,
     422            1 :             "dbname": "postgres",
     423            1 :             "user": "john_doe",
     424            1 :             "aux": dummy_aux(),
     425            1 :         }))?;
     426              : 
     427              :         // new field (forward compatibility)
     428            1 :         serde_json::from_value::<DatabaseInfo>(json!({
     429            1 :             "host": "localhost",
     430            1 :             "port": 5432,
     431            1 :             "dbname": "postgres",
     432            1 :             "user": "john_doe",
     433            1 :             "project": "hello_world",
     434            1 :             "N.E.W": "forward compatibility check",
     435            1 :             "aux": dummy_aux(),
     436            1 :         }))?;
     437              : 
     438              :         // with allowed_ips
     439            1 :         let dbinfo = serde_json::from_value::<DatabaseInfo>(json!({
     440            1 :             "host": "localhost",
     441            1 :             "port": 5432,
     442            1 :             "dbname": "postgres",
     443            1 :             "user": "john_doe",
     444            1 :             "password": "password",
     445            1 :             "aux": dummy_aux(),
     446            1 :             "allowed_ips": ["127.0.0.1"],
     447            1 :         }))?;
     448              : 
     449            1 :         assert_eq!(
     450            1 :             dbinfo.allowed_ips,
     451            1 :             Some(vec![IpPattern::Single("127.0.0.1".parse()?)])
     452              :         );
     453              : 
     454            1 :         Ok(())
     455            1 :     }
     456              : 
     457              :     #[test]
     458            1 :     fn parse_wake_compute() -> anyhow::Result<()> {
     459            1 :         let json = json!({
     460            1 :             "address": "0.0.0.0",
     461            1 :             "aux": dummy_aux(),
     462            1 :         });
     463            1 :         serde_json::from_str::<WakeCompute>(&json.to_string())?;
     464            1 :         Ok(())
     465            1 :     }
     466              : 
     467              :     #[test]
     468            1 :     fn parse_get_role_secret() -> anyhow::Result<()> {
     469            1 :         // Empty `allowed_ips` field.
     470            1 :         let json = json!({
     471            1 :             "role_secret": "secret",
     472            1 :         });
     473            1 :         serde_json::from_str::<GetRoleSecret>(&json.to_string())?;
     474            1 :         let json = json!({
     475            1 :             "role_secret": "secret",
     476            1 :             "allowed_ips": ["8.8.8.8"],
     477            1 :         });
     478            1 :         serde_json::from_str::<GetRoleSecret>(&json.to_string())?;
     479            1 :         let json = json!({
     480            1 :             "role_secret": "secret",
     481            1 :             "allowed_ips": ["8.8.8.8"],
     482            1 :             "project_id": "project",
     483            1 :         });
     484            1 :         serde_json::from_str::<GetRoleSecret>(&json.to_string())?;
     485              : 
     486            1 :         Ok(())
     487            1 :     }
     488              : }
        

Generated by: LCOV version 2.1-beta