LCOV - code coverage report
Current view: top level - libs/proxy/tokio-postgres2/src - config.rs (source / functions) Coverage Total Hit
Test: 4f58e98c51285c7fa348e0b410c88a10caf68ad2.info Lines: 54.4 % 114 62
Test Date: 2025-01-07 20:58:07 Functions: 43.2 % 37 16

            Line data    Source code
       1              : //! Connection configuration.
       2              : 
       3              : use crate::connect::connect;
       4              : use crate::connect_raw::connect_raw;
       5              : use crate::connect_raw::RawConnection;
       6              : use crate::tls::MakeTlsConnect;
       7              : use crate::tls::TlsConnect;
       8              : use crate::{Client, Connection, Error};
       9              : use postgres_protocol2::message::frontend::StartupMessageParams;
      10              : use std::fmt;
      11              : use std::str;
      12              : use std::time::Duration;
      13              : use tokio::io::{AsyncRead, AsyncWrite};
      14              : 
      15              : pub use postgres_protocol2::authentication::sasl::ScramKeys;
      16              : use tokio::net::TcpStream;
      17              : 
      18              : /// TLS configuration.
      19              : #[derive(Debug, Copy, Clone, PartialEq, Eq)]
      20              : #[non_exhaustive]
      21              : pub enum SslMode {
      22              :     /// Do not use TLS.
      23              :     Disable,
      24              :     /// Attempt to connect with TLS but allow sessions without.
      25              :     Prefer,
      26              :     /// Require the use of TLS.
      27              :     Require,
      28              : }
      29              : 
      30              : /// Channel binding configuration.
      31              : #[derive(Debug, Copy, Clone, PartialEq, Eq)]
      32              : #[non_exhaustive]
      33              : pub enum ChannelBinding {
      34              :     /// Do not use channel binding.
      35              :     Disable,
      36              :     /// Attempt to use channel binding but allow sessions without.
      37              :     Prefer,
      38              :     /// Require the use of channel binding.
      39              :     Require,
      40              : }
      41              : 
      42              : /// Replication mode configuration.
      43              : #[derive(Debug, Copy, Clone, PartialEq, Eq)]
      44              : #[non_exhaustive]
      45              : pub enum ReplicationMode {
      46              :     /// Physical replication.
      47              :     Physical,
      48              :     /// Logical replication.
      49              :     Logical,
      50              : }
      51              : 
      52              : /// A host specification.
      53              : #[derive(Debug, Clone, PartialEq, Eq)]
      54              : pub enum Host {
      55              :     /// A TCP hostname.
      56              :     Tcp(String),
      57              : }
      58              : 
      59              : /// Precomputed keys which may override password during auth.
      60              : #[derive(Debug, Clone, Copy, PartialEq, Eq)]
      61              : pub enum AuthKeys {
      62              :     /// A `ClientKey` & `ServerKey` pair for `SCRAM-SHA-256`.
      63              :     ScramSha256(ScramKeys<32>),
      64              : }
      65              : 
      66              : /// Connection configuration.
      67              : #[derive(Clone, PartialEq, Eq)]
      68              : pub struct Config {
      69              :     pub(crate) host: Host,
      70              :     pub(crate) port: u16,
      71              : 
      72              :     pub(crate) password: Option<Vec<u8>>,
      73              :     pub(crate) auth_keys: Option<Box<AuthKeys>>,
      74              :     pub(crate) ssl_mode: SslMode,
      75              :     pub(crate) connect_timeout: Option<Duration>,
      76              :     pub(crate) channel_binding: ChannelBinding,
      77              :     pub(crate) server_params: StartupMessageParams,
      78              : 
      79              :     database: bool,
      80              :     username: bool,
      81              : }
      82              : 
      83              : impl Config {
      84              :     /// Creates a new configuration.
      85           25 :     pub fn new(host: String, port: u16) -> Config {
      86           25 :         Config {
      87           25 :             host: Host::Tcp(host),
      88           25 :             port,
      89           25 :             password: None,
      90           25 :             auth_keys: None,
      91           25 :             ssl_mode: SslMode::Prefer,
      92           25 :             connect_timeout: None,
      93           25 :             channel_binding: ChannelBinding::Prefer,
      94           25 :             server_params: StartupMessageParams::default(),
      95           25 : 
      96           25 :             database: false,
      97           25 :             username: false,
      98           25 :         }
      99           25 :     }
     100              : 
     101              :     /// Sets the user to authenticate with.
     102              :     ///
     103              :     /// Required.
     104           15 :     pub fn user(&mut self, user: &str) -> &mut Config {
     105           15 :         self.set_param("user", user)
     106           15 :     }
     107              : 
     108              :     /// Gets the user to authenticate with, if one has been configured with
     109              :     /// the `user` method.
     110            0 :     pub fn user_is_set(&self) -> bool {
     111            0 :         self.username
     112            0 :     }
     113              : 
     114              :     /// Sets the password to authenticate with.
     115           22 :     pub fn password<T>(&mut self, password: T) -> &mut Config
     116           22 :     where
     117           22 :         T: AsRef<[u8]>,
     118           22 :     {
     119           22 :         self.password = Some(password.as_ref().to_vec());
     120           22 :         self
     121           22 :     }
     122              : 
     123              :     /// Gets the password to authenticate with, if one has been configured with
     124              :     /// the `password` method.
     125           15 :     pub fn get_password(&self) -> Option<&[u8]> {
     126           15 :         self.password.as_deref()
     127           15 :     }
     128              : 
     129              :     /// Sets precomputed protocol-specific keys to authenticate with.
     130              :     /// When set, this option will override `password`.
     131              :     /// See [`AuthKeys`] for more information.
     132            0 :     pub fn auth_keys(&mut self, keys: AuthKeys) -> &mut Config {
     133            0 :         self.auth_keys = Some(Box::new(keys));
     134            0 :         self
     135            0 :     }
     136              : 
     137              :     /// Gets precomputed protocol-specific keys to authenticate with.
     138              :     /// if one has been configured with the `auth_keys` method.
     139           15 :     pub fn get_auth_keys(&self) -> Option<AuthKeys> {
     140           15 :         self.auth_keys.as_deref().copied()
     141           15 :     }
     142              : 
     143              :     /// Sets the name of the database to connect to.
     144              :     ///
     145              :     /// Defaults to the user.
     146           15 :     pub fn dbname(&mut self, dbname: &str) -> &mut Config {
     147           15 :         self.set_param("database", dbname)
     148           15 :     }
     149              : 
     150              :     /// Gets the name of the database to connect to, if one has been configured
     151              :     /// with the `dbname` method.
     152            0 :     pub fn db_is_set(&self) -> bool {
     153            0 :         self.database
     154            0 :     }
     155              : 
     156           31 :     pub fn set_param(&mut self, name: &str, value: &str) -> &mut Config {
     157           31 :         if name == "database" {
     158           15 :             self.database = true;
     159           16 :         } else if name == "user" {
     160           15 :             self.username = true;
     161           15 :         }
     162              : 
     163           31 :         self.server_params.insert(name, value);
     164           31 :         self
     165           31 :     }
     166              : 
     167              :     /// Sets the SSL configuration.
     168              :     ///
     169              :     /// Defaults to `prefer`.
     170           15 :     pub fn ssl_mode(&mut self, ssl_mode: SslMode) -> &mut Config {
     171           15 :         self.ssl_mode = ssl_mode;
     172           15 :         self
     173           15 :     }
     174              : 
     175              :     /// Gets the SSL configuration.
     176            0 :     pub fn get_ssl_mode(&self) -> SslMode {
     177            0 :         self.ssl_mode
     178            0 :     }
     179              : 
     180              :     /// Gets the hosts that have been added to the configuration with `host`.
     181            0 :     pub fn get_host(&self) -> &Host {
     182            0 :         &self.host
     183            0 :     }
     184              : 
     185              :     /// Gets the ports that have been added to the configuration with `port`.
     186            0 :     pub fn get_port(&self) -> u16 {
     187            0 :         self.port
     188            0 :     }
     189              : 
     190              :     /// Sets the timeout applied to socket-level connection attempts.
     191              :     ///
     192              :     /// Note that hostnames can resolve to multiple IP addresses, and this timeout will apply to each address of each
     193              :     /// host separately. Defaults to no limit.
     194            0 :     pub fn connect_timeout(&mut self, connect_timeout: Duration) -> &mut Config {
     195            0 :         self.connect_timeout = Some(connect_timeout);
     196            0 :         self
     197            0 :     }
     198              : 
     199              :     /// Gets the connection timeout, if one has been set with the
     200              :     /// `connect_timeout` method.
     201            0 :     pub fn get_connect_timeout(&self) -> Option<&Duration> {
     202            0 :         self.connect_timeout.as_ref()
     203            0 :     }
     204              : 
     205              :     /// Sets the channel binding behavior.
     206              :     ///
     207              :     /// Defaults to `prefer`.
     208           11 :     pub fn channel_binding(&mut self, channel_binding: ChannelBinding) -> &mut Config {
     209           11 :         self.channel_binding = channel_binding;
     210           11 :         self
     211           11 :     }
     212              : 
     213              :     /// Gets the channel binding behavior.
     214            0 :     pub fn get_channel_binding(&self) -> ChannelBinding {
     215            0 :         self.channel_binding
     216            0 :     }
     217              : 
     218              :     /// Opens a connection to a PostgreSQL database.
     219              :     ///
     220              :     /// Requires the `runtime` Cargo feature (enabled by default).
     221            0 :     pub async fn connect<T>(
     222            0 :         &self,
     223            0 :         tls: T,
     224            0 :     ) -> Result<(Client, Connection<TcpStream, T::Stream>), Error>
     225            0 :     where
     226            0 :         T: MakeTlsConnect<TcpStream>,
     227            0 :     {
     228            0 :         connect(tls, self).await
     229            0 :     }
     230              : 
     231           15 :     pub async fn connect_raw<S, T>(
     232           15 :         &self,
     233           15 :         stream: S,
     234           15 :         tls: T,
     235           15 :     ) -> Result<RawConnection<S, T::Stream>, Error>
     236           15 :     where
     237           15 :         S: AsyncRead + AsyncWrite + Unpin,
     238           15 :         T: TlsConnect<S>,
     239           15 :     {
     240           15 :         connect_raw(stream, tls, self).await
     241           15 :     }
     242              : }
     243              : 
     244              : // Omit password from debug output
     245              : impl fmt::Debug for Config {
     246            0 :     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
     247              :         struct Redaction {}
     248              :         impl fmt::Debug for Redaction {
     249            0 :             fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
     250            0 :                 write!(f, "_")
     251            0 :             }
     252              :         }
     253              : 
     254            0 :         f.debug_struct("Config")
     255            0 :             .field("password", &self.password.as_ref().map(|_| Redaction {}))
     256            0 :             .field("ssl_mode", &self.ssl_mode)
     257            0 :             .field("host", &self.host)
     258            0 :             .field("port", &self.port)
     259            0 :             .field("connect_timeout", &self.connect_timeout)
     260            0 :             .field("channel_binding", &self.channel_binding)
     261            0 :             .field("server_params", &self.server_params)
     262            0 :             .finish()
     263            0 :     }
     264              : }
        

Generated by: LCOV version 2.1-beta