Line data Source code
1 : mod classic;
2 : mod console_redirect;
3 : mod hacks;
4 : pub mod jwt;
5 : pub mod local;
6 :
7 : use std::net::IpAddr;
8 : use std::sync::Arc;
9 :
10 : pub use console_redirect::ConsoleRedirectBackend;
11 : pub(crate) use console_redirect::ConsoleRedirectError;
12 : use ipnet::{Ipv4Net, Ipv6Net};
13 : use local::LocalBackend;
14 : use postgres_client::config::AuthKeys;
15 : use serde::{Deserialize, Serialize};
16 : use tokio::io::{AsyncRead, AsyncWrite};
17 : use tracing::{debug, info, warn};
18 :
19 : use crate::auth::credentials::check_peer_addr_is_in_list;
20 : use crate::auth::{
21 : self, AuthError, ComputeUserInfoMaybeEndpoint, IpPattern, validate_password_and_exchange,
22 : };
23 : use crate::cache::Cached;
24 : use crate::config::AuthenticationConfig;
25 : use crate::context::RequestContext;
26 : use crate::control_plane::client::ControlPlaneClient;
27 : use crate::control_plane::errors::GetAuthInfoError;
28 : use crate::control_plane::{
29 : self, AccessBlockerFlags, AuthSecret, CachedAccessBlockerFlags, CachedAllowedIps,
30 : CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret, ControlPlaneApi,
31 : };
32 : use crate::intern::EndpointIdInt;
33 : use crate::metrics::Metrics;
34 : use crate::protocol2::ConnectionInfoExtra;
35 : use crate::proxy::NeonOptions;
36 : use crate::proxy::connect_compute::ComputeConnectBackend;
37 : use crate::rate_limiter::{BucketRateLimiter, EndpointRateLimiter};
38 : use crate::stream::Stream;
39 : use crate::types::{EndpointCacheKey, EndpointId, RoleName};
40 : use crate::{scram, stream};
41 :
42 : /// Alternative to [`std::borrow::Cow`] but doesn't need `T: ToOwned` as we don't need that functionality
43 : pub enum MaybeOwned<'a, T> {
44 : Owned(T),
45 : Borrowed(&'a T),
46 : }
47 :
48 : impl<T> std::ops::Deref for MaybeOwned<'_, T> {
49 : type Target = T;
50 :
51 19 : fn deref(&self) -> &Self::Target {
52 19 : match self {
53 19 : MaybeOwned::Owned(t) => t,
54 0 : MaybeOwned::Borrowed(t) => t,
55 : }
56 19 : }
57 : }
58 :
59 : /// This type serves two purposes:
60 : ///
61 : /// * When `T` is `()`, it's just a regular auth backend selector
62 : /// which we use in [`crate::config::ProxyConfig`].
63 : ///
64 : /// * However, when we substitute `T` with [`ComputeUserInfoMaybeEndpoint`],
65 : /// this helps us provide the credentials only to those auth
66 : /// backends which require them for the authentication process.
67 : pub enum Backend<'a, T> {
68 : /// Cloud API (V2).
69 : ControlPlane(MaybeOwned<'a, ControlPlaneClient>, T),
70 : /// Local proxy uses configured auth credentials and does not wake compute
71 : Local(MaybeOwned<'a, LocalBackend>),
72 : }
73 :
74 : impl std::fmt::Display for Backend<'_, ()> {
75 0 : fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76 0 : match self {
77 0 : Self::ControlPlane(api, ()) => match &**api {
78 0 : ControlPlaneClient::ProxyV1(endpoint) => fmt
79 0 : .debug_tuple("ControlPlane::ProxyV1")
80 0 : .field(&endpoint.url())
81 0 : .finish(),
82 : #[cfg(any(test, feature = "testing"))]
83 0 : ControlPlaneClient::PostgresMock(endpoint) => {
84 0 : let url = endpoint.url();
85 0 : match url::Url::parse(url) {
86 0 : Ok(mut url) => {
87 0 : let _ = url.set_password(Some("_redacted_"));
88 0 : let url = url.as_str();
89 0 : fmt.debug_tuple("ControlPlane::PostgresMock")
90 0 : .field(&url)
91 0 : .finish()
92 : }
93 0 : Err(_) => fmt
94 0 : .debug_tuple("ControlPlane::PostgresMock")
95 0 : .field(&url)
96 0 : .finish(),
97 : }
98 : }
99 : #[cfg(test)]
100 0 : ControlPlaneClient::Test(_) => fmt.debug_tuple("ControlPlane::Test").finish(),
101 : },
102 0 : Self::Local(_) => fmt.debug_tuple("Local").finish(),
103 : }
104 0 : }
105 : }
106 :
107 : impl<T> Backend<'_, T> {
108 : /// Very similar to [`std::option::Option::as_ref`].
109 : /// This helps us pass structured config to async tasks.
110 0 : pub(crate) fn as_ref(&self) -> Backend<'_, &T> {
111 0 : match self {
112 0 : Self::ControlPlane(c, x) => Backend::ControlPlane(MaybeOwned::Borrowed(c), x),
113 0 : Self::Local(l) => Backend::Local(MaybeOwned::Borrowed(l)),
114 : }
115 0 : }
116 :
117 0 : pub(crate) fn get_api(&self) -> &ControlPlaneClient {
118 0 : match self {
119 0 : Self::ControlPlane(api, _) => api,
120 0 : Self::Local(_) => panic!("Local backend has no API"),
121 : }
122 0 : }
123 :
124 0 : pub(crate) fn is_local_proxy(&self) -> bool {
125 0 : matches!(self, Self::Local(_))
126 0 : }
127 : }
128 :
129 : impl<'a, T> Backend<'a, T> {
130 : /// Very similar to [`std::option::Option::map`].
131 : /// Maps [`Backend<T>`] to [`Backend<R>`] by applying
132 : /// a function to a contained value.
133 0 : pub(crate) fn map<R>(self, f: impl FnOnce(T) -> R) -> Backend<'a, R> {
134 0 : match self {
135 0 : Self::ControlPlane(c, x) => Backend::ControlPlane(c, f(x)),
136 0 : Self::Local(l) => Backend::Local(l),
137 : }
138 0 : }
139 : }
140 : impl<'a, T, E> Backend<'a, Result<T, E>> {
141 : /// Very similar to [`std::option::Option::transpose`].
142 : /// This is most useful for error handling.
143 0 : pub(crate) fn transpose(self) -> Result<Backend<'a, T>, E> {
144 0 : match self {
145 0 : Self::ControlPlane(c, x) => x.map(|x| Backend::ControlPlane(c, x)),
146 0 : Self::Local(l) => Ok(Backend::Local(l)),
147 : }
148 0 : }
149 : }
150 :
151 : pub(crate) struct ComputeCredentials {
152 : pub(crate) info: ComputeUserInfo,
153 : pub(crate) keys: ComputeCredentialKeys,
154 : }
155 :
156 : #[derive(Debug, Clone)]
157 : pub(crate) struct ComputeUserInfoNoEndpoint {
158 : pub(crate) user: RoleName,
159 : pub(crate) options: NeonOptions,
160 : }
161 :
162 0 : #[derive(Debug, Clone, Default, Serialize, Deserialize)]
163 : pub(crate) struct ComputeUserInfo {
164 : pub(crate) endpoint: EndpointId,
165 : pub(crate) user: RoleName,
166 : pub(crate) options: NeonOptions,
167 : }
168 :
169 : impl ComputeUserInfo {
170 2 : pub(crate) fn endpoint_cache_key(&self) -> EndpointCacheKey {
171 2 : self.options.get_cache_key(&self.endpoint)
172 2 : }
173 : }
174 :
175 : #[cfg_attr(test, derive(Debug))]
176 : pub(crate) enum ComputeCredentialKeys {
177 : #[cfg(any(test, feature = "testing"))]
178 : Password(Vec<u8>),
179 : AuthKeys(AuthKeys),
180 : JwtPayload(Vec<u8>),
181 : None,
182 : }
183 :
184 : impl TryFrom<ComputeUserInfoMaybeEndpoint> for ComputeUserInfo {
185 : // user name
186 : type Error = ComputeUserInfoNoEndpoint;
187 :
188 3 : fn try_from(user_info: ComputeUserInfoMaybeEndpoint) -> Result<Self, Self::Error> {
189 3 : match user_info.endpoint_id {
190 1 : None => Err(ComputeUserInfoNoEndpoint {
191 1 : user: user_info.user,
192 1 : options: user_info.options,
193 1 : }),
194 2 : Some(endpoint) => Ok(ComputeUserInfo {
195 2 : endpoint,
196 2 : user: user_info.user,
197 2 : options: user_info.options,
198 2 : }),
199 : }
200 3 : }
201 : }
202 :
203 : #[derive(PartialEq, PartialOrd, Hash, Eq, Ord, Debug, Copy, Clone)]
204 : pub struct MaskedIp(IpAddr);
205 :
206 : impl MaskedIp {
207 15 : fn new(value: IpAddr, prefix: u8) -> Self {
208 15 : match value {
209 11 : IpAddr::V4(v4) => Self(IpAddr::V4(
210 11 : Ipv4Net::new(v4, prefix).map_or(v4, |x| x.trunc().addr()),
211 11 : )),
212 4 : IpAddr::V6(v6) => Self(IpAddr::V6(
213 4 : Ipv6Net::new(v6, prefix).map_or(v6, |x| x.trunc().addr()),
214 4 : )),
215 : }
216 15 : }
217 : }
218 :
219 : // This can't be just per IP because that would limit some PaaS that share IP addresses
220 : pub type AuthRateLimiter = BucketRateLimiter<(EndpointIdInt, MaskedIp)>;
221 :
222 : impl AuthenticationConfig {
223 3 : pub(crate) fn check_rate_limit(
224 3 : &self,
225 3 : ctx: &RequestContext,
226 3 : secret: AuthSecret,
227 3 : endpoint: &EndpointId,
228 3 : is_cleartext: bool,
229 3 : ) -> auth::Result<AuthSecret> {
230 3 : // we have validated the endpoint exists, so let's intern it.
231 3 : let endpoint_int = EndpointIdInt::from(endpoint.normalize());
232 :
233 : // only count the full hash count if password hack or websocket flow.
234 : // in other words, if proxy needs to run the hashing
235 3 : let password_weight = if is_cleartext {
236 2 : match &secret {
237 : #[cfg(any(test, feature = "testing"))]
238 0 : AuthSecret::Md5(_) => 1,
239 2 : AuthSecret::Scram(s) => s.iterations + 1,
240 : }
241 : } else {
242 : // validating scram takes just 1 hmac_sha_256 operation.
243 1 : 1
244 : };
245 :
246 3 : let limit_not_exceeded = self.rate_limiter.check(
247 3 : (
248 3 : endpoint_int,
249 3 : MaskedIp::new(ctx.peer_addr(), self.rate_limit_ip_subnet),
250 3 : ),
251 3 : password_weight,
252 3 : );
253 3 :
254 3 : if !limit_not_exceeded {
255 0 : warn!(
256 : enabled = self.rate_limiter_enabled,
257 0 : "rate limiting authentication"
258 : );
259 0 : Metrics::get().proxy.requests_auth_rate_limits_total.inc();
260 0 : Metrics::get()
261 0 : .proxy
262 0 : .endpoints_auth_rate_limits
263 0 : .get_metric()
264 0 : .measure(endpoint);
265 0 :
266 0 : if self.rate_limiter_enabled {
267 0 : return Err(auth::AuthError::too_many_connections());
268 0 : }
269 3 : }
270 :
271 3 : Ok(secret)
272 3 : }
273 : }
274 :
275 : /// True to its name, this function encapsulates our current auth trade-offs.
276 : /// Here, we choose the appropriate auth flow based on circumstances.
277 : ///
278 : /// All authentication flows will emit an AuthenticationOk message if successful.
279 3 : async fn auth_quirks(
280 3 : ctx: &RequestContext,
281 3 : api: &impl control_plane::ControlPlaneApi,
282 3 : user_info: ComputeUserInfoMaybeEndpoint,
283 3 : client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
284 3 : allow_cleartext: bool,
285 3 : config: &'static AuthenticationConfig,
286 3 : endpoint_rate_limiter: Arc<EndpointRateLimiter>,
287 3 : ) -> auth::Result<(ComputeCredentials, Option<Vec<IpPattern>>)> {
288 : // If there's no project so far, that entails that client doesn't
289 : // support SNI or other means of passing the endpoint (project) name.
290 : // We now expect to see a very specific payload in the place of password.
291 3 : let (info, unauthenticated_password) = match user_info.try_into() {
292 1 : Err(info) => {
293 1 : let (info, password) =
294 1 : hacks::password_hack_no_authentication(ctx, info, client).await?;
295 1 : ctx.set_endpoint_id(info.endpoint.clone());
296 1 : (info, Some(password))
297 : }
298 2 : Ok(info) => (info, None),
299 : };
300 :
301 3 : debug!("fetching authentication info and allowlists");
302 :
303 : // check allowed list
304 3 : let allowed_ips = if config.ip_allowlist_check_enabled {
305 3 : let allowed_ips = api.get_allowed_ips(ctx, &info).await?;
306 3 : if !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips) {
307 0 : return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr()));
308 3 : }
309 3 : allowed_ips
310 : } else {
311 0 : Cached::new_uncached(Arc::new(vec![]))
312 : };
313 :
314 : // check if a VPC endpoint ID is coming in and if yes, if it's allowed
315 3 : let access_blocks = api.get_block_public_or_vpc_access(ctx, &info).await?;
316 3 : if config.is_vpc_acccess_proxy {
317 0 : if access_blocks.vpc_access_blocked {
318 0 : return Err(AuthError::NetworkNotAllowed);
319 0 : }
320 :
321 0 : let incoming_vpc_endpoint_id = match ctx.extra() {
322 0 : None => return Err(AuthError::MissingEndpointName),
323 0 : Some(ConnectionInfoExtra::Aws { vpce_id }) => vpce_id.to_string(),
324 0 : Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(),
325 : };
326 0 : let allowed_vpc_endpoint_ids = api.get_allowed_vpc_endpoint_ids(ctx, &info).await?;
327 : // TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that.
328 0 : if !allowed_vpc_endpoint_ids.is_empty()
329 0 : && !allowed_vpc_endpoint_ids.contains(&incoming_vpc_endpoint_id)
330 : {
331 0 : return Err(AuthError::vpc_endpoint_id_not_allowed(
332 0 : incoming_vpc_endpoint_id,
333 0 : ));
334 0 : }
335 3 : } else if access_blocks.public_access_blocked {
336 0 : return Err(AuthError::NetworkNotAllowed);
337 3 : }
338 :
339 3 : if !endpoint_rate_limiter.check(info.endpoint.clone().into(), 1) {
340 0 : return Err(AuthError::too_many_connections());
341 3 : }
342 3 : let cached_secret = api.get_role_secret(ctx, &info).await?;
343 3 : let (cached_entry, secret) = cached_secret.take_value();
344 :
345 3 : let secret = if let Some(secret) = secret {
346 3 : config.check_rate_limit(
347 3 : ctx,
348 3 : secret,
349 3 : &info.endpoint,
350 3 : unauthenticated_password.is_some() || allow_cleartext,
351 0 : )?
352 : } else {
353 : // If we don't have an authentication secret, we mock one to
354 : // prevent malicious probing (possible due to missing protocol steps).
355 : // This mocked secret will never lead to successful authentication.
356 0 : info!("authentication info not found, mocking it");
357 0 : AuthSecret::Scram(scram::ServerSecret::mock(rand::random()))
358 : };
359 :
360 3 : match authenticate_with_secret(
361 3 : ctx,
362 3 : secret,
363 3 : info,
364 3 : client,
365 3 : unauthenticated_password,
366 3 : allow_cleartext,
367 3 : config,
368 3 : )
369 3 : .await
370 : {
371 3 : Ok(keys) => Ok((keys, Some(allowed_ips.as_ref().clone()))),
372 0 : Err(e) => {
373 0 : if e.is_password_failed() {
374 0 : // The password could have been changed, so we invalidate the cache.
375 0 : cached_entry.invalidate();
376 0 : }
377 0 : Err(e)
378 : }
379 : }
380 3 : }
381 :
382 3 : async fn authenticate_with_secret(
383 3 : ctx: &RequestContext,
384 3 : secret: AuthSecret,
385 3 : info: ComputeUserInfo,
386 3 : client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
387 3 : unauthenticated_password: Option<Vec<u8>>,
388 3 : allow_cleartext: bool,
389 3 : config: &'static AuthenticationConfig,
390 3 : ) -> auth::Result<ComputeCredentials> {
391 3 : if let Some(password) = unauthenticated_password {
392 1 : let ep = EndpointIdInt::from(&info.endpoint);
393 :
394 1 : let auth_outcome =
395 1 : validate_password_and_exchange(&config.thread_pool, ep, &password, secret).await?;
396 1 : let keys = match auth_outcome {
397 1 : crate::sasl::Outcome::Success(key) => key,
398 0 : crate::sasl::Outcome::Failure(reason) => {
399 0 : info!("auth backend failed with an error: {reason}");
400 0 : return Err(auth::AuthError::password_failed(&*info.user));
401 : }
402 : };
403 :
404 : // we have authenticated the password
405 1 : client.write_message_noflush(&pq_proto::BeMessage::AuthenticationOk)?;
406 :
407 1 : return Ok(ComputeCredentials { info, keys });
408 2 : }
409 2 :
410 2 : // -- the remaining flows are self-authenticating --
411 2 :
412 2 : // Perform cleartext auth if we're allowed to do that.
413 2 : // Currently, we use it for websocket connections (latency).
414 2 : if allow_cleartext {
415 1 : ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
416 1 : return hacks::authenticate_cleartext(ctx, info, client, secret, config).await;
417 1 : }
418 1 :
419 1 : // Finally, proceed with the main auth flow (SCRAM-based).
420 1 : classic::authenticate(ctx, info, client, config, secret).await
421 3 : }
422 :
423 : impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> {
424 : /// Get username from the credentials.
425 0 : pub(crate) fn get_user(&self) -> &str {
426 0 : match self {
427 0 : Self::ControlPlane(_, user_info) => &user_info.user,
428 0 : Self::Local(_) => "local",
429 : }
430 0 : }
431 :
432 : /// Authenticate the client via the requested backend, possibly using credentials.
433 : #[tracing::instrument(fields(allow_cleartext = allow_cleartext), skip_all)]
434 : pub(crate) async fn authenticate(
435 : self,
436 : ctx: &RequestContext,
437 : client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
438 : allow_cleartext: bool,
439 : config: &'static AuthenticationConfig,
440 : endpoint_rate_limiter: Arc<EndpointRateLimiter>,
441 : ) -> auth::Result<(Backend<'a, ComputeCredentials>, Option<Vec<IpPattern>>)> {
442 : let res = match self {
443 : Self::ControlPlane(api, user_info) => {
444 : debug!(
445 : user = &*user_info.user,
446 : project = user_info.endpoint(),
447 : "performing authentication using the console"
448 : );
449 :
450 : let (credentials, ip_allowlist) = auth_quirks(
451 : ctx,
452 : &*api,
453 : user_info,
454 : client,
455 : allow_cleartext,
456 : config,
457 : endpoint_rate_limiter,
458 : )
459 : .await?;
460 : Ok((Backend::ControlPlane(api, credentials), ip_allowlist))
461 : }
462 : Self::Local(_) => {
463 : return Err(auth::AuthError::bad_auth_method("invalid for local proxy"));
464 : }
465 : };
466 :
467 : // TODO: replace with some metric
468 : info!("user successfully authenticated");
469 : res
470 : }
471 : }
472 :
473 : impl Backend<'_, ComputeUserInfo> {
474 0 : pub(crate) async fn get_role_secret(
475 0 : &self,
476 0 : ctx: &RequestContext,
477 0 : ) -> Result<CachedRoleSecret, GetAuthInfoError> {
478 0 : match self {
479 0 : Self::ControlPlane(api, user_info) => api.get_role_secret(ctx, user_info).await,
480 0 : Self::Local(_) => Ok(Cached::new_uncached(None)),
481 : }
482 0 : }
483 :
484 0 : pub(crate) async fn get_allowed_ips(
485 0 : &self,
486 0 : ctx: &RequestContext,
487 0 : ) -> Result<CachedAllowedIps, GetAuthInfoError> {
488 0 : match self {
489 0 : Self::ControlPlane(api, user_info) => api.get_allowed_ips(ctx, user_info).await,
490 0 : Self::Local(_) => Ok(Cached::new_uncached(Arc::new(vec![]))),
491 : }
492 0 : }
493 :
494 0 : pub(crate) async fn get_allowed_vpc_endpoint_ids(
495 0 : &self,
496 0 : ctx: &RequestContext,
497 0 : ) -> Result<CachedAllowedVpcEndpointIds, GetAuthInfoError> {
498 0 : match self {
499 0 : Self::ControlPlane(api, user_info) => {
500 0 : api.get_allowed_vpc_endpoint_ids(ctx, user_info).await
501 : }
502 0 : Self::Local(_) => Ok(Cached::new_uncached(Arc::new(vec![]))),
503 : }
504 0 : }
505 :
506 0 : pub(crate) async fn get_block_public_or_vpc_access(
507 0 : &self,
508 0 : ctx: &RequestContext,
509 0 : ) -> Result<CachedAccessBlockerFlags, GetAuthInfoError> {
510 0 : match self {
511 0 : Self::ControlPlane(api, user_info) => {
512 0 : api.get_block_public_or_vpc_access(ctx, user_info).await
513 : }
514 0 : Self::Local(_) => Ok(Cached::new_uncached(AccessBlockerFlags::default())),
515 : }
516 0 : }
517 : }
518 :
519 : #[async_trait::async_trait]
520 : impl ComputeConnectBackend for Backend<'_, ComputeCredentials> {
521 19 : async fn wake_compute(
522 19 : &self,
523 19 : ctx: &RequestContext,
524 19 : ) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
525 19 : match self {
526 19 : Self::ControlPlane(api, creds) => api.wake_compute(ctx, &creds.info).await,
527 0 : Self::Local(local) => Ok(Cached::new_uncached(local.node_info.clone())),
528 : }
529 38 : }
530 :
531 10 : fn get_keys(&self) -> &ComputeCredentialKeys {
532 10 : match self {
533 10 : Self::ControlPlane(_, creds) => &creds.keys,
534 0 : Self::Local(_) => &ComputeCredentialKeys::None,
535 : }
536 10 : }
537 : }
538 :
539 : #[cfg(test)]
540 : mod tests {
541 : #![allow(clippy::unimplemented, clippy::unwrap_used)]
542 :
543 : use std::net::IpAddr;
544 : use std::sync::Arc;
545 : use std::time::Duration;
546 :
547 : use bytes::BytesMut;
548 : use control_plane::AuthSecret;
549 : use fallible_iterator::FallibleIterator;
550 : use once_cell::sync::Lazy;
551 : use postgres_protocol::authentication::sasl::{ChannelBinding, ScramSha256};
552 : use postgres_protocol::message::backend::Message as PgMessage;
553 : use postgres_protocol::message::frontend;
554 : use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
555 :
556 : use super::jwt::JwkCache;
557 : use super::{AuthRateLimiter, auth_quirks};
558 : use crate::auth::backend::MaskedIp;
559 : use crate::auth::{ComputeUserInfoMaybeEndpoint, IpPattern};
560 : use crate::config::AuthenticationConfig;
561 : use crate::context::RequestContext;
562 : use crate::control_plane::{
563 : self, AccessBlockerFlags, CachedAccessBlockerFlags, CachedAllowedIps,
564 : CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret,
565 : };
566 : use crate::proxy::NeonOptions;
567 : use crate::rate_limiter::{EndpointRateLimiter, RateBucketInfo};
568 : use crate::scram::ServerSecret;
569 : use crate::scram::threadpool::ThreadPool;
570 : use crate::stream::{PqStream, Stream};
571 :
572 : struct Auth {
573 : ips: Vec<IpPattern>,
574 : vpc_endpoint_ids: Vec<String>,
575 : access_blocker_flags: AccessBlockerFlags,
576 : secret: AuthSecret,
577 : }
578 :
579 : impl control_plane::ControlPlaneApi for Auth {
580 3 : async fn get_role_secret(
581 3 : &self,
582 3 : _ctx: &RequestContext,
583 3 : _user_info: &super::ComputeUserInfo,
584 3 : ) -> Result<CachedRoleSecret, control_plane::errors::GetAuthInfoError> {
585 3 : Ok(CachedRoleSecret::new_uncached(Some(self.secret.clone())))
586 3 : }
587 :
588 3 : async fn get_allowed_ips(
589 3 : &self,
590 3 : _ctx: &RequestContext,
591 3 : _user_info: &super::ComputeUserInfo,
592 3 : ) -> Result<CachedAllowedIps, control_plane::errors::GetAuthInfoError> {
593 3 : Ok(CachedAllowedIps::new_uncached(Arc::new(self.ips.clone())))
594 3 : }
595 :
596 0 : async fn get_allowed_vpc_endpoint_ids(
597 0 : &self,
598 0 : _ctx: &RequestContext,
599 0 : _user_info: &super::ComputeUserInfo,
600 0 : ) -> Result<CachedAllowedVpcEndpointIds, control_plane::errors::GetAuthInfoError> {
601 0 : Ok(CachedAllowedVpcEndpointIds::new_uncached(Arc::new(
602 0 : self.vpc_endpoint_ids.clone(),
603 0 : )))
604 0 : }
605 :
606 3 : async fn get_block_public_or_vpc_access(
607 3 : &self,
608 3 : _ctx: &RequestContext,
609 3 : _user_info: &super::ComputeUserInfo,
610 3 : ) -> Result<CachedAccessBlockerFlags, control_plane::errors::GetAuthInfoError> {
611 3 : Ok(CachedAccessBlockerFlags::new_uncached(
612 3 : self.access_blocker_flags.clone(),
613 3 : ))
614 3 : }
615 :
616 0 : async fn get_endpoint_jwks(
617 0 : &self,
618 0 : _ctx: &RequestContext,
619 0 : _endpoint: crate::types::EndpointId,
620 0 : ) -> Result<Vec<super::jwt::AuthRule>, control_plane::errors::GetEndpointJwksError>
621 0 : {
622 0 : unimplemented!()
623 : }
624 :
625 0 : async fn wake_compute(
626 0 : &self,
627 0 : _ctx: &RequestContext,
628 0 : _user_info: &super::ComputeUserInfo,
629 0 : ) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
630 0 : unimplemented!()
631 : }
632 : }
633 :
634 3 : static CONFIG: Lazy<AuthenticationConfig> = Lazy::new(|| AuthenticationConfig {
635 3 : jwks_cache: JwkCache::default(),
636 3 : thread_pool: ThreadPool::new(1),
637 3 : scram_protocol_timeout: std::time::Duration::from_secs(5),
638 3 : rate_limiter_enabled: true,
639 3 : rate_limiter: AuthRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET),
640 3 : rate_limit_ip_subnet: 64,
641 3 : ip_allowlist_check_enabled: true,
642 3 : is_vpc_acccess_proxy: false,
643 3 : is_auth_broker: false,
644 3 : accept_jwts: false,
645 3 : console_redirect_confirmation_timeout: std::time::Duration::from_secs(5),
646 3 : });
647 :
648 5 : async fn read_message(r: &mut (impl AsyncRead + Unpin), b: &mut BytesMut) -> PgMessage {
649 : loop {
650 7 : r.read_buf(&mut *b).await.unwrap();
651 7 : if let Some(m) = PgMessage::parse(&mut *b).unwrap() {
652 5 : break m;
653 2 : }
654 : }
655 5 : }
656 :
657 : #[test]
658 1 : fn masked_ip() {
659 1 : let ip_a = IpAddr::V4([127, 0, 0, 1].into());
660 1 : let ip_b = IpAddr::V4([127, 0, 0, 2].into());
661 1 : let ip_c = IpAddr::V4([192, 168, 1, 101].into());
662 1 : let ip_d = IpAddr::V4([192, 168, 1, 102].into());
663 1 : let ip_e = IpAddr::V6("abcd:abcd:abcd:abcd:abcd:abcd:abcd:abcd".parse().unwrap());
664 1 : let ip_f = IpAddr::V6("abcd:abcd:abcd:abcd:1234:abcd:abcd:abcd".parse().unwrap());
665 1 :
666 1 : assert_ne!(MaskedIp::new(ip_a, 64), MaskedIp::new(ip_b, 64));
667 1 : assert_ne!(MaskedIp::new(ip_a, 32), MaskedIp::new(ip_b, 32));
668 1 : assert_eq!(MaskedIp::new(ip_a, 30), MaskedIp::new(ip_b, 30));
669 1 : assert_eq!(MaskedIp::new(ip_c, 30), MaskedIp::new(ip_d, 30));
670 :
671 1 : assert_ne!(MaskedIp::new(ip_e, 128), MaskedIp::new(ip_f, 128));
672 1 : assert_eq!(MaskedIp::new(ip_e, 64), MaskedIp::new(ip_f, 64));
673 1 : }
674 :
675 : #[test]
676 1 : fn test_default_auth_rate_limit_set() {
677 1 : // these values used to exceed u32::MAX
678 1 : assert_eq!(
679 1 : RateBucketInfo::DEFAULT_AUTH_SET,
680 1 : [
681 1 : RateBucketInfo {
682 1 : interval: Duration::from_secs(1),
683 1 : max_rpi: 1000 * 4096,
684 1 : },
685 1 : RateBucketInfo {
686 1 : interval: Duration::from_secs(60),
687 1 : max_rpi: 600 * 4096 * 60,
688 1 : },
689 1 : RateBucketInfo {
690 1 : interval: Duration::from_secs(600),
691 1 : max_rpi: 300 * 4096 * 600,
692 1 : }
693 1 : ]
694 1 : );
695 :
696 4 : for x in RateBucketInfo::DEFAULT_AUTH_SET {
697 3 : let y = x.to_string().parse().unwrap();
698 3 : assert_eq!(x, y);
699 : }
700 1 : }
701 :
702 : #[tokio::test]
703 1 : async fn auth_quirks_scram() {
704 1 : let (mut client, server) = tokio::io::duplex(1024);
705 1 : let mut stream = PqStream::new(Stream::from_raw(server));
706 1 :
707 1 : let ctx = RequestContext::test();
708 1 : let api = Auth {
709 1 : ips: vec![],
710 1 : vpc_endpoint_ids: vec![],
711 1 : access_blocker_flags: AccessBlockerFlags::default(),
712 1 : secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
713 1 : };
714 1 :
715 1 : let user_info = ComputeUserInfoMaybeEndpoint {
716 1 : user: "conrad".into(),
717 1 : endpoint_id: Some("endpoint".into()),
718 1 : options: NeonOptions::default(),
719 1 : };
720 1 :
721 1 : let handle = tokio::spawn(async move {
722 1 : let mut scram = ScramSha256::new(b"my-secret-password", ChannelBinding::unsupported());
723 1 :
724 1 : let mut read = BytesMut::new();
725 1 :
726 1 : // server should offer scram
727 1 : match read_message(&mut client, &mut read).await {
728 1 : PgMessage::AuthenticationSasl(a) => {
729 1 : let options: Vec<&str> = a.mechanisms().collect().unwrap();
730 1 : assert_eq!(options, ["SCRAM-SHA-256"]);
731 1 : }
732 1 : _ => panic!("wrong message"),
733 1 : }
734 1 :
735 1 : // client sends client-first-message
736 1 : let mut write = BytesMut::new();
737 1 : frontend::sasl_initial_response("SCRAM-SHA-256", scram.message(), &mut write).unwrap();
738 1 : client.write_all(&write).await.unwrap();
739 1 :
740 1 : // server response with server-first-message
741 1 : match read_message(&mut client, &mut read).await {
742 1 : PgMessage::AuthenticationSaslContinue(a) => {
743 1 : scram.update(a.data()).await.unwrap();
744 1 : }
745 1 : _ => panic!("wrong message"),
746 1 : }
747 1 :
748 1 : // client response with client-final-message
749 1 : write.clear();
750 1 : frontend::sasl_response(scram.message(), &mut write).unwrap();
751 1 : client.write_all(&write).await.unwrap();
752 1 :
753 1 : // server response with server-final-message
754 1 : match read_message(&mut client, &mut read).await {
755 1 : PgMessage::AuthenticationSaslFinal(a) => {
756 1 : scram.finish(a.data()).unwrap();
757 1 : }
758 1 : _ => panic!("wrong message"),
759 1 : }
760 1 : });
761 1 : let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
762 1 : EndpointRateLimiter::DEFAULT,
763 1 : 64,
764 1 : ));
765 1 :
766 1 : let _creds = auth_quirks(
767 1 : &ctx,
768 1 : &api,
769 1 : user_info,
770 1 : &mut stream,
771 1 : false,
772 1 : &CONFIG,
773 1 : endpoint_rate_limiter,
774 1 : )
775 1 : .await
776 1 : .unwrap();
777 1 :
778 1 : // flush the final server message
779 1 : stream.flush().await.unwrap();
780 1 :
781 1 : handle.await.unwrap();
782 1 : }
783 :
784 : #[tokio::test]
785 1 : async fn auth_quirks_cleartext() {
786 1 : let (mut client, server) = tokio::io::duplex(1024);
787 1 : let mut stream = PqStream::new(Stream::from_raw(server));
788 1 :
789 1 : let ctx = RequestContext::test();
790 1 : let api = Auth {
791 1 : ips: vec![],
792 1 : vpc_endpoint_ids: vec![],
793 1 : access_blocker_flags: AccessBlockerFlags::default(),
794 1 : secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
795 1 : };
796 1 :
797 1 : let user_info = ComputeUserInfoMaybeEndpoint {
798 1 : user: "conrad".into(),
799 1 : endpoint_id: Some("endpoint".into()),
800 1 : options: NeonOptions::default(),
801 1 : };
802 1 :
803 1 : let handle = tokio::spawn(async move {
804 1 : let mut read = BytesMut::new();
805 1 : let mut write = BytesMut::new();
806 1 :
807 1 : // server should offer cleartext
808 1 : match read_message(&mut client, &mut read).await {
809 1 : PgMessage::AuthenticationCleartextPassword => {}
810 1 : _ => panic!("wrong message"),
811 1 : }
812 1 :
813 1 : // client responds with password
814 1 : write.clear();
815 1 : frontend::password_message(b"my-secret-password", &mut write).unwrap();
816 1 : client.write_all(&write).await.unwrap();
817 1 : });
818 1 : let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
819 1 : EndpointRateLimiter::DEFAULT,
820 1 : 64,
821 1 : ));
822 1 :
823 1 : let _creds = auth_quirks(
824 1 : &ctx,
825 1 : &api,
826 1 : user_info,
827 1 : &mut stream,
828 1 : true,
829 1 : &CONFIG,
830 1 : endpoint_rate_limiter,
831 1 : )
832 1 : .await
833 1 : .unwrap();
834 1 :
835 1 : handle.await.unwrap();
836 1 : }
837 :
838 : #[tokio::test]
839 1 : async fn auth_quirks_password_hack() {
840 1 : let (mut client, server) = tokio::io::duplex(1024);
841 1 : let mut stream = PqStream::new(Stream::from_raw(server));
842 1 :
843 1 : let ctx = RequestContext::test();
844 1 : let api = Auth {
845 1 : ips: vec![],
846 1 : vpc_endpoint_ids: vec![],
847 1 : access_blocker_flags: AccessBlockerFlags::default(),
848 1 : secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
849 1 : };
850 1 :
851 1 : let user_info = ComputeUserInfoMaybeEndpoint {
852 1 : user: "conrad".into(),
853 1 : endpoint_id: None,
854 1 : options: NeonOptions::default(),
855 1 : };
856 1 :
857 1 : let handle = tokio::spawn(async move {
858 1 : let mut read = BytesMut::new();
859 1 :
860 1 : // server should offer cleartext
861 1 : match read_message(&mut client, &mut read).await {
862 1 : PgMessage::AuthenticationCleartextPassword => {}
863 1 : _ => panic!("wrong message"),
864 1 : }
865 1 :
866 1 : // client responds with password
867 1 : let mut write = BytesMut::new();
868 1 : frontend::password_message(b"endpoint=my-endpoint;my-secret-password", &mut write)
869 1 : .unwrap();
870 1 : client.write_all(&write).await.unwrap();
871 1 : });
872 1 :
873 1 : let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
874 1 : EndpointRateLimiter::DEFAULT,
875 1 : 64,
876 1 : ));
877 1 :
878 1 : let creds = auth_quirks(
879 1 : &ctx,
880 1 : &api,
881 1 : user_info,
882 1 : &mut stream,
883 1 : true,
884 1 : &CONFIG,
885 1 : endpoint_rate_limiter,
886 1 : )
887 1 : .await
888 1 : .unwrap();
889 1 :
890 1 : assert_eq!(creds.0.info.endpoint, "my-endpoint");
891 1 :
892 1 : handle.await.unwrap();
893 1 : }
894 : }
|