Line data Source code
1 : use std::fmt::{Debug, Display};
2 : use std::time::Duration;
3 :
4 : use futures::Future;
5 : use tokio_util::sync::CancellationToken;
6 :
7 : pub const DEFAULT_BASE_BACKOFF_SECONDS: f64 = 0.1;
8 : pub const DEFAULT_MAX_BACKOFF_SECONDS: f64 = 3.0;
9 :
10 1035 : pub async fn exponential_backoff(
11 1035 : n: u32,
12 1035 : base_increment: f64,
13 1035 : max_seconds: f64,
14 1035 : cancel: &CancellationToken,
15 1035 : ) {
16 1035 : let backoff_duration_seconds =
17 1035 : exponential_backoff_duration_seconds(n, base_increment, max_seconds);
18 1035 : if backoff_duration_seconds > 0.0 {
19 1 : tracing::info!(
20 0 : "Backoff: waiting {backoff_duration_seconds} seconds before processing with the task",
21 : );
22 :
23 : drop(
24 1 : tokio::time::timeout(
25 1 : std::time::Duration::from_secs_f64(backoff_duration_seconds),
26 1 : cancel.cancelled(),
27 1 : )
28 1 : .await,
29 : )
30 1034 : }
31 1035 : }
32 :
33 0 : pub fn exponential_backoff_duration(n: u32, base_increment: f64, max_seconds: f64) -> Duration {
34 0 : let seconds = exponential_backoff_duration_seconds(n, base_increment, max_seconds);
35 0 : Duration::from_secs_f64(seconds)
36 0 : }
37 :
38 11080 : pub fn exponential_backoff_duration_seconds(n: u32, base_increment: f64, max_seconds: f64) -> f64 {
39 11080 : if n == 0 {
40 1035 : 0.0
41 : } else {
42 10045 : (1.0 + base_increment).powf(f64::from(n)).min(max_seconds)
43 : }
44 11080 : }
45 :
46 : /// Retries passed operation until one of the following conditions are met:
47 : /// - encountered error is considered as permanent (non-retryable)
48 : /// - retries have been exhausted
49 : /// - cancellation token has been cancelled
50 : ///
51 : /// `is_permanent` closure should be used to provide distinction between permanent/non-permanent
52 : /// errors. When attempts cross `warn_threshold` function starts to emit log warnings.
53 : /// `description` argument is added to log messages. Its value should identify the `op` is doing
54 : /// `cancel` cancels new attempts and the backoff sleep.
55 : ///
56 : /// If attempts fail, they are being logged with `{:#}` which works for anyhow, but does not work
57 : /// for any other error type. Final failed attempt is logged with `{:?}`.
58 : ///
59 : /// Returns `None` if cancellation was noticed during backoff or the terminal result.
60 2529 : pub async fn retry<T, O, F, E>(
61 2529 : mut op: O,
62 2529 : is_permanent: impl Fn(&E) -> bool,
63 2529 : warn_threshold: u32,
64 2529 : max_retries: u32,
65 2529 : description: &str,
66 2529 : cancel: &CancellationToken,
67 2529 : ) -> Option<Result<T, E>>
68 2529 : where
69 2529 : // Not std::error::Error because anyhow::Error doesnt implement it.
70 2529 : // For context see https://github.com/dtolnay/anyhow/issues/63
71 2529 : E: Display + Debug + 'static,
72 2529 : O: FnMut() -> F,
73 2529 : F: Future<Output = Result<T, E>>,
74 2529 : {
75 2529 : let mut attempts = 0;
76 : loop {
77 2546 : if cancel.is_cancelled() {
78 0 : return None;
79 2546 : }
80 :
81 2546 : let result = op().await;
82 1 : match &result {
83 : Ok(_) => {
84 1167 : if attempts > 0 {
85 15 : tracing::info!("{description} succeeded after {attempts} retries");
86 1152 : }
87 1167 : return Some(result);
88 : }
89 :
90 : // These are "permanent" errors that should not be retried.
91 1379 : Err(e) if is_permanent(e) => {
92 1361 : return Some(result);
93 : }
94 : // Assume that any other failure might be transient, and the operation might
95 : // succeed if we just keep trying.
96 18 : Err(err) if attempts < warn_threshold => {
97 17 : tracing::info!("{description} failed, will retry (attempt {attempts}): {err:#}");
98 : }
99 1 : Err(err) if attempts < max_retries => {
100 0 : tracing::warn!("{description} failed, will retry (attempt {attempts}): {err:#}");
101 : }
102 1 : Err(err) => {
103 1 : // Operation failed `max_attempts` times. Time to give up.
104 1 : tracing::warn!(
105 0 : "{description} still failed after {attempts} retries, giving up: {err:?}"
106 : );
107 1 : return Some(result);
108 : }
109 : }
110 : // sleep and retry
111 17 : exponential_backoff(
112 17 : attempts,
113 17 : DEFAULT_BASE_BACKOFF_SECONDS,
114 17 : DEFAULT_MAX_BACKOFF_SECONDS,
115 17 : cancel,
116 17 : )
117 17 : .await;
118 17 : attempts += 1;
119 : }
120 2529 : }
121 :
122 : #[cfg(test)]
123 : mod tests {
124 : use super::*;
125 : use std::io;
126 : use tokio::sync::Mutex;
127 :
128 : #[test]
129 1 : fn backoff_defaults_produce_growing_backoff_sequence() {
130 1 : let mut current_backoff_value = None;
131 :
132 10001 : for i in 0..10_000 {
133 10000 : let new_backoff_value = exponential_backoff_duration_seconds(
134 10000 : i,
135 10000 : DEFAULT_BASE_BACKOFF_SECONDS,
136 10000 : DEFAULT_MAX_BACKOFF_SECONDS,
137 10000 : );
138 :
139 10000 : if let Some(old_backoff_value) = current_backoff_value.replace(new_backoff_value) {
140 9999 : assert!(
141 9999 : old_backoff_value <= new_backoff_value,
142 0 : "{i}th backoff value {new_backoff_value} is smaller than the previous one {old_backoff_value}"
143 : )
144 1 : }
145 : }
146 :
147 1 : assert_eq!(
148 1 : current_backoff_value.expect("Should have produced backoff values to compare"),
149 : DEFAULT_MAX_BACKOFF_SECONDS,
150 0 : "Given big enough of retries, backoff should reach its allowed max value"
151 : );
152 1 : }
153 :
154 : #[tokio::test(start_paused = true)]
155 1 : async fn retry_always_error() {
156 1 : let count = Mutex::new(0);
157 1 : retry(
158 2 : || async {
159 2 : *count.lock().await += 1;
160 2 : Result::<(), io::Error>::Err(io::Error::from(io::ErrorKind::Other))
161 2 : },
162 2 : |_e| false,
163 1 : 1,
164 1 : 1,
165 1 : "work",
166 1 : &CancellationToken::new(),
167 1 : )
168 1 : .await
169 1 : .expect("not cancelled")
170 1 : .expect_err("it can only fail");
171 1 :
172 1 : assert_eq!(*count.lock().await, 2);
173 1 : }
174 :
175 : #[tokio::test(start_paused = true)]
176 1 : async fn retry_ok_after_err() {
177 1 : let count = Mutex::new(0);
178 1 : retry(
179 3 : || async {
180 3 : let mut locked = count.lock().await;
181 3 : if *locked > 1 {
182 1 : Ok(())
183 1 : } else {
184 2 : *locked += 1;
185 2 : Err(io::Error::from(io::ErrorKind::Other))
186 1 : }
187 6 : },
188 2 : |_e| false,
189 1 : 2,
190 1 : 2,
191 1 : "work",
192 1 : &CancellationToken::new(),
193 1 : )
194 1 : .await
195 1 : .expect("not cancelled")
196 1 : .expect("success on second try");
197 1 : }
198 :
199 : #[tokio::test(start_paused = true)]
200 1 : async fn dont_retry_permanent_errors() {
201 1 : let count = Mutex::new(0);
202 1 : let _ = retry(
203 1 : || async {
204 1 : let mut locked = count.lock().await;
205 1 : if *locked > 1 {
206 1 : Ok(())
207 1 : } else {
208 1 : *locked += 1;
209 1 : Err(io::Error::from(io::ErrorKind::Other))
210 1 : }
211 2 : },
212 1 : |_e| true,
213 1 : 2,
214 1 : 2,
215 1 : "work",
216 1 : &CancellationToken::new(),
217 1 : )
218 1 : .await
219 1 : .expect("was not cancellation")
220 1 : .expect_err("it was permanent error");
221 1 :
222 1 : assert_eq!(*count.lock().await, 1);
223 1 : }
224 : }
|