LCOV - differential code coverage report
Current view: top level - proxy/src/auth - credentials.rs (source / functions) Coverage Total Hit UBC CBC
Current: f6946e90941b557c917ac98cd5a7e9506d180f3e.info Lines: 93.8 % 195 183 12 183
Current Date: 2023-10-19 02:04:12 Functions: 81.4 % 43 35 8 35
Baseline: c8637f37369098875162f194f92736355783b050.info
Baseline Date: 2023-10-18 20:25:20

           TLA  Line data    Source code
       1                 : //! User credentials used in authentication.
       2                 : 
       3                 : use crate::{auth::password_hack::parse_endpoint_param, error::UserFacingError};
       4                 : use itertools::Itertools;
       5                 : use pq_proto::StartupMessageParams;
       6                 : use std::collections::HashSet;
       7                 : use thiserror::Error;
       8                 : use tracing::info;
       9                 : 
      10 UBC           0 : #[derive(Debug, Error, PartialEq, Eq, Clone)]
      11                 : pub enum ClientCredsParseError {
      12                 :     #[error("Parameter '{0}' is missing in startup packet.")]
      13                 :     MissingKey(&'static str),
      14                 : 
      15                 :     #[error(
      16                 :         "Inconsistent project name inferred from \
      17                 :          SNI ('{}') and project option ('{}').",
      18                 :         .domain, .option,
      19                 :     )]
      20                 :     InconsistentProjectNames { domain: String, option: String },
      21                 : 
      22                 :     #[error(
      23                 :         "Common name inferred from SNI ('{}') is not known",
      24                 :         .cn,
      25                 :     )]
      26                 :     UnknownCommonName { cn: String },
      27                 : 
      28                 :     #[error("Project name ('{0}') must contain only alphanumeric characters and hyphen.")]
      29                 :     MalformedProjectName(String),
      30                 : }
      31                 : 
      32                 : impl UserFacingError for ClientCredsParseError {}
      33                 : 
      34                 : /// Various client credentials which we use for authentication.
      35                 : /// Note that we don't store any kind of client key or password here.
      36               0 : #[derive(Debug, Clone, PartialEq, Eq)]
      37                 : pub struct ClientCredentials<'a> {
      38                 :     pub user: &'a str,
      39                 :     // TODO: this is a severe misnomer! We should think of a new name ASAP.
      40                 :     pub project: Option<String>,
      41                 : }
      42                 : 
      43                 : impl ClientCredentials<'_> {
      44                 :     #[inline]
      45 CBC          30 :     pub fn project(&self) -> Option<&str> {
      46              30 :         self.project.as_deref()
      47              30 :     }
      48                 : }
      49                 : 
      50                 : impl<'a> ClientCredentials<'a> {
      51                 :     #[cfg(test)]
      52 UBC           0 :     pub fn new_noop() -> Self {
      53               0 :         ClientCredentials {
      54               0 :             user: "",
      55               0 :             project: None,
      56               0 :         }
      57               0 :     }
      58                 : 
      59 CBC          69 :     pub fn parse(
      60              69 :         params: &'a StartupMessageParams,
      61              69 :         sni: Option<&str>,
      62              69 :         common_names: Option<HashSet<String>>,
      63              69 :     ) -> Result<Self, ClientCredsParseError> {
      64              69 :         use ClientCredsParseError::*;
      65              69 : 
      66              69 :         // Some parameters are stored in the startup message.
      67              69 :         let get_param = |key| params.get(key).ok_or(MissingKey(key));
      68              69 :         let user = get_param("user")?;
      69                 : 
      70                 :         // Project name might be passed via PG's command-line options.
      71              69 :         let project_option = params
      72              69 :             .options_raw()
      73              69 :             .and_then(|options| {
      74                 :                 // We support both `project` (deprecated) and `endpoint` options for backward compatibility.
      75                 :                 // However, if both are present, we don't exactly know which one to use.
      76                 :                 // Therefore we require that only one of them is present.
      77              36 :                 options
      78              36 :                     .filter_map(parse_endpoint_param)
      79              36 :                     .at_most_one()
      80              36 :                     .ok()?
      81              69 :             })
      82              69 :             .map(|name| name.to_string());
      83                 : 
      84              69 :         let project_from_domain = if let Some(sni_str) = sni {
      85              54 :             if let Some(cn) = common_names {
      86              54 :                 let common_name_from_sni = sni_str.split_once('.').map(|(_, domain)| domain);
      87                 : 
      88              54 :                 let project = common_name_from_sni
      89              54 :                     .and_then(|domain| {
      90              54 :                         if cn.contains(domain) {
      91              53 :                             subdomain_from_sni(sni_str, domain)
      92                 :                         } else {
      93               1 :                             None
      94                 :                         }
      95              54 :                     })
      96              54 :                     .ok_or_else(|| UnknownCommonName {
      97               1 :                         cn: common_name_from_sni.unwrap_or("").into(),
      98              54 :                     })?;
      99                 : 
     100              53 :                 Some(project)
     101                 :             } else {
     102 UBC           0 :                 None
     103                 :             }
     104                 :         } else {
     105 CBC          15 :             None
     106                 :         };
     107                 : 
     108              68 :         let project = match (project_option, project_from_domain) {
     109                 :             // Invariant: if we have both project name variants, they should match.
     110               2 :             (Some(option), Some(domain)) if option != domain => {
     111               1 :                 Some(Err(InconsistentProjectNames { domain, option }))
     112                 :             }
     113                 :             // Invariant: project name may not contain certain characters.
     114              67 :             (a, b) => a.or(b).map(|name| match project_name_valid(&name) {
     115 UBC           0 :                 false => Err(MalformedProjectName(name)),
     116 CBC          60 :                 true => Ok(name),
     117              67 :             }),
     118                 :         }
     119              68 :         .transpose()?;
     120                 : 
     121              67 :         info!(user, project = project.as_deref(), "credentials");
     122                 : 
     123              67 :         Ok(Self { user, project })
     124              69 :     }
     125                 : }
     126                 : 
     127              60 : fn project_name_valid(name: &str) -> bool {
     128             351 :     name.chars().all(|c| c.is_alphanumeric() || c == '-')
     129              60 : }
     130                 : 
     131                 : fn subdomain_from_sni(sni: &str, common_name: &str) -> Option<String> {
     132              53 :     sni.strip_suffix(common_name)?
     133              53 :         .strip_suffix('.')
     134              53 :         .map(str::to_owned)
     135              53 : }
     136                 : 
     137                 : #[cfg(test)]
     138                 : mod tests {
     139                 :     use super::*;
     140                 :     use ClientCredsParseError::*;
     141                 : 
     142               1 :     #[test]
     143               1 :     fn parse_bare_minimum() -> anyhow::Result<()> {
     144               1 :         // According to postgresql, only `user` should be required.
     145               1 :         let options = StartupMessageParams::new([("user", "john_doe")]);
     146                 : 
     147               1 :         let creds = ClientCredentials::parse(&options, None, None)?;
     148               1 :         assert_eq!(creds.user, "john_doe");
     149               1 :         assert_eq!(creds.project, None);
     150                 : 
     151               1 :         Ok(())
     152               1 :     }
     153                 : 
     154               1 :     #[test]
     155               1 :     fn parse_excessive() -> anyhow::Result<()> {
     156               1 :         let options = StartupMessageParams::new([
     157               1 :             ("user", "john_doe"),
     158               1 :             ("database", "world"), // should be ignored
     159               1 :             ("foo", "bar"),        // should be ignored
     160               1 :         ]);
     161                 : 
     162               1 :         let creds = ClientCredentials::parse(&options, None, None)?;
     163               1 :         assert_eq!(creds.user, "john_doe");
     164               1 :         assert_eq!(creds.project, None);
     165                 : 
     166               1 :         Ok(())
     167               1 :     }
     168                 : 
     169               1 :     #[test]
     170               1 :     fn parse_project_from_sni() -> anyhow::Result<()> {
     171               1 :         let options = StartupMessageParams::new([("user", "john_doe")]);
     172               1 : 
     173               1 :         let sni = Some("foo.localhost");
     174               1 :         let common_names = Some(["localhost".into()].into());
     175                 : 
     176               1 :         let creds = ClientCredentials::parse(&options, sni, common_names)?;
     177               1 :         assert_eq!(creds.user, "john_doe");
     178               1 :         assert_eq!(creds.project.as_deref(), Some("foo"));
     179                 : 
     180               1 :         Ok(())
     181               1 :     }
     182                 : 
     183               1 :     #[test]
     184               1 :     fn parse_project_from_options() -> anyhow::Result<()> {
     185               1 :         let options = StartupMessageParams::new([
     186               1 :             ("user", "john_doe"),
     187               1 :             ("options", "-ckey=1 project=bar -c geqo=off"),
     188               1 :         ]);
     189                 : 
     190               1 :         let creds = ClientCredentials::parse(&options, None, None)?;
     191               1 :         assert_eq!(creds.user, "john_doe");
     192               1 :         assert_eq!(creds.project.as_deref(), Some("bar"));
     193                 : 
     194               1 :         Ok(())
     195               1 :     }
     196                 : 
     197               1 :     #[test]
     198               1 :     fn parse_endpoint_from_options() -> anyhow::Result<()> {
     199               1 :         let options = StartupMessageParams::new([
     200               1 :             ("user", "john_doe"),
     201               1 :             ("options", "-ckey=1 endpoint=bar -c geqo=off"),
     202               1 :         ]);
     203                 : 
     204               1 :         let creds = ClientCredentials::parse(&options, None, None)?;
     205               1 :         assert_eq!(creds.user, "john_doe");
     206               1 :         assert_eq!(creds.project.as_deref(), Some("bar"));
     207                 : 
     208               1 :         Ok(())
     209               1 :     }
     210                 : 
     211               1 :     #[test]
     212               1 :     fn parse_three_endpoints_from_options() -> anyhow::Result<()> {
     213               1 :         let options = StartupMessageParams::new([
     214               1 :             ("user", "john_doe"),
     215               1 :             (
     216               1 :                 "options",
     217               1 :                 "-ckey=1 endpoint=one endpoint=two endpoint=three -c geqo=off",
     218               1 :             ),
     219               1 :         ]);
     220                 : 
     221               1 :         let creds = ClientCredentials::parse(&options, None, None)?;
     222               1 :         assert_eq!(creds.user, "john_doe");
     223               1 :         assert!(creds.project.is_none());
     224                 : 
     225               1 :         Ok(())
     226               1 :     }
     227                 : 
     228               1 :     #[test]
     229               1 :     fn parse_when_endpoint_and_project_are_in_options() -> anyhow::Result<()> {
     230               1 :         let options = StartupMessageParams::new([
     231               1 :             ("user", "john_doe"),
     232               1 :             ("options", "-ckey=1 endpoint=bar project=foo -c geqo=off"),
     233               1 :         ]);
     234                 : 
     235               1 :         let creds = ClientCredentials::parse(&options, None, None)?;
     236               1 :         assert_eq!(creds.user, "john_doe");
     237               1 :         assert!(creds.project.is_none());
     238                 : 
     239               1 :         Ok(())
     240               1 :     }
     241                 : 
     242               1 :     #[test]
     243               1 :     fn parse_projects_identical() -> anyhow::Result<()> {
     244               1 :         let options = StartupMessageParams::new([("user", "john_doe"), ("options", "project=baz")]);
     245               1 : 
     246               1 :         let sni = Some("baz.localhost");
     247               1 :         let common_names = Some(["localhost".into()].into());
     248                 : 
     249               1 :         let creds = ClientCredentials::parse(&options, sni, common_names)?;
     250               1 :         assert_eq!(creds.user, "john_doe");
     251               1 :         assert_eq!(creds.project.as_deref(), Some("baz"));
     252                 : 
     253               1 :         Ok(())
     254               1 :     }
     255                 : 
     256               1 :     #[test]
     257               1 :     fn parse_multi_common_names() -> anyhow::Result<()> {
     258               1 :         let options = StartupMessageParams::new([("user", "john_doe")]);
     259               1 : 
     260               1 :         let common_names = Some(["a.com".into(), "b.com".into()].into());
     261               1 :         let sni = Some("p1.a.com");
     262               1 :         let creds = ClientCredentials::parse(&options, sni, common_names)?;
     263               1 :         assert_eq!(creds.project.as_deref(), Some("p1"));
     264                 : 
     265               1 :         let common_names = Some(["a.com".into(), "b.com".into()].into());
     266               1 :         let sni = Some("p1.b.com");
     267               1 :         let creds = ClientCredentials::parse(&options, sni, common_names)?;
     268               1 :         assert_eq!(creds.project.as_deref(), Some("p1"));
     269                 : 
     270               1 :         Ok(())
     271               1 :     }
     272                 : 
     273               1 :     #[test]
     274               1 :     fn parse_projects_different() {
     275               1 :         let options =
     276               1 :             StartupMessageParams::new([("user", "john_doe"), ("options", "project=first")]);
     277               1 : 
     278               1 :         let sni = Some("second.localhost");
     279               1 :         let common_names = Some(["localhost".into()].into());
     280               1 : 
     281               1 :         let err = ClientCredentials::parse(&options, sni, common_names).expect_err("should fail");
     282               1 :         match err {
     283               1 :             InconsistentProjectNames { domain, option } => {
     284               1 :                 assert_eq!(option, "first");
     285               1 :                 assert_eq!(domain, "second");
     286                 :             }
     287 UBC           0 :             _ => panic!("bad error: {err:?}"),
     288                 :         }
     289 CBC           1 :     }
     290                 : 
     291               1 :     #[test]
     292               1 :     fn parse_inconsistent_sni() {
     293               1 :         let options = StartupMessageParams::new([("user", "john_doe")]);
     294               1 : 
     295               1 :         let sni = Some("project.localhost");
     296               1 :         let common_names = Some(["example.com".into()].into());
     297               1 : 
     298               1 :         let err = ClientCredentials::parse(&options, sni, common_names).expect_err("should fail");
     299               1 :         match err {
     300               1 :             UnknownCommonName { cn } => {
     301               1 :                 assert_eq!(cn, "localhost");
     302                 :             }
     303 UBC           0 :             _ => panic!("bad error: {err:?}"),
     304                 :         }
     305 CBC           1 :     }
     306                 : }
        

Generated by: LCOV version 2.1-beta