Line data Source code
1 : //! A group of high-level tests for connection establishing logic and auth.
2 : #![allow(clippy::unimplemented, clippy::unwrap_used)]
3 :
4 : mod mitm;
5 :
6 : use std::time::Duration;
7 :
8 : use anyhow::{Context, bail};
9 : use async_trait::async_trait;
10 : use http::StatusCode;
11 : use postgres_client::config::SslMode;
12 : use postgres_client::tls::{MakeTlsConnect, NoTls};
13 : use retry::{ShouldRetryWakeCompute, retry_after};
14 : use rstest::rstest;
15 : use rustls::crypto::ring;
16 : use rustls::pki_types;
17 : use tokio::io::DuplexStream;
18 :
19 : use super::connect_compute::ConnectMechanism;
20 : use super::retry::CouldRetry;
21 : use super::*;
22 : use crate::auth::backend::{
23 : ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, MaybeOwned,
24 : };
25 : use crate::config::{ComputeConfig, RetryConfig};
26 : use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient};
27 : use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status};
28 : use crate::control_plane::{
29 : self, CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedNodeInfo, NodeInfo, NodeInfoCache,
30 : };
31 : use crate::error::ErrorKind;
32 : use crate::tls::client_config::compute_client_config_with_certs;
33 : use crate::tls::postgres_rustls::MakeRustlsConnect;
34 : use crate::tls::server_config::CertResolver;
35 : use crate::types::{BranchId, EndpointId, ProjectId};
36 : use crate::{sasl, scram};
37 :
38 : /// Generate a set of TLS certificates: CA + server.
39 21 : fn generate_certs(
40 21 : hostname: &str,
41 21 : common_name: &str,
42 21 : ) -> anyhow::Result<(
43 21 : pki_types::CertificateDer<'static>,
44 21 : pki_types::CertificateDer<'static>,
45 21 : pki_types::PrivateKeyDer<'static>,
46 21 : )> {
47 21 : let ca_key = rcgen::KeyPair::generate()?;
48 21 : let ca = {
49 21 : let mut params = rcgen::CertificateParams::default();
50 21 : params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
51 21 : params.self_signed(&ca_key)?
52 : };
53 :
54 21 : let cert_key = rcgen::KeyPair::generate()?;
55 21 : let cert = {
56 21 : let mut params = rcgen::CertificateParams::new(vec![hostname.into()])?;
57 21 : params.distinguished_name = rcgen::DistinguishedName::new();
58 21 : params
59 21 : .distinguished_name
60 21 : .push(rcgen::DnType::CommonName, common_name);
61 21 : params.signed_by(&cert_key, &ca, &ca_key)?
62 : };
63 :
64 21 : Ok((
65 21 : ca.der().clone(),
66 21 : cert.der().clone(),
67 21 : pki_types::PrivateKeyDer::Pkcs8(cert_key.serialize_der().into()),
68 21 : ))
69 21 : }
70 :
71 : struct ClientConfig<'a> {
72 : config: Arc<rustls::ClientConfig>,
73 : hostname: &'a str,
74 : }
75 :
76 : type TlsConnect<S> = <MakeRustlsConnect as MakeTlsConnect<S>>::TlsConnect;
77 :
78 : impl ClientConfig<'_> {
79 20 : fn make_tls_connect(self) -> anyhow::Result<TlsConnect<DuplexStream>> {
80 20 : let mut mk = MakeRustlsConnect::new(self.config);
81 20 : let tls = MakeTlsConnect::<DuplexStream>::make_tls_connect(&mut mk, self.hostname)?;
82 20 : Ok(tls)
83 20 : }
84 : }
85 :
86 : /// Generate TLS certificates and build rustls configs for client and server.
87 21 : fn generate_tls_config<'a>(
88 21 : hostname: &'a str,
89 21 : common_name: &'a str,
90 21 : ) -> anyhow::Result<(ClientConfig<'a>, TlsConfig)> {
91 21 : let (ca, cert, key) = generate_certs(hostname, common_name)?;
92 :
93 21 : let tls_config = {
94 21 : let config =
95 21 : rustls::ServerConfig::builder_with_provider(Arc::new(ring::default_provider()))
96 21 : .with_safe_default_protocol_versions()
97 21 : .context("ring should support the default protocol versions")?
98 21 : .with_no_client_auth()
99 21 : .with_single_cert(vec![cert.clone()], key.clone_key())?;
100 :
101 21 : let mut cert_resolver = CertResolver::new();
102 21 : cert_resolver.add_cert(key, vec![cert], true)?;
103 :
104 21 : let common_names = cert_resolver.get_common_names();
105 21 :
106 21 : let config = Arc::new(config);
107 21 :
108 21 : TlsConfig {
109 21 : http_config: config.clone(),
110 21 : pg_config: config,
111 21 : common_names,
112 21 : cert_resolver: Arc::new(cert_resolver),
113 21 : }
114 21 : };
115 21 :
116 21 : let client_config = {
117 21 : let config = Arc::new(compute_client_config_with_certs([ca]));
118 21 :
119 21 : ClientConfig { config, hostname }
120 21 : };
121 21 :
122 21 : Ok((client_config, tls_config))
123 21 : }
124 :
125 : #[async_trait]
126 : trait TestAuth: Sized {
127 2 : async fn authenticate<S: AsyncRead + AsyncWrite + Unpin + Send>(
128 2 : self,
129 2 : stream: &mut PqStream<Stream<S>>,
130 2 : ) -> anyhow::Result<()> {
131 2 : stream.write_message_noflush(&Be::AuthenticationOk)?;
132 2 : Ok(())
133 4 : }
134 : }
135 :
136 : struct NoAuth;
137 : impl TestAuth for NoAuth {}
138 :
139 : struct Scram(scram::ServerSecret);
140 :
141 : impl Scram {
142 11 : async fn new(password: &str) -> anyhow::Result<Self> {
143 11 : let secret = scram::ServerSecret::build(password)
144 11 : .await
145 11 : .context("failed to generate scram secret")?;
146 11 : Ok(Scram(secret))
147 11 : }
148 :
149 1 : fn mock() -> Self {
150 1 : Scram(scram::ServerSecret::mock(rand::random()))
151 1 : }
152 : }
153 :
154 : #[async_trait]
155 : impl TestAuth for Scram {
156 12 : async fn authenticate<S: AsyncRead + AsyncWrite + Unpin + Send>(
157 12 : self,
158 12 : stream: &mut PqStream<Stream<S>>,
159 12 : ) -> anyhow::Result<()> {
160 12 : let outcome = auth::AuthFlow::new(stream)
161 12 : .begin(auth::Scram(&self.0, &RequestContext::test()))
162 12 : .await?
163 12 : .authenticate()
164 12 : .await?;
165 :
166 : use sasl::Outcome::*;
167 6 : match outcome {
168 5 : Success(_) => Ok(()),
169 1 : Failure(reason) => bail!("autentication failed with an error: {reason}"),
170 : }
171 24 : }
172 : }
173 :
174 : /// A dummy proxy impl which performs a handshake and reports auth success.
175 15 : async fn dummy_proxy(
176 15 : client: impl AsyncRead + AsyncWrite + Unpin + Send,
177 15 : tls: Option<TlsConfig>,
178 15 : auth: impl TestAuth + Send,
179 15 : ) -> anyhow::Result<()> {
180 15 : let (client, _) = read_proxy_protocol(client).await?;
181 15 : let mut stream = match handshake(&RequestContext::test(), client, tls.as_ref(), false).await? {
182 14 : HandshakeData::Startup(stream, _) => stream,
183 0 : HandshakeData::Cancel(_) => bail!("cancellation not supported"),
184 : };
185 :
186 14 : auth.authenticate(&mut stream).await?;
187 :
188 7 : stream
189 7 : .write_message_noflush(&Be::CLIENT_ENCODING)?
190 7 : .write_message(&Be::ReadyForQuery)
191 7 : .await?;
192 :
193 7 : Ok(())
194 15 : }
195 :
196 : #[tokio::test]
197 1 : async fn handshake_tls_is_enforced_by_proxy() -> anyhow::Result<()> {
198 1 : let (client, server) = tokio::io::duplex(1024);
199 1 :
200 1 : let (_, server_config) = generate_tls_config("generic-project-name.localhost", "localhost")?;
201 1 : let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth));
202 1 :
203 1 : let client_err = postgres_client::Config::new("test".to_owned(), 5432)
204 1 : .user("john_doe")
205 1 : .dbname("earth")
206 1 : .ssl_mode(SslMode::Disable)
207 1 : .connect_raw(server, NoTls)
208 1 : .await
209 1 : .err() // -> Option<E>
210 1 : .context("client shouldn't be able to connect")?;
211 1 :
212 1 : assert!(client_err.to_string().contains(ERR_INSECURE_CONNECTION));
213 1 :
214 1 : let server_err = proxy
215 1 : .await?
216 1 : .err() // -> Option<E>
217 1 : .context("server shouldn't accept client")?;
218 1 :
219 1 : assert!(client_err.to_string().contains(&server_err.to_string()));
220 1 :
221 1 : Ok(())
222 1 : }
223 :
224 : #[tokio::test]
225 1 : async fn handshake_tls() -> anyhow::Result<()> {
226 1 : let (client, server) = tokio::io::duplex(1024);
227 1 :
228 1 : let (client_config, server_config) =
229 1 : generate_tls_config("generic-project-name.localhost", "localhost")?;
230 1 : let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth));
231 1 :
232 1 : let _conn = postgres_client::Config::new("test".to_owned(), 5432)
233 1 : .user("john_doe")
234 1 : .dbname("earth")
235 1 : .ssl_mode(SslMode::Require)
236 1 : .connect_raw(server, client_config.make_tls_connect()?)
237 1 : .await?;
238 1 :
239 1 : proxy.await?
240 1 : }
241 :
242 : #[tokio::test]
243 1 : async fn handshake_raw() -> anyhow::Result<()> {
244 1 : let (client, server) = tokio::io::duplex(1024);
245 1 :
246 1 : let proxy = tokio::spawn(dummy_proxy(client, None, NoAuth));
247 1 :
248 1 : let _conn = postgres_client::Config::new("test".to_owned(), 5432)
249 1 : .user("john_doe")
250 1 : .dbname("earth")
251 1 : .set_param("options", "project=generic-project-name")
252 1 : .ssl_mode(SslMode::Prefer)
253 1 : .connect_raw(server, NoTls)
254 1 : .await?;
255 1 :
256 1 : proxy.await?
257 1 : }
258 :
259 : #[tokio::test]
260 1 : async fn keepalive_is_inherited() -> anyhow::Result<()> {
261 1 : use tokio::net::{TcpListener, TcpStream};
262 1 :
263 1 : let listener = TcpListener::bind("127.0.0.1:0").await?;
264 1 : let port = listener.local_addr()?.port();
265 1 : socket2::SockRef::from(&listener).set_keepalive(true)?;
266 1 :
267 1 : let t = tokio::spawn(async move {
268 1 : let (client, _) = listener.accept().await?;
269 1 : let keepalive = socket2::SockRef::from(&client).keepalive()?;
270 1 : anyhow::Ok(keepalive)
271 1 : });
272 1 :
273 1 : TcpStream::connect(("127.0.0.1", port)).await?;
274 1 : assert!(t.await??, "keepalive should be inherited");
275 1 :
276 1 : Ok(())
277 1 : }
278 :
279 3 : #[rstest]
280 : #[case("password_foo")]
281 : #[case("pwd-bar")]
282 : #[case("")]
283 : #[tokio::test]
284 : async fn scram_auth_good(#[case] password: &str) -> anyhow::Result<()> {
285 : let (client, server) = tokio::io::duplex(1024);
286 :
287 : let (client_config, server_config) =
288 : generate_tls_config("generic-project-name.localhost", "localhost")?;
289 : let proxy = tokio::spawn(dummy_proxy(
290 : client,
291 : Some(server_config),
292 : Scram::new(password).await?,
293 : ));
294 :
295 : let _conn = postgres_client::Config::new("test".to_owned(), 5432)
296 : .channel_binding(postgres_client::config::ChannelBinding::Require)
297 : .user("user")
298 : .dbname("db")
299 : .password(password)
300 : .ssl_mode(SslMode::Require)
301 : .connect_raw(server, client_config.make_tls_connect()?)
302 : .await?;
303 :
304 : proxy.await?
305 : }
306 :
307 : #[tokio::test]
308 1 : async fn scram_auth_disable_channel_binding() -> anyhow::Result<()> {
309 1 : let (client, server) = tokio::io::duplex(1024);
310 1 :
311 1 : let (client_config, server_config) =
312 1 : generate_tls_config("generic-project-name.localhost", "localhost")?;
313 1 : let proxy = tokio::spawn(dummy_proxy(
314 1 : client,
315 1 : Some(server_config),
316 1 : Scram::new("password").await?,
317 1 : ));
318 1 :
319 1 : let _conn = postgres_client::Config::new("test".to_owned(), 5432)
320 1 : .channel_binding(postgres_client::config::ChannelBinding::Disable)
321 1 : .user("user")
322 1 : .dbname("db")
323 1 : .password("password")
324 1 : .ssl_mode(SslMode::Require)
325 1 : .connect_raw(server, client_config.make_tls_connect()?)
326 1 : .await?;
327 1 :
328 1 : proxy.await?
329 1 : }
330 :
331 : #[tokio::test]
332 1 : async fn scram_auth_mock() -> anyhow::Result<()> {
333 1 : let (client, server) = tokio::io::duplex(1024);
334 1 :
335 1 : let (client_config, server_config) =
336 1 : generate_tls_config("generic-project-name.localhost", "localhost")?;
337 1 : let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), Scram::mock()));
338 1 :
339 1 : use rand::Rng;
340 1 : use rand::distributions::Alphanumeric;
341 1 : let password: String = rand::thread_rng()
342 1 : .sample_iter(&Alphanumeric)
343 1 : .take(rand::random::<u8>() as usize)
344 1 : .map(char::from)
345 1 : .collect();
346 1 :
347 1 : let _client_err = postgres_client::Config::new("test".to_owned(), 5432)
348 1 : .user("user")
349 1 : .dbname("db")
350 1 : .password(&password) // no password will match the mocked secret
351 1 : .ssl_mode(SslMode::Require)
352 1 : .connect_raw(server, client_config.make_tls_connect()?)
353 1 : .await
354 1 : .err() // -> Option<E>
355 1 : .context("client shouldn't be able to connect")?;
356 1 :
357 1 : let _server_err = proxy
358 1 : .await?
359 1 : .err() // -> Option<E>
360 1 : .context("server shouldn't accept client")?;
361 1 :
362 1 : Ok(())
363 1 : }
364 :
365 : #[test]
366 1 : fn connect_compute_total_wait() {
367 1 : let mut total_wait = tokio::time::Duration::ZERO;
368 1 : let config = RetryConfig {
369 1 : base_delay: Duration::from_secs(1),
370 1 : max_retries: 5,
371 1 : backoff_factor: 2.0,
372 1 : };
373 4 : for num_retries in 1..config.max_retries {
374 4 : total_wait += retry_after(num_retries, config);
375 4 : }
376 1 : assert!(f64::abs(total_wait.as_secs_f64() - 15.0) < 0.1);
377 1 : }
378 :
379 : #[derive(Clone, Copy, Debug)]
380 : enum ConnectAction {
381 : Wake,
382 : WakeFail,
383 : WakeRetry,
384 : Connect,
385 : Retry,
386 : Fail,
387 : }
388 :
389 : #[derive(Clone)]
390 : struct TestConnectMechanism {
391 : counter: Arc<std::sync::Mutex<usize>>,
392 : sequence: Vec<ConnectAction>,
393 : cache: &'static NodeInfoCache,
394 : }
395 :
396 : impl TestConnectMechanism {
397 7 : fn verify(&self) {
398 7 : let counter = self.counter.lock().unwrap();
399 7 : assert_eq!(
400 7 : *counter,
401 7 : self.sequence.len(),
402 0 : "sequence does not proceed to the end"
403 : );
404 7 : }
405 : }
406 :
407 : impl TestConnectMechanism {
408 7 : fn new(sequence: Vec<ConnectAction>) -> Self {
409 7 : Self {
410 7 : counter: Arc::new(std::sync::Mutex::new(0)),
411 7 : sequence,
412 7 : cache: Box::leak(Box::new(NodeInfoCache::new(
413 7 : "test",
414 7 : 1,
415 7 : Duration::from_secs(100),
416 7 : false,
417 7 : ))),
418 7 : }
419 7 : }
420 : }
421 :
422 : #[derive(Debug)]
423 : struct TestConnection;
424 :
425 : #[derive(Debug)]
426 : struct TestConnectError {
427 : retryable: bool,
428 : kind: crate::error::ErrorKind,
429 : }
430 :
431 : impl ReportableError for TestConnectError {
432 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
433 0 : self.kind
434 0 : }
435 : }
436 :
437 : impl std::fmt::Display for TestConnectError {
438 0 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
439 0 : write!(f, "{self:?}")
440 0 : }
441 : }
442 :
443 : impl std::error::Error for TestConnectError {}
444 :
445 : impl CouldRetry for TestConnectError {
446 5 : fn could_retry(&self) -> bool {
447 5 : self.retryable
448 5 : }
449 : }
450 : impl ShouldRetryWakeCompute for TestConnectError {
451 4 : fn should_retry_wake_compute(&self) -> bool {
452 4 : true
453 4 : }
454 : }
455 :
456 : #[async_trait]
457 : impl ConnectMechanism for TestConnectMechanism {
458 : type Connection = TestConnection;
459 : type ConnectError = TestConnectError;
460 : type Error = anyhow::Error;
461 :
462 14 : async fn connect_once(
463 14 : &self,
464 14 : _ctx: &RequestContext,
465 14 : _node_info: &control_plane::CachedNodeInfo,
466 14 : _config: &ComputeConfig,
467 14 : ) -> Result<Self::Connection, Self::ConnectError> {
468 14 : let mut counter = self.counter.lock().unwrap();
469 14 : let action = self.sequence[*counter];
470 14 : *counter += 1;
471 14 : match action {
472 4 : ConnectAction::Connect => Ok(TestConnection),
473 8 : ConnectAction::Retry => Err(TestConnectError {
474 8 : retryable: true,
475 8 : kind: ErrorKind::Compute,
476 8 : }),
477 2 : ConnectAction::Fail => Err(TestConnectError {
478 2 : retryable: false,
479 2 : kind: ErrorKind::Compute,
480 2 : }),
481 0 : x => panic!("expecting action {x:?}, connect is called instead"),
482 : }
483 28 : }
484 :
485 10 : fn update_connect_config(&self, _conf: &mut compute::ConnCfg) {}
486 : }
487 :
488 : impl TestControlPlaneClient for TestConnectMechanism {
489 13 : fn wake_compute(&self) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
490 13 : let mut counter = self.counter.lock().unwrap();
491 13 : let action = self.sequence[*counter];
492 13 : *counter += 1;
493 13 : match action {
494 10 : ConnectAction::Wake => Ok(helper_create_cached_node_info(self.cache)),
495 : ConnectAction::WakeFail => {
496 1 : let err = control_plane::errors::ControlPlaneError::Message(Box::new(
497 1 : ControlPlaneErrorMessage {
498 1 : http_status_code: StatusCode::BAD_REQUEST,
499 1 : error: "TEST".into(),
500 1 : status: None,
501 1 : },
502 1 : ));
503 1 : assert!(!err.could_retry());
504 1 : Err(control_plane::errors::WakeComputeError::ControlPlane(err))
505 : }
506 : ConnectAction::WakeRetry => {
507 2 : let err = control_plane::errors::ControlPlaneError::Message(Box::new(
508 2 : ControlPlaneErrorMessage {
509 2 : http_status_code: StatusCode::BAD_REQUEST,
510 2 : error: "TEST".into(),
511 2 : status: Some(Status {
512 2 : code: "error".into(),
513 2 : message: "error".into(),
514 2 : details: Details {
515 2 : error_info: None,
516 2 : retry_info: Some(control_plane::messages::RetryInfo {
517 2 : retry_delay_ms: 1,
518 2 : }),
519 2 : user_facing_message: None,
520 2 : },
521 2 : }),
522 2 : },
523 2 : ));
524 2 : assert!(err.could_retry());
525 2 : Err(control_plane::errors::WakeComputeError::ControlPlane(err))
526 : }
527 0 : x => panic!("expecting action {x:?}, wake_compute is called instead"),
528 : }
529 13 : }
530 :
531 0 : fn get_allowed_ips(&self) -> Result<CachedAllowedIps, control_plane::errors::GetAuthInfoError> {
532 0 : unimplemented!("not used in tests")
533 : }
534 :
535 0 : fn get_allowed_vpc_endpoint_ids(
536 0 : &self,
537 0 : ) -> Result<CachedAllowedVpcEndpointIds, control_plane::errors::GetAuthInfoError> {
538 0 : unimplemented!("not used in tests")
539 : }
540 :
541 0 : fn get_block_public_or_vpc_access(
542 0 : &self,
543 0 : ) -> Result<control_plane::CachedAccessBlockerFlags, control_plane::errors::GetAuthInfoError>
544 0 : {
545 0 : unimplemented!("not used in tests")
546 : }
547 :
548 0 : fn dyn_clone(&self) -> Box<dyn TestControlPlaneClient> {
549 0 : Box::new(self.clone())
550 0 : }
551 : }
552 :
553 10 : fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo {
554 10 : let node = NodeInfo {
555 10 : config: compute::ConnCfg::new("test".to_owned(), 5432),
556 10 : aux: MetricsAuxInfo {
557 10 : endpoint_id: (&EndpointId::from("endpoint")).into(),
558 10 : project_id: (&ProjectId::from("project")).into(),
559 10 : branch_id: (&BranchId::from("branch")).into(),
560 10 : compute_id: "compute".into(),
561 10 : cold_start_info: crate::control_plane::messages::ColdStartInfo::Warm,
562 10 : },
563 10 : };
564 10 : let (_, node2) = cache.insert_unit("key".into(), Ok(node.clone()));
565 10 : node2.map(|()| node)
566 10 : }
567 :
568 7 : fn helper_create_connect_info(
569 7 : mechanism: &TestConnectMechanism,
570 7 : ) -> auth::Backend<'static, ComputeCredentials> {
571 7 : let user_info = auth::Backend::ControlPlane(
572 7 : MaybeOwned::Owned(ControlPlaneClient::Test(Box::new(mechanism.clone()))),
573 7 : ComputeCredentials {
574 7 : info: ComputeUserInfo {
575 7 : endpoint: "endpoint".into(),
576 7 : user: "user".into(),
577 7 : options: NeonOptions::parse_options_raw(""),
578 7 : },
579 7 : keys: ComputeCredentialKeys::Password("password".into()),
580 7 : },
581 7 : );
582 7 : user_info
583 7 : }
584 :
585 7 : fn config() -> ComputeConfig {
586 7 : let retry = RetryConfig {
587 7 : base_delay: Duration::from_secs(1),
588 7 : max_retries: 5,
589 7 : backoff_factor: 2.0,
590 7 : };
591 7 :
592 7 : ComputeConfig {
593 7 : retry,
594 7 : tls: Arc::new(compute_client_config_with_certs(std::iter::empty())),
595 7 : timeout: Duration::from_secs(2),
596 7 : }
597 7 : }
598 :
599 : #[tokio::test]
600 1 : async fn connect_to_compute_success() {
601 1 : let _ = env_logger::try_init();
602 1 : use ConnectAction::*;
603 1 : let ctx = RequestContext::test();
604 1 : let mechanism = TestConnectMechanism::new(vec![Wake, Connect]);
605 1 : let user_info = helper_create_connect_info(&mechanism);
606 1 : let config = config();
607 1 : connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
608 1 : .await
609 1 : .unwrap();
610 1 : mechanism.verify();
611 1 : }
612 :
613 : #[tokio::test]
614 1 : async fn connect_to_compute_retry() {
615 1 : let _ = env_logger::try_init();
616 1 : use ConnectAction::*;
617 1 : let ctx = RequestContext::test();
618 1 : let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Connect]);
619 1 : let user_info = helper_create_connect_info(&mechanism);
620 1 : let config = config();
621 1 : connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
622 1 : .await
623 1 : .unwrap();
624 1 : mechanism.verify();
625 1 : }
626 :
627 : /// Test that we don't retry if the error is not retryable.
628 : #[tokio::test]
629 1 : async fn connect_to_compute_non_retry_1() {
630 1 : let _ = env_logger::try_init();
631 1 : use ConnectAction::*;
632 1 : let ctx = RequestContext::test();
633 1 : let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Fail]);
634 1 : let user_info = helper_create_connect_info(&mechanism);
635 1 : let config = config();
636 1 : connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
637 1 : .await
638 1 : .unwrap_err();
639 1 : mechanism.verify();
640 1 : }
641 :
642 : /// Even for non-retryable errors, we should retry at least once.
643 : #[tokio::test]
644 1 : async fn connect_to_compute_non_retry_2() {
645 1 : let _ = env_logger::try_init();
646 1 : use ConnectAction::*;
647 1 : let ctx = RequestContext::test();
648 1 : let mechanism = TestConnectMechanism::new(vec![Wake, Fail, Wake, Connect]);
649 1 : let user_info = helper_create_connect_info(&mechanism);
650 1 : let config = config();
651 1 : connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
652 1 : .await
653 1 : .unwrap();
654 1 : mechanism.verify();
655 1 : }
656 :
657 : /// Retry for at most `NUM_RETRIES_CONNECT` times.
658 : #[tokio::test]
659 1 : async fn connect_to_compute_non_retry_3() {
660 1 : let _ = env_logger::try_init();
661 1 : tokio::time::pause();
662 1 : use ConnectAction::*;
663 1 : let ctx = RequestContext::test();
664 1 : let mechanism =
665 1 : TestConnectMechanism::new(vec![Wake, Retry, Wake, Retry, Retry, Retry, Retry, Retry]);
666 1 : let user_info = helper_create_connect_info(&mechanism);
667 1 : let wake_compute_retry_config = RetryConfig {
668 1 : base_delay: Duration::from_secs(1),
669 1 : max_retries: 1,
670 1 : backoff_factor: 2.0,
671 1 : };
672 1 : let config = config();
673 1 : connect_to_compute(
674 1 : &ctx,
675 1 : &mechanism,
676 1 : &user_info,
677 1 : wake_compute_retry_config,
678 1 : &config,
679 1 : )
680 1 : .await
681 1 : .unwrap_err();
682 1 : mechanism.verify();
683 1 : }
684 :
685 : /// Should retry wake compute.
686 : #[tokio::test]
687 1 : async fn wake_retry() {
688 1 : let _ = env_logger::try_init();
689 1 : use ConnectAction::*;
690 1 : let ctx = RequestContext::test();
691 1 : let mechanism = TestConnectMechanism::new(vec![WakeRetry, Wake, Connect]);
692 1 : let user_info = helper_create_connect_info(&mechanism);
693 1 : let config = config();
694 1 : connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
695 1 : .await
696 1 : .unwrap();
697 1 : mechanism.verify();
698 1 : }
699 :
700 : /// Wake failed with a non-retryable error.
701 : #[tokio::test]
702 1 : async fn wake_non_retry() {
703 1 : let _ = env_logger::try_init();
704 1 : use ConnectAction::*;
705 1 : let ctx = RequestContext::test();
706 1 : let mechanism = TestConnectMechanism::new(vec![WakeRetry, WakeFail]);
707 1 : let user_info = helper_create_connect_info(&mechanism);
708 1 : let config = config();
709 1 : connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
710 1 : .await
711 1 : .unwrap_err();
712 1 : mechanism.verify();
713 1 : }
|