LCOV - code coverage report
Current view: top level - proxy/src - waiters.rs (source / functions) Coverage Total Hit
Test: 90b23405d17e36048d3bb64e314067f397803f1b.info Lines: 94.6 % 56 53
Test Date: 2024-09-20 13:14:58 Functions: 27.6 % 29 8

            Line data    Source code
       1              : use hashbrown::HashMap;
       2              : use parking_lot::Mutex;
       3              : use pin_project_lite::pin_project;
       4              : use std::pin::Pin;
       5              : use std::task;
       6              : use thiserror::Error;
       7              : use tokio::sync::oneshot;
       8              : 
       9            0 : #[derive(Debug, Error)]
      10              : pub(crate) enum RegisterError {
      11              :     #[error("Waiter `{0}` already registered")]
      12              :     Occupied(String),
      13              : }
      14              : 
      15            0 : #[derive(Debug, Error)]
      16              : pub(crate) enum NotifyError {
      17              :     #[error("Notify failed: waiter `{0}` not registered")]
      18              :     NotFound(String),
      19              : 
      20              :     #[error("Notify failed: channel hangup")]
      21              :     Hangup,
      22              : }
      23              : 
      24            0 : #[derive(Debug, Error)]
      25              : pub(crate) enum WaitError {
      26              :     #[error("Wait failed: channel hangup")]
      27              :     Hangup,
      28              : }
      29              : 
      30              : pub(crate) struct Waiters<T>(pub(self) Mutex<HashMap<String, oneshot::Sender<T>>>);
      31              : 
      32              : impl<T> Default for Waiters<T> {
      33            1 :     fn default() -> Self {
      34            1 :         Waiters(Mutex::default())
      35            1 :     }
      36              : }
      37              : 
      38              : impl<T> Waiters<T> {
      39            1 :     pub(crate) fn register(&self, key: String) -> Result<Waiter<'_, T>, RegisterError> {
      40            1 :         let (tx, rx) = oneshot::channel();
      41            1 : 
      42            1 :         self.0
      43            1 :             .lock()
      44            1 :             .try_insert(key.clone(), tx)
      45            1 :             .map_err(|e| RegisterError::Occupied(e.entry.key().clone()))?;
      46              : 
      47            1 :         Ok(Waiter {
      48            1 :             receiver: rx,
      49            1 :             guard: DropKey {
      50            1 :                 registry: self,
      51            1 :                 key,
      52            1 :             },
      53            1 :         })
      54            1 :     }
      55              : 
      56            1 :     pub(crate) fn notify(&self, key: &str, value: T) -> Result<(), NotifyError>
      57            1 :     where
      58            1 :         T: Send + Sync,
      59            1 :     {
      60            1 :         let tx = self
      61            1 :             .0
      62            1 :             .lock()
      63            1 :             .remove(key)
      64            1 :             .ok_or_else(|| NotifyError::NotFound(key.to_string()))?;
      65              : 
      66            1 :         tx.send(value).map_err(|_| NotifyError::Hangup)
      67            1 :     }
      68              : }
      69              : 
      70              : struct DropKey<'a, T> {
      71              :     key: String,
      72              :     registry: &'a Waiters<T>,
      73              : }
      74              : 
      75              : impl<'a, T> Drop for DropKey<'a, T> {
      76            1 :     fn drop(&mut self) {
      77            1 :         self.registry.0.lock().remove(&self.key);
      78            1 :     }
      79              : }
      80              : 
      81              : pin_project! {
      82              :     pub(crate) struct Waiter<'a, T> {
      83              :         #[pin]
      84              :         receiver: oneshot::Receiver<T>,
      85              :         guard: DropKey<'a, T>,
      86              :     }
      87              : }
      88              : 
      89              : impl<T> std::future::Future for Waiter<'_, T> {
      90              :     type Output = Result<T, WaitError>;
      91              : 
      92            2 :     fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
      93            2 :         self.project()
      94            2 :             .receiver
      95            2 :             .poll(cx)
      96            2 :             .map_err(|_| WaitError::Hangup)
      97            2 :     }
      98              : }
      99              : 
     100              : #[cfg(test)]
     101              : mod tests {
     102              :     use super::*;
     103              :     use std::sync::Arc;
     104              : 
     105              :     #[tokio::test]
     106            1 :     async fn test_waiter() -> anyhow::Result<()> {
     107            1 :         let waiters = Arc::new(Waiters::default());
     108            1 : 
     109            1 :         let key = "Key";
     110            1 :         let waiter = waiters.register(key.to_owned())?;
     111            1 : 
     112            1 :         let waiters = Arc::clone(&waiters);
     113            1 :         let notifier = tokio::spawn(async move {
     114            1 :             waiters.notify(key, ())?;
     115            1 :             Ok(())
     116            1 :         });
     117            1 : 
     118            1 :         waiter.await?;
     119            1 :         notifier.await?
     120            1 :     }
     121              : }
        

Generated by: LCOV version 2.1-beta