Line data Source code
1 : //! User credentials used in authentication.
2 :
3 : use std::collections::HashSet;
4 : use std::net::IpAddr;
5 : use std::str::FromStr;
6 :
7 : use itertools::Itertools;
8 : use pq_proto::StartupMessageParams;
9 : use thiserror::Error;
10 : use tracing::{debug, warn};
11 :
12 : use crate::auth::password_hack::parse_endpoint_param;
13 : use crate::context::RequestContext;
14 : use crate::error::{ReportableError, UserFacingError};
15 : use crate::metrics::{Metrics, SniKind};
16 : use crate::proxy::NeonOptions;
17 : use crate::serverless::SERVERLESS_DRIVER_SNI;
18 : use crate::types::{EndpointId, RoleName};
19 :
20 : #[derive(Debug, Error, PartialEq, Eq, Clone)]
21 : pub(crate) enum ComputeUserInfoParseError {
22 : #[error("Parameter '{0}' is missing in startup packet.")]
23 : MissingKey(&'static str),
24 :
25 : #[error(
26 : "Inconsistent project name inferred from \
27 : SNI ('{}') and project option ('{}').",
28 : .domain, .option,
29 : )]
30 : InconsistentProjectNames {
31 : domain: EndpointId,
32 : option: EndpointId,
33 : },
34 :
35 : #[error(
36 : "Common name inferred from SNI ('{}') is not known",
37 : .cn,
38 : )]
39 : UnknownCommonName { cn: String },
40 :
41 : #[error("Project name ('{0}') must contain only alphanumeric characters and hyphen.")]
42 : MalformedProjectName(EndpointId),
43 : }
44 :
45 : impl UserFacingError for ComputeUserInfoParseError {}
46 :
47 : impl ReportableError for ComputeUserInfoParseError {
48 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
49 0 : crate::error::ErrorKind::User
50 0 : }
51 : }
52 :
53 : /// Various client credentials which we use for authentication.
54 : /// Note that we don't store any kind of client key or password here.
55 : #[derive(Debug, Clone, PartialEq, Eq)]
56 : pub(crate) struct ComputeUserInfoMaybeEndpoint {
57 : pub(crate) user: RoleName,
58 : pub(crate) endpoint_id: Option<EndpointId>,
59 : pub(crate) options: NeonOptions,
60 : }
61 :
62 : impl ComputeUserInfoMaybeEndpoint {
63 : #[inline]
64 0 : pub(crate) fn endpoint(&self) -> Option<&str> {
65 0 : self.endpoint_id.as_deref()
66 0 : }
67 : }
68 :
69 27 : pub(crate) fn endpoint_sni(
70 27 : sni: &str,
71 27 : common_names: &HashSet<String>,
72 27 : ) -> Result<Option<EndpointId>, ComputeUserInfoParseError> {
73 27 : let Some((subdomain, common_name)) = sni.split_once('.') else {
74 0 : return Err(ComputeUserInfoParseError::UnknownCommonName { cn: sni.into() });
75 : };
76 27 : if !common_names.contains(common_name) {
77 1 : return Err(ComputeUserInfoParseError::UnknownCommonName {
78 1 : cn: common_name.into(),
79 1 : });
80 26 : }
81 26 : if subdomain == SERVERLESS_DRIVER_SNI {
82 0 : return Ok(None);
83 26 : }
84 26 : Ok(Some(EndpointId::from(subdomain)))
85 27 : }
86 :
87 : impl ComputeUserInfoMaybeEndpoint {
88 13 : pub(crate) fn parse(
89 13 : ctx: &RequestContext,
90 13 : params: &StartupMessageParams,
91 13 : sni: Option<&str>,
92 13 : common_names: Option<&HashSet<String>>,
93 13 : ) -> Result<Self, ComputeUserInfoParseError> {
94 13 : // Some parameters are stored in the startup message.
95 13 : let get_param = |key| {
96 13 : params
97 13 : .get(key)
98 13 : .ok_or(ComputeUserInfoParseError::MissingKey(key))
99 13 : };
100 13 : let user: RoleName = get_param("user")?.into();
101 13 :
102 13 : // Project name might be passed via PG's command-line options.
103 13 : let endpoint_option = params
104 13 : .options_raw()
105 13 : .and_then(|options| {
106 7 : // We support both `project` (deprecated) and `endpoint` options for backward compatibility.
107 7 : // However, if both are present, we don't exactly know which one to use.
108 7 : // Therefore we require that only one of them is present.
109 7 : options
110 7 : .filter_map(parse_endpoint_param)
111 7 : .at_most_one()
112 7 : .ok()?
113 13 : })
114 13 : .map(|name| name.into());
115 :
116 13 : let endpoint_from_domain = if let Some(sni_str) = sni {
117 7 : if let Some(cn) = common_names {
118 7 : endpoint_sni(sni_str, cn)?
119 : } else {
120 0 : None
121 : }
122 : } else {
123 6 : None
124 : };
125 :
126 12 : let endpoint = match (endpoint_option, endpoint_from_domain) {
127 : // Invariant: if we have both project name variants, they should match.
128 2 : (Some(option), Some(domain)) if option != domain => {
129 1 : Some(Err(ComputeUserInfoParseError::InconsistentProjectNames {
130 1 : domain,
131 1 : option,
132 1 : }))
133 : }
134 : // Invariant: project name may not contain certain characters.
135 11 : (a, b) => a.or(b).map(|name| {
136 7 : if project_name_valid(name.as_ref()) {
137 7 : Ok(name)
138 : } else {
139 0 : Err(ComputeUserInfoParseError::MalformedProjectName(name))
140 : }
141 11 : }),
142 : }
143 12 : .transpose()?;
144 :
145 11 : if let Some(ep) = &endpoint {
146 7 : ctx.set_endpoint_id(ep.clone());
147 7 : }
148 :
149 11 : let metrics = Metrics::get();
150 11 : debug!(%user, "credentials");
151 11 : if sni.is_some() {
152 5 : debug!("Connection with sni");
153 5 : metrics.proxy.accepted_connections_by_sni.inc(SniKind::Sni);
154 6 : } else if endpoint.is_some() {
155 2 : metrics
156 2 : .proxy
157 2 : .accepted_connections_by_sni
158 2 : .inc(SniKind::NoSni);
159 2 : debug!("Connection without sni");
160 : } else {
161 4 : metrics
162 4 : .proxy
163 4 : .accepted_connections_by_sni
164 4 : .inc(SniKind::PasswordHack);
165 4 : debug!("Connection with password hack");
166 : }
167 :
168 11 : let options = NeonOptions::parse_params(params);
169 11 :
170 11 : Ok(Self {
171 11 : user,
172 11 : endpoint_id: endpoint,
173 11 : options,
174 11 : })
175 13 : }
176 : }
177 :
178 10 : pub(crate) fn check_peer_addr_is_in_list(peer_addr: &IpAddr, ip_list: &[IpPattern]) -> bool {
179 10 : ip_list.is_empty() || ip_list.iter().any(|pattern| check_ip(peer_addr, pattern))
180 10 : }
181 :
182 : #[derive(Debug, Clone, Eq, PartialEq)]
183 : pub(crate) enum IpPattern {
184 : Subnet(ipnet::IpNet),
185 : Range(IpAddr, IpAddr),
186 : Single(IpAddr),
187 : None,
188 : }
189 :
190 : impl<'de> serde::de::Deserialize<'de> for IpPattern {
191 10 : fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
192 10 : where
193 10 : D: serde::Deserializer<'de>,
194 10 : {
195 : struct StrVisitor;
196 : impl serde::de::Visitor<'_> for StrVisitor {
197 : type Value = IpPattern;
198 :
199 0 : fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
200 0 : write!(
201 0 : formatter,
202 0 : "comma separated list with ip address, ip address range, or ip address subnet mask"
203 0 : )
204 0 : }
205 :
206 10 : fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
207 10 : where
208 10 : E: serde::de::Error,
209 10 : {
210 10 : Ok(parse_ip_pattern(v).unwrap_or_else(|e| {
211 1 : warn!("Cannot parse ip pattern {v}: {e}");
212 1 : IpPattern::None
213 10 : }))
214 10 : }
215 : }
216 10 : deserializer.deserialize_str(StrVisitor)
217 10 : }
218 : }
219 :
220 : impl FromStr for IpPattern {
221 : type Err = anyhow::Error;
222 :
223 6 : fn from_str(s: &str) -> Result<Self, Self::Err> {
224 6 : parse_ip_pattern(s)
225 6 : }
226 : }
227 :
228 24 : fn parse_ip_pattern(pattern: &str) -> anyhow::Result<IpPattern> {
229 24 : if pattern.contains('/') {
230 2 : let subnet: ipnet::IpNet = pattern.parse()?;
231 1 : return Ok(IpPattern::Subnet(subnet));
232 22 : }
233 22 : if let Some((start, end)) = pattern.split_once('-') {
234 3 : let start: IpAddr = start.parse()?;
235 2 : let end: IpAddr = end.parse()?;
236 1 : return Ok(IpPattern::Range(start, end));
237 19 : }
238 19 : let addr: IpAddr = pattern.parse()?;
239 16 : Ok(IpPattern::Single(addr))
240 24 : }
241 :
242 16 : fn check_ip(ip: &IpAddr, pattern: &IpPattern) -> bool {
243 16 : match pattern {
244 3 : IpPattern::Subnet(subnet) => subnet.contains(ip),
245 5 : IpPattern::Range(start, end) => start <= ip && ip <= end,
246 7 : IpPattern::Single(addr) => addr == ip,
247 1 : IpPattern::None => false,
248 : }
249 16 : }
250 :
251 7 : fn project_name_valid(name: &str) -> bool {
252 23 : name.chars().all(|c| c.is_alphanumeric() || c == '-')
253 7 : }
254 :
255 : #[cfg(test)]
256 : #[expect(clippy::unwrap_used)]
257 : mod tests {
258 : use ComputeUserInfoParseError::*;
259 : use serde_json::json;
260 :
261 : use super::*;
262 :
263 : #[test]
264 1 : fn parse_bare_minimum() -> anyhow::Result<()> {
265 1 : // According to postgresql, only `user` should be required.
266 1 : let options = StartupMessageParams::new([("user", "john_doe")]);
267 1 : let ctx = RequestContext::test();
268 1 : let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
269 1 : assert_eq!(user_info.user, "john_doe");
270 1 : assert_eq!(user_info.endpoint_id, None);
271 :
272 1 : Ok(())
273 1 : }
274 :
275 : #[test]
276 1 : fn parse_excessive() -> anyhow::Result<()> {
277 1 : let options = StartupMessageParams::new([
278 1 : ("user", "john_doe"),
279 1 : ("database", "world"), // should be ignored
280 1 : ("foo", "bar"), // should be ignored
281 1 : ]);
282 1 : let ctx = RequestContext::test();
283 1 : let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
284 1 : assert_eq!(user_info.user, "john_doe");
285 1 : assert_eq!(user_info.endpoint_id, None);
286 :
287 1 : Ok(())
288 1 : }
289 :
290 : #[test]
291 1 : fn parse_project_from_sni() -> anyhow::Result<()> {
292 1 : let options = StartupMessageParams::new([("user", "john_doe")]);
293 1 :
294 1 : let sni = Some("foo.localhost");
295 1 : let common_names = Some(["localhost".into()].into());
296 1 :
297 1 : let ctx = RequestContext::test();
298 1 : let user_info =
299 1 : ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
300 1 : assert_eq!(user_info.user, "john_doe");
301 1 : assert_eq!(user_info.endpoint_id.as_deref(), Some("foo"));
302 1 : assert_eq!(user_info.options.get_cache_key("foo"), "foo");
303 :
304 1 : Ok(())
305 1 : }
306 :
307 : #[test]
308 1 : fn parse_project_from_options() -> anyhow::Result<()> {
309 1 : let options = StartupMessageParams::new([
310 1 : ("user", "john_doe"),
311 1 : ("options", "-ckey=1 project=bar -c geqo=off"),
312 1 : ]);
313 1 :
314 1 : let ctx = RequestContext::test();
315 1 : let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
316 1 : assert_eq!(user_info.user, "john_doe");
317 1 : assert_eq!(user_info.endpoint_id.as_deref(), Some("bar"));
318 :
319 1 : Ok(())
320 1 : }
321 :
322 : #[test]
323 1 : fn parse_endpoint_from_options() -> anyhow::Result<()> {
324 1 : let options = StartupMessageParams::new([
325 1 : ("user", "john_doe"),
326 1 : ("options", "-ckey=1 endpoint=bar -c geqo=off"),
327 1 : ]);
328 1 :
329 1 : let ctx = RequestContext::test();
330 1 : let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
331 1 : assert_eq!(user_info.user, "john_doe");
332 1 : assert_eq!(user_info.endpoint_id.as_deref(), Some("bar"));
333 :
334 1 : Ok(())
335 1 : }
336 :
337 : #[test]
338 1 : fn parse_three_endpoints_from_options() -> anyhow::Result<()> {
339 1 : let options = StartupMessageParams::new([
340 1 : ("user", "john_doe"),
341 1 : (
342 1 : "options",
343 1 : "-ckey=1 endpoint=one endpoint=two endpoint=three -c geqo=off",
344 1 : ),
345 1 : ]);
346 1 :
347 1 : let ctx = RequestContext::test();
348 1 : let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
349 1 : assert_eq!(user_info.user, "john_doe");
350 1 : assert!(user_info.endpoint_id.is_none());
351 :
352 1 : Ok(())
353 1 : }
354 :
355 : #[test]
356 1 : fn parse_when_endpoint_and_project_are_in_options() -> anyhow::Result<()> {
357 1 : let options = StartupMessageParams::new([
358 1 : ("user", "john_doe"),
359 1 : ("options", "-ckey=1 endpoint=bar project=foo -c geqo=off"),
360 1 : ]);
361 1 :
362 1 : let ctx = RequestContext::test();
363 1 : let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
364 1 : assert_eq!(user_info.user, "john_doe");
365 1 : assert!(user_info.endpoint_id.is_none());
366 :
367 1 : Ok(())
368 1 : }
369 :
370 : #[test]
371 1 : fn parse_projects_identical() -> anyhow::Result<()> {
372 1 : let options = StartupMessageParams::new([("user", "john_doe"), ("options", "project=baz")]);
373 1 :
374 1 : let sni = Some("baz.localhost");
375 1 : let common_names = Some(["localhost".into()].into());
376 1 :
377 1 : let ctx = RequestContext::test();
378 1 : let user_info =
379 1 : ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
380 1 : assert_eq!(user_info.user, "john_doe");
381 1 : assert_eq!(user_info.endpoint_id.as_deref(), Some("baz"));
382 :
383 1 : Ok(())
384 1 : }
385 :
386 : #[test]
387 1 : fn parse_multi_common_names() -> anyhow::Result<()> {
388 1 : let options = StartupMessageParams::new([("user", "john_doe")]);
389 1 :
390 1 : let common_names = Some(["a.com".into(), "b.com".into()].into());
391 1 : let sni = Some("p1.a.com");
392 1 : let ctx = RequestContext::test();
393 1 : let user_info =
394 1 : ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
395 1 : assert_eq!(user_info.endpoint_id.as_deref(), Some("p1"));
396 :
397 1 : let common_names = Some(["a.com".into(), "b.com".into()].into());
398 1 : let sni = Some("p1.b.com");
399 1 : let ctx = RequestContext::test();
400 1 : let user_info =
401 1 : ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
402 1 : assert_eq!(user_info.endpoint_id.as_deref(), Some("p1"));
403 :
404 1 : Ok(())
405 1 : }
406 :
407 : #[test]
408 1 : fn parse_projects_different() {
409 1 : let options =
410 1 : StartupMessageParams::new([("user", "john_doe"), ("options", "project=first")]);
411 1 :
412 1 : let sni = Some("second.localhost");
413 1 : let common_names = Some(["localhost".into()].into());
414 1 :
415 1 : let ctx = RequestContext::test();
416 1 : let err = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())
417 1 : .expect_err("should fail");
418 1 : match err {
419 1 : InconsistentProjectNames { domain, option } => {
420 1 : assert_eq!(option, "first");
421 1 : assert_eq!(domain, "second");
422 : }
423 0 : _ => panic!("bad error: {err:?}"),
424 : }
425 1 : }
426 :
427 : #[test]
428 1 : fn parse_inconsistent_sni() {
429 1 : let options = StartupMessageParams::new([("user", "john_doe")]);
430 1 :
431 1 : let sni = Some("project.localhost");
432 1 : let common_names = Some(["example.com".into()].into());
433 1 :
434 1 : let ctx = RequestContext::test();
435 1 : let err = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())
436 1 : .expect_err("should fail");
437 1 : match err {
438 1 : UnknownCommonName { cn } => {
439 1 : assert_eq!(cn, "localhost");
440 : }
441 0 : _ => panic!("bad error: {err:?}"),
442 : }
443 1 : }
444 :
445 : #[test]
446 1 : fn parse_neon_options() -> anyhow::Result<()> {
447 1 : let options = StartupMessageParams::new([
448 1 : ("user", "john_doe"),
449 1 : ("options", "neon_lsn:0/2 neon_endpoint_type:read_write"),
450 1 : ]);
451 1 :
452 1 : let sni = Some("project.localhost");
453 1 : let common_names = Some(["localhost".into()].into());
454 1 : let ctx = RequestContext::test();
455 1 : let user_info =
456 1 : ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
457 1 : assert_eq!(user_info.endpoint_id.as_deref(), Some("project"));
458 1 : assert_eq!(
459 1 : user_info.options.get_cache_key("project"),
460 1 : "project endpoint_type:read_write lsn:0/2"
461 1 : );
462 :
463 1 : Ok(())
464 1 : }
465 :
466 : #[test]
467 1 : fn test_check_peer_addr_is_in_list() {
468 4 : fn check(v: serde_json::Value) -> bool {
469 4 : let peer_addr = IpAddr::from([127, 0, 0, 1]);
470 4 : let ip_list: Vec<IpPattern> = serde_json::from_value(v).unwrap();
471 4 : check_peer_addr_is_in_list(&peer_addr, &ip_list)
472 4 : }
473 :
474 1 : assert!(check(json!([])));
475 1 : assert!(check(json!(["127.0.0.1"])));
476 1 : assert!(!check(json!(["8.8.8.8"])));
477 : // If there is an incorrect address, it will be skipped.
478 1 : assert!(check(json!(["88.8.8", "127.0.0.1"])));
479 1 : }
480 : #[test]
481 1 : fn test_parse_ip_v4() -> anyhow::Result<()> {
482 1 : let peer_addr = IpAddr::from([127, 0, 0, 1]);
483 : // Ok
484 1 : assert_eq!(parse_ip_pattern("127.0.0.1")?, IpPattern::Single(peer_addr));
485 1 : assert_eq!(
486 1 : parse_ip_pattern("127.0.0.1/31")?,
487 1 : IpPattern::Subnet(ipnet::IpNet::new(peer_addr, 31)?)
488 : );
489 1 : assert_eq!(
490 1 : parse_ip_pattern("0.0.0.0-200.0.1.2")?,
491 1 : IpPattern::Range(IpAddr::from([0, 0, 0, 0]), IpAddr::from([200, 0, 1, 2]))
492 : );
493 :
494 : // Error
495 1 : assert!(parse_ip_pattern("300.0.1.2").is_err());
496 1 : assert!(parse_ip_pattern("30.1.2").is_err());
497 1 : assert!(parse_ip_pattern("127.0.0.1/33").is_err());
498 1 : assert!(parse_ip_pattern("127.0.0.1-127.0.3").is_err());
499 1 : assert!(parse_ip_pattern("1234.0.0.1-127.0.3.0").is_err());
500 1 : Ok(())
501 1 : }
502 :
503 : #[test]
504 1 : fn test_check_ipv4() -> anyhow::Result<()> {
505 1 : let peer_addr = IpAddr::from([127, 0, 0, 1]);
506 1 : let peer_addr_next = IpAddr::from([127, 0, 0, 2]);
507 1 : let peer_addr_prev = IpAddr::from([127, 0, 0, 0]);
508 1 : // Success
509 1 : assert!(check_ip(&peer_addr, &IpPattern::Single(peer_addr)));
510 1 : assert!(check_ip(
511 1 : &peer_addr,
512 1 : &IpPattern::Subnet(ipnet::IpNet::new(peer_addr_prev, 31)?)
513 : ));
514 1 : assert!(check_ip(
515 1 : &peer_addr,
516 1 : &IpPattern::Subnet(ipnet::IpNet::new(peer_addr_next, 30)?)
517 : ));
518 1 : assert!(check_ip(
519 1 : &peer_addr,
520 1 : &IpPattern::Range(IpAddr::from([0, 0, 0, 0]), IpAddr::from([200, 0, 1, 2]))
521 1 : ));
522 1 : assert!(check_ip(
523 1 : &peer_addr,
524 1 : &IpPattern::Range(peer_addr, peer_addr)
525 1 : ));
526 :
527 : // Not success
528 1 : assert!(!check_ip(&peer_addr, &IpPattern::Single(peer_addr_prev)));
529 1 : assert!(!check_ip(
530 1 : &peer_addr,
531 1 : &IpPattern::Subnet(ipnet::IpNet::new(peer_addr_next, 31)?)
532 : ));
533 1 : assert!(!check_ip(
534 1 : &peer_addr,
535 1 : &IpPattern::Range(IpAddr::from([0, 0, 0, 0]), peer_addr_prev)
536 1 : ));
537 1 : assert!(!check_ip(
538 1 : &peer_addr,
539 1 : &IpPattern::Range(peer_addr_next, IpAddr::from([128, 0, 0, 0]))
540 1 : ));
541 : // There is no check that for range start <= end. But it's fine as long as for all this cases the result is false.
542 1 : assert!(!check_ip(
543 1 : &peer_addr,
544 1 : &IpPattern::Range(peer_addr, peer_addr_prev)
545 1 : ));
546 1 : Ok(())
547 1 : }
548 :
549 : #[test]
550 1 : fn test_connection_blocker() {
551 3 : fn check(v: serde_json::Value) -> bool {
552 3 : let peer_addr = IpAddr::from([127, 0, 0, 1]);
553 3 : let ip_list: Vec<IpPattern> = serde_json::from_value(v).unwrap();
554 3 : check_peer_addr_is_in_list(&peer_addr, &ip_list)
555 3 : }
556 :
557 1 : assert!(check(json!([])));
558 1 : assert!(check(json!(["127.0.0.1"])));
559 1 : assert!(!check(json!(["255.255.255.255"])));
560 1 : }
561 : }
|