Line data Source code
1 : use hashbrown::HashMap;
2 : use parking_lot::Mutex;
3 : use pin_project_lite::pin_project;
4 : use std::pin::Pin;
5 : use std::task;
6 : use thiserror::Error;
7 : use tokio::sync::oneshot;
8 :
9 0 : #[derive(Debug, Error)]
10 : pub(crate) enum RegisterError {
11 : #[error("Waiter `{0}` already registered")]
12 : Occupied(String),
13 : }
14 :
15 0 : #[derive(Debug, Error)]
16 : pub(crate) enum NotifyError {
17 : #[error("Notify failed: waiter `{0}` not registered")]
18 : NotFound(String),
19 :
20 : #[error("Notify failed: channel hangup")]
21 : Hangup,
22 : }
23 :
24 0 : #[derive(Debug, Error)]
25 : pub(crate) enum WaitError {
26 : #[error("Wait failed: channel hangup")]
27 : Hangup,
28 : }
29 :
30 : pub(crate) struct Waiters<T>(pub(self) Mutex<HashMap<String, oneshot::Sender<T>>>);
31 :
32 : impl<T> Default for Waiters<T> {
33 6 : fn default() -> Self {
34 6 : Waiters(Mutex::default())
35 6 : }
36 : }
37 :
38 : impl<T> Waiters<T> {
39 6 : pub(crate) fn register(&self, key: String) -> Result<Waiter<'_, T>, RegisterError> {
40 6 : let (tx, rx) = oneshot::channel();
41 6 :
42 6 : self.0
43 6 : .lock()
44 6 : .try_insert(key.clone(), tx)
45 6 : .map_err(|e| RegisterError::Occupied(e.entry.key().clone()))?;
46 :
47 6 : Ok(Waiter {
48 6 : receiver: rx,
49 6 : guard: DropKey {
50 6 : registry: self,
51 6 : key,
52 6 : },
53 6 : })
54 6 : }
55 :
56 6 : pub(crate) fn notify(&self, key: &str, value: T) -> Result<(), NotifyError>
57 6 : where
58 6 : T: Send + Sync,
59 6 : {
60 6 : let tx = self
61 6 : .0
62 6 : .lock()
63 6 : .remove(key)
64 6 : .ok_or_else(|| NotifyError::NotFound(key.to_string()))?;
65 :
66 6 : tx.send(value).map_err(|_| NotifyError::Hangup)
67 6 : }
68 : }
69 :
70 : struct DropKey<'a, T> {
71 : key: String,
72 : registry: &'a Waiters<T>,
73 : }
74 :
75 : impl<'a, T> Drop for DropKey<'a, T> {
76 6 : fn drop(&mut self) {
77 6 : self.registry.0.lock().remove(&self.key);
78 6 : }
79 : }
80 :
81 : pin_project! {
82 : pub(crate) struct Waiter<'a, T> {
83 : #[pin]
84 : receiver: oneshot::Receiver<T>,
85 : guard: DropKey<'a, T>,
86 : }
87 : }
88 :
89 : impl<T> std::future::Future for Waiter<'_, T> {
90 : type Output = Result<T, WaitError>;
91 :
92 12 : fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
93 12 : self.project()
94 12 : .receiver
95 12 : .poll(cx)
96 12 : .map_err(|_| WaitError::Hangup)
97 12 : }
98 : }
99 :
100 : #[cfg(test)]
101 : mod tests {
102 : use super::*;
103 : use std::sync::Arc;
104 :
105 : #[tokio::test]
106 6 : async fn test_waiter() -> anyhow::Result<()> {
107 6 : let waiters = Arc::new(Waiters::default());
108 6 :
109 6 : let key = "Key";
110 6 : let waiter = waiters.register(key.to_owned())?;
111 6 :
112 6 : let waiters = Arc::clone(&waiters);
113 6 : let notifier = tokio::spawn(async move {
114 6 : waiters.notify(key, ())?;
115 6 : Ok(())
116 6 : });
117 6 :
118 6 : waiter.await?;
119 6 : notifier.await?
120 6 : }
121 : }
|