LCOV - code coverage report
Current view: top level - proxy/src/console - messages.rs (source / functions) Coverage Total Hit
Test: e402c46de0a007db6b48dddbde450ddbb92e6ceb.info Lines: 62.0 % 184 114
Test Date: 2024-06-25 10:31:23 Functions: 22.7 % 163 37

            Line data    Source code
       1              : use measured::FixedCardinalityLabel;
       2              : use serde::{Deserialize, Serialize};
       3              : use std::fmt::{self, Display};
       4              : 
       5              : use crate::auth::IpPattern;
       6              : 
       7              : use crate::intern::{BranchIdInt, EndpointIdInt, ProjectIdInt};
       8              : use crate::proxy::retry::ShouldRetry;
       9              : 
      10              : /// Generic error response with human-readable description.
      11              : /// Note that we can't always present it to user as is.
      12            0 : #[derive(Debug, Deserialize)]
      13              : pub struct ConsoleError {
      14              :     pub error: Box<str>,
      15              :     #[serde(skip)]
      16              :     pub http_status_code: http::StatusCode,
      17              :     pub status: Option<Status>,
      18              : }
      19              : 
      20              : impl ConsoleError {
      21            6 :     pub fn get_reason(&self) -> Reason {
      22            6 :         self.status
      23            6 :             .as_ref()
      24            6 :             .and_then(|s| s.details.error_info.as_ref())
      25            6 :             .map(|e| e.reason)
      26            6 :             .unwrap_or(Reason::Unknown)
      27            6 :     }
      28            0 :     pub fn get_user_facing_message(&self) -> String {
      29            0 :         use super::provider::errors::REQUEST_FAILED;
      30            0 :         self.status
      31            0 :             .as_ref()
      32            0 :             .and_then(|s| s.details.user_facing_message.as_ref())
      33            0 :             .map(|m| m.message.clone().into())
      34            0 :             .unwrap_or_else(|| {
      35            0 :                 // Ask @neondatabase/control-plane for review before adding more.
      36            0 :                 match self.http_status_code {
      37              :                     http::StatusCode::NOT_FOUND => {
      38              :                         // Status 404: failed to get a project-related resource.
      39            0 :                         format!("{REQUEST_FAILED}: endpoint cannot be found")
      40              :                     }
      41              :                     http::StatusCode::NOT_ACCEPTABLE => {
      42              :                         // Status 406: endpoint is disabled (we don't allow connections).
      43            0 :                         format!("{REQUEST_FAILED}: endpoint is disabled")
      44              :                     }
      45              :                     http::StatusCode::LOCKED | http::StatusCode::UNPROCESSABLE_ENTITY => {
      46              :                         // Status 423: project might be in maintenance mode (or bad state), or quotas exceeded.
      47            0 :                         format!("{REQUEST_FAILED}: endpoint is temporarily unavailable. Check your quotas and/or contact our support.")
      48              :                     }
      49            0 :                     _ => REQUEST_FAILED.to_owned(),
      50              :                 }
      51            0 :             })
      52            0 :     }
      53              : }
      54              : 
      55              : impl Display for ConsoleError {
      56            0 :     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
      57            0 :         let msg = self
      58            0 :             .status
      59            0 :             .as_ref()
      60            0 :             .and_then(|s| s.details.user_facing_message.as_ref())
      61            0 :             .map(|m| m.message.as_ref())
      62            0 :             .unwrap_or_else(|| &self.error);
      63            0 :         write!(f, "{}", msg)
      64            0 :     }
      65              : }
      66              : 
      67              : impl ShouldRetry for ConsoleError {
      68           12 :     fn could_retry(&self) -> bool {
      69           12 :         if self.status.is_none() || self.status.as_ref().unwrap().details.retry_info.is_none() {
      70              :             // retry some temporary failures because the compute was in a bad state
      71              :             // (bad request can be returned when the endpoint was in transition)
      72           12 :             return match &self {
      73              :                 ConsoleError {
      74              :                     http_status_code: http::StatusCode::BAD_REQUEST,
      75              :                     ..
      76            8 :                 } => true,
      77              :                 // don't retry when quotas are exceeded
      78              :                 ConsoleError {
      79              :                     http_status_code: http::StatusCode::UNPROCESSABLE_ENTITY,
      80            0 :                     ref error,
      81            0 :                     ..
      82            0 :                 } => !error.contains("compute time quota of non-primary branches is exceeded"),
      83              :                 // locked can be returned when the endpoint was in transition
      84              :                 // or when quotas are exceeded. don't retry when quotas are exceeded
      85              :                 ConsoleError {
      86              :                     http_status_code: http::StatusCode::LOCKED,
      87            0 :                     ref error,
      88            0 :                     ..
      89            0 :                 } => {
      90            0 :                     !error.contains("quota exceeded")
      91            0 :                         && !error.contains("the limit for current plan reached")
      92              :                 }
      93            4 :                 _ => false,
      94              :             };
      95            0 :         }
      96              : 
      97              :         // retry if the response has a retry delay
      98            0 :         if let Some(retry_info) = self
      99            0 :             .status
     100            0 :             .as_ref()
     101            0 :             .and_then(|s| s.details.retry_info.as_ref())
     102              :         {
     103            0 :             retry_info.retry_delay_ms > 0
     104              :         } else {
     105            0 :             false
     106              :         }
     107           12 :     }
     108              : }
     109              : 
     110            0 : #[derive(Debug, Deserialize)]
     111              : pub struct Status {
     112              :     pub code: Box<str>,
     113              :     pub message: Box<str>,
     114              :     pub details: Details,
     115              : }
     116              : 
     117            0 : #[derive(Debug, Deserialize)]
     118              : pub struct Details {
     119              :     pub error_info: Option<ErrorInfo>,
     120              :     pub retry_info: Option<RetryInfo>,
     121              :     pub user_facing_message: Option<UserFacingMessage>,
     122              : }
     123              : 
     124            0 : #[derive(Debug, Deserialize)]
     125              : pub struct ErrorInfo {
     126              :     pub reason: Reason,
     127              :     // Schema could also have `metadata` field, but it's not structured. Skip it for now.
     128              : }
     129              : 
     130            0 : #[derive(Clone, Copy, Debug, Deserialize, Default)]
     131              : pub enum Reason {
     132              :     #[serde(rename = "ROLE_PROTECTED")]
     133              :     RoleProtected,
     134              :     #[serde(rename = "RESOURCE_NOT_FOUND")]
     135              :     ResourceNotFound,
     136              :     #[serde(rename = "PROJECT_NOT_FOUND")]
     137              :     ProjectNotFound,
     138              :     #[serde(rename = "ENDPOINT_NOT_FOUND")]
     139              :     EndpointNotFound,
     140              :     #[serde(rename = "BRANCH_NOT_FOUND")]
     141              :     BranchNotFound,
     142              :     #[serde(rename = "RATE_LIMIT_EXCEEDED")]
     143              :     RateLimitExceeded,
     144              :     #[serde(rename = "NON_PRIMARY_BRANCH_COMPUTE_TIME_EXCEEDED")]
     145              :     NonPrimaryBranchComputeTimeExceeded,
     146              :     #[serde(rename = "ACTIVE_TIME_QUOTA_EXCEEDED")]
     147              :     ActiveTimeQuotaExceeded,
     148              :     #[serde(rename = "COMPUTE_TIME_QUOTA_EXCEEDED")]
     149              :     ComputeTimeQuotaExceeded,
     150              :     #[serde(rename = "WRITTEN_DATA_QUOTA_EXCEEDED")]
     151              :     WrittenDataQuotaExceeded,
     152              :     #[serde(rename = "DATA_TRANSFER_QUOTA_EXCEEDED")]
     153              :     DataTransferQuotaExceeded,
     154              :     #[serde(rename = "LOGICAL_SIZE_QUOTA_EXCEEDED")]
     155              :     LogicalSizeQuotaExceeded,
     156              :     #[default]
     157              :     #[serde(other)]
     158              :     Unknown,
     159              : }
     160              : 
     161              : impl Reason {
     162            0 :     pub fn is_not_found(&self) -> bool {
     163            0 :         matches!(
     164            0 :             self,
     165              :             Reason::ResourceNotFound
     166              :                 | Reason::ProjectNotFound
     167              :                 | Reason::EndpointNotFound
     168              :                 | Reason::BranchNotFound
     169              :         )
     170            0 :     }
     171              : }
     172              : 
     173            0 : #[derive(Debug, Deserialize)]
     174              : pub struct RetryInfo {
     175              :     pub retry_delay_ms: u64,
     176              : }
     177              : 
     178            0 : #[derive(Debug, Deserialize)]
     179              : pub struct UserFacingMessage {
     180              :     pub message: Box<str>,
     181              : }
     182              : 
     183              : /// Response which holds client's auth secret, e.g. [`crate::scram::ServerSecret`].
     184              : /// Returned by the `/proxy_get_role_secret` API method.
     185           18 : #[derive(Deserialize)]
     186              : pub struct GetRoleSecret {
     187              :     pub role_secret: Box<str>,
     188              :     pub allowed_ips: Option<Vec<IpPattern>>,
     189              :     pub project_id: Option<ProjectIdInt>,
     190              : }
     191              : 
     192              : // Manually implement debug to omit sensitive info.
     193              : impl fmt::Debug for GetRoleSecret {
     194            0 :     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
     195            0 :         f.debug_struct("GetRoleSecret").finish_non_exhaustive()
     196            0 :     }
     197              : }
     198              : 
     199              : /// Response which holds compute node's `host:port` pair.
     200              : /// Returned by the `/proxy_wake_compute` API method.
     201            6 : #[derive(Debug, Deserialize)]
     202              : pub struct WakeCompute {
     203              :     pub address: Box<str>,
     204              :     pub aux: MetricsAuxInfo,
     205              : }
     206              : 
     207              : /// Async response which concludes the link auth flow.
     208              : /// Also known as `kickResponse` in the console.
     209            8 : #[derive(Debug, Deserialize)]
     210              : pub struct KickSession<'a> {
     211              :     /// Session ID is assigned by the proxy.
     212              :     pub session_id: &'a str,
     213              : 
     214              :     /// Compute node connection params.
     215              :     #[serde(deserialize_with = "KickSession::parse_db_info")]
     216              :     pub result: DatabaseInfo,
     217              : }
     218              : 
     219              : impl KickSession<'_> {
     220            2 :     fn parse_db_info<'de, D>(des: D) -> Result<DatabaseInfo, D::Error>
     221            2 :     where
     222            2 :         D: serde::Deserializer<'de>,
     223            2 :     {
     224            4 :         #[derive(Deserialize)]
     225            2 :         enum Wrapper {
     226            2 :             // Currently, console only reports `Success`.
     227            2 :             // `Failure(String)` used to be here... RIP.
     228            2 :             Success(DatabaseInfo),
     229            2 :         }
     230            2 : 
     231            2 :         Wrapper::deserialize(des).map(|x| match x {
     232            2 :             Wrapper::Success(info) => info,
     233            2 :         })
     234            2 :     }
     235              : }
     236              : 
     237              : /// Compute node connection params.
     238           56 : #[derive(Deserialize)]
     239              : pub struct DatabaseInfo {
     240              :     pub host: Box<str>,
     241              :     pub port: u16,
     242              :     pub dbname: Box<str>,
     243              :     pub user: Box<str>,
     244              :     /// Console always provides a password, but it might
     245              :     /// be inconvenient for debug with local PG instance.
     246              :     pub password: Option<Box<str>>,
     247              :     pub aux: MetricsAuxInfo,
     248              : }
     249              : 
     250              : // Manually implement debug to omit sensitive info.
     251              : impl fmt::Debug for DatabaseInfo {
     252            0 :     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
     253            0 :         f.debug_struct("DatabaseInfo")
     254            0 :             .field("host", &self.host)
     255            0 :             .field("port", &self.port)
     256            0 :             .field("dbname", &self.dbname)
     257            0 :             .field("user", &self.user)
     258            0 :             .finish_non_exhaustive()
     259            0 :     }
     260              : }
     261              : 
     262              : /// Various labels for prometheus metrics.
     263              : /// Also known as `ProxyMetricsAuxInfo` in the console.
     264           50 : #[derive(Debug, Deserialize, Clone)]
     265              : pub struct MetricsAuxInfo {
     266              :     pub endpoint_id: EndpointIdInt,
     267              :     pub project_id: ProjectIdInt,
     268              :     pub branch_id: BranchIdInt,
     269              :     #[serde(default)]
     270              :     pub cold_start_info: ColdStartInfo,
     271              : }
     272              : 
     273           20 : #[derive(Debug, Default, Serialize, Deserialize, Clone, Copy, FixedCardinalityLabel)]
     274              : #[serde(rename_all = "snake_case")]
     275              : pub enum ColdStartInfo {
     276              :     #[default]
     277              :     Unknown,
     278              :     /// Compute was already running
     279              :     Warm,
     280              :     #[serde(rename = "pool_hit")]
     281              :     #[label(rename = "pool_hit")]
     282              :     /// Compute was not running but there was an available VM
     283              :     VmPoolHit,
     284              :     #[serde(rename = "pool_miss")]
     285              :     #[label(rename = "pool_miss")]
     286              :     /// Compute was not running and there were no VMs available
     287              :     VmPoolMiss,
     288              : 
     289              :     // not provided by control plane
     290              :     /// Connection available from HTTP pool
     291              :     HttpPoolHit,
     292              :     /// Cached connection info
     293              :     WarmCached,
     294              : }
     295              : 
     296              : impl ColdStartInfo {
     297            0 :     pub fn as_str(&self) -> &'static str {
     298            0 :         match self {
     299            0 :             ColdStartInfo::Unknown => "unknown",
     300            0 :             ColdStartInfo::Warm => "warm",
     301            0 :             ColdStartInfo::VmPoolHit => "pool_hit",
     302            0 :             ColdStartInfo::VmPoolMiss => "pool_miss",
     303            0 :             ColdStartInfo::HttpPoolHit => "http_pool_hit",
     304            0 :             ColdStartInfo::WarmCached => "warm_cached",
     305              :         }
     306            0 :     }
     307              : }
     308              : 
     309              : #[cfg(test)]
     310              : mod tests {
     311              :     use super::*;
     312              :     use serde_json::json;
     313              : 
     314           10 :     fn dummy_aux() -> serde_json::Value {
     315           10 :         json!({
     316           10 :             "endpoint_id": "endpoint",
     317           10 :             "project_id": "project",
     318           10 :             "branch_id": "branch",
     319           10 :             "cold_start_info": "unknown",
     320           10 :         })
     321           10 :     }
     322              : 
     323              :     #[test]
     324            2 :     fn parse_kick_session() -> anyhow::Result<()> {
     325            2 :         // This is what the console's kickResponse looks like.
     326            2 :         let json = json!({
     327            2 :             "session_id": "deadbeef",
     328            2 :             "result": {
     329            2 :                 "Success": {
     330            2 :                     "host": "localhost",
     331            2 :                     "port": 5432,
     332            2 :                     "dbname": "postgres",
     333            2 :                     "user": "john_doe",
     334            2 :                     "password": "password",
     335            2 :                     "aux": dummy_aux(),
     336            2 :                 }
     337            2 :             }
     338            2 :         });
     339            2 :         let _: KickSession = serde_json::from_str(&json.to_string())?;
     340              : 
     341            2 :         Ok(())
     342            2 :     }
     343              : 
     344              :     #[test]
     345            2 :     fn parse_db_info() -> anyhow::Result<()> {
     346              :         // with password
     347            2 :         let _: DatabaseInfo = serde_json::from_value(json!({
     348            2 :             "host": "localhost",
     349            2 :             "port": 5432,
     350            2 :             "dbname": "postgres",
     351            2 :             "user": "john_doe",
     352            2 :             "password": "password",
     353            2 :             "aux": dummy_aux(),
     354            2 :         }))?;
     355              : 
     356              :         // without password
     357            2 :         let _: DatabaseInfo = serde_json::from_value(json!({
     358            2 :             "host": "localhost",
     359            2 :             "port": 5432,
     360            2 :             "dbname": "postgres",
     361            2 :             "user": "john_doe",
     362            2 :             "aux": dummy_aux(),
     363            2 :         }))?;
     364              : 
     365              :         // new field (forward compatibility)
     366            2 :         let _: DatabaseInfo = serde_json::from_value(json!({
     367            2 :             "host": "localhost",
     368            2 :             "port": 5432,
     369            2 :             "dbname": "postgres",
     370            2 :             "user": "john_doe",
     371            2 :             "project": "hello_world",
     372            2 :             "N.E.W": "forward compatibility check",
     373            2 :             "aux": dummy_aux(),
     374            2 :         }))?;
     375              : 
     376            2 :         Ok(())
     377            2 :     }
     378              : 
     379              :     #[test]
     380            2 :     fn parse_wake_compute() -> anyhow::Result<()> {
     381            2 :         let json = json!({
     382            2 :             "address": "0.0.0.0",
     383            2 :             "aux": dummy_aux(),
     384            2 :         });
     385            2 :         let _: WakeCompute = serde_json::from_str(&json.to_string())?;
     386            2 :         Ok(())
     387            2 :     }
     388              : 
     389              :     #[test]
     390            2 :     fn parse_get_role_secret() -> anyhow::Result<()> {
     391            2 :         // Empty `allowed_ips` field.
     392            2 :         let json = json!({
     393            2 :             "role_secret": "secret",
     394            2 :         });
     395            2 :         let _: GetRoleSecret = serde_json::from_str(&json.to_string())?;
     396            2 :         let json = json!({
     397            2 :             "role_secret": "secret",
     398            2 :             "allowed_ips": ["8.8.8.8"],
     399            2 :         });
     400            2 :         let _: GetRoleSecret = serde_json::from_str(&json.to_string())?;
     401            2 :         let json = json!({
     402            2 :             "role_secret": "secret",
     403            2 :             "allowed_ips": ["8.8.8.8"],
     404            2 :             "project_id": "project",
     405            2 :         });
     406            2 :         let _: GetRoleSecret = serde_json::from_str(&json.to_string())?;
     407              : 
     408            2 :         Ok(())
     409            2 :     }
     410              : }
        

Generated by: LCOV version 2.1-beta