Line data Source code
1 : //! Connection configuration.
2 :
3 : use crate::connect::connect;
4 : use crate::connect_raw::connect_raw;
5 : use crate::connect_raw::RawConnection;
6 : use crate::tls::MakeTlsConnect;
7 : use crate::tls::TlsConnect;
8 : use crate::{Client, Connection, Error};
9 : use postgres_protocol2::message::frontend::StartupMessageParams;
10 : use serde::{Deserialize, Serialize};
11 : use std::fmt;
12 : use std::str;
13 : use std::time::Duration;
14 : use tokio::io::{AsyncRead, AsyncWrite};
15 :
16 : pub use postgres_protocol2::authentication::sasl::ScramKeys;
17 : use tokio::net::TcpStream;
18 :
19 : /// TLS configuration.
20 0 : #[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)]
21 : #[non_exhaustive]
22 : pub enum SslMode {
23 : /// Do not use TLS.
24 : Disable,
25 : /// Attempt to connect with TLS but allow sessions without.
26 : Prefer,
27 : /// Require the use of TLS.
28 : Require,
29 : }
30 :
31 : /// Channel binding configuration.
32 : #[derive(Debug, Copy, Clone, PartialEq, Eq)]
33 : #[non_exhaustive]
34 : pub enum ChannelBinding {
35 : /// Do not use channel binding.
36 : Disable,
37 : /// Attempt to use channel binding but allow sessions without.
38 : Prefer,
39 : /// Require the use of channel binding.
40 : Require,
41 : }
42 :
43 : /// Replication mode configuration.
44 : #[derive(Debug, Copy, Clone, PartialEq, Eq)]
45 : #[non_exhaustive]
46 : pub enum ReplicationMode {
47 : /// Physical replication.
48 : Physical,
49 : /// Logical replication.
50 : Logical,
51 : }
52 :
53 : /// A host specification.
54 0 : #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
55 : pub enum Host {
56 : /// A TCP hostname.
57 : Tcp(String),
58 : }
59 :
60 : /// Precomputed keys which may override password during auth.
61 : #[derive(Debug, Clone, Copy, PartialEq, Eq)]
62 : pub enum AuthKeys {
63 : /// A `ClientKey` & `ServerKey` pair for `SCRAM-SHA-256`.
64 : ScramSha256(ScramKeys<32>),
65 : }
66 :
67 : /// Connection configuration.
68 : #[derive(Clone, PartialEq, Eq)]
69 : pub struct Config {
70 : pub(crate) host: Host,
71 : pub(crate) port: u16,
72 :
73 : pub(crate) password: Option<Vec<u8>>,
74 : pub(crate) auth_keys: Option<Box<AuthKeys>>,
75 : pub(crate) ssl_mode: SslMode,
76 : pub(crate) connect_timeout: Option<Duration>,
77 : pub(crate) channel_binding: ChannelBinding,
78 : pub(crate) server_params: StartupMessageParams,
79 :
80 : database: bool,
81 : username: bool,
82 : }
83 :
84 : impl Config {
85 : /// Creates a new configuration.
86 25 : pub fn new(host: String, port: u16) -> Config {
87 25 : Config {
88 25 : host: Host::Tcp(host),
89 25 : port,
90 25 : password: None,
91 25 : auth_keys: None,
92 25 : ssl_mode: SslMode::Prefer,
93 25 : connect_timeout: None,
94 25 : channel_binding: ChannelBinding::Prefer,
95 25 : server_params: StartupMessageParams::default(),
96 25 :
97 25 : database: false,
98 25 : username: false,
99 25 : }
100 25 : }
101 :
102 : /// Sets the user to authenticate with.
103 : ///
104 : /// Required.
105 15 : pub fn user(&mut self, user: &str) -> &mut Config {
106 15 : self.set_param("user", user)
107 15 : }
108 :
109 : /// Gets the user to authenticate with, if one has been configured with
110 : /// the `user` method.
111 0 : pub fn user_is_set(&self) -> bool {
112 0 : self.username
113 0 : }
114 :
115 : /// Sets the password to authenticate with.
116 22 : pub fn password<T>(&mut self, password: T) -> &mut Config
117 22 : where
118 22 : T: AsRef<[u8]>,
119 22 : {
120 22 : self.password = Some(password.as_ref().to_vec());
121 22 : self
122 22 : }
123 :
124 : /// Gets the password to authenticate with, if one has been configured with
125 : /// the `password` method.
126 15 : pub fn get_password(&self) -> Option<&[u8]> {
127 15 : self.password.as_deref()
128 15 : }
129 :
130 : /// Sets precomputed protocol-specific keys to authenticate with.
131 : /// When set, this option will override `password`.
132 : /// See [`AuthKeys`] for more information.
133 0 : pub fn auth_keys(&mut self, keys: AuthKeys) -> &mut Config {
134 0 : self.auth_keys = Some(Box::new(keys));
135 0 : self
136 0 : }
137 :
138 : /// Gets precomputed protocol-specific keys to authenticate with.
139 : /// if one has been configured with the `auth_keys` method.
140 15 : pub fn get_auth_keys(&self) -> Option<AuthKeys> {
141 15 : self.auth_keys.as_deref().copied()
142 15 : }
143 :
144 : /// Sets the name of the database to connect to.
145 : ///
146 : /// Defaults to the user.
147 15 : pub fn dbname(&mut self, dbname: &str) -> &mut Config {
148 15 : self.set_param("database", dbname)
149 15 : }
150 :
151 : /// Gets the name of the database to connect to, if one has been configured
152 : /// with the `dbname` method.
153 0 : pub fn db_is_set(&self) -> bool {
154 0 : self.database
155 0 : }
156 :
157 31 : pub fn set_param(&mut self, name: &str, value: &str) -> &mut Config {
158 31 : if name == "database" {
159 15 : self.database = true;
160 16 : } else if name == "user" {
161 15 : self.username = true;
162 15 : }
163 :
164 31 : self.server_params.insert(name, value);
165 31 : self
166 31 : }
167 :
168 : /// Sets the SSL configuration.
169 : ///
170 : /// Defaults to `prefer`.
171 15 : pub fn ssl_mode(&mut self, ssl_mode: SslMode) -> &mut Config {
172 15 : self.ssl_mode = ssl_mode;
173 15 : self
174 15 : }
175 :
176 : /// Gets the SSL configuration.
177 0 : pub fn get_ssl_mode(&self) -> SslMode {
178 0 : self.ssl_mode
179 0 : }
180 :
181 : /// Gets the hosts that have been added to the configuration with `host`.
182 0 : pub fn get_host(&self) -> &Host {
183 0 : &self.host
184 0 : }
185 :
186 : /// Gets the ports that have been added to the configuration with `port`.
187 0 : pub fn get_port(&self) -> u16 {
188 0 : self.port
189 0 : }
190 :
191 : /// Sets the timeout applied to socket-level connection attempts.
192 : ///
193 : /// Note that hostnames can resolve to multiple IP addresses, and this timeout will apply to each address of each
194 : /// host separately. Defaults to no limit.
195 0 : pub fn connect_timeout(&mut self, connect_timeout: Duration) -> &mut Config {
196 0 : self.connect_timeout = Some(connect_timeout);
197 0 : self
198 0 : }
199 :
200 : /// Gets the connection timeout, if one has been set with the
201 : /// `connect_timeout` method.
202 0 : pub fn get_connect_timeout(&self) -> Option<&Duration> {
203 0 : self.connect_timeout.as_ref()
204 0 : }
205 :
206 : /// Sets the channel binding behavior.
207 : ///
208 : /// Defaults to `prefer`.
209 11 : pub fn channel_binding(&mut self, channel_binding: ChannelBinding) -> &mut Config {
210 11 : self.channel_binding = channel_binding;
211 11 : self
212 11 : }
213 :
214 : /// Gets the channel binding behavior.
215 0 : pub fn get_channel_binding(&self) -> ChannelBinding {
216 0 : self.channel_binding
217 0 : }
218 :
219 : /// Opens a connection to a PostgreSQL database.
220 : ///
221 : /// Requires the `runtime` Cargo feature (enabled by default).
222 0 : pub async fn connect<T>(
223 0 : &self,
224 0 : tls: T,
225 0 : ) -> Result<(Client, Connection<TcpStream, T::Stream>), Error>
226 0 : where
227 0 : T: MakeTlsConnect<TcpStream>,
228 0 : {
229 0 : connect(tls, self).await
230 0 : }
231 :
232 15 : pub async fn connect_raw<S, T>(
233 15 : &self,
234 15 : stream: S,
235 15 : tls: T,
236 15 : ) -> Result<RawConnection<S, T::Stream>, Error>
237 15 : where
238 15 : S: AsyncRead + AsyncWrite + Unpin,
239 15 : T: TlsConnect<S>,
240 15 : {
241 15 : connect_raw(stream, tls, self).await
242 15 : }
243 : }
244 :
245 : // Omit password from debug output
246 : impl fmt::Debug for Config {
247 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
248 : struct Redaction {}
249 : impl fmt::Debug for Redaction {
250 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
251 0 : write!(f, "_")
252 0 : }
253 : }
254 :
255 0 : f.debug_struct("Config")
256 0 : .field("password", &self.password.as_ref().map(|_| Redaction {}))
257 0 : .field("ssl_mode", &self.ssl_mode)
258 0 : .field("host", &self.host)
259 0 : .field("port", &self.port)
260 0 : .field("connect_timeout", &self.connect_timeout)
261 0 : .field("channel_binding", &self.channel_binding)
262 0 : .field("server_params", &self.server_params)
263 0 : .finish()
264 0 : }
265 : }
|