Line data Source code
1 : mod classic;
2 : mod console_redirect;
3 : mod hacks;
4 : pub mod jwt;
5 : pub mod local;
6 :
7 : use std::sync::Arc;
8 :
9 : pub use console_redirect::ConsoleRedirectBackend;
10 : pub(crate) use console_redirect::ConsoleRedirectError;
11 : use local::LocalBackend;
12 : use postgres_client::config::AuthKeys;
13 : use serde::{Deserialize, Serialize};
14 : use tokio::io::{AsyncRead, AsyncWrite};
15 : use tracing::{debug, info};
16 :
17 : use crate::auth::{self, ComputeUserInfoMaybeEndpoint, validate_password_and_exchange};
18 : use crate::cache::Cached;
19 : use crate::config::AuthenticationConfig;
20 : use crate::context::RequestContext;
21 : use crate::control_plane::client::ControlPlaneClient;
22 : use crate::control_plane::errors::GetAuthInfoError;
23 : use crate::control_plane::messages::EndpointRateLimitConfig;
24 : use crate::control_plane::{
25 : self, AccessBlockerFlags, AuthSecret, CachedNodeInfo, ControlPlaneApi, EndpointAccessControl,
26 : RoleAccessControl,
27 : };
28 : use crate::intern::EndpointIdInt;
29 : use crate::pqproto::BeMessage;
30 : use crate::proxy::NeonOptions;
31 : use crate::proxy::wake_compute::WakeComputeBackend;
32 : use crate::rate_limiter::EndpointRateLimiter;
33 : use crate::stream::Stream;
34 : use crate::types::{EndpointCacheKey, EndpointId, RoleName};
35 : use crate::{scram, stream};
36 :
37 : /// Alternative to [`std::borrow::Cow`] but doesn't need `T: ToOwned` as we don't need that functionality
38 : pub enum MaybeOwned<'a, T> {
39 : Owned(T),
40 : Borrowed(&'a T),
41 : }
42 :
43 : impl<T> std::ops::Deref for MaybeOwned<'_, T> {
44 : type Target = T;
45 :
46 19 : fn deref(&self) -> &Self::Target {
47 19 : match self {
48 19 : MaybeOwned::Owned(t) => t,
49 0 : MaybeOwned::Borrowed(t) => t,
50 : }
51 19 : }
52 : }
53 :
54 : /// This type serves two purposes:
55 : ///
56 : /// * When `T` is `()`, it's just a regular auth backend selector
57 : /// which we use in [`crate::config::ProxyConfig`].
58 : ///
59 : /// * However, when we substitute `T` with [`ComputeUserInfoMaybeEndpoint`],
60 : /// this helps us provide the credentials only to those auth
61 : /// backends which require them for the authentication process.
62 : pub enum Backend<'a, T> {
63 : /// Cloud API (V2).
64 : ControlPlane(MaybeOwned<'a, ControlPlaneClient>, T),
65 : /// Local proxy uses configured auth credentials and does not wake compute
66 : Local(MaybeOwned<'a, LocalBackend>),
67 : }
68 :
69 : impl std::fmt::Display for Backend<'_, ()> {
70 0 : fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 0 : match self {
72 0 : Self::ControlPlane(api, ()) => match &**api {
73 0 : ControlPlaneClient::ProxyV1(endpoint) => fmt
74 0 : .debug_tuple("ControlPlane::ProxyV1")
75 0 : .field(&endpoint.url())
76 0 : .finish(),
77 : #[cfg(any(test, feature = "testing"))]
78 0 : ControlPlaneClient::PostgresMock(endpoint) => {
79 0 : let url = endpoint.url();
80 0 : match url::Url::parse(url) {
81 0 : Ok(mut url) => {
82 0 : let _ = url.set_password(Some("_redacted_"));
83 0 : let url = url.as_str();
84 0 : fmt.debug_tuple("ControlPlane::PostgresMock")
85 0 : .field(&url)
86 0 : .finish()
87 : }
88 0 : Err(_) => fmt
89 0 : .debug_tuple("ControlPlane::PostgresMock")
90 0 : .field(&url)
91 0 : .finish(),
92 : }
93 : }
94 : #[cfg(test)]
95 0 : ControlPlaneClient::Test(_) => fmt.debug_tuple("ControlPlane::Test").finish(),
96 : },
97 0 : Self::Local(_) => fmt.debug_tuple("Local").finish(),
98 : }
99 0 : }
100 : }
101 :
102 : impl<T> Backend<'_, T> {
103 : /// Very similar to [`std::option::Option::as_ref`].
104 : /// This helps us pass structured config to async tasks.
105 0 : pub(crate) fn as_ref(&self) -> Backend<'_, &T> {
106 0 : match self {
107 0 : Self::ControlPlane(c, x) => Backend::ControlPlane(MaybeOwned::Borrowed(c), x),
108 0 : Self::Local(l) => Backend::Local(MaybeOwned::Borrowed(l)),
109 : }
110 0 : }
111 :
112 0 : pub(crate) fn get_api(&self) -> &ControlPlaneClient {
113 0 : match self {
114 0 : Self::ControlPlane(api, _) => api,
115 0 : Self::Local(_) => panic!("Local backend has no API"),
116 : }
117 0 : }
118 :
119 0 : pub(crate) fn is_local_proxy(&self) -> bool {
120 0 : matches!(self, Self::Local(_))
121 0 : }
122 : }
123 :
124 : impl<'a, T> Backend<'a, T> {
125 : /// Very similar to [`std::option::Option::map`].
126 : /// Maps [`Backend<T>`] to [`Backend<R>`] by applying
127 : /// a function to a contained value.
128 0 : pub(crate) fn map<R>(self, f: impl FnOnce(T) -> R) -> Backend<'a, R> {
129 0 : match self {
130 0 : Self::ControlPlane(c, x) => Backend::ControlPlane(c, f(x)),
131 0 : Self::Local(l) => Backend::Local(l),
132 : }
133 0 : }
134 : }
135 : impl<'a, T, E> Backend<'a, Result<T, E>> {
136 : /// Very similar to [`std::option::Option::transpose`].
137 : /// This is most useful for error handling.
138 0 : pub(crate) fn transpose(self) -> Result<Backend<'a, T>, E> {
139 0 : match self {
140 0 : Self::ControlPlane(c, x) => x.map(|x| Backend::ControlPlane(c, x)),
141 0 : Self::Local(l) => Ok(Backend::Local(l)),
142 : }
143 0 : }
144 : }
145 :
146 : pub(crate) struct ComputeCredentials {
147 : pub(crate) info: ComputeUserInfo,
148 : pub(crate) keys: ComputeCredentialKeys,
149 : }
150 :
151 : #[derive(Debug, Clone)]
152 : pub(crate) struct ComputeUserInfoNoEndpoint {
153 : pub(crate) user: RoleName,
154 : pub(crate) options: NeonOptions,
155 : }
156 :
157 0 : #[derive(Debug, Clone, Default, Serialize, Deserialize)]
158 : pub(crate) struct ComputeUserInfo {
159 : pub(crate) endpoint: EndpointId,
160 : pub(crate) user: RoleName,
161 : pub(crate) options: NeonOptions,
162 : }
163 :
164 : impl ComputeUserInfo {
165 2 : pub(crate) fn endpoint_cache_key(&self) -> EndpointCacheKey {
166 2 : self.options.get_cache_key(&self.endpoint)
167 2 : }
168 : }
169 :
170 : #[cfg_attr(test, derive(Debug))]
171 : pub(crate) enum ComputeCredentialKeys {
172 : AuthKeys(AuthKeys),
173 : JwtPayload(Vec<u8>),
174 : None,
175 : }
176 :
177 : impl TryFrom<ComputeUserInfoMaybeEndpoint> for ComputeUserInfo {
178 : // user name
179 : type Error = ComputeUserInfoNoEndpoint;
180 :
181 3 : fn try_from(user_info: ComputeUserInfoMaybeEndpoint) -> Result<Self, Self::Error> {
182 3 : match user_info.endpoint_id {
183 1 : None => Err(ComputeUserInfoNoEndpoint {
184 1 : user: user_info.user,
185 1 : options: user_info.options,
186 1 : }),
187 2 : Some(endpoint) => Ok(ComputeUserInfo {
188 2 : endpoint,
189 2 : user: user_info.user,
190 2 : options: user_info.options,
191 2 : }),
192 : }
193 3 : }
194 : }
195 :
196 : /// True to its name, this function encapsulates our current auth trade-offs.
197 : /// Here, we choose the appropriate auth flow based on circumstances.
198 : ///
199 : /// All authentication flows will emit an AuthenticationOk message if successful.
200 3 : async fn auth_quirks(
201 3 : ctx: &RequestContext,
202 3 : api: &impl control_plane::ControlPlaneApi,
203 3 : user_info: ComputeUserInfoMaybeEndpoint,
204 3 : client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
205 3 : allow_cleartext: bool,
206 3 : config: &'static AuthenticationConfig,
207 3 : endpoint_rate_limiter: Arc<EndpointRateLimiter>,
208 3 : ) -> auth::Result<ComputeCredentials> {
209 : // If there's no project so far, that entails that client doesn't
210 : // support SNI or other means of passing the endpoint (project) name.
211 : // We now expect to see a very specific payload in the place of password.
212 3 : let (info, unauthenticated_password) = match user_info.try_into() {
213 1 : Err(info) => {
214 1 : let (info, password) =
215 1 : hacks::password_hack_no_authentication(ctx, info, client).await?;
216 1 : ctx.set_endpoint_id(info.endpoint.clone());
217 1 : (info, Some(password))
218 : }
219 2 : Ok(info) => (info, None),
220 : };
221 :
222 3 : debug!("fetching authentication info and allowlists");
223 :
224 3 : let access_controls = api
225 3 : .get_endpoint_access_control(ctx, &info.endpoint, &info.user)
226 3 : .await?;
227 :
228 3 : access_controls.check(
229 3 : ctx,
230 3 : config.ip_allowlist_check_enabled,
231 3 : config.is_vpc_acccess_proxy,
232 3 : )?;
233 :
234 3 : access_controls.connection_attempt_rate_limit(ctx, &info.endpoint, &endpoint_rate_limiter)?;
235 :
236 3 : let role_access = api
237 3 : .get_role_access_control(ctx, &info.endpoint, &info.user)
238 3 : .await?;
239 :
240 3 : let secret = if let Some(secret) = role_access.secret {
241 3 : secret
242 : } else {
243 : // If we don't have an authentication secret, we mock one to
244 : // prevent malicious probing (possible due to missing protocol steps).
245 : // This mocked secret will never lead to successful authentication.
246 0 : info!("authentication info not found, mocking it");
247 0 : AuthSecret::Scram(scram::ServerSecret::mock(rand::random()))
248 : };
249 :
250 3 : match authenticate_with_secret(
251 3 : ctx,
252 3 : secret,
253 3 : info,
254 3 : client,
255 3 : unauthenticated_password,
256 3 : allow_cleartext,
257 3 : config,
258 3 : )
259 3 : .await
260 : {
261 3 : Ok(keys) => Ok(keys),
262 0 : Err(e) => Err(e),
263 : }
264 3 : }
265 :
266 3 : async fn authenticate_with_secret(
267 3 : ctx: &RequestContext,
268 3 : secret: AuthSecret,
269 3 : info: ComputeUserInfo,
270 3 : client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
271 3 : unauthenticated_password: Option<Vec<u8>>,
272 3 : allow_cleartext: bool,
273 3 : config: &'static AuthenticationConfig,
274 3 : ) -> auth::Result<ComputeCredentials> {
275 3 : if let Some(password) = unauthenticated_password {
276 1 : let ep = EndpointIdInt::from(&info.endpoint);
277 :
278 1 : let auth_outcome =
279 1 : validate_password_and_exchange(&config.thread_pool, ep, &password, secret).await?;
280 1 : let keys = match auth_outcome {
281 1 : crate::sasl::Outcome::Success(key) => key,
282 0 : crate::sasl::Outcome::Failure(reason) => {
283 0 : info!("auth backend failed with an error: {reason}");
284 0 : return Err(auth::AuthError::password_failed(&*info.user));
285 : }
286 : };
287 :
288 : // we have authenticated the password
289 1 : client.write_message(BeMessage::AuthenticationOk);
290 1 :
291 1 : return Ok(ComputeCredentials { info, keys });
292 2 : }
293 2 :
294 2 : // -- the remaining flows are self-authenticating --
295 2 :
296 2 : // Perform cleartext auth if we're allowed to do that.
297 2 : // Currently, we use it for websocket connections (latency).
298 2 : if allow_cleartext {
299 1 : ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
300 1 : return hacks::authenticate_cleartext(ctx, info, client, secret, config).await;
301 1 : }
302 1 :
303 1 : // Finally, proceed with the main auth flow (SCRAM-based).
304 1 : classic::authenticate(ctx, info, client, config, secret).await
305 3 : }
306 :
307 : impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> {
308 : /// Get username from the credentials.
309 0 : pub(crate) fn get_user(&self) -> &str {
310 0 : match self {
311 0 : Self::ControlPlane(_, user_info) => &user_info.user,
312 0 : Self::Local(_) => "local",
313 : }
314 0 : }
315 :
316 : /// Authenticate the client via the requested backend, possibly using credentials.
317 : #[tracing::instrument(fields(allow_cleartext = allow_cleartext), skip_all)]
318 : pub(crate) async fn authenticate(
319 : self,
320 : ctx: &RequestContext,
321 : client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
322 : allow_cleartext: bool,
323 : config: &'static AuthenticationConfig,
324 : endpoint_rate_limiter: Arc<EndpointRateLimiter>,
325 : ) -> auth::Result<Backend<'a, ComputeCredentials>> {
326 : let res = match self {
327 : Self::ControlPlane(api, user_info) => {
328 : debug!(
329 : user = &*user_info.user,
330 : project = user_info.endpoint(),
331 : "performing authentication using the console"
332 : );
333 :
334 : let auth_res = auth_quirks(
335 : ctx,
336 : &*api,
337 : user_info.clone(),
338 : client,
339 : allow_cleartext,
340 : config,
341 : endpoint_rate_limiter,
342 : )
343 : .await;
344 : match auth_res {
345 : Ok(credentials) => Ok(Backend::ControlPlane(api, credentials)),
346 : Err(e) => {
347 : // The password could have been changed, so we invalidate the cache.
348 : // We should only invalidate the cache if the TTL might have expired.
349 : if e.is_password_failed() {
350 : #[allow(irrefutable_let_patterns)]
351 : if let ControlPlaneClient::ProxyV1(api) = &*api {
352 : if let Some(ep) = &user_info.endpoint_id {
353 : api.caches
354 : .project_info
355 : .maybe_invalidate_role_secret(ep, &user_info.user);
356 : }
357 : }
358 : }
359 :
360 : Err(e)
361 : }
362 : }
363 : }
364 : Self::Local(_) => {
365 : return Err(auth::AuthError::bad_auth_method("invalid for local proxy"));
366 : }
367 : };
368 :
369 : // TODO: replace with some metric
370 : info!("user successfully authenticated");
371 : res
372 : }
373 : }
374 :
375 : impl Backend<'_, ComputeUserInfo> {
376 0 : pub(crate) async fn get_role_secret(
377 0 : &self,
378 0 : ctx: &RequestContext,
379 0 : ) -> Result<RoleAccessControl, GetAuthInfoError> {
380 0 : match self {
381 0 : Self::ControlPlane(api, user_info) => {
382 0 : api.get_role_access_control(ctx, &user_info.endpoint, &user_info.user)
383 0 : .await
384 : }
385 0 : Self::Local(_) => Ok(RoleAccessControl { secret: None }),
386 : }
387 0 : }
388 :
389 0 : pub(crate) async fn get_endpoint_access_control(
390 0 : &self,
391 0 : ctx: &RequestContext,
392 0 : ) -> Result<EndpointAccessControl, GetAuthInfoError> {
393 0 : match self {
394 0 : Self::ControlPlane(api, user_info) => {
395 0 : api.get_endpoint_access_control(ctx, &user_info.endpoint, &user_info.user)
396 0 : .await
397 : }
398 0 : Self::Local(_) => Ok(EndpointAccessControl {
399 0 : allowed_ips: Arc::new(vec![]),
400 0 : allowed_vpce: Arc::new(vec![]),
401 0 : flags: AccessBlockerFlags::default(),
402 0 : rate_limits: EndpointRateLimitConfig::default(),
403 0 : }),
404 : }
405 0 : }
406 : }
407 :
408 : #[async_trait::async_trait]
409 : impl WakeComputeBackend for Backend<'_, ComputeUserInfo> {
410 19 : async fn wake_compute(
411 19 : &self,
412 19 : ctx: &RequestContext,
413 19 : ) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
414 19 : match self {
415 19 : Self::ControlPlane(api, info) => api.wake_compute(ctx, info).await,
416 0 : Self::Local(local) => Ok(Cached::new_uncached(local.node_info.clone())),
417 : }
418 38 : }
419 : }
420 :
421 : #[cfg(test)]
422 : mod tests {
423 : #![allow(clippy::unimplemented, clippy::unwrap_used)]
424 :
425 : use std::sync::Arc;
426 :
427 : use bytes::BytesMut;
428 : use control_plane::AuthSecret;
429 : use fallible_iterator::FallibleIterator;
430 : use once_cell::sync::Lazy;
431 : use postgres_protocol::authentication::sasl::{ChannelBinding, ScramSha256};
432 : use postgres_protocol::message::backend::Message as PgMessage;
433 : use postgres_protocol::message::frontend;
434 : use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
435 :
436 : use super::auth_quirks;
437 : use super::jwt::JwkCache;
438 : use crate::auth::{ComputeUserInfoMaybeEndpoint, IpPattern};
439 : use crate::config::AuthenticationConfig;
440 : use crate::context::RequestContext;
441 : use crate::control_plane::messages::EndpointRateLimitConfig;
442 : use crate::control_plane::{
443 : self, AccessBlockerFlags, CachedNodeInfo, EndpointAccessControl, RoleAccessControl,
444 : };
445 : use crate::proxy::NeonOptions;
446 : use crate::rate_limiter::EndpointRateLimiter;
447 : use crate::scram::ServerSecret;
448 : use crate::scram::threadpool::ThreadPool;
449 : use crate::stream::{PqStream, Stream};
450 :
451 : struct Auth {
452 : ips: Vec<IpPattern>,
453 : vpc_endpoint_ids: Vec<String>,
454 : access_blocker_flags: AccessBlockerFlags,
455 : secret: AuthSecret,
456 : }
457 :
458 : impl control_plane::ControlPlaneApi for Auth {
459 3 : async fn get_role_access_control(
460 3 : &self,
461 3 : _ctx: &RequestContext,
462 3 : _endpoint: &crate::types::EndpointId,
463 3 : _role: &crate::types::RoleName,
464 3 : ) -> Result<RoleAccessControl, control_plane::errors::GetAuthInfoError> {
465 3 : Ok(RoleAccessControl {
466 3 : secret: Some(self.secret.clone()),
467 3 : })
468 3 : }
469 :
470 3 : async fn get_endpoint_access_control(
471 3 : &self,
472 3 : _ctx: &RequestContext,
473 3 : _endpoint: &crate::types::EndpointId,
474 3 : _role: &crate::types::RoleName,
475 3 : ) -> Result<EndpointAccessControl, control_plane::errors::GetAuthInfoError> {
476 3 : Ok(EndpointAccessControl {
477 3 : allowed_ips: Arc::new(self.ips.clone()),
478 3 : allowed_vpce: Arc::new(self.vpc_endpoint_ids.clone()),
479 3 : flags: self.access_blocker_flags,
480 3 : rate_limits: EndpointRateLimitConfig::default(),
481 3 : })
482 3 : }
483 :
484 0 : async fn get_endpoint_jwks(
485 0 : &self,
486 0 : _ctx: &RequestContext,
487 0 : _endpoint: &crate::types::EndpointId,
488 0 : ) -> Result<Vec<super::jwt::AuthRule>, control_plane::errors::GetEndpointJwksError>
489 0 : {
490 0 : unimplemented!()
491 : }
492 :
493 0 : async fn wake_compute(
494 0 : &self,
495 0 : _ctx: &RequestContext,
496 0 : _user_info: &super::ComputeUserInfo,
497 0 : ) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
498 0 : unimplemented!()
499 : }
500 : }
501 :
502 3 : static CONFIG: Lazy<AuthenticationConfig> = Lazy::new(|| AuthenticationConfig {
503 3 : jwks_cache: JwkCache::default(),
504 3 : thread_pool: ThreadPool::new(1),
505 3 : scram_protocol_timeout: std::time::Duration::from_secs(5),
506 3 : ip_allowlist_check_enabled: true,
507 3 : is_vpc_acccess_proxy: false,
508 3 : is_auth_broker: false,
509 3 : accept_jwts: false,
510 3 : console_redirect_confirmation_timeout: std::time::Duration::from_secs(5),
511 3 : });
512 :
513 5 : async fn read_message(r: &mut (impl AsyncRead + Unpin), b: &mut BytesMut) -> PgMessage {
514 : loop {
515 7 : r.read_buf(&mut *b).await.unwrap();
516 7 : if let Some(m) = PgMessage::parse(&mut *b).unwrap() {
517 5 : break m;
518 2 : }
519 : }
520 5 : }
521 :
522 : #[tokio::test]
523 1 : async fn auth_quirks_scram() {
524 1 : let (mut client, server) = tokio::io::duplex(1024);
525 1 : let mut stream = PqStream::new_skip_handshake(Stream::from_raw(server));
526 1 :
527 1 : let ctx = RequestContext::test();
528 1 : let api = Auth {
529 1 : ips: vec![],
530 1 : vpc_endpoint_ids: vec![],
531 1 : access_blocker_flags: AccessBlockerFlags::default(),
532 1 : secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
533 1 : };
534 1 :
535 1 : let user_info = ComputeUserInfoMaybeEndpoint {
536 1 : user: "conrad".into(),
537 1 : endpoint_id: Some("endpoint".into()),
538 1 : options: NeonOptions::default(),
539 1 : };
540 1 :
541 1 : let handle = tokio::spawn(async move {
542 1 : let mut scram = ScramSha256::new(b"my-secret-password", ChannelBinding::unsupported());
543 1 :
544 1 : let mut read = BytesMut::new();
545 1 :
546 1 : // server should offer scram
547 1 : match read_message(&mut client, &mut read).await {
548 1 : PgMessage::AuthenticationSasl(a) => {
549 1 : let options: Vec<&str> = a.mechanisms().collect().unwrap();
550 1 : assert_eq!(options, ["SCRAM-SHA-256"]);
551 1 : }
552 1 : _ => panic!("wrong message"),
553 1 : }
554 1 :
555 1 : // client sends client-first-message
556 1 : let mut write = BytesMut::new();
557 1 : frontend::sasl_initial_response("SCRAM-SHA-256", scram.message(), &mut write).unwrap();
558 1 : client.write_all(&write).await.unwrap();
559 1 :
560 1 : // server response with server-first-message
561 1 : match read_message(&mut client, &mut read).await {
562 1 : PgMessage::AuthenticationSaslContinue(a) => {
563 1 : scram.update(a.data()).await.unwrap();
564 1 : }
565 1 : _ => panic!("wrong message"),
566 1 : }
567 1 :
568 1 : // client response with client-final-message
569 1 : write.clear();
570 1 : frontend::sasl_response(scram.message(), &mut write).unwrap();
571 1 : client.write_all(&write).await.unwrap();
572 1 :
573 1 : // server response with server-final-message
574 1 : match read_message(&mut client, &mut read).await {
575 1 : PgMessage::AuthenticationSaslFinal(a) => {
576 1 : scram.finish(a.data()).unwrap();
577 1 : }
578 1 : _ => panic!("wrong message"),
579 1 : }
580 1 : });
581 1 : let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
582 1 : EndpointRateLimiter::DEFAULT,
583 1 : 64,
584 1 : ));
585 1 :
586 1 : let _creds = auth_quirks(
587 1 : &ctx,
588 1 : &api,
589 1 : user_info,
590 1 : &mut stream,
591 1 : false,
592 1 : &CONFIG,
593 1 : endpoint_rate_limiter,
594 1 : )
595 1 : .await
596 1 : .unwrap();
597 1 :
598 1 : // flush the final server message
599 1 : stream.flush().await.unwrap();
600 1 :
601 1 : handle.await.unwrap();
602 1 : }
603 :
604 : #[tokio::test]
605 1 : async fn auth_quirks_cleartext() {
606 1 : let (mut client, server) = tokio::io::duplex(1024);
607 1 : let mut stream = PqStream::new_skip_handshake(Stream::from_raw(server));
608 1 :
609 1 : let ctx = RequestContext::test();
610 1 : let api = Auth {
611 1 : ips: vec![],
612 1 : vpc_endpoint_ids: vec![],
613 1 : access_blocker_flags: AccessBlockerFlags::default(),
614 1 : secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
615 1 : };
616 1 :
617 1 : let user_info = ComputeUserInfoMaybeEndpoint {
618 1 : user: "conrad".into(),
619 1 : endpoint_id: Some("endpoint".into()),
620 1 : options: NeonOptions::default(),
621 1 : };
622 1 :
623 1 : let handle = tokio::spawn(async move {
624 1 : let mut read = BytesMut::new();
625 1 : let mut write = BytesMut::new();
626 1 :
627 1 : // server should offer cleartext
628 1 : match read_message(&mut client, &mut read).await {
629 1 : PgMessage::AuthenticationCleartextPassword => {}
630 1 : _ => panic!("wrong message"),
631 1 : }
632 1 :
633 1 : // client responds with password
634 1 : write.clear();
635 1 : frontend::password_message(b"my-secret-password", &mut write).unwrap();
636 1 : client.write_all(&write).await.unwrap();
637 1 : });
638 1 : let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
639 1 : EndpointRateLimiter::DEFAULT,
640 1 : 64,
641 1 : ));
642 1 :
643 1 : let _creds = auth_quirks(
644 1 : &ctx,
645 1 : &api,
646 1 : user_info,
647 1 : &mut stream,
648 1 : true,
649 1 : &CONFIG,
650 1 : endpoint_rate_limiter,
651 1 : )
652 1 : .await
653 1 : .unwrap();
654 1 :
655 1 : handle.await.unwrap();
656 1 : }
657 :
658 : #[tokio::test]
659 1 : async fn auth_quirks_password_hack() {
660 1 : let (mut client, server) = tokio::io::duplex(1024);
661 1 : let mut stream = PqStream::new_skip_handshake(Stream::from_raw(server));
662 1 :
663 1 : let ctx = RequestContext::test();
664 1 : let api = Auth {
665 1 : ips: vec![],
666 1 : vpc_endpoint_ids: vec![],
667 1 : access_blocker_flags: AccessBlockerFlags::default(),
668 1 : secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
669 1 : };
670 1 :
671 1 : let user_info = ComputeUserInfoMaybeEndpoint {
672 1 : user: "conrad".into(),
673 1 : endpoint_id: None,
674 1 : options: NeonOptions::default(),
675 1 : };
676 1 :
677 1 : let handle = tokio::spawn(async move {
678 1 : let mut read = BytesMut::new();
679 1 :
680 1 : // server should offer cleartext
681 1 : match read_message(&mut client, &mut read).await {
682 1 : PgMessage::AuthenticationCleartextPassword => {}
683 1 : _ => panic!("wrong message"),
684 1 : }
685 1 :
686 1 : // client responds with password
687 1 : let mut write = BytesMut::new();
688 1 : frontend::password_message(b"endpoint=my-endpoint;my-secret-password", &mut write)
689 1 : .unwrap();
690 1 : client.write_all(&write).await.unwrap();
691 1 : });
692 1 :
693 1 : let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
694 1 : EndpointRateLimiter::DEFAULT,
695 1 : 64,
696 1 : ));
697 1 :
698 1 : let creds = auth_quirks(
699 1 : &ctx,
700 1 : &api,
701 1 : user_info,
702 1 : &mut stream,
703 1 : true,
704 1 : &CONFIG,
705 1 : endpoint_rate_limiter,
706 1 : )
707 1 : .await
708 1 : .unwrap();
709 1 :
710 1 : assert_eq!(creds.info.endpoint, "my-endpoint");
711 1 :
712 1 : handle.await.unwrap();
713 1 : }
714 : }
|