Line data Source code
1 : //! User credentials used in authentication.
2 :
3 : use std::collections::HashSet;
4 : use std::net::IpAddr;
5 : use std::str::FromStr;
6 :
7 : use itertools::Itertools;
8 : use thiserror::Error;
9 : use tracing::{debug, warn};
10 :
11 : use crate::auth::password_hack::parse_endpoint_param;
12 : use crate::context::RequestContext;
13 : use crate::error::{ReportableError, UserFacingError};
14 : use crate::metrics::{Metrics, SniGroup, SniKind};
15 : use crate::pqproto::StartupMessageParams;
16 : use crate::proxy::NeonOptions;
17 : use crate::serverless::{AUTH_BROKER_SNI, SERVERLESS_DRIVER_SNI};
18 : use crate::types::{EndpointId, RoleName};
19 :
20 : #[derive(Debug, Error, PartialEq, Eq, Clone)]
21 : pub(crate) enum ComputeUserInfoParseError {
22 : #[error("Parameter '{0}' is missing in startup packet.")]
23 : MissingKey(&'static str),
24 :
25 : #[error(
26 : "Inconsistent project name inferred from \
27 : SNI ('{}') and project option ('{}').",
28 : .domain, .option,
29 : )]
30 : InconsistentProjectNames {
31 : domain: EndpointId,
32 : option: EndpointId,
33 : },
34 :
35 : #[error("Project name ('{0}') must contain only alphanumeric characters and hyphen.")]
36 : MalformedProjectName(EndpointId),
37 : }
38 :
39 : impl UserFacingError for ComputeUserInfoParseError {}
40 :
41 : impl ReportableError for ComputeUserInfoParseError {
42 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
43 0 : crate::error::ErrorKind::User
44 0 : }
45 : }
46 :
47 : /// Various client credentials which we use for authentication.
48 : /// Note that we don't store any kind of client key or password here.
49 : #[derive(Debug, Clone, PartialEq, Eq)]
50 : pub(crate) struct ComputeUserInfoMaybeEndpoint {
51 : pub(crate) user: RoleName,
52 : pub(crate) endpoint_id: Option<EndpointId>,
53 : pub(crate) options: NeonOptions,
54 : }
55 :
56 : impl ComputeUserInfoMaybeEndpoint {
57 : #[inline]
58 0 : pub(crate) fn endpoint(&self) -> Option<&str> {
59 0 : self.endpoint_id.as_deref()
60 0 : }
61 : }
62 :
63 28 : pub(crate) fn endpoint_sni(sni: &str, common_names: &HashSet<String>) -> Option<EndpointId> {
64 28 : let (subdomain, common_name) = sni.split_once('.')?;
65 28 : if !common_names.contains(common_name) {
66 2 : return None;
67 26 : }
68 26 : if subdomain == SERVERLESS_DRIVER_SNI || subdomain == AUTH_BROKER_SNI {
69 0 : return None;
70 26 : }
71 26 : Some(EndpointId::from(subdomain))
72 28 : }
73 :
74 : impl ComputeUserInfoMaybeEndpoint {
75 14 : pub(crate) fn parse(
76 14 : ctx: &RequestContext,
77 14 : params: &StartupMessageParams,
78 14 : sni: Option<&str>,
79 14 : common_names: Option<&HashSet<String>>,
80 14 : ) -> Result<Self, ComputeUserInfoParseError> {
81 : // Some parameters are stored in the startup message.
82 14 : let get_param = |key| {
83 14 : params
84 14 : .get(key)
85 14 : .ok_or(ComputeUserInfoParseError::MissingKey(key))
86 14 : };
87 14 : let user: RoleName = get_param("user")?.into();
88 :
89 : // Project name might be passed via PG's command-line options.
90 14 : let endpoint_option = params
91 14 : .options_raw()
92 14 : .and_then(|options| {
93 : // We support both `project` (deprecated) and `endpoint` options for backward compatibility.
94 : // However, if both are present, we don't exactly know which one to use.
95 : // Therefore we require that only one of them is present.
96 8 : options
97 8 : .filter_map(parse_endpoint_param)
98 8 : .at_most_one()
99 8 : .ok()?
100 8 : })
101 14 : .map(|name| name.into());
102 :
103 14 : let endpoint_from_domain =
104 14 : sni.and_then(|sni_str| common_names.and_then(|cn| endpoint_sni(sni_str, cn)));
105 :
106 14 : let endpoint = match (endpoint_option, endpoint_from_domain) {
107 : // Invariant: if we have both project name variants, they should match.
108 2 : (Some(option), Some(domain)) if option != domain => {
109 1 : Some(Err(ComputeUserInfoParseError::InconsistentProjectNames {
110 1 : domain,
111 1 : option,
112 1 : }))
113 : }
114 : // Invariant: project name may not contain certain characters.
115 13 : (a, b) => a.or(b).map(|name| {
116 8 : if project_name_valid(name.as_ref()) {
117 8 : Ok(name)
118 : } else {
119 0 : Err(ComputeUserInfoParseError::MalformedProjectName(name))
120 : }
121 8 : }),
122 : }
123 14 : .transpose()?;
124 :
125 13 : if let Some(ep) = &endpoint {
126 8 : ctx.set_endpoint_id(ep.clone());
127 8 : }
128 :
129 13 : let metrics = Metrics::get();
130 13 : debug!(%user, "credentials");
131 :
132 13 : let protocol = ctx.protocol();
133 13 : let kind = if sni.is_some() {
134 7 : debug!("Connection with sni");
135 7 : SniKind::Sni
136 6 : } else if endpoint.is_some() {
137 2 : debug!("Connection without sni");
138 2 : SniKind::NoSni
139 : } else {
140 4 : debug!("Connection with password hack");
141 4 : SniKind::PasswordHack
142 : };
143 :
144 13 : metrics
145 13 : .proxy
146 13 : .accepted_connections_by_sni
147 13 : .inc(SniGroup { protocol, kind });
148 :
149 13 : let options = NeonOptions::parse_params(params);
150 :
151 13 : Ok(Self {
152 13 : user,
153 13 : endpoint_id: endpoint,
154 13 : options,
155 13 : })
156 14 : }
157 : }
158 :
159 10 : pub(crate) fn check_peer_addr_is_in_list(peer_addr: &IpAddr, ip_list: &[IpPattern]) -> bool {
160 10 : ip_list.is_empty() || ip_list.iter().any(|pattern| check_ip(peer_addr, pattern))
161 10 : }
162 :
163 : #[derive(Debug, Clone, Eq, PartialEq)]
164 : pub(crate) enum IpPattern {
165 : Subnet(ipnet::IpNet),
166 : Range(IpAddr, IpAddr),
167 : Single(IpAddr),
168 : None,
169 : }
170 :
171 : impl<'de> serde::de::Deserialize<'de> for IpPattern {
172 10 : fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
173 10 : where
174 10 : D: serde::Deserializer<'de>,
175 : {
176 : struct StrVisitor;
177 : impl serde::de::Visitor<'_> for StrVisitor {
178 : type Value = IpPattern;
179 :
180 0 : fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
181 0 : write!(
182 0 : formatter,
183 0 : "comma separated list with ip address, ip address range, or ip address subnet mask"
184 : )
185 0 : }
186 :
187 10 : fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
188 10 : where
189 10 : E: serde::de::Error,
190 : {
191 10 : Ok(parse_ip_pattern(v).unwrap_or_else(|e| {
192 1 : warn!("Cannot parse ip pattern {v}: {e}");
193 1 : IpPattern::None
194 1 : }))
195 10 : }
196 : }
197 10 : deserializer.deserialize_str(StrVisitor)
198 10 : }
199 : }
200 :
201 : impl FromStr for IpPattern {
202 : type Err = anyhow::Error;
203 :
204 2 : fn from_str(s: &str) -> Result<Self, Self::Err> {
205 2 : parse_ip_pattern(s)
206 2 : }
207 : }
208 :
209 20 : fn parse_ip_pattern(pattern: &str) -> anyhow::Result<IpPattern> {
210 20 : if pattern.contains('/') {
211 2 : let subnet: ipnet::IpNet = pattern.parse()?;
212 1 : return Ok(IpPattern::Subnet(subnet));
213 18 : }
214 18 : if let Some((start, end)) = pattern.split_once('-') {
215 3 : let start: IpAddr = start.parse()?;
216 2 : let end: IpAddr = end.parse()?;
217 1 : return Ok(IpPattern::Range(start, end));
218 15 : }
219 15 : let addr: IpAddr = pattern.parse()?;
220 12 : Ok(IpPattern::Single(addr))
221 20 : }
222 :
223 16 : fn check_ip(ip: &IpAddr, pattern: &IpPattern) -> bool {
224 16 : match pattern {
225 3 : IpPattern::Subnet(subnet) => subnet.contains(ip),
226 5 : IpPattern::Range(start, end) => start <= ip && ip <= end,
227 7 : IpPattern::Single(addr) => addr == ip,
228 1 : IpPattern::None => false,
229 : }
230 16 : }
231 :
232 8 : fn project_name_valid(name: &str) -> bool {
233 39 : name.chars().all(|c| c.is_alphanumeric() || c == '-')
234 8 : }
235 :
236 : #[cfg(test)]
237 : mod tests {
238 : use ComputeUserInfoParseError::*;
239 : use serde_json::json;
240 :
241 : use super::*;
242 :
243 : #[test]
244 1 : fn parse_bare_minimum() -> anyhow::Result<()> {
245 : // According to postgresql, only `user` should be required.
246 1 : let options = StartupMessageParams::new([("user", "john_doe")]);
247 1 : let ctx = RequestContext::test();
248 1 : let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
249 1 : assert_eq!(user_info.user, "john_doe");
250 1 : assert_eq!(user_info.endpoint_id, None);
251 :
252 1 : Ok(())
253 1 : }
254 :
255 : #[test]
256 1 : fn parse_excessive() -> anyhow::Result<()> {
257 1 : let options = StartupMessageParams::new([
258 1 : ("user", "john_doe"),
259 1 : ("database", "world"), // should be ignored
260 1 : ("foo", "bar"), // should be ignored
261 1 : ]);
262 1 : let ctx = RequestContext::test();
263 1 : let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
264 1 : assert_eq!(user_info.user, "john_doe");
265 1 : assert_eq!(user_info.endpoint_id, None);
266 :
267 1 : Ok(())
268 1 : }
269 :
270 : #[test]
271 1 : fn parse_project_from_sni() -> anyhow::Result<()> {
272 1 : let options = StartupMessageParams::new([("user", "john_doe")]);
273 :
274 1 : let sni = Some("foo.localhost");
275 1 : let common_names = Some(["localhost".into()].into());
276 :
277 1 : let ctx = RequestContext::test();
278 1 : let user_info =
279 1 : ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
280 1 : assert_eq!(user_info.user, "john_doe");
281 1 : assert_eq!(user_info.endpoint_id.as_deref(), Some("foo"));
282 1 : assert_eq!(user_info.options.get_cache_key("foo"), "foo");
283 :
284 1 : Ok(())
285 1 : }
286 :
287 : #[test]
288 1 : fn parse_project_from_options() -> anyhow::Result<()> {
289 1 : let options = StartupMessageParams::new([
290 1 : ("user", "john_doe"),
291 1 : ("options", "-ckey=1 project=bar -c geqo=off"),
292 1 : ]);
293 :
294 1 : let ctx = RequestContext::test();
295 1 : let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
296 1 : assert_eq!(user_info.user, "john_doe");
297 1 : assert_eq!(user_info.endpoint_id.as_deref(), Some("bar"));
298 :
299 1 : Ok(())
300 1 : }
301 :
302 : #[test]
303 1 : fn parse_endpoint_from_options() -> anyhow::Result<()> {
304 1 : let options = StartupMessageParams::new([
305 1 : ("user", "john_doe"),
306 1 : ("options", "-ckey=1 endpoint=bar -c geqo=off"),
307 1 : ]);
308 :
309 1 : let ctx = RequestContext::test();
310 1 : let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
311 1 : assert_eq!(user_info.user, "john_doe");
312 1 : assert_eq!(user_info.endpoint_id.as_deref(), Some("bar"));
313 :
314 1 : Ok(())
315 1 : }
316 :
317 : #[test]
318 1 : fn parse_three_endpoints_from_options() -> anyhow::Result<()> {
319 1 : let options = StartupMessageParams::new([
320 1 : ("user", "john_doe"),
321 1 : (
322 1 : "options",
323 1 : "-ckey=1 endpoint=one endpoint=two endpoint=three -c geqo=off",
324 1 : ),
325 1 : ]);
326 :
327 1 : let ctx = RequestContext::test();
328 1 : let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
329 1 : assert_eq!(user_info.user, "john_doe");
330 1 : assert!(user_info.endpoint_id.is_none());
331 :
332 1 : Ok(())
333 1 : }
334 :
335 : #[test]
336 1 : fn parse_when_endpoint_and_project_are_in_options() -> anyhow::Result<()> {
337 1 : let options = StartupMessageParams::new([
338 1 : ("user", "john_doe"),
339 1 : ("options", "-ckey=1 endpoint=bar project=foo -c geqo=off"),
340 1 : ]);
341 :
342 1 : let ctx = RequestContext::test();
343 1 : let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
344 1 : assert_eq!(user_info.user, "john_doe");
345 1 : assert!(user_info.endpoint_id.is_none());
346 :
347 1 : Ok(())
348 1 : }
349 :
350 : #[test]
351 1 : fn parse_projects_identical() -> anyhow::Result<()> {
352 1 : let options = StartupMessageParams::new([("user", "john_doe"), ("options", "project=baz")]);
353 :
354 1 : let sni = Some("baz.localhost");
355 1 : let common_names = Some(["localhost".into()].into());
356 :
357 1 : let ctx = RequestContext::test();
358 1 : let user_info =
359 1 : ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
360 1 : assert_eq!(user_info.user, "john_doe");
361 1 : assert_eq!(user_info.endpoint_id.as_deref(), Some("baz"));
362 :
363 1 : Ok(())
364 1 : }
365 :
366 : #[test]
367 1 : fn parse_multi_common_names() -> anyhow::Result<()> {
368 1 : let options = StartupMessageParams::new([("user", "john_doe")]);
369 :
370 1 : let common_names = Some(["a.com".into(), "b.com".into()].into());
371 1 : let sni = Some("p1.a.com");
372 1 : let ctx = RequestContext::test();
373 1 : let user_info =
374 1 : ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
375 1 : assert_eq!(user_info.endpoint_id.as_deref(), Some("p1"));
376 :
377 1 : let common_names = Some(["a.com".into(), "b.com".into()].into());
378 1 : let sni = Some("p1.b.com");
379 1 : let ctx = RequestContext::test();
380 1 : let user_info =
381 1 : ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
382 1 : assert_eq!(user_info.endpoint_id.as_deref(), Some("p1"));
383 :
384 1 : Ok(())
385 1 : }
386 :
387 : #[test]
388 1 : fn parse_projects_different() {
389 1 : let options =
390 1 : StartupMessageParams::new([("user", "john_doe"), ("options", "project=first")]);
391 :
392 1 : let sni = Some("second.localhost");
393 1 : let common_names = Some(["localhost".into()].into());
394 :
395 1 : let ctx = RequestContext::test();
396 1 : let err = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())
397 1 : .expect_err("should fail");
398 1 : match err {
399 1 : InconsistentProjectNames { domain, option } => {
400 1 : assert_eq!(option, "first");
401 1 : assert_eq!(domain, "second");
402 : }
403 0 : _ => panic!("bad error: {err:?}"),
404 : }
405 1 : }
406 :
407 : #[test]
408 1 : fn parse_unknown_sni() {
409 1 : let options = StartupMessageParams::new([("user", "john_doe")]);
410 :
411 1 : let sni = Some("project.localhost");
412 1 : let common_names = Some(["example.com".into()].into());
413 :
414 1 : let ctx = RequestContext::test();
415 1 : let info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())
416 1 : .unwrap();
417 :
418 1 : assert!(info.endpoint_id.is_none());
419 1 : }
420 :
421 : #[test]
422 1 : fn parse_unknown_sni_with_options() {
423 1 : let options = StartupMessageParams::new([
424 1 : ("user", "john_doe"),
425 1 : ("options", "endpoint=foo-bar-baz-1234"),
426 1 : ]);
427 :
428 1 : let sni = Some("project.localhost");
429 1 : let common_names = Some(["example.com".into()].into());
430 :
431 1 : let ctx = RequestContext::test();
432 1 : let info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())
433 1 : .unwrap();
434 :
435 1 : assert_eq!(info.endpoint_id.as_deref(), Some("foo-bar-baz-1234"));
436 1 : }
437 :
438 : #[test]
439 1 : fn parse_neon_options() -> anyhow::Result<()> {
440 1 : let options = StartupMessageParams::new([
441 1 : ("user", "john_doe"),
442 1 : ("options", "neon_lsn:0/2 neon_endpoint_type:read_write"),
443 1 : ]);
444 :
445 1 : let sni = Some("project.localhost");
446 1 : let common_names = Some(["localhost".into()].into());
447 1 : let ctx = RequestContext::test();
448 1 : let user_info =
449 1 : ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
450 1 : assert_eq!(user_info.endpoint_id.as_deref(), Some("project"));
451 1 : assert_eq!(
452 1 : user_info.options.get_cache_key("project"),
453 : "project endpoint_type:read_write lsn:0/2"
454 : );
455 :
456 1 : Ok(())
457 1 : }
458 :
459 : #[test]
460 1 : fn test_check_peer_addr_is_in_list() {
461 4 : fn check(v: serde_json::Value) -> bool {
462 4 : let peer_addr = IpAddr::from([127, 0, 0, 1]);
463 4 : let ip_list: Vec<IpPattern> = serde_json::from_value(v).unwrap();
464 4 : check_peer_addr_is_in_list(&peer_addr, &ip_list)
465 4 : }
466 :
467 1 : assert!(check(json!([])));
468 1 : assert!(check(json!(["127.0.0.1"])));
469 1 : assert!(!check(json!(["8.8.8.8"])));
470 : // If there is an incorrect address, it will be skipped.
471 1 : assert!(check(json!(["88.8.8", "127.0.0.1"])));
472 1 : }
473 : #[test]
474 1 : fn test_parse_ip_v4() -> anyhow::Result<()> {
475 1 : let peer_addr = IpAddr::from([127, 0, 0, 1]);
476 : // Ok
477 1 : assert_eq!(parse_ip_pattern("127.0.0.1")?, IpPattern::Single(peer_addr));
478 1 : assert_eq!(
479 1 : parse_ip_pattern("127.0.0.1/31")?,
480 1 : IpPattern::Subnet(ipnet::IpNet::new(peer_addr, 31)?)
481 : );
482 1 : assert_eq!(
483 1 : parse_ip_pattern("0.0.0.0-200.0.1.2")?,
484 1 : IpPattern::Range(IpAddr::from([0, 0, 0, 0]), IpAddr::from([200, 0, 1, 2]))
485 : );
486 :
487 : // Error
488 1 : assert!(parse_ip_pattern("300.0.1.2").is_err());
489 1 : assert!(parse_ip_pattern("30.1.2").is_err());
490 1 : assert!(parse_ip_pattern("127.0.0.1/33").is_err());
491 1 : assert!(parse_ip_pattern("127.0.0.1-127.0.3").is_err());
492 1 : assert!(parse_ip_pattern("1234.0.0.1-127.0.3.0").is_err());
493 1 : Ok(())
494 1 : }
495 :
496 : #[test]
497 1 : fn test_check_ipv4() -> anyhow::Result<()> {
498 1 : let peer_addr = IpAddr::from([127, 0, 0, 1]);
499 1 : let peer_addr_next = IpAddr::from([127, 0, 0, 2]);
500 1 : let peer_addr_prev = IpAddr::from([127, 0, 0, 0]);
501 : // Success
502 1 : assert!(check_ip(&peer_addr, &IpPattern::Single(peer_addr)));
503 1 : assert!(check_ip(
504 1 : &peer_addr,
505 1 : &IpPattern::Subnet(ipnet::IpNet::new(peer_addr_prev, 31)?)
506 : ));
507 1 : assert!(check_ip(
508 1 : &peer_addr,
509 1 : &IpPattern::Subnet(ipnet::IpNet::new(peer_addr_next, 30)?)
510 : ));
511 1 : assert!(check_ip(
512 1 : &peer_addr,
513 1 : &IpPattern::Range(IpAddr::from([0, 0, 0, 0]), IpAddr::from([200, 0, 1, 2]))
514 : ));
515 1 : assert!(check_ip(
516 1 : &peer_addr,
517 1 : &IpPattern::Range(peer_addr, peer_addr)
518 : ));
519 :
520 : // Not success
521 1 : assert!(!check_ip(&peer_addr, &IpPattern::Single(peer_addr_prev)));
522 1 : assert!(!check_ip(
523 1 : &peer_addr,
524 1 : &IpPattern::Subnet(ipnet::IpNet::new(peer_addr_next, 31)?)
525 : ));
526 1 : assert!(!check_ip(
527 1 : &peer_addr,
528 1 : &IpPattern::Range(IpAddr::from([0, 0, 0, 0]), peer_addr_prev)
529 1 : ));
530 1 : assert!(!check_ip(
531 1 : &peer_addr,
532 1 : &IpPattern::Range(peer_addr_next, IpAddr::from([128, 0, 0, 0]))
533 1 : ));
534 : // There is no check that for range start <= end. But it's fine as long as for all this cases the result is false.
535 1 : assert!(!check_ip(
536 1 : &peer_addr,
537 1 : &IpPattern::Range(peer_addr, peer_addr_prev)
538 1 : ));
539 1 : Ok(())
540 1 : }
541 :
542 : #[test]
543 1 : fn test_connection_blocker() {
544 3 : fn check(v: serde_json::Value) -> bool {
545 3 : let peer_addr = IpAddr::from([127, 0, 0, 1]);
546 3 : let ip_list: Vec<IpPattern> = serde_json::from_value(v).unwrap();
547 3 : check_peer_addr_is_in_list(&peer_addr, &ip_list)
548 3 : }
549 :
550 1 : assert!(check(json!([])));
551 1 : assert!(check(json!(["127.0.0.1"])));
552 1 : assert!(!check(json!(["255.255.255.255"])));
553 1 : }
554 : }
|