LCOV - code coverage report
Current view: top level - proxy/src/console - messages.rs (source / functions) Coverage Total Hit
Test: 42f947419473a288706e86ecdf7c2863d760d5d7.info Lines: 61.3 % 186 114
Test Date: 2024-08-02 21:34:27 Functions: 23.2 % 164 38

            Line data    Source code
       1              : use measured::FixedCardinalityLabel;
       2              : use serde::{Deserialize, Serialize};
       3              : use std::fmt::{self, Display};
       4              : 
       5              : use crate::auth::IpPattern;
       6              : 
       7              : use crate::intern::{BranchIdInt, EndpointIdInt, ProjectIdInt};
       8              : use crate::proxy::retry::CouldRetry;
       9              : 
      10              : /// Generic error response with human-readable description.
      11              : /// Note that we can't always present it to user as is.
      12            0 : #[derive(Debug, Deserialize, Clone)]
      13              : pub struct ConsoleError {
      14              :     pub error: Box<str>,
      15              :     #[serde(skip)]
      16              :     pub http_status_code: http::StatusCode,
      17              :     pub status: Option<Status>,
      18              : }
      19              : 
      20              : impl ConsoleError {
      21            6 :     pub fn get_reason(&self) -> Reason {
      22            6 :         self.status
      23            6 :             .as_ref()
      24            6 :             .and_then(|s| s.details.error_info.as_ref())
      25            6 :             .map(|e| e.reason)
      26            6 :             .unwrap_or(Reason::Unknown)
      27            6 :     }
      28            0 :     pub fn get_user_facing_message(&self) -> String {
      29            0 :         use super::provider::errors::REQUEST_FAILED;
      30            0 :         self.status
      31            0 :             .as_ref()
      32            0 :             .and_then(|s| s.details.user_facing_message.as_ref())
      33            0 :             .map(|m| m.message.clone().into())
      34            0 :             .unwrap_or_else(|| {
      35            0 :                 // Ask @neondatabase/control-plane for review before adding more.
      36            0 :                 match self.http_status_code {
      37              :                     http::StatusCode::NOT_FOUND => {
      38              :                         // Status 404: failed to get a project-related resource.
      39            0 :                         format!("{REQUEST_FAILED}: endpoint cannot be found")
      40              :                     }
      41              :                     http::StatusCode::NOT_ACCEPTABLE => {
      42              :                         // Status 406: endpoint is disabled (we don't allow connections).
      43            0 :                         format!("{REQUEST_FAILED}: endpoint is disabled")
      44              :                     }
      45              :                     http::StatusCode::LOCKED | http::StatusCode::UNPROCESSABLE_ENTITY => {
      46              :                         // Status 423: project might be in maintenance mode (or bad state), or quotas exceeded.
      47            0 :                         format!("{REQUEST_FAILED}: endpoint is temporarily unavailable. Check your quotas and/or contact our support.")
      48              :                     }
      49            0 :                     _ => REQUEST_FAILED.to_owned(),
      50              :                 }
      51            0 :             })
      52            0 :     }
      53              : }
      54              : 
      55              : impl Display for ConsoleError {
      56            0 :     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
      57            0 :         let msg = self
      58            0 :             .status
      59            0 :             .as_ref()
      60            0 :             .and_then(|s| s.details.user_facing_message.as_ref())
      61            0 :             .map(|m| m.message.as_ref())
      62            0 :             .unwrap_or_else(|| &self.error);
      63            0 :         write!(f, "{}", msg)
      64            0 :     }
      65              : }
      66              : 
      67              : impl CouldRetry for ConsoleError {
      68           12 :     fn could_retry(&self) -> bool {
      69              :         // If the error message does not have a status,
      70              :         // the error is unknown and probably should not retry automatically
      71           12 :         let Some(status) = &self.status else {
      72            4 :             return false;
      73              :         };
      74              : 
      75              :         // retry if the retry info is set.
      76            8 :         if status.details.retry_info.is_some() {
      77            8 :             return true;
      78            0 :         }
      79            0 : 
      80            0 :         // if no retry info set, attempt to use the error code to guess the retry state.
      81            0 :         let reason = status
      82            0 :             .details
      83            0 :             .error_info
      84            0 :             .map_or(Reason::Unknown, |e| e.reason);
      85            0 : 
      86            0 :         reason.can_retry()
      87           12 :     }
      88              : }
      89              : 
      90            0 : #[derive(Debug, Deserialize, Clone)]
      91              : pub struct Status {
      92              :     pub code: Box<str>,
      93              :     pub message: Box<str>,
      94              :     pub details: Details,
      95              : }
      96              : 
      97            0 : #[derive(Debug, Deserialize, Clone)]
      98              : pub struct Details {
      99              :     pub error_info: Option<ErrorInfo>,
     100              :     pub retry_info: Option<RetryInfo>,
     101              :     pub user_facing_message: Option<UserFacingMessage>,
     102              : }
     103              : 
     104            0 : #[derive(Copy, Clone, Debug, Deserialize)]
     105              : pub struct ErrorInfo {
     106              :     pub reason: Reason,
     107              :     // Schema could also have `metadata` field, but it's not structured. Skip it for now.
     108              : }
     109              : 
     110            0 : #[derive(Clone, Copy, Debug, Deserialize, Default)]
     111              : pub enum Reason {
     112              :     /// RoleProtected indicates that the role is protected and the attempted operation is not permitted on protected roles.
     113              :     #[serde(rename = "ROLE_PROTECTED")]
     114              :     RoleProtected,
     115              :     /// ResourceNotFound indicates that a resource (project, endpoint, branch, etc.) wasn't found,
     116              :     /// usually due to the provided ID not being correct or because the subject doesn't have enough permissions to
     117              :     /// access the requested resource.
     118              :     /// Prefer a more specific reason if possible, e.g., ProjectNotFound, EndpointNotFound, etc.
     119              :     #[serde(rename = "RESOURCE_NOT_FOUND")]
     120              :     ResourceNotFound,
     121              :     /// ProjectNotFound indicates that the project wasn't found, usually due to the provided ID not being correct,
     122              :     /// or that the subject doesn't have enough permissions to access the requested project.
     123              :     #[serde(rename = "PROJECT_NOT_FOUND")]
     124              :     ProjectNotFound,
     125              :     /// EndpointNotFound indicates that the endpoint wasn't found, usually due to the provided ID not being correct,
     126              :     /// or that the subject doesn't have enough permissions to access the requested endpoint.
     127              :     #[serde(rename = "ENDPOINT_NOT_FOUND")]
     128              :     EndpointNotFound,
     129              :     /// BranchNotFound indicates that the branch wasn't found, usually due to the provided ID not being correct,
     130              :     /// or that the subject doesn't have enough permissions to access the requested branch.
     131              :     #[serde(rename = "BRANCH_NOT_FOUND")]
     132              :     BranchNotFound,
     133              :     /// RateLimitExceeded indicates that the rate limit for the operation has been exceeded.
     134              :     #[serde(rename = "RATE_LIMIT_EXCEEDED")]
     135              :     RateLimitExceeded,
     136              :     /// NonDefaultBranchComputeTimeExceeded indicates that the compute time quota of non-default branches has been
     137              :     /// exceeded.
     138              :     #[serde(rename = "NON_PRIMARY_BRANCH_COMPUTE_TIME_EXCEEDED")]
     139              :     NonDefaultBranchComputeTimeExceeded,
     140              :     /// ActiveTimeQuotaExceeded indicates that the active time quota was exceeded.
     141              :     #[serde(rename = "ACTIVE_TIME_QUOTA_EXCEEDED")]
     142              :     ActiveTimeQuotaExceeded,
     143              :     /// ComputeTimeQuotaExceeded indicates that the compute time quota was exceeded.
     144              :     #[serde(rename = "COMPUTE_TIME_QUOTA_EXCEEDED")]
     145              :     ComputeTimeQuotaExceeded,
     146              :     /// WrittenDataQuotaExceeded indicates that the written data quota was exceeded.
     147              :     #[serde(rename = "WRITTEN_DATA_QUOTA_EXCEEDED")]
     148              :     WrittenDataQuotaExceeded,
     149              :     /// DataTransferQuotaExceeded indicates that the data transfer quota was exceeded.
     150              :     #[serde(rename = "DATA_TRANSFER_QUOTA_EXCEEDED")]
     151              :     DataTransferQuotaExceeded,
     152              :     /// LogicalSizeQuotaExceeded indicates that the logical size quota was exceeded.
     153              :     #[serde(rename = "LOGICAL_SIZE_QUOTA_EXCEEDED")]
     154              :     LogicalSizeQuotaExceeded,
     155              :     /// RunningOperations indicates that the project already has some running operations
     156              :     /// and scheduling of new ones is prohibited.
     157              :     #[serde(rename = "RUNNING_OPERATIONS")]
     158              :     RunningOperations,
     159              :     /// ConcurrencyLimitReached indicates that the concurrency limit for an action was reached.
     160              :     #[serde(rename = "CONCURRENCY_LIMIT_REACHED")]
     161              :     ConcurrencyLimitReached,
     162              :     /// LockAlreadyTaken indicates that the we attempted to take a lock that was already taken.
     163              :     #[serde(rename = "LOCK_ALREADY_TAKEN")]
     164              :     LockAlreadyTaken,
     165              :     #[default]
     166              :     #[serde(other)]
     167              :     Unknown,
     168              : }
     169              : 
     170              : impl Reason {
     171            0 :     pub fn is_not_found(&self) -> bool {
     172            0 :         matches!(
     173            0 :             self,
     174              :             Reason::ResourceNotFound
     175              :                 | Reason::ProjectNotFound
     176              :                 | Reason::EndpointNotFound
     177              :                 | Reason::BranchNotFound
     178              :         )
     179            0 :     }
     180              : 
     181            0 :     pub fn can_retry(&self) -> bool {
     182            0 :         match self {
     183              :             // do not retry role protected errors
     184              :             // not a transitive error
     185            0 :             Reason::RoleProtected => false,
     186              :             // on retry, it will still not be found
     187              :             Reason::ResourceNotFound
     188              :             | Reason::ProjectNotFound
     189              :             | Reason::EndpointNotFound
     190            0 :             | Reason::BranchNotFound => false,
     191              :             // we were asked to go away
     192              :             Reason::RateLimitExceeded
     193              :             | Reason::NonDefaultBranchComputeTimeExceeded
     194              :             | Reason::ActiveTimeQuotaExceeded
     195              :             | Reason::ComputeTimeQuotaExceeded
     196              :             | Reason::WrittenDataQuotaExceeded
     197              :             | Reason::DataTransferQuotaExceeded
     198            0 :             | Reason::LogicalSizeQuotaExceeded => false,
     199              :             // transitive error. control plane is currently busy
     200              :             // but might be ready soon
     201              :             Reason::RunningOperations
     202              :             | Reason::ConcurrencyLimitReached
     203            0 :             | Reason::LockAlreadyTaken => true,
     204              :             // unknown error. better not retry it.
     205            0 :             Reason::Unknown => false,
     206              :         }
     207            0 :     }
     208              : }
     209              : 
     210            0 : #[derive(Copy, Clone, Debug, Deserialize)]
     211              : pub struct RetryInfo {
     212              :     pub retry_delay_ms: u64,
     213              : }
     214              : 
     215            0 : #[derive(Debug, Deserialize, Clone)]
     216              : pub struct UserFacingMessage {
     217              :     pub message: Box<str>,
     218              : }
     219              : 
     220              : /// Response which holds client's auth secret, e.g. [`crate::scram::ServerSecret`].
     221              : /// Returned by the `/proxy_get_role_secret` API method.
     222           18 : #[derive(Deserialize)]
     223              : pub struct GetRoleSecret {
     224              :     pub role_secret: Box<str>,
     225              :     pub allowed_ips: Option<Vec<IpPattern>>,
     226              :     pub project_id: Option<ProjectIdInt>,
     227              : }
     228              : 
     229              : // Manually implement debug to omit sensitive info.
     230              : impl fmt::Debug for GetRoleSecret {
     231            0 :     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
     232            0 :         f.debug_struct("GetRoleSecret").finish_non_exhaustive()
     233            0 :     }
     234              : }
     235              : 
     236              : /// Response which holds compute node's `host:port` pair.
     237              : /// Returned by the `/proxy_wake_compute` API method.
     238            6 : #[derive(Debug, Deserialize)]
     239              : pub struct WakeCompute {
     240              :     pub address: Box<str>,
     241              :     pub aux: MetricsAuxInfo,
     242              : }
     243              : 
     244              : /// Async response which concludes the link auth flow.
     245              : /// Also known as `kickResponse` in the console.
     246            8 : #[derive(Debug, Deserialize)]
     247              : pub struct KickSession<'a> {
     248              :     /// Session ID is assigned by the proxy.
     249              :     pub session_id: &'a str,
     250              : 
     251              :     /// Compute node connection params.
     252              :     #[serde(deserialize_with = "KickSession::parse_db_info")]
     253              :     pub result: DatabaseInfo,
     254              : }
     255              : 
     256              : impl KickSession<'_> {
     257            2 :     fn parse_db_info<'de, D>(des: D) -> Result<DatabaseInfo, D::Error>
     258            2 :     where
     259            2 :         D: serde::Deserializer<'de>,
     260            2 :     {
     261            4 :         #[derive(Deserialize)]
     262            2 :         enum Wrapper {
     263            2 :             // Currently, console only reports `Success`.
     264            2 :             // `Failure(String)` used to be here... RIP.
     265            2 :             Success(DatabaseInfo),
     266            2 :         }
     267            2 : 
     268            2 :         Wrapper::deserialize(des).map(|x| match x {
     269            2 :             Wrapper::Success(info) => info,
     270            2 :         })
     271            2 :     }
     272              : }
     273              : 
     274              : /// Compute node connection params.
     275           56 : #[derive(Deserialize)]
     276              : pub struct DatabaseInfo {
     277              :     pub host: Box<str>,
     278              :     pub port: u16,
     279              :     pub dbname: Box<str>,
     280              :     pub user: Box<str>,
     281              :     /// Console always provides a password, but it might
     282              :     /// be inconvenient for debug with local PG instance.
     283              :     pub password: Option<Box<str>>,
     284              :     pub aux: MetricsAuxInfo,
     285              : }
     286              : 
     287              : // Manually implement debug to omit sensitive info.
     288              : impl fmt::Debug for DatabaseInfo {
     289            0 :     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
     290            0 :         f.debug_struct("DatabaseInfo")
     291            0 :             .field("host", &self.host)
     292            0 :             .field("port", &self.port)
     293            0 :             .field("dbname", &self.dbname)
     294            0 :             .field("user", &self.user)
     295            0 :             .finish_non_exhaustive()
     296            0 :     }
     297              : }
     298              : 
     299              : /// Various labels for prometheus metrics.
     300              : /// Also known as `ProxyMetricsAuxInfo` in the console.
     301           50 : #[derive(Debug, Deserialize, Clone)]
     302              : pub struct MetricsAuxInfo {
     303              :     pub endpoint_id: EndpointIdInt,
     304              :     pub project_id: ProjectIdInt,
     305              :     pub branch_id: BranchIdInt,
     306              :     #[serde(default)]
     307              :     pub cold_start_info: ColdStartInfo,
     308              : }
     309              : 
     310           20 : #[derive(Debug, Default, Serialize, Deserialize, Clone, Copy, FixedCardinalityLabel)]
     311              : #[serde(rename_all = "snake_case")]
     312              : pub enum ColdStartInfo {
     313              :     #[default]
     314              :     Unknown,
     315              :     /// Compute was already running
     316              :     Warm,
     317              :     #[serde(rename = "pool_hit")]
     318              :     #[label(rename = "pool_hit")]
     319              :     /// Compute was not running but there was an available VM
     320              :     VmPoolHit,
     321              :     #[serde(rename = "pool_miss")]
     322              :     #[label(rename = "pool_miss")]
     323              :     /// Compute was not running and there were no VMs available
     324              :     VmPoolMiss,
     325              : 
     326              :     // not provided by control plane
     327              :     /// Connection available from HTTP pool
     328              :     HttpPoolHit,
     329              :     /// Cached connection info
     330              :     WarmCached,
     331              : }
     332              : 
     333              : impl ColdStartInfo {
     334            0 :     pub fn as_str(&self) -> &'static str {
     335            0 :         match self {
     336            0 :             ColdStartInfo::Unknown => "unknown",
     337            0 :             ColdStartInfo::Warm => "warm",
     338            0 :             ColdStartInfo::VmPoolHit => "pool_hit",
     339            0 :             ColdStartInfo::VmPoolMiss => "pool_miss",
     340            0 :             ColdStartInfo::HttpPoolHit => "http_pool_hit",
     341            0 :             ColdStartInfo::WarmCached => "warm_cached",
     342              :         }
     343            0 :     }
     344              : }
     345              : 
     346              : #[cfg(test)]
     347              : mod tests {
     348              :     use super::*;
     349              :     use serde_json::json;
     350              : 
     351           10 :     fn dummy_aux() -> serde_json::Value {
     352           10 :         json!({
     353           10 :             "endpoint_id": "endpoint",
     354           10 :             "project_id": "project",
     355           10 :             "branch_id": "branch",
     356           10 :             "cold_start_info": "unknown",
     357           10 :         })
     358           10 :     }
     359              : 
     360              :     #[test]
     361            2 :     fn parse_kick_session() -> anyhow::Result<()> {
     362            2 :         // This is what the console's kickResponse looks like.
     363            2 :         let json = json!({
     364            2 :             "session_id": "deadbeef",
     365            2 :             "result": {
     366            2 :                 "Success": {
     367            2 :                     "host": "localhost",
     368            2 :                     "port": 5432,
     369            2 :                     "dbname": "postgres",
     370            2 :                     "user": "john_doe",
     371            2 :                     "password": "password",
     372            2 :                     "aux": dummy_aux(),
     373            2 :                 }
     374            2 :             }
     375            2 :         });
     376            2 :         let _: KickSession = serde_json::from_str(&json.to_string())?;
     377              : 
     378            2 :         Ok(())
     379            2 :     }
     380              : 
     381              :     #[test]
     382            2 :     fn parse_db_info() -> anyhow::Result<()> {
     383              :         // with password
     384            2 :         let _: DatabaseInfo = serde_json::from_value(json!({
     385            2 :             "host": "localhost",
     386            2 :             "port": 5432,
     387            2 :             "dbname": "postgres",
     388            2 :             "user": "john_doe",
     389            2 :             "password": "password",
     390            2 :             "aux": dummy_aux(),
     391            2 :         }))?;
     392              : 
     393              :         // without password
     394            2 :         let _: DatabaseInfo = serde_json::from_value(json!({
     395            2 :             "host": "localhost",
     396            2 :             "port": 5432,
     397            2 :             "dbname": "postgres",
     398            2 :             "user": "john_doe",
     399            2 :             "aux": dummy_aux(),
     400            2 :         }))?;
     401              : 
     402              :         // new field (forward compatibility)
     403            2 :         let _: DatabaseInfo = serde_json::from_value(json!({
     404            2 :             "host": "localhost",
     405            2 :             "port": 5432,
     406            2 :             "dbname": "postgres",
     407            2 :             "user": "john_doe",
     408            2 :             "project": "hello_world",
     409            2 :             "N.E.W": "forward compatibility check",
     410            2 :             "aux": dummy_aux(),
     411            2 :         }))?;
     412              : 
     413            2 :         Ok(())
     414            2 :     }
     415              : 
     416              :     #[test]
     417            2 :     fn parse_wake_compute() -> anyhow::Result<()> {
     418            2 :         let json = json!({
     419            2 :             "address": "0.0.0.0",
     420            2 :             "aux": dummy_aux(),
     421            2 :         });
     422            2 :         let _: WakeCompute = serde_json::from_str(&json.to_string())?;
     423            2 :         Ok(())
     424            2 :     }
     425              : 
     426              :     #[test]
     427            2 :     fn parse_get_role_secret() -> anyhow::Result<()> {
     428            2 :         // Empty `allowed_ips` field.
     429            2 :         let json = json!({
     430            2 :             "role_secret": "secret",
     431            2 :         });
     432            2 :         let _: GetRoleSecret = serde_json::from_str(&json.to_string())?;
     433            2 :         let json = json!({
     434            2 :             "role_secret": "secret",
     435            2 :             "allowed_ips": ["8.8.8.8"],
     436            2 :         });
     437            2 :         let _: GetRoleSecret = serde_json::from_str(&json.to_string())?;
     438            2 :         let json = json!({
     439            2 :             "role_secret": "secret",
     440            2 :             "allowed_ips": ["8.8.8.8"],
     441            2 :             "project_id": "project",
     442            2 :         });
     443            2 :         let _: GetRoleSecret = serde_json::from_str(&json.to_string())?;
     444              : 
     445            2 :         Ok(())
     446            2 :     }
     447              : }
        

Generated by: LCOV version 2.1-beta