Line data Source code
1 : #![warn(missing_docs)]
2 :
3 : use std::cmp::{Eq, Ordering};
4 : use std::collections::BinaryHeap;
5 : use std::mem;
6 : use std::sync::Mutex;
7 : use std::time::Duration;
8 : use tokio::sync::watch::{self, channel};
9 : use tokio::time::timeout;
10 :
11 : /// An error happened while waiting for a number
12 0 : #[derive(Debug, PartialEq, Eq, thiserror::Error)]
13 : pub enum SeqWaitError {
14 : /// The wait timeout was reached
15 : #[error("seqwait timeout was reached")]
16 : Timeout,
17 :
18 : /// [`SeqWait::shutdown`] was called
19 : #[error("SeqWait::shutdown was called")]
20 : Shutdown,
21 : }
22 :
23 : /// Monotonically increasing value
24 : ///
25 : /// It is handy to store some other fields under the same mutex in `SeqWait<S>`
26 : /// (e.g. store prev_record_lsn). So we allow SeqWait to be parametrized with
27 : /// any type that can expose counter. `V` is the type of exposed counter.
28 : pub trait MonotonicCounter<V> {
29 : /// Bump counter value and check that it goes forward
30 : /// N.B.: new_val is an actual new value, not a difference.
31 : fn cnt_advance(&mut self, new_val: V);
32 :
33 : /// Get counter value
34 : fn cnt_value(&self) -> V;
35 : }
36 :
37 : /// Heap of waiters, lowest numbers pop first.
38 : struct Waiters<V>
39 : where
40 : V: Ord,
41 : {
42 : heap: BinaryHeap<Waiter<V>>,
43 : /// Number of the first waiter in the heap, or None if there are no waiters.
44 : status_channel: watch::Sender<Option<V>>,
45 : }
46 :
47 : impl<V> Waiters<V>
48 : where
49 : V: Ord + Copy,
50 : {
51 1266 : fn new() -> Self {
52 1266 : Waiters {
53 1266 : heap: BinaryHeap::new(),
54 1266 : status_channel: channel(None).0,
55 1266 : }
56 1266 : }
57 :
58 : /// `status_channel` contains the number of the first waiter in the heap.
59 : /// This function should be called whenever waiters heap changes.
60 14413312 : fn update_status(&self) {
61 14413312 : let first_waiter = self.heap.peek().map(|w| w.wake_num);
62 14413312 : let _ = self.status_channel.send_replace(first_waiter);
63 14413312 : }
64 :
65 : /// Add new waiter to the heap, return a channel that will be notified when the number arrives.
66 5 : fn add(&mut self, num: V) -> watch::Receiver<()> {
67 5 : let (tx, rx) = channel(());
68 5 : self.heap.push(Waiter {
69 5 : wake_num: num,
70 5 : wake_channel: tx,
71 5 : });
72 5 : self.update_status();
73 5 : rx
74 5 : }
75 :
76 : /// Pop all waiters <= num from the heap. Collect channels in a vector,
77 : /// so that caller can wake them up.
78 14413281 : fn pop_leq(&mut self, num: V) -> Vec<watch::Sender<()>> {
79 14413281 : let mut wake_these = Vec::new();
80 14413285 : while let Some(n) = self.heap.peek() {
81 4 : if n.wake_num > num {
82 0 : break;
83 4 : }
84 4 : wake_these.push(self.heap.pop().unwrap().wake_channel);
85 : }
86 14413281 : self.update_status();
87 14413281 : wake_these
88 14413281 : }
89 :
90 : /// Used on shutdown to efficiently drop all waiters.
91 26 : fn take_all(&mut self) -> BinaryHeap<Waiter<V>> {
92 26 : let heap = mem::take(&mut self.heap);
93 26 : self.update_status();
94 26 : heap
95 26 : }
96 : }
97 :
98 : struct Waiter<T>
99 : where
100 : T: Ord,
101 : {
102 : wake_num: T, // wake me when this number arrives ...
103 : wake_channel: watch::Sender<()>, // ... by sending a message to this channel
104 : }
105 :
106 : // BinaryHeap is a max-heap, and we want a min-heap. Reverse the ordering here
107 : // to get that.
108 : impl<T: Ord> PartialOrd for Waiter<T> {
109 1 : fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
110 1 : Some(self.cmp(other))
111 1 : }
112 : }
113 :
114 : impl<T: Ord> Ord for Waiter<T> {
115 1 : fn cmp(&self, other: &Self) -> Ordering {
116 1 : other.wake_num.cmp(&self.wake_num)
117 1 : }
118 : }
119 :
120 : impl<T: Ord> PartialEq for Waiter<T> {
121 0 : fn eq(&self, other: &Self) -> bool {
122 0 : other.wake_num == self.wake_num
123 0 : }
124 : }
125 :
126 : impl<T: Ord> Eq for Waiter<T> {}
127 :
128 : /// Internal components of a `SeqWait`
129 : struct SeqWaitInt<S, V>
130 : where
131 : S: MonotonicCounter<V>,
132 : V: Ord,
133 : {
134 : waiters: Waiters<V>,
135 : current: S,
136 : shutdown: bool,
137 : }
138 :
139 : /// A tool for waiting on a sequence number
140 : ///
141 : /// This provides a way to wait the arrival of a number.
142 : /// As soon as the number arrives by another caller calling
143 : /// [`advance`], then the waiter will be woken up.
144 : ///
145 : /// This implementation takes a blocking Mutex on both [`wait_for`]
146 : /// and [`advance`], meaning there may be unexpected executor blocking
147 : /// due to thread scheduling unfairness. There are probably better
148 : /// implementations, but we can probably live with this for now.
149 : ///
150 : /// [`wait_for`]: SeqWait::wait_for
151 : /// [`advance`]: SeqWait::advance
152 : ///
153 : /// `S` means Storage, `V` is type of counter that this storage exposes.
154 : ///
155 : pub struct SeqWait<S, V>
156 : where
157 : S: MonotonicCounter<V>,
158 : V: Ord,
159 : {
160 : internal: Mutex<SeqWaitInt<S, V>>,
161 : }
162 :
163 : impl<S, V> SeqWait<S, V>
164 : where
165 : S: MonotonicCounter<V> + Copy,
166 : V: Ord + Copy,
167 : {
168 : /// Create a new `SeqWait`, initialized to a particular number
169 1266 : pub fn new(starting_num: S) -> Self {
170 1266 : let internal = SeqWaitInt {
171 1266 : waiters: Waiters::new(),
172 1266 : current: starting_num,
173 1266 : shutdown: false,
174 1266 : };
175 1266 : SeqWait {
176 1266 : internal: Mutex::new(internal),
177 1266 : }
178 1266 : }
179 :
180 : /// Shut down a `SeqWait`, causing all waiters (present and
181 : /// future) to return an error.
182 26 : pub fn shutdown(&self) {
183 26 : let waiters = {
184 26 : // Prevent new waiters; wake all those that exist.
185 26 : // Wake everyone with an error.
186 26 : let mut internal = self.internal.lock().unwrap();
187 26 :
188 26 : // Block any future waiters from starting
189 26 : internal.shutdown = true;
190 26 :
191 26 : // Take all waiters to drop them later.
192 26 : internal.waiters.take_all()
193 26 :
194 26 : // Drop the lock as we exit this scope.
195 26 : };
196 26 :
197 26 : // When we drop the waiters list, each Receiver will
198 26 : // be woken with an error.
199 26 : // This drop doesn't need to be explicit; it's done
200 26 : // here to make it easier to read the code and understand
201 26 : // the order of events.
202 26 : drop(waiters);
203 26 : }
204 :
205 : /// Wait for a number to arrive
206 : ///
207 : /// This call won't complete until someone has called `advance`
208 : /// with a number greater than or equal to the one we're waiting for.
209 : ///
210 : /// This function is async cancellation-safe.
211 4 : pub async fn wait_for(&self, num: V) -> Result<(), SeqWaitError> {
212 4 : match self.queue_for_wait(num) {
213 1 : Ok(None) => Ok(()),
214 3 : Ok(Some(mut rx)) => rx.changed().await.map_err(|_| SeqWaitError::Shutdown),
215 0 : Err(e) => Err(e),
216 : }
217 4 : }
218 :
219 : /// Wait for a number to arrive
220 : ///
221 : /// This call won't complete until someone has called `advance`
222 : /// with a number greater than or equal to the one we're waiting for.
223 : ///
224 : /// If that hasn't happened after the specified timeout duration,
225 : /// [`SeqWaitError::Timeout`] will be returned.
226 : ///
227 : /// This function is async cancellation-safe.
228 675666 : pub async fn wait_for_timeout(
229 675666 : &self,
230 675666 : num: V,
231 675666 : timeout_duration: Duration,
232 675666 : ) -> Result<(), SeqWaitError> {
233 675666 : match self.queue_for_wait(num) {
234 675664 : Ok(None) => Ok(()),
235 2 : Ok(Some(mut rx)) => match timeout(timeout_duration, rx.changed()).await {
236 0 : Ok(Ok(())) => Ok(()),
237 0 : Ok(Err(_)) => Err(SeqWaitError::Shutdown),
238 2 : Err(_) => Err(SeqWaitError::Timeout),
239 : },
240 0 : Err(e) => Err(e),
241 : }
242 675666 : }
243 :
244 : /// Check if [`Self::wait_for`] or [`Self::wait_for_timeout`] would wait if called with `num`.
245 0 : pub fn would_wait_for(&self, num: V) -> Result<(), V> {
246 0 : let internal = self.internal.lock().unwrap();
247 0 : let cnt = internal.current.cnt_value();
248 0 : drop(internal);
249 0 : if cnt >= num {
250 0 : Ok(())
251 : } else {
252 0 : Err(cnt)
253 : }
254 0 : }
255 :
256 : /// Register and return a channel that will be notified when a number arrives,
257 : /// or None, if it has already arrived.
258 675670 : fn queue_for_wait(&self, num: V) -> Result<Option<watch::Receiver<()>>, SeqWaitError> {
259 675670 : let mut internal = self.internal.lock().unwrap();
260 675670 : if internal.current.cnt_value() >= num {
261 675665 : return Ok(None);
262 5 : }
263 5 : if internal.shutdown {
264 0 : return Err(SeqWaitError::Shutdown);
265 5 : }
266 5 :
267 5 : // Add waiter channel to the queue.
268 5 : let rx = internal.waiters.add(num);
269 5 : // Drop the lock as we exit this scope.
270 5 : Ok(Some(rx))
271 675670 : }
272 :
273 : /// Announce a new number has arrived
274 : ///
275 : /// All waiters at this value or below will be woken.
276 : ///
277 : /// Returns the old number.
278 15837346 : pub fn advance(&self, num: V) -> V {
279 : let old_value;
280 14413281 : let wake_these = {
281 15837346 : let mut internal = self.internal.lock().unwrap();
282 15837346 :
283 15837346 : old_value = internal.current.cnt_value();
284 15837346 : if old_value >= num {
285 1424065 : return old_value;
286 14413281 : }
287 14413281 : internal.current.cnt_advance(num);
288 14413281 :
289 14413281 : // Pop all waiters <= num from the heap.
290 14413281 : internal.waiters.pop_leq(num)
291 : };
292 :
293 14413285 : for tx in wake_these {
294 4 : // This can fail if there are no receivers.
295 4 : // We don't care; discard the error.
296 4 : let _ = tx.send(());
297 4 : }
298 14413281 : old_value
299 15837346 : }
300 :
301 : /// Read the current value, without waiting.
302 833773 : pub fn load(&self) -> S {
303 833773 : self.internal.lock().unwrap().current
304 833773 : }
305 :
306 : /// Get a Receiver for the current status.
307 : ///
308 : /// The current status is the number of the first waiter in the queue,
309 : /// or None if there are no waiters.
310 : ///
311 : /// This receiver will be notified whenever the status changes.
312 : /// It is useful for receiving notifications when the first waiter
313 : /// starts waiting for a number, or when there are no more waiters left.
314 0 : pub fn status_receiver(&self) -> watch::Receiver<Option<V>> {
315 0 : self.internal
316 0 : .lock()
317 0 : .unwrap()
318 0 : .waiters
319 0 : .status_channel
320 0 : .subscribe()
321 0 : }
322 : }
323 :
324 : #[cfg(test)]
325 : mod tests {
326 : use super::*;
327 : use std::sync::Arc;
328 :
329 : impl MonotonicCounter<i32> for i32 {
330 3 : fn cnt_advance(&mut self, val: i32) {
331 3 : assert!(*self <= val);
332 3 : *self = val;
333 3 : }
334 10 : fn cnt_value(&self) -> i32 {
335 10 : *self
336 10 : }
337 : }
338 :
339 : #[tokio::test]
340 1 : async fn seqwait() {
341 1 : let seq = Arc::new(SeqWait::new(0));
342 1 : let seq2 = Arc::clone(&seq);
343 1 : let seq3 = Arc::clone(&seq);
344 1 : let jh1 = tokio::task::spawn(async move {
345 1 : seq2.wait_for(42).await.expect("wait_for 42");
346 1 : let old = seq2.advance(100);
347 1 : assert_eq!(old, 99);
348 1 : seq2.wait_for_timeout(999, Duration::from_millis(100))
349 1 : .await
350 1 : .expect_err("no 999");
351 1 : });
352 1 : let jh2 = tokio::task::spawn(async move {
353 1 : seq3.wait_for(42).await.expect("wait_for 42");
354 1 : seq3.wait_for(0).await.expect("wait_for 0");
355 1 : });
356 1 : tokio::time::sleep(Duration::from_millis(200)).await;
357 1 : let old = seq.advance(99);
358 1 : assert_eq!(old, 0);
359 1 : seq.wait_for(100).await.expect("wait_for 100");
360 1 :
361 1 : // Calling advance with a smaller value is a no-op
362 1 : assert_eq!(seq.advance(98), 100);
363 1 : assert_eq!(seq.load(), 100);
364 1 :
365 1 : jh1.await.unwrap();
366 1 : jh2.await.unwrap();
367 1 :
368 1 : seq.shutdown();
369 1 : }
370 :
371 : #[tokio::test]
372 1 : async fn seqwait_timeout() {
373 1 : let seq = Arc::new(SeqWait::new(0));
374 1 : let seq2 = Arc::clone(&seq);
375 1 : let jh = tokio::task::spawn(async move {
376 1 : let timeout = Duration::from_millis(1);
377 1 : let res = seq2.wait_for_timeout(42, timeout).await;
378 1 : assert_eq!(res, Err(SeqWaitError::Timeout));
379 1 : });
380 1 : tokio::time::sleep(Duration::from_millis(200)).await;
381 1 : // This will attempt to wake, but nothing will happen
382 1 : // because the waiter already dropped its Receiver.
383 1 : let old = seq.advance(99);
384 1 : assert_eq!(old, 0);
385 1 : jh.await.unwrap();
386 1 :
387 1 : seq.shutdown();
388 1 : }
389 : }
|