LCOV - code coverage report
Current view: top level - proxy/src/auth - credentials.rs (source / functions) Coverage Total Hit
Test: 1e20c4f2b28aa592527961bb32170ebbd2c9172f.info Lines: 96.0 % 351 337
Test Date: 2025-07-16 12:29:03 Functions: 90.5 % 42 38

            Line data    Source code
       1              : //! User credentials used in authentication.
       2              : 
       3              : use std::collections::HashSet;
       4              : use std::net::IpAddr;
       5              : use std::str::FromStr;
       6              : 
       7              : use itertools::Itertools;
       8              : use thiserror::Error;
       9              : use tracing::{debug, warn};
      10              : 
      11              : use crate::auth::password_hack::parse_endpoint_param;
      12              : use crate::context::RequestContext;
      13              : use crate::error::{ReportableError, UserFacingError};
      14              : use crate::metrics::{Metrics, SniGroup, SniKind};
      15              : use crate::pqproto::StartupMessageParams;
      16              : use crate::proxy::NeonOptions;
      17              : use crate::serverless::{AUTH_BROKER_SNI, SERVERLESS_DRIVER_SNI};
      18              : use crate::types::{EndpointId, RoleName};
      19              : 
      20              : #[derive(Debug, Error, PartialEq, Eq, Clone)]
      21              : pub(crate) enum ComputeUserInfoParseError {
      22              :     #[error("Parameter '{0}' is missing in startup packet.")]
      23              :     MissingKey(&'static str),
      24              : 
      25              :     #[error(
      26              :         "Inconsistent project name inferred from \
      27              :          SNI ('{}') and project option ('{}').",
      28              :         .domain, .option,
      29              :     )]
      30              :     InconsistentProjectNames {
      31              :         domain: EndpointId,
      32              :         option: EndpointId,
      33              :     },
      34              : 
      35              :     #[error("Project name ('{0}') must contain only alphanumeric characters and hyphen.")]
      36              :     MalformedProjectName(EndpointId),
      37              : }
      38              : 
      39              : impl UserFacingError for ComputeUserInfoParseError {}
      40              : 
      41              : impl ReportableError for ComputeUserInfoParseError {
      42            0 :     fn get_error_kind(&self) -> crate::error::ErrorKind {
      43            0 :         crate::error::ErrorKind::User
      44            0 :     }
      45              : }
      46              : 
      47              : /// Various client credentials which we use for authentication.
      48              : /// Note that we don't store any kind of client key or password here.
      49              : #[derive(Debug, Clone, PartialEq, Eq)]
      50              : pub(crate) struct ComputeUserInfoMaybeEndpoint {
      51              :     pub(crate) user: RoleName,
      52              :     pub(crate) endpoint_id: Option<EndpointId>,
      53              :     pub(crate) options: NeonOptions,
      54              : }
      55              : 
      56              : impl ComputeUserInfoMaybeEndpoint {
      57              :     #[inline]
      58            0 :     pub(crate) fn endpoint(&self) -> Option<&str> {
      59            0 :         self.endpoint_id.as_deref()
      60            0 :     }
      61              : }
      62              : 
      63           28 : pub(crate) fn endpoint_sni(sni: &str, common_names: &HashSet<String>) -> Option<EndpointId> {
      64           28 :     let (subdomain, common_name) = sni.split_once('.')?;
      65           28 :     if !common_names.contains(common_name) {
      66            2 :         return None;
      67           26 :     }
      68           26 :     if subdomain == SERVERLESS_DRIVER_SNI || subdomain == AUTH_BROKER_SNI {
      69            0 :         return None;
      70           26 :     }
      71           26 :     Some(EndpointId::from(subdomain))
      72           28 : }
      73              : 
      74              : impl ComputeUserInfoMaybeEndpoint {
      75           14 :     pub(crate) fn parse(
      76           14 :         ctx: &RequestContext,
      77           14 :         params: &StartupMessageParams,
      78           14 :         sni: Option<&str>,
      79           14 :         common_names: Option<&HashSet<String>>,
      80           14 :     ) -> Result<Self, ComputeUserInfoParseError> {
      81              :         // Some parameters are stored in the startup message.
      82           14 :         let get_param = |key| {
      83           14 :             params
      84           14 :                 .get(key)
      85           14 :                 .ok_or(ComputeUserInfoParseError::MissingKey(key))
      86           14 :         };
      87           14 :         let user: RoleName = get_param("user")?.into();
      88              : 
      89              :         // Project name might be passed via PG's command-line options.
      90           14 :         let endpoint_option = params
      91           14 :             .options_raw()
      92           14 :             .and_then(|options| {
      93              :                 // We support both `project` (deprecated) and `endpoint` options for backward compatibility.
      94              :                 // However, if both are present, we don't exactly know which one to use.
      95              :                 // Therefore we require that only one of them is present.
      96            8 :                 options
      97            8 :                     .filter_map(parse_endpoint_param)
      98            8 :                     .at_most_one()
      99            8 :                     .ok()?
     100            8 :             })
     101           14 :             .map(|name| name.into());
     102              : 
     103           14 :         let endpoint_from_domain =
     104           14 :             sni.and_then(|sni_str| common_names.and_then(|cn| endpoint_sni(sni_str, cn)));
     105              : 
     106           14 :         let endpoint = match (endpoint_option, endpoint_from_domain) {
     107              :             // Invariant: if we have both project name variants, they should match.
     108            2 :             (Some(option), Some(domain)) if option != domain => {
     109            1 :                 Some(Err(ComputeUserInfoParseError::InconsistentProjectNames {
     110            1 :                     domain,
     111            1 :                     option,
     112            1 :                 }))
     113              :             }
     114              :             // Invariant: project name may not contain certain characters.
     115           13 :             (a, b) => a.or(b).map(|name| {
     116            8 :                 if project_name_valid(name.as_ref()) {
     117            8 :                     Ok(name)
     118              :                 } else {
     119            0 :                     Err(ComputeUserInfoParseError::MalformedProjectName(name))
     120              :                 }
     121            8 :             }),
     122              :         }
     123           14 :         .transpose()?;
     124              : 
     125           13 :         if let Some(ep) = &endpoint {
     126            8 :             ctx.set_endpoint_id(ep.clone());
     127            8 :         }
     128              : 
     129           13 :         let metrics = Metrics::get();
     130           13 :         debug!(%user, "credentials");
     131              : 
     132           13 :         let protocol = ctx.protocol();
     133           13 :         let kind = if sni.is_some() {
     134            7 :             debug!("Connection with sni");
     135            7 :             SniKind::Sni
     136            6 :         } else if endpoint.is_some() {
     137            2 :             debug!("Connection without sni");
     138            2 :             SniKind::NoSni
     139              :         } else {
     140            4 :             debug!("Connection with password hack");
     141            4 :             SniKind::PasswordHack
     142              :         };
     143              : 
     144           13 :         metrics
     145           13 :             .proxy
     146           13 :             .accepted_connections_by_sni
     147           13 :             .inc(SniGroup { protocol, kind });
     148              : 
     149           13 :         let options = NeonOptions::parse_params(params);
     150              : 
     151           13 :         Ok(Self {
     152           13 :             user,
     153           13 :             endpoint_id: endpoint,
     154           13 :             options,
     155           13 :         })
     156           14 :     }
     157              : }
     158              : 
     159           10 : pub(crate) fn check_peer_addr_is_in_list(peer_addr: &IpAddr, ip_list: &[IpPattern]) -> bool {
     160           10 :     ip_list.is_empty() || ip_list.iter().any(|pattern| check_ip(peer_addr, pattern))
     161           10 : }
     162              : 
     163              : #[derive(Debug, Clone, Eq, PartialEq)]
     164              : pub(crate) enum IpPattern {
     165              :     Subnet(ipnet::IpNet),
     166              :     Range(IpAddr, IpAddr),
     167              :     Single(IpAddr),
     168              :     None,
     169              : }
     170              : 
     171              : impl<'de> serde::de::Deserialize<'de> for IpPattern {
     172           10 :     fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
     173           10 :     where
     174           10 :         D: serde::Deserializer<'de>,
     175              :     {
     176              :         struct StrVisitor;
     177              :         impl serde::de::Visitor<'_> for StrVisitor {
     178              :             type Value = IpPattern;
     179              : 
     180            0 :             fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
     181            0 :                 write!(
     182            0 :                     formatter,
     183            0 :                     "comma separated list with ip address, ip address range, or ip address subnet mask"
     184              :                 )
     185            0 :             }
     186              : 
     187           10 :             fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
     188           10 :             where
     189           10 :                 E: serde::de::Error,
     190              :             {
     191           10 :                 Ok(parse_ip_pattern(v).unwrap_or_else(|e| {
     192            1 :                     warn!("Cannot parse ip pattern {v}: {e}");
     193            1 :                     IpPattern::None
     194            1 :                 }))
     195           10 :             }
     196              :         }
     197           10 :         deserializer.deserialize_str(StrVisitor)
     198           10 :     }
     199              : }
     200              : 
     201              : impl FromStr for IpPattern {
     202              :     type Err = anyhow::Error;
     203              : 
     204            2 :     fn from_str(s: &str) -> Result<Self, Self::Err> {
     205            2 :         parse_ip_pattern(s)
     206            2 :     }
     207              : }
     208              : 
     209           20 : fn parse_ip_pattern(pattern: &str) -> anyhow::Result<IpPattern> {
     210           20 :     if pattern.contains('/') {
     211            2 :         let subnet: ipnet::IpNet = pattern.parse()?;
     212            1 :         return Ok(IpPattern::Subnet(subnet));
     213           18 :     }
     214           18 :     if let Some((start, end)) = pattern.split_once('-') {
     215            3 :         let start: IpAddr = start.parse()?;
     216            2 :         let end: IpAddr = end.parse()?;
     217            1 :         return Ok(IpPattern::Range(start, end));
     218           15 :     }
     219           15 :     let addr: IpAddr = pattern.parse()?;
     220           12 :     Ok(IpPattern::Single(addr))
     221           20 : }
     222              : 
     223           16 : fn check_ip(ip: &IpAddr, pattern: &IpPattern) -> bool {
     224           16 :     match pattern {
     225            3 :         IpPattern::Subnet(subnet) => subnet.contains(ip),
     226            5 :         IpPattern::Range(start, end) => start <= ip && ip <= end,
     227            7 :         IpPattern::Single(addr) => addr == ip,
     228            1 :         IpPattern::None => false,
     229              :     }
     230           16 : }
     231              : 
     232            8 : fn project_name_valid(name: &str) -> bool {
     233           39 :     name.chars().all(|c| c.is_alphanumeric() || c == '-')
     234            8 : }
     235              : 
     236              : #[cfg(test)]
     237              : mod tests {
     238              :     use ComputeUserInfoParseError::*;
     239              :     use serde_json::json;
     240              : 
     241              :     use super::*;
     242              : 
     243              :     #[test]
     244            1 :     fn parse_bare_minimum() -> anyhow::Result<()> {
     245              :         // According to postgresql, only `user` should be required.
     246            1 :         let options = StartupMessageParams::new([("user", "john_doe")]);
     247            1 :         let ctx = RequestContext::test();
     248            1 :         let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
     249            1 :         assert_eq!(user_info.user, "john_doe");
     250            1 :         assert_eq!(user_info.endpoint_id, None);
     251              : 
     252            1 :         Ok(())
     253            1 :     }
     254              : 
     255              :     #[test]
     256            1 :     fn parse_excessive() -> anyhow::Result<()> {
     257            1 :         let options = StartupMessageParams::new([
     258            1 :             ("user", "john_doe"),
     259            1 :             ("database", "world"), // should be ignored
     260            1 :             ("foo", "bar"),        // should be ignored
     261            1 :         ]);
     262            1 :         let ctx = RequestContext::test();
     263            1 :         let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
     264            1 :         assert_eq!(user_info.user, "john_doe");
     265            1 :         assert_eq!(user_info.endpoint_id, None);
     266              : 
     267            1 :         Ok(())
     268            1 :     }
     269              : 
     270              :     #[test]
     271            1 :     fn parse_project_from_sni() -> anyhow::Result<()> {
     272            1 :         let options = StartupMessageParams::new([("user", "john_doe")]);
     273              : 
     274            1 :         let sni = Some("foo.localhost");
     275            1 :         let common_names = Some(["localhost".into()].into());
     276              : 
     277            1 :         let ctx = RequestContext::test();
     278            1 :         let user_info =
     279            1 :             ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
     280            1 :         assert_eq!(user_info.user, "john_doe");
     281            1 :         assert_eq!(user_info.endpoint_id.as_deref(), Some("foo"));
     282            1 :         assert_eq!(user_info.options.get_cache_key("foo"), "foo");
     283              : 
     284            1 :         Ok(())
     285            1 :     }
     286              : 
     287              :     #[test]
     288            1 :     fn parse_project_from_options() -> anyhow::Result<()> {
     289            1 :         let options = StartupMessageParams::new([
     290            1 :             ("user", "john_doe"),
     291            1 :             ("options", "-ckey=1 project=bar -c geqo=off"),
     292            1 :         ]);
     293              : 
     294            1 :         let ctx = RequestContext::test();
     295            1 :         let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
     296            1 :         assert_eq!(user_info.user, "john_doe");
     297            1 :         assert_eq!(user_info.endpoint_id.as_deref(), Some("bar"));
     298              : 
     299            1 :         Ok(())
     300            1 :     }
     301              : 
     302              :     #[test]
     303            1 :     fn parse_endpoint_from_options() -> anyhow::Result<()> {
     304            1 :         let options = StartupMessageParams::new([
     305            1 :             ("user", "john_doe"),
     306            1 :             ("options", "-ckey=1 endpoint=bar -c geqo=off"),
     307            1 :         ]);
     308              : 
     309            1 :         let ctx = RequestContext::test();
     310            1 :         let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
     311            1 :         assert_eq!(user_info.user, "john_doe");
     312            1 :         assert_eq!(user_info.endpoint_id.as_deref(), Some("bar"));
     313              : 
     314            1 :         Ok(())
     315            1 :     }
     316              : 
     317              :     #[test]
     318            1 :     fn parse_three_endpoints_from_options() -> anyhow::Result<()> {
     319            1 :         let options = StartupMessageParams::new([
     320            1 :             ("user", "john_doe"),
     321            1 :             (
     322            1 :                 "options",
     323            1 :                 "-ckey=1 endpoint=one endpoint=two endpoint=three -c geqo=off",
     324            1 :             ),
     325            1 :         ]);
     326              : 
     327            1 :         let ctx = RequestContext::test();
     328            1 :         let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
     329            1 :         assert_eq!(user_info.user, "john_doe");
     330            1 :         assert!(user_info.endpoint_id.is_none());
     331              : 
     332            1 :         Ok(())
     333            1 :     }
     334              : 
     335              :     #[test]
     336            1 :     fn parse_when_endpoint_and_project_are_in_options() -> anyhow::Result<()> {
     337            1 :         let options = StartupMessageParams::new([
     338            1 :             ("user", "john_doe"),
     339            1 :             ("options", "-ckey=1 endpoint=bar project=foo -c geqo=off"),
     340            1 :         ]);
     341              : 
     342            1 :         let ctx = RequestContext::test();
     343            1 :         let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
     344            1 :         assert_eq!(user_info.user, "john_doe");
     345            1 :         assert!(user_info.endpoint_id.is_none());
     346              : 
     347            1 :         Ok(())
     348            1 :     }
     349              : 
     350              :     #[test]
     351            1 :     fn parse_projects_identical() -> anyhow::Result<()> {
     352            1 :         let options = StartupMessageParams::new([("user", "john_doe"), ("options", "project=baz")]);
     353              : 
     354            1 :         let sni = Some("baz.localhost");
     355            1 :         let common_names = Some(["localhost".into()].into());
     356              : 
     357            1 :         let ctx = RequestContext::test();
     358            1 :         let user_info =
     359            1 :             ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
     360            1 :         assert_eq!(user_info.user, "john_doe");
     361            1 :         assert_eq!(user_info.endpoint_id.as_deref(), Some("baz"));
     362              : 
     363            1 :         Ok(())
     364            1 :     }
     365              : 
     366              :     #[test]
     367            1 :     fn parse_multi_common_names() -> anyhow::Result<()> {
     368            1 :         let options = StartupMessageParams::new([("user", "john_doe")]);
     369              : 
     370            1 :         let common_names = Some(["a.com".into(), "b.com".into()].into());
     371            1 :         let sni = Some("p1.a.com");
     372            1 :         let ctx = RequestContext::test();
     373            1 :         let user_info =
     374            1 :             ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
     375            1 :         assert_eq!(user_info.endpoint_id.as_deref(), Some("p1"));
     376              : 
     377            1 :         let common_names = Some(["a.com".into(), "b.com".into()].into());
     378            1 :         let sni = Some("p1.b.com");
     379            1 :         let ctx = RequestContext::test();
     380            1 :         let user_info =
     381            1 :             ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
     382            1 :         assert_eq!(user_info.endpoint_id.as_deref(), Some("p1"));
     383              : 
     384            1 :         Ok(())
     385            1 :     }
     386              : 
     387              :     #[test]
     388            1 :     fn parse_projects_different() {
     389            1 :         let options =
     390            1 :             StartupMessageParams::new([("user", "john_doe"), ("options", "project=first")]);
     391              : 
     392            1 :         let sni = Some("second.localhost");
     393            1 :         let common_names = Some(["localhost".into()].into());
     394              : 
     395            1 :         let ctx = RequestContext::test();
     396            1 :         let err = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())
     397            1 :             .expect_err("should fail");
     398            1 :         match err {
     399            1 :             InconsistentProjectNames { domain, option } => {
     400            1 :                 assert_eq!(option, "first");
     401            1 :                 assert_eq!(domain, "second");
     402              :             }
     403            0 :             _ => panic!("bad error: {err:?}"),
     404              :         }
     405            1 :     }
     406              : 
     407              :     #[test]
     408            1 :     fn parse_unknown_sni() {
     409            1 :         let options = StartupMessageParams::new([("user", "john_doe")]);
     410              : 
     411            1 :         let sni = Some("project.localhost");
     412            1 :         let common_names = Some(["example.com".into()].into());
     413              : 
     414            1 :         let ctx = RequestContext::test();
     415            1 :         let info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())
     416            1 :             .unwrap();
     417              : 
     418            1 :         assert!(info.endpoint_id.is_none());
     419            1 :     }
     420              : 
     421              :     #[test]
     422            1 :     fn parse_unknown_sni_with_options() {
     423            1 :         let options = StartupMessageParams::new([
     424            1 :             ("user", "john_doe"),
     425            1 :             ("options", "endpoint=foo-bar-baz-1234"),
     426            1 :         ]);
     427              : 
     428            1 :         let sni = Some("project.localhost");
     429            1 :         let common_names = Some(["example.com".into()].into());
     430              : 
     431            1 :         let ctx = RequestContext::test();
     432            1 :         let info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())
     433            1 :             .unwrap();
     434              : 
     435            1 :         assert_eq!(info.endpoint_id.as_deref(), Some("foo-bar-baz-1234"));
     436            1 :     }
     437              : 
     438              :     #[test]
     439            1 :     fn parse_neon_options() -> anyhow::Result<()> {
     440            1 :         let options = StartupMessageParams::new([
     441            1 :             ("user", "john_doe"),
     442            1 :             ("options", "neon_lsn:0/2 neon_endpoint_type:read_write"),
     443            1 :         ]);
     444              : 
     445            1 :         let sni = Some("project.localhost");
     446            1 :         let common_names = Some(["localhost".into()].into());
     447            1 :         let ctx = RequestContext::test();
     448            1 :         let user_info =
     449            1 :             ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
     450            1 :         assert_eq!(user_info.endpoint_id.as_deref(), Some("project"));
     451            1 :         assert_eq!(
     452            1 :             user_info.options.get_cache_key("project"),
     453              :             "project endpoint_type:read_write lsn:0/2"
     454              :         );
     455              : 
     456            1 :         Ok(())
     457            1 :     }
     458              : 
     459              :     #[test]
     460            1 :     fn test_check_peer_addr_is_in_list() {
     461            4 :         fn check(v: serde_json::Value) -> bool {
     462            4 :             let peer_addr = IpAddr::from([127, 0, 0, 1]);
     463            4 :             let ip_list: Vec<IpPattern> = serde_json::from_value(v).unwrap();
     464            4 :             check_peer_addr_is_in_list(&peer_addr, &ip_list)
     465            4 :         }
     466              : 
     467            1 :         assert!(check(json!([])));
     468            1 :         assert!(check(json!(["127.0.0.1"])));
     469            1 :         assert!(!check(json!(["8.8.8.8"])));
     470              :         // If there is an incorrect address, it will be skipped.
     471            1 :         assert!(check(json!(["88.8.8", "127.0.0.1"])));
     472            1 :     }
     473              :     #[test]
     474            1 :     fn test_parse_ip_v4() -> anyhow::Result<()> {
     475            1 :         let peer_addr = IpAddr::from([127, 0, 0, 1]);
     476              :         // Ok
     477            1 :         assert_eq!(parse_ip_pattern("127.0.0.1")?, IpPattern::Single(peer_addr));
     478            1 :         assert_eq!(
     479            1 :             parse_ip_pattern("127.0.0.1/31")?,
     480            1 :             IpPattern::Subnet(ipnet::IpNet::new(peer_addr, 31)?)
     481              :         );
     482            1 :         assert_eq!(
     483            1 :             parse_ip_pattern("0.0.0.0-200.0.1.2")?,
     484            1 :             IpPattern::Range(IpAddr::from([0, 0, 0, 0]), IpAddr::from([200, 0, 1, 2]))
     485              :         );
     486              : 
     487              :         // Error
     488            1 :         assert!(parse_ip_pattern("300.0.1.2").is_err());
     489            1 :         assert!(parse_ip_pattern("30.1.2").is_err());
     490            1 :         assert!(parse_ip_pattern("127.0.0.1/33").is_err());
     491            1 :         assert!(parse_ip_pattern("127.0.0.1-127.0.3").is_err());
     492            1 :         assert!(parse_ip_pattern("1234.0.0.1-127.0.3.0").is_err());
     493            1 :         Ok(())
     494            1 :     }
     495              : 
     496              :     #[test]
     497            1 :     fn test_check_ipv4() -> anyhow::Result<()> {
     498            1 :         let peer_addr = IpAddr::from([127, 0, 0, 1]);
     499            1 :         let peer_addr_next = IpAddr::from([127, 0, 0, 2]);
     500            1 :         let peer_addr_prev = IpAddr::from([127, 0, 0, 0]);
     501              :         // Success
     502            1 :         assert!(check_ip(&peer_addr, &IpPattern::Single(peer_addr)));
     503            1 :         assert!(check_ip(
     504            1 :             &peer_addr,
     505            1 :             &IpPattern::Subnet(ipnet::IpNet::new(peer_addr_prev, 31)?)
     506              :         ));
     507            1 :         assert!(check_ip(
     508            1 :             &peer_addr,
     509            1 :             &IpPattern::Subnet(ipnet::IpNet::new(peer_addr_next, 30)?)
     510              :         ));
     511            1 :         assert!(check_ip(
     512            1 :             &peer_addr,
     513            1 :             &IpPattern::Range(IpAddr::from([0, 0, 0, 0]), IpAddr::from([200, 0, 1, 2]))
     514              :         ));
     515            1 :         assert!(check_ip(
     516            1 :             &peer_addr,
     517            1 :             &IpPattern::Range(peer_addr, peer_addr)
     518              :         ));
     519              : 
     520              :         // Not success
     521            1 :         assert!(!check_ip(&peer_addr, &IpPattern::Single(peer_addr_prev)));
     522            1 :         assert!(!check_ip(
     523            1 :             &peer_addr,
     524            1 :             &IpPattern::Subnet(ipnet::IpNet::new(peer_addr_next, 31)?)
     525              :         ));
     526            1 :         assert!(!check_ip(
     527            1 :             &peer_addr,
     528            1 :             &IpPattern::Range(IpAddr::from([0, 0, 0, 0]), peer_addr_prev)
     529            1 :         ));
     530            1 :         assert!(!check_ip(
     531            1 :             &peer_addr,
     532            1 :             &IpPattern::Range(peer_addr_next, IpAddr::from([128, 0, 0, 0]))
     533            1 :         ));
     534              :         // There is no check that for range start <= end. But it's fine as long as for all this cases the result is false.
     535            1 :         assert!(!check_ip(
     536            1 :             &peer_addr,
     537            1 :             &IpPattern::Range(peer_addr, peer_addr_prev)
     538            1 :         ));
     539            1 :         Ok(())
     540            1 :     }
     541              : 
     542              :     #[test]
     543            1 :     fn test_connection_blocker() {
     544            3 :         fn check(v: serde_json::Value) -> bool {
     545            3 :             let peer_addr = IpAddr::from([127, 0, 0, 1]);
     546            3 :             let ip_list: Vec<IpPattern> = serde_json::from_value(v).unwrap();
     547            3 :             check_peer_addr_is_in_list(&peer_addr, &ip_list)
     548            3 :         }
     549              : 
     550            1 :         assert!(check(json!([])));
     551            1 :         assert!(check(json!(["127.0.0.1"])));
     552            1 :         assert!(!check(json!(["255.255.255.255"])));
     553            1 :     }
     554              : }
        

Generated by: LCOV version 2.1-beta