Line data Source code
1 : mod classic;
2 : mod hacks;
3 : pub mod jwt;
4 : pub mod local;
5 : mod web;
6 :
7 : use std::net::IpAddr;
8 : use std::sync::Arc;
9 : use std::time::Duration;
10 :
11 : use ipnet::{Ipv4Net, Ipv6Net};
12 : use local::LocalBackend;
13 : use tokio::io::{AsyncRead, AsyncWrite};
14 : use tokio_postgres::config::AuthKeys;
15 : use tracing::{info, warn};
16 : pub(crate) use web::WebAuthError;
17 :
18 : use crate::auth::credentials::check_peer_addr_is_in_list;
19 : use crate::auth::{validate_password_and_exchange, AuthError};
20 : use crate::cache::Cached;
21 : use crate::console::errors::GetAuthInfoError;
22 : use crate::console::provider::{CachedRoleSecret, ConsoleBackend};
23 : use crate::console::{AuthSecret, NodeInfo};
24 : use crate::context::RequestMonitoring;
25 : use crate::intern::EndpointIdInt;
26 : use crate::metrics::Metrics;
27 : use crate::proxy::connect_compute::ComputeConnectBackend;
28 : use crate::proxy::NeonOptions;
29 : use crate::rate_limiter::{BucketRateLimiter, EndpointRateLimiter, RateBucketInfo};
30 : use crate::stream::Stream;
31 : use crate::{
32 : auth::{self, ComputeUserInfoMaybeEndpoint},
33 : config::AuthenticationConfig,
34 : console::{
35 : self,
36 : provider::{CachedAllowedIps, CachedNodeInfo},
37 : Api,
38 : },
39 : stream, url,
40 : };
41 : use crate::{scram, EndpointCacheKey, EndpointId, RoleName};
42 :
43 : /// Alternative to [`std::borrow::Cow`] but doesn't need `T: ToOwned` as we don't need that functionality
44 : pub enum MaybeOwned<'a, T> {
45 : Owned(T),
46 : Borrowed(&'a T),
47 : }
48 :
49 : impl<T> std::ops::Deref for MaybeOwned<'_, T> {
50 : type Target = T;
51 :
52 13 : fn deref(&self) -> &Self::Target {
53 13 : match self {
54 13 : MaybeOwned::Owned(t) => t,
55 0 : MaybeOwned::Borrowed(t) => t,
56 : }
57 13 : }
58 : }
59 :
60 : /// This type serves two purposes:
61 : ///
62 : /// * When `T` is `()`, it's just a regular auth backend selector
63 : /// which we use in [`crate::config::ProxyConfig`].
64 : ///
65 : /// * However, when we substitute `T` with [`ComputeUserInfoMaybeEndpoint`],
66 : /// this helps us provide the credentials only to those auth
67 : /// backends which require them for the authentication process.
68 : pub enum Backend<'a, T, D> {
69 : /// Cloud API (V2).
70 : Console(MaybeOwned<'a, ConsoleBackend>, T),
71 : /// Authentication via a web browser.
72 : Web(MaybeOwned<'a, url::ApiUrl>, D),
73 : /// Local proxy uses configured auth credentials and does not wake compute
74 : Local(MaybeOwned<'a, LocalBackend>),
75 : }
76 :
77 : #[cfg(test)]
78 : pub(crate) trait TestBackend: Send + Sync + 'static {
79 : fn wake_compute(&self) -> Result<CachedNodeInfo, console::errors::WakeComputeError>;
80 : fn get_allowed_ips_and_secret(
81 : &self,
82 : ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), console::errors::GetAuthInfoError>;
83 : }
84 :
85 : impl std::fmt::Display for Backend<'_, (), ()> {
86 0 : fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87 0 : match self {
88 0 : Self::Console(api, ()) => match &**api {
89 0 : ConsoleBackend::Console(endpoint) => {
90 0 : fmt.debug_tuple("Console").field(&endpoint.url()).finish()
91 : }
92 : #[cfg(any(test, feature = "testing"))]
93 0 : ConsoleBackend::Postgres(endpoint) => {
94 0 : fmt.debug_tuple("Postgres").field(&endpoint.url()).finish()
95 : }
96 : #[cfg(test)]
97 0 : ConsoleBackend::Test(_) => fmt.debug_tuple("Test").finish(),
98 : },
99 0 : Self::Web(url, ()) => fmt.debug_tuple("Web").field(&url.as_str()).finish(),
100 0 : Self::Local(_) => fmt.debug_tuple("Local").finish(),
101 : }
102 0 : }
103 : }
104 :
105 : impl<T, D> Backend<'_, T, D> {
106 : /// Very similar to [`std::option::Option::as_ref`].
107 : /// This helps us pass structured config to async tasks.
108 0 : pub(crate) fn as_ref(&self) -> Backend<'_, &T, &D> {
109 0 : match self {
110 0 : Self::Console(c, x) => Backend::Console(MaybeOwned::Borrowed(c), x),
111 0 : Self::Web(c, x) => Backend::Web(MaybeOwned::Borrowed(c), x),
112 0 : Self::Local(l) => Backend::Local(MaybeOwned::Borrowed(l)),
113 : }
114 0 : }
115 : }
116 :
117 : impl<'a, T, D> Backend<'a, T, D> {
118 : /// Very similar to [`std::option::Option::map`].
119 : /// Maps [`Backend<T>`] to [`Backend<R>`] by applying
120 : /// a function to a contained value.
121 0 : pub(crate) fn map<R>(self, f: impl FnOnce(T) -> R) -> Backend<'a, R, D> {
122 0 : match self {
123 0 : Self::Console(c, x) => Backend::Console(c, f(x)),
124 0 : Self::Web(c, x) => Backend::Web(c, x),
125 0 : Self::Local(l) => Backend::Local(l),
126 : }
127 0 : }
128 : }
129 : impl<'a, T, D, E> Backend<'a, Result<T, E>, D> {
130 : /// Very similar to [`std::option::Option::transpose`].
131 : /// This is most useful for error handling.
132 0 : pub(crate) fn transpose(self) -> Result<Backend<'a, T, D>, E> {
133 0 : match self {
134 0 : Self::Console(c, x) => x.map(|x| Backend::Console(c, x)),
135 0 : Self::Web(c, x) => Ok(Backend::Web(c, x)),
136 0 : Self::Local(l) => Ok(Backend::Local(l)),
137 : }
138 0 : }
139 : }
140 :
141 : pub(crate) struct ComputeCredentials {
142 : pub(crate) info: ComputeUserInfo,
143 : pub(crate) keys: ComputeCredentialKeys,
144 : }
145 :
146 : #[derive(Debug, Clone)]
147 : pub(crate) struct ComputeUserInfoNoEndpoint {
148 : pub(crate) user: RoleName,
149 : pub(crate) options: NeonOptions,
150 : }
151 :
152 : #[derive(Debug, Clone)]
153 : pub(crate) struct ComputeUserInfo {
154 : pub(crate) endpoint: EndpointId,
155 : pub(crate) user: RoleName,
156 : pub(crate) options: NeonOptions,
157 : }
158 :
159 : impl ComputeUserInfo {
160 2 : pub(crate) fn endpoint_cache_key(&self) -> EndpointCacheKey {
161 2 : self.options.get_cache_key(&self.endpoint)
162 2 : }
163 : }
164 :
165 : pub(crate) enum ComputeCredentialKeys {
166 : Password(Vec<u8>),
167 : AuthKeys(AuthKeys),
168 : None,
169 : }
170 :
171 : impl TryFrom<ComputeUserInfoMaybeEndpoint> for ComputeUserInfo {
172 : // user name
173 : type Error = ComputeUserInfoNoEndpoint;
174 :
175 3 : fn try_from(user_info: ComputeUserInfoMaybeEndpoint) -> Result<Self, Self::Error> {
176 3 : match user_info.endpoint_id {
177 1 : None => Err(ComputeUserInfoNoEndpoint {
178 1 : user: user_info.user,
179 1 : options: user_info.options,
180 1 : }),
181 2 : Some(endpoint) => Ok(ComputeUserInfo {
182 2 : endpoint,
183 2 : user: user_info.user,
184 2 : options: user_info.options,
185 2 : }),
186 : }
187 3 : }
188 : }
189 :
190 : #[derive(PartialEq, PartialOrd, Hash, Eq, Ord, Debug, Copy, Clone)]
191 : pub struct MaskedIp(IpAddr);
192 :
193 : impl MaskedIp {
194 15 : fn new(value: IpAddr, prefix: u8) -> Self {
195 15 : match value {
196 11 : IpAddr::V4(v4) => Self(IpAddr::V4(
197 11 : Ipv4Net::new(v4, prefix).map_or(v4, |x| x.trunc().addr()),
198 11 : )),
199 4 : IpAddr::V6(v6) => Self(IpAddr::V6(
200 4 : Ipv6Net::new(v6, prefix).map_or(v6, |x| x.trunc().addr()),
201 4 : )),
202 : }
203 15 : }
204 : }
205 :
206 : // This can't be just per IP because that would limit some PaaS that share IP addresses
207 : pub type AuthRateLimiter = BucketRateLimiter<(EndpointIdInt, MaskedIp)>;
208 :
209 : impl RateBucketInfo {
210 : /// All of these are per endpoint-maskedip pair.
211 : /// Context: 4096 rounds of pbkdf2 take about 1ms of cpu time to execute (1 milli-cpu-second or 1mcpus).
212 : ///
213 : /// First bucket: 1000mcpus total per endpoint-ip pair
214 : /// * 4096000 requests per second with 1 hash rounds.
215 : /// * 1000 requests per second with 4096 hash rounds.
216 : /// * 6.8 requests per second with 600000 hash rounds.
217 : pub const DEFAULT_AUTH_SET: [Self; 3] = [
218 : Self::new(1000 * 4096, Duration::from_secs(1)),
219 : Self::new(600 * 4096, Duration::from_secs(60)),
220 : Self::new(300 * 4096, Duration::from_secs(600)),
221 : ];
222 : }
223 :
224 : impl AuthenticationConfig {
225 3 : pub(crate) fn check_rate_limit(
226 3 : &self,
227 3 : ctx: &RequestMonitoring,
228 3 : config: &AuthenticationConfig,
229 3 : secret: AuthSecret,
230 3 : endpoint: &EndpointId,
231 3 : is_cleartext: bool,
232 3 : ) -> auth::Result<AuthSecret> {
233 3 : // we have validated the endpoint exists, so let's intern it.
234 3 : let endpoint_int = EndpointIdInt::from(endpoint.normalize());
235 :
236 : // only count the full hash count if password hack or websocket flow.
237 : // in other words, if proxy needs to run the hashing
238 3 : let password_weight = if is_cleartext {
239 2 : match &secret {
240 : #[cfg(any(test, feature = "testing"))]
241 0 : AuthSecret::Md5(_) => 1,
242 2 : AuthSecret::Scram(s) => s.iterations + 1,
243 : }
244 : } else {
245 : // validating scram takes just 1 hmac_sha_256 operation.
246 1 : 1
247 : };
248 :
249 3 : let limit_not_exceeded = self.rate_limiter.check(
250 3 : (
251 3 : endpoint_int,
252 3 : MaskedIp::new(ctx.peer_addr(), config.rate_limit_ip_subnet),
253 3 : ),
254 3 : password_weight,
255 3 : );
256 3 :
257 3 : if !limit_not_exceeded {
258 0 : warn!(
259 : enabled = self.rate_limiter_enabled,
260 0 : "rate limiting authentication"
261 : );
262 0 : Metrics::get().proxy.requests_auth_rate_limits_total.inc();
263 0 : Metrics::get()
264 0 : .proxy
265 0 : .endpoints_auth_rate_limits
266 0 : .get_metric()
267 0 : .measure(endpoint);
268 0 :
269 0 : if self.rate_limiter_enabled {
270 0 : return Err(auth::AuthError::too_many_connections());
271 0 : }
272 3 : }
273 :
274 3 : Ok(secret)
275 3 : }
276 : }
277 :
278 : /// True to its name, this function encapsulates our current auth trade-offs.
279 : /// Here, we choose the appropriate auth flow based on circumstances.
280 : ///
281 : /// All authentication flows will emit an AuthenticationOk message if successful.
282 3 : async fn auth_quirks(
283 3 : ctx: &RequestMonitoring,
284 3 : api: &impl console::Api,
285 3 : user_info: ComputeUserInfoMaybeEndpoint,
286 3 : client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
287 3 : allow_cleartext: bool,
288 3 : config: &'static AuthenticationConfig,
289 3 : endpoint_rate_limiter: Arc<EndpointRateLimiter>,
290 3 : ) -> auth::Result<ComputeCredentials> {
291 : // If there's no project so far, that entails that client doesn't
292 : // support SNI or other means of passing the endpoint (project) name.
293 : // We now expect to see a very specific payload in the place of password.
294 3 : let (info, unauthenticated_password) = match user_info.try_into() {
295 1 : Err(info) => {
296 1 : let res = hacks::password_hack_no_authentication(ctx, info, client).await?;
297 :
298 1 : ctx.set_endpoint_id(res.info.endpoint.clone());
299 1 : let password = match res.keys {
300 1 : ComputeCredentialKeys::Password(p) => p,
301 : ComputeCredentialKeys::AuthKeys(_) | ComputeCredentialKeys::None => {
302 0 : unreachable!("password hack should return a password")
303 : }
304 : };
305 1 : (res.info, Some(password))
306 : }
307 2 : Ok(info) => (info, None),
308 : };
309 :
310 3 : info!("fetching user's authentication info");
311 3 : let (allowed_ips, maybe_secret) = api.get_allowed_ips_and_secret(ctx, &info).await?;
312 :
313 : // check allowed list
314 3 : if config.ip_allowlist_check_enabled
315 3 : && !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips)
316 : {
317 0 : return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr()));
318 3 : }
319 3 :
320 3 : if !endpoint_rate_limiter.check(info.endpoint.clone().into(), 1) {
321 0 : return Err(AuthError::too_many_connections());
322 3 : }
323 3 : let cached_secret = match maybe_secret {
324 3 : Some(secret) => secret,
325 0 : None => api.get_role_secret(ctx, &info).await?,
326 : };
327 3 : let (cached_entry, secret) = cached_secret.take_value();
328 :
329 3 : let secret = if let Some(secret) = secret {
330 3 : config.check_rate_limit(
331 3 : ctx,
332 3 : config,
333 3 : secret,
334 3 : &info.endpoint,
335 3 : unauthenticated_password.is_some() || allow_cleartext,
336 0 : )?
337 : } else {
338 : // If we don't have an authentication secret, we mock one to
339 : // prevent malicious probing (possible due to missing protocol steps).
340 : // This mocked secret will never lead to successful authentication.
341 0 : info!("authentication info not found, mocking it");
342 0 : AuthSecret::Scram(scram::ServerSecret::mock(rand::random()))
343 : };
344 :
345 3 : match authenticate_with_secret(
346 3 : ctx,
347 3 : secret,
348 3 : info,
349 3 : client,
350 3 : unauthenticated_password,
351 3 : allow_cleartext,
352 3 : config,
353 3 : )
354 5 : .await
355 : {
356 3 : Ok(keys) => Ok(keys),
357 0 : Err(e) => {
358 0 : if e.is_auth_failed() {
359 0 : // The password could have been changed, so we invalidate the cache.
360 0 : cached_entry.invalidate();
361 0 : }
362 0 : Err(e)
363 : }
364 : }
365 3 : }
366 :
367 3 : async fn authenticate_with_secret(
368 3 : ctx: &RequestMonitoring,
369 3 : secret: AuthSecret,
370 3 : info: ComputeUserInfo,
371 3 : client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
372 3 : unauthenticated_password: Option<Vec<u8>>,
373 3 : allow_cleartext: bool,
374 3 : config: &'static AuthenticationConfig,
375 3 : ) -> auth::Result<ComputeCredentials> {
376 3 : if let Some(password) = unauthenticated_password {
377 1 : let ep = EndpointIdInt::from(&info.endpoint);
378 :
379 1 : let auth_outcome =
380 1 : validate_password_and_exchange(&config.thread_pool, ep, &password, secret).await?;
381 1 : let keys = match auth_outcome {
382 1 : crate::sasl::Outcome::Success(key) => key,
383 0 : crate::sasl::Outcome::Failure(reason) => {
384 0 : info!("auth backend failed with an error: {reason}");
385 0 : return Err(auth::AuthError::auth_failed(&*info.user));
386 : }
387 : };
388 :
389 : // we have authenticated the password
390 1 : client.write_message_noflush(&pq_proto::BeMessage::AuthenticationOk)?;
391 :
392 1 : return Ok(ComputeCredentials { info, keys });
393 2 : }
394 2 :
395 2 : // -- the remaining flows are self-authenticating --
396 2 :
397 2 : // Perform cleartext auth if we're allowed to do that.
398 2 : // Currently, we use it for websocket connections (latency).
399 2 : if allow_cleartext {
400 1 : ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
401 2 : return hacks::authenticate_cleartext(ctx, info, client, secret, config).await;
402 1 : }
403 1 :
404 1 : // Finally, proceed with the main auth flow (SCRAM-based).
405 2 : classic::authenticate(ctx, info, client, config, secret).await
406 3 : }
407 :
408 : impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint, &()> {
409 : /// Get username from the credentials.
410 0 : pub(crate) fn get_user(&self) -> &str {
411 0 : match self {
412 0 : Self::Console(_, user_info) => &user_info.user,
413 0 : Self::Web(_, ()) => "web",
414 0 : Self::Local(_) => "local",
415 : }
416 0 : }
417 :
418 : /// Authenticate the client via the requested backend, possibly using credentials.
419 0 : #[tracing::instrument(fields(allow_cleartext = allow_cleartext), skip_all)]
420 : pub(crate) async fn authenticate(
421 : self,
422 : ctx: &RequestMonitoring,
423 : client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
424 : allow_cleartext: bool,
425 : config: &'static AuthenticationConfig,
426 : endpoint_rate_limiter: Arc<EndpointRateLimiter>,
427 : ) -> auth::Result<Backend<'a, ComputeCredentials, NodeInfo>> {
428 : let res = match self {
429 : Self::Console(api, user_info) => {
430 : info!(
431 : user = &*user_info.user,
432 : project = user_info.endpoint(),
433 : "performing authentication using the console"
434 : );
435 :
436 : let credentials = auth_quirks(
437 : ctx,
438 : &*api,
439 : user_info,
440 : client,
441 : allow_cleartext,
442 : config,
443 : endpoint_rate_limiter,
444 : )
445 : .await?;
446 : Backend::Console(api, credentials)
447 : }
448 : // NOTE: this auth backend doesn't use client credentials.
449 : Self::Web(url, ()) => {
450 : info!("performing web authentication");
451 :
452 : let info = web::authenticate(ctx, &url, client).await?;
453 :
454 : Backend::Web(url, info)
455 : }
456 : Self::Local(_) => {
457 : return Err(auth::AuthError::bad_auth_method("invalid for local proxy"))
458 : }
459 : };
460 :
461 : info!("user successfully authenticated");
462 : Ok(res)
463 : }
464 : }
465 :
466 : impl Backend<'_, ComputeUserInfo, &()> {
467 0 : pub(crate) async fn get_role_secret(
468 0 : &self,
469 0 : ctx: &RequestMonitoring,
470 0 : ) -> Result<CachedRoleSecret, GetAuthInfoError> {
471 0 : match self {
472 0 : Self::Console(api, user_info) => api.get_role_secret(ctx, user_info).await,
473 0 : Self::Web(_, ()) => Ok(Cached::new_uncached(None)),
474 0 : Self::Local(_) => Ok(Cached::new_uncached(None)),
475 : }
476 0 : }
477 :
478 0 : pub(crate) async fn get_allowed_ips_and_secret(
479 0 : &self,
480 0 : ctx: &RequestMonitoring,
481 0 : ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
482 0 : match self {
483 0 : Self::Console(api, user_info) => api.get_allowed_ips_and_secret(ctx, user_info).await,
484 0 : Self::Web(_, ()) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
485 0 : Self::Local(_) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
486 : }
487 0 : }
488 : }
489 :
490 : #[async_trait::async_trait]
491 : impl ComputeConnectBackend for Backend<'_, ComputeCredentials, NodeInfo> {
492 0 : async fn wake_compute(
493 0 : &self,
494 0 : ctx: &RequestMonitoring,
495 0 : ) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
496 0 : match self {
497 0 : Self::Console(api, creds) => api.wake_compute(ctx, &creds.info).await,
498 0 : Self::Web(_, info) => Ok(Cached::new_uncached(info.clone())),
499 0 : Self::Local(local) => Ok(Cached::new_uncached(local.node_info.clone())),
500 : }
501 0 : }
502 :
503 0 : fn get_keys(&self) -> &ComputeCredentialKeys {
504 0 : match self {
505 0 : Self::Console(_, creds) => &creds.keys,
506 0 : Self::Web(_, _) => &ComputeCredentialKeys::None,
507 0 : Self::Local(_) => &ComputeCredentialKeys::None,
508 : }
509 0 : }
510 : }
511 :
512 : #[async_trait::async_trait]
513 : impl ComputeConnectBackend for Backend<'_, ComputeCredentials, &()> {
514 13 : async fn wake_compute(
515 13 : &self,
516 13 : ctx: &RequestMonitoring,
517 13 : ) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
518 13 : match self {
519 13 : Self::Console(api, creds) => api.wake_compute(ctx, &creds.info).await,
520 0 : Self::Web(_, ()) => {
521 0 : unreachable!("web auth flow doesn't support waking the compute")
522 : }
523 0 : Self::Local(local) => Ok(Cached::new_uncached(local.node_info.clone())),
524 : }
525 26 : }
526 :
527 6 : fn get_keys(&self) -> &ComputeCredentialKeys {
528 6 : match self {
529 6 : Self::Console(_, creds) => &creds.keys,
530 0 : Self::Web(_, ()) => &ComputeCredentialKeys::None,
531 0 : Self::Local(_) => &ComputeCredentialKeys::None,
532 : }
533 6 : }
534 : }
535 :
536 : #[cfg(test)]
537 : mod tests {
538 : use std::{net::IpAddr, sync::Arc, time::Duration};
539 :
540 : use bytes::BytesMut;
541 : use fallible_iterator::FallibleIterator;
542 : use once_cell::sync::Lazy;
543 : use postgres_protocol::{
544 : authentication::sasl::{ChannelBinding, ScramSha256},
545 : message::{backend::Message as PgMessage, frontend},
546 : };
547 : use provider::AuthSecret;
548 : use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
549 :
550 : use crate::{
551 : auth::{backend::MaskedIp, ComputeUserInfoMaybeEndpoint, IpPattern},
552 : config::AuthenticationConfig,
553 : console::{
554 : self,
555 : provider::{self, CachedAllowedIps, CachedRoleSecret},
556 : CachedNodeInfo,
557 : },
558 : context::RequestMonitoring,
559 : proxy::NeonOptions,
560 : rate_limiter::{EndpointRateLimiter, RateBucketInfo},
561 : scram::{threadpool::ThreadPool, ServerSecret},
562 : stream::{PqStream, Stream},
563 : };
564 :
565 : use super::{auth_quirks, AuthRateLimiter};
566 :
567 : struct Auth {
568 : ips: Vec<IpPattern>,
569 : secret: AuthSecret,
570 : }
571 :
572 : impl console::Api for Auth {
573 0 : async fn get_role_secret(
574 0 : &self,
575 0 : _ctx: &RequestMonitoring,
576 0 : _user_info: &super::ComputeUserInfo,
577 0 : ) -> Result<CachedRoleSecret, console::errors::GetAuthInfoError> {
578 0 : Ok(CachedRoleSecret::new_uncached(Some(self.secret.clone())))
579 0 : }
580 :
581 3 : async fn get_allowed_ips_and_secret(
582 3 : &self,
583 3 : _ctx: &RequestMonitoring,
584 3 : _user_info: &super::ComputeUserInfo,
585 3 : ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), console::errors::GetAuthInfoError>
586 3 : {
587 3 : Ok((
588 3 : CachedAllowedIps::new_uncached(Arc::new(self.ips.clone())),
589 3 : Some(CachedRoleSecret::new_uncached(Some(self.secret.clone()))),
590 3 : ))
591 3 : }
592 :
593 0 : async fn wake_compute(
594 0 : &self,
595 0 : _ctx: &RequestMonitoring,
596 0 : _user_info: &super::ComputeUserInfo,
597 0 : ) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
598 0 : unimplemented!()
599 : }
600 : }
601 :
602 3 : static CONFIG: Lazy<AuthenticationConfig> = Lazy::new(|| AuthenticationConfig {
603 3 : thread_pool: ThreadPool::new(1),
604 3 : scram_protocol_timeout: std::time::Duration::from_secs(5),
605 3 : rate_limiter_enabled: true,
606 3 : rate_limiter: AuthRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET),
607 3 : rate_limit_ip_subnet: 64,
608 3 : ip_allowlist_check_enabled: true,
609 3 : });
610 :
611 5 : async fn read_message(r: &mut (impl AsyncRead + Unpin), b: &mut BytesMut) -> PgMessage {
612 : loop {
613 7 : r.read_buf(&mut *b).await.unwrap();
614 7 : if let Some(m) = PgMessage::parse(&mut *b).unwrap() {
615 5 : break m;
616 2 : }
617 : }
618 5 : }
619 :
620 : #[test]
621 1 : fn masked_ip() {
622 1 : let ip_a = IpAddr::V4([127, 0, 0, 1].into());
623 1 : let ip_b = IpAddr::V4([127, 0, 0, 2].into());
624 1 : let ip_c = IpAddr::V4([192, 168, 1, 101].into());
625 1 : let ip_d = IpAddr::V4([192, 168, 1, 102].into());
626 1 : let ip_e = IpAddr::V6("abcd:abcd:abcd:abcd:abcd:abcd:abcd:abcd".parse().unwrap());
627 1 : let ip_f = IpAddr::V6("abcd:abcd:abcd:abcd:1234:abcd:abcd:abcd".parse().unwrap());
628 1 :
629 1 : assert_ne!(MaskedIp::new(ip_a, 64), MaskedIp::new(ip_b, 64));
630 1 : assert_ne!(MaskedIp::new(ip_a, 32), MaskedIp::new(ip_b, 32));
631 1 : assert_eq!(MaskedIp::new(ip_a, 30), MaskedIp::new(ip_b, 30));
632 1 : assert_eq!(MaskedIp::new(ip_c, 30), MaskedIp::new(ip_d, 30));
633 :
634 1 : assert_ne!(MaskedIp::new(ip_e, 128), MaskedIp::new(ip_f, 128));
635 1 : assert_eq!(MaskedIp::new(ip_e, 64), MaskedIp::new(ip_f, 64));
636 1 : }
637 :
638 : #[test]
639 1 : fn test_default_auth_rate_limit_set() {
640 1 : // these values used to exceed u32::MAX
641 1 : assert_eq!(
642 1 : RateBucketInfo::DEFAULT_AUTH_SET,
643 1 : [
644 1 : RateBucketInfo {
645 1 : interval: Duration::from_secs(1),
646 1 : max_rpi: 1000 * 4096,
647 1 : },
648 1 : RateBucketInfo {
649 1 : interval: Duration::from_secs(60),
650 1 : max_rpi: 600 * 4096 * 60,
651 1 : },
652 1 : RateBucketInfo {
653 1 : interval: Duration::from_secs(600),
654 1 : max_rpi: 300 * 4096 * 600,
655 1 : }
656 1 : ]
657 1 : );
658 :
659 4 : for x in RateBucketInfo::DEFAULT_AUTH_SET {
660 3 : let y = x.to_string().parse().unwrap();
661 3 : assert_eq!(x, y);
662 : }
663 1 : }
664 :
665 : #[tokio::test]
666 1 : async fn auth_quirks_scram() {
667 1 : let (mut client, server) = tokio::io::duplex(1024);
668 1 : let mut stream = PqStream::new(Stream::from_raw(server));
669 1 :
670 1 : let ctx = RequestMonitoring::test();
671 1 : let api = Auth {
672 1 : ips: vec![],
673 3 : secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
674 1 : };
675 1 :
676 1 : let user_info = ComputeUserInfoMaybeEndpoint {
677 1 : user: "conrad".into(),
678 1 : endpoint_id: Some("endpoint".into()),
679 1 : options: NeonOptions::default(),
680 1 : };
681 1 :
682 1 : let handle = tokio::spawn(async move {
683 1 : let mut scram = ScramSha256::new(b"my-secret-password", ChannelBinding::unsupported());
684 1 :
685 1 : let mut read = BytesMut::new();
686 1 :
687 1 : // server should offer scram
688 1 : match read_message(&mut client, &mut read).await {
689 1 : PgMessage::AuthenticationSasl(a) => {
690 1 : let options: Vec<&str> = a.mechanisms().collect().unwrap();
691 1 : assert_eq!(options, ["SCRAM-SHA-256"]);
692 1 : }
693 1 : _ => panic!("wrong message"),
694 1 : }
695 1 :
696 1 : // client sends client-first-message
697 1 : let mut write = BytesMut::new();
698 1 : frontend::sasl_initial_response("SCRAM-SHA-256", scram.message(), &mut write).unwrap();
699 1 : client.write_all(&write).await.unwrap();
700 1 :
701 1 : // server response with server-first-message
702 1 : match read_message(&mut client, &mut read).await {
703 1 : PgMessage::AuthenticationSaslContinue(a) => {
704 3 : scram.update(a.data()).await.unwrap();
705 1 : }
706 1 : _ => panic!("wrong message"),
707 1 : }
708 1 :
709 1 : // client response with client-final-message
710 1 : write.clear();
711 1 : frontend::sasl_response(scram.message(), &mut write).unwrap();
712 1 : client.write_all(&write).await.unwrap();
713 1 :
714 1 : // server response with server-final-message
715 1 : match read_message(&mut client, &mut read).await {
716 1 : PgMessage::AuthenticationSaslFinal(a) => {
717 1 : scram.finish(a.data()).unwrap();
718 1 : }
719 1 : _ => panic!("wrong message"),
720 1 : }
721 1 : });
722 1 : let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
723 1 : EndpointRateLimiter::DEFAULT,
724 1 : 64,
725 1 : ));
726 1 :
727 1 : let _creds = auth_quirks(
728 1 : &ctx,
729 1 : &api,
730 1 : user_info,
731 1 : &mut stream,
732 1 : false,
733 1 : &CONFIG,
734 1 : endpoint_rate_limiter,
735 1 : )
736 2 : .await
737 1 : .unwrap();
738 1 :
739 1 : handle.await.unwrap();
740 1 : }
741 :
742 : #[tokio::test]
743 1 : async fn auth_quirks_cleartext() {
744 1 : let (mut client, server) = tokio::io::duplex(1024);
745 1 : let mut stream = PqStream::new(Stream::from_raw(server));
746 1 :
747 1 : let ctx = RequestMonitoring::test();
748 1 : let api = Auth {
749 1 : ips: vec![],
750 3 : secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
751 1 : };
752 1 :
753 1 : let user_info = ComputeUserInfoMaybeEndpoint {
754 1 : user: "conrad".into(),
755 1 : endpoint_id: Some("endpoint".into()),
756 1 : options: NeonOptions::default(),
757 1 : };
758 1 :
759 1 : let handle = tokio::spawn(async move {
760 1 : let mut read = BytesMut::new();
761 1 : let mut write = BytesMut::new();
762 1 :
763 1 : // server should offer cleartext
764 1 : match read_message(&mut client, &mut read).await {
765 1 : PgMessage::AuthenticationCleartextPassword => {}
766 1 : _ => panic!("wrong message"),
767 1 : }
768 1 :
769 1 : // client responds with password
770 1 : write.clear();
771 1 : frontend::password_message(b"my-secret-password", &mut write).unwrap();
772 1 : client.write_all(&write).await.unwrap();
773 1 : });
774 1 : let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
775 1 : EndpointRateLimiter::DEFAULT,
776 1 : 64,
777 1 : ));
778 1 :
779 1 : let _creds = auth_quirks(
780 1 : &ctx,
781 1 : &api,
782 1 : user_info,
783 1 : &mut stream,
784 1 : true,
785 1 : &CONFIG,
786 1 : endpoint_rate_limiter,
787 1 : )
788 2 : .await
789 1 : .unwrap();
790 1 :
791 1 : handle.await.unwrap();
792 1 : }
793 :
794 : #[tokio::test]
795 1 : async fn auth_quirks_password_hack() {
796 1 : let (mut client, server) = tokio::io::duplex(1024);
797 1 : let mut stream = PqStream::new(Stream::from_raw(server));
798 1 :
799 1 : let ctx = RequestMonitoring::test();
800 1 : let api = Auth {
801 1 : ips: vec![],
802 3 : secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
803 1 : };
804 1 :
805 1 : let user_info = ComputeUserInfoMaybeEndpoint {
806 1 : user: "conrad".into(),
807 1 : endpoint_id: None,
808 1 : options: NeonOptions::default(),
809 1 : };
810 1 :
811 1 : let handle = tokio::spawn(async move {
812 1 : let mut read = BytesMut::new();
813 1 :
814 1 : // server should offer cleartext
815 1 : match read_message(&mut client, &mut read).await {
816 1 : PgMessage::AuthenticationCleartextPassword => {}
817 1 : _ => panic!("wrong message"),
818 1 : }
819 1 :
820 1 : // client responds with password
821 1 : let mut write = BytesMut::new();
822 1 : frontend::password_message(b"endpoint=my-endpoint;my-secret-password", &mut write)
823 1 : .unwrap();
824 1 : client.write_all(&write).await.unwrap();
825 1 : });
826 1 :
827 1 : let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
828 1 : EndpointRateLimiter::DEFAULT,
829 1 : 64,
830 1 : ));
831 1 :
832 1 : let creds = auth_quirks(
833 1 : &ctx,
834 1 : &api,
835 1 : user_info,
836 1 : &mut stream,
837 1 : true,
838 1 : &CONFIG,
839 1 : endpoint_rate_limiter,
840 1 : )
841 2 : .await
842 1 : .unwrap();
843 1 :
844 1 : assert_eq!(creds.info.endpoint, "my-endpoint");
845 1 :
846 1 : handle.await.unwrap();
847 1 : }
848 : }
|