Line data Source code
1 : //! User credentials used in authentication.
2 :
3 : use crate::{
4 : auth::password_hack::parse_endpoint_param,
5 : context::RequestMonitoring,
6 : error::{ReportableError, UserFacingError},
7 : metrics::{Metrics, SniKind},
8 : proxy::NeonOptions,
9 : serverless::SERVERLESS_DRIVER_SNI,
10 : EndpointId, RoleName,
11 : };
12 : use itertools::Itertools;
13 : use pq_proto::StartupMessageParams;
14 : use std::{collections::HashSet, net::IpAddr, str::FromStr};
15 : use thiserror::Error;
16 : use tracing::{info, warn};
17 :
18 0 : #[derive(Debug, Error, PartialEq, Eq, Clone)]
19 : pub(crate) enum ComputeUserInfoParseError {
20 : #[error("Parameter '{0}' is missing in startup packet.")]
21 : MissingKey(&'static str),
22 :
23 : #[error(
24 : "Inconsistent project name inferred from \
25 : SNI ('{}') and project option ('{}').",
26 : .domain, .option,
27 : )]
28 : InconsistentProjectNames {
29 : domain: EndpointId,
30 : option: EndpointId,
31 : },
32 :
33 : #[error(
34 : "Common name inferred from SNI ('{}') is not known",
35 : .cn,
36 : )]
37 : UnknownCommonName { cn: String },
38 :
39 : #[error("Project name ('{0}') must contain only alphanumeric characters and hyphen.")]
40 : MalformedProjectName(EndpointId),
41 : }
42 :
43 : impl UserFacingError for ComputeUserInfoParseError {}
44 :
45 : impl ReportableError for ComputeUserInfoParseError {
46 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
47 0 : crate::error::ErrorKind::User
48 0 : }
49 : }
50 :
51 : /// Various client credentials which we use for authentication.
52 : /// Note that we don't store any kind of client key or password here.
53 : #[derive(Debug, Clone, PartialEq, Eq)]
54 : pub(crate) struct ComputeUserInfoMaybeEndpoint {
55 : pub(crate) user: RoleName,
56 : pub(crate) endpoint_id: Option<EndpointId>,
57 : pub(crate) options: NeonOptions,
58 : }
59 :
60 : impl ComputeUserInfoMaybeEndpoint {
61 : #[inline]
62 0 : pub(crate) fn endpoint(&self) -> Option<&str> {
63 0 : self.endpoint_id.as_deref()
64 0 : }
65 : }
66 :
67 27 : pub(crate) fn endpoint_sni(
68 27 : sni: &str,
69 27 : common_names: &HashSet<String>,
70 27 : ) -> Result<Option<EndpointId>, ComputeUserInfoParseError> {
71 27 : let Some((subdomain, common_name)) = sni.split_once('.') else {
72 0 : return Err(ComputeUserInfoParseError::UnknownCommonName { cn: sni.into() });
73 : };
74 27 : if !common_names.contains(common_name) {
75 1 : return Err(ComputeUserInfoParseError::UnknownCommonName {
76 1 : cn: common_name.into(),
77 1 : });
78 26 : }
79 26 : if subdomain == SERVERLESS_DRIVER_SNI {
80 0 : return Ok(None);
81 26 : }
82 26 : Ok(Some(EndpointId::from(subdomain)))
83 27 : }
84 :
85 : impl ComputeUserInfoMaybeEndpoint {
86 13 : pub(crate) fn parse(
87 13 : ctx: &RequestMonitoring,
88 13 : params: &StartupMessageParams,
89 13 : sni: Option<&str>,
90 13 : common_names: Option<&HashSet<String>>,
91 13 : ) -> Result<Self, ComputeUserInfoParseError> {
92 13 : // Some parameters are stored in the startup message.
93 13 : let get_param = |key| {
94 13 : params
95 13 : .get(key)
96 13 : .ok_or(ComputeUserInfoParseError::MissingKey(key))
97 13 : };
98 13 : let user: RoleName = get_param("user")?.into();
99 13 :
100 13 : // Project name might be passed via PG's command-line options.
101 13 : let endpoint_option = params
102 13 : .options_raw()
103 13 : .and_then(|options| {
104 7 : // We support both `project` (deprecated) and `endpoint` options for backward compatibility.
105 7 : // However, if both are present, we don't exactly know which one to use.
106 7 : // Therefore we require that only one of them is present.
107 7 : options
108 7 : .filter_map(parse_endpoint_param)
109 7 : .at_most_one()
110 7 : .ok()?
111 13 : })
112 13 : .map(|name| name.into());
113 :
114 13 : let endpoint_from_domain = if let Some(sni_str) = sni {
115 7 : if let Some(cn) = common_names {
116 7 : endpoint_sni(sni_str, cn)?
117 : } else {
118 0 : None
119 : }
120 : } else {
121 6 : None
122 : };
123 :
124 12 : let endpoint = match (endpoint_option, endpoint_from_domain) {
125 : // Invariant: if we have both project name variants, they should match.
126 2 : (Some(option), Some(domain)) if option != domain => {
127 1 : Some(Err(ComputeUserInfoParseError::InconsistentProjectNames {
128 1 : domain,
129 1 : option,
130 1 : }))
131 : }
132 : // Invariant: project name may not contain certain characters.
133 11 : (a, b) => a.or(b).map(|name| {
134 7 : if project_name_valid(name.as_ref()) {
135 7 : Ok(name)
136 : } else {
137 0 : Err(ComputeUserInfoParseError::MalformedProjectName(name))
138 : }
139 11 : }),
140 : }
141 12 : .transpose()?;
142 :
143 11 : if let Some(ep) = &endpoint {
144 7 : ctx.set_endpoint_id(ep.clone());
145 7 : }
146 :
147 11 : let metrics = Metrics::get();
148 11 : info!(%user, "credentials");
149 11 : if sni.is_some() {
150 5 : info!("Connection with sni");
151 5 : metrics.proxy.accepted_connections_by_sni.inc(SniKind::Sni);
152 6 : } else if endpoint.is_some() {
153 2 : metrics
154 2 : .proxy
155 2 : .accepted_connections_by_sni
156 2 : .inc(SniKind::NoSni);
157 2 : info!("Connection without sni");
158 : } else {
159 4 : metrics
160 4 : .proxy
161 4 : .accepted_connections_by_sni
162 4 : .inc(SniKind::PasswordHack);
163 4 : info!("Connection with password hack");
164 : }
165 :
166 11 : let options = NeonOptions::parse_params(params);
167 11 :
168 11 : Ok(Self {
169 11 : user,
170 11 : endpoint_id: endpoint,
171 11 : options,
172 11 : })
173 13 : }
174 : }
175 :
176 10 : pub(crate) fn check_peer_addr_is_in_list(peer_addr: &IpAddr, ip_list: &[IpPattern]) -> bool {
177 10 : ip_list.is_empty() || ip_list.iter().any(|pattern| check_ip(peer_addr, pattern))
178 10 : }
179 :
180 : #[derive(Debug, Clone, Eq, PartialEq)]
181 : pub(crate) enum IpPattern {
182 : Subnet(ipnet::IpNet),
183 : Range(IpAddr, IpAddr),
184 : Single(IpAddr),
185 : None,
186 : }
187 :
188 : impl<'de> serde::de::Deserialize<'de> for IpPattern {
189 8 : fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
190 8 : where
191 8 : D: serde::Deserializer<'de>,
192 8 : {
193 : struct StrVisitor;
194 : impl<'de> serde::de::Visitor<'de> for StrVisitor {
195 : type Value = IpPattern;
196 :
197 0 : fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
198 0 : write!(formatter, "comma separated list with ip address, ip address range, or ip address subnet mask")
199 0 : }
200 :
201 8 : fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
202 8 : where
203 8 : E: serde::de::Error,
204 8 : {
205 8 : Ok(parse_ip_pattern(v).unwrap_or_else(|e| {
206 1 : warn!("Cannot parse ip pattern {v}: {e}");
207 1 : IpPattern::None
208 8 : }))
209 8 : }
210 : }
211 8 : deserializer.deserialize_str(StrVisitor)
212 8 : }
213 : }
214 :
215 : impl FromStr for IpPattern {
216 : type Err = anyhow::Error;
217 :
218 6 : fn from_str(s: &str) -> Result<Self, Self::Err> {
219 6 : parse_ip_pattern(s)
220 6 : }
221 : }
222 :
223 22 : fn parse_ip_pattern(pattern: &str) -> anyhow::Result<IpPattern> {
224 22 : if pattern.contains('/') {
225 2 : let subnet: ipnet::IpNet = pattern.parse()?;
226 1 : return Ok(IpPattern::Subnet(subnet));
227 20 : }
228 20 : if let Some((start, end)) = pattern.split_once('-') {
229 3 : let start: IpAddr = start.parse()?;
230 2 : let end: IpAddr = end.parse()?;
231 1 : return Ok(IpPattern::Range(start, end));
232 17 : }
233 17 : let addr: IpAddr = pattern.parse()?;
234 14 : Ok(IpPattern::Single(addr))
235 22 : }
236 :
237 16 : fn check_ip(ip: &IpAddr, pattern: &IpPattern) -> bool {
238 16 : match pattern {
239 3 : IpPattern::Subnet(subnet) => subnet.contains(ip),
240 5 : IpPattern::Range(start, end) => start <= ip && ip <= end,
241 7 : IpPattern::Single(addr) => addr == ip,
242 1 : IpPattern::None => false,
243 : }
244 16 : }
245 :
246 7 : fn project_name_valid(name: &str) -> bool {
247 23 : name.chars().all(|c| c.is_alphanumeric() || c == '-')
248 7 : }
249 :
250 : #[cfg(test)]
251 : mod tests {
252 : use super::*;
253 : use serde_json::json;
254 : use ComputeUserInfoParseError::*;
255 :
256 : #[test]
257 1 : fn parse_bare_minimum() -> anyhow::Result<()> {
258 1 : // According to postgresql, only `user` should be required.
259 1 : let options = StartupMessageParams::new([("user", "john_doe")]);
260 1 : let ctx = RequestMonitoring::test();
261 1 : let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
262 1 : assert_eq!(user_info.user, "john_doe");
263 1 : assert_eq!(user_info.endpoint_id, None);
264 :
265 1 : Ok(())
266 1 : }
267 :
268 : #[test]
269 1 : fn parse_excessive() -> anyhow::Result<()> {
270 1 : let options = StartupMessageParams::new([
271 1 : ("user", "john_doe"),
272 1 : ("database", "world"), // should be ignored
273 1 : ("foo", "bar"), // should be ignored
274 1 : ]);
275 1 : let ctx = RequestMonitoring::test();
276 1 : let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
277 1 : assert_eq!(user_info.user, "john_doe");
278 1 : assert_eq!(user_info.endpoint_id, None);
279 :
280 1 : Ok(())
281 1 : }
282 :
283 : #[test]
284 1 : fn parse_project_from_sni() -> anyhow::Result<()> {
285 1 : let options = StartupMessageParams::new([("user", "john_doe")]);
286 1 :
287 1 : let sni = Some("foo.localhost");
288 1 : let common_names = Some(["localhost".into()].into());
289 1 :
290 1 : let ctx = RequestMonitoring::test();
291 1 : let user_info =
292 1 : ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
293 1 : assert_eq!(user_info.user, "john_doe");
294 1 : assert_eq!(user_info.endpoint_id.as_deref(), Some("foo"));
295 1 : assert_eq!(user_info.options.get_cache_key("foo"), "foo");
296 :
297 1 : Ok(())
298 1 : }
299 :
300 : #[test]
301 1 : fn parse_project_from_options() -> anyhow::Result<()> {
302 1 : let options = StartupMessageParams::new([
303 1 : ("user", "john_doe"),
304 1 : ("options", "-ckey=1 project=bar -c geqo=off"),
305 1 : ]);
306 1 :
307 1 : let ctx = RequestMonitoring::test();
308 1 : let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
309 1 : assert_eq!(user_info.user, "john_doe");
310 1 : assert_eq!(user_info.endpoint_id.as_deref(), Some("bar"));
311 :
312 1 : Ok(())
313 1 : }
314 :
315 : #[test]
316 1 : fn parse_endpoint_from_options() -> anyhow::Result<()> {
317 1 : let options = StartupMessageParams::new([
318 1 : ("user", "john_doe"),
319 1 : ("options", "-ckey=1 endpoint=bar -c geqo=off"),
320 1 : ]);
321 1 :
322 1 : let ctx = RequestMonitoring::test();
323 1 : let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
324 1 : assert_eq!(user_info.user, "john_doe");
325 1 : assert_eq!(user_info.endpoint_id.as_deref(), Some("bar"));
326 :
327 1 : Ok(())
328 1 : }
329 :
330 : #[test]
331 1 : fn parse_three_endpoints_from_options() -> anyhow::Result<()> {
332 1 : let options = StartupMessageParams::new([
333 1 : ("user", "john_doe"),
334 1 : (
335 1 : "options",
336 1 : "-ckey=1 endpoint=one endpoint=two endpoint=three -c geqo=off",
337 1 : ),
338 1 : ]);
339 1 :
340 1 : let ctx = RequestMonitoring::test();
341 1 : let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
342 1 : assert_eq!(user_info.user, "john_doe");
343 1 : assert!(user_info.endpoint_id.is_none());
344 :
345 1 : Ok(())
346 1 : }
347 :
348 : #[test]
349 1 : fn parse_when_endpoint_and_project_are_in_options() -> anyhow::Result<()> {
350 1 : let options = StartupMessageParams::new([
351 1 : ("user", "john_doe"),
352 1 : ("options", "-ckey=1 endpoint=bar project=foo -c geqo=off"),
353 1 : ]);
354 1 :
355 1 : let ctx = RequestMonitoring::test();
356 1 : let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
357 1 : assert_eq!(user_info.user, "john_doe");
358 1 : assert!(user_info.endpoint_id.is_none());
359 :
360 1 : Ok(())
361 1 : }
362 :
363 : #[test]
364 1 : fn parse_projects_identical() -> anyhow::Result<()> {
365 1 : let options = StartupMessageParams::new([("user", "john_doe"), ("options", "project=baz")]);
366 1 :
367 1 : let sni = Some("baz.localhost");
368 1 : let common_names = Some(["localhost".into()].into());
369 1 :
370 1 : let ctx = RequestMonitoring::test();
371 1 : let user_info =
372 1 : ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
373 1 : assert_eq!(user_info.user, "john_doe");
374 1 : assert_eq!(user_info.endpoint_id.as_deref(), Some("baz"));
375 :
376 1 : Ok(())
377 1 : }
378 :
379 : #[test]
380 1 : fn parse_multi_common_names() -> anyhow::Result<()> {
381 1 : let options = StartupMessageParams::new([("user", "john_doe")]);
382 1 :
383 1 : let common_names = Some(["a.com".into(), "b.com".into()].into());
384 1 : let sni = Some("p1.a.com");
385 1 : let ctx = RequestMonitoring::test();
386 1 : let user_info =
387 1 : ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
388 1 : assert_eq!(user_info.endpoint_id.as_deref(), Some("p1"));
389 :
390 1 : let common_names = Some(["a.com".into(), "b.com".into()].into());
391 1 : let sni = Some("p1.b.com");
392 1 : let ctx = RequestMonitoring::test();
393 1 : let user_info =
394 1 : ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
395 1 : assert_eq!(user_info.endpoint_id.as_deref(), Some("p1"));
396 :
397 1 : Ok(())
398 1 : }
399 :
400 : #[test]
401 1 : fn parse_projects_different() {
402 1 : let options =
403 1 : StartupMessageParams::new([("user", "john_doe"), ("options", "project=first")]);
404 1 :
405 1 : let sni = Some("second.localhost");
406 1 : let common_names = Some(["localhost".into()].into());
407 1 :
408 1 : let ctx = RequestMonitoring::test();
409 1 : let err = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())
410 1 : .expect_err("should fail");
411 1 : match err {
412 1 : InconsistentProjectNames { domain, option } => {
413 1 : assert_eq!(option, "first");
414 1 : assert_eq!(domain, "second");
415 : }
416 0 : _ => panic!("bad error: {err:?}"),
417 : }
418 1 : }
419 :
420 : #[test]
421 1 : fn parse_inconsistent_sni() {
422 1 : let options = StartupMessageParams::new([("user", "john_doe")]);
423 1 :
424 1 : let sni = Some("project.localhost");
425 1 : let common_names = Some(["example.com".into()].into());
426 1 :
427 1 : let ctx = RequestMonitoring::test();
428 1 : let err = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())
429 1 : .expect_err("should fail");
430 1 : match err {
431 1 : UnknownCommonName { cn } => {
432 1 : assert_eq!(cn, "localhost");
433 : }
434 0 : _ => panic!("bad error: {err:?}"),
435 : }
436 1 : }
437 :
438 : #[test]
439 1 : fn parse_neon_options() -> anyhow::Result<()> {
440 1 : let options = StartupMessageParams::new([
441 1 : ("user", "john_doe"),
442 1 : ("options", "neon_lsn:0/2 neon_endpoint_type:read_write"),
443 1 : ]);
444 1 :
445 1 : let sni = Some("project.localhost");
446 1 : let common_names = Some(["localhost".into()].into());
447 1 : let ctx = RequestMonitoring::test();
448 1 : let user_info =
449 1 : ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
450 1 : assert_eq!(user_info.endpoint_id.as_deref(), Some("project"));
451 1 : assert_eq!(
452 1 : user_info.options.get_cache_key("project"),
453 1 : "project endpoint_type:read_write lsn:0/2"
454 1 : );
455 :
456 1 : Ok(())
457 1 : }
458 :
459 : #[test]
460 1 : fn test_check_peer_addr_is_in_list() {
461 4 : fn check(v: serde_json::Value) -> bool {
462 4 : let peer_addr = IpAddr::from([127, 0, 0, 1]);
463 4 : let ip_list: Vec<IpPattern> = serde_json::from_value(v).unwrap();
464 4 : check_peer_addr_is_in_list(&peer_addr, &ip_list)
465 4 : }
466 :
467 1 : assert!(check(json!([])));
468 1 : assert!(check(json!(["127.0.0.1"])));
469 1 : assert!(!check(json!(["8.8.8.8"])));
470 : // If there is an incorrect address, it will be skipped.
471 1 : assert!(check(json!(["88.8.8", "127.0.0.1"])));
472 1 : }
473 : #[test]
474 1 : fn test_parse_ip_v4() -> anyhow::Result<()> {
475 1 : let peer_addr = IpAddr::from([127, 0, 0, 1]);
476 : // Ok
477 1 : assert_eq!(parse_ip_pattern("127.0.0.1")?, IpPattern::Single(peer_addr));
478 1 : assert_eq!(
479 1 : parse_ip_pattern("127.0.0.1/31")?,
480 1 : IpPattern::Subnet(ipnet::IpNet::new(peer_addr, 31)?)
481 : );
482 1 : assert_eq!(
483 1 : parse_ip_pattern("0.0.0.0-200.0.1.2")?,
484 1 : IpPattern::Range(IpAddr::from([0, 0, 0, 0]), IpAddr::from([200, 0, 1, 2]))
485 : );
486 :
487 : // Error
488 1 : assert!(parse_ip_pattern("300.0.1.2").is_err());
489 1 : assert!(parse_ip_pattern("30.1.2").is_err());
490 1 : assert!(parse_ip_pattern("127.0.0.1/33").is_err());
491 1 : assert!(parse_ip_pattern("127.0.0.1-127.0.3").is_err());
492 1 : assert!(parse_ip_pattern("1234.0.0.1-127.0.3.0").is_err());
493 1 : Ok(())
494 1 : }
495 :
496 : #[test]
497 1 : fn test_check_ipv4() -> anyhow::Result<()> {
498 1 : let peer_addr = IpAddr::from([127, 0, 0, 1]);
499 1 : let peer_addr_next = IpAddr::from([127, 0, 0, 2]);
500 1 : let peer_addr_prev = IpAddr::from([127, 0, 0, 0]);
501 1 : // Success
502 1 : assert!(check_ip(&peer_addr, &IpPattern::Single(peer_addr)));
503 1 : assert!(check_ip(
504 1 : &peer_addr,
505 1 : &IpPattern::Subnet(ipnet::IpNet::new(peer_addr_prev, 31)?)
506 : ));
507 1 : assert!(check_ip(
508 1 : &peer_addr,
509 1 : &IpPattern::Subnet(ipnet::IpNet::new(peer_addr_next, 30)?)
510 : ));
511 1 : assert!(check_ip(
512 1 : &peer_addr,
513 1 : &IpPattern::Range(IpAddr::from([0, 0, 0, 0]), IpAddr::from([200, 0, 1, 2]))
514 1 : ));
515 1 : assert!(check_ip(
516 1 : &peer_addr,
517 1 : &IpPattern::Range(peer_addr, peer_addr)
518 1 : ));
519 :
520 : // Not success
521 1 : assert!(!check_ip(&peer_addr, &IpPattern::Single(peer_addr_prev)));
522 1 : assert!(!check_ip(
523 1 : &peer_addr,
524 1 : &IpPattern::Subnet(ipnet::IpNet::new(peer_addr_next, 31)?)
525 : ));
526 1 : assert!(!check_ip(
527 1 : &peer_addr,
528 1 : &IpPattern::Range(IpAddr::from([0, 0, 0, 0]), peer_addr_prev)
529 1 : ));
530 1 : assert!(!check_ip(
531 1 : &peer_addr,
532 1 : &IpPattern::Range(peer_addr_next, IpAddr::from([128, 0, 0, 0]))
533 1 : ));
534 : // There is no check that for range start <= end. But it's fine as long as for all this cases the result is false.
535 1 : assert!(!check_ip(
536 1 : &peer_addr,
537 1 : &IpPattern::Range(peer_addr, peer_addr_prev)
538 1 : ));
539 1 : Ok(())
540 1 : }
541 :
542 : #[test]
543 1 : fn test_connection_blocker() {
544 3 : fn check(v: serde_json::Value) -> bool {
545 3 : let peer_addr = IpAddr::from([127, 0, 0, 1]);
546 3 : let ip_list: Vec<IpPattern> = serde_json::from_value(v).unwrap();
547 3 : check_peer_addr_is_in_list(&peer_addr, &ip_list)
548 3 : }
549 :
550 1 : assert!(check(json!([])));
551 1 : assert!(check(json!(["127.0.0.1"])));
552 1 : assert!(!check(json!(["255.255.255.255"])));
553 1 : }
554 : }
|