Line data Source code
1 : use std::fmt::{Debug, Display};
2 :
3 : use futures::Future;
4 : use tokio_util::sync::CancellationToken;
5 :
6 : pub const DEFAULT_BASE_BACKOFF_SECONDS: f64 = 0.1;
7 : pub const DEFAULT_MAX_BACKOFF_SECONDS: f64 = 3.0;
8 :
9 6199 : pub async fn exponential_backoff(
10 6199 : n: u32,
11 6199 : base_increment: f64,
12 6199 : max_seconds: f64,
13 6199 : cancel: &CancellationToken,
14 6199 : ) {
15 6199 : let backoff_duration_seconds =
16 6199 : exponential_backoff_duration_seconds(n, base_increment, max_seconds);
17 6199 : if backoff_duration_seconds > 0.0 {
18 8 : tracing::info!(
19 8 : "Backoff: waiting {backoff_duration_seconds} seconds before processing with the task",
20 8 : );
21 :
22 : drop(
23 9 : tokio::time::timeout(
24 9 : std::time::Duration::from_secs_f64(backoff_duration_seconds),
25 9 : cancel.cancelled(),
26 9 : )
27 7 : .await,
28 : )
29 6190 : }
30 6198 : }
31 :
32 16199 : pub fn exponential_backoff_duration_seconds(n: u32, base_increment: f64, max_seconds: f64) -> f64 {
33 16199 : if n == 0 {
34 6191 : 0.0
35 : } else {
36 10008 : (1.0 + base_increment).powf(f64::from(n)).min(max_seconds)
37 : }
38 16199 : }
39 :
40 : /// Configure cancellation for a retried operation: when to cancel (the token), and
41 : /// what kind of error to return on cancellation
42 : pub struct Cancel<E, CF>
43 : where
44 : E: Display + Debug + 'static,
45 : CF: Fn() -> E,
46 : {
47 : token: CancellationToken,
48 : on_cancel: CF,
49 : }
50 :
51 : impl<E, CF> Cancel<E, CF>
52 : where
53 : E: Display + Debug + 'static,
54 : CF: Fn() -> E,
55 : {
56 2450 : pub fn new(token: CancellationToken, on_cancel: CF) -> Self {
57 2450 : Self { token, on_cancel }
58 2450 : }
59 : }
60 :
61 : /// retries passed operation until one of the following conditions are met:
62 : /// Encountered error is considered as permanent (non-retryable)
63 : /// Retries have been exhausted.
64 : /// `is_permanent` closure should be used to provide distinction between permanent/non-permanent errors
65 : /// When attempts cross `warn_threshold` function starts to emit log warnings.
66 : /// `description` argument is added to log messages. Its value should identify the `op` is doing
67 : /// `cancel` argument is required: any time we are looping on retry, we should be using a CancellationToken
68 : /// to drop out promptly on shutdown.
69 2450 : pub async fn retry<T, O, F, E, CF>(
70 2450 : mut op: O,
71 2450 : is_permanent: impl Fn(&E) -> bool,
72 2450 : warn_threshold: u32,
73 2450 : max_retries: u32,
74 2450 : description: &str,
75 2450 : cancel: Cancel<E, CF>,
76 2450 : ) -> Result<T, E>
77 2450 : where
78 2450 : // Not std::error::Error because anyhow::Error doesnt implement it.
79 2450 : // For context see https://github.com/dtolnay/anyhow/issues/63
80 2450 : E: Display + Debug + 'static,
81 2450 : O: FnMut() -> F,
82 2450 : F: Future<Output = Result<T, E>>,
83 2450 : CF: Fn() -> E,
84 2450 : {
85 2450 : let mut attempts = 0;
86 : loop {
87 2824 : if cancel.token.is_cancelled() {
88 0 : return Err((cancel.on_cancel)());
89 2824 : }
90 :
91 362365 : let result = op().await;
92 1 : match result {
93 : Ok(_) => {
94 2063 : if attempts > 0 {
95 317 : tracing::info!("{description} succeeded after {attempts} retries");
96 1745 : }
97 2063 : return result;
98 : }
99 :
100 : // These are "permanent" errors that should not be retried.
101 759 : Err(ref e) if is_permanent(e) => {
102 384 : return result;
103 : }
104 : // Assume that any other failure might be transient, and the operation might
105 : // succeed if we just keep trying.
106 375 : Err(err) if attempts < warn_threshold => {
107 371 : tracing::info!("{description} failed, will retry (attempt {attempts}): {err:#}");
108 : }
109 1 : Err(err) if attempts < max_retries => {
110 0 : tracing::warn!("{description} failed, will retry (attempt {attempts}): {err:#}");
111 : }
112 1 : Err(ref err) => {
113 : // Operation failed `max_attempts` times. Time to give up.
114 0 : tracing::warn!(
115 0 : "{description} still failed after {attempts} retries, giving up: {err:?}"
116 0 : );
117 1 : return result;
118 : }
119 : }
120 : // sleep and retry
121 374 : exponential_backoff(
122 374 : attempts,
123 374 : DEFAULT_BASE_BACKOFF_SECONDS,
124 374 : DEFAULT_MAX_BACKOFF_SECONDS,
125 374 : &cancel.token,
126 374 : )
127 1 : .await;
128 374 : attempts += 1;
129 : }
130 2448 : }
131 :
132 : #[cfg(test)]
133 : mod tests {
134 : use std::io;
135 :
136 : use tokio::sync::Mutex;
137 :
138 : use super::*;
139 :
140 1 : #[test]
141 1 : fn backoff_defaults_produce_growing_backoff_sequence() {
142 1 : let mut current_backoff_value = None;
143 :
144 10001 : for i in 0..10_000 {
145 10000 : let new_backoff_value = exponential_backoff_duration_seconds(
146 10000 : i,
147 10000 : DEFAULT_BASE_BACKOFF_SECONDS,
148 10000 : DEFAULT_MAX_BACKOFF_SECONDS,
149 10000 : );
150 :
151 10000 : if let Some(old_backoff_value) = current_backoff_value.replace(new_backoff_value) {
152 9999 : assert!(
153 9999 : old_backoff_value <= new_backoff_value,
154 0 : "{i}th backoff value {new_backoff_value} is smaller than the previous one {old_backoff_value}"
155 : )
156 1 : }
157 : }
158 :
159 1 : assert_eq!(
160 1 : current_backoff_value.expect("Should have produced backoff values to compare"),
161 : DEFAULT_MAX_BACKOFF_SECONDS,
162 0 : "Given big enough of retries, backoff should reach its allowed max value"
163 : );
164 1 : }
165 :
166 1 : #[tokio::test(start_paused = true)]
167 1 : async fn retry_always_error() {
168 1 : let count = Mutex::new(0);
169 1 : let err_result = retry(
170 2 : || async {
171 2 : *count.lock().await += 1;
172 2 : Result::<(), io::Error>::Err(io::Error::from(io::ErrorKind::Other))
173 2 : },
174 2 : |_e| false,
175 1 : 1,
176 1 : 1,
177 1 : "work",
178 1 : Cancel::new(CancellationToken::new(), || -> io::Error { unreachable!() }),
179 1 : )
180 0 : .await;
181 :
182 1 : assert!(err_result.is_err());
183 :
184 1 : assert_eq!(*count.lock().await, 2);
185 : }
186 :
187 1 : #[tokio::test(start_paused = true)]
188 1 : async fn retry_ok_after_err() {
189 1 : let count = Mutex::new(0);
190 1 : retry(
191 3 : || async {
192 3 : let mut locked = count.lock().await;
193 3 : if *locked > 1 {
194 1 : Ok(())
195 : } else {
196 2 : *locked += 1;
197 2 : Err(io::Error::from(io::ErrorKind::Other))
198 : }
199 3 : },
200 2 : |_e| false,
201 1 : 2,
202 1 : 2,
203 1 : "work",
204 1 : Cancel::new(CancellationToken::new(), || -> io::Error { unreachable!() }),
205 1 : )
206 1 : .await
207 1 : .unwrap();
208 : }
209 :
210 1 : #[tokio::test(start_paused = true)]
211 1 : async fn dont_retry_permanent_errors() {
212 1 : let count = Mutex::new(0);
213 1 : let _ = retry(
214 1 : || async {
215 1 : let mut locked = count.lock().await;
216 1 : if *locked > 1 {
217 0 : Ok(())
218 : } else {
219 1 : *locked += 1;
220 1 : Err(io::Error::from(io::ErrorKind::Other))
221 : }
222 1 : },
223 1 : |_e| true,
224 1 : 2,
225 1 : 2,
226 1 : "work",
227 1 : Cancel::new(CancellationToken::new(), || -> io::Error { unreachable!() }),
228 1 : )
229 0 : .await
230 1 : .unwrap_err();
231 :
232 1 : assert_eq!(*count.lock().await, 1);
233 : }
234 : }
|