Line data Source code
1 : #![deny(unsafe_code)]
2 : #![deny(clippy::undocumented_unsafe_blocks)]
3 : use std::borrow::Cow;
4 : use std::fmt;
5 :
6 : use anyhow::{Context, bail};
7 : use itertools::Itertools;
8 : use url::Host;
9 :
10 : /// Parses a string of format either `host:port` or `host` into a corresponding pair.
11 : ///
12 : /// The `host` part should be a correct `url::Host`, while `port` (if present) should be
13 : /// a valid decimal u16 of digits only.
14 57 : pub fn parse_host_port<S: AsRef<str>>(host_port: S) -> Result<(Host, Option<u16>), anyhow::Error> {
15 57 : let (host, port) = match host_port.as_ref().rsplit_once(':') {
16 3 : Some((host, port)) => (
17 3 : host,
18 3 : // +80 is a valid u16, but not a valid port
19 7 : if port.chars().all(|c| c.is_ascii_digit()) {
20 2 : Some(port.parse::<u16>().context("Unable to parse port")?)
21 : } else {
22 1 : bail!("Port contains a non-ascii-digit")
23 : },
24 : ),
25 54 : None => (host_port.as_ref(), None), // No colons, no port specified
26 : };
27 56 : let host = Host::parse(host).context("Unable to parse host")?;
28 55 : Ok((host, port))
29 5 : }
30 :
31 : #[cfg(test)]
32 : mod tests_parse_host_port {
33 : use url::Host;
34 :
35 : use crate::parse_host_port;
36 :
37 : #[test]
38 1 : fn test_normal() {
39 1 : let (host, port) = parse_host_port("hello:123").unwrap();
40 1 : assert_eq!(host, Host::Domain("hello".to_owned()));
41 1 : assert_eq!(port, Some(123));
42 1 : }
43 :
44 : #[test]
45 1 : fn test_no_port() {
46 1 : let (host, port) = parse_host_port("hello").unwrap();
47 1 : assert_eq!(host, Host::Domain("hello".to_owned()));
48 1 : assert_eq!(port, None);
49 1 : }
50 :
51 : #[test]
52 1 : fn test_ipv6() {
53 1 : let (host, port) = parse_host_port("[::1]:123").unwrap();
54 1 : assert_eq!(host, Host::<String>::Ipv6(std::net::Ipv6Addr::LOCALHOST));
55 1 : assert_eq!(port, Some(123));
56 1 : }
57 :
58 : #[test]
59 1 : fn test_invalid_host() {
60 1 : assert!(parse_host_port("hello world").is_err());
61 1 : }
62 :
63 : #[test]
64 1 : fn test_invalid_port() {
65 1 : assert!(parse_host_port("hello:+80").is_err());
66 1 : }
67 : }
68 :
69 : #[derive(Clone)]
70 : pub struct PgConnectionConfig {
71 : host: Host,
72 : port: u16,
73 : password: Option<String>,
74 : options: Vec<String>,
75 : }
76 :
77 : /// A simplified PostgreSQL connection configuration. Supports only a subset of possible
78 : /// settings for simplicity. A password getter or `to_connection_string` methods are not
79 : /// added by design to avoid accidentally leaking password through logging, command line
80 : /// arguments to a child process, or likewise.
81 : impl PgConnectionConfig {
82 56 : pub fn new_host_port(host: Host, port: u16) -> Self {
83 56 : PgConnectionConfig {
84 56 : host,
85 56 : port,
86 56 : password: None,
87 56 : options: vec![],
88 56 : }
89 56 : }
90 :
91 33 : pub fn host(&self) -> &Host {
92 33 : &self.host
93 33 : }
94 :
95 8 : pub fn port(&self) -> u16 {
96 8 : self.port
97 8 : }
98 :
99 0 : pub fn set_host(mut self, h: Host) -> Self {
100 0 : self.host = h;
101 0 : self
102 0 : }
103 :
104 0 : pub fn set_port(mut self, p: u16) -> Self {
105 0 : self.port = p;
106 0 : self
107 0 : }
108 :
109 53 : pub fn set_password(mut self, s: Option<String>) -> Self {
110 53 : self.password = s;
111 53 : self
112 53 : }
113 :
114 61 : pub fn extend_options<I: IntoIterator<Item = S>, S: Into<String>>(mut self, i: I) -> Self {
115 220 : self.options.extend(i.into_iter().map(|s| s.into()));
116 61 : self
117 61 : }
118 :
119 : /// Return a `<host>:<port>` string.
120 4 : pub fn raw_address(&self) -> String {
121 4 : format!("{}:{}", self.host(), self.port())
122 4 : }
123 :
124 : /// Build a client library-specific connection configuration.
125 : /// Used for testing and when we need to add some obscure configuration
126 : /// elements at the last moment.
127 1 : pub fn to_tokio_postgres_config(&self) -> tokio_postgres::Config {
128 1 : // Use `tokio_postgres::Config` instead of `postgres::Config` because
129 1 : // the former supports more options to fiddle with later.
130 1 : let mut config = tokio_postgres::Config::new();
131 1 : config.host(&self.host().to_string()).port(self.port);
132 1 : if let Some(password) = &self.password {
133 0 : config.password(password);
134 1 : }
135 1 : if !self.options.is_empty() {
136 1 : // These options are command-line options and should be escaped before being passed
137 1 : // as an 'options' connection string parameter, see
138 1 : // https://www.postgresql.org/docs/15/libpq-connect.html#LIBPQ-CONNECT-OPTIONS
139 1 : //
140 1 : // They will be space-separated, so each space inside an option should be escaped,
141 1 : // and all backslashes should be escaped before that. Although we don't expect options
142 1 : // with spaces at the moment, they're supported by PostgreSQL. Hence we support them
143 1 : // in this typesafe interface.
144 1 : //
145 1 : // We use `Cow` to avoid allocations in the best case (no escaping). A fully imperative
146 1 : // solution would require 1-2 allocations in the worst case as well, but it's harder to
147 1 : // implement and this function is hardly a bottleneck. The function is only called around
148 1 : // establishing a new connection.
149 1 : #[allow(unstable_name_collisions)]
150 1 : config.options(
151 1 : &self
152 1 : .options
153 1 : .iter()
154 4 : .map(|s| {
155 4 : if s.contains(['\\', ' ']) {
156 2 : Cow::Owned(s.replace('\\', "\\\\").replace(' ', "\\ "))
157 : } else {
158 2 : Cow::Borrowed(s.as_str())
159 : }
160 4 : })
161 1 : .intersperse(Cow::Borrowed(" ")) // TODO: use impl from std once it's stabilized
162 1 : .collect::<String>(),
163 1 : );
164 1 : }
165 1 : config
166 1 : }
167 :
168 : /// Connect using postgres protocol with TLS disabled.
169 0 : pub async fn connect_no_tls(
170 0 : &self,
171 0 : ) -> Result<
172 0 : (
173 0 : tokio_postgres::Client,
174 0 : tokio_postgres::Connection<tokio_postgres::Socket, tokio_postgres::tls::NoTlsStream>,
175 0 : ),
176 0 : tokio_postgres::Error,
177 0 : > {
178 0 : self.to_tokio_postgres_config()
179 0 : .connect(tokio_postgres::NoTls)
180 0 : .await
181 0 : }
182 : }
183 :
184 : impl fmt::Display for PgConnectionConfig {
185 0 : fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
186 0 : // The password is intentionally hidden and not part of this display string.
187 0 : write!(f, "postgresql://{}:{}", self.host, self.port)
188 0 : }
189 : }
190 :
191 : impl fmt::Debug for PgConnectionConfig {
192 3 : fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
193 3 : // We want `password: Some(REDACTED-STRING)`, not `password: Some("REDACTED-STRING")`
194 3 : // so even if the password is `REDACTED-STRING` (quite unlikely) there is no confusion.
195 3 : // Hence `format_args!()`, it returns a "safe" string which is not escaped by `Debug`.
196 3 : f.debug_struct("PgConnectionConfig")
197 3 : .field("host", &self.host)
198 3 : .field("port", &self.port)
199 3 : .field(
200 3 : "password",
201 3 : &self
202 3 : .password
203 3 : .as_ref()
204 3 : .map(|_| format_args!("REDACTED-STRING")),
205 3 : )
206 3 : .finish()
207 3 : }
208 : }
209 :
210 : #[cfg(test)]
211 : mod tests_pg_connection_config {
212 : use once_cell::sync::Lazy;
213 : use url::Host;
214 :
215 : use crate::PgConnectionConfig;
216 :
217 3 : static STUB_HOST: Lazy<Host> = Lazy::new(|| Host::Domain("stub.host.example".to_owned()));
218 :
219 : #[test]
220 1 : fn test_no_password() {
221 1 : let cfg = PgConnectionConfig::new_host_port(STUB_HOST.clone(), 123);
222 1 : assert_eq!(cfg.host(), &*STUB_HOST);
223 1 : assert_eq!(cfg.port(), 123);
224 1 : assert_eq!(cfg.raw_address(), "stub.host.example:123");
225 1 : assert_eq!(
226 1 : format!("{:?}", cfg),
227 1 : "PgConnectionConfig { host: Domain(\"stub.host.example\"), port: 123, password: None }"
228 1 : );
229 1 : }
230 :
231 : #[test]
232 1 : fn test_ipv6() {
233 1 : // May be a special case because hostname contains a colon.
234 1 : let cfg = PgConnectionConfig::new_host_port(Host::parse("[::1]").unwrap(), 123);
235 1 : assert_eq!(
236 1 : cfg.host(),
237 1 : &Host::<String>::Ipv6(std::net::Ipv6Addr::LOCALHOST)
238 1 : );
239 1 : assert_eq!(cfg.port(), 123);
240 1 : assert_eq!(cfg.raw_address(), "[::1]:123");
241 1 : assert_eq!(
242 1 : format!("{:?}", cfg),
243 1 : "PgConnectionConfig { host: Ipv6(::1), port: 123, password: None }"
244 1 : );
245 1 : }
246 :
247 : #[test]
248 1 : fn test_with_password() {
249 1 : let cfg = PgConnectionConfig::new_host_port(STUB_HOST.clone(), 123)
250 1 : .set_password(Some("password".to_owned()));
251 1 : assert_eq!(cfg.host(), &*STUB_HOST);
252 1 : assert_eq!(cfg.port(), 123);
253 1 : assert_eq!(cfg.raw_address(), "stub.host.example:123");
254 1 : assert_eq!(
255 1 : format!("{:?}", cfg),
256 1 : "PgConnectionConfig { host: Domain(\"stub.host.example\"), port: 123, password: Some(REDACTED-STRING) }"
257 1 : );
258 1 : }
259 :
260 : #[test]
261 1 : fn test_with_options() {
262 1 : let cfg = PgConnectionConfig::new_host_port(STUB_HOST.clone(), 123).extend_options([
263 1 : "hello",
264 1 : "world",
265 1 : "with space",
266 1 : "and \\ backslashes",
267 1 : ]);
268 1 : assert_eq!(cfg.host(), &*STUB_HOST);
269 1 : assert_eq!(cfg.port(), 123);
270 1 : assert_eq!(cfg.raw_address(), "stub.host.example:123");
271 1 : assert_eq!(
272 1 : cfg.to_tokio_postgres_config().get_options(),
273 1 : Some("hello world with\\ space and\\ \\\\\\ backslashes")
274 1 : );
275 1 : }
276 : }
|