TLA Line data Source code
1 : #![warn(missing_docs)]
2 :
3 : use std::cmp::{Eq, Ordering, PartialOrd};
4 : use std::collections::BinaryHeap;
5 : use std::fmt::Debug;
6 : use std::mem;
7 : use std::sync::Mutex;
8 : use std::time::Duration;
9 : use tokio::sync::watch::{channel, Receiver, Sender};
10 : use tokio::time::timeout;
11 :
12 : /// An error happened while waiting for a number
13 CBC 4 : #[derive(Debug, PartialEq, Eq, thiserror::Error)]
14 : pub enum SeqWaitError {
15 : /// The wait timeout was reached
16 : #[error("seqwait timeout was reached")]
17 : Timeout,
18 :
19 : /// [`SeqWait::shutdown`] was called
20 : #[error("SeqWait::shutdown was called")]
21 : Shutdown,
22 : }
23 :
24 : /// Monotonically increasing value
25 : ///
26 : /// It is handy to store some other fields under the same mutex in `SeqWait<S>`
27 : /// (e.g. store prev_record_lsn). So we allow SeqWait to be parametrized with
28 : /// any type that can expose counter. `V` is the type of exposed counter.
29 : pub trait MonotonicCounter<V> {
30 : /// Bump counter value and check that it goes forward
31 : /// N.B.: new_val is an actual new value, not a difference.
32 : fn cnt_advance(&mut self, new_val: V);
33 :
34 : /// Get counter value
35 : fn cnt_value(&self) -> V;
36 : }
37 :
38 : /// Internal components of a `SeqWait`
39 : struct SeqWaitInt<S, V>
40 : where
41 : S: MonotonicCounter<V>,
42 : V: Ord,
43 : {
44 : waiters: BinaryHeap<Waiter<V>>,
45 : current: S,
46 : shutdown: bool,
47 : }
48 :
49 : struct Waiter<T>
50 : where
51 : T: Ord,
52 : {
53 : wake_num: T, // wake me when this number arrives ...
54 : wake_channel: Sender<()>, // ... by sending a message to this channel
55 : }
56 :
57 : // BinaryHeap is a max-heap, and we want a min-heap. Reverse the ordering here
58 : // to get that.
59 : impl<T: Ord> PartialOrd for Waiter<T> {
60 90426 : fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
61 90426 : Some(self.cmp(other))
62 90426 : }
63 : }
64 :
65 : impl<T: Ord> Ord for Waiter<T> {
66 90426 : fn cmp(&self, other: &Self) -> Ordering {
67 90426 : other.wake_num.cmp(&self.wake_num)
68 90426 : }
69 : }
70 :
71 : impl<T: Ord> PartialEq for Waiter<T> {
72 UBC 0 : fn eq(&self, other: &Self) -> bool {
73 0 : other.wake_num == self.wake_num
74 0 : }
75 : }
76 :
77 : impl<T: Ord> Eq for Waiter<T> {}
78 :
79 : /// A tool for waiting on a sequence number
80 : ///
81 : /// This provides a way to wait the arrival of a number.
82 : /// As soon as the number arrives by another caller calling
83 : /// [`advance`], then the waiter will be woken up.
84 : ///
85 : /// This implementation takes a blocking Mutex on both [`wait_for`]
86 : /// and [`advance`], meaning there may be unexpected executor blocking
87 : /// due to thread scheduling unfairness. There are probably better
88 : /// implementations, but we can probably live with this for now.
89 : ///
90 : /// [`wait_for`]: SeqWait::wait_for
91 : /// [`advance`]: SeqWait::advance
92 : ///
93 : /// `S` means Storage, `V` is type of counter that this storage exposes.
94 : ///
95 : pub struct SeqWait<S, V>
96 : where
97 : S: MonotonicCounter<V>,
98 : V: Ord,
99 : {
100 : internal: Mutex<SeqWaitInt<S, V>>,
101 : }
102 :
103 : impl<S, V> SeqWait<S, V>
104 : where
105 : S: MonotonicCounter<V> + Copy,
106 : V: Ord + Copy,
107 : {
108 : /// Create a new `SeqWait`, initialized to a particular number
109 CBC 1304 : pub fn new(starting_num: S) -> Self {
110 1304 : let internal = SeqWaitInt {
111 1304 : waiters: BinaryHeap::new(),
112 1304 : current: starting_num,
113 1304 : shutdown: false,
114 1304 : };
115 1304 : SeqWait {
116 1304 : internal: Mutex::new(internal),
117 1304 : }
118 1304 : }
119 :
120 : /// Shut down a `SeqWait`, causing all waiters (present and
121 : /// future) to return an error.
122 2 : pub fn shutdown(&self) {
123 2 : let waiters = {
124 2 : // Prevent new waiters; wake all those that exist.
125 2 : // Wake everyone with an error.
126 2 : let mut internal = self.internal.lock().unwrap();
127 2 :
128 2 : // This will steal the entire waiters map.
129 2 : // When we drop it all waiters will be woken.
130 2 : mem::take(&mut internal.waiters)
131 2 :
132 2 : // Drop the lock as we exit this scope.
133 2 : };
134 2 :
135 2 : // When we drop the waiters list, each Receiver will
136 2 : // be woken with an error.
137 2 : // This drop doesn't need to be explicit; it's done
138 2 : // here to make it easier to read the code and understand
139 2 : // the order of events.
140 2 : drop(waiters);
141 2 : }
142 :
143 : /// Wait for a number to arrive
144 : ///
145 : /// This call won't complete until someone has called `advance`
146 : /// with a number greater than or equal to the one we're waiting for.
147 : ///
148 : /// This function is async cancellation-safe.
149 4 : pub async fn wait_for(&self, num: V) -> Result<(), SeqWaitError> {
150 4 : match self.queue_for_wait(num) {
151 1 : Ok(None) => Ok(()),
152 3 : Ok(Some(mut rx)) => rx.changed().await.map_err(|_| SeqWaitError::Shutdown),
153 UBC 0 : Err(e) => Err(e),
154 : }
155 CBC 4 : }
156 :
157 : /// Wait for a number to arrive
158 : ///
159 : /// This call won't complete until someone has called `advance`
160 : /// with a number greater than or equal to the one we're waiting for.
161 : ///
162 : /// If that hasn't happened after the specified timeout duration,
163 : /// [`SeqWaitError::Timeout`] will be returned.
164 : ///
165 : /// This function is async cancellation-safe.
166 1278622 : pub async fn wait_for_timeout(
167 1278622 : &self,
168 1278622 : num: V,
169 1278622 : timeout_duration: Duration,
170 1278622 : ) -> Result<(), SeqWaitError> {
171 1278622 : match self.queue_for_wait(num) {
172 1176301 : Ok(None) => Ok(()),
173 145153 : Ok(Some(mut rx)) => match timeout(timeout_duration, rx.changed()).await {
174 102314 : Ok(Ok(())) => Ok(()),
175 UBC 0 : Ok(Err(_)) => Err(SeqWaitError::Shutdown),
176 CBC 6 : Err(_) => Err(SeqWaitError::Timeout),
177 : },
178 UBC 0 : Err(e) => Err(e),
179 : }
180 CBC 1278621 : }
181 :
182 : /// Register and return a channel that will be notified when a number arrives,
183 : /// or None, if it has already arrived.
184 1278626 : fn queue_for_wait(&self, num: V) -> Result<Option<Receiver<()>>, SeqWaitError> {
185 1278626 : let mut internal = self.internal.lock().unwrap();
186 1278626 : if internal.current.cnt_value() >= num {
187 1176302 : return Ok(None);
188 102324 : }
189 102324 : if internal.shutdown {
190 UBC 0 : return Err(SeqWaitError::Shutdown);
191 CBC 102324 : }
192 102324 :
193 102324 : // Create a new channel.
194 102324 : let (tx, rx) = channel(());
195 102324 : internal.waiters.push(Waiter {
196 102324 : wake_num: num,
197 102324 : wake_channel: tx,
198 102324 : });
199 102324 : // Drop the lock as we exit this scope.
200 102324 : Ok(Some(rx))
201 1278626 : }
202 :
203 : /// Announce a new number has arrived
204 : ///
205 : /// All waiters at this value or below will be woken.
206 : ///
207 : /// Returns the old number.
208 69330607 : pub fn advance(&self, num: V) -> V {
209 : let old_value;
210 69330606 : let wake_these = {
211 69330607 : let mut internal = self.internal.lock().unwrap();
212 69330607 :
213 69330607 : old_value = internal.current.cnt_value();
214 69330607 : if old_value >= num {
215 1 : return old_value;
216 69330606 : }
217 69330606 : internal.current.cnt_advance(num);
218 69330606 :
219 69330606 : // Pop all waiters <= num from the heap. Collect them in a vector, and
220 69330606 : // wake them up after releasing the lock.
221 69330606 : let mut wake_these = Vec::new();
222 69432924 : while let Some(n) = internal.waiters.peek() {
223 12212997 : if n.wake_num > num {
224 12110679 : break;
225 102318 : }
226 102318 : wake_these.push(internal.waiters.pop().unwrap().wake_channel);
227 : }
228 69330606 : wake_these
229 : };
230 :
231 69432924 : for tx in wake_these {
232 102318 : // This can fail if there are no receivers.
233 102318 : // We don't care; discard the error.
234 102318 : let _ = tx.send(());
235 102318 : }
236 69330606 : old_value
237 69330607 : }
238 :
239 : /// Read the current value, without waiting.
240 82268228 : pub fn load(&self) -> S {
241 82268228 : self.internal.lock().unwrap().current
242 82268228 : }
243 : }
244 :
245 : #[cfg(test)]
246 : mod tests {
247 : use super::*;
248 : use std::sync::Arc;
249 : use std::time::Duration;
250 :
251 : impl MonotonicCounter<i32> for i32 {
252 3 : fn cnt_advance(&mut self, val: i32) {
253 3 : assert!(*self <= val);
254 3 : *self = val;
255 3 : }
256 10 : fn cnt_value(&self) -> i32 {
257 10 : *self
258 10 : }
259 : }
260 :
261 1 : #[tokio::test]
262 1 : async fn seqwait() {
263 1 : let seq = Arc::new(SeqWait::new(0));
264 1 : let seq2 = Arc::clone(&seq);
265 1 : let seq3 = Arc::clone(&seq);
266 1 : let jh1 = tokio::task::spawn(async move {
267 1 : seq2.wait_for(42).await.expect("wait_for 42");
268 1 : let old = seq2.advance(100);
269 1 : assert_eq!(old, 99);
270 1 : seq2.wait_for_timeout(999, Duration::from_millis(100))
271 1 : .await
272 1 : .expect_err("no 999");
273 1 : });
274 1 : let jh2 = tokio::task::spawn(async move {
275 1 : seq3.wait_for(42).await.expect("wait_for 42");
276 1 : seq3.wait_for(0).await.expect("wait_for 0");
277 1 : });
278 1 : tokio::time::sleep(Duration::from_millis(200)).await;
279 1 : let old = seq.advance(99);
280 1 : assert_eq!(old, 0);
281 1 : seq.wait_for(100).await.expect("wait_for 100");
282 1 :
283 1 : // Calling advance with a smaller value is a no-op
284 1 : assert_eq!(seq.advance(98), 100);
285 1 : assert_eq!(seq.load(), 100);
286 :
287 1 : jh1.await.unwrap();
288 1 : jh2.await.unwrap();
289 1 :
290 1 : seq.shutdown();
291 : }
292 :
293 1 : #[tokio::test]
294 1 : async fn seqwait_timeout() {
295 1 : let seq = Arc::new(SeqWait::new(0));
296 1 : let seq2 = Arc::clone(&seq);
297 1 : let jh = tokio::task::spawn(async move {
298 1 : let timeout = Duration::from_millis(1);
299 1 : let res = seq2.wait_for_timeout(42, timeout).await;
300 1 : assert_eq!(res, Err(SeqWaitError::Timeout));
301 1 : });
302 1 : tokio::time::sleep(Duration::from_millis(200)).await;
303 : // This will attempt to wake, but nothing will happen
304 : // because the waiter already dropped its Receiver.
305 1 : let old = seq.advance(99);
306 1 : assert_eq!(old, 0);
307 1 : jh.await.unwrap();
308 1 :
309 1 : seq.shutdown();
310 : }
311 : }
|