Line data Source code
1 : #![warn(missing_docs)]
2 :
3 : use std::cmp::{Eq, Ordering};
4 : use std::collections::BinaryHeap;
5 : use std::mem;
6 : use std::sync::Mutex;
7 : use std::time::Duration;
8 : use tokio::sync::watch::{self, channel};
9 : use tokio::time::timeout;
10 :
11 : /// An error happened while waiting for a number
12 0 : #[derive(Debug, PartialEq, Eq, thiserror::Error)]
13 : pub enum SeqWaitError {
14 : /// The wait timeout was reached
15 : #[error("seqwait timeout was reached")]
16 : Timeout,
17 :
18 : /// [`SeqWait::shutdown`] was called
19 : #[error("SeqWait::shutdown was called")]
20 : Shutdown,
21 : }
22 :
23 : /// Monotonically increasing value
24 : ///
25 : /// It is handy to store some other fields under the same mutex in `SeqWait<S>`
26 : /// (e.g. store prev_record_lsn). So we allow SeqWait to be parametrized with
27 : /// any type that can expose counter. `V` is the type of exposed counter.
28 : pub trait MonotonicCounter<V> {
29 : /// Bump counter value and check that it goes forward
30 : /// N.B.: new_val is an actual new value, not a difference.
31 : fn cnt_advance(&mut self, new_val: V);
32 :
33 : /// Get counter value
34 : fn cnt_value(&self) -> V;
35 : }
36 :
37 : /// Heap of waiters, lowest numbers pop first.
38 : struct Waiters<V>
39 : where
40 : V: Ord,
41 : {
42 : heap: BinaryHeap<Waiter<V>>,
43 : /// Number of the first waiter in the heap, or None if there are no waiters.
44 : status_channel: watch::Sender<Option<V>>,
45 : }
46 :
47 : impl<V> Waiters<V>
48 : where
49 : V: Ord + Copy,
50 : {
51 25458 : fn new() -> Self {
52 25458 : Waiters {
53 25458 : heap: BinaryHeap::new(),
54 25458 : status_channel: channel(None).0,
55 25458 : }
56 25458 : }
57 :
58 : /// `status_channel` contains the number of the first waiter in the heap.
59 : /// This function should be called whenever waiters heap changes.
60 20 : fn update_status(&self) {
61 20 : let first_waiter = self.heap.peek().map(|w| w.wake_num);
62 20 : let _ = self.status_channel.send_replace(first_waiter);
63 20 : }
64 :
65 : /// Add new waiter to the heap, return a channel that will be notified when the number arrives.
66 5 : fn add(&mut self, num: V) -> watch::Receiver<()> {
67 5 : let (tx, rx) = channel(());
68 5 : self.heap.push(Waiter {
69 5 : wake_num: num,
70 5 : wake_channel: tx,
71 5 : });
72 5 : self.update_status();
73 5 : rx
74 5 : }
75 :
76 : /// Pop all waiters <= num from the heap. Collect channels in a vector,
77 : /// so that caller can wake them up.
78 4804419 : fn pop_leq(&mut self, num: V) -> Vec<watch::Sender<()>> {
79 4804419 : let mut wake_these = Vec::new();
80 4804423 : while let Some(n) = self.heap.peek() {
81 4 : if n.wake_num > num {
82 0 : break;
83 4 : }
84 4 : wake_these.push(self.heap.pop().unwrap().wake_channel);
85 : }
86 4804419 : if !wake_these.is_empty() {
87 3 : self.update_status();
88 4804416 : }
89 4804419 : wake_these
90 4804419 : }
91 :
92 : /// Used on shutdown to efficiently drop all waiters.
93 12 : fn take_all(&mut self) -> BinaryHeap<Waiter<V>> {
94 12 : let heap = mem::take(&mut self.heap);
95 12 : self.update_status();
96 12 : heap
97 12 : }
98 : }
99 :
100 : struct Waiter<T>
101 : where
102 : T: Ord,
103 : {
104 : wake_num: T, // wake me when this number arrives ...
105 : wake_channel: watch::Sender<()>, // ... by sending a message to this channel
106 : }
107 :
108 : // BinaryHeap is a max-heap, and we want a min-heap. Reverse the ordering here
109 : // to get that.
110 : impl<T: Ord> PartialOrd for Waiter<T> {
111 1 : fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
112 1 : Some(self.cmp(other))
113 1 : }
114 : }
115 :
116 : impl<T: Ord> Ord for Waiter<T> {
117 1 : fn cmp(&self, other: &Self) -> Ordering {
118 1 : other.wake_num.cmp(&self.wake_num)
119 1 : }
120 : }
121 :
122 : impl<T: Ord> PartialEq for Waiter<T> {
123 0 : fn eq(&self, other: &Self) -> bool {
124 0 : other.wake_num == self.wake_num
125 0 : }
126 : }
127 :
128 : impl<T: Ord> Eq for Waiter<T> {}
129 :
130 : /// Internal components of a `SeqWait`
131 : struct SeqWaitInt<S, V>
132 : where
133 : S: MonotonicCounter<V>,
134 : V: Ord,
135 : {
136 : waiters: Waiters<V>,
137 : current: S,
138 : shutdown: bool,
139 : }
140 :
141 : /// A tool for waiting on a sequence number
142 : ///
143 : /// This provides a way to wait the arrival of a number.
144 : /// As soon as the number arrives by another caller calling
145 : /// [`advance`], then the waiter will be woken up.
146 : ///
147 : /// This implementation takes a blocking Mutex on both [`wait_for`]
148 : /// and [`advance`], meaning there may be unexpected executor blocking
149 : /// due to thread scheduling unfairness. There are probably better
150 : /// implementations, but we can probably live with this for now.
151 : ///
152 : /// [`wait_for`]: SeqWait::wait_for
153 : /// [`advance`]: SeqWait::advance
154 : ///
155 : /// `S` means Storage, `V` is type of counter that this storage exposes.
156 : ///
157 : pub struct SeqWait<S, V>
158 : where
159 : S: MonotonicCounter<V>,
160 : V: Ord,
161 : {
162 : internal: Mutex<SeqWaitInt<S, V>>,
163 : }
164 :
165 : impl<S, V> SeqWait<S, V>
166 : where
167 : S: MonotonicCounter<V> + Copy,
168 : V: Ord + Copy,
169 : {
170 : /// Create a new `SeqWait`, initialized to a particular number
171 25458 : pub fn new(starting_num: S) -> Self {
172 25458 : let internal = SeqWaitInt {
173 25458 : waiters: Waiters::new(),
174 25458 : current: starting_num,
175 25458 : shutdown: false,
176 25458 : };
177 25458 : SeqWait {
178 25458 : internal: Mutex::new(internal),
179 25458 : }
180 25458 : }
181 :
182 : /// Shut down a `SeqWait`, causing all waiters (present and
183 : /// future) to return an error.
184 12 : pub fn shutdown(&self) {
185 12 : let waiters = {
186 12 : // Prevent new waiters; wake all those that exist.
187 12 : // Wake everyone with an error.
188 12 : let mut internal = self.internal.lock().unwrap();
189 12 :
190 12 : // Block any future waiters from starting
191 12 : internal.shutdown = true;
192 12 :
193 12 : // Take all waiters to drop them later.
194 12 : internal.waiters.take_all()
195 12 :
196 12 : // Drop the lock as we exit this scope.
197 12 : };
198 12 :
199 12 : // When we drop the waiters list, each Receiver will
200 12 : // be woken with an error.
201 12 : // This drop doesn't need to be explicit; it's done
202 12 : // here to make it easier to read the code and understand
203 12 : // the order of events.
204 12 : drop(waiters);
205 12 : }
206 :
207 : /// Wait for a number to arrive
208 : ///
209 : /// This call won't complete until someone has called `advance`
210 : /// with a number greater than or equal to the one we're waiting for.
211 : ///
212 : /// This function is async cancellation-safe.
213 4 : pub async fn wait_for(&self, num: V) -> Result<(), SeqWaitError> {
214 4 : match self.queue_for_wait(num) {
215 1 : Ok(None) => Ok(()),
216 3 : Ok(Some(mut rx)) => rx.changed().await.map_err(|_| SeqWaitError::Shutdown),
217 0 : Err(e) => Err(e),
218 : }
219 4 : }
220 :
221 : /// Wait for a number to arrive
222 : ///
223 : /// This call won't complete until someone has called `advance`
224 : /// with a number greater than or equal to the one we're waiting for.
225 : ///
226 : /// If that hasn't happened after the specified timeout duration,
227 : /// [`SeqWaitError::Timeout`] will be returned.
228 : ///
229 : /// This function is async cancellation-safe.
230 226558 : pub async fn wait_for_timeout(
231 226558 : &self,
232 226558 : num: V,
233 226558 : timeout_duration: Duration,
234 226558 : ) -> Result<(), SeqWaitError> {
235 226558 : match self.queue_for_wait(num) {
236 226556 : Ok(None) => Ok(()),
237 2 : Ok(Some(mut rx)) => match timeout(timeout_duration, rx.changed()).await {
238 0 : Ok(Ok(())) => Ok(()),
239 0 : Ok(Err(_)) => Err(SeqWaitError::Shutdown),
240 2 : Err(_) => Err(SeqWaitError::Timeout),
241 : },
242 0 : Err(e) => Err(e),
243 : }
244 226558 : }
245 :
246 : /// Check if [`Self::wait_for`] or [`Self::wait_for_timeout`] would wait if called with `num`.
247 0 : pub fn would_wait_for(&self, num: V) -> Result<(), V> {
248 0 : let internal = self.internal.lock().unwrap();
249 0 : let cnt = internal.current.cnt_value();
250 0 : drop(internal);
251 0 : if cnt >= num {
252 0 : Ok(())
253 : } else {
254 0 : Err(cnt)
255 : }
256 0 : }
257 :
258 : /// Register and return a channel that will be notified when a number arrives,
259 : /// or None, if it has already arrived.
260 226562 : fn queue_for_wait(&self, num: V) -> Result<Option<watch::Receiver<()>>, SeqWaitError> {
261 226562 : let mut internal = self.internal.lock().unwrap();
262 226562 : if internal.current.cnt_value() >= num {
263 226557 : return Ok(None);
264 5 : }
265 5 : if internal.shutdown {
266 0 : return Err(SeqWaitError::Shutdown);
267 5 : }
268 5 :
269 5 : // Add waiter channel to the queue.
270 5 : let rx = internal.waiters.add(num);
271 5 : // Drop the lock as we exit this scope.
272 5 : Ok(Some(rx))
273 226562 : }
274 :
275 : /// Announce a new number has arrived
276 : ///
277 : /// All waiters at this value or below will be woken.
278 : ///
279 : /// Returns the old number.
280 5279108 : pub fn advance(&self, num: V) -> V {
281 : let old_value;
282 4804419 : let wake_these = {
283 5279108 : let mut internal = self.internal.lock().unwrap();
284 5279108 :
285 5279108 : old_value = internal.current.cnt_value();
286 5279108 : if old_value >= num {
287 474689 : return old_value;
288 4804419 : }
289 4804419 : internal.current.cnt_advance(num);
290 4804419 :
291 4804419 : // Pop all waiters <= num from the heap.
292 4804419 : internal.waiters.pop_leq(num)
293 : };
294 :
295 4804423 : for tx in wake_these {
296 4 : // This can fail if there are no receivers.
297 4 : // We don't care; discard the error.
298 4 : let _ = tx.send(());
299 4 : }
300 4804419 : old_value
301 5279108 : }
302 :
303 : /// Read the current value, without waiting.
304 275665 : pub fn load(&self) -> S {
305 275665 : self.internal.lock().unwrap().current
306 275665 : }
307 :
308 : /// Get a Receiver for the current status.
309 : ///
310 : /// The current status is the number of the first waiter in the queue,
311 : /// or None if there are no waiters.
312 : ///
313 : /// This receiver will be notified whenever the status changes.
314 : /// It is useful for receiving notifications when the first waiter
315 : /// starts waiting for a number, or when there are no more waiters left.
316 0 : pub fn status_receiver(&self) -> watch::Receiver<Option<V>> {
317 0 : self.internal
318 0 : .lock()
319 0 : .unwrap()
320 0 : .waiters
321 0 : .status_channel
322 0 : .subscribe()
323 0 : }
324 : }
325 :
326 : #[cfg(test)]
327 : mod tests {
328 : use super::*;
329 : use std::sync::Arc;
330 :
331 : impl MonotonicCounter<i32> for i32 {
332 3 : fn cnt_advance(&mut self, val: i32) {
333 3 : assert!(*self <= val);
334 3 : *self = val;
335 3 : }
336 10 : fn cnt_value(&self) -> i32 {
337 10 : *self
338 10 : }
339 : }
340 :
341 : #[tokio::test]
342 1 : async fn seqwait() {
343 1 : let seq = Arc::new(SeqWait::new(0));
344 1 : let seq2 = Arc::clone(&seq);
345 1 : let seq3 = Arc::clone(&seq);
346 1 : let jh1 = tokio::task::spawn(async move {
347 1 : seq2.wait_for(42).await.expect("wait_for 42");
348 1 : let old = seq2.advance(100);
349 1 : assert_eq!(old, 99);
350 1 : seq2.wait_for_timeout(999, Duration::from_millis(100))
351 1 : .await
352 1 : .expect_err("no 999");
353 1 : });
354 1 : let jh2 = tokio::task::spawn(async move {
355 1 : seq3.wait_for(42).await.expect("wait_for 42");
356 1 : seq3.wait_for(0).await.expect("wait_for 0");
357 1 : });
358 1 : tokio::time::sleep(Duration::from_millis(200)).await;
359 1 : let old = seq.advance(99);
360 1 : assert_eq!(old, 0);
361 1 : seq.wait_for(100).await.expect("wait_for 100");
362 1 :
363 1 : // Calling advance with a smaller value is a no-op
364 1 : assert_eq!(seq.advance(98), 100);
365 1 : assert_eq!(seq.load(), 100);
366 1 :
367 1 : jh1.await.unwrap();
368 1 : jh2.await.unwrap();
369 1 :
370 1 : seq.shutdown();
371 1 : }
372 :
373 : #[tokio::test]
374 1 : async fn seqwait_timeout() {
375 1 : let seq = Arc::new(SeqWait::new(0));
376 1 : let seq2 = Arc::clone(&seq);
377 1 : let jh = tokio::task::spawn(async move {
378 1 : let timeout = Duration::from_millis(1);
379 1 : let res = seq2.wait_for_timeout(42, timeout).await;
380 1 : assert_eq!(res, Err(SeqWaitError::Timeout));
381 1 : });
382 1 : tokio::time::sleep(Duration::from_millis(200)).await;
383 1 : // This will attempt to wake, but nothing will happen
384 1 : // because the waiter already dropped its Receiver.
385 1 : let old = seq.advance(99);
386 1 : assert_eq!(old, 0);
387 1 : jh.await.unwrap();
388 1 :
389 1 : seq.shutdown();
390 1 : }
391 : }
|