LCOV - code coverage report
Current view: top level - libs/proxy/tokio-postgres2/src - config.rs (source / functions) Coverage Total Hit
Test: 1e20c4f2b28aa592527961bb32170ebbd2c9172f.info Lines: 44.8 % 134 60
Test Date: 2025-07-16 12:29:03 Functions: 30.4 % 46 14

            Line data    Source code
       1              : //! Connection configuration.
       2              : 
       3              : use std::net::IpAddr;
       4              : use std::time::Duration;
       5              : use std::{fmt, str};
       6              : 
       7              : pub use postgres_protocol2::authentication::sasl::ScramKeys;
       8              : use postgres_protocol2::message::frontend::StartupMessageParams;
       9              : use serde::{Deserialize, Serialize};
      10              : use tokio::io::{AsyncRead, AsyncWrite};
      11              : use tokio::net::TcpStream;
      12              : 
      13              : use crate::connect::connect;
      14              : use crate::connect_raw::{RawConnection, connect_raw};
      15              : use crate::connect_tls::connect_tls;
      16              : use crate::maybe_tls_stream::MaybeTlsStream;
      17              : use crate::tls::{MakeTlsConnect, TlsConnect, TlsStream};
      18              : use crate::{Client, Connection, Error};
      19              : 
      20              : /// TLS configuration.
      21            0 : #[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)]
      22              : pub enum SslMode {
      23              :     /// Do not use TLS.
      24              :     Disable,
      25              :     /// Attempt to connect with TLS but allow sessions without.
      26              :     Prefer,
      27              :     /// Require the use of TLS.
      28              :     Require,
      29              : }
      30              : 
      31              : /// Channel binding configuration.
      32              : #[derive(Debug, Copy, Clone, PartialEq, Eq)]
      33              : #[non_exhaustive]
      34              : pub enum ChannelBinding {
      35              :     /// Do not use channel binding.
      36              :     Disable,
      37              :     /// Attempt to use channel binding but allow sessions without.
      38              :     Prefer,
      39              :     /// Require the use of channel binding.
      40              :     Require,
      41              : }
      42              : 
      43              : /// Replication mode configuration.
      44              : #[derive(Debug, Copy, Clone, PartialEq, Eq)]
      45              : #[non_exhaustive]
      46              : pub enum ReplicationMode {
      47              :     /// Physical replication.
      48              :     Physical,
      49              :     /// Logical replication.
      50              :     Logical,
      51              : }
      52              : 
      53              : /// A host specification.
      54            0 : #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
      55              : pub enum Host {
      56              :     /// A TCP hostname.
      57              :     Tcp(String),
      58              : }
      59              : 
      60              : /// Precomputed keys which may override password during auth.
      61              : #[derive(Debug, Clone, Copy, PartialEq, Eq)]
      62              : pub enum AuthKeys {
      63              :     /// A `ClientKey` & `ServerKey` pair for `SCRAM-SHA-256`.
      64              :     ScramSha256(ScramKeys<32>),
      65              : }
      66              : 
      67              : /// Connection configuration.
      68              : #[derive(Clone, PartialEq, Eq)]
      69              : pub struct Config {
      70              :     pub(crate) host_addr: Option<IpAddr>,
      71              :     pub(crate) host: Host,
      72              :     pub(crate) port: u16,
      73              : 
      74              :     pub(crate) password: Option<Vec<u8>>,
      75              :     pub(crate) auth_keys: Option<Box<AuthKeys>>,
      76              :     pub(crate) ssl_mode: SslMode,
      77              :     pub(crate) connect_timeout: Option<Duration>,
      78              :     pub(crate) channel_binding: ChannelBinding,
      79              :     pub(crate) server_params: StartupMessageParams,
      80              : 
      81              :     database: bool,
      82              :     username: bool,
      83              : }
      84              : 
      85              : impl Config {
      86              :     /// Creates a new configuration.
      87           15 :     pub fn new(host: String, port: u16) -> Config {
      88           15 :         Config {
      89           15 :             host_addr: None,
      90           15 :             host: Host::Tcp(host),
      91           15 :             port,
      92           15 :             password: None,
      93           15 :             auth_keys: None,
      94           15 :             ssl_mode: SslMode::Prefer,
      95           15 :             connect_timeout: None,
      96           15 :             channel_binding: ChannelBinding::Prefer,
      97           15 :             server_params: StartupMessageParams::default(),
      98           15 : 
      99           15 :             database: false,
     100           15 :             username: false,
     101           15 :         }
     102           15 :     }
     103              : 
     104              :     /// Sets the user to authenticate with.
     105              :     ///
     106              :     /// Required.
     107           15 :     pub fn user(&mut self, user: &str) -> &mut Config {
     108           15 :         self.set_param("user", user)
     109           15 :     }
     110              : 
     111              :     /// Gets the user to authenticate with, if one has been configured with
     112              :     /// the `user` method.
     113            0 :     pub fn user_is_set(&self) -> bool {
     114            0 :         self.username
     115            0 :     }
     116              : 
     117              :     /// Sets the password to authenticate with.
     118           12 :     pub fn password<T>(&mut self, password: T) -> &mut Config
     119           12 :     where
     120           12 :         T: AsRef<[u8]>,
     121              :     {
     122           12 :         self.password = Some(password.as_ref().to_vec());
     123           12 :         self
     124            0 :     }
     125              : 
     126              :     /// Gets the password to authenticate with, if one has been configured with
     127              :     /// the `password` method.
     128           11 :     pub fn get_password(&self) -> Option<&[u8]> {
     129           11 :         self.password.as_deref()
     130           11 :     }
     131              : 
     132              :     /// Sets precomputed protocol-specific keys to authenticate with.
     133              :     /// When set, this option will override `password`.
     134              :     /// See [`AuthKeys`] for more information.
     135            0 :     pub fn auth_keys(&mut self, keys: AuthKeys) -> &mut Config {
     136            0 :         self.auth_keys = Some(Box::new(keys));
     137            0 :         self
     138            0 :     }
     139              : 
     140              :     /// Gets precomputed protocol-specific keys to authenticate with.
     141              :     /// if one has been configured with the `auth_keys` method.
     142           11 :     pub fn get_auth_keys(&self) -> Option<AuthKeys> {
     143           11 :         self.auth_keys.as_deref().copied()
     144           11 :     }
     145              : 
     146              :     /// Sets the name of the database to connect to.
     147              :     ///
     148              :     /// Defaults to the user.
     149           15 :     pub fn dbname(&mut self, dbname: &str) -> &mut Config {
     150           15 :         self.set_param("database", dbname)
     151           15 :     }
     152              : 
     153              :     /// Gets the name of the database to connect to, if one has been configured
     154              :     /// with the `dbname` method.
     155            0 :     pub fn db_is_set(&self) -> bool {
     156            0 :         self.database
     157            0 :     }
     158              : 
     159           31 :     pub fn set_param(&mut self, name: &str, value: &str) -> &mut Config {
     160           31 :         if name == "database" {
     161           15 :             self.database = true;
     162           16 :         } else if name == "user" {
     163           15 :             self.username = true;
     164           15 :         }
     165              : 
     166           31 :         self.server_params.insert(name, value);
     167           31 :         self
     168           31 :     }
     169              : 
     170            0 :     pub fn set_host_addr(&mut self, addr: IpAddr) -> &mut Config {
     171            0 :         self.host_addr = Some(addr);
     172            0 :         self
     173            0 :     }
     174              : 
     175            0 :     pub fn get_host_addr(&self) -> Option<IpAddr> {
     176            0 :         self.host_addr
     177            0 :     }
     178              : 
     179              :     /// Sets the SSL configuration.
     180              :     ///
     181              :     /// Defaults to `prefer`.
     182           15 :     pub fn ssl_mode(&mut self, ssl_mode: SslMode) -> &mut Config {
     183           15 :         self.ssl_mode = ssl_mode;
     184           15 :         self
     185           15 :     }
     186              : 
     187              :     /// Gets the SSL configuration.
     188            0 :     pub fn get_ssl_mode(&self) -> SslMode {
     189            0 :         self.ssl_mode
     190            0 :     }
     191              : 
     192              :     /// Gets the hosts that have been added to the configuration with `host`.
     193            0 :     pub fn get_host(&self) -> &Host {
     194            0 :         &self.host
     195            0 :     }
     196              : 
     197              :     /// Gets the ports that have been added to the configuration with `port`.
     198            0 :     pub fn get_port(&self) -> u16 {
     199            0 :         self.port
     200            0 :     }
     201              : 
     202              :     /// Sets the timeout applied to socket-level connection attempts.
     203              :     ///
     204              :     /// Note that hostnames can resolve to multiple IP addresses, and this timeout will apply to each address of each
     205              :     /// host separately. Defaults to no limit.
     206            0 :     pub fn connect_timeout(&mut self, connect_timeout: Duration) -> &mut Config {
     207            0 :         self.connect_timeout = Some(connect_timeout);
     208            0 :         self
     209            0 :     }
     210              : 
     211              :     /// Gets the connection timeout, if one has been set with the
     212              :     /// `connect_timeout` method.
     213            0 :     pub fn get_connect_timeout(&self) -> Option<&Duration> {
     214            0 :         self.connect_timeout.as_ref()
     215            0 :     }
     216              : 
     217              :     /// Sets the channel binding behavior.
     218              :     ///
     219              :     /// Defaults to `prefer`.
     220           11 :     pub fn channel_binding(&mut self, channel_binding: ChannelBinding) -> &mut Config {
     221           11 :         self.channel_binding = channel_binding;
     222           11 :         self
     223           11 :     }
     224              : 
     225              :     /// Gets the channel binding behavior.
     226            0 :     pub fn get_channel_binding(&self) -> ChannelBinding {
     227            0 :         self.channel_binding
     228            0 :     }
     229              : 
     230              :     /// Opens a connection to a PostgreSQL database.
     231              :     ///
     232              :     /// Requires the `runtime` Cargo feature (enabled by default).
     233            0 :     pub async fn connect<T>(
     234            0 :         &self,
     235            0 :         tls: &T,
     236            0 :     ) -> Result<(Client, Connection<TcpStream, T::Stream>), Error>
     237            0 :     where
     238            0 :         T: MakeTlsConnect<TcpStream>,
     239            0 :     {
     240            0 :         connect(tls, self).await
     241            0 :     }
     242              : 
     243           15 :     pub async fn tls_and_authenticate<S, T>(
     244           15 :         &self,
     245           15 :         stream: S,
     246           15 :         tls: T,
     247           15 :     ) -> Result<RawConnection<S, T::Stream>, Error>
     248           15 :     where
     249           15 :         S: AsyncRead + AsyncWrite + Unpin,
     250           15 :         T: TlsConnect<S>,
     251            0 :     {
     252           15 :         let stream = connect_tls(stream, self.ssl_mode, tls).await?;
     253           15 :         connect_raw(stream, self).await
     254            0 :     }
     255              : 
     256            0 :     pub async fn authenticate<S, T>(
     257            0 :         &self,
     258            0 :         stream: MaybeTlsStream<S, T>,
     259            0 :     ) -> Result<RawConnection<S, T>, Error>
     260            0 :     where
     261            0 :         S: AsyncRead + AsyncWrite + Unpin,
     262            0 :         T: TlsStream + Unpin,
     263            0 :     {
     264            0 :         connect_raw(stream, self).await
     265            0 :     }
     266              : }
     267              : 
     268              : // Omit password from debug output
     269              : impl fmt::Debug for Config {
     270            0 :     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
     271              :         struct Redaction {}
     272              :         impl fmt::Debug for Redaction {
     273            0 :             fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
     274            0 :                 write!(f, "_")
     275            0 :             }
     276              :         }
     277              : 
     278            0 :         f.debug_struct("Config")
     279            0 :             .field("password", &self.password.as_ref().map(|_| Redaction {}))
     280            0 :             .field("ssl_mode", &self.ssl_mode)
     281            0 :             .field("host", &self.host)
     282            0 :             .field("port", &self.port)
     283            0 :             .field("connect_timeout", &self.connect_timeout)
     284            0 :             .field("channel_binding", &self.channel_binding)
     285            0 :             .field("server_params", &self.server_params)
     286            0 :             .finish()
     287            0 :     }
     288              : }
        

Generated by: LCOV version 2.1-beta