LCOV - differential code coverage report
Current view: top level - proxy/src/auth - credentials.rs (source / functions) Coverage Total Hit UBC CBC
Current: cd44433dd675caa99df17a61b18949c8387e2242.info Lines: 98.1 % 367 360 7 360
Current Date: 2024-01-09 02:06:09 Functions: 83.3 % 60 50 10 50
Baseline: 66c52a629a0f4a503e193045e0df4c77139e344b.info
Baseline Date: 2024-01-08 15:34:46

           TLA  Line data    Source code
       1                 : //! User credentials used in authentication.
       2                 : 
       3                 : use crate::{
       4                 :     auth::password_hack::parse_endpoint_param, context::RequestMonitoring, error::UserFacingError,
       5                 :     metrics::NUM_CONNECTION_ACCEPTED_BY_SNI, proxy::neon_options_str,
       6                 : };
       7                 : use itertools::Itertools;
       8                 : use pq_proto::StartupMessageParams;
       9                 : use smol_str::SmolStr;
      10                 : use std::{collections::HashSet, net::IpAddr};
      11                 : use thiserror::Error;
      12                 : use tracing::{info, warn};
      13                 : 
      14 UBC           0 : #[derive(Debug, Error, PartialEq, Eq, Clone)]
      15                 : pub enum ClientCredsParseError {
      16                 :     #[error("Parameter '{0}' is missing in startup packet.")]
      17                 :     MissingKey(&'static str),
      18                 : 
      19                 :     #[error(
      20                 :         "Inconsistent project name inferred from \
      21                 :          SNI ('{}') and project option ('{}').",
      22                 :         .domain, .option,
      23                 :     )]
      24                 :     InconsistentProjectNames { domain: SmolStr, option: SmolStr },
      25                 : 
      26                 :     #[error(
      27                 :         "Common name inferred from SNI ('{}') is not known",
      28                 :         .cn,
      29                 :     )]
      30                 :     UnknownCommonName { cn: String },
      31                 : 
      32                 :     #[error("Project name ('{0}') must contain only alphanumeric characters and hyphen.")]
      33                 :     MalformedProjectName(SmolStr),
      34                 : }
      35                 : 
      36                 : impl UserFacingError for ClientCredsParseError {}
      37                 : 
      38                 : /// Various client credentials which we use for authentication.
      39                 : /// Note that we don't store any kind of client key or password here.
      40               0 : #[derive(Debug, Clone, PartialEq, Eq)]
      41                 : pub struct ClientCredentials {
      42                 :     pub user: SmolStr,
      43                 :     // TODO: this is a severe misnomer! We should think of a new name ASAP.
      44                 :     pub project: Option<SmolStr>,
      45                 : 
      46                 :     pub cache_key: SmolStr,
      47                 : }
      48                 : 
      49                 : impl ClientCredentials {
      50                 :     #[inline]
      51 CBC          46 :     pub fn project(&self) -> Option<&str> {
      52              46 :         self.project.as_deref()
      53              46 :     }
      54                 : }
      55                 : 
      56                 : impl ClientCredentials {
      57             100 :     pub fn parse(
      58             100 :         ctx: &mut RequestMonitoring,
      59             100 :         params: &StartupMessageParams,
      60             100 :         sni: Option<&str>,
      61             100 :         common_names: Option<HashSet<String>>,
      62             100 :     ) -> Result<Self, ClientCredsParseError> {
      63             100 :         use ClientCredsParseError::*;
      64             100 : 
      65             100 :         // Some parameters are stored in the startup message.
      66             100 :         let get_param = |key| params.get(key).ok_or(MissingKey(key));
      67             100 :         let user: SmolStr = get_param("user")?.into();
      68             100 : 
      69             100 :         // record the values if we have them
      70             100 :         ctx.set_application(params.get("application_name").map(SmolStr::from));
      71             100 :         ctx.set_user(user.clone());
      72             100 :         ctx.set_endpoint_id(sni.map(SmolStr::from));
      73             100 : 
      74             100 :         // Project name might be passed via PG's command-line options.
      75             100 :         let project_option = params
      76             100 :             .options_raw()
      77             100 :             .and_then(|options| {
      78              94 :                 // We support both `project` (deprecated) and `endpoint` options for backward compatibility.
      79              94 :                 // However, if both are present, we don't exactly know which one to use.
      80              94 :                 // Therefore we require that only one of them is present.
      81              94 :                 options
      82              94 :                     .filter_map(parse_endpoint_param)
      83              94 :                     .at_most_one()
      84              94 :                     .ok()?
      85             100 :             })
      86             100 :             .map(|name| name.into());
      87                 : 
      88             100 :         let project_from_domain = if let Some(sni_str) = sni {
      89              77 :             if let Some(cn) = common_names {
      90              77 :                 let common_name_from_sni = sni_str.split_once('.').map(|(_, domain)| domain);
      91                 : 
      92              77 :                 let project = common_name_from_sni
      93              77 :                     .and_then(|domain| {
      94              77 :                         if cn.contains(domain) {
      95              76 :                             subdomain_from_sni(sni_str, domain)
      96                 :                         } else {
      97               1 :                             None
      98                 :                         }
      99              77 :                     })
     100              77 :                     .ok_or_else(|| UnknownCommonName {
     101               1 :                         cn: common_name_from_sni.unwrap_or("").into(),
     102              77 :                     })?;
     103                 : 
     104              76 :                 Some(project)
     105                 :             } else {
     106 UBC           0 :                 None
     107                 :             }
     108                 :         } else {
     109 CBC          23 :             None
     110                 :         };
     111                 : 
     112              99 :         let project = match (project_option, project_from_domain) {
     113                 :             // Invariant: if we have both project name variants, they should match.
     114               2 :             (Some(option), Some(domain)) if option != domain => {
     115               1 :                 Some(Err(InconsistentProjectNames { domain, option }))
     116                 :             }
     117                 :             // Invariant: project name may not contain certain characters.
     118              98 :             (a, b) => a.or(b).map(|name| match project_name_valid(&name) {
     119 UBC           0 :                 false => Err(MalformedProjectName(name)),
     120 CBC          91 :                 true => Ok(name),
     121              98 :             }),
     122                 :         }
     123              99 :         .transpose()?;
     124                 : 
     125              98 :         info!(%user, project = project.as_deref(), "credentials");
     126              98 :         if sni.is_some() {
     127              75 :             info!("Connection with sni");
     128              75 :             NUM_CONNECTION_ACCEPTED_BY_SNI
     129              75 :                 .with_label_values(&["sni"])
     130              75 :                 .inc();
     131              23 :         } else if project.is_some() {
     132              16 :             NUM_CONNECTION_ACCEPTED_BY_SNI
     133              16 :                 .with_label_values(&["no_sni"])
     134              16 :                 .inc();
     135              16 :             info!("Connection without sni");
     136                 :         } else {
     137               7 :             NUM_CONNECTION_ACCEPTED_BY_SNI
     138               7 :                 .with_label_values(&["password_hack"])
     139               7 :                 .inc();
     140               7 :             info!("Connection with password hack");
     141                 :         }
     142                 : 
     143              98 :         let cache_key = format!(
     144              98 :             "{}{}",
     145              98 :             project.as_deref().unwrap_or(""),
     146              98 :             neon_options_str(params)
     147              98 :         )
     148              98 :         .into();
     149              98 : 
     150              98 :         Ok(Self {
     151              98 :             user,
     152              98 :             project,
     153              98 :             cache_key,
     154              98 :         })
     155             100 :     }
     156                 : }
     157                 : 
     158              86 : pub fn check_peer_addr_is_in_list(peer_addr: &IpAddr, ip_list: &Vec<String>) -> bool {
     159              86 :     if ip_list.is_empty() {
     160              75 :         return true;
     161              11 :     }
     162              21 :     for ip in ip_list {
     163                 :         // We expect that all ip addresses from control plane are correct.
     164                 :         // However, if some of them are broken, we still can check the others.
     165              16 :         match parse_ip_pattern(ip) {
     166              15 :             Ok(pattern) => {
     167              15 :                 if check_ip(peer_addr, &pattern) {
     168               6 :                     return true;
     169               9 :                 }
     170                 :             }
     171               1 :             Err(err) => warn!("Cannot parse ip: {}; err: {}", ip, err),
     172                 :         }
     173                 :     }
     174               5 :     false
     175              86 : }
     176                 : 
     177               3 : #[derive(Debug, Clone, Eq, PartialEq)]
     178                 : enum IpPattern {
     179                 :     Subnet(ipnet::IpNet),
     180                 :     Range(IpAddr, IpAddr),
     181                 :     Single(IpAddr),
     182                 : }
     183                 : 
     184              24 : fn parse_ip_pattern(pattern: &str) -> anyhow::Result<IpPattern> {
     185              24 :     if pattern.contains('/') {
     186               2 :         let subnet: ipnet::IpNet = pattern.parse()?;
     187               1 :         return Ok(IpPattern::Subnet(subnet));
     188              22 :     }
     189              22 :     if let Some((start, end)) = pattern.split_once('-') {
     190               3 :         let start: IpAddr = start.parse()?;
     191               2 :         let end: IpAddr = end.parse()?;
     192               1 :         return Ok(IpPattern::Range(start, end));
     193              19 :     }
     194              19 :     let addr: IpAddr = pattern.parse()?;
     195              16 :     Ok(IpPattern::Single(addr))
     196              24 : }
     197                 : 
     198              25 : fn check_ip(ip: &IpAddr, pattern: &IpPattern) -> bool {
     199              25 :     match pattern {
     200               3 :         IpPattern::Subnet(subnet) => subnet.contains(ip),
     201               5 :         IpPattern::Range(start, end) => start <= ip && ip <= end,
     202              17 :         IpPattern::Single(addr) => addr == ip,
     203                 :     }
     204              25 : }
     205                 : 
     206              91 : fn project_name_valid(name: &str) -> bool {
     207             628 :     name.chars().all(|c| c.is_alphanumeric() || c == '-')
     208              91 : }
     209                 : 
     210              76 : fn subdomain_from_sni(sni: &str, common_name: &str) -> Option<SmolStr> {
     211              76 :     sni.strip_suffix(common_name)?
     212              76 :         .strip_suffix('.')
     213              76 :         .map(SmolStr::from)
     214              76 : }
     215                 : 
     216                 : #[cfg(test)]
     217                 : mod tests {
     218                 :     use super::*;
     219                 :     use ClientCredsParseError::*;
     220                 : 
     221               1 :     #[test]
     222               1 :     fn parse_bare_minimum() -> anyhow::Result<()> {
     223               1 :         // According to postgresql, only `user` should be required.
     224               1 :         let options = StartupMessageParams::new([("user", "john_doe")]);
     225               1 :         let mut ctx = RequestMonitoring::test();
     226               1 :         let creds = ClientCredentials::parse(&mut ctx, &options, None, None)?;
     227               1 :         assert_eq!(creds.user, "john_doe");
     228               1 :         assert_eq!(creds.project, None);
     229                 : 
     230               1 :         Ok(())
     231               1 :     }
     232                 : 
     233               1 :     #[test]
     234               1 :     fn parse_excessive() -> anyhow::Result<()> {
     235               1 :         let options = StartupMessageParams::new([
     236               1 :             ("user", "john_doe"),
     237               1 :             ("database", "world"), // should be ignored
     238               1 :             ("foo", "bar"),        // should be ignored
     239               1 :         ]);
     240               1 :         let mut ctx = RequestMonitoring::test();
     241               1 :         let creds = ClientCredentials::parse(&mut ctx, &options, None, None)?;
     242               1 :         assert_eq!(creds.user, "john_doe");
     243               1 :         assert_eq!(creds.project, None);
     244                 : 
     245               1 :         Ok(())
     246               1 :     }
     247                 : 
     248               1 :     #[test]
     249               1 :     fn parse_project_from_sni() -> anyhow::Result<()> {
     250               1 :         let options = StartupMessageParams::new([("user", "john_doe")]);
     251               1 : 
     252               1 :         let sni = Some("foo.localhost");
     253               1 :         let common_names = Some(["localhost".into()].into());
     254               1 : 
     255               1 :         let mut ctx = RequestMonitoring::test();
     256               1 :         let creds = ClientCredentials::parse(&mut ctx, &options, sni, common_names)?;
     257               1 :         assert_eq!(creds.user, "john_doe");
     258               1 :         assert_eq!(creds.project.as_deref(), Some("foo"));
     259               1 :         assert_eq!(creds.cache_key, "foo");
     260                 : 
     261               1 :         Ok(())
     262               1 :     }
     263                 : 
     264               1 :     #[test]
     265               1 :     fn parse_project_from_options() -> anyhow::Result<()> {
     266               1 :         let options = StartupMessageParams::new([
     267               1 :             ("user", "john_doe"),
     268               1 :             ("options", "-ckey=1 project=bar -c geqo=off"),
     269               1 :         ]);
     270               1 : 
     271               1 :         let mut ctx = RequestMonitoring::test();
     272               1 :         let creds = ClientCredentials::parse(&mut ctx, &options, None, None)?;
     273               1 :         assert_eq!(creds.user, "john_doe");
     274               1 :         assert_eq!(creds.project.as_deref(), Some("bar"));
     275                 : 
     276               1 :         Ok(())
     277               1 :     }
     278                 : 
     279               1 :     #[test]
     280               1 :     fn parse_endpoint_from_options() -> anyhow::Result<()> {
     281               1 :         let options = StartupMessageParams::new([
     282               1 :             ("user", "john_doe"),
     283               1 :             ("options", "-ckey=1 endpoint=bar -c geqo=off"),
     284               1 :         ]);
     285               1 : 
     286               1 :         let mut ctx = RequestMonitoring::test();
     287               1 :         let creds = ClientCredentials::parse(&mut ctx, &options, None, None)?;
     288               1 :         assert_eq!(creds.user, "john_doe");
     289               1 :         assert_eq!(creds.project.as_deref(), Some("bar"));
     290                 : 
     291               1 :         Ok(())
     292               1 :     }
     293                 : 
     294               1 :     #[test]
     295               1 :     fn parse_three_endpoints_from_options() -> anyhow::Result<()> {
     296               1 :         let options = StartupMessageParams::new([
     297               1 :             ("user", "john_doe"),
     298               1 :             (
     299               1 :                 "options",
     300               1 :                 "-ckey=1 endpoint=one endpoint=two endpoint=three -c geqo=off",
     301               1 :             ),
     302               1 :         ]);
     303               1 : 
     304               1 :         let mut ctx = RequestMonitoring::test();
     305               1 :         let creds = ClientCredentials::parse(&mut ctx, &options, None, None)?;
     306               1 :         assert_eq!(creds.user, "john_doe");
     307               1 :         assert!(creds.project.is_none());
     308                 : 
     309               1 :         Ok(())
     310               1 :     }
     311                 : 
     312               1 :     #[test]
     313               1 :     fn parse_when_endpoint_and_project_are_in_options() -> anyhow::Result<()> {
     314               1 :         let options = StartupMessageParams::new([
     315               1 :             ("user", "john_doe"),
     316               1 :             ("options", "-ckey=1 endpoint=bar project=foo -c geqo=off"),
     317               1 :         ]);
     318               1 : 
     319               1 :         let mut ctx = RequestMonitoring::test();
     320               1 :         let creds = ClientCredentials::parse(&mut ctx, &options, None, None)?;
     321               1 :         assert_eq!(creds.user, "john_doe");
     322               1 :         assert!(creds.project.is_none());
     323                 : 
     324               1 :         Ok(())
     325               1 :     }
     326                 : 
     327               1 :     #[test]
     328               1 :     fn parse_projects_identical() -> anyhow::Result<()> {
     329               1 :         let options = StartupMessageParams::new([("user", "john_doe"), ("options", "project=baz")]);
     330               1 : 
     331               1 :         let sni = Some("baz.localhost");
     332               1 :         let common_names = Some(["localhost".into()].into());
     333               1 : 
     334               1 :         let mut ctx = RequestMonitoring::test();
     335               1 :         let creds = ClientCredentials::parse(&mut ctx, &options, sni, common_names)?;
     336               1 :         assert_eq!(creds.user, "john_doe");
     337               1 :         assert_eq!(creds.project.as_deref(), Some("baz"));
     338                 : 
     339               1 :         Ok(())
     340               1 :     }
     341                 : 
     342               1 :     #[test]
     343               1 :     fn parse_multi_common_names() -> anyhow::Result<()> {
     344               1 :         let options = StartupMessageParams::new([("user", "john_doe")]);
     345               1 : 
     346               1 :         let common_names = Some(["a.com".into(), "b.com".into()].into());
     347               1 :         let sni = Some("p1.a.com");
     348               1 :         let mut ctx = RequestMonitoring::test();
     349               1 :         let creds = ClientCredentials::parse(&mut ctx, &options, sni, common_names)?;
     350               1 :         assert_eq!(creds.project.as_deref(), Some("p1"));
     351                 : 
     352               1 :         let common_names = Some(["a.com".into(), "b.com".into()].into());
     353               1 :         let sni = Some("p1.b.com");
     354               1 :         let mut ctx = RequestMonitoring::test();
     355               1 :         let creds = ClientCredentials::parse(&mut ctx, &options, sni, common_names)?;
     356               1 :         assert_eq!(creds.project.as_deref(), Some("p1"));
     357                 : 
     358               1 :         Ok(())
     359               1 :     }
     360                 : 
     361               1 :     #[test]
     362               1 :     fn parse_projects_different() {
     363               1 :         let options =
     364               1 :             StartupMessageParams::new([("user", "john_doe"), ("options", "project=first")]);
     365               1 : 
     366               1 :         let sni = Some("second.localhost");
     367               1 :         let common_names = Some(["localhost".into()].into());
     368               1 : 
     369               1 :         let mut ctx = RequestMonitoring::test();
     370               1 :         let err = ClientCredentials::parse(&mut ctx, &options, sni, common_names)
     371               1 :             .expect_err("should fail");
     372               1 :         match err {
     373               1 :             InconsistentProjectNames { domain, option } => {
     374               1 :                 assert_eq!(option, "first");
     375               1 :                 assert_eq!(domain, "second");
     376                 :             }
     377 UBC           0 :             _ => panic!("bad error: {err:?}"),
     378                 :         }
     379 CBC           1 :     }
     380                 : 
     381               1 :     #[test]
     382               1 :     fn parse_inconsistent_sni() {
     383               1 :         let options = StartupMessageParams::new([("user", "john_doe")]);
     384               1 : 
     385               1 :         let sni = Some("project.localhost");
     386               1 :         let common_names = Some(["example.com".into()].into());
     387               1 : 
     388               1 :         let mut ctx = RequestMonitoring::test();
     389               1 :         let err = ClientCredentials::parse(&mut ctx, &options, sni, common_names)
     390               1 :             .expect_err("should fail");
     391               1 :         match err {
     392               1 :             UnknownCommonName { cn } => {
     393               1 :                 assert_eq!(cn, "localhost");
     394                 :             }
     395 UBC           0 :             _ => panic!("bad error: {err:?}"),
     396                 :         }
     397 CBC           1 :     }
     398                 : 
     399               1 :     #[test]
     400               1 :     fn parse_neon_options() -> anyhow::Result<()> {
     401               1 :         let options = StartupMessageParams::new([
     402               1 :             ("user", "john_doe"),
     403               1 :             ("options", "neon_lsn:0/2 neon_endpoint_type:read_write"),
     404               1 :         ]);
     405               1 : 
     406               1 :         let sni = Some("project.localhost");
     407               1 :         let common_names = Some(["localhost".into()].into());
     408               1 :         let mut ctx = RequestMonitoring::test();
     409               1 :         let creds = ClientCredentials::parse(&mut ctx, &options, sni, common_names)?;
     410               1 :         assert_eq!(creds.project.as_deref(), Some("project"));
     411               1 :         assert_eq!(creds.cache_key, "projectendpoint_type:read_write lsn:0/2");
     412                 : 
     413               1 :         Ok(())
     414               1 :     }
     415                 : 
     416               1 :     #[test]
     417               1 :     fn test_check_peer_addr_is_in_list() {
     418               1 :         let peer_addr = IpAddr::from([127, 0, 0, 1]);
     419               1 :         assert!(check_peer_addr_is_in_list(&peer_addr, &vec![]));
     420               1 :         assert!(check_peer_addr_is_in_list(
     421               1 :             &peer_addr,
     422               1 :             &vec!["127.0.0.1".into()]
     423               1 :         ));
     424               1 :         assert!(!check_peer_addr_is_in_list(
     425               1 :             &peer_addr,
     426               1 :             &vec!["8.8.8.8".into()]
     427               1 :         ));
     428                 :         // If there is an incorrect address, it will be skipped.
     429               1 :         assert!(check_peer_addr_is_in_list(
     430               1 :             &peer_addr,
     431               1 :             &vec!["88.8.8".into(), "127.0.0.1".into()]
     432               1 :         ));
     433               1 :     }
     434               1 :     #[test]
     435               1 :     fn test_parse_ip_v4() -> anyhow::Result<()> {
     436               1 :         let peer_addr = IpAddr::from([127, 0, 0, 1]);
     437                 :         // Ok
     438               1 :         assert_eq!(parse_ip_pattern("127.0.0.1")?, IpPattern::Single(peer_addr));
     439               1 :         assert_eq!(
     440               1 :             parse_ip_pattern("127.0.0.1/31")?,
     441               1 :             IpPattern::Subnet(ipnet::IpNet::new(peer_addr, 31)?)
     442                 :         );
     443               1 :         assert_eq!(
     444               1 :             parse_ip_pattern("0.0.0.0-200.0.1.2")?,
     445               1 :             IpPattern::Range(IpAddr::from([0, 0, 0, 0]), IpAddr::from([200, 0, 1, 2]))
     446                 :         );
     447                 : 
     448                 :         // Error
     449               1 :         assert!(parse_ip_pattern("300.0.1.2").is_err());
     450               1 :         assert!(parse_ip_pattern("30.1.2").is_err());
     451               1 :         assert!(parse_ip_pattern("127.0.0.1/33").is_err());
     452               1 :         assert!(parse_ip_pattern("127.0.0.1-127.0.3").is_err());
     453               1 :         assert!(parse_ip_pattern("1234.0.0.1-127.0.3.0").is_err());
     454               1 :         Ok(())
     455               1 :     }
     456                 : 
     457               1 :     #[test]
     458               1 :     fn test_check_ipv4() -> anyhow::Result<()> {
     459               1 :         let peer_addr = IpAddr::from([127, 0, 0, 1]);
     460               1 :         let peer_addr_next = IpAddr::from([127, 0, 0, 2]);
     461               1 :         let peer_addr_prev = IpAddr::from([127, 0, 0, 0]);
     462                 :         // Success
     463               1 :         assert!(check_ip(&peer_addr, &IpPattern::Single(peer_addr)));
     464               1 :         assert!(check_ip(
     465               1 :             &peer_addr,
     466               1 :             &IpPattern::Subnet(ipnet::IpNet::new(peer_addr_prev, 31)?)
     467                 :         ));
     468               1 :         assert!(check_ip(
     469               1 :             &peer_addr,
     470               1 :             &IpPattern::Subnet(ipnet::IpNet::new(peer_addr_next, 30)?)
     471                 :         ));
     472               1 :         assert!(check_ip(
     473               1 :             &peer_addr,
     474               1 :             &IpPattern::Range(IpAddr::from([0, 0, 0, 0]), IpAddr::from([200, 0, 1, 2]))
     475               1 :         ));
     476               1 :         assert!(check_ip(
     477               1 :             &peer_addr,
     478               1 :             &IpPattern::Range(peer_addr, peer_addr)
     479               1 :         ));
     480                 : 
     481                 :         // Not success
     482               1 :         assert!(!check_ip(&peer_addr, &IpPattern::Single(peer_addr_prev)));
     483 UBC           0 :         assert!(!check_ip(
     484 CBC           1 :             &peer_addr,
     485               1 :             &IpPattern::Subnet(ipnet::IpNet::new(peer_addr_next, 31)?)
     486                 :         ));
     487               1 :         assert!(!check_ip(
     488               1 :             &peer_addr,
     489               1 :             &IpPattern::Range(IpAddr::from([0, 0, 0, 0]), peer_addr_prev)
     490               1 :         ));
     491               1 :         assert!(!check_ip(
     492               1 :             &peer_addr,
     493               1 :             &IpPattern::Range(peer_addr_next, IpAddr::from([128, 0, 0, 0]))
     494               1 :         ));
     495                 :         // There is no check that for range start <= end. But it's fine as long as for all this cases the result is false.
     496               1 :         assert!(!check_ip(
     497               1 :             &peer_addr,
     498               1 :             &IpPattern::Range(peer_addr, peer_addr_prev)
     499               1 :         ));
     500               1 :         Ok(())
     501               1 :     }
     502                 : }
        

Generated by: LCOV version 2.1-beta