Line data Source code
1 : #![warn(missing_docs)]
2 :
3 : use std::cmp::{Eq, Ordering};
4 : use std::collections::BinaryHeap;
5 : use std::mem;
6 : use std::sync::Mutex;
7 : use std::time::Duration;
8 :
9 : use tokio::sync::watch::{self, channel};
10 : use tokio::time::timeout;
11 :
12 : /// An error happened while waiting for a number
13 : #[derive(Debug, PartialEq, Eq, thiserror::Error)]
14 : pub enum SeqWaitError {
15 : /// The wait timeout was reached
16 : #[error("seqwait timeout was reached")]
17 : Timeout,
18 :
19 : /// [`SeqWait::shutdown`] was called
20 : #[error("SeqWait::shutdown was called")]
21 : Shutdown,
22 : }
23 :
24 : /// Monotonically increasing value
25 : ///
26 : /// It is handy to store some other fields under the same mutex in `SeqWait<S>`
27 : /// (e.g. store prev_record_lsn). So we allow SeqWait to be parametrized with
28 : /// any type that can expose counter. `V` is the type of exposed counter.
29 : pub trait MonotonicCounter<V> {
30 : /// Bump counter value and check that it goes forward
31 : /// N.B.: new_val is an actual new value, not a difference.
32 : fn cnt_advance(&mut self, new_val: V);
33 :
34 : /// Get counter value
35 : fn cnt_value(&self) -> V;
36 : }
37 :
38 : /// Heap of waiters, lowest numbers pop first.
39 : struct Waiters<V>
40 : where
41 : V: Ord,
42 : {
43 : heap: BinaryHeap<Waiter<V>>,
44 : /// Number of the first waiter in the heap, or None if there are no waiters.
45 : status_channel: watch::Sender<Option<V>>,
46 : }
47 :
48 : impl<V> Waiters<V>
49 : where
50 : V: Ord + Copy,
51 : {
52 26592 : fn new() -> Self {
53 26592 : Waiters {
54 26592 : heap: BinaryHeap::new(),
55 26592 : status_channel: channel(None).0,
56 26592 : }
57 26592 : }
58 :
59 : /// `status_channel` contains the number of the first waiter in the heap.
60 : /// This function should be called whenever waiters heap changes.
61 30 : fn update_status(&self) {
62 30 : let first_waiter = self.heap.peek().map(|w| w.wake_num);
63 30 : let _ = self.status_channel.send_replace(first_waiter);
64 30 : }
65 :
66 : /// Add new waiter to the heap, return a channel that will be notified when the number arrives.
67 5 : fn add(&mut self, num: V) -> watch::Receiver<()> {
68 5 : let (tx, rx) = channel(());
69 5 : self.heap.push(Waiter {
70 5 : wake_num: num,
71 5 : wake_channel: tx,
72 5 : });
73 5 : self.update_status();
74 5 : rx
75 5 : }
76 :
77 : /// Pop all waiters <= num from the heap. Collect channels in a vector,
78 : /// so that caller can wake them up.
79 9608923 : fn pop_leq(&mut self, num: V) -> Vec<watch::Sender<()>> {
80 9608923 : let mut wake_these = Vec::new();
81 9608927 : while let Some(n) = self.heap.peek() {
82 4 : if n.wake_num > num {
83 0 : break;
84 4 : }
85 4 : wake_these.push(self.heap.pop().unwrap().wake_channel);
86 : }
87 9608923 : if !wake_these.is_empty() {
88 3 : self.update_status();
89 3 : }
90 9608923 : wake_these
91 9608923 : }
92 :
93 : /// Used on shutdown to efficiently drop all waiters.
94 22 : fn take_all(&mut self) -> BinaryHeap<Waiter<V>> {
95 22 : let heap = mem::take(&mut self.heap);
96 22 : self.update_status();
97 22 : heap
98 22 : }
99 : }
100 :
101 : struct Waiter<T>
102 : where
103 : T: Ord,
104 : {
105 : wake_num: T, // wake me when this number arrives ...
106 : wake_channel: watch::Sender<()>, // ... by sending a message to this channel
107 : }
108 :
109 : // BinaryHeap is a max-heap, and we want a min-heap. Reverse the ordering here
110 : // to get that.
111 : impl<T: Ord> PartialOrd for Waiter<T> {
112 1 : fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
113 1 : Some(self.cmp(other))
114 1 : }
115 : }
116 :
117 : impl<T: Ord> Ord for Waiter<T> {
118 1 : fn cmp(&self, other: &Self) -> Ordering {
119 1 : other.wake_num.cmp(&self.wake_num)
120 1 : }
121 : }
122 :
123 : impl<T: Ord> PartialEq for Waiter<T> {
124 0 : fn eq(&self, other: &Self) -> bool {
125 0 : other.wake_num == self.wake_num
126 0 : }
127 : }
128 :
129 : impl<T: Ord> Eq for Waiter<T> {}
130 :
131 : /// Internal components of a `SeqWait`
132 : struct SeqWaitInt<S, V>
133 : where
134 : S: MonotonicCounter<V>,
135 : V: Ord,
136 : {
137 : waiters: Waiters<V>,
138 : current: S,
139 : shutdown: bool,
140 : }
141 :
142 : /// A tool for waiting on a sequence number
143 : ///
144 : /// This provides a way to wait the arrival of a number.
145 : /// As soon as the number arrives by another caller calling
146 : /// [`advance`], then the waiter will be woken up.
147 : ///
148 : /// This implementation takes a blocking Mutex on both [`wait_for`]
149 : /// and [`advance`], meaning there may be unexpected executor blocking
150 : /// due to thread scheduling unfairness. There are probably better
151 : /// implementations, but we can probably live with this for now.
152 : ///
153 : /// [`wait_for`]: SeqWait::wait_for
154 : /// [`advance`]: SeqWait::advance
155 : ///
156 : /// `S` means Storage, `V` is type of counter that this storage exposes.
157 : ///
158 : pub struct SeqWait<S, V>
159 : where
160 : S: MonotonicCounter<V>,
161 : V: Ord,
162 : {
163 : internal: Mutex<SeqWaitInt<S, V>>,
164 : }
165 :
166 : impl<S, V> SeqWait<S, V>
167 : where
168 : S: MonotonicCounter<V> + Copy,
169 : V: Ord + Copy,
170 : {
171 : /// Create a new `SeqWait`, initialized to a particular number
172 26592 : pub fn new(starting_num: S) -> Self {
173 26592 : let internal = SeqWaitInt {
174 26592 : waiters: Waiters::new(),
175 26592 : current: starting_num,
176 26592 : shutdown: false,
177 26592 : };
178 26592 : SeqWait {
179 26592 : internal: Mutex::new(internal),
180 26592 : }
181 26592 : }
182 :
183 : /// Shut down a `SeqWait`, causing all waiters (present and
184 : /// future) to return an error.
185 22 : pub fn shutdown(&self) {
186 22 : let waiters = {
187 22 : // Prevent new waiters; wake all those that exist.
188 22 : // Wake everyone with an error.
189 22 : let mut internal = self.internal.lock().unwrap();
190 22 :
191 22 : // Block any future waiters from starting
192 22 : internal.shutdown = true;
193 22 :
194 22 : // Take all waiters to drop them later.
195 22 : internal.waiters.take_all()
196 22 :
197 22 : // Drop the lock as we exit this scope.
198 22 : };
199 22 :
200 22 : // When we drop the waiters list, each Receiver will
201 22 : // be woken with an error.
202 22 : // This drop doesn't need to be explicit; it's done
203 22 : // here to make it easier to read the code and understand
204 22 : // the order of events.
205 22 : drop(waiters);
206 22 : }
207 :
208 : /// Wait for a number to arrive
209 : ///
210 : /// This call won't complete until someone has called `advance`
211 : /// with a number greater than or equal to the one we're waiting for.
212 : ///
213 : /// This function is async cancellation-safe.
214 4 : pub async fn wait_for(&self, num: V) -> Result<(), SeqWaitError> {
215 4 : match self.queue_for_wait(num) {
216 1 : Ok(None) => Ok(()),
217 3 : Ok(Some(mut rx)) => rx.changed().await.map_err(|_| SeqWaitError::Shutdown),
218 0 : Err(e) => Err(e),
219 : }
220 4 : }
221 :
222 : /// Wait for a number to arrive
223 : ///
224 : /// This call won't complete until someone has called `advance`
225 : /// with a number greater than or equal to the one we're waiting for.
226 : ///
227 : /// If that hasn't happened after the specified timeout duration,
228 : /// [`SeqWaitError::Timeout`] will be returned.
229 : ///
230 : /// This function is async cancellation-safe.
231 449973 : pub async fn wait_for_timeout(
232 449973 : &self,
233 449973 : num: V,
234 449973 : timeout_duration: Duration,
235 449973 : ) -> Result<(), SeqWaitError> {
236 449973 : match self.queue_for_wait(num) {
237 449971 : Ok(None) => Ok(()),
238 2 : Ok(Some(mut rx)) => match timeout(timeout_duration, rx.changed()).await {
239 0 : Ok(Ok(())) => Ok(()),
240 0 : Ok(Err(_)) => Err(SeqWaitError::Shutdown),
241 2 : Err(_) => Err(SeqWaitError::Timeout),
242 : },
243 0 : Err(e) => Err(e),
244 : }
245 2 : }
246 :
247 : /// Check if [`Self::wait_for`] or [`Self::wait_for_timeout`] would wait if called with `num`.
248 0 : pub fn would_wait_for(&self, num: V) -> Result<(), V> {
249 0 : let internal = self.internal.lock().unwrap();
250 0 : let cnt = internal.current.cnt_value();
251 0 : drop(internal);
252 0 : if cnt >= num { Ok(()) } else { Err(cnt) }
253 0 : }
254 :
255 : /// Register and return a channel that will be notified when a number arrives,
256 : /// or None, if it has already arrived.
257 449977 : fn queue_for_wait(&self, num: V) -> Result<Option<watch::Receiver<()>>, SeqWaitError> {
258 449977 : let mut internal = self.internal.lock().unwrap();
259 449977 : if internal.current.cnt_value() >= num {
260 449972 : return Ok(None);
261 5 : }
262 5 : if internal.shutdown {
263 0 : return Err(SeqWaitError::Shutdown);
264 5 : }
265 5 :
266 5 : // Add waiter channel to the queue.
267 5 : let rx = internal.waiters.add(num);
268 5 : // Drop the lock as we exit this scope.
269 5 : Ok(Some(rx))
270 6 : }
271 :
272 : /// Announce a new number has arrived
273 : ///
274 : /// All waiters at this value or below will be woken.
275 : ///
276 : /// Returns the old number.
277 10558300 : pub fn advance(&self, num: V) -> V {
278 : let old_value;
279 9608923 : let wake_these = {
280 10558300 : let mut internal = self.internal.lock().unwrap();
281 10558300 :
282 10558300 : old_value = internal.current.cnt_value();
283 10558300 : if old_value >= num {
284 949377 : return old_value;
285 9608923 : }
286 9608923 : internal.current.cnt_advance(num);
287 9608923 :
288 9608923 : // Pop all waiters <= num from the heap.
289 9608923 : internal.waiters.pop_leq(num)
290 : };
291 :
292 9608927 : for tx in wake_these {
293 4 : // This can fail if there are no receivers.
294 4 : // We don't care; discard the error.
295 4 : let _ = tx.send(());
296 4 : }
297 9608923 : old_value
298 4 : }
299 :
300 : /// Read the current value, without waiting.
301 551837 : pub fn load(&self) -> S {
302 551837 : self.internal.lock().unwrap().current
303 551837 : }
304 :
305 : /// Get a Receiver for the current status.
306 : ///
307 : /// The current status is the number of the first waiter in the queue,
308 : /// or None if there are no waiters.
309 : ///
310 : /// This receiver will be notified whenever the status changes.
311 : /// It is useful for receiving notifications when the first waiter
312 : /// starts waiting for a number, or when there are no more waiters left.
313 0 : pub fn status_receiver(&self) -> watch::Receiver<Option<V>> {
314 0 : self.internal
315 0 : .lock()
316 0 : .unwrap()
317 0 : .waiters
318 0 : .status_channel
319 0 : .subscribe()
320 0 : }
321 : }
322 :
323 : #[cfg(test)]
324 : mod tests {
325 : use std::sync::Arc;
326 :
327 : use super::*;
328 :
329 : impl MonotonicCounter<i32> for i32 {
330 3 : fn cnt_advance(&mut self, val: i32) {
331 3 : assert!(*self <= val);
332 3 : *self = val;
333 3 : }
334 10 : fn cnt_value(&self) -> i32 {
335 10 : *self
336 10 : }
337 : }
338 :
339 : #[tokio::test]
340 1 : async fn seqwait() {
341 1 : let seq = Arc::new(SeqWait::new(0));
342 1 : let seq2 = Arc::clone(&seq);
343 1 : let seq3 = Arc::clone(&seq);
344 1 : let jh1 = tokio::task::spawn(async move {
345 1 : seq2.wait_for(42).await.expect("wait_for 42");
346 1 : let old = seq2.advance(100);
347 1 : assert_eq!(old, 99);
348 1 : seq2.wait_for_timeout(999, Duration::from_millis(100))
349 1 : .await
350 1 : .expect_err("no 999");
351 1 : });
352 1 : let jh2 = tokio::task::spawn(async move {
353 1 : seq3.wait_for(42).await.expect("wait_for 42");
354 1 : seq3.wait_for(0).await.expect("wait_for 0");
355 1 : });
356 1 : tokio::time::sleep(Duration::from_millis(200)).await;
357 1 : let old = seq.advance(99);
358 1 : assert_eq!(old, 0);
359 1 : seq.wait_for(100).await.expect("wait_for 100");
360 1 :
361 1 : // Calling advance with a smaller value is a no-op
362 1 : assert_eq!(seq.advance(98), 100);
363 1 : assert_eq!(seq.load(), 100);
364 1 :
365 1 : jh1.await.unwrap();
366 1 : jh2.await.unwrap();
367 1 :
368 1 : seq.shutdown();
369 1 : }
370 :
371 : #[tokio::test]
372 1 : async fn seqwait_timeout() {
373 1 : let seq = Arc::new(SeqWait::new(0));
374 1 : let seq2 = Arc::clone(&seq);
375 1 : let jh = tokio::task::spawn(async move {
376 1 : let timeout = Duration::from_millis(1);
377 1 : let res = seq2.wait_for_timeout(42, timeout).await;
378 1 : assert_eq!(res, Err(SeqWaitError::Timeout));
379 1 : });
380 1 : tokio::time::sleep(Duration::from_millis(200)).await;
381 1 : // This will attempt to wake, but nothing will happen
382 1 : // because the waiter already dropped its Receiver.
383 1 : let old = seq.advance(99);
384 1 : assert_eq!(old, 0);
385 1 : jh.await.unwrap();
386 1 :
387 1 : seq.shutdown();
388 1 : }
389 : }
|