Line data Source code
1 : //! Connection configuration.
2 :
3 : use std::net::IpAddr;
4 : use std::time::Duration;
5 : use std::{fmt, str};
6 :
7 : pub use postgres_protocol2::authentication::sasl::ScramKeys;
8 : use postgres_protocol2::message::frontend::StartupMessageParams;
9 : use serde::{Deserialize, Serialize};
10 : use tokio::io::{AsyncRead, AsyncWrite};
11 : use tokio::net::TcpStream;
12 :
13 : use crate::connect::connect;
14 : use crate::connect_raw::{RawConnection, connect_raw};
15 : use crate::connect_tls::connect_tls;
16 : use crate::maybe_tls_stream::MaybeTlsStream;
17 : use crate::tls::{MakeTlsConnect, TlsConnect, TlsStream};
18 : use crate::{Client, Connection, Error};
19 :
20 : /// TLS configuration.
21 0 : #[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)]
22 : pub enum SslMode {
23 : /// Do not use TLS.
24 : Disable,
25 : /// Attempt to connect with TLS but allow sessions without.
26 : Prefer,
27 : /// Require the use of TLS.
28 : Require,
29 : }
30 :
31 : /// Channel binding configuration.
32 : #[derive(Debug, Copy, Clone, PartialEq, Eq)]
33 : #[non_exhaustive]
34 : pub enum ChannelBinding {
35 : /// Do not use channel binding.
36 : Disable,
37 : /// Attempt to use channel binding but allow sessions without.
38 : Prefer,
39 : /// Require the use of channel binding.
40 : Require,
41 : }
42 :
43 : /// Replication mode configuration.
44 : #[derive(Debug, Copy, Clone, PartialEq, Eq)]
45 : #[non_exhaustive]
46 : pub enum ReplicationMode {
47 : /// Physical replication.
48 : Physical,
49 : /// Logical replication.
50 : Logical,
51 : }
52 :
53 : /// A host specification.
54 0 : #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
55 : pub enum Host {
56 : /// A TCP hostname.
57 : Tcp(String),
58 : }
59 :
60 : /// Precomputed keys which may override password during auth.
61 : #[derive(Debug, Clone, Copy, PartialEq, Eq)]
62 : pub enum AuthKeys {
63 : /// A `ClientKey` & `ServerKey` pair for `SCRAM-SHA-256`.
64 : ScramSha256(ScramKeys<32>),
65 : }
66 :
67 : /// Connection configuration.
68 : #[derive(Clone, PartialEq, Eq)]
69 : pub struct Config {
70 : pub(crate) host_addr: Option<IpAddr>,
71 : pub(crate) host: Host,
72 : pub(crate) port: u16,
73 :
74 : pub(crate) password: Option<Vec<u8>>,
75 : pub(crate) auth_keys: Option<Box<AuthKeys>>,
76 : pub(crate) ssl_mode: SslMode,
77 : pub(crate) connect_timeout: Option<Duration>,
78 : pub(crate) channel_binding: ChannelBinding,
79 : pub(crate) server_params: StartupMessageParams,
80 :
81 : database: bool,
82 : username: bool,
83 : }
84 :
85 : impl Config {
86 : /// Creates a new configuration.
87 15 : pub fn new(host: String, port: u16) -> Config {
88 15 : Config {
89 15 : host_addr: None,
90 15 : host: Host::Tcp(host),
91 15 : port,
92 15 : password: None,
93 15 : auth_keys: None,
94 15 : ssl_mode: SslMode::Prefer,
95 15 : connect_timeout: None,
96 15 : channel_binding: ChannelBinding::Prefer,
97 15 : server_params: StartupMessageParams::default(),
98 15 :
99 15 : database: false,
100 15 : username: false,
101 15 : }
102 15 : }
103 :
104 : /// Sets the user to authenticate with.
105 : ///
106 : /// Required.
107 15 : pub fn user(&mut self, user: &str) -> &mut Config {
108 15 : self.set_param("user", user)
109 15 : }
110 :
111 : /// Gets the user to authenticate with, if one has been configured with
112 : /// the `user` method.
113 0 : pub fn user_is_set(&self) -> bool {
114 0 : self.username
115 0 : }
116 :
117 : /// Sets the password to authenticate with.
118 12 : pub fn password<T>(&mut self, password: T) -> &mut Config
119 12 : where
120 12 : T: AsRef<[u8]>,
121 : {
122 12 : self.password = Some(password.as_ref().to_vec());
123 12 : self
124 0 : }
125 :
126 : /// Gets the password to authenticate with, if one has been configured with
127 : /// the `password` method.
128 11 : pub fn get_password(&self) -> Option<&[u8]> {
129 11 : self.password.as_deref()
130 11 : }
131 :
132 : /// Sets precomputed protocol-specific keys to authenticate with.
133 : /// When set, this option will override `password`.
134 : /// See [`AuthKeys`] for more information.
135 0 : pub fn auth_keys(&mut self, keys: AuthKeys) -> &mut Config {
136 0 : self.auth_keys = Some(Box::new(keys));
137 0 : self
138 0 : }
139 :
140 : /// Gets precomputed protocol-specific keys to authenticate with.
141 : /// if one has been configured with the `auth_keys` method.
142 11 : pub fn get_auth_keys(&self) -> Option<AuthKeys> {
143 11 : self.auth_keys.as_deref().copied()
144 11 : }
145 :
146 : /// Sets the name of the database to connect to.
147 : ///
148 : /// Defaults to the user.
149 15 : pub fn dbname(&mut self, dbname: &str) -> &mut Config {
150 15 : self.set_param("database", dbname)
151 15 : }
152 :
153 : /// Gets the name of the database to connect to, if one has been configured
154 : /// with the `dbname` method.
155 0 : pub fn db_is_set(&self) -> bool {
156 0 : self.database
157 0 : }
158 :
159 31 : pub fn set_param(&mut self, name: &str, value: &str) -> &mut Config {
160 31 : if name == "database" {
161 15 : self.database = true;
162 16 : } else if name == "user" {
163 15 : self.username = true;
164 15 : }
165 :
166 31 : self.server_params.insert(name, value);
167 31 : self
168 31 : }
169 :
170 0 : pub fn set_host_addr(&mut self, addr: IpAddr) -> &mut Config {
171 0 : self.host_addr = Some(addr);
172 0 : self
173 0 : }
174 :
175 0 : pub fn get_host_addr(&self) -> Option<IpAddr> {
176 0 : self.host_addr
177 0 : }
178 :
179 : /// Sets the SSL configuration.
180 : ///
181 : /// Defaults to `prefer`.
182 15 : pub fn ssl_mode(&mut self, ssl_mode: SslMode) -> &mut Config {
183 15 : self.ssl_mode = ssl_mode;
184 15 : self
185 15 : }
186 :
187 : /// Gets the SSL configuration.
188 0 : pub fn get_ssl_mode(&self) -> SslMode {
189 0 : self.ssl_mode
190 0 : }
191 :
192 : /// Gets the hosts that have been added to the configuration with `host`.
193 0 : pub fn get_host(&self) -> &Host {
194 0 : &self.host
195 0 : }
196 :
197 : /// Gets the ports that have been added to the configuration with `port`.
198 0 : pub fn get_port(&self) -> u16 {
199 0 : self.port
200 0 : }
201 :
202 : /// Sets the timeout applied to socket-level connection attempts.
203 : ///
204 : /// Note that hostnames can resolve to multiple IP addresses, and this timeout will apply to each address of each
205 : /// host separately. Defaults to no limit.
206 0 : pub fn connect_timeout(&mut self, connect_timeout: Duration) -> &mut Config {
207 0 : self.connect_timeout = Some(connect_timeout);
208 0 : self
209 0 : }
210 :
211 : /// Gets the connection timeout, if one has been set with the
212 : /// `connect_timeout` method.
213 0 : pub fn get_connect_timeout(&self) -> Option<&Duration> {
214 0 : self.connect_timeout.as_ref()
215 0 : }
216 :
217 : /// Sets the channel binding behavior.
218 : ///
219 : /// Defaults to `prefer`.
220 11 : pub fn channel_binding(&mut self, channel_binding: ChannelBinding) -> &mut Config {
221 11 : self.channel_binding = channel_binding;
222 11 : self
223 11 : }
224 :
225 : /// Gets the channel binding behavior.
226 0 : pub fn get_channel_binding(&self) -> ChannelBinding {
227 0 : self.channel_binding
228 0 : }
229 :
230 : /// Opens a connection to a PostgreSQL database.
231 : ///
232 : /// Requires the `runtime` Cargo feature (enabled by default).
233 0 : pub async fn connect<T>(
234 0 : &self,
235 0 : tls: &T,
236 0 : ) -> Result<(Client, Connection<TcpStream, T::Stream>), Error>
237 0 : where
238 0 : T: MakeTlsConnect<TcpStream>,
239 0 : {
240 0 : connect(tls, self).await
241 0 : }
242 :
243 15 : pub async fn tls_and_authenticate<S, T>(
244 15 : &self,
245 15 : stream: S,
246 15 : tls: T,
247 15 : ) -> Result<RawConnection<S, T::Stream>, Error>
248 15 : where
249 15 : S: AsyncRead + AsyncWrite + Unpin,
250 15 : T: TlsConnect<S>,
251 0 : {
252 15 : let stream = connect_tls(stream, self.ssl_mode, tls).await?;
253 15 : connect_raw(stream, self).await
254 0 : }
255 :
256 0 : pub async fn authenticate<S, T>(
257 0 : &self,
258 0 : stream: MaybeTlsStream<S, T>,
259 0 : ) -> Result<RawConnection<S, T>, Error>
260 0 : where
261 0 : S: AsyncRead + AsyncWrite + Unpin,
262 0 : T: TlsStream + Unpin,
263 0 : {
264 0 : connect_raw(stream, self).await
265 0 : }
266 : }
267 :
268 : // Omit password from debug output
269 : impl fmt::Debug for Config {
270 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
271 : struct Redaction {}
272 : impl fmt::Debug for Redaction {
273 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
274 0 : write!(f, "_")
275 0 : }
276 : }
277 :
278 0 : f.debug_struct("Config")
279 0 : .field("password", &self.password.as_ref().map(|_| Redaction {}))
280 0 : .field("ssl_mode", &self.ssl_mode)
281 0 : .field("host", &self.host)
282 0 : .field("port", &self.port)
283 0 : .field("connect_timeout", &self.connect_timeout)
284 0 : .field("channel_binding", &self.channel_binding)
285 0 : .field("server_params", &self.server_params)
286 0 : .finish()
287 0 : }
288 : }
|