LCOV - code coverage report
Current view: top level - libs/postgres_connection/src - lib.rs (source / functions) Coverage Total Hit
Test: f081ec316c96fa98335efd15ef501745aa4f015d.info Lines: 71.4 % 182 130
Test Date: 2024-06-25 15:11:17 Functions: 67.6 % 37 25

            Line data    Source code
       1              : #![deny(unsafe_code)]
       2              : #![deny(clippy::undocumented_unsafe_blocks)]
       3              : use anyhow::{bail, Context};
       4              : use itertools::Itertools;
       5              : use std::borrow::Cow;
       6              : use std::fmt;
       7              : use url::Host;
       8              : 
       9              : /// Parses a string of format either `host:port` or `host` into a corresponding pair.
      10              : /// The `host` part should be a correct `url::Host`, while `port` (if present) should be
      11              : /// a valid decimal u16 of digits only.
      12           36 : pub fn parse_host_port<S: AsRef<str>>(host_port: S) -> Result<(Host, Option<u16>), anyhow::Error> {
      13           36 :     let (host, port) = match host_port.as_ref().rsplit_once(':') {
      14            6 :         Some((host, port)) => (
      15            6 :             host,
      16            6 :             // +80 is a valid u16, but not a valid port
      17           14 :             if port.chars().all(|c| c.is_ascii_digit()) {
      18            4 :                 Some(port.parse::<u16>().context("Unable to parse port")?)
      19              :             } else {
      20            2 :                 bail!("Port contains a non-ascii-digit")
      21              :             },
      22              :         ),
      23           30 :         None => (host_port.as_ref(), None), // No colons, no port specified
      24              :     };
      25           34 :     let host = Host::parse(host).context("Unable to parse host")?;
      26           32 :     Ok((host, port))
      27           36 : }
      28              : 
      29              : #[cfg(test)]
      30              : mod tests_parse_host_port {
      31              :     use crate::parse_host_port;
      32              :     use url::Host;
      33              : 
      34              :     #[test]
      35            2 :     fn test_normal() {
      36            2 :         let (host, port) = parse_host_port("hello:123").unwrap();
      37            2 :         assert_eq!(host, Host::Domain("hello".to_owned()));
      38            2 :         assert_eq!(port, Some(123));
      39            2 :     }
      40              : 
      41              :     #[test]
      42            2 :     fn test_no_port() {
      43            2 :         let (host, port) = parse_host_port("hello").unwrap();
      44            2 :         assert_eq!(host, Host::Domain("hello".to_owned()));
      45            2 :         assert_eq!(port, None);
      46            2 :     }
      47              : 
      48              :     #[test]
      49            2 :     fn test_ipv6() {
      50            2 :         let (host, port) = parse_host_port("[::1]:123").unwrap();
      51            2 :         assert_eq!(host, Host::<String>::Ipv6(std::net::Ipv6Addr::LOCALHOST));
      52            2 :         assert_eq!(port, Some(123));
      53            2 :     }
      54              : 
      55              :     #[test]
      56            2 :     fn test_invalid_host() {
      57            2 :         assert!(parse_host_port("hello world").is_err());
      58            2 :     }
      59              : 
      60              :     #[test]
      61            2 :     fn test_invalid_port() {
      62            2 :         assert!(parse_host_port("hello:+80").is_err());
      63            2 :     }
      64              : }
      65              : 
      66              : #[derive(Clone)]
      67              : pub struct PgConnectionConfig {
      68              :     host: Host,
      69              :     port: u16,
      70              :     password: Option<String>,
      71              :     options: Vec<String>,
      72              : }
      73              : 
      74              : /// A simplified PostgreSQL connection configuration. Supports only a subset of possible
      75              : /// settings for simplicity. A password getter or `to_connection_string` methods are not
      76              : /// added by design to avoid accidentally leaking password through logging, command line
      77              : /// arguments to a child process, or likewise.
      78              : impl PgConnectionConfig {
      79           32 :     pub fn new_host_port(host: Host, port: u16) -> Self {
      80           32 :         PgConnectionConfig {
      81           32 :             host,
      82           32 :             port,
      83           32 :             password: None,
      84           32 :             options: vec![],
      85           32 :         }
      86           32 :     }
      87              : 
      88           24 :     pub fn host(&self) -> &Host {
      89           24 :         &self.host
      90           24 :     }
      91              : 
      92           12 :     pub fn port(&self) -> u16 {
      93           12 :         self.port
      94           12 :     }
      95              : 
      96            0 :     pub fn set_host(mut self, h: Host) -> Self {
      97            0 :         self.host = h;
      98            0 :         self
      99            0 :     }
     100              : 
     101            0 :     pub fn set_port(mut self, p: u16) -> Self {
     102            0 :         self.port = p;
     103            0 :         self
     104            0 :     }
     105              : 
     106           28 :     pub fn set_password(mut self, s: Option<String>) -> Self {
     107           28 :         self.password = s;
     108           28 :         self
     109           28 :     }
     110              : 
     111           30 :     pub fn extend_options<I: IntoIterator<Item = S>, S: Into<String>>(mut self, i: I) -> Self {
     112           82 :         self.options.extend(i.into_iter().map(|s| s.into()));
     113           30 :         self
     114           30 :     }
     115              : 
     116              :     /// Return a `<host>:<port>` string.
     117            6 :     pub fn raw_address(&self) -> String {
     118            6 :         format!("{}:{}", self.host(), self.port())
     119            6 :     }
     120              : 
     121              :     /// Build a client library-specific connection configuration.
     122              :     /// Used for testing and when we need to add some obscure configuration
     123              :     /// elements at the last moment.
     124            0 :     pub fn to_tokio_postgres_config(&self) -> tokio_postgres::Config {
     125            0 :         // Use `tokio_postgres::Config` instead of `postgres::Config` because
     126            0 :         // the former supports more options to fiddle with later.
     127            0 :         let mut config = tokio_postgres::Config::new();
     128            0 :         config.host(&self.host().to_string()).port(self.port);
     129            0 :         if let Some(password) = &self.password {
     130            0 :             config.password(password);
     131            0 :         }
     132            0 :         if !self.options.is_empty() {
     133            0 :             // These options are command-line options and should be escaped before being passed
     134            0 :             // as an 'options' connection string parameter, see
     135            0 :             // https://www.postgresql.org/docs/15/libpq-connect.html#LIBPQ-CONNECT-OPTIONS
     136            0 :             //
     137            0 :             // They will be space-separated, so each space inside an option should be escaped,
     138            0 :             // and all backslashes should be escaped before that. Although we don't expect options
     139            0 :             // with spaces at the moment, they're supported by PostgreSQL. Hence we support them
     140            0 :             // in this typesafe interface.
     141            0 :             //
     142            0 :             // We use `Cow` to avoid allocations in the best case (no escaping). A fully imperative
     143            0 :             // solution would require 1-2 allocations in the worst case as well, but it's harder to
     144            0 :             // implement and this function is hardly a bottleneck. The function is only called around
     145            0 :             // establishing a new connection.
     146            0 :             #[allow(unstable_name_collisions)]
     147            0 :             config.options(&encode_options(&self.options));
     148            0 :         }
     149            0 :         config
     150            0 :     }
     151              : 
     152              :     /// Connect using postgres protocol with TLS disabled.
     153            0 :     pub async fn connect_no_tls(
     154            0 :         &self,
     155            0 :     ) -> Result<
     156            0 :         (
     157            0 :             tokio_postgres::Client,
     158            0 :             tokio_postgres::Connection<tokio_postgres::Socket, tokio_postgres::tls::NoTlsStream>,
     159            0 :         ),
     160            0 :         postgres::Error,
     161            0 :     > {
     162            0 :         self.to_tokio_postgres_config()
     163            0 :             .connect(postgres::NoTls)
     164            0 :             .await
     165            0 :     }
     166              : }
     167              : 
     168              : #[allow(unstable_name_collisions)]
     169            2 : fn encode_options(options: &[String]) -> String {
     170            2 :     options
     171            2 :         .iter()
     172            8 :         .map(|s| {
     173            8 :             if s.contains(['\\', ' ']) {
     174            4 :                 Cow::Owned(s.replace('\\', "\\\\").replace(' ', "\\ "))
     175              :             } else {
     176            4 :                 Cow::Borrowed(s.as_str())
     177              :             }
     178            8 :         })
     179            2 :         .intersperse(Cow::Borrowed(" ")) // TODO: use impl from std once it's stabilized
     180            2 :         .collect::<String>()
     181            2 : }
     182              : 
     183              : impl fmt::Display for PgConnectionConfig {
     184            0 :     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
     185            0 :         // The password is intentionally hidden and not part of this display string.
     186            0 :         write!(f, "postgresql://{}:{}", self.host, self.port)
     187            0 :     }
     188              : }
     189              : 
     190              : impl fmt::Debug for PgConnectionConfig {
     191            6 :     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
     192            6 :         // We want `password: Some(REDACTED-STRING)`, not `password: Some("REDACTED-STRING")`
     193            6 :         // so even if the password is `REDACTED-STRING` (quite unlikely) there is no confusion.
     194            6 :         // Hence `format_args!()`, it returns a "safe" string which is not escaped by `Debug`.
     195            6 :         f.debug_struct("PgConnectionConfig")
     196            6 :             .field("host", &self.host)
     197            6 :             .field("port", &self.port)
     198            6 :             .field(
     199            6 :                 "password",
     200            6 :                 &self
     201            6 :                     .password
     202            6 :                     .as_ref()
     203            6 :                     .map(|_| format_args!("REDACTED-STRING")),
     204            6 :             )
     205            6 :             .finish()
     206            6 :     }
     207              : }
     208              : 
     209              : #[cfg(test)]
     210              : mod tests_pg_connection_config {
     211              :     use crate::{encode_options, PgConnectionConfig};
     212              :     use once_cell::sync::Lazy;
     213              :     use url::Host;
     214              : 
     215            4 :     static STUB_HOST: Lazy<Host> = Lazy::new(|| Host::Domain("stub.host.example".to_owned()));
     216              : 
     217              :     #[test]
     218            2 :     fn test_no_password() {
     219            2 :         let cfg = PgConnectionConfig::new_host_port(STUB_HOST.clone(), 123);
     220            2 :         assert_eq!(cfg.host(), &*STUB_HOST);
     221            2 :         assert_eq!(cfg.port(), 123);
     222            2 :         assert_eq!(cfg.raw_address(), "stub.host.example:123");
     223            2 :         assert_eq!(
     224            2 :             format!("{:?}", cfg),
     225            2 :             "PgConnectionConfig { host: Domain(\"stub.host.example\"), port: 123, password: None }"
     226            2 :         );
     227            2 :     }
     228              : 
     229              :     #[test]
     230            2 :     fn test_ipv6() {
     231            2 :         // May be a special case because hostname contains a colon.
     232            2 :         let cfg = PgConnectionConfig::new_host_port(Host::parse("[::1]").unwrap(), 123);
     233            2 :         assert_eq!(
     234            2 :             cfg.host(),
     235            2 :             &Host::<String>::Ipv6(std::net::Ipv6Addr::LOCALHOST)
     236            2 :         );
     237            2 :         assert_eq!(cfg.port(), 123);
     238            2 :         assert_eq!(cfg.raw_address(), "[::1]:123");
     239            2 :         assert_eq!(
     240            2 :             format!("{:?}", cfg),
     241            2 :             "PgConnectionConfig { host: Ipv6(::1), port: 123, password: None }"
     242            2 :         );
     243            2 :     }
     244              : 
     245              :     #[test]
     246            2 :     fn test_with_password() {
     247            2 :         let cfg = PgConnectionConfig::new_host_port(STUB_HOST.clone(), 123)
     248            2 :             .set_password(Some("password".to_owned()));
     249            2 :         assert_eq!(cfg.host(), &*STUB_HOST);
     250            2 :         assert_eq!(cfg.port(), 123);
     251            2 :         assert_eq!(cfg.raw_address(), "stub.host.example:123");
     252            2 :         assert_eq!(
     253            2 :             format!("{:?}", cfg),
     254            2 :             "PgConnectionConfig { host: Domain(\"stub.host.example\"), port: 123, password: Some(REDACTED-STRING) }"
     255            2 :         );
     256            2 :     }
     257              : 
     258              :     #[test]
     259            2 :     fn test_with_options() {
     260            2 :         let options = encode_options(&[
     261            2 :             "hello".to_owned(),
     262            2 :             "world".to_owned(),
     263            2 :             "with space".to_owned(),
     264            2 :             "and \\ backslashes".to_owned(),
     265            2 :         ]);
     266            2 :         assert_eq!(options, "hello world with\\ space and\\ \\\\\\ backslashes");
     267            2 :     }
     268              : }
        

Generated by: LCOV version 2.1-beta