Line data Source code
1 : //! Connection configuration.
2 :
3 : use std::net::IpAddr;
4 : use std::time::Duration;
5 : use std::{fmt, str};
6 :
7 : pub use postgres_protocol2::authentication::sasl::ScramKeys;
8 : use postgres_protocol2::message::frontend::StartupMessageParams;
9 : use serde::{Deserialize, Serialize};
10 : use tokio::io::{AsyncRead, AsyncWrite};
11 : use tokio::net::TcpStream;
12 :
13 : use crate::connect::connect;
14 : use crate::connect_raw::{RawConnection, connect_raw};
15 : use crate::tls::{MakeTlsConnect, TlsConnect};
16 : use crate::{Client, Connection, Error};
17 :
18 : /// TLS configuration.
19 0 : #[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)]
20 : pub enum SslMode {
21 : /// Do not use TLS.
22 : Disable,
23 : /// Attempt to connect with TLS but allow sessions without.
24 : Prefer,
25 : /// Require the use of TLS.
26 : Require,
27 : }
28 :
29 : /// Channel binding configuration.
30 : #[derive(Debug, Copy, Clone, PartialEq, Eq)]
31 : #[non_exhaustive]
32 : pub enum ChannelBinding {
33 : /// Do not use channel binding.
34 : Disable,
35 : /// Attempt to use channel binding but allow sessions without.
36 : Prefer,
37 : /// Require the use of channel binding.
38 : Require,
39 : }
40 :
41 : /// Replication mode configuration.
42 : #[derive(Debug, Copy, Clone, PartialEq, Eq)]
43 : #[non_exhaustive]
44 : pub enum ReplicationMode {
45 : /// Physical replication.
46 : Physical,
47 : /// Logical replication.
48 : Logical,
49 : }
50 :
51 : /// A host specification.
52 0 : #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
53 : pub enum Host {
54 : /// A TCP hostname.
55 : Tcp(String),
56 : }
57 :
58 : /// Precomputed keys which may override password during auth.
59 : #[derive(Debug, Clone, Copy, PartialEq, Eq)]
60 : pub enum AuthKeys {
61 : /// A `ClientKey` & `ServerKey` pair for `SCRAM-SHA-256`.
62 : ScramSha256(ScramKeys<32>),
63 : }
64 :
65 : /// Connection configuration.
66 : #[derive(Clone, PartialEq, Eq)]
67 : pub struct Config {
68 : pub(crate) host_addr: Option<IpAddr>,
69 : pub(crate) host: Host,
70 : pub(crate) port: u16,
71 :
72 : pub(crate) password: Option<Vec<u8>>,
73 : pub(crate) auth_keys: Option<Box<AuthKeys>>,
74 : pub(crate) ssl_mode: SslMode,
75 : pub(crate) connect_timeout: Option<Duration>,
76 : pub(crate) channel_binding: ChannelBinding,
77 : pub(crate) server_params: StartupMessageParams,
78 :
79 : database: bool,
80 : username: bool,
81 : }
82 :
83 : impl Config {
84 : /// Creates a new configuration.
85 15 : pub fn new(host: String, port: u16) -> Config {
86 15 : Config {
87 15 : host_addr: None,
88 15 : host: Host::Tcp(host),
89 15 : port,
90 15 : password: None,
91 15 : auth_keys: None,
92 15 : ssl_mode: SslMode::Prefer,
93 15 : connect_timeout: None,
94 15 : channel_binding: ChannelBinding::Prefer,
95 15 : server_params: StartupMessageParams::default(),
96 15 :
97 15 : database: false,
98 15 : username: false,
99 15 : }
100 15 : }
101 :
102 : /// Sets the user to authenticate with.
103 : ///
104 : /// Required.
105 15 : pub fn user(&mut self, user: &str) -> &mut Config {
106 15 : self.set_param("user", user)
107 15 : }
108 :
109 : /// Gets the user to authenticate with, if one has been configured with
110 : /// the `user` method.
111 0 : pub fn user_is_set(&self) -> bool {
112 0 : self.username
113 0 : }
114 :
115 : /// Sets the password to authenticate with.
116 12 : pub fn password<T>(&mut self, password: T) -> &mut Config
117 12 : where
118 12 : T: AsRef<[u8]>,
119 12 : {
120 12 : self.password = Some(password.as_ref().to_vec());
121 12 : self
122 12 : }
123 :
124 : /// Gets the password to authenticate with, if one has been configured with
125 : /// the `password` method.
126 11 : pub fn get_password(&self) -> Option<&[u8]> {
127 11 : self.password.as_deref()
128 11 : }
129 :
130 : /// Sets precomputed protocol-specific keys to authenticate with.
131 : /// When set, this option will override `password`.
132 : /// See [`AuthKeys`] for more information.
133 0 : pub fn auth_keys(&mut self, keys: AuthKeys) -> &mut Config {
134 0 : self.auth_keys = Some(Box::new(keys));
135 0 : self
136 0 : }
137 :
138 : /// Gets precomputed protocol-specific keys to authenticate with.
139 : /// if one has been configured with the `auth_keys` method.
140 11 : pub fn get_auth_keys(&self) -> Option<AuthKeys> {
141 11 : self.auth_keys.as_deref().copied()
142 11 : }
143 :
144 : /// Sets the name of the database to connect to.
145 : ///
146 : /// Defaults to the user.
147 15 : pub fn dbname(&mut self, dbname: &str) -> &mut Config {
148 15 : self.set_param("database", dbname)
149 15 : }
150 :
151 : /// Gets the name of the database to connect to, if one has been configured
152 : /// with the `dbname` method.
153 0 : pub fn db_is_set(&self) -> bool {
154 0 : self.database
155 0 : }
156 :
157 31 : pub fn set_param(&mut self, name: &str, value: &str) -> &mut Config {
158 31 : if name == "database" {
159 15 : self.database = true;
160 16 : } else if name == "user" {
161 15 : self.username = true;
162 15 : }
163 :
164 31 : self.server_params.insert(name, value);
165 31 : self
166 31 : }
167 :
168 0 : pub fn set_host_addr(&mut self, addr: IpAddr) -> &mut Config {
169 0 : self.host_addr = Some(addr);
170 0 : self
171 0 : }
172 :
173 0 : pub fn get_host_addr(&self) -> Option<IpAddr> {
174 0 : self.host_addr
175 0 : }
176 :
177 : /// Sets the SSL configuration.
178 : ///
179 : /// Defaults to `prefer`.
180 15 : pub fn ssl_mode(&mut self, ssl_mode: SslMode) -> &mut Config {
181 15 : self.ssl_mode = ssl_mode;
182 15 : self
183 15 : }
184 :
185 : /// Gets the SSL configuration.
186 0 : pub fn get_ssl_mode(&self) -> SslMode {
187 0 : self.ssl_mode
188 0 : }
189 :
190 : /// Gets the hosts that have been added to the configuration with `host`.
191 0 : pub fn get_host(&self) -> &Host {
192 0 : &self.host
193 0 : }
194 :
195 : /// Gets the ports that have been added to the configuration with `port`.
196 0 : pub fn get_port(&self) -> u16 {
197 0 : self.port
198 0 : }
199 :
200 : /// Sets the timeout applied to socket-level connection attempts.
201 : ///
202 : /// Note that hostnames can resolve to multiple IP addresses, and this timeout will apply to each address of each
203 : /// host separately. Defaults to no limit.
204 0 : pub fn connect_timeout(&mut self, connect_timeout: Duration) -> &mut Config {
205 0 : self.connect_timeout = Some(connect_timeout);
206 0 : self
207 0 : }
208 :
209 : /// Gets the connection timeout, if one has been set with the
210 : /// `connect_timeout` method.
211 0 : pub fn get_connect_timeout(&self) -> Option<&Duration> {
212 0 : self.connect_timeout.as_ref()
213 0 : }
214 :
215 : /// Sets the channel binding behavior.
216 : ///
217 : /// Defaults to `prefer`.
218 11 : pub fn channel_binding(&mut self, channel_binding: ChannelBinding) -> &mut Config {
219 11 : self.channel_binding = channel_binding;
220 11 : self
221 11 : }
222 :
223 : /// Gets the channel binding behavior.
224 0 : pub fn get_channel_binding(&self) -> ChannelBinding {
225 0 : self.channel_binding
226 0 : }
227 :
228 : /// Opens a connection to a PostgreSQL database.
229 : ///
230 : /// Requires the `runtime` Cargo feature (enabled by default).
231 0 : pub async fn connect<T>(
232 0 : &self,
233 0 : tls: &T,
234 0 : ) -> Result<(Client, Connection<TcpStream, T::Stream>), Error>
235 0 : where
236 0 : T: MakeTlsConnect<TcpStream>,
237 0 : {
238 0 : connect(tls, self).await
239 0 : }
240 :
241 15 : pub async fn connect_raw<S, T>(
242 15 : &self,
243 15 : stream: S,
244 15 : tls: T,
245 15 : ) -> Result<RawConnection<S, T::Stream>, Error>
246 15 : where
247 15 : S: AsyncRead + AsyncWrite + Unpin,
248 15 : T: TlsConnect<S>,
249 15 : {
250 15 : connect_raw(stream, tls, self).await
251 0 : }
252 : }
253 :
254 : // Omit password from debug output
255 : impl fmt::Debug for Config {
256 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
257 : struct Redaction {}
258 : impl fmt::Debug for Redaction {
259 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
260 0 : write!(f, "_")
261 0 : }
262 : }
263 :
264 0 : f.debug_struct("Config")
265 0 : .field("password", &self.password.as_ref().map(|_| Redaction {}))
266 0 : .field("ssl_mode", &self.ssl_mode)
267 0 : .field("host", &self.host)
268 0 : .field("port", &self.port)
269 0 : .field("connect_timeout", &self.connect_timeout)
270 0 : .field("channel_binding", &self.channel_binding)
271 0 : .field("server_params", &self.server_params)
272 0 : .finish()
273 0 : }
274 : }
|