Line data Source code
1 : mod classic;
2 : mod hacks;
3 : mod link;
4 :
5 : use std::net::IpAddr;
6 : use std::sync::Arc;
7 : use std::time::Duration;
8 :
9 : use ipnet::{Ipv4Net, Ipv6Net};
10 : pub use link::LinkAuthError;
11 : use tokio::io::{AsyncRead, AsyncWrite};
12 : use tokio_postgres::config::AuthKeys;
13 : use tracing::{info, warn};
14 :
15 : use crate::auth::credentials::check_peer_addr_is_in_list;
16 : use crate::auth::{validate_password_and_exchange, AuthError};
17 : use crate::cache::Cached;
18 : use crate::console::errors::GetAuthInfoError;
19 : use crate::console::provider::{CachedRoleSecret, ConsoleBackend};
20 : use crate::console::{AuthSecret, NodeInfo};
21 : use crate::context::RequestMonitoring;
22 : use crate::intern::EndpointIdInt;
23 : use crate::metrics::Metrics;
24 : use crate::proxy::connect_compute::ComputeConnectBackend;
25 : use crate::proxy::NeonOptions;
26 : use crate::rate_limiter::{BucketRateLimiter, EndpointRateLimiter, RateBucketInfo};
27 : use crate::stream::Stream;
28 : use crate::{
29 : auth::{self, ComputeUserInfoMaybeEndpoint},
30 : config::AuthenticationConfig,
31 : console::{
32 : self,
33 : provider::{CachedAllowedIps, CachedNodeInfo},
34 : Api,
35 : },
36 : stream, url,
37 : };
38 : use crate::{scram, EndpointCacheKey, EndpointId, RoleName};
39 :
40 : /// Alternative to [`std::borrow::Cow`] but doesn't need `T: ToOwned` as we don't need that functionality
41 : pub enum MaybeOwned<'a, T> {
42 : Owned(T),
43 : Borrowed(&'a T),
44 : }
45 :
46 : impl<T> std::ops::Deref for MaybeOwned<'_, T> {
47 : type Target = T;
48 :
49 26 : fn deref(&self) -> &Self::Target {
50 26 : match self {
51 26 : MaybeOwned::Owned(t) => t,
52 0 : MaybeOwned::Borrowed(t) => t,
53 : }
54 26 : }
55 : }
56 :
57 : /// This type serves two purposes:
58 : ///
59 : /// * When `T` is `()`, it's just a regular auth backend selector
60 : /// which we use in [`crate::config::ProxyConfig`].
61 : ///
62 : /// * However, when we substitute `T` with [`ComputeUserInfoMaybeEndpoint`],
63 : /// this helps us provide the credentials only to those auth
64 : /// backends which require them for the authentication process.
65 : pub enum BackendType<'a, T, D> {
66 : /// Cloud API (V2).
67 : Console(MaybeOwned<'a, ConsoleBackend>, T),
68 : /// Authentication via a web browser.
69 : Link(MaybeOwned<'a, url::ApiUrl>, D),
70 : }
71 :
72 : pub trait TestBackend: Send + Sync + 'static {
73 : fn wake_compute(&self) -> Result<CachedNodeInfo, console::errors::WakeComputeError>;
74 : fn get_allowed_ips_and_secret(
75 : &self,
76 : ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), console::errors::GetAuthInfoError>;
77 : fn get_role_secret(&self) -> Result<CachedRoleSecret, console::errors::GetAuthInfoError>;
78 : }
79 :
80 : impl std::fmt::Display for BackendType<'_, (), ()> {
81 0 : fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82 0 : use BackendType::*;
83 0 : match self {
84 0 : Console(api, _) => match &**api {
85 0 : ConsoleBackend::Console(endpoint) => {
86 0 : fmt.debug_tuple("Console").field(&endpoint.url()).finish()
87 : }
88 : #[cfg(any(test, feature = "testing"))]
89 0 : ConsoleBackend::Postgres(endpoint) => {
90 0 : fmt.debug_tuple("Postgres").field(&endpoint.url()).finish()
91 : }
92 : #[cfg(test)]
93 0 : ConsoleBackend::Test(_) => fmt.debug_tuple("Test").finish(),
94 : },
95 0 : Link(url, _) => fmt.debug_tuple("Link").field(&url.as_str()).finish(),
96 : }
97 0 : }
98 : }
99 :
100 : impl<T, D> BackendType<'_, T, D> {
101 : /// Very similar to [`std::option::Option::as_ref`].
102 : /// This helps us pass structured config to async tasks.
103 0 : pub fn as_ref(&self) -> BackendType<'_, &T, &D> {
104 0 : use BackendType::*;
105 0 : match self {
106 0 : Console(c, x) => Console(MaybeOwned::Borrowed(c), x),
107 0 : Link(c, x) => Link(MaybeOwned::Borrowed(c), x),
108 : }
109 0 : }
110 : }
111 :
112 : impl<'a, T, D> BackendType<'a, T, D> {
113 : /// Very similar to [`std::option::Option::map`].
114 : /// Maps [`BackendType<T>`] to [`BackendType<R>`] by applying
115 : /// a function to a contained value.
116 0 : pub fn map<R>(self, f: impl FnOnce(T) -> R) -> BackendType<'a, R, D> {
117 0 : use BackendType::*;
118 0 : match self {
119 0 : Console(c, x) => Console(c, f(x)),
120 0 : Link(c, x) => Link(c, x),
121 : }
122 0 : }
123 : }
124 : impl<'a, T, D, E> BackendType<'a, Result<T, E>, D> {
125 : /// Very similar to [`std::option::Option::transpose`].
126 : /// This is most useful for error handling.
127 0 : pub fn transpose(self) -> Result<BackendType<'a, T, D>, E> {
128 0 : use BackendType::*;
129 0 : match self {
130 0 : Console(c, x) => x.map(|x| Console(c, x)),
131 0 : Link(c, x) => Ok(Link(c, x)),
132 : }
133 0 : }
134 : }
135 :
136 : pub struct ComputeCredentials {
137 : pub info: ComputeUserInfo,
138 : pub keys: ComputeCredentialKeys,
139 : }
140 :
141 : #[derive(Debug, Clone)]
142 : pub struct ComputeUserInfoNoEndpoint {
143 : pub user: RoleName,
144 : pub options: NeonOptions,
145 : }
146 :
147 : #[derive(Debug, Clone)]
148 : pub struct ComputeUserInfo {
149 : pub endpoint: EndpointId,
150 : pub user: RoleName,
151 : pub options: NeonOptions,
152 : }
153 :
154 : impl ComputeUserInfo {
155 4 : pub fn endpoint_cache_key(&self) -> EndpointCacheKey {
156 4 : self.options.get_cache_key(&self.endpoint)
157 4 : }
158 : }
159 :
160 : pub enum ComputeCredentialKeys {
161 : Password(Vec<u8>),
162 : AuthKeys(AuthKeys),
163 : }
164 :
165 : impl TryFrom<ComputeUserInfoMaybeEndpoint> for ComputeUserInfo {
166 : // user name
167 : type Error = ComputeUserInfoNoEndpoint;
168 :
169 6 : fn try_from(user_info: ComputeUserInfoMaybeEndpoint) -> Result<Self, Self::Error> {
170 6 : match user_info.endpoint_id {
171 2 : None => Err(ComputeUserInfoNoEndpoint {
172 2 : user: user_info.user,
173 2 : options: user_info.options,
174 2 : }),
175 4 : Some(endpoint) => Ok(ComputeUserInfo {
176 4 : endpoint,
177 4 : user: user_info.user,
178 4 : options: user_info.options,
179 4 : }),
180 : }
181 6 : }
182 : }
183 :
184 : #[derive(PartialEq, PartialOrd, Hash, Eq, Ord, Debug, Copy, Clone)]
185 : pub struct MaskedIp(IpAddr);
186 :
187 : impl MaskedIp {
188 30 : fn new(value: IpAddr, prefix: u8) -> Self {
189 30 : match value {
190 22 : IpAddr::V4(v4) => Self(IpAddr::V4(
191 22 : Ipv4Net::new(v4, prefix).map_or(v4, |x| x.trunc().addr()),
192 22 : )),
193 8 : IpAddr::V6(v6) => Self(IpAddr::V6(
194 8 : Ipv6Net::new(v6, prefix).map_or(v6, |x| x.trunc().addr()),
195 8 : )),
196 : }
197 30 : }
198 : }
199 :
200 : // This can't be just per IP because that would limit some PaaS that share IP addresses
201 : pub type AuthRateLimiter = BucketRateLimiter<(EndpointIdInt, MaskedIp)>;
202 :
203 : impl RateBucketInfo {
204 : /// All of these are per endpoint-maskedip pair.
205 : /// Context: 4096 rounds of pbkdf2 take about 1ms of cpu time to execute (1 milli-cpu-second or 1mcpus).
206 : ///
207 : /// First bucket: 1000mcpus total per endpoint-ip pair
208 : /// * 4096000 requests per second with 1 hash rounds.
209 : /// * 1000 requests per second with 4096 hash rounds.
210 : /// * 6.8 requests per second with 600000 hash rounds.
211 : pub const DEFAULT_AUTH_SET: [Self; 3] = [
212 : Self::new(1000 * 4096, Duration::from_secs(1)),
213 : Self::new(600 * 4096, Duration::from_secs(60)),
214 : Self::new(300 * 4096, Duration::from_secs(600)),
215 : ];
216 : }
217 :
218 : impl AuthenticationConfig {
219 6 : pub fn check_rate_limit(
220 6 : &self,
221 6 : ctx: &mut RequestMonitoring,
222 6 : config: &AuthenticationConfig,
223 6 : secret: AuthSecret,
224 6 : endpoint: &EndpointId,
225 6 : is_cleartext: bool,
226 6 : ) -> auth::Result<AuthSecret> {
227 6 : // we have validated the endpoint exists, so let's intern it.
228 6 : let endpoint_int = EndpointIdInt::from(endpoint.normalize());
229 :
230 : // only count the full hash count if password hack or websocket flow.
231 : // in other words, if proxy needs to run the hashing
232 6 : let password_weight = if is_cleartext {
233 4 : match &secret {
234 : #[cfg(any(test, feature = "testing"))]
235 0 : AuthSecret::Md5(_) => 1,
236 4 : AuthSecret::Scram(s) => s.iterations + 1,
237 : }
238 : } else {
239 : // validating scram takes just 1 hmac_sha_256 operation.
240 2 : 1
241 : };
242 :
243 6 : let limit_not_exceeded = self.rate_limiter.check(
244 6 : (
245 6 : endpoint_int,
246 6 : MaskedIp::new(ctx.peer_addr, config.rate_limit_ip_subnet),
247 6 : ),
248 6 : password_weight,
249 6 : );
250 6 :
251 6 : if !limit_not_exceeded {
252 0 : warn!(
253 : enabled = self.rate_limiter_enabled,
254 0 : "rate limiting authentication"
255 : );
256 0 : Metrics::get().proxy.requests_auth_rate_limits_total.inc();
257 0 : Metrics::get()
258 0 : .proxy
259 0 : .endpoints_auth_rate_limits
260 0 : .get_metric()
261 0 : .measure(endpoint);
262 0 :
263 0 : if self.rate_limiter_enabled {
264 0 : return Err(auth::AuthError::too_many_connections());
265 0 : }
266 6 : }
267 :
268 6 : Ok(secret)
269 6 : }
270 : }
271 :
272 : /// True to its name, this function encapsulates our current auth trade-offs.
273 : /// Here, we choose the appropriate auth flow based on circumstances.
274 : ///
275 : /// All authentication flows will emit an AuthenticationOk message if successful.
276 6 : async fn auth_quirks(
277 6 : ctx: &mut RequestMonitoring,
278 6 : api: &impl console::Api,
279 6 : user_info: ComputeUserInfoMaybeEndpoint,
280 6 : client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
281 6 : allow_cleartext: bool,
282 6 : config: &'static AuthenticationConfig,
283 6 : endpoint_rate_limiter: Arc<EndpointRateLimiter>,
284 6 : ) -> auth::Result<ComputeCredentials> {
285 : // If there's no project so far, that entails that client doesn't
286 : // support SNI or other means of passing the endpoint (project) name.
287 : // We now expect to see a very specific payload in the place of password.
288 6 : let (info, unauthenticated_password) = match user_info.try_into() {
289 2 : Err(info) => {
290 2 : let res = hacks::password_hack_no_authentication(ctx, info, client).await?;
291 :
292 2 : ctx.set_endpoint_id(res.info.endpoint.clone());
293 2 : let password = match res.keys {
294 2 : ComputeCredentialKeys::Password(p) => p,
295 0 : _ => unreachable!("password hack should return a password"),
296 : };
297 2 : (res.info, Some(password))
298 : }
299 4 : Ok(info) => (info, None),
300 : };
301 :
302 6 : info!("fetching user's authentication info");
303 6 : let (allowed_ips, maybe_secret) = api.get_allowed_ips_and_secret(ctx, &info).await?;
304 :
305 : // check allowed list
306 6 : if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) {
307 0 : return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr));
308 6 : }
309 6 :
310 6 : if !endpoint_rate_limiter.check(info.endpoint.clone().into(), 1) {
311 0 : return Err(AuthError::too_many_connections());
312 6 : }
313 6 : let cached_secret = match maybe_secret {
314 6 : Some(secret) => secret,
315 0 : None => api.get_role_secret(ctx, &info).await?,
316 : };
317 6 : let (cached_entry, secret) = cached_secret.take_value();
318 :
319 6 : let secret = match secret {
320 6 : Some(secret) => config.check_rate_limit(
321 6 : ctx,
322 6 : config,
323 6 : secret,
324 6 : &info.endpoint,
325 6 : unauthenticated_password.is_some() || allow_cleartext,
326 0 : )?,
327 : None => {
328 : // If we don't have an authentication secret, we mock one to
329 : // prevent malicious probing (possible due to missing protocol steps).
330 : // This mocked secret will never lead to successful authentication.
331 0 : info!("authentication info not found, mocking it");
332 0 : AuthSecret::Scram(scram::ServerSecret::mock(rand::random()))
333 : }
334 : };
335 :
336 6 : match authenticate_with_secret(
337 6 : ctx,
338 6 : secret,
339 6 : info,
340 6 : client,
341 6 : unauthenticated_password,
342 6 : allow_cleartext,
343 6 : config,
344 6 : )
345 10 : .await
346 : {
347 6 : Ok(keys) => Ok(keys),
348 0 : Err(e) => {
349 0 : if e.is_auth_failed() {
350 0 : // The password could have been changed, so we invalidate the cache.
351 0 : cached_entry.invalidate();
352 0 : }
353 0 : Err(e)
354 : }
355 : }
356 6 : }
357 :
358 6 : async fn authenticate_with_secret(
359 6 : ctx: &mut RequestMonitoring,
360 6 : secret: AuthSecret,
361 6 : info: ComputeUserInfo,
362 6 : client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
363 6 : unauthenticated_password: Option<Vec<u8>>,
364 6 : allow_cleartext: bool,
365 6 : config: &'static AuthenticationConfig,
366 6 : ) -> auth::Result<ComputeCredentials> {
367 6 : if let Some(password) = unauthenticated_password {
368 2 : let ep = EndpointIdInt::from(&info.endpoint);
369 :
370 2 : let auth_outcome =
371 2 : validate_password_and_exchange(&config.thread_pool, ep, &password, secret).await?;
372 2 : let keys = match auth_outcome {
373 2 : crate::sasl::Outcome::Success(key) => key,
374 0 : crate::sasl::Outcome::Failure(reason) => {
375 0 : info!("auth backend failed with an error: {reason}");
376 0 : return Err(auth::AuthError::auth_failed(&*info.user));
377 : }
378 : };
379 :
380 : // we have authenticated the password
381 2 : client.write_message_noflush(&pq_proto::BeMessage::AuthenticationOk)?;
382 :
383 2 : return Ok(ComputeCredentials { info, keys });
384 4 : }
385 4 :
386 4 : // -- the remaining flows are self-authenticating --
387 4 :
388 4 : // Perform cleartext auth if we're allowed to do that.
389 4 : // Currently, we use it for websocket connections (latency).
390 4 : if allow_cleartext {
391 2 : ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
392 4 : return hacks::authenticate_cleartext(ctx, info, client, secret, config).await;
393 2 : }
394 2 :
395 2 : // Finally, proceed with the main auth flow (SCRAM-based).
396 4 : classic::authenticate(ctx, info, client, config, secret).await
397 6 : }
398 :
399 : impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
400 : /// Get compute endpoint name from the credentials.
401 0 : pub fn get_endpoint(&self) -> Option<EndpointId> {
402 0 : use BackendType::*;
403 0 :
404 0 : match self {
405 0 : Console(_, user_info) => user_info.endpoint_id.clone(),
406 0 : Link(_, _) => Some("link".into()),
407 : }
408 0 : }
409 :
410 : /// Get username from the credentials.
411 0 : pub fn get_user(&self) -> &str {
412 0 : use BackendType::*;
413 0 :
414 0 : match self {
415 0 : Console(_, user_info) => &user_info.user,
416 0 : Link(_, _) => "link",
417 : }
418 0 : }
419 :
420 : /// Authenticate the client via the requested backend, possibly using credentials.
421 0 : #[tracing::instrument(fields(allow_cleartext = allow_cleartext), skip_all)]
422 : pub async fn authenticate(
423 : self,
424 : ctx: &mut RequestMonitoring,
425 : client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
426 : allow_cleartext: bool,
427 : config: &'static AuthenticationConfig,
428 : endpoint_rate_limiter: Arc<EndpointRateLimiter>,
429 : ) -> auth::Result<BackendType<'a, ComputeCredentials, NodeInfo>> {
430 : use BackendType::*;
431 :
432 : let res = match self {
433 : Console(api, user_info) => {
434 : info!(
435 : user = &*user_info.user,
436 : project = user_info.endpoint(),
437 : "performing authentication using the console"
438 : );
439 :
440 : let credentials = auth_quirks(
441 : ctx,
442 : &*api,
443 : user_info,
444 : client,
445 : allow_cleartext,
446 : config,
447 : endpoint_rate_limiter,
448 : )
449 : .await?;
450 : BackendType::Console(api, credentials)
451 : }
452 : // NOTE: this auth backend doesn't use client credentials.
453 : Link(url, _) => {
454 : info!("performing link authentication");
455 :
456 : let info = link::authenticate(ctx, &url, client).await?;
457 :
458 : BackendType::Link(url, info)
459 : }
460 : };
461 :
462 : info!("user successfully authenticated");
463 : Ok(res)
464 : }
465 : }
466 :
467 : impl BackendType<'_, ComputeUserInfo, &()> {
468 0 : pub async fn get_role_secret(
469 0 : &self,
470 0 : ctx: &mut RequestMonitoring,
471 0 : ) -> Result<CachedRoleSecret, GetAuthInfoError> {
472 0 : use BackendType::*;
473 0 : match self {
474 0 : Console(api, user_info) => api.get_role_secret(ctx, user_info).await,
475 0 : Link(_, _) => Ok(Cached::new_uncached(None)),
476 : }
477 0 : }
478 :
479 0 : pub async fn get_allowed_ips_and_secret(
480 0 : &self,
481 0 : ctx: &mut RequestMonitoring,
482 0 : ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
483 0 : use BackendType::*;
484 0 : match self {
485 0 : Console(api, user_info) => api.get_allowed_ips_and_secret(ctx, user_info).await,
486 0 : Link(_, _) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
487 : }
488 0 : }
489 : }
490 :
491 : #[async_trait::async_trait]
492 : impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, NodeInfo> {
493 0 : async fn wake_compute(
494 0 : &self,
495 0 : ctx: &mut RequestMonitoring,
496 0 : ) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
497 0 : use BackendType::*;
498 0 :
499 0 : match self {
500 0 : Console(api, creds) => api.wake_compute(ctx, &creds.info).await,
501 0 : Link(_, info) => Ok(Cached::new_uncached(info.clone())),
502 0 : }
503 0 : }
504 :
505 0 : fn get_keys(&self) -> Option<&ComputeCredentialKeys> {
506 0 : match self {
507 0 : BackendType::Console(_, creds) => Some(&creds.keys),
508 0 : BackendType::Link(_, _) => None,
509 : }
510 0 : }
511 : }
512 :
513 : #[async_trait::async_trait]
514 : impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, &()> {
515 26 : async fn wake_compute(
516 26 : &self,
517 26 : ctx: &mut RequestMonitoring,
518 26 : ) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
519 26 : use BackendType::*;
520 26 :
521 26 : match self {
522 26 : Console(api, creds) => api.wake_compute(ctx, &creds.info).await,
523 26 : Link(_, _) => unreachable!("link auth flow doesn't support waking the compute"),
524 26 : }
525 26 : }
526 :
527 12 : fn get_keys(&self) -> Option<&ComputeCredentialKeys> {
528 12 : match self {
529 12 : BackendType::Console(_, creds) => Some(&creds.keys),
530 0 : BackendType::Link(_, _) => None,
531 : }
532 12 : }
533 : }
534 :
535 : #[cfg(test)]
536 : mod tests {
537 : use std::{net::IpAddr, sync::Arc, time::Duration};
538 :
539 : use bytes::BytesMut;
540 : use fallible_iterator::FallibleIterator;
541 : use once_cell::sync::Lazy;
542 : use postgres_protocol::{
543 : authentication::sasl::{ChannelBinding, ScramSha256},
544 : message::{backend::Message as PgMessage, frontend},
545 : };
546 : use provider::AuthSecret;
547 : use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
548 :
549 : use crate::{
550 : auth::{backend::MaskedIp, ComputeUserInfoMaybeEndpoint, IpPattern},
551 : config::AuthenticationConfig,
552 : console::{
553 : self,
554 : provider::{self, CachedAllowedIps, CachedRoleSecret},
555 : CachedNodeInfo,
556 : },
557 : context::RequestMonitoring,
558 : proxy::NeonOptions,
559 : rate_limiter::{EndpointRateLimiter, RateBucketInfo},
560 : scram::{threadpool::ThreadPool, ServerSecret},
561 : stream::{PqStream, Stream},
562 : };
563 :
564 : use super::{auth_quirks, AuthRateLimiter};
565 :
566 : struct Auth {
567 : ips: Vec<IpPattern>,
568 : secret: AuthSecret,
569 : }
570 :
571 : impl console::Api for Auth {
572 0 : async fn get_role_secret(
573 0 : &self,
574 0 : _ctx: &mut RequestMonitoring,
575 0 : _user_info: &super::ComputeUserInfo,
576 0 : ) -> Result<CachedRoleSecret, console::errors::GetAuthInfoError> {
577 0 : Ok(CachedRoleSecret::new_uncached(Some(self.secret.clone())))
578 0 : }
579 :
580 6 : async fn get_allowed_ips_and_secret(
581 6 : &self,
582 6 : _ctx: &mut RequestMonitoring,
583 6 : _user_info: &super::ComputeUserInfo,
584 6 : ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), console::errors::GetAuthInfoError>
585 6 : {
586 6 : Ok((
587 6 : CachedAllowedIps::new_uncached(Arc::new(self.ips.clone())),
588 6 : Some(CachedRoleSecret::new_uncached(Some(self.secret.clone()))),
589 6 : ))
590 6 : }
591 :
592 0 : async fn wake_compute(
593 0 : &self,
594 0 : _ctx: &mut RequestMonitoring,
595 0 : _user_info: &super::ComputeUserInfo,
596 0 : ) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
597 0 : unimplemented!()
598 : }
599 : }
600 :
601 6 : static CONFIG: Lazy<AuthenticationConfig> = Lazy::new(|| AuthenticationConfig {
602 6 : thread_pool: ThreadPool::new(1),
603 6 : scram_protocol_timeout: std::time::Duration::from_secs(5),
604 6 : rate_limiter_enabled: true,
605 6 : rate_limiter: AuthRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET),
606 6 : rate_limit_ip_subnet: 64,
607 6 : });
608 :
609 10 : async fn read_message(r: &mut (impl AsyncRead + Unpin), b: &mut BytesMut) -> PgMessage {
610 : loop {
611 14 : r.read_buf(&mut *b).await.unwrap();
612 14 : if let Some(m) = PgMessage::parse(&mut *b).unwrap() {
613 10 : break m;
614 4 : }
615 : }
616 10 : }
617 :
618 : #[test]
619 2 : fn masked_ip() {
620 2 : let ip_a = IpAddr::V4([127, 0, 0, 1].into());
621 2 : let ip_b = IpAddr::V4([127, 0, 0, 2].into());
622 2 : let ip_c = IpAddr::V4([192, 168, 1, 101].into());
623 2 : let ip_d = IpAddr::V4([192, 168, 1, 102].into());
624 2 : let ip_e = IpAddr::V6("abcd:abcd:abcd:abcd:abcd:abcd:abcd:abcd".parse().unwrap());
625 2 : let ip_f = IpAddr::V6("abcd:abcd:abcd:abcd:1234:abcd:abcd:abcd".parse().unwrap());
626 2 :
627 2 : assert_ne!(MaskedIp::new(ip_a, 64), MaskedIp::new(ip_b, 64));
628 2 : assert_ne!(MaskedIp::new(ip_a, 32), MaskedIp::new(ip_b, 32));
629 2 : assert_eq!(MaskedIp::new(ip_a, 30), MaskedIp::new(ip_b, 30));
630 2 : assert_eq!(MaskedIp::new(ip_c, 30), MaskedIp::new(ip_d, 30));
631 :
632 2 : assert_ne!(MaskedIp::new(ip_e, 128), MaskedIp::new(ip_f, 128));
633 2 : assert_eq!(MaskedIp::new(ip_e, 64), MaskedIp::new(ip_f, 64));
634 2 : }
635 :
636 : #[test]
637 2 : fn test_default_auth_rate_limit_set() {
638 2 : // these values used to exceed u32::MAX
639 2 : assert_eq!(
640 2 : RateBucketInfo::DEFAULT_AUTH_SET,
641 2 : [
642 2 : RateBucketInfo {
643 2 : interval: Duration::from_secs(1),
644 2 : max_rpi: 1000 * 4096,
645 2 : },
646 2 : RateBucketInfo {
647 2 : interval: Duration::from_secs(60),
648 2 : max_rpi: 600 * 4096 * 60,
649 2 : },
650 2 : RateBucketInfo {
651 2 : interval: Duration::from_secs(600),
652 2 : max_rpi: 300 * 4096 * 600,
653 2 : }
654 2 : ]
655 2 : );
656 :
657 8 : for x in RateBucketInfo::DEFAULT_AUTH_SET {
658 6 : let y = x.to_string().parse().unwrap();
659 6 : assert_eq!(x, y);
660 : }
661 2 : }
662 :
663 : #[tokio::test]
664 2 : async fn auth_quirks_scram() {
665 2 : let (mut client, server) = tokio::io::duplex(1024);
666 2 : let mut stream = PqStream::new(Stream::from_raw(server));
667 2 :
668 2 : let mut ctx = RequestMonitoring::test();
669 2 : let api = Auth {
670 2 : ips: vec![],
671 6 : secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
672 2 : };
673 2 :
674 2 : let user_info = ComputeUserInfoMaybeEndpoint {
675 2 : user: "conrad".into(),
676 2 : endpoint_id: Some("endpoint".into()),
677 2 : options: NeonOptions::default(),
678 2 : };
679 2 :
680 2 : let handle = tokio::spawn(async move {
681 2 : let mut scram = ScramSha256::new(b"my-secret-password", ChannelBinding::unsupported());
682 2 :
683 2 : let mut read = BytesMut::new();
684 2 :
685 2 : // server should offer scram
686 2 : match read_message(&mut client, &mut read).await {
687 2 : PgMessage::AuthenticationSasl(a) => {
688 2 : let options: Vec<&str> = a.mechanisms().collect().unwrap();
689 2 : assert_eq!(options, ["SCRAM-SHA-256"]);
690 2 : }
691 2 : _ => panic!("wrong message"),
692 2 : }
693 2 :
694 2 : // client sends client-first-message
695 2 : let mut write = BytesMut::new();
696 2 : frontend::sasl_initial_response("SCRAM-SHA-256", scram.message(), &mut write).unwrap();
697 2 : client.write_all(&write).await.unwrap();
698 2 :
699 2 : // server response with server-first-message
700 2 : match read_message(&mut client, &mut read).await {
701 2 : PgMessage::AuthenticationSaslContinue(a) => {
702 6 : scram.update(a.data()).await.unwrap();
703 2 : }
704 2 : _ => panic!("wrong message"),
705 2 : }
706 2 :
707 2 : // client response with client-final-message
708 2 : write.clear();
709 2 : frontend::sasl_response(scram.message(), &mut write).unwrap();
710 2 : client.write_all(&write).await.unwrap();
711 2 :
712 2 : // server response with server-final-message
713 2 : match read_message(&mut client, &mut read).await {
714 2 : PgMessage::AuthenticationSaslFinal(a) => {
715 2 : scram.finish(a.data()).unwrap();
716 2 : }
717 2 : _ => panic!("wrong message"),
718 2 : }
719 2 : });
720 2 : let endpoint_rate_limiter =
721 2 : Arc::new(EndpointRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET));
722 2 :
723 2 : let _creds = auth_quirks(
724 2 : &mut ctx,
725 2 : &api,
726 2 : user_info,
727 2 : &mut stream,
728 2 : false,
729 2 : &CONFIG,
730 2 : endpoint_rate_limiter,
731 2 : )
732 4 : .await
733 2 : .unwrap();
734 2 :
735 2 : handle.await.unwrap();
736 2 : }
737 :
738 : #[tokio::test]
739 2 : async fn auth_quirks_cleartext() {
740 2 : let (mut client, server) = tokio::io::duplex(1024);
741 2 : let mut stream = PqStream::new(Stream::from_raw(server));
742 2 :
743 2 : let mut ctx = RequestMonitoring::test();
744 2 : let api = Auth {
745 2 : ips: vec![],
746 6 : secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
747 2 : };
748 2 :
749 2 : let user_info = ComputeUserInfoMaybeEndpoint {
750 2 : user: "conrad".into(),
751 2 : endpoint_id: Some("endpoint".into()),
752 2 : options: NeonOptions::default(),
753 2 : };
754 2 :
755 2 : let handle = tokio::spawn(async move {
756 2 : let mut read = BytesMut::new();
757 2 : let mut write = BytesMut::new();
758 2 :
759 2 : // server should offer cleartext
760 2 : match read_message(&mut client, &mut read).await {
761 2 : PgMessage::AuthenticationCleartextPassword => {}
762 2 : _ => panic!("wrong message"),
763 2 : }
764 2 :
765 2 : // client responds with password
766 2 : write.clear();
767 2 : frontend::password_message(b"my-secret-password", &mut write).unwrap();
768 2 : client.write_all(&write).await.unwrap();
769 2 : });
770 2 : let endpoint_rate_limiter =
771 2 : Arc::new(EndpointRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET));
772 2 :
773 2 : let _creds = auth_quirks(
774 2 : &mut ctx,
775 2 : &api,
776 2 : user_info,
777 2 : &mut stream,
778 2 : true,
779 2 : &CONFIG,
780 2 : endpoint_rate_limiter,
781 2 : )
782 4 : .await
783 2 : .unwrap();
784 2 :
785 2 : handle.await.unwrap();
786 2 : }
787 :
788 : #[tokio::test]
789 2 : async fn auth_quirks_password_hack() {
790 2 : let (mut client, server) = tokio::io::duplex(1024);
791 2 : let mut stream = PqStream::new(Stream::from_raw(server));
792 2 :
793 2 : let mut ctx = RequestMonitoring::test();
794 2 : let api = Auth {
795 2 : ips: vec![],
796 6 : secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
797 2 : };
798 2 :
799 2 : let user_info = ComputeUserInfoMaybeEndpoint {
800 2 : user: "conrad".into(),
801 2 : endpoint_id: None,
802 2 : options: NeonOptions::default(),
803 2 : };
804 2 :
805 2 : let handle = tokio::spawn(async move {
806 2 : let mut read = BytesMut::new();
807 2 :
808 2 : // server should offer cleartext
809 2 : match read_message(&mut client, &mut read).await {
810 2 : PgMessage::AuthenticationCleartextPassword => {}
811 2 : _ => panic!("wrong message"),
812 2 : }
813 2 :
814 2 : // client responds with password
815 2 : let mut write = BytesMut::new();
816 2 : frontend::password_message(b"endpoint=my-endpoint;my-secret-password", &mut write)
817 2 : .unwrap();
818 2 : client.write_all(&write).await.unwrap();
819 2 : });
820 2 :
821 2 : let endpoint_rate_limiter =
822 2 : Arc::new(EndpointRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET));
823 2 :
824 2 : let creds = auth_quirks(
825 2 : &mut ctx,
826 2 : &api,
827 2 : user_info,
828 2 : &mut stream,
829 2 : true,
830 2 : &CONFIG,
831 2 : endpoint_rate_limiter,
832 2 : )
833 4 : .await
834 2 : .unwrap();
835 2 :
836 2 : assert_eq!(creds.info.endpoint, "my-endpoint");
837 2 :
838 2 : handle.await.unwrap();
839 2 : }
840 : }
|