Line data Source code
1 : #![warn(missing_docs)]
2 :
3 : use std::cmp::{Eq, Ordering};
4 : use std::collections::BinaryHeap;
5 : use std::fmt::Debug;
6 : use std::mem;
7 : use std::sync::Mutex;
8 : use std::time::Duration;
9 : use tokio::sync::watch::{channel, Receiver, Sender};
10 : use tokio::time::timeout;
11 :
12 : /// An error happened while waiting for a number
13 0 : #[derive(Debug, PartialEq, Eq, thiserror::Error)]
14 : pub enum SeqWaitError {
15 : /// The wait timeout was reached
16 : #[error("seqwait timeout was reached")]
17 : Timeout,
18 :
19 : /// [`SeqWait::shutdown`] was called
20 : #[error("SeqWait::shutdown was called")]
21 : Shutdown,
22 : }
23 :
24 : /// Monotonically increasing value
25 : ///
26 : /// It is handy to store some other fields under the same mutex in `SeqWait<S>`
27 : /// (e.g. store prev_record_lsn). So we allow SeqWait to be parametrized with
28 : /// any type that can expose counter. `V` is the type of exposed counter.
29 : pub trait MonotonicCounter<V> {
30 : /// Bump counter value and check that it goes forward
31 : /// N.B.: new_val is an actual new value, not a difference.
32 : fn cnt_advance(&mut self, new_val: V);
33 :
34 : /// Get counter value
35 : fn cnt_value(&self) -> V;
36 : }
37 :
38 : /// Internal components of a `SeqWait`
39 : struct SeqWaitInt<S, V>
40 : where
41 : S: MonotonicCounter<V>,
42 : V: Ord,
43 : {
44 : waiters: BinaryHeap<Waiter<V>>,
45 : current: S,
46 : shutdown: bool,
47 : }
48 :
49 : struct Waiter<T>
50 : where
51 : T: Ord,
52 : {
53 : wake_num: T, // wake me when this number arrives ...
54 : wake_channel: Sender<()>, // ... by sending a message to this channel
55 : }
56 :
57 : // BinaryHeap is a max-heap, and we want a min-heap. Reverse the ordering here
58 : // to get that.
59 : impl<T: Ord> PartialOrd for Waiter<T> {
60 2 : fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
61 2 : Some(self.cmp(other))
62 2 : }
63 : }
64 :
65 : impl<T: Ord> Ord for Waiter<T> {
66 2 : fn cmp(&self, other: &Self) -> Ordering {
67 2 : other.wake_num.cmp(&self.wake_num)
68 2 : }
69 : }
70 :
71 : impl<T: Ord> PartialEq for Waiter<T> {
72 0 : fn eq(&self, other: &Self) -> bool {
73 0 : other.wake_num == self.wake_num
74 0 : }
75 : }
76 :
77 : impl<T: Ord> Eq for Waiter<T> {}
78 :
79 : /// A tool for waiting on a sequence number
80 : ///
81 : /// This provides a way to wait the arrival of a number.
82 : /// As soon as the number arrives by another caller calling
83 : /// [`advance`], then the waiter will be woken up.
84 : ///
85 : /// This implementation takes a blocking Mutex on both [`wait_for`]
86 : /// and [`advance`], meaning there may be unexpected executor blocking
87 : /// due to thread scheduling unfairness. There are probably better
88 : /// implementations, but we can probably live with this for now.
89 : ///
90 : /// [`wait_for`]: SeqWait::wait_for
91 : /// [`advance`]: SeqWait::advance
92 : ///
93 : /// `S` means Storage, `V` is type of counter that this storage exposes.
94 : ///
95 : pub struct SeqWait<S, V>
96 : where
97 : S: MonotonicCounter<V>,
98 : V: Ord,
99 : {
100 : internal: Mutex<SeqWaitInt<S, V>>,
101 : }
102 :
103 : impl<S, V> SeqWait<S, V>
104 : where
105 : S: MonotonicCounter<V> + Copy,
106 : V: Ord + Copy,
107 : {
108 : /// Create a new `SeqWait`, initialized to a particular number
109 368 : pub fn new(starting_num: S) -> Self {
110 368 : let internal = SeqWaitInt {
111 368 : waiters: BinaryHeap::new(),
112 368 : current: starting_num,
113 368 : shutdown: false,
114 368 : };
115 368 : SeqWait {
116 368 : internal: Mutex::new(internal),
117 368 : }
118 368 : }
119 :
120 : /// Shut down a `SeqWait`, causing all waiters (present and
121 : /// future) to return an error.
122 12 : pub fn shutdown(&self) {
123 12 : let waiters = {
124 12 : // Prevent new waiters; wake all those that exist.
125 12 : // Wake everyone with an error.
126 12 : let mut internal = self.internal.lock().unwrap();
127 12 :
128 12 : // Block any future waiters from starting
129 12 : internal.shutdown = true;
130 12 :
131 12 : // This will steal the entire waiters map.
132 12 : // When we drop it all waiters will be woken.
133 12 : mem::take(&mut internal.waiters)
134 12 :
135 12 : // Drop the lock as we exit this scope.
136 12 : };
137 12 :
138 12 : // When we drop the waiters list, each Receiver will
139 12 : // be woken with an error.
140 12 : // This drop doesn't need to be explicit; it's done
141 12 : // here to make it easier to read the code and understand
142 12 : // the order of events.
143 12 : drop(waiters);
144 12 : }
145 :
146 : /// Wait for a number to arrive
147 : ///
148 : /// This call won't complete until someone has called `advance`
149 : /// with a number greater than or equal to the one we're waiting for.
150 : ///
151 : /// This function is async cancellation-safe.
152 8 : pub async fn wait_for(&self, num: V) -> Result<(), SeqWaitError> {
153 8 : match self.queue_for_wait(num) {
154 2 : Ok(None) => Ok(()),
155 6 : Ok(Some(mut rx)) => rx.changed().await.map_err(|_| SeqWaitError::Shutdown),
156 0 : Err(e) => Err(e),
157 : }
158 8 : }
159 :
160 : /// Wait for a number to arrive
161 : ///
162 : /// This call won't complete until someone has called `advance`
163 : /// with a number greater than or equal to the one we're waiting for.
164 : ///
165 : /// If that hasn't happened after the specified timeout duration,
166 : /// [`SeqWaitError::Timeout`] will be returned.
167 : ///
168 : /// This function is async cancellation-safe.
169 226809 : pub async fn wait_for_timeout(
170 226809 : &self,
171 226809 : num: V,
172 226809 : timeout_duration: Duration,
173 226809 : ) -> Result<(), SeqWaitError> {
174 226809 : match self.queue_for_wait(num) {
175 226805 : Ok(None) => Ok(()),
176 4 : Ok(Some(mut rx)) => match timeout(timeout_duration, rx.changed()).await {
177 0 : Ok(Ok(())) => Ok(()),
178 0 : Ok(Err(_)) => Err(SeqWaitError::Shutdown),
179 4 : Err(_) => Err(SeqWaitError::Timeout),
180 : },
181 0 : Err(e) => Err(e),
182 : }
183 226809 : }
184 :
185 : /// Check if [`Self::wait_for`] or [`Self::wait_for_timeout`] would wait if called with `num`.
186 0 : pub fn would_wait_for(&self, num: V) -> Result<(), V> {
187 0 : let internal = self.internal.lock().unwrap();
188 0 : let cnt = internal.current.cnt_value();
189 0 : drop(internal);
190 0 : if cnt >= num {
191 0 : Ok(())
192 : } else {
193 0 : Err(cnt)
194 : }
195 0 : }
196 :
197 : /// Register and return a channel that will be notified when a number arrives,
198 : /// or None, if it has already arrived.
199 226817 : fn queue_for_wait(&self, num: V) -> Result<Option<Receiver<()>>, SeqWaitError> {
200 226817 : let mut internal = self.internal.lock().unwrap();
201 226817 : if internal.current.cnt_value() >= num {
202 226807 : return Ok(None);
203 10 : }
204 10 : if internal.shutdown {
205 0 : return Err(SeqWaitError::Shutdown);
206 10 : }
207 10 :
208 10 : // Create a new channel.
209 10 : let (tx, rx) = channel(());
210 10 : internal.waiters.push(Waiter {
211 10 : wake_num: num,
212 10 : wake_channel: tx,
213 10 : });
214 10 : // Drop the lock as we exit this scope.
215 10 : Ok(Some(rx))
216 226817 : }
217 :
218 : /// Announce a new number has arrived
219 : ///
220 : /// All waiters at this value or below will be woken.
221 : ///
222 : /// Returns the old number.
223 4122944 : pub fn advance(&self, num: V) -> V {
224 : let old_value;
225 3648254 : let wake_these = {
226 4122944 : let mut internal = self.internal.lock().unwrap();
227 4122944 :
228 4122944 : old_value = internal.current.cnt_value();
229 4122944 : if old_value >= num {
230 474690 : return old_value;
231 3648254 : }
232 3648254 : internal.current.cnt_advance(num);
233 3648254 :
234 3648254 : // Pop all waiters <= num from the heap. Collect them in a vector, and
235 3648254 : // wake them up after releasing the lock.
236 3648254 : let mut wake_these = Vec::new();
237 3648262 : while let Some(n) = internal.waiters.peek() {
238 8 : if n.wake_num > num {
239 0 : break;
240 8 : }
241 8 : wake_these.push(internal.waiters.pop().unwrap().wake_channel);
242 : }
243 3648254 : wake_these
244 : };
245 :
246 3648262 : for tx in wake_these {
247 8 : // This can fail if there are no receivers.
248 8 : // We don't care; discard the error.
249 8 : let _ = tx.send(());
250 8 : }
251 3648254 : old_value
252 4122944 : }
253 :
254 : /// Read the current value, without waiting.
255 3922434 : pub fn load(&self) -> S {
256 3922434 : self.internal.lock().unwrap().current
257 3922434 : }
258 : }
259 :
260 : #[cfg(test)]
261 : mod tests {
262 : use super::*;
263 : use std::sync::Arc;
264 :
265 : impl MonotonicCounter<i32> for i32 {
266 6 : fn cnt_advance(&mut self, val: i32) {
267 6 : assert!(*self <= val);
268 6 : *self = val;
269 6 : }
270 20 : fn cnt_value(&self) -> i32 {
271 20 : *self
272 20 : }
273 : }
274 :
275 : #[tokio::test]
276 2 : async fn seqwait() {
277 2 : let seq = Arc::new(SeqWait::new(0));
278 2 : let seq2 = Arc::clone(&seq);
279 2 : let seq3 = Arc::clone(&seq);
280 2 : let jh1 = tokio::task::spawn(async move {
281 2 : seq2.wait_for(42).await.expect("wait_for 42");
282 2 : let old = seq2.advance(100);
283 2 : assert_eq!(old, 99);
284 2 : seq2.wait_for_timeout(999, Duration::from_millis(100))
285 2 : .await
286 2 : .expect_err("no 999");
287 2 : });
288 2 : let jh2 = tokio::task::spawn(async move {
289 2 : seq3.wait_for(42).await.expect("wait_for 42");
290 2 : seq3.wait_for(0).await.expect("wait_for 0");
291 2 : });
292 2 : tokio::time::sleep(Duration::from_millis(200)).await;
293 2 : let old = seq.advance(99);
294 2 : assert_eq!(old, 0);
295 2 : seq.wait_for(100).await.expect("wait_for 100");
296 2 :
297 2 : // Calling advance with a smaller value is a no-op
298 2 : assert_eq!(seq.advance(98), 100);
299 2 : assert_eq!(seq.load(), 100);
300 2 :
301 2 : jh1.await.unwrap();
302 2 : jh2.await.unwrap();
303 2 :
304 2 : seq.shutdown();
305 2 : }
306 :
307 : #[tokio::test]
308 2 : async fn seqwait_timeout() {
309 2 : let seq = Arc::new(SeqWait::new(0));
310 2 : let seq2 = Arc::clone(&seq);
311 2 : let jh = tokio::task::spawn(async move {
312 2 : let timeout = Duration::from_millis(1);
313 2 : let res = seq2.wait_for_timeout(42, timeout).await;
314 2 : assert_eq!(res, Err(SeqWaitError::Timeout));
315 2 : });
316 2 : tokio::time::sleep(Duration::from_millis(200)).await;
317 2 : // This will attempt to wake, but nothing will happen
318 2 : // because the waiter already dropped its Receiver.
319 2 : let old = seq.advance(99);
320 2 : assert_eq!(old, 0);
321 2 : jh.await.unwrap();
322 2 :
323 2 : seq.shutdown();
324 2 : }
325 : }
|