LCOV - code coverage report
Current view: top level - proxy/src/proxy/tests - mod.rs (source / functions) Coverage Total Hit
Test: 6a14b070dc6eeeeb359cfa8817925ac37a02fab4.info Lines: 95.2 % 496 472
Test Date: 2025-03-31 22:46:13 Functions: 90.5 % 63 57

            Line data    Source code
       1              : //! A group of high-level tests for connection establishing logic and auth.
       2              : #![allow(clippy::unimplemented, clippy::unwrap_used)]
       3              : 
       4              : mod mitm;
       5              : 
       6              : use std::time::Duration;
       7              : 
       8              : use anyhow::{Context, bail};
       9              : use async_trait::async_trait;
      10              : use http::StatusCode;
      11              : use postgres_client::config::SslMode;
      12              : use postgres_client::tls::{MakeTlsConnect, NoTls};
      13              : use retry::{ShouldRetryWakeCompute, retry_after};
      14              : use rstest::rstest;
      15              : use rustls::crypto::ring;
      16              : use rustls::pki_types;
      17              : use tokio::io::DuplexStream;
      18              : 
      19              : use super::connect_compute::ConnectMechanism;
      20              : use super::retry::CouldRetry;
      21              : use super::*;
      22              : use crate::auth::backend::{
      23              :     ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, MaybeOwned,
      24              : };
      25              : use crate::config::{ComputeConfig, RetryConfig};
      26              : use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient};
      27              : use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status};
      28              : use crate::control_plane::{
      29              :     self, CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedNodeInfo, NodeInfo, NodeInfoCache,
      30              : };
      31              : use crate::error::ErrorKind;
      32              : use crate::tls::client_config::compute_client_config_with_certs;
      33              : use crate::tls::postgres_rustls::MakeRustlsConnect;
      34              : use crate::tls::server_config::CertResolver;
      35              : use crate::types::{BranchId, EndpointId, ProjectId};
      36              : use crate::{sasl, scram};
      37              : 
      38              : /// Generate a set of TLS certificates: CA + server.
      39           21 : fn generate_certs(
      40           21 :     hostname: &str,
      41           21 :     common_name: &str,
      42           21 : ) -> anyhow::Result<(
      43           21 :     pki_types::CertificateDer<'static>,
      44           21 :     pki_types::CertificateDer<'static>,
      45           21 :     pki_types::PrivateKeyDer<'static>,
      46           21 : )> {
      47           21 :     let ca_key = rcgen::KeyPair::generate()?;
      48           21 :     let ca = {
      49           21 :         let mut params = rcgen::CertificateParams::default();
      50           21 :         params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
      51           21 :         params.self_signed(&ca_key)?
      52              :     };
      53              : 
      54           21 :     let cert_key = rcgen::KeyPair::generate()?;
      55           21 :     let cert = {
      56           21 :         let mut params = rcgen::CertificateParams::new(vec![hostname.into()])?;
      57           21 :         params.distinguished_name = rcgen::DistinguishedName::new();
      58           21 :         params
      59           21 :             .distinguished_name
      60           21 :             .push(rcgen::DnType::CommonName, common_name);
      61           21 :         params.signed_by(&cert_key, &ca, &ca_key)?
      62              :     };
      63              : 
      64           21 :     Ok((
      65           21 :         ca.der().clone(),
      66           21 :         cert.der().clone(),
      67           21 :         pki_types::PrivateKeyDer::Pkcs8(cert_key.serialize_der().into()),
      68           21 :     ))
      69           21 : }
      70              : 
      71              : struct ClientConfig<'a> {
      72              :     config: Arc<rustls::ClientConfig>,
      73              :     hostname: &'a str,
      74              : }
      75              : 
      76              : type TlsConnect<S> = <MakeRustlsConnect as MakeTlsConnect<S>>::TlsConnect;
      77              : 
      78              : impl ClientConfig<'_> {
      79           20 :     fn make_tls_connect(self) -> anyhow::Result<TlsConnect<DuplexStream>> {
      80           20 :         let mut mk = MakeRustlsConnect::new(self.config);
      81           20 :         let tls = MakeTlsConnect::<DuplexStream>::make_tls_connect(&mut mk, self.hostname)?;
      82           20 :         Ok(tls)
      83           20 :     }
      84              : }
      85              : 
      86              : /// Generate TLS certificates and build rustls configs for client and server.
      87           21 : fn generate_tls_config<'a>(
      88           21 :     hostname: &'a str,
      89           21 :     common_name: &'a str,
      90           21 : ) -> anyhow::Result<(ClientConfig<'a>, TlsConfig)> {
      91           21 :     let (ca, cert, key) = generate_certs(hostname, common_name)?;
      92              : 
      93           21 :     let tls_config = {
      94           21 :         let config =
      95           21 :             rustls::ServerConfig::builder_with_provider(Arc::new(ring::default_provider()))
      96           21 :                 .with_safe_default_protocol_versions()
      97           21 :                 .context("ring should support the default protocol versions")?
      98           21 :                 .with_no_client_auth()
      99           21 :                 .with_single_cert(vec![cert.clone()], key.clone_key())?;
     100              : 
     101           21 :         let mut cert_resolver = CertResolver::new();
     102           21 :         cert_resolver.add_cert(key, vec![cert], true)?;
     103              : 
     104           21 :         let common_names = cert_resolver.get_common_names();
     105           21 : 
     106           21 :         let config = Arc::new(config);
     107           21 : 
     108           21 :         TlsConfig {
     109           21 :             http_config: config.clone(),
     110           21 :             pg_config: config,
     111           21 :             common_names,
     112           21 :             cert_resolver: Arc::new(cert_resolver),
     113           21 :         }
     114           21 :     };
     115           21 : 
     116           21 :     let client_config = {
     117           21 :         let config = Arc::new(compute_client_config_with_certs([ca]));
     118           21 : 
     119           21 :         ClientConfig { config, hostname }
     120           21 :     };
     121           21 : 
     122           21 :     Ok((client_config, tls_config))
     123           21 : }
     124              : 
     125              : #[async_trait]
     126              : trait TestAuth: Sized {
     127            2 :     async fn authenticate<S: AsyncRead + AsyncWrite + Unpin + Send>(
     128            2 :         self,
     129            2 :         stream: &mut PqStream<Stream<S>>,
     130            2 :     ) -> anyhow::Result<()> {
     131            2 :         stream.write_message_noflush(&Be::AuthenticationOk)?;
     132            2 :         Ok(())
     133            4 :     }
     134              : }
     135              : 
     136              : struct NoAuth;
     137              : impl TestAuth for NoAuth {}
     138              : 
     139              : struct Scram(scram::ServerSecret);
     140              : 
     141              : impl Scram {
     142           11 :     async fn new(password: &str) -> anyhow::Result<Self> {
     143           11 :         let secret = scram::ServerSecret::build(password)
     144           11 :             .await
     145           11 :             .context("failed to generate scram secret")?;
     146           11 :         Ok(Scram(secret))
     147           11 :     }
     148              : 
     149            1 :     fn mock() -> Self {
     150            1 :         Scram(scram::ServerSecret::mock(rand::random()))
     151            1 :     }
     152              : }
     153              : 
     154              : #[async_trait]
     155              : impl TestAuth for Scram {
     156           12 :     async fn authenticate<S: AsyncRead + AsyncWrite + Unpin + Send>(
     157           12 :         self,
     158           12 :         stream: &mut PqStream<Stream<S>>,
     159           12 :     ) -> anyhow::Result<()> {
     160           12 :         let outcome = auth::AuthFlow::new(stream)
     161           12 :             .begin(auth::Scram(&self.0, &RequestContext::test()))
     162           12 :             .await?
     163           12 :             .authenticate()
     164           12 :             .await?;
     165              : 
     166              :         use sasl::Outcome::*;
     167            6 :         match outcome {
     168            5 :             Success(_) => Ok(()),
     169            1 :             Failure(reason) => bail!("autentication failed with an error: {reason}"),
     170              :         }
     171           24 :     }
     172              : }
     173              : 
     174              : /// A dummy proxy impl which performs a handshake and reports auth success.
     175           15 : async fn dummy_proxy(
     176           15 :     client: impl AsyncRead + AsyncWrite + Unpin + Send,
     177           15 :     tls: Option<TlsConfig>,
     178           15 :     auth: impl TestAuth + Send,
     179           15 : ) -> anyhow::Result<()> {
     180           15 :     let (client, _) = read_proxy_protocol(client).await?;
     181           15 :     let mut stream = match handshake(&RequestContext::test(), client, tls.as_ref(), false).await? {
     182           14 :         HandshakeData::Startup(stream, _) => stream,
     183            0 :         HandshakeData::Cancel(_) => bail!("cancellation not supported"),
     184              :     };
     185              : 
     186           14 :     auth.authenticate(&mut stream).await?;
     187              : 
     188            7 :     stream
     189            7 :         .write_message_noflush(&Be::CLIENT_ENCODING)?
     190            7 :         .write_message(&Be::ReadyForQuery)
     191            7 :         .await?;
     192              : 
     193            7 :     Ok(())
     194           15 : }
     195              : 
     196              : #[tokio::test]
     197            1 : async fn handshake_tls_is_enforced_by_proxy() -> anyhow::Result<()> {
     198            1 :     let (client, server) = tokio::io::duplex(1024);
     199            1 : 
     200            1 :     let (_, server_config) = generate_tls_config("generic-project-name.localhost", "localhost")?;
     201            1 :     let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth));
     202            1 : 
     203            1 :     let client_err = postgres_client::Config::new("test".to_owned(), 5432)
     204            1 :         .user("john_doe")
     205            1 :         .dbname("earth")
     206            1 :         .ssl_mode(SslMode::Disable)
     207            1 :         .connect_raw(server, NoTls)
     208            1 :         .await
     209            1 :         .err() // -> Option<E>
     210            1 :         .context("client shouldn't be able to connect")?;
     211            1 : 
     212            1 :     assert!(client_err.to_string().contains(ERR_INSECURE_CONNECTION));
     213            1 : 
     214            1 :     let server_err = proxy
     215            1 :         .await?
     216            1 :         .err() // -> Option<E>
     217            1 :         .context("server shouldn't accept client")?;
     218            1 : 
     219            1 :     assert!(client_err.to_string().contains(&server_err.to_string()));
     220            1 : 
     221            1 :     Ok(())
     222            1 : }
     223              : 
     224              : #[tokio::test]
     225            1 : async fn handshake_tls() -> anyhow::Result<()> {
     226            1 :     let (client, server) = tokio::io::duplex(1024);
     227            1 : 
     228            1 :     let (client_config, server_config) =
     229            1 :         generate_tls_config("generic-project-name.localhost", "localhost")?;
     230            1 :     let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth));
     231            1 : 
     232            1 :     let _conn = postgres_client::Config::new("test".to_owned(), 5432)
     233            1 :         .user("john_doe")
     234            1 :         .dbname("earth")
     235            1 :         .ssl_mode(SslMode::Require)
     236            1 :         .connect_raw(server, client_config.make_tls_connect()?)
     237            1 :         .await?;
     238            1 : 
     239            1 :     proxy.await?
     240            1 : }
     241              : 
     242              : #[tokio::test]
     243            1 : async fn handshake_raw() -> anyhow::Result<()> {
     244            1 :     let (client, server) = tokio::io::duplex(1024);
     245            1 : 
     246            1 :     let proxy = tokio::spawn(dummy_proxy(client, None, NoAuth));
     247            1 : 
     248            1 :     let _conn = postgres_client::Config::new("test".to_owned(), 5432)
     249            1 :         .user("john_doe")
     250            1 :         .dbname("earth")
     251            1 :         .set_param("options", "project=generic-project-name")
     252            1 :         .ssl_mode(SslMode::Prefer)
     253            1 :         .connect_raw(server, NoTls)
     254            1 :         .await?;
     255            1 : 
     256            1 :     proxy.await?
     257            1 : }
     258              : 
     259              : #[tokio::test]
     260            1 : async fn keepalive_is_inherited() -> anyhow::Result<()> {
     261            1 :     use tokio::net::{TcpListener, TcpStream};
     262            1 : 
     263            1 :     let listener = TcpListener::bind("127.0.0.1:0").await?;
     264            1 :     let port = listener.local_addr()?.port();
     265            1 :     socket2::SockRef::from(&listener).set_keepalive(true)?;
     266            1 : 
     267            1 :     let t = tokio::spawn(async move {
     268            1 :         let (client, _) = listener.accept().await?;
     269            1 :         let keepalive = socket2::SockRef::from(&client).keepalive()?;
     270            1 :         anyhow::Ok(keepalive)
     271            1 :     });
     272            1 : 
     273            1 :     TcpStream::connect(("127.0.0.1", port)).await?;
     274            1 :     assert!(t.await??, "keepalive should be inherited");
     275            1 : 
     276            1 :     Ok(())
     277            1 : }
     278              : 
     279            3 : #[rstest]
     280              : #[case("password_foo")]
     281              : #[case("pwd-bar")]
     282              : #[case("")]
     283              : #[tokio::test]
     284              : async fn scram_auth_good(#[case] password: &str) -> anyhow::Result<()> {
     285              :     let (client, server) = tokio::io::duplex(1024);
     286              : 
     287              :     let (client_config, server_config) =
     288              :         generate_tls_config("generic-project-name.localhost", "localhost")?;
     289              :     let proxy = tokio::spawn(dummy_proxy(
     290              :         client,
     291              :         Some(server_config),
     292              :         Scram::new(password).await?,
     293              :     ));
     294              : 
     295              :     let _conn = postgres_client::Config::new("test".to_owned(), 5432)
     296              :         .channel_binding(postgres_client::config::ChannelBinding::Require)
     297              :         .user("user")
     298              :         .dbname("db")
     299              :         .password(password)
     300              :         .ssl_mode(SslMode::Require)
     301              :         .connect_raw(server, client_config.make_tls_connect()?)
     302              :         .await?;
     303              : 
     304              :     proxy.await?
     305              : }
     306              : 
     307              : #[tokio::test]
     308            1 : async fn scram_auth_disable_channel_binding() -> anyhow::Result<()> {
     309            1 :     let (client, server) = tokio::io::duplex(1024);
     310            1 : 
     311            1 :     let (client_config, server_config) =
     312            1 :         generate_tls_config("generic-project-name.localhost", "localhost")?;
     313            1 :     let proxy = tokio::spawn(dummy_proxy(
     314            1 :         client,
     315            1 :         Some(server_config),
     316            1 :         Scram::new("password").await?,
     317            1 :     ));
     318            1 : 
     319            1 :     let _conn = postgres_client::Config::new("test".to_owned(), 5432)
     320            1 :         .channel_binding(postgres_client::config::ChannelBinding::Disable)
     321            1 :         .user("user")
     322            1 :         .dbname("db")
     323            1 :         .password("password")
     324            1 :         .ssl_mode(SslMode::Require)
     325            1 :         .connect_raw(server, client_config.make_tls_connect()?)
     326            1 :         .await?;
     327            1 : 
     328            1 :     proxy.await?
     329            1 : }
     330              : 
     331              : #[tokio::test]
     332            1 : async fn scram_auth_mock() -> anyhow::Result<()> {
     333            1 :     let (client, server) = tokio::io::duplex(1024);
     334            1 : 
     335            1 :     let (client_config, server_config) =
     336            1 :         generate_tls_config("generic-project-name.localhost", "localhost")?;
     337            1 :     let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), Scram::mock()));
     338            1 : 
     339            1 :     use rand::Rng;
     340            1 :     use rand::distributions::Alphanumeric;
     341            1 :     let password: String = rand::thread_rng()
     342            1 :         .sample_iter(&Alphanumeric)
     343            1 :         .take(rand::random::<u8>() as usize)
     344            1 :         .map(char::from)
     345            1 :         .collect();
     346            1 : 
     347            1 :     let _client_err = postgres_client::Config::new("test".to_owned(), 5432)
     348            1 :         .user("user")
     349            1 :         .dbname("db")
     350            1 :         .password(&password) // no password will match the mocked secret
     351            1 :         .ssl_mode(SslMode::Require)
     352            1 :         .connect_raw(server, client_config.make_tls_connect()?)
     353            1 :         .await
     354            1 :         .err() // -> Option<E>
     355            1 :         .context("client shouldn't be able to connect")?;
     356            1 : 
     357            1 :     let _server_err = proxy
     358            1 :         .await?
     359            1 :         .err() // -> Option<E>
     360            1 :         .context("server shouldn't accept client")?;
     361            1 : 
     362            1 :     Ok(())
     363            1 : }
     364              : 
     365              : #[test]
     366            1 : fn connect_compute_total_wait() {
     367            1 :     let mut total_wait = tokio::time::Duration::ZERO;
     368            1 :     let config = RetryConfig {
     369            1 :         base_delay: Duration::from_secs(1),
     370            1 :         max_retries: 5,
     371            1 :         backoff_factor: 2.0,
     372            1 :     };
     373            4 :     for num_retries in 1..config.max_retries {
     374            4 :         total_wait += retry_after(num_retries, config);
     375            4 :     }
     376            1 :     assert!(f64::abs(total_wait.as_secs_f64() - 15.0) < 0.1);
     377            1 : }
     378              : 
     379              : #[derive(Clone, Copy, Debug)]
     380              : enum ConnectAction {
     381              :     Wake,
     382              :     WakeFail,
     383              :     WakeRetry,
     384              :     Connect,
     385              :     Retry,
     386              :     Fail,
     387              : }
     388              : 
     389              : #[derive(Clone)]
     390              : struct TestConnectMechanism {
     391              :     counter: Arc<std::sync::Mutex<usize>>,
     392              :     sequence: Vec<ConnectAction>,
     393              :     cache: &'static NodeInfoCache,
     394              : }
     395              : 
     396              : impl TestConnectMechanism {
     397            7 :     fn verify(&self) {
     398            7 :         let counter = self.counter.lock().unwrap();
     399            7 :         assert_eq!(
     400            7 :             *counter,
     401            7 :             self.sequence.len(),
     402            0 :             "sequence does not proceed to the end"
     403              :         );
     404            7 :     }
     405              : }
     406              : 
     407              : impl TestConnectMechanism {
     408            7 :     fn new(sequence: Vec<ConnectAction>) -> Self {
     409            7 :         Self {
     410            7 :             counter: Arc::new(std::sync::Mutex::new(0)),
     411            7 :             sequence,
     412            7 :             cache: Box::leak(Box::new(NodeInfoCache::new(
     413            7 :                 "test",
     414            7 :                 1,
     415            7 :                 Duration::from_secs(100),
     416            7 :                 false,
     417            7 :             ))),
     418            7 :         }
     419            7 :     }
     420              : }
     421              : 
     422              : #[derive(Debug)]
     423              : struct TestConnection;
     424              : 
     425              : #[derive(Debug)]
     426              : struct TestConnectError {
     427              :     retryable: bool,
     428              :     kind: crate::error::ErrorKind,
     429              : }
     430              : 
     431              : impl ReportableError for TestConnectError {
     432            0 :     fn get_error_kind(&self) -> crate::error::ErrorKind {
     433            0 :         self.kind
     434            0 :     }
     435              : }
     436              : 
     437              : impl std::fmt::Display for TestConnectError {
     438            0 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
     439            0 :         write!(f, "{self:?}")
     440            0 :     }
     441              : }
     442              : 
     443              : impl std::error::Error for TestConnectError {}
     444              : 
     445              : impl CouldRetry for TestConnectError {
     446            5 :     fn could_retry(&self) -> bool {
     447            5 :         self.retryable
     448            5 :     }
     449              : }
     450              : impl ShouldRetryWakeCompute for TestConnectError {
     451            4 :     fn should_retry_wake_compute(&self) -> bool {
     452            4 :         true
     453            4 :     }
     454              : }
     455              : 
     456              : #[async_trait]
     457              : impl ConnectMechanism for TestConnectMechanism {
     458              :     type Connection = TestConnection;
     459              :     type ConnectError = TestConnectError;
     460              :     type Error = anyhow::Error;
     461              : 
     462           14 :     async fn connect_once(
     463           14 :         &self,
     464           14 :         _ctx: &RequestContext,
     465           14 :         _node_info: &control_plane::CachedNodeInfo,
     466           14 :         _config: &ComputeConfig,
     467           14 :     ) -> Result<Self::Connection, Self::ConnectError> {
     468           14 :         let mut counter = self.counter.lock().unwrap();
     469           14 :         let action = self.sequence[*counter];
     470           14 :         *counter += 1;
     471           14 :         match action {
     472            4 :             ConnectAction::Connect => Ok(TestConnection),
     473            8 :             ConnectAction::Retry => Err(TestConnectError {
     474            8 :                 retryable: true,
     475            8 :                 kind: ErrorKind::Compute,
     476            8 :             }),
     477            2 :             ConnectAction::Fail => Err(TestConnectError {
     478            2 :                 retryable: false,
     479            2 :                 kind: ErrorKind::Compute,
     480            2 :             }),
     481            0 :             x => panic!("expecting action {x:?}, connect is called instead"),
     482              :         }
     483           28 :     }
     484              : 
     485           10 :     fn update_connect_config(&self, _conf: &mut compute::ConnCfg) {}
     486              : }
     487              : 
     488              : impl TestControlPlaneClient for TestConnectMechanism {
     489           13 :     fn wake_compute(&self) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
     490           13 :         let mut counter = self.counter.lock().unwrap();
     491           13 :         let action = self.sequence[*counter];
     492           13 :         *counter += 1;
     493           13 :         match action {
     494           10 :             ConnectAction::Wake => Ok(helper_create_cached_node_info(self.cache)),
     495              :             ConnectAction::WakeFail => {
     496            1 :                 let err = control_plane::errors::ControlPlaneError::Message(Box::new(
     497            1 :                     ControlPlaneErrorMessage {
     498            1 :                         http_status_code: StatusCode::BAD_REQUEST,
     499            1 :                         error: "TEST".into(),
     500            1 :                         status: None,
     501            1 :                     },
     502            1 :                 ));
     503            1 :                 assert!(!err.could_retry());
     504            1 :                 Err(control_plane::errors::WakeComputeError::ControlPlane(err))
     505              :             }
     506              :             ConnectAction::WakeRetry => {
     507            2 :                 let err = control_plane::errors::ControlPlaneError::Message(Box::new(
     508            2 :                     ControlPlaneErrorMessage {
     509            2 :                         http_status_code: StatusCode::BAD_REQUEST,
     510            2 :                         error: "TEST".into(),
     511            2 :                         status: Some(Status {
     512            2 :                             code: "error".into(),
     513            2 :                             message: "error".into(),
     514            2 :                             details: Details {
     515            2 :                                 error_info: None,
     516            2 :                                 retry_info: Some(control_plane::messages::RetryInfo {
     517            2 :                                     retry_delay_ms: 1,
     518            2 :                                 }),
     519            2 :                                 user_facing_message: None,
     520            2 :                             },
     521            2 :                         }),
     522            2 :                     },
     523            2 :                 ));
     524            2 :                 assert!(err.could_retry());
     525            2 :                 Err(control_plane::errors::WakeComputeError::ControlPlane(err))
     526              :             }
     527            0 :             x => panic!("expecting action {x:?}, wake_compute is called instead"),
     528              :         }
     529           13 :     }
     530              : 
     531            0 :     fn get_allowed_ips(&self) -> Result<CachedAllowedIps, control_plane::errors::GetAuthInfoError> {
     532            0 :         unimplemented!("not used in tests")
     533              :     }
     534              : 
     535            0 :     fn get_allowed_vpc_endpoint_ids(
     536            0 :         &self,
     537            0 :     ) -> Result<CachedAllowedVpcEndpointIds, control_plane::errors::GetAuthInfoError> {
     538            0 :         unimplemented!("not used in tests")
     539              :     }
     540              : 
     541            0 :     fn get_block_public_or_vpc_access(
     542            0 :         &self,
     543            0 :     ) -> Result<control_plane::CachedAccessBlockerFlags, control_plane::errors::GetAuthInfoError>
     544            0 :     {
     545            0 :         unimplemented!("not used in tests")
     546              :     }
     547              : 
     548            0 :     fn dyn_clone(&self) -> Box<dyn TestControlPlaneClient> {
     549            0 :         Box::new(self.clone())
     550            0 :     }
     551              : }
     552              : 
     553           10 : fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo {
     554           10 :     let node = NodeInfo {
     555           10 :         config: compute::ConnCfg::new("test".to_owned(), 5432),
     556           10 :         aux: MetricsAuxInfo {
     557           10 :             endpoint_id: (&EndpointId::from("endpoint")).into(),
     558           10 :             project_id: (&ProjectId::from("project")).into(),
     559           10 :             branch_id: (&BranchId::from("branch")).into(),
     560           10 :             compute_id: "compute".into(),
     561           10 :             cold_start_info: crate::control_plane::messages::ColdStartInfo::Warm,
     562           10 :         },
     563           10 :     };
     564           10 :     let (_, node2) = cache.insert_unit("key".into(), Ok(node.clone()));
     565           10 :     node2.map(|()| node)
     566           10 : }
     567              : 
     568            7 : fn helper_create_connect_info(
     569            7 :     mechanism: &TestConnectMechanism,
     570            7 : ) -> auth::Backend<'static, ComputeCredentials> {
     571            7 :     let user_info = auth::Backend::ControlPlane(
     572            7 :         MaybeOwned::Owned(ControlPlaneClient::Test(Box::new(mechanism.clone()))),
     573            7 :         ComputeCredentials {
     574            7 :             info: ComputeUserInfo {
     575            7 :                 endpoint: "endpoint".into(),
     576            7 :                 user: "user".into(),
     577            7 :                 options: NeonOptions::parse_options_raw(""),
     578            7 :             },
     579            7 :             keys: ComputeCredentialKeys::Password("password".into()),
     580            7 :         },
     581            7 :     );
     582            7 :     user_info
     583            7 : }
     584              : 
     585            7 : fn config() -> ComputeConfig {
     586            7 :     let retry = RetryConfig {
     587            7 :         base_delay: Duration::from_secs(1),
     588            7 :         max_retries: 5,
     589            7 :         backoff_factor: 2.0,
     590            7 :     };
     591            7 : 
     592            7 :     ComputeConfig {
     593            7 :         retry,
     594            7 :         tls: Arc::new(compute_client_config_with_certs(std::iter::empty())),
     595            7 :         timeout: Duration::from_secs(2),
     596            7 :     }
     597            7 : }
     598              : 
     599              : #[tokio::test]
     600            1 : async fn connect_to_compute_success() {
     601            1 :     let _ = env_logger::try_init();
     602            1 :     use ConnectAction::*;
     603            1 :     let ctx = RequestContext::test();
     604            1 :     let mechanism = TestConnectMechanism::new(vec![Wake, Connect]);
     605            1 :     let user_info = helper_create_connect_info(&mechanism);
     606            1 :     let config = config();
     607            1 :     connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
     608            1 :         .await
     609            1 :         .unwrap();
     610            1 :     mechanism.verify();
     611            1 : }
     612              : 
     613              : #[tokio::test]
     614            1 : async fn connect_to_compute_retry() {
     615            1 :     let _ = env_logger::try_init();
     616            1 :     use ConnectAction::*;
     617            1 :     let ctx = RequestContext::test();
     618            1 :     let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Connect]);
     619            1 :     let user_info = helper_create_connect_info(&mechanism);
     620            1 :     let config = config();
     621            1 :     connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
     622            1 :         .await
     623            1 :         .unwrap();
     624            1 :     mechanism.verify();
     625            1 : }
     626              : 
     627              : /// Test that we don't retry if the error is not retryable.
     628              : #[tokio::test]
     629            1 : async fn connect_to_compute_non_retry_1() {
     630            1 :     let _ = env_logger::try_init();
     631            1 :     use ConnectAction::*;
     632            1 :     let ctx = RequestContext::test();
     633            1 :     let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Fail]);
     634            1 :     let user_info = helper_create_connect_info(&mechanism);
     635            1 :     let config = config();
     636            1 :     connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
     637            1 :         .await
     638            1 :         .unwrap_err();
     639            1 :     mechanism.verify();
     640            1 : }
     641              : 
     642              : /// Even for non-retryable errors, we should retry at least once.
     643              : #[tokio::test]
     644            1 : async fn connect_to_compute_non_retry_2() {
     645            1 :     let _ = env_logger::try_init();
     646            1 :     use ConnectAction::*;
     647            1 :     let ctx = RequestContext::test();
     648            1 :     let mechanism = TestConnectMechanism::new(vec![Wake, Fail, Wake, Connect]);
     649            1 :     let user_info = helper_create_connect_info(&mechanism);
     650            1 :     let config = config();
     651            1 :     connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
     652            1 :         .await
     653            1 :         .unwrap();
     654            1 :     mechanism.verify();
     655            1 : }
     656              : 
     657              : /// Retry for at most `NUM_RETRIES_CONNECT` times.
     658              : #[tokio::test]
     659            1 : async fn connect_to_compute_non_retry_3() {
     660            1 :     let _ = env_logger::try_init();
     661            1 :     tokio::time::pause();
     662            1 :     use ConnectAction::*;
     663            1 :     let ctx = RequestContext::test();
     664            1 :     let mechanism =
     665            1 :         TestConnectMechanism::new(vec![Wake, Retry, Wake, Retry, Retry, Retry, Retry, Retry]);
     666            1 :     let user_info = helper_create_connect_info(&mechanism);
     667            1 :     let wake_compute_retry_config = RetryConfig {
     668            1 :         base_delay: Duration::from_secs(1),
     669            1 :         max_retries: 1,
     670            1 :         backoff_factor: 2.0,
     671            1 :     };
     672            1 :     let config = config();
     673            1 :     connect_to_compute(
     674            1 :         &ctx,
     675            1 :         &mechanism,
     676            1 :         &user_info,
     677            1 :         wake_compute_retry_config,
     678            1 :         &config,
     679            1 :     )
     680            1 :     .await
     681            1 :     .unwrap_err();
     682            1 :     mechanism.verify();
     683            1 : }
     684              : 
     685              : /// Should retry wake compute.
     686              : #[tokio::test]
     687            1 : async fn wake_retry() {
     688            1 :     let _ = env_logger::try_init();
     689            1 :     use ConnectAction::*;
     690            1 :     let ctx = RequestContext::test();
     691            1 :     let mechanism = TestConnectMechanism::new(vec![WakeRetry, Wake, Connect]);
     692            1 :     let user_info = helper_create_connect_info(&mechanism);
     693            1 :     let config = config();
     694            1 :     connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
     695            1 :         .await
     696            1 :         .unwrap();
     697            1 :     mechanism.verify();
     698            1 : }
     699              : 
     700              : /// Wake failed with a non-retryable error.
     701              : #[tokio::test]
     702            1 : async fn wake_non_retry() {
     703            1 :     let _ = env_logger::try_init();
     704            1 :     use ConnectAction::*;
     705            1 :     let ctx = RequestContext::test();
     706            1 :     let mechanism = TestConnectMechanism::new(vec![WakeRetry, WakeFail]);
     707            1 :     let user_info = helper_create_connect_info(&mechanism);
     708            1 :     let config = config();
     709            1 :     connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
     710            1 :         .await
     711            1 :         .unwrap_err();
     712            1 :     mechanism.verify();
     713            1 : }
        

Generated by: LCOV version 2.1-beta