LCOV - code coverage report
Current view: top level - proxy/src - waiters.rs (source / functions) Coverage Total Hit
Test: 792183ae0ef4f1f8b22e9ac7e8748740ab73f873.info Lines: 94.6 % 56 53
Test Date: 2024-06-26 01:04:33 Functions: 27.6 % 29 8

            Line data    Source code
       1              : use hashbrown::HashMap;
       2              : use parking_lot::Mutex;
       3              : use pin_project_lite::pin_project;
       4              : use std::pin::Pin;
       5              : use std::task;
       6              : use thiserror::Error;
       7              : use tokio::sync::oneshot;
       8              : 
       9            0 : #[derive(Debug, Error)]
      10              : pub enum RegisterError {
      11              :     #[error("Waiter `{0}` already registered")]
      12              :     Occupied(String),
      13              : }
      14              : 
      15            0 : #[derive(Debug, Error)]
      16              : pub enum NotifyError {
      17              :     #[error("Notify failed: waiter `{0}` not registered")]
      18              :     NotFound(String),
      19              : 
      20              :     #[error("Notify failed: channel hangup")]
      21              :     Hangup,
      22              : }
      23              : 
      24            0 : #[derive(Debug, Error)]
      25              : pub enum WaitError {
      26              :     #[error("Wait failed: channel hangup")]
      27              :     Hangup,
      28              : }
      29              : 
      30              : pub struct Waiters<T>(pub(self) Mutex<HashMap<String, oneshot::Sender<T>>>);
      31              : 
      32              : impl<T> Default for Waiters<T> {
      33            2 :     fn default() -> Self {
      34            2 :         Waiters(Default::default())
      35            2 :     }
      36              : }
      37              : 
      38              : impl<T> Waiters<T> {
      39            2 :     pub fn register(&self, key: String) -> Result<Waiter<T>, RegisterError> {
      40            2 :         let (tx, rx) = oneshot::channel();
      41            2 : 
      42            2 :         self.0
      43            2 :             .lock()
      44            2 :             .try_insert(key.clone(), tx)
      45            2 :             .map_err(|e| RegisterError::Occupied(e.entry.key().clone()))?;
      46              : 
      47            2 :         Ok(Waiter {
      48            2 :             receiver: rx,
      49            2 :             guard: DropKey {
      50            2 :                 registry: self,
      51            2 :                 key,
      52            2 :             },
      53            2 :         })
      54            2 :     }
      55              : 
      56            2 :     pub fn notify(&self, key: &str, value: T) -> Result<(), NotifyError>
      57            2 :     where
      58            2 :         T: Send + Sync,
      59            2 :     {
      60            2 :         let tx = self
      61            2 :             .0
      62            2 :             .lock()
      63            2 :             .remove(key)
      64            2 :             .ok_or_else(|| NotifyError::NotFound(key.to_string()))?;
      65              : 
      66            2 :         tx.send(value).map_err(|_| NotifyError::Hangup)
      67            2 :     }
      68              : }
      69              : 
      70              : struct DropKey<'a, T> {
      71              :     key: String,
      72              :     registry: &'a Waiters<T>,
      73              : }
      74              : 
      75              : impl<'a, T> Drop for DropKey<'a, T> {
      76            2 :     fn drop(&mut self) {
      77            2 :         self.registry.0.lock().remove(&self.key);
      78            2 :     }
      79              : }
      80              : 
      81              : pin_project! {
      82              :     pub struct Waiter<'a, T> {
      83              :         #[pin]
      84              :         receiver: oneshot::Receiver<T>,
      85              :         guard: DropKey<'a, T>,
      86              :     }
      87              : }
      88              : 
      89              : impl<T> std::future::Future for Waiter<'_, T> {
      90              :     type Output = Result<T, WaitError>;
      91              : 
      92            4 :     fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
      93            4 :         self.project()
      94            4 :             .receiver
      95            4 :             .poll(cx)
      96            4 :             .map_err(|_| WaitError::Hangup)
      97            4 :     }
      98              : }
      99              : 
     100              : #[cfg(test)]
     101              : mod tests {
     102              :     use super::*;
     103              :     use std::sync::Arc;
     104              : 
     105              :     #[tokio::test]
     106            2 :     async fn test_waiter() -> anyhow::Result<()> {
     107            2 :         let waiters = Arc::new(Waiters::default());
     108            2 : 
     109            2 :         let key = "Key";
     110            2 :         let waiter = waiters.register(key.to_owned())?;
     111            2 : 
     112            2 :         let waiters = Arc::clone(&waiters);
     113            2 :         let notifier = tokio::spawn(async move {
     114            2 :             waiters.notify(key, Default::default())?;
     115            2 :             Ok(())
     116            2 :         });
     117            2 : 
     118            2 :         waiter.await?;
     119            2 :         notifier.await?
     120            2 :     }
     121              : }
        

Generated by: LCOV version 2.1-beta