Line data Source code
1 : #![deny(unsafe_code)]
2 : #![deny(clippy::undocumented_unsafe_blocks)]
3 : use anyhow::{bail, Context};
4 : use itertools::Itertools;
5 : use std::borrow::Cow;
6 : use std::fmt;
7 : use url::Host;
8 :
9 : /// Parses a string of format either `host:port` or `host` into a corresponding pair.
10 : ///
11 : /// The `host` part should be a correct `url::Host`, while `port` (if present) should be
12 : /// a valid decimal u16 of digits only.
13 31 : pub fn parse_host_port<S: AsRef<str>>(host_port: S) -> Result<(Host, Option<u16>), anyhow::Error> {
14 31 : let (host, port) = match host_port.as_ref().rsplit_once(':') {
15 3 : Some((host, port)) => (
16 3 : host,
17 3 : // +80 is a valid u16, but not a valid port
18 7 : if port.chars().all(|c| c.is_ascii_digit()) {
19 2 : Some(port.parse::<u16>().context("Unable to parse port")?)
20 : } else {
21 1 : bail!("Port contains a non-ascii-digit")
22 : },
23 : ),
24 28 : None => (host_port.as_ref(), None), // No colons, no port specified
25 : };
26 30 : let host = Host::parse(host).context("Unable to parse host")?;
27 29 : Ok((host, port))
28 31 : }
29 :
30 : #[cfg(test)]
31 : mod tests_parse_host_port {
32 : use crate::parse_host_port;
33 : use url::Host;
34 :
35 : #[test]
36 1 : fn test_normal() {
37 1 : let (host, port) = parse_host_port("hello:123").unwrap();
38 1 : assert_eq!(host, Host::Domain("hello".to_owned()));
39 1 : assert_eq!(port, Some(123));
40 1 : }
41 :
42 : #[test]
43 1 : fn test_no_port() {
44 1 : let (host, port) = parse_host_port("hello").unwrap();
45 1 : assert_eq!(host, Host::Domain("hello".to_owned()));
46 1 : assert_eq!(port, None);
47 1 : }
48 :
49 : #[test]
50 1 : fn test_ipv6() {
51 1 : let (host, port) = parse_host_port("[::1]:123").unwrap();
52 1 : assert_eq!(host, Host::<String>::Ipv6(std::net::Ipv6Addr::LOCALHOST));
53 1 : assert_eq!(port, Some(123));
54 1 : }
55 :
56 : #[test]
57 1 : fn test_invalid_host() {
58 1 : assert!(parse_host_port("hello world").is_err());
59 1 : }
60 :
61 : #[test]
62 1 : fn test_invalid_port() {
63 1 : assert!(parse_host_port("hello:+80").is_err());
64 1 : }
65 : }
66 :
67 : #[derive(Clone)]
68 : pub struct PgConnectionConfig {
69 : host: Host,
70 : port: u16,
71 : password: Option<String>,
72 : options: Vec<String>,
73 : }
74 :
75 : /// A simplified PostgreSQL connection configuration. Supports only a subset of possible
76 : /// settings for simplicity. A password getter or `to_connection_string` methods are not
77 : /// added by design to avoid accidentally leaking password through logging, command line
78 : /// arguments to a child process, or likewise.
79 : impl PgConnectionConfig {
80 30 : pub fn new_host_port(host: Host, port: u16) -> Self {
81 30 : PgConnectionConfig {
82 30 : host,
83 30 : port,
84 30 : password: None,
85 30 : options: vec![],
86 30 : }
87 30 : }
88 :
89 21 : pub fn host(&self) -> &Host {
90 21 : &self.host
91 21 : }
92 :
93 8 : pub fn port(&self) -> u16 {
94 8 : self.port
95 8 : }
96 :
97 0 : pub fn set_host(mut self, h: Host) -> Self {
98 0 : self.host = h;
99 0 : self
100 0 : }
101 :
102 0 : pub fn set_port(mut self, p: u16) -> Self {
103 0 : self.port = p;
104 0 : self
105 0 : }
106 :
107 27 : pub fn set_password(mut self, s: Option<String>) -> Self {
108 27 : self.password = s;
109 27 : self
110 27 : }
111 :
112 31 : pub fn extend_options<I: IntoIterator<Item = S>, S: Into<String>>(mut self, i: I) -> Self {
113 190 : self.options.extend(i.into_iter().map(|s| s.into()));
114 31 : self
115 31 : }
116 :
117 : /// Return a `<host>:<port>` string.
118 4 : pub fn raw_address(&self) -> String {
119 4 : format!("{}:{}", self.host(), self.port())
120 4 : }
121 :
122 : /// Build a client library-specific connection configuration.
123 : /// Used for testing and when we need to add some obscure configuration
124 : /// elements at the last moment.
125 1 : pub fn to_tokio_postgres_config(&self) -> tokio_postgres::Config {
126 1 : // Use `tokio_postgres::Config` instead of `postgres::Config` because
127 1 : // the former supports more options to fiddle with later.
128 1 : let mut config = tokio_postgres::Config::new();
129 1 : config.host(&self.host().to_string()).port(self.port);
130 1 : if let Some(password) = &self.password {
131 0 : config.password(password);
132 1 : }
133 1 : if !self.options.is_empty() {
134 1 : // These options are command-line options and should be escaped before being passed
135 1 : // as an 'options' connection string parameter, see
136 1 : // https://www.postgresql.org/docs/15/libpq-connect.html#LIBPQ-CONNECT-OPTIONS
137 1 : //
138 1 : // They will be space-separated, so each space inside an option should be escaped,
139 1 : // and all backslashes should be escaped before that. Although we don't expect options
140 1 : // with spaces at the moment, they're supported by PostgreSQL. Hence we support them
141 1 : // in this typesafe interface.
142 1 : //
143 1 : // We use `Cow` to avoid allocations in the best case (no escaping). A fully imperative
144 1 : // solution would require 1-2 allocations in the worst case as well, but it's harder to
145 1 : // implement and this function is hardly a bottleneck. The function is only called around
146 1 : // establishing a new connection.
147 1 : #[allow(unstable_name_collisions)]
148 1 : config.options(
149 1 : &self
150 1 : .options
151 1 : .iter()
152 4 : .map(|s| {
153 4 : if s.contains(['\\', ' ']) {
154 2 : Cow::Owned(s.replace('\\', "\\\\").replace(' ', "\\ "))
155 : } else {
156 2 : Cow::Borrowed(s.as_str())
157 : }
158 4 : })
159 1 : .intersperse(Cow::Borrowed(" ")) // TODO: use impl from std once it's stabilized
160 1 : .collect::<String>(),
161 1 : );
162 1 : }
163 1 : config
164 1 : }
165 :
166 : /// Connect using postgres protocol with TLS disabled.
167 0 : pub async fn connect_no_tls(
168 0 : &self,
169 0 : ) -> Result<
170 0 : (
171 0 : tokio_postgres::Client,
172 0 : tokio_postgres::Connection<tokio_postgres::Socket, tokio_postgres::tls::NoTlsStream>,
173 0 : ),
174 0 : postgres::Error,
175 0 : > {
176 0 : self.to_tokio_postgres_config()
177 0 : .connect(postgres::NoTls)
178 0 : .await
179 0 : }
180 : }
181 :
182 : impl fmt::Display for PgConnectionConfig {
183 0 : fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
184 0 : // The password is intentionally hidden and not part of this display string.
185 0 : write!(f, "postgresql://{}:{}", self.host, self.port)
186 0 : }
187 : }
188 :
189 : impl fmt::Debug for PgConnectionConfig {
190 3 : fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
191 3 : // We want `password: Some(REDACTED-STRING)`, not `password: Some("REDACTED-STRING")`
192 3 : // so even if the password is `REDACTED-STRING` (quite unlikely) there is no confusion.
193 3 : // Hence `format_args!()`, it returns a "safe" string which is not escaped by `Debug`.
194 3 : f.debug_struct("PgConnectionConfig")
195 3 : .field("host", &self.host)
196 3 : .field("port", &self.port)
197 3 : .field(
198 3 : "password",
199 3 : &self
200 3 : .password
201 3 : .as_ref()
202 3 : .map(|_| format_args!("REDACTED-STRING")),
203 3 : )
204 3 : .finish()
205 3 : }
206 : }
207 :
208 : #[cfg(test)]
209 : mod tests_pg_connection_config {
210 : use crate::PgConnectionConfig;
211 : use once_cell::sync::Lazy;
212 : use url::Host;
213 :
214 3 : static STUB_HOST: Lazy<Host> = Lazy::new(|| Host::Domain("stub.host.example".to_owned()));
215 :
216 : #[test]
217 1 : fn test_no_password() {
218 1 : let cfg = PgConnectionConfig::new_host_port(STUB_HOST.clone(), 123);
219 1 : assert_eq!(cfg.host(), &*STUB_HOST);
220 1 : assert_eq!(cfg.port(), 123);
221 1 : assert_eq!(cfg.raw_address(), "stub.host.example:123");
222 1 : assert_eq!(
223 1 : format!("{:?}", cfg),
224 1 : "PgConnectionConfig { host: Domain(\"stub.host.example\"), port: 123, password: None }"
225 1 : );
226 1 : }
227 :
228 : #[test]
229 1 : fn test_ipv6() {
230 1 : // May be a special case because hostname contains a colon.
231 1 : let cfg = PgConnectionConfig::new_host_port(Host::parse("[::1]").unwrap(), 123);
232 1 : assert_eq!(
233 1 : cfg.host(),
234 1 : &Host::<String>::Ipv6(std::net::Ipv6Addr::LOCALHOST)
235 1 : );
236 1 : assert_eq!(cfg.port(), 123);
237 1 : assert_eq!(cfg.raw_address(), "[::1]:123");
238 1 : assert_eq!(
239 1 : format!("{:?}", cfg),
240 1 : "PgConnectionConfig { host: Ipv6(::1), port: 123, password: None }"
241 1 : );
242 1 : }
243 :
244 : #[test]
245 1 : fn test_with_password() {
246 1 : let cfg = PgConnectionConfig::new_host_port(STUB_HOST.clone(), 123)
247 1 : .set_password(Some("password".to_owned()));
248 1 : assert_eq!(cfg.host(), &*STUB_HOST);
249 1 : assert_eq!(cfg.port(), 123);
250 1 : assert_eq!(cfg.raw_address(), "stub.host.example:123");
251 1 : assert_eq!(
252 1 : format!("{:?}", cfg),
253 1 : "PgConnectionConfig { host: Domain(\"stub.host.example\"), port: 123, password: Some(REDACTED-STRING) }"
254 1 : );
255 1 : }
256 :
257 : #[test]
258 1 : fn test_with_options() {
259 1 : let cfg = PgConnectionConfig::new_host_port(STUB_HOST.clone(), 123).extend_options([
260 1 : "hello",
261 1 : "world",
262 1 : "with space",
263 1 : "and \\ backslashes",
264 1 : ]);
265 1 : assert_eq!(cfg.host(), &*STUB_HOST);
266 1 : assert_eq!(cfg.port(), 123);
267 1 : assert_eq!(cfg.raw_address(), "stub.host.example:123");
268 1 : assert_eq!(
269 1 : cfg.to_tokio_postgres_config().get_options(),
270 1 : Some("hello world with\\ space and\\ \\\\\\ backslashes")
271 1 : );
272 1 : }
273 : }
|