LCOV - code coverage report
Current view: top level - libs/utils/src - backoff.rs (source / functions) Coverage Total Hit
Test: 8ac049b474321fdc72ddcb56d7165153a1a900e8.info Lines: 93.2 % 147 137
Test Date: 2023-09-06 10:18:01 Functions: 65.6 % 122 80

            Line data    Source code
       1              : use std::fmt::{Debug, Display};
       2              : 
       3              : use futures::Future;
       4              : use tokio_util::sync::CancellationToken;
       5              : 
       6              : pub const DEFAULT_BASE_BACKOFF_SECONDS: f64 = 0.1;
       7              : pub const DEFAULT_MAX_BACKOFF_SECONDS: f64 = 3.0;
       8              : 
       9         6199 : pub async fn exponential_backoff(
      10         6199 :     n: u32,
      11         6199 :     base_increment: f64,
      12         6199 :     max_seconds: f64,
      13         6199 :     cancel: &CancellationToken,
      14         6199 : ) {
      15         6199 :     let backoff_duration_seconds =
      16         6199 :         exponential_backoff_duration_seconds(n, base_increment, max_seconds);
      17         6199 :     if backoff_duration_seconds > 0.0 {
      18            8 :         tracing::info!(
      19            8 :             "Backoff: waiting {backoff_duration_seconds} seconds before processing with the task",
      20            8 :         );
      21              : 
      22              :         drop(
      23            9 :             tokio::time::timeout(
      24            9 :                 std::time::Duration::from_secs_f64(backoff_duration_seconds),
      25            9 :                 cancel.cancelled(),
      26            9 :             )
      27            7 :             .await,
      28              :         )
      29         6190 :     }
      30         6198 : }
      31              : 
      32        16199 : pub fn exponential_backoff_duration_seconds(n: u32, base_increment: f64, max_seconds: f64) -> f64 {
      33        16199 :     if n == 0 {
      34         6191 :         0.0
      35              :     } else {
      36        10008 :         (1.0 + base_increment).powf(f64::from(n)).min(max_seconds)
      37              :     }
      38        16199 : }
      39              : 
      40              : /// Configure cancellation for a retried operation: when to cancel (the token), and
      41              : /// what kind of error to return on cancellation
      42              : pub struct Cancel<E, CF>
      43              : where
      44              :     E: Display + Debug + 'static,
      45              :     CF: Fn() -> E,
      46              : {
      47              :     token: CancellationToken,
      48              :     on_cancel: CF,
      49              : }
      50              : 
      51              : impl<E, CF> Cancel<E, CF>
      52              : where
      53              :     E: Display + Debug + 'static,
      54              :     CF: Fn() -> E,
      55              : {
      56         2450 :     pub fn new(token: CancellationToken, on_cancel: CF) -> Self {
      57         2450 :         Self { token, on_cancel }
      58         2450 :     }
      59              : }
      60              : 
      61              : /// retries passed operation until one of the following conditions are met:
      62              : /// Encountered error is considered as permanent (non-retryable)
      63              : /// Retries have been exhausted.
      64              : /// `is_permanent` closure should be used to provide distinction between permanent/non-permanent errors
      65              : /// When attempts cross `warn_threshold` function starts to emit log warnings.
      66              : /// `description` argument is added to log messages. Its value should identify the `op` is doing
      67              : /// `cancel` argument is required: any time we are looping on retry, we should be using a CancellationToken
      68              : /// to drop out promptly on shutdown.
      69         2450 : pub async fn retry<T, O, F, E, CF>(
      70         2450 :     mut op: O,
      71         2450 :     is_permanent: impl Fn(&E) -> bool,
      72         2450 :     warn_threshold: u32,
      73         2450 :     max_retries: u32,
      74         2450 :     description: &str,
      75         2450 :     cancel: Cancel<E, CF>,
      76         2450 : ) -> Result<T, E>
      77         2450 : where
      78         2450 :     // Not std::error::Error because anyhow::Error doesnt implement it.
      79         2450 :     // For context see https://github.com/dtolnay/anyhow/issues/63
      80         2450 :     E: Display + Debug + 'static,
      81         2450 :     O: FnMut() -> F,
      82         2450 :     F: Future<Output = Result<T, E>>,
      83         2450 :     CF: Fn() -> E,
      84         2450 : {
      85         2450 :     let mut attempts = 0;
      86              :     loop {
      87         2824 :         if cancel.token.is_cancelled() {
      88            0 :             return Err((cancel.on_cancel)());
      89         2824 :         }
      90              : 
      91       362365 :         let result = op().await;
      92            1 :         match result {
      93              :             Ok(_) => {
      94         2063 :                 if attempts > 0 {
      95          317 :                     tracing::info!("{description} succeeded after {attempts} retries");
      96         1745 :                 }
      97         2063 :                 return result;
      98              :             }
      99              : 
     100              :             // These are "permanent" errors that should not be retried.
     101          759 :             Err(ref e) if is_permanent(e) => {
     102          384 :                 return result;
     103              :             }
     104              :             // Assume that any other failure might be transient, and the operation might
     105              :             // succeed if we just keep trying.
     106          375 :             Err(err) if attempts < warn_threshold => {
     107          371 :                 tracing::info!("{description} failed, will retry (attempt {attempts}): {err:#}");
     108              :             }
     109            1 :             Err(err) if attempts < max_retries => {
     110            0 :                 tracing::warn!("{description} failed, will retry (attempt {attempts}): {err:#}");
     111              :             }
     112            1 :             Err(ref err) => {
     113              :                 // Operation failed `max_attempts` times. Time to give up.
     114            0 :                 tracing::warn!(
     115            0 :                     "{description} still failed after {attempts} retries, giving up: {err:?}"
     116            0 :                 );
     117            1 :                 return result;
     118              :             }
     119              :         }
     120              :         // sleep and retry
     121          374 :         exponential_backoff(
     122          374 :             attempts,
     123          374 :             DEFAULT_BASE_BACKOFF_SECONDS,
     124          374 :             DEFAULT_MAX_BACKOFF_SECONDS,
     125          374 :             &cancel.token,
     126          374 :         )
     127            1 :         .await;
     128          374 :         attempts += 1;
     129              :     }
     130         2448 : }
     131              : 
     132              : #[cfg(test)]
     133              : mod tests {
     134              :     use std::io;
     135              : 
     136              :     use tokio::sync::Mutex;
     137              : 
     138              :     use super::*;
     139              : 
     140            1 :     #[test]
     141            1 :     fn backoff_defaults_produce_growing_backoff_sequence() {
     142            1 :         let mut current_backoff_value = None;
     143              : 
     144        10001 :         for i in 0..10_000 {
     145        10000 :             let new_backoff_value = exponential_backoff_duration_seconds(
     146        10000 :                 i,
     147        10000 :                 DEFAULT_BASE_BACKOFF_SECONDS,
     148        10000 :                 DEFAULT_MAX_BACKOFF_SECONDS,
     149        10000 :             );
     150              : 
     151        10000 :             if let Some(old_backoff_value) = current_backoff_value.replace(new_backoff_value) {
     152         9999 :                 assert!(
     153         9999 :                     old_backoff_value <= new_backoff_value,
     154            0 :                     "{i}th backoff value {new_backoff_value} is smaller than the previous one {old_backoff_value}"
     155              :                 )
     156            1 :             }
     157              :         }
     158              : 
     159            1 :         assert_eq!(
     160            1 :             current_backoff_value.expect("Should have produced backoff values to compare"),
     161              :             DEFAULT_MAX_BACKOFF_SECONDS,
     162            0 :             "Given big enough of retries, backoff should reach its allowed max value"
     163              :         );
     164            1 :     }
     165              : 
     166            1 :     #[tokio::test(start_paused = true)]
     167            1 :     async fn retry_always_error() {
     168            1 :         let count = Mutex::new(0);
     169            1 :         let err_result = retry(
     170            2 :             || async {
     171            2 :                 *count.lock().await += 1;
     172            2 :                 Result::<(), io::Error>::Err(io::Error::from(io::ErrorKind::Other))
     173            2 :             },
     174            2 :             |_e| false,
     175            1 :             1,
     176            1 :             1,
     177            1 :             "work",
     178            1 :             Cancel::new(CancellationToken::new(), || -> io::Error { unreachable!() }),
     179            1 :         )
     180            0 :         .await;
     181              : 
     182            1 :         assert!(err_result.is_err());
     183              : 
     184            1 :         assert_eq!(*count.lock().await, 2);
     185              :     }
     186              : 
     187            1 :     #[tokio::test(start_paused = true)]
     188            1 :     async fn retry_ok_after_err() {
     189            1 :         let count = Mutex::new(0);
     190            1 :         retry(
     191            3 :             || async {
     192            3 :                 let mut locked = count.lock().await;
     193            3 :                 if *locked > 1 {
     194            1 :                     Ok(())
     195              :                 } else {
     196            2 :                     *locked += 1;
     197            2 :                     Err(io::Error::from(io::ErrorKind::Other))
     198              :                 }
     199            3 :             },
     200            2 :             |_e| false,
     201            1 :             2,
     202            1 :             2,
     203            1 :             "work",
     204            1 :             Cancel::new(CancellationToken::new(), || -> io::Error { unreachable!() }),
     205            1 :         )
     206            1 :         .await
     207            1 :         .unwrap();
     208              :     }
     209              : 
     210            1 :     #[tokio::test(start_paused = true)]
     211            1 :     async fn dont_retry_permanent_errors() {
     212            1 :         let count = Mutex::new(0);
     213            1 :         let _ = retry(
     214            1 :             || async {
     215            1 :                 let mut locked = count.lock().await;
     216            1 :                 if *locked > 1 {
     217            0 :                     Ok(())
     218              :                 } else {
     219            1 :                     *locked += 1;
     220            1 :                     Err(io::Error::from(io::ErrorKind::Other))
     221              :                 }
     222            1 :             },
     223            1 :             |_e| true,
     224            1 :             2,
     225            1 :             2,
     226            1 :             "work",
     227            1 :             Cancel::new(CancellationToken::new(), || -> io::Error { unreachable!() }),
     228            1 :         )
     229            0 :         .await
     230            1 :         .unwrap_err();
     231              : 
     232            1 :         assert_eq!(*count.lock().await, 1);
     233              :     }
     234              : }
        

Generated by: LCOV version 2.1-beta