TLA Line data Source code
1 : //! User credentials used in authentication.
2 :
3 : use crate::{
4 : auth::password_hack::parse_endpoint_param, context::RequestMonitoring, error::UserFacingError,
5 : metrics::NUM_CONNECTION_ACCEPTED_BY_SNI, proxy::neon_options_str,
6 : };
7 : use itertools::Itertools;
8 : use pq_proto::StartupMessageParams;
9 : use smol_str::SmolStr;
10 : use std::{collections::HashSet, net::IpAddr};
11 : use thiserror::Error;
12 : use tracing::{info, warn};
13 :
14 UBC 0 : #[derive(Debug, Error, PartialEq, Eq, Clone)]
15 : pub enum ClientCredsParseError {
16 : #[error("Parameter '{0}' is missing in startup packet.")]
17 : MissingKey(&'static str),
18 :
19 : #[error(
20 : "Inconsistent project name inferred from \
21 : SNI ('{}') and project option ('{}').",
22 : .domain, .option,
23 : )]
24 : InconsistentProjectNames { domain: SmolStr, option: SmolStr },
25 :
26 : #[error(
27 : "Common name inferred from SNI ('{}') is not known",
28 : .cn,
29 : )]
30 : UnknownCommonName { cn: String },
31 :
32 : #[error("Project name ('{0}') must contain only alphanumeric characters and hyphen.")]
33 : MalformedProjectName(SmolStr),
34 : }
35 :
36 : impl UserFacingError for ClientCredsParseError {}
37 :
38 : /// Various client credentials which we use for authentication.
39 : /// Note that we don't store any kind of client key or password here.
40 0 : #[derive(Debug, Clone, PartialEq, Eq)]
41 : pub struct ClientCredentials {
42 : pub user: SmolStr,
43 : // TODO: this is a severe misnomer! We should think of a new name ASAP.
44 : pub project: Option<SmolStr>,
45 :
46 : pub cache_key: SmolStr,
47 : }
48 :
49 : impl ClientCredentials {
50 : #[inline]
51 CBC 46 : pub fn project(&self) -> Option<&str> {
52 46 : self.project.as_deref()
53 46 : }
54 : }
55 :
56 : impl ClientCredentials {
57 100 : pub fn parse(
58 100 : ctx: &mut RequestMonitoring,
59 100 : params: &StartupMessageParams,
60 100 : sni: Option<&str>,
61 100 : common_names: Option<HashSet<String>>,
62 100 : ) -> Result<Self, ClientCredsParseError> {
63 100 : use ClientCredsParseError::*;
64 100 :
65 100 : // Some parameters are stored in the startup message.
66 100 : let get_param = |key| params.get(key).ok_or(MissingKey(key));
67 100 : let user: SmolStr = get_param("user")?.into();
68 100 :
69 100 : // record the values if we have them
70 100 : ctx.set_application(params.get("application_name").map(SmolStr::from));
71 100 : ctx.set_user(user.clone());
72 100 : ctx.set_endpoint_id(sni.map(SmolStr::from));
73 100 :
74 100 : // Project name might be passed via PG's command-line options.
75 100 : let project_option = params
76 100 : .options_raw()
77 100 : .and_then(|options| {
78 94 : // We support both `project` (deprecated) and `endpoint` options for backward compatibility.
79 94 : // However, if both are present, we don't exactly know which one to use.
80 94 : // Therefore we require that only one of them is present.
81 94 : options
82 94 : .filter_map(parse_endpoint_param)
83 94 : .at_most_one()
84 94 : .ok()?
85 100 : })
86 100 : .map(|name| name.into());
87 :
88 100 : let project_from_domain = if let Some(sni_str) = sni {
89 77 : if let Some(cn) = common_names {
90 77 : let common_name_from_sni = sni_str.split_once('.').map(|(_, domain)| domain);
91 :
92 77 : let project = common_name_from_sni
93 77 : .and_then(|domain| {
94 77 : if cn.contains(domain) {
95 76 : subdomain_from_sni(sni_str, domain)
96 : } else {
97 1 : None
98 : }
99 77 : })
100 77 : .ok_or_else(|| UnknownCommonName {
101 1 : cn: common_name_from_sni.unwrap_or("").into(),
102 77 : })?;
103 :
104 76 : Some(project)
105 : } else {
106 UBC 0 : None
107 : }
108 : } else {
109 CBC 23 : None
110 : };
111 :
112 99 : let project = match (project_option, project_from_domain) {
113 : // Invariant: if we have both project name variants, they should match.
114 2 : (Some(option), Some(domain)) if option != domain => {
115 1 : Some(Err(InconsistentProjectNames { domain, option }))
116 : }
117 : // Invariant: project name may not contain certain characters.
118 98 : (a, b) => a.or(b).map(|name| match project_name_valid(&name) {
119 UBC 0 : false => Err(MalformedProjectName(name)),
120 CBC 91 : true => Ok(name),
121 98 : }),
122 : }
123 99 : .transpose()?;
124 :
125 98 : info!(%user, project = project.as_deref(), "credentials");
126 98 : if sni.is_some() {
127 75 : info!("Connection with sni");
128 75 : NUM_CONNECTION_ACCEPTED_BY_SNI
129 75 : .with_label_values(&["sni"])
130 75 : .inc();
131 23 : } else if project.is_some() {
132 16 : NUM_CONNECTION_ACCEPTED_BY_SNI
133 16 : .with_label_values(&["no_sni"])
134 16 : .inc();
135 16 : info!("Connection without sni");
136 : } else {
137 7 : NUM_CONNECTION_ACCEPTED_BY_SNI
138 7 : .with_label_values(&["password_hack"])
139 7 : .inc();
140 7 : info!("Connection with password hack");
141 : }
142 :
143 98 : let cache_key = format!(
144 98 : "{}{}",
145 98 : project.as_deref().unwrap_or(""),
146 98 : neon_options_str(params)
147 98 : )
148 98 : .into();
149 98 :
150 98 : Ok(Self {
151 98 : user,
152 98 : project,
153 98 : cache_key,
154 98 : })
155 100 : }
156 : }
157 :
158 86 : pub fn check_peer_addr_is_in_list(peer_addr: &IpAddr, ip_list: &Vec<String>) -> bool {
159 86 : if ip_list.is_empty() {
160 75 : return true;
161 11 : }
162 21 : for ip in ip_list {
163 : // We expect that all ip addresses from control plane are correct.
164 : // However, if some of them are broken, we still can check the others.
165 16 : match parse_ip_pattern(ip) {
166 15 : Ok(pattern) => {
167 15 : if check_ip(peer_addr, &pattern) {
168 6 : return true;
169 9 : }
170 : }
171 1 : Err(err) => warn!("Cannot parse ip: {}; err: {}", ip, err),
172 : }
173 : }
174 5 : false
175 86 : }
176 :
177 3 : #[derive(Debug, Clone, Eq, PartialEq)]
178 : enum IpPattern {
179 : Subnet(ipnet::IpNet),
180 : Range(IpAddr, IpAddr),
181 : Single(IpAddr),
182 : }
183 :
184 24 : fn parse_ip_pattern(pattern: &str) -> anyhow::Result<IpPattern> {
185 24 : if pattern.contains('/') {
186 2 : let subnet: ipnet::IpNet = pattern.parse()?;
187 1 : return Ok(IpPattern::Subnet(subnet));
188 22 : }
189 22 : if let Some((start, end)) = pattern.split_once('-') {
190 3 : let start: IpAddr = start.parse()?;
191 2 : let end: IpAddr = end.parse()?;
192 1 : return Ok(IpPattern::Range(start, end));
193 19 : }
194 19 : let addr: IpAddr = pattern.parse()?;
195 16 : Ok(IpPattern::Single(addr))
196 24 : }
197 :
198 25 : fn check_ip(ip: &IpAddr, pattern: &IpPattern) -> bool {
199 25 : match pattern {
200 3 : IpPattern::Subnet(subnet) => subnet.contains(ip),
201 5 : IpPattern::Range(start, end) => start <= ip && ip <= end,
202 17 : IpPattern::Single(addr) => addr == ip,
203 : }
204 25 : }
205 :
206 91 : fn project_name_valid(name: &str) -> bool {
207 628 : name.chars().all(|c| c.is_alphanumeric() || c == '-')
208 91 : }
209 :
210 76 : fn subdomain_from_sni(sni: &str, common_name: &str) -> Option<SmolStr> {
211 76 : sni.strip_suffix(common_name)?
212 76 : .strip_suffix('.')
213 76 : .map(SmolStr::from)
214 76 : }
215 :
216 : #[cfg(test)]
217 : mod tests {
218 : use super::*;
219 : use ClientCredsParseError::*;
220 :
221 1 : #[test]
222 1 : fn parse_bare_minimum() -> anyhow::Result<()> {
223 1 : // According to postgresql, only `user` should be required.
224 1 : let options = StartupMessageParams::new([("user", "john_doe")]);
225 1 : let mut ctx = RequestMonitoring::test();
226 1 : let creds = ClientCredentials::parse(&mut ctx, &options, None, None)?;
227 1 : assert_eq!(creds.user, "john_doe");
228 1 : assert_eq!(creds.project, None);
229 :
230 1 : Ok(())
231 1 : }
232 :
233 1 : #[test]
234 1 : fn parse_excessive() -> anyhow::Result<()> {
235 1 : let options = StartupMessageParams::new([
236 1 : ("user", "john_doe"),
237 1 : ("database", "world"), // should be ignored
238 1 : ("foo", "bar"), // should be ignored
239 1 : ]);
240 1 : let mut ctx = RequestMonitoring::test();
241 1 : let creds = ClientCredentials::parse(&mut ctx, &options, None, None)?;
242 1 : assert_eq!(creds.user, "john_doe");
243 1 : assert_eq!(creds.project, None);
244 :
245 1 : Ok(())
246 1 : }
247 :
248 1 : #[test]
249 1 : fn parse_project_from_sni() -> anyhow::Result<()> {
250 1 : let options = StartupMessageParams::new([("user", "john_doe")]);
251 1 :
252 1 : let sni = Some("foo.localhost");
253 1 : let common_names = Some(["localhost".into()].into());
254 1 :
255 1 : let mut ctx = RequestMonitoring::test();
256 1 : let creds = ClientCredentials::parse(&mut ctx, &options, sni, common_names)?;
257 1 : assert_eq!(creds.user, "john_doe");
258 1 : assert_eq!(creds.project.as_deref(), Some("foo"));
259 1 : assert_eq!(creds.cache_key, "foo");
260 :
261 1 : Ok(())
262 1 : }
263 :
264 1 : #[test]
265 1 : fn parse_project_from_options() -> anyhow::Result<()> {
266 1 : let options = StartupMessageParams::new([
267 1 : ("user", "john_doe"),
268 1 : ("options", "-ckey=1 project=bar -c geqo=off"),
269 1 : ]);
270 1 :
271 1 : let mut ctx = RequestMonitoring::test();
272 1 : let creds = ClientCredentials::parse(&mut ctx, &options, None, None)?;
273 1 : assert_eq!(creds.user, "john_doe");
274 1 : assert_eq!(creds.project.as_deref(), Some("bar"));
275 :
276 1 : Ok(())
277 1 : }
278 :
279 1 : #[test]
280 1 : fn parse_endpoint_from_options() -> anyhow::Result<()> {
281 1 : let options = StartupMessageParams::new([
282 1 : ("user", "john_doe"),
283 1 : ("options", "-ckey=1 endpoint=bar -c geqo=off"),
284 1 : ]);
285 1 :
286 1 : let mut ctx = RequestMonitoring::test();
287 1 : let creds = ClientCredentials::parse(&mut ctx, &options, None, None)?;
288 1 : assert_eq!(creds.user, "john_doe");
289 1 : assert_eq!(creds.project.as_deref(), Some("bar"));
290 :
291 1 : Ok(())
292 1 : }
293 :
294 1 : #[test]
295 1 : fn parse_three_endpoints_from_options() -> anyhow::Result<()> {
296 1 : let options = StartupMessageParams::new([
297 1 : ("user", "john_doe"),
298 1 : (
299 1 : "options",
300 1 : "-ckey=1 endpoint=one endpoint=two endpoint=three -c geqo=off",
301 1 : ),
302 1 : ]);
303 1 :
304 1 : let mut ctx = RequestMonitoring::test();
305 1 : let creds = ClientCredentials::parse(&mut ctx, &options, None, None)?;
306 1 : assert_eq!(creds.user, "john_doe");
307 1 : assert!(creds.project.is_none());
308 :
309 1 : Ok(())
310 1 : }
311 :
312 1 : #[test]
313 1 : fn parse_when_endpoint_and_project_are_in_options() -> anyhow::Result<()> {
314 1 : let options = StartupMessageParams::new([
315 1 : ("user", "john_doe"),
316 1 : ("options", "-ckey=1 endpoint=bar project=foo -c geqo=off"),
317 1 : ]);
318 1 :
319 1 : let mut ctx = RequestMonitoring::test();
320 1 : let creds = ClientCredentials::parse(&mut ctx, &options, None, None)?;
321 1 : assert_eq!(creds.user, "john_doe");
322 1 : assert!(creds.project.is_none());
323 :
324 1 : Ok(())
325 1 : }
326 :
327 1 : #[test]
328 1 : fn parse_projects_identical() -> anyhow::Result<()> {
329 1 : let options = StartupMessageParams::new([("user", "john_doe"), ("options", "project=baz")]);
330 1 :
331 1 : let sni = Some("baz.localhost");
332 1 : let common_names = Some(["localhost".into()].into());
333 1 :
334 1 : let mut ctx = RequestMonitoring::test();
335 1 : let creds = ClientCredentials::parse(&mut ctx, &options, sni, common_names)?;
336 1 : assert_eq!(creds.user, "john_doe");
337 1 : assert_eq!(creds.project.as_deref(), Some("baz"));
338 :
339 1 : Ok(())
340 1 : }
341 :
342 1 : #[test]
343 1 : fn parse_multi_common_names() -> anyhow::Result<()> {
344 1 : let options = StartupMessageParams::new([("user", "john_doe")]);
345 1 :
346 1 : let common_names = Some(["a.com".into(), "b.com".into()].into());
347 1 : let sni = Some("p1.a.com");
348 1 : let mut ctx = RequestMonitoring::test();
349 1 : let creds = ClientCredentials::parse(&mut ctx, &options, sni, common_names)?;
350 1 : assert_eq!(creds.project.as_deref(), Some("p1"));
351 :
352 1 : let common_names = Some(["a.com".into(), "b.com".into()].into());
353 1 : let sni = Some("p1.b.com");
354 1 : let mut ctx = RequestMonitoring::test();
355 1 : let creds = ClientCredentials::parse(&mut ctx, &options, sni, common_names)?;
356 1 : assert_eq!(creds.project.as_deref(), Some("p1"));
357 :
358 1 : Ok(())
359 1 : }
360 :
361 1 : #[test]
362 1 : fn parse_projects_different() {
363 1 : let options =
364 1 : StartupMessageParams::new([("user", "john_doe"), ("options", "project=first")]);
365 1 :
366 1 : let sni = Some("second.localhost");
367 1 : let common_names = Some(["localhost".into()].into());
368 1 :
369 1 : let mut ctx = RequestMonitoring::test();
370 1 : let err = ClientCredentials::parse(&mut ctx, &options, sni, common_names)
371 1 : .expect_err("should fail");
372 1 : match err {
373 1 : InconsistentProjectNames { domain, option } => {
374 1 : assert_eq!(option, "first");
375 1 : assert_eq!(domain, "second");
376 : }
377 UBC 0 : _ => panic!("bad error: {err:?}"),
378 : }
379 CBC 1 : }
380 :
381 1 : #[test]
382 1 : fn parse_inconsistent_sni() {
383 1 : let options = StartupMessageParams::new([("user", "john_doe")]);
384 1 :
385 1 : let sni = Some("project.localhost");
386 1 : let common_names = Some(["example.com".into()].into());
387 1 :
388 1 : let mut ctx = RequestMonitoring::test();
389 1 : let err = ClientCredentials::parse(&mut ctx, &options, sni, common_names)
390 1 : .expect_err("should fail");
391 1 : match err {
392 1 : UnknownCommonName { cn } => {
393 1 : assert_eq!(cn, "localhost");
394 : }
395 UBC 0 : _ => panic!("bad error: {err:?}"),
396 : }
397 CBC 1 : }
398 :
399 1 : #[test]
400 1 : fn parse_neon_options() -> anyhow::Result<()> {
401 1 : let options = StartupMessageParams::new([
402 1 : ("user", "john_doe"),
403 1 : ("options", "neon_lsn:0/2 neon_endpoint_type:read_write"),
404 1 : ]);
405 1 :
406 1 : let sni = Some("project.localhost");
407 1 : let common_names = Some(["localhost".into()].into());
408 1 : let mut ctx = RequestMonitoring::test();
409 1 : let creds = ClientCredentials::parse(&mut ctx, &options, sni, common_names)?;
410 1 : assert_eq!(creds.project.as_deref(), Some("project"));
411 1 : assert_eq!(creds.cache_key, "projectendpoint_type:read_write lsn:0/2");
412 :
413 1 : Ok(())
414 1 : }
415 :
416 1 : #[test]
417 1 : fn test_check_peer_addr_is_in_list() {
418 1 : let peer_addr = IpAddr::from([127, 0, 0, 1]);
419 1 : assert!(check_peer_addr_is_in_list(&peer_addr, &vec![]));
420 1 : assert!(check_peer_addr_is_in_list(
421 1 : &peer_addr,
422 1 : &vec!["127.0.0.1".into()]
423 1 : ));
424 1 : assert!(!check_peer_addr_is_in_list(
425 1 : &peer_addr,
426 1 : &vec!["8.8.8.8".into()]
427 1 : ));
428 : // If there is an incorrect address, it will be skipped.
429 1 : assert!(check_peer_addr_is_in_list(
430 1 : &peer_addr,
431 1 : &vec!["88.8.8".into(), "127.0.0.1".into()]
432 1 : ));
433 1 : }
434 1 : #[test]
435 1 : fn test_parse_ip_v4() -> anyhow::Result<()> {
436 1 : let peer_addr = IpAddr::from([127, 0, 0, 1]);
437 : // Ok
438 1 : assert_eq!(parse_ip_pattern("127.0.0.1")?, IpPattern::Single(peer_addr));
439 1 : assert_eq!(
440 1 : parse_ip_pattern("127.0.0.1/31")?,
441 1 : IpPattern::Subnet(ipnet::IpNet::new(peer_addr, 31)?)
442 : );
443 1 : assert_eq!(
444 1 : parse_ip_pattern("0.0.0.0-200.0.1.2")?,
445 1 : IpPattern::Range(IpAddr::from([0, 0, 0, 0]), IpAddr::from([200, 0, 1, 2]))
446 : );
447 :
448 : // Error
449 1 : assert!(parse_ip_pattern("300.0.1.2").is_err());
450 1 : assert!(parse_ip_pattern("30.1.2").is_err());
451 1 : assert!(parse_ip_pattern("127.0.0.1/33").is_err());
452 1 : assert!(parse_ip_pattern("127.0.0.1-127.0.3").is_err());
453 1 : assert!(parse_ip_pattern("1234.0.0.1-127.0.3.0").is_err());
454 1 : Ok(())
455 1 : }
456 :
457 1 : #[test]
458 1 : fn test_check_ipv4() -> anyhow::Result<()> {
459 1 : let peer_addr = IpAddr::from([127, 0, 0, 1]);
460 1 : let peer_addr_next = IpAddr::from([127, 0, 0, 2]);
461 1 : let peer_addr_prev = IpAddr::from([127, 0, 0, 0]);
462 : // Success
463 1 : assert!(check_ip(&peer_addr, &IpPattern::Single(peer_addr)));
464 1 : assert!(check_ip(
465 1 : &peer_addr,
466 1 : &IpPattern::Subnet(ipnet::IpNet::new(peer_addr_prev, 31)?)
467 : ));
468 1 : assert!(check_ip(
469 1 : &peer_addr,
470 1 : &IpPattern::Subnet(ipnet::IpNet::new(peer_addr_next, 30)?)
471 : ));
472 1 : assert!(check_ip(
473 1 : &peer_addr,
474 1 : &IpPattern::Range(IpAddr::from([0, 0, 0, 0]), IpAddr::from([200, 0, 1, 2]))
475 1 : ));
476 1 : assert!(check_ip(
477 1 : &peer_addr,
478 1 : &IpPattern::Range(peer_addr, peer_addr)
479 1 : ));
480 :
481 : // Not success
482 1 : assert!(!check_ip(&peer_addr, &IpPattern::Single(peer_addr_prev)));
483 UBC 0 : assert!(!check_ip(
484 CBC 1 : &peer_addr,
485 1 : &IpPattern::Subnet(ipnet::IpNet::new(peer_addr_next, 31)?)
486 : ));
487 1 : assert!(!check_ip(
488 1 : &peer_addr,
489 1 : &IpPattern::Range(IpAddr::from([0, 0, 0, 0]), peer_addr_prev)
490 1 : ));
491 1 : assert!(!check_ip(
492 1 : &peer_addr,
493 1 : &IpPattern::Range(peer_addr_next, IpAddr::from([128, 0, 0, 0]))
494 1 : ));
495 : // There is no check that for range start <= end. But it's fine as long as for all this cases the result is false.
496 1 : assert!(!check_ip(
497 1 : &peer_addr,
498 1 : &IpPattern::Range(peer_addr, peer_addr_prev)
499 1 : ));
500 1 : Ok(())
501 1 : }
502 : }
|