|             Line data    Source code 
       1              : #![warn(missing_docs)]
       2              : 
       3              : use std::cmp::{Eq, Ordering};
       4              : use std::collections::BinaryHeap;
       5              : use std::mem;
       6              : use std::sync::Mutex;
       7              : use std::time::Duration;
       8              : use tokio::sync::watch::{self, channel};
       9              : use tokio::time::timeout;
      10              : 
      11              : /// An error happened while waiting for a number
      12            0 : #[derive(Debug, PartialEq, Eq, thiserror::Error)]
      13              : pub enum SeqWaitError {
      14              :     /// The wait timeout was reached
      15              :     #[error("seqwait timeout was reached")]
      16              :     Timeout,
      17              : 
      18              :     /// [`SeqWait::shutdown`] was called
      19              :     #[error("SeqWait::shutdown was called")]
      20              :     Shutdown,
      21              : }
      22              : 
      23              : /// Monotonically increasing value
      24              : ///
      25              : /// It is handy to store some other fields under the same mutex in `SeqWait<S>`
      26              : /// (e.g. store prev_record_lsn). So we allow SeqWait to be parametrized with
      27              : /// any type that can expose counter. `V` is the type of exposed counter.
      28              : pub trait MonotonicCounter<V> {
      29              :     /// Bump counter value and check that it goes forward
      30              :     /// N.B.: new_val is an actual new value, not a difference.
      31              :     fn cnt_advance(&mut self, new_val: V);
      32              : 
      33              :     /// Get counter value
      34              :     fn cnt_value(&self) -> V;
      35              : }
      36              : 
      37              : /// Heap of waiters, lowest numbers pop first.
      38              : struct Waiters<V>
      39              : where
      40              :     V: Ord,
      41              : {
      42              :     heap: BinaryHeap<Waiter<V>>,
      43              :     /// Number of the first waiter in the heap, or None if there are no waiters.
      44              :     status_channel: watch::Sender<Option<V>>,
      45              : }
      46              : 
      47              : impl<V> Waiters<V>
      48              : where
      49              :     V: Ord + Copy,
      50              : {
      51          427 :     fn new() -> Self {
      52          427 :         Waiters {
      53          427 :             heap: BinaryHeap::new(),
      54          427 :             status_channel: channel(None).0,
      55          427 :         }
      56          427 :     }
      57              : 
      58              :     /// `status_channel` contains the number of the first waiter in the heap.
      59              :     /// This function should be called whenever waiters heap changes.
      60      4804404 :     fn update_status(&self) {
      61      4804404 :         let first_waiter = self.heap.peek().map(|w| w.wake_num);
      62      4804404 :         let _ = self.status_channel.send_replace(first_waiter);
      63      4804404 :     }
      64              : 
      65              :     /// Add new waiter to the heap, return a channel that will be notified when the number arrives.
      66           10 :     fn add(&mut self, num: V) -> watch::Receiver<()> {
      67           10 :         let (tx, rx) = channel(());
      68           10 :         self.heap.push(Waiter {
      69           10 :             wake_num: num,
      70           10 :             wake_channel: tx,
      71           10 :         });
      72           10 :         self.update_status();
      73           10 :         rx
      74           10 :     }
      75              : 
      76              :     /// Pop all waiters <= num from the heap. Collect channels in a vector,
      77              :     /// so that caller can wake them up.
      78      4804382 :     fn pop_leq(&mut self, num: V) -> Vec<watch::Sender<()>> {
      79      4804382 :         let mut wake_these = Vec::new();
      80      4804390 :         while let Some(n) = self.heap.peek() {
      81            8 :             if n.wake_num > num {
      82            0 :                 break;
      83            8 :             }
      84            8 :             wake_these.push(self.heap.pop().unwrap().wake_channel);
      85              :         }
      86      4804382 :         self.update_status();
      87      4804382 :         wake_these
      88      4804382 :     }
      89              : 
      90              :     /// Used on shutdown to efficiently drop all waiters.
      91           12 :     fn take_all(&mut self) -> BinaryHeap<Waiter<V>> {
      92           12 :         let heap = mem::take(&mut self.heap);
      93           12 :         self.update_status();
      94           12 :         heap
      95           12 :     }
      96              : }
      97              : 
      98              : struct Waiter<T>
      99              : where
     100              :     T: Ord,
     101              : {
     102              :     wake_num: T,                     // wake me when this number arrives ...
     103              :     wake_channel: watch::Sender<()>, // ... by sending a message to this channel
     104              : }
     105              : 
     106              : // BinaryHeap is a max-heap, and we want a min-heap. Reverse the ordering here
     107              : // to get that.
     108              : impl<T: Ord> PartialOrd for Waiter<T> {
     109            2 :     fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
     110            2 :         Some(self.cmp(other))
     111            2 :     }
     112              : }
     113              : 
     114              : impl<T: Ord> Ord for Waiter<T> {
     115            2 :     fn cmp(&self, other: &Self) -> Ordering {
     116            2 :         other.wake_num.cmp(&self.wake_num)
     117            2 :     }
     118              : }
     119              : 
     120              : impl<T: Ord> PartialEq for Waiter<T> {
     121            0 :     fn eq(&self, other: &Self) -> bool {
     122            0 :         other.wake_num == self.wake_num
     123            0 :     }
     124              : }
     125              : 
     126              : impl<T: Ord> Eq for Waiter<T> {}
     127              : 
     128              : /// Internal components of a `SeqWait`
     129              : struct SeqWaitInt<S, V>
     130              : where
     131              :     S: MonotonicCounter<V>,
     132              :     V: Ord,
     133              : {
     134              :     waiters: Waiters<V>,
     135              :     current: S,
     136              :     shutdown: bool,
     137              : }
     138              : 
     139              : /// A tool for waiting on a sequence number
     140              : ///
     141              : /// This provides a way to wait the arrival of a number.
     142              : /// As soon as the number arrives by another caller calling
     143              : /// [`advance`], then the waiter will be woken up.
     144              : ///
     145              : /// This implementation takes a blocking Mutex on both [`wait_for`]
     146              : /// and [`advance`], meaning there may be unexpected executor blocking
     147              : /// due to thread scheduling unfairness. There are probably better
     148              : /// implementations, but we can probably live with this for now.
     149              : ///
     150              : /// [`wait_for`]: SeqWait::wait_for
     151              : /// [`advance`]: SeqWait::advance
     152              : ///
     153              : /// `S` means Storage, `V` is type of counter that this storage exposes.
     154              : ///
     155              : pub struct SeqWait<S, V>
     156              : where
     157              :     S: MonotonicCounter<V>,
     158              :     V: Ord,
     159              : {
     160              :     internal: Mutex<SeqWaitInt<S, V>>,
     161              : }
     162              : 
     163              : impl<S, V> SeqWait<S, V>
     164              : where
     165              :     S: MonotonicCounter<V> + Copy,
     166              :     V: Ord + Copy,
     167              : {
     168              :     /// Create a new `SeqWait`, initialized to a particular number
     169          427 :     pub fn new(starting_num: S) -> Self {
     170          427 :         let internal = SeqWaitInt {
     171          427 :             waiters: Waiters::new(),
     172          427 :             current: starting_num,
     173          427 :             shutdown: false,
     174          427 :         };
     175          427 :         SeqWait {
     176          427 :             internal: Mutex::new(internal),
     177          427 :         }
     178          427 :     }
     179              : 
     180              :     /// Shut down a `SeqWait`, causing all waiters (present and
     181              :     /// future) to return an error.
     182           12 :     pub fn shutdown(&self) {
     183           12 :         let waiters = {
     184           12 :             // Prevent new waiters; wake all those that exist.
     185           12 :             // Wake everyone with an error.
     186           12 :             let mut internal = self.internal.lock().unwrap();
     187           12 : 
     188           12 :             // Block any future waiters from starting
     189           12 :             internal.shutdown = true;
     190           12 : 
     191           12 :             // Take all waiters to drop them later.
     192           12 :             internal.waiters.take_all()
     193           12 : 
     194           12 :             // Drop the lock as we exit this scope.
     195           12 :         };
     196           12 : 
     197           12 :         // When we drop the waiters list, each Receiver will
     198           12 :         // be woken with an error.
     199           12 :         // This drop doesn't need to be explicit; it's done
     200           12 :         // here to make it easier to read the code and understand
     201           12 :         // the order of events.
     202           12 :         drop(waiters);
     203           12 :     }
     204              : 
     205              :     /// Wait for a number to arrive
     206              :     ///
     207              :     /// This call won't complete until someone has called `advance`
     208              :     /// with a number greater than or equal to the one we're waiting for.
     209              :     ///
     210              :     /// This function is async cancellation-safe.
     211            8 :     pub async fn wait_for(&self, num: V) -> Result<(), SeqWaitError> {
     212            8 :         match self.queue_for_wait(num) {
     213            2 :             Ok(None) => Ok(()),
     214            6 :             Ok(Some(mut rx)) => rx.changed().await.map_err(|_| SeqWaitError::Shutdown),
     215            0 :             Err(e) => Err(e),
     216              :         }
     217            8 :     }
     218              : 
     219              :     /// Wait for a number to arrive
     220              :     ///
     221              :     /// This call won't complete until someone has called `advance`
     222              :     /// with a number greater than or equal to the one we're waiting for.
     223              :     ///
     224              :     /// If that hasn't happened after the specified timeout duration,
     225              :     /// [`SeqWaitError::Timeout`] will be returned.
     226              :     ///
     227              :     /// This function is async cancellation-safe.
     228       224844 :     pub async fn wait_for_timeout(
     229       224844 :         &self,
     230       224844 :         num: V,
     231       224844 :         timeout_duration: Duration,
     232       224844 :     ) -> Result<(), SeqWaitError> {
     233       224844 :         match self.queue_for_wait(num) {
     234       224840 :             Ok(None) => Ok(()),
     235            4 :             Ok(Some(mut rx)) => match timeout(timeout_duration, rx.changed()).await {
     236            0 :                 Ok(Ok(())) => Ok(()),
     237            0 :                 Ok(Err(_)) => Err(SeqWaitError::Shutdown),
     238            4 :                 Err(_) => Err(SeqWaitError::Timeout),
     239              :             },
     240            0 :             Err(e) => Err(e),
     241              :         }
     242       224844 :     }
     243              : 
     244              :     /// Check if [`Self::wait_for`] or [`Self::wait_for_timeout`] would wait if called with `num`.
     245            0 :     pub fn would_wait_for(&self, num: V) -> Result<(), V> {
     246            0 :         let internal = self.internal.lock().unwrap();
     247            0 :         let cnt = internal.current.cnt_value();
     248            0 :         drop(internal);
     249            0 :         if cnt >= num {
     250            0 :             Ok(())
     251              :         } else {
     252            0 :             Err(cnt)
     253              :         }
     254            0 :     }
     255              : 
     256              :     /// Register and return a channel that will be notified when a number arrives,
     257              :     /// or None, if it has already arrived.
     258       224852 :     fn queue_for_wait(&self, num: V) -> Result<Option<watch::Receiver<()>>, SeqWaitError> {
     259       224852 :         let mut internal = self.internal.lock().unwrap();
     260       224852 :         if internal.current.cnt_value() >= num {
     261       224842 :             return Ok(None);
     262           10 :         }
     263           10 :         if internal.shutdown {
     264            0 :             return Err(SeqWaitError::Shutdown);
     265           10 :         }
     266           10 : 
     267           10 :         // Add waiter channel to the queue.
     268           10 :         let rx = internal.waiters.add(num);
     269           10 :         // Drop the lock as we exit this scope.
     270           10 :         Ok(Some(rx))
     271       224852 :     }
     272              : 
     273              :     /// Announce a new number has arrived
     274              :     ///
     275              :     /// All waiters at this value or below will be woken.
     276              :     ///
     277              :     /// Returns the old number.
     278      5279072 :     pub fn advance(&self, num: V) -> V {
     279              :         let old_value;
     280      4804382 :         let wake_these = {
     281      5279072 :             let mut internal = self.internal.lock().unwrap();
     282      5279072 : 
     283      5279072 :             old_value = internal.current.cnt_value();
     284      5279072 :             if old_value >= num {
     285       474690 :                 return old_value;
     286      4804382 :             }
     287      4804382 :             internal.current.cnt_advance(num);
     288      4804382 : 
     289      4804382 :             // Pop all waiters <= num from the heap.
     290      4804382 :             internal.waiters.pop_leq(num)
     291              :         };
     292              : 
     293      4804390 :         for tx in wake_these {
     294            8 :             // This can fail if there are no receivers.
     295            8 :             // We don't care; discard the error.
     296            8 :             let _ = tx.send(());
     297            8 :         }
     298      4804382 :         old_value
     299      5279072 :     }
     300              : 
     301              :     /// Read the current value, without waiting.
     302       276772 :     pub fn load(&self) -> S {
     303       276772 :         self.internal.lock().unwrap().current
     304       276772 :     }
     305              : 
     306              :     /// Get a Receiver for the current status.
     307              :     ///
     308              :     /// The current status is the number of the first waiter in the queue,
     309              :     /// or None if there are no waiters.
     310              :     ///
     311              :     /// This receiver will be notified whenever the status changes.
     312              :     /// It is useful for receiving notifications when the first waiter
     313              :     /// starts waiting for a number, or when there are no more waiters left.
     314            0 :     pub fn status_receiver(&self) -> watch::Receiver<Option<V>> {
     315            0 :         self.internal
     316            0 :             .lock()
     317            0 :             .unwrap()
     318            0 :             .waiters
     319            0 :             .status_channel
     320            0 :             .subscribe()
     321            0 :     }
     322              : }
     323              : 
     324              : #[cfg(test)]
     325              : mod tests {
     326              :     use super::*;
     327              :     use std::sync::Arc;
     328              : 
     329              :     impl MonotonicCounter<i32> for i32 {
     330            6 :         fn cnt_advance(&mut self, val: i32) {
     331            6 :             assert!(*self <= val);
     332            6 :             *self = val;
     333            6 :         }
     334           20 :         fn cnt_value(&self) -> i32 {
     335           20 :             *self
     336           20 :         }
     337              :     }
     338              : 
     339              :     #[tokio::test]
     340            2 :     async fn seqwait() {
     341            2 :         let seq = Arc::new(SeqWait::new(0));
     342            2 :         let seq2 = Arc::clone(&seq);
     343            2 :         let seq3 = Arc::clone(&seq);
     344            2 :         let jh1 = tokio::task::spawn(async move {
     345            2 :             seq2.wait_for(42).await.expect("wait_for 42");
     346            2 :             let old = seq2.advance(100);
     347            2 :             assert_eq!(old, 99);
     348            2 :             seq2.wait_for_timeout(999, Duration::from_millis(100))
     349            2 :                 .await
     350            2 :                 .expect_err("no 999");
     351            2 :         });
     352            2 :         let jh2 = tokio::task::spawn(async move {
     353            2 :             seq3.wait_for(42).await.expect("wait_for 42");
     354            2 :             seq3.wait_for(0).await.expect("wait_for 0");
     355            2 :         });
     356            2 :         tokio::time::sleep(Duration::from_millis(200)).await;
     357            2 :         let old = seq.advance(99);
     358            2 :         assert_eq!(old, 0);
     359            2 :         seq.wait_for(100).await.expect("wait_for 100");
     360            2 : 
     361            2 :         // Calling advance with a smaller value is a no-op
     362            2 :         assert_eq!(seq.advance(98), 100);
     363            2 :         assert_eq!(seq.load(), 100);
     364            2 : 
     365            2 :         jh1.await.unwrap();
     366            2 :         jh2.await.unwrap();
     367            2 : 
     368            2 :         seq.shutdown();
     369            2 :     }
     370              : 
     371              :     #[tokio::test]
     372            2 :     async fn seqwait_timeout() {
     373            2 :         let seq = Arc::new(SeqWait::new(0));
     374            2 :         let seq2 = Arc::clone(&seq);
     375            2 :         let jh = tokio::task::spawn(async move {
     376            2 :             let timeout = Duration::from_millis(1);
     377            2 :             let res = seq2.wait_for_timeout(42, timeout).await;
     378            2 :             assert_eq!(res, Err(SeqWaitError::Timeout));
     379            2 :         });
     380            2 :         tokio::time::sleep(Duration::from_millis(200)).await;
     381            2 :         // This will attempt to wake, but nothing will happen
     382            2 :         // because the waiter already dropped its Receiver.
     383            2 :         let old = seq.advance(99);
     384            2 :         assert_eq!(old, 0);
     385            2 :         jh.await.unwrap();
     386            2 : 
     387            2 :         seq.shutdown();
     388            2 :     }
     389              : }
         |