LCOV - code coverage report
Current view: top level - proxy/src/proxy/tests - mod.rs (source / functions) Coverage Total Hit
Test: ef1c66bb4fbe62e3fa18f8b9d22d3134c7ecd2da.info Lines: 97.7 % 530 518
Test Date: 2025-07-25 10:34:39 Functions: 96.9 % 65 63

            Line data    Source code
       1              : //! A group of high-level tests for connection establishing logic and auth.
       2              : #![allow(clippy::unimplemented)]
       3              : 
       4              : mod mitm;
       5              : 
       6              : use std::sync::Arc;
       7              : use std::time::Duration;
       8              : 
       9              : use anyhow::{Context, bail};
      10              : use async_trait::async_trait;
      11              : use http::StatusCode;
      12              : use postgres_client::config::SslMode;
      13              : use postgres_client::tls::{MakeTlsConnect, NoTls};
      14              : use rstest::rstest;
      15              : use rustls::crypto::ring;
      16              : use rustls::pki_types;
      17              : use tokio::io::{AsyncRead, AsyncWrite, DuplexStream};
      18              : use tokio::time::Instant;
      19              : use tracing_test::traced_test;
      20              : 
      21              : use super::retry::CouldRetry;
      22              : use crate::auth::backend::{ComputeUserInfo, MaybeOwned};
      23              : use crate::config::{ComputeConfig, RetryConfig, TlsConfig};
      24              : use crate::context::RequestContext;
      25              : use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient};
      26              : use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status};
      27              : use crate::control_plane::{self, CachedNodeInfo, NodeInfo, NodeInfoCache};
      28              : use crate::error::ErrorKind;
      29              : use crate::pglb::ERR_INSECURE_CONNECTION;
      30              : use crate::pglb::handshake::{HandshakeData, handshake};
      31              : use crate::pqproto::BeMessage;
      32              : use crate::proxy::NeonOptions;
      33              : use crate::proxy::connect_compute::{ConnectMechanism, connect_to_compute_inner};
      34              : use crate::proxy::retry::retry_after;
      35              : use crate::stream::{PqStream, Stream};
      36              : use crate::tls::client_config::compute_client_config_with_certs;
      37              : use crate::tls::server_config::CertResolver;
      38              : use crate::types::{BranchId, EndpointId, ProjectId};
      39              : use crate::{auth, compute, sasl, scram};
      40              : 
      41              : /// Generate a set of TLS certificates: CA + server.
      42           21 : fn generate_certs(
      43           21 :     hostname: &str,
      44           21 :     common_name: &str,
      45           21 : ) -> anyhow::Result<(
      46           21 :     pki_types::CertificateDer<'static>,
      47           21 :     pki_types::CertificateDer<'static>,
      48           21 :     pki_types::PrivateKeyDer<'static>,
      49           21 : )> {
      50           21 :     let ca_key = rcgen::KeyPair::generate()?;
      51           21 :     let ca = {
      52           21 :         let mut params = rcgen::CertificateParams::default();
      53           21 :         params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
      54           21 :         params.self_signed(&ca_key)?
      55              :     };
      56              : 
      57           21 :     let cert_key = rcgen::KeyPair::generate()?;
      58           21 :     let cert = {
      59           21 :         let mut params = rcgen::CertificateParams::new(vec![hostname.into()])?;
      60           21 :         params.distinguished_name = rcgen::DistinguishedName::new();
      61           21 :         params
      62           21 :             .distinguished_name
      63           21 :             .push(rcgen::DnType::CommonName, common_name);
      64           21 :         params.signed_by(&cert_key, &ca, &ca_key)?
      65              :     };
      66              : 
      67           21 :     Ok((
      68           21 :         ca.der().clone(),
      69           21 :         cert.der().clone(),
      70           21 :         pki_types::PrivateKeyDer::Pkcs8(cert_key.serialize_der().into()),
      71           21 :     ))
      72           21 : }
      73              : 
      74              : struct ClientConfig<'a> {
      75              :     config: Arc<rustls::ClientConfig>,
      76              :     hostname: &'a str,
      77              : }
      78              : 
      79              : type TlsConnect<S> = <ComputeConfig as MakeTlsConnect<S>>::TlsConnect;
      80              : 
      81              : impl ClientConfig<'_> {
      82           20 :     fn make_tls_connect(self) -> anyhow::Result<TlsConnect<DuplexStream>> {
      83           20 :         Ok(crate::tls::postgres_rustls::make_tls_connect(
      84           20 :             &self.config,
      85           20 :             self.hostname,
      86            0 :         )?)
      87           20 :     }
      88              : }
      89              : 
      90              : /// Generate TLS certificates and build rustls configs for client and server.
      91           21 : fn generate_tls_config<'a>(
      92           21 :     hostname: &'a str,
      93           21 :     common_name: &'a str,
      94           21 : ) -> anyhow::Result<(ClientConfig<'a>, TlsConfig)> {
      95           21 :     let (ca, cert, key) = generate_certs(hostname, common_name)?;
      96              : 
      97           21 :     let tls_config = {
      98           21 :         let config =
      99           21 :             rustls::ServerConfig::builder_with_provider(Arc::new(ring::default_provider()))
     100           21 :                 .with_safe_default_protocol_versions()
     101           21 :                 .context("ring should support the default protocol versions")?
     102           21 :                 .with_no_client_auth()
     103           21 :                 .with_single_cert(vec![cert.clone()], key.clone_key())?;
     104              : 
     105           21 :         let cert_resolver = CertResolver::new(key, vec![cert])?;
     106              : 
     107           21 :         let common_names = cert_resolver.get_common_names();
     108              : 
     109           21 :         let config = Arc::new(config);
     110              : 
     111           21 :         TlsConfig {
     112           21 :             http_config: config.clone(),
     113           21 :             pg_config: config,
     114           21 :             common_names,
     115           21 :             cert_resolver: Arc::new(cert_resolver),
     116           21 :         }
     117              :     };
     118              : 
     119           21 :     let client_config = {
     120           21 :         let config = Arc::new(compute_client_config_with_certs([ca]));
     121              : 
     122           21 :         ClientConfig { config, hostname }
     123              :     };
     124              : 
     125           21 :     Ok((client_config, tls_config))
     126           21 : }
     127              : 
     128              : #[async_trait]
     129              : trait TestAuth: Sized {
     130            2 :     async fn authenticate<S: AsyncRead + AsyncWrite + Unpin + Send>(
     131              :         self,
     132              :         stream: &mut PqStream<Stream<S>>,
     133            2 :     ) -> anyhow::Result<()> {
     134            2 :         stream.write_message(BeMessage::AuthenticationOk);
     135            2 :         Ok(())
     136            4 :     }
     137              : }
     138              : 
     139              : struct NoAuth;
     140              : impl TestAuth for NoAuth {}
     141              : 
     142              : struct Scram(scram::ServerSecret);
     143              : 
     144              : impl Scram {
     145           11 :     async fn new(password: &str) -> anyhow::Result<Self> {
     146           11 :         let secret = scram::ServerSecret::build(password)
     147           11 :             .await
     148           11 :             .context("failed to generate scram secret")?;
     149           11 :         Ok(Scram(secret))
     150           11 :     }
     151              : 
     152            1 :     fn mock() -> Self {
     153            1 :         Scram(scram::ServerSecret::mock(rand::random()))
     154            1 :     }
     155              : }
     156              : 
     157              : #[async_trait]
     158              : impl TestAuth for Scram {
     159           12 :     async fn authenticate<S: AsyncRead + AsyncWrite + Unpin + Send>(
     160              :         self,
     161              :         stream: &mut PqStream<Stream<S>>,
     162           12 :     ) -> anyhow::Result<()> {
     163           12 :         let outcome = auth::AuthFlow::new(stream, auth::Scram(&self.0, &RequestContext::test()))
     164           12 :             .authenticate()
     165           12 :             .await?;
     166              : 
     167              :         use sasl::Outcome::*;
     168            6 :         match outcome {
     169            5 :             Success(_) => Ok(()),
     170            1 :             Failure(reason) => bail!("autentication failed with an error: {reason}"),
     171              :         }
     172           24 :     }
     173              : }
     174              : 
     175              : /// A dummy proxy impl which performs a handshake and reports auth success.
     176           15 : async fn dummy_proxy(
     177           15 :     client: impl AsyncRead + AsyncWrite + Unpin + Send,
     178           15 :     tls: Option<TlsConfig>,
     179           15 :     auth: impl TestAuth + Send,
     180           15 : ) -> anyhow::Result<()> {
     181           15 :     let mut stream = match handshake(&RequestContext::test(), client, tls.as_ref(), false).await? {
     182           14 :         HandshakeData::Startup(stream, _) => stream,
     183            0 :         HandshakeData::Cancel(_) => bail!("cancellation not supported"),
     184              :     };
     185              : 
     186           14 :     auth.authenticate(&mut stream).await?;
     187              : 
     188            7 :     stream.write_message(BeMessage::ParameterStatus {
     189            7 :         name: b"client_encoding",
     190            7 :         value: b"UTF8",
     191            7 :     });
     192            7 :     stream.write_message(BeMessage::ReadyForQuery);
     193            7 :     stream.flush().await?;
     194              : 
     195            7 :     Ok(())
     196           15 : }
     197              : 
     198              : #[tokio::test]
     199            1 : async fn handshake_tls_is_enforced_by_proxy() -> anyhow::Result<()> {
     200            1 :     let (client, server) = tokio::io::duplex(1024);
     201              : 
     202            1 :     let (_, server_config) = generate_tls_config("generic-project-name.localhost", "localhost")?;
     203            1 :     let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth));
     204              : 
     205            1 :     let client_err = postgres_client::Config::new("test".to_owned(), 5432)
     206            1 :         .user("john_doe")
     207            1 :         .dbname("earth")
     208            1 :         .ssl_mode(SslMode::Disable)
     209            1 :         .tls_and_authenticate(server, NoTls)
     210            1 :         .await
     211            1 :         .err() // -> Option<E>
     212            1 :         .context("client shouldn't be able to connect")?;
     213              : 
     214            1 :     assert!(client_err.to_string().contains(ERR_INSECURE_CONNECTION));
     215              : 
     216            1 :     let server_err = proxy
     217            1 :         .await?
     218            1 :         .err() // -> Option<E>
     219            1 :         .context("server shouldn't accept client")?;
     220              : 
     221            1 :     assert!(client_err.to_string().contains(&server_err.to_string()));
     222              : 
     223            2 :     Ok(())
     224            1 : }
     225              : 
     226              : #[tokio::test]
     227            1 : async fn handshake_tls() -> anyhow::Result<()> {
     228            1 :     let (client, server) = tokio::io::duplex(1024);
     229              : 
     230            1 :     let (client_config, server_config) =
     231            1 :         generate_tls_config("generic-project-name.localhost", "localhost")?;
     232            1 :     let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth));
     233              : 
     234            1 :     let _conn = postgres_client::Config::new("test".to_owned(), 5432)
     235            1 :         .user("john_doe")
     236            1 :         .dbname("earth")
     237            1 :         .ssl_mode(SslMode::Require)
     238            1 :         .tls_and_authenticate(server, client_config.make_tls_connect()?)
     239            1 :         .await?;
     240              : 
     241            1 :     proxy.await?
     242            1 : }
     243              : 
     244              : #[tokio::test]
     245            1 : async fn handshake_raw() -> anyhow::Result<()> {
     246            1 :     let (client, server) = tokio::io::duplex(1024);
     247              : 
     248            1 :     let proxy = tokio::spawn(dummy_proxy(client, None, NoAuth));
     249              : 
     250            1 :     let _conn = postgres_client::Config::new("test".to_owned(), 5432)
     251            1 :         .user("john_doe")
     252            1 :         .dbname("earth")
     253            1 :         .set_param("options", "project=generic-project-name")
     254            1 :         .ssl_mode(SslMode::Prefer)
     255            1 :         .tls_and_authenticate(server, NoTls)
     256            1 :         .await?;
     257              : 
     258            1 :     proxy.await?
     259            1 : }
     260              : 
     261              : #[tokio::test]
     262            1 : async fn keepalive_is_inherited() -> anyhow::Result<()> {
     263              :     use tokio::net::{TcpListener, TcpStream};
     264              : 
     265            1 :     let listener = TcpListener::bind("127.0.0.1:0").await?;
     266            1 :     let port = listener.local_addr()?.port();
     267            1 :     socket2::SockRef::from(&listener).set_keepalive(true)?;
     268              : 
     269            1 :     let t = tokio::spawn(async move {
     270            1 :         let (client, _) = listener.accept().await?;
     271            1 :         let keepalive = socket2::SockRef::from(&client).keepalive()?;
     272            1 :         anyhow::Ok(keepalive)
     273            1 :     });
     274              : 
     275            1 :     TcpStream::connect(("127.0.0.1", port)).await?;
     276            1 :     assert!(t.await??, "keepalive should be inherited");
     277              : 
     278            2 :     Ok(())
     279            1 : }
     280              : 
     281              : #[rstest]
     282              : #[case("password_foo")]
     283              : #[case("pwd-bar")]
     284              : #[case("")]
     285              : #[tokio::test]
     286              : async fn scram_auth_good(#[case] password: &str) -> anyhow::Result<()> {
     287              :     let (client, server) = tokio::io::duplex(1024);
     288              : 
     289              :     let (client_config, server_config) =
     290              :         generate_tls_config("generic-project-name.localhost", "localhost")?;
     291              :     let proxy = tokio::spawn(dummy_proxy(
     292              :         client,
     293              :         Some(server_config),
     294              :         Scram::new(password).await?,
     295              :     ));
     296              : 
     297              :     let _conn = postgres_client::Config::new("test".to_owned(), 5432)
     298              :         .channel_binding(postgres_client::config::ChannelBinding::Require)
     299              :         .user("user")
     300              :         .dbname("db")
     301              :         .password(password)
     302              :         .ssl_mode(SslMode::Require)
     303              :         .tls_and_authenticate(server, client_config.make_tls_connect()?)
     304              :         .await?;
     305              : 
     306              :     proxy.await?
     307              : }
     308              : 
     309              : #[tokio::test]
     310            1 : async fn scram_auth_disable_channel_binding() -> anyhow::Result<()> {
     311            1 :     let (client, server) = tokio::io::duplex(1024);
     312              : 
     313            1 :     let (client_config, server_config) =
     314            1 :         generate_tls_config("generic-project-name.localhost", "localhost")?;
     315            1 :     let proxy = tokio::spawn(dummy_proxy(
     316            1 :         client,
     317            1 :         Some(server_config),
     318            1 :         Scram::new("password").await?,
     319              :     ));
     320              : 
     321            1 :     let _conn = postgres_client::Config::new("test".to_owned(), 5432)
     322            1 :         .channel_binding(postgres_client::config::ChannelBinding::Disable)
     323            1 :         .user("user")
     324            1 :         .dbname("db")
     325            1 :         .password("password")
     326            1 :         .ssl_mode(SslMode::Require)
     327            1 :         .tls_and_authenticate(server, client_config.make_tls_connect()?)
     328            1 :         .await?;
     329              : 
     330            1 :     proxy.await?
     331            1 : }
     332              : 
     333              : #[tokio::test]
     334            1 : async fn scram_auth_mock() -> anyhow::Result<()> {
     335            1 :     let (client, server) = tokio::io::duplex(1024);
     336              : 
     337            1 :     let (client_config, server_config) =
     338            1 :         generate_tls_config("generic-project-name.localhost", "localhost")?;
     339            1 :     let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), Scram::mock()));
     340              : 
     341              :     use rand::Rng;
     342              :     use rand::distr::Alphanumeric;
     343            1 :     let password: String = rand::rng()
     344            1 :         .sample_iter(&Alphanumeric)
     345            1 :         .take(rand::random::<u8>() as usize)
     346            1 :         .map(char::from)
     347            1 :         .collect();
     348              : 
     349            1 :     let _client_err = postgres_client::Config::new("test".to_owned(), 5432)
     350            1 :         .user("user")
     351            1 :         .dbname("db")
     352            1 :         .password(&password) // no password will match the mocked secret
     353            1 :         .ssl_mode(SslMode::Require)
     354            1 :         .tls_and_authenticate(server, client_config.make_tls_connect()?)
     355            1 :         .await
     356            1 :         .err() // -> Option<E>
     357            1 :         .context("client shouldn't be able to connect")?;
     358              : 
     359            1 :     let _server_err = proxy
     360            1 :         .await?
     361            1 :         .err() // -> Option<E>
     362            1 :         .context("server shouldn't accept client")?;
     363              : 
     364            2 :     Ok(())
     365            1 : }
     366              : 
     367              : #[test]
     368            1 : fn connect_compute_total_wait() {
     369            1 :     let mut total_wait = tokio::time::Duration::ZERO;
     370            1 :     let config = RetryConfig {
     371            1 :         base_delay: Duration::from_secs(1),
     372            1 :         max_retries: 5,
     373            1 :         backoff_factor: 2.0,
     374            1 :     };
     375            4 :     for num_retries in 1..config.max_retries {
     376            4 :         total_wait += retry_after(num_retries, config);
     377            4 :     }
     378            1 :     assert!(f64::abs(total_wait.as_secs_f64() - 15.0) < 0.1);
     379            1 : }
     380              : 
     381              : #[derive(Clone, Copy, Debug)]
     382              : enum ConnectAction {
     383              :     Wake,
     384              :     WakeCold,
     385              :     WakeFail,
     386              :     WakeRetry,
     387              :     Connect,
     388              :     // connect_once -> Err, could_retry = true, should_retry_wake_compute = true
     389              :     Retry,
     390              :     // connect_once -> Err, could_retry = true, should_retry_wake_compute = false
     391              :     RetryNoWake,
     392              :     // connect_once -> Err, could_retry = false, should_retry_wake_compute = true
     393              :     Fail,
     394              :     // connect_once -> Err, could_retry = false, should_retry_wake_compute = false
     395              :     FailNoWake,
     396              : }
     397              : 
     398              : #[derive(Clone)]
     399              : struct TestConnectMechanism {
     400              :     counter: Arc<std::sync::Mutex<usize>>,
     401              :     sequence: Vec<ConnectAction>,
     402              :     cache: &'static NodeInfoCache,
     403              : }
     404              : 
     405              : impl TestConnectMechanism {
     406           11 :     fn verify(&self) {
     407           11 :         let counter = self.counter.lock().unwrap();
     408           11 :         assert_eq!(
     409           11 :             *counter,
     410           11 :             self.sequence.len(),
     411            0 :             "sequence does not proceed to the end"
     412              :         );
     413           11 :     }
     414              : }
     415              : 
     416              : impl TestConnectMechanism {
     417           13 :     fn new(sequence: Vec<ConnectAction>) -> Self {
     418           13 :         Self {
     419           13 :             counter: Arc::new(std::sync::Mutex::new(0)),
     420           13 :             sequence,
     421           13 :             cache: Box::leak(Box::new(NodeInfoCache::new(
     422           13 :                 "test",
     423           13 :                 1,
     424           13 :                 Duration::from_secs(100),
     425           13 :                 false,
     426           13 :             ))),
     427           13 :         }
     428           13 :     }
     429              : }
     430              : 
     431              : #[derive(Debug)]
     432              : struct TestConnection;
     433              : 
     434              : impl ConnectMechanism for TestConnectMechanism {
     435              :     type Connection = TestConnection;
     436              : 
     437           25 :     async fn connect_once(
     438           25 :         &self,
     439           25 :         _ctx: &RequestContext,
     440           25 :         _node_info: &control_plane::CachedNodeInfo,
     441           25 :         _config: &ComputeConfig,
     442           25 :     ) -> Result<Self::Connection, compute::ConnectionError> {
     443           25 :         let mut counter = self.counter.lock().unwrap();
     444           25 :         let action = self.sequence[*counter];
     445           25 :         *counter += 1;
     446           25 :         match action {
     447            8 :             ConnectAction::Connect => Ok(TestConnection),
     448           10 :             ConnectAction::Retry => Err(compute::ConnectionError::TestError {
     449           10 :                 retryable: true,
     450           10 :                 wakeable: true,
     451           10 :                 kind: ErrorKind::Compute,
     452           10 :             }),
     453            2 :             ConnectAction::RetryNoWake => Err(compute::ConnectionError::TestError {
     454            2 :                 retryable: true,
     455            2 :                 wakeable: false,
     456            2 :                 kind: ErrorKind::Compute,
     457            2 :             }),
     458            4 :             ConnectAction::Fail => Err(compute::ConnectionError::TestError {
     459            4 :                 retryable: false,
     460            4 :                 wakeable: true,
     461            4 :                 kind: ErrorKind::Compute,
     462            4 :             }),
     463            1 :             ConnectAction::FailNoWake => Err(compute::ConnectionError::TestError {
     464            1 :                 retryable: false,
     465            1 :                 wakeable: false,
     466            1 :                 kind: ErrorKind::Compute,
     467            1 :             }),
     468            0 :             x => panic!("expecting action {x:?}, connect is called instead"),
     469              :         }
     470           25 :     }
     471              : }
     472              : 
     473              : impl TestControlPlaneClient for TestConnectMechanism {
     474           21 :     fn wake_compute(&self) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
     475           21 :         let mut counter = self.counter.lock().unwrap();
     476           21 :         let action = self.sequence[*counter];
     477           21 :         *counter += 1;
     478           21 :         match action {
     479           17 :             ConnectAction::Wake => Ok(helper_create_cached_node_info(self.cache)),
     480            1 :             ConnectAction::WakeCold => Ok(CachedNodeInfo::new_uncached(
     481            1 :                 helper_create_uncached_node_info(),
     482            1 :             )),
     483              :             ConnectAction::WakeFail => {
     484            1 :                 let err = control_plane::errors::ControlPlaneError::Message(Box::new(
     485            1 :                     ControlPlaneErrorMessage {
     486            1 :                         http_status_code: StatusCode::BAD_REQUEST,
     487            1 :                         error: "TEST".into(),
     488            1 :                         status: None,
     489            1 :                     },
     490            1 :                 ));
     491            1 :                 assert!(!err.could_retry());
     492            1 :                 Err(control_plane::errors::WakeComputeError::ControlPlane(err))
     493              :             }
     494              :             ConnectAction::WakeRetry => {
     495            2 :                 let err = control_plane::errors::ControlPlaneError::Message(Box::new(
     496            2 :                     ControlPlaneErrorMessage {
     497            2 :                         http_status_code: StatusCode::BAD_REQUEST,
     498            2 :                         error: "TEST".into(),
     499            2 :                         status: Some(Status {
     500            2 :                             code: "error".into(),
     501            2 :                             message: "error".into(),
     502            2 :                             details: Details {
     503            2 :                                 error_info: None,
     504            2 :                                 retry_info: Some(control_plane::messages::RetryInfo {
     505            2 :                                     retry_at: Instant::now() + Duration::from_millis(1),
     506            2 :                                 }),
     507            2 :                                 user_facing_message: None,
     508            2 :                             },
     509            2 :                         }),
     510            2 :                     },
     511            2 :                 ));
     512            2 :                 assert!(err.could_retry());
     513            2 :                 Err(control_plane::errors::WakeComputeError::ControlPlane(err))
     514              :             }
     515            0 :             x => panic!("expecting action {x:?}, wake_compute is called instead"),
     516              :         }
     517           21 :     }
     518              : 
     519            0 :     fn get_access_control(
     520            0 :         &self,
     521            0 :     ) -> Result<control_plane::EndpointAccessControl, control_plane::errors::GetAuthInfoError> {
     522            0 :         unimplemented!("not used in tests")
     523              :     }
     524              : 
     525            0 :     fn dyn_clone(&self) -> Box<dyn TestControlPlaneClient> {
     526            0 :         Box::new(self.clone())
     527            0 :     }
     528              : }
     529              : 
     530           18 : fn helper_create_uncached_node_info() -> NodeInfo {
     531           18 :     NodeInfo {
     532           18 :         conn_info: compute::ConnectInfo {
     533           18 :             host: "test".into(),
     534           18 :             port: 5432,
     535           18 :             ssl_mode: SslMode::Disable,
     536           18 :             host_addr: None,
     537           18 :         },
     538           18 :         aux: MetricsAuxInfo {
     539           18 :             endpoint_id: (&EndpointId::from("endpoint")).into(),
     540           18 :             project_id: (&ProjectId::from("project")).into(),
     541           18 :             branch_id: (&BranchId::from("branch")).into(),
     542           18 :             compute_id: "compute".into(),
     543           18 :             cold_start_info: crate::control_plane::messages::ColdStartInfo::Warm,
     544           18 :         },
     545           18 :     }
     546           18 : }
     547              : 
     548           17 : fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo {
     549           17 :     let node = helper_create_uncached_node_info();
     550           17 :     let (_, node2) = cache.insert_unit("key".into(), Ok(node.clone()));
     551           17 :     node2.map(|()| node)
     552           17 : }
     553              : 
     554           13 : fn helper_create_connect_info(
     555           13 :     mechanism: &TestConnectMechanism,
     556           13 : ) -> auth::Backend<'static, ComputeUserInfo> {
     557           13 :     auth::Backend::ControlPlane(
     558           13 :         MaybeOwned::Owned(ControlPlaneClient::Test(Box::new(mechanism.clone()))),
     559           13 :         ComputeUserInfo {
     560           13 :             endpoint: "endpoint".into(),
     561           13 :             user: "user".into(),
     562           13 :             options: NeonOptions::parse_options_raw(""),
     563           13 :         },
     564           13 :     )
     565           13 : }
     566              : 
     567           13 : fn config() -> ComputeConfig {
     568           13 :     let retry = RetryConfig {
     569           13 :         base_delay: Duration::from_secs(1),
     570           13 :         max_retries: 5,
     571           13 :         backoff_factor: 2.0,
     572           13 :     };
     573              : 
     574           13 :     ComputeConfig {
     575           13 :         retry,
     576           13 :         tls: Arc::new(compute_client_config_with_certs(std::iter::empty())),
     577           13 :         timeout: Duration::from_secs(2),
     578           13 :     }
     579           13 : }
     580              : 
     581              : #[tokio::test]
     582            1 : async fn connect_to_compute_success() {
     583            1 :     let _ = env_logger::try_init();
     584              :     use ConnectAction::*;
     585            1 :     let ctx = RequestContext::test();
     586            1 :     let mechanism = TestConnectMechanism::new(vec![Wake, Connect]);
     587            1 :     let user_info = helper_create_connect_info(&mechanism);
     588            1 :     let config = config();
     589            1 :     connect_to_compute_inner(&ctx, &mechanism, &user_info, config.retry, &config)
     590            1 :         .await
     591            1 :         .unwrap();
     592            1 :     mechanism.verify();
     593            1 : }
     594              : 
     595              : #[tokio::test]
     596            1 : async fn connect_to_compute_retry() {
     597            1 :     let _ = env_logger::try_init();
     598              :     use ConnectAction::*;
     599            1 :     let ctx = RequestContext::test();
     600            1 :     let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Connect]);
     601            1 :     let user_info = helper_create_connect_info(&mechanism);
     602            1 :     let config = config();
     603            1 :     connect_to_compute_inner(&ctx, &mechanism, &user_info, config.retry, &config)
     604            1 :         .await
     605            1 :         .unwrap();
     606            1 :     mechanism.verify();
     607            1 : }
     608              : 
     609              : /// Test that we don't retry if the error is not retryable.
     610              : #[tokio::test]
     611            1 : async fn connect_to_compute_non_retry_1() {
     612            1 :     let _ = env_logger::try_init();
     613              :     use ConnectAction::*;
     614            1 :     let ctx = RequestContext::test();
     615            1 :     let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Fail]);
     616            1 :     let user_info = helper_create_connect_info(&mechanism);
     617            1 :     let config = config();
     618            1 :     connect_to_compute_inner(&ctx, &mechanism, &user_info, config.retry, &config)
     619            1 :         .await
     620            1 :         .unwrap_err();
     621            1 :     mechanism.verify();
     622            1 : }
     623              : 
     624              : /// Even for non-retryable errors, we should retry at least once.
     625              : #[tokio::test]
     626            1 : async fn connect_to_compute_non_retry_2() {
     627            1 :     let _ = env_logger::try_init();
     628              :     use ConnectAction::*;
     629            1 :     let ctx = RequestContext::test();
     630            1 :     let mechanism = TestConnectMechanism::new(vec![Wake, Fail, Wake, Connect]);
     631            1 :     let user_info = helper_create_connect_info(&mechanism);
     632            1 :     let config = config();
     633            1 :     connect_to_compute_inner(&ctx, &mechanism, &user_info, config.retry, &config)
     634            1 :         .await
     635            1 :         .unwrap();
     636            1 :     mechanism.verify();
     637            1 : }
     638              : 
     639              : /// Retry for at most `NUM_RETRIES_CONNECT` times.
     640              : #[tokio::test]
     641            1 : async fn connect_to_compute_non_retry_3() {
     642            1 :     let _ = env_logger::try_init();
     643            1 :     tokio::time::pause();
     644              :     use ConnectAction::*;
     645            1 :     let ctx = RequestContext::test();
     646            1 :     let mechanism =
     647            1 :         TestConnectMechanism::new(vec![Wake, Retry, Wake, Retry, Retry, Retry, Retry, Retry]);
     648            1 :     let user_info = helper_create_connect_info(&mechanism);
     649            1 :     let wake_compute_retry_config = RetryConfig {
     650            1 :         base_delay: Duration::from_secs(1),
     651            1 :         max_retries: 1,
     652            1 :         backoff_factor: 2.0,
     653            1 :     };
     654            1 :     let config = config();
     655            1 :     connect_to_compute_inner(
     656            1 :         &ctx,
     657            1 :         &mechanism,
     658            1 :         &user_info,
     659            1 :         wake_compute_retry_config,
     660            1 :         &config,
     661            1 :     )
     662            1 :     .await
     663            1 :     .unwrap_err();
     664            1 :     mechanism.verify();
     665            1 : }
     666              : 
     667              : /// Should retry wake compute.
     668              : #[tokio::test]
     669            1 : async fn wake_retry() {
     670            1 :     let _ = env_logger::try_init();
     671              :     use ConnectAction::*;
     672            1 :     let ctx = RequestContext::test();
     673            1 :     let mechanism = TestConnectMechanism::new(vec![WakeRetry, Wake, Connect]);
     674            1 :     let user_info = helper_create_connect_info(&mechanism);
     675            1 :     let config = config();
     676            1 :     connect_to_compute_inner(&ctx, &mechanism, &user_info, config.retry, &config)
     677            1 :         .await
     678            1 :         .unwrap();
     679            1 :     mechanism.verify();
     680            1 : }
     681              : 
     682              : /// Wake failed with a non-retryable error.
     683              : #[tokio::test]
     684            1 : async fn wake_non_retry() {
     685            1 :     let _ = env_logger::try_init();
     686              :     use ConnectAction::*;
     687            1 :     let ctx = RequestContext::test();
     688            1 :     let mechanism = TestConnectMechanism::new(vec![WakeRetry, WakeFail]);
     689            1 :     let user_info = helper_create_connect_info(&mechanism);
     690            1 :     let config = config();
     691            1 :     connect_to_compute_inner(&ctx, &mechanism, &user_info, config.retry, &config)
     692            1 :         .await
     693            1 :         .unwrap_err();
     694            1 :     mechanism.verify();
     695            1 : }
     696              : 
     697              : #[tokio::test]
     698              : #[traced_test]
     699            1 : async fn fail_but_wake_invalidates_cache() {
     700            1 :     let ctx = RequestContext::test();
     701            1 :     let mech = TestConnectMechanism::new(vec![
     702            1 :         ConnectAction::Wake,
     703            1 :         ConnectAction::Fail,
     704            1 :         ConnectAction::Wake,
     705            1 :         ConnectAction::Connect,
     706              :     ]);
     707            1 :     let user = helper_create_connect_info(&mech);
     708            1 :     let cfg = config();
     709              : 
     710            1 :     connect_to_compute_inner(&ctx, &mech, &user, cfg.retry, &cfg)
     711            1 :         .await
     712            1 :         .unwrap();
     713              : 
     714            1 :     assert!(logs_contain(
     715            1 :         "invalidating stalled compute node info cache entry"
     716            1 :     ));
     717            1 : }
     718              : 
     719              : #[tokio::test]
     720              : #[traced_test]
     721            1 : async fn fail_no_wake_skips_cache_invalidation() {
     722            1 :     let ctx = RequestContext::test();
     723            1 :     let mech = TestConnectMechanism::new(vec![
     724            1 :         ConnectAction::Wake,
     725            1 :         ConnectAction::RetryNoWake,
     726            1 :         ConnectAction::Connect,
     727              :     ]);
     728            1 :     let user = helper_create_connect_info(&mech);
     729            1 :     let cfg = config();
     730              : 
     731            1 :     connect_to_compute_inner(&ctx, &mech, &user, cfg.retry, &cfg)
     732            1 :         .await
     733            1 :         .unwrap();
     734              : 
     735            1 :     assert!(!logs_contain(
     736            1 :         "invalidating stalled compute node info cache entry"
     737            1 :     ));
     738            1 : }
     739              : 
     740              : #[tokio::test]
     741              : #[traced_test]
     742            1 : async fn retry_but_wake_invalidates_cache() {
     743            1 :     let _ = env_logger::try_init();
     744              :     use ConnectAction::*;
     745              : 
     746            1 :     let ctx = RequestContext::test();
     747              :     // Wake → Retry (retryable + wakeable) → Wake → Connect
     748            1 :     let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Connect]);
     749            1 :     let user_info = helper_create_connect_info(&mechanism);
     750            1 :     let cfg = config();
     751              : 
     752            1 :     connect_to_compute_inner(&ctx, &mechanism, &user_info, cfg.retry, &cfg)
     753            1 :         .await
     754            1 :         .unwrap();
     755            1 :     mechanism.verify();
     756              : 
     757              :     // Because Retry has wakeable=true, we should see invalidate_cache
     758            1 :     assert!(logs_contain(
     759            1 :         "invalidating stalled compute node info cache entry"
     760            1 :     ));
     761            1 : }
     762              : 
     763              : #[tokio::test]
     764              : #[traced_test]
     765            1 : async fn retry_no_wake_skips_invalidation() {
     766            1 :     let _ = env_logger::try_init();
     767              :     use ConnectAction::*;
     768              : 
     769            1 :     let ctx = RequestContext::test();
     770              :     // Wake → RetryNoWake (retryable + NOT wakeable)
     771            1 :     let mechanism = TestConnectMechanism::new(vec![Wake, RetryNoWake, Fail]);
     772            1 :     let user_info = helper_create_connect_info(&mechanism);
     773            1 :     let cfg = config();
     774              : 
     775            1 :     connect_to_compute_inner(&ctx, &mechanism, &user_info, cfg.retry, &cfg)
     776            1 :         .await
     777            1 :         .unwrap_err();
     778            1 :     mechanism.verify();
     779              : 
     780              :     // Because RetryNoWake has wakeable=false, we must NOT see invalidate_cache
     781            1 :     assert!(!logs_contain(
     782            1 :         "invalidating stalled compute node info cache entry"
     783            1 :     ));
     784            1 : }
     785              : 
     786              : #[tokio::test]
     787              : #[traced_test]
     788            1 : async fn retry_no_wake_error_fast() {
     789            1 :     let _ = env_logger::try_init();
     790              :     use ConnectAction::*;
     791              : 
     792            1 :     let ctx = RequestContext::test();
     793              :     // Wake → FailNoWake (not retryable + NOT wakeable)
     794            1 :     let mechanism = TestConnectMechanism::new(vec![Wake, FailNoWake]);
     795            1 :     let user_info = helper_create_connect_info(&mechanism);
     796            1 :     let cfg = config();
     797              : 
     798            1 :     connect_to_compute_inner(&ctx, &mechanism, &user_info, cfg.retry, &cfg)
     799            1 :         .await
     800            1 :         .unwrap_err();
     801            1 :     mechanism.verify();
     802              : 
     803              :     // Because FailNoWake has wakeable=false, we must NOT see invalidate_cache
     804            1 :     assert!(!logs_contain(
     805            1 :         "invalidating stalled compute node info cache entry"
     806            1 :     ));
     807            1 : }
     808              : 
     809              : #[tokio::test]
     810              : #[traced_test]
     811            1 : async fn retry_cold_wake_skips_invalidation() {
     812            1 :     let _ = env_logger::try_init();
     813              :     use ConnectAction::*;
     814              : 
     815            1 :     let ctx = RequestContext::test();
     816              :     // WakeCold → FailNoWake (not retryable + NOT wakeable)
     817            1 :     let mechanism = TestConnectMechanism::new(vec![WakeCold, Retry, Connect]);
     818            1 :     let user_info = helper_create_connect_info(&mechanism);
     819            1 :     let cfg = config();
     820              : 
     821            1 :     connect_to_compute_inner(&ctx, &mechanism, &user_info, cfg.retry, &cfg)
     822            1 :         .await
     823            1 :         .unwrap();
     824            1 :     mechanism.verify();
     825            1 : }
        

Generated by: LCOV version 2.1-beta