Line data Source code
1 : use std::pin::Pin;
2 : use std::task;
3 :
4 : use hashbrown::HashMap;
5 : use parking_lot::Mutex;
6 : use pin_project_lite::pin_project;
7 : use thiserror::Error;
8 : use tokio::sync::oneshot;
9 :
10 0 : #[derive(Debug, Error)]
11 : pub(crate) enum RegisterError {
12 : #[error("Waiter `{0}` already registered")]
13 : Occupied(String),
14 : }
15 :
16 0 : #[derive(Debug, Error)]
17 : pub(crate) enum NotifyError {
18 : #[error("Notify failed: waiter `{0}` not registered")]
19 : NotFound(String),
20 :
21 : #[error("Notify failed: channel hangup")]
22 : Hangup,
23 : }
24 :
25 0 : #[derive(Debug, Error)]
26 : pub(crate) enum WaitError {
27 : #[error("Wait failed: channel hangup")]
28 : Hangup,
29 : }
30 :
31 : pub(crate) struct Waiters<T>(pub(self) Mutex<HashMap<String, oneshot::Sender<T>>>);
32 :
33 : impl<T> Default for Waiters<T> {
34 1 : fn default() -> Self {
35 1 : Waiters(Mutex::default())
36 1 : }
37 : }
38 :
39 : impl<T> Waiters<T> {
40 1 : pub(crate) fn register(&self, key: String) -> Result<Waiter<'_, T>, RegisterError> {
41 1 : let (tx, rx) = oneshot::channel();
42 1 :
43 1 : self.0
44 1 : .lock()
45 1 : .try_insert(key.clone(), tx)
46 1 : .map_err(|e| RegisterError::Occupied(e.entry.key().clone()))?;
47 :
48 1 : Ok(Waiter {
49 1 : receiver: rx,
50 1 : guard: DropKey {
51 1 : registry: self,
52 1 : key,
53 1 : },
54 1 : })
55 1 : }
56 :
57 1 : pub(crate) fn notify(&self, key: &str, value: T) -> Result<(), NotifyError>
58 1 : where
59 1 : T: Send + Sync,
60 1 : {
61 1 : let tx = self
62 1 : .0
63 1 : .lock()
64 1 : .remove(key)
65 1 : .ok_or_else(|| NotifyError::NotFound(key.to_string()))?;
66 :
67 1 : tx.send(value).map_err(|_| NotifyError::Hangup)
68 1 : }
69 : }
70 :
71 : struct DropKey<'a, T> {
72 : key: String,
73 : registry: &'a Waiters<T>,
74 : }
75 :
76 : impl<T> Drop for DropKey<'_, T> {
77 1 : fn drop(&mut self) {
78 1 : self.registry.0.lock().remove(&self.key);
79 1 : }
80 : }
81 :
82 : pin_project! {
83 : pub(crate) struct Waiter<'a, T> {
84 : #[pin]
85 : receiver: oneshot::Receiver<T>,
86 : guard: DropKey<'a, T>,
87 : }
88 : }
89 :
90 : impl<T> std::future::Future for Waiter<'_, T> {
91 : type Output = Result<T, WaitError>;
92 :
93 2 : fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
94 2 : self.project()
95 2 : .receiver
96 2 : .poll(cx)
97 2 : .map_err(|_| WaitError::Hangup)
98 2 : }
99 : }
100 :
101 : #[cfg(test)]
102 : mod tests {
103 : use std::sync::Arc;
104 :
105 : use super::*;
106 :
107 : #[tokio::test]
108 1 : async fn test_waiter() -> anyhow::Result<()> {
109 1 : let waiters = Arc::new(Waiters::default());
110 1 :
111 1 : let key = "Key";
112 1 : let waiter = waiters.register(key.to_owned())?;
113 1 :
114 1 : let waiters = Arc::clone(&waiters);
115 1 : let notifier = tokio::spawn(async move {
116 1 : waiters.notify(key, ())?;
117 1 : Ok(())
118 1 : });
119 1 :
120 1 : waiter.await?;
121 1 : notifier.await?
122 1 : }
123 : }
|