Line data Source code
1 : use std::fmt::{Debug, Display};
2 :
3 : use futures::Future;
4 : use tokio_util::sync::CancellationToken;
5 :
6 : pub const DEFAULT_BASE_BACKOFF_SECONDS: f64 = 0.1;
7 : pub const DEFAULT_MAX_BACKOFF_SECONDS: f64 = 3.0;
8 :
9 410 : pub async fn exponential_backoff(
10 410 : n: u32,
11 410 : base_increment: f64,
12 410 : max_seconds: f64,
13 410 : cancel: &CancellationToken,
14 410 : ) {
15 410 : let backoff_duration_seconds =
16 410 : exponential_backoff_duration_seconds(n, base_increment, max_seconds);
17 410 : if backoff_duration_seconds > 0.0 {
18 1 : tracing::info!(
19 0 : "Backoff: waiting {backoff_duration_seconds} seconds before processing with the task",
20 : );
21 :
22 : drop(
23 1 : tokio::time::timeout(
24 1 : std::time::Duration::from_secs_f64(backoff_duration_seconds),
25 1 : cancel.cancelled(),
26 1 : )
27 1 : .await,
28 : )
29 409 : }
30 410 : }
31 :
32 10410 : pub fn exponential_backoff_duration_seconds(n: u32, base_increment: f64, max_seconds: f64) -> f64 {
33 10410 : if n == 0 {
34 410 : 0.0
35 : } else {
36 10000 : (1.0 + base_increment).powf(f64::from(n)).min(max_seconds)
37 : }
38 10410 : }
39 :
40 : /// Retries passed operation until one of the following conditions are met:
41 : /// - encountered error is considered as permanent (non-retryable)
42 : /// - retries have been exhausted
43 : /// - cancellation token has been cancelled
44 : ///
45 : /// `is_permanent` closure should be used to provide distinction between permanent/non-permanent
46 : /// errors. When attempts cross `warn_threshold` function starts to emit log warnings.
47 : /// `description` argument is added to log messages. Its value should identify the `op` is doing
48 : /// `cancel` cancels new attempts and the backoff sleep.
49 : ///
50 : /// If attempts fail, they are being logged with `{:#}` which works for anyhow, but does not work
51 : /// for any other error type. Final failed attempt is logged with `{:?}`.
52 : ///
53 : /// Returns `None` if cancellation was noticed during backoff or the terminal result.
54 419 : pub async fn retry<T, O, F, E>(
55 419 : mut op: O,
56 419 : is_permanent: impl Fn(&E) -> bool,
57 419 : warn_threshold: u32,
58 419 : max_retries: u32,
59 419 : description: &str,
60 419 : cancel: &CancellationToken,
61 419 : ) -> Option<Result<T, E>>
62 419 : where
63 419 : // Not std::error::Error because anyhow::Error doesnt implement it.
64 419 : // For context see https://github.com/dtolnay/anyhow/issues/63
65 419 : E: Display + Debug + 'static,
66 419 : O: FnMut() -> F,
67 419 : F: Future<Output = Result<T, E>>,
68 419 : {
69 419 : let mut attempts = 0;
70 : loop {
71 436 : if cancel.is_cancelled() {
72 0 : return None;
73 436 : }
74 :
75 2020 : let result = op().await;
76 1 : match &result {
77 : Ok(_) => {
78 403 : if attempts > 0 {
79 15 : tracing::info!("{description} succeeded after {attempts} retries");
80 388 : }
81 403 : return Some(result);
82 : }
83 :
84 : // These are "permanent" errors that should not be retried.
85 33 : Err(e) if is_permanent(e) => {
86 15 : return Some(result);
87 : }
88 : // Assume that any other failure might be transient, and the operation might
89 : // succeed if we just keep trying.
90 18 : Err(err) if attempts < warn_threshold => {
91 17 : tracing::info!("{description} failed, will retry (attempt {attempts}): {err:#}");
92 : }
93 1 : Err(err) if attempts < max_retries => {
94 0 : tracing::warn!("{description} failed, will retry (attempt {attempts}): {err:#}");
95 : }
96 1 : Err(err) => {
97 1 : // Operation failed `max_attempts` times. Time to give up.
98 1 : tracing::warn!(
99 0 : "{description} still failed after {attempts} retries, giving up: {err:?}"
100 : );
101 1 : return Some(result);
102 : }
103 : }
104 : // sleep and retry
105 17 : exponential_backoff(
106 17 : attempts,
107 17 : DEFAULT_BASE_BACKOFF_SECONDS,
108 17 : DEFAULT_MAX_BACKOFF_SECONDS,
109 17 : cancel,
110 17 : )
111 1 : .await;
112 17 : attempts += 1;
113 : }
114 419 : }
115 :
116 : #[cfg(test)]
117 : mod tests {
118 : use super::*;
119 : use std::io;
120 : use tokio::sync::Mutex;
121 :
122 : #[test]
123 1 : fn backoff_defaults_produce_growing_backoff_sequence() {
124 1 : let mut current_backoff_value = None;
125 :
126 10001 : for i in 0..10_000 {
127 10000 : let new_backoff_value = exponential_backoff_duration_seconds(
128 10000 : i,
129 10000 : DEFAULT_BASE_BACKOFF_SECONDS,
130 10000 : DEFAULT_MAX_BACKOFF_SECONDS,
131 10000 : );
132 :
133 10000 : if let Some(old_backoff_value) = current_backoff_value.replace(new_backoff_value) {
134 9999 : assert!(
135 9999 : old_backoff_value <= new_backoff_value,
136 0 : "{i}th backoff value {new_backoff_value} is smaller than the previous one {old_backoff_value}"
137 : )
138 1 : }
139 : }
140 :
141 1 : assert_eq!(
142 1 : current_backoff_value.expect("Should have produced backoff values to compare"),
143 : DEFAULT_MAX_BACKOFF_SECONDS,
144 0 : "Given big enough of retries, backoff should reach its allowed max value"
145 : );
146 1 : }
147 :
148 : #[tokio::test(start_paused = true)]
149 1 : async fn retry_always_error() {
150 1 : let count = Mutex::new(0);
151 1 : retry(
152 2 : || async {
153 2 : *count.lock().await += 1;
154 2 : Result::<(), io::Error>::Err(io::Error::from(io::ErrorKind::Other))
155 2 : },
156 2 : |_e| false,
157 1 : 1,
158 1 : 1,
159 1 : "work",
160 1 : &CancellationToken::new(),
161 1 : )
162 1 : .await
163 1 : .expect("not cancelled")
164 1 : .expect_err("it can only fail");
165 1 :
166 1 : assert_eq!(*count.lock().await, 2);
167 1 : }
168 :
169 : #[tokio::test(start_paused = true)]
170 1 : async fn retry_ok_after_err() {
171 1 : let count = Mutex::new(0);
172 1 : retry(
173 3 : || async {
174 3 : let mut locked = count.lock().await;
175 3 : if *locked > 1 {
176 1 : Ok(())
177 1 : } else {
178 2 : *locked += 1;
179 2 : Err(io::Error::from(io::ErrorKind::Other))
180 1 : }
181 6 : },
182 2 : |_e| false,
183 1 : 2,
184 1 : 2,
185 1 : "work",
186 1 : &CancellationToken::new(),
187 1 : )
188 1 : .await
189 1 : .expect("not cancelled")
190 1 : .expect("success on second try");
191 1 : }
192 :
193 : #[tokio::test(start_paused = true)]
194 1 : async fn dont_retry_permanent_errors() {
195 1 : let count = Mutex::new(0);
196 1 : let _ = retry(
197 1 : || async {
198 1 : let mut locked = count.lock().await;
199 1 : if *locked > 1 {
200 1 : Ok(())
201 1 : } else {
202 1 : *locked += 1;
203 1 : Err(io::Error::from(io::ErrorKind::Other))
204 1 : }
205 2 : },
206 1 : |_e| true,
207 1 : 2,
208 1 : 2,
209 1 : "work",
210 1 : &CancellationToken::new(),
211 1 : )
212 1 : .await
213 1 : .expect("was not cancellation")
214 1 : .expect_err("it was permanent error");
215 1 :
216 1 : assert_eq!(*count.lock().await, 1);
217 1 : }
218 : }
|