Line data Source code
1 : //! Connection configuration.
2 :
3 : use std::net::IpAddr;
4 : use std::time::Duration;
5 : use std::{fmt, str};
6 :
7 : pub use postgres_protocol2::authentication::sasl::ScramKeys;
8 : use postgres_protocol2::message::frontend::StartupMessageParams;
9 : use serde::{Deserialize, Serialize};
10 : use tokio::io::{AsyncRead, AsyncWrite};
11 : use tokio::net::TcpStream;
12 :
13 : use crate::connect::connect;
14 : use crate::connect_raw::{RawConnection, connect_raw};
15 : use crate::tls::{MakeTlsConnect, TlsConnect};
16 : use crate::{Client, Connection, Error};
17 :
18 : /// TLS configuration.
19 0 : #[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)]
20 : #[non_exhaustive]
21 : pub enum SslMode {
22 : /// Do not use TLS.
23 : Disable,
24 : /// Attempt to connect with TLS but allow sessions without.
25 : Prefer,
26 : /// Require the use of TLS.
27 : Require,
28 : }
29 :
30 : /// Channel binding configuration.
31 : #[derive(Debug, Copy, Clone, PartialEq, Eq)]
32 : #[non_exhaustive]
33 : pub enum ChannelBinding {
34 : /// Do not use channel binding.
35 : Disable,
36 : /// Attempt to use channel binding but allow sessions without.
37 : Prefer,
38 : /// Require the use of channel binding.
39 : Require,
40 : }
41 :
42 : /// Replication mode configuration.
43 : #[derive(Debug, Copy, Clone, PartialEq, Eq)]
44 : #[non_exhaustive]
45 : pub enum ReplicationMode {
46 : /// Physical replication.
47 : Physical,
48 : /// Logical replication.
49 : Logical,
50 : }
51 :
52 : /// A host specification.
53 0 : #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
54 : pub enum Host {
55 : /// A TCP hostname.
56 : Tcp(String),
57 : }
58 :
59 : /// Precomputed keys which may override password during auth.
60 : #[derive(Debug, Clone, Copy, PartialEq, Eq)]
61 : pub enum AuthKeys {
62 : /// A `ClientKey` & `ServerKey` pair for `SCRAM-SHA-256`.
63 : ScramSha256(ScramKeys<32>),
64 : }
65 :
66 : /// Connection configuration.
67 : #[derive(Clone, PartialEq, Eq)]
68 : pub struct Config {
69 : pub(crate) host_addr: Option<IpAddr>,
70 : pub(crate) host: Host,
71 : pub(crate) port: u16,
72 :
73 : pub(crate) password: Option<Vec<u8>>,
74 : pub(crate) auth_keys: Option<Box<AuthKeys>>,
75 : pub(crate) ssl_mode: SslMode,
76 : pub(crate) connect_timeout: Option<Duration>,
77 : pub(crate) channel_binding: ChannelBinding,
78 : pub(crate) server_params: StartupMessageParams,
79 :
80 : database: bool,
81 : username: bool,
82 : }
83 :
84 : impl Config {
85 : /// Creates a new configuration.
86 25 : pub fn new(host: String, port: u16) -> Config {
87 25 : Config {
88 25 : host_addr: None,
89 25 : host: Host::Tcp(host),
90 25 : port,
91 25 : password: None,
92 25 : auth_keys: None,
93 25 : ssl_mode: SslMode::Prefer,
94 25 : connect_timeout: None,
95 25 : channel_binding: ChannelBinding::Prefer,
96 25 : server_params: StartupMessageParams::default(),
97 25 :
98 25 : database: false,
99 25 : username: false,
100 25 : }
101 25 : }
102 :
103 : /// Sets the user to authenticate with.
104 : ///
105 : /// Required.
106 15 : pub fn user(&mut self, user: &str) -> &mut Config {
107 15 : self.set_param("user", user)
108 15 : }
109 :
110 : /// Gets the user to authenticate with, if one has been configured with
111 : /// the `user` method.
112 0 : pub fn user_is_set(&self) -> bool {
113 0 : self.username
114 0 : }
115 :
116 : /// Sets the password to authenticate with.
117 22 : pub fn password<T>(&mut self, password: T) -> &mut Config
118 22 : where
119 22 : T: AsRef<[u8]>,
120 22 : {
121 22 : self.password = Some(password.as_ref().to_vec());
122 22 : self
123 22 : }
124 :
125 : /// Gets the password to authenticate with, if one has been configured with
126 : /// the `password` method.
127 15 : pub fn get_password(&self) -> Option<&[u8]> {
128 15 : self.password.as_deref()
129 15 : }
130 :
131 : /// Sets precomputed protocol-specific keys to authenticate with.
132 : /// When set, this option will override `password`.
133 : /// See [`AuthKeys`] for more information.
134 0 : pub fn auth_keys(&mut self, keys: AuthKeys) -> &mut Config {
135 0 : self.auth_keys = Some(Box::new(keys));
136 0 : self
137 0 : }
138 :
139 : /// Gets precomputed protocol-specific keys to authenticate with.
140 : /// if one has been configured with the `auth_keys` method.
141 15 : pub fn get_auth_keys(&self) -> Option<AuthKeys> {
142 15 : self.auth_keys.as_deref().copied()
143 15 : }
144 :
145 : /// Sets the name of the database to connect to.
146 : ///
147 : /// Defaults to the user.
148 15 : pub fn dbname(&mut self, dbname: &str) -> &mut Config {
149 15 : self.set_param("database", dbname)
150 15 : }
151 :
152 : /// Gets the name of the database to connect to, if one has been configured
153 : /// with the `dbname` method.
154 0 : pub fn db_is_set(&self) -> bool {
155 0 : self.database
156 0 : }
157 :
158 31 : pub fn set_param(&mut self, name: &str, value: &str) -> &mut Config {
159 31 : if name == "database" {
160 15 : self.database = true;
161 16 : } else if name == "user" {
162 15 : self.username = true;
163 15 : }
164 :
165 31 : self.server_params.insert(name, value);
166 31 : self
167 31 : }
168 :
169 0 : pub fn set_host_addr(&mut self, addr: IpAddr) -> &mut Config {
170 0 : self.host_addr = Some(addr);
171 0 : self
172 0 : }
173 :
174 0 : pub fn get_host_addr(&self) -> Option<IpAddr> {
175 0 : self.host_addr
176 0 : }
177 :
178 : /// Sets the SSL configuration.
179 : ///
180 : /// Defaults to `prefer`.
181 15 : pub fn ssl_mode(&mut self, ssl_mode: SslMode) -> &mut Config {
182 15 : self.ssl_mode = ssl_mode;
183 15 : self
184 15 : }
185 :
186 : /// Gets the SSL configuration.
187 0 : pub fn get_ssl_mode(&self) -> SslMode {
188 0 : self.ssl_mode
189 0 : }
190 :
191 : /// Gets the hosts that have been added to the configuration with `host`.
192 0 : pub fn get_host(&self) -> &Host {
193 0 : &self.host
194 0 : }
195 :
196 : /// Gets the ports that have been added to the configuration with `port`.
197 0 : pub fn get_port(&self) -> u16 {
198 0 : self.port
199 0 : }
200 :
201 : /// Sets the timeout applied to socket-level connection attempts.
202 : ///
203 : /// Note that hostnames can resolve to multiple IP addresses, and this timeout will apply to each address of each
204 : /// host separately. Defaults to no limit.
205 0 : pub fn connect_timeout(&mut self, connect_timeout: Duration) -> &mut Config {
206 0 : self.connect_timeout = Some(connect_timeout);
207 0 : self
208 0 : }
209 :
210 : /// Gets the connection timeout, if one has been set with the
211 : /// `connect_timeout` method.
212 0 : pub fn get_connect_timeout(&self) -> Option<&Duration> {
213 0 : self.connect_timeout.as_ref()
214 0 : }
215 :
216 : /// Sets the channel binding behavior.
217 : ///
218 : /// Defaults to `prefer`.
219 11 : pub fn channel_binding(&mut self, channel_binding: ChannelBinding) -> &mut Config {
220 11 : self.channel_binding = channel_binding;
221 11 : self
222 11 : }
223 :
224 : /// Gets the channel binding behavior.
225 0 : pub fn get_channel_binding(&self) -> ChannelBinding {
226 0 : self.channel_binding
227 0 : }
228 :
229 : /// Opens a connection to a PostgreSQL database.
230 : ///
231 : /// Requires the `runtime` Cargo feature (enabled by default).
232 0 : pub async fn connect<T>(
233 0 : &self,
234 0 : tls: T,
235 0 : ) -> Result<(Client, Connection<TcpStream, T::Stream>), Error>
236 0 : where
237 0 : T: MakeTlsConnect<TcpStream>,
238 0 : {
239 0 : connect(tls, self).await
240 0 : }
241 :
242 15 : pub async fn connect_raw<S, T>(
243 15 : &self,
244 15 : stream: S,
245 15 : tls: T,
246 15 : ) -> Result<RawConnection<S, T::Stream>, Error>
247 15 : where
248 15 : S: AsyncRead + AsyncWrite + Unpin,
249 15 : T: TlsConnect<S>,
250 15 : {
251 15 : connect_raw(stream, tls, self).await
252 0 : }
253 : }
254 :
255 : // Omit password from debug output
256 : impl fmt::Debug for Config {
257 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
258 : struct Redaction {}
259 : impl fmt::Debug for Redaction {
260 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
261 0 : write!(f, "_")
262 0 : }
263 : }
264 :
265 0 : f.debug_struct("Config")
266 0 : .field("password", &self.password.as_ref().map(|_| Redaction {}))
267 0 : .field("ssl_mode", &self.ssl_mode)
268 0 : .field("host", &self.host)
269 0 : .field("port", &self.port)
270 0 : .field("connect_timeout", &self.connect_timeout)
271 0 : .field("channel_binding", &self.channel_binding)
272 0 : .field("server_params", &self.server_params)
273 0 : .finish()
274 0 : }
275 : }
|