LCOV - code coverage report
Current view: top level - pageserver/src/context - optional_counter.rs (source / functions) Coverage Total Hit
Test: 960803fca14b2e843c565dddf575f7017d250bc3.info Lines: 85.2 % 61 52
Test Date: 2024-06-22 23:41:44 Functions: 90.9 % 11 10

            Line data    Source code
       1              : use std::{
       2              :     sync::atomic::{AtomicU32, Ordering},
       3              :     time::Duration,
       4              : };
       5              : 
       6              : #[derive(Debug)]
       7              : pub struct CounterU32 {
       8              :     inner: AtomicU32,
       9              : }
      10              : impl Default for CounterU32 {
      11      6123927 :     fn default() -> Self {
      12      6123927 :         Self {
      13      6123927 :             inner: AtomicU32::new(u32::MAX),
      14      6123927 :         }
      15      6123927 :     }
      16              : }
      17              : impl CounterU32 {
      18           12 :     pub fn open(&self) -> Result<(), &'static str> {
      19           12 :         match self
      20           12 :             .inner
      21           12 :             .compare_exchange(u32::MAX, 0, Ordering::Relaxed, Ordering::Relaxed)
      22              :         {
      23           12 :             Ok(_) => Ok(()),
      24            0 :             Err(_) => Err("open() called on clsoed state"),
      25              :         }
      26           12 :     }
      27           12 :     pub fn close(&self) -> Result<u32, &'static str> {
      28           12 :         match self.inner.swap(u32::MAX, Ordering::Relaxed) {
      29            0 :             u32::MAX => Err("close() called on closed state"),
      30           12 :             x => Ok(x),
      31              :         }
      32           12 :     }
      33              : 
      34            2 :     pub fn add(&self, count: u32) -> Result<(), &'static str> {
      35            2 :         if count == 0 {
      36            0 :             return Ok(());
      37            2 :         }
      38            2 :         let mut had_err = None;
      39            2 :         self.inner
      40            2 :             .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |cur| match cur {
      41              :                 u32::MAX => {
      42            0 :                     had_err = Some("add() called on closed state");
      43            0 :                     None
      44              :                 }
      45            2 :                 x => {
      46            2 :                     let (new, overflowed) = x.overflowing_add(count);
      47            2 :                     if new == u32::MAX || overflowed {
      48            0 :                         had_err = Some("add() overflowed the counter");
      49            0 :                         None
      50              :                     } else {
      51            2 :                         Some(new)
      52              :                     }
      53              :                 }
      54            2 :             })
      55            2 :             .map_err(|_| had_err.expect("we set it whenever the function returns None"))
      56            2 :             .map(|_| ())
      57            2 :     }
      58              : }
      59              : 
      60              : #[derive(Default, Debug)]
      61              : pub struct MicroSecondsCounterU32 {
      62              :     inner: CounterU32,
      63              : }
      64              : 
      65              : impl MicroSecondsCounterU32 {
      66           12 :     pub fn open(&self) -> Result<(), &'static str> {
      67           12 :         self.inner.open()
      68           12 :     }
      69            2 :     pub fn add(&self, duration: Duration) -> Result<(), &'static str> {
      70            2 :         match duration.as_micros().try_into() {
      71            2 :             Ok(x) => self.inner.add(x),
      72            0 :             Err(_) => Err("add(): duration conversion error"),
      73              :         }
      74            2 :     }
      75           12 :     pub fn close_and_checked_sub_from(&self, from: Duration) -> Result<Duration, &'static str> {
      76           12 :         let val = self.inner.close()?;
      77           12 :         let val = Duration::from_micros(val as u64);
      78           12 :         let subbed = match from.checked_sub(val) {
      79           12 :             Some(v) => v,
      80            0 :             None => return Err("Duration::checked_sub"),
      81              :         };
      82           12 :         Ok(subbed)
      83           12 :     }
      84              : }
      85              : 
      86              : #[cfg(test)]
      87              : mod tests {
      88              : 
      89              :     use super::*;
      90              : 
      91              :     #[test]
      92            2 :     fn test_basic() {
      93            2 :         let counter = MicroSecondsCounterU32::default();
      94            2 :         counter.open().unwrap();
      95            2 :         counter.add(Duration::from_micros(23)).unwrap();
      96            2 :         let res = counter
      97            2 :             .close_and_checked_sub_from(Duration::from_micros(42))
      98            2 :             .unwrap();
      99            2 :         assert_eq!(res, Duration::from_micros(42 - 23));
     100            2 :     }
     101              : }
        

Generated by: LCOV version 2.1-beta