Line data Source code
1 : #![warn(missing_docs)]
2 :
3 : use std::cmp::{Eq, Ordering};
4 : use std::collections::BinaryHeap;
5 : use std::mem;
6 : use std::sync::Mutex;
7 : use std::time::Duration;
8 :
9 : use tokio::sync::watch::{self, channel};
10 : use tokio::time::timeout;
11 :
12 : /// An error happened while waiting for a number
13 : #[derive(Debug, PartialEq, Eq, thiserror::Error)]
14 : pub enum SeqWaitError {
15 : /// The wait timeout was reached
16 : #[error("seqwait timeout was reached")]
17 : Timeout,
18 :
19 : /// [`SeqWait::shutdown`] was called
20 : #[error("SeqWait::shutdown was called")]
21 : Shutdown,
22 : }
23 :
24 : /// Monotonically increasing value
25 : ///
26 : /// It is handy to store some other fields under the same mutex in `SeqWait<S>`
27 : /// (e.g. store prev_record_lsn). So we allow SeqWait to be parametrized with
28 : /// any type that can expose counter. `V` is the type of exposed counter.
29 : pub trait MonotonicCounter<V> {
30 : /// Bump counter value and check that it goes forward
31 : /// N.B.: new_val is an actual new value, not a difference.
32 : fn cnt_advance(&mut self, new_val: V);
33 :
34 : /// Get counter value
35 : fn cnt_value(&self) -> V;
36 : }
37 :
38 : /// Heap of waiters, lowest numbers pop first.
39 : struct Waiters<V>
40 : where
41 : V: Ord,
42 : {
43 : heap: BinaryHeap<Waiter<V>>,
44 : /// Number of the first waiter in the heap, or None if there are no waiters.
45 : status_channel: watch::Sender<Option<V>>,
46 : }
47 :
48 : impl<V> Waiters<V>
49 : where
50 : V: Ord + Copy,
51 : {
52 25923 : fn new() -> Self {
53 25923 : Waiters {
54 25923 : heap: BinaryHeap::new(),
55 25923 : status_channel: channel(None).0,
56 25923 : }
57 2 : }
58 :
59 : /// `status_channel` contains the number of the first waiter in the heap.
60 : /// This function should be called whenever waiters heap changes.
61 15 : fn update_status(&self) {
62 15 : let first_waiter = self.heap.peek().map(|w| w.wake_num);
63 15 : let _ = self.status_channel.send_replace(first_waiter);
64 10 : }
65 :
66 : /// Add new waiter to the heap, return a channel that will be notified when the number arrives.
67 5 : fn add(&mut self, num: V) -> watch::Receiver<()> {
68 5 : let (tx, rx) = channel(());
69 5 : self.heap.push(Waiter {
70 5 : wake_num: num,
71 5 : wake_channel: tx,
72 5 : });
73 5 : self.update_status();
74 5 : rx
75 5 : }
76 :
77 : /// Pop all waiters <= num from the heap. Collect channels in a vector,
78 : /// so that caller can wake them up.
79 2402250 : fn pop_leq(&mut self, num: V) -> Vec<watch::Sender<()>> {
80 2402250 : let mut wake_these = Vec::new();
81 2402254 : while let Some(n) = self.heap.peek() {
82 4 : if n.wake_num > num {
83 0 : break;
84 4 : }
85 4 : wake_these.push(self.heap.pop().unwrap().wake_channel);
86 : }
87 2402250 : if !wake_these.is_empty() {
88 3 : self.update_status();
89 3 : }
90 2402250 : wake_these
91 3 : }
92 :
93 : /// Used on shutdown to efficiently drop all waiters.
94 7 : fn take_all(&mut self) -> BinaryHeap<Waiter<V>> {
95 7 : let heap = mem::take(&mut self.heap);
96 7 : self.update_status();
97 7 : heap
98 2 : }
99 : }
100 :
101 : struct Waiter<T>
102 : where
103 : T: Ord,
104 : {
105 : wake_num: T, // wake me when this number arrives ...
106 : wake_channel: watch::Sender<()>, // ... by sending a message to this channel
107 : }
108 :
109 : // BinaryHeap is a max-heap, and we want a min-heap. Reverse the ordering here
110 : // to get that.
111 : impl<T: Ord> PartialOrd for Waiter<T> {
112 1 : fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
113 1 : Some(self.cmp(other))
114 1 : }
115 : }
116 :
117 : impl<T: Ord> Ord for Waiter<T> {
118 1 : fn cmp(&self, other: &Self) -> Ordering {
119 1 : other.wake_num.cmp(&self.wake_num)
120 1 : }
121 : }
122 :
123 : impl<T: Ord> PartialEq for Waiter<T> {
124 0 : fn eq(&self, other: &Self) -> bool {
125 0 : other.wake_num == self.wake_num
126 0 : }
127 : }
128 :
129 : impl<T: Ord> Eq for Waiter<T> {}
130 :
131 : /// Internal components of a `SeqWait`
132 : struct SeqWaitInt<S, V>
133 : where
134 : S: MonotonicCounter<V>,
135 : V: Ord,
136 : {
137 : waiters: Waiters<V>,
138 : current: S,
139 : shutdown: bool,
140 : }
141 :
142 : /// A tool for waiting on a sequence number
143 : ///
144 : /// This provides a way to wait the arrival of a number.
145 : /// As soon as the number arrives by another caller calling
146 : /// [`advance`], then the waiter will be woken up.
147 : ///
148 : /// This implementation takes a blocking Mutex on both [`wait_for`]
149 : /// and [`advance`], meaning there may be unexpected executor blocking
150 : /// due to thread scheduling unfairness. There are probably better
151 : /// implementations, but we can probably live with this for now.
152 : ///
153 : /// [`wait_for`]: SeqWait::wait_for
154 : /// [`advance`]: SeqWait::advance
155 : ///
156 : /// `S` means Storage, `V` is type of counter that this storage exposes.
157 : ///
158 : pub struct SeqWait<S, V>
159 : where
160 : S: MonotonicCounter<V>,
161 : V: Ord,
162 : {
163 : internal: Mutex<SeqWaitInt<S, V>>,
164 : }
165 :
166 : impl<S, V> SeqWait<S, V>
167 : where
168 : S: MonotonicCounter<V> + Copy,
169 : V: Ord + Copy,
170 : {
171 : /// Create a new `SeqWait`, initialized to a particular number
172 25923 : pub fn new(starting_num: S) -> Self {
173 25923 : let internal = SeqWaitInt {
174 25923 : waiters: Waiters::new(),
175 25923 : current: starting_num,
176 25923 : shutdown: false,
177 25923 : };
178 25923 : SeqWait {
179 25923 : internal: Mutex::new(internal),
180 25923 : }
181 2 : }
182 :
183 : /// Shut down a `SeqWait`, causing all waiters (present and
184 : /// future) to return an error.
185 7 : pub fn shutdown(&self) {
186 7 : let waiters = {
187 : // Prevent new waiters; wake all those that exist.
188 : // Wake everyone with an error.
189 7 : let mut internal = self.internal.lock().unwrap();
190 :
191 : // Block any future waiters from starting
192 7 : internal.shutdown = true;
193 :
194 : // Take all waiters to drop them later.
195 7 : internal.waiters.take_all()
196 :
197 : // Drop the lock as we exit this scope.
198 : };
199 :
200 : // When we drop the waiters list, each Receiver will
201 : // be woken with an error.
202 : // This drop doesn't need to be explicit; it's done
203 : // here to make it easier to read the code and understand
204 : // the order of events.
205 7 : drop(waiters);
206 2 : }
207 :
208 : /// Wait for a number to arrive
209 : ///
210 : /// This call won't complete until someone has called `advance`
211 : /// with a number greater than or equal to the one we're waiting for.
212 : ///
213 : /// This function is async cancellation-safe.
214 4 : pub async fn wait_for(&self, num: V) -> Result<(), SeqWaitError> {
215 4 : match self.queue_for_wait(num) {
216 1 : Ok(None) => Ok(()),
217 3 : Ok(Some(mut rx)) => rx.changed().await.map_err(|_| SeqWaitError::Shutdown),
218 0 : Err(e) => Err(e),
219 : }
220 4 : }
221 :
222 : /// Wait for a number to arrive
223 : ///
224 : /// This call won't complete until someone has called `advance`
225 : /// with a number greater than or equal to the one we're waiting for.
226 : ///
227 : /// If that hasn't happened after the specified timeout duration,
228 : /// [`SeqWaitError::Timeout`] will be returned.
229 : ///
230 : /// This function is async cancellation-safe.
231 112982 : pub async fn wait_for_timeout(
232 112982 : &self,
233 112982 : num: V,
234 112982 : timeout_duration: Duration,
235 112982 : ) -> Result<(), SeqWaitError> {
236 112982 : match self.queue_for_wait(num) {
237 112980 : Ok(None) => Ok(()),
238 2 : Ok(Some(mut rx)) => match timeout(timeout_duration, rx.changed()).await {
239 0 : Ok(Ok(())) => Ok(()),
240 0 : Ok(Err(_)) => Err(SeqWaitError::Shutdown),
241 2 : Err(_) => Err(SeqWaitError::Timeout),
242 : },
243 0 : Err(e) => Err(e),
244 : }
245 2 : }
246 :
247 : /// Check if [`Self::wait_for`] or [`Self::wait_for_timeout`] would wait if called with `num`.
248 0 : pub fn would_wait_for(&self, num: V) -> Result<(), V> {
249 0 : let internal = self.internal.lock().unwrap();
250 0 : let cnt = internal.current.cnt_value();
251 0 : drop(internal);
252 0 : if cnt >= num { Ok(()) } else { Err(cnt) }
253 0 : }
254 :
255 : /// Register and return a channel that will be notified when a number arrives,
256 : /// or None, if it has already arrived.
257 112986 : fn queue_for_wait(&self, num: V) -> Result<Option<watch::Receiver<()>>, SeqWaitError> {
258 112986 : let mut internal = self.internal.lock().unwrap();
259 112986 : if internal.current.cnt_value() >= num {
260 112981 : return Ok(None);
261 5 : }
262 5 : if internal.shutdown {
263 0 : return Err(SeqWaitError::Shutdown);
264 5 : }
265 :
266 : // Add waiter channel to the queue.
267 5 : let rx = internal.waiters.add(num);
268 : // Drop the lock as we exit this scope.
269 5 : Ok(Some(rx))
270 6 : }
271 :
272 : /// Announce a new number has arrived
273 : ///
274 : /// All waiters at this value or below will be woken.
275 : ///
276 : /// Returns the old number.
277 2639595 : pub fn advance(&self, num: V) -> V {
278 : let old_value;
279 2402250 : let wake_these = {
280 2639595 : let mut internal = self.internal.lock().unwrap();
281 :
282 2639595 : old_value = internal.current.cnt_value();
283 2639595 : if old_value >= num {
284 237345 : return old_value;
285 3 : }
286 2402250 : internal.current.cnt_advance(num);
287 :
288 : // Pop all waiters <= num from the heap.
289 2402250 : internal.waiters.pop_leq(num)
290 : };
291 :
292 2402254 : for tx in wake_these {
293 4 : // This can fail if there are no receivers.
294 4 : // We don't care; discard the error.
295 4 : let _ = tx.send(());
296 4 : }
297 2402250 : old_value
298 4 : }
299 :
300 : /// Read the current value, without waiting.
301 138943 : pub fn load(&self) -> S {
302 138943 : self.internal.lock().unwrap().current
303 1 : }
304 :
305 : /// Get a Receiver for the current status.
306 : ///
307 : /// The current status is the number of the first waiter in the queue,
308 : /// or None if there are no waiters.
309 : ///
310 : /// This receiver will be notified whenever the status changes.
311 : /// It is useful for receiving notifications when the first waiter
312 : /// starts waiting for a number, or when there are no more waiters left.
313 0 : pub fn status_receiver(&self) -> watch::Receiver<Option<V>> {
314 0 : self.internal
315 0 : .lock()
316 0 : .unwrap()
317 0 : .waiters
318 0 : .status_channel
319 0 : .subscribe()
320 0 : }
321 : }
322 :
323 : #[cfg(test)]
324 : mod tests {
325 : use std::sync::Arc;
326 :
327 : use super::*;
328 :
329 : impl MonotonicCounter<i32> for i32 {
330 3 : fn cnt_advance(&mut self, val: i32) {
331 3 : assert!(*self <= val);
332 3 : *self = val;
333 3 : }
334 10 : fn cnt_value(&self) -> i32 {
335 10 : *self
336 10 : }
337 : }
338 :
339 : #[tokio::test]
340 1 : async fn seqwait() {
341 1 : let seq = Arc::new(SeqWait::new(0));
342 1 : let seq2 = Arc::clone(&seq);
343 1 : let seq3 = Arc::clone(&seq);
344 1 : let jh1 = tokio::task::spawn(async move {
345 1 : seq2.wait_for(42).await.expect("wait_for 42");
346 1 : let old = seq2.advance(100);
347 1 : assert_eq!(old, 99);
348 1 : seq2.wait_for_timeout(999, Duration::from_millis(100))
349 1 : .await
350 1 : .expect_err("no 999");
351 1 : });
352 1 : let jh2 = tokio::task::spawn(async move {
353 1 : seq3.wait_for(42).await.expect("wait_for 42");
354 1 : seq3.wait_for(0).await.expect("wait_for 0");
355 1 : });
356 1 : tokio::time::sleep(Duration::from_millis(200)).await;
357 1 : let old = seq.advance(99);
358 1 : assert_eq!(old, 0);
359 1 : seq.wait_for(100).await.expect("wait_for 100");
360 :
361 : // Calling advance with a smaller value is a no-op
362 1 : assert_eq!(seq.advance(98), 100);
363 1 : assert_eq!(seq.load(), 100);
364 :
365 1 : jh1.await.unwrap();
366 1 : jh2.await.unwrap();
367 :
368 1 : seq.shutdown();
369 1 : }
370 :
371 : #[tokio::test]
372 1 : async fn seqwait_timeout() {
373 1 : let seq = Arc::new(SeqWait::new(0));
374 1 : let seq2 = Arc::clone(&seq);
375 1 : let jh = tokio::task::spawn(async move {
376 1 : let timeout = Duration::from_millis(1);
377 1 : let res = seq2.wait_for_timeout(42, timeout).await;
378 1 : assert_eq!(res, Err(SeqWaitError::Timeout));
379 1 : });
380 1 : tokio::time::sleep(Duration::from_millis(200)).await;
381 : // This will attempt to wake, but nothing will happen
382 : // because the waiter already dropped its Receiver.
383 1 : let old = seq.advance(99);
384 1 : assert_eq!(old, 0);
385 1 : jh.await.unwrap();
386 :
387 1 : seq.shutdown();
388 1 : }
389 : }
|