Line data Source code
1 : mod classic;
2 : mod console_redirect;
3 : mod hacks;
4 : pub mod jwt;
5 : pub mod local;
6 :
7 : use std::net::IpAddr;
8 : use std::sync::Arc;
9 :
10 : pub use console_redirect::ConsoleRedirectBackend;
11 : pub(crate) use console_redirect::ConsoleRedirectError;
12 : use ipnet::{Ipv4Net, Ipv6Net};
13 : use local::LocalBackend;
14 : use postgres_client::config::AuthKeys;
15 : use tokio::io::{AsyncRead, AsyncWrite};
16 : use tracing::{debug, info, warn};
17 :
18 : use crate::auth::credentials::check_peer_addr_is_in_list;
19 : use crate::auth::{self, validate_password_and_exchange, AuthError, ComputeUserInfoMaybeEndpoint};
20 : use crate::cache::Cached;
21 : use crate::config::AuthenticationConfig;
22 : use crate::context::RequestContext;
23 : use crate::control_plane::client::ControlPlaneClient;
24 : use crate::control_plane::errors::GetAuthInfoError;
25 : use crate::control_plane::{
26 : self, AuthSecret, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret, ControlPlaneApi,
27 : };
28 : use crate::intern::EndpointIdInt;
29 : use crate::metrics::Metrics;
30 : use crate::proxy::connect_compute::ComputeConnectBackend;
31 : use crate::proxy::NeonOptions;
32 : use crate::rate_limiter::{BucketRateLimiter, EndpointRateLimiter};
33 : use crate::stream::Stream;
34 : use crate::types::{EndpointCacheKey, EndpointId, RoleName};
35 : use crate::{scram, stream};
36 :
37 : /// Alternative to [`std::borrow::Cow`] but doesn't need `T: ToOwned` as we don't need that functionality
38 : pub enum MaybeOwned<'a, T> {
39 : Owned(T),
40 : Borrowed(&'a T),
41 : }
42 :
43 : impl<T> std::ops::Deref for MaybeOwned<'_, T> {
44 : type Target = T;
45 :
46 13 : fn deref(&self) -> &Self::Target {
47 13 : match self {
48 13 : MaybeOwned::Owned(t) => t,
49 0 : MaybeOwned::Borrowed(t) => t,
50 : }
51 13 : }
52 : }
53 :
54 : /// This type serves two purposes:
55 : ///
56 : /// * When `T` is `()`, it's just a regular auth backend selector
57 : /// which we use in [`crate::config::ProxyConfig`].
58 : ///
59 : /// * However, when we substitute `T` with [`ComputeUserInfoMaybeEndpoint`],
60 : /// this helps us provide the credentials only to those auth
61 : /// backends which require them for the authentication process.
62 : pub enum Backend<'a, T> {
63 : /// Cloud API (V2).
64 : ControlPlane(MaybeOwned<'a, ControlPlaneClient>, T),
65 : /// Local proxy uses configured auth credentials and does not wake compute
66 : Local(MaybeOwned<'a, LocalBackend>),
67 : }
68 :
69 : impl std::fmt::Display for Backend<'_, ()> {
70 0 : fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 0 : match self {
72 0 : Self::ControlPlane(api, ()) => match &**api {
73 0 : ControlPlaneClient::ProxyV1(endpoint) => fmt
74 0 : .debug_tuple("ControlPlane::ProxyV1")
75 0 : .field(&endpoint.url())
76 0 : .finish(),
77 0 : ControlPlaneClient::Neon(endpoint) => fmt
78 0 : .debug_tuple("ControlPlane::Neon")
79 0 : .field(&endpoint.url())
80 0 : .finish(),
81 : #[cfg(any(test, feature = "testing"))]
82 0 : ControlPlaneClient::PostgresMock(endpoint) => fmt
83 0 : .debug_tuple("ControlPlane::PostgresMock")
84 0 : .field(&endpoint.url())
85 0 : .finish(),
86 : #[cfg(test)]
87 0 : ControlPlaneClient::Test(_) => fmt.debug_tuple("ControlPlane::Test").finish(),
88 : },
89 0 : Self::Local(_) => fmt.debug_tuple("Local").finish(),
90 : }
91 0 : }
92 : }
93 :
94 : impl<T> Backend<'_, T> {
95 : /// Very similar to [`std::option::Option::as_ref`].
96 : /// This helps us pass structured config to async tasks.
97 0 : pub(crate) fn as_ref(&self) -> Backend<'_, &T> {
98 0 : match self {
99 0 : Self::ControlPlane(c, x) => Backend::ControlPlane(MaybeOwned::Borrowed(c), x),
100 0 : Self::Local(l) => Backend::Local(MaybeOwned::Borrowed(l)),
101 : }
102 0 : }
103 : }
104 :
105 : impl<'a, T> Backend<'a, T> {
106 : /// Very similar to [`std::option::Option::map`].
107 : /// Maps [`Backend<T>`] to [`Backend<R>`] by applying
108 : /// a function to a contained value.
109 0 : pub(crate) fn map<R>(self, f: impl FnOnce(T) -> R) -> Backend<'a, R> {
110 0 : match self {
111 0 : Self::ControlPlane(c, x) => Backend::ControlPlane(c, f(x)),
112 0 : Self::Local(l) => Backend::Local(l),
113 : }
114 0 : }
115 : }
116 : impl<'a, T, E> Backend<'a, Result<T, E>> {
117 : /// Very similar to [`std::option::Option::transpose`].
118 : /// This is most useful for error handling.
119 0 : pub(crate) fn transpose(self) -> Result<Backend<'a, T>, E> {
120 0 : match self {
121 0 : Self::ControlPlane(c, x) => x.map(|x| Backend::ControlPlane(c, x)),
122 0 : Self::Local(l) => Ok(Backend::Local(l)),
123 : }
124 0 : }
125 : }
126 :
127 : pub(crate) struct ComputeCredentials {
128 : pub(crate) info: ComputeUserInfo,
129 : pub(crate) keys: ComputeCredentialKeys,
130 : }
131 :
132 : #[derive(Debug, Clone)]
133 : pub(crate) struct ComputeUserInfoNoEndpoint {
134 : pub(crate) user: RoleName,
135 : pub(crate) options: NeonOptions,
136 : }
137 :
138 : #[derive(Debug, Clone)]
139 : pub(crate) struct ComputeUserInfo {
140 : pub(crate) endpoint: EndpointId,
141 : pub(crate) user: RoleName,
142 : pub(crate) options: NeonOptions,
143 : }
144 :
145 : impl ComputeUserInfo {
146 2 : pub(crate) fn endpoint_cache_key(&self) -> EndpointCacheKey {
147 2 : self.options.get_cache_key(&self.endpoint)
148 2 : }
149 : }
150 :
151 : #[cfg_attr(test, derive(Debug))]
152 : pub(crate) enum ComputeCredentialKeys {
153 : #[cfg(any(test, feature = "testing"))]
154 : Password(Vec<u8>),
155 : AuthKeys(AuthKeys),
156 : JwtPayload(Vec<u8>),
157 : None,
158 : }
159 :
160 : impl TryFrom<ComputeUserInfoMaybeEndpoint> for ComputeUserInfo {
161 : // user name
162 : type Error = ComputeUserInfoNoEndpoint;
163 :
164 3 : fn try_from(user_info: ComputeUserInfoMaybeEndpoint) -> Result<Self, Self::Error> {
165 3 : match user_info.endpoint_id {
166 1 : None => Err(ComputeUserInfoNoEndpoint {
167 1 : user: user_info.user,
168 1 : options: user_info.options,
169 1 : }),
170 2 : Some(endpoint) => Ok(ComputeUserInfo {
171 2 : endpoint,
172 2 : user: user_info.user,
173 2 : options: user_info.options,
174 2 : }),
175 : }
176 3 : }
177 : }
178 :
179 : #[derive(PartialEq, PartialOrd, Hash, Eq, Ord, Debug, Copy, Clone)]
180 : pub struct MaskedIp(IpAddr);
181 :
182 : impl MaskedIp {
183 15 : fn new(value: IpAddr, prefix: u8) -> Self {
184 15 : match value {
185 11 : IpAddr::V4(v4) => Self(IpAddr::V4(
186 11 : Ipv4Net::new(v4, prefix).map_or(v4, |x| x.trunc().addr()),
187 11 : )),
188 4 : IpAddr::V6(v6) => Self(IpAddr::V6(
189 4 : Ipv6Net::new(v6, prefix).map_or(v6, |x| x.trunc().addr()),
190 4 : )),
191 : }
192 15 : }
193 : }
194 :
195 : // This can't be just per IP because that would limit some PaaS that share IP addresses
196 : pub type AuthRateLimiter = BucketRateLimiter<(EndpointIdInt, MaskedIp)>;
197 :
198 : impl AuthenticationConfig {
199 3 : pub(crate) fn check_rate_limit(
200 3 : &self,
201 3 : ctx: &RequestContext,
202 3 : secret: AuthSecret,
203 3 : endpoint: &EndpointId,
204 3 : is_cleartext: bool,
205 3 : ) -> auth::Result<AuthSecret> {
206 3 : // we have validated the endpoint exists, so let's intern it.
207 3 : let endpoint_int = EndpointIdInt::from(endpoint.normalize());
208 :
209 : // only count the full hash count if password hack or websocket flow.
210 : // in other words, if proxy needs to run the hashing
211 3 : let password_weight = if is_cleartext {
212 2 : match &secret {
213 : #[cfg(any(test, feature = "testing"))]
214 0 : AuthSecret::Md5(_) => 1,
215 2 : AuthSecret::Scram(s) => s.iterations + 1,
216 : }
217 : } else {
218 : // validating scram takes just 1 hmac_sha_256 operation.
219 1 : 1
220 : };
221 :
222 3 : let limit_not_exceeded = self.rate_limiter.check(
223 3 : (
224 3 : endpoint_int,
225 3 : MaskedIp::new(ctx.peer_addr(), self.rate_limit_ip_subnet),
226 3 : ),
227 3 : password_weight,
228 3 : );
229 3 :
230 3 : if !limit_not_exceeded {
231 0 : warn!(
232 : enabled = self.rate_limiter_enabled,
233 0 : "rate limiting authentication"
234 : );
235 0 : Metrics::get().proxy.requests_auth_rate_limits_total.inc();
236 0 : Metrics::get()
237 0 : .proxy
238 0 : .endpoints_auth_rate_limits
239 0 : .get_metric()
240 0 : .measure(endpoint);
241 0 :
242 0 : if self.rate_limiter_enabled {
243 0 : return Err(auth::AuthError::too_many_connections());
244 0 : }
245 3 : }
246 :
247 3 : Ok(secret)
248 3 : }
249 : }
250 :
251 : /// True to its name, this function encapsulates our current auth trade-offs.
252 : /// Here, we choose the appropriate auth flow based on circumstances.
253 : ///
254 : /// All authentication flows will emit an AuthenticationOk message if successful.
255 3 : async fn auth_quirks(
256 3 : ctx: &RequestContext,
257 3 : api: &impl control_plane::ControlPlaneApi,
258 3 : user_info: ComputeUserInfoMaybeEndpoint,
259 3 : client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
260 3 : allow_cleartext: bool,
261 3 : config: &'static AuthenticationConfig,
262 3 : endpoint_rate_limiter: Arc<EndpointRateLimiter>,
263 3 : ) -> auth::Result<ComputeCredentials> {
264 : // If there's no project so far, that entails that client doesn't
265 : // support SNI or other means of passing the endpoint (project) name.
266 : // We now expect to see a very specific payload in the place of password.
267 3 : let (info, unauthenticated_password) = match user_info.try_into() {
268 1 : Err(info) => {
269 1 : let (info, password) =
270 1 : hacks::password_hack_no_authentication(ctx, info, client).await?;
271 1 : ctx.set_endpoint_id(info.endpoint.clone());
272 1 : (info, Some(password))
273 : }
274 2 : Ok(info) => (info, None),
275 : };
276 :
277 3 : debug!("fetching user's authentication info");
278 3 : let (allowed_ips, maybe_secret) = api.get_allowed_ips_and_secret(ctx, &info).await?;
279 :
280 : // check allowed list
281 3 : if config.ip_allowlist_check_enabled
282 3 : && !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips)
283 : {
284 0 : return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr()));
285 3 : }
286 3 :
287 3 : if !endpoint_rate_limiter.check(info.endpoint.clone().into(), 1) {
288 0 : return Err(AuthError::too_many_connections());
289 3 : }
290 3 : let cached_secret = match maybe_secret {
291 3 : Some(secret) => secret,
292 0 : None => api.get_role_secret(ctx, &info).await?,
293 : };
294 3 : let (cached_entry, secret) = cached_secret.take_value();
295 :
296 3 : let secret = if let Some(secret) = secret {
297 3 : config.check_rate_limit(
298 3 : ctx,
299 3 : secret,
300 3 : &info.endpoint,
301 3 : unauthenticated_password.is_some() || allow_cleartext,
302 0 : )?
303 : } else {
304 : // If we don't have an authentication secret, we mock one to
305 : // prevent malicious probing (possible due to missing protocol steps).
306 : // This mocked secret will never lead to successful authentication.
307 0 : info!("authentication info not found, mocking it");
308 0 : AuthSecret::Scram(scram::ServerSecret::mock(rand::random()))
309 : };
310 :
311 3 : match authenticate_with_secret(
312 3 : ctx,
313 3 : secret,
314 3 : info,
315 3 : client,
316 3 : unauthenticated_password,
317 3 : allow_cleartext,
318 3 : config,
319 3 : )
320 3 : .await
321 : {
322 3 : Ok(keys) => Ok(keys),
323 0 : Err(e) => {
324 0 : if e.is_password_failed() {
325 0 : // The password could have been changed, so we invalidate the cache.
326 0 : cached_entry.invalidate();
327 0 : }
328 0 : Err(e)
329 : }
330 : }
331 3 : }
332 :
333 3 : async fn authenticate_with_secret(
334 3 : ctx: &RequestContext,
335 3 : secret: AuthSecret,
336 3 : info: ComputeUserInfo,
337 3 : client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
338 3 : unauthenticated_password: Option<Vec<u8>>,
339 3 : allow_cleartext: bool,
340 3 : config: &'static AuthenticationConfig,
341 3 : ) -> auth::Result<ComputeCredentials> {
342 3 : if let Some(password) = unauthenticated_password {
343 1 : let ep = EndpointIdInt::from(&info.endpoint);
344 :
345 1 : let auth_outcome =
346 1 : validate_password_and_exchange(&config.thread_pool, ep, &password, secret).await?;
347 1 : let keys = match auth_outcome {
348 1 : crate::sasl::Outcome::Success(key) => key,
349 0 : crate::sasl::Outcome::Failure(reason) => {
350 0 : info!("auth backend failed with an error: {reason}");
351 0 : return Err(auth::AuthError::password_failed(&*info.user));
352 : }
353 : };
354 :
355 : // we have authenticated the password
356 1 : client.write_message_noflush(&pq_proto::BeMessage::AuthenticationOk)?;
357 :
358 1 : return Ok(ComputeCredentials { info, keys });
359 2 : }
360 2 :
361 2 : // -- the remaining flows are self-authenticating --
362 2 :
363 2 : // Perform cleartext auth if we're allowed to do that.
364 2 : // Currently, we use it for websocket connections (latency).
365 2 : if allow_cleartext {
366 1 : ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
367 1 : return hacks::authenticate_cleartext(ctx, info, client, secret, config).await;
368 1 : }
369 1 :
370 1 : // Finally, proceed with the main auth flow (SCRAM-based).
371 1 : classic::authenticate(ctx, info, client, config, secret).await
372 3 : }
373 :
374 : impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> {
375 : /// Get username from the credentials.
376 0 : pub(crate) fn get_user(&self) -> &str {
377 0 : match self {
378 0 : Self::ControlPlane(_, user_info) => &user_info.user,
379 0 : Self::Local(_) => "local",
380 : }
381 0 : }
382 :
383 : /// Authenticate the client via the requested backend, possibly using credentials.
384 0 : #[tracing::instrument(fields(allow_cleartext = allow_cleartext), skip_all)]
385 : pub(crate) async fn authenticate(
386 : self,
387 : ctx: &RequestContext,
388 : client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
389 : allow_cleartext: bool,
390 : config: &'static AuthenticationConfig,
391 : endpoint_rate_limiter: Arc<EndpointRateLimiter>,
392 : ) -> auth::Result<Backend<'a, ComputeCredentials>> {
393 : let res = match self {
394 : Self::ControlPlane(api, user_info) => {
395 : debug!(
396 : user = &*user_info.user,
397 : project = user_info.endpoint(),
398 : "performing authentication using the console"
399 : );
400 :
401 : let credentials = auth_quirks(
402 : ctx,
403 : &*api,
404 : user_info,
405 : client,
406 : allow_cleartext,
407 : config,
408 : endpoint_rate_limiter,
409 : )
410 : .await?;
411 : Backend::ControlPlane(api, credentials)
412 : }
413 : Self::Local(_) => {
414 : return Err(auth::AuthError::bad_auth_method("invalid for local proxy"))
415 : }
416 : };
417 :
418 : // TODO: replace with some metric
419 : info!("user successfully authenticated");
420 : Ok(res)
421 : }
422 : }
423 :
424 : impl Backend<'_, ComputeUserInfo> {
425 0 : pub(crate) async fn get_role_secret(
426 0 : &self,
427 0 : ctx: &RequestContext,
428 0 : ) -> Result<CachedRoleSecret, GetAuthInfoError> {
429 0 : match self {
430 0 : Self::ControlPlane(api, user_info) => api.get_role_secret(ctx, user_info).await,
431 0 : Self::Local(_) => Ok(Cached::new_uncached(None)),
432 : }
433 0 : }
434 :
435 0 : pub(crate) async fn get_allowed_ips_and_secret(
436 0 : &self,
437 0 : ctx: &RequestContext,
438 0 : ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
439 0 : match self {
440 0 : Self::ControlPlane(api, user_info) => {
441 0 : api.get_allowed_ips_and_secret(ctx, user_info).await
442 : }
443 0 : Self::Local(_) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
444 : }
445 0 : }
446 : }
447 :
448 : #[async_trait::async_trait]
449 : impl ComputeConnectBackend for Backend<'_, ComputeCredentials> {
450 13 : async fn wake_compute(
451 13 : &self,
452 13 : ctx: &RequestContext,
453 13 : ) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
454 13 : match self {
455 13 : Self::ControlPlane(api, creds) => api.wake_compute(ctx, &creds.info).await,
456 0 : Self::Local(local) => Ok(Cached::new_uncached(local.node_info.clone())),
457 : }
458 26 : }
459 :
460 6 : fn get_keys(&self) -> &ComputeCredentialKeys {
461 6 : match self {
462 6 : Self::ControlPlane(_, creds) => &creds.keys,
463 0 : Self::Local(_) => &ComputeCredentialKeys::None,
464 : }
465 6 : }
466 : }
467 :
468 : #[cfg(test)]
469 : mod tests {
470 : use std::net::IpAddr;
471 : use std::sync::Arc;
472 : use std::time::Duration;
473 :
474 : use bytes::BytesMut;
475 : use control_plane::AuthSecret;
476 : use fallible_iterator::FallibleIterator;
477 : use once_cell::sync::Lazy;
478 : use postgres_protocol::authentication::sasl::{ChannelBinding, ScramSha256};
479 : use postgres_protocol::message::backend::Message as PgMessage;
480 : use postgres_protocol::message::frontend;
481 : use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
482 :
483 : use super::jwt::JwkCache;
484 : use super::{auth_quirks, AuthRateLimiter};
485 : use crate::auth::backend::MaskedIp;
486 : use crate::auth::{ComputeUserInfoMaybeEndpoint, IpPattern};
487 : use crate::config::AuthenticationConfig;
488 : use crate::context::RequestContext;
489 : use crate::control_plane::{self, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret};
490 : use crate::proxy::NeonOptions;
491 : use crate::rate_limiter::{EndpointRateLimiter, RateBucketInfo};
492 : use crate::scram::threadpool::ThreadPool;
493 : use crate::scram::ServerSecret;
494 : use crate::stream::{PqStream, Stream};
495 :
496 : struct Auth {
497 : ips: Vec<IpPattern>,
498 : secret: AuthSecret,
499 : }
500 :
501 : impl control_plane::ControlPlaneApi for Auth {
502 0 : async fn get_role_secret(
503 0 : &self,
504 0 : _ctx: &RequestContext,
505 0 : _user_info: &super::ComputeUserInfo,
506 0 : ) -> Result<CachedRoleSecret, control_plane::errors::GetAuthInfoError> {
507 0 : Ok(CachedRoleSecret::new_uncached(Some(self.secret.clone())))
508 0 : }
509 :
510 3 : async fn get_allowed_ips_and_secret(
511 3 : &self,
512 3 : _ctx: &RequestContext,
513 3 : _user_info: &super::ComputeUserInfo,
514 3 : ) -> Result<
515 3 : (CachedAllowedIps, Option<CachedRoleSecret>),
516 3 : control_plane::errors::GetAuthInfoError,
517 3 : > {
518 3 : Ok((
519 3 : CachedAllowedIps::new_uncached(Arc::new(self.ips.clone())),
520 3 : Some(CachedRoleSecret::new_uncached(Some(self.secret.clone()))),
521 3 : ))
522 3 : }
523 :
524 0 : async fn get_endpoint_jwks(
525 0 : &self,
526 0 : _ctx: &RequestContext,
527 0 : _endpoint: crate::types::EndpointId,
528 0 : ) -> Result<Vec<super::jwt::AuthRule>, control_plane::errors::GetEndpointJwksError>
529 0 : {
530 0 : unimplemented!()
531 : }
532 :
533 0 : async fn wake_compute(
534 0 : &self,
535 0 : _ctx: &RequestContext,
536 0 : _user_info: &super::ComputeUserInfo,
537 0 : ) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
538 0 : unimplemented!()
539 : }
540 : }
541 :
542 3 : static CONFIG: Lazy<AuthenticationConfig> = Lazy::new(|| AuthenticationConfig {
543 3 : jwks_cache: JwkCache::default(),
544 3 : thread_pool: ThreadPool::new(1),
545 3 : scram_protocol_timeout: std::time::Duration::from_secs(5),
546 3 : rate_limiter_enabled: true,
547 3 : rate_limiter: AuthRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET),
548 3 : rate_limit_ip_subnet: 64,
549 3 : ip_allowlist_check_enabled: true,
550 3 : is_auth_broker: false,
551 3 : accept_jwts: false,
552 3 : console_redirect_confirmation_timeout: std::time::Duration::from_secs(5),
553 3 : });
554 :
555 5 : async fn read_message(r: &mut (impl AsyncRead + Unpin), b: &mut BytesMut) -> PgMessage {
556 : loop {
557 7 : r.read_buf(&mut *b).await.unwrap();
558 7 : if let Some(m) = PgMessage::parse(&mut *b).unwrap() {
559 5 : break m;
560 2 : }
561 : }
562 5 : }
563 :
564 : #[test]
565 1 : fn masked_ip() {
566 1 : let ip_a = IpAddr::V4([127, 0, 0, 1].into());
567 1 : let ip_b = IpAddr::V4([127, 0, 0, 2].into());
568 1 : let ip_c = IpAddr::V4([192, 168, 1, 101].into());
569 1 : let ip_d = IpAddr::V4([192, 168, 1, 102].into());
570 1 : let ip_e = IpAddr::V6("abcd:abcd:abcd:abcd:abcd:abcd:abcd:abcd".parse().unwrap());
571 1 : let ip_f = IpAddr::V6("abcd:abcd:abcd:abcd:1234:abcd:abcd:abcd".parse().unwrap());
572 1 :
573 1 : assert_ne!(MaskedIp::new(ip_a, 64), MaskedIp::new(ip_b, 64));
574 1 : assert_ne!(MaskedIp::new(ip_a, 32), MaskedIp::new(ip_b, 32));
575 1 : assert_eq!(MaskedIp::new(ip_a, 30), MaskedIp::new(ip_b, 30));
576 1 : assert_eq!(MaskedIp::new(ip_c, 30), MaskedIp::new(ip_d, 30));
577 :
578 1 : assert_ne!(MaskedIp::new(ip_e, 128), MaskedIp::new(ip_f, 128));
579 1 : assert_eq!(MaskedIp::new(ip_e, 64), MaskedIp::new(ip_f, 64));
580 1 : }
581 :
582 : #[test]
583 1 : fn test_default_auth_rate_limit_set() {
584 1 : // these values used to exceed u32::MAX
585 1 : assert_eq!(
586 1 : RateBucketInfo::DEFAULT_AUTH_SET,
587 1 : [
588 1 : RateBucketInfo {
589 1 : interval: Duration::from_secs(1),
590 1 : max_rpi: 1000 * 4096,
591 1 : },
592 1 : RateBucketInfo {
593 1 : interval: Duration::from_secs(60),
594 1 : max_rpi: 600 * 4096 * 60,
595 1 : },
596 1 : RateBucketInfo {
597 1 : interval: Duration::from_secs(600),
598 1 : max_rpi: 300 * 4096 * 600,
599 1 : }
600 1 : ]
601 1 : );
602 :
603 4 : for x in RateBucketInfo::DEFAULT_AUTH_SET {
604 3 : let y = x.to_string().parse().unwrap();
605 3 : assert_eq!(x, y);
606 : }
607 1 : }
608 :
609 : #[tokio::test]
610 1 : async fn auth_quirks_scram() {
611 1 : let (mut client, server) = tokio::io::duplex(1024);
612 1 : let mut stream = PqStream::new(Stream::from_raw(server));
613 1 :
614 1 : let ctx = RequestContext::test();
615 1 : let api = Auth {
616 1 : ips: vec![],
617 1 : secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
618 1 : };
619 1 :
620 1 : let user_info = ComputeUserInfoMaybeEndpoint {
621 1 : user: "conrad".into(),
622 1 : endpoint_id: Some("endpoint".into()),
623 1 : options: NeonOptions::default(),
624 1 : };
625 1 :
626 1 : let handle = tokio::spawn(async move {
627 1 : let mut scram = ScramSha256::new(b"my-secret-password", ChannelBinding::unsupported());
628 1 :
629 1 : let mut read = BytesMut::new();
630 1 :
631 1 : // server should offer scram
632 1 : match read_message(&mut client, &mut read).await {
633 1 : PgMessage::AuthenticationSasl(a) => {
634 1 : let options: Vec<&str> = a.mechanisms().collect().unwrap();
635 1 : assert_eq!(options, ["SCRAM-SHA-256"]);
636 1 : }
637 1 : _ => panic!("wrong message"),
638 1 : }
639 1 :
640 1 : // client sends client-first-message
641 1 : let mut write = BytesMut::new();
642 1 : frontend::sasl_initial_response("SCRAM-SHA-256", scram.message(), &mut write).unwrap();
643 1 : client.write_all(&write).await.unwrap();
644 1 :
645 1 : // server response with server-first-message
646 1 : match read_message(&mut client, &mut read).await {
647 1 : PgMessage::AuthenticationSaslContinue(a) => {
648 1 : scram.update(a.data()).await.unwrap();
649 1 : }
650 1 : _ => panic!("wrong message"),
651 1 : }
652 1 :
653 1 : // client response with client-final-message
654 1 : write.clear();
655 1 : frontend::sasl_response(scram.message(), &mut write).unwrap();
656 1 : client.write_all(&write).await.unwrap();
657 1 :
658 1 : // server response with server-final-message
659 1 : match read_message(&mut client, &mut read).await {
660 1 : PgMessage::AuthenticationSaslFinal(a) => {
661 1 : scram.finish(a.data()).unwrap();
662 1 : }
663 1 : _ => panic!("wrong message"),
664 1 : }
665 1 : });
666 1 : let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
667 1 : EndpointRateLimiter::DEFAULT,
668 1 : 64,
669 1 : ));
670 1 :
671 1 : let _creds = auth_quirks(
672 1 : &ctx,
673 1 : &api,
674 1 : user_info,
675 1 : &mut stream,
676 1 : false,
677 1 : &CONFIG,
678 1 : endpoint_rate_limiter,
679 1 : )
680 1 : .await
681 1 : .unwrap();
682 1 :
683 1 : handle.await.unwrap();
684 1 : }
685 :
686 : #[tokio::test]
687 1 : async fn auth_quirks_cleartext() {
688 1 : let (mut client, server) = tokio::io::duplex(1024);
689 1 : let mut stream = PqStream::new(Stream::from_raw(server));
690 1 :
691 1 : let ctx = RequestContext::test();
692 1 : let api = Auth {
693 1 : ips: vec![],
694 1 : secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
695 1 : };
696 1 :
697 1 : let user_info = ComputeUserInfoMaybeEndpoint {
698 1 : user: "conrad".into(),
699 1 : endpoint_id: Some("endpoint".into()),
700 1 : options: NeonOptions::default(),
701 1 : };
702 1 :
703 1 : let handle = tokio::spawn(async move {
704 1 : let mut read = BytesMut::new();
705 1 : let mut write = BytesMut::new();
706 1 :
707 1 : // server should offer cleartext
708 1 : match read_message(&mut client, &mut read).await {
709 1 : PgMessage::AuthenticationCleartextPassword => {}
710 1 : _ => panic!("wrong message"),
711 1 : }
712 1 :
713 1 : // client responds with password
714 1 : write.clear();
715 1 : frontend::password_message(b"my-secret-password", &mut write).unwrap();
716 1 : client.write_all(&write).await.unwrap();
717 1 : });
718 1 : let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
719 1 : EndpointRateLimiter::DEFAULT,
720 1 : 64,
721 1 : ));
722 1 :
723 1 : let _creds = auth_quirks(
724 1 : &ctx,
725 1 : &api,
726 1 : user_info,
727 1 : &mut stream,
728 1 : true,
729 1 : &CONFIG,
730 1 : endpoint_rate_limiter,
731 1 : )
732 1 : .await
733 1 : .unwrap();
734 1 :
735 1 : handle.await.unwrap();
736 1 : }
737 :
738 : #[tokio::test]
739 1 : async fn auth_quirks_password_hack() {
740 1 : let (mut client, server) = tokio::io::duplex(1024);
741 1 : let mut stream = PqStream::new(Stream::from_raw(server));
742 1 :
743 1 : let ctx = RequestContext::test();
744 1 : let api = Auth {
745 1 : ips: vec![],
746 1 : secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
747 1 : };
748 1 :
749 1 : let user_info = ComputeUserInfoMaybeEndpoint {
750 1 : user: "conrad".into(),
751 1 : endpoint_id: None,
752 1 : options: NeonOptions::default(),
753 1 : };
754 1 :
755 1 : let handle = tokio::spawn(async move {
756 1 : let mut read = BytesMut::new();
757 1 :
758 1 : // server should offer cleartext
759 1 : match read_message(&mut client, &mut read).await {
760 1 : PgMessage::AuthenticationCleartextPassword => {}
761 1 : _ => panic!("wrong message"),
762 1 : }
763 1 :
764 1 : // client responds with password
765 1 : let mut write = BytesMut::new();
766 1 : frontend::password_message(b"endpoint=my-endpoint;my-secret-password", &mut write)
767 1 : .unwrap();
768 1 : client.write_all(&write).await.unwrap();
769 1 : });
770 1 :
771 1 : let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
772 1 : EndpointRateLimiter::DEFAULT,
773 1 : 64,
774 1 : ));
775 1 :
776 1 : let creds = auth_quirks(
777 1 : &ctx,
778 1 : &api,
779 1 : user_info,
780 1 : &mut stream,
781 1 : true,
782 1 : &CONFIG,
783 1 : endpoint_rate_limiter,
784 1 : )
785 1 : .await
786 1 : .unwrap();
787 1 :
788 1 : assert_eq!(creds.info.endpoint, "my-endpoint");
789 1 :
790 1 : handle.await.unwrap();
791 1 : }
792 : }
|