Line data Source code
1 : mod classic;
2 : mod console_redirect;
3 : mod hacks;
4 : pub mod jwt;
5 : pub mod local;
6 :
7 : use std::net::IpAddr;
8 : use std::sync::Arc;
9 : use std::time::Duration;
10 :
11 : pub use console_redirect::ConsoleRedirectBackend;
12 : pub(crate) use console_redirect::ConsoleRedirectError;
13 : use ipnet::{Ipv4Net, Ipv6Net};
14 : use local::LocalBackend;
15 : use tokio::io::{AsyncRead, AsyncWrite};
16 : use tokio_postgres::config::AuthKeys;
17 : use tracing::{debug, info, warn};
18 :
19 : use crate::auth::credentials::check_peer_addr_is_in_list;
20 : use crate::auth::{self, validate_password_and_exchange, AuthError, ComputeUserInfoMaybeEndpoint};
21 : use crate::cache::Cached;
22 : use crate::config::AuthenticationConfig;
23 : use crate::context::RequestContext;
24 : use crate::control_plane::client::ControlPlaneClient;
25 : use crate::control_plane::errors::GetAuthInfoError;
26 : use crate::control_plane::{
27 : self, AuthSecret, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret, ControlPlaneApi,
28 : };
29 : use crate::intern::EndpointIdInt;
30 : use crate::metrics::Metrics;
31 : use crate::proxy::connect_compute::ComputeConnectBackend;
32 : use crate::proxy::NeonOptions;
33 : use crate::rate_limiter::{BucketRateLimiter, EndpointRateLimiter, RateBucketInfo};
34 : use crate::stream::Stream;
35 : use crate::types::{EndpointCacheKey, EndpointId, RoleName};
36 : use crate::{scram, stream};
37 :
38 : /// Alternative to [`std::borrow::Cow`] but doesn't need `T: ToOwned` as we don't need that functionality
39 : pub enum MaybeOwned<'a, T> {
40 : Owned(T),
41 : Borrowed(&'a T),
42 : }
43 :
44 : impl<T> std::ops::Deref for MaybeOwned<'_, T> {
45 : type Target = T;
46 :
47 13 : fn deref(&self) -> &Self::Target {
48 13 : match self {
49 13 : MaybeOwned::Owned(t) => t,
50 0 : MaybeOwned::Borrowed(t) => t,
51 : }
52 13 : }
53 : }
54 :
55 : /// This type serves two purposes:
56 : ///
57 : /// * When `T` is `()`, it's just a regular auth backend selector
58 : /// which we use in [`crate::config::ProxyConfig`].
59 : ///
60 : /// * However, when we substitute `T` with [`ComputeUserInfoMaybeEndpoint`],
61 : /// this helps us provide the credentials only to those auth
62 : /// backends which require them for the authentication process.
63 : pub enum Backend<'a, T> {
64 : /// Cloud API (V2).
65 : ControlPlane(MaybeOwned<'a, ControlPlaneClient>, T),
66 : /// Local proxy uses configured auth credentials and does not wake compute
67 : Local(MaybeOwned<'a, LocalBackend>),
68 : }
69 :
70 : impl std::fmt::Display for Backend<'_, ()> {
71 0 : fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72 0 : match self {
73 0 : Self::ControlPlane(api, ()) => match &**api {
74 0 : ControlPlaneClient::Neon(endpoint) => fmt
75 0 : .debug_tuple("ControlPlane::Neon")
76 0 : .field(&endpoint.url())
77 0 : .finish(),
78 : #[cfg(any(test, feature = "testing"))]
79 0 : ControlPlaneClient::PostgresMock(endpoint) => fmt
80 0 : .debug_tuple("ControlPlane::PostgresMock")
81 0 : .field(&endpoint.url())
82 0 : .finish(),
83 : #[cfg(test)]
84 0 : ControlPlaneClient::Test(_) => fmt.debug_tuple("ControlPlane::Test").finish(),
85 : },
86 0 : Self::Local(_) => fmt.debug_tuple("Local").finish(),
87 : }
88 0 : }
89 : }
90 :
91 : impl<T> Backend<'_, T> {
92 : /// Very similar to [`std::option::Option::as_ref`].
93 : /// This helps us pass structured config to async tasks.
94 0 : pub(crate) fn as_ref(&self) -> Backend<'_, &T> {
95 0 : match self {
96 0 : Self::ControlPlane(c, x) => Backend::ControlPlane(MaybeOwned::Borrowed(c), x),
97 0 : Self::Local(l) => Backend::Local(MaybeOwned::Borrowed(l)),
98 : }
99 0 : }
100 : }
101 :
102 : impl<'a, T> Backend<'a, T> {
103 : /// Very similar to [`std::option::Option::map`].
104 : /// Maps [`Backend<T>`] to [`Backend<R>`] by applying
105 : /// a function to a contained value.
106 0 : pub(crate) fn map<R>(self, f: impl FnOnce(T) -> R) -> Backend<'a, R> {
107 0 : match self {
108 0 : Self::ControlPlane(c, x) => Backend::ControlPlane(c, f(x)),
109 0 : Self::Local(l) => Backend::Local(l),
110 : }
111 0 : }
112 : }
113 : impl<'a, T, E> Backend<'a, Result<T, E>> {
114 : /// Very similar to [`std::option::Option::transpose`].
115 : /// This is most useful for error handling.
116 0 : pub(crate) fn transpose(self) -> Result<Backend<'a, T>, E> {
117 0 : match self {
118 0 : Self::ControlPlane(c, x) => x.map(|x| Backend::ControlPlane(c, x)),
119 0 : Self::Local(l) => Ok(Backend::Local(l)),
120 : }
121 0 : }
122 : }
123 :
124 : pub(crate) struct ComputeCredentials {
125 : pub(crate) info: ComputeUserInfo,
126 : pub(crate) keys: ComputeCredentialKeys,
127 : }
128 :
129 : #[derive(Debug, Clone)]
130 : pub(crate) struct ComputeUserInfoNoEndpoint {
131 : pub(crate) user: RoleName,
132 : pub(crate) options: NeonOptions,
133 : }
134 :
135 : #[derive(Debug, Clone)]
136 : pub(crate) struct ComputeUserInfo {
137 : pub(crate) endpoint: EndpointId,
138 : pub(crate) user: RoleName,
139 : pub(crate) options: NeonOptions,
140 : }
141 :
142 : impl ComputeUserInfo {
143 2 : pub(crate) fn endpoint_cache_key(&self) -> EndpointCacheKey {
144 2 : self.options.get_cache_key(&self.endpoint)
145 2 : }
146 : }
147 :
148 : #[cfg_attr(test, derive(Debug))]
149 : pub(crate) enum ComputeCredentialKeys {
150 : #[cfg(any(test, feature = "testing"))]
151 : Password(Vec<u8>),
152 : AuthKeys(AuthKeys),
153 : JwtPayload(Vec<u8>),
154 : None,
155 : }
156 :
157 : impl TryFrom<ComputeUserInfoMaybeEndpoint> for ComputeUserInfo {
158 : // user name
159 : type Error = ComputeUserInfoNoEndpoint;
160 :
161 3 : fn try_from(user_info: ComputeUserInfoMaybeEndpoint) -> Result<Self, Self::Error> {
162 3 : match user_info.endpoint_id {
163 1 : None => Err(ComputeUserInfoNoEndpoint {
164 1 : user: user_info.user,
165 1 : options: user_info.options,
166 1 : }),
167 2 : Some(endpoint) => Ok(ComputeUserInfo {
168 2 : endpoint,
169 2 : user: user_info.user,
170 2 : options: user_info.options,
171 2 : }),
172 : }
173 3 : }
174 : }
175 :
176 : #[derive(PartialEq, PartialOrd, Hash, Eq, Ord, Debug, Copy, Clone)]
177 : pub struct MaskedIp(IpAddr);
178 :
179 : impl MaskedIp {
180 15 : fn new(value: IpAddr, prefix: u8) -> Self {
181 15 : match value {
182 11 : IpAddr::V4(v4) => Self(IpAddr::V4(
183 11 : Ipv4Net::new(v4, prefix).map_or(v4, |x| x.trunc().addr()),
184 11 : )),
185 4 : IpAddr::V6(v6) => Self(IpAddr::V6(
186 4 : Ipv6Net::new(v6, prefix).map_or(v6, |x| x.trunc().addr()),
187 4 : )),
188 : }
189 15 : }
190 : }
191 :
192 : // This can't be just per IP because that would limit some PaaS that share IP addresses
193 : pub type AuthRateLimiter = BucketRateLimiter<(EndpointIdInt, MaskedIp)>;
194 :
195 : impl RateBucketInfo {
196 : /// All of these are per endpoint-maskedip pair.
197 : /// Context: 4096 rounds of pbkdf2 take about 1ms of cpu time to execute (1 milli-cpu-second or 1mcpus).
198 : ///
199 : /// First bucket: 1000mcpus total per endpoint-ip pair
200 : /// * 4096000 requests per second with 1 hash rounds.
201 : /// * 1000 requests per second with 4096 hash rounds.
202 : /// * 6.8 requests per second with 600000 hash rounds.
203 : pub const DEFAULT_AUTH_SET: [Self; 3] = [
204 : Self::new(1000 * 4096, Duration::from_secs(1)),
205 : Self::new(600 * 4096, Duration::from_secs(60)),
206 : Self::new(300 * 4096, Duration::from_secs(600)),
207 : ];
208 : }
209 :
210 : impl AuthenticationConfig {
211 3 : pub(crate) fn check_rate_limit(
212 3 : &self,
213 3 : ctx: &RequestContext,
214 3 : secret: AuthSecret,
215 3 : endpoint: &EndpointId,
216 3 : is_cleartext: bool,
217 3 : ) -> auth::Result<AuthSecret> {
218 3 : // we have validated the endpoint exists, so let's intern it.
219 3 : let endpoint_int = EndpointIdInt::from(endpoint.normalize());
220 :
221 : // only count the full hash count if password hack or websocket flow.
222 : // in other words, if proxy needs to run the hashing
223 3 : let password_weight = if is_cleartext {
224 2 : match &secret {
225 : #[cfg(any(test, feature = "testing"))]
226 0 : AuthSecret::Md5(_) => 1,
227 2 : AuthSecret::Scram(s) => s.iterations + 1,
228 : }
229 : } else {
230 : // validating scram takes just 1 hmac_sha_256 operation.
231 1 : 1
232 : };
233 :
234 3 : let limit_not_exceeded = self.rate_limiter.check(
235 3 : (
236 3 : endpoint_int,
237 3 : MaskedIp::new(ctx.peer_addr(), self.rate_limit_ip_subnet),
238 3 : ),
239 3 : password_weight,
240 3 : );
241 3 :
242 3 : if !limit_not_exceeded {
243 0 : warn!(
244 : enabled = self.rate_limiter_enabled,
245 0 : "rate limiting authentication"
246 : );
247 0 : Metrics::get().proxy.requests_auth_rate_limits_total.inc();
248 0 : Metrics::get()
249 0 : .proxy
250 0 : .endpoints_auth_rate_limits
251 0 : .get_metric()
252 0 : .measure(endpoint);
253 0 :
254 0 : if self.rate_limiter_enabled {
255 0 : return Err(auth::AuthError::too_many_connections());
256 0 : }
257 3 : }
258 :
259 3 : Ok(secret)
260 3 : }
261 : }
262 :
263 : /// True to its name, this function encapsulates our current auth trade-offs.
264 : /// Here, we choose the appropriate auth flow based on circumstances.
265 : ///
266 : /// All authentication flows will emit an AuthenticationOk message if successful.
267 3 : async fn auth_quirks(
268 3 : ctx: &RequestContext,
269 3 : api: &impl control_plane::ControlPlaneApi,
270 3 : user_info: ComputeUserInfoMaybeEndpoint,
271 3 : client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
272 3 : allow_cleartext: bool,
273 3 : config: &'static AuthenticationConfig,
274 3 : endpoint_rate_limiter: Arc<EndpointRateLimiter>,
275 3 : ) -> auth::Result<ComputeCredentials> {
276 : // If there's no project so far, that entails that client doesn't
277 : // support SNI or other means of passing the endpoint (project) name.
278 : // We now expect to see a very specific payload in the place of password.
279 3 : let (info, unauthenticated_password) = match user_info.try_into() {
280 1 : Err(info) => {
281 1 : let (info, password) =
282 1 : hacks::password_hack_no_authentication(ctx, info, client).await?;
283 1 : ctx.set_endpoint_id(info.endpoint.clone());
284 1 : (info, Some(password))
285 : }
286 2 : Ok(info) => (info, None),
287 : };
288 :
289 3 : debug!("fetching user's authentication info");
290 3 : let (allowed_ips, maybe_secret) = api.get_allowed_ips_and_secret(ctx, &info).await?;
291 :
292 : // check allowed list
293 3 : if config.ip_allowlist_check_enabled
294 3 : && !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips)
295 : {
296 0 : return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr()));
297 3 : }
298 3 :
299 3 : if !endpoint_rate_limiter.check(info.endpoint.clone().into(), 1) {
300 0 : return Err(AuthError::too_many_connections());
301 3 : }
302 3 : let cached_secret = match maybe_secret {
303 3 : Some(secret) => secret,
304 0 : None => api.get_role_secret(ctx, &info).await?,
305 : };
306 3 : let (cached_entry, secret) = cached_secret.take_value();
307 :
308 3 : let secret = if let Some(secret) = secret {
309 3 : config.check_rate_limit(
310 3 : ctx,
311 3 : secret,
312 3 : &info.endpoint,
313 3 : unauthenticated_password.is_some() || allow_cleartext,
314 0 : )?
315 : } else {
316 : // If we don't have an authentication secret, we mock one to
317 : // prevent malicious probing (possible due to missing protocol steps).
318 : // This mocked secret will never lead to successful authentication.
319 0 : info!("authentication info not found, mocking it");
320 0 : AuthSecret::Scram(scram::ServerSecret::mock(rand::random()))
321 : };
322 :
323 3 : match authenticate_with_secret(
324 3 : ctx,
325 3 : secret,
326 3 : info,
327 3 : client,
328 3 : unauthenticated_password,
329 3 : allow_cleartext,
330 3 : config,
331 3 : )
332 5 : .await
333 : {
334 3 : Ok(keys) => Ok(keys),
335 0 : Err(e) => {
336 0 : if e.is_password_failed() {
337 0 : // The password could have been changed, so we invalidate the cache.
338 0 : cached_entry.invalidate();
339 0 : }
340 0 : Err(e)
341 : }
342 : }
343 3 : }
344 :
345 3 : async fn authenticate_with_secret(
346 3 : ctx: &RequestContext,
347 3 : secret: AuthSecret,
348 3 : info: ComputeUserInfo,
349 3 : client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
350 3 : unauthenticated_password: Option<Vec<u8>>,
351 3 : allow_cleartext: bool,
352 3 : config: &'static AuthenticationConfig,
353 3 : ) -> auth::Result<ComputeCredentials> {
354 3 : if let Some(password) = unauthenticated_password {
355 1 : let ep = EndpointIdInt::from(&info.endpoint);
356 :
357 1 : let auth_outcome =
358 1 : validate_password_and_exchange(&config.thread_pool, ep, &password, secret).await?;
359 1 : let keys = match auth_outcome {
360 1 : crate::sasl::Outcome::Success(key) => key,
361 0 : crate::sasl::Outcome::Failure(reason) => {
362 0 : info!("auth backend failed with an error: {reason}");
363 0 : return Err(auth::AuthError::password_failed(&*info.user));
364 : }
365 : };
366 :
367 : // we have authenticated the password
368 1 : client.write_message_noflush(&pq_proto::BeMessage::AuthenticationOk)?;
369 :
370 1 : return Ok(ComputeCredentials { info, keys });
371 2 : }
372 2 :
373 2 : // -- the remaining flows are self-authenticating --
374 2 :
375 2 : // Perform cleartext auth if we're allowed to do that.
376 2 : // Currently, we use it for websocket connections (latency).
377 2 : if allow_cleartext {
378 1 : ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
379 2 : return hacks::authenticate_cleartext(ctx, info, client, secret, config).await;
380 1 : }
381 1 :
382 1 : // Finally, proceed with the main auth flow (SCRAM-based).
383 2 : classic::authenticate(ctx, info, client, config, secret).await
384 3 : }
385 :
386 : impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> {
387 : /// Get username from the credentials.
388 0 : pub(crate) fn get_user(&self) -> &str {
389 0 : match self {
390 0 : Self::ControlPlane(_, user_info) => &user_info.user,
391 0 : Self::Local(_) => "local",
392 : }
393 0 : }
394 :
395 : /// Authenticate the client via the requested backend, possibly using credentials.
396 0 : #[tracing::instrument(fields(allow_cleartext = allow_cleartext), skip_all)]
397 : pub(crate) async fn authenticate(
398 : self,
399 : ctx: &RequestContext,
400 : client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
401 : allow_cleartext: bool,
402 : config: &'static AuthenticationConfig,
403 : endpoint_rate_limiter: Arc<EndpointRateLimiter>,
404 : ) -> auth::Result<Backend<'a, ComputeCredentials>> {
405 : let res = match self {
406 : Self::ControlPlane(api, user_info) => {
407 : debug!(
408 : user = &*user_info.user,
409 : project = user_info.endpoint(),
410 : "performing authentication using the console"
411 : );
412 :
413 : let credentials = auth_quirks(
414 : ctx,
415 : &*api,
416 : user_info,
417 : client,
418 : allow_cleartext,
419 : config,
420 : endpoint_rate_limiter,
421 : )
422 : .await?;
423 : Backend::ControlPlane(api, credentials)
424 : }
425 : Self::Local(_) => {
426 : return Err(auth::AuthError::bad_auth_method("invalid for local proxy"))
427 : }
428 : };
429 :
430 : // TODO: replace with some metric
431 : info!("user successfully authenticated");
432 : Ok(res)
433 : }
434 : }
435 :
436 : impl Backend<'_, ComputeUserInfo> {
437 0 : pub(crate) async fn get_role_secret(
438 0 : &self,
439 0 : ctx: &RequestContext,
440 0 : ) -> Result<CachedRoleSecret, GetAuthInfoError> {
441 0 : match self {
442 0 : Self::ControlPlane(api, user_info) => api.get_role_secret(ctx, user_info).await,
443 0 : Self::Local(_) => Ok(Cached::new_uncached(None)),
444 : }
445 0 : }
446 :
447 0 : pub(crate) async fn get_allowed_ips_and_secret(
448 0 : &self,
449 0 : ctx: &RequestContext,
450 0 : ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
451 0 : match self {
452 0 : Self::ControlPlane(api, user_info) => {
453 0 : api.get_allowed_ips_and_secret(ctx, user_info).await
454 : }
455 0 : Self::Local(_) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
456 : }
457 0 : }
458 : }
459 :
460 : #[async_trait::async_trait]
461 : impl ComputeConnectBackend for Backend<'_, ComputeCredentials> {
462 13 : async fn wake_compute(
463 13 : &self,
464 13 : ctx: &RequestContext,
465 13 : ) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
466 13 : match self {
467 13 : Self::ControlPlane(api, creds) => api.wake_compute(ctx, &creds.info).await,
468 0 : Self::Local(local) => Ok(Cached::new_uncached(local.node_info.clone())),
469 : }
470 26 : }
471 :
472 6 : fn get_keys(&self) -> &ComputeCredentialKeys {
473 6 : match self {
474 6 : Self::ControlPlane(_, creds) => &creds.keys,
475 0 : Self::Local(_) => &ComputeCredentialKeys::None,
476 : }
477 6 : }
478 : }
479 :
480 : #[cfg(test)]
481 : mod tests {
482 : use std::net::IpAddr;
483 : use std::sync::Arc;
484 : use std::time::Duration;
485 :
486 : use bytes::BytesMut;
487 : use control_plane::AuthSecret;
488 : use fallible_iterator::FallibleIterator;
489 : use once_cell::sync::Lazy;
490 : use postgres_protocol::authentication::sasl::{ChannelBinding, ScramSha256};
491 : use postgres_protocol::message::backend::Message as PgMessage;
492 : use postgres_protocol::message::frontend;
493 : use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
494 :
495 : use super::jwt::JwkCache;
496 : use super::{auth_quirks, AuthRateLimiter};
497 : use crate::auth::backend::MaskedIp;
498 : use crate::auth::{ComputeUserInfoMaybeEndpoint, IpPattern};
499 : use crate::config::AuthenticationConfig;
500 : use crate::context::RequestContext;
501 : use crate::control_plane::{self, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret};
502 : use crate::proxy::NeonOptions;
503 : use crate::rate_limiter::{EndpointRateLimiter, RateBucketInfo};
504 : use crate::scram::threadpool::ThreadPool;
505 : use crate::scram::ServerSecret;
506 : use crate::stream::{PqStream, Stream};
507 :
508 : struct Auth {
509 : ips: Vec<IpPattern>,
510 : secret: AuthSecret,
511 : }
512 :
513 : impl control_plane::ControlPlaneApi for Auth {
514 0 : async fn get_role_secret(
515 0 : &self,
516 0 : _ctx: &RequestContext,
517 0 : _user_info: &super::ComputeUserInfo,
518 0 : ) -> Result<CachedRoleSecret, control_plane::errors::GetAuthInfoError> {
519 0 : Ok(CachedRoleSecret::new_uncached(Some(self.secret.clone())))
520 0 : }
521 :
522 3 : async fn get_allowed_ips_and_secret(
523 3 : &self,
524 3 : _ctx: &RequestContext,
525 3 : _user_info: &super::ComputeUserInfo,
526 3 : ) -> Result<
527 3 : (CachedAllowedIps, Option<CachedRoleSecret>),
528 3 : control_plane::errors::GetAuthInfoError,
529 3 : > {
530 3 : Ok((
531 3 : CachedAllowedIps::new_uncached(Arc::new(self.ips.clone())),
532 3 : Some(CachedRoleSecret::new_uncached(Some(self.secret.clone()))),
533 3 : ))
534 3 : }
535 :
536 0 : async fn get_endpoint_jwks(
537 0 : &self,
538 0 : _ctx: &RequestContext,
539 0 : _endpoint: crate::types::EndpointId,
540 0 : ) -> Result<Vec<super::jwt::AuthRule>, control_plane::errors::GetEndpointJwksError>
541 0 : {
542 0 : unimplemented!()
543 : }
544 :
545 0 : async fn wake_compute(
546 0 : &self,
547 0 : _ctx: &RequestContext,
548 0 : _user_info: &super::ComputeUserInfo,
549 0 : ) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
550 0 : unimplemented!()
551 : }
552 : }
553 :
554 3 : static CONFIG: Lazy<AuthenticationConfig> = Lazy::new(|| AuthenticationConfig {
555 3 : jwks_cache: JwkCache::default(),
556 3 : thread_pool: ThreadPool::new(1),
557 3 : scram_protocol_timeout: std::time::Duration::from_secs(5),
558 3 : rate_limiter_enabled: true,
559 3 : rate_limiter: AuthRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET),
560 3 : rate_limit_ip_subnet: 64,
561 3 : ip_allowlist_check_enabled: true,
562 3 : is_auth_broker: false,
563 3 : accept_jwts: false,
564 3 : console_redirect_confirmation_timeout: std::time::Duration::from_secs(5),
565 3 : });
566 :
567 5 : async fn read_message(r: &mut (impl AsyncRead + Unpin), b: &mut BytesMut) -> PgMessage {
568 : loop {
569 7 : r.read_buf(&mut *b).await.unwrap();
570 7 : if let Some(m) = PgMessage::parse(&mut *b).unwrap() {
571 5 : break m;
572 2 : }
573 : }
574 5 : }
575 :
576 : #[test]
577 1 : fn masked_ip() {
578 1 : let ip_a = IpAddr::V4([127, 0, 0, 1].into());
579 1 : let ip_b = IpAddr::V4([127, 0, 0, 2].into());
580 1 : let ip_c = IpAddr::V4([192, 168, 1, 101].into());
581 1 : let ip_d = IpAddr::V4([192, 168, 1, 102].into());
582 1 : let ip_e = IpAddr::V6("abcd:abcd:abcd:abcd:abcd:abcd:abcd:abcd".parse().unwrap());
583 1 : let ip_f = IpAddr::V6("abcd:abcd:abcd:abcd:1234:abcd:abcd:abcd".parse().unwrap());
584 1 :
585 1 : assert_ne!(MaskedIp::new(ip_a, 64), MaskedIp::new(ip_b, 64));
586 1 : assert_ne!(MaskedIp::new(ip_a, 32), MaskedIp::new(ip_b, 32));
587 1 : assert_eq!(MaskedIp::new(ip_a, 30), MaskedIp::new(ip_b, 30));
588 1 : assert_eq!(MaskedIp::new(ip_c, 30), MaskedIp::new(ip_d, 30));
589 :
590 1 : assert_ne!(MaskedIp::new(ip_e, 128), MaskedIp::new(ip_f, 128));
591 1 : assert_eq!(MaskedIp::new(ip_e, 64), MaskedIp::new(ip_f, 64));
592 1 : }
593 :
594 : #[test]
595 1 : fn test_default_auth_rate_limit_set() {
596 1 : // these values used to exceed u32::MAX
597 1 : assert_eq!(
598 1 : RateBucketInfo::DEFAULT_AUTH_SET,
599 1 : [
600 1 : RateBucketInfo {
601 1 : interval: Duration::from_secs(1),
602 1 : max_rpi: 1000 * 4096,
603 1 : },
604 1 : RateBucketInfo {
605 1 : interval: Duration::from_secs(60),
606 1 : max_rpi: 600 * 4096 * 60,
607 1 : },
608 1 : RateBucketInfo {
609 1 : interval: Duration::from_secs(600),
610 1 : max_rpi: 300 * 4096 * 600,
611 1 : }
612 1 : ]
613 1 : );
614 :
615 4 : for x in RateBucketInfo::DEFAULT_AUTH_SET {
616 3 : let y = x.to_string().parse().unwrap();
617 3 : assert_eq!(x, y);
618 : }
619 1 : }
620 :
621 : #[tokio::test]
622 1 : async fn auth_quirks_scram() {
623 1 : let (mut client, server) = tokio::io::duplex(1024);
624 1 : let mut stream = PqStream::new(Stream::from_raw(server));
625 1 :
626 1 : let ctx = RequestContext::test();
627 1 : let api = Auth {
628 1 : ips: vec![],
629 3 : secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
630 1 : };
631 1 :
632 1 : let user_info = ComputeUserInfoMaybeEndpoint {
633 1 : user: "conrad".into(),
634 1 : endpoint_id: Some("endpoint".into()),
635 1 : options: NeonOptions::default(),
636 1 : };
637 1 :
638 1 : let handle = tokio::spawn(async move {
639 1 : let mut scram = ScramSha256::new(b"my-secret-password", ChannelBinding::unsupported());
640 1 :
641 1 : let mut read = BytesMut::new();
642 1 :
643 1 : // server should offer scram
644 1 : match read_message(&mut client, &mut read).await {
645 1 : PgMessage::AuthenticationSasl(a) => {
646 1 : let options: Vec<&str> = a.mechanisms().collect().unwrap();
647 1 : assert_eq!(options, ["SCRAM-SHA-256"]);
648 1 : }
649 1 : _ => panic!("wrong message"),
650 1 : }
651 1 :
652 1 : // client sends client-first-message
653 1 : let mut write = BytesMut::new();
654 1 : frontend::sasl_initial_response("SCRAM-SHA-256", scram.message(), &mut write).unwrap();
655 1 : client.write_all(&write).await.unwrap();
656 1 :
657 1 : // server response with server-first-message
658 1 : match read_message(&mut client, &mut read).await {
659 1 : PgMessage::AuthenticationSaslContinue(a) => {
660 3 : scram.update(a.data()).await.unwrap();
661 1 : }
662 1 : _ => panic!("wrong message"),
663 1 : }
664 1 :
665 1 : // client response with client-final-message
666 1 : write.clear();
667 1 : frontend::sasl_response(scram.message(), &mut write).unwrap();
668 1 : client.write_all(&write).await.unwrap();
669 1 :
670 1 : // server response with server-final-message
671 1 : match read_message(&mut client, &mut read).await {
672 1 : PgMessage::AuthenticationSaslFinal(a) => {
673 1 : scram.finish(a.data()).unwrap();
674 1 : }
675 1 : _ => panic!("wrong message"),
676 1 : }
677 1 : });
678 1 : let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
679 1 : EndpointRateLimiter::DEFAULT,
680 1 : 64,
681 1 : ));
682 1 :
683 1 : let _creds = auth_quirks(
684 1 : &ctx,
685 1 : &api,
686 1 : user_info,
687 1 : &mut stream,
688 1 : false,
689 1 : &CONFIG,
690 1 : endpoint_rate_limiter,
691 1 : )
692 2 : .await
693 1 : .unwrap();
694 1 :
695 1 : handle.await.unwrap();
696 1 : }
697 :
698 : #[tokio::test]
699 1 : async fn auth_quirks_cleartext() {
700 1 : let (mut client, server) = tokio::io::duplex(1024);
701 1 : let mut stream = PqStream::new(Stream::from_raw(server));
702 1 :
703 1 : let ctx = RequestContext::test();
704 1 : let api = Auth {
705 1 : ips: vec![],
706 3 : secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
707 1 : };
708 1 :
709 1 : let user_info = ComputeUserInfoMaybeEndpoint {
710 1 : user: "conrad".into(),
711 1 : endpoint_id: Some("endpoint".into()),
712 1 : options: NeonOptions::default(),
713 1 : };
714 1 :
715 1 : let handle = tokio::spawn(async move {
716 1 : let mut read = BytesMut::new();
717 1 : let mut write = BytesMut::new();
718 1 :
719 1 : // server should offer cleartext
720 1 : match read_message(&mut client, &mut read).await {
721 1 : PgMessage::AuthenticationCleartextPassword => {}
722 1 : _ => panic!("wrong message"),
723 1 : }
724 1 :
725 1 : // client responds with password
726 1 : write.clear();
727 1 : frontend::password_message(b"my-secret-password", &mut write).unwrap();
728 1 : client.write_all(&write).await.unwrap();
729 1 : });
730 1 : let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
731 1 : EndpointRateLimiter::DEFAULT,
732 1 : 64,
733 1 : ));
734 1 :
735 1 : let _creds = auth_quirks(
736 1 : &ctx,
737 1 : &api,
738 1 : user_info,
739 1 : &mut stream,
740 1 : true,
741 1 : &CONFIG,
742 1 : endpoint_rate_limiter,
743 1 : )
744 2 : .await
745 1 : .unwrap();
746 1 :
747 1 : handle.await.unwrap();
748 1 : }
749 :
750 : #[tokio::test]
751 1 : async fn auth_quirks_password_hack() {
752 1 : let (mut client, server) = tokio::io::duplex(1024);
753 1 : let mut stream = PqStream::new(Stream::from_raw(server));
754 1 :
755 1 : let ctx = RequestContext::test();
756 1 : let api = Auth {
757 1 : ips: vec![],
758 3 : secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
759 1 : };
760 1 :
761 1 : let user_info = ComputeUserInfoMaybeEndpoint {
762 1 : user: "conrad".into(),
763 1 : endpoint_id: None,
764 1 : options: NeonOptions::default(),
765 1 : };
766 1 :
767 1 : let handle = tokio::spawn(async move {
768 1 : let mut read = BytesMut::new();
769 1 :
770 1 : // server should offer cleartext
771 1 : match read_message(&mut client, &mut read).await {
772 1 : PgMessage::AuthenticationCleartextPassword => {}
773 1 : _ => panic!("wrong message"),
774 1 : }
775 1 :
776 1 : // client responds with password
777 1 : let mut write = BytesMut::new();
778 1 : frontend::password_message(b"endpoint=my-endpoint;my-secret-password", &mut write)
779 1 : .unwrap();
780 1 : client.write_all(&write).await.unwrap();
781 1 : });
782 1 :
783 1 : let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
784 1 : EndpointRateLimiter::DEFAULT,
785 1 : 64,
786 1 : ));
787 1 :
788 1 : let creds = auth_quirks(
789 1 : &ctx,
790 1 : &api,
791 1 : user_info,
792 1 : &mut stream,
793 1 : true,
794 1 : &CONFIG,
795 1 : endpoint_rate_limiter,
796 1 : )
797 2 : .await
798 1 : .unwrap();
799 1 :
800 1 : assert_eq!(creds.info.endpoint, "my-endpoint");
801 1 :
802 1 : handle.await.unwrap();
803 1 : }
804 : }
|