Line data Source code
1 : use std::fmt::{Debug, Display};
2 : use std::time::Duration;
3 :
4 : use futures::Future;
5 : use tokio_util::sync::CancellationToken;
6 :
7 : pub const DEFAULT_BASE_BACKOFF_SECONDS: f64 = 0.1;
8 : pub const DEFAULT_MAX_BACKOFF_SECONDS: f64 = 3.0;
9 :
10 1061 : pub async fn exponential_backoff(
11 1061 : n: u32,
12 1061 : base_increment: f64,
13 1061 : max_seconds: f64,
14 1061 : cancel: &CancellationToken,
15 1061 : ) {
16 1061 : let backoff_duration_seconds =
17 1061 : exponential_backoff_duration_seconds(n, base_increment, max_seconds);
18 1061 : if backoff_duration_seconds > 0.0 {
19 1 : tracing::info!(
20 0 : "Backoff: waiting {backoff_duration_seconds} seconds before processing with the task",
21 : );
22 :
23 : drop(
24 1 : tokio::time::timeout(
25 1 : std::time::Duration::from_secs_f64(backoff_duration_seconds),
26 1 : cancel.cancelled(),
27 1 : )
28 1 : .await,
29 : )
30 2 : }
31 3 : }
32 :
33 0 : pub fn exponential_backoff_duration(n: u32, base_increment: f64, max_seconds: f64) -> Duration {
34 0 : let seconds = exponential_backoff_duration_seconds(n, base_increment, max_seconds);
35 0 : Duration::from_secs_f64(seconds)
36 0 : }
37 :
38 11106 : pub fn exponential_backoff_duration_seconds(n: u32, base_increment: f64, max_seconds: f64) -> f64 {
39 11106 : if n == 0 {
40 1061 : 0.0
41 : } else {
42 10045 : (1.0 + base_increment).powf(f64::from(n)).min(max_seconds)
43 : }
44 11106 : }
45 :
46 : /// Retries passed operation until one of the following conditions are met:
47 : /// - encountered error is considered as permanent (non-retryable)
48 : /// - retries have been exhausted
49 : /// - cancellation token has been cancelled
50 : ///
51 : /// `is_permanent` closure should be used to provide distinction between permanent/non-permanent
52 : /// errors. When attempts cross `warn_threshold` function starts to emit log warnings.
53 : /// `description` argument is added to log messages. Its value should identify the `op` is doing
54 : /// `cancel` cancels new attempts and the backoff sleep.
55 : ///
56 : /// If attempts fail, they are being logged with `{:#}` which works for anyhow, but does not work
57 : /// for any other error type. Final failed attempt is logged with `{:?}`.
58 : ///
59 : /// Returns `None` if cancellation was noticed during backoff or the terminal result.
60 2569 : pub async fn retry<T, O, F, E>(
61 2569 : mut op: O,
62 2569 : is_permanent: impl Fn(&E) -> bool,
63 2569 : warn_threshold: u32,
64 2569 : max_retries: u32,
65 2569 : description: &str,
66 2569 : cancel: &CancellationToken,
67 2569 : ) -> Option<Result<T, E>>
68 2569 : where
69 2569 : // Not std::error::Error because anyhow::Error doesnt implement it.
70 2569 : // For context see https://github.com/dtolnay/anyhow/issues/63
71 2569 : E: Display + Debug + 'static,
72 2569 : O: FnMut() -> F,
73 2569 : F: Future<Output = Result<T, E>>,
74 2569 : {
75 2569 : let mut attempts = 0;
76 : loop {
77 2588 : if cancel.is_cancelled() {
78 0 : return None;
79 6 : }
80 :
81 2588 : let result = op().await;
82 1 : match &result {
83 : Ok(_) => {
84 1183 : if attempts > 0 {
85 17 : tracing::info!("{description} succeeded after {attempts} retries");
86 0 : }
87 1183 : return Some(result);
88 : }
89 :
90 : // These are "permanent" errors that should not be retried.
91 1405 : Err(e) if is_permanent(e) => {
92 1385 : return Some(result);
93 : }
94 : // Assume that any other failure might be transient, and the operation might
95 : // succeed if we just keep trying.
96 20 : Err(err) if attempts < warn_threshold => {
97 19 : tracing::info!("{description} failed, will retry (attempt {attempts}): {err:#}");
98 : }
99 1 : Err(err) if attempts < max_retries => {
100 0 : tracing::warn!("{description} failed, will retry (attempt {attempts}): {err:#}");
101 : }
102 1 : Err(err) => {
103 1 : // Operation failed `max_attempts` times. Time to give up.
104 1 : tracing::warn!(
105 0 : "{description} still failed after {attempts} retries, giving up: {err:?}"
106 : );
107 1 : return Some(result);
108 : }
109 : }
110 : // sleep and retry
111 19 : exponential_backoff(
112 19 : attempts,
113 19 : DEFAULT_BASE_BACKOFF_SECONDS,
114 19 : DEFAULT_MAX_BACKOFF_SECONDS,
115 19 : cancel,
116 19 : )
117 19 : .await;
118 19 : attempts += 1;
119 : }
120 3 : }
121 :
122 : #[cfg(test)]
123 : mod tests {
124 : use std::io;
125 :
126 : use tokio::sync::Mutex;
127 :
128 : use super::*;
129 :
130 : #[test]
131 1 : fn backoff_defaults_produce_growing_backoff_sequence() {
132 1 : let mut current_backoff_value = None;
133 :
134 10001 : for i in 0..10_000 {
135 10000 : let new_backoff_value = exponential_backoff_duration_seconds(
136 10000 : i,
137 10000 : DEFAULT_BASE_BACKOFF_SECONDS,
138 10000 : DEFAULT_MAX_BACKOFF_SECONDS,
139 10000 : );
140 :
141 10000 : if let Some(old_backoff_value) = current_backoff_value.replace(new_backoff_value) {
142 9999 : assert!(
143 9999 : old_backoff_value <= new_backoff_value,
144 0 : "{i}th backoff value {new_backoff_value} is smaller than the previous one {old_backoff_value}"
145 : )
146 1 : }
147 : }
148 :
149 1 : assert_eq!(
150 1 : current_backoff_value.expect("Should have produced backoff values to compare"),
151 : DEFAULT_MAX_BACKOFF_SECONDS,
152 0 : "Given big enough of retries, backoff should reach its allowed max value"
153 : );
154 1 : }
155 :
156 : #[tokio::test(start_paused = true)]
157 1 : async fn retry_always_error() {
158 1 : let count = Mutex::new(0);
159 1 : retry(
160 2 : || async {
161 2 : *count.lock().await += 1;
162 2 : Result::<(), io::Error>::Err(io::Error::from(io::ErrorKind::Other))
163 2 : },
164 2 : |_e| false,
165 1 : 1,
166 1 : 1,
167 1 : "work",
168 1 : &CancellationToken::new(),
169 1 : )
170 1 : .await
171 1 : .expect("not cancelled")
172 1 : .expect_err("it can only fail");
173 1 :
174 1 : assert_eq!(*count.lock().await, 2);
175 1 : }
176 :
177 : #[tokio::test(start_paused = true)]
178 1 : async fn retry_ok_after_err() {
179 1 : let count = Mutex::new(0);
180 1 : retry(
181 3 : || async {
182 3 : let mut locked = count.lock().await;
183 3 : if *locked > 1 {
184 1 : Ok(())
185 1 : } else {
186 2 : *locked += 1;
187 2 : Err(io::Error::from(io::ErrorKind::Other))
188 1 : }
189 6 : },
190 2 : |_e| false,
191 1 : 2,
192 1 : 2,
193 1 : "work",
194 1 : &CancellationToken::new(),
195 1 : )
196 1 : .await
197 1 : .expect("not cancelled")
198 1 : .expect("success on second try");
199 1 : }
200 :
201 : #[tokio::test(start_paused = true)]
202 1 : async fn dont_retry_permanent_errors() {
203 1 : let count = Mutex::new(0);
204 1 : let _ = retry(
205 1 : || async {
206 1 : let mut locked = count.lock().await;
207 1 : if *locked > 1 {
208 1 : Ok(())
209 1 : } else {
210 1 : *locked += 1;
211 1 : Err(io::Error::from(io::ErrorKind::Other))
212 1 : }
213 2 : },
214 1 : |_e| true,
215 1 : 2,
216 1 : 2,
217 1 : "work",
218 1 : &CancellationToken::new(),
219 1 : )
220 1 : .await
221 1 : .expect("was not cancellation")
222 1 : .expect_err("it was permanent error");
223 1 :
224 1 : assert_eq!(*count.lock().await, 1);
225 1 : }
226 : }
|