LCOV - code coverage report
Current view: top level - proxy/src/auth - credentials.rs (source / functions) Coverage Total Hit
Test: 8ac049b474321fdc72ddcb56d7165153a1a900e8.info Lines: 93.9 % 196 184
Test Date: 2023-09-06 10:18:01 Functions: 81.4 % 43 35

            Line data    Source code
       1              : //! User credentials used in authentication.
       2              : 
       3              : use crate::{auth::password_hack::parse_endpoint_param, error::UserFacingError};
       4              : use itertools::Itertools;
       5              : use pq_proto::StartupMessageParams;
       6              : use std::collections::HashSet;
       7              : use thiserror::Error;
       8              : use tracing::info;
       9              : 
      10            0 : #[derive(Debug, Error, PartialEq, Eq, Clone)]
      11              : pub enum ClientCredsParseError {
      12              :     #[error("Parameter '{0}' is missing in startup packet.")]
      13              :     MissingKey(&'static str),
      14              : 
      15              :     #[error(
      16              :         "Inconsistent project name inferred from \
      17              :          SNI ('{}') and project option ('{}').",
      18              :         .domain, .option,
      19              :     )]
      20              :     InconsistentProjectNames { domain: String, option: String },
      21              : 
      22              :     #[error(
      23              :         "Common name inferred from SNI ('{}') is not known",
      24              :         .cn,
      25              :     )]
      26              :     UnknownCommonName { cn: String },
      27              : 
      28              :     #[error("Project name ('{0}') must contain only alphanumeric characters and hyphen.")]
      29              :     MalformedProjectName(String),
      30              : }
      31              : 
      32              : impl UserFacingError for ClientCredsParseError {}
      33              : 
      34              : /// Various client credentials which we use for authentication.
      35              : /// Note that we don't store any kind of client key or password here.
      36            0 : #[derive(Debug, Clone, PartialEq, Eq)]
      37              : pub struct ClientCredentials<'a> {
      38              :     pub user: &'a str,
      39              :     // TODO: this is a severe misnomer! We should think of a new name ASAP.
      40              :     pub project: Option<String>,
      41              : }
      42              : 
      43              : impl ClientCredentials<'_> {
      44              :     #[inline]
      45           28 :     pub fn project(&self) -> Option<&str> {
      46           28 :         self.project.as_deref()
      47           28 :     }
      48              : }
      49              : 
      50              : impl<'a> ClientCredentials<'a> {
      51              :     #[cfg(test)]
      52            0 :     pub fn new_noop() -> Self {
      53            0 :         ClientCredentials {
      54            0 :             user: "",
      55            0 :             project: None,
      56            0 :         }
      57            0 :     }
      58              : 
      59           60 :     pub fn parse(
      60           60 :         params: &'a StartupMessageParams,
      61           60 :         sni: Option<&str>,
      62           60 :         common_names: Option<HashSet<String>>,
      63           60 :     ) -> Result<Self, ClientCredsParseError> {
      64           60 :         use ClientCredsParseError::*;
      65           60 : 
      66           60 :         // Some parameters are stored in the startup message.
      67           60 :         let get_param = |key| params.get(key).ok_or(MissingKey(key));
      68           60 :         let user = get_param("user")?;
      69              : 
      70              :         // Project name might be passed via PG's command-line options.
      71           60 :         let project_option = params
      72           60 :             .options_raw()
      73           60 :             .and_then(|options| {
      74              :                 // We support both `project` (deprecated) and `endpoint` options for backward compatibility.
      75              :                 // However, if both are present, we don't exactly know which one to use.
      76              :                 // Therefore we require that only one of them is present.
      77           34 :                 options
      78           34 :                     .filter_map(parse_endpoint_param)
      79           34 :                     .at_most_one()
      80           34 :                     .ok()?
      81           60 :             })
      82           60 :             .map(|name| name.to_string());
      83              : 
      84           60 :         let project_from_domain = if let Some(sni_str) = sni {
      85           45 :             if let Some(cn) = common_names {
      86           45 :                 let common_name_from_sni = sni_str.split_once('.').map(|(_, domain)| domain);
      87              : 
      88           45 :                 let project = common_name_from_sni
      89           45 :                     .and_then(|domain| {
      90           45 :                         if cn.contains(domain) {
      91           44 :                             subdomain_from_sni(sni_str, domain)
      92              :                         } else {
      93            1 :                             None
      94              :                         }
      95           45 :                     })
      96           45 :                     .ok_or_else(|| UnknownCommonName {
      97            1 :                         cn: common_name_from_sni.unwrap_or("").into(),
      98           45 :                     })?;
      99              : 
     100           44 :                 Some(project)
     101              :             } else {
     102            0 :                 None
     103              :             }
     104              :         } else {
     105           15 :             None
     106              :         };
     107              : 
     108           59 :         let project = match (project_option, project_from_domain) {
     109              :             // Invariant: if we have both project name variants, they should match.
     110            2 :             (Some(option), Some(domain)) if option != domain => {
     111            1 :                 Some(Err(InconsistentProjectNames { domain, option }))
     112              :             }
     113              :             // Invariant: project name may not contain certain characters.
     114           58 :             (a, b) => a.or(b).map(|name| match project_name_valid(&name) {
     115            0 :                 false => Err(MalformedProjectName(name)),
     116           51 :                 true => Ok(name),
     117           58 :             }),
     118              :         }
     119           59 :         .transpose()?;
     120              : 
     121           58 :         info!(user, project = project.as_deref(), "credentials");
     122              : 
     123           58 :         Ok(Self { user, project })
     124           60 :     }
     125              : }
     126              : 
     127           51 : fn project_name_valid(name: &str) -> bool {
     128          306 :     name.chars().all(|c| c.is_alphanumeric() || c == '-')
     129           51 : }
     130              : 
     131           44 : fn subdomain_from_sni(sni: &str, common_name: &str) -> Option<String> {
     132           44 :     sni.strip_suffix(common_name)?
     133           44 :         .strip_suffix('.')
     134           44 :         .map(str::to_owned)
     135           44 : }
     136              : 
     137              : #[cfg(test)]
     138              : mod tests {
     139              :     use super::*;
     140              :     use ClientCredsParseError::*;
     141              : 
     142            1 :     #[test]
     143            1 :     fn parse_bare_minimum() -> anyhow::Result<()> {
     144            1 :         // According to postgresql, only `user` should be required.
     145            1 :         let options = StartupMessageParams::new([("user", "john_doe")]);
     146              : 
     147            1 :         let creds = ClientCredentials::parse(&options, None, None)?;
     148            1 :         assert_eq!(creds.user, "john_doe");
     149            1 :         assert_eq!(creds.project, None);
     150              : 
     151            1 :         Ok(())
     152            1 :     }
     153              : 
     154            1 :     #[test]
     155            1 :     fn parse_excessive() -> anyhow::Result<()> {
     156            1 :         let options = StartupMessageParams::new([
     157            1 :             ("user", "john_doe"),
     158            1 :             ("database", "world"), // should be ignored
     159            1 :             ("foo", "bar"),        // should be ignored
     160            1 :         ]);
     161              : 
     162            1 :         let creds = ClientCredentials::parse(&options, None, None)?;
     163            1 :         assert_eq!(creds.user, "john_doe");
     164            1 :         assert_eq!(creds.project, None);
     165              : 
     166            1 :         Ok(())
     167            1 :     }
     168              : 
     169            1 :     #[test]
     170            1 :     fn parse_project_from_sni() -> anyhow::Result<()> {
     171            1 :         let options = StartupMessageParams::new([("user", "john_doe")]);
     172            1 : 
     173            1 :         let sni = Some("foo.localhost");
     174            1 :         let common_names = Some(["localhost".into()].into());
     175              : 
     176            1 :         let creds = ClientCredentials::parse(&options, sni, common_names)?;
     177            1 :         assert_eq!(creds.user, "john_doe");
     178            1 :         assert_eq!(creds.project.as_deref(), Some("foo"));
     179              : 
     180            1 :         Ok(())
     181            1 :     }
     182              : 
     183            1 :     #[test]
     184            1 :     fn parse_project_from_options() -> anyhow::Result<()> {
     185            1 :         let options = StartupMessageParams::new([
     186            1 :             ("user", "john_doe"),
     187            1 :             ("options", "-ckey=1 project=bar -c geqo=off"),
     188            1 :         ]);
     189              : 
     190            1 :         let creds = ClientCredentials::parse(&options, None, None)?;
     191            1 :         assert_eq!(creds.user, "john_doe");
     192            1 :         assert_eq!(creds.project.as_deref(), Some("bar"));
     193              : 
     194            1 :         Ok(())
     195            1 :     }
     196              : 
     197            1 :     #[test]
     198            1 :     fn parse_endpoint_from_options() -> anyhow::Result<()> {
     199            1 :         let options = StartupMessageParams::new([
     200            1 :             ("user", "john_doe"),
     201            1 :             ("options", "-ckey=1 endpoint=bar -c geqo=off"),
     202            1 :         ]);
     203              : 
     204            1 :         let creds = ClientCredentials::parse(&options, None, None)?;
     205            1 :         assert_eq!(creds.user, "john_doe");
     206            1 :         assert_eq!(creds.project.as_deref(), Some("bar"));
     207              : 
     208            1 :         Ok(())
     209            1 :     }
     210              : 
     211            1 :     #[test]
     212            1 :     fn parse_three_endpoints_from_options() -> anyhow::Result<()> {
     213            1 :         let options = StartupMessageParams::new([
     214            1 :             ("user", "john_doe"),
     215            1 :             (
     216            1 :                 "options",
     217            1 :                 "-ckey=1 endpoint=one endpoint=two endpoint=three -c geqo=off",
     218            1 :             ),
     219            1 :         ]);
     220              : 
     221            1 :         let creds = ClientCredentials::parse(&options, None, None)?;
     222            1 :         assert_eq!(creds.user, "john_doe");
     223            1 :         assert!(creds.project.is_none());
     224              : 
     225            1 :         Ok(())
     226            1 :     }
     227              : 
     228            1 :     #[test]
     229            1 :     fn parse_when_endpoint_and_project_are_in_options() -> anyhow::Result<()> {
     230            1 :         let options = StartupMessageParams::new([
     231            1 :             ("user", "john_doe"),
     232            1 :             ("options", "-ckey=1 endpoint=bar project=foo -c geqo=off"),
     233            1 :         ]);
     234              : 
     235            1 :         let creds = ClientCredentials::parse(&options, None, None)?;
     236            1 :         assert_eq!(creds.user, "john_doe");
     237            1 :         assert!(creds.project.is_none());
     238              : 
     239            1 :         Ok(())
     240            1 :     }
     241              : 
     242            1 :     #[test]
     243            1 :     fn parse_projects_identical() -> anyhow::Result<()> {
     244            1 :         let options = StartupMessageParams::new([("user", "john_doe"), ("options", "project=baz")]);
     245            1 : 
     246            1 :         let sni = Some("baz.localhost");
     247            1 :         let common_names = Some(["localhost".into()].into());
     248              : 
     249            1 :         let creds = ClientCredentials::parse(&options, sni, common_names)?;
     250            1 :         assert_eq!(creds.user, "john_doe");
     251            1 :         assert_eq!(creds.project.as_deref(), Some("baz"));
     252              : 
     253            1 :         Ok(())
     254            1 :     }
     255              : 
     256            1 :     #[test]
     257            1 :     fn parse_multi_common_names() -> anyhow::Result<()> {
     258            1 :         let options = StartupMessageParams::new([("user", "john_doe")]);
     259            1 : 
     260            1 :         let common_names = Some(["a.com".into(), "b.com".into()].into());
     261            1 :         let sni = Some("p1.a.com");
     262            1 :         let creds = ClientCredentials::parse(&options, sni, common_names)?;
     263            1 :         assert_eq!(creds.project.as_deref(), Some("p1"));
     264              : 
     265            1 :         let common_names = Some(["a.com".into(), "b.com".into()].into());
     266            1 :         let sni = Some("p1.b.com");
     267            1 :         let creds = ClientCredentials::parse(&options, sni, common_names)?;
     268            1 :         assert_eq!(creds.project.as_deref(), Some("p1"));
     269              : 
     270            1 :         Ok(())
     271            1 :     }
     272              : 
     273            1 :     #[test]
     274            1 :     fn parse_projects_different() {
     275            1 :         let options =
     276            1 :             StartupMessageParams::new([("user", "john_doe"), ("options", "project=first")]);
     277            1 : 
     278            1 :         let sni = Some("second.localhost");
     279            1 :         let common_names = Some(["localhost".into()].into());
     280            1 : 
     281            1 :         let err = ClientCredentials::parse(&options, sni, common_names).expect_err("should fail");
     282            1 :         match err {
     283            1 :             InconsistentProjectNames { domain, option } => {
     284            1 :                 assert_eq!(option, "first");
     285            1 :                 assert_eq!(domain, "second");
     286              :             }
     287            0 :             _ => panic!("bad error: {err:?}"),
     288              :         }
     289            1 :     }
     290              : 
     291            1 :     #[test]
     292            1 :     fn parse_inconsistent_sni() {
     293            1 :         let options = StartupMessageParams::new([("user", "john_doe")]);
     294            1 : 
     295            1 :         let sni = Some("project.localhost");
     296            1 :         let common_names = Some(["example.com".into()].into());
     297            1 : 
     298            1 :         let err = ClientCredentials::parse(&options, sni, common_names).expect_err("should fail");
     299            1 :         match err {
     300            1 :             UnknownCommonName { cn } => {
     301            1 :                 assert_eq!(cn, "localhost");
     302              :             }
     303            0 :             _ => panic!("bad error: {err:?}"),
     304              :         }
     305            1 :     }
     306              : }
        

Generated by: LCOV version 2.1-beta