Line data Source code
1 : mod classic;
2 : mod console_redirect;
3 : mod hacks;
4 : pub mod jwt;
5 : pub mod local;
6 :
7 : use std::net::IpAddr;
8 : use std::sync::Arc;
9 :
10 : pub use console_redirect::ConsoleRedirectBackend;
11 : pub(crate) use console_redirect::ConsoleRedirectError;
12 : use ipnet::{Ipv4Net, Ipv6Net};
13 : use local::LocalBackend;
14 : use postgres_client::config::AuthKeys;
15 : use serde::{Deserialize, Serialize};
16 : use tokio::io::{AsyncRead, AsyncWrite};
17 : use tracing::{debug, info, warn};
18 :
19 : use crate::auth::credentials::check_peer_addr_is_in_list;
20 : use crate::auth::{
21 : self, validate_password_and_exchange, AuthError, ComputeUserInfoMaybeEndpoint, IpPattern,
22 : };
23 : use crate::cache::Cached;
24 : use crate::config::AuthenticationConfig;
25 : use crate::context::RequestContext;
26 : use crate::control_plane::client::ControlPlaneClient;
27 : use crate::control_plane::errors::GetAuthInfoError;
28 : use crate::control_plane::{
29 : self, AuthSecret, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret, ControlPlaneApi,
30 : };
31 : use crate::intern::EndpointIdInt;
32 : use crate::metrics::Metrics;
33 : use crate::proxy::connect_compute::ComputeConnectBackend;
34 : use crate::proxy::NeonOptions;
35 : use crate::rate_limiter::{BucketRateLimiter, EndpointRateLimiter};
36 : use crate::stream::Stream;
37 : use crate::types::{EndpointCacheKey, EndpointId, RoleName};
38 : use crate::{scram, stream};
39 :
40 : /// Alternative to [`std::borrow::Cow`] but doesn't need `T: ToOwned` as we don't need that functionality
41 : pub enum MaybeOwned<'a, T> {
42 : Owned(T),
43 : Borrowed(&'a T),
44 : }
45 :
46 : impl<T> std::ops::Deref for MaybeOwned<'_, T> {
47 : type Target = T;
48 :
49 13 : fn deref(&self) -> &Self::Target {
50 13 : match self {
51 13 : MaybeOwned::Owned(t) => t,
52 0 : MaybeOwned::Borrowed(t) => t,
53 : }
54 13 : }
55 : }
56 :
57 : /// This type serves two purposes:
58 : ///
59 : /// * When `T` is `()`, it's just a regular auth backend selector
60 : /// which we use in [`crate::config::ProxyConfig`].
61 : ///
62 : /// * However, when we substitute `T` with [`ComputeUserInfoMaybeEndpoint`],
63 : /// this helps us provide the credentials only to those auth
64 : /// backends which require them for the authentication process.
65 : pub enum Backend<'a, T> {
66 : /// Cloud API (V2).
67 : ControlPlane(MaybeOwned<'a, ControlPlaneClient>, T),
68 : /// Local proxy uses configured auth credentials and does not wake compute
69 : Local(MaybeOwned<'a, LocalBackend>),
70 : }
71 :
72 : impl std::fmt::Display for Backend<'_, ()> {
73 0 : fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74 0 : match self {
75 0 : Self::ControlPlane(api, ()) => match &**api {
76 0 : ControlPlaneClient::ProxyV1(endpoint) => fmt
77 0 : .debug_tuple("ControlPlane::ProxyV1")
78 0 : .field(&endpoint.url())
79 0 : .finish(),
80 : #[cfg(any(test, feature = "testing"))]
81 0 : ControlPlaneClient::PostgresMock(endpoint) => fmt
82 0 : .debug_tuple("ControlPlane::PostgresMock")
83 0 : .field(&endpoint.url())
84 0 : .finish(),
85 : #[cfg(test)]
86 0 : ControlPlaneClient::Test(_) => fmt.debug_tuple("ControlPlane::Test").finish(),
87 : },
88 0 : Self::Local(_) => fmt.debug_tuple("Local").finish(),
89 : }
90 0 : }
91 : }
92 :
93 : impl<T> Backend<'_, T> {
94 : /// Very similar to [`std::option::Option::as_ref`].
95 : /// This helps us pass structured config to async tasks.
96 0 : pub(crate) fn as_ref(&self) -> Backend<'_, &T> {
97 0 : match self {
98 0 : Self::ControlPlane(c, x) => Backend::ControlPlane(MaybeOwned::Borrowed(c), x),
99 0 : Self::Local(l) => Backend::Local(MaybeOwned::Borrowed(l)),
100 : }
101 0 : }
102 : }
103 :
104 : impl<'a, T> Backend<'a, T> {
105 : /// Very similar to [`std::option::Option::map`].
106 : /// Maps [`Backend<T>`] to [`Backend<R>`] by applying
107 : /// a function to a contained value.
108 0 : pub(crate) fn map<R>(self, f: impl FnOnce(T) -> R) -> Backend<'a, R> {
109 0 : match self {
110 0 : Self::ControlPlane(c, x) => Backend::ControlPlane(c, f(x)),
111 0 : Self::Local(l) => Backend::Local(l),
112 : }
113 0 : }
114 : }
115 : impl<'a, T, E> Backend<'a, Result<T, E>> {
116 : /// Very similar to [`std::option::Option::transpose`].
117 : /// This is most useful for error handling.
118 0 : pub(crate) fn transpose(self) -> Result<Backend<'a, T>, E> {
119 0 : match self {
120 0 : Self::ControlPlane(c, x) => x.map(|x| Backend::ControlPlane(c, x)),
121 0 : Self::Local(l) => Ok(Backend::Local(l)),
122 : }
123 0 : }
124 : }
125 :
126 : pub(crate) struct ComputeCredentials {
127 : pub(crate) info: ComputeUserInfo,
128 : pub(crate) keys: ComputeCredentialKeys,
129 : }
130 :
131 : #[derive(Debug, Clone)]
132 : pub(crate) struct ComputeUserInfoNoEndpoint {
133 : pub(crate) user: RoleName,
134 : pub(crate) options: NeonOptions,
135 : }
136 :
137 0 : #[derive(Debug, Clone, Default, Serialize, Deserialize)]
138 : pub(crate) struct ComputeUserInfo {
139 : pub(crate) endpoint: EndpointId,
140 : pub(crate) user: RoleName,
141 : pub(crate) options: NeonOptions,
142 : }
143 :
144 : impl ComputeUserInfo {
145 2 : pub(crate) fn endpoint_cache_key(&self) -> EndpointCacheKey {
146 2 : self.options.get_cache_key(&self.endpoint)
147 2 : }
148 : }
149 :
150 : #[cfg_attr(test, derive(Debug))]
151 : pub(crate) enum ComputeCredentialKeys {
152 : #[cfg(any(test, feature = "testing"))]
153 : Password(Vec<u8>),
154 : AuthKeys(AuthKeys),
155 : JwtPayload(Vec<u8>),
156 : None,
157 : }
158 :
159 : impl TryFrom<ComputeUserInfoMaybeEndpoint> for ComputeUserInfo {
160 : // user name
161 : type Error = ComputeUserInfoNoEndpoint;
162 :
163 3 : fn try_from(user_info: ComputeUserInfoMaybeEndpoint) -> Result<Self, Self::Error> {
164 3 : match user_info.endpoint_id {
165 1 : None => Err(ComputeUserInfoNoEndpoint {
166 1 : user: user_info.user,
167 1 : options: user_info.options,
168 1 : }),
169 2 : Some(endpoint) => Ok(ComputeUserInfo {
170 2 : endpoint,
171 2 : user: user_info.user,
172 2 : options: user_info.options,
173 2 : }),
174 : }
175 3 : }
176 : }
177 :
178 : #[derive(PartialEq, PartialOrd, Hash, Eq, Ord, Debug, Copy, Clone)]
179 : pub struct MaskedIp(IpAddr);
180 :
181 : impl MaskedIp {
182 15 : fn new(value: IpAddr, prefix: u8) -> Self {
183 15 : match value {
184 11 : IpAddr::V4(v4) => Self(IpAddr::V4(
185 11 : Ipv4Net::new(v4, prefix).map_or(v4, |x| x.trunc().addr()),
186 11 : )),
187 4 : IpAddr::V6(v6) => Self(IpAddr::V6(
188 4 : Ipv6Net::new(v6, prefix).map_or(v6, |x| x.trunc().addr()),
189 4 : )),
190 : }
191 15 : }
192 : }
193 :
194 : // This can't be just per IP because that would limit some PaaS that share IP addresses
195 : pub type AuthRateLimiter = BucketRateLimiter<(EndpointIdInt, MaskedIp)>;
196 :
197 : impl AuthenticationConfig {
198 3 : pub(crate) fn check_rate_limit(
199 3 : &self,
200 3 : ctx: &RequestContext,
201 3 : secret: AuthSecret,
202 3 : endpoint: &EndpointId,
203 3 : is_cleartext: bool,
204 3 : ) -> auth::Result<AuthSecret> {
205 3 : // we have validated the endpoint exists, so let's intern it.
206 3 : let endpoint_int = EndpointIdInt::from(endpoint.normalize());
207 :
208 : // only count the full hash count if password hack or websocket flow.
209 : // in other words, if proxy needs to run the hashing
210 3 : let password_weight = if is_cleartext {
211 2 : match &secret {
212 : #[cfg(any(test, feature = "testing"))]
213 0 : AuthSecret::Md5(_) => 1,
214 2 : AuthSecret::Scram(s) => s.iterations + 1,
215 : }
216 : } else {
217 : // validating scram takes just 1 hmac_sha_256 operation.
218 1 : 1
219 : };
220 :
221 3 : let limit_not_exceeded = self.rate_limiter.check(
222 3 : (
223 3 : endpoint_int,
224 3 : MaskedIp::new(ctx.peer_addr(), self.rate_limit_ip_subnet),
225 3 : ),
226 3 : password_weight,
227 3 : );
228 3 :
229 3 : if !limit_not_exceeded {
230 0 : warn!(
231 : enabled = self.rate_limiter_enabled,
232 0 : "rate limiting authentication"
233 : );
234 0 : Metrics::get().proxy.requests_auth_rate_limits_total.inc();
235 0 : Metrics::get()
236 0 : .proxy
237 0 : .endpoints_auth_rate_limits
238 0 : .get_metric()
239 0 : .measure(endpoint);
240 0 :
241 0 : if self.rate_limiter_enabled {
242 0 : return Err(auth::AuthError::too_many_connections());
243 0 : }
244 3 : }
245 :
246 3 : Ok(secret)
247 3 : }
248 : }
249 :
250 : #[async_trait::async_trait]
251 : pub(crate) trait BackendIpAllowlist {
252 : async fn get_allowed_ips(
253 : &self,
254 : ctx: &RequestContext,
255 : user_info: &ComputeUserInfo,
256 : ) -> auth::Result<Vec<auth::IpPattern>>;
257 : }
258 :
259 : /// True to its name, this function encapsulates our current auth trade-offs.
260 : /// Here, we choose the appropriate auth flow based on circumstances.
261 : ///
262 : /// All authentication flows will emit an AuthenticationOk message if successful.
263 3 : async fn auth_quirks(
264 3 : ctx: &RequestContext,
265 3 : api: &impl control_plane::ControlPlaneApi,
266 3 : user_info: ComputeUserInfoMaybeEndpoint,
267 3 : client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
268 3 : allow_cleartext: bool,
269 3 : config: &'static AuthenticationConfig,
270 3 : endpoint_rate_limiter: Arc<EndpointRateLimiter>,
271 3 : ) -> auth::Result<(ComputeCredentials, Option<Vec<IpPattern>>)> {
272 : // If there's no project so far, that entails that client doesn't
273 : // support SNI or other means of passing the endpoint (project) name.
274 : // We now expect to see a very specific payload in the place of password.
275 3 : let (info, unauthenticated_password) = match user_info.try_into() {
276 1 : Err(info) => {
277 1 : let (info, password) =
278 1 : hacks::password_hack_no_authentication(ctx, info, client).await?;
279 1 : ctx.set_endpoint_id(info.endpoint.clone());
280 1 : (info, Some(password))
281 : }
282 2 : Ok(info) => (info, None),
283 : };
284 :
285 3 : debug!("fetching user's authentication info");
286 3 : let (allowed_ips, maybe_secret) = api.get_allowed_ips_and_secret(ctx, &info).await?;
287 :
288 : // check allowed list
289 3 : if config.ip_allowlist_check_enabled
290 3 : && !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips)
291 : {
292 0 : return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr()));
293 3 : }
294 3 :
295 3 : if !endpoint_rate_limiter.check(info.endpoint.clone().into(), 1) {
296 0 : return Err(AuthError::too_many_connections());
297 3 : }
298 3 : let cached_secret = match maybe_secret {
299 3 : Some(secret) => secret,
300 0 : None => api.get_role_secret(ctx, &info).await?,
301 : };
302 3 : let (cached_entry, secret) = cached_secret.take_value();
303 :
304 3 : let secret = if let Some(secret) = secret {
305 3 : config.check_rate_limit(
306 3 : ctx,
307 3 : secret,
308 3 : &info.endpoint,
309 3 : unauthenticated_password.is_some() || allow_cleartext,
310 0 : )?
311 : } else {
312 : // If we don't have an authentication secret, we mock one to
313 : // prevent malicious probing (possible due to missing protocol steps).
314 : // This mocked secret will never lead to successful authentication.
315 0 : info!("authentication info not found, mocking it");
316 0 : AuthSecret::Scram(scram::ServerSecret::mock(rand::random()))
317 : };
318 :
319 3 : match authenticate_with_secret(
320 3 : ctx,
321 3 : secret,
322 3 : info,
323 3 : client,
324 3 : unauthenticated_password,
325 3 : allow_cleartext,
326 3 : config,
327 3 : )
328 3 : .await
329 : {
330 3 : Ok(keys) => Ok((keys, Some(allowed_ips.as_ref().clone()))),
331 0 : Err(e) => {
332 0 : if e.is_password_failed() {
333 0 : // The password could have been changed, so we invalidate the cache.
334 0 : cached_entry.invalidate();
335 0 : }
336 0 : Err(e)
337 : }
338 : }
339 3 : }
340 :
341 3 : async fn authenticate_with_secret(
342 3 : ctx: &RequestContext,
343 3 : secret: AuthSecret,
344 3 : info: ComputeUserInfo,
345 3 : client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
346 3 : unauthenticated_password: Option<Vec<u8>>,
347 3 : allow_cleartext: bool,
348 3 : config: &'static AuthenticationConfig,
349 3 : ) -> auth::Result<ComputeCredentials> {
350 3 : if let Some(password) = unauthenticated_password {
351 1 : let ep = EndpointIdInt::from(&info.endpoint);
352 :
353 1 : let auth_outcome =
354 1 : validate_password_and_exchange(&config.thread_pool, ep, &password, secret).await?;
355 1 : let keys = match auth_outcome {
356 1 : crate::sasl::Outcome::Success(key) => key,
357 0 : crate::sasl::Outcome::Failure(reason) => {
358 0 : info!("auth backend failed with an error: {reason}");
359 0 : return Err(auth::AuthError::password_failed(&*info.user));
360 : }
361 : };
362 :
363 : // we have authenticated the password
364 1 : client.write_message_noflush(&pq_proto::BeMessage::AuthenticationOk)?;
365 :
366 1 : return Ok(ComputeCredentials { info, keys });
367 2 : }
368 2 :
369 2 : // -- the remaining flows are self-authenticating --
370 2 :
371 2 : // Perform cleartext auth if we're allowed to do that.
372 2 : // Currently, we use it for websocket connections (latency).
373 2 : if allow_cleartext {
374 1 : ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
375 1 : return hacks::authenticate_cleartext(ctx, info, client, secret, config).await;
376 1 : }
377 1 :
378 1 : // Finally, proceed with the main auth flow (SCRAM-based).
379 1 : classic::authenticate(ctx, info, client, config, secret).await
380 3 : }
381 :
382 : impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> {
383 : /// Get username from the credentials.
384 0 : pub(crate) fn get_user(&self) -> &str {
385 0 : match self {
386 0 : Self::ControlPlane(_, user_info) => &user_info.user,
387 0 : Self::Local(_) => "local",
388 : }
389 0 : }
390 :
391 : /// Authenticate the client via the requested backend, possibly using credentials.
392 : #[tracing::instrument(fields(allow_cleartext = allow_cleartext), skip_all)]
393 : pub(crate) async fn authenticate(
394 : self,
395 : ctx: &RequestContext,
396 : client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
397 : allow_cleartext: bool,
398 : config: &'static AuthenticationConfig,
399 : endpoint_rate_limiter: Arc<EndpointRateLimiter>,
400 : ) -> auth::Result<(Backend<'a, ComputeCredentials>, Option<Vec<IpPattern>>)> {
401 : let res = match self {
402 : Self::ControlPlane(api, user_info) => {
403 : debug!(
404 : user = &*user_info.user,
405 : project = user_info.endpoint(),
406 : "performing authentication using the console"
407 : );
408 :
409 : let (credentials, ip_allowlist) = auth_quirks(
410 : ctx,
411 : &*api,
412 : user_info,
413 : client,
414 : allow_cleartext,
415 : config,
416 : endpoint_rate_limiter,
417 : )
418 : .await?;
419 : Ok((Backend::ControlPlane(api, credentials), ip_allowlist))
420 : }
421 : Self::Local(_) => {
422 : return Err(auth::AuthError::bad_auth_method("invalid for local proxy"))
423 : }
424 : };
425 :
426 : // TODO: replace with some metric
427 : info!("user successfully authenticated");
428 : res
429 : }
430 : }
431 :
432 : impl Backend<'_, ComputeUserInfo> {
433 0 : pub(crate) async fn get_role_secret(
434 0 : &self,
435 0 : ctx: &RequestContext,
436 0 : ) -> Result<CachedRoleSecret, GetAuthInfoError> {
437 0 : match self {
438 0 : Self::ControlPlane(api, user_info) => api.get_role_secret(ctx, user_info).await,
439 0 : Self::Local(_) => Ok(Cached::new_uncached(None)),
440 : }
441 0 : }
442 :
443 0 : pub(crate) async fn get_allowed_ips_and_secret(
444 0 : &self,
445 0 : ctx: &RequestContext,
446 0 : ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
447 0 : match self {
448 0 : Self::ControlPlane(api, user_info) => {
449 0 : api.get_allowed_ips_and_secret(ctx, user_info).await
450 : }
451 0 : Self::Local(_) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
452 : }
453 0 : }
454 : }
455 :
456 : #[async_trait::async_trait]
457 : impl BackendIpAllowlist for Backend<'_, ()> {
458 0 : async fn get_allowed_ips(
459 0 : &self,
460 0 : ctx: &RequestContext,
461 0 : user_info: &ComputeUserInfo,
462 0 : ) -> auth::Result<Vec<auth::IpPattern>> {
463 0 : let auth_data = match self {
464 0 : Self::ControlPlane(api, ()) => api.get_allowed_ips_and_secret(ctx, user_info).await,
465 0 : Self::Local(_) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
466 : };
467 :
468 0 : auth_data
469 0 : .map(|(ips, _)| ips.as_ref().clone())
470 0 : .map_err(|e| e.into())
471 0 : }
472 : }
473 :
474 : #[async_trait::async_trait]
475 : impl ComputeConnectBackend for Backend<'_, ComputeCredentials> {
476 13 : async fn wake_compute(
477 13 : &self,
478 13 : ctx: &RequestContext,
479 13 : ) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
480 13 : match self {
481 13 : Self::ControlPlane(api, creds) => api.wake_compute(ctx, &creds.info).await,
482 0 : Self::Local(local) => Ok(Cached::new_uncached(local.node_info.clone())),
483 : }
484 26 : }
485 :
486 6 : fn get_keys(&self) -> &ComputeCredentialKeys {
487 6 : match self {
488 6 : Self::ControlPlane(_, creds) => &creds.keys,
489 0 : Self::Local(_) => &ComputeCredentialKeys::None,
490 : }
491 6 : }
492 : }
493 :
494 : #[cfg(test)]
495 : mod tests {
496 : #![allow(clippy::unimplemented, clippy::unwrap_used)]
497 :
498 : use std::net::IpAddr;
499 : use std::sync::Arc;
500 : use std::time::Duration;
501 :
502 : use bytes::BytesMut;
503 : use control_plane::AuthSecret;
504 : use fallible_iterator::FallibleIterator;
505 : use once_cell::sync::Lazy;
506 : use postgres_protocol::authentication::sasl::{ChannelBinding, ScramSha256};
507 : use postgres_protocol::message::backend::Message as PgMessage;
508 : use postgres_protocol::message::frontend;
509 : use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
510 :
511 : use super::jwt::JwkCache;
512 : use super::{auth_quirks, AuthRateLimiter};
513 : use crate::auth::backend::MaskedIp;
514 : use crate::auth::{ComputeUserInfoMaybeEndpoint, IpPattern};
515 : use crate::config::AuthenticationConfig;
516 : use crate::context::RequestContext;
517 : use crate::control_plane::{self, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret};
518 : use crate::proxy::NeonOptions;
519 : use crate::rate_limiter::{EndpointRateLimiter, RateBucketInfo};
520 : use crate::scram::threadpool::ThreadPool;
521 : use crate::scram::ServerSecret;
522 : use crate::stream::{PqStream, Stream};
523 :
524 : struct Auth {
525 : ips: Vec<IpPattern>,
526 : secret: AuthSecret,
527 : }
528 :
529 : impl control_plane::ControlPlaneApi for Auth {
530 0 : async fn get_role_secret(
531 0 : &self,
532 0 : _ctx: &RequestContext,
533 0 : _user_info: &super::ComputeUserInfo,
534 0 : ) -> Result<CachedRoleSecret, control_plane::errors::GetAuthInfoError> {
535 0 : Ok(CachedRoleSecret::new_uncached(Some(self.secret.clone())))
536 0 : }
537 :
538 3 : async fn get_allowed_ips_and_secret(
539 3 : &self,
540 3 : _ctx: &RequestContext,
541 3 : _user_info: &super::ComputeUserInfo,
542 3 : ) -> Result<
543 3 : (CachedAllowedIps, Option<CachedRoleSecret>),
544 3 : control_plane::errors::GetAuthInfoError,
545 3 : > {
546 3 : Ok((
547 3 : CachedAllowedIps::new_uncached(Arc::new(self.ips.clone())),
548 3 : Some(CachedRoleSecret::new_uncached(Some(self.secret.clone()))),
549 3 : ))
550 3 : }
551 :
552 0 : async fn get_endpoint_jwks(
553 0 : &self,
554 0 : _ctx: &RequestContext,
555 0 : _endpoint: crate::types::EndpointId,
556 0 : ) -> Result<Vec<super::jwt::AuthRule>, control_plane::errors::GetEndpointJwksError>
557 0 : {
558 0 : unimplemented!()
559 : }
560 :
561 0 : async fn wake_compute(
562 0 : &self,
563 0 : _ctx: &RequestContext,
564 0 : _user_info: &super::ComputeUserInfo,
565 0 : ) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
566 0 : unimplemented!()
567 : }
568 : }
569 :
570 3 : static CONFIG: Lazy<AuthenticationConfig> = Lazy::new(|| AuthenticationConfig {
571 3 : jwks_cache: JwkCache::default(),
572 3 : thread_pool: ThreadPool::new(1),
573 3 : scram_protocol_timeout: std::time::Duration::from_secs(5),
574 3 : rate_limiter_enabled: true,
575 3 : rate_limiter: AuthRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET),
576 3 : rate_limit_ip_subnet: 64,
577 3 : ip_allowlist_check_enabled: true,
578 3 : is_auth_broker: false,
579 3 : accept_jwts: false,
580 3 : console_redirect_confirmation_timeout: std::time::Duration::from_secs(5),
581 3 : });
582 :
583 5 : async fn read_message(r: &mut (impl AsyncRead + Unpin), b: &mut BytesMut) -> PgMessage {
584 : loop {
585 7 : r.read_buf(&mut *b).await.unwrap();
586 7 : if let Some(m) = PgMessage::parse(&mut *b).unwrap() {
587 5 : break m;
588 2 : }
589 : }
590 5 : }
591 :
592 : #[test]
593 1 : fn masked_ip() {
594 1 : let ip_a = IpAddr::V4([127, 0, 0, 1].into());
595 1 : let ip_b = IpAddr::V4([127, 0, 0, 2].into());
596 1 : let ip_c = IpAddr::V4([192, 168, 1, 101].into());
597 1 : let ip_d = IpAddr::V4([192, 168, 1, 102].into());
598 1 : let ip_e = IpAddr::V6("abcd:abcd:abcd:abcd:abcd:abcd:abcd:abcd".parse().unwrap());
599 1 : let ip_f = IpAddr::V6("abcd:abcd:abcd:abcd:1234:abcd:abcd:abcd".parse().unwrap());
600 1 :
601 1 : assert_ne!(MaskedIp::new(ip_a, 64), MaskedIp::new(ip_b, 64));
602 1 : assert_ne!(MaskedIp::new(ip_a, 32), MaskedIp::new(ip_b, 32));
603 1 : assert_eq!(MaskedIp::new(ip_a, 30), MaskedIp::new(ip_b, 30));
604 1 : assert_eq!(MaskedIp::new(ip_c, 30), MaskedIp::new(ip_d, 30));
605 :
606 1 : assert_ne!(MaskedIp::new(ip_e, 128), MaskedIp::new(ip_f, 128));
607 1 : assert_eq!(MaskedIp::new(ip_e, 64), MaskedIp::new(ip_f, 64));
608 1 : }
609 :
610 : #[test]
611 1 : fn test_default_auth_rate_limit_set() {
612 1 : // these values used to exceed u32::MAX
613 1 : assert_eq!(
614 1 : RateBucketInfo::DEFAULT_AUTH_SET,
615 1 : [
616 1 : RateBucketInfo {
617 1 : interval: Duration::from_secs(1),
618 1 : max_rpi: 1000 * 4096,
619 1 : },
620 1 : RateBucketInfo {
621 1 : interval: Duration::from_secs(60),
622 1 : max_rpi: 600 * 4096 * 60,
623 1 : },
624 1 : RateBucketInfo {
625 1 : interval: Duration::from_secs(600),
626 1 : max_rpi: 300 * 4096 * 600,
627 1 : }
628 1 : ]
629 1 : );
630 :
631 4 : for x in RateBucketInfo::DEFAULT_AUTH_SET {
632 3 : let y = x.to_string().parse().unwrap();
633 3 : assert_eq!(x, y);
634 : }
635 1 : }
636 :
637 : #[tokio::test]
638 1 : async fn auth_quirks_scram() {
639 1 : let (mut client, server) = tokio::io::duplex(1024);
640 1 : let mut stream = PqStream::new(Stream::from_raw(server));
641 1 :
642 1 : let ctx = RequestContext::test();
643 1 : let api = Auth {
644 1 : ips: vec![],
645 1 : secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
646 1 : };
647 1 :
648 1 : let user_info = ComputeUserInfoMaybeEndpoint {
649 1 : user: "conrad".into(),
650 1 : endpoint_id: Some("endpoint".into()),
651 1 : options: NeonOptions::default(),
652 1 : };
653 1 :
654 1 : let handle = tokio::spawn(async move {
655 1 : let mut scram = ScramSha256::new(b"my-secret-password", ChannelBinding::unsupported());
656 1 :
657 1 : let mut read = BytesMut::new();
658 1 :
659 1 : // server should offer scram
660 1 : match read_message(&mut client, &mut read).await {
661 1 : PgMessage::AuthenticationSasl(a) => {
662 1 : let options: Vec<&str> = a.mechanisms().collect().unwrap();
663 1 : assert_eq!(options, ["SCRAM-SHA-256"]);
664 1 : }
665 1 : _ => panic!("wrong message"),
666 1 : }
667 1 :
668 1 : // client sends client-first-message
669 1 : let mut write = BytesMut::new();
670 1 : frontend::sasl_initial_response("SCRAM-SHA-256", scram.message(), &mut write).unwrap();
671 1 : client.write_all(&write).await.unwrap();
672 1 :
673 1 : // server response with server-first-message
674 1 : match read_message(&mut client, &mut read).await {
675 1 : PgMessage::AuthenticationSaslContinue(a) => {
676 1 : scram.update(a.data()).await.unwrap();
677 1 : }
678 1 : _ => panic!("wrong message"),
679 1 : }
680 1 :
681 1 : // client response with client-final-message
682 1 : write.clear();
683 1 : frontend::sasl_response(scram.message(), &mut write).unwrap();
684 1 : client.write_all(&write).await.unwrap();
685 1 :
686 1 : // server response with server-final-message
687 1 : match read_message(&mut client, &mut read).await {
688 1 : PgMessage::AuthenticationSaslFinal(a) => {
689 1 : scram.finish(a.data()).unwrap();
690 1 : }
691 1 : _ => panic!("wrong message"),
692 1 : }
693 1 : });
694 1 : let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
695 1 : EndpointRateLimiter::DEFAULT,
696 1 : 64,
697 1 : ));
698 1 :
699 1 : let _creds = auth_quirks(
700 1 : &ctx,
701 1 : &api,
702 1 : user_info,
703 1 : &mut stream,
704 1 : false,
705 1 : &CONFIG,
706 1 : endpoint_rate_limiter,
707 1 : )
708 1 : .await
709 1 : .unwrap();
710 1 :
711 1 : // flush the final server message
712 1 : stream.flush().await.unwrap();
713 1 :
714 1 : handle.await.unwrap();
715 1 : }
716 :
717 : #[tokio::test]
718 1 : async fn auth_quirks_cleartext() {
719 1 : let (mut client, server) = tokio::io::duplex(1024);
720 1 : let mut stream = PqStream::new(Stream::from_raw(server));
721 1 :
722 1 : let ctx = RequestContext::test();
723 1 : let api = Auth {
724 1 : ips: vec![],
725 1 : secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
726 1 : };
727 1 :
728 1 : let user_info = ComputeUserInfoMaybeEndpoint {
729 1 : user: "conrad".into(),
730 1 : endpoint_id: Some("endpoint".into()),
731 1 : options: NeonOptions::default(),
732 1 : };
733 1 :
734 1 : let handle = tokio::spawn(async move {
735 1 : let mut read = BytesMut::new();
736 1 : let mut write = BytesMut::new();
737 1 :
738 1 : // server should offer cleartext
739 1 : match read_message(&mut client, &mut read).await {
740 1 : PgMessage::AuthenticationCleartextPassword => {}
741 1 : _ => panic!("wrong message"),
742 1 : }
743 1 :
744 1 : // client responds with password
745 1 : write.clear();
746 1 : frontend::password_message(b"my-secret-password", &mut write).unwrap();
747 1 : client.write_all(&write).await.unwrap();
748 1 : });
749 1 : let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
750 1 : EndpointRateLimiter::DEFAULT,
751 1 : 64,
752 1 : ));
753 1 :
754 1 : let _creds = auth_quirks(
755 1 : &ctx,
756 1 : &api,
757 1 : user_info,
758 1 : &mut stream,
759 1 : true,
760 1 : &CONFIG,
761 1 : endpoint_rate_limiter,
762 1 : )
763 1 : .await
764 1 : .unwrap();
765 1 :
766 1 : handle.await.unwrap();
767 1 : }
768 :
769 : #[tokio::test]
770 1 : async fn auth_quirks_password_hack() {
771 1 : let (mut client, server) = tokio::io::duplex(1024);
772 1 : let mut stream = PqStream::new(Stream::from_raw(server));
773 1 :
774 1 : let ctx = RequestContext::test();
775 1 : let api = Auth {
776 1 : ips: vec![],
777 1 : secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
778 1 : };
779 1 :
780 1 : let user_info = ComputeUserInfoMaybeEndpoint {
781 1 : user: "conrad".into(),
782 1 : endpoint_id: None,
783 1 : options: NeonOptions::default(),
784 1 : };
785 1 :
786 1 : let handle = tokio::spawn(async move {
787 1 : let mut read = BytesMut::new();
788 1 :
789 1 : // server should offer cleartext
790 1 : match read_message(&mut client, &mut read).await {
791 1 : PgMessage::AuthenticationCleartextPassword => {}
792 1 : _ => panic!("wrong message"),
793 1 : }
794 1 :
795 1 : // client responds with password
796 1 : let mut write = BytesMut::new();
797 1 : frontend::password_message(b"endpoint=my-endpoint;my-secret-password", &mut write)
798 1 : .unwrap();
799 1 : client.write_all(&write).await.unwrap();
800 1 : });
801 1 :
802 1 : let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
803 1 : EndpointRateLimiter::DEFAULT,
804 1 : 64,
805 1 : ));
806 1 :
807 1 : let creds = auth_quirks(
808 1 : &ctx,
809 1 : &api,
810 1 : user_info,
811 1 : &mut stream,
812 1 : true,
813 1 : &CONFIG,
814 1 : endpoint_rate_limiter,
815 1 : )
816 1 : .await
817 1 : .unwrap();
818 1 :
819 1 : assert_eq!(creds.0.info.endpoint, "my-endpoint");
820 1 :
821 1 : handle.await.unwrap();
822 1 : }
823 : }
|