Line data Source code
1 : //! A group of high-level tests for connection establishing logic and auth.
2 : #![allow(clippy::unimplemented)]
3 :
4 : mod mitm;
5 :
6 : use std::sync::Arc;
7 : use std::time::Duration;
8 :
9 : use anyhow::{Context, bail};
10 : use async_trait::async_trait;
11 : use http::StatusCode;
12 : use postgres_client::config::SslMode;
13 : use postgres_client::tls::{MakeTlsConnect, NoTls};
14 : use rstest::rstest;
15 : use rustls::crypto::ring;
16 : use rustls::pki_types;
17 : use tokio::io::{AsyncRead, AsyncWrite, DuplexStream};
18 : use tokio::time::Instant;
19 : use tracing_test::traced_test;
20 :
21 : use super::retry::CouldRetry;
22 : use crate::auth::backend::{ComputeUserInfo, MaybeOwned};
23 : use crate::config::{ComputeConfig, RetryConfig, TlsConfig};
24 : use crate::context::RequestContext;
25 : use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient};
26 : use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status};
27 : use crate::control_plane::{self, CachedNodeInfo, NodeInfo, NodeInfoCache};
28 : use crate::error::ErrorKind;
29 : use crate::pglb::ERR_INSECURE_CONNECTION;
30 : use crate::pglb::handshake::{HandshakeData, handshake};
31 : use crate::pqproto::BeMessage;
32 : use crate::proxy::NeonOptions;
33 : use crate::proxy::connect_compute::{ConnectMechanism, connect_to_compute_inner};
34 : use crate::proxy::retry::retry_after;
35 : use crate::stream::{PqStream, Stream};
36 : use crate::tls::client_config::compute_client_config_with_certs;
37 : use crate::tls::server_config::CertResolver;
38 : use crate::types::{BranchId, EndpointId, ProjectId};
39 : use crate::{auth, compute, sasl, scram};
40 :
41 : /// Generate a set of TLS certificates: CA + server.
42 21 : fn generate_certs(
43 21 : hostname: &str,
44 21 : common_name: &str,
45 21 : ) -> anyhow::Result<(
46 21 : pki_types::CertificateDer<'static>,
47 21 : pki_types::CertificateDer<'static>,
48 21 : pki_types::PrivateKeyDer<'static>,
49 21 : )> {
50 21 : let ca_key = rcgen::KeyPair::generate()?;
51 21 : let ca = {
52 21 : let mut params = rcgen::CertificateParams::default();
53 21 : params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
54 21 : params.self_signed(&ca_key)?
55 : };
56 :
57 21 : let cert_key = rcgen::KeyPair::generate()?;
58 21 : let cert = {
59 21 : let mut params = rcgen::CertificateParams::new(vec![hostname.into()])?;
60 21 : params.distinguished_name = rcgen::DistinguishedName::new();
61 21 : params
62 21 : .distinguished_name
63 21 : .push(rcgen::DnType::CommonName, common_name);
64 21 : params.signed_by(&cert_key, &ca, &ca_key)?
65 : };
66 :
67 21 : Ok((
68 21 : ca.der().clone(),
69 21 : cert.der().clone(),
70 21 : pki_types::PrivateKeyDer::Pkcs8(cert_key.serialize_der().into()),
71 21 : ))
72 21 : }
73 :
74 : struct ClientConfig<'a> {
75 : config: Arc<rustls::ClientConfig>,
76 : hostname: &'a str,
77 : }
78 :
79 : type TlsConnect<S> = <ComputeConfig as MakeTlsConnect<S>>::TlsConnect;
80 :
81 : impl ClientConfig<'_> {
82 20 : fn make_tls_connect(self) -> anyhow::Result<TlsConnect<DuplexStream>> {
83 20 : Ok(crate::tls::postgres_rustls::make_tls_connect(
84 20 : &self.config,
85 20 : self.hostname,
86 0 : )?)
87 20 : }
88 : }
89 :
90 : /// Generate TLS certificates and build rustls configs for client and server.
91 21 : fn generate_tls_config<'a>(
92 21 : hostname: &'a str,
93 21 : common_name: &'a str,
94 21 : ) -> anyhow::Result<(ClientConfig<'a>, TlsConfig)> {
95 21 : let (ca, cert, key) = generate_certs(hostname, common_name)?;
96 :
97 21 : let tls_config = {
98 21 : let config =
99 21 : rustls::ServerConfig::builder_with_provider(Arc::new(ring::default_provider()))
100 21 : .with_safe_default_protocol_versions()
101 21 : .context("ring should support the default protocol versions")?
102 21 : .with_no_client_auth()
103 21 : .with_single_cert(vec![cert.clone()], key.clone_key())?;
104 :
105 21 : let cert_resolver = CertResolver::new(key, vec![cert])?;
106 :
107 21 : let common_names = cert_resolver.get_common_names();
108 :
109 21 : let config = Arc::new(config);
110 :
111 21 : TlsConfig {
112 21 : http_config: config.clone(),
113 21 : pg_config: config,
114 21 : common_names,
115 21 : cert_resolver: Arc::new(cert_resolver),
116 21 : }
117 : };
118 :
119 21 : let client_config = {
120 21 : let config = Arc::new(compute_client_config_with_certs([ca]));
121 :
122 21 : ClientConfig { config, hostname }
123 : };
124 :
125 21 : Ok((client_config, tls_config))
126 21 : }
127 :
128 : #[async_trait]
129 : trait TestAuth: Sized {
130 2 : async fn authenticate<S: AsyncRead + AsyncWrite + Unpin + Send>(
131 : self,
132 : stream: &mut PqStream<Stream<S>>,
133 2 : ) -> anyhow::Result<()> {
134 2 : stream.write_message(BeMessage::AuthenticationOk);
135 2 : Ok(())
136 4 : }
137 : }
138 :
139 : struct NoAuth;
140 : impl TestAuth for NoAuth {}
141 :
142 : struct Scram(scram::ServerSecret);
143 :
144 : impl Scram {
145 11 : async fn new(password: &str) -> anyhow::Result<Self> {
146 11 : let secret = scram::ServerSecret::build(password)
147 11 : .await
148 11 : .context("failed to generate scram secret")?;
149 11 : Ok(Scram(secret))
150 11 : }
151 :
152 1 : fn mock() -> Self {
153 1 : Scram(scram::ServerSecret::mock(rand::random()))
154 1 : }
155 : }
156 :
157 : #[async_trait]
158 : impl TestAuth for Scram {
159 12 : async fn authenticate<S: AsyncRead + AsyncWrite + Unpin + Send>(
160 : self,
161 : stream: &mut PqStream<Stream<S>>,
162 12 : ) -> anyhow::Result<()> {
163 12 : let outcome = auth::AuthFlow::new(stream, auth::Scram(&self.0, &RequestContext::test()))
164 12 : .authenticate()
165 12 : .await?;
166 :
167 : use sasl::Outcome::*;
168 6 : match outcome {
169 5 : Success(_) => Ok(()),
170 1 : Failure(reason) => bail!("autentication failed with an error: {reason}"),
171 : }
172 24 : }
173 : }
174 :
175 : /// A dummy proxy impl which performs a handshake and reports auth success.
176 15 : async fn dummy_proxy(
177 15 : client: impl AsyncRead + AsyncWrite + Unpin + Send,
178 15 : tls: Option<TlsConfig>,
179 15 : auth: impl TestAuth + Send,
180 15 : ) -> anyhow::Result<()> {
181 15 : let mut stream = match handshake(&RequestContext::test(), client, tls.as_ref(), false).await? {
182 14 : HandshakeData::Startup(stream, _) => stream,
183 0 : HandshakeData::Cancel(_) => bail!("cancellation not supported"),
184 : };
185 :
186 14 : auth.authenticate(&mut stream).await?;
187 :
188 7 : stream.write_message(BeMessage::ParameterStatus {
189 7 : name: b"client_encoding",
190 7 : value: b"UTF8",
191 7 : });
192 7 : stream.write_message(BeMessage::ReadyForQuery);
193 7 : stream.flush().await?;
194 :
195 7 : Ok(())
196 15 : }
197 :
198 : #[tokio::test]
199 1 : async fn handshake_tls_is_enforced_by_proxy() -> anyhow::Result<()> {
200 1 : let (client, server) = tokio::io::duplex(1024);
201 :
202 1 : let (_, server_config) = generate_tls_config("generic-project-name.localhost", "localhost")?;
203 1 : let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth));
204 :
205 1 : let client_err = postgres_client::Config::new("test".to_owned(), 5432)
206 1 : .user("john_doe")
207 1 : .dbname("earth")
208 1 : .ssl_mode(SslMode::Disable)
209 1 : .tls_and_authenticate(server, NoTls)
210 1 : .await
211 1 : .err() // -> Option<E>
212 1 : .context("client shouldn't be able to connect")?;
213 :
214 1 : assert!(client_err.to_string().contains(ERR_INSECURE_CONNECTION));
215 :
216 1 : let server_err = proxy
217 1 : .await?
218 1 : .err() // -> Option<E>
219 1 : .context("server shouldn't accept client")?;
220 :
221 1 : assert!(client_err.to_string().contains(&server_err.to_string()));
222 :
223 2 : Ok(())
224 1 : }
225 :
226 : #[tokio::test]
227 1 : async fn handshake_tls() -> anyhow::Result<()> {
228 1 : let (client, server) = tokio::io::duplex(1024);
229 :
230 1 : let (client_config, server_config) =
231 1 : generate_tls_config("generic-project-name.localhost", "localhost")?;
232 1 : let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth));
233 :
234 1 : let _conn = postgres_client::Config::new("test".to_owned(), 5432)
235 1 : .user("john_doe")
236 1 : .dbname("earth")
237 1 : .ssl_mode(SslMode::Require)
238 1 : .tls_and_authenticate(server, client_config.make_tls_connect()?)
239 1 : .await?;
240 :
241 1 : proxy.await?
242 1 : }
243 :
244 : #[tokio::test]
245 1 : async fn handshake_raw() -> anyhow::Result<()> {
246 1 : let (client, server) = tokio::io::duplex(1024);
247 :
248 1 : let proxy = tokio::spawn(dummy_proxy(client, None, NoAuth));
249 :
250 1 : let _conn = postgres_client::Config::new("test".to_owned(), 5432)
251 1 : .user("john_doe")
252 1 : .dbname("earth")
253 1 : .set_param("options", "project=generic-project-name")
254 1 : .ssl_mode(SslMode::Prefer)
255 1 : .tls_and_authenticate(server, NoTls)
256 1 : .await?;
257 :
258 1 : proxy.await?
259 1 : }
260 :
261 : #[tokio::test]
262 1 : async fn keepalive_is_inherited() -> anyhow::Result<()> {
263 : use tokio::net::{TcpListener, TcpStream};
264 :
265 1 : let listener = TcpListener::bind("127.0.0.1:0").await?;
266 1 : let port = listener.local_addr()?.port();
267 1 : socket2::SockRef::from(&listener).set_keepalive(true)?;
268 :
269 1 : let t = tokio::spawn(async move {
270 1 : let (client, _) = listener.accept().await?;
271 1 : let keepalive = socket2::SockRef::from(&client).keepalive()?;
272 1 : anyhow::Ok(keepalive)
273 1 : });
274 :
275 1 : TcpStream::connect(("127.0.0.1", port)).await?;
276 1 : assert!(t.await??, "keepalive should be inherited");
277 :
278 2 : Ok(())
279 1 : }
280 :
281 : #[rstest]
282 : #[case("password_foo")]
283 : #[case("pwd-bar")]
284 : #[case("")]
285 : #[tokio::test]
286 : async fn scram_auth_good(#[case] password: &str) -> anyhow::Result<()> {
287 : let (client, server) = tokio::io::duplex(1024);
288 :
289 : let (client_config, server_config) =
290 : generate_tls_config("generic-project-name.localhost", "localhost")?;
291 : let proxy = tokio::spawn(dummy_proxy(
292 : client,
293 : Some(server_config),
294 : Scram::new(password).await?,
295 : ));
296 :
297 : let _conn = postgres_client::Config::new("test".to_owned(), 5432)
298 : .channel_binding(postgres_client::config::ChannelBinding::Require)
299 : .user("user")
300 : .dbname("db")
301 : .password(password)
302 : .ssl_mode(SslMode::Require)
303 : .tls_and_authenticate(server, client_config.make_tls_connect()?)
304 : .await?;
305 :
306 : proxy.await?
307 : }
308 :
309 : #[tokio::test]
310 1 : async fn scram_auth_disable_channel_binding() -> anyhow::Result<()> {
311 1 : let (client, server) = tokio::io::duplex(1024);
312 :
313 1 : let (client_config, server_config) =
314 1 : generate_tls_config("generic-project-name.localhost", "localhost")?;
315 1 : let proxy = tokio::spawn(dummy_proxy(
316 1 : client,
317 1 : Some(server_config),
318 1 : Scram::new("password").await?,
319 : ));
320 :
321 1 : let _conn = postgres_client::Config::new("test".to_owned(), 5432)
322 1 : .channel_binding(postgres_client::config::ChannelBinding::Disable)
323 1 : .user("user")
324 1 : .dbname("db")
325 1 : .password("password")
326 1 : .ssl_mode(SslMode::Require)
327 1 : .tls_and_authenticate(server, client_config.make_tls_connect()?)
328 1 : .await?;
329 :
330 1 : proxy.await?
331 1 : }
332 :
333 : #[tokio::test]
334 1 : async fn scram_auth_mock() -> anyhow::Result<()> {
335 1 : let (client, server) = tokio::io::duplex(1024);
336 :
337 1 : let (client_config, server_config) =
338 1 : generate_tls_config("generic-project-name.localhost", "localhost")?;
339 1 : let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), Scram::mock()));
340 :
341 : use rand::Rng;
342 : use rand::distr::Alphanumeric;
343 1 : let password: String = rand::rng()
344 1 : .sample_iter(&Alphanumeric)
345 1 : .take(rand::random::<u8>() as usize)
346 1 : .map(char::from)
347 1 : .collect();
348 :
349 1 : let _client_err = postgres_client::Config::new("test".to_owned(), 5432)
350 1 : .user("user")
351 1 : .dbname("db")
352 1 : .password(&password) // no password will match the mocked secret
353 1 : .ssl_mode(SslMode::Require)
354 1 : .tls_and_authenticate(server, client_config.make_tls_connect()?)
355 1 : .await
356 1 : .err() // -> Option<E>
357 1 : .context("client shouldn't be able to connect")?;
358 :
359 1 : let _server_err = proxy
360 1 : .await?
361 1 : .err() // -> Option<E>
362 1 : .context("server shouldn't accept client")?;
363 :
364 2 : Ok(())
365 1 : }
366 :
367 : #[test]
368 1 : fn connect_compute_total_wait() {
369 1 : let mut total_wait = tokio::time::Duration::ZERO;
370 1 : let config = RetryConfig {
371 1 : base_delay: Duration::from_secs(1),
372 1 : max_retries: 5,
373 1 : backoff_factor: 2.0,
374 1 : };
375 4 : for num_retries in 1..config.max_retries {
376 4 : total_wait += retry_after(num_retries, config);
377 4 : }
378 1 : assert!(f64::abs(total_wait.as_secs_f64() - 15.0) < 0.1);
379 1 : }
380 :
381 : #[derive(Clone, Copy, Debug)]
382 : enum ConnectAction {
383 : Wake,
384 : WakeCold,
385 : WakeFail,
386 : WakeRetry,
387 : Connect,
388 : // connect_once -> Err, could_retry = true, should_retry_wake_compute = true
389 : Retry,
390 : // connect_once -> Err, could_retry = true, should_retry_wake_compute = false
391 : RetryNoWake,
392 : // connect_once -> Err, could_retry = false, should_retry_wake_compute = true
393 : Fail,
394 : // connect_once -> Err, could_retry = false, should_retry_wake_compute = false
395 : FailNoWake,
396 : }
397 :
398 : #[derive(Clone)]
399 : struct TestConnectMechanism {
400 : counter: Arc<std::sync::Mutex<usize>>,
401 : sequence: Vec<ConnectAction>,
402 : cache: &'static NodeInfoCache,
403 : }
404 :
405 : impl TestConnectMechanism {
406 11 : fn verify(&self) {
407 11 : let counter = self.counter.lock().unwrap();
408 11 : assert_eq!(
409 11 : *counter,
410 11 : self.sequence.len(),
411 0 : "sequence does not proceed to the end"
412 : );
413 11 : }
414 : }
415 :
416 : impl TestConnectMechanism {
417 13 : fn new(sequence: Vec<ConnectAction>) -> Self {
418 13 : Self {
419 13 : counter: Arc::new(std::sync::Mutex::new(0)),
420 13 : sequence,
421 13 : cache: Box::leak(Box::new(NodeInfoCache::new(
422 13 : "test",
423 13 : 1,
424 13 : Duration::from_secs(100),
425 13 : false,
426 13 : ))),
427 13 : }
428 13 : }
429 : }
430 :
431 : #[derive(Debug)]
432 : struct TestConnection;
433 :
434 : impl ConnectMechanism for TestConnectMechanism {
435 : type Connection = TestConnection;
436 :
437 25 : async fn connect_once(
438 25 : &self,
439 25 : _ctx: &RequestContext,
440 25 : _node_info: &control_plane::CachedNodeInfo,
441 25 : _config: &ComputeConfig,
442 25 : ) -> Result<Self::Connection, compute::ConnectionError> {
443 25 : let mut counter = self.counter.lock().unwrap();
444 25 : let action = self.sequence[*counter];
445 25 : *counter += 1;
446 25 : match action {
447 8 : ConnectAction::Connect => Ok(TestConnection),
448 10 : ConnectAction::Retry => Err(compute::ConnectionError::TestError {
449 10 : retryable: true,
450 10 : wakeable: true,
451 10 : kind: ErrorKind::Compute,
452 10 : }),
453 2 : ConnectAction::RetryNoWake => Err(compute::ConnectionError::TestError {
454 2 : retryable: true,
455 2 : wakeable: false,
456 2 : kind: ErrorKind::Compute,
457 2 : }),
458 4 : ConnectAction::Fail => Err(compute::ConnectionError::TestError {
459 4 : retryable: false,
460 4 : wakeable: true,
461 4 : kind: ErrorKind::Compute,
462 4 : }),
463 1 : ConnectAction::FailNoWake => Err(compute::ConnectionError::TestError {
464 1 : retryable: false,
465 1 : wakeable: false,
466 1 : kind: ErrorKind::Compute,
467 1 : }),
468 0 : x => panic!("expecting action {x:?}, connect is called instead"),
469 : }
470 25 : }
471 : }
472 :
473 : impl TestControlPlaneClient for TestConnectMechanism {
474 21 : fn wake_compute(&self) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
475 21 : let mut counter = self.counter.lock().unwrap();
476 21 : let action = self.sequence[*counter];
477 21 : *counter += 1;
478 21 : match action {
479 17 : ConnectAction::Wake => Ok(helper_create_cached_node_info(self.cache)),
480 1 : ConnectAction::WakeCold => Ok(CachedNodeInfo::new_uncached(
481 1 : helper_create_uncached_node_info(),
482 1 : )),
483 : ConnectAction::WakeFail => {
484 1 : let err = control_plane::errors::ControlPlaneError::Message(Box::new(
485 1 : ControlPlaneErrorMessage {
486 1 : http_status_code: StatusCode::BAD_REQUEST,
487 1 : error: "TEST".into(),
488 1 : status: None,
489 1 : },
490 1 : ));
491 1 : assert!(!err.could_retry());
492 1 : Err(control_plane::errors::WakeComputeError::ControlPlane(err))
493 : }
494 : ConnectAction::WakeRetry => {
495 2 : let err = control_plane::errors::ControlPlaneError::Message(Box::new(
496 2 : ControlPlaneErrorMessage {
497 2 : http_status_code: StatusCode::BAD_REQUEST,
498 2 : error: "TEST".into(),
499 2 : status: Some(Status {
500 2 : code: "error".into(),
501 2 : message: "error".into(),
502 2 : details: Details {
503 2 : error_info: None,
504 2 : retry_info: Some(control_plane::messages::RetryInfo {
505 2 : retry_at: Instant::now() + Duration::from_millis(1),
506 2 : }),
507 2 : user_facing_message: None,
508 2 : },
509 2 : }),
510 2 : },
511 2 : ));
512 2 : assert!(err.could_retry());
513 2 : Err(control_plane::errors::WakeComputeError::ControlPlane(err))
514 : }
515 0 : x => panic!("expecting action {x:?}, wake_compute is called instead"),
516 : }
517 21 : }
518 :
519 0 : fn get_access_control(
520 0 : &self,
521 0 : ) -> Result<control_plane::EndpointAccessControl, control_plane::errors::GetAuthInfoError> {
522 0 : unimplemented!("not used in tests")
523 : }
524 :
525 0 : fn dyn_clone(&self) -> Box<dyn TestControlPlaneClient> {
526 0 : Box::new(self.clone())
527 0 : }
528 : }
529 :
530 18 : fn helper_create_uncached_node_info() -> NodeInfo {
531 18 : NodeInfo {
532 18 : conn_info: compute::ConnectInfo {
533 18 : host: "test".into(),
534 18 : port: 5432,
535 18 : ssl_mode: SslMode::Disable,
536 18 : host_addr: None,
537 18 : },
538 18 : aux: MetricsAuxInfo {
539 18 : endpoint_id: (&EndpointId::from("endpoint")).into(),
540 18 : project_id: (&ProjectId::from("project")).into(),
541 18 : branch_id: (&BranchId::from("branch")).into(),
542 18 : compute_id: "compute".into(),
543 18 : cold_start_info: crate::control_plane::messages::ColdStartInfo::Warm,
544 18 : },
545 18 : }
546 18 : }
547 :
548 17 : fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo {
549 17 : let node = helper_create_uncached_node_info();
550 17 : let (_, node2) = cache.insert_unit("key".into(), Ok(node.clone()));
551 17 : node2.map(|()| node)
552 17 : }
553 :
554 13 : fn helper_create_connect_info(
555 13 : mechanism: &TestConnectMechanism,
556 13 : ) -> auth::Backend<'static, ComputeUserInfo> {
557 13 : auth::Backend::ControlPlane(
558 13 : MaybeOwned::Owned(ControlPlaneClient::Test(Box::new(mechanism.clone()))),
559 13 : ComputeUserInfo {
560 13 : endpoint: "endpoint".into(),
561 13 : user: "user".into(),
562 13 : options: NeonOptions::parse_options_raw(""),
563 13 : },
564 13 : )
565 13 : }
566 :
567 13 : fn config() -> ComputeConfig {
568 13 : let retry = RetryConfig {
569 13 : base_delay: Duration::from_secs(1),
570 13 : max_retries: 5,
571 13 : backoff_factor: 2.0,
572 13 : };
573 :
574 13 : ComputeConfig {
575 13 : retry,
576 13 : tls: Arc::new(compute_client_config_with_certs(std::iter::empty())),
577 13 : timeout: Duration::from_secs(2),
578 13 : }
579 13 : }
580 :
581 : #[tokio::test]
582 1 : async fn connect_to_compute_success() {
583 1 : let _ = env_logger::try_init();
584 : use ConnectAction::*;
585 1 : let ctx = RequestContext::test();
586 1 : let mechanism = TestConnectMechanism::new(vec![Wake, Connect]);
587 1 : let user_info = helper_create_connect_info(&mechanism);
588 1 : let config = config();
589 1 : connect_to_compute_inner(&ctx, &mechanism, &user_info, config.retry, &config)
590 1 : .await
591 1 : .unwrap();
592 1 : mechanism.verify();
593 1 : }
594 :
595 : #[tokio::test]
596 1 : async fn connect_to_compute_retry() {
597 1 : let _ = env_logger::try_init();
598 : use ConnectAction::*;
599 1 : let ctx = RequestContext::test();
600 1 : let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Connect]);
601 1 : let user_info = helper_create_connect_info(&mechanism);
602 1 : let config = config();
603 1 : connect_to_compute_inner(&ctx, &mechanism, &user_info, config.retry, &config)
604 1 : .await
605 1 : .unwrap();
606 1 : mechanism.verify();
607 1 : }
608 :
609 : /// Test that we don't retry if the error is not retryable.
610 : #[tokio::test]
611 1 : async fn connect_to_compute_non_retry_1() {
612 1 : let _ = env_logger::try_init();
613 : use ConnectAction::*;
614 1 : let ctx = RequestContext::test();
615 1 : let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Fail]);
616 1 : let user_info = helper_create_connect_info(&mechanism);
617 1 : let config = config();
618 1 : connect_to_compute_inner(&ctx, &mechanism, &user_info, config.retry, &config)
619 1 : .await
620 1 : .unwrap_err();
621 1 : mechanism.verify();
622 1 : }
623 :
624 : /// Even for non-retryable errors, we should retry at least once.
625 : #[tokio::test]
626 1 : async fn connect_to_compute_non_retry_2() {
627 1 : let _ = env_logger::try_init();
628 : use ConnectAction::*;
629 1 : let ctx = RequestContext::test();
630 1 : let mechanism = TestConnectMechanism::new(vec![Wake, Fail, Wake, Connect]);
631 1 : let user_info = helper_create_connect_info(&mechanism);
632 1 : let config = config();
633 1 : connect_to_compute_inner(&ctx, &mechanism, &user_info, config.retry, &config)
634 1 : .await
635 1 : .unwrap();
636 1 : mechanism.verify();
637 1 : }
638 :
639 : /// Retry for at most `NUM_RETRIES_CONNECT` times.
640 : #[tokio::test]
641 1 : async fn connect_to_compute_non_retry_3() {
642 1 : let _ = env_logger::try_init();
643 1 : tokio::time::pause();
644 : use ConnectAction::*;
645 1 : let ctx = RequestContext::test();
646 1 : let mechanism =
647 1 : TestConnectMechanism::new(vec![Wake, Retry, Wake, Retry, Retry, Retry, Retry, Retry]);
648 1 : let user_info = helper_create_connect_info(&mechanism);
649 1 : let wake_compute_retry_config = RetryConfig {
650 1 : base_delay: Duration::from_secs(1),
651 1 : max_retries: 1,
652 1 : backoff_factor: 2.0,
653 1 : };
654 1 : let config = config();
655 1 : connect_to_compute_inner(
656 1 : &ctx,
657 1 : &mechanism,
658 1 : &user_info,
659 1 : wake_compute_retry_config,
660 1 : &config,
661 1 : )
662 1 : .await
663 1 : .unwrap_err();
664 1 : mechanism.verify();
665 1 : }
666 :
667 : /// Should retry wake compute.
668 : #[tokio::test]
669 1 : async fn wake_retry() {
670 1 : let _ = env_logger::try_init();
671 : use ConnectAction::*;
672 1 : let ctx = RequestContext::test();
673 1 : let mechanism = TestConnectMechanism::new(vec![WakeRetry, Wake, Connect]);
674 1 : let user_info = helper_create_connect_info(&mechanism);
675 1 : let config = config();
676 1 : connect_to_compute_inner(&ctx, &mechanism, &user_info, config.retry, &config)
677 1 : .await
678 1 : .unwrap();
679 1 : mechanism.verify();
680 1 : }
681 :
682 : /// Wake failed with a non-retryable error.
683 : #[tokio::test]
684 1 : async fn wake_non_retry() {
685 1 : let _ = env_logger::try_init();
686 : use ConnectAction::*;
687 1 : let ctx = RequestContext::test();
688 1 : let mechanism = TestConnectMechanism::new(vec![WakeRetry, WakeFail]);
689 1 : let user_info = helper_create_connect_info(&mechanism);
690 1 : let config = config();
691 1 : connect_to_compute_inner(&ctx, &mechanism, &user_info, config.retry, &config)
692 1 : .await
693 1 : .unwrap_err();
694 1 : mechanism.verify();
695 1 : }
696 :
697 : #[tokio::test]
698 : #[traced_test]
699 1 : async fn fail_but_wake_invalidates_cache() {
700 1 : let ctx = RequestContext::test();
701 1 : let mech = TestConnectMechanism::new(vec![
702 1 : ConnectAction::Wake,
703 1 : ConnectAction::Fail,
704 1 : ConnectAction::Wake,
705 1 : ConnectAction::Connect,
706 : ]);
707 1 : let user = helper_create_connect_info(&mech);
708 1 : let cfg = config();
709 :
710 1 : connect_to_compute_inner(&ctx, &mech, &user, cfg.retry, &cfg)
711 1 : .await
712 1 : .unwrap();
713 :
714 1 : assert!(logs_contain(
715 1 : "invalidating stalled compute node info cache entry"
716 1 : ));
717 1 : }
718 :
719 : #[tokio::test]
720 : #[traced_test]
721 1 : async fn fail_no_wake_skips_cache_invalidation() {
722 1 : let ctx = RequestContext::test();
723 1 : let mech = TestConnectMechanism::new(vec![
724 1 : ConnectAction::Wake,
725 1 : ConnectAction::RetryNoWake,
726 1 : ConnectAction::Connect,
727 : ]);
728 1 : let user = helper_create_connect_info(&mech);
729 1 : let cfg = config();
730 :
731 1 : connect_to_compute_inner(&ctx, &mech, &user, cfg.retry, &cfg)
732 1 : .await
733 1 : .unwrap();
734 :
735 1 : assert!(!logs_contain(
736 1 : "invalidating stalled compute node info cache entry"
737 1 : ));
738 1 : }
739 :
740 : #[tokio::test]
741 : #[traced_test]
742 1 : async fn retry_but_wake_invalidates_cache() {
743 1 : let _ = env_logger::try_init();
744 : use ConnectAction::*;
745 :
746 1 : let ctx = RequestContext::test();
747 : // Wake → Retry (retryable + wakeable) → Wake → Connect
748 1 : let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Connect]);
749 1 : let user_info = helper_create_connect_info(&mechanism);
750 1 : let cfg = config();
751 :
752 1 : connect_to_compute_inner(&ctx, &mechanism, &user_info, cfg.retry, &cfg)
753 1 : .await
754 1 : .unwrap();
755 1 : mechanism.verify();
756 :
757 : // Because Retry has wakeable=true, we should see invalidate_cache
758 1 : assert!(logs_contain(
759 1 : "invalidating stalled compute node info cache entry"
760 1 : ));
761 1 : }
762 :
763 : #[tokio::test]
764 : #[traced_test]
765 1 : async fn retry_no_wake_skips_invalidation() {
766 1 : let _ = env_logger::try_init();
767 : use ConnectAction::*;
768 :
769 1 : let ctx = RequestContext::test();
770 : // Wake → RetryNoWake (retryable + NOT wakeable)
771 1 : let mechanism = TestConnectMechanism::new(vec![Wake, RetryNoWake, Fail]);
772 1 : let user_info = helper_create_connect_info(&mechanism);
773 1 : let cfg = config();
774 :
775 1 : connect_to_compute_inner(&ctx, &mechanism, &user_info, cfg.retry, &cfg)
776 1 : .await
777 1 : .unwrap_err();
778 1 : mechanism.verify();
779 :
780 : // Because RetryNoWake has wakeable=false, we must NOT see invalidate_cache
781 1 : assert!(!logs_contain(
782 1 : "invalidating stalled compute node info cache entry"
783 1 : ));
784 1 : }
785 :
786 : #[tokio::test]
787 : #[traced_test]
788 1 : async fn retry_no_wake_error_fast() {
789 1 : let _ = env_logger::try_init();
790 : use ConnectAction::*;
791 :
792 1 : let ctx = RequestContext::test();
793 : // Wake → FailNoWake (not retryable + NOT wakeable)
794 1 : let mechanism = TestConnectMechanism::new(vec![Wake, FailNoWake]);
795 1 : let user_info = helper_create_connect_info(&mechanism);
796 1 : let cfg = config();
797 :
798 1 : connect_to_compute_inner(&ctx, &mechanism, &user_info, cfg.retry, &cfg)
799 1 : .await
800 1 : .unwrap_err();
801 1 : mechanism.verify();
802 :
803 : // Because FailNoWake has wakeable=false, we must NOT see invalidate_cache
804 1 : assert!(!logs_contain(
805 1 : "invalidating stalled compute node info cache entry"
806 1 : ));
807 1 : }
808 :
809 : #[tokio::test]
810 : #[traced_test]
811 1 : async fn retry_cold_wake_skips_invalidation() {
812 1 : let _ = env_logger::try_init();
813 : use ConnectAction::*;
814 :
815 1 : let ctx = RequestContext::test();
816 : // WakeCold → FailNoWake (not retryable + NOT wakeable)
817 1 : let mechanism = TestConnectMechanism::new(vec![WakeCold, Retry, Connect]);
818 1 : let user_info = helper_create_connect_info(&mechanism);
819 1 : let cfg = config();
820 :
821 1 : connect_to_compute_inner(&ctx, &mechanism, &user_info, cfg.retry, &cfg)
822 1 : .await
823 1 : .unwrap();
824 1 : mechanism.verify();
825 1 : }
|