Line data Source code
1 : //! User credentials used in authentication.
2 :
3 : use crate::{
4 : auth::password_hack::parse_endpoint_param,
5 : context::RequestMonitoring,
6 : error::{ReportableError, UserFacingError},
7 : metrics::NUM_CONNECTION_ACCEPTED_BY_SNI,
8 : proxy::NeonOptions,
9 : serverless::SERVERLESS_DRIVER_SNI,
10 : EndpointId, RoleName,
11 : };
12 : use itertools::Itertools;
13 : use pq_proto::StartupMessageParams;
14 : use smol_str::SmolStr;
15 : use std::{collections::HashSet, net::IpAddr, str::FromStr};
16 : use thiserror::Error;
17 : use tracing::{info, warn};
18 :
19 0 : #[derive(Debug, Error, PartialEq, Eq, Clone)]
20 : pub enum ComputeUserInfoParseError {
21 : #[error("Parameter '{0}' is missing in startup packet.")]
22 : MissingKey(&'static str),
23 :
24 : #[error(
25 : "Inconsistent project name inferred from \
26 : SNI ('{}') and project option ('{}').",
27 : .domain, .option,
28 : )]
29 : InconsistentProjectNames {
30 : domain: EndpointId,
31 : option: EndpointId,
32 : },
33 :
34 : #[error(
35 : "Common name inferred from SNI ('{}') is not known",
36 : .cn,
37 : )]
38 : UnknownCommonName { cn: String },
39 :
40 : #[error("Project name ('{0}') must contain only alphanumeric characters and hyphen.")]
41 : MalformedProjectName(EndpointId),
42 : }
43 :
44 : impl UserFacingError for ComputeUserInfoParseError {}
45 :
46 : impl ReportableError for ComputeUserInfoParseError {
47 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
48 0 : crate::error::ErrorKind::User
49 0 : }
50 : }
51 :
52 : /// Various client credentials which we use for authentication.
53 : /// Note that we don't store any kind of client key or password here.
54 0 : #[derive(Debug, Clone, PartialEq, Eq)]
55 : pub struct ComputeUserInfoMaybeEndpoint {
56 : pub user: RoleName,
57 : pub endpoint_id: Option<EndpointId>,
58 : pub options: NeonOptions,
59 : }
60 :
61 : impl ComputeUserInfoMaybeEndpoint {
62 : #[inline]
63 49 : pub fn endpoint(&self) -> Option<&str> {
64 49 : self.endpoint_id.as_deref()
65 49 : }
66 : }
67 :
68 93 : pub fn endpoint_sni(
69 93 : sni: &str,
70 93 : common_names: &HashSet<String>,
71 93 : ) -> Result<Option<EndpointId>, ComputeUserInfoParseError> {
72 93 : let Some((subdomain, common_name)) = sni.split_once('.') else {
73 0 : return Err(ComputeUserInfoParseError::UnknownCommonName { cn: sni.into() });
74 : };
75 93 : if !common_names.contains(common_name) {
76 2 : return Err(ComputeUserInfoParseError::UnknownCommonName {
77 2 : cn: common_name.into(),
78 2 : });
79 91 : }
80 91 : if subdomain == SERVERLESS_DRIVER_SNI {
81 0 : return Ok(None);
82 91 : }
83 91 : Ok(Some(EndpointId::from(subdomain)))
84 93 : }
85 :
86 : impl ComputeUserInfoMaybeEndpoint {
87 75 : pub fn parse(
88 75 : ctx: &mut RequestMonitoring,
89 75 : params: &StartupMessageParams,
90 75 : sni: Option<&str>,
91 75 : common_names: Option<&HashSet<String>>,
92 75 : ) -> Result<Self, ComputeUserInfoParseError> {
93 75 : use ComputeUserInfoParseError::*;
94 75 :
95 75 : // Some parameters are stored in the startup message.
96 75 : let get_param = |key| params.get(key).ok_or(MissingKey(key));
97 75 : let user: RoleName = get_param("user")?.into();
98 75 :
99 75 : // record the values if we have them
100 75 : ctx.set_application(params.get("application_name").map(SmolStr::from));
101 75 : ctx.set_user(user.clone());
102 75 : if let Some(dbname) = params.get("database") {
103 51 : ctx.set_dbname(dbname.into());
104 51 : }
105 :
106 : // Project name might be passed via PG's command-line options.
107 75 : let endpoint_option = params
108 75 : .options_raw()
109 75 : .and_then(|options| {
110 63 : // We support both `project` (deprecated) and `endpoint` options for backward compatibility.
111 63 : // However, if both are present, we don't exactly know which one to use.
112 63 : // Therefore we require that only one of them is present.
113 63 : options
114 63 : .filter_map(parse_endpoint_param)
115 63 : .at_most_one()
116 63 : .ok()?
117 75 : })
118 75 : .map(|name| name.into());
119 :
120 75 : let endpoint_from_domain = if let Some(sni_str) = sni {
121 46 : if let Some(cn) = common_names {
122 46 : endpoint_sni(sni_str, cn)?
123 : } else {
124 0 : None
125 : }
126 : } else {
127 29 : None
128 : };
129 :
130 73 : let endpoint = match (endpoint_option, endpoint_from_domain) {
131 : // Invariant: if we have both project name variants, they should match.
132 4 : (Some(option), Some(domain)) if option != domain => {
133 2 : Some(Err(InconsistentProjectNames { domain, option }))
134 : }
135 : // Invariant: project name may not contain certain characters.
136 71 : (a, b) => a.or(b).map(|name| match project_name_valid(name.as_ref()) {
137 0 : false => Err(MalformedProjectName(name)),
138 60 : true => Ok(name),
139 71 : }),
140 : }
141 73 : .transpose()?;
142 :
143 71 : if let Some(ep) = &endpoint {
144 60 : ctx.set_endpoint_id(ep.clone());
145 60 : tracing::Span::current().record("ep", &tracing::field::display(ep));
146 60 : }
147 :
148 71 : info!(%user, project = endpoint.as_deref(), "credentials");
149 71 : if sni.is_some() {
150 42 : info!("Connection with sni");
151 42 : NUM_CONNECTION_ACCEPTED_BY_SNI
152 42 : .with_label_values(&["sni"])
153 42 : .inc();
154 29 : } else if endpoint.is_some() {
155 18 : NUM_CONNECTION_ACCEPTED_BY_SNI
156 18 : .with_label_values(&["no_sni"])
157 18 : .inc();
158 18 : info!("Connection without sni");
159 : } else {
160 11 : NUM_CONNECTION_ACCEPTED_BY_SNI
161 11 : .with_label_values(&["password_hack"])
162 11 : .inc();
163 11 : info!("Connection with password hack");
164 : }
165 :
166 71 : let options = NeonOptions::parse_params(params);
167 71 :
168 71 : Ok(Self {
169 71 : user,
170 71 : endpoint_id: endpoint,
171 71 : options,
172 71 : })
173 75 : }
174 : }
175 :
176 99 : pub fn check_peer_addr_is_in_list(peer_addr: &IpAddr, ip_list: &[IpPattern]) -> bool {
177 99 : ip_list.is_empty() || ip_list.iter().any(|pattern| check_ip(peer_addr, pattern))
178 99 : }
179 :
180 6 : #[derive(Debug, Clone, Eq, PartialEq)]
181 : pub enum IpPattern {
182 : Subnet(ipnet::IpNet),
183 : Range(IpAddr, IpAddr),
184 : Single(IpAddr),
185 : None,
186 : }
187 :
188 : impl<'de> serde::de::Deserialize<'de> for IpPattern {
189 12 : fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
190 12 : where
191 12 : D: serde::Deserializer<'de>,
192 12 : {
193 12 : struct StrVisitor;
194 12 : impl<'de> serde::de::Visitor<'de> for StrVisitor {
195 12 : type Value = IpPattern;
196 12 :
197 12 : fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
198 0 : write!(formatter, "comma separated list with ip address, ip address range, or ip address subnet mask")
199 0 : }
200 12 :
201 12 : fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
202 12 : where
203 12 : E: serde::de::Error,
204 12 : {
205 12 : Ok(parse_ip_pattern(v).unwrap_or_else(|e| {
206 2 : warn!("Cannot parse ip pattern {v}: {e}");
207 12 : IpPattern::None
208 12 : }))
209 12 : }
210 12 : }
211 12 : deserializer.deserialize_str(StrVisitor)
212 12 : }
213 : }
214 :
215 : impl FromStr for IpPattern {
216 : type Err = anyhow::Error;
217 :
218 32 : fn from_str(s: &str) -> Result<Self, Self::Err> {
219 32 : parse_ip_pattern(s)
220 32 : }
221 : }
222 :
223 60 : fn parse_ip_pattern(pattern: &str) -> anyhow::Result<IpPattern> {
224 60 : if pattern.contains('/') {
225 4 : let subnet: ipnet::IpNet = pattern.parse()?;
226 2 : return Ok(IpPattern::Subnet(subnet));
227 56 : }
228 56 : if let Some((start, end)) = pattern.split_once('-') {
229 6 : let start: IpAddr = start.parse()?;
230 4 : let end: IpAddr = end.parse()?;
231 2 : return Ok(IpPattern::Range(start, end));
232 50 : }
233 50 : let addr: IpAddr = pattern.parse()?;
234 44 : Ok(IpPattern::Single(addr))
235 60 : }
236 :
237 40 : fn check_ip(ip: &IpAddr, pattern: &IpPattern) -> bool {
238 40 : match pattern {
239 6 : IpPattern::Subnet(subnet) => subnet.contains(ip),
240 10 : IpPattern::Range(start, end) => start <= ip && ip <= end,
241 22 : IpPattern::Single(addr) => addr == ip,
242 2 : IpPattern::None => false,
243 : }
244 40 : }
245 :
246 60 : fn project_name_valid(name: &str) -> bool {
247 461 : name.chars().all(|c| c.is_alphanumeric() || c == '-')
248 60 : }
249 :
250 : #[cfg(test)]
251 : mod tests {
252 : use super::*;
253 : use serde_json::json;
254 : use ComputeUserInfoParseError::*;
255 :
256 2 : #[test]
257 2 : fn parse_bare_minimum() -> anyhow::Result<()> {
258 2 : // According to postgresql, only `user` should be required.
259 2 : let options = StartupMessageParams::new([("user", "john_doe")]);
260 2 : let mut ctx = RequestMonitoring::test();
261 2 : let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
262 2 : assert_eq!(user_info.user, "john_doe");
263 2 : assert_eq!(user_info.endpoint_id, None);
264 :
265 2 : Ok(())
266 2 : }
267 :
268 2 : #[test]
269 2 : fn parse_excessive() -> anyhow::Result<()> {
270 2 : let options = StartupMessageParams::new([
271 2 : ("user", "john_doe"),
272 2 : ("database", "world"), // should be ignored
273 2 : ("foo", "bar"), // should be ignored
274 2 : ]);
275 2 : let mut ctx = RequestMonitoring::test();
276 2 : let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
277 2 : assert_eq!(user_info.user, "john_doe");
278 2 : assert_eq!(user_info.endpoint_id, None);
279 :
280 2 : Ok(())
281 2 : }
282 :
283 2 : #[test]
284 2 : fn parse_project_from_sni() -> anyhow::Result<()> {
285 2 : let options = StartupMessageParams::new([("user", "john_doe")]);
286 2 :
287 2 : let sni = Some("foo.localhost");
288 2 : let common_names = Some(["localhost".into()].into());
289 2 :
290 2 : let mut ctx = RequestMonitoring::test();
291 2 : let user_info =
292 2 : ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
293 2 : assert_eq!(user_info.user, "john_doe");
294 2 : assert_eq!(user_info.endpoint_id.as_deref(), Some("foo"));
295 2 : assert_eq!(user_info.options.get_cache_key("foo"), "foo");
296 :
297 2 : Ok(())
298 2 : }
299 :
300 2 : #[test]
301 2 : fn parse_project_from_options() -> anyhow::Result<()> {
302 2 : let options = StartupMessageParams::new([
303 2 : ("user", "john_doe"),
304 2 : ("options", "-ckey=1 project=bar -c geqo=off"),
305 2 : ]);
306 2 :
307 2 : let mut ctx = RequestMonitoring::test();
308 2 : let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
309 2 : assert_eq!(user_info.user, "john_doe");
310 2 : assert_eq!(user_info.endpoint_id.as_deref(), Some("bar"));
311 :
312 2 : Ok(())
313 2 : }
314 :
315 2 : #[test]
316 2 : fn parse_endpoint_from_options() -> anyhow::Result<()> {
317 2 : let options = StartupMessageParams::new([
318 2 : ("user", "john_doe"),
319 2 : ("options", "-ckey=1 endpoint=bar -c geqo=off"),
320 2 : ]);
321 2 :
322 2 : let mut ctx = RequestMonitoring::test();
323 2 : let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
324 2 : assert_eq!(user_info.user, "john_doe");
325 2 : assert_eq!(user_info.endpoint_id.as_deref(), Some("bar"));
326 :
327 2 : Ok(())
328 2 : }
329 :
330 2 : #[test]
331 2 : fn parse_three_endpoints_from_options() -> anyhow::Result<()> {
332 2 : let options = StartupMessageParams::new([
333 2 : ("user", "john_doe"),
334 2 : (
335 2 : "options",
336 2 : "-ckey=1 endpoint=one endpoint=two endpoint=three -c geqo=off",
337 2 : ),
338 2 : ]);
339 2 :
340 2 : let mut ctx = RequestMonitoring::test();
341 2 : let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
342 2 : assert_eq!(user_info.user, "john_doe");
343 2 : assert!(user_info.endpoint_id.is_none());
344 :
345 2 : Ok(())
346 2 : }
347 :
348 2 : #[test]
349 2 : fn parse_when_endpoint_and_project_are_in_options() -> anyhow::Result<()> {
350 2 : let options = StartupMessageParams::new([
351 2 : ("user", "john_doe"),
352 2 : ("options", "-ckey=1 endpoint=bar project=foo -c geqo=off"),
353 2 : ]);
354 2 :
355 2 : let mut ctx = RequestMonitoring::test();
356 2 : let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
357 2 : assert_eq!(user_info.user, "john_doe");
358 2 : assert!(user_info.endpoint_id.is_none());
359 :
360 2 : Ok(())
361 2 : }
362 :
363 2 : #[test]
364 2 : fn parse_projects_identical() -> anyhow::Result<()> {
365 2 : let options = StartupMessageParams::new([("user", "john_doe"), ("options", "project=baz")]);
366 2 :
367 2 : let sni = Some("baz.localhost");
368 2 : let common_names = Some(["localhost".into()].into());
369 2 :
370 2 : let mut ctx = RequestMonitoring::test();
371 2 : let user_info =
372 2 : ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
373 2 : assert_eq!(user_info.user, "john_doe");
374 2 : assert_eq!(user_info.endpoint_id.as_deref(), Some("baz"));
375 :
376 2 : Ok(())
377 2 : }
378 :
379 2 : #[test]
380 2 : fn parse_multi_common_names() -> anyhow::Result<()> {
381 2 : let options = StartupMessageParams::new([("user", "john_doe")]);
382 2 :
383 2 : let common_names = Some(["a.com".into(), "b.com".into()].into());
384 2 : let sni = Some("p1.a.com");
385 2 : let mut ctx = RequestMonitoring::test();
386 2 : let user_info =
387 2 : ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
388 2 : assert_eq!(user_info.endpoint_id.as_deref(), Some("p1"));
389 :
390 2 : let common_names = Some(["a.com".into(), "b.com".into()].into());
391 2 : let sni = Some("p1.b.com");
392 2 : let mut ctx = RequestMonitoring::test();
393 2 : let user_info =
394 2 : ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
395 2 : assert_eq!(user_info.endpoint_id.as_deref(), Some("p1"));
396 :
397 2 : Ok(())
398 2 : }
399 :
400 2 : #[test]
401 2 : fn parse_projects_different() {
402 2 : let options =
403 2 : StartupMessageParams::new([("user", "john_doe"), ("options", "project=first")]);
404 2 :
405 2 : let sni = Some("second.localhost");
406 2 : let common_names = Some(["localhost".into()].into());
407 2 :
408 2 : let mut ctx = RequestMonitoring::test();
409 2 : let err =
410 2 : ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())
411 2 : .expect_err("should fail");
412 2 : match err {
413 2 : InconsistentProjectNames { domain, option } => {
414 2 : assert_eq!(option, "first");
415 2 : assert_eq!(domain, "second");
416 : }
417 0 : _ => panic!("bad error: {err:?}"),
418 : }
419 2 : }
420 :
421 2 : #[test]
422 2 : fn parse_inconsistent_sni() {
423 2 : let options = StartupMessageParams::new([("user", "john_doe")]);
424 2 :
425 2 : let sni = Some("project.localhost");
426 2 : let common_names = Some(["example.com".into()].into());
427 2 :
428 2 : let mut ctx = RequestMonitoring::test();
429 2 : let err =
430 2 : ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())
431 2 : .expect_err("should fail");
432 2 : match err {
433 2 : UnknownCommonName { cn } => {
434 2 : assert_eq!(cn, "localhost");
435 : }
436 0 : _ => panic!("bad error: {err:?}"),
437 : }
438 2 : }
439 :
440 2 : #[test]
441 2 : fn parse_neon_options() -> anyhow::Result<()> {
442 2 : let options = StartupMessageParams::new([
443 2 : ("user", "john_doe"),
444 2 : ("options", "neon_lsn:0/2 neon_endpoint_type:read_write"),
445 2 : ]);
446 2 :
447 2 : let sni = Some("project.localhost");
448 2 : let common_names = Some(["localhost".into()].into());
449 2 : let mut ctx = RequestMonitoring::test();
450 2 : let user_info =
451 2 : ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
452 2 : assert_eq!(user_info.endpoint_id.as_deref(), Some("project"));
453 2 : assert_eq!(
454 2 : user_info.options.get_cache_key("project"),
455 2 : "project endpoint_type:read_write lsn:0/2"
456 2 : );
457 :
458 2 : Ok(())
459 2 : }
460 :
461 2 : #[test]
462 2 : fn test_check_peer_addr_is_in_list() {
463 8 : fn check(v: serde_json::Value) -> bool {
464 8 : let peer_addr = IpAddr::from([127, 0, 0, 1]);
465 8 : let ip_list: Vec<IpPattern> = serde_json::from_value(v).unwrap();
466 8 : check_peer_addr_is_in_list(&peer_addr, &ip_list)
467 8 : }
468 2 :
469 2 : assert!(check(json!([])));
470 2 : assert!(check(json!(["127.0.0.1"])));
471 2 : assert!(!check(json!(["8.8.8.8"])));
472 : // If there is an incorrect address, it will be skipped.
473 2 : assert!(check(json!(["88.8.8", "127.0.0.1"])));
474 2 : }
475 2 : #[test]
476 2 : fn test_parse_ip_v4() -> anyhow::Result<()> {
477 2 : let peer_addr = IpAddr::from([127, 0, 0, 1]);
478 : // Ok
479 2 : assert_eq!(parse_ip_pattern("127.0.0.1")?, IpPattern::Single(peer_addr));
480 2 : assert_eq!(
481 2 : parse_ip_pattern("127.0.0.1/31")?,
482 2 : IpPattern::Subnet(ipnet::IpNet::new(peer_addr, 31)?)
483 : );
484 2 : assert_eq!(
485 2 : parse_ip_pattern("0.0.0.0-200.0.1.2")?,
486 2 : IpPattern::Range(IpAddr::from([0, 0, 0, 0]), IpAddr::from([200, 0, 1, 2]))
487 : );
488 :
489 : // Error
490 2 : assert!(parse_ip_pattern("300.0.1.2").is_err());
491 2 : assert!(parse_ip_pattern("30.1.2").is_err());
492 2 : assert!(parse_ip_pattern("127.0.0.1/33").is_err());
493 2 : assert!(parse_ip_pattern("127.0.0.1-127.0.3").is_err());
494 2 : assert!(parse_ip_pattern("1234.0.0.1-127.0.3.0").is_err());
495 2 : Ok(())
496 2 : }
497 :
498 2 : #[test]
499 2 : fn test_check_ipv4() -> anyhow::Result<()> {
500 2 : let peer_addr = IpAddr::from([127, 0, 0, 1]);
501 2 : let peer_addr_next = IpAddr::from([127, 0, 0, 2]);
502 2 : let peer_addr_prev = IpAddr::from([127, 0, 0, 0]);
503 2 : // Success
504 2 : assert!(check_ip(&peer_addr, &IpPattern::Single(peer_addr)));
505 2 : assert!(check_ip(
506 2 : &peer_addr,
507 2 : &IpPattern::Subnet(ipnet::IpNet::new(peer_addr_prev, 31)?)
508 : ));
509 2 : assert!(check_ip(
510 2 : &peer_addr,
511 2 : &IpPattern::Subnet(ipnet::IpNet::new(peer_addr_next, 30)?)
512 : ));
513 2 : assert!(check_ip(
514 2 : &peer_addr,
515 2 : &IpPattern::Range(IpAddr::from([0, 0, 0, 0]), IpAddr::from([200, 0, 1, 2]))
516 2 : ));
517 2 : assert!(check_ip(
518 2 : &peer_addr,
519 2 : &IpPattern::Range(peer_addr, peer_addr)
520 2 : ));
521 :
522 : // Not success
523 2 : assert!(!check_ip(&peer_addr, &IpPattern::Single(peer_addr_prev)));
524 2 : assert!(!check_ip(
525 2 : &peer_addr,
526 2 : &IpPattern::Subnet(ipnet::IpNet::new(peer_addr_next, 31)?)
527 : ));
528 2 : assert!(!check_ip(
529 2 : &peer_addr,
530 2 : &IpPattern::Range(IpAddr::from([0, 0, 0, 0]), peer_addr_prev)
531 2 : ));
532 2 : assert!(!check_ip(
533 2 : &peer_addr,
534 2 : &IpPattern::Range(peer_addr_next, IpAddr::from([128, 0, 0, 0]))
535 2 : ));
536 : // There is no check that for range start <= end. But it's fine as long as for all this cases the result is false.
537 2 : assert!(!check_ip(
538 2 : &peer_addr,
539 2 : &IpPattern::Range(peer_addr, peer_addr_prev)
540 2 : ));
541 2 : Ok(())
542 2 : }
543 : }
|