Line data Source code
1 : use std::{
2 : sync::atomic::{AtomicU32, Ordering},
3 : time::Duration,
4 : };
5 :
6 : #[derive(Debug)]
7 : pub struct CounterU32 {
8 : inner: AtomicU32,
9 : }
10 : impl Default for CounterU32 {
11 822725 : fn default() -> Self {
12 822725 : Self {
13 822725 : inner: AtomicU32::new(u32::MAX),
14 822725 : }
15 822725 : }
16 : }
17 : impl CounterU32 {
18 12 : pub fn open(&self) -> Result<(), &'static str> {
19 12 : match self
20 12 : .inner
21 12 : .compare_exchange(u32::MAX, 0, Ordering::Relaxed, Ordering::Relaxed)
22 : {
23 12 : Ok(_) => Ok(()),
24 0 : Err(_) => Err("open() called on clsoed state"),
25 : }
26 12 : }
27 12 : pub fn close(&self) -> Result<u32, &'static str> {
28 12 : match self.inner.swap(u32::MAX, Ordering::Relaxed) {
29 0 : u32::MAX => Err("close() called on closed state"),
30 12 : x => Ok(x),
31 : }
32 12 : }
33 :
34 2 : pub fn add(&self, count: u32) -> Result<(), &'static str> {
35 2 : if count == 0 {
36 0 : return Ok(());
37 2 : }
38 2 : let mut had_err = None;
39 2 : self.inner
40 2 : .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |cur| match cur {
41 : u32::MAX => {
42 0 : had_err = Some("add() called on closed state");
43 0 : None
44 : }
45 2 : x => {
46 2 : let (new, overflowed) = x.overflowing_add(count);
47 2 : if new == u32::MAX || overflowed {
48 0 : had_err = Some("add() overflowed the counter");
49 0 : None
50 : } else {
51 2 : Some(new)
52 : }
53 : }
54 2 : })
55 2 : .map_err(|_| had_err.expect("we set it whenever the function returns None"))
56 2 : .map(|_| ())
57 2 : }
58 : }
59 :
60 : #[derive(Default, Debug)]
61 : pub struct MicroSecondsCounterU32 {
62 : inner: CounterU32,
63 : }
64 :
65 : impl MicroSecondsCounterU32 {
66 12 : pub fn open(&self) -> Result<(), &'static str> {
67 12 : self.inner.open()
68 12 : }
69 2 : pub fn add(&self, duration: Duration) -> Result<(), &'static str> {
70 2 : match duration.as_micros().try_into() {
71 2 : Ok(x) => self.inner.add(x),
72 0 : Err(_) => Err("add(): duration conversion error"),
73 : }
74 2 : }
75 12 : pub fn close_and_checked_sub_from(&self, from: Duration) -> Result<Duration, &'static str> {
76 12 : let val = self.inner.close()?;
77 12 : let val = Duration::from_micros(val as u64);
78 12 : let subbed = match from.checked_sub(val) {
79 12 : Some(v) => v,
80 0 : None => return Err("Duration::checked_sub"),
81 : };
82 12 : Ok(subbed)
83 12 : }
84 : }
85 :
86 : #[cfg(test)]
87 : mod tests {
88 :
89 : use super::*;
90 :
91 : #[test]
92 2 : fn test_basic() {
93 2 : let counter = MicroSecondsCounterU32::default();
94 2 : counter.open().unwrap();
95 2 : counter.add(Duration::from_micros(23)).unwrap();
96 2 : let res = counter
97 2 : .close_and_checked_sub_from(Duration::from_micros(42))
98 2 : .unwrap();
99 2 : assert_eq!(res, Duration::from_micros(42 - 23));
100 2 : }
101 : }
|