Line data Source code
1 : //! Connection configuration.
2 :
3 : use crate::connect::connect;
4 : use crate::connect_raw::connect_raw;
5 : use crate::connect_raw::RawConnection;
6 : use crate::tls::MakeTlsConnect;
7 : use crate::tls::TlsConnect;
8 : use crate::{Client, Connection, Error};
9 : use postgres_protocol2::message::frontend::StartupMessageParams;
10 : use std::fmt;
11 : use std::str;
12 : use std::time::Duration;
13 : use tokio::io::{AsyncRead, AsyncWrite};
14 :
15 : pub use postgres_protocol2::authentication::sasl::ScramKeys;
16 : use tokio::net::TcpStream;
17 :
18 : /// TLS configuration.
19 : #[derive(Debug, Copy, Clone, PartialEq, Eq)]
20 : #[non_exhaustive]
21 : pub enum SslMode {
22 : /// Do not use TLS.
23 : Disable,
24 : /// Attempt to connect with TLS but allow sessions without.
25 : Prefer,
26 : /// Require the use of TLS.
27 : Require,
28 : }
29 :
30 : /// Channel binding configuration.
31 : #[derive(Debug, Copy, Clone, PartialEq, Eq)]
32 : #[non_exhaustive]
33 : pub enum ChannelBinding {
34 : /// Do not use channel binding.
35 : Disable,
36 : /// Attempt to use channel binding but allow sessions without.
37 : Prefer,
38 : /// Require the use of channel binding.
39 : Require,
40 : }
41 :
42 : /// Replication mode configuration.
43 : #[derive(Debug, Copy, Clone, PartialEq, Eq)]
44 : #[non_exhaustive]
45 : pub enum ReplicationMode {
46 : /// Physical replication.
47 : Physical,
48 : /// Logical replication.
49 : Logical,
50 : }
51 :
52 : /// A host specification.
53 : #[derive(Debug, Clone, PartialEq, Eq)]
54 : pub enum Host {
55 : /// A TCP hostname.
56 : Tcp(String),
57 : }
58 :
59 : /// Precomputed keys which may override password during auth.
60 : #[derive(Debug, Clone, Copy, PartialEq, Eq)]
61 : pub enum AuthKeys {
62 : /// A `ClientKey` & `ServerKey` pair for `SCRAM-SHA-256`.
63 : ScramSha256(ScramKeys<32>),
64 : }
65 :
66 : /// Connection configuration.
67 : #[derive(Clone, PartialEq, Eq)]
68 : pub struct Config {
69 : pub(crate) host: Host,
70 : pub(crate) port: u16,
71 :
72 : pub(crate) password: Option<Vec<u8>>,
73 : pub(crate) auth_keys: Option<Box<AuthKeys>>,
74 : pub(crate) ssl_mode: SslMode,
75 : pub(crate) connect_timeout: Option<Duration>,
76 : pub(crate) channel_binding: ChannelBinding,
77 : pub(crate) server_params: StartupMessageParams,
78 :
79 : database: bool,
80 : username: bool,
81 : }
82 :
83 : impl Config {
84 : /// Creates a new configuration.
85 25 : pub fn new(host: String, port: u16) -> Config {
86 25 : Config {
87 25 : host: Host::Tcp(host),
88 25 : port,
89 25 : password: None,
90 25 : auth_keys: None,
91 25 : ssl_mode: SslMode::Prefer,
92 25 : connect_timeout: None,
93 25 : channel_binding: ChannelBinding::Prefer,
94 25 : server_params: StartupMessageParams::default(),
95 25 :
96 25 : database: false,
97 25 : username: false,
98 25 : }
99 25 : }
100 :
101 : /// Sets the user to authenticate with.
102 : ///
103 : /// Required.
104 15 : pub fn user(&mut self, user: &str) -> &mut Config {
105 15 : self.set_param("user", user)
106 15 : }
107 :
108 : /// Gets the user to authenticate with, if one has been configured with
109 : /// the `user` method.
110 0 : pub fn user_is_set(&self) -> bool {
111 0 : self.username
112 0 : }
113 :
114 : /// Sets the password to authenticate with.
115 22 : pub fn password<T>(&mut self, password: T) -> &mut Config
116 22 : where
117 22 : T: AsRef<[u8]>,
118 22 : {
119 22 : self.password = Some(password.as_ref().to_vec());
120 22 : self
121 22 : }
122 :
123 : /// Gets the password to authenticate with, if one has been configured with
124 : /// the `password` method.
125 15 : pub fn get_password(&self) -> Option<&[u8]> {
126 15 : self.password.as_deref()
127 15 : }
128 :
129 : /// Sets precomputed protocol-specific keys to authenticate with.
130 : /// When set, this option will override `password`.
131 : /// See [`AuthKeys`] for more information.
132 0 : pub fn auth_keys(&mut self, keys: AuthKeys) -> &mut Config {
133 0 : self.auth_keys = Some(Box::new(keys));
134 0 : self
135 0 : }
136 :
137 : /// Gets precomputed protocol-specific keys to authenticate with.
138 : /// if one has been configured with the `auth_keys` method.
139 15 : pub fn get_auth_keys(&self) -> Option<AuthKeys> {
140 15 : self.auth_keys.as_deref().copied()
141 15 : }
142 :
143 : /// Sets the name of the database to connect to.
144 : ///
145 : /// Defaults to the user.
146 15 : pub fn dbname(&mut self, dbname: &str) -> &mut Config {
147 15 : self.set_param("database", dbname)
148 15 : }
149 :
150 : /// Gets the name of the database to connect to, if one has been configured
151 : /// with the `dbname` method.
152 0 : pub fn db_is_set(&self) -> bool {
153 0 : self.database
154 0 : }
155 :
156 31 : pub fn set_param(&mut self, name: &str, value: &str) -> &mut Config {
157 31 : if name == "database" {
158 15 : self.database = true;
159 16 : } else if name == "user" {
160 15 : self.username = true;
161 15 : }
162 :
163 31 : self.server_params.insert(name, value);
164 31 : self
165 31 : }
166 :
167 : /// Sets the SSL configuration.
168 : ///
169 : /// Defaults to `prefer`.
170 15 : pub fn ssl_mode(&mut self, ssl_mode: SslMode) -> &mut Config {
171 15 : self.ssl_mode = ssl_mode;
172 15 : self
173 15 : }
174 :
175 : /// Gets the SSL configuration.
176 0 : pub fn get_ssl_mode(&self) -> SslMode {
177 0 : self.ssl_mode
178 0 : }
179 :
180 : /// Gets the hosts that have been added to the configuration with `host`.
181 0 : pub fn get_host(&self) -> &Host {
182 0 : &self.host
183 0 : }
184 :
185 : /// Gets the ports that have been added to the configuration with `port`.
186 0 : pub fn get_port(&self) -> u16 {
187 0 : self.port
188 0 : }
189 :
190 : /// Sets the timeout applied to socket-level connection attempts.
191 : ///
192 : /// Note that hostnames can resolve to multiple IP addresses, and this timeout will apply to each address of each
193 : /// host separately. Defaults to no limit.
194 0 : pub fn connect_timeout(&mut self, connect_timeout: Duration) -> &mut Config {
195 0 : self.connect_timeout = Some(connect_timeout);
196 0 : self
197 0 : }
198 :
199 : /// Gets the connection timeout, if one has been set with the
200 : /// `connect_timeout` method.
201 0 : pub fn get_connect_timeout(&self) -> Option<&Duration> {
202 0 : self.connect_timeout.as_ref()
203 0 : }
204 :
205 : /// Sets the channel binding behavior.
206 : ///
207 : /// Defaults to `prefer`.
208 11 : pub fn channel_binding(&mut self, channel_binding: ChannelBinding) -> &mut Config {
209 11 : self.channel_binding = channel_binding;
210 11 : self
211 11 : }
212 :
213 : /// Gets the channel binding behavior.
214 0 : pub fn get_channel_binding(&self) -> ChannelBinding {
215 0 : self.channel_binding
216 0 : }
217 :
218 : /// Opens a connection to a PostgreSQL database.
219 : ///
220 : /// Requires the `runtime` Cargo feature (enabled by default).
221 0 : pub async fn connect<T>(
222 0 : &self,
223 0 : tls: T,
224 0 : ) -> Result<(Client, Connection<TcpStream, T::Stream>), Error>
225 0 : where
226 0 : T: MakeTlsConnect<TcpStream>,
227 0 : {
228 0 : connect(tls, self).await
229 0 : }
230 :
231 15 : pub async fn connect_raw<S, T>(
232 15 : &self,
233 15 : stream: S,
234 15 : tls: T,
235 15 : ) -> Result<RawConnection<S, T::Stream>, Error>
236 15 : where
237 15 : S: AsyncRead + AsyncWrite + Unpin,
238 15 : T: TlsConnect<S>,
239 15 : {
240 15 : connect_raw(stream, tls, self).await
241 15 : }
242 : }
243 :
244 : // Omit password from debug output
245 : impl fmt::Debug for Config {
246 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
247 : struct Redaction {}
248 : impl fmt::Debug for Redaction {
249 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
250 0 : write!(f, "_")
251 0 : }
252 : }
253 :
254 0 : f.debug_struct("Config")
255 0 : .field("password", &self.password.as_ref().map(|_| Redaction {}))
256 0 : .field("ssl_mode", &self.ssl_mode)
257 0 : .field("host", &self.host)
258 0 : .field("port", &self.port)
259 0 : .field("connect_timeout", &self.connect_timeout)
260 0 : .field("channel_binding", &self.channel_binding)
261 0 : .field("server_params", &self.server_params)
262 0 : .finish()
263 0 : }
264 : }
|