LCOV - code coverage report
Current view: top level - proxy/src - waiters.rs (source / functions) Coverage Total Hit
Test: 49aa928ec5b4b510172d8b5c6d154da28e70a46c.info Lines: 94.6 % 56 53
Test Date: 2024-11-13 18:23:39 Functions: 27.6 % 29 8

            Line data    Source code
       1              : use std::pin::Pin;
       2              : use std::task;
       3              : 
       4              : use hashbrown::HashMap;
       5              : use parking_lot::Mutex;
       6              : use pin_project_lite::pin_project;
       7              : use thiserror::Error;
       8              : use tokio::sync::oneshot;
       9              : 
      10            0 : #[derive(Debug, Error)]
      11              : pub(crate) enum RegisterError {
      12              :     #[error("Waiter `{0}` already registered")]
      13              :     Occupied(String),
      14              : }
      15              : 
      16            0 : #[derive(Debug, Error)]
      17              : pub(crate) enum NotifyError {
      18              :     #[error("Notify failed: waiter `{0}` not registered")]
      19              :     NotFound(String),
      20              : 
      21              :     #[error("Notify failed: channel hangup")]
      22              :     Hangup,
      23              : }
      24              : 
      25            0 : #[derive(Debug, Error)]
      26              : pub(crate) enum WaitError {
      27              :     #[error("Wait failed: channel hangup")]
      28              :     Hangup,
      29              : }
      30              : 
      31              : pub(crate) struct Waiters<T>(pub(self) Mutex<HashMap<String, oneshot::Sender<T>>>);
      32              : 
      33              : impl<T> Default for Waiters<T> {
      34            1 :     fn default() -> Self {
      35            1 :         Waiters(Mutex::default())
      36            1 :     }
      37              : }
      38              : 
      39              : impl<T> Waiters<T> {
      40            1 :     pub(crate) fn register(&self, key: String) -> Result<Waiter<'_, T>, RegisterError> {
      41            1 :         let (tx, rx) = oneshot::channel();
      42            1 : 
      43            1 :         self.0
      44            1 :             .lock()
      45            1 :             .try_insert(key.clone(), tx)
      46            1 :             .map_err(|e| RegisterError::Occupied(e.entry.key().clone()))?;
      47              : 
      48            1 :         Ok(Waiter {
      49            1 :             receiver: rx,
      50            1 :             guard: DropKey {
      51            1 :                 registry: self,
      52            1 :                 key,
      53            1 :             },
      54            1 :         })
      55            1 :     }
      56              : 
      57            1 :     pub(crate) fn notify(&self, key: &str, value: T) -> Result<(), NotifyError>
      58            1 :     where
      59            1 :         T: Send + Sync,
      60            1 :     {
      61            1 :         let tx = self
      62            1 :             .0
      63            1 :             .lock()
      64            1 :             .remove(key)
      65            1 :             .ok_or_else(|| NotifyError::NotFound(key.to_string()))?;
      66              : 
      67            1 :         tx.send(value).map_err(|_| NotifyError::Hangup)
      68            1 :     }
      69              : }
      70              : 
      71              : struct DropKey<'a, T> {
      72              :     key: String,
      73              :     registry: &'a Waiters<T>,
      74              : }
      75              : 
      76              : impl<T> Drop for DropKey<'_, T> {
      77            1 :     fn drop(&mut self) {
      78            1 :         self.registry.0.lock().remove(&self.key);
      79            1 :     }
      80              : }
      81              : 
      82              : pin_project! {
      83              :     pub(crate) struct Waiter<'a, T> {
      84              :         #[pin]
      85              :         receiver: oneshot::Receiver<T>,
      86              :         guard: DropKey<'a, T>,
      87              :     }
      88              : }
      89              : 
      90              : impl<T> std::future::Future for Waiter<'_, T> {
      91              :     type Output = Result<T, WaitError>;
      92              : 
      93            2 :     fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
      94            2 :         self.project()
      95            2 :             .receiver
      96            2 :             .poll(cx)
      97            2 :             .map_err(|_| WaitError::Hangup)
      98            2 :     }
      99              : }
     100              : 
     101              : #[cfg(test)]
     102              : mod tests {
     103              :     use std::sync::Arc;
     104              : 
     105              :     use super::*;
     106              : 
     107              :     #[tokio::test]
     108            1 :     async fn test_waiter() -> anyhow::Result<()> {
     109            1 :         let waiters = Arc::new(Waiters::default());
     110            1 : 
     111            1 :         let key = "Key";
     112            1 :         let waiter = waiters.register(key.to_owned())?;
     113            1 : 
     114            1 :         let waiters = Arc::clone(&waiters);
     115            1 :         let notifier = tokio::spawn(async move {
     116            1 :             waiters.notify(key, ())?;
     117            1 :             Ok(())
     118            1 :         });
     119            1 : 
     120            1 :         waiter.await?;
     121            1 :         notifier.await?
     122            1 :     }
     123              : }
        

Generated by: LCOV version 2.1-beta