LCOV - code coverage report
Current view: top level - proxy/src/console - messages.rs (source / functions) Coverage Total Hit
Test: 322b88762cba8ea666f63cda880cccab6936bf37.info Lines: 89.6 % 115 103
Test Date: 2024-02-29 11:57:12 Functions: 41.5 % 106 44

            Line data    Source code
       1              : use serde::Deserialize;
       2              : use std::fmt;
       3              : 
       4              : use crate::auth::IpPattern;
       5              : 
       6              : use crate::{BranchId, EndpointId, ProjectId};
       7              : 
       8              : /// Generic error response with human-readable description.
       9              : /// Note that we can't always present it to user as is.
      10            0 : #[derive(Debug, Deserialize)]
      11              : pub struct ConsoleError {
      12              :     pub error: Box<str>,
      13              : }
      14              : 
      15              : /// Response which holds client's auth secret, e.g. [`crate::scram::ServerSecret`].
      16              : /// Returned by the `/proxy_get_role_secret` API method.
      17           30 : #[derive(Deserialize)]
      18              : pub struct GetRoleSecret {
      19              :     pub role_secret: Box<str>,
      20              :     pub allowed_ips: Option<Vec<IpPattern>>,
      21              :     pub project_id: Option<ProjectId>,
      22              : }
      23              : 
      24              : // Manually implement debug to omit sensitive info.
      25              : impl fmt::Debug for GetRoleSecret {
      26            0 :     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
      27            0 :         f.debug_struct("GetRoleSecret").finish_non_exhaustive()
      28            0 :     }
      29              : }
      30              : 
      31              : /// Response which holds compute node's `host:port` pair.
      32              : /// Returned by the `/proxy_wake_compute` API method.
      33           10 : #[derive(Debug, Deserialize)]
      34              : pub struct WakeCompute {
      35              :     pub address: Box<str>,
      36              :     pub aux: MetricsAuxInfo,
      37              : }
      38              : 
      39              : /// Async response which concludes the link auth flow.
      40              : /// Also known as `kickResponse` in the console.
      41           12 : #[derive(Debug, Deserialize)]
      42              : pub struct KickSession<'a> {
      43              :     /// Session ID is assigned by the proxy.
      44              :     pub session_id: &'a str,
      45              : 
      46              :     /// Compute node connection params.
      47              :     #[serde(deserialize_with = "KickSession::parse_db_info")]
      48              :     pub result: DatabaseInfo,
      49              : }
      50              : 
      51              : impl KickSession<'_> {
      52            2 :     fn parse_db_info<'de, D>(des: D) -> Result<DatabaseInfo, D::Error>
      53            2 :     where
      54            2 :         D: serde::Deserializer<'de>,
      55            2 :     {
      56            4 :         #[derive(Deserialize)]
      57            2 :         enum Wrapper {
      58            2 :             // Currently, console only reports `Success`.
      59            2 :             // `Failure(String)` used to be here... RIP.
      60            2 :             Success(DatabaseInfo),
      61            2 :         }
      62            2 : 
      63            2 :         Wrapper::deserialize(des).map(|x| match x {
      64            2 :             Wrapper::Success(info) => info,
      65            2 :         })
      66            2 :     }
      67              : }
      68              : 
      69              : /// Compute node connection params.
      70          100 : #[derive(Deserialize)]
      71              : pub struct DatabaseInfo {
      72              :     pub host: Box<str>,
      73              :     pub port: u16,
      74              :     pub dbname: Box<str>,
      75              :     pub user: Box<str>,
      76              :     /// Console always provides a password, but it might
      77              :     /// be inconvenient for debug with local PG instance.
      78              :     pub password: Option<Box<str>>,
      79              :     pub aux: MetricsAuxInfo,
      80              : }
      81              : 
      82              : // Manually implement debug to omit sensitive info.
      83              : impl fmt::Debug for DatabaseInfo {
      84            0 :     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
      85            0 :         f.debug_struct("DatabaseInfo")
      86            0 :             .field("host", &self.host)
      87            0 :             .field("port", &self.port)
      88            0 :             .field("dbname", &self.dbname)
      89            0 :             .field("user", &self.user)
      90            0 :             .finish_non_exhaustive()
      91            0 :     }
      92              : }
      93              : 
      94              : /// Various labels for prometheus metrics.
      95              : /// Also known as `ProxyMetricsAuxInfo` in the console.
      96           70 : #[derive(Debug, Deserialize, Clone, Default)]
      97              : pub struct MetricsAuxInfo {
      98              :     pub endpoint_id: EndpointId,
      99              :     pub project_id: ProjectId,
     100              :     pub branch_id: BranchId,
     101              :     pub is_cold_start: Option<bool>,
     102              : }
     103              : 
     104              : #[cfg(test)]
     105              : mod tests {
     106              :     use super::*;
     107              :     use serde_json::json;
     108              : 
     109           10 :     fn dummy_aux() -> serde_json::Value {
     110           10 :         json!({
     111           10 :             "endpoint_id": "endpoint",
     112           10 :             "project_id": "project",
     113           10 :             "branch_id": "branch",
     114           10 :         })
     115           10 :     }
     116              : 
     117            2 :     #[test]
     118            2 :     fn parse_kick_session() -> anyhow::Result<()> {
     119            2 :         // This is what the console's kickResponse looks like.
     120            2 :         let json = json!({
     121            2 :             "session_id": "deadbeef",
     122            2 :             "result": {
     123            2 :                 "Success": {
     124            2 :                     "host": "localhost",
     125            2 :                     "port": 5432,
     126            2 :                     "dbname": "postgres",
     127            2 :                     "user": "john_doe",
     128            2 :                     "password": "password",
     129            2 :                     "aux": dummy_aux(),
     130            2 :                 }
     131            2 :             }
     132            2 :         });
     133            2 :         let _: KickSession = serde_json::from_str(&json.to_string())?;
     134              : 
     135            2 :         Ok(())
     136            2 :     }
     137              : 
     138            2 :     #[test]
     139            2 :     fn parse_db_info() -> anyhow::Result<()> {
     140              :         // with password
     141            2 :         let _: DatabaseInfo = serde_json::from_value(json!({
     142            2 :             "host": "localhost",
     143            2 :             "port": 5432,
     144            2 :             "dbname": "postgres",
     145            2 :             "user": "john_doe",
     146            2 :             "password": "password",
     147            2 :             "aux": dummy_aux(),
     148            2 :         }))?;
     149              : 
     150              :         // without password
     151            2 :         let _: DatabaseInfo = serde_json::from_value(json!({
     152            2 :             "host": "localhost",
     153            2 :             "port": 5432,
     154            2 :             "dbname": "postgres",
     155            2 :             "user": "john_doe",
     156            2 :             "aux": dummy_aux(),
     157            2 :         }))?;
     158              : 
     159              :         // new field (forward compatibility)
     160            2 :         let _: DatabaseInfo = serde_json::from_value(json!({
     161            2 :             "host": "localhost",
     162            2 :             "port": 5432,
     163            2 :             "dbname": "postgres",
     164            2 :             "user": "john_doe",
     165            2 :             "project": "hello_world",
     166            2 :             "N.E.W": "forward compatibility check",
     167            2 :             "aux": dummy_aux(),
     168            2 :         }))?;
     169              : 
     170            2 :         Ok(())
     171            2 :     }
     172              : 
     173            2 :     #[test]
     174            2 :     fn parse_wake_compute() -> anyhow::Result<()> {
     175            2 :         let json = json!({
     176            2 :             "address": "0.0.0.0",
     177            2 :             "aux": dummy_aux(),
     178            2 :         });
     179            2 :         let _: WakeCompute = serde_json::from_str(&json.to_string())?;
     180            2 :         Ok(())
     181            2 :     }
     182              : 
     183            2 :     #[test]
     184            2 :     fn parse_get_role_secret() -> anyhow::Result<()> {
     185            2 :         // Empty `allowed_ips` field.
     186            2 :         let json = json!({
     187            2 :             "role_secret": "secret",
     188            2 :         });
     189            2 :         let _: GetRoleSecret = serde_json::from_str(&json.to_string())?;
     190            2 :         let json = json!({
     191            2 :             "role_secret": "secret",
     192            2 :             "allowed_ips": ["8.8.8.8"],
     193            2 :         });
     194            2 :         let _: GetRoleSecret = serde_json::from_str(&json.to_string())?;
     195            2 :         let json = json!({
     196            2 :             "role_secret": "secret",
     197            2 :             "allowed_ips": ["8.8.8.8"],
     198            2 :             "project_id": "project",
     199            2 :         });
     200            2 :         let _: GetRoleSecret = serde_json::from_str(&json.to_string())?;
     201              : 
     202            2 :         Ok(())
     203            2 :     }
     204              : }
        

Generated by: LCOV version 2.1-beta