Line data Source code
1 : #![warn(missing_docs)]
2 :
3 : use std::cmp::{Eq, Ordering, PartialOrd};
4 : use std::collections::BinaryHeap;
5 : use std::fmt::Debug;
6 : use std::mem;
7 : use std::sync::Mutex;
8 : use std::time::Duration;
9 : use tokio::sync::watch::{channel, Receiver, Sender};
10 : use tokio::time::timeout;
11 :
12 : /// An error happened while waiting for a number
13 2 : #[derive(Debug, PartialEq, Eq, thiserror::Error)]
14 : pub enum SeqWaitError {
15 : /// The wait timeout was reached
16 : #[error("seqwait timeout was reached")]
17 : Timeout,
18 :
19 : /// [`SeqWait::shutdown`] was called
20 : #[error("SeqWait::shutdown was called")]
21 : Shutdown,
22 : }
23 :
24 : /// Monotonically increasing value
25 : ///
26 : /// It is handy to store some other fields under the same mutex in `SeqWait<S>`
27 : /// (e.g. store prev_record_lsn). So we allow SeqWait to be parametrized with
28 : /// any type that can expose counter. `V` is the type of exposed counter.
29 : pub trait MonotonicCounter<V> {
30 : /// Bump counter value and check that it goes forward
31 : /// N.B.: new_val is an actual new value, not a difference.
32 : fn cnt_advance(&mut self, new_val: V);
33 :
34 : /// Get counter value
35 : fn cnt_value(&self) -> V;
36 : }
37 :
38 : /// Internal components of a `SeqWait`
39 : struct SeqWaitInt<S, V>
40 : where
41 : S: MonotonicCounter<V>,
42 : V: Ord,
43 : {
44 : waiters: BinaryHeap<Waiter<V>>,
45 : current: S,
46 : shutdown: bool,
47 : }
48 :
49 : struct Waiter<T>
50 : where
51 : T: Ord,
52 : {
53 : wake_num: T, // wake me when this number arrives ...
54 : wake_channel: Sender<()>, // ... by sending a message to this channel
55 : }
56 :
57 : // BinaryHeap is a max-heap, and we want a min-heap. Reverse the ordering here
58 : // to get that.
59 : impl<T: Ord> PartialOrd for Waiter<T> {
60 2 : fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
61 2 : Some(self.cmp(other))
62 2 : }
63 : }
64 :
65 : impl<T: Ord> Ord for Waiter<T> {
66 2 : fn cmp(&self, other: &Self) -> Ordering {
67 2 : other.wake_num.cmp(&self.wake_num)
68 2 : }
69 : }
70 :
71 : impl<T: Ord> PartialEq for Waiter<T> {
72 0 : fn eq(&self, other: &Self) -> bool {
73 0 : other.wake_num == self.wake_num
74 0 : }
75 : }
76 :
77 : impl<T: Ord> Eq for Waiter<T> {}
78 :
79 : /// A tool for waiting on a sequence number
80 : ///
81 : /// This provides a way to wait the arrival of a number.
82 : /// As soon as the number arrives by another caller calling
83 : /// [`advance`], then the waiter will be woken up.
84 : ///
85 : /// This implementation takes a blocking Mutex on both [`wait_for`]
86 : /// and [`advance`], meaning there may be unexpected executor blocking
87 : /// due to thread scheduling unfairness. There are probably better
88 : /// implementations, but we can probably live with this for now.
89 : ///
90 : /// [`wait_for`]: SeqWait::wait_for
91 : /// [`advance`]: SeqWait::advance
92 : ///
93 : /// `S` means Storage, `V` is type of counter that this storage exposes.
94 : ///
95 : pub struct SeqWait<S, V>
96 : where
97 : S: MonotonicCounter<V>,
98 : V: Ord,
99 : {
100 : internal: Mutex<SeqWaitInt<S, V>>,
101 : }
102 :
103 : impl<S, V> SeqWait<S, V>
104 : where
105 : S: MonotonicCounter<V> + Copy,
106 : V: Ord + Copy,
107 : {
108 : /// Create a new `SeqWait`, initialized to a particular number
109 304 : pub fn new(starting_num: S) -> Self {
110 304 : let internal = SeqWaitInt {
111 304 : waiters: BinaryHeap::new(),
112 304 : current: starting_num,
113 304 : shutdown: false,
114 304 : };
115 304 : SeqWait {
116 304 : internal: Mutex::new(internal),
117 304 : }
118 304 : }
119 :
120 : /// Shut down a `SeqWait`, causing all waiters (present and
121 : /// future) to return an error.
122 18 : pub fn shutdown(&self) {
123 18 : let waiters = {
124 18 : // Prevent new waiters; wake all those that exist.
125 18 : // Wake everyone with an error.
126 18 : let mut internal = self.internal.lock().unwrap();
127 18 :
128 18 : // Block any future waiters from starting
129 18 : internal.shutdown = true;
130 18 :
131 18 : // This will steal the entire waiters map.
132 18 : // When we drop it all waiters will be woken.
133 18 : mem::take(&mut internal.waiters)
134 18 :
135 18 : // Drop the lock as we exit this scope.
136 18 : };
137 18 :
138 18 : // When we drop the waiters list, each Receiver will
139 18 : // be woken with an error.
140 18 : // This drop doesn't need to be explicit; it's done
141 18 : // here to make it easier to read the code and understand
142 18 : // the order of events.
143 18 : drop(waiters);
144 18 : }
145 :
146 : /// Wait for a number to arrive
147 : ///
148 : /// This call won't complete until someone has called `advance`
149 : /// with a number greater than or equal to the one we're waiting for.
150 : ///
151 : /// This function is async cancellation-safe.
152 8 : pub async fn wait_for(&self, num: V) -> Result<(), SeqWaitError> {
153 8 : match self.queue_for_wait(num) {
154 2 : Ok(None) => Ok(()),
155 6 : Ok(Some(mut rx)) => rx.changed().await.map_err(|_| SeqWaitError::Shutdown),
156 0 : Err(e) => Err(e),
157 : }
158 8 : }
159 :
160 : /// Wait for a number to arrive
161 : ///
162 : /// This call won't complete until someone has called `advance`
163 : /// with a number greater than or equal to the one we're waiting for.
164 : ///
165 : /// If that hasn't happened after the specified timeout duration,
166 : /// [`SeqWaitError::Timeout`] will be returned.
167 : ///
168 : /// This function is async cancellation-safe.
169 226771 : pub async fn wait_for_timeout(
170 226771 : &self,
171 226771 : num: V,
172 226771 : timeout_duration: Duration,
173 226771 : ) -> Result<(), SeqWaitError> {
174 226771 : match self.queue_for_wait(num) {
175 226767 : Ok(None) => Ok(()),
176 4 : Ok(Some(mut rx)) => match timeout(timeout_duration, rx.changed()).await {
177 0 : Ok(Ok(())) => Ok(()),
178 0 : Ok(Err(_)) => Err(SeqWaitError::Shutdown),
179 4 : Err(_) => Err(SeqWaitError::Timeout),
180 : },
181 0 : Err(e) => Err(e),
182 : }
183 226771 : }
184 :
185 : /// Register and return a channel that will be notified when a number arrives,
186 : /// or None, if it has already arrived.
187 226779 : fn queue_for_wait(&self, num: V) -> Result<Option<Receiver<()>>, SeqWaitError> {
188 226779 : let mut internal = self.internal.lock().unwrap();
189 226779 : if internal.current.cnt_value() >= num {
190 226769 : return Ok(None);
191 10 : }
192 10 : if internal.shutdown {
193 0 : return Err(SeqWaitError::Shutdown);
194 10 : }
195 10 :
196 10 : // Create a new channel.
197 10 : let (tx, rx) = channel(());
198 10 : internal.waiters.push(Waiter {
199 10 : wake_num: num,
200 10 : wake_channel: tx,
201 10 : });
202 10 : // Drop the lock as we exit this scope.
203 10 : Ok(Some(rx))
204 226779 : }
205 :
206 : /// Announce a new number has arrived
207 : ///
208 : /// All waiters at this value or below will be woken.
209 : ///
210 : /// Returns the old number.
211 3102920 : pub fn advance(&self, num: V) -> V {
212 : let old_value;
213 2628230 : let wake_these = {
214 3102920 : let mut internal = self.internal.lock().unwrap();
215 3102920 :
216 3102920 : old_value = internal.current.cnt_value();
217 3102920 : if old_value >= num {
218 474690 : return old_value;
219 2628230 : }
220 2628230 : internal.current.cnt_advance(num);
221 2628230 :
222 2628230 : // Pop all waiters <= num from the heap. Collect them in a vector, and
223 2628230 : // wake them up after releasing the lock.
224 2628230 : let mut wake_these = Vec::new();
225 2628238 : while let Some(n) = internal.waiters.peek() {
226 8 : if n.wake_num > num {
227 0 : break;
228 8 : }
229 8 : wake_these.push(internal.waiters.pop().unwrap().wake_channel);
230 : }
231 2628230 : wake_these
232 : };
233 :
234 2628238 : for tx in wake_these {
235 8 : // This can fail if there are no receivers.
236 8 : // We don't care; discard the error.
237 8 : let _ = tx.send(());
238 8 : }
239 2628230 : old_value
240 3102920 : }
241 :
242 : /// Read the current value, without waiting.
243 2901270 : pub fn load(&self) -> S {
244 2901270 : self.internal.lock().unwrap().current
245 2901270 : }
246 : }
247 :
248 : #[cfg(test)]
249 : mod tests {
250 : use super::*;
251 : use std::sync::Arc;
252 : use std::time::Duration;
253 :
254 : impl MonotonicCounter<i32> for i32 {
255 6 : fn cnt_advance(&mut self, val: i32) {
256 6 : assert!(*self <= val);
257 6 : *self = val;
258 6 : }
259 20 : fn cnt_value(&self) -> i32 {
260 20 : *self
261 20 : }
262 : }
263 :
264 2 : #[tokio::test]
265 2 : async fn seqwait() {
266 2 : let seq = Arc::new(SeqWait::new(0));
267 2 : let seq2 = Arc::clone(&seq);
268 2 : let seq3 = Arc::clone(&seq);
269 2 : let jh1 = tokio::task::spawn(async move {
270 2 : seq2.wait_for(42).await.expect("wait_for 42");
271 2 : let old = seq2.advance(100);
272 2 : assert_eq!(old, 99);
273 2 : seq2.wait_for_timeout(999, Duration::from_millis(100))
274 2 : .await
275 2 : .expect_err("no 999");
276 2 : });
277 2 : let jh2 = tokio::task::spawn(async move {
278 2 : seq3.wait_for(42).await.expect("wait_for 42");
279 2 : seq3.wait_for(0).await.expect("wait_for 0");
280 2 : });
281 2 : tokio::time::sleep(Duration::from_millis(200)).await;
282 2 : let old = seq.advance(99);
283 2 : assert_eq!(old, 0);
284 2 : seq.wait_for(100).await.expect("wait_for 100");
285 2 :
286 2 : // Calling advance with a smaller value is a no-op
287 2 : assert_eq!(seq.advance(98), 100);
288 2 : assert_eq!(seq.load(), 100);
289 2 :
290 2 : jh1.await.unwrap();
291 2 : jh2.await.unwrap();
292 2 :
293 2 : seq.shutdown();
294 2 : }
295 :
296 2 : #[tokio::test]
297 2 : async fn seqwait_timeout() {
298 2 : let seq = Arc::new(SeqWait::new(0));
299 2 : let seq2 = Arc::clone(&seq);
300 2 : let jh = tokio::task::spawn(async move {
301 2 : let timeout = Duration::from_millis(1);
302 2 : let res = seq2.wait_for_timeout(42, timeout).await;
303 2 : assert_eq!(res, Err(SeqWaitError::Timeout));
304 2 : });
305 2 : tokio::time::sleep(Duration::from_millis(200)).await;
306 2 : // This will attempt to wake, but nothing will happen
307 2 : // because the waiter already dropped its Receiver.
308 2 : let old = seq.advance(99);
309 2 : assert_eq!(old, 0);
310 2 : jh.await.unwrap();
311 2 :
312 2 : seq.shutdown();
313 2 : }
314 : }
|