Line data Source code
1 : #![warn(missing_docs)]
2 :
3 : use std::cmp::{Eq, Ordering, PartialOrd};
4 : use std::collections::BinaryHeap;
5 : use std::fmt::Debug;
6 : use std::mem;
7 : use std::sync::Mutex;
8 : use std::time::Duration;
9 : use tokio::sync::watch::{channel, Receiver, Sender};
10 : use tokio::time::timeout;
11 :
12 : /// An error happened while waiting for a number
13 2 : #[derive(Debug, PartialEq, Eq, thiserror::Error)]
14 : pub enum SeqWaitError {
15 : /// The wait timeout was reached
16 : #[error("seqwait timeout was reached")]
17 : Timeout,
18 :
19 : /// [`SeqWait::shutdown`] was called
20 : #[error("SeqWait::shutdown was called")]
21 : Shutdown,
22 : }
23 :
24 : /// Monotonically increasing value
25 : ///
26 : /// It is handy to store some other fields under the same mutex in `SeqWait<S>`
27 : /// (e.g. store prev_record_lsn). So we allow SeqWait to be parametrized with
28 : /// any type that can expose counter. `V` is the type of exposed counter.
29 : pub trait MonotonicCounter<V> {
30 : /// Bump counter value and check that it goes forward
31 : /// N.B.: new_val is an actual new value, not a difference.
32 : fn cnt_advance(&mut self, new_val: V);
33 :
34 : /// Get counter value
35 : fn cnt_value(&self) -> V;
36 : }
37 :
38 : /// Internal components of a `SeqWait`
39 : struct SeqWaitInt<S, V>
40 : where
41 : S: MonotonicCounter<V>,
42 : V: Ord,
43 : {
44 : waiters: BinaryHeap<Waiter<V>>,
45 : current: S,
46 : shutdown: bool,
47 : }
48 :
49 : struct Waiter<T>
50 : where
51 : T: Ord,
52 : {
53 : wake_num: T, // wake me when this number arrives ...
54 : wake_channel: Sender<()>, // ... by sending a message to this channel
55 : }
56 :
57 : // BinaryHeap is a max-heap, and we want a min-heap. Reverse the ordering here
58 : // to get that.
59 : impl<T: Ord> PartialOrd for Waiter<T> {
60 96551 : fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
61 96551 : Some(self.cmp(other))
62 96551 : }
63 : }
64 :
65 : impl<T: Ord> Ord for Waiter<T> {
66 96551 : fn cmp(&self, other: &Self) -> Ordering {
67 96551 : other.wake_num.cmp(&self.wake_num)
68 96551 : }
69 : }
70 :
71 : impl<T: Ord> PartialEq for Waiter<T> {
72 0 : fn eq(&self, other: &Self) -> bool {
73 0 : other.wake_num == self.wake_num
74 0 : }
75 : }
76 :
77 : impl<T: Ord> Eq for Waiter<T> {}
78 :
79 : /// A tool for waiting on a sequence number
80 : ///
81 : /// This provides a way to wait the arrival of a number.
82 : /// As soon as the number arrives by another caller calling
83 : /// [`advance`], then the waiter will be woken up.
84 : ///
85 : /// This implementation takes a blocking Mutex on both [`wait_for`]
86 : /// and [`advance`], meaning there may be unexpected executor blocking
87 : /// due to thread scheduling unfairness. There are probably better
88 : /// implementations, but we can probably live with this for now.
89 : ///
90 : /// [`wait_for`]: SeqWait::wait_for
91 : /// [`advance`]: SeqWait::advance
92 : ///
93 : /// `S` means Storage, `V` is type of counter that this storage exposes.
94 : ///
95 : pub struct SeqWait<S, V>
96 : where
97 : S: MonotonicCounter<V>,
98 : V: Ord,
99 : {
100 : internal: Mutex<SeqWaitInt<S, V>>,
101 : }
102 :
103 : impl<S, V> SeqWait<S, V>
104 : where
105 : S: MonotonicCounter<V> + Copy,
106 : V: Ord + Copy,
107 : {
108 : /// Create a new `SeqWait`, initialized to a particular number
109 2584 : pub fn new(starting_num: S) -> Self {
110 2584 : let internal = SeqWaitInt {
111 2584 : waiters: BinaryHeap::new(),
112 2584 : current: starting_num,
113 2584 : shutdown: false,
114 2584 : };
115 2584 : SeqWait {
116 2584 : internal: Mutex::new(internal),
117 2584 : }
118 2584 : }
119 :
120 : /// Shut down a `SeqWait`, causing all waiters (present and
121 : /// future) to return an error.
122 830 : pub fn shutdown(&self) {
123 830 : let waiters = {
124 830 : // Prevent new waiters; wake all those that exist.
125 830 : // Wake everyone with an error.
126 830 : let mut internal = self.internal.lock().unwrap();
127 830 :
128 830 : // Block any future waiters from starting
129 830 : internal.shutdown = true;
130 830 :
131 830 : // This will steal the entire waiters map.
132 830 : // When we drop it all waiters will be woken.
133 830 : mem::take(&mut internal.waiters)
134 830 :
135 830 : // Drop the lock as we exit this scope.
136 830 : };
137 830 :
138 830 : // When we drop the waiters list, each Receiver will
139 830 : // be woken with an error.
140 830 : // This drop doesn't need to be explicit; it's done
141 830 : // here to make it easier to read the code and understand
142 830 : // the order of events.
143 830 : drop(waiters);
144 830 : }
145 :
146 : /// Wait for a number to arrive
147 : ///
148 : /// This call won't complete until someone has called `advance`
149 : /// with a number greater than or equal to the one we're waiting for.
150 : ///
151 : /// This function is async cancellation-safe.
152 492 : pub async fn wait_for(&self, num: V) -> Result<(), SeqWaitError> {
153 485 : match self.queue_for_wait(num) {
154 2 : Ok(None) => Ok(()),
155 483 : Ok(Some(mut rx)) => rx.changed().await.map_err(|_| SeqWaitError::Shutdown),
156 0 : Err(e) => Err(e),
157 : }
158 9 : }
159 :
160 : /// Wait for a number to arrive
161 : ///
162 : /// This call won't complete until someone has called `advance`
163 : /// with a number greater than or equal to the one we're waiting for.
164 : ///
165 : /// If that hasn't happened after the specified timeout duration,
166 : /// [`SeqWaitError::Timeout`] will be returned.
167 : ///
168 : /// This function is async cancellation-safe.
169 1358731 : pub async fn wait_for_timeout(
170 1358731 : &self,
171 1358731 : num: V,
172 1358731 : timeout_duration: Duration,
173 1358731 : ) -> Result<(), SeqWaitError> {
174 1358731 : match self.queue_for_wait(num) {
175 1298941 : Ok(None) => Ok(()),
176 113027 : Ok(Some(mut rx)) => match timeout(timeout_duration, rx.changed()).await {
177 59760 : Ok(Ok(())) => Ok(()),
178 0 : Ok(Err(_)) => Err(SeqWaitError::Shutdown),
179 28 : Err(_) => Err(SeqWaitError::Timeout),
180 : },
181 0 : Err(e) => Err(e),
182 : }
183 1358729 : }
184 :
185 : /// Register and return a channel that will be notified when a number arrives,
186 : /// or None, if it has already arrived.
187 1359216 : fn queue_for_wait(&self, num: V) -> Result<Option<Receiver<()>>, SeqWaitError> {
188 1359216 : let mut internal = self.internal.lock().unwrap();
189 1359216 : if internal.current.cnt_value() >= num {
190 1298943 : return Ok(None);
191 60273 : }
192 60273 : if internal.shutdown {
193 0 : return Err(SeqWaitError::Shutdown);
194 60273 : }
195 60273 :
196 60273 : // Create a new channel.
197 60273 : let (tx, rx) = channel(());
198 60273 : internal.waiters.push(Waiter {
199 60273 : wake_num: num,
200 60273 : wake_channel: tx,
201 60273 : });
202 60273 : // Drop the lock as we exit this scope.
203 60273 : Ok(Some(rx))
204 1359216 : }
205 :
206 : /// Announce a new number has arrived
207 : ///
208 : /// All waiters at this value or below will be woken.
209 : ///
210 : /// Returns the old number.
211 76570822 : pub fn advance(&self, num: V) -> V {
212 : let old_value;
213 74771397 : let wake_these = {
214 76570822 : let mut internal = self.internal.lock().unwrap();
215 76570822 :
216 76570822 : old_value = internal.current.cnt_value();
217 76570822 : if old_value >= num {
218 1799425 : return old_value;
219 74771397 : }
220 74771397 : internal.current.cnt_advance(num);
221 74771397 :
222 74771397 : // Pop all waiters <= num from the heap. Collect them in a vector, and
223 74771397 : // wake them up after releasing the lock.
224 74771397 : let mut wake_these = Vec::new();
225 74831166 : while let Some(n) = internal.waiters.peek() {
226 12310981 : if n.wake_num > num {
227 12251212 : break;
228 59769 : }
229 59769 : wake_these.push(internal.waiters.pop().unwrap().wake_channel);
230 : }
231 74771397 : wake_these
232 : };
233 :
234 74831166 : for tx in wake_these {
235 59769 : // This can fail if there are no receivers.
236 59769 : // We don't care; discard the error.
237 59769 : let _ = tx.send(());
238 59769 : }
239 74771397 : old_value
240 76570822 : }
241 :
242 : /// Read the current value, without waiting.
243 10871532 : pub fn load(&self) -> S {
244 10871532 : self.internal.lock().unwrap().current
245 10871532 : }
246 : }
247 :
248 : #[cfg(test)]
249 : mod tests {
250 : use super::*;
251 : use std::sync::Arc;
252 : use std::time::Duration;
253 :
254 : impl MonotonicCounter<i32> for i32 {
255 6 : fn cnt_advance(&mut self, val: i32) {
256 6 : assert!(*self <= val);
257 6 : *self = val;
258 6 : }
259 20 : fn cnt_value(&self) -> i32 {
260 20 : *self
261 20 : }
262 : }
263 :
264 2 : #[tokio::test]
265 2 : async fn seqwait() {
266 2 : let seq = Arc::new(SeqWait::new(0));
267 2 : let seq2 = Arc::clone(&seq);
268 2 : let seq3 = Arc::clone(&seq);
269 2 : let jh1 = tokio::task::spawn(async move {
270 2 : seq2.wait_for(42).await.expect("wait_for 42");
271 2 : let old = seq2.advance(100);
272 2 : assert_eq!(old, 99);
273 2 : seq2.wait_for_timeout(999, Duration::from_millis(100))
274 2 : .await
275 2 : .expect_err("no 999");
276 2 : });
277 2 : let jh2 = tokio::task::spawn(async move {
278 2 : seq3.wait_for(42).await.expect("wait_for 42");
279 2 : seq3.wait_for(0).await.expect("wait_for 0");
280 2 : });
281 2 : tokio::time::sleep(Duration::from_millis(200)).await;
282 2 : let old = seq.advance(99);
283 2 : assert_eq!(old, 0);
284 2 : seq.wait_for(100).await.expect("wait_for 100");
285 2 :
286 2 : // Calling advance with a smaller value is a no-op
287 2 : assert_eq!(seq.advance(98), 100);
288 2 : assert_eq!(seq.load(), 100);
289 :
290 2 : jh1.await.unwrap();
291 2 : jh2.await.unwrap();
292 2 :
293 2 : seq.shutdown();
294 : }
295 :
296 2 : #[tokio::test]
297 2 : async fn seqwait_timeout() {
298 2 : let seq = Arc::new(SeqWait::new(0));
299 2 : let seq2 = Arc::clone(&seq);
300 2 : let jh = tokio::task::spawn(async move {
301 2 : let timeout = Duration::from_millis(1);
302 2 : let res = seq2.wait_for_timeout(42, timeout).await;
303 2 : assert_eq!(res, Err(SeqWaitError::Timeout));
304 2 : });
305 2 : tokio::time::sleep(Duration::from_millis(200)).await;
306 : // This will attempt to wake, but nothing will happen
307 : // because the waiter already dropped its Receiver.
308 2 : let old = seq.advance(99);
309 2 : assert_eq!(old, 0);
310 2 : jh.await.unwrap();
311 2 :
312 2 : seq.shutdown();
313 : }
314 : }
|