TLA Line data Source code
1 : use std::fmt::{Debug, Display};
2 :
3 : use futures::Future;
4 : use tokio_util::sync::CancellationToken;
5 :
6 : pub const DEFAULT_BASE_BACKOFF_SECONDS: f64 = 0.1;
7 : pub const DEFAULT_MAX_BACKOFF_SECONDS: f64 = 3.0;
8 :
9 CBC 4276 : pub async fn exponential_backoff(
10 4276 : n: u32,
11 4276 : base_increment: f64,
12 4276 : max_seconds: f64,
13 4276 : cancel: &CancellationToken,
14 4276 : ) {
15 4276 : let backoff_duration_seconds =
16 4276 : exponential_backoff_duration_seconds(n, base_increment, max_seconds);
17 4276 : if backoff_duration_seconds > 0.0 {
18 31 : tracing::info!(
19 31 : "Backoff: waiting {backoff_duration_seconds} seconds before processing with the task",
20 31 : );
21 :
22 : drop(
23 32 : tokio::time::timeout(
24 32 : std::time::Duration::from_secs_f64(backoff_duration_seconds),
25 32 : cancel.cancelled(),
26 32 : )
27 33 : .await,
28 : )
29 4244 : }
30 4275 : }
31 :
32 14276 : pub fn exponential_backoff_duration_seconds(n: u32, base_increment: f64, max_seconds: f64) -> f64 {
33 14276 : if n == 0 {
34 4245 : 0.0
35 : } else {
36 10031 : (1.0 + base_increment).powf(f64::from(n)).min(max_seconds)
37 : }
38 14276 : }
39 :
40 : /// Configure cancellation for a retried operation: when to cancel (the token), and
41 : /// what kind of error to return on cancellation
42 : pub struct Cancel<E, CF>
43 : where
44 : E: Display + Debug + 'static,
45 : CF: Fn() -> E,
46 : {
47 : token: CancellationToken,
48 : on_cancel: CF,
49 : }
50 :
51 : impl<E, CF> Cancel<E, CF>
52 : where
53 : E: Display + Debug + 'static,
54 : CF: Fn() -> E,
55 : {
56 3738 : pub fn new(token: CancellationToken, on_cancel: CF) -> Self {
57 3738 : Self { token, on_cancel }
58 3738 : }
59 : }
60 :
61 : /// retries passed operation until one of the following conditions are met:
62 : /// Encountered error is considered as permanent (non-retryable)
63 : /// Retries have been exhausted.
64 : /// `is_permanent` closure should be used to provide distinction between permanent/non-permanent errors
65 : /// When attempts cross `warn_threshold` function starts to emit log warnings.
66 : /// `description` argument is added to log messages. Its value should identify the `op` is doing
67 : /// `cancel` argument is required: any time we are looping on retry, we should be using a CancellationToken
68 : /// to drop out promptly on shutdown.
69 3738 : pub async fn retry<T, O, F, E, CF>(
70 3738 : mut op: O,
71 3738 : is_permanent: impl Fn(&E) -> bool,
72 3738 : warn_threshold: u32,
73 3738 : max_retries: u32,
74 3738 : description: &str,
75 3738 : cancel: Cancel<E, CF>,
76 3738 : ) -> Result<T, E>
77 3738 : where
78 3738 : // Not std::error::Error because anyhow::Error doesnt implement it.
79 3738 : // For context see https://github.com/dtolnay/anyhow/issues/63
80 3738 : E: Display + Debug + 'static,
81 3738 : O: FnMut() -> F,
82 3738 : F: Future<Output = Result<T, E>>,
83 3738 : CF: Fn() -> E,
84 3738 : {
85 3738 : let mut attempts = 0;
86 : loop {
87 4267 : if cancel.token.is_cancelled() {
88 2 : return Err((cancel.on_cancel)());
89 4265 : }
90 :
91 471390 : let result = op().await;
92 18 : match result {
93 : Ok(_) => {
94 3011 : if attempts > 0 {
95 428 : tracing::info!("{description} succeeded after {attempts} retries");
96 2582 : }
97 3011 : return result;
98 : }
99 :
100 : // These are "permanent" errors that should not be retried.
101 1252 : Err(ref e) if is_permanent(e) => {
102 720 : return result;
103 : }
104 : // Assume that any other failure might be transient, and the operation might
105 : // succeed if we just keep trying.
106 532 : Err(err) if attempts < warn_threshold => {
107 511 : tracing::info!("{description} failed, will retry (attempt {attempts}): {err:#}");
108 : }
109 18 : Err(err) if attempts < max_retries => {
110 16 : tracing::warn!("{description} failed, will retry (attempt {attempts}): {err:#}");
111 : }
112 2 : Err(ref err) => {
113 : // Operation failed `max_attempts` times. Time to give up.
114 1 : tracing::warn!(
115 1 : "{description} still failed after {attempts} retries, giving up: {err:?}"
116 1 : );
117 2 : return result;
118 : }
119 : }
120 : // sleep and retry
121 530 : exponential_backoff(
122 530 : attempts,
123 530 : DEFAULT_BASE_BACKOFF_SECONDS,
124 530 : DEFAULT_MAX_BACKOFF_SECONDS,
125 530 : &cancel.token,
126 530 : )
127 27 : .await;
128 529 : attempts += 1;
129 : }
130 3735 : }
131 :
132 : #[cfg(test)]
133 : mod tests {
134 : use std::io;
135 :
136 : use tokio::sync::Mutex;
137 :
138 : use super::*;
139 :
140 1 : #[test]
141 1 : fn backoff_defaults_produce_growing_backoff_sequence() {
142 1 : let mut current_backoff_value = None;
143 :
144 10001 : for i in 0..10_000 {
145 10000 : let new_backoff_value = exponential_backoff_duration_seconds(
146 10000 : i,
147 10000 : DEFAULT_BASE_BACKOFF_SECONDS,
148 10000 : DEFAULT_MAX_BACKOFF_SECONDS,
149 10000 : );
150 :
151 10000 : if let Some(old_backoff_value) = current_backoff_value.replace(new_backoff_value) {
152 9999 : assert!(
153 9999 : old_backoff_value <= new_backoff_value,
154 UBC 0 : "{i}th backoff value {new_backoff_value} is smaller than the previous one {old_backoff_value}"
155 : )
156 CBC 1 : }
157 : }
158 :
159 1 : assert_eq!(
160 1 : current_backoff_value.expect("Should have produced backoff values to compare"),
161 : DEFAULT_MAX_BACKOFF_SECONDS,
162 UBC 0 : "Given big enough of retries, backoff should reach its allowed max value"
163 : );
164 CBC 1 : }
165 :
166 1 : #[tokio::test(start_paused = true)]
167 1 : async fn retry_always_error() {
168 1 : let count = Mutex::new(0);
169 1 : let err_result = retry(
170 2 : || async {
171 2 : *count.lock().await += 1;
172 2 : Result::<(), io::Error>::Err(io::Error::from(io::ErrorKind::Other))
173 2 : },
174 2 : |_e| false,
175 1 : 1,
176 1 : 1,
177 1 : "work",
178 1 : Cancel::new(CancellationToken::new(), || -> io::Error { unreachable!() }),
179 1 : )
180 UBC 0 : .await;
181 :
182 CBC 1 : assert!(err_result.is_err());
183 :
184 1 : assert_eq!(*count.lock().await, 2);
185 : }
186 :
187 1 : #[tokio::test(start_paused = true)]
188 1 : async fn retry_ok_after_err() {
189 1 : let count = Mutex::new(0);
190 1 : retry(
191 3 : || async {
192 3 : let mut locked = count.lock().await;
193 3 : if *locked > 1 {
194 1 : Ok(())
195 : } else {
196 2 : *locked += 1;
197 2 : Err(io::Error::from(io::ErrorKind::Other))
198 : }
199 3 : },
200 2 : |_e| false,
201 1 : 2,
202 1 : 2,
203 1 : "work",
204 1 : Cancel::new(CancellationToken::new(), || -> io::Error { unreachable!() }),
205 1 : )
206 1 : .await
207 1 : .unwrap();
208 : }
209 :
210 1 : #[tokio::test(start_paused = true)]
211 1 : async fn dont_retry_permanent_errors() {
212 1 : let count = Mutex::new(0);
213 1 : let _ = retry(
214 1 : || async {
215 1 : let mut locked = count.lock().await;
216 1 : if *locked > 1 {
217 UBC 0 : Ok(())
218 : } else {
219 CBC 1 : *locked += 1;
220 1 : Err(io::Error::from(io::ErrorKind::Other))
221 : }
222 1 : },
223 1 : |_e| true,
224 1 : 2,
225 1 : 2,
226 1 : "work",
227 1 : Cancel::new(CancellationToken::new(), || -> io::Error { unreachable!() }),
228 1 : )
229 UBC 0 : .await
230 CBC 1 : .unwrap_err();
231 :
232 1 : assert_eq!(*count.lock().await, 1);
233 : }
234 : }
|