Line data Source code
1 : mod classic;
2 : mod console_redirect;
3 : mod hacks;
4 : pub mod jwt;
5 : pub mod local;
6 :
7 : use std::net::IpAddr;
8 : use std::sync::Arc;
9 : use std::time::Duration;
10 :
11 : pub use console_redirect::ConsoleRedirectBackend;
12 : pub(crate) use console_redirect::WebAuthError;
13 : use ipnet::{Ipv4Net, Ipv6Net};
14 : use local::LocalBackend;
15 : use tokio::io::{AsyncRead, AsyncWrite};
16 : use tokio_postgres::config::AuthKeys;
17 : use tracing::{info, warn};
18 :
19 : use crate::auth::credentials::check_peer_addr_is_in_list;
20 : use crate::auth::{self, validate_password_and_exchange, AuthError, ComputeUserInfoMaybeEndpoint};
21 : use crate::cache::Cached;
22 : use crate::config::AuthenticationConfig;
23 : use crate::context::RequestMonitoring;
24 : use crate::control_plane::errors::GetAuthInfoError;
25 : use crate::control_plane::provider::{
26 : CachedAllowedIps, CachedNodeInfo, CachedRoleSecret, ControlPlaneBackend,
27 : };
28 : use crate::control_plane::{self, Api, AuthSecret};
29 : use crate::intern::EndpointIdInt;
30 : use crate::metrics::Metrics;
31 : use crate::proxy::connect_compute::ComputeConnectBackend;
32 : use crate::proxy::NeonOptions;
33 : use crate::rate_limiter::{BucketRateLimiter, EndpointRateLimiter, RateBucketInfo};
34 : use crate::stream::Stream;
35 : use crate::{scram, stream, EndpointCacheKey, EndpointId, RoleName};
36 :
37 : /// Alternative to [`std::borrow::Cow`] but doesn't need `T: ToOwned` as we don't need that functionality
38 : pub enum MaybeOwned<'a, T> {
39 : Owned(T),
40 : Borrowed(&'a T),
41 : }
42 :
43 : impl<T> std::ops::Deref for MaybeOwned<'_, T> {
44 : type Target = T;
45 :
46 13 : fn deref(&self) -> &Self::Target {
47 13 : match self {
48 13 : MaybeOwned::Owned(t) => t,
49 0 : MaybeOwned::Borrowed(t) => t,
50 : }
51 13 : }
52 : }
53 :
54 : /// This type serves two purposes:
55 : ///
56 : /// * When `T` is `()`, it's just a regular auth backend selector
57 : /// which we use in [`crate::config::ProxyConfig`].
58 : ///
59 : /// * However, when we substitute `T` with [`ComputeUserInfoMaybeEndpoint`],
60 : /// this helps us provide the credentials only to those auth
61 : /// backends which require them for the authentication process.
62 : pub enum Backend<'a, T> {
63 : /// Cloud API (V2).
64 : ControlPlane(MaybeOwned<'a, ControlPlaneBackend>, T),
65 : /// Local proxy uses configured auth credentials and does not wake compute
66 : Local(MaybeOwned<'a, LocalBackend>),
67 : }
68 :
69 : #[cfg(test)]
70 : pub(crate) trait TestBackend: Send + Sync + 'static {
71 : fn wake_compute(&self) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError>;
72 : fn get_allowed_ips_and_secret(
73 : &self,
74 : ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), control_plane::errors::GetAuthInfoError>;
75 : fn dyn_clone(&self) -> Box<dyn TestBackend>;
76 : }
77 :
78 : #[cfg(test)]
79 : impl Clone for Box<dyn TestBackend> {
80 0 : fn clone(&self) -> Self {
81 0 : TestBackend::dyn_clone(&**self)
82 0 : }
83 : }
84 :
85 : impl std::fmt::Display for Backend<'_, ()> {
86 0 : fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87 0 : match self {
88 0 : Self::ControlPlane(api, ()) => match &**api {
89 0 : ControlPlaneBackend::Management(endpoint) => fmt
90 0 : .debug_tuple("ControlPlane::Management")
91 0 : .field(&endpoint.url())
92 0 : .finish(),
93 : #[cfg(any(test, feature = "testing"))]
94 0 : ControlPlaneBackend::PostgresMock(endpoint) => fmt
95 0 : .debug_tuple("ControlPlane::PostgresMock")
96 0 : .field(&endpoint.url())
97 0 : .finish(),
98 : #[cfg(test)]
99 0 : ControlPlaneBackend::Test(_) => fmt.debug_tuple("ControlPlane::Test").finish(),
100 : },
101 0 : Self::Local(_) => fmt.debug_tuple("Local").finish(),
102 : }
103 0 : }
104 : }
105 :
106 : impl<T> Backend<'_, T> {
107 : /// Very similar to [`std::option::Option::as_ref`].
108 : /// This helps us pass structured config to async tasks.
109 0 : pub(crate) fn as_ref(&self) -> Backend<'_, &T> {
110 0 : match self {
111 0 : Self::ControlPlane(c, x) => Backend::ControlPlane(MaybeOwned::Borrowed(c), x),
112 0 : Self::Local(l) => Backend::Local(MaybeOwned::Borrowed(l)),
113 : }
114 0 : }
115 : }
116 :
117 : impl<'a, T> Backend<'a, T> {
118 : /// Very similar to [`std::option::Option::map`].
119 : /// Maps [`Backend<T>`] to [`Backend<R>`] by applying
120 : /// a function to a contained value.
121 0 : pub(crate) fn map<R>(self, f: impl FnOnce(T) -> R) -> Backend<'a, R> {
122 0 : match self {
123 0 : Self::ControlPlane(c, x) => Backend::ControlPlane(c, f(x)),
124 0 : Self::Local(l) => Backend::Local(l),
125 : }
126 0 : }
127 : }
128 : impl<'a, T, E> Backend<'a, Result<T, E>> {
129 : /// Very similar to [`std::option::Option::transpose`].
130 : /// This is most useful for error handling.
131 0 : pub(crate) fn transpose(self) -> Result<Backend<'a, T>, E> {
132 0 : match self {
133 0 : Self::ControlPlane(c, x) => x.map(|x| Backend::ControlPlane(c, x)),
134 0 : Self::Local(l) => Ok(Backend::Local(l)),
135 : }
136 0 : }
137 : }
138 :
139 : pub(crate) struct ComputeCredentials {
140 : pub(crate) info: ComputeUserInfo,
141 : pub(crate) keys: ComputeCredentialKeys,
142 : }
143 :
144 : #[derive(Debug, Clone)]
145 : pub(crate) struct ComputeUserInfoNoEndpoint {
146 : pub(crate) user: RoleName,
147 : pub(crate) options: NeonOptions,
148 : }
149 :
150 : #[derive(Debug, Clone)]
151 : pub(crate) struct ComputeUserInfo {
152 : pub(crate) endpoint: EndpointId,
153 : pub(crate) user: RoleName,
154 : pub(crate) options: NeonOptions,
155 : }
156 :
157 : impl ComputeUserInfo {
158 2 : pub(crate) fn endpoint_cache_key(&self) -> EndpointCacheKey {
159 2 : self.options.get_cache_key(&self.endpoint)
160 2 : }
161 : }
162 :
163 : #[cfg_attr(test, derive(Debug))]
164 : pub(crate) enum ComputeCredentialKeys {
165 : #[cfg(any(test, feature = "testing"))]
166 : Password(Vec<u8>),
167 : AuthKeys(AuthKeys),
168 : JwtPayload(Vec<u8>),
169 : None,
170 : }
171 :
172 : impl TryFrom<ComputeUserInfoMaybeEndpoint> for ComputeUserInfo {
173 : // user name
174 : type Error = ComputeUserInfoNoEndpoint;
175 :
176 3 : fn try_from(user_info: ComputeUserInfoMaybeEndpoint) -> Result<Self, Self::Error> {
177 3 : match user_info.endpoint_id {
178 1 : None => Err(ComputeUserInfoNoEndpoint {
179 1 : user: user_info.user,
180 1 : options: user_info.options,
181 1 : }),
182 2 : Some(endpoint) => Ok(ComputeUserInfo {
183 2 : endpoint,
184 2 : user: user_info.user,
185 2 : options: user_info.options,
186 2 : }),
187 : }
188 3 : }
189 : }
190 :
191 : #[derive(PartialEq, PartialOrd, Hash, Eq, Ord, Debug, Copy, Clone)]
192 : pub struct MaskedIp(IpAddr);
193 :
194 : impl MaskedIp {
195 15 : fn new(value: IpAddr, prefix: u8) -> Self {
196 15 : match value {
197 11 : IpAddr::V4(v4) => Self(IpAddr::V4(
198 11 : Ipv4Net::new(v4, prefix).map_or(v4, |x| x.trunc().addr()),
199 11 : )),
200 4 : IpAddr::V6(v6) => Self(IpAddr::V6(
201 4 : Ipv6Net::new(v6, prefix).map_or(v6, |x| x.trunc().addr()),
202 4 : )),
203 : }
204 15 : }
205 : }
206 :
207 : // This can't be just per IP because that would limit some PaaS that share IP addresses
208 : pub type AuthRateLimiter = BucketRateLimiter<(EndpointIdInt, MaskedIp)>;
209 :
210 : impl RateBucketInfo {
211 : /// All of these are per endpoint-maskedip pair.
212 : /// Context: 4096 rounds of pbkdf2 take about 1ms of cpu time to execute (1 milli-cpu-second or 1mcpus).
213 : ///
214 : /// First bucket: 1000mcpus total per endpoint-ip pair
215 : /// * 4096000 requests per second with 1 hash rounds.
216 : /// * 1000 requests per second with 4096 hash rounds.
217 : /// * 6.8 requests per second with 600000 hash rounds.
218 : pub const DEFAULT_AUTH_SET: [Self; 3] = [
219 : Self::new(1000 * 4096, Duration::from_secs(1)),
220 : Self::new(600 * 4096, Duration::from_secs(60)),
221 : Self::new(300 * 4096, Duration::from_secs(600)),
222 : ];
223 : }
224 :
225 : impl AuthenticationConfig {
226 3 : pub(crate) fn check_rate_limit(
227 3 : &self,
228 3 : ctx: &RequestMonitoring,
229 3 : secret: AuthSecret,
230 3 : endpoint: &EndpointId,
231 3 : is_cleartext: bool,
232 3 : ) -> auth::Result<AuthSecret> {
233 3 : // we have validated the endpoint exists, so let's intern it.
234 3 : let endpoint_int = EndpointIdInt::from(endpoint.normalize());
235 :
236 : // only count the full hash count if password hack or websocket flow.
237 : // in other words, if proxy needs to run the hashing
238 3 : let password_weight = if is_cleartext {
239 2 : match &secret {
240 : #[cfg(any(test, feature = "testing"))]
241 0 : AuthSecret::Md5(_) => 1,
242 2 : AuthSecret::Scram(s) => s.iterations + 1,
243 : }
244 : } else {
245 : // validating scram takes just 1 hmac_sha_256 operation.
246 1 : 1
247 : };
248 :
249 3 : let limit_not_exceeded = self.rate_limiter.check(
250 3 : (
251 3 : endpoint_int,
252 3 : MaskedIp::new(ctx.peer_addr(), self.rate_limit_ip_subnet),
253 3 : ),
254 3 : password_weight,
255 3 : );
256 3 :
257 3 : if !limit_not_exceeded {
258 0 : warn!(
259 : enabled = self.rate_limiter_enabled,
260 0 : "rate limiting authentication"
261 : );
262 0 : Metrics::get().proxy.requests_auth_rate_limits_total.inc();
263 0 : Metrics::get()
264 0 : .proxy
265 0 : .endpoints_auth_rate_limits
266 0 : .get_metric()
267 0 : .measure(endpoint);
268 0 :
269 0 : if self.rate_limiter_enabled {
270 0 : return Err(auth::AuthError::too_many_connections());
271 0 : }
272 3 : }
273 :
274 3 : Ok(secret)
275 3 : }
276 : }
277 :
278 : /// True to its name, this function encapsulates our current auth trade-offs.
279 : /// Here, we choose the appropriate auth flow based on circumstances.
280 : ///
281 : /// All authentication flows will emit an AuthenticationOk message if successful.
282 3 : async fn auth_quirks(
283 3 : ctx: &RequestMonitoring,
284 3 : api: &impl control_plane::Api,
285 3 : user_info: ComputeUserInfoMaybeEndpoint,
286 3 : client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
287 3 : allow_cleartext: bool,
288 3 : config: &'static AuthenticationConfig,
289 3 : endpoint_rate_limiter: Arc<EndpointRateLimiter>,
290 3 : ) -> auth::Result<ComputeCredentials> {
291 : // If there's no project so far, that entails that client doesn't
292 : // support SNI or other means of passing the endpoint (project) name.
293 : // We now expect to see a very specific payload in the place of password.
294 3 : let (info, unauthenticated_password) = match user_info.try_into() {
295 1 : Err(info) => {
296 1 : let (info, password) =
297 1 : hacks::password_hack_no_authentication(ctx, info, client).await?;
298 1 : ctx.set_endpoint_id(info.endpoint.clone());
299 1 : (info, Some(password))
300 : }
301 2 : Ok(info) => (info, None),
302 : };
303 :
304 3 : info!("fetching user's authentication info");
305 3 : let (allowed_ips, maybe_secret) = api.get_allowed_ips_and_secret(ctx, &info).await?;
306 :
307 : // check allowed list
308 3 : if config.ip_allowlist_check_enabled
309 3 : && !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips)
310 : {
311 0 : return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr()));
312 3 : }
313 3 :
314 3 : if !endpoint_rate_limiter.check(info.endpoint.clone().into(), 1) {
315 0 : return Err(AuthError::too_many_connections());
316 3 : }
317 3 : let cached_secret = match maybe_secret {
318 3 : Some(secret) => secret,
319 0 : None => api.get_role_secret(ctx, &info).await?,
320 : };
321 3 : let (cached_entry, secret) = cached_secret.take_value();
322 :
323 3 : let secret = if let Some(secret) = secret {
324 3 : config.check_rate_limit(
325 3 : ctx,
326 3 : secret,
327 3 : &info.endpoint,
328 3 : unauthenticated_password.is_some() || allow_cleartext,
329 0 : )?
330 : } else {
331 : // If we don't have an authentication secret, we mock one to
332 : // prevent malicious probing (possible due to missing protocol steps).
333 : // This mocked secret will never lead to successful authentication.
334 0 : info!("authentication info not found, mocking it");
335 0 : AuthSecret::Scram(scram::ServerSecret::mock(rand::random()))
336 : };
337 :
338 3 : match authenticate_with_secret(
339 3 : ctx,
340 3 : secret,
341 3 : info,
342 3 : client,
343 3 : unauthenticated_password,
344 3 : allow_cleartext,
345 3 : config,
346 3 : )
347 5 : .await
348 : {
349 3 : Ok(keys) => Ok(keys),
350 0 : Err(e) => {
351 0 : if e.is_auth_failed() {
352 0 : // The password could have been changed, so we invalidate the cache.
353 0 : cached_entry.invalidate();
354 0 : }
355 0 : Err(e)
356 : }
357 : }
358 3 : }
359 :
360 3 : async fn authenticate_with_secret(
361 3 : ctx: &RequestMonitoring,
362 3 : secret: AuthSecret,
363 3 : info: ComputeUserInfo,
364 3 : client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
365 3 : unauthenticated_password: Option<Vec<u8>>,
366 3 : allow_cleartext: bool,
367 3 : config: &'static AuthenticationConfig,
368 3 : ) -> auth::Result<ComputeCredentials> {
369 3 : if let Some(password) = unauthenticated_password {
370 1 : let ep = EndpointIdInt::from(&info.endpoint);
371 :
372 1 : let auth_outcome =
373 1 : validate_password_and_exchange(&config.thread_pool, ep, &password, secret).await?;
374 1 : let keys = match auth_outcome {
375 1 : crate::sasl::Outcome::Success(key) => key,
376 0 : crate::sasl::Outcome::Failure(reason) => {
377 0 : info!("auth backend failed with an error: {reason}");
378 0 : return Err(auth::AuthError::auth_failed(&*info.user));
379 : }
380 : };
381 :
382 : // we have authenticated the password
383 1 : client.write_message_noflush(&pq_proto::BeMessage::AuthenticationOk)?;
384 :
385 1 : return Ok(ComputeCredentials { info, keys });
386 2 : }
387 2 :
388 2 : // -- the remaining flows are self-authenticating --
389 2 :
390 2 : // Perform cleartext auth if we're allowed to do that.
391 2 : // Currently, we use it for websocket connections (latency).
392 2 : if allow_cleartext {
393 1 : ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
394 2 : return hacks::authenticate_cleartext(ctx, info, client, secret, config).await;
395 1 : }
396 1 :
397 1 : // Finally, proceed with the main auth flow (SCRAM-based).
398 2 : classic::authenticate(ctx, info, client, config, secret).await
399 3 : }
400 :
401 : impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> {
402 : /// Get username from the credentials.
403 0 : pub(crate) fn get_user(&self) -> &str {
404 0 : match self {
405 0 : Self::ControlPlane(_, user_info) => &user_info.user,
406 0 : Self::Local(_) => "local",
407 : }
408 0 : }
409 :
410 : /// Authenticate the client via the requested backend, possibly using credentials.
411 0 : #[tracing::instrument(fields(allow_cleartext = allow_cleartext), skip_all)]
412 : pub(crate) async fn authenticate(
413 : self,
414 : ctx: &RequestMonitoring,
415 : client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
416 : allow_cleartext: bool,
417 : config: &'static AuthenticationConfig,
418 : endpoint_rate_limiter: Arc<EndpointRateLimiter>,
419 : ) -> auth::Result<Backend<'a, ComputeCredentials>> {
420 : let res = match self {
421 : Self::ControlPlane(api, user_info) => {
422 : info!(
423 : user = &*user_info.user,
424 : project = user_info.endpoint(),
425 : "performing authentication using the console"
426 : );
427 :
428 : let credentials = auth_quirks(
429 : ctx,
430 : &*api,
431 : user_info,
432 : client,
433 : allow_cleartext,
434 : config,
435 : endpoint_rate_limiter,
436 : )
437 : .await?;
438 : Backend::ControlPlane(api, credentials)
439 : }
440 : Self::Local(_) => {
441 : return Err(auth::AuthError::bad_auth_method("invalid for local proxy"))
442 : }
443 : };
444 :
445 : info!("user successfully authenticated");
446 : Ok(res)
447 : }
448 : }
449 :
450 : impl Backend<'_, ComputeUserInfo> {
451 0 : pub(crate) async fn get_role_secret(
452 0 : &self,
453 0 : ctx: &RequestMonitoring,
454 0 : ) -> Result<CachedRoleSecret, GetAuthInfoError> {
455 0 : match self {
456 0 : Self::ControlPlane(api, user_info) => api.get_role_secret(ctx, user_info).await,
457 0 : Self::Local(_) => Ok(Cached::new_uncached(None)),
458 : }
459 0 : }
460 :
461 0 : pub(crate) async fn get_allowed_ips_and_secret(
462 0 : &self,
463 0 : ctx: &RequestMonitoring,
464 0 : ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
465 0 : match self {
466 0 : Self::ControlPlane(api, user_info) => {
467 0 : api.get_allowed_ips_and_secret(ctx, user_info).await
468 : }
469 0 : Self::Local(_) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
470 : }
471 0 : }
472 : }
473 :
474 : #[async_trait::async_trait]
475 : impl ComputeConnectBackend for Backend<'_, ComputeCredentials> {
476 13 : async fn wake_compute(
477 13 : &self,
478 13 : ctx: &RequestMonitoring,
479 13 : ) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
480 13 : match self {
481 13 : Self::ControlPlane(api, creds) => api.wake_compute(ctx, &creds.info).await,
482 0 : Self::Local(local) => Ok(Cached::new_uncached(local.node_info.clone())),
483 : }
484 26 : }
485 :
486 6 : fn get_keys(&self) -> &ComputeCredentialKeys {
487 6 : match self {
488 6 : Self::ControlPlane(_, creds) => &creds.keys,
489 0 : Self::Local(_) => &ComputeCredentialKeys::None,
490 : }
491 6 : }
492 : }
493 :
494 : #[cfg(test)]
495 : mod tests {
496 : use std::net::IpAddr;
497 : use std::sync::Arc;
498 : use std::time::Duration;
499 :
500 : use bytes::BytesMut;
501 : use fallible_iterator::FallibleIterator;
502 : use once_cell::sync::Lazy;
503 : use postgres_protocol::authentication::sasl::{ChannelBinding, ScramSha256};
504 : use postgres_protocol::message::backend::Message as PgMessage;
505 : use postgres_protocol::message::frontend;
506 : use provider::AuthSecret;
507 : use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
508 :
509 : use super::jwt::JwkCache;
510 : use super::{auth_quirks, AuthRateLimiter};
511 : use crate::auth::backend::MaskedIp;
512 : use crate::auth::{ComputeUserInfoMaybeEndpoint, IpPattern};
513 : use crate::config::AuthenticationConfig;
514 : use crate::context::RequestMonitoring;
515 : use crate::control_plane::provider::{self, CachedAllowedIps, CachedRoleSecret};
516 : use crate::control_plane::{self, CachedNodeInfo};
517 : use crate::proxy::NeonOptions;
518 : use crate::rate_limiter::{EndpointRateLimiter, RateBucketInfo};
519 : use crate::scram::threadpool::ThreadPool;
520 : use crate::scram::ServerSecret;
521 : use crate::stream::{PqStream, Stream};
522 :
523 : struct Auth {
524 : ips: Vec<IpPattern>,
525 : secret: AuthSecret,
526 : }
527 :
528 : impl control_plane::Api for Auth {
529 0 : async fn get_role_secret(
530 0 : &self,
531 0 : _ctx: &RequestMonitoring,
532 0 : _user_info: &super::ComputeUserInfo,
533 0 : ) -> Result<CachedRoleSecret, control_plane::errors::GetAuthInfoError> {
534 0 : Ok(CachedRoleSecret::new_uncached(Some(self.secret.clone())))
535 0 : }
536 :
537 3 : async fn get_allowed_ips_and_secret(
538 3 : &self,
539 3 : _ctx: &RequestMonitoring,
540 3 : _user_info: &super::ComputeUserInfo,
541 3 : ) -> Result<
542 3 : (CachedAllowedIps, Option<CachedRoleSecret>),
543 3 : control_plane::errors::GetAuthInfoError,
544 3 : > {
545 3 : Ok((
546 3 : CachedAllowedIps::new_uncached(Arc::new(self.ips.clone())),
547 3 : Some(CachedRoleSecret::new_uncached(Some(self.secret.clone()))),
548 3 : ))
549 3 : }
550 :
551 0 : async fn get_endpoint_jwks(
552 0 : &self,
553 0 : _ctx: &RequestMonitoring,
554 0 : _endpoint: crate::EndpointId,
555 0 : ) -> Result<Vec<super::jwt::AuthRule>, control_plane::errors::GetEndpointJwksError>
556 0 : {
557 0 : unimplemented!()
558 : }
559 :
560 0 : async fn wake_compute(
561 0 : &self,
562 0 : _ctx: &RequestMonitoring,
563 0 : _user_info: &super::ComputeUserInfo,
564 0 : ) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
565 0 : unimplemented!()
566 : }
567 : }
568 :
569 3 : static CONFIG: Lazy<AuthenticationConfig> = Lazy::new(|| AuthenticationConfig {
570 3 : jwks_cache: JwkCache::default(),
571 3 : thread_pool: ThreadPool::new(1),
572 3 : scram_protocol_timeout: std::time::Duration::from_secs(5),
573 3 : rate_limiter_enabled: true,
574 3 : rate_limiter: AuthRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET),
575 3 : rate_limit_ip_subnet: 64,
576 3 : ip_allowlist_check_enabled: true,
577 3 : is_auth_broker: false,
578 3 : accept_jwts: false,
579 3 : webauth_confirmation_timeout: std::time::Duration::from_secs(5),
580 3 : });
581 :
582 5 : async fn read_message(r: &mut (impl AsyncRead + Unpin), b: &mut BytesMut) -> PgMessage {
583 : loop {
584 7 : r.read_buf(&mut *b).await.unwrap();
585 7 : if let Some(m) = PgMessage::parse(&mut *b).unwrap() {
586 5 : break m;
587 2 : }
588 : }
589 5 : }
590 :
591 : #[test]
592 1 : fn masked_ip() {
593 1 : let ip_a = IpAddr::V4([127, 0, 0, 1].into());
594 1 : let ip_b = IpAddr::V4([127, 0, 0, 2].into());
595 1 : let ip_c = IpAddr::V4([192, 168, 1, 101].into());
596 1 : let ip_d = IpAddr::V4([192, 168, 1, 102].into());
597 1 : let ip_e = IpAddr::V6("abcd:abcd:abcd:abcd:abcd:abcd:abcd:abcd".parse().unwrap());
598 1 : let ip_f = IpAddr::V6("abcd:abcd:abcd:abcd:1234:abcd:abcd:abcd".parse().unwrap());
599 1 :
600 1 : assert_ne!(MaskedIp::new(ip_a, 64), MaskedIp::new(ip_b, 64));
601 1 : assert_ne!(MaskedIp::new(ip_a, 32), MaskedIp::new(ip_b, 32));
602 1 : assert_eq!(MaskedIp::new(ip_a, 30), MaskedIp::new(ip_b, 30));
603 1 : assert_eq!(MaskedIp::new(ip_c, 30), MaskedIp::new(ip_d, 30));
604 :
605 1 : assert_ne!(MaskedIp::new(ip_e, 128), MaskedIp::new(ip_f, 128));
606 1 : assert_eq!(MaskedIp::new(ip_e, 64), MaskedIp::new(ip_f, 64));
607 1 : }
608 :
609 : #[test]
610 1 : fn test_default_auth_rate_limit_set() {
611 1 : // these values used to exceed u32::MAX
612 1 : assert_eq!(
613 1 : RateBucketInfo::DEFAULT_AUTH_SET,
614 1 : [
615 1 : RateBucketInfo {
616 1 : interval: Duration::from_secs(1),
617 1 : max_rpi: 1000 * 4096,
618 1 : },
619 1 : RateBucketInfo {
620 1 : interval: Duration::from_secs(60),
621 1 : max_rpi: 600 * 4096 * 60,
622 1 : },
623 1 : RateBucketInfo {
624 1 : interval: Duration::from_secs(600),
625 1 : max_rpi: 300 * 4096 * 600,
626 1 : }
627 1 : ]
628 1 : );
629 :
630 4 : for x in RateBucketInfo::DEFAULT_AUTH_SET {
631 3 : let y = x.to_string().parse().unwrap();
632 3 : assert_eq!(x, y);
633 : }
634 1 : }
635 :
636 : #[tokio::test]
637 1 : async fn auth_quirks_scram() {
638 1 : let (mut client, server) = tokio::io::duplex(1024);
639 1 : let mut stream = PqStream::new(Stream::from_raw(server));
640 1 :
641 1 : let ctx = RequestMonitoring::test();
642 1 : let api = Auth {
643 1 : ips: vec![],
644 3 : secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
645 1 : };
646 1 :
647 1 : let user_info = ComputeUserInfoMaybeEndpoint {
648 1 : user: "conrad".into(),
649 1 : endpoint_id: Some("endpoint".into()),
650 1 : options: NeonOptions::default(),
651 1 : };
652 1 :
653 1 : let handle = tokio::spawn(async move {
654 1 : let mut scram = ScramSha256::new(b"my-secret-password", ChannelBinding::unsupported());
655 1 :
656 1 : let mut read = BytesMut::new();
657 1 :
658 1 : // server should offer scram
659 1 : match read_message(&mut client, &mut read).await {
660 1 : PgMessage::AuthenticationSasl(a) => {
661 1 : let options: Vec<&str> = a.mechanisms().collect().unwrap();
662 1 : assert_eq!(options, ["SCRAM-SHA-256"]);
663 1 : }
664 1 : _ => panic!("wrong message"),
665 1 : }
666 1 :
667 1 : // client sends client-first-message
668 1 : let mut write = BytesMut::new();
669 1 : frontend::sasl_initial_response("SCRAM-SHA-256", scram.message(), &mut write).unwrap();
670 1 : client.write_all(&write).await.unwrap();
671 1 :
672 1 : // server response with server-first-message
673 1 : match read_message(&mut client, &mut read).await {
674 1 : PgMessage::AuthenticationSaslContinue(a) => {
675 3 : scram.update(a.data()).await.unwrap();
676 1 : }
677 1 : _ => panic!("wrong message"),
678 1 : }
679 1 :
680 1 : // client response with client-final-message
681 1 : write.clear();
682 1 : frontend::sasl_response(scram.message(), &mut write).unwrap();
683 1 : client.write_all(&write).await.unwrap();
684 1 :
685 1 : // server response with server-final-message
686 1 : match read_message(&mut client, &mut read).await {
687 1 : PgMessage::AuthenticationSaslFinal(a) => {
688 1 : scram.finish(a.data()).unwrap();
689 1 : }
690 1 : _ => panic!("wrong message"),
691 1 : }
692 1 : });
693 1 : let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
694 1 : EndpointRateLimiter::DEFAULT,
695 1 : 64,
696 1 : ));
697 1 :
698 1 : let _creds = auth_quirks(
699 1 : &ctx,
700 1 : &api,
701 1 : user_info,
702 1 : &mut stream,
703 1 : false,
704 1 : &CONFIG,
705 1 : endpoint_rate_limiter,
706 1 : )
707 2 : .await
708 1 : .unwrap();
709 1 :
710 1 : handle.await.unwrap();
711 1 : }
712 :
713 : #[tokio::test]
714 1 : async fn auth_quirks_cleartext() {
715 1 : let (mut client, server) = tokio::io::duplex(1024);
716 1 : let mut stream = PqStream::new(Stream::from_raw(server));
717 1 :
718 1 : let ctx = RequestMonitoring::test();
719 1 : let api = Auth {
720 1 : ips: vec![],
721 3 : secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
722 1 : };
723 1 :
724 1 : let user_info = ComputeUserInfoMaybeEndpoint {
725 1 : user: "conrad".into(),
726 1 : endpoint_id: Some("endpoint".into()),
727 1 : options: NeonOptions::default(),
728 1 : };
729 1 :
730 1 : let handle = tokio::spawn(async move {
731 1 : let mut read = BytesMut::new();
732 1 : let mut write = BytesMut::new();
733 1 :
734 1 : // server should offer cleartext
735 1 : match read_message(&mut client, &mut read).await {
736 1 : PgMessage::AuthenticationCleartextPassword => {}
737 1 : _ => panic!("wrong message"),
738 1 : }
739 1 :
740 1 : // client responds with password
741 1 : write.clear();
742 1 : frontend::password_message(b"my-secret-password", &mut write).unwrap();
743 1 : client.write_all(&write).await.unwrap();
744 1 : });
745 1 : let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
746 1 : EndpointRateLimiter::DEFAULT,
747 1 : 64,
748 1 : ));
749 1 :
750 1 : let _creds = auth_quirks(
751 1 : &ctx,
752 1 : &api,
753 1 : user_info,
754 1 : &mut stream,
755 1 : true,
756 1 : &CONFIG,
757 1 : endpoint_rate_limiter,
758 1 : )
759 2 : .await
760 1 : .unwrap();
761 1 :
762 1 : handle.await.unwrap();
763 1 : }
764 :
765 : #[tokio::test]
766 1 : async fn auth_quirks_password_hack() {
767 1 : let (mut client, server) = tokio::io::duplex(1024);
768 1 : let mut stream = PqStream::new(Stream::from_raw(server));
769 1 :
770 1 : let ctx = RequestMonitoring::test();
771 1 : let api = Auth {
772 1 : ips: vec![],
773 3 : secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
774 1 : };
775 1 :
776 1 : let user_info = ComputeUserInfoMaybeEndpoint {
777 1 : user: "conrad".into(),
778 1 : endpoint_id: None,
779 1 : options: NeonOptions::default(),
780 1 : };
781 1 :
782 1 : let handle = tokio::spawn(async move {
783 1 : let mut read = BytesMut::new();
784 1 :
785 1 : // server should offer cleartext
786 1 : match read_message(&mut client, &mut read).await {
787 1 : PgMessage::AuthenticationCleartextPassword => {}
788 1 : _ => panic!("wrong message"),
789 1 : }
790 1 :
791 1 : // client responds with password
792 1 : let mut write = BytesMut::new();
793 1 : frontend::password_message(b"endpoint=my-endpoint;my-secret-password", &mut write)
794 1 : .unwrap();
795 1 : client.write_all(&write).await.unwrap();
796 1 : });
797 1 :
798 1 : let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
799 1 : EndpointRateLimiter::DEFAULT,
800 1 : 64,
801 1 : ));
802 1 :
803 1 : let creds = auth_quirks(
804 1 : &ctx,
805 1 : &api,
806 1 : user_info,
807 1 : &mut stream,
808 1 : true,
809 1 : &CONFIG,
810 1 : endpoint_rate_limiter,
811 1 : )
812 2 : .await
813 1 : .unwrap();
814 1 :
815 1 : assert_eq!(creds.info.endpoint, "my-endpoint");
816 1 :
817 1 : handle.await.unwrap();
818 1 : }
819 : }
|