Line data Source code
1 : //! User credentials used in authentication.
2 :
3 : use crate::{
4 : auth::password_hack::parse_endpoint_param,
5 : context::RequestMonitoring,
6 : error::{ReportableError, UserFacingError},
7 : metrics::{Metrics, SniKind},
8 : proxy::NeonOptions,
9 : serverless::SERVERLESS_DRIVER_SNI,
10 : EndpointId, RoleName,
11 : };
12 : use itertools::Itertools;
13 : use pq_proto::StartupMessageParams;
14 : use std::{collections::HashSet, net::IpAddr, str::FromStr};
15 : use thiserror::Error;
16 : use tracing::{info, warn};
17 :
18 0 : #[derive(Debug, Error, PartialEq, Eq, Clone)]
19 : pub enum ComputeUserInfoParseError {
20 : #[error("Parameter '{0}' is missing in startup packet.")]
21 : MissingKey(&'static str),
22 :
23 : #[error(
24 : "Inconsistent project name inferred from \
25 : SNI ('{}') and project option ('{}').",
26 : .domain, .option,
27 : )]
28 : InconsistentProjectNames {
29 : domain: EndpointId,
30 : option: EndpointId,
31 : },
32 :
33 : #[error(
34 : "Common name inferred from SNI ('{}') is not known",
35 : .cn,
36 : )]
37 : UnknownCommonName { cn: String },
38 :
39 : #[error("Project name ('{0}') must contain only alphanumeric characters and hyphen.")]
40 : MalformedProjectName(EndpointId),
41 : }
42 :
43 : impl UserFacingError for ComputeUserInfoParseError {}
44 :
45 : impl ReportableError for ComputeUserInfoParseError {
46 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
47 0 : crate::error::ErrorKind::User
48 0 : }
49 : }
50 :
51 : /// Various client credentials which we use for authentication.
52 : /// Note that we don't store any kind of client key or password here.
53 : #[derive(Debug, Clone, PartialEq, Eq)]
54 : pub struct ComputeUserInfoMaybeEndpoint {
55 : pub user: RoleName,
56 : pub endpoint_id: Option<EndpointId>,
57 : pub options: NeonOptions,
58 : }
59 :
60 : impl ComputeUserInfoMaybeEndpoint {
61 : #[inline]
62 0 : pub fn endpoint(&self) -> Option<&str> {
63 0 : self.endpoint_id.as_deref()
64 0 : }
65 : }
66 :
67 14 : pub fn endpoint_sni(
68 14 : sni: &str,
69 14 : common_names: &HashSet<String>,
70 14 : ) -> Result<Option<EndpointId>, ComputeUserInfoParseError> {
71 14 : let Some((subdomain, common_name)) = sni.split_once('.') else {
72 0 : return Err(ComputeUserInfoParseError::UnknownCommonName { cn: sni.into() });
73 : };
74 14 : if !common_names.contains(common_name) {
75 2 : return Err(ComputeUserInfoParseError::UnknownCommonName {
76 2 : cn: common_name.into(),
77 2 : });
78 12 : }
79 12 : if subdomain == SERVERLESS_DRIVER_SNI {
80 0 : return Ok(None);
81 12 : }
82 12 : Ok(Some(EndpointId::from(subdomain)))
83 14 : }
84 :
85 : impl ComputeUserInfoMaybeEndpoint {
86 26 : pub fn parse(
87 26 : ctx: &mut RequestMonitoring,
88 26 : params: &StartupMessageParams,
89 26 : sni: Option<&str>,
90 26 : common_names: Option<&HashSet<String>>,
91 26 : ) -> Result<Self, ComputeUserInfoParseError> {
92 26 : use ComputeUserInfoParseError::*;
93 26 :
94 26 : // Some parameters are stored in the startup message.
95 26 : let get_param = |key| params.get(key).ok_or(MissingKey(key));
96 26 : let user: RoleName = get_param("user")?.into();
97 26 :
98 26 : // Project name might be passed via PG's command-line options.
99 26 : let endpoint_option = params
100 26 : .options_raw()
101 26 : .and_then(|options| {
102 14 : // We support both `project` (deprecated) and `endpoint` options for backward compatibility.
103 14 : // However, if both are present, we don't exactly know which one to use.
104 14 : // Therefore we require that only one of them is present.
105 14 : options
106 14 : .filter_map(parse_endpoint_param)
107 14 : .at_most_one()
108 14 : .ok()?
109 26 : })
110 26 : .map(|name| name.into());
111 :
112 26 : let endpoint_from_domain = if let Some(sni_str) = sni {
113 14 : if let Some(cn) = common_names {
114 14 : endpoint_sni(sni_str, cn)?
115 : } else {
116 0 : None
117 : }
118 : } else {
119 12 : None
120 : };
121 :
122 24 : let endpoint = match (endpoint_option, endpoint_from_domain) {
123 : // Invariant: if we have both project name variants, they should match.
124 4 : (Some(option), Some(domain)) if option != domain => {
125 2 : Some(Err(InconsistentProjectNames { domain, option }))
126 : }
127 : // Invariant: project name may not contain certain characters.
128 22 : (a, b) => a.or(b).map(|name| match project_name_valid(name.as_ref()) {
129 0 : false => Err(MalformedProjectName(name)),
130 14 : true => Ok(name),
131 22 : }),
132 : }
133 24 : .transpose()?;
134 :
135 22 : if let Some(ep) = &endpoint {
136 14 : ctx.set_endpoint_id(ep.clone());
137 14 : }
138 :
139 22 : let metrics = Metrics::get();
140 22 : info!(%user, "credentials");
141 22 : if sni.is_some() {
142 10 : info!("Connection with sni");
143 10 : metrics.proxy.accepted_connections_by_sni.inc(SniKind::Sni);
144 12 : } else if endpoint.is_some() {
145 4 : metrics
146 4 : .proxy
147 4 : .accepted_connections_by_sni
148 4 : .inc(SniKind::NoSni);
149 4 : info!("Connection without sni");
150 : } else {
151 8 : metrics
152 8 : .proxy
153 8 : .accepted_connections_by_sni
154 8 : .inc(SniKind::PasswordHack);
155 8 : info!("Connection with password hack");
156 : }
157 :
158 22 : let options = NeonOptions::parse_params(params);
159 22 :
160 22 : Ok(Self {
161 22 : user,
162 22 : endpoint_id: endpoint,
163 22 : options,
164 22 : })
165 26 : }
166 : }
167 :
168 14 : pub fn check_peer_addr_is_in_list(peer_addr: &IpAddr, ip_list: &[IpPattern]) -> bool {
169 14 : ip_list.is_empty() || ip_list.iter().any(|pattern| check_ip(peer_addr, pattern))
170 14 : }
171 :
172 : #[derive(Debug, Clone, Eq, PartialEq)]
173 : pub enum IpPattern {
174 : Subnet(ipnet::IpNet),
175 : Range(IpAddr, IpAddr),
176 : Single(IpAddr),
177 : None,
178 : }
179 :
180 : impl<'de> serde::de::Deserialize<'de> for IpPattern {
181 12 : fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
182 12 : where
183 12 : D: serde::Deserializer<'de>,
184 12 : {
185 12 : struct StrVisitor;
186 12 : impl<'de> serde::de::Visitor<'de> for StrVisitor {
187 12 : type Value = IpPattern;
188 12 :
189 12 : fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
190 0 : write!(formatter, "comma separated list with ip address, ip address range, or ip address subnet mask")
191 0 : }
192 12 :
193 12 : fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
194 12 : where
195 12 : E: serde::de::Error,
196 12 : {
197 12 : Ok(parse_ip_pattern(v).unwrap_or_else(|e| {
198 2 : warn!("Cannot parse ip pattern {v}: {e}");
199 12 : IpPattern::None
200 12 : }))
201 12 : }
202 12 : }
203 12 : deserializer.deserialize_str(StrVisitor)
204 12 : }
205 : }
206 :
207 : impl FromStr for IpPattern {
208 : type Err = anyhow::Error;
209 :
210 12 : fn from_str(s: &str) -> Result<Self, Self::Err> {
211 12 : parse_ip_pattern(s)
212 12 : }
213 : }
214 :
215 40 : fn parse_ip_pattern(pattern: &str) -> anyhow::Result<IpPattern> {
216 40 : if pattern.contains('/') {
217 4 : let subnet: ipnet::IpNet = pattern.parse()?;
218 2 : return Ok(IpPattern::Subnet(subnet));
219 36 : }
220 36 : if let Some((start, end)) = pattern.split_once('-') {
221 6 : let start: IpAddr = start.parse()?;
222 4 : let end: IpAddr = end.parse()?;
223 2 : return Ok(IpPattern::Range(start, end));
224 30 : }
225 30 : let addr: IpAddr = pattern.parse()?;
226 24 : Ok(IpPattern::Single(addr))
227 40 : }
228 :
229 28 : fn check_ip(ip: &IpAddr, pattern: &IpPattern) -> bool {
230 28 : match pattern {
231 6 : IpPattern::Subnet(subnet) => subnet.contains(ip),
232 10 : IpPattern::Range(start, end) => start <= ip && ip <= end,
233 10 : IpPattern::Single(addr) => addr == ip,
234 2 : IpPattern::None => false,
235 : }
236 28 : }
237 :
238 14 : fn project_name_valid(name: &str) -> bool {
239 46 : name.chars().all(|c| c.is_alphanumeric() || c == '-')
240 14 : }
241 :
242 : #[cfg(test)]
243 : mod tests {
244 : use super::*;
245 : use serde_json::json;
246 : use ComputeUserInfoParseError::*;
247 :
248 : #[test]
249 2 : fn parse_bare_minimum() -> anyhow::Result<()> {
250 2 : // According to postgresql, only `user` should be required.
251 2 : let options = StartupMessageParams::new([("user", "john_doe")]);
252 2 : let mut ctx = RequestMonitoring::test();
253 2 : let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
254 2 : assert_eq!(user_info.user, "john_doe");
255 2 : assert_eq!(user_info.endpoint_id, None);
256 :
257 2 : Ok(())
258 2 : }
259 :
260 : #[test]
261 2 : fn parse_excessive() -> anyhow::Result<()> {
262 2 : let options = StartupMessageParams::new([
263 2 : ("user", "john_doe"),
264 2 : ("database", "world"), // should be ignored
265 2 : ("foo", "bar"), // should be ignored
266 2 : ]);
267 2 : let mut ctx = RequestMonitoring::test();
268 2 : let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
269 2 : assert_eq!(user_info.user, "john_doe");
270 2 : assert_eq!(user_info.endpoint_id, None);
271 :
272 2 : Ok(())
273 2 : }
274 :
275 : #[test]
276 2 : fn parse_project_from_sni() -> anyhow::Result<()> {
277 2 : let options = StartupMessageParams::new([("user", "john_doe")]);
278 2 :
279 2 : let sni = Some("foo.localhost");
280 2 : let common_names = Some(["localhost".into()].into());
281 2 :
282 2 : let mut ctx = RequestMonitoring::test();
283 2 : let user_info =
284 2 : ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
285 2 : assert_eq!(user_info.user, "john_doe");
286 2 : assert_eq!(user_info.endpoint_id.as_deref(), Some("foo"));
287 2 : assert_eq!(user_info.options.get_cache_key("foo"), "foo");
288 :
289 2 : Ok(())
290 2 : }
291 :
292 : #[test]
293 2 : fn parse_project_from_options() -> anyhow::Result<()> {
294 2 : let options = StartupMessageParams::new([
295 2 : ("user", "john_doe"),
296 2 : ("options", "-ckey=1 project=bar -c geqo=off"),
297 2 : ]);
298 2 :
299 2 : let mut ctx = RequestMonitoring::test();
300 2 : let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
301 2 : assert_eq!(user_info.user, "john_doe");
302 2 : assert_eq!(user_info.endpoint_id.as_deref(), Some("bar"));
303 :
304 2 : Ok(())
305 2 : }
306 :
307 : #[test]
308 2 : fn parse_endpoint_from_options() -> anyhow::Result<()> {
309 2 : let options = StartupMessageParams::new([
310 2 : ("user", "john_doe"),
311 2 : ("options", "-ckey=1 endpoint=bar -c geqo=off"),
312 2 : ]);
313 2 :
314 2 : let mut ctx = RequestMonitoring::test();
315 2 : let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
316 2 : assert_eq!(user_info.user, "john_doe");
317 2 : assert_eq!(user_info.endpoint_id.as_deref(), Some("bar"));
318 :
319 2 : Ok(())
320 2 : }
321 :
322 : #[test]
323 2 : fn parse_three_endpoints_from_options() -> anyhow::Result<()> {
324 2 : let options = StartupMessageParams::new([
325 2 : ("user", "john_doe"),
326 2 : (
327 2 : "options",
328 2 : "-ckey=1 endpoint=one endpoint=two endpoint=three -c geqo=off",
329 2 : ),
330 2 : ]);
331 2 :
332 2 : let mut ctx = RequestMonitoring::test();
333 2 : let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
334 2 : assert_eq!(user_info.user, "john_doe");
335 2 : assert!(user_info.endpoint_id.is_none());
336 :
337 2 : Ok(())
338 2 : }
339 :
340 : #[test]
341 2 : fn parse_when_endpoint_and_project_are_in_options() -> anyhow::Result<()> {
342 2 : let options = StartupMessageParams::new([
343 2 : ("user", "john_doe"),
344 2 : ("options", "-ckey=1 endpoint=bar project=foo -c geqo=off"),
345 2 : ]);
346 2 :
347 2 : let mut ctx = RequestMonitoring::test();
348 2 : let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
349 2 : assert_eq!(user_info.user, "john_doe");
350 2 : assert!(user_info.endpoint_id.is_none());
351 :
352 2 : Ok(())
353 2 : }
354 :
355 : #[test]
356 2 : fn parse_projects_identical() -> anyhow::Result<()> {
357 2 : let options = StartupMessageParams::new([("user", "john_doe"), ("options", "project=baz")]);
358 2 :
359 2 : let sni = Some("baz.localhost");
360 2 : let common_names = Some(["localhost".into()].into());
361 2 :
362 2 : let mut ctx = RequestMonitoring::test();
363 2 : let user_info =
364 2 : ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
365 2 : assert_eq!(user_info.user, "john_doe");
366 2 : assert_eq!(user_info.endpoint_id.as_deref(), Some("baz"));
367 :
368 2 : Ok(())
369 2 : }
370 :
371 : #[test]
372 2 : fn parse_multi_common_names() -> anyhow::Result<()> {
373 2 : let options = StartupMessageParams::new([("user", "john_doe")]);
374 2 :
375 2 : let common_names = Some(["a.com".into(), "b.com".into()].into());
376 2 : let sni = Some("p1.a.com");
377 2 : let mut ctx = RequestMonitoring::test();
378 2 : let user_info =
379 2 : ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
380 2 : assert_eq!(user_info.endpoint_id.as_deref(), Some("p1"));
381 :
382 2 : let common_names = Some(["a.com".into(), "b.com".into()].into());
383 2 : let sni = Some("p1.b.com");
384 2 : let mut ctx = RequestMonitoring::test();
385 2 : let user_info =
386 2 : ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
387 2 : assert_eq!(user_info.endpoint_id.as_deref(), Some("p1"));
388 :
389 2 : Ok(())
390 2 : }
391 :
392 : #[test]
393 2 : fn parse_projects_different() {
394 2 : let options =
395 2 : StartupMessageParams::new([("user", "john_doe"), ("options", "project=first")]);
396 2 :
397 2 : let sni = Some("second.localhost");
398 2 : let common_names = Some(["localhost".into()].into());
399 2 :
400 2 : let mut ctx = RequestMonitoring::test();
401 2 : let err =
402 2 : ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())
403 2 : .expect_err("should fail");
404 2 : match err {
405 2 : InconsistentProjectNames { domain, option } => {
406 2 : assert_eq!(option, "first");
407 2 : assert_eq!(domain, "second");
408 : }
409 0 : _ => panic!("bad error: {err:?}"),
410 : }
411 2 : }
412 :
413 : #[test]
414 2 : fn parse_inconsistent_sni() {
415 2 : let options = StartupMessageParams::new([("user", "john_doe")]);
416 2 :
417 2 : let sni = Some("project.localhost");
418 2 : let common_names = Some(["example.com".into()].into());
419 2 :
420 2 : let mut ctx = RequestMonitoring::test();
421 2 : let err =
422 2 : ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())
423 2 : .expect_err("should fail");
424 2 : match err {
425 2 : UnknownCommonName { cn } => {
426 2 : assert_eq!(cn, "localhost");
427 : }
428 0 : _ => panic!("bad error: {err:?}"),
429 : }
430 2 : }
431 :
432 : #[test]
433 2 : fn parse_neon_options() -> anyhow::Result<()> {
434 2 : let options = StartupMessageParams::new([
435 2 : ("user", "john_doe"),
436 2 : ("options", "neon_lsn:0/2 neon_endpoint_type:read_write"),
437 2 : ]);
438 2 :
439 2 : let sni = Some("project.localhost");
440 2 : let common_names = Some(["localhost".into()].into());
441 2 : let mut ctx = RequestMonitoring::test();
442 2 : let user_info =
443 2 : ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
444 2 : assert_eq!(user_info.endpoint_id.as_deref(), Some("project"));
445 2 : assert_eq!(
446 2 : user_info.options.get_cache_key("project"),
447 2 : "project endpoint_type:read_write lsn:0/2"
448 2 : );
449 :
450 2 : Ok(())
451 2 : }
452 :
453 : #[test]
454 2 : fn test_check_peer_addr_is_in_list() {
455 8 : fn check(v: serde_json::Value) -> bool {
456 8 : let peer_addr = IpAddr::from([127, 0, 0, 1]);
457 8 : let ip_list: Vec<IpPattern> = serde_json::from_value(v).unwrap();
458 8 : check_peer_addr_is_in_list(&peer_addr, &ip_list)
459 8 : }
460 2 :
461 2 : assert!(check(json!([])));
462 2 : assert!(check(json!(["127.0.0.1"])));
463 2 : assert!(!check(json!(["8.8.8.8"])));
464 : // If there is an incorrect address, it will be skipped.
465 2 : assert!(check(json!(["88.8.8", "127.0.0.1"])));
466 2 : }
467 : #[test]
468 2 : fn test_parse_ip_v4() -> anyhow::Result<()> {
469 2 : let peer_addr = IpAddr::from([127, 0, 0, 1]);
470 : // Ok
471 2 : assert_eq!(parse_ip_pattern("127.0.0.1")?, IpPattern::Single(peer_addr));
472 2 : assert_eq!(
473 2 : parse_ip_pattern("127.0.0.1/31")?,
474 2 : IpPattern::Subnet(ipnet::IpNet::new(peer_addr, 31)?)
475 : );
476 2 : assert_eq!(
477 2 : parse_ip_pattern("0.0.0.0-200.0.1.2")?,
478 2 : IpPattern::Range(IpAddr::from([0, 0, 0, 0]), IpAddr::from([200, 0, 1, 2]))
479 : );
480 :
481 : // Error
482 2 : assert!(parse_ip_pattern("300.0.1.2").is_err());
483 2 : assert!(parse_ip_pattern("30.1.2").is_err());
484 2 : assert!(parse_ip_pattern("127.0.0.1/33").is_err());
485 2 : assert!(parse_ip_pattern("127.0.0.1-127.0.3").is_err());
486 2 : assert!(parse_ip_pattern("1234.0.0.1-127.0.3.0").is_err());
487 2 : Ok(())
488 2 : }
489 :
490 : #[test]
491 2 : fn test_check_ipv4() -> anyhow::Result<()> {
492 2 : let peer_addr = IpAddr::from([127, 0, 0, 1]);
493 2 : let peer_addr_next = IpAddr::from([127, 0, 0, 2]);
494 2 : let peer_addr_prev = IpAddr::from([127, 0, 0, 0]);
495 2 : // Success
496 2 : assert!(check_ip(&peer_addr, &IpPattern::Single(peer_addr)));
497 2 : assert!(check_ip(
498 2 : &peer_addr,
499 2 : &IpPattern::Subnet(ipnet::IpNet::new(peer_addr_prev, 31)?)
500 : ));
501 2 : assert!(check_ip(
502 2 : &peer_addr,
503 2 : &IpPattern::Subnet(ipnet::IpNet::new(peer_addr_next, 30)?)
504 : ));
505 2 : assert!(check_ip(
506 2 : &peer_addr,
507 2 : &IpPattern::Range(IpAddr::from([0, 0, 0, 0]), IpAddr::from([200, 0, 1, 2]))
508 2 : ));
509 2 : assert!(check_ip(
510 2 : &peer_addr,
511 2 : &IpPattern::Range(peer_addr, peer_addr)
512 2 : ));
513 :
514 : // Not success
515 2 : assert!(!check_ip(&peer_addr, &IpPattern::Single(peer_addr_prev)));
516 2 : assert!(!check_ip(
517 2 : &peer_addr,
518 2 : &IpPattern::Subnet(ipnet::IpNet::new(peer_addr_next, 31)?)
519 : ));
520 2 : assert!(!check_ip(
521 2 : &peer_addr,
522 2 : &IpPattern::Range(IpAddr::from([0, 0, 0, 0]), peer_addr_prev)
523 2 : ));
524 2 : assert!(!check_ip(
525 2 : &peer_addr,
526 2 : &IpPattern::Range(peer_addr_next, IpAddr::from([128, 0, 0, 0]))
527 2 : ));
528 : // There is no check that for range start <= end. But it's fine as long as for all this cases the result is false.
529 2 : assert!(!check_ip(
530 2 : &peer_addr,
531 2 : &IpPattern::Range(peer_addr, peer_addr_prev)
532 2 : ));
533 2 : Ok(())
534 2 : }
535 : }
|