|             Line data    Source code 
       1              : use std::fmt::{Debug, Display};
       2              : use std::time::Duration;
       3              : 
       4              : use futures::Future;
       5              : use tokio_util::sync::CancellationToken;
       6              : 
       7              : pub const DEFAULT_BASE_BACKOFF_SECONDS: f64 = 0.1;
       8              : pub const DEFAULT_MAX_BACKOFF_SECONDS: f64 = 3.0;
       9              : 
      10          589 : pub async fn exponential_backoff(
      11          589 :     n: u32,
      12          589 :     base_increment: f64,
      13          589 :     max_seconds: f64,
      14          589 :     cancel: &CancellationToken,
      15          589 : ) {
      16          589 :     let backoff_duration_seconds =
      17          589 :         exponential_backoff_duration_seconds(n, base_increment, max_seconds);
      18          589 :     if backoff_duration_seconds > 0.0 {
      19            1 :         tracing::info!(
      20            0 :             "Backoff: waiting {backoff_duration_seconds} seconds before processing with the task",
      21              :         );
      22              : 
      23            1 :         drop(
      24            1 :             tokio::time::timeout(
      25            1 :                 std::time::Duration::from_secs_f64(backoff_duration_seconds),
      26            1 :                 cancel.cancelled(),
      27            1 :             )
      28            1 :             .await,
      29              :         )
      30          588 :     }
      31          589 : }
      32              : 
      33            0 : pub fn exponential_backoff_duration(n: u32, base_increment: f64, max_seconds: f64) -> Duration {
      34            0 :     let seconds = exponential_backoff_duration_seconds(n, base_increment, max_seconds);
      35            0 :     Duration::from_secs_f64(seconds)
      36            0 : }
      37              : 
      38        10634 : pub fn exponential_backoff_duration_seconds(n: u32, base_increment: f64, max_seconds: f64) -> f64 {
      39        10634 :     if n == 0 {
      40          589 :         0.0
      41              :     } else {
      42        10045 :         (1.0 + base_increment).powf(f64::from(n)).min(max_seconds)
      43              :     }
      44        10634 : }
      45              : 
      46              : /// Retries passed operation until one of the following conditions are met:
      47              : /// - encountered error is considered as permanent (non-retryable)
      48              : /// - retries have been exhausted
      49              : /// - cancellation token has been cancelled
      50              : ///
      51              : /// `is_permanent` closure should be used to provide distinction between permanent/non-permanent
      52              : /// errors. When attempts cross `warn_threshold` function starts to emit log warnings.
      53              : /// `description` argument is added to log messages. Its value should identify the `op` is doing
      54              : /// `cancel` cancels new attempts and the backoff sleep.
      55              : ///
      56              : /// If attempts fail, they are being logged with `{:#}` which works for anyhow, but does not work
      57              : /// for any other error type. Final failed attempt is logged with `{:?}`.
      58              : ///
      59              : /// Returns `None` if cancellation was noticed during backoff or the terminal result.
      60          957 : pub async fn retry<T, O, F, E>(
      61          957 :     mut op: O,
      62          957 :     is_permanent: impl Fn(&E) -> bool,
      63          957 :     warn_threshold: u32,
      64          957 :     max_retries: u32,
      65          957 :     description: &str,
      66          957 :     cancel: &CancellationToken,
      67          957 : ) -> Option<Result<T, E>>
      68          957 : where
      69          957 :     // Not std::error::Error because anyhow::Error doesnt implement it.
      70          957 :     // For context see https://github.com/dtolnay/anyhow/issues/63
      71          957 :     E: Display + Debug + 'static,
      72          957 :     O: FnMut() -> F,
      73          957 :     F: Future<Output = Result<T, E>>,
      74            3 : {
      75          957 :     let mut attempts = 0;
      76              :     loop {
      77          975 :         if cancel.is_cancelled() {
      78            0 :             return None;
      79            6 :         }
      80              : 
      81          975 :         let result = op().await;
      82            1 :         match &result {
      83              :             Ok(_) => {
      84          585 :                 if attempts > 0 {
      85           16 :                     tracing::info!("{description} succeeded after {attempts} retries");
      86            0 :                 }
      87          585 :                 return Some(result);
      88              :             }
      89              : 
      90              :             // These are "permanent" errors that should not be retried.
      91          390 :             Err(e) if is_permanent(e) => {
      92          371 :                 return Some(result);
      93              :             }
      94              :             // Assume that any other failure might be transient, and the operation might
      95              :             // succeed if we just keep trying.
      96           19 :             Err(err) if attempts < warn_threshold => {
      97           18 :                 tracing::info!("{description} failed, will retry (attempt {attempts}): {err:#}");
      98              :             }
      99            1 :             Err(err) if attempts < max_retries => {
     100            0 :                 tracing::warn!("{description} failed, will retry (attempt {attempts}): {err:#}");
     101              :             }
     102            1 :             Err(err) => {
     103              :                 // Operation failed `max_attempts` times. Time to give up.
     104            1 :                 tracing::warn!(
     105            0 :                     "{description} still failed after {attempts} retries, giving up: {err:?}"
     106              :                 );
     107            1 :                 return Some(result);
     108              :             }
     109              :         }
     110              :         // sleep and retry
     111           18 :         exponential_backoff(
     112           18 :             attempts,
     113           18 :             DEFAULT_BASE_BACKOFF_SECONDS,
     114           18 :             DEFAULT_MAX_BACKOFF_SECONDS,
     115           18 :             cancel,
     116           18 :         )
     117           18 :         .await;
     118           18 :         attempts += 1;
     119              :     }
     120            3 : }
     121              : 
     122              : #[cfg(test)]
     123              : mod tests {
     124              :     use std::io;
     125              : 
     126              :     use tokio::sync::Mutex;
     127              : 
     128              :     use super::*;
     129              : 
     130              :     #[test]
     131            1 :     fn backoff_defaults_produce_growing_backoff_sequence() {
     132            1 :         let mut current_backoff_value = None;
     133              : 
     134        10001 :         for i in 0..10_000 {
     135        10000 :             let new_backoff_value = exponential_backoff_duration_seconds(
     136        10000 :                 i,
     137              :                 DEFAULT_BASE_BACKOFF_SECONDS,
     138              :                 DEFAULT_MAX_BACKOFF_SECONDS,
     139              :             );
     140              : 
     141        10000 :             if let Some(old_backoff_value) = current_backoff_value.replace(new_backoff_value) {
     142         9999 :                 assert!(
     143         9999 :                     old_backoff_value <= new_backoff_value,
     144            0 :                     "{i}th backoff value {new_backoff_value} is smaller than the previous one {old_backoff_value}"
     145              :                 )
     146            1 :             }
     147              :         }
     148              : 
     149            1 :         assert_eq!(
     150            1 :             current_backoff_value.expect("Should have produced backoff values to compare"),
     151              :             DEFAULT_MAX_BACKOFF_SECONDS,
     152            0 :             "Given big enough of retries, backoff should reach its allowed max value"
     153              :         );
     154            1 :     }
     155              : 
     156              :     #[tokio::test(start_paused = true)]
     157            1 :     async fn retry_always_error() {
     158            1 :         let count = Mutex::new(0);
     159            1 :         retry(
     160            2 :             || async {
     161            2 :                 *count.lock().await += 1;
     162            2 :                 Result::<(), io::Error>::Err(io::Error::from(io::ErrorKind::Other))
     163            4 :             },
     164              :             |_e| false,
     165              :             1,
     166              :             1,
     167            1 :             "work",
     168            1 :             &CancellationToken::new(),
     169              :         )
     170            1 :         .await
     171            1 :         .expect("not cancelled")
     172            1 :         .expect_err("it can only fail");
     173              : 
     174            1 :         assert_eq!(*count.lock().await, 2);
     175            1 :     }
     176              : 
     177              :     #[tokio::test(start_paused = true)]
     178            1 :     async fn retry_ok_after_err() {
     179            1 :         let count = Mutex::new(0);
     180            1 :         retry(
     181            3 :             || async {
     182            3 :                 let mut locked = count.lock().await;
     183            3 :                 if *locked > 1 {
     184            1 :                     Ok(())
     185            1 :                 } else {
     186            2 :                     *locked += 1;
     187            2 :                     Err(io::Error::from(io::ErrorKind::Other))
     188            1 :                 }
     189            6 :             },
     190            1 :             |_e| false,
     191            1 :             2,
     192            1 :             2,
     193            1 :             "work",
     194            1 :             &CancellationToken::new(),
     195            1 :         )
     196            1 :         .await
     197            1 :         .expect("not cancelled")
     198            1 :         .expect("success on second try");
     199            1 :     }
     200              : 
     201              :     #[tokio::test(start_paused = true)]
     202            1 :     async fn dont_retry_permanent_errors() {
     203            1 :         let count = Mutex::new(0);
     204            1 :         let _ = retry(
     205            1 :             || async {
     206            1 :                 let mut locked = count.lock().await;
     207            1 :                 if *locked > 1 {
     208            0 :                     Ok(())
     209              :                 } else {
     210            1 :                     *locked += 1;
     211            1 :                     Err(io::Error::from(io::ErrorKind::Other))
     212              :                 }
     213            2 :             },
     214              :             |_e| true,
     215              :             2,
     216              :             2,
     217            1 :             "work",
     218            1 :             &CancellationToken::new(),
     219              :         )
     220            1 :         .await
     221            1 :         .expect("was not cancellation")
     222            1 :         .expect_err("it was permanent error");
     223              : 
     224            1 :         assert_eq!(*count.lock().await, 1);
     225            1 :     }
     226              : }
         |