LCOV - code coverage report
Current view: top level - libs/utils/src - seqwait.rs (source / functions) Coverage Total Hit
Test: 5445d246133daeceb0507e6cc0797ab7c1c70cb8.info Lines: 88.5 % 200 177
Test Date: 2025-03-12 18:05:02 Functions: 47.6 % 82 39

            Line data    Source code
       1              : #![warn(missing_docs)]
       2              : 
       3              : use std::cmp::{Eq, Ordering};
       4              : use std::collections::BinaryHeap;
       5              : use std::mem;
       6              : use std::sync::Mutex;
       7              : use std::time::Duration;
       8              : 
       9              : use tokio::sync::watch::{self, channel};
      10              : use tokio::time::timeout;
      11              : 
      12              : /// An error happened while waiting for a number
      13              : #[derive(Debug, PartialEq, Eq, thiserror::Error)]
      14              : pub enum SeqWaitError {
      15              :     /// The wait timeout was reached
      16              :     #[error("seqwait timeout was reached")]
      17              :     Timeout,
      18              : 
      19              :     /// [`SeqWait::shutdown`] was called
      20              :     #[error("SeqWait::shutdown was called")]
      21              :     Shutdown,
      22              : }
      23              : 
      24              : /// Monotonically increasing value
      25              : ///
      26              : /// It is handy to store some other fields under the same mutex in `SeqWait<S>`
      27              : /// (e.g. store prev_record_lsn). So we allow SeqWait to be parametrized with
      28              : /// any type that can expose counter. `V` is the type of exposed counter.
      29              : pub trait MonotonicCounter<V> {
      30              :     /// Bump counter value and check that it goes forward
      31              :     /// N.B.: new_val is an actual new value, not a difference.
      32              :     fn cnt_advance(&mut self, new_val: V);
      33              : 
      34              :     /// Get counter value
      35              :     fn cnt_value(&self) -> V;
      36              : }
      37              : 
      38              : /// Heap of waiters, lowest numbers pop first.
      39              : struct Waiters<V>
      40              : where
      41              :     V: Ord,
      42              : {
      43              :     heap: BinaryHeap<Waiter<V>>,
      44              :     /// Number of the first waiter in the heap, or None if there are no waiters.
      45              :     status_channel: watch::Sender<Option<V>>,
      46              : }
      47              : 
      48              : impl<V> Waiters<V>
      49              : where
      50              :     V: Ord + Copy,
      51              : {
      52        26592 :     fn new() -> Self {
      53        26592 :         Waiters {
      54        26592 :             heap: BinaryHeap::new(),
      55        26592 :             status_channel: channel(None).0,
      56        26592 :         }
      57        26592 :     }
      58              : 
      59              :     /// `status_channel` contains the number of the first waiter in the heap.
      60              :     /// This function should be called whenever waiters heap changes.
      61           30 :     fn update_status(&self) {
      62           30 :         let first_waiter = self.heap.peek().map(|w| w.wake_num);
      63           30 :         let _ = self.status_channel.send_replace(first_waiter);
      64           30 :     }
      65              : 
      66              :     /// Add new waiter to the heap, return a channel that will be notified when the number arrives.
      67            5 :     fn add(&mut self, num: V) -> watch::Receiver<()> {
      68            5 :         let (tx, rx) = channel(());
      69            5 :         self.heap.push(Waiter {
      70            5 :             wake_num: num,
      71            5 :             wake_channel: tx,
      72            5 :         });
      73            5 :         self.update_status();
      74            5 :         rx
      75            5 :     }
      76              : 
      77              :     /// Pop all waiters <= num from the heap. Collect channels in a vector,
      78              :     /// so that caller can wake them up.
      79      9608923 :     fn pop_leq(&mut self, num: V) -> Vec<watch::Sender<()>> {
      80      9608923 :         let mut wake_these = Vec::new();
      81      9608927 :         while let Some(n) = self.heap.peek() {
      82            4 :             if n.wake_num > num {
      83            0 :                 break;
      84            4 :             }
      85            4 :             wake_these.push(self.heap.pop().unwrap().wake_channel);
      86              :         }
      87      9608923 :         if !wake_these.is_empty() {
      88            3 :             self.update_status();
      89            3 :         }
      90      9608923 :         wake_these
      91      9608923 :     }
      92              : 
      93              :     /// Used on shutdown to efficiently drop all waiters.
      94           22 :     fn take_all(&mut self) -> BinaryHeap<Waiter<V>> {
      95           22 :         let heap = mem::take(&mut self.heap);
      96           22 :         self.update_status();
      97           22 :         heap
      98           22 :     }
      99              : }
     100              : 
     101              : struct Waiter<T>
     102              : where
     103              :     T: Ord,
     104              : {
     105              :     wake_num: T,                     // wake me when this number arrives ...
     106              :     wake_channel: watch::Sender<()>, // ... by sending a message to this channel
     107              : }
     108              : 
     109              : // BinaryHeap is a max-heap, and we want a min-heap. Reverse the ordering here
     110              : // to get that.
     111              : impl<T: Ord> PartialOrd for Waiter<T> {
     112            1 :     fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
     113            1 :         Some(self.cmp(other))
     114            1 :     }
     115              : }
     116              : 
     117              : impl<T: Ord> Ord for Waiter<T> {
     118            1 :     fn cmp(&self, other: &Self) -> Ordering {
     119            1 :         other.wake_num.cmp(&self.wake_num)
     120            1 :     }
     121              : }
     122              : 
     123              : impl<T: Ord> PartialEq for Waiter<T> {
     124            0 :     fn eq(&self, other: &Self) -> bool {
     125            0 :         other.wake_num == self.wake_num
     126            0 :     }
     127              : }
     128              : 
     129              : impl<T: Ord> Eq for Waiter<T> {}
     130              : 
     131              : /// Internal components of a `SeqWait`
     132              : struct SeqWaitInt<S, V>
     133              : where
     134              :     S: MonotonicCounter<V>,
     135              :     V: Ord,
     136              : {
     137              :     waiters: Waiters<V>,
     138              :     current: S,
     139              :     shutdown: bool,
     140              : }
     141              : 
     142              : /// A tool for waiting on a sequence number
     143              : ///
     144              : /// This provides a way to wait the arrival of a number.
     145              : /// As soon as the number arrives by another caller calling
     146              : /// [`advance`], then the waiter will be woken up.
     147              : ///
     148              : /// This implementation takes a blocking Mutex on both [`wait_for`]
     149              : /// and [`advance`], meaning there may be unexpected executor blocking
     150              : /// due to thread scheduling unfairness. There are probably better
     151              : /// implementations, but we can probably live with this for now.
     152              : ///
     153              : /// [`wait_for`]: SeqWait::wait_for
     154              : /// [`advance`]: SeqWait::advance
     155              : ///
     156              : /// `S` means Storage, `V` is type of counter that this storage exposes.
     157              : ///
     158              : pub struct SeqWait<S, V>
     159              : where
     160              :     S: MonotonicCounter<V>,
     161              :     V: Ord,
     162              : {
     163              :     internal: Mutex<SeqWaitInt<S, V>>,
     164              : }
     165              : 
     166              : impl<S, V> SeqWait<S, V>
     167              : where
     168              :     S: MonotonicCounter<V> + Copy,
     169              :     V: Ord + Copy,
     170              : {
     171              :     /// Create a new `SeqWait`, initialized to a particular number
     172        26592 :     pub fn new(starting_num: S) -> Self {
     173        26592 :         let internal = SeqWaitInt {
     174        26592 :             waiters: Waiters::new(),
     175        26592 :             current: starting_num,
     176        26592 :             shutdown: false,
     177        26592 :         };
     178        26592 :         SeqWait {
     179        26592 :             internal: Mutex::new(internal),
     180        26592 :         }
     181        26592 :     }
     182              : 
     183              :     /// Shut down a `SeqWait`, causing all waiters (present and
     184              :     /// future) to return an error.
     185           22 :     pub fn shutdown(&self) {
     186           22 :         let waiters = {
     187           22 :             // Prevent new waiters; wake all those that exist.
     188           22 :             // Wake everyone with an error.
     189           22 :             let mut internal = self.internal.lock().unwrap();
     190           22 : 
     191           22 :             // Block any future waiters from starting
     192           22 :             internal.shutdown = true;
     193           22 : 
     194           22 :             // Take all waiters to drop them later.
     195           22 :             internal.waiters.take_all()
     196           22 : 
     197           22 :             // Drop the lock as we exit this scope.
     198           22 :         };
     199           22 : 
     200           22 :         // When we drop the waiters list, each Receiver will
     201           22 :         // be woken with an error.
     202           22 :         // This drop doesn't need to be explicit; it's done
     203           22 :         // here to make it easier to read the code and understand
     204           22 :         // the order of events.
     205           22 :         drop(waiters);
     206           22 :     }
     207              : 
     208              :     /// Wait for a number to arrive
     209              :     ///
     210              :     /// This call won't complete until someone has called `advance`
     211              :     /// with a number greater than or equal to the one we're waiting for.
     212              :     ///
     213              :     /// This function is async cancellation-safe.
     214            4 :     pub async fn wait_for(&self, num: V) -> Result<(), SeqWaitError> {
     215            4 :         match self.queue_for_wait(num) {
     216            1 :             Ok(None) => Ok(()),
     217            3 :             Ok(Some(mut rx)) => rx.changed().await.map_err(|_| SeqWaitError::Shutdown),
     218            0 :             Err(e) => Err(e),
     219              :         }
     220            4 :     }
     221              : 
     222              :     /// Wait for a number to arrive
     223              :     ///
     224              :     /// This call won't complete until someone has called `advance`
     225              :     /// with a number greater than or equal to the one we're waiting for.
     226              :     ///
     227              :     /// If that hasn't happened after the specified timeout duration,
     228              :     /// [`SeqWaitError::Timeout`] will be returned.
     229              :     ///
     230              :     /// This function is async cancellation-safe.
     231       449973 :     pub async fn wait_for_timeout(
     232       449973 :         &self,
     233       449973 :         num: V,
     234       449973 :         timeout_duration: Duration,
     235       449973 :     ) -> Result<(), SeqWaitError> {
     236       449973 :         match self.queue_for_wait(num) {
     237       449971 :             Ok(None) => Ok(()),
     238            2 :             Ok(Some(mut rx)) => match timeout(timeout_duration, rx.changed()).await {
     239            0 :                 Ok(Ok(())) => Ok(()),
     240            0 :                 Ok(Err(_)) => Err(SeqWaitError::Shutdown),
     241            2 :                 Err(_) => Err(SeqWaitError::Timeout),
     242              :             },
     243            0 :             Err(e) => Err(e),
     244              :         }
     245            2 :     }
     246              : 
     247              :     /// Check if [`Self::wait_for`] or [`Self::wait_for_timeout`] would wait if called with `num`.
     248            0 :     pub fn would_wait_for(&self, num: V) -> Result<(), V> {
     249            0 :         let internal = self.internal.lock().unwrap();
     250            0 :         let cnt = internal.current.cnt_value();
     251            0 :         drop(internal);
     252            0 :         if cnt >= num { Ok(()) } else { Err(cnt) }
     253            0 :     }
     254              : 
     255              :     /// Register and return a channel that will be notified when a number arrives,
     256              :     /// or None, if it has already arrived.
     257       449977 :     fn queue_for_wait(&self, num: V) -> Result<Option<watch::Receiver<()>>, SeqWaitError> {
     258       449977 :         let mut internal = self.internal.lock().unwrap();
     259       449977 :         if internal.current.cnt_value() >= num {
     260       449972 :             return Ok(None);
     261            5 :         }
     262            5 :         if internal.shutdown {
     263            0 :             return Err(SeqWaitError::Shutdown);
     264            5 :         }
     265            5 : 
     266            5 :         // Add waiter channel to the queue.
     267            5 :         let rx = internal.waiters.add(num);
     268            5 :         // Drop the lock as we exit this scope.
     269            5 :         Ok(Some(rx))
     270            6 :     }
     271              : 
     272              :     /// Announce a new number has arrived
     273              :     ///
     274              :     /// All waiters at this value or below will be woken.
     275              :     ///
     276              :     /// Returns the old number.
     277     10558300 :     pub fn advance(&self, num: V) -> V {
     278              :         let old_value;
     279      9608923 :         let wake_these = {
     280     10558300 :             let mut internal = self.internal.lock().unwrap();
     281     10558300 : 
     282     10558300 :             old_value = internal.current.cnt_value();
     283     10558300 :             if old_value >= num {
     284       949377 :                 return old_value;
     285      9608923 :             }
     286      9608923 :             internal.current.cnt_advance(num);
     287      9608923 : 
     288      9608923 :             // Pop all waiters <= num from the heap.
     289      9608923 :             internal.waiters.pop_leq(num)
     290              :         };
     291              : 
     292      9608927 :         for tx in wake_these {
     293            4 :             // This can fail if there are no receivers.
     294            4 :             // We don't care; discard the error.
     295            4 :             let _ = tx.send(());
     296            4 :         }
     297      9608923 :         old_value
     298            4 :     }
     299              : 
     300              :     /// Read the current value, without waiting.
     301       551837 :     pub fn load(&self) -> S {
     302       551837 :         self.internal.lock().unwrap().current
     303       551837 :     }
     304              : 
     305              :     /// Get a Receiver for the current status.
     306              :     ///
     307              :     /// The current status is the number of the first waiter in the queue,
     308              :     /// or None if there are no waiters.
     309              :     ///
     310              :     /// This receiver will be notified whenever the status changes.
     311              :     /// It is useful for receiving notifications when the first waiter
     312              :     /// starts waiting for a number, or when there are no more waiters left.
     313            0 :     pub fn status_receiver(&self) -> watch::Receiver<Option<V>> {
     314            0 :         self.internal
     315            0 :             .lock()
     316            0 :             .unwrap()
     317            0 :             .waiters
     318            0 :             .status_channel
     319            0 :             .subscribe()
     320            0 :     }
     321              : }
     322              : 
     323              : #[cfg(test)]
     324              : mod tests {
     325              :     use std::sync::Arc;
     326              : 
     327              :     use super::*;
     328              : 
     329              :     impl MonotonicCounter<i32> for i32 {
     330            3 :         fn cnt_advance(&mut self, val: i32) {
     331            3 :             assert!(*self <= val);
     332            3 :             *self = val;
     333            3 :         }
     334           10 :         fn cnt_value(&self) -> i32 {
     335           10 :             *self
     336           10 :         }
     337              :     }
     338              : 
     339              :     #[tokio::test]
     340            1 :     async fn seqwait() {
     341            1 :         let seq = Arc::new(SeqWait::new(0));
     342            1 :         let seq2 = Arc::clone(&seq);
     343            1 :         let seq3 = Arc::clone(&seq);
     344            1 :         let jh1 = tokio::task::spawn(async move {
     345            1 :             seq2.wait_for(42).await.expect("wait_for 42");
     346            1 :             let old = seq2.advance(100);
     347            1 :             assert_eq!(old, 99);
     348            1 :             seq2.wait_for_timeout(999, Duration::from_millis(100))
     349            1 :                 .await
     350            1 :                 .expect_err("no 999");
     351            1 :         });
     352            1 :         let jh2 = tokio::task::spawn(async move {
     353            1 :             seq3.wait_for(42).await.expect("wait_for 42");
     354            1 :             seq3.wait_for(0).await.expect("wait_for 0");
     355            1 :         });
     356            1 :         tokio::time::sleep(Duration::from_millis(200)).await;
     357            1 :         let old = seq.advance(99);
     358            1 :         assert_eq!(old, 0);
     359            1 :         seq.wait_for(100).await.expect("wait_for 100");
     360            1 : 
     361            1 :         // Calling advance with a smaller value is a no-op
     362            1 :         assert_eq!(seq.advance(98), 100);
     363            1 :         assert_eq!(seq.load(), 100);
     364            1 : 
     365            1 :         jh1.await.unwrap();
     366            1 :         jh2.await.unwrap();
     367            1 : 
     368            1 :         seq.shutdown();
     369            1 :     }
     370              : 
     371              :     #[tokio::test]
     372            1 :     async fn seqwait_timeout() {
     373            1 :         let seq = Arc::new(SeqWait::new(0));
     374            1 :         let seq2 = Arc::clone(&seq);
     375            1 :         let jh = tokio::task::spawn(async move {
     376            1 :             let timeout = Duration::from_millis(1);
     377            1 :             let res = seq2.wait_for_timeout(42, timeout).await;
     378            1 :             assert_eq!(res, Err(SeqWaitError::Timeout));
     379            1 :         });
     380            1 :         tokio::time::sleep(Duration::from_millis(200)).await;
     381            1 :         // This will attempt to wake, but nothing will happen
     382            1 :         // because the waiter already dropped its Receiver.
     383            1 :         let old = seq.advance(99);
     384            1 :         assert_eq!(old, 0);
     385            1 :         jh.await.unwrap();
     386            1 : 
     387            1 :         seq.shutdown();
     388            1 :     }
     389              : }
        

Generated by: LCOV version 2.1-beta