Line data Source code
1 : use std::{
2 : sync::atomic::{AtomicU32, Ordering},
3 : time::Duration,
4 : };
5 :
6 : #[derive(Debug)]
7 : pub struct CounterU32 {
8 : inner: AtomicU32,
9 : }
10 : impl Default for CounterU32 {
11 2469364 : fn default() -> Self {
12 2469364 : Self {
13 2469364 : inner: AtomicU32::new(u32::MAX),
14 2469364 : }
15 2469364 : }
16 : }
17 : impl CounterU32 {
18 36 : pub fn open(&self) -> Result<(), &'static str> {
19 36 : match self
20 36 : .inner
21 36 : .compare_exchange(u32::MAX, 0, Ordering::Relaxed, Ordering::Relaxed)
22 : {
23 36 : Ok(_) => Ok(()),
24 0 : Err(_) => Err("open() called on clsoed state"),
25 : }
26 36 : }
27 36 : pub fn close(&self) -> Result<u32, &'static str> {
28 36 : match self.inner.swap(u32::MAX, Ordering::Relaxed) {
29 0 : u32::MAX => Err("close() called on closed state"),
30 36 : x => Ok(x),
31 : }
32 36 : }
33 :
34 6 : pub fn add(&self, count: u32) -> Result<(), &'static str> {
35 6 : if count == 0 {
36 0 : return Ok(());
37 6 : }
38 6 : let mut had_err = None;
39 6 : self.inner
40 6 : .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |cur| match cur {
41 : u32::MAX => {
42 0 : had_err = Some("add() called on closed state");
43 0 : None
44 : }
45 6 : x => {
46 6 : let (new, overflowed) = x.overflowing_add(count);
47 6 : if new == u32::MAX || overflowed {
48 0 : had_err = Some("add() overflowed the counter");
49 0 : None
50 : } else {
51 6 : Some(new)
52 : }
53 : }
54 6 : })
55 6 : .map_err(|_| had_err.expect("we set it whenever the function returns None"))
56 6 : .map(|_| ())
57 6 : }
58 : }
59 :
60 : #[derive(Default, Debug)]
61 : pub struct MicroSecondsCounterU32 {
62 : inner: CounterU32,
63 : }
64 :
65 : impl MicroSecondsCounterU32 {
66 36 : pub fn open(&self) -> Result<(), &'static str> {
67 36 : self.inner.open()
68 36 : }
69 6 : pub fn add(&self, duration: Duration) -> Result<(), &'static str> {
70 6 : match duration.as_micros().try_into() {
71 6 : Ok(x) => self.inner.add(x),
72 0 : Err(_) => Err("add(): duration conversion error"),
73 : }
74 6 : }
75 36 : pub fn close_and_checked_sub_from(&self, from: Duration) -> Result<Duration, &'static str> {
76 36 : let val = self.inner.close()?;
77 36 : let val = Duration::from_micros(val as u64);
78 36 : let subbed = match from.checked_sub(val) {
79 36 : Some(v) => v,
80 0 : None => return Err("Duration::checked_sub"),
81 : };
82 36 : Ok(subbed)
83 36 : }
84 : }
85 :
86 : #[cfg(test)]
87 : mod tests {
88 :
89 : use super::*;
90 :
91 : #[test]
92 6 : fn test_basic() {
93 6 : let counter = MicroSecondsCounterU32::default();
94 6 : counter.open().unwrap();
95 6 : counter.add(Duration::from_micros(23)).unwrap();
96 6 : let res = counter
97 6 : .close_and_checked_sub_from(Duration::from_micros(42))
98 6 : .unwrap();
99 6 : assert_eq!(res, Duration::from_micros(42 - 23));
100 6 : }
101 : }
|