Line data Source code
1 : //! User credentials used in authentication.
2 :
3 : use crate::{
4 : auth::password_hack::parse_endpoint_param, context::RequestMonitoring, error::UserFacingError,
5 : metrics::NUM_CONNECTION_ACCEPTED_BY_SNI, proxy::NeonOptions, serverless::SERVERLESS_DRIVER_SNI,
6 : EndpointId, RoleName,
7 : };
8 : use itertools::Itertools;
9 : use pq_proto::StartupMessageParams;
10 : use smol_str::SmolStr;
11 : use std::{collections::HashSet, net::IpAddr, str::FromStr};
12 : use thiserror::Error;
13 : use tracing::{info, warn};
14 :
15 0 : #[derive(Debug, Error, PartialEq, Eq, Clone)]
16 : pub enum ComputeUserInfoParseError {
17 : #[error("Parameter '{0}' is missing in startup packet.")]
18 : MissingKey(&'static str),
19 :
20 : #[error(
21 : "Inconsistent project name inferred from \
22 : SNI ('{}') and project option ('{}').",
23 : .domain, .option,
24 : )]
25 : InconsistentProjectNames {
26 : domain: EndpointId,
27 : option: EndpointId,
28 : },
29 :
30 : #[error(
31 : "Common name inferred from SNI ('{}') is not known",
32 : .cn,
33 : )]
34 : UnknownCommonName { cn: String },
35 :
36 : #[error("Project name ('{0}') must contain only alphanumeric characters and hyphen.")]
37 : MalformedProjectName(EndpointId),
38 : }
39 :
40 : impl UserFacingError for ComputeUserInfoParseError {}
41 :
42 : /// Various client credentials which we use for authentication.
43 : /// Note that we don't store any kind of client key or password here.
44 0 : #[derive(Debug, Clone, PartialEq, Eq)]
45 : pub struct ComputeUserInfoMaybeEndpoint {
46 : pub user: RoleName,
47 : pub endpoint_id: Option<EndpointId>,
48 : pub options: NeonOptions,
49 : }
50 :
51 : impl ComputeUserInfoMaybeEndpoint {
52 : #[inline]
53 47 : pub fn endpoint(&self) -> Option<&str> {
54 47 : self.endpoint_id.as_deref()
55 47 : }
56 : }
57 :
58 90 : pub fn endpoint_sni(
59 90 : sni: &str,
60 90 : common_names: &HashSet<String>,
61 90 : ) -> Result<Option<EndpointId>, ComputeUserInfoParseError> {
62 90 : let Some((subdomain, common_name)) = sni.split_once('.') else {
63 0 : return Err(ComputeUserInfoParseError::UnknownCommonName { cn: sni.into() });
64 : };
65 90 : if !common_names.contains(common_name) {
66 2 : return Err(ComputeUserInfoParseError::UnknownCommonName {
67 2 : cn: common_name.into(),
68 2 : });
69 88 : }
70 88 : if subdomain == SERVERLESS_DRIVER_SNI {
71 0 : return Ok(None);
72 88 : }
73 88 : Ok(Some(EndpointId::from(subdomain)))
74 90 : }
75 :
76 : impl ComputeUserInfoMaybeEndpoint {
77 73 : pub fn parse(
78 73 : ctx: &mut RequestMonitoring,
79 73 : params: &StartupMessageParams,
80 73 : sni: Option<&str>,
81 73 : common_names: Option<&HashSet<String>>,
82 73 : ) -> Result<Self, ComputeUserInfoParseError> {
83 73 : use ComputeUserInfoParseError::*;
84 73 :
85 73 : // Some parameters are stored in the startup message.
86 73 : let get_param = |key| params.get(key).ok_or(MissingKey(key));
87 73 : let user: RoleName = get_param("user")?.into();
88 73 :
89 73 : // record the values if we have them
90 73 : ctx.set_application(params.get("application_name").map(SmolStr::from));
91 73 : ctx.set_user(user.clone());
92 73 :
93 73 : // Project name might be passed via PG's command-line options.
94 73 : let endpoint_option = params
95 73 : .options_raw()
96 73 : .and_then(|options| {
97 61 : // We support both `project` (deprecated) and `endpoint` options for backward compatibility.
98 61 : // However, if both are present, we don't exactly know which one to use.
99 61 : // Therefore we require that only one of them is present.
100 61 : options
101 61 : .filter_map(parse_endpoint_param)
102 61 : .at_most_one()
103 61 : .ok()?
104 73 : })
105 73 : .map(|name| name.into());
106 :
107 73 : let endpoint_from_domain = if let Some(sni_str) = sni {
108 44 : if let Some(cn) = common_names {
109 44 : endpoint_sni(sni_str, cn)?
110 : } else {
111 0 : None
112 : }
113 : } else {
114 29 : None
115 : };
116 :
117 71 : let endpoint = match (endpoint_option, endpoint_from_domain) {
118 : // Invariant: if we have both project name variants, they should match.
119 4 : (Some(option), Some(domain)) if option != domain => {
120 2 : Some(Err(InconsistentProjectNames { domain, option }))
121 : }
122 : // Invariant: project name may not contain certain characters.
123 69 : (a, b) => a.or(b).map(|name| match project_name_valid(name.as_ref()) {
124 0 : false => Err(MalformedProjectName(name)),
125 58 : true => Ok(name),
126 69 : }),
127 : }
128 71 : .transpose()?;
129 :
130 69 : if let Some(ep) = &endpoint {
131 58 : ctx.set_endpoint_id(ep.clone());
132 58 : tracing::Span::current().record("ep", &tracing::field::display(ep));
133 58 : }
134 :
135 69 : info!(%user, project = endpoint.as_deref(), "credentials");
136 69 : if sni.is_some() {
137 40 : info!("Connection with sni");
138 40 : NUM_CONNECTION_ACCEPTED_BY_SNI
139 40 : .with_label_values(&["sni"])
140 40 : .inc();
141 29 : } else if endpoint.is_some() {
142 18 : NUM_CONNECTION_ACCEPTED_BY_SNI
143 18 : .with_label_values(&["no_sni"])
144 18 : .inc();
145 18 : info!("Connection without sni");
146 : } else {
147 11 : NUM_CONNECTION_ACCEPTED_BY_SNI
148 11 : .with_label_values(&["password_hack"])
149 11 : .inc();
150 11 : info!("Connection with password hack");
151 : }
152 :
153 69 : let options = NeonOptions::parse_params(params);
154 69 :
155 69 : Ok(Self {
156 69 : user,
157 69 : endpoint_id: endpoint,
158 69 : options,
159 69 : })
160 73 : }
161 : }
162 :
163 92 : pub fn check_peer_addr_is_in_list(peer_addr: &IpAddr, ip_list: &[IpPattern]) -> bool {
164 92 : ip_list.is_empty() || ip_list.iter().any(|pattern| check_ip(peer_addr, pattern))
165 92 : }
166 :
167 6 : #[derive(Debug, Clone, Eq, PartialEq)]
168 : pub enum IpPattern {
169 : Subnet(ipnet::IpNet),
170 : Range(IpAddr, IpAddr),
171 : Single(IpAddr),
172 : None,
173 : }
174 :
175 : impl<'de> serde::de::Deserialize<'de> for IpPattern {
176 12 : fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
177 12 : where
178 12 : D: serde::Deserializer<'de>,
179 12 : {
180 12 : struct StrVisitor;
181 12 : impl<'de> serde::de::Visitor<'de> for StrVisitor {
182 12 : type Value = IpPattern;
183 12 :
184 12 : fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
185 0 : write!(formatter, "comma separated list with ip address, ip address range, or ip address subnet mask")
186 0 : }
187 12 :
188 12 : fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
189 12 : where
190 12 : E: serde::de::Error,
191 12 : {
192 12 : Ok(parse_ip_pattern(v).unwrap_or_else(|e| {
193 2 : warn!("Cannot parse ip pattern {v}: {e}");
194 12 : IpPattern::None
195 12 : }))
196 12 : }
197 12 : }
198 12 : deserializer.deserialize_str(StrVisitor)
199 12 : }
200 : }
201 :
202 : impl FromStr for IpPattern {
203 : type Err = anyhow::Error;
204 :
205 30 : fn from_str(s: &str) -> Result<Self, Self::Err> {
206 30 : parse_ip_pattern(s)
207 30 : }
208 : }
209 :
210 58 : fn parse_ip_pattern(pattern: &str) -> anyhow::Result<IpPattern> {
211 58 : if pattern.contains('/') {
212 4 : let subnet: ipnet::IpNet = pattern.parse()?;
213 2 : return Ok(IpPattern::Subnet(subnet));
214 54 : }
215 54 : if let Some((start, end)) = pattern.split_once('-') {
216 6 : let start: IpAddr = start.parse()?;
217 4 : let end: IpAddr = end.parse()?;
218 2 : return Ok(IpPattern::Range(start, end));
219 48 : }
220 48 : let addr: IpAddr = pattern.parse()?;
221 42 : Ok(IpPattern::Single(addr))
222 58 : }
223 :
224 40 : fn check_ip(ip: &IpAddr, pattern: &IpPattern) -> bool {
225 40 : match pattern {
226 6 : IpPattern::Subnet(subnet) => subnet.contains(ip),
227 10 : IpPattern::Range(start, end) => start <= ip && ip <= end,
228 22 : IpPattern::Single(addr) => addr == ip,
229 2 : IpPattern::None => false,
230 : }
231 40 : }
232 :
233 58 : fn project_name_valid(name: &str) -> bool {
234 451 : name.chars().all(|c| c.is_alphanumeric() || c == '-')
235 58 : }
236 :
237 : #[cfg(test)]
238 : mod tests {
239 : use super::*;
240 : use serde_json::json;
241 : use ComputeUserInfoParseError::*;
242 :
243 2 : #[test]
244 2 : fn parse_bare_minimum() -> anyhow::Result<()> {
245 2 : // According to postgresql, only `user` should be required.
246 2 : let options = StartupMessageParams::new([("user", "john_doe")]);
247 2 : let mut ctx = RequestMonitoring::test();
248 2 : let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
249 2 : assert_eq!(user_info.user, "john_doe");
250 2 : assert_eq!(user_info.endpoint_id, None);
251 :
252 2 : Ok(())
253 2 : }
254 :
255 2 : #[test]
256 2 : fn parse_excessive() -> anyhow::Result<()> {
257 2 : let options = StartupMessageParams::new([
258 2 : ("user", "john_doe"),
259 2 : ("database", "world"), // should be ignored
260 2 : ("foo", "bar"), // should be ignored
261 2 : ]);
262 2 : let mut ctx = RequestMonitoring::test();
263 2 : let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
264 2 : assert_eq!(user_info.user, "john_doe");
265 2 : assert_eq!(user_info.endpoint_id, None);
266 :
267 2 : Ok(())
268 2 : }
269 :
270 2 : #[test]
271 2 : fn parse_project_from_sni() -> anyhow::Result<()> {
272 2 : let options = StartupMessageParams::new([("user", "john_doe")]);
273 2 :
274 2 : let sni = Some("foo.localhost");
275 2 : let common_names = Some(["localhost".into()].into());
276 2 :
277 2 : let mut ctx = RequestMonitoring::test();
278 2 : let user_info =
279 2 : ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
280 2 : assert_eq!(user_info.user, "john_doe");
281 2 : assert_eq!(user_info.endpoint_id.as_deref(), Some("foo"));
282 2 : assert_eq!(user_info.options.get_cache_key("foo"), "foo");
283 :
284 2 : Ok(())
285 2 : }
286 :
287 2 : #[test]
288 2 : fn parse_project_from_options() -> anyhow::Result<()> {
289 2 : let options = StartupMessageParams::new([
290 2 : ("user", "john_doe"),
291 2 : ("options", "-ckey=1 project=bar -c geqo=off"),
292 2 : ]);
293 2 :
294 2 : let mut ctx = RequestMonitoring::test();
295 2 : let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
296 2 : assert_eq!(user_info.user, "john_doe");
297 2 : assert_eq!(user_info.endpoint_id.as_deref(), Some("bar"));
298 :
299 2 : Ok(())
300 2 : }
301 :
302 2 : #[test]
303 2 : fn parse_endpoint_from_options() -> anyhow::Result<()> {
304 2 : let options = StartupMessageParams::new([
305 2 : ("user", "john_doe"),
306 2 : ("options", "-ckey=1 endpoint=bar -c geqo=off"),
307 2 : ]);
308 2 :
309 2 : let mut ctx = RequestMonitoring::test();
310 2 : let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
311 2 : assert_eq!(user_info.user, "john_doe");
312 2 : assert_eq!(user_info.endpoint_id.as_deref(), Some("bar"));
313 :
314 2 : Ok(())
315 2 : }
316 :
317 2 : #[test]
318 2 : fn parse_three_endpoints_from_options() -> anyhow::Result<()> {
319 2 : let options = StartupMessageParams::new([
320 2 : ("user", "john_doe"),
321 2 : (
322 2 : "options",
323 2 : "-ckey=1 endpoint=one endpoint=two endpoint=three -c geqo=off",
324 2 : ),
325 2 : ]);
326 2 :
327 2 : let mut ctx = RequestMonitoring::test();
328 2 : let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
329 2 : assert_eq!(user_info.user, "john_doe");
330 2 : assert!(user_info.endpoint_id.is_none());
331 :
332 2 : Ok(())
333 2 : }
334 :
335 2 : #[test]
336 2 : fn parse_when_endpoint_and_project_are_in_options() -> anyhow::Result<()> {
337 2 : let options = StartupMessageParams::new([
338 2 : ("user", "john_doe"),
339 2 : ("options", "-ckey=1 endpoint=bar project=foo -c geqo=off"),
340 2 : ]);
341 2 :
342 2 : let mut ctx = RequestMonitoring::test();
343 2 : let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
344 2 : assert_eq!(user_info.user, "john_doe");
345 2 : assert!(user_info.endpoint_id.is_none());
346 :
347 2 : Ok(())
348 2 : }
349 :
350 2 : #[test]
351 2 : fn parse_projects_identical() -> anyhow::Result<()> {
352 2 : let options = StartupMessageParams::new([("user", "john_doe"), ("options", "project=baz")]);
353 2 :
354 2 : let sni = Some("baz.localhost");
355 2 : let common_names = Some(["localhost".into()].into());
356 2 :
357 2 : let mut ctx = RequestMonitoring::test();
358 2 : let user_info =
359 2 : ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
360 2 : assert_eq!(user_info.user, "john_doe");
361 2 : assert_eq!(user_info.endpoint_id.as_deref(), Some("baz"));
362 :
363 2 : Ok(())
364 2 : }
365 :
366 2 : #[test]
367 2 : fn parse_multi_common_names() -> anyhow::Result<()> {
368 2 : let options = StartupMessageParams::new([("user", "john_doe")]);
369 2 :
370 2 : let common_names = Some(["a.com".into(), "b.com".into()].into());
371 2 : let sni = Some("p1.a.com");
372 2 : let mut ctx = RequestMonitoring::test();
373 2 : let user_info =
374 2 : ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
375 2 : assert_eq!(user_info.endpoint_id.as_deref(), Some("p1"));
376 :
377 2 : let common_names = Some(["a.com".into(), "b.com".into()].into());
378 2 : let sni = Some("p1.b.com");
379 2 : let mut ctx = RequestMonitoring::test();
380 2 : let user_info =
381 2 : ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
382 2 : assert_eq!(user_info.endpoint_id.as_deref(), Some("p1"));
383 :
384 2 : Ok(())
385 2 : }
386 :
387 2 : #[test]
388 2 : fn parse_projects_different() {
389 2 : let options =
390 2 : StartupMessageParams::new([("user", "john_doe"), ("options", "project=first")]);
391 2 :
392 2 : let sni = Some("second.localhost");
393 2 : let common_names = Some(["localhost".into()].into());
394 2 :
395 2 : let mut ctx = RequestMonitoring::test();
396 2 : let err =
397 2 : ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())
398 2 : .expect_err("should fail");
399 2 : match err {
400 2 : InconsistentProjectNames { domain, option } => {
401 2 : assert_eq!(option, "first");
402 2 : assert_eq!(domain, "second");
403 : }
404 0 : _ => panic!("bad error: {err:?}"),
405 : }
406 2 : }
407 :
408 2 : #[test]
409 2 : fn parse_inconsistent_sni() {
410 2 : let options = StartupMessageParams::new([("user", "john_doe")]);
411 2 :
412 2 : let sni = Some("project.localhost");
413 2 : let common_names = Some(["example.com".into()].into());
414 2 :
415 2 : let mut ctx = RequestMonitoring::test();
416 2 : let err =
417 2 : ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())
418 2 : .expect_err("should fail");
419 2 : match err {
420 2 : UnknownCommonName { cn } => {
421 2 : assert_eq!(cn, "localhost");
422 : }
423 0 : _ => panic!("bad error: {err:?}"),
424 : }
425 2 : }
426 :
427 2 : #[test]
428 2 : fn parse_neon_options() -> anyhow::Result<()> {
429 2 : let options = StartupMessageParams::new([
430 2 : ("user", "john_doe"),
431 2 : ("options", "neon_lsn:0/2 neon_endpoint_type:read_write"),
432 2 : ]);
433 2 :
434 2 : let sni = Some("project.localhost");
435 2 : let common_names = Some(["localhost".into()].into());
436 2 : let mut ctx = RequestMonitoring::test();
437 2 : let user_info =
438 2 : ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
439 2 : assert_eq!(user_info.endpoint_id.as_deref(), Some("project"));
440 2 : assert_eq!(
441 2 : user_info.options.get_cache_key("project"),
442 2 : "project endpoint_type:read_write lsn:0/2"
443 2 : );
444 :
445 2 : Ok(())
446 2 : }
447 :
448 2 : #[test]
449 2 : fn test_check_peer_addr_is_in_list() {
450 8 : fn check(v: serde_json::Value) -> bool {
451 8 : let peer_addr = IpAddr::from([127, 0, 0, 1]);
452 8 : let ip_list: Vec<IpPattern> = serde_json::from_value(v).unwrap();
453 8 : check_peer_addr_is_in_list(&peer_addr, &ip_list)
454 8 : }
455 :
456 2 : assert!(check(json!([])));
457 2 : assert!(check(json!(["127.0.0.1"])));
458 2 : assert!(!check(json!(["8.8.8.8"])));
459 : // If there is an incorrect address, it will be skipped.
460 2 : assert!(check(json!(["88.8.8", "127.0.0.1"])));
461 2 : }
462 2 : #[test]
463 2 : fn test_parse_ip_v4() -> anyhow::Result<()> {
464 2 : let peer_addr = IpAddr::from([127, 0, 0, 1]);
465 : // Ok
466 2 : assert_eq!(parse_ip_pattern("127.0.0.1")?, IpPattern::Single(peer_addr));
467 2 : assert_eq!(
468 2 : parse_ip_pattern("127.0.0.1/31")?,
469 2 : IpPattern::Subnet(ipnet::IpNet::new(peer_addr, 31)?)
470 : );
471 2 : assert_eq!(
472 2 : parse_ip_pattern("0.0.0.0-200.0.1.2")?,
473 2 : IpPattern::Range(IpAddr::from([0, 0, 0, 0]), IpAddr::from([200, 0, 1, 2]))
474 : );
475 :
476 : // Error
477 2 : assert!(parse_ip_pattern("300.0.1.2").is_err());
478 2 : assert!(parse_ip_pattern("30.1.2").is_err());
479 2 : assert!(parse_ip_pattern("127.0.0.1/33").is_err());
480 2 : assert!(parse_ip_pattern("127.0.0.1-127.0.3").is_err());
481 2 : assert!(parse_ip_pattern("1234.0.0.1-127.0.3.0").is_err());
482 2 : Ok(())
483 2 : }
484 :
485 2 : #[test]
486 2 : fn test_check_ipv4() -> anyhow::Result<()> {
487 2 : let peer_addr = IpAddr::from([127, 0, 0, 1]);
488 2 : let peer_addr_next = IpAddr::from([127, 0, 0, 2]);
489 2 : let peer_addr_prev = IpAddr::from([127, 0, 0, 0]);
490 : // Success
491 2 : assert!(check_ip(&peer_addr, &IpPattern::Single(peer_addr)));
492 2 : assert!(check_ip(
493 2 : &peer_addr,
494 2 : &IpPattern::Subnet(ipnet::IpNet::new(peer_addr_prev, 31)?)
495 : ));
496 2 : assert!(check_ip(
497 2 : &peer_addr,
498 2 : &IpPattern::Subnet(ipnet::IpNet::new(peer_addr_next, 30)?)
499 : ));
500 2 : assert!(check_ip(
501 2 : &peer_addr,
502 2 : &IpPattern::Range(IpAddr::from([0, 0, 0, 0]), IpAddr::from([200, 0, 1, 2]))
503 2 : ));
504 2 : assert!(check_ip(
505 2 : &peer_addr,
506 2 : &IpPattern::Range(peer_addr, peer_addr)
507 2 : ));
508 :
509 : // Not success
510 2 : assert!(!check_ip(&peer_addr, &IpPattern::Single(peer_addr_prev)));
511 0 : assert!(!check_ip(
512 2 : &peer_addr,
513 2 : &IpPattern::Subnet(ipnet::IpNet::new(peer_addr_next, 31)?)
514 : ));
515 2 : assert!(!check_ip(
516 2 : &peer_addr,
517 2 : &IpPattern::Range(IpAddr::from([0, 0, 0, 0]), peer_addr_prev)
518 2 : ));
519 2 : assert!(!check_ip(
520 2 : &peer_addr,
521 2 : &IpPattern::Range(peer_addr_next, IpAddr::from([128, 0, 0, 0]))
522 2 : ));
523 : // There is no check that for range start <= end. But it's fine as long as for all this cases the result is false.
524 2 : assert!(!check_ip(
525 2 : &peer_addr,
526 2 : &IpPattern::Range(peer_addr, peer_addr_prev)
527 2 : ));
528 2 : Ok(())
529 2 : }
530 : }
|